lean-explore 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,921 @@
1
+ # src/lean_explore/local/search.py
2
+
3
+ """Performs semantic search and ranked retrieval of StatementGroups.
4
+
5
+ Combines semantic similarity from FAISS and pre-scaled PageRank scores
6
+ to rank StatementGroups. It loads necessary assets (embedding model,
7
+ FAISS index, ID map) using default configurations, embeds the user query,
8
+ performs FAISS search, filters based on a similarity threshold,
9
+ retrieves group details from the database, normalizes semantic similarity scores,
10
+ and then combines these scores using configurable weights to produce a final
11
+ ranked list. It also logs search performance statistics to a dedicated
12
+ JSONL file.
13
+ """
14
+
15
+ import argparse
16
+ import datetime
17
+ import json
18
+ import logging
19
+ import os
20
+ import pathlib
21
+ import sys
22
+ import time
23
+ from typing import Any, Dict, List, Optional, Tuple
24
+
25
+ from filelock import FileLock, Timeout
26
+
27
+ # --- Dependency Imports ---
28
+ try:
29
+ import faiss
30
+ import numpy as np
31
+ from sentence_transformers import SentenceTransformer
32
+ from sqlalchemy import create_engine, or_, select
33
+ from sqlalchemy.exc import OperationalError, SQLAlchemyError
34
+ from sqlalchemy.orm import Session, joinedload, sessionmaker
35
+ except ImportError as e:
36
+ # pylint: disable=broad-exception-raised
37
+ print(
38
+ f"Error: Missing required libraries ({e}).\n"
39
+ "Please install them: pip install SQLAlchemy faiss-cpu "
40
+ "sentence-transformers numpy filelock rapidfuzz",
41
+ file=sys.stderr,
42
+ )
43
+ sys.exit(1)
44
+
45
+ # --- Project Model & Default Config Imports ---
46
+ try:
47
+ from lean_explore import defaults # Using the new defaults module
48
+ from lean_explore.shared.models.db import StatementGroup
49
+ except ImportError as e:
50
+ # pylint: disable=broad-exception-raised
51
+ print(
52
+ f"Error: Could not import project modules (StatementGroup, defaults): {e}\n"
53
+ "Ensure 'lean_explore' is installed (e.g., 'pip install -e .') "
54
+ "and all dependencies are met.",
55
+ file=sys.stderr,
56
+ )
57
+ sys.exit(1)
58
+
59
+
60
+ # --- Logging Setup ---
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ format="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s",
64
+ datefmt="%Y-%m-%d %H:%M:%S",
65
+ )
66
+ logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
67
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
68
+ logger = logging.getLogger(__name__)
69
+
70
+ # --- Constants ---
71
+ NEWLINE = os.linesep
72
+ EPSILON = 1e-9
73
+ # PROJECT_ROOT might be less relevant for asset paths if defaults.py
74
+ # provides absolute paths
75
+ PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
76
+
77
+ # --- Performance Logging Path Setup ---
78
+ # Logs will be stored in a user-writable directory, e.g., ~/.lean_explore/logs/
79
+ # defaults.LEAN_EXPLORE_USER_DATA_DIR is ~/.lean_explore/data/
80
+ # So, its parent is ~/.lean_explore/
81
+ _USER_LOGS_BASE_DIR = defaults.LEAN_EXPLORE_USER_DATA_DIR.parent / "logs"
82
+ PERFORMANCE_LOG_DIR = str(_USER_LOGS_BASE_DIR)
83
+ PERFORMANCE_LOG_FILENAME = "search_stats.jsonl"
84
+ PERFORMANCE_LOG_PATH = os.path.join(PERFORMANCE_LOG_DIR, PERFORMANCE_LOG_FILENAME)
85
+ LOCK_PATH = os.path.join(PERFORMANCE_LOG_DIR, f"{PERFORMANCE_LOG_FILENAME}.lock")
86
+
87
+
88
+ # --- Performance Logging Helper ---
89
+
90
+
91
+ def log_search_event_to_json(
92
+ status: str,
93
+ duration_ms: float,
94
+ results_count: int,
95
+ error_type: Optional[str] = None,
96
+ ) -> None:
97
+ """Logs a search event as a JSON line to a dedicated performance log file.
98
+
99
+ Args:
100
+ status: A string code indicating the outcome of the search.
101
+ duration_ms: The total duration of the search processing in milliseconds.
102
+ results_count: The number of search results returned.
103
+ error_type: Optional. The type of error if the status indicates an error.
104
+ """
105
+ log_entry = {
106
+ "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
107
+ "event": "search_processed",
108
+ "status": status,
109
+ "duration_ms": round(duration_ms, 2),
110
+ "results_count": results_count,
111
+ }
112
+ if error_type:
113
+ log_entry["error_type"] = error_type
114
+
115
+ try:
116
+ os.makedirs(PERFORMANCE_LOG_DIR, exist_ok=True)
117
+ except OSError as e:
118
+ # This error is critical for logging but should not stop main search flow.
119
+ # The fallback print helps retain info if file logging fails.
120
+ logger.error(
121
+ "Performance logging error: Could not create log directory %s: %s. "
122
+ "Log entry: %s",
123
+ PERFORMANCE_LOG_DIR,
124
+ e,
125
+ log_entry,
126
+ exc_info=False, # Keep exc_info False to avoid spamming user console
127
+ )
128
+ print(
129
+ f"FALLBACK_PERF_LOG (DIR_ERROR): {json.dumps(log_entry)}", file=sys.stderr
130
+ )
131
+ return
132
+
133
+ lock = FileLock(LOCK_PATH, timeout=2)
134
+ try:
135
+ with lock:
136
+ with open(PERFORMANCE_LOG_PATH, "a", encoding="utf-8") as f:
137
+ f.write(json.dumps(log_entry) + "\n")
138
+ except Timeout:
139
+ logger.warning(
140
+ "Performance logging error: Timeout acquiring lock for %s. "
141
+ "Log entry lost: %s",
142
+ LOCK_PATH,
143
+ log_entry,
144
+ )
145
+ print(
146
+ f"FALLBACK_PERF_LOG (LOCK_TIMEOUT): {json.dumps(log_entry)}",
147
+ file=sys.stderr,
148
+ )
149
+ except Exception as e:
150
+ logger.error( # Keep as error for unexpected write issues
151
+ "Performance logging error: Failed to write to %s: %s. Log entry: %s",
152
+ PERFORMANCE_LOG_PATH,
153
+ e,
154
+ log_entry,
155
+ exc_info=False,
156
+ )
157
+ print(
158
+ f"FALLBACK_PERF_LOG (WRITE_ERROR): {json.dumps(log_entry)}", file=sys.stderr
159
+ )
160
+
161
+
162
+ # --- Asset Loading Functions ---
163
+ def load_faiss_assets(
164
+ index_path_str: str, map_path_str: str
165
+ ) -> Tuple[Optional[faiss.Index], Optional[List[str]]]:
166
+ """Loads the FAISS index and ID map from specified file paths.
167
+
168
+ Args:
169
+ index_path_str: String path to the FAISS index file.
170
+ map_path_str: String path to the JSON ID map file.
171
+
172
+ Returns:
173
+ A tuple (faiss.Index or None, list_of_IDs or None).
174
+ """
175
+ index_path = pathlib.Path(index_path_str).resolve()
176
+ map_path = pathlib.Path(map_path_str).resolve()
177
+
178
+ if not index_path.exists():
179
+ logger.error("FAISS index file not found: %s", index_path)
180
+ return None, None
181
+ if not map_path.exists():
182
+ logger.error("FAISS ID map file not found: %s", map_path)
183
+ return None, None
184
+
185
+ faiss_index_obj: Optional[faiss.Index] = None
186
+ id_map_list: Optional[List[str]] = None
187
+
188
+ try:
189
+ logger.info("Loading FAISS index from %s...", index_path)
190
+ faiss_index_obj = faiss.read_index(str(index_path))
191
+ logger.info(
192
+ "Loaded FAISS index with %d vectors (Metric Type: %s).",
193
+ faiss_index_obj.ntotal,
194
+ faiss_index_obj.metric_type,
195
+ )
196
+ except Exception as e:
197
+ logger.error(
198
+ "Failed to load FAISS index from %s: %s", index_path, e, exc_info=True
199
+ )
200
+ return None, id_map_list # Return None for index if loading failed
201
+
202
+ try:
203
+ logger.info("Loading ID map from %s...", map_path)
204
+ with open(map_path, encoding="utf-8") as f:
205
+ id_map_list = json.load(f)
206
+ if not isinstance(id_map_list, list):
207
+ logger.error(
208
+ "ID map file (%s) does not contain a valid JSON list.", map_path
209
+ )
210
+ return faiss_index_obj, None # Return None for map if parsing failed
211
+ logger.info("Loaded ID map with %d entries.", len(id_map_list))
212
+ except Exception as e:
213
+ logger.error(
214
+ "Failed to load or parse ID map file %s: %s", map_path, e, exc_info=True
215
+ )
216
+ return faiss_index_obj, None # Return None for map if loading/parsing failed
217
+
218
+ if (
219
+ faiss_index_obj is not None
220
+ and id_map_list is not None
221
+ and faiss_index_obj.ntotal != len(id_map_list)
222
+ ):
223
+ logger.warning(
224
+ "Mismatch: FAISS index size (%d) vs ID map size (%d). "
225
+ "Results may be inconsistent.",
226
+ faiss_index_obj.ntotal,
227
+ len(id_map_list),
228
+ )
229
+ return faiss_index_obj, id_map_list
230
+
231
+
232
+ def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
233
+ """Loads the specified Sentence Transformer model.
234
+
235
+ Args:
236
+ model_name: The name or path of the sentence-transformer model.
237
+
238
+ Returns:
239
+ The loaded model, or None if loading fails.
240
+ """
241
+ logger.info("Loading sentence transformer model '%s'...", model_name)
242
+ try:
243
+ model = SentenceTransformer(model_name)
244
+ logger.info(
245
+ "Model '%s' loaded successfully. Max sequence length: %d.",
246
+ model_name,
247
+ model.max_seq_length,
248
+ )
249
+ return model
250
+ except Exception as e: # Broad exception for any model loading issue
251
+ logger.error(
252
+ "Failed to load sentence transformer model '%s': %s",
253
+ model_name,
254
+ e,
255
+ exc_info=True,
256
+ )
257
+ return None
258
+
259
+
260
+ # --- Main Search Function ---
261
+
262
+
263
+ def perform_search(
264
+ session: Session,
265
+ query_string: str,
266
+ model: SentenceTransformer,
267
+ faiss_index: faiss.Index,
268
+ text_chunk_id_map: List[str],
269
+ faiss_k: int,
270
+ pagerank_weight: float,
271
+ text_relevance_weight: float,
272
+ log_searches: bool, # Added parameter
273
+ selected_packages: Optional[List[str]] = None,
274
+ semantic_similarity_threshold: float = defaults.DEFAULT_SEM_SIM_THRESHOLD,
275
+ faiss_nprobe: int = defaults.DEFAULT_FAISS_NPROBE,
276
+ ) -> List[Tuple[StatementGroup, Dict[str, float]]]:
277
+ """Performs semantic search and ranking.
278
+
279
+ Args:
280
+ session: SQLAlchemy session for database access.
281
+ query_string: The user's search query string.
282
+ model: The loaded SentenceTransformer embedding model.
283
+ faiss_index: The loaded FAISS index for text chunks.
284
+ text_chunk_id_map: A list mapping FAISS internal indices to text chunk IDs.
285
+ faiss_k: The number of nearest neighbors to retrieve from FAISS.
286
+ pagerank_weight: Weight for the pre-scaled PageRank score.
287
+ text_relevance_weight: Weight for the normalized semantic similarity score.
288
+ log_searches: If True, search performance data will be logged.
289
+ selected_packages: Optional list of package names to filter search by.
290
+ semantic_similarity_threshold: Minimum similarity for a result to be considered.
291
+ faiss_nprobe: Number of closest cells/clusters to search for IVF-type FAISS
292
+ indexes.
293
+
294
+ Returns:
295
+ A list of tuples, sorted by final_score, containing a
296
+ `StatementGroup` object and its scores.
297
+
298
+ Raises:
299
+ Exception: If critical errors like query embedding or FAISS search fail.
300
+ """
301
+ overall_start_time = time.time()
302
+
303
+ logger.info("Search request event initiated.")
304
+ if semantic_similarity_threshold > 0.0 + EPSILON:
305
+ logger.info(
306
+ "Applying semantic similarity threshold: %.3f",
307
+ semantic_similarity_threshold,
308
+ )
309
+
310
+ if not query_string.strip():
311
+ logger.warning("Empty query provided. Returning no results.")
312
+ if log_searches:
313
+ duration_ms = (time.time() - overall_start_time) * 1000
314
+ log_search_event_to_json(
315
+ status="EMPTY_QUERY_SUBMITTED", duration_ms=duration_ms, results_count=0
316
+ )
317
+ return []
318
+
319
+ try:
320
+ query_embedding = model.encode([query_string.strip()], convert_to_numpy=True)[
321
+ 0
322
+ ].astype(np.float32)
323
+ query_embedding_reshaped = np.expand_dims(query_embedding, axis=0)
324
+ if faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
325
+ logger.debug(
326
+ "Normalizing query embedding for Inner Product (cosine) search."
327
+ )
328
+ faiss.normalize_L2(query_embedding_reshaped)
329
+ except Exception as e:
330
+ logger.error("Failed to embed query: %s", e, exc_info=True)
331
+ if log_searches:
332
+ duration_ms = (time.time() - overall_start_time) * 1000
333
+ log_search_event_to_json(
334
+ status="EMBEDDING_ERROR",
335
+ duration_ms=duration_ms,
336
+ results_count=0,
337
+ error_type=type(e).__name__,
338
+ )
339
+ raise Exception(f"Query embedding failed: {e}") from e
340
+
341
+ try:
342
+ logger.debug(
343
+ "Searching FAISS index for top %d text chunk neighbors...", faiss_k
344
+ )
345
+ if hasattr(faiss_index, "nprobe") and isinstance(
346
+ faiss_index.nprobe, int
347
+ ): # Check if index is IVF
348
+ if faiss_nprobe > 0:
349
+ faiss_index.nprobe = faiss_nprobe
350
+ logger.debug(f"Set FAISS nprobe to: {faiss_index.nprobe}")
351
+ else: # faiss_nprobe from config is invalid
352
+ logger.warning(
353
+ f"Configured faiss_nprobe is {faiss_nprobe}. Must be > 0. "
354
+ "Using FAISS default or previously set nprobe for this IVF index."
355
+ )
356
+ distances, indices = faiss_index.search(query_embedding_reshaped, faiss_k)
357
+ except Exception as e:
358
+ logger.error("FAISS search failed: %s", e, exc_info=True)
359
+ if log_searches:
360
+ duration_ms = (time.time() - overall_start_time) * 1000
361
+ log_search_event_to_json(
362
+ status="FAISS_SEARCH_ERROR",
363
+ duration_ms=duration_ms,
364
+ results_count=0,
365
+ error_type=type(e).__name__,
366
+ )
367
+ raise Exception(f"FAISS search failed: {e}") from e
368
+
369
+ sg_candidates_raw_similarity: Dict[int, float] = {}
370
+ if indices.size > 0 and distances.size > 0:
371
+ for i, faiss_internal_idx in enumerate(indices[0]):
372
+ if faiss_internal_idx == -1: # FAISS can return -1 for no neighbor
373
+ continue
374
+ try:
375
+ text_chunk_id_str = text_chunk_id_map[faiss_internal_idx]
376
+ raw_faiss_score = distances[0][i]
377
+ similarity_score: float
378
+
379
+ if faiss_index.metric_type == faiss.METRIC_L2:
380
+ similarity_score = 1.0 / (1.0 + np.sqrt(max(0, raw_faiss_score)))
381
+ elif faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
382
+ # Assuming normalized vectors, inner product is cosine similarity
383
+ similarity_score = raw_faiss_score
384
+ else: # Default or unknown metric, treat score as distance-like
385
+ similarity_score = 1.0 / (1.0 + max(0, raw_faiss_score))
386
+ logger.warning(
387
+ "Unhandled FAISS metric type %d for text chunk. "
388
+ "Using 1/(1+score) for similarity.",
389
+ faiss_index.metric_type,
390
+ )
391
+ similarity_score = max(
392
+ 0.0, min(1.0, similarity_score)
393
+ ) # Clamp to [0,1]
394
+
395
+ parts = text_chunk_id_str.split("_")
396
+ if len(parts) >= 2 and parts[0] == "sg":
397
+ try:
398
+ sg_id = int(parts[1])
399
+ # If multiple chunks from the same StatementGroup are retrieved,
400
+ # keep the one with the highest similarity to the query.
401
+ if (
402
+ sg_id not in sg_candidates_raw_similarity
403
+ or similarity_score > sg_candidates_raw_similarity[sg_id]
404
+ ):
405
+ sg_candidates_raw_similarity[sg_id] = similarity_score
406
+ except ValueError:
407
+ logger.warning(
408
+ "Could not parse StatementGroup ID from chunk_id: %s",
409
+ text_chunk_id_str,
410
+ )
411
+ else:
412
+ logger.warning(
413
+ "Malformed text_chunk_id format: %s", text_chunk_id_str
414
+ )
415
+ except IndexError:
416
+ logger.warning(
417
+ "FAISS internal index %d out of bounds for ID map (size %d). "
418
+ "Possible data inconsistency.",
419
+ faiss_internal_idx,
420
+ len(text_chunk_id_map),
421
+ )
422
+ except (
423
+ Exception
424
+ ) as e: # Catch any other unexpected errors during result processing
425
+ logger.warning(
426
+ "Error processing FAISS result for internal index %d "
427
+ "(chunk_id '%s'): %s",
428
+ faiss_internal_idx,
429
+ text_chunk_id_str if "text_chunk_id_str" in locals() else "N/A",
430
+ e,
431
+ )
432
+
433
+ if not sg_candidates_raw_similarity:
434
+ logger.info(
435
+ "No valid StatementGroup candidates found after FAISS search and parsing."
436
+ )
437
+ if log_searches:
438
+ duration_ms = (time.time() - overall_start_time) * 1000
439
+ log_search_event_to_json(
440
+ status="NO_FAISS_CANDIDATES", duration_ms=duration_ms, results_count=0
441
+ )
442
+ return []
443
+ logger.info(
444
+ "Aggregated %d unique StatementGroup candidates from FAISS results.",
445
+ len(sg_candidates_raw_similarity),
446
+ )
447
+
448
+ if semantic_similarity_threshold > 0.0 + EPSILON:
449
+ initial_candidate_count = len(sg_candidates_raw_similarity)
450
+ sg_candidates_raw_similarity = {
451
+ sg_id: sim
452
+ for sg_id, sim in sg_candidates_raw_similarity.items()
453
+ if sim >= semantic_similarity_threshold
454
+ }
455
+ logger.info(
456
+ "Post-thresholding: %d of %d candidates remaining (threshold: %.3f).",
457
+ len(sg_candidates_raw_similarity),
458
+ initial_candidate_count,
459
+ semantic_similarity_threshold,
460
+ )
461
+
462
+ if not sg_candidates_raw_similarity:
463
+ logger.info(
464
+ "No StatementGroup candidates met the semantic similarity "
465
+ "threshold of %.3f.",
466
+ semantic_similarity_threshold,
467
+ )
468
+ if log_searches:
469
+ duration_ms = (time.time() - overall_start_time) * 1000
470
+ log_search_event_to_json(
471
+ status="NO_CANDIDATES_POST_THRESHOLD",
472
+ duration_ms=duration_ms,
473
+ results_count=0,
474
+ )
475
+ return []
476
+
477
+ candidate_sg_ids = list(sg_candidates_raw_similarity.keys())
478
+ sg_objects_map: Dict[int, StatementGroup] = {}
479
+ try:
480
+ logger.debug(
481
+ "Fetching StatementGroup details from DB for %d IDs...",
482
+ len(candidate_sg_ids),
483
+ )
484
+ stmt = select(StatementGroup).where(StatementGroup.id.in_(candidate_sg_ids))
485
+
486
+ if selected_packages:
487
+ logger.info("Filtering search by packages: %s", selected_packages)
488
+ package_filters_sqla = []
489
+ # Assuming package names in selected_packages are like "Mathlib", "Std"
490
+ # And source_file in DB is like
491
+ # "Mathlib/CategoryTheory/Adjunction/Basic.lean"
492
+ for pkg_name in selected_packages:
493
+ # Ensure exact package match at the start of the file path
494
+ # component
495
+ package_filters_sqla.append(
496
+ StatementGroup.source_file.startswith(pkg_name + "/")
497
+ )
498
+
499
+ if package_filters_sqla:
500
+ stmt = stmt.where(or_(*package_filters_sqla))
501
+
502
+ # Eagerly load primary_declaration to avoid N+1 queries later if
503
+ # accessing lean_name
504
+ stmt = stmt.options(joinedload(StatementGroup.primary_declaration))
505
+ db_results = session.execute(stmt).scalars().unique().all()
506
+ for sg_obj in db_results:
507
+ sg_objects_map[sg_obj.id] = sg_obj
508
+
509
+ logger.debug(
510
+ "Fetched details for %d StatementGroups from DB that matched filters.",
511
+ len(sg_objects_map),
512
+ )
513
+ # Log if some IDs from FAISS (post-threshold and package filter if
514
+ # applied) were not found in DB. This check is more informative if
515
+ # done *after* any package filtering logic in the query
516
+ final_candidate_ids_after_db_match = set(sg_objects_map.keys())
517
+ original_faiss_candidate_ids = set(candidate_sg_ids)
518
+
519
+ if len(final_candidate_ids_after_db_match) < len(original_faiss_candidate_ids):
520
+ missing_from_db_or_filtered_out = (
521
+ original_faiss_candidate_ids - final_candidate_ids_after_db_match
522
+ )
523
+ logger.info(
524
+ "%d candidates from FAISS (post-threshold) were not found in DB "
525
+ "or excluded by package filters: (e.g., %s).",
526
+ len(missing_from_db_or_filtered_out),
527
+ list(missing_from_db_or_filtered_out)[:5],
528
+ )
529
+
530
+ except SQLAlchemyError as e:
531
+ logger.error(
532
+ "Database query for StatementGroup details failed: %s", e, exc_info=True
533
+ )
534
+ if log_searches:
535
+ duration_ms = (time.time() - overall_start_time) * 1000
536
+ log_search_event_to_json(
537
+ status="DB_FETCH_ERROR",
538
+ duration_ms=duration_ms,
539
+ results_count=0,
540
+ error_type=type(e).__name__,
541
+ )
542
+ raise # Re-raise to be handled by the caller
543
+
544
+ results_with_scores: List[Tuple[StatementGroup, Dict[str, float]]] = []
545
+ candidate_semantic_similarities: List[float] = [] # For normalization range
546
+ processed_candidates_data: List[
547
+ Dict[str, Any]
548
+ ] = [] # Temp store for data to be scored
549
+
550
+ # Iterate over IDs that were confirmed to exist in the DB and match filters
551
+ for sg_id in final_candidate_ids_after_db_match: # Use keys from sg_objects_map
552
+ sg_obj = sg_objects_map[sg_id] # We know this exists
553
+ raw_sem_sim = sg_candidates_raw_similarity[
554
+ sg_id
555
+ ] # This ID came from FAISS initially
556
+
557
+ processed_candidates_data.append(
558
+ {
559
+ "sg_obj": sg_obj,
560
+ "raw_sem_sim": raw_sem_sim,
561
+ }
562
+ )
563
+ candidate_semantic_similarities.append(raw_sem_sim)
564
+
565
+ if not processed_candidates_data:
566
+ logger.info(
567
+ "No candidates remaining after matching with DB data or other "
568
+ "processing steps."
569
+ )
570
+ if log_searches:
571
+ duration_ms = (time.time() - overall_start_time) * 1000
572
+ log_search_event_to_json(
573
+ status="NO_CANDIDATES_POST_PROCESSING",
574
+ duration_ms=duration_ms,
575
+ results_count=0,
576
+ )
577
+ return []
578
+
579
+ # Normalize semantic similarity scores for the retrieved candidates
580
+ min_sem_sim = (
581
+ min(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
582
+ )
583
+ max_sem_sim = (
584
+ max(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
585
+ )
586
+ range_sem_sim = max_sem_sim - min_sem_sim
587
+ logger.debug(
588
+ "Raw semantic similarity range for normalization: [%.4f, %.4f]",
589
+ min_sem_sim,
590
+ max_sem_sim,
591
+ )
592
+
593
+ for candidate_data in processed_candidates_data:
594
+ sg_obj = candidate_data["sg_obj"]
595
+ current_raw_sem_sim = candidate_data["raw_sem_sim"]
596
+
597
+ # Normalize semantic similarity: scale to [0,1]
598
+ norm_sem_sim = 0.5 # Default if range is zero (e.g., only one candidate)
599
+ if range_sem_sim > EPSILON:
600
+ norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
601
+ elif (
602
+ len(candidate_semantic_similarities) == 1
603
+ and candidate_semantic_similarities[0] > 0
604
+ ): # Single candidate
605
+ # If only one candidate, its normalized score should be high if
606
+ # its raw score is non-zero.
607
+ norm_sem_sim = 1.0
608
+ elif (
609
+ len(candidate_semantic_similarities) == 0
610
+ ): # Should not happen given previous check
611
+ norm_sem_sim = 0.0
612
+
613
+ current_scaled_pagerank = (
614
+ sg_obj.scaled_pagerank_score
615
+ if sg_obj.scaled_pagerank_score is not None
616
+ else 0.0
617
+ )
618
+
619
+ # Combine scores using weights
620
+ weighted_norm_similarity = text_relevance_weight * norm_sem_sim
621
+ weighted_scaled_pagerank = pagerank_weight * current_scaled_pagerank
622
+ final_score = weighted_norm_similarity + weighted_scaled_pagerank
623
+
624
+ score_dict = {
625
+ "final_score": final_score,
626
+ "norm_similarity": norm_sem_sim,
627
+ "scaled_pagerank": current_scaled_pagerank,
628
+ "weighted_norm_similarity": weighted_norm_similarity,
629
+ "weighted_scaled_pagerank": weighted_scaled_pagerank,
630
+ "raw_similarity": current_raw_sem_sim, # Keep raw similarity for inspection
631
+ }
632
+ results_with_scores.append((sg_obj, score_dict))
633
+
634
+ results_with_scores.sort(key=lambda item: item[1]["final_score"], reverse=True)
635
+
636
+ final_status = "SUCCESS"
637
+ results_count = len(results_with_scores)
638
+ if (
639
+ not results_with_scores and processed_candidates_data
640
+ ): # Had candidates, but scoring/sorting yielded none (unlikely)
641
+ final_status = "NO_RESULTS_FINAL_SCORED"
642
+ elif (
643
+ not results_with_scores and not processed_candidates_data
644
+ ): # No candidates from the start essentially
645
+ # This case should have been caught earlier, but as a safeguard for logging
646
+ if not candidate_sg_ids:
647
+ final_status = "NO_FAISS_CANDIDATES"
648
+ elif not sg_candidates_raw_similarity:
649
+ final_status = "NO_CANDIDATES_POST_THRESHOLD"
650
+
651
+ if log_searches:
652
+ duration_ms = (time.time() - overall_start_time) * 1000
653
+ log_search_event_to_json(
654
+ status=final_status, duration_ms=duration_ms, results_count=results_count
655
+ )
656
+
657
+ return results_with_scores
658
+
659
+
660
+ # --- Output Formatting ---
661
+
662
+
663
+ def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> None:
664
+ """Formats and prints the search results to the console.
665
+
666
+ Args:
667
+ results: A list of tuples, each containing a StatementGroup
668
+ object and its scores, sorted by final_score.
669
+ """
670
+ if not results:
671
+ print("\nNo results found.")
672
+ return
673
+
674
+ print(f"\n--- Top {len(results)} Search Results (StatementGroups) ---")
675
+ for i, (sg_obj, scores) in enumerate(results):
676
+ primary_decl_name = (
677
+ sg_obj.primary_declaration.lean_name
678
+ if sg_obj.primary_declaration and sg_obj.primary_declaration.lean_name
679
+ else "N/A"
680
+ )
681
+ print(
682
+ f"\n{i + 1}. Lean Name: {primary_decl_name} (SG ID: {sg_obj.id})\n"
683
+ f" Final Score: {scores['final_score']:.4f} ("
684
+ f"NormSim*W: {scores['weighted_norm_similarity']:.4f}, "
685
+ f"ScaledPR*W: {scores['weighted_scaled_pagerank']:.4f})"
686
+ )
687
+ print(
688
+ f" Scores: [NormSim: {scores['norm_similarity']:.4f}, "
689
+ f"ScaledPR: {scores['scaled_pagerank']:.4f}, "
690
+ f"RawSim: {scores['raw_similarity']:.4f}]"
691
+ )
692
+
693
+ lean_display = (
694
+ sg_obj.display_statement_text or sg_obj.statement_text or "[No Lean code]"
695
+ )
696
+ lean_display_short = (
697
+ (lean_display[:200] + "...") if len(lean_display) > 200 else lean_display
698
+ )
699
+ print(f" Lean Code: {lean_display_short.replace(NEWLINE, ' ')}")
700
+
701
+ desc_display = (
702
+ sg_obj.informal_description or sg_obj.docstring or "[No description]"
703
+ )
704
+ desc_display_short = (
705
+ (desc_display[:150] + "...") if len(desc_display) > 150 else desc_display
706
+ )
707
+ print(f" Description: {desc_display_short.replace(NEWLINE, ' ')}")
708
+
709
+ source_loc = sg_obj.source_file or "[No source file]"
710
+ if source_loc.startswith("Mathlib/"): # Simplify Mathlib paths
711
+ source_loc = source_loc[len("Mathlib/") :]
712
+ print(f" File: {source_loc}:{sg_obj.range_start_line}")
713
+
714
+ print("\n---------------------------------------------------")
715
+
716
+
717
+ # --- Argument Parsing & Main Execution ---
718
+
719
+
720
+ def parse_arguments() -> argparse.Namespace:
721
+ """Parses command-line arguments for the search script.
722
+
723
+ Returns:
724
+ An object containing the parsed arguments.
725
+ """
726
+ parser = argparse.ArgumentParser(
727
+ description="Search Lean StatementGroups using combined scoring.",
728
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
729
+ )
730
+ parser.add_argument("query", type=str, help="The search query string.")
731
+ parser.add_argument(
732
+ "--limit",
733
+ "-n",
734
+ type=int,
735
+ default=None, # Will use DEFAULT_RESULTS_LIMIT from defaults if None
736
+ help="Maximum number of final results to display. Overrides default if set.",
737
+ )
738
+ parser.add_argument(
739
+ "--packages",
740
+ metavar="PKG",
741
+ type=str,
742
+ nargs="*", # Allows zero or more package names
743
+ default=None, # No filter if not provided
744
+ help="Filter search results by specific package names (e.g., Mathlib Std). "
745
+ "If not provided, searches all packages.",
746
+ )
747
+ return parser.parse_args()
748
+
749
+
750
+ def main():
751
+ """Main execution function for the search script."""
752
+ args = parse_arguments()
753
+
754
+ logger.info(
755
+ "Using default configurations for paths and parameters from "
756
+ "lean_explore.defaults."
757
+ )
758
+
759
+ # These now point to the versioned paths, e.g., .../toolchains/0.1.0/file.db
760
+ db_url = defaults.DEFAULT_DB_URL
761
+ embedding_model_name = defaults.DEFAULT_EMBEDDING_MODEL_NAME
762
+ resolved_idx_path = str(defaults.DEFAULT_FAISS_INDEX_PATH.resolve())
763
+ resolved_map_path = str(defaults.DEFAULT_FAISS_MAP_PATH.resolve())
764
+
765
+ faiss_k_cand = defaults.DEFAULT_FAISS_K
766
+ pr_weight = defaults.DEFAULT_PAGERANK_WEIGHT
767
+ sem_sim_weight = defaults.DEFAULT_TEXT_RELEVANCE_WEIGHT
768
+ results_disp_limit = (
769
+ args.limit if args.limit is not None else defaults.DEFAULT_RESULTS_LIMIT
770
+ )
771
+ semantic_sim_thresh = defaults.DEFAULT_SEM_SIM_THRESHOLD
772
+ faiss_nprobe_val = defaults.DEFAULT_FAISS_NPROBE
773
+
774
+ db_url_display = (
775
+ f"...{str(defaults.DEFAULT_DB_PATH.resolve())[-30:]}"
776
+ if len(str(defaults.DEFAULT_DB_PATH.resolve())) > 30
777
+ else str(defaults.DEFAULT_DB_PATH.resolve())
778
+ )
779
+ logger.info("--- Starting Search (Direct Script Execution) ---")
780
+ logger.info("Query: '%s'", args.query)
781
+ logger.info("Displaying Top: %d results", results_disp_limit)
782
+ if args.packages:
783
+ logger.info("Filtering by user-specified packages: %s", args.packages)
784
+ else:
785
+ logger.info("No package filter specified, searching all packages.")
786
+ logger.info("FAISS k (candidates): %d", faiss_k_cand)
787
+ logger.info("FAISS nprobe (from defaults): %d", faiss_nprobe_val)
788
+ logger.info(
789
+ "Semantic Similarity Threshold (from defaults): %.3f", semantic_sim_thresh
790
+ )
791
+ logger.info(
792
+ "Weights -> NormTextSim: %.2f, ScaledPR: %.2f",
793
+ sem_sim_weight,
794
+ pr_weight,
795
+ )
796
+ logger.info("Using FAISS index: %s", resolved_idx_path)
797
+ logger.info("Using ID map: %s", resolved_map_path)
798
+ logger.info(
799
+ "Database path: %s", db_url_display
800
+ ) # Changed from URL for clarity with file paths
801
+
802
+ # Ensure user data directory and toolchain directory exist for logs etc.
803
+ # The fetch command handles creation of the specific toolchain version dir.
804
+ # Here, we ensure the base log directory can be created by performance logger.
805
+ try:
806
+ _USER_LOGS_BASE_DIR.mkdir(parents=True, exist_ok=True)
807
+ except OSError as e:
808
+ logger.warning(
809
+ f"Could not create user log directory {_USER_LOGS_BASE_DIR}: {e}"
810
+ )
811
+
812
+ engine = None
813
+ try:
814
+ # Asset loading with improved error potential
815
+ s_transformer_model = load_embedding_model(embedding_model_name)
816
+ if s_transformer_model is None:
817
+ # load_embedding_model already logs the error
818
+ logger.error(
819
+ "Sentence transformer model loading failed. Cannot proceed with search."
820
+ )
821
+ sys.exit(1)
822
+
823
+ faiss_idx, id_map = load_faiss_assets(resolved_idx_path, resolved_map_path)
824
+ if faiss_idx is None or id_map is None:
825
+ # load_faiss_assets already logs details
826
+ logger.error(
827
+ "Failed to load critical FAISS assets (index or ID map).\n"
828
+ f"Expected at:\n Index path: {resolved_idx_path}\n"
829
+ f" ID map path: {resolved_map_path}\n"
830
+ "Please ensure these files exist or run 'leanexplore data fetch' "
831
+ "to download the data toolchain."
832
+ )
833
+ sys.exit(1)
834
+
835
+ # Database connection
836
+ # Check for DB file existence before creating engine if it's a
837
+ # file-based SQLite DB
838
+ is_file_db = db_url.startswith("sqlite:///")
839
+ db_file_path = None
840
+ if is_file_db:
841
+ # Extract file path from sqlite:/// URL
842
+ db_file_path_str = db_url[len("sqlite///") :]
843
+ db_file_path = pathlib.Path(db_file_path_str)
844
+ if not db_file_path.exists():
845
+ logger.error(
846
+ f"Database file not found at the expected location: "
847
+ f"{db_file_path}\n"
848
+ "Please run 'leanexplore data fetch' to download the data "
849
+ "toolchain."
850
+ )
851
+ sys.exit(1)
852
+
853
+ engine = create_engine(db_url, echo=False)
854
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
855
+
856
+ with SessionLocal() as session:
857
+ ranked_results = perform_search(
858
+ session=session,
859
+ query_string=args.query,
860
+ model=s_transformer_model,
861
+ faiss_index=faiss_idx,
862
+ text_chunk_id_map=id_map,
863
+ faiss_k=faiss_k_cand,
864
+ pagerank_weight=pr_weight,
865
+ text_relevance_weight=sem_sim_weight,
866
+ log_searches=True,
867
+ selected_packages=args.packages,
868
+ semantic_similarity_threshold=semantic_sim_thresh, # from defaults
869
+ faiss_nprobe=faiss_nprobe_val, # from defaults
870
+ )
871
+
872
+ print_results(ranked_results[:results_disp_limit])
873
+
874
+ except FileNotFoundError as e: # Should be less common now with explicit checks
875
+ logger.error(
876
+ f"A required file was not found: {e.filename}.\n"
877
+ "This could be an issue with configured paths or missing data.\n"
878
+ "If this relates to core data assets, please try running "
879
+ "'leanexplore data fetch'."
880
+ )
881
+ sys.exit(1)
882
+ except OperationalError as e_db:
883
+ is_file_db_op_err = defaults.DEFAULT_DB_URL.startswith("sqlite:///")
884
+ db_file_path_op_err = defaults.DEFAULT_DB_PATH
885
+ if is_file_db_op_err and (
886
+ "unable to open database file" in str(e_db).lower()
887
+ or (db_file_path_op_err and not db_file_path_op_err.exists())
888
+ ):
889
+ p = str(db_file_path_op_err.resolve())
890
+ logger.error(
891
+ f"Database connection failed: {e_db}\n"
892
+ f"The database file appears to be missing or inaccessible at: "
893
+ f"{p if db_file_path_op_err else 'Unknown Path'}\n"
894
+ "Please run 'leanexplore data fetch' to download or update the "
895
+ "data toolchain."
896
+ )
897
+ else:
898
+ logger.error(
899
+ f"Database connection/operational error: {e_db}", exc_info=True
900
+ )
901
+ sys.exit(1)
902
+ except SQLAlchemyError as e_sqla: # Catch other SQLAlchemy errors
903
+ logger.error(
904
+ "A database error occurred during search: %s", e_sqla, exc_info=True
905
+ )
906
+ sys.exit(1)
907
+ except Exception as e_general: # Catch-all for other unexpected critical errors
908
+ logger.critical(
909
+ "An unexpected critical error occurred during search: %s",
910
+ e_general,
911
+ exc_info=True,
912
+ )
913
+ sys.exit(1)
914
+ finally:
915
+ if engine:
916
+ engine.dispose()
917
+ logger.debug("Database engine disposed.")
918
+
919
+
920
+ if __name__ == "__main__":
921
+ main()