Coverage for databooks/data_models/notebook.py: 94%

117 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 13:11 +0000

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4import json 

5from copy import deepcopy 

6from difflib import SequenceMatcher 

7from itertools import chain 

8from pathlib import Path 

9from typing import ( 

10 Any, 

11 Callable, 

12 Generator, 

13 Iterable, 

14 List, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 TypeVar, 

19 Union, 

20 cast, 

21) 

22 

23from pydantic import Extra, validate_model 

24from pydantic.generics import GenericModel 

25from rich import box 

26from rich.columns import Columns 

27from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult 

28from rich.panel import Panel 

29from rich.text import Text 

30 

31from databooks.data_models.base import BaseCells, DatabooksBase 

32from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell 

33from databooks.logging import get_logger 

34 

35logger = get_logger(__file__) 

36 

37Cell = Union[CodeCell, RawCell, MarkdownCell] 

38CellsPair = Tuple[List[Cell], List[Cell]] 

39T = TypeVar("T", Cell, CellsPair) 

40 

41 

42class Cells(GenericModel, BaseCells[T]): 

43 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

44 

45 __root__: Sequence[T] = () 

46 

47 def __init__(self, elements: Sequence[T] = ()) -> None: 

48 """Allow passing data as a positional argument when instantiating class.""" 

49 super(Cells, self).__init__(__root__=elements) 

50 

51 @property 

52 def data(self) -> List[T]: # type: ignore 

53 """Define property `data` required for `collections.UserList` class.""" 

54 return list(self.__root__) 

55 

56 def __iter__(self) -> Generator[Any, None, None]: 

57 """Use list property as iterable.""" 

58 return (el for el in self.data) 

59 

60 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]: 

61 """Return the difference using `difflib.SequenceMatcher`.""" 

62 if type(self) != type(other): 

63 raise TypeError( 

64 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

65 f" `{type(other).__name__}`" 

66 ) 

67 

68 # By setting the context to the max number of cells and using 

69 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

70 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

71 n_context = max(len(self), len(other)) 

72 diff_opcodes = list( 

73 SequenceMatcher( 

74 isjunk=None, a=self, b=other, autojunk=False 

75 ).get_grouped_opcodes(n_context) 

76 ) 

77 

78 if len(diff_opcodes) > 1: 

79 raise RuntimeError( 

80 "Expected one group for opcodes when context size is " 

81 f" {n_context} for {len(self)} and {len(other)} cells in" 

82 " notebooks." 

83 ) 

84 return Cells[CellsPair]( 

85 [ 

86 # https://github.com/python/mypy/issues/9459 

87 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

88 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

89 ] 

90 ) 

91 

92 def __rich_console__( 

93 self, console: Console, options: ConsoleOptions 

94 ) -> RenderResult: 

95 """Rich display of all cells in notebook.""" 

96 yield from self._get_renderables(expand=True, width=options.max_width // 3) 

97 

98 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]: 

99 """Get the Rich renderables, depending on whether `Cells` is a diff or not.""" 

100 if all(isinstance(el, tuple) for el in self.data): 

101 return chain.from_iterable( 

102 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs) 

103 if val[0] != val[1] 

104 else val[0] 

105 for val in cast(List[CellsPair], self.data) 

106 ) 

107 return cast(List[Cell], self.data) 

108 

109 @classmethod 

110 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

111 """Get validators for custom class.""" 

112 yield cls.validate 

113 

114 @classmethod 

115 def validate(cls, v: List[T]) -> Cells[T]: 

116 """Ensure object is custom defined container.""" 

117 if not isinstance(v, cls): 

118 return cls(v) 

119 else: 

120 return v 

121 

122 @classmethod 

123 def wrap_cols( 

124 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any 

125 ) -> Sequence[Columns]: 

126 """Wrap the first and second cells into colunmns for iterable.""" 

127 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)] 

128 _first = Group(*first_cells or _empty) 

129 _last = Group(*last_cells or _empty) 

130 return [Columns([_first, _last], **cols_kwargs)] 

131 

132 @staticmethod 

133 def wrap_git( 

134 first_cells: List[Cell], 

135 last_cells: List[Cell], 

136 hash_first: Optional[str] = None, 

137 hash_last: Optional[str] = None, 

138 ) -> Sequence[Cell]: 

139 """Wrap git-diff cells in existing notebook.""" 

140 return [ 

141 MarkdownCell( 

142 metadata=CellMetadata(git_hash=hash_first), 

143 source=[f"`<<<<<<< {hash_first}`"], 

144 cell_type="markdown", 

145 ), 

146 *first_cells, 

147 MarkdownCell( 

148 source=["`=======`"], 

149 cell_type="markdown", 

150 metadata=CellMetadata(), 

151 ), 

152 *last_cells, 

153 MarkdownCell( 

154 metadata=CellMetadata(git_hash=hash_last), 

155 source=[f"`>>>>>>> {hash_last}`"], 

156 cell_type="markdown", 

157 ), 

158 ] 

159 

160 def resolve( 

161 self: Cells[CellsPair], 

162 *, 

163 keep_first_cells: Optional[bool] = None, 

164 first_id: Optional[str] = None, 

165 last_id: Optional[str] = None, 

166 **kwargs: Any, 

167 ) -> List[Cell]: 

168 """ 

169 Resolve differences between `databooks.data_models.notebook.Cells`. 

170 

171 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

172 If `None`, then keep both wrapping the git-diff tags 

173 :param first_id: Git hash of first file in conflict 

174 :param last_id: Git hash of last file in conflict 

175 :param kwargs: (Unused) keyword arguments to keep compatibility with 

176 `databooks.data_models.base.resolve` 

177 :return: List of cells 

178 """ 

179 if keep_first_cells is not None: 

180 return list( 

181 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

182 ) 

183 return list( 

184 chain.from_iterable( 

185 Cells.wrap_git( 

186 first_cells=val[0], 

187 last_cells=val[1], 

188 hash_first=first_id, 

189 hash_last=last_id, 

190 ) 

191 if val[0] != val[1] 

192 else val[0] 

193 for val in self.data 

194 ) 

195 ) 

196 

197 

198class NotebookMetadata(DatabooksBase): 

199 """Notebook metadata. Empty by default but can accept extra fields.""" 

200 

201 

202class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

203 """Jupyter notebook. Extra fields yield invalid notebook.""" 

204 

205 nbformat: int 

206 nbformat_minor: int 

207 metadata: NotebookMetadata 

208 cells: Cells[Cell] 

209 

210 def __rich_console__( 

211 self, console: Console, options: ConsoleOptions 

212 ) -> RenderResult: 

213 """Rich display notebook.""" 

214 

215 def _rich(kernel: str) -> Text: 

216 """Display with `kernel` theme, horizontal padding and right-justified.""" 

217 return Text(kernel, style="kernel", justify="right") 

218 

219 kernelspec = self.metadata.dict().get("kernelspec", {}) 

220 if isinstance(kernelspec, tuple): # check if this is a `DiffCells` 

221 kernelspec = tuple( 

222 ks or {"language": "text", "display_name": "null"} for ks in kernelspec 

223 ) 

224 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec) 

225 nb_lang = lang_first if lang_first == lang_last else "text" 

226 if any("display_name" in ks.keys() for ks in kernelspec): 

227 kernel_first, kernel_last = [ 

228 _rich(ks["display_name"]) for ks in kernelspec 

229 ] 

230 yield Columns( 

231 [kernel_first, kernel_last], 

232 expand=True, 

233 width=options.max_width // 3, 

234 ) if kernel_first != kernel_last else kernel_first 

235 else: 

236 nb_lang = kernelspec.get("language", "text") 

237 if "display_name" in kernelspec.keys(): 

238 yield _rich(kernelspec["display_name"]) 

239 

240 for cell in self.cells: 

241 if isinstance(cell, CodeCell): 

242 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang) 

243 yield self.cells 

244 

245 @classmethod 

246 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

247 """Parse notebook from a path.""" 

248 content_arg = parse_kwargs.pop("content_type", None) 

249 if content_arg is not None: 

250 raise ValueError( 

251 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

252 ) 

253 return super(JupyterNotebook, cls).parse_file( 

254 path=path, content_type="json", **parse_kwargs 

255 ) 

256 

257 def write( 

258 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any 

259 ) -> None: 

260 """Write notebook to disk.""" 

261 path = Path(path) if not isinstance(path, Path) else path 

262 json_kwargs = {"indent": 2, **json_kwargs} 

263 if path.is_file() and not overwrite: 

264 raise ValueError( 

265 f"File exists at {path} exists. Specify `overwrite = True`." 

266 ) 

267 

268 _, _, validation_error = validate_model(self.__class__, self.dict()) 

269 if validation_error: 

270 raise validation_error 

271 with path.open("w") as f: 

272 json.dump(self.dict(), fp=f, **json_kwargs) 

273 

274 def clear_metadata( 

275 self, 

276 *, 

277 notebook_metadata_keep: Sequence[str] = None, 

278 notebook_metadata_remove: Sequence[str] = None, 

279 **cell_kwargs: Any, 

280 ) -> None: 

281 """ 

282 Clear notebook and cell metadata. 

283 

284 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

285 sequence (i.e.: `()`) to remove all extra fields. 

286 :param notebook_metadata_remove: Metadata values to remove 

287 :param cell_kwargs: keyword arguments to be passed to each cell's 

288 `databooks.data_models.cell.BaseCell.clear_metadata` 

289 :return: 

290 """ 

291 nargs = sum( 

292 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

293 ) 

294 if nargs != 1: 

295 raise ValueError( 

296 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

297 f" must be passed, got {nargs} arguments." 

298 ) 

299 if notebook_metadata_keep is not None: 

300 notebook_metadata_remove = tuple( 

301 field 

302 for field, _ in self.metadata 

303 if field not in notebook_metadata_keep 

304 ) 

305 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

306 

307 if len(cell_kwargs) > 0: 

308 _clean_cells = deepcopy(self.cells) 

309 for cell in _clean_cells: 

310 cell.clear_fields(**cell_kwargs) 

311 self.cells = _clean_cells