lean-explore 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lean_explore/__init__.py +14 -1
- lean_explore/api/__init__.py +12 -1
- lean_explore/api/client.py +64 -176
- lean_explore/cli/__init__.py +10 -1
- lean_explore/cli/data_commands.py +157 -479
- lean_explore/cli/display.py +171 -0
- lean_explore/cli/main.py +51 -608
- lean_explore/config.py +244 -0
- lean_explore/extract/__init__.py +5 -0
- lean_explore/extract/__main__.py +368 -0
- lean_explore/extract/doc_gen4.py +200 -0
- lean_explore/extract/doc_parser.py +499 -0
- lean_explore/extract/embeddings.py +371 -0
- lean_explore/extract/github.py +110 -0
- lean_explore/extract/index.py +317 -0
- lean_explore/extract/informalize.py +653 -0
- lean_explore/extract/package_config.py +59 -0
- lean_explore/extract/package_registry.py +45 -0
- lean_explore/extract/package_utils.py +105 -0
- lean_explore/extract/types.py +25 -0
- lean_explore/mcp/__init__.py +11 -1
- lean_explore/mcp/app.py +14 -46
- lean_explore/mcp/server.py +20 -35
- lean_explore/mcp/tools.py +70 -205
- lean_explore/models/__init__.py +9 -0
- lean_explore/models/search_db.py +76 -0
- lean_explore/models/search_types.py +53 -0
- lean_explore/search/__init__.py +32 -0
- lean_explore/search/engine.py +655 -0
- lean_explore/search/scoring.py +156 -0
- lean_explore/search/service.py +68 -0
- lean_explore/search/tokenization.py +71 -0
- lean_explore/util/__init__.py +28 -0
- lean_explore/util/embedding_client.py +92 -0
- lean_explore/util/logging.py +22 -0
- lean_explore/util/openrouter_client.py +63 -0
- lean_explore/util/reranker_client.py +189 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/METADATA +32 -9
- lean_explore-1.0.0.dist-info/RECORD +43 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/WHEEL +1 -1
- lean_explore-1.0.0.dist-info/entry_points.txt +2 -0
- lean_explore/cli/agent.py +0 -788
- lean_explore/cli/config_utils.py +0 -481
- lean_explore/defaults.py +0 -114
- lean_explore/local/__init__.py +0 -1
- lean_explore/local/search.py +0 -1050
- lean_explore/local/service.py +0 -479
- lean_explore/shared/__init__.py +0 -1
- lean_explore/shared/models/__init__.py +0 -1
- lean_explore/shared/models/api.py +0 -117
- lean_explore/shared/models/db.py +0 -396
- lean_explore-0.3.0.dist-info/RECORD +0 -26
- lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,655 @@
|
|
|
1
|
+
"""Core search engine for Lean declarations.
|
|
2
|
+
|
|
3
|
+
This module provides the SearchEngine class that implements hybrid search using
|
|
4
|
+
BM25 lexical matching and FAISS semantic search, combined via Reciprocal Rank
|
|
5
|
+
Fusion (RRF) and cross-encoder reranking.
|
|
6
|
+
|
|
7
|
+
Note: On macOS, torch and FAISS have OpenMP library conflicts. To avoid segfaults:
|
|
8
|
+
- FAISS is imported lazily (not at module level)
|
|
9
|
+
- When semantic search is needed, torch/embeddings are loaded FIRST, then FAISS
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import bm25s
|
|
18
|
+
import numpy as np
|
|
19
|
+
from sqlalchemy import select
|
|
20
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
21
|
+
|
|
22
|
+
from lean_explore.config import Config
|
|
23
|
+
from lean_explore.models import Declaration, SearchResult
|
|
24
|
+
from lean_explore.search.scoring import (
|
|
25
|
+
fuzzy_name_score,
|
|
26
|
+
normalize_dependency_counts,
|
|
27
|
+
normalize_scores,
|
|
28
|
+
)
|
|
29
|
+
from lean_explore.search.tokenization import (
|
|
30
|
+
is_autogenerated,
|
|
31
|
+
tokenize_raw,
|
|
32
|
+
tokenize_spaced,
|
|
33
|
+
tokenize_words,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
import faiss
|
|
38
|
+
|
|
39
|
+
from lean_explore.util import EmbeddingClient, RerankerClient
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SearchEngine:
|
|
45
|
+
"""Core search engine for Lean declarations.
|
|
46
|
+
|
|
47
|
+
Uses two-stage retrieval:
|
|
48
|
+
1. FAISS semantic search on informalizations
|
|
49
|
+
2. BM25 lexical search on declaration names (independent)
|
|
50
|
+
Then merges and reranks candidates.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
db_url: str | None = None,
|
|
56
|
+
embedding_client: "EmbeddingClient | None" = None,
|
|
57
|
+
embedding_model_name: str = "Qwen/Qwen3-Embedding-0.6B",
|
|
58
|
+
reranker_client: "RerankerClient | None" = None,
|
|
59
|
+
reranker_model_name: str = "Qwen/Qwen3-Reranker-0.6B",
|
|
60
|
+
faiss_index_path: Path | None = None,
|
|
61
|
+
faiss_ids_map_path: Path | None = None,
|
|
62
|
+
use_local_data: bool = True,
|
|
63
|
+
):
|
|
64
|
+
"""Initialize the search engine.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
db_url: Database URL. Defaults to configured URL.
|
|
68
|
+
embedding_client: Client for generating embeddings. Created lazily if None.
|
|
69
|
+
embedding_model_name: Name of the embedding model to use.
|
|
70
|
+
reranker_client: Client for reranking results. Created lazily if None.
|
|
71
|
+
reranker_model_name: Name of the reranker model to use.
|
|
72
|
+
faiss_index_path: Path to FAISS index. Defaults to config path.
|
|
73
|
+
faiss_ids_map_path: Path to FAISS ID mapping. Defaults to config path.
|
|
74
|
+
use_local_data: If True, use DATA_DIRECTORY paths. If False, use
|
|
75
|
+
CACHE_DIRECTORY paths (for downloaded remote data).
|
|
76
|
+
"""
|
|
77
|
+
self._embedding_client = embedding_client
|
|
78
|
+
self._embedding_model_name = embedding_model_name
|
|
79
|
+
self._reranker_client = reranker_client
|
|
80
|
+
self._reranker_model_name = reranker_model_name
|
|
81
|
+
|
|
82
|
+
if use_local_data:
|
|
83
|
+
base_path = Config.ACTIVE_DATA_PATH
|
|
84
|
+
default_db_url = Config.EXTRACTION_DATABASE_URL
|
|
85
|
+
else:
|
|
86
|
+
base_path = Config.ACTIVE_CACHE_PATH
|
|
87
|
+
default_db_url = Config.DATABASE_URL
|
|
88
|
+
|
|
89
|
+
self.db_url = db_url or default_db_url
|
|
90
|
+
self.engine: AsyncEngine = create_async_engine(self.db_url)
|
|
91
|
+
|
|
92
|
+
self._faiss_informal_path = faiss_index_path or (
|
|
93
|
+
base_path / "informalization_faiss.index"
|
|
94
|
+
)
|
|
95
|
+
self._faiss_informal_ids_path = faiss_ids_map_path or (
|
|
96
|
+
base_path / "informalization_faiss_ids_map.json"
|
|
97
|
+
)
|
|
98
|
+
self._faiss_informal_index: faiss.Index | None = None
|
|
99
|
+
self._faiss_informal_id_map: list[int] | None = None
|
|
100
|
+
|
|
101
|
+
self._bm25_spaced_path = base_path / "bm25_name_spaced"
|
|
102
|
+
self._bm25_raw_path = base_path / "bm25_name_raw"
|
|
103
|
+
self._bm25_ids_map_path = base_path / "bm25_ids_map.json"
|
|
104
|
+
self._all_declaration_ids: list[int] | None = None
|
|
105
|
+
self._bm25_name_spaced: bm25s.BM25 | None = None
|
|
106
|
+
self._bm25_name_raw: bm25s.BM25 | None = None
|
|
107
|
+
|
|
108
|
+
self._validate_paths()
|
|
109
|
+
|
|
110
|
+
def _validate_paths(self) -> None:
|
|
111
|
+
"""Validate that required data files exist."""
|
|
112
|
+
required_paths = [
|
|
113
|
+
self._faiss_informal_path,
|
|
114
|
+
self._faiss_informal_ids_path,
|
|
115
|
+
self._bm25_spaced_path,
|
|
116
|
+
self._bm25_raw_path,
|
|
117
|
+
self._bm25_ids_map_path,
|
|
118
|
+
]
|
|
119
|
+
for path in required_paths:
|
|
120
|
+
if not path.exists():
|
|
121
|
+
raise FileNotFoundError(
|
|
122
|
+
f"Required file not found at {path}. "
|
|
123
|
+
"Please run 'lean-explore data fetch' to download the data."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def embedding_client(self) -> "EmbeddingClient":
|
|
128
|
+
"""Lazily create the embedding client to avoid loading torch at import time."""
|
|
129
|
+
if self._embedding_client is None:
|
|
130
|
+
from lean_explore.util import EmbeddingClient
|
|
131
|
+
|
|
132
|
+
self._embedding_client = EmbeddingClient(
|
|
133
|
+
model_name=self._embedding_model_name,
|
|
134
|
+
max_length=512,
|
|
135
|
+
)
|
|
136
|
+
return self._embedding_client
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def reranker_client(self) -> "RerankerClient":
|
|
140
|
+
"""Lazily create the reranker client to avoid loading torch at import time."""
|
|
141
|
+
if self._reranker_client is None:
|
|
142
|
+
from lean_explore.util import RerankerClient
|
|
143
|
+
|
|
144
|
+
self._reranker_client = RerankerClient(
|
|
145
|
+
model_name=self._reranker_model_name,
|
|
146
|
+
max_length=256,
|
|
147
|
+
)
|
|
148
|
+
return self._reranker_client
|
|
149
|
+
|
|
150
|
+
def _ensure_faiss_loaded(self) -> None:
|
|
151
|
+
"""Load the FAISS index if not already loaded."""
|
|
152
|
+
if self._faiss_informal_index is not None:
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
import faiss
|
|
156
|
+
|
|
157
|
+
logger.info(f"Loading FAISS index from {self._faiss_informal_path}")
|
|
158
|
+
self._faiss_informal_index = faiss.read_index(str(self._faiss_informal_path))
|
|
159
|
+
with open(self._faiss_informal_ids_path) as f:
|
|
160
|
+
self._faiss_informal_id_map = json.load(f)
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def faiss_informal_index(self) -> "faiss.Index":
|
|
164
|
+
"""Get the informalization FAISS index."""
|
|
165
|
+
self._ensure_faiss_loaded()
|
|
166
|
+
return self._faiss_informal_index # type: ignore[return-value]
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def faiss_informal_id_map(self) -> list[int]:
|
|
170
|
+
"""Get the informalization FAISS ID mapping."""
|
|
171
|
+
self._ensure_faiss_loaded()
|
|
172
|
+
return self._faiss_informal_id_map # type: ignore[return-value]
|
|
173
|
+
|
|
174
|
+
def _ensure_bm25_loaded(self) -> None:
|
|
175
|
+
"""Load pre-built BM25 indices from disk."""
|
|
176
|
+
if self._bm25_name_spaced is not None:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
logger.info(f"Loading BM25 indices from {self._bm25_spaced_path.parent}")
|
|
180
|
+
|
|
181
|
+
self._bm25_name_spaced = bm25s.BM25.load(str(self._bm25_spaced_path))
|
|
182
|
+
self._bm25_name_raw = bm25s.BM25.load(str(self._bm25_raw_path))
|
|
183
|
+
|
|
184
|
+
with open(self._bm25_ids_map_path) as f:
|
|
185
|
+
self._all_declaration_ids = json.load(f)
|
|
186
|
+
|
|
187
|
+
logger.info(f"BM25 indices loaded ({len(self._all_declaration_ids)} decls)")
|
|
188
|
+
|
|
189
|
+
def _retrieve_bm25_candidates(self, query: str, bm25_k: int) -> dict[int, float]:
|
|
190
|
+
"""Retrieve candidates using BM25 on declaration names.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
query: Search query string.
|
|
194
|
+
bm25_k: Number of candidates to retrieve.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Map of declaration ID to BM25 score.
|
|
198
|
+
"""
|
|
199
|
+
self._ensure_bm25_loaded()
|
|
200
|
+
|
|
201
|
+
query_tokens_spaced = tokenize_spaced(query)
|
|
202
|
+
query_tokens_raw = tokenize_raw(query)
|
|
203
|
+
|
|
204
|
+
results_spaced, scores_spaced = self._bm25_name_spaced.retrieve(
|
|
205
|
+
[query_tokens_spaced], k=bm25_k
|
|
206
|
+
)
|
|
207
|
+
results_raw, scores_raw = self._bm25_name_raw.retrieve(
|
|
208
|
+
[query_tokens_raw], k=bm25_k
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
bm25_map: dict[int, float] = {}
|
|
212
|
+
for idx, score in zip(results_spaced[0], scores_spaced[0]):
|
|
213
|
+
decl_id = self._all_declaration_ids[idx]
|
|
214
|
+
bm25_map[decl_id] = max(bm25_map.get(decl_id, 0.0), float(score))
|
|
215
|
+
for idx, score in zip(results_raw[0], scores_raw[0]):
|
|
216
|
+
decl_id = self._all_declaration_ids[idx]
|
|
217
|
+
bm25_map[decl_id] = max(bm25_map.get(decl_id, 0.0), float(score))
|
|
218
|
+
|
|
219
|
+
logger.info(f"BM25 name: {len(bm25_map)} candidates")
|
|
220
|
+
return bm25_map
|
|
221
|
+
|
|
222
|
+
async def _retrieve_semantic_candidates(
|
|
223
|
+
self, query: str, faiss_k: int
|
|
224
|
+
) -> dict[int, float]:
|
|
225
|
+
"""Retrieve candidates using semantic search on informalizations.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
query: Search query string.
|
|
229
|
+
faiss_k: Number of candidates to retrieve from FAISS.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Map of declaration ID to semantic similarity score.
|
|
233
|
+
"""
|
|
234
|
+
embedding_response = await self.embedding_client.embed([query], is_query=True)
|
|
235
|
+
query_embedding = np.array(
|
|
236
|
+
[embedding_response.embeddings[0]], dtype=np.float32
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
import faiss as faiss_module
|
|
240
|
+
|
|
241
|
+
faiss_module.normalize_L2(query_embedding)
|
|
242
|
+
|
|
243
|
+
informal_index = self.faiss_informal_index
|
|
244
|
+
informal_id_map = self.faiss_informal_id_map
|
|
245
|
+
|
|
246
|
+
if hasattr(informal_index, "nprobe"):
|
|
247
|
+
informal_index.nprobe = 64
|
|
248
|
+
|
|
249
|
+
distances, indices = informal_index.search(query_embedding, faiss_k)
|
|
250
|
+
|
|
251
|
+
semantic_map: dict[int, float] = {}
|
|
252
|
+
for idx, dist in zip(indices[0], distances[0]):
|
|
253
|
+
if idx == -1 or idx >= len(informal_id_map):
|
|
254
|
+
continue
|
|
255
|
+
decl_id = informal_id_map[idx]
|
|
256
|
+
similarity = float(dist)
|
|
257
|
+
semantic_map[decl_id] = max(semantic_map.get(decl_id, 0.0), similarity)
|
|
258
|
+
|
|
259
|
+
logger.info(f"FAISS informal: {len(semantic_map)} candidates")
|
|
260
|
+
return semantic_map
|
|
261
|
+
|
|
262
|
+
def _compute_rrf_scores(
|
|
263
|
+
self,
|
|
264
|
+
bm25_map: dict[int, float],
|
|
265
|
+
semantic_map: dict[int, float],
|
|
266
|
+
) -> list[tuple[int, float]]:
|
|
267
|
+
"""Compute RRF scores from BM25 and semantic retrieval signals.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
bm25_map: Map of declaration ID to BM25 score.
|
|
271
|
+
semantic_map: Map of declaration ID to semantic similarity score.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of (declaration_id, rrf_score) sorted by score descending.
|
|
275
|
+
"""
|
|
276
|
+
all_candidate_ids = set(bm25_map.keys()) | set(semantic_map.keys())
|
|
277
|
+
logger.info(f"Total merged candidates: {len(all_candidate_ids)}")
|
|
278
|
+
|
|
279
|
+
if not all_candidate_ids:
|
|
280
|
+
return []
|
|
281
|
+
|
|
282
|
+
bm25_sorted = sorted(bm25_map.items(), key=lambda x: x[1], reverse=True)
|
|
283
|
+
sem_sorted = sorted(semantic_map.items(), key=lambda x: x[1], reverse=True)
|
|
284
|
+
|
|
285
|
+
bm25_rank_map = {cid: rank + 1 for rank, (cid, _) in enumerate(bm25_sorted)}
|
|
286
|
+
sem_rank_map = {cid: rank + 1 for rank, (cid, _) in enumerate(sem_sorted)}
|
|
287
|
+
|
|
288
|
+
default_bm25_rank = len(bm25_sorted) + 1
|
|
289
|
+
default_sem_rank = len(sem_sorted) + 1
|
|
290
|
+
|
|
291
|
+
rrf_scores: list[tuple[int, float]] = []
|
|
292
|
+
for cid in all_candidate_ids:
|
|
293
|
+
name_rank = bm25_rank_map.get(cid, default_bm25_rank)
|
|
294
|
+
inf_rank = sem_rank_map.get(cid, default_sem_rank)
|
|
295
|
+
rrf_score = 1.0 / name_rank + 1.0 / inf_rank
|
|
296
|
+
rrf_scores.append((cid, rrf_score))
|
|
297
|
+
|
|
298
|
+
rrf_scores.sort(key=lambda x: x[1], reverse=True)
|
|
299
|
+
return rrf_scores
|
|
300
|
+
|
|
301
|
+
async def _apply_dependency_boost(
|
|
302
|
+
self,
|
|
303
|
+
rrf_scores: list[tuple[int, float]],
|
|
304
|
+
top_n: int = 500,
|
|
305
|
+
) -> tuple[list[tuple[int, float]], dict[int, Declaration]]:
|
|
306
|
+
"""Apply dependency-based boost to RRF scores.
|
|
307
|
+
|
|
308
|
+
Declarations that are dependencies of other top candidates get a boost.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
rrf_scores: List of (declaration_id, rrf_score) sorted by score.
|
|
312
|
+
top_n: Number of top candidates to consider for dependency analysis.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Tuple of (boosted_scores, declarations_map).
|
|
316
|
+
"""
|
|
317
|
+
top_ids = [cid for cid, _ in rrf_scores[:top_n]]
|
|
318
|
+
|
|
319
|
+
async with AsyncSession(self.engine) as session:
|
|
320
|
+
stmt = select(Declaration).where(Declaration.id.in_(top_ids))
|
|
321
|
+
result = await session.execute(stmt)
|
|
322
|
+
declarations_map = {d.id: d for d in result.scalars().all()}
|
|
323
|
+
|
|
324
|
+
name_to_id = {
|
|
325
|
+
declarations_map[cid].name: cid
|
|
326
|
+
for cid in top_ids
|
|
327
|
+
if cid in declarations_map
|
|
328
|
+
}
|
|
329
|
+
dep_counts: dict[int, int] = {cid: 0 for cid in top_ids}
|
|
330
|
+
|
|
331
|
+
for cid in top_ids:
|
|
332
|
+
decl = declarations_map.get(cid)
|
|
333
|
+
if decl and decl.dependencies:
|
|
334
|
+
try:
|
|
335
|
+
deps = json.loads(decl.dependencies)
|
|
336
|
+
for dep_name in deps:
|
|
337
|
+
if dep_name in name_to_id:
|
|
338
|
+
dep_counts[name_to_id[dep_name]] += 1
|
|
339
|
+
except json.JSONDecodeError:
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
max_deps = max(dep_counts.values()) if dep_counts else 0
|
|
343
|
+
boosted_scores: list[tuple[int, float]] = []
|
|
344
|
+
|
|
345
|
+
for rank, (cid, _) in enumerate(rrf_scores[:top_n], 1):
|
|
346
|
+
dep_count = dep_counts.get(cid, 0)
|
|
347
|
+
if max_deps > 0 and dep_count > 0:
|
|
348
|
+
dep_rank = (max_deps - dep_count) + 1
|
|
349
|
+
else:
|
|
350
|
+
dep_rank = max_deps + 1 if max_deps > 0 else top_n + 1
|
|
351
|
+
|
|
352
|
+
boosted_score = 1.0 / rank + 1.0 / dep_rank
|
|
353
|
+
boosted_scores.append((cid, boosted_score))
|
|
354
|
+
|
|
355
|
+
boosted_scores.sort(key=lambda x: x[1], reverse=True)
|
|
356
|
+
logger.info(f"Applied dependency boost to top {top_n} candidates")
|
|
357
|
+
return boosted_scores, declarations_map
|
|
358
|
+
|
|
359
|
+
async def _rerank_candidates(
|
|
360
|
+
self,
|
|
361
|
+
query: str,
|
|
362
|
+
scored_results: list[tuple[Declaration, float]],
|
|
363
|
+
limit: int,
|
|
364
|
+
) -> list[SearchResult]:
|
|
365
|
+
"""Apply cross-encoder reranking with additional signals.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
query: Search query string.
|
|
369
|
+
scored_results: List of (declaration, score) tuples.
|
|
370
|
+
limit: Maximum number of results to return.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
List of SearchResult objects after reranking.
|
|
374
|
+
"""
|
|
375
|
+
logger.info(f"Reranking top {len(scored_results)} candidates")
|
|
376
|
+
|
|
377
|
+
documents = [
|
|
378
|
+
f"{decl.name}: {decl.informalization}"
|
|
379
|
+
if decl.informalization
|
|
380
|
+
else decl.name
|
|
381
|
+
for decl, _ in scored_results
|
|
382
|
+
]
|
|
383
|
+
|
|
384
|
+
rerank_response = await self.reranker_client.rerank(query, documents)
|
|
385
|
+
reranker_scores = rerank_response.scores
|
|
386
|
+
|
|
387
|
+
fuzzy_scores = [
|
|
388
|
+
fuzzy_name_score(query, decl.name) for decl, _ in scored_results
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
bm25_informal_scores = self._compute_bm25_on_informalizations(
|
|
392
|
+
query, scored_results
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
dep_counts = self._compute_candidate_dependency_counts(scored_results)
|
|
396
|
+
|
|
397
|
+
norm_reranker = normalize_scores(reranker_scores)
|
|
398
|
+
norm_fuzzy = normalize_scores(fuzzy_scores)
|
|
399
|
+
norm_bm25 = normalize_scores(bm25_informal_scores)
|
|
400
|
+
norm_dep = normalize_dependency_counts(dep_counts)
|
|
401
|
+
|
|
402
|
+
final_scores = []
|
|
403
|
+
for i, (decl, _) in enumerate(scored_results):
|
|
404
|
+
score = 1.0 * norm_reranker[i] + 0.4 * norm_bm25[i] + 0.2 * norm_dep[i]
|
|
405
|
+
if fuzzy_scores[i] >= 0.7:
|
|
406
|
+
score += 1.0 * norm_fuzzy[i]
|
|
407
|
+
final_scores.append(score)
|
|
408
|
+
|
|
409
|
+
combined = sorted(
|
|
410
|
+
zip(scored_results, final_scores),
|
|
411
|
+
key=lambda x: x[1],
|
|
412
|
+
reverse=True,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
return self._filter_and_convert_results(combined, limit)
|
|
416
|
+
|
|
417
|
+
def _compute_bm25_on_informalizations(
|
|
418
|
+
self,
|
|
419
|
+
query: str,
|
|
420
|
+
scored_results: list[tuple[Declaration, float]],
|
|
421
|
+
) -> list[float]:
|
|
422
|
+
"""Compute BM25 scores on informalizations for reranking.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
query: Search query string.
|
|
426
|
+
scored_results: List of (declaration, score) tuples.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
List of BM25 scores for each candidate.
|
|
430
|
+
"""
|
|
431
|
+
informalizations = [
|
|
432
|
+
decl.informalization if decl.informalization else decl.name
|
|
433
|
+
for decl, _ in scored_results
|
|
434
|
+
]
|
|
435
|
+
informal_tokens = [tokenize_words(text) for text in informalizations]
|
|
436
|
+
query_tokens = tokenize_words(query)
|
|
437
|
+
|
|
438
|
+
bm25_informal = bm25s.BM25(method="bm25+")
|
|
439
|
+
bm25_informal.index(informal_tokens)
|
|
440
|
+
results, scores = bm25_informal.retrieve([query_tokens], k=len(informal_tokens))
|
|
441
|
+
|
|
442
|
+
bm25_scores = [0.0] * len(scored_results)
|
|
443
|
+
for idx, score in zip(results[0], scores[0]):
|
|
444
|
+
if int(idx) < len(bm25_scores):
|
|
445
|
+
bm25_scores[int(idx)] = float(score)
|
|
446
|
+
|
|
447
|
+
return bm25_scores
|
|
448
|
+
|
|
449
|
+
def _compute_candidate_dependency_counts(
|
|
450
|
+
self,
|
|
451
|
+
scored_results: list[tuple[Declaration, float]],
|
|
452
|
+
) -> list[int]:
|
|
453
|
+
"""Count how many candidates depend on each declaration.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
scored_results: List of (declaration, score) tuples.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
List of dependency counts for each candidate.
|
|
460
|
+
"""
|
|
461
|
+
candidate_names = {decl.name for decl, _ in scored_results}
|
|
462
|
+
dep_counts_map: dict[str, int] = {name: 0 for name in candidate_names}
|
|
463
|
+
|
|
464
|
+
for decl, _ in scored_results:
|
|
465
|
+
if decl.dependencies:
|
|
466
|
+
try:
|
|
467
|
+
deps = json.loads(decl.dependencies)
|
|
468
|
+
for dep_name in deps:
|
|
469
|
+
if dep_name in dep_counts_map:
|
|
470
|
+
dep_counts_map[dep_name] += 1
|
|
471
|
+
except json.JSONDecodeError:
|
|
472
|
+
pass
|
|
473
|
+
|
|
474
|
+
return [dep_counts_map.get(decl.name, 0) for decl, _ in scored_results]
|
|
475
|
+
|
|
476
|
+
def _filter_and_convert_results(
|
|
477
|
+
self,
|
|
478
|
+
combined: list[tuple[tuple[Declaration, float], float]],
|
|
479
|
+
limit: int,
|
|
480
|
+
) -> list[SearchResult]:
|
|
481
|
+
"""Filter auto-generated declarations and convert to SearchResult.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
combined: List of ((declaration, old_score), final_score) tuples.
|
|
485
|
+
limit: Maximum number of results to return.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
List of SearchResult objects.
|
|
489
|
+
"""
|
|
490
|
+
results = []
|
|
491
|
+
for (decl, _), _ in combined:
|
|
492
|
+
if not is_autogenerated(decl.name):
|
|
493
|
+
results.append(self._to_search_result(decl))
|
|
494
|
+
if len(results) >= limit:
|
|
495
|
+
break
|
|
496
|
+
return results
|
|
497
|
+
|
|
498
|
+
def _extract_package(self, module: str) -> str:
|
|
499
|
+
"""Extract package name from module path.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
module: Full module path (e.g., "Mathlib.Algebra.Group").
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
Package name (first component of module path).
|
|
506
|
+
"""
|
|
507
|
+
return module.split(".")[0] if module else ""
|
|
508
|
+
|
|
509
|
+
def _filter_by_packages(
|
|
510
|
+
self,
|
|
511
|
+
declarations_map: dict[int, Declaration],
|
|
512
|
+
packages: list[str],
|
|
513
|
+
) -> dict[int, Declaration]:
|
|
514
|
+
"""Filter declarations to only include specified packages.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
declarations_map: Map of declaration ID to Declaration.
|
|
518
|
+
packages: List of package names to include.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
Filtered declarations map.
|
|
522
|
+
"""
|
|
523
|
+
if not packages:
|
|
524
|
+
return declarations_map
|
|
525
|
+
|
|
526
|
+
package_set = set(packages)
|
|
527
|
+
return {
|
|
528
|
+
cid: decl
|
|
529
|
+
for cid, decl in declarations_map.items()
|
|
530
|
+
if self._extract_package(decl.module) in package_set
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
async def search(
|
|
534
|
+
self,
|
|
535
|
+
query: str,
|
|
536
|
+
limit: int = 50,
|
|
537
|
+
faiss_k: int = 1000,
|
|
538
|
+
bm25_k: int = 1000,
|
|
539
|
+
rerank_top: int | None = 25,
|
|
540
|
+
packages: list[str] | None = None,
|
|
541
|
+
) -> list[SearchResult]:
|
|
542
|
+
"""Search for Lean declarations using Reciprocal Rank Fusion.
|
|
543
|
+
|
|
544
|
+
Two-signal approach:
|
|
545
|
+
1. BM25+ on declaration names (lexical match)
|
|
546
|
+
2. Semantic search on informalizations (meaning match)
|
|
547
|
+
|
|
548
|
+
Combined via RRF: score = 1/name_rank + 1/informal_rank
|
|
549
|
+
|
|
550
|
+
Optionally applies cross-encoder reranking to the top candidates.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
query: Search query string.
|
|
554
|
+
limit: Maximum number of results to return. Defaults to 50.
|
|
555
|
+
faiss_k: Number of candidates from FAISS index. Defaults to 1000.
|
|
556
|
+
bm25_k: Number of candidates from BM25 index. Defaults to 1000.
|
|
557
|
+
rerank_top: If set, apply cross-encoder reranking to top N candidates.
|
|
558
|
+
Set to 0 or None to skip reranking.
|
|
559
|
+
packages: Optional list of package names to filter by. If provided,
|
|
560
|
+
only declarations from these packages will be returned.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
List of SearchResult objects, ranked by combined score.
|
|
564
|
+
"""
|
|
565
|
+
if not query.strip():
|
|
566
|
+
return []
|
|
567
|
+
|
|
568
|
+
bm25_map = self._retrieve_bm25_candidates(query, bm25_k)
|
|
569
|
+
semantic_map = await self._retrieve_semantic_candidates(query, faiss_k)
|
|
570
|
+
rrf_scores = self._compute_rrf_scores(bm25_map, semantic_map)
|
|
571
|
+
|
|
572
|
+
if not rrf_scores:
|
|
573
|
+
return []
|
|
574
|
+
|
|
575
|
+
boosted_scores, declarations_map = await self._apply_dependency_boost(
|
|
576
|
+
rrf_scores
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# Apply package filtering if specified
|
|
580
|
+
if packages:
|
|
581
|
+
declarations_map = self._filter_by_packages(declarations_map, packages)
|
|
582
|
+
# Filter boosted_scores to only include filtered declarations
|
|
583
|
+
boosted_scores = [
|
|
584
|
+
(cid, score)
|
|
585
|
+
for cid, score in boosted_scores
|
|
586
|
+
if cid in declarations_map
|
|
587
|
+
]
|
|
588
|
+
logger.info(f"Filtered to {len(declarations_map)} in {packages}")
|
|
589
|
+
|
|
590
|
+
top_n = rerank_top if rerank_top and rerank_top > 0 else limit
|
|
591
|
+
|
|
592
|
+
scored_results: list[tuple[Declaration, float]] = [
|
|
593
|
+
(declarations_map[cid], score)
|
|
594
|
+
for cid, score in boosted_scores[:top_n]
|
|
595
|
+
if cid in declarations_map
|
|
596
|
+
]
|
|
597
|
+
|
|
598
|
+
if rerank_top and rerank_top > 0:
|
|
599
|
+
return await self._rerank_candidates(query, scored_results, limit)
|
|
600
|
+
|
|
601
|
+
results = []
|
|
602
|
+
for decl, _ in scored_results:
|
|
603
|
+
if not is_autogenerated(decl.name):
|
|
604
|
+
results.append(self._to_search_result(decl))
|
|
605
|
+
if len(results) >= limit:
|
|
606
|
+
break
|
|
607
|
+
return results
|
|
608
|
+
|
|
609
|
+
async def get_by_id(self, declaration_id: int) -> SearchResult | None:
|
|
610
|
+
"""Retrieve a declaration by ID.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
declaration_id: The declaration ID.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
SearchResult if found, None otherwise.
|
|
617
|
+
"""
|
|
618
|
+
async with AsyncSession(self.engine) as session:
|
|
619
|
+
decl = await session.get(Declaration, declaration_id)
|
|
620
|
+
return self._to_search_result(decl) if decl else None
|
|
621
|
+
|
|
622
|
+
async def get_by_name(self, name: str) -> SearchResult | None:
|
|
623
|
+
"""Retrieve a declaration by its exact name.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
name: The exact declaration name (e.g., "AlgebraicGeometry.Scheme").
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
SearchResult if found, None otherwise.
|
|
630
|
+
"""
|
|
631
|
+
async with AsyncSession(self.engine) as session:
|
|
632
|
+
stmt = select(Declaration).where(Declaration.name == name)
|
|
633
|
+
result = await session.execute(stmt)
|
|
634
|
+
decl = result.scalar_one_or_none()
|
|
635
|
+
return self._to_search_result(decl) if decl else None
|
|
636
|
+
|
|
637
|
+
def _to_search_result(self, decl: Declaration) -> SearchResult:
|
|
638
|
+
"""Convert Declaration ORM object to SearchResult.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
decl: Declaration ORM object.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
SearchResult pydantic model.
|
|
645
|
+
"""
|
|
646
|
+
return SearchResult(
|
|
647
|
+
id=decl.id,
|
|
648
|
+
name=decl.name,
|
|
649
|
+
module=decl.module,
|
|
650
|
+
docstring=decl.docstring,
|
|
651
|
+
source_text=decl.source_text,
|
|
652
|
+
source_link=decl.source_link,
|
|
653
|
+
dependencies=decl.dependencies,
|
|
654
|
+
informalization=decl.informalization,
|
|
655
|
+
)
|