hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +311 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
  6. hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
  7. hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
  8. hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
  9. hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
  10. hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
  11. hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
  12. hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
  13. hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
  14. hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
  15. hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
  16. hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
  17. hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
  18. hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
  19. hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
  20. hindsight_api/api/http.py +1406 -118
  21. hindsight_api/api/mcp.py +11 -196
  22. hindsight_api/config.py +359 -27
  23. hindsight_api/engine/consolidation/__init__.py +5 -0
  24. hindsight_api/engine/consolidation/consolidator.py +859 -0
  25. hindsight_api/engine/consolidation/prompts.py +69 -0
  26. hindsight_api/engine/cross_encoder.py +706 -88
  27. hindsight_api/engine/db_budget.py +284 -0
  28. hindsight_api/engine/db_utils.py +11 -0
  29. hindsight_api/engine/directives/__init__.py +5 -0
  30. hindsight_api/engine/directives/models.py +37 -0
  31. hindsight_api/engine/embeddings.py +553 -29
  32. hindsight_api/engine/entity_resolver.py +8 -5
  33. hindsight_api/engine/interface.py +40 -17
  34. hindsight_api/engine/llm_wrapper.py +744 -68
  35. hindsight_api/engine/memory_engine.py +2505 -1017
  36. hindsight_api/engine/mental_models/__init__.py +14 -0
  37. hindsight_api/engine/mental_models/models.py +53 -0
  38. hindsight_api/engine/query_analyzer.py +4 -3
  39. hindsight_api/engine/reflect/__init__.py +18 -0
  40. hindsight_api/engine/reflect/agent.py +933 -0
  41. hindsight_api/engine/reflect/models.py +109 -0
  42. hindsight_api/engine/reflect/observations.py +186 -0
  43. hindsight_api/engine/reflect/prompts.py +483 -0
  44. hindsight_api/engine/reflect/tools.py +437 -0
  45. hindsight_api/engine/reflect/tools_schema.py +250 -0
  46. hindsight_api/engine/response_models.py +168 -4
  47. hindsight_api/engine/retain/bank_utils.py +79 -201
  48. hindsight_api/engine/retain/fact_extraction.py +424 -195
  49. hindsight_api/engine/retain/fact_storage.py +35 -12
  50. hindsight_api/engine/retain/link_utils.py +29 -24
  51. hindsight_api/engine/retain/orchestrator.py +24 -43
  52. hindsight_api/engine/retain/types.py +11 -2
  53. hindsight_api/engine/search/graph_retrieval.py +43 -14
  54. hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
  55. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  56. hindsight_api/engine/search/reranking.py +2 -2
  57. hindsight_api/engine/search/retrieval.py +848 -201
  58. hindsight_api/engine/search/tags.py +172 -0
  59. hindsight_api/engine/search/think_utils.py +42 -141
  60. hindsight_api/engine/search/trace.py +12 -1
  61. hindsight_api/engine/search/tracer.py +26 -6
  62. hindsight_api/engine/search/types.py +21 -3
  63. hindsight_api/engine/task_backend.py +113 -106
  64. hindsight_api/engine/utils.py +1 -152
  65. hindsight_api/extensions/__init__.py +10 -1
  66. hindsight_api/extensions/builtin/tenant.py +5 -1
  67. hindsight_api/extensions/context.py +10 -1
  68. hindsight_api/extensions/operation_validator.py +81 -4
  69. hindsight_api/extensions/tenant.py +26 -0
  70. hindsight_api/main.py +69 -6
  71. hindsight_api/mcp_local.py +12 -53
  72. hindsight_api/mcp_tools.py +494 -0
  73. hindsight_api/metrics.py +433 -48
  74. hindsight_api/migrations.py +141 -1
  75. hindsight_api/models.py +3 -3
  76. hindsight_api/pg0.py +53 -0
  77. hindsight_api/server.py +39 -2
  78. hindsight_api/worker/__init__.py +11 -0
  79. hindsight_api/worker/main.py +296 -0
  80. hindsight_api/worker/poller.py +486 -0
  81. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
  82. hindsight_api-0.4.0.dist-info/RECORD +112 -0
  83. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
  84. hindsight_api/engine/retain/observation_regeneration.py +0 -254
  85. hindsight_api/engine/search/observation_utils.py +0 -125
  86. hindsight_api/engine/search/scoring.py +0 -159
  87. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  88. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
@@ -6,17 +6,38 @@ Provides an interface for reranking with different backends.
6
6
  Configuration via environment variables - see hindsight_api.config for all env var names.
7
7
  """
8
8
 
9
+ import asyncio
9
10
  import logging
10
11
  import os
11
12
  from abc import ABC, abstractmethod
13
+ from concurrent.futures import ThreadPoolExecutor
12
14
 
13
15
  import httpx
14
16
 
15
17
  from ..config import (
18
+ DEFAULT_LITELLM_API_BASE,
19
+ DEFAULT_RERANKER_COHERE_MODEL,
20
+ DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
21
+ DEFAULT_RERANKER_FLASHRANK_MODEL,
22
+ DEFAULT_RERANKER_LITELLM_MODEL,
23
+ DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT,
16
24
  DEFAULT_RERANKER_LOCAL_MODEL,
17
25
  DEFAULT_RERANKER_PROVIDER,
26
+ DEFAULT_RERANKER_TEI_BATCH_SIZE,
27
+ DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
28
+ ENV_COHERE_API_KEY,
29
+ ENV_LITELLM_API_BASE,
30
+ ENV_LITELLM_API_KEY,
31
+ ENV_RERANKER_COHERE_BASE_URL,
32
+ ENV_RERANKER_COHERE_MODEL,
33
+ ENV_RERANKER_FLASHRANK_CACHE_DIR,
34
+ ENV_RERANKER_FLASHRANK_MODEL,
35
+ ENV_RERANKER_LITELLM_MODEL,
36
+ ENV_RERANKER_LOCAL_MAX_CONCURRENT,
18
37
  ENV_RERANKER_LOCAL_MODEL,
19
38
  ENV_RERANKER_PROVIDER,
39
+ ENV_RERANKER_TEI_BATCH_SIZE,
40
+ ENV_RERANKER_TEI_MAX_CONCURRENT,
20
41
  ENV_RERANKER_TEI_URL,
21
42
  )
22
43
 
@@ -47,7 +68,7 @@ class CrossEncoderModel(ABC):
47
68
  pass
48
69
 
49
70
  @abstractmethod
50
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
71
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
51
72
  """
52
73
  Score query-document pairs for relevance.
53
74
 
@@ -70,25 +91,34 @@ class LocalSTCrossEncoder(CrossEncoderModel):
70
91
  - Fast inference (~80ms for 100 pairs on CPU)
71
92
  - Small model (80MB)
72
93
  - Trained for passage re-ranking
94
+
95
+ Uses a dedicated thread pool to limit concurrent CPU-bound work.
73
96
  """
74
97
 
75
- def __init__(self, model_name: str | None = None):
98
+ # Shared executor across all instances (one model loaded anyway)
99
+ _executor: ThreadPoolExecutor | None = None
100
+ _max_concurrent: int = 4 # Limit concurrent CPU-bound reranking calls
101
+
102
+ def __init__(self, model_name: str | None = None, max_concurrent: int = 4):
76
103
  """
77
104
  Initialize local SentenceTransformers cross-encoder.
78
105
 
79
106
  Args:
80
107
  model_name: Name of the CrossEncoder model to use.
81
108
  Default: cross-encoder/ms-marco-MiniLM-L-6-v2
109
+ max_concurrent: Maximum concurrent reranking calls (default: 2).
110
+ Higher values may cause CPU thrashing under load.
82
111
  """
83
112
  self.model_name = model_name or DEFAULT_RERANKER_LOCAL_MODEL
84
113
  self._model = None
114
+ LocalSTCrossEncoder._max_concurrent = max_concurrent
85
115
 
86
116
  @property
87
117
  def provider_name(self) -> str:
88
118
  return "local"
89
119
 
90
120
  async def initialize(self) -> None:
91
- """Load the cross-encoder model."""
121
+ """Load the cross-encoder model and initialize the executor."""
92
122
  if self._model is not None:
93
123
  return
94
124
 
@@ -101,13 +131,134 @@ class LocalSTCrossEncoder(CrossEncoderModel):
101
131
  )
102
132
 
103
133
  logger.info(f"Reranker: initializing local provider with model {self.model_name}")
104
- self._model = CrossEncoder(self.model_name)
105
- logger.info("Reranker: local provider initialized")
106
134
 
107
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
135
+ # Determine device based on hardware availability.
136
+ # We always set low_cpu_mem_usage=False to prevent lazy loading (meta tensors)
137
+ # which can cause issues when accelerate is installed but no GPU is available.
138
+ # Note: We do NOT use device_map because CrossEncoder internally calls .to(device)
139
+ # after loading, which conflicts with accelerate's device_map handling.
140
+ import torch
141
+
142
+ # Check for GPU (CUDA) or Apple Silicon (MPS)
143
+ has_gpu = torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
144
+
145
+ if has_gpu:
146
+ device = None # Let sentence-transformers auto-detect GPU/MPS
147
+ else:
148
+ device = "cpu"
149
+
150
+ self._model = CrossEncoder(
151
+ self.model_name,
152
+ device=device,
153
+ model_kwargs={"low_cpu_mem_usage": False},
154
+ )
155
+
156
+ # Initialize shared executor (limited workers naturally limits concurrency)
157
+ if LocalSTCrossEncoder._executor is None:
158
+ LocalSTCrossEncoder._executor = ThreadPoolExecutor(
159
+ max_workers=LocalSTCrossEncoder._max_concurrent,
160
+ thread_name_prefix="reranker",
161
+ )
162
+ logger.info(f"Reranker: local provider initialized (max_concurrent={LocalSTCrossEncoder._max_concurrent})")
163
+ else:
164
+ logger.info("Reranker: local provider initialized (using existing executor)")
165
+
166
+ def _is_xpc_error(self, error: Exception) -> bool:
167
+ """
168
+ Check if an error is an XPC connection error (macOS daemon issue).
169
+
170
+ On macOS, long-running daemons can lose XPC connections to system services
171
+ when the process is idle for extended periods.
172
+ """
173
+ error_str = str(error).lower()
174
+ return "xpc_error_connection_invalid" in error_str or "xpc error" in error_str
175
+
176
+ def _reinitialize_model_sync(self) -> None:
177
+ """
178
+ Clear and reinitialize the cross-encoder model synchronously.
179
+
180
+ This is used to recover from XPC errors on macOS where the
181
+ PyTorch/MPS backend loses its connection to system services.
182
+ """
183
+ logger.warning(f"Reinitializing reranker model {self.model_name} due to backend error")
184
+
185
+ # Clear existing model
186
+ self._model = None
187
+
188
+ # Force garbage collection to free resources
189
+ import gc
190
+
191
+ import torch
192
+
193
+ gc.collect()
194
+
195
+ # If using CUDA/MPS, clear the cache
196
+ if torch.cuda.is_available():
197
+ torch.cuda.empty_cache()
198
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
199
+ try:
200
+ torch.mps.empty_cache()
201
+ except AttributeError:
202
+ pass # Method might not exist in all PyTorch versions
203
+
204
+ # Reinitialize the model
205
+ try:
206
+ from sentence_transformers import CrossEncoder
207
+ except ImportError:
208
+ raise ImportError(
209
+ "sentence-transformers is required for LocalSTCrossEncoder. "
210
+ "Install it with: pip install sentence-transformers"
211
+ )
212
+
213
+ # Determine device based on hardware availability
214
+ has_gpu = torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
215
+
216
+ if has_gpu:
217
+ device = None # Let sentence-transformers auto-detect GPU/MPS
218
+ else:
219
+ device = "cpu"
220
+
221
+ self._model = CrossEncoder(
222
+ self.model_name,
223
+ device=device,
224
+ model_kwargs={"low_cpu_mem_usage": False},
225
+ )
226
+
227
+ logger.info("Reranker: local provider reinitialized successfully")
228
+
229
+ def _predict_with_recovery(self, pairs: list[tuple[str, str]]) -> list[float]:
230
+ """
231
+ Predict with automatic recovery from XPC errors.
232
+
233
+ This runs synchronously in the thread pool.
234
+ """
235
+ max_retries = 1
236
+ for attempt in range(max_retries + 1):
237
+ try:
238
+ scores = self._model.predict(pairs, show_progress_bar=False)
239
+ return scores.tolist() if hasattr(scores, "tolist") else list(scores)
240
+ except Exception as e:
241
+ # Check if this is an XPC error (macOS daemon issue)
242
+ if self._is_xpc_error(e) and attempt < max_retries:
243
+ logger.warning(f"XPC error detected in reranker (attempt {attempt + 1}): {e}")
244
+ try:
245
+ self._reinitialize_model_sync()
246
+ logger.info("Reranker reinitialized successfully, retrying prediction")
247
+ continue
248
+ except Exception as reinit_error:
249
+ logger.error(f"Failed to reinitialize reranker: {reinit_error}")
250
+ raise Exception(f"Failed to recover from XPC error: {str(e)}")
251
+ else:
252
+ # Not an XPC error or out of retries
253
+ raise
254
+
255
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
108
256
  """
109
257
  Score query-document pairs for relevance.
110
258
 
259
+ Uses a dedicated thread pool with limited workers to prevent CPU thrashing.
260
+ Automatically recovers from XPC errors on macOS by reinitializing the model.
261
+
111
262
  Args:
112
263
  pairs: List of (query, document) tuples to score
113
264
 
@@ -116,8 +267,14 @@ class LocalSTCrossEncoder(CrossEncoderModel):
116
267
  """
117
268
  if self._model is None:
118
269
  raise RuntimeError("Reranker not initialized. Call initialize() first.")
119
- scores = self._model.predict(pairs, show_progress_bar=False)
120
- return scores.tolist() if hasattr(scores, "tolist") else list(scores)
270
+
271
+ # Use dedicated executor - limited workers naturally limits concurrency
272
+ loop = asyncio.get_event_loop()
273
+ return await loop.run_in_executor(
274
+ LocalSTCrossEncoder._executor,
275
+ self._predict_with_recovery,
276
+ pairs,
277
+ )
121
278
 
122
279
 
123
280
  class RemoteTEICrossEncoder(CrossEncoderModel):
@@ -128,13 +285,21 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
128
285
  See: https://github.com/huggingface/text-embeddings-inference
129
286
 
130
287
  Note: The TEI server must be running a cross-encoder/reranker model.
288
+
289
+ Requests are made in parallel with configurable batch size and max concurrency (backpressure).
290
+ Uses a GLOBAL semaphore to limit concurrent requests across ALL recall operations.
131
291
  """
132
292
 
293
+ # Global semaphore shared across all instances and calls to prevent thundering herd
294
+ _global_semaphore: asyncio.Semaphore | None = None
295
+ _global_max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT
296
+
133
297
  def __init__(
134
298
  self,
135
299
  base_url: str,
136
300
  timeout: float = 30.0,
137
- batch_size: int = 32,
301
+ batch_size: int = DEFAULT_RERANKER_TEI_BATCH_SIZE,
302
+ max_concurrent: int = DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
138
303
  max_retries: int = 3,
139
304
  retry_delay: float = 0.5,
140
305
  ):
@@ -144,80 +309,246 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
144
309
  Args:
145
310
  base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
146
311
  timeout: Request timeout in seconds (default: 30.0)
147
- batch_size: Maximum batch size for rerank requests (default: 32)
312
+ batch_size: Maximum batch size for rerank requests (default: 128)
313
+ max_concurrent: Maximum concurrent requests for backpressure (default: 8).
314
+ This is a GLOBAL limit across all parallel recall operations.
148
315
  max_retries: Maximum number of retries for failed requests (default: 3)
149
316
  retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
150
317
  """
151
318
  self.base_url = base_url.rstrip("/")
152
319
  self.timeout = timeout
153
320
  self.batch_size = batch_size
321
+ self.max_concurrent = max_concurrent
154
322
  self.max_retries = max_retries
155
323
  self.retry_delay = retry_delay
156
- self._client: httpx.Client | None = None
324
+ self._async_client: httpx.AsyncClient | None = None
157
325
  self._model_id: str | None = None
158
326
 
327
+ # Update global semaphore if max_concurrent changed
328
+ if (
329
+ RemoteTEICrossEncoder._global_semaphore is None
330
+ or RemoteTEICrossEncoder._global_max_concurrent != max_concurrent
331
+ ):
332
+ RemoteTEICrossEncoder._global_max_concurrent = max_concurrent
333
+ RemoteTEICrossEncoder._global_semaphore = asyncio.Semaphore(max_concurrent)
334
+
159
335
  @property
160
336
  def provider_name(self) -> str:
161
337
  return "tei"
162
338
 
163
- def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
164
- """Make an HTTP request with automatic retries on transient errors."""
165
- import time
166
-
339
+ async def _async_request_with_retry(
340
+ self,
341
+ client: httpx.AsyncClient,
342
+ semaphore: asyncio.Semaphore,
343
+ method: str,
344
+ url: str,
345
+ **kwargs,
346
+ ) -> httpx.Response:
347
+ """Make an async HTTP request with automatic retries on transient errors and semaphore for backpressure."""
167
348
  last_error = None
168
349
  delay = self.retry_delay
169
350
 
170
- for attempt in range(self.max_retries + 1):
171
- try:
172
- if method == "GET":
173
- response = self._client.get(url, **kwargs)
174
- else:
175
- response = self._client.post(url, **kwargs)
176
- response.raise_for_status()
177
- return response
178
- except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
179
- last_error = e
180
- if attempt < self.max_retries:
181
- logger.warning(
182
- f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
183
- )
184
- time.sleep(delay)
185
- delay *= 2 # Exponential backoff
186
- except httpx.HTTPStatusError as e:
187
- # Retry on 5xx server errors
188
- if e.response.status_code >= 500 and attempt < self.max_retries:
351
+ async with semaphore:
352
+ for attempt in range(self.max_retries + 1):
353
+ try:
354
+ if method == "GET":
355
+ response = await client.get(url, **kwargs)
356
+ else:
357
+ response = await client.post(url, **kwargs)
358
+ response.raise_for_status()
359
+ return response
360
+ except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
189
361
  last_error = e
190
- logger.warning(
191
- f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s..."
192
- )
193
- time.sleep(delay)
194
- delay *= 2
195
- else:
196
- raise
362
+ if attempt < self.max_retries:
363
+ logger.warning(
364
+ f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
365
+ f"Retrying in {delay}s..."
366
+ )
367
+ await asyncio.sleep(delay)
368
+ delay *= 2 # Exponential backoff
369
+ except httpx.HTTPStatusError as e:
370
+ # Retry on 5xx server errors
371
+ if e.response.status_code >= 500 and attempt < self.max_retries:
372
+ last_error = e
373
+ logger.warning(
374
+ f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. "
375
+ f"Retrying in {delay}s..."
376
+ )
377
+ await asyncio.sleep(delay)
378
+ delay *= 2
379
+ else:
380
+ raise
197
381
 
198
382
  raise last_error
199
383
 
200
384
  async def initialize(self) -> None:
201
385
  """Initialize the HTTP client and verify server connectivity."""
202
- if self._client is not None:
386
+ if self._async_client is not None:
203
387
  return
204
388
 
205
- logger.info(f"Reranker: initializing TEI provider at {self.base_url}")
206
- self._client = httpx.Client(timeout=self.timeout)
389
+ logger.info(
390
+ f"Reranker: initializing TEI provider at {self.base_url} "
391
+ f"(batch_size={self.batch_size}, max_concurrent={self.max_concurrent})"
392
+ )
393
+ self._async_client = httpx.AsyncClient(timeout=self.timeout)
207
394
 
208
395
  # Verify server is reachable and get model info
396
+ # Use a temporary semaphore for initialization
397
+ init_semaphore = asyncio.Semaphore(1)
209
398
  try:
210
- response = self._request_with_retry("GET", f"{self.base_url}/info")
399
+ response = await self._async_request_with_retry(
400
+ self._async_client, init_semaphore, "GET", f"{self.base_url}/info"
401
+ )
211
402
  info = response.json()
212
403
  self._model_id = info.get("model_id", "unknown")
213
404
  logger.info(f"Reranker: TEI provider initialized (model: {self._model_id})")
214
405
  except httpx.HTTPError as e:
406
+ self._async_client = None
215
407
  raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
216
408
 
217
- def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
409
+ async def _rerank_query_group(
410
+ self,
411
+ client: httpx.AsyncClient,
412
+ semaphore: asyncio.Semaphore,
413
+ query: str,
414
+ texts: list[str],
415
+ ) -> list[tuple[int, float]]:
416
+ """Rerank a single query group and return list of (original_index, score) tuples."""
417
+ try:
418
+ response = await self._async_request_with_retry(
419
+ client,
420
+ semaphore,
421
+ "POST",
422
+ f"{self.base_url}/rerank",
423
+ json={
424
+ "query": query,
425
+ "texts": texts,
426
+ "return_text": False,
427
+ },
428
+ )
429
+ results = response.json()
430
+ # TEI returns results sorted by score descending, with original index
431
+ return [(result["index"], result["score"]) for result in results]
432
+ except httpx.HTTPError as e:
433
+ raise RuntimeError(f"TEI rerank request failed: {e}")
434
+
435
+ async def _predict_async(self, pairs: list[tuple[str, str]]) -> list[float]:
436
+ """Async implementation of predict that runs requests in parallel with backpressure."""
437
+ if not pairs:
438
+ return []
439
+
440
+ # Group all pairs by query
441
+ query_groups: dict[str, list[tuple[int, str]]] = {}
442
+ for idx, (query, text) in enumerate(pairs):
443
+ if query not in query_groups:
444
+ query_groups[query] = []
445
+ query_groups[query].append((idx, text))
446
+
447
+ # Split each query group into batches
448
+ tasks_info: list[tuple[str, list[int], list[str]]] = [] # (query, indices, texts)
449
+ for query, indexed_texts in query_groups.items():
450
+ indices = [idx for idx, _ in indexed_texts]
451
+ texts = [text for _, text in indexed_texts]
452
+
453
+ # Split into batches
454
+ for i in range(0, len(texts), self.batch_size):
455
+ batch_indices = indices[i : i + self.batch_size]
456
+ batch_texts = texts[i : i + self.batch_size]
457
+ tasks_info.append((query, batch_indices, batch_texts))
458
+
459
+ # Run all requests in parallel with GLOBAL semaphore for backpressure
460
+ # This ensures max_concurrent is respected across ALL parallel recall operations
461
+ all_scores = [0.0] * len(pairs)
462
+ semaphore = RemoteTEICrossEncoder._global_semaphore
463
+
464
+ tasks = [
465
+ self._rerank_query_group(self._async_client, semaphore, query, texts) for query, _, texts in tasks_info
466
+ ]
467
+ results = await asyncio.gather(*tasks)
468
+
469
+ # Map scores back to original positions
470
+ for (_, indices, _), result_scores in zip(tasks_info, results):
471
+ for original_idx_in_batch, score in result_scores:
472
+ global_idx = indices[original_idx_in_batch]
473
+ all_scores[global_idx] = score
474
+
475
+ return all_scores
476
+
477
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
218
478
  """
219
479
  Score query-document pairs using the remote TEI reranker.
220
480
 
481
+ Requests are made in parallel with configurable backpressure.
482
+
483
+ Args:
484
+ pairs: List of (query, document) tuples to score
485
+
486
+ Returns:
487
+ List of relevance scores
488
+ """
489
+ if self._async_client is None:
490
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
491
+
492
+ return await self._predict_async(pairs)
493
+
494
+
495
+ class CohereCrossEncoder(CrossEncoderModel):
496
+ """
497
+ Cohere cross-encoder implementation using the Cohere Rerank API.
498
+
499
+ Supports rerank-english-v3.0 and rerank-multilingual-v3.0 models.
500
+ """
501
+
502
+ def __init__(
503
+ self,
504
+ api_key: str,
505
+ model: str = DEFAULT_RERANKER_COHERE_MODEL,
506
+ base_url: str | None = None,
507
+ timeout: float = 60.0,
508
+ ):
509
+ """
510
+ Initialize Cohere cross-encoder client.
511
+
512
+ Args:
513
+ api_key: Cohere API key
514
+ model: Cohere rerank model name (default: rerank-english-v3.0)
515
+ base_url: Custom base URL for Cohere-compatible API (e.g., Azure-hosted endpoint)
516
+ timeout: Request timeout in seconds (default: 60.0)
517
+ """
518
+ self.api_key = api_key
519
+ self.model = model
520
+ self.base_url = base_url
521
+ self.timeout = timeout
522
+ self._client = None
523
+
524
+ @property
525
+ def provider_name(self) -> str:
526
+ return "cohere"
527
+
528
+ async def initialize(self) -> None:
529
+ """Initialize the Cohere client."""
530
+ if self._client is not None:
531
+ return
532
+
533
+ try:
534
+ import cohere
535
+ except ImportError:
536
+ raise ImportError("cohere is required for CohereCrossEncoder. Install it with: pip install cohere")
537
+
538
+ base_url_msg = f" at {self.base_url}" if self.base_url else ""
539
+ logger.info(f"Reranker: initializing Cohere provider with model {self.model}{base_url_msg}")
540
+
541
+ # Build client kwargs, only including base_url if set (for Azure or custom endpoints)
542
+ client_kwargs = {"api_key": self.api_key, "timeout": self.timeout}
543
+ if self.base_url:
544
+ client_kwargs["base_url"] = self.base_url
545
+ self._client = cohere.Client(**client_kwargs)
546
+ logger.info("Reranker: Cohere provider initialized")
547
+
548
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
549
+ """
550
+ Score query-document pairs using the Cohere Rerank API.
551
+
221
552
  Args:
222
553
  pairs: List of (query, document) tuples to score
223
554
 
@@ -230,50 +561,312 @@ class RemoteTEICrossEncoder(CrossEncoderModel):
230
561
  if not pairs:
231
562
  return []
232
563
 
233
- all_scores = []
564
+ # Run sync Cohere API calls in thread pool
565
+ loop = asyncio.get_event_loop()
566
+ return await loop.run_in_executor(None, self._predict_sync, pairs)
567
+
568
+ def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
569
+ """Synchronous predict implementation for Cohere API."""
570
+ # Group pairs by query for efficient batching
571
+ # Cohere rerank expects one query with multiple documents
572
+ query_groups: dict[str, list[tuple[int, str]]] = {}
573
+ for idx, (query, text) in enumerate(pairs):
574
+ if query not in query_groups:
575
+ query_groups[query] = []
576
+ query_groups[query].append((idx, text))
577
+
578
+ all_scores = [0.0] * len(pairs)
579
+
580
+ for query, indexed_texts in query_groups.items():
581
+ texts = [text for _, text in indexed_texts]
582
+ indices = [idx for idx, _ in indexed_texts]
583
+
584
+ response = self._client.rerank(
585
+ query=query,
586
+ documents=texts,
587
+ model=self.model,
588
+ return_documents=False,
589
+ )
234
590
 
235
- # Process in batches
236
- for i in range(0, len(pairs), self.batch_size):
237
- batch = pairs[i : i + self.batch_size]
591
+ # Map scores back to original positions
592
+ for result in response.results:
593
+ original_idx = result.index
594
+ score = result.relevance_score
595
+ all_scores[indices[original_idx]] = score
238
596
 
239
- # TEI rerank endpoint expects query and texts separately
240
- # All pairs in a batch should have the same query for optimal performance
241
- # but we handle mixed queries by making separate requests per unique query
242
- query_groups: dict[str, list[tuple[int, str]]] = {}
243
- for idx, (query, text) in enumerate(batch):
244
- if query not in query_groups:
245
- query_groups[query] = []
246
- query_groups[query].append((idx, text))
597
+ return all_scores
247
598
 
248
- batch_scores = [0.0] * len(batch)
249
599
 
250
- for query, indexed_texts in query_groups.items():
251
- texts = [text for _, text in indexed_texts]
252
- indices = [idx for idx, _ in indexed_texts]
600
+ class RRFPassthroughCrossEncoder(CrossEncoderModel):
601
+ """
602
+ Passthrough cross-encoder that preserves RRF scores without neural reranking.
253
603
 
254
- try:
255
- response = self._request_with_retry(
256
- "POST",
257
- f"{self.base_url}/rerank",
258
- json={
259
- "query": query,
260
- "texts": texts,
261
- "return_text": False,
262
- },
263
- )
264
- results = response.json()
265
-
266
- # TEI returns results sorted by score descending, with original index
267
- for result in results:
268
- original_idx = result["index"]
269
- score = result["score"]
270
- # Map back to batch position
271
- batch_scores[indices[original_idx]] = score
272
-
273
- except httpx.HTTPError as e:
274
- raise RuntimeError(f"TEI rerank request failed: {e}")
275
-
276
- all_scores.extend(batch_scores)
604
+ This is useful for:
605
+ - Testing retrieval quality without reranking overhead
606
+ - Deployments where reranking latency is unacceptable
607
+ - Debugging to isolate retrieval vs reranking issues
608
+ """
609
+
610
+ def __init__(self):
611
+ """Initialize RRF passthrough cross-encoder."""
612
+ pass
613
+
614
+ @property
615
+ def provider_name(self) -> str:
616
+ return "rrf"
617
+
618
+ async def initialize(self) -> None:
619
+ """No initialization needed."""
620
+ logger.info("Reranker: RRF passthrough provider initialized (neural reranking disabled)")
621
+
622
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
623
+ """
624
+ Return neutral scores - actual ranking uses RRF scores from retrieval.
625
+
626
+ Args:
627
+ pairs: List of (query, document) tuples (ignored)
628
+
629
+ Returns:
630
+ List of 0.5 scores (neutral, lets RRF scores dominate)
631
+ """
632
+ # Return neutral scores so RRF ranking is preserved
633
+ return [0.5] * len(pairs)
634
+
635
+
636
+ class FlashRankCrossEncoder(CrossEncoderModel):
637
+ """
638
+ FlashRank cross-encoder implementation.
639
+
640
+ FlashRank is an ultra-lite reranking library that runs on CPU without
641
+ requiring PyTorch or Transformers. It's ideal for serverless deployments
642
+ with minimal cold-start overhead.
643
+
644
+ Available models:
645
+ - ms-marco-TinyBERT-L-2-v2: Fastest, ~4MB
646
+ - ms-marco-MiniLM-L-12-v2: Best quality, ~34MB (default)
647
+ - rank-T5-flan: Best zero-shot, ~110MB
648
+ - ms-marco-MultiBERT-L-12: Multi-lingual, ~150MB
649
+ """
650
+
651
+ # Shared executor for CPU-bound reranking
652
+ _executor: ThreadPoolExecutor | None = None
653
+ _max_concurrent: int = 4
654
+
655
+ def __init__(
656
+ self,
657
+ model_name: str | None = None,
658
+ cache_dir: str | None = None,
659
+ max_length: int = 512,
660
+ max_concurrent: int = 4,
661
+ ):
662
+ """
663
+ Initialize FlashRank cross-encoder.
664
+
665
+ Args:
666
+ model_name: FlashRank model name. Default: ms-marco-MiniLM-L-12-v2
667
+ cache_dir: Directory to cache downloaded models. Default: system cache
668
+ max_length: Maximum sequence length for reranking. Default: 512
669
+ max_concurrent: Maximum concurrent reranking calls. Default: 4
670
+ """
671
+ self.model_name = model_name or DEFAULT_RERANKER_FLASHRANK_MODEL
672
+ self.cache_dir = cache_dir or DEFAULT_RERANKER_FLASHRANK_CACHE_DIR
673
+ self.max_length = max_length
674
+ self._ranker = None
675
+ FlashRankCrossEncoder._max_concurrent = max_concurrent
676
+
677
+ @property
678
+ def provider_name(self) -> str:
679
+ return "flashrank"
680
+
681
+ async def initialize(self) -> None:
682
+ """Load the FlashRank model."""
683
+ if self._ranker is not None:
684
+ return
685
+
686
+ try:
687
+ from flashrank import Ranker # type: ignore[import-untyped]
688
+ except ImportError:
689
+ raise ImportError("flashrank is required for FlashRankCrossEncoder. Install it with: pip install flashrank")
690
+
691
+ logger.info(f"Reranker: initializing FlashRank provider with model {self.model_name}")
692
+
693
+ # Initialize ranker with optional cache directory
694
+ ranker_kwargs = {"model_name": self.model_name, "max_length": self.max_length}
695
+ if self.cache_dir:
696
+ ranker_kwargs["cache_dir"] = self.cache_dir
697
+
698
+ self._ranker = Ranker(**ranker_kwargs)
699
+
700
+ # Initialize shared executor
701
+ if FlashRankCrossEncoder._executor is None:
702
+ FlashRankCrossEncoder._executor = ThreadPoolExecutor(
703
+ max_workers=FlashRankCrossEncoder._max_concurrent,
704
+ thread_name_prefix="flashrank",
705
+ )
706
+ logger.info(
707
+ f"Reranker: FlashRank provider initialized (max_concurrent={FlashRankCrossEncoder._max_concurrent})"
708
+ )
709
+ else:
710
+ logger.info("Reranker: FlashRank provider initialized (using existing executor)")
711
+
712
+ def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
713
+ """Synchronous predict - processes each query group."""
714
+ from flashrank import RerankRequest # type: ignore[import-untyped]
715
+
716
+ if not pairs:
717
+ return []
718
+
719
+ # Group pairs by query
720
+ query_groups: dict[str, list[tuple[int, str]]] = {}
721
+ for idx, (query, text) in enumerate(pairs):
722
+ if query not in query_groups:
723
+ query_groups[query] = []
724
+ query_groups[query].append((idx, text))
725
+
726
+ all_scores = [0.0] * len(pairs)
727
+
728
+ for query, indexed_texts in query_groups.items():
729
+ # Build passages list for FlashRank
730
+ passages = [{"id": i, "text": text} for i, (_, text) in enumerate(indexed_texts)]
731
+ global_indices = [idx for idx, _ in indexed_texts]
732
+
733
+ # Create rerank request
734
+ request = RerankRequest(query=query, passages=passages)
735
+ results = self._ranker.rerank(request)
736
+
737
+ # Map scores back to original positions
738
+ for result in results:
739
+ local_idx = result["id"]
740
+ score = result["score"]
741
+ global_idx = global_indices[local_idx]
742
+ all_scores[global_idx] = score
743
+
744
+ return all_scores
745
+
746
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
747
+ """
748
+ Score query-document pairs using FlashRank.
749
+
750
+ Args:
751
+ pairs: List of (query, document) tuples to score
752
+
753
+ Returns:
754
+ List of relevance scores (higher = more relevant)
755
+ """
756
+ if self._ranker is None:
757
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
758
+
759
+ # Run in thread pool to avoid blocking event loop
760
+ loop = asyncio.get_event_loop()
761
+ return await loop.run_in_executor(FlashRankCrossEncoder._executor, self._predict_sync, pairs)
762
+
763
+
764
+ class LiteLLMCrossEncoder(CrossEncoderModel):
765
+ """
766
+ LiteLLM cross-encoder implementation using LiteLLM proxy's /rerank endpoint.
767
+
768
+ LiteLLM provides a unified interface for multiple reranking providers via
769
+ the Cohere-compatible /rerank endpoint.
770
+ See: https://docs.litellm.ai/docs/rerank
771
+
772
+ Supported providers via LiteLLM:
773
+ - Cohere (rerank-english-v3.0, etc.) - prefix with cohere/
774
+ - Together AI - prefix with together_ai/
775
+ - Azure AI - prefix with azure_ai/
776
+ - Jina AI - prefix with jina_ai/
777
+ - AWS Bedrock - prefix with bedrock/
778
+ - Voyage AI - prefix with voyage/
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ api_base: str = DEFAULT_LITELLM_API_BASE,
784
+ api_key: str | None = None,
785
+ model: str = DEFAULT_RERANKER_LITELLM_MODEL,
786
+ timeout: float = 60.0,
787
+ ):
788
+ """
789
+ Initialize LiteLLM cross-encoder client.
790
+
791
+ Args:
792
+ api_base: Base URL of the LiteLLM proxy (default: http://localhost:4000)
793
+ api_key: API key for the LiteLLM proxy (optional, depends on proxy config)
794
+ model: Reranking model name (default: cohere/rerank-english-v3.0)
795
+ Use provider prefix (e.g., cohere/, together_ai/, voyage/)
796
+ timeout: Request timeout in seconds (default: 60.0)
797
+ """
798
+ self.api_base = api_base.rstrip("/")
799
+ self.api_key = api_key
800
+ self.model = model
801
+ self.timeout = timeout
802
+ self._async_client: httpx.AsyncClient | None = None
803
+
804
+ @property
805
+ def provider_name(self) -> str:
806
+ return "litellm"
807
+
808
+ async def initialize(self) -> None:
809
+ """Initialize the async HTTP client."""
810
+ if self._async_client is not None:
811
+ return
812
+
813
+ logger.info(f"Reranker: initializing LiteLLM provider at {self.api_base} with model {self.model}")
814
+
815
+ headers = {"Content-Type": "application/json"}
816
+ if self.api_key:
817
+ headers["Authorization"] = f"Bearer {self.api_key}"
818
+
819
+ self._async_client = httpx.AsyncClient(timeout=self.timeout, headers=headers)
820
+ logger.info("Reranker: LiteLLM provider initialized")
821
+
822
+ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
823
+ """
824
+ Score query-document pairs using the LiteLLM proxy's /rerank endpoint.
825
+
826
+ Args:
827
+ pairs: List of (query, document) tuples to score
828
+
829
+ Returns:
830
+ List of relevance scores
831
+ """
832
+ if self._async_client is None:
833
+ raise RuntimeError("Reranker not initialized. Call initialize() first.")
834
+
835
+ if not pairs:
836
+ return []
837
+
838
+ # Group pairs by query (LiteLLM rerank expects one query with multiple documents)
839
+ query_groups: dict[str, list[tuple[int, str]]] = {}
840
+ for idx, (query, text) in enumerate(pairs):
841
+ if query not in query_groups:
842
+ query_groups[query] = []
843
+ query_groups[query].append((idx, text))
844
+
845
+ all_scores = [0.0] * len(pairs)
846
+
847
+ for query, indexed_texts in query_groups.items():
848
+ texts = [text for _, text in indexed_texts]
849
+ indices = [idx for idx, _ in indexed_texts]
850
+
851
+ # LiteLLM /rerank follows Cohere API format
852
+ response = await self._async_client.post(
853
+ f"{self.api_base}/rerank",
854
+ json={
855
+ "model": self.model,
856
+ "query": query,
857
+ "documents": texts,
858
+ "top_n": len(texts), # Return all scores
859
+ },
860
+ )
861
+ response.raise_for_status()
862
+ result = response.json()
863
+
864
+ # Map scores back to original positions
865
+ # Response format: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
866
+ for item in result.get("results", []):
867
+ original_idx = item["index"]
868
+ score = item.get("relevance_score", item.get("score", 0.0))
869
+ all_scores[indices[original_idx]] = score
277
870
 
278
871
  return all_scores
279
872
 
@@ -293,10 +886,35 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
293
886
  url = os.environ.get(ENV_RERANKER_TEI_URL)
294
887
  if not url:
295
888
  raise ValueError(f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'")
296
- return RemoteTEICrossEncoder(base_url=url)
889
+ batch_size = int(os.environ.get(ENV_RERANKER_TEI_BATCH_SIZE, str(DEFAULT_RERANKER_TEI_BATCH_SIZE)))
890
+ max_concurrent = int(os.environ.get(ENV_RERANKER_TEI_MAX_CONCURRENT, str(DEFAULT_RERANKER_TEI_MAX_CONCURRENT)))
891
+ return RemoteTEICrossEncoder(base_url=url, batch_size=batch_size, max_concurrent=max_concurrent)
297
892
  elif provider == "local":
298
893
  model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
299
894
  model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
300
- return LocalSTCrossEncoder(model_name=model_name)
895
+ max_concurrent = int(
896
+ os.environ.get(ENV_RERANKER_LOCAL_MAX_CONCURRENT, str(DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT))
897
+ )
898
+ return LocalSTCrossEncoder(model_name=model_name, max_concurrent=max_concurrent)
899
+ elif provider == "cohere":
900
+ api_key = os.environ.get(ENV_COHERE_API_KEY)
901
+ if not api_key:
902
+ raise ValueError(f"{ENV_COHERE_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'cohere'")
903
+ model = os.environ.get(ENV_RERANKER_COHERE_MODEL, DEFAULT_RERANKER_COHERE_MODEL)
904
+ base_url = os.environ.get(ENV_RERANKER_COHERE_BASE_URL) or None
905
+ return CohereCrossEncoder(api_key=api_key, model=model, base_url=base_url)
906
+ elif provider == "flashrank":
907
+ model = os.environ.get(ENV_RERANKER_FLASHRANK_MODEL, DEFAULT_RERANKER_FLASHRANK_MODEL)
908
+ cache_dir = os.environ.get(ENV_RERANKER_FLASHRANK_CACHE_DIR, DEFAULT_RERANKER_FLASHRANK_CACHE_DIR)
909
+ return FlashRankCrossEncoder(model_name=model, cache_dir=cache_dir)
910
+ elif provider == "litellm":
911
+ api_base = os.environ.get(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE)
912
+ api_key = os.environ.get(ENV_LITELLM_API_KEY)
913
+ model = os.environ.get(ENV_RERANKER_LITELLM_MODEL, DEFAULT_RERANKER_LITELLM_MODEL)
914
+ return LiteLLMCrossEncoder(api_base=api_base, api_key=api_key, model=model)
915
+ elif provider == "rrf":
916
+ return RRFPassthroughCrossEncoder()
301
917
  else:
302
- raise ValueError(f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'")
918
+ raise ValueError(
919
+ f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'rrf'"
920
+ )