Coverage for databooks/git_utils.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Git helper functions."""
2from dataclasses import dataclass
3from pathlib import Path
4from typing import Dict, List
6from git import Blob, Git, Repo # type: ignore
8from databooks.common import find_obj
9from databooks.logging import get_logger
11logger = get_logger(name=__file__)
14@dataclass
15class UnmergedBlob:
16 """Container for git unmerged blobs."""
18 filename: Path
19 stage: Dict[int, Blob]
22@dataclass
23class ConflictFile:
24 """Container for path and different versions of conflicted notebooks."""
26 filename: Path
27 first_log: str
28 last_log: str
29 first_contents: str
30 last_contents: str
33def get_repo(path: Path = Path.cwd()) -> Repo:
34 """Find git repo in current or parent directories."""
35 repo_dir = find_obj(
36 obj_name=".git", start=Path(path.anchor), finish=path, is_dir=True
37 )
38 repo = Repo(path=repo_dir)
39 logger.debug(f"Repo found at: {repo.working_dir}")
40 return repo
43def blob2commit(blob: Blob, repo: Repo) -> str:
44 """Get the short commit message from blob hash."""
45 _git = Git(working_dir=repo.working_dir)
46 commit_id = _git.log(find_object=blob, max_count=1, all=True, oneline=True)
47 return (
48 commit_id
49 if len(commit_id) > 0
50 else _git.stash("list", "--oneline", "--max-count", "1", "--find-object", blob)
51 )
54def get_conflict_blobs(repo: Repo) -> List[ConflictFile]:
55 """Get the source files for conflicts."""
56 unmerged_blobs = repo.index.unmerged_blobs()
57 blobs = (
58 UnmergedBlob(filename=Path(k), stage=dict(v))
59 for k, v in unmerged_blobs.items()
60 if 0 not in dict(v).keys() # only get blobs that could not be merged
61 )
63 if not isinstance(repo.working_dir, (Path, str)):
64 raise RuntimeError(
65 "Expected `repo` to be `pathlib.Path` or `str`, got"
66 f" {type(repo.working_dir)}."
67 )
68 return [
69 ConflictFile(
70 filename=repo.working_dir / blob.filename,
71 first_log=blob2commit(blob=blob.stage[2], repo=repo),
72 last_log=blob2commit(blob=blob.stage[3], repo=repo),
73 first_contents=repo.git.show(blob.stage[2]),
74 last_contents=repo.git.show(blob.stage[3]),
75 )
76 for blob in blobs
77 ]