Coverage for databooks/affirm.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

83 statements  

1"""Functions to safely evaluate strings and inspect notebook.""" 

2import ast 

3from copy import deepcopy 

4from functools import reduce 

5from itertools import compress 

6from pathlib import Path 

7from typing import Any, Callable, Dict, Iterable, List, Tuple 

8 

9from databooks import JupyterNotebook 

10from databooks.data_models.base import DatabooksBase 

11from databooks.logging import get_logger, set_verbose 

12 

13logger = get_logger(__file__) 

14 

15_ALLOWED_BUILTINS = ( 

16 all, 

17 any, 

18 enumerate, 

19 filter, 

20 getattr, 

21 hasattr, 

22 len, 

23 list, 

24 range, 

25 sorted, 

26) 

27_ALLOWED_NODES = ( 

28 ast.Add, 

29 ast.And, 

30 ast.BinOp, 

31 ast.BitAnd, 

32 ast.BitOr, 

33 ast.BitXor, 

34 ast.BoolOp, 

35 ast.boolop, 

36 ast.cmpop, 

37 ast.Compare, 

38 ast.comprehension, 

39 ast.Constant, 

40 ast.Dict, 

41 ast.Eq, 

42 ast.Expr, 

43 ast.expr, 

44 ast.expr_context, 

45 ast.Expression, 

46 ast.For, 

47 ast.Gt, 

48 ast.GtE, 

49 ast.In, 

50 ast.Is, 

51 ast.IsNot, 

52 ast.List, 

53 ast.ListComp, 

54 ast.Load, 

55 ast.LShift, 

56 ast.Lt, 

57 ast.LtE, 

58 ast.Mod, 

59 ast.Name, 

60 ast.Not, 

61 ast.NotEq, 

62 ast.NotIn, 

63 ast.Num, 

64 ast.operator, 

65 ast.Or, 

66 ast.RShift, 

67 ast.Set, 

68 ast.Slice, 

69 ast.slice, 

70 ast.Str, 

71 ast.Sub, 

72 ast.Tuple, 

73 ast.UAdd, 

74 ast.UnaryOp, 

75 ast.unaryop, 

76 ast.USub, 

77) 

78 

79 

80class DatabooksParser(ast.NodeVisitor): 

81 """AST parser that disallows unsafe nodes/values.""" 

82 

83 def __init__(self, **variables: Any) -> None: 

84 """Instantiate with variables and callables (built-ins) scope.""" 

85 # https://github.com/python/mypy/issues/3728 

86 self.builtins = {b.__name__: b for b in _ALLOWED_BUILTINS} # type: ignore 

87 self.names = deepcopy(variables) or {} 

88 self.scope = { 

89 **self.names, 

90 "__builtins__": self.builtins, 

91 } 

92 

93 @staticmethod 

94 def _prioritize(field: Tuple[str, Any]) -> bool: 

95 """Prioritize `ast.comprehension` nodes when expanding the AST tree.""" 

96 _, value = field 

97 if not isinstance(value, list): 

98 return True 

99 return not any(isinstance(f, ast.comprehension) for f in value) 

100 

101 def _get_iter(self, node: ast.AST) -> Iterable: 

102 """Use `DatabooksParser.safe_eval_ast` to get the iterable object.""" 

103 tree = ast.Expression(body=node) 

104 return iter(self.safe_eval_ast(tree)) 

105 

106 def generic_visit(self, node: ast.AST) -> None: 

107 """ 

108 Prioritize `ast.comprehension` nodes when expanding tree. 

109 

110 Similar to `NodeVisitor.generic_visit`, but favor comprehensions when multiple 

111 nodes on the same level. In comprehensions, we have a generator argument that 

112 includes names that are stored. By visiting them first we avoid 'running into' 

113 unknown names. 

114 """ 

115 if not isinstance(node, _ALLOWED_NODES): 

116 raise ValueError(f"Invalid node `{node}`.") 

117 

118 for field, value in sorted(ast.iter_fields(node), key=self._prioritize): 

119 if isinstance(value, list): 

120 for item in value: 

121 if isinstance(item, ast.AST): 

122 self.visit(item) 

123 elif isinstance(value, ast.AST): 

124 self.visit(value) 

125 

126 def visit_comprehension(self, node: ast.comprehension) -> None: 

127 """Add variable from a comprehension to list of allowed names.""" 

128 if not isinstance(node.target, ast.Name): 

129 raise RuntimeError( 

130 "Expected `ast.comprehension`'s target to be `ast.Name`, got" 

131 f" `ast.{type(node.target).__name__}`." 

132 ) 

133 # If any elements in the comprehension are a `DatabooksBase` instance, then 

134 # pass down the attributes as valid 

135 iterable = self._get_iter(node.iter) 

136 databooks_el = [el for el in iterable if isinstance(el, DatabooksBase)] 

137 if databooks_el: 

138 d_attrs = reduce(lambda a, b: {**a, **b}, [dict(el) for el in databooks_el]) 

139 self.names[node.target.id] = DatabooksBase(**d_attrs) if databooks_el else ... 

140 self.generic_visit(node) 

141 

142 def visit_Attribute(self, node: ast.Attribute) -> None: 

143 """Allow attributes for Pydantic fields only.""" 

144 if not isinstance(node.value, (ast.Attribute, ast.Name)): 

145 raise ValueError( 

146 "Expected attribute to be one of `ast.Name` or `ast.Attribute`, got" 

147 f" `ast.{type(node.value).__name__}`" 

148 ) 

149 if not isinstance(node.value, ast.Attribute): 

150 obj = self.names[node.value.id] 

151 allowed_attrs = dict(obj).keys() if isinstance(obj, DatabooksBase) else () 

152 if node.attr not in allowed_attrs: 

153 raise ValueError( 

154 "Expected attribute to be one of" 

155 f" `{allowed_attrs}`, got `{node.attr}`" 

156 ) 

157 self.generic_visit(node) 

158 

159 def visit_Name(self, node: ast.Name) -> None: 

160 """Only allow names from scope or comprehension variables.""" 

161 valid_names = {**self.names, **self.builtins} 

162 if node.id not in valid_names: 

163 raise ValueError( 

164 f"Expected `name` to be one of `{valid_names.keys()}`, got `{node.id}`." 

165 ) 

166 self.generic_visit(node) 

167 

168 def safe_eval_ast(self, ast_tree: ast.AST) -> Any: 

169 """Evaluate safe AST trees only (raise errors otherwise).""" 

170 self.visit(ast_tree) 

171 exe = compile(ast_tree, filename="", mode="eval") 

172 return eval(exe, self.scope) 

173 

174 def safe_eval(self, src: str) -> Any: 

175 """ 

176 Evaluate strings that are safe only (raise errors otherwise). 

177 

178 A "safe" string or node provided may only consist of nodes in 

179 `databooks.affirm._ALLOWED_NODES` and built-ins from 

180 `databooks.affirm._ALLOWED_BUILTINS`. 

181 """ 

182 ast_tree = ast.parse(src, mode="eval") 

183 return self.safe_eval_ast(ast_tree) 

184 

185 

186def affirm(nb_path: Path, exprs: List[str], verbose: bool = False) -> bool: 

187 """ 

188 Return whether notebook passed all checks (expressions). 

189 

190 :param nb_path: Path of notebook file 

191 :param exprs: Expression with check to be evaluated on notebook 

192 :param verbose: Log failed tests for notebook 

193 :return: Evaluated expression cast as a `bool` 

194 """ 

195 if verbose: 

196 set_verbose(logger) 

197 

198 nb = JupyterNotebook.parse_file(nb_path) 

199 variables: Dict[str, Any] = { 

200 "nb": nb, 

201 "raw_cells": [c for c in nb.cells if c.cell_type == "raw"], 

202 "md_cells": [c for c in nb.cells if c.cell_type == "markdown"], 

203 "code_cells": [c for c in nb.cells if c.cell_type == "code"], 

204 "exec_cells": [ 

205 c 

206 for c in nb.cells 

207 if c.cell_type == "code" and c.execution_count is not None 

208 ], 

209 } 

210 databooks_parser = DatabooksParser(**variables) 

211 is_ok = [bool(databooks_parser.safe_eval(expr)) for expr in exprs] 

212 n_fail = sum([not ok for ok in is_ok]) 

213 

214 logger.info(f"{nb_path} failed {n_fail} of {len(is_ok)} checks.") 

215 logger.debug( 

216 str(nb_path) 

217 + ( 

218 f" failed {list(compress(exprs, (not ok for ok in is_ok)))}." 

219 if n_fail > 0 

220 else " succeeded all checks." 

221 ) 

222 ) 

223 return all(is_ok) 

224 

225 

226def affirm_all( 

227 nb_paths: List[Path], 

228 *, 

229 progress_callback: Callable[[], None] = lambda: None, 

230 **affirm_kwargs: Any, 

231) -> List[bool]: 

232 """ 

233 Clear metadata for multiple notebooks at notebooks and cell level. 

234 

235 :param nb_paths: Paths of notebooks to assert metadata 

236 :param progress_callback: Callback function to report progress 

237 :param affirm_kwargs: Keyword arguments to be passed to `databooks.affirm.affirm` 

238 :return: Whether the notebooks contained or not the desired metadata 

239 """ 

240 checks = [] 

241 for nb_path in nb_paths: 

242 checks.append(affirm(nb_path, **affirm_kwargs)) 

243 progress_callback() 

244 return checks