Coverage for databooks/data_models/base.py: 88%
58 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
1"""Data models - Base Pydantic model with custom methods."""
2from __future__ import annotations
4from abc import abstractmethod
5from collections import UserList
6from typing import Any, Dict, Generic, Iterable, List, TypeVar, cast, overload
8from pydantic import BaseModel, ConfigDict, create_model
9from typing_extensions import Protocol, runtime_checkable
11T = TypeVar("T")
14@runtime_checkable
15class DiffModel(Protocol, Iterable):
16 """Protocol for mypy static type checking."""
18 is_diff: bool
20 def resolve(self, *args: Any, **kwargs: Any) -> DatabooksBase:
21 """Protocol method that returns a valid base object."""
24class BaseCells(UserList, Generic[T]):
25 """Base abstract class for notebook cells."""
27 @abstractmethod
28 def resolve(self, **kwargs: Any) -> list:
29 """Return valid notebook cells from differences."""
30 raise NotImplementedError
33@overload
34def resolve(
35 model: DiffModel,
36 **kwargs: Any,
37) -> DatabooksBase:
38 ...
41@overload
42def resolve(
43 model: BaseCells,
44 **kwargs: Any,
45) -> List[T]:
46 ...
49def resolve(
50 model: DiffModel | BaseCells,
51 *,
52 keep_first: bool = True,
53 ignore_none: bool = True,
54 **kwargs: Any,
55) -> DatabooksBase | List[T]:
56 """
57 Resolve differences for 'diff models'.
59 Return instance alike the parent class `databooks.data_models.base.DatabooksBase`.
60 :param model: DiffModel that is to be resolved (self when added as a method to a
61 class
62 :param keep_first: Whether to keep the information from the prior in the
63 'diff model' or the latter
64 :param ignore_none: Whether to ignore `None` values if encountered, and use the
65 other field value
66 :return: Model with selected fields from the differences
67 """
68 field_d = dict(model)
69 is_diff = field_d.pop("is_diff")
70 if not is_diff:
71 raise TypeError("Can only resolve dynamic 'diff models' (when `is_diff=True`).")
73 res_vals: Dict[str, Any] = {}
74 for name, value in field_d.items():
75 if isinstance(value, (DiffModel, BaseCells)):
76 res_vals[name] = value.resolve(
77 keep_first=keep_first, ignore_none=ignore_none, **kwargs
78 )
79 else:
80 res_vals[name] = (
81 value[keep_first]
82 if value[not keep_first] is None and ignore_none
83 else value[not keep_first]
84 )
86 return type(model).mro()[1](**res_vals)
89class DatabooksBase(BaseModel):
90 """Base Pydantic class with extras on managing fields."""
92 model_config = ConfigDict(extra="allow")
94 def remove_fields(
95 self,
96 fields: Iterable[str],
97 *,
98 recursive: bool = False,
99 missing_ok: bool = False,
100 ) -> None:
101 """
102 Remove selected fields.
104 :param fields: Fields to remove
105 :param recursive: Whether to remove the fields recursively in case of nested
106 models
107 :param missing_ok: Whether to raise errors in case field is missing
108 :return:
109 """
110 d_model = dict(self)
111 for field in fields:
112 field_val = d_model.get(field) if missing_ok else d_model[field]
113 if recursive and isinstance(field_val, DatabooksBase):
114 field_val.remove_fields(fields)
115 elif field in d_model:
116 delattr(self, field)
118 def __str__(self) -> str:
119 """Return outputs of __repr__."""
120 return repr(self)
122 def __sub__(self, other: DatabooksBase) -> DiffModel:
123 """
124 Subtraction between `databooks.data_models.base.DatabooksBase` objects.
126 The difference basically return models that replace each fields by a tuple,
127 where for each field we have `field = (self_value, other_value)`
128 """
129 if type(self) != type(other):
130 raise TypeError(
131 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
132 f" `{type(other).__name__}`"
133 )
135 # Get field and values for each instance
136 self_d = dict(self)
137 other_d = dict(other)
139 # Build dict with {field: (type, value)} for each field
140 fields_d: Dict[str, Any] = {}
141 for name in self_d.keys() | other_d.keys():
142 self_val = self_d.get(name)
143 other_val = other_d.get(name)
144 if type(self_val) is type(other_val) and all(
145 isinstance(val, (DatabooksBase, BaseCells))
146 for val in (self_val, other_val)
147 ):
148 # Recursively get the diffs for nested models
149 fields_d[name] = (Any, self_val - other_val) # type: ignore
150 else:
151 fields_d[name] = (tuple, (self_val, other_val))
153 # Build Pydantic models dynamically
154 DiffInstance = create_model(
155 "Diff" + type(self).__name__,
156 __base__=type(self),
157 resolve=resolve,
158 is_diff=(bool, True),
159 **fields_d,
160 )
162 return cast(DiffModel, DiffInstance()) # it'll be filled in with the defaults