spatial-memory-mcp 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatial_memory/__init__.py +97 -0
- spatial_memory/__main__.py +271 -0
- spatial_memory/adapters/__init__.py +7 -0
- spatial_memory/adapters/lancedb_repository.py +880 -0
- spatial_memory/config.py +769 -0
- spatial_memory/core/__init__.py +118 -0
- spatial_memory/core/cache.py +317 -0
- spatial_memory/core/circuit_breaker.py +297 -0
- spatial_memory/core/connection_pool.py +220 -0
- spatial_memory/core/consolidation_strategies.py +401 -0
- spatial_memory/core/database.py +3072 -0
- spatial_memory/core/db_idempotency.py +242 -0
- spatial_memory/core/db_indexes.py +576 -0
- spatial_memory/core/db_migrations.py +588 -0
- spatial_memory/core/db_search.py +512 -0
- spatial_memory/core/db_versioning.py +178 -0
- spatial_memory/core/embeddings.py +558 -0
- spatial_memory/core/errors.py +317 -0
- spatial_memory/core/file_security.py +701 -0
- spatial_memory/core/filesystem.py +178 -0
- spatial_memory/core/health.py +289 -0
- spatial_memory/core/helpers.py +79 -0
- spatial_memory/core/import_security.py +433 -0
- spatial_memory/core/lifecycle_ops.py +1067 -0
- spatial_memory/core/logging.py +194 -0
- spatial_memory/core/metrics.py +192 -0
- spatial_memory/core/models.py +660 -0
- spatial_memory/core/rate_limiter.py +326 -0
- spatial_memory/core/response_types.py +500 -0
- spatial_memory/core/security.py +588 -0
- spatial_memory/core/spatial_ops.py +430 -0
- spatial_memory/core/tracing.py +300 -0
- spatial_memory/core/utils.py +110 -0
- spatial_memory/core/validation.py +406 -0
- spatial_memory/factory.py +444 -0
- spatial_memory/migrations/__init__.py +40 -0
- spatial_memory/ports/__init__.py +11 -0
- spatial_memory/ports/repositories.py +630 -0
- spatial_memory/py.typed +0 -0
- spatial_memory/server.py +1214 -0
- spatial_memory/services/__init__.py +70 -0
- spatial_memory/services/decay_manager.py +411 -0
- spatial_memory/services/export_import.py +1031 -0
- spatial_memory/services/lifecycle.py +1139 -0
- spatial_memory/services/memory.py +412 -0
- spatial_memory/services/spatial.py +1152 -0
- spatial_memory/services/utility.py +429 -0
- spatial_memory/tools/__init__.py +5 -0
- spatial_memory/tools/definitions.py +695 -0
- spatial_memory/verify.py +140 -0
- spatial_memory_mcp-1.9.1.dist-info/METADATA +509 -0
- spatial_memory_mcp-1.9.1.dist-info/RECORD +55 -0
- spatial_memory_mcp-1.9.1.dist-info/WHEEL +4 -0
- spatial_memory_mcp-1.9.1.dist-info/entry_points.txt +2 -0
- spatial_memory_mcp-1.9.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,3072 @@
|
|
|
1
|
+
"""LanceDB database wrapper for Spatial Memory MCP Server.
|
|
2
|
+
|
|
3
|
+
Enterprise-grade implementation with:
|
|
4
|
+
- Connection pooling (singleton pattern)
|
|
5
|
+
- Automatic index creation (IVF-PQ, FTS, scalar)
|
|
6
|
+
- Hybrid search with RRF reranking
|
|
7
|
+
- Batch operations and streaming
|
|
8
|
+
- Maintenance and optimization utilities
|
|
9
|
+
- Health metrics and monitoring
|
|
10
|
+
- Retry logic for transient errors
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from collections.abc import Callable, Generator, Iterator
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from datetime import timedelta
|
|
23
|
+
from functools import wraps
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
|
26
|
+
|
|
27
|
+
import lancedb
|
|
28
|
+
import lancedb.index
|
|
29
|
+
import numpy as np
|
|
30
|
+
import pyarrow as pa
|
|
31
|
+
import pyarrow.parquet as pq
|
|
32
|
+
from filelock import FileLock
|
|
33
|
+
from filelock import Timeout as FileLockTimeout
|
|
34
|
+
|
|
35
|
+
from spatial_memory.core.connection_pool import ConnectionPool
|
|
36
|
+
from spatial_memory.core.db_idempotency import IdempotencyManager, IdempotencyRecord
|
|
37
|
+
from spatial_memory.core.db_indexes import IndexManager
|
|
38
|
+
from spatial_memory.core.db_migrations import CURRENT_SCHEMA_VERSION, MigrationManager
|
|
39
|
+
from spatial_memory.core.db_search import SearchManager
|
|
40
|
+
from spatial_memory.core.db_versioning import VersionManager
|
|
41
|
+
from spatial_memory.core.errors import (
|
|
42
|
+
DimensionMismatchError,
|
|
43
|
+
FileLockError,
|
|
44
|
+
MemoryNotFoundError,
|
|
45
|
+
PartialBatchInsertError,
|
|
46
|
+
StorageError,
|
|
47
|
+
ValidationError,
|
|
48
|
+
)
|
|
49
|
+
from spatial_memory.core.filesystem import (
|
|
50
|
+
detect_filesystem_type,
|
|
51
|
+
get_filesystem_warning_message,
|
|
52
|
+
is_network_filesystem,
|
|
53
|
+
)
|
|
54
|
+
from spatial_memory.core.utils import utc_now
|
|
55
|
+
|
|
56
|
+
# Import centralized validation functions
|
|
57
|
+
from spatial_memory.core.validation import (
|
|
58
|
+
sanitize_string as _sanitize_string_impl,
|
|
59
|
+
)
|
|
60
|
+
from spatial_memory.core.validation import (
|
|
61
|
+
validate_metadata as _validate_metadata_impl,
|
|
62
|
+
)
|
|
63
|
+
from spatial_memory.core.validation import (
|
|
64
|
+
validate_namespace as _validate_namespace_impl,
|
|
65
|
+
)
|
|
66
|
+
from spatial_memory.core.validation import (
|
|
67
|
+
validate_tags as _validate_tags_impl,
|
|
68
|
+
)
|
|
69
|
+
from spatial_memory.core.validation import (
|
|
70
|
+
validate_uuid as _validate_uuid_impl,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
if TYPE_CHECKING:
|
|
74
|
+
from lancedb.table import Table as LanceTable
|
|
75
|
+
|
|
76
|
+
logger = logging.getLogger(__name__)
|
|
77
|
+
|
|
78
|
+
# Type variable for retry decorator
|
|
79
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
80
|
+
|
|
81
|
+
# All known vector index types for detection
|
|
82
|
+
VECTOR_INDEX_TYPES = frozenset({
|
|
83
|
+
"IVF_PQ", "IVF_FLAT", "HNSW",
|
|
84
|
+
"IVF_HNSW_PQ", "IVF_HNSW_SQ",
|
|
85
|
+
"HNSW_PQ", "HNSW_SQ",
|
|
86
|
+
})
|
|
87
|
+
|
|
88
|
+
# ============================================================================
|
|
89
|
+
# Connection Pool (Singleton Pattern with LRU Eviction)
|
|
90
|
+
# ============================================================================
|
|
91
|
+
|
|
92
|
+
# Global connection pool instance
|
|
93
|
+
_connection_pool = ConnectionPool(max_size=10)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def set_connection_pool_max_size(max_size: int) -> None:
|
|
97
|
+
"""Set the maximum connection pool size.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
max_size: Maximum number of connections to cache.
|
|
101
|
+
"""
|
|
102
|
+
_connection_pool.max_size = max_size
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_or_create_connection(
|
|
106
|
+
storage_path: Path,
|
|
107
|
+
read_consistency_interval_ms: int = 0,
|
|
108
|
+
) -> lancedb.DBConnection:
|
|
109
|
+
"""Get cached connection or create new one (thread-safe with LRU eviction).
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
storage_path: Path to LanceDB storage directory.
|
|
113
|
+
read_consistency_interval_ms: Read consistency interval in milliseconds.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
LanceDB connection instance.
|
|
117
|
+
"""
|
|
118
|
+
path_key = str(storage_path.absolute())
|
|
119
|
+
return _connection_pool.get_or_create(path_key, read_consistency_interval_ms)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def clear_connection_cache() -> None:
|
|
123
|
+
"""Clear the connection cache, properly closing connections.
|
|
124
|
+
|
|
125
|
+
Should be called during shutdown or testing cleanup.
|
|
126
|
+
"""
|
|
127
|
+
_connection_pool.close_all()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def invalidate_connection(storage_path: Path) -> bool:
|
|
131
|
+
"""Invalidate a specific cached connection.
|
|
132
|
+
|
|
133
|
+
Use when a database connection becomes stale (e.g., database was
|
|
134
|
+
deleted and recreated externally).
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
storage_path: Path to the database to invalidate.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
True if a connection was invalidated, False if not found in cache.
|
|
141
|
+
"""
|
|
142
|
+
path_key = str(storage_path.absolute())
|
|
143
|
+
return _connection_pool.invalidate(path_key)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# ============================================================================
|
|
147
|
+
# Retry Decorator
|
|
148
|
+
# ============================================================================
|
|
149
|
+
|
|
150
|
+
# Default retry settings (can be overridden per-call)
|
|
151
|
+
DEFAULT_RETRY_MAX_ATTEMPTS = 3
|
|
152
|
+
DEFAULT_RETRY_BACKOFF_SECONDS = 0.5
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def retry_on_storage_error(
|
|
156
|
+
max_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
|
|
157
|
+
backoff: float = DEFAULT_RETRY_BACKOFF_SECONDS,
|
|
158
|
+
) -> Callable[[F], F]:
|
|
159
|
+
"""Retry decorator for transient storage errors.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
max_attempts: Maximum number of retry attempts.
|
|
163
|
+
backoff: Initial backoff time in seconds (doubles each attempt).
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Decorated function with retry logic.
|
|
167
|
+
|
|
168
|
+
Note:
|
|
169
|
+
- Decorator values are STATIC: Parameters are fixed at class definition
|
|
170
|
+
time, not instance creation time. This means the instance config values
|
|
171
|
+
(max_retry_attempts, retry_backoff_seconds) exist for external tooling
|
|
172
|
+
or future dynamic use, but do NOT affect this decorator's behavior.
|
|
173
|
+
- Does NOT retry concurrent modification or conflict errors as these
|
|
174
|
+
require application-level resolution (e.g., refresh and retry).
|
|
175
|
+
"""
|
|
176
|
+
# Patterns indicating non-retryable errors
|
|
177
|
+
non_retryable_patterns = ("concurrent", "conflict", "version mismatch")
|
|
178
|
+
|
|
179
|
+
def decorator(func: F) -> F:
|
|
180
|
+
@wraps(func)
|
|
181
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
182
|
+
last_error: Exception | None = None
|
|
183
|
+
for attempt in range(max_attempts):
|
|
184
|
+
try:
|
|
185
|
+
return func(*args, **kwargs)
|
|
186
|
+
except (
|
|
187
|
+
StorageError, OSError, ConnectionError, TimeoutError
|
|
188
|
+
) as e:
|
|
189
|
+
last_error = e
|
|
190
|
+
error_str = str(e).lower()
|
|
191
|
+
|
|
192
|
+
# Check for non-retryable errors - raise immediately
|
|
193
|
+
if any(pattern in error_str for pattern in non_retryable_patterns):
|
|
194
|
+
logger.warning(
|
|
195
|
+
f"Non-retryable error in {func.__name__}: {e}"
|
|
196
|
+
)
|
|
197
|
+
raise
|
|
198
|
+
|
|
199
|
+
# Check if we've exhausted retries
|
|
200
|
+
if attempt == max_attempts - 1:
|
|
201
|
+
raise
|
|
202
|
+
|
|
203
|
+
# Retry with exponential backoff
|
|
204
|
+
wait_time = backoff * (2 ** attempt)
|
|
205
|
+
logger.warning(
|
|
206
|
+
f"{func.__name__} failed (attempt {attempt + 1}/{max_attempts})"
|
|
207
|
+
f": {e}. Retrying in {wait_time:.1f}s..."
|
|
208
|
+
)
|
|
209
|
+
time.sleep(wait_time)
|
|
210
|
+
# Should never reach here, but satisfy type checker
|
|
211
|
+
if last_error:
|
|
212
|
+
raise last_error
|
|
213
|
+
return None
|
|
214
|
+
return cast(F, wrapper)
|
|
215
|
+
return decorator
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def with_write_lock(func: F) -> F:
|
|
219
|
+
"""Decorator to acquire write lock for mutation operations.
|
|
220
|
+
|
|
221
|
+
Serializes write operations per Database instance to prevent
|
|
222
|
+
LanceDB version conflicts during concurrent writes.
|
|
223
|
+
|
|
224
|
+
Uses RLock to allow nested calls (e.g., bulk_import -> insert_batch).
|
|
225
|
+
"""
|
|
226
|
+
@wraps(func)
|
|
227
|
+
def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
|
|
228
|
+
with self._write_lock:
|
|
229
|
+
return func(self, *args, **kwargs)
|
|
230
|
+
return cast(F, wrapper)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def with_stale_connection_recovery(func: F) -> F:
|
|
234
|
+
"""Decorator to auto-recover from stale connection errors.
|
|
235
|
+
|
|
236
|
+
Detects when a database operation fails due to stale metadata
|
|
237
|
+
(e.g., database was recreated while connection was cached),
|
|
238
|
+
reconnects, and retries the operation once.
|
|
239
|
+
"""
|
|
240
|
+
@wraps(func)
|
|
241
|
+
def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
|
|
242
|
+
try:
|
|
243
|
+
return func(self, *args, **kwargs)
|
|
244
|
+
except Exception as e:
|
|
245
|
+
if _connection_pool.is_stale_connection_error(e):
|
|
246
|
+
logger.warning(
|
|
247
|
+
f"Stale connection detected in {func.__name__}, reconnecting..."
|
|
248
|
+
)
|
|
249
|
+
self.reconnect()
|
|
250
|
+
return func(self, *args, **kwargs)
|
|
251
|
+
raise
|
|
252
|
+
return cast(F, wrapper)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# ============================================================================
|
|
256
|
+
# Cross-Process Lock Manager
|
|
257
|
+
# ============================================================================
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class ProcessLockManager:
|
|
261
|
+
"""Cross-process file lock manager with reentrant support.
|
|
262
|
+
|
|
263
|
+
Wraps FileLock with thread-local depth tracking to support nested calls
|
|
264
|
+
(e.g., bulk_import() -> insert_batch()). Each thread can re-acquire the
|
|
265
|
+
lock without blocking.
|
|
266
|
+
|
|
267
|
+
Thread Safety:
|
|
268
|
+
- Lock depth is tracked per-thread using threading.local
|
|
269
|
+
- The underlying FileLock handles cross-process synchronization
|
|
270
|
+
- Multiple threads in the same process can hold the lock via RLock behavior
|
|
271
|
+
|
|
272
|
+
Example:
|
|
273
|
+
lock = ProcessLockManager(Path("/tmp/db.lock"), timeout=30.0)
|
|
274
|
+
with lock:
|
|
275
|
+
# Protected region
|
|
276
|
+
with lock: # Nested call - same thread can re-acquire
|
|
277
|
+
pass
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
lock_path: Path,
|
|
283
|
+
timeout: float = 30.0,
|
|
284
|
+
poll_interval: float = 0.1,
|
|
285
|
+
enabled: bool = True,
|
|
286
|
+
) -> None:
|
|
287
|
+
"""Initialize the process lock manager.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
lock_path: Path to the lock file.
|
|
291
|
+
timeout: Maximum seconds to wait for lock acquisition.
|
|
292
|
+
poll_interval: Seconds between lock acquisition attempts.
|
|
293
|
+
enabled: If False, all lock operations are no-ops.
|
|
294
|
+
"""
|
|
295
|
+
self.lock_path = lock_path
|
|
296
|
+
self.timeout = timeout
|
|
297
|
+
self.poll_interval = poll_interval
|
|
298
|
+
self.enabled = enabled
|
|
299
|
+
|
|
300
|
+
# Create FileLock only if enabled
|
|
301
|
+
self._lock: FileLock | None = None
|
|
302
|
+
if enabled:
|
|
303
|
+
try:
|
|
304
|
+
self._lock = FileLock(str(lock_path), timeout=timeout)
|
|
305
|
+
except Exception as e:
|
|
306
|
+
# Fallback to disabled mode if lock file can't be created
|
|
307
|
+
# (e.g., read-only filesystem)
|
|
308
|
+
logger.warning(
|
|
309
|
+
f"Could not create file lock at {lock_path}: {e}. "
|
|
310
|
+
"Falling back to disabled mode."
|
|
311
|
+
)
|
|
312
|
+
self.enabled = False
|
|
313
|
+
|
|
314
|
+
# Thread-local storage for lock depth tracking
|
|
315
|
+
self._local = threading.local()
|
|
316
|
+
|
|
317
|
+
def _get_depth(self) -> int:
|
|
318
|
+
"""Get current lock depth for this thread."""
|
|
319
|
+
return getattr(self._local, "depth", 0)
|
|
320
|
+
|
|
321
|
+
def _set_depth(self, depth: int) -> None:
|
|
322
|
+
"""Set lock depth for this thread."""
|
|
323
|
+
self._local.depth = depth
|
|
324
|
+
|
|
325
|
+
def acquire(self) -> bool:
|
|
326
|
+
"""Acquire the lock (reentrant for same thread).
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
True if lock was newly acquired, False if already held by this thread.
|
|
330
|
+
|
|
331
|
+
Raises:
|
|
332
|
+
FileLockError: If lock cannot be acquired within timeout.
|
|
333
|
+
"""
|
|
334
|
+
if not self.enabled or self._lock is None:
|
|
335
|
+
return True
|
|
336
|
+
|
|
337
|
+
depth = self._get_depth()
|
|
338
|
+
if depth > 0:
|
|
339
|
+
# Already held by this thread - increment depth
|
|
340
|
+
self._set_depth(depth + 1)
|
|
341
|
+
return False # Not newly acquired
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
self._lock.acquire(timeout=self.timeout, poll_interval=self.poll_interval)
|
|
345
|
+
self._set_depth(1)
|
|
346
|
+
return True
|
|
347
|
+
except FileLockTimeout:
|
|
348
|
+
raise FileLockError(
|
|
349
|
+
lock_path=str(self.lock_path),
|
|
350
|
+
timeout=self.timeout,
|
|
351
|
+
message=(
|
|
352
|
+
f"Timed out waiting {self.timeout}s for file lock at "
|
|
353
|
+
f"{self.lock_path}. Another process may be holding the lock."
|
|
354
|
+
),
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def release(self) -> bool:
|
|
358
|
+
"""Release the lock (decrements depth, releases when depth reaches 0).
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
True if lock was released, False if still held (depth > 0).
|
|
362
|
+
"""
|
|
363
|
+
if not self.enabled or self._lock is None:
|
|
364
|
+
return True
|
|
365
|
+
|
|
366
|
+
depth = self._get_depth()
|
|
367
|
+
if depth <= 0:
|
|
368
|
+
return True # Not holding the lock
|
|
369
|
+
|
|
370
|
+
if depth == 1:
|
|
371
|
+
self._lock.release()
|
|
372
|
+
self._set_depth(0)
|
|
373
|
+
return True
|
|
374
|
+
else:
|
|
375
|
+
self._set_depth(depth - 1)
|
|
376
|
+
return False # Still holding
|
|
377
|
+
|
|
378
|
+
def __enter__(self) -> ProcessLockManager:
|
|
379
|
+
"""Enter context manager - acquire lock."""
|
|
380
|
+
self.acquire()
|
|
381
|
+
return self
|
|
382
|
+
|
|
383
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
384
|
+
"""Exit context manager - release lock."""
|
|
385
|
+
self.release()
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def with_process_lock(func: F) -> F:
|
|
389
|
+
"""Decorator to acquire process-level file lock for write operations.
|
|
390
|
+
|
|
391
|
+
Must be applied BEFORE (outer) @with_write_lock to ensure:
|
|
392
|
+
1. Cross-process lock acquired first
|
|
393
|
+
2. Then intra-process thread lock
|
|
394
|
+
3. Releases in reverse order
|
|
395
|
+
|
|
396
|
+
Usage:
|
|
397
|
+
@with_process_lock # Outer - cross-process
|
|
398
|
+
@with_write_lock # Inner - intra-process
|
|
399
|
+
def insert(self, ...):
|
|
400
|
+
...
|
|
401
|
+
"""
|
|
402
|
+
@wraps(func)
|
|
403
|
+
def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
|
|
404
|
+
if self._process_lock is None:
|
|
405
|
+
return func(self, *args, **kwargs)
|
|
406
|
+
with self._process_lock:
|
|
407
|
+
return func(self, *args, **kwargs)
|
|
408
|
+
return cast(F, wrapper)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# ============================================================================
|
|
412
|
+
# Health Metrics
|
|
413
|
+
# ============================================================================
|
|
414
|
+
|
|
415
|
+
@dataclass
|
|
416
|
+
class IndexStats:
|
|
417
|
+
"""Statistics for a single index."""
|
|
418
|
+
name: str
|
|
419
|
+
index_type: str
|
|
420
|
+
num_indexed_rows: int
|
|
421
|
+
num_unindexed_rows: int
|
|
422
|
+
needs_update: bool
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@dataclass
|
|
426
|
+
class HealthMetrics:
|
|
427
|
+
"""Database health and performance metrics."""
|
|
428
|
+
total_rows: int
|
|
429
|
+
total_bytes: int
|
|
430
|
+
total_bytes_mb: float
|
|
431
|
+
num_fragments: int
|
|
432
|
+
num_small_fragments: int
|
|
433
|
+
needs_compaction: bool
|
|
434
|
+
has_vector_index: bool
|
|
435
|
+
has_fts_index: bool
|
|
436
|
+
indices: list[IndexStats]
|
|
437
|
+
version: int
|
|
438
|
+
error: str | None = None
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# Backward compatibility aliases - use centralized validation module
|
|
442
|
+
_sanitize_string = _sanitize_string_impl
|
|
443
|
+
_validate_uuid = _validate_uuid_impl
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _get_index_attr(idx: Any, attr: str, default: Any = None) -> Any:
|
|
447
|
+
"""Get an attribute from an index object (handles both dict and IndexConfig).
|
|
448
|
+
|
|
449
|
+
LanceDB 0.27+ returns IndexConfig objects, while older versions use dicts.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
idx: Index object (dict or IndexConfig).
|
|
453
|
+
attr: Attribute name to retrieve.
|
|
454
|
+
default: Default value if attribute not found.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
The attribute value or default.
|
|
458
|
+
"""
|
|
459
|
+
if isinstance(idx, dict):
|
|
460
|
+
return idx.get(attr, default)
|
|
461
|
+
return getattr(idx, attr, default)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
_validate_namespace = _validate_namespace_impl
|
|
465
|
+
_validate_tags = _validate_tags_impl
|
|
466
|
+
_validate_metadata = _validate_metadata_impl
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class Database:
|
|
470
|
+
"""LanceDB wrapper for memory storage and retrieval.
|
|
471
|
+
|
|
472
|
+
Enterprise-grade features:
|
|
473
|
+
- Connection pooling via singleton pattern with LRU eviction
|
|
474
|
+
- Automatic index creation based on dataset size
|
|
475
|
+
- Hybrid search with RRF reranking and alpha parameter
|
|
476
|
+
- Batch operations for efficiency
|
|
477
|
+
- Row count caching for search performance (thread-safe)
|
|
478
|
+
- Maintenance and optimization utilities
|
|
479
|
+
|
|
480
|
+
Thread Safety:
|
|
481
|
+
The module-level connection pool is thread-safe. However, individual
|
|
482
|
+
Database instances should NOT be shared across threads without external
|
|
483
|
+
synchronization. Each thread should create its own Database instance,
|
|
484
|
+
which will share the underlying pooled connection safely.
|
|
485
|
+
|
|
486
|
+
Supports context manager protocol for safe resource management.
|
|
487
|
+
|
|
488
|
+
Example:
|
|
489
|
+
with Database(path) as db:
|
|
490
|
+
db.insert(content="Hello", vector=vec)
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
# Cache refresh interval for row count (seconds)
|
|
494
|
+
_COUNT_CACHE_TTL = 60.0
|
|
495
|
+
# Cache refresh interval for namespaces (seconds) - longer because namespaces change less often
|
|
496
|
+
_NAMESPACE_CACHE_TTL = 300.0
|
|
497
|
+
|
|
498
|
+
def __init__(
|
|
499
|
+
self,
|
|
500
|
+
storage_path: Path,
|
|
501
|
+
embedding_dim: int = 384,
|
|
502
|
+
auto_create_indexes: bool = True,
|
|
503
|
+
vector_index_threshold: int = 10_000,
|
|
504
|
+
enable_fts: bool = True,
|
|
505
|
+
index_nprobes: int = 20,
|
|
506
|
+
index_refine_factor: int = 5,
|
|
507
|
+
max_retry_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
|
|
508
|
+
retry_backoff_seconds: float = DEFAULT_RETRY_BACKOFF_SECONDS,
|
|
509
|
+
read_consistency_interval_ms: int = 0,
|
|
510
|
+
index_wait_timeout_seconds: float = 30.0,
|
|
511
|
+
fts_stem: bool = True,
|
|
512
|
+
fts_remove_stop_words: bool = True,
|
|
513
|
+
fts_language: str = "English",
|
|
514
|
+
index_type: str = "IVF_PQ",
|
|
515
|
+
hnsw_m: int = 20,
|
|
516
|
+
hnsw_ef_construction: int = 300,
|
|
517
|
+
enable_memory_expiration: bool = False,
|
|
518
|
+
default_memory_ttl_days: int | None = None,
|
|
519
|
+
filelock_enabled: bool = True,
|
|
520
|
+
filelock_timeout: float = 30.0,
|
|
521
|
+
filelock_poll_interval: float = 0.1,
|
|
522
|
+
acknowledge_network_filesystem_risk: bool = False,
|
|
523
|
+
) -> None:
|
|
524
|
+
"""Initialize the database connection.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
storage_path: Path to LanceDB storage directory.
|
|
528
|
+
embedding_dim: Dimension of embedding vectors.
|
|
529
|
+
auto_create_indexes: Automatically create indexes when thresholds met.
|
|
530
|
+
vector_index_threshold: Row count to trigger vector index creation.
|
|
531
|
+
enable_fts: Enable full-text search index.
|
|
532
|
+
index_nprobes: Number of partitions to search (higher = better recall).
|
|
533
|
+
index_refine_factor: Re-rank top (refine_factor * limit) for accuracy.
|
|
534
|
+
max_retry_attempts: Maximum retry attempts for transient errors.
|
|
535
|
+
retry_backoff_seconds: Initial backoff time for retries.
|
|
536
|
+
read_consistency_interval_ms: Read consistency interval (0 = strong).
|
|
537
|
+
index_wait_timeout_seconds: Timeout for waiting on index creation.
|
|
538
|
+
fts_stem: Enable stemming in FTS (running -> run).
|
|
539
|
+
fts_remove_stop_words: Remove stop words in FTS (the, is, etc.).
|
|
540
|
+
filelock_enabled: Enable cross-process file locking.
|
|
541
|
+
filelock_timeout: Timeout in seconds for acquiring filelock.
|
|
542
|
+
filelock_poll_interval: Interval between lock acquisition attempts.
|
|
543
|
+
fts_language: Language for FTS stemming.
|
|
544
|
+
index_type: Vector index type (IVF_PQ, IVF_FLAT, or HNSW_SQ).
|
|
545
|
+
hnsw_m: HNSW connections per node (4-64).
|
|
546
|
+
hnsw_ef_construction: HNSW build-time search width (100-1000).
|
|
547
|
+
enable_memory_expiration: Enable automatic memory expiration.
|
|
548
|
+
default_memory_ttl_days: Default TTL for memories in days (None = no expiration).
|
|
549
|
+
acknowledge_network_filesystem_risk: Suppress network filesystem warnings.
|
|
550
|
+
"""
|
|
551
|
+
self.storage_path = Path(storage_path)
|
|
552
|
+
self.embedding_dim = embedding_dim
|
|
553
|
+
self.auto_create_indexes = auto_create_indexes
|
|
554
|
+
self.vector_index_threshold = vector_index_threshold
|
|
555
|
+
self.enable_fts = enable_fts
|
|
556
|
+
self.index_nprobes = index_nprobes
|
|
557
|
+
self.index_refine_factor = index_refine_factor
|
|
558
|
+
self.max_retry_attempts = max_retry_attempts
|
|
559
|
+
self.retry_backoff_seconds = retry_backoff_seconds
|
|
560
|
+
self.read_consistency_interval_ms = read_consistency_interval_ms
|
|
561
|
+
self.index_wait_timeout_seconds = index_wait_timeout_seconds
|
|
562
|
+
self.fts_stem = fts_stem
|
|
563
|
+
self.fts_remove_stop_words = fts_remove_stop_words
|
|
564
|
+
self.fts_language = fts_language
|
|
565
|
+
self.index_type = index_type
|
|
566
|
+
self.hnsw_m = hnsw_m
|
|
567
|
+
self.hnsw_ef_construction = hnsw_ef_construction
|
|
568
|
+
self.enable_memory_expiration = enable_memory_expiration
|
|
569
|
+
self.default_memory_ttl_days = default_memory_ttl_days
|
|
570
|
+
self.filelock_enabled = filelock_enabled
|
|
571
|
+
self.filelock_timeout = filelock_timeout
|
|
572
|
+
self.filelock_poll_interval = filelock_poll_interval
|
|
573
|
+
self.acknowledge_network_filesystem_risk = acknowledge_network_filesystem_risk
|
|
574
|
+
self._db: lancedb.DBConnection | None = None
|
|
575
|
+
self._table: LanceTable | None = None
|
|
576
|
+
self._has_vector_index: bool | None = None
|
|
577
|
+
self._has_fts_index: bool | None = None
|
|
578
|
+
# Row count cache for performance (avoid count_rows() on every search)
|
|
579
|
+
self._cached_row_count: int | None = None
|
|
580
|
+
self._count_cache_time: float = 0.0
|
|
581
|
+
# Thread-safe lock for row count cache
|
|
582
|
+
self._cache_lock = threading.Lock()
|
|
583
|
+
# Namespace cache for performance
|
|
584
|
+
self._cached_namespaces: set[str] | None = None
|
|
585
|
+
self._namespace_cache_time: float = 0.0
|
|
586
|
+
self._namespace_cache_lock = threading.Lock()
|
|
587
|
+
# Write lock for serializing mutations (prevents LanceDB version conflicts)
|
|
588
|
+
self._write_lock = threading.RLock()
|
|
589
|
+
# Cross-process lock (initialized in connect())
|
|
590
|
+
self._process_lock: ProcessLockManager | None = None
|
|
591
|
+
# Auto-compaction tracking
|
|
592
|
+
self._modification_count: int = 0
|
|
593
|
+
self._auto_compaction_threshold: int = 100 # Compact after this many modifications
|
|
594
|
+
self._auto_compaction_enabled: bool = True
|
|
595
|
+
# Version manager (initialized in connect())
|
|
596
|
+
self._version_manager: VersionManager | None = None
|
|
597
|
+
# Index manager (initialized in connect())
|
|
598
|
+
self._index_manager: IndexManager | None = None
|
|
599
|
+
# Search manager (initialized in connect())
|
|
600
|
+
self._search_manager: SearchManager | None = None
|
|
601
|
+
# Idempotency manager (initialized in connect())
|
|
602
|
+
self._idempotency_manager: IdempotencyManager | None = None
|
|
603
|
+
|
|
604
|
+
def __enter__(self) -> Database:
|
|
605
|
+
"""Enter context manager."""
|
|
606
|
+
self.connect()
|
|
607
|
+
return self
|
|
608
|
+
|
|
609
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
610
|
+
"""Exit context manager."""
|
|
611
|
+
self.close()
|
|
612
|
+
|
|
613
|
+
def connect(self) -> None:
|
|
614
|
+
"""Connect to the database using pooled connections."""
|
|
615
|
+
try:
|
|
616
|
+
self.storage_path.mkdir(parents=True, exist_ok=True)
|
|
617
|
+
|
|
618
|
+
# Check for network filesystem and warn if detected
|
|
619
|
+
if not self.acknowledge_network_filesystem_risk:
|
|
620
|
+
if is_network_filesystem(self.storage_path):
|
|
621
|
+
fs_type = detect_filesystem_type(self.storage_path)
|
|
622
|
+
warning_msg = get_filesystem_warning_message(fs_type, self.storage_path)
|
|
623
|
+
logger.warning(warning_msg)
|
|
624
|
+
|
|
625
|
+
# Initialize cross-process lock manager
|
|
626
|
+
if self.filelock_enabled:
|
|
627
|
+
lock_path = self.storage_path / ".spatial-memory.lock"
|
|
628
|
+
self._process_lock = ProcessLockManager(
|
|
629
|
+
lock_path=lock_path,
|
|
630
|
+
timeout=self.filelock_timeout,
|
|
631
|
+
poll_interval=self.filelock_poll_interval,
|
|
632
|
+
enabled=self.filelock_enabled,
|
|
633
|
+
)
|
|
634
|
+
else:
|
|
635
|
+
self._process_lock = None
|
|
636
|
+
|
|
637
|
+
# Use connection pooling with read consistency support
|
|
638
|
+
self._db = _get_or_create_connection(
|
|
639
|
+
self.storage_path,
|
|
640
|
+
read_consistency_interval_ms=self.read_consistency_interval_ms,
|
|
641
|
+
)
|
|
642
|
+
self._ensure_table()
|
|
643
|
+
# Initialize remaining managers (IndexManager already initialized in _ensure_table)
|
|
644
|
+
self._version_manager = VersionManager(self)
|
|
645
|
+
self._search_manager = SearchManager(self)
|
|
646
|
+
self._idempotency_manager = IdempotencyManager(self)
|
|
647
|
+
logger.info(f"Connected to LanceDB at {self.storage_path}")
|
|
648
|
+
|
|
649
|
+
# Check for pending schema migrations
|
|
650
|
+
self._check_pending_migrations()
|
|
651
|
+
except Exception as e:
|
|
652
|
+
raise StorageError(f"Failed to connect to database: {e}") from e
|
|
653
|
+
|
|
654
|
+
def _check_pending_migrations(self) -> None:
|
|
655
|
+
"""Check for pending migrations and warn if any exist.
|
|
656
|
+
|
|
657
|
+
This method checks the schema version and logs a warning if there
|
|
658
|
+
are pending migrations. It does not auto-apply migrations - that
|
|
659
|
+
requires explicit user action via the CLI.
|
|
660
|
+
"""
|
|
661
|
+
try:
|
|
662
|
+
manager = MigrationManager(self, embeddings=None)
|
|
663
|
+
manager.register_builtin_migrations()
|
|
664
|
+
|
|
665
|
+
current_version = manager.get_current_version()
|
|
666
|
+
pending = manager.get_pending_migrations()
|
|
667
|
+
|
|
668
|
+
if pending:
|
|
669
|
+
pending_versions = [m.version for m in pending]
|
|
670
|
+
logger.warning(
|
|
671
|
+
f"Database schema version {current_version} is outdated. "
|
|
672
|
+
f"{len(pending)} migration(s) pending: {', '.join(pending_versions)}. "
|
|
673
|
+
f"Target version: {CURRENT_SCHEMA_VERSION}. "
|
|
674
|
+
f"Run 'spatial-memory migrate' to apply migrations."
|
|
675
|
+
)
|
|
676
|
+
except Exception as e:
|
|
677
|
+
# Don't fail connection due to migration check errors
|
|
678
|
+
logger.debug(f"Migration check skipped: {e}")
|
|
679
|
+
|
|
680
|
+
def _ensure_table(self) -> None:
|
|
681
|
+
"""Ensure the memories table exists with appropriate indexes.
|
|
682
|
+
|
|
683
|
+
Uses retry logic to handle race conditions when multiple processes
|
|
684
|
+
attempt to create/open the table simultaneously.
|
|
685
|
+
"""
|
|
686
|
+
if self._db is None:
|
|
687
|
+
raise StorageError("Database not connected")
|
|
688
|
+
|
|
689
|
+
max_retries = 3
|
|
690
|
+
retry_delay = 0.1 # Start with 100ms
|
|
691
|
+
|
|
692
|
+
for attempt in range(max_retries):
|
|
693
|
+
try:
|
|
694
|
+
existing_tables_result = self._db.list_tables()
|
|
695
|
+
# Handle both old (list) and new (object with .tables) LanceDB API
|
|
696
|
+
if hasattr(existing_tables_result, 'tables'):
|
|
697
|
+
existing_tables = existing_tables_result.tables
|
|
698
|
+
else:
|
|
699
|
+
existing_tables = existing_tables_result
|
|
700
|
+
|
|
701
|
+
if "memories" not in existing_tables:
|
|
702
|
+
# Create table with schema
|
|
703
|
+
schema = pa.schema([
|
|
704
|
+
pa.field("id", pa.string()),
|
|
705
|
+
pa.field("content", pa.string()),
|
|
706
|
+
pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
|
|
707
|
+
pa.field("created_at", pa.timestamp("us")),
|
|
708
|
+
pa.field("updated_at", pa.timestamp("us")),
|
|
709
|
+
pa.field("last_accessed", pa.timestamp("us")),
|
|
710
|
+
pa.field("access_count", pa.int32()),
|
|
711
|
+
pa.field("importance", pa.float32()),
|
|
712
|
+
pa.field("namespace", pa.string()),
|
|
713
|
+
pa.field("tags", pa.list_(pa.string())),
|
|
714
|
+
pa.field("source", pa.string()),
|
|
715
|
+
pa.field("metadata", pa.string()),
|
|
716
|
+
pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
|
|
717
|
+
])
|
|
718
|
+
try:
|
|
719
|
+
self._table = self._db.create_table("memories", schema=schema)
|
|
720
|
+
logger.info("Created memories table")
|
|
721
|
+
except Exception as create_err:
|
|
722
|
+
# Table might have been created by another process
|
|
723
|
+
if "already exists" in str(create_err).lower():
|
|
724
|
+
logger.debug("Table created by another process, opening it")
|
|
725
|
+
self._table = self._db.open_table("memories")
|
|
726
|
+
else:
|
|
727
|
+
raise
|
|
728
|
+
|
|
729
|
+
# Initialize IndexManager immediately after table is set
|
|
730
|
+
self._index_manager = IndexManager(self)
|
|
731
|
+
|
|
732
|
+
# Create FTS index on new table if enabled
|
|
733
|
+
if self.enable_fts:
|
|
734
|
+
self._index_manager.create_fts_index()
|
|
735
|
+
else:
|
|
736
|
+
self._table = self._db.open_table("memories")
|
|
737
|
+
logger.debug("Opened existing memories table")
|
|
738
|
+
|
|
739
|
+
# Initialize IndexManager immediately after table is set
|
|
740
|
+
self._index_manager = IndexManager(self)
|
|
741
|
+
|
|
742
|
+
# Check existing indexes
|
|
743
|
+
self._index_manager.check_existing_indexes()
|
|
744
|
+
|
|
745
|
+
# Success - exit retry loop
|
|
746
|
+
return
|
|
747
|
+
|
|
748
|
+
except Exception as e:
|
|
749
|
+
error_msg = str(e).lower()
|
|
750
|
+
# Retry on transient race conditions
|
|
751
|
+
if attempt < max_retries - 1 and (
|
|
752
|
+
"not found" in error_msg
|
|
753
|
+
or "does not exist" in error_msg
|
|
754
|
+
or "already exists" in error_msg
|
|
755
|
+
):
|
|
756
|
+
logger.debug(
|
|
757
|
+
f"Table operation failed (attempt {attempt + 1}/{max_retries}), "
|
|
758
|
+
f"retrying in {retry_delay}s: {e}"
|
|
759
|
+
)
|
|
760
|
+
time.sleep(retry_delay)
|
|
761
|
+
retry_delay *= 2 # Exponential backoff
|
|
762
|
+
else:
|
|
763
|
+
raise
|
|
764
|
+
|
|
765
|
+
def _check_existing_indexes(self) -> None:
|
|
766
|
+
"""Check which indexes already exist. Delegates to IndexManager."""
|
|
767
|
+
if self._index_manager is None:
|
|
768
|
+
raise StorageError("Database not connected")
|
|
769
|
+
self._index_manager.check_existing_indexes()
|
|
770
|
+
# Sync local state for backward compatibility
|
|
771
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
772
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
773
|
+
|
|
774
|
+
def _create_fts_index(self) -> None:
|
|
775
|
+
"""Create FTS index. Delegates to IndexManager."""
|
|
776
|
+
if self._index_manager is None:
|
|
777
|
+
raise StorageError("Database not connected")
|
|
778
|
+
self._index_manager.create_fts_index()
|
|
779
|
+
# Sync local state for backward compatibility
|
|
780
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
781
|
+
|
|
782
|
+
@property
|
|
783
|
+
def table(self) -> LanceTable:
|
|
784
|
+
"""Get the memories table, connecting if needed."""
|
|
785
|
+
if self._table is None:
|
|
786
|
+
self.connect()
|
|
787
|
+
assert self._table is not None # connect() sets this or raises
|
|
788
|
+
return self._table
|
|
789
|
+
|
|
790
|
+
def close(self) -> None:
|
|
791
|
+
"""Close the database connection and remove from pool.
|
|
792
|
+
|
|
793
|
+
This invalidates the pooled connection so that subsequent
|
|
794
|
+
Database instances will create fresh connections.
|
|
795
|
+
"""
|
|
796
|
+
# Invalidate pooled connection first
|
|
797
|
+
invalidate_connection(self.storage_path)
|
|
798
|
+
|
|
799
|
+
# Clear local state
|
|
800
|
+
self._table = None
|
|
801
|
+
self._db = None
|
|
802
|
+
self._has_vector_index = None
|
|
803
|
+
self._has_fts_index = None
|
|
804
|
+
self._version_manager = None
|
|
805
|
+
self._index_manager = None
|
|
806
|
+
self._search_manager = None
|
|
807
|
+
self._idempotency_manager = None
|
|
808
|
+
with self._cache_lock:
|
|
809
|
+
self._cached_row_count = None
|
|
810
|
+
self._count_cache_time = 0.0
|
|
811
|
+
with self._namespace_cache_lock:
|
|
812
|
+
self._cached_namespaces = None
|
|
813
|
+
self._namespace_cache_time = 0.0
|
|
814
|
+
logger.debug("Database connection closed and removed from pool")
|
|
815
|
+
|
|
816
|
+
def reconnect(self) -> None:
|
|
817
|
+
"""Invalidate cached connection and reconnect.
|
|
818
|
+
|
|
819
|
+
Use when the database connection becomes stale (e.g., database was
|
|
820
|
+
deleted and recreated externally, or metadata references missing files).
|
|
821
|
+
|
|
822
|
+
This method:
|
|
823
|
+
1. Closes the current connection state
|
|
824
|
+
2. Invalidates the pooled connection for this path
|
|
825
|
+
3. Creates a fresh connection to the database
|
|
826
|
+
"""
|
|
827
|
+
logger.info(f"Reconnecting to database at {self.storage_path}")
|
|
828
|
+
self.close()
|
|
829
|
+
invalidate_connection(self.storage_path)
|
|
830
|
+
self.connect()
|
|
831
|
+
|
|
832
|
+
def _is_stale_connection_error(self, error: Exception) -> bool:
|
|
833
|
+
"""Check if an error indicates a stale/corrupted connection.
|
|
834
|
+
|
|
835
|
+
Args:
|
|
836
|
+
error: The exception to check.
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
True if the error indicates a stale connection.
|
|
840
|
+
"""
|
|
841
|
+
return ConnectionPool.is_stale_connection_error(error)
|
|
842
|
+
|
|
843
|
+
def _get_cached_row_count(self) -> int:
|
|
844
|
+
"""Get row count with caching for performance (thread-safe).
|
|
845
|
+
|
|
846
|
+
Avoids calling count_rows() on every search operation.
|
|
847
|
+
Cache is invalidated on insert/delete or after TTL expires.
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
Cached or fresh row count.
|
|
851
|
+
"""
|
|
852
|
+
now = time.time()
|
|
853
|
+
with self._cache_lock:
|
|
854
|
+
if (
|
|
855
|
+
self._cached_row_count is None
|
|
856
|
+
or (now - self._count_cache_time) > self._COUNT_CACHE_TTL
|
|
857
|
+
):
|
|
858
|
+
self._cached_row_count = self.table.count_rows()
|
|
859
|
+
self._count_cache_time = now
|
|
860
|
+
return self._cached_row_count
|
|
861
|
+
|
|
862
|
+
def _invalidate_count_cache(self) -> None:
|
|
863
|
+
"""Invalidate the row count cache after modifications (thread-safe)."""
|
|
864
|
+
with self._cache_lock:
|
|
865
|
+
self._cached_row_count = None
|
|
866
|
+
self._count_cache_time = 0.0
|
|
867
|
+
|
|
868
|
+
def _invalidate_namespace_cache(self) -> None:
|
|
869
|
+
"""Invalidate the namespace cache after modifications (thread-safe)."""
|
|
870
|
+
with self._namespace_cache_lock:
|
|
871
|
+
self._cached_namespaces = None
|
|
872
|
+
self._namespace_cache_time = 0.0
|
|
873
|
+
|
|
874
|
+
def _track_modification(self, count: int = 1) -> None:
|
|
875
|
+
"""Track database modifications and trigger auto-compaction if threshold reached.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
count: Number of modifications to track (default 1).
|
|
879
|
+
"""
|
|
880
|
+
if not self._auto_compaction_enabled:
|
|
881
|
+
return
|
|
882
|
+
|
|
883
|
+
self._modification_count += count
|
|
884
|
+
if self._modification_count >= self._auto_compaction_threshold:
|
|
885
|
+
# Reset counter before compacting to avoid re-triggering
|
|
886
|
+
self._modification_count = 0
|
|
887
|
+
try:
|
|
888
|
+
stats = self._get_table_stats()
|
|
889
|
+
# Only compact if there are enough fragments to justify it
|
|
890
|
+
if stats.get("num_small_fragments", 0) >= 5:
|
|
891
|
+
logger.info(
|
|
892
|
+
f"Auto-compaction triggered after {self._auto_compaction_threshold} "
|
|
893
|
+
f"modifications ({stats.get('num_small_fragments', 0)} small fragments)"
|
|
894
|
+
)
|
|
895
|
+
self.table.compact_files()
|
|
896
|
+
logger.debug("Auto-compaction completed")
|
|
897
|
+
except Exception as e:
|
|
898
|
+
# Don't fail operations due to compaction issues
|
|
899
|
+
logger.debug(f"Auto-compaction skipped: {e}")
|
|
900
|
+
|
|
901
|
+
def set_auto_compaction(
|
|
902
|
+
self,
|
|
903
|
+
enabled: bool = True,
|
|
904
|
+
threshold: int | None = None,
|
|
905
|
+
) -> None:
|
|
906
|
+
"""Configure auto-compaction behavior.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
enabled: Whether auto-compaction is enabled.
|
|
910
|
+
threshold: Number of modifications before auto-compact (default: 100).
|
|
911
|
+
"""
|
|
912
|
+
self._auto_compaction_enabled = enabled
|
|
913
|
+
if threshold is not None:
|
|
914
|
+
if threshold < 10:
|
|
915
|
+
raise ValueError("Auto-compaction threshold must be at least 10")
|
|
916
|
+
self._auto_compaction_threshold = threshold
|
|
917
|
+
|
|
918
|
+
# ========================================================================
|
|
919
|
+
# Index Management (delegates to IndexManager)
|
|
920
|
+
# ========================================================================
|
|
921
|
+
|
|
922
|
+
def create_vector_index(self, force: bool = False) -> bool:
|
|
923
|
+
"""Create vector index for similarity search. Delegates to IndexManager.
|
|
924
|
+
|
|
925
|
+
Args:
|
|
926
|
+
force: Force index creation regardless of dataset size.
|
|
927
|
+
|
|
928
|
+
Returns:
|
|
929
|
+
True if index was created, False if skipped.
|
|
930
|
+
|
|
931
|
+
Raises:
|
|
932
|
+
StorageError: If index creation fails.
|
|
933
|
+
"""
|
|
934
|
+
if self._index_manager is None:
|
|
935
|
+
raise StorageError("Database not connected")
|
|
936
|
+
result = self._index_manager.create_vector_index(force=force)
|
|
937
|
+
# Sync local state only when index was created or modified
|
|
938
|
+
if result:
|
|
939
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
940
|
+
return result
|
|
941
|
+
|
|
942
|
+
def create_scalar_indexes(self) -> None:
|
|
943
|
+
"""Create scalar indexes for frequently filtered columns. Delegates to IndexManager.
|
|
944
|
+
|
|
945
|
+
Raises:
|
|
946
|
+
StorageError: If index creation fails critically.
|
|
947
|
+
"""
|
|
948
|
+
if self._index_manager is None:
|
|
949
|
+
raise StorageError("Database not connected")
|
|
950
|
+
self._index_manager.create_scalar_indexes()
|
|
951
|
+
|
|
952
|
+
def ensure_indexes(self, force: bool = False) -> dict[str, bool]:
|
|
953
|
+
"""Ensure all appropriate indexes exist. Delegates to IndexManager.
|
|
954
|
+
|
|
955
|
+
Args:
|
|
956
|
+
force: Force index creation regardless of thresholds.
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
Dict indicating which indexes were created.
|
|
960
|
+
"""
|
|
961
|
+
if self._index_manager is None:
|
|
962
|
+
raise StorageError("Database not connected")
|
|
963
|
+
results = self._index_manager.ensure_indexes(force=force)
|
|
964
|
+
# Sync local state for backward compatibility
|
|
965
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
966
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
967
|
+
return results
|
|
968
|
+
|
|
969
|
+
# ========================================================================
|
|
970
|
+
# Maintenance & Optimization
|
|
971
|
+
# ========================================================================
|
|
972
|
+
|
|
973
|
+
def optimize(self) -> dict[str, Any]:
|
|
974
|
+
"""Run optimization and maintenance tasks.
|
|
975
|
+
|
|
976
|
+
Performs:
|
|
977
|
+
- File compaction (merges small fragments)
|
|
978
|
+
- Index optimization
|
|
979
|
+
|
|
980
|
+
Returns:
|
|
981
|
+
Statistics about optimization performed.
|
|
982
|
+
"""
|
|
983
|
+
try:
|
|
984
|
+
stats_before = self._get_table_stats()
|
|
985
|
+
|
|
986
|
+
# Compact small fragments
|
|
987
|
+
needs_compaction = stats_before.get("num_small_fragments", 0) > 10
|
|
988
|
+
if needs_compaction:
|
|
989
|
+
logger.info("Compacting fragments...")
|
|
990
|
+
self.table.compact_files()
|
|
991
|
+
|
|
992
|
+
# Optimize indexes
|
|
993
|
+
logger.info("Optimizing indexes...")
|
|
994
|
+
self.table.optimize()
|
|
995
|
+
|
|
996
|
+
stats_after = self._get_table_stats()
|
|
997
|
+
|
|
998
|
+
return {
|
|
999
|
+
"fragments_before": stats_before.get("num_fragments", 0),
|
|
1000
|
+
"fragments_after": stats_after.get("num_fragments", 0),
|
|
1001
|
+
"compaction_performed": needs_compaction,
|
|
1002
|
+
"total_rows": stats_after.get("num_rows", 0),
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
except Exception as e:
|
|
1006
|
+
logger.error(f"Optimization failed: {e}")
|
|
1007
|
+
return {"error": str(e)}
|
|
1008
|
+
|
|
1009
|
+
def _get_table_stats(self) -> dict[str, Any]:
|
|
1010
|
+
"""Get table statistics with best-effort fragment info."""
|
|
1011
|
+
try:
|
|
1012
|
+
count = self.table.count_rows()
|
|
1013
|
+
stats: dict[str, Any] = {
|
|
1014
|
+
"num_rows": count,
|
|
1015
|
+
"num_fragments": 0,
|
|
1016
|
+
"num_small_fragments": 0,
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
# Try to get fragment stats from table.stats() if available
|
|
1020
|
+
try:
|
|
1021
|
+
if hasattr(self.table, "stats"):
|
|
1022
|
+
table_stats = self.table.stats()
|
|
1023
|
+
if isinstance(table_stats, dict):
|
|
1024
|
+
stats["num_fragments"] = table_stats.get("num_fragments", 0)
|
|
1025
|
+
stats["num_small_fragments"] = table_stats.get("num_small_fragments", 0)
|
|
1026
|
+
elif hasattr(table_stats, "num_fragments"):
|
|
1027
|
+
stats["num_fragments"] = table_stats.num_fragments
|
|
1028
|
+
stats["num_small_fragments"] = getattr(
|
|
1029
|
+
table_stats, "num_small_fragments", 0
|
|
1030
|
+
)
|
|
1031
|
+
except Exception as e:
|
|
1032
|
+
logger.debug(f"Could not get fragment stats: {e}")
|
|
1033
|
+
|
|
1034
|
+
return stats
|
|
1035
|
+
except Exception as e:
|
|
1036
|
+
logger.warning(f"Could not get table stats: {e}")
|
|
1037
|
+
return {}
|
|
1038
|
+
|
|
1039
|
+
@with_stale_connection_recovery
|
|
1040
|
+
def get_health_metrics(self) -> HealthMetrics:
|
|
1041
|
+
"""Get comprehensive health and performance metrics.
|
|
1042
|
+
|
|
1043
|
+
Returns:
|
|
1044
|
+
HealthMetrics dataclass with all metrics.
|
|
1045
|
+
"""
|
|
1046
|
+
try:
|
|
1047
|
+
count = self.table.count_rows()
|
|
1048
|
+
|
|
1049
|
+
# Estimate size (rough approximation)
|
|
1050
|
+
# vector (dim * 4 bytes) + avg content size estimate
|
|
1051
|
+
estimated_bytes = count * (self.embedding_dim * 4 + 1000)
|
|
1052
|
+
|
|
1053
|
+
# Check indexes
|
|
1054
|
+
indices: list[IndexStats] = []
|
|
1055
|
+
try:
|
|
1056
|
+
for idx in self.table.list_indices():
|
|
1057
|
+
indices.append(IndexStats(
|
|
1058
|
+
name=str(_get_index_attr(idx, "name", "unknown")),
|
|
1059
|
+
index_type=str(_get_index_attr(idx, "index_type", "unknown")),
|
|
1060
|
+
num_indexed_rows=count, # Approximate
|
|
1061
|
+
num_unindexed_rows=0,
|
|
1062
|
+
needs_update=False,
|
|
1063
|
+
))
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
logger.warning(f"Could not get index stats: {e}")
|
|
1066
|
+
|
|
1067
|
+
return HealthMetrics(
|
|
1068
|
+
total_rows=count,
|
|
1069
|
+
total_bytes=estimated_bytes,
|
|
1070
|
+
total_bytes_mb=estimated_bytes / (1024 * 1024),
|
|
1071
|
+
num_fragments=0,
|
|
1072
|
+
num_small_fragments=0,
|
|
1073
|
+
needs_compaction=False,
|
|
1074
|
+
has_vector_index=self._has_vector_index or False,
|
|
1075
|
+
has_fts_index=self._has_fts_index or False,
|
|
1076
|
+
indices=indices,
|
|
1077
|
+
version=0,
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
except Exception as e:
|
|
1081
|
+
return HealthMetrics(
|
|
1082
|
+
total_rows=0,
|
|
1083
|
+
total_bytes=0,
|
|
1084
|
+
total_bytes_mb=0,
|
|
1085
|
+
num_fragments=0,
|
|
1086
|
+
num_small_fragments=0,
|
|
1087
|
+
needs_compaction=False,
|
|
1088
|
+
has_vector_index=False,
|
|
1089
|
+
has_fts_index=False,
|
|
1090
|
+
indices=[],
|
|
1091
|
+
version=0,
|
|
1092
|
+
error=str(e),
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
@with_process_lock
|
|
1096
|
+
@with_write_lock
|
|
1097
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
1098
|
+
def insert(
|
|
1099
|
+
self,
|
|
1100
|
+
content: str,
|
|
1101
|
+
vector: np.ndarray,
|
|
1102
|
+
namespace: str = "default",
|
|
1103
|
+
tags: list[str] | None = None,
|
|
1104
|
+
importance: float = 0.5,
|
|
1105
|
+
source: str = "manual",
|
|
1106
|
+
metadata: dict[str, Any] | None = None,
|
|
1107
|
+
) -> str:
|
|
1108
|
+
"""Insert a new memory.
|
|
1109
|
+
|
|
1110
|
+
Args:
|
|
1111
|
+
content: Text content of the memory.
|
|
1112
|
+
vector: Embedding vector.
|
|
1113
|
+
namespace: Namespace for organization.
|
|
1114
|
+
tags: List of tags.
|
|
1115
|
+
importance: Importance score (0-1).
|
|
1116
|
+
source: Source of the memory.
|
|
1117
|
+
metadata: Additional metadata.
|
|
1118
|
+
|
|
1119
|
+
Returns:
|
|
1120
|
+
The generated memory ID.
|
|
1121
|
+
|
|
1122
|
+
Raises:
|
|
1123
|
+
ValidationError: If input validation fails.
|
|
1124
|
+
StorageError: If database operation fails.
|
|
1125
|
+
"""
|
|
1126
|
+
# Validate inputs
|
|
1127
|
+
namespace = _validate_namespace(namespace)
|
|
1128
|
+
tags = _validate_tags(tags)
|
|
1129
|
+
metadata = _validate_metadata(metadata)
|
|
1130
|
+
if not content or len(content) > 100000:
|
|
1131
|
+
raise ValidationError("Content must be between 1 and 100000 characters")
|
|
1132
|
+
if not 0.0 <= importance <= 1.0:
|
|
1133
|
+
raise ValidationError("Importance must be between 0.0 and 1.0")
|
|
1134
|
+
|
|
1135
|
+
# Validate vector dimensions
|
|
1136
|
+
if len(vector) != self.embedding_dim:
|
|
1137
|
+
raise DimensionMismatchError(
|
|
1138
|
+
expected_dim=self.embedding_dim,
|
|
1139
|
+
actual_dim=len(vector),
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
memory_id = str(uuid.uuid4())
|
|
1143
|
+
now = utc_now()
|
|
1144
|
+
|
|
1145
|
+
# Calculate expires_at if default TTL is configured
|
|
1146
|
+
expires_at = None
|
|
1147
|
+
if self.default_memory_ttl_days is not None:
|
|
1148
|
+
expires_at = now + timedelta(days=self.default_memory_ttl_days)
|
|
1149
|
+
|
|
1150
|
+
record = {
|
|
1151
|
+
"id": memory_id,
|
|
1152
|
+
"content": content,
|
|
1153
|
+
"vector": vector.tolist(),
|
|
1154
|
+
"created_at": now,
|
|
1155
|
+
"updated_at": now,
|
|
1156
|
+
"last_accessed": now,
|
|
1157
|
+
"access_count": 0,
|
|
1158
|
+
"importance": importance,
|
|
1159
|
+
"namespace": namespace,
|
|
1160
|
+
"tags": tags,
|
|
1161
|
+
"source": source,
|
|
1162
|
+
"metadata": json.dumps(metadata),
|
|
1163
|
+
"expires_at": expires_at,
|
|
1164
|
+
}
|
|
1165
|
+
|
|
1166
|
+
try:
|
|
1167
|
+
self.table.add([record])
|
|
1168
|
+
self._invalidate_count_cache()
|
|
1169
|
+
self._track_modification()
|
|
1170
|
+
self._invalidate_namespace_cache()
|
|
1171
|
+
logger.debug(f"Inserted memory {memory_id}")
|
|
1172
|
+
return memory_id
|
|
1173
|
+
except Exception as e:
|
|
1174
|
+
raise StorageError(f"Failed to insert memory: {e}") from e
|
|
1175
|
+
|
|
1176
|
+
# Maximum batch size to prevent memory exhaustion
|
|
1177
|
+
MAX_BATCH_SIZE = 10_000
|
|
1178
|
+
|
|
1179
|
+
@with_process_lock
|
|
1180
|
+
@with_write_lock
|
|
1181
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
1182
|
+
def insert_batch(
|
|
1183
|
+
self,
|
|
1184
|
+
records: list[dict[str, Any]],
|
|
1185
|
+
batch_size: int = 1000,
|
|
1186
|
+
atomic: bool = False,
|
|
1187
|
+
) -> list[str]:
|
|
1188
|
+
"""Insert multiple memories efficiently with batching.
|
|
1189
|
+
|
|
1190
|
+
Args:
|
|
1191
|
+
records: List of memory records with content, vector, and optional fields.
|
|
1192
|
+
batch_size: Records per batch (default: 1000, max: 10000).
|
|
1193
|
+
atomic: If True, rollback all inserts on partial failure.
|
|
1194
|
+
When atomic=True and a batch fails:
|
|
1195
|
+
- Attempts to delete already-inserted records
|
|
1196
|
+
- If rollback succeeds, raises the original StorageError
|
|
1197
|
+
- If rollback fails, raises PartialBatchInsertError with succeeded_ids
|
|
1198
|
+
|
|
1199
|
+
Returns:
|
|
1200
|
+
List of generated memory IDs.
|
|
1201
|
+
|
|
1202
|
+
Raises:
|
|
1203
|
+
ValidationError: If input validation fails or batch_size exceeds maximum.
|
|
1204
|
+
StorageError: If database operation fails (and rollback succeeds when atomic=True).
|
|
1205
|
+
PartialBatchInsertError: If atomic=True and rollback fails after partial insert.
|
|
1206
|
+
"""
|
|
1207
|
+
if batch_size > self.MAX_BATCH_SIZE:
|
|
1208
|
+
raise ValidationError(
|
|
1209
|
+
f"batch_size ({batch_size}) exceeds maximum {self.MAX_BATCH_SIZE}"
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
all_ids: list[str] = []
|
|
1213
|
+
total_requested = len(records)
|
|
1214
|
+
|
|
1215
|
+
# Process in batches for large inserts
|
|
1216
|
+
for batch_index, i in enumerate(range(0, len(records), batch_size)):
|
|
1217
|
+
batch = records[i:i + batch_size]
|
|
1218
|
+
now = utc_now()
|
|
1219
|
+
memory_ids: list[str] = []
|
|
1220
|
+
prepared_records: list[dict[str, Any]] = []
|
|
1221
|
+
|
|
1222
|
+
for record in batch:
|
|
1223
|
+
# Validate each record
|
|
1224
|
+
namespace = _validate_namespace(record.get("namespace", "default"))
|
|
1225
|
+
tags = _validate_tags(record.get("tags"))
|
|
1226
|
+
metadata = _validate_metadata(record.get("metadata"))
|
|
1227
|
+
content = record.get("content", "")
|
|
1228
|
+
if not content or len(content) > 100000:
|
|
1229
|
+
raise ValidationError("Content must be between 1 and 100000 characters")
|
|
1230
|
+
|
|
1231
|
+
importance = record.get("importance", 0.5)
|
|
1232
|
+
if not 0.0 <= importance <= 1.0:
|
|
1233
|
+
raise ValidationError("Importance must be between 0.0 and 1.0")
|
|
1234
|
+
|
|
1235
|
+
memory_id = str(uuid.uuid4())
|
|
1236
|
+
memory_ids.append(memory_id)
|
|
1237
|
+
|
|
1238
|
+
raw_vector = record["vector"]
|
|
1239
|
+
if isinstance(raw_vector, np.ndarray):
|
|
1240
|
+
vector_list = raw_vector.tolist()
|
|
1241
|
+
else:
|
|
1242
|
+
vector_list = raw_vector
|
|
1243
|
+
|
|
1244
|
+
# Validate vector dimensions
|
|
1245
|
+
if len(vector_list) != self.embedding_dim:
|
|
1246
|
+
raise DimensionMismatchError(
|
|
1247
|
+
expected_dim=self.embedding_dim,
|
|
1248
|
+
actual_dim=len(vector_list),
|
|
1249
|
+
record_index=i + len(memory_ids),
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
# Calculate expires_at if default TTL is configured
|
|
1253
|
+
expires_at = None
|
|
1254
|
+
if self.default_memory_ttl_days is not None:
|
|
1255
|
+
expires_at = now + timedelta(days=self.default_memory_ttl_days)
|
|
1256
|
+
|
|
1257
|
+
prepared = {
|
|
1258
|
+
"id": memory_id,
|
|
1259
|
+
"content": content,
|
|
1260
|
+
"vector": vector_list,
|
|
1261
|
+
"created_at": now,
|
|
1262
|
+
"updated_at": now,
|
|
1263
|
+
"last_accessed": now,
|
|
1264
|
+
"access_count": 0,
|
|
1265
|
+
"importance": importance,
|
|
1266
|
+
"namespace": namespace,
|
|
1267
|
+
"tags": tags,
|
|
1268
|
+
"source": record.get("source", "manual"),
|
|
1269
|
+
"metadata": json.dumps(metadata),
|
|
1270
|
+
"expires_at": expires_at,
|
|
1271
|
+
}
|
|
1272
|
+
prepared_records.append(prepared)
|
|
1273
|
+
|
|
1274
|
+
try:
|
|
1275
|
+
self.table.add(prepared_records)
|
|
1276
|
+
all_ids.extend(memory_ids)
|
|
1277
|
+
self._invalidate_count_cache()
|
|
1278
|
+
self._track_modification(len(memory_ids))
|
|
1279
|
+
self._invalidate_namespace_cache()
|
|
1280
|
+
logger.debug(f"Inserted batch {batch_index + 1}: {len(memory_ids)} memories")
|
|
1281
|
+
except Exception as e:
|
|
1282
|
+
if atomic and all_ids:
|
|
1283
|
+
# Attempt rollback of previously inserted records
|
|
1284
|
+
logger.warning(
|
|
1285
|
+
f"Batch {batch_index + 1} failed, attempting rollback of "
|
|
1286
|
+
f"{len(all_ids)} previously inserted records"
|
|
1287
|
+
)
|
|
1288
|
+
rollback_error = self._rollback_batch_insert(all_ids)
|
|
1289
|
+
if rollback_error:
|
|
1290
|
+
# Rollback failed - raise PartialBatchInsertError
|
|
1291
|
+
raise PartialBatchInsertError(
|
|
1292
|
+
message=f"Batch insert failed and rollback also failed: {e}",
|
|
1293
|
+
succeeded_ids=all_ids,
|
|
1294
|
+
total_requested=total_requested,
|
|
1295
|
+
failed_batch_index=batch_index,
|
|
1296
|
+
) from e
|
|
1297
|
+
else:
|
|
1298
|
+
# Rollback succeeded - raise original error
|
|
1299
|
+
logger.info(f"Rollback successful, deleted {len(all_ids)} records")
|
|
1300
|
+
raise StorageError(f"Failed to insert batch (rolled back): {e}") from e
|
|
1301
|
+
raise StorageError(f"Failed to insert batch: {e}") from e
|
|
1302
|
+
|
|
1303
|
+
# Check if we should create indexes after large insert
|
|
1304
|
+
if self.auto_create_indexes and len(all_ids) >= 1000:
|
|
1305
|
+
count = self._get_cached_row_count()
|
|
1306
|
+
if count >= self.vector_index_threshold and not self._has_vector_index:
|
|
1307
|
+
logger.info("Dataset crossed index threshold, creating indexes...")
|
|
1308
|
+
try:
|
|
1309
|
+
self.ensure_indexes()
|
|
1310
|
+
except Exception as e:
|
|
1311
|
+
logger.warning(f"Auto-index creation failed: {e}")
|
|
1312
|
+
|
|
1313
|
+
logger.debug(f"Inserted {len(all_ids)} memories total")
|
|
1314
|
+
return all_ids
|
|
1315
|
+
|
|
1316
|
+
def _rollback_batch_insert(self, memory_ids: list[str]) -> Exception | None:
|
|
1317
|
+
"""Attempt to delete records inserted during a failed batch operation.
|
|
1318
|
+
|
|
1319
|
+
Args:
|
|
1320
|
+
memory_ids: List of memory IDs to delete.
|
|
1321
|
+
|
|
1322
|
+
Returns:
|
|
1323
|
+
None if rollback succeeded, Exception if it failed.
|
|
1324
|
+
"""
|
|
1325
|
+
try:
|
|
1326
|
+
if not memory_ids:
|
|
1327
|
+
return None
|
|
1328
|
+
|
|
1329
|
+
# Use delete_batch for efficient rollback
|
|
1330
|
+
id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in memory_ids)
|
|
1331
|
+
self.table.delete(f"id IN ({id_list})")
|
|
1332
|
+
self._invalidate_count_cache()
|
|
1333
|
+
self._track_modification(len(memory_ids))
|
|
1334
|
+
self._invalidate_namespace_cache()
|
|
1335
|
+
logger.debug(f"Rolled back {len(memory_ids)} records")
|
|
1336
|
+
return None
|
|
1337
|
+
except Exception as e:
|
|
1338
|
+
logger.error(f"Rollback failed: {e}")
|
|
1339
|
+
return e
|
|
1340
|
+
|
|
1341
|
+
@with_stale_connection_recovery
|
|
1342
|
+
def get(self, memory_id: str) -> dict[str, Any]:
|
|
1343
|
+
"""Get a memory by ID.
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
memory_id: The memory ID.
|
|
1347
|
+
|
|
1348
|
+
Returns:
|
|
1349
|
+
The memory record.
|
|
1350
|
+
|
|
1351
|
+
Raises:
|
|
1352
|
+
ValidationError: If memory_id is invalid.
|
|
1353
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
1354
|
+
StorageError: If database operation fails.
|
|
1355
|
+
"""
|
|
1356
|
+
# Validate and sanitize memory_id
|
|
1357
|
+
memory_id = _validate_uuid(memory_id)
|
|
1358
|
+
safe_id = _sanitize_string(memory_id)
|
|
1359
|
+
|
|
1360
|
+
try:
|
|
1361
|
+
results = self.table.search().where(f"id = '{safe_id}'").limit(1).to_list()
|
|
1362
|
+
if not results:
|
|
1363
|
+
raise MemoryNotFoundError(memory_id)
|
|
1364
|
+
|
|
1365
|
+
record: dict[str, Any] = results[0]
|
|
1366
|
+
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
1367
|
+
return record
|
|
1368
|
+
except MemoryNotFoundError:
|
|
1369
|
+
raise
|
|
1370
|
+
except ValidationError:
|
|
1371
|
+
raise
|
|
1372
|
+
except Exception as e:
|
|
1373
|
+
raise StorageError(f"Failed to get memory: {e}") from e
|
|
1374
|
+
|
|
1375
|
+
def get_batch(self, memory_ids: list[str]) -> dict[str, dict[str, Any]]:
|
|
1376
|
+
"""Get multiple memories by ID in a single query.
|
|
1377
|
+
|
|
1378
|
+
Args:
|
|
1379
|
+
memory_ids: List of memory UUIDs to retrieve.
|
|
1380
|
+
|
|
1381
|
+
Returns:
|
|
1382
|
+
Dict mapping memory_id to memory record. Missing IDs are not included.
|
|
1383
|
+
|
|
1384
|
+
Raises:
|
|
1385
|
+
ValidationError: If any memory_id format is invalid.
|
|
1386
|
+
StorageError: If database operation fails.
|
|
1387
|
+
"""
|
|
1388
|
+
if not memory_ids:
|
|
1389
|
+
return {}
|
|
1390
|
+
|
|
1391
|
+
# Validate all IDs first
|
|
1392
|
+
validated_ids: list[str] = []
|
|
1393
|
+
for memory_id in memory_ids:
|
|
1394
|
+
try:
|
|
1395
|
+
validated_id = _validate_uuid(memory_id)
|
|
1396
|
+
validated_ids.append(_sanitize_string(validated_id))
|
|
1397
|
+
except Exception as e:
|
|
1398
|
+
logger.debug(f"Invalid memory ID {memory_id}: {e}")
|
|
1399
|
+
continue
|
|
1400
|
+
|
|
1401
|
+
if not validated_ids:
|
|
1402
|
+
return {}
|
|
1403
|
+
|
|
1404
|
+
try:
|
|
1405
|
+
# Batch fetch with single IN query
|
|
1406
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
1407
|
+
results = self.table.search().where(f"id IN ({id_list})").to_list()
|
|
1408
|
+
|
|
1409
|
+
# Build result map
|
|
1410
|
+
result_map: dict[str, dict[str, Any]] = {}
|
|
1411
|
+
for record in results:
|
|
1412
|
+
# Deserialize metadata
|
|
1413
|
+
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
1414
|
+
result_map[record["id"]] = record
|
|
1415
|
+
|
|
1416
|
+
return result_map
|
|
1417
|
+
except Exception as e:
|
|
1418
|
+
raise StorageError(f"Failed to batch get memories: {e}") from e
|
|
1419
|
+
|
|
1420
|
+
@with_process_lock
|
|
1421
|
+
@with_write_lock
|
|
1422
|
+
def update(self, memory_id: str, updates: dict[str, Any]) -> None:
|
|
1423
|
+
"""Update a memory using atomic merge_insert.
|
|
1424
|
+
|
|
1425
|
+
Uses LanceDB's merge_insert API for atomic upserts, eliminating
|
|
1426
|
+
race conditions from delete-then-insert patterns.
|
|
1427
|
+
|
|
1428
|
+
Args:
|
|
1429
|
+
memory_id: The memory ID.
|
|
1430
|
+
updates: Fields to update.
|
|
1431
|
+
|
|
1432
|
+
Raises:
|
|
1433
|
+
ValidationError: If input validation fails.
|
|
1434
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
1435
|
+
StorageError: If database operation fails.
|
|
1436
|
+
"""
|
|
1437
|
+
# Validate memory_id
|
|
1438
|
+
memory_id = _validate_uuid(memory_id)
|
|
1439
|
+
|
|
1440
|
+
# First verify the memory exists
|
|
1441
|
+
existing = self.get(memory_id)
|
|
1442
|
+
|
|
1443
|
+
# Prepare updates
|
|
1444
|
+
updates["updated_at"] = utc_now()
|
|
1445
|
+
if "metadata" in updates and isinstance(updates["metadata"], dict):
|
|
1446
|
+
updates["metadata"] = json.dumps(updates["metadata"])
|
|
1447
|
+
if "vector" in updates and isinstance(updates["vector"], np.ndarray):
|
|
1448
|
+
updates["vector"] = updates["vector"].tolist()
|
|
1449
|
+
|
|
1450
|
+
# Merge existing with updates
|
|
1451
|
+
for key, value in updates.items():
|
|
1452
|
+
existing[key] = value
|
|
1453
|
+
|
|
1454
|
+
# Ensure metadata is serialized as JSON string for storage
|
|
1455
|
+
if isinstance(existing.get("metadata"), dict):
|
|
1456
|
+
existing["metadata"] = json.dumps(existing["metadata"])
|
|
1457
|
+
|
|
1458
|
+
# Ensure vector is a list, not numpy array
|
|
1459
|
+
if isinstance(existing.get("vector"), np.ndarray):
|
|
1460
|
+
existing["vector"] = existing["vector"].tolist()
|
|
1461
|
+
|
|
1462
|
+
try:
|
|
1463
|
+
# Atomic upsert using merge_insert
|
|
1464
|
+
# Requires BTREE index on 'id' column (created in create_scalar_indexes)
|
|
1465
|
+
(
|
|
1466
|
+
self.table.merge_insert("id")
|
|
1467
|
+
.when_matched_update_all()
|
|
1468
|
+
.when_not_matched_insert_all()
|
|
1469
|
+
.execute([existing])
|
|
1470
|
+
)
|
|
1471
|
+
logger.debug(f"Updated memory {memory_id} (atomic merge_insert)")
|
|
1472
|
+
except Exception as e:
|
|
1473
|
+
raise StorageError(f"Failed to update memory: {e}") from e
|
|
1474
|
+
|
|
1475
|
+
@with_process_lock
|
|
1476
|
+
@with_write_lock
|
|
1477
|
+
def update_batch(
|
|
1478
|
+
self, updates: list[tuple[str, dict[str, Any]]]
|
|
1479
|
+
) -> tuple[int, list[str]]:
|
|
1480
|
+
"""Update multiple memories using atomic merge_insert.
|
|
1481
|
+
|
|
1482
|
+
Args:
|
|
1483
|
+
updates: List of (memory_id, updates_dict) tuples.
|
|
1484
|
+
|
|
1485
|
+
Returns:
|
|
1486
|
+
Tuple of (success_count, list of failed memory_ids).
|
|
1487
|
+
|
|
1488
|
+
Raises:
|
|
1489
|
+
StorageError: If database operation fails completely.
|
|
1490
|
+
"""
|
|
1491
|
+
if not updates:
|
|
1492
|
+
return 0, []
|
|
1493
|
+
|
|
1494
|
+
now = utc_now()
|
|
1495
|
+
records_to_update: list[dict[str, Any]] = []
|
|
1496
|
+
failed_ids: list[str] = []
|
|
1497
|
+
|
|
1498
|
+
# Validate all IDs and collect them
|
|
1499
|
+
validated_updates: list[tuple[str, dict[str, Any]]] = []
|
|
1500
|
+
for memory_id, update_dict in updates:
|
|
1501
|
+
try:
|
|
1502
|
+
validated_id = _validate_uuid(memory_id)
|
|
1503
|
+
validated_updates.append((_sanitize_string(validated_id), update_dict))
|
|
1504
|
+
except Exception as e:
|
|
1505
|
+
logger.debug(f"Invalid memory ID {memory_id}: {e}")
|
|
1506
|
+
failed_ids.append(memory_id)
|
|
1507
|
+
|
|
1508
|
+
if not validated_updates:
|
|
1509
|
+
return 0, failed_ids
|
|
1510
|
+
|
|
1511
|
+
# Batch fetch all records
|
|
1512
|
+
validated_ids = [vid for vid, _ in validated_updates]
|
|
1513
|
+
try:
|
|
1514
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
1515
|
+
all_records = self.table.search().where(f"id IN ({id_list})").to_list()
|
|
1516
|
+
except Exception as e:
|
|
1517
|
+
logger.error(f"Failed to batch fetch records for update: {e}")
|
|
1518
|
+
raise StorageError(f"Failed to batch fetch for update: {e}") from e
|
|
1519
|
+
|
|
1520
|
+
# Build lookup map
|
|
1521
|
+
record_map: dict[str, dict[str, Any]] = {}
|
|
1522
|
+
for record in all_records:
|
|
1523
|
+
record_map[record["id"]] = record
|
|
1524
|
+
|
|
1525
|
+
# Apply updates to found records
|
|
1526
|
+
update_dict_map = dict(validated_updates)
|
|
1527
|
+
for memory_id in validated_ids:
|
|
1528
|
+
if memory_id not in record_map:
|
|
1529
|
+
logger.debug(f"Memory {memory_id} not found for batch update")
|
|
1530
|
+
failed_ids.append(memory_id)
|
|
1531
|
+
continue
|
|
1532
|
+
|
|
1533
|
+
record = record_map[memory_id]
|
|
1534
|
+
update_dict = update_dict_map[memory_id]
|
|
1535
|
+
|
|
1536
|
+
# Apply updates
|
|
1537
|
+
record["updated_at"] = now
|
|
1538
|
+
for key, value in update_dict.items():
|
|
1539
|
+
if key == "metadata" and isinstance(value, dict):
|
|
1540
|
+
record[key] = json.dumps(value)
|
|
1541
|
+
elif key == "vector" and isinstance(value, np.ndarray):
|
|
1542
|
+
record[key] = value.tolist()
|
|
1543
|
+
else:
|
|
1544
|
+
record[key] = value
|
|
1545
|
+
|
|
1546
|
+
# Ensure metadata is serialized
|
|
1547
|
+
if isinstance(record.get("metadata"), dict):
|
|
1548
|
+
record["metadata"] = json.dumps(record["metadata"])
|
|
1549
|
+
|
|
1550
|
+
# Ensure vector is a list
|
|
1551
|
+
if isinstance(record.get("vector"), np.ndarray):
|
|
1552
|
+
record["vector"] = record["vector"].tolist()
|
|
1553
|
+
|
|
1554
|
+
records_to_update.append(record)
|
|
1555
|
+
|
|
1556
|
+
if not records_to_update:
|
|
1557
|
+
return 0, failed_ids
|
|
1558
|
+
|
|
1559
|
+
try:
|
|
1560
|
+
# Atomic batch upsert
|
|
1561
|
+
(
|
|
1562
|
+
self.table.merge_insert("id")
|
|
1563
|
+
.when_matched_update_all()
|
|
1564
|
+
.when_not_matched_insert_all()
|
|
1565
|
+
.execute(records_to_update)
|
|
1566
|
+
)
|
|
1567
|
+
success_count = len(records_to_update)
|
|
1568
|
+
logger.debug(
|
|
1569
|
+
f"Batch updated {success_count}/{len(updates)} memories "
|
|
1570
|
+
"(atomic merge_insert)"
|
|
1571
|
+
)
|
|
1572
|
+
return success_count, failed_ids
|
|
1573
|
+
except Exception as e:
|
|
1574
|
+
logger.error(f"Failed to batch update: {e}")
|
|
1575
|
+
raise StorageError(f"Failed to batch update: {e}") from e
|
|
1576
|
+
|
|
1577
|
+
@with_process_lock
|
|
1578
|
+
@with_write_lock
|
|
1579
|
+
def delete(self, memory_id: str) -> None:
|
|
1580
|
+
"""Delete a memory.
|
|
1581
|
+
|
|
1582
|
+
Args:
|
|
1583
|
+
memory_id: The memory ID.
|
|
1584
|
+
|
|
1585
|
+
Raises:
|
|
1586
|
+
ValidationError: If memory_id is invalid.
|
|
1587
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
1588
|
+
StorageError: If database operation fails.
|
|
1589
|
+
"""
|
|
1590
|
+
# Validate memory_id
|
|
1591
|
+
memory_id = _validate_uuid(memory_id)
|
|
1592
|
+
safe_id = _sanitize_string(memory_id)
|
|
1593
|
+
|
|
1594
|
+
# First verify the memory exists
|
|
1595
|
+
self.get(memory_id)
|
|
1596
|
+
|
|
1597
|
+
try:
|
|
1598
|
+
self.table.delete(f"id = '{safe_id}'")
|
|
1599
|
+
self._invalidate_count_cache()
|
|
1600
|
+
self._track_modification()
|
|
1601
|
+
self._invalidate_namespace_cache()
|
|
1602
|
+
logger.debug(f"Deleted memory {memory_id}")
|
|
1603
|
+
except Exception as e:
|
|
1604
|
+
raise StorageError(f"Failed to delete memory: {e}") from e
|
|
1605
|
+
|
|
1606
|
+
@with_process_lock
|
|
1607
|
+
@with_write_lock
|
|
1608
|
+
def delete_by_namespace(self, namespace: str) -> int:
|
|
1609
|
+
"""Delete all memories in a namespace.
|
|
1610
|
+
|
|
1611
|
+
Args:
|
|
1612
|
+
namespace: The namespace to delete.
|
|
1613
|
+
|
|
1614
|
+
Returns:
|
|
1615
|
+
Number of deleted records.
|
|
1616
|
+
|
|
1617
|
+
Raises:
|
|
1618
|
+
ValidationError: If namespace is invalid.
|
|
1619
|
+
StorageError: If database operation fails.
|
|
1620
|
+
"""
|
|
1621
|
+
namespace = _validate_namespace(namespace)
|
|
1622
|
+
safe_ns = _sanitize_string(namespace)
|
|
1623
|
+
|
|
1624
|
+
try:
|
|
1625
|
+
count_before: int = self.table.count_rows()
|
|
1626
|
+
self.table.delete(f"namespace = '{safe_ns}'")
|
|
1627
|
+
self._invalidate_count_cache()
|
|
1628
|
+
self._track_modification()
|
|
1629
|
+
self._invalidate_namespace_cache()
|
|
1630
|
+
count_after: int = self.table.count_rows()
|
|
1631
|
+
deleted = count_before - count_after
|
|
1632
|
+
logger.debug(f"Deleted {deleted} memories in namespace '{namespace}'")
|
|
1633
|
+
return deleted
|
|
1634
|
+
except Exception as e:
|
|
1635
|
+
raise StorageError(f"Failed to delete by namespace: {e}") from e
|
|
1636
|
+
|
|
1637
|
+
@with_process_lock
|
|
1638
|
+
@with_write_lock
|
|
1639
|
+
def clear_all(self, reset_indexes: bool = True) -> int:
|
|
1640
|
+
"""Clear all memories from the database.
|
|
1641
|
+
|
|
1642
|
+
This is primarily for testing purposes to reset database state
|
|
1643
|
+
between tests while maintaining the connection.
|
|
1644
|
+
|
|
1645
|
+
Args:
|
|
1646
|
+
reset_indexes: If True, also reset index tracking flags.
|
|
1647
|
+
This allows tests to verify index creation behavior.
|
|
1648
|
+
|
|
1649
|
+
Returns:
|
|
1650
|
+
Number of deleted records.
|
|
1651
|
+
|
|
1652
|
+
Raises:
|
|
1653
|
+
StorageError: If database operation fails.
|
|
1654
|
+
"""
|
|
1655
|
+
try:
|
|
1656
|
+
count: int = self.table.count_rows()
|
|
1657
|
+
if count > 0:
|
|
1658
|
+
# Delete all rows - use simpler predicate that definitely matches
|
|
1659
|
+
self.table.delete("true")
|
|
1660
|
+
|
|
1661
|
+
# Verify deletion worked
|
|
1662
|
+
remaining = self.table.count_rows()
|
|
1663
|
+
if remaining > 0:
|
|
1664
|
+
logger.warning(
|
|
1665
|
+
f"clear_all: {remaining} records remain after delete, "
|
|
1666
|
+
f"attempting cleanup again"
|
|
1667
|
+
)
|
|
1668
|
+
# Try alternative delete approach
|
|
1669
|
+
self.table.delete("id IS NOT NULL")
|
|
1670
|
+
|
|
1671
|
+
self._invalidate_count_cache()
|
|
1672
|
+
self._track_modification()
|
|
1673
|
+
self._invalidate_namespace_cache()
|
|
1674
|
+
|
|
1675
|
+
# Reset index tracking flags for test isolation
|
|
1676
|
+
if reset_indexes:
|
|
1677
|
+
self._has_vector_index = None
|
|
1678
|
+
self._has_fts_index = False
|
|
1679
|
+
self._has_scalar_indexes = False
|
|
1680
|
+
|
|
1681
|
+
logger.debug(f"Cleared all {count} memories from database")
|
|
1682
|
+
return count
|
|
1683
|
+
except Exception as e:
|
|
1684
|
+
raise StorageError(f"Failed to clear all memories: {e}") from e
|
|
1685
|
+
|
|
1686
|
+
@with_process_lock
|
|
1687
|
+
@with_write_lock
|
|
1688
|
+
def rename_namespace(self, old_namespace: str, new_namespace: str) -> int:
|
|
1689
|
+
"""Rename all memories from one namespace to another.
|
|
1690
|
+
|
|
1691
|
+
Uses atomic batch update via merge_insert for data integrity.
|
|
1692
|
+
On partial failure, attempts to rollback renamed records to original namespace.
|
|
1693
|
+
|
|
1694
|
+
Args:
|
|
1695
|
+
old_namespace: Source namespace name.
|
|
1696
|
+
new_namespace: Target namespace name.
|
|
1697
|
+
|
|
1698
|
+
Returns:
|
|
1699
|
+
Number of memories renamed.
|
|
1700
|
+
|
|
1701
|
+
Raises:
|
|
1702
|
+
ValidationError: If namespace names are invalid.
|
|
1703
|
+
NamespaceNotFoundError: If old_namespace doesn't exist.
|
|
1704
|
+
StorageError: If database operation fails.
|
|
1705
|
+
"""
|
|
1706
|
+
from spatial_memory.core.errors import NamespaceNotFoundError
|
|
1707
|
+
|
|
1708
|
+
old_namespace = _validate_namespace(old_namespace)
|
|
1709
|
+
new_namespace = _validate_namespace(new_namespace)
|
|
1710
|
+
safe_old = _sanitize_string(old_namespace)
|
|
1711
|
+
_sanitize_string(new_namespace) # Validate but don't store unused result
|
|
1712
|
+
|
|
1713
|
+
try:
|
|
1714
|
+
# Check if source namespace exists
|
|
1715
|
+
existing = self.get_namespaces()
|
|
1716
|
+
if old_namespace not in existing:
|
|
1717
|
+
raise NamespaceNotFoundError(old_namespace)
|
|
1718
|
+
|
|
1719
|
+
# Short-circuit if renaming to same namespace (no-op)
|
|
1720
|
+
if old_namespace == new_namespace:
|
|
1721
|
+
count = self.count(namespace=old_namespace)
|
|
1722
|
+
logger.debug(f"Namespace '{old_namespace}' renamed to itself ({count} records)")
|
|
1723
|
+
return count
|
|
1724
|
+
|
|
1725
|
+
# Track renamed IDs for rollback capability
|
|
1726
|
+
renamed_ids: list[str] = []
|
|
1727
|
+
|
|
1728
|
+
# Fetch all records in batches with iteration safeguards
|
|
1729
|
+
batch_size = 1000
|
|
1730
|
+
max_iterations = 10000 # Safety cap: 10M records at 1000/batch
|
|
1731
|
+
updated = 0
|
|
1732
|
+
iteration = 0
|
|
1733
|
+
previous_updated = 0
|
|
1734
|
+
|
|
1735
|
+
while True:
|
|
1736
|
+
iteration += 1
|
|
1737
|
+
|
|
1738
|
+
# Safety limit to prevent infinite loops
|
|
1739
|
+
if iteration > max_iterations:
|
|
1740
|
+
raise StorageError(
|
|
1741
|
+
f"rename_namespace exceeded maximum iterations ({max_iterations}). "
|
|
1742
|
+
f"Updated {updated} records before stopping. "
|
|
1743
|
+
"This may indicate a database consistency issue."
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
records = (
|
|
1747
|
+
self.table.search()
|
|
1748
|
+
.where(f"namespace = '{safe_old}'")
|
|
1749
|
+
.limit(batch_size)
|
|
1750
|
+
.to_list()
|
|
1751
|
+
)
|
|
1752
|
+
|
|
1753
|
+
if not records:
|
|
1754
|
+
break
|
|
1755
|
+
|
|
1756
|
+
# Track IDs in this batch for potential rollback
|
|
1757
|
+
batch_ids = [r["id"] for r in records]
|
|
1758
|
+
|
|
1759
|
+
# Update namespace field
|
|
1760
|
+
for r in records:
|
|
1761
|
+
r["namespace"] = new_namespace
|
|
1762
|
+
r["updated_at"] = utc_now()
|
|
1763
|
+
if isinstance(r.get("metadata"), dict):
|
|
1764
|
+
r["metadata"] = json.dumps(r["metadata"])
|
|
1765
|
+
if isinstance(r.get("vector"), np.ndarray):
|
|
1766
|
+
r["vector"] = r["vector"].tolist()
|
|
1767
|
+
|
|
1768
|
+
try:
|
|
1769
|
+
# Atomic upsert
|
|
1770
|
+
(
|
|
1771
|
+
self.table.merge_insert("id")
|
|
1772
|
+
.when_matched_update_all()
|
|
1773
|
+
.when_not_matched_insert_all()
|
|
1774
|
+
.execute(records)
|
|
1775
|
+
)
|
|
1776
|
+
# Only track as renamed after successful update
|
|
1777
|
+
renamed_ids.extend(batch_ids)
|
|
1778
|
+
except Exception as batch_error:
|
|
1779
|
+
# Batch failed - attempt rollback of previously renamed records
|
|
1780
|
+
if renamed_ids:
|
|
1781
|
+
logger.warning(
|
|
1782
|
+
f"Batch {iteration} failed, attempting rollback of "
|
|
1783
|
+
f"{len(renamed_ids)} previously renamed records"
|
|
1784
|
+
)
|
|
1785
|
+
rollback_error = self._rollback_namespace_rename(
|
|
1786
|
+
renamed_ids, old_namespace
|
|
1787
|
+
)
|
|
1788
|
+
if rollback_error:
|
|
1789
|
+
raise StorageError(
|
|
1790
|
+
f"Namespace rename failed at batch {iteration} and "
|
|
1791
|
+
f"rollback also failed. {len(renamed_ids)} records may be "
|
|
1792
|
+
f"in inconsistent state (partially in '{new_namespace}'). "
|
|
1793
|
+
f"Original error: {batch_error}. Rollback error: {rollback_error}"
|
|
1794
|
+
) from batch_error
|
|
1795
|
+
else:
|
|
1796
|
+
logger.info(
|
|
1797
|
+
f"Rollback successful, reverted {len(renamed_ids)} records "
|
|
1798
|
+
f"back to namespace '{old_namespace}'"
|
|
1799
|
+
)
|
|
1800
|
+
raise StorageError(
|
|
1801
|
+
f"Failed to rename namespace (rolled back): {batch_error}"
|
|
1802
|
+
) from batch_error
|
|
1803
|
+
|
|
1804
|
+
updated += len(records)
|
|
1805
|
+
|
|
1806
|
+
# Detect stalled progress (same batch being processed repeatedly)
|
|
1807
|
+
if updated == previous_updated:
|
|
1808
|
+
raise StorageError(
|
|
1809
|
+
f"rename_namespace stalled at {updated} records. "
|
|
1810
|
+
"merge_insert may have failed silently."
|
|
1811
|
+
)
|
|
1812
|
+
previous_updated = updated
|
|
1813
|
+
|
|
1814
|
+
self._invalidate_namespace_cache()
|
|
1815
|
+
logger.debug(
|
|
1816
|
+
f"Renamed {updated} memories from '{old_namespace}' to '{new_namespace}'"
|
|
1817
|
+
)
|
|
1818
|
+
return updated
|
|
1819
|
+
|
|
1820
|
+
except (ValidationError, NamespaceNotFoundError):
|
|
1821
|
+
raise
|
|
1822
|
+
except Exception as e:
|
|
1823
|
+
raise StorageError(f"Failed to rename namespace: {e}") from e
|
|
1824
|
+
|
|
1825
|
+
def _rollback_namespace_rename(
|
|
1826
|
+
self, memory_ids: list[str], target_namespace: str
|
|
1827
|
+
) -> Exception | None:
|
|
1828
|
+
"""Attempt to revert renamed records back to original namespace.
|
|
1829
|
+
|
|
1830
|
+
Args:
|
|
1831
|
+
memory_ids: List of memory IDs to revert.
|
|
1832
|
+
target_namespace: Namespace to revert records to.
|
|
1833
|
+
|
|
1834
|
+
Returns:
|
|
1835
|
+
None if rollback succeeded, Exception if it failed.
|
|
1836
|
+
"""
|
|
1837
|
+
try:
|
|
1838
|
+
if not memory_ids:
|
|
1839
|
+
return None
|
|
1840
|
+
|
|
1841
|
+
_sanitize_string(target_namespace) # Validate namespace
|
|
1842
|
+
now = utc_now()
|
|
1843
|
+
|
|
1844
|
+
# Process in batches for large rollbacks
|
|
1845
|
+
batch_size = 1000
|
|
1846
|
+
for i in range(0, len(memory_ids), batch_size):
|
|
1847
|
+
batch_ids = memory_ids[i:i + batch_size]
|
|
1848
|
+
id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in batch_ids)
|
|
1849
|
+
|
|
1850
|
+
# Fetch records that need rollback
|
|
1851
|
+
records = (
|
|
1852
|
+
self.table.search()
|
|
1853
|
+
.where(f"id IN ({id_list})")
|
|
1854
|
+
.to_list()
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
if not records:
|
|
1858
|
+
continue
|
|
1859
|
+
|
|
1860
|
+
# Revert namespace
|
|
1861
|
+
for r in records:
|
|
1862
|
+
r["namespace"] = target_namespace
|
|
1863
|
+
r["updated_at"] = now
|
|
1864
|
+
if isinstance(r.get("metadata"), dict):
|
|
1865
|
+
r["metadata"] = json.dumps(r["metadata"])
|
|
1866
|
+
if isinstance(r.get("vector"), np.ndarray):
|
|
1867
|
+
r["vector"] = r["vector"].tolist()
|
|
1868
|
+
|
|
1869
|
+
# Atomic upsert to restore original namespace
|
|
1870
|
+
(
|
|
1871
|
+
self.table.merge_insert("id")
|
|
1872
|
+
.when_matched_update_all()
|
|
1873
|
+
.when_not_matched_insert_all()
|
|
1874
|
+
.execute(records)
|
|
1875
|
+
)
|
|
1876
|
+
|
|
1877
|
+
self._invalidate_namespace_cache()
|
|
1878
|
+
logger.debug(f"Rolled back {len(memory_ids)} records to namespace '{target_namespace}'")
|
|
1879
|
+
return None
|
|
1880
|
+
|
|
1881
|
+
except Exception as e:
|
|
1882
|
+
logger.error(f"Namespace rename rollback failed: {e}")
|
|
1883
|
+
return e
|
|
1884
|
+
|
|
1885
|
+
@with_stale_connection_recovery
|
|
1886
|
+
def get_stats(self, namespace: str | None = None) -> dict[str, Any]:
|
|
1887
|
+
"""Get comprehensive database statistics.
|
|
1888
|
+
|
|
1889
|
+
Uses efficient LanceDB queries for aggregations.
|
|
1890
|
+
|
|
1891
|
+
Args:
|
|
1892
|
+
namespace: Filter stats to specific namespace (None = all).
|
|
1893
|
+
|
|
1894
|
+
Returns:
|
|
1895
|
+
Dictionary with statistics including:
|
|
1896
|
+
- total_memories: Total count of memories
|
|
1897
|
+
- namespaces: Dict mapping namespace to count
|
|
1898
|
+
- storage_bytes: Total storage size in bytes
|
|
1899
|
+
- storage_mb: Total storage size in megabytes
|
|
1900
|
+
- has_vector_index: Whether vector index exists
|
|
1901
|
+
- has_fts_index: Whether full-text search index exists
|
|
1902
|
+
- num_fragments: Number of storage fragments
|
|
1903
|
+
- needs_compaction: Whether compaction is recommended
|
|
1904
|
+
- table_version: Current table version number
|
|
1905
|
+
- indices: List of index information dicts
|
|
1906
|
+
|
|
1907
|
+
Raises:
|
|
1908
|
+
ValidationError: If namespace is invalid.
|
|
1909
|
+
StorageError: If database operation fails.
|
|
1910
|
+
"""
|
|
1911
|
+
try:
|
|
1912
|
+
metrics = self.get_health_metrics()
|
|
1913
|
+
|
|
1914
|
+
# Get memory counts by namespace using efficient Arrow aggregation
|
|
1915
|
+
# Use pure Arrow operations (no pandas dependency)
|
|
1916
|
+
ns_arrow = self.table.search().select(["namespace"]).to_arrow()
|
|
1917
|
+
|
|
1918
|
+
# Count by namespace using Arrow's to_pylist()
|
|
1919
|
+
ns_counts: dict[str, int] = {}
|
|
1920
|
+
for record in ns_arrow.to_pylist():
|
|
1921
|
+
ns = record["namespace"]
|
|
1922
|
+
ns_counts[ns] = ns_counts.get(ns, 0) + 1
|
|
1923
|
+
|
|
1924
|
+
# Filter if namespace specified
|
|
1925
|
+
if namespace:
|
|
1926
|
+
namespace = _validate_namespace(namespace)
|
|
1927
|
+
if namespace in ns_counts:
|
|
1928
|
+
ns_counts = {namespace: ns_counts[namespace]}
|
|
1929
|
+
else:
|
|
1930
|
+
ns_counts = {}
|
|
1931
|
+
|
|
1932
|
+
total = sum(ns_counts.values()) if ns_counts else 0
|
|
1933
|
+
|
|
1934
|
+
return {
|
|
1935
|
+
"total_memories": total if namespace else metrics.total_rows,
|
|
1936
|
+
"namespaces": ns_counts,
|
|
1937
|
+
"storage_bytes": metrics.total_bytes,
|
|
1938
|
+
"storage_mb": metrics.total_bytes_mb,
|
|
1939
|
+
"num_fragments": metrics.num_fragments,
|
|
1940
|
+
"needs_compaction": metrics.needs_compaction,
|
|
1941
|
+
"has_vector_index": metrics.has_vector_index,
|
|
1942
|
+
"has_fts_index": metrics.has_fts_index,
|
|
1943
|
+
"table_version": metrics.version,
|
|
1944
|
+
"indices": [
|
|
1945
|
+
{
|
|
1946
|
+
"name": idx.name,
|
|
1947
|
+
"index_type": idx.index_type,
|
|
1948
|
+
"num_indexed_rows": idx.num_indexed_rows,
|
|
1949
|
+
"status": "ready" if not idx.needs_update else "needs_update",
|
|
1950
|
+
}
|
|
1951
|
+
for idx in metrics.indices
|
|
1952
|
+
],
|
|
1953
|
+
}
|
|
1954
|
+
except ValidationError:
|
|
1955
|
+
raise
|
|
1956
|
+
except Exception as e:
|
|
1957
|
+
raise StorageError(f"Failed to get stats: {e}") from e
|
|
1958
|
+
|
|
1959
|
+
def get_namespace_stats(self, namespace: str) -> dict[str, Any]:
|
|
1960
|
+
"""Get statistics for a specific namespace.
|
|
1961
|
+
|
|
1962
|
+
Args:
|
|
1963
|
+
namespace: The namespace to get statistics for.
|
|
1964
|
+
|
|
1965
|
+
Returns:
|
|
1966
|
+
Dictionary containing:
|
|
1967
|
+
- namespace: The namespace name
|
|
1968
|
+
- memory_count: Number of memories in namespace
|
|
1969
|
+
- oldest_memory: Datetime of oldest memory (or None)
|
|
1970
|
+
- newest_memory: Datetime of newest memory (or None)
|
|
1971
|
+
- avg_content_length: Average content length (or None if empty)
|
|
1972
|
+
|
|
1973
|
+
Raises:
|
|
1974
|
+
ValidationError: If namespace is invalid.
|
|
1975
|
+
StorageError: If database operation fails.
|
|
1976
|
+
"""
|
|
1977
|
+
namespace = _validate_namespace(namespace)
|
|
1978
|
+
safe_ns = _sanitize_string(namespace)
|
|
1979
|
+
|
|
1980
|
+
try:
|
|
1981
|
+
# Get count efficiently
|
|
1982
|
+
filter_expr = f"namespace = '{safe_ns}'"
|
|
1983
|
+
count_results = (
|
|
1984
|
+
self.table.search()
|
|
1985
|
+
.where(filter_expr)
|
|
1986
|
+
.select(["id"])
|
|
1987
|
+
.limit(1000000) # High limit to count all
|
|
1988
|
+
.to_list()
|
|
1989
|
+
)
|
|
1990
|
+
memory_count = len(count_results)
|
|
1991
|
+
|
|
1992
|
+
if memory_count == 0:
|
|
1993
|
+
return {
|
|
1994
|
+
"namespace": namespace,
|
|
1995
|
+
"memory_count": 0,
|
|
1996
|
+
"oldest_memory": None,
|
|
1997
|
+
"newest_memory": None,
|
|
1998
|
+
"avg_content_length": None,
|
|
1999
|
+
}
|
|
2000
|
+
|
|
2001
|
+
# Get oldest memory (sort ascending, limit 1)
|
|
2002
|
+
oldest_records = (
|
|
2003
|
+
self.table.search()
|
|
2004
|
+
.where(filter_expr)
|
|
2005
|
+
.select(["created_at"])
|
|
2006
|
+
.limit(1)
|
|
2007
|
+
.to_list()
|
|
2008
|
+
)
|
|
2009
|
+
oldest = oldest_records[0]["created_at"] if oldest_records else None
|
|
2010
|
+
|
|
2011
|
+
# Get newest memory - need to fetch more and find max since LanceDB
|
|
2012
|
+
# doesn't support ORDER BY DESC efficiently
|
|
2013
|
+
# Sample up to 1000 records for stats to avoid loading everything
|
|
2014
|
+
sample_size = min(memory_count, 1000)
|
|
2015
|
+
sample_records = (
|
|
2016
|
+
self.table.search()
|
|
2017
|
+
.where(filter_expr)
|
|
2018
|
+
.select(["created_at", "content"])
|
|
2019
|
+
.limit(sample_size)
|
|
2020
|
+
.to_list()
|
|
2021
|
+
)
|
|
2022
|
+
|
|
2023
|
+
# Find newest from sample (for large namespaces this is approximate)
|
|
2024
|
+
if sample_records:
|
|
2025
|
+
created_times = [r["created_at"] for r in sample_records]
|
|
2026
|
+
newest = max(created_times)
|
|
2027
|
+
# Calculate average content length from sample
|
|
2028
|
+
content_lengths = [len(r.get("content", "")) for r in sample_records]
|
|
2029
|
+
avg_content_length = sum(content_lengths) / len(content_lengths)
|
|
2030
|
+
else:
|
|
2031
|
+
newest = oldest
|
|
2032
|
+
avg_content_length = None
|
|
2033
|
+
|
|
2034
|
+
return {
|
|
2035
|
+
"namespace": namespace,
|
|
2036
|
+
"memory_count": memory_count,
|
|
2037
|
+
"oldest_memory": oldest,
|
|
2038
|
+
"newest_memory": newest,
|
|
2039
|
+
"avg_content_length": avg_content_length,
|
|
2040
|
+
}
|
|
2041
|
+
|
|
2042
|
+
except ValidationError:
|
|
2043
|
+
raise
|
|
2044
|
+
except Exception as e:
|
|
2045
|
+
raise StorageError(f"Failed to get namespace stats: {e}") from e
|
|
2046
|
+
|
|
2047
|
+
def get_all_for_export(
|
|
2048
|
+
self,
|
|
2049
|
+
namespace: str | None = None,
|
|
2050
|
+
batch_size: int = 1000,
|
|
2051
|
+
) -> Generator[list[dict[str, Any]], None, None]:
|
|
2052
|
+
"""Stream all memories for export in batches.
|
|
2053
|
+
|
|
2054
|
+
Memory-efficient export using generator pattern.
|
|
2055
|
+
|
|
2056
|
+
Args:
|
|
2057
|
+
namespace: Optional namespace filter.
|
|
2058
|
+
batch_size: Records per batch.
|
|
2059
|
+
|
|
2060
|
+
Yields:
|
|
2061
|
+
Batches of memory dictionaries.
|
|
2062
|
+
|
|
2063
|
+
Raises:
|
|
2064
|
+
ValidationError: If namespace is invalid.
|
|
2065
|
+
StorageError: If database operation fails.
|
|
2066
|
+
"""
|
|
2067
|
+
try:
|
|
2068
|
+
search = self.table.search()
|
|
2069
|
+
|
|
2070
|
+
if namespace is not None:
|
|
2071
|
+
namespace = _validate_namespace(namespace)
|
|
2072
|
+
safe_ns = _sanitize_string(namespace)
|
|
2073
|
+
search = search.where(f"namespace = '{safe_ns}'")
|
|
2074
|
+
|
|
2075
|
+
# Use Arrow for efficient streaming
|
|
2076
|
+
arrow_table = search.to_arrow()
|
|
2077
|
+
records = arrow_table.to_pylist()
|
|
2078
|
+
|
|
2079
|
+
# Yield in batches
|
|
2080
|
+
for i in range(0, len(records), batch_size):
|
|
2081
|
+
batch = records[i : i + batch_size]
|
|
2082
|
+
|
|
2083
|
+
# Process metadata
|
|
2084
|
+
for record in batch:
|
|
2085
|
+
if isinstance(record.get("metadata"), str):
|
|
2086
|
+
try:
|
|
2087
|
+
record["metadata"] = json.loads(record["metadata"])
|
|
2088
|
+
except json.JSONDecodeError:
|
|
2089
|
+
record["metadata"] = {}
|
|
2090
|
+
|
|
2091
|
+
yield batch
|
|
2092
|
+
|
|
2093
|
+
except ValidationError:
|
|
2094
|
+
raise
|
|
2095
|
+
except Exception as e:
|
|
2096
|
+
raise StorageError(f"Failed to stream export: {e}") from e
|
|
2097
|
+
|
|
2098
|
+
@with_process_lock
|
|
2099
|
+
@with_write_lock
|
|
2100
|
+
def bulk_import(
|
|
2101
|
+
self,
|
|
2102
|
+
records: Iterator[dict[str, Any]],
|
|
2103
|
+
batch_size: int = 1000,
|
|
2104
|
+
namespace_override: str | None = None,
|
|
2105
|
+
) -> tuple[int, list[str]]:
|
|
2106
|
+
"""Import memories from an iterator of records.
|
|
2107
|
+
|
|
2108
|
+
Supports streaming import for large datasets.
|
|
2109
|
+
|
|
2110
|
+
Args:
|
|
2111
|
+
records: Iterator of memory dictionaries.
|
|
2112
|
+
batch_size: Records per database insert batch.
|
|
2113
|
+
namespace_override: Override namespace for all records.
|
|
2114
|
+
|
|
2115
|
+
Returns:
|
|
2116
|
+
Tuple of (records_imported, list_of_new_ids).
|
|
2117
|
+
|
|
2118
|
+
Raises:
|
|
2119
|
+
ValidationError: If records contain invalid data.
|
|
2120
|
+
StorageError: If database operation fails.
|
|
2121
|
+
"""
|
|
2122
|
+
if namespace_override is not None:
|
|
2123
|
+
namespace_override = _validate_namespace(namespace_override)
|
|
2124
|
+
|
|
2125
|
+
imported = 0
|
|
2126
|
+
all_ids: list[str] = []
|
|
2127
|
+
batch: list[dict[str, Any]] = []
|
|
2128
|
+
|
|
2129
|
+
try:
|
|
2130
|
+
for record in records:
|
|
2131
|
+
prepared = self._prepare_import_record(record, namespace_override)
|
|
2132
|
+
batch.append(prepared)
|
|
2133
|
+
|
|
2134
|
+
if len(batch) >= batch_size:
|
|
2135
|
+
ids = self.insert_batch(batch, batch_size=batch_size)
|
|
2136
|
+
all_ids.extend(ids)
|
|
2137
|
+
imported += len(ids)
|
|
2138
|
+
batch = []
|
|
2139
|
+
|
|
2140
|
+
# Import remaining
|
|
2141
|
+
if batch:
|
|
2142
|
+
ids = self.insert_batch(batch, batch_size=batch_size)
|
|
2143
|
+
all_ids.extend(ids)
|
|
2144
|
+
imported += len(ids)
|
|
2145
|
+
|
|
2146
|
+
return imported, all_ids
|
|
2147
|
+
|
|
2148
|
+
except (ValidationError, StorageError):
|
|
2149
|
+
raise
|
|
2150
|
+
except Exception as e:
|
|
2151
|
+
raise StorageError(f"Bulk import failed: {e}") from e
|
|
2152
|
+
|
|
2153
|
+
def _prepare_import_record(
|
|
2154
|
+
self,
|
|
2155
|
+
record: dict[str, Any],
|
|
2156
|
+
namespace_override: str | None = None,
|
|
2157
|
+
) -> dict[str, Any]:
|
|
2158
|
+
"""Prepare a record for import.
|
|
2159
|
+
|
|
2160
|
+
Args:
|
|
2161
|
+
record: The raw record from import file.
|
|
2162
|
+
namespace_override: Optional namespace override.
|
|
2163
|
+
|
|
2164
|
+
Returns:
|
|
2165
|
+
Prepared record suitable for insert_batch.
|
|
2166
|
+
"""
|
|
2167
|
+
# Required fields
|
|
2168
|
+
content = record.get("content", "")
|
|
2169
|
+
vector = record.get("vector", [])
|
|
2170
|
+
|
|
2171
|
+
# Convert vector to numpy if needed
|
|
2172
|
+
if isinstance(vector, list):
|
|
2173
|
+
vector = np.array(vector, dtype=np.float32)
|
|
2174
|
+
|
|
2175
|
+
# Get namespace (override if specified)
|
|
2176
|
+
namespace = namespace_override or record.get("namespace", "default")
|
|
2177
|
+
|
|
2178
|
+
# Optional fields with defaults
|
|
2179
|
+
tags = record.get("tags", [])
|
|
2180
|
+
importance = record.get("importance", 0.5)
|
|
2181
|
+
source = record.get("source", "import")
|
|
2182
|
+
metadata = record.get("metadata", {})
|
|
2183
|
+
|
|
2184
|
+
return {
|
|
2185
|
+
"content": content,
|
|
2186
|
+
"vector": vector,
|
|
2187
|
+
"namespace": namespace,
|
|
2188
|
+
"tags": tags,
|
|
2189
|
+
"importance": importance,
|
|
2190
|
+
"source": source,
|
|
2191
|
+
"metadata": metadata,
|
|
2192
|
+
}
|
|
2193
|
+
|
|
2194
|
+
@with_process_lock
|
|
2195
|
+
@with_write_lock
|
|
2196
|
+
def delete_batch(self, memory_ids: list[str]) -> tuple[int, list[str]]:
|
|
2197
|
+
"""Delete multiple memories atomically using IN clause.
|
|
2198
|
+
|
|
2199
|
+
Args:
|
|
2200
|
+
memory_ids: List of memory UUIDs to delete.
|
|
2201
|
+
|
|
2202
|
+
Returns:
|
|
2203
|
+
Tuple of (count_deleted, list_of_deleted_ids) where:
|
|
2204
|
+
- count_deleted: Number of memories actually deleted
|
|
2205
|
+
- list_of_deleted_ids: IDs that were actually deleted
|
|
2206
|
+
|
|
2207
|
+
Raises:
|
|
2208
|
+
ValidationError: If any memory_id is invalid.
|
|
2209
|
+
StorageError: If database operation fails.
|
|
2210
|
+
"""
|
|
2211
|
+
if not memory_ids:
|
|
2212
|
+
return (0, [])
|
|
2213
|
+
|
|
2214
|
+
# Validate all IDs first (fail fast)
|
|
2215
|
+
validated_ids: list[str] = []
|
|
2216
|
+
for memory_id in memory_ids:
|
|
2217
|
+
validated_id = _validate_uuid(memory_id)
|
|
2218
|
+
validated_ids.append(_sanitize_string(validated_id))
|
|
2219
|
+
|
|
2220
|
+
try:
|
|
2221
|
+
# First, check which IDs actually exist
|
|
2222
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
2223
|
+
filter_expr = f"id IN ({id_list})"
|
|
2224
|
+
existing_records = (
|
|
2225
|
+
self.table.search()
|
|
2226
|
+
.where(filter_expr)
|
|
2227
|
+
.select(["id"])
|
|
2228
|
+
.limit(len(validated_ids))
|
|
2229
|
+
.to_list()
|
|
2230
|
+
)
|
|
2231
|
+
existing_ids = [r["id"] for r in existing_records]
|
|
2232
|
+
|
|
2233
|
+
if not existing_ids:
|
|
2234
|
+
return (0, [])
|
|
2235
|
+
|
|
2236
|
+
# Delete only existing IDs
|
|
2237
|
+
existing_id_list = ", ".join(f"'{mid}'" for mid in existing_ids)
|
|
2238
|
+
delete_expr = f"id IN ({existing_id_list})"
|
|
2239
|
+
self.table.delete(delete_expr)
|
|
2240
|
+
|
|
2241
|
+
self._invalidate_count_cache()
|
|
2242
|
+
self._track_modification()
|
|
2243
|
+
self._invalidate_namespace_cache()
|
|
2244
|
+
|
|
2245
|
+
logger.debug(f"Batch deleted {len(existing_ids)} memories")
|
|
2246
|
+
return (len(existing_ids), existing_ids)
|
|
2247
|
+
except ValidationError:
|
|
2248
|
+
raise
|
|
2249
|
+
except Exception as e:
|
|
2250
|
+
raise StorageError(f"Failed to delete batch: {e}") from e
|
|
2251
|
+
|
|
2252
|
+
@with_process_lock
|
|
2253
|
+
@with_write_lock
|
|
2254
|
+
def update_access_batch(self, memory_ids: list[str]) -> int:
|
|
2255
|
+
"""Update access timestamp and count for multiple memories using atomic merge_insert.
|
|
2256
|
+
|
|
2257
|
+
Uses LanceDB's merge_insert API for atomic batch upserts, eliminating
|
|
2258
|
+
race conditions from delete-then-insert patterns.
|
|
2259
|
+
|
|
2260
|
+
Args:
|
|
2261
|
+
memory_ids: List of memory UUIDs to update.
|
|
2262
|
+
|
|
2263
|
+
Returns:
|
|
2264
|
+
Number of memories successfully updated.
|
|
2265
|
+
"""
|
|
2266
|
+
if not memory_ids:
|
|
2267
|
+
return 0
|
|
2268
|
+
|
|
2269
|
+
now = utc_now()
|
|
2270
|
+
records_to_update: list[dict[str, Any]] = []
|
|
2271
|
+
|
|
2272
|
+
# Validate all IDs first
|
|
2273
|
+
validated_ids: list[str] = []
|
|
2274
|
+
for memory_id in memory_ids:
|
|
2275
|
+
try:
|
|
2276
|
+
validated_id = _validate_uuid(memory_id)
|
|
2277
|
+
validated_ids.append(_sanitize_string(validated_id))
|
|
2278
|
+
except Exception as e:
|
|
2279
|
+
logger.debug(f"Invalid memory ID {memory_id}: {e}")
|
|
2280
|
+
continue
|
|
2281
|
+
|
|
2282
|
+
if not validated_ids:
|
|
2283
|
+
return 0
|
|
2284
|
+
|
|
2285
|
+
# Batch fetch all records with single IN query (fixes N+1 pattern)
|
|
2286
|
+
try:
|
|
2287
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
2288
|
+
all_records = self.table.search().where(f"id IN ({id_list})").to_list()
|
|
2289
|
+
except Exception as e:
|
|
2290
|
+
logger.error(f"Failed to batch fetch records for access update: {e}")
|
|
2291
|
+
return 0
|
|
2292
|
+
|
|
2293
|
+
# Build lookup map for found records
|
|
2294
|
+
found_ids = set()
|
|
2295
|
+
for record in all_records:
|
|
2296
|
+
found_ids.add(record["id"])
|
|
2297
|
+
record["last_accessed"] = now
|
|
2298
|
+
record["access_count"] = record["access_count"] + 1
|
|
2299
|
+
|
|
2300
|
+
# Ensure proper serialization for metadata
|
|
2301
|
+
if isinstance(record.get("metadata"), dict):
|
|
2302
|
+
record["metadata"] = json.dumps(record["metadata"])
|
|
2303
|
+
|
|
2304
|
+
# Ensure vector is a list, not numpy array
|
|
2305
|
+
if isinstance(record.get("vector"), np.ndarray):
|
|
2306
|
+
record["vector"] = record["vector"].tolist()
|
|
2307
|
+
|
|
2308
|
+
records_to_update.append(record)
|
|
2309
|
+
|
|
2310
|
+
# Log any IDs that weren't found
|
|
2311
|
+
missing_ids = set(validated_ids) - found_ids
|
|
2312
|
+
for missing_id in missing_ids:
|
|
2313
|
+
logger.debug(f"Memory {missing_id} not found for access update")
|
|
2314
|
+
|
|
2315
|
+
if not records_to_update:
|
|
2316
|
+
return 0
|
|
2317
|
+
|
|
2318
|
+
try:
|
|
2319
|
+
# Atomic batch upsert using merge_insert
|
|
2320
|
+
# Requires BTREE index on 'id' column (created in create_scalar_indexes)
|
|
2321
|
+
(
|
|
2322
|
+
self.table.merge_insert("id")
|
|
2323
|
+
.when_matched_update_all()
|
|
2324
|
+
.when_not_matched_insert_all()
|
|
2325
|
+
.execute(records_to_update)
|
|
2326
|
+
)
|
|
2327
|
+
updated = len(records_to_update)
|
|
2328
|
+
logger.debug(
|
|
2329
|
+
f"Batch updated access for {updated}/{len(memory_ids)} memories "
|
|
2330
|
+
"(atomic merge_insert)"
|
|
2331
|
+
)
|
|
2332
|
+
return updated
|
|
2333
|
+
except Exception as e:
|
|
2334
|
+
logger.error(f"Failed to batch update access: {e}")
|
|
2335
|
+
return 0
|
|
2336
|
+
|
|
2337
|
+
def _create_retry_decorator(self) -> Callable[[F], F]:
|
|
2338
|
+
"""Create a retry decorator using instance settings."""
|
|
2339
|
+
return retry_on_storage_error(
|
|
2340
|
+
max_attempts=self.max_retry_attempts,
|
|
2341
|
+
backoff=self.retry_backoff_seconds,
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
# ========================================================================
|
|
2345
|
+
# Search Operations (delegates to SearchManager)
|
|
2346
|
+
# ========================================================================
|
|
2347
|
+
|
|
2348
|
+
def _calculate_search_params(
|
|
2349
|
+
self,
|
|
2350
|
+
count: int,
|
|
2351
|
+
limit: int,
|
|
2352
|
+
nprobes_override: int | None = None,
|
|
2353
|
+
refine_factor_override: int | None = None,
|
|
2354
|
+
) -> tuple[int, int]:
|
|
2355
|
+
"""Calculate optimal search parameters. Delegates to SearchManager."""
|
|
2356
|
+
if self._search_manager is None:
|
|
2357
|
+
raise StorageError("Database not connected")
|
|
2358
|
+
return self._search_manager.calculate_search_params(
|
|
2359
|
+
count, limit, nprobes_override, refine_factor_override
|
|
2360
|
+
)
|
|
2361
|
+
|
|
2362
|
+
@with_stale_connection_recovery
|
|
2363
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
2364
|
+
def vector_search(
|
|
2365
|
+
self,
|
|
2366
|
+
query_vector: np.ndarray,
|
|
2367
|
+
limit: int = 5,
|
|
2368
|
+
namespace: str | None = None,
|
|
2369
|
+
min_similarity: float = 0.0,
|
|
2370
|
+
nprobes: int | None = None,
|
|
2371
|
+
refine_factor: int | None = None,
|
|
2372
|
+
include_vector: bool = False,
|
|
2373
|
+
) -> list[dict[str, Any]]:
|
|
2374
|
+
"""Search for similar memories by vector. Delegates to SearchManager.
|
|
2375
|
+
|
|
2376
|
+
Args:
|
|
2377
|
+
query_vector: Query embedding vector.
|
|
2378
|
+
limit: Maximum number of results.
|
|
2379
|
+
namespace: Filter to specific namespace.
|
|
2380
|
+
min_similarity: Minimum similarity threshold (0-1).
|
|
2381
|
+
nprobes: Number of partitions to search.
|
|
2382
|
+
refine_factor: Re-rank top (refine_factor * limit) for accuracy.
|
|
2383
|
+
include_vector: Whether to include vector embeddings in results.
|
|
2384
|
+
|
|
2385
|
+
Returns:
|
|
2386
|
+
List of memory records with similarity scores.
|
|
2387
|
+
|
|
2388
|
+
Raises:
|
|
2389
|
+
ValidationError: If input validation fails.
|
|
2390
|
+
StorageError: If database operation fails.
|
|
2391
|
+
"""
|
|
2392
|
+
if self._search_manager is None:
|
|
2393
|
+
raise StorageError("Database not connected")
|
|
2394
|
+
return self._search_manager.vector_search(
|
|
2395
|
+
query_vector=query_vector,
|
|
2396
|
+
limit=limit,
|
|
2397
|
+
namespace=namespace,
|
|
2398
|
+
min_similarity=min_similarity,
|
|
2399
|
+
nprobes=nprobes,
|
|
2400
|
+
refine_factor=refine_factor,
|
|
2401
|
+
include_vector=include_vector,
|
|
2402
|
+
)
|
|
2403
|
+
|
|
2404
|
+
@with_stale_connection_recovery
|
|
2405
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
2406
|
+
def batch_vector_search_native(
|
|
2407
|
+
self,
|
|
2408
|
+
query_vectors: list[np.ndarray],
|
|
2409
|
+
limit_per_query: int = 3,
|
|
2410
|
+
namespace: str | None = None,
|
|
2411
|
+
min_similarity: float = 0.0,
|
|
2412
|
+
include_vector: bool = False,
|
|
2413
|
+
) -> list[list[dict[str, Any]]]:
|
|
2414
|
+
"""Batch search using native LanceDB. Delegates to SearchManager.
|
|
2415
|
+
|
|
2416
|
+
Args:
|
|
2417
|
+
query_vectors: List of query embedding vectors.
|
|
2418
|
+
limit_per_query: Maximum number of results per query.
|
|
2419
|
+
namespace: Filter to specific namespace.
|
|
2420
|
+
min_similarity: Minimum similarity threshold (0-1).
|
|
2421
|
+
include_vector: Whether to include vector embeddings in results.
|
|
2422
|
+
|
|
2423
|
+
Returns:
|
|
2424
|
+
List of result lists, one per query vector.
|
|
2425
|
+
|
|
2426
|
+
Raises:
|
|
2427
|
+
ValidationError: If input validation fails.
|
|
2428
|
+
StorageError: If database operation fails.
|
|
2429
|
+
"""
|
|
2430
|
+
if self._search_manager is None:
|
|
2431
|
+
raise StorageError("Database not connected")
|
|
2432
|
+
return self._search_manager.batch_vector_search_native(
|
|
2433
|
+
query_vectors=query_vectors,
|
|
2434
|
+
limit_per_query=limit_per_query,
|
|
2435
|
+
namespace=namespace,
|
|
2436
|
+
min_similarity=min_similarity,
|
|
2437
|
+
include_vector=include_vector,
|
|
2438
|
+
)
|
|
2439
|
+
|
|
2440
|
+
@with_stale_connection_recovery
|
|
2441
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
2442
|
+
def hybrid_search(
|
|
2443
|
+
self,
|
|
2444
|
+
query: str,
|
|
2445
|
+
query_vector: np.ndarray,
|
|
2446
|
+
limit: int = 5,
|
|
2447
|
+
namespace: str | None = None,
|
|
2448
|
+
alpha: float = 0.5,
|
|
2449
|
+
min_similarity: float = 0.0,
|
|
2450
|
+
) -> list[dict[str, Any]]:
|
|
2451
|
+
"""Hybrid search combining vector and keyword. Delegates to SearchManager.
|
|
2452
|
+
|
|
2453
|
+
Args:
|
|
2454
|
+
query: Text query for full-text search.
|
|
2455
|
+
query_vector: Embedding vector for semantic search.
|
|
2456
|
+
limit: Number of results.
|
|
2457
|
+
namespace: Filter to namespace.
|
|
2458
|
+
alpha: Balance between vector (1.0) and keyword (0.0).
|
|
2459
|
+
min_similarity: Minimum similarity threshold.
|
|
2460
|
+
|
|
2461
|
+
Returns:
|
|
2462
|
+
List of memory records with combined scores.
|
|
2463
|
+
|
|
2464
|
+
Raises:
|
|
2465
|
+
ValidationError: If input validation fails.
|
|
2466
|
+
StorageError: If database operation fails.
|
|
2467
|
+
"""
|
|
2468
|
+
if self._search_manager is None:
|
|
2469
|
+
raise StorageError("Database not connected")
|
|
2470
|
+
return self._search_manager.hybrid_search(
|
|
2471
|
+
query=query,
|
|
2472
|
+
query_vector=query_vector,
|
|
2473
|
+
limit=limit,
|
|
2474
|
+
namespace=namespace,
|
|
2475
|
+
alpha=alpha,
|
|
2476
|
+
min_similarity=min_similarity,
|
|
2477
|
+
)
|
|
2478
|
+
|
|
2479
|
+
@with_stale_connection_recovery
|
|
2480
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
2481
|
+
def batch_vector_search(
|
|
2482
|
+
self,
|
|
2483
|
+
query_vectors: list[np.ndarray],
|
|
2484
|
+
limit_per_query: int = 3,
|
|
2485
|
+
namespace: str | None = None,
|
|
2486
|
+
parallel: bool = False, # Deprecated
|
|
2487
|
+
max_workers: int = 4, # Deprecated
|
|
2488
|
+
include_vector: bool = False,
|
|
2489
|
+
) -> list[list[dict[str, Any]]]:
|
|
2490
|
+
"""Search using multiple query vectors. Delegates to SearchManager.
|
|
2491
|
+
|
|
2492
|
+
Args:
|
|
2493
|
+
query_vectors: List of query embedding vectors.
|
|
2494
|
+
limit_per_query: Maximum results per query vector.
|
|
2495
|
+
namespace: Filter to specific namespace.
|
|
2496
|
+
parallel: Deprecated, kept for backward compatibility.
|
|
2497
|
+
max_workers: Deprecated, kept for backward compatibility.
|
|
2498
|
+
include_vector: Whether to include vector embeddings in results.
|
|
2499
|
+
|
|
2500
|
+
Returns:
|
|
2501
|
+
List of result lists (one per query vector).
|
|
2502
|
+
|
|
2503
|
+
Raises:
|
|
2504
|
+
StorageError: If database operation fails.
|
|
2505
|
+
"""
|
|
2506
|
+
if self._search_manager is None:
|
|
2507
|
+
raise StorageError("Database not connected")
|
|
2508
|
+
return self._search_manager.batch_vector_search(
|
|
2509
|
+
query_vectors=query_vectors,
|
|
2510
|
+
limit_per_query=limit_per_query,
|
|
2511
|
+
namespace=namespace,
|
|
2512
|
+
parallel=parallel,
|
|
2513
|
+
max_workers=max_workers,
|
|
2514
|
+
include_vector=include_vector,
|
|
2515
|
+
)
|
|
2516
|
+
|
|
2517
|
+
def get_vectors_for_clustering(
|
|
2518
|
+
self,
|
|
2519
|
+
namespace: str | None = None,
|
|
2520
|
+
max_memories: int = 10_000,
|
|
2521
|
+
) -> tuple[list[str], np.ndarray]:
|
|
2522
|
+
"""Fetch all vectors for clustering operations (e.g., HDBSCAN).
|
|
2523
|
+
|
|
2524
|
+
Optimized for memory efficiency with large datasets.
|
|
2525
|
+
|
|
2526
|
+
Args:
|
|
2527
|
+
namespace: Filter to specific namespace.
|
|
2528
|
+
max_memories: Maximum memories to fetch.
|
|
2529
|
+
|
|
2530
|
+
Returns:
|
|
2531
|
+
Tuple of (memory_ids, vectors_array).
|
|
2532
|
+
|
|
2533
|
+
Raises:
|
|
2534
|
+
ValidationError: If input validation fails.
|
|
2535
|
+
StorageError: If database operation fails.
|
|
2536
|
+
"""
|
|
2537
|
+
try:
|
|
2538
|
+
# Build query selecting only needed columns
|
|
2539
|
+
search = self.table.search()
|
|
2540
|
+
|
|
2541
|
+
if namespace:
|
|
2542
|
+
namespace = _validate_namespace(namespace)
|
|
2543
|
+
safe_ns = _sanitize_string(namespace)
|
|
2544
|
+
search = search.where(f"namespace = '{safe_ns}'")
|
|
2545
|
+
|
|
2546
|
+
# Select only id and vector to minimize memory usage
|
|
2547
|
+
search = search.select(["id", "vector"]).limit(max_memories)
|
|
2548
|
+
|
|
2549
|
+
results = search.to_list()
|
|
2550
|
+
|
|
2551
|
+
if not results:
|
|
2552
|
+
return [], np.array([], dtype=np.float32).reshape(0, self.embedding_dim)
|
|
2553
|
+
|
|
2554
|
+
ids = [r["id"] for r in results]
|
|
2555
|
+
vectors = np.array([r["vector"] for r in results], dtype=np.float32)
|
|
2556
|
+
|
|
2557
|
+
return ids, vectors
|
|
2558
|
+
|
|
2559
|
+
except ValidationError:
|
|
2560
|
+
raise
|
|
2561
|
+
except Exception as e:
|
|
2562
|
+
raise StorageError(f"Failed to fetch vectors for clustering: {e}") from e
|
|
2563
|
+
|
|
2564
|
+
def get_vectors_as_arrow(
|
|
2565
|
+
self,
|
|
2566
|
+
namespace: str | None = None,
|
|
2567
|
+
columns: list[str] | None = None,
|
|
2568
|
+
) -> pa.Table:
|
|
2569
|
+
"""Get memories as Arrow table for efficient processing.
|
|
2570
|
+
|
|
2571
|
+
Arrow tables enable zero-copy data sharing and efficient columnar
|
|
2572
|
+
operations. Use this for large-scale analytics.
|
|
2573
|
+
|
|
2574
|
+
Args:
|
|
2575
|
+
namespace: Filter to specific namespace.
|
|
2576
|
+
columns: Columns to select (None = all).
|
|
2577
|
+
|
|
2578
|
+
Returns:
|
|
2579
|
+
PyArrow Table with selected data.
|
|
2580
|
+
|
|
2581
|
+
Raises:
|
|
2582
|
+
StorageError: If database operation fails.
|
|
2583
|
+
"""
|
|
2584
|
+
try:
|
|
2585
|
+
search = self.table.search()
|
|
2586
|
+
|
|
2587
|
+
if namespace:
|
|
2588
|
+
namespace = _validate_namespace(namespace)
|
|
2589
|
+
safe_ns = _sanitize_string(namespace)
|
|
2590
|
+
search = search.where(f"namespace = '{safe_ns}'")
|
|
2591
|
+
|
|
2592
|
+
if columns:
|
|
2593
|
+
search = search.select(columns)
|
|
2594
|
+
|
|
2595
|
+
return search.to_arrow()
|
|
2596
|
+
|
|
2597
|
+
except Exception as e:
|
|
2598
|
+
raise StorageError(f"Failed to get Arrow table: {e}") from e
|
|
2599
|
+
|
|
2600
|
+
def get_all(
|
|
2601
|
+
self,
|
|
2602
|
+
namespace: str | None = None,
|
|
2603
|
+
limit: int | None = None,
|
|
2604
|
+
) -> list[dict[str, Any]]:
|
|
2605
|
+
"""Get all memories, optionally filtered by namespace.
|
|
2606
|
+
|
|
2607
|
+
Args:
|
|
2608
|
+
namespace: Filter to specific namespace.
|
|
2609
|
+
limit: Maximum number of results.
|
|
2610
|
+
|
|
2611
|
+
Returns:
|
|
2612
|
+
List of memory records.
|
|
2613
|
+
|
|
2614
|
+
Raises:
|
|
2615
|
+
ValidationError: If input validation fails.
|
|
2616
|
+
StorageError: If database operation fails.
|
|
2617
|
+
"""
|
|
2618
|
+
try:
|
|
2619
|
+
search = self.table.search()
|
|
2620
|
+
|
|
2621
|
+
if namespace:
|
|
2622
|
+
namespace = _validate_namespace(namespace)
|
|
2623
|
+
safe_ns = _sanitize_string(namespace)
|
|
2624
|
+
search = search.where(f"namespace = '{safe_ns}'")
|
|
2625
|
+
|
|
2626
|
+
if limit:
|
|
2627
|
+
search = search.limit(limit)
|
|
2628
|
+
|
|
2629
|
+
results: list[dict[str, Any]] = search.to_list()
|
|
2630
|
+
|
|
2631
|
+
for record in results:
|
|
2632
|
+
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
2633
|
+
|
|
2634
|
+
return results
|
|
2635
|
+
except ValidationError:
|
|
2636
|
+
raise
|
|
2637
|
+
except Exception as e:
|
|
2638
|
+
raise StorageError(f"Failed to get all memories: {e}") from e
|
|
2639
|
+
|
|
2640
|
+
@with_stale_connection_recovery
|
|
2641
|
+
def count(self, namespace: str | None = None) -> int:
|
|
2642
|
+
"""Count memories.
|
|
2643
|
+
|
|
2644
|
+
Args:
|
|
2645
|
+
namespace: Filter to specific namespace.
|
|
2646
|
+
|
|
2647
|
+
Returns:
|
|
2648
|
+
Number of memories.
|
|
2649
|
+
|
|
2650
|
+
Raises:
|
|
2651
|
+
ValidationError: If input validation fails.
|
|
2652
|
+
StorageError: If database operation fails.
|
|
2653
|
+
"""
|
|
2654
|
+
try:
|
|
2655
|
+
if namespace:
|
|
2656
|
+
namespace = _validate_namespace(namespace)
|
|
2657
|
+
safe_ns = _sanitize_string(namespace)
|
|
2658
|
+
# Use count_rows with filter predicate for efficiency
|
|
2659
|
+
count: int = self.table.count_rows(f"namespace = '{safe_ns}'")
|
|
2660
|
+
return count
|
|
2661
|
+
count = self.table.count_rows()
|
|
2662
|
+
return count
|
|
2663
|
+
except ValidationError:
|
|
2664
|
+
raise
|
|
2665
|
+
except Exception as e:
|
|
2666
|
+
raise StorageError(f"Failed to count memories: {e}") from e
|
|
2667
|
+
|
|
2668
|
+
@with_stale_connection_recovery
|
|
2669
|
+
def get_namespaces(self) -> list[str]:
|
|
2670
|
+
"""Get all unique namespaces (cached with TTL, thread-safe).
|
|
2671
|
+
|
|
2672
|
+
Uses double-checked locking to avoid race conditions where another
|
|
2673
|
+
thread could see stale data between cache check and update.
|
|
2674
|
+
|
|
2675
|
+
Returns:
|
|
2676
|
+
Sorted list of namespace names.
|
|
2677
|
+
|
|
2678
|
+
Raises:
|
|
2679
|
+
StorageError: If database operation fails.
|
|
2680
|
+
"""
|
|
2681
|
+
try:
|
|
2682
|
+
now = time.time()
|
|
2683
|
+
|
|
2684
|
+
# First check with lock (quick path if cache is valid)
|
|
2685
|
+
with self._namespace_cache_lock:
|
|
2686
|
+
if (
|
|
2687
|
+
self._cached_namespaces is not None
|
|
2688
|
+
and (now - self._namespace_cache_time) <= self._NAMESPACE_CACHE_TTL
|
|
2689
|
+
):
|
|
2690
|
+
return sorted(self._cached_namespaces)
|
|
2691
|
+
|
|
2692
|
+
# Fetch from database (outside lock to avoid blocking)
|
|
2693
|
+
results = self.table.search().select(["namespace"]).to_list()
|
|
2694
|
+
namespaces = set(r["namespace"] for r in results)
|
|
2695
|
+
|
|
2696
|
+
# Double-checked locking: re-check and update atomically
|
|
2697
|
+
with self._namespace_cache_lock:
|
|
2698
|
+
# Another thread may have populated cache while we were fetching
|
|
2699
|
+
if self._cached_namespaces is None:
|
|
2700
|
+
self._cached_namespaces = namespaces
|
|
2701
|
+
self._namespace_cache_time = now
|
|
2702
|
+
# Return fresh data regardless (it's at least as current)
|
|
2703
|
+
return sorted(namespaces)
|
|
2704
|
+
|
|
2705
|
+
except Exception as e:
|
|
2706
|
+
raise StorageError(f"Failed to get namespaces: {e}") from e
|
|
2707
|
+
|
|
2708
|
+
@with_process_lock
|
|
2709
|
+
@with_write_lock
|
|
2710
|
+
def update_access(self, memory_id: str) -> None:
|
|
2711
|
+
"""Update access timestamp and count for a memory.
|
|
2712
|
+
|
|
2713
|
+
Args:
|
|
2714
|
+
memory_id: The memory ID.
|
|
2715
|
+
|
|
2716
|
+
Raises:
|
|
2717
|
+
ValidationError: If memory_id is invalid.
|
|
2718
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
2719
|
+
StorageError: If database operation fails.
|
|
2720
|
+
"""
|
|
2721
|
+
existing = self.get(memory_id)
|
|
2722
|
+
# Note: self.update() also has @with_process_lock and @with_write_lock,
|
|
2723
|
+
# but both support reentrancy within the same thread (no deadlock):
|
|
2724
|
+
# - ProcessLockManager tracks depth via threading.local
|
|
2725
|
+
# - RLock allows same thread to re-acquire
|
|
2726
|
+
self.update(memory_id, {
|
|
2727
|
+
"last_accessed": utc_now(),
|
|
2728
|
+
"access_count": existing["access_count"] + 1,
|
|
2729
|
+
})
|
|
2730
|
+
|
|
2731
|
+
# ========================================================================
|
|
2732
|
+
# Backup & Export
|
|
2733
|
+
# ========================================================================
|
|
2734
|
+
|
|
2735
|
+
def export_to_parquet(
|
|
2736
|
+
self,
|
|
2737
|
+
output_path: Path,
|
|
2738
|
+
namespace: str | None = None,
|
|
2739
|
+
) -> dict[str, Any]:
|
|
2740
|
+
"""Export memories to Parquet file for backup.
|
|
2741
|
+
|
|
2742
|
+
Parquet provides efficient compression and fast read performance
|
|
2743
|
+
for large datasets.
|
|
2744
|
+
|
|
2745
|
+
Args:
|
|
2746
|
+
output_path: Path to save Parquet file.
|
|
2747
|
+
namespace: Export only this namespace (None = all).
|
|
2748
|
+
|
|
2749
|
+
Returns:
|
|
2750
|
+
Export statistics (rows_exported, output_path, size_mb).
|
|
2751
|
+
|
|
2752
|
+
Raises:
|
|
2753
|
+
StorageError: If export fails.
|
|
2754
|
+
"""
|
|
2755
|
+
try:
|
|
2756
|
+
# Get all data as Arrow table (efficient)
|
|
2757
|
+
arrow_table = self.get_vectors_as_arrow(namespace=namespace)
|
|
2758
|
+
|
|
2759
|
+
# Ensure parent directory exists
|
|
2760
|
+
output_path = Path(output_path)
|
|
2761
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2762
|
+
|
|
2763
|
+
# Write to Parquet with compression
|
|
2764
|
+
pq.write_table(
|
|
2765
|
+
arrow_table,
|
|
2766
|
+
output_path,
|
|
2767
|
+
compression="zstd", # Good compression + fast decompression
|
|
2768
|
+
)
|
|
2769
|
+
|
|
2770
|
+
size_bytes = output_path.stat().st_size
|
|
2771
|
+
|
|
2772
|
+
logger.info(
|
|
2773
|
+
f"Exported {arrow_table.num_rows} memories to {output_path} "
|
|
2774
|
+
f"({size_bytes / (1024 * 1024):.2f} MB)"
|
|
2775
|
+
)
|
|
2776
|
+
|
|
2777
|
+
return {
|
|
2778
|
+
"rows_exported": arrow_table.num_rows,
|
|
2779
|
+
"output_path": str(output_path),
|
|
2780
|
+
"size_bytes": size_bytes,
|
|
2781
|
+
"size_mb": size_bytes / (1024 * 1024),
|
|
2782
|
+
}
|
|
2783
|
+
|
|
2784
|
+
except Exception as e:
|
|
2785
|
+
raise StorageError(f"Export failed: {e}") from e
|
|
2786
|
+
|
|
2787
|
+
def import_from_parquet(
|
|
2788
|
+
self,
|
|
2789
|
+
parquet_path: Path,
|
|
2790
|
+
namespace_override: str | None = None,
|
|
2791
|
+
batch_size: int = 1000,
|
|
2792
|
+
) -> dict[str, Any]:
|
|
2793
|
+
"""Import memories from Parquet backup.
|
|
2794
|
+
|
|
2795
|
+
Args:
|
|
2796
|
+
parquet_path: Path to Parquet file.
|
|
2797
|
+
namespace_override: Override namespace for all imported memories.
|
|
2798
|
+
batch_size: Records per batch during import.
|
|
2799
|
+
|
|
2800
|
+
Returns:
|
|
2801
|
+
Import statistics (rows_imported, source).
|
|
2802
|
+
|
|
2803
|
+
Raises:
|
|
2804
|
+
StorageError: If import fails.
|
|
2805
|
+
"""
|
|
2806
|
+
try:
|
|
2807
|
+
parquet_path = Path(parquet_path)
|
|
2808
|
+
if not parquet_path.exists():
|
|
2809
|
+
raise StorageError(f"Parquet file not found: {parquet_path}")
|
|
2810
|
+
|
|
2811
|
+
table = pq.read_table(parquet_path)
|
|
2812
|
+
total_rows = table.num_rows
|
|
2813
|
+
|
|
2814
|
+
logger.info(f"Importing {total_rows} memories from {parquet_path}")
|
|
2815
|
+
|
|
2816
|
+
# Convert to list of dicts for processing
|
|
2817
|
+
records = table.to_pylist()
|
|
2818
|
+
|
|
2819
|
+
# Override namespace if requested
|
|
2820
|
+
if namespace_override:
|
|
2821
|
+
namespace_override = _validate_namespace(namespace_override)
|
|
2822
|
+
for record in records:
|
|
2823
|
+
record["namespace"] = namespace_override
|
|
2824
|
+
|
|
2825
|
+
# Regenerate IDs to avoid conflicts
|
|
2826
|
+
for record in records:
|
|
2827
|
+
record["id"] = str(uuid.uuid4())
|
|
2828
|
+
# Ensure metadata is properly formatted
|
|
2829
|
+
if isinstance(record.get("metadata"), str):
|
|
2830
|
+
try:
|
|
2831
|
+
record["metadata"] = json.loads(record["metadata"])
|
|
2832
|
+
except json.JSONDecodeError:
|
|
2833
|
+
record["metadata"] = {}
|
|
2834
|
+
|
|
2835
|
+
# After reading from parquet, serialize metadata back to JSON string
|
|
2836
|
+
# Parquet may read metadata as dict/struct, but the database expects JSON string
|
|
2837
|
+
for record in records:
|
|
2838
|
+
if "metadata" in record and isinstance(record["metadata"], dict):
|
|
2839
|
+
record["metadata"] = json.dumps(record["metadata"])
|
|
2840
|
+
|
|
2841
|
+
# Insert in batches
|
|
2842
|
+
imported = 0
|
|
2843
|
+
for i in range(0, len(records), batch_size):
|
|
2844
|
+
batch = records[i:i + batch_size]
|
|
2845
|
+
# Convert to format expected by insert
|
|
2846
|
+
prepared = []
|
|
2847
|
+
for r in batch:
|
|
2848
|
+
# Ensure metadata is a JSON string for storage
|
|
2849
|
+
metadata = r.get("metadata", {})
|
|
2850
|
+
if isinstance(metadata, dict):
|
|
2851
|
+
metadata = json.dumps(metadata)
|
|
2852
|
+
elif metadata is None:
|
|
2853
|
+
metadata = "{}"
|
|
2854
|
+
|
|
2855
|
+
prepared.append({
|
|
2856
|
+
"content": r["content"],
|
|
2857
|
+
"vector": r["vector"],
|
|
2858
|
+
"namespace": r["namespace"],
|
|
2859
|
+
"tags": r.get("tags", []),
|
|
2860
|
+
"importance": r.get("importance", 0.5),
|
|
2861
|
+
"source": r.get("source", "import"),
|
|
2862
|
+
"metadata": metadata,
|
|
2863
|
+
"expires_at": r.get("expires_at"), # Preserve TTL from source
|
|
2864
|
+
})
|
|
2865
|
+
self.table.add(prepared)
|
|
2866
|
+
imported += len(batch)
|
|
2867
|
+
logger.debug(f"Imported batch: {imported}/{total_rows}")
|
|
2868
|
+
|
|
2869
|
+
logger.info(f"Successfully imported {imported} memories")
|
|
2870
|
+
|
|
2871
|
+
return {
|
|
2872
|
+
"rows_imported": imported,
|
|
2873
|
+
"source": str(parquet_path),
|
|
2874
|
+
}
|
|
2875
|
+
|
|
2876
|
+
except StorageError:
|
|
2877
|
+
raise
|
|
2878
|
+
except Exception as e:
|
|
2879
|
+
raise StorageError(f"Import failed: {e}") from e
|
|
2880
|
+
|
|
2881
|
+
# ========================================================================
|
|
2882
|
+
# TTL (Time-To-Live) Management
|
|
2883
|
+
# ========================================================================
|
|
2884
|
+
|
|
2885
|
+
def set_memory_ttl(self, memory_id: str, ttl_days: int | None) -> None:
|
|
2886
|
+
"""Set TTL for a specific memory.
|
|
2887
|
+
|
|
2888
|
+
Args:
|
|
2889
|
+
memory_id: Memory ID.
|
|
2890
|
+
ttl_days: Days until expiration, or None to remove TTL.
|
|
2891
|
+
|
|
2892
|
+
Raises:
|
|
2893
|
+
ValidationError: If memory_id is invalid.
|
|
2894
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
2895
|
+
StorageError: If database operation fails.
|
|
2896
|
+
"""
|
|
2897
|
+
memory_id = _validate_uuid(memory_id)
|
|
2898
|
+
|
|
2899
|
+
# Verify memory exists
|
|
2900
|
+
existing = self.get(memory_id)
|
|
2901
|
+
|
|
2902
|
+
if ttl_days is not None:
|
|
2903
|
+
if ttl_days <= 0:
|
|
2904
|
+
raise ValidationError("TTL days must be positive")
|
|
2905
|
+
expires_at = utc_now() + timedelta(days=ttl_days)
|
|
2906
|
+
else:
|
|
2907
|
+
expires_at = None
|
|
2908
|
+
|
|
2909
|
+
# Prepare record with TTL update
|
|
2910
|
+
existing["expires_at"] = expires_at
|
|
2911
|
+
existing["updated_at"] = utc_now()
|
|
2912
|
+
|
|
2913
|
+
# Ensure proper serialization for LanceDB
|
|
2914
|
+
if isinstance(existing.get("metadata"), dict):
|
|
2915
|
+
existing["metadata"] = json.dumps(existing["metadata"])
|
|
2916
|
+
if isinstance(existing.get("vector"), np.ndarray):
|
|
2917
|
+
existing["vector"] = existing["vector"].tolist()
|
|
2918
|
+
|
|
2919
|
+
try:
|
|
2920
|
+
# Atomic upsert using merge_insert (same pattern as update() method)
|
|
2921
|
+
# This prevents data loss if the operation fails partway through
|
|
2922
|
+
(
|
|
2923
|
+
self.table.merge_insert("id")
|
|
2924
|
+
.when_matched_update_all()
|
|
2925
|
+
.when_not_matched_insert_all()
|
|
2926
|
+
.execute([existing])
|
|
2927
|
+
)
|
|
2928
|
+
logger.debug(f"Set TTL for memory {memory_id}: expires_at={expires_at}")
|
|
2929
|
+
except Exception as e:
|
|
2930
|
+
raise StorageError(f"Failed to set memory TTL: {e}") from e
|
|
2931
|
+
|
|
2932
|
+
def cleanup_expired_memories(self) -> int:
|
|
2933
|
+
"""Delete memories that have passed their expiration time.
|
|
2934
|
+
|
|
2935
|
+
Returns:
|
|
2936
|
+
Number of deleted memories.
|
|
2937
|
+
|
|
2938
|
+
Raises:
|
|
2939
|
+
StorageError: If cleanup fails.
|
|
2940
|
+
"""
|
|
2941
|
+
if not self.enable_memory_expiration:
|
|
2942
|
+
logger.debug("Memory expiration is disabled, skipping cleanup")
|
|
2943
|
+
return 0
|
|
2944
|
+
|
|
2945
|
+
try:
|
|
2946
|
+
now = utc_now()
|
|
2947
|
+
count_before: int = self.table.count_rows()
|
|
2948
|
+
|
|
2949
|
+
# Delete expired memories using timestamp comparison
|
|
2950
|
+
# LanceDB uses ISO 8601 format for timestamp comparisons
|
|
2951
|
+
predicate = (
|
|
2952
|
+
f"expires_at IS NOT NULL AND expires_at < timestamp '{now.isoformat()}'"
|
|
2953
|
+
)
|
|
2954
|
+
self.table.delete(predicate)
|
|
2955
|
+
|
|
2956
|
+
count_after: int = self.table.count_rows()
|
|
2957
|
+
deleted: int = count_before - count_after
|
|
2958
|
+
|
|
2959
|
+
if deleted > 0:
|
|
2960
|
+
self._invalidate_count_cache()
|
|
2961
|
+
self._track_modification(deleted)
|
|
2962
|
+
self._invalidate_namespace_cache()
|
|
2963
|
+
logger.info(f"Cleaned up {deleted} expired memories")
|
|
2964
|
+
|
|
2965
|
+
return deleted
|
|
2966
|
+
except Exception as e:
|
|
2967
|
+
raise StorageError(f"Failed to cleanup expired memories: {e}") from e
|
|
2968
|
+
|
|
2969
|
+
# ========================================================================
|
|
2970
|
+
# Snapshot / Version Management (delegated to VersionManager)
|
|
2971
|
+
# ========================================================================
|
|
2972
|
+
|
|
2973
|
+
def create_snapshot(self, tag: str) -> int:
|
|
2974
|
+
"""Create a named snapshot of the current table state.
|
|
2975
|
+
|
|
2976
|
+
Delegates to VersionManager. See VersionManager.create_snapshot for details.
|
|
2977
|
+
"""
|
|
2978
|
+
if self._version_manager is None:
|
|
2979
|
+
raise StorageError("Database not connected")
|
|
2980
|
+
return self._version_manager.create_snapshot(tag)
|
|
2981
|
+
|
|
2982
|
+
def list_snapshots(self) -> list[dict[str, Any]]:
|
|
2983
|
+
"""List available versions/snapshots.
|
|
2984
|
+
|
|
2985
|
+
Delegates to VersionManager. See VersionManager.list_snapshots for details.
|
|
2986
|
+
"""
|
|
2987
|
+
if self._version_manager is None:
|
|
2988
|
+
raise StorageError("Database not connected")
|
|
2989
|
+
return self._version_manager.list_snapshots()
|
|
2990
|
+
|
|
2991
|
+
def restore_snapshot(self, version: int) -> None:
|
|
2992
|
+
"""Restore table to a specific version.
|
|
2993
|
+
|
|
2994
|
+
Delegates to VersionManager. See VersionManager.restore_snapshot for details.
|
|
2995
|
+
"""
|
|
2996
|
+
if self._version_manager is None:
|
|
2997
|
+
raise StorageError("Database not connected")
|
|
2998
|
+
self._version_manager.restore_snapshot(version)
|
|
2999
|
+
|
|
3000
|
+
def get_current_version(self) -> int:
|
|
3001
|
+
"""Get the current table version number.
|
|
3002
|
+
|
|
3003
|
+
Delegates to VersionManager. See VersionManager.get_current_version for details.
|
|
3004
|
+
"""
|
|
3005
|
+
if self._version_manager is None:
|
|
3006
|
+
raise StorageError("Database not connected")
|
|
3007
|
+
return self._version_manager.get_current_version()
|
|
3008
|
+
|
|
3009
|
+
# ========================================================================
|
|
3010
|
+
# Idempotency Key Management (delegates to IdempotencyManager)
|
|
3011
|
+
# ========================================================================
|
|
3012
|
+
|
|
3013
|
+
@property
|
|
3014
|
+
def idempotency_table(self) -> LanceTable:
|
|
3015
|
+
"""Get the idempotency keys table. Delegates to IdempotencyManager."""
|
|
3016
|
+
if self._idempotency_manager is None:
|
|
3017
|
+
raise StorageError("Database not connected")
|
|
3018
|
+
return self._idempotency_manager.idempotency_table
|
|
3019
|
+
|
|
3020
|
+
def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
|
|
3021
|
+
"""Look up an idempotency record by key. Delegates to IdempotencyManager.
|
|
3022
|
+
|
|
3023
|
+
Args:
|
|
3024
|
+
key: The idempotency key to look up.
|
|
3025
|
+
|
|
3026
|
+
Returns:
|
|
3027
|
+
IdempotencyRecord if found and not expired, None otherwise.
|
|
3028
|
+
|
|
3029
|
+
Raises:
|
|
3030
|
+
StorageError: If database operation fails.
|
|
3031
|
+
"""
|
|
3032
|
+
if self._idempotency_manager is None:
|
|
3033
|
+
raise StorageError("Database not connected")
|
|
3034
|
+
return self._idempotency_manager.get_by_idempotency_key(key)
|
|
3035
|
+
|
|
3036
|
+
@with_process_lock
|
|
3037
|
+
@with_write_lock
|
|
3038
|
+
def store_idempotency_key(
|
|
3039
|
+
self,
|
|
3040
|
+
key: str,
|
|
3041
|
+
memory_id: str,
|
|
3042
|
+
ttl_hours: float = 24.0,
|
|
3043
|
+
) -> None:
|
|
3044
|
+
"""Store an idempotency key mapping. Delegates to IdempotencyManager.
|
|
3045
|
+
|
|
3046
|
+
Args:
|
|
3047
|
+
key: The idempotency key.
|
|
3048
|
+
memory_id: The memory ID that was created.
|
|
3049
|
+
ttl_hours: Time-to-live in hours (default: 24 hours).
|
|
3050
|
+
|
|
3051
|
+
Raises:
|
|
3052
|
+
ValidationError: If inputs are invalid.
|
|
3053
|
+
StorageError: If database operation fails.
|
|
3054
|
+
"""
|
|
3055
|
+
if self._idempotency_manager is None:
|
|
3056
|
+
raise StorageError("Database not connected")
|
|
3057
|
+
self._idempotency_manager.store_idempotency_key(key, memory_id, ttl_hours)
|
|
3058
|
+
|
|
3059
|
+
@with_process_lock
|
|
3060
|
+
@with_write_lock
|
|
3061
|
+
def cleanup_expired_idempotency_keys(self) -> int:
|
|
3062
|
+
"""Remove expired idempotency keys. Delegates to IdempotencyManager.
|
|
3063
|
+
|
|
3064
|
+
Returns:
|
|
3065
|
+
Number of keys removed.
|
|
3066
|
+
|
|
3067
|
+
Raises:
|
|
3068
|
+
StorageError: If cleanup fails.
|
|
3069
|
+
"""
|
|
3070
|
+
if self._idempotency_manager is None:
|
|
3071
|
+
raise StorageError("Database not connected")
|
|
3072
|
+
return self._idempotency_manager.cleanup_expired_idempotency_keys()
|