Coverage for databooks/data_models/notebook.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

108 statements  

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4from copy import deepcopy 

5from difflib import SequenceMatcher 

6from itertools import chain 

7from pathlib import Path 

8from typing import ( 

9 Any, 

10 Callable, 

11 Dict, 

12 Generator, 

13 Iterable, 

14 List, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 TypeVar, 

19 Union, 

20) 

21 

22from pydantic import Extra, PositiveInt, root_validator, validator 

23from pydantic.generics import GenericModel 

24 

25from databooks.data_models.base import BaseCells, DatabooksBase 

26from databooks.logging import get_logger 

27 

28logger = get_logger(__file__) 

29 

30 

31class NotebookMetadata(DatabooksBase): 

32 """Notebook metadata. Empty by default but can accept extra fields.""" 

33 

34 

35class CellMetadata(DatabooksBase): 

36 """Cell metadata. Empty by default but can accept extra fields.""" 

37 

38 

39class Cell(DatabooksBase): 

40 """ 

41 Jupyter notebook cells. 

42 

43 Fields `outputs` and `execution_count` are not included since they should only be 

44 present in code cells - thus are treated as extra fields. 

45 """ 

46 

47 metadata: CellMetadata 

48 source: Union[List[str], str] 

49 cell_type: str 

50 

51 def __hash__(self) -> int: 

52 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

53 return hash( 

54 (type(self),) + tuple(v) if isinstance(v, list) else v 

55 for v in self.__dict__.values() 

56 ) 

57 

58 def remove_fields( 

59 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

60 ) -> None: 

61 """ 

62 Remove Cell fields. 

63 

64 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

65 fields for `databooks.data_models.notebook.Cell`. 

66 """ 

67 # Ignore required `Cell` fields 

68 cell_fields = self.__fields__ # required fields especified in class definition 

69 if any(field in fields for field in cell_fields): 

70 logger.debug( 

71 "Ignoring removal of required fields " 

72 + str([f for f in fields if f in cell_fields]) 

73 + f" in `{type(self).__name__}`." 

74 ) 

75 fields = [f for f in fields if f not in cell_fields] 

76 

77 super(Cell, self).remove_fields(fields, missing_ok=missing_ok) 

78 

79 if self.cell_type == "code": 

80 self.outputs: List[Dict[str, Any]] = ( 

81 [] if "outputs" not in dict(self) else self.outputs 

82 ) 

83 self.execution_count: Optional[PositiveInt] = ( 

84 None if "execution_count" not in dict(self) else self.execution_count 

85 ) 

86 

87 def clear_fields( 

88 self, 

89 *, 

90 cell_metadata_keep: Sequence[str] = None, 

91 cell_metadata_remove: Sequence[str] = None, 

92 cell_remove_fields: Sequence[str] = (), 

93 ) -> None: 

94 """ 

95 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

96 

97 You can also specify metadata to keep or remove from the `metadata` property of 

98 `databooks.data_models.notebook.Cell`. 

99 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

100 sequence (i.e.: `()`) to remove all extra fields. 

101 :param cell_metadata_remove: Metadata values to remove 

102 :param cell_remove_fields: Fields to remove from cell 

103 :return: 

104 """ 

105 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

106 if nargs != 1: 

107 raise ValueError( 

108 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

109 f" be passed, got {nargs} arguments." 

110 ) 

111 

112 if cell_metadata_keep is not None: 

113 cell_metadata_remove = tuple( 

114 field for field, _ in self.metadata if field not in cell_metadata_keep 

115 ) 

116 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

117 

118 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

119 

120 @validator("cell_type") 

121 def cell_has_valid_type(cls, v: str) -> str: 

122 """Check if cell has one of the three predefined types.""" 

123 valid_cell_types = ("raw", "markdown", "code") 

124 if v not in valid_cell_types: 

125 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}") 

126 return v 

127 

128 @root_validator 

129 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

130 """Check that code cells have list-type outputs.""" 

131 if values["cell_type"] == "code" and not isinstance(values["outputs"], list): 

132 raise ValueError( 

133 "All code cells must have a list output property, got" 

134 f" {type(values.get('outputs'))}" 

135 ) 

136 return values 

137 

138 @root_validator 

139 def only_code_cells_have_outputs_and_execution_count( 

140 cls, values: Dict[str, Any] 

141 ) -> Dict[str, Any]: 

142 """Check that only code cells have outputs and execution count.""" 

143 if values["cell_type"] != "code" and ( 

144 ("outputs" in values) or ("execution_count" in values) 

145 ): 

146 raise ValueError( 

147 "Found `outputs` or `execution_count` for cell of type" 

148 f" `{values['cell_type']}`" 

149 ) 

150 return values 

151 

152 

153T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]]) 

154 

155 

156class Cells(GenericModel, BaseCells[T]): 

157 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

158 

159 __root__: Sequence[T] = [] 

160 

161 def __init__(self, elements: Sequence[T] = ()) -> None: 

162 """Allow passing data as a positional argument when instantiating class.""" 

163 super(Cells, self).__init__(__root__=elements) 

164 

165 @property 

166 def data(self) -> List[T]: # type: ignore 

167 """Define property `data` required for `collections.UserList` class.""" 

168 return list(self.__root__) 

169 

170 def __iter__(self) -> Generator[Any, None, None]: 

171 """Use list property as iterable.""" 

172 return (el for el in self.data) 

173 

174 def __sub__( 

175 self: Cells[Cell], other: Cells[Cell] 

176 ) -> Cells[Tuple[List[Cell], List[Cell]]]: 

177 """Return the difference using `difflib.SequenceMatcher`.""" 

178 if type(self) != type(other): 

179 raise TypeError( 

180 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

181 f" `{type(other).__name__}`" 

182 ) 

183 

184 # By setting the context to the max number of cells and using 

185 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

186 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

187 n_context = max(len(self), len(other)) 

188 diff_opcodes = list( 

189 SequenceMatcher( 

190 isjunk=None, a=self, b=other, autojunk=False 

191 ).get_grouped_opcodes(n_context) 

192 ) 

193 

194 if len(diff_opcodes) > 1: 

195 raise RuntimeError( 

196 "Expected one group for opcodes when context size is " 

197 f" {n_context} for {len(self)} and {len(other)} cells in" 

198 " notebooks." 

199 ) 

200 return Cells[Tuple[List[Cell], List[Cell]]]( 

201 [ 

202 # https://github.com/python/mypy/issues/9459 

203 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

204 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

205 ] 

206 ) 

207 

208 @classmethod 

209 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

210 """Get validators for custom class.""" 

211 yield cls.validate 

212 

213 @classmethod 

214 def validate(cls, v: List[T]) -> Cells[T]: 

215 """Ensure object is custom defined container.""" 

216 if not isinstance(v, cls): 

217 return cls(v) 

218 else: 

219 return v 

220 

221 @staticmethod 

222 def wrap_git( 

223 first_cells: List[Cell], 

224 last_cells: List[Cell], 

225 hash_first: Optional[str] = None, 

226 hash_last: Optional[str] = None, 

227 ) -> List[Cell]: 

228 """Wrap git-diff cells in existing notebook.""" 

229 return ( 

230 [ 

231 Cell( 

232 metadata=CellMetadata(git_hash=hash_first), 

233 source=[f"`<<<<<<< {hash_first}`"], 

234 cell_type="markdown", 

235 ) 

236 ] 

237 + first_cells 

238 + [ 

239 Cell( 

240 source=["`=======`"], 

241 cell_type="markdown", 

242 metadata=CellMetadata(), 

243 ) 

244 ] 

245 + last_cells 

246 + [ 

247 Cell( 

248 metadata=CellMetadata(git_hash=hash_last), 

249 source=[f"`>>>>>>> {hash_last}`"], 

250 cell_type="markdown", 

251 ) 

252 ] 

253 ) 

254 

255 def resolve( 

256 self: Cells[Tuple[List[Cell], List[Cell]]], 

257 *, 

258 keep_first_cells: Optional[bool] = None, 

259 first_id: Optional[str] = None, 

260 last_id: Optional[str] = None, 

261 **kwargs: Any, 

262 ) -> List[Cell]: 

263 """ 

264 Resolve differences between `databooks.data_models.notebook.Cells`. 

265 

266 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

267 If `None`, then keep both wrapping the git-diff tags 

268 :param first_id: Git hash of first file in conflict 

269 :param last_id: Git hash of last file in conflict 

270 :param kwargs: (Unused) keyword arguments to keep compatibility with 

271 `databooks.data_models.base.resolve` 

272 :return: List of cells 

273 """ 

274 if keep_first_cells is not None: 

275 return list( 

276 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

277 ) 

278 return list( 

279 chain.from_iterable( 

280 Cells.wrap_git( 

281 first_cells=val[0], 

282 last_cells=val[1], 

283 hash_first=first_id, 

284 hash_last=last_id, 

285 ) 

286 if val[0] != val[1] 

287 else val[0] 

288 for val in self.data 

289 ) 

290 ) 

291 

292 

293class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

294 """Jupyter notebook. Extra fields yield invalid notebook.""" 

295 

296 nbformat: int 

297 nbformat_minor: int 

298 metadata: NotebookMetadata 

299 cells: Cells[Cell] 

300 

301 @classmethod 

302 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

303 """Parse notebook from a path.""" 

304 content_arg = parse_kwargs.pop("content_type", None) 

305 if content_arg is not None: 

306 raise ValueError( 

307 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

308 ) 

309 return super(JupyterNotebook, cls).parse_file( 

310 path=path, content_type="json", **parse_kwargs 

311 ) 

312 

313 def clear_metadata( 

314 self, 

315 *, 

316 notebook_metadata_keep: Sequence[str] = None, 

317 notebook_metadata_remove: Sequence[str] = None, 

318 **cell_kwargs: Any, 

319 ) -> None: 

320 """ 

321 Clear notebook and cell metadata. 

322 

323 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

324 sequence (i.e.: `()`) to remove all extra fields. 

325 :param notebook_metadata_remove: Metadata values to remove 

326 :param cell_kwargs: keyword arguments to be passed to each cell's 

327 `databooks.data_models.Cell.clear_metadata` 

328 :return: 

329 """ 

330 nargs = sum( 

331 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

332 ) 

333 if nargs != 1: 

334 raise ValueError( 

335 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

336 f" must be passed, got {nargs} arguments." 

337 ) 

338 if notebook_metadata_keep is not None: 

339 notebook_metadata_remove = tuple( 

340 field 

341 for field, _ in self.metadata 

342 if field not in notebook_metadata_keep 

343 ) 

344 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

345 

346 if len(cell_kwargs) > 0: 

347 _clean_cells = deepcopy(self.cells) 

348 for cell in _clean_cells: 

349 cell.clear_fields(**cell_kwargs) 

350 self.cells = _clean_cells