Coverage for databooks/git_utils.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

30 statements  

1"""Git helper functions.""" 

2from dataclasses import dataclass 

3from pathlib import Path 

4from typing import Dict, List, cast 

5 

6from git import Blob, Git, Repo # type: ignore 

7 

8from databooks.common import get_logger 

9 

10logger = get_logger(name=__file__) 

11 

12 

13@dataclass 

14class UnmergedBlob: 

15 """Container for git unmerged blobs.""" 

16 

17 filename: Path 

18 stage: Dict[int, Blob] 

19 

20 

21@dataclass 

22class ConflictFile: 

23 """Container for path and different versions of conflicted notebooks.""" 

24 

25 filename: Path 

26 first_log: str 

27 last_log: str 

28 first_contents: str 

29 last_contents: str 

30 

31 

32def get_repo(path: Path = Path.cwd()) -> Repo: 

33 """Find git repo in current or parent directories.""" 

34 repo = Repo(path=path, search_parent_directories=True) 

35 logger.info(f"Repo found at: {repo.working_dir}") 

36 return repo 

37 

38 

39def blob2commit(blob: Blob, repo: Repo) -> str: 

40 """Get the short commit message from blob hash.""" 

41 _git = Git(working_dir=repo.working_dir) 

42 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True) 

43 return ( 

44 commit_id 

45 if len(commit_id) > 0 

46 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob) 

47 ) 

48 

49 

50def get_conflict_blobs(repo: Repo) -> List[ConflictFile]: 

51 """Get the source files for conflicts.""" 

52 unmerged_blobs = repo.index.unmerged_blobs() 

53 blobs = ( 

54 UnmergedBlob(filename=Path(k), stage=dict(v)) 

55 for k, v in unmerged_blobs.items() 

56 if 0 not in dict(v).keys() # only get blobs that could not be merged 

57 ) 

58 # Type checking: `repo.working_dir` is not None if repo is passed 

59 repo.working_dir = cast(str, repo.working_dir) 

60 return [ 

61 ConflictFile( 

62 filename=repo.working_dir / blob.filename, 

63 first_log=blob2commit(blob=blob.stage[2], repo=repo), 

64 last_log=blob2commit(blob=blob.stage[3], repo=repo), 

65 first_contents=repo.git.show(blob.stage[2]), 

66 last_contents=repo.git.show(blob.stage[3]), 

67 ) 

68 for blob in blobs 

69 ]