Coverage for databooks/affirm.py: 98%
84 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-10-03 12:27 +0000
1"""Functions to safely evaluate strings and inspect notebook."""
2import ast
3from collections import abc
4from copy import deepcopy
5from itertools import compress
6from pathlib import Path
7from typing import Any, Callable, Dict, Iterable, List, Tuple
9from databooks import JupyterNotebook
10from databooks.data_models.base import DatabooksBase
11from databooks.logging import get_logger, set_verbose
13logger = get_logger(__file__)
15_ALLOWED_BUILTINS = (
16 all,
17 any,
18 enumerate,
19 filter,
20 getattr,
21 hasattr,
22 len,
23 list,
24 range,
25 sorted,
26)
27_ALLOWED_NODES = (
28 ast.Add,
29 ast.And,
30 ast.BinOp,
31 ast.BitAnd,
32 ast.BitOr,
33 ast.BitXor,
34 ast.BoolOp,
35 ast.boolop,
36 ast.cmpop,
37 ast.Compare,
38 ast.comprehension,
39 ast.Constant,
40 ast.Dict,
41 ast.Eq,
42 ast.Expr,
43 ast.expr,
44 ast.expr_context,
45 ast.Expression,
46 ast.For,
47 ast.Gt,
48 ast.GtE,
49 ast.In,
50 ast.Is,
51 ast.IsNot,
52 ast.List,
53 ast.ListComp,
54 ast.Load,
55 ast.LShift,
56 ast.Lt,
57 ast.LtE,
58 ast.Mod,
59 ast.Name,
60 ast.Not,
61 ast.NotEq,
62 ast.NotIn,
63 ast.Num,
64 ast.operator,
65 ast.Or,
66 ast.RShift,
67 ast.Set,
68 ast.Slice,
69 ast.slice,
70 ast.Str,
71 ast.Sub,
72 ast.Subscript,
73 ast.Tuple,
74 ast.UAdd,
75 ast.UnaryOp,
76 ast.unaryop,
77 ast.USub,
78)
81class DatabooksParser(ast.NodeVisitor):
82 """AST parser that disallows unsafe nodes/values."""
84 def __init__(self, **variables: Any) -> None:
85 """Instantiate with variables and callables (built-ins) scope."""
86 # https://github.com/python/mypy/issues/3728
87 self.builtins = {b.__name__: b for b in _ALLOWED_BUILTINS} # type: ignore
88 self.names = deepcopy(variables) or {}
89 self.scope = {
90 **self.names,
91 "__builtins__": self.builtins,
92 }
94 @staticmethod
95 def _prioritize(field: Tuple[str, Any]) -> bool:
96 """Prioritize `ast.comprehension` nodes when expanding the AST tree."""
97 _, value = field
98 if not isinstance(value, list):
99 return True
100 return not any(isinstance(f, ast.comprehension) for f in value)
102 @staticmethod
103 def _allowed_attr(obj: Any, attr: str, is_dynamic: bool = False) -> None:
104 """
105 Check that attribute is a key of `databooks.data_models.base.DatabooksBase`.
107 If `obj` is an iterable and was computed dynamically (that is, not originally in
108 scope but computed from a comprehension), check attributes for all elements in
109 the iterable.
110 """
111 allowed_attrs = list(dict(obj).keys()) if isinstance(obj, DatabooksBase) else ()
112 if isinstance(obj, abc.Iterable) and is_dynamic:
113 for el in obj:
114 DatabooksParser._allowed_attr(obj=el, attr=attr)
115 else:
116 if attr not in allowed_attrs:
117 raise ValueError(
118 "Expected attribute to be one of"
119 f" `{allowed_attrs}`, got `{attr}` for {obj}."
120 )
122 def _get_iter(self, node: ast.AST) -> Iterable:
123 """Use `DatabooksParser.safe_eval_ast` to get the iterable object."""
124 tree = ast.Expression(body=node)
125 return iter(self.safe_eval_ast(tree))
127 def generic_visit(self, node: ast.AST) -> None:
128 """
129 Prioritize `ast.comprehension` nodes when expanding tree.
131 Similar to `NodeVisitor.generic_visit`, but favor comprehensions when multiple
132 nodes on the same level. In comprehensions, we have a generator argument that
133 includes names that are stored. By visiting them first we avoid 'running into'
134 unknown names.
135 """
136 if not isinstance(node, _ALLOWED_NODES):
137 raise ValueError(f"Invalid node `{node}`.")
139 for field, value in sorted(ast.iter_fields(node), key=self._prioritize):
140 if isinstance(value, list):
141 for item in value:
142 if isinstance(item, ast.AST):
143 self.visit(item)
144 elif isinstance(value, ast.AST):
145 self.visit(value)
147 def visit_comprehension(self, node: ast.comprehension) -> None:
148 """Add variable from a comprehension to list of allowed names."""
149 if not isinstance(node.target, ast.Name):
150 raise RuntimeError(
151 "Expected `ast.comprehension`'s target to be `ast.Name`, got"
152 f" `ast.{type(node.target).__name__}`."
153 )
154 self.names[node.target.id] = self._get_iter(node.iter)
155 self.generic_visit(node)
157 def visit_Attribute(self, node: ast.Attribute) -> None:
158 """Allow attributes for Pydantic fields only."""
159 if not isinstance(node.value, (ast.Attribute, ast.Name, ast.Subscript)):
160 raise ValueError(
161 "Expected attribute to be one of `ast.Name`, `ast.Attribute` or"
162 f" `ast.Subscript`, got `ast.{type(node.value).__name__}`."
163 )
164 if isinstance(node.value, ast.Name):
165 self._allowed_attr(
166 obj=self.names[node.value.id],
167 attr=node.attr,
168 is_dynamic=node.value.id in (self.names.keys() - self.scope.keys()),
169 )
170 self.generic_visit(node)
172 def visit_Name(self, node: ast.Name) -> None:
173 """Only allow names from scope or comprehension variables."""
174 valid_names = {**self.names, **self.builtins}
175 if node.id not in valid_names:
176 raise ValueError(
177 f"Expected `name` to be one of `{valid_names.keys()}`, got `{node.id}`."
178 )
179 self.generic_visit(node)
181 def safe_eval_ast(self, ast_tree: ast.AST) -> Any:
182 """Evaluate safe AST trees only (raise errors otherwise)."""
183 self.visit(ast_tree)
184 exe = compile(ast_tree, filename="", mode="eval")
185 return eval(exe, self.scope)
187 def safe_eval(self, src: str) -> Any:
188 """
189 Evaluate strings that are safe only (raise errors otherwise).
191 A "safe" string or node provided may only consist of nodes in
192 `databooks.affirm._ALLOWED_NODES` and built-ins from
193 `databooks.affirm._ALLOWED_BUILTINS`.
194 """
195 ast_tree = ast.parse(src, mode="eval")
196 return self.safe_eval_ast(ast_tree)
199def affirm(nb_path: Path, exprs: List[str], verbose: bool = False) -> bool:
200 """
201 Return whether notebook passed all checks (expressions).
203 :param nb_path: Path of notebook file
204 :param exprs: Expression with check to be evaluated on notebook
205 :param verbose: Log failed tests for notebook
206 :return: Evaluated expression cast as a `bool`
207 """
208 if verbose:
209 set_verbose(logger)
211 nb = JupyterNotebook.parse_file(nb_path)
212 variables: Dict[str, Any] = {
213 "nb": nb,
214 "raw_cells": [c for c in nb.cells if c.cell_type == "raw"],
215 "md_cells": [c for c in nb.cells if c.cell_type == "markdown"],
216 "code_cells": [c for c in nb.cells if c.cell_type == "code"],
217 "exec_cells": [
218 c
219 for c in nb.cells
220 if c.cell_type == "code" and c.execution_count is not None
221 ],
222 }
223 databooks_parser = DatabooksParser(**variables)
224 is_ok = [bool(databooks_parser.safe_eval(expr)) for expr in exprs]
225 n_fail = sum([not ok for ok in is_ok])
227 logger.info(f"{nb_path} failed {n_fail} of {len(is_ok)} checks.")
228 logger.debug(
229 str(nb_path)
230 + (
231 f" failed {list(compress(exprs, (not ok for ok in is_ok)))}."
232 if n_fail > 0
233 else " succeeded all checks."
234 )
235 )
236 return all(is_ok)
239def affirm_all(
240 nb_paths: List[Path],
241 *,
242 progress_callback: Callable[[], None] = lambda: None,
243 **affirm_kwargs: Any,
244) -> List[bool]:
245 """
246 Clear metadata for multiple notebooks at notebooks and cell level.
248 :param nb_paths: Paths of notebooks to assert metadata
249 :param progress_callback: Callback function to report progress
250 :param affirm_kwargs: Keyword arguments to be passed to `databooks.affirm.affirm`
251 :return: Whether the notebooks contained or not the desired metadata
252 """
253 checks = []
254 for nb_path in nb_paths:
255 checks.append(affirm(nb_path, **affirm_kwargs))
256 progress_callback()
257 return checks