spatial-memory-mcp 1.0.2__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of spatial-memory-mcp might be problematic. Click here for more details.
- spatial_memory/__init__.py +97 -97
- spatial_memory/config.py +105 -0
- spatial_memory/core/__init__.py +26 -0
- spatial_memory/core/cache.py +317 -0
- spatial_memory/core/circuit_breaker.py +297 -0
- spatial_memory/core/database.py +167 -1
- spatial_memory/core/embeddings.py +92 -2
- spatial_memory/core/logging.py +194 -103
- spatial_memory/core/rate_limiter.py +309 -105
- spatial_memory/core/tracing.py +300 -0
- spatial_memory/core/validation.py +319 -319
- spatial_memory/server.py +230 -29
- spatial_memory/services/memory.py +79 -2
- spatial_memory/tools/definitions.py +695 -671
- {spatial_memory_mcp-1.0.2.dist-info → spatial_memory_mcp-1.5.3.dist-info}/METADATA +1 -1
- {spatial_memory_mcp-1.0.2.dist-info → spatial_memory_mcp-1.5.3.dist-info}/RECORD +19 -16
- {spatial_memory_mcp-1.0.2.dist-info → spatial_memory_mcp-1.5.3.dist-info}/WHEEL +0 -0
- {spatial_memory_mcp-1.0.2.dist-info → spatial_memory_mcp-1.5.3.dist-info}/entry_points.txt +0 -0
- {spatial_memory_mcp-1.0.2.dist-info → spatial_memory_mcp-1.5.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Circuit breaker pattern implementation for fault tolerance.
|
|
2
|
+
|
|
3
|
+
The circuit breaker prevents cascading failures by fast-failing requests
|
|
4
|
+
when a service is unhealthy, allowing time for recovery.
|
|
5
|
+
|
|
6
|
+
State transitions:
|
|
7
|
+
CLOSED (normal) -> OPEN (failures >= threshold)
|
|
8
|
+
OPEN -> HALF_OPEN (after reset_timeout)
|
|
9
|
+
HALF_OPEN -> CLOSED (probe succeeds)
|
|
10
|
+
HALF_OPEN -> OPEN (probe fails)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import threading
|
|
17
|
+
import time
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import TypeVar
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CircuitState(Enum):
|
|
28
|
+
"""Circuit breaker states.
|
|
29
|
+
|
|
30
|
+
CLOSED: Normal operation, requests pass through.
|
|
31
|
+
OPEN: Circuit is tripped, requests are rejected immediately.
|
|
32
|
+
HALF_OPEN: Testing recovery, limited requests allowed.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
CLOSED = "closed"
|
|
36
|
+
OPEN = "open"
|
|
37
|
+
HALF_OPEN = "half_open"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CircuitOpenError(Exception):
|
|
41
|
+
"""Raised when circuit is open and call is rejected.
|
|
42
|
+
|
|
43
|
+
This exception indicates that the circuit breaker is preventing
|
|
44
|
+
requests from reaching a failing service.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
message: str = "Circuit breaker is open",
|
|
50
|
+
time_until_retry: float | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initialize CircuitOpenError.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
message: Error description.
|
|
56
|
+
time_until_retry: Seconds until the circuit will transition to HALF_OPEN.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(message)
|
|
59
|
+
self.time_until_retry = time_until_retry
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CircuitBreaker:
|
|
63
|
+
"""Circuit breaker for protecting external service calls.
|
|
64
|
+
|
|
65
|
+
Monitors failures and opens the circuit when failures exceed a threshold,
|
|
66
|
+
preventing further calls until a reset timeout has elapsed.
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
breaker = CircuitBreaker(failure_threshold=5, reset_timeout=60.0)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
result = breaker.call(my_api_call, arg1, arg2)
|
|
73
|
+
except CircuitOpenError:
|
|
74
|
+
# Service is unhealthy, use fallback
|
|
75
|
+
result = fallback_value
|
|
76
|
+
|
|
77
|
+
Thread Safety:
|
|
78
|
+
This class is thread-safe. All state transitions and counters
|
|
79
|
+
are protected by a lock.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
failure_threshold: int = 5,
|
|
85
|
+
reset_timeout: float = 60.0,
|
|
86
|
+
half_open_max_calls: int = 1,
|
|
87
|
+
name: str | None = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize the circuit breaker.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
failure_threshold: Number of consecutive failures before opening circuit.
|
|
93
|
+
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
|
|
94
|
+
half_open_max_calls: Maximum concurrent calls allowed in HALF_OPEN state.
|
|
95
|
+
name: Optional name for logging purposes.
|
|
96
|
+
"""
|
|
97
|
+
if failure_threshold < 1:
|
|
98
|
+
raise ValueError("failure_threshold must be at least 1")
|
|
99
|
+
if reset_timeout <= 0:
|
|
100
|
+
raise ValueError("reset_timeout must be positive")
|
|
101
|
+
if half_open_max_calls < 1:
|
|
102
|
+
raise ValueError("half_open_max_calls must be at least 1")
|
|
103
|
+
|
|
104
|
+
self._failure_threshold = failure_threshold
|
|
105
|
+
self._reset_timeout = reset_timeout
|
|
106
|
+
self._half_open_max_calls = half_open_max_calls
|
|
107
|
+
self._name = name or "circuit_breaker"
|
|
108
|
+
|
|
109
|
+
self._state = CircuitState.CLOSED
|
|
110
|
+
self._failure_count = 0
|
|
111
|
+
self._last_failure_time: float | None = None
|
|
112
|
+
self._half_open_calls = 0
|
|
113
|
+
self._lock = threading.Lock()
|
|
114
|
+
|
|
115
|
+
# Statistics
|
|
116
|
+
self._total_calls = 0
|
|
117
|
+
self._total_failures = 0
|
|
118
|
+
self._total_rejections = 0
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def state(self) -> CircuitState:
|
|
122
|
+
"""Get current circuit state.
|
|
123
|
+
|
|
124
|
+
This property also handles automatic transition from OPEN to HALF_OPEN
|
|
125
|
+
when the reset timeout has elapsed.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Current circuit state.
|
|
129
|
+
"""
|
|
130
|
+
with self._lock:
|
|
131
|
+
return self._get_state_unlocked()
|
|
132
|
+
|
|
133
|
+
def _get_state_unlocked(self) -> CircuitState:
|
|
134
|
+
"""Get state without acquiring lock (must be called with lock held)."""
|
|
135
|
+
if self._state == CircuitState.OPEN:
|
|
136
|
+
if self._should_transition_to_half_open():
|
|
137
|
+
self._transition_to_half_open()
|
|
138
|
+
return self._state
|
|
139
|
+
|
|
140
|
+
def _should_transition_to_half_open(self) -> bool:
|
|
141
|
+
"""Check if reset timeout has elapsed (must be called with lock held)."""
|
|
142
|
+
if self._last_failure_time is None:
|
|
143
|
+
return False
|
|
144
|
+
elapsed = time.monotonic() - self._last_failure_time
|
|
145
|
+
return elapsed >= self._reset_timeout
|
|
146
|
+
|
|
147
|
+
def _transition_to_half_open(self) -> None:
|
|
148
|
+
"""Transition to HALF_OPEN state (must be called with lock held)."""
|
|
149
|
+
logger.info(f"[{self._name}] Circuit transitioning from OPEN to HALF_OPEN")
|
|
150
|
+
self._state = CircuitState.HALF_OPEN
|
|
151
|
+
self._half_open_calls = 0
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def failure_count(self) -> int:
|
|
155
|
+
"""Get current failure count."""
|
|
156
|
+
with self._lock:
|
|
157
|
+
return self._failure_count
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def stats(self) -> dict[str, int | str | float | None]:
|
|
161
|
+
"""Get circuit breaker statistics.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Dictionary with state, counters, and timing information.
|
|
165
|
+
"""
|
|
166
|
+
with self._lock:
|
|
167
|
+
state = self._get_state_unlocked()
|
|
168
|
+
time_until_retry = None
|
|
169
|
+
if state == CircuitState.OPEN and self._last_failure_time is not None:
|
|
170
|
+
elapsed = time.monotonic() - self._last_failure_time
|
|
171
|
+
time_until_retry = max(0.0, self._reset_timeout - elapsed)
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
"state": state.value,
|
|
175
|
+
"failure_count": self._failure_count,
|
|
176
|
+
"failure_threshold": self._failure_threshold,
|
|
177
|
+
"total_calls": self._total_calls,
|
|
178
|
+
"total_failures": self._total_failures,
|
|
179
|
+
"total_rejections": self._total_rejections,
|
|
180
|
+
"reset_timeout": self._reset_timeout,
|
|
181
|
+
"time_until_retry": time_until_retry,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def call(self, func: Callable[..., T], *args: object, **kwargs: object) -> T:
|
|
185
|
+
"""Execute function with circuit breaker protection.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
func: Function to execute.
|
|
189
|
+
*args: Positional arguments for the function.
|
|
190
|
+
**kwargs: Keyword arguments for the function.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Return value of the function.
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
CircuitOpenError: If circuit is OPEN and not ready for retry.
|
|
197
|
+
Exception: Any exception raised by the function.
|
|
198
|
+
"""
|
|
199
|
+
with self._lock:
|
|
200
|
+
self._total_calls += 1
|
|
201
|
+
current_state = self._get_state_unlocked()
|
|
202
|
+
|
|
203
|
+
if current_state == CircuitState.OPEN:
|
|
204
|
+
self._total_rejections += 1
|
|
205
|
+
time_until_retry = None
|
|
206
|
+
if self._last_failure_time is not None:
|
|
207
|
+
elapsed = time.monotonic() - self._last_failure_time
|
|
208
|
+
time_until_retry = max(0.0, self._reset_timeout - elapsed)
|
|
209
|
+
raise CircuitOpenError(
|
|
210
|
+
f"[{self._name}] Circuit is OPEN, rejecting call",
|
|
211
|
+
time_until_retry=time_until_retry,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if current_state == CircuitState.HALF_OPEN:
|
|
215
|
+
if self._half_open_calls >= self._half_open_max_calls:
|
|
216
|
+
self._total_rejections += 1
|
|
217
|
+
raise CircuitOpenError(
|
|
218
|
+
f"[{self._name}] Circuit is HALF_OPEN, max probe calls reached",
|
|
219
|
+
time_until_retry=0.0,
|
|
220
|
+
)
|
|
221
|
+
self._half_open_calls += 1
|
|
222
|
+
|
|
223
|
+
# Execute function outside the lock
|
|
224
|
+
try:
|
|
225
|
+
result = func(*args, **kwargs)
|
|
226
|
+
self._on_success()
|
|
227
|
+
return result
|
|
228
|
+
except Exception as e:
|
|
229
|
+
self._on_failure(e)
|
|
230
|
+
raise
|
|
231
|
+
|
|
232
|
+
def _on_success(self) -> None:
|
|
233
|
+
"""Handle successful call.
|
|
234
|
+
|
|
235
|
+
In CLOSED state: Reset failure count.
|
|
236
|
+
In HALF_OPEN state: Transition to CLOSED.
|
|
237
|
+
"""
|
|
238
|
+
with self._lock:
|
|
239
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
240
|
+
logger.info(
|
|
241
|
+
f"[{self._name}] Probe succeeded, circuit transitioning to CLOSED"
|
|
242
|
+
)
|
|
243
|
+
self._state = CircuitState.CLOSED
|
|
244
|
+
self._failure_count = 0
|
|
245
|
+
self._half_open_calls = 0
|
|
246
|
+
elif self._state == CircuitState.CLOSED:
|
|
247
|
+
self._failure_count = 0
|
|
248
|
+
|
|
249
|
+
def _on_failure(self, error: Exception) -> None:
|
|
250
|
+
"""Handle failed call.
|
|
251
|
+
|
|
252
|
+
In CLOSED state: Increment failure count, open circuit if threshold reached.
|
|
253
|
+
In HALF_OPEN state: Transition back to OPEN.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
error: The exception that was raised.
|
|
257
|
+
"""
|
|
258
|
+
with self._lock:
|
|
259
|
+
self._total_failures += 1
|
|
260
|
+
self._failure_count += 1
|
|
261
|
+
self._last_failure_time = time.monotonic()
|
|
262
|
+
|
|
263
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
264
|
+
logger.warning(
|
|
265
|
+
f"[{self._name}] Probe failed ({error!r}), "
|
|
266
|
+
f"circuit transitioning back to OPEN"
|
|
267
|
+
)
|
|
268
|
+
self._state = CircuitState.OPEN
|
|
269
|
+
self._half_open_calls = 0
|
|
270
|
+
elif self._state == CircuitState.CLOSED:
|
|
271
|
+
if self._failure_count >= self._failure_threshold:
|
|
272
|
+
logger.warning(
|
|
273
|
+
f"[{self._name}] Failure threshold reached "
|
|
274
|
+
f"({self._failure_count}/{self._failure_threshold}), "
|
|
275
|
+
f"circuit transitioning to OPEN"
|
|
276
|
+
)
|
|
277
|
+
self._state = CircuitState.OPEN
|
|
278
|
+
|
|
279
|
+
def reset(self) -> None:
|
|
280
|
+
"""Manually reset circuit to CLOSED state.
|
|
281
|
+
|
|
282
|
+
This clears all failure counters and transitions the circuit
|
|
283
|
+
to CLOSED state regardless of current state.
|
|
284
|
+
"""
|
|
285
|
+
with self._lock:
|
|
286
|
+
logger.info(f"[{self._name}] Circuit manually reset to CLOSED")
|
|
287
|
+
self._state = CircuitState.CLOSED
|
|
288
|
+
self._failure_count = 0
|
|
289
|
+
self._half_open_calls = 0
|
|
290
|
+
self._last_failure_time = None
|
|
291
|
+
|
|
292
|
+
def __repr__(self) -> str:
|
|
293
|
+
"""Return string representation."""
|
|
294
|
+
return (
|
|
295
|
+
f"CircuitBreaker(name={self._name!r}, state={self.state.value}, "
|
|
296
|
+
f"failures={self.failure_count}/{self._failure_threshold})"
|
|
297
|
+
)
|
spatial_memory/core/database.py
CHANGED
|
@@ -35,7 +35,7 @@ from filelock import FileLock, Timeout as FileLockTimeout
|
|
|
35
35
|
|
|
36
36
|
from spatial_memory.core.connection_pool import ConnectionPool
|
|
37
37
|
from spatial_memory.core.errors import FileLockError, MemoryNotFoundError, StorageError, ValidationError
|
|
38
|
-
from spatial_memory.core.utils import utc_now
|
|
38
|
+
from spatial_memory.core.utils import to_aware_utc, utc_now
|
|
39
39
|
|
|
40
40
|
# Import centralized validation functions
|
|
41
41
|
from spatial_memory.core.validation import (
|
|
@@ -391,6 +391,15 @@ def with_process_lock(func: F) -> F:
|
|
|
391
391
|
# Health Metrics
|
|
392
392
|
# ============================================================================
|
|
393
393
|
|
|
394
|
+
@dataclass
|
|
395
|
+
class IdempotencyRecord:
|
|
396
|
+
"""Record for idempotency key tracking."""
|
|
397
|
+
key: str
|
|
398
|
+
memory_id: str
|
|
399
|
+
created_at: Any # datetime
|
|
400
|
+
expires_at: Any # datetime
|
|
401
|
+
|
|
402
|
+
|
|
394
403
|
@dataclass
|
|
395
404
|
class IndexStats:
|
|
396
405
|
"""Statistics for a single index."""
|
|
@@ -3045,3 +3054,160 @@ class Database:
|
|
|
3045
3054
|
return self.table.version
|
|
3046
3055
|
except Exception as e:
|
|
3047
3056
|
raise StorageError(f"Failed to get current version: {e}") from e
|
|
3057
|
+
|
|
3058
|
+
# ========================================================================
|
|
3059
|
+
# Idempotency Key Management
|
|
3060
|
+
# ========================================================================
|
|
3061
|
+
|
|
3062
|
+
def _ensure_idempotency_table(self) -> None:
|
|
3063
|
+
"""Ensure the idempotency keys table exists."""
|
|
3064
|
+
if self._db is None:
|
|
3065
|
+
raise StorageError("Database not connected")
|
|
3066
|
+
|
|
3067
|
+
existing_tables_result = self._db.list_tables()
|
|
3068
|
+
if hasattr(existing_tables_result, 'tables'):
|
|
3069
|
+
existing_tables = existing_tables_result.tables
|
|
3070
|
+
else:
|
|
3071
|
+
existing_tables = existing_tables_result
|
|
3072
|
+
|
|
3073
|
+
if "idempotency_keys" not in existing_tables:
|
|
3074
|
+
schema = pa.schema([
|
|
3075
|
+
pa.field("key", pa.string()),
|
|
3076
|
+
pa.field("memory_id", pa.string()),
|
|
3077
|
+
pa.field("created_at", pa.timestamp("us")),
|
|
3078
|
+
pa.field("expires_at", pa.timestamp("us")),
|
|
3079
|
+
])
|
|
3080
|
+
self._db.create_table("idempotency_keys", schema=schema)
|
|
3081
|
+
logger.info("Created idempotency_keys table")
|
|
3082
|
+
|
|
3083
|
+
@property
|
|
3084
|
+
def idempotency_table(self) -> LanceTable:
|
|
3085
|
+
"""Get the idempotency keys table, creating if needed."""
|
|
3086
|
+
if self._db is None:
|
|
3087
|
+
self.connect()
|
|
3088
|
+
self._ensure_idempotency_table()
|
|
3089
|
+
assert self._db is not None
|
|
3090
|
+
return self._db.open_table("idempotency_keys")
|
|
3091
|
+
|
|
3092
|
+
def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
|
|
3093
|
+
"""Look up an idempotency record by key.
|
|
3094
|
+
|
|
3095
|
+
Args:
|
|
3096
|
+
key: The idempotency key to look up.
|
|
3097
|
+
|
|
3098
|
+
Returns:
|
|
3099
|
+
IdempotencyRecord if found and not expired, None otherwise.
|
|
3100
|
+
|
|
3101
|
+
Raises:
|
|
3102
|
+
StorageError: If database operation fails.
|
|
3103
|
+
"""
|
|
3104
|
+
if not key:
|
|
3105
|
+
return None
|
|
3106
|
+
|
|
3107
|
+
try:
|
|
3108
|
+
safe_key = _sanitize_string(key)
|
|
3109
|
+
results = (
|
|
3110
|
+
self.idempotency_table.search()
|
|
3111
|
+
.where(f"key = '{safe_key}'")
|
|
3112
|
+
.limit(1)
|
|
3113
|
+
.to_list()
|
|
3114
|
+
)
|
|
3115
|
+
|
|
3116
|
+
if not results:
|
|
3117
|
+
return None
|
|
3118
|
+
|
|
3119
|
+
record = results[0]
|
|
3120
|
+
now = utc_now()
|
|
3121
|
+
|
|
3122
|
+
# Check if expired (convert DB naive datetime to aware for comparison)
|
|
3123
|
+
expires_at = record.get("expires_at")
|
|
3124
|
+
if expires_at is not None:
|
|
3125
|
+
expires_at_aware = to_aware_utc(expires_at)
|
|
3126
|
+
if expires_at_aware < now:
|
|
3127
|
+
# Expired - clean it up and return None
|
|
3128
|
+
logger.debug(f"Idempotency key '{key}' has expired")
|
|
3129
|
+
return None
|
|
3130
|
+
|
|
3131
|
+
return IdempotencyRecord(
|
|
3132
|
+
key=record["key"],
|
|
3133
|
+
memory_id=record["memory_id"],
|
|
3134
|
+
created_at=record["created_at"],
|
|
3135
|
+
expires_at=record["expires_at"],
|
|
3136
|
+
)
|
|
3137
|
+
|
|
3138
|
+
except Exception as e:
|
|
3139
|
+
raise StorageError(f"Failed to look up idempotency key: {e}") from e
|
|
3140
|
+
|
|
3141
|
+
@with_process_lock
|
|
3142
|
+
@with_write_lock
|
|
3143
|
+
def store_idempotency_key(
|
|
3144
|
+
self,
|
|
3145
|
+
key: str,
|
|
3146
|
+
memory_id: str,
|
|
3147
|
+
ttl_hours: float = 24.0,
|
|
3148
|
+
) -> None:
|
|
3149
|
+
"""Store an idempotency key mapping.
|
|
3150
|
+
|
|
3151
|
+
Args:
|
|
3152
|
+
key: The idempotency key.
|
|
3153
|
+
memory_id: The memory ID that was created.
|
|
3154
|
+
ttl_hours: Time-to-live in hours (default: 24 hours).
|
|
3155
|
+
|
|
3156
|
+
Raises:
|
|
3157
|
+
ValidationError: If inputs are invalid.
|
|
3158
|
+
StorageError: If database operation fails.
|
|
3159
|
+
"""
|
|
3160
|
+
if not key:
|
|
3161
|
+
raise ValidationError("Idempotency key cannot be empty")
|
|
3162
|
+
if not memory_id:
|
|
3163
|
+
raise ValidationError("Memory ID cannot be empty")
|
|
3164
|
+
if ttl_hours <= 0:
|
|
3165
|
+
raise ValidationError("TTL must be positive")
|
|
3166
|
+
|
|
3167
|
+
now = utc_now()
|
|
3168
|
+
expires_at = now + timedelta(hours=ttl_hours)
|
|
3169
|
+
|
|
3170
|
+
record = {
|
|
3171
|
+
"key": key,
|
|
3172
|
+
"memory_id": memory_id,
|
|
3173
|
+
"created_at": now,
|
|
3174
|
+
"expires_at": expires_at,
|
|
3175
|
+
}
|
|
3176
|
+
|
|
3177
|
+
try:
|
|
3178
|
+
self.idempotency_table.add([record])
|
|
3179
|
+
logger.debug(
|
|
3180
|
+
f"Stored idempotency key '{key}' -> memory '{memory_id}' "
|
|
3181
|
+
f"(expires in {ttl_hours}h)"
|
|
3182
|
+
)
|
|
3183
|
+
except Exception as e:
|
|
3184
|
+
raise StorageError(f"Failed to store idempotency key: {e}") from e
|
|
3185
|
+
|
|
3186
|
+
@with_process_lock
|
|
3187
|
+
@with_write_lock
|
|
3188
|
+
def cleanup_expired_idempotency_keys(self) -> int:
|
|
3189
|
+
"""Remove expired idempotency keys.
|
|
3190
|
+
|
|
3191
|
+
Returns:
|
|
3192
|
+
Number of keys removed.
|
|
3193
|
+
|
|
3194
|
+
Raises:
|
|
3195
|
+
StorageError: If cleanup fails.
|
|
3196
|
+
"""
|
|
3197
|
+
try:
|
|
3198
|
+
now = utc_now()
|
|
3199
|
+
count_before = self.idempotency_table.count_rows()
|
|
3200
|
+
|
|
3201
|
+
# Delete expired keys
|
|
3202
|
+
predicate = f"expires_at < timestamp '{now.isoformat()}'"
|
|
3203
|
+
self.idempotency_table.delete(predicate)
|
|
3204
|
+
|
|
3205
|
+
count_after = self.idempotency_table.count_rows()
|
|
3206
|
+
deleted = count_before - count_after
|
|
3207
|
+
|
|
3208
|
+
if deleted > 0:
|
|
3209
|
+
logger.info(f"Cleaned up {deleted} expired idempotency keys")
|
|
3210
|
+
|
|
3211
|
+
return deleted
|
|
3212
|
+
except Exception as e:
|
|
3213
|
+
raise StorageError(f"Failed to cleanup idempotency keys: {e}") from e
|
|
@@ -12,6 +12,11 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
|
|
15
|
+
from spatial_memory.core.circuit_breaker import (
|
|
16
|
+
CircuitBreaker,
|
|
17
|
+
CircuitOpenError,
|
|
18
|
+
CircuitState,
|
|
19
|
+
)
|
|
15
20
|
from spatial_memory.core.errors import ConfigurationError, EmbeddingError
|
|
16
21
|
|
|
17
22
|
if TYPE_CHECKING:
|
|
@@ -158,6 +163,7 @@ class EmbeddingService:
|
|
|
158
163
|
|
|
159
164
|
Supports local sentence-transformers models and optional OpenAI API.
|
|
160
165
|
Uses ONNX Runtime by default for 2-3x faster inference.
|
|
166
|
+
Optionally uses a circuit breaker for fault tolerance with external services.
|
|
161
167
|
"""
|
|
162
168
|
|
|
163
169
|
def __init__(
|
|
@@ -165,6 +171,10 @@ class EmbeddingService:
|
|
|
165
171
|
model_name: str = "all-MiniLM-L6-v2",
|
|
166
172
|
openai_api_key: str | Any | None = None,
|
|
167
173
|
backend: EmbeddingBackend = "auto",
|
|
174
|
+
circuit_breaker: CircuitBreaker | None = None,
|
|
175
|
+
circuit_breaker_enabled: bool = True,
|
|
176
|
+
circuit_breaker_failure_threshold: int = 5,
|
|
177
|
+
circuit_breaker_reset_timeout: float = 60.0,
|
|
168
178
|
) -> None:
|
|
169
179
|
"""Initialize the embedding service.
|
|
170
180
|
|
|
@@ -174,6 +184,14 @@ class EmbeddingService:
|
|
|
174
184
|
Can be a string or a SecretStr (pydantic).
|
|
175
185
|
backend: Inference backend. 'auto' uses ONNX if available (default),
|
|
176
186
|
'onnx' forces ONNX Runtime, 'pytorch' forces PyTorch.
|
|
187
|
+
circuit_breaker: Optional pre-configured circuit breaker instance.
|
|
188
|
+
If provided, other circuit breaker parameters are ignored.
|
|
189
|
+
circuit_breaker_enabled: Whether to enable circuit breaker for OpenAI calls.
|
|
190
|
+
Defaults to True. Only applies to OpenAI models.
|
|
191
|
+
circuit_breaker_failure_threshold: Number of consecutive failures before
|
|
192
|
+
opening the circuit. Default is 5.
|
|
193
|
+
circuit_breaker_reset_timeout: Seconds to wait before attempting recovery.
|
|
194
|
+
Default is 60.0 seconds.
|
|
177
195
|
"""
|
|
178
196
|
self.model_name = model_name
|
|
179
197
|
# Handle both plain strings and SecretStr (pydantic)
|
|
@@ -203,6 +221,23 @@ class EmbeddingService:
|
|
|
203
221
|
"OpenAI API key required for OpenAI embedding models"
|
|
204
222
|
)
|
|
205
223
|
|
|
224
|
+
# Circuit breaker for OpenAI API calls (optional)
|
|
225
|
+
if circuit_breaker is not None:
|
|
226
|
+
self._circuit_breaker: CircuitBreaker | None = circuit_breaker
|
|
227
|
+
elif circuit_breaker_enabled and self.use_openai:
|
|
228
|
+
self._circuit_breaker = CircuitBreaker(
|
|
229
|
+
failure_threshold=circuit_breaker_failure_threshold,
|
|
230
|
+
reset_timeout=circuit_breaker_reset_timeout,
|
|
231
|
+
name=f"embedding_service_{model_name}",
|
|
232
|
+
)
|
|
233
|
+
logger.info(
|
|
234
|
+
f"Circuit breaker enabled for embedding service "
|
|
235
|
+
f"(threshold={circuit_breaker_failure_threshold}, "
|
|
236
|
+
f"timeout={circuit_breaker_reset_timeout}s)"
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
self._circuit_breaker = None
|
|
240
|
+
|
|
206
241
|
def _load_local_model(self) -> None:
|
|
207
242
|
"""Load local sentence-transformers model with ONNX or PyTorch backend."""
|
|
208
243
|
if self._model is not None:
|
|
@@ -300,6 +335,26 @@ class EmbeddingService:
|
|
|
300
335
|
self._load_local_model()
|
|
301
336
|
return self._active_backend or "pytorch"
|
|
302
337
|
|
|
338
|
+
@property
|
|
339
|
+
def circuit_state(self) -> CircuitState | None:
|
|
340
|
+
"""Get the current circuit breaker state.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
CircuitState if circuit breaker is enabled, None otherwise.
|
|
344
|
+
"""
|
|
345
|
+
if self._circuit_breaker is None:
|
|
346
|
+
return None
|
|
347
|
+
return self._circuit_breaker.state
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def circuit_breaker(self) -> CircuitBreaker | None:
|
|
351
|
+
"""Get the circuit breaker instance.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
CircuitBreaker if enabled, None otherwise.
|
|
355
|
+
"""
|
|
356
|
+
return self._circuit_breaker
|
|
357
|
+
|
|
303
358
|
def embed(self, text: str) -> np.ndarray:
|
|
304
359
|
"""Generate embedding for a single text.
|
|
305
360
|
|
|
@@ -320,7 +375,7 @@ class EmbeddingService:
|
|
|
320
375
|
|
|
321
376
|
# Generate embedding (outside lock to allow concurrent generation)
|
|
322
377
|
if self.use_openai:
|
|
323
|
-
embedding = self.
|
|
378
|
+
embedding = self._embed_openai_with_circuit_breaker([text])[0]
|
|
324
379
|
else:
|
|
325
380
|
embedding = self._embed_local([text])[0]
|
|
326
381
|
|
|
@@ -352,7 +407,7 @@ class EmbeddingService:
|
|
|
352
407
|
return []
|
|
353
408
|
|
|
354
409
|
if self.use_openai:
|
|
355
|
-
return self.
|
|
410
|
+
return self._embed_openai_with_circuit_breaker(texts)
|
|
356
411
|
else:
|
|
357
412
|
return self._embed_local(texts)
|
|
358
413
|
|
|
@@ -387,6 +442,41 @@ class EmbeddingService:
|
|
|
387
442
|
masked_error = _mask_api_key(str(e))
|
|
388
443
|
raise EmbeddingError(f"Failed to generate embeddings: {masked_error}") from e
|
|
389
444
|
|
|
445
|
+
def _embed_openai_with_circuit_breaker(self, texts: list[str]) -> list[np.ndarray]:
|
|
446
|
+
"""Generate embeddings using OpenAI API with circuit breaker protection.
|
|
447
|
+
|
|
448
|
+
Wraps the OpenAI embedding call with a circuit breaker to prevent
|
|
449
|
+
cascading failures when the API is unavailable.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
texts: List of texts to embed.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
List of embedding vectors.
|
|
456
|
+
|
|
457
|
+
Raises:
|
|
458
|
+
EmbeddingError: If circuit is open or embedding generation fails.
|
|
459
|
+
"""
|
|
460
|
+
if self._circuit_breaker is None:
|
|
461
|
+
# No circuit breaker, call directly
|
|
462
|
+
return self._embed_openai(texts)
|
|
463
|
+
|
|
464
|
+
try:
|
|
465
|
+
return self._circuit_breaker.call(self._embed_openai, texts)
|
|
466
|
+
except CircuitOpenError as e:
|
|
467
|
+
logger.warning(
|
|
468
|
+
f"Circuit breaker is open for embedding service, "
|
|
469
|
+
f"time until retry: {e.time_until_retry:.1f}s"
|
|
470
|
+
if e.time_until_retry is not None
|
|
471
|
+
else "Circuit breaker is open for embedding service"
|
|
472
|
+
)
|
|
473
|
+
raise EmbeddingError(
|
|
474
|
+
f"Embedding service temporarily unavailable (circuit open). "
|
|
475
|
+
f"Try again in {e.time_until_retry:.0f} seconds."
|
|
476
|
+
if e.time_until_retry is not None
|
|
477
|
+
else "Embedding service temporarily unavailable (circuit open)."
|
|
478
|
+
) from e
|
|
479
|
+
|
|
390
480
|
@retry_on_api_error(max_attempts=3, backoff=1.0)
|
|
391
481
|
def _embed_openai(self, texts: list[str]) -> list[np.ndarray]:
|
|
392
482
|
"""Generate embeddings using OpenAI API with retry logic.
|