Coverage for databooks/cli.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

72 statements  

1"""Main CLI application.""" 

2from itertools import compress 

3from pathlib import Path 

4from typing import List, Optional 

5 

6import tomli 

7from rich.progress import ( 

8 BarColumn, 

9 Progress, 

10 SpinnerColumn, 

11 TextColumn, 

12 TimeElapsedColumn, 

13) 

14from typer import Argument, BadParameter, Context, Exit, Option, Typer, echo 

15 

16from databooks.common import expand_paths 

17from databooks.config import TOML_CONFIG_FILE, get_config 

18from databooks.conflicts import conflicts2nbs, path2conflicts 

19from databooks.logging import get_logger 

20from databooks.metadata import clear_all 

21from databooks.version import __version__ 

22 

23logger = get_logger(__file__) 

24 

25app = Typer() 

26 

27 

28def _version_callback(show_version: bool) -> None: 

29 """Return application version.""" 

30 if show_version: 

31 echo("databooks version: " + __version__) 

32 raise Exit() 

33 

34 

35def _help_callback(ctx: Context, show_help: Optional[bool]) -> None: 

36 """Reimplement `help` command to execute eagerly.""" 

37 if show_help: 

38 echo(ctx.command.get_help(ctx)) 

39 raise Exit() 

40 

41 

42def _config_callback(ctx: Context, config_path: Optional[Path]) -> Optional[Path]: 

43 """Get config file and inject values into context to override default args.""" 

44 target_paths = expand_paths( 

45 paths=[Path(p) for p in ctx.params.get("paths", ())], rglob="*" 

46 ) 

47 config_path = ( 

48 get_config( 

49 target_paths=target_paths, 

50 config_filename=TOML_CONFIG_FILE, 

51 ) 

52 if config_path is None and target_paths 

53 else config_path 

54 ) 

55 logger.debug(f"Loading config file from: {config_path}") 

56 

57 ctx.default_map = ctx.default_map or {} # initialize defaults 

58 

59 if config_path is not None: # config may not be specified 

60 with config_path.open("r") as f: 

61 conf = ( 

62 tomli.load(f) 

63 .get("tool", {}) 

64 .get("databooks", {}) 

65 .get(ctx.command.name, {}) 

66 ) 

67 # Merge configuration 

68 ctx.default_map.update({k.replace("-", "_"): v for k, v in conf.items()}) 

69 return config_path 

70 

71 

72@app.callback() 

73def callback( # noqa: D103 

74 version: Optional[bool] = Option( 

75 None, "--version", callback=_version_callback, is_eager=True 

76 ) 

77) -> None: 

78 """CLI tool to resolve git conflicts and remove metadata in notebooks.""" 

79 

80 

81@app.command(add_help_option=False) 

82def meta( 

83 paths: List[Path] = Argument(..., is_eager=True, help="Path(s) of notebook files"), 

84 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

85 prefix: str = Option("", help="Prefix to add to filepath when writing files"), 

86 suffix: str = Option("", help="Suffix to add to filepath when writing files"), 

87 rm_outs: bool = Option(False, help="Whether to remove cell outputs"), 

88 rm_exec: bool = Option(True, help="Whether to remove the cell execution counts"), 

89 nb_meta_keep: List[str] = Option([], help="Notebook metadata fields to keep"), 

90 cell_meta_keep: List[str] = Option([], help="Cells metadata fields to keep"), 

91 cell_fields_keep: List[str] = Option( 

92 [], 

93 help="Other (excluding `execution_counts` and `outputs`) cell fields to keep", 

94 ), 

95 overwrite: bool = Option( 

96 False, "--overwrite", "-w", help="Confirm overwrite of files" 

97 ), 

98 check: bool = Option( 

99 False, 

100 "--check", 

101 help="Don't write files but check whether there is unwanted metadata", 

102 ), 

103 verbose: bool = Option( 

104 False, "--verbose", "-v", help="Log processed files in console" 

105 ), 

106 config: Optional[Path] = Option( 

107 None, 

108 "--config", 

109 "-c", 

110 is_eager=True, 

111 callback=_config_callback, 

112 resolve_path=True, 

113 exists=True, 

114 help="Get CLI options from configuration file", 

115 ), 

116 help: Optional[bool] = Option( 

117 None, is_eager=True, callback=_help_callback, help="Show this message and exit" 

118 ), 

119) -> None: 

120 """Clear both notebook and cell metadata.""" 

121 if any(path.suffix not in ("", ".ipynb") for path in paths): 

122 raise BadParameter( 

123 "Expected either notebook files, a directory or glob expression." 

124 ) 

125 nb_paths = expand_paths(paths=paths, ignore=ignore) 

126 if not nb_paths: 

127 logger.info(f"No notebooks found in {paths}. Nothing to do.") 

128 raise Exit() 

129 

130 if not bool(prefix + suffix) and not check: 

131 if not overwrite: 

132 raise BadParameter( 

133 "No prefix nor suffix were passed." 

134 " Please specify `--overwrite` or `-w` to overwrite files." 

135 ) 

136 else: 

137 logger.warning(f"{len(nb_paths)} files will be overwritten") 

138 

139 write_paths = [p.parent / (prefix + p.stem + suffix + p.suffix) for p in nb_paths] 

140 cell_fields_keep = list( 

141 compress(["outputs", "execution_count"], (not v for v in (rm_outs, rm_exec))) 

142 ) + list(cell_fields_keep) 

143 with Progress( 

144 SpinnerColumn(), 

145 TextColumn("[progress.description]{task.description}"), 

146 BarColumn(), 

147 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

148 TimeElapsedColumn(), 

149 ) as progress: 

150 metadata = progress.add_task("[yellow]Removing metadata", total=len(nb_paths)) 

151 

152 are_equal = clear_all( 

153 read_paths=nb_paths, 

154 write_paths=write_paths, 

155 progress_callback=lambda: progress.update(metadata, advance=1), 

156 notebook_metadata_keep=nb_meta_keep, 

157 cell_metadata_keep=cell_meta_keep, 

158 cell_fields_keep=cell_fields_keep, 

159 check=check, 

160 verbose=verbose, 

161 ) 

162 if check: 

163 if all(are_equal): 

164 logger.info("No unwanted metadata!") 

165 else: 

166 logger.info( 

167 f"Found unwanted metadata in {sum(not eq for eq in are_equal)} out of" 

168 f" {len(are_equal)} files" 

169 ) 

170 raise Exit(code=1) 

171 else: 

172 logger.info( 

173 f"The metadata of {sum(not eq for eq in are_equal)} out of {len(are_equal)}" 

174 " notebooks were removed!" 

175 ) 

176 

177 

178@app.command(add_help_option=False) 

179def fix( 

180 paths: List[Path] = Argument( 

181 ..., is_eager=True, help="Path(s) of notebook files with conflicts" 

182 ), 

183 ignore: List[str] = Option(["!*"], help="Glob expression(s) of files to ignore"), 

184 metadata_head: bool = Option( 

185 True, help="Whether or not to keep the metadata from the head/current notebook" 

186 ), 

187 cells_head: Optional[bool] = Option( 

188 None, 

189 help="Whether to keep the cells from the head/base notebook. Omit to keep both", 

190 ), 

191 cell_fields_ignore: List[str] = Option( 

192 [ 

193 "id", 

194 "execution_count", 

195 ], 

196 help="Cell fields to remove before comparing cells", 

197 ), 

198 interactive: bool = Option( 

199 False, 

200 "--interactive", 

201 "-i", 

202 help="Interactively resolve the conflicts (not implemented)", 

203 ), 

204 verbose: bool = Option(False, help="Log processed files in console"), 

205 config: Optional[Path] = Option( 

206 None, 

207 "--config", 

208 "-c", 

209 is_eager=True, 

210 callback=_config_callback, 

211 resolve_path=True, 

212 exists=True, 

213 help="Get CLI options from configuration file", 

214 ), 

215 help: Optional[bool] = Option( 

216 None, is_eager=True, callback=_help_callback, help="Show this message and exit" 

217 ), 

218) -> None: 

219 """ 

220 Fix git conflicts for notebooks. 

221 

222 Perform by getting the unmerged blobs from git index, comparing them and returning 

223 a valid notebook summarizing the differences - see 

224 [git docs](https://git-scm.com/docs/git-ls-files). 

225 """ 

226 filepaths = expand_paths(paths=paths, ignore=ignore) 

227 conflict_files = path2conflicts(nb_paths=filepaths) 

228 if not conflict_files: 

229 raise BadParameter( 

230 f"No conflicts found at {', '.join([str(p) for p in filepaths])}." 

231 ) 

232 if interactive: 

233 raise NotImplementedError 

234 

235 with Progress( 

236 SpinnerColumn(), 

237 TextColumn("[progress.description]{task.description}"), 

238 BarColumn(), 

239 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 

240 TimeElapsedColumn(), 

241 ) as progress: 

242 conflicts = progress.add_task( 

243 "[yellow]Removing metadata", total=len(conflict_files) 

244 ) 

245 conflicts2nbs( 

246 conflict_files=conflict_files, 

247 meta_first=metadata_head, 

248 cells_first=cells_head, 

249 cell_fields_ignore=cell_fields_ignore, 

250 verbose=verbose, 

251 progress_callback=lambda: progress.update(conflicts, advance=1), 

252 ) 

253 logger.info(f"Resolved the conflicts of {len(conflict_files)}!") 

254 

255 

256@app.command() 

257def diff() -> None: 

258 """Show differences between notebooks (not implemented).""" 

259 raise NotImplementedError