spatial-memory-mcp 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of spatial-memory-mcp might be problematic. Click here for more details.
- spatial_memory/__init__.py +97 -0
- spatial_memory/__main__.py +270 -0
- spatial_memory/adapters/__init__.py +7 -0
- spatial_memory/adapters/lancedb_repository.py +878 -0
- spatial_memory/config.py +728 -0
- spatial_memory/core/__init__.py +118 -0
- spatial_memory/core/cache.py +317 -0
- spatial_memory/core/circuit_breaker.py +297 -0
- spatial_memory/core/connection_pool.py +220 -0
- spatial_memory/core/consolidation_strategies.py +402 -0
- spatial_memory/core/database.py +3069 -0
- spatial_memory/core/db_idempotency.py +242 -0
- spatial_memory/core/db_indexes.py +575 -0
- spatial_memory/core/db_migrations.py +584 -0
- spatial_memory/core/db_search.py +509 -0
- spatial_memory/core/db_versioning.py +177 -0
- spatial_memory/core/embeddings.py +557 -0
- spatial_memory/core/errors.py +317 -0
- spatial_memory/core/file_security.py +702 -0
- spatial_memory/core/filesystem.py +178 -0
- spatial_memory/core/health.py +289 -0
- spatial_memory/core/helpers.py +79 -0
- spatial_memory/core/import_security.py +432 -0
- spatial_memory/core/lifecycle_ops.py +1067 -0
- spatial_memory/core/logging.py +194 -0
- spatial_memory/core/metrics.py +192 -0
- spatial_memory/core/models.py +628 -0
- spatial_memory/core/rate_limiter.py +326 -0
- spatial_memory/core/response_types.py +497 -0
- spatial_memory/core/security.py +588 -0
- spatial_memory/core/spatial_ops.py +426 -0
- spatial_memory/core/tracing.py +300 -0
- spatial_memory/core/utils.py +110 -0
- spatial_memory/core/validation.py +403 -0
- spatial_memory/factory.py +407 -0
- spatial_memory/migrations/__init__.py +40 -0
- spatial_memory/ports/__init__.py +11 -0
- spatial_memory/ports/repositories.py +631 -0
- spatial_memory/py.typed +0 -0
- spatial_memory/server.py +1141 -0
- spatial_memory/services/__init__.py +70 -0
- spatial_memory/services/export_import.py +1023 -0
- spatial_memory/services/lifecycle.py +1120 -0
- spatial_memory/services/memory.py +412 -0
- spatial_memory/services/spatial.py +1147 -0
- spatial_memory/services/utility.py +409 -0
- spatial_memory/tools/__init__.py +5 -0
- spatial_memory/tools/definitions.py +695 -0
- spatial_memory/verify.py +140 -0
- spatial_memory_mcp-1.6.1.dist-info/METADATA +499 -0
- spatial_memory_mcp-1.6.1.dist-info/RECORD +54 -0
- spatial_memory_mcp-1.6.1.dist-info/WHEEL +4 -0
- spatial_memory_mcp-1.6.1.dist-info/entry_points.txt +2 -0
- spatial_memory_mcp-1.6.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
"""Embedding service for Spatial Memory MCP Server."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from functools import wraps
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from spatial_memory.core.circuit_breaker import (
|
|
16
|
+
CircuitBreaker,
|
|
17
|
+
CircuitOpenError,
|
|
18
|
+
CircuitState,
|
|
19
|
+
)
|
|
20
|
+
from spatial_memory.core.errors import ConfigurationError, EmbeddingError
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from openai import OpenAI
|
|
24
|
+
from sentence_transformers import SentenceTransformer
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Backend type for embedding inference
|
|
29
|
+
EmbeddingBackend = Literal["auto", "onnx", "pytorch"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_onnx_available() -> bool:
|
|
33
|
+
"""Check if ONNX Runtime and Optimum are available.
|
|
34
|
+
|
|
35
|
+
Sentence-transformers requires both onnxruntime and optimum for ONNX support.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
import onnxruntime # noqa: F401
|
|
39
|
+
import optimum.onnxruntime # noqa: F401
|
|
40
|
+
return True
|
|
41
|
+
except ImportError:
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _detect_backend(requested: EmbeddingBackend) -> Literal["onnx", "pytorch"]:
|
|
46
|
+
"""Detect which backend to use.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
requested: The requested backend ('auto', 'onnx', or 'pytorch').
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The actual backend to use ('onnx' or 'pytorch').
|
|
53
|
+
"""
|
|
54
|
+
if requested == "pytorch":
|
|
55
|
+
return "pytorch"
|
|
56
|
+
elif requested == "onnx":
|
|
57
|
+
if not _is_onnx_available():
|
|
58
|
+
raise ConfigurationError(
|
|
59
|
+
"ONNX Runtime requested but not fully installed. "
|
|
60
|
+
"Install with: pip install sentence-transformers[onnx]"
|
|
61
|
+
)
|
|
62
|
+
return "onnx"
|
|
63
|
+
else: # auto
|
|
64
|
+
if _is_onnx_available():
|
|
65
|
+
return "onnx"
|
|
66
|
+
return "pytorch"
|
|
67
|
+
|
|
68
|
+
# Type variable for retry decorator
|
|
69
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _mask_api_key(text: str) -> str:
|
|
73
|
+
"""Mask API keys in error messages.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
text: Error message text that might contain API keys.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Text with API keys masked.
|
|
80
|
+
"""
|
|
81
|
+
# Mask OpenAI keys (sk-...)
|
|
82
|
+
text = re.sub(r'sk-[a-zA-Z0-9]{20,}', '***OPENAI_KEY***', text)
|
|
83
|
+
# Mask generic api_key patterns
|
|
84
|
+
text = re.sub(
|
|
85
|
+
r'api[_-]?key["\']?\s*[:=]\s*["\']?[\w-]+',
|
|
86
|
+
'api_key=***MASKED***',
|
|
87
|
+
text,
|
|
88
|
+
flags=re.IGNORECASE
|
|
89
|
+
)
|
|
90
|
+
return text
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def retry_on_api_error(
|
|
94
|
+
max_attempts: int = 3,
|
|
95
|
+
backoff: float = 1.0,
|
|
96
|
+
retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
|
|
97
|
+
) -> Callable[[F], F]:
|
|
98
|
+
"""Retry decorator for transient API errors.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
max_attempts: Maximum number of retry attempts.
|
|
102
|
+
backoff: Initial backoff time in seconds (doubles each attempt).
|
|
103
|
+
retryable_status_codes: HTTP status codes that should trigger retry.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Decorated function with retry logic.
|
|
107
|
+
"""
|
|
108
|
+
# Non-retryable auth errors
|
|
109
|
+
non_retryable_codes = (401, 403)
|
|
110
|
+
|
|
111
|
+
def decorator(func: F) -> F:
|
|
112
|
+
@wraps(func)
|
|
113
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
114
|
+
last_error: Exception | None = None
|
|
115
|
+
for attempt in range(max_attempts):
|
|
116
|
+
try:
|
|
117
|
+
return func(*args, **kwargs)
|
|
118
|
+
except Exception as e:
|
|
119
|
+
last_error = e
|
|
120
|
+
|
|
121
|
+
# Check for OpenAI-specific errors
|
|
122
|
+
status_code = None
|
|
123
|
+
if hasattr(e, "status_code"):
|
|
124
|
+
status_code = e.status_code
|
|
125
|
+
elif hasattr(e, "response") and hasattr(e.response, "status_code"):
|
|
126
|
+
status_code = e.response.status_code
|
|
127
|
+
|
|
128
|
+
# Don't retry auth errors
|
|
129
|
+
if status_code in non_retryable_codes:
|
|
130
|
+
logger.warning(f"Non-retryable API error (status {status_code}): {e}")
|
|
131
|
+
raise
|
|
132
|
+
|
|
133
|
+
# Check if we should retry
|
|
134
|
+
should_retry = (
|
|
135
|
+
status_code in retryable_status_codes
|
|
136
|
+
or "rate" in str(e).lower()
|
|
137
|
+
or "timeout" in str(e).lower()
|
|
138
|
+
or "connection" in str(e).lower()
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if not should_retry or attempt == max_attempts - 1:
|
|
142
|
+
raise
|
|
143
|
+
|
|
144
|
+
# Retry with exponential backoff
|
|
145
|
+
wait_time = backoff * (2 ** attempt)
|
|
146
|
+
logger.warning(
|
|
147
|
+
f"API call failed (attempt {attempt + 1}/{max_attempts}): {e}. "
|
|
148
|
+
f"Retrying in {wait_time:.1f}s..."
|
|
149
|
+
)
|
|
150
|
+
time.sleep(wait_time)
|
|
151
|
+
|
|
152
|
+
if last_error:
|
|
153
|
+
raise last_error
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
return wrapper # type: ignore
|
|
157
|
+
|
|
158
|
+
return decorator
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class EmbeddingService:
|
|
162
|
+
"""Service for generating text embeddings.
|
|
163
|
+
|
|
164
|
+
Supports local sentence-transformers models and optional OpenAI API.
|
|
165
|
+
Uses ONNX Runtime by default for 2-3x faster inference.
|
|
166
|
+
Optionally uses a circuit breaker for fault tolerance with external services.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
model_name: str = "all-MiniLM-L6-v2",
|
|
172
|
+
openai_api_key: str | Any | None = None,
|
|
173
|
+
backend: EmbeddingBackend = "auto",
|
|
174
|
+
circuit_breaker: CircuitBreaker | None = None,
|
|
175
|
+
circuit_breaker_enabled: bool = True,
|
|
176
|
+
circuit_breaker_failure_threshold: int = 5,
|
|
177
|
+
circuit_breaker_reset_timeout: float = 60.0,
|
|
178
|
+
cache_max_size: int = 1000,
|
|
179
|
+
) -> None:
|
|
180
|
+
"""Initialize the embedding service.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
model_name: Model name. Use 'openai:model-name' for OpenAI models.
|
|
184
|
+
openai_api_key: OpenAI API key (required for OpenAI models).
|
|
185
|
+
Can be a string or a SecretStr (pydantic).
|
|
186
|
+
backend: Inference backend. 'auto' uses ONNX if available (default),
|
|
187
|
+
'onnx' forces ONNX Runtime, 'pytorch' forces PyTorch.
|
|
188
|
+
circuit_breaker: Optional pre-configured circuit breaker instance.
|
|
189
|
+
If provided, other circuit breaker parameters are ignored.
|
|
190
|
+
circuit_breaker_enabled: Whether to enable circuit breaker for OpenAI calls.
|
|
191
|
+
Defaults to True. Only applies to OpenAI models.
|
|
192
|
+
circuit_breaker_failure_threshold: Number of consecutive failures before
|
|
193
|
+
opening the circuit. Default is 5.
|
|
194
|
+
circuit_breaker_reset_timeout: Seconds to wait before attempting recovery.
|
|
195
|
+
Default is 60.0 seconds.
|
|
196
|
+
cache_max_size: Maximum number of embeddings to cache (LRU eviction).
|
|
197
|
+
Default is 1000. Set to 0 to disable caching.
|
|
198
|
+
"""
|
|
199
|
+
self.model_name = model_name
|
|
200
|
+
# Handle both plain strings and SecretStr (pydantic)
|
|
201
|
+
if openai_api_key is not None and hasattr(openai_api_key, 'get_secret_value'):
|
|
202
|
+
self._openai_api_key: str | None = openai_api_key.get_secret_value()
|
|
203
|
+
else:
|
|
204
|
+
self._openai_api_key = openai_api_key
|
|
205
|
+
self._model: SentenceTransformer | None = None
|
|
206
|
+
self._openai_client: OpenAI | None = None
|
|
207
|
+
self._dimensions: int | None = None
|
|
208
|
+
|
|
209
|
+
# Determine backend for local models
|
|
210
|
+
self._requested_backend = backend
|
|
211
|
+
self._active_backend: Literal["onnx", "pytorch"] | None = None
|
|
212
|
+
|
|
213
|
+
# Embedding cache (LRU with max size)
|
|
214
|
+
self._embed_cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
|
215
|
+
self._cache_max_size = cache_max_size
|
|
216
|
+
self._cache_lock = threading.Lock()
|
|
217
|
+
|
|
218
|
+
# Determine if using OpenAI
|
|
219
|
+
self.use_openai = model_name.startswith("openai:")
|
|
220
|
+
if self.use_openai:
|
|
221
|
+
self.openai_model = model_name.split(":", 1)[1]
|
|
222
|
+
if not self._openai_api_key:
|
|
223
|
+
raise ConfigurationError(
|
|
224
|
+
"OpenAI API key required for OpenAI embedding models"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Circuit breaker for OpenAI API calls (optional)
|
|
228
|
+
if circuit_breaker is not None:
|
|
229
|
+
self._circuit_breaker: CircuitBreaker | None = circuit_breaker
|
|
230
|
+
elif circuit_breaker_enabled and self.use_openai:
|
|
231
|
+
self._circuit_breaker = CircuitBreaker(
|
|
232
|
+
failure_threshold=circuit_breaker_failure_threshold,
|
|
233
|
+
reset_timeout=circuit_breaker_reset_timeout,
|
|
234
|
+
name=f"embedding_service_{model_name}",
|
|
235
|
+
)
|
|
236
|
+
logger.info(
|
|
237
|
+
f"Circuit breaker enabled for embedding service "
|
|
238
|
+
f"(threshold={circuit_breaker_failure_threshold}, "
|
|
239
|
+
f"timeout={circuit_breaker_reset_timeout}s)"
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
self._circuit_breaker = None
|
|
243
|
+
|
|
244
|
+
def _load_local_model(self) -> None:
|
|
245
|
+
"""Load local sentence-transformers model with ONNX or PyTorch backend."""
|
|
246
|
+
if self._model is not None:
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
from sentence_transformers import SentenceTransformer
|
|
251
|
+
|
|
252
|
+
# Detect which backend to use
|
|
253
|
+
self._active_backend = _detect_backend(self._requested_backend)
|
|
254
|
+
|
|
255
|
+
logger.info(
|
|
256
|
+
f"Loading embedding model: {self.model_name} "
|
|
257
|
+
f"(backend: {self._active_backend})"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Load model with appropriate backend
|
|
261
|
+
if self._active_backend == "onnx":
|
|
262
|
+
# Use ONNX Runtime backend for faster inference
|
|
263
|
+
self._model = SentenceTransformer(
|
|
264
|
+
self.model_name,
|
|
265
|
+
backend="onnx",
|
|
266
|
+
)
|
|
267
|
+
logger.info(
|
|
268
|
+
f"Using ONNX Runtime backend (2-3x faster inference)"
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
# Use default PyTorch backend
|
|
272
|
+
self._model = SentenceTransformer(self.model_name)
|
|
273
|
+
logger.info(
|
|
274
|
+
f"Using PyTorch backend"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self._dimensions = self._model.get_sentence_embedding_dimension()
|
|
278
|
+
logger.info(
|
|
279
|
+
f"Loaded model with {self._dimensions} dimensions"
|
|
280
|
+
)
|
|
281
|
+
except Exception as e:
|
|
282
|
+
masked_error = _mask_api_key(str(e))
|
|
283
|
+
raise EmbeddingError(f"Failed to load embedding model: {masked_error}") from e
|
|
284
|
+
|
|
285
|
+
def _load_openai_client(self) -> None:
|
|
286
|
+
"""Load OpenAI client."""
|
|
287
|
+
if self._openai_client is not None:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
from openai import OpenAI
|
|
292
|
+
|
|
293
|
+
self._openai_client = OpenAI(api_key=self._openai_api_key)
|
|
294
|
+
# Set dimensions based on model
|
|
295
|
+
model_dimensions = {
|
|
296
|
+
"text-embedding-3-small": 1536,
|
|
297
|
+
"text-embedding-3-large": 3072,
|
|
298
|
+
"text-embedding-ada-002": 1536,
|
|
299
|
+
}
|
|
300
|
+
self._dimensions = model_dimensions.get(self.openai_model, 1536)
|
|
301
|
+
logger.info(
|
|
302
|
+
f"Initialized OpenAI client for {self.openai_model} "
|
|
303
|
+
f"({self._dimensions} dimensions)"
|
|
304
|
+
)
|
|
305
|
+
except ImportError:
|
|
306
|
+
raise ConfigurationError(
|
|
307
|
+
"OpenAI package not installed. Run: pip install openai"
|
|
308
|
+
)
|
|
309
|
+
except Exception as e:
|
|
310
|
+
masked_error = _mask_api_key(str(e))
|
|
311
|
+
raise EmbeddingError(f"Failed to initialize OpenAI client: {masked_error}") from e
|
|
312
|
+
|
|
313
|
+
def _get_cache_key(self, text: str) -> str:
|
|
314
|
+
"""Generate cache key from text content.
|
|
315
|
+
|
|
316
|
+
Uses MD5 for speed (not security) - collisions are acceptable for cache.
|
|
317
|
+
"""
|
|
318
|
+
return hashlib.md5(text.encode(), usedforsecurity=False).hexdigest()
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def dimensions(self) -> int:
|
|
322
|
+
"""Get the embedding dimensions."""
|
|
323
|
+
if self._dimensions is None:
|
|
324
|
+
if self.use_openai:
|
|
325
|
+
self._load_openai_client()
|
|
326
|
+
else:
|
|
327
|
+
self._load_local_model()
|
|
328
|
+
return self._dimensions # type: ignore
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def backend(self) -> str:
|
|
332
|
+
"""Get the active embedding backend.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
'openai' for OpenAI API, 'onnx' or 'pytorch' for local models.
|
|
336
|
+
"""
|
|
337
|
+
if self.use_openai:
|
|
338
|
+
return "openai"
|
|
339
|
+
if self._active_backend is None:
|
|
340
|
+
# Force model load to determine backend
|
|
341
|
+
self._load_local_model()
|
|
342
|
+
return self._active_backend or "pytorch"
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def circuit_state(self) -> CircuitState | None:
|
|
346
|
+
"""Get the current circuit breaker state.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
CircuitState if circuit breaker is enabled, None otherwise.
|
|
350
|
+
"""
|
|
351
|
+
if self._circuit_breaker is None:
|
|
352
|
+
return None
|
|
353
|
+
return self._circuit_breaker.state
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def circuit_breaker(self) -> CircuitBreaker | None:
|
|
357
|
+
"""Get the circuit breaker instance.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
CircuitBreaker if enabled, None otherwise.
|
|
361
|
+
"""
|
|
362
|
+
return self._circuit_breaker
|
|
363
|
+
|
|
364
|
+
def embed(self, text: str) -> np.ndarray:
|
|
365
|
+
"""Generate embedding for a single text.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
text: Text to embed.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
Embedding vector as numpy array.
|
|
372
|
+
"""
|
|
373
|
+
cache_key = self._get_cache_key(text)
|
|
374
|
+
|
|
375
|
+
# Check cache first
|
|
376
|
+
with self._cache_lock:
|
|
377
|
+
if cache_key in self._embed_cache:
|
|
378
|
+
# Move to end (most recently used) and return copy
|
|
379
|
+
self._embed_cache.move_to_end(cache_key)
|
|
380
|
+
return self._embed_cache[cache_key].copy()
|
|
381
|
+
|
|
382
|
+
# Generate embedding (outside lock to allow concurrent generation)
|
|
383
|
+
if self.use_openai:
|
|
384
|
+
embedding = self._embed_openai_with_circuit_breaker([text])[0]
|
|
385
|
+
else:
|
|
386
|
+
embedding = self._embed_local([text])[0]
|
|
387
|
+
|
|
388
|
+
# Cache the result (if caching enabled)
|
|
389
|
+
if self._cache_max_size > 0:
|
|
390
|
+
with self._cache_lock:
|
|
391
|
+
# Check if another thread already cached it
|
|
392
|
+
if cache_key not in self._embed_cache:
|
|
393
|
+
# Evict oldest entries if at capacity
|
|
394
|
+
while len(self._embed_cache) >= self._cache_max_size:
|
|
395
|
+
self._embed_cache.popitem(last=False)
|
|
396
|
+
self._embed_cache[cache_key] = embedding.copy()
|
|
397
|
+
else:
|
|
398
|
+
# Another thread cached it, move to end
|
|
399
|
+
self._embed_cache.move_to_end(cache_key)
|
|
400
|
+
|
|
401
|
+
return embedding
|
|
402
|
+
|
|
403
|
+
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
|
404
|
+
"""Generate embeddings for multiple texts.
|
|
405
|
+
|
|
406
|
+
Uses cache for already-embedded texts and only generates
|
|
407
|
+
embeddings for texts not in cache.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
texts: List of texts to embed.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
List of embedding vectors.
|
|
414
|
+
"""
|
|
415
|
+
if not texts:
|
|
416
|
+
logger.debug("embed_batch called with empty input, returning empty list")
|
|
417
|
+
return []
|
|
418
|
+
|
|
419
|
+
# If caching disabled, generate all embeddings directly
|
|
420
|
+
if self._cache_max_size <= 0:
|
|
421
|
+
if self.use_openai:
|
|
422
|
+
return self._embed_openai_with_circuit_breaker(texts)
|
|
423
|
+
else:
|
|
424
|
+
return self._embed_local(texts)
|
|
425
|
+
|
|
426
|
+
# Check cache for each text
|
|
427
|
+
results: list[np.ndarray | None] = [None] * len(texts)
|
|
428
|
+
texts_to_embed: list[tuple[int, str]] = [] # (index, text)
|
|
429
|
+
|
|
430
|
+
with self._cache_lock:
|
|
431
|
+
for i, text in enumerate(texts):
|
|
432
|
+
cache_key = self._get_cache_key(text)
|
|
433
|
+
if cache_key in self._embed_cache:
|
|
434
|
+
self._embed_cache.move_to_end(cache_key)
|
|
435
|
+
results[i] = self._embed_cache[cache_key].copy()
|
|
436
|
+
else:
|
|
437
|
+
texts_to_embed.append((i, text))
|
|
438
|
+
|
|
439
|
+
# Generate embeddings for uncached texts
|
|
440
|
+
if texts_to_embed:
|
|
441
|
+
uncached_texts = [t for _, t in texts_to_embed]
|
|
442
|
+
if self.use_openai:
|
|
443
|
+
new_embeddings = self._embed_openai_with_circuit_breaker(uncached_texts)
|
|
444
|
+
else:
|
|
445
|
+
new_embeddings = self._embed_local(uncached_texts)
|
|
446
|
+
|
|
447
|
+
# Store results and cache them
|
|
448
|
+
with self._cache_lock:
|
|
449
|
+
for (idx, text), embedding in zip(texts_to_embed, new_embeddings):
|
|
450
|
+
results[idx] = embedding
|
|
451
|
+
cache_key = self._get_cache_key(text)
|
|
452
|
+
if cache_key not in self._embed_cache:
|
|
453
|
+
# Evict oldest entries if at capacity
|
|
454
|
+
while len(self._embed_cache) >= self._cache_max_size:
|
|
455
|
+
self._embed_cache.popitem(last=False)
|
|
456
|
+
self._embed_cache[cache_key] = embedding.copy()
|
|
457
|
+
|
|
458
|
+
# Type assertion - all results should be filled
|
|
459
|
+
return [r for r in results if r is not None]
|
|
460
|
+
|
|
461
|
+
def clear_cache(self) -> int:
|
|
462
|
+
"""Clear embedding cache. Returns number of entries cleared."""
|
|
463
|
+
with self._cache_lock:
|
|
464
|
+
count = len(self._embed_cache)
|
|
465
|
+
self._embed_cache.clear()
|
|
466
|
+
return count
|
|
467
|
+
|
|
468
|
+
def _embed_local(self, texts: list[str]) -> list[np.ndarray]:
|
|
469
|
+
"""Generate embeddings using local model.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
texts: List of texts to embed.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
List of embedding vectors.
|
|
476
|
+
"""
|
|
477
|
+
self._load_local_model()
|
|
478
|
+
assert self._model is not None # _load_local_model() sets this or raises
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
embeddings = self._model.encode(
|
|
482
|
+
texts,
|
|
483
|
+
convert_to_numpy=True,
|
|
484
|
+
normalize_embeddings=True,
|
|
485
|
+
show_progress_bar=False,
|
|
486
|
+
)
|
|
487
|
+
return [emb for emb in embeddings]
|
|
488
|
+
except Exception as e:
|
|
489
|
+
masked_error = _mask_api_key(str(e))
|
|
490
|
+
raise EmbeddingError(f"Failed to generate embeddings: {masked_error}") from e
|
|
491
|
+
|
|
492
|
+
def _embed_openai_with_circuit_breaker(self, texts: list[str]) -> list[np.ndarray]:
|
|
493
|
+
"""Generate embeddings using OpenAI API with circuit breaker protection.
|
|
494
|
+
|
|
495
|
+
Wraps the OpenAI embedding call with a circuit breaker to prevent
|
|
496
|
+
cascading failures when the API is unavailable.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
texts: List of texts to embed.
|
|
500
|
+
|
|
501
|
+
Returns:
|
|
502
|
+
List of embedding vectors.
|
|
503
|
+
|
|
504
|
+
Raises:
|
|
505
|
+
EmbeddingError: If circuit is open or embedding generation fails.
|
|
506
|
+
"""
|
|
507
|
+
if self._circuit_breaker is None:
|
|
508
|
+
# No circuit breaker, call directly
|
|
509
|
+
return self._embed_openai(texts)
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
return self._circuit_breaker.call(self._embed_openai, texts)
|
|
513
|
+
except CircuitOpenError as e:
|
|
514
|
+
logger.warning(
|
|
515
|
+
f"Circuit breaker is open for embedding service, "
|
|
516
|
+
f"time until retry: {e.time_until_retry:.1f}s"
|
|
517
|
+
if e.time_until_retry is not None
|
|
518
|
+
else "Circuit breaker is open for embedding service"
|
|
519
|
+
)
|
|
520
|
+
raise EmbeddingError(
|
|
521
|
+
f"Embedding service temporarily unavailable (circuit open). "
|
|
522
|
+
f"Try again in {e.time_until_retry:.0f} seconds."
|
|
523
|
+
if e.time_until_retry is not None
|
|
524
|
+
else "Embedding service temporarily unavailable (circuit open)."
|
|
525
|
+
) from e
|
|
526
|
+
|
|
527
|
+
@retry_on_api_error(max_attempts=3, backoff=1.0)
|
|
528
|
+
def _embed_openai(self, texts: list[str]) -> list[np.ndarray]:
|
|
529
|
+
"""Generate embeddings using OpenAI API with retry logic.
|
|
530
|
+
|
|
531
|
+
Automatically retries on transient errors (429 rate limit, 5xx server errors).
|
|
532
|
+
Does not retry on auth errors (401, 403).
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
texts: List of texts to embed.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
List of embedding vectors.
|
|
539
|
+
"""
|
|
540
|
+
self._load_openai_client()
|
|
541
|
+
assert self._openai_client is not None # _load_openai_client() sets this or raises
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
response = self._openai_client.embeddings.create(
|
|
545
|
+
model=self.openai_model,
|
|
546
|
+
input=texts,
|
|
547
|
+
)
|
|
548
|
+
embeddings = []
|
|
549
|
+
for item in response.data:
|
|
550
|
+
emb = np.array(item.embedding, dtype=np.float32)
|
|
551
|
+
# Normalize
|
|
552
|
+
emb = emb / np.linalg.norm(emb)
|
|
553
|
+
embeddings.append(emb)
|
|
554
|
+
return embeddings
|
|
555
|
+
except Exception as e:
|
|
556
|
+
masked_error = _mask_api_key(str(e))
|
|
557
|
+
raise EmbeddingError(f"Failed to generate OpenAI embeddings: {masked_error}") from e
|