Coverage for databooks/git_utils.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

33 statements  

1"""Git helper functions.""" 

2from dataclasses import dataclass 

3from pathlib import Path 

4from typing import Dict, List 

5 

6from git import Blob, Git, Repo # type: ignore 

7 

8from databooks.common import find_obj 

9from databooks.logging import get_logger 

10 

11logger = get_logger(name=__file__) 

12 

13 

14@dataclass 

15class UnmergedBlob: 

16 """Container for git unmerged blobs.""" 

17 

18 filename: Path 

19 stage: Dict[int, Blob] 

20 

21 

22@dataclass 

23class ConflictFile: 

24 """Container for path and different versions of conflicted notebooks.""" 

25 

26 filename: Path 

27 first_log: str 

28 last_log: str 

29 first_contents: str 

30 last_contents: str 

31 

32 

33def get_repo(path: Path = Path.cwd()) -> Repo: 

34 """Find git repo in current or parent directories.""" 

35 repo_dir = find_obj( 

36 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True 

37 ) 

38 repo = Repo(path=repo_dir) 

39 logger.debug(f"Repo found at: {repo.working_dir}") 

40 return repo 

41 

42 

43def blob2commit(blob: Blob, repo: Repo) -> str: 

44 """Get the short commit message from blob hash.""" 

45 _git = Git(working_dir=repo.working_dir) 

46 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True) 

47 return ( 

48 commit_id 

49 if len(commit_id) > 0 

50 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob) 

51 ) 

52 

53 

54def get_conflict_blobs(repo: Repo) -> List[ConflictFile]: 

55 """Get the source files for conflicts.""" 

56 unmerged_blobs = repo.index.unmerged_blobs() 

57 blobs = ( 

58 UnmergedBlob(filename=Path(k), stage=dict(v)) 

59 for k, v in unmerged_blobs.items() 

60 if 0 not in dict(v).keys() # only get blobs that could not be merged 

61 ) 

62 

63 if not isinstance(repo.working_dir, (Path, str)): 

64 raise RuntimeError( 

65 "Expected `repo` to be `pathlib.Path` or `str`, got" 

66 f" {type(repo.working_dir)}." 

67 ) 

68 return [ 

69 ConflictFile( 

70 filename=repo.working_dir / blob.filename, 

71 first_log=blob2commit(blob=blob.stage[2], repo=repo), 

72 last_log=blob2commit(blob=blob.stage[3], repo=repo), 

73 first_contents=repo.git.show(blob.stage[2]), 

74 last_contents=repo.git.show(blob.stage[3]), 

75 ) 

76 for blob in blobs 

77 ]