Coverage for databooks/data_models/notebook.py: 94%

116 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-04 16:41 +0000

1"""Data models - Jupyter Notebooks and components.""" 

2from __future__ import annotations 

3 

4import json 

5from copy import deepcopy 

6from difflib import SequenceMatcher 

7from itertools import chain 

8from pathlib import Path 

9from typing import ( 

10 Any, 

11 Callable, 

12 Generator, 

13 Iterable, 

14 List, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 TypeVar, 

19 Union, 

20 cast, 

21) 

22 

23from pydantic import Extra, validate_model 

24from pydantic.generics import GenericModel 

25from rich import box 

26from rich.columns import Columns 

27from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult 

28from rich.panel import Panel 

29from rich.text import Text 

30 

31from databooks.data_models.base import BaseCells, DatabooksBase 

32from databooks.data_models.cell import CellMetadata, CodeCell, MarkdownCell, RawCell 

33from databooks.logging import get_logger 

34 

35logger = get_logger(__file__) 

36 

37Cell = Union[CodeCell, RawCell, MarkdownCell] 

38CellsPair = Tuple[List[Cell], List[Cell]] 

39T = TypeVar("T", Cell, CellsPair) 

40 

41 

42class Cells(GenericModel, BaseCells[T]): 

43 """Similar to `list`, with `-` operator using `difflib.SequenceMatcher`.""" 

44 

45 __root__: Sequence[T] = () 

46 

47 def __init__(self, elements: Sequence[T] = ()) -> None: 

48 """Allow passing data as a positional argument when instantiating class.""" 

49 super(Cells, self).__init__(__root__=elements) 

50 

51 @property 

52 def data(self) -> List[T]: # type: ignore 

53 """Define property `data` required for `collections.UserList` class.""" 

54 return list(self.__root__) 

55 

56 def __iter__(self) -> Generator[Any, None, None]: 

57 """Use list property as iterable.""" 

58 return (el for el in self.data) 

59 

60 def __sub__(self: Cells[Cell], other: Cells[Cell]) -> Cells[CellsPair]: 

61 """Return the difference using `difflib.SequenceMatcher`.""" 

62 if type(self) != type(other): 

63 raise TypeError( 

64 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

65 f" `{type(other).__name__}`" 

66 ) 

67 

68 # By setting the context to the max number of cells and using 

69 # `pathlib.SequenceMatcher.get_grouped_opcodes` we essentially get the same 

70 # result as `pathlib.SequenceMatcher.get_opcodes` but in smaller chunks 

71 n_context = max(len(self), len(other)) 

72 diff_opcodes = list( 

73 SequenceMatcher( 

74 isjunk=None, a=self, b=other, autojunk=False 

75 ).get_grouped_opcodes(n_context) 

76 ) 

77 

78 if len(diff_opcodes) > 1: 

79 raise RuntimeError( 

80 "Expected one group for opcodes when context size is " 

81 f" {n_context} for {len(self)} and {len(other)} cells in" 

82 " notebooks." 

83 ) 

84 return Cells[CellsPair]( 

85 [ 

86 # https://github.com/python/mypy/issues/9459 

87 tuple((self.data[i1:j1], other.data[i2:j2])) # type: ignore 

88 for _, i1, j1, i2, j2 in chain.from_iterable(diff_opcodes) 

89 ] 

90 ) 

91 

92 def __rich_console__( 

93 self, console: Console, options: ConsoleOptions 

94 ) -> RenderResult: 

95 """Rich display of all cells in notebook.""" 

96 yield from self._get_renderables(expand=True, width=options.max_width // 3) 

97 

98 def _get_renderables(self, **wrap_cols_kwargs: Any) -> Iterable[RenderableType]: 

99 """Get the Rich renderables, depending on whether `Cells` is a diff or not.""" 

100 if all(isinstance(el, tuple) for el in self.data): 

101 return chain.from_iterable( 

102 Cells.wrap_cols(val[0], val[1], **wrap_cols_kwargs) 

103 if val[0] != val[1] 

104 else val[0] 

105 for val in cast(List[CellsPair], self.data) 

106 ) 

107 return cast(List[Cell], self.data) 

108 

109 @classmethod 

110 def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: 

111 """Get validators for custom class.""" 

112 yield cls.validate 

113 

114 @classmethod 

115 def validate(cls, v: List[T]) -> Cells[T]: 

116 """Ensure object is custom defined container.""" 

117 if not isinstance(v, cls): 

118 return cls(v) 

119 else: 

120 return v 

121 

122 @classmethod 

123 def wrap_cols( 

124 cls, first_cells: List[Cell], last_cells: List[Cell], **cols_kwargs: Any 

125 ) -> Sequence[Columns]: 

126 """Wrap the first and second cells into colunmns for iterable.""" 

127 _empty = [Panel(Text("<None>", justify="center"), box=box.SIMPLE)] 

128 _first = Group(*first_cells or _empty) 

129 _last = Group(*last_cells or _empty) 

130 return [Columns([_first, _last], **cols_kwargs)] 

131 

132 @staticmethod 

133 def wrap_git( 

134 first_cells: List[Cell], 

135 last_cells: List[Cell], 

136 hash_first: Optional[str] = None, 

137 hash_last: Optional[str] = None, 

138 ) -> Sequence[Cell]: 

139 """Wrap git-diff cells in existing notebook.""" 

140 return [ 

141 MarkdownCell( 

142 metadata=CellMetadata(git_hash=hash_first), 

143 source=[f"`<<<<<<< {hash_first}`"], 

144 cell_type="markdown", 

145 ), 

146 *first_cells, 

147 MarkdownCell( 

148 source=["`=======`"], 

149 cell_type="markdown", 

150 metadata=CellMetadata(), 

151 ), 

152 *last_cells, 

153 MarkdownCell( 

154 metadata=CellMetadata(git_hash=hash_last), 

155 source=[f"`>>>>>>> {hash_last}`"], 

156 cell_type="markdown", 

157 ), 

158 ] 

159 

160 def resolve( 

161 self: Cells[CellsPair], 

162 *, 

163 keep_first_cells: Optional[bool] = None, 

164 first_id: Optional[str] = None, 

165 last_id: Optional[str] = None, 

166 **kwargs: Any, 

167 ) -> List[Cell]: 

168 """ 

169 Resolve differences between `databooks.data_models.notebook.Cells`. 

170 

171 :param keep_first_cells: Whether to keep the cells of the first notebook or not. 

172 If `None`, then keep both wrapping the git-diff tags 

173 :param first_id: Git hash of first file in conflict 

174 :param last_id: Git hash of last file in conflict 

175 :param kwargs: (Unused) keyword arguments to keep compatibility with 

176 `databooks.data_models.base.resolve` 

177 :return: List of cells 

178 """ 

179 if keep_first_cells is not None: 

180 return list( 

181 chain.from_iterable(pairs[not keep_first_cells] for pairs in self.data) 

182 ) 

183 return list( 

184 chain.from_iterable( 

185 Cells.wrap_git( 

186 first_cells=val[0], 

187 last_cells=val[1], 

188 hash_first=first_id, 

189 hash_last=last_id, 

190 ) 

191 if val[0] != val[1] 

192 else val[0] 

193 for val in self.data 

194 ) 

195 ) 

196 

197 

198class NotebookMetadata(DatabooksBase): 

199 """Notebook metadata. Empty by default but can accept extra fields.""" 

200 

201 

202class JupyterNotebook(DatabooksBase, extra=Extra.forbid): 

203 """Jupyter notebook. Extra fields yield invalid notebook.""" 

204 

205 nbformat: int 

206 nbformat_minor: int 

207 metadata: NotebookMetadata 

208 cells: Cells[Cell] 

209 

210 def __rich_console__( 

211 self, console: Console, options: ConsoleOptions 

212 ) -> RenderResult: 

213 """Rich display notebook.""" 

214 

215 def _rich(kernel: str) -> Text: 

216 """Display with `kernel` theme, horizontal padding and right-justified.""" 

217 return Text(kernel, style="kernel", justify="right") 

218 

219 kernelspec = self.metadata.dict().get("kernelspec", {}) 

220 if isinstance(kernelspec, tuple): # check if this is a `DiffCells` 

221 lang_first, lang_last = (ks.get("language", "text") for ks in kernelspec) 

222 nb_lang = lang_first if lang_first == lang_last else "text" 

223 if any("display_name" in ks.keys() for ks in kernelspec): 

224 kernel_first, kernel_last = [ 

225 _rich(ks["display_name"]) for ks in kernelspec 

226 ] 

227 yield Columns( 

228 [kernel_first, kernel_last], 

229 expand=True, 

230 width=options.max_width // 3, 

231 ) if kernel_first != kernel_last else kernel_first 

232 else: 

233 nb_lang = kernelspec.get("language", "text") 

234 if "display_name" in kernelspec.keys(): 

235 yield _rich(kernelspec["display_name"]) 

236 

237 for cell in self.cells: 

238 if isinstance(cell, CodeCell): 

239 cell.metadata = CellMetadata(**cell.metadata.dict(), lang=nb_lang) 

240 yield self.cells 

241 

242 @classmethod 

243 def parse_file(cls, path: Path | str, **parse_kwargs: Any) -> JupyterNotebook: 

244 """Parse notebook from a path.""" 

245 content_arg = parse_kwargs.pop("content_type", None) 

246 if content_arg is not None: 

247 raise ValueError( 

248 f"Value of `content_type` must be `json` (default), got `{content_arg}`" 

249 ) 

250 return super(JupyterNotebook, cls).parse_file( 

251 path=path, content_type="json", **parse_kwargs 

252 ) 

253 

254 def write( 

255 self, path: Path | str, overwrite: bool = False, **json_kwargs: Any 

256 ) -> None: 

257 """Write notebook to disk.""" 

258 path = Path(path) if not isinstance(path, Path) else path 

259 json_kwargs = {"indent": 2, **json_kwargs} 

260 if path.is_file() and not overwrite: 

261 raise ValueError( 

262 f"File exists at {path} exists. Specify `overwrite = True`." 

263 ) 

264 

265 _, _, validation_error = validate_model(self.__class__, self.dict()) 

266 if validation_error: 

267 raise validation_error 

268 with path.open("w") as f: 

269 json.dump(self.dict(), fp=f, **json_kwargs) 

270 

271 def clear_metadata( 

272 self, 

273 *, 

274 notebook_metadata_keep: Sequence[str] = None, 

275 notebook_metadata_remove: Sequence[str] = None, 

276 **cell_kwargs: Any, 

277 ) -> None: 

278 """ 

279 Clear notebook and cell metadata. 

280 

281 :param notebook_metadata_keep: Metadata values to keep - simply pass an empty 

282 sequence (i.e.: `()`) to remove all extra fields. 

283 :param notebook_metadata_remove: Metadata values to remove 

284 :param cell_kwargs: keyword arguments to be passed to each cell's 

285 `databooks.data_models.cell.BaseCell.clear_metadata` 

286 :return: 

287 """ 

288 nargs = sum( 

289 (notebook_metadata_keep is not None, notebook_metadata_remove is not None) 

290 ) 

291 if nargs != 1: 

292 raise ValueError( 

293 "Exactly one of `notebook_metadata_keep` or `notebook_metadata_remove`" 

294 f" must be passed, got {nargs} arguments." 

295 ) 

296 if notebook_metadata_keep is not None: 

297 notebook_metadata_remove = tuple( 

298 field 

299 for field, _ in self.metadata 

300 if field not in notebook_metadata_keep 

301 ) 

302 self.metadata.remove_fields(notebook_metadata_remove) # type: ignore 

303 

304 if len(cell_kwargs) > 0: 

305 _clean_cells = deepcopy(self.cells) 

306 for cell in _clean_cells: 

307 cell.clear_fields(**cell_kwargs) 

308 self.cells = _clean_cells