Coverage for databooks/git_utils.py: 92%

77 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-10-03 12:27 +0000

1"""Git helper functions.""" 

2from dataclasses import dataclass 

3from enum import Enum 

4from pathlib import Path 

5from typing import Dict, List, Optional, Sequence, Union, cast, overload 

6 

7from git import Git 

8from git.diff import DiffIndex 

9from git.objects.blob import Blob 

10from git.objects.commit import Commit 

11from git.objects.tree import Tree 

12from git.repo import Repo 

13 

14from databooks.common import find_common_parent, find_obj 

15from databooks.logging import get_logger, set_verbose 

16 

17logger = get_logger(name=__file__) 

18 

19# https://github.com/python/mypy/issues/5317 

20ChangeType = Enum("ChangeType", [*DiffIndex.change_type, "U"]) # type: ignore[misc] 

21 

22 

23@dataclass 

24class UnmergedBlob: 

25 """Container for git unmerged blobs.""" 

26 

27 filename: Path 

28 stage: Dict[int, Blob] 

29 

30 

31@dataclass 

32class ConflictFile: 

33 """Container for path and different versions of conflicted notebooks.""" 

34 

35 filename: Path 

36 first_log: str 

37 last_log: str 

38 first_contents: str 

39 last_contents: str 

40 

41 

42@dataclass 

43class Contents: 

44 """Container for path of file versions.""" 

45 

46 path: Optional[Path] 

47 contents: Optional[str] 

48 

49 

50@dataclass 

51class DiffContents: 

52 """Container for path and different versions of conflicted notebooks.""" 

53 

54 a: Contents 

55 b: Contents 

56 change_type: ChangeType 

57 

58 

59@overload 

60def blob2str(blob: None) -> None: 

61 ... 

62 

63 

64@overload 

65def blob2str(blob: Blob) -> str: 

66 ... 

67 

68 

69def blob2str(blob: Optional[Blob]) -> Optional[str]: 

70 """Get the blob contents if they exist (otherwise return `None`).""" 

71 return blob.data_stream.read() if blob is not None else None 

72 

73 

74def blob2commit(blob: Blob, repo: Repo) -> str: 

75 """Get the short commit message from blob hash.""" 

76 _git = Git(working_dir=repo.working_dir) 

77 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True) 

78 return ( 

79 commit_id 

80 if len(commit_id) > 0 

81 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob) 

82 ) 

83 

84 

85def diff2contents( 

86 blob: Blob, 

87 ref: Optional[Union[Tree, Commit, str]], 

88 path: Path, 

89 not_exists: bool = False, 

90) -> Optional[str]: 

91 """ 

92 Get the blob contents from the diff. 

93 

94 Depends on whether we are diffing against current working tree and if object exists 

95 at diff time (added or deleted objects only exist at one side). If comparing 

96 against working tree (`ref=None`) we return the current file contents. 

97 :param blob: git diff blob 

98 :param ref: git reference 

99 :param path: path to object 

100 :param not_exists: whether object exists at 'diff time' (added or removed objects 

101 do not exist) 

102 :return: blob contents as a string (if exists) 

103 """ 

104 if not_exists: 

105 return None 

106 elif ref is None: 

107 return path.read_text() 

108 else: 

109 return blob2str(blob) 

110 

111 

112def get_repo(path: Path) -> Optional[Repo]: 

113 """Find git repo in current or parent directories.""" 

114 repo_dir = find_obj( 

115 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True 

116 ) 

117 if repo_dir is not None: 

118 repo = Repo(path=repo_dir) 

119 logger.debug(f"Repo found at: {repo.working_dir}.") 

120 return repo 

121 else: 

122 logger.debug(f"No repo found at {path}.") 

123 

124 

125def get_conflict_blobs(repo: Repo) -> List[ConflictFile]: 

126 """Get the source files for conflicts.""" 

127 unmerged_blobs = repo.index.unmerged_blobs() 

128 blobs = ( 

129 UnmergedBlob(filename=Path(k), stage=dict(v)) 

130 for k, v in unmerged_blobs.items() 

131 if 0 not in dict(v).keys() # only get blobs that could not be merged 

132 ) 

133 

134 if not isinstance(repo.working_dir, (Path, str)): 

135 raise RuntimeError( 

136 "Expected `repo` to be `pathlib.Path` or `str`, got" 

137 f" {type(repo.working_dir)}." 

138 ) 

139 return [ 

140 ConflictFile( 

141 filename=repo.working_dir / blob.filename, 

142 first_log=blob2commit(blob=blob.stage[2], repo=repo), 

143 last_log=blob2commit(blob=blob.stage[3], repo=repo), 

144 first_contents=blob2str(blob.stage[2]), 

145 last_contents=blob2str(blob.stage[3]), 

146 ) 

147 for blob in blobs 

148 ] 

149 

150 

151def get_nb_diffs( 

152 ref_base: Optional[str] = None, 

153 ref_remote: Optional[str] = None, 

154 paths: Sequence[Path] = (), 

155 *, 

156 repo: Optional[Repo] = None, 

157 verbose: bool = False, 

158) -> List[DiffContents]: 

159 """ 

160 Get the noteebook(s) git diff(s). 

161 

162 By default, diffs are compared with the current working directory. That is, staged 

163 files will still show up in the diffs. Only return the diffs for notebook files. 

164 """ 

165 if verbose: 

166 set_verbose(logger) 

167 

168 common_path = find_common_parent(paths or [Path.cwd()]) 

169 repo = get_repo(path=common_path) if repo is None else repo 

170 if repo is None or repo.working_dir is None: 

171 raise ValueError("No repo found - cannot compute diffs.") 

172 

173 ref_base = repo.index if ref_base is None else repo.tree(ref_base) 

174 ref_remote = ref_remote if ref_remote is None else repo.tree(ref_remote) 

175 

176 logger.debug( 

177 f"Looking for diffs on path(s) {[p.resolve() for p in paths]}.\n" 

178 f"Comparing `{ref_base}` and `{ref_remote}`." 

179 ) 

180 repo_root_dir = Path(repo.working_dir) 

181 return [ 

182 DiffContents( 

183 a=Contents( 

184 path=Path(d.a_path), 

185 contents=diff2contents( 

186 blob=cast(Blob, d.a_blob), 

187 ref=ref_base, 

188 path=repo_root_dir / d.a_path, 

189 not_exists=d.change_type is ChangeType.A, # type: ignore 

190 ), 

191 ), 

192 b=Contents( 

193 path=Path(d.b_path), 

194 contents=diff2contents( 

195 blob=cast(Blob, d.b_blob), 

196 ref=ref_remote, 

197 path=repo_root_dir / d.b_path, 

198 not_exists=d.change_type is ChangeType.D, # type: ignore 

199 ), 

200 ), 

201 change_type=ChangeType[d.change_type], 

202 ) 

203 for d in ref_base.diff( 

204 other=ref_remote, paths=list(paths) or list(repo_root_dir.rglob("*.ipynb")) 

205 ) 

206 ]