Coverage for databooks/git_utils.py: 92%
78 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-09 13:11 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-09 13:11 +0000
1"""Git helper functions."""
2from dataclasses import dataclass
3from enum import Enum
4from pathlib import Path
5from typing import Dict, List, Optional, Sequence, Union, cast, overload
7from git import Git
8from git.diff import DiffIndex
9from git.objects.blob import Blob
10from git.objects.commit import Commit
11from git.objects.tree import Tree
12from git.repo import Repo
14from databooks.common import find_common_parent, find_obj
15from databooks.logging import get_logger, set_verbose
17logger = get_logger(name=__file__)
19# https://github.com/python/mypy/issues/5317
20ChangeType = Enum("ChangeType", [*DiffIndex.change_type, "U"]) # type: ignore[misc]
23@dataclass
24class UnmergedBlob:
25 """Container for git unmerged blobs."""
27 filename: Path
28 stage: Dict[int, Blob]
31@dataclass
32class ConflictFile:
33 """Container for path and different versions of conflicted notebooks."""
35 filename: Path
36 first_log: str
37 last_log: str
38 first_contents: str
39 last_contents: str
42@dataclass
43class Contents:
44 """Container for path of file versions."""
46 path: Optional[Path]
47 contents: Optional[str]
50@dataclass
51class DiffContents:
52 """Container for path and different versions of conflicted notebooks."""
54 a: Contents
55 b: Contents
56 change_type: ChangeType
59@overload
60def blob2str(blob: None) -> None:
61 ...
64@overload
65def blob2str(blob: Blob) -> str:
66 ...
69def blob2str(blob: Optional[Blob]) -> Optional[str]:
70 """Get the blob contents if they exist (otherwise return `None`)."""
71 return blob.data_stream.read() if blob is not None else None
74def blob2commit(blob: Blob, repo: Repo) -> str:
75 """Get the short commit message from blob hash."""
76 _git = Git(working_dir=repo.working_dir)
77 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True)
78 return (
79 commit_id
80 if len(commit_id) > 0
81 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob)
82 )
85def diff2contents(
86 blob: Blob,
87 ref: Optional[Union[Tree, Commit, str]],
88 path: Path,
89 not_exists: bool = False,
90) -> Optional[str]:
91 """
92 Get the blob contents from the diff.
94 Depends on whether we are diffing against current working tree and if object exists
95 at diff time (added or deleted objects only exist at one side). If comparing
96 against working tree (`ref=None`) we return the current file contents.
97 :param blob: git diff blob
98 :param ref: git reference
99 :param path: path to object
100 :param not_exists: whether object exists at 'diff time' (added or removed objects
101 do not exist)
102 :return: blob contents as a string (if exists)
103 """
104 if not_exists:
105 return None
106 elif ref is None:
107 return path.read_text()
108 else:
109 return blob2str(blob)
112def get_repo(path: Path) -> Optional[Repo]:
113 """Find git repo in current or parent directories."""
114 repo_dir = find_obj(
115 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True
116 )
117 if repo_dir is not None:
118 repo = Repo(path=repo_dir)
119 logger.debug(f"Repo found at: {repo.working_dir}.")
120 return repo
121 else:
122 logger.debug(f"No repo found at {path}.")
123 return None
126def get_conflict_blobs(repo: Repo) -> List[ConflictFile]:
127 """Get the source files for conflicts."""
128 unmerged_blobs = repo.index.unmerged_blobs()
129 blobs = (
130 UnmergedBlob(filename=Path(k), stage=dict(v))
131 for k, v in unmerged_blobs.items()
132 if 0 not in dict(v).keys() # only get blobs that could not be merged
133 )
135 if not isinstance(repo.working_dir, (Path, str)):
136 raise RuntimeError(
137 "Expected `repo` to be `pathlib.Path` or `str`, got"
138 f" {type(repo.working_dir)}."
139 )
140 return [
141 ConflictFile(
142 filename=repo.working_dir / blob.filename,
143 first_log=blob2commit(blob=blob.stage[2], repo=repo),
144 last_log=blob2commit(blob=blob.stage[3], repo=repo),
145 first_contents=blob2str(blob.stage[2]),
146 last_contents=blob2str(blob.stage[3]),
147 )
148 for blob in blobs
149 ]
152def get_nb_diffs(
153 ref_base: Optional[str] = None,
154 ref_remote: Optional[str] = None,
155 paths: Sequence[Path] = (),
156 *,
157 repo: Optional[Repo] = None,
158 verbose: bool = False,
159) -> List[DiffContents]:
160 """
161 Get the noteebook(s) git diff(s).
163 By default, diffs are compared with the current working direcotory. That is, staged
164 files will still show up in the diffs. Only return the diffs for notebook files.
165 """
166 if verbose:
167 set_verbose(logger)
169 common_path = find_common_parent(paths or [Path.cwd()])
170 repo = get_repo(path=common_path) if repo is None else repo
171 if repo is None or repo.working_dir is None:
172 raise ValueError("No repo found - cannot compute diffs.")
174 ref_base = repo.index if ref_base is None else repo.tree(ref_base)
175 ref_remote = ref_remote if ref_remote is None else repo.tree(ref_remote)
177 logger.debug(
178 f"Looking for diffs on path(s) {[p.resolve() for p in paths]}.\n"
179 f"Comparing `{ref_base}` and `{ref_remote}`."
180 )
181 repo_root_dir = Path(repo.working_dir)
182 return [
183 DiffContents(
184 a=Contents(
185 path=Path(d.a_path),
186 contents=diff2contents(
187 blob=cast(Blob, d.a_blob),
188 ref=ref_base,
189 path=repo_root_dir / d.a_path,
190 not_exists=d.change_type is ChangeType.A, # type: ignore
191 ),
192 ),
193 b=Contents(
194 path=Path(d.b_path),
195 contents=diff2contents(
196 blob=cast(Blob, d.b_blob),
197 ref=ref_remote,
198 path=repo_root_dir / d.b_path,
199 not_exists=d.change_type is ChangeType.D, # type: ignore
200 ),
201 ),
202 change_type=ChangeType[d.change_type],
203 )
204 for d in ref_base.diff(
205 other=ref_remote, paths=list(paths) or list(repo_root_dir.rglob("*.ipynb"))
206 )
207 ]