Coverage for databooks/data_models/cell.py: 95%

129 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 13:11 +0000

1"""Data models - Cells and components.""" 

2from __future__ import annotations 

3 

4from typing import Any, Dict, Iterable, List, Optional, Sequence, Union 

5 

6from pydantic import PositiveInt, validator 

7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult 

8from rich.markdown import Markdown 

9from rich.panel import Panel 

10from rich.syntax import Syntax 

11from rich.text import Text 

12 

13from databooks.data_models.base import DatabooksBase 

14from databooks.logging import get_logger 

15 

16logger = get_logger(__file__) 

17 

18 

19class CellMetadata(DatabooksBase): 

20 """Cell metadata. Empty by default but can accept extra fields.""" 

21 

22 

23class BaseCell(DatabooksBase): 

24 """ 

25 Jupyter notebook cells. 

26 

27 Fields `outputs` and `execution_count` are not included since they should only be 

28 present in code cells - thus are treated as extra fields. 

29 """ 

30 

31 metadata: CellMetadata 

32 source: Union[List[str], str] 

33 cell_type: str 

34 

35 def __hash__(self) -> int: 

36 """Cells must be hashable for `difflib.SequenceMatcher`.""" 

37 return hash( 

38 (type(self),) + tuple(v) if isinstance(v, list) else v 

39 for v in self.__dict__.values() 

40 ) 

41 

42 def remove_fields( 

43 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any 

44 ) -> None: 

45 """ 

46 Remove cell fields. 

47 

48 Similar to `databooks.data_models.base.remove_fields`, but will ignore required 

49 fields for cell type. 

50 """ 

51 # Ignore required `BaseCell` fields 

52 cell_fields = BaseCell.__fields__ # required fields 

53 if any(field in fields for field in cell_fields): 

54 logger.debug( 

55 "Ignoring removal of required fields " 

56 + str([f for f in fields if f in cell_fields]) 

57 + f" in `{type(self).__name__}`." 

58 ) 

59 fields = [f for f in fields if f not in cell_fields] 

60 

61 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok) 

62 

63 if self.cell_type == "code": 

64 self.outputs: CellOutputs = ( 

65 CellOutputs(__root__=[]) 

66 if "outputs" not in dict(self) 

67 else self.outputs 

68 ) 

69 self.execution_count: Optional[PositiveInt] = ( 

70 None if "execution_count" not in dict(self) else self.execution_count 

71 ) 

72 

73 def clear_fields( 

74 self, 

75 *, 

76 cell_metadata_keep: Sequence[str] = None, 

77 cell_metadata_remove: Sequence[str] = None, 

78 cell_remove_fields: Sequence[str] = (), 

79 ) -> None: 

80 """ 

81 Clear cell metadata, execution count, outputs or other desired fields (id, ...). 

82 

83 You can also specify metadata to keep or remove from the `metadata` property of 

84 `databooks.data_models.cell.BaseCell`. 

85 :param cell_metadata_keep: Metadata values to keep - simply pass an empty 

86 sequence (i.e.: `()`) to remove all extra fields. 

87 :param cell_metadata_remove: Metadata values to remove 

88 :param cell_remove_fields: Fields to remove from cell 

89 :return: 

90 """ 

91 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None)) 

92 if nargs != 1: 

93 raise ValueError( 

94 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must" 

95 f" be passed, got {nargs} arguments." 

96 ) 

97 

98 if cell_metadata_keep is not None: 

99 cell_metadata_remove = tuple( 

100 field for field, _ in self.metadata if field not in cell_metadata_keep 

101 ) 

102 self.metadata.remove_fields(cell_metadata_remove) # type: ignore 

103 

104 self.remove_fields(fields=cell_remove_fields, missing_ok=True) 

105 

106 

107class CellStreamOutput(DatabooksBase): 

108 """Cell output of type `stream`.""" 

109 

110 output_type: str 

111 name: str 

112 text: List[str] 

113 

114 def __rich__( 

115 self, 

116 ) -> ConsoleRenderable: 

117 """Rich display of cell stream outputs.""" 

118 return Text("".join(self.text)) 

119 

120 @validator("output_type") 

121 def output_type_must_be_stream(cls, v: str) -> str: 

122 """Check if stream has `stream` type.""" 

123 if v != "stream": 

124 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.") 

125 return v 

126 

127 @validator("name") 

128 def stream_name_must_match(cls, v: str) -> str: 

129 """Check if stream name is either `stdout` or `stderr`.""" 

130 valid_names = ("stdout", "stderr") 

131 if v not in valid_names: 

132 raise ValueError( 

133 f"Invalid stream name. Expected one of {valid_names}, got {v}." 

134 ) 

135 return v 

136 

137 

138class CellDisplayDataOutput(DatabooksBase): 

139 """Cell output of type `display_data`.""" 

140 

141 output_type: str 

142 data: Dict[str, Any] 

143 metadata: Dict[str, Any] 

144 

145 @property 

146 def rich_output(self) -> Sequence[ConsoleRenderable]: 

147 """Dynamically compute the rich output - also in `CellExecuteResultOutput`.""" 

148 mime_func = { 

149 "image/png": None, 

150 "text/html": None, 

151 "text/plain": lambda s: Text("".join(s)), 

152 } 

153 supported = [k for k, v in mime_func.items() if v is not None] 

154 not_supported = [ 

155 Text(f"<✨Rich✨ `{mime}` not currently supported 😢>") 

156 for mime in self.data.keys() 

157 if mime not in supported 

158 ] 

159 return not_supported + [ 

160 next( 

161 mime_func[mime](content) # type: ignore 

162 for mime, content in self.data.items() 

163 if mime in supported 

164 ) 

165 ] 

166 

167 def __rich_console__( 

168 self, console: Console, options: ConsoleOptions 

169 ) -> RenderResult: 

170 """Rich display of data display outputs.""" 

171 yield from self.rich_output 

172 

173 @validator("output_type") 

174 def output_type_must_match(cls, v: str) -> str: 

175 """Check if stream has `display_data` type.""" 

176 if v != "display_data": 

177 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.") 

178 return v 

179 

180 

181class CellExecuteResultOutput(CellDisplayDataOutput): 

182 """Cell output of type `execute_result`.""" 

183 

184 execution_count: PositiveInt 

185 

186 def __rich_console__( 

187 self, console: Console, options: ConsoleOptions 

188 ) -> RenderResult: 

189 """Rich display of executed cell outputs.""" 

190 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count") 

191 yield from self.rich_output 

192 

193 @validator("output_type") 

194 def output_type_must_match(cls, v: str) -> str: 

195 """Check if stream has `execute_result` type.""" 

196 if v != "execute_result": 

197 raise ValueError( 

198 f"Invalid output type. Expected `execute_result`, got {v}." 

199 ) 

200 return v 

201 

202 

203class CellErrorOutput(DatabooksBase): 

204 """Cell output of type `error`.""" 

205 

206 output_type: str 

207 ename: str 

208 evalue: str 

209 traceback: List[str] 

210 

211 def __rich__( 

212 self, 

213 ) -> ConsoleRenderable: 

214 """Rich display of error outputs.""" 

215 return Text.from_ansi("\n".join(self.traceback)) 

216 

217 @validator("output_type") 

218 def output_type_must_match(cls, v: str) -> str: 

219 """Check if stream has `error` type.""" 

220 if v != "error": 

221 raise ValueError(f"Invalid output type. Expected `error`, got {v}.") 

222 return v 

223 

224 

225CellOutputType = Union[ 

226 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput 

227] 

228 

229 

230class CellOutputs(DatabooksBase): 

231 """Outputs of notebook code cells.""" 

232 

233 __root__: List[CellOutputType] 

234 

235 def __rich_console__( 

236 self, console: Console, options: ConsoleOptions 

237 ) -> RenderResult: 

238 """Rich display of code cell outputs.""" 

239 yield from self.values 

240 

241 @property 

242 def values( 

243 self, 

244 ) -> List[CellOutputType]: 

245 """Alias `__root__` with outputs for easy referencing.""" 

246 return self.__root__ 

247 

248 

249class CodeCell(BaseCell): 

250 """Cell of type `code` - defined for rich displaying in terminal.""" 

251 

252 outputs: CellOutputs 

253 cell_type: str = "code" 

254 

255 def __rich_console__( 

256 self, console: Console, options: ConsoleOptions 

257 ) -> RenderResult: 

258 """Rich display of code cells.""" 

259 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count") 

260 yield Panel( 

261 Syntax( 

262 "".join(self.source) if isinstance(self.source, list) else self.source, 

263 getattr(self.metadata, "lang", "text"), 

264 ) 

265 ) 

266 yield self.outputs 

267 

268 @validator("cell_type") 

269 def cell_has_code_type(cls, v: str) -> str: 

270 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

271 if v != "code": 

272 raise ValueError(f"Expected code of type `code`, got `{v}`.") 

273 return v 

274 

275 

276class MarkdownCell(BaseCell): 

277 """Cell of type `markdown` - defined for rich displaying in terminal.""" 

278 

279 cell_type: str = "markdown" 

280 

281 def __rich__( 

282 self, 

283 ) -> ConsoleRenderable: 

284 """Rich display of markdown cells.""" 

285 return Panel(Markdown("".join(self.source))) 

286 

287 @validator("cell_type") 

288 def cell_has_md_type(cls, v: str) -> str: 

289 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

290 if v != "markdown": 

291 raise ValueError(f"Expected code of type `markdown`, got {v}.") 

292 return v 

293 

294 

295class RawCell(BaseCell): 

296 """Cell of type `raw` - defined for rich displaying in terminal.""" 

297 

298 cell_type: str = "raw" 

299 

300 def __rich__( 

301 self, 

302 ) -> ConsoleRenderable: 

303 """Rich display of raw cells.""" 

304 return Panel(Text("".join(self.source))) 

305 

306 @validator("cell_type") 

307 def cell_has_md_type(cls, v: str) -> str: 

308 """Extract the list values from the __root__ attribute of `CellOutputs`.""" 

309 if v != "raw": 

310 raise ValueError(f"Expected code of type `raw`, got {v}.") 

311 return v