Coverage for databooks/git_utils.py: 92%
77 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
1"""Git helper functions."""
2from dataclasses import dataclass
3from enum import Enum
4from pathlib import Path
5from typing import Dict, List, Optional, Sequence, Union, cast, overload
7from git import Git
8from git.diff import DiffIndex
9from git.objects.blob import Blob
10from git.objects.commit import Commit
11from git.objects.tree import Tree
12from git.repo import Repo
14from databooks.common import find_common_parent, find_obj
15from databooks.logging import get_logger, set_verbose
17logger = get_logger(name=__file__)
19# https://github.com/python/mypy/issues/5317
20ChangeType = Enum("ChangeType", [*DiffIndex.change_type, "U"]) # type: ignore[misc]
23@dataclass
24class UnmergedBlob:
25 """Container for git unmerged blobs."""
27 filename: Path
28 stage: Dict[int, Blob]
31@dataclass
32class ConflictFile:
33 """Container for path and different versions of conflicted notebooks."""
35 filename: Path
36 first_log: str
37 last_log: str
38 first_contents: str
39 last_contents: str
42@dataclass
43class Contents:
44 """Container for path of file versions."""
46 path: Optional[Path]
47 contents: Optional[str]
50@dataclass
51class DiffContents:
52 """Container for path and different versions of conflicted notebooks."""
54 a: Contents
55 b: Contents
56 change_type: ChangeType
59@overload
60def blob2str(blob: None) -> None:
61 ...
64@overload
65def blob2str(blob: Blob) -> str:
66 ...
69def blob2str(blob: Optional[Blob]) -> Optional[str]:
70 """Get the blob contents if they exist (otherwise return `None`)."""
71 return blob.data_stream.read() if blob is not None else None
74def blob2commit(blob: Blob, repo: Repo) -> str:
75 """Get the short commit message from blob hash."""
76 _git = Git(working_dir=repo.working_dir)
77 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True)
78 return (
79 commit_id
80 if len(commit_id) > 0
81 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob)
82 )
85def diff2contents(
86 blob: Blob,
87 ref: Optional[Union[Tree, Commit, str]],
88 path: Path,
89 not_exists: bool = False,
90) -> Optional[str]:
91 """
92 Get the blob contents from the diff.
94 Depends on whether we are diffing against current working tree and if object exists
95 at diff time (added or deleted objects only exist at one side). If comparing
96 against working tree (`ref=None`) we return the current file contents.
97 :param blob: git diff blob
98 :param ref: git reference
99 :param path: path to object
100 :param not_exists: whether object exists at 'diff time' (added or removed objects
101 do not exist)
102 :return: blob contents as a string (if exists)
103 """
104 if not_exists:
105 return None
106 elif ref is None:
107 return path.read_text()
108 else:
109 return blob2str(blob)
112def get_repo(path: Path) -> Optional[Repo]:
113 """Find git repo in current or parent directories."""
114 repo_dir = find_obj(
115 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True
116 )
117 if repo_dir is not None:
118 repo = Repo(path=repo_dir)
119 logger.debug(f"Repo found at: {repo.working_dir}.")
120 return repo
121 else:
122 logger.debug(f"No repo found at {path}.")
125def get_conflict_blobs(repo: Repo) -> List[ConflictFile]:
126 """Get the source files for conflicts."""
127 unmerged_blobs = repo.index.unmerged_blobs()
128 blobs = (
129 UnmergedBlob(filename=Path(k), stage=dict(v))
130 for k, v in unmerged_blobs.items()
131 if 0 not in dict(v).keys() # only get blobs that could not be merged
132 )
134 if not isinstance(repo.working_dir, (Path, str)):
135 raise RuntimeError(
136 "Expected `repo` to be `pathlib.Path` or `str`, got"
137 f" {type(repo.working_dir)}."
138 )
139 return [
140 ConflictFile(
141 filename=repo.working_dir / blob.filename,
142 first_log=blob2commit(blob=blob.stage[2], repo=repo),
143 last_log=blob2commit(blob=blob.stage[3], repo=repo),
144 first_contents=blob2str(blob.stage[2]),
145 last_contents=blob2str(blob.stage[3]),
146 )
147 for blob in blobs
148 ]
151def get_nb_diffs(
152 ref_base: Optional[str] = None,
153 ref_remote: Optional[str] = None,
154 paths: Sequence[Path] = (),
155 *,
156 repo: Optional[Repo] = None,
157 verbose: bool = False,
158) -> List[DiffContents]:
159 """
160 Get the noteebook(s) git diff(s).
162 By default, diffs are compared with the current working directory. That is, staged
163 files will still show up in the diffs. Only return the diffs for notebook files.
164 """
165 if verbose:
166 set_verbose(logger)
168 common_path = find_common_parent(paths or [Path.cwd()])
169 repo = get_repo(path=common_path) if repo is None else repo
170 if repo is None or repo.working_dir is None:
171 raise ValueError("No repo found - cannot compute diffs.")
173 ref_base = repo.index if ref_base is None else repo.tree(ref_base)
174 ref_remote = ref_remote if ref_remote is None else repo.tree(ref_remote)
176 logger.debug(
177 f"Looking for diffs on path(s) {[p.resolve() for p in paths]}.\n"
178 f"Comparing `{ref_base}` and `{ref_remote}`."
179 )
180 repo_root_dir = Path(repo.working_dir)
181 return [
182 DiffContents(
183 a=Contents(
184 path=Path(d.a_path),
185 contents=diff2contents(
186 blob=cast(Blob, d.a_blob),
187 ref=ref_base,
188 path=repo_root_dir / d.a_path,
189 not_exists=d.change_type is ChangeType.A, # type: ignore
190 ),
191 ),
192 b=Contents(
193 path=Path(d.b_path),
194 contents=diff2contents(
195 blob=cast(Blob, d.b_blob),
196 ref=ref_remote,
197 path=repo_root_dir / d.b_path,
198 not_exists=d.change_type is ChangeType.D, # type: ignore
199 ),
200 ),
201 change_type=ChangeType[d.change_type],
202 )
203 for d in ref_base.diff(
204 other=ref_remote, paths=list(paths) or list(repo_root_dir.rglob("*.ipynb"))
205 )
206 ]