spatial-memory-mcp 1.0.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spatial-memory-mcp might be problematic. Click here for more details.

Files changed (39) hide show
  1. spatial_memory/__init__.py +97 -97
  2. spatial_memory/__main__.py +241 -2
  3. spatial_memory/adapters/lancedb_repository.py +74 -5
  4. spatial_memory/config.py +115 -2
  5. spatial_memory/core/__init__.py +35 -0
  6. spatial_memory/core/cache.py +317 -0
  7. spatial_memory/core/circuit_breaker.py +297 -0
  8. spatial_memory/core/connection_pool.py +41 -3
  9. spatial_memory/core/consolidation_strategies.py +402 -0
  10. spatial_memory/core/database.py +791 -769
  11. spatial_memory/core/db_idempotency.py +242 -0
  12. spatial_memory/core/db_indexes.py +575 -0
  13. spatial_memory/core/db_migrations.py +584 -0
  14. spatial_memory/core/db_search.py +509 -0
  15. spatial_memory/core/db_versioning.py +177 -0
  16. spatial_memory/core/embeddings.py +156 -19
  17. spatial_memory/core/errors.py +75 -3
  18. spatial_memory/core/filesystem.py +178 -0
  19. spatial_memory/core/logging.py +194 -103
  20. spatial_memory/core/models.py +4 -0
  21. spatial_memory/core/rate_limiter.py +326 -105
  22. spatial_memory/core/response_types.py +497 -0
  23. spatial_memory/core/tracing.py +300 -0
  24. spatial_memory/core/validation.py +403 -319
  25. spatial_memory/factory.py +407 -0
  26. spatial_memory/migrations/__init__.py +40 -0
  27. spatial_memory/ports/repositories.py +52 -2
  28. spatial_memory/server.py +329 -188
  29. spatial_memory/services/export_import.py +61 -43
  30. spatial_memory/services/lifecycle.py +397 -122
  31. spatial_memory/services/memory.py +81 -4
  32. spatial_memory/services/spatial.py +129 -46
  33. spatial_memory/tools/definitions.py +695 -671
  34. {spatial_memory_mcp-1.0.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/METADATA +83 -3
  35. spatial_memory_mcp-1.6.0.dist-info/RECORD +54 -0
  36. spatial_memory_mcp-1.0.3.dist-info/RECORD +0 -41
  37. {spatial_memory_mcp-1.0.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/WHEEL +0 -0
  38. {spatial_memory_mcp-1.0.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/entry_points.txt +0 -0
  39. {spatial_memory_mcp-1.0.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -34,8 +34,21 @@ import pyarrow.parquet as pq
34
34
  from filelock import FileLock, Timeout as FileLockTimeout
35
35
 
36
36
  from spatial_memory.core.connection_pool import ConnectionPool
37
- from spatial_memory.core.errors import FileLockError, MemoryNotFoundError, StorageError, ValidationError
38
- from spatial_memory.core.utils import utc_now
37
+ from spatial_memory.core.db_idempotency import IdempotencyManager, IdempotencyRecord
38
+ from spatial_memory.core.db_indexes import IndexManager
39
+ from spatial_memory.core.db_migrations import CURRENT_SCHEMA_VERSION, MigrationManager
40
+ from spatial_memory.core.db_search import SearchManager
41
+ from spatial_memory.core.db_versioning import VersionManager
42
+ from spatial_memory.core.errors import (
43
+ DimensionMismatchError,
44
+ FileLockError,
45
+ MemoryNotFoundError,
46
+ PartialBatchInsertError,
47
+ StorageError,
48
+ ValidationError,
49
+ )
50
+ from spatial_memory.core.filesystem import detect_filesystem_type, get_filesystem_warning_message, is_network_filesystem
51
+ from spatial_memory.core.utils import to_aware_utc, utc_now
39
52
 
40
53
  # Import centralized validation functions
41
54
  from spatial_memory.core.validation import (
@@ -131,9 +144,14 @@ def invalidate_connection(storage_path: Path) -> bool:
131
144
  # Retry Decorator
132
145
  # ============================================================================
133
146
 
147
+ # Default retry settings (can be overridden per-call)
148
+ DEFAULT_RETRY_MAX_ATTEMPTS = 3
149
+ DEFAULT_RETRY_BACKOFF_SECONDS = 0.5
150
+
151
+
134
152
  def retry_on_storage_error(
135
- max_attempts: int = 3,
136
- backoff: float = 0.5,
153
+ max_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
154
+ backoff: float = DEFAULT_RETRY_BACKOFF_SECONDS,
137
155
  ) -> Callable[[F], F]:
138
156
  """Retry decorator for transient storage errors.
139
157
 
@@ -483,8 +501,8 @@ class Database:
483
501
  enable_fts: bool = True,
484
502
  index_nprobes: int = 20,
485
503
  index_refine_factor: int = 5,
486
- max_retry_attempts: int = 3,
487
- retry_backoff_seconds: float = 0.5,
504
+ max_retry_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
505
+ retry_backoff_seconds: float = DEFAULT_RETRY_BACKOFF_SECONDS,
488
506
  read_consistency_interval_ms: int = 0,
489
507
  index_wait_timeout_seconds: float = 30.0,
490
508
  fts_stem: bool = True,
@@ -498,6 +516,7 @@ class Database:
498
516
  filelock_enabled: bool = True,
499
517
  filelock_timeout: float = 30.0,
500
518
  filelock_poll_interval: float = 0.1,
519
+ acknowledge_network_filesystem_risk: bool = False,
501
520
  ) -> None:
502
521
  """Initialize the database connection.
503
522
 
@@ -524,6 +543,7 @@ class Database:
524
543
  hnsw_ef_construction: HNSW build-time search width (100-1000).
525
544
  enable_memory_expiration: Enable automatic memory expiration.
526
545
  default_memory_ttl_days: Default TTL for memories in days (None = no expiration).
546
+ acknowledge_network_filesystem_risk: Suppress network filesystem warnings.
527
547
  """
528
548
  self.storage_path = Path(storage_path)
529
549
  self.embedding_dim = embedding_dim
@@ -547,6 +567,7 @@ class Database:
547
567
  self.filelock_enabled = filelock_enabled
548
568
  self.filelock_timeout = filelock_timeout
549
569
  self.filelock_poll_interval = filelock_poll_interval
570
+ self.acknowledge_network_filesystem_risk = acknowledge_network_filesystem_risk
550
571
  self._db: lancedb.DBConnection | None = None
551
572
  self._table: LanceTable | None = None
552
573
  self._has_vector_index: bool | None = None
@@ -564,6 +585,18 @@ class Database:
564
585
  self._write_lock = threading.RLock()
565
586
  # Cross-process lock (initialized in connect())
566
587
  self._process_lock: ProcessLockManager | None = None
588
+ # Auto-compaction tracking
589
+ self._modification_count: int = 0
590
+ self._auto_compaction_threshold: int = 100 # Compact after this many modifications
591
+ self._auto_compaction_enabled: bool = True
592
+ # Version manager (initialized in connect())
593
+ self._version_manager: VersionManager | None = None
594
+ # Index manager (initialized in connect())
595
+ self._index_manager: IndexManager | None = None
596
+ # Search manager (initialized in connect())
597
+ self._search_manager: SearchManager | None = None
598
+ # Idempotency manager (initialized in connect())
599
+ self._idempotency_manager: IdempotencyManager | None = None
567
600
 
568
601
  def __enter__(self) -> Database:
569
602
  """Enter context manager."""
@@ -579,6 +612,13 @@ class Database:
579
612
  try:
580
613
  self.storage_path.mkdir(parents=True, exist_ok=True)
581
614
 
615
+ # Check for network filesystem and warn if detected
616
+ if not self.acknowledge_network_filesystem_risk:
617
+ if is_network_filesystem(self.storage_path):
618
+ fs_type = detect_filesystem_type(self.storage_path)
619
+ warning_msg = get_filesystem_warning_message(fs_type, self.storage_path)
620
+ logger.warning(warning_msg)
621
+
582
622
  # Initialize cross-process lock manager
583
623
  if self.filelock_enabled:
584
624
  lock_path = self.storage_path / ".spatial-memory.lock"
@@ -597,109 +637,144 @@ class Database:
597
637
  read_consistency_interval_ms=self.read_consistency_interval_ms,
598
638
  )
599
639
  self._ensure_table()
640
+ # Initialize remaining managers (IndexManager already initialized in _ensure_table)
641
+ self._version_manager = VersionManager(self)
642
+ self._search_manager = SearchManager(self)
643
+ self._idempotency_manager = IdempotencyManager(self)
600
644
  logger.info(f"Connected to LanceDB at {self.storage_path}")
645
+
646
+ # Check for pending schema migrations
647
+ self._check_pending_migrations()
601
648
  except Exception as e:
602
649
  raise StorageError(f"Failed to connect to database: {e}") from e
603
650
 
651
+ def _check_pending_migrations(self) -> None:
652
+ """Check for pending migrations and warn if any exist.
653
+
654
+ This method checks the schema version and logs a warning if there
655
+ are pending migrations. It does not auto-apply migrations - that
656
+ requires explicit user action via the CLI.
657
+ """
658
+ try:
659
+ manager = MigrationManager(self, embeddings=None)
660
+ manager.register_builtin_migrations()
661
+
662
+ current_version = manager.get_current_version()
663
+ pending = manager.get_pending_migrations()
664
+
665
+ if pending:
666
+ pending_versions = [m.version for m in pending]
667
+ logger.warning(
668
+ f"Database schema version {current_version} is outdated. "
669
+ f"{len(pending)} migration(s) pending: {', '.join(pending_versions)}. "
670
+ f"Target version: {CURRENT_SCHEMA_VERSION}. "
671
+ f"Run 'spatial-memory migrate' to apply migrations."
672
+ )
673
+ except Exception as e:
674
+ # Don't fail connection due to migration check errors
675
+ logger.debug(f"Migration check skipped: {e}")
676
+
604
677
  def _ensure_table(self) -> None:
605
- """Ensure the memories table exists with appropriate indexes."""
678
+ """Ensure the memories table exists with appropriate indexes.
679
+
680
+ Uses retry logic to handle race conditions when multiple processes
681
+ attempt to create/open the table simultaneously.
682
+ """
606
683
  if self._db is None:
607
684
  raise StorageError("Database not connected")
608
685
 
609
- existing_tables_result = self._db.list_tables()
610
- # Handle both old (list) and new (object with .tables) LanceDB API
611
- if hasattr(existing_tables_result, 'tables'):
612
- existing_tables = existing_tables_result.tables
613
- else:
614
- existing_tables = existing_tables_result
615
- if "memories" not in existing_tables:
616
- # Create table with schema
617
- schema = pa.schema([
618
- pa.field("id", pa.string()),
619
- pa.field("content", pa.string()),
620
- pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
621
- pa.field("created_at", pa.timestamp("us")),
622
- pa.field("updated_at", pa.timestamp("us")),
623
- pa.field("last_accessed", pa.timestamp("us")),
624
- pa.field("access_count", pa.int32()),
625
- pa.field("importance", pa.float32()),
626
- pa.field("namespace", pa.string()),
627
- pa.field("tags", pa.list_(pa.string())),
628
- pa.field("source", pa.string()),
629
- pa.field("metadata", pa.string()),
630
- pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
631
- ])
632
- self._table = self._db.create_table("memories", schema=schema)
633
- logger.info("Created memories table")
634
-
635
- # Create FTS index on new table if enabled
636
- if self.enable_fts:
637
- self._create_fts_index()
638
- else:
639
- self._table = self._db.open_table("memories")
640
- logger.debug("Opened existing memories table")
686
+ max_retries = 3
687
+ retry_delay = 0.1 # Start with 100ms
641
688
 
642
- # Check existing indexes
643
- self._check_existing_indexes()
689
+ for attempt in range(max_retries):
690
+ try:
691
+ existing_tables_result = self._db.list_tables()
692
+ # Handle both old (list) and new (object with .tables) LanceDB API
693
+ if hasattr(existing_tables_result, 'tables'):
694
+ existing_tables = existing_tables_result.tables
695
+ else:
696
+ existing_tables = existing_tables_result
697
+
698
+ if "memories" not in existing_tables:
699
+ # Create table with schema
700
+ schema = pa.schema([
701
+ pa.field("id", pa.string()),
702
+ pa.field("content", pa.string()),
703
+ pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
704
+ pa.field("created_at", pa.timestamp("us")),
705
+ pa.field("updated_at", pa.timestamp("us")),
706
+ pa.field("last_accessed", pa.timestamp("us")),
707
+ pa.field("access_count", pa.int32()),
708
+ pa.field("importance", pa.float32()),
709
+ pa.field("namespace", pa.string()),
710
+ pa.field("tags", pa.list_(pa.string())),
711
+ pa.field("source", pa.string()),
712
+ pa.field("metadata", pa.string()),
713
+ pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
714
+ ])
715
+ try:
716
+ self._table = self._db.create_table("memories", schema=schema)
717
+ logger.info("Created memories table")
718
+ except Exception as create_err:
719
+ # Table might have been created by another process
720
+ if "already exists" in str(create_err).lower():
721
+ logger.debug("Table created by another process, opening it")
722
+ self._table = self._db.open_table("memories")
723
+ else:
724
+ raise
644
725
 
645
- def _check_existing_indexes(self) -> None:
646
- """Check which indexes already exist using robust detection."""
647
- try:
648
- indices = self.table.list_indices()
726
+ # Initialize IndexManager immediately after table is set
727
+ self._index_manager = IndexManager(self)
649
728
 
650
- self._has_vector_index = False
651
- self._has_fts_index = False
729
+ # Create FTS index on new table if enabled
730
+ if self.enable_fts:
731
+ self._index_manager.create_fts_index()
732
+ else:
733
+ self._table = self._db.open_table("memories")
734
+ logger.debug("Opened existing memories table")
652
735
 
653
- for idx in indices:
654
- index_name = str(_get_index_attr(idx, "name", "")).lower()
655
- index_type = str(_get_index_attr(idx, "index_type", "")).upper()
656
- columns = _get_index_attr(idx, "columns", [])
736
+ # Initialize IndexManager immediately after table is set
737
+ self._index_manager = IndexManager(self)
657
738
 
658
- # Vector index detection: check index_type or column name
659
- if index_type in VECTOR_INDEX_TYPES:
660
- self._has_vector_index = True
661
- elif "vector" in columns or "vector" in index_name:
662
- self._has_vector_index = True
739
+ # Check existing indexes
740
+ self._index_manager.check_existing_indexes()
663
741
 
664
- # FTS index detection: check index_type or name patterns
665
- if index_type == "FTS":
666
- self._has_fts_index = True
667
- elif "fts" in index_name or "content" in index_name:
668
- self._has_fts_index = True
742
+ # Success - exit retry loop
743
+ return
669
744
 
670
- logger.debug(
671
- f"Existing indexes: vector={self._has_vector_index}, "
672
- f"fts={self._has_fts_index}"
673
- )
674
- except Exception as e:
675
- logger.warning(f"Could not check existing indexes: {e}")
676
- self._has_vector_index = None
677
- self._has_fts_index = None
745
+ except Exception as e:
746
+ error_msg = str(e).lower()
747
+ # Retry on transient race conditions
748
+ if attempt < max_retries - 1 and (
749
+ "not found" in error_msg
750
+ or "does not exist" in error_msg
751
+ or "already exists" in error_msg
752
+ ):
753
+ logger.debug(
754
+ f"Table operation failed (attempt {attempt + 1}/{max_retries}), "
755
+ f"retrying in {retry_delay}s: {e}"
756
+ )
757
+ time.sleep(retry_delay)
758
+ retry_delay *= 2 # Exponential backoff
759
+ else:
760
+ raise
761
+
762
+ def _check_existing_indexes(self) -> None:
763
+ """Check which indexes already exist. Delegates to IndexManager."""
764
+ if self._index_manager is None:
765
+ raise StorageError("Database not connected")
766
+ self._index_manager.check_existing_indexes()
767
+ # Sync local state for backward compatibility
768
+ self._has_vector_index = self._index_manager.has_vector_index
769
+ self._has_fts_index = self._index_manager.has_fts_index
678
770
 
679
771
  def _create_fts_index(self) -> None:
680
- """Create full-text search index with optimized settings."""
681
- try:
682
- self.table.create_fts_index(
683
- "content",
684
- use_tantivy=False, # Use Lance native FTS
685
- language=self.fts_language,
686
- stem=self.fts_stem,
687
- remove_stop_words=self.fts_remove_stop_words,
688
- with_position=True, # Enable phrase queries
689
- lower_case=True, # Case-insensitive search
690
- )
691
- self._has_fts_index = True
692
- logger.info(
693
- f"Created FTS index with stemming={self.fts_stem}, "
694
- f"stop_words={self.fts_remove_stop_words}"
695
- )
696
- except Exception as e:
697
- # Check if index already exists (not an error)
698
- if "already exists" in str(e).lower():
699
- self._has_fts_index = True
700
- logger.debug("FTS index already exists")
701
- else:
702
- logger.warning(f"FTS index creation failed: {e}")
772
+ """Create FTS index. Delegates to IndexManager."""
773
+ if self._index_manager is None:
774
+ raise StorageError("Database not connected")
775
+ self._index_manager.create_fts_index()
776
+ # Sync local state for backward compatibility
777
+ self._has_fts_index = self._index_manager.has_fts_index
703
778
 
704
779
  @property
705
780
  def table(self) -> LanceTable:
@@ -710,18 +785,30 @@ class Database:
710
785
  return self._table
711
786
 
712
787
  def close(self) -> None:
713
- """Close the database connection (connection remains pooled)."""
788
+ """Close the database connection and remove from pool.
789
+
790
+ This invalidates the pooled connection so that subsequent
791
+ Database instances will create fresh connections.
792
+ """
793
+ # Invalidate pooled connection first
794
+ invalidate_connection(self.storage_path)
795
+
796
+ # Clear local state
714
797
  self._table = None
715
798
  self._db = None
716
799
  self._has_vector_index = None
717
800
  self._has_fts_index = None
801
+ self._version_manager = None
802
+ self._index_manager = None
803
+ self._search_manager = None
804
+ self._idempotency_manager = None
718
805
  with self._cache_lock:
719
806
  self._cached_row_count = None
720
807
  self._count_cache_time = 0.0
721
808
  with self._namespace_cache_lock:
722
809
  self._cached_namespaces = None
723
810
  self._namespace_cache_time = 0.0
724
- logger.debug("Database connection closed")
811
+ logger.debug("Database connection closed and removed from pool")
725
812
 
726
813
  def reconnect(self) -> None:
727
814
  """Invalidate cached connection and reconnect.
@@ -781,313 +868,86 @@ class Database:
781
868
  self._cached_namespaces = None
782
869
  self._namespace_cache_time = 0.0
783
870
 
784
- # ========================================================================
785
- # Index Management
786
- # ========================================================================
787
-
788
- def create_vector_index(self, force: bool = False) -> bool:
789
- """Create vector index for similarity search.
790
-
791
- Supports IVF_PQ, IVF_FLAT, and HNSW_SQ index types based on configuration.
792
- Automatically determines optimal parameters based on dataset size.
871
+ def _track_modification(self, count: int = 1) -> None:
872
+ """Track database modifications and trigger auto-compaction if threshold reached.
793
873
 
794
874
  Args:
795
- force: Force index creation regardless of dataset size.
796
-
797
- Returns:
798
- True if index was created, False if skipped.
799
-
800
- Raises:
801
- StorageError: If index creation fails.
875
+ count: Number of modifications to track (default 1).
802
876
  """
803
- count = self.table.count_rows()
804
-
805
- # Check threshold
806
- if count < self.vector_index_threshold and not force:
807
- logger.info(
808
- f"Dataset has {count} rows, below threshold {self.vector_index_threshold}. "
809
- "Skipping vector index creation."
810
- )
811
- return False
812
-
813
- # Check if already exists
814
- if self._has_vector_index and not force:
815
- logger.info("Vector index already exists")
816
- return False
817
-
818
- # Handle HNSW_SQ index type
819
- if self.index_type == "HNSW_SQ":
820
- return self._create_hnsw_index(count)
821
-
822
- # IVF-based index creation (IVF_PQ or IVF_FLAT)
823
- return self._create_ivf_index(count)
877
+ if not self._auto_compaction_enabled:
878
+ return
824
879
 
825
- def _create_hnsw_index(self, count: int) -> bool:
826
- """Create HNSW-SQ vector index.
880
+ self._modification_count += count
881
+ if self._modification_count >= self._auto_compaction_threshold:
882
+ # Reset counter before compacting to avoid re-triggering
883
+ self._modification_count = 0
884
+ try:
885
+ stats = self._get_table_stats()
886
+ # Only compact if there are enough fragments to justify it
887
+ if stats.get("num_small_fragments", 0) >= 5:
888
+ logger.info(
889
+ f"Auto-compaction triggered after {self._auto_compaction_threshold} "
890
+ f"modifications ({stats.get('num_small_fragments', 0)} small fragments)"
891
+ )
892
+ self.table.compact_files()
893
+ logger.debug("Auto-compaction completed")
894
+ except Exception as e:
895
+ # Don't fail operations due to compaction issues
896
+ logger.debug(f"Auto-compaction skipped: {e}")
827
897
 
828
- HNSW (Hierarchical Navigable Small World) provides better recall than IVF
829
- at the cost of higher memory usage. Good for datasets where recall is critical.
898
+ def set_auto_compaction(
899
+ self,
900
+ enabled: bool = True,
901
+ threshold: int | None = None,
902
+ ) -> None:
903
+ """Configure auto-compaction behavior.
830
904
 
831
905
  Args:
832
- count: Number of rows in the table.
833
-
834
- Returns:
835
- True if index was created.
836
-
837
- Raises:
838
- StorageError: If index creation fails.
906
+ enabled: Whether auto-compaction is enabled.
907
+ threshold: Number of modifications before auto-compact (default: 100).
839
908
  """
840
- logger.info(
841
- f"Creating HNSW_SQ vector index: m={self.hnsw_m}, "
842
- f"ef_construction={self.hnsw_ef_construction} for {count} rows"
843
- )
909
+ self._auto_compaction_enabled = enabled
910
+ if threshold is not None:
911
+ if threshold < 10:
912
+ raise ValueError("Auto-compaction threshold must be at least 10")
913
+ self._auto_compaction_threshold = threshold
844
914
 
845
- try:
846
- self.table.create_index(
847
- metric="cosine",
848
- vector_column_name="vector",
849
- index_type="HNSW_SQ",
850
- replace=True,
851
- m=self.hnsw_m,
852
- ef_construction=self.hnsw_ef_construction,
853
- )
854
-
855
- # Wait for index to be ready with configurable timeout
856
- self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
857
-
858
- self._has_vector_index = True
859
- logger.info("HNSW_SQ vector index created successfully")
860
-
861
- # Optimize after index creation (may fail in some environments)
862
- try:
863
- self.table.optimize()
864
- except Exception as optimize_error:
865
- logger.debug(f"Optimization after index creation skipped: {optimize_error}")
866
-
867
- return True
868
-
869
- except Exception as e:
870
- logger.error(f"Failed to create HNSW_SQ vector index: {e}")
871
- raise StorageError(f"HNSW_SQ vector index creation failed: {e}") from e
872
-
873
- def _create_ivf_index(self, count: int) -> bool:
874
- """Create IVF-PQ or IVF-FLAT vector index.
915
+ # ========================================================================
916
+ # Index Management (delegates to IndexManager)
917
+ # ========================================================================
875
918
 
876
- Uses sqrt rule for partitions: num_partitions = sqrt(count), clamped to [16, 4096].
877
- Uses 48 sub-vectors for <500K rows (8 dims each for 384-dim vectors),
878
- 96 sub-vectors for >=500K rows (4 dims each).
919
+ def create_vector_index(self, force: bool = False) -> bool:
920
+ """Create vector index for similarity search. Delegates to IndexManager.
879
921
 
880
922
  Args:
881
- count: Number of rows in the table.
923
+ force: Force index creation regardless of dataset size.
882
924
 
883
925
  Returns:
884
- True if index was created.
926
+ True if index was created, False if skipped.
885
927
 
886
928
  Raises:
887
929
  StorageError: If index creation fails.
888
930
  """
889
- # Use sqrt rule for partitions, clamped to [16, 4096]
890
- num_partitions = int(math.sqrt(count))
891
- num_partitions = max(16, min(num_partitions, 4096))
892
-
893
- # Choose num_sub_vectors based on dataset size
894
- # <500K: 48 sub-vectors (8 dims each for 384-dim, more precision)
895
- # >=500K: 96 sub-vectors (4 dims each, more compression)
896
- if count < 500_000:
897
- num_sub_vectors = 48
898
- else:
899
- num_sub_vectors = 96
900
-
901
- # Validate embedding_dim % num_sub_vectors == 0 (required for IVF-PQ)
902
- if self.embedding_dim % num_sub_vectors != 0:
903
- # Find a valid divisor from common sub-vector counts
904
- valid_divisors = [96, 48, 32, 24, 16, 12, 8, 4]
905
- found_divisor = False
906
- for divisor in valid_divisors:
907
- if self.embedding_dim % divisor == 0:
908
- logger.info(
909
- f"Adjusted num_sub_vectors from {num_sub_vectors} to {divisor} "
910
- f"for embedding_dim={self.embedding_dim}"
911
- )
912
- num_sub_vectors = divisor
913
- found_divisor = True
914
- break
915
-
916
- if not found_divisor:
917
- raise StorageError(
918
- f"Cannot create IVF-PQ index: embedding_dim={self.embedding_dim} "
919
- "has no suitable divisor for sub-vectors. "
920
- f"Tried divisors: {valid_divisors}"
921
- )
922
-
923
- # IVF-PQ requires minimum rows for training (sample_rate * num_partitions / 256)
924
- # Default sample_rate=256, so we need at least 256 rows
925
- # Also, IVF requires num_partitions < num_vectors for KMeans training
926
- sample_rate = 256 # default
927
- if count < 256:
928
- # Use IVF_FLAT for very small datasets (no PQ training required)
929
- logger.info(
930
- f"Dataset too small for IVF-PQ ({count} rows < 256). "
931
- "Using IVF_FLAT index instead."
932
- )
933
- index_type = "IVF_FLAT"
934
- sample_rate = max(16, count // 4) # Lower sample rate for small data
935
- else:
936
- index_type = self.index_type if self.index_type in ("IVF_PQ", "IVF_FLAT") else "IVF_PQ"
937
-
938
- # Ensure num_partitions < num_vectors for KMeans clustering
939
- if num_partitions >= count:
940
- num_partitions = max(1, count // 4) # Use 1/4 of count, minimum 1
941
- logger.info(f"Adjusted num_partitions to {num_partitions} for {count} rows")
942
-
943
- logger.info(
944
- f"Creating {index_type} vector index: {num_partitions} partitions, "
945
- f"{num_sub_vectors} sub-vectors for {count} rows"
946
- )
947
-
948
- try:
949
- # LanceDB 0.27+ API: parameters passed directly to create_index
950
- index_kwargs: dict[str, Any] = {
951
- "metric": "cosine",
952
- "num_partitions": num_partitions,
953
- "vector_column_name": "vector",
954
- "index_type": index_type,
955
- "replace": True,
956
- "sample_rate": sample_rate,
957
- }
958
-
959
- # num_sub_vectors only applies to PQ-based indexes
960
- if "PQ" in index_type:
961
- index_kwargs["num_sub_vectors"] = num_sub_vectors
962
-
963
- self.table.create_index(**index_kwargs)
964
-
965
- # Wait for index to be ready with configurable timeout
966
- self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
967
-
968
- self._has_vector_index = True
969
- logger.info(f"{index_type} vector index created successfully")
970
-
971
- # Optimize after index creation (may fail in some environments)
972
- try:
973
- self.table.optimize()
974
- except Exception as optimize_error:
975
- logger.debug(f"Optimization after index creation skipped: {optimize_error}")
976
-
977
- return True
978
-
979
- except Exception as e:
980
- logger.error(f"Failed to create {index_type} vector index: {e}")
981
- raise StorageError(f"{index_type} vector index creation failed: {e}") from e
982
-
983
- def _wait_for_index_ready(
984
- self,
985
- column_name: str,
986
- timeout_seconds: float,
987
- poll_interval: float = 0.5,
988
- ) -> None:
989
- """Wait for an index on the specified column to be ready.
990
-
991
- Args:
992
- column_name: Name of the column the index is on (e.g., "vector").
993
- LanceDB typically names indexes as "{column_name}_idx".
994
- timeout_seconds: Maximum time to wait.
995
- poll_interval: Time between status checks.
996
- """
997
- if timeout_seconds <= 0:
998
- return
999
-
1000
- start_time = time.time()
1001
- while time.time() - start_time < timeout_seconds:
1002
- try:
1003
- indices = self.table.list_indices()
1004
- for idx in indices:
1005
- idx_name = str(_get_index_attr(idx, "name", "")).lower()
1006
- idx_columns = _get_index_attr(idx, "columns", [])
1007
-
1008
- # Match by column name in index metadata, or index name contains column
1009
- if column_name in idx_columns or column_name in idx_name:
1010
- # Index exists, check if it's ready
1011
- status = str(_get_index_attr(idx, "status", "ready"))
1012
- if status.lower() in ("ready", "complete", "built"):
1013
- logger.debug(f"Index on {column_name} is ready")
1014
- return
1015
- break
1016
- except Exception as e:
1017
- logger.debug(f"Error checking index status: {e}")
1018
-
1019
- time.sleep(poll_interval)
1020
-
1021
- logger.warning(
1022
- f"Timeout waiting for index on {column_name} after {timeout_seconds}s"
1023
- )
931
+ if self._index_manager is None:
932
+ raise StorageError("Database not connected")
933
+ result = self._index_manager.create_vector_index(force=force)
934
+ # Sync local state only when index was created or modified
935
+ if result:
936
+ self._has_vector_index = self._index_manager.has_vector_index
937
+ return result
1024
938
 
1025
939
  def create_scalar_indexes(self) -> None:
1026
- """Create scalar indexes for frequently filtered columns.
1027
-
1028
- Creates:
1029
- - BTREE on id (fast lookups, upserts)
1030
- - BTREE on timestamps and importance (range queries)
1031
- - BITMAP on namespace and source (low cardinality)
1032
- - LABEL_LIST on tags (array contains queries)
940
+ """Create scalar indexes for frequently filtered columns. Delegates to IndexManager.
1033
941
 
1034
942
  Raises:
1035
943
  StorageError: If index creation fails critically.
1036
944
  """
1037
- # BTREE indexes for range queries and lookups
1038
- btree_columns = [
1039
- "id", # Fast lookups and merge_insert
1040
- "created_at",
1041
- "updated_at",
1042
- "last_accessed",
1043
- "importance",
1044
- "access_count",
1045
- "expires_at", # TTL expiration queries
1046
- ]
1047
-
1048
- for column in btree_columns:
1049
- try:
1050
- self.table.create_scalar_index(
1051
- column,
1052
- index_type="BTREE",
1053
- replace=True,
1054
- )
1055
- logger.debug(f"Created BTREE index on {column}")
1056
- except Exception as e:
1057
- if "already exists" not in str(e).lower():
1058
- logger.warning(f"Could not create BTREE index on {column}: {e}")
1059
-
1060
- # BITMAP indexes for low-cardinality columns
1061
- bitmap_columns = ["namespace", "source"]
1062
-
1063
- for column in bitmap_columns:
1064
- try:
1065
- self.table.create_scalar_index(
1066
- column,
1067
- index_type="BITMAP",
1068
- replace=True,
1069
- )
1070
- logger.debug(f"Created BITMAP index on {column}")
1071
- except Exception as e:
1072
- if "already exists" not in str(e).lower():
1073
- logger.warning(f"Could not create BITMAP index on {column}: {e}")
1074
-
1075
- # LABEL_LIST index for tags array (supports array_has_any queries)
1076
- try:
1077
- self.table.create_scalar_index(
1078
- "tags",
1079
- index_type="LABEL_LIST",
1080
- replace=True,
1081
- )
1082
- logger.debug("Created LABEL_LIST index on tags")
1083
- except Exception as e:
1084
- if "already exists" not in str(e).lower():
1085
- logger.warning(f"Could not create LABEL_LIST index on tags: {e}")
1086
-
1087
- logger.info("Scalar indexes created")
945
+ if self._index_manager is None:
946
+ raise StorageError("Database not connected")
947
+ self._index_manager.create_scalar_indexes()
1088
948
 
1089
949
  def ensure_indexes(self, force: bool = False) -> dict[str, bool]:
1090
- """Ensure all appropriate indexes exist.
950
+ """Ensure all appropriate indexes exist. Delegates to IndexManager.
1091
951
 
1092
952
  Args:
1093
953
  force: Force index creation regardless of thresholds.
@@ -1095,35 +955,12 @@ class Database:
1095
955
  Returns:
1096
956
  Dict indicating which indexes were created.
1097
957
  """
1098
- results = {
1099
- "vector_index": False,
1100
- "scalar_indexes": False,
1101
- "fts_index": False,
1102
- }
1103
-
1104
- count = self.table.count_rows()
1105
-
1106
- # Vector index
1107
- if self.auto_create_indexes or force:
1108
- if count >= self.vector_index_threshold or force:
1109
- results["vector_index"] = self.create_vector_index(force=force)
1110
-
1111
- # Scalar indexes (always create if > 1000 rows)
1112
- if count >= 1000 or force:
1113
- try:
1114
- self.create_scalar_indexes()
1115
- results["scalar_indexes"] = True
1116
- except Exception as e:
1117
- logger.warning(f"Scalar index creation partially failed: {e}")
1118
-
1119
- # FTS index
1120
- if self.enable_fts and not self._has_fts_index:
1121
- try:
1122
- self._create_fts_index()
1123
- results["fts_index"] = True
1124
- except Exception as e:
1125
- logger.warning(f"FTS index creation failed in ensure_indexes: {e}")
1126
-
958
+ if self._index_manager is None:
959
+ raise StorageError("Database not connected")
960
+ results = self._index_manager.ensure_indexes(force=force)
961
+ # Sync local state for backward compatibility
962
+ self._has_vector_index = self._index_manager.has_vector_index
963
+ self._has_fts_index = self._index_manager.has_fts_index
1127
964
  return results
1128
965
 
1129
966
  # ========================================================================
@@ -1292,6 +1129,13 @@ class Database:
1292
1129
  if not 0.0 <= importance <= 1.0:
1293
1130
  raise ValidationError("Importance must be between 0.0 and 1.0")
1294
1131
 
1132
+ # Validate vector dimensions
1133
+ if len(vector) != self.embedding_dim:
1134
+ raise DimensionMismatchError(
1135
+ expected_dim=self.embedding_dim,
1136
+ actual_dim=len(vector),
1137
+ )
1138
+
1295
1139
  memory_id = str(uuid.uuid4())
1296
1140
  now = utc_now()
1297
1141
 
@@ -1319,6 +1163,7 @@ class Database:
1319
1163
  try:
1320
1164
  self.table.add([record])
1321
1165
  self._invalidate_count_cache()
1166
+ self._track_modification()
1322
1167
  self._invalidate_namespace_cache()
1323
1168
  logger.debug(f"Inserted memory {memory_id}")
1324
1169
  return memory_id
@@ -1335,23 +1180,26 @@ class Database:
1335
1180
  self,
1336
1181
  records: list[dict[str, Any]],
1337
1182
  batch_size: int = 1000,
1183
+ atomic: bool = False,
1338
1184
  ) -> list[str]:
1339
1185
  """Insert multiple memories efficiently with batching.
1340
1186
 
1341
- Note: Batch insert is NOT atomic. Partial failures may leave some
1342
- records inserted. If atomicity is required, use individual inserts
1343
- with transaction management at the application layer.
1344
-
1345
1187
  Args:
1346
1188
  records: List of memory records with content, vector, and optional fields.
1347
1189
  batch_size: Records per batch (default: 1000, max: 10000).
1190
+ atomic: If True, rollback all inserts on partial failure.
1191
+ When atomic=True and a batch fails:
1192
+ - Attempts to delete already-inserted records
1193
+ - If rollback succeeds, raises the original StorageError
1194
+ - If rollback fails, raises PartialBatchInsertError with succeeded_ids
1348
1195
 
1349
1196
  Returns:
1350
1197
  List of generated memory IDs.
1351
1198
 
1352
1199
  Raises:
1353
1200
  ValidationError: If input validation fails or batch_size exceeds maximum.
1354
- StorageError: If database operation fails.
1201
+ StorageError: If database operation fails (and rollback succeeds when atomic=True).
1202
+ PartialBatchInsertError: If atomic=True and rollback fails after partial insert.
1355
1203
  """
1356
1204
  if batch_size > self.MAX_BATCH_SIZE:
1357
1205
  raise ValidationError(
@@ -1359,9 +1207,10 @@ class Database:
1359
1207
  )
1360
1208
 
1361
1209
  all_ids: list[str] = []
1210
+ total_requested = len(records)
1362
1211
 
1363
1212
  # Process in batches for large inserts
1364
- for i in range(0, len(records), batch_size):
1213
+ for batch_index, i in enumerate(range(0, len(records), batch_size)):
1365
1214
  batch = records[i:i + batch_size]
1366
1215
  now = utc_now()
1367
1216
  memory_ids: list[str] = []
@@ -1389,6 +1238,14 @@ class Database:
1389
1238
  else:
1390
1239
  vector_list = raw_vector
1391
1240
 
1241
+ # Validate vector dimensions
1242
+ if len(vector_list) != self.embedding_dim:
1243
+ raise DimensionMismatchError(
1244
+ expected_dim=self.embedding_dim,
1245
+ actual_dim=len(vector_list),
1246
+ record_index=i + len(memory_ids),
1247
+ )
1248
+
1392
1249
  # Calculate expires_at if default TTL is configured
1393
1250
  expires_at = None
1394
1251
  if self.default_memory_ttl_days is not None:
@@ -1415,9 +1272,29 @@ class Database:
1415
1272
  self.table.add(prepared_records)
1416
1273
  all_ids.extend(memory_ids)
1417
1274
  self._invalidate_count_cache()
1275
+ self._track_modification(len(memory_ids))
1418
1276
  self._invalidate_namespace_cache()
1419
- logger.debug(f"Inserted batch {i // batch_size + 1}: {len(memory_ids)} memories")
1277
+ logger.debug(f"Inserted batch {batch_index + 1}: {len(memory_ids)} memories")
1420
1278
  except Exception as e:
1279
+ if atomic and all_ids:
1280
+ # Attempt rollback of previously inserted records
1281
+ logger.warning(
1282
+ f"Batch {batch_index + 1} failed, attempting rollback of "
1283
+ f"{len(all_ids)} previously inserted records"
1284
+ )
1285
+ rollback_error = self._rollback_batch_insert(all_ids)
1286
+ if rollback_error:
1287
+ # Rollback failed - raise PartialBatchInsertError
1288
+ raise PartialBatchInsertError(
1289
+ message=f"Batch insert failed and rollback also failed: {e}",
1290
+ succeeded_ids=all_ids,
1291
+ total_requested=total_requested,
1292
+ failed_batch_index=batch_index,
1293
+ ) from e
1294
+ else:
1295
+ # Rollback succeeded - raise original error
1296
+ logger.info(f"Rollback successful, deleted {len(all_ids)} records")
1297
+ raise StorageError(f"Failed to insert batch (rolled back): {e}") from e
1421
1298
  raise StorageError(f"Failed to insert batch: {e}") from e
1422
1299
 
1423
1300
  # Check if we should create indexes after large insert
@@ -1433,6 +1310,31 @@ class Database:
1433
1310
  logger.debug(f"Inserted {len(all_ids)} memories total")
1434
1311
  return all_ids
1435
1312
 
1313
+ def _rollback_batch_insert(self, memory_ids: list[str]) -> Exception | None:
1314
+ """Attempt to delete records inserted during a failed batch operation.
1315
+
1316
+ Args:
1317
+ memory_ids: List of memory IDs to delete.
1318
+
1319
+ Returns:
1320
+ None if rollback succeeded, Exception if it failed.
1321
+ """
1322
+ try:
1323
+ if not memory_ids:
1324
+ return None
1325
+
1326
+ # Use delete_batch for efficient rollback
1327
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in memory_ids)
1328
+ self.table.delete(f"id IN ({id_list})")
1329
+ self._invalidate_count_cache()
1330
+ self._track_modification(len(memory_ids))
1331
+ self._invalidate_namespace_cache()
1332
+ logger.debug(f"Rolled back {len(memory_ids)} records")
1333
+ return None
1334
+ except Exception as e:
1335
+ logger.error(f"Rollback failed: {e}")
1336
+ return e
1337
+
1436
1338
  @with_stale_connection_recovery
1437
1339
  def get(self, memory_id: str) -> dict[str, Any]:
1438
1340
  """Get a memory by ID.
@@ -1467,6 +1369,51 @@ class Database:
1467
1369
  except Exception as e:
1468
1370
  raise StorageError(f"Failed to get memory: {e}") from e
1469
1371
 
1372
+ def get_batch(self, memory_ids: list[str]) -> dict[str, dict[str, Any]]:
1373
+ """Get multiple memories by ID in a single query.
1374
+
1375
+ Args:
1376
+ memory_ids: List of memory UUIDs to retrieve.
1377
+
1378
+ Returns:
1379
+ Dict mapping memory_id to memory record. Missing IDs are not included.
1380
+
1381
+ Raises:
1382
+ ValidationError: If any memory_id format is invalid.
1383
+ StorageError: If database operation fails.
1384
+ """
1385
+ if not memory_ids:
1386
+ return {}
1387
+
1388
+ # Validate all IDs first
1389
+ validated_ids: list[str] = []
1390
+ for memory_id in memory_ids:
1391
+ try:
1392
+ validated_id = _validate_uuid(memory_id)
1393
+ validated_ids.append(_sanitize_string(validated_id))
1394
+ except Exception as e:
1395
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1396
+ continue
1397
+
1398
+ if not validated_ids:
1399
+ return {}
1400
+
1401
+ try:
1402
+ # Batch fetch with single IN query
1403
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1404
+ results = self.table.search().where(f"id IN ({id_list})").to_list()
1405
+
1406
+ # Build result map
1407
+ result_map: dict[str, dict[str, Any]] = {}
1408
+ for record in results:
1409
+ # Deserialize metadata
1410
+ record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
1411
+ result_map[record["id"]] = record
1412
+
1413
+ return result_map
1414
+ except Exception as e:
1415
+ raise StorageError(f"Failed to batch get memories: {e}") from e
1416
+
1470
1417
  @with_process_lock
1471
1418
  @with_write_lock
1472
1419
  def update(self, memory_id: str, updates: dict[str, Any]) -> None:
@@ -1522,6 +1469,108 @@ class Database:
1522
1469
  except Exception as e:
1523
1470
  raise StorageError(f"Failed to update memory: {e}") from e
1524
1471
 
1472
+ @with_process_lock
1473
+ @with_write_lock
1474
+ def update_batch(
1475
+ self, updates: list[tuple[str, dict[str, Any]]]
1476
+ ) -> tuple[int, list[str]]:
1477
+ """Update multiple memories using atomic merge_insert.
1478
+
1479
+ Args:
1480
+ updates: List of (memory_id, updates_dict) tuples.
1481
+
1482
+ Returns:
1483
+ Tuple of (success_count, list of failed memory_ids).
1484
+
1485
+ Raises:
1486
+ StorageError: If database operation fails completely.
1487
+ """
1488
+ if not updates:
1489
+ return 0, []
1490
+
1491
+ now = utc_now()
1492
+ records_to_update: list[dict[str, Any]] = []
1493
+ failed_ids: list[str] = []
1494
+
1495
+ # Validate all IDs and collect them
1496
+ validated_updates: list[tuple[str, dict[str, Any]]] = []
1497
+ for memory_id, update_dict in updates:
1498
+ try:
1499
+ validated_id = _validate_uuid(memory_id)
1500
+ validated_updates.append((_sanitize_string(validated_id), update_dict))
1501
+ except Exception as e:
1502
+ logger.debug(f"Invalid memory ID {memory_id}: {e}")
1503
+ failed_ids.append(memory_id)
1504
+
1505
+ if not validated_updates:
1506
+ return 0, failed_ids
1507
+
1508
+ # Batch fetch all records
1509
+ validated_ids = [vid for vid, _ in validated_updates]
1510
+ try:
1511
+ id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
1512
+ all_records = self.table.search().where(f"id IN ({id_list})").to_list()
1513
+ except Exception as e:
1514
+ logger.error(f"Failed to batch fetch records for update: {e}")
1515
+ raise StorageError(f"Failed to batch fetch for update: {e}") from e
1516
+
1517
+ # Build lookup map
1518
+ record_map: dict[str, dict[str, Any]] = {}
1519
+ for record in all_records:
1520
+ record_map[record["id"]] = record
1521
+
1522
+ # Apply updates to found records
1523
+ update_dict_map = dict(validated_updates)
1524
+ for memory_id in validated_ids:
1525
+ if memory_id not in record_map:
1526
+ logger.debug(f"Memory {memory_id} not found for batch update")
1527
+ failed_ids.append(memory_id)
1528
+ continue
1529
+
1530
+ record = record_map[memory_id]
1531
+ update_dict = update_dict_map[memory_id]
1532
+
1533
+ # Apply updates
1534
+ record["updated_at"] = now
1535
+ for key, value in update_dict.items():
1536
+ if key == "metadata" and isinstance(value, dict):
1537
+ record[key] = json.dumps(value)
1538
+ elif key == "vector" and isinstance(value, np.ndarray):
1539
+ record[key] = value.tolist()
1540
+ else:
1541
+ record[key] = value
1542
+
1543
+ # Ensure metadata is serialized
1544
+ if isinstance(record.get("metadata"), dict):
1545
+ record["metadata"] = json.dumps(record["metadata"])
1546
+
1547
+ # Ensure vector is a list
1548
+ if isinstance(record.get("vector"), np.ndarray):
1549
+ record["vector"] = record["vector"].tolist()
1550
+
1551
+ records_to_update.append(record)
1552
+
1553
+ if not records_to_update:
1554
+ return 0, failed_ids
1555
+
1556
+ try:
1557
+ # Atomic batch upsert
1558
+ (
1559
+ self.table.merge_insert("id")
1560
+ .when_matched_update_all()
1561
+ .when_not_matched_insert_all()
1562
+ .execute(records_to_update)
1563
+ )
1564
+ success_count = len(records_to_update)
1565
+ logger.debug(
1566
+ f"Batch updated {success_count}/{len(updates)} memories "
1567
+ "(atomic merge_insert)"
1568
+ )
1569
+ return success_count, failed_ids
1570
+ except Exception as e:
1571
+ logger.error(f"Failed to batch update: {e}")
1572
+ raise StorageError(f"Failed to batch update: {e}") from e
1573
+
1525
1574
  @with_process_lock
1526
1575
  @with_write_lock
1527
1576
  def delete(self, memory_id: str) -> None:
@@ -1545,6 +1594,7 @@ class Database:
1545
1594
  try:
1546
1595
  self.table.delete(f"id = '{safe_id}'")
1547
1596
  self._invalidate_count_cache()
1597
+ self._track_modification()
1548
1598
  self._invalidate_namespace_cache()
1549
1599
  logger.debug(f"Deleted memory {memory_id}")
1550
1600
  except Exception as e:
@@ -1572,6 +1622,7 @@ class Database:
1572
1622
  count_before: int = self.table.count_rows()
1573
1623
  self.table.delete(f"namespace = '{safe_ns}'")
1574
1624
  self._invalidate_count_cache()
1625
+ self._track_modification()
1575
1626
  self._invalidate_namespace_cache()
1576
1627
  count_after: int = self.table.count_rows()
1577
1628
  deleted = count_before - count_after
@@ -1615,6 +1666,7 @@ class Database:
1615
1666
  self.table.delete("id IS NOT NULL")
1616
1667
 
1617
1668
  self._invalidate_count_cache()
1669
+ self._track_modification()
1618
1670
  self._invalidate_namespace_cache()
1619
1671
 
1620
1672
  # Reset index tracking flags for test isolation
@@ -1634,6 +1686,7 @@ class Database:
1634
1686
  """Rename all memories from one namespace to another.
1635
1687
 
1636
1688
  Uses atomic batch update via merge_insert for data integrity.
1689
+ On partial failure, attempts to rollback renamed records to original namespace.
1637
1690
 
1638
1691
  Args:
1639
1692
  old_namespace: Source namespace name.
@@ -1652,6 +1705,7 @@ class Database:
1652
1705
  old_namespace = _validate_namespace(old_namespace)
1653
1706
  new_namespace = _validate_namespace(new_namespace)
1654
1707
  safe_old = _sanitize_string(old_namespace)
1708
+ safe_new = _sanitize_string(new_namespace)
1655
1709
 
1656
1710
  try:
1657
1711
  # Check if source namespace exists
@@ -1665,6 +1719,9 @@ class Database:
1665
1719
  logger.debug(f"Namespace '{old_namespace}' renamed to itself ({count} records)")
1666
1720
  return count
1667
1721
 
1722
+ # Track renamed IDs for rollback capability
1723
+ renamed_ids: list[str] = []
1724
+
1668
1725
  # Fetch all records in batches with iteration safeguards
1669
1726
  batch_size = 1000
1670
1727
  max_iterations = 10000 # Safety cap: 10M records at 1000/batch
@@ -1693,6 +1750,9 @@ class Database:
1693
1750
  if not records:
1694
1751
  break
1695
1752
 
1753
+ # Track IDs in this batch for potential rollback
1754
+ batch_ids = [r["id"] for r in records]
1755
+
1696
1756
  # Update namespace field
1697
1757
  for r in records:
1698
1758
  r["namespace"] = new_namespace
@@ -1702,13 +1762,41 @@ class Database:
1702
1762
  if isinstance(r.get("vector"), np.ndarray):
1703
1763
  r["vector"] = r["vector"].tolist()
1704
1764
 
1705
- # Atomic upsert
1706
- (
1707
- self.table.merge_insert("id")
1708
- .when_matched_update_all()
1709
- .when_not_matched_insert_all()
1710
- .execute(records)
1711
- )
1765
+ try:
1766
+ # Atomic upsert
1767
+ (
1768
+ self.table.merge_insert("id")
1769
+ .when_matched_update_all()
1770
+ .when_not_matched_insert_all()
1771
+ .execute(records)
1772
+ )
1773
+ # Only track as renamed after successful update
1774
+ renamed_ids.extend(batch_ids)
1775
+ except Exception as batch_error:
1776
+ # Batch failed - attempt rollback of previously renamed records
1777
+ if renamed_ids:
1778
+ logger.warning(
1779
+ f"Batch {iteration} failed, attempting rollback of "
1780
+ f"{len(renamed_ids)} previously renamed records"
1781
+ )
1782
+ rollback_error = self._rollback_namespace_rename(
1783
+ renamed_ids, old_namespace
1784
+ )
1785
+ if rollback_error:
1786
+ raise StorageError(
1787
+ f"Namespace rename failed at batch {iteration} and "
1788
+ f"rollback also failed. {len(renamed_ids)} records may be "
1789
+ f"in inconsistent state (partially in '{new_namespace}'). "
1790
+ f"Original error: {batch_error}. Rollback error: {rollback_error}"
1791
+ ) from batch_error
1792
+ else:
1793
+ logger.info(
1794
+ f"Rollback successful, reverted {len(renamed_ids)} records "
1795
+ f"back to namespace '{old_namespace}'"
1796
+ )
1797
+ raise StorageError(
1798
+ f"Failed to rename namespace (rolled back): {batch_error}"
1799
+ ) from batch_error
1712
1800
 
1713
1801
  updated += len(records)
1714
1802
 
@@ -1731,6 +1819,66 @@ class Database:
1731
1819
  except Exception as e:
1732
1820
  raise StorageError(f"Failed to rename namespace: {e}") from e
1733
1821
 
1822
+ def _rollback_namespace_rename(
1823
+ self, memory_ids: list[str], target_namespace: str
1824
+ ) -> Exception | None:
1825
+ """Attempt to revert renamed records back to original namespace.
1826
+
1827
+ Args:
1828
+ memory_ids: List of memory IDs to revert.
1829
+ target_namespace: Namespace to revert records to.
1830
+
1831
+ Returns:
1832
+ None if rollback succeeded, Exception if it failed.
1833
+ """
1834
+ try:
1835
+ if not memory_ids:
1836
+ return None
1837
+
1838
+ safe_namespace = _sanitize_string(target_namespace)
1839
+ now = utc_now()
1840
+
1841
+ # Process in batches for large rollbacks
1842
+ batch_size = 1000
1843
+ for i in range(0, len(memory_ids), batch_size):
1844
+ batch_ids = memory_ids[i:i + batch_size]
1845
+ id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in batch_ids)
1846
+
1847
+ # Fetch records that need rollback
1848
+ records = (
1849
+ self.table.search()
1850
+ .where(f"id IN ({id_list})")
1851
+ .to_list()
1852
+ )
1853
+
1854
+ if not records:
1855
+ continue
1856
+
1857
+ # Revert namespace
1858
+ for r in records:
1859
+ r["namespace"] = target_namespace
1860
+ r["updated_at"] = now
1861
+ if isinstance(r.get("metadata"), dict):
1862
+ r["metadata"] = json.dumps(r["metadata"])
1863
+ if isinstance(r.get("vector"), np.ndarray):
1864
+ r["vector"] = r["vector"].tolist()
1865
+
1866
+ # Atomic upsert to restore original namespace
1867
+ (
1868
+ self.table.merge_insert("id")
1869
+ .when_matched_update_all()
1870
+ .when_not_matched_insert_all()
1871
+ .execute(records)
1872
+ )
1873
+
1874
+ self._invalidate_namespace_cache()
1875
+ logger.debug(f"Rolled back {len(memory_ids)} records to namespace '{target_namespace}'")
1876
+ return None
1877
+
1878
+ except Exception as e:
1879
+ logger.error(f"Namespace rename rollback failed: {e}")
1880
+ return e
1881
+
1734
1882
  @with_stale_connection_recovery
1735
1883
  def get_stats(self, namespace: str | None = None) -> dict[str, Any]:
1736
1884
  """Get comprehensive database statistics.
@@ -1827,15 +1975,18 @@ class Database:
1827
1975
  safe_ns = _sanitize_string(namespace)
1828
1976
 
1829
1977
  try:
1830
- # Get records for this namespace (select created_at and content for stats)
1831
- records = (
1978
+ # Get count efficiently
1979
+ filter_expr = f"namespace = '{safe_ns}'"
1980
+ count_results = (
1832
1981
  self.table.search()
1833
- .where(f"namespace = '{safe_ns}'")
1834
- .select(["created_at", "content"])
1982
+ .where(filter_expr)
1983
+ .select(["id"])
1984
+ .limit(1000000) # High limit to count all
1835
1985
  .to_list()
1836
1986
  )
1987
+ memory_count = len(count_results)
1837
1988
 
1838
- if not records:
1989
+ if memory_count == 0:
1839
1990
  return {
1840
1991
  "namespace": namespace,
1841
1992
  "memory_count": 0,
@@ -1844,18 +1995,42 @@ class Database:
1844
1995
  "avg_content_length": None,
1845
1996
  }
1846
1997
 
1847
- # Find oldest and newest
1848
- created_times = [r["created_at"] for r in records]
1849
- oldest = min(created_times)
1850
- newest = max(created_times)
1998
+ # Get oldest memory (sort ascending, limit 1)
1999
+ oldest_records = (
2000
+ self.table.search()
2001
+ .where(filter_expr)
2002
+ .select(["created_at"])
2003
+ .limit(1)
2004
+ .to_list()
2005
+ )
2006
+ oldest = oldest_records[0]["created_at"] if oldest_records else None
1851
2007
 
1852
- # Calculate average content length
1853
- content_lengths = [len(r.get("content", "")) for r in records]
1854
- avg_content_length = sum(content_lengths) / len(content_lengths)
2008
+ # Get newest memory - need to fetch more and find max since LanceDB
2009
+ # doesn't support ORDER BY DESC efficiently
2010
+ # Sample up to 1000 records for stats to avoid loading everything
2011
+ sample_size = min(memory_count, 1000)
2012
+ sample_records = (
2013
+ self.table.search()
2014
+ .where(filter_expr)
2015
+ .select(["created_at", "content"])
2016
+ .limit(sample_size)
2017
+ .to_list()
2018
+ )
2019
+
2020
+ # Find newest from sample (for large namespaces this is approximate)
2021
+ if sample_records:
2022
+ created_times = [r["created_at"] for r in sample_records]
2023
+ newest = max(created_times)
2024
+ # Calculate average content length from sample
2025
+ content_lengths = [len(r.get("content", "")) for r in sample_records]
2026
+ avg_content_length = sum(content_lengths) / len(content_lengths)
2027
+ else:
2028
+ newest = oldest
2029
+ avg_content_length = None
1855
2030
 
1856
2031
  return {
1857
2032
  "namespace": namespace,
1858
- "memory_count": len(records),
2033
+ "memory_count": memory_count,
1859
2034
  "oldest_memory": oldest,
1860
2035
  "newest_memory": newest,
1861
2036
  "avg_content_length": avg_content_length,
@@ -2015,21 +2190,23 @@ class Database:
2015
2190
 
2016
2191
  @with_process_lock
2017
2192
  @with_write_lock
2018
- def delete_batch(self, memory_ids: list[str]) -> int:
2193
+ def delete_batch(self, memory_ids: list[str]) -> tuple[int, list[str]]:
2019
2194
  """Delete multiple memories atomically using IN clause.
2020
2195
 
2021
2196
  Args:
2022
2197
  memory_ids: List of memory UUIDs to delete.
2023
2198
 
2024
2199
  Returns:
2025
- Number of memories actually deleted.
2200
+ Tuple of (count_deleted, list_of_deleted_ids) where:
2201
+ - count_deleted: Number of memories actually deleted
2202
+ - list_of_deleted_ids: IDs that were actually deleted
2026
2203
 
2027
2204
  Raises:
2028
2205
  ValidationError: If any memory_id is invalid.
2029
2206
  StorageError: If database operation fails.
2030
2207
  """
2031
2208
  if not memory_ids:
2032
- return 0
2209
+ return (0, [])
2033
2210
 
2034
2211
  # Validate all IDs first (fail fast)
2035
2212
  validated_ids: list[str] = []
@@ -2038,21 +2215,32 @@ class Database:
2038
2215
  validated_ids.append(_sanitize_string(validated_id))
2039
2216
 
2040
2217
  try:
2041
- count_before: int = self.table.count_rows()
2042
-
2043
- # Build IN clause for atomic batch delete
2218
+ # First, check which IDs actually exist
2044
2219
  id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
2045
2220
  filter_expr = f"id IN ({id_list})"
2046
- self.table.delete(filter_expr)
2221
+ existing_records = (
2222
+ self.table.search()
2223
+ .where(filter_expr)
2224
+ .select(["id"])
2225
+ .limit(len(validated_ids))
2226
+ .to_list()
2227
+ )
2228
+ existing_ids = [r["id"] for r in existing_records]
2229
+
2230
+ if not existing_ids:
2231
+ return (0, [])
2232
+
2233
+ # Delete only existing IDs
2234
+ existing_id_list = ", ".join(f"'{mid}'" for mid in existing_ids)
2235
+ delete_expr = f"id IN ({existing_id_list})"
2236
+ self.table.delete(delete_expr)
2047
2237
 
2048
2238
  self._invalidate_count_cache()
2239
+ self._track_modification()
2049
2240
  self._invalidate_namespace_cache()
2050
2241
 
2051
- count_after: int = self.table.count_rows()
2052
- deleted = count_before - count_after
2053
-
2054
- logger.debug(f"Batch deleted {deleted} memories")
2055
- return deleted
2242
+ logger.debug(f"Batch deleted {len(existing_ids)} memories")
2243
+ return (len(existing_ids), existing_ids)
2056
2244
  except ValidationError:
2057
2245
  raise
2058
2246
  except Exception as e:
@@ -2150,6 +2338,10 @@ class Database:
2150
2338
  backoff=self.retry_backoff_seconds,
2151
2339
  )
2152
2340
 
2341
+ # ========================================================================
2342
+ # Search Operations (delegates to SearchManager)
2343
+ # ========================================================================
2344
+
2153
2345
  def _calculate_search_params(
2154
2346
  self,
2155
2347
  count: int,
@@ -2157,59 +2349,12 @@ class Database:
2157
2349
  nprobes_override: int | None = None,
2158
2350
  refine_factor_override: int | None = None,
2159
2351
  ) -> tuple[int, int]:
2160
- """Calculate optimal search parameters based on dataset size and limit.
2161
-
2162
- Dynamically tunes nprobes and refine_factor for optimal recall/speed tradeoff.
2163
-
2164
- Args:
2165
- count: Number of rows in the dataset.
2166
- limit: Number of results requested.
2167
- nprobes_override: Optional override for nprobes (uses this if provided).
2168
- refine_factor_override: Optional override for refine_factor.
2169
-
2170
- Returns:
2171
- Tuple of (nprobes, refine_factor).
2172
-
2173
- Scaling rules:
2174
- - nprobes: Base from config, scaled up for larger datasets
2175
- - <100K: config value (default 20)
2176
- - 100K-1M: max(config, 30)
2177
- - 1M-10M: max(config, 50)
2178
- - >10M: max(config, 100)
2179
- - refine_factor: Base from config, scaled up for small limits
2180
- - limit <= 5: config value * 2
2181
- - limit <= 20: config value
2182
- - limit > 20: max(config // 2, 2)
2183
- """
2184
- # Calculate nprobes based on dataset size
2185
- if nprobes_override is not None:
2186
- nprobes = nprobes_override
2187
- else:
2188
- base_nprobes = self.index_nprobes
2189
- if count < 100_000:
2190
- nprobes = base_nprobes
2191
- elif count < 1_000_000:
2192
- nprobes = max(base_nprobes, 30)
2193
- elif count < 10_000_000:
2194
- nprobes = max(base_nprobes, 50)
2195
- else:
2196
- nprobes = max(base_nprobes, 100)
2197
-
2198
- # Calculate refine_factor based on limit
2199
- if refine_factor_override is not None:
2200
- refine_factor = refine_factor_override
2201
- else:
2202
- base_refine = self.index_refine_factor
2203
- if limit <= 5:
2204
- # Small limits need more refinement for accuracy
2205
- refine_factor = base_refine * 2
2206
- elif limit <= 20:
2207
- refine_factor = base_refine
2208
- else:
2209
- # Large limits can use less refinement
2210
- refine_factor = max(base_refine // 2, 2)
2211
-
2212
- return nprobes, refine_factor
2352
+ """Calculate optimal search parameters. Delegates to SearchManager."""
2353
+ if self._search_manager is None:
2354
+ raise StorageError("Database not connected")
2355
+ return self._search_manager.calculate_search_params(
2356
+ count, limit, nprobes_override, refine_factor_override
2357
+ )
2213
2358
 
2214
2359
  @with_stale_connection_recovery
2215
2360
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2223,19 +2368,16 @@ class Database:
2223
2368
  refine_factor: int | None = None,
2224
2369
  include_vector: bool = False,
2225
2370
  ) -> list[dict[str, Any]]:
2226
- """Search for similar memories by vector with performance tuning.
2371
+ """Search for similar memories by vector. Delegates to SearchManager.
2227
2372
 
2228
2373
  Args:
2229
2374
  query_vector: Query embedding vector.
2230
2375
  limit: Maximum number of results.
2231
2376
  namespace: Filter to specific namespace.
2232
2377
  min_similarity: Minimum similarity threshold (0-1).
2233
- nprobes: Number of partitions to search (higher = better recall).
2234
- Only effective when vector index exists. Defaults to dynamic calculation.
2378
+ nprobes: Number of partitions to search.
2235
2379
  refine_factor: Re-rank top (refine_factor * limit) for accuracy.
2236
- Defaults to dynamic calculation based on limit.
2237
2380
  include_vector: Whether to include vector embeddings in results.
2238
- Defaults to False to reduce response size.
2239
2381
 
2240
2382
  Returns:
2241
2383
  List of memory records with similarity scores.
@@ -2244,66 +2386,53 @@ class Database:
2244
2386
  ValidationError: If input validation fails.
2245
2387
  StorageError: If database operation fails.
2246
2388
  """
2247
- try:
2248
- search = self.table.search(query_vector.tolist())
2389
+ if self._search_manager is None:
2390
+ raise StorageError("Database not connected")
2391
+ return self._search_manager.vector_search(
2392
+ query_vector=query_vector,
2393
+ limit=limit,
2394
+ namespace=namespace,
2395
+ min_similarity=min_similarity,
2396
+ nprobes=nprobes,
2397
+ refine_factor=refine_factor,
2398
+ include_vector=include_vector,
2399
+ )
2249
2400
 
2250
- # Distance type for queries (cosine for semantic similarity)
2251
- # Note: When vector index exists, the index's metric is used
2252
- search = search.distance_type("cosine")
2401
+ @with_stale_connection_recovery
2402
+ @retry_on_storage_error(max_attempts=3, backoff=0.5)
2403
+ def batch_vector_search_native(
2404
+ self,
2405
+ query_vectors: list[np.ndarray],
2406
+ limit_per_query: int = 3,
2407
+ namespace: str | None = None,
2408
+ min_similarity: float = 0.0,
2409
+ include_vector: bool = False,
2410
+ ) -> list[list[dict[str, Any]]]:
2411
+ """Batch search using native LanceDB. Delegates to SearchManager.
2253
2412
 
2254
- # Apply performance tuning when index exists (use cached count)
2255
- count = self._get_cached_row_count()
2256
- if count > self.vector_index_threshold and self._has_vector_index:
2257
- # Use dynamic calculation for search params
2258
- actual_nprobes, actual_refine = self._calculate_search_params(
2259
- count, limit, nprobes, refine_factor
2260
- )
2261
- search = search.nprobes(actual_nprobes)
2262
- search = search.refine_factor(actual_refine)
2413
+ Args:
2414
+ query_vectors: List of query embedding vectors.
2415
+ limit_per_query: Maximum number of results per query.
2416
+ namespace: Filter to specific namespace.
2417
+ min_similarity: Minimum similarity threshold (0-1).
2418
+ include_vector: Whether to include vector embeddings in results.
2263
2419
 
2264
- # Build filter with sanitized namespace
2265
- # prefilter=True applies namespace filter BEFORE vector search for better performance
2266
- if namespace:
2267
- namespace = _validate_namespace(namespace)
2268
- safe_ns = _sanitize_string(namespace)
2269
- search = search.where(f"namespace = '{safe_ns}'", prefilter=True)
2270
-
2271
- # Vector projection: exclude vector column to reduce response size
2272
- if not include_vector:
2273
- search = search.select([
2274
- "id", "content", "namespace", "metadata",
2275
- "created_at", "updated_at", "last_accessed",
2276
- "importance", "tags", "source", "access_count",
2277
- "expires_at",
2278
- ])
2279
-
2280
- # Fetch extra if filtering by similarity
2281
- fetch_limit = limit * 2 if min_similarity > 0.0 else limit
2282
- results: list[dict[str, Any]] = search.limit(fetch_limit).to_list()
2283
-
2284
- # Process results
2285
- filtered_results: list[dict[str, Any]] = []
2286
- for record in results:
2287
- record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
2288
- # LanceDB returns _distance, convert to similarity
2289
- if "_distance" in record:
2290
- # Cosine distance to similarity: 1 - distance
2291
- # Clamp to [0, 1] (cosine distance can exceed 1 for unnormalized)
2292
- similarity = max(0.0, min(1.0, 1 - record["_distance"]))
2293
- record["similarity"] = similarity
2294
- del record["_distance"]
2295
-
2296
- # Apply similarity threshold
2297
- if record.get("similarity", 0) >= min_similarity:
2298
- filtered_results.append(record)
2299
- if len(filtered_results) >= limit:
2300
- break
2301
-
2302
- return filtered_results
2303
- except ValidationError:
2304
- raise
2305
- except Exception as e:
2306
- raise StorageError(f"Failed to search: {e}") from e
2420
+ Returns:
2421
+ List of result lists, one per query vector.
2422
+
2423
+ Raises:
2424
+ ValidationError: If input validation fails.
2425
+ StorageError: If database operation fails.
2426
+ """
2427
+ if self._search_manager is None:
2428
+ raise StorageError("Database not connected")
2429
+ return self._search_manager.batch_vector_search_native(
2430
+ query_vectors=query_vectors,
2431
+ limit_per_query=limit_per_query,
2432
+ namespace=namespace,
2433
+ min_similarity=min_similarity,
2434
+ include_vector=include_vector,
2435
+ )
2307
2436
 
2308
2437
  @with_stale_connection_recovery
2309
2438
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2316,10 +2445,7 @@ class Database:
2316
2445
  alpha: float = 0.5,
2317
2446
  min_similarity: float = 0.0,
2318
2447
  ) -> list[dict[str, Any]]:
2319
- """Hybrid search combining vector similarity and keyword matching.
2320
-
2321
- Uses LinearCombinationReranker to balance vector and keyword scores
2322
- based on the alpha parameter.
2448
+ """Hybrid search combining vector and keyword. Delegates to SearchManager.
2323
2449
 
2324
2450
  Args:
2325
2451
  query: Text query for full-text search.
@@ -2327,9 +2453,7 @@ class Database:
2327
2453
  limit: Number of results.
2328
2454
  namespace: Filter to namespace.
2329
2455
  alpha: Balance between vector (1.0) and keyword (0.0).
2330
- 0.5 = balanced (recommended).
2331
- min_similarity: Minimum similarity threshold (0.0-1.0).
2332
- Results below this threshold are filtered out.
2456
+ min_similarity: Minimum similarity threshold.
2333
2457
 
2334
2458
  Returns:
2335
2459
  List of memory records with combined scores.
@@ -2338,80 +2462,16 @@ class Database:
2338
2462
  ValidationError: If input validation fails.
2339
2463
  StorageError: If database operation fails.
2340
2464
  """
2341
- try:
2342
- # Check if FTS is available
2343
- if not self._has_fts_index:
2344
- logger.debug("FTS index not available, falling back to vector search")
2345
- return self.vector_search(query_vector, limit=limit, namespace=namespace)
2346
-
2347
- # Create hybrid search with explicit vector column specification
2348
- # Required when using external embeddings (not LanceDB built-in)
2349
- search = (
2350
- self.table.search(query, query_type="hybrid")
2351
- .vector(query_vector.tolist())
2352
- .vector_column_name("vector")
2353
- )
2354
-
2355
- # Apply alpha parameter using LinearCombinationReranker
2356
- # alpha=1.0 means full vector, alpha=0.0 means full FTS
2357
- try:
2358
- from lancedb.rerankers import LinearCombinationReranker
2359
-
2360
- reranker = LinearCombinationReranker(weight=alpha)
2361
- search = search.rerank(reranker)
2362
- except ImportError:
2363
- logger.debug("LinearCombinationReranker not available, using default reranking")
2364
- except Exception as e:
2365
- logger.debug(f"Could not apply reranker: {e}")
2366
-
2367
- # Apply namespace filter
2368
- if namespace:
2369
- namespace = _validate_namespace(namespace)
2370
- safe_ns = _sanitize_string(namespace)
2371
- search = search.where(f"namespace = '{safe_ns}'")
2372
-
2373
- results: list[dict[str, Any]] = search.limit(limit).to_list()
2374
-
2375
- # Process results - normalize scores and clean up internal columns
2376
- processed_results: list[dict[str, Any]] = []
2377
- for record in results:
2378
- record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
2379
-
2380
- # Compute similarity from various score columns
2381
- # Priority: _relevance_score > _distance > _score > default
2382
- similarity: float
2383
- if "_relevance_score" in record:
2384
- # Reranker output - use directly (already 0-1 range)
2385
- similarity = float(record["_relevance_score"])
2386
- del record["_relevance_score"]
2387
- elif "_distance" in record:
2388
- # Vector distance - convert to similarity
2389
- similarity = max(0.0, min(1.0, 1 - float(record["_distance"])))
2390
- del record["_distance"]
2391
- elif "_score" in record:
2392
- # BM25 score - normalize using score/(1+score)
2393
- score = float(record["_score"])
2394
- similarity = score / (1.0 + score)
2395
- del record["_score"]
2396
- else:
2397
- # No score column - use default
2398
- similarity = 0.5
2399
-
2400
- record["similarity"] = similarity
2401
-
2402
- # Mark as hybrid result with alpha value
2403
- record["search_type"] = "hybrid"
2404
- record["alpha"] = alpha
2405
-
2406
- # Apply min_similarity filter
2407
- if similarity >= min_similarity:
2408
- processed_results.append(record)
2409
-
2410
- return processed_results
2411
-
2412
- except Exception as e:
2413
- logger.warning(f"Hybrid search failed, falling back to vector search: {e}")
2414
- return self.vector_search(query_vector, limit=limit, namespace=namespace)
2465
+ if self._search_manager is None:
2466
+ raise StorageError("Database not connected")
2467
+ return self._search_manager.hybrid_search(
2468
+ query=query,
2469
+ query_vector=query_vector,
2470
+ limit=limit,
2471
+ namespace=namespace,
2472
+ alpha=alpha,
2473
+ min_similarity=min_similarity,
2474
+ )
2415
2475
 
2416
2476
  @with_stale_connection_recovery
2417
2477
  @retry_on_storage_error(max_attempts=3, backoff=0.5)
@@ -2420,20 +2480,19 @@ class Database:
2420
2480
  query_vectors: list[np.ndarray],
2421
2481
  limit_per_query: int = 3,
2422
2482
  namespace: str | None = None,
2423
- parallel: bool = False,
2424
- max_workers: int = 4,
2483
+ parallel: bool = False, # Deprecated
2484
+ max_workers: int = 4, # Deprecated
2485
+ include_vector: bool = False,
2425
2486
  ) -> list[list[dict[str, Any]]]:
2426
- """Search for similar memories using multiple query vectors.
2427
-
2428
- Efficient for operations like journey interpolation where multiple
2429
- points need to find nearby memories.
2487
+ """Search using multiple query vectors. Delegates to SearchManager.
2430
2488
 
2431
2489
  Args:
2432
2490
  query_vectors: List of query embedding vectors.
2433
2491
  limit_per_query: Maximum results per query vector.
2434
2492
  namespace: Filter to specific namespace.
2435
- parallel: Execute searches in parallel using ThreadPoolExecutor.
2436
- max_workers: Maximum worker threads for parallel execution.
2493
+ parallel: Deprecated, kept for backward compatibility.
2494
+ max_workers: Deprecated, kept for backward compatibility.
2495
+ include_vector: Whether to include vector embeddings in results.
2437
2496
 
2438
2497
  Returns:
2439
2498
  List of result lists (one per query vector).
@@ -2441,52 +2500,16 @@ class Database:
2441
2500
  Raises:
2442
2501
  StorageError: If database operation fails.
2443
2502
  """
2444
- if not query_vectors:
2445
- return []
2446
-
2447
- # Build namespace filter once
2448
- where_clause: str | None = None
2449
- if namespace:
2450
- namespace = _validate_namespace(namespace)
2451
- safe_ns = _sanitize_string(namespace)
2452
- where_clause = f"namespace = '{safe_ns}'"
2453
-
2454
- def search_single(vec: np.ndarray) -> list[dict[str, Any]]:
2455
- """Execute a single vector search."""
2456
- search = self.table.search(vec.tolist()).distance_type("cosine")
2457
-
2458
- if where_clause:
2459
- search = search.where(where_clause)
2460
-
2461
- results: list[dict[str, Any]] = search.limit(limit_per_query).to_list()
2462
-
2463
- # Process results
2464
- for record in results:
2465
- meta = record["metadata"]
2466
- record["metadata"] = json.loads(meta) if meta else {}
2467
- if "_distance" in record:
2468
- record["similarity"] = max(0.0, min(1.0, 1 - record["_distance"]))
2469
- del record["_distance"]
2470
-
2471
- return results
2472
-
2473
- try:
2474
- if parallel and len(query_vectors) > 1:
2475
- # Use ThreadPoolExecutor for parallel execution
2476
- from concurrent.futures import ThreadPoolExecutor
2477
-
2478
- workers = min(max_workers, len(query_vectors))
2479
- with ThreadPoolExecutor(max_workers=workers) as executor:
2480
- # Map preserves order
2481
- all_results = list(executor.map(search_single, query_vectors))
2482
- else:
2483
- # Sequential execution
2484
- all_results = [search_single(vec) for vec in query_vectors]
2485
-
2486
- return all_results
2487
-
2488
- except Exception as e:
2489
- raise StorageError(f"Batch vector search failed: {e}") from e
2503
+ if self._search_manager is None:
2504
+ raise StorageError("Database not connected")
2505
+ return self._search_manager.batch_vector_search(
2506
+ query_vectors=query_vectors,
2507
+ limit_per_query=limit_per_query,
2508
+ namespace=namespace,
2509
+ parallel=parallel,
2510
+ max_workers=max_workers,
2511
+ include_vector=include_vector,
2512
+ )
2490
2513
 
2491
2514
  def get_vectors_for_clustering(
2492
2515
  self,
@@ -2932,6 +2955,7 @@ class Database:
2932
2955
 
2933
2956
  if deleted > 0:
2934
2957
  self._invalidate_count_cache()
2958
+ self._track_modification(deleted)
2935
2959
  self._invalidate_namespace_cache()
2936
2960
  logger.info(f"Cleaned up {deleted} expired memories")
2937
2961
 
@@ -2940,108 +2964,106 @@ class Database:
2940
2964
  raise StorageError(f"Failed to cleanup expired memories: {e}") from e
2941
2965
 
2942
2966
  # ========================================================================
2943
- # Snapshot / Version Management
2967
+ # Snapshot / Version Management (delegated to VersionManager)
2944
2968
  # ========================================================================
2945
2969
 
2946
2970
  def create_snapshot(self, tag: str) -> int:
2947
2971
  """Create a named snapshot of the current table state.
2948
2972
 
2949
- LanceDB automatically versions data on every write. This method
2950
- returns the current version number which can be used with restore_snapshot().
2951
-
2952
- Args:
2953
- tag: Semantic version tag (e.g., "v1.0.0", "backup-2024-01").
2954
- Note: Tag is logged for reference but LanceDB tracks versions
2955
- numerically. Consider storing tag->version mappings externally
2956
- if tag-based retrieval is needed.
2957
-
2958
- Returns:
2959
- Version number of the snapshot.
2960
-
2961
- Raises:
2962
- StorageError: If snapshot creation fails.
2973
+ Delegates to VersionManager. See VersionManager.create_snapshot for details.
2963
2974
  """
2964
- try:
2965
- version = self.table.version
2966
- logger.info(f"Created snapshot '{tag}' at version {version}")
2967
- return version
2968
- except Exception as e:
2969
- raise StorageError(f"Failed to create snapshot: {e}") from e
2975
+ if self._version_manager is None:
2976
+ raise StorageError("Database not connected")
2977
+ return self._version_manager.create_snapshot(tag)
2970
2978
 
2971
2979
  def list_snapshots(self) -> list[dict[str, Any]]:
2972
2980
  """List available versions/snapshots.
2973
2981
 
2974
- Returns:
2975
- List of version information dictionaries. Each dict contains
2976
- at minimum 'version' key. Additional fields depend on LanceDB
2977
- version and available metadata.
2982
+ Delegates to VersionManager. See VersionManager.list_snapshots for details.
2983
+ """
2984
+ if self._version_manager is None:
2985
+ raise StorageError("Database not connected")
2986
+ return self._version_manager.list_snapshots()
2978
2987
 
2979
- Raises:
2980
- StorageError: If listing fails.
2988
+ def restore_snapshot(self, version: int) -> None:
2989
+ """Restore table to a specific version.
2990
+
2991
+ Delegates to VersionManager. See VersionManager.restore_snapshot for details.
2981
2992
  """
2982
- try:
2983
- versions_info: list[dict[str, Any]] = []
2993
+ if self._version_manager is None:
2994
+ raise StorageError("Database not connected")
2995
+ self._version_manager.restore_snapshot(version)
2984
2996
 
2985
- # Try to get version history if available
2986
- if hasattr(self.table, "list_versions"):
2987
- try:
2988
- versions = self.table.list_versions()
2989
- for v in versions:
2990
- if isinstance(v, dict):
2991
- versions_info.append(v)
2992
- elif hasattr(v, "version"):
2993
- versions_info.append({
2994
- "version": v.version,
2995
- "timestamp": getattr(v, "timestamp", None),
2996
- })
2997
- else:
2998
- versions_info.append({"version": v})
2999
- except Exception as e:
3000
- logger.debug(f"list_versions not fully supported: {e}")
2997
+ def get_current_version(self) -> int:
2998
+ """Get the current table version number.
3001
2999
 
3002
- # Always include current version
3003
- if not versions_info:
3004
- versions_info.append({"version": self.table.version})
3000
+ Delegates to VersionManager. See VersionManager.get_current_version for details.
3001
+ """
3002
+ if self._version_manager is None:
3003
+ raise StorageError("Database not connected")
3004
+ return self._version_manager.get_current_version()
3005
3005
 
3006
- return versions_info
3007
- except Exception as e:
3008
- logger.warning(f"Could not list snapshots: {e}")
3009
- return [{"version": 0, "error": str(e)}]
3006
+ # ========================================================================
3007
+ # Idempotency Key Management (delegates to IdempotencyManager)
3008
+ # ========================================================================
3010
3009
 
3011
- def restore_snapshot(self, version: int) -> None:
3012
- """Restore table to a specific version.
3010
+ @property
3011
+ def idempotency_table(self) -> LanceTable:
3012
+ """Get the idempotency keys table. Delegates to IdempotencyManager."""
3013
+ if self._idempotency_manager is None:
3014
+ raise StorageError("Database not connected")
3015
+ return self._idempotency_manager.idempotency_table
3013
3016
 
3014
- This creates a NEW version that reflects the old state
3015
- (doesn't delete history).
3017
+ def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
3018
+ """Look up an idempotency record by key. Delegates to IdempotencyManager.
3016
3019
 
3017
3020
  Args:
3018
- version: The version number to restore to.
3021
+ key: The idempotency key to look up.
3022
+
3023
+ Returns:
3024
+ IdempotencyRecord if found and not expired, None otherwise.
3019
3025
 
3020
3026
  Raises:
3021
- ValidationError: If version is invalid.
3022
- StorageError: If restore fails.
3027
+ StorageError: If database operation fails.
3023
3028
  """
3024
- if version < 0:
3025
- raise ValidationError("Version must be non-negative")
3029
+ if self._idempotency_manager is None:
3030
+ raise StorageError("Database not connected")
3031
+ return self._idempotency_manager.get_by_idempotency_key(key)
3026
3032
 
3027
- try:
3028
- self.table.restore(version)
3029
- self._invalidate_count_cache()
3030
- self._invalidate_namespace_cache()
3031
- logger.info(f"Restored to version {version}")
3032
- except Exception as e:
3033
- raise StorageError(f"Failed to restore snapshot: {e}") from e
3033
+ @with_process_lock
3034
+ @with_write_lock
3035
+ def store_idempotency_key(
3036
+ self,
3037
+ key: str,
3038
+ memory_id: str,
3039
+ ttl_hours: float = 24.0,
3040
+ ) -> None:
3041
+ """Store an idempotency key mapping. Delegates to IdempotencyManager.
3034
3042
 
3035
- def get_current_version(self) -> int:
3036
- """Get the current table version number.
3043
+ Args:
3044
+ key: The idempotency key.
3045
+ memory_id: The memory ID that was created.
3046
+ ttl_hours: Time-to-live in hours (default: 24 hours).
3047
+
3048
+ Raises:
3049
+ ValidationError: If inputs are invalid.
3050
+ StorageError: If database operation fails.
3051
+ """
3052
+ if self._idempotency_manager is None:
3053
+ raise StorageError("Database not connected")
3054
+ self._idempotency_manager.store_idempotency_key(key, memory_id, ttl_hours)
3055
+
3056
+ @with_process_lock
3057
+ @with_write_lock
3058
+ def cleanup_expired_idempotency_keys(self) -> int:
3059
+ """Remove expired idempotency keys. Delegates to IdempotencyManager.
3037
3060
 
3038
3061
  Returns:
3039
- Current version number.
3062
+ Number of keys removed.
3040
3063
 
3041
3064
  Raises:
3042
- StorageError: If version cannot be retrieved.
3065
+ StorageError: If cleanup fails.
3043
3066
  """
3044
- try:
3045
- return self.table.version
3046
- except Exception as e:
3047
- raise StorageError(f"Failed to get current version: {e}") from e
3067
+ if self._idempotency_manager is None:
3068
+ raise StorageError("Database not connected")
3069
+ return self._idempotency_manager.cleanup_expired_idempotency_keys()