Coverage for databooks/data_models/base.py: 87%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Data models - Base Pydantic model with custom methods."""
2from __future__ import annotations
4from abc import abstractmethod
5from collections import UserList
6from typing import (
7 Any,
8 Dict,
9 Generic,
10 Iterable,
11 List,
12 Protocol,
13 TypeVar,
14 cast,
15 overload,
16 runtime_checkable,
17)
19from pydantic import BaseModel, Extra, create_model
21T = TypeVar("T")
24@runtime_checkable
25class DiffModel(Protocol, Iterable):
26 """Protocol for mypy static type checking."""
28 is_diff: bool
30 def resolve(self, *args: Any, **kwargs: Any) -> DatabooksBase:
31 """Return a valid base object."""
32 ...
35class BaseCells(UserList, Generic[T]):
36 """Base abstract class for notebook cells."""
38 @abstractmethod
39 def resolve(self, **kwargs: Any) -> list:
40 """Return valid notebook cells from differences."""
41 raise NotImplementedError
43 ...
46@overload
47def resolve(
48 model: DiffModel,
49 **kwargs: Any,
50) -> DatabooksBase:
51 ...
54@overload
55def resolve(
56 model: BaseCells,
57 **kwargs: Any,
58) -> List[T]:
59 ...
62def resolve(
63 model: DiffModel | BaseCells,
64 *,
65 keep_first: bool = True,
66 ignore_none: bool = True,
67 **kwargs: Any,
68) -> DatabooksBase | List[T]:
69 """
70 Resolve differences for 'diff models'.
72 Return instance alike the parent class `databooks.data_models.Cell.DatabooksBase`.
73 :param model: DiffModel that is to be resolved (self when added as a method to a
74 class
75 :param keep_first: Whether to keep the information from the prior in the
76 'diff model' or the later
77 :param ignore_none: Whether or not to ignore `None` values if encountered, and
78 use the other field value
79 :return: Model with selected fields from the differences
80 """
81 field_d = dict(model)
82 is_diff = field_d.pop("is_diff")
83 if not is_diff:
84 raise TypeError("Can only resolve dynamic 'diff models' (when `is_diff=True`).")
86 res_vals = cast(Dict[str, Any], {})
87 for name, value in field_d.items():
88 if isinstance(value, (DiffModel, BaseCells)):
89 res_vals[name] = value.resolve(
90 keep_first=keep_first, ignore_none=ignore_none, **kwargs
91 )
92 else:
93 res_vals[name] = (
94 value[keep_first]
95 if value[not keep_first] is None and ignore_none
96 else value[not keep_first]
97 )
99 return type(model).mro()[1](**res_vals)
102class DatabooksBase(BaseModel):
103 """Base Pydantic class with extras on managing fields."""
105 class Config:
106 """Default configuration for base class."""
108 extra = Extra.allow
110 def remove_fields(
111 self,
112 fields: Iterable[str],
113 *,
114 recursive: bool = False,
115 missing_ok: bool = False,
116 ) -> None:
117 """
118 Remove selected fields.
120 :param fields: Fields to remove
121 :param recursive: Whether or not to remove the fields recursively in case of
122 nested models
123 :return:
124 """
125 d_model = dict(self)
126 for field in fields:
127 field_val = d_model.get(field) if missing_ok else d_model[field]
128 if recursive and isinstance(field_val, DatabooksBase):
129 field_val.remove_fields(fields)
130 elif field in d_model:
131 delattr(self, field)
133 def __str__(self) -> str:
134 """Return outputs of __repr__."""
135 return repr(self)
137 def __sub__(self, other: DatabooksBase) -> DatabooksBase:
138 """
139 Subtraction between `databooks.data_models.base.DatabooksBase` objects.
141 The difference basically return models that replace each fields by a tuple,
142 where for each field we have `field = (self_value, other_value)`
143 """
144 if type(self) != type(other):
145 raise TypeError(
146 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
147 f" `{type(other).__name__}`"
148 )
150 # Get field and values for each instance
151 self_d = dict(self)
152 other_d = dict(other)
154 # Build dict with {field: (type, value)} for each field
155 fields_d = {}
156 for name in self_d.keys() | other_d.keys():
157 self_val = self_d.get(name)
158 other_val = other_d.get(name)
159 if type(self_val) is type(other_val) and all(
160 isinstance(val, (DatabooksBase, BaseCells))
161 for val in (self_val, other_val)
162 ):
163 # Recursively get the diffs for nested models
164 fields_d[name] = (Any, self_val - other_val) # type: ignore
165 else:
166 fields_d[name] = (tuple, (self_val, other_val))
168 # Build Pydantic models dynamically
169 DiffModel = create_model(
170 "Diff" + type(self).__name__,
171 __base__=type(self),
172 resolve=resolve,
173 is_diff=True,
174 **cast(Dict[str, Any], fields_d),
175 )
176 return DiffModel() # it'll be filled in with the defaults