Coverage for databooks/data_models/notebook.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

121 statements  

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4import json 

5from copy import deepcopy 

6from difflib import SequenceMatcher 

7from itertools import chain 

8from pathlib import Path 

9from typing import ( 

10 Any, 

11 Callable, 

12 Dict, 

13 Generator, 

14 Iterable, 

15 List, 

16 Optional, 

17 Sequence, 

18 Tuple, 

19 TypeVar, 

20 Union, 

21) 

22 

23from pydantic import Extra, PositiveInt, root_validator, validate_model, validator 

24from pydantic.generics import GenericModel 

25 

26from databooks.data_models.base import BaseCells, DatabooksBase 

27from databooks.logging import get_logger 

28 

29logger = get_logger(__file__) 

30 

31 

32class NotebookMetadata(DatabooksBase): 

33 """Notebook metadata. Empty by default but can accept extra fields.""" 

34 

35 

36class CellMetadata(DatabooksBase): 

37 """Cell metadata. Empty by default but can accept extra fields.""" 

38 

39 

40class Cell(DatabooksBase): 

41 """ 

42 Jupyter notebook cells. 

43 

44 Fields `outputs` and `execution_count` are not included since they should only be 

45 present in code cells - thus are treated as extra fields. 

46 """ 

47 

48 metadata: CellMetadata 

49 source: Union[List[str], str] 

50 cell_type: str 

51 

52 def __hash__(self) -> int: 

53 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

54 return hash( 

55 (type(self),) + tuple(v) if isinstance(v, list) else v 

56 for v in self.__dict__.values() 

57 ) 

58 

59 def remove_fields( 

60 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

61 ) -> None: 

62 """ 

63 Remove Cell fields. 

64 

65 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

66 fields for `databooks.data_models.notebook.Cell`. 

67 """ 

68 # Ignore required `Cell` fields 

69 cell_fields = self.__fields__ # required fields especified in class definition 

70 if any(field in fields for field in cell_fields): 

71 logger.debug( 

72 "Ignoring removal of required fields " 

73 + str([f for f in fields if f in cell_fields]) 

74 + f" in `{type(self).__name__}`." 

75 ) 

76 fields = [f for f in fields if f not in cell_fields] 

77 

78 super(Cell, self).remove_fields(fields, missing_ok=missing_ok) 

79 

80 if self.cell_type == "code": 

81 self.outputs: List[Dict[str, Any]] = ( 

82 [] if "outputs" not in dict(self) else self.outputs 

83 ) 

84 self.execution_count: Optional[PositiveInt] = ( 

85 None if "execution_count" not in dict(self) else self.execution_count 

86 ) 

87 

88 def clear_fields( 

89 self, 

90 *, 

91 cell_metadata_keep: Sequence[str] = None, 

92 cell_metadata_remove: Sequence[str] = None, 

93 cell_remove_fields: Sequence[str] = (), 

94 ) -> None: 

95 """ 

96 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

97 

98 You can also specify metadata to keep or remove from the `metadata` property of 

99 `databooks.data_models.notebook.Cell`. 

100 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

101 sequence (i.e.: `()`) to remove all extra fields. 

102 :param cell_metadata_remove: Metadata values to remove 

103 :param cell_remove_fields: Fields to remove from cell 

104 :return: 

105 """ 

106 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

107 if nargs != 1: 

108 raise ValueError( 

109 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

110 f" be passed, got {nargs} arguments." 

111 ) 

112 

113 if cell_metadata_keep is not None: 

114 cell_metadata_remove = tuple( 

115 field for field, _ in self.metadata if field not in cell_metadata_keep 

116 ) 

117 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

118 

119 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

120 

121 @validator("cell_type") 

122 def cell_has_valid_type(cls, v: str) -> str: 

123 """Check if cell has one of the three predefined types.""" 

124 valid_cell_types = ("raw", "markdown", "code") 

125 if v not in valid_cell_types: 

126 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}") 

127 return v 

128 

129 @root_validator 

130 def code_cell_has_valid_outputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

131 """Check that code cells have list-type outputs.""" 

132 if values.get("cell_type") == "code" and "outputs" not in values: 

133 raise ValueError( 

134 f"All code cells must have an `outputs` property, got {values}" 

135 ) 

136 if not isinstance(values["outputs"], list): 

137 raise ValueError( 

138 f"Cell outputs must be a list, got {type(values['outputs'])}" 

139 ) 

140 return values 

141 

142 @root_validator 

143 def only_code_cells_have_outputs_and_execution_count( 

144 cls, values: Dict[str, Any] 

145 ) -> Dict[str, Any]: 

146 """Check that only code cells have outputs and execution count.""" 

147 if values.get("cell_type") != "code" and ( 

148 ("outputs" in values) or ("execution_count" in values) 

149 ): 

150 raise ValueError( 

151 "Found `outputs` or `execution_count` for cell of type" 

152 f" `{values['cell_type']}`" 

153 ) 

154 return values 

155 

156 

157T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]]) 

158 

159 

160class Cells(GenericModel, BaseCells[T]): 

161 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

162 

163 __root__: Sequence[T] = [] 

164 

165 def __init__(self, elements: Sequence[T] = ()) -> None: 

166 """Allow passing data as a positional argument when instantiating class.""" 

167 super(Cells, self).__init__(__root__=elements) 

168 

169 @property 

170 def data(self) -> List[T]: # type: ignore 

171 """Define property `data` required for `collections.UserList` class.""" 

172 return list(self.__root__) 

173 

174 def __iter__(self) -> Generator[Any, None, None]: 

175 """Use list property as iterable.""" 

176 return (el for el in self.data) 

177 

178 def __sub__( 

179 self: Cells[Cell], other: Cells[Cell] 

180 ) -> Cells[Tuple[List[Cell], List[Cell]]]: 

181 """Return the difference using `difflib.SequenceMatcher`.""" 

182 if type(self) != type(other): 

183 raise TypeError( 

184 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

185 f" `{type(other).__name__}`" 

186 ) 

187 

188 # By setting the context to the max number of cells and using 

189 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

190 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

191 n_context = max(len(self), len(other)) 

192 diff_opcodes = list( 

193 SequenceMatcher( 

194 isjunk=None, a=self, b=other, autojunk=False 

195 ).get_grouped_opcodes(n_context) 

196 ) 

197 

198 if len(diff_opcodes) > 1: 

199 raise RuntimeError( 

200 "Expected one group for opcodes when context size is " 

201 f" {n_context} for {len(self)} and {len(other)} cells in" 

202 " notebooks." 

203 ) 

204 return Cells[Tuple[List[Cell], List[Cell]]]( 

205 [ 

206 # https://github.com/python/mypy/issues/9459 

207 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

208 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

209 ] 

210 ) 

211 

212 @classmethod 

213 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

214 """Get validators for custom class.""" 

215 yield cls.validate 

216 

217 @classmethod 

218 def validate(cls, v: List[T]) -> Cells[T]: 

219 """Ensure object is custom defined container.""" 

220 if not isinstance(v, cls): 

221 return cls(v) 

222 else: 

223 return v 

224 

225 @staticmethod 

226 def wrap_git( 

227 first_cells: List[Cell], 

228 last_cells: List[Cell], 

229 hash_first: Optional[str] = None, 

230 hash_last: Optional[str] = None, 

231 ) -> List[Cell]: 

232 """Wrap git-diff cells in existing notebook.""" 

233 return ( 

234 [ 

235 Cell( 

236 metadata=CellMetadata(git_hash=hash_first), 

237 source=[f"`<<<<<<< {hash_first}`"], 

238 cell_type="markdown", 

239 ) 

240 ] 

241 + first_cells 

242 + [ 

243 Cell( 

244 source=["`=======`"], 

245 cell_type="markdown", 

246 metadata=CellMetadata(), 

247 ) 

248 ] 

249 + last_cells 

250 + [ 

251 Cell( 

252 metadata=CellMetadata(git_hash=hash_last), 

253 source=[f"`>>>>>>> {hash_last}`"], 

254 cell_type="markdown", 

255 ) 

256 ] 

257 ) 

258 

259 def resolve( 

260 self: Cells[Tuple[List[Cell], List[Cell]]], 

261 *, 

262 keep_first_cells: Optional[bool] = None, 

263 first_id: Optional[str] = None, 

264 last_id: Optional[str] = None, 

265 **kwargs: Any, 

266 ) -> List[Cell]: 

267 """ 

268 Resolve differences between `databooks.data_models.notebook.Cells`. 

269 

270 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

271 If `None`, then keep both wrapping the git-diff tags 

272 :param first_id: Git hash of first file in conflict 

273 :param last_id: Git hash of last file in conflict 

274 :param kwargs: (Unused) keyword arguments to keep compatibility with 

275 `databooks.data_models.base.resolve` 

276 :return: List of cells 

277 """ 

278 if keep_first_cells is not None: 

279 return list( 

280 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

281 ) 

282 return list( 

283 chain.from_iterable( 

284 Cells.wrap_git( 

285 first_cells=val[0], 

286 last_cells=val[1], 

287 hash_first=first_id, 

288 hash_last=last_id, 

289 ) 

290 if val[0] != val[1] 

291 else val[0] 

292 for val in self.data 

293 ) 

294 ) 

295 

296 

297class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

298 """Jupyter notebook. Extra fields yield invalid notebook.""" 

299 

300 nbformat: int 

301 nbformat_minor: int 

302 metadata: NotebookMetadata 

303 cells: Cells[Cell] 

304 

305 @classmethod 

306 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

307 """Parse notebook from a path.""" 

308 content_arg = parse_kwargs.pop("content_type", None) 

309 if content_arg is not None: 

310 raise ValueError( 

311 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

312 ) 

313 return super(JupyterNotebook, cls).parse_file( 

314 path=path, content_type="json", **parse_kwargs 

315 ) 

316 

317 def write( 

318 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any 

319 ) -> None: 

320 """Write notebook to disk.""" 

321 path = Path(path) if not isinstance(path, Path) else path 

322 json_kwargs = {"indent": 2, **json_kwargs} 

323 if path.is_file() and not overwrite: 

324 raise ValueError( 

325 f"File exists at {path} exists. Specify `overwrite = True`." 

326 ) 

327 

328 _, _, validation_error = validate_model(self.__class__, self.dict()) 

329 if validation_error: 

330 raise validation_error 

331 with path.open("w") as f: 

332 json.dump(self.dict(), fp=f, **json_kwargs) 

333 

334 def clear_metadata( 

335 self, 

336 *, 

337 notebook_metadata_keep: Sequence[str] = None, 

338 notebook_metadata_remove: Sequence[str] = None, 

339 **cell_kwargs: Any, 

340 ) -> None: 

341 """ 

342 Clear notebook and cell metadata. 

343 

344 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

345 sequence (i.e.: `()`) to remove all extra fields. 

346 :param notebook_metadata_remove: Metadata values to remove 

347 :param cell_kwargs: keyword arguments to be passed to each cell's 

348 `databooks.data_models.Cell.clear_metadata` 

349 :return: 

350 """ 

351 nargs = sum( 

352 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

353 ) 

354 if nargs != 1: 

355 raise ValueError( 

356 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

357 f" must be passed, got {nargs} arguments." 

358 ) 

359 if notebook_metadata_keep is not None: 

360 notebook_metadata_remove = tuple( 

361 field 

362 for field, _ in self.metadata 

363 if field not in notebook_metadata_keep 

364 ) 

365 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

366 

367 if len(cell_kwargs) > 0: 

368 _clean_cells = deepcopy(self.cells) 

369 for cell in _clean_cells: 

370 cell.clear_fields(**cell_kwargs) 

371 self.cells = _clean_cells