spatial-memory-mcp 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spatial-memory-mcp might be problematic. Click here for more details.

Files changed (34) hide show
  1. spatial_memory/__init__.py +1 -1
  2. spatial_memory/__main__.py +241 -2
  3. spatial_memory/adapters/lancedb_repository.py +74 -5
  4. spatial_memory/config.py +10 -2
  5. spatial_memory/core/__init__.py +9 -0
  6. spatial_memory/core/connection_pool.py +41 -3
  7. spatial_memory/core/consolidation_strategies.py +402 -0
  8. spatial_memory/core/database.py +774 -918
  9. spatial_memory/core/db_idempotency.py +242 -0
  10. spatial_memory/core/db_indexes.py +575 -0
  11. spatial_memory/core/db_migrations.py +584 -0
  12. spatial_memory/core/db_search.py +509 -0
  13. spatial_memory/core/db_versioning.py +177 -0
  14. spatial_memory/core/embeddings.py +65 -18
  15. spatial_memory/core/errors.py +75 -3
  16. spatial_memory/core/filesystem.py +178 -0
  17. spatial_memory/core/models.py +4 -0
  18. spatial_memory/core/rate_limiter.py +26 -9
  19. spatial_memory/core/response_types.py +497 -0
  20. spatial_memory/core/validation.py +86 -2
  21. spatial_memory/factory.py +407 -0
  22. spatial_memory/migrations/__init__.py +40 -0
  23. spatial_memory/ports/repositories.py +52 -2
  24. spatial_memory/server.py +131 -189
  25. spatial_memory/services/export_import.py +61 -43
  26. spatial_memory/services/lifecycle.py +397 -122
  27. spatial_memory/services/memory.py +2 -2
  28. spatial_memory/services/spatial.py +129 -46
  29. {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/METADATA +83 -3
  30. spatial_memory_mcp-1.6.0.dist-info/RECORD +54 -0
  31. spatial_memory_mcp-1.5.3.dist-info/RECORD +0 -44
  32. {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/WHEEL +0 -0
  33. {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/entry_points.txt +0 -0
  34. {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -34,7 +34,20 @@ import pyarrow.parquet as pq
34
34
  from filelock import FileLock, Timeout as FileLockTimeout
35
35
 
36
36
  from spatial_memory.core.connection_pool import ConnectionPool
37
- from spatial_memory.core.errors import FileLockError, MemoryNotFoundError, StorageError, ValidationError
37
+ from spatial_memory.core.db_idempotency import IdempotencyManager, IdempotencyRecord
38
+ from spatial_memory.core.db_indexes import IndexManager
39
+ from spatial_memory.core.db_migrations import CURRENT_SCHEMA_VERSION, MigrationManager
40
+ from spatial_memory.core.db_search import SearchManager
41
+ from spatial_memory.core.db_versioning import VersionManager
42
+ from spatial_memory.core.errors import (
43
+ DimensionMismatchError,
44
+ FileLockError,
45
+ MemoryNotFoundError,
46
+ PartialBatchInsertError,
47
+ StorageError,
48
+ ValidationError,
49
+ )
50
+ from spatial_memory.core.filesystem import detect_filesystem_type, get_filesystem_warning_message, is_network_filesystem
38
51
  from spatial_memory.core.utils import to_aware_utc, utc_now
39
52
 
40
53
  # Import centralized validation functions
@@ -131,9 +144,14 @@ def invalidate_connection(storage_path: Path) -> bool:
131
144
  # Retry Decorator
132
145
  # ============================================================================
133
146
 
147
+ # Default retry settings (can be overridden per-call)
148
+ DEFAULT_RETRY_MAX_ATTEMPTS = 3
149
+ DEFAULT_RETRY_BACKOFF_SECONDS = 0.5
150
+
151
+
134
152
  def retry_on_storage_error(
135
- max_attempts: int = 3,
136
- backoff: float = 0.5,
153
+ max_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
154
+ backoff: float = DEFAULT_RETRY_BACKOFF_SECONDS,
137
155
  ) -> Callable[[F], F]:
138
156
  """Retry decorator for transient storage errors.
139
157
 
@@ -391,15 +409,6 @@ def with_process_lock(func: F) -> F:
391
409
  # Health Metrics
392
410
  # ============================================================================
393
411
 
394
- @dataclass
395
- class IdempotencyRecord:
396
- """Record for idempotency key tracking."""
397
- key: str
398
- memory_id: str
399
- created_at: Any # datetime
400
- expires_at: Any # datetime
401
-
402
-
403
412
  @dataclass
404
413
  class IndexStats:
405
414
  """Statistics for a single index."""
@@ -492,8 +501,8 @@ class Database:
492
501
  enable_fts: bool = True,
493
502
  index_nprobes: int = 20,
494
503
  index_refine_factor: int = 5,
495
- max_retry_attempts: int = 3,
496
- retry_backoff_seconds: float = 0.5,
504
+ max_retry_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
505
+ retry_backoff_seconds: float = DEFAULT_RETRY_BACKOFF_SECONDS,
497
506
  read_consistency_interval_ms: int = 0,
498
507
  index_wait_timeout_seconds: float = 30.0,
499
508
  fts_stem: bool = True,
@@ -507,6 +516,7 @@ class Database:
507
516
  filelock_enabled: bool = True,
508
517
  filelock_timeout: float = 30.0,
509
518
  filelock_poll_interval: float = 0.1,
519
+ acknowledge_network_filesystem_risk: bool = False,
510
520
  ) -> None:
511
521
  """Initialize the database connection.
512
522
 
@@ -533,6 +543,7 @@ class Database:
533
543
  hnsw_ef_construction: HNSW build-time search width (100-1000).
534
544
  enable_memory_expiration: Enable automatic memory expiration.
535
545
  default_memory_ttl_days: Default TTL for memories in days (None = no expiration).
546
+ acknowledge_network_filesystem_risk: Suppress network filesystem warnings.
536
547
  """
537
548
  self.storage_path = Path(storage_path)
538
549
  self.embedding_dim = embedding_dim
@@ -556,6 +567,7 @@ class Database:
556
567
  self.filelock_enabled = filelock_enabled
557
568
  self.filelock_timeout = filelock_timeout
558
569
  self.filelock_poll_interval = filelock_poll_interval
570
+ self.acknowledge_network_filesystem_risk = acknowledge_network_filesystem_risk
559
571
  self._db: lancedb.DBConnection | None = None
560
572
  self._table: LanceTable | None = None
561
573
  self._has_vector_index: bool | None = None
@@ -573,6 +585,18 @@ class Database:
573
585
  self._write_lock = threading.RLock()
574
586
  # Cross-process lock (initialized in connect())
575
587
  self._process_lock: ProcessLockManager | None = None
588
+ # Auto-compaction tracking
589
+ self._modification_count: int = 0
590
+ self._auto_compaction_threshold: int = 100 # Compact after this many modifications
591
+ self._auto_compaction_enabled: bool = True
592
+ # Version manager (initialized in connect())
593
+ self._version_manager: VersionManager | None = None
594
+ # Index manager (initialized in connect())
595
+ self._index_manager: IndexManager | None = None
596
+ # Search manager (initialized in connect())
597
+ self._search_manager: SearchManager | None = None
598
+ # Idempotency manager (initialized in connect())
599
+ self._idempotency_manager: IdempotencyManager | None = None
576
600
 
577
601
  def __enter__(self) -> Database:
578
602
  """Enter context manager."""
@@ -588,6 +612,13 @@ class Database:
588
612
  try:
589
613
  self.storage_path.mkdir(parents=True, exist_ok=True)
590
614
 
615
+ # Check for network filesystem and warn if detected
616
+ if not self.acknowledge_network_filesystem_risk:
617
+ if is_network_filesystem(self.storage_path):
618
+ fs_type = detect_filesystem_type(self.storage_path)
619
+ warning_msg = get_filesystem_warning_message(fs_type, self.storage_path)
620
+ logger.warning(warning_msg)
621
+
591
622
  # Initialize cross-process lock manager
592
623
  if self.filelock_enabled:
593
624
  lock_path = self.storage_path / ".spatial-memory.lock"
@@ -606,109 +637,144 @@ class Database:
606
637
  read_consistency_interval_ms=self.read_consistency_interval_ms,
607
638
  )
608
639
  self._ensure_table()
640
+ # Initialize remaining managers (IndexManager already initialized in _ensure_table)
641
+ self._version_manager = VersionManager(self)
642
+ self._search_manager = SearchManager(self)
643
+ self._idempotency_manager = IdempotencyManager(self)
609
644
  logger.info(f"Connected to LanceDB at {self.storage_path}")
645
+
646
+ # Check for pending schema migrations
647
+ self._check_pending_migrations()
610
648
  except Exception as e:
611
649
  raise StorageError(f"Failed to connect to database: {e}") from e
612
650
 
651
+ def _check_pending_migrations(self) -> None:
652
+ """Check for pending migrations and warn if any exist.
653
+
654
+ This method checks the schema version and logs a warning if there
655
+ are pending migrations. It does not auto-apply migrations - that
656
+ requires explicit user action via the CLI.
657
+ """
658
+ try:
659
+ manager = MigrationManager(self, embeddings=None)
660
+ manager.register_builtin_migrations()
661
+
662
+ current_version = manager.get_current_version()
663
+ pending = manager.get_pending_migrations()
664
+
665
+ if pending:
666
+ pending_versions = [m.version for m in pending]
667
+ logger.warning(
668
+ f"Database schema version {current_version} is outdated. "
669
+ f"{len(pending)} migration(s) pending: {', '.join(pending_versions)}. "
670
+ f"Target version: {CURRENT_SCHEMA_VERSION}. "
671
+ f"Run 'spatial-memory migrate' to apply migrations."
672
+ )
673
+ except Exception as e:
674
+ # Don't fail connection due to migration check errors
675
+ logger.debug(f"Migration check skipped: {e}")
676
+
613
677
  def _ensure_table(self) -> None:
614
- """Ensure the memories table exists with appropriate indexes."""
678
+ """Ensure the memories table exists with appropriate indexes.
679
+
680
+ Uses retry logic to handle race conditions when multiple processes
681
+ attempt to create/open the table simultaneously.
682
+ """
615
683
  if self._db is None:
616
684
  raise StorageError("Database not connected")
617
685
 
618
- existing_tables_result = self._db.list_tables()
619
- # Handle both old (list) and new (object with .tables) LanceDB API
620
- if hasattr(existing_tables_result, 'tables'):
621
- existing_tables = existing_tables_result.tables
622
- else:
623
- existing_tables = existing_tables_result
624
- if "memories" not in existing_tables:
625
- # Create table with schema
626
- schema = pa.schema([
627
- pa.field("id", pa.string()),
628
- pa.field("content", pa.string()),
629
- pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
630
- pa.field("created_at", pa.timestamp("us")),
631
- pa.field("updated_at", pa.timestamp("us")),
632
- pa.field("last_accessed", pa.timestamp("us")),
633
- pa.field("access_count", pa.int32()),
634
- pa.field("importance", pa.float32()),
635
- pa.field("namespace", pa.string()),
636
- pa.field("tags", pa.list_(pa.string())),
637
- pa.field("source", pa.string()),
638
- pa.field("metadata", pa.string()),
639
- pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
640
- ])
641
- self._table = self._db.create_table("memories", schema=schema)
642
- logger.info("Created memories table")
643
-
644
- # Create FTS index on new table if enabled
645
- if self.enable_fts:
646
- self._create_fts_index()
647
- else:
648
- self._table = self._db.open_table("memories")
649
- logger.debug("Opened existing memories table")
686
+ max_retries = 3
687
+ retry_delay = 0.1 # Start with 100ms
650
688
 
651
- # Check existing indexes
652
- self._check_existing_indexes()
689
+ for attempt in range(max_retries):
690
+ try:
691
+ existing_tables_result = self._db.list_tables()
692
+ # Handle both old (list) and new (object with .tables) LanceDB API
693
+ if hasattr(existing_tables_result, 'tables'):
694
+ existing_tables = existing_tables_result.tables
695
+ else:
696
+ existing_tables = existing_tables_result
697
+
698
+ if "memories" not in existing_tables:
699
+ # Create table with schema
700
+ schema = pa.schema([
701
+ pa.field("id", pa.string()),
702
+ pa.field("content", pa.string()),
703
+ pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
704
+ pa.field("created_at", pa.timestamp("us")),
705
+ pa.field("updated_at", pa.timestamp("us")),
706
+ pa.field("last_accessed", pa.timestamp("us")),
707
+ pa.field("access_count", pa.int32()),
708
+ pa.field("importance", pa.float32()),
709
+ pa.field("namespace", pa.string()),
710
+ pa.field("tags", pa.list_(pa.string())),
711
+ pa.field("source", pa.string()),
712
+ pa.field("metadata", pa.string()),
713
+ pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
714
+ ])
715
+ try:
716
+ self._table = self._db.create_table("memories", schema=schema)
717
+ logger.info("Created memories table")
718
+ except Exception as create_err:
719
+ # Table might have been created by another process
720
+ if "already exists" in str(create_err).lower():
721
+ logger.debug("Table created by another process, opening it")
722
+ self._table = self._db.open_table("memories")
723
+ else:
724
+ raise
653
725
 
654
- def _check_existing_indexes(self) -> None:
655
- """Check which indexes already exist using robust detection."""
656
- try:
657
- indices = self.table.list_indices()
726
+ # Initialize IndexManager immediately after table is set
727
+ self._index_manager = IndexManager(self)
658
728
 
659
- self._has_vector_index = False
660
- self._has_fts_index = False
729
+ # Create FTS index on new table if enabled
730
+ if self.enable_fts:
731
+ self._index_manager.create_fts_index()
732
+ else:
733
+ self._table = self._db.open_table("memories")
734
+ logger.debug("Opened existing memories table")
661
735
 
662
- for idx in indices:
663
- index_name = str(_get_index_attr(idx, "name", "")).lower()
664
- index_type = str(_get_index_attr(idx, "index_type", "")).upper()
665
- columns = _get_index_attr(idx, "columns", [])
736
+ # Initialize IndexManager immediately after table is set
737
+ self._index_manager = IndexManager(self)
666
738
 
667
- # Vector index detection: check index_type or column name
668
- if index_type in VECTOR_INDEX_TYPES:
669
- self._has_vector_index = True
670
- elif "vector" in columns or "vector" in index_name:
671
- self._has_vector_index = True
739
+ # Check existing indexes
740
+ self._index_manager.check_existing_indexes()
672
741
 
673
- # FTS index detection: check index_type or name patterns
674
- if index_type == "FTS":
675
- self._has_fts_index = True
676
- elif "fts" in index_name or "content" in index_name:
677
- self._has_fts_index = True
742
+ # Success - exit retry loop
743
+ return
678
744
 
679
- logger.debug(
680
- f"Existing indexes: vector={self._has_vector_index}, "
681
- f"fts={self._has_fts_index}"
682
- )
683
- except Exception as e:
684
- logger.warning(f"Could not check existing indexes: {e}")
685
- self._has_vector_index = None
686
- self._has_fts_index = None
745
+ except Exception as e:
746
+ error_msg = str(e).lower()
747
+ # Retry on transient race conditions
748
+ if attempt < max_retries - 1 and (
749
+ "not found" in error_msg
750
+ or "does not exist" in error_msg
751
+ or "already exists" in error_msg
752
+ ):
753
+ logger.debug(
754
+ f"Table operation failed (attempt {attempt + 1}/{max_retries}), "
755
+ f"retrying in {retry_delay}s: {e}"
756
+ )
757
+ time.sleep(retry_delay)
758
+ retry_delay *= 2 # Exponential backoff
759
+ else:
760
+ raise
761
+
762
+ def _check_existing_indexes(self) -> None:
763
+ """Check which indexes already exist. Delegates to IndexManager."""
764
+ if self._index_manager is None:
765
+ raise StorageError("Database not connected")
766
+ self._index_manager.check_existing_indexes()
767
+ # Sync local state for backward compatibility
768
+ self._has_vector_index = self._index_manager.has_vector_index
769
+ self._has_fts_index = self._index_manager.has_fts_index
687
770
 
688
771
  def _create_fts_index(self) -> None:
689
- """Create full-text search index with optimized settings."""
690
- try:
691
- self.table.create_fts_index(
692
- "content",
693
- use_tantivy=False, # Use Lance native FTS
694
- language=self.fts_language,
695
- stem=self.fts_stem,
696
- remove_stop_words=self.fts_remove_stop_words,
697
- with_position=True, # Enable phrase queries
698
- lower_case=True, # Case-insensitive search
699
- )
700
- self._has_fts_index = True
701
- logger.info(
702
- f"Created FTS index with stemming={self.fts_stem}, "
703
- f"stop_words={self.fts_remove_stop_words}"
704
- )
705
- except Exception as e:
706
- # Check if index already exists (not an error)
707
- if "already exists" in str(e).lower():
708
- self._has_fts_index = True
709
- logger.debug("FTS index already exists")
710
- else:
711
- logger.warning(f"FTS index creation failed: {e}")
772
+ """Create FTS index. Delegates to IndexManager."""
773
+ if self._index_manager is None:
774
+ raise StorageError("Database not connected")
775
+ self._index_manager.create_fts_index()
776
+ # Sync local state for backward compatibility
777
+ self._has_fts_index = self._index_manager.has_fts_index
712
778
 
713
779
  @property
714
780
  def table(self) -> LanceTable:
@@ -719,18 +785,30 @@ class Database:
719
785
  return self._table
720
786
 
721
787
  def close(self) -> None:
722
- """Close the database connection (connection remains pooled)."""
788
+ """Close the database connection and remove from pool.
789
+
790
+ This invalidates the pooled connection so that subsequent
791
+ Database instances will create fresh connections.
792
+ """
793
+ # Invalidate pooled connection first
794
+ invalidate_connection(self.storage_path)
795
+
796
+ # Clear local state
723
797
  self._table = None
724
798
  self._db = None
725
799
  self._has_vector_index = None
726
800
  self._has_fts_index = None
801
+ self._version_manager = None
802
+ self._index_manager = None
803
+ self._search_manager = None
804
+ self._idempotency_manager = None
727
805
  with self._cache_lock:
728
806
  self._cached_row_count = None
729
807
  self._count_cache_time = 0.0
730
808
  with self._namespace_cache_lock:
731
809
  self._cached_namespaces = None
732
810
  self._namespace_cache_time = 0.0
733
- logger.debug("Database connection closed")
811
+ logger.debug("Database connection closed and removed from pool")
734
812
 
735
813
  def reconnect(self) -> None:
736
814
  """Invalidate cached connection and reconnect.
@@ -790,313 +868,86 @@ class Database:
790
868
  self._cached_namespaces = None
791
869
  self._namespace_cache_time = 0.0
792
870
 
793
- # ========================================================================
794
- # Index Management
795
- # ========================================================================
796
-
797
- def create_vector_index(self, force: bool = False) -> bool:
798
- """Create vector index for similarity search.
799
-
800
- Supports IVF_PQ, IVF_FLAT, and HNSW_SQ index types based on configuration.
801
- Automatically determines optimal parameters based on dataset size.
871
+ def _track_modification(self, count: int = 1) -> None:
872
+ """Track database modifications and trigger auto-compaction if threshold reached.
802
873
 
803
874
  Args:
804
- force: Force index creation regardless of dataset size.
805
-
806
- Returns:
807
- True if index was created, False if skipped.
808
-
809
- Raises:
810
- StorageError: If index creation fails.
875
+ count: Number of modifications to track (default 1).
811
876
  """
812
- count = self.table.count_rows()
813
-
814
- # Check threshold
815
- if count < self.vector_index_threshold and not force:
816
- logger.info(
817
- f"Dataset has {count} rows, below threshold {self.vector_index_threshold}. "
818
- "Skipping vector index creation."
819
- )
820
- return False
821
-
822
- # Check if already exists
823
- if self._has_vector_index and not force:
824
- logger.info("Vector index already exists")
825
- return False
826
-
827
- # Handle HNSW_SQ index type
828
- if self.index_type == "HNSW_SQ":
829
- return self._create_hnsw_index(count)
830
-
831
- # IVF-based index creation (IVF_PQ or IVF_FLAT)
832
- return self._create_ivf_index(count)
877
+ if not self._auto_compaction_enabled:
878
+ return
833
879
 
834
- def _create_hnsw_index(self, count: int) -> bool:
835
- """Create HNSW-SQ vector index.
880
+ self._modification_count += count
881
+ if self._modification_count >= self._auto_compaction_threshold:
882
+ # Reset counter before compacting to avoid re-triggering
883
+ self._modification_count = 0
884
+ try:
885
+ stats = self._get_table_stats()
886
+ # Only compact if there are enough fragments to justify it
887
+ if stats.get("num_small_fragments", 0) >= 5:
888
+ logger.info(
889
+ f"Auto-compaction triggered after {self._auto_compaction_threshold} "
890
+ f"modifications ({stats.get('num_small_fragments', 0)} small fragments)"
891
+ )
892
+ self.table.compact_files()
893
+ logger.debug("Auto-compaction completed")
894
+ except Exception as e:
895
+ # Don't fail operations due to compaction issues
896
+ logger.debug(f"Auto-compaction skipped: {e}")
836
897
 
837
- HNSW (Hierarchical Navigable Small World) provides better recall than IVF
838
- at the cost of higher memory usage. Good for datasets where recall is critical.
898
+ def set_auto_compaction(
899
+ self,
900
+ enabled: bool = True,
901
+ threshold: int | None = None,
902
+ ) -> None:
903
+ """Configure auto-compaction behavior.
839
904
 
840
905
  Args:
841
- count: Number of rows in the table.
842
-
843
- Returns:
844
- True if index was created.
845
-
846
- Raises:
847
- StorageError: If index creation fails.
906
+ enabled: Whether auto-compaction is enabled.
907
+ threshold: Number of modifications before auto-compact (default: 100).
848
908
  """
849
- logger.info(
850
- f"Creating HNSW_SQ vector index: m={self.hnsw_m}, "
851
- f"ef_construction={self.hnsw_ef_construction} for {count} rows"
852
- )
853
-
854
- try:
855
- self.table.create_index(
856
- metric="cosine",
857
- vector_column_name="vector",
858
- index_type="HNSW_SQ",
859
- replace=True,
860
- m=self.hnsw_m,
861
- ef_construction=self.hnsw_ef_construction,
862
- )
863
-
864
- # Wait for index to be ready with configurable timeout
865
- self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
866
-
867
- self._has_vector_index = True
868
- logger.info("HNSW_SQ vector index created successfully")
869
-
870
- # Optimize after index creation (may fail in some environments)
871
- try:
872
- self.table.optimize()
873
- except Exception as optimize_error:
874
- logger.debug(f"Optimization after index creation skipped: {optimize_error}")
909
+ self._auto_compaction_enabled = enabled
910
+ if threshold is not None:
911
+ if threshold < 10:
912
+ raise ValueError("Auto-compaction threshold must be at least 10")
913
+ self._auto_compaction_threshold = threshold
875
914
 
876
- return True
877
-
878
- except Exception as e:
879
- logger.error(f"Failed to create HNSW_SQ vector index: {e}")
880
- raise StorageError(f"HNSW_SQ vector index creation failed: {e}") from e
881
-
882
- def _create_ivf_index(self, count: int) -> bool:
883
- """Create IVF-PQ or IVF-FLAT vector index.
915
+ # ========================================================================
916
+ # Index Management (delegates to IndexManager)
917
+ # ========================================================================
884
918
 
885
- Uses sqrt rule for partitions: num_partitions = sqrt(count), clamped to [16, 4096].
886
- Uses 48 sub-vectors for <500K rows (8 dims each for 384-dim vectors),
887
- 96 sub-vectors for >=500K rows (4 dims each).
919
+ def create_vector_index(self, force: bool = False) -> bool:
920
+ """Create vector index for similarity search. Delegates to IndexManager.
888
921
 
889
922
  Args:
890
- count: Number of rows in the table.
923
+ force: Force index creation regardless of dataset size.
891
924
 
892
925
  Returns:
893
- True if index was created.
926
+ True if index was created, False if skipped.
894
927
 
895
928
  Raises:
896
929
  StorageError: If index creation fails.
897
930
  """
898
- # Use sqrt rule for partitions, clamped to [16, 4096]
899
- num_partitions = int(math.sqrt(count))
900
- num_partitions = max(16, min(num_partitions, 4096))
901
-
902
- # Choose num_sub_vectors based on dataset size
903
- # <500K: 48 sub-vectors (8 dims each for 384-dim, more precision)
904
- # >=500K: 96 sub-vectors (4 dims each, more compression)
905
- if count < 500_000:
906
- num_sub_vectors = 48
907
- else:
908
- num_sub_vectors = 96
909
-
910
- # Validate embedding_dim % num_sub_vectors == 0 (required for IVF-PQ)
911
- if self.embedding_dim % num_sub_vectors != 0:
912
- # Find a valid divisor from common sub-vector counts
913
- valid_divisors = [96, 48, 32, 24, 16, 12, 8, 4]
914
- found_divisor = False
915
- for divisor in valid_divisors:
916
- if self.embedding_dim % divisor == 0:
917
- logger.info(
918
- f"Adjusted num_sub_vectors from {num_sub_vectors} to {divisor} "
919
- f"for embedding_dim={self.embedding_dim}"
920
- )
921
- num_sub_vectors = divisor
922
- found_divisor = True
923
- break
924
-
925
- if not found_divisor:
926
- raise StorageError(
927
- f"Cannot create IVF-PQ index: embedding_dim={self.embedding_dim} "
928
- "has no suitable divisor for sub-vectors. "
929
- f"Tried divisors: {valid_divisors}"
930
- )
931
-
932
- # IVF-PQ requires minimum rows for training (sample_rate * num_partitions / 256)
933
- # Default sample_rate=256, so we need at least 256 rows
934
- # Also, IVF requires num_partitions < num_vectors for KMeans training
935
- sample_rate = 256 # default
936
- if count < 256:
937
- # Use IVF_FLAT for very small datasets (no PQ training required)
938
- logger.info(
939
- f"Dataset too small for IVF-PQ ({count} rows < 256). "
940
- "Using IVF_FLAT index instead."
941
- )
942
- index_type = "IVF_FLAT"
943
- sample_rate = max(16, count // 4) # Lower sample rate for small data
944
- else:
945
- index_type = self.index_type if self.index_type in ("IVF_PQ", "IVF_FLAT") else "IVF_PQ"
946
-
947
- # Ensure num_partitions < num_vectors for KMeans clustering
948
- if num_partitions >= count:
949
- num_partitions = max(1, count // 4) # Use 1/4 of count, minimum 1
950
- logger.info(f"Adjusted num_partitions to {num_partitions} for {count} rows")
951
-
952
- logger.info(
953
- f"Creating {index_type} vector index: {num_partitions} partitions, "
954
- f"{num_sub_vectors} sub-vectors for {count} rows"
955
- )
956
-
957
- try:
958
- # LanceDB 0.27+ API: parameters passed directly to create_index
959
- index_kwargs: dict[str, Any] = {
960
- "metric": "cosine",
961
- "num_partitions": num_partitions,
962
- "vector_column_name": "vector",
963
- "index_type": index_type,
964
- "replace": True,
965
- "sample_rate": sample_rate,
966
- }
967
-
968
- # num_sub_vectors only applies to PQ-based indexes
969
- if "PQ" in index_type:
970
- index_kwargs["num_sub_vectors"] = num_sub_vectors
971
-
972
- self.table.create_index(**index_kwargs)
973
-
974
- # Wait for index to be ready with configurable timeout
975
- self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
976
-
977
- self._has_vector_index = True
978
- logger.info(f"{index_type} vector index created successfully")
979
-
980
- # Optimize after index creation (may fail in some environments)
981
- try:
982
- self.table.optimize()
983
- except Exception as optimize_error:
984
- logger.debug(f"Optimization after index creation skipped: {optimize_error}")
985
-
986
- return True
987
-
988
- except Exception as e:
989
- logger.error(f"Failed to create {index_type} vector index: {e}")
990
- raise StorageError(f"{index_type} vector index creation failed: {e}") from e
991
-
992
- def _wait_for_index_ready(
993
- self,
994
- column_name: str,
995
- timeout_seconds: float,
996
- poll_interval: float = 0.5,
997
- ) -> None:
998
- """Wait for an index on the specified column to be ready.
999
-
1000
- Args:
1001
- column_name: Name of the column the index is on (e.g., "vector").
1002
- LanceDB typically names indexes as "{column_name}_idx".
1003
- timeout_seconds: Maximum time to wait.
1004
- poll_interval: Time between status checks.
1005
- """
1006
- if timeout_seconds <= 0:
1007
- return
1008
-
1009
- start_time = time.time()
1010
- while time.time() - start_time < timeout_seconds:
1011
- try:
1012
- indices = self.table.list_indices()
1013
- for idx in indices:
1014
- idx_name = str(_get_index_attr(idx, "name", "")).lower()
1015
- idx_columns = _get_index_attr(idx, "columns", [])
1016
-
1017
- # Match by column name in index metadata, or index name contains column
1018
- if column_name in idx_columns or column_name in idx_name:
1019
- # Index exists, check if it's ready
1020
- status = str(_get_index_attr(idx, "status", "ready"))
1021
- if status.lower() in ("ready", "complete", "built"):
1022
- logger.debug(f"Index on {column_name} is ready")
1023
- return
1024
- break
1025
- except Exception as e:
1026
- logger.debug(f"Error checking index status: {e}")
1027
-
1028
- time.sleep(poll_interval)
1029
-
1030
- logger.warning(
1031
- f"Timeout waiting for index on {column_name} after {timeout_seconds}s"
1032
- )
931
+ if self._index_manager is None:
932
+ raise StorageError("Database not connected")
933
+ result = self._index_manager.create_vector_index(force=force)
934
+ # Sync local state only when index was created or modified
935
+ if result:
936
+ self._has_vector_index = self._index_manager.has_vector_index
937
+ return result
1033
938
 
1034
939
  def create_scalar_indexes(self) -> None:
1035
- """Create scalar indexes for frequently filtered columns.
1036
-
1037
- Creates:
1038
- - BTREE on id (fast lookups, upserts)
1039
- - BTREE on timestamps and importance (range queries)
1040
- - BITMAP on namespace and source (low cardinality)
1041
- - LABEL_LIST on tags (array contains queries)
940
+ """Create scalar indexes for frequently filtered columns. Delegates to IndexManager.
1042
941
 
1043
942
  Raises:
1044
943
  StorageError: If index creation fails critically.
1045
944
  """
1046
- # BTREE indexes for range queries and lookups
1047
- btree_columns = [
1048
- "id", # Fast lookups and merge_insert
1049
- "created_at",
1050
- "updated_at",
1051
- "last_accessed",
1052
- "importance",
1053
- "access_count",
1054
- "expires_at", # TTL expiration queries
1055
- ]
1056
-
1057
- for column in btree_columns:
1058
- try:
1059
- self.table.create_scalar_index(
1060
- column,
1061
- index_type="BTREE",
1062
- replace=True,
1063
- )
1064
- logger.debug(f"Created BTREE index on {column}")
1065
- except Exception as e:
1066
- if "already exists" not in str(e).lower():
1067
- logger.warning(f"Could not create BTREE index on {column}: {e}")
1068
-
1069
- # BITMAP indexes for low-cardinality columns
1070
- bitmap_columns = ["namespace", "source"]
1071
-
1072
- for column in bitmap_columns:
1073
- try:
1074
- self.table.create_scalar_index(
1075
- column,
1076
- index_type="BITMAP",
1077
- replace=True,
1078
- )
1079
- logger.debug(f"Created BITMAP index on {column}")
1080
- except Exception as e:
1081
- if "already exists" not in str(e).lower():
1082
- logger.warning(f"Could not create BITMAP index on {column}: {e}")
1083
-
1084
- # LABEL_LIST index for tags array (supports array_has_any queries)
1085
- try:
1086
- self.table.create_scalar_index(
1087
- "tags",
1088
- index_type="LABEL_LIST",
1089
- replace=True,
1090
- )
1091
- logger.debug("Created LABEL_LIST index on tags")
1092
- except Exception as e:
1093
- if "already exists" not in str(e).lower():
1094
- logger.warning(f"Could not create LABEL_LIST index on tags: {e}")
1095
-
1096
- logger.info("Scalar indexes created")
945
+ if self._index_manager is None:
946
+ raise StorageError("Database not connected")
947
+ self._index_manager.create_scalar_indexes()
1097
948
 
1098
949
  def ensure_indexes(self, force: bool = False) -> dict[str, bool]:
1099
- """Ensure all appropriate indexes exist.
950
+ """Ensure all appropriate indexes exist. Delegates to IndexManager.
1100
951
 
1101
952
  Args:
1102
953
  force: Force index creation regardless of thresholds.
@@ -1104,35 +955,12 @@ class Database:
1104
955
  Returns:
1105
956
  Dict indicating which indexes were created.
1106
957
  """
1107
- results = {
1108
- "vector_index": False,
1109
- "scalar_indexes": False,
1110
- "fts_index": False,
1111
- }
1112
-
1113
- count = self.table.count_rows()
1114
-
1115
- # Vector index
1116
- if self.auto_create_indexes or force:
1117
- if count >= self.vector_index_threshold or force:
1118
- results["vector_index"] = self.create_vector_index(force=force)
1119
-
1120
- # Scalar indexes (always create if > 1000 rows)
1121
- if count >= 1000 or force:
1122
- try:
1123
- self.create_scalar_indexes()
1124
- results["scalar_indexes"] = True
1125
- except Exception as e:
1126
- logger.warning(f"Scalar index creation partially failed: {e}")
1127
-
1128
- # FTS index
1129
- if self.enable_fts and not self._has_fts_index:
1130
- try:
1131
- self._create_fts_index()
1132
- results["fts_index"] = True
1133
- except Exception as e:
1134
- logger.warning(f"FTS index creation failed in ensure_indexes: {e}")
1135
-
958
+ if self._index_manager is None:
959
+ raise StorageError("Database not connected")
960
+ results = self._index_manager.ensure_indexes(force=force)
961
+ # Sync local state for backward compatibility
962
+ self._has_vector_index = self._index_manager.has_vector_index
963
+ self._has_fts_index = self._index_manager.has_fts_index
1136
964
  return results
1137
965
 
1138
966
  # ========================================================================
@@ -1301,6 +1129,13 @@ class Database:
1301
1129
  if not 0.0 <= importance <= 1.0:
1302
1130
  raise ValidationError("Importance must be between 0.0 and 1.0")
1303
1131
 
1132
+ # Validate vector dimensions
1133
+ if len(vector) != self.embedding_dim:
1134
+ raise DimensionMismatchError(
1135
+ expected_dim=self.embedding_dim,
1136
+ actual_dim=len(vector),
1137
+ )
1138
+
1304
1139
  memory_id = str(uuid.uuid4())
1305
1140
  now = utc_now()
1306
1141
 
@@ -1328,6 +1163,7 @@ class Database:
1328
1163
  try:
1329
1164
  self.table.add([record])
1330
1165
  self._invalidate_count_cache()
1166
+ self._track_modification()
1331
1167
  self._invalidate_namespace_cache()
1332
1168
  logger.debug(f"Inserted memory {memory_id}")
1333
1169
  return memory_id
@@ -1344,23 +1180,26 @@ class Database:
1344
1180
  self,
1345
1181
  records: list[dict[str, Any]],
1346
1182
  batch_size: int = 1000,
1183
+ atomic: bool = False,
1347
1184
  ) -> list[str]:
1348
1185
  """Insert multiple memories efficiently with batching.
1349
1186
 
1350
- Note: Batch insert is NOT atomic. Partial failures may leave some
1351
- records inserted. If atomicity is required, use individual inserts
1352
- with transaction management at the application layer.
1353
-
1354
1187
  Args:
1355
1188
  records: List of memory records with content, vector, and optional fields.
1356
1189
  batch_size: Records per batch (default: 1000, max: 10000).
1190
+ atomic: If True, rollback all inserts on partial failure.
1191
+ When atomic=True and a batch fails:
1192
+ - Attempts to delete already-inserted records
1193
+ - If rollback succeeds, raises the original StorageError
1194
+ - If rollback fails, raises PartialBatchInsertError with succeeded_ids
1357
1195
 
1358
1196
  Returns:
1359
1197
  List of generated memory IDs.
1360
1198
 
1361
1199
  Raises:
1362
1200
  ValidationError: If input validation fails or batch_size exceeds maximum.
1363
- StorageError: If database operation fails.
1201
+ StorageError: If database operation fails (and rollback succeeds when atomic=True).
1202
+ PartialBatchInsertError: If atomic=True and rollback fails after partial insert.
1364
1203
  """
1365
1204
  if batch_size > self.MAX_BATCH_SIZE:
1366
1205
  raise ValidationError(
@@ -1368,9 +1207,10 @@ class Database:
1368
1207
  )
1369
1208
 
1370
1209
  all_ids: list[str] = []
1210
+ total_requested = len(records)
1371
1211
 
1372
1212
  # Process in batches for large inserts
1373
- for i in range(0, len(records), batch_size):
1213
+ for batch_index, i in enumerate(range(0, len(records), batch_size)):
1374
1214
  batch = records[i:i + batch_size]
1375
1215
  now = utc_now()
1376
1216
  memory_ids: list[str] = []
@@ -1398,6 +1238,14 @@ class Database:
1398
1238
  else:
1399
1239
  vector_list = raw_vector
1400
1240
 
1241
+ # Validate vector dimensions
1242
+ if len(vector_list) != self.embedding_dim:
1243
+ raise DimensionMismatchError(
1244
+ expected_dim=self.embedding_dim,
1245
+ actual_dim=len(vector_list),
1246
+ record_index=i + len(memory_ids),
1247
+ )
1248
+
1401
1249
  # Calculate expires_at if default TTL is configured
1402
1250
  expires_at = None
1403
1251
  if self.default_memory_ttl_days is not None:
@@ -1424,9 +1272,29 @@ class Database:
1424
1272
  self.table.add(prepared_records)
1425
1273
  all_ids.extend(memory_ids)
1426
1274
  self._invalidate_count_cache()
1275
+ self._track_modification(len(memory_ids))
1427
1276
  self._invalidate_namespace_cache()
1428
- logger.debug(f"Inserted batch {i // batch_size + 1}: {len(memory_ids)} memories")
1277
+ logger.debug(f"Inserted batch {batch_index + 1}: {len(memory_ids)} memories")
1429
1278
  except Exception as e:
1279
+ if atomic and all_ids:
1280
+ # Attempt rollback of previously inserted records
1281
+ logger.warning(
1282
+ f"Batch {batch_index + 1} failed, attempting rollback of "
1283
+ f"{len(all_ids)} previously inserted records"
1284
+ )
1285
+ rollback_error = self._rollback_batch_insert(all_ids)
1286
+ if rollback_error:
1287
+ # Rollback failed - raise PartialBatchInsertError
1288
+ raise PartialBatchInsertError(
1289
+ message=f"Batch insert failed and rollback also failed: {e}",
1290
+ succeeded_ids=all_ids,
1291
+ total_requested=total_requested,
1292
+ failed_batch_index=batch_index,
1293
+ ) from e
1294
+ else:
1295
+ # Rollback succeeded - raise original error
1296
+ logger.info(f"Rollback successful, deleted {len(all_ids)} records")
1297
+ raise StorageError(f"Failed to insert batch (rolled back): {e}") from e
1430
1298
  raise StorageError(f"Failed to insert batch: {e}") from e
1431
1299
 
1432
1300
  # Check if we should create indexes after large insert
@@ -1442,6 +1310,31 @@ class Database:
1442
1310
  logger.debug(f"Inserted {len(all_ids)} memories total")
1443
1311
  return all_ids
1444
1312
 
1313
+ def _rollback_batch_insert(self, memory_ids: list[str]) -> Exception | None:
1314
+ """Attempt to delete records inserted during a failed batch operation.
1315
+
1316
+ Args:
1317
+ memory_ids: List of memory IDs to delete.
1318
+
1319
+ Returns:
1320
+ None if rollback succeeded, Exception if it failed.
1321
+ """
1322
+ try:
1323
+ if not memory_ids:
1324
+ return None
1325
+
1326
+ # Use delete_batch for efficient rollback
1327
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in memory_ids)
1328
+ self.table.delete(f"id IN ({id_list})")
1329
+ self._invalidate_count_cache()
1330
+ self._track_modification(len(memory_ids))
1331
+ self._invalidate_namespace_cache()
1332
+ logger.debug(f"Rolled back {len(memory_ids)} records")
1333
+ return None
1334
+ except Exception as e:
1335
+ logger.error(f"Rollback failed: {e}")
1336
+ return e
1337
+
1445
1338
  @with_stale_connection_recovery
1446
1339
  def get(self, memory_id: str) -> dict[str, Any]:
1447
1340
  """Get a memory by ID.
@@ -1476,6 +1369,51 @@ class Database:
1476
1369
  except Exception as e:
1477
1370
  raise StorageError(f"Failed to get memory: {e}") from e
1478
1371
 
1372
+ def get_batch(self, memory_ids: list[str]) -> dict[str, dict[str, Any]]:
1373
+ """Get multiple memories by ID in a single query.
1374
+
1375
+ Args:
1376
+ memory_ids: List of memory UUIDs to retrieve.
1377
+
1378
+ Returns:
1379
+ Dict mapping memory_id to memory record. Missing IDs are not included.
1380
+
1381
+ Raises:
1382
+ ValidationError: If any memory_id format is invalid.
1383
+ StorageError: If database operation fails.
1384
+ """
1385
+ if not memory_ids:
1386
+ return {}
1387
+
1388
+ # Validate all IDs first
1389
+ validated_ids: list[str] = []
1390
+ for memory_id in memory_ids:
1391
+ try:
1392
+ validated_id = _validate_uuid(memory_id)
1393
+ validated_ids.append(_sanitize_string(validated_id))
1394
+ except Exception as e:
1395
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1396
+ continue
1397
+
1398
+ if not validated_ids:
1399
+ return {}
1400
+
1401
+ try:
1402
+ # Batch fetch with single IN query
1403
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1404
+ results = self.table.search().where(f"id IN ({id_list})").to_list()
1405
+
1406
+ # Build result map
1407
+ result_map: dict[str, dict[str, Any]] = {}
1408
+ for record in results:
1409
+ # Deserialize metadata
1410
+ record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
1411
+ result_map[record["id"]] = record
1412
+
1413
+ return result_map
1414
+ except Exception as e:
1415
+ raise StorageError(f"Failed to batch get memories: {e}") from e
1416
+
1479
1417
  @with_process_lock
1480
1418
  @with_write_lock
1481
1419
  def update(self, memory_id: str, updates: dict[str, Any]) -> None:
@@ -1533,42 +1471,145 @@ class Database:
1533
1471
 
1534
1472
  @with_process_lock
1535
1473
  @with_write_lock
1536
- def delete(self, memory_id: str) -> None:
1537
- """Delete a memory.
1474
+ def update_batch(
1475
+ self, updates: list[tuple[str, dict[str, Any]]]
1476
+ ) -> tuple[int, list[str]]:
1477
+ """Update multiple memories using atomic merge_insert.
1538
1478
 
1539
1479
  Args:
1540
- memory_id: The memory ID.
1480
+ updates: List of (memory_id, updates_dict) tuples.
1481
+
1482
+ Returns:
1483
+ Tuple of (success_count, list of failed memory_ids).
1541
1484
 
1542
1485
  Raises:
1543
- ValidationError: If memory_id is invalid.
1544
- MemoryNotFoundError: If memory doesn't exist.
1545
- StorageError: If database operation fails.
1486
+ StorageError: If database operation fails completely.
1546
1487
  """
1547
- # Validate memory_id
1548
- memory_id = _validate_uuid(memory_id)
1549
- safe_id = _sanitize_string(memory_id)
1488
+ if not updates:
1489
+ return 0, []
1550
1490
 
1551
- # First verify the memory exists
1552
- self.get(memory_id)
1491
+ now = utc_now()
1492
+ records_to_update: list[dict[str, Any]] = []
1493
+ failed_ids: list[str] = []
1553
1494
 
1495
+ # Validate all IDs and collect them
1496
+ validated_updates: list[tuple[str, dict[str, Any]]] = []
1497
+ for memory_id, update_dict in updates:
1498
+ try:
1499
+ validated_id = _validate_uuid(memory_id)
1500
+ validated_updates.append((_sanitize_string(validated_id), update_dict))
1501
+ except Exception as e:
1502
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1503
+ failed_ids.append(memory_id)
1504
+
1505
+ if not validated_updates:
1506
+ return 0, failed_ids
1507
+
1508
+ # Batch fetch all records
1509
+ validated_ids = [vid for vid, _ in validated_updates]
1554
1510
  try:
1555
- self.table.delete(f"id = '{safe_id}'")
1556
- self._invalidate_count_cache()
1557
- self._invalidate_namespace_cache()
1558
- logger.debug(f"Deleted memory {memory_id}")
1511
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1512
+ all_records = self.table.search().where(f"id IN ({id_list})").to_list()
1559
1513
  except Exception as e:
1560
- raise StorageError(f"Failed to delete memory: {e}") from e
1514
+ logger.error(f"Failed to batch fetch records for update: {e}")
1515
+ raise StorageError(f"Failed to batch fetch for update: {e}") from e
1561
1516
 
1562
- @with_process_lock
1563
- @with_write_lock
1564
- def delete_by_namespace(self, namespace: str) -> int:
1565
- """Delete all memories in a namespace.
1517
+ # Build lookup map
1518
+ record_map: dict[str, dict[str, Any]] = {}
1519
+ for record in all_records:
1520
+ record_map[record["id"]] = record
1521
+
1522
+ # Apply updates to found records
1523
+ update_dict_map = dict(validated_updates)
1524
+ for memory_id in validated_ids:
1525
+ if memory_id not in record_map:
1526
+ logger.debug(f"Memory {memory_id} not found for batch update")
1527
+ failed_ids.append(memory_id)
1528
+ continue
1566
1529
 
1567
- Args:
1568
- namespace: The namespace to delete.
1530
+ record = record_map[memory_id]
1531
+ update_dict = update_dict_map[memory_id]
1569
1532
 
1570
- Returns:
1571
- Number of deleted records.
1533
+ # Apply updates
1534
+ record["updated_at"] = now
1535
+ for key, value in update_dict.items():
1536
+ if key == "metadata" and isinstance(value, dict):
1537
+ record[key] = json.dumps(value)
1538
+ elif key == "vector" and isinstance(value, np.ndarray):
1539
+ record[key] = value.tolist()
1540
+ else:
1541
+ record[key] = value
1542
+
1543
+ # Ensure metadata is serialized
1544
+ if isinstance(record.get("metadata"), dict):
1545
+ record["metadata"] = json.dumps(record["metadata"])
1546
+
1547
+ # Ensure vector is a list
1548
+ if isinstance(record.get("vector"), np.ndarray):
1549
+ record["vector"] = record["vector"].tolist()
1550
+
1551
+ records_to_update.append(record)
1552
+
1553
+ if not records_to_update:
1554
+ return 0, failed_ids
1555
+
1556
+ try:
1557
+ # Atomic batch upsert
1558
+ (
1559
+ self.table.merge_insert("id")
1560
+ .when_matched_update_all()
1561
+ .when_not_matched_insert_all()
1562
+ .execute(records_to_update)
1563
+ )
1564
+ success_count = len(records_to_update)
1565
+ logger.debug(
1566
+ f"Batch updated {success_count}/{len(updates)} memories "
1567
+ "(atomic merge_insert)"
1568
+ )
1569
+ return success_count, failed_ids
1570
+ except Exception as e:
1571
+ logger.error(f"Failed to batch update: {e}")
1572
+ raise StorageError(f"Failed to batch update: {e}") from e
1573
+
1574
+ @with_process_lock
1575
+ @with_write_lock
1576
+ def delete(self, memory_id: str) -> None:
1577
+ """Delete a memory.
1578
+
1579
+ Args:
1580
+ memory_id: The memory ID.
1581
+
1582
+ Raises:
1583
+ ValidationError: If memory_id is invalid.
1584
+ MemoryNotFoundError: If memory doesn't exist.
1585
+ StorageError: If database operation fails.
1586
+ """
1587
+ # Validate memory_id
1588
+ memory_id = _validate_uuid(memory_id)
1589
+ safe_id = _sanitize_string(memory_id)
1590
+
1591
+ # First verify the memory exists
1592
+ self.get(memory_id)
1593
+
1594
+ try:
1595
+ self.table.delete(f"id = '{safe_id}'")
1596
+ self._invalidate_count_cache()
1597
+ self._track_modification()
1598
+ self._invalidate_namespace_cache()
1599
+ logger.debug(f"Deleted memory {memory_id}")
1600
+ except Exception as e:
1601
+ raise StorageError(f"Failed to delete memory: {e}") from e
1602
+
1603
+ @with_process_lock
1604
+ @with_write_lock
1605
+ def delete_by_namespace(self, namespace: str) -> int:
1606
+ """Delete all memories in a namespace.
1607
+
1608
+ Args:
1609
+ namespace: The namespace to delete.
1610
+
1611
+ Returns:
1612
+ Number of deleted records.
1572
1613
 
1573
1614
  Raises:
1574
1615
  ValidationError: If namespace is invalid.
@@ -1581,6 +1622,7 @@ class Database:
1581
1622
  count_before: int = self.table.count_rows()
1582
1623
  self.table.delete(f"namespace = '{safe_ns}'")
1583
1624
  self._invalidate_count_cache()
1625
+ self._track_modification()
1584
1626
  self._invalidate_namespace_cache()
1585
1627
  count_after: int = self.table.count_rows()
1586
1628
  deleted = count_before - count_after
@@ -1624,6 +1666,7 @@ class Database:
1624
1666
  self.table.delete("id IS NOT NULL")
1625
1667
 
1626
1668
  self._invalidate_count_cache()
1669
+ self._track_modification()
1627
1670
  self._invalidate_namespace_cache()
1628
1671
 
1629
1672
  # Reset index tracking flags for test isolation
@@ -1643,6 +1686,7 @@ class Database:
1643
1686
  """Rename all memories from one namespace to another.
1644
1687
 
1645
1688
  Uses atomic batch update via merge_insert for data integrity.
1689
+ On partial failure, attempts to rollback renamed records to original namespace.
1646
1690
 
1647
1691
  Args:
1648
1692
  old_namespace: Source namespace name.
@@ -1661,6 +1705,7 @@ class Database:
1661
1705
  old_namespace = _validate_namespace(old_namespace)
1662
1706
  new_namespace = _validate_namespace(new_namespace)
1663
1707
  safe_old = _sanitize_string(old_namespace)
1708
+ safe_new = _sanitize_string(new_namespace)
1664
1709
 
1665
1710
  try:
1666
1711
  # Check if source namespace exists
@@ -1674,6 +1719,9 @@ class Database:
1674
1719
  logger.debug(f"Namespace '{old_namespace}' renamed to itself ({count} records)")
1675
1720
  return count
1676
1721
 
1722
+ # Track renamed IDs for rollback capability
1723
+ renamed_ids: list[str] = []
1724
+
1677
1725
  # Fetch all records in batches with iteration safeguards
1678
1726
  batch_size = 1000
1679
1727
  max_iterations = 10000 # Safety cap: 10M records at 1000/batch
@@ -1702,6 +1750,9 @@ class Database:
1702
1750
  if not records:
1703
1751
  break
1704
1752
 
1753
+ # Track IDs in this batch for potential rollback
1754
+ batch_ids = [r["id"] for r in records]
1755
+
1705
1756
  # Update namespace field
1706
1757
  for r in records:
1707
1758
  r["namespace"] = new_namespace
@@ -1711,13 +1762,41 @@ class Database:
1711
1762
  if isinstance(r.get("vector"), np.ndarray):
1712
1763
  r["vector"] = r["vector"].tolist()
1713
1764
 
1714
- # Atomic upsert
1715
- (
1716
- self.table.merge_insert("id")
1717
- .when_matched_update_all()
1718
- .when_not_matched_insert_all()
1719
- .execute(records)
1720
- )
1765
+ try:
1766
+ # Atomic upsert
1767
+ (
1768
+ self.table.merge_insert("id")
1769
+ .when_matched_update_all()
1770
+ .when_not_matched_insert_all()
1771
+ .execute(records)
1772
+ )
1773
+ # Only track as renamed after successful update
1774
+ renamed_ids.extend(batch_ids)
1775
+ except Exception as batch_error:
1776
+ # Batch failed - attempt rollback of previously renamed records
1777
+ if renamed_ids:
1778
+ logger.warning(
1779
+ f"Batch {iteration} failed, attempting rollback of "
1780
+ f"{len(renamed_ids)} previously renamed records"
1781
+ )
1782
+ rollback_error = self._rollback_namespace_rename(
1783
+ renamed_ids, old_namespace
1784
+ )
1785
+ if rollback_error:
1786
+ raise StorageError(
1787
+ f"Namespace rename failed at batch {iteration} and "
1788
+ f"rollback also failed. {len(renamed_ids)} records may be "
1789
+ f"in inconsistent state (partially in '{new_namespace}'). "
1790
+ f"Original error: {batch_error}. Rollback error: {rollback_error}"
1791
+ ) from batch_error
1792
+ else:
1793
+ logger.info(
1794
+ f"Rollback successful, reverted {len(renamed_ids)} records "
1795
+ f"back to namespace '{old_namespace}'"
1796
+ )
1797
+ raise StorageError(
1798
+ f"Failed to rename namespace (rolled back): {batch_error}"
1799
+ ) from batch_error
1721
1800
 
1722
1801
  updated += len(records)
1723
1802
 
@@ -1740,6 +1819,66 @@ class Database:
1740
1819
  except Exception as e:
1741
1820
  raise StorageError(f"Failed to rename namespace: {e}") from e
1742
1821
 
1822
+ def _rollback_namespace_rename(
1823
+ self, memory_ids: list[str], target_namespace: str
1824
+ ) -> Exception | None:
1825
+ """Attempt to revert renamed records back to original namespace.
1826
+
1827
+ Args:
1828
+ memory_ids: List of memory IDs to revert.
1829
+ target_namespace: Namespace to revert records to.
1830
+
1831
+ Returns:
1832
+ None if rollback succeeded, Exception if it failed.
1833
+ """
1834
+ try:
1835
+ if not memory_ids:
1836
+ return None
1837
+
1838
+ safe_namespace = _sanitize_string(target_namespace)
1839
+ now = utc_now()
1840
+
1841
+ # Process in batches for large rollbacks
1842
+ batch_size = 1000
1843
+ for i in range(0, len(memory_ids), batch_size):
1844
+ batch_ids = memory_ids[i:i + batch_size]
1845
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in batch_ids)
1846
+
1847
+ # Fetch records that need rollback
1848
+ records = (
1849
+ self.table.search()
1850
+ .where(f"id IN ({id_list})")
1851
+ .to_list()
1852
+ )
1853
+
1854
+ if not records:
1855
+ continue
1856
+
1857
+ # Revert namespace
1858
+ for r in records:
1859
+ r["namespace"] = target_namespace
1860
+ r["updated_at"] = now
1861
+ if isinstance(r.get("metadata"), dict):
1862
+ r["metadata"] = json.dumps(r["metadata"])
1863
+ if isinstance(r.get("vector"), np.ndarray):
1864
+ r["vector"] = r["vector"].tolist()
1865
+
1866
+ # Atomic upsert to restore original namespace
1867
+ (
1868
+ self.table.merge_insert("id")
1869
+ .when_matched_update_all()
1870
+ .when_not_matched_insert_all()
1871
+ .execute(records)
1872
+ )
1873
+
1874
+ self._invalidate_namespace_cache()
1875
+ logger.debug(f"Rolled back {len(memory_ids)} records to namespace '{target_namespace}'")
1876
+ return None
1877
+
1878
+ except Exception as e:
1879
+ logger.error(f"Namespace rename rollback failed: {e}")
1880
+ return e
1881
+
1743
1882
  @with_stale_connection_recovery
1744
1883
  def get_stats(self, namespace: str | None = None) -> dict[str, Any]:
1745
1884
  """Get comprehensive database statistics.
@@ -1836,15 +1975,18 @@ class Database:
1836
1975
  safe_ns = _sanitize_string(namespace)
1837
1976
 
1838
1977
  try:
1839
- # Get records for this namespace (select created_at and content for stats)
1840
- records = (
1978
+ # Get count efficiently
1979
+ filter_expr = f"namespace = '{safe_ns}'"
1980
+ count_results = (
1841
1981
  self.table.search()
1842
- .where(f"namespace = '{safe_ns}'")
1843
- .select(["created_at", "content"])
1982
+ .where(filter_expr)
1983
+ .select(["id"])
1984
+ .limit(1000000) # High limit to count all
1844
1985
  .to_list()
1845
1986
  )
1987
+ memory_count = len(count_results)
1846
1988
 
1847
- if not records:
1989
+ if memory_count == 0:
1848
1990
  return {
1849
1991
  "namespace": namespace,
1850
1992
  "memory_count": 0,
@@ -1853,18 +1995,42 @@ class Database:
1853
1995
  "avg_content_length": None,
1854
1996
  }
1855
1997
 
1856
- # Find oldest and newest
1857
- created_times = [r["created_at"] for r in records]
1858
- oldest = min(created_times)
1859
- newest = max(created_times)
1998
+ # Get oldest memory (sort ascending, limit 1)
1999
+ oldest_records = (
2000
+ self.table.search()
2001
+ .where(filter_expr)
2002
+ .select(["created_at"])
2003
+ .limit(1)
2004
+ .to_list()
2005
+ )
2006
+ oldest = oldest_records[0]["created_at"] if oldest_records else None
2007
+
2008
+ # Get newest memory - need to fetch more and find max since LanceDB
2009
+ # doesn't support ORDER BY DESC efficiently
2010
+ # Sample up to 1000 records for stats to avoid loading everything
2011
+ sample_size = min(memory_count, 1000)
2012
+ sample_records = (
2013
+ self.table.search()
2014
+ .where(filter_expr)
2015
+ .select(["created_at", "content"])
2016
+ .limit(sample_size)
2017
+ .to_list()
2018
+ )
1860
2019
 
1861
- # Calculate average content length
1862
- content_lengths = [len(r.get("content", "")) for r in records]
1863
- avg_content_length = sum(content_lengths) / len(content_lengths)
2020
+ # Find newest from sample (for large namespaces this is approximate)
2021
+ if sample_records:
2022
+ created_times = [r["created_at"] for r in sample_records]
2023
+ newest = max(created_times)
2024
+ # Calculate average content length from sample
2025
+ content_lengths = [len(r.get("content", "")) for r in sample_records]
2026
+ avg_content_length = sum(content_lengths) / len(content_lengths)
2027
+ else:
2028
+ newest = oldest
2029
+ avg_content_length = None
1864
2030
 
1865
2031
  return {
1866
2032
  "namespace": namespace,
1867
- "memory_count": len(records),
2033
+ "memory_count": memory_count,
1868
2034
  "oldest_memory": oldest,
1869
2035
  "newest_memory": newest,
1870
2036
  "avg_content_length": avg_content_length,
@@ -2024,21 +2190,23 @@ class Database:
2024
2190
 
2025
2191
  @with_process_lock
2026
2192
  @with_write_lock
2027
- def delete_batch(self, memory_ids: list[str]) -> int:
2193
+ def delete_batch(self, memory_ids: list[str]) -> tuple[int, list[str]]:
2028
2194
  """Delete multiple memories atomically using IN clause.
2029
2195
 
2030
2196
  Args:
2031
2197
  memory_ids: List of memory UUIDs to delete.
2032
2198
 
2033
2199
  Returns:
2034
- Number of memories actually deleted.
2200
+ Tuple of (count_deleted, list_of_deleted_ids) where:
2201
+ - count_deleted: Number of memories actually deleted
2202
+ - list_of_deleted_ids: IDs that were actually deleted
2035
2203
 
2036
2204
  Raises:
2037
2205
  ValidationError: If any memory_id is invalid.
2038
2206
  StorageError: If database operation fails.
2039
2207
  """
2040
2208
  if not memory_ids:
2041
- return 0
2209
+ return (0, [])
2042
2210
 
2043
2211
  # Validate all IDs first (fail fast)
2044
2212
  validated_ids: list[str] = []
@@ -2047,21 +2215,32 @@ class Database:
2047
2215
  validated_ids.append(_sanitize_string(validated_id))
2048
2216
 
2049
2217
  try:
2050
- count_before: int = self.table.count_rows()
2051
-
2052
- # Build IN clause for atomic batch delete
2218
+ # First, check which IDs actually exist
2053
2219
  id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
2054
2220
  filter_expr = f"id IN ({id_list})"
2055
- self.table.delete(filter_expr)
2221
+ existing_records = (
2222
+ self.table.search()
2223
+ .where(filter_expr)
2224
+ .select(["id"])
2225
+ .limit(len(validated_ids))
2226
+ .to_list()
2227
+ )
2228
+ existing_ids = [r["id"] for r in existing_records]
2229
+
2230
+ if not existing_ids:
2231
+ return (0, [])
2232
+
2233
+ # Delete only existing IDs
2234
+ existing_id_list = ", ".join(f"'{mid}'" for mid in existing_ids)
2235
+ delete_expr = f"id IN ({existing_id_list})"
2236
+ self.table.delete(delete_expr)
2056
2237
 
2057
2238
  self._invalidate_count_cache()
2239
+ self._track_modification()
2058
2240
  self._invalidate_namespace_cache()
2059
2241
 
2060
- count_after: int = self.table.count_rows()
2061
- deleted = count_before - count_after
2062
-
2063
- logger.debug(f"Batch deleted {deleted} memories")
2064
- return deleted
2242
+ logger.debug(f"Batch deleted {len(existing_ids)} memories")
2243
+ return (len(existing_ids), existing_ids)
2065
2244
  except ValidationError:
2066
2245
  raise
2067
2246
  except Exception as e:
@@ -2159,6 +2338,10 @@ class Database:
2159
2338
  backoff=self.retry_backoff_seconds,
2160
2339
  )
2161
2340
 
2341
+ # ========================================================================
2342
+ # Search Operations (delegates to SearchManager)
2343
+ # ========================================================================
2344
+
2162
2345
  def _calculate_search_params(
2163
2346
  self,
2164
2347
  count: int,
@@ -2166,59 +2349,12 @@ class Database:
2166
2349
  nprobes_override: int | None = None,
2167
2350
  refine_factor_override: int | None = None,
2168
2351
  ) -> tuple[int, int]:
2169
- """Calculate optimal search parameters based on dataset size and limit.
2170
-
2171
- Dynamically tunes nprobes and refine_factor for optimal recall/speed tradeoff.
2172
-
2173
- Args:
2174
- count: Number of rows in the dataset.
2175
- limit: Number of results requested.
2176
- nprobes_override: Optional override for nprobes (uses this if provided).
2177
- refine_factor_override: Optional override for refine_factor.
2178
-
2179
- Returns:
2180
- Tuple of (nprobes, refine_factor).
2181
-
2182
- Scaling rules:
2183
- - nprobes: Base from config, scaled up for larger datasets
2184
- - <100K: config value (default 20)
2185
- - 100K-1M: max(config, 30)
2186
- - 1M-10M: max(config, 50)
2187
- - >10M: max(config, 100)
2188
- - refine_factor: Base from config, scaled up for small limits
2189
- - limit <= 5: config value * 2
2190
- - limit <= 20: config value
2191
- - limit > 20: max(config // 2, 2)
2192
- """
2193
- # Calculate nprobes based on dataset size
2194
- if nprobes_override is not None:
2195
- nprobes = nprobes_override
2196
- else:
2197
- base_nprobes = self.index_nprobes
2198
- if count < 100_000:
2199
- nprobes = base_nprobes
2200
- elif count < 1_000_000:
2201
- nprobes = max(base_nprobes, 30)
2202
- elif count < 10_000_000:
2203
- nprobes = max(base_nprobes, 50)
2204
- else:
2205
- nprobes = max(base_nprobes, 100)
2206
-
2207
- # Calculate refine_factor based on limit
2208
- if refine_factor_override is not None:
2209
- refine_factor = refine_factor_override
2210
- else:
2211
- base_refine = self.index_refine_factor
2212
- if limit <= 5:
2213
- # Small limits need more refinement for accuracy
2214
- refine_factor = base_refine * 2
2215
- elif limit <= 20:
2216
- refine_factor = base_refine
2217
- else:
2218
- # Large limits can use less refinement
2219
- refine_factor = max(base_refine // 2, 2)
2220
-
2221
- return nprobes, refine_factor
2352
+ """Calculate optimal search parameters. Delegates to SearchManager."""
2353
+ if self._search_manager is None:
2354
+ raise StorageError("Database not connected")
2355
+ return self._search_manager.calculate_search_params(
2356
+ count, limit, nprobes_override, refine_factor_override
2357
+ )
2222
2358
 
2223
2359
  @with_stale_connection_recovery
2224
2360
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2232,19 +2368,16 @@ class Database:
2232
2368
  refine_factor: int | None = None,
2233
2369
  include_vector: bool = False,
2234
2370
  ) -> list[dict[str, Any]]:
2235
- """Search for similar memories by vector with performance tuning.
2371
+ """Search for similar memories by vector. Delegates to SearchManager.
2236
2372
 
2237
2373
  Args:
2238
2374
  query_vector: Query embedding vector.
2239
2375
  limit: Maximum number of results.
2240
2376
  namespace: Filter to specific namespace.
2241
2377
  min_similarity: Minimum similarity threshold (0-1).
2242
- nprobes: Number of partitions to search (higher = better recall).
2243
- Only effective when vector index exists. Defaults to dynamic calculation.
2378
+ nprobes: Number of partitions to search.
2244
2379
  refine_factor: Re-rank top (refine_factor * limit) for accuracy.
2245
- Defaults to dynamic calculation based on limit.
2246
2380
  include_vector: Whether to include vector embeddings in results.
2247
- Defaults to False to reduce response size.
2248
2381
 
2249
2382
  Returns:
2250
2383
  List of memory records with similarity scores.
@@ -2253,66 +2386,53 @@ class Database:
2253
2386
  ValidationError: If input validation fails.
2254
2387
  StorageError: If database operation fails.
2255
2388
  """
2256
- try:
2257
- search = self.table.search(query_vector.tolist())
2389
+ if self._search_manager is None:
2390
+ raise StorageError("Database not connected")
2391
+ return self._search_manager.vector_search(
2392
+ query_vector=query_vector,
2393
+ limit=limit,
2394
+ namespace=namespace,
2395
+ min_similarity=min_similarity,
2396
+ nprobes=nprobes,
2397
+ refine_factor=refine_factor,
2398
+ include_vector=include_vector,
2399
+ )
2258
2400
 
2259
- # Distance type for queries (cosine for semantic similarity)
2260
- # Note: When vector index exists, the index's metric is used
2261
- search = search.distance_type("cosine")
2401
+ @with_stale_connection_recovery
2402
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2403
+ def batch_vector_search_native(
2404
+ self,
2405
+ query_vectors: list[np.ndarray],
2406
+ limit_per_query: int = 3,
2407
+ namespace: str | None = None,
2408
+ min_similarity: float = 0.0,
2409
+ include_vector: bool = False,
2410
+ ) -> list[list[dict[str, Any]]]:
2411
+ """Batch search using native LanceDB. Delegates to SearchManager.
2262
2412
 
2263
- # Apply performance tuning when index exists (use cached count)
2264
- count = self._get_cached_row_count()
2265
- if count > self.vector_index_threshold and self._has_vector_index:
2266
- # Use dynamic calculation for search params
2267
- actual_nprobes, actual_refine = self._calculate_search_params(
2268
- count, limit, nprobes, refine_factor
2269
- )
2270
- search = search.nprobes(actual_nprobes)
2271
- search = search.refine_factor(actual_refine)
2413
+ Args:
2414
+ query_vectors: List of query embedding vectors.
2415
+ limit_per_query: Maximum number of results per query.
2416
+ namespace: Filter to specific namespace.
2417
+ min_similarity: Minimum similarity threshold (0-1).
2418
+ include_vector: Whether to include vector embeddings in results.
2272
2419
 
2273
- # Build filter with sanitized namespace
2274
- # prefilter=True applies namespace filter BEFORE vector search for better performance
2275
- if namespace:
2276
- namespace = _validate_namespace(namespace)
2277
- safe_ns = _sanitize_string(namespace)
2278
- search = search.where(f"namespace = '{safe_ns}'", prefilter=True)
2279
-
2280
- # Vector projection: exclude vector column to reduce response size
2281
- if not include_vector:
2282
- search = search.select([
2283
- "id", "content", "namespace", "metadata",
2284
- "created_at", "updated_at", "last_accessed",
2285
- "importance", "tags", "source", "access_count",
2286
- "expires_at",
2287
- ])
2288
-
2289
- # Fetch extra if filtering by similarity
2290
- fetch_limit = limit * 2 if min_similarity > 0.0 else limit
2291
- results: list[dict[str, Any]] = search.limit(fetch_limit).to_list()
2292
-
2293
- # Process results
2294
- filtered_results: list[dict[str, Any]] = []
2295
- for record in results:
2296
- record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
2297
- # LanceDB returns _distance, convert to similarity
2298
- if "_distance" in record:
2299
- # Cosine distance to similarity: 1 - distance
2300
- # Clamp to [0, 1] (cosine distance can exceed 1 for unnormalized)
2301
- similarity = max(0.0, min(1.0, 1 - record["_distance"]))
2302
- record["similarity"] = similarity
2303
- del record["_distance"]
2304
-
2305
- # Apply similarity threshold
2306
- if record.get("similarity", 0) >= min_similarity:
2307
- filtered_results.append(record)
2308
- if len(filtered_results) >= limit:
2309
- break
2310
-
2311
- return filtered_results
2312
- except ValidationError:
2313
- raise
2314
- except Exception as e:
2315
- raise StorageError(f"Failed to search: {e}") from e
2420
+ Returns:
2421
+ List of result lists, one per query vector.
2422
+
2423
+ Raises:
2424
+ ValidationError: If input validation fails.
2425
+ StorageError: If database operation fails.
2426
+ """
2427
+ if self._search_manager is None:
2428
+ raise StorageError("Database not connected")
2429
+ return self._search_manager.batch_vector_search_native(
2430
+ query_vectors=query_vectors,
2431
+ limit_per_query=limit_per_query,
2432
+ namespace=namespace,
2433
+ min_similarity=min_similarity,
2434
+ include_vector=include_vector,
2435
+ )
2316
2436
 
2317
2437
  @with_stale_connection_recovery
2318
2438
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2325,10 +2445,7 @@ class Database:
2325
2445
  alpha: float = 0.5,
2326
2446
  min_similarity: float = 0.0,
2327
2447
  ) -> list[dict[str, Any]]:
2328
- """Hybrid search combining vector similarity and keyword matching.
2329
-
2330
- Uses LinearCombinationReranker to balance vector and keyword scores
2331
- based on the alpha parameter.
2448
+ """Hybrid search combining vector and keyword. Delegates to SearchManager.
2332
2449
 
2333
2450
  Args:
2334
2451
  query: Text query for full-text search.
@@ -2336,9 +2453,7 @@ class Database:
2336
2453
  limit: Number of results.
2337
2454
  namespace: Filter to namespace.
2338
2455
  alpha: Balance between vector (1.0) and keyword (0.0).
2339
- 0.5 = balanced (recommended).
2340
- min_similarity: Minimum similarity threshold (0.0-1.0).
2341
- Results below this threshold are filtered out.
2456
+ min_similarity: Minimum similarity threshold.
2342
2457
 
2343
2458
  Returns:
2344
2459
  List of memory records with combined scores.
@@ -2347,80 +2462,16 @@ class Database:
2347
2462
  ValidationError: If input validation fails.
2348
2463
  StorageError: If database operation fails.
2349
2464
  """
2350
- try:
2351
- # Check if FTS is available
2352
- if not self._has_fts_index:
2353
- logger.debug("FTS index not available, falling back to vector search")
2354
- return self.vector_search(query_vector, limit=limit, namespace=namespace)
2355
-
2356
- # Create hybrid search with explicit vector column specification
2357
- # Required when using external embeddings (not LanceDB built-in)
2358
- search = (
2359
- self.table.search(query, query_type="hybrid")
2360
- .vector(query_vector.tolist())
2361
- .vector_column_name("vector")
2362
- )
2363
-
2364
- # Apply alpha parameter using LinearCombinationReranker
2365
- # alpha=1.0 means full vector, alpha=0.0 means full FTS
2366
- try:
2367
- from lancedb.rerankers import LinearCombinationReranker
2368
-
2369
- reranker = LinearCombinationReranker(weight=alpha)
2370
- search = search.rerank(reranker)
2371
- except ImportError:
2372
- logger.debug("LinearCombinationReranker not available, using default reranking")
2373
- except Exception as e:
2374
- logger.debug(f"Could not apply reranker: {e}")
2375
-
2376
- # Apply namespace filter
2377
- if namespace:
2378
- namespace = _validate_namespace(namespace)
2379
- safe_ns = _sanitize_string(namespace)
2380
- search = search.where(f"namespace = '{safe_ns}'")
2381
-
2382
- results: list[dict[str, Any]] = search.limit(limit).to_list()
2383
-
2384
- # Process results - normalize scores and clean up internal columns
2385
- processed_results: list[dict[str, Any]] = []
2386
- for record in results:
2387
- record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
2388
-
2389
- # Compute similarity from various score columns
2390
- # Priority: _relevance_score > _distance > _score > default
2391
- similarity: float
2392
- if "_relevance_score" in record:
2393
- # Reranker output - use directly (already 0-1 range)
2394
- similarity = float(record["_relevance_score"])
2395
- del record["_relevance_score"]
2396
- elif "_distance" in record:
2397
- # Vector distance - convert to similarity
2398
- similarity = max(0.0, min(1.0, 1 - float(record["_distance"])))
2399
- del record["_distance"]
2400
- elif "_score" in record:
2401
- # BM25 score - normalize using score/(1+score)
2402
- score = float(record["_score"])
2403
- similarity = score / (1.0 + score)
2404
- del record["_score"]
2405
- else:
2406
- # No score column - use default
2407
- similarity = 0.5
2408
-
2409
- record["similarity"] = similarity
2410
-
2411
- # Mark as hybrid result with alpha value
2412
- record["search_type"] = "hybrid"
2413
- record["alpha"] = alpha
2414
-
2415
- # Apply min_similarity filter
2416
- if similarity >= min_similarity:
2417
- processed_results.append(record)
2418
-
2419
- return processed_results
2420
-
2421
- except Exception as e:
2422
- logger.warning(f"Hybrid search failed, falling back to vector search: {e}")
2423
- return self.vector_search(query_vector, limit=limit, namespace=namespace)
2465
+ if self._search_manager is None:
2466
+ raise StorageError("Database not connected")
2467
+ return self._search_manager.hybrid_search(
2468
+ query=query,
2469
+ query_vector=query_vector,
2470
+ limit=limit,
2471
+ namespace=namespace,
2472
+ alpha=alpha,
2473
+ min_similarity=min_similarity,
2474
+ )
2424
2475
 
2425
2476
  @with_stale_connection_recovery
2426
2477
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2429,20 +2480,19 @@ class Database:
2429
2480
  query_vectors: list[np.ndarray],
2430
2481
  limit_per_query: int = 3,
2431
2482
  namespace: str | None = None,
2432
- parallel: bool = False,
2433
- max_workers: int = 4,
2483
+ parallel: bool = False, # Deprecated
2484
+ max_workers: int = 4, # Deprecated
2485
+ include_vector: bool = False,
2434
2486
  ) -> list[list[dict[str, Any]]]:
2435
- """Search for similar memories using multiple query vectors.
2436
-
2437
- Efficient for operations like journey interpolation where multiple
2438
- points need to find nearby memories.
2487
+ """Search using multiple query vectors. Delegates to SearchManager.
2439
2488
 
2440
2489
  Args:
2441
2490
  query_vectors: List of query embedding vectors.
2442
2491
  limit_per_query: Maximum results per query vector.
2443
2492
  namespace: Filter to specific namespace.
2444
- parallel: Execute searches in parallel using ThreadPoolExecutor.
2445
- max_workers: Maximum worker threads for parallel execution.
2493
+ parallel: Deprecated, kept for backward compatibility.
2494
+ max_workers: Deprecated, kept for backward compatibility.
2495
+ include_vector: Whether to include vector embeddings in results.
2446
2496
 
2447
2497
  Returns:
2448
2498
  List of result lists (one per query vector).
@@ -2450,52 +2500,16 @@ class Database:
2450
2500
  Raises:
2451
2501
  StorageError: If database operation fails.
2452
2502
  """
2453
- if not query_vectors:
2454
- return []
2455
-
2456
- # Build namespace filter once
2457
- where_clause: str | None = None
2458
- if namespace:
2459
- namespace = _validate_namespace(namespace)
2460
- safe_ns = _sanitize_string(namespace)
2461
- where_clause = f"namespace = '{safe_ns}'"
2462
-
2463
- def search_single(vec: np.ndarray) -> list[dict[str, Any]]:
2464
- """Execute a single vector search."""
2465
- search = self.table.search(vec.tolist()).distance_type("cosine")
2466
-
2467
- if where_clause:
2468
- search = search.where(where_clause)
2469
-
2470
- results: list[dict[str, Any]] = search.limit(limit_per_query).to_list()
2471
-
2472
- # Process results
2473
- for record in results:
2474
- meta = record["metadata"]
2475
- record["metadata"] = json.loads(meta) if meta else {}
2476
- if "_distance" in record:
2477
- record["similarity"] = max(0.0, min(1.0, 1 - record["_distance"]))
2478
- del record["_distance"]
2479
-
2480
- return results
2481
-
2482
- try:
2483
- if parallel and len(query_vectors) > 1:
2484
- # Use ThreadPoolExecutor for parallel execution
2485
- from concurrent.futures import ThreadPoolExecutor
2486
-
2487
- workers = min(max_workers, len(query_vectors))
2488
- with ThreadPoolExecutor(max_workers=workers) as executor:
2489
- # Map preserves order
2490
- all_results = list(executor.map(search_single, query_vectors))
2491
- else:
2492
- # Sequential execution
2493
- all_results = [search_single(vec) for vec in query_vectors]
2494
-
2495
- return all_results
2496
-
2497
- except Exception as e:
2498
- raise StorageError(f"Batch vector search failed: {e}") from e
2503
+ if self._search_manager is None:
2504
+ raise StorageError("Database not connected")
2505
+ return self._search_manager.batch_vector_search(
2506
+ query_vectors=query_vectors,
2507
+ limit_per_query=limit_per_query,
2508
+ namespace=namespace,
2509
+ parallel=parallel,
2510
+ max_workers=max_workers,
2511
+ include_vector=include_vector,
2512
+ )
2499
2513
 
2500
2514
  def get_vectors_for_clustering(
2501
2515
  self,
@@ -2941,6 +2955,7 @@ class Database:
2941
2955
 
2942
2956
  if deleted > 0:
2943
2957
  self._invalidate_count_cache()
2958
+ self._track_modification(deleted)
2944
2959
  self._invalidate_namespace_cache()
2945
2960
  logger.info(f"Cleaned up {deleted} expired memories")
2946
2961
 
@@ -2949,148 +2964,58 @@ class Database:
2949
2964
  raise StorageError(f"Failed to cleanup expired memories: {e}") from e
2950
2965
 
2951
2966
  # ========================================================================
2952
- # Snapshot / Version Management
2967
+ # Snapshot / Version Management (delegated to VersionManager)
2953
2968
  # ========================================================================
2954
2969
 
2955
2970
  def create_snapshot(self, tag: str) -> int:
2956
2971
  """Create a named snapshot of the current table state.
2957
2972
 
2958
- LanceDB automatically versions data on every write. This method
2959
- returns the current version number which can be used with restore_snapshot().
2960
-
2961
- Args:
2962
- tag: Semantic version tag (e.g., "v1.0.0", "backup-2024-01").
2963
- Note: Tag is logged for reference but LanceDB tracks versions
2964
- numerically. Consider storing tag->version mappings externally
2965
- if tag-based retrieval is needed.
2966
-
2967
- Returns:
2968
- Version number of the snapshot.
2969
-
2970
- Raises:
2971
- StorageError: If snapshot creation fails.
2973
+ Delegates to VersionManager. See VersionManager.create_snapshot for details.
2972
2974
  """
2973
- try:
2974
- version = self.table.version
2975
- logger.info(f"Created snapshot '{tag}' at version {version}")
2976
- return version
2977
- except Exception as e:
2978
- raise StorageError(f"Failed to create snapshot: {e}") from e
2975
+ if self._version_manager is None:
2976
+ raise StorageError("Database not connected")
2977
+ return self._version_manager.create_snapshot(tag)
2979
2978
 
2980
2979
  def list_snapshots(self) -> list[dict[str, Any]]:
2981
2980
  """List available versions/snapshots.
2982
2981
 
2983
- Returns:
2984
- List of version information dictionaries. Each dict contains
2985
- at minimum 'version' key. Additional fields depend on LanceDB
2986
- version and available metadata.
2987
-
2988
- Raises:
2989
- StorageError: If listing fails.
2982
+ Delegates to VersionManager. See VersionManager.list_snapshots for details.
2990
2983
  """
2991
- try:
2992
- versions_info: list[dict[str, Any]] = []
2993
-
2994
- # Try to get version history if available
2995
- if hasattr(self.table, "list_versions"):
2996
- try:
2997
- versions = self.table.list_versions()
2998
- for v in versions:
2999
- if isinstance(v, dict):
3000
- versions_info.append(v)
3001
- elif hasattr(v, "version"):
3002
- versions_info.append({
3003
- "version": v.version,
3004
- "timestamp": getattr(v, "timestamp", None),
3005
- })
3006
- else:
3007
- versions_info.append({"version": v})
3008
- except Exception as e:
3009
- logger.debug(f"list_versions not fully supported: {e}")
3010
-
3011
- # Always include current version
3012
- if not versions_info:
3013
- versions_info.append({"version": self.table.version})
3014
-
3015
- return versions_info
3016
- except Exception as e:
3017
- logger.warning(f"Could not list snapshots: {e}")
3018
- return [{"version": 0, "error": str(e)}]
2984
+ if self._version_manager is None:
2985
+ raise StorageError("Database not connected")
2986
+ return self._version_manager.list_snapshots()
3019
2987
 
3020
2988
  def restore_snapshot(self, version: int) -> None:
3021
2989
  """Restore table to a specific version.
3022
2990
 
3023
- This creates a NEW version that reflects the old state
3024
- (doesn't delete history).
3025
-
3026
- Args:
3027
- version: The version number to restore to.
3028
-
3029
- Raises:
3030
- ValidationError: If version is invalid.
3031
- StorageError: If restore fails.
2991
+ Delegates to VersionManager. See VersionManager.restore_snapshot for details.
3032
2992
  """
3033
- if version < 0:
3034
- raise ValidationError("Version must be non-negative")
3035
-
3036
- try:
3037
- self.table.restore(version)
3038
- self._invalidate_count_cache()
3039
- self._invalidate_namespace_cache()
3040
- logger.info(f"Restored to version {version}")
3041
- except Exception as e:
3042
- raise StorageError(f"Failed to restore snapshot: {e}") from e
2993
+ if self._version_manager is None:
2994
+ raise StorageError("Database not connected")
2995
+ self._version_manager.restore_snapshot(version)
3043
2996
 
3044
2997
  def get_current_version(self) -> int:
3045
2998
  """Get the current table version number.
3046
2999
 
3047
- Returns:
3048
- Current version number.
3049
-
3050
- Raises:
3051
- StorageError: If version cannot be retrieved.
3000
+ Delegates to VersionManager. See VersionManager.get_current_version for details.
3052
3001
  """
3053
- try:
3054
- return self.table.version
3055
- except Exception as e:
3056
- raise StorageError(f"Failed to get current version: {e}") from e
3002
+ if self._version_manager is None:
3003
+ raise StorageError("Database not connected")
3004
+ return self._version_manager.get_current_version()
3057
3005
 
3058
3006
  # ========================================================================
3059
- # Idempotency Key Management
3007
+ # Idempotency Key Management (delegates to IdempotencyManager)
3060
3008
  # ========================================================================
3061
3009
 
3062
- def _ensure_idempotency_table(self) -> None:
3063
- """Ensure the idempotency keys table exists."""
3064
- if self._db is None:
3065
- raise StorageError("Database not connected")
3066
-
3067
- existing_tables_result = self._db.list_tables()
3068
- if hasattr(existing_tables_result, 'tables'):
3069
- existing_tables = existing_tables_result.tables
3070
- else:
3071
- existing_tables = existing_tables_result
3072
-
3073
- if "idempotency_keys" not in existing_tables:
3074
- schema = pa.schema([
3075
- pa.field("key", pa.string()),
3076
- pa.field("memory_id", pa.string()),
3077
- pa.field("created_at", pa.timestamp("us")),
3078
- pa.field("expires_at", pa.timestamp("us")),
3079
- ])
3080
- self._db.create_table("idempotency_keys", schema=schema)
3081
- logger.info("Created idempotency_keys table")
3082
-
3083
3010
  @property
3084
3011
  def idempotency_table(self) -> LanceTable:
3085
- """Get the idempotency keys table, creating if needed."""
3086
- if self._db is None:
3087
- self.connect()
3088
- self._ensure_idempotency_table()
3089
- assert self._db is not None
3090
- return self._db.open_table("idempotency_keys")
3012
+ """Get the idempotency keys table. Delegates to IdempotencyManager."""
3013
+ if self._idempotency_manager is None:
3014
+ raise StorageError("Database not connected")
3015
+ return self._idempotency_manager.idempotency_table
3091
3016
 
3092
3017
  def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
3093
- """Look up an idempotency record by key.
3018
+ """Look up an idempotency record by key. Delegates to IdempotencyManager.
3094
3019
 
3095
3020
  Args:
3096
3021
  key: The idempotency key to look up.
@@ -3101,42 +3026,9 @@ class Database:
3101
3026
  Raises:
3102
3027
  StorageError: If database operation fails.
3103
3028
  """
3104
- if not key:
3105
- return None
3106
-
3107
- try:
3108
- safe_key = _sanitize_string(key)
3109
- results = (
3110
- self.idempotency_table.search()
3111
- .where(f"key = '{safe_key}'")
3112
- .limit(1)
3113
- .to_list()
3114
- )
3115
-
3116
- if not results:
3117
- return None
3118
-
3119
- record = results[0]
3120
- now = utc_now()
3121
-
3122
- # Check if expired (convert DB naive datetime to aware for comparison)
3123
- expires_at = record.get("expires_at")
3124
- if expires_at is not None:
3125
- expires_at_aware = to_aware_utc(expires_at)
3126
- if expires_at_aware < now:
3127
- # Expired - clean it up and return None
3128
- logger.debug(f"Idempotency key '{key}' has expired")
3129
- return None
3130
-
3131
- return IdempotencyRecord(
3132
- key=record["key"],
3133
- memory_id=record["memory_id"],
3134
- created_at=record["created_at"],
3135
- expires_at=record["expires_at"],
3136
- )
3137
-
3138
- except Exception as e:
3139
- raise StorageError(f"Failed to look up idempotency key: {e}") from e
3029
+ if self._idempotency_manager is None:
3030
+ raise StorageError("Database not connected")
3031
+ return self._idempotency_manager.get_by_idempotency_key(key)
3140
3032
 
3141
3033
  @with_process_lock
3142
3034
  @with_write_lock
@@ -3146,7 +3038,7 @@ class Database:
3146
3038
  memory_id: str,
3147
3039
  ttl_hours: float = 24.0,
3148
3040
  ) -> None:
3149
- """Store an idempotency key mapping.
3041
+ """Store an idempotency key mapping. Delegates to IdempotencyManager.
3150
3042
 
3151
3043
  Args:
3152
3044
  key: The idempotency key.
@@ -3157,36 +3049,14 @@ class Database:
3157
3049
  ValidationError: If inputs are invalid.
3158
3050
  StorageError: If database operation fails.
3159
3051
  """
3160
- if not key:
3161
- raise ValidationError("Idempotency key cannot be empty")
3162
- if not memory_id:
3163
- raise ValidationError("Memory ID cannot be empty")
3164
- if ttl_hours <= 0:
3165
- raise ValidationError("TTL must be positive")
3166
-
3167
- now = utc_now()
3168
- expires_at = now + timedelta(hours=ttl_hours)
3169
-
3170
- record = {
3171
- "key": key,
3172
- "memory_id": memory_id,
3173
- "created_at": now,
3174
- "expires_at": expires_at,
3175
- }
3176
-
3177
- try:
3178
- self.idempotency_table.add([record])
3179
- logger.debug(
3180
- f"Stored idempotency key '{key}' -> memory '{memory_id}' "
3181
- f"(expires in {ttl_hours}h)"
3182
- )
3183
- except Exception as e:
3184
- raise StorageError(f"Failed to store idempotency key: {e}") from e
3052
+ if self._idempotency_manager is None:
3053
+ raise StorageError("Database not connected")
3054
+ self._idempotency_manager.store_idempotency_key(key, memory_id, ttl_hours)
3185
3055
 
3186
3056
  @with_process_lock
3187
3057
  @with_write_lock
3188
3058
  def cleanup_expired_idempotency_keys(self) -> int:
3189
- """Remove expired idempotency keys.
3059
+ """Remove expired idempotency keys. Delegates to IdempotencyManager.
3190
3060
 
3191
3061
  Returns:
3192
3062
  Number of keys removed.
@@ -3194,20 +3064,6 @@ class Database:
3194
3064
  Raises:
3195
3065
  StorageError: If cleanup fails.
3196
3066
  """
3197
- try:
3198
- now = utc_now()
3199
- count_before = self.idempotency_table.count_rows()
3200
-
3201
- # Delete expired keys
3202
- predicate = f"expires_at < timestamp '{now.isoformat()}'"
3203
- self.idempotency_table.delete(predicate)
3204
-
3205
- count_after = self.idempotency_table.count_rows()
3206
- deleted = count_before - count_after
3207
-
3208
- if deleted > 0:
3209
- logger.info(f"Cleaned up {deleted} expired idempotency keys")
3210
-
3211
- return deleted
3212
- except Exception as e:
3213
- raise StorageError(f"Failed to cleanup idempotency keys: {e}") from e
3067
+ if self._idempotency_manager is None:
3068
+ raise StorageError("Database not connected")
3069
+ return self._idempotency_manager.cleanup_expired_idempotency_keys()