lean-explore 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,506 @@
1
+ # src/lean_explore/cli/data_commands.py
2
+
3
+ """Provides CLI commands for managing local Lean exploration data toolchains.
4
+
5
+ This module includes functions to fetch toolchain data (database, FAISS index, etc.)
6
+ from a remote source (Cloudflare R2), verify its integrity, decompress it,
7
+ and place it in the appropriate local directory for the application to use.
8
+ """
9
+
10
+ import gzip
11
+ import hashlib
12
+ import json
13
+ import pathlib
14
+ import shutil
15
+ from typing import Any, Dict, List, Optional
16
+
17
+ import requests
18
+ import typer
19
+ from rich.console import Console
20
+ from rich.progress import (
21
+ BarColumn,
22
+ DownloadColumn,
23
+ Progress,
24
+ TextColumn,
25
+ TimeRemainingColumn,
26
+ TransferSpeedColumn,
27
+ )
28
+
29
+ from lean_explore import defaults # For R2 URLs and local paths
30
+
31
+ # Typer application for data commands
32
+ app = typer.Typer(
33
+ name="data",
34
+ help="Manage local data toolchains for Lean Explore (e.g., download, list, "
35
+ "select).",
36
+ no_args_is_help=True,
37
+ )
38
+
39
+ # Initialize console for rich output
40
+ console = Console()
41
+
42
+
43
+ # --- Internal Helper Functions ---
44
+
45
+
46
+ def _fetch_remote_json(url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
47
+ """Fetches JSON data from a remote URL.
48
+
49
+ Args:
50
+ url: The URL to fetch JSON from.
51
+ timeout: Request timeout in seconds.
52
+
53
+ Returns:
54
+ A dictionary parsed from JSON, or None if an error occurs.
55
+ """
56
+ try:
57
+ response = requests.get(url, timeout=timeout)
58
+ response.raise_for_status() # Raise an exception for HTTP errors
59
+ return response.json()
60
+ except requests.exceptions.RequestException as e:
61
+ console.print(f"[bold red]Error fetching manifest from {url}: {e}[/bold red]")
62
+ except json.JSONDecodeError as e:
63
+ console.print(f"[bold red]Error parsing JSON from {url}: {e}[/bold red]")
64
+ return None
65
+
66
+
67
+ def _resolve_toolchain_version_info(
68
+ manifest_data: Dict[str, Any], requested_identifier: str
69
+ ) -> Optional[Dict[str, Any]]:
70
+ """Resolves a requested version identifier to its concrete toolchain info.
71
+
72
+ Handles aliases like "stable" by looking up "default_toolchain" in the manifest.
73
+
74
+ Args:
75
+ manifest_data: The parsed manifest dictionary.
76
+ requested_identifier: The version string requested by the user (e.g., "stable",
77
+ "0.1.0").
78
+
79
+ Returns:
80
+ The dictionary containing information for the resolved concrete toolchain
81
+ version, or None if not found or resolvable.
82
+ """
83
+ toolchains_dict = manifest_data.get("toolchains")
84
+ if not isinstance(toolchains_dict, dict):
85
+ console.print(
86
+ "[bold red]Error: Manifest is missing 'toolchains' dictionary.[/bold red]"
87
+ )
88
+ return None
89
+
90
+ target_version_key = requested_identifier
91
+ if requested_identifier.lower() == "stable":
92
+ stable_alias_target = manifest_data.get("default_toolchain")
93
+ if not stable_alias_target:
94
+ console.print(
95
+ "[bold red]Error: Manifest does not define a 'default_toolchain' "
96
+ "for 'stable'.[/bold red]"
97
+ )
98
+ return None
99
+ target_version_key = stable_alias_target
100
+ console.print(
101
+ f"Note: 'stable' currently points to version '{target_version_key}'."
102
+ )
103
+
104
+ version_info = toolchains_dict.get(target_version_key)
105
+ if not version_info:
106
+ console.print(
107
+ f"[bold red]Error: Version '{target_version_key}' (resolved from "
108
+ f"'{requested_identifier}') not found in the manifest.[/bold red]"
109
+ )
110
+ return None
111
+
112
+ # Store the resolved key for easier access by the caller
113
+ version_info["_resolved_key"] = target_version_key
114
+ return version_info
115
+
116
+
117
+ def _download_file_with_progress(
118
+ url: str,
119
+ destination_path: pathlib.Path,
120
+ description: str,
121
+ expected_size_bytes: Optional[int] = None,
122
+ timeout: int = 30,
123
+ ) -> bool:
124
+ """Downloads a file from a URL with a progress bar, saving raw bytes.
125
+
126
+ This function attempts to download the raw bytes from the server,
127
+ especially to handle pre-gzipped files correctly without interference
128
+ from the requests library's automatic content decoding.
129
+
130
+ Args:
131
+ url: The URL to download from.
132
+ destination_path: The local path to save the downloaded file.
133
+ description: A description of the file for the progress bar.
134
+ expected_size_bytes: The expected size of the file in bytes for progress
135
+ tracking. This should typically be the size of the compressed file if
136
+ downloading a gzipped file.
137
+ timeout: Request timeout in seconds for establishing connection and for read.
138
+
139
+ Returns:
140
+ True if download was successful, False otherwise.
141
+ """
142
+ console.print(f"Downloading [cyan]{description}[/cyan] from {url}...")
143
+ try:
144
+ # By not setting 'Accept-Encoding', we let the server decide if it wants
145
+ # to send a Content-Encoding. We will handle the raw stream.
146
+ r = requests.get(url, stream=True, timeout=timeout)
147
+ try:
148
+ r.raise_for_status()
149
+
150
+ # Content-Length should refer to the size of the entity on the wire.
151
+ # If the server sends Content-Encoding: gzip, this should be the gzipped
152
+ # size.
153
+ total_size_from_header = int(r.headers.get("content-length", 0))
154
+
155
+ display_size = total_size_from_header
156
+ if expected_size_bytes is not None:
157
+ if (
158
+ total_size_from_header > 0
159
+ and expected_size_bytes != total_size_from_header
160
+ ):
161
+ console.print(
162
+ f"[yellow]Warning: Expected size for "
163
+ f"[cyan]{description}[/cyan] "
164
+ f"is {expected_size_bytes} bytes, but server "
165
+ "reports "
166
+ f"Content-Length: {total_size_from_header} bytes. Using server "
167
+ "reported size for progress bar if available, otherwise "
168
+ "expected size.[/yellow]"
169
+ )
170
+ # Prefer expected_size_bytes if it's provided and server doesn't send
171
+ # Content-Length or if we want to strictly adhere to manifest size for
172
+ # progress. However, for live progress, server's content-length is
173
+ # usually more accurate for what's being transferred.
174
+ if (
175
+ total_size_from_header == 0
176
+ ): # If server didn't provide content-length
177
+ display_size = expected_size_bytes
178
+ elif total_size_from_header == 0 and expected_size_bytes is None:
179
+ # Cannot determine size for progress bar
180
+ display_size = None
181
+
182
+ with Progress(
183
+ TextColumn("[progress.description]{task.description}"),
184
+ BarColumn(),
185
+ DownloadColumn(),
186
+ TransferSpeedColumn(),
187
+ TimeRemainingColumn(),
188
+ console=console,
189
+ transient=False,
190
+ ) as progress:
191
+ task_id = progress.add_task(description, total=display_size)
192
+ destination_path.parent.mkdir(parents=True, exist_ok=True)
193
+ downloaded_bytes_count = 0
194
+ with open(destination_path, "wb") as f:
195
+ # Iterate over the raw stream to prevent requests from
196
+ # auto-decompressing based on Content-Encoding headers.
197
+ for chunk in r.raw.stream(decode_content=False, amt=8192):
198
+ f.write(chunk)
199
+ downloaded_bytes_count += len(chunk)
200
+ progress.update(task_id, advance=len(chunk))
201
+ finally:
202
+ r.close()
203
+
204
+ # Sanity check after download
205
+ actual_downloaded_size = destination_path.stat().st_size
206
+ if (
207
+ total_size_from_header > 0
208
+ and actual_downloaded_size != total_size_from_header
209
+ ):
210
+ # This might indicate an incomplete download if not all bytes were written.
211
+ console.print(
212
+ f"[orange3]Warning: For [cyan]{description}[/cyan], downloaded size "
213
+ f"({actual_downloaded_size} bytes) differs from Content-Length header "
214
+ f"({total_size_from_header} bytes). Checksum verification will be the "
215
+ "final arbiter.[/orange3]"
216
+ )
217
+ elif (
218
+ expected_size_bytes is not None
219
+ and actual_downloaded_size != expected_size_bytes
220
+ ):
221
+ console.print(
222
+ f"[orange3]Warning: For [cyan]{description}[/cyan], downloaded size "
223
+ f"({actual_downloaded_size} bytes) differs from manifest expected "
224
+ f"size ({expected_size_bytes} bytes). Checksum verification will be "
225
+ "the final arbiter.[/orange3]"
226
+ )
227
+
228
+ console.print(
229
+ f"[green]Downloaded raw content for {description} successfully.[/green]"
230
+ )
231
+ return True
232
+ except requests.exceptions.RequestException as e:
233
+ console.print(f"[bold red]Error downloading {description}: {e}[/bold red]")
234
+ except OSError as e:
235
+ console.print(f"[bold red]Error writing {description} to disk: {e}[/bold red]")
236
+ except Exception as e: # Catch any other unexpected errors during download
237
+ console.print(
238
+ f"[bold red]An unexpected error occurred during download of {description}:"
239
+ f" {e}[/bold red]"
240
+ )
241
+
242
+ if destination_path.exists():
243
+ destination_path.unlink(missing_ok=True)
244
+ return False
245
+
246
+
247
+ def _verify_sha256_checksum(file_path: pathlib.Path, expected_checksum: str) -> bool:
248
+ """Verifies the SHA256 checksum of a file.
249
+
250
+ Args:
251
+ file_path: Path to the file to verify.
252
+ expected_checksum: The expected SHA256 checksum string (hex digest).
253
+
254
+ Returns:
255
+ True if the checksum matches, False otherwise.
256
+ """
257
+ console.print(f"Verifying checksum for [cyan]{file_path.name}[/cyan]...")
258
+ sha256_hash = hashlib.sha256()
259
+ try:
260
+ with open(file_path, "rb") as f:
261
+ # Read and update hash string value in blocks of 4K
262
+ for byte_block in iter(lambda: f.read(4096), b""):
263
+ sha256_hash.update(byte_block)
264
+ calculated_checksum = sha256_hash.hexdigest()
265
+ if calculated_checksum == expected_checksum.lower():
266
+ console.print(f"[green]Checksum verified for {file_path.name}.[/green]")
267
+ return True
268
+ else:
269
+ console.print(
270
+ f"[bold red]Checksum mismatch for {file_path.name}:[/bold red]\n"
271
+ f" Expected: {expected_checksum.lower()}\n"
272
+ f" Got: {calculated_checksum}"
273
+ )
274
+ return False
275
+ except OSError as e:
276
+ console.print(
277
+ "[bold red]Error reading file "
278
+ f"{file_path.name} for checksum: {e}[/bold red]"
279
+ )
280
+ return False
281
+
282
+
283
+ def _decompress_gzipped_file(
284
+ gzipped_file_path: pathlib.Path, output_file_path: pathlib.Path
285
+ ) -> bool:
286
+ """Decompresses a .gz file.
287
+
288
+ Args:
289
+ gzipped_file_path: Path to the .gz file.
290
+ output_file_path: Path to save the decompressed output.
291
+
292
+ Returns:
293
+ True if decompression was successful, False otherwise.
294
+ """
295
+ console.print(
296
+ f"Decompressing [cyan]{gzipped_file_path.name}[/cyan] to "
297
+ f"{output_file_path.name}..."
298
+ )
299
+ try:
300
+ output_file_path.parent.mkdir(parents=True, exist_ok=True)
301
+ with gzip.open(gzipped_file_path, "rb") as f_in:
302
+ with open(output_file_path, "wb") as f_out:
303
+ shutil.copyfileobj(f_in, f_out)
304
+ console.print(
305
+ f"[green]Decompressed {gzipped_file_path.name} successfully.[/green]"
306
+ )
307
+ return True
308
+ except (OSError, gzip.BadGzipFile, EOFError) as e:
309
+ console.print(
310
+ f"[bold red]Error decompressing {gzipped_file_path.name}: {e}[/bold red]"
311
+ )
312
+ if output_file_path.exists(): # Clean up partial decompression
313
+ output_file_path.unlink(missing_ok=True)
314
+ return False
315
+
316
+
317
+ # --- CLI Command Functions ---
318
+
319
+
320
+ @app.callback()
321
+ def main() -> None:
322
+ """Lean-Explore data CLI.
323
+
324
+ This callback exists only to prevent Typer from treating the first
325
+ sub-command as a *default* command when there is otherwise just one.
326
+ """
327
+ pass
328
+
329
+
330
+ @app.command()
331
+ def fetch(
332
+ version: str = typer.Argument(
333
+ None,
334
+ help=(
335
+ "The toolchain version to fetch (e.g., 'stable', '0.1.0'). "
336
+ "'stable' will attempt to use the 'default_toolchain' from the manifest."
337
+ ),
338
+ show_default=False,
339
+ ),
340
+ ) -> None:
341
+ """Fetches and installs a specified data version from the remote repository.
342
+
343
+ Downloads necessary assets like the database and FAISS index, verifies their
344
+ integrity via SHA256 checksums, decompresses them, and places them into the
345
+ appropriate local directory (e.g., ~/.lean_explore/data/toolchains/<version>/).
346
+ """
347
+ console.rule(
348
+ f"[bold blue]Fetching Lean Explore Data Toolchain: {version}[/bold blue]"
349
+ )
350
+
351
+ if version is None:
352
+ version = "stable"
353
+
354
+ # 1. Fetch and Parse Manifest
355
+ console.print(f"Fetching data manifest from {defaults.R2_MANIFEST_DEFAULT_URL}...")
356
+ manifest_data = _fetch_remote_json(defaults.R2_MANIFEST_DEFAULT_URL)
357
+ if not manifest_data:
358
+ console.print(
359
+ "[bold red]Failed to fetch or parse the manifest. Aborting.[/bold red]"
360
+ )
361
+ raise typer.Exit(code=1)
362
+ console.print("[green]Manifest fetched successfully.[/green]")
363
+
364
+ # 2. Resolve Target Version from Manifest
365
+ version_info = _resolve_toolchain_version_info(manifest_data, version)
366
+ if not version_info:
367
+ # _resolve_toolchain_version_info already prints detailed errors
368
+ raise typer.Exit(code=1)
369
+
370
+ resolved_version_key = version_info["_resolved_key"] # Key like "0.1.0"
371
+ console.print(
372
+ f"Processing toolchain version: [bold yellow]{resolved_version_key}"
373
+ "[/bold yellow] "
374
+ f"('{version_info.get('description', 'N/A')}')"
375
+ )
376
+
377
+ # 3. Determine Local Paths and Ensure Directory Exists
378
+ local_version_dir = defaults.LEAN_EXPLORE_TOOLCHAINS_BASE_DIR / resolved_version_key
379
+ try:
380
+ local_version_dir.mkdir(parents=True, exist_ok=True)
381
+ console.print(f"Data will be stored in: [dim]{local_version_dir}[/dim]")
382
+ except OSError as e:
383
+ console.print(
384
+ f"[bold red]Error creating local directory {local_version_dir}: {e}"
385
+ "[/bold red]"
386
+ )
387
+ raise typer.Exit(code=1)
388
+
389
+ # 4. Process Files for the Target Version
390
+ files_to_process: List[Dict[str, Any]] = version_info.get("files", [])
391
+ if not files_to_process:
392
+ console.print(
393
+ f"[yellow]No files listed in the manifest for version "
394
+ f"'{resolved_version_key}'. Nothing to do.[/yellow]"
395
+ )
396
+ raise typer.Exit(code=0)
397
+
398
+ all_files_successful = True
399
+ for file_entry in files_to_process:
400
+ local_name = file_entry.get("local_name")
401
+ remote_name = file_entry.get("remote_name")
402
+ expected_checksum = file_entry.get("sha256")
403
+ expected_size_compressed = file_entry.get(
404
+ "size_bytes_compressed"
405
+ ) # This is size of .gz
406
+ assets_r2_path_prefix = version_info.get(
407
+ "assets_base_path_r2", ""
408
+ ) # e.g., "assets/0.1.0/"
409
+
410
+ if not all([local_name, remote_name, expected_checksum]):
411
+ console.print(
412
+ f"[bold red]Skipping invalid file entry in manifest: {file_entry}. "
413
+ "Missing name, remote name, or checksum.[/bold red]"
414
+ )
415
+ all_files_successful = False
416
+ continue
417
+
418
+ console.rule(f"[bold cyan]Processing: {local_name}[/bold cyan]")
419
+
420
+ final_local_path = local_version_dir / local_name
421
+ temp_download_path = local_version_dir / remote_name # Path for the .gz file
422
+
423
+ remote_url = (
424
+ defaults.R2_ASSETS_BASE_URL.rstrip("/")
425
+ + "/"
426
+ + assets_r2_path_prefix.strip("/")
427
+ + "/"
428
+ + remote_name
429
+ )
430
+
431
+ if final_local_path.exists():
432
+ console.print(
433
+ f"[yellow]'{local_name}' already exists at {final_local_path}. "
434
+ "Skipping download.[/yellow]\n"
435
+ f"[dim] (Checksum verification for existing files is not yet "
436
+ "implemented. Delete the file to re-download).[/dim]"
437
+ )
438
+ continue
439
+
440
+ if temp_download_path.exists():
441
+ temp_download_path.unlink(missing_ok=True)
442
+
443
+ download_ok = _download_file_with_progress(
444
+ remote_url,
445
+ temp_download_path,
446
+ description=local_name,
447
+ expected_size_bytes=expected_size_compressed,
448
+ )
449
+ if not download_ok:
450
+ all_files_successful = False
451
+ console.print(
452
+ f"[bold red]Failed to download {remote_name}. Halting for this file."
453
+ "[/bold red]"
454
+ )
455
+ continue
456
+
457
+ checksum_ok = _verify_sha256_checksum(temp_download_path, expected_checksum)
458
+ if not checksum_ok:
459
+ all_files_successful = False
460
+ console.print(
461
+ f"[bold red]Checksum verification failed for {remote_name}. "
462
+ "Deleting downloaded file.[/bold red]"
463
+ )
464
+ temp_download_path.unlink(missing_ok=True)
465
+ continue
466
+
467
+ decompress_ok = _decompress_gzipped_file(temp_download_path, final_local_path)
468
+ if not decompress_ok:
469
+ all_files_successful = False
470
+ console.print(
471
+ f"[bold red]Failed to decompress {remote_name}. "
472
+ "Cleaning up temporary files.[/bold red]"
473
+ )
474
+ if final_local_path.exists():
475
+ final_local_path.unlink(missing_ok=True)
476
+ if (
477
+ temp_download_path.exists()
478
+ ): # Ensure .gz is also removed on decompress failure
479
+ temp_download_path.unlink(missing_ok=True)
480
+ continue
481
+
482
+ if temp_download_path.exists():
483
+ temp_download_path.unlink()
484
+ console.print(
485
+ f"[green]Successfully installed and verified {local_name} to "
486
+ f"{final_local_path}[/green]\n"
487
+ )
488
+
489
+ console.rule()
490
+ if all_files_successful:
491
+ console.print(
492
+ f"[bold green]Toolchain '{resolved_version_key}' fetch process completed "
493
+ "successfully.[/bold green]"
494
+ )
495
+ else:
496
+ console.print(
497
+ f"[bold orange3]Toolchain '{resolved_version_key}' fetch process completed "
498
+ "with some errors. Please review the output above.[/bold orange3]"
499
+ )
500
+ raise typer.Exit(code=1)
501
+
502
+
503
+ if __name__ == "__main__":
504
+ # This allows testing `python -m lean_explore.cli.data_commands fetch stable`
505
+ # For actual CLI use, this app will be mounted in `main.py`.
506
+ app()