lean-explore 0.3.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. lean_explore/__init__.py +14 -1
  2. lean_explore/api/__init__.py +12 -1
  3. lean_explore/api/client.py +64 -176
  4. lean_explore/cli/__init__.py +10 -1
  5. lean_explore/cli/data_commands.py +184 -489
  6. lean_explore/cli/display.py +171 -0
  7. lean_explore/cli/main.py +51 -608
  8. lean_explore/config.py +244 -0
  9. lean_explore/extract/__init__.py +5 -0
  10. lean_explore/extract/__main__.py +368 -0
  11. lean_explore/extract/doc_gen4.py +200 -0
  12. lean_explore/extract/doc_parser.py +499 -0
  13. lean_explore/extract/embeddings.py +369 -0
  14. lean_explore/extract/github.py +110 -0
  15. lean_explore/extract/index.py +316 -0
  16. lean_explore/extract/informalize.py +653 -0
  17. lean_explore/extract/package_config.py +59 -0
  18. lean_explore/extract/package_registry.py +45 -0
  19. lean_explore/extract/package_utils.py +105 -0
  20. lean_explore/extract/types.py +25 -0
  21. lean_explore/mcp/__init__.py +11 -1
  22. lean_explore/mcp/app.py +14 -46
  23. lean_explore/mcp/server.py +20 -35
  24. lean_explore/mcp/tools.py +71 -205
  25. lean_explore/models/__init__.py +9 -0
  26. lean_explore/models/search_db.py +76 -0
  27. lean_explore/models/search_types.py +53 -0
  28. lean_explore/search/__init__.py +32 -0
  29. lean_explore/search/engine.py +651 -0
  30. lean_explore/search/scoring.py +156 -0
  31. lean_explore/search/service.py +68 -0
  32. lean_explore/search/tokenization.py +71 -0
  33. lean_explore/util/__init__.py +28 -0
  34. lean_explore/util/embedding_client.py +92 -0
  35. lean_explore/util/logging.py +22 -0
  36. lean_explore/util/openrouter_client.py +63 -0
  37. lean_explore/util/reranker_client.py +187 -0
  38. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/METADATA +32 -9
  39. lean_explore-1.0.1.dist-info/RECORD +43 -0
  40. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/WHEEL +1 -1
  41. lean_explore-1.0.1.dist-info/entry_points.txt +2 -0
  42. lean_explore/cli/agent.py +0 -788
  43. lean_explore/cli/config_utils.py +0 -481
  44. lean_explore/defaults.py +0 -114
  45. lean_explore/local/__init__.py +0 -1
  46. lean_explore/local/search.py +0 -1050
  47. lean_explore/local/service.py +0 -479
  48. lean_explore/shared/__init__.py +0 -1
  49. lean_explore/shared/models/__init__.py +0 -1
  50. lean_explore/shared/models/api.py +0 -117
  51. lean_explore/shared/models/db.py +0 -396
  52. lean_explore-0.3.0.dist-info/RECORD +0 -26
  53. lean_explore-0.3.0.dist-info/entry_points.txt +0 -2
  54. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/licenses/LICENSE +0 -0
  55. {lean_explore-0.3.0.dist-info → lean_explore-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ """Score normalization and fusion algorithms for search ranking.
2
+
3
+ This module provides utilities for combining multiple retrieval signals into
4
+ a unified ranking using techniques like Reciprocal Rank Fusion (RRF) and
5
+ weighted score fusion.
6
+ """
7
+
8
+ import difflib
9
+ import math
10
+
11
+ EPSILON = 1e-9
12
+
13
+
14
+ def normalize_scores(scores: list[float]) -> list[float]:
15
+ """Min-max normalize scores to [0, 1] range.
16
+
17
+ Args:
18
+ scores: List of raw scores.
19
+
20
+ Returns:
21
+ List of normalized scores.
22
+ """
23
+ if not scores:
24
+ return []
25
+
26
+ min_score = min(scores)
27
+ max_score = max(scores)
28
+ score_range = max_score - min_score
29
+
30
+ if score_range < EPSILON:
31
+ if max_score > EPSILON:
32
+ return [1.0] * len(scores)
33
+ return [0.0] * len(scores)
34
+
35
+ return [(s - min_score) / score_range for s in scores]
36
+
37
+
38
+ def normalize_dependency_counts(counts: list[int]) -> list[float]:
39
+ """Log-scale normalization for dependency counts.
40
+
41
+ Uses log(1 + count) / log(1 + max_count) to compress the range
42
+ and give more credit to items with moderate dependency counts.
43
+
44
+ Args:
45
+ counts: List of dependency counts.
46
+
47
+ Returns:
48
+ List of normalized scores in [0, 1] range.
49
+ """
50
+ if not counts:
51
+ return []
52
+
53
+ max_count = max(counts)
54
+ if max_count == 0:
55
+ return [0.0] * len(counts)
56
+
57
+ log_max = math.log(1 + max_count)
58
+ return [math.log(1 + c) / log_max for c in counts]
59
+
60
+
61
+ def compute_ranks(scores: list[float]) -> list[int]:
62
+ """Compute ranks for a list of scores (1-indexed, higher score = lower rank).
63
+
64
+ Candidates with score 0 get rank len(scores)+1 (worst possible).
65
+
66
+ Args:
67
+ scores: List of raw scores.
68
+
69
+ Returns:
70
+ List of ranks (1 = best).
71
+ """
72
+ n = len(scores)
73
+ indexed = [(i, s) for i, s in enumerate(scores)]
74
+ indexed.sort(key=lambda x: x[1], reverse=True)
75
+
76
+ ranks = [0] * n
77
+ for rank, (idx, score) in enumerate(indexed, 1):
78
+ if score > 0:
79
+ ranks[idx] = rank
80
+ else:
81
+ ranks[idx] = n + 1
82
+
83
+ return ranks
84
+
85
+
86
+ def reciprocal_rank_fusion(rank_lists: list[list[int]], k: int = 0) -> list[float]:
87
+ """Compute RRF scores from multiple rank lists.
88
+
89
+ RRF(d) = sum(1 / (k + rank_i(d)) for each signal i)
90
+
91
+ Args:
92
+ rank_lists: List of rank lists, one per signal.
93
+ k: Constant to prevent top rank from dominating. Default 0 means 1/rank.
94
+
95
+ Returns:
96
+ List of RRF scores for each candidate.
97
+ """
98
+ n = len(rank_lists[0])
99
+ rrf_scores = []
100
+
101
+ for i in range(n):
102
+ score = sum(1.0 / (k + ranks[i]) for ranks in rank_lists)
103
+ rrf_scores.append(score)
104
+
105
+ return rrf_scores
106
+
107
+
108
+ def weighted_score_fusion(
109
+ score_lists: list[list[float]],
110
+ weights: list[float],
111
+ ) -> list[float]:
112
+ """Combine multiple score lists using weighted normalized scores.
113
+
114
+ Each score list is normalized to [0, 1] using min-max scaling,
115
+ then combined with the given weights.
116
+
117
+ Args:
118
+ score_lists: List of score lists, one per signal.
119
+ weights: Weight for each signal (should sum to 1.0 for interpretability).
120
+
121
+ Returns:
122
+ List of combined scores for each candidate.
123
+ """
124
+ if not score_lists:
125
+ return []
126
+
127
+ n = len(score_lists[0])
128
+ if n == 0:
129
+ return []
130
+
131
+ normalized_lists = [normalize_scores(scores) for scores in score_lists]
132
+
133
+ combined = []
134
+ for i in range(n):
135
+ score = sum(w * normalized_lists[j][i] for j, w in enumerate(weights))
136
+ combined.append(score)
137
+
138
+ return combined
139
+
140
+
141
+ def fuzzy_name_score(query: str, name: str) -> float:
142
+ """Compute fuzzy match score between query and declaration name.
143
+
144
+ Normalizes both strings (dots/underscores -> spaces) and uses
145
+ SequenceMatcher ratio for character-level similarity.
146
+
147
+ Args:
148
+ query: Search query string.
149
+ name: Declaration name to match against.
150
+
151
+ Returns:
152
+ Similarity score between 0 and 1.
153
+ """
154
+ normalized_query = query.lower().replace(".", " ").replace("_", " ")
155
+ normalized_name = name.lower().replace(".", " ").replace("_", " ")
156
+ return difflib.SequenceMatcher(None, normalized_query, normalized_name).ratio()
@@ -0,0 +1,68 @@
1
+ """Service layer for search operations."""
2
+
3
+ import time
4
+
5
+ from lean_explore.models import SearchResponse, SearchResult
6
+ from lean_explore.search.engine import SearchEngine
7
+
8
+
9
+ class Service:
10
+ """Service wrapper for search operations.
11
+
12
+ Provides a clean interface for searching and retrieving declarations.
13
+ """
14
+
15
+ def __init__(self, engine: SearchEngine | None = None):
16
+ """Initialize the search service.
17
+
18
+ Args:
19
+ engine: SearchEngine instance. Defaults to new engine.
20
+ """
21
+ self.engine = engine or SearchEngine()
22
+
23
+ async def search(
24
+ self,
25
+ query: str,
26
+ limit: int = 20,
27
+ rerank_top: int | None = 50,
28
+ packages: list[str] | None = None,
29
+ ) -> SearchResponse:
30
+ """Search for Lean declarations.
31
+
32
+ Args:
33
+ query: Search query string.
34
+ limit: Maximum number of results to return.
35
+ rerank_top: Number of candidates to rerank with cross-encoder.
36
+ packages: Filter results to specific packages (e.g., ["Mathlib"]).
37
+
38
+ Returns:
39
+ SearchResponse containing results and metadata.
40
+ """
41
+ start_time = time.time()
42
+
43
+ results = await self.engine.search(
44
+ query=query,
45
+ limit=limit,
46
+ rerank_top=rerank_top,
47
+ packages=packages,
48
+ )
49
+
50
+ processing_time_ms = int((time.time() - start_time) * 1000)
51
+
52
+ return SearchResponse(
53
+ query=query,
54
+ results=results,
55
+ count=len(results),
56
+ processing_time_ms=processing_time_ms,
57
+ )
58
+
59
+ async def get_by_id(self, declaration_id: int) -> SearchResult | None:
60
+ """Retrieve a declaration by ID.
61
+
62
+ Args:
63
+ declaration_id: The declaration ID.
64
+
65
+ Returns:
66
+ SearchResult if found, None otherwise.
67
+ """
68
+ return await self.engine.get_by_id(declaration_id)
@@ -0,0 +1,71 @@
1
+ """Text tokenization utilities for search indexing and querying.
2
+
3
+ This module provides tokenization strategies for Lean declaration names,
4
+ supporting both spaced tokenization (splits on dots, underscores, camelCase)
5
+ and raw tokenization (preserves structure for exact matching).
6
+ """
7
+
8
+ import re
9
+
10
+
11
+ def tokenize_spaced(text: str) -> list[str]:
12
+ """Tokenize text with spacing on dots, underscores, and camelCase.
13
+
14
+ Args:
15
+ text: Input text to tokenize.
16
+
17
+ Returns:
18
+ List of lowercase word tokens.
19
+ """
20
+ if not text:
21
+ return []
22
+ # Replace dots and underscores with spaces
23
+ text = text.replace(".", " ").replace("_", " ")
24
+ # Split camelCase: insert space before uppercase letters
25
+ text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
26
+ return re.findall(r"\w+", text.lower())
27
+
28
+
29
+ def tokenize_raw(text: str) -> list[str]:
30
+ """Tokenize text as single token (preserves dots).
31
+
32
+ Args:
33
+ text: Input text to tokenize.
34
+
35
+ Returns:
36
+ List with the full text as a single lowercase token.
37
+ """
38
+ if not text:
39
+ return []
40
+ return [text.lower()]
41
+
42
+
43
+ def tokenize_words(text: str) -> list[str]:
44
+ """Simple word tokenization for natural language text.
45
+
46
+ Splits on whitespace and punctuation, returns lowercase tokens.
47
+
48
+ Args:
49
+ text: Input text to tokenize.
50
+
51
+ Returns:
52
+ List of lowercase word tokens.
53
+ """
54
+ if not text:
55
+ return []
56
+ return [w.lower() for w in re.findall(r"\w+", text)]
57
+
58
+
59
+ def is_autogenerated(name: str) -> bool:
60
+ """Check if a declaration name is auto-generated by Lean.
61
+
62
+ Auto-generated declarations include:
63
+ - .mk constructors (e.g., Nat.mk)
64
+
65
+ Args:
66
+ name: Fully qualified declaration name.
67
+
68
+ Returns:
69
+ True if the declaration is auto-generated.
70
+ """
71
+ return name.endswith(".mk")
@@ -0,0 +1,28 @@
1
+ """Shared utilities for lean_explore.
2
+
3
+ Imports are lazy to avoid loading torch when not needed.
4
+ """
5
+
6
+
7
+ def __getattr__(name: str):
8
+ """Lazy import attributes to avoid loading torch unnecessarily."""
9
+ if name == "EmbeddingClient":
10
+ from lean_explore.util.embedding_client import EmbeddingClient
11
+
12
+ return EmbeddingClient
13
+ if name == "RerankerClient":
14
+ from lean_explore.util.reranker_client import RerankerClient
15
+
16
+ return RerankerClient
17
+ if name == "OpenRouterClient":
18
+ from lean_explore.util.openrouter_client import OpenRouterClient
19
+
20
+ return OpenRouterClient
21
+ if name == "setup_logging":
22
+ from lean_explore.util.logging import setup_logging
23
+
24
+ return setup_logging
25
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
26
+
27
+
28
+ __all__ = ["EmbeddingClient", "RerankerClient", "OpenRouterClient", "setup_logging"]
@@ -0,0 +1,92 @@
1
+ """Embedding generation client using sentence transformers."""
2
+
3
+ import asyncio
4
+ import logging
5
+
6
+ import torch
7
+ from pydantic import BaseModel
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class EmbeddingResponse(BaseModel):
14
+ """Response from embedding generation."""
15
+
16
+ texts: list[str]
17
+ """Original input texts."""
18
+
19
+ embeddings: list[list[float]]
20
+ """List of embeddings (one per input text)."""
21
+
22
+ model: str
23
+ """Model name used for generation."""
24
+
25
+
26
+ class EmbeddingClient:
27
+ """Client for generating text embeddings."""
28
+
29
+ def __init__(
30
+ self, model_name: str, device: str | None = None, max_length: int | None = None
31
+ ):
32
+ """Initialize the embedding client.
33
+
34
+ Args:
35
+ model_name: Name of the sentence transformer model
36
+ device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
37
+ max_length: Maximum sequence length for tokenization. If None, uses
38
+ model default. Lower values reduce memory usage.
39
+ """
40
+ self.model_name = model_name
41
+ self.device = device or self._select_device()
42
+ self.max_length = max_length
43
+ logger.info(f"Loading embedding model {model_name} on {self.device}")
44
+ self.model = SentenceTransformer(model_name, device=self.device)
45
+
46
+ # Set max sequence length if specified
47
+ if max_length is not None:
48
+ self.model.max_seq_length = max_length
49
+ logger.info(f"Set max sequence length to {max_length}")
50
+
51
+ def _select_device(self) -> str:
52
+ """Select best available device."""
53
+ if torch.cuda.is_available():
54
+ return "cuda"
55
+ if torch.backends.mps.is_available():
56
+ return "mps"
57
+ return "cpu"
58
+
59
+ async def embed(
60
+ self, texts: list[str], is_query: bool = False
61
+ ) -> EmbeddingResponse:
62
+ """Generate embeddings for a list of texts.
63
+
64
+ Args:
65
+ texts: List of text strings to embed
66
+ is_query: If True, encode as search queries using the model's query
67
+ prompt. If False (default), encode as documents without prompt.
68
+ Qwen3-Embedding models are asymmetric and perform better when
69
+ queries use prompt_name="query".
70
+
71
+ Returns:
72
+ EmbeddingResponse with texts, embeddings, and model info
73
+ """
74
+ loop = asyncio.get_event_loop()
75
+
76
+ def _encode():
77
+ # Use query prompt for search queries, no prompt for documents
78
+ encode_kwargs = {
79
+ "show_progress_bar": False,
80
+ "convert_to_numpy": True,
81
+ "batch_size": 256, # Larger batches for GPU utilization
82
+ }
83
+ if is_query:
84
+ encode_kwargs["prompt_name"] = "query"
85
+ return self.model.encode(texts, **encode_kwargs)
86
+
87
+ embeddings = await loop.run_in_executor(None, _encode)
88
+ return EmbeddingResponse(
89
+ texts=texts,
90
+ embeddings=[emb.tolist() for emb in embeddings],
91
+ model=self.model_name,
92
+ )
@@ -0,0 +1,22 @@
1
+ """Logging utilities for lean_explore."""
2
+
3
+ import logging
4
+ import sys
5
+
6
+
7
+ def setup_logging(verbose: bool = False) -> None:
8
+ """Configure logging for lean_explore applications.
9
+
10
+ Args:
11
+ verbose: If True, set level to DEBUG; otherwise INFO.
12
+ """
13
+ level = logging.DEBUG if verbose else logging.INFO
14
+ logging.basicConfig(
15
+ level=level,
16
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
17
+ handlers=[logging.StreamHandler(sys.stdout)],
18
+ )
19
+
20
+ # Suppress noisy third-party loggers
21
+ logging.getLogger("httpx").setLevel(logging.WARNING)
22
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
@@ -0,0 +1,63 @@
1
+ """OpenRouter API wrapper using OpenAI SDK types.
2
+
3
+ Provides a client class that wraps OpenRouter's API and returns OpenAI SDK-compatible
4
+ types for easy integration.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+
11
+ from openai import AsyncOpenAI
12
+ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
13
+ from tenacity import retry, stop_after_attempt, wait_exponential
14
+
15
+
16
+ class OpenRouterClient:
17
+ """Client for interacting with OpenRouter API using OpenAI SDK types."""
18
+
19
+ def __init__(self):
20
+ """Initialize OpenRouter client.
21
+
22
+ Reads API key from OPENROUTER_API_KEY environment variable.
23
+ """
24
+ api_key = os.getenv("OPENROUTER_API_KEY")
25
+ if not api_key:
26
+ raise ValueError("OPENROUTER_API_KEY environment variable not set")
27
+
28
+ self.client = AsyncOpenAI(
29
+ base_url="https://openrouter.ai/api/v1",
30
+ api_key=api_key,
31
+ )
32
+
33
+ @retry(
34
+ stop=stop_after_attempt(3),
35
+ wait=wait_exponential(multiplier=1, min=2, max=10),
36
+ )
37
+ async def generate(
38
+ self,
39
+ model: str,
40
+ messages: list[ChatCompletionMessageParam],
41
+ temperature: float = 0.7,
42
+ max_tokens: int | None = None,
43
+ **kwargs,
44
+ ) -> ChatCompletion:
45
+ """Generate a chat completion using OpenRouter.
46
+
47
+ Args:
48
+ model: Model name (e.g., "anthropic/claude-3.5-sonnet")
49
+ messages: List of message dicts with "role" and "content"
50
+ temperature: Sampling temperature
51
+ max_tokens: Maximum tokens to generate
52
+ **kwargs: Additional parameters to pass to the API
53
+
54
+ Returns:
55
+ ChatCompletion object from OpenAI SDK
56
+ """
57
+ return await self.client.chat.completions.create(
58
+ model=model,
59
+ messages=messages,
60
+ temperature=temperature,
61
+ max_tokens=max_tokens,
62
+ **kwargs,
63
+ )
@@ -0,0 +1,187 @@
1
+ """Reranker client using Qwen3-Reranker for query-document scoring."""
2
+
3
+ import asyncio
4
+ import logging
5
+
6
+ import torch
7
+ from pydantic import BaseModel
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ DEFAULT_INSTRUCTION = "Find relevant Lean 4 math declarations"
13
+
14
+
15
+ class RerankerResponse(BaseModel):
16
+ """Response from reranking operation."""
17
+
18
+ query: str
19
+ """The original query."""
20
+
21
+ scores: list[float]
22
+ """Relevance scores for each document (same order as input)."""
23
+
24
+ model: str
25
+ """Model name used for reranking."""
26
+
27
+
28
+ class RerankerClient:
29
+ """Client for reranking query-document pairs using Qwen3-Reranker."""
30
+
31
+ def __init__(
32
+ self,
33
+ model_name: str = "Qwen/Qwen3-Reranker-0.6B",
34
+ device: str | None = None,
35
+ max_length: int = 512,
36
+ instruction: str = DEFAULT_INSTRUCTION,
37
+ ):
38
+ """Initialize the reranker client.
39
+
40
+ Args:
41
+ model_name: Name of the reranker model from HuggingFace.
42
+ device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
43
+ max_length: Maximum sequence length for tokenization.
44
+ instruction: Task instruction prepended to each query-document pair.
45
+ """
46
+ self.model_name = model_name
47
+ self.device = device or self._select_device()
48
+ self.max_length = max_length
49
+ self.instruction = instruction
50
+
51
+ logger.info(f"Loading reranker model {model_name} on {self.device}")
52
+
53
+ self.tokenizer = AutoTokenizer.from_pretrained(
54
+ model_name, padding_side="left", trust_remote_code=True
55
+ )
56
+
57
+ # Use float32 on CPU (faster than float16 which gets emulated)
58
+ # Use float16 on GPU for memory efficiency
59
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
60
+
61
+ self.model = AutoModelForCausalLM.from_pretrained(
62
+ model_name, torch_dtype=dtype, trust_remote_code=True
63
+ ).to(self.device)
64
+ self.model.eval()
65
+
66
+ # Get token IDs for true/false classification
67
+ self._token_true_id = self.tokenizer.convert_tokens_to_ids("true")
68
+ self._token_false_id = self.tokenizer.convert_tokens_to_ids("false")
69
+
70
+ logger.info("Reranker model loaded successfully")
71
+
72
+ def _select_device(self) -> str:
73
+ """Select best available device."""
74
+ if torch.cuda.is_available():
75
+ return "cuda"
76
+ return "cpu"
77
+
78
+ def _format_pair(self, query: str, document: str) -> str:
79
+ """Format a query-document pair with instruction.
80
+
81
+ Args:
82
+ query: The search query.
83
+ document: The document text to score.
84
+
85
+ Returns:
86
+ Formatted string for the reranker model.
87
+ """
88
+ return (
89
+ f"<Instruct>: {self.instruction}\n<Query>: {query}\n<Document>: {document}"
90
+ )
91
+
92
+ @torch.no_grad()
93
+ def _compute_scores_sync(self, pairs: list[str]) -> list[float]:
94
+ """Compute relevance scores for formatted pairs synchronously.
95
+
96
+ Args:
97
+ pairs: List of formatted query-document strings.
98
+
99
+ Returns:
100
+ List of relevance scores in [0, 1].
101
+ """
102
+ inputs = self.tokenizer(
103
+ pairs,
104
+ padding=True,
105
+ truncation=True,
106
+ max_length=self.max_length,
107
+ return_tensors="pt",
108
+ ).to(self.device)
109
+
110
+ # Get logits for last token
111
+ outputs = self.model(**inputs)
112
+ logits = outputs.logits[:, -1, :]
113
+
114
+ # Extract true/false logits
115
+ true_logits = logits[:, self._token_true_id]
116
+ false_logits = logits[:, self._token_false_id]
117
+
118
+ # Compute probability of "true" using softmax
119
+ stacked = torch.stack([false_logits, true_logits], dim=1)
120
+ log_probs = torch.nn.functional.log_softmax(stacked, dim=1)
121
+ scores = log_probs[:, 1].exp()
122
+
123
+ return scores.cpu().tolist()
124
+
125
+ def rerank_sync(
126
+ self,
127
+ query: str,
128
+ documents: list[str],
129
+ ) -> RerankerResponse:
130
+ """Rerank documents synchronously (faster for small batches).
131
+
132
+ Args:
133
+ query: The search query.
134
+ documents: List of document texts to rerank.
135
+
136
+ Returns:
137
+ RerankerResponse with scores for each document.
138
+ """
139
+ if not documents:
140
+ return RerankerResponse(query=query, scores=[], model=self.model_name)
141
+
142
+ pairs = [self._format_pair(query, doc) for doc in documents]
143
+ scores = self._compute_scores_sync(pairs)
144
+ return RerankerResponse(query=query, scores=scores, model=self.model_name)
145
+
146
+ async def rerank(
147
+ self,
148
+ query: str,
149
+ documents: list[str],
150
+ batch_size: int | None = None,
151
+ ) -> RerankerResponse:
152
+ """Rerank documents by relevance to query.
153
+
154
+ Args:
155
+ query: The search query.
156
+ documents: List of document texts to rerank.
157
+ batch_size: Number of pairs to process at once.
158
+
159
+ Returns:
160
+ RerankerResponse with scores for each document.
161
+ """
162
+ if not documents:
163
+ return RerankerResponse(query=query, scores=[], model=self.model_name)
164
+
165
+ # Default batch size: 16 on GPU (fits 8GB VRAM), 32 on CPU
166
+ if batch_size is None:
167
+ batch_size = 16 if self.device == "cuda" else 32
168
+
169
+ # For small batches, run synchronously to avoid executor overhead
170
+ if len(documents) <= batch_size:
171
+ return self.rerank_sync(query, documents)
172
+
173
+ # Format all pairs
174
+ pairs = [self._format_pair(query, doc) for doc in documents]
175
+
176
+ # Process in batches
177
+ loop = asyncio.get_event_loop()
178
+ all_scores: list[float] = []
179
+
180
+ for i in range(0, len(pairs), batch_size):
181
+ batch = pairs[i : i + batch_size]
182
+ batch_scores = await loop.run_in_executor(
183
+ None, self._compute_scores_sync, batch
184
+ )
185
+ all_scores.extend(batch_scores)
186
+
187
+ return RerankerResponse(query=query, scores=all_scores, model=self.model_name)