lean-explore 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. lean_explore/__init__.py +14 -1
  2. lean_explore/api/__init__.py +12 -1
  3. lean_explore/api/client.py +60 -80
  4. lean_explore/cli/__init__.py +10 -1
  5. lean_explore/cli/data_commands.py +157 -479
  6. lean_explore/cli/display.py +171 -0
  7. lean_explore/cli/main.py +51 -608
  8. lean_explore/config.py +244 -0
  9. lean_explore/extract/__init__.py +5 -0
  10. lean_explore/extract/__main__.py +368 -0
  11. lean_explore/extract/doc_gen4.py +200 -0
  12. lean_explore/extract/doc_parser.py +499 -0
  13. lean_explore/extract/embeddings.py +371 -0
  14. lean_explore/extract/github.py +110 -0
  15. lean_explore/extract/index.py +317 -0
  16. lean_explore/extract/informalize.py +653 -0
  17. lean_explore/extract/package_config.py +59 -0
  18. lean_explore/extract/package_registry.py +45 -0
  19. lean_explore/extract/package_utils.py +105 -0
  20. lean_explore/extract/types.py +25 -0
  21. lean_explore/mcp/__init__.py +11 -1
  22. lean_explore/mcp/app.py +14 -46
  23. lean_explore/mcp/server.py +20 -35
  24. lean_explore/mcp/tools.py +70 -177
  25. lean_explore/models/__init__.py +9 -0
  26. lean_explore/models/search_db.py +76 -0
  27. lean_explore/models/search_types.py +53 -0
  28. lean_explore/search/__init__.py +32 -0
  29. lean_explore/search/engine.py +655 -0
  30. lean_explore/search/scoring.py +156 -0
  31. lean_explore/search/service.py +68 -0
  32. lean_explore/search/tokenization.py +71 -0
  33. lean_explore/util/__init__.py +28 -0
  34. lean_explore/util/embedding_client.py +92 -0
  35. lean_explore/util/logging.py +22 -0
  36. lean_explore/util/openrouter_client.py +63 -0
  37. lean_explore/util/reranker_client.py +189 -0
  38. {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/METADATA +55 -10
  39. lean_explore-1.0.0.dist-info/RECORD +43 -0
  40. {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/WHEEL +1 -1
  41. lean_explore-1.0.0.dist-info/entry_points.txt +2 -0
  42. lean_explore/cli/agent.py +0 -781
  43. lean_explore/cli/config_utils.py +0 -481
  44. lean_explore/defaults.py +0 -114
  45. lean_explore/local/__init__.py +0 -1
  46. lean_explore/local/search.py +0 -1050
  47. lean_explore/local/service.py +0 -392
  48. lean_explore/shared/__init__.py +0 -1
  49. lean_explore/shared/models/__init__.py +0 -1
  50. lean_explore/shared/models/api.py +0 -117
  51. lean_explore/shared/models/db.py +0 -396
  52. lean_explore-0.2.2.dist-info/RECORD +0 -26
  53. lean_explore-0.2.2.dist-info/entry_points.txt +0 -2
  54. {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/licenses/LICENSE +0 -0
  55. {lean_explore-0.2.2.dist-info → lean_explore-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,1050 +0,0 @@
1
- # src/lean_explore/local/search.py
2
-
3
- """Performs semantic search and ranked retrieval of StatementGroups.
4
-
5
- Combines semantic similarity from FAISS, pre-scaled PageRank scores, and
6
- lexical word matching (on Lean name, docstring, and informal descriptions)
7
- to rank StatementGroups. It loads necessary assets (embedding model,
8
- FAISS index, ID map) using default configurations, embeds the user query,
9
- performs FAISS search, filters based on a similarity threshold,
10
- retrieves group details from the database, normalizes semantic similarity,
11
- PageRank, and BM25 scores based on the current candidate set, and then
12
- combines these normalized scores using configurable weights to produce a
13
- final ranked list. It also logs search performance statistics to a dedicated
14
- JSONL file.
15
- """
16
-
17
- import argparse
18
- import datetime
19
- import json
20
- import logging
21
- import os
22
- import pathlib
23
- import re
24
- import sys
25
- import time
26
- from typing import Any, Dict, List, Optional, Tuple
27
-
28
- from filelock import FileLock, Timeout
29
-
30
- try:
31
- import faiss
32
- import numpy as np
33
- from nltk.stem.porter import PorterStemmer
34
- from rank_bm25 import BM25Plus
35
- from sentence_transformers import SentenceTransformer
36
- from sqlalchemy import create_engine, or_, select
37
- from sqlalchemy.exc import OperationalError, SQLAlchemyError
38
- from sqlalchemy.orm import Session, joinedload, sessionmaker
39
- except ImportError as e:
40
- print(
41
- f"Error: Missing required libraries ({e}).\n"
42
- "Please install them: pip install SQLAlchemy faiss-cpu "
43
- "sentence-transformers numpy filelock rapidfuzz rank_bm25 nltk",
44
- file=sys.stderr,
45
- )
46
- sys.exit(1)
47
-
48
- try:
49
- from lean_explore import defaults
50
- from lean_explore.shared.models.db import StatementGroup
51
- except ImportError as e:
52
- print(
53
- f"Error: Could not import project modules (StatementGroup, defaults): {e}\n"
54
- "Ensure 'lean_explore' is installed (e.g., 'pip install -e .') "
55
- "and all dependencies are met.",
56
- file=sys.stderr,
57
- )
58
- sys.exit(1)
59
-
60
-
61
- logging.basicConfig(
62
- level=logging.WARNING,
63
- format="%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s",
64
- datefmt="%Y-%m-%d %H:%M:%S",
65
- )
66
- logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
67
- logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
68
- logger = logging.getLogger(__name__)
69
-
70
- NEWLINE = os.linesep
71
- EPSILON = 1e-9
72
- PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
73
-
74
- _USER_LOGS_BASE_DIR = defaults.LEAN_EXPLORE_USER_DATA_DIR.parent / "logs"
75
- PERFORMANCE_LOG_DIR = str(_USER_LOGS_BASE_DIR)
76
- PERFORMANCE_LOG_FILENAME = "search_stats.jsonl"
77
- PERFORMANCE_LOG_PATH = os.path.join(PERFORMANCE_LOG_DIR, PERFORMANCE_LOG_FILENAME)
78
- LOCK_PATH = os.path.join(PERFORMANCE_LOG_DIR, f"{PERFORMANCE_LOG_FILENAME}.lock")
79
-
80
-
81
- def log_search_event_to_json(
82
- status: str,
83
- duration_ms: float,
84
- results_count: int,
85
- error_type: Optional[str] = None,
86
- ) -> None:
87
- """Logs a search event as a JSON line to a dedicated performance log file.
88
-
89
- Args:
90
- status: A string code indicating the outcome of the search.
91
- duration_ms: The total duration of the search processing in milliseconds.
92
- results_count: The number of search results returned.
93
- error_type: Optional. The type of error if the status indicates an error.
94
- """
95
- log_entry = {
96
- "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
97
- "event": "search_processed",
98
- "status": status,
99
- "duration_ms": round(duration_ms, 2),
100
- "results_count": results_count,
101
- }
102
- if error_type:
103
- log_entry["error_type"] = error_type
104
-
105
- try:
106
- os.makedirs(PERFORMANCE_LOG_DIR, exist_ok=True)
107
- except OSError as e:
108
- logger.error(
109
- "Performance logging error: Could not create log directory %s: %s. "
110
- "Log entry: %s",
111
- PERFORMANCE_LOG_DIR,
112
- e,
113
- log_entry,
114
- exc_info=False,
115
- )
116
- print(
117
- f"FALLBACK_PERF_LOG (DIR_ERROR): {json.dumps(log_entry)}", file=sys.stderr
118
- )
119
- return
120
-
121
- lock = FileLock(LOCK_PATH, timeout=2)
122
- try:
123
- with lock:
124
- with open(PERFORMANCE_LOG_PATH, "a", encoding="utf-8") as f:
125
- f.write(json.dumps(log_entry) + "\n")
126
- except Timeout:
127
- logger.warning(
128
- "Performance logging error: Timeout acquiring lock for %s. "
129
- "Log entry lost: %s",
130
- LOCK_PATH,
131
- log_entry,
132
- )
133
- print(
134
- f"FALLBACK_PERF_LOG (LOCK_TIMEOUT): {json.dumps(log_entry)}",
135
- file=sys.stderr,
136
- )
137
- except Exception as e:
138
- logger.error(
139
- "Performance logging error: Failed to write to %s: %s. Log entry: %s",
140
- PERFORMANCE_LOG_PATH,
141
- e,
142
- log_entry,
143
- exc_info=False,
144
- )
145
- print(
146
- f"FALLBACK_PERF_LOG (WRITE_ERROR): {json.dumps(log_entry)}", file=sys.stderr
147
- )
148
-
149
-
150
- def load_faiss_assets(
151
- index_path_str: str, map_path_str: str
152
- ) -> Tuple[Optional[faiss.Index], Optional[List[str]]]:
153
- """Loads the FAISS index and ID map from specified file paths.
154
-
155
- Args:
156
- index_path_str: String path to the FAISS index file.
157
- map_path_str: String path to the JSON ID map file.
158
-
159
- Returns:
160
- A tuple (faiss.Index or None, list_of_IDs or None).
161
- """
162
- index_path = pathlib.Path(index_path_str).resolve()
163
- map_path = pathlib.Path(map_path_str).resolve()
164
-
165
- if not index_path.exists():
166
- logger.error("FAISS index file not found: %s", index_path)
167
- return None, None
168
- if not map_path.exists():
169
- logger.error("FAISS ID map file not found: %s", map_path)
170
- return None, None
171
-
172
- faiss_index_obj: Optional[faiss.Index] = None
173
- id_map_list: Optional[List[str]] = None
174
-
175
- try:
176
- logger.info("Loading FAISS index from %s...", index_path)
177
- faiss_index_obj = faiss.read_index(str(index_path))
178
- logger.info(
179
- "Loaded FAISS index with %d vectors (Metric Type: %s).",
180
- faiss_index_obj.ntotal,
181
- faiss_index_obj.metric_type,
182
- )
183
- except Exception as e:
184
- logger.error(
185
- "Failed to load FAISS index from %s: %s", index_path, e, exc_info=True
186
- )
187
- return None, id_map_list
188
-
189
- try:
190
- logger.info("Loading ID map from %s...", map_path)
191
- with open(map_path, encoding="utf-8") as f:
192
- id_map_list = json.load(f)
193
- if not isinstance(id_map_list, list):
194
- logger.error(
195
- "ID map file (%s) does not contain a valid JSON list.", map_path
196
- )
197
- return faiss_index_obj, None
198
- logger.info("Loaded ID map with %d entries.", len(id_map_list))
199
- except Exception as e:
200
- logger.error(
201
- "Failed to load or parse ID map file %s: %s", map_path, e, exc_info=True
202
- )
203
- return faiss_index_obj, None
204
-
205
- if (
206
- faiss_index_obj is not None
207
- and id_map_list is not None
208
- and faiss_index_obj.ntotal != len(id_map_list)
209
- ):
210
- logger.warning(
211
- "Mismatch: FAISS index size (%d) vs ID map size (%d). "
212
- "Results may be inconsistent.",
213
- faiss_index_obj.ntotal,
214
- len(id_map_list),
215
- )
216
- return faiss_index_obj, id_map_list
217
-
218
-
219
- def load_embedding_model(model_name: str) -> Optional[SentenceTransformer]:
220
- """Loads the specified Sentence Transformer model.
221
-
222
- Args:
223
- model_name: The name or path of the sentence-transformer model.
224
-
225
- Returns:
226
- The loaded model, or None if loading fails.
227
- """
228
- logger.info("Loading sentence transformer model '%s'...", model_name)
229
- try:
230
- model = SentenceTransformer(model_name)
231
- logger.info(
232
- "Model '%s' loaded successfully. Max sequence length: %d.",
233
- model_name,
234
- model.max_seq_length,
235
- )
236
- return model
237
- except Exception as e:
238
- logger.error(
239
- "Failed to load sentence transformer model '%s': %s",
240
- model_name,
241
- e,
242
- exc_info=True,
243
- )
244
- return None
245
-
246
-
247
- def spacify_text(text: str) -> str:
248
- """Converts a string by adding spaces around delimiters and camelCase.
249
-
250
- This function takes a string, typically a file path or a name with
251
- camelCase, and transforms it to a more human-readable format by:
252
- - Replacing hyphens and underscores with single spaces.
253
- - Inserting spaces to separate words in camelCase (e.g.,
254
- 'CamelCaseWord' becomes 'Camel Case Word').
255
- - Adding spaces around common path delimiters such as '/' and '.'.
256
- - Normalizing multiple consecutive spaces into single spaces.
257
- - Stripping leading and trailing whitespace from the final string.
258
-
259
- Args:
260
- text: The input string to be transformed.
261
-
262
- Returns:
263
- The transformed string with spaces inserted for improved readability.
264
- """
265
- text_str = str(text)
266
-
267
- first_slash_index = text_str.find("/")
268
- if first_slash_index != -1:
269
- text_str = text_str[first_slash_index + 1 :]
270
-
271
- text_str = text_str.replace("-", " ").replace("_", " ").replace(".lean", "")
272
-
273
- text_str = re.sub(r"([a-z0-9])([A-Z])", r"\1 \2", text_str)
274
- text_str = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", text_str)
275
-
276
- text_str = text_str.replace("/", " ")
277
- text_str = text_str.replace(".", " ")
278
-
279
- text_str = re.sub(r"\s+", " ", text_str).strip()
280
- text_str = text_str.lower()
281
- return text_str
282
-
283
-
284
- def perform_search(
285
- session: Session,
286
- query_string: str,
287
- model: SentenceTransformer,
288
- faiss_index: faiss.Index,
289
- text_chunk_id_map: List[str],
290
- faiss_k: int,
291
- pagerank_weight: float,
292
- text_relevance_weight: float,
293
- log_searches: bool,
294
- name_match_weight: float = defaults.DEFAULT_NAME_MATCH_WEIGHT,
295
- selected_packages: Optional[List[str]] = None,
296
- semantic_similarity_threshold: float = defaults.DEFAULT_SEM_SIM_THRESHOLD,
297
- faiss_nprobe: int = defaults.DEFAULT_FAISS_NPROBE,
298
- faiss_oversampling_factor: int = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR,
299
- ) -> List[Tuple[StatementGroup, Dict[str, float]]]:
300
- """Performs semantic and lexical search, then ranks results.
301
-
302
- Scores (semantic similarity, PageRank, BM25) are normalized to a 0-1
303
- range based on the current set of candidates before being weighted and
304
- combined. If `selected_packages` are specified, `faiss_k` is multiplied
305
- by `faiss_oversampling_factor` to retrieve more initial candidates.
306
-
307
- Args:
308
- session: SQLAlchemy session for database access.
309
- query_string: The user's search query string.
310
- model: The loaded SentenceTransformer embedding model.
311
- faiss_index: The loaded FAISS index for text chunks.
312
- text_chunk_id_map: A list mapping FAISS internal indices to text chunk IDs.
313
- faiss_k: The base number of nearest neighbors to retrieve from FAISS.
314
- pagerank_weight: Weight for the PageRank score.
315
- text_relevance_weight: Weight for the semantic similarity score.
316
- log_searches: If True, search performance data will be logged.
317
- name_match_weight: Weight for the lexical word match score (BM25).
318
- Defaults to `defaults.DEFAULT_NAME_MATCH_WEIGHT`.
319
- selected_packages: Optional list of package names to filter search by.
320
- Defaults to None.
321
- semantic_similarity_threshold: Minimum similarity for a result to be
322
- considered. Defaults to `defaults.DEFAULT_SEM_SIM_THRESHOLD`.
323
- faiss_nprobe: Number of closest cells/clusters to search for IVF-type
324
- FAISS indexes. Defaults to `defaults.DEFAULT_FAISS_NPROBE`.
325
- faiss_oversampling_factor: Factor to multiply `faiss_k` by when
326
- `selected_packages` are active.
327
- Defaults to `defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR`.
328
-
329
-
330
- Returns:
331
- A list of tuples, sorted by `final_score`. Each tuple contains a
332
- `StatementGroup` object and a dictionary of its scores.
333
- The score dictionary includes:
334
- - 'final_score': The combined weighted score.
335
- - 'raw_similarity': Original FAISS similarity (0-1).
336
- - 'norm_similarity': `raw_similarity` normalized across current results.
337
- - 'original_pagerank_score': PageRank score from the database.
338
- - 'scaled_pagerank': `original_pagerank_score` normalized across current
339
- results (this key is kept for compatibility, but
340
- now holds the normalized PageRank).
341
- - 'raw_word_match_score': Original BM25 score.
342
- - 'norm_word_match_score': `raw_word_match_score` normalized across
343
- current results.
344
- - Weighted components: `weighted_norm_similarity`,
345
- `weighted_scaled_pagerank` (uses normalized PageRank),
346
- `weighted_word_match_score` (uses normalized BM25 score).
347
-
348
- Raises:
349
- Exception: If critical errors like query embedding or FAISS search fail.
350
- """
351
- overall_start_time = time.time()
352
-
353
- logger.info("Search request event initiated.")
354
- if semantic_similarity_threshold > 0.0 + EPSILON:
355
- logger.info(
356
- "Applying semantic similarity threshold: %.3f",
357
- semantic_similarity_threshold,
358
- )
359
-
360
- if not query_string.strip():
361
- logger.warning("Empty query provided. Returning no results.")
362
- if log_searches:
363
- duration_ms = (time.time() - overall_start_time) * 1000
364
- log_search_event_to_json(
365
- status="EMPTY_QUERY_SUBMITTED", duration_ms=duration_ms, results_count=0
366
- )
367
- return []
368
-
369
- try:
370
- query_embedding = model.encode([query_string.strip()], convert_to_numpy=True)[
371
- 0
372
- ].astype(np.float32)
373
- query_embedding_reshaped = np.expand_dims(query_embedding, axis=0)
374
- if faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
375
- logger.debug(
376
- "Normalizing query embedding for Inner Product (cosine) search."
377
- )
378
- faiss.normalize_L2(query_embedding_reshaped)
379
- except Exception as e:
380
- logger.error("Failed to embed query: %s", e, exc_info=True)
381
- if log_searches:
382
- duration_ms = (time.time() - overall_start_time) * 1000
383
- log_search_event_to_json(
384
- status="EMBEDDING_ERROR",
385
- duration_ms=duration_ms,
386
- results_count=0,
387
- error_type=type(e).__name__,
388
- )
389
- raise Exception(f"Query embedding failed: {e}") from e
390
-
391
- actual_faiss_k_to_use = faiss_k
392
- if selected_packages and faiss_oversampling_factor > 1:
393
- actual_faiss_k_to_use = faiss_k * faiss_oversampling_factor
394
- logger.info(
395
- f"Package filter active. "
396
- f"Using oversampled FAISS K: {actual_faiss_k_to_use} "
397
- f"(base K: {faiss_k}, factor: {faiss_oversampling_factor})"
398
- )
399
- else:
400
- logger.info(f"Using FAISS K: {actual_faiss_k_to_use} for initial retrieval.")
401
-
402
- try:
403
- logger.debug(
404
- "Searching FAISS index for top %d text chunk neighbors...",
405
- actual_faiss_k_to_use,
406
- )
407
- if hasattr(faiss_index, "nprobe") and isinstance(faiss_index.nprobe, int):
408
- if faiss_nprobe > 0:
409
- faiss_index.nprobe = faiss_nprobe
410
- logger.debug(f"Set FAISS nprobe to: {faiss_index.nprobe}")
411
- else:
412
- logger.warning(
413
- f"Configured faiss_nprobe is {faiss_nprobe}. Must be > 0. "
414
- "Using FAISS default or previously set nprobe for this IVF index."
415
- )
416
- distances, indices = faiss_index.search(
417
- query_embedding_reshaped, actual_faiss_k_to_use
418
- )
419
- except Exception as e:
420
- logger.error("FAISS search failed: %s", e, exc_info=True)
421
- if log_searches:
422
- duration_ms = (time.time() - overall_start_time) * 1000
423
- log_search_event_to_json(
424
- status="FAISS_SEARCH_ERROR",
425
- duration_ms=duration_ms,
426
- results_count=0,
427
- error_type=type(e).__name__,
428
- )
429
- raise Exception(f"FAISS search failed: {e}") from e
430
-
431
- sg_candidates_raw_similarity: Dict[int, float] = {}
432
- if indices.size > 0 and distances.size > 0:
433
- for i, faiss_internal_idx in enumerate(indices[0]):
434
- if faiss_internal_idx == -1:
435
- continue
436
- try:
437
- text_chunk_id_str = text_chunk_id_map[faiss_internal_idx]
438
- raw_faiss_score = distances[0][i]
439
- similarity_score: float
440
-
441
- if faiss_index.metric_type == faiss.METRIC_L2:
442
- similarity_score = 1.0 / (1.0 + np.sqrt(max(0, raw_faiss_score)))
443
- elif faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT:
444
- similarity_score = raw_faiss_score
445
- else:
446
- similarity_score = 1.0 / (1.0 + max(0, raw_faiss_score))
447
- logger.warning(
448
- "Unhandled FAISS metric type %d for text chunk. "
449
- "Using 1/(1+score) for similarity.",
450
- faiss_index.metric_type,
451
- )
452
- similarity_score = max(0.0, min(1.0, similarity_score))
453
-
454
- parts = text_chunk_id_str.split("_")
455
- if len(parts) >= 2 and parts[0] == "sg":
456
- try:
457
- sg_id = int(parts[1])
458
- if (
459
- sg_id not in sg_candidates_raw_similarity
460
- or similarity_score > sg_candidates_raw_similarity[sg_id]
461
- ):
462
- sg_candidates_raw_similarity[sg_id] = similarity_score
463
- except ValueError:
464
- logger.warning(
465
- "Could not parse StatementGroup ID from chunk_id: %s",
466
- text_chunk_id_str,
467
- )
468
- else:
469
- logger.warning(
470
- "Malformed text_chunk_id format: %s", text_chunk_id_str
471
- )
472
- except IndexError:
473
- logger.warning(
474
- "FAISS internal index %d out of bounds for ID map (size %d). "
475
- "Possible data inconsistency.",
476
- faiss_internal_idx,
477
- len(text_chunk_id_map),
478
- )
479
- except Exception as e:
480
- logger.warning(
481
- "Error processing FAISS result for internal index %d "
482
- "(chunk_id '%s'): %s",
483
- faiss_internal_idx,
484
- text_chunk_id_str if "text_chunk_id_str" in locals() else "N/A",
485
- e,
486
- )
487
-
488
- if not sg_candidates_raw_similarity:
489
- logger.info(
490
- "No valid StatementGroup candidates found after FAISS search and parsing."
491
- )
492
- if log_searches:
493
- duration_ms = (time.time() - overall_start_time) * 1000
494
- log_search_event_to_json(
495
- status="NO_FAISS_CANDIDATES", duration_ms=duration_ms, results_count=0
496
- )
497
- return []
498
- logger.info(
499
- "Aggregated %d unique StatementGroup candidates from FAISS results.",
500
- len(sg_candidates_raw_similarity),
501
- )
502
-
503
- if semantic_similarity_threshold > 0.0 + EPSILON:
504
- initial_candidate_count = len(sg_candidates_raw_similarity)
505
- sg_candidates_raw_similarity = {
506
- sg_id: sim
507
- for sg_id, sim in sg_candidates_raw_similarity.items()
508
- if sim >= semantic_similarity_threshold
509
- }
510
- logger.info(
511
- "Post-thresholding: %d of %d candidates remaining (threshold: %.3f).",
512
- len(sg_candidates_raw_similarity),
513
- initial_candidate_count,
514
- semantic_similarity_threshold,
515
- )
516
-
517
- if not sg_candidates_raw_similarity:
518
- logger.info(
519
- "No StatementGroup candidates met the semantic similarity "
520
- "threshold of %.3f.",
521
- semantic_similarity_threshold,
522
- )
523
- if log_searches:
524
- duration_ms = (time.time() - overall_start_time) * 1000
525
- log_search_event_to_json(
526
- status="NO_CANDIDATES_POST_THRESHOLD",
527
- duration_ms=duration_ms,
528
- results_count=0,
529
- )
530
- return []
531
-
532
- candidate_sg_ids = list(sg_candidates_raw_similarity.keys())
533
- sg_objects_map: Dict[int, StatementGroup] = {}
534
- try:
535
- logger.debug(
536
- "Fetching StatementGroup details from DB for %d IDs...",
537
- len(candidate_sg_ids),
538
- )
539
- stmt = select(StatementGroup).where(StatementGroup.id.in_(candidate_sg_ids))
540
-
541
- if selected_packages:
542
- logger.info("Filtering search by packages: %s", selected_packages)
543
- package_filters_sqla = []
544
- for pkg_name in selected_packages:
545
- if pkg_name.strip():
546
- package_filters_sqla.append(
547
- StatementGroup.source_file.startswith(pkg_name.strip() + "/")
548
- )
549
-
550
- if package_filters_sqla:
551
- stmt = stmt.where(or_(*package_filters_sqla))
552
-
553
- stmt = stmt.options(joinedload(StatementGroup.primary_declaration))
554
- db_results = session.execute(stmt).scalars().unique().all()
555
- for sg_obj in db_results:
556
- sg_objects_map[sg_obj.id] = sg_obj
557
-
558
- logger.debug(
559
- "Fetched details for %d StatementGroups from DB that matched filters.",
560
- len(sg_objects_map),
561
- )
562
- final_candidate_ids_after_db_match = set(sg_objects_map.keys())
563
- original_faiss_candidate_ids = set(candidate_sg_ids)
564
-
565
- if len(final_candidate_ids_after_db_match) < len(original_faiss_candidate_ids):
566
- missing_from_db_or_filtered_out = (
567
- original_faiss_candidate_ids - final_candidate_ids_after_db_match
568
- )
569
- logger.info(
570
- "%d candidates from FAISS (post-threshold) were not found in DB "
571
- "or excluded by package filters: (e.g., %s).",
572
- len(missing_from_db_or_filtered_out),
573
- list(missing_from_db_or_filtered_out)[:5],
574
- )
575
-
576
- except SQLAlchemyError as e:
577
- logger.error(
578
- "Database query for StatementGroup details failed: %s", e, exc_info=True
579
- )
580
- if log_searches:
581
- duration_ms = (time.time() - overall_start_time) * 1000
582
- log_search_event_to_json(
583
- status="DB_FETCH_ERROR",
584
- duration_ms=duration_ms,
585
- results_count=0,
586
- error_type=type(e).__name__,
587
- )
588
- raise
589
-
590
- results_with_scores: List[Tuple[StatementGroup, Dict[str, float]]] = []
591
- candidate_semantic_similarities: List[float] = []
592
- candidate_pagerank_scores: List[float] = []
593
-
594
- processed_candidates_data: List[Dict[str, Any]] = []
595
-
596
- for sg_id in final_candidate_ids_after_db_match:
597
- sg_obj = sg_objects_map[sg_id]
598
- raw_sem_sim = sg_candidates_raw_similarity[sg_id]
599
-
600
- processed_candidates_data.append(
601
- {
602
- "sg_obj": sg_obj,
603
- "raw_sem_sim": raw_sem_sim,
604
- "original_pagerank": sg_obj.scaled_pagerank_score
605
- if sg_obj.scaled_pagerank_score is not None
606
- else 0.0,
607
- }
608
- )
609
- candidate_semantic_similarities.append(raw_sem_sim)
610
- candidate_pagerank_scores.append(
611
- sg_obj.scaled_pagerank_score
612
- if sg_obj.scaled_pagerank_score is not None
613
- else 0.0
614
- )
615
-
616
- if not processed_candidates_data:
617
- logger.info(
618
- "No candidates remaining after matching with DB data or other "
619
- "processing steps."
620
- )
621
- if log_searches:
622
- duration_ms = (time.time() - overall_start_time) * 1000
623
- log_search_event_to_json(
624
- status="NO_CANDIDATES_POST_PROCESSING",
625
- duration_ms=duration_ms,
626
- results_count=0,
627
- )
628
- return []
629
-
630
- stemmer = PorterStemmer()
631
-
632
- def _get_tokenized_list(text_to_tokenize: str) -> List[str]:
633
- if not text_to_tokenize:
634
- return []
635
- tokens = re.findall(r"\w+", text_to_tokenize.lower())
636
- return [stemmer.stem(token) for token in tokens]
637
-
638
- tokenized_query = _get_tokenized_list(query_string.strip())
639
- bm25_corpus: List[List[str]] = []
640
- for candidate_item_data in processed_candidates_data:
641
- sg_obj_for_corpus = candidate_item_data["sg_obj"]
642
- combined_text_for_bm25 = " ".join(
643
- filter(
644
- None,
645
- [
646
- (
647
- sg_obj_for_corpus.primary_declaration.lean_name
648
- if sg_obj_for_corpus.primary_declaration
649
- else None
650
- ),
651
- sg_obj_for_corpus.docstring,
652
- sg_obj_for_corpus.informal_description,
653
- sg_obj_for_corpus.informal_summary,
654
- sg_obj_for_corpus.display_statement_text,
655
- (
656
- sg_obj_for_corpus.primary_declaration.lean_name
657
- if sg_obj_for_corpus.primary_declaration
658
- else None
659
- ),
660
- (
661
- spacify_text(sg_obj_for_corpus.primary_declaration.source_file)
662
- if sg_obj_for_corpus.primary_declaration
663
- and sg_obj_for_corpus.primary_declaration.source_file
664
- else None
665
- ),
666
- ],
667
- )
668
- )
669
- bm25_corpus.append(_get_tokenized_list(combined_text_for_bm25))
670
-
671
- raw_bm25_scores_list: List[float] = [0.0] * len(processed_candidates_data)
672
- if tokenized_query and any(bm25_corpus):
673
- try:
674
- bm25_model = BM25Plus(bm25_corpus)
675
- raw_bm25_scores_list = bm25_model.get_scores(tokenized_query)
676
- raw_bm25_scores_list = [
677
- max(0.0, float(score)) for score in raw_bm25_scores_list
678
- ]
679
- except Exception as e:
680
- logger.warning(
681
- "BM25Plus scoring failed: %s. Word match scores defaulted to 0.",
682
- e,
683
- exc_info=False,
684
- )
685
- raw_bm25_scores_list = [0.0] * len(processed_candidates_data)
686
-
687
- min_sem_sim = (
688
- min(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
689
- )
690
- max_sem_sim = (
691
- max(candidate_semantic_similarities) if candidate_semantic_similarities else 0.0
692
- )
693
- range_sem_sim = max_sem_sim - min_sem_sim
694
- logger.debug(
695
- "Raw semantic similarity range for normalization: [%.4f, %.4f]",
696
- min_sem_sim,
697
- max_sem_sim,
698
- )
699
-
700
- min_pr = min(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
701
- max_pr = max(candidate_pagerank_scores) if candidate_pagerank_scores else 0.0
702
- range_pr = max_pr - min_pr
703
- logger.debug(
704
- "Original PageRank score range for normalization: [%.4f, %.4f]", min_pr, max_pr
705
- )
706
-
707
- min_bm25 = min(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
708
- max_bm25 = max(raw_bm25_scores_list) if raw_bm25_scores_list else 0.0
709
- range_bm25 = max_bm25 - min_bm25
710
- logger.debug(
711
- "Raw BM25 score range for normalization: [%.4f, %.4f]", min_bm25, max_bm25
712
- )
713
-
714
- for i, candidate_data in enumerate(processed_candidates_data):
715
- sg_obj = candidate_data["sg_obj"]
716
- current_raw_sem_sim = candidate_data["raw_sem_sim"]
717
- original_pagerank_score = candidate_data["original_pagerank"]
718
- original_bm25_score = raw_bm25_scores_list[i]
719
-
720
- norm_sem_sim = 0.5
721
- if candidate_semantic_similarities:
722
- if range_sem_sim > EPSILON:
723
- norm_sem_sim = (current_raw_sem_sim - min_sem_sim) / range_sem_sim
724
- elif (
725
- len(candidate_semantic_similarities) == 1
726
- and candidate_semantic_similarities[0] > EPSILON
727
- ):
728
- norm_sem_sim = 1.0
729
- elif (
730
- len(candidate_semantic_similarities) > 0
731
- and range_sem_sim <= EPSILON
732
- and max_sem_sim <= EPSILON
733
- ):
734
- norm_sem_sim = 0.0
735
- else:
736
- norm_sem_sim = 0.0
737
- norm_sem_sim = max(0.0, min(1.0, norm_sem_sim))
738
-
739
- norm_pagerank_score = 0.0
740
- if candidate_pagerank_scores:
741
- if range_pr > EPSILON:
742
- norm_pagerank_score = (original_pagerank_score - min_pr) / range_pr
743
- elif max_pr > EPSILON:
744
- norm_pagerank_score = 1.0
745
- norm_pagerank_score = max(0.0, min(1.0, norm_pagerank_score))
746
-
747
- norm_bm25_score = 0.0
748
- if raw_bm25_scores_list:
749
- if range_bm25 > EPSILON:
750
- norm_bm25_score = (original_bm25_score - min_bm25) / range_bm25
751
- elif max_bm25 > EPSILON:
752
- norm_bm25_score = 1.0
753
- norm_bm25_score = max(0.0, min(1.0, norm_bm25_score))
754
-
755
- weighted_norm_similarity = text_relevance_weight * norm_sem_sim
756
- weighted_norm_pagerank = pagerank_weight * norm_pagerank_score
757
- weighted_norm_bm25_score = name_match_weight * norm_bm25_score
758
-
759
- final_score = (
760
- weighted_norm_similarity + weighted_norm_pagerank + weighted_norm_bm25_score
761
- )
762
-
763
- score_dict = {
764
- "final_score": final_score,
765
- "raw_similarity": current_raw_sem_sim,
766
- "norm_similarity": norm_sem_sim,
767
- "original_pagerank_score": original_pagerank_score,
768
- "scaled_pagerank": norm_pagerank_score,
769
- "raw_word_match_score": original_bm25_score,
770
- "norm_word_match_score": norm_bm25_score,
771
- "weighted_norm_similarity": weighted_norm_similarity,
772
- "weighted_scaled_pagerank": weighted_norm_pagerank,
773
- "weighted_word_match_score": weighted_norm_bm25_score,
774
- }
775
- results_with_scores.append((sg_obj, score_dict))
776
-
777
- results_with_scores.sort(key=lambda item: item[1]["final_score"], reverse=True)
778
-
779
- final_status = "SUCCESS"
780
- results_count = len(results_with_scores)
781
- if not results_with_scores and processed_candidates_data:
782
- final_status = "NO_RESULTS_FINAL_SCORED"
783
- elif not results_with_scores and not processed_candidates_data:
784
- if not sg_candidates_raw_similarity:
785
- final_status = "NO_CANDIDATES_POST_THRESHOLD"
786
-
787
- if log_searches:
788
- duration_ms = (time.time() - overall_start_time) * 1000
789
- log_search_event_to_json(
790
- status=final_status, duration_ms=duration_ms, results_count=results_count
791
- )
792
-
793
- return results_with_scores
794
-
795
-
796
- def print_results(results: List[Tuple[StatementGroup, Dict[str, float]]]) -> None:
797
- """Formats and prints the search results to the console.
798
-
799
- Args:
800
- results: A list of tuples, each containing a StatementGroup
801
- object and its scores, sorted by final_score.
802
- """
803
- if not results:
804
- print("\nNo results found.")
805
- return
806
-
807
- print(f"\n--- Top {len(results)} Search Results (StatementGroups) ---")
808
- for i, (sg_obj, scores) in enumerate(results):
809
- primary_decl_name = (
810
- sg_obj.primary_declaration.lean_name
811
- if sg_obj.primary_declaration and sg_obj.primary_declaration.lean_name
812
- else "N/A"
813
- )
814
- print(
815
- f"\n{i + 1}. Lean Name: {primary_decl_name} (SG ID: {sg_obj.id})\n"
816
- f" Final Score: {scores['final_score']:.4f} ("
817
- f"NormSim*W: {scores['weighted_norm_similarity']:.4f}, "
818
- f"NormPR*W: {scores['weighted_scaled_pagerank']:.4f}, "
819
- f"NormWordMatch*W: {scores['weighted_word_match_score']:.4f})"
820
- )
821
- print(
822
- f" Scores: [NormSim: {scores['norm_similarity']:.4f} "
823
- f"(Raw: {scores['raw_similarity']:.4f}), "
824
- f"NormPR: {scores['scaled_pagerank']:.4f} "
825
- f"(Original: {scores['original_pagerank_score']:.4f}), "
826
- f"NormWordMatch: {scores['norm_word_match_score']:.4f} "
827
- f"(OriginalBM25: {scores['raw_word_match_score']:.2f})]"
828
- )
829
-
830
- lean_display = (
831
- sg_obj.display_statement_text or sg_obj.statement_text or "[No Lean code]"
832
- )
833
- lean_display_short = (
834
- (lean_display[:200] + "...") if len(lean_display) > 200 else lean_display
835
- )
836
- print(f" Lean Code: {lean_display_short.replace(NEWLINE, ' ')}")
837
-
838
- desc_display = (
839
- sg_obj.informal_description or sg_obj.docstring or "[No description]"
840
- )
841
- desc_display_short = (
842
- (desc_display[:150] + "...") if len(desc_display) > 150 else desc_display
843
- )
844
- print(f" Description: {desc_display_short.replace(NEWLINE, ' ')}")
845
-
846
- source_loc = sg_obj.source_file or "[No source file]"
847
- if source_loc.startswith("Mathlib/"):
848
- source_loc = source_loc[len("Mathlib/") :]
849
- print(f" File: {source_loc}:{sg_obj.range_start_line}")
850
-
851
- print("\n---------------------------------------------------")
852
-
853
-
854
- def parse_arguments() -> argparse.Namespace:
855
- """Parses command-line arguments for the search script.
856
-
857
- Returns:
858
- An object containing the parsed arguments.
859
- """
860
- parser = argparse.ArgumentParser(
861
- description="Search Lean StatementGroups using combined scoring.",
862
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
863
- )
864
- parser.add_argument("query", type=str, help="The search query string.")
865
- parser.add_argument(
866
- "--limit",
867
- "-n",
868
- type=int,
869
- default=None,
870
- help="Maximum number of final results to display. Overrides default if set.",
871
- )
872
- parser.add_argument(
873
- "--packages",
874
- metavar="PKG",
875
- type=str,
876
- nargs="*",
877
- default=None,
878
- help="Filter search results by specific package names (e.g., Mathlib Std). "
879
- "If not provided, searches all packages.",
880
- )
881
- return parser.parse_args()
882
-
883
-
884
- def main():
885
- """Main execution function for the search script."""
886
- args = parse_arguments()
887
-
888
- logger.info(
889
- "Using default configurations for paths and parameters from "
890
- "lean_explore.defaults."
891
- )
892
-
893
- db_url = defaults.DEFAULT_DB_URL
894
- embedding_model_name = defaults.DEFAULT_EMBEDDING_MODEL_NAME
895
- resolved_idx_path = str(defaults.DEFAULT_FAISS_INDEX_PATH.resolve())
896
- resolved_map_path = str(defaults.DEFAULT_FAISS_MAP_PATH.resolve())
897
-
898
- faiss_k_cand = defaults.DEFAULT_FAISS_K
899
- pr_weight = defaults.DEFAULT_PAGERANK_WEIGHT
900
- sem_sim_weight = defaults.DEFAULT_TEXT_RELEVANCE_WEIGHT
901
- name_match_w = defaults.DEFAULT_NAME_MATCH_WEIGHT
902
- results_disp_limit = (
903
- args.limit if args.limit is not None else defaults.DEFAULT_RESULTS_LIMIT
904
- )
905
- semantic_sim_thresh = defaults.DEFAULT_SEM_SIM_THRESHOLD
906
- faiss_nprobe_val = defaults.DEFAULT_FAISS_NPROBE
907
- faiss_oversampling_factor_val = defaults.DEFAULT_FAISS_OVERSAMPLING_FACTOR
908
-
909
- db_url_display = (
910
- f"...{str(defaults.DEFAULT_DB_PATH.resolve())[-30:]}"
911
- if len(str(defaults.DEFAULT_DB_PATH.resolve())) > 30
912
- else str(defaults.DEFAULT_DB_PATH.resolve())
913
- )
914
- logger.info("--- Starting Search (Direct Script Execution) ---")
915
- logger.info("Query: '%s'", args.query)
916
- logger.info("Displaying Top: %d results", results_disp_limit)
917
- if args.packages:
918
- logger.info("Filtering by user-specified packages: %s", args.packages)
919
- else:
920
- logger.info("No package filter specified, searching all packages.")
921
- logger.info("FAISS k (candidates): %d", faiss_k_cand)
922
- logger.info("FAISS nprobe (from defaults): %d", faiss_nprobe_val)
923
- logger.info(
924
- "FAISS Oversampling Factor (from defaults): %d", faiss_oversampling_factor_val
925
- )
926
- logger.info(
927
- "Semantic Similarity Threshold (from defaults): %.3f", semantic_sim_thresh
928
- )
929
- logger.info(
930
- "Weights -> NormTextSim: %.2f, NormPR: %.2f, NormWordMatch (BM25): %.2f",
931
- sem_sim_weight,
932
- pr_weight,
933
- name_match_w,
934
- )
935
- logger.info("Using FAISS index: %s", resolved_idx_path)
936
- logger.info("Using ID map: %s", resolved_map_path)
937
- logger.info("Database path: %s", db_url_display)
938
-
939
- try:
940
- _USER_LOGS_BASE_DIR.mkdir(parents=True, exist_ok=True)
941
- except OSError as e:
942
- logger.warning(
943
- f"Could not create user log directory {_USER_LOGS_BASE_DIR}: {e}"
944
- )
945
-
946
- engine = None
947
- try:
948
- s_transformer_model = load_embedding_model(embedding_model_name)
949
- if s_transformer_model is None:
950
- logger.error(
951
- "Sentence transformer model loading failed. Cannot proceed with search."
952
- )
953
- sys.exit(1)
954
-
955
- faiss_idx, id_map = load_faiss_assets(resolved_idx_path, resolved_map_path)
956
- if faiss_idx is None or id_map is None:
957
- logger.error(
958
- "Failed to load critical FAISS assets (index or ID map).\n"
959
- f"Expected at:\n Index path: {resolved_idx_path}\n"
960
- f" ID map path: {resolved_map_path}\n"
961
- "Please ensure these files exist or run 'leanexplore data fetch' "
962
- "to download the data toolchain."
963
- )
964
- sys.exit(1)
965
-
966
- is_file_db = db_url.startswith("sqlite:///")
967
- db_file_path = None
968
- if is_file_db:
969
- db_file_path_str = db_url[len("sqlite///") :]
970
- db_file_path = pathlib.Path(db_file_path_str)
971
- if not db_file_path.exists():
972
- logger.error(
973
- f"Database file not found at the expected location: "
974
- f"{db_file_path}\n"
975
- "Please run 'leanexplore data fetch' to download the data "
976
- "toolchain."
977
- )
978
- sys.exit(1)
979
-
980
- engine = create_engine(db_url, echo=False)
981
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
982
-
983
- with SessionLocal() as session:
984
- ranked_results = perform_search(
985
- session=session,
986
- query_string=args.query,
987
- model=s_transformer_model,
988
- faiss_index=faiss_idx,
989
- text_chunk_id_map=id_map,
990
- faiss_k=faiss_k_cand,
991
- pagerank_weight=pr_weight,
992
- text_relevance_weight=sem_sim_weight,
993
- log_searches=True,
994
- name_match_weight=name_match_w,
995
- selected_packages=args.packages,
996
- semantic_similarity_threshold=semantic_sim_thresh,
997
- faiss_nprobe=faiss_nprobe_val,
998
- faiss_oversampling_factor=faiss_oversampling_factor_val,
999
- )
1000
-
1001
- print_results(ranked_results[:results_disp_limit])
1002
-
1003
- except FileNotFoundError as e:
1004
- logger.error(
1005
- f"A required file was not found: {e.filename}.\n"
1006
- "This could be an issue with configured paths or missing data.\n"
1007
- "If this relates to core data assets, please try running "
1008
- "'leanexplore data fetch'."
1009
- )
1010
- sys.exit(1)
1011
- except OperationalError as e_db:
1012
- is_file_db_op_err = defaults.DEFAULT_DB_URL.startswith("sqlite:///")
1013
- db_file_path_op_err = defaults.DEFAULT_DB_PATH
1014
- if is_file_db_op_err and (
1015
- "unable to open database file" in str(e_db).lower()
1016
- or (db_file_path_op_err and not db_file_path_op_err.exists())
1017
- ):
1018
- p = str(db_file_path_op_err.resolve())
1019
- logger.error(
1020
- f"Database connection failed: {e_db}\n"
1021
- f"The database file appears to be missing or inaccessible at: "
1022
- f"{p if db_file_path_op_err else 'Unknown Path'}\n"
1023
- "Please run 'leanexplore data fetch' to download or update the "
1024
- "data toolchain."
1025
- )
1026
- else:
1027
- logger.error(
1028
- f"Database connection/operational error: {e_db}", exc_info=True
1029
- )
1030
- sys.exit(1)
1031
- except SQLAlchemyError as e_sqla:
1032
- logger.error(
1033
- "A database error occurred during search: %s", e_sqla, exc_info=True
1034
- )
1035
- sys.exit(1)
1036
- except Exception as e_general:
1037
- logger.critical(
1038
- "An unexpected critical error occurred during search: %s",
1039
- e_general,
1040
- exc_info=True,
1041
- )
1042
- sys.exit(1)
1043
- finally:
1044
- if engine:
1045
- engine.dispose()
1046
- logger.debug("Database engine disposed.")
1047
-
1048
-
1049
- if __name__ == "__main__":
1050
- main()