hindsight-api 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -2
- hindsight_api/alembic/README +1 -0
- hindsight_api/alembic/env.py +146 -0
- hindsight_api/alembic/script.py.mako +28 -0
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +274 -0
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +70 -0
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +39 -0
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +48 -0
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +62 -0
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +65 -0
- hindsight_api/api/http.py +84 -86
- hindsight_api/config.py +154 -0
- hindsight_api/engine/__init__.py +7 -2
- hindsight_api/engine/cross_encoder.py +219 -15
- hindsight_api/engine/embeddings.py +192 -18
- hindsight_api/engine/llm_wrapper.py +88 -139
- hindsight_api/engine/memory_engine.py +71 -51
- hindsight_api/engine/retain/bank_utils.py +2 -2
- hindsight_api/engine/retain/fact_extraction.py +1 -1
- hindsight_api/engine/search/reranking.py +6 -10
- hindsight_api/engine/search/tracer.py +1 -1
- hindsight_api/main.py +201 -0
- hindsight_api/migrations.py +7 -7
- hindsight_api/server.py +43 -0
- {hindsight_api-0.1.0.dist-info → hindsight_api-0.1.2.dist-info}/METADATA +1 -1
- {hindsight_api-0.1.0.dist-info → hindsight_api-0.1.2.dist-info}/RECORD +28 -19
- hindsight_api-0.1.2.dist-info/entry_points.txt +2 -0
- hindsight_api/cli.py +0 -127
- hindsight_api/web/__init__.py +0 -12
- hindsight_api/web/server.py +0 -109
- hindsight_api-0.1.0.dist-info/entry_points.txt +0 -2
- {hindsight_api-0.1.0.dist-info → hindsight_api-0.1.2.dist-info}/WHEEL +0 -0
|
@@ -2,10 +2,23 @@
|
|
|
2
2
|
Cross-encoder abstraction for reranking.
|
|
3
3
|
|
|
4
4
|
Provides an interface for reranking with different backends.
|
|
5
|
+
|
|
6
|
+
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
5
7
|
"""
|
|
6
8
|
from abc import ABC, abstractmethod
|
|
7
|
-
from typing import List, Tuple
|
|
9
|
+
from typing import List, Tuple, Optional
|
|
8
10
|
import logging
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from ..config import (
|
|
16
|
+
ENV_RERANKER_PROVIDER,
|
|
17
|
+
ENV_RERANKER_LOCAL_MODEL,
|
|
18
|
+
ENV_RERANKER_TEI_URL,
|
|
19
|
+
DEFAULT_RERANKER_PROVIDER,
|
|
20
|
+
DEFAULT_RERANKER_LOCAL_MODEL,
|
|
21
|
+
)
|
|
9
22
|
|
|
10
23
|
logger = logging.getLogger(__name__)
|
|
11
24
|
|
|
@@ -17,12 +30,18 @@ class CrossEncoderModel(ABC):
|
|
|
17
30
|
Cross-encoders take query-document pairs and return relevance scores.
|
|
18
31
|
"""
|
|
19
32
|
|
|
33
|
+
@property
|
|
20
34
|
@abstractmethod
|
|
21
|
-
def
|
|
35
|
+
def provider_name(self) -> str:
|
|
36
|
+
"""Return a human-readable name for this provider (e.g., 'local', 'tei')."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
async def initialize(self) -> None:
|
|
22
41
|
"""
|
|
23
|
-
|
|
42
|
+
Initialize the cross-encoder model asynchronously.
|
|
24
43
|
|
|
25
|
-
This should be called during
|
|
44
|
+
This should be called during startup to load/connect to the model
|
|
26
45
|
and avoid cold start latency on first predict() call.
|
|
27
46
|
"""
|
|
28
47
|
pass
|
|
@@ -41,11 +60,11 @@ class CrossEncoderModel(ABC):
|
|
|
41
60
|
pass
|
|
42
61
|
|
|
43
62
|
|
|
44
|
-
class
|
|
63
|
+
class LocalSTCrossEncoder(CrossEncoderModel):
|
|
45
64
|
"""
|
|
46
|
-
|
|
65
|
+
Local cross-encoder implementation using SentenceTransformers.
|
|
47
66
|
|
|
48
|
-
Call
|
|
67
|
+
Call initialize() during startup to load the model and avoid cold starts.
|
|
49
68
|
|
|
50
69
|
Default model is cross-encoder/ms-marco-MiniLM-L-6-v2:
|
|
51
70
|
- Fast inference (~80ms for 100 pairs on CPU)
|
|
@@ -53,18 +72,22 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
53
72
|
- Trained for passage re-ranking
|
|
54
73
|
"""
|
|
55
74
|
|
|
56
|
-
def __init__(self, model_name: str =
|
|
75
|
+
def __init__(self, model_name: Optional[str] = None):
|
|
57
76
|
"""
|
|
58
|
-
Initialize SentenceTransformers cross-encoder.
|
|
77
|
+
Initialize local SentenceTransformers cross-encoder.
|
|
59
78
|
|
|
60
79
|
Args:
|
|
61
80
|
model_name: Name of the CrossEncoder model to use.
|
|
62
81
|
Default: cross-encoder/ms-marco-MiniLM-L-6-v2
|
|
63
82
|
"""
|
|
64
|
-
self.model_name = model_name
|
|
83
|
+
self.model_name = model_name or DEFAULT_RERANKER_LOCAL_MODEL
|
|
65
84
|
self._model = None
|
|
66
85
|
|
|
67
|
-
|
|
86
|
+
@property
|
|
87
|
+
def provider_name(self) -> str:
|
|
88
|
+
return "local"
|
|
89
|
+
|
|
90
|
+
async def initialize(self) -> None:
|
|
68
91
|
"""Load the cross-encoder model."""
|
|
69
92
|
if self._model is not None:
|
|
70
93
|
return
|
|
@@ -73,18 +96,18 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
73
96
|
from sentence_transformers import CrossEncoder
|
|
74
97
|
except ImportError:
|
|
75
98
|
raise ImportError(
|
|
76
|
-
"sentence-transformers is required for
|
|
99
|
+
"sentence-transformers is required for LocalSTCrossEncoder. "
|
|
77
100
|
"Install it with: pip install sentence-transformers"
|
|
78
101
|
)
|
|
79
102
|
|
|
80
|
-
logger.info(f"
|
|
103
|
+
logger.info(f"Reranker: initializing local provider with model {self.model_name}")
|
|
81
104
|
# Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
|
|
82
105
|
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
83
106
|
self._model = CrossEncoder(
|
|
84
107
|
self.model_name,
|
|
85
108
|
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
86
109
|
)
|
|
87
|
-
logger.info("
|
|
110
|
+
logger.info("Reranker: local provider initialized")
|
|
88
111
|
|
|
89
112
|
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
90
113
|
"""
|
|
@@ -97,6 +120,187 @@ class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
|
97
120
|
List of relevance scores (raw logits from the model)
|
|
98
121
|
"""
|
|
99
122
|
if self._model is None:
|
|
100
|
-
|
|
123
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
101
124
|
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
102
125
|
return scores.tolist() if hasattr(scores, 'tolist') else list(scores)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class RemoteTEICrossEncoder(CrossEncoderModel):
|
|
129
|
+
"""
|
|
130
|
+
Remote cross-encoder implementation using HuggingFace Text Embeddings Inference (TEI) HTTP API.
|
|
131
|
+
|
|
132
|
+
TEI supports reranking via the /rerank endpoint.
|
|
133
|
+
See: https://github.com/huggingface/text-embeddings-inference
|
|
134
|
+
|
|
135
|
+
Note: The TEI server must be running a cross-encoder/reranker model.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
base_url: str,
|
|
141
|
+
timeout: float = 30.0,
|
|
142
|
+
batch_size: int = 32,
|
|
143
|
+
max_retries: int = 3,
|
|
144
|
+
retry_delay: float = 0.5,
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Initialize remote TEI cross-encoder client.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
|
|
151
|
+
timeout: Request timeout in seconds (default: 30.0)
|
|
152
|
+
batch_size: Maximum batch size for rerank requests (default: 32)
|
|
153
|
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
|
154
|
+
retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
|
|
155
|
+
"""
|
|
156
|
+
self.base_url = base_url.rstrip("/")
|
|
157
|
+
self.timeout = timeout
|
|
158
|
+
self.batch_size = batch_size
|
|
159
|
+
self.max_retries = max_retries
|
|
160
|
+
self.retry_delay = retry_delay
|
|
161
|
+
self._client: Optional[httpx.Client] = None
|
|
162
|
+
self._model_id: Optional[str] = None
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def provider_name(self) -> str:
|
|
166
|
+
return "tei"
|
|
167
|
+
|
|
168
|
+
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
169
|
+
"""Make an HTTP request with automatic retries on transient errors."""
|
|
170
|
+
import time
|
|
171
|
+
last_error = None
|
|
172
|
+
delay = self.retry_delay
|
|
173
|
+
|
|
174
|
+
for attempt in range(self.max_retries + 1):
|
|
175
|
+
try:
|
|
176
|
+
if method == "GET":
|
|
177
|
+
response = self._client.get(url, **kwargs)
|
|
178
|
+
else:
|
|
179
|
+
response = self._client.post(url, **kwargs)
|
|
180
|
+
response.raise_for_status()
|
|
181
|
+
return response
|
|
182
|
+
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
183
|
+
last_error = e
|
|
184
|
+
if attempt < self.max_retries:
|
|
185
|
+
logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
186
|
+
time.sleep(delay)
|
|
187
|
+
delay *= 2 # Exponential backoff
|
|
188
|
+
except httpx.HTTPStatusError as e:
|
|
189
|
+
# Retry on 5xx server errors
|
|
190
|
+
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
191
|
+
last_error = e
|
|
192
|
+
logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
193
|
+
time.sleep(delay)
|
|
194
|
+
delay *= 2
|
|
195
|
+
else:
|
|
196
|
+
raise
|
|
197
|
+
|
|
198
|
+
raise last_error
|
|
199
|
+
|
|
200
|
+
async def initialize(self) -> None:
|
|
201
|
+
"""Initialize the HTTP client and verify server connectivity."""
|
|
202
|
+
if self._client is not None:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
logger.info(f"Reranker: initializing TEI provider at {self.base_url}")
|
|
206
|
+
self._client = httpx.Client(timeout=self.timeout)
|
|
207
|
+
|
|
208
|
+
# Verify server is reachable and get model info
|
|
209
|
+
try:
|
|
210
|
+
response = self._request_with_retry("GET", f"{self.base_url}/info")
|
|
211
|
+
info = response.json()
|
|
212
|
+
self._model_id = info.get("model_id", "unknown")
|
|
213
|
+
logger.info(f"Reranker: TEI provider initialized (model: {self._model_id})")
|
|
214
|
+
except httpx.HTTPError as e:
|
|
215
|
+
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
216
|
+
|
|
217
|
+
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
218
|
+
"""
|
|
219
|
+
Score query-document pairs using the remote TEI reranker.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
pairs: List of (query, document) tuples to score
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of relevance scores
|
|
226
|
+
"""
|
|
227
|
+
if self._client is None:
|
|
228
|
+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
|
|
229
|
+
|
|
230
|
+
if not pairs:
|
|
231
|
+
return []
|
|
232
|
+
|
|
233
|
+
all_scores = []
|
|
234
|
+
|
|
235
|
+
# Process in batches
|
|
236
|
+
for i in range(0, len(pairs), self.batch_size):
|
|
237
|
+
batch = pairs[i:i + self.batch_size]
|
|
238
|
+
|
|
239
|
+
# TEI rerank endpoint expects query and texts separately
|
|
240
|
+
# All pairs in a batch should have the same query for optimal performance
|
|
241
|
+
# but we handle mixed queries by making separate requests per unique query
|
|
242
|
+
query_groups: dict[str, list[tuple[int, str]]] = {}
|
|
243
|
+
for idx, (query, text) in enumerate(batch):
|
|
244
|
+
if query not in query_groups:
|
|
245
|
+
query_groups[query] = []
|
|
246
|
+
query_groups[query].append((idx, text))
|
|
247
|
+
|
|
248
|
+
batch_scores = [0.0] * len(batch)
|
|
249
|
+
|
|
250
|
+
for query, indexed_texts in query_groups.items():
|
|
251
|
+
texts = [text for _, text in indexed_texts]
|
|
252
|
+
indices = [idx for idx, _ in indexed_texts]
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
response = self._request_with_retry(
|
|
256
|
+
"POST",
|
|
257
|
+
f"{self.base_url}/rerank",
|
|
258
|
+
json={
|
|
259
|
+
"query": query,
|
|
260
|
+
"texts": texts,
|
|
261
|
+
"return_text": False,
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
results = response.json()
|
|
265
|
+
|
|
266
|
+
# TEI returns results sorted by score descending, with original index
|
|
267
|
+
for result in results:
|
|
268
|
+
original_idx = result["index"]
|
|
269
|
+
score = result["score"]
|
|
270
|
+
# Map back to batch position
|
|
271
|
+
batch_scores[indices[original_idx]] = score
|
|
272
|
+
|
|
273
|
+
except httpx.HTTPError as e:
|
|
274
|
+
raise RuntimeError(f"TEI rerank request failed: {e}")
|
|
275
|
+
|
|
276
|
+
all_scores.extend(batch_scores)
|
|
277
|
+
|
|
278
|
+
return all_scores
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def create_cross_encoder_from_env() -> CrossEncoderModel:
|
|
282
|
+
"""
|
|
283
|
+
Create a CrossEncoderModel instance based on environment variables.
|
|
284
|
+
|
|
285
|
+
See hindsight_api.config for environment variable names and defaults.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Configured CrossEncoderModel instance
|
|
289
|
+
"""
|
|
290
|
+
provider = os.environ.get(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER).lower()
|
|
291
|
+
|
|
292
|
+
if provider == "tei":
|
|
293
|
+
url = os.environ.get(ENV_RERANKER_TEI_URL)
|
|
294
|
+
if not url:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"{ENV_RERANKER_TEI_URL} is required when {ENV_RERANKER_PROVIDER} is 'tei'"
|
|
297
|
+
)
|
|
298
|
+
return RemoteTEICrossEncoder(base_url=url)
|
|
299
|
+
elif provider == "local":
|
|
300
|
+
model = os.environ.get(ENV_RERANKER_LOCAL_MODEL)
|
|
301
|
+
model_name = model or DEFAULT_RERANKER_LOCAL_MODEL
|
|
302
|
+
return LocalSTCrossEncoder(model_name=model_name)
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei'"
|
|
306
|
+
)
|
|
@@ -5,15 +5,26 @@ Provides an interface for generating embeddings with different backends.
|
|
|
5
5
|
|
|
6
6
|
IMPORTANT: All embeddings must produce 384-dimensional vectors to match
|
|
7
7
|
the database schema (pgvector column defined as vector(384)).
|
|
8
|
+
|
|
9
|
+
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
8
10
|
"""
|
|
9
11
|
from abc import ABC, abstractmethod
|
|
10
|
-
from typing import List
|
|
12
|
+
from typing import List, Optional
|
|
11
13
|
import logging
|
|
14
|
+
import os
|
|
12
15
|
|
|
13
|
-
|
|
16
|
+
import httpx
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
18
|
+
from ..config import (
|
|
19
|
+
ENV_EMBEDDINGS_PROVIDER,
|
|
20
|
+
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
21
|
+
ENV_EMBEDDINGS_TEI_URL,
|
|
22
|
+
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
23
|
+
DEFAULT_EMBEDDINGS_LOCAL_MODEL,
|
|
24
|
+
EMBEDDING_DIMENSION,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
17
28
|
|
|
18
29
|
|
|
19
30
|
class Embeddings(ABC):
|
|
@@ -24,12 +35,18 @@ class Embeddings(ABC):
|
|
|
24
35
|
the database schema.
|
|
25
36
|
"""
|
|
26
37
|
|
|
38
|
+
@property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def provider_name(self) -> str:
|
|
41
|
+
"""Return a human-readable name for this provider (e.g., 'local', 'tei')."""
|
|
42
|
+
pass
|
|
43
|
+
|
|
27
44
|
@abstractmethod
|
|
28
|
-
def
|
|
45
|
+
async def initialize(self) -> None:
|
|
29
46
|
"""
|
|
30
|
-
|
|
47
|
+
Initialize the embedding model asynchronously.
|
|
31
48
|
|
|
32
|
-
This should be called during
|
|
49
|
+
This should be called during startup to load/connect to the model
|
|
33
50
|
and avoid cold start latency on first encode() call.
|
|
34
51
|
"""
|
|
35
52
|
pass
|
|
@@ -48,29 +65,33 @@ class Embeddings(ABC):
|
|
|
48
65
|
pass
|
|
49
66
|
|
|
50
67
|
|
|
51
|
-
class
|
|
68
|
+
class LocalSTEmbeddings(Embeddings):
|
|
52
69
|
"""
|
|
53
|
-
|
|
70
|
+
Local embeddings implementation using SentenceTransformers.
|
|
54
71
|
|
|
55
|
-
Call
|
|
72
|
+
Call initialize() during startup to load the model and avoid cold starts.
|
|
56
73
|
|
|
57
74
|
Default model is BAAI/bge-small-en-v1.5 which produces 384-dimensional
|
|
58
75
|
embeddings matching the database schema.
|
|
59
76
|
"""
|
|
60
77
|
|
|
61
|
-
def __init__(self, model_name: str =
|
|
78
|
+
def __init__(self, model_name: Optional[str] = None):
|
|
62
79
|
"""
|
|
63
|
-
Initialize SentenceTransformers embeddings.
|
|
80
|
+
Initialize local SentenceTransformers embeddings.
|
|
64
81
|
|
|
65
82
|
Args:
|
|
66
83
|
model_name: Name of the SentenceTransformer model to use.
|
|
67
84
|
Must produce 384-dimensional embeddings.
|
|
68
85
|
Default: BAAI/bge-small-en-v1.5
|
|
69
86
|
"""
|
|
70
|
-
self.model_name = model_name
|
|
87
|
+
self.model_name = model_name or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
71
88
|
self._model = None
|
|
72
89
|
|
|
73
|
-
|
|
90
|
+
@property
|
|
91
|
+
def provider_name(self) -> str:
|
|
92
|
+
return "local"
|
|
93
|
+
|
|
94
|
+
async def initialize(self) -> None:
|
|
74
95
|
"""Load the embedding model."""
|
|
75
96
|
if self._model is not None:
|
|
76
97
|
return
|
|
@@ -79,11 +100,11 @@ class SentenceTransformersEmbeddings(Embeddings):
|
|
|
79
100
|
from sentence_transformers import SentenceTransformer
|
|
80
101
|
except ImportError:
|
|
81
102
|
raise ImportError(
|
|
82
|
-
"sentence-transformers is required for
|
|
103
|
+
"sentence-transformers is required for LocalSTEmbeddings. "
|
|
83
104
|
"Install it with: pip install sentence-transformers"
|
|
84
105
|
)
|
|
85
106
|
|
|
86
|
-
logger.info(f"
|
|
107
|
+
logger.info(f"Embeddings: initializing local provider with model {self.model_name}")
|
|
87
108
|
# Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
|
|
88
109
|
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
89
110
|
self._model = SentenceTransformer(
|
|
@@ -100,7 +121,7 @@ class SentenceTransformersEmbeddings(Embeddings):
|
|
|
100
121
|
f"Use a model that produces {EMBEDDING_DIMENSION}-dimensional embeddings."
|
|
101
122
|
)
|
|
102
123
|
|
|
103
|
-
logger.info(f"
|
|
124
|
+
logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
|
|
104
125
|
|
|
105
126
|
def encode(self, texts: List[str]) -> List[List[float]]:
|
|
106
127
|
"""
|
|
@@ -113,6 +134,159 @@ class SentenceTransformersEmbeddings(Embeddings):
|
|
|
113
134
|
List of 384-dimensional embedding vectors
|
|
114
135
|
"""
|
|
115
136
|
if self._model is None:
|
|
116
|
-
|
|
137
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
117
138
|
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
118
139
|
return [emb.tolist() for emb in embeddings]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class RemoteTEIEmbeddings(Embeddings):
|
|
143
|
+
"""
|
|
144
|
+
Remote embeddings implementation using HuggingFace Text Embeddings Inference (TEI) HTTP API.
|
|
145
|
+
|
|
146
|
+
TEI provides a high-performance inference server for embedding models.
|
|
147
|
+
See: https://github.com/huggingface/text-embeddings-inference
|
|
148
|
+
|
|
149
|
+
The server should be running a model that produces 384-dimensional embeddings.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
base_url: str,
|
|
155
|
+
timeout: float = 30.0,
|
|
156
|
+
batch_size: int = 32,
|
|
157
|
+
max_retries: int = 3,
|
|
158
|
+
retry_delay: float = 0.5,
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
Initialize remote TEI embeddings client.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
base_url: Base URL of the TEI server (e.g., "http://localhost:8080")
|
|
165
|
+
timeout: Request timeout in seconds (default: 30.0)
|
|
166
|
+
batch_size: Maximum batch size for embedding requests (default: 32)
|
|
167
|
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
|
168
|
+
retry_delay: Initial delay between retries in seconds, doubles each retry (default: 0.5)
|
|
169
|
+
"""
|
|
170
|
+
self.base_url = base_url.rstrip("/")
|
|
171
|
+
self.timeout = timeout
|
|
172
|
+
self.batch_size = batch_size
|
|
173
|
+
self.max_retries = max_retries
|
|
174
|
+
self.retry_delay = retry_delay
|
|
175
|
+
self._client: Optional[httpx.Client] = None
|
|
176
|
+
self._model_id: Optional[str] = None
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def provider_name(self) -> str:
|
|
180
|
+
return "tei"
|
|
181
|
+
|
|
182
|
+
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
183
|
+
"""Make an HTTP request with automatic retries on transient errors."""
|
|
184
|
+
import time
|
|
185
|
+
last_error = None
|
|
186
|
+
delay = self.retry_delay
|
|
187
|
+
|
|
188
|
+
for attempt in range(self.max_retries + 1):
|
|
189
|
+
try:
|
|
190
|
+
if method == "GET":
|
|
191
|
+
response = self._client.get(url, **kwargs)
|
|
192
|
+
else:
|
|
193
|
+
response = self._client.post(url, **kwargs)
|
|
194
|
+
response.raise_for_status()
|
|
195
|
+
return response
|
|
196
|
+
except (httpx.ConnectError, httpx.ReadTimeout, httpx.WriteTimeout) as e:
|
|
197
|
+
last_error = e
|
|
198
|
+
if attempt < self.max_retries:
|
|
199
|
+
logger.warning(f"TEI request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
200
|
+
time.sleep(delay)
|
|
201
|
+
delay *= 2 # Exponential backoff
|
|
202
|
+
except httpx.HTTPStatusError as e:
|
|
203
|
+
# Retry on 5xx server errors
|
|
204
|
+
if e.response.status_code >= 500 and attempt < self.max_retries:
|
|
205
|
+
last_error = e
|
|
206
|
+
logger.warning(f"TEI server error (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay}s...")
|
|
207
|
+
time.sleep(delay)
|
|
208
|
+
delay *= 2
|
|
209
|
+
else:
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
raise last_error
|
|
213
|
+
|
|
214
|
+
async def initialize(self) -> None:
|
|
215
|
+
"""Initialize the HTTP client and verify server connectivity."""
|
|
216
|
+
if self._client is not None:
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
logger.info(f"Embeddings: initializing TEI provider at {self.base_url}")
|
|
220
|
+
self._client = httpx.Client(timeout=self.timeout)
|
|
221
|
+
|
|
222
|
+
# Verify server is reachable and get model info
|
|
223
|
+
try:
|
|
224
|
+
response = self._request_with_retry("GET", f"{self.base_url}/info")
|
|
225
|
+
info = response.json()
|
|
226
|
+
self._model_id = info.get("model_id", "unknown")
|
|
227
|
+
logger.info(f"Embeddings: TEI provider initialized (model: {self._model_id})")
|
|
228
|
+
except httpx.HTTPError as e:
|
|
229
|
+
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
230
|
+
|
|
231
|
+
def encode(self, texts: List[str]) -> List[List[float]]:
|
|
232
|
+
"""
|
|
233
|
+
Generate embeddings using the remote TEI server.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
texts: List of text strings to encode
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
List of embedding vectors
|
|
240
|
+
"""
|
|
241
|
+
if self._client is None:
|
|
242
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
243
|
+
|
|
244
|
+
if not texts:
|
|
245
|
+
return []
|
|
246
|
+
|
|
247
|
+
all_embeddings = []
|
|
248
|
+
|
|
249
|
+
# Process in batches
|
|
250
|
+
for i in range(0, len(texts), self.batch_size):
|
|
251
|
+
batch = texts[i:i + self.batch_size]
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
response = self._request_with_retry(
|
|
255
|
+
"POST",
|
|
256
|
+
f"{self.base_url}/embed",
|
|
257
|
+
json={"inputs": batch},
|
|
258
|
+
)
|
|
259
|
+
batch_embeddings = response.json()
|
|
260
|
+
all_embeddings.extend(batch_embeddings)
|
|
261
|
+
except httpx.HTTPError as e:
|
|
262
|
+
raise RuntimeError(f"TEI embedding request failed: {e}")
|
|
263
|
+
|
|
264
|
+
return all_embeddings
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def create_embeddings_from_env() -> Embeddings:
|
|
268
|
+
"""
|
|
269
|
+
Create an Embeddings instance based on environment variables.
|
|
270
|
+
|
|
271
|
+
See hindsight_api.config for environment variable names and defaults.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Configured Embeddings instance
|
|
275
|
+
"""
|
|
276
|
+
provider = os.environ.get(ENV_EMBEDDINGS_PROVIDER, DEFAULT_EMBEDDINGS_PROVIDER).lower()
|
|
277
|
+
|
|
278
|
+
if provider == "tei":
|
|
279
|
+
url = os.environ.get(ENV_EMBEDDINGS_TEI_URL)
|
|
280
|
+
if not url:
|
|
281
|
+
raise ValueError(
|
|
282
|
+
f"{ENV_EMBEDDINGS_TEI_URL} is required when {ENV_EMBEDDINGS_PROVIDER} is 'tei'"
|
|
283
|
+
)
|
|
284
|
+
return RemoteTEIEmbeddings(base_url=url)
|
|
285
|
+
elif provider == "local":
|
|
286
|
+
model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
|
|
287
|
+
model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
288
|
+
return LocalSTEmbeddings(model_name=model_name)
|
|
289
|
+
else:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei'"
|
|
292
|
+
)
|