Coverage for databooks/affirm.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Functions to safely evaluate strings and inspect notebook."""
2import ast
3from copy import deepcopy
4from functools import reduce
5from itertools import compress
6from pathlib import Path
7from typing import Any, Callable, Dict, Iterable, List, Tuple
9from databooks import JupyterNotebook
10from databooks.data_models.base import DatabooksBase
11from databooks.logging import get_logger, set_verbose
13logger = get_logger(__file__)
15_ALLOWED_BUILTINS = (
16 all,
17 any,
18 enumerate,
19 filter,
20 getattr,
21 hasattr,
22 len,
23 list,
24 range,
25 sorted,
26)
27_ALLOWED_NODES = (
28 ast.Add,
29 ast.And,
30 ast.BinOp,
31 ast.BitAnd,
32 ast.BitOr,
33 ast.BitXor,
34 ast.BoolOp,
35 ast.boolop,
36 ast.cmpop,
37 ast.Compare,
38 ast.comprehension,
39 ast.Constant,
40 ast.Dict,
41 ast.Eq,
42 ast.Expr,
43 ast.expr,
44 ast.expr_context,
45 ast.Expression,
46 ast.For,
47 ast.Gt,
48 ast.GtE,
49 ast.In,
50 ast.Is,
51 ast.IsNot,
52 ast.List,
53 ast.ListComp,
54 ast.Load,
55 ast.LShift,
56 ast.Lt,
57 ast.LtE,
58 ast.Mod,
59 ast.Name,
60 ast.Not,
61 ast.NotEq,
62 ast.NotIn,
63 ast.Num,
64 ast.operator,
65 ast.Or,
66 ast.RShift,
67 ast.Set,
68 ast.Slice,
69 ast.slice,
70 ast.Str,
71 ast.Sub,
72 ast.Tuple,
73 ast.UAdd,
74 ast.UnaryOp,
75 ast.unaryop,
76 ast.USub,
77)
80class DatabooksParser(ast.NodeVisitor):
81 """AST parser that disallows unsafe nodes/values."""
83 def __init__(self, **variables: Any) -> None:
84 """Instantiate with variables and callables (built-ins) scope."""
85 # https://github.com/python/mypy/issues/3728
86 self.builtins = {b.__name__: b for b in _ALLOWED_BUILTINS} # type: ignore
87 self.names = deepcopy(variables) or {}
88 self.scope = {
89 **self.names,
90 "__builtins__": self.builtins,
91 }
93 @staticmethod
94 def _prioritize(field: Tuple[str, Any]) -> bool:
95 """Prioritize `ast.comprehension` nodes when expanding the AST tree."""
96 _, value = field
97 if not isinstance(value, list):
98 return True
99 return not any(isinstance(f, ast.comprehension) for f in value)
101 def _get_iter(self, node: ast.AST) -> Iterable:
102 """Use `DatabooksParser.safe_eval_ast` to get the iterable object."""
103 tree = ast.Expression(body=node)
104 return iter(self.safe_eval_ast(tree))
106 def generic_visit(self, node: ast.AST) -> None:
107 """
108 Prioritize `ast.comprehension` nodes when expanding tree.
110 Similar to `NodeVisitor.generic_visit`, but favor comprehensions when multiple
111 nodes on the same level. In comprehensions, we have a generator argument that
112 includes names that are stored. By visiting them first we avoid 'running into'
113 unknown names.
114 """
115 if not isinstance(node, _ALLOWED_NODES):
116 raise ValueError(f"Invalid node `{node}`.")
118 for field, value in sorted(ast.iter_fields(node), key=self._prioritize):
119 if isinstance(value, list):
120 for item in value:
121 if isinstance(item, ast.AST):
122 self.visit(item)
123 elif isinstance(value, ast.AST):
124 self.visit(value)
126 def visit_comprehension(self, node: ast.comprehension) -> None:
127 """Add variable from a comprehension to list of allowed names."""
128 if not isinstance(node.target, ast.Name):
129 raise RuntimeError(
130 "Expected `ast.comprehension`'s target to be `ast.Name`, got"
131 f" `ast.{type(node.target).__name__}`."
132 )
133 # If any elements in the comprehension are a `DatabooksBase` instance, then
134 # pass down the attributes as valid
135 iterable = self._get_iter(node.iter)
136 databooks_el = [el for el in iterable if isinstance(el, DatabooksBase)]
137 if databooks_el:
138 d_attrs = reduce(lambda a, b: {**a, **b}, [dict(el) for el in databooks_el])
139 self.names[node.target.id] = DatabooksBase(**d_attrs) if databooks_el else ...
140 self.generic_visit(node)
142 def visit_Attribute(self, node: ast.Attribute) -> None:
143 """Allow attributes for Pydantic fields only."""
144 if not isinstance(node.value, (ast.Attribute, ast.Name)):
145 raise ValueError(
146 "Expected attribute to be one of `ast.Name` or `ast.Attribute`, got"
147 f" `ast.{type(node.value).__name__}`"
148 )
149 if not isinstance(node.value, ast.Attribute):
150 obj = self.names[node.value.id]
151 allowed_attrs = dict(obj).keys() if isinstance(obj, DatabooksBase) else ()
152 if node.attr not in allowed_attrs:
153 raise ValueError(
154 "Expected attribute to be one of"
155 f" `{allowed_attrs}`, got `{node.attr}`"
156 )
157 self.generic_visit(node)
159 def visit_Name(self, node: ast.Name) -> None:
160 """Only allow names from scope or comprehension variables."""
161 valid_names = {**self.names, **self.builtins}
162 if node.id not in valid_names:
163 raise ValueError(
164 f"Expected `name` to be one of `{valid_names.keys()}`, got `{node.id}`."
165 )
166 self.generic_visit(node)
168 def safe_eval_ast(self, ast_tree: ast.AST) -> Any:
169 """Evaluate safe AST trees only (raise errors otherwise)."""
170 self.visit(ast_tree)
171 exe = compile(ast_tree, filename="", mode="eval")
172 return eval(exe, self.scope)
174 def safe_eval(self, src: str) -> Any:
175 """
176 Evaluate strings that are safe only (raise errors otherwise).
178 A "safe" string or node provided may only consist of nodes in
179 `databooks.affirm._ALLOWED_NODES` and built-ins from
180 `databooks.affirm._ALLOWED_BUILTINS`.
181 """
182 ast_tree = ast.parse(src, mode="eval")
183 return self.safe_eval_ast(ast_tree)
186def affirm(nb_path: Path, exprs: List[str], verbose: bool = False) -> bool:
187 """
188 Return whether notebook passed all checks (expressions).
190 :param nb_path: Path of notebook file
191 :param exprs: Expression with check to be evaluated on notebook
192 :param verbose: Log failed tests for notebook
193 :return: Evaluated expression cast as a `bool`
194 """
195 if verbose:
196 set_verbose(logger)
198 nb = JupyterNotebook.parse_file(nb_path)
199 variables: Dict[str, Any] = {
200 "nb": nb,
201 "raw_cells": [c for c in nb.cells if c.cell_type == "raw"],
202 "md_cells": [c for c in nb.cells if c.cell_type == "markdown"],
203 "code_cells": [c for c in nb.cells if c.cell_type == "code"],
204 "exec_cells": [
205 c
206 for c in nb.cells
207 if c.cell_type == "code" and c.execution_count is not None
208 ],
209 }
210 databooks_parser = DatabooksParser(**variables)
211 is_ok = [bool(databooks_parser.safe_eval(expr)) for expr in exprs]
212 n_fail = sum([not ok for ok in is_ok])
214 logger.info(f"{nb_path} failed {n_fail} of {len(is_ok)} checks.")
215 logger.debug(
216 str(nb_path)
217 + (
218 f" failed {list(compress(exprs, (not ok for ok in is_ok)))}."
219 if n_fail > 0
220 else " succeeded all checks."
221 )
222 )
223 return all(is_ok)
226def affirm_all(
227 nb_paths: List[Path],
228 *,
229 progress_callback: Callable[[], None] = lambda: None,
230 **affirm_kwargs: Any,
231) -> List[bool]:
232 """
233 Clear metadata for multiple notebooks at notebooks and cell level.
235 :param nb_paths: Paths of notebooks to assert metadata
236 :param progress_callback: Callback function to report progress
237 :param affirm_kwargs: Keyword arguments to be passed to `databooks.affirm.affirm`
238 :return: Whether the notebooks contained or not the desired metadata
239 """
240 checks = []
241 for nb_path in nb_paths:
242 checks.append(affirm(nb_path, **affirm_kwargs))
243 progress_callback()
244 return checks