Coverage for databooks/data_models/cell.py: 95%

129 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-11 20:30 +0000

1"""Data models - Cells and components.""" 

2from __future__ import annotations 

3 

4from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union 

5 

6from pydantic import PositiveInt, validator 

7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult 

8from rich.markdown import Markdown 

9from rich.panel import Panel 

10from rich.syntax import Syntax 

11from rich.text import Text 

12 

13from databooks.data_models.base import DatabooksBase 

14from databooks.data_models.rich_helpers import HtmlTable 

15from databooks.logging import get_logger 

16 

17logger = get_logger(__file__) 

18 

19 

20class CellMetadata(DatabooksBase): 

21 """Cell metadata. Empty by default but can accept extra fields.""" 

22 

23 

24class BaseCell(DatabooksBase): 

25 """ 

26 Jupyter notebook cells. 

27 

28 Fields `outputs` and `execution_count` are not included since they should only be 

29 present in code cells - thus are treated as extra fields. 

30 """ 

31 

32 metadata: CellMetadata 

33 source: Union[List[str], str] 

34 cell_type: str 

35 

36 def __hash__(self) -> int: 

37 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

38 return hash( 

39 (type(self),) + tuple(v) if isinstance(v, list) else v 

40 for v in self.__dict__.values() 

41 ) 

42 

43 def remove_fields( 

44 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

45 ) -> None: 

46 """ 

47 Remove cell fields. 

48 

49 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

50 fields for cell type. 

51 """ 

52 # Ignore required `BaseCell` fields 

53 cell_fields = BaseCell.__fields__ # required fields 

54 if any(field in fields for field in cell_fields): 

55 logger.debug( 

56 "Ignoring removal of required fields " 

57 + str([f for f in fields if f in cell_fields]) 

58 + f" in `{type(self).__name__}`." 

59 ) 

60 fields = [f for f in fields if f not in cell_fields] 

61 

62 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok) 

63 

64 if self.cell_type == "code": 

65 self.outputs: CellOutputs = ( 

66 CellOutputs(__root__=[]) 

67 if "outputs" not in dict(self) 

68 else self.outputs 

69 ) 

70 self.execution_count: Optional[PositiveInt] = ( 

71 None if "execution_count" not in dict(self) else self.execution_count 

72 ) 

73 

74 def clear_fields( 

75 self, 

76 *, 

77 cell_metadata_keep: Sequence[str] = None, 

78 cell_metadata_remove: Sequence[str] = None, 

79 cell_remove_fields: Sequence[str] = (), 

80 ) -> None: 

81 """ 

82 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

83 

84 You can also specify metadata to keep or remove from the `metadata` property of 

85 `databooks.data_models.cell.BaseCell`. 

86 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

87 sequence (i.e.: `()`) to remove all extra fields. 

88 :param cell_metadata_remove: Metadata values to remove 

89 :param cell_remove_fields: Fields to remove from cell 

90 :return: 

91 """ 

92 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

93 if nargs != 1: 

94 raise ValueError( 

95 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

96 f" be passed, got {nargs} arguments." 

97 ) 

98 

99 if cell_metadata_keep is not None: 

100 cell_metadata_remove = tuple( 

101 field for field, _ in self.metadata if field not in cell_metadata_keep 

102 ) 

103 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

104 

105 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

106 

107 

108class CellStreamOutput(DatabooksBase): 

109 """Cell output of type `stream`.""" 

110 

111 output_type: str 

112 name: str 

113 text: List[str] 

114 

115 def __rich__( 

116 self, 

117 ) -> ConsoleRenderable: 

118 """Rich display of cell stream outputs.""" 

119 return Text("".join(self.text)) 

120 

121 @validator("output_type") 

122 def output_type_must_be_stream(cls, v: str) -> str: 

123 """Check if stream has `stream` type.""" 

124 if v != "stream": 

125 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.") 

126 return v 

127 

128 @validator("name") 

129 def stream_name_must_match(cls, v: str) -> str: 

130 """Check if stream name is either `stdout` or `stderr`.""" 

131 valid_names = ("stdout", "stderr") 

132 if v not in valid_names: 

133 raise ValueError( 

134 f"Invalid stream name. Expected one of {valid_names}, got {v}." 

135 ) 

136 return v 

137 

138 

139class CellDisplayDataOutput(DatabooksBase): 

140 """Cell output of type `display_data`.""" 

141 

142 output_type: str 

143 data: Dict[str, Any] 

144 metadata: Dict[str, Any] 

145 

146 @property 

147 def rich_output(self) -> Sequence[ConsoleRenderable]: 

148 """Dynamically compute the rich output - also in `CellExecuteResultOutput`.""" 

149 mime_func: Dict[str, Callable[[str], Optional[ConsoleRenderable]]] = { 

150 "image/png": lambda s: None, 

151 "text/html": lambda s: HtmlTable("".join(s)).rich(), 

152 "text/plain": lambda s: Text("".join(s)), 

153 } 

154 _rich = { 

155 mime: mime_func.get(mime, lambda s: None)(content) # try to render element 

156 for mime, content in self.data.items() 

157 } 

158 return [ 

159 *[ 

160 Text(f"<✨Rich✨ `{mime}` not available 😢>") 

161 for mime, renderable in _rich.items() 

162 if renderable is None 

163 ], 

164 next(renderable for renderable in _rich.values() if renderable is not None), 

165 ] 

166 

167 def __rich_console__( 

168 self, console: Console, options: ConsoleOptions 

169 ) -> RenderResult: 

170 """Rich display of data display outputs.""" 

171 yield from self.rich_output 

172 

173 @validator("output_type") 

174 def output_type_must_match(cls, v: str) -> str: 

175 """Check if stream has `display_data` type.""" 

176 if v != "display_data": 

177 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.") 

178 return v 

179 

180 

181class CellExecuteResultOutput(CellDisplayDataOutput): 

182 """Cell output of type `execute_result`.""" 

183 

184 execution_count: PositiveInt 

185 

186 def __rich_console__( 

187 self, console: Console, options: ConsoleOptions 

188 ) -> RenderResult: 

189 """Rich display of executed cell outputs.""" 

190 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count") 

191 yield from self.rich_output 

192 

193 @validator("output_type") 

194 def output_type_must_match(cls, v: str) -> str: 

195 """Check if stream has `execute_result` type.""" 

196 if v != "execute_result": 

197 raise ValueError( 

198 f"Invalid output type. Expected `execute_result`, got {v}." 

199 ) 

200 return v 

201 

202 

203class CellErrorOutput(DatabooksBase): 

204 """Cell output of type `error`.""" 

205 

206 output_type: str 

207 ename: str 

208 evalue: str 

209 traceback: List[str] 

210 

211 def __rich__( 

212 self, 

213 ) -> ConsoleRenderable: 

214 """Rich display of error outputs.""" 

215 return Text.from_ansi("\n".join(self.traceback)) 

216 

217 @validator("output_type") 

218 def output_type_must_match(cls, v: str) -> str: 

219 """Check if stream has `error` type.""" 

220 if v != "error": 

221 raise ValueError(f"Invalid output type. Expected `error`, got {v}.") 

222 return v 

223 

224 

225CellOutputType = Union[ 

226 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput 

227] 

228 

229 

230class CellOutputs(DatabooksBase): 

231 """Outputs of notebook code cells.""" 

232 

233 __root__: List[CellOutputType] 

234 

235 def __rich_console__( 

236 self, console: Console, options: ConsoleOptions 

237 ) -> RenderResult: 

238 """Rich display of code cell outputs.""" 

239 yield from self.values 

240 

241 @property 

242 def values( 

243 self, 

244 ) -> List[CellOutputType]: 

245 """Alias `__root__` with outputs for easy referencing.""" 

246 return self.__root__ 

247 

248 

249class CodeCell(BaseCell): 

250 """Cell of type `code` - defined for rich displaying in terminal.""" 

251 

252 outputs: CellOutputs 

253 cell_type: str = "code" 

254 

255 def __rich_console__( 

256 self, console: Console, options: ConsoleOptions 

257 ) -> RenderResult: 

258 """Rich display of code cells.""" 

259 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count") 

260 yield Panel( 

261 Syntax( 

262 "".join(self.source) if isinstance(self.source, list) else self.source, 

263 getattr(self.metadata, "lang", "text"), 

264 ) 

265 ) 

266 yield self.outputs 

267 

268 @validator("cell_type") 

269 def cell_has_code_type(cls, v: str) -> str: 

270 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

271 if v != "code": 

272 raise ValueError(f"Expected code of type `code`, got `{v}`.") 

273 return v 

274 

275 

276class MarkdownCell(BaseCell): 

277 """Cell of type `markdown` - defined for rich displaying in terminal.""" 

278 

279 cell_type: str = "markdown" 

280 

281 def __rich__( 

282 self, 

283 ) -> ConsoleRenderable: 

284 """Rich display of markdown cells.""" 

285 return Panel(Markdown("".join(self.source))) 

286 

287 @validator("cell_type") 

288 def cell_has_md_type(cls, v: str) -> str: 

289 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

290 if v != "markdown": 

291 raise ValueError(f"Expected code of type `markdown`, got {v}.") 

292 return v 

293 

294 

295class RawCell(BaseCell): 

296 """Cell of type `raw` - defined for rich displaying in terminal.""" 

297 

298 cell_type: str = "raw" 

299 

300 def __rich__( 

301 self, 

302 ) -> ConsoleRenderable: 

303 """Rich display of raw cells.""" 

304 return Panel(Text("".join(self.source))) 

305 

306 @validator("cell_type") 

307 def cell_has_md_type(cls, v: str) -> str: 

308 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

309 if v != "raw": 

310 raise ValueError(f"Expected code of type `raw`, got {v}.") 

311 return v