hindsight-api 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +252 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/api/http.py +282 -20
- hindsight_api/api/mcp.py +47 -52
- hindsight_api/config.py +238 -6
- hindsight_api/engine/cross_encoder.py +599 -86
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/embeddings.py +453 -26
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +8 -4
- hindsight_api/engine/llm_wrapper.py +241 -27
- hindsight_api/engine/memory_engine.py +609 -122
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/response_models.py +38 -0
- hindsight_api/engine/retain/fact_extraction.py +388 -192
- hindsight_api/engine/retain/fact_storage.py +34 -8
- hindsight_api/engine/retain/link_utils.py +24 -16
- hindsight_api/engine/retain/orchestrator.py +52 -17
- hindsight_api/engine/retain/types.py +9 -0
- hindsight_api/engine/search/graph_retrieval.py +42 -13
- hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +847 -200
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/trace.py +12 -0
- hindsight_api/engine/search/tracer.py +24 -1
- hindsight_api/engine/search/types.py +21 -0
- hindsight_api/engine/task_backend.py +109 -18
- hindsight_api/engine/utils.py +1 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/main.py +56 -4
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -1
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
- hindsight_api-0.3.0.dist-info/RECORD +82 -0
- {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.2.0.dist-info/RECORD +0 -75
- {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -3,8 +3,8 @@ Embeddings abstraction for the memory system.
|
|
|
3
3
|
|
|
4
4
|
Provides an interface for generating embeddings with different backends.
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
The embedding dimension is auto-detected from the model at initialization.
|
|
7
|
+
The database schema is automatically adjusted to match the model's dimension.
|
|
8
8
|
|
|
9
9
|
Configuration via environment variables - see hindsight_api.config for all env var names.
|
|
10
10
|
"""
|
|
@@ -16,12 +16,25 @@ from abc import ABC, abstractmethod
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
18
18
|
from ..config import (
|
|
19
|
+
DEFAULT_EMBEDDINGS_COHERE_MODEL,
|
|
20
|
+
DEFAULT_EMBEDDINGS_LITELLM_MODEL,
|
|
19
21
|
DEFAULT_EMBEDDINGS_LOCAL_MODEL,
|
|
22
|
+
DEFAULT_EMBEDDINGS_OPENAI_MODEL,
|
|
20
23
|
DEFAULT_EMBEDDINGS_PROVIDER,
|
|
21
|
-
|
|
24
|
+
DEFAULT_LITELLM_API_BASE,
|
|
25
|
+
ENV_COHERE_API_KEY,
|
|
26
|
+
ENV_EMBEDDINGS_COHERE_BASE_URL,
|
|
27
|
+
ENV_EMBEDDINGS_COHERE_MODEL,
|
|
28
|
+
ENV_EMBEDDINGS_LITELLM_MODEL,
|
|
22
29
|
ENV_EMBEDDINGS_LOCAL_MODEL,
|
|
30
|
+
ENV_EMBEDDINGS_OPENAI_API_KEY,
|
|
31
|
+
ENV_EMBEDDINGS_OPENAI_BASE_URL,
|
|
32
|
+
ENV_EMBEDDINGS_OPENAI_MODEL,
|
|
23
33
|
ENV_EMBEDDINGS_PROVIDER,
|
|
24
34
|
ENV_EMBEDDINGS_TEI_URL,
|
|
35
|
+
ENV_LITELLM_API_BASE,
|
|
36
|
+
ENV_LITELLM_API_KEY,
|
|
37
|
+
ENV_LLM_API_KEY,
|
|
25
38
|
)
|
|
26
39
|
|
|
27
40
|
logger = logging.getLogger(__name__)
|
|
@@ -31,8 +44,8 @@ class Embeddings(ABC):
|
|
|
31
44
|
"""
|
|
32
45
|
Abstract base class for embedding generation.
|
|
33
46
|
|
|
34
|
-
|
|
35
|
-
|
|
47
|
+
The embedding dimension is determined by the model and detected at initialization.
|
|
48
|
+
The database schema is automatically adjusted to match the model's dimension.
|
|
36
49
|
"""
|
|
37
50
|
|
|
38
51
|
@property
|
|
@@ -41,6 +54,12 @@ class Embeddings(ABC):
|
|
|
41
54
|
"""Return a human-readable name for this provider (e.g., 'local', 'tei')."""
|
|
42
55
|
pass
|
|
43
56
|
|
|
57
|
+
@property
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def dimension(self) -> int:
|
|
60
|
+
"""Return the embedding dimension produced by this model."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
44
63
|
@abstractmethod
|
|
45
64
|
async def initialize(self) -> None:
|
|
46
65
|
"""
|
|
@@ -54,13 +73,13 @@ class Embeddings(ABC):
|
|
|
54
73
|
@abstractmethod
|
|
55
74
|
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
56
75
|
"""
|
|
57
|
-
Generate
|
|
76
|
+
Generate embeddings for a list of texts.
|
|
58
77
|
|
|
59
78
|
Args:
|
|
60
79
|
texts: List of text strings to encode
|
|
61
80
|
|
|
62
81
|
Returns:
|
|
63
|
-
List of
|
|
82
|
+
List of embedding vectors (each is a list of floats)
|
|
64
83
|
"""
|
|
65
84
|
pass
|
|
66
85
|
|
|
@@ -70,9 +89,7 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
70
89
|
Local embeddings implementation using SentenceTransformers.
|
|
71
90
|
|
|
72
91
|
Call initialize() during startup to load the model and avoid cold starts.
|
|
73
|
-
|
|
74
|
-
Default model is BAAI/bge-small-en-v1.5 which produces 384-dimensional
|
|
75
|
-
embeddings matching the database schema.
|
|
92
|
+
The embedding dimension is auto-detected from the model.
|
|
76
93
|
"""
|
|
77
94
|
|
|
78
95
|
def __init__(self, model_name: str | None = None):
|
|
@@ -81,16 +98,22 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
81
98
|
|
|
82
99
|
Args:
|
|
83
100
|
model_name: Name of the SentenceTransformer model to use.
|
|
84
|
-
Must produce 384-dimensional embeddings.
|
|
85
101
|
Default: BAAI/bge-small-en-v1.5
|
|
86
102
|
"""
|
|
87
103
|
self.model_name = model_name or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
88
104
|
self._model = None
|
|
105
|
+
self._dimension: int | None = None
|
|
89
106
|
|
|
90
107
|
@property
|
|
91
108
|
def provider_name(self) -> str:
|
|
92
109
|
return "local"
|
|
93
110
|
|
|
111
|
+
@property
|
|
112
|
+
def dimension(self) -> int:
|
|
113
|
+
if self._dimension is None:
|
|
114
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
115
|
+
return self._dimension
|
|
116
|
+
|
|
94
117
|
async def initialize(self) -> None:
|
|
95
118
|
"""Load the embedding model."""
|
|
96
119
|
if self._model is not None:
|
|
@@ -112,26 +135,18 @@ class LocalSTEmbeddings(Embeddings):
|
|
|
112
135
|
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
113
136
|
)
|
|
114
137
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
if model_dim != EMBEDDING_DIMENSION:
|
|
118
|
-
raise ValueError(
|
|
119
|
-
f"Model {self.model_name} produces {model_dim}-dimensional embeddings, "
|
|
120
|
-
f"but database schema requires {EMBEDDING_DIMENSION} dimensions. "
|
|
121
|
-
f"Use a model that produces {EMBEDDING_DIMENSION}-dimensional embeddings."
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
logger.info(f"Embeddings: local provider initialized (dim: {model_dim})")
|
|
138
|
+
self._dimension = self._model.get_sentence_embedding_dimension()
|
|
139
|
+
logger.info(f"Embeddings: local provider initialized (dim: {self._dimension})")
|
|
125
140
|
|
|
126
141
|
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
127
142
|
"""
|
|
128
|
-
Generate
|
|
143
|
+
Generate embeddings for a list of texts.
|
|
129
144
|
|
|
130
145
|
Args:
|
|
131
146
|
texts: List of text strings to encode
|
|
132
147
|
|
|
133
148
|
Returns:
|
|
134
|
-
List of
|
|
149
|
+
List of embedding vectors
|
|
135
150
|
"""
|
|
136
151
|
if self._model is None:
|
|
137
152
|
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
@@ -146,7 +161,7 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
146
161
|
TEI provides a high-performance inference server for embedding models.
|
|
147
162
|
See: https://github.com/huggingface/text-embeddings-inference
|
|
148
163
|
|
|
149
|
-
The
|
|
164
|
+
The embedding dimension is auto-detected from the server at initialization.
|
|
150
165
|
"""
|
|
151
166
|
|
|
152
167
|
def __init__(
|
|
@@ -174,11 +189,18 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
174
189
|
self.retry_delay = retry_delay
|
|
175
190
|
self._client: httpx.Client | None = None
|
|
176
191
|
self._model_id: str | None = None
|
|
192
|
+
self._dimension: int | None = None
|
|
177
193
|
|
|
178
194
|
@property
|
|
179
195
|
def provider_name(self) -> str:
|
|
180
196
|
return "tei"
|
|
181
197
|
|
|
198
|
+
@property
|
|
199
|
+
def dimension(self) -> int:
|
|
200
|
+
if self._dimension is None:
|
|
201
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
202
|
+
return self._dimension
|
|
203
|
+
|
|
182
204
|
def _request_with_retry(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
183
205
|
"""Make an HTTP request with automatic retries on transient errors."""
|
|
184
206
|
import time
|
|
@@ -229,7 +251,24 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
229
251
|
response = self._request_with_retry("GET", f"{self.base_url}/info")
|
|
230
252
|
info = response.json()
|
|
231
253
|
self._model_id = info.get("model_id", "unknown")
|
|
232
|
-
|
|
254
|
+
|
|
255
|
+
# Get dimension from server info or by doing a test embedding
|
|
256
|
+
if "max_input_length" in info and "model_dtype" in info:
|
|
257
|
+
# Try to get dimension from info endpoint (some TEI versions expose it)
|
|
258
|
+
# If not available, do a test embedding
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
# Do a test embedding to detect dimension
|
|
262
|
+
test_response = self._request_with_retry(
|
|
263
|
+
"POST",
|
|
264
|
+
f"{self.base_url}/embed",
|
|
265
|
+
json={"inputs": ["test"]},
|
|
266
|
+
)
|
|
267
|
+
test_embeddings = test_response.json()
|
|
268
|
+
if test_embeddings and len(test_embeddings) > 0:
|
|
269
|
+
self._dimension = len(test_embeddings[0])
|
|
270
|
+
|
|
271
|
+
logger.info(f"Embeddings: TEI provider initialized (model: {self._model_id}, dim: {self._dimension})")
|
|
233
272
|
except httpx.HTTPError as e:
|
|
234
273
|
raise RuntimeError(f"Failed to connect to TEI server at {self.base_url}: {e}")
|
|
235
274
|
|
|
@@ -269,6 +308,369 @@ class RemoteTEIEmbeddings(Embeddings):
|
|
|
269
308
|
return all_embeddings
|
|
270
309
|
|
|
271
310
|
|
|
311
|
+
class OpenAIEmbeddings(Embeddings):
|
|
312
|
+
"""
|
|
313
|
+
OpenAI embeddings implementation using the OpenAI API.
|
|
314
|
+
|
|
315
|
+
Supports text-embedding-3-small (1536 dims), text-embedding-3-large (3072 dims),
|
|
316
|
+
and text-embedding-ada-002 (1536 dims, legacy).
|
|
317
|
+
|
|
318
|
+
The embedding dimension is auto-detected from the model at initialization.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
# Known dimensions for OpenAI embedding models
|
|
322
|
+
MODEL_DIMENSIONS = {
|
|
323
|
+
"text-embedding-3-small": 1536,
|
|
324
|
+
"text-embedding-3-large": 3072,
|
|
325
|
+
"text-embedding-ada-002": 1536,
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
api_key: str,
|
|
331
|
+
model: str = DEFAULT_EMBEDDINGS_OPENAI_MODEL,
|
|
332
|
+
base_url: str | None = None,
|
|
333
|
+
batch_size: int = 100,
|
|
334
|
+
max_retries: int = 3,
|
|
335
|
+
):
|
|
336
|
+
"""
|
|
337
|
+
Initialize OpenAI embeddings client.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
api_key: OpenAI API key
|
|
341
|
+
model: OpenAI embedding model name (default: text-embedding-3-small)
|
|
342
|
+
base_url: Custom base URL for OpenAI-compatible API (e.g., Azure OpenAI endpoint)
|
|
343
|
+
batch_size: Maximum batch size for embedding requests (default: 100)
|
|
344
|
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
|
345
|
+
"""
|
|
346
|
+
self.api_key = api_key
|
|
347
|
+
self.model = model
|
|
348
|
+
self.base_url = base_url
|
|
349
|
+
self.batch_size = batch_size
|
|
350
|
+
self.max_retries = max_retries
|
|
351
|
+
self._client = None
|
|
352
|
+
self._dimension: int | None = None
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def provider_name(self) -> str:
|
|
356
|
+
return "openai"
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def dimension(self) -> int:
|
|
360
|
+
if self._dimension is None:
|
|
361
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
362
|
+
return self._dimension
|
|
363
|
+
|
|
364
|
+
async def initialize(self) -> None:
|
|
365
|
+
"""Initialize the OpenAI client and detect dimension."""
|
|
366
|
+
if self._client is not None:
|
|
367
|
+
return
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
from openai import OpenAI
|
|
371
|
+
except ImportError:
|
|
372
|
+
raise ImportError("openai is required for OpenAIEmbeddings. Install it with: pip install openai")
|
|
373
|
+
|
|
374
|
+
base_url_msg = f" at {self.base_url}" if self.base_url else ""
|
|
375
|
+
logger.info(f"Embeddings: initializing OpenAI provider with model {self.model}{base_url_msg}")
|
|
376
|
+
|
|
377
|
+
# Build client kwargs, only including base_url if set (for Azure or custom endpoints)
|
|
378
|
+
client_kwargs = {"api_key": self.api_key, "max_retries": self.max_retries}
|
|
379
|
+
if self.base_url:
|
|
380
|
+
client_kwargs["base_url"] = self.base_url
|
|
381
|
+
self._client = OpenAI(**client_kwargs)
|
|
382
|
+
|
|
383
|
+
# Try to get dimension from known models, otherwise do a test embedding
|
|
384
|
+
if self.model in self.MODEL_DIMENSIONS:
|
|
385
|
+
self._dimension = self.MODEL_DIMENSIONS[self.model]
|
|
386
|
+
else:
|
|
387
|
+
# Do a test embedding to detect dimension
|
|
388
|
+
response = self._client.embeddings.create(
|
|
389
|
+
model=self.model,
|
|
390
|
+
input=["test"],
|
|
391
|
+
)
|
|
392
|
+
if response.data:
|
|
393
|
+
self._dimension = len(response.data[0].embedding)
|
|
394
|
+
|
|
395
|
+
logger.info(f"Embeddings: OpenAI provider initialized (model: {self.model}, dim: {self._dimension})")
|
|
396
|
+
|
|
397
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
398
|
+
"""
|
|
399
|
+
Generate embeddings using the OpenAI API.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
texts: List of text strings to encode
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
List of embedding vectors
|
|
406
|
+
"""
|
|
407
|
+
if self._client is None:
|
|
408
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
409
|
+
|
|
410
|
+
if not texts:
|
|
411
|
+
return []
|
|
412
|
+
|
|
413
|
+
all_embeddings = []
|
|
414
|
+
|
|
415
|
+
# Process in batches
|
|
416
|
+
for i in range(0, len(texts), self.batch_size):
|
|
417
|
+
batch = texts[i : i + self.batch_size]
|
|
418
|
+
|
|
419
|
+
response = self._client.embeddings.create(
|
|
420
|
+
model=self.model,
|
|
421
|
+
input=batch,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Sort by index to ensure correct order
|
|
425
|
+
batch_embeddings = sorted(response.data, key=lambda x: x.index)
|
|
426
|
+
all_embeddings.extend([e.embedding for e in batch_embeddings])
|
|
427
|
+
|
|
428
|
+
return all_embeddings
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class CohereEmbeddings(Embeddings):
|
|
432
|
+
"""
|
|
433
|
+
Cohere embeddings implementation using the Cohere API.
|
|
434
|
+
|
|
435
|
+
Supports embed-english-v3.0 (1024 dims) and embed-multilingual-v3.0 (1024 dims).
|
|
436
|
+
|
|
437
|
+
The embedding dimension is auto-detected from the model at initialization.
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
# Known dimensions for Cohere embedding models
|
|
441
|
+
MODEL_DIMENSIONS = {
|
|
442
|
+
"embed-english-v3.0": 1024,
|
|
443
|
+
"embed-multilingual-v3.0": 1024,
|
|
444
|
+
"embed-english-light-v3.0": 384,
|
|
445
|
+
"embed-multilingual-light-v3.0": 384,
|
|
446
|
+
"embed-english-v2.0": 4096,
|
|
447
|
+
"embed-multilingual-v2.0": 768,
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
api_key: str,
|
|
453
|
+
model: str = DEFAULT_EMBEDDINGS_COHERE_MODEL,
|
|
454
|
+
base_url: str | None = None,
|
|
455
|
+
batch_size: int = 96,
|
|
456
|
+
timeout: float = 60.0,
|
|
457
|
+
input_type: str = "search_document",
|
|
458
|
+
):
|
|
459
|
+
"""
|
|
460
|
+
Initialize Cohere embeddings client.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
api_key: Cohere API key
|
|
464
|
+
model: Cohere embedding model name (default: embed-english-v3.0)
|
|
465
|
+
base_url: Custom base URL for Cohere-compatible API (e.g., Azure-hosted endpoint)
|
|
466
|
+
batch_size: Maximum batch size for embedding requests (default: 96, Cohere's limit)
|
|
467
|
+
timeout: Request timeout in seconds (default: 60.0)
|
|
468
|
+
input_type: Input type for embeddings (default: search_document).
|
|
469
|
+
Options: search_document, search_query, classification, clustering
|
|
470
|
+
"""
|
|
471
|
+
self.api_key = api_key
|
|
472
|
+
self.model = model
|
|
473
|
+
self.base_url = base_url
|
|
474
|
+
self.batch_size = batch_size
|
|
475
|
+
self.timeout = timeout
|
|
476
|
+
self.input_type = input_type
|
|
477
|
+
self._client = None
|
|
478
|
+
self._dimension: int | None = None
|
|
479
|
+
|
|
480
|
+
@property
|
|
481
|
+
def provider_name(self) -> str:
|
|
482
|
+
return "cohere"
|
|
483
|
+
|
|
484
|
+
@property
|
|
485
|
+
def dimension(self) -> int:
|
|
486
|
+
if self._dimension is None:
|
|
487
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
488
|
+
return self._dimension
|
|
489
|
+
|
|
490
|
+
async def initialize(self) -> None:
|
|
491
|
+
"""Initialize the Cohere client and detect dimension."""
|
|
492
|
+
if self._client is not None:
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
try:
|
|
496
|
+
import cohere
|
|
497
|
+
except ImportError:
|
|
498
|
+
raise ImportError("cohere is required for CohereEmbeddings. Install it with: pip install cohere")
|
|
499
|
+
|
|
500
|
+
base_url_msg = f" at {self.base_url}" if self.base_url else ""
|
|
501
|
+
logger.info(f"Embeddings: initializing Cohere provider with model {self.model}{base_url_msg}")
|
|
502
|
+
|
|
503
|
+
# Build client kwargs, only including base_url if set (for Azure or custom endpoints)
|
|
504
|
+
client_kwargs = {"api_key": self.api_key, "timeout": self.timeout}
|
|
505
|
+
if self.base_url:
|
|
506
|
+
client_kwargs["base_url"] = self.base_url
|
|
507
|
+
self._client = cohere.Client(**client_kwargs)
|
|
508
|
+
|
|
509
|
+
# Try to get dimension from known models, otherwise do a test embedding
|
|
510
|
+
if self.model in self.MODEL_DIMENSIONS:
|
|
511
|
+
self._dimension = self.MODEL_DIMENSIONS[self.model]
|
|
512
|
+
else:
|
|
513
|
+
# Do a test embedding to detect dimension
|
|
514
|
+
response = self._client.embed(
|
|
515
|
+
texts=["test"],
|
|
516
|
+
model=self.model,
|
|
517
|
+
input_type=self.input_type,
|
|
518
|
+
)
|
|
519
|
+
if response.embeddings:
|
|
520
|
+
self._dimension = len(response.embeddings[0])
|
|
521
|
+
|
|
522
|
+
logger.info(f"Embeddings: Cohere provider initialized (model: {self.model}, dim: {self._dimension})")
|
|
523
|
+
|
|
524
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
525
|
+
"""
|
|
526
|
+
Generate embeddings using the Cohere API.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
texts: List of text strings to encode
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
List of embedding vectors
|
|
533
|
+
"""
|
|
534
|
+
if self._client is None:
|
|
535
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
536
|
+
|
|
537
|
+
if not texts:
|
|
538
|
+
return []
|
|
539
|
+
|
|
540
|
+
all_embeddings = []
|
|
541
|
+
|
|
542
|
+
# Process in batches
|
|
543
|
+
for i in range(0, len(texts), self.batch_size):
|
|
544
|
+
batch = texts[i : i + self.batch_size]
|
|
545
|
+
|
|
546
|
+
response = self._client.embed(
|
|
547
|
+
texts=batch,
|
|
548
|
+
model=self.model,
|
|
549
|
+
input_type=self.input_type,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
all_embeddings.extend(response.embeddings)
|
|
553
|
+
|
|
554
|
+
return all_embeddings
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class LiteLLMEmbeddings(Embeddings):
|
|
558
|
+
"""
|
|
559
|
+
LiteLLM embeddings implementation using LiteLLM proxy's /embeddings endpoint.
|
|
560
|
+
|
|
561
|
+
LiteLLM provides a unified interface for multiple embedding providers.
|
|
562
|
+
The proxy exposes an OpenAI-compatible /embeddings endpoint.
|
|
563
|
+
See: https://docs.litellm.ai/docs/embedding/supported_embedding
|
|
564
|
+
|
|
565
|
+
Supported providers via LiteLLM:
|
|
566
|
+
- OpenAI (text-embedding-3-small, text-embedding-ada-002, etc.)
|
|
567
|
+
- Cohere (embed-english-v3.0, etc.) - prefix with cohere/
|
|
568
|
+
- Vertex AI (textembedding-gecko, etc.) - prefix with vertex_ai/
|
|
569
|
+
- HuggingFace, Mistral, Voyage AI, etc.
|
|
570
|
+
|
|
571
|
+
The embedding dimension is auto-detected from the model at initialization.
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
def __init__(
|
|
575
|
+
self,
|
|
576
|
+
api_base: str = DEFAULT_LITELLM_API_BASE,
|
|
577
|
+
api_key: str | None = None,
|
|
578
|
+
model: str = DEFAULT_EMBEDDINGS_LITELLM_MODEL,
|
|
579
|
+
batch_size: int = 100,
|
|
580
|
+
timeout: float = 60.0,
|
|
581
|
+
):
|
|
582
|
+
"""
|
|
583
|
+
Initialize LiteLLM embeddings client.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
api_base: Base URL of the LiteLLM proxy (default: http://localhost:4000)
|
|
587
|
+
api_key: API key for the LiteLLM proxy (optional, depends on proxy config)
|
|
588
|
+
model: Embedding model name (default: text-embedding-3-small)
|
|
589
|
+
Use provider prefix for non-OpenAI models (e.g., cohere/embed-english-v3.0)
|
|
590
|
+
batch_size: Maximum batch size for embedding requests (default: 100)
|
|
591
|
+
timeout: Request timeout in seconds (default: 60.0)
|
|
592
|
+
"""
|
|
593
|
+
self.api_base = api_base.rstrip("/")
|
|
594
|
+
self.api_key = api_key
|
|
595
|
+
self.model = model
|
|
596
|
+
self.batch_size = batch_size
|
|
597
|
+
self.timeout = timeout
|
|
598
|
+
self._client: httpx.Client | None = None
|
|
599
|
+
self._dimension: int | None = None
|
|
600
|
+
|
|
601
|
+
@property
|
|
602
|
+
def provider_name(self) -> str:
|
|
603
|
+
return "litellm"
|
|
604
|
+
|
|
605
|
+
@property
|
|
606
|
+
def dimension(self) -> int:
|
|
607
|
+
if self._dimension is None:
|
|
608
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
609
|
+
return self._dimension
|
|
610
|
+
|
|
611
|
+
async def initialize(self) -> None:
|
|
612
|
+
"""Initialize the HTTP client and detect embedding dimension."""
|
|
613
|
+
if self._client is not None:
|
|
614
|
+
return
|
|
615
|
+
|
|
616
|
+
logger.info(f"Embeddings: initializing LiteLLM provider at {self.api_base} with model {self.model}")
|
|
617
|
+
|
|
618
|
+
headers = {"Content-Type": "application/json"}
|
|
619
|
+
if self.api_key:
|
|
620
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
621
|
+
|
|
622
|
+
self._client = httpx.Client(timeout=self.timeout, headers=headers)
|
|
623
|
+
|
|
624
|
+
# Do a test embedding to detect dimension
|
|
625
|
+
try:
|
|
626
|
+
response = self._client.post(
|
|
627
|
+
f"{self.api_base}/embeddings",
|
|
628
|
+
json={"model": self.model, "input": ["test"]},
|
|
629
|
+
)
|
|
630
|
+
response.raise_for_status()
|
|
631
|
+
result = response.json()
|
|
632
|
+
if result.get("data") and len(result["data"]) > 0:
|
|
633
|
+
self._dimension = len(result["data"][0]["embedding"])
|
|
634
|
+
logger.info(f"Embeddings: LiteLLM provider initialized (model: {self.model}, dim: {self._dimension})")
|
|
635
|
+
except httpx.HTTPError as e:
|
|
636
|
+
raise RuntimeError(f"Failed to connect to LiteLLM proxy at {self.api_base}: {e}")
|
|
637
|
+
|
|
638
|
+
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
639
|
+
"""
|
|
640
|
+
Generate embeddings using the LiteLLM proxy.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
texts: List of text strings to encode
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
List of embedding vectors
|
|
647
|
+
"""
|
|
648
|
+
if self._client is None:
|
|
649
|
+
raise RuntimeError("Embeddings not initialized. Call initialize() first.")
|
|
650
|
+
|
|
651
|
+
if not texts:
|
|
652
|
+
return []
|
|
653
|
+
|
|
654
|
+
all_embeddings = []
|
|
655
|
+
|
|
656
|
+
# Process in batches
|
|
657
|
+
for i in range(0, len(texts), self.batch_size):
|
|
658
|
+
batch = texts[i : i + self.batch_size]
|
|
659
|
+
|
|
660
|
+
response = self._client.post(
|
|
661
|
+
f"{self.api_base}/embeddings",
|
|
662
|
+
json={"model": self.model, "input": batch},
|
|
663
|
+
)
|
|
664
|
+
response.raise_for_status()
|
|
665
|
+
result = response.json()
|
|
666
|
+
|
|
667
|
+
# Sort by index to ensure correct order
|
|
668
|
+
batch_embeddings = sorted(result["data"], key=lambda x: x["index"])
|
|
669
|
+
all_embeddings.extend([e["embedding"] for e in batch_embeddings])
|
|
670
|
+
|
|
671
|
+
return all_embeddings
|
|
672
|
+
|
|
673
|
+
|
|
272
674
|
def create_embeddings_from_env() -> Embeddings:
|
|
273
675
|
"""
|
|
274
676
|
Create an Embeddings instance based on environment variables.
|
|
@@ -289,5 +691,30 @@ def create_embeddings_from_env() -> Embeddings:
|
|
|
289
691
|
model = os.environ.get(ENV_EMBEDDINGS_LOCAL_MODEL)
|
|
290
692
|
model_name = model or DEFAULT_EMBEDDINGS_LOCAL_MODEL
|
|
291
693
|
return LocalSTEmbeddings(model_name=model_name)
|
|
694
|
+
elif provider == "openai":
|
|
695
|
+
# Use dedicated embeddings API key, or fall back to LLM API key
|
|
696
|
+
api_key = os.environ.get(ENV_EMBEDDINGS_OPENAI_API_KEY) or os.environ.get(ENV_LLM_API_KEY)
|
|
697
|
+
if not api_key:
|
|
698
|
+
raise ValueError(
|
|
699
|
+
f"{ENV_EMBEDDINGS_OPENAI_API_KEY} or {ENV_LLM_API_KEY} is required "
|
|
700
|
+
f"when {ENV_EMBEDDINGS_PROVIDER} is 'openai'"
|
|
701
|
+
)
|
|
702
|
+
model = os.environ.get(ENV_EMBEDDINGS_OPENAI_MODEL, DEFAULT_EMBEDDINGS_OPENAI_MODEL)
|
|
703
|
+
base_url = os.environ.get(ENV_EMBEDDINGS_OPENAI_BASE_URL) or None
|
|
704
|
+
return OpenAIEmbeddings(api_key=api_key, model=model, base_url=base_url)
|
|
705
|
+
elif provider == "cohere":
|
|
706
|
+
api_key = os.environ.get(ENV_COHERE_API_KEY)
|
|
707
|
+
if not api_key:
|
|
708
|
+
raise ValueError(f"{ENV_COHERE_API_KEY} is required when {ENV_EMBEDDINGS_PROVIDER} is 'cohere'")
|
|
709
|
+
model = os.environ.get(ENV_EMBEDDINGS_COHERE_MODEL, DEFAULT_EMBEDDINGS_COHERE_MODEL)
|
|
710
|
+
base_url = os.environ.get(ENV_EMBEDDINGS_COHERE_BASE_URL) or None
|
|
711
|
+
return CohereEmbeddings(api_key=api_key, model=model, base_url=base_url)
|
|
712
|
+
elif provider == "litellm":
|
|
713
|
+
api_base = os.environ.get(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE)
|
|
714
|
+
api_key = os.environ.get(ENV_LITELLM_API_KEY)
|
|
715
|
+
model = os.environ.get(ENV_EMBEDDINGS_LITELLM_MODEL, DEFAULT_EMBEDDINGS_LITELLM_MODEL)
|
|
716
|
+
return LiteLLMEmbeddings(api_base=api_base, api_key=api_key, model=model)
|
|
292
717
|
else:
|
|
293
|
-
raise ValueError(
|
|
718
|
+
raise ValueError(
|
|
719
|
+
f"Unknown embeddings provider: {provider}. Supported: 'local', 'tei', 'openai', 'cohere', 'litellm'"
|
|
720
|
+
)
|
|
@@ -209,7 +209,7 @@ class EntityResolver:
|
|
|
209
209
|
# This handles duplicates via ON CONFLICT and returns all IDs
|
|
210
210
|
if entities_to_create:
|
|
211
211
|
# Group entities by canonical name (lowercase) to handle duplicates within batch
|
|
212
|
-
# For duplicates, we only insert once and reuse the ID
|
|
212
|
+
# For duplicates, we only insert once and reuse the ID, but track the count
|
|
213
213
|
unique_entities = {} # lowercase_name -> (entity_data, event_date, [indices])
|
|
214
214
|
for idx, entity_data, event_date in entities_to_create:
|
|
215
215
|
name_lower = entity_data["text"].lower()
|
|
@@ -223,29 +223,32 @@ class EntityResolver:
|
|
|
223
223
|
# Use a single query with unnest for speed
|
|
224
224
|
entity_names = []
|
|
225
225
|
entity_dates = []
|
|
226
|
+
entity_counts = [] # Track how many times each entity appears in this batch
|
|
226
227
|
indices_map = [] # Maps result index -> list of original indices
|
|
227
228
|
|
|
228
229
|
for name_lower, (entity_data, event_date, indices) in unique_entities.items():
|
|
229
230
|
entity_names.append(entity_data["text"])
|
|
230
231
|
entity_dates.append(event_date)
|
|
232
|
+
entity_counts.append(len(indices)) # Count of occurrences in this batch
|
|
231
233
|
indices_map.append(indices)
|
|
232
234
|
|
|
233
235
|
# Batch INSERT ... ON CONFLICT with RETURNING
|
|
234
|
-
#
|
|
236
|
+
# Uses the batch count for mention_count instead of always 1
|
|
235
237
|
rows = await conn.fetch(
|
|
236
238
|
f"""
|
|
237
239
|
INSERT INTO {fq_table("entities")} (bank_id, canonical_name, first_seen, last_seen, mention_count)
|
|
238
|
-
SELECT $1, name, event_date, event_date,
|
|
239
|
-
FROM unnest($2::text[], $3::timestamptz[]) AS t(name, event_date)
|
|
240
|
+
SELECT $1, name, event_date, event_date, cnt
|
|
241
|
+
FROM unnest($2::text[], $3::timestamptz[], $4::int[]) AS t(name, event_date, cnt)
|
|
240
242
|
ON CONFLICT (bank_id, LOWER(canonical_name))
|
|
241
243
|
DO UPDATE SET
|
|
242
|
-
mention_count = {fq_table("entities")}.mention_count +
|
|
244
|
+
mention_count = {fq_table("entities")}.mention_count + EXCLUDED.mention_count,
|
|
243
245
|
last_seen = EXCLUDED.last_seen
|
|
244
246
|
RETURNING id
|
|
245
247
|
""",
|
|
246
248
|
bank_id,
|
|
247
249
|
entity_names,
|
|
248
250
|
entity_dates,
|
|
251
|
+
entity_counts,
|
|
249
252
|
)
|
|
250
253
|
|
|
251
254
|
# Map returned IDs back to original indices
|
|
@@ -289,6 +289,7 @@ class MemoryEngineInterface(ABC):
|
|
|
289
289
|
bank_id: str,
|
|
290
290
|
*,
|
|
291
291
|
fact_type: str | None = None,
|
|
292
|
+
limit: int = 1000,
|
|
292
293
|
request_context: "RequestContext",
|
|
293
294
|
) -> dict[str, Any]:
|
|
294
295
|
"""
|
|
@@ -297,10 +298,11 @@ class MemoryEngineInterface(ABC):
|
|
|
297
298
|
Args:
|
|
298
299
|
bank_id: The memory bank ID.
|
|
299
300
|
fact_type: Filter by fact type.
|
|
301
|
+
limit: Maximum number of items to return (default: 1000).
|
|
300
302
|
request_context: Request context for authentication.
|
|
301
303
|
|
|
302
304
|
Returns:
|
|
303
|
-
Dict with nodes, edges, table_rows, total_units.
|
|
305
|
+
Dict with nodes, edges, table_rows, total_units, limit.
|
|
304
306
|
"""
|
|
305
307
|
...
|
|
306
308
|
|
|
@@ -404,18 +406,20 @@ class MemoryEngineInterface(ABC):
|
|
|
404
406
|
bank_id: str,
|
|
405
407
|
*,
|
|
406
408
|
limit: int = 100,
|
|
409
|
+
offset: int = 0,
|
|
407
410
|
request_context: "RequestContext",
|
|
408
|
-
) ->
|
|
411
|
+
) -> dict[str, Any]:
|
|
409
412
|
"""
|
|
410
|
-
List entities for a bank.
|
|
413
|
+
List entities for a bank with pagination.
|
|
411
414
|
|
|
412
415
|
Args:
|
|
413
416
|
bank_id: The memory bank ID.
|
|
414
417
|
limit: Maximum results.
|
|
418
|
+
offset: Offset for pagination.
|
|
415
419
|
request_context: Request context for authentication.
|
|
416
420
|
|
|
417
421
|
Returns:
|
|
418
|
-
|
|
422
|
+
Dict with items, total, limit, offset.
|
|
419
423
|
"""
|
|
420
424
|
...
|
|
421
425
|
|