lean-explore 0.3.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lean_explore/__init__.py +14 -1
- lean_explore/api/__init__.py +12 -1
- lean_explore/api/client.py +64 -176
- lean_explore/cli/__init__.py +10 -1
- lean_explore/cli/data_commands.py +184 -489
- lean_explore/cli/display.py +171 -0
- lean_explore/cli/main.py +51 -608
- lean_explore/config.py +244 -0
- lean_explore/extract/__init__.py +5 -0
- lean_explore/extract/__main__.py +368 -0
- lean_explore/extract/doc_gen4.py +200 -0
- lean_explore/extract/doc_parser.py +499 -0
- lean_explore/extract/embeddings.py +369 -0
- lean_explore/extract/github.py +110 -0
- lean_explore/extract/index.py +316 -0
- lean_explore/extract/informalize.py +653 -0
- lean_explore/extract/package_config.py +59 -0
- lean_explore/extract/package_registry.py +45 -0
- lean_explore/extract/package_utils.py +105 -0
- lean_explore/extract/types.py +25 -0
- lean_explore/mcp/__init__.py +11 -1
- lean_explore/mcp/app.py +14 -46
- lean_explore/mcp/server.py +20 -35
- lean_explore/mcp/tools.py +71 -205
- lean_explore/models/__init__.py +9 -0
- lean_explore/models/search_db.py +76 -0
- lean_explore/models/search_types.py +53 -0
- lean_explore/search/__init__.py +32 -0
- lean_explore/search/engine.py +651 -0
- lean_explore/search/scoring.py +156 -0
- lean_explore/search/service.py +68 -0
- lean_explore/search/tokenization.py +71 -0
- lean_explore/util/__init__.py +28 -0
- lean_explore/util/embedding_client.py +92 -0
- lean_explore/util/logging.py +22 -0
- lean_explore/util/openrouter_client.py +63 -0
- lean_explore/util/reranker_client.py +187 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/METADATA +32 -9
- lean_explore-1.0.1.dist-info/RECORD +43 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/WHEEL +1 -1
- lean_explore-1.0.1.dist-info/entry_points.txt +2 -0
- lean_explore/cli/agent.py +0 -788
- lean_explore/cli/config_utils.py +0 -481
- lean_explore/defaults.py +0 -114
- lean_explore/local/__init__.py +0 -1
- lean_explore/local/search.py +0 -1050
- lean_explore/local/service.py +0 -479
- lean_explore/shared/__init__.py +0 -1
- lean_explore/shared/models/__init__.py +0 -1
- lean_explore/shared/models/api.py +0 -117
- lean_explore/shared/models/db.py +0 -396
- lean_explore-0.3.0.dist-info/RECORD +0 -26
- lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/top_level.txt +0 -0
lean_explore/local/search.py
DELETED
|
@@ -1,1050 +0,0 @@
|
|
|
1
|
-
# src/lean_explore/local/search.py
|
|
2
|
-
|
|
3
|
-
"""Performs semantic search and ranked retrieval of StatementGroups.
|
|
4
|
-
|
|
5
|
-
Combines semantic similarity from FAISS, pre-scaled PageRank scores, and
|
|
6
|
-
lexical word matching (on Lean name, docstring, and informal descriptions)
|
|
7
|
-
to rank StatementGroups. It loads necessary assets (embedding model,
|
|
8
|
-
FAISS index, ID map) using default configurations, embeds the user query,
|
|
9
|
-
performs FAISS search, filters based on a similarity threshold,
|
|
10
|
-
retrieves group details from the database, normalizes semantic similarity,
|
|
11
|
-
PageRank, and BM25 scores based on the current candidate set, and then
|
|
12
|
-
combines these normalized scores using configurable weights to produce a
|
|
13
|
-
final ranked list. It also logs search performance statistics to a dedicated
|
|
14
|
-
JSONL file.
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
import argparse
|
|
18
|
-
import datetime
|
|
19
|
-
import json
|
|
20
|
-
import logging
|
|
21
|
-
import os
|
|
22
|
-
import pathlib
|
|
23
|
-
import re
|
|
24
|
-
import sys
|
|
25
|
-
import time
|
|
26
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
27
|
-
|
|
28
|
-
from filelock import FileLock, Timeout
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
import faiss
|
|
32
|
-
import numpy as np
|
|
33
|
-
from nltk.stem.porter import PorterStemmer
|
|
34
|
-
from rank_bm25 import BM25Plus
|
|
35
|
-
from sentence_transformers import SentenceTransformer
|
|
36
|
-
from sqlalchemy import create_engine, or_, select
|
|
37
|
-
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
|
38
|
-
from sqlalchemy.orm import Session, joinedload, sessionmaker
|
|
39
|
-
except ImportError as e:
|
|
40
|
-
print(
|
|
41
|
-
f"Error: Missing required libraries ({e}).\n"
|
|
42
|
-
"Please install them: pip install SQLAlchemy faiss-cpu "
|
|
43
|
-
"sentence-transformers numpy filelock rapidfuzz rank_bm25 nltk",
|
|
44
|
-
file=sys.stderr,
|
|
45
|
-
)
|
|
46
|
-
sys.exit(1)
|
|
47
|
-
|
|
48
|
-
try:
|
|
49
|
-
from lean_explore import defaults
|
|
50
|
-
from lean_explore.shared.models.db import StatementGroup
|
|
51
|
-
except ImportError as e:
|
|
52
|
-
print(
|
|
53
|
-
f"Error: Could not import project modules (StatementGroup, defaults): {e}\n"
|
|
54
|
-
"Ensure 'lean_explore' is installed (e.g., 'pip install -e .') "
|
|
55
|
-
"and all dependencies are met.",
|
|
56
|
-
file=sys.stderr,
|
|
57
|
-
)
|
|
58
|
-
sys.exit(1)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
logging.basicConfig(
|
|
62
|
-
level=logging.WARNING,
|
|
63
|
-
format="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s",
|
|
64
|
-
datefmt="%Y-%m-%d %H:%M:%S",
|
|
65
|
-
)
|
|
66
|
-
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
|
67
|
-
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
|
68
|
-
logger = logging.getLogger(__name__)
|
|
69
|
-
|
|
70
|
-
NEWLINE = os.linesep
|
|
71
|
-
EPSILON = 1e-9
|
|
72
|
-
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
|
73
|
-
|
|
74
|
-
_USER_LOGS_BASE_DIR = defaults.LEAN_EXPLORE_USER_DATA_DIR.parent / "logs"
|
|
75
|
-
PERFORMANCE_LOG_DIR = str(_USER_LOGS_BASE_DIR)
|
|
76
|
-
PERFORMANCE_LOG_FILENAME = "search_stats.jsonl"
|
|
77
|
-
PERFORMANCE_LOG_PATH = os.path.join(PERFORMANCE_LOG_DIR, PERFORMANCE_LOG_FILENAME)
|
|
78
|
-
LOCK_PATH = os.path.join(PERFORMANCE_LOG_DIR, f"{PERFORMANCE_LOG_FILENAME}.lock")
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def log_search_event_to_json(
|
|
82
|
-
status: str,
|
|
83
|
-
duration_ms: float,
|
|
84
|
-
results_count: int,
|
|
85
|
-
error_type: Optional[str] = None,
|
|
86
|
-
) -> None:
|
|
87
|
-
"""Logs a search event as a JSON line to a dedicated performance log file.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
status: A string code indicating the outcome of the search.
|
|
91
|
-
duration_ms: The total duration of the search processing in milliseconds.
|
|
92
|
-
results_count: The number of search results returned.
|
|
93
|
-
error_type: Optional. The type of error if the status indicates an error.
|
|
94
|
-
"""
|
|
95
|
-
log_entry = {
|
|
96
|
-
"timestamp": datetime.datetime.utcnow().isoformat() + "Z",
|
|
97
|
-
"event": "search_processed",
|
|
98
|
-
"status": status,
|
|
99
|
-
"duration_ms": round(duration_ms, 2),
|
|
100
|
-
"results_count": results_count,
|
|
101
|
-
}
|
|
102
|
-
if error_type:
|
|
103
|
-
log_entry["error_type"] = error_type
|
|
104
|
-
|
|
105
|
-
try:
|
|
106
|
-
os.makedirs(PERFORMANCE_LOG_DIR, exist_ok=True)
|
|
107
|
-
except OSError as e:
|
|
108
|
-
logger.error(
|
|
109
|
-
"Performance logging error: Could not create log directory %s: %s. "
|
|
110
|
-
"Log entry: %s",
|
|
111
|
-
PERFORMANCE_LOG_DIR,
|
|
112
|
-
e,
|
|
113
|
-
log_entry,
|
|
114
|
-
exc_info=False,
|
|
115
|
-
)
|
|
116
|
-
print(
|
|
117
|
-
f"FALLBACK_PERF_LOG (DIR_ERROR): {json.dumps(log_entry)}", file=sys.stderr
|
|
118
|
-
)
|
|
119
|
-
return
|
|
120
|
-
|
|
121
|
-
lock = FileLock(LOCK_PATH, timeout=2)
|
|
122
|
-
try:
|
|
123
|
-
with lock:
|
|
124
|
-
with open(PERFORMANCE_LOG_PATH, "a", encoding="utf-8") as f:
|
|
125
|
-
f.write(json.dumps(log_entry) + "\n")
|
|
126
|
-
except Timeout:
|
|
127
|
-
logger.warning(
|
|
128
|
-
"Performance logging error: Timeout acquiring lock for %s. "
|
|
129
|
-
"Log entry lost: %s",
|
|
130
|
-
LOCK_PATH,
|
|
131
|
-
log_entry,
|
|
132
|
-
)
|
|
133
|
-
print(
|
|
134
|
-
f"FALLBACK_PERF_LOG (LOCK_TIMEOUT): {json.dumps(log_entry)}",
|
|
135
|
-
file=sys.stderr,
|
|
136
|
-
)
|
|
137
|
-
except Exception as e:
|
|
138
|
-
logger.error(
|
|
139
|
-
"Performance logging error: Failed to write to %s: %s. Log entry: %s",
|
|
140
|
-
PERFORMANCE_LOG_PATH,
|
|
141
|
-
e,
|
|
142
|
-
log_entry,
|
|
143
|
-
exc_info=False,
|
|
144
|
-
)
|
|
145
|
-
print(
|
|
146
|
-
f"FALLBACK_PERF_LOG (WRITE_ERROR): {json.dumps(log_entry)}", file=sys.stderr
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def load_faiss_assets(
|
|
151
|
-
index_path_str: str, map_path_str: str
|
|
152
|
-
) -> Tuple[Optional[faiss.Index], Optional[List[str]]]:
|
|
153
|
-
"""Loads the FAISS index and ID map from specified file paths.
|
|
154
|
-
|
|
155
|
-
Args:
|
|
156
|
-
index_path_str: String path to the FAISS index file.
|
|
157
|
-
map_path_str: String path to the JSON ID map file.
|
|
158
|
-
|
|
159
|
-
Returns:
|
|
160
|
-
A tuple (faiss.Index or None, list_of_IDs or None).
|
|
161
|
-
"""
|
|
162
|
-
index_path = pathlib.Path(index_path_str).resolve()
|
|
163
|
-
map_path = pathlib.Path(map_path_str).resolve()
|
|
164
|
-
|
|
165
|
-
if not index_path.exists():
|
|
166
|
-
logger.error("FAISS index file not found: %s", index_path)
|
|
167
|
-
return None, None
|
|
168
|
-
if not map_path.exists():
|
|
169
|
-
logger.error("FAISS ID map file not found: %s", map_path)
|
|
170
|
-
return None, None
|
|
171
|
-
|
|
172
|
-
faiss_index_obj: Optional[faiss.Index] = None
|
|
173
|
-
id_map_list: Optional[List[str]] = None
|
|
174
|
-
|
|
175
|
-
try:
|
|
176
|
-
logger.info("Loading FAISS index from %s...", index_path)
|
|
177
|
-
faiss_index_obj = faiss.read_index(str(index_path))
|
|
178
|
-
logger.info(
|
|
179
|
-
"Loaded FAISS index with %d vectors (Metric Type: %s).",
|
|
180
|
-
faiss_index_obj.ntotal,
|
|
181
|
-
faiss_index_obj.metric_type,
|
|
182
|
-
)
|
|
183
|
-
except Exception as e:
|
|
184
|
-
logger.error(
|
|
185
|
-
"Failed to load FAISS index from %s: %s", index_path, e, exc_info=True
|
|
186
|
-
)
|
|
187
|
-
return None, id_map_list
|
|
188
|
-
|
|
189
|
-
try:
|
|
190
|
-
logger.info("Loading ID map from %s...", map_path)
|
|
191
|
-
with open(map_path, encoding="utf-8") as f:
|
|
192
|
-
id_map_list = json.load(f)
|
|
193
|
-
if not isinstance(id_map_list, list):
|
|
194
|
-
logger.error(
|
|
195
|
-
"ID map file (%s) does not contain a valid JSON list.", map_path
|
|
196
|
-
)
|
|
197
|
-
return faiss_index_obj, None
|
|
198
|
-
logger.info("Loaded ID map with %d entries.", len(id_map_list))
|
|
199
|
-
except Exception as e:
|
|
200
|
-
logger.error(
|
|
201
|
-
"Failed to load or parse ID map file %s: %s", map_path, e, exc_info=True
|
|
202
|
-
)
|
|
203
|
-
return faiss_index_obj, None
|
|
204
|
-
|
|
205
|
-
if (
|
|
206
|
-
faiss_index_obj is not None
|
|
207
|
-
and id_map_list is not None
|
|
208
|
-
and faiss_index_obj.ntotal != len(id_map_list)
|
|
209
|
-
):
|
|
210
|
-
logger.warning(
|
|
211
|
-
"Mismatch: FAISS index size (%d) vs ID map size (%d). "
|
|
212
|
-
"Results may be inconsistent.",
|
|
213
|
-
faiss_index_obj.ntotal,
|
|
214
|
-
len(id_map_list),
|
|
215
|
-
)
|
|
216
|
-
return faiss_index_obj, id_map_list
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
|
|
220
|
-
"""Loads the specified Sentence Transformer model.
|
|
221
|
-
|
|
222
|
-
Args:
|
|
223
|
-
model_name: The name or path of the sentence-transformer model.
|
|
224
|
-
|
|
225
|
-
Returns:
|
|
226
|
-
The loaded model, or None if loading fails.
|
|
227
|
-
"""
|
|
228
|
-
logger.info("Loading sentence transformer model '%s'...", model_name)
|
|
229
|
-
try:
|
|
230
|
-
model = SentenceTransformer(model_name)
|
|
231
|
-
logger.info(
|
|
232
|
-
"Model '%s' loaded successfully. Max sequence length: %d.",
|
|
233
|
-
model_name,
|
|
234
|
-
model.max_seq_length,
|
|
235
|
-
)
|
|
236
|
-
return model
|
|
237
|
-
except Exception as e:
|
|
238
|
-
logger.error(
|
|
239
|
-
"Failed to load sentence transformer model '%s': %s",
|
|
240
|
-
model_name,
|
|
241
|
-
e,
|
|
242
|
-
exc_info=True,
|
|
243
|
-
)
|
|
244
|
-
return None
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def spacify_text(text: str) -> str:
|
|
248
|
-
"""Converts a string by adding spaces around delimiters and camelCase.
|
|
249
|
-
|
|
250
|
-
This function takes a string, typically a file path or a name with
|
|
251
|
-
camelCase, and transforms it to a more human-readable format by:
|
|
252
|
-
- Replacing hyphens and underscores with single spaces.
|
|
253
|
-
- Inserting spaces to separate words in camelCase (e.g.,
|
|
254
|
-
'CamelCaseWord' becomes 'Camel Case Word').
|
|
255
|
-
- Adding spaces around common path delimiters such as '/' and '.'.
|
|
256
|
-
- Normalizing multiple consecutive spaces into single spaces.
|
|
257
|
-
- Stripping leading and trailing whitespace from the final string.
|
|
258
|
-
|
|
259
|
-
Args:
|
|
260
|
-
text: The input string to be transformed.
|
|
261
|
-
|
|
262
|
-
Returns:
|
|
263
|
-
The transformed string with spaces inserted for improved readability.
|
|
264
|
-
"""
|
|
265
|
-
text_str = str(text)
|
|
266
|
-
|
|
267
|
-
first_slash_index = text_str.find("/")
|
|
268
|
-
if first_slash_index != -1:
|
|
269
|
-
text_str = text_str[first_slash_index + 1 :]
|
|
270
|
-
|
|
271
|
-
text_str = text_str.replace("-", " ").replace("_", " ").replace(".lean", "")
|
|
272
|
-
|
|
273
|
-
text_str = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", text_str)
|
|
274
|
-
text_str = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", text_str)
|
|
275
|
-
|
|
276
|
-
text_str = text_str.replace("/", " ")
|
|
277
|
-
text_str = text_str.replace(".", " ")
|
|
278
|
-
|
|
279
|
-
text_str = re.sub(r"\s+", " ", text_str).strip()
|
|
280
|
-
text_str = text_str.lower()
|
|
281
|
-
return text_str
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
def perform_search(
|
|
285
|
-
session: Session,
|
|
286
|
-
query_string: str,
|
|
287
|
-
model: SentenceTransformer,
|
|
288
|
-
faiss_index: faiss.Index,
|
|
289
|
-
text_chunk_id_map: List[str],
|
|
290
|
-
faiss_k: int,
|
|
291
|
-
pagerank_weight: float,
|
|
292
|
-
text_relevance_weight: float,
|
|
293
|
-
log_searches: bool,
|
|
294
|
-
name_match_weight: float = defaults.DEFAULT_NAME_MATCH_WEIGHT,
|
|
295
|
-
selected_packages: Optional[List[str]] = None,
|
|
296
|
-
semantic_similarity_threshold: float = defaults.DEFAULT_SEM_SIM_THRESHOLD,
|
|
297
|
-
faiss_nprobe: int = defaults.DEFAULT_FAISS_NPROBE,
|
|
298
|
-
faiss_oversampling_factor: int = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR,
|
|
299
|
-
) -> List[Tuple[StatementGroup, Dict[str, float]]]:
|
|
300
|
-
"""Performs semantic and lexical search, then ranks results.
|
|
301
|
-
|
|
302
|
-
Scores (semantic similarity, PageRank, BM25) are normalized to a 0-1
|
|
303
|
-
range based on the current set of candidates before being weighted and
|
|
304
|
-
combined. If `selected_packages` are specified, `faiss_k` is multiplied
|
|
305
|
-
by `faiss_oversampling_factor` to retrieve more initial candidates.
|
|
306
|
-
|
|
307
|
-
Args:
|
|
308
|
-
session: SQLAlchemy session for database access.
|
|
309
|
-
query_string: The user's search query string.
|
|
310
|
-
model: The loaded SentenceTransformer embedding model.
|
|
311
|
-
faiss_index: The loaded FAISS index for text chunks.
|
|
312
|
-
text_chunk_id_map: A list mapping FAISS internal indices to text chunk IDs.
|
|
313
|
-
faiss_k: The base number of nearest neighbors to retrieve from FAISS.
|
|
314
|
-
pagerank_weight: Weight for the PageRank score.
|
|
315
|
-
text_relevance_weight: Weight for the semantic similarity score.
|
|
316
|
-
log_searches: If True, search performance data will be logged.
|
|
317
|
-
name_match_weight: Weight for the lexical word match score (BM25).
|
|
318
|
-
Defaults to `defaults.DEFAULT_NAME_MATCH_WEIGHT`.
|
|
319
|
-
selected_packages: Optional list of package names to filter search by.
|
|
320
|
-
Defaults to None.
|
|
321
|
-
semantic_similarity_threshold: Minimum similarity for a result to be
|
|
322
|
-
considered. Defaults to `defaults.DEFAULT_SEM_SIM_THRESHOLD`.
|
|
323
|
-
faiss_nprobe: Number of closest cells/clusters to search for IVF-type
|
|
324
|
-
FAISS indexes. Defaults to `defaults.DEFAULT_FAISS_NPROBE`.
|
|
325
|
-
faiss_oversampling_factor: Factor to multiply `faiss_k` by when
|
|
326
|
-
`selected_packages` are active.
|
|
327
|
-
Defaults to `defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR`.
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
Returns:
|
|
331
|
-
A list of tuples, sorted by `final_score`. Each tuple contains a
|
|
332
|
-
`StatementGroup` object and a dictionary of its scores.
|
|
333
|
-
The score dictionary includes:
|
|
334
|
-
- 'final_score': The combined weighted score.
|
|
335
|
-
- 'raw_similarity': Original FAISS similarity (0-1).
|
|
336
|
-
- 'norm_similarity': `raw_similarity` normalized across current results.
|
|
337
|
-
- 'original_pagerank_score': PageRank score from the database.
|
|
338
|
-
- 'scaled_pagerank': `original_pagerank_score` normalized across current
|
|
339
|
-
results (this key is kept for compatibility, but
|
|
340
|
-
now holds the normalized PageRank).
|
|
341
|
-
- 'raw_word_match_score': Original BM25 score.
|
|
342
|
-
- 'norm_word_match_score': `raw_word_match_score` normalized across
|
|
343
|
-
current results.
|
|
344
|
-
- Weighted components: `weighted_norm_similarity`,
|
|
345
|
-
`weighted_scaled_pagerank` (uses normalized PageRank),
|
|
346
|
-
`weighted_word_match_score` (uses normalized BM25 score).
|
|
347
|
-
|
|
348
|
-
Raises:
|
|
349
|
-
Exception: If critical errors like query embedding or FAISS search fail.
|
|
350
|
-
"""
|
|
351
|
-
overall_start_time = time.time()
|
|
352
|
-
|
|
353
|
-
logger.info("Search request event initiated.")
|
|
354
|
-
if semantic_similarity_threshold > 0.0 + EPSILON:
|
|
355
|
-
logger.info(
|
|
356
|
-
"Applying semantic similarity threshold: %.3f",
|
|
357
|
-
semantic_similarity_threshold,
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
if not query_string.strip():
|
|
361
|
-
logger.warning("Empty query provided. Returning no results.")
|
|
362
|
-
if log_searches:
|
|
363
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
364
|
-
log_search_event_to_json(
|
|
365
|
-
status="EMPTY_QUERY_SUBMITTED", duration_ms=duration_ms, results_count=0
|
|
366
|
-
)
|
|
367
|
-
return []
|
|
368
|
-
|
|
369
|
-
try:
|
|
370
|
-
query_embedding = model.encode([query_string.strip()], convert_to_numpy=True)[
|
|
371
|
-
0
|
|
372
|
-
].astype(np.float32)
|
|
373
|
-
query_embedding_reshaped = np.expand_dims(query_embedding, axis=0)
|
|
374
|
-
if faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
375
|
-
logger.debug(
|
|
376
|
-
"Normalizing query embedding for Inner Product (cosine) search."
|
|
377
|
-
)
|
|
378
|
-
faiss.normalize_L2(query_embedding_reshaped)
|
|
379
|
-
except Exception as e:
|
|
380
|
-
logger.error("Failed to embed query: %s", e, exc_info=True)
|
|
381
|
-
if log_searches:
|
|
382
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
383
|
-
log_search_event_to_json(
|
|
384
|
-
status="EMBEDDING_ERROR",
|
|
385
|
-
duration_ms=duration_ms,
|
|
386
|
-
results_count=0,
|
|
387
|
-
error_type=type(e).__name__,
|
|
388
|
-
)
|
|
389
|
-
raise Exception(f"Query embedding failed: {e}") from e
|
|
390
|
-
|
|
391
|
-
actual_faiss_k_to_use = faiss_k
|
|
392
|
-
if selected_packages and faiss_oversampling_factor > 1:
|
|
393
|
-
actual_faiss_k_to_use = faiss_k * faiss_oversampling_factor
|
|
394
|
-
logger.info(
|
|
395
|
-
f"Package filter active. "
|
|
396
|
-
f"Using oversampled FAISS K: {actual_faiss_k_to_use} "
|
|
397
|
-
f"(base K: {faiss_k}, factor: {faiss_oversampling_factor})"
|
|
398
|
-
)
|
|
399
|
-
else:
|
|
400
|
-
logger.info(f"Using FAISS K: {actual_faiss_k_to_use} for initial retrieval.")
|
|
401
|
-
|
|
402
|
-
try:
|
|
403
|
-
logger.debug(
|
|
404
|
-
"Searching FAISS index for top %d text chunk neighbors...",
|
|
405
|
-
actual_faiss_k_to_use,
|
|
406
|
-
)
|
|
407
|
-
if hasattr(faiss_index, "nprobe") and isinstance(faiss_index.nprobe, int):
|
|
408
|
-
if faiss_nprobe > 0:
|
|
409
|
-
faiss_index.nprobe = faiss_nprobe
|
|
410
|
-
logger.debug(f"Set FAISS nprobe to: {faiss_index.nprobe}")
|
|
411
|
-
else:
|
|
412
|
-
logger.warning(
|
|
413
|
-
f"Configured faiss_nprobe is {faiss_nprobe}. Must be > 0. "
|
|
414
|
-
"Using FAISS default or previously set nprobe for this IVF index."
|
|
415
|
-
)
|
|
416
|
-
distances, indices = faiss_index.search(
|
|
417
|
-
query_embedding_reshaped, actual_faiss_k_to_use
|
|
418
|
-
)
|
|
419
|
-
except Exception as e:
|
|
420
|
-
logger.error("FAISS search failed: %s", e, exc_info=True)
|
|
421
|
-
if log_searches:
|
|
422
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
423
|
-
log_search_event_to_json(
|
|
424
|
-
status="FAISS_SEARCH_ERROR",
|
|
425
|
-
duration_ms=duration_ms,
|
|
426
|
-
results_count=0,
|
|
427
|
-
error_type=type(e).__name__,
|
|
428
|
-
)
|
|
429
|
-
raise Exception(f"FAISS search failed: {e}") from e
|
|
430
|
-
|
|
431
|
-
sg_candidates_raw_similarity: Dict[int, float] = {}
|
|
432
|
-
if indices.size > 0 and distances.size > 0:
|
|
433
|
-
for i, faiss_internal_idx in enumerate(indices[0]):
|
|
434
|
-
if faiss_internal_idx == -1:
|
|
435
|
-
continue
|
|
436
|
-
try:
|
|
437
|
-
text_chunk_id_str = text_chunk_id_map[faiss_internal_idx]
|
|
438
|
-
raw_faiss_score = distances[0][i]
|
|
439
|
-
similarity_score: float
|
|
440
|
-
|
|
441
|
-
if faiss_index.metric_type == faiss.METRIC_L2:
|
|
442
|
-
similarity_score = 1.0 / (1.0 + np.sqrt(max(0, raw_faiss_score)))
|
|
443
|
-
elif faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
444
|
-
similarity_score = raw_faiss_score
|
|
445
|
-
else:
|
|
446
|
-
similarity_score = 1.0 / (1.0 + max(0, raw_faiss_score))
|
|
447
|
-
logger.warning(
|
|
448
|
-
"Unhandled FAISS metric type %d for text chunk. "
|
|
449
|
-
"Using 1/(1+score) for similarity.",
|
|
450
|
-
faiss_index.metric_type,
|
|
451
|
-
)
|
|
452
|
-
similarity_score = max(0.0, min(1.0, similarity_score))
|
|
453
|
-
|
|
454
|
-
parts = text_chunk_id_str.split("_")
|
|
455
|
-
if len(parts) >= 2 and parts[0] == "sg":
|
|
456
|
-
try:
|
|
457
|
-
sg_id = int(parts[1])
|
|
458
|
-
if (
|
|
459
|
-
sg_id not in sg_candidates_raw_similarity
|
|
460
|
-
or similarity_score > sg_candidates_raw_similarity[sg_id]
|
|
461
|
-
):
|
|
462
|
-
sg_candidates_raw_similarity[sg_id] = similarity_score
|
|
463
|
-
except ValueError:
|
|
464
|
-
logger.warning(
|
|
465
|
-
"Could not parse StatementGroup ID from chunk_id: %s",
|
|
466
|
-
text_chunk_id_str,
|
|
467
|
-
)
|
|
468
|
-
else:
|
|
469
|
-
logger.warning(
|
|
470
|
-
"Malformed text_chunk_id format: %s", text_chunk_id_str
|
|
471
|
-
)
|
|
472
|
-
except IndexError:
|
|
473
|
-
logger.warning(
|
|
474
|
-
"FAISS internal index %d out of bounds for ID map (size %d). "
|
|
475
|
-
"Possible data inconsistency.",
|
|
476
|
-
faiss_internal_idx,
|
|
477
|
-
len(text_chunk_id_map),
|
|
478
|
-
)
|
|
479
|
-
except Exception as e:
|
|
480
|
-
logger.warning(
|
|
481
|
-
"Error processing FAISS result for internal index %d "
|
|
482
|
-
"(chunk_id '%s'): %s",
|
|
483
|
-
faiss_internal_idx,
|
|
484
|
-
text_chunk_id_str if "text_chunk_id_str" in locals() else "N/A",
|
|
485
|
-
e,
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
if not sg_candidates_raw_similarity:
|
|
489
|
-
logger.info(
|
|
490
|
-
"No valid StatementGroup candidates found after FAISS search and parsing."
|
|
491
|
-
)
|
|
492
|
-
if log_searches:
|
|
493
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
494
|
-
log_search_event_to_json(
|
|
495
|
-
status="NO_FAISS_CANDIDATES", duration_ms=duration_ms, results_count=0
|
|
496
|
-
)
|
|
497
|
-
return []
|
|
498
|
-
logger.info(
|
|
499
|
-
"Aggregated %d unique StatementGroup candidates from FAISS results.",
|
|
500
|
-
len(sg_candidates_raw_similarity),
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
if semantic_similarity_threshold > 0.0 + EPSILON:
|
|
504
|
-
initial_candidate_count = len(sg_candidates_raw_similarity)
|
|
505
|
-
sg_candidates_raw_similarity = {
|
|
506
|
-
sg_id: sim
|
|
507
|
-
for sg_id, sim in sg_candidates_raw_similarity.items()
|
|
508
|
-
if sim >= semantic_similarity_threshold
|
|
509
|
-
}
|
|
510
|
-
logger.info(
|
|
511
|
-
"Post-thresholding: %d of %d candidates remaining (threshold: %.3f).",
|
|
512
|
-
len(sg_candidates_raw_similarity),
|
|
513
|
-
initial_candidate_count,
|
|
514
|
-
semantic_similarity_threshold,
|
|
515
|
-
)
|
|
516
|
-
|
|
517
|
-
if not sg_candidates_raw_similarity:
|
|
518
|
-
logger.info(
|
|
519
|
-
"No StatementGroup candidates met the semantic similarity "
|
|
520
|
-
"threshold of %.3f.",
|
|
521
|
-
semantic_similarity_threshold,
|
|
522
|
-
)
|
|
523
|
-
if log_searches:
|
|
524
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
525
|
-
log_search_event_to_json(
|
|
526
|
-
status="NO_CANDIDATES_POST_THRESHOLD",
|
|
527
|
-
duration_ms=duration_ms,
|
|
528
|
-
results_count=0,
|
|
529
|
-
)
|
|
530
|
-
return []
|
|
531
|
-
|
|
532
|
-
candidate_sg_ids = list(sg_candidates_raw_similarity.keys())
|
|
533
|
-
sg_objects_map: Dict[int, StatementGroup] = {}
|
|
534
|
-
try:
|
|
535
|
-
logger.debug(
|
|
536
|
-
"Fetching StatementGroup details from DB for %d IDs...",
|
|
537
|
-
len(candidate_sg_ids),
|
|
538
|
-
)
|
|
539
|
-
stmt = select(StatementGroup).where(StatementGroup.id.in_(candidate_sg_ids))
|
|
540
|
-
|
|
541
|
-
if selected_packages:
|
|
542
|
-
logger.info("Filtering search by packages: %s", selected_packages)
|
|
543
|
-
package_filters_sqla = []
|
|
544
|
-
for pkg_name in selected_packages:
|
|
545
|
-
if pkg_name.strip():
|
|
546
|
-
package_filters_sqla.append(
|
|
547
|
-
StatementGroup.source_file.startswith(pkg_name.strip() + "/")
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
if package_filters_sqla:
|
|
551
|
-
stmt = stmt.where(or_(*package_filters_sqla))
|
|
552
|
-
|
|
553
|
-
stmt = stmt.options(joinedload(StatementGroup.primary_declaration))
|
|
554
|
-
db_results = session.execute(stmt).scalars().unique().all()
|
|
555
|
-
for sg_obj in db_results:
|
|
556
|
-
sg_objects_map[sg_obj.id] = sg_obj
|
|
557
|
-
|
|
558
|
-
logger.debug(
|
|
559
|
-
"Fetched details for %d StatementGroups from DB that matched filters.",
|
|
560
|
-
len(sg_objects_map),
|
|
561
|
-
)
|
|
562
|
-
final_candidate_ids_after_db_match = set(sg_objects_map.keys())
|
|
563
|
-
original_faiss_candidate_ids = set(candidate_sg_ids)
|
|
564
|
-
|
|
565
|
-
if len(final_candidate_ids_after_db_match) < len(original_faiss_candidate_ids):
|
|
566
|
-
missing_from_db_or_filtered_out = (
|
|
567
|
-
original_faiss_candidate_ids - final_candidate_ids_after_db_match
|
|
568
|
-
)
|
|
569
|
-
logger.info(
|
|
570
|
-
"%d candidates from FAISS (post-threshold) were not found in DB "
|
|
571
|
-
"or excluded by package filters: (e.g., %s).",
|
|
572
|
-
len(missing_from_db_or_filtered_out),
|
|
573
|
-
list(missing_from_db_or_filtered_out)[:5],
|
|
574
|
-
)
|
|
575
|
-
|
|
576
|
-
except SQLAlchemyError as e:
|
|
577
|
-
logger.error(
|
|
578
|
-
"Database query for StatementGroup details failed: %s", e, exc_info=True
|
|
579
|
-
)
|
|
580
|
-
if log_searches:
|
|
581
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
582
|
-
log_search_event_to_json(
|
|
583
|
-
status="DB_FETCH_ERROR",
|
|
584
|
-
duration_ms=duration_ms,
|
|
585
|
-
results_count=0,
|
|
586
|
-
error_type=type(e).__name__,
|
|
587
|
-
)
|
|
588
|
-
raise
|
|
589
|
-
|
|
590
|
-
results_with_scores: List[Tuple[StatementGroup, Dict[str, float]]] = []
|
|
591
|
-
candidate_semantic_similarities: List[float] = []
|
|
592
|
-
candidate_pagerank_scores: List[float] = []
|
|
593
|
-
|
|
594
|
-
processed_candidates_data: List[Dict[str, Any]] = []
|
|
595
|
-
|
|
596
|
-
for sg_id in final_candidate_ids_after_db_match:
|
|
597
|
-
sg_obj = sg_objects_map[sg_id]
|
|
598
|
-
raw_sem_sim = sg_candidates_raw_similarity[sg_id]
|
|
599
|
-
|
|
600
|
-
processed_candidates_data.append(
|
|
601
|
-
{
|
|
602
|
-
"sg_obj": sg_obj,
|
|
603
|
-
"raw_sem_sim": raw_sem_sim,
|
|
604
|
-
"original_pagerank": sg_obj.scaled_pagerank_score
|
|
605
|
-
if sg_obj.scaled_pagerank_score is not None
|
|
606
|
-
else 0.0,
|
|
607
|
-
}
|
|
608
|
-
)
|
|
609
|
-
candidate_semantic_similarities.append(raw_sem_sim)
|
|
610
|
-
candidate_pagerank_scores.append(
|
|
611
|
-
sg_obj.scaled_pagerank_score
|
|
612
|
-
if sg_obj.scaled_pagerank_score is not None
|
|
613
|
-
else 0.0
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
if not processed_candidates_data:
|
|
617
|
-
logger.info(
|
|
618
|
-
"No candidates remaining after matching with DB data or other "
|
|
619
|
-
"processing steps."
|
|
620
|
-
)
|
|
621
|
-
if log_searches:
|
|
622
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
623
|
-
log_search_event_to_json(
|
|
624
|
-
status="NO_CANDIDATES_POST_PROCESSING",
|
|
625
|
-
duration_ms=duration_ms,
|
|
626
|
-
results_count=0,
|
|
627
|
-
)
|
|
628
|
-
return []
|
|
629
|
-
|
|
630
|
-
stemmer = PorterStemmer()
|
|
631
|
-
|
|
632
|
-
def _get_tokenized_list(text_to_tokenize: str) -> List[str]:
|
|
633
|
-
if not text_to_tokenize:
|
|
634
|
-
return []
|
|
635
|
-
tokens = re.findall(r"\w+", text_to_tokenize.lower())
|
|
636
|
-
return [stemmer.stem(token) for token in tokens]
|
|
637
|
-
|
|
638
|
-
tokenized_query = _get_tokenized_list(query_string.strip())
|
|
639
|
-
bm25_corpus: List[List[str]] = []
|
|
640
|
-
for candidate_item_data in processed_candidates_data:
|
|
641
|
-
sg_obj_for_corpus = candidate_item_data["sg_obj"]
|
|
642
|
-
combined_text_for_bm25 = " ".join(
|
|
643
|
-
filter(
|
|
644
|
-
None,
|
|
645
|
-
[
|
|
646
|
-
(
|
|
647
|
-
sg_obj_for_corpus.primary_declaration.lean_name
|
|
648
|
-
if sg_obj_for_corpus.primary_declaration
|
|
649
|
-
else None
|
|
650
|
-
),
|
|
651
|
-
sg_obj_for_corpus.docstring,
|
|
652
|
-
sg_obj_for_corpus.informal_description,
|
|
653
|
-
sg_obj_for_corpus.informal_summary,
|
|
654
|
-
sg_obj_for_corpus.display_statement_text,
|
|
655
|
-
(
|
|
656
|
-
sg_obj_for_corpus.primary_declaration.lean_name
|
|
657
|
-
if sg_obj_for_corpus.primary_declaration
|
|
658
|
-
else None
|
|
659
|
-
),
|
|
660
|
-
(
|
|
661
|
-
spacify_text(sg_obj_for_corpus.primary_declaration.source_file)
|
|
662
|
-
if sg_obj_for_corpus.primary_declaration
|
|
663
|
-
and sg_obj_for_corpus.primary_declaration.source_file
|
|
664
|
-
else None
|
|
665
|
-
),
|
|
666
|
-
],
|
|
667
|
-
)
|
|
668
|
-
)
|
|
669
|
-
bm25_corpus.append(_get_tokenized_list(combined_text_for_bm25))
|
|
670
|
-
|
|
671
|
-
raw_bm25_scores_list: List[float] = [0.0] * len(processed_candidates_data)
|
|
672
|
-
if tokenized_query and any(bm25_corpus):
|
|
673
|
-
try:
|
|
674
|
-
bm25_model = BM25Plus(bm25_corpus)
|
|
675
|
-
raw_bm25_scores_list = bm25_model.get_scores(tokenized_query)
|
|
676
|
-
raw_bm25_scores_list = [
|
|
677
|
-
max(0.0, float(score)) for score in raw_bm25_scores_list
|
|
678
|
-
]
|
|
679
|
-
except Exception as e:
|
|
680
|
-
logger.warning(
|
|
681
|
-
"BM25Plus scoring failed: %s. Word match scores defaulted to 0.",
|
|
682
|
-
e,
|
|
683
|
-
exc_info=False,
|
|
684
|
-
)
|
|
685
|
-
raw_bm25_scores_list = [0.0] * len(processed_candidates_data)
|
|
686
|
-
|
|
687
|
-
min_sem_sim = (
|
|
688
|
-
min(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
|
|
689
|
-
)
|
|
690
|
-
max_sem_sim = (
|
|
691
|
-
max(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
|
|
692
|
-
)
|
|
693
|
-
range_sem_sim = max_sem_sim - min_sem_sim
|
|
694
|
-
logger.debug(
|
|
695
|
-
"Raw semantic similarity range for normalization: [%.4f, %.4f]",
|
|
696
|
-
min_sem_sim,
|
|
697
|
-
max_sem_sim,
|
|
698
|
-
)
|
|
699
|
-
|
|
700
|
-
min_pr = min(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
|
|
701
|
-
max_pr = max(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
|
|
702
|
-
range_pr = max_pr - min_pr
|
|
703
|
-
logger.debug(
|
|
704
|
-
"Original PageRank score range for normalization: [%.4f, %.4f]", min_pr, max_pr
|
|
705
|
-
)
|
|
706
|
-
|
|
707
|
-
min_bm25 = min(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
|
|
708
|
-
max_bm25 = max(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
|
|
709
|
-
range_bm25 = max_bm25 - min_bm25
|
|
710
|
-
logger.debug(
|
|
711
|
-
"Raw BM25 score range for normalization: [%.4f, %.4f]", min_bm25, max_bm25
|
|
712
|
-
)
|
|
713
|
-
|
|
714
|
-
for i, candidate_data in enumerate(processed_candidates_data):
|
|
715
|
-
sg_obj = candidate_data["sg_obj"]
|
|
716
|
-
current_raw_sem_sim = candidate_data["raw_sem_sim"]
|
|
717
|
-
original_pagerank_score = candidate_data["original_pagerank"]
|
|
718
|
-
original_bm25_score = raw_bm25_scores_list[i]
|
|
719
|
-
|
|
720
|
-
norm_sem_sim = 0.5
|
|
721
|
-
if candidate_semantic_similarities:
|
|
722
|
-
if range_sem_sim > EPSILON:
|
|
723
|
-
norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
|
|
724
|
-
elif (
|
|
725
|
-
len(candidate_semantic_similarities) == 1
|
|
726
|
-
and candidate_semantic_similarities[0] > EPSILON
|
|
727
|
-
):
|
|
728
|
-
norm_sem_sim = 1.0
|
|
729
|
-
elif (
|
|
730
|
-
len(candidate_semantic_similarities) > 0
|
|
731
|
-
and range_sem_sim <= EPSILON
|
|
732
|
-
and max_sem_sim <= EPSILON
|
|
733
|
-
):
|
|
734
|
-
norm_sem_sim = 0.0
|
|
735
|
-
else:
|
|
736
|
-
norm_sem_sim = 0.0
|
|
737
|
-
norm_sem_sim = max(0.0, min(1.0, norm_sem_sim))
|
|
738
|
-
|
|
739
|
-
norm_pagerank_score = 0.0
|
|
740
|
-
if candidate_pagerank_scores:
|
|
741
|
-
if range_pr > EPSILON:
|
|
742
|
-
norm_pagerank_score = (original_pagerank_score - min_pr) / range_pr
|
|
743
|
-
elif max_pr > EPSILON:
|
|
744
|
-
norm_pagerank_score = 1.0
|
|
745
|
-
norm_pagerank_score = max(0.0, min(1.0, norm_pagerank_score))
|
|
746
|
-
|
|
747
|
-
norm_bm25_score = 0.0
|
|
748
|
-
if raw_bm25_scores_list:
|
|
749
|
-
if range_bm25 > EPSILON:
|
|
750
|
-
norm_bm25_score = (original_bm25_score - min_bm25) / range_bm25
|
|
751
|
-
elif max_bm25 > EPSILON:
|
|
752
|
-
norm_bm25_score = 1.0
|
|
753
|
-
norm_bm25_score = max(0.0, min(1.0, norm_bm25_score))
|
|
754
|
-
|
|
755
|
-
weighted_norm_similarity = text_relevance_weight * norm_sem_sim
|
|
756
|
-
weighted_norm_pagerank = pagerank_weight * norm_pagerank_score
|
|
757
|
-
weighted_norm_bm25_score = name_match_weight * norm_bm25_score
|
|
758
|
-
|
|
759
|
-
final_score = (
|
|
760
|
-
weighted_norm_similarity + weighted_norm_pagerank + weighted_norm_bm25_score
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
score_dict = {
|
|
764
|
-
"final_score": final_score,
|
|
765
|
-
"raw_similarity": current_raw_sem_sim,
|
|
766
|
-
"norm_similarity": norm_sem_sim,
|
|
767
|
-
"original_pagerank_score": original_pagerank_score,
|
|
768
|
-
"scaled_pagerank": norm_pagerank_score,
|
|
769
|
-
"raw_word_match_score": original_bm25_score,
|
|
770
|
-
"norm_word_match_score": norm_bm25_score,
|
|
771
|
-
"weighted_norm_similarity": weighted_norm_similarity,
|
|
772
|
-
"weighted_scaled_pagerank": weighted_norm_pagerank,
|
|
773
|
-
"weighted_word_match_score": weighted_norm_bm25_score,
|
|
774
|
-
}
|
|
775
|
-
results_with_scores.append((sg_obj, score_dict))
|
|
776
|
-
|
|
777
|
-
results_with_scores.sort(key=lambda item: item[1]["final_score"], reverse=True)
|
|
778
|
-
|
|
779
|
-
final_status = "SUCCESS"
|
|
780
|
-
results_count = len(results_with_scores)
|
|
781
|
-
if not results_with_scores and processed_candidates_data:
|
|
782
|
-
final_status = "NO_RESULTS_FINAL_SCORED"
|
|
783
|
-
elif not results_with_scores and not processed_candidates_data:
|
|
784
|
-
if not sg_candidates_raw_similarity:
|
|
785
|
-
final_status = "NO_CANDIDATES_POST_THRESHOLD"
|
|
786
|
-
|
|
787
|
-
if log_searches:
|
|
788
|
-
duration_ms = (time.time() - overall_start_time) * 1000
|
|
789
|
-
log_search_event_to_json(
|
|
790
|
-
status=final_status, duration_ms=duration_ms, results_count=results_count
|
|
791
|
-
)
|
|
792
|
-
|
|
793
|
-
return results_with_scores
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> None:
|
|
797
|
-
"""Formats and prints the search results to the console.
|
|
798
|
-
|
|
799
|
-
Args:
|
|
800
|
-
results: A list of tuples, each containing a StatementGroup
|
|
801
|
-
object and its scores, sorted by final_score.
|
|
802
|
-
"""
|
|
803
|
-
if not results:
|
|
804
|
-
print("\nNo results found.")
|
|
805
|
-
return
|
|
806
|
-
|
|
807
|
-
print(f"\n--- Top {len(results)} Search Results (StatementGroups) ---")
|
|
808
|
-
for i, (sg_obj, scores) in enumerate(results):
|
|
809
|
-
primary_decl_name = (
|
|
810
|
-
sg_obj.primary_declaration.lean_name
|
|
811
|
-
if sg_obj.primary_declaration and sg_obj.primary_declaration.lean_name
|
|
812
|
-
else "N/A"
|
|
813
|
-
)
|
|
814
|
-
print(
|
|
815
|
-
f"\n{i + 1}. Lean Name: {primary_decl_name} (SG ID: {sg_obj.id})\n"
|
|
816
|
-
f" Final Score: {scores['final_score']:.4f} ("
|
|
817
|
-
f"NormSim*W: {scores['weighted_norm_similarity']:.4f}, "
|
|
818
|
-
f"NormPR*W: {scores['weighted_scaled_pagerank']:.4f}, "
|
|
819
|
-
f"NormWordMatch*W: {scores['weighted_word_match_score']:.4f})"
|
|
820
|
-
)
|
|
821
|
-
print(
|
|
822
|
-
f" Scores: [NormSim: {scores['norm_similarity']:.4f} "
|
|
823
|
-
f"(Raw: {scores['raw_similarity']:.4f}), "
|
|
824
|
-
f"NormPR: {scores['scaled_pagerank']:.4f} "
|
|
825
|
-
f"(Original: {scores['original_pagerank_score']:.4f}), "
|
|
826
|
-
f"NormWordMatch: {scores['norm_word_match_score']:.4f} "
|
|
827
|
-
f"(OriginalBM25: {scores['raw_word_match_score']:.2f})]"
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
lean_display = (
|
|
831
|
-
sg_obj.display_statement_text or sg_obj.statement_text or "[No Lean code]"
|
|
832
|
-
)
|
|
833
|
-
lean_display_short = (
|
|
834
|
-
(lean_display[:200] + "...") if len(lean_display) > 200 else lean_display
|
|
835
|
-
)
|
|
836
|
-
print(f" Lean Code: {lean_display_short.replace(NEWLINE, ' ')}")
|
|
837
|
-
|
|
838
|
-
desc_display = (
|
|
839
|
-
sg_obj.informal_description or sg_obj.docstring or "[No description]"
|
|
840
|
-
)
|
|
841
|
-
desc_display_short = (
|
|
842
|
-
(desc_display[:150] + "...") if len(desc_display) > 150 else desc_display
|
|
843
|
-
)
|
|
844
|
-
print(f" Description: {desc_display_short.replace(NEWLINE, ' ')}")
|
|
845
|
-
|
|
846
|
-
source_loc = sg_obj.source_file or "[No source file]"
|
|
847
|
-
if source_loc.startswith("Mathlib/"):
|
|
848
|
-
source_loc = source_loc[len("Mathlib/") :]
|
|
849
|
-
print(f" File: {source_loc}:{sg_obj.range_start_line}")
|
|
850
|
-
|
|
851
|
-
print("\n---------------------------------------------------")
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
def parse_arguments() -> argparse.Namespace:
|
|
855
|
-
"""Parses command-line arguments for the search script.
|
|
856
|
-
|
|
857
|
-
Returns:
|
|
858
|
-
An object containing the parsed arguments.
|
|
859
|
-
"""
|
|
860
|
-
parser = argparse.ArgumentParser(
|
|
861
|
-
description="Search Lean StatementGroups using combined scoring.",
|
|
862
|
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
863
|
-
)
|
|
864
|
-
parser.add_argument("query", type=str, help="The search query string.")
|
|
865
|
-
parser.add_argument(
|
|
866
|
-
"--limit",
|
|
867
|
-
"-n",
|
|
868
|
-
type=int,
|
|
869
|
-
default=None,
|
|
870
|
-
help="Maximum number of final results to display. Overrides default if set.",
|
|
871
|
-
)
|
|
872
|
-
parser.add_argument(
|
|
873
|
-
"--packages",
|
|
874
|
-
metavar="PKG",
|
|
875
|
-
type=str,
|
|
876
|
-
nargs="*",
|
|
877
|
-
default=None,
|
|
878
|
-
help="Filter search results by specific package names (e.g., Mathlib Std). "
|
|
879
|
-
"If not provided, searches all packages.",
|
|
880
|
-
)
|
|
881
|
-
return parser.parse_args()
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
def main():
|
|
885
|
-
"""Main execution function for the search script."""
|
|
886
|
-
args = parse_arguments()
|
|
887
|
-
|
|
888
|
-
logger.info(
|
|
889
|
-
"Using default configurations for paths and parameters from "
|
|
890
|
-
"lean_explore.defaults."
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
db_url = defaults.DEFAULT_DB_URL
|
|
894
|
-
embedding_model_name = defaults.DEFAULT_EMBEDDING_MODEL_NAME
|
|
895
|
-
resolved_idx_path = str(defaults.DEFAULT_FAISS_INDEX_PATH.resolve())
|
|
896
|
-
resolved_map_path = str(defaults.DEFAULT_FAISS_MAP_PATH.resolve())
|
|
897
|
-
|
|
898
|
-
faiss_k_cand = defaults.DEFAULT_FAISS_K
|
|
899
|
-
pr_weight = defaults.DEFAULT_PAGERANK_WEIGHT
|
|
900
|
-
sem_sim_weight = defaults.DEFAULT_TEXT_RELEVANCE_WEIGHT
|
|
901
|
-
name_match_w = defaults.DEFAULT_NAME_MATCH_WEIGHT
|
|
902
|
-
results_disp_limit = (
|
|
903
|
-
args.limit if args.limit is not None else defaults.DEFAULT_RESULTS_LIMIT
|
|
904
|
-
)
|
|
905
|
-
semantic_sim_thresh = defaults.DEFAULT_SEM_SIM_THRESHOLD
|
|
906
|
-
faiss_nprobe_val = defaults.DEFAULT_FAISS_NPROBE
|
|
907
|
-
faiss_oversampling_factor_val = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR
|
|
908
|
-
|
|
909
|
-
db_url_display = (
|
|
910
|
-
f"...{str(defaults.DEFAULT_DB_PATH.resolve())[-30:]}"
|
|
911
|
-
if len(str(defaults.DEFAULT_DB_PATH.resolve())) > 30
|
|
912
|
-
else str(defaults.DEFAULT_DB_PATH.resolve())
|
|
913
|
-
)
|
|
914
|
-
logger.info("--- Starting Search (Direct Script Execution) ---")
|
|
915
|
-
logger.info("Query: '%s'", args.query)
|
|
916
|
-
logger.info("Displaying Top: %d results", results_disp_limit)
|
|
917
|
-
if args.packages:
|
|
918
|
-
logger.info("Filtering by user-specified packages: %s", args.packages)
|
|
919
|
-
else:
|
|
920
|
-
logger.info("No package filter specified, searching all packages.")
|
|
921
|
-
logger.info("FAISS k (candidates): %d", faiss_k_cand)
|
|
922
|
-
logger.info("FAISS nprobe (from defaults): %d", faiss_nprobe_val)
|
|
923
|
-
logger.info(
|
|
924
|
-
"FAISS Oversampling Factor (from defaults): %d", faiss_oversampling_factor_val
|
|
925
|
-
)
|
|
926
|
-
logger.info(
|
|
927
|
-
"Semantic Similarity Threshold (from defaults): %.3f", semantic_sim_thresh
|
|
928
|
-
)
|
|
929
|
-
logger.info(
|
|
930
|
-
"Weights -> NormTextSim: %.2f, NormPR: %.2f, NormWordMatch (BM25): %.2f",
|
|
931
|
-
sem_sim_weight,
|
|
932
|
-
pr_weight,
|
|
933
|
-
name_match_w,
|
|
934
|
-
)
|
|
935
|
-
logger.info("Using FAISS index: %s", resolved_idx_path)
|
|
936
|
-
logger.info("Using ID map: %s", resolved_map_path)
|
|
937
|
-
logger.info("Database path: %s", db_url_display)
|
|
938
|
-
|
|
939
|
-
try:
|
|
940
|
-
_USER_LOGS_BASE_DIR.mkdir(parents=True, exist_ok=True)
|
|
941
|
-
except OSError as e:
|
|
942
|
-
logger.warning(
|
|
943
|
-
f"Could not create user log directory {_USER_LOGS_BASE_DIR}: {e}"
|
|
944
|
-
)
|
|
945
|
-
|
|
946
|
-
engine = None
|
|
947
|
-
try:
|
|
948
|
-
s_transformer_model = load_embedding_model(embedding_model_name)
|
|
949
|
-
if s_transformer_model is None:
|
|
950
|
-
logger.error(
|
|
951
|
-
"Sentence transformer model loading failed. Cannot proceed with search."
|
|
952
|
-
)
|
|
953
|
-
sys.exit(1)
|
|
954
|
-
|
|
955
|
-
faiss_idx, id_map = load_faiss_assets(resolved_idx_path, resolved_map_path)
|
|
956
|
-
if faiss_idx is None or id_map is None:
|
|
957
|
-
logger.error(
|
|
958
|
-
"Failed to load critical FAISS assets (index or ID map).\n"
|
|
959
|
-
f"Expected at:\n Index path: {resolved_idx_path}\n"
|
|
960
|
-
f" ID map path: {resolved_map_path}\n"
|
|
961
|
-
"Please ensure these files exist or run 'leanexplore data fetch' "
|
|
962
|
-
"to download the data toolchain."
|
|
963
|
-
)
|
|
964
|
-
sys.exit(1)
|
|
965
|
-
|
|
966
|
-
is_file_db = db_url.startswith("sqlite:///")
|
|
967
|
-
db_file_path = None
|
|
968
|
-
if is_file_db:
|
|
969
|
-
db_file_path_str = db_url[len("sqlite///") :]
|
|
970
|
-
db_file_path = pathlib.Path(db_file_path_str)
|
|
971
|
-
if not db_file_path.exists():
|
|
972
|
-
logger.error(
|
|
973
|
-
f"Database file not found at the expected location: "
|
|
974
|
-
f"{db_file_path}\n"
|
|
975
|
-
"Please run 'leanexplore data fetch' to download the data "
|
|
976
|
-
"toolchain."
|
|
977
|
-
)
|
|
978
|
-
sys.exit(1)
|
|
979
|
-
|
|
980
|
-
engine = create_engine(db_url, echo=False)
|
|
981
|
-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
982
|
-
|
|
983
|
-
with SessionLocal() as session:
|
|
984
|
-
ranked_results = perform_search(
|
|
985
|
-
session=session,
|
|
986
|
-
query_string=args.query,
|
|
987
|
-
model=s_transformer_model,
|
|
988
|
-
faiss_index=faiss_idx,
|
|
989
|
-
text_chunk_id_map=id_map,
|
|
990
|
-
faiss_k=faiss_k_cand,
|
|
991
|
-
pagerank_weight=pr_weight,
|
|
992
|
-
text_relevance_weight=sem_sim_weight,
|
|
993
|
-
log_searches=True,
|
|
994
|
-
name_match_weight=name_match_w,
|
|
995
|
-
selected_packages=args.packages,
|
|
996
|
-
semantic_similarity_threshold=semantic_sim_thresh,
|
|
997
|
-
faiss_nprobe=faiss_nprobe_val,
|
|
998
|
-
faiss_oversampling_factor=faiss_oversampling_factor_val,
|
|
999
|
-
)
|
|
1000
|
-
|
|
1001
|
-
print_results(ranked_results[:results_disp_limit])
|
|
1002
|
-
|
|
1003
|
-
except FileNotFoundError as e:
|
|
1004
|
-
logger.error(
|
|
1005
|
-
f"A required file was not found: {e.filename}.\n"
|
|
1006
|
-
"This could be an issue with configured paths or missing data.\n"
|
|
1007
|
-
"If this relates to core data assets, please try running "
|
|
1008
|
-
"'leanexplore data fetch'."
|
|
1009
|
-
)
|
|
1010
|
-
sys.exit(1)
|
|
1011
|
-
except OperationalError as e_db:
|
|
1012
|
-
is_file_db_op_err = defaults.DEFAULT_DB_URL.startswith("sqlite:///")
|
|
1013
|
-
db_file_path_op_err = defaults.DEFAULT_DB_PATH
|
|
1014
|
-
if is_file_db_op_err and (
|
|
1015
|
-
"unable to open database file" in str(e_db).lower()
|
|
1016
|
-
or (db_file_path_op_err and not db_file_path_op_err.exists())
|
|
1017
|
-
):
|
|
1018
|
-
p = str(db_file_path_op_err.resolve())
|
|
1019
|
-
logger.error(
|
|
1020
|
-
f"Database connection failed: {e_db}\n"
|
|
1021
|
-
f"The database file appears to be missing or inaccessible at: "
|
|
1022
|
-
f"{p if db_file_path_op_err else 'Unknown Path'}\n"
|
|
1023
|
-
"Please run 'leanexplore data fetch' to download or update the "
|
|
1024
|
-
"data toolchain."
|
|
1025
|
-
)
|
|
1026
|
-
else:
|
|
1027
|
-
logger.error(
|
|
1028
|
-
f"Database connection/operational error: {e_db}", exc_info=True
|
|
1029
|
-
)
|
|
1030
|
-
sys.exit(1)
|
|
1031
|
-
except SQLAlchemyError as e_sqla:
|
|
1032
|
-
logger.error(
|
|
1033
|
-
"A database error occurred during search: %s", e_sqla, exc_info=True
|
|
1034
|
-
)
|
|
1035
|
-
sys.exit(1)
|
|
1036
|
-
except Exception as e_general:
|
|
1037
|
-
logger.critical(
|
|
1038
|
-
"An unexpected critical error occurred during search: %s",
|
|
1039
|
-
e_general,
|
|
1040
|
-
exc_info=True,
|
|
1041
|
-
)
|
|
1042
|
-
sys.exit(1)
|
|
1043
|
-
finally:
|
|
1044
|
-
if engine:
|
|
1045
|
-
engine.dispose()
|
|
1046
|
-
logger.debug("Database engine disposed.")
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
if __name__ == "__main__":
|
|
1050
|
-
main()
|