lean-explore 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lean_explore/__init__.py +1 -0
- lean_explore/api/__init__.py +1 -0
- lean_explore/api/client.py +124 -0
- lean_explore/cli/__init__.py +1 -0
- lean_explore/cli/agent.py +781 -0
- lean_explore/cli/config_utils.py +408 -0
- lean_explore/cli/data_commands.py +506 -0
- lean_explore/cli/main.py +659 -0
- lean_explore/defaults.py +117 -0
- lean_explore/local/__init__.py +1 -0
- lean_explore/local/search.py +921 -0
- lean_explore/local/service.py +394 -0
- lean_explore/mcp/__init__.py +1 -0
- lean_explore/mcp/app.py +107 -0
- lean_explore/mcp/server.py +247 -0
- lean_explore/mcp/tools.py +242 -0
- lean_explore/shared/__init__.py +1 -0
- lean_explore/shared/models/__init__.py +1 -0
- lean_explore/shared/models/api.py +117 -0
- lean_explore/shared/models/db.py +411 -0
- lean_explore-0.1.1.dist-info/METADATA +277 -0
- lean_explore-0.1.1.dist-info/RECORD +26 -0
- lean_explore-0.1.1.dist-info/WHEEL +5 -0
- lean_explore-0.1.1.dist-info/entry_points.txt +2 -0
- lean_explore-0.1.1.dist-info/licenses/LICENSE +201 -0
- lean_explore-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,921 @@
|
|
|
1
|
+
# src/lean_explore/local/search.py
|
|
2
|
+
|
|
3
|
+
"""Performs semantic search and ranked retrieval of StatementGroups.
|
|
4
|
+
|
|
5
|
+
Combines semantic similarity from FAISS and pre-scaled PageRank scores
|
|
6
|
+
to rank StatementGroups. It loads necessary assets (embedding model,
|
|
7
|
+
FAISS index, ID map) using default configurations, embeds the user query,
|
|
8
|
+
performs FAISS search, filters based on a similarity threshold,
|
|
9
|
+
retrieves group details from the database, normalizes semantic similarity scores,
|
|
10
|
+
and then combines these scores using configurable weights to produce a final
|
|
11
|
+
ranked list. It also logs search performance statistics to a dedicated
|
|
12
|
+
JSONL file.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import datetime
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import pathlib
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
24
|
+
|
|
25
|
+
from filelock import FileLock, Timeout
|
|
26
|
+
|
|
27
|
+
# --- Dependency Imports ---
|
|
28
|
+
try:
|
|
29
|
+
import faiss
|
|
30
|
+
import numpy as np
|
|
31
|
+
from sentence_transformers import SentenceTransformer
|
|
32
|
+
from sqlalchemy import create_engine, or_, select
|
|
33
|
+
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
|
34
|
+
from sqlalchemy.orm import Session, joinedload, sessionmaker
|
|
35
|
+
except ImportError as e:
|
|
36
|
+
# pylint: disable=broad-exception-raised
|
|
37
|
+
print(
|
|
38
|
+
f"Error: Missing required libraries ({e}).\n"
|
|
39
|
+
"Please install them: pip install SQLAlchemy faiss-cpu "
|
|
40
|
+
"sentence-transformers numpy filelock rapidfuzz",
|
|
41
|
+
file=sys.stderr,
|
|
42
|
+
)
|
|
43
|
+
sys.exit(1)
|
|
44
|
+
|
|
45
|
+
# --- Project Model & Default Config Imports ---
|
|
46
|
+
try:
|
|
47
|
+
from lean_explore import defaults # Using the new defaults module
|
|
48
|
+
from lean_explore.shared.models.db import StatementGroup
|
|
49
|
+
except ImportError as e:
|
|
50
|
+
# pylint: disable=broad-exception-raised
|
|
51
|
+
print(
|
|
52
|
+
f"Error: Could not import project modules (StatementGroup, defaults): {e}\n"
|
|
53
|
+
"Ensure 'lean_explore' is installed (e.g., 'pip install -e .') "
|
|
54
|
+
"and all dependencies are met.",
|
|
55
|
+
file=sys.stderr,
|
|
56
|
+
)
|
|
57
|
+
sys.exit(1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# --- Logging Setup ---
|
|
61
|
+
logging.basicConfig(
|
|
62
|
+
level=logging.INFO,
|
|
63
|
+
format="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s",
|
|
64
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
65
|
+
)
|
|
66
|
+
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
|
67
|
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
|
68
|
+
logger = logging.getLogger(__name__)
|
|
69
|
+
|
|
70
|
+
# --- Constants ---
|
|
71
|
+
NEWLINE = os.linesep
|
|
72
|
+
EPSILON = 1e-9
|
|
73
|
+
# PROJECT_ROOT might be less relevant for asset paths if defaults.py
|
|
74
|
+
# provides absolute paths
|
|
75
|
+
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
|
76
|
+
|
|
77
|
+
# --- Performance Logging Path Setup ---
|
|
78
|
+
# Logs will be stored in a user-writable directory, e.g., ~/.lean_explore/logs/
|
|
79
|
+
# defaults.LEAN_EXPLORE_USER_DATA_DIR is ~/.lean_explore/data/
|
|
80
|
+
# So, its parent is ~/.lean_explore/
|
|
81
|
+
_USER_LOGS_BASE_DIR = defaults.LEAN_EXPLORE_USER_DATA_DIR.parent / "logs"
|
|
82
|
+
PERFORMANCE_LOG_DIR = str(_USER_LOGS_BASE_DIR)
|
|
83
|
+
PERFORMANCE_LOG_FILENAME = "search_stats.jsonl"
|
|
84
|
+
PERFORMANCE_LOG_PATH = os.path.join(PERFORMANCE_LOG_DIR, PERFORMANCE_LOG_FILENAME)
|
|
85
|
+
LOCK_PATH = os.path.join(PERFORMANCE_LOG_DIR, f"{PERFORMANCE_LOG_FILENAME}.lock")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# --- Performance Logging Helper ---
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def log_search_event_to_json(
|
|
92
|
+
status: str,
|
|
93
|
+
duration_ms: float,
|
|
94
|
+
results_count: int,
|
|
95
|
+
error_type: Optional[str] = None,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Logs a search event as a JSON line to a dedicated performance log file.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
status: A string code indicating the outcome of the search.
|
|
101
|
+
duration_ms: The total duration of the search processing in milliseconds.
|
|
102
|
+
results_count: The number of search results returned.
|
|
103
|
+
error_type: Optional. The type of error if the status indicates an error.
|
|
104
|
+
"""
|
|
105
|
+
log_entry = {
|
|
106
|
+
"timestamp": datetime.datetime.utcnow().isoformat() + "Z",
|
|
107
|
+
"event": "search_processed",
|
|
108
|
+
"status": status,
|
|
109
|
+
"duration_ms": round(duration_ms, 2),
|
|
110
|
+
"results_count": results_count,
|
|
111
|
+
}
|
|
112
|
+
if error_type:
|
|
113
|
+
log_entry["error_type"] = error_type
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
os.makedirs(PERFORMANCE_LOG_DIR, exist_ok=True)
|
|
117
|
+
except OSError as e:
|
|
118
|
+
# This error is critical for logging but should not stop main search flow.
|
|
119
|
+
# The fallback print helps retain info if file logging fails.
|
|
120
|
+
logger.error(
|
|
121
|
+
"Performance logging error: Could not create log directory %s: %s. "
|
|
122
|
+
"Log entry: %s",
|
|
123
|
+
PERFORMANCE_LOG_DIR,
|
|
124
|
+
e,
|
|
125
|
+
log_entry,
|
|
126
|
+
exc_info=False, # Keep exc_info False to avoid spamming user console
|
|
127
|
+
)
|
|
128
|
+
print(
|
|
129
|
+
f"FALLBACK_PERF_LOG (DIR_ERROR): {json.dumps(log_entry)}", file=sys.stderr
|
|
130
|
+
)
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
lock = FileLock(LOCK_PATH, timeout=2)
|
|
134
|
+
try:
|
|
135
|
+
with lock:
|
|
136
|
+
with open(PERFORMANCE_LOG_PATH, "a", encoding="utf-8") as f:
|
|
137
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
138
|
+
except Timeout:
|
|
139
|
+
logger.warning(
|
|
140
|
+
"Performance logging error: Timeout acquiring lock for %s. "
|
|
141
|
+
"Log entry lost: %s",
|
|
142
|
+
LOCK_PATH,
|
|
143
|
+
log_entry,
|
|
144
|
+
)
|
|
145
|
+
print(
|
|
146
|
+
f"FALLBACK_PERF_LOG (LOCK_TIMEOUT): {json.dumps(log_entry)}",
|
|
147
|
+
file=sys.stderr,
|
|
148
|
+
)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.error( # Keep as error for unexpected write issues
|
|
151
|
+
"Performance logging error: Failed to write to %s: %s. Log entry: %s",
|
|
152
|
+
PERFORMANCE_LOG_PATH,
|
|
153
|
+
e,
|
|
154
|
+
log_entry,
|
|
155
|
+
exc_info=False,
|
|
156
|
+
)
|
|
157
|
+
print(
|
|
158
|
+
f"FALLBACK_PERF_LOG (WRITE_ERROR): {json.dumps(log_entry)}", file=sys.stderr
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# --- Asset Loading Functions ---
|
|
163
|
+
def load_faiss_assets(
|
|
164
|
+
index_path_str: str, map_path_str: str
|
|
165
|
+
) -> Tuple[Optional[faiss.Index], Optional[List[str]]]:
|
|
166
|
+
"""Loads the FAISS index and ID map from specified file paths.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
index_path_str: String path to the FAISS index file.
|
|
170
|
+
map_path_str: String path to the JSON ID map file.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A tuple (faiss.Index or None, list_of_IDs or None).
|
|
174
|
+
"""
|
|
175
|
+
index_path = pathlib.Path(index_path_str).resolve()
|
|
176
|
+
map_path = pathlib.Path(map_path_str).resolve()
|
|
177
|
+
|
|
178
|
+
if not index_path.exists():
|
|
179
|
+
logger.error("FAISS index file not found: %s", index_path)
|
|
180
|
+
return None, None
|
|
181
|
+
if not map_path.exists():
|
|
182
|
+
logger.error("FAISS ID map file not found: %s", map_path)
|
|
183
|
+
return None, None
|
|
184
|
+
|
|
185
|
+
faiss_index_obj: Optional[faiss.Index] = None
|
|
186
|
+
id_map_list: Optional[List[str]] = None
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
logger.info("Loading FAISS index from %s...", index_path)
|
|
190
|
+
faiss_index_obj = faiss.read_index(str(index_path))
|
|
191
|
+
logger.info(
|
|
192
|
+
"Loaded FAISS index with %d vectors (Metric Type: %s).",
|
|
193
|
+
faiss_index_obj.ntotal,
|
|
194
|
+
faiss_index_obj.metric_type,
|
|
195
|
+
)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(
|
|
198
|
+
"Failed to load FAISS index from %s: %s", index_path, e, exc_info=True
|
|
199
|
+
)
|
|
200
|
+
return None, id_map_list # Return None for index if loading failed
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
logger.info("Loading ID map from %s...", map_path)
|
|
204
|
+
with open(map_path, encoding="utf-8") as f:
|
|
205
|
+
id_map_list = json.load(f)
|
|
206
|
+
if not isinstance(id_map_list, list):
|
|
207
|
+
logger.error(
|
|
208
|
+
"ID map file (%s) does not contain a valid JSON list.", map_path
|
|
209
|
+
)
|
|
210
|
+
return faiss_index_obj, None # Return None for map if parsing failed
|
|
211
|
+
logger.info("Loaded ID map with %d entries.", len(id_map_list))
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error(
|
|
214
|
+
"Failed to load or parse ID map file %s: %s", map_path, e, exc_info=True
|
|
215
|
+
)
|
|
216
|
+
return faiss_index_obj, None # Return None for map if loading/parsing failed
|
|
217
|
+
|
|
218
|
+
if (
|
|
219
|
+
faiss_index_obj is not None
|
|
220
|
+
and id_map_list is not None
|
|
221
|
+
and faiss_index_obj.ntotal != len(id_map_list)
|
|
222
|
+
):
|
|
223
|
+
logger.warning(
|
|
224
|
+
"Mismatch: FAISS index size (%d) vs ID map size (%d). "
|
|
225
|
+
"Results may be inconsistent.",
|
|
226
|
+
faiss_index_obj.ntotal,
|
|
227
|
+
len(id_map_list),
|
|
228
|
+
)
|
|
229
|
+
return faiss_index_obj, id_map_list
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
|
|
233
|
+
"""Loads the specified Sentence Transformer model.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
model_name: The name or path of the sentence-transformer model.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
The loaded model, or None if loading fails.
|
|
240
|
+
"""
|
|
241
|
+
logger.info("Loading sentence transformer model '%s'...", model_name)
|
|
242
|
+
try:
|
|
243
|
+
model = SentenceTransformer(model_name)
|
|
244
|
+
logger.info(
|
|
245
|
+
"Model '%s' loaded successfully. Max sequence length: %d.",
|
|
246
|
+
model_name,
|
|
247
|
+
model.max_seq_length,
|
|
248
|
+
)
|
|
249
|
+
return model
|
|
250
|
+
except Exception as e: # Broad exception for any model loading issue
|
|
251
|
+
logger.error(
|
|
252
|
+
"Failed to load sentence transformer model '%s': %s",
|
|
253
|
+
model_name,
|
|
254
|
+
e,
|
|
255
|
+
exc_info=True,
|
|
256
|
+
)
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# --- Main Search Function ---
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def perform_search(
|
|
264
|
+
session: Session,
|
|
265
|
+
query_string: str,
|
|
266
|
+
model: SentenceTransformer,
|
|
267
|
+
faiss_index: faiss.Index,
|
|
268
|
+
text_chunk_id_map: List[str],
|
|
269
|
+
faiss_k: int,
|
|
270
|
+
pagerank_weight: float,
|
|
271
|
+
text_relevance_weight: float,
|
|
272
|
+
log_searches: bool, # Added parameter
|
|
273
|
+
selected_packages: Optional[List[str]] = None,
|
|
274
|
+
semantic_similarity_threshold: float = defaults.DEFAULT_SEM_SIM_THRESHOLD,
|
|
275
|
+
faiss_nprobe: int = defaults.DEFAULT_FAISS_NPROBE,
|
|
276
|
+
) -> List[Tuple[StatementGroup, Dict[str, float]]]:
|
|
277
|
+
"""Performs semantic search and ranking.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
session: SQLAlchemy session for database access.
|
|
281
|
+
query_string: The user's search query string.
|
|
282
|
+
model: The loaded SentenceTransformer embedding model.
|
|
283
|
+
faiss_index: The loaded FAISS index for text chunks.
|
|
284
|
+
text_chunk_id_map: A list mapping FAISS internal indices to text chunk IDs.
|
|
285
|
+
faiss_k: The number of nearest neighbors to retrieve from FAISS.
|
|
286
|
+
pagerank_weight: Weight for the pre-scaled PageRank score.
|
|
287
|
+
text_relevance_weight: Weight for the normalized semantic similarity score.
|
|
288
|
+
log_searches: If True, search performance data will be logged.
|
|
289
|
+
selected_packages: Optional list of package names to filter search by.
|
|
290
|
+
semantic_similarity_threshold: Minimum similarity for a result to be considered.
|
|
291
|
+
faiss_nprobe: Number of closest cells/clusters to search for IVF-type FAISS
|
|
292
|
+
indexes.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
A list of tuples, sorted by final_score, containing a
|
|
296
|
+
`StatementGroup` object and its scores.
|
|
297
|
+
|
|
298
|
+
Raises:
|
|
299
|
+
Exception: If critical errors like query embedding or FAISS search fail.
|
|
300
|
+
"""
|
|
301
|
+
overall_start_time = time.time()
|
|
302
|
+
|
|
303
|
+
logger.info("Search request event initiated.")
|
|
304
|
+
if semantic_similarity_threshold > 0.0 + EPSILON:
|
|
305
|
+
logger.info(
|
|
306
|
+
"Applying semantic similarity threshold: %.3f",
|
|
307
|
+
semantic_similarity_threshold,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if not query_string.strip():
|
|
311
|
+
logger.warning("Empty query provided. Returning no results.")
|
|
312
|
+
if log_searches:
|
|
313
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
314
|
+
log_search_event_to_json(
|
|
315
|
+
status="EMPTY_QUERY_SUBMITTED", duration_ms=duration_ms, results_count=0
|
|
316
|
+
)
|
|
317
|
+
return []
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
query_embedding = model.encode([query_string.strip()], convert_to_numpy=True)[
|
|
321
|
+
0
|
|
322
|
+
].astype(np.float32)
|
|
323
|
+
query_embedding_reshaped = np.expand_dims(query_embedding, axis=0)
|
|
324
|
+
if faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
325
|
+
logger.debug(
|
|
326
|
+
"Normalizing query embedding for Inner Product (cosine) search."
|
|
327
|
+
)
|
|
328
|
+
faiss.normalize_L2(query_embedding_reshaped)
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error("Failed to embed query: %s", e, exc_info=True)
|
|
331
|
+
if log_searches:
|
|
332
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
333
|
+
log_search_event_to_json(
|
|
334
|
+
status="EMBEDDING_ERROR",
|
|
335
|
+
duration_ms=duration_ms,
|
|
336
|
+
results_count=0,
|
|
337
|
+
error_type=type(e).__name__,
|
|
338
|
+
)
|
|
339
|
+
raise Exception(f"Query embedding failed: {e}") from e
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
logger.debug(
|
|
343
|
+
"Searching FAISS index for top %d text chunk neighbors...", faiss_k
|
|
344
|
+
)
|
|
345
|
+
if hasattr(faiss_index, "nprobe") and isinstance(
|
|
346
|
+
faiss_index.nprobe, int
|
|
347
|
+
): # Check if index is IVF
|
|
348
|
+
if faiss_nprobe > 0:
|
|
349
|
+
faiss_index.nprobe = faiss_nprobe
|
|
350
|
+
logger.debug(f"Set FAISS nprobe to: {faiss_index.nprobe}")
|
|
351
|
+
else: # faiss_nprobe from config is invalid
|
|
352
|
+
logger.warning(
|
|
353
|
+
f"Configured faiss_nprobe is {faiss_nprobe}. Must be > 0. "
|
|
354
|
+
"Using FAISS default or previously set nprobe for this IVF index."
|
|
355
|
+
)
|
|
356
|
+
distances, indices = faiss_index.search(query_embedding_reshaped, faiss_k)
|
|
357
|
+
except Exception as e:
|
|
358
|
+
logger.error("FAISS search failed: %s", e, exc_info=True)
|
|
359
|
+
if log_searches:
|
|
360
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
361
|
+
log_search_event_to_json(
|
|
362
|
+
status="FAISS_SEARCH_ERROR",
|
|
363
|
+
duration_ms=duration_ms,
|
|
364
|
+
results_count=0,
|
|
365
|
+
error_type=type(e).__name__,
|
|
366
|
+
)
|
|
367
|
+
raise Exception(f"FAISS search failed: {e}") from e
|
|
368
|
+
|
|
369
|
+
sg_candidates_raw_similarity: Dict[int, float] = {}
|
|
370
|
+
if indices.size > 0 and distances.size > 0:
|
|
371
|
+
for i, faiss_internal_idx in enumerate(indices[0]):
|
|
372
|
+
if faiss_internal_idx == -1: # FAISS can return -1 for no neighbor
|
|
373
|
+
continue
|
|
374
|
+
try:
|
|
375
|
+
text_chunk_id_str = text_chunk_id_map[faiss_internal_idx]
|
|
376
|
+
raw_faiss_score = distances[0][i]
|
|
377
|
+
similarity_score: float
|
|
378
|
+
|
|
379
|
+
if faiss_index.metric_type == faiss.METRIC_L2:
|
|
380
|
+
similarity_score = 1.0 / (1.0 + np.sqrt(max(0, raw_faiss_score)))
|
|
381
|
+
elif faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
382
|
+
# Assuming normalized vectors, inner product is cosine similarity
|
|
383
|
+
similarity_score = raw_faiss_score
|
|
384
|
+
else: # Default or unknown metric, treat score as distance-like
|
|
385
|
+
similarity_score = 1.0 / (1.0 + max(0, raw_faiss_score))
|
|
386
|
+
logger.warning(
|
|
387
|
+
"Unhandled FAISS metric type %d for text chunk. "
|
|
388
|
+
"Using 1/(1+score) for similarity.",
|
|
389
|
+
faiss_index.metric_type,
|
|
390
|
+
)
|
|
391
|
+
similarity_score = max(
|
|
392
|
+
0.0, min(1.0, similarity_score)
|
|
393
|
+
) # Clamp to [0,1]
|
|
394
|
+
|
|
395
|
+
parts = text_chunk_id_str.split("_")
|
|
396
|
+
if len(parts) >= 2 and parts[0] == "sg":
|
|
397
|
+
try:
|
|
398
|
+
sg_id = int(parts[1])
|
|
399
|
+
# If multiple chunks from the same StatementGroup are retrieved,
|
|
400
|
+
# keep the one with the highest similarity to the query.
|
|
401
|
+
if (
|
|
402
|
+
sg_id not in sg_candidates_raw_similarity
|
|
403
|
+
or similarity_score > sg_candidates_raw_similarity[sg_id]
|
|
404
|
+
):
|
|
405
|
+
sg_candidates_raw_similarity[sg_id] = similarity_score
|
|
406
|
+
except ValueError:
|
|
407
|
+
logger.warning(
|
|
408
|
+
"Could not parse StatementGroup ID from chunk_id: %s",
|
|
409
|
+
text_chunk_id_str,
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
logger.warning(
|
|
413
|
+
"Malformed text_chunk_id format: %s", text_chunk_id_str
|
|
414
|
+
)
|
|
415
|
+
except IndexError:
|
|
416
|
+
logger.warning(
|
|
417
|
+
"FAISS internal index %d out of bounds for ID map (size %d). "
|
|
418
|
+
"Possible data inconsistency.",
|
|
419
|
+
faiss_internal_idx,
|
|
420
|
+
len(text_chunk_id_map),
|
|
421
|
+
)
|
|
422
|
+
except (
|
|
423
|
+
Exception
|
|
424
|
+
) as e: # Catch any other unexpected errors during result processing
|
|
425
|
+
logger.warning(
|
|
426
|
+
"Error processing FAISS result for internal index %d "
|
|
427
|
+
"(chunk_id '%s'): %s",
|
|
428
|
+
faiss_internal_idx,
|
|
429
|
+
text_chunk_id_str if "text_chunk_id_str" in locals() else "N/A",
|
|
430
|
+
e,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if not sg_candidates_raw_similarity:
|
|
434
|
+
logger.info(
|
|
435
|
+
"No valid StatementGroup candidates found after FAISS search and parsing."
|
|
436
|
+
)
|
|
437
|
+
if log_searches:
|
|
438
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
439
|
+
log_search_event_to_json(
|
|
440
|
+
status="NO_FAISS_CANDIDATES", duration_ms=duration_ms, results_count=0
|
|
441
|
+
)
|
|
442
|
+
return []
|
|
443
|
+
logger.info(
|
|
444
|
+
"Aggregated %d unique StatementGroup candidates from FAISS results.",
|
|
445
|
+
len(sg_candidates_raw_similarity),
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
if semantic_similarity_threshold > 0.0 + EPSILON:
|
|
449
|
+
initial_candidate_count = len(sg_candidates_raw_similarity)
|
|
450
|
+
sg_candidates_raw_similarity = {
|
|
451
|
+
sg_id: sim
|
|
452
|
+
for sg_id, sim in sg_candidates_raw_similarity.items()
|
|
453
|
+
if sim >= semantic_similarity_threshold
|
|
454
|
+
}
|
|
455
|
+
logger.info(
|
|
456
|
+
"Post-thresholding: %d of %d candidates remaining (threshold: %.3f).",
|
|
457
|
+
len(sg_candidates_raw_similarity),
|
|
458
|
+
initial_candidate_count,
|
|
459
|
+
semantic_similarity_threshold,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if not sg_candidates_raw_similarity:
|
|
463
|
+
logger.info(
|
|
464
|
+
"No StatementGroup candidates met the semantic similarity "
|
|
465
|
+
"threshold of %.3f.",
|
|
466
|
+
semantic_similarity_threshold,
|
|
467
|
+
)
|
|
468
|
+
if log_searches:
|
|
469
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
470
|
+
log_search_event_to_json(
|
|
471
|
+
status="NO_CANDIDATES_POST_THRESHOLD",
|
|
472
|
+
duration_ms=duration_ms,
|
|
473
|
+
results_count=0,
|
|
474
|
+
)
|
|
475
|
+
return []
|
|
476
|
+
|
|
477
|
+
candidate_sg_ids = list(sg_candidates_raw_similarity.keys())
|
|
478
|
+
sg_objects_map: Dict[int, StatementGroup] = {}
|
|
479
|
+
try:
|
|
480
|
+
logger.debug(
|
|
481
|
+
"Fetching StatementGroup details from DB for %d IDs...",
|
|
482
|
+
len(candidate_sg_ids),
|
|
483
|
+
)
|
|
484
|
+
stmt = select(StatementGroup).where(StatementGroup.id.in_(candidate_sg_ids))
|
|
485
|
+
|
|
486
|
+
if selected_packages:
|
|
487
|
+
logger.info("Filtering search by packages: %s", selected_packages)
|
|
488
|
+
package_filters_sqla = []
|
|
489
|
+
# Assuming package names in selected_packages are like "Mathlib", "Std"
|
|
490
|
+
# And source_file in DB is like
|
|
491
|
+
# "Mathlib/CategoryTheory/Adjunction/Basic.lean"
|
|
492
|
+
for pkg_name in selected_packages:
|
|
493
|
+
# Ensure exact package match at the start of the file path
|
|
494
|
+
# component
|
|
495
|
+
package_filters_sqla.append(
|
|
496
|
+
StatementGroup.source_file.startswith(pkg_name + "/")
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if package_filters_sqla:
|
|
500
|
+
stmt = stmt.where(or_(*package_filters_sqla))
|
|
501
|
+
|
|
502
|
+
# Eagerly load primary_declaration to avoid N+1 queries later if
|
|
503
|
+
# accessing lean_name
|
|
504
|
+
stmt = stmt.options(joinedload(StatementGroup.primary_declaration))
|
|
505
|
+
db_results = session.execute(stmt).scalars().unique().all()
|
|
506
|
+
for sg_obj in db_results:
|
|
507
|
+
sg_objects_map[sg_obj.id] = sg_obj
|
|
508
|
+
|
|
509
|
+
logger.debug(
|
|
510
|
+
"Fetched details for %d StatementGroups from DB that matched filters.",
|
|
511
|
+
len(sg_objects_map),
|
|
512
|
+
)
|
|
513
|
+
# Log if some IDs from FAISS (post-threshold and package filter if
|
|
514
|
+
# applied) were not found in DB. This check is more informative if
|
|
515
|
+
# done *after* any package filtering logic in the query
|
|
516
|
+
final_candidate_ids_after_db_match = set(sg_objects_map.keys())
|
|
517
|
+
original_faiss_candidate_ids = set(candidate_sg_ids)
|
|
518
|
+
|
|
519
|
+
if len(final_candidate_ids_after_db_match) < len(original_faiss_candidate_ids):
|
|
520
|
+
missing_from_db_or_filtered_out = (
|
|
521
|
+
original_faiss_candidate_ids - final_candidate_ids_after_db_match
|
|
522
|
+
)
|
|
523
|
+
logger.info(
|
|
524
|
+
"%d candidates from FAISS (post-threshold) were not found in DB "
|
|
525
|
+
"or excluded by package filters: (e.g., %s).",
|
|
526
|
+
len(missing_from_db_or_filtered_out),
|
|
527
|
+
list(missing_from_db_or_filtered_out)[:5],
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
except SQLAlchemyError as e:
|
|
531
|
+
logger.error(
|
|
532
|
+
"Database query for StatementGroup details failed: %s", e, exc_info=True
|
|
533
|
+
)
|
|
534
|
+
if log_searches:
|
|
535
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
536
|
+
log_search_event_to_json(
|
|
537
|
+
status="DB_FETCH_ERROR",
|
|
538
|
+
duration_ms=duration_ms,
|
|
539
|
+
results_count=0,
|
|
540
|
+
error_type=type(e).__name__,
|
|
541
|
+
)
|
|
542
|
+
raise # Re-raise to be handled by the caller
|
|
543
|
+
|
|
544
|
+
results_with_scores: List[Tuple[StatementGroup, Dict[str, float]]] = []
|
|
545
|
+
candidate_semantic_similarities: List[float] = [] # For normalization range
|
|
546
|
+
processed_candidates_data: List[
|
|
547
|
+
Dict[str, Any]
|
|
548
|
+
] = [] # Temp store for data to be scored
|
|
549
|
+
|
|
550
|
+
# Iterate over IDs that were confirmed to exist in the DB and match filters
|
|
551
|
+
for sg_id in final_candidate_ids_after_db_match: # Use keys from sg_objects_map
|
|
552
|
+
sg_obj = sg_objects_map[sg_id] # We know this exists
|
|
553
|
+
raw_sem_sim = sg_candidates_raw_similarity[
|
|
554
|
+
sg_id
|
|
555
|
+
] # This ID came from FAISS initially
|
|
556
|
+
|
|
557
|
+
processed_candidates_data.append(
|
|
558
|
+
{
|
|
559
|
+
"sg_obj": sg_obj,
|
|
560
|
+
"raw_sem_sim": raw_sem_sim,
|
|
561
|
+
}
|
|
562
|
+
)
|
|
563
|
+
candidate_semantic_similarities.append(raw_sem_sim)
|
|
564
|
+
|
|
565
|
+
if not processed_candidates_data:
|
|
566
|
+
logger.info(
|
|
567
|
+
"No candidates remaining after matching with DB data or other "
|
|
568
|
+
"processing steps."
|
|
569
|
+
)
|
|
570
|
+
if log_searches:
|
|
571
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
572
|
+
log_search_event_to_json(
|
|
573
|
+
status="NO_CANDIDATES_POST_PROCESSING",
|
|
574
|
+
duration_ms=duration_ms,
|
|
575
|
+
results_count=0,
|
|
576
|
+
)
|
|
577
|
+
return []
|
|
578
|
+
|
|
579
|
+
# Normalize semantic similarity scores for the retrieved candidates
|
|
580
|
+
min_sem_sim = (
|
|
581
|
+
min(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
|
|
582
|
+
)
|
|
583
|
+
max_sem_sim = (
|
|
584
|
+
max(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
|
|
585
|
+
)
|
|
586
|
+
range_sem_sim = max_sem_sim - min_sem_sim
|
|
587
|
+
logger.debug(
|
|
588
|
+
"Raw semantic similarity range for normalization: [%.4f, %.4f]",
|
|
589
|
+
min_sem_sim,
|
|
590
|
+
max_sem_sim,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
for candidate_data in processed_candidates_data:
|
|
594
|
+
sg_obj = candidate_data["sg_obj"]
|
|
595
|
+
current_raw_sem_sim = candidate_data["raw_sem_sim"]
|
|
596
|
+
|
|
597
|
+
# Normalize semantic similarity: scale to [0,1]
|
|
598
|
+
norm_sem_sim = 0.5 # Default if range is zero (e.g., only one candidate)
|
|
599
|
+
if range_sem_sim > EPSILON:
|
|
600
|
+
norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
|
|
601
|
+
elif (
|
|
602
|
+
len(candidate_semantic_similarities) == 1
|
|
603
|
+
and candidate_semantic_similarities[0] > 0
|
|
604
|
+
): # Single candidate
|
|
605
|
+
# If only one candidate, its normalized score should be high if
|
|
606
|
+
# its raw score is non-zero.
|
|
607
|
+
norm_sem_sim = 1.0
|
|
608
|
+
elif (
|
|
609
|
+
len(candidate_semantic_similarities) == 0
|
|
610
|
+
): # Should not happen given previous check
|
|
611
|
+
norm_sem_sim = 0.0
|
|
612
|
+
|
|
613
|
+
current_scaled_pagerank = (
|
|
614
|
+
sg_obj.scaled_pagerank_score
|
|
615
|
+
if sg_obj.scaled_pagerank_score is not None
|
|
616
|
+
else 0.0
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Combine scores using weights
|
|
620
|
+
weighted_norm_similarity = text_relevance_weight * norm_sem_sim
|
|
621
|
+
weighted_scaled_pagerank = pagerank_weight * current_scaled_pagerank
|
|
622
|
+
final_score = weighted_norm_similarity + weighted_scaled_pagerank
|
|
623
|
+
|
|
624
|
+
score_dict = {
|
|
625
|
+
"final_score": final_score,
|
|
626
|
+
"norm_similarity": norm_sem_sim,
|
|
627
|
+
"scaled_pagerank": current_scaled_pagerank,
|
|
628
|
+
"weighted_norm_similarity": weighted_norm_similarity,
|
|
629
|
+
"weighted_scaled_pagerank": weighted_scaled_pagerank,
|
|
630
|
+
"raw_similarity": current_raw_sem_sim, # Keep raw similarity for inspection
|
|
631
|
+
}
|
|
632
|
+
results_with_scores.append((sg_obj, score_dict))
|
|
633
|
+
|
|
634
|
+
results_with_scores.sort(key=lambda item: item[1]["final_score"], reverse=True)
|
|
635
|
+
|
|
636
|
+
final_status = "SUCCESS"
|
|
637
|
+
results_count = len(results_with_scores)
|
|
638
|
+
if (
|
|
639
|
+
not results_with_scores and processed_candidates_data
|
|
640
|
+
): # Had candidates, but scoring/sorting yielded none (unlikely)
|
|
641
|
+
final_status = "NO_RESULTS_FINAL_SCORED"
|
|
642
|
+
elif (
|
|
643
|
+
not results_with_scores and not processed_candidates_data
|
|
644
|
+
): # No candidates from the start essentially
|
|
645
|
+
# This case should have been caught earlier, but as a safeguard for logging
|
|
646
|
+
if not candidate_sg_ids:
|
|
647
|
+
final_status = "NO_FAISS_CANDIDATES"
|
|
648
|
+
elif not sg_candidates_raw_similarity:
|
|
649
|
+
final_status = "NO_CANDIDATES_POST_THRESHOLD"
|
|
650
|
+
|
|
651
|
+
if log_searches:
|
|
652
|
+
duration_ms = (time.time() - overall_start_time) * 1000
|
|
653
|
+
log_search_event_to_json(
|
|
654
|
+
status=final_status, duration_ms=duration_ms, results_count=results_count
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
return results_with_scores
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
# --- Output Formatting ---
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> None:
|
|
664
|
+
"""Formats and prints the search results to the console.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
results: A list of tuples, each containing a StatementGroup
|
|
668
|
+
object and its scores, sorted by final_score.
|
|
669
|
+
"""
|
|
670
|
+
if not results:
|
|
671
|
+
print("\nNo results found.")
|
|
672
|
+
return
|
|
673
|
+
|
|
674
|
+
print(f"\n--- Top {len(results)} Search Results (StatementGroups) ---")
|
|
675
|
+
for i, (sg_obj, scores) in enumerate(results):
|
|
676
|
+
primary_decl_name = (
|
|
677
|
+
sg_obj.primary_declaration.lean_name
|
|
678
|
+
if sg_obj.primary_declaration and sg_obj.primary_declaration.lean_name
|
|
679
|
+
else "N/A"
|
|
680
|
+
)
|
|
681
|
+
print(
|
|
682
|
+
f"\n{i + 1}. Lean Name: {primary_decl_name} (SG ID: {sg_obj.id})\n"
|
|
683
|
+
f" Final Score: {scores['final_score']:.4f} ("
|
|
684
|
+
f"NormSim*W: {scores['weighted_norm_similarity']:.4f}, "
|
|
685
|
+
f"ScaledPR*W: {scores['weighted_scaled_pagerank']:.4f})"
|
|
686
|
+
)
|
|
687
|
+
print(
|
|
688
|
+
f" Scores: [NormSim: {scores['norm_similarity']:.4f}, "
|
|
689
|
+
f"ScaledPR: {scores['scaled_pagerank']:.4f}, "
|
|
690
|
+
f"RawSim: {scores['raw_similarity']:.4f}]"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
lean_display = (
|
|
694
|
+
sg_obj.display_statement_text or sg_obj.statement_text or "[No Lean code]"
|
|
695
|
+
)
|
|
696
|
+
lean_display_short = (
|
|
697
|
+
(lean_display[:200] + "...") if len(lean_display) > 200 else lean_display
|
|
698
|
+
)
|
|
699
|
+
print(f" Lean Code: {lean_display_short.replace(NEWLINE, ' ')}")
|
|
700
|
+
|
|
701
|
+
desc_display = (
|
|
702
|
+
sg_obj.informal_description or sg_obj.docstring or "[No description]"
|
|
703
|
+
)
|
|
704
|
+
desc_display_short = (
|
|
705
|
+
(desc_display[:150] + "...") if len(desc_display) > 150 else desc_display
|
|
706
|
+
)
|
|
707
|
+
print(f" Description: {desc_display_short.replace(NEWLINE, ' ')}")
|
|
708
|
+
|
|
709
|
+
source_loc = sg_obj.source_file or "[No source file]"
|
|
710
|
+
if source_loc.startswith("Mathlib/"): # Simplify Mathlib paths
|
|
711
|
+
source_loc = source_loc[len("Mathlib/") :]
|
|
712
|
+
print(f" File: {source_loc}:{sg_obj.range_start_line}")
|
|
713
|
+
|
|
714
|
+
print("\n---------------------------------------------------")
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
# --- Argument Parsing & Main Execution ---
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def parse_arguments() -> argparse.Namespace:
|
|
721
|
+
"""Parses command-line arguments for the search script.
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
An object containing the parsed arguments.
|
|
725
|
+
"""
|
|
726
|
+
parser = argparse.ArgumentParser(
|
|
727
|
+
description="Search Lean StatementGroups using combined scoring.",
|
|
728
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
729
|
+
)
|
|
730
|
+
parser.add_argument("query", type=str, help="The search query string.")
|
|
731
|
+
parser.add_argument(
|
|
732
|
+
"--limit",
|
|
733
|
+
"-n",
|
|
734
|
+
type=int,
|
|
735
|
+
default=None, # Will use DEFAULT_RESULTS_LIMIT from defaults if None
|
|
736
|
+
help="Maximum number of final results to display. Overrides default if set.",
|
|
737
|
+
)
|
|
738
|
+
parser.add_argument(
|
|
739
|
+
"--packages",
|
|
740
|
+
metavar="PKG",
|
|
741
|
+
type=str,
|
|
742
|
+
nargs="*", # Allows zero or more package names
|
|
743
|
+
default=None, # No filter if not provided
|
|
744
|
+
help="Filter search results by specific package names (e.g., Mathlib Std). "
|
|
745
|
+
"If not provided, searches all packages.",
|
|
746
|
+
)
|
|
747
|
+
return parser.parse_args()
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def main():
|
|
751
|
+
"""Main execution function for the search script."""
|
|
752
|
+
args = parse_arguments()
|
|
753
|
+
|
|
754
|
+
logger.info(
|
|
755
|
+
"Using default configurations for paths and parameters from "
|
|
756
|
+
"lean_explore.defaults."
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# These now point to the versioned paths, e.g., .../toolchains/0.1.0/file.db
|
|
760
|
+
db_url = defaults.DEFAULT_DB_URL
|
|
761
|
+
embedding_model_name = defaults.DEFAULT_EMBEDDING_MODEL_NAME
|
|
762
|
+
resolved_idx_path = str(defaults.DEFAULT_FAISS_INDEX_PATH.resolve())
|
|
763
|
+
resolved_map_path = str(defaults.DEFAULT_FAISS_MAP_PATH.resolve())
|
|
764
|
+
|
|
765
|
+
faiss_k_cand = defaults.DEFAULT_FAISS_K
|
|
766
|
+
pr_weight = defaults.DEFAULT_PAGERANK_WEIGHT
|
|
767
|
+
sem_sim_weight = defaults.DEFAULT_TEXT_RELEVANCE_WEIGHT
|
|
768
|
+
results_disp_limit = (
|
|
769
|
+
args.limit if args.limit is not None else defaults.DEFAULT_RESULTS_LIMIT
|
|
770
|
+
)
|
|
771
|
+
semantic_sim_thresh = defaults.DEFAULT_SEM_SIM_THRESHOLD
|
|
772
|
+
faiss_nprobe_val = defaults.DEFAULT_FAISS_NPROBE
|
|
773
|
+
|
|
774
|
+
db_url_display = (
|
|
775
|
+
f"...{str(defaults.DEFAULT_DB_PATH.resolve())[-30:]}"
|
|
776
|
+
if len(str(defaults.DEFAULT_DB_PATH.resolve())) > 30
|
|
777
|
+
else str(defaults.DEFAULT_DB_PATH.resolve())
|
|
778
|
+
)
|
|
779
|
+
logger.info("--- Starting Search (Direct Script Execution) ---")
|
|
780
|
+
logger.info("Query: '%s'", args.query)
|
|
781
|
+
logger.info("Displaying Top: %d results", results_disp_limit)
|
|
782
|
+
if args.packages:
|
|
783
|
+
logger.info("Filtering by user-specified packages: %s", args.packages)
|
|
784
|
+
else:
|
|
785
|
+
logger.info("No package filter specified, searching all packages.")
|
|
786
|
+
logger.info("FAISS k (candidates): %d", faiss_k_cand)
|
|
787
|
+
logger.info("FAISS nprobe (from defaults): %d", faiss_nprobe_val)
|
|
788
|
+
logger.info(
|
|
789
|
+
"Semantic Similarity Threshold (from defaults): %.3f", semantic_sim_thresh
|
|
790
|
+
)
|
|
791
|
+
logger.info(
|
|
792
|
+
"Weights -> NormTextSim: %.2f, ScaledPR: %.2f",
|
|
793
|
+
sem_sim_weight,
|
|
794
|
+
pr_weight,
|
|
795
|
+
)
|
|
796
|
+
logger.info("Using FAISS index: %s", resolved_idx_path)
|
|
797
|
+
logger.info("Using ID map: %s", resolved_map_path)
|
|
798
|
+
logger.info(
|
|
799
|
+
"Database path: %s", db_url_display
|
|
800
|
+
) # Changed from URL for clarity with file paths
|
|
801
|
+
|
|
802
|
+
# Ensure user data directory and toolchain directory exist for logs etc.
|
|
803
|
+
# The fetch command handles creation of the specific toolchain version dir.
|
|
804
|
+
# Here, we ensure the base log directory can be created by performance logger.
|
|
805
|
+
try:
|
|
806
|
+
_USER_LOGS_BASE_DIR.mkdir(parents=True, exist_ok=True)
|
|
807
|
+
except OSError as e:
|
|
808
|
+
logger.warning(
|
|
809
|
+
f"Could not create user log directory {_USER_LOGS_BASE_DIR}: {e}"
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
engine = None
|
|
813
|
+
try:
|
|
814
|
+
# Asset loading with improved error potential
|
|
815
|
+
s_transformer_model = load_embedding_model(embedding_model_name)
|
|
816
|
+
if s_transformer_model is None:
|
|
817
|
+
# load_embedding_model already logs the error
|
|
818
|
+
logger.error(
|
|
819
|
+
"Sentence transformer model loading failed. Cannot proceed with search."
|
|
820
|
+
)
|
|
821
|
+
sys.exit(1)
|
|
822
|
+
|
|
823
|
+
faiss_idx, id_map = load_faiss_assets(resolved_idx_path, resolved_map_path)
|
|
824
|
+
if faiss_idx is None or id_map is None:
|
|
825
|
+
# load_faiss_assets already logs details
|
|
826
|
+
logger.error(
|
|
827
|
+
"Failed to load critical FAISS assets (index or ID map).\n"
|
|
828
|
+
f"Expected at:\n Index path: {resolved_idx_path}\n"
|
|
829
|
+
f" ID map path: {resolved_map_path}\n"
|
|
830
|
+
"Please ensure these files exist or run 'leanexplore data fetch' "
|
|
831
|
+
"to download the data toolchain."
|
|
832
|
+
)
|
|
833
|
+
sys.exit(1)
|
|
834
|
+
|
|
835
|
+
# Database connection
|
|
836
|
+
# Check for DB file existence before creating engine if it's a
|
|
837
|
+
# file-based SQLite DB
|
|
838
|
+
is_file_db = db_url.startswith("sqlite:///")
|
|
839
|
+
db_file_path = None
|
|
840
|
+
if is_file_db:
|
|
841
|
+
# Extract file path from sqlite:/// URL
|
|
842
|
+
db_file_path_str = db_url[len("sqlite///") :]
|
|
843
|
+
db_file_path = pathlib.Path(db_file_path_str)
|
|
844
|
+
if not db_file_path.exists():
|
|
845
|
+
logger.error(
|
|
846
|
+
f"Database file not found at the expected location: "
|
|
847
|
+
f"{db_file_path}\n"
|
|
848
|
+
"Please run 'leanexplore data fetch' to download the data "
|
|
849
|
+
"toolchain."
|
|
850
|
+
)
|
|
851
|
+
sys.exit(1)
|
|
852
|
+
|
|
853
|
+
engine = create_engine(db_url, echo=False)
|
|
854
|
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
855
|
+
|
|
856
|
+
with SessionLocal() as session:
|
|
857
|
+
ranked_results = perform_search(
|
|
858
|
+
session=session,
|
|
859
|
+
query_string=args.query,
|
|
860
|
+
model=s_transformer_model,
|
|
861
|
+
faiss_index=faiss_idx,
|
|
862
|
+
text_chunk_id_map=id_map,
|
|
863
|
+
faiss_k=faiss_k_cand,
|
|
864
|
+
pagerank_weight=pr_weight,
|
|
865
|
+
text_relevance_weight=sem_sim_weight,
|
|
866
|
+
log_searches=True,
|
|
867
|
+
selected_packages=args.packages,
|
|
868
|
+
semantic_similarity_threshold=semantic_sim_thresh, # from defaults
|
|
869
|
+
faiss_nprobe=faiss_nprobe_val, # from defaults
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
print_results(ranked_results[:results_disp_limit])
|
|
873
|
+
|
|
874
|
+
except FileNotFoundError as e: # Should be less common now with explicit checks
|
|
875
|
+
logger.error(
|
|
876
|
+
f"A required file was not found: {e.filename}.\n"
|
|
877
|
+
"This could be an issue with configured paths or missing data.\n"
|
|
878
|
+
"If this relates to core data assets, please try running "
|
|
879
|
+
"'leanexplore data fetch'."
|
|
880
|
+
)
|
|
881
|
+
sys.exit(1)
|
|
882
|
+
except OperationalError as e_db:
|
|
883
|
+
is_file_db_op_err = defaults.DEFAULT_DB_URL.startswith("sqlite:///")
|
|
884
|
+
db_file_path_op_err = defaults.DEFAULT_DB_PATH
|
|
885
|
+
if is_file_db_op_err and (
|
|
886
|
+
"unable to open database file" in str(e_db).lower()
|
|
887
|
+
or (db_file_path_op_err and not db_file_path_op_err.exists())
|
|
888
|
+
):
|
|
889
|
+
p = str(db_file_path_op_err.resolve())
|
|
890
|
+
logger.error(
|
|
891
|
+
f"Database connection failed: {e_db}\n"
|
|
892
|
+
f"The database file appears to be missing or inaccessible at: "
|
|
893
|
+
f"{p if db_file_path_op_err else 'Unknown Path'}\n"
|
|
894
|
+
"Please run 'leanexplore data fetch' to download or update the "
|
|
895
|
+
"data toolchain."
|
|
896
|
+
)
|
|
897
|
+
else:
|
|
898
|
+
logger.error(
|
|
899
|
+
f"Database connection/operational error: {e_db}", exc_info=True
|
|
900
|
+
)
|
|
901
|
+
sys.exit(1)
|
|
902
|
+
except SQLAlchemyError as e_sqla: # Catch other SQLAlchemy errors
|
|
903
|
+
logger.error(
|
|
904
|
+
"A database error occurred during search: %s", e_sqla, exc_info=True
|
|
905
|
+
)
|
|
906
|
+
sys.exit(1)
|
|
907
|
+
except Exception as e_general: # Catch-all for other unexpected critical errors
|
|
908
|
+
logger.critical(
|
|
909
|
+
"An unexpected critical error occurred during search: %s",
|
|
910
|
+
e_general,
|
|
911
|
+
exc_info=True,
|
|
912
|
+
)
|
|
913
|
+
sys.exit(1)
|
|
914
|
+
finally:
|
|
915
|
+
if engine:
|
|
916
|
+
engine.dispose()
|
|
917
|
+
logger.debug("Database engine disposed.")
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
if __name__ == "__main__":
|
|
921
|
+
main()
|