lean-explore 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,13 +2,15 @@
2
2
 
3
3
  """Performs semantic search and ranked retrieval of StatementGroups.
4
4
 
5
- Combines semantic similarity from FAISS and pre-scaled PageRank scores
5
+ Combines semantic similarity from FAISS, pre-scaled PageRank scores, and
6
+ lexical word matching (on Lean name, docstring, and informal descriptions)
6
7
  to rank StatementGroups. It loads necessary assets (embedding model,
7
8
  FAISS index, ID map) using default configurations, embeds the user query,
8
9
  performs FAISS search, filters based on a similarity threshold,
9
- retrieves group details from the database, normalizes semantic similarity scores,
10
- and then combines these scores using configurable weights to produce a final
11
- ranked list. It also logs search performance statistics to a dedicated
10
+ retrieves group details from the database, normalizes semantic similarity,
11
+ PageRank, and BM25 scores based on the current candidate set, and then
12
+ combines these normalized scores using configurable weights to produce a
13
+ final ranked list. It also logs search performance statistics to a dedicated
12
14
  JSONL file.
13
15
  """
14
16
 
@@ -18,36 +20,35 @@ import json
18
20
  import logging
19
21
  import os
20
22
  import pathlib
23
+ import re
21
24
  import sys
22
25
  import time
23
26
  from typing import Any, Dict, List, Optional, Tuple
24
27
 
25
28
  from filelock import FileLock, Timeout
26
29
 
27
- # --- Dependency Imports ---
28
30
  try:
29
31
  import faiss
30
32
  import numpy as np
33
+ from nltk.stem.porter import PorterStemmer
34
+ from rank_bm25 import BM25Plus
31
35
  from sentence_transformers import SentenceTransformer
32
36
  from sqlalchemy import create_engine, or_, select
33
37
  from sqlalchemy.exc import OperationalError, SQLAlchemyError
34
38
  from sqlalchemy.orm import Session, joinedload, sessionmaker
35
39
  except ImportError as e:
36
- # pylint: disable=broad-exception-raised
37
40
  print(
38
41
  f"Error: Missing required libraries ({e}).\n"
39
42
  "Please install them: pip install SQLAlchemy faiss-cpu "
40
- "sentence-transformers numpy filelock rapidfuzz",
43
+ "sentence-transformers numpy filelock rapidfuzz rank_bm25 nltk",
41
44
  file=sys.stderr,
42
45
  )
43
46
  sys.exit(1)
44
47
 
45
- # --- Project Model & Default Config Imports ---
46
48
  try:
47
- from lean_explore import defaults # Using the new defaults module
49
+ from lean_explore import defaults
48
50
  from lean_explore.shared.models.db import StatementGroup
49
51
  except ImportError as e:
50
- # pylint: disable=broad-exception-raised
51
52
  print(
52
53
  f"Error: Could not import project modules (StatementGroup, defaults): {e}\n"
53
54
  "Ensure 'lean_explore' is installed (e.g., 'pip install -e .') "
@@ -57,7 +58,6 @@ except ImportError as e:
57
58
  sys.exit(1)
58
59
 
59
60
 
60
- # --- Logging Setup ---
61
61
  logging.basicConfig(
62
62
  level=logging.WARNING,
63
63
  format="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s",
@@ -67,17 +67,10 @@ logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
67
67
  logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
68
68
  logger = logging.getLogger(__name__)
69
69
 
70
- # --- Constants ---
71
70
  NEWLINE = os.linesep
72
71
  EPSILON = 1e-9
73
- # PROJECT_ROOT might be less relevant for asset paths if defaults.py
74
- # provides absolute paths
75
72
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
76
73
 
77
- # --- Performance Logging Path Setup ---
78
- # Logs will be stored in a user-writable directory, e.g., ~/.lean_explore/logs/
79
- # defaults.LEAN_EXPLORE_USER_DATA_DIR is ~/.lean_explore/data/
80
- # So, its parent is ~/.lean_explore/
81
74
  _USER_LOGS_BASE_DIR = defaults.LEAN_EXPLORE_USER_DATA_DIR.parent / "logs"
82
75
  PERFORMANCE_LOG_DIR = str(_USER_LOGS_BASE_DIR)
83
76
  PERFORMANCE_LOG_FILENAME = "search_stats.jsonl"
@@ -85,9 +78,6 @@ PERFORMANCE_LOG_PATH = os.path.join(PERFORMANCE_LOG_DIR, PERFORMANCE_LOG_FILENAM
85
78
  LOCK_PATH = os.path.join(PERFORMANCE_LOG_DIR, f"{PERFORMANCE_LOG_FILENAME}.lock")
86
79
 
87
80
 
88
- # --- Performance Logging Helper ---
89
-
90
-
91
81
  def log_search_event_to_json(
92
82
  status: str,
93
83
  duration_ms: float,
@@ -115,15 +105,13 @@ def log_search_event_to_json(
115
105
  try:
116
106
  os.makedirs(PERFORMANCE_LOG_DIR, exist_ok=True)
117
107
  except OSError as e:
118
- # This error is critical for logging but should not stop main search flow.
119
- # The fallback print helps retain info if file logging fails.
120
108
  logger.error(
121
109
  "Performance logging error: Could not create log directory %s: %s. "
122
110
  "Log entry: %s",
123
111
  PERFORMANCE_LOG_DIR,
124
112
  e,
125
113
  log_entry,
126
- exc_info=False, # Keep exc_info False to avoid spamming user console
114
+ exc_info=False,
127
115
  )
128
116
  print(
129
117
  f"FALLBACK_PERF_LOG (DIR_ERROR): {json.dumps(log_entry)}", file=sys.stderr
@@ -147,7 +135,7 @@ def log_search_event_to_json(
147
135
  file=sys.stderr,
148
136
  )
149
137
  except Exception as e:
150
- logger.error( # Keep as error for unexpected write issues
138
+ logger.error(
151
139
  "Performance logging error: Failed to write to %s: %s. Log entry: %s",
152
140
  PERFORMANCE_LOG_PATH,
153
141
  e,
@@ -159,7 +147,6 @@ def log_search_event_to_json(
159
147
  )
160
148
 
161
149
 
162
- # --- Asset Loading Functions ---
163
150
  def load_faiss_assets(
164
151
  index_path_str: str, map_path_str: str
165
152
  ) -> Tuple[Optional[faiss.Index], Optional[List[str]]]:
@@ -197,7 +184,7 @@ def load_faiss_assets(
197
184
  logger.error(
198
185
  "Failed to load FAISS index from %s: %s", index_path, e, exc_info=True
199
186
  )
200
- return None, id_map_list # Return None for index if loading failed
187
+ return None, id_map_list
201
188
 
202
189
  try:
203
190
  logger.info("Loading ID map from %s...", map_path)
@@ -207,13 +194,13 @@ def load_faiss_assets(
207
194
  logger.error(
208
195
  "ID map file (%s) does not contain a valid JSON list.", map_path
209
196
  )
210
- return faiss_index_obj, None # Return None for map if parsing failed
197
+ return faiss_index_obj, None
211
198
  logger.info("Loaded ID map with %d entries.", len(id_map_list))
212
199
  except Exception as e:
213
200
  logger.error(
214
201
  "Failed to load or parse ID map file %s: %s", map_path, e, exc_info=True
215
202
  )
216
- return faiss_index_obj, None # Return None for map if loading/parsing failed
203
+ return faiss_index_obj, None
217
204
 
218
205
  if (
219
206
  faiss_index_obj is not None
@@ -247,7 +234,7 @@ def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
247
234
  model.max_seq_length,
248
235
  )
249
236
  return model
250
- except Exception as e: # Broad exception for any model loading issue
237
+ except Exception as e:
251
238
  logger.error(
252
239
  "Failed to load sentence transformer model '%s': %s",
253
240
  model_name,
@@ -257,7 +244,41 @@ def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
257
244
  return None
258
245
 
259
246
 
260
- # --- Main Search Function ---
247
+ def spacify_text(text: str) -> str:
248
+ """Converts a string by adding spaces around delimiters and camelCase.
249
+
250
+ This function takes a string, typically a file path or a name with
251
+ camelCase, and transforms it to a more human-readable format by:
252
+ - Replacing hyphens and underscores with single spaces.
253
+ - Inserting spaces to separate words in camelCase (e.g.,
254
+ 'CamelCaseWord' becomes 'Camel Case Word').
255
+ - Adding spaces around common path delimiters such as '/' and '.'.
256
+ - Normalizing multiple consecutive spaces into single spaces.
257
+ - Stripping leading and trailing whitespace from the final string.
258
+
259
+ Args:
260
+ text: The input string to be transformed.
261
+
262
+ Returns:
263
+ The transformed string with spaces inserted for improved readability.
264
+ """
265
+ text_str = str(text)
266
+
267
+ first_slash_index = text_str.find("/")
268
+ if first_slash_index != -1:
269
+ text_str = text_str[first_slash_index + 1 :]
270
+
271
+ text_str = text_str.replace("-", " ").replace("_", " ").replace(".lean", "")
272
+
273
+ text_str = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", text_str)
274
+ text_str = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", text_str)
275
+
276
+ text_str = text_str.replace("/", " ")
277
+ text_str = text_str.replace(".", " ")
278
+
279
+ text_str = re.sub(r"\s+", " ", text_str).strip()
280
+ text_str = text_str.lower()
281
+ return text_str
261
282
 
262
283
 
263
284
  def perform_search(
@@ -269,12 +290,19 @@ def perform_search(
269
290
  faiss_k: int,
270
291
  pagerank_weight: float,
271
292
  text_relevance_weight: float,
272
- log_searches: bool, # Added parameter
293
+ log_searches: bool,
294
+ name_match_weight: float = defaults.DEFAULT_NAME_MATCH_WEIGHT,
273
295
  selected_packages: Optional[List[str]] = None,
274
296
  semantic_similarity_threshold: float = defaults.DEFAULT_SEM_SIM_THRESHOLD,
275
297
  faiss_nprobe: int = defaults.DEFAULT_FAISS_NPROBE,
298
+ faiss_oversampling_factor: int = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR,
276
299
  ) -> List[Tuple[StatementGroup, Dict[str, float]]]:
277
- """Performs semantic search and ranking.
300
+ """Performs semantic and lexical search, then ranks results.
301
+
302
+ Scores (semantic similarity, PageRank, BM25) are normalized to a 0-1
303
+ range based on the current set of candidates before being weighted and
304
+ combined. If `selected_packages` are specified, `faiss_k` is multiplied
305
+ by `faiss_oversampling_factor` to retrieve more initial candidates.
278
306
 
279
307
  Args:
280
308
  session: SQLAlchemy session for database access.
@@ -282,18 +310,40 @@ def perform_search(
282
310
  model: The loaded SentenceTransformer embedding model.
283
311
  faiss_index: The loaded FAISS index for text chunks.
284
312
  text_chunk_id_map: A list mapping FAISS internal indices to text chunk IDs.
285
- faiss_k: The number of nearest neighbors to retrieve from FAISS.
286
- pagerank_weight: Weight for the pre-scaled PageRank score.
287
- text_relevance_weight: Weight for the normalized semantic similarity score.
313
+ faiss_k: The base number of nearest neighbors to retrieve from FAISS.
314
+ pagerank_weight: Weight for the PageRank score.
315
+ text_relevance_weight: Weight for the semantic similarity score.
288
316
  log_searches: If True, search performance data will be logged.
317
+ name_match_weight: Weight for the lexical word match score (BM25).
318
+ Defaults to `defaults.DEFAULT_NAME_MATCH_WEIGHT`.
289
319
  selected_packages: Optional list of package names to filter search by.
290
- semantic_similarity_threshold: Minimum similarity for a result to be considered.
291
- faiss_nprobe: Number of closest cells/clusters to search for IVF-type FAISS
292
- indexes.
320
+ Defaults to None.
321
+ semantic_similarity_threshold: Minimum similarity for a result to be
322
+ considered. Defaults to `defaults.DEFAULT_SEM_SIM_THRESHOLD`.
323
+ faiss_nprobe: Number of closest cells/clusters to search for IVF-type
324
+ FAISS indexes. Defaults to `defaults.DEFAULT_FAISS_NPROBE`.
325
+ faiss_oversampling_factor: Factor to multiply `faiss_k` by when
326
+ `selected_packages` are active.
327
+ Defaults to `defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR`.
328
+
293
329
 
294
330
  Returns:
295
- A list of tuples, sorted by final_score, containing a
296
- `StatementGroup` object and its scores.
331
+ A list of tuples, sorted by `final_score`. Each tuple contains a
332
+ `StatementGroup` object and a dictionary of its scores.
333
+ The score dictionary includes:
334
+ - 'final_score': The combined weighted score.
335
+ - 'raw_similarity': Original FAISS similarity (0-1).
336
+ - 'norm_similarity': `raw_similarity` normalized across current results.
337
+ - 'original_pagerank_score': PageRank score from the database.
338
+ - 'scaled_pagerank': `original_pagerank_score` normalized across current
339
+ results (this key is kept for compatibility, but
340
+ now holds the normalized PageRank).
341
+ - 'raw_word_match_score': Original BM25 score.
342
+ - 'norm_word_match_score': `raw_word_match_score` normalized across
343
+ current results.
344
+ - Weighted components: `weighted_norm_similarity`,
345
+ `weighted_scaled_pagerank` (uses normalized PageRank),
346
+ `weighted_word_match_score` (uses normalized BM25 score).
297
347
 
298
348
  Raises:
299
349
  Exception: If critical errors like query embedding or FAISS search fail.
@@ -338,22 +388,34 @@ def perform_search(
338
388
  )
339
389
  raise Exception(f"Query embedding failed: {e}") from e
340
390
 
391
+ actual_faiss_k_to_use = faiss_k
392
+ if selected_packages and faiss_oversampling_factor > 1:
393
+ actual_faiss_k_to_use = faiss_k * faiss_oversampling_factor
394
+ logger.info(
395
+ f"Package filter active. "
396
+ f"Using oversampled FAISS K: {actual_faiss_k_to_use} "
397
+ f"(base K: {faiss_k}, factor: {faiss_oversampling_factor})"
398
+ )
399
+ else:
400
+ logger.info(f"Using FAISS K: {actual_faiss_k_to_use} for initial retrieval.")
401
+
341
402
  try:
342
403
  logger.debug(
343
- "Searching FAISS index for top %d text chunk neighbors...", faiss_k
404
+ "Searching FAISS index for top %d text chunk neighbors...",
405
+ actual_faiss_k_to_use,
344
406
  )
345
- if hasattr(faiss_index, "nprobe") and isinstance(
346
- faiss_index.nprobe, int
347
- ): # Check if index is IVF
407
+ if hasattr(faiss_index, "nprobe") and isinstance(faiss_index.nprobe, int):
348
408
  if faiss_nprobe > 0:
349
409
  faiss_index.nprobe = faiss_nprobe
350
410
  logger.debug(f"Set FAISS nprobe to: {faiss_index.nprobe}")
351
- else: # faiss_nprobe from config is invalid
411
+ else:
352
412
  logger.warning(
353
413
  f"Configured faiss_nprobe is {faiss_nprobe}. Must be > 0. "
354
414
  "Using FAISS default or previously set nprobe for this IVF index."
355
415
  )
356
- distances, indices = faiss_index.search(query_embedding_reshaped, faiss_k)
416
+ distances, indices = faiss_index.search(
417
+ query_embedding_reshaped, actual_faiss_k_to_use
418
+ )
357
419
  except Exception as e:
358
420
  logger.error("FAISS search failed: %s", e, exc_info=True)
359
421
  if log_searches:
@@ -369,7 +431,7 @@ def perform_search(
369
431
  sg_candidates_raw_similarity: Dict[int, float] = {}
370
432
  if indices.size > 0 and distances.size > 0:
371
433
  for i, faiss_internal_idx in enumerate(indices[0]):
372
- if faiss_internal_idx == -1: # FAISS can return -1 for no neighbor
434
+ if faiss_internal_idx == -1:
373
435
  continue
374
436
  try:
375
437
  text_chunk_id_str = text_chunk_id_map[faiss_internal_idx]
@@ -379,25 +441,20 @@ def perform_search(
379
441
  if faiss_index.metric_type == faiss.METRIC_L2:
380
442
  similarity_score = 1.0 / (1.0 + np.sqrt(max(0, raw_faiss_score)))
381
443
  elif faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
382
- # Assuming normalized vectors, inner product is cosine similarity
383
444
  similarity_score = raw_faiss_score
384
- else: # Default or unknown metric, treat score as distance-like
445
+ else:
385
446
  similarity_score = 1.0 / (1.0 + max(0, raw_faiss_score))
386
447
  logger.warning(
387
448
  "Unhandled FAISS metric type %d for text chunk. "
388
449
  "Using 1/(1+score) for similarity.",
389
450
  faiss_index.metric_type,
390
451
  )
391
- similarity_score = max(
392
- 0.0, min(1.0, similarity_score)
393
- ) # Clamp to [0,1]
452
+ similarity_score = max(0.0, min(1.0, similarity_score))
394
453
 
395
454
  parts = text_chunk_id_str.split("_")
396
455
  if len(parts) >= 2 and parts[0] == "sg":
397
456
  try:
398
457
  sg_id = int(parts[1])
399
- # If multiple chunks from the same StatementGroup are retrieved,
400
- # keep the one with the highest similarity to the query.
401
458
  if (
402
459
  sg_id not in sg_candidates_raw_similarity
403
460
  or similarity_score > sg_candidates_raw_similarity[sg_id]
@@ -419,9 +476,7 @@ def perform_search(
419
476
  faiss_internal_idx,
420
477
  len(text_chunk_id_map),
421
478
  )
422
- except (
423
- Exception
424
- ) as e: # Catch any other unexpected errors during result processing
479
+ except Exception as e:
425
480
  logger.warning(
426
481
  "Error processing FAISS result for internal index %d "
427
482
  "(chunk_id '%s'): %s",
@@ -486,21 +541,15 @@ def perform_search(
486
541
  if selected_packages:
487
542
  logger.info("Filtering search by packages: %s", selected_packages)
488
543
  package_filters_sqla = []
489
- # Assuming package names in selected_packages are like "Mathlib", "Std"
490
- # And source_file in DB is like
491
- # "Mathlib/CategoryTheory/Adjunction/Basic.lean"
492
544
  for pkg_name in selected_packages:
493
- # Ensure exact package match at the start of the file path
494
- # component
495
- package_filters_sqla.append(
496
- StatementGroup.source_file.startswith(pkg_name + "/")
497
- )
545
+ if pkg_name.strip():
546
+ package_filters_sqla.append(
547
+ StatementGroup.source_file.startswith(pkg_name.strip() + "/")
548
+ )
498
549
 
499
550
  if package_filters_sqla:
500
551
  stmt = stmt.where(or_(*package_filters_sqla))
501
552
 
502
- # Eagerly load primary_declaration to avoid N+1 queries later if
503
- # accessing lean_name
504
553
  stmt = stmt.options(joinedload(StatementGroup.primary_declaration))
505
554
  db_results = session.execute(stmt).scalars().unique().all()
506
555
  for sg_obj in db_results:
@@ -510,9 +559,6 @@ def perform_search(
510
559
  "Fetched details for %d StatementGroups from DB that matched filters.",
511
560
  len(sg_objects_map),
512
561
  )
513
- # Log if some IDs from FAISS (post-threshold and package filter if
514
- # applied) were not found in DB. This check is more informative if
515
- # done *after* any package filtering logic in the query
516
562
  final_candidate_ids_after_db_match = set(sg_objects_map.keys())
517
563
  original_faiss_candidate_ids = set(candidate_sg_ids)
518
564
 
@@ -539,28 +585,33 @@ def perform_search(
539
585
  results_count=0,
540
586
  error_type=type(e).__name__,
541
587
  )
542
- raise # Re-raise to be handled by the caller
588
+ raise
543
589
 
544
590
  results_with_scores: List[Tuple[StatementGroup, Dict[str, float]]] = []
545
- candidate_semantic_similarities: List[float] = [] # For normalization range
546
- processed_candidates_data: List[
547
- Dict[str, Any]
548
- ] = [] # Temp store for data to be scored
549
-
550
- # Iterate over IDs that were confirmed to exist in the DB and match filters
551
- for sg_id in final_candidate_ids_after_db_match: # Use keys from sg_objects_map
552
- sg_obj = sg_objects_map[sg_id] # We know this exists
553
- raw_sem_sim = sg_candidates_raw_similarity[
554
- sg_id
555
- ] # This ID came from FAISS initially
591
+ candidate_semantic_similarities: List[float] = []
592
+ candidate_pagerank_scores: List[float] = []
593
+
594
+ processed_candidates_data: List[Dict[str, Any]] = []
595
+
596
+ for sg_id in final_candidate_ids_after_db_match:
597
+ sg_obj = sg_objects_map[sg_id]
598
+ raw_sem_sim = sg_candidates_raw_similarity[sg_id]
556
599
 
557
600
  processed_candidates_data.append(
558
601
  {
559
602
  "sg_obj": sg_obj,
560
603
  "raw_sem_sim": raw_sem_sim,
604
+ "original_pagerank": sg_obj.scaled_pagerank_score
605
+ if sg_obj.scaled_pagerank_score is not None
606
+ else 0.0,
561
607
  }
562
608
  )
563
609
  candidate_semantic_similarities.append(raw_sem_sim)
610
+ candidate_pagerank_scores.append(
611
+ sg_obj.scaled_pagerank_score
612
+ if sg_obj.scaled_pagerank_score is not None
613
+ else 0.0
614
+ )
564
615
 
565
616
  if not processed_candidates_data:
566
617
  logger.info(
@@ -576,7 +627,63 @@ def perform_search(
576
627
  )
577
628
  return []
578
629
 
579
- # Normalize semantic similarity scores for the retrieved candidates
630
+ stemmer = PorterStemmer()
631
+
632
+ def _get_tokenized_list(text_to_tokenize: str) -> List[str]:
633
+ if not text_to_tokenize:
634
+ return []
635
+ tokens = re.findall(r"\w+", text_to_tokenize.lower())
636
+ return [stemmer.stem(token) for token in tokens]
637
+
638
+ tokenized_query = _get_tokenized_list(query_string.strip())
639
+ bm25_corpus: List[List[str]] = []
640
+ for candidate_item_data in processed_candidates_data:
641
+ sg_obj_for_corpus = candidate_item_data["sg_obj"]
642
+ combined_text_for_bm25 = " ".join(
643
+ filter(
644
+ None,
645
+ [
646
+ (
647
+ sg_obj_for_corpus.primary_declaration.lean_name
648
+ if sg_obj_for_corpus.primary_declaration
649
+ else None
650
+ ),
651
+ sg_obj_for_corpus.docstring,
652
+ sg_obj_for_corpus.informal_description,
653
+ sg_obj_for_corpus.informal_summary,
654
+ sg_obj_for_corpus.display_statement_text,
655
+ (
656
+ sg_obj_for_corpus.primary_declaration.lean_name
657
+ if sg_obj_for_corpus.primary_declaration
658
+ else None
659
+ ),
660
+ (
661
+ spacify_text(sg_obj_for_corpus.primary_declaration.source_file)
662
+ if sg_obj_for_corpus.primary_declaration
663
+ and sg_obj_for_corpus.primary_declaration.source_file
664
+ else None
665
+ ),
666
+ ],
667
+ )
668
+ )
669
+ bm25_corpus.append(_get_tokenized_list(combined_text_for_bm25))
670
+
671
+ raw_bm25_scores_list: List[float] = [0.0] * len(processed_candidates_data)
672
+ if tokenized_query and any(bm25_corpus):
673
+ try:
674
+ bm25_model = BM25Plus(bm25_corpus)
675
+ raw_bm25_scores_list = bm25_model.get_scores(tokenized_query)
676
+ raw_bm25_scores_list = [
677
+ max(0.0, float(score)) for score in raw_bm25_scores_list
678
+ ]
679
+ except Exception as e:
680
+ logger.warning(
681
+ "BM25Plus scoring failed: %s. Word match scores defaulted to 0.",
682
+ e,
683
+ exc_info=False,
684
+ )
685
+ raw_bm25_scores_list = [0.0] * len(processed_candidates_data)
686
+
580
687
  min_sem_sim = (
581
688
  min(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
582
689
  )
@@ -590,44 +697,80 @@ def perform_search(
590
697
  max_sem_sim,
591
698
  )
592
699
 
593
- for candidate_data in processed_candidates_data:
700
+ min_pr = min(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
701
+ max_pr = max(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
702
+ range_pr = max_pr - min_pr
703
+ logger.debug(
704
+ "Original PageRank score range for normalization: [%.4f, %.4f]", min_pr, max_pr
705
+ )
706
+
707
+ min_bm25 = min(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
708
+ max_bm25 = max(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
709
+ range_bm25 = max_bm25 - min_bm25
710
+ logger.debug(
711
+ "Raw BM25 score range for normalization: [%.4f, %.4f]", min_bm25, max_bm25
712
+ )
713
+
714
+ for i, candidate_data in enumerate(processed_candidates_data):
594
715
  sg_obj = candidate_data["sg_obj"]
595
716
  current_raw_sem_sim = candidate_data["raw_sem_sim"]
596
-
597
- # Normalize semantic similarity: scale to [0,1]
598
- norm_sem_sim = 0.5 # Default if range is zero (e.g., only one candidate)
599
- if range_sem_sim > EPSILON:
600
- norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
601
- elif (
602
- len(candidate_semantic_similarities) == 1
603
- and candidate_semantic_similarities[0] > 0
604
- ): # Single candidate
605
- # If only one candidate, its normalized score should be high if
606
- # its raw score is non-zero.
607
- norm_sem_sim = 1.0
608
- elif (
609
- len(candidate_semantic_similarities) == 0
610
- ): # Should not happen given previous check
717
+ original_pagerank_score = candidate_data["original_pagerank"]
718
+ original_bm25_score = raw_bm25_scores_list[i]
719
+
720
+ norm_sem_sim = 0.5
721
+ if candidate_semantic_similarities:
722
+ if range_sem_sim > EPSILON:
723
+ norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
724
+ elif (
725
+ len(candidate_semantic_similarities) == 1
726
+ and candidate_semantic_similarities[0] > EPSILON
727
+ ):
728
+ norm_sem_sim = 1.0
729
+ elif (
730
+ len(candidate_semantic_similarities) > 0
731
+ and range_sem_sim <= EPSILON
732
+ and max_sem_sim <= EPSILON
733
+ ):
734
+ norm_sem_sim = 0.0
735
+ else:
611
736
  norm_sem_sim = 0.0
737
+ norm_sem_sim = max(0.0, min(1.0, norm_sem_sim))
738
+
739
+ norm_pagerank_score = 0.0
740
+ if candidate_pagerank_scores:
741
+ if range_pr > EPSILON:
742
+ norm_pagerank_score = (original_pagerank_score - min_pr) / range_pr
743
+ elif max_pr > EPSILON:
744
+ norm_pagerank_score = 1.0
745
+ norm_pagerank_score = max(0.0, min(1.0, norm_pagerank_score))
746
+
747
+ norm_bm25_score = 0.0
748
+ if raw_bm25_scores_list:
749
+ if range_bm25 > EPSILON:
750
+ norm_bm25_score = (original_bm25_score - min_bm25) / range_bm25
751
+ elif max_bm25 > EPSILON:
752
+ norm_bm25_score = 1.0
753
+ norm_bm25_score = max(0.0, min(1.0, norm_bm25_score))
612
754
 
613
- current_scaled_pagerank = (
614
- sg_obj.scaled_pagerank_score
615
- if sg_obj.scaled_pagerank_score is not None
616
- else 0.0
617
- )
618
-
619
- # Combine scores using weights
620
755
  weighted_norm_similarity = text_relevance_weight * norm_sem_sim
621
- weighted_scaled_pagerank = pagerank_weight * current_scaled_pagerank
622
- final_score = weighted_norm_similarity + weighted_scaled_pagerank
756
+ weighted_norm_pagerank = pagerank_weight * norm_pagerank_score
757
+ weighted_norm_bm25_score = name_match_weight * norm_bm25_score
758
+
759
+ final_score = (
760
+ weighted_norm_similarity + weighted_norm_pagerank + weighted_norm_bm25_score
761
+ )
623
762
 
624
763
  score_dict = {
625
764
  "final_score": final_score,
765
+ "raw_similarity": current_raw_sem_sim,
626
766
  "norm_similarity": norm_sem_sim,
627
- "scaled_pagerank": current_scaled_pagerank,
767
+ "original_pagerank_score": original_pagerank_score,
768
+ "scaled_pagerank": norm_pagerank_score,
769
+ "raw_word_match_score": original_bm25_score,
770
+ "norm_word_match_score": norm_bm25_score,
628
771
  "weighted_norm_similarity": weighted_norm_similarity,
629
- "weighted_scaled_pagerank": weighted_scaled_pagerank,
630
- "raw_similarity": current_raw_sem_sim, # Keep raw similarity for inspection
772
+ "weighted_scaled_pagerank": weighted_norm_pagerank,
773
+ "weighted_word_match_score": weighted_norm_bm25_score,
631
774
  }
632
775
  results_with_scores.append((sg_obj, score_dict))
633
776
 
@@ -635,17 +778,10 @@ def perform_search(
635
778
 
636
779
  final_status = "SUCCESS"
637
780
  results_count = len(results_with_scores)
638
- if (
639
- not results_with_scores and processed_candidates_data
640
- ): # Had candidates, but scoring/sorting yielded none (unlikely)
781
+ if not results_with_scores and processed_candidates_data:
641
782
  final_status = "NO_RESULTS_FINAL_SCORED"
642
- elif (
643
- not results_with_scores and not processed_candidates_data
644
- ): # No candidates from the start essentially
645
- # This case should have been caught earlier, but as a safeguard for logging
646
- if not candidate_sg_ids:
647
- final_status = "NO_FAISS_CANDIDATES"
648
- elif not sg_candidates_raw_similarity:
783
+ elif not results_with_scores and not processed_candidates_data:
784
+ if not sg_candidates_raw_similarity:
649
785
  final_status = "NO_CANDIDATES_POST_THRESHOLD"
650
786
 
651
787
  if log_searches:
@@ -657,9 +793,6 @@ def perform_search(
657
793
  return results_with_scores
658
794
 
659
795
 
660
- # --- Output Formatting ---
661
-
662
-
663
796
  def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> None:
664
797
  """Formats and prints the search results to the console.
665
798
 
@@ -682,12 +815,16 @@ def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> Non
682
815
  f"\n{i + 1}. Lean Name: {primary_decl_name} (SG ID: {sg_obj.id})\n"
683
816
  f" Final Score: {scores['final_score']:.4f} ("
684
817
  f"NormSim*W: {scores['weighted_norm_similarity']:.4f}, "
685
- f"ScaledPR*W: {scores['weighted_scaled_pagerank']:.4f})"
818
+ f"NormPR*W: {scores['weighted_scaled_pagerank']:.4f}, "
819
+ f"NormWordMatch*W: {scores['weighted_word_match_score']:.4f})"
686
820
  )
687
821
  print(
688
- f" Scores: [NormSim: {scores['norm_similarity']:.4f}, "
689
- f"ScaledPR: {scores['scaled_pagerank']:.4f}, "
690
- f"RawSim: {scores['raw_similarity']:.4f}]"
822
+ f" Scores: [NormSim: {scores['norm_similarity']:.4f} "
823
+ f"(Raw: {scores['raw_similarity']:.4f}), "
824
+ f"NormPR: {scores['scaled_pagerank']:.4f} "
825
+ f"(Original: {scores['original_pagerank_score']:.4f}), "
826
+ f"NormWordMatch: {scores['norm_word_match_score']:.4f} "
827
+ f"(OriginalBM25: {scores['raw_word_match_score']:.2f})]"
691
828
  )
692
829
 
693
830
  lean_display = (
@@ -707,16 +844,13 @@ def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> Non
707
844
  print(f" Description: {desc_display_short.replace(NEWLINE, ' ')}")
708
845
 
709
846
  source_loc = sg_obj.source_file or "[No source file]"
710
- if source_loc.startswith("Mathlib/"): # Simplify Mathlib paths
847
+ if source_loc.startswith("Mathlib/"):
711
848
  source_loc = source_loc[len("Mathlib/") :]
712
849
  print(f" File: {source_loc}:{sg_obj.range_start_line}")
713
850
 
714
851
  print("\n---------------------------------------------------")
715
852
 
716
853
 
717
- # --- Argument Parsing & Main Execution ---
718
-
719
-
720
854
  def parse_arguments() -> argparse.Namespace:
721
855
  """Parses command-line arguments for the search script.
722
856
 
@@ -732,15 +866,15 @@ def parse_arguments() -> argparse.Namespace:
732
866
  "--limit",
733
867
  "-n",
734
868
  type=int,
735
- default=None, # Will use DEFAULT_RESULTS_LIMIT from defaults if None
869
+ default=None,
736
870
  help="Maximum number of final results to display. Overrides default if set.",
737
871
  )
738
872
  parser.add_argument(
739
873
  "--packages",
740
874
  metavar="PKG",
741
875
  type=str,
742
- nargs="*", # Allows zero or more package names
743
- default=None, # No filter if not provided
876
+ nargs="*",
877
+ default=None,
744
878
  help="Filter search results by specific package names (e.g., Mathlib Std). "
745
879
  "If not provided, searches all packages.",
746
880
  )
@@ -756,7 +890,6 @@ def main():
756
890
  "lean_explore.defaults."
757
891
  )
758
892
 
759
- # These now point to the versioned paths, e.g., .../toolchains/0.1.0/file.db
760
893
  db_url = defaults.DEFAULT_DB_URL
761
894
  embedding_model_name = defaults.DEFAULT_EMBEDDING_MODEL_NAME
762
895
  resolved_idx_path = str(defaults.DEFAULT_FAISS_INDEX_PATH.resolve())
@@ -765,11 +898,13 @@ def main():
765
898
  faiss_k_cand = defaults.DEFAULT_FAISS_K
766
899
  pr_weight = defaults.DEFAULT_PAGERANK_WEIGHT
767
900
  sem_sim_weight = defaults.DEFAULT_TEXT_RELEVANCE_WEIGHT
901
+ name_match_w = defaults.DEFAULT_NAME_MATCH_WEIGHT
768
902
  results_disp_limit = (
769
903
  args.limit if args.limit is not None else defaults.DEFAULT_RESULTS_LIMIT
770
904
  )
771
905
  semantic_sim_thresh = defaults.DEFAULT_SEM_SIM_THRESHOLD
772
906
  faiss_nprobe_val = defaults.DEFAULT_FAISS_NPROBE
907
+ faiss_oversampling_factor_val = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR
773
908
 
774
909
  db_url_display = (
775
910
  f"...{str(defaults.DEFAULT_DB_PATH.resolve())[-30:]}"
@@ -785,23 +920,22 @@ def main():
785
920
  logger.info("No package filter specified, searching all packages.")
786
921
  logger.info("FAISS k (candidates): %d", faiss_k_cand)
787
922
  logger.info("FAISS nprobe (from defaults): %d", faiss_nprobe_val)
923
+ logger.info(
924
+ "FAISS Oversampling Factor (from defaults): %d", faiss_oversampling_factor_val
925
+ )
788
926
  logger.info(
789
927
  "Semantic Similarity Threshold (from defaults): %.3f", semantic_sim_thresh
790
928
  )
791
929
  logger.info(
792
- "Weights -> NormTextSim: %.2f, ScaledPR: %.2f",
930
+ "Weights -> NormTextSim: %.2f, NormPR: %.2f, NormWordMatch (BM25): %.2f",
793
931
  sem_sim_weight,
794
932
  pr_weight,
933
+ name_match_w,
795
934
  )
796
935
  logger.info("Using FAISS index: %s", resolved_idx_path)
797
936
  logger.info("Using ID map: %s", resolved_map_path)
798
- logger.info(
799
- "Database path: %s", db_url_display
800
- ) # Changed from URL for clarity with file paths
937
+ logger.info("Database path: %s", db_url_display)
801
938
 
802
- # Ensure user data directory and toolchain directory exist for logs etc.
803
- # The fetch command handles creation of the specific toolchain version dir.
804
- # Here, we ensure the base log directory can be created by performance logger.
805
939
  try:
806
940
  _USER_LOGS_BASE_DIR.mkdir(parents=True, exist_ok=True)
807
941
  except OSError as e:
@@ -811,10 +945,8 @@ def main():
811
945
 
812
946
  engine = None
813
947
  try:
814
- # Asset loading with improved error potential
815
948
  s_transformer_model = load_embedding_model(embedding_model_name)
816
949
  if s_transformer_model is None:
817
- # load_embedding_model already logs the error
818
950
  logger.error(
819
951
  "Sentence transformer model loading failed. Cannot proceed with search."
820
952
  )
@@ -822,7 +954,6 @@ def main():
822
954
 
823
955
  faiss_idx, id_map = load_faiss_assets(resolved_idx_path, resolved_map_path)
824
956
  if faiss_idx is None or id_map is None:
825
- # load_faiss_assets already logs details
826
957
  logger.error(
827
958
  "Failed to load critical FAISS assets (index or ID map).\n"
828
959
  f"Expected at:\n Index path: {resolved_idx_path}\n"
@@ -832,13 +963,9 @@ def main():
832
963
  )
833
964
  sys.exit(1)
834
965
 
835
- # Database connection
836
- # Check for DB file existence before creating engine if it's a
837
- # file-based SQLite DB
838
966
  is_file_db = db_url.startswith("sqlite:///")
839
967
  db_file_path = None
840
968
  if is_file_db:
841
- # Extract file path from sqlite:/// URL
842
969
  db_file_path_str = db_url[len("sqlite///") :]
843
970
  db_file_path = pathlib.Path(db_file_path_str)
844
971
  if not db_file_path.exists():
@@ -864,14 +991,16 @@ def main():
864
991
  pagerank_weight=pr_weight,
865
992
  text_relevance_weight=sem_sim_weight,
866
993
  log_searches=True,
994
+ name_match_weight=name_match_w,
867
995
  selected_packages=args.packages,
868
- semantic_similarity_threshold=semantic_sim_thresh, # from defaults
869
- faiss_nprobe=faiss_nprobe_val, # from defaults
996
+ semantic_similarity_threshold=semantic_sim_thresh,
997
+ faiss_nprobe=faiss_nprobe_val,
998
+ faiss_oversampling_factor=faiss_oversampling_factor_val,
870
999
  )
871
1000
 
872
1001
  print_results(ranked_results[:results_disp_limit])
873
1002
 
874
- except FileNotFoundError as e: # Should be less common now with explicit checks
1003
+ except FileNotFoundError as e:
875
1004
  logger.error(
876
1005
  f"A required file was not found: {e.filename}.\n"
877
1006
  "This could be an issue with configured paths or missing data.\n"
@@ -899,12 +1028,12 @@ def main():
899
1028
  f"Database connection/operational error: {e_db}", exc_info=True
900
1029
  )
901
1030
  sys.exit(1)
902
- except SQLAlchemyError as e_sqla: # Catch other SQLAlchemy errors
1031
+ except SQLAlchemyError as e_sqla:
903
1032
  logger.error(
904
1033
  "A database error occurred during search: %s", e_sqla, exc_info=True
905
1034
  )
906
1035
  sys.exit(1)
907
- except Exception as e_general: # Catch-all for other unexpected critical errors
1036
+ except Exception as e_general:
908
1037
  logger.critical(
909
1038
  "An unexpected critical error occurred during search: %s",
910
1039
  e_general,