Coverage for databooks/data_models/notebook.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

108 statements  

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4from copy import deepcopy 

5from difflib import SequenceMatcher 

6from itertools import chain 

7from pathlib import Path 

8from typing import ( 

9 Any, 

10 Callable, 

11 Dict, 

12 Generator, 

13 Iterable, 

14 List, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 TypeVar, 

19 Union, 

20) 

21 

22from pydantic import Extra, PositiveInt, root_validator, validator 

23from pydantic.generics import GenericModel 

24 

25from databooks.data_models.base import BaseCells, DatabooksBase 

26from databooks.logging import get_logger 

27 

28logger = get_logger(__file__) 

29 

30 

31class NotebookMetadata(DatabooksBase): 

32 """Notebook metadata. Empty by default but can accept extra fields.""" 

33 

34 

35class CellMetadata(DatabooksBase): 

36 """Cell metadata. Empty by default but can accept extra fields.""" 

37 

38 

39class Cell(DatabooksBase): 

40 """ 

41 Jupyter notebook cells. 

42 

43 Fields `outputs` and `execution_count` are not included since they should only be 

44 present in code cells - thus are treated as extra fields. 

45 """ 

46 

47 metadata: CellMetadata 

48 source: Union[List[str], str] 

49 cell_type: str 

50 

51 def __hash__(self) -> int: 

52 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

53 return hash( 

54 (type(self),) + tuple(v) if isinstance(v, list) else v 

55 for v in self.__dict__.values() 

56 ) 

57 

58 def remove_fields( 

59 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

60 ) -> None: 

61 """ 

62 Remove Cell fields. 

63 

64 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

65 fields for `databooks.data_models.notebook.Cell`. 

66 """ 

67 # Ignore required `Cell` fields 

68 cell_fields = self.__fields__ # required fields especified in class definition 

69 if any(field in fields for field in cell_fields): 

70 logger.debug( 

71 "Ignoring removal of required fields " 

72 + str([f for f in fields if f in cell_fields]) 

73 + f" in `{type(self).__name__}`." 

74 ) 

75 fields = [f for f in fields if f not in cell_fields] 

76 

77 super(Cell, self).remove_fields(fields, missing_ok=missing_ok) 

78 

79 if self.cell_type == "code": 

80 self.outputs: List[Dict[str, Any]] = ( 

81 [] if "outputs" not in dict(self) else self.outputs 

82 ) 

83 self.execution_count: Optional[PositiveInt] = ( 

84 None if "execution_count" not in dict(self) else self.execution_count 

85 ) 

86 

87 def clear_fields( 

88 self, 

89 *, 

90 cell_metadata_keep: Sequence[str] = None, 

91 cell_metadata_remove: Sequence[str] = None, 

92 cell_remove_fields: Sequence[str] = (), 

93 ) -> None: 

94 """ 

95 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

96 

97 You can also specify metadata to keep or remove from the `metadata` property of 

98 `databooks.data_models.notebook.Cell`. 

99 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

100 sequence (i.e.: `()`) to remove all extra fields. 

101 :param cell_metadata_remove: Metadata values to remove 

102 :param cell_remove_fields: Fields to remove from cell 

103 :return: 

104 """ 

105 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

106 if nargs != 1: 

107 raise ValueError( 

108 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

109 f" be passed, got {nargs} arguments." 

110 ) 

111 

112 if cell_metadata_keep is not None: 

113 cell_metadata_remove = tuple( 

114 field for field, _ in self.metadata if field not in cell_metadata_keep 

115 ) 

116 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

117 

118 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

119 

120 @validator("cell_type") 

121 def cell_has_valid_type(cls, v: str) -> str: 

122 """Check if cell has one of the three predefined types.""" 

123 valid_cell_types = ("raw", "markdown", "code") 

124 if v not in valid_cell_types: 

125 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}") 

126 return v 

127 

128 @root_validator 

129 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

130 """Check that code cells have list-type outputs.""" 

131 if values["cell_type"] == "code" and not isinstance(values["outputs"], list): 

132 raise ValueError( 

133 "All code cells must have a list output property, got" 

134 f" {type(values.get('outputs'))}" 

135 ) 

136 return values 

137 

138 @root_validator 

139 def only_code_cells_have_outputs_and_execution_count( 

140 cls, values: Dict[str, Any] 

141 ) -> Dict[str, Any]: 

142 """Check that only code cells have outputs and execution count.""" 

143 if values["cell_type"] != "code" and ( 

144 ("outputs" in values) or ("execution_count" in values) 

145 ): 

146 raise ValueError( 

147 "Found `outputs` or `execution_count` for cell of type" 

148 f" `{values['cell_type']}`" 

149 ) 

150 return values 

151 

152 

153T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]]) 

154 

155 

156class Cells(GenericModel, BaseCells[T]): 

157 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

158 

159 __root__: Sequence[T] = [] 

160 

161 def __init__(self, elements: Sequence[T] = ()) -> None: 

162 """Allow passing data as a positional argument when instantiating class.""" 

163 super(Cells, self).__init__(__root__=elements) 

164 

165 @property 

166 def data(self) -> List[T]: # type: ignore 

167 """Define property `data` required for `collections.UserList` class.""" 

168 return list(self.__root__) 

169 

170 def __iter__(self) -> Generator[Any, None, None]: 

171 """Use list property as iterable.""" 

172 return (el for el in self.data) 

173 

174 def __sub__( 

175 self: Cells[Cell], other: Cells[Cell] 

176 ) -> Cells[Tuple[List[Cell], List[Cell]]]: 

177 """Return the difference using `difflib.SequenceMatcher`.""" 

178 if type(self) != type(other): 

179 raise TypeError( 

180 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

181 f" `{type(other).__name__}`" 

182 ) 

183 

184 # By setting the context to the max number of cells and using 

185 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

186 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

187 n_context = max(len(self), len(other)) 

188 diff_opcodes = list( 

189 SequenceMatcher( 

190 isjunk=None, a=self, b=other, autojunk=False 

191 ).get_grouped_opcodes(n_context) 

192 ) 

193 

194 if len(diff_opcodes) > 1: 

195 raise RuntimeError( 

196 "Expected one group for opcodes when context size is " 

197 f" {n_context} for {len(self)} and {len(other)} cells in" 

198 " notebooks." 

199 ) 

200 return Cells[Tuple[List[Cell], List[Cell]]]( 

201 [ 

202 # https://github.com/python/mypy/issues/9459 

203 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

204 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

205 ] 

206 ) 

207 

208 @classmethod 

209 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

210 """Get validators for custom class.""" 

211 yield cls.validate 

212 

213 @classmethod 

214 def validate(cls, v: List[T]) -> Cells[T]: 

215 """Ensure object is custom defined container.""" 

216 if not isinstance(v, cls): 

217 return cls(v) 

218 else: 

219 return v 

220 

221 @classmethod 

222 def wrap_git( 

223 cls, 

224 first_cells: List[Cell], 

225 last_cells: List[Cell], 

226 hash_first: Optional[str] = None, 

227 hash_last: Optional[str] = None, 

228 ) -> List[Cell]: 

229 """Wrap git-diff cells in existing notebook.""" 

230 return ( 

231 [ 

232 Cell( 

233 metadata=CellMetadata(git_hash=hash_first), 

234 source=[f"`<<<<<<< {hash_first}`"], 

235 cell_type="markdown", 

236 ) 

237 ] 

238 + first_cells 

239 + [ 

240 Cell( 

241 source=["`=======`"], 

242 cell_type="markdown", 

243 metadata=CellMetadata(), 

244 ) 

245 ] 

246 + last_cells 

247 + [ 

248 Cell( 

249 metadata=CellMetadata(git_hash=hash_last), 

250 source=[f"`>>>>>>> {hash_last}`"], 

251 cell_type="markdown", 

252 ) 

253 ] 

254 ) 

255 

256 def resolve( 

257 self: Cells[Tuple[List[Cell], List[Cell]]], 

258 *, 

259 keep_first_cells: Optional[bool] = None, 

260 first_id: Optional[str] = None, 

261 last_id: Optional[str] = None, 

262 **kwargs: Any, 

263 ) -> List[Cell]: 

264 """ 

265 Resolve differences between `databooks.data_models.notebook.Cells`. 

266 

267 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

268 If `None`, then keep both wrapping the git-diff tags 

269 :param first_id: Git hash of first file in conflict 

270 :param last_id: Git hash of last file in conflict 

271 :param kwargs: (Unused) keyword arguments to keep compatibility with 

272 `databooks.data_models.base.resolve` 

273 :return: List of cells 

274 """ 

275 if keep_first_cells is not None: 

276 return list( 

277 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

278 ) 

279 return list( 

280 chain.from_iterable( 

281 Cells.wrap_git( 

282 first_cells=val[0], 

283 last_cells=val[1], 

284 hash_first=first_id, 

285 hash_last=last_id, 

286 ) 

287 if val[0] != val[1] 

288 else val[0] 

289 for val in self.data 

290 ) 

291 ) 

292 

293 

294class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

295 """Jupyter notebook. Extra fields yield invalid notebook.""" 

296 

297 nbformat: int 

298 nbformat_minor: int 

299 metadata: NotebookMetadata 

300 cells: Cells[Cell] 

301 

302 @classmethod 

303 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

304 """Parse notebook from a path.""" 

305 content_arg = parse_kwargs.pop("content_type", None) 

306 if content_arg is not None: 

307 raise ValueError( 

308 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

309 ) 

310 return super(JupyterNotebook, cls).parse_file( 

311 path=path, content_type="json", **parse_kwargs 

312 ) 

313 

314 def clear_metadata( 

315 self, 

316 *, 

317 notebook_metadata_keep: Sequence[str] = None, 

318 notebook_metadata_remove: Sequence[str] = None, 

319 **cell_kwargs: Any, 

320 ) -> None: 

321 """ 

322 Clear notebook and cell metadata. 

323 

324 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

325 sequence (i.e.: `()`) to remove all extra fields. 

326 :param notebook_metadata_remove: Metadata values to remove 

327 :param cell_kwargs: keyword arguments to be passed to each cell's 

328 `databooks.data_models.Cell.clear_metadata` 

329 :return: 

330 """ 

331 nargs = sum( 

332 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

333 ) 

334 if nargs != 1: 

335 raise ValueError( 

336 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

337 f" must be passed, got {nargs} arguments." 

338 ) 

339 if notebook_metadata_keep is not None: 

340 notebook_metadata_remove = tuple( 

341 field 

342 for field, _ in self.metadata 

343 if field not in notebook_metadata_keep 

344 ) 

345 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

346 

347 if len(cell_kwargs) > 0: 

348 _clean_cells = deepcopy(self.cells) 

349 for cell in _clean_cells: 

350 cell.clear_fields(**cell_kwargs) 

351 self.cells = _clean_cells