lean-explore 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lean_explore/__init__.py +14 -1
- lean_explore/api/__init__.py +12 -1
- lean_explore/api/client.py +64 -176
- lean_explore/cli/__init__.py +10 -1
- lean_explore/cli/data_commands.py +157 -479
- lean_explore/cli/display.py +171 -0
- lean_explore/cli/main.py +51 -608
- lean_explore/config.py +244 -0
- lean_explore/extract/__init__.py +5 -0
- lean_explore/extract/__main__.py +368 -0
- lean_explore/extract/doc_gen4.py +200 -0
- lean_explore/extract/doc_parser.py +499 -0
- lean_explore/extract/embeddings.py +371 -0
- lean_explore/extract/github.py +110 -0
- lean_explore/extract/index.py +317 -0
- lean_explore/extract/informalize.py +653 -0
- lean_explore/extract/package_config.py +59 -0
- lean_explore/extract/package_registry.py +45 -0
- lean_explore/extract/package_utils.py +105 -0
- lean_explore/extract/types.py +25 -0
- lean_explore/mcp/__init__.py +11 -1
- lean_explore/mcp/app.py +14 -46
- lean_explore/mcp/server.py +20 -35
- lean_explore/mcp/tools.py +70 -205
- lean_explore/models/__init__.py +9 -0
- lean_explore/models/search_db.py +76 -0
- lean_explore/models/search_types.py +53 -0
- lean_explore/search/__init__.py +32 -0
- lean_explore/search/engine.py +655 -0
- lean_explore/search/scoring.py +156 -0
- lean_explore/search/service.py +68 -0
- lean_explore/search/tokenization.py +71 -0
- lean_explore/util/__init__.py +28 -0
- lean_explore/util/embedding_client.py +92 -0
- lean_explore/util/logging.py +22 -0
- lean_explore/util/openrouter_client.py +63 -0
- lean_explore/util/reranker_client.py +189 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/METADATA +32 -9
- lean_explore-1.0.0.dist-info/RECORD +43 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/WHEEL +1 -1
- lean_explore-1.0.0.dist-info/entry_points.txt +2 -0
- lean_explore/cli/agent.py +0 -788
- lean_explore/cli/config_utils.py +0 -481
- lean_explore/defaults.py +0 -114
- lean_explore/local/__init__.py +0 -1
- lean_explore/local/search.py +0 -1050
- lean_explore/local/service.py +0 -479
- lean_explore/shared/__init__.py +0 -1
- lean_explore/shared/models/__init__.py +0 -1
- lean_explore/shared/models/api.py +0 -117
- lean_explore/shared/models/db.py +0 -396
- lean_explore-0.3.0.dist-info/RECORD +0 -26
- lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lean_explore-0.3.0.dist-info → lean_explore-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Score normalization and fusion algorithms for search ranking.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for combining multiple retrieval signals into
|
|
4
|
+
a unified ranking using techniques like Reciprocal Rank Fusion (RRF) and
|
|
5
|
+
weighted score fusion.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import difflib
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
EPSILON = 1e-9
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def normalize_scores(scores: list[float]) -> list[float]:
|
|
15
|
+
"""Min-max normalize scores to [0, 1] range.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
scores: List of raw scores.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List of normalized scores.
|
|
22
|
+
"""
|
|
23
|
+
if not scores:
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
min_score = min(scores)
|
|
27
|
+
max_score = max(scores)
|
|
28
|
+
score_range = max_score - min_score
|
|
29
|
+
|
|
30
|
+
if score_range < EPSILON:
|
|
31
|
+
if max_score > EPSILON:
|
|
32
|
+
return [1.0] * len(scores)
|
|
33
|
+
return [0.0] * len(scores)
|
|
34
|
+
|
|
35
|
+
return [(s - min_score) / score_range for s in scores]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def normalize_dependency_counts(counts: list[int]) -> list[float]:
|
|
39
|
+
"""Log-scale normalization for dependency counts.
|
|
40
|
+
|
|
41
|
+
Uses log(1 + count) / log(1 + max_count) to compress the range
|
|
42
|
+
and give more credit to items with moderate dependency counts.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
counts: List of dependency counts.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
List of normalized scores in [0, 1] range.
|
|
49
|
+
"""
|
|
50
|
+
if not counts:
|
|
51
|
+
return []
|
|
52
|
+
|
|
53
|
+
max_count = max(counts)
|
|
54
|
+
if max_count == 0:
|
|
55
|
+
return [0.0] * len(counts)
|
|
56
|
+
|
|
57
|
+
log_max = math.log(1 + max_count)
|
|
58
|
+
return [math.log(1 + c) / log_max for c in counts]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def compute_ranks(scores: list[float]) -> list[int]:
|
|
62
|
+
"""Compute ranks for a list of scores (1-indexed, higher score = lower rank).
|
|
63
|
+
|
|
64
|
+
Candidates with score 0 get rank len(scores)+1 (worst possible).
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
scores: List of raw scores.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of ranks (1 = best).
|
|
71
|
+
"""
|
|
72
|
+
n = len(scores)
|
|
73
|
+
indexed = [(i, s) for i, s in enumerate(scores)]
|
|
74
|
+
indexed.sort(key=lambda x: x[1], reverse=True)
|
|
75
|
+
|
|
76
|
+
ranks = [0] * n
|
|
77
|
+
for rank, (idx, score) in enumerate(indexed, 1):
|
|
78
|
+
if score > 0:
|
|
79
|
+
ranks[idx] = rank
|
|
80
|
+
else:
|
|
81
|
+
ranks[idx] = n + 1
|
|
82
|
+
|
|
83
|
+
return ranks
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def reciprocal_rank_fusion(rank_lists: list[list[int]], k: int = 0) -> list[float]:
|
|
87
|
+
"""Compute RRF scores from multiple rank lists.
|
|
88
|
+
|
|
89
|
+
RRF(d) = sum(1 / (k + rank_i(d)) for each signal i)
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
rank_lists: List of rank lists, one per signal.
|
|
93
|
+
k: Constant to prevent top rank from dominating. Default 0 means 1/rank.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List of RRF scores for each candidate.
|
|
97
|
+
"""
|
|
98
|
+
n = len(rank_lists[0])
|
|
99
|
+
rrf_scores = []
|
|
100
|
+
|
|
101
|
+
for i in range(n):
|
|
102
|
+
score = sum(1.0 / (k + ranks[i]) for ranks in rank_lists)
|
|
103
|
+
rrf_scores.append(score)
|
|
104
|
+
|
|
105
|
+
return rrf_scores
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def weighted_score_fusion(
|
|
109
|
+
score_lists: list[list[float]],
|
|
110
|
+
weights: list[float],
|
|
111
|
+
) -> list[float]:
|
|
112
|
+
"""Combine multiple score lists using weighted normalized scores.
|
|
113
|
+
|
|
114
|
+
Each score list is normalized to [0, 1] using min-max scaling,
|
|
115
|
+
then combined with the given weights.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
score_lists: List of score lists, one per signal.
|
|
119
|
+
weights: Weight for each signal (should sum to 1.0 for interpretability).
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of combined scores for each candidate.
|
|
123
|
+
"""
|
|
124
|
+
if not score_lists:
|
|
125
|
+
return []
|
|
126
|
+
|
|
127
|
+
n = len(score_lists[0])
|
|
128
|
+
if n == 0:
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
normalized_lists = [normalize_scores(scores) for scores in score_lists]
|
|
132
|
+
|
|
133
|
+
combined = []
|
|
134
|
+
for i in range(n):
|
|
135
|
+
score = sum(w * normalized_lists[j][i] for j, w in enumerate(weights))
|
|
136
|
+
combined.append(score)
|
|
137
|
+
|
|
138
|
+
return combined
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def fuzzy_name_score(query: str, name: str) -> float:
|
|
142
|
+
"""Compute fuzzy match score between query and declaration name.
|
|
143
|
+
|
|
144
|
+
Normalizes both strings (dots/underscores -> spaces) and uses
|
|
145
|
+
SequenceMatcher ratio for character-level similarity.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
query: Search query string.
|
|
149
|
+
name: Declaration name to match against.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Similarity score between 0 and 1.
|
|
153
|
+
"""
|
|
154
|
+
normalized_query = query.lower().replace(".", " ").replace("_", " ")
|
|
155
|
+
normalized_name = name.lower().replace(".", " ").replace("_", " ")
|
|
156
|
+
return difflib.SequenceMatcher(None, normalized_query, normalized_name).ratio()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Service layer for search operations."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from lean_explore.models import SearchResponse, SearchResult
|
|
6
|
+
from lean_explore.search.engine import SearchEngine
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Service:
|
|
10
|
+
"""Service wrapper for search operations.
|
|
11
|
+
|
|
12
|
+
Provides a clean interface for searching and retrieving declarations.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, engine: SearchEngine | None = None):
|
|
16
|
+
"""Initialize the search service.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
engine: SearchEngine instance. Defaults to new engine.
|
|
20
|
+
"""
|
|
21
|
+
self.engine = engine or SearchEngine()
|
|
22
|
+
|
|
23
|
+
async def search(
|
|
24
|
+
self,
|
|
25
|
+
query: str,
|
|
26
|
+
limit: int = 20,
|
|
27
|
+
rerank_top: int | None = 50,
|
|
28
|
+
packages: list[str] | None = None,
|
|
29
|
+
) -> SearchResponse:
|
|
30
|
+
"""Search for Lean declarations.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
query: Search query string.
|
|
34
|
+
limit: Maximum number of results to return.
|
|
35
|
+
rerank_top: Number of candidates to rerank with cross-encoder.
|
|
36
|
+
packages: Filter results to specific packages (e.g., ["Mathlib"]).
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
SearchResponse containing results and metadata.
|
|
40
|
+
"""
|
|
41
|
+
start_time = time.time()
|
|
42
|
+
|
|
43
|
+
results = await self.engine.search(
|
|
44
|
+
query=query,
|
|
45
|
+
limit=limit,
|
|
46
|
+
rerank_top=rerank_top,
|
|
47
|
+
packages=packages,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
|
51
|
+
|
|
52
|
+
return SearchResponse(
|
|
53
|
+
query=query,
|
|
54
|
+
results=results,
|
|
55
|
+
count=len(results),
|
|
56
|
+
processing_time_ms=processing_time_ms,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
async def get_by_id(self, declaration_id: int) -> SearchResult | None:
|
|
60
|
+
"""Retrieve a declaration by ID.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
declaration_id: The declaration ID.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
SearchResult if found, None otherwise.
|
|
67
|
+
"""
|
|
68
|
+
return await self.engine.get_by_id(declaration_id)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Text tokenization utilities for search indexing and querying.
|
|
2
|
+
|
|
3
|
+
This module provides tokenization strategies for Lean declaration names,
|
|
4
|
+
supporting both spaced tokenization (splits on dots, underscores, camelCase)
|
|
5
|
+
and raw tokenization (preserves structure for exact matching).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def tokenize_spaced(text: str) -> list[str]:
|
|
12
|
+
"""Tokenize text with spacing on dots, underscores, and camelCase.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
text: Input text to tokenize.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List of lowercase word tokens.
|
|
19
|
+
"""
|
|
20
|
+
if not text:
|
|
21
|
+
return []
|
|
22
|
+
# Replace dots and underscores with spaces
|
|
23
|
+
text = text.replace(".", " ").replace("_", " ")
|
|
24
|
+
# Split camelCase: insert space before uppercase letters
|
|
25
|
+
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
|
26
|
+
return re.findall(r"\w+", text.lower())
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def tokenize_raw(text: str) -> list[str]:
|
|
30
|
+
"""Tokenize text as single token (preserves dots).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
text: Input text to tokenize.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
List with the full text as a single lowercase token.
|
|
37
|
+
"""
|
|
38
|
+
if not text:
|
|
39
|
+
return []
|
|
40
|
+
return [text.lower()]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def tokenize_words(text: str) -> list[str]:
|
|
44
|
+
"""Simple word tokenization for natural language text.
|
|
45
|
+
|
|
46
|
+
Splits on whitespace and punctuation, returns lowercase tokens.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
text: Input text to tokenize.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
List of lowercase word tokens.
|
|
53
|
+
"""
|
|
54
|
+
if not text:
|
|
55
|
+
return []
|
|
56
|
+
return [w.lower() for w in re.findall(r"\w+", text)]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_autogenerated(name: str) -> bool:
|
|
60
|
+
"""Check if a declaration name is auto-generated by Lean.
|
|
61
|
+
|
|
62
|
+
Auto-generated declarations include:
|
|
63
|
+
- .mk constructors (e.g., Nat.mk)
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
name: Fully qualified declaration name.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
True if the declaration is auto-generated.
|
|
70
|
+
"""
|
|
71
|
+
return name.endswith(".mk")
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Shared utilities for lean_explore.
|
|
2
|
+
|
|
3
|
+
Imports are lazy to avoid loading torch when not needed.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def __getattr__(name: str):
|
|
8
|
+
"""Lazy import attributes to avoid loading torch unnecessarily."""
|
|
9
|
+
if name == "EmbeddingClient":
|
|
10
|
+
from lean_explore.util.embedding_client import EmbeddingClient
|
|
11
|
+
|
|
12
|
+
return EmbeddingClient
|
|
13
|
+
if name == "RerankerClient":
|
|
14
|
+
from lean_explore.util.reranker_client import RerankerClient
|
|
15
|
+
|
|
16
|
+
return RerankerClient
|
|
17
|
+
if name == "OpenRouterClient":
|
|
18
|
+
from lean_explore.util.openrouter_client import OpenRouterClient
|
|
19
|
+
|
|
20
|
+
return OpenRouterClient
|
|
21
|
+
if name == "setup_logging":
|
|
22
|
+
from lean_explore.util.logging import setup_logging
|
|
23
|
+
|
|
24
|
+
return setup_logging
|
|
25
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
__all__ = ["EmbeddingClient", "RerankerClient", "OpenRouterClient", "setup_logging"]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Embedding generation client using sentence transformers."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingResponse(BaseModel):
|
|
14
|
+
"""Response from embedding generation."""
|
|
15
|
+
|
|
16
|
+
texts: list[str]
|
|
17
|
+
"""Original input texts."""
|
|
18
|
+
|
|
19
|
+
embeddings: list[list[float]]
|
|
20
|
+
"""List of embeddings (one per input text)."""
|
|
21
|
+
|
|
22
|
+
model: str
|
|
23
|
+
"""Model name used for generation."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmbeddingClient:
|
|
27
|
+
"""Client for generating text embeddings."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self, model_name: str, device: str | None = None, max_length: int | None = None
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the embedding client.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model_name: Name of the sentence transformer model
|
|
36
|
+
device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
|
|
37
|
+
max_length: Maximum sequence length for tokenization. If None, uses
|
|
38
|
+
model default. Lower values reduce memory usage.
|
|
39
|
+
"""
|
|
40
|
+
self.model_name = model_name
|
|
41
|
+
self.device = device or self._select_device()
|
|
42
|
+
self.max_length = max_length
|
|
43
|
+
logger.info(f"Loading embedding model {model_name} on {self.device}")
|
|
44
|
+
self.model = SentenceTransformer(model_name, device=self.device)
|
|
45
|
+
|
|
46
|
+
# Set max sequence length if specified
|
|
47
|
+
if max_length is not None:
|
|
48
|
+
self.model.max_seq_length = max_length
|
|
49
|
+
logger.info(f"Set max sequence length to {max_length}")
|
|
50
|
+
|
|
51
|
+
def _select_device(self) -> str:
|
|
52
|
+
"""Select best available device."""
|
|
53
|
+
if torch.cuda.is_available():
|
|
54
|
+
return "cuda"
|
|
55
|
+
if torch.backends.mps.is_available():
|
|
56
|
+
return "mps"
|
|
57
|
+
return "cpu"
|
|
58
|
+
|
|
59
|
+
async def embed(
|
|
60
|
+
self, texts: list[str], is_query: bool = False
|
|
61
|
+
) -> EmbeddingResponse:
|
|
62
|
+
"""Generate embeddings for a list of texts.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
texts: List of text strings to embed
|
|
66
|
+
is_query: If True, encode as search queries using the model's query
|
|
67
|
+
prompt. If False (default), encode as documents without prompt.
|
|
68
|
+
Qwen3-Embedding models are asymmetric and perform better when
|
|
69
|
+
queries use prompt_name="query".
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
EmbeddingResponse with texts, embeddings, and model info
|
|
73
|
+
"""
|
|
74
|
+
loop = asyncio.get_event_loop()
|
|
75
|
+
|
|
76
|
+
def _encode():
|
|
77
|
+
# Use query prompt for search queries, no prompt for documents
|
|
78
|
+
encode_kwargs = {
|
|
79
|
+
"show_progress_bar": False,
|
|
80
|
+
"convert_to_numpy": True,
|
|
81
|
+
"batch_size": 256, # Larger batches for GPU utilization
|
|
82
|
+
}
|
|
83
|
+
if is_query:
|
|
84
|
+
encode_kwargs["prompt_name"] = "query"
|
|
85
|
+
return self.model.encode(texts, **encode_kwargs)
|
|
86
|
+
|
|
87
|
+
embeddings = await loop.run_in_executor(None, _encode)
|
|
88
|
+
return EmbeddingResponse(
|
|
89
|
+
texts=texts,
|
|
90
|
+
embeddings=[emb.tolist() for emb in embeddings],
|
|
91
|
+
model=self.model_name,
|
|
92
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Logging utilities for lean_explore."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def setup_logging(verbose: bool = False) -> None:
|
|
8
|
+
"""Configure logging for lean_explore applications.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
verbose: If True, set level to DEBUG; otherwise INFO.
|
|
12
|
+
"""
|
|
13
|
+
level = logging.DEBUG if verbose else logging.INFO
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
level=level,
|
|
16
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
17
|
+
handlers=[logging.StreamHandler(sys.stdout)],
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Suppress noisy third-party loggers
|
|
21
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
22
|
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""OpenRouter API wrapper using OpenAI SDK types.
|
|
2
|
+
|
|
3
|
+
Provides a client class that wraps OpenRouter's API and returns OpenAI SDK-compatible
|
|
4
|
+
types for easy integration.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
|
13
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OpenRouterClient:
|
|
17
|
+
"""Client for interacting with OpenRouter API using OpenAI SDK types."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
"""Initialize OpenRouter client.
|
|
21
|
+
|
|
22
|
+
Reads API key from OPENROUTER_API_KEY environment variable.
|
|
23
|
+
"""
|
|
24
|
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
|
25
|
+
if not api_key:
|
|
26
|
+
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
|
27
|
+
|
|
28
|
+
self.client = AsyncOpenAI(
|
|
29
|
+
base_url="https://openrouter.ai/api/v1",
|
|
30
|
+
api_key=api_key,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@retry(
|
|
34
|
+
stop=stop_after_attempt(3),
|
|
35
|
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
|
36
|
+
)
|
|
37
|
+
async def generate(
|
|
38
|
+
self,
|
|
39
|
+
model: str,
|
|
40
|
+
messages: list[ChatCompletionMessageParam],
|
|
41
|
+
temperature: float = 0.7,
|
|
42
|
+
max_tokens: int | None = None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> ChatCompletion:
|
|
45
|
+
"""Generate a chat completion using OpenRouter.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model: Model name (e.g., "anthropic/claude-3.5-sonnet")
|
|
49
|
+
messages: List of message dicts with "role" and "content"
|
|
50
|
+
temperature: Sampling temperature
|
|
51
|
+
max_tokens: Maximum tokens to generate
|
|
52
|
+
**kwargs: Additional parameters to pass to the API
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
ChatCompletion object from OpenAI SDK
|
|
56
|
+
"""
|
|
57
|
+
return await self.client.chat.completions.create(
|
|
58
|
+
model=model,
|
|
59
|
+
messages=messages,
|
|
60
|
+
temperature=temperature,
|
|
61
|
+
max_tokens=max_tokens,
|
|
62
|
+
**kwargs,
|
|
63
|
+
)
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""Reranker client using Qwen3-Reranker for query-document scoring."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
DEFAULT_INSTRUCTION = "Find relevant Lean 4 math declarations"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RerankerResponse(BaseModel):
|
|
16
|
+
"""Response from reranking operation."""
|
|
17
|
+
|
|
18
|
+
query: str
|
|
19
|
+
"""The original query."""
|
|
20
|
+
|
|
21
|
+
scores: list[float]
|
|
22
|
+
"""Relevance scores for each document (same order as input)."""
|
|
23
|
+
|
|
24
|
+
model: str
|
|
25
|
+
"""Model name used for reranking."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RerankerClient:
|
|
29
|
+
"""Client for reranking query-document pairs using Qwen3-Reranker."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model_name: str = "Qwen/Qwen3-Reranker-0.6B",
|
|
34
|
+
device: str | None = None,
|
|
35
|
+
max_length: int = 512,
|
|
36
|
+
instruction: str = DEFAULT_INSTRUCTION,
|
|
37
|
+
):
|
|
38
|
+
"""Initialize the reranker client.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_name: Name of the reranker model from HuggingFace.
|
|
42
|
+
device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
|
|
43
|
+
max_length: Maximum sequence length for tokenization.
|
|
44
|
+
instruction: Task instruction prepended to each query-document pair.
|
|
45
|
+
"""
|
|
46
|
+
self.model_name = model_name
|
|
47
|
+
self.device = device or self._select_device()
|
|
48
|
+
self.max_length = max_length
|
|
49
|
+
self.instruction = instruction
|
|
50
|
+
|
|
51
|
+
logger.info(f"Loading reranker model {model_name} on {self.device}")
|
|
52
|
+
|
|
53
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
54
|
+
model_name, padding_side="left", trust_remote_code=True
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Use float32 on CPU (faster than float16 which gets emulated)
|
|
58
|
+
# Use float16 on GPU for memory efficiency
|
|
59
|
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
|
60
|
+
|
|
61
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
62
|
+
model_name, torch_dtype=dtype, trust_remote_code=True
|
|
63
|
+
).to(self.device)
|
|
64
|
+
self.model.eval()
|
|
65
|
+
|
|
66
|
+
# Get token IDs for true/false classification
|
|
67
|
+
self._token_true_id = self.tokenizer.convert_tokens_to_ids("true")
|
|
68
|
+
self._token_false_id = self.tokenizer.convert_tokens_to_ids("false")
|
|
69
|
+
|
|
70
|
+
logger.info("Reranker model loaded successfully")
|
|
71
|
+
|
|
72
|
+
def _select_device(self) -> str:
|
|
73
|
+
"""Select best available device."""
|
|
74
|
+
if torch.cuda.is_available():
|
|
75
|
+
return "cuda"
|
|
76
|
+
return "cpu"
|
|
77
|
+
|
|
78
|
+
def _format_pair(self, query: str, document: str) -> str:
|
|
79
|
+
"""Format a query-document pair with instruction.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
query: The search query.
|
|
83
|
+
document: The document text to score.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Formatted string for the reranker model.
|
|
87
|
+
"""
|
|
88
|
+
return (
|
|
89
|
+
f"<Instruct>: {self.instruction}\n"
|
|
90
|
+
f"<Query>: {query}\n"
|
|
91
|
+
f"<Document>: {document}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@torch.no_grad()
|
|
95
|
+
def _compute_scores_sync(self, pairs: list[str]) -> list[float]:
|
|
96
|
+
"""Compute relevance scores for formatted pairs synchronously.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
pairs: List of formatted query-document strings.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of relevance scores in [0, 1].
|
|
103
|
+
"""
|
|
104
|
+
inputs = self.tokenizer(
|
|
105
|
+
pairs,
|
|
106
|
+
padding=True,
|
|
107
|
+
truncation=True,
|
|
108
|
+
max_length=self.max_length,
|
|
109
|
+
return_tensors="pt",
|
|
110
|
+
).to(self.device)
|
|
111
|
+
|
|
112
|
+
# Get logits for last token
|
|
113
|
+
outputs = self.model(**inputs)
|
|
114
|
+
logits = outputs.logits[:, -1, :]
|
|
115
|
+
|
|
116
|
+
# Extract true/false logits
|
|
117
|
+
true_logits = logits[:, self._token_true_id]
|
|
118
|
+
false_logits = logits[:, self._token_false_id]
|
|
119
|
+
|
|
120
|
+
# Compute probability of "true" using softmax
|
|
121
|
+
stacked = torch.stack([false_logits, true_logits], dim=1)
|
|
122
|
+
log_probs = torch.nn.functional.log_softmax(stacked, dim=1)
|
|
123
|
+
scores = log_probs[:, 1].exp()
|
|
124
|
+
|
|
125
|
+
return scores.cpu().tolist()
|
|
126
|
+
|
|
127
|
+
def rerank_sync(
|
|
128
|
+
self,
|
|
129
|
+
query: str,
|
|
130
|
+
documents: list[str],
|
|
131
|
+
) -> RerankerResponse:
|
|
132
|
+
"""Rerank documents synchronously (faster for small batches).
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
query: The search query.
|
|
136
|
+
documents: List of document texts to rerank.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
RerankerResponse with scores for each document.
|
|
140
|
+
"""
|
|
141
|
+
if not documents:
|
|
142
|
+
return RerankerResponse(query=query, scores=[], model=self.model_name)
|
|
143
|
+
|
|
144
|
+
pairs = [self._format_pair(query, doc) for doc in documents]
|
|
145
|
+
scores = self._compute_scores_sync(pairs)
|
|
146
|
+
return RerankerResponse(query=query, scores=scores, model=self.model_name)
|
|
147
|
+
|
|
148
|
+
async def rerank(
|
|
149
|
+
self,
|
|
150
|
+
query: str,
|
|
151
|
+
documents: list[str],
|
|
152
|
+
batch_size: int | None = None,
|
|
153
|
+
) -> RerankerResponse:
|
|
154
|
+
"""Rerank documents by relevance to query.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
query: The search query.
|
|
158
|
+
documents: List of document texts to rerank.
|
|
159
|
+
batch_size: Number of pairs to process at once.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
RerankerResponse with scores for each document.
|
|
163
|
+
"""
|
|
164
|
+
if not documents:
|
|
165
|
+
return RerankerResponse(query=query, scores=[], model=self.model_name)
|
|
166
|
+
|
|
167
|
+
# Default batch size: 16 on GPU (fits 8GB VRAM), 32 on CPU
|
|
168
|
+
if batch_size is None:
|
|
169
|
+
batch_size = 16 if self.device == "cuda" else 32
|
|
170
|
+
|
|
171
|
+
# For small batches, run synchronously to avoid executor overhead
|
|
172
|
+
if len(documents) <= batch_size:
|
|
173
|
+
return self.rerank_sync(query, documents)
|
|
174
|
+
|
|
175
|
+
# Format all pairs
|
|
176
|
+
pairs = [self._format_pair(query, doc) for doc in documents]
|
|
177
|
+
|
|
178
|
+
# Process in batches
|
|
179
|
+
loop = asyncio.get_event_loop()
|
|
180
|
+
all_scores: list[float] = []
|
|
181
|
+
|
|
182
|
+
for i in range(0, len(pairs), batch_size):
|
|
183
|
+
batch = pairs[i : i + batch_size]
|
|
184
|
+
batch_scores = await loop.run_in_executor(
|
|
185
|
+
None, self._compute_scores_sync, batch
|
|
186
|
+
)
|
|
187
|
+
all_scores.extend(batch_scores)
|
|
188
|
+
|
|
189
|
+
return RerankerResponse(query=query, scores=all_scores, model=self.model_name)
|