Coverage for databooks/data_models/cell.py: 94%
143 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
1"""Data models - Cells and components."""
2from __future__ import annotations
4from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
6from pydantic import PositiveInt, RootModel, field_validator
7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult
8from rich.markdown import Markdown
9from rich.panel import Panel
10from rich.syntax import Syntax
11from rich.text import Text
13from databooks.data_models.base import DatabooksBase
14from databooks.data_models.rich_helpers import HtmlTable, RichHtmlTableError
15from databooks.logging import get_logger
17logger = get_logger(__file__)
20class CellMetadata(DatabooksBase):
21 """Cell metadata. Empty by default but can accept extra fields."""
24class BaseCell(DatabooksBase):
25 """
26 Jupyter notebook cells.
28 Fields `outputs` and `execution_count` are not included since they should only be
29 present in code cells - thus are treated as extra fields.
30 """
32 metadata: CellMetadata
33 source: Union[List[str], str]
34 cell_type: str
36 def __hash__(self) -> int:
37 """Cells must be hashable for `difflib.SequenceMatcher`."""
38 return hash(
39 (type(self),) + tuple(v) if isinstance(v, list) else v
40 for v in self.__dict__.values()
41 )
43 def remove_fields(
44 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
45 ) -> None:
46 """
47 Remove cell fields.
49 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
50 fields for cell type.
51 """
52 # Ignore required `BaseCell` fields
53 cell_fields = BaseCell.__fields__ # required fields
54 if any(field in fields for field in cell_fields):
55 logger.debug(
56 "Ignoring removal of required fields "
57 + str([f for f in fields if f in cell_fields])
58 + f" in `{type(self).__name__}`."
59 )
60 fields = [f for f in fields if f not in cell_fields]
62 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok)
64 if self.cell_type == "code":
65 self.outputs: CellOutputs = (
66 CellOutputs([]) if "outputs" not in dict(self) else self.outputs
67 )
68 self.execution_count: Optional[PositiveInt] = (
69 None if "execution_count" not in dict(self) else self.execution_count
70 )
72 def clear_fields(
73 self,
74 *,
75 cell_metadata_keep: Sequence[str] = None,
76 cell_metadata_remove: Sequence[str] = None,
77 cell_remove_fields: Sequence[str] = (),
78 ) -> None:
79 """
80 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
82 You can also specify metadata to keep or remove from the `metadata` property of
83 `databooks.data_models.cell.BaseCell`.
84 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
85 sequence (i.e.: `()`) to remove all extra fields.
86 :param cell_metadata_remove: Metadata values to remove
87 :param cell_remove_fields: Fields to remove from cell
88 :return:
89 """
90 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
91 if nargs != 1:
92 raise ValueError(
93 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
94 f" be passed, got {nargs} arguments."
95 )
97 if cell_metadata_keep is not None:
98 cell_metadata_remove = tuple(
99 field for field, _ in self.metadata if field not in cell_metadata_keep
100 )
101 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
103 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
106class CellStreamOutput(DatabooksBase):
107 """Cell output of type `stream`."""
109 output_type: str
110 name: str
111 text: List[str]
113 def __rich__(
114 self,
115 ) -> ConsoleRenderable:
116 """Rich display of cell stream outputs."""
117 return Text("".join(self.text))
119 @field_validator("output_type")
120 @classmethod
121 def output_type_must_be_stream(cls, v: str) -> str:
122 """Check if stream has `stream` type."""
123 if v != "stream":
124 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.")
125 return v
127 @field_validator("name")
128 @classmethod
129 def stream_name_must_match(cls, v: str) -> str:
130 """Check if stream name is either `stdout` or `stderr`."""
131 valid_names = ("stdout", "stderr")
132 if v not in valid_names:
133 raise ValueError(
134 f"Invalid stream name. Expected one of {valid_names}, got {v}."
135 )
136 return v
139class CellDisplayDataOutput(DatabooksBase):
140 """Cell output of type `display_data`."""
142 output_type: str
143 data: Dict[str, Any]
144 metadata: Dict[str, Any]
146 @property
147 def rich_output(self) -> Sequence[ConsoleRenderable]:
148 """Dynamically compute the rich output - also in `CellExecuteResultOutput`."""
150 def _try_parse_html(s: str) -> Optional[ConsoleRenderable]:
151 """Try to parse HTML table, return `None` if any errors are raised."""
152 try:
153 return HtmlTable("".join(s)).rich()
154 except RichHtmlTableError:
155 logger.debug("Could not generate rich HTML table.")
156 return None
158 mime_func: Dict[str, Callable[[str], Optional[ConsoleRenderable]]] = {
159 "text/html": lambda s: _try_parse_html(s),
160 "text/plain": lambda s: Text("".join(s)),
161 }
162 _rich = {
163 mime: mime_func.get(mime, lambda s: None)(content) # try to render element
164 for mime, content in self.data.items()
165 }
166 return [
167 *[
168 Text(f"<✨Rich✨ `{mime}` not available 😢>")
169 for mime, renderable in _rich.items()
170 if renderable is None
171 ],
172 next(renderable for renderable in _rich.values() if renderable is not None),
173 ]
175 def __rich_console__(
176 self, console: Console, options: ConsoleOptions
177 ) -> RenderResult:
178 """Rich display of data display outputs."""
179 yield from self.rich_output
181 @field_validator("output_type")
182 @classmethod
183 def output_type_must_match(cls, v: str) -> str:
184 """Check if stream has `display_data` type."""
185 if v != "display_data":
186 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.")
187 return v
190class CellExecuteResultOutput(CellDisplayDataOutput):
191 """Cell output of type `execute_result`."""
193 execution_count: PositiveInt
195 def __rich_console__(
196 self, console: Console, options: ConsoleOptions
197 ) -> RenderResult:
198 """Rich display of executed cell outputs."""
199 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count")
200 yield from self.rich_output
202 @field_validator("output_type")
203 @classmethod
204 def output_type_must_match(cls, v: str) -> str:
205 """Check if stream has `execute_result` type."""
206 if v != "execute_result":
207 raise ValueError(
208 f"Invalid output type. Expected `execute_result`, got {v}."
209 )
210 return v
213class CellErrorOutput(DatabooksBase):
214 """Cell output of type `error`."""
216 output_type: str
217 ename: str
218 evalue: str
219 traceback: List[str]
221 def __rich__(
222 self,
223 ) -> ConsoleRenderable:
224 """Rich display of error outputs."""
225 return Text.from_ansi("\n".join(self.traceback))
227 @field_validator("output_type")
228 @classmethod
229 def output_type_must_match(cls, v: str) -> str:
230 """Check if stream has `error` type."""
231 if v != "error":
232 raise ValueError(f"Invalid output type. Expected `error`, got {v}.")
233 return v
236CellOutputType = Union[
237 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput
238]
241class CellOutputs(RootModel):
242 """Outputs of notebook code cells."""
244 root: List[CellOutputType]
246 def __rich_console__(
247 self, console: Console, options: ConsoleOptions
248 ) -> RenderResult:
249 """Rich display of code cell outputs."""
250 yield from self.values
252 @property
253 def values(
254 self,
255 ) -> List[CellOutputType]:
256 """Alias `root` with outputs for easy referencing."""
257 return self.root
260class CodeCell(BaseCell):
261 """Cell of type `code` - defined for rich displaying in terminal."""
263 outputs: CellOutputs
264 cell_type: str = "code"
266 def __rich_console__(
267 self, console: Console, options: ConsoleOptions
268 ) -> RenderResult:
269 """Rich display of code cells."""
270 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count")
271 yield Panel(
272 Syntax(
273 "".join(self.source) if isinstance(self.source, list) else self.source,
274 getattr(self.metadata, "lang", "text"),
275 )
276 )
277 yield self.outputs
279 @field_validator("cell_type")
280 @classmethod
281 def cell_has_code_type(cls, v: str) -> str:
282 """Extract the list values from the __root__ attribute of `CellOutputs`."""
283 if v != "code":
284 raise ValueError(f"Expected code of type `code`, got `{v}`.")
285 return v
288class MarkdownCell(BaseCell):
289 """Cell of type `markdown` - defined for rich displaying in terminal."""
291 cell_type: str = "markdown"
293 def __rich__(
294 self,
295 ) -> ConsoleRenderable:
296 """Rich display of markdown cells."""
297 return Panel(Markdown("".join(self.source)))
299 @field_validator("cell_type")
300 @classmethod
301 def cell_has_md_type(cls, v: str) -> str:
302 """Extract the list values from the __root__ attribute of `CellOutputs`."""
303 if v != "markdown":
304 raise ValueError(f"Expected code of type `markdown`, got {v}.")
305 return v
308class RawCell(BaseCell):
309 """Cell of type `raw` - defined for rich displaying in terminal."""
311 cell_type: str = "raw"
313 def __rich__(
314 self,
315 ) -> ConsoleRenderable:
316 """Rich display of raw cells."""
317 return Panel(Text("".join(self.source)))
319 @field_validator("cell_type")
320 @classmethod
321 def cell_has_md_type(cls, v: str) -> str:
322 """Extract the list values from the __root__ attribute of `CellOutputs`."""
323 if v != "raw":
324 raise ValueError(f"Expected code of type `raw`, got {v}.")
325 return v