hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +311 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
- hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
- hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
- hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
- hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
- hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
- hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
- hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
- hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
- hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
- hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
- hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
- hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
- hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
- hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
- hindsight_api/api/http.py +1406 -118
- hindsight_api/api/mcp.py +11 -196
- hindsight_api/config.py +359 -27
- hindsight_api/engine/consolidation/__init__.py +5 -0
- hindsight_api/engine/consolidation/consolidator.py +859 -0
- hindsight_api/engine/consolidation/prompts.py +69 -0
- hindsight_api/engine/cross_encoder.py +706 -88
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/directives/__init__.py +5 -0
- hindsight_api/engine/directives/models.py +37 -0
- hindsight_api/engine/embeddings.py +553 -29
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +40 -17
- hindsight_api/engine/llm_wrapper.py +744 -68
- hindsight_api/engine/memory_engine.py +2505 -1017
- hindsight_api/engine/mental_models/__init__.py +14 -0
- hindsight_api/engine/mental_models/models.py +53 -0
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/reflect/__init__.py +18 -0
- hindsight_api/engine/reflect/agent.py +933 -0
- hindsight_api/engine/reflect/models.py +109 -0
- hindsight_api/engine/reflect/observations.py +186 -0
- hindsight_api/engine/reflect/prompts.py +483 -0
- hindsight_api/engine/reflect/tools.py +437 -0
- hindsight_api/engine/reflect/tools_schema.py +250 -0
- hindsight_api/engine/response_models.py +168 -4
- hindsight_api/engine/retain/bank_utils.py +79 -201
- hindsight_api/engine/retain/fact_extraction.py +424 -195
- hindsight_api/engine/retain/fact_storage.py +35 -12
- hindsight_api/engine/retain/link_utils.py +29 -24
- hindsight_api/engine/retain/orchestrator.py +24 -43
- hindsight_api/engine/retain/types.py +11 -2
- hindsight_api/engine/search/graph_retrieval.py +43 -14
- hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +848 -201
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +42 -141
- hindsight_api/engine/search/trace.py +12 -1
- hindsight_api/engine/search/tracer.py +26 -6
- hindsight_api/engine/search/types.py +21 -3
- hindsight_api/engine/task_backend.py +113 -106
- hindsight_api/engine/utils.py +1 -152
- hindsight_api/extensions/__init__.py +10 -1
- hindsight_api/extensions/builtin/tenant.py +5 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/extensions/operation_validator.py +81 -4
- hindsight_api/extensions/tenant.py +26 -0
- hindsight_api/main.py +69 -6
- hindsight_api/mcp_local.py +12 -53
- hindsight_api/mcp_tools.py +494 -0
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -3
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- hindsight_api/worker/__init__.py +11 -0
- hindsight_api/worker/main.py +296 -0
- hindsight_api/worker/poller.py +486 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
- hindsight_api-0.4.0.dist-info/RECORD +112 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
- hindsight_api/engine/retain/observation_regeneration.py +0 -254
- hindsight_api/engine/search/observation_utils.py +0 -125
- hindsight_api/engine/search/scoring.py +0 -159
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -6,17 +6,38 @@ Provides an interface for reranking with different backends.
|
|
|
6
6
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
import asyncio
|
|
9
10
|
import logging
|
|
10
11
|
import os
|
|
11
12
|
from abc import ABC, abstractmethod
|
|
13
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
14
|
|
|
13
15
|
import httpx
|
|
14
16
|
|
|
15
17
|
from ..config import (
|
|
18
|
+
DEFAULT_LITELLM_API_BASE,
|
|
19
|
+
DEFAULT_RERANKER_COHERE_MODEL,
|
|
20
|
+
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
|
|
21
|
+
DEFAULT_RERANKER_FLASHRANK_MODEL,
|
|
22
|
+
DEFAULT_RERANKER_LITELLM_MODEL,
|
|
23
|
+
DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT,
|
|
16
24
|
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
17
25
|
DEFAULT_RERANKER_PROVIDER,
|
|
26
|
+
DEFAULT_RERANKER_TEI_BATCH_SIZE,
|
|
27
|
+
DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
|
|
28
|
+
ENV_COHERE_API_KEY,
|
|
29
|
+
ENV_LITELLM_API_BASE,
|
|
30
|
+
ENV_LITELLM_API_KEY,
|
|
31
|
+
ENV_RERANKER_COHERE_BASE_URL,
|
|
32
|
+
ENV_RERANKER_COHERE_MODEL,
|
|
33
|
+
ENV_RERANKER_FLASHRANK_CACHE_DIR,
|
|
34
|
+
ENV_RERANKER_FLASHRANK_MODEL,
|
|
35
|
+
ENV_RERANKER_LITELLM_MODEL,
|
|
36
|
+
ENV_RERANKER_LOCAL_MAX_CONCURRENT,
|
|
18
37
|
ENV_RERANKER_LOCAL_MODEL,
|
|
19
38
|
ENV_RERANKER_PROVIDER,
|
|
39
|
+
ENV_RERANKER_TEI_BATCH_SIZE,
|
|
40
|
+
ENV_RERANKER_TEI_MAX_CONCURRENT,
|
|
20
41
|
ENV_RERANKER_TEI_URL,
|
|
21
42
|
)
|
|
22
43
|
|
|
@@ -47,7 +68,7 @@ class CrossEncoderModel(ABC):
|
|
|
47
68
|
pass
|
|
48
69
|
|
|
49
70
|
@abstractmethod
|
|
50
|
-
def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
71
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
51
72
|
"""
|
|
52
73
|
Score query-document pairs for relevance.
|
|
53
74
|
|
|
@@ -70,25 +91,34 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
70
91
|
- Fast inference (~80ms for 100 pairs on CPU)
|
|
71
92
|
- Small model (80MB)
|
|
72
93
|
- Trained for passage re-ranking
|
|
94
|
+
|
|
95
|
+
Uses a dedicated thread pool to limit concurrent CPU-bound work.
|
|
73
96
|
"""
|
|
74
97
|
|
|
75
|
-
|
|
98
|
+
# Shared executor across all instances (one model loaded anyway)
|
|
99
|
+
_executor: ThreadPoolExecutor | None = None
|
|
100
|
+
_max_concurrent: int = 4 # Limit concurrent CPU-bound reranking calls
|
|
101
|
+
|
|
102
|
+
def __init__(self, model_name: str | None = None, max_concurrent: int = 4):
|
|
76
103
|
"""
|
|
77
104
|
Initialize local SentenceTransformers cross-encoder.
|
|
78
105
|
|
|
79
106
|
Args:
|
|
80
107
|
model_name: Name of the CrossEncoder model to use.
|
|
81
108
|
Default: cross-encoder/ms-marco-MiniLM-L-6-v2
|
|
109
|
+
max_concurrent: Maximum concurrent reranking calls (default: 2).
|
|
110
|
+
Higher values may cause CPU thrashing under load.
|
|
82
111
|
"""
|
|
83
112
|
self.model_name = model_name or DEFAULT_RERANKER_LOCAL_MODEL
|
|
84
113
|
self._model = None
|
|
114
|
+
LocalSTCrossEncoder._max_concurrent = max_concurrent
|
|
85
115
|
|
|
86
116
|
@property
|
|
87
117
|
def provider_name(self) -> str:
|
|
88
118
|
return "local"
|
|
89
119
|
|
|
90
120
|
async def initialize(self) -> None:
|
|
91
|
-
"""Load the cross-encoder model."""
|
|
121
|
+
"""Load the cross-encoder model and initialize the executor."""
|
|
92
122
|
if self._model is not None:
|
|
93
123
|
return
|
|
94
124
|
|
|
@@ -101,13 +131,134 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
101
131
|
)
|
|
102
132
|
|
|
103
133
|
logger.info(f"Reranker: initializing local provider with model {self.model_name}")
|
|
104
|
-
self._model = CrossEncoder(self.model_name)
|
|
105
|
-
logger.info("Reranker: local provider initialized")
|
|
106
134
|
|
|
107
|
-
|
|
135
|
+
# Determine device based on hardware availability.
|
|
136
|
+
# We always set low_cpu_mem_usage=False to prevent lazy loading (meta tensors)
|
|
137
|
+
# which can cause issues when accelerate is installed but no GPU is available.
|
|
138
|
+
# Note: We do NOT use device_map because CrossEncoder internally calls .to(device)
|
|
139
|
+
# after loading, which conflicts with accelerate's device_map handling.
|
|
140
|
+
import torch
|
|
141
|
+
|
|
142
|
+
# Check for GPU (CUDA) or Apple Silicon (MPS)
|
|
143
|
+
has_gpu = torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
|
|
144
|
+
|
|
145
|
+
if has_gpu:
|
|
146
|
+
device = None # Let sentence-transformers auto-detect GPU/MPS
|
|
147
|
+
else:
|
|
148
|
+
device = "cpu"
|
|
149
|
+
|
|
150
|
+
self._model = CrossEncoder(
|
|
151
|
+
self.model_name,
|
|
152
|
+
device=device,
|
|
153
|
+
model_kwargs={"low_cpu_mem_usage": False},
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Initialize shared executor (limited workers naturally limits concurrency)
|
|
157
|
+
if LocalSTCrossEncoder._executor is None:
|
|
158
|
+
LocalSTCrossEncoder._executor = ThreadPoolExecutor(
|
|
159
|
+
max_workers=LocalSTCrossEncoder._max_concurrent,
|
|
160
|
+
thread_name_prefix="reranker",
|
|
161
|
+
)
|
|
162
|
+
logger.info(f"Reranker: local provider initialized (max_concurrent={LocalSTCrossEncoder._max_concurrent})")
|
|
163
|
+
else:
|
|
164
|
+
logger.info("Reranker: local provider initialized (using existing executor)")
|
|
165
|
+
|
|
166
|
+
def _is_xpc_error(self, error: Exception) -> bool:
|
|
167
|
+
"""
|
|
168
|
+
Check if an error is an XPC connection error (macOS daemon issue).
|
|
169
|
+
|
|
170
|
+
On macOS, long-running daemons can lose XPC connections to system services
|
|
171
|
+
when the process is idle for extended periods.
|
|
172
|
+
"""
|
|
173
|
+
error_str = str(error).lower()
|
|
174
|
+
return "xpc_error_connection_invalid" in error_str or "xpc error" in error_str
|
|
175
|
+
|
|
176
|
+
def _reinitialize_model_sync(self) -> None:
|
|
177
|
+
"""
|
|
178
|
+
Clear and reinitialize the cross-encoder model synchronously.
|
|
179
|
+
|
|
180
|
+
This is used to recover from XPC errors on macOS where the
|
|
181
|
+
PyTorch/MPS backend loses its connection to system services.
|
|
182
|
+
"""
|
|
183
|
+
logger.warning(f"Reinitializing reranker model {self.model_name} due to backend error")
|
|
184
|
+
|
|
185
|
+
# Clear existing model
|
|
186
|
+
self._model = None
|
|
187
|
+
|
|
188
|
+
# Force garbage collection to free resources
|
|
189
|
+
import gc
|
|
190
|
+
|
|
191
|
+
import torch
|
|
192
|
+
|
|
193
|
+
gc.collect()
|
|
194
|
+
|
|
195
|
+
# If using CUDA/MPS, clear the cache
|
|
196
|
+
if torch.cuda.is_available():
|
|
197
|
+
torch.cuda.empty_cache()
|
|
198
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
199
|
+
try:
|
|
200
|
+
torch.mps.empty_cache()
|
|
201
|
+
except AttributeError:
|
|
202
|
+
pass # Method might not exist in all PyTorch versions
|
|
203
|
+
|
|
204
|
+
# Reinitialize the model
|
|
205
|
+
try:
|
|
206
|
+
from sentence_transformers import CrossEncoder
|
|
207
|
+
except ImportError:
|
|
208
|
+
raise ImportError(
|
|
209
|
+
"sentence-transformers is required for LocalSTCrossEncoder. "
|
|
210
|
+
"Install it with: pip install sentence-transformers"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Determine device based on hardware availability
|
|
214
|
+
has_gpu = torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
|
|
215
|
+
|
|
216
|
+
if has_gpu:
|
|
217
|
+
device = None # Let sentence-transformers auto-detect GPU/MPS
|
|
218
|
+
else:
|
|
219
|
+
device = "cpu"
|
|
220
|
+
|
|
221
|
+
self._model = CrossEncoder(
|
|
222
|
+
self.model_name,
|
|
223
|
+
device=device,
|
|
224
|
+
model_kwargs={"low_cpu_mem_usage": False},
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
logger.info("Reranker: local provider reinitialized successfully")
|
|
228
|
+
|
|
229
|
+
def _predict_with_recovery(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
230
|
+
"""
|
|
231
|
+
Predict with automatic recovery from XPC errors.
|
|
232
|
+
|
|
233
|
+
This runs synchronously in the thread pool.
|
|
234
|
+
"""
|
|
235
|
+
max_retries = 1
|
|
236
|
+
for attempt in range(max_retries + 1):
|
|
237
|
+
try:
|
|
238
|
+
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
239
|
+
return scores.tolist() if hasattr(scores, "tolist") else list(scores)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
# Check if this is an XPC error (macOS daemon issue)
|
|
242
|
+
if self._is_xpc_error(e) and attempt < max_retries:
|
|
243
|
+
logger.warning(f"XPC error detected in reranker (attempt {attempt + 1}): {e}")
|
|
244
|
+
try:
|
|
245
|
+
self._reinitialize_model_sync()
|
|
246
|
+
logger.info("Reranker reinitialized successfully, retrying prediction")
|
|
247
|
+
continue
|
|
248
|
+
except Exception as reinit_error:
|
|
249
|
+
logger.error(f"Failed to reinitialize reranker: {reinit_error}")
|
|
250
|
+
raise Exception(f"Failed to recover from XPC error: {str(e)}")
|
|
251
|
+
else:
|
|
252
|
+
# Not an XPC error or out of retries
|
|
253
|
+
raise
|
|
254
|
+
|
|
255
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
108
256
|
"""
|
|
109
257
|
Score query-document pairs for relevance.
|
|
110
258
|
|
|
259
|
+
Uses a dedicated thread pool with limited workers to prevent CPU thrashing.
|
|
260
|
+
Automatically recovers from XPC errors on macOS by reinitializing the model.
|
|
261
|
+
|
|
111
262
|
Args:
|
|
112
263
|
pairs: List of (query, document) tuples to score
|
|
113
264
|
|
|
@@ -116,8 +267,14 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
116
267
|
"""
|
|
117
268
|
if self._model is None:
|
|
118
269
|
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
119
|
-
|
|
120
|
-
|
|
270
|
+
|
|
271
|
+
# Use dedicated executor - limited workers naturally limits concurrency
|
|
272
|
+
loop = asyncio.get_event_loop()
|
|
273
|
+
return await loop.run_in_executor(
|
|
274
|
+
LocalSTCrossEncoder._executor,
|
|
275
|
+
self._predict_with_recovery,
|
|
276
|
+
pairs,
|
|
277
|
+
)
|
|
121
278
|
|
|
122
279
|
|
|
123
280
|
class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
@@ -128,13 +285,21 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
128
285
|
See: https://github.com/huggingface/text-embeddings-inference
|
|
129
286
|
|
|
130
287
|
Note: The TEI server must be running a cross-encoder/reranker model.
|
|
288
|
+
|
|
289
|
+
Requests are made in parallel with configurable batch size and max concurrency (backpressure).
|
|
290
|
+
Uses a GLOBAL semaphore to limit concurrent requests across ALL recall operations.
|
|
131
291
|
"""
|
|
132
292
|
|
|
293
|
+
# Global semaphore shared across all instances and calls to prevent thundering herd
|
|
294
|
+
_global_semaphore: asyncio.Semaphore | None = None
|
|
295
|
+
_global_max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT
|
|
296
|
+
|
|
133
297
|
def __init__(
|
|
134
298
|
self,
|
|
135
299
|
base_url: str,
|
|
136
300
|
timeout: float = 30.0,
|
|
137
|
-
batch_size: int =
|
|
301
|
+
batch_size: int = DEFAULT_RERANKER_TEI_BATCH_SIZE,
|
|
302
|
+
max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
|
|
138
303
|
max_retries: int = 3,
|
|
139
304
|
retry_delay: float = 0.5,
|
|
140
305
|
):
|
|
@@ -144,80 +309,246 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
144
309
|
Args:
|
|
145
310
|
base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
|
|
146
311
|
timeout: Request timeout in seconds (default: 30.0)
|
|
147
|
-
batch_size: Maximum batch size for rerank requests (default:
|
|
312
|
+
batch_size: Maximum batch size for rerank requests (default: 128)
|
|
313
|
+
max_concurrent: Maximum concurrent requests for backpressure (default: 8).
|
|
314
|
+
This is a GLOBAL limit across all parallel recall operations.
|
|
148
315
|
max_retries: Maximum number of retries for failed requests (default: 3)
|
|
149
316
|
retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
|
|
150
317
|
"""
|
|
151
318
|
self.base_url = base_url.rstrip("/")
|
|
152
319
|
self.timeout = timeout
|
|
153
320
|
self.batch_size = batch_size
|
|
321
|
+
self.max_concurrent = max_concurrent
|
|
154
322
|
self.max_retries = max_retries
|
|
155
323
|
self.retry_delay = retry_delay
|
|
156
|
-
self.
|
|
324
|
+
self._async_client: httpx.AsyncClient | None = None
|
|
157
325
|
self._model_id: str | None = None
|
|
158
326
|
|
|
327
|
+
# Update global semaphore if max_concurrent changed
|
|
328
|
+
if (
|
|
329
|
+
RemoteTEICrossEncoder._global_semaphore is None
|
|
330
|
+
or RemoteTEICrossEncoder._global_max_concurrent != max_concurrent
|
|
331
|
+
):
|
|
332
|
+
RemoteTEICrossEncoder._global_max_concurrent = max_concurrent
|
|
333
|
+
RemoteTEICrossEncoder._global_semaphore = asyncio.Semaphore(max_concurrent)
|
|
334
|
+
|
|
159
335
|
@property
|
|
160
336
|
def provider_name(self) -> str:
|
|
161
337
|
return "tei"
|
|
162
338
|
|
|
163
|
-
def
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
339
|
+
async def _async_request_with_retry(
|
|
340
|
+
self,
|
|
341
|
+
client: httpx.AsyncClient,
|
|
342
|
+
semaphore: asyncio.Semaphore,
|
|
343
|
+
method: str,
|
|
344
|
+
url: str,
|
|
345
|
+
**kwargs,
|
|
346
|
+
) -> httpx.Response:
|
|
347
|
+
"""Make an async HTTP request with automatic retries on transient errors and semaphore for backpressure."""
|
|
167
348
|
last_error = None
|
|
168
349
|
delay = self.retry_delay
|
|
169
350
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if attempt < self.max_retries:
|
|
181
|
-
logger.warning(
|
|
182
|
-
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
|
|
183
|
-
)
|
|
184
|
-
time.sleep(delay)
|
|
185
|
-
delay *= 2 # Exponential backoff
|
|
186
|
-
except httpx.HTTPStatusError as e:
|
|
187
|
-
# Retry on 5xx server errors
|
|
188
|
-
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
351
|
+
async with semaphore:
|
|
352
|
+
for attempt in range(self.max_retries + 1):
|
|
353
|
+
try:
|
|
354
|
+
if method == "GET":
|
|
355
|
+
response = await client.get(url, **kwargs)
|
|
356
|
+
else:
|
|
357
|
+
response = await client.post(url, **kwargs)
|
|
358
|
+
response.raise_for_status()
|
|
359
|
+
return response
|
|
360
|
+
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
189
361
|
last_error = e
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
362
|
+
if attempt < self.max_retries:
|
|
363
|
+
logger.warning(
|
|
364
|
+
f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
|
|
365
|
+
f"Retrying in {delay}s..."
|
|
366
|
+
)
|
|
367
|
+
await asyncio.sleep(delay)
|
|
368
|
+
delay *= 2 # Exponential backoff
|
|
369
|
+
except httpx.HTTPStatusError as e:
|
|
370
|
+
# Retry on 5xx server errors
|
|
371
|
+
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
372
|
+
last_error = e
|
|
373
|
+
logger.warning(
|
|
374
|
+
f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
|
|
375
|
+
f"Retrying in {delay}s..."
|
|
376
|
+
)
|
|
377
|
+
await asyncio.sleep(delay)
|
|
378
|
+
delay *= 2
|
|
379
|
+
else:
|
|
380
|
+
raise
|
|
197
381
|
|
|
198
382
|
raise last_error
|
|
199
383
|
|
|
200
384
|
async def initialize(self) -> None:
|
|
201
385
|
"""Initialize the HTTP client and verify server connectivity."""
|
|
202
|
-
if self.
|
|
386
|
+
if self._async_client is not None:
|
|
203
387
|
return
|
|
204
388
|
|
|
205
|
-
logger.info(
|
|
206
|
-
|
|
389
|
+
logger.info(
|
|
390
|
+
f"Reranker: initializing TEI provider at {self.base_url} "
|
|
391
|
+
f"(batch_size={self.batch_size}, max_concurrent={self.max_concurrent})"
|
|
392
|
+
)
|
|
393
|
+
self._async_client = httpx.AsyncClient(timeout=self.timeout)
|
|
207
394
|
|
|
208
395
|
# Verify server is reachable and get model info
|
|
396
|
+
# Use a temporary semaphore for initialization
|
|
397
|
+
init_semaphore = asyncio.Semaphore(1)
|
|
209
398
|
try:
|
|
210
|
-
response = self.
|
|
399
|
+
response = await self._async_request_with_retry(
|
|
400
|
+
self._async_client, init_semaphore, "GET", f"{self.base_url}/info"
|
|
401
|
+
)
|
|
211
402
|
info = response.json()
|
|
212
403
|
self._model_id = info.get("model_id", "unknown")
|
|
213
404
|
logger.info(f"Reranker: TEI provider initialized (model: {self._model_id})")
|
|
214
405
|
except httpx.HTTPError as e:
|
|
406
|
+
self._async_client = None
|
|
215
407
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
216
408
|
|
|
217
|
-
def
|
|
409
|
+
async def _rerank_query_group(
|
|
410
|
+
self,
|
|
411
|
+
client: httpx.AsyncClient,
|
|
412
|
+
semaphore: asyncio.Semaphore,
|
|
413
|
+
query: str,
|
|
414
|
+
texts: list[str],
|
|
415
|
+
) -> list[tuple[int, float]]:
|
|
416
|
+
"""Rerank a single query group and return list of (original_index, score) tuples."""
|
|
417
|
+
try:
|
|
418
|
+
response = await self._async_request_with_retry(
|
|
419
|
+
client,
|
|
420
|
+
semaphore,
|
|
421
|
+
"POST",
|
|
422
|
+
f"{self.base_url}/rerank",
|
|
423
|
+
json={
|
|
424
|
+
"query": query,
|
|
425
|
+
"texts": texts,
|
|
426
|
+
"return_text": False,
|
|
427
|
+
},
|
|
428
|
+
)
|
|
429
|
+
results = response.json()
|
|
430
|
+
# TEI returns results sorted by score descending, with original index
|
|
431
|
+
return [(result["index"], result["score"]) for result in results]
|
|
432
|
+
except httpx.HTTPError as e:
|
|
433
|
+
raise RuntimeError(f"TEI rerank request failed: {e}")
|
|
434
|
+
|
|
435
|
+
async def _predict_async(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
436
|
+
"""Async implementation of predict that runs requests in parallel with backpressure."""
|
|
437
|
+
if not pairs:
|
|
438
|
+
return []
|
|
439
|
+
|
|
440
|
+
# Group all pairs by query
|
|
441
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
442
|
+
for idx, (query, text) in enumerate(pairs):
|
|
443
|
+
if query not in query_groups:
|
|
444
|
+
query_groups[query] = []
|
|
445
|
+
query_groups[query].append((idx, text))
|
|
446
|
+
|
|
447
|
+
# Split each query group into batches
|
|
448
|
+
tasks_info: list[tuple[str, list[int], list[str]]] = [] # (query, indices, texts)
|
|
449
|
+
for query, indexed_texts in query_groups.items():
|
|
450
|
+
indices = [idx for idx, _ in indexed_texts]
|
|
451
|
+
texts = [text for _, text in indexed_texts]
|
|
452
|
+
|
|
453
|
+
# Split into batches
|
|
454
|
+
for i in range(0, len(texts), self.batch_size):
|
|
455
|
+
batch_indices = indices[i : i + self.batch_size]
|
|
456
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
457
|
+
tasks_info.append((query, batch_indices, batch_texts))
|
|
458
|
+
|
|
459
|
+
# Run all requests in parallel with GLOBAL semaphore for backpressure
|
|
460
|
+
# This ensures max_concurrent is respected across ALL parallel recall operations
|
|
461
|
+
all_scores = [0.0] * len(pairs)
|
|
462
|
+
semaphore = RemoteTEICrossEncoder._global_semaphore
|
|
463
|
+
|
|
464
|
+
tasks = [
|
|
465
|
+
self._rerank_query_group(self._async_client, semaphore, query, texts) for query, _, texts in tasks_info
|
|
466
|
+
]
|
|
467
|
+
results = await asyncio.gather(*tasks)
|
|
468
|
+
|
|
469
|
+
# Map scores back to original positions
|
|
470
|
+
for (_, indices, _), result_scores in zip(tasks_info, results):
|
|
471
|
+
for original_idx_in_batch, score in result_scores:
|
|
472
|
+
global_idx = indices[original_idx_in_batch]
|
|
473
|
+
all_scores[global_idx] = score
|
|
474
|
+
|
|
475
|
+
return all_scores
|
|
476
|
+
|
|
477
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
218
478
|
"""
|
|
219
479
|
Score query-document pairs using the remote TEI reranker.
|
|
220
480
|
|
|
481
|
+
Requests are made in parallel with configurable backpressure.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
pairs: List of (query, document) tuples to score
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
List of relevance scores
|
|
488
|
+
"""
|
|
489
|
+
if self._async_client is None:
|
|
490
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
491
|
+
|
|
492
|
+
return await self._predict_async(pairs)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class CohereCrossEncoder(CrossEncoderModel):
|
|
496
|
+
"""
|
|
497
|
+
Cohere cross-encoder implementation using the Cohere Rerank API.
|
|
498
|
+
|
|
499
|
+
Supports rerank-english-v3.0 and rerank-multilingual-v3.0 models.
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
def __init__(
|
|
503
|
+
self,
|
|
504
|
+
api_key: str,
|
|
505
|
+
model: str = DEFAULT_RERANKER_COHERE_MODEL,
|
|
506
|
+
base_url: str | None = None,
|
|
507
|
+
timeout: float = 60.0,
|
|
508
|
+
):
|
|
509
|
+
"""
|
|
510
|
+
Initialize Cohere cross-encoder client.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
api_key: Cohere API key
|
|
514
|
+
model: Cohere rerank model name (default: rerank-english-v3.0)
|
|
515
|
+
base_url: Custom base URL for Cohere-compatible API (e.g., Azure-hosted endpoint)
|
|
516
|
+
timeout: Request timeout in seconds (default: 60.0)
|
|
517
|
+
"""
|
|
518
|
+
self.api_key = api_key
|
|
519
|
+
self.model = model
|
|
520
|
+
self.base_url = base_url
|
|
521
|
+
self.timeout = timeout
|
|
522
|
+
self._client = None
|
|
523
|
+
|
|
524
|
+
@property
|
|
525
|
+
def provider_name(self) -> str:
|
|
526
|
+
return "cohere"
|
|
527
|
+
|
|
528
|
+
async def initialize(self) -> None:
|
|
529
|
+
"""Initialize the Cohere client."""
|
|
530
|
+
if self._client is not None:
|
|
531
|
+
return
|
|
532
|
+
|
|
533
|
+
try:
|
|
534
|
+
import cohere
|
|
535
|
+
except ImportError:
|
|
536
|
+
raise ImportError("cohere is required for CohereCrossEncoder. Install it with: pip install cohere")
|
|
537
|
+
|
|
538
|
+
base_url_msg = f" at {self.base_url}" if self.base_url else ""
|
|
539
|
+
logger.info(f"Reranker: initializing Cohere provider with model {self.model}{base_url_msg}")
|
|
540
|
+
|
|
541
|
+
# Build client kwargs, only including base_url if set (for Azure or custom endpoints)
|
|
542
|
+
client_kwargs = {"api_key": self.api_key, "timeout": self.timeout}
|
|
543
|
+
if self.base_url:
|
|
544
|
+
client_kwargs["base_url"] = self.base_url
|
|
545
|
+
self._client = cohere.Client(**client_kwargs)
|
|
546
|
+
logger.info("Reranker: Cohere provider initialized")
|
|
547
|
+
|
|
548
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
549
|
+
"""
|
|
550
|
+
Score query-document pairs using the Cohere Rerank API.
|
|
551
|
+
|
|
221
552
|
Args:
|
|
222
553
|
pairs: List of (query, document) tuples to score
|
|
223
554
|
|
|
@@ -230,50 +561,312 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
|
230
561
|
if not pairs:
|
|
231
562
|
return []
|
|
232
563
|
|
|
233
|
-
|
|
564
|
+
# Run sync Cohere API calls in thread pool
|
|
565
|
+
loop = asyncio.get_event_loop()
|
|
566
|
+
return await loop.run_in_executor(None, self._predict_sync, pairs)
|
|
567
|
+
|
|
568
|
+
def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
569
|
+
"""Synchronous predict implementation for Cohere API."""
|
|
570
|
+
# Group pairs by query for efficient batching
|
|
571
|
+
# Cohere rerank expects one query with multiple documents
|
|
572
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
573
|
+
for idx, (query, text) in enumerate(pairs):
|
|
574
|
+
if query not in query_groups:
|
|
575
|
+
query_groups[query] = []
|
|
576
|
+
query_groups[query].append((idx, text))
|
|
577
|
+
|
|
578
|
+
all_scores = [0.0] * len(pairs)
|
|
579
|
+
|
|
580
|
+
for query, indexed_texts in query_groups.items():
|
|
581
|
+
texts = [text for _, text in indexed_texts]
|
|
582
|
+
indices = [idx for idx, _ in indexed_texts]
|
|
583
|
+
|
|
584
|
+
response = self._client.rerank(
|
|
585
|
+
query=query,
|
|
586
|
+
documents=texts,
|
|
587
|
+
model=self.model,
|
|
588
|
+
return_documents=False,
|
|
589
|
+
)
|
|
234
590
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
591
|
+
# Map scores back to original positions
|
|
592
|
+
for result in response.results:
|
|
593
|
+
original_idx = result.index
|
|
594
|
+
score = result.relevance_score
|
|
595
|
+
all_scores[indices[original_idx]] = score
|
|
238
596
|
|
|
239
|
-
|
|
240
|
-
# All pairs in a batch should have the same query for optimal performance
|
|
241
|
-
# but we handle mixed queries by making separate requests per unique query
|
|
242
|
-
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
243
|
-
for idx, (query, text) in enumerate(batch):
|
|
244
|
-
if query not in query_groups:
|
|
245
|
-
query_groups[query] = []
|
|
246
|
-
query_groups[query].append((idx, text))
|
|
597
|
+
return all_scores
|
|
247
598
|
|
|
248
|
-
batch_scores = [0.0] * len(batch)
|
|
249
599
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
600
|
+
class RRFPassthroughCrossEncoder(CrossEncoderModel):
|
|
601
|
+
"""
|
|
602
|
+
Passthrough cross-encoder that preserves RRF scores without neural reranking.
|
|
253
603
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
604
|
+
This is useful for:
|
|
605
|
+
- Testing retrieval quality without reranking overhead
|
|
606
|
+
- Deployments where reranking latency is unacceptable
|
|
607
|
+
- Debugging to isolate retrieval vs reranking issues
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
def __init__(self):
|
|
611
|
+
"""Initialize RRF passthrough cross-encoder."""
|
|
612
|
+
pass
|
|
613
|
+
|
|
614
|
+
@property
|
|
615
|
+
def provider_name(self) -> str:
|
|
616
|
+
return "rrf"
|
|
617
|
+
|
|
618
|
+
async def initialize(self) -> None:
|
|
619
|
+
"""No initialization needed."""
|
|
620
|
+
logger.info("Reranker: RRF passthrough provider initialized (neural reranking disabled)")
|
|
621
|
+
|
|
622
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
623
|
+
"""
|
|
624
|
+
Return neutral scores - actual ranking uses RRF scores from retrieval.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
pairs: List of (query, document) tuples (ignored)
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
List of 0.5 scores (neutral, lets RRF scores dominate)
|
|
631
|
+
"""
|
|
632
|
+
# Return neutral scores so RRF ranking is preserved
|
|
633
|
+
return [0.5] * len(pairs)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class FlashRankCrossEncoder(CrossEncoderModel):
|
|
637
|
+
"""
|
|
638
|
+
FlashRank cross-encoder implementation.
|
|
639
|
+
|
|
640
|
+
FlashRank is an ultra-lite reranking library that runs on CPU without
|
|
641
|
+
requiring PyTorch or Transformers. It's ideal for serverless deployments
|
|
642
|
+
with minimal cold-start overhead.
|
|
643
|
+
|
|
644
|
+
Available models:
|
|
645
|
+
- ms-marco-TinyBERT-L-2-v2: Fastest, ~4MB
|
|
646
|
+
- ms-marco-MiniLM-L-12-v2: Best quality, ~34MB (default)
|
|
647
|
+
- rank-T5-flan: Best zero-shot, ~110MB
|
|
648
|
+
- ms-marco-MultiBERT-L-12: Multi-lingual, ~150MB
|
|
649
|
+
"""
|
|
650
|
+
|
|
651
|
+
# Shared executor for CPU-bound reranking
|
|
652
|
+
_executor: ThreadPoolExecutor | None = None
|
|
653
|
+
_max_concurrent: int = 4
|
|
654
|
+
|
|
655
|
+
def __init__(
|
|
656
|
+
self,
|
|
657
|
+
model_name: str | None = None,
|
|
658
|
+
cache_dir: str | None = None,
|
|
659
|
+
max_length: int = 512,
|
|
660
|
+
max_concurrent: int = 4,
|
|
661
|
+
):
|
|
662
|
+
"""
|
|
663
|
+
Initialize FlashRank cross-encoder.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
model_name: FlashRank model name. Default: ms-marco-MiniLM-L-12-v2
|
|
667
|
+
cache_dir: Directory to cache downloaded models. Default: system cache
|
|
668
|
+
max_length: Maximum sequence length for reranking. Default: 512
|
|
669
|
+
max_concurrent: Maximum concurrent reranking calls. Default: 4
|
|
670
|
+
"""
|
|
671
|
+
self.model_name = model_name or DEFAULT_RERANKER_FLASHRANK_MODEL
|
|
672
|
+
self.cache_dir = cache_dir or DEFAULT_RERANKER_FLASHRANK_CACHE_DIR
|
|
673
|
+
self.max_length = max_length
|
|
674
|
+
self._ranker = None
|
|
675
|
+
FlashRankCrossEncoder._max_concurrent = max_concurrent
|
|
676
|
+
|
|
677
|
+
@property
|
|
678
|
+
def provider_name(self) -> str:
|
|
679
|
+
return "flashrank"
|
|
680
|
+
|
|
681
|
+
async def initialize(self) -> None:
|
|
682
|
+
"""Load the FlashRank model."""
|
|
683
|
+
if self._ranker is not None:
|
|
684
|
+
return
|
|
685
|
+
|
|
686
|
+
try:
|
|
687
|
+
from flashrank import Ranker # type: ignore[import-untyped]
|
|
688
|
+
except ImportError:
|
|
689
|
+
raise ImportError("flashrank is required for FlashRankCrossEncoder. Install it with: pip install flashrank")
|
|
690
|
+
|
|
691
|
+
logger.info(f"Reranker: initializing FlashRank provider with model {self.model_name}")
|
|
692
|
+
|
|
693
|
+
# Initialize ranker with optional cache directory
|
|
694
|
+
ranker_kwargs = {"model_name": self.model_name, "max_length": self.max_length}
|
|
695
|
+
if self.cache_dir:
|
|
696
|
+
ranker_kwargs["cache_dir"] = self.cache_dir
|
|
697
|
+
|
|
698
|
+
self._ranker = Ranker(**ranker_kwargs)
|
|
699
|
+
|
|
700
|
+
# Initialize shared executor
|
|
701
|
+
if FlashRankCrossEncoder._executor is None:
|
|
702
|
+
FlashRankCrossEncoder._executor = ThreadPoolExecutor(
|
|
703
|
+
max_workers=FlashRankCrossEncoder._max_concurrent,
|
|
704
|
+
thread_name_prefix="flashrank",
|
|
705
|
+
)
|
|
706
|
+
logger.info(
|
|
707
|
+
f"Reranker: FlashRank provider initialized (max_concurrent={FlashRankCrossEncoder._max_concurrent})"
|
|
708
|
+
)
|
|
709
|
+
else:
|
|
710
|
+
logger.info("Reranker: FlashRank provider initialized (using existing executor)")
|
|
711
|
+
|
|
712
|
+
def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
713
|
+
"""Synchronous predict - processes each query group."""
|
|
714
|
+
from flashrank import RerankRequest # type: ignore[import-untyped]
|
|
715
|
+
|
|
716
|
+
if not pairs:
|
|
717
|
+
return []
|
|
718
|
+
|
|
719
|
+
# Group pairs by query
|
|
720
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
721
|
+
for idx, (query, text) in enumerate(pairs):
|
|
722
|
+
if query not in query_groups:
|
|
723
|
+
query_groups[query] = []
|
|
724
|
+
query_groups[query].append((idx, text))
|
|
725
|
+
|
|
726
|
+
all_scores = [0.0] * len(pairs)
|
|
727
|
+
|
|
728
|
+
for query, indexed_texts in query_groups.items():
|
|
729
|
+
# Build passages list for FlashRank
|
|
730
|
+
passages = [{"id": i, "text": text} for i, (_, text) in enumerate(indexed_texts)]
|
|
731
|
+
global_indices = [idx for idx, _ in indexed_texts]
|
|
732
|
+
|
|
733
|
+
# Create rerank request
|
|
734
|
+
request = RerankRequest(query=query, passages=passages)
|
|
735
|
+
results = self._ranker.rerank(request)
|
|
736
|
+
|
|
737
|
+
# Map scores back to original positions
|
|
738
|
+
for result in results:
|
|
739
|
+
local_idx = result["id"]
|
|
740
|
+
score = result["score"]
|
|
741
|
+
global_idx = global_indices[local_idx]
|
|
742
|
+
all_scores[global_idx] = score
|
|
743
|
+
|
|
744
|
+
return all_scores
|
|
745
|
+
|
|
746
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
747
|
+
"""
|
|
748
|
+
Score query-document pairs using FlashRank.
|
|
749
|
+
|
|
750
|
+
Args:
|
|
751
|
+
pairs: List of (query, document) tuples to score
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
List of relevance scores (higher = more relevant)
|
|
755
|
+
"""
|
|
756
|
+
if self._ranker is None:
|
|
757
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
758
|
+
|
|
759
|
+
# Run in thread pool to avoid blocking event loop
|
|
760
|
+
loop = asyncio.get_event_loop()
|
|
761
|
+
return await loop.run_in_executor(FlashRankCrossEncoder._executor, self._predict_sync, pairs)
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
class LiteLLMCrossEncoder(CrossEncoderModel):
|
|
765
|
+
"""
|
|
766
|
+
LiteLLM cross-encoder implementation using LiteLLM proxy's /rerank endpoint.
|
|
767
|
+
|
|
768
|
+
LiteLLM provides a unified interface for multiple reranking providers via
|
|
769
|
+
the Cohere-compatible /rerank endpoint.
|
|
770
|
+
See: https://docs.litellm.ai/docs/rerank
|
|
771
|
+
|
|
772
|
+
Supported providers via LiteLLM:
|
|
773
|
+
- Cohere (rerank-english-v3.0, etc.) - prefix with cohere/
|
|
774
|
+
- Together AI - prefix with together_ai/
|
|
775
|
+
- Azure AI - prefix with azure_ai/
|
|
776
|
+
- Jina AI - prefix with jina_ai/
|
|
777
|
+
- AWS Bedrock - prefix with bedrock/
|
|
778
|
+
- Voyage AI - prefix with voyage/
|
|
779
|
+
"""
|
|
780
|
+
|
|
781
|
+
def __init__(
|
|
782
|
+
self,
|
|
783
|
+
api_base: str = DEFAULT_LITELLM_API_BASE,
|
|
784
|
+
api_key: str | None = None,
|
|
785
|
+
model: str = DEFAULT_RERANKER_LITELLM_MODEL,
|
|
786
|
+
timeout: float = 60.0,
|
|
787
|
+
):
|
|
788
|
+
"""
|
|
789
|
+
Initialize LiteLLM cross-encoder client.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
api_base: Base URL of the LiteLLM proxy (default: http://localhost:4000)
|
|
793
|
+
api_key: API key for the LiteLLM proxy (optional, depends on proxy config)
|
|
794
|
+
model: Reranking model name (default: cohere/rerank-english-v3.0)
|
|
795
|
+
Use provider prefix (e.g., cohere/, together_ai/, voyage/)
|
|
796
|
+
timeout: Request timeout in seconds (default: 60.0)
|
|
797
|
+
"""
|
|
798
|
+
self.api_base = api_base.rstrip("/")
|
|
799
|
+
self.api_key = api_key
|
|
800
|
+
self.model = model
|
|
801
|
+
self.timeout = timeout
|
|
802
|
+
self._async_client: httpx.AsyncClient | None = None
|
|
803
|
+
|
|
804
|
+
@property
|
|
805
|
+
def provider_name(self) -> str:
|
|
806
|
+
return "litellm"
|
|
807
|
+
|
|
808
|
+
async def initialize(self) -> None:
|
|
809
|
+
"""Initialize the async HTTP client."""
|
|
810
|
+
if self._async_client is not None:
|
|
811
|
+
return
|
|
812
|
+
|
|
813
|
+
logger.info(f"Reranker: initializing LiteLLM provider at {self.api_base} with model {self.model}")
|
|
814
|
+
|
|
815
|
+
headers = {"Content-Type": "application/json"}
|
|
816
|
+
if self.api_key:
|
|
817
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
818
|
+
|
|
819
|
+
self._async_client = httpx.AsyncClient(timeout=self.timeout, headers=headers)
|
|
820
|
+
logger.info("Reranker: LiteLLM provider initialized")
|
|
821
|
+
|
|
822
|
+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
|
|
823
|
+
"""
|
|
824
|
+
Score query-document pairs using the LiteLLM proxy's /rerank endpoint.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
pairs: List of (query, document) tuples to score
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
List of relevance scores
|
|
831
|
+
"""
|
|
832
|
+
if self._async_client is None:
|
|
833
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
834
|
+
|
|
835
|
+
if not pairs:
|
|
836
|
+
return []
|
|
837
|
+
|
|
838
|
+
# Group pairs by query (LiteLLM rerank expects one query with multiple documents)
|
|
839
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
840
|
+
for idx, (query, text) in enumerate(pairs):
|
|
841
|
+
if query not in query_groups:
|
|
842
|
+
query_groups[query] = []
|
|
843
|
+
query_groups[query].append((idx, text))
|
|
844
|
+
|
|
845
|
+
all_scores = [0.0] * len(pairs)
|
|
846
|
+
|
|
847
|
+
for query, indexed_texts in query_groups.items():
|
|
848
|
+
texts = [text for _, text in indexed_texts]
|
|
849
|
+
indices = [idx for idx, _ in indexed_texts]
|
|
850
|
+
|
|
851
|
+
# LiteLLM /rerank follows Cohere API format
|
|
852
|
+
response = await self._async_client.post(
|
|
853
|
+
f"{self.api_base}/rerank",
|
|
854
|
+
json={
|
|
855
|
+
"model": self.model,
|
|
856
|
+
"query": query,
|
|
857
|
+
"documents": texts,
|
|
858
|
+
"top_n": len(texts), # Return all scores
|
|
859
|
+
},
|
|
860
|
+
)
|
|
861
|
+
response.raise_for_status()
|
|
862
|
+
result = response.json()
|
|
863
|
+
|
|
864
|
+
# Map scores back to original positions
|
|
865
|
+
# Response format: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
|
|
866
|
+
for item in result.get("results", []):
|
|
867
|
+
original_idx = item["index"]
|
|
868
|
+
score = item.get("relevance_score", item.get("score", 0.0))
|
|
869
|
+
all_scores[indices[original_idx]] = score
|
|
277
870
|
|
|
278
871
|
return all_scores
|
|
279
872
|
|
|
@@ -293,10 +886,35 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
|
|
|
293
886
|
url = os.environ.get(ENV_RERANKER_TEI_URL)
|
|
294
887
|
if not url:
|
|
295
888
|
raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
|
|
296
|
-
|
|
889
|
+
batch_size = int(os.environ.get(ENV_RERANKER_TEI_BATCH_SIZE, str(DEFAULT_RERANKER_TEI_BATCH_SIZE)))
|
|
890
|
+
max_concurrent = int(os.environ.get(ENV_RERANKER_TEI_MAX_CONCURRENT, str(DEFAULT_RERANKER_TEI_MAX_CONCURRENT)))
|
|
891
|
+
return RemoteTEICrossEncoder(base_url=url, batch_size=batch_size, max_concurrent=max_concurrent)
|
|
297
892
|
elif provider == "local":
|
|
298
893
|
model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
|
|
299
894
|
model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
|
|
300
|
-
|
|
895
|
+
max_concurrent = int(
|
|
896
|
+
os.environ.get(ENV_RERANKER_LOCAL_MAX_CONCURRENT, str(DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT))
|
|
897
|
+
)
|
|
898
|
+
return LocalSTCrossEncoder(model_name=model_name, max_concurrent=max_concurrent)
|
|
899
|
+
elif provider == "cohere":
|
|
900
|
+
api_key = os.environ.get(ENV_COHERE_API_KEY)
|
|
901
|
+
if not api_key:
|
|
902
|
+
raise ValueError(f"{ENV_COHERE_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'cohere'")
|
|
903
|
+
model = os.environ.get(ENV_RERANKER_COHERE_MODEL, DEFAULT_RERANKER_COHERE_MODEL)
|
|
904
|
+
base_url = os.environ.get(ENV_RERANKER_COHERE_BASE_URL) or None
|
|
905
|
+
return CohereCrossEncoder(api_key=api_key, model=model, base_url=base_url)
|
|
906
|
+
elif provider == "flashrank":
|
|
907
|
+
model = os.environ.get(ENV_RERANKER_FLASHRANK_MODEL, DEFAULT_RERANKER_FLASHRANK_MODEL)
|
|
908
|
+
cache_dir = os.environ.get(ENV_RERANKER_FLASHRANK_CACHE_DIR, DEFAULT_RERANKER_FLASHRANK_CACHE_DIR)
|
|
909
|
+
return FlashRankCrossEncoder(model_name=model, cache_dir=cache_dir)
|
|
910
|
+
elif provider == "litellm":
|
|
911
|
+
api_base = os.environ.get(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE)
|
|
912
|
+
api_key = os.environ.get(ENV_LITELLM_API_KEY)
|
|
913
|
+
model = os.environ.get(ENV_RERANKER_LITELLM_MODEL, DEFAULT_RERANKER_LITELLM_MODEL)
|
|
914
|
+
return LiteLLMCrossEncoder(api_base=api_base, api_key=api_key, model=model)
|
|
915
|
+
elif provider == "rrf":
|
|
916
|
+
return RRFPassthroughCrossEncoder()
|
|
301
917
|
else:
|
|
302
|
-
raise ValueError(
|
|
918
|
+
raise ValueError(
|
|
919
|
+
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'rrf'"
|
|
920
|
+
)
|