Coverage for databooks/data_models/notebook.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4import json
5from copy import deepcopy
6from difflib import SequenceMatcher
7from itertools import chain
8from pathlib import Path
9from typing import (
10 Any,
11 Callable,
12 Dict,
13 Generator,
14 Iterable,
15 List,
16 Optional,
17 Sequence,
18 Tuple,
19 TypeVar,
20 Union,
21)
23from pydantic import Extra, PositiveInt, root_validator, validate_model, validator
24from pydantic.generics import GenericModel
26from databooks.data_models.base import BaseCells, DatabooksBase
27from databooks.logging import get_logger
29logger = get_logger(__file__)
32class NotebookMetadata(DatabooksBase):
33 """Notebook metadata. Empty by default but can accept extra fields."""
36class CellMetadata(DatabooksBase):
37 """Cell metadata. Empty by default but can accept extra fields."""
40class Cell(DatabooksBase):
41 """
42 Jupyter notebook cells.
44 Fields `outputs` and `execution_count` are not included since they should only be
45 present in code cells - thus are treated as extra fields.
46 """
48 metadata: CellMetadata
49 source: Union[List[str], str]
50 cell_type: str
52 def __hash__(self) -> int:
53 """Cells must be hashable for `difflib.SequenceMatcher`."""
54 return hash(
55 (type(self),) + tuple(v) if isinstance(v, list) else v
56 for v in self.__dict__.values()
57 )
59 def remove_fields(
60 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
61 ) -> None:
62 """
63 Remove Cell fields.
65 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
66 fields for `databooks.data_models.notebook.Cell`.
67 """
68 # Ignore required `Cell` fields
69 cell_fields = self.__fields__ # required fields especified in class definition
70 if any(field in fields for field in cell_fields):
71 logger.debug(
72 "Ignoring removal of required fields "
73 + str([f for f in fields if f in cell_fields])
74 + f" in `{type(self).__name__}`."
75 )
76 fields = [f for f in fields if f not in cell_fields]
78 super(Cell, self).remove_fields(fields, missing_ok=missing_ok)
80 if self.cell_type == "code":
81 self.outputs: List[Dict[str, Any]] = (
82 [] if "outputs" not in dict(self) else self.outputs
83 )
84 self.execution_count: Optional[PositiveInt] = (
85 None if "execution_count" not in dict(self) else self.execution_count
86 )
88 def clear_fields(
89 self,
90 *,
91 cell_metadata_keep: Sequence[str] = None,
92 cell_metadata_remove: Sequence[str] = None,
93 cell_remove_fields: Sequence[str] = (),
94 ) -> None:
95 """
96 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
98 You can also specify metadata to keep or remove from the `metadata` property of
99 `databooks.data_models.notebook.Cell`.
100 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
101 sequence (i.e.: `()`) to remove all extra fields.
102 :param cell_metadata_remove: Metadata values to remove
103 :param cell_remove_fields: Fields to remove from cell
104 :return:
105 """
106 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
107 if nargs != 1:
108 raise ValueError(
109 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
110 f" be passed, got {nargs} arguments."
111 )
113 if cell_metadata_keep is not None:
114 cell_metadata_remove = tuple(
115 field for field, _ in self.metadata if field not in cell_metadata_keep
116 )
117 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
119 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
121 @validator("cell_type")
122 def cell_has_valid_type(cls, v: str) -> str:
123 """Check if cell has one of the three predefined types."""
124 valid_cell_types = ("raw", "markdown", "code")
125 if v not in valid_cell_types:
126 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}")
127 return v
129 @root_validator
130 def code_cell_has_valid_outputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
131 """Check that code cells have list-type outputs."""
132 if values.get("cell_type") == "code" and "outputs" not in values:
133 raise ValueError(
134 f"All code cells must have an `outputs` property, got {values}"
135 )
136 if not isinstance(values["outputs"], list):
137 raise ValueError(
138 f"Cell outputs must be a list, got {type(values['outputs'])}"
139 )
140 return values
142 @root_validator
143 def only_code_cells_have_outputs_and_execution_count(
144 cls, values: Dict[str, Any]
145 ) -> Dict[str, Any]:
146 """Check that only code cells have outputs and execution count."""
147 if values.get("cell_type") != "code" and (
148 ("outputs" in values) or ("execution_count" in values)
149 ):
150 raise ValueError(
151 "Found `outputs` or `execution_count` for cell of type"
152 f" `{values['cell_type']}`"
153 )
154 return values
157T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]])
160class Cells(GenericModel, BaseCells[T]):
161 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
163 __root__: Sequence[T] = []
165 def __init__(self, elements: Sequence[T] = ()) -> None:
166 """Allow passing data as a positional argument when instantiating class."""
167 super(Cells, self).__init__(__root__=elements)
169 @property
170 def data(self) -> List[T]: # type: ignore
171 """Define property `data` required for `collections.UserList` class."""
172 return list(self.__root__)
174 def __iter__(self) -> Generator[Any, None, None]:
175 """Use list property as iterable."""
176 return (el for el in self.data)
178 def __sub__(
179 self: Cells[Cell], other: Cells[Cell]
180 ) -> Cells[Tuple[List[Cell], List[Cell]]]:
181 """Return the difference using `difflib.SequenceMatcher`."""
182 if type(self) != type(other):
183 raise TypeError(
184 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
185 f" `{type(other).__name__}`"
186 )
188 # By setting the context to the max number of cells and using
189 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
190 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
191 n_context = max(len(self), len(other))
192 diff_opcodes = list(
193 SequenceMatcher(
194 isjunk=None, a=self, b=other, autojunk=False
195 ).get_grouped_opcodes(n_context)
196 )
198 if len(diff_opcodes) > 1:
199 raise RuntimeError(
200 "Expected one group for opcodes when context size is "
201 f" {n_context} for {len(self)} and {len(other)} cells in"
202 " notebooks."
203 )
204 return Cells[Tuple[List[Cell], List[Cell]]](
205 [
206 # https://github.com/python/mypy/issues/9459
207 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
208 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
209 ]
210 )
212 @classmethod
213 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
214 """Get validators for custom class."""
215 yield cls.validate
217 @classmethod
218 def validate(cls, v: List[T]) -> Cells[T]:
219 """Ensure object is custom defined container."""
220 if not isinstance(v, cls):
221 return cls(v)
222 else:
223 return v
225 @staticmethod
226 def wrap_git(
227 first_cells: List[Cell],
228 last_cells: List[Cell],
229 hash_first: Optional[str] = None,
230 hash_last: Optional[str] = None,
231 ) -> List[Cell]:
232 """Wrap git-diff cells in existing notebook."""
233 return (
234 [
235 Cell(
236 metadata=CellMetadata(git_hash=hash_first),
237 source=[f"`<<<<<<< {hash_first}`"],
238 cell_type="markdown",
239 )
240 ]
241 + first_cells
242 + [
243 Cell(
244 source=["`=======`"],
245 cell_type="markdown",
246 metadata=CellMetadata(),
247 )
248 ]
249 + last_cells
250 + [
251 Cell(
252 metadata=CellMetadata(git_hash=hash_last),
253 source=[f"`>>>>>>> {hash_last}`"],
254 cell_type="markdown",
255 )
256 ]
257 )
259 def resolve(
260 self: Cells[Tuple[List[Cell], List[Cell]]],
261 *,
262 keep_first_cells: Optional[bool] = None,
263 first_id: Optional[str] = None,
264 last_id: Optional[str] = None,
265 **kwargs: Any,
266 ) -> List[Cell]:
267 """
268 Resolve differences between `databooks.data_models.notebook.Cells`.
270 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
271 If `None`, then keep both wrapping the git-diff tags
272 :param first_id: Git hash of first file in conflict
273 :param last_id: Git hash of last file in conflict
274 :param kwargs: (Unused) keyword arguments to keep compatibility with
275 `databooks.data_models.base.resolve`
276 :return: List of cells
277 """
278 if keep_first_cells is not None:
279 return list(
280 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
281 )
282 return list(
283 chain.from_iterable(
284 Cells.wrap_git(
285 first_cells=val[0],
286 last_cells=val[1],
287 hash_first=first_id,
288 hash_last=last_id,
289 )
290 if val[0] != val[1]
291 else val[0]
292 for val in self.data
293 )
294 )
297class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
298 """Jupyter notebook. Extra fields yield invalid notebook."""
300 nbformat: int
301 nbformat_minor: int
302 metadata: NotebookMetadata
303 cells: Cells[Cell]
305 @classmethod
306 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
307 """Parse notebook from a path."""
308 content_arg = parse_kwargs.pop("content_type", None)
309 if content_arg is not None:
310 raise ValueError(
311 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
312 )
313 return super(JupyterNotebook, cls).parse_file(
314 path=path, content_type="json", **parse_kwargs
315 )
317 def write(
318 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any
319 ) -> None:
320 """Write notebook to disk."""
321 path = Path(path) if not isinstance(path, Path) else path
322 json_kwargs = {"indent": 2, **json_kwargs}
323 if path.is_file() and not overwrite:
324 raise ValueError(
325 f"File exists at {path} exists. Specify `overwrite = True`."
326 )
328 _, _, validation_error = validate_model(self.__class__, self.dict())
329 if validation_error:
330 raise validation_error
331 with path.open("w") as f:
332 json.dump(self.dict(), fp=f, **json_kwargs)
334 def clear_metadata(
335 self,
336 *,
337 notebook_metadata_keep: Sequence[str] = None,
338 notebook_metadata_remove: Sequence[str] = None,
339 **cell_kwargs: Any,
340 ) -> None:
341 """
342 Clear notebook and cell metadata.
344 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
345 sequence (i.e.: `()`) to remove all extra fields.
346 :param notebook_metadata_remove: Metadata values to remove
347 :param cell_kwargs: keyword arguments to be passed to each cell's
348 `databooks.data_models.Cell.clear_metadata`
349 :return:
350 """
351 nargs = sum(
352 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
353 )
354 if nargs != 1:
355 raise ValueError(
356 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
357 f" must be passed, got {nargs} arguments."
358 )
359 if notebook_metadata_keep is not None:
360 notebook_metadata_remove = tuple(
361 field
362 for field, _ in self.metadata
363 if field not in notebook_metadata_keep
364 )
365 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
367 if len(cell_kwargs) > 0:
368 _clean_cells = deepcopy(self.cells)
369 for cell in _clean_cells:
370 cell.clear_fields(**cell_kwargs)
371 self.cells = _clean_cells