Coverage for databooks/data_models/notebook.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Data models - Jupyter Notebooks and components."""
2from __future__ import annotations
4from copy import deepcopy
5from difflib import SequenceMatcher
6from itertools import chain
7from pathlib import Path
8from typing import (
9 Any,
10 Callable,
11 Dict,
12 Generator,
13 List,
14 Optional,
15 Sequence,
16 Tuple,
17 TypeVar,
18 Union,
19)
21from pydantic import Extra, root_validator, validator
22from pydantic.generics import GenericModel
24from databooks.data_models.base import BaseCells, DatabooksBase
27class NotebookMetadata(DatabooksBase):
28 """Notebook metadata. Empty by default but can accept extra fields."""
30 ...
33class CellMetadata(DatabooksBase):
34 """Cell metadata. Empty by default but can accept extra fields."""
36 ...
39class Cell(DatabooksBase):
40 """
41 Jupyter notebook cells.
43 Fields `outputs` and `execution_count` are not included since they should only be
44 present in code cells - thus are treated as extra fields.
45 """
47 metadata: CellMetadata
48 source: Union[List[str], str]
49 cell_type: str
51 def __hash__(self) -> int:
52 """Cells must be hashable for `difflib.SequenceMatcher`."""
53 return hash(
54 (type(self),) + tuple(v) if isinstance(v, list) else v
55 for v in self.__dict__.values()
56 )
58 def clear_metadata(
59 self,
60 *,
61 cell_metadata_keep: Sequence[str] = None,
62 cell_metadata_remove: Sequence[str] = None,
63 cell_execution_count: bool = True,
64 cell_outputs: bool = False,
65 remove_fields: List[str] = ["id"],
66 ) -> None:
67 """
68 Clear cell metadata, execution count and outputs.
70 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
71 sequence (i.e.: `()`) to remove all extra fields.
72 :param cell_metadata_remove: Metadata values to remove
73 :param cell_execution_count: Whether or not to keep the execution count
74 :param cell_outputs: whether or not to keep the cell outputs
75 :return:
76 """
77 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
78 if nargs != 1:
79 raise ValueError(
80 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
81 f" be passed, got {nargs} arguments."
82 )
83 if cell_metadata_keep is not None:
84 cell_metadata_remove = tuple(
85 field for field, _ in self.metadata if field not in cell_metadata_keep
86 )
87 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
89 self.remove_fields(remove_fields, missing_ok=True)
90 if self.cell_type == "code":
91 if cell_outputs:
92 self.outputs: List[Dict[str, Any]] = []
93 if cell_execution_count:
94 self.execution_count = None
96 @validator("cell_type")
97 def cell_has_valid_type(cls, v: str) -> str:
98 """Check if cell has one of the three predefined types."""
99 valid_cell_types = ("raw", "markdown", "code")
100 if v not in valid_cell_types:
101 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}")
102 return v
104 @root_validator
105 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]:
106 """Check that code cells have list-type outputs."""
107 if values["cell_type"] == "code" and not isinstance(values["outputs"], list):
108 raise ValueError(
109 "All code cells must have a list output property, got"
110 f" {type(values.get('outputs'))}"
111 )
112 return values
114 @root_validator
115 def only_code_cells_have_outputs_and_execution_count(
116 cls, values: Dict[str, Any]
117 ) -> Dict[str, Any]:
118 """Check that only code cells have outputs and execution count."""
119 if values["cell_type"] != "code" and (
120 ("outputs" in values) or ("execution_count" in values)
121 ):
122 raise ValueError(
123 "Found `outputs` or `execution_count` for cell of type"
124 f" `{values['cell_type']}`"
125 )
126 return values
129T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]])
132class Cells(GenericModel, BaseCells[T]):
133 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`."""
135 __root__: Sequence[T] = []
137 def __init__(self, elements: Sequence[T] = ()) -> None:
138 """Allow passing data as a positional argument when instantiating class."""
139 super(Cells, self).__init__(__root__=elements)
141 @property
142 def data(self) -> List[T]: # type: ignore
143 """Define property `data` required for `collections.UserList` class."""
144 return list(self.__root__)
146 def __iter__(self) -> Generator[Any, None, None]:
147 """Use list property as iterable."""
148 return (el for el in self.data)
150 def __sub__(
151 self: Cells[Cell], other: Cells[Cell]
152 ) -> Cells[Tuple[List[Cell], List[Cell]]]:
153 """Return the difference using `difflib.SequenceMatcher`."""
154 if type(self) != type(other):
155 raise TypeError(
156 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
157 f" `{type(other).__name__}`"
158 )
160 _self = deepcopy(self)
161 _other = deepcopy(other)
162 for cells in (_self, _other):
163 for cell in cells:
164 cell.remove_fields(["id"], missing_ok=True)
166 # By setting the context to the max number of cells and using
167 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same
168 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks
169 n_context = max(len(_self), len(_other))
170 diff_opcodes = list(
171 SequenceMatcher(
172 isjunk=None, a=_self, b=_other, autojunk=False
173 ).get_grouped_opcodes(n_context)
174 )
176 if len(diff_opcodes) > 1:
177 raise RuntimeError(
178 "Expected one group for opcodes when context size is "
179 f" {n_context} for {len(_self)} and {len(_other)} cells in"
180 " notebooks."
181 )
182 return Cells[Tuple[List[Cell], List[Cell]]](
183 [
184 # https://github.com/python/mypy/issues/9459
185 tuple((_self.data[i1:j1], _other.data[i2:j2])) # type: ignore
186 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes)
187 ]
188 )
190 @classmethod
191 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
192 """Get validators for custom class."""
193 yield cls.validate
195 @classmethod
196 def validate(cls, v: List[T]) -> Cells[T]:
197 """Ensure object is custom defined container."""
198 if not isinstance(v, cls):
199 return cls(v)
200 else:
201 return v
203 @classmethod
204 def wrap_git(
205 cls,
206 first_cells: List[Cell],
207 last_cells: List[Cell],
208 hash_first: Optional[str] = None,
209 hash_last: Optional[str] = None,
210 ) -> List[Cell]:
211 """Wrap git-diff cells in existing notebook."""
212 return (
213 [
214 Cell(
215 metadata=CellMetadata(git_hash=hash_first),
216 source=[f"`<<<<<<< {hash_first}`"],
217 cell_type="markdown",
218 )
219 ]
220 + first_cells
221 + [
222 Cell(
223 source=["`=======`"],
224 cell_type="markdown",
225 metadata=CellMetadata(),
226 )
227 ]
228 + last_cells
229 + [
230 Cell(
231 metadata=CellMetadata(git_hash=hash_last),
232 source=[f"`>>>>>>> {hash_last}`"],
233 cell_type="markdown",
234 )
235 ]
236 )
238 def resolve(
239 self: Cells[Tuple[List[Cell], List[Cell]]],
240 *,
241 keep_first_cells: Optional[bool] = None,
242 first_id: Optional[str] = None,
243 last_id: Optional[str] = None,
244 **kwargs: Any,
245 ) -> List[Cell]:
246 """
247 Resolve differences between `databooks.data_models.notebook.Cells`.
249 :param keep_first_cells: Whether to keep the cells of the first notebook or not.
250 If `None`, then keep both wrapping the git-diff tags
251 :param first_id: Git hash of first file in conflict
252 :param last_id: Git hash of last file in conflict
253 :param kwargs: (Unused) keyword arguments to keep compatibility with
254 `databooks.data_models.base.resolve`
255 :return: List of cells
256 """
257 if keep_first_cells is not None:
258 return list(
259 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data)
260 )
261 return list(
262 chain.from_iterable(
263 Cells.wrap_git(
264 first_cells=val[0],
265 last_cells=val[1],
266 hash_first=first_id,
267 hash_last=last_id,
268 )
269 if val[0] != val[1]
270 else val[0]
271 for val in self.data
272 )
273 )
276class JupyterNotebook(DatabooksBase, extra=Extra.forbid):
277 """Jupyter notebook. Extra fields yield invalid notebook."""
279 nbformat: int
280 nbformat_minor: int
281 metadata: NotebookMetadata
282 cells: Cells[Cell]
284 @classmethod
285 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook:
286 """Parse notebook from a path."""
287 content_arg = parse_kwargs.pop("content_type", None)
288 if content_arg is not None:
289 raise ValueError(
290 f"Value of `content_type` must be `json` (default), got `{content_arg}`"
291 )
292 return super(JupyterNotebook, cls).parse_file(
293 path=path, content_type="json", **parse_kwargs
294 )
296 def clear_metadata(
297 self,
298 *,
299 notebook_metadata_keep: Sequence[str] = None,
300 notebook_metadata_remove: Sequence[str] = None,
301 **cell_kwargs: Any,
302 ) -> None:
303 """
304 Clear notebook and cell metadata.
306 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty
307 sequence (i.e.: `()`) to remove all extra fields.
308 :param notebook_metadata_remove: Metadata values to remove
309 :param cell_kwargs: keyword arguments to be passed to each cell's
310 `databooks.data_models.Cell.clear_metadata`
311 :return:
312 """
313 nargs = sum(
314 (notebook_metadata_keep is not None, notebook_metadata_remove is not None)
315 )
316 if nargs != 1:
317 raise ValueError(
318 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`"
319 f" must be passed, got {nargs} arguments."
320 )
321 if notebook_metadata_keep is not None:
322 notebook_metadata_remove = tuple(
323 field
324 for field, _ in self.metadata
325 if field not in notebook_metadata_keep
326 )
327 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore
329 if len(cell_kwargs) > 0:
330 _clean_cells = deepcopy(self.cells)
331 for cell in _clean_cells:
332 cell.clear_metadata(**cell_kwargs)
333 self.cells = _clean_cells