Coverage for databooks/data_models/cell.py: 95%
129 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
1"""Data models - Cells and components."""
2from __future__ import annotations
4from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
6from pydantic import PositiveInt, validator
7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult
8from rich.markdown import Markdown
9from rich.panel import Panel
10from rich.syntax import Syntax
11from rich.text import Text
13from databooks.data_models.base import DatabooksBase
14from databooks.data_models.rich_helpers import HtmlTable
15from databooks.logging import get_logger
17logger = get_logger(__file__)
20class CellMetadata(DatabooksBase):
21 """Cell metadata. Empty by default but can accept extra fields."""
24class BaseCell(DatabooksBase):
25 """
26 Jupyter notebook cells.
28 Fields `outputs` and `execution_count` are not included since they should only be
29 present in code cells - thus are treated as extra fields.
30 """
32 metadata: CellMetadata
33 source: Union[List[str], str]
34 cell_type: str
36 def __hash__(self) -> int:
37 """Cells must be hashable for `difflib.SequenceMatcher`."""
38 return hash(
39 (type(self),) + tuple(v) if isinstance(v, list) else v
40 for v in self.__dict__.values()
41 )
43 def remove_fields(
44 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
45 ) -> None:
46 """
47 Remove cell fields.
49 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
50 fields for cell type.
51 """
52 # Ignore required `BaseCell` fields
53 cell_fields = BaseCell.__fields__ # required fields
54 if any(field in fields for field in cell_fields):
55 logger.debug(
56 "Ignoring removal of required fields "
57 + str([f for f in fields if f in cell_fields])
58 + f" in `{type(self).__name__}`."
59 )
60 fields = [f for f in fields if f not in cell_fields]
62 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok)
64 if self.cell_type == "code":
65 self.outputs: CellOutputs = (
66 CellOutputs(__root__=[])
67 if "outputs" not in dict(self)
68 else self.outputs
69 )
70 self.execution_count: Optional[PositiveInt] = (
71 None if "execution_count" not in dict(self) else self.execution_count
72 )
74 def clear_fields(
75 self,
76 *,
77 cell_metadata_keep: Sequence[str] = None,
78 cell_metadata_remove: Sequence[str] = None,
79 cell_remove_fields: Sequence[str] = (),
80 ) -> None:
81 """
82 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
84 You can also specify metadata to keep or remove from the `metadata` property of
85 `databooks.data_models.cell.BaseCell`.
86 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
87 sequence (i.e.: `()`) to remove all extra fields.
88 :param cell_metadata_remove: Metadata values to remove
89 :param cell_remove_fields: Fields to remove from cell
90 :return:
91 """
92 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
93 if nargs != 1:
94 raise ValueError(
95 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
96 f" be passed, got {nargs} arguments."
97 )
99 if cell_metadata_keep is not None:
100 cell_metadata_remove = tuple(
101 field for field, _ in self.metadata if field not in cell_metadata_keep
102 )
103 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
105 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
108class CellStreamOutput(DatabooksBase):
109 """Cell output of type `stream`."""
111 output_type: str
112 name: str
113 text: List[str]
115 def __rich__(
116 self,
117 ) -> ConsoleRenderable:
118 """Rich display of cell stream outputs."""
119 return Text("".join(self.text))
121 @validator("output_type")
122 def output_type_must_be_stream(cls, v: str) -> str:
123 """Check if stream has `stream` type."""
124 if v != "stream":
125 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.")
126 return v
128 @validator("name")
129 def stream_name_must_match(cls, v: str) -> str:
130 """Check if stream name is either `stdout` or `stderr`."""
131 valid_names = ("stdout", "stderr")
132 if v not in valid_names:
133 raise ValueError(
134 f"Invalid stream name. Expected one of {valid_names}, got {v}."
135 )
136 return v
139class CellDisplayDataOutput(DatabooksBase):
140 """Cell output of type `display_data`."""
142 output_type: str
143 data: Dict[str, Any]
144 metadata: Dict[str, Any]
146 @property
147 def rich_output(self) -> Sequence[ConsoleRenderable]:
148 """Dynamically compute the rich output - also in `CellExecuteResultOutput`."""
149 mime_func: Dict[str, Callable[[str], Optional[ConsoleRenderable]]] = {
150 "image/png": lambda s: None,
151 "text/html": lambda s: HtmlTable("".join(s)).rich(),
152 "text/plain": lambda s: Text("".join(s)),
153 }
154 _rich = {
155 mime: mime_func.get(mime, lambda s: None)(content) # try to render element
156 for mime, content in self.data.items()
157 }
158 return [
159 *[
160 Text(f"<✨Rich✨ `{mime}` not available 😢>")
161 for mime, renderable in _rich.items()
162 if renderable is None
163 ],
164 next(renderable for renderable in _rich.values() if renderable is not None),
165 ]
167 def __rich_console__(
168 self, console: Console, options: ConsoleOptions
169 ) -> RenderResult:
170 """Rich display of data display outputs."""
171 yield from self.rich_output
173 @validator("output_type")
174 def output_type_must_match(cls, v: str) -> str:
175 """Check if stream has `display_data` type."""
176 if v != "display_data":
177 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.")
178 return v
181class CellExecuteResultOutput(CellDisplayDataOutput):
182 """Cell output of type `execute_result`."""
184 execution_count: PositiveInt
186 def __rich_console__(
187 self, console: Console, options: ConsoleOptions
188 ) -> RenderResult:
189 """Rich display of executed cell outputs."""
190 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count")
191 yield from self.rich_output
193 @validator("output_type")
194 def output_type_must_match(cls, v: str) -> str:
195 """Check if stream has `execute_result` type."""
196 if v != "execute_result":
197 raise ValueError(
198 f"Invalid output type. Expected `execute_result`, got {v}."
199 )
200 return v
203class CellErrorOutput(DatabooksBase):
204 """Cell output of type `error`."""
206 output_type: str
207 ename: str
208 evalue: str
209 traceback: List[str]
211 def __rich__(
212 self,
213 ) -> ConsoleRenderable:
214 """Rich display of error outputs."""
215 return Text.from_ansi("\n".join(self.traceback))
217 @validator("output_type")
218 def output_type_must_match(cls, v: str) -> str:
219 """Check if stream has `error` type."""
220 if v != "error":
221 raise ValueError(f"Invalid output type. Expected `error`, got {v}.")
222 return v
225CellOutputType = Union[
226 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput
227]
230class CellOutputs(DatabooksBase):
231 """Outputs of notebook code cells."""
233 __root__: List[CellOutputType]
235 def __rich_console__(
236 self, console: Console, options: ConsoleOptions
237 ) -> RenderResult:
238 """Rich display of code cell outputs."""
239 yield from self.values
241 @property
242 def values(
243 self,
244 ) -> List[CellOutputType]:
245 """Alias `__root__` with outputs for easy referencing."""
246 return self.__root__
249class CodeCell(BaseCell):
250 """Cell of type `code` - defined for rich displaying in terminal."""
252 outputs: CellOutputs
253 cell_type: str = "code"
255 def __rich_console__(
256 self, console: Console, options: ConsoleOptions
257 ) -> RenderResult:
258 """Rich display of code cells."""
259 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count")
260 yield Panel(
261 Syntax(
262 "".join(self.source) if isinstance(self.source, list) else self.source,
263 getattr(self.metadata, "lang", "text"),
264 )
265 )
266 yield self.outputs
268 @validator("cell_type")
269 def cell_has_code_type(cls, v: str) -> str:
270 """Extract the list values from the __root__ attribute of `CellOutputs`."""
271 if v != "code":
272 raise ValueError(f"Expected code of type `code`, got `{v}`.")
273 return v
276class MarkdownCell(BaseCell):
277 """Cell of type `markdown` - defined for rich displaying in terminal."""
279 cell_type: str = "markdown"
281 def __rich__(
282 self,
283 ) -> ConsoleRenderable:
284 """Rich display of markdown cells."""
285 return Panel(Markdown("".join(self.source)))
287 @validator("cell_type")
288 def cell_has_md_type(cls, v: str) -> str:
289 """Extract the list values from the __root__ attribute of `CellOutputs`."""
290 if v != "markdown":
291 raise ValueError(f"Expected code of type `markdown`, got {v}.")
292 return v
295class RawCell(BaseCell):
296 """Cell of type `raw` - defined for rich displaying in terminal."""
298 cell_type: str = "raw"
300 def __rich__(
301 self,
302 ) -> ConsoleRenderable:
303 """Rich display of raw cells."""
304 return Panel(Text("".join(self.source)))
306 @validator("cell_type")
307 def cell_has_md_type(cls, v: str) -> str:
308 """Extract the list values from the __root__ attribute of `CellOutputs`."""
309 if v != "raw":
310 raise ValueError(f"Expected code of type `raw`, got {v}.")
311 return v