Coverage for databooks/data_models/notebook.py: 92%

113 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-10-03 12:27 +0000

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4import json 

5from copy import deepcopy 

6from difflib import SequenceMatcher 

7from itertools import chain 

8from pathlib import Path 

9from typing import ( 

10 Any, 

11 Callable, 

12 Generator, 

13 Iterable, 

14 List, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 TypeVar, 

19 Union, 

20 cast, 

21) 

22 

23from pydantic import Extra, RootModel 

24from rich import box 

25from rich.columns import Columns 

26from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult 

27from rich.panel import Panel 

28from rich.text import Text 

29 

30from databooks.data_models.base import BaseCells, DatabooksBase 

31from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell 

32from databooks.logging import get_logger 

33 

34logger = get_logger(__file__) 

35 

36Cell = Union[CodeCell, RawCell, MarkdownCell] 

37CellsPair = Tuple[List[Cell], List[Cell]] 

38T = TypeVar("T", Cell, CellsPair) 

39 

40 

41class Cells(RootModel[Sequence[T]], BaseCells[T]): 

42 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

43 

44 root: Sequence[T] 

45 

46 @property 

47 def data(self) -> List[T]: # type: ignore 

48 """Define property `data` required for `collections.UserList` class.""" 

49 return list(self.root) 

50 

51 def __iter__(self) -> Generator[Any, None, None]: 

52 """Use list property as iterable.""" 

53 return (el for el in self.data) 

54 

55 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]: 

56 """Return the difference using `difflib.SequenceMatcher`.""" 

57 if type(self) != type(other): 

58 raise TypeError( 

59 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

60 f" `{type(other).__name__}`" 

61 ) 

62 

63 # By setting the context to the max number of cells and using 

64 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

65 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

66 n_context = max(len(self), len(other)) 

67 diff_opcodes = list( 

68 SequenceMatcher( 

69 isjunk=None, a=self, b=other, autojunk=False 

70 ).get_grouped_opcodes(n_context) 

71 ) 

72 

73 if len(diff_opcodes) > 1: 

74 raise RuntimeError( 

75 "Expected one group for opcodes when context size is " 

76 f" {n_context} for {len(self)} and {len(other)} cells in" 

77 " notebooks." 

78 ) 

79 

80 return Cells[CellsPair]( 

81 [ 

82 # https://github.com/python/mypy/issues/9459 

83 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

84 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

85 ] 

86 ) 

87 

88 def __rich_console__( 

89 self, console: Console, options: ConsoleOptions 

90 ) -> RenderResult: 

91 """Rich display of all cells in notebook.""" 

92 yield from self._get_renderables(expand=True, width=options.max_width // 3) 

93 

94 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]: 

95 """Get the Rich renderables, depending on whether `Cells` is a diff or not.""" 

96 if all(isinstance(el, tuple) for el in self.data): 

97 return chain.from_iterable( 

98 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs) 

99 if val[0] != val[1] 

100 else val[0] 

101 for val in cast(List[CellsPair], self.data) 

102 ) 

103 return cast(List[Cell], self.data) 

104 

105 @classmethod 

106 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

107 """Get validators for custom class.""" 

108 yield cls.validate 

109 

110 @classmethod 

111 def validate(cls, v: List[T]) -> Cells[T]: 

112 """Ensure object is custom defined container.""" 

113 if not isinstance(v, cls): 

114 return cls(v) 

115 else: 

116 return v 

117 

118 @classmethod 

119 def wrap_cols( 

120 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any 

121 ) -> Sequence[Columns]: 

122 """Wrap the first and second cells into colunmns for iterable.""" 

123 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)] 

124 _first = Group(*first_cells or _empty) 

125 _last = Group(*last_cells or _empty) 

126 return [Columns([_first, _last], **cols_kwargs)] 

127 

128 @staticmethod 

129 def wrap_git( 

130 first_cells: List[Cell], 

131 last_cells: List[Cell], 

132 hash_first: Optional[str] = None, 

133 hash_last: Optional[str] = None, 

134 ) -> Sequence[Cell]: 

135 """Wrap git-diff cells in existing notebook.""" 

136 return [ 

137 MarkdownCell( 

138 metadata=CellMetadata(git_hash=hash_first), 

139 source=[f"`<<<<<<< {hash_first}`"], 

140 ), 

141 *first_cells, 

142 MarkdownCell( 

143 source=["`=======`"], 

144 metadata=CellMetadata(), 

145 ), 

146 *last_cells, 

147 MarkdownCell( 

148 metadata=CellMetadata(git_hash=hash_last), 

149 source=[f"`>>>>>>> {hash_last}`"], 

150 ), 

151 ] 

152 

153 def resolve( 

154 self: Cells[CellsPair], 

155 *, 

156 keep_first_cells: Optional[bool] = None, 

157 first_id: Optional[str] = None, 

158 last_id: Optional[str] = None, 

159 **kwargs: Any, 

160 ) -> List[Cell]: 

161 """ 

162 Resolve differences between `databooks.data_models.notebook.Cells`. 

163 

164 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

165 If `None`, then keep both wrapping the git-diff tags 

166 :param first_id: Git hash of first file in conflict 

167 :param last_id: Git hash of last file in conflict 

168 :param kwargs: (Unused) keyword arguments to keep compatibility with 

169 `databooks.data_models.base.resolve` 

170 :return: List of cells 

171 """ 

172 if keep_first_cells is not None: 

173 return list( 

174 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

175 ) 

176 return list( 

177 chain.from_iterable( 

178 Cells.wrap_git( 

179 first_cells=val[0], 

180 last_cells=val[1], 

181 hash_first=first_id, 

182 hash_last=last_id, 

183 ) 

184 if val[0] != val[1] 

185 else val[0] 

186 for val in self.data 

187 ) 

188 ) 

189 

190 

191class NotebookMetadata(DatabooksBase): 

192 """Notebook metadata. Empty by default but can accept extra fields.""" 

193 

194 

195class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

196 """Jupyter notebook. Extra fields yield invalid notebook.""" 

197 

198 nbformat: int 

199 nbformat_minor: int 

200 metadata: NotebookMetadata 

201 cells: Cells[Cell] 

202 

203 def __rich_console__( 

204 self, console: Console, options: ConsoleOptions 

205 ) -> RenderResult: 

206 """Rich display notebook.""" 

207 

208 def _rich(kernel: str) -> Text: 

209 """Display with `kernel` theme, horizontal padding and right-justified.""" 

210 return Text(kernel, style="kernel", justify="right") 

211 

212 kernelspec = self.metadata.dict().get("kernelspec", {}) 

213 if isinstance(kernelspec, tuple): # check if this is a `DiffCells` 

214 kernelspec = tuple( 

215 ks or {"language": "text", "display_name": "null"} for ks in kernelspec 

216 ) 

217 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec) 

218 nb_lang = lang_first if lang_first == lang_last else "text" 

219 if any("display_name" in ks.keys() for ks in kernelspec): 

220 kernel_first, kernel_last = [ 

221 _rich(ks["display_name"]) for ks in kernelspec 

222 ] 

223 yield Columns( 

224 [kernel_first, kernel_last], 

225 expand=True, 

226 width=options.max_width // 3, 

227 ) if kernel_first != kernel_last else kernel_first 

228 else: 

229 nb_lang = kernelspec.get("language", "text") 

230 if "display_name" in kernelspec.keys(): 

231 yield _rich(kernelspec["display_name"]) 

232 

233 for cell in self.cells: 

234 if isinstance(cell, CodeCell): 

235 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang) 

236 yield self.cells 

237 

238 @classmethod 

239 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

240 """Parse notebook from a path.""" 

241 content_arg = parse_kwargs.pop("content_type", None) 

242 if content_arg is not None: 

243 raise ValueError( 

244 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

245 ) 

246 

247 path = Path(path) if not isinstance(path, Path) else path 

248 return JupyterNotebook.model_validate_json(json_data=path.read_text()) 

249 

250 def write( 

251 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any 

252 ) -> None: 

253 """Write notebook to disk.""" 

254 path = Path(path) if not isinstance(path, Path) else path 

255 json_kwargs = {"indent": 2, **json_kwargs} 

256 if path.is_file() and not overwrite: 

257 raise ValueError( 

258 f"File exists at {path} exists. Specify `overwrite = True`." 

259 ) 

260 

261 self.__class__.model_validate(self.dict()) 

262 

263 with path.open("w") as f: 

264 json.dump(self.dict(), fp=f, **json_kwargs) 

265 

266 def clear_metadata( 

267 self, 

268 *, 

269 notebook_metadata_keep: Sequence[str] = None, 

270 notebook_metadata_remove: Sequence[str] = None, 

271 **cell_kwargs: Any, 

272 ) -> None: 

273 """ 

274 Clear notebook and cell metadata. 

275 

276 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

277 sequence (i.e.: `()`) to remove all extra fields. 

278 :param notebook_metadata_remove: Metadata values to remove 

279 :param cell_kwargs: keyword arguments to be passed to each cell's 

280 `databooks.data_models.cell.BaseCell.clear_metadata` 

281 :return: 

282 """ 

283 nargs = sum( 

284 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

285 ) 

286 if nargs != 1: 

287 raise ValueError( 

288 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

289 f" must be passed, got {nargs} arguments." 

290 ) 

291 if notebook_metadata_keep is not None: 

292 notebook_metadata_remove = tuple( 

293 field 

294 for field, _ in self.metadata 

295 if field not in notebook_metadata_keep 

296 ) 

297 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

298 

299 if len(cell_kwargs) > 0: 

300 _clean_cells = deepcopy(self.cells) 

301 for cell in _clean_cells: 

302 cell.clear_fields(**cell_kwargs) 

303 self.cells = _clean_cells