lean-explore 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lean_explore/__init__.py +14 -1
- lean_explore/api/__init__.py +12 -1
- lean_explore/api/client.py +60 -80
- lean_explore/cli/__init__.py +10 -1
- lean_explore/cli/data_commands.py +157 -479
- lean_explore/cli/display.py +171 -0
- lean_explore/cli/main.py +51 -608
- lean_explore/config.py +244 -0
- lean_explore/extract/__init__.py +5 -0
- lean_explore/extract/__main__.py +368 -0
- lean_explore/extract/doc_gen4.py +200 -0
- lean_explore/extract/doc_parser.py +499 -0
- lean_explore/extract/embeddings.py +371 -0
- lean_explore/extract/github.py +110 -0
- lean_explore/extract/index.py +317 -0
- lean_explore/extract/informalize.py +653 -0
- lean_explore/extract/package_config.py +59 -0
- lean_explore/extract/package_registry.py +45 -0
- lean_explore/extract/package_utils.py +105 -0
- lean_explore/extract/types.py +25 -0
- lean_explore/mcp/__init__.py +11 -1
- lean_explore/mcp/app.py +14 -46
- lean_explore/mcp/server.py +20 -35
- lean_explore/mcp/tools.py +70 -177
- lean_explore/models/__init__.py +9 -0
- lean_explore/models/search_db.py +76 -0
- lean_explore/models/search_types.py +53 -0
- lean_explore/search/__init__.py +32 -0
- lean_explore/search/engine.py +655 -0
- lean_explore/search/scoring.py +156 -0
- lean_explore/search/service.py +68 -0
- lean_explore/search/tokenization.py +71 -0
- lean_explore/util/__init__.py +28 -0
- lean_explore/util/embedding_client.py +92 -0
- lean_explore/util/logging.py +22 -0
- lean_explore/util/openrouter_client.py +63 -0
- lean_explore/util/reranker_client.py +189 -0
- {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/METADATA +55 -10
- lean_explore-1.0.0.dist-info/RECORD +43 -0
- {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/WHEEL +1 -1
- lean_explore-1.0.0.dist-info/entry_points.txt +2 -0
- lean_explore/cli/agent.py +0 -781
- lean_explore/cli/config_utils.py +0 -481
- lean_explore/defaults.py +0 -114
- lean_explore/local/__init__.py +0 -1
- lean_explore/local/search.py +0 -1050
- lean_explore/local/service.py +0 -392
- lean_explore/shared/__init__.py +0 -1
- lean_explore/shared/models/__init__.py +0 -1
- lean_explore/shared/models/api.py +0 -117
- lean_explore/shared/models/db.py +0 -396
- lean_explore-0.2.2.dist-info/RECORD +0 -26
- lean_explore-0.2.2.dist-info/entry_points.txt +0 -2
- {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
"""Generate informal natural language descriptions for Lean declarations.
|
|
2
|
+
|
|
3
|
+
Reads declarations from the database, generates informal descriptions using
|
|
4
|
+
an LLM via OpenRouter, and updates the informalization field.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from rich.progress import (
|
|
15
|
+
BarColumn,
|
|
16
|
+
Progress,
|
|
17
|
+
SpinnerColumn,
|
|
18
|
+
TaskProgressColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
TimeRemainingColumn,
|
|
21
|
+
)
|
|
22
|
+
from sqlalchemy import select, update
|
|
23
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
24
|
+
|
|
25
|
+
from lean_explore.config import Config
|
|
26
|
+
from lean_explore.models import Declaration
|
|
27
|
+
from lean_explore.util import OpenRouterClient
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --- Data Classes ---
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class InformalizationResult:
|
|
37
|
+
"""Result of processing a single declaration."""
|
|
38
|
+
|
|
39
|
+
declaration_id: int
|
|
40
|
+
declaration_name: str
|
|
41
|
+
informalization: str | None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class DeclarationData:
|
|
46
|
+
"""Plain data extracted from Declaration ORM object for async processing."""
|
|
47
|
+
|
|
48
|
+
id: int
|
|
49
|
+
name: str
|
|
50
|
+
source_text: str
|
|
51
|
+
docstring: str | None
|
|
52
|
+
dependencies: str | None
|
|
53
|
+
informalization: str | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# --- Utility Functions ---
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _parse_dependencies(dependencies: str | list[str] | None) -> list[str]:
|
|
60
|
+
"""Parse dependencies field which may be JSON string or list.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
dependencies: Dependencies as JSON string, list, or None
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of dependency names
|
|
67
|
+
"""
|
|
68
|
+
if not dependencies:
|
|
69
|
+
return []
|
|
70
|
+
if isinstance(dependencies, str):
|
|
71
|
+
return json.loads(dependencies)
|
|
72
|
+
return dependencies
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _build_dependency_layers(
|
|
76
|
+
declarations: list[Declaration],
|
|
77
|
+
) -> list[list[Declaration]]:
|
|
78
|
+
"""Build dependency layers where each layer has no dependencies on later layers.
|
|
79
|
+
|
|
80
|
+
Returns a list of layers, where layer 0 has no dependencies, layer 1 only
|
|
81
|
+
depends on layer 0, etc. Cycles are broken arbitrarily.
|
|
82
|
+
"""
|
|
83
|
+
name_to_declaration = {
|
|
84
|
+
declaration.name: declaration for declaration in declarations
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
graph = defaultdict(list)
|
|
88
|
+
in_degree = defaultdict(int)
|
|
89
|
+
|
|
90
|
+
for declaration in declarations:
|
|
91
|
+
in_degree[declaration.name] = 0
|
|
92
|
+
|
|
93
|
+
for declaration in declarations:
|
|
94
|
+
dependencies = _parse_dependencies(declaration.dependencies)
|
|
95
|
+
for dependency_name in dependencies:
|
|
96
|
+
if dependency_name in name_to_declaration:
|
|
97
|
+
graph[dependency_name].append(declaration.name)
|
|
98
|
+
in_degree[declaration.name] += 1
|
|
99
|
+
|
|
100
|
+
# Process declarations layer by layer using Kahn's algorithm
|
|
101
|
+
layers = []
|
|
102
|
+
current_layer = [
|
|
103
|
+
name_to_declaration[name] for name in in_degree if in_degree[name] == 0
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
while current_layer:
|
|
107
|
+
layers.append(current_layer)
|
|
108
|
+
next_layer = []
|
|
109
|
+
|
|
110
|
+
for declaration in current_layer:
|
|
111
|
+
for neighbor in graph[declaration.name]:
|
|
112
|
+
in_degree[neighbor] -= 1
|
|
113
|
+
if in_degree[neighbor] == 0:
|
|
114
|
+
next_layer.append(name_to_declaration[neighbor])
|
|
115
|
+
|
|
116
|
+
current_layer = next_layer
|
|
117
|
+
|
|
118
|
+
# If there are nodes with non-zero in-degree, we have cycles
|
|
119
|
+
# Add them as a final layer (cycle is broken by arbitrary order)
|
|
120
|
+
remaining = [name_to_declaration[name] for name in in_degree if in_degree[name] > 0]
|
|
121
|
+
if remaining:
|
|
122
|
+
logger.warning(
|
|
123
|
+
f"Found {len(remaining)} declarations in cycles, adding as final layer"
|
|
124
|
+
)
|
|
125
|
+
layers.append(remaining)
|
|
126
|
+
|
|
127
|
+
return layers
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# --- Database Loading ---
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
async def _load_existing_informalizations(
|
|
134
|
+
session: AsyncSession,
|
|
135
|
+
) -> list[InformalizationResult]:
|
|
136
|
+
"""Load all existing informalizations from the database."""
|
|
137
|
+
logger.info("Loading existing informalizations...")
|
|
138
|
+
stmt = select(Declaration).where(Declaration.informalization.isnot(None))
|
|
139
|
+
result = await session.execute(stmt)
|
|
140
|
+
declarations = result.scalars().all()
|
|
141
|
+
informalizations = [
|
|
142
|
+
InformalizationResult(
|
|
143
|
+
declaration_id=declaration.id,
|
|
144
|
+
declaration_name=declaration.name,
|
|
145
|
+
informalization=declaration.informalization,
|
|
146
|
+
)
|
|
147
|
+
for declaration in declarations
|
|
148
|
+
]
|
|
149
|
+
logger.info(f"Loaded {len(informalizations)} existing informalizations")
|
|
150
|
+
return informalizations
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def _get_declarations_to_process(
|
|
154
|
+
session: AsyncSession, limit: int | None
|
|
155
|
+
) -> list[Declaration]:
|
|
156
|
+
"""Query and return declarations that need informalization."""
|
|
157
|
+
stmt = select(Declaration).where(Declaration.informalization.is_(None))
|
|
158
|
+
if limit:
|
|
159
|
+
stmt = stmt.limit(limit)
|
|
160
|
+
result = await session.execute(stmt)
|
|
161
|
+
return list(result.scalars().all())
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# --- Cross-Database Cache Loading ---
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _discover_database_files() -> list[Path]:
|
|
168
|
+
"""Discover all lean_explore.db files in data/ and cache/ directories.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
List of paths to discovered database files
|
|
172
|
+
"""
|
|
173
|
+
database_files = []
|
|
174
|
+
|
|
175
|
+
# Search in data directory
|
|
176
|
+
data_dir = Config.DATA_DIRECTORY
|
|
177
|
+
if data_dir.exists():
|
|
178
|
+
database_files.extend(data_dir.rglob("lean_explore.db"))
|
|
179
|
+
|
|
180
|
+
# Search in cache directory
|
|
181
|
+
cache_dir = Config.CACHE_DIRECTORY
|
|
182
|
+
if cache_dir.exists():
|
|
183
|
+
database_files.extend(cache_dir.rglob("lean_explore.db"))
|
|
184
|
+
|
|
185
|
+
logger.info(f"Discovered {len(database_files)} database files")
|
|
186
|
+
return database_files
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def _load_cache_from_databases(
|
|
190
|
+
database_files: list[Path],
|
|
191
|
+
) -> dict[tuple[str, str], str]:
|
|
192
|
+
"""Load informalizations from all discovered databases.
|
|
193
|
+
|
|
194
|
+
Builds a cache mapping (name, source_text) -> informalization by scanning
|
|
195
|
+
all databases for declarations that have informalizations.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
database_files: List of database file paths to scan
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Dictionary mapping (name, source_text) -> informalization
|
|
202
|
+
"""
|
|
203
|
+
cache: dict[tuple[str, str], str] = {}
|
|
204
|
+
|
|
205
|
+
for db_path in database_files:
|
|
206
|
+
db_url = f"sqlite+aiosqlite:///{db_path}"
|
|
207
|
+
logger.info(f"Loading cache from {db_path}")
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
engine = create_async_engine(db_url)
|
|
211
|
+
async with AsyncSession(engine) as session:
|
|
212
|
+
stmt = select(Declaration).where(
|
|
213
|
+
Declaration.informalization.isnot(None)
|
|
214
|
+
)
|
|
215
|
+
result = await session.execute(stmt)
|
|
216
|
+
declarations = result.scalars().all()
|
|
217
|
+
|
|
218
|
+
for declaration in declarations:
|
|
219
|
+
cache_key = (declaration.name, declaration.source_text)
|
|
220
|
+
if cache_key not in cache:
|
|
221
|
+
cache[cache_key] = declaration.informalization
|
|
222
|
+
|
|
223
|
+
logger.info(
|
|
224
|
+
f"Loaded {len(declarations)} informalizations from {db_path}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
await engine.dispose()
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.warning(f"Failed to load cache from {db_path}: {e}")
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
logger.info(f"Total cache size: {len(cache)} unique (name, source_text) pairs")
|
|
234
|
+
return cache
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
# --- Processing Functions ---
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
async def _process_one_declaration(
|
|
241
|
+
*,
|
|
242
|
+
declaration_data: DeclarationData,
|
|
243
|
+
client: OpenRouterClient,
|
|
244
|
+
model: str,
|
|
245
|
+
prompt_template: str,
|
|
246
|
+
informalizations_by_name: dict[str, str],
|
|
247
|
+
cache: dict[tuple[str, str], str],
|
|
248
|
+
semaphore: asyncio.Semaphore,
|
|
249
|
+
) -> InformalizationResult:
|
|
250
|
+
"""Process a single declaration and generate its informalization.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
declaration_data: Plain data extracted from Declaration ORM object
|
|
254
|
+
client: OpenRouter client
|
|
255
|
+
model: Model name to use
|
|
256
|
+
prompt_template: Prompt template string
|
|
257
|
+
informalizations_by_name: Map of declaration names to informalizations
|
|
258
|
+
cache: Map of (name, source_text) to cached informalizations
|
|
259
|
+
semaphore: Concurrency control semaphore
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
InformalizationResult with declaration info and generated informalization
|
|
263
|
+
"""
|
|
264
|
+
if declaration_data.informalization is not None:
|
|
265
|
+
return InformalizationResult(
|
|
266
|
+
declaration_id=declaration_data.id,
|
|
267
|
+
declaration_name=declaration_data.name,
|
|
268
|
+
informalization=None,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Check cross-database cache first
|
|
272
|
+
cache_key = (declaration_data.name, declaration_data.source_text)
|
|
273
|
+
if cache_key in cache:
|
|
274
|
+
return InformalizationResult(
|
|
275
|
+
declaration_id=declaration_data.id,
|
|
276
|
+
declaration_name=declaration_data.name,
|
|
277
|
+
informalization=cache[cache_key],
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
async with semaphore:
|
|
281
|
+
dependencies_text = ""
|
|
282
|
+
dependencies = _parse_dependencies(declaration_data.dependencies)
|
|
283
|
+
if dependencies:
|
|
284
|
+
dependency_informalizations = []
|
|
285
|
+
# Limit to first 20 dependencies
|
|
286
|
+
for dependency_name in dependencies[:20]:
|
|
287
|
+
if dependency_name in informalizations_by_name:
|
|
288
|
+
informal_description = informalizations_by_name[dependency_name]
|
|
289
|
+
# Truncate description to 256 characters
|
|
290
|
+
if len(informal_description) > 256:
|
|
291
|
+
informal_description = informal_description[:253] + "..."
|
|
292
|
+
dependency_informalizations.append(
|
|
293
|
+
f"- {dependency_name}: {informal_description}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if dependency_informalizations:
|
|
297
|
+
dependencies_text = "Dependencies:\n" + "\n".join(
|
|
298
|
+
dependency_informalizations
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
prompt = prompt_template.format(
|
|
302
|
+
name=declaration_data.name,
|
|
303
|
+
source_text=declaration_data.source_text,
|
|
304
|
+
docstring=declaration_data.docstring or "No docstring available",
|
|
305
|
+
dependencies=dependencies_text,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
response = await client.generate(
|
|
309
|
+
model=model,
|
|
310
|
+
messages=[{"role": "user", "content": prompt}],
|
|
311
|
+
temperature=0.3,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if response.choices and response.choices[0].message.content:
|
|
315
|
+
result = response.choices[0].message.content.strip()
|
|
316
|
+
return InformalizationResult(
|
|
317
|
+
declaration_id=declaration_data.id,
|
|
318
|
+
declaration_name=declaration_data.name,
|
|
319
|
+
informalization=result,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
logger.warning(f"Empty response for declaration {declaration_data.name}")
|
|
323
|
+
return InformalizationResult(
|
|
324
|
+
declaration_id=declaration_data.id,
|
|
325
|
+
declaration_name=declaration_data.name,
|
|
326
|
+
informalization=None,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
async def _process_layer(
|
|
331
|
+
*,
|
|
332
|
+
session: AsyncSession,
|
|
333
|
+
layer: list[Declaration],
|
|
334
|
+
client: OpenRouterClient,
|
|
335
|
+
model: str,
|
|
336
|
+
prompt_template: str,
|
|
337
|
+
informalizations_by_name: dict[str, str],
|
|
338
|
+
cache: dict[tuple[str, str], str],
|
|
339
|
+
semaphore: asyncio.Semaphore,
|
|
340
|
+
progress,
|
|
341
|
+
total_task,
|
|
342
|
+
batch_task,
|
|
343
|
+
commit_batch_size: int,
|
|
344
|
+
) -> int:
|
|
345
|
+
"""Process a single dependency layer.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
session: Async database session for search database
|
|
349
|
+
layer: List of declarations in this layer
|
|
350
|
+
client: OpenRouter client
|
|
351
|
+
model: Model name to use
|
|
352
|
+
prompt_template: Prompt template string
|
|
353
|
+
informalizations_by_name: Map of declaration names to informalizations
|
|
354
|
+
cache: Map of (name, source_text) to cached informalizations
|
|
355
|
+
semaphore: Concurrency control semaphore
|
|
356
|
+
progress: Rich progress bar
|
|
357
|
+
total_task: Progress task ID for total progress
|
|
358
|
+
batch_task: Progress task ID for batch progress
|
|
359
|
+
commit_batch_size: Number of updates to batch before committing
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Number of declarations processed in this layer
|
|
363
|
+
"""
|
|
364
|
+
processed = 0
|
|
365
|
+
pending_updates = []
|
|
366
|
+
|
|
367
|
+
# Extract data from ORM objects before creating async tasks
|
|
368
|
+
# This avoids SQLAlchemy session issues with concurrent access
|
|
369
|
+
declaration_data_list = [
|
|
370
|
+
DeclarationData(
|
|
371
|
+
id=d.id,
|
|
372
|
+
name=d.name,
|
|
373
|
+
source_text=d.source_text,
|
|
374
|
+
docstring=d.docstring,
|
|
375
|
+
dependencies=d.dependencies,
|
|
376
|
+
informalization=d.informalization,
|
|
377
|
+
)
|
|
378
|
+
for d in layer
|
|
379
|
+
]
|
|
380
|
+
|
|
381
|
+
# Create tasks for all declarations in this layer
|
|
382
|
+
tasks = [
|
|
383
|
+
asyncio.create_task(
|
|
384
|
+
_process_one_declaration(
|
|
385
|
+
declaration_data=data,
|
|
386
|
+
client=client,
|
|
387
|
+
model=model,
|
|
388
|
+
prompt_template=prompt_template,
|
|
389
|
+
informalizations_by_name=informalizations_by_name,
|
|
390
|
+
cache=cache,
|
|
391
|
+
semaphore=semaphore,
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
for data in declaration_data_list
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
# Process results as they complete
|
|
398
|
+
for coro in asyncio.as_completed(tasks):
|
|
399
|
+
result = await coro
|
|
400
|
+
|
|
401
|
+
if result.informalization:
|
|
402
|
+
pending_updates.append(
|
|
403
|
+
{
|
|
404
|
+
"id": result.declaration_id,
|
|
405
|
+
"informalization": result.informalization,
|
|
406
|
+
}
|
|
407
|
+
)
|
|
408
|
+
informalizations_by_name[result.declaration_name] = result.informalization
|
|
409
|
+
processed += 1
|
|
410
|
+
|
|
411
|
+
progress.update(total_task, advance=1)
|
|
412
|
+
progress.update(batch_task, advance=1)
|
|
413
|
+
|
|
414
|
+
if len(pending_updates) >= commit_batch_size:
|
|
415
|
+
await session.execute(update(Declaration), pending_updates)
|
|
416
|
+
await session.commit()
|
|
417
|
+
logger.info(f"Committed batch of {len(pending_updates)} updates")
|
|
418
|
+
pending_updates.clear()
|
|
419
|
+
progress.reset(batch_task)
|
|
420
|
+
|
|
421
|
+
if pending_updates:
|
|
422
|
+
await session.execute(update(Declaration), pending_updates)
|
|
423
|
+
await session.commit()
|
|
424
|
+
logger.info(f"Committed batch of {len(pending_updates)} updates")
|
|
425
|
+
progress.reset(batch_task)
|
|
426
|
+
|
|
427
|
+
return processed
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
async def _process_layers(
|
|
431
|
+
*,
|
|
432
|
+
session: AsyncSession,
|
|
433
|
+
layers: list[list[Declaration]],
|
|
434
|
+
client: OpenRouterClient,
|
|
435
|
+
model: str,
|
|
436
|
+
prompt_template: str,
|
|
437
|
+
existing_informalizations: list[InformalizationResult],
|
|
438
|
+
cache: dict[tuple[str, str], str],
|
|
439
|
+
semaphore: asyncio.Semaphore,
|
|
440
|
+
commit_batch_size: int,
|
|
441
|
+
) -> int:
|
|
442
|
+
"""Process declarations layer by layer with progress tracking.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
session: Async database session for search database
|
|
446
|
+
layers: List of dependency layers to process
|
|
447
|
+
client: OpenRouter client
|
|
448
|
+
model: Model name to use
|
|
449
|
+
prompt_template: Prompt template string
|
|
450
|
+
existing_informalizations: List of existing informalizations
|
|
451
|
+
cache: Map of (name, source_text) to cached informalizations
|
|
452
|
+
semaphore: Concurrency control semaphore
|
|
453
|
+
commit_batch_size: Number of updates to batch before committing to database
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Number of declarations processed
|
|
457
|
+
"""
|
|
458
|
+
total = sum(len(layer) for layer in layers)
|
|
459
|
+
processed = 0
|
|
460
|
+
|
|
461
|
+
informalizations_by_name = {
|
|
462
|
+
inf.declaration_name: inf.informalization
|
|
463
|
+
for inf in existing_informalizations
|
|
464
|
+
if inf.informalization is not None
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
with Progress(
|
|
468
|
+
SpinnerColumn(),
|
|
469
|
+
TextColumn("[progress.description]{task.description}"),
|
|
470
|
+
BarColumn(),
|
|
471
|
+
TaskProgressColumn(),
|
|
472
|
+
TimeRemainingColumn(),
|
|
473
|
+
) as progress:
|
|
474
|
+
total_task = progress.add_task(f"[cyan]Total ({total:,})", total=total)
|
|
475
|
+
batch_task = progress.add_task(
|
|
476
|
+
f"[green]Batch ({commit_batch_size:,})", total=commit_batch_size
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
for layer_num, layer in enumerate(layers):
|
|
480
|
+
logger.info(
|
|
481
|
+
f"Processing layer {layer_num + 1}/{len(layers)} "
|
|
482
|
+
f"({len(layer)} declarations)"
|
|
483
|
+
)
|
|
484
|
+
layer_processed = await _process_layer(
|
|
485
|
+
session=session,
|
|
486
|
+
layer=layer,
|
|
487
|
+
client=client,
|
|
488
|
+
model=model,
|
|
489
|
+
prompt_template=prompt_template,
|
|
490
|
+
informalizations_by_name=informalizations_by_name,
|
|
491
|
+
cache=cache,
|
|
492
|
+
semaphore=semaphore,
|
|
493
|
+
progress=progress,
|
|
494
|
+
total_task=total_task,
|
|
495
|
+
batch_task=batch_task,
|
|
496
|
+
commit_batch_size=commit_batch_size,
|
|
497
|
+
)
|
|
498
|
+
processed += layer_processed
|
|
499
|
+
logger.info(
|
|
500
|
+
f"Completed layer {layer_num + 1}: "
|
|
501
|
+
f"{layer_processed}/{len(layer)} declarations informalized"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return processed
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
# --- Public API ---
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
async def _apply_cache_to_declarations(
|
|
511
|
+
session: AsyncSession,
|
|
512
|
+
declarations: list[Declaration],
|
|
513
|
+
cache: dict[tuple[str, str], str],
|
|
514
|
+
commit_batch_size: int = 1000,
|
|
515
|
+
) -> tuple[int, list[Declaration]]:
|
|
516
|
+
"""Apply cached informalizations to declarations.
|
|
517
|
+
|
|
518
|
+
This is a fast first pass that applies all cache hits before making any
|
|
519
|
+
API calls, allowing the user to see exactly how many API calls will be needed.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
session: Async database session
|
|
523
|
+
declarations: List of declarations to check against cache
|
|
524
|
+
cache: Map of (name, source_text) to cached informalizations
|
|
525
|
+
commit_batch_size: Number of updates to batch before committing
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Tuple of (cache_hits_count, list of declarations still needing API calls)
|
|
529
|
+
"""
|
|
530
|
+
from sqlalchemy import text
|
|
531
|
+
|
|
532
|
+
# Phase 1: Match all declarations against cache in memory
|
|
533
|
+
updates_to_apply: list[tuple[int, str]] = []
|
|
534
|
+
remaining: list[Declaration] = []
|
|
535
|
+
|
|
536
|
+
for declaration in declarations:
|
|
537
|
+
cache_key = (declaration.name, declaration.source_text)
|
|
538
|
+
if cache_key in cache:
|
|
539
|
+
updates_to_apply.append((declaration.id, cache[cache_key]))
|
|
540
|
+
else:
|
|
541
|
+
remaining.append(declaration)
|
|
542
|
+
|
|
543
|
+
logger.info(
|
|
544
|
+
f"Cache matching complete: {len(updates_to_apply)} hits, "
|
|
545
|
+
f"{len(remaining)} misses"
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
# Phase 2: Apply updates in batches using raw SQL for efficiency
|
|
549
|
+
if updates_to_apply:
|
|
550
|
+
num_updates = len(updates_to_apply)
|
|
551
|
+
total_batches = (num_updates + commit_batch_size - 1) // commit_batch_size
|
|
552
|
+
logger.info(
|
|
553
|
+
f"Applying {len(updates_to_apply)} cached informalizations "
|
|
554
|
+
f"in {total_batches} batches..."
|
|
555
|
+
)
|
|
556
|
+
stmt = text("UPDATE declarations SET informalization = :inf WHERE id = :id")
|
|
557
|
+
for i in range(0, len(updates_to_apply), commit_batch_size):
|
|
558
|
+
batch = updates_to_apply[i : i + commit_batch_size]
|
|
559
|
+
params = [{"id": decl_id, "inf": inf} for decl_id, inf in batch]
|
|
560
|
+
conn = await session.connection()
|
|
561
|
+
await conn.execute(stmt, params)
|
|
562
|
+
await session.commit()
|
|
563
|
+
batch_num = i // commit_batch_size + 1
|
|
564
|
+
if batch_num % 10 == 0 or batch_num == total_batches:
|
|
565
|
+
logger.info(f"Committed batch {batch_num}/{total_batches}")
|
|
566
|
+
|
|
567
|
+
return len(updates_to_apply), remaining
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
async def informalize_declarations(
|
|
571
|
+
search_db_engine: AsyncEngine,
|
|
572
|
+
*,
|
|
573
|
+
model: str = "google/gemini-3-flash-preview",
|
|
574
|
+
commit_batch_size: int = 1000,
|
|
575
|
+
max_concurrent: int = 100,
|
|
576
|
+
limit: int | None = None,
|
|
577
|
+
) -> None:
|
|
578
|
+
"""Generate informalizations for declarations missing them.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
search_db_engine: Async database engine for search database (Declaration table)
|
|
582
|
+
model: LLM model to use for generation
|
|
583
|
+
commit_batch_size: Number of updates to batch before committing to database
|
|
584
|
+
max_concurrent: Maximum number of concurrent LLM API calls
|
|
585
|
+
limit: Maximum number of declarations to process (None for all)
|
|
586
|
+
"""
|
|
587
|
+
prompt_template = (Path(__file__).parent / "prompt.txt").read_text()
|
|
588
|
+
logger.info("Starting informalization process...")
|
|
589
|
+
logger.info(
|
|
590
|
+
f"Model: {model}, Max concurrent: {max_concurrent}, "
|
|
591
|
+
f"Commit batch size: {commit_batch_size}"
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Discover and load cache from all existing databases
|
|
595
|
+
logger.info("Discovering existing databases for cache...")
|
|
596
|
+
database_files = _discover_database_files()
|
|
597
|
+
cache = await _load_cache_from_databases(database_files)
|
|
598
|
+
|
|
599
|
+
async with AsyncSession(search_db_engine, expire_on_commit=False) as search_session:
|
|
600
|
+
existing_informalizations = await _load_existing_informalizations(
|
|
601
|
+
search_session
|
|
602
|
+
)
|
|
603
|
+
declarations = await _get_declarations_to_process(search_session, limit)
|
|
604
|
+
|
|
605
|
+
logger.info(f"Found {len(declarations)} declarations needing informalization")
|
|
606
|
+
if not declarations:
|
|
607
|
+
logger.info("No declarations to process")
|
|
608
|
+
return
|
|
609
|
+
|
|
610
|
+
# Phase 1: Apply all cache hits first
|
|
611
|
+
logger.info("Phase 1: Applying cached informalizations...")
|
|
612
|
+
cache_hits, remaining_declarations = await _apply_cache_to_declarations(
|
|
613
|
+
search_session, declarations, cache, commit_batch_size
|
|
614
|
+
)
|
|
615
|
+
logger.info(
|
|
616
|
+
f"Applied {cache_hits} informalizations from cache, "
|
|
617
|
+
f"{len(remaining_declarations)} remaining need API calls"
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
if not remaining_declarations:
|
|
621
|
+
logger.info("All declarations served from cache, no API calls needed")
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
# Phase 2: Process remaining declarations with API calls
|
|
625
|
+
logger.info("Phase 2: Making API calls for remaining declarations...")
|
|
626
|
+
client = OpenRouterClient()
|
|
627
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
628
|
+
|
|
629
|
+
# Reload existing informalizations (now includes cache hits)
|
|
630
|
+
existing_informalizations = await _load_existing_informalizations(
|
|
631
|
+
search_session
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
logger.info("Building dependency layers for remaining declarations...")
|
|
635
|
+
layers = _build_dependency_layers(remaining_declarations)
|
|
636
|
+
logger.info(f"Built {len(layers)} dependency layers")
|
|
637
|
+
|
|
638
|
+
processed = await _process_layers(
|
|
639
|
+
session=search_session,
|
|
640
|
+
layers=layers,
|
|
641
|
+
client=client,
|
|
642
|
+
model=model,
|
|
643
|
+
prompt_template=prompt_template,
|
|
644
|
+
existing_informalizations=existing_informalizations,
|
|
645
|
+
cache=cache,
|
|
646
|
+
semaphore=semaphore,
|
|
647
|
+
commit_batch_size=commit_batch_size,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
logger.info(
|
|
651
|
+
f"Informalization complete. Processed {processed}/"
|
|
652
|
+
f"{len(remaining_declarations)} remaining declarations via API"
|
|
653
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Package configuration for Lean extraction.
|
|
2
|
+
|
|
3
|
+
This module defines the configuration dataclass and version strategy enum
|
|
4
|
+
for Lean packages to extract.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VersionStrategy(Enum):
|
|
13
|
+
"""Strategy for selecting which version of a package to extract."""
|
|
14
|
+
|
|
15
|
+
LATEST = "latest"
|
|
16
|
+
"""Use HEAD/main branch - for packages with CI that ensures main compiles."""
|
|
17
|
+
|
|
18
|
+
TAGGED = "tagged"
|
|
19
|
+
"""Use the latest git tag - safer for downstream packages."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class PackageConfig:
|
|
24
|
+
"""Configuration for a Lean package extraction."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
"""Package name (e.g., 'mathlib', 'physlean')."""
|
|
28
|
+
|
|
29
|
+
git_url: str
|
|
30
|
+
"""GitHub repository URL."""
|
|
31
|
+
|
|
32
|
+
module_prefixes: list[str]
|
|
33
|
+
"""Module name prefixes that belong to this package (e.g., ['Mathlib'])."""
|
|
34
|
+
|
|
35
|
+
version_strategy: VersionStrategy = VersionStrategy.TAGGED
|
|
36
|
+
"""Strategy for selecting the version to extract."""
|
|
37
|
+
|
|
38
|
+
lean_toolchain: str | None = None
|
|
39
|
+
"""Override Lean toolchain version. If None, determined from package."""
|
|
40
|
+
|
|
41
|
+
depends_on: list[str] = field(default_factory=list)
|
|
42
|
+
"""List of package names this package depends on (for extraction ordering)."""
|
|
43
|
+
|
|
44
|
+
extract_core: bool = False
|
|
45
|
+
"""If True, also extract Init/Lean/Std modules from this package's toolchain."""
|
|
46
|
+
|
|
47
|
+
def workspace_path(self, base_path: Path) -> Path:
|
|
48
|
+
"""Get the workspace path for this package."""
|
|
49
|
+
return base_path / self.name
|
|
50
|
+
|
|
51
|
+
def should_include_module(self, module_name: str) -> bool:
|
|
52
|
+
"""Check if a module belongs to this package based on prefixes.
|
|
53
|
+
|
|
54
|
+
Uses exact match or prefix + "." to avoid "Lean" matching "LeanSearchClient".
|
|
55
|
+
"""
|
|
56
|
+
return any(
|
|
57
|
+
module_name == prefix or module_name.startswith(prefix + ".")
|
|
58
|
+
for prefix in self.module_prefixes
|
|
59
|
+
)
|