Coverage for databooks/data_models/notebook.py: 94%
116 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-04 16:41 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-04 16:41 +0000
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4import json
5from copy import deepcopy
6from difflib import SequenceMatcher
7from itertools import chain
8from pathlib import Path
9from typing import (
10 Any,
11 Callable,
12 Generator,
13 Iterable,
14 List,
15 Optional,
16 Sequence,
17 Tuple,
18 TypeVar,
19 Union,
20 cast,
21)
23from pydantic import Extra, validate_model
24from pydantic.generics import GenericModel
25from rich import box
26from rich.columns import Columns
27from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult
28from rich.panel import Panel
29from rich.text import Text
31from databooks.data_models.base import BaseCells, DatabooksBase
32from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell
33from databooks.logging import get_logger
35logger = get_logger(__file__)
37Cell = Union[CodeCell, RawCell, MarkdownCell]
38CellsPair = Tuple[List[Cell], List[Cell]]
39T = TypeVar("T", Cell, CellsPair)
42class Cells(GenericModel, BaseCells[T]):
43 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
45 __root__: Sequence[T] = ()
47 def __init__(self, elements: Sequence[T] = ()) -> None:
48 """Allow passing data as a positional argument when instantiating class."""
49 super(Cells, self).__init__(__root__=elements)
51 @property
52 def data(self) -> List[T]: # type: ignore
53 """Define property `data` required for `collections.UserList` class."""
54 return list(self.__root__)
56 def __iter__(self) -> Generator[Any, None, None]:
57 """Use list property as iterable."""
58 return (el for el in self.data)
60 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]:
61 """Return the difference using `difflib.SequenceMatcher`."""
62 if type(self) != type(other):
63 raise TypeError(
64 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
65 f" `{type(other).__name__}`"
66 )
68 # By setting the context to the max number of cells and using
69 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
70 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
71 n_context = max(len(self), len(other))
72 diff_opcodes = list(
73 SequenceMatcher(
74 isjunk=None, a=self, b=other, autojunk=False
75 ).get_grouped_opcodes(n_context)
76 )
78 if len(diff_opcodes) > 1:
79 raise RuntimeError(
80 "Expected one group for opcodes when context size is "
81 f" {n_context} for {len(self)} and {len(other)} cells in"
82 " notebooks."
83 )
84 return Cells[CellsPair](
85 [
86 # https://github.com/python/mypy/issues/9459
87 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
88 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
89 ]
90 )
92 def __rich_console__(
93 self, console: Console, options: ConsoleOptions
94 ) -> RenderResult:
95 """Rich display of all cells in notebook."""
96 yield from self._get_renderables(expand=True, width=options.max_width // 3)
98 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]:
99 """Get the Rich renderables, depending on whether `Cells` is a diff or not."""
100 if all(isinstance(el, tuple) for el in self.data):
101 return chain.from_iterable(
102 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs)
103 if val[0] != val[1]
104 else val[0]
105 for val in cast(List[CellsPair], self.data)
106 )
107 return cast(List[Cell], self.data)
109 @classmethod
110 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
111 """Get validators for custom class."""
112 yield cls.validate
114 @classmethod
115 def validate(cls, v: List[T]) -> Cells[T]:
116 """Ensure object is custom defined container."""
117 if not isinstance(v, cls):
118 return cls(v)
119 else:
120 return v
122 @classmethod
123 def wrap_cols(
124 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any
125 ) -> Sequence[Columns]:
126 """Wrap the first and second cells into colunmns for iterable."""
127 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)]
128 _first = Group(*first_cells or _empty)
129 _last = Group(*last_cells or _empty)
130 return [Columns([_first, _last], **cols_kwargs)]
132 @staticmethod
133 def wrap_git(
134 first_cells: List[Cell],
135 last_cells: List[Cell],
136 hash_first: Optional[str] = None,
137 hash_last: Optional[str] = None,
138 ) -> Sequence[Cell]:
139 """Wrap git-diff cells in existing notebook."""
140 return [
141 MarkdownCell(
142 metadata=CellMetadata(git_hash=hash_first),
143 source=[f"`<<<<<<< {hash_first}`"],
144 cell_type="markdown",
145 ),
146 *first_cells,
147 MarkdownCell(
148 source=["`=======`"],
149 cell_type="markdown",
150 metadata=CellMetadata(),
151 ),
152 *last_cells,
153 MarkdownCell(
154 metadata=CellMetadata(git_hash=hash_last),
155 source=[f"`>>>>>>> {hash_last}`"],
156 cell_type="markdown",
157 ),
158 ]
160 def resolve(
161 self: Cells[CellsPair],
162 *,
163 keep_first_cells: Optional[bool] = None,
164 first_id: Optional[str] = None,
165 last_id: Optional[str] = None,
166 **kwargs: Any,
167 ) -> List[Cell]:
168 """
169 Resolve differences between `databooks.data_models.notebook.Cells`.
171 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
172 If `None`, then keep both wrapping the git-diff tags
173 :param first_id: Git hash of first file in conflict
174 :param last_id: Git hash of last file in conflict
175 :param kwargs: (Unused) keyword arguments to keep compatibility with
176 `databooks.data_models.base.resolve`
177 :return: List of cells
178 """
179 if keep_first_cells is not None:
180 return list(
181 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
182 )
183 return list(
184 chain.from_iterable(
185 Cells.wrap_git(
186 first_cells=val[0],
187 last_cells=val[1],
188 hash_first=first_id,
189 hash_last=last_id,
190 )
191 if val[0] != val[1]
192 else val[0]
193 for val in self.data
194 )
195 )
198class NotebookMetadata(DatabooksBase):
199 """Notebook metadata. Empty by default but can accept extra fields."""
202class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
203 """Jupyter notebook. Extra fields yield invalid notebook."""
205 nbformat: int
206 nbformat_minor: int
207 metadata: NotebookMetadata
208 cells: Cells[Cell]
210 def __rich_console__(
211 self, console: Console, options: ConsoleOptions
212 ) -> RenderResult:
213 """Rich display notebook."""
215 def _rich(kernel: str) -> Text:
216 """Display with `kernel` theme, horizontal padding and right-justified."""
217 return Text(kernel, style="kernel", justify="right")
219 kernelspec = self.metadata.dict().get("kernelspec", {})
220 if isinstance(kernelspec, tuple): # check if this is a `DiffCells`
221 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec)
222 nb_lang = lang_first if lang_first == lang_last else "text"
223 if any("display_name" in ks.keys() for ks in kernelspec):
224 kernel_first, kernel_last = [
225 _rich(ks["display_name"]) for ks in kernelspec
226 ]
227 yield Columns(
228 [kernel_first, kernel_last],
229 expand=True,
230 width=options.max_width // 3,
231 ) if kernel_first != kernel_last else kernel_first
232 else:
233 nb_lang = kernelspec.get("language", "text")
234 if "display_name" in kernelspec.keys():
235 yield _rich(kernelspec["display_name"])
237 for cell in self.cells:
238 if isinstance(cell, CodeCell):
239 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang)
240 yield self.cells
242 @classmethod
243 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
244 """Parse notebook from a path."""
245 content_arg = parse_kwargs.pop("content_type", None)
246 if content_arg is not None:
247 raise ValueError(
248 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
249 )
250 return super(JupyterNotebook, cls).parse_file(
251 path=path, content_type="json", **parse_kwargs
252 )
254 def write(
255 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
256 ) -> None:
257 """Write notebook to disk."""
258 path = Path(path) if not isinstance(path, Path) else path
259 json_kwargs = {"indent": 2, **json_kwargs}
260 if path.is_file() and not overwrite:
261 raise ValueError(
262 f"File exists at {path} exists. Specify `overwrite = True`."
263 )
265 _, _, validation_error = validate_model(self.__class__, self.dict())
266 if validation_error:
267 raise validation_error
268 with path.open("w") as f:
269 json.dump(self.dict(), fp=f, **json_kwargs)
271 def clear_metadata(
272 self,
273 *,
274 notebook_metadata_keep: Sequence[str] = None,
275 notebook_metadata_remove: Sequence[str] = None,
276 **cell_kwargs: Any,
277 ) -> None:
278 """
279 Clear notebook and cell metadata.
281 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
282 sequence (i.e.: `()`) to remove all extra fields.
283 :param notebook_metadata_remove: Metadata values to remove
284 :param cell_kwargs: keyword arguments to be passed to each cell's
285 `databooks.data_models.cell.BaseCell.clear_metadata`
286 :return:
287 """
288 nargs = sum(
289 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
290 )
291 if nargs != 1:
292 raise ValueError(
293 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
294 f" must be passed, got {nargs} arguments."
295 )
296 if notebook_metadata_keep is not None:
297 notebook_metadata_remove = tuple(
298 field
299 for field, _ in self.metadata
300 if field not in notebook_metadata_keep
301 )
302 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
304 if len(cell_kwargs) > 0:
305 _clean_cells = deepcopy(self.cells)
306 for cell in _clean_cells:
307 cell.clear_fields(**cell_kwargs)
308 self.cells = _clean_cells