Coverage for databooks/data_models/cell.py: 94%

143 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-10-03 12:27 +0000

1"""Data models - Cells and components.""" 

2from __future__ import annotations 

3 

4from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union 

5 

6from pydantic import PositiveInt, RootModel, field_validator 

7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult 

8from rich.markdown import Markdown 

9from rich.panel import Panel 

10from rich.syntax import Syntax 

11from rich.text import Text 

12 

13from databooks.data_models.base import DatabooksBase 

14from databooks.data_models.rich_helpers import HtmlTable, RichHtmlTableError 

15from databooks.logging import get_logger 

16 

17logger = get_logger(__file__) 

18 

19 

20class CellMetadata(DatabooksBase): 

21 """Cell metadata. Empty by default but can accept extra fields.""" 

22 

23 

24class BaseCell(DatabooksBase): 

25 """ 

26 Jupyter notebook cells. 

27 

28 Fields `outputs` and `execution_count` are not included since they should only be 

29 present in code cells - thus are treated as extra fields. 

30 """ 

31 

32 metadata: CellMetadata 

33 source: Union[List[str], str] 

34 cell_type: str 

35 

36 def __hash__(self) -> int: 

37 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

38 return hash( 

39 (type(self),) + tuple(v) if isinstance(v, list) else v 

40 for v in self.__dict__.values() 

41 ) 

42 

43 def remove_fields( 

44 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

45 ) -> None: 

46 """ 

47 Remove cell fields. 

48 

49 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

50 fields for cell type. 

51 """ 

52 # Ignore required `BaseCell` fields 

53 cell_fields = BaseCell.__fields__ # required fields 

54 if any(field in fields for field in cell_fields): 

55 logger.debug( 

56 "Ignoring removal of required fields " 

57 + str([f for f in fields if f in cell_fields]) 

58 + f" in `{type(self).__name__}`." 

59 ) 

60 fields = [f for f in fields if f not in cell_fields] 

61 

62 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok) 

63 

64 if self.cell_type == "code": 

65 self.outputs: CellOutputs = ( 

66 CellOutputs([]) if "outputs" not in dict(self) else self.outputs 

67 ) 

68 self.execution_count: Optional[PositiveInt] = ( 

69 None if "execution_count" not in dict(self) else self.execution_count 

70 ) 

71 

72 def clear_fields( 

73 self, 

74 *, 

75 cell_metadata_keep: Sequence[str] = None, 

76 cell_metadata_remove: Sequence[str] = None, 

77 cell_remove_fields: Sequence[str] = (), 

78 ) -> None: 

79 """ 

80 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

81 

82 You can also specify metadata to keep or remove from the `metadata` property of 

83 `databooks.data_models.cell.BaseCell`. 

84 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

85 sequence (i.e.: `()`) to remove all extra fields. 

86 :param cell_metadata_remove: Metadata values to remove 

87 :param cell_remove_fields: Fields to remove from cell 

88 :return: 

89 """ 

90 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

91 if nargs != 1: 

92 raise ValueError( 

93 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

94 f" be passed, got {nargs} arguments." 

95 ) 

96 

97 if cell_metadata_keep is not None: 

98 cell_metadata_remove = tuple( 

99 field for field, _ in self.metadata if field not in cell_metadata_keep 

100 ) 

101 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

102 

103 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

104 

105 

106class CellStreamOutput(DatabooksBase): 

107 """Cell output of type `stream`.""" 

108 

109 output_type: str 

110 name: str 

111 text: List[str] 

112 

113 def __rich__( 

114 self, 

115 ) -> ConsoleRenderable: 

116 """Rich display of cell stream outputs.""" 

117 return Text("".join(self.text)) 

118 

119 @field_validator("output_type") 

120 @classmethod 

121 def output_type_must_be_stream(cls, v: str) -> str: 

122 """Check if stream has `stream` type.""" 

123 if v != "stream": 

124 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.") 

125 return v 

126 

127 @field_validator("name") 

128 @classmethod 

129 def stream_name_must_match(cls, v: str) -> str: 

130 """Check if stream name is either `stdout` or `stderr`.""" 

131 valid_names = ("stdout", "stderr") 

132 if v not in valid_names: 

133 raise ValueError( 

134 f"Invalid stream name. Expected one of {valid_names}, got {v}." 

135 ) 

136 return v 

137 

138 

139class CellDisplayDataOutput(DatabooksBase): 

140 """Cell output of type `display_data`.""" 

141 

142 output_type: str 

143 data: Dict[str, Any] 

144 metadata: Dict[str, Any] 

145 

146 @property 

147 def rich_output(self) -> Sequence[ConsoleRenderable]: 

148 """Dynamically compute the rich output - also in `CellExecuteResultOutput`.""" 

149 

150 def _try_parse_html(s: str) -> Optional[ConsoleRenderable]: 

151 """Try to parse HTML table, return `None` if any errors are raised.""" 

152 try: 

153 return HtmlTable("".join(s)).rich() 

154 except RichHtmlTableError: 

155 logger.debug("Could not generate rich HTML table.") 

156 return None 

157 

158 mime_func: Dict[str, Callable[[str], Optional[ConsoleRenderable]]] = { 

159 "text/html": lambda s: _try_parse_html(s), 

160 "text/plain": lambda s: Text("".join(s)), 

161 } 

162 _rich = { 

163 mime: mime_func.get(mime, lambda s: None)(content) # try to render element 

164 for mime, content in self.data.items() 

165 } 

166 return [ 

167 *[ 

168 Text(f"<✨Rich✨ `{mime}` not available 😢>") 

169 for mime, renderable in _rich.items() 

170 if renderable is None 

171 ], 

172 next(renderable for renderable in _rich.values() if renderable is not None), 

173 ] 

174 

175 def __rich_console__( 

176 self, console: Console, options: ConsoleOptions 

177 ) -> RenderResult: 

178 """Rich display of data display outputs.""" 

179 yield from self.rich_output 

180 

181 @field_validator("output_type") 

182 @classmethod 

183 def output_type_must_match(cls, v: str) -> str: 

184 """Check if stream has `display_data` type.""" 

185 if v != "display_data": 

186 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.") 

187 return v 

188 

189 

190class CellExecuteResultOutput(CellDisplayDataOutput): 

191 """Cell output of type `execute_result`.""" 

192 

193 execution_count: PositiveInt 

194 

195 def __rich_console__( 

196 self, console: Console, options: ConsoleOptions 

197 ) -> RenderResult: 

198 """Rich display of executed cell outputs.""" 

199 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count") 

200 yield from self.rich_output 

201 

202 @field_validator("output_type") 

203 @classmethod 

204 def output_type_must_match(cls, v: str) -> str: 

205 """Check if stream has `execute_result` type.""" 

206 if v != "execute_result": 

207 raise ValueError( 

208 f"Invalid output type. Expected `execute_result`, got {v}." 

209 ) 

210 return v 

211 

212 

213class CellErrorOutput(DatabooksBase): 

214 """Cell output of type `error`.""" 

215 

216 output_type: str 

217 ename: str 

218 evalue: str 

219 traceback: List[str] 

220 

221 def __rich__( 

222 self, 

223 ) -> ConsoleRenderable: 

224 """Rich display of error outputs.""" 

225 return Text.from_ansi("\n".join(self.traceback)) 

226 

227 @field_validator("output_type") 

228 @classmethod 

229 def output_type_must_match(cls, v: str) -> str: 

230 """Check if stream has `error` type.""" 

231 if v != "error": 

232 raise ValueError(f"Invalid output type. Expected `error`, got {v}.") 

233 return v 

234 

235 

236CellOutputType = Union[ 

237 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput 

238] 

239 

240 

241class CellOutputs(RootModel): 

242 """Outputs of notebook code cells.""" 

243 

244 root: List[CellOutputType] 

245 

246 def __rich_console__( 

247 self, console: Console, options: ConsoleOptions 

248 ) -> RenderResult: 

249 """Rich display of code cell outputs.""" 

250 yield from self.values 

251 

252 @property 

253 def values( 

254 self, 

255 ) -> List[CellOutputType]: 

256 """Alias `root` with outputs for easy referencing.""" 

257 return self.root 

258 

259 

260class CodeCell(BaseCell): 

261 """Cell of type `code` - defined for rich displaying in terminal.""" 

262 

263 outputs: CellOutputs 

264 cell_type: str = "code" 

265 

266 def __rich_console__( 

267 self, console: Console, options: ConsoleOptions 

268 ) -> RenderResult: 

269 """Rich display of code cells.""" 

270 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count") 

271 yield Panel( 

272 Syntax( 

273 "".join(self.source) if isinstance(self.source, list) else self.source, 

274 getattr(self.metadata, "lang", "text"), 

275 ) 

276 ) 

277 yield self.outputs 

278 

279 @field_validator("cell_type") 

280 @classmethod 

281 def cell_has_code_type(cls, v: str) -> str: 

282 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

283 if v != "code": 

284 raise ValueError(f"Expected code of type `code`, got `{v}`.") 

285 return v 

286 

287 

288class MarkdownCell(BaseCell): 

289 """Cell of type `markdown` - defined for rich displaying in terminal.""" 

290 

291 cell_type: str = "markdown" 

292 

293 def __rich__( 

294 self, 

295 ) -> ConsoleRenderable: 

296 """Rich display of markdown cells.""" 

297 return Panel(Markdown("".join(self.source))) 

298 

299 @field_validator("cell_type") 

300 @classmethod 

301 def cell_has_md_type(cls, v: str) -> str: 

302 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

303 if v != "markdown": 

304 raise ValueError(f"Expected code of type `markdown`, got {v}.") 

305 return v 

306 

307 

308class RawCell(BaseCell): 

309 """Cell of type `raw` - defined for rich displaying in terminal.""" 

310 

311 cell_type: str = "raw" 

312 

313 def __rich__( 

314 self, 

315 ) -> ConsoleRenderable: 

316 """Rich display of raw cells.""" 

317 return Panel(Text("".join(self.source))) 

318 

319 @field_validator("cell_type") 

320 @classmethod 

321 def cell_has_md_type(cls, v: str) -> str: 

322 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

323 if v != "raw": 

324 raise ValueError(f"Expected code of type `raw`, got {v}.") 

325 return v