lean-explore 0.3.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. lean_explore/__init__.py +14 -1
  2. lean_explore/api/__init__.py +12 -1
  3. lean_explore/api/client.py +64 -176
  4. lean_explore/cli/__init__.py +10 -1
  5. lean_explore/cli/data_commands.py +184 -489
  6. lean_explore/cli/display.py +171 -0
  7. lean_explore/cli/main.py +51 -608
  8. lean_explore/config.py +244 -0
  9. lean_explore/extract/__init__.py +5 -0
  10. lean_explore/extract/__main__.py +368 -0
  11. lean_explore/extract/doc_gen4.py +200 -0
  12. lean_explore/extract/doc_parser.py +499 -0
  13. lean_explore/extract/embeddings.py +369 -0
  14. lean_explore/extract/github.py +110 -0
  15. lean_explore/extract/index.py +316 -0
  16. lean_explore/extract/informalize.py +653 -0
  17. lean_explore/extract/package_config.py +59 -0
  18. lean_explore/extract/package_registry.py +45 -0
  19. lean_explore/extract/package_utils.py +105 -0
  20. lean_explore/extract/types.py +25 -0
  21. lean_explore/mcp/__init__.py +11 -1
  22. lean_explore/mcp/app.py +14 -46
  23. lean_explore/mcp/server.py +20 -35
  24. lean_explore/mcp/tools.py +71 -205
  25. lean_explore/models/__init__.py +9 -0
  26. lean_explore/models/search_db.py +76 -0
  27. lean_explore/models/search_types.py +53 -0
  28. lean_explore/search/__init__.py +32 -0
  29. lean_explore/search/engine.py +651 -0
  30. lean_explore/search/scoring.py +156 -0
  31. lean_explore/search/service.py +68 -0
  32. lean_explore/search/tokenization.py +71 -0
  33. lean_explore/util/__init__.py +28 -0
  34. lean_explore/util/embedding_client.py +92 -0
  35. lean_explore/util/logging.py +22 -0
  36. lean_explore/util/openrouter_client.py +63 -0
  37. lean_explore/util/reranker_client.py +187 -0
  38. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/METADATA +32 -9
  39. lean_explore-1.0.1.dist-info/RECORD +43 -0
  40. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/WHEEL +1 -1
  41. lean_explore-1.0.1.dist-info/entry_points.txt +2 -0
  42. lean_explore/cli/agent.py +0 -788
  43. lean_explore/cli/config_utils.py +0 -481
  44. lean_explore/defaults.py +0 -114
  45. lean_explore/local/__init__.py +0 -1
  46. lean_explore/local/search.py +0 -1050
  47. lean_explore/local/service.py +0 -479
  48. lean_explore/shared/__init__.py +0 -1
  49. lean_explore/shared/models/__init__.py +0 -1
  50. lean_explore/shared/models/api.py +0 -117
  51. lean_explore/shared/models/db.py +0 -396
  52. lean_explore-0.3.0.dist-info/RECORD +0 -26
  53. lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
  54. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/licenses/LICENSE +0 -0
  55. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,369 @@
1
+ """Generate embeddings for Lean declarations.
2
+
3
+ Reads declarations from the database and generates informalization embeddings
4
+ for semantic search.
5
+ """
6
+
7
+ import logging
8
+ import sqlite3
9
+ import struct
10
+ import time
11
+ from collections import deque
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+
15
+ from rich.progress import (
16
+ BarColumn,
17
+ Progress,
18
+ ProgressColumn,
19
+ SpinnerColumn,
20
+ Task,
21
+ TaskProgressColumn,
22
+ TextColumn,
23
+ TimeRemainingColumn,
24
+ )
25
+ from rich.text import Text
26
+ from sqlalchemy import select
27
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
28
+
29
+ from lean_explore.config import Config
30
+ from lean_explore.models import Declaration
31
+ from lean_explore.util import EmbeddingClient
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class RateColumn(ProgressColumn):
37
+ """Custom column showing embeddings per second over a rolling window."""
38
+
39
+ def __init__(self, window_seconds: int = 300):
40
+ """Initialize rate column.
41
+
42
+ Args:
43
+ window_seconds: Rolling window size in seconds for rate calculation
44
+ """
45
+ super().__init__()
46
+ self.window_seconds = window_seconds
47
+ self.history: deque[tuple[float, int]] = deque()
48
+ self.total_count = 0
49
+
50
+ def add_count(self, count: int) -> None:
51
+ """Add embedding count with timestamp."""
52
+ now = time.time()
53
+ self.history.append((now, count))
54
+ self.total_count += count
55
+ # Remove old entries outside window
56
+ cutoff = now - self.window_seconds
57
+ while self.history and self.history[0][0] < cutoff:
58
+ self.history.popleft()
59
+
60
+ def render(self, task: Task) -> Text:
61
+ """Render the rate column."""
62
+ if not self.history:
63
+ return Text("-- emb/s", style="cyan")
64
+
65
+ now = time.time()
66
+ cutoff = now - self.window_seconds
67
+ # Sum counts within window
68
+ window_count = sum(c for t, c in self.history if t >= cutoff)
69
+ # Calculate elapsed time in window
70
+ if self.history:
71
+ oldest_in_window = max(self.history[0][0], cutoff)
72
+ elapsed = now - oldest_in_window
73
+ if elapsed > 0:
74
+ rate = window_count / elapsed
75
+ return Text(f"{rate:.1f} emb/s", style="cyan")
76
+
77
+ return Text("-- emb/s", style="cyan")
78
+
79
+
80
+ # --- Data Classes ---
81
+
82
+
83
+ @dataclass
84
+ class EmbeddingCaches:
85
+ """Container for embedding caches.
86
+
87
+ Stores embeddings as raw bytes for efficiency. Use _deserialize_embedding()
88
+ to convert to list[float] when actually needed.
89
+ """
90
+
91
+ by_informalization: dict[str, bytes]
92
+
93
+
94
+ def _deserialize_embedding(data: bytes) -> list[float]:
95
+ """Convert raw binary embedding to list[float].
96
+
97
+ Args:
98
+ data: Binary embedding data (float32 packed)
99
+
100
+ Returns:
101
+ List of float values
102
+ """
103
+ num_floats = len(data) // 4
104
+ return list(struct.unpack(f"{num_floats}f", data))
105
+
106
+
107
+ # --- Cross-Database Cache Loading ---
108
+
109
+
110
+ def _discover_database_files() -> list[Path]:
111
+ """Discover all lean_explore.db files in data/ and cache/ directories.
112
+
113
+ Returns:
114
+ List of paths to discovered database files
115
+ """
116
+ database_files = []
117
+
118
+ # Search in data directory
119
+ data_dir = Config.DATA_DIRECTORY
120
+ if data_dir.exists():
121
+ database_files.extend(data_dir.rglob("lean_explore.db"))
122
+
123
+ # Search in cache directory
124
+ cache_dir = Config.CACHE_DIRECTORY
125
+ if cache_dir.exists():
126
+ database_files.extend(cache_dir.rglob("lean_explore.db"))
127
+
128
+ logger.info(f"Discovered {len(database_files)} database files")
129
+ return database_files
130
+
131
+
132
+ def _load_embedding_caches(database_files: list[Path]) -> EmbeddingCaches:
133
+ """Load embeddings from all discovered databases.
134
+
135
+ Builds a cache mapping informalization text to raw embedding bytes by scanning
136
+ all databases for declarations that have embeddings.
137
+
138
+ Uses sync sqlite3 directly to avoid SQLAlchemy ORM overhead and TypeDecorator
139
+ deserialization. Embeddings are stored as raw bytes and only deserialized
140
+ when actually used.
141
+
142
+ Args:
143
+ database_files: List of database file paths to scan
144
+
145
+ Returns:
146
+ EmbeddingCaches with cache dictionary populated (as bytes)
147
+ """
148
+ cache_by_informalization: dict[str, bytes] = {}
149
+
150
+ for db_path in database_files:
151
+ logger.info(f"Loading embedding cache from {db_path}")
152
+
153
+ try:
154
+ connection = sqlite3.connect(db_path)
155
+ cursor = connection.execute(
156
+ """
157
+ SELECT informalization, informalization_embedding
158
+ FROM declarations
159
+ WHERE informalization_embedding IS NOT NULL
160
+ """
161
+ )
162
+
163
+ count = 0
164
+ for row in cursor:
165
+ count += 1
166
+ (informalization, informalization_embedding) = row
167
+
168
+ # Cache informalization embedding
169
+ if (
170
+ informalization is not None
171
+ and informalization not in cache_by_informalization
172
+ ):
173
+ cache_by_informalization[informalization] = (
174
+ informalization_embedding
175
+ )
176
+
177
+ connection.close()
178
+ logger.info(f"Loaded {count} declarations from {db_path}")
179
+
180
+ except Exception as e:
181
+ logger.warning(f"Failed to load embedding cache from {db_path}: {e}")
182
+ continue
183
+
184
+ logger.info(f"Total cache size - informalization: {len(cache_by_informalization)}")
185
+
186
+ return EmbeddingCaches(by_informalization=cache_by_informalization)
187
+
188
+
189
+ async def _get_declarations_needing_embeddings(
190
+ session: AsyncSession, limit: int | None
191
+ ) -> list[Declaration]:
192
+ """Get declarations that need informalization embeddings.
193
+
194
+ Only returns declarations that have an informalization but no embedding yet.
195
+
196
+ Args:
197
+ session: Async database session
198
+ limit: Maximum number of declarations to retrieve (None for all)
199
+
200
+ Returns:
201
+ List of declarations needing embeddings
202
+ """
203
+ stmt = select(Declaration).where(
204
+ Declaration.informalization.isnot(None),
205
+ Declaration.informalization_embedding.is_(None),
206
+ )
207
+ if limit:
208
+ stmt = stmt.limit(limit)
209
+ result = await session.execute(stmt)
210
+ return list(result.scalars().all())
211
+
212
+
213
+ async def _apply_cache_to_declarations(
214
+ session: AsyncSession,
215
+ declarations: list[Declaration],
216
+ caches: EmbeddingCaches,
217
+ commit_batch_size: int = 1000,
218
+ ) -> tuple[int, list[Declaration]]:
219
+ """Apply cached embeddings to declarations.
220
+
221
+ This is a fast first pass that applies all cache hits before generating
222
+ new embeddings, allowing the user to see exactly how many need generation.
223
+
224
+ Args:
225
+ session: Async database session
226
+ declarations: List of declarations to check against cache
227
+ caches: Embedding caches from cross-database loading
228
+ commit_batch_size: Number of updates to batch before committing
229
+
230
+ Returns:
231
+ Tuple of (cache_hits_count, list of declarations still needing generation)
232
+ """
233
+ cache_hits = 0
234
+ remaining: list[Declaration] = []
235
+ batch_count = 0
236
+
237
+ for declaration in declarations:
238
+ if not declaration.informalization:
239
+ continue
240
+
241
+ if declaration.informalization in caches.by_informalization:
242
+ declaration.informalization_embedding = _deserialize_embedding(
243
+ caches.by_informalization[declaration.informalization]
244
+ )
245
+ cache_hits += 1
246
+ batch_count += 1
247
+
248
+ if batch_count >= commit_batch_size:
249
+ await session.commit()
250
+ batch_count = 0
251
+ else:
252
+ remaining.append(declaration)
253
+
254
+ if batch_count > 0:
255
+ await session.commit()
256
+
257
+ return cache_hits, remaining
258
+
259
+
260
+ async def _process_batch(
261
+ session: AsyncSession,
262
+ declarations: list[Declaration],
263
+ client: EmbeddingClient,
264
+ ) -> int:
265
+ """Process a batch of declarations and generate informalization embeddings.
266
+
267
+ Args:
268
+ session: Async database session
269
+ declarations: List of declarations to process (already filtered, no cache)
270
+ client: Embedding client for generating embeddings
271
+
272
+ Returns:
273
+ Number of embeddings generated
274
+ """
275
+ texts_to_embed = []
276
+ declarations_to_embed = []
277
+
278
+ for declaration in declarations:
279
+ if not declaration.informalization:
280
+ continue
281
+ if declaration.informalization_embedding is not None:
282
+ continue
283
+ texts_to_embed.append(declaration.informalization)
284
+ declarations_to_embed.append(declaration)
285
+
286
+ if texts_to_embed:
287
+ response = await client.embed(texts_to_embed)
288
+
289
+ for declaration, embedding in zip(declarations_to_embed, response.embeddings):
290
+ declaration.informalization_embedding = embedding
291
+
292
+ await session.commit()
293
+
294
+ return len(texts_to_embed)
295
+
296
+
297
+ async def generate_embeddings(
298
+ engine: AsyncEngine,
299
+ model_name: str,
300
+ batch_size: int = 128,
301
+ limit: int | None = None,
302
+ max_seq_length: int = 512,
303
+ ) -> None:
304
+ """Generate embeddings for all declarations.
305
+
306
+ Args:
307
+ engine: Async database engine
308
+ model_name: Name of the sentence transformer model to use
309
+ batch_size: Number of declarations to process in each batch (default 250)
310
+ limit: Maximum number of declarations to process (None for all)
311
+ max_seq_length: Maximum sequence length for tokenization (default 512).
312
+ Lower values reduce memory usage but may truncate long texts.
313
+ """
314
+ # Discover and load embedding caches from all existing databases
315
+ logger.info("Discovering existing databases for embedding cache...")
316
+ database_files = _discover_database_files()
317
+ caches = _load_embedding_caches(database_files)
318
+
319
+ async with AsyncSession(engine, expire_on_commit=False) as session:
320
+ declarations = await _get_declarations_needing_embeddings(session, limit)
321
+ logger.info(f"Found {len(declarations)} declarations needing embeddings")
322
+
323
+ if not declarations:
324
+ logger.info("No declarations to process")
325
+ return
326
+
327
+ # Phase 1: Apply all cache hits first
328
+ logger.info("Phase 1: Applying cached embeddings...")
329
+ cache_hits, remaining = await _apply_cache_to_declarations(
330
+ session, declarations, caches
331
+ )
332
+ logger.info(
333
+ f"Applied {cache_hits} embeddings from cache, "
334
+ f"{len(remaining)} remaining need generation"
335
+ )
336
+
337
+ if not remaining:
338
+ logger.info("All embeddings served from cache, no generation needed")
339
+ return
340
+
341
+ # Phase 2: Generate embeddings for remaining declarations
342
+ logger.info("Phase 2: Generating embeddings for remaining declarations...")
343
+ client = EmbeddingClient(model_name=model_name, max_length=max_seq_length)
344
+ logger.info(f"Using {client.model_name} on {client.device}")
345
+
346
+ total = len(remaining)
347
+ total_embeddings = 0
348
+ rate_column = RateColumn(window_seconds=60)
349
+ with Progress(
350
+ SpinnerColumn(),
351
+ TextColumn("[progress.description]{task.description}"),
352
+ BarColumn(),
353
+ TaskProgressColumn(),
354
+ rate_column,
355
+ TimeRemainingColumn(),
356
+ ) as progress:
357
+ task = progress.add_task("Generating embeddings", total=total)
358
+
359
+ for i in range(0, total, batch_size):
360
+ batch = remaining[i : i + batch_size]
361
+ count = await _process_batch(session, batch, client)
362
+ total_embeddings += count
363
+ rate_column.add_count(count)
364
+ progress.update(task, advance=len(batch))
365
+
366
+ logger.info(
367
+ f"Generated {total_embeddings} new embeddings "
368
+ f"({cache_hits} from cache, {total_embeddings} generated)"
369
+ )
@@ -0,0 +1,110 @@
1
+ """GitHub utilities for fetching package metadata.
2
+
3
+ This module provides functions to interact with GitHub repositories
4
+ for fetching toolchain versions and release tags.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import re
10
+ import urllib.request
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def github_url_to_raw(git_url: str, branch: str, file_path: str) -> str:
16
+ """Convert GitHub repo URL to raw file URL.
17
+
18
+ Args:
19
+ git_url: GitHub repository URL (e.g., https://github.com/owner/repo)
20
+ branch: Branch or tag name
21
+ file_path: Path to file in repo
22
+
23
+ Returns:
24
+ Raw GitHub URL for the file.
25
+ """
26
+ match = re.search(r"github\.com/([^/]+)/([^/]+?)(?:\.git)?$", git_url)
27
+ if not match:
28
+ raise ValueError(f"Could not parse GitHub URL: {git_url}")
29
+ owner, repo = match.groups()
30
+ return f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{file_path}"
31
+
32
+
33
+ def fetch_lean_toolchain(git_url: str, ref: str = "main") -> str:
34
+ """Fetch lean-toolchain content from a GitHub repository.
35
+
36
+ Args:
37
+ git_url: GitHub repository URL
38
+ ref: Branch name or tag (default: main)
39
+
40
+ Returns:
41
+ Content of the lean-toolchain file (e.g., 'leanprover/lean4:v4.27.0')
42
+ """
43
+ raw_url = github_url_to_raw(git_url, ref, "lean-toolchain")
44
+ logger.info(f"Fetching lean-toolchain from {raw_url}")
45
+
46
+ try:
47
+ with urllib.request.urlopen(raw_url, timeout=30) as response:
48
+ return response.read().decode("utf-8").strip()
49
+ except Exception as e:
50
+ raise RuntimeError(f"Failed to fetch lean-toolchain from {raw_url}: {e}")
51
+
52
+
53
+ def fetch_latest_tag(git_url: str) -> str:
54
+ """Fetch the latest semver tag from a GitHub repository.
55
+
56
+ Args:
57
+ git_url: GitHub repository URL
58
+
59
+ Returns:
60
+ Latest tag name (e.g., 'v4.26.0')
61
+ """
62
+ match = re.search(r"github\.com/([^/]+)/([^/]+?)(?:\.git)?$", git_url)
63
+ if not match:
64
+ raise ValueError(f"Could not parse GitHub URL: {git_url}")
65
+ owner, repo = match.groups()
66
+
67
+ api_url = f"https://api.github.com/repos/{owner}/{repo}/tags?per_page=100"
68
+ logger.info(f"Fetching tags from {api_url}")
69
+
70
+ try:
71
+ request = urllib.request.Request(
72
+ api_url,
73
+ headers={"Accept": "application/vnd.github.v3+json"},
74
+ )
75
+ with urllib.request.urlopen(request, timeout=30) as response:
76
+ tags = json.loads(response.read().decode("utf-8"))
77
+ except Exception as e:
78
+ raise RuntimeError(f"Failed to fetch tags from {api_url}: {e}")
79
+
80
+ if not tags:
81
+ raise RuntimeError(f"No tags found for {git_url}")
82
+
83
+ # Filter to semver-like tags (v*.*.*)
84
+ semver_pattern = re.compile(r"^v?\d+\.\d+\.\d+")
85
+ semver_tags = [t["name"] for t in tags if semver_pattern.match(t["name"])]
86
+
87
+ if not semver_tags:
88
+ return tags[0]["name"]
89
+
90
+ def semver_key(tag: str) -> list[int]:
91
+ return [int(x) for x in re.findall(r"\d+", tag)]
92
+
93
+ semver_tags.sort(key=semver_key, reverse=True)
94
+ return semver_tags[0]
95
+
96
+
97
+ def extract_lean_version(toolchain: str) -> str:
98
+ """Extract version from lean-toolchain content.
99
+
100
+ Args:
101
+ toolchain: Toolchain content like 'leanprover/lean4:v4.27.0'
102
+ or 'leanprover/lean4:v4.28.0-rc1'.
103
+
104
+ Returns:
105
+ Version string like 'v4.27.0' or 'v4.28.0-rc1'
106
+ """
107
+ match = re.search(r"v\d+\.\d+\.\d+(?:-rc\d+)?", toolchain)
108
+ if not match:
109
+ raise ValueError(f"Could not extract version from toolchain: {toolchain}")
110
+ return match.group()