Coverage for databooks/data_models/notebook.py: 94%
117 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4import json
5from copy import deepcopy
6from difflib import SequenceMatcher
7from itertools import chain
8from pathlib import Path
9from typing import (
10 Any,
11 Callable,
12 Generator,
13 Iterable,
14 List,
15 Optional,
16 Sequence,
17 Tuple,
18 TypeVar,
19 Union,
20 cast,
21)
23from pydantic import Extra, validate_model
24from pydantic.generics import GenericModel
25from rich import box
26from rich.columns import Columns
27from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult
28from rich.panel import Panel
29from rich.text import Text
31from databooks.data_models.base import BaseCells, DatabooksBase
32from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell
33from databooks.logging import get_logger
35logger = get_logger(__file__)
37Cell = Union[CodeCell, RawCell, MarkdownCell]
38CellsPair = Tuple[List[Cell], List[Cell]]
39T = TypeVar("T", Cell, CellsPair)
42class Cells(GenericModel, BaseCells[T]):
43 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
45 __root__: Sequence[T] = ()
47 def __init__(self, elements: Sequence[T] = ()) -> None:
48 """Allow passing data as a positional argument when instantiating class."""
49 super(Cells, self).__init__(__root__=elements)
51 @property
52 def data(self) -> List[T]: # type: ignore
53 """Define property `data` required for `collections.UserList` class."""
54 return list(self.__root__)
56 def __iter__(self) -> Generator[Any, None, None]:
57 """Use list property as iterable."""
58 return (el for el in self.data)
60 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]:
61 """Return the difference using `difflib.SequenceMatcher`."""
62 if type(self) != type(other):
63 raise TypeError(
64 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
65 f" `{type(other).__name__}`"
66 )
68 # By setting the context to the max number of cells and using
69 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
70 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
71 n_context = max(len(self), len(other))
72 diff_opcodes = list(
73 SequenceMatcher(
74 isjunk=None, a=self, b=other, autojunk=False
75 ).get_grouped_opcodes(n_context)
76 )
78 if len(diff_opcodes) > 1:
79 raise RuntimeError(
80 "Expected one group for opcodes when context size is "
81 f" {n_context} for {len(self)} and {len(other)} cells in"
82 " notebooks."
83 )
84 return Cells[CellsPair](
85 [
86 # https://github.com/python/mypy/issues/9459
87 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
88 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
89 ]
90 )
92 def __rich_console__(
93 self, console: Console, options: ConsoleOptions
94 ) -> RenderResult:
95 """Rich display of all cells in notebook."""
96 yield from self._get_renderables(expand=True, width=options.max_width // 3)
98 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]:
99 """Get the Rich renderables, depending on whether `Cells` is a diff or not."""
100 if all(isinstance(el, tuple) for el in self.data):
101 return chain.from_iterable(
102 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs)
103 if val[0] != val[1]
104 else val[0]
105 for val in cast(List[CellsPair], self.data)
106 )
107 return cast(List[Cell], self.data)
109 @classmethod
110 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
111 """Get validators for custom class."""
112 yield cls.validate
114 @classmethod
115 def validate(cls, v: List[T]) -> Cells[T]:
116 """Ensure object is custom defined container."""
117 if not isinstance(v, cls):
118 return cls(v)
119 else:
120 return v
122 @classmethod
123 def wrap_cols(
124 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any
125 ) -> Sequence[Columns]:
126 """Wrap the first and second cells into colunmns for iterable."""
127 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)]
128 _first = Group(*first_cells or _empty)
129 _last = Group(*last_cells or _empty)
130 return [Columns([_first, _last], **cols_kwargs)]
132 @staticmethod
133 def wrap_git(
134 first_cells: List[Cell],
135 last_cells: List[Cell],
136 hash_first: Optional[str] = None,
137 hash_last: Optional[str] = None,
138 ) -> Sequence[Cell]:
139 """Wrap git-diff cells in existing notebook."""
140 return [
141 MarkdownCell(
142 metadata=CellMetadata(git_hash=hash_first),
143 source=[f"`<<<<<<< {hash_first}`"],
144 cell_type="markdown",
145 ),
146 *first_cells,
147 MarkdownCell(
148 source=["`=======`"],
149 cell_type="markdown",
150 metadata=CellMetadata(),
151 ),
152 *last_cells,
153 MarkdownCell(
154 metadata=CellMetadata(git_hash=hash_last),
155 source=[f"`>>>>>>> {hash_last}`"],
156 cell_type="markdown",
157 ),
158 ]
160 def resolve(
161 self: Cells[CellsPair],
162 *,
163 keep_first_cells: Optional[bool] = None,
164 first_id: Optional[str] = None,
165 last_id: Optional[str] = None,
166 **kwargs: Any,
167 ) -> List[Cell]:
168 """
169 Resolve differences between `databooks.data_models.notebook.Cells`.
171 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
172 If `None`, then keep both wrapping the git-diff tags
173 :param first_id: Git hash of first file in conflict
174 :param last_id: Git hash of last file in conflict
175 :param kwargs: (Unused) keyword arguments to keep compatibility with
176 `databooks.data_models.base.resolve`
177 :return: List of cells
178 """
179 if keep_first_cells is not None:
180 return list(
181 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
182 )
183 return list(
184 chain.from_iterable(
185 Cells.wrap_git(
186 first_cells=val[0],
187 last_cells=val[1],
188 hash_first=first_id,
189 hash_last=last_id,
190 )
191 if val[0] != val[1]
192 else val[0]
193 for val in self.data
194 )
195 )
198class NotebookMetadata(DatabooksBase):
199 """Notebook metadata. Empty by default but can accept extra fields."""
202class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
203 """Jupyter notebook. Extra fields yield invalid notebook."""
205 nbformat: int
206 nbformat_minor: int
207 metadata: NotebookMetadata
208 cells: Cells[Cell]
210 def __rich_console__(
211 self, console: Console, options: ConsoleOptions
212 ) -> RenderResult:
213 """Rich display notebook."""
215 def _rich(kernel: str) -> Text:
216 """Display with `kernel` theme, horizontal padding and right-justified."""
217 return Text(kernel, style="kernel", justify="right")
219 kernelspec = self.metadata.dict().get("kernelspec", {})
220 if isinstance(kernelspec, tuple): # check if this is a `DiffCells`
221 kernelspec = tuple(
222 ks or {"language": "text", "display_name": "null"} for ks in kernelspec
223 )
224 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec)
225 nb_lang = lang_first if lang_first == lang_last else "text"
226 if any("display_name" in ks.keys() for ks in kernelspec):
227 kernel_first, kernel_last = [
228 _rich(ks["display_name"]) for ks in kernelspec
229 ]
230 yield Columns(
231 [kernel_first, kernel_last],
232 expand=True,
233 width=options.max_width // 3,
234 ) if kernel_first != kernel_last else kernel_first
235 else:
236 nb_lang = kernelspec.get("language", "text")
237 if "display_name" in kernelspec.keys():
238 yield _rich(kernelspec["display_name"])
240 for cell in self.cells:
241 if isinstance(cell, CodeCell):
242 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang)
243 yield self.cells
245 @classmethod
246 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
247 """Parse notebook from a path."""
248 content_arg = parse_kwargs.pop("content_type", None)
249 if content_arg is not None:
250 raise ValueError(
251 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
252 )
253 return super(JupyterNotebook, cls).parse_file(
254 path=path, content_type="json", **parse_kwargs
255 )
257 def write(
258 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
259 ) -> None:
260 """Write notebook to disk."""
261 path = Path(path) if not isinstance(path, Path) else path
262 json_kwargs = {"indent": 2, **json_kwargs}
263 if path.is_file() and not overwrite:
264 raise ValueError(
265 f"File exists at {path} exists. Specify `overwrite = True`."
266 )
268 _, _, validation_error = validate_model(self.__class__, self.dict())
269 if validation_error:
270 raise validation_error
271 with path.open("w") as f:
272 json.dump(self.dict(), fp=f, **json_kwargs)
274 def clear_metadata(
275 self,
276 *,
277 notebook_metadata_keep: Sequence[str] = None,
278 notebook_metadata_remove: Sequence[str] = None,
279 **cell_kwargs: Any,
280 ) -> None:
281 """
282 Clear notebook and cell metadata.
284 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
285 sequence (i.e.: `()`) to remove all extra fields.
286 :param notebook_metadata_remove: Metadata values to remove
287 :param cell_kwargs: keyword arguments to be passed to each cell's
288 `databooks.data_models.cell.BaseCell.clear_metadata`
289 :return:
290 """
291 nargs = sum(
292 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
293 )
294 if nargs != 1:
295 raise ValueError(
296 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
297 f" must be passed, got {nargs} arguments."
298 )
299 if notebook_metadata_keep is not None:
300 notebook_metadata_remove = tuple(
301 field
302 for field, _ in self.metadata
303 if field not in notebook_metadata_keep
304 )
305 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
307 if len(cell_kwargs) > 0:
308 _clean_cells = deepcopy(self.cells)
309 for cell in _clean_cells:
310 cell.clear_fields(**cell_kwargs)
311 self.cells = _clean_cells