spatial-memory-mcp 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. spatial_memory/__init__.py +97 -0
  2. spatial_memory/__main__.py +271 -0
  3. spatial_memory/adapters/__init__.py +7 -0
  4. spatial_memory/adapters/lancedb_repository.py +880 -0
  5. spatial_memory/config.py +769 -0
  6. spatial_memory/core/__init__.py +118 -0
  7. spatial_memory/core/cache.py +317 -0
  8. spatial_memory/core/circuit_breaker.py +297 -0
  9. spatial_memory/core/connection_pool.py +220 -0
  10. spatial_memory/core/consolidation_strategies.py +401 -0
  11. spatial_memory/core/database.py +3072 -0
  12. spatial_memory/core/db_idempotency.py +242 -0
  13. spatial_memory/core/db_indexes.py +576 -0
  14. spatial_memory/core/db_migrations.py +588 -0
  15. spatial_memory/core/db_search.py +512 -0
  16. spatial_memory/core/db_versioning.py +178 -0
  17. spatial_memory/core/embeddings.py +558 -0
  18. spatial_memory/core/errors.py +317 -0
  19. spatial_memory/core/file_security.py +701 -0
  20. spatial_memory/core/filesystem.py +178 -0
  21. spatial_memory/core/health.py +289 -0
  22. spatial_memory/core/helpers.py +79 -0
  23. spatial_memory/core/import_security.py +433 -0
  24. spatial_memory/core/lifecycle_ops.py +1067 -0
  25. spatial_memory/core/logging.py +194 -0
  26. spatial_memory/core/metrics.py +192 -0
  27. spatial_memory/core/models.py +660 -0
  28. spatial_memory/core/rate_limiter.py +326 -0
  29. spatial_memory/core/response_types.py +500 -0
  30. spatial_memory/core/security.py +588 -0
  31. spatial_memory/core/spatial_ops.py +430 -0
  32. spatial_memory/core/tracing.py +300 -0
  33. spatial_memory/core/utils.py +110 -0
  34. spatial_memory/core/validation.py +406 -0
  35. spatial_memory/factory.py +444 -0
  36. spatial_memory/migrations/__init__.py +40 -0
  37. spatial_memory/ports/__init__.py +11 -0
  38. spatial_memory/ports/repositories.py +630 -0
  39. spatial_memory/py.typed +0 -0
  40. spatial_memory/server.py +1214 -0
  41. spatial_memory/services/__init__.py +70 -0
  42. spatial_memory/services/decay_manager.py +411 -0
  43. spatial_memory/services/export_import.py +1031 -0
  44. spatial_memory/services/lifecycle.py +1139 -0
  45. spatial_memory/services/memory.py +412 -0
  46. spatial_memory/services/spatial.py +1152 -0
  47. spatial_memory/services/utility.py +429 -0
  48. spatial_memory/tools/__init__.py +5 -0
  49. spatial_memory/tools/definitions.py +695 -0
  50. spatial_memory/verify.py +140 -0
  51. spatial_memory_mcp-1.9.1.dist-info/METADATA +509 -0
  52. spatial_memory_mcp-1.9.1.dist-info/RECORD +55 -0
  53. spatial_memory_mcp-1.9.1.dist-info/WHEEL +4 -0
  54. spatial_memory_mcp-1.9.1.dist-info/entry_points.txt +2 -0
  55. spatial_memory_mcp-1.9.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,3072 @@
1
+ """LanceDB database wrapper for Spatial Memory MCP Server.
2
+
3
+ Enterprise-grade implementation with:
4
+ - Connection pooling (singleton pattern)
5
+ - Automatic index creation (IVF-PQ, FTS, scalar)
6
+ - Hybrid search with RRF reranking
7
+ - Batch operations and streaming
8
+ - Maintenance and optimization utilities
9
+ - Health metrics and monitoring
10
+ - Retry logic for transient errors
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import logging
17
+ import threading
18
+ import time
19
+ import uuid
20
+ from collections.abc import Callable, Generator, Iterator
21
+ from dataclasses import dataclass
22
+ from datetime import timedelta
23
+ from functools import wraps
24
+ from pathlib import Path
25
+ from typing import TYPE_CHECKING, Any, TypeVar, cast
26
+
27
+ import lancedb
28
+ import lancedb.index
29
+ import numpy as np
30
+ import pyarrow as pa
31
+ import pyarrow.parquet as pq
32
+ from filelock import FileLock
33
+ from filelock import Timeout as FileLockTimeout
34
+
35
+ from spatial_memory.core.connection_pool import ConnectionPool
36
+ from spatial_memory.core.db_idempotency import IdempotencyManager, IdempotencyRecord
37
+ from spatial_memory.core.db_indexes import IndexManager
38
+ from spatial_memory.core.db_migrations import CURRENT_SCHEMA_VERSION, MigrationManager
39
+ from spatial_memory.core.db_search import SearchManager
40
+ from spatial_memory.core.db_versioning import VersionManager
41
+ from spatial_memory.core.errors import (
42
+ DimensionMismatchError,
43
+ FileLockError,
44
+ MemoryNotFoundError,
45
+ PartialBatchInsertError,
46
+ StorageError,
47
+ ValidationError,
48
+ )
49
+ from spatial_memory.core.filesystem import (
50
+ detect_filesystem_type,
51
+ get_filesystem_warning_message,
52
+ is_network_filesystem,
53
+ )
54
+ from spatial_memory.core.utils import utc_now
55
+
56
+ # Import centralized validation functions
57
+ from spatial_memory.core.validation import (
58
+ sanitize_string as _sanitize_string_impl,
59
+ )
60
+ from spatial_memory.core.validation import (
61
+ validate_metadata as _validate_metadata_impl,
62
+ )
63
+ from spatial_memory.core.validation import (
64
+ validate_namespace as _validate_namespace_impl,
65
+ )
66
+ from spatial_memory.core.validation import (
67
+ validate_tags as _validate_tags_impl,
68
+ )
69
+ from spatial_memory.core.validation import (
70
+ validate_uuid as _validate_uuid_impl,
71
+ )
72
+
73
+ if TYPE_CHECKING:
74
+ from lancedb.table import Table as LanceTable
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+ # Type variable for retry decorator
79
+ F = TypeVar("F", bound=Callable[..., Any])
80
+
81
+ # All known vector index types for detection
82
+ VECTOR_INDEX_TYPES = frozenset({
83
+ "IVF_PQ", "IVF_FLAT", "HNSW",
84
+ "IVF_HNSW_PQ", "IVF_HNSW_SQ",
85
+ "HNSW_PQ", "HNSW_SQ",
86
+ })
87
+
88
+ # ============================================================================
89
+ # Connection Pool (Singleton Pattern with LRU Eviction)
90
+ # ============================================================================
91
+
92
+ # Global connection pool instance
93
+ _connection_pool = ConnectionPool(max_size=10)
94
+
95
+
96
+ def set_connection_pool_max_size(max_size: int) -> None:
97
+ """Set the maximum connection pool size.
98
+
99
+ Args:
100
+ max_size: Maximum number of connections to cache.
101
+ """
102
+ _connection_pool.max_size = max_size
103
+
104
+
105
+ def _get_or_create_connection(
106
+ storage_path: Path,
107
+ read_consistency_interval_ms: int = 0,
108
+ ) -> lancedb.DBConnection:
109
+ """Get cached connection or create new one (thread-safe with LRU eviction).
110
+
111
+ Args:
112
+ storage_path: Path to LanceDB storage directory.
113
+ read_consistency_interval_ms: Read consistency interval in milliseconds.
114
+
115
+ Returns:
116
+ LanceDB connection instance.
117
+ """
118
+ path_key = str(storage_path.absolute())
119
+ return _connection_pool.get_or_create(path_key, read_consistency_interval_ms)
120
+
121
+
122
+ def clear_connection_cache() -> None:
123
+ """Clear the connection cache, properly closing connections.
124
+
125
+ Should be called during shutdown or testing cleanup.
126
+ """
127
+ _connection_pool.close_all()
128
+
129
+
130
+ def invalidate_connection(storage_path: Path) -> bool:
131
+ """Invalidate a specific cached connection.
132
+
133
+ Use when a database connection becomes stale (e.g., database was
134
+ deleted and recreated externally).
135
+
136
+ Args:
137
+ storage_path: Path to the database to invalidate.
138
+
139
+ Returns:
140
+ True if a connection was invalidated, False if not found in cache.
141
+ """
142
+ path_key = str(storage_path.absolute())
143
+ return _connection_pool.invalidate(path_key)
144
+
145
+
146
+ # ============================================================================
147
+ # Retry Decorator
148
+ # ============================================================================
149
+
150
+ # Default retry settings (can be overridden per-call)
151
+ DEFAULT_RETRY_MAX_ATTEMPTS = 3
152
+ DEFAULT_RETRY_BACKOFF_SECONDS = 0.5
153
+
154
+
155
+ def retry_on_storage_error(
156
+ max_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
157
+ backoff: float = DEFAULT_RETRY_BACKOFF_SECONDS,
158
+ ) -> Callable[[F], F]:
159
+ """Retry decorator for transient storage errors.
160
+
161
+ Args:
162
+ max_attempts: Maximum number of retry attempts.
163
+ backoff: Initial backoff time in seconds (doubles each attempt).
164
+
165
+ Returns:
166
+ Decorated function with retry logic.
167
+
168
+ Note:
169
+ - Decorator values are STATIC: Parameters are fixed at class definition
170
+ time, not instance creation time. This means the instance config values
171
+ (max_retry_attempts, retry_backoff_seconds) exist for external tooling
172
+ or future dynamic use, but do NOT affect this decorator's behavior.
173
+ - Does NOT retry concurrent modification or conflict errors as these
174
+ require application-level resolution (e.g., refresh and retry).
175
+ """
176
+ # Patterns indicating non-retryable errors
177
+ non_retryable_patterns = ("concurrent", "conflict", "version mismatch")
178
+
179
+ def decorator(func: F) -> F:
180
+ @wraps(func)
181
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
182
+ last_error: Exception | None = None
183
+ for attempt in range(max_attempts):
184
+ try:
185
+ return func(*args, **kwargs)
186
+ except (
187
+ StorageError, OSError, ConnectionError, TimeoutError
188
+ ) as e:
189
+ last_error = e
190
+ error_str = str(e).lower()
191
+
192
+ # Check for non-retryable errors - raise immediately
193
+ if any(pattern in error_str for pattern in non_retryable_patterns):
194
+ logger.warning(
195
+ f"Non-retryable error in {func.__name__}: {e}"
196
+ )
197
+ raise
198
+
199
+ # Check if we've exhausted retries
200
+ if attempt == max_attempts - 1:
201
+ raise
202
+
203
+ # Retry with exponential backoff
204
+ wait_time = backoff * (2 ** attempt)
205
+ logger.warning(
206
+ f"{func.__name__} failed (attempt {attempt + 1}/{max_attempts})"
207
+ f": {e}. Retrying in {wait_time:.1f}s..."
208
+ )
209
+ time.sleep(wait_time)
210
+ # Should never reach here, but satisfy type checker
211
+ if last_error:
212
+ raise last_error
213
+ return None
214
+ return cast(F, wrapper)
215
+ return decorator
216
+
217
+
218
+ def with_write_lock(func: F) -> F:
219
+ """Decorator to acquire write lock for mutation operations.
220
+
221
+ Serializes write operations per Database instance to prevent
222
+ LanceDB version conflicts during concurrent writes.
223
+
224
+ Uses RLock to allow nested calls (e.g., bulk_import -> insert_batch).
225
+ """
226
+ @wraps(func)
227
+ def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
228
+ with self._write_lock:
229
+ return func(self, *args, **kwargs)
230
+ return cast(F, wrapper)
231
+
232
+
233
+ def with_stale_connection_recovery(func: F) -> F:
234
+ """Decorator to auto-recover from stale connection errors.
235
+
236
+ Detects when a database operation fails due to stale metadata
237
+ (e.g., database was recreated while connection was cached),
238
+ reconnects, and retries the operation once.
239
+ """
240
+ @wraps(func)
241
+ def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
242
+ try:
243
+ return func(self, *args, **kwargs)
244
+ except Exception as e:
245
+ if _connection_pool.is_stale_connection_error(e):
246
+ logger.warning(
247
+ f"Stale connection detected in {func.__name__}, reconnecting..."
248
+ )
249
+ self.reconnect()
250
+ return func(self, *args, **kwargs)
251
+ raise
252
+ return cast(F, wrapper)
253
+
254
+
255
+ # ============================================================================
256
+ # Cross-Process Lock Manager
257
+ # ============================================================================
258
+
259
+
260
+ class ProcessLockManager:
261
+ """Cross-process file lock manager with reentrant support.
262
+
263
+ Wraps FileLock with thread-local depth tracking to support nested calls
264
+ (e.g., bulk_import() -> insert_batch()). Each thread can re-acquire the
265
+ lock without blocking.
266
+
267
+ Thread Safety:
268
+ - Lock depth is tracked per-thread using threading.local
269
+ - The underlying FileLock handles cross-process synchronization
270
+ - Multiple threads in the same process can hold the lock via RLock behavior
271
+
272
+ Example:
273
+ lock = ProcessLockManager(Path("/tmp/db.lock"), timeout=30.0)
274
+ with lock:
275
+ # Protected region
276
+ with lock: # Nested call - same thread can re-acquire
277
+ pass
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ lock_path: Path,
283
+ timeout: float = 30.0,
284
+ poll_interval: float = 0.1,
285
+ enabled: bool = True,
286
+ ) -> None:
287
+ """Initialize the process lock manager.
288
+
289
+ Args:
290
+ lock_path: Path to the lock file.
291
+ timeout: Maximum seconds to wait for lock acquisition.
292
+ poll_interval: Seconds between lock acquisition attempts.
293
+ enabled: If False, all lock operations are no-ops.
294
+ """
295
+ self.lock_path = lock_path
296
+ self.timeout = timeout
297
+ self.poll_interval = poll_interval
298
+ self.enabled = enabled
299
+
300
+ # Create FileLock only if enabled
301
+ self._lock: FileLock | None = None
302
+ if enabled:
303
+ try:
304
+ self._lock = FileLock(str(lock_path), timeout=timeout)
305
+ except Exception as e:
306
+ # Fallback to disabled mode if lock file can't be created
307
+ # (e.g., read-only filesystem)
308
+ logger.warning(
309
+ f"Could not create file lock at {lock_path}: {e}. "
310
+ "Falling back to disabled mode."
311
+ )
312
+ self.enabled = False
313
+
314
+ # Thread-local storage for lock depth tracking
315
+ self._local = threading.local()
316
+
317
+ def _get_depth(self) -> int:
318
+ """Get current lock depth for this thread."""
319
+ return getattr(self._local, "depth", 0)
320
+
321
+ def _set_depth(self, depth: int) -> None:
322
+ """Set lock depth for this thread."""
323
+ self._local.depth = depth
324
+
325
+ def acquire(self) -> bool:
326
+ """Acquire the lock (reentrant for same thread).
327
+
328
+ Returns:
329
+ True if lock was newly acquired, False if already held by this thread.
330
+
331
+ Raises:
332
+ FileLockError: If lock cannot be acquired within timeout.
333
+ """
334
+ if not self.enabled or self._lock is None:
335
+ return True
336
+
337
+ depth = self._get_depth()
338
+ if depth > 0:
339
+ # Already held by this thread - increment depth
340
+ self._set_depth(depth + 1)
341
+ return False # Not newly acquired
342
+
343
+ try:
344
+ self._lock.acquire(timeout=self.timeout, poll_interval=self.poll_interval)
345
+ self._set_depth(1)
346
+ return True
347
+ except FileLockTimeout:
348
+ raise FileLockError(
349
+ lock_path=str(self.lock_path),
350
+ timeout=self.timeout,
351
+ message=(
352
+ f"Timed out waiting {self.timeout}s for file lock at "
353
+ f"{self.lock_path}. Another process may be holding the lock."
354
+ ),
355
+ )
356
+
357
+ def release(self) -> bool:
358
+ """Release the lock (decrements depth, releases when depth reaches 0).
359
+
360
+ Returns:
361
+ True if lock was released, False if still held (depth > 0).
362
+ """
363
+ if not self.enabled or self._lock is None:
364
+ return True
365
+
366
+ depth = self._get_depth()
367
+ if depth <= 0:
368
+ return True # Not holding the lock
369
+
370
+ if depth == 1:
371
+ self._lock.release()
372
+ self._set_depth(0)
373
+ return True
374
+ else:
375
+ self._set_depth(depth - 1)
376
+ return False # Still holding
377
+
378
+ def __enter__(self) -> ProcessLockManager:
379
+ """Enter context manager - acquire lock."""
380
+ self.acquire()
381
+ return self
382
+
383
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
384
+ """Exit context manager - release lock."""
385
+ self.release()
386
+
387
+
388
+ def with_process_lock(func: F) -> F:
389
+ """Decorator to acquire process-level file lock for write operations.
390
+
391
+ Must be applied BEFORE (outer) @with_write_lock to ensure:
392
+ 1. Cross-process lock acquired first
393
+ 2. Then intra-process thread lock
394
+ 3. Releases in reverse order
395
+
396
+ Usage:
397
+ @with_process_lock # Outer - cross-process
398
+ @with_write_lock # Inner - intra-process
399
+ def insert(self, ...):
400
+ ...
401
+ """
402
+ @wraps(func)
403
+ def wrapper(self: Database, *args: Any, **kwargs: Any) -> Any:
404
+ if self._process_lock is None:
405
+ return func(self, *args, **kwargs)
406
+ with self._process_lock:
407
+ return func(self, *args, **kwargs)
408
+ return cast(F, wrapper)
409
+
410
+
411
+ # ============================================================================
412
+ # Health Metrics
413
+ # ============================================================================
414
+
415
+ @dataclass
416
+ class IndexStats:
417
+ """Statistics for a single index."""
418
+ name: str
419
+ index_type: str
420
+ num_indexed_rows: int
421
+ num_unindexed_rows: int
422
+ needs_update: bool
423
+
424
+
425
+ @dataclass
426
+ class HealthMetrics:
427
+ """Database health and performance metrics."""
428
+ total_rows: int
429
+ total_bytes: int
430
+ total_bytes_mb: float
431
+ num_fragments: int
432
+ num_small_fragments: int
433
+ needs_compaction: bool
434
+ has_vector_index: bool
435
+ has_fts_index: bool
436
+ indices: list[IndexStats]
437
+ version: int
438
+ error: str | None = None
439
+
440
+
441
+ # Backward compatibility aliases - use centralized validation module
442
+ _sanitize_string = _sanitize_string_impl
443
+ _validate_uuid = _validate_uuid_impl
444
+
445
+
446
+ def _get_index_attr(idx: Any, attr: str, default: Any = None) -> Any:
447
+ """Get an attribute from an index object (handles both dict and IndexConfig).
448
+
449
+ LanceDB 0.27+ returns IndexConfig objects, while older versions use dicts.
450
+
451
+ Args:
452
+ idx: Index object (dict or IndexConfig).
453
+ attr: Attribute name to retrieve.
454
+ default: Default value if attribute not found.
455
+
456
+ Returns:
457
+ The attribute value or default.
458
+ """
459
+ if isinstance(idx, dict):
460
+ return idx.get(attr, default)
461
+ return getattr(idx, attr, default)
462
+
463
+
464
+ _validate_namespace = _validate_namespace_impl
465
+ _validate_tags = _validate_tags_impl
466
+ _validate_metadata = _validate_metadata_impl
467
+
468
+
469
+ class Database:
470
+ """LanceDB wrapper for memory storage and retrieval.
471
+
472
+ Enterprise-grade features:
473
+ - Connection pooling via singleton pattern with LRU eviction
474
+ - Automatic index creation based on dataset size
475
+ - Hybrid search with RRF reranking and alpha parameter
476
+ - Batch operations for efficiency
477
+ - Row count caching for search performance (thread-safe)
478
+ - Maintenance and optimization utilities
479
+
480
+ Thread Safety:
481
+ The module-level connection pool is thread-safe. However, individual
482
+ Database instances should NOT be shared across threads without external
483
+ synchronization. Each thread should create its own Database instance,
484
+ which will share the underlying pooled connection safely.
485
+
486
+ Supports context manager protocol for safe resource management.
487
+
488
+ Example:
489
+ with Database(path) as db:
490
+ db.insert(content="Hello", vector=vec)
491
+ """
492
+
493
+ # Cache refresh interval for row count (seconds)
494
+ _COUNT_CACHE_TTL = 60.0
495
+ # Cache refresh interval for namespaces (seconds) - longer because namespaces change less often
496
+ _NAMESPACE_CACHE_TTL = 300.0
497
+
498
+ def __init__(
499
+ self,
500
+ storage_path: Path,
501
+ embedding_dim: int = 384,
502
+ auto_create_indexes: bool = True,
503
+ vector_index_threshold: int = 10_000,
504
+ enable_fts: bool = True,
505
+ index_nprobes: int = 20,
506
+ index_refine_factor: int = 5,
507
+ max_retry_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
508
+ retry_backoff_seconds: float = DEFAULT_RETRY_BACKOFF_SECONDS,
509
+ read_consistency_interval_ms: int = 0,
510
+ index_wait_timeout_seconds: float = 30.0,
511
+ fts_stem: bool = True,
512
+ fts_remove_stop_words: bool = True,
513
+ fts_language: str = "English",
514
+ index_type: str = "IVF_PQ",
515
+ hnsw_m: int = 20,
516
+ hnsw_ef_construction: int = 300,
517
+ enable_memory_expiration: bool = False,
518
+ default_memory_ttl_days: int | None = None,
519
+ filelock_enabled: bool = True,
520
+ filelock_timeout: float = 30.0,
521
+ filelock_poll_interval: float = 0.1,
522
+ acknowledge_network_filesystem_risk: bool = False,
523
+ ) -> None:
524
+ """Initialize the database connection.
525
+
526
+ Args:
527
+ storage_path: Path to LanceDB storage directory.
528
+ embedding_dim: Dimension of embedding vectors.
529
+ auto_create_indexes: Automatically create indexes when thresholds met.
530
+ vector_index_threshold: Row count to trigger vector index creation.
531
+ enable_fts: Enable full-text search index.
532
+ index_nprobes: Number of partitions to search (higher = better recall).
533
+ index_refine_factor: Re-rank top (refine_factor * limit) for accuracy.
534
+ max_retry_attempts: Maximum retry attempts for transient errors.
535
+ retry_backoff_seconds: Initial backoff time for retries.
536
+ read_consistency_interval_ms: Read consistency interval (0 = strong).
537
+ index_wait_timeout_seconds: Timeout for waiting on index creation.
538
+ fts_stem: Enable stemming in FTS (running -> run).
539
+ fts_remove_stop_words: Remove stop words in FTS (the, is, etc.).
540
+ filelock_enabled: Enable cross-process file locking.
541
+ filelock_timeout: Timeout in seconds for acquiring filelock.
542
+ filelock_poll_interval: Interval between lock acquisition attempts.
543
+ fts_language: Language for FTS stemming.
544
+ index_type: Vector index type (IVF_PQ, IVF_FLAT, or HNSW_SQ).
545
+ hnsw_m: HNSW connections per node (4-64).
546
+ hnsw_ef_construction: HNSW build-time search width (100-1000).
547
+ enable_memory_expiration: Enable automatic memory expiration.
548
+ default_memory_ttl_days: Default TTL for memories in days (None = no expiration).
549
+ acknowledge_network_filesystem_risk: Suppress network filesystem warnings.
550
+ """
551
+ self.storage_path = Path(storage_path)
552
+ self.embedding_dim = embedding_dim
553
+ self.auto_create_indexes = auto_create_indexes
554
+ self.vector_index_threshold = vector_index_threshold
555
+ self.enable_fts = enable_fts
556
+ self.index_nprobes = index_nprobes
557
+ self.index_refine_factor = index_refine_factor
558
+ self.max_retry_attempts = max_retry_attempts
559
+ self.retry_backoff_seconds = retry_backoff_seconds
560
+ self.read_consistency_interval_ms = read_consistency_interval_ms
561
+ self.index_wait_timeout_seconds = index_wait_timeout_seconds
562
+ self.fts_stem = fts_stem
563
+ self.fts_remove_stop_words = fts_remove_stop_words
564
+ self.fts_language = fts_language
565
+ self.index_type = index_type
566
+ self.hnsw_m = hnsw_m
567
+ self.hnsw_ef_construction = hnsw_ef_construction
568
+ self.enable_memory_expiration = enable_memory_expiration
569
+ self.default_memory_ttl_days = default_memory_ttl_days
570
+ self.filelock_enabled = filelock_enabled
571
+ self.filelock_timeout = filelock_timeout
572
+ self.filelock_poll_interval = filelock_poll_interval
573
+ self.acknowledge_network_filesystem_risk = acknowledge_network_filesystem_risk
574
+ self._db: lancedb.DBConnection | None = None
575
+ self._table: LanceTable | None = None
576
+ self._has_vector_index: bool | None = None
577
+ self._has_fts_index: bool | None = None
578
+ # Row count cache for performance (avoid count_rows() on every search)
579
+ self._cached_row_count: int | None = None
580
+ self._count_cache_time: float = 0.0
581
+ # Thread-safe lock for row count cache
582
+ self._cache_lock = threading.Lock()
583
+ # Namespace cache for performance
584
+ self._cached_namespaces: set[str] | None = None
585
+ self._namespace_cache_time: float = 0.0
586
+ self._namespace_cache_lock = threading.Lock()
587
+ # Write lock for serializing mutations (prevents LanceDB version conflicts)
588
+ self._write_lock = threading.RLock()
589
+ # Cross-process lock (initialized in connect())
590
+ self._process_lock: ProcessLockManager | None = None
591
+ # Auto-compaction tracking
592
+ self._modification_count: int = 0
593
+ self._auto_compaction_threshold: int = 100 # Compact after this many modifications
594
+ self._auto_compaction_enabled: bool = True
595
+ # Version manager (initialized in connect())
596
+ self._version_manager: VersionManager | None = None
597
+ # Index manager (initialized in connect())
598
+ self._index_manager: IndexManager | None = None
599
+ # Search manager (initialized in connect())
600
+ self._search_manager: SearchManager | None = None
601
+ # Idempotency manager (initialized in connect())
602
+ self._idempotency_manager: IdempotencyManager | None = None
603
+
604
+ def __enter__(self) -> Database:
605
+ """Enter context manager."""
606
+ self.connect()
607
+ return self
608
+
609
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
610
+ """Exit context manager."""
611
+ self.close()
612
+
613
+ def connect(self) -> None:
614
+ """Connect to the database using pooled connections."""
615
+ try:
616
+ self.storage_path.mkdir(parents=True, exist_ok=True)
617
+
618
+ # Check for network filesystem and warn if detected
619
+ if not self.acknowledge_network_filesystem_risk:
620
+ if is_network_filesystem(self.storage_path):
621
+ fs_type = detect_filesystem_type(self.storage_path)
622
+ warning_msg = get_filesystem_warning_message(fs_type, self.storage_path)
623
+ logger.warning(warning_msg)
624
+
625
+ # Initialize cross-process lock manager
626
+ if self.filelock_enabled:
627
+ lock_path = self.storage_path / ".spatial-memory.lock"
628
+ self._process_lock = ProcessLockManager(
629
+ lock_path=lock_path,
630
+ timeout=self.filelock_timeout,
631
+ poll_interval=self.filelock_poll_interval,
632
+ enabled=self.filelock_enabled,
633
+ )
634
+ else:
635
+ self._process_lock = None
636
+
637
+ # Use connection pooling with read consistency support
638
+ self._db = _get_or_create_connection(
639
+ self.storage_path,
640
+ read_consistency_interval_ms=self.read_consistency_interval_ms,
641
+ )
642
+ self._ensure_table()
643
+ # Initialize remaining managers (IndexManager already initialized in _ensure_table)
644
+ self._version_manager = VersionManager(self)
645
+ self._search_manager = SearchManager(self)
646
+ self._idempotency_manager = IdempotencyManager(self)
647
+ logger.info(f"Connected to LanceDB at {self.storage_path}")
648
+
649
+ # Check for pending schema migrations
650
+ self._check_pending_migrations()
651
+ except Exception as e:
652
+ raise StorageError(f"Failed to connect to database: {e}") from e
653
+
654
+ def _check_pending_migrations(self) -> None:
655
+ """Check for pending migrations and warn if any exist.
656
+
657
+ This method checks the schema version and logs a warning if there
658
+ are pending migrations. It does not auto-apply migrations - that
659
+ requires explicit user action via the CLI.
660
+ """
661
+ try:
662
+ manager = MigrationManager(self, embeddings=None)
663
+ manager.register_builtin_migrations()
664
+
665
+ current_version = manager.get_current_version()
666
+ pending = manager.get_pending_migrations()
667
+
668
+ if pending:
669
+ pending_versions = [m.version for m in pending]
670
+ logger.warning(
671
+ f"Database schema version {current_version} is outdated. "
672
+ f"{len(pending)} migration(s) pending: {', '.join(pending_versions)}. "
673
+ f"Target version: {CURRENT_SCHEMA_VERSION}. "
674
+ f"Run 'spatial-memory migrate' to apply migrations."
675
+ )
676
+ except Exception as e:
677
+ # Don't fail connection due to migration check errors
678
+ logger.debug(f"Migration check skipped: {e}")
679
+
680
+ def _ensure_table(self) -> None:
681
+ """Ensure the memories table exists with appropriate indexes.
682
+
683
+ Uses retry logic to handle race conditions when multiple processes
684
+ attempt to create/open the table simultaneously.
685
+ """
686
+ if self._db is None:
687
+ raise StorageError("Database not connected")
688
+
689
+ max_retries = 3
690
+ retry_delay = 0.1 # Start with 100ms
691
+
692
+ for attempt in range(max_retries):
693
+ try:
694
+ existing_tables_result = self._db.list_tables()
695
+ # Handle both old (list) and new (object with .tables) LanceDB API
696
+ if hasattr(existing_tables_result, 'tables'):
697
+ existing_tables = existing_tables_result.tables
698
+ else:
699
+ existing_tables = existing_tables_result
700
+
701
+ if "memories" not in existing_tables:
702
+ # Create table with schema
703
+ schema = pa.schema([
704
+ pa.field("id", pa.string()),
705
+ pa.field("content", pa.string()),
706
+ pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
707
+ pa.field("created_at", pa.timestamp("us")),
708
+ pa.field("updated_at", pa.timestamp("us")),
709
+ pa.field("last_accessed", pa.timestamp("us")),
710
+ pa.field("access_count", pa.int32()),
711
+ pa.field("importance", pa.float32()),
712
+ pa.field("namespace", pa.string()),
713
+ pa.field("tags", pa.list_(pa.string())),
714
+ pa.field("source", pa.string()),
715
+ pa.field("metadata", pa.string()),
716
+ pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
717
+ ])
718
+ try:
719
+ self._table = self._db.create_table("memories", schema=schema)
720
+ logger.info("Created memories table")
721
+ except Exception as create_err:
722
+ # Table might have been created by another process
723
+ if "already exists" in str(create_err).lower():
724
+ logger.debug("Table created by another process, opening it")
725
+ self._table = self._db.open_table("memories")
726
+ else:
727
+ raise
728
+
729
+ # Initialize IndexManager immediately after table is set
730
+ self._index_manager = IndexManager(self)
731
+
732
+ # Create FTS index on new table if enabled
733
+ if self.enable_fts:
734
+ self._index_manager.create_fts_index()
735
+ else:
736
+ self._table = self._db.open_table("memories")
737
+ logger.debug("Opened existing memories table")
738
+
739
+ # Initialize IndexManager immediately after table is set
740
+ self._index_manager = IndexManager(self)
741
+
742
+ # Check existing indexes
743
+ self._index_manager.check_existing_indexes()
744
+
745
+ # Success - exit retry loop
746
+ return
747
+
748
+ except Exception as e:
749
+ error_msg = str(e).lower()
750
+ # Retry on transient race conditions
751
+ if attempt < max_retries - 1 and (
752
+ "not found" in error_msg
753
+ or "does not exist" in error_msg
754
+ or "already exists" in error_msg
755
+ ):
756
+ logger.debug(
757
+ f"Table operation failed (attempt {attempt + 1}/{max_retries}), "
758
+ f"retrying in {retry_delay}s: {e}"
759
+ )
760
+ time.sleep(retry_delay)
761
+ retry_delay *= 2 # Exponential backoff
762
+ else:
763
+ raise
764
+
765
+ def _check_existing_indexes(self) -> None:
766
+ """Check which indexes already exist. Delegates to IndexManager."""
767
+ if self._index_manager is None:
768
+ raise StorageError("Database not connected")
769
+ self._index_manager.check_existing_indexes()
770
+ # Sync local state for backward compatibility
771
+ self._has_vector_index = self._index_manager.has_vector_index
772
+ self._has_fts_index = self._index_manager.has_fts_index
773
+
774
+ def _create_fts_index(self) -> None:
775
+ """Create FTS index. Delegates to IndexManager."""
776
+ if self._index_manager is None:
777
+ raise StorageError("Database not connected")
778
+ self._index_manager.create_fts_index()
779
+ # Sync local state for backward compatibility
780
+ self._has_fts_index = self._index_manager.has_fts_index
781
+
782
+ @property
783
+ def table(self) -> LanceTable:
784
+ """Get the memories table, connecting if needed."""
785
+ if self._table is None:
786
+ self.connect()
787
+ assert self._table is not None # connect() sets this or raises
788
+ return self._table
789
+
790
+ def close(self) -> None:
791
+ """Close the database connection and remove from pool.
792
+
793
+ This invalidates the pooled connection so that subsequent
794
+ Database instances will create fresh connections.
795
+ """
796
+ # Invalidate pooled connection first
797
+ invalidate_connection(self.storage_path)
798
+
799
+ # Clear local state
800
+ self._table = None
801
+ self._db = None
802
+ self._has_vector_index = None
803
+ self._has_fts_index = None
804
+ self._version_manager = None
805
+ self._index_manager = None
806
+ self._search_manager = None
807
+ self._idempotency_manager = None
808
+ with self._cache_lock:
809
+ self._cached_row_count = None
810
+ self._count_cache_time = 0.0
811
+ with self._namespace_cache_lock:
812
+ self._cached_namespaces = None
813
+ self._namespace_cache_time = 0.0
814
+ logger.debug("Database connection closed and removed from pool")
815
+
816
+ def reconnect(self) -> None:
817
+ """Invalidate cached connection and reconnect.
818
+
819
+ Use when the database connection becomes stale (e.g., database was
820
+ deleted and recreated externally, or metadata references missing files).
821
+
822
+ This method:
823
+ 1. Closes the current connection state
824
+ 2. Invalidates the pooled connection for this path
825
+ 3. Creates a fresh connection to the database
826
+ """
827
+ logger.info(f"Reconnecting to database at {self.storage_path}")
828
+ self.close()
829
+ invalidate_connection(self.storage_path)
830
+ self.connect()
831
+
832
+ def _is_stale_connection_error(self, error: Exception) -> bool:
833
+ """Check if an error indicates a stale/corrupted connection.
834
+
835
+ Args:
836
+ error: The exception to check.
837
+
838
+ Returns:
839
+ True if the error indicates a stale connection.
840
+ """
841
+ return ConnectionPool.is_stale_connection_error(error)
842
+
843
+ def _get_cached_row_count(self) -> int:
844
+ """Get row count with caching for performance (thread-safe).
845
+
846
+ Avoids calling count_rows() on every search operation.
847
+ Cache is invalidated on insert/delete or after TTL expires.
848
+
849
+ Returns:
850
+ Cached or fresh row count.
851
+ """
852
+ now = time.time()
853
+ with self._cache_lock:
854
+ if (
855
+ self._cached_row_count is None
856
+ or (now - self._count_cache_time) > self._COUNT_CACHE_TTL
857
+ ):
858
+ self._cached_row_count = self.table.count_rows()
859
+ self._count_cache_time = now
860
+ return self._cached_row_count
861
+
862
+ def _invalidate_count_cache(self) -> None:
863
+ """Invalidate the row count cache after modifications (thread-safe)."""
864
+ with self._cache_lock:
865
+ self._cached_row_count = None
866
+ self._count_cache_time = 0.0
867
+
868
+ def _invalidate_namespace_cache(self) -> None:
869
+ """Invalidate the namespace cache after modifications (thread-safe)."""
870
+ with self._namespace_cache_lock:
871
+ self._cached_namespaces = None
872
+ self._namespace_cache_time = 0.0
873
+
874
+ def _track_modification(self, count: int = 1) -> None:
875
+ """Track database modifications and trigger auto-compaction if threshold reached.
876
+
877
+ Args:
878
+ count: Number of modifications to track (default 1).
879
+ """
880
+ if not self._auto_compaction_enabled:
881
+ return
882
+
883
+ self._modification_count += count
884
+ if self._modification_count >= self._auto_compaction_threshold:
885
+ # Reset counter before compacting to avoid re-triggering
886
+ self._modification_count = 0
887
+ try:
888
+ stats = self._get_table_stats()
889
+ # Only compact if there are enough fragments to justify it
890
+ if stats.get("num_small_fragments", 0) >= 5:
891
+ logger.info(
892
+ f"Auto-compaction triggered after {self._auto_compaction_threshold} "
893
+ f"modifications ({stats.get('num_small_fragments', 0)} small fragments)"
894
+ )
895
+ self.table.compact_files()
896
+ logger.debug("Auto-compaction completed")
897
+ except Exception as e:
898
+ # Don't fail operations due to compaction issues
899
+ logger.debug(f"Auto-compaction skipped: {e}")
900
+
901
+ def set_auto_compaction(
902
+ self,
903
+ enabled: bool = True,
904
+ threshold: int | None = None,
905
+ ) -> None:
906
+ """Configure auto-compaction behavior.
907
+
908
+ Args:
909
+ enabled: Whether auto-compaction is enabled.
910
+ threshold: Number of modifications before auto-compact (default: 100).
911
+ """
912
+ self._auto_compaction_enabled = enabled
913
+ if threshold is not None:
914
+ if threshold < 10:
915
+ raise ValueError("Auto-compaction threshold must be at least 10")
916
+ self._auto_compaction_threshold = threshold
917
+
918
+ # ========================================================================
919
+ # Index Management (delegates to IndexManager)
920
+ # ========================================================================
921
+
922
+ def create_vector_index(self, force: bool = False) -> bool:
923
+ """Create vector index for similarity search. Delegates to IndexManager.
924
+
925
+ Args:
926
+ force: Force index creation regardless of dataset size.
927
+
928
+ Returns:
929
+ True if index was created, False if skipped.
930
+
931
+ Raises:
932
+ StorageError: If index creation fails.
933
+ """
934
+ if self._index_manager is None:
935
+ raise StorageError("Database not connected")
936
+ result = self._index_manager.create_vector_index(force=force)
937
+ # Sync local state only when index was created or modified
938
+ if result:
939
+ self._has_vector_index = self._index_manager.has_vector_index
940
+ return result
941
+
942
+ def create_scalar_indexes(self) -> None:
943
+ """Create scalar indexes for frequently filtered columns. Delegates to IndexManager.
944
+
945
+ Raises:
946
+ StorageError: If index creation fails critically.
947
+ """
948
+ if self._index_manager is None:
949
+ raise StorageError("Database not connected")
950
+ self._index_manager.create_scalar_indexes()
951
+
952
+ def ensure_indexes(self, force: bool = False) -> dict[str, bool]:
953
+ """Ensure all appropriate indexes exist. Delegates to IndexManager.
954
+
955
+ Args:
956
+ force: Force index creation regardless of thresholds.
957
+
958
+ Returns:
959
+ Dict indicating which indexes were created.
960
+ """
961
+ if self._index_manager is None:
962
+ raise StorageError("Database not connected")
963
+ results = self._index_manager.ensure_indexes(force=force)
964
+ # Sync local state for backward compatibility
965
+ self._has_vector_index = self._index_manager.has_vector_index
966
+ self._has_fts_index = self._index_manager.has_fts_index
967
+ return results
968
+
969
+ # ========================================================================
970
+ # Maintenance & Optimization
971
+ # ========================================================================
972
+
973
+ def optimize(self) -> dict[str, Any]:
974
+ """Run optimization and maintenance tasks.
975
+
976
+ Performs:
977
+ - File compaction (merges small fragments)
978
+ - Index optimization
979
+
980
+ Returns:
981
+ Statistics about optimization performed.
982
+ """
983
+ try:
984
+ stats_before = self._get_table_stats()
985
+
986
+ # Compact small fragments
987
+ needs_compaction = stats_before.get("num_small_fragments", 0) > 10
988
+ if needs_compaction:
989
+ logger.info("Compacting fragments...")
990
+ self.table.compact_files()
991
+
992
+ # Optimize indexes
993
+ logger.info("Optimizing indexes...")
994
+ self.table.optimize()
995
+
996
+ stats_after = self._get_table_stats()
997
+
998
+ return {
999
+ "fragments_before": stats_before.get("num_fragments", 0),
1000
+ "fragments_after": stats_after.get("num_fragments", 0),
1001
+ "compaction_performed": needs_compaction,
1002
+ "total_rows": stats_after.get("num_rows", 0),
1003
+ }
1004
+
1005
+ except Exception as e:
1006
+ logger.error(f"Optimization failed: {e}")
1007
+ return {"error": str(e)}
1008
+
1009
+ def _get_table_stats(self) -> dict[str, Any]:
1010
+ """Get table statistics with best-effort fragment info."""
1011
+ try:
1012
+ count = self.table.count_rows()
1013
+ stats: dict[str, Any] = {
1014
+ "num_rows": count,
1015
+ "num_fragments": 0,
1016
+ "num_small_fragments": 0,
1017
+ }
1018
+
1019
+ # Try to get fragment stats from table.stats() if available
1020
+ try:
1021
+ if hasattr(self.table, "stats"):
1022
+ table_stats = self.table.stats()
1023
+ if isinstance(table_stats, dict):
1024
+ stats["num_fragments"] = table_stats.get("num_fragments", 0)
1025
+ stats["num_small_fragments"] = table_stats.get("num_small_fragments", 0)
1026
+ elif hasattr(table_stats, "num_fragments"):
1027
+ stats["num_fragments"] = table_stats.num_fragments
1028
+ stats["num_small_fragments"] = getattr(
1029
+ table_stats, "num_small_fragments", 0
1030
+ )
1031
+ except Exception as e:
1032
+ logger.debug(f"Could not get fragment stats: {e}")
1033
+
1034
+ return stats
1035
+ except Exception as e:
1036
+ logger.warning(f"Could not get table stats: {e}")
1037
+ return {}
1038
+
1039
+ @with_stale_connection_recovery
1040
+ def get_health_metrics(self) -> HealthMetrics:
1041
+ """Get comprehensive health and performance metrics.
1042
+
1043
+ Returns:
1044
+ HealthMetrics dataclass with all metrics.
1045
+ """
1046
+ try:
1047
+ count = self.table.count_rows()
1048
+
1049
+ # Estimate size (rough approximation)
1050
+ # vector (dim * 4 bytes) + avg content size estimate
1051
+ estimated_bytes = count * (self.embedding_dim * 4 + 1000)
1052
+
1053
+ # Check indexes
1054
+ indices: list[IndexStats] = []
1055
+ try:
1056
+ for idx in self.table.list_indices():
1057
+ indices.append(IndexStats(
1058
+ name=str(_get_index_attr(idx, "name", "unknown")),
1059
+ index_type=str(_get_index_attr(idx, "index_type", "unknown")),
1060
+ num_indexed_rows=count, # Approximate
1061
+ num_unindexed_rows=0,
1062
+ needs_update=False,
1063
+ ))
1064
+ except Exception as e:
1065
+ logger.warning(f"Could not get index stats: {e}")
1066
+
1067
+ return HealthMetrics(
1068
+ total_rows=count,
1069
+ total_bytes=estimated_bytes,
1070
+ total_bytes_mb=estimated_bytes / (1024 * 1024),
1071
+ num_fragments=0,
1072
+ num_small_fragments=0,
1073
+ needs_compaction=False,
1074
+ has_vector_index=self._has_vector_index or False,
1075
+ has_fts_index=self._has_fts_index or False,
1076
+ indices=indices,
1077
+ version=0,
1078
+ )
1079
+
1080
+ except Exception as e:
1081
+ return HealthMetrics(
1082
+ total_rows=0,
1083
+ total_bytes=0,
1084
+ total_bytes_mb=0,
1085
+ num_fragments=0,
1086
+ num_small_fragments=0,
1087
+ needs_compaction=False,
1088
+ has_vector_index=False,
1089
+ has_fts_index=False,
1090
+ indices=[],
1091
+ version=0,
1092
+ error=str(e),
1093
+ )
1094
+
1095
+ @with_process_lock
1096
+ @with_write_lock
1097
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
1098
+ def insert(
1099
+ self,
1100
+ content: str,
1101
+ vector: np.ndarray,
1102
+ namespace: str = "default",
1103
+ tags: list[str] | None = None,
1104
+ importance: float = 0.5,
1105
+ source: str = "manual",
1106
+ metadata: dict[str, Any] | None = None,
1107
+ ) -> str:
1108
+ """Insert a new memory.
1109
+
1110
+ Args:
1111
+ content: Text content of the memory.
1112
+ vector: Embedding vector.
1113
+ namespace: Namespace for organization.
1114
+ tags: List of tags.
1115
+ importance: Importance score (0-1).
1116
+ source: Source of the memory.
1117
+ metadata: Additional metadata.
1118
+
1119
+ Returns:
1120
+ The generated memory ID.
1121
+
1122
+ Raises:
1123
+ ValidationError: If input validation fails.
1124
+ StorageError: If database operation fails.
1125
+ """
1126
+ # Validate inputs
1127
+ namespace = _validate_namespace(namespace)
1128
+ tags = _validate_tags(tags)
1129
+ metadata = _validate_metadata(metadata)
1130
+ if not content or len(content) > 100000:
1131
+ raise ValidationError("Content must be between 1 and 100000 characters")
1132
+ if not 0.0 <= importance <= 1.0:
1133
+ raise ValidationError("Importance must be between 0.0 and 1.0")
1134
+
1135
+ # Validate vector dimensions
1136
+ if len(vector) != self.embedding_dim:
1137
+ raise DimensionMismatchError(
1138
+ expected_dim=self.embedding_dim,
1139
+ actual_dim=len(vector),
1140
+ )
1141
+
1142
+ memory_id = str(uuid.uuid4())
1143
+ now = utc_now()
1144
+
1145
+ # Calculate expires_at if default TTL is configured
1146
+ expires_at = None
1147
+ if self.default_memory_ttl_days is not None:
1148
+ expires_at = now + timedelta(days=self.default_memory_ttl_days)
1149
+
1150
+ record = {
1151
+ "id": memory_id,
1152
+ "content": content,
1153
+ "vector": vector.tolist(),
1154
+ "created_at": now,
1155
+ "updated_at": now,
1156
+ "last_accessed": now,
1157
+ "access_count": 0,
1158
+ "importance": importance,
1159
+ "namespace": namespace,
1160
+ "tags": tags,
1161
+ "source": source,
1162
+ "metadata": json.dumps(metadata),
1163
+ "expires_at": expires_at,
1164
+ }
1165
+
1166
+ try:
1167
+ self.table.add([record])
1168
+ self._invalidate_count_cache()
1169
+ self._track_modification()
1170
+ self._invalidate_namespace_cache()
1171
+ logger.debug(f"Inserted memory {memory_id}")
1172
+ return memory_id
1173
+ except Exception as e:
1174
+ raise StorageError(f"Failed to insert memory: {e}") from e
1175
+
1176
+ # Maximum batch size to prevent memory exhaustion
1177
+ MAX_BATCH_SIZE = 10_000
1178
+
1179
+ @with_process_lock
1180
+ @with_write_lock
1181
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
1182
+ def insert_batch(
1183
+ self,
1184
+ records: list[dict[str, Any]],
1185
+ batch_size: int = 1000,
1186
+ atomic: bool = False,
1187
+ ) -> list[str]:
1188
+ """Insert multiple memories efficiently with batching.
1189
+
1190
+ Args:
1191
+ records: List of memory records with content, vector, and optional fields.
1192
+ batch_size: Records per batch (default: 1000, max: 10000).
1193
+ atomic: If True, rollback all inserts on partial failure.
1194
+ When atomic=True and a batch fails:
1195
+ - Attempts to delete already-inserted records
1196
+ - If rollback succeeds, raises the original StorageError
1197
+ - If rollback fails, raises PartialBatchInsertError with succeeded_ids
1198
+
1199
+ Returns:
1200
+ List of generated memory IDs.
1201
+
1202
+ Raises:
1203
+ ValidationError: If input validation fails or batch_size exceeds maximum.
1204
+ StorageError: If database operation fails (and rollback succeeds when atomic=True).
1205
+ PartialBatchInsertError: If atomic=True and rollback fails after partial insert.
1206
+ """
1207
+ if batch_size > self.MAX_BATCH_SIZE:
1208
+ raise ValidationError(
1209
+ f"batch_size ({batch_size}) exceeds maximum {self.MAX_BATCH_SIZE}"
1210
+ )
1211
+
1212
+ all_ids: list[str] = []
1213
+ total_requested = len(records)
1214
+
1215
+ # Process in batches for large inserts
1216
+ for batch_index, i in enumerate(range(0, len(records), batch_size)):
1217
+ batch = records[i:i + batch_size]
1218
+ now = utc_now()
1219
+ memory_ids: list[str] = []
1220
+ prepared_records: list[dict[str, Any]] = []
1221
+
1222
+ for record in batch:
1223
+ # Validate each record
1224
+ namespace = _validate_namespace(record.get("namespace", "default"))
1225
+ tags = _validate_tags(record.get("tags"))
1226
+ metadata = _validate_metadata(record.get("metadata"))
1227
+ content = record.get("content", "")
1228
+ if not content or len(content) > 100000:
1229
+ raise ValidationError("Content must be between 1 and 100000 characters")
1230
+
1231
+ importance = record.get("importance", 0.5)
1232
+ if not 0.0 <= importance <= 1.0:
1233
+ raise ValidationError("Importance must be between 0.0 and 1.0")
1234
+
1235
+ memory_id = str(uuid.uuid4())
1236
+ memory_ids.append(memory_id)
1237
+
1238
+ raw_vector = record["vector"]
1239
+ if isinstance(raw_vector, np.ndarray):
1240
+ vector_list = raw_vector.tolist()
1241
+ else:
1242
+ vector_list = raw_vector
1243
+
1244
+ # Validate vector dimensions
1245
+ if len(vector_list) != self.embedding_dim:
1246
+ raise DimensionMismatchError(
1247
+ expected_dim=self.embedding_dim,
1248
+ actual_dim=len(vector_list),
1249
+ record_index=i + len(memory_ids),
1250
+ )
1251
+
1252
+ # Calculate expires_at if default TTL is configured
1253
+ expires_at = None
1254
+ if self.default_memory_ttl_days is not None:
1255
+ expires_at = now + timedelta(days=self.default_memory_ttl_days)
1256
+
1257
+ prepared = {
1258
+ "id": memory_id,
1259
+ "content": content,
1260
+ "vector": vector_list,
1261
+ "created_at": now,
1262
+ "updated_at": now,
1263
+ "last_accessed": now,
1264
+ "access_count": 0,
1265
+ "importance": importance,
1266
+ "namespace": namespace,
1267
+ "tags": tags,
1268
+ "source": record.get("source", "manual"),
1269
+ "metadata": json.dumps(metadata),
1270
+ "expires_at": expires_at,
1271
+ }
1272
+ prepared_records.append(prepared)
1273
+
1274
+ try:
1275
+ self.table.add(prepared_records)
1276
+ all_ids.extend(memory_ids)
1277
+ self._invalidate_count_cache()
1278
+ self._track_modification(len(memory_ids))
1279
+ self._invalidate_namespace_cache()
1280
+ logger.debug(f"Inserted batch {batch_index + 1}: {len(memory_ids)} memories")
1281
+ except Exception as e:
1282
+ if atomic and all_ids:
1283
+ # Attempt rollback of previously inserted records
1284
+ logger.warning(
1285
+ f"Batch {batch_index + 1} failed, attempting rollback of "
1286
+ f"{len(all_ids)} previously inserted records"
1287
+ )
1288
+ rollback_error = self._rollback_batch_insert(all_ids)
1289
+ if rollback_error:
1290
+ # Rollback failed - raise PartialBatchInsertError
1291
+ raise PartialBatchInsertError(
1292
+ message=f"Batch insert failed and rollback also failed: {e}",
1293
+ succeeded_ids=all_ids,
1294
+ total_requested=total_requested,
1295
+ failed_batch_index=batch_index,
1296
+ ) from e
1297
+ else:
1298
+ # Rollback succeeded - raise original error
1299
+ logger.info(f"Rollback successful, deleted {len(all_ids)} records")
1300
+ raise StorageError(f"Failed to insert batch (rolled back): {e}") from e
1301
+ raise StorageError(f"Failed to insert batch: {e}") from e
1302
+
1303
+ # Check if we should create indexes after large insert
1304
+ if self.auto_create_indexes and len(all_ids) >= 1000:
1305
+ count = self._get_cached_row_count()
1306
+ if count >= self.vector_index_threshold and not self._has_vector_index:
1307
+ logger.info("Dataset crossed index threshold, creating indexes...")
1308
+ try:
1309
+ self.ensure_indexes()
1310
+ except Exception as e:
1311
+ logger.warning(f"Auto-index creation failed: {e}")
1312
+
1313
+ logger.debug(f"Inserted {len(all_ids)} memories total")
1314
+ return all_ids
1315
+
1316
+ def _rollback_batch_insert(self, memory_ids: list[str]) -> Exception | None:
1317
+ """Attempt to delete records inserted during a failed batch operation.
1318
+
1319
+ Args:
1320
+ memory_ids: List of memory IDs to delete.
1321
+
1322
+ Returns:
1323
+ None if rollback succeeded, Exception if it failed.
1324
+ """
1325
+ try:
1326
+ if not memory_ids:
1327
+ return None
1328
+
1329
+ # Use delete_batch for efficient rollback
1330
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in memory_ids)
1331
+ self.table.delete(f"id IN ({id_list})")
1332
+ self._invalidate_count_cache()
1333
+ self._track_modification(len(memory_ids))
1334
+ self._invalidate_namespace_cache()
1335
+ logger.debug(f"Rolled back {len(memory_ids)} records")
1336
+ return None
1337
+ except Exception as e:
1338
+ logger.error(f"Rollback failed: {e}")
1339
+ return e
1340
+
1341
+ @with_stale_connection_recovery
1342
+ def get(self, memory_id: str) -> dict[str, Any]:
1343
+ """Get a memory by ID.
1344
+
1345
+ Args:
1346
+ memory_id: The memory ID.
1347
+
1348
+ Returns:
1349
+ The memory record.
1350
+
1351
+ Raises:
1352
+ ValidationError: If memory_id is invalid.
1353
+ MemoryNotFoundError: If memory doesn't exist.
1354
+ StorageError: If database operation fails.
1355
+ """
1356
+ # Validate and sanitize memory_id
1357
+ memory_id = _validate_uuid(memory_id)
1358
+ safe_id = _sanitize_string(memory_id)
1359
+
1360
+ try:
1361
+ results = self.table.search().where(f"id = '{safe_id}'").limit(1).to_list()
1362
+ if not results:
1363
+ raise MemoryNotFoundError(memory_id)
1364
+
1365
+ record: dict[str, Any] = results[0]
1366
+ record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
1367
+ return record
1368
+ except MemoryNotFoundError:
1369
+ raise
1370
+ except ValidationError:
1371
+ raise
1372
+ except Exception as e:
1373
+ raise StorageError(f"Failed to get memory: {e}") from e
1374
+
1375
+ def get_batch(self, memory_ids: list[str]) -> dict[str, dict[str, Any]]:
1376
+ """Get multiple memories by ID in a single query.
1377
+
1378
+ Args:
1379
+ memory_ids: List of memory UUIDs to retrieve.
1380
+
1381
+ Returns:
1382
+ Dict mapping memory_id to memory record. Missing IDs are not included.
1383
+
1384
+ Raises:
1385
+ ValidationError: If any memory_id format is invalid.
1386
+ StorageError: If database operation fails.
1387
+ """
1388
+ if not memory_ids:
1389
+ return {}
1390
+
1391
+ # Validate all IDs first
1392
+ validated_ids: list[str] = []
1393
+ for memory_id in memory_ids:
1394
+ try:
1395
+ validated_id = _validate_uuid(memory_id)
1396
+ validated_ids.append(_sanitize_string(validated_id))
1397
+ except Exception as e:
1398
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1399
+ continue
1400
+
1401
+ if not validated_ids:
1402
+ return {}
1403
+
1404
+ try:
1405
+ # Batch fetch with single IN query
1406
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1407
+ results = self.table.search().where(f"id IN ({id_list})").to_list()
1408
+
1409
+ # Build result map
1410
+ result_map: dict[str, dict[str, Any]] = {}
1411
+ for record in results:
1412
+ # Deserialize metadata
1413
+ record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
1414
+ result_map[record["id"]] = record
1415
+
1416
+ return result_map
1417
+ except Exception as e:
1418
+ raise StorageError(f"Failed to batch get memories: {e}") from e
1419
+
1420
+ @with_process_lock
1421
+ @with_write_lock
1422
+ def update(self, memory_id: str, updates: dict[str, Any]) -> None:
1423
+ """Update a memory using atomic merge_insert.
1424
+
1425
+ Uses LanceDB's merge_insert API for atomic upserts, eliminating
1426
+ race conditions from delete-then-insert patterns.
1427
+
1428
+ Args:
1429
+ memory_id: The memory ID.
1430
+ updates: Fields to update.
1431
+
1432
+ Raises:
1433
+ ValidationError: If input validation fails.
1434
+ MemoryNotFoundError: If memory doesn't exist.
1435
+ StorageError: If database operation fails.
1436
+ """
1437
+ # Validate memory_id
1438
+ memory_id = _validate_uuid(memory_id)
1439
+
1440
+ # First verify the memory exists
1441
+ existing = self.get(memory_id)
1442
+
1443
+ # Prepare updates
1444
+ updates["updated_at"] = utc_now()
1445
+ if "metadata" in updates and isinstance(updates["metadata"], dict):
1446
+ updates["metadata"] = json.dumps(updates["metadata"])
1447
+ if "vector" in updates and isinstance(updates["vector"], np.ndarray):
1448
+ updates["vector"] = updates["vector"].tolist()
1449
+
1450
+ # Merge existing with updates
1451
+ for key, value in updates.items():
1452
+ existing[key] = value
1453
+
1454
+ # Ensure metadata is serialized as JSON string for storage
1455
+ if isinstance(existing.get("metadata"), dict):
1456
+ existing["metadata"] = json.dumps(existing["metadata"])
1457
+
1458
+ # Ensure vector is a list, not numpy array
1459
+ if isinstance(existing.get("vector"), np.ndarray):
1460
+ existing["vector"] = existing["vector"].tolist()
1461
+
1462
+ try:
1463
+ # Atomic upsert using merge_insert
1464
+ # Requires BTREE index on 'id' column (created in create_scalar_indexes)
1465
+ (
1466
+ self.table.merge_insert("id")
1467
+ .when_matched_update_all()
1468
+ .when_not_matched_insert_all()
1469
+ .execute([existing])
1470
+ )
1471
+ logger.debug(f"Updated memory {memory_id} (atomic merge_insert)")
1472
+ except Exception as e:
1473
+ raise StorageError(f"Failed to update memory: {e}") from e
1474
+
1475
+ @with_process_lock
1476
+ @with_write_lock
1477
+ def update_batch(
1478
+ self, updates: list[tuple[str, dict[str, Any]]]
1479
+ ) -> tuple[int, list[str]]:
1480
+ """Update multiple memories using atomic merge_insert.
1481
+
1482
+ Args:
1483
+ updates: List of (memory_id, updates_dict) tuples.
1484
+
1485
+ Returns:
1486
+ Tuple of (success_count, list of failed memory_ids).
1487
+
1488
+ Raises:
1489
+ StorageError: If database operation fails completely.
1490
+ """
1491
+ if not updates:
1492
+ return 0, []
1493
+
1494
+ now = utc_now()
1495
+ records_to_update: list[dict[str, Any]] = []
1496
+ failed_ids: list[str] = []
1497
+
1498
+ # Validate all IDs and collect them
1499
+ validated_updates: list[tuple[str, dict[str, Any]]] = []
1500
+ for memory_id, update_dict in updates:
1501
+ try:
1502
+ validated_id = _validate_uuid(memory_id)
1503
+ validated_updates.append((_sanitize_string(validated_id), update_dict))
1504
+ except Exception as e:
1505
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1506
+ failed_ids.append(memory_id)
1507
+
1508
+ if not validated_updates:
1509
+ return 0, failed_ids
1510
+
1511
+ # Batch fetch all records
1512
+ validated_ids = [vid for vid, _ in validated_updates]
1513
+ try:
1514
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1515
+ all_records = self.table.search().where(f"id IN ({id_list})").to_list()
1516
+ except Exception as e:
1517
+ logger.error(f"Failed to batch fetch records for update: {e}")
1518
+ raise StorageError(f"Failed to batch fetch for update: {e}") from e
1519
+
1520
+ # Build lookup map
1521
+ record_map: dict[str, dict[str, Any]] = {}
1522
+ for record in all_records:
1523
+ record_map[record["id"]] = record
1524
+
1525
+ # Apply updates to found records
1526
+ update_dict_map = dict(validated_updates)
1527
+ for memory_id in validated_ids:
1528
+ if memory_id not in record_map:
1529
+ logger.debug(f"Memory {memory_id} not found for batch update")
1530
+ failed_ids.append(memory_id)
1531
+ continue
1532
+
1533
+ record = record_map[memory_id]
1534
+ update_dict = update_dict_map[memory_id]
1535
+
1536
+ # Apply updates
1537
+ record["updated_at"] = now
1538
+ for key, value in update_dict.items():
1539
+ if key == "metadata" and isinstance(value, dict):
1540
+ record[key] = json.dumps(value)
1541
+ elif key == "vector" and isinstance(value, np.ndarray):
1542
+ record[key] = value.tolist()
1543
+ else:
1544
+ record[key] = value
1545
+
1546
+ # Ensure metadata is serialized
1547
+ if isinstance(record.get("metadata"), dict):
1548
+ record["metadata"] = json.dumps(record["metadata"])
1549
+
1550
+ # Ensure vector is a list
1551
+ if isinstance(record.get("vector"), np.ndarray):
1552
+ record["vector"] = record["vector"].tolist()
1553
+
1554
+ records_to_update.append(record)
1555
+
1556
+ if not records_to_update:
1557
+ return 0, failed_ids
1558
+
1559
+ try:
1560
+ # Atomic batch upsert
1561
+ (
1562
+ self.table.merge_insert("id")
1563
+ .when_matched_update_all()
1564
+ .when_not_matched_insert_all()
1565
+ .execute(records_to_update)
1566
+ )
1567
+ success_count = len(records_to_update)
1568
+ logger.debug(
1569
+ f"Batch updated {success_count}/{len(updates)} memories "
1570
+ "(atomic merge_insert)"
1571
+ )
1572
+ return success_count, failed_ids
1573
+ except Exception as e:
1574
+ logger.error(f"Failed to batch update: {e}")
1575
+ raise StorageError(f"Failed to batch update: {e}") from e
1576
+
1577
+ @with_process_lock
1578
+ @with_write_lock
1579
+ def delete(self, memory_id: str) -> None:
1580
+ """Delete a memory.
1581
+
1582
+ Args:
1583
+ memory_id: The memory ID.
1584
+
1585
+ Raises:
1586
+ ValidationError: If memory_id is invalid.
1587
+ MemoryNotFoundError: If memory doesn't exist.
1588
+ StorageError: If database operation fails.
1589
+ """
1590
+ # Validate memory_id
1591
+ memory_id = _validate_uuid(memory_id)
1592
+ safe_id = _sanitize_string(memory_id)
1593
+
1594
+ # First verify the memory exists
1595
+ self.get(memory_id)
1596
+
1597
+ try:
1598
+ self.table.delete(f"id = '{safe_id}'")
1599
+ self._invalidate_count_cache()
1600
+ self._track_modification()
1601
+ self._invalidate_namespace_cache()
1602
+ logger.debug(f"Deleted memory {memory_id}")
1603
+ except Exception as e:
1604
+ raise StorageError(f"Failed to delete memory: {e}") from e
1605
+
1606
+ @with_process_lock
1607
+ @with_write_lock
1608
+ def delete_by_namespace(self, namespace: str) -> int:
1609
+ """Delete all memories in a namespace.
1610
+
1611
+ Args:
1612
+ namespace: The namespace to delete.
1613
+
1614
+ Returns:
1615
+ Number of deleted records.
1616
+
1617
+ Raises:
1618
+ ValidationError: If namespace is invalid.
1619
+ StorageError: If database operation fails.
1620
+ """
1621
+ namespace = _validate_namespace(namespace)
1622
+ safe_ns = _sanitize_string(namespace)
1623
+
1624
+ try:
1625
+ count_before: int = self.table.count_rows()
1626
+ self.table.delete(f"namespace = '{safe_ns}'")
1627
+ self._invalidate_count_cache()
1628
+ self._track_modification()
1629
+ self._invalidate_namespace_cache()
1630
+ count_after: int = self.table.count_rows()
1631
+ deleted = count_before - count_after
1632
+ logger.debug(f"Deleted {deleted} memories in namespace '{namespace}'")
1633
+ return deleted
1634
+ except Exception as e:
1635
+ raise StorageError(f"Failed to delete by namespace: {e}") from e
1636
+
1637
+ @with_process_lock
1638
+ @with_write_lock
1639
+ def clear_all(self, reset_indexes: bool = True) -> int:
1640
+ """Clear all memories from the database.
1641
+
1642
+ This is primarily for testing purposes to reset database state
1643
+ between tests while maintaining the connection.
1644
+
1645
+ Args:
1646
+ reset_indexes: If True, also reset index tracking flags.
1647
+ This allows tests to verify index creation behavior.
1648
+
1649
+ Returns:
1650
+ Number of deleted records.
1651
+
1652
+ Raises:
1653
+ StorageError: If database operation fails.
1654
+ """
1655
+ try:
1656
+ count: int = self.table.count_rows()
1657
+ if count > 0:
1658
+ # Delete all rows - use simpler predicate that definitely matches
1659
+ self.table.delete("true")
1660
+
1661
+ # Verify deletion worked
1662
+ remaining = self.table.count_rows()
1663
+ if remaining > 0:
1664
+ logger.warning(
1665
+ f"clear_all: {remaining} records remain after delete, "
1666
+ f"attempting cleanup again"
1667
+ )
1668
+ # Try alternative delete approach
1669
+ self.table.delete("id IS NOT NULL")
1670
+
1671
+ self._invalidate_count_cache()
1672
+ self._track_modification()
1673
+ self._invalidate_namespace_cache()
1674
+
1675
+ # Reset index tracking flags for test isolation
1676
+ if reset_indexes:
1677
+ self._has_vector_index = None
1678
+ self._has_fts_index = False
1679
+ self._has_scalar_indexes = False
1680
+
1681
+ logger.debug(f"Cleared all {count} memories from database")
1682
+ return count
1683
+ except Exception as e:
1684
+ raise StorageError(f"Failed to clear all memories: {e}") from e
1685
+
1686
+ @with_process_lock
1687
+ @with_write_lock
1688
+ def rename_namespace(self, old_namespace: str, new_namespace: str) -> int:
1689
+ """Rename all memories from one namespace to another.
1690
+
1691
+ Uses atomic batch update via merge_insert for data integrity.
1692
+ On partial failure, attempts to rollback renamed records to original namespace.
1693
+
1694
+ Args:
1695
+ old_namespace: Source namespace name.
1696
+ new_namespace: Target namespace name.
1697
+
1698
+ Returns:
1699
+ Number of memories renamed.
1700
+
1701
+ Raises:
1702
+ ValidationError: If namespace names are invalid.
1703
+ NamespaceNotFoundError: If old_namespace doesn't exist.
1704
+ StorageError: If database operation fails.
1705
+ """
1706
+ from spatial_memory.core.errors import NamespaceNotFoundError
1707
+
1708
+ old_namespace = _validate_namespace(old_namespace)
1709
+ new_namespace = _validate_namespace(new_namespace)
1710
+ safe_old = _sanitize_string(old_namespace)
1711
+ _sanitize_string(new_namespace) # Validate but don't store unused result
1712
+
1713
+ try:
1714
+ # Check if source namespace exists
1715
+ existing = self.get_namespaces()
1716
+ if old_namespace not in existing:
1717
+ raise NamespaceNotFoundError(old_namespace)
1718
+
1719
+ # Short-circuit if renaming to same namespace (no-op)
1720
+ if old_namespace == new_namespace:
1721
+ count = self.count(namespace=old_namespace)
1722
+ logger.debug(f"Namespace '{old_namespace}' renamed to itself ({count} records)")
1723
+ return count
1724
+
1725
+ # Track renamed IDs for rollback capability
1726
+ renamed_ids: list[str] = []
1727
+
1728
+ # Fetch all records in batches with iteration safeguards
1729
+ batch_size = 1000
1730
+ max_iterations = 10000 # Safety cap: 10M records at 1000/batch
1731
+ updated = 0
1732
+ iteration = 0
1733
+ previous_updated = 0
1734
+
1735
+ while True:
1736
+ iteration += 1
1737
+
1738
+ # Safety limit to prevent infinite loops
1739
+ if iteration > max_iterations:
1740
+ raise StorageError(
1741
+ f"rename_namespace exceeded maximum iterations ({max_iterations}). "
1742
+ f"Updated {updated} records before stopping. "
1743
+ "This may indicate a database consistency issue."
1744
+ )
1745
+
1746
+ records = (
1747
+ self.table.search()
1748
+ .where(f"namespace = '{safe_old}'")
1749
+ .limit(batch_size)
1750
+ .to_list()
1751
+ )
1752
+
1753
+ if not records:
1754
+ break
1755
+
1756
+ # Track IDs in this batch for potential rollback
1757
+ batch_ids = [r["id"] for r in records]
1758
+
1759
+ # Update namespace field
1760
+ for r in records:
1761
+ r["namespace"] = new_namespace
1762
+ r["updated_at"] = utc_now()
1763
+ if isinstance(r.get("metadata"), dict):
1764
+ r["metadata"] = json.dumps(r["metadata"])
1765
+ if isinstance(r.get("vector"), np.ndarray):
1766
+ r["vector"] = r["vector"].tolist()
1767
+
1768
+ try:
1769
+ # Atomic upsert
1770
+ (
1771
+ self.table.merge_insert("id")
1772
+ .when_matched_update_all()
1773
+ .when_not_matched_insert_all()
1774
+ .execute(records)
1775
+ )
1776
+ # Only track as renamed after successful update
1777
+ renamed_ids.extend(batch_ids)
1778
+ except Exception as batch_error:
1779
+ # Batch failed - attempt rollback of previously renamed records
1780
+ if renamed_ids:
1781
+ logger.warning(
1782
+ f"Batch {iteration} failed, attempting rollback of "
1783
+ f"{len(renamed_ids)} previously renamed records"
1784
+ )
1785
+ rollback_error = self._rollback_namespace_rename(
1786
+ renamed_ids, old_namespace
1787
+ )
1788
+ if rollback_error:
1789
+ raise StorageError(
1790
+ f"Namespace rename failed at batch {iteration} and "
1791
+ f"rollback also failed. {len(renamed_ids)} records may be "
1792
+ f"in inconsistent state (partially in '{new_namespace}'). "
1793
+ f"Original error: {batch_error}. Rollback error: {rollback_error}"
1794
+ ) from batch_error
1795
+ else:
1796
+ logger.info(
1797
+ f"Rollback successful, reverted {len(renamed_ids)} records "
1798
+ f"back to namespace '{old_namespace}'"
1799
+ )
1800
+ raise StorageError(
1801
+ f"Failed to rename namespace (rolled back): {batch_error}"
1802
+ ) from batch_error
1803
+
1804
+ updated += len(records)
1805
+
1806
+ # Detect stalled progress (same batch being processed repeatedly)
1807
+ if updated == previous_updated:
1808
+ raise StorageError(
1809
+ f"rename_namespace stalled at {updated} records. "
1810
+ "merge_insert may have failed silently."
1811
+ )
1812
+ previous_updated = updated
1813
+
1814
+ self._invalidate_namespace_cache()
1815
+ logger.debug(
1816
+ f"Renamed {updated} memories from '{old_namespace}' to '{new_namespace}'"
1817
+ )
1818
+ return updated
1819
+
1820
+ except (ValidationError, NamespaceNotFoundError):
1821
+ raise
1822
+ except Exception as e:
1823
+ raise StorageError(f"Failed to rename namespace: {e}") from e
1824
+
1825
+ def _rollback_namespace_rename(
1826
+ self, memory_ids: list[str], target_namespace: str
1827
+ ) -> Exception | None:
1828
+ """Attempt to revert renamed records back to original namespace.
1829
+
1830
+ Args:
1831
+ memory_ids: List of memory IDs to revert.
1832
+ target_namespace: Namespace to revert records to.
1833
+
1834
+ Returns:
1835
+ None if rollback succeeded, Exception if it failed.
1836
+ """
1837
+ try:
1838
+ if not memory_ids:
1839
+ return None
1840
+
1841
+ _sanitize_string(target_namespace) # Validate namespace
1842
+ now = utc_now()
1843
+
1844
+ # Process in batches for large rollbacks
1845
+ batch_size = 1000
1846
+ for i in range(0, len(memory_ids), batch_size):
1847
+ batch_ids = memory_ids[i:i + batch_size]
1848
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in batch_ids)
1849
+
1850
+ # Fetch records that need rollback
1851
+ records = (
1852
+ self.table.search()
1853
+ .where(f"id IN ({id_list})")
1854
+ .to_list()
1855
+ )
1856
+
1857
+ if not records:
1858
+ continue
1859
+
1860
+ # Revert namespace
1861
+ for r in records:
1862
+ r["namespace"] = target_namespace
1863
+ r["updated_at"] = now
1864
+ if isinstance(r.get("metadata"), dict):
1865
+ r["metadata"] = json.dumps(r["metadata"])
1866
+ if isinstance(r.get("vector"), np.ndarray):
1867
+ r["vector"] = r["vector"].tolist()
1868
+
1869
+ # Atomic upsert to restore original namespace
1870
+ (
1871
+ self.table.merge_insert("id")
1872
+ .when_matched_update_all()
1873
+ .when_not_matched_insert_all()
1874
+ .execute(records)
1875
+ )
1876
+
1877
+ self._invalidate_namespace_cache()
1878
+ logger.debug(f"Rolled back {len(memory_ids)} records to namespace '{target_namespace}'")
1879
+ return None
1880
+
1881
+ except Exception as e:
1882
+ logger.error(f"Namespace rename rollback failed: {e}")
1883
+ return e
1884
+
1885
+ @with_stale_connection_recovery
1886
+ def get_stats(self, namespace: str | None = None) -> dict[str, Any]:
1887
+ """Get comprehensive database statistics.
1888
+
1889
+ Uses efficient LanceDB queries for aggregations.
1890
+
1891
+ Args:
1892
+ namespace: Filter stats to specific namespace (None = all).
1893
+
1894
+ Returns:
1895
+ Dictionary with statistics including:
1896
+ - total_memories: Total count of memories
1897
+ - namespaces: Dict mapping namespace to count
1898
+ - storage_bytes: Total storage size in bytes
1899
+ - storage_mb: Total storage size in megabytes
1900
+ - has_vector_index: Whether vector index exists
1901
+ - has_fts_index: Whether full-text search index exists
1902
+ - num_fragments: Number of storage fragments
1903
+ - needs_compaction: Whether compaction is recommended
1904
+ - table_version: Current table version number
1905
+ - indices: List of index information dicts
1906
+
1907
+ Raises:
1908
+ ValidationError: If namespace is invalid.
1909
+ StorageError: If database operation fails.
1910
+ """
1911
+ try:
1912
+ metrics = self.get_health_metrics()
1913
+
1914
+ # Get memory counts by namespace using efficient Arrow aggregation
1915
+ # Use pure Arrow operations (no pandas dependency)
1916
+ ns_arrow = self.table.search().select(["namespace"]).to_arrow()
1917
+
1918
+ # Count by namespace using Arrow's to_pylist()
1919
+ ns_counts: dict[str, int] = {}
1920
+ for record in ns_arrow.to_pylist():
1921
+ ns = record["namespace"]
1922
+ ns_counts[ns] = ns_counts.get(ns, 0) + 1
1923
+
1924
+ # Filter if namespace specified
1925
+ if namespace:
1926
+ namespace = _validate_namespace(namespace)
1927
+ if namespace in ns_counts:
1928
+ ns_counts = {namespace: ns_counts[namespace]}
1929
+ else:
1930
+ ns_counts = {}
1931
+
1932
+ total = sum(ns_counts.values()) if ns_counts else 0
1933
+
1934
+ return {
1935
+ "total_memories": total if namespace else metrics.total_rows,
1936
+ "namespaces": ns_counts,
1937
+ "storage_bytes": metrics.total_bytes,
1938
+ "storage_mb": metrics.total_bytes_mb,
1939
+ "num_fragments": metrics.num_fragments,
1940
+ "needs_compaction": metrics.needs_compaction,
1941
+ "has_vector_index": metrics.has_vector_index,
1942
+ "has_fts_index": metrics.has_fts_index,
1943
+ "table_version": metrics.version,
1944
+ "indices": [
1945
+ {
1946
+ "name": idx.name,
1947
+ "index_type": idx.index_type,
1948
+ "num_indexed_rows": idx.num_indexed_rows,
1949
+ "status": "ready" if not idx.needs_update else "needs_update",
1950
+ }
1951
+ for idx in metrics.indices
1952
+ ],
1953
+ }
1954
+ except ValidationError:
1955
+ raise
1956
+ except Exception as e:
1957
+ raise StorageError(f"Failed to get stats: {e}") from e
1958
+
1959
+ def get_namespace_stats(self, namespace: str) -> dict[str, Any]:
1960
+ """Get statistics for a specific namespace.
1961
+
1962
+ Args:
1963
+ namespace: The namespace to get statistics for.
1964
+
1965
+ Returns:
1966
+ Dictionary containing:
1967
+ - namespace: The namespace name
1968
+ - memory_count: Number of memories in namespace
1969
+ - oldest_memory: Datetime of oldest memory (or None)
1970
+ - newest_memory: Datetime of newest memory (or None)
1971
+ - avg_content_length: Average content length (or None if empty)
1972
+
1973
+ Raises:
1974
+ ValidationError: If namespace is invalid.
1975
+ StorageError: If database operation fails.
1976
+ """
1977
+ namespace = _validate_namespace(namespace)
1978
+ safe_ns = _sanitize_string(namespace)
1979
+
1980
+ try:
1981
+ # Get count efficiently
1982
+ filter_expr = f"namespace = '{safe_ns}'"
1983
+ count_results = (
1984
+ self.table.search()
1985
+ .where(filter_expr)
1986
+ .select(["id"])
1987
+ .limit(1000000) # High limit to count all
1988
+ .to_list()
1989
+ )
1990
+ memory_count = len(count_results)
1991
+
1992
+ if memory_count == 0:
1993
+ return {
1994
+ "namespace": namespace,
1995
+ "memory_count": 0,
1996
+ "oldest_memory": None,
1997
+ "newest_memory": None,
1998
+ "avg_content_length": None,
1999
+ }
2000
+
2001
+ # Get oldest memory (sort ascending, limit 1)
2002
+ oldest_records = (
2003
+ self.table.search()
2004
+ .where(filter_expr)
2005
+ .select(["created_at"])
2006
+ .limit(1)
2007
+ .to_list()
2008
+ )
2009
+ oldest = oldest_records[0]["created_at"] if oldest_records else None
2010
+
2011
+ # Get newest memory - need to fetch more and find max since LanceDB
2012
+ # doesn't support ORDER BY DESC efficiently
2013
+ # Sample up to 1000 records for stats to avoid loading everything
2014
+ sample_size = min(memory_count, 1000)
2015
+ sample_records = (
2016
+ self.table.search()
2017
+ .where(filter_expr)
2018
+ .select(["created_at", "content"])
2019
+ .limit(sample_size)
2020
+ .to_list()
2021
+ )
2022
+
2023
+ # Find newest from sample (for large namespaces this is approximate)
2024
+ if sample_records:
2025
+ created_times = [r["created_at"] for r in sample_records]
2026
+ newest = max(created_times)
2027
+ # Calculate average content length from sample
2028
+ content_lengths = [len(r.get("content", "")) for r in sample_records]
2029
+ avg_content_length = sum(content_lengths) / len(content_lengths)
2030
+ else:
2031
+ newest = oldest
2032
+ avg_content_length = None
2033
+
2034
+ return {
2035
+ "namespace": namespace,
2036
+ "memory_count": memory_count,
2037
+ "oldest_memory": oldest,
2038
+ "newest_memory": newest,
2039
+ "avg_content_length": avg_content_length,
2040
+ }
2041
+
2042
+ except ValidationError:
2043
+ raise
2044
+ except Exception as e:
2045
+ raise StorageError(f"Failed to get namespace stats: {e}") from e
2046
+
2047
+ def get_all_for_export(
2048
+ self,
2049
+ namespace: str | None = None,
2050
+ batch_size: int = 1000,
2051
+ ) -> Generator[list[dict[str, Any]], None, None]:
2052
+ """Stream all memories for export in batches.
2053
+
2054
+ Memory-efficient export using generator pattern.
2055
+
2056
+ Args:
2057
+ namespace: Optional namespace filter.
2058
+ batch_size: Records per batch.
2059
+
2060
+ Yields:
2061
+ Batches of memory dictionaries.
2062
+
2063
+ Raises:
2064
+ ValidationError: If namespace is invalid.
2065
+ StorageError: If database operation fails.
2066
+ """
2067
+ try:
2068
+ search = self.table.search()
2069
+
2070
+ if namespace is not None:
2071
+ namespace = _validate_namespace(namespace)
2072
+ safe_ns = _sanitize_string(namespace)
2073
+ search = search.where(f"namespace = '{safe_ns}'")
2074
+
2075
+ # Use Arrow for efficient streaming
2076
+ arrow_table = search.to_arrow()
2077
+ records = arrow_table.to_pylist()
2078
+
2079
+ # Yield in batches
2080
+ for i in range(0, len(records), batch_size):
2081
+ batch = records[i : i + batch_size]
2082
+
2083
+ # Process metadata
2084
+ for record in batch:
2085
+ if isinstance(record.get("metadata"), str):
2086
+ try:
2087
+ record["metadata"] = json.loads(record["metadata"])
2088
+ except json.JSONDecodeError:
2089
+ record["metadata"] = {}
2090
+
2091
+ yield batch
2092
+
2093
+ except ValidationError:
2094
+ raise
2095
+ except Exception as e:
2096
+ raise StorageError(f"Failed to stream export: {e}") from e
2097
+
2098
+ @with_process_lock
2099
+ @with_write_lock
2100
+ def bulk_import(
2101
+ self,
2102
+ records: Iterator[dict[str, Any]],
2103
+ batch_size: int = 1000,
2104
+ namespace_override: str | None = None,
2105
+ ) -> tuple[int, list[str]]:
2106
+ """Import memories from an iterator of records.
2107
+
2108
+ Supports streaming import for large datasets.
2109
+
2110
+ Args:
2111
+ records: Iterator of memory dictionaries.
2112
+ batch_size: Records per database insert batch.
2113
+ namespace_override: Override namespace for all records.
2114
+
2115
+ Returns:
2116
+ Tuple of (records_imported, list_of_new_ids).
2117
+
2118
+ Raises:
2119
+ ValidationError: If records contain invalid data.
2120
+ StorageError: If database operation fails.
2121
+ """
2122
+ if namespace_override is not None:
2123
+ namespace_override = _validate_namespace(namespace_override)
2124
+
2125
+ imported = 0
2126
+ all_ids: list[str] = []
2127
+ batch: list[dict[str, Any]] = []
2128
+
2129
+ try:
2130
+ for record in records:
2131
+ prepared = self._prepare_import_record(record, namespace_override)
2132
+ batch.append(prepared)
2133
+
2134
+ if len(batch) >= batch_size:
2135
+ ids = self.insert_batch(batch, batch_size=batch_size)
2136
+ all_ids.extend(ids)
2137
+ imported += len(ids)
2138
+ batch = []
2139
+
2140
+ # Import remaining
2141
+ if batch:
2142
+ ids = self.insert_batch(batch, batch_size=batch_size)
2143
+ all_ids.extend(ids)
2144
+ imported += len(ids)
2145
+
2146
+ return imported, all_ids
2147
+
2148
+ except (ValidationError, StorageError):
2149
+ raise
2150
+ except Exception as e:
2151
+ raise StorageError(f"Bulk import failed: {e}") from e
2152
+
2153
+ def _prepare_import_record(
2154
+ self,
2155
+ record: dict[str, Any],
2156
+ namespace_override: str | None = None,
2157
+ ) -> dict[str, Any]:
2158
+ """Prepare a record for import.
2159
+
2160
+ Args:
2161
+ record: The raw record from import file.
2162
+ namespace_override: Optional namespace override.
2163
+
2164
+ Returns:
2165
+ Prepared record suitable for insert_batch.
2166
+ """
2167
+ # Required fields
2168
+ content = record.get("content", "")
2169
+ vector = record.get("vector", [])
2170
+
2171
+ # Convert vector to numpy if needed
2172
+ if isinstance(vector, list):
2173
+ vector = np.array(vector, dtype=np.float32)
2174
+
2175
+ # Get namespace (override if specified)
2176
+ namespace = namespace_override or record.get("namespace", "default")
2177
+
2178
+ # Optional fields with defaults
2179
+ tags = record.get("tags", [])
2180
+ importance = record.get("importance", 0.5)
2181
+ source = record.get("source", "import")
2182
+ metadata = record.get("metadata", {})
2183
+
2184
+ return {
2185
+ "content": content,
2186
+ "vector": vector,
2187
+ "namespace": namespace,
2188
+ "tags": tags,
2189
+ "importance": importance,
2190
+ "source": source,
2191
+ "metadata": metadata,
2192
+ }
2193
+
2194
+ @with_process_lock
2195
+ @with_write_lock
2196
+ def delete_batch(self, memory_ids: list[str]) -> tuple[int, list[str]]:
2197
+ """Delete multiple memories atomically using IN clause.
2198
+
2199
+ Args:
2200
+ memory_ids: List of memory UUIDs to delete.
2201
+
2202
+ Returns:
2203
+ Tuple of (count_deleted, list_of_deleted_ids) where:
2204
+ - count_deleted: Number of memories actually deleted
2205
+ - list_of_deleted_ids: IDs that were actually deleted
2206
+
2207
+ Raises:
2208
+ ValidationError: If any memory_id is invalid.
2209
+ StorageError: If database operation fails.
2210
+ """
2211
+ if not memory_ids:
2212
+ return (0, [])
2213
+
2214
+ # Validate all IDs first (fail fast)
2215
+ validated_ids: list[str] = []
2216
+ for memory_id in memory_ids:
2217
+ validated_id = _validate_uuid(memory_id)
2218
+ validated_ids.append(_sanitize_string(validated_id))
2219
+
2220
+ try:
2221
+ # First, check which IDs actually exist
2222
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
2223
+ filter_expr = f"id IN ({id_list})"
2224
+ existing_records = (
2225
+ self.table.search()
2226
+ .where(filter_expr)
2227
+ .select(["id"])
2228
+ .limit(len(validated_ids))
2229
+ .to_list()
2230
+ )
2231
+ existing_ids = [r["id"] for r in existing_records]
2232
+
2233
+ if not existing_ids:
2234
+ return (0, [])
2235
+
2236
+ # Delete only existing IDs
2237
+ existing_id_list = ", ".join(f"'{mid}'" for mid in existing_ids)
2238
+ delete_expr = f"id IN ({existing_id_list})"
2239
+ self.table.delete(delete_expr)
2240
+
2241
+ self._invalidate_count_cache()
2242
+ self._track_modification()
2243
+ self._invalidate_namespace_cache()
2244
+
2245
+ logger.debug(f"Batch deleted {len(existing_ids)} memories")
2246
+ return (len(existing_ids), existing_ids)
2247
+ except ValidationError:
2248
+ raise
2249
+ except Exception as e:
2250
+ raise StorageError(f"Failed to delete batch: {e}") from e
2251
+
2252
+ @with_process_lock
2253
+ @with_write_lock
2254
+ def update_access_batch(self, memory_ids: list[str]) -> int:
2255
+ """Update access timestamp and count for multiple memories using atomic merge_insert.
2256
+
2257
+ Uses LanceDB's merge_insert API for atomic batch upserts, eliminating
2258
+ race conditions from delete-then-insert patterns.
2259
+
2260
+ Args:
2261
+ memory_ids: List of memory UUIDs to update.
2262
+
2263
+ Returns:
2264
+ Number of memories successfully updated.
2265
+ """
2266
+ if not memory_ids:
2267
+ return 0
2268
+
2269
+ now = utc_now()
2270
+ records_to_update: list[dict[str, Any]] = []
2271
+
2272
+ # Validate all IDs first
2273
+ validated_ids: list[str] = []
2274
+ for memory_id in memory_ids:
2275
+ try:
2276
+ validated_id = _validate_uuid(memory_id)
2277
+ validated_ids.append(_sanitize_string(validated_id))
2278
+ except Exception as e:
2279
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
2280
+ continue
2281
+
2282
+ if not validated_ids:
2283
+ return 0
2284
+
2285
+ # Batch fetch all records with single IN query (fixes N+1 pattern)
2286
+ try:
2287
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
2288
+ all_records = self.table.search().where(f"id IN ({id_list})").to_list()
2289
+ except Exception as e:
2290
+ logger.error(f"Failed to batch fetch records for access update: {e}")
2291
+ return 0
2292
+
2293
+ # Build lookup map for found records
2294
+ found_ids = set()
2295
+ for record in all_records:
2296
+ found_ids.add(record["id"])
2297
+ record["last_accessed"] = now
2298
+ record["access_count"] = record["access_count"] + 1
2299
+
2300
+ # Ensure proper serialization for metadata
2301
+ if isinstance(record.get("metadata"), dict):
2302
+ record["metadata"] = json.dumps(record["metadata"])
2303
+
2304
+ # Ensure vector is a list, not numpy array
2305
+ if isinstance(record.get("vector"), np.ndarray):
2306
+ record["vector"] = record["vector"].tolist()
2307
+
2308
+ records_to_update.append(record)
2309
+
2310
+ # Log any IDs that weren't found
2311
+ missing_ids = set(validated_ids) - found_ids
2312
+ for missing_id in missing_ids:
2313
+ logger.debug(f"Memory {missing_id} not found for access update")
2314
+
2315
+ if not records_to_update:
2316
+ return 0
2317
+
2318
+ try:
2319
+ # Atomic batch upsert using merge_insert
2320
+ # Requires BTREE index on 'id' column (created in create_scalar_indexes)
2321
+ (
2322
+ self.table.merge_insert("id")
2323
+ .when_matched_update_all()
2324
+ .when_not_matched_insert_all()
2325
+ .execute(records_to_update)
2326
+ )
2327
+ updated = len(records_to_update)
2328
+ logger.debug(
2329
+ f"Batch updated access for {updated}/{len(memory_ids)} memories "
2330
+ "(atomic merge_insert)"
2331
+ )
2332
+ return updated
2333
+ except Exception as e:
2334
+ logger.error(f"Failed to batch update access: {e}")
2335
+ return 0
2336
+
2337
+ def _create_retry_decorator(self) -> Callable[[F], F]:
2338
+ """Create a retry decorator using instance settings."""
2339
+ return retry_on_storage_error(
2340
+ max_attempts=self.max_retry_attempts,
2341
+ backoff=self.retry_backoff_seconds,
2342
+ )
2343
+
2344
+ # ========================================================================
2345
+ # Search Operations (delegates to SearchManager)
2346
+ # ========================================================================
2347
+
2348
+ def _calculate_search_params(
2349
+ self,
2350
+ count: int,
2351
+ limit: int,
2352
+ nprobes_override: int | None = None,
2353
+ refine_factor_override: int | None = None,
2354
+ ) -> tuple[int, int]:
2355
+ """Calculate optimal search parameters. Delegates to SearchManager."""
2356
+ if self._search_manager is None:
2357
+ raise StorageError("Database not connected")
2358
+ return self._search_manager.calculate_search_params(
2359
+ count, limit, nprobes_override, refine_factor_override
2360
+ )
2361
+
2362
+ @with_stale_connection_recovery
2363
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2364
+ def vector_search(
2365
+ self,
2366
+ query_vector: np.ndarray,
2367
+ limit: int = 5,
2368
+ namespace: str | None = None,
2369
+ min_similarity: float = 0.0,
2370
+ nprobes: int | None = None,
2371
+ refine_factor: int | None = None,
2372
+ include_vector: bool = False,
2373
+ ) -> list[dict[str, Any]]:
2374
+ """Search for similar memories by vector. Delegates to SearchManager.
2375
+
2376
+ Args:
2377
+ query_vector: Query embedding vector.
2378
+ limit: Maximum number of results.
2379
+ namespace: Filter to specific namespace.
2380
+ min_similarity: Minimum similarity threshold (0-1).
2381
+ nprobes: Number of partitions to search.
2382
+ refine_factor: Re-rank top (refine_factor * limit) for accuracy.
2383
+ include_vector: Whether to include vector embeddings in results.
2384
+
2385
+ Returns:
2386
+ List of memory records with similarity scores.
2387
+
2388
+ Raises:
2389
+ ValidationError: If input validation fails.
2390
+ StorageError: If database operation fails.
2391
+ """
2392
+ if self._search_manager is None:
2393
+ raise StorageError("Database not connected")
2394
+ return self._search_manager.vector_search(
2395
+ query_vector=query_vector,
2396
+ limit=limit,
2397
+ namespace=namespace,
2398
+ min_similarity=min_similarity,
2399
+ nprobes=nprobes,
2400
+ refine_factor=refine_factor,
2401
+ include_vector=include_vector,
2402
+ )
2403
+
2404
+ @with_stale_connection_recovery
2405
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2406
+ def batch_vector_search_native(
2407
+ self,
2408
+ query_vectors: list[np.ndarray],
2409
+ limit_per_query: int = 3,
2410
+ namespace: str | None = None,
2411
+ min_similarity: float = 0.0,
2412
+ include_vector: bool = False,
2413
+ ) -> list[list[dict[str, Any]]]:
2414
+ """Batch search using native LanceDB. Delegates to SearchManager.
2415
+
2416
+ Args:
2417
+ query_vectors: List of query embedding vectors.
2418
+ limit_per_query: Maximum number of results per query.
2419
+ namespace: Filter to specific namespace.
2420
+ min_similarity: Minimum similarity threshold (0-1).
2421
+ include_vector: Whether to include vector embeddings in results.
2422
+
2423
+ Returns:
2424
+ List of result lists, one per query vector.
2425
+
2426
+ Raises:
2427
+ ValidationError: If input validation fails.
2428
+ StorageError: If database operation fails.
2429
+ """
2430
+ if self._search_manager is None:
2431
+ raise StorageError("Database not connected")
2432
+ return self._search_manager.batch_vector_search_native(
2433
+ query_vectors=query_vectors,
2434
+ limit_per_query=limit_per_query,
2435
+ namespace=namespace,
2436
+ min_similarity=min_similarity,
2437
+ include_vector=include_vector,
2438
+ )
2439
+
2440
+ @with_stale_connection_recovery
2441
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2442
+ def hybrid_search(
2443
+ self,
2444
+ query: str,
2445
+ query_vector: np.ndarray,
2446
+ limit: int = 5,
2447
+ namespace: str | None = None,
2448
+ alpha: float = 0.5,
2449
+ min_similarity: float = 0.0,
2450
+ ) -> list[dict[str, Any]]:
2451
+ """Hybrid search combining vector and keyword. Delegates to SearchManager.
2452
+
2453
+ Args:
2454
+ query: Text query for full-text search.
2455
+ query_vector: Embedding vector for semantic search.
2456
+ limit: Number of results.
2457
+ namespace: Filter to namespace.
2458
+ alpha: Balance between vector (1.0) and keyword (0.0).
2459
+ min_similarity: Minimum similarity threshold.
2460
+
2461
+ Returns:
2462
+ List of memory records with combined scores.
2463
+
2464
+ Raises:
2465
+ ValidationError: If input validation fails.
2466
+ StorageError: If database operation fails.
2467
+ """
2468
+ if self._search_manager is None:
2469
+ raise StorageError("Database not connected")
2470
+ return self._search_manager.hybrid_search(
2471
+ query=query,
2472
+ query_vector=query_vector,
2473
+ limit=limit,
2474
+ namespace=namespace,
2475
+ alpha=alpha,
2476
+ min_similarity=min_similarity,
2477
+ )
2478
+
2479
+ @with_stale_connection_recovery
2480
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2481
+ def batch_vector_search(
2482
+ self,
2483
+ query_vectors: list[np.ndarray],
2484
+ limit_per_query: int = 3,
2485
+ namespace: str | None = None,
2486
+ parallel: bool = False, # Deprecated
2487
+ max_workers: int = 4, # Deprecated
2488
+ include_vector: bool = False,
2489
+ ) -> list[list[dict[str, Any]]]:
2490
+ """Search using multiple query vectors. Delegates to SearchManager.
2491
+
2492
+ Args:
2493
+ query_vectors: List of query embedding vectors.
2494
+ limit_per_query: Maximum results per query vector.
2495
+ namespace: Filter to specific namespace.
2496
+ parallel: Deprecated, kept for backward compatibility.
2497
+ max_workers: Deprecated, kept for backward compatibility.
2498
+ include_vector: Whether to include vector embeddings in results.
2499
+
2500
+ Returns:
2501
+ List of result lists (one per query vector).
2502
+
2503
+ Raises:
2504
+ StorageError: If database operation fails.
2505
+ """
2506
+ if self._search_manager is None:
2507
+ raise StorageError("Database not connected")
2508
+ return self._search_manager.batch_vector_search(
2509
+ query_vectors=query_vectors,
2510
+ limit_per_query=limit_per_query,
2511
+ namespace=namespace,
2512
+ parallel=parallel,
2513
+ max_workers=max_workers,
2514
+ include_vector=include_vector,
2515
+ )
2516
+
2517
+ def get_vectors_for_clustering(
2518
+ self,
2519
+ namespace: str | None = None,
2520
+ max_memories: int = 10_000,
2521
+ ) -> tuple[list[str], np.ndarray]:
2522
+ """Fetch all vectors for clustering operations (e.g., HDBSCAN).
2523
+
2524
+ Optimized for memory efficiency with large datasets.
2525
+
2526
+ Args:
2527
+ namespace: Filter to specific namespace.
2528
+ max_memories: Maximum memories to fetch.
2529
+
2530
+ Returns:
2531
+ Tuple of (memory_ids, vectors_array).
2532
+
2533
+ Raises:
2534
+ ValidationError: If input validation fails.
2535
+ StorageError: If database operation fails.
2536
+ """
2537
+ try:
2538
+ # Build query selecting only needed columns
2539
+ search = self.table.search()
2540
+
2541
+ if namespace:
2542
+ namespace = _validate_namespace(namespace)
2543
+ safe_ns = _sanitize_string(namespace)
2544
+ search = search.where(f"namespace = '{safe_ns}'")
2545
+
2546
+ # Select only id and vector to minimize memory usage
2547
+ search = search.select(["id", "vector"]).limit(max_memories)
2548
+
2549
+ results = search.to_list()
2550
+
2551
+ if not results:
2552
+ return [], np.array([], dtype=np.float32).reshape(0, self.embedding_dim)
2553
+
2554
+ ids = [r["id"] for r in results]
2555
+ vectors = np.array([r["vector"] for r in results], dtype=np.float32)
2556
+
2557
+ return ids, vectors
2558
+
2559
+ except ValidationError:
2560
+ raise
2561
+ except Exception as e:
2562
+ raise StorageError(f"Failed to fetch vectors for clustering: {e}") from e
2563
+
2564
+ def get_vectors_as_arrow(
2565
+ self,
2566
+ namespace: str | None = None,
2567
+ columns: list[str] | None = None,
2568
+ ) -> pa.Table:
2569
+ """Get memories as Arrow table for efficient processing.
2570
+
2571
+ Arrow tables enable zero-copy data sharing and efficient columnar
2572
+ operations. Use this for large-scale analytics.
2573
+
2574
+ Args:
2575
+ namespace: Filter to specific namespace.
2576
+ columns: Columns to select (None = all).
2577
+
2578
+ Returns:
2579
+ PyArrow Table with selected data.
2580
+
2581
+ Raises:
2582
+ StorageError: If database operation fails.
2583
+ """
2584
+ try:
2585
+ search = self.table.search()
2586
+
2587
+ if namespace:
2588
+ namespace = _validate_namespace(namespace)
2589
+ safe_ns = _sanitize_string(namespace)
2590
+ search = search.where(f"namespace = '{safe_ns}'")
2591
+
2592
+ if columns:
2593
+ search = search.select(columns)
2594
+
2595
+ return search.to_arrow()
2596
+
2597
+ except Exception as e:
2598
+ raise StorageError(f"Failed to get Arrow table: {e}") from e
2599
+
2600
+ def get_all(
2601
+ self,
2602
+ namespace: str | None = None,
2603
+ limit: int | None = None,
2604
+ ) -> list[dict[str, Any]]:
2605
+ """Get all memories, optionally filtered by namespace.
2606
+
2607
+ Args:
2608
+ namespace: Filter to specific namespace.
2609
+ limit: Maximum number of results.
2610
+
2611
+ Returns:
2612
+ List of memory records.
2613
+
2614
+ Raises:
2615
+ ValidationError: If input validation fails.
2616
+ StorageError: If database operation fails.
2617
+ """
2618
+ try:
2619
+ search = self.table.search()
2620
+
2621
+ if namespace:
2622
+ namespace = _validate_namespace(namespace)
2623
+ safe_ns = _sanitize_string(namespace)
2624
+ search = search.where(f"namespace = '{safe_ns}'")
2625
+
2626
+ if limit:
2627
+ search = search.limit(limit)
2628
+
2629
+ results: list[dict[str, Any]] = search.to_list()
2630
+
2631
+ for record in results:
2632
+ record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
2633
+
2634
+ return results
2635
+ except ValidationError:
2636
+ raise
2637
+ except Exception as e:
2638
+ raise StorageError(f"Failed to get all memories: {e}") from e
2639
+
2640
+ @with_stale_connection_recovery
2641
+ def count(self, namespace: str | None = None) -> int:
2642
+ """Count memories.
2643
+
2644
+ Args:
2645
+ namespace: Filter to specific namespace.
2646
+
2647
+ Returns:
2648
+ Number of memories.
2649
+
2650
+ Raises:
2651
+ ValidationError: If input validation fails.
2652
+ StorageError: If database operation fails.
2653
+ """
2654
+ try:
2655
+ if namespace:
2656
+ namespace = _validate_namespace(namespace)
2657
+ safe_ns = _sanitize_string(namespace)
2658
+ # Use count_rows with filter predicate for efficiency
2659
+ count: int = self.table.count_rows(f"namespace = '{safe_ns}'")
2660
+ return count
2661
+ count = self.table.count_rows()
2662
+ return count
2663
+ except ValidationError:
2664
+ raise
2665
+ except Exception as e:
2666
+ raise StorageError(f"Failed to count memories: {e}") from e
2667
+
2668
+ @with_stale_connection_recovery
2669
+ def get_namespaces(self) -> list[str]:
2670
+ """Get all unique namespaces (cached with TTL, thread-safe).
2671
+
2672
+ Uses double-checked locking to avoid race conditions where another
2673
+ thread could see stale data between cache check and update.
2674
+
2675
+ Returns:
2676
+ Sorted list of namespace names.
2677
+
2678
+ Raises:
2679
+ StorageError: If database operation fails.
2680
+ """
2681
+ try:
2682
+ now = time.time()
2683
+
2684
+ # First check with lock (quick path if cache is valid)
2685
+ with self._namespace_cache_lock:
2686
+ if (
2687
+ self._cached_namespaces is not None
2688
+ and (now - self._namespace_cache_time) <= self._NAMESPACE_CACHE_TTL
2689
+ ):
2690
+ return sorted(self._cached_namespaces)
2691
+
2692
+ # Fetch from database (outside lock to avoid blocking)
2693
+ results = self.table.search().select(["namespace"]).to_list()
2694
+ namespaces = set(r["namespace"] for r in results)
2695
+
2696
+ # Double-checked locking: re-check and update atomically
2697
+ with self._namespace_cache_lock:
2698
+ # Another thread may have populated cache while we were fetching
2699
+ if self._cached_namespaces is None:
2700
+ self._cached_namespaces = namespaces
2701
+ self._namespace_cache_time = now
2702
+ # Return fresh data regardless (it's at least as current)
2703
+ return sorted(namespaces)
2704
+
2705
+ except Exception as e:
2706
+ raise StorageError(f"Failed to get namespaces: {e}") from e
2707
+
2708
+ @with_process_lock
2709
+ @with_write_lock
2710
+ def update_access(self, memory_id: str) -> None:
2711
+ """Update access timestamp and count for a memory.
2712
+
2713
+ Args:
2714
+ memory_id: The memory ID.
2715
+
2716
+ Raises:
2717
+ ValidationError: If memory_id is invalid.
2718
+ MemoryNotFoundError: If memory doesn't exist.
2719
+ StorageError: If database operation fails.
2720
+ """
2721
+ existing = self.get(memory_id)
2722
+ # Note: self.update() also has @with_process_lock and @with_write_lock,
2723
+ # but both support reentrancy within the same thread (no deadlock):
2724
+ # - ProcessLockManager tracks depth via threading.local
2725
+ # - RLock allows same thread to re-acquire
2726
+ self.update(memory_id, {
2727
+ "last_accessed": utc_now(),
2728
+ "access_count": existing["access_count"] + 1,
2729
+ })
2730
+
2731
+ # ========================================================================
2732
+ # Backup & Export
2733
+ # ========================================================================
2734
+
2735
+ def export_to_parquet(
2736
+ self,
2737
+ output_path: Path,
2738
+ namespace: str | None = None,
2739
+ ) -> dict[str, Any]:
2740
+ """Export memories to Parquet file for backup.
2741
+
2742
+ Parquet provides efficient compression and fast read performance
2743
+ for large datasets.
2744
+
2745
+ Args:
2746
+ output_path: Path to save Parquet file.
2747
+ namespace: Export only this namespace (None = all).
2748
+
2749
+ Returns:
2750
+ Export statistics (rows_exported, output_path, size_mb).
2751
+
2752
+ Raises:
2753
+ StorageError: If export fails.
2754
+ """
2755
+ try:
2756
+ # Get all data as Arrow table (efficient)
2757
+ arrow_table = self.get_vectors_as_arrow(namespace=namespace)
2758
+
2759
+ # Ensure parent directory exists
2760
+ output_path = Path(output_path)
2761
+ output_path.parent.mkdir(parents=True, exist_ok=True)
2762
+
2763
+ # Write to Parquet with compression
2764
+ pq.write_table(
2765
+ arrow_table,
2766
+ output_path,
2767
+ compression="zstd", # Good compression + fast decompression
2768
+ )
2769
+
2770
+ size_bytes = output_path.stat().st_size
2771
+
2772
+ logger.info(
2773
+ f"Exported {arrow_table.num_rows} memories to {output_path} "
2774
+ f"({size_bytes / (1024 * 1024):.2f} MB)"
2775
+ )
2776
+
2777
+ return {
2778
+ "rows_exported": arrow_table.num_rows,
2779
+ "output_path": str(output_path),
2780
+ "size_bytes": size_bytes,
2781
+ "size_mb": size_bytes / (1024 * 1024),
2782
+ }
2783
+
2784
+ except Exception as e:
2785
+ raise StorageError(f"Export failed: {e}") from e
2786
+
2787
+ def import_from_parquet(
2788
+ self,
2789
+ parquet_path: Path,
2790
+ namespace_override: str | None = None,
2791
+ batch_size: int = 1000,
2792
+ ) -> dict[str, Any]:
2793
+ """Import memories from Parquet backup.
2794
+
2795
+ Args:
2796
+ parquet_path: Path to Parquet file.
2797
+ namespace_override: Override namespace for all imported memories.
2798
+ batch_size: Records per batch during import.
2799
+
2800
+ Returns:
2801
+ Import statistics (rows_imported, source).
2802
+
2803
+ Raises:
2804
+ StorageError: If import fails.
2805
+ """
2806
+ try:
2807
+ parquet_path = Path(parquet_path)
2808
+ if not parquet_path.exists():
2809
+ raise StorageError(f"Parquet file not found: {parquet_path}")
2810
+
2811
+ table = pq.read_table(parquet_path)
2812
+ total_rows = table.num_rows
2813
+
2814
+ logger.info(f"Importing {total_rows} memories from {parquet_path}")
2815
+
2816
+ # Convert to list of dicts for processing
2817
+ records = table.to_pylist()
2818
+
2819
+ # Override namespace if requested
2820
+ if namespace_override:
2821
+ namespace_override = _validate_namespace(namespace_override)
2822
+ for record in records:
2823
+ record["namespace"] = namespace_override
2824
+
2825
+ # Regenerate IDs to avoid conflicts
2826
+ for record in records:
2827
+ record["id"] = str(uuid.uuid4())
2828
+ # Ensure metadata is properly formatted
2829
+ if isinstance(record.get("metadata"), str):
2830
+ try:
2831
+ record["metadata"] = json.loads(record["metadata"])
2832
+ except json.JSONDecodeError:
2833
+ record["metadata"] = {}
2834
+
2835
+ # After reading from parquet, serialize metadata back to JSON string
2836
+ # Parquet may read metadata as dict/struct, but the database expects JSON string
2837
+ for record in records:
2838
+ if "metadata" in record and isinstance(record["metadata"], dict):
2839
+ record["metadata"] = json.dumps(record["metadata"])
2840
+
2841
+ # Insert in batches
2842
+ imported = 0
2843
+ for i in range(0, len(records), batch_size):
2844
+ batch = records[i:i + batch_size]
2845
+ # Convert to format expected by insert
2846
+ prepared = []
2847
+ for r in batch:
2848
+ # Ensure metadata is a JSON string for storage
2849
+ metadata = r.get("metadata", {})
2850
+ if isinstance(metadata, dict):
2851
+ metadata = json.dumps(metadata)
2852
+ elif metadata is None:
2853
+ metadata = "{}"
2854
+
2855
+ prepared.append({
2856
+ "content": r["content"],
2857
+ "vector": r["vector"],
2858
+ "namespace": r["namespace"],
2859
+ "tags": r.get("tags", []),
2860
+ "importance": r.get("importance", 0.5),
2861
+ "source": r.get("source", "import"),
2862
+ "metadata": metadata,
2863
+ "expires_at": r.get("expires_at"), # Preserve TTL from source
2864
+ })
2865
+ self.table.add(prepared)
2866
+ imported += len(batch)
2867
+ logger.debug(f"Imported batch: {imported}/{total_rows}")
2868
+
2869
+ logger.info(f"Successfully imported {imported} memories")
2870
+
2871
+ return {
2872
+ "rows_imported": imported,
2873
+ "source": str(parquet_path),
2874
+ }
2875
+
2876
+ except StorageError:
2877
+ raise
2878
+ except Exception as e:
2879
+ raise StorageError(f"Import failed: {e}") from e
2880
+
2881
+ # ========================================================================
2882
+ # TTL (Time-To-Live) Management
2883
+ # ========================================================================
2884
+
2885
+ def set_memory_ttl(self, memory_id: str, ttl_days: int | None) -> None:
2886
+ """Set TTL for a specific memory.
2887
+
2888
+ Args:
2889
+ memory_id: Memory ID.
2890
+ ttl_days: Days until expiration, or None to remove TTL.
2891
+
2892
+ Raises:
2893
+ ValidationError: If memory_id is invalid.
2894
+ MemoryNotFoundError: If memory doesn't exist.
2895
+ StorageError: If database operation fails.
2896
+ """
2897
+ memory_id = _validate_uuid(memory_id)
2898
+
2899
+ # Verify memory exists
2900
+ existing = self.get(memory_id)
2901
+
2902
+ if ttl_days is not None:
2903
+ if ttl_days <= 0:
2904
+ raise ValidationError("TTL days must be positive")
2905
+ expires_at = utc_now() + timedelta(days=ttl_days)
2906
+ else:
2907
+ expires_at = None
2908
+
2909
+ # Prepare record with TTL update
2910
+ existing["expires_at"] = expires_at
2911
+ existing["updated_at"] = utc_now()
2912
+
2913
+ # Ensure proper serialization for LanceDB
2914
+ if isinstance(existing.get("metadata"), dict):
2915
+ existing["metadata"] = json.dumps(existing["metadata"])
2916
+ if isinstance(existing.get("vector"), np.ndarray):
2917
+ existing["vector"] = existing["vector"].tolist()
2918
+
2919
+ try:
2920
+ # Atomic upsert using merge_insert (same pattern as update() method)
2921
+ # This prevents data loss if the operation fails partway through
2922
+ (
2923
+ self.table.merge_insert("id")
2924
+ .when_matched_update_all()
2925
+ .when_not_matched_insert_all()
2926
+ .execute([existing])
2927
+ )
2928
+ logger.debug(f"Set TTL for memory {memory_id}: expires_at={expires_at}")
2929
+ except Exception as e:
2930
+ raise StorageError(f"Failed to set memory TTL: {e}") from e
2931
+
2932
+ def cleanup_expired_memories(self) -> int:
2933
+ """Delete memories that have passed their expiration time.
2934
+
2935
+ Returns:
2936
+ Number of deleted memories.
2937
+
2938
+ Raises:
2939
+ StorageError: If cleanup fails.
2940
+ """
2941
+ if not self.enable_memory_expiration:
2942
+ logger.debug("Memory expiration is disabled, skipping cleanup")
2943
+ return 0
2944
+
2945
+ try:
2946
+ now = utc_now()
2947
+ count_before: int = self.table.count_rows()
2948
+
2949
+ # Delete expired memories using timestamp comparison
2950
+ # LanceDB uses ISO 8601 format for timestamp comparisons
2951
+ predicate = (
2952
+ f"expires_at IS NOT NULL AND expires_at < timestamp '{now.isoformat()}'"
2953
+ )
2954
+ self.table.delete(predicate)
2955
+
2956
+ count_after: int = self.table.count_rows()
2957
+ deleted: int = count_before - count_after
2958
+
2959
+ if deleted > 0:
2960
+ self._invalidate_count_cache()
2961
+ self._track_modification(deleted)
2962
+ self._invalidate_namespace_cache()
2963
+ logger.info(f"Cleaned up {deleted} expired memories")
2964
+
2965
+ return deleted
2966
+ except Exception as e:
2967
+ raise StorageError(f"Failed to cleanup expired memories: {e}") from e
2968
+
2969
+ # ========================================================================
2970
+ # Snapshot / Version Management (delegated to VersionManager)
2971
+ # ========================================================================
2972
+
2973
+ def create_snapshot(self, tag: str) -> int:
2974
+ """Create a named snapshot of the current table state.
2975
+
2976
+ Delegates to VersionManager. See VersionManager.create_snapshot for details.
2977
+ """
2978
+ if self._version_manager is None:
2979
+ raise StorageError("Database not connected")
2980
+ return self._version_manager.create_snapshot(tag)
2981
+
2982
+ def list_snapshots(self) -> list[dict[str, Any]]:
2983
+ """List available versions/snapshots.
2984
+
2985
+ Delegates to VersionManager. See VersionManager.list_snapshots for details.
2986
+ """
2987
+ if self._version_manager is None:
2988
+ raise StorageError("Database not connected")
2989
+ return self._version_manager.list_snapshots()
2990
+
2991
+ def restore_snapshot(self, version: int) -> None:
2992
+ """Restore table to a specific version.
2993
+
2994
+ Delegates to VersionManager. See VersionManager.restore_snapshot for details.
2995
+ """
2996
+ if self._version_manager is None:
2997
+ raise StorageError("Database not connected")
2998
+ self._version_manager.restore_snapshot(version)
2999
+
3000
+ def get_current_version(self) -> int:
3001
+ """Get the current table version number.
3002
+
3003
+ Delegates to VersionManager. See VersionManager.get_current_version for details.
3004
+ """
3005
+ if self._version_manager is None:
3006
+ raise StorageError("Database not connected")
3007
+ return self._version_manager.get_current_version()
3008
+
3009
+ # ========================================================================
3010
+ # Idempotency Key Management (delegates to IdempotencyManager)
3011
+ # ========================================================================
3012
+
3013
+ @property
3014
+ def idempotency_table(self) -> LanceTable:
3015
+ """Get the idempotency keys table. Delegates to IdempotencyManager."""
3016
+ if self._idempotency_manager is None:
3017
+ raise StorageError("Database not connected")
3018
+ return self._idempotency_manager.idempotency_table
3019
+
3020
+ def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
3021
+ """Look up an idempotency record by key. Delegates to IdempotencyManager.
3022
+
3023
+ Args:
3024
+ key: The idempotency key to look up.
3025
+
3026
+ Returns:
3027
+ IdempotencyRecord if found and not expired, None otherwise.
3028
+
3029
+ Raises:
3030
+ StorageError: If database operation fails.
3031
+ """
3032
+ if self._idempotency_manager is None:
3033
+ raise StorageError("Database not connected")
3034
+ return self._idempotency_manager.get_by_idempotency_key(key)
3035
+
3036
+ @with_process_lock
3037
+ @with_write_lock
3038
+ def store_idempotency_key(
3039
+ self,
3040
+ key: str,
3041
+ memory_id: str,
3042
+ ttl_hours: float = 24.0,
3043
+ ) -> None:
3044
+ """Store an idempotency key mapping. Delegates to IdempotencyManager.
3045
+
3046
+ Args:
3047
+ key: The idempotency key.
3048
+ memory_id: The memory ID that was created.
3049
+ ttl_hours: Time-to-live in hours (default: 24 hours).
3050
+
3051
+ Raises:
3052
+ ValidationError: If inputs are invalid.
3053
+ StorageError: If database operation fails.
3054
+ """
3055
+ if self._idempotency_manager is None:
3056
+ raise StorageError("Database not connected")
3057
+ self._idempotency_manager.store_idempotency_key(key, memory_id, ttl_hours)
3058
+
3059
+ @with_process_lock
3060
+ @with_write_lock
3061
+ def cleanup_expired_idempotency_keys(self) -> int:
3062
+ """Remove expired idempotency keys. Delegates to IdempotencyManager.
3063
+
3064
+ Returns:
3065
+ Number of keys removed.
3066
+
3067
+ Raises:
3068
+ StorageError: If cleanup fails.
3069
+ """
3070
+ if self._idempotency_manager is None:
3071
+ raise StorageError("Database not connected")
3072
+ return self._idempotency_manager.cleanup_expired_idempotency_keys()