Coverage for databooks/data_models/cell.py: 95%
129 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-04 16:41 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-04 16:41 +0000
1"""Data models - Cells and components."""
2from __future__ import annotations
4from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
6from pydantic import PositiveInt, validator
7from rich.console import Console, ConsoleOptions, ConsoleRenderable, RenderResult
8from rich.markdown import Markdown
9from rich.panel import Panel
10from rich.syntax import Syntax
11from rich.text import Text
13from databooks.data_models.base import DatabooksBase
14from databooks.logging import get_logger
16logger = get_logger(__file__)
19class CellMetadata(DatabooksBase):
20 """Cell metadata. Empty by default but can accept extra fields."""
23class BaseCell(DatabooksBase):
24 """
25 Jupyter notebook cells.
27 Fields `outputs` and `execution_count` are not included since they should only be
28 present in code cells - thus are treated as extra fields.
29 """
31 metadata: CellMetadata
32 source: Union[List[str], str]
33 cell_type: str
35 def __hash__(self) -> int:
36 """Cells must be hashable for `difflib.SequenceMatcher`."""
37 return hash(
38 (type(self),) + tuple(v) if isinstance(v, list) else v
39 for v in self.__dict__.values()
40 )
42 def remove_fields(
43 self, fields: Iterable[str] = (), missing_ok: bool = True, **kwargs: Any
44 ) -> None:
45 """
46 Remove cell fields.
48 Similar to `databooks.data_models.base.remove_fields`, but will ignore required
49 fields for cell type.
50 """
51 # Ignore required `BaseCell` fields
52 cell_fields = self.__fields__ # required fields especified in class definition
53 if any(field in fields for field in cell_fields):
54 logger.debug(
55 "Ignoring removal of required fields "
56 + str([f for f in fields if f in cell_fields])
57 + f" in `{type(self).__name__}`."
58 )
59 fields = [f for f in fields if f not in cell_fields]
61 super(BaseCell, self).remove_fields(fields, missing_ok=missing_ok)
63 if self.cell_type == "code":
64 self.outputs: CellOutputs = (
65 CellOutputs(__root__=[])
66 if "outputs" not in dict(self)
67 else self.outputs
68 )
69 self.execution_count: Optional[PositiveInt] = (
70 None if "execution_count" not in dict(self) else self.execution_count
71 )
73 def clear_fields(
74 self,
75 *,
76 cell_metadata_keep: Sequence[str] = None,
77 cell_metadata_remove: Sequence[str] = None,
78 cell_remove_fields: Sequence[str] = (),
79 ) -> None:
80 """
81 Clear cell metadata, execution count, outputs or other desired fields (id, ...).
83 You can also specify metadata to keep or remove from the `metadata` property of
84 `databooks.data_models.cell.BaseCell`.
85 :param cell_metadata_keep: Metadata values to keep - simply pass an empty
86 sequence (i.e.: `()`) to remove all extra fields.
87 :param cell_metadata_remove: Metadata values to remove
88 :param cell_remove_fields: Fields to remove from cell
89 :return:
90 """
91 nargs = sum((cell_metadata_keep is not None, cell_metadata_remove is not None))
92 if nargs != 1:
93 raise ValueError(
94 "Exactly one of `cell_metadata_keep` or `cell_metadata_remove` must"
95 f" be passed, got {nargs} arguments."
96 )
98 if cell_metadata_keep is not None:
99 cell_metadata_remove = tuple(
100 field for field, _ in self.metadata if field not in cell_metadata_keep
101 )
102 self.metadata.remove_fields(cell_metadata_remove) # type: ignore
104 self.remove_fields(fields=cell_remove_fields, missing_ok=True)
107class CellStreamOutput(DatabooksBase):
108 """Cell output of type `stream`."""
110 output_type: str
111 name: str
112 text: List[str]
114 def __rich__(
115 self,
116 ) -> ConsoleRenderable:
117 """Rich display of cell stream outputs."""
118 return Text("".join(self.text))
120 @validator("output_type")
121 def output_type_must_be_stream(cls, v: str) -> str:
122 """Check if stream has `stream` type."""
123 if v != "stream":
124 raise ValueError(f"Invalid output type. Expected `stream`, got {v}.")
125 return v
127 @validator("name")
128 def stream_name_must_match(cls, v: str) -> str:
129 """Check if stream name is either `stdout` or `stderr`."""
130 valid_names = ("stdout", "stderr")
131 if v not in valid_names:
132 raise ValueError(
133 f"Invalid stream name. Expected one of {valid_names}, got {v}."
134 )
135 return v
138class CellDisplayDataOutput(DatabooksBase):
139 """Cell output of type `display_data`."""
141 output_type: str
142 data: Dict[str, Any]
143 metadata: Dict[str, Any]
145 @property
146 def rich_output(self) -> Sequence[ConsoleRenderable]:
147 """Dynamically compute the rich output - also in `CellExecuteResultOutput`."""
148 mime_func = {
149 "image/png": None,
150 "text/html": None,
151 "text/plain": lambda s: Text("".join(s)),
152 }
153 supported = [k for k, v in mime_func.items() if v is not None]
154 not_supported = [
155 Text(f"<✨Rich✨ `{mime}` not currently supported 😢>")
156 for mime in self.data.keys()
157 if mime not in supported
158 ]
159 return not_supported + [
160 next(
161 mime_func[mime](content) # type: ignore
162 for mime, content in self.data.items()
163 if mime in supported
164 )
165 ]
167 def __rich_console__(
168 self, console: Console, options: ConsoleOptions
169 ) -> RenderResult:
170 """Rich display of data display outputs."""
171 yield from self.rich_output
173 @validator("output_type")
174 def output_type_must_match(cls, v: str) -> str:
175 """Check if stream has `display_data` type."""
176 if v != "display_data":
177 raise ValueError(f"Invalid output type. Expected `display_data`, got {v}.")
178 return v
181class CellExecuteResultOutput(CellDisplayDataOutput):
182 """Cell output of type `execute_result`."""
184 execution_count: PositiveInt
186 def __rich_console__(
187 self, console: Console, options: ConsoleOptions
188 ) -> RenderResult:
189 """Rich display of executed cell outputs."""
190 yield Text(f"Out [{self.execution_count or ' '}]:", style="out_count")
191 yield from self.rich_output
193 @validator("output_type")
194 def output_type_must_match(cls, v: str) -> str:
195 """Check if stream has `execute_result` type."""
196 if v != "execute_result":
197 raise ValueError(
198 f"Invalid output type. Expected `execute_result`, got {v}."
199 )
200 return v
203class CellErrorOutput(DatabooksBase):
204 """Cell output of type `error`."""
206 output_type: str
207 ename: str
208 evalue: str
209 traceback: List[str]
211 def __rich__(
212 self,
213 ) -> ConsoleRenderable:
214 """Rich display of error outputs."""
215 return Text.from_ansi("\n".join(self.traceback))
217 @validator("output_type")
218 def output_type_must_match(cls, v: str) -> str:
219 """Check if stream has `error` type."""
220 if v != "error":
221 raise ValueError(f"Invalid output type. Expected `error`, got {v}.")
222 return v
225CellOutputType = Union[
226 CellStreamOutput, CellDisplayDataOutput, CellExecuteResultOutput, CellErrorOutput
227]
230class CellOutputs(DatabooksBase):
231 """Outputs of notebook code cells."""
233 __root__: List[CellOutputType]
235 def __rich_console__(
236 self, console: Console, options: ConsoleOptions
237 ) -> RenderResult:
238 """Rich display of code cell outputs."""
239 yield from self.values
241 @property
242 def values(
243 self,
244 ) -> List[CellOutputType]:
245 """Alias `__root__` with outputs for easy referencing."""
246 return self.__root__
249class CodeCell(BaseCell):
250 """Cell of type `code` - defined for rich displaying in terminal."""
252 outputs: CellOutputs
253 cell_type: str = "code"
255 def __rich_console__(
256 self, console: Console, options: ConsoleOptions
257 ) -> RenderResult:
258 """Rich display of code cells."""
259 yield Text(f"In [{self.execution_count or ' '}]:", style="in_count")
260 yield Panel(
261 Syntax(
262 "".join(self.source) if isinstance(self.source, list) else self.source,
263 getattr(self.metadata, "lang", "text"),
264 )
265 )
266 yield self.outputs
268 @validator("cell_type")
269 def cell_has_code_type(cls, v: str) -> str:
270 """Extract the list values from the __root__ attribute of `CellOutputs`."""
271 if v != "code":
272 raise ValueError(f"Expected code of type `code`, got `{v}`.")
273 return v
276class MarkdownCell(BaseCell):
277 """Cell of type `markdown` - defined for rich displaying in terminal."""
279 cell_type: str = "markdown"
281 def __rich__(
282 self,
283 ) -> ConsoleRenderable:
284 """Rich display of markdown cells."""
285 return Panel(Markdown("".join(self.source)))
287 @validator("cell_type")
288 def cell_has_md_type(cls, v: str) -> str:
289 """Extract the list values from the __root__ attribute of `CellOutputs`."""
290 if v != "markdown":
291 raise ValueError(f"Expected code of type `markdown`, got {v}.")
292 return v
295class RawCell(BaseCell):
296 """Cell of type `raw` - defined for rich displaying in terminal."""
298 cell_type: str = "raw"
300 def __rich__(
301 self,
302 ) -> ConsoleRenderable:
303 """Rich display of raw cells."""
304 return Panel(Text("".join(self.source)))
306 @validator("cell_type")
307 def cell_has_md_type(cls, v: str) -> str:
308 """Extract the list values from the __root__ attribute of `CellOutputs`."""
309 if v != "raw":
310 raise ValueError(f"Expected code of type `raw`, got {v}.")
311 return v