Coverage for databooks/cli.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

90 statements  

1"""Main CLI application.""" 

2from itertools import compress 

3from pathlib import Path 

4from typing import List, Optional 

5 

6import tomli 

7from rich.progress import ( 

8 BarColumn, 

9 Progress, 

10 SpinnerColumn, 

11 TextColumn, 

12 TimeElapsedColumn, 

13) 

14from typer import Argument, BadParameter, Context, Exit, Option, Typer, echo 

15 

16from databooks.affirm import affirm_all 

17from databooks.common import expand_paths 

18from databooks.config import TOML_CONFIG_FILE, get_config 

19from databooks.conflicts import conflicts2nbs, path2conflicts 

20from databooks.logging import get_logger 

21from databooks.metadata import clear_all 

22from databooks.recipes import Recipe 

23from databooks.version import __version__ 

24 

25logger = get_logger(__file__) 

26 

27app = Typer() 

28 

29 

30def _version_callback(show_version: bool) -> None: 

31 """Return application version.""" 

32 if show_version: 

33 echo("databooks version: " + __version__) 

34 raise Exit() 

35 

36 

37def _help_callback(ctx: Context, show_help: Optional[bool]) -> None: 

38 """Reimplement `help` command to execute eagerly.""" 

39 if show_help: 

40 echo(ctx.command.get_help(ctx)) 

41 raise Exit() 

42 

43 

44def _config_callback(ctx: Context, config_path: Optional[Path]) -> Optional[Path]: 

45 """Get config file and inject values into context to override default args.""" 

46 target_paths = expand_paths( 

47 paths=[Path(p) for p in ctx.params.get("paths", ())], rglob="*" 

48 ) 

49 config_path = ( 

50 get_config( 

51 target_paths=target_paths, 

52 config_filename=TOML_CONFIG_FILE, 

53 ) 

54 if config_path is None and target_paths 

55 else config_path 

56 ) 

57 logger.debug(f"Loading config file from: {config_path}") 

58 

59 ctx.default_map = ctx.default_map or {} # initialize defaults 

60 

61 if config_path is not None: # config may not be specified 

62 with config_path.open("r") as f: 

63 conf = ( 

64 tomli.load(f) 

65 .get("tool", {}) 

66 .get("databooks", {}) 

67 .get(ctx.command.name, {}) 

68 ) 

69 # Merge configuration 

70 ctx.default_map.update({k.replace("-", "_"): v for k, v in conf.items()}) 

71 return config_path 

72 

73 

74def _check_paths(paths: List[Path], ignore: List[str]) -> List[Path]: 

75 if any(path.suffix not in ("", ".ipynb") for path in paths): 

76 raise BadParameter( 

77 "Expected either notebook files, a directory or glob expression." 

78 ) 

79 nb_paths = expand_paths(paths=paths, ignore=ignore) 

80 if not nb_paths: 

81 logger.info(f"No notebooks found in {paths}. Nothing to do.") 

82 raise Exit() 

83 return nb_paths 

84 

85 

86@app.callback() 

87def callback( # noqa: D103 

88 version: Optional[bool] = Option( 

89 None, "--version", callback=_version_callback, is_eager=True 

90 ) 

91) -> None: 

92 """CLI tool to resolve git conflicts and remove metadata in notebooks.""" 

93 

94 

95@app.command(add_help_option=False) 

96def meta( 

97 paths: List[Path] = Argument(..., is_eager=True, help="Path(s) of notebook files"), 

98 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

99 prefix: str = Option("", help="Prefix to add to filepath when writing files"), 

100 suffix: str = Option("", help="Suffix to add to filepath when writing files"), 

101 rm_outs: bool = Option(False, help="Whether to remove cell outputs"), 

102 rm_exec: bool = Option(True, help="Whether to remove the cell execution counts"), 

103 nb_meta_keep: List[str] = Option((), help="Notebook metadata fields to keep"), 

104 cell_meta_keep: List[str] = Option((), help="Cells metadata fields to keep"), 

105 cell_fields_keep: List[str] = Option( 

106 (), 

107 help="Other (excluding `execution_counts` and `outputs`) cell fields to keep", 

108 ), 

109 overwrite: bool = Option( 

110 False, "--overwrite", "-w", help="Confirm overwrite of files" 

111 ), 

112 check: bool = Option( 

113 False, 

114 "--check", 

115 help="Don't write files but check whether there is unwanted metadata", 

116 ), 

117 verbose: bool = Option( 

118 False, "--verbose", "-v", help="Log processed files in console" 

119 ), 

120 config: Optional[Path] = Option( 

121 None, 

122 "--config", 

123 "-c", 

124 is_eager=True, 

125 callback=_config_callback, 

126 resolve_path=True, 

127 exists=True, 

128 help="Get CLI options from configuration file", 

129 ), 

130 help: Optional[bool] = Option( 

131 None, is_eager=True, callback=_help_callback, help="Show this message and exit" 

132 ), 

133) -> None: 

134 """Clear both notebook and cell metadata.""" 

135 nb_paths = _check_paths(paths=paths, ignore=ignore) 

136 

137 if not bool(prefix + suffix) and not check: 

138 if not overwrite: 

139 raise BadParameter( 

140 "No prefix nor suffix were passed." 

141 " Please specify `--overwrite` or `-w` to overwrite files." 

142 ) 

143 else: 

144 logger.warning(f"{len(nb_paths)} files will be overwritten") 

145 

146 write_paths = [p.parent / (prefix + p.stem + suffix + p.suffix) for p in nb_paths] 

147 cell_fields_keep = list( 

148 compress(["outputs", "execution_count"], (not v for v in (rm_outs, rm_exec))) 

149 ) + list(cell_fields_keep) 

150 with Progress( 

151 SpinnerColumn(), 

152 TextColumn("[progress.description]{task.description}"), 

153 BarColumn(), 

154 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

155 TimeElapsedColumn(), 

156 ) as progress: 

157 metadata = progress.add_task("[yellow]Removing metadata", total=len(nb_paths)) 

158 

159 are_equal = clear_all( 

160 read_paths=nb_paths, 

161 write_paths=write_paths, 

162 progress_callback=lambda: progress.update(metadata, advance=1), 

163 notebook_metadata_keep=nb_meta_keep, 

164 cell_metadata_keep=cell_meta_keep, 

165 cell_fields_keep=cell_fields_keep, 

166 check=check, 

167 verbose=verbose, 

168 ) 

169 if check: 

170 if all(are_equal): 

171 logger.info("No unwanted metadata!") 

172 else: 

173 logger.info( 

174 f"Found unwanted metadata in {sum(not eq for eq in are_equal)} out of" 

175 f" {len(are_equal)} files." 

176 ) 

177 raise Exit(code=1) 

178 else: 

179 logger.info( 

180 f"The metadata of {sum(not eq for eq in are_equal)} out of {len(are_equal)}" 

181 " notebooks were removed!" 

182 ) 

183 

184 

185@app.command("assert", add_help_option=False) 

186def affirm_meta( 

187 paths: List[Path] = Argument(..., is_eager=True, help="Path(s) of notebook files"), 

188 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

189 expr: List[str] = Option((), help="Expressions to assert on notebooks"), 

190 recipe: List[Recipe] = Option((), help="Common recipes of expressions"), 

191 verbose: bool = Option( 

192 False, "--verbose", "-v", help="Log processed files in console" 

193 ), 

194 config: Optional[Path] = Option( 

195 None, 

196 "--config", 

197 "-c", 

198 is_eager=True, 

199 callback=_config_callback, 

200 resolve_path=True, 

201 exists=True, 

202 help="Get CLI options from configuration file", 

203 ), 

204 help: Optional[bool] = Option( 

205 None, is_eager=True, callback=_help_callback, help="Show this message and exit" 

206 ), 

207) -> None: 

208 """ 

209 Assert notebook metadata has desired values. 

210 

211 Pass one (or multiple) strings or recipes. The available variables in scope include 

212 `nb` (notebook), `raw_cells` (notebook cells of `raw` type), `md_cells` (notebook 

213 cells of `markdown` type), `code_cells` (notebook cells of `code` type) and 

214 `exec_cells` (notebook cells of `code` type that were executed - have an `execution 

215 count` value). Recipes can be found on `databooks.recipes.CookBook`. 

216 """ 

217 nb_paths = _check_paths(paths=paths, ignore=ignore) 

218 exprs = [r.name for r in recipe] + list(expr) 

219 if not exprs: 

220 raise BadParameter("Must specify at least one of `expr` or `recipe`.") 

221 

222 with Progress( 

223 SpinnerColumn(), 

224 TextColumn("[progress.description]{task.description}"), 

225 BarColumn(), 

226 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

227 TimeElapsedColumn(), 

228 ) as progress: 

229 assert_checks = progress.add_task( 

230 "[yellow]Running assert checks", total=len(nb_paths) 

231 ) 

232 

233 are_ok = affirm_all( 

234 nb_paths=nb_paths, 

235 progress_callback=lambda: progress.update(assert_checks, advance=1), 

236 exprs=exprs, 

237 verbose=verbose, 

238 ) 

239 

240 if all(are_ok): 

241 logger.info("All notebooks comply with the desired metadata!") 

242 else: 

243 logger.info( 

244 f"Found issues in notebook metadata for {sum(not ok for ok in are_ok)} out" 

245 f" of {len(are_ok)} notebooks." 

246 ) 

247 raise Exit(code=1) 

248 

249 

250@app.command(add_help_option=False) 

251def fix( 

252 paths: List[Path] = Argument( 

253 ..., is_eager=True, help="Path(s) of notebook files with conflicts" 

254 ), 

255 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

256 metadata_head: bool = Option( 

257 True, help="Whether or not to keep the metadata from the head/current notebook" 

258 ), 

259 cells_head: Optional[bool] = Option( 

260 None, 

261 help="Whether to keep the cells from the head/base notebook. Omit to keep both", 

262 ), 

263 cell_fields_ignore: List[str] = Option( 

264 [ 

265 "id", 

266 "execution_count", 

267 ], 

268 help="Cell fields to remove before comparing cells", 

269 ), 

270 interactive: bool = Option( 

271 False, 

272 "--interactive", 

273 "-i", 

274 help="Interactively resolve the conflicts (not implemented)", 

275 ), 

276 verbose: bool = Option(False, help="Log processed files in console"), 

277 config: Optional[Path] = Option( 

278 None, 

279 "--config", 

280 "-c", 

281 is_eager=True, 

282 callback=_config_callback, 

283 resolve_path=True, 

284 exists=True, 

285 help="Get CLI options from configuration file", 

286 ), 

287 help: Optional[bool] = Option( 

288 None, is_eager=True, callback=_help_callback, help="Show this message and exit" 

289 ), 

290) -> None: 

291 """ 

292 Fix git conflicts for notebooks. 

293 

294 Perform by getting the unmerged blobs from git index, comparing them and returning 

295 a valid notebook summarizing the differences - see 

296 [git docs](https://git-scm.com/docs/git-ls-files). 

297 """ 

298 filepaths = expand_paths(paths=paths, ignore=ignore) 

299 conflict_files = path2conflicts(nb_paths=filepaths) 

300 if not conflict_files: 

301 raise BadParameter( 

302 f"No conflicts found at {', '.join([str(p) for p in filepaths])}." 

303 ) 

304 if interactive: 

305 raise NotImplementedError 

306 

307 with Progress( 

308 SpinnerColumn(), 

309 TextColumn("[progress.description]{task.description}"), 

310 BarColumn(), 

311 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

312 TimeElapsedColumn(), 

313 ) as progress: 

314 conflicts = progress.add_task( 

315 "[yellow]Resolving conflicts", total=len(conflict_files) 

316 ) 

317 conflicts2nbs( 

318 conflict_files=conflict_files, 

319 meta_first=metadata_head, 

320 cells_first=cells_head, 

321 cell_fields_ignore=cell_fields_ignore, 

322 verbose=verbose, 

323 progress_callback=lambda: progress.update(conflicts, advance=1), 

324 ) 

325 logger.info(f"Resolved the conflicts of {len(conflict_files)}!") 

326 

327 

328@app.command() 

329def diff() -> None: 

330 """Show differences between notebooks (not implemented).""" 

331 raise NotImplementedError