hindsight-api 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.0.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -6,17 +6,38 @@ Provides an interface for reranking with different backends.
6
6
  Configuration via environment variables - see hindsight_api.config for all env var names.
7
7
  """
8
8
 
9
+ import asyncio
9
10
  import logging
10
11
  import os
11
12
  from abc import ABC, abstractmethod
13
+ from concurrent.futures import ThreadPoolExecutor
12
14
 
13
15
  import httpx
14
16
 
15
17
  from ..config import (
18
+ DEFAULT_LITELLM_API_BASE,
19
+ DEFAULT_RERANKER_COHERE_MODEL,
20
+ DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
21
+ DEFAULT_RERANKER_FLASHRANK_MODEL,
22
+ DEFAULT_RERANKER_LITELLM_MODEL,
23
+ DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT,
16
24
  DEFAULT_RERANKER_LOCAL_MODEL,
17
25
  DEFAULT_RERANKER_PROVIDER,
26
+ DEFAULT_RERANKER_TEI_BATCH_SIZE,
27
+ DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
28
+ ENV_COHERE_API_KEY,
29
+ ENV_LITELLM_API_BASE,
30
+ ENV_LITELLM_API_KEY,
31
+ ENV_RERANKER_COHERE_BASE_URL,
32
+ ENV_RERANKER_COHERE_MODEL,
33
+ ENV_RERANKER_FLASHRANK_CACHE_DIR,
34
+ ENV_RERANKER_FLASHRANK_MODEL,
35
+ ENV_RERANKER_LITELLM_MODEL,
36
+ ENV_RERANKER_LOCAL_MAX_CONCURRENT,
18
37
  ENV_RERANKER_LOCAL_MODEL,
19
38
  ENV_RERANKER_PROVIDER,
39
+ ENV_RERANKER_TEI_BATCH_SIZE,
40
+ ENV_RERANKER_TEI_MAX_CONCURRENT,
20
41
  ENV_RERANKER_TEI_URL,
21
42
  )
22
43
 
@@ -47,7 +68,7 @@ class CrossEncoderModel(ABC):
47
68
  pass
48
69
 
49
70
  @abstractmethod
50
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
71
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
51
72
  """
52
73
  Score query-document pairs for relevance.
53
74
 
@@ -70,25 +91,34 @@ class LocalSTCrossEncoder(CrossEncoderModel):
70
91
  - Fast inference (~80ms for 100 pairs on CPU)
71
92
  - Small model (80MB)
72
93
  - Trained for passage re-ranking
94
+
95
+ Uses a dedicated thread pool to limit concurrent CPU-bound work.
73
96
  """
74
97
 
75
- def __init__(self, model_name: str | None = None):
98
+ # Shared executor across all instances (one model loaded anyway)
99
+ _executor: ThreadPoolExecutor | None = None
100
+ _max_concurrent: int = 4 # Limit concurrent CPU-bound reranking calls
101
+
102
+ def __init__(self, model_name: str | None = None, max_concurrent: int = 4):
76
103
  """
77
104
  Initialize local SentenceTransformers cross-encoder.
78
105
 
79
106
  Args:
80
107
  model_name: Name of the CrossEncoder model to use.
81
108
  Default: cross-encoder/ms-marco-MiniLM-L-6-v2
109
+ max_concurrent: Maximum concurrent reranking calls (default: 2).
110
+ Higher values may cause CPU thrashing under load.
82
111
  """
83
112
  self.model_name = model_name or DEFAULT_RERANKER_LOCAL_MODEL
84
113
  self._model = None
114
+ LocalSTCrossEncoder._max_concurrent = max_concurrent
85
115
 
86
116
  @property
87
117
  def provider_name(self) -> str:
88
118
  return "local"
89
119
 
90
120
  async def initialize(self) -> None:
91
- """Load the cross-encoder model."""
121
+ """Load the cross-encoder model and initialize the executor."""
92
122
  if self._model is not None:
93
123
  return
94
124
 
@@ -100,14 +130,30 @@ class LocalSTCrossEncoder(CrossEncoderModel):
100
130
  "Install it with: pip install sentence-transformers"
101
131
  )
102
132
 
133
+ # Note: We use CPU even when GPU/MPS is available because:
134
+ # 1. The reranker model (MiniLM) is tiny (~22M params)
135
+ # 2. Batch sizes are small (~100-200 pairs)
136
+ # 3. Data transfer overhead to GPU outweighs compute benefit
137
+ # 4. CPU inference is actually faster for this workload
103
138
  logger.info(f"Reranker: initializing local provider with model {self.model_name}")
104
139
  self._model = CrossEncoder(self.model_name)
105
- logger.info("Reranker: local provider initialized")
106
140
 
107
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
141
+ # Initialize shared executor (limited workers naturally limits concurrency)
142
+ if LocalSTCrossEncoder._executor is None:
143
+ LocalSTCrossEncoder._executor = ThreadPoolExecutor(
144
+ max_workers=LocalSTCrossEncoder._max_concurrent,
145
+ thread_name_prefix="reranker",
146
+ )
147
+ logger.info(f"Reranker: local provider initialized (max_concurrent={LocalSTCrossEncoder._max_concurrent})")
148
+ else:
149
+ logger.info("Reranker: local provider initialized (using existing executor)")
150
+
151
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
108
152
  """
109
153
  Score query-document pairs for relevance.
110
154
 
155
+ Uses a dedicated thread pool with limited workers to prevent CPU thrashing.
156
+
111
157
  Args:
112
158
  pairs: List of (query, document) tuples to score
113
159
 
@@ -116,7 +162,13 @@ class LocalSTCrossEncoder(CrossEncoderModel):
116
162
  """
117
163
  if self._model is None:
118
164
  raise RuntimeError("Reranker not initialized. Call initialize() first.")
119
- scores = self._model.predict(pairs, show_progress_bar=False)
165
+
166
+ # Use dedicated executor - limited workers naturally limits concurrency
167
+ loop = asyncio.get_event_loop()
168
+ scores = await loop.run_in_executor(
169
+ LocalSTCrossEncoder._executor,
170
+ lambda: self._model.predict(pairs, show_progress_bar=False),
171
+ )
120
172
  return scores.tolist() if hasattr(scores, "tolist") else list(scores)
121
173
 
122
174
 
@@ -128,13 +180,21 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
128
180
  See: https://github.com/huggingface/text-embeddings-inference
129
181
 
130
182
  Note: The TEI server must be running a cross-encoder/reranker model.
183
+
184
+ Requests are made in parallel with configurable batch size and max concurrency (backpressure).
185
+ Uses a GLOBAL semaphore to limit concurrent requests across ALL recall operations.
131
186
  """
132
187
 
188
+ # Global semaphore shared across all instances and calls to prevent thundering herd
189
+ _global_semaphore: asyncio.Semaphore | None = None
190
+ _global_max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT
191
+
133
192
  def __init__(
134
193
  self,
135
194
  base_url: str,
136
195
  timeout: float = 30.0,
137
- batch_size: int = 32,
196
+ batch_size: int = DEFAULT_RERANKER_TEI_BATCH_SIZE,
197
+ max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
138
198
  max_retries: int = 3,
139
199
  retry_delay: float = 0.5,
140
200
  ):
@@ -144,80 +204,246 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
144
204
  Args:
145
205
  base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
146
206
  timeout: Request timeout in seconds (default: 30.0)
147
- batch_size: Maximum batch size for rerank requests (default: 32)
207
+ batch_size: Maximum batch size for rerank requests (default: 128)
208
+ max_concurrent: Maximum concurrent requests for backpressure (default: 8).
209
+ This is a GLOBAL limit across all parallel recall operations.
148
210
  max_retries: Maximum number of retries for failed requests (default: 3)
149
211
  retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
150
212
  """
151
213
  self.base_url = base_url.rstrip("/")
152
214
  self.timeout = timeout
153
215
  self.batch_size = batch_size
216
+ self.max_concurrent = max_concurrent
154
217
  self.max_retries = max_retries
155
218
  self.retry_delay = retry_delay
156
- self._client: httpx.Client | None = None
219
+ self._async_client: httpx.AsyncClient | None = None
157
220
  self._model_id: str | None = None
158
221
 
222
+ # Update global semaphore if max_concurrent changed
223
+ if (
224
+ RemoteTEICrossEncoder._global_semaphore is None
225
+ or RemoteTEICrossEncoder._global_max_concurrent != max_concurrent
226
+ ):
227
+ RemoteTEICrossEncoder._global_max_concurrent = max_concurrent
228
+ RemoteTEICrossEncoder._global_semaphore = asyncio.Semaphore(max_concurrent)
229
+
159
230
  @property
160
231
  def provider_name(self) -> str:
161
232
  return "tei"
162
233
 
163
- def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
164
- """Make an HTTP request with automatic retries on transient errors."""
165
- import time
166
-
234
+ async def _async_request_with_retry(
235
+ self,
236
+ client: httpx.AsyncClient,
237
+ semaphore: asyncio.Semaphore,
238
+ method: str,
239
+ url: str,
240
+ **kwargs,
241
+ ) -> httpx.Response:
242
+ """Make an async HTTP request with automatic retries on transient errors and semaphore for backpressure."""
167
243
  last_error = None
168
244
  delay = self.retry_delay
169
245
 
170
- for attempt in range(self.max_retries + 1):
171
- try:
172
- if method == "GET":
173
- response = self._client.get(url, **kwargs)
174
- else:
175
- response = self._client.post(url, **kwargs)
176
- response.raise_for_status()
177
- return response
178
- except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
179
- last_error = e
180
- if attempt < self.max_retries:
181
- logger.warning(
182
- f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
183
- )
184
- time.sleep(delay)
185
- delay *= 2 # Exponential backoff
186
- except httpx.HTTPStatusError as e:
187
- # Retry on 5xx server errors
188
- if e.response.status_code >= 500 and attempt < self.max_retries:
246
+ async with semaphore:
247
+ for attempt in range(self.max_retries + 1):
248
+ try:
249
+ if method == "GET":
250
+ response = await client.get(url, **kwargs)
251
+ else:
252
+ response = await client.post(url, **kwargs)
253
+ response.raise_for_status()
254
+ return response
255
+ except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
189
256
  last_error = e
190
- logger.warning(
191
- f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
192
- )
193
- time.sleep(delay)
194
- delay *= 2
195
- else:
196
- raise
257
+ if attempt < self.max_retries:
258
+ logger.warning(
259
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
260
+ f"Retrying in {delay}s..."
261
+ )
262
+ await asyncio.sleep(delay)
263
+ delay *= 2 # Exponential backoff
264
+ except httpx.HTTPStatusError as e:
265
+ # Retry on 5xx server errors
266
+ if e.response.status_code >= 500 and attempt < self.max_retries:
267
+ last_error = e
268
+ logger.warning(
269
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
270
+ f"Retrying in {delay}s..."
271
+ )
272
+ await asyncio.sleep(delay)
273
+ delay *= 2
274
+ else:
275
+ raise
197
276
 
198
277
  raise last_error
199
278
 
200
279
  async def initialize(self) -> None:
201
280
  """Initialize the HTTP client and verify server connectivity."""
202
- if self._client is not None:
281
+ if self._async_client is not None:
203
282
  return
204
283
 
205
- logger.info(f"Reranker: initializing TEI provider at {self.base_url}")
206
- self._client = httpx.Client(timeout=self.timeout)
284
+ logger.info(
285
+ f"Reranker: initializing TEI provider at {self.base_url} "
286
+ f"(batch_size={self.batch_size}, max_concurrent={self.max_concurrent})"
287
+ )
288
+ self._async_client = httpx.AsyncClient(timeout=self.timeout)
207
289
 
208
290
  # Verify server is reachable and get model info
291
+ # Use a temporary semaphore for initialization
292
+ init_semaphore = asyncio.Semaphore(1)
209
293
  try:
210
- response = self._request_with_retry("GET", f"{self.base_url}/info")
294
+ response = await self._async_request_with_retry(
295
+ self._async_client, init_semaphore, "GET", f"{self.base_url}/info"
296
+ )
211
297
  info = response.json()
212
298
  self._model_id = info.get("model_id", "unknown")
213
299
  logger.info(f"Reranker: TEI provider initialized (model: {self._model_id})")
214
300
  except httpx.HTTPError as e:
301
+ self._async_client = None
215
302
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
216
303
 
217
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
304
+ async def _rerank_query_group(
305
+ self,
306
+ client: httpx.AsyncClient,
307
+ semaphore: asyncio.Semaphore,
308
+ query: str,
309
+ texts: list[str],
310
+ ) -> list[tuple[int, float]]:
311
+ """Rerank a single query group and return list of (original_index, score) tuples."""
312
+ try:
313
+ response = await self._async_request_with_retry(
314
+ client,
315
+ semaphore,
316
+ "POST",
317
+ f"{self.base_url}/rerank",
318
+ json={
319
+ "query": query,
320
+ "texts": texts,
321
+ "return_text": False,
322
+ },
323
+ )
324
+ results = response.json()
325
+ # TEI returns results sorted by score descending, with original index
326
+ return [(result["index"], result["score"]) for result in results]
327
+ except httpx.HTTPError as e:
328
+ raise RuntimeError(f"TEI rerank request failed: {e}")
329
+
330
+ async def _predict_async(self, pairs: list[tuple[str, str]]) -> list[float]:
331
+ """Async implementation of predict that runs requests in parallel with backpressure."""
332
+ if not pairs:
333
+ return []
334
+
335
+ # Group all pairs by query
336
+ query_groups: dict[str, list[tuple[int, str]]] = {}
337
+ for idx, (query, text) in enumerate(pairs):
338
+ if query not in query_groups:
339
+ query_groups[query] = []
340
+ query_groups[query].append((idx, text))
341
+
342
+ # Split each query group into batches
343
+ tasks_info: list[tuple[str, list[int], list[str]]] = [] # (query, indices, texts)
344
+ for query, indexed_texts in query_groups.items():
345
+ indices = [idx for idx, _ in indexed_texts]
346
+ texts = [text for _, text in indexed_texts]
347
+
348
+ # Split into batches
349
+ for i in range(0, len(texts), self.batch_size):
350
+ batch_indices = indices[i : i + self.batch_size]
351
+ batch_texts = texts[i : i + self.batch_size]
352
+ tasks_info.append((query, batch_indices, batch_texts))
353
+
354
+ # Run all requests in parallel with GLOBAL semaphore for backpressure
355
+ # This ensures max_concurrent is respected across ALL parallel recall operations
356
+ all_scores = [0.0] * len(pairs)
357
+ semaphore = RemoteTEICrossEncoder._global_semaphore
358
+
359
+ tasks = [
360
+ self._rerank_query_group(self._async_client, semaphore, query, texts) for query, _, texts in tasks_info
361
+ ]
362
+ results = await asyncio.gather(*tasks)
363
+
364
+ # Map scores back to original positions
365
+ for (_, indices, _), result_scores in zip(tasks_info, results):
366
+ for original_idx_in_batch, score in result_scores:
367
+ global_idx = indices[original_idx_in_batch]
368
+ all_scores[global_idx] = score
369
+
370
+ return all_scores
371
+
372
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
218
373
  """
219
374
  Score query-document pairs using the remote TEI reranker.
220
375
 
376
+ Requests are made in parallel with configurable backpressure.
377
+
378
+ Args:
379
+ pairs: List of (query, document) tuples to score
380
+
381
+ Returns:
382
+ List of relevance scores
383
+ """
384
+ if self._async_client is None:
385
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
386
+
387
+ return await self._predict_async(pairs)
388
+
389
+
390
+ class CohereCrossEncoder(CrossEncoderModel):
391
+ """
392
+ Cohere cross-encoder implementation using the Cohere Rerank API.
393
+
394
+ Supports rerank-english-v3.0 and rerank-multilingual-v3.0 models.
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ api_key: str,
400
+ model: str = DEFAULT_RERANKER_COHERE_MODEL,
401
+ base_url: str | None = None,
402
+ timeout: float = 60.0,
403
+ ):
404
+ """
405
+ Initialize Cohere cross-encoder client.
406
+
407
+ Args:
408
+ api_key: Cohere API key
409
+ model: Cohere rerank model name (default: rerank-english-v3.0)
410
+ base_url: Custom base URL for Cohere-compatible API (e.g., Azure-hosted endpoint)
411
+ timeout: Request timeout in seconds (default: 60.0)
412
+ """
413
+ self.api_key = api_key
414
+ self.model = model
415
+ self.base_url = base_url
416
+ self.timeout = timeout
417
+ self._client = None
418
+
419
+ @property
420
+ def provider_name(self) -> str:
421
+ return "cohere"
422
+
423
+ async def initialize(self) -> None:
424
+ """Initialize the Cohere client."""
425
+ if self._client is not None:
426
+ return
427
+
428
+ try:
429
+ import cohere
430
+ except ImportError:
431
+ raise ImportError("cohere is required for CohereCrossEncoder. Install it with: pip install cohere")
432
+
433
+ base_url_msg = f" at {self.base_url}" if self.base_url else ""
434
+ logger.info(f"Reranker: initializing Cohere provider with model {self.model}{base_url_msg}")
435
+
436
+ # Build client kwargs, only including base_url if set (for Azure or custom endpoints)
437
+ client_kwargs = {"api_key": self.api_key, "timeout": self.timeout}
438
+ if self.base_url:
439
+ client_kwargs["base_url"] = self.base_url
440
+ self._client = cohere.Client(**client_kwargs)
441
+ logger.info("Reranker: Cohere provider initialized")
442
+
443
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
444
+ """
445
+ Score query-document pairs using the Cohere Rerank API.
446
+
221
447
  Args:
222
448
  pairs: List of (query, document) tuples to score
223
449
 
@@ -230,50 +456,312 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
230
456
  if not pairs:
231
457
  return []
232
458
 
233
- all_scores = []
459
+ # Run sync Cohere API calls in thread pool
460
+ loop = asyncio.get_event_loop()
461
+ return await loop.run_in_executor(None, self._predict_sync, pairs)
462
+
463
+ def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
464
+ """Synchronous predict implementation for Cohere API."""
465
+ # Group pairs by query for efficient batching
466
+ # Cohere rerank expects one query with multiple documents
467
+ query_groups: dict[str, list[tuple[int, str]]] = {}
468
+ for idx, (query, text) in enumerate(pairs):
469
+ if query not in query_groups:
470
+ query_groups[query] = []
471
+ query_groups[query].append((idx, text))
472
+
473
+ all_scores = [0.0] * len(pairs)
474
+
475
+ for query, indexed_texts in query_groups.items():
476
+ texts = [text for _, text in indexed_texts]
477
+ indices = [idx for idx, _ in indexed_texts]
478
+
479
+ response = self._client.rerank(
480
+ query=query,
481
+ documents=texts,
482
+ model=self.model,
483
+ return_documents=False,
484
+ )
234
485
 
235
- # Process in batches
236
- for i in range(0, len(pairs), self.batch_size):
237
- batch = pairs[i : i + self.batch_size]
486
+ # Map scores back to original positions
487
+ for result in response.results:
488
+ original_idx = result.index
489
+ score = result.relevance_score
490
+ all_scores[indices[original_idx]] = score
238
491
 
239
- # TEI rerank endpoint expects query and texts separately
240
- # All pairs in a batch should have the same query for optimal performance
241
- # but we handle mixed queries by making separate requests per unique query
242
- query_groups: dict[str, list[tuple[int, str]]] = {}
243
- for idx, (query, text) in enumerate(batch):
244
- if query not in query_groups:
245
- query_groups[query] = []
246
- query_groups[query].append((idx, text))
492
+ return all_scores
247
493
 
248
- batch_scores = [0.0] * len(batch)
249
494
 
250
- for query, indexed_texts in query_groups.items():
251
- texts = [text for _, text in indexed_texts]
252
- indices = [idx for idx, _ in indexed_texts]
495
+ class RRFPassthroughCrossEncoder(CrossEncoderModel):
496
+ """
497
+ Passthrough cross-encoder that preserves RRF scores without neural reranking.
253
498
 
254
- try:
255
- response = self._request_with_retry(
256
- "POST",
257
- f"{self.base_url}/rerank",
258
- json={
259
- "query": query,
260
- "texts": texts,
261
- "return_text": False,
262
- },
263
- )
264
- results = response.json()
265
-
266
- # TEI returns results sorted by score descending, with original index
267
- for result in results:
268
- original_idx = result["index"]
269
- score = result["score"]
270
- # Map back to batch position
271
- batch_scores[indices[original_idx]] = score
272
-
273
- except httpx.HTTPError as e:
274
- raise RuntimeError(f"TEI rerank request failed: {e}")
275
-
276
- all_scores.extend(batch_scores)
499
+ This is useful for:
500
+ - Testing retrieval quality without reranking overhead
501
+ - Deployments where reranking latency is unacceptable
502
+ - Debugging to isolate retrieval vs reranking issues
503
+ """
504
+
505
+ def __init__(self):
506
+ """Initialize RRF passthrough cross-encoder."""
507
+ pass
508
+
509
+ @property
510
+ def provider_name(self) -> str:
511
+ return "rrf"
512
+
513
+ async def initialize(self) -> None:
514
+ """No initialization needed."""
515
+ logger.info("Reranker: RRF passthrough provider initialized (neural reranking disabled)")
516
+
517
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
518
+ """
519
+ Return neutral scores - actual ranking uses RRF scores from retrieval.
520
+
521
+ Args:
522
+ pairs: List of (query, document) tuples (ignored)
523
+
524
+ Returns:
525
+ List of 0.5 scores (neutral, lets RRF scores dominate)
526
+ """
527
+ # Return neutral scores so RRF ranking is preserved
528
+ return [0.5] * len(pairs)
529
+
530
+
531
+ class FlashRankCrossEncoder(CrossEncoderModel):
532
+ """
533
+ FlashRank cross-encoder implementation.
534
+
535
+ FlashRank is an ultra-lite reranking library that runs on CPU without
536
+ requiring PyTorch or Transformers. It's ideal for serverless deployments
537
+ with minimal cold-start overhead.
538
+
539
+ Available models:
540
+ - ms-marco-TinyBERT-L-2-v2: Fastest, ~4MB
541
+ - ms-marco-MiniLM-L-12-v2: Best quality, ~34MB (default)
542
+ - rank-T5-flan: Best zero-shot, ~110MB
543
+ - ms-marco-MultiBERT-L-12: Multi-lingual, ~150MB
544
+ """
545
+
546
+ # Shared executor for CPU-bound reranking
547
+ _executor: ThreadPoolExecutor | None = None
548
+ _max_concurrent: int = 4
549
+
550
+ def __init__(
551
+ self,
552
+ model_name: str | None = None,
553
+ cache_dir: str | None = None,
554
+ max_length: int = 512,
555
+ max_concurrent: int = 4,
556
+ ):
557
+ """
558
+ Initialize FlashRank cross-encoder.
559
+
560
+ Args:
561
+ model_name: FlashRank model name. Default: ms-marco-MiniLM-L-12-v2
562
+ cache_dir: Directory to cache downloaded models. Default: system cache
563
+ max_length: Maximum sequence length for reranking. Default: 512
564
+ max_concurrent: Maximum concurrent reranking calls. Default: 4
565
+ """
566
+ self.model_name = model_name or DEFAULT_RERANKER_FLASHRANK_MODEL
567
+ self.cache_dir = cache_dir or DEFAULT_RERANKER_FLASHRANK_CACHE_DIR
568
+ self.max_length = max_length
569
+ self._ranker = None
570
+ FlashRankCrossEncoder._max_concurrent = max_concurrent
571
+
572
+ @property
573
+ def provider_name(self) -> str:
574
+ return "flashrank"
575
+
576
+ async def initialize(self) -> None:
577
+ """Load the FlashRank model."""
578
+ if self._ranker is not None:
579
+ return
580
+
581
+ try:
582
+ from flashrank import Ranker # type: ignore[import-untyped]
583
+ except ImportError:
584
+ raise ImportError("flashrank is required for FlashRankCrossEncoder. Install it with: pip install flashrank")
585
+
586
+ logger.info(f"Reranker: initializing FlashRank provider with model {self.model_name}")
587
+
588
+ # Initialize ranker with optional cache directory
589
+ ranker_kwargs = {"model_name": self.model_name, "max_length": self.max_length}
590
+ if self.cache_dir:
591
+ ranker_kwargs["cache_dir"] = self.cache_dir
592
+
593
+ self._ranker = Ranker(**ranker_kwargs)
594
+
595
+ # Initialize shared executor
596
+ if FlashRankCrossEncoder._executor is None:
597
+ FlashRankCrossEncoder._executor = ThreadPoolExecutor(
598
+ max_workers=FlashRankCrossEncoder._max_concurrent,
599
+ thread_name_prefix="flashrank",
600
+ )
601
+ logger.info(
602
+ f"Reranker: FlashRank provider initialized (max_concurrent={FlashRankCrossEncoder._max_concurrent})"
603
+ )
604
+ else:
605
+ logger.info("Reranker: FlashRank provider initialized (using existing executor)")
606
+
607
+ def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
608
+ """Synchronous predict - processes each query group."""
609
+ from flashrank import RerankRequest # type: ignore[import-untyped]
610
+
611
+ if not pairs:
612
+ return []
613
+
614
+ # Group pairs by query
615
+ query_groups: dict[str, list[tuple[int, str]]] = {}
616
+ for idx, (query, text) in enumerate(pairs):
617
+ if query not in query_groups:
618
+ query_groups[query] = []
619
+ query_groups[query].append((idx, text))
620
+
621
+ all_scores = [0.0] * len(pairs)
622
+
623
+ for query, indexed_texts in query_groups.items():
624
+ # Build passages list for FlashRank
625
+ passages = [{"id": i, "text": text} for i, (_, text) in enumerate(indexed_texts)]
626
+ global_indices = [idx for idx, _ in indexed_texts]
627
+
628
+ # Create rerank request
629
+ request = RerankRequest(query=query, passages=passages)
630
+ results = self._ranker.rerank(request)
631
+
632
+ # Map scores back to original positions
633
+ for result in results:
634
+ local_idx = result["id"]
635
+ score = result["score"]
636
+ global_idx = global_indices[local_idx]
637
+ all_scores[global_idx] = score
638
+
639
+ return all_scores
640
+
641
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
642
+ """
643
+ Score query-document pairs using FlashRank.
644
+
645
+ Args:
646
+ pairs: List of (query, document) tuples to score
647
+
648
+ Returns:
649
+ List of relevance scores (higher = more relevant)
650
+ """
651
+ if self._ranker is None:
652
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
653
+
654
+ # Run in thread pool to avoid blocking event loop
655
+ loop = asyncio.get_event_loop()
656
+ return await loop.run_in_executor(FlashRankCrossEncoder._executor, self._predict_sync, pairs)
657
+
658
+
659
+ class LiteLLMCrossEncoder(CrossEncoderModel):
660
+ """
661
+ LiteLLM cross-encoder implementation using LiteLLM proxy's /rerank endpoint.
662
+
663
+ LiteLLM provides a unified interface for multiple reranking providers via
664
+ the Cohere-compatible /rerank endpoint.
665
+ See: https://docs.litellm.ai/docs/rerank
666
+
667
+ Supported providers via LiteLLM:
668
+ - Cohere (rerank-english-v3.0, etc.) - prefix with cohere/
669
+ - Together AI - prefix with together_ai/
670
+ - Azure AI - prefix with azure_ai/
671
+ - Jina AI - prefix with jina_ai/
672
+ - AWS Bedrock - prefix with bedrock/
673
+ - Voyage AI - prefix with voyage/
674
+ """
675
+
676
+ def __init__(
677
+ self,
678
+ api_base: str = DEFAULT_LITELLM_API_BASE,
679
+ api_key: str | None = None,
680
+ model: str = DEFAULT_RERANKER_LITELLM_MODEL,
681
+ timeout: float = 60.0,
682
+ ):
683
+ """
684
+ Initialize LiteLLM cross-encoder client.
685
+
686
+ Args:
687
+ api_base: Base URL of the LiteLLM proxy (default: http://localhost:4000)
688
+ api_key: API key for the LiteLLM proxy (optional, depends on proxy config)
689
+ model: Reranking model name (default: cohere/rerank-english-v3.0)
690
+ Use provider prefix (e.g., cohere/, together_ai/, voyage/)
691
+ timeout: Request timeout in seconds (default: 60.0)
692
+ """
693
+ self.api_base = api_base.rstrip("/")
694
+ self.api_key = api_key
695
+ self.model = model
696
+ self.timeout = timeout
697
+ self._async_client: httpx.AsyncClient | None = None
698
+
699
+ @property
700
+ def provider_name(self) -> str:
701
+ return "litellm"
702
+
703
+ async def initialize(self) -> None:
704
+ """Initialize the async HTTP client."""
705
+ if self._async_client is not None:
706
+ return
707
+
708
+ logger.info(f"Reranker: initializing LiteLLM provider at {self.api_base} with model {self.model}")
709
+
710
+ headers = {"Content-Type": "application/json"}
711
+ if self.api_key:
712
+ headers["Authorization"] = f"Bearer {self.api_key}"
713
+
714
+ self._async_client = httpx.AsyncClient(timeout=self.timeout, headers=headers)
715
+ logger.info("Reranker: LiteLLM provider initialized")
716
+
717
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
718
+ """
719
+ Score query-document pairs using the LiteLLM proxy's /rerank endpoint.
720
+
721
+ Args:
722
+ pairs: List of (query, document) tuples to score
723
+
724
+ Returns:
725
+ List of relevance scores
726
+ """
727
+ if self._async_client is None:
728
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
729
+
730
+ if not pairs:
731
+ return []
732
+
733
+ # Group pairs by query (LiteLLM rerank expects one query with multiple documents)
734
+ query_groups: dict[str, list[tuple[int, str]]] = {}
735
+ for idx, (query, text) in enumerate(pairs):
736
+ if query not in query_groups:
737
+ query_groups[query] = []
738
+ query_groups[query].append((idx, text))
739
+
740
+ all_scores = [0.0] * len(pairs)
741
+
742
+ for query, indexed_texts in query_groups.items():
743
+ texts = [text for _, text in indexed_texts]
744
+ indices = [idx for idx, _ in indexed_texts]
745
+
746
+ # LiteLLM /rerank follows Cohere API format
747
+ response = await self._async_client.post(
748
+ f"{self.api_base}/rerank",
749
+ json={
750
+ "model": self.model,
751
+ "query": query,
752
+ "documents": texts,
753
+ "top_n": len(texts), # Return all scores
754
+ },
755
+ )
756
+ response.raise_for_status()
757
+ result = response.json()
758
+
759
+ # Map scores back to original positions
760
+ # Response format: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
761
+ for item in result.get("results", []):
762
+ original_idx = item["index"]
763
+ score = item.get("relevance_score", item.get("score", 0.0))
764
+ all_scores[indices[original_idx]] = score
277
765
 
278
766
  return all_scores
279
767
 
@@ -293,10 +781,35 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
293
781
  url = os.environ.get(ENV_RERANKER_TEI_URL)
294
782
  if not url:
295
783
  raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
296
- return RemoteTEICrossEncoder(base_url=url)
784
+ batch_size = int(os.environ.get(ENV_RERANKER_TEI_BATCH_SIZE, str(DEFAULT_RERANKER_TEI_BATCH_SIZE)))
785
+ max_concurrent = int(os.environ.get(ENV_RERANKER_TEI_MAX_CONCURRENT, str(DEFAULT_RERANKER_TEI_MAX_CONCURRENT)))
786
+ return RemoteTEICrossEncoder(base_url=url, batch_size=batch_size, max_concurrent=max_concurrent)
297
787
  elif provider == "local":
298
788
  model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
299
789
  model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
300
- return LocalSTCrossEncoder(model_name=model_name)
790
+ max_concurrent = int(
791
+ os.environ.get(ENV_RERANKER_LOCAL_MAX_CONCURRENT, str(DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT))
792
+ )
793
+ return LocalSTCrossEncoder(model_name=model_name, max_concurrent=max_concurrent)
794
+ elif provider == "cohere":
795
+ api_key = os.environ.get(ENV_COHERE_API_KEY)
796
+ if not api_key:
797
+ raise ValueError(f"{ENV_COHERE_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'cohere'")
798
+ model = os.environ.get(ENV_RERANKER_COHERE_MODEL, DEFAULT_RERANKER_COHERE_MODEL)
799
+ base_url = os.environ.get(ENV_RERANKER_COHERE_BASE_URL) or None
800
+ return CohereCrossEncoder(api_key=api_key, model=model, base_url=base_url)
801
+ elif provider == "flashrank":
802
+ model = os.environ.get(ENV_RERANKER_FLASHRANK_MODEL, DEFAULT_RERANKER_FLASHRANK_MODEL)
803
+ cache_dir = os.environ.get(ENV_RERANKER_FLASHRANK_CACHE_DIR, DEFAULT_RERANKER_FLASHRANK_CACHE_DIR)
804
+ return FlashRankCrossEncoder(model_name=model, cache_dir=cache_dir)
805
+ elif provider == "litellm":
806
+ api_base = os.environ.get(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE)
807
+ api_key = os.environ.get(ENV_LITELLM_API_KEY)
808
+ model = os.environ.get(ENV_RERANKER_LITELLM_MODEL, DEFAULT_RERANKER_LITELLM_MODEL)
809
+ return LiteLLMCrossEncoder(api_base=api_base, api_key=api_key, model=model)
810
+ elif provider == "rrf":
811
+ return RRFPassthroughCrossEncoder()
301
812
  else:
302
- raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
813
+ raise ValueError(
814
+ f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'rrf'"
815
+ )