Coverage for databooks/git_utils.py: 92%

78 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 13:11 +0000

1"""Git helper functions.""" 

2from dataclasses import dataclass 

3from enum import Enum 

4from pathlib import Path 

5from typing import Dict, List, Optional, Sequence, Union, cast, overload 

6 

7from git import Git 

8from git.diff import DiffIndex 

9from git.objects.blob import Blob 

10from git.objects.commit import Commit 

11from git.objects.tree import Tree 

12from git.repo import Repo 

13 

14from databooks.common import find_common_parent, find_obj 

15from databooks.logging import get_logger, set_verbose 

16 

17logger = get_logger(name=__file__) 

18 

19# https://github.com/python/mypy/issues/5317 

20ChangeType = Enum("ChangeType", [*DiffIndex.change_type, "U"]) # type: ignore[misc] 

21 

22 

23@dataclass 

24class UnmergedBlob: 

25 """Container for git unmerged blobs.""" 

26 

27 filename: Path 

28 stage: Dict[int, Blob] 

29 

30 

31@dataclass 

32class ConflictFile: 

33 """Container for path and different versions of conflicted notebooks.""" 

34 

35 filename: Path 

36 first_log: str 

37 last_log: str 

38 first_contents: str 

39 last_contents: str 

40 

41 

42@dataclass 

43class Contents: 

44 """Container for path of file versions.""" 

45 

46 path: Optional[Path] 

47 contents: Optional[str] 

48 

49 

50@dataclass 

51class DiffContents: 

52 """Container for path and different versions of conflicted notebooks.""" 

53 

54 a: Contents 

55 b: Contents 

56 change_type: ChangeType 

57 

58 

59@overload 

60def blob2str(blob: None) -> None: 

61 ... 

62 

63 

64@overload 

65def blob2str(blob: Blob) -> str: 

66 ... 

67 

68 

69def blob2str(blob: Optional[Blob]) -> Optional[str]: 

70 """Get the blob contents if they exist (otherwise return `None`).""" 

71 return blob.data_stream.read() if blob is not None else None 

72 

73 

74def blob2commit(blob: Blob, repo: Repo) -> str: 

75 """Get the short commit message from blob hash.""" 

76 _git = Git(working_dir=repo.working_dir) 

77 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True) 

78 return ( 

79 commit_id 

80 if len(commit_id) > 0 

81 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob) 

82 ) 

83 

84 

85def diff2contents( 

86 blob: Blob, 

87 ref: Optional[Union[Tree, Commit, str]], 

88 path: Path, 

89 not_exists: bool = False, 

90) -> Optional[str]: 

91 """ 

92 Get the blob contents from the diff. 

93 

94 Depends on whether we are diffing against current working tree and if object exists 

95 at diff time (added or deleted objects only exist at one side). If comparing 

96 against working tree (`ref=None`) we return the current file contents. 

97 :param blob: git diff blob 

98 :param ref: git reference 

99 :param path: path to object 

100 :param not_exists: whether object exists at 'diff time' (added or removed objects 

101 do not exist) 

102 :return: blob contents as a string (if exists) 

103 """ 

104 if not_exists: 

105 return None 

106 elif ref is None: 

107 return path.read_text() 

108 else: 

109 return blob2str(blob) 

110 

111 

112def get_repo(path: Path) -> Optional[Repo]: 

113 """Find git repo in current or parent directories.""" 

114 repo_dir = find_obj( 

115 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True 

116 ) 

117 if repo_dir is not None: 

118 repo = Repo(path=repo_dir) 

119 logger.debug(f"Repo found at: {repo.working_dir}.") 

120 return repo 

121 else: 

122 logger.debug(f"No repo found at {path}.") 

123 return None 

124 

125 

126def get_conflict_blobs(repo: Repo) -> List[ConflictFile]: 

127 """Get the source files for conflicts.""" 

128 unmerged_blobs = repo.index.unmerged_blobs() 

129 blobs = ( 

130 UnmergedBlob(filename=Path(k), stage=dict(v)) 

131 for k, v in unmerged_blobs.items() 

132 if 0 not in dict(v).keys() # only get blobs that could not be merged 

133 ) 

134 

135 if not isinstance(repo.working_dir, (Path, str)): 

136 raise RuntimeError( 

137 "Expected `repo` to be `pathlib.Path` or `str`, got" 

138 f" {type(repo.working_dir)}." 

139 ) 

140 return [ 

141 ConflictFile( 

142 filename=repo.working_dir / blob.filename, 

143 first_log=blob2commit(blob=blob.stage[2], repo=repo), 

144 last_log=blob2commit(blob=blob.stage[3], repo=repo), 

145 first_contents=blob2str(blob.stage[2]), 

146 last_contents=blob2str(blob.stage[3]), 

147 ) 

148 for blob in blobs 

149 ] 

150 

151 

152def get_nb_diffs( 

153 ref_base: Optional[str] = None, 

154 ref_remote: Optional[str] = None, 

155 paths: Sequence[Path] = (), 

156 *, 

157 repo: Optional[Repo] = None, 

158 verbose: bool = False, 

159) -> List[DiffContents]: 

160 """ 

161 Get the noteebook(s) git diff(s). 

162 

163 By default, diffs are compared with the current working direcotory. That is, staged 

164 files will still show up in the diffs. Only return the diffs for notebook files. 

165 """ 

166 if verbose: 

167 set_verbose(logger) 

168 

169 common_path = find_common_parent(paths or [Path.cwd()]) 

170 repo = get_repo(path=common_path) if repo is None else repo 

171 if repo is None or repo.working_dir is None: 

172 raise ValueError("No repo found - cannot compute diffs.") 

173 

174 ref_base = repo.index if ref_base is None else repo.tree(ref_base) 

175 ref_remote = ref_remote if ref_remote is None else repo.tree(ref_remote) 

176 

177 logger.debug( 

178 f"Looking for diffs on path(s) {[p.resolve() for p in paths]}.\n" 

179 f"Comparing `{ref_base}` and `{ref_remote}`." 

180 ) 

181 repo_root_dir = Path(repo.working_dir) 

182 return [ 

183 DiffContents( 

184 a=Contents( 

185 path=Path(d.a_path), 

186 contents=diff2contents( 

187 blob=cast(Blob, d.a_blob), 

188 ref=ref_base, 

189 path=repo_root_dir / d.a_path, 

190 not_exists=d.change_type is ChangeType.A, # type: ignore 

191 ), 

192 ), 

193 b=Contents( 

194 path=Path(d.b_path), 

195 contents=diff2contents( 

196 blob=cast(Blob, d.b_blob), 

197 ref=ref_remote, 

198 path=repo_root_dir / d.b_path, 

199 not_exists=d.change_type is ChangeType.D, # type: ignore 

200 ), 

201 ), 

202 change_type=ChangeType[d.change_type], 

203 ) 

204 for d in ref_base.diff( 

205 other=ref_remote, paths=list(paths) or list(repo_root_dir.rglob("*.ipynb")) 

206 ) 

207 ]