Coverage for databooks/affirm.py: 98%

84 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-10-03 12:27 +0000

1"""Functions to safely evaluate strings and inspect notebook.""" 

2import ast 

3from collections import abc 

4from copy import deepcopy 

5from itertools import compress 

6from pathlib import Path 

7from typing import Any, Callable, Dict, Iterable, List, Tuple 

8 

9from databooks import JupyterNotebook 

10from databooks.data_models.base import DatabooksBase 

11from databooks.logging import get_logger, set_verbose 

12 

13logger = get_logger(__file__) 

14 

15_ALLOWED_BUILTINS = ( 

16 all, 

17 any, 

18 enumerate, 

19 filter, 

20 getattr, 

21 hasattr, 

22 len, 

23 list, 

24 range, 

25 sorted, 

26) 

27_ALLOWED_NODES = ( 

28 ast.Add, 

29 ast.And, 

30 ast.BinOp, 

31 ast.BitAnd, 

32 ast.BitOr, 

33 ast.BitXor, 

34 ast.BoolOp, 

35 ast.boolop, 

36 ast.cmpop, 

37 ast.Compare, 

38 ast.comprehension, 

39 ast.Constant, 

40 ast.Dict, 

41 ast.Eq, 

42 ast.Expr, 

43 ast.expr, 

44 ast.expr_context, 

45 ast.Expression, 

46 ast.For, 

47 ast.Gt, 

48 ast.GtE, 

49 ast.In, 

50 ast.Is, 

51 ast.IsNot, 

52 ast.List, 

53 ast.ListComp, 

54 ast.Load, 

55 ast.LShift, 

56 ast.Lt, 

57 ast.LtE, 

58 ast.Mod, 

59 ast.Name, 

60 ast.Not, 

61 ast.NotEq, 

62 ast.NotIn, 

63 ast.Num, 

64 ast.operator, 

65 ast.Or, 

66 ast.RShift, 

67 ast.Set, 

68 ast.Slice, 

69 ast.slice, 

70 ast.Str, 

71 ast.Sub, 

72 ast.Subscript, 

73 ast.Tuple, 

74 ast.UAdd, 

75 ast.UnaryOp, 

76 ast.unaryop, 

77 ast.USub, 

78) 

79 

80 

81class DatabooksParser(ast.NodeVisitor): 

82 """AST parser that disallows unsafe nodes/values.""" 

83 

84 def __init__(self, **variables: Any) -> None: 

85 """Instantiate with variables and callables (built-ins) scope.""" 

86 # https://github.com/python/mypy/issues/3728 

87 self.builtins = {b.__name__: b for b in _ALLOWED_BUILTINS} # type: ignore 

88 self.names = deepcopy(variables) or {} 

89 self.scope = { 

90 **self.names, 

91 "__builtins__": self.builtins, 

92 } 

93 

94 @staticmethod 

95 def _prioritize(field: Tuple[str, Any]) -> bool: 

96 """Prioritize `ast.comprehension` nodes when expanding the AST tree.""" 

97 _, value = field 

98 if not isinstance(value, list): 

99 return True 

100 return not any(isinstance(f, ast.comprehension) for f in value) 

101 

102 @staticmethod 

103 def _allowed_attr(obj: Any, attr: str, is_dynamic: bool = False) -> None: 

104 """ 

105 Check that attribute is a key of `databooks.data_models.base.DatabooksBase`. 

106 

107 If `obj` is an iterable and was computed dynamically (that is, not originally in 

108 scope but computed from a comprehension), check attributes for all elements in 

109 the iterable. 

110 """ 

111 allowed_attrs = list(dict(obj).keys()) if isinstance(obj, DatabooksBase) else () 

112 if isinstance(obj, abc.Iterable) and is_dynamic: 

113 for el in obj: 

114 DatabooksParser._allowed_attr(obj=el, attr=attr) 

115 else: 

116 if attr not in allowed_attrs: 

117 raise ValueError( 

118 "Expected attribute to be one of" 

119 f" `{allowed_attrs}`, got `{attr}` for {obj}." 

120 ) 

121 

122 def _get_iter(self, node: ast.AST) -> Iterable: 

123 """Use `DatabooksParser.safe_eval_ast` to get the iterable object.""" 

124 tree = ast.Expression(body=node) 

125 return iter(self.safe_eval_ast(tree)) 

126 

127 def generic_visit(self, node: ast.AST) -> None: 

128 """ 

129 Prioritize `ast.comprehension` nodes when expanding tree. 

130 

131 Similar to `NodeVisitor.generic_visit`, but favor comprehensions when multiple 

132 nodes on the same level. In comprehensions, we have a generator argument that 

133 includes names that are stored. By visiting them first we avoid 'running into' 

134 unknown names. 

135 """ 

136 if not isinstance(node, _ALLOWED_NODES): 

137 raise ValueError(f"Invalid node `{node}`.") 

138 

139 for field, value in sorted(ast.iter_fields(node), key=self._prioritize): 

140 if isinstance(value, list): 

141 for item in value: 

142 if isinstance(item, ast.AST): 

143 self.visit(item) 

144 elif isinstance(value, ast.AST): 

145 self.visit(value) 

146 

147 def visit_comprehension(self, node: ast.comprehension) -> None: 

148 """Add variable from a comprehension to list of allowed names.""" 

149 if not isinstance(node.target, ast.Name): 

150 raise RuntimeError( 

151 "Expected `ast.comprehension`'s target to be `ast.Name`, got" 

152 f" `ast.{type(node.target).__name__}`." 

153 ) 

154 self.names[node.target.id] = self._get_iter(node.iter) 

155 self.generic_visit(node) 

156 

157 def visit_Attribute(self, node: ast.Attribute) -> None: 

158 """Allow attributes for Pydantic fields only.""" 

159 if not isinstance(node.value, (ast.Attribute, ast.Name, ast.Subscript)): 

160 raise ValueError( 

161 "Expected attribute to be one of `ast.Name`, `ast.Attribute` or" 

162 f" `ast.Subscript`, got `ast.{type(node.value).__name__}`." 

163 ) 

164 if isinstance(node.value, ast.Name): 

165 self._allowed_attr( 

166 obj=self.names[node.value.id], 

167 attr=node.attr, 

168 is_dynamic=node.value.id in (self.names.keys() - self.scope.keys()), 

169 ) 

170 self.generic_visit(node) 

171 

172 def visit_Name(self, node: ast.Name) -> None: 

173 """Only allow names from scope or comprehension variables.""" 

174 valid_names = {**self.names, **self.builtins} 

175 if node.id not in valid_names: 

176 raise ValueError( 

177 f"Expected `name` to be one of `{valid_names.keys()}`, got `{node.id}`." 

178 ) 

179 self.generic_visit(node) 

180 

181 def safe_eval_ast(self, ast_tree: ast.AST) -> Any: 

182 """Evaluate safe AST trees only (raise errors otherwise).""" 

183 self.visit(ast_tree) 

184 exe = compile(ast_tree, filename="", mode="eval") 

185 return eval(exe, self.scope) 

186 

187 def safe_eval(self, src: str) -> Any: 

188 """ 

189 Evaluate strings that are safe only (raise errors otherwise). 

190 

191 A "safe" string or node provided may only consist of nodes in 

192 `databooks.affirm._ALLOWED_NODES` and built-ins from 

193 `databooks.affirm._ALLOWED_BUILTINS`. 

194 """ 

195 ast_tree = ast.parse(src, mode="eval") 

196 return self.safe_eval_ast(ast_tree) 

197 

198 

199def affirm(nb_path: Path, exprs: List[str], verbose: bool = False) -> bool: 

200 """ 

201 Return whether notebook passed all checks (expressions). 

202 

203 :param nb_path: Path of notebook file 

204 :param exprs: Expression with check to be evaluated on notebook 

205 :param verbose: Log failed tests for notebook 

206 :return: Evaluated expression cast as a `bool` 

207 """ 

208 if verbose: 

209 set_verbose(logger) 

210 

211 nb = JupyterNotebook.parse_file(nb_path) 

212 variables: Dict[str, Any] = { 

213 "nb": nb, 

214 "raw_cells": [c for c in nb.cells if c.cell_type == "raw"], 

215 "md_cells": [c for c in nb.cells if c.cell_type == "markdown"], 

216 "code_cells": [c for c in nb.cells if c.cell_type == "code"], 

217 "exec_cells": [ 

218 c 

219 for c in nb.cells 

220 if c.cell_type == "code" and c.execution_count is not None 

221 ], 

222 } 

223 databooks_parser = DatabooksParser(**variables) 

224 is_ok = [bool(databooks_parser.safe_eval(expr)) for expr in exprs] 

225 n_fail = sum([not ok for ok in is_ok]) 

226 

227 logger.info(f"{nb_path} failed {n_fail} of {len(is_ok)} checks.") 

228 logger.debug( 

229 str(nb_path) 

230 + ( 

231 f" failed {list(compress(exprs, (not ok for ok in is_ok)))}." 

232 if n_fail > 0 

233 else " succeeded all checks." 

234 ) 

235 ) 

236 return all(is_ok) 

237 

238 

239def affirm_all( 

240 nb_paths: List[Path], 

241 *, 

242 progress_callback: Callable[[], None] = lambda: None, 

243 **affirm_kwargs: Any, 

244) -> List[bool]: 

245 """ 

246 Clear metadata for multiple notebooks at notebooks and cell level. 

247 

248 :param nb_paths: Paths of notebooks to assert metadata 

249 :param progress_callback: Callback function to report progress 

250 :param affirm_kwargs: Keyword arguments to be passed to `databooks.affirm.affirm` 

251 :return: Whether the notebooks contained or not the desired metadata 

252 """ 

253 checks = [] 

254 for nb_path in nb_paths: 

255 checks.append(affirm(nb_path, **affirm_kwargs)) 

256 progress_callback() 

257 return checks