Coverage for databooks/data_models/base.py: 88%

58 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-10-03 12:27 +0000

1"""Data models - Base Pydantic model with custom methods.""" 

2from __future__ import annotations 

3 

4from abc import abstractmethod 

5from collections import UserList 

6from typing import Any, Dict, Generic, Iterable, List, TypeVar, cast, overload 

7 

8from pydantic import BaseModel, ConfigDict, create_model 

9from typing_extensions import Protocol, runtime_checkable 

10 

11T = TypeVar("T") 

12 

13 

14@runtime_checkable 

15class DiffModel(Protocol, Iterable): 

16 """Protocol for mypy static type checking.""" 

17 

18 is_diff: bool 

19 

20 def resolve(self, *args: Any, **kwargs: Any) -> DatabooksBase: 

21 """Protocol method that returns a valid base object.""" 

22 

23 

24class BaseCells(UserList, Generic[T]): 

25 """Base abstract class for notebook cells.""" 

26 

27 @abstractmethod 

28 def resolve(self, **kwargs: Any) -> list: 

29 """Return valid notebook cells from differences.""" 

30 raise NotImplementedError 

31 

32 

33@overload 

34def resolve( 

35 model: DiffModel, 

36 **kwargs: Any, 

37) -> DatabooksBase: 

38 ... 

39 

40 

41@overload 

42def resolve( 

43 model: BaseCells, 

44 **kwargs: Any, 

45) -> List[T]: 

46 ... 

47 

48 

49def resolve( 

50 model: DiffModel | BaseCells, 

51 *, 

52 keep_first: bool = True, 

53 ignore_none: bool = True, 

54 **kwargs: Any, 

55) -> DatabooksBase | List[T]: 

56 """ 

57 Resolve differences for 'diff models'. 

58 

59 Return instance alike the parent class `databooks.data_models.base.DatabooksBase`. 

60 :param model: DiffModel that is to be resolved (self when added as a method to a 

61 class 

62 :param keep_first: Whether to keep the information from the prior in the 

63 'diff model' or the latter 

64 :param ignore_none: Whether to ignore `None` values if encountered, and use the 

65 other field value 

66 :return: Model with selected fields from the differences 

67 """ 

68 field_d = dict(model) 

69 is_diff = field_d.pop("is_diff") 

70 if not is_diff: 

71 raise TypeError("Can only resolve dynamic 'diff models' (when `is_diff=True`).") 

72 

73 res_vals: Dict[str, Any] = {} 

74 for name, value in field_d.items(): 

75 if isinstance(value, (DiffModel, BaseCells)): 

76 res_vals[name] = value.resolve( 

77 keep_first=keep_first, ignore_none=ignore_none, **kwargs 

78 ) 

79 else: 

80 res_vals[name] = ( 

81 value[keep_first] 

82 if value[not keep_first] is None and ignore_none 

83 else value[not keep_first] 

84 ) 

85 

86 return type(model).mro()[1](**res_vals) 

87 

88 

89class DatabooksBase(BaseModel): 

90 """Base Pydantic class with extras on managing fields.""" 

91 

92 model_config = ConfigDict(extra="allow") 

93 

94 def remove_fields( 

95 self, 

96 fields: Iterable[str], 

97 *, 

98 recursive: bool = False, 

99 missing_ok: bool = False, 

100 ) -> None: 

101 """ 

102 Remove selected fields. 

103 

104 :param fields: Fields to remove 

105 :param recursive: Whether to remove the fields recursively in case of nested 

106 models 

107 :param missing_ok: Whether to raise errors in case field is missing 

108 :return: 

109 """ 

110 d_model = dict(self) 

111 for field in fields: 

112 field_val = d_model.get(field) if missing_ok else d_model[field] 

113 if recursive and isinstance(field_val, DatabooksBase): 

114 field_val.remove_fields(fields) 

115 elif field in d_model: 

116 delattr(self, field) 

117 

118 def __str__(self) -> str: 

119 """Return outputs of __repr__.""" 

120 return repr(self) 

121 

122 def __sub__(self, other: DatabooksBase) -> DiffModel: 

123 """ 

124 Subtraction between `databooks.data_models.base.DatabooksBase` objects. 

125 

126 The difference basically return models that replace each fields by a tuple, 

127 where for each field we have `field = (self_value, other_value)` 

128 """ 

129 if type(self) != type(other): 

130 raise TypeError( 

131 f"Unsupported operand types for `-`: `{type(self).__name__}` and" 

132 f" `{type(other).__name__}`" 

133 ) 

134 

135 # Get field and values for each instance 

136 self_d = dict(self) 

137 other_d = dict(other) 

138 

139 # Build dict with {field: (type, value)} for each field 

140 fields_d: Dict[str, Any] = {} 

141 for name in self_d.keys() | other_d.keys(): 

142 self_val = self_d.get(name) 

143 other_val = other_d.get(name) 

144 if type(self_val) is type(other_val) and all( 

145 isinstance(val, (DatabooksBase, BaseCells)) 

146 for val in (self_val, other_val) 

147 ): 

148 # Recursively get the diffs for nested models 

149 fields_d[name] = (Any, self_val - other_val) # type: ignore 

150 else: 

151 fields_d[name] = (tuple, (self_val, other_val)) 

152 

153 # Build Pydantic models dynamically 

154 DiffInstance = create_model( 

155 "Diff" + type(self).__name__, 

156 __base__=type(self), 

157 resolve=resolve, 

158 is_diff=(bool, True), 

159 **fields_d, 

160 ) 

161 

162 return cast(DiffModel, DiffInstance()) # it'll be filled in with the defaults