Coverage for databooks/data_models/notebook.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4from copy import deepcopy
5from difflib import SequenceMatcher
6from itertools import chain
7from pathlib import Path
8from typing import (
9 Any,
10 Callable,
11 Dict,
12 Generator,
13 Iterable,
14 List,
15 Optional,
16 Sequence,
17 Tuple,
18 TypeVar,
19 Union,
20)
22from pydantic import Extra, PositiveInt, root_validator, validator
23from pydantic.generics import GenericModel
25from databooks.data_models.base import BaseCells, DatabooksBase
26from databooks.logging import get_logger
28logger = get_logger(__file__)
31class NotebookMetadata(DatabooksBase):
32 """Notebook metadata. Empty by default but can accept extra fields."""
35class CellMetadata(DatabooksBase):
36 """Cell metadata. Empty by default but can accept extra fields."""
39class Cell(DatabooksBase):
40 """
41 Jupyter notebook cells.
43 Fields `outputs` and `execution_count` are not included since they should only be
44 present in code cells - thus are treated as extra fields.
45 """
47 metadata: CellMetadata
48 source: Union[List[str], str]
49 cell_type: str
51 def __hash__(self) -> int:
52 """Cells must be hashable for `difflib.SequenceMatcher`."""
53 return hash(
54 (type(self),) + tuple(v) if isinstance(v, list) else v
55 for v in self.__dict__.values()
56 )
58 def remove_fields(
59 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
60 ) -> None:
61 """
62 Remove Cell fields.
64 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
65 fields for `databooks.data_models.notebook.Cell`.
66 """
67 # Ignore required `Cell` fields
68 cell_fields = self.__fields__ # required fields especified in class definition
69 if any(field in fields for field in cell_fields):
70 logger.debug(
71 "Ignoring removal of required fields "
72 + str([f for f in fields if f in cell_fields])
73 + f" in `{type(self).__name__}`."
74 )
75 fields = [f for f in fields if f not in cell_fields]
77 super(Cell, self).remove_fields(fields, missing_ok=missing_ok)
79 if self.cell_type == "code":
80 self.outputs: List[Dict[str, Any]] = (
81 [] if "outputs" not in dict(self) else self.outputs
82 )
83 self.execution_count: Optional[PositiveInt] = (
84 None if "execution_count" not in dict(self) else self.execution_count
85 )
87 def clear_fields(
88 self,
89 *,
90 cell_metadata_keep: Sequence[str] = None,
91 cell_metadata_remove: Sequence[str] = None,
92 cell_remove_fields: Sequence[str] = (),
93 ) -> None:
94 """
95 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
97 You can also specify metadata to keep or remove from the `metadata` property of
98 `databooks.data_models.notebook.Cell`.
99 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
100 sequence (i.e.: `()`) to remove all extra fields.
101 :param cell_metadata_remove: Metadata values to remove
102 :param cell_remove_fields: Fields to remove from cell
103 :return:
104 """
105 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
106 if nargs != 1:
107 raise ValueError(
108 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
109 f" be passed, got {nargs} arguments."
110 )
112 if cell_metadata_keep is not None:
113 cell_metadata_remove = tuple(
114 field for field, _ in self.metadata if field not in cell_metadata_keep
115 )
116 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
118 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
120 @validator("cell_type")
121 def cell_has_valid_type(cls, v: str) -> str:
122 """Check if cell has one of the three predefined types."""
123 valid_cell_types = ("raw", "markdown", "code")
124 if v not in valid_cell_types:
125 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}")
126 return v
128 @root_validator
129 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]:
130 """Check that code cells have list-type outputs."""
131 if values["cell_type"] == "code" and not isinstance(values["outputs"], list):
132 raise ValueError(
133 "All code cells must have a list output property, got"
134 f" {type(values.get('outputs'))}"
135 )
136 return values
138 @root_validator
139 def only_code_cells_have_outputs_and_execution_count(
140 cls, values: Dict[str, Any]
141 ) -> Dict[str, Any]:
142 """Check that only code cells have outputs and execution count."""
143 if values["cell_type"] != "code" and (
144 ("outputs" in values) or ("execution_count" in values)
145 ):
146 raise ValueError(
147 "Found `outputs` or `execution_count` for cell of type"
148 f" `{values['cell_type']}`"
149 )
150 return values
153T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]])
156class Cells(GenericModel, BaseCells[T]):
157 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
159 __root__: Sequence[T] = []
161 def __init__(self, elements: Sequence[T] = ()) -> None:
162 """Allow passing data as a positional argument when instantiating class."""
163 super(Cells, self).__init__(__root__=elements)
165 @property
166 def data(self) -> List[T]: # type: ignore
167 """Define property `data` required for `collections.UserList` class."""
168 return list(self.__root__)
170 def __iter__(self) -> Generator[Any, None, None]:
171 """Use list property as iterable."""
172 return (el for el in self.data)
174 def __sub__(
175 self: Cells[Cell], other: Cells[Cell]
176 ) -> Cells[Tuple[List[Cell], List[Cell]]]:
177 """Return the difference using `difflib.SequenceMatcher`."""
178 if type(self) != type(other):
179 raise TypeError(
180 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
181 f" `{type(other).__name__}`"
182 )
184 # By setting the context to the max number of cells and using
185 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
186 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
187 n_context = max(len(self), len(other))
188 diff_opcodes = list(
189 SequenceMatcher(
190 isjunk=None, a=self, b=other, autojunk=False
191 ).get_grouped_opcodes(n_context)
192 )
194 if len(diff_opcodes) > 1:
195 raise RuntimeError(
196 "Expected one group for opcodes when context size is "
197 f" {n_context} for {len(self)} and {len(other)} cells in"
198 " notebooks."
199 )
200 return Cells[Tuple[List[Cell], List[Cell]]](
201 [
202 # https://github.com/python/mypy/issues/9459
203 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
204 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
205 ]
206 )
208 @classmethod
209 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
210 """Get validators for custom class."""
211 yield cls.validate
213 @classmethod
214 def validate(cls, v: List[T]) -> Cells[T]:
215 """Ensure object is custom defined container."""
216 if not isinstance(v, cls):
217 return cls(v)
218 else:
219 return v
221 @classmethod
222 def wrap_git(
223 cls,
224 first_cells: List[Cell],
225 last_cells: List[Cell],
226 hash_first: Optional[str] = None,
227 hash_last: Optional[str] = None,
228 ) -> List[Cell]:
229 """Wrap git-diff cells in existing notebook."""
230 return (
231 [
232 Cell(
233 metadata=CellMetadata(git_hash=hash_first),
234 source=[f"`<<<<<<< {hash_first}`"],
235 cell_type="markdown",
236 )
237 ]
238 + first_cells
239 + [
240 Cell(
241 source=["`=======`"],
242 cell_type="markdown",
243 metadata=CellMetadata(),
244 )
245 ]
246 + last_cells
247 + [
248 Cell(
249 metadata=CellMetadata(git_hash=hash_last),
250 source=[f"`>>>>>>> {hash_last}`"],
251 cell_type="markdown",
252 )
253 ]
254 )
256 def resolve(
257 self: Cells[Tuple[List[Cell], List[Cell]]],
258 *,
259 keep_first_cells: Optional[bool] = None,
260 first_id: Optional[str] = None,
261 last_id: Optional[str] = None,
262 **kwargs: Any,
263 ) -> List[Cell]:
264 """
265 Resolve differences between `databooks.data_models.notebook.Cells`.
267 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
268 If `None`, then keep both wrapping the git-diff tags
269 :param first_id: Git hash of first file in conflict
270 :param last_id: Git hash of last file in conflict
271 :param kwargs: (Unused) keyword arguments to keep compatibility with
272 `databooks.data_models.base.resolve`
273 :return: List of cells
274 """
275 if keep_first_cells is not None:
276 return list(
277 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
278 )
279 return list(
280 chain.from_iterable(
281 Cells.wrap_git(
282 first_cells=val[0],
283 last_cells=val[1],
284 hash_first=first_id,
285 hash_last=last_id,
286 )
287 if val[0] != val[1]
288 else val[0]
289 for val in self.data
290 )
291 )
294class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
295 """Jupyter notebook. Extra fields yield invalid notebook."""
297 nbformat: int
298 nbformat_minor: int
299 metadata: NotebookMetadata
300 cells: Cells[Cell]
302 @classmethod
303 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
304 """Parse notebook from a path."""
305 content_arg = parse_kwargs.pop("content_type", None)
306 if content_arg is not None:
307 raise ValueError(
308 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
309 )
310 return super(JupyterNotebook, cls).parse_file(
311 path=path, content_type="json", **parse_kwargs
312 )
314 def clear_metadata(
315 self,
316 *,
317 notebook_metadata_keep: Sequence[str] = None,
318 notebook_metadata_remove: Sequence[str] = None,
319 **cell_kwargs: Any,
320 ) -> None:
321 """
322 Clear notebook and cell metadata.
324 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
325 sequence (i.e.: `()`) to remove all extra fields.
326 :param notebook_metadata_remove: Metadata values to remove
327 :param cell_kwargs: keyword arguments to be passed to each cell's
328 `databooks.data_models.Cell.clear_metadata`
329 :return:
330 """
331 nargs = sum(
332 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
333 )
334 if nargs != 1:
335 raise ValueError(
336 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
337 f" must be passed, got {nargs} arguments."
338 )
339 if notebook_metadata_keep is not None:
340 notebook_metadata_remove = tuple(
341 field
342 for field, _ in self.metadata
343 if field not in notebook_metadata_keep
344 )
345 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
347 if len(cell_kwargs) > 0:
348 _clean_cells = deepcopy(self.cells)
349 for cell in _clean_cells:
350 cell.clear_fields(**cell_kwargs)
351 self.cells = _clean_cells