Coverage for databooks/data_models/notebook.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4from copy import deepcopy
5from difflib import SequenceMatcher
6from itertools import chain
7from pathlib import Path
8from typing import (
9 Any,
10 Callable,
11 Dict,
12 Generator,
13 Iterable,
14 List,
15 Optional,
16 Sequence,
17 Tuple,
18 TypeVar,
19 Union,
20)
22from pydantic import Extra, PositiveInt, root_validator, validator
23from pydantic.generics import GenericModel
25from databooks.data_models.base import BaseCells, DatabooksBase
26from databooks.logging import get_logger
28logger = get_logger(__file__)
31class NotebookMetadata(DatabooksBase):
32 """Notebook metadata. Empty by default but can accept extra fields."""
35class CellMetadata(DatabooksBase):
36 """Cell metadata. Empty by default but can accept extra fields."""
39class Cell(DatabooksBase):
40 """
41 Jupyter notebook cells.
43 Fields `outputs` and `execution_count` are not included since they should only be
44 present in code cells - thus are treated as extra fields.
45 """
47 metadata: CellMetadata
48 source: Union[List[str], str]
49 cell_type: str
51 def __hash__(self) -> int:
52 """Cells must be hashable for `difflib.SequenceMatcher`."""
53 return hash(
54 (type(self),) + tuple(v) if isinstance(v, list) else v
55 for v in self.__dict__.values()
56 )
58 def remove_fields(
59 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
60 ) -> None:
61 """
62 Remove Cell fields.
64 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
65 fields for `databooks.data_models.notebook.Cell`.
66 """
67 # Ignore required `Cell` fields
68 cell_fields = self.__fields__ # required fields especified in class definition
69 if any(field in fields for field in cell_fields):
70 logger.debug(
71 "Ignoring removal of required fields "
72 + str([f for f in fields if f in cell_fields])
73 + f" in `{type(self).__name__}`."
74 )
75 fields = [f for f in fields if f not in cell_fields]
77 super(Cell, self).remove_fields(fields, missing_ok=missing_ok)
79 if self.cell_type == "code":
80 self.outputs: List[Dict[str, Any]] = (
81 [] if "outputs" not in dict(self) else self.outputs
82 )
83 self.execution_count: Optional[PositiveInt] = (
84 None if "execution_count" not in dict(self) else self.execution_count
85 )
87 def clear_fields(
88 self,
89 *,
90 cell_metadata_keep: Sequence[str] = None,
91 cell_metadata_remove: Sequence[str] = None,
92 cell_remove_fields: Sequence[str] = (),
93 ) -> None:
94 """
95 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
97 You can also specify metadata to keep or remove from the `metadata` property of
98 `databooks.data_models.notebook.Cell`.
99 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
100 sequence (i.e.: `()`) to remove all extra fields.
101 :param cell_metadata_remove: Metadata values to remove
102 :param cell_remove_fields: Fields to remove from cell
103 :return:
104 """
105 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
106 if nargs != 1:
107 raise ValueError(
108 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
109 f" be passed, got {nargs} arguments."
110 )
112 if cell_metadata_keep is not None:
113 cell_metadata_remove = tuple(
114 field for field, _ in self.metadata if field not in cell_metadata_keep
115 )
116 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
118 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
120 @validator("cell_type")
121 def cell_has_valid_type(cls, v: str) -> str:
122 """Check if cell has one of the three predefined types."""
123 valid_cell_types = ("raw", "markdown", "code")
124 if v not in valid_cell_types:
125 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}")
126 return v
128 @root_validator
129 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]:
130 """Check that code cells have list-type outputs."""
131 if values["cell_type"] == "code" and not isinstance(values["outputs"], list):
132 raise ValueError(
133 "All code cells must have a list output property, got"
134 f" {type(values.get('outputs'))}"
135 )
136 return values
138 @root_validator
139 def only_code_cells_have_outputs_and_execution_count(
140 cls, values: Dict[str, Any]
141 ) -> Dict[str, Any]:
142 """Check that only code cells have outputs and execution count."""
143 if values["cell_type"] != "code" and (
144 ("outputs" in values) or ("execution_count" in values)
145 ):
146 raise ValueError(
147 "Found `outputs` or `execution_count` for cell of type"
148 f" `{values['cell_type']}`"
149 )
150 return values
153T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]])
156class Cells(GenericModel, BaseCells[T]):
157 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
159 __root__: Sequence[T] = []
161 def __init__(self, elements: Sequence[T] = ()) -> None:
162 """Allow passing data as a positional argument when instantiating class."""
163 super(Cells, self).__init__(__root__=elements)
165 @property
166 def data(self) -> List[T]: # type: ignore
167 """Define property `data` required for `collections.UserList` class."""
168 return list(self.__root__)
170 def __iter__(self) -> Generator[Any, None, None]:
171 """Use list property as iterable."""
172 return (el for el in self.data)
174 def __sub__(
175 self: Cells[Cell], other: Cells[Cell]
176 ) -> Cells[Tuple[List[Cell], List[Cell]]]:
177 """Return the difference using `difflib.SequenceMatcher`."""
178 if type(self) != type(other):
179 raise TypeError(
180 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
181 f" `{type(other).__name__}`"
182 )
184 # By setting the context to the max number of cells and using
185 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
186 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
187 n_context = max(len(self), len(other))
188 diff_opcodes = list(
189 SequenceMatcher(
190 isjunk=None, a=self, b=other, autojunk=False
191 ).get_grouped_opcodes(n_context)
192 )
194 if len(diff_opcodes) > 1:
195 raise RuntimeError(
196 "Expected one group for opcodes when context size is "
197 f" {n_context} for {len(self)} and {len(other)} cells in"
198 " notebooks."
199 )
200 return Cells[Tuple[List[Cell], List[Cell]]](
201 [
202 # https://github.com/python/mypy/issues/9459
203 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore
204 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
205 ]
206 )
208 @classmethod
209 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
210 """Get validators for custom class."""
211 yield cls.validate
213 @classmethod
214 def validate(cls, v: List[T]) -> Cells[T]:
215 """Ensure object is custom defined container."""
216 if not isinstance(v, cls):
217 return cls(v)
218 else:
219 return v
221 @staticmethod
222 def wrap_git(
223 first_cells: List[Cell],
224 last_cells: List[Cell],
225 hash_first: Optional[str] = None,
226 hash_last: Optional[str] = None,
227 ) -> List[Cell]:
228 """Wrap git-diff cells in existing notebook."""
229 return (
230 [
231 Cell(
232 metadata=CellMetadata(git_hash=hash_first),
233 source=[f"`<<<<<<< {hash_first}`"],
234 cell_type="markdown",
235 )
236 ]
237 + first_cells
238 + [
239 Cell(
240 source=["`=======`"],
241 cell_type="markdown",
242 metadata=CellMetadata(),
243 )
244 ]
245 + last_cells
246 + [
247 Cell(
248 metadata=CellMetadata(git_hash=hash_last),
249 source=[f"`>>>>>>> {hash_last}`"],
250 cell_type="markdown",
251 )
252 ]
253 )
255 def resolve(
256 self: Cells[Tuple[List[Cell], List[Cell]]],
257 *,
258 keep_first_cells: Optional[bool] = None,
259 first_id: Optional[str] = None,
260 last_id: Optional[str] = None,
261 **kwargs: Any,
262 ) -> List[Cell]:
263 """
264 Resolve differences between `databooks.data_models.notebook.Cells`.
266 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
267 If `None`, then keep both wrapping the git-diff tags
268 :param first_id: Git hash of first file in conflict
269 :param last_id: Git hash of last file in conflict
270 :param kwargs: (Unused) keyword arguments to keep compatibility with
271 `databooks.data_models.base.resolve`
272 :return: List of cells
273 """
274 if keep_first_cells is not None:
275 return list(
276 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
277 )
278 return list(
279 chain.from_iterable(
280 Cells.wrap_git(
281 first_cells=val[0],
282 last_cells=val[1],
283 hash_first=first_id,
284 hash_last=last_id,
285 )
286 if val[0] != val[1]
287 else val[0]
288 for val in self.data
289 )
290 )
293class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
294 """Jupyter notebook. Extra fields yield invalid notebook."""
296 nbformat: int
297 nbformat_minor: int
298 metadata: NotebookMetadata
299 cells: Cells[Cell]
301 @classmethod
302 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
303 """Parse notebook from a path."""
304 content_arg = parse_kwargs.pop("content_type", None)
305 if content_arg is not None:
306 raise ValueError(
307 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
308 )
309 return super(JupyterNotebook, cls).parse_file(
310 path=path, content_type="json", **parse_kwargs
311 )
313 def clear_metadata(
314 self,
315 *,
316 notebook_metadata_keep: Sequence[str] = None,
317 notebook_metadata_remove: Sequence[str] = None,
318 **cell_kwargs: Any,
319 ) -> None:
320 """
321 Clear notebook and cell metadata.
323 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
324 sequence (i.e.: `()`) to remove all extra fields.
325 :param notebook_metadata_remove: Metadata values to remove
326 :param cell_kwargs: keyword arguments to be passed to each cell's
327 `databooks.data_models.Cell.clear_metadata`
328 :return:
329 """
330 nargs = sum(
331 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
332 )
333 if nargs != 1:
334 raise ValueError(
335 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
336 f" must be passed, got {nargs} arguments."
337 )
338 if notebook_metadata_keep is not None:
339 notebook_metadata_remove = tuple(
340 field
341 for field, _ in self.metadata
342 if field not in notebook_metadata_keep
343 )
344 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
346 if len(cell_kwargs) > 0:
347 _clean_cells = deepcopy(self.cells)
348 for cell in _clean_cells:
349 cell.clear_fields(**cell_kwargs)
350 self.cells = _clean_cells