lean-explore 0.3.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. lean_explore/__init__.py +14 -1
  2. lean_explore/api/__init__.py +12 -1
  3. lean_explore/api/client.py +64 -176
  4. lean_explore/cli/__init__.py +10 -1
  5. lean_explore/cli/data_commands.py +184 -489
  6. lean_explore/cli/display.py +171 -0
  7. lean_explore/cli/main.py +51 -608
  8. lean_explore/config.py +244 -0
  9. lean_explore/extract/__init__.py +5 -0
  10. lean_explore/extract/__main__.py +368 -0
  11. lean_explore/extract/doc_gen4.py +200 -0
  12. lean_explore/extract/doc_parser.py +499 -0
  13. lean_explore/extract/embeddings.py +369 -0
  14. lean_explore/extract/github.py +110 -0
  15. lean_explore/extract/index.py +316 -0
  16. lean_explore/extract/informalize.py +653 -0
  17. lean_explore/extract/package_config.py +59 -0
  18. lean_explore/extract/package_registry.py +45 -0
  19. lean_explore/extract/package_utils.py +105 -0
  20. lean_explore/extract/types.py +25 -0
  21. lean_explore/mcp/__init__.py +11 -1
  22. lean_explore/mcp/app.py +14 -46
  23. lean_explore/mcp/server.py +20 -35
  24. lean_explore/mcp/tools.py +71 -205
  25. lean_explore/models/__init__.py +9 -0
  26. lean_explore/models/search_db.py +76 -0
  27. lean_explore/models/search_types.py +53 -0
  28. lean_explore/search/__init__.py +32 -0
  29. lean_explore/search/engine.py +651 -0
  30. lean_explore/search/scoring.py +156 -0
  31. lean_explore/search/service.py +68 -0
  32. lean_explore/search/tokenization.py +71 -0
  33. lean_explore/util/__init__.py +28 -0
  34. lean_explore/util/embedding_client.py +92 -0
  35. lean_explore/util/logging.py +22 -0
  36. lean_explore/util/openrouter_client.py +63 -0
  37. lean_explore/util/reranker_client.py +187 -0
  38. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/METADATA +32 -9
  39. lean_explore-1.0.1.dist-info/RECORD +43 -0
  40. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/WHEEL +1 -1
  41. lean_explore-1.0.1.dist-info/entry_points.txt +2 -0
  42. lean_explore/cli/agent.py +0 -788
  43. lean_explore/cli/config_utils.py +0 -481
  44. lean_explore/defaults.py +0 -114
  45. lean_explore/local/__init__.py +0 -1
  46. lean_explore/local/search.py +0 -1050
  47. lean_explore/local/service.py +0 -479
  48. lean_explore/shared/__init__.py +0 -1
  49. lean_explore/shared/models/__init__.py +0 -1
  50. lean_explore/shared/models/api.py +0 -117
  51. lean_explore/shared/models/db.py +0 -396
  52. lean_explore-0.3.0.dist-info/RECORD +0 -26
  53. lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
  54. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/licenses/LICENSE +0 -0
  55. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,651 @@
1
+ """Core search engine for Lean declarations.
2
+
3
+ This module provides the SearchEngine class that implements hybrid search using
4
+ BM25 lexical matching and FAISS semantic search, combined via Reciprocal Rank
5
+ Fusion (RRF) and cross-encoder reranking.
6
+
7
+ Note: On macOS, torch and FAISS have OpenMP library conflicts. To avoid segfaults:
8
+ - FAISS is imported lazily (not at module level)
9
+ - When semantic search is needed, torch/embeddings are loaded FIRST, then FAISS
10
+ """
11
+
12
+ import json
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import bm25s
18
+ import numpy as np
19
+ from sqlalchemy import select
20
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
21
+
22
+ from lean_explore.config import Config
23
+ from lean_explore.models import Declaration, SearchResult
24
+ from lean_explore.search.scoring import (
25
+ fuzzy_name_score,
26
+ normalize_dependency_counts,
27
+ normalize_scores,
28
+ )
29
+ from lean_explore.search.tokenization import (
30
+ is_autogenerated,
31
+ tokenize_raw,
32
+ tokenize_spaced,
33
+ tokenize_words,
34
+ )
35
+
36
+ if TYPE_CHECKING:
37
+ import faiss
38
+
39
+ from lean_explore.util import EmbeddingClient, RerankerClient
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class SearchEngine:
45
+ """Core search engine for Lean declarations.
46
+
47
+ Uses two-stage retrieval:
48
+ 1. FAISS semantic search on informalizations
49
+ 2. BM25 lexical search on declaration names (independent)
50
+ Then merges and reranks candidates.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ db_url: str | None = None,
56
+ embedding_client: "EmbeddingClient | None" = None,
57
+ embedding_model_name: str = "Qwen/Qwen3-Embedding-0.6B",
58
+ reranker_client: "RerankerClient | None" = None,
59
+ reranker_model_name: str = "Qwen/Qwen3-Reranker-0.6B",
60
+ faiss_index_path: Path | None = None,
61
+ faiss_ids_map_path: Path | None = None,
62
+ use_local_data: bool = True,
63
+ ):
64
+ """Initialize the search engine.
65
+
66
+ Args:
67
+ db_url: Database URL. Defaults to configured URL.
68
+ embedding_client: Client for generating embeddings. Created lazily if None.
69
+ embedding_model_name: Name of the embedding model to use.
70
+ reranker_client: Client for reranking results. Created lazily if None.
71
+ reranker_model_name: Name of the reranker model to use.
72
+ faiss_index_path: Path to FAISS index. Defaults to config path.
73
+ faiss_ids_map_path: Path to FAISS ID mapping. Defaults to config path.
74
+ use_local_data: If True, use DATA_DIRECTORY paths. If False, use
75
+ CACHE_DIRECTORY paths (for downloaded remote data).
76
+ """
77
+ self._embedding_client = embedding_client
78
+ self._embedding_model_name = embedding_model_name
79
+ self._reranker_client = reranker_client
80
+ self._reranker_model_name = reranker_model_name
81
+
82
+ if use_local_data:
83
+ base_path = Config.ACTIVE_DATA_PATH
84
+ default_db_url = Config.EXTRACTION_DATABASE_URL
85
+ else:
86
+ base_path = Config.ACTIVE_CACHE_PATH
87
+ default_db_url = Config.DATABASE_URL
88
+
89
+ self.db_url = db_url or default_db_url
90
+ self.engine: AsyncEngine = create_async_engine(self.db_url)
91
+
92
+ self._faiss_informal_path = faiss_index_path or (
93
+ base_path / "informalization_faiss.index"
94
+ )
95
+ self._faiss_informal_ids_path = faiss_ids_map_path or (
96
+ base_path / "informalization_faiss_ids_map.json"
97
+ )
98
+ self._faiss_informal_index: faiss.Index | None = None
99
+ self._faiss_informal_id_map: list[int] | None = None
100
+
101
+ self._bm25_spaced_path = base_path / "bm25_name_spaced"
102
+ self._bm25_raw_path = base_path / "bm25_name_raw"
103
+ self._bm25_ids_map_path = base_path / "bm25_ids_map.json"
104
+ self._all_declaration_ids: list[int] | None = None
105
+ self._bm25_name_spaced: bm25s.BM25 | None = None
106
+ self._bm25_name_raw: bm25s.BM25 | None = None
107
+
108
+ self._validate_paths()
109
+
110
+ def _validate_paths(self) -> None:
111
+ """Validate that required data files exist."""
112
+ required_paths = [
113
+ self._faiss_informal_path,
114
+ self._faiss_informal_ids_path,
115
+ self._bm25_spaced_path,
116
+ self._bm25_raw_path,
117
+ self._bm25_ids_map_path,
118
+ ]
119
+ for path in required_paths:
120
+ if not path.exists():
121
+ raise FileNotFoundError(
122
+ f"Required file not found at {path}. "
123
+ "Please run 'lean-explore data fetch' to download the data."
124
+ )
125
+
126
+ @property
127
+ def embedding_client(self) -> "EmbeddingClient":
128
+ """Lazily create the embedding client to avoid loading torch at import time."""
129
+ if self._embedding_client is None:
130
+ from lean_explore.util import EmbeddingClient
131
+
132
+ self._embedding_client = EmbeddingClient(
133
+ model_name=self._embedding_model_name,
134
+ max_length=512,
135
+ )
136
+ return self._embedding_client
137
+
138
+ @property
139
+ def reranker_client(self) -> "RerankerClient":
140
+ """Lazily create the reranker client to avoid loading torch at import time."""
141
+ if self._reranker_client is None:
142
+ from lean_explore.util import RerankerClient
143
+
144
+ self._reranker_client = RerankerClient(
145
+ model_name=self._reranker_model_name,
146
+ max_length=256,
147
+ )
148
+ return self._reranker_client
149
+
150
+ def _ensure_faiss_loaded(self) -> None:
151
+ """Load the FAISS index if not already loaded."""
152
+ if self._faiss_informal_index is not None:
153
+ return
154
+
155
+ import faiss
156
+
157
+ logger.info(f"Loading FAISS index from {self._faiss_informal_path}")
158
+ self._faiss_informal_index = faiss.read_index(str(self._faiss_informal_path))
159
+ with open(self._faiss_informal_ids_path) as f:
160
+ self._faiss_informal_id_map = json.load(f)
161
+
162
+ @property
163
+ def faiss_informal_index(self) -> "faiss.Index":
164
+ """Get the informalization FAISS index."""
165
+ self._ensure_faiss_loaded()
166
+ return self._faiss_informal_index # type: ignore[return-value]
167
+
168
+ @property
169
+ def faiss_informal_id_map(self) -> list[int]:
170
+ """Get the informalization FAISS ID mapping."""
171
+ self._ensure_faiss_loaded()
172
+ return self._faiss_informal_id_map # type: ignore[return-value]
173
+
174
+ def _ensure_bm25_loaded(self) -> None:
175
+ """Load pre-built BM25 indices from disk."""
176
+ if self._bm25_name_spaced is not None:
177
+ return
178
+
179
+ logger.info(f"Loading BM25 indices from {self._bm25_spaced_path.parent}")
180
+
181
+ self._bm25_name_spaced = bm25s.BM25.load(str(self._bm25_spaced_path))
182
+ self._bm25_name_raw = bm25s.BM25.load(str(self._bm25_raw_path))
183
+
184
+ with open(self._bm25_ids_map_path) as f:
185
+ self._all_declaration_ids = json.load(f)
186
+
187
+ logger.info(f"BM25 indices loaded ({len(self._all_declaration_ids)} decls)")
188
+
189
+ def _retrieve_bm25_candidates(self, query: str, bm25_k: int) -> dict[int, float]:
190
+ """Retrieve candidates using BM25 on declaration names.
191
+
192
+ Args:
193
+ query: Search query string.
194
+ bm25_k: Number of candidates to retrieve.
195
+
196
+ Returns:
197
+ Map of declaration ID to BM25 score.
198
+ """
199
+ self._ensure_bm25_loaded()
200
+
201
+ query_tokens_spaced = tokenize_spaced(query)
202
+ query_tokens_raw = tokenize_raw(query)
203
+
204
+ results_spaced, scores_spaced = self._bm25_name_spaced.retrieve(
205
+ [query_tokens_spaced], k=bm25_k
206
+ )
207
+ results_raw, scores_raw = self._bm25_name_raw.retrieve(
208
+ [query_tokens_raw], k=bm25_k
209
+ )
210
+
211
+ bm25_map: dict[int, float] = {}
212
+ for idx, score in zip(results_spaced[0], scores_spaced[0]):
213
+ decl_id = self._all_declaration_ids[idx]
214
+ bm25_map[decl_id] = max(bm25_map.get(decl_id, 0.0), float(score))
215
+ for idx, score in zip(results_raw[0], scores_raw[0]):
216
+ decl_id = self._all_declaration_ids[idx]
217
+ bm25_map[decl_id] = max(bm25_map.get(decl_id, 0.0), float(score))
218
+
219
+ logger.info(f"BM25 name: {len(bm25_map)} candidates")
220
+ return bm25_map
221
+
222
+ async def _retrieve_semantic_candidates(
223
+ self, query: str, faiss_k: int
224
+ ) -> dict[int, float]:
225
+ """Retrieve candidates using semantic search on informalizations.
226
+
227
+ Args:
228
+ query: Search query string.
229
+ faiss_k: Number of candidates to retrieve from FAISS.
230
+
231
+ Returns:
232
+ Map of declaration ID to semantic similarity score.
233
+ """
234
+ embedding_response = await self.embedding_client.embed([query], is_query=True)
235
+ query_embedding = np.array([embedding_response.embeddings[0]], dtype=np.float32)
236
+
237
+ import faiss as faiss_module
238
+
239
+ faiss_module.normalize_L2(query_embedding)
240
+
241
+ informal_index = self.faiss_informal_index
242
+ informal_id_map = self.faiss_informal_id_map
243
+
244
+ if hasattr(informal_index, "nprobe"):
245
+ informal_index.nprobe = 64
246
+
247
+ distances, indices = informal_index.search(query_embedding, faiss_k)
248
+
249
+ semantic_map: dict[int, float] = {}
250
+ for idx, dist in zip(indices[0], distances[0]):
251
+ if idx == -1 or idx >= len(informal_id_map):
252
+ continue
253
+ decl_id = informal_id_map[idx]
254
+ similarity = float(dist)
255
+ semantic_map[decl_id] = max(semantic_map.get(decl_id, 0.0), similarity)
256
+
257
+ logger.info(f"FAISS informal: {len(semantic_map)} candidates")
258
+ return semantic_map
259
+
260
+ def _compute_rrf_scores(
261
+ self,
262
+ bm25_map: dict[int, float],
263
+ semantic_map: dict[int, float],
264
+ ) -> list[tuple[int, float]]:
265
+ """Compute RRF scores from BM25 and semantic retrieval signals.
266
+
267
+ Args:
268
+ bm25_map: Map of declaration ID to BM25 score.
269
+ semantic_map: Map of declaration ID to semantic similarity score.
270
+
271
+ Returns:
272
+ List of (declaration_id, rrf_score) sorted by score descending.
273
+ """
274
+ all_candidate_ids = set(bm25_map.keys()) | set(semantic_map.keys())
275
+ logger.info(f"Total merged candidates: {len(all_candidate_ids)}")
276
+
277
+ if not all_candidate_ids:
278
+ return []
279
+
280
+ bm25_sorted = sorted(bm25_map.items(), key=lambda x: x[1], reverse=True)
281
+ sem_sorted = sorted(semantic_map.items(), key=lambda x: x[1], reverse=True)
282
+
283
+ bm25_rank_map = {cid: rank + 1 for rank, (cid, _) in enumerate(bm25_sorted)}
284
+ sem_rank_map = {cid: rank + 1 for rank, (cid, _) in enumerate(sem_sorted)}
285
+
286
+ default_bm25_rank = len(bm25_sorted) + 1
287
+ default_sem_rank = len(sem_sorted) + 1
288
+
289
+ rrf_scores: list[tuple[int, float]] = []
290
+ for cid in all_candidate_ids:
291
+ name_rank = bm25_rank_map.get(cid, default_bm25_rank)
292
+ inf_rank = sem_rank_map.get(cid, default_sem_rank)
293
+ rrf_score = 1.0 / name_rank + 1.0 / inf_rank
294
+ rrf_scores.append((cid, rrf_score))
295
+
296
+ rrf_scores.sort(key=lambda x: x[1], reverse=True)
297
+ return rrf_scores
298
+
299
+ async def _apply_dependency_boost(
300
+ self,
301
+ rrf_scores: list[tuple[int, float]],
302
+ top_n: int = 500,
303
+ ) -> tuple[list[tuple[int, float]], dict[int, Declaration]]:
304
+ """Apply dependency-based boost to RRF scores.
305
+
306
+ Declarations that are dependencies of other top candidates get a boost.
307
+
308
+ Args:
309
+ rrf_scores: List of (declaration_id, rrf_score) sorted by score.
310
+ top_n: Number of top candidates to consider for dependency analysis.
311
+
312
+ Returns:
313
+ Tuple of (boosted_scores, declarations_map).
314
+ """
315
+ top_ids = [cid for cid, _ in rrf_scores[:top_n]]
316
+
317
+ async with AsyncSession(self.engine) as session:
318
+ stmt = select(Declaration).where(Declaration.id.in_(top_ids))
319
+ result = await session.execute(stmt)
320
+ declarations_map = {d.id: d for d in result.scalars().all()}
321
+
322
+ name_to_id = {
323
+ declarations_map[cid].name: cid
324
+ for cid in top_ids
325
+ if cid in declarations_map
326
+ }
327
+ dep_counts: dict[int, int] = {cid: 0 for cid in top_ids}
328
+
329
+ for cid in top_ids:
330
+ decl = declarations_map.get(cid)
331
+ if decl and decl.dependencies:
332
+ try:
333
+ deps = json.loads(decl.dependencies)
334
+ for dep_name in deps:
335
+ if dep_name in name_to_id:
336
+ dep_counts[name_to_id[dep_name]] += 1
337
+ except json.JSONDecodeError:
338
+ pass
339
+
340
+ max_deps = max(dep_counts.values()) if dep_counts else 0
341
+ boosted_scores: list[tuple[int, float]] = []
342
+
343
+ for rank, (cid, _) in enumerate(rrf_scores[:top_n], 1):
344
+ dep_count = dep_counts.get(cid, 0)
345
+ if max_deps > 0 and dep_count > 0:
346
+ dep_rank = (max_deps - dep_count) + 1
347
+ else:
348
+ dep_rank = max_deps + 1 if max_deps > 0 else top_n + 1
349
+
350
+ boosted_score = 1.0 / rank + 1.0 / dep_rank
351
+ boosted_scores.append((cid, boosted_score))
352
+
353
+ boosted_scores.sort(key=lambda x: x[1], reverse=True)
354
+ logger.info(f"Applied dependency boost to top {top_n} candidates")
355
+ return boosted_scores, declarations_map
356
+
357
+ async def _rerank_candidates(
358
+ self,
359
+ query: str,
360
+ scored_results: list[tuple[Declaration, float]],
361
+ limit: int,
362
+ ) -> list[SearchResult]:
363
+ """Apply cross-encoder reranking with additional signals.
364
+
365
+ Args:
366
+ query: Search query string.
367
+ scored_results: List of (declaration, score) tuples.
368
+ limit: Maximum number of results to return.
369
+
370
+ Returns:
371
+ List of SearchResult objects after reranking.
372
+ """
373
+ logger.info(f"Reranking top {len(scored_results)} candidates")
374
+
375
+ documents = [
376
+ f"{decl.name}: {decl.informalization}"
377
+ if decl.informalization
378
+ else decl.name
379
+ for decl, _ in scored_results
380
+ ]
381
+
382
+ rerank_response = await self.reranker_client.rerank(query, documents)
383
+ reranker_scores = rerank_response.scores
384
+
385
+ fuzzy_scores = [
386
+ fuzzy_name_score(query, decl.name) for decl, _ in scored_results
387
+ ]
388
+
389
+ bm25_informal_scores = self._compute_bm25_on_informalizations(
390
+ query, scored_results
391
+ )
392
+
393
+ dep_counts = self._compute_candidate_dependency_counts(scored_results)
394
+
395
+ norm_reranker = normalize_scores(reranker_scores)
396
+ norm_fuzzy = normalize_scores(fuzzy_scores)
397
+ norm_bm25 = normalize_scores(bm25_informal_scores)
398
+ norm_dep = normalize_dependency_counts(dep_counts)
399
+
400
+ final_scores = []
401
+ for i, (decl, _) in enumerate(scored_results):
402
+ score = 1.0 * norm_reranker[i] + 0.4 * norm_bm25[i] + 0.2 * norm_dep[i]
403
+ if fuzzy_scores[i] >= 0.7:
404
+ score += 1.0 * norm_fuzzy[i]
405
+ final_scores.append(score)
406
+
407
+ combined = sorted(
408
+ zip(scored_results, final_scores),
409
+ key=lambda x: x[1],
410
+ reverse=True,
411
+ )
412
+
413
+ return self._filter_and_convert_results(combined, limit)
414
+
415
+ def _compute_bm25_on_informalizations(
416
+ self,
417
+ query: str,
418
+ scored_results: list[tuple[Declaration, float]],
419
+ ) -> list[float]:
420
+ """Compute BM25 scores on informalizations for reranking.
421
+
422
+ Args:
423
+ query: Search query string.
424
+ scored_results: List of (declaration, score) tuples.
425
+
426
+ Returns:
427
+ List of BM25 scores for each candidate.
428
+ """
429
+ informalizations = [
430
+ decl.informalization if decl.informalization else decl.name
431
+ for decl, _ in scored_results
432
+ ]
433
+ informal_tokens = [tokenize_words(text) for text in informalizations]
434
+ query_tokens = tokenize_words(query)
435
+
436
+ bm25_informal = bm25s.BM25(method="bm25+")
437
+ bm25_informal.index(informal_tokens)
438
+ results, scores = bm25_informal.retrieve([query_tokens], k=len(informal_tokens))
439
+
440
+ bm25_scores = [0.0] * len(scored_results)
441
+ for idx, score in zip(results[0], scores[0]):
442
+ if int(idx) < len(bm25_scores):
443
+ bm25_scores[int(idx)] = float(score)
444
+
445
+ return bm25_scores
446
+
447
+ def _compute_candidate_dependency_counts(
448
+ self,
449
+ scored_results: list[tuple[Declaration, float]],
450
+ ) -> list[int]:
451
+ """Count how many candidates depend on each declaration.
452
+
453
+ Args:
454
+ scored_results: List of (declaration, score) tuples.
455
+
456
+ Returns:
457
+ List of dependency counts for each candidate.
458
+ """
459
+ candidate_names = {decl.name for decl, _ in scored_results}
460
+ dep_counts_map: dict[str, int] = {name: 0 for name in candidate_names}
461
+
462
+ for decl, _ in scored_results:
463
+ if decl.dependencies:
464
+ try:
465
+ deps = json.loads(decl.dependencies)
466
+ for dep_name in deps:
467
+ if dep_name in dep_counts_map:
468
+ dep_counts_map[dep_name] += 1
469
+ except json.JSONDecodeError:
470
+ pass
471
+
472
+ return [dep_counts_map.get(decl.name, 0) for decl, _ in scored_results]
473
+
474
+ def _filter_and_convert_results(
475
+ self,
476
+ combined: list[tuple[tuple[Declaration, float], float]],
477
+ limit: int,
478
+ ) -> list[SearchResult]:
479
+ """Filter auto-generated declarations and convert to SearchResult.
480
+
481
+ Args:
482
+ combined: List of ((declaration, old_score), final_score) tuples.
483
+ limit: Maximum number of results to return.
484
+
485
+ Returns:
486
+ List of SearchResult objects.
487
+ """
488
+ results = []
489
+ for (decl, _), _ in combined:
490
+ if not is_autogenerated(decl.name):
491
+ results.append(self._to_search_result(decl))
492
+ if len(results) >= limit:
493
+ break
494
+ return results
495
+
496
+ def _extract_package(self, module: str) -> str:
497
+ """Extract package name from module path.
498
+
499
+ Args:
500
+ module: Full module path (e.g., "Mathlib.Algebra.Group").
501
+
502
+ Returns:
503
+ Package name (first component of module path).
504
+ """
505
+ return module.split(".")[0] if module else ""
506
+
507
+ def _filter_by_packages(
508
+ self,
509
+ declarations_map: dict[int, Declaration],
510
+ packages: list[str],
511
+ ) -> dict[int, Declaration]:
512
+ """Filter declarations to only include specified packages.
513
+
514
+ Args:
515
+ declarations_map: Map of declaration ID to Declaration.
516
+ packages: List of package names to include.
517
+
518
+ Returns:
519
+ Filtered declarations map.
520
+ """
521
+ if not packages:
522
+ return declarations_map
523
+
524
+ package_set = set(packages)
525
+ return {
526
+ cid: decl
527
+ for cid, decl in declarations_map.items()
528
+ if self._extract_package(decl.module) in package_set
529
+ }
530
+
531
+ async def search(
532
+ self,
533
+ query: str,
534
+ limit: int = 50,
535
+ faiss_k: int = 1000,
536
+ bm25_k: int = 1000,
537
+ rerank_top: int | None = 25,
538
+ packages: list[str] | None = None,
539
+ ) -> list[SearchResult]:
540
+ """Search for Lean declarations using Reciprocal Rank Fusion.
541
+
542
+ Two-signal approach:
543
+ 1. BM25+ on declaration names (lexical match)
544
+ 2. Semantic search on informalizations (meaning match)
545
+
546
+ Combined via RRF: score = 1/name_rank + 1/informal_rank
547
+
548
+ Optionally applies cross-encoder reranking to the top candidates.
549
+
550
+ Args:
551
+ query: Search query string.
552
+ limit: Maximum number of results to return. Defaults to 50.
553
+ faiss_k: Number of candidates from FAISS index. Defaults to 1000.
554
+ bm25_k: Number of candidates from BM25 index. Defaults to 1000.
555
+ rerank_top: If set, apply cross-encoder reranking to top N candidates.
556
+ Set to 0 or None to skip reranking.
557
+ packages: Optional list of package names to filter by. If provided,
558
+ only declarations from these packages will be returned.
559
+
560
+ Returns:
561
+ List of SearchResult objects, ranked by combined score.
562
+ """
563
+ if not query.strip():
564
+ return []
565
+
566
+ bm25_map = self._retrieve_bm25_candidates(query, bm25_k)
567
+ semantic_map = await self._retrieve_semantic_candidates(query, faiss_k)
568
+ rrf_scores = self._compute_rrf_scores(bm25_map, semantic_map)
569
+
570
+ if not rrf_scores:
571
+ return []
572
+
573
+ boosted_scores, declarations_map = await self._apply_dependency_boost(
574
+ rrf_scores
575
+ )
576
+
577
+ # Apply package filtering if specified
578
+ if packages:
579
+ declarations_map = self._filter_by_packages(declarations_map, packages)
580
+ # Filter boosted_scores to only include filtered declarations
581
+ boosted_scores = [
582
+ (cid, score) for cid, score in boosted_scores if cid in declarations_map
583
+ ]
584
+ logger.info(f"Filtered to {len(declarations_map)} in {packages}")
585
+
586
+ top_n = rerank_top if rerank_top and rerank_top > 0 else limit
587
+
588
+ scored_results: list[tuple[Declaration, float]] = [
589
+ (declarations_map[cid], score)
590
+ for cid, score in boosted_scores[:top_n]
591
+ if cid in declarations_map
592
+ ]
593
+
594
+ if rerank_top and rerank_top > 0:
595
+ return await self._rerank_candidates(query, scored_results, limit)
596
+
597
+ results = []
598
+ for decl, _ in scored_results:
599
+ if not is_autogenerated(decl.name):
600
+ results.append(self._to_search_result(decl))
601
+ if len(results) >= limit:
602
+ break
603
+ return results
604
+
605
+ async def get_by_id(self, declaration_id: int) -> SearchResult | None:
606
+ """Retrieve a declaration by ID.
607
+
608
+ Args:
609
+ declaration_id: The declaration ID.
610
+
611
+ Returns:
612
+ SearchResult if found, None otherwise.
613
+ """
614
+ async with AsyncSession(self.engine) as session:
615
+ decl = await session.get(Declaration, declaration_id)
616
+ return self._to_search_result(decl) if decl else None
617
+
618
+ async def get_by_name(self, name: str) -> SearchResult | None:
619
+ """Retrieve a declaration by its exact name.
620
+
621
+ Args:
622
+ name: The exact declaration name (e.g., "AlgebraicGeometry.Scheme").
623
+
624
+ Returns:
625
+ SearchResult if found, None otherwise.
626
+ """
627
+ async with AsyncSession(self.engine) as session:
628
+ stmt = select(Declaration).where(Declaration.name == name)
629
+ result = await session.execute(stmt)
630
+ decl = result.scalar_one_or_none()
631
+ return self._to_search_result(decl) if decl else None
632
+
633
+ def _to_search_result(self, decl: Declaration) -> SearchResult:
634
+ """Convert Declaration ORM object to SearchResult.
635
+
636
+ Args:
637
+ decl: Declaration ORM object.
638
+
639
+ Returns:
640
+ SearchResult pydantic model.
641
+ """
642
+ return SearchResult(
643
+ id=decl.id,
644
+ name=decl.name,
645
+ module=decl.module,
646
+ docstring=decl.docstring,
647
+ source_text=decl.source_text,
648
+ source_link=decl.source_link,
649
+ dependencies=decl.dependencies,
650
+ informalization=decl.informalization,
651
+ )