hindsight-api 0.0.21__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -2
- hindsight_api/alembic/README +1 -0
- hindsight_api/alembic/env.py +146 -0
- hindsight_api/alembic/script.py.mako +28 -0
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +274 -0
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +70 -0
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +39 -0
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +48 -0
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +62 -0
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +65 -0
- hindsight_api/api/__init__.py +2 -4
- hindsight_api/api/http.py +112 -164
- hindsight_api/api/mcp.py +2 -1
- hindsight_api/config.py +154 -0
- hindsight_api/engine/__init__.py +7 -2
- hindsight_api/engine/cross_encoder.py +225 -16
- hindsight_api/engine/embeddings.py +198 -19
- hindsight_api/engine/entity_resolver.py +56 -29
- hindsight_api/engine/llm_wrapper.py +147 -106
- hindsight_api/engine/memory_engine.py +337 -192
- hindsight_api/engine/response_models.py +15 -17
- hindsight_api/engine/retain/bank_utils.py +25 -35
- hindsight_api/engine/retain/entity_processing.py +5 -5
- hindsight_api/engine/retain/fact_extraction.py +86 -24
- hindsight_api/engine/retain/fact_storage.py +1 -1
- hindsight_api/engine/retain/link_creation.py +12 -6
- hindsight_api/engine/retain/link_utils.py +50 -56
- hindsight_api/engine/retain/observation_regeneration.py +264 -0
- hindsight_api/engine/retain/orchestrator.py +31 -44
- hindsight_api/engine/retain/types.py +14 -0
- hindsight_api/engine/search/reranking.py +6 -10
- hindsight_api/engine/search/retrieval.py +2 -2
- hindsight_api/engine/search/think_utils.py +59 -30
- hindsight_api/engine/search/tracer.py +1 -1
- hindsight_api/main.py +201 -0
- hindsight_api/migrations.py +61 -39
- hindsight_api/models.py +1 -2
- hindsight_api/pg0.py +17 -36
- hindsight_api/server.py +43 -0
- {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/METADATA +2 -3
- hindsight_api-0.1.1.dist-info/RECORD +60 -0
- hindsight_api-0.1.1.dist-info/entry_points.txt +2 -0
- hindsight_api/cli.py +0 -128
- hindsight_api/web/__init__.py +0 -12
- hindsight_api/web/server.py +0 -109
- hindsight_api-0.0.21.dist-info/RECORD +0 -50
- hindsight_api-0.0.21.dist-info/entry_points.txt +0 -2
- {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/WHEEL +0 -0
hindsight_api/config.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Centralized configuration for Hindsight API.
|
|
3
|
+
|
|
4
|
+
All environment variables and their defaults are defined here.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Optional
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
# Environment variable names
|
|
14
|
+
ENV_DATABASE_URL = "HINDSIGHT_API_DATABASE_URL"
|
|
15
|
+
ENV_LLM_PROVIDER = "HINDSIGHT_API_LLM_PROVIDER"
|
|
16
|
+
ENV_LLM_API_KEY = "HINDSIGHT_API_LLM_API_KEY"
|
|
17
|
+
ENV_LLM_MODEL = "HINDSIGHT_API_LLM_MODEL"
|
|
18
|
+
ENV_LLM_BASE_URL = "HINDSIGHT_API_LLM_BASE_URL"
|
|
19
|
+
|
|
20
|
+
ENV_EMBEDDINGS_PROVIDER = "HINDSIGHT_API_EMBEDDINGS_PROVIDER"
|
|
21
|
+
ENV_EMBEDDINGS_LOCAL_MODEL = "HINDSIGHT_API_EMBEDDINGS_LOCAL_MODEL"
|
|
22
|
+
ENV_EMBEDDINGS_TEI_URL = "HINDSIGHT_API_EMBEDDINGS_TEI_URL"
|
|
23
|
+
|
|
24
|
+
ENV_RERANKER_PROVIDER = "HINDSIGHT_API_RERANKER_PROVIDER"
|
|
25
|
+
ENV_RERANKER_LOCAL_MODEL = "HINDSIGHT_API_RERANKER_LOCAL_MODEL"
|
|
26
|
+
ENV_RERANKER_TEI_URL = "HINDSIGHT_API_RERANKER_TEI_URL"
|
|
27
|
+
|
|
28
|
+
ENV_HOST = "HINDSIGHT_API_HOST"
|
|
29
|
+
ENV_PORT = "HINDSIGHT_API_PORT"
|
|
30
|
+
ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
|
|
31
|
+
ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
|
|
32
|
+
|
|
33
|
+
# Default values
|
|
34
|
+
DEFAULT_DATABASE_URL = "pg0"
|
|
35
|
+
DEFAULT_LLM_PROVIDER = "groq"
|
|
36
|
+
DEFAULT_LLM_MODEL = "openai/gpt-oss-20b"
|
|
37
|
+
|
|
38
|
+
DEFAULT_EMBEDDINGS_PROVIDER = "local"
|
|
39
|
+
DEFAULT_EMBEDDINGS_LOCAL_MODEL = "BAAI/bge-small-en-v1.5"
|
|
40
|
+
|
|
41
|
+
DEFAULT_RERANKER_PROVIDER = "local"
|
|
42
|
+
DEFAULT_RERANKER_LOCAL_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
43
|
+
|
|
44
|
+
DEFAULT_HOST = "0.0.0.0"
|
|
45
|
+
DEFAULT_PORT = 8888
|
|
46
|
+
DEFAULT_LOG_LEVEL = "info"
|
|
47
|
+
DEFAULT_MCP_ENABLED = True
|
|
48
|
+
|
|
49
|
+
# Required embedding dimension for database schema
|
|
50
|
+
EMBEDDING_DIMENSION = 384
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class HindsightConfig:
|
|
55
|
+
"""Configuration container for Hindsight API."""
|
|
56
|
+
|
|
57
|
+
# Database
|
|
58
|
+
database_url: str
|
|
59
|
+
|
|
60
|
+
# LLM
|
|
61
|
+
llm_provider: str
|
|
62
|
+
llm_api_key: Optional[str]
|
|
63
|
+
llm_model: str
|
|
64
|
+
llm_base_url: Optional[str]
|
|
65
|
+
|
|
66
|
+
# Embeddings
|
|
67
|
+
embeddings_provider: str
|
|
68
|
+
embeddings_local_model: str
|
|
69
|
+
embeddings_tei_url: Optional[str]
|
|
70
|
+
|
|
71
|
+
# Reranker
|
|
72
|
+
reranker_provider: str
|
|
73
|
+
reranker_local_model: str
|
|
74
|
+
reranker_tei_url: Optional[str]
|
|
75
|
+
|
|
76
|
+
# Server
|
|
77
|
+
host: str
|
|
78
|
+
port: int
|
|
79
|
+
log_level: str
|
|
80
|
+
mcp_enabled: bool
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_env(cls) -> "HindsightConfig":
|
|
84
|
+
"""Create configuration from environment variables."""
|
|
85
|
+
return cls(
|
|
86
|
+
# Database
|
|
87
|
+
database_url=os.getenv(ENV_DATABASE_URL, DEFAULT_DATABASE_URL),
|
|
88
|
+
|
|
89
|
+
# LLM
|
|
90
|
+
llm_provider=os.getenv(ENV_LLM_PROVIDER, DEFAULT_LLM_PROVIDER),
|
|
91
|
+
llm_api_key=os.getenv(ENV_LLM_API_KEY),
|
|
92
|
+
llm_model=os.getenv(ENV_LLM_MODEL, DEFAULT_LLM_MODEL),
|
|
93
|
+
llm_base_url=os.getenv(ENV_LLM_BASE_URL) or None,
|
|
94
|
+
|
|
95
|
+
# Embeddings
|
|
96
|
+
embeddings_provider=os.getenv(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER),
|
|
97
|
+
embeddings_local_model=os.getenv(ENV_EMBEDDINGS_LOCAL_MODEL, DEFAULT_EMBEDDINGS_LOCAL_MODEL),
|
|
98
|
+
embeddings_tei_url=os.getenv(ENV_EMBEDDINGS_TEI_URL),
|
|
99
|
+
|
|
100
|
+
# Reranker
|
|
101
|
+
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
|
|
102
|
+
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
|
|
103
|
+
reranker_tei_url=os.getenv(ENV_RERANKER_TEI_URL),
|
|
104
|
+
|
|
105
|
+
# Server
|
|
106
|
+
host=os.getenv(ENV_HOST, DEFAULT_HOST),
|
|
107
|
+
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
|
|
108
|
+
log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
|
109
|
+
mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def get_llm_base_url(self) -> str:
|
|
113
|
+
"""Get the LLM base URL, with provider-specific defaults."""
|
|
114
|
+
if self.llm_base_url:
|
|
115
|
+
return self.llm_base_url
|
|
116
|
+
|
|
117
|
+
provider = self.llm_provider.lower()
|
|
118
|
+
if provider == "groq":
|
|
119
|
+
return "https://api.groq.com/openai/v1"
|
|
120
|
+
elif provider == "ollama":
|
|
121
|
+
return "http://localhost:11434/v1"
|
|
122
|
+
else:
|
|
123
|
+
return ""
|
|
124
|
+
|
|
125
|
+
def get_python_log_level(self) -> int:
|
|
126
|
+
"""Get the Python logging level from the configured log level string."""
|
|
127
|
+
log_level_map = {
|
|
128
|
+
"critical": logging.CRITICAL,
|
|
129
|
+
"error": logging.ERROR,
|
|
130
|
+
"warning": logging.WARNING,
|
|
131
|
+
"info": logging.INFO,
|
|
132
|
+
"debug": logging.DEBUG,
|
|
133
|
+
"trace": logging.DEBUG, # Python doesn't have TRACE, use DEBUG
|
|
134
|
+
}
|
|
135
|
+
return log_level_map.get(self.log_level.lower(), logging.INFO)
|
|
136
|
+
|
|
137
|
+
def configure_logging(self) -> None:
|
|
138
|
+
"""Configure Python logging based on the log level."""
|
|
139
|
+
logging.basicConfig(
|
|
140
|
+
level=self.get_python_log_level(),
|
|
141
|
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def log_config(self) -> None:
|
|
145
|
+
"""Log the current configuration (without sensitive values)."""
|
|
146
|
+
logger.info(f"Database: {self.database_url}")
|
|
147
|
+
logger.info(f"LLM: provider={self.llm_provider}, model={self.llm_model}")
|
|
148
|
+
logger.info(f"Embeddings: provider={self.embeddings_provider}")
|
|
149
|
+
logger.info(f"Reranker: provider={self.reranker_provider}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_config() -> HindsightConfig:
|
|
153
|
+
"""Get the current configuration from environment variables."""
|
|
154
|
+
return HindsightConfig.from_env()
|
hindsight_api/engine/__init__.py
CHANGED
|
@@ -9,7 +9,8 @@ This package contains all the implementation details of the memory engine:
|
|
|
9
9
|
|
|
10
10
|
from .memory_engine import MemoryEngine
|
|
11
11
|
from .db_utils import acquire_with_retry
|
|
12
|
-
from .embeddings import Embeddings,
|
|
12
|
+
from .embeddings import Embeddings, LocalSTEmbeddings, RemoteTEIEmbeddings
|
|
13
|
+
from .cross_encoder import CrossEncoderModel, LocalSTCrossEncoder, RemoteTEICrossEncoder
|
|
13
14
|
from .search.trace import (
|
|
14
15
|
SearchTrace,
|
|
15
16
|
QueryInfo,
|
|
@@ -29,7 +30,11 @@ __all__ = [
|
|
|
29
30
|
"MemoryEngine",
|
|
30
31
|
"acquire_with_retry",
|
|
31
32
|
"Embeddings",
|
|
32
|
-
"
|
|
33
|
+
"LocalSTEmbeddings",
|
|
34
|
+
"RemoteTEIEmbeddings",
|
|
35
|
+
"CrossEncoderModel",
|
|
36
|
+
"LocalSTCrossEncoder",
|
|
37
|
+
"RemoteTEICrossEncoder",
|
|
33
38
|
"SearchTrace",
|
|
34
39
|
"SearchTracer",
|
|
35
40
|
"QueryInfo",
|
|
@@ -2,10 +2,23 @@
|
|
|
2
2
|
Cross-encoder abstraction for reranking.
|
|
3
3
|
|
|
4
4
|
Provides an interface for reranking with different backends.
|
|
5
|
+
|
|
6
|
+
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
5
7
|
"""
|
|
6
8
|
from abc import ABC, abstractmethod
|
|
7
|
-
from typing import List, Tuple
|
|
9
|
+
from typing import List, Tuple, Optional
|
|
8
10
|
import logging
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from ..config import (
|
|
16
|
+
ENV_RERANKER_PROVIDER,
|
|
17
|
+
ENV_RERANKER_LOCAL_MODEL,
|
|
18
|
+
ENV_RERANKER_TEI_URL,
|
|
19
|
+
DEFAULT_RERANKER_PROVIDER,
|
|
20
|
+
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
21
|
+
)
|
|
9
22
|
|
|
10
23
|
logger = logging.getLogger(__name__)
|
|
11
24
|
|
|
@@ -17,12 +30,18 @@ class CrossEncoderModel(ABC):
|
|
|
17
30
|
Cross-encoders take query-document pairs and return relevance scores.
|
|
18
31
|
"""
|
|
19
32
|
|
|
33
|
+
@property
|
|
20
34
|
@abstractmethod
|
|
21
|
-
def
|
|
35
|
+
def provider_name(self) -> str:
|
|
36
|
+
"""Return a human-readable name for this provider (e.g., 'local', 'tei')."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
async def initialize(self) -> None:
|
|
22
41
|
"""
|
|
23
|
-
|
|
42
|
+
Initialize the cross-encoder model asynchronously.
|
|
24
43
|
|
|
25
|
-
This should be called during
|
|
44
|
+
This should be called during startup to load/connect to the model
|
|
26
45
|
and avoid cold start latency on first predict() call.
|
|
27
46
|
"""
|
|
28
47
|
pass
|
|
@@ -41,11 +60,11 @@ class CrossEncoderModel(ABC):
|
|
|
41
60
|
pass
|
|
42
61
|
|
|
43
62
|
|
|
44
|
-
class
|
|
63
|
+
class LocalSTCrossEncoder(CrossEncoderModel):
|
|
45
64
|
"""
|
|
46
|
-
|
|
65
|
+
Local cross-encoder implementation using SentenceTransformers.
|
|
47
66
|
|
|
48
|
-
Call
|
|
67
|
+
Call initialize() during startup to load the model and avoid cold starts.
|
|
49
68
|
|
|
50
69
|
Default model is cross-encoder/ms-marco-MiniLM-L-6-v2:
|
|
51
70
|
- Fast inference (~80ms for 100 pairs on CPU)
|
|
@@ -53,18 +72,22 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
53
72
|
- Trained for passage re-ranking
|
|
54
73
|
"""
|
|
55
74
|
|
|
56
|
-
def __init__(self, model_name: str =
|
|
75
|
+
def __init__(self, model_name: Optional[str] = None):
|
|
57
76
|
"""
|
|
58
|
-
Initialize SentenceTransformers cross-encoder.
|
|
77
|
+
Initialize local SentenceTransformers cross-encoder.
|
|
59
78
|
|
|
60
79
|
Args:
|
|
61
80
|
model_name: Name of the CrossEncoder model to use.
|
|
62
81
|
Default: cross-encoder/ms-marco-MiniLM-L-6-v2
|
|
63
82
|
"""
|
|
64
|
-
self.model_name = model_name
|
|
83
|
+
self.model_name = model_name or DEFAULT_RERANKER_LOCAL_MODEL
|
|
65
84
|
self._model = None
|
|
66
85
|
|
|
67
|
-
|
|
86
|
+
@property
|
|
87
|
+
def provider_name(self) -> str:
|
|
88
|
+
return "local"
|
|
89
|
+
|
|
90
|
+
async def initialize(self) -> None:
|
|
68
91
|
"""Load the cross-encoder model."""
|
|
69
92
|
if self._model is not None:
|
|
70
93
|
return
|
|
@@ -73,13 +96,18 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
73
96
|
from sentence_transformers import CrossEncoder
|
|
74
97
|
except ImportError:
|
|
75
98
|
raise ImportError(
|
|
76
|
-
"sentence-transformers is required for
|
|
99
|
+
"sentence-transformers is required for LocalSTCrossEncoder. "
|
|
77
100
|
"Install it with: pip install sentence-transformers"
|
|
78
101
|
)
|
|
79
102
|
|
|
80
|
-
logger.info(f"
|
|
81
|
-
|
|
82
|
-
|
|
103
|
+
logger.info(f"Reranker: initializing local provider with model {self.model_name}")
|
|
104
|
+
# Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
|
|
105
|
+
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
106
|
+
self._model = CrossEncoder(
|
|
107
|
+
self.model_name,
|
|
108
|
+
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
109
|
+
)
|
|
110
|
+
logger.info("Reranker: local provider initialized")
|
|
83
111
|
|
|
84
112
|
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
85
113
|
"""
|
|
@@ -92,6 +120,187 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
92
120
|
List of relevance scores (raw logits from the model)
|
|
93
121
|
"""
|
|
94
122
|
if self._model is None:
|
|
95
|
-
|
|
123
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
96
124
|
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
97
125
|
return scores.tolist() if hasattr(scores, 'tolist') else list(scores)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
129
|
+
"""
|
|
130
|
+
Remote cross-encoder implementation using HuggingFace Text Embeddings Inference (TEI) HTTP API.
|
|
131
|
+
|
|
132
|
+
TEI supports reranking via the /rerank endpoint.
|
|
133
|
+
See: https://github.com/huggingface/text-embeddings-inference
|
|
134
|
+
|
|
135
|
+
Note: The TEI server must be running a cross-encoder/reranker model.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
base_url: str,
|
|
141
|
+
timeout: float = 30.0,
|
|
142
|
+
batch_size: int = 32,
|
|
143
|
+
max_retries: int = 3,
|
|
144
|
+
retry_delay: float = 0.5,
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Initialize remote TEI cross-encoder client.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
|
|
151
|
+
timeout: Request timeout in seconds (default: 30.0)
|
|
152
|
+
batch_size: Maximum batch size for rerank requests (default: 32)
|
|
153
|
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
|
154
|
+
retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
|
|
155
|
+
"""
|
|
156
|
+
self.base_url = base_url.rstrip("/")
|
|
157
|
+
self.timeout = timeout
|
|
158
|
+
self.batch_size = batch_size
|
|
159
|
+
self.max_retries = max_retries
|
|
160
|
+
self.retry_delay = retry_delay
|
|
161
|
+
self._client: Optional[httpx.Client] = None
|
|
162
|
+
self._model_id: Optional[str] = None
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def provider_name(self) -> str:
|
|
166
|
+
return "tei"
|
|
167
|
+
|
|
168
|
+
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
169
|
+
"""Make an HTTP request with automatic retries on transient errors."""
|
|
170
|
+
import time
|
|
171
|
+
last_error = None
|
|
172
|
+
delay = self.retry_delay
|
|
173
|
+
|
|
174
|
+
for attempt in range(self.max_retries + 1):
|
|
175
|
+
try:
|
|
176
|
+
if method == "GET":
|
|
177
|
+
response = self._client.get(url, **kwargs)
|
|
178
|
+
else:
|
|
179
|
+
response = self._client.post(url, **kwargs)
|
|
180
|
+
response.raise_for_status()
|
|
181
|
+
return response
|
|
182
|
+
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
183
|
+
last_error = e
|
|
184
|
+
if attempt < self.max_retries:
|
|
185
|
+
logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
186
|
+
time.sleep(delay)
|
|
187
|
+
delay *= 2 # Exponential backoff
|
|
188
|
+
except httpx.HTTPStatusError as e:
|
|
189
|
+
# Retry on 5xx server errors
|
|
190
|
+
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
191
|
+
last_error = e
|
|
192
|
+
logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
193
|
+
time.sleep(delay)
|
|
194
|
+
delay *= 2
|
|
195
|
+
else:
|
|
196
|
+
raise
|
|
197
|
+
|
|
198
|
+
raise last_error
|
|
199
|
+
|
|
200
|
+
async def initialize(self) -> None:
|
|
201
|
+
"""Initialize the HTTP client and verify server connectivity."""
|
|
202
|
+
if self._client is not None:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
logger.info(f"Reranker: initializing TEI provider at {self.base_url}")
|
|
206
|
+
self._client = httpx.Client(timeout=self.timeout)
|
|
207
|
+
|
|
208
|
+
# Verify server is reachable and get model info
|
|
209
|
+
try:
|
|
210
|
+
response = self._request_with_retry("GET", f"{self.base_url}/info")
|
|
211
|
+
info = response.json()
|
|
212
|
+
self._model_id = info.get("model_id", "unknown")
|
|
213
|
+
logger.info(f"Reranker: TEI provider initialized (model: {self._model_id})")
|
|
214
|
+
except httpx.HTTPError as e:
|
|
215
|
+
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
216
|
+
|
|
217
|
+
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
218
|
+
"""
|
|
219
|
+
Score query-document pairs using the remote TEI reranker.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
pairs: List of (query, document) tuples to score
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of relevance scores
|
|
226
|
+
"""
|
|
227
|
+
if self._client is None:
|
|
228
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
229
|
+
|
|
230
|
+
if not pairs:
|
|
231
|
+
return []
|
|
232
|
+
|
|
233
|
+
all_scores = []
|
|
234
|
+
|
|
235
|
+
# Process in batches
|
|
236
|
+
for i in range(0, len(pairs), self.batch_size):
|
|
237
|
+
batch = pairs[i:i + self.batch_size]
|
|
238
|
+
|
|
239
|
+
# TEI rerank endpoint expects query and texts separately
|
|
240
|
+
# All pairs in a batch should have the same query for optimal performance
|
|
241
|
+
# but we handle mixed queries by making separate requests per unique query
|
|
242
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
243
|
+
for idx, (query, text) in enumerate(batch):
|
|
244
|
+
if query not in query_groups:
|
|
245
|
+
query_groups[query] = []
|
|
246
|
+
query_groups[query].append((idx, text))
|
|
247
|
+
|
|
248
|
+
batch_scores = [0.0] * len(batch)
|
|
249
|
+
|
|
250
|
+
for query, indexed_texts in query_groups.items():
|
|
251
|
+
texts = [text for _, text in indexed_texts]
|
|
252
|
+
indices = [idx for idx, _ in indexed_texts]
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
response = self._request_with_retry(
|
|
256
|
+
"POST",
|
|
257
|
+
f"{self.base_url}/rerank",
|
|
258
|
+
json={
|
|
259
|
+
"query": query,
|
|
260
|
+
"texts": texts,
|
|
261
|
+
"return_text": False,
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
results = response.json()
|
|
265
|
+
|
|
266
|
+
# TEI returns results sorted by score descending, with original index
|
|
267
|
+
for result in results:
|
|
268
|
+
original_idx = result["index"]
|
|
269
|
+
score = result["score"]
|
|
270
|
+
# Map back to batch position
|
|
271
|
+
batch_scores[indices[original_idx]] = score
|
|
272
|
+
|
|
273
|
+
except httpx.HTTPError as e:
|
|
274
|
+
raise RuntimeError(f"TEI rerank request failed: {e}")
|
|
275
|
+
|
|
276
|
+
all_scores.extend(batch_scores)
|
|
277
|
+
|
|
278
|
+
return all_scores
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def create_cross_encoder_from_env() -> CrossEncoderModel:
|
|
282
|
+
"""
|
|
283
|
+
Create a CrossEncoderModel instance based on environment variables.
|
|
284
|
+
|
|
285
|
+
See hindsight_api.config for environment variable names and defaults.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Configured CrossEncoderModel instance
|
|
289
|
+
"""
|
|
290
|
+
provider = os.environ.get(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER).lower()
|
|
291
|
+
|
|
292
|
+
if provider == "tei":
|
|
293
|
+
url = os.environ.get(ENV_RERANKER_TEI_URL)
|
|
294
|
+
if not url:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
|
|
297
|
+
)
|
|
298
|
+
return RemoteTEICrossEncoder(base_url=url)
|
|
299
|
+
elif provider == "local":
|
|
300
|
+
model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
|
|
301
|
+
model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
|
|
302
|
+
return LocalSTCrossEncoder(model_name=model_name)
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
|
|
306
|
+
)
|