Coverage for databooks/data_models/notebook.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

109 statements  

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4from copy import deepcopy 

5from difflib import SequenceMatcher 

6from itertools import chain 

7from pathlib import Path 

8from typing import ( 

9 Any, 

10 Callable, 

11 Dict, 

12 Generator, 

13 List, 

14 Optional, 

15 Sequence, 

16 Tuple, 

17 TypeVar, 

18 Union, 

19) 

20 

21from pydantic import Extra, root_validator, validator 

22from pydantic.generics import GenericModel 

23 

24from databooks.data_models.base import BaseCells, DatabooksBase 

25 

26 

27class NotebookMetadata(DatabooksBase): 

28 """Notebook metadata. Empty by default but can accept extra fields.""" 

29 

30 ... 

31 

32 

33class CellMetadata(DatabooksBase): 

34 """Cell metadata. Empty by default but can accept extra fields.""" 

35 

36 ... 

37 

38 

39class Cell(DatabooksBase): 

40 """ 

41 Jupyter notebook cells. 

42 

43 Fields `outputs` and `execution_count` are not included since they should only be 

44 present in code cells - thus are treated as extra fields. 

45 """ 

46 

47 metadata: CellMetadata 

48 source: Union[List[str], str] 

49 cell_type: str 

50 

51 def __hash__(self) -> int: 

52 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

53 return hash( 

54 (type(self),) + tuple(v) if isinstance(v, list) else v 

55 for v in self.__dict__.values() 

56 ) 

57 

58 def clear_metadata( 

59 self, 

60 *, 

61 cell_metadata_keep: Sequence[str] = None, 

62 cell_metadata_remove: Sequence[str] = None, 

63 cell_execution_count: bool = True, 

64 cell_outputs: bool = False, 

65 remove_fields: List[str] = ["id"], 

66 ) -> None: 

67 """ 

68 Clear cell metadata, execution count and outputs. 

69 

70 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

71 sequence (i.e.: `()`) to remove all extra fields. 

72 :param cell_metadata_remove: Metadata values to remove 

73 :param cell_execution_count: Whether or not to keep the execution count 

74 :param cell_outputs: whether or not to keep the cell outputs 

75 :return: 

76 """ 

77 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

78 if nargs != 1: 

79 raise ValueError( 

80 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

81 f" be passed, got {nargs} arguments." 

82 ) 

83 if cell_metadata_keep is not None: 

84 cell_metadata_remove = tuple( 

85 field for field, _ in self.metadata if field not in cell_metadata_keep 

86 ) 

87 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

88 

89 self.remove_fields(remove_fields, missing_ok=True) 

90 if self.cell_type == "code": 

91 if cell_outputs: 

92 self.outputs: List[Dict[str, Any]] = [] 

93 if cell_execution_count: 

94 self.execution_count = None 

95 

96 @validator("cell_type") 

97 def cell_has_valid_type(cls, v: str) -> str: 

98 """Check if cell has one of the three predefined types.""" 

99 valid_cell_types = ("raw", "markdown", "code") 

100 if v not in valid_cell_types: 

101 raise ValueError(f"Invalid cell type. Must be one of {valid_cell_types}") 

102 return v 

103 

104 @root_validator 

105 def must_not_be_list_for_code_cells(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

106 """Check that code cells have list-type outputs.""" 

107 if values["cell_type"] == "code" and not isinstance(values["outputs"], list): 

108 raise ValueError( 

109 "All code cells must have a list output property, got" 

110 f" {type(values.get('outputs'))}" 

111 ) 

112 return values 

113 

114 @root_validator 

115 def only_code_cells_have_outputs_and_execution_count( 

116 cls, values: Dict[str, Any] 

117 ) -> Dict[str, Any]: 

118 """Check that only code cells have outputs and execution count.""" 

119 if values["cell_type"] != "code" and ( 

120 ("outputs" in values) or ("execution_count" in values) 

121 ): 

122 raise ValueError( 

123 "Found `outputs` or `execution_count` for cell of type" 

124 f" `{values['cell_type']}`" 

125 ) 

126 return values 

127 

128 

129T = TypeVar("T", Cell, Tuple[List[Cell], List[Cell]]) 

130 

131 

132class Cells(GenericModel, BaseCells[T]): 

133 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

134 

135 __root__: Sequence[T] = [] 

136 

137 def __init__(self, elements: Sequence[T] = ()) -> None: 

138 """Allow passing data as a positional argument when instantiating class.""" 

139 super(Cells, self).__init__(__root__=elements) 

140 

141 @property 

142 def data(self) -> List[T]: # type: ignore 

143 """Define property `data` required for `collections.UserList` class.""" 

144 return list(self.__root__) 

145 

146 def __iter__(self) -> Generator[Any, None, None]: 

147 """Use list property as iterable.""" 

148 return (el for el in self.data) 

149 

150 def __sub__( 

151 self: Cells[Cell], other: Cells[Cell] 

152 ) -> Cells[Tuple[List[Cell], List[Cell]]]: 

153 """Return the difference using `difflib.SequenceMatcher`.""" 

154 if type(self) != type(other): 

155 raise TypeError( 

156 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

157 f" `{type(other).__name__}`" 

158 ) 

159 

160 _self = deepcopy(self) 

161 _other = deepcopy(other) 

162 for cells in (_self, _other): 

163 for cell in cells: 

164 cell.remove_fields(["id"], missing_ok=True) 

165 

166 # By setting the context to the max number of cells and using 

167 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

168 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

169 n_context = max(len(_self), len(_other)) 

170 diff_opcodes = list( 

171 SequenceMatcher( 

172 isjunk=None, a=_self, b=_other, autojunk=False 

173 ).get_grouped_opcodes(n_context) 

174 ) 

175 

176 if len(diff_opcodes) > 1: 

177 raise RuntimeError( 

178 "Expected one group for opcodes when context size is " 

179 f" {n_context} for {len(_self)} and {len(_other)} cells in" 

180 " notebooks." 

181 ) 

182 return Cells[Tuple[List[Cell], List[Cell]]]( 

183 [ 

184 # https://github.com/python/mypy/issues/9459 

185 tuple((_self.data[i1:j1], _other.data[i2:j2])) # type: ignore 

186 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

187 ] 

188 ) 

189 

190 @classmethod 

191 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

192 """Get validators for custom class.""" 

193 yield cls.validate 

194 

195 @classmethod 

196 def validate(cls, v: List[T]) -> Cells[T]: 

197 """Ensure object is custom defined container.""" 

198 if not isinstance(v, cls): 

199 return cls(v) 

200 else: 

201 return v 

202 

203 @classmethod 

204 def wrap_git( 

205 cls, 

206 first_cells: List[Cell], 

207 last_cells: List[Cell], 

208 hash_first: Optional[str] = None, 

209 hash_last: Optional[str] = None, 

210 ) -> List[Cell]: 

211 """Wrap git-diff cells in existing notebook.""" 

212 return ( 

213 [ 

214 Cell( 

215 metadata=CellMetadata(git_hash=hash_first), 

216 source=[f"`<<<<<<< {hash_first}`"], 

217 cell_type="markdown", 

218 ) 

219 ] 

220 + first_cells 

221 + [ 

222 Cell( 

223 source=["`=======`"], 

224 cell_type="markdown", 

225 metadata=CellMetadata(), 

226 ) 

227 ] 

228 + last_cells 

229 + [ 

230 Cell( 

231 metadata=CellMetadata(git_hash=hash_last), 

232 source=[f"`>>>>>>> {hash_last}`"], 

233 cell_type="markdown", 

234 ) 

235 ] 

236 ) 

237 

238 def resolve( 

239 self: Cells[Tuple[List[Cell], List[Cell]]], 

240 *, 

241 keep_first_cells: Optional[bool] = None, 

242 first_id: Optional[str] = None, 

243 last_id: Optional[str] = None, 

244 **kwargs: Any, 

245 ) -> List[Cell]: 

246 """ 

247 Resolve differences between `databooks.data_models.notebook.Cells`. 

248 

249 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

250 If `None`, then keep both wrapping the git-diff tags 

251 :param first_id: Git hash of first file in conflict 

252 :param last_id: Git hash of last file in conflict 

253 :param kwargs: (Unused) keyword arguments to keep compatibility with 

254 `databooks.data_models.base.resolve` 

255 :return: List of cells 

256 """ 

257 if keep_first_cells is not None: 

258 return list( 

259 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

260 ) 

261 return list( 

262 chain.from_iterable( 

263 Cells.wrap_git( 

264 first_cells=val[0], 

265 last_cells=val[1], 

266 hash_first=first_id, 

267 hash_last=last_id, 

268 ) 

269 if val[0] != val[1] 

270 else val[0] 

271 for val in self.data 

272 ) 

273 ) 

274 

275 

276class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

277 """Jupyter notebook. Extra fields yield invalid notebook.""" 

278 

279 nbformat: int 

280 nbformat_minor: int 

281 metadata: NotebookMetadata 

282 cells: Cells[Cell] 

283 

284 @classmethod 

285 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

286 """Parse notebook from a path.""" 

287 content_arg = parse_kwargs.pop("content_type", None) 

288 if content_arg is not None: 

289 raise ValueError( 

290 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

291 ) 

292 return super(JupyterNotebook, cls).parse_file( 

293 path=path, content_type="json", **parse_kwargs 

294 ) 

295 

296 def clear_metadata( 

297 self, 

298 *, 

299 notebook_metadata_keep: Sequence[str] = None, 

300 notebook_metadata_remove: Sequence[str] = None, 

301 **cell_kwargs: Any, 

302 ) -> None: 

303 """ 

304 Clear notebook and cell metadata. 

305 

306 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

307 sequence (i.e.: `()`) to remove all extra fields. 

308 :param notebook_metadata_remove: Metadata values to remove 

309 :param cell_kwargs: keyword arguments to be passed to each cell's 

310 `databooks.data_models.Cell.clear_metadata` 

311 :return: 

312 """ 

313 nargs = sum( 

314 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

315 ) 

316 if nargs != 1: 

317 raise ValueError( 

318 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

319 f" must be passed, got {nargs} arguments." 

320 ) 

321 if notebook_metadata_keep is not None: 

322 notebook_metadata_remove = tuple( 

323 field 

324 for field, _ in self.metadata 

325 if field not in notebook_metadata_keep 

326 ) 

327 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

328 

329 if len(cell_kwargs) > 0: 

330 _clean_cells = deepcopy(self.cells) 

331 for cell in _clean_cells: 

332 cell.clear_metadata(**cell_kwargs) 

333 self.cells = _clean_cells