Coverage for databooks/data_models/notebook.py: 92%
113 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4import json
5from copy import deepcopy
6from difflib import SequenceMatcher
7from itertools import chain
8from pathlib import Path
9from typing import (
10 Any,
11 Callable,
12 Generator,
13 Iterable,
14 List,
15 Optional,
16 Sequence,
17 Tuple,
18 TypeVar,
19 Union,
20 cast,
21)
23from pydantic import Extra, RootModel
24from rich import box
25from rich.columns import Columns
26from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult
27from rich.panel import Panel
28from rich.text import Text
30from databooks.data_models.base import BaseCells, DatabooksBase
31from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell
32from databooks.logging import get_logger
34logger = get_logger(__file__)
36Cell = Union[CodeCell, RawCell, MarkdownCell]
37CellsPair = Tuple[List[Cell], List[Cell]]
38T = TypeVar("T", Cell, CellsPair)
41class Cells(RootModel[Sequence[T]], BaseCells[T]):
42 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
44 root: Sequence[T]
46 @property
47 def data(self) -> List[T]: # type: ignore
48 """Define property `data` required for `collections.UserList` class."""
49 return list(self.root)
51 def __iter__(self) -> Generator[Any, None, None]:
52 """Use list property as iterable."""
53 return (el for el in self.data)
55 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]:
56 """Return the difference using `difflib.SequenceMatcher`."""
57 if type(self) != type(other):
58 raise TypeError(
59 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
60 f" `{type(other).__name__}`"
61 )
63 # By setting the context to the max number of cells and using
64 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
65 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
66 n_context = max(len(self), len(other))
67 diff_opcodes = list(
68 SequenceMatcher(
69 isjunk=None, a=self, b=other, autojunk=False
70 ).get_grouped_opcodes(n_context)
71 )
73 if len(diff_opcodes) > 1:
74 raise RuntimeError(
75 "Expected one group for opcodes when context size is "
76 f" {n_context} for {len(self)} and {len(other)} cells in"
77 " notebooks."
78 )
80 return Cells[CellsPair](
81 [
82 # https://github.com/python/mypy/issues/9459
83 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
84 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
85 ]
86 )
88 def __rich_console__(
89 self, console: Console, options: ConsoleOptions
90 ) -> RenderResult:
91 """Rich display of all cells in notebook."""
92 yield from self._get_renderables(expand=True, width=options.max_width // 3)
94 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]:
95 """Get the Rich renderables, depending on whether `Cells` is a diff or not."""
96 if all(isinstance(el, tuple) for el in self.data):
97 return chain.from_iterable(
98 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs)
99 if val[0] != val[1]
100 else val[0]
101 for val in cast(List[CellsPair], self.data)
102 )
103 return cast(List[Cell], self.data)
105 @classmethod
106 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
107 """Get validators for custom class."""
108 yield cls.validate
110 @classmethod
111 def validate(cls, v: List[T]) -> Cells[T]:
112 """Ensure object is custom defined container."""
113 if not isinstance(v, cls):
114 return cls(v)
115 else:
116 return v
118 @classmethod
119 def wrap_cols(
120 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any
121 ) -> Sequence[Columns]:
122 """Wrap the first and second cells into colunmns for iterable."""
123 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)]
124 _first = Group(*first_cells or _empty)
125 _last = Group(*last_cells or _empty)
126 return [Columns([_first, _last], **cols_kwargs)]
128 @staticmethod
129 def wrap_git(
130 first_cells: List[Cell],
131 last_cells: List[Cell],
132 hash_first: Optional[str] = None,
133 hash_last: Optional[str] = None,
134 ) -> Sequence[Cell]:
135 """Wrap git-diff cells in existing notebook."""
136 return [
137 MarkdownCell(
138 metadata=CellMetadata(git_hash=hash_first),
139 source=[f"`<<<<<<< {hash_first}`"],
140 ),
141 *first_cells,
142 MarkdownCell(
143 source=["`=======`"],
144 metadata=CellMetadata(),
145 ),
146 *last_cells,
147 MarkdownCell(
148 metadata=CellMetadata(git_hash=hash_last),
149 source=[f"`>>>>>>> {hash_last}`"],
150 ),
151 ]
153 def resolve(
154 self: Cells[CellsPair],
155 *,
156 keep_first_cells: Optional[bool] = None,
157 first_id: Optional[str] = None,
158 last_id: Optional[str] = None,
159 **kwargs: Any,
160 ) -> List[Cell]:
161 """
162 Resolve differences between `databooks.data_models.notebook.Cells`.
164 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
165 If `None`, then keep both wrapping the git-diff tags
166 :param first_id: Git hash of first file in conflict
167 :param last_id: Git hash of last file in conflict
168 :param kwargs: (Unused) keyword arguments to keep compatibility with
169 `databooks.data_models.base.resolve`
170 :return: List of cells
171 """
172 if keep_first_cells is not None:
173 return list(
174 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
175 )
176 return list(
177 chain.from_iterable(
178 Cells.wrap_git(
179 first_cells=val[0],
180 last_cells=val[1],
181 hash_first=first_id,
182 hash_last=last_id,
183 )
184 if val[0] != val[1]
185 else val[0]
186 for val in self.data
187 )
188 )
191class NotebookMetadata(DatabooksBase):
192 """Notebook metadata. Empty by default but can accept extra fields."""
195class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
196 """Jupyter notebook. Extra fields yield invalid notebook."""
198 nbformat: int
199 nbformat_minor: int
200 metadata: NotebookMetadata
201 cells: Cells[Cell]
203 def __rich_console__(
204 self, console: Console, options: ConsoleOptions
205 ) -> RenderResult:
206 """Rich display notebook."""
208 def _rich(kernel: str) -> Text:
209 """Display with `kernel` theme, horizontal padding and right-justified."""
210 return Text(kernel, style="kernel", justify="right")
212 kernelspec = self.metadata.dict().get("kernelspec", {})
213 if isinstance(kernelspec, tuple): # check if this is a `DiffCells`
214 kernelspec = tuple(
215 ks or {"language": "text", "display_name": "null"} for ks in kernelspec
216 )
217 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec)
218 nb_lang = lang_first if lang_first == lang_last else "text"
219 if any("display_name" in ks.keys() for ks in kernelspec):
220 kernel_first, kernel_last = [
221 _rich(ks["display_name"]) for ks in kernelspec
222 ]
223 yield Columns(
224 [kernel_first, kernel_last],
225 expand=True,
226 width=options.max_width // 3,
227 ) if kernel_first != kernel_last else kernel_first
228 else:
229 nb_lang = kernelspec.get("language", "text")
230 if "display_name" in kernelspec.keys():
231 yield _rich(kernelspec["display_name"])
233 for cell in self.cells:
234 if isinstance(cell, CodeCell):
235 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang)
236 yield self.cells
238 @classmethod
239 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
240 """Parse notebook from a path."""
241 content_arg = parse_kwargs.pop("content_type", None)
242 if content_arg is not None:
243 raise ValueError(
244 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
245 )
247 path = Path(path) if not isinstance(path, Path) else path
248 return JupyterNotebook.model_validate_json(json_data=path.read_text())
250 def write(
251 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
252 ) -> None:
253 """Write notebook to disk."""
254 path = Path(path) if not isinstance(path, Path) else path
255 json_kwargs = {"indent": 2, **json_kwargs}
256 if path.is_file() and not overwrite:
257 raise ValueError(
258 f"File exists at {path} exists. Specify `overwrite = True`."
259 )
261 self.__class__.model_validate(self.dict())
263 with path.open("w") as f:
264 json.dump(self.dict(), fp=f, **json_kwargs)
266 def clear_metadata(
267 self,
268 *,
269 notebook_metadata_keep: Sequence[str] = None,
270 notebook_metadata_remove: Sequence[str] = None,
271 **cell_kwargs: Any,
272 ) -> None:
273 """
274 Clear notebook and cell metadata.
276 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
277 sequence (i.e.: `()`) to remove all extra fields.
278 :param notebook_metadata_remove: Metadata values to remove
279 :param cell_kwargs: keyword arguments to be passed to each cell's
280 `databooks.data_models.cell.BaseCell.clear_metadata`
281 :return:
282 """
283 nargs = sum(
284 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
285 )
286 if nargs != 1:
287 raise ValueError(
288 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
289 f" must be passed, got {nargs} arguments."
290 )
291 if notebook_metadata_keep is not None:
292 notebook_metadata_remove = tuple(
293 field
294 for field, _ in self.metadata
295 if field not in notebook_metadata_keep
296 )
297 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
299 if len(cell_kwargs) > 0:
300 _clean_cells = deepcopy(self.cells)
301 for cell in _clean_cells:
302 cell.clear_fields(**cell_kwargs)
303 self.cells = _clean_cells