Coverage for databooks/cli.py: 91%

119 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 13:11 +0000

1"""Main CLI application.""" 

2from itertools import compress 

3from pathlib import Path 

4from typing import List, Optional, Tuple 

5 

6import tomli 

7from rich.progress import ( 

8 BarColumn, 

9 Progress, 

10 SpinnerColumn, 

11 TextColumn, 

12 TimeElapsedColumn, 

13) 

14from rich.prompt import Confirm 

15from typer import Argument, BadParameter, Context, Exit, Option, Typer, echo 

16 

17from databooks.affirm import affirm_all 

18from databooks.common import expand_paths 

19from databooks.config import TOML_CONFIG_FILE, get_config 

20from databooks.conflicts import conflicts2nbs, path2conflicts 

21from databooks.git_utils import get_nb_diffs 

22from databooks.logging import get_logger 

23from databooks.metadata import clear_all 

24from databooks.recipes import Recipe 

25from databooks.tui import print_diffs, print_nbs 

26from databooks.version import __version__ 

27 

28logger = get_logger(__file__) 

29 

30app = Typer() 

31 

32 

33def _version_callback(show_version: bool) -> None: 

34 """Return application version.""" 

35 if show_version: 

36 echo("databooks version: " + __version__) 

37 raise Exit() 

38 

39 

40def _help_callback(ctx: Context, show_help: Optional[bool]) -> None: 

41 """Reimplement `help` command to execute eagerly.""" 

42 if show_help: 

43 echo(ctx.command.get_help(ctx)) 

44 raise Exit() 

45 

46 

47def _config_callback(ctx: Context, config_path: Optional[Path]) -> Optional[Path]: 

48 """Get config file and inject values into context to override default args.""" 

49 target_paths = expand_paths( 

50 paths=[Path(p).resolve() for p in ctx.params.get("paths", ())] 

51 ) or [Path.cwd()] 

52 config_path = ( 

53 get_config( 

54 target_paths=target_paths, 

55 config_filename=TOML_CONFIG_FILE, 

56 ) 

57 if config_path is None and target_paths 

58 else config_path 

59 ) 

60 logger.debug(f"Loading config file from: {config_path}") 

61 if config_path is not None: # config may not be specified 

62 with config_path.open("rb") as f: 

63 conf = ( 

64 tomli.load(f) 

65 .get("tool", {}) 

66 .get("databooks", {}) 

67 .get(ctx.command.name, {}) 

68 ) 

69 # Merge configuration 

70 ctx.default_map = { 

71 **(ctx.default_map or {}), 

72 **{k.replace("-", "_"): v for k, v in conf.items()}, 

73 } 

74 return config_path 

75 

76 

77def _check_paths(paths: List[Path], ignore: List[str]) -> List[Path]: 

78 """Check that notebooks exist retrieve the file paths.""" 

79 if any(path.suffix not in ("", ".ipynb") for path in paths): 

80 raise BadParameter( 

81 "Expected either notebook files, a directory or glob expression." 

82 ) 

83 nb_paths = expand_paths(paths=paths, ignore=ignore) 

84 if not nb_paths: 

85 logger.info( 

86 f"No notebooks found in {[p.resolve() for p in paths]}. Nothing to do." 

87 ) 

88 raise Exit() 

89 return nb_paths 

90 

91 

92def _parse_paths( 

93 *refs: Optional[str], paths: List[Path] 

94) -> Tuple[Tuple[Optional[str], ...], List[Path]]: 

95 """Detect paths from `refs` and add to `paths`.""" 

96 first, *rest = refs 

97 if first is not None and Path(first).exists(): 

98 paths += [Path(first)] 

99 first = None 

100 if rest: 

101 _refs, _paths = _parse_paths(*rest, paths=paths) 

102 return (first, *_refs), _paths 

103 return (first,), paths 

104 

105 

106@app.callback() 

107def callback( # noqa: D103 

108 version: Optional[bool] = Option( 

109 None, "--version", callback=_version_callback, is_eager=True 

110 ) 

111) -> None: 

112 """CLI tool to resolve git conflicts and remove metadata in notebooks.""" 

113 

114 

115@app.command(add_help_option=False) 

116def meta( 

117 paths: List[Path] = Argument(..., is_eager=True, help="Path(s) of notebook files"), 

118 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

119 prefix: str = Option("", help="Prefix to add to filepath when writing files"), 

120 suffix: str = Option("", help="Suffix to add to filepath when writing files"), 

121 rm_outs: bool = Option(False, help="Whether to remove cell outputs"), 

122 rm_exec: bool = Option(True, help="Whether to remove the cell execution counts"), 

123 nb_meta_keep: List[str] = Option((), help="Notebook metadata fields to keep"), 

124 cell_meta_keep: List[str] = Option((), help="Cells metadata fields to keep"), 

125 cell_fields_keep: List[str] = Option( 

126 (), 

127 help="Other (excluding `execution_counts` and `outputs`) cell fields to keep", 

128 ), 

129 overwrite: bool = Option(False, "--yes", "-y", help="Confirm overwrite of files"), 

130 check: bool = Option( 

131 False, 

132 "--check", 

133 help="Don't write files but check whether there is unwanted metadata", 

134 ), 

135 verbose: bool = Option( 

136 False, "--verbose", "-v", help="Log processed files in console" 

137 ), 

138 config: Optional[Path] = Option( 

139 None, 

140 "--config", 

141 "-c", 

142 is_eager=True, 

143 callback=_config_callback, 

144 resolve_path=True, 

145 exists=True, 

146 help="Get CLI options from configuration file", 

147 ), 

148 help: Optional[bool] = Option( 

149 None, 

150 "--help", 

151 is_eager=True, 

152 callback=_help_callback, 

153 help="Show this message and exit", 

154 ), 

155) -> None: 

156 """Clear both notebook and cell metadata.""" 

157 nb_paths = _check_paths(paths=paths, ignore=ignore) 

158 

159 if not bool(prefix + suffix) and not check: 

160 overwrite = ( 

161 Confirm.ask( 

162 f"{len(nb_paths)} files will be overwritten" 

163 " (no prefix nor suffix was passed). Continue?" 

164 ) 

165 if not overwrite 

166 else overwrite 

167 ) 

168 if not overwrite: 

169 raise Exit() 

170 else: 

171 logger.warning(f"{len(nb_paths)} files will be overwritten") 

172 

173 write_paths = [p.parent / (prefix + p.stem + suffix + p.suffix) for p in nb_paths] 

174 cell_fields_keep = list( 

175 compress(["outputs", "execution_count"], (not v for v in (rm_outs, rm_exec))) 

176 ) + list(cell_fields_keep) 

177 with Progress( 

178 SpinnerColumn(), 

179 TextColumn("[progress.description]{task.description}"), 

180 BarColumn(), 

181 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

182 TimeElapsedColumn(), 

183 ) as progress: 

184 metadata = progress.add_task("[yellow]Removing metadata", total=len(nb_paths)) 

185 

186 are_equal = clear_all( 

187 read_paths=nb_paths, 

188 write_paths=write_paths, 

189 progress_callback=lambda: progress.update(metadata, advance=1), 

190 notebook_metadata_keep=nb_meta_keep, 

191 cell_metadata_keep=cell_meta_keep, 

192 cell_fields_keep=cell_fields_keep, 

193 check=check, 

194 verbose=verbose, 

195 overwrite=overwrite, 

196 ) 

197 if check: 

198 if all(are_equal): 

199 logger.info("No unwanted metadata!") 

200 else: 

201 logger.info( 

202 f"Found unwanted metadata in {sum(not eq for eq in are_equal)} out of" 

203 f" {len(are_equal)} files." 

204 ) 

205 raise Exit(code=1) 

206 else: 

207 logger.info( 

208 f"The metadata of {sum(not eq for eq in are_equal)} out of {len(are_equal)}" 

209 " notebooks were removed!" 

210 ) 

211 

212 

213@app.command("assert", add_help_option=False) 

214def affirm_meta( 

215 paths: List[Path] = Argument(..., is_eager=True, help="Path(s) of notebook files"), 

216 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

217 expr: List[str] = Option( 

218 (), "--expr", "-x", help="Expressions to assert on notebooks" 

219 ), 

220 recipe: List[Recipe] = Option( 

221 (), 

222 "--recipe", 

223 "-r", 

224 help="Common recipes of expressions - see" 

225 " https://databooks.dev/latest/usage/overview/#recipes", 

226 ), 

227 verbose: bool = Option( 

228 False, "--verbose", "-v", help="Log processed files in console" 

229 ), 

230 config: Optional[Path] = Option( 

231 None, 

232 "--config", 

233 "-c", 

234 is_eager=True, 

235 callback=_config_callback, 

236 resolve_path=True, 

237 exists=True, 

238 help="Get CLI options from configuration file", 

239 ), 

240 help: Optional[bool] = Option( 

241 None, 

242 "--help", 

243 is_eager=True, 

244 callback=_help_callback, 

245 help="Show this message and exit", 

246 ), 

247) -> None: 

248 """ 

249 Assert notebook metadata has desired values. 

250 

251 Pass one (or multiple) strings or recipes. The available variables in scope include 

252 `nb` (notebook), `raw_cells` (notebook cells of `raw` type), `md_cells` (notebook 

253 cells of `markdown` type), `code_cells` (notebook cells of `code` type) and 

254 `exec_cells` (notebook cells of `code` type that were executed - have an `execution 

255 count` value). Recipes can be found on `databooks.recipes.CookBook`. 

256 """ 

257 nb_paths = _check_paths(paths=paths, ignore=ignore) 

258 exprs = [r.name for r in recipe] + list(expr) 

259 if not exprs: 

260 raise BadParameter("Must specify at least one of `expr` or `recipe`.") 

261 

262 with Progress( 

263 SpinnerColumn(), 

264 TextColumn("[progress.description]{task.description}"), 

265 BarColumn(), 

266 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

267 TimeElapsedColumn(), 

268 ) as progress: 

269 assert_checks = progress.add_task( 

270 "[yellow]Running assert checks", total=len(nb_paths) 

271 ) 

272 

273 are_ok = affirm_all( 

274 nb_paths=nb_paths, 

275 progress_callback=lambda: progress.update(assert_checks, advance=1), 

276 exprs=exprs, 

277 verbose=verbose, 

278 ) 

279 

280 if all(are_ok): 

281 logger.info("All notebooks comply with the desired metadata!") 

282 else: 

283 logger.info( 

284 f"Found issues in notebook metadata for {sum(not ok for ok in are_ok)} out" 

285 f" of {len(are_ok)} notebooks." 

286 ) 

287 raise Exit(code=1) 

288 

289 

290@app.command(add_help_option=False) 

291def fix( 

292 paths: List[Path] = Argument( 

293 ..., is_eager=True, help="Path(s) of notebook files with conflicts" 

294 ), 

295 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

296 metadata_head: bool = Option( 

297 True, help="Whether or not to keep the metadata from the head/current notebook" 

298 ), 

299 cells_head: Optional[bool] = Option( 

300 None, 

301 help="Whether to keep the cells from the head/base notebook. Omit to keep both", 

302 ), 

303 cell_fields_ignore: List[str] = Option( 

304 [ 

305 "id", 

306 "execution_count", 

307 ], 

308 help="Cell fields to remove before comparing cells", 

309 ), 

310 interactive: bool = Option( 

311 False, 

312 "--interactive", 

313 "-i", 

314 help="Interactively resolve the conflicts (not implemented)", 

315 ), 

316 verbose: bool = Option( 

317 False, "--verbose", "-v", help="Log processed files in console" 

318 ), 

319 config: Optional[Path] = Option( 

320 None, 

321 "--config", 

322 "-c", 

323 is_eager=True, 

324 callback=_config_callback, 

325 resolve_path=True, 

326 exists=True, 

327 help="Get CLI options from configuration file", 

328 ), 

329 help: Optional[bool] = Option( 

330 None, 

331 "--help", 

332 is_eager=True, 

333 callback=_help_callback, 

334 help="Show this message and exit", 

335 ), 

336) -> None: 

337 """ 

338 Fix git conflicts for notebooks. 

339 

340 Perform by getting the unmerged blobs from git index, comparing them and returning 

341 a valid notebook summarizing the differences - see 

342 [git docs](https://git-scm.com/docs/git-ls-files). 

343 """ 

344 filepaths = expand_paths(paths=paths, ignore=ignore) 

345 if filepaths is None: 

346 raise RuntimeError( 

347 f"Expected `filepaths` to be list of paths, got {filepaths}." 

348 ) 

349 conflict_files = path2conflicts(nb_paths=filepaths) 

350 if not conflict_files: 

351 raise BadParameter( 

352 f"No conflicts found at {', '.join([str(p) for p in filepaths])}." 

353 ) 

354 if interactive: 

355 raise NotImplementedError 

356 

357 with Progress( 

358 SpinnerColumn(), 

359 TextColumn("[progress.description]{task.description}"), 

360 BarColumn(), 

361 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

362 TimeElapsedColumn(), 

363 ) as progress: 

364 conflicts = progress.add_task( 

365 "[yellow]Resolving conflicts", total=len(conflict_files) 

366 ) 

367 conflicts2nbs( 

368 conflict_files=conflict_files, 

369 meta_first=metadata_head, 

370 cells_first=cells_head, 

371 cell_fields_ignore=cell_fields_ignore, 

372 verbose=verbose, 

373 progress_callback=lambda: progress.update(conflicts, advance=1), 

374 ) 

375 logger.info(f"Resolved the conflicts of {len(conflict_files)}!") 

376 

377 

378@app.command(add_help_option=False) 

379def show( 

380 paths: List[Path] = Argument( 

381 ..., is_eager=True, help="Path(s) of notebook files with conflicts" 

382 ), 

383 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

384 pager: bool = Option( 

385 False, "--pager", "-p", help="Use pager instead of printing to terminal" 

386 ), 

387 verbose: bool = Option( 

388 False, "--verbose", "-v", help="Increase verbosity for debugging" 

389 ), 

390 multiple: bool = Option(False, "--yes", "-y", help="Show multiple files"), 

391 config: Optional[Path] = Option( 

392 None, 

393 "--config", 

394 "-c", 

395 is_eager=True, 

396 callback=_config_callback, 

397 resolve_path=True, 

398 exists=True, 

399 help="Get CLI options from configuration file", 

400 ), 

401 help: Optional[bool] = Option( 

402 None, 

403 "--help", 

404 is_eager=True, 

405 callback=_help_callback, 

406 help="Show this message and exit", 

407 ), 

408) -> None: 

409 """Show rich representation of notebook.""" 

410 nb_paths = _check_paths(paths=paths, ignore=ignore) 

411 if len(nb_paths) > 1 and not multiple: 

412 if not Confirm.ask(f"Show {len(nb_paths)} notebooks?"): 

413 raise Exit() 

414 

415 print_nbs(nb_paths, use_pager=pager) 

416 

417 

418@app.command() 

419def diff( 

420 ref_base: Optional[str] = Argument( 

421 None, help="Base reference (hash, branch, etc.), defaults to index" 

422 ), 

423 ref_remote: Optional[str] = Argument( 

424 None, help="Remote reference (hash, branch, etc.), defaults to working tree" 

425 ), 

426 paths: List[Path] = Argument( 

427 None, is_eager=True, help="Path(s) of notebook files to compare" 

428 ), 

429 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

430 pager: bool = Option( 

431 False, "--pager", "-p", help="Use pager instead of printing to terminal" 

432 ), 

433 verbose: bool = Option( 

434 False, "--verbose", "-v", help="Increase verbosity for debugging" 

435 ), 

436 multiple: bool = Option(False, "--yes", "-y", help="Show multiple files"), 

437 config: Optional[Path] = Option( 

438 None, 

439 "--config", 

440 "-c", 

441 is_eager=True, 

442 callback=_config_callback, 

443 resolve_path=True, 

444 exists=True, 

445 help="Get CLI options from configuration file", 

446 ), 

447 help: Optional[bool] = Option( 

448 None, 

449 "--help", 

450 is_eager=True, 

451 callback=_help_callback, 

452 help="Show this message and exit", 

453 ), 

454) -> None: 

455 """ 

456 Show differences between notebooks. 

457 

458 This is similar to `git-diff`, but in practice it is a subset of `git-diff` 

459 features - only exception is that we cannot compare diffs between local files. That 

460 means we can compare files that are staged with other branches, hashes, etc., or 

461 compare the current directory with the current index. 

462 """ 

463 (ref_base, ref_remote), paths = _parse_paths(ref_base, ref_remote, paths=paths) 

464 diffs = get_nb_diffs( 

465 ref_base=ref_base, ref_remote=ref_remote, paths=paths, verbose=verbose 

466 ) 

467 if not diffs: 

468 logger.info("No notebook diffs found. Nothing to do.") 

469 raise Exit() 

470 if len(diffs) > 1 and not multiple: 

471 if not Confirm.ask(f"Show {len(diffs)} notebook diffs?"): 

472 raise Exit() 

473 print_diffs(diffs=diffs, use_pager=pager)