Coverage for databooks/data_models/base.py: 88%
59 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-11 20:30 +0000
1"""Data models - Base Pydantic model with custom methods."""
2from __future__ import annotations
4from abc import abstractmethod
5from collections import UserList
6from typing import Any, Dict, Generic, Iterable, List, TypeVar, cast, overload
8from pydantic import BaseModel, Extra, create_model
9from typing_extensions import Protocol, runtime_checkable
11T = TypeVar("T")
14@runtime_checkable
15class DiffModel(Protocol, Iterable):
16 """Protocol for mypy static type checking."""
18 is_diff: bool
20 def resolve(self, *args: Any, **kwargs: Any) -> DatabooksBase:
21 """Protocol method that returns a valid base object."""
24class BaseCells(UserList, Generic[T]):
25 """Base abstract class for notebook cells."""
27 @abstractmethod
28 def resolve(self, **kwargs: Any) -> list:
29 """Return valid notebook cells from differences."""
30 raise NotImplementedError
33@overload
34def resolve(
35 model: DiffModel,
36 **kwargs: Any,
37) -> DatabooksBase:
38 ...
41@overload
42def resolve(
43 model: BaseCells,
44 **kwargs: Any,
45) -> List[T]:
46 ...
49def resolve(
50 model: DiffModel | BaseCells,
51 *,
52 keep_first: bool = True,
53 ignore_none: bool = True,
54 **kwargs: Any,
55) -> DatabooksBase | List[T]:
56 """
57 Resolve differences for 'diff models'.
59 Return instance alike the parent class `databooks.data_models.base.DatabooksBase`.
60 :param model: DiffModel that is to be resolved (self when added as a method to a
61 class
62 :param keep_first: Whether to keep the information from the prior in the
63 'diff model' or the latter
64 :param ignore_none: Whether to ignore `None` values if encountered, and use the
65 other field value
66 :return: Model with selected fields from the differences
67 """
68 field_d = dict(model)
69 is_diff = field_d.pop("is_diff")
70 if not is_diff:
71 raise TypeError("Can only resolve dynamic 'diff models' (when `is_diff=True`).")
73 res_vals: Dict[str, Any] = {}
74 for name, value in field_d.items():
75 if isinstance(value, (DiffModel, BaseCells)):
76 res_vals[name] = value.resolve(
77 keep_first=keep_first, ignore_none=ignore_none, **kwargs
78 )
79 else:
80 res_vals[name] = (
81 value[keep_first]
82 if value[not keep_first] is None and ignore_none
83 else value[not keep_first]
84 )
86 return type(model).mro()[1](**res_vals)
89class DatabooksBase(BaseModel):
90 """Base Pydantic class with extras on managing fields."""
92 class Config:
93 """Default configuration for base class."""
95 extra = Extra.allow
97 def remove_fields(
98 self,
99 fields: Iterable[str],
100 *,
101 recursive: bool = False,
102 missing_ok: bool = False,
103 ) -> None:
104 """
105 Remove selected fields.
107 :param fields: Fields to remove
108 :param recursive: Whether to remove the fields recursively in case of nested
109 models
110 :param missing_ok: Whether to raise errors in case field is missing
111 :return:
112 """
113 d_model = dict(self)
114 for field in fields:
115 field_val = d_model.get(field) if missing_ok else d_model[field]
116 if recursive and isinstance(field_val, DatabooksBase):
117 field_val.remove_fields(fields)
118 elif field in d_model:
119 delattr(self, field)
121 def __str__(self) -> str:
122 """Return outputs of __repr__."""
123 return repr(self)
125 def __sub__(self, other: DatabooksBase) -> DiffModel:
126 """
127 Subtraction between `databooks.data_models.base.DatabooksBase` objects.
129 The difference basically return models that replace each fields by a tuple,
130 where for each field we have `field = (self_value, other_value)`
131 """
132 if type(self) != type(other):
133 raise TypeError(
134 f"Unsupported operand types for `-`: `{type(self).__name__}` and"
135 f" `{type(other).__name__}`"
136 )
138 # Get field and values for each instance
139 self_d = dict(self)
140 other_d = dict(other)
142 # Build dict with {field: (type, value)} for each field
143 fields_d: Dict[str, Any] = {}
144 for name in self_d.keys() | other_d.keys():
145 self_val = self_d.get(name)
146 other_val = other_d.get(name)
147 if type(self_val) is type(other_val) and all(
148 isinstance(val, (DatabooksBase, BaseCells))
149 for val in (self_val, other_val)
150 ):
151 # Recursively get the diffs for nested models
152 fields_d[name] = (Any, self_val - other_val) # type: ignore
153 else:
154 fields_d[name] = (tuple, (self_val, other_val))
156 # Build Pydantic models dynamically
157 DiffInstance = create_model(
158 "Diff" + type(self).__name__,
159 __base__=type(self),
160 resolve=resolve,
161 is_diff=True,
162 **fields_d,
163 )
164 return cast(DiffModel, DiffInstance()) # it'll be filled in with the defaults