spatial-memory-mcp 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of spatial-memory-mcp might be problematic. Click here for more details.
- spatial_memory/__init__.py +1 -1
- spatial_memory/__main__.py +241 -2
- spatial_memory/adapters/lancedb_repository.py +74 -5
- spatial_memory/config.py +10 -2
- spatial_memory/core/__init__.py +9 -0
- spatial_memory/core/connection_pool.py +41 -3
- spatial_memory/core/consolidation_strategies.py +402 -0
- spatial_memory/core/database.py +774 -918
- spatial_memory/core/db_idempotency.py +242 -0
- spatial_memory/core/db_indexes.py +575 -0
- spatial_memory/core/db_migrations.py +584 -0
- spatial_memory/core/db_search.py +509 -0
- spatial_memory/core/db_versioning.py +177 -0
- spatial_memory/core/embeddings.py +65 -18
- spatial_memory/core/errors.py +75 -3
- spatial_memory/core/filesystem.py +178 -0
- spatial_memory/core/models.py +4 -0
- spatial_memory/core/rate_limiter.py +26 -9
- spatial_memory/core/response_types.py +497 -0
- spatial_memory/core/validation.py +86 -2
- spatial_memory/factory.py +407 -0
- spatial_memory/migrations/__init__.py +40 -0
- spatial_memory/ports/repositories.py +52 -2
- spatial_memory/server.py +131 -189
- spatial_memory/services/export_import.py +61 -43
- spatial_memory/services/lifecycle.py +397 -122
- spatial_memory/services/memory.py +2 -2
- spatial_memory/services/spatial.py +129 -46
- {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/METADATA +83 -3
- spatial_memory_mcp-1.6.0.dist-info/RECORD +54 -0
- spatial_memory_mcp-1.5.3.dist-info/RECORD +0 -44
- {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/WHEEL +0 -0
- {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/entry_points.txt +0 -0
- {spatial_memory_mcp-1.5.3.dist-info → spatial_memory_mcp-1.6.0.dist-info}/licenses/LICENSE +0 -0
spatial_memory/core/database.py
CHANGED
|
@@ -34,7 +34,20 @@ import pyarrow.parquet as pq
|
|
|
34
34
|
from filelock import FileLock, Timeout as FileLockTimeout
|
|
35
35
|
|
|
36
36
|
from spatial_memory.core.connection_pool import ConnectionPool
|
|
37
|
-
from spatial_memory.core.
|
|
37
|
+
from spatial_memory.core.db_idempotency import IdempotencyManager, IdempotencyRecord
|
|
38
|
+
from spatial_memory.core.db_indexes import IndexManager
|
|
39
|
+
from spatial_memory.core.db_migrations import CURRENT_SCHEMA_VERSION, MigrationManager
|
|
40
|
+
from spatial_memory.core.db_search import SearchManager
|
|
41
|
+
from spatial_memory.core.db_versioning import VersionManager
|
|
42
|
+
from spatial_memory.core.errors import (
|
|
43
|
+
DimensionMismatchError,
|
|
44
|
+
FileLockError,
|
|
45
|
+
MemoryNotFoundError,
|
|
46
|
+
PartialBatchInsertError,
|
|
47
|
+
StorageError,
|
|
48
|
+
ValidationError,
|
|
49
|
+
)
|
|
50
|
+
from spatial_memory.core.filesystem import detect_filesystem_type, get_filesystem_warning_message, is_network_filesystem
|
|
38
51
|
from spatial_memory.core.utils import to_aware_utc, utc_now
|
|
39
52
|
|
|
40
53
|
# Import centralized validation functions
|
|
@@ -131,9 +144,14 @@ def invalidate_connection(storage_path: Path) -> bool:
|
|
|
131
144
|
# Retry Decorator
|
|
132
145
|
# ============================================================================
|
|
133
146
|
|
|
147
|
+
# Default retry settings (can be overridden per-call)
|
|
148
|
+
DEFAULT_RETRY_MAX_ATTEMPTS = 3
|
|
149
|
+
DEFAULT_RETRY_BACKOFF_SECONDS = 0.5
|
|
150
|
+
|
|
151
|
+
|
|
134
152
|
def retry_on_storage_error(
|
|
135
|
-
max_attempts: int =
|
|
136
|
-
backoff: float =
|
|
153
|
+
max_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
|
|
154
|
+
backoff: float = DEFAULT_RETRY_BACKOFF_SECONDS,
|
|
137
155
|
) -> Callable[[F], F]:
|
|
138
156
|
"""Retry decorator for transient storage errors.
|
|
139
157
|
|
|
@@ -391,15 +409,6 @@ def with_process_lock(func: F) -> F:
|
|
|
391
409
|
# Health Metrics
|
|
392
410
|
# ============================================================================
|
|
393
411
|
|
|
394
|
-
@dataclass
|
|
395
|
-
class IdempotencyRecord:
|
|
396
|
-
"""Record for idempotency key tracking."""
|
|
397
|
-
key: str
|
|
398
|
-
memory_id: str
|
|
399
|
-
created_at: Any # datetime
|
|
400
|
-
expires_at: Any # datetime
|
|
401
|
-
|
|
402
|
-
|
|
403
412
|
@dataclass
|
|
404
413
|
class IndexStats:
|
|
405
414
|
"""Statistics for a single index."""
|
|
@@ -492,8 +501,8 @@ class Database:
|
|
|
492
501
|
enable_fts: bool = True,
|
|
493
502
|
index_nprobes: int = 20,
|
|
494
503
|
index_refine_factor: int = 5,
|
|
495
|
-
max_retry_attempts: int =
|
|
496
|
-
retry_backoff_seconds: float =
|
|
504
|
+
max_retry_attempts: int = DEFAULT_RETRY_MAX_ATTEMPTS,
|
|
505
|
+
retry_backoff_seconds: float = DEFAULT_RETRY_BACKOFF_SECONDS,
|
|
497
506
|
read_consistency_interval_ms: int = 0,
|
|
498
507
|
index_wait_timeout_seconds: float = 30.0,
|
|
499
508
|
fts_stem: bool = True,
|
|
@@ -507,6 +516,7 @@ class Database:
|
|
|
507
516
|
filelock_enabled: bool = True,
|
|
508
517
|
filelock_timeout: float = 30.0,
|
|
509
518
|
filelock_poll_interval: float = 0.1,
|
|
519
|
+
acknowledge_network_filesystem_risk: bool = False,
|
|
510
520
|
) -> None:
|
|
511
521
|
"""Initialize the database connection.
|
|
512
522
|
|
|
@@ -533,6 +543,7 @@ class Database:
|
|
|
533
543
|
hnsw_ef_construction: HNSW build-time search width (100-1000).
|
|
534
544
|
enable_memory_expiration: Enable automatic memory expiration.
|
|
535
545
|
default_memory_ttl_days: Default TTL for memories in days (None = no expiration).
|
|
546
|
+
acknowledge_network_filesystem_risk: Suppress network filesystem warnings.
|
|
536
547
|
"""
|
|
537
548
|
self.storage_path = Path(storage_path)
|
|
538
549
|
self.embedding_dim = embedding_dim
|
|
@@ -556,6 +567,7 @@ class Database:
|
|
|
556
567
|
self.filelock_enabled = filelock_enabled
|
|
557
568
|
self.filelock_timeout = filelock_timeout
|
|
558
569
|
self.filelock_poll_interval = filelock_poll_interval
|
|
570
|
+
self.acknowledge_network_filesystem_risk = acknowledge_network_filesystem_risk
|
|
559
571
|
self._db: lancedb.DBConnection | None = None
|
|
560
572
|
self._table: LanceTable | None = None
|
|
561
573
|
self._has_vector_index: bool | None = None
|
|
@@ -573,6 +585,18 @@ class Database:
|
|
|
573
585
|
self._write_lock = threading.RLock()
|
|
574
586
|
# Cross-process lock (initialized in connect())
|
|
575
587
|
self._process_lock: ProcessLockManager | None = None
|
|
588
|
+
# Auto-compaction tracking
|
|
589
|
+
self._modification_count: int = 0
|
|
590
|
+
self._auto_compaction_threshold: int = 100 # Compact after this many modifications
|
|
591
|
+
self._auto_compaction_enabled: bool = True
|
|
592
|
+
# Version manager (initialized in connect())
|
|
593
|
+
self._version_manager: VersionManager | None = None
|
|
594
|
+
# Index manager (initialized in connect())
|
|
595
|
+
self._index_manager: IndexManager | None = None
|
|
596
|
+
# Search manager (initialized in connect())
|
|
597
|
+
self._search_manager: SearchManager | None = None
|
|
598
|
+
# Idempotency manager (initialized in connect())
|
|
599
|
+
self._idempotency_manager: IdempotencyManager | None = None
|
|
576
600
|
|
|
577
601
|
def __enter__(self) -> Database:
|
|
578
602
|
"""Enter context manager."""
|
|
@@ -588,6 +612,13 @@ class Database:
|
|
|
588
612
|
try:
|
|
589
613
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
|
590
614
|
|
|
615
|
+
# Check for network filesystem and warn if detected
|
|
616
|
+
if not self.acknowledge_network_filesystem_risk:
|
|
617
|
+
if is_network_filesystem(self.storage_path):
|
|
618
|
+
fs_type = detect_filesystem_type(self.storage_path)
|
|
619
|
+
warning_msg = get_filesystem_warning_message(fs_type, self.storage_path)
|
|
620
|
+
logger.warning(warning_msg)
|
|
621
|
+
|
|
591
622
|
# Initialize cross-process lock manager
|
|
592
623
|
if self.filelock_enabled:
|
|
593
624
|
lock_path = self.storage_path / ".spatial-memory.lock"
|
|
@@ -606,109 +637,144 @@ class Database:
|
|
|
606
637
|
read_consistency_interval_ms=self.read_consistency_interval_ms,
|
|
607
638
|
)
|
|
608
639
|
self._ensure_table()
|
|
640
|
+
# Initialize remaining managers (IndexManager already initialized in _ensure_table)
|
|
641
|
+
self._version_manager = VersionManager(self)
|
|
642
|
+
self._search_manager = SearchManager(self)
|
|
643
|
+
self._idempotency_manager = IdempotencyManager(self)
|
|
609
644
|
logger.info(f"Connected to LanceDB at {self.storage_path}")
|
|
645
|
+
|
|
646
|
+
# Check for pending schema migrations
|
|
647
|
+
self._check_pending_migrations()
|
|
610
648
|
except Exception as e:
|
|
611
649
|
raise StorageError(f"Failed to connect to database: {e}") from e
|
|
612
650
|
|
|
651
|
+
def _check_pending_migrations(self) -> None:
|
|
652
|
+
"""Check for pending migrations and warn if any exist.
|
|
653
|
+
|
|
654
|
+
This method checks the schema version and logs a warning if there
|
|
655
|
+
are pending migrations. It does not auto-apply migrations - that
|
|
656
|
+
requires explicit user action via the CLI.
|
|
657
|
+
"""
|
|
658
|
+
try:
|
|
659
|
+
manager = MigrationManager(self, embeddings=None)
|
|
660
|
+
manager.register_builtin_migrations()
|
|
661
|
+
|
|
662
|
+
current_version = manager.get_current_version()
|
|
663
|
+
pending = manager.get_pending_migrations()
|
|
664
|
+
|
|
665
|
+
if pending:
|
|
666
|
+
pending_versions = [m.version for m in pending]
|
|
667
|
+
logger.warning(
|
|
668
|
+
f"Database schema version {current_version} is outdated. "
|
|
669
|
+
f"{len(pending)} migration(s) pending: {', '.join(pending_versions)}. "
|
|
670
|
+
f"Target version: {CURRENT_SCHEMA_VERSION}. "
|
|
671
|
+
f"Run 'spatial-memory migrate' to apply migrations."
|
|
672
|
+
)
|
|
673
|
+
except Exception as e:
|
|
674
|
+
# Don't fail connection due to migration check errors
|
|
675
|
+
logger.debug(f"Migration check skipped: {e}")
|
|
676
|
+
|
|
613
677
|
def _ensure_table(self) -> None:
|
|
614
|
-
"""Ensure the memories table exists with appropriate indexes.
|
|
678
|
+
"""Ensure the memories table exists with appropriate indexes.
|
|
679
|
+
|
|
680
|
+
Uses retry logic to handle race conditions when multiple processes
|
|
681
|
+
attempt to create/open the table simultaneously.
|
|
682
|
+
"""
|
|
615
683
|
if self._db is None:
|
|
616
684
|
raise StorageError("Database not connected")
|
|
617
685
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
if hasattr(existing_tables_result, 'tables'):
|
|
621
|
-
existing_tables = existing_tables_result.tables
|
|
622
|
-
else:
|
|
623
|
-
existing_tables = existing_tables_result
|
|
624
|
-
if "memories" not in existing_tables:
|
|
625
|
-
# Create table with schema
|
|
626
|
-
schema = pa.schema([
|
|
627
|
-
pa.field("id", pa.string()),
|
|
628
|
-
pa.field("content", pa.string()),
|
|
629
|
-
pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
|
|
630
|
-
pa.field("created_at", pa.timestamp("us")),
|
|
631
|
-
pa.field("updated_at", pa.timestamp("us")),
|
|
632
|
-
pa.field("last_accessed", pa.timestamp("us")),
|
|
633
|
-
pa.field("access_count", pa.int32()),
|
|
634
|
-
pa.field("importance", pa.float32()),
|
|
635
|
-
pa.field("namespace", pa.string()),
|
|
636
|
-
pa.field("tags", pa.list_(pa.string())),
|
|
637
|
-
pa.field("source", pa.string()),
|
|
638
|
-
pa.field("metadata", pa.string()),
|
|
639
|
-
pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
|
|
640
|
-
])
|
|
641
|
-
self._table = self._db.create_table("memories", schema=schema)
|
|
642
|
-
logger.info("Created memories table")
|
|
643
|
-
|
|
644
|
-
# Create FTS index on new table if enabled
|
|
645
|
-
if self.enable_fts:
|
|
646
|
-
self._create_fts_index()
|
|
647
|
-
else:
|
|
648
|
-
self._table = self._db.open_table("memories")
|
|
649
|
-
logger.debug("Opened existing memories table")
|
|
686
|
+
max_retries = 3
|
|
687
|
+
retry_delay = 0.1 # Start with 100ms
|
|
650
688
|
|
|
651
|
-
|
|
652
|
-
|
|
689
|
+
for attempt in range(max_retries):
|
|
690
|
+
try:
|
|
691
|
+
existing_tables_result = self._db.list_tables()
|
|
692
|
+
# Handle both old (list) and new (object with .tables) LanceDB API
|
|
693
|
+
if hasattr(existing_tables_result, 'tables'):
|
|
694
|
+
existing_tables = existing_tables_result.tables
|
|
695
|
+
else:
|
|
696
|
+
existing_tables = existing_tables_result
|
|
697
|
+
|
|
698
|
+
if "memories" not in existing_tables:
|
|
699
|
+
# Create table with schema
|
|
700
|
+
schema = pa.schema([
|
|
701
|
+
pa.field("id", pa.string()),
|
|
702
|
+
pa.field("content", pa.string()),
|
|
703
|
+
pa.field("vector", pa.list_(pa.float32(), self.embedding_dim)),
|
|
704
|
+
pa.field("created_at", pa.timestamp("us")),
|
|
705
|
+
pa.field("updated_at", pa.timestamp("us")),
|
|
706
|
+
pa.field("last_accessed", pa.timestamp("us")),
|
|
707
|
+
pa.field("access_count", pa.int32()),
|
|
708
|
+
pa.field("importance", pa.float32()),
|
|
709
|
+
pa.field("namespace", pa.string()),
|
|
710
|
+
pa.field("tags", pa.list_(pa.string())),
|
|
711
|
+
pa.field("source", pa.string()),
|
|
712
|
+
pa.field("metadata", pa.string()),
|
|
713
|
+
pa.field("expires_at", pa.timestamp("us")), # TTL support - nullable
|
|
714
|
+
])
|
|
715
|
+
try:
|
|
716
|
+
self._table = self._db.create_table("memories", schema=schema)
|
|
717
|
+
logger.info("Created memories table")
|
|
718
|
+
except Exception as create_err:
|
|
719
|
+
# Table might have been created by another process
|
|
720
|
+
if "already exists" in str(create_err).lower():
|
|
721
|
+
logger.debug("Table created by another process, opening it")
|
|
722
|
+
self._table = self._db.open_table("memories")
|
|
723
|
+
else:
|
|
724
|
+
raise
|
|
653
725
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
try:
|
|
657
|
-
indices = self.table.list_indices()
|
|
726
|
+
# Initialize IndexManager immediately after table is set
|
|
727
|
+
self._index_manager = IndexManager(self)
|
|
658
728
|
|
|
659
|
-
|
|
660
|
-
|
|
729
|
+
# Create FTS index on new table if enabled
|
|
730
|
+
if self.enable_fts:
|
|
731
|
+
self._index_manager.create_fts_index()
|
|
732
|
+
else:
|
|
733
|
+
self._table = self._db.open_table("memories")
|
|
734
|
+
logger.debug("Opened existing memories table")
|
|
661
735
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
index_type = str(_get_index_attr(idx, "index_type", "")).upper()
|
|
665
|
-
columns = _get_index_attr(idx, "columns", [])
|
|
736
|
+
# Initialize IndexManager immediately after table is set
|
|
737
|
+
self._index_manager = IndexManager(self)
|
|
666
738
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
self._has_vector_index = True
|
|
670
|
-
elif "vector" in columns or "vector" in index_name:
|
|
671
|
-
self._has_vector_index = True
|
|
739
|
+
# Check existing indexes
|
|
740
|
+
self._index_manager.check_existing_indexes()
|
|
672
741
|
|
|
673
|
-
#
|
|
674
|
-
|
|
675
|
-
self._has_fts_index = True
|
|
676
|
-
elif "fts" in index_name or "content" in index_name:
|
|
677
|
-
self._has_fts_index = True
|
|
742
|
+
# Success - exit retry loop
|
|
743
|
+
return
|
|
678
744
|
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
745
|
+
except Exception as e:
|
|
746
|
+
error_msg = str(e).lower()
|
|
747
|
+
# Retry on transient race conditions
|
|
748
|
+
if attempt < max_retries - 1 and (
|
|
749
|
+
"not found" in error_msg
|
|
750
|
+
or "does not exist" in error_msg
|
|
751
|
+
or "already exists" in error_msg
|
|
752
|
+
):
|
|
753
|
+
logger.debug(
|
|
754
|
+
f"Table operation failed (attempt {attempt + 1}/{max_retries}), "
|
|
755
|
+
f"retrying in {retry_delay}s: {e}"
|
|
756
|
+
)
|
|
757
|
+
time.sleep(retry_delay)
|
|
758
|
+
retry_delay *= 2 # Exponential backoff
|
|
759
|
+
else:
|
|
760
|
+
raise
|
|
761
|
+
|
|
762
|
+
def _check_existing_indexes(self) -> None:
|
|
763
|
+
"""Check which indexes already exist. Delegates to IndexManager."""
|
|
764
|
+
if self._index_manager is None:
|
|
765
|
+
raise StorageError("Database not connected")
|
|
766
|
+
self._index_manager.check_existing_indexes()
|
|
767
|
+
# Sync local state for backward compatibility
|
|
768
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
769
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
687
770
|
|
|
688
771
|
def _create_fts_index(self) -> None:
|
|
689
|
-
"""Create
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
stem=self.fts_stem,
|
|
696
|
-
remove_stop_words=self.fts_remove_stop_words,
|
|
697
|
-
with_position=True, # Enable phrase queries
|
|
698
|
-
lower_case=True, # Case-insensitive search
|
|
699
|
-
)
|
|
700
|
-
self._has_fts_index = True
|
|
701
|
-
logger.info(
|
|
702
|
-
f"Created FTS index with stemming={self.fts_stem}, "
|
|
703
|
-
f"stop_words={self.fts_remove_stop_words}"
|
|
704
|
-
)
|
|
705
|
-
except Exception as e:
|
|
706
|
-
# Check if index already exists (not an error)
|
|
707
|
-
if "already exists" in str(e).lower():
|
|
708
|
-
self._has_fts_index = True
|
|
709
|
-
logger.debug("FTS index already exists")
|
|
710
|
-
else:
|
|
711
|
-
logger.warning(f"FTS index creation failed: {e}")
|
|
772
|
+
"""Create FTS index. Delegates to IndexManager."""
|
|
773
|
+
if self._index_manager is None:
|
|
774
|
+
raise StorageError("Database not connected")
|
|
775
|
+
self._index_manager.create_fts_index()
|
|
776
|
+
# Sync local state for backward compatibility
|
|
777
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
712
778
|
|
|
713
779
|
@property
|
|
714
780
|
def table(self) -> LanceTable:
|
|
@@ -719,18 +785,30 @@ class Database:
|
|
|
719
785
|
return self._table
|
|
720
786
|
|
|
721
787
|
def close(self) -> None:
|
|
722
|
-
"""Close the database connection
|
|
788
|
+
"""Close the database connection and remove from pool.
|
|
789
|
+
|
|
790
|
+
This invalidates the pooled connection so that subsequent
|
|
791
|
+
Database instances will create fresh connections.
|
|
792
|
+
"""
|
|
793
|
+
# Invalidate pooled connection first
|
|
794
|
+
invalidate_connection(self.storage_path)
|
|
795
|
+
|
|
796
|
+
# Clear local state
|
|
723
797
|
self._table = None
|
|
724
798
|
self._db = None
|
|
725
799
|
self._has_vector_index = None
|
|
726
800
|
self._has_fts_index = None
|
|
801
|
+
self._version_manager = None
|
|
802
|
+
self._index_manager = None
|
|
803
|
+
self._search_manager = None
|
|
804
|
+
self._idempotency_manager = None
|
|
727
805
|
with self._cache_lock:
|
|
728
806
|
self._cached_row_count = None
|
|
729
807
|
self._count_cache_time = 0.0
|
|
730
808
|
with self._namespace_cache_lock:
|
|
731
809
|
self._cached_namespaces = None
|
|
732
810
|
self._namespace_cache_time = 0.0
|
|
733
|
-
logger.debug("Database connection closed")
|
|
811
|
+
logger.debug("Database connection closed and removed from pool")
|
|
734
812
|
|
|
735
813
|
def reconnect(self) -> None:
|
|
736
814
|
"""Invalidate cached connection and reconnect.
|
|
@@ -790,313 +868,86 @@ class Database:
|
|
|
790
868
|
self._cached_namespaces = None
|
|
791
869
|
self._namespace_cache_time = 0.0
|
|
792
870
|
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
# ========================================================================
|
|
796
|
-
|
|
797
|
-
def create_vector_index(self, force: bool = False) -> bool:
|
|
798
|
-
"""Create vector index for similarity search.
|
|
799
|
-
|
|
800
|
-
Supports IVF_PQ, IVF_FLAT, and HNSW_SQ index types based on configuration.
|
|
801
|
-
Automatically determines optimal parameters based on dataset size.
|
|
871
|
+
def _track_modification(self, count: int = 1) -> None:
|
|
872
|
+
"""Track database modifications and trigger auto-compaction if threshold reached.
|
|
802
873
|
|
|
803
874
|
Args:
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
Returns:
|
|
807
|
-
True if index was created, False if skipped.
|
|
808
|
-
|
|
809
|
-
Raises:
|
|
810
|
-
StorageError: If index creation fails.
|
|
875
|
+
count: Number of modifications to track (default 1).
|
|
811
876
|
"""
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
# Check threshold
|
|
815
|
-
if count < self.vector_index_threshold and not force:
|
|
816
|
-
logger.info(
|
|
817
|
-
f"Dataset has {count} rows, below threshold {self.vector_index_threshold}. "
|
|
818
|
-
"Skipping vector index creation."
|
|
819
|
-
)
|
|
820
|
-
return False
|
|
821
|
-
|
|
822
|
-
# Check if already exists
|
|
823
|
-
if self._has_vector_index and not force:
|
|
824
|
-
logger.info("Vector index already exists")
|
|
825
|
-
return False
|
|
826
|
-
|
|
827
|
-
# Handle HNSW_SQ index type
|
|
828
|
-
if self.index_type == "HNSW_SQ":
|
|
829
|
-
return self._create_hnsw_index(count)
|
|
830
|
-
|
|
831
|
-
# IVF-based index creation (IVF_PQ or IVF_FLAT)
|
|
832
|
-
return self._create_ivf_index(count)
|
|
877
|
+
if not self._auto_compaction_enabled:
|
|
878
|
+
return
|
|
833
879
|
|
|
834
|
-
|
|
835
|
-
|
|
880
|
+
self._modification_count += count
|
|
881
|
+
if self._modification_count >= self._auto_compaction_threshold:
|
|
882
|
+
# Reset counter before compacting to avoid re-triggering
|
|
883
|
+
self._modification_count = 0
|
|
884
|
+
try:
|
|
885
|
+
stats = self._get_table_stats()
|
|
886
|
+
# Only compact if there are enough fragments to justify it
|
|
887
|
+
if stats.get("num_small_fragments", 0) >= 5:
|
|
888
|
+
logger.info(
|
|
889
|
+
f"Auto-compaction triggered after {self._auto_compaction_threshold} "
|
|
890
|
+
f"modifications ({stats.get('num_small_fragments', 0)} small fragments)"
|
|
891
|
+
)
|
|
892
|
+
self.table.compact_files()
|
|
893
|
+
logger.debug("Auto-compaction completed")
|
|
894
|
+
except Exception as e:
|
|
895
|
+
# Don't fail operations due to compaction issues
|
|
896
|
+
logger.debug(f"Auto-compaction skipped: {e}")
|
|
836
897
|
|
|
837
|
-
|
|
838
|
-
|
|
898
|
+
def set_auto_compaction(
|
|
899
|
+
self,
|
|
900
|
+
enabled: bool = True,
|
|
901
|
+
threshold: int | None = None,
|
|
902
|
+
) -> None:
|
|
903
|
+
"""Configure auto-compaction behavior.
|
|
839
904
|
|
|
840
905
|
Args:
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
Returns:
|
|
844
|
-
True if index was created.
|
|
845
|
-
|
|
846
|
-
Raises:
|
|
847
|
-
StorageError: If index creation fails.
|
|
906
|
+
enabled: Whether auto-compaction is enabled.
|
|
907
|
+
threshold: Number of modifications before auto-compact (default: 100).
|
|
848
908
|
"""
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
try:
|
|
855
|
-
self.table.create_index(
|
|
856
|
-
metric="cosine",
|
|
857
|
-
vector_column_name="vector",
|
|
858
|
-
index_type="HNSW_SQ",
|
|
859
|
-
replace=True,
|
|
860
|
-
m=self.hnsw_m,
|
|
861
|
-
ef_construction=self.hnsw_ef_construction,
|
|
862
|
-
)
|
|
863
|
-
|
|
864
|
-
# Wait for index to be ready with configurable timeout
|
|
865
|
-
self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
|
|
866
|
-
|
|
867
|
-
self._has_vector_index = True
|
|
868
|
-
logger.info("HNSW_SQ vector index created successfully")
|
|
869
|
-
|
|
870
|
-
# Optimize after index creation (may fail in some environments)
|
|
871
|
-
try:
|
|
872
|
-
self.table.optimize()
|
|
873
|
-
except Exception as optimize_error:
|
|
874
|
-
logger.debug(f"Optimization after index creation skipped: {optimize_error}")
|
|
909
|
+
self._auto_compaction_enabled = enabled
|
|
910
|
+
if threshold is not None:
|
|
911
|
+
if threshold < 10:
|
|
912
|
+
raise ValueError("Auto-compaction threshold must be at least 10")
|
|
913
|
+
self._auto_compaction_threshold = threshold
|
|
875
914
|
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
logger.error(f"Failed to create HNSW_SQ vector index: {e}")
|
|
880
|
-
raise StorageError(f"HNSW_SQ vector index creation failed: {e}") from e
|
|
881
|
-
|
|
882
|
-
def _create_ivf_index(self, count: int) -> bool:
|
|
883
|
-
"""Create IVF-PQ or IVF-FLAT vector index.
|
|
915
|
+
# ========================================================================
|
|
916
|
+
# Index Management (delegates to IndexManager)
|
|
917
|
+
# ========================================================================
|
|
884
918
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
96 sub-vectors for >=500K rows (4 dims each).
|
|
919
|
+
def create_vector_index(self, force: bool = False) -> bool:
|
|
920
|
+
"""Create vector index for similarity search. Delegates to IndexManager.
|
|
888
921
|
|
|
889
922
|
Args:
|
|
890
|
-
|
|
923
|
+
force: Force index creation regardless of dataset size.
|
|
891
924
|
|
|
892
925
|
Returns:
|
|
893
|
-
True if index was created.
|
|
926
|
+
True if index was created, False if skipped.
|
|
894
927
|
|
|
895
928
|
Raises:
|
|
896
929
|
StorageError: If index creation fails.
|
|
897
930
|
"""
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
if count < 500_000:
|
|
906
|
-
num_sub_vectors = 48
|
|
907
|
-
else:
|
|
908
|
-
num_sub_vectors = 96
|
|
909
|
-
|
|
910
|
-
# Validate embedding_dim % num_sub_vectors == 0 (required for IVF-PQ)
|
|
911
|
-
if self.embedding_dim % num_sub_vectors != 0:
|
|
912
|
-
# Find a valid divisor from common sub-vector counts
|
|
913
|
-
valid_divisors = [96, 48, 32, 24, 16, 12, 8, 4]
|
|
914
|
-
found_divisor = False
|
|
915
|
-
for divisor in valid_divisors:
|
|
916
|
-
if self.embedding_dim % divisor == 0:
|
|
917
|
-
logger.info(
|
|
918
|
-
f"Adjusted num_sub_vectors from {num_sub_vectors} to {divisor} "
|
|
919
|
-
f"for embedding_dim={self.embedding_dim}"
|
|
920
|
-
)
|
|
921
|
-
num_sub_vectors = divisor
|
|
922
|
-
found_divisor = True
|
|
923
|
-
break
|
|
924
|
-
|
|
925
|
-
if not found_divisor:
|
|
926
|
-
raise StorageError(
|
|
927
|
-
f"Cannot create IVF-PQ index: embedding_dim={self.embedding_dim} "
|
|
928
|
-
"has no suitable divisor for sub-vectors. "
|
|
929
|
-
f"Tried divisors: {valid_divisors}"
|
|
930
|
-
)
|
|
931
|
-
|
|
932
|
-
# IVF-PQ requires minimum rows for training (sample_rate * num_partitions / 256)
|
|
933
|
-
# Default sample_rate=256, so we need at least 256 rows
|
|
934
|
-
# Also, IVF requires num_partitions < num_vectors for KMeans training
|
|
935
|
-
sample_rate = 256 # default
|
|
936
|
-
if count < 256:
|
|
937
|
-
# Use IVF_FLAT for very small datasets (no PQ training required)
|
|
938
|
-
logger.info(
|
|
939
|
-
f"Dataset too small for IVF-PQ ({count} rows < 256). "
|
|
940
|
-
"Using IVF_FLAT index instead."
|
|
941
|
-
)
|
|
942
|
-
index_type = "IVF_FLAT"
|
|
943
|
-
sample_rate = max(16, count // 4) # Lower sample rate for small data
|
|
944
|
-
else:
|
|
945
|
-
index_type = self.index_type if self.index_type in ("IVF_PQ", "IVF_FLAT") else "IVF_PQ"
|
|
946
|
-
|
|
947
|
-
# Ensure num_partitions < num_vectors for KMeans clustering
|
|
948
|
-
if num_partitions >= count:
|
|
949
|
-
num_partitions = max(1, count // 4) # Use 1/4 of count, minimum 1
|
|
950
|
-
logger.info(f"Adjusted num_partitions to {num_partitions} for {count} rows")
|
|
951
|
-
|
|
952
|
-
logger.info(
|
|
953
|
-
f"Creating {index_type} vector index: {num_partitions} partitions, "
|
|
954
|
-
f"{num_sub_vectors} sub-vectors for {count} rows"
|
|
955
|
-
)
|
|
956
|
-
|
|
957
|
-
try:
|
|
958
|
-
# LanceDB 0.27+ API: parameters passed directly to create_index
|
|
959
|
-
index_kwargs: dict[str, Any] = {
|
|
960
|
-
"metric": "cosine",
|
|
961
|
-
"num_partitions": num_partitions,
|
|
962
|
-
"vector_column_name": "vector",
|
|
963
|
-
"index_type": index_type,
|
|
964
|
-
"replace": True,
|
|
965
|
-
"sample_rate": sample_rate,
|
|
966
|
-
}
|
|
967
|
-
|
|
968
|
-
# num_sub_vectors only applies to PQ-based indexes
|
|
969
|
-
if "PQ" in index_type:
|
|
970
|
-
index_kwargs["num_sub_vectors"] = num_sub_vectors
|
|
971
|
-
|
|
972
|
-
self.table.create_index(**index_kwargs)
|
|
973
|
-
|
|
974
|
-
# Wait for index to be ready with configurable timeout
|
|
975
|
-
self._wait_for_index_ready("vector", self.index_wait_timeout_seconds)
|
|
976
|
-
|
|
977
|
-
self._has_vector_index = True
|
|
978
|
-
logger.info(f"{index_type} vector index created successfully")
|
|
979
|
-
|
|
980
|
-
# Optimize after index creation (may fail in some environments)
|
|
981
|
-
try:
|
|
982
|
-
self.table.optimize()
|
|
983
|
-
except Exception as optimize_error:
|
|
984
|
-
logger.debug(f"Optimization after index creation skipped: {optimize_error}")
|
|
985
|
-
|
|
986
|
-
return True
|
|
987
|
-
|
|
988
|
-
except Exception as e:
|
|
989
|
-
logger.error(f"Failed to create {index_type} vector index: {e}")
|
|
990
|
-
raise StorageError(f"{index_type} vector index creation failed: {e}") from e
|
|
991
|
-
|
|
992
|
-
def _wait_for_index_ready(
|
|
993
|
-
self,
|
|
994
|
-
column_name: str,
|
|
995
|
-
timeout_seconds: float,
|
|
996
|
-
poll_interval: float = 0.5,
|
|
997
|
-
) -> None:
|
|
998
|
-
"""Wait for an index on the specified column to be ready.
|
|
999
|
-
|
|
1000
|
-
Args:
|
|
1001
|
-
column_name: Name of the column the index is on (e.g., "vector").
|
|
1002
|
-
LanceDB typically names indexes as "{column_name}_idx".
|
|
1003
|
-
timeout_seconds: Maximum time to wait.
|
|
1004
|
-
poll_interval: Time between status checks.
|
|
1005
|
-
"""
|
|
1006
|
-
if timeout_seconds <= 0:
|
|
1007
|
-
return
|
|
1008
|
-
|
|
1009
|
-
start_time = time.time()
|
|
1010
|
-
while time.time() - start_time < timeout_seconds:
|
|
1011
|
-
try:
|
|
1012
|
-
indices = self.table.list_indices()
|
|
1013
|
-
for idx in indices:
|
|
1014
|
-
idx_name = str(_get_index_attr(idx, "name", "")).lower()
|
|
1015
|
-
idx_columns = _get_index_attr(idx, "columns", [])
|
|
1016
|
-
|
|
1017
|
-
# Match by column name in index metadata, or index name contains column
|
|
1018
|
-
if column_name in idx_columns or column_name in idx_name:
|
|
1019
|
-
# Index exists, check if it's ready
|
|
1020
|
-
status = str(_get_index_attr(idx, "status", "ready"))
|
|
1021
|
-
if status.lower() in ("ready", "complete", "built"):
|
|
1022
|
-
logger.debug(f"Index on {column_name} is ready")
|
|
1023
|
-
return
|
|
1024
|
-
break
|
|
1025
|
-
except Exception as e:
|
|
1026
|
-
logger.debug(f"Error checking index status: {e}")
|
|
1027
|
-
|
|
1028
|
-
time.sleep(poll_interval)
|
|
1029
|
-
|
|
1030
|
-
logger.warning(
|
|
1031
|
-
f"Timeout waiting for index on {column_name} after {timeout_seconds}s"
|
|
1032
|
-
)
|
|
931
|
+
if self._index_manager is None:
|
|
932
|
+
raise StorageError("Database not connected")
|
|
933
|
+
result = self._index_manager.create_vector_index(force=force)
|
|
934
|
+
# Sync local state only when index was created or modified
|
|
935
|
+
if result:
|
|
936
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
937
|
+
return result
|
|
1033
938
|
|
|
1034
939
|
def create_scalar_indexes(self) -> None:
|
|
1035
|
-
"""Create scalar indexes for frequently filtered columns.
|
|
1036
|
-
|
|
1037
|
-
Creates:
|
|
1038
|
-
- BTREE on id (fast lookups, upserts)
|
|
1039
|
-
- BTREE on timestamps and importance (range queries)
|
|
1040
|
-
- BITMAP on namespace and source (low cardinality)
|
|
1041
|
-
- LABEL_LIST on tags (array contains queries)
|
|
940
|
+
"""Create scalar indexes for frequently filtered columns. Delegates to IndexManager.
|
|
1042
941
|
|
|
1043
942
|
Raises:
|
|
1044
943
|
StorageError: If index creation fails critically.
|
|
1045
944
|
"""
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
"created_at",
|
|
1050
|
-
"updated_at",
|
|
1051
|
-
"last_accessed",
|
|
1052
|
-
"importance",
|
|
1053
|
-
"access_count",
|
|
1054
|
-
"expires_at", # TTL expiration queries
|
|
1055
|
-
]
|
|
1056
|
-
|
|
1057
|
-
for column in btree_columns:
|
|
1058
|
-
try:
|
|
1059
|
-
self.table.create_scalar_index(
|
|
1060
|
-
column,
|
|
1061
|
-
index_type="BTREE",
|
|
1062
|
-
replace=True,
|
|
1063
|
-
)
|
|
1064
|
-
logger.debug(f"Created BTREE index on {column}")
|
|
1065
|
-
except Exception as e:
|
|
1066
|
-
if "already exists" not in str(e).lower():
|
|
1067
|
-
logger.warning(f"Could not create BTREE index on {column}: {e}")
|
|
1068
|
-
|
|
1069
|
-
# BITMAP indexes for low-cardinality columns
|
|
1070
|
-
bitmap_columns = ["namespace", "source"]
|
|
1071
|
-
|
|
1072
|
-
for column in bitmap_columns:
|
|
1073
|
-
try:
|
|
1074
|
-
self.table.create_scalar_index(
|
|
1075
|
-
column,
|
|
1076
|
-
index_type="BITMAP",
|
|
1077
|
-
replace=True,
|
|
1078
|
-
)
|
|
1079
|
-
logger.debug(f"Created BITMAP index on {column}")
|
|
1080
|
-
except Exception as e:
|
|
1081
|
-
if "already exists" not in str(e).lower():
|
|
1082
|
-
logger.warning(f"Could not create BITMAP index on {column}: {e}")
|
|
1083
|
-
|
|
1084
|
-
# LABEL_LIST index for tags array (supports array_has_any queries)
|
|
1085
|
-
try:
|
|
1086
|
-
self.table.create_scalar_index(
|
|
1087
|
-
"tags",
|
|
1088
|
-
index_type="LABEL_LIST",
|
|
1089
|
-
replace=True,
|
|
1090
|
-
)
|
|
1091
|
-
logger.debug("Created LABEL_LIST index on tags")
|
|
1092
|
-
except Exception as e:
|
|
1093
|
-
if "already exists" not in str(e).lower():
|
|
1094
|
-
logger.warning(f"Could not create LABEL_LIST index on tags: {e}")
|
|
1095
|
-
|
|
1096
|
-
logger.info("Scalar indexes created")
|
|
945
|
+
if self._index_manager is None:
|
|
946
|
+
raise StorageError("Database not connected")
|
|
947
|
+
self._index_manager.create_scalar_indexes()
|
|
1097
948
|
|
|
1098
949
|
def ensure_indexes(self, force: bool = False) -> dict[str, bool]:
|
|
1099
|
-
"""Ensure all appropriate indexes exist.
|
|
950
|
+
"""Ensure all appropriate indexes exist. Delegates to IndexManager.
|
|
1100
951
|
|
|
1101
952
|
Args:
|
|
1102
953
|
force: Force index creation regardless of thresholds.
|
|
@@ -1104,35 +955,12 @@ class Database:
|
|
|
1104
955
|
Returns:
|
|
1105
956
|
Dict indicating which indexes were created.
|
|
1106
957
|
"""
|
|
1107
|
-
|
|
1108
|
-
"
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
count = self.table.count_rows()
|
|
1114
|
-
|
|
1115
|
-
# Vector index
|
|
1116
|
-
if self.auto_create_indexes or force:
|
|
1117
|
-
if count >= self.vector_index_threshold or force:
|
|
1118
|
-
results["vector_index"] = self.create_vector_index(force=force)
|
|
1119
|
-
|
|
1120
|
-
# Scalar indexes (always create if > 1000 rows)
|
|
1121
|
-
if count >= 1000 or force:
|
|
1122
|
-
try:
|
|
1123
|
-
self.create_scalar_indexes()
|
|
1124
|
-
results["scalar_indexes"] = True
|
|
1125
|
-
except Exception as e:
|
|
1126
|
-
logger.warning(f"Scalar index creation partially failed: {e}")
|
|
1127
|
-
|
|
1128
|
-
# FTS index
|
|
1129
|
-
if self.enable_fts and not self._has_fts_index:
|
|
1130
|
-
try:
|
|
1131
|
-
self._create_fts_index()
|
|
1132
|
-
results["fts_index"] = True
|
|
1133
|
-
except Exception as e:
|
|
1134
|
-
logger.warning(f"FTS index creation failed in ensure_indexes: {e}")
|
|
1135
|
-
|
|
958
|
+
if self._index_manager is None:
|
|
959
|
+
raise StorageError("Database not connected")
|
|
960
|
+
results = self._index_manager.ensure_indexes(force=force)
|
|
961
|
+
# Sync local state for backward compatibility
|
|
962
|
+
self._has_vector_index = self._index_manager.has_vector_index
|
|
963
|
+
self._has_fts_index = self._index_manager.has_fts_index
|
|
1136
964
|
return results
|
|
1137
965
|
|
|
1138
966
|
# ========================================================================
|
|
@@ -1301,6 +1129,13 @@ class Database:
|
|
|
1301
1129
|
if not 0.0 <= importance <= 1.0:
|
|
1302
1130
|
raise ValidationError("Importance must be between 0.0 and 1.0")
|
|
1303
1131
|
|
|
1132
|
+
# Validate vector dimensions
|
|
1133
|
+
if len(vector) != self.embedding_dim:
|
|
1134
|
+
raise DimensionMismatchError(
|
|
1135
|
+
expected_dim=self.embedding_dim,
|
|
1136
|
+
actual_dim=len(vector),
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1304
1139
|
memory_id = str(uuid.uuid4())
|
|
1305
1140
|
now = utc_now()
|
|
1306
1141
|
|
|
@@ -1328,6 +1163,7 @@ class Database:
|
|
|
1328
1163
|
try:
|
|
1329
1164
|
self.table.add([record])
|
|
1330
1165
|
self._invalidate_count_cache()
|
|
1166
|
+
self._track_modification()
|
|
1331
1167
|
self._invalidate_namespace_cache()
|
|
1332
1168
|
logger.debug(f"Inserted memory {memory_id}")
|
|
1333
1169
|
return memory_id
|
|
@@ -1344,23 +1180,26 @@ class Database:
|
|
|
1344
1180
|
self,
|
|
1345
1181
|
records: list[dict[str, Any]],
|
|
1346
1182
|
batch_size: int = 1000,
|
|
1183
|
+
atomic: bool = False,
|
|
1347
1184
|
) -> list[str]:
|
|
1348
1185
|
"""Insert multiple memories efficiently with batching.
|
|
1349
1186
|
|
|
1350
|
-
Note: Batch insert is NOT atomic. Partial failures may leave some
|
|
1351
|
-
records inserted. If atomicity is required, use individual inserts
|
|
1352
|
-
with transaction management at the application layer.
|
|
1353
|
-
|
|
1354
1187
|
Args:
|
|
1355
1188
|
records: List of memory records with content, vector, and optional fields.
|
|
1356
1189
|
batch_size: Records per batch (default: 1000, max: 10000).
|
|
1190
|
+
atomic: If True, rollback all inserts on partial failure.
|
|
1191
|
+
When atomic=True and a batch fails:
|
|
1192
|
+
- Attempts to delete already-inserted records
|
|
1193
|
+
- If rollback succeeds, raises the original StorageError
|
|
1194
|
+
- If rollback fails, raises PartialBatchInsertError with succeeded_ids
|
|
1357
1195
|
|
|
1358
1196
|
Returns:
|
|
1359
1197
|
List of generated memory IDs.
|
|
1360
1198
|
|
|
1361
1199
|
Raises:
|
|
1362
1200
|
ValidationError: If input validation fails or batch_size exceeds maximum.
|
|
1363
|
-
StorageError: If database operation fails.
|
|
1201
|
+
StorageError: If database operation fails (and rollback succeeds when atomic=True).
|
|
1202
|
+
PartialBatchInsertError: If atomic=True and rollback fails after partial insert.
|
|
1364
1203
|
"""
|
|
1365
1204
|
if batch_size > self.MAX_BATCH_SIZE:
|
|
1366
1205
|
raise ValidationError(
|
|
@@ -1368,9 +1207,10 @@ class Database:
|
|
|
1368
1207
|
)
|
|
1369
1208
|
|
|
1370
1209
|
all_ids: list[str] = []
|
|
1210
|
+
total_requested = len(records)
|
|
1371
1211
|
|
|
1372
1212
|
# Process in batches for large inserts
|
|
1373
|
-
for i in range(0, len(records), batch_size):
|
|
1213
|
+
for batch_index, i in enumerate(range(0, len(records), batch_size)):
|
|
1374
1214
|
batch = records[i:i + batch_size]
|
|
1375
1215
|
now = utc_now()
|
|
1376
1216
|
memory_ids: list[str] = []
|
|
@@ -1398,6 +1238,14 @@ class Database:
|
|
|
1398
1238
|
else:
|
|
1399
1239
|
vector_list = raw_vector
|
|
1400
1240
|
|
|
1241
|
+
# Validate vector dimensions
|
|
1242
|
+
if len(vector_list) != self.embedding_dim:
|
|
1243
|
+
raise DimensionMismatchError(
|
|
1244
|
+
expected_dim=self.embedding_dim,
|
|
1245
|
+
actual_dim=len(vector_list),
|
|
1246
|
+
record_index=i + len(memory_ids),
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1401
1249
|
# Calculate expires_at if default TTL is configured
|
|
1402
1250
|
expires_at = None
|
|
1403
1251
|
if self.default_memory_ttl_days is not None:
|
|
@@ -1424,9 +1272,29 @@ class Database:
|
|
|
1424
1272
|
self.table.add(prepared_records)
|
|
1425
1273
|
all_ids.extend(memory_ids)
|
|
1426
1274
|
self._invalidate_count_cache()
|
|
1275
|
+
self._track_modification(len(memory_ids))
|
|
1427
1276
|
self._invalidate_namespace_cache()
|
|
1428
|
-
logger.debug(f"Inserted batch {
|
|
1277
|
+
logger.debug(f"Inserted batch {batch_index + 1}: {len(memory_ids)} memories")
|
|
1429
1278
|
except Exception as e:
|
|
1279
|
+
if atomic and all_ids:
|
|
1280
|
+
# Attempt rollback of previously inserted records
|
|
1281
|
+
logger.warning(
|
|
1282
|
+
f"Batch {batch_index + 1} failed, attempting rollback of "
|
|
1283
|
+
f"{len(all_ids)} previously inserted records"
|
|
1284
|
+
)
|
|
1285
|
+
rollback_error = self._rollback_batch_insert(all_ids)
|
|
1286
|
+
if rollback_error:
|
|
1287
|
+
# Rollback failed - raise PartialBatchInsertError
|
|
1288
|
+
raise PartialBatchInsertError(
|
|
1289
|
+
message=f"Batch insert failed and rollback also failed: {e}",
|
|
1290
|
+
succeeded_ids=all_ids,
|
|
1291
|
+
total_requested=total_requested,
|
|
1292
|
+
failed_batch_index=batch_index,
|
|
1293
|
+
) from e
|
|
1294
|
+
else:
|
|
1295
|
+
# Rollback succeeded - raise original error
|
|
1296
|
+
logger.info(f"Rollback successful, deleted {len(all_ids)} records")
|
|
1297
|
+
raise StorageError(f"Failed to insert batch (rolled back): {e}") from e
|
|
1430
1298
|
raise StorageError(f"Failed to insert batch: {e}") from e
|
|
1431
1299
|
|
|
1432
1300
|
# Check if we should create indexes after large insert
|
|
@@ -1442,6 +1310,31 @@ class Database:
|
|
|
1442
1310
|
logger.debug(f"Inserted {len(all_ids)} memories total")
|
|
1443
1311
|
return all_ids
|
|
1444
1312
|
|
|
1313
|
+
def _rollback_batch_insert(self, memory_ids: list[str]) -> Exception | None:
|
|
1314
|
+
"""Attempt to delete records inserted during a failed batch operation.
|
|
1315
|
+
|
|
1316
|
+
Args:
|
|
1317
|
+
memory_ids: List of memory IDs to delete.
|
|
1318
|
+
|
|
1319
|
+
Returns:
|
|
1320
|
+
None if rollback succeeded, Exception if it failed.
|
|
1321
|
+
"""
|
|
1322
|
+
try:
|
|
1323
|
+
if not memory_ids:
|
|
1324
|
+
return None
|
|
1325
|
+
|
|
1326
|
+
# Use delete_batch for efficient rollback
|
|
1327
|
+
id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in memory_ids)
|
|
1328
|
+
self.table.delete(f"id IN ({id_list})")
|
|
1329
|
+
self._invalidate_count_cache()
|
|
1330
|
+
self._track_modification(len(memory_ids))
|
|
1331
|
+
self._invalidate_namespace_cache()
|
|
1332
|
+
logger.debug(f"Rolled back {len(memory_ids)} records")
|
|
1333
|
+
return None
|
|
1334
|
+
except Exception as e:
|
|
1335
|
+
logger.error(f"Rollback failed: {e}")
|
|
1336
|
+
return e
|
|
1337
|
+
|
|
1445
1338
|
@with_stale_connection_recovery
|
|
1446
1339
|
def get(self, memory_id: str) -> dict[str, Any]:
|
|
1447
1340
|
"""Get a memory by ID.
|
|
@@ -1476,6 +1369,51 @@ class Database:
|
|
|
1476
1369
|
except Exception as e:
|
|
1477
1370
|
raise StorageError(f"Failed to get memory: {e}") from e
|
|
1478
1371
|
|
|
1372
|
+
def get_batch(self, memory_ids: list[str]) -> dict[str, dict[str, Any]]:
|
|
1373
|
+
"""Get multiple memories by ID in a single query.
|
|
1374
|
+
|
|
1375
|
+
Args:
|
|
1376
|
+
memory_ids: List of memory UUIDs to retrieve.
|
|
1377
|
+
|
|
1378
|
+
Returns:
|
|
1379
|
+
Dict mapping memory_id to memory record. Missing IDs are not included.
|
|
1380
|
+
|
|
1381
|
+
Raises:
|
|
1382
|
+
ValidationError: If any memory_id format is invalid.
|
|
1383
|
+
StorageError: If database operation fails.
|
|
1384
|
+
"""
|
|
1385
|
+
if not memory_ids:
|
|
1386
|
+
return {}
|
|
1387
|
+
|
|
1388
|
+
# Validate all IDs first
|
|
1389
|
+
validated_ids: list[str] = []
|
|
1390
|
+
for memory_id in memory_ids:
|
|
1391
|
+
try:
|
|
1392
|
+
validated_id = _validate_uuid(memory_id)
|
|
1393
|
+
validated_ids.append(_sanitize_string(validated_id))
|
|
1394
|
+
except Exception as e:
|
|
1395
|
+
logger.debug(f"Invalid memory ID {memory_id}: {e}")
|
|
1396
|
+
continue
|
|
1397
|
+
|
|
1398
|
+
if not validated_ids:
|
|
1399
|
+
return {}
|
|
1400
|
+
|
|
1401
|
+
try:
|
|
1402
|
+
# Batch fetch with single IN query
|
|
1403
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
1404
|
+
results = self.table.search().where(f"id IN ({id_list})").to_list()
|
|
1405
|
+
|
|
1406
|
+
# Build result map
|
|
1407
|
+
result_map: dict[str, dict[str, Any]] = {}
|
|
1408
|
+
for record in results:
|
|
1409
|
+
# Deserialize metadata
|
|
1410
|
+
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
1411
|
+
result_map[record["id"]] = record
|
|
1412
|
+
|
|
1413
|
+
return result_map
|
|
1414
|
+
except Exception as e:
|
|
1415
|
+
raise StorageError(f"Failed to batch get memories: {e}") from e
|
|
1416
|
+
|
|
1479
1417
|
@with_process_lock
|
|
1480
1418
|
@with_write_lock
|
|
1481
1419
|
def update(self, memory_id: str, updates: dict[str, Any]) -> None:
|
|
@@ -1533,42 +1471,145 @@ class Database:
|
|
|
1533
1471
|
|
|
1534
1472
|
@with_process_lock
|
|
1535
1473
|
@with_write_lock
|
|
1536
|
-
def
|
|
1537
|
-
|
|
1474
|
+
def update_batch(
|
|
1475
|
+
self, updates: list[tuple[str, dict[str, Any]]]
|
|
1476
|
+
) -> tuple[int, list[str]]:
|
|
1477
|
+
"""Update multiple memories using atomic merge_insert.
|
|
1538
1478
|
|
|
1539
1479
|
Args:
|
|
1540
|
-
|
|
1480
|
+
updates: List of (memory_id, updates_dict) tuples.
|
|
1481
|
+
|
|
1482
|
+
Returns:
|
|
1483
|
+
Tuple of (success_count, list of failed memory_ids).
|
|
1541
1484
|
|
|
1542
1485
|
Raises:
|
|
1543
|
-
|
|
1544
|
-
MemoryNotFoundError: If memory doesn't exist.
|
|
1545
|
-
StorageError: If database operation fails.
|
|
1486
|
+
StorageError: If database operation fails completely.
|
|
1546
1487
|
"""
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
safe_id = _sanitize_string(memory_id)
|
|
1488
|
+
if not updates:
|
|
1489
|
+
return 0, []
|
|
1550
1490
|
|
|
1551
|
-
|
|
1552
|
-
|
|
1491
|
+
now = utc_now()
|
|
1492
|
+
records_to_update: list[dict[str, Any]] = []
|
|
1493
|
+
failed_ids: list[str] = []
|
|
1553
1494
|
|
|
1495
|
+
# Validate all IDs and collect them
|
|
1496
|
+
validated_updates: list[tuple[str, dict[str, Any]]] = []
|
|
1497
|
+
for memory_id, update_dict in updates:
|
|
1498
|
+
try:
|
|
1499
|
+
validated_id = _validate_uuid(memory_id)
|
|
1500
|
+
validated_updates.append((_sanitize_string(validated_id), update_dict))
|
|
1501
|
+
except Exception as e:
|
|
1502
|
+
logger.debug(f"Invalid memory ID {memory_id}: {e}")
|
|
1503
|
+
failed_ids.append(memory_id)
|
|
1504
|
+
|
|
1505
|
+
if not validated_updates:
|
|
1506
|
+
return 0, failed_ids
|
|
1507
|
+
|
|
1508
|
+
# Batch fetch all records
|
|
1509
|
+
validated_ids = [vid for vid, _ in validated_updates]
|
|
1554
1510
|
try:
|
|
1555
|
-
|
|
1556
|
-
self.
|
|
1557
|
-
self._invalidate_namespace_cache()
|
|
1558
|
-
logger.debug(f"Deleted memory {memory_id}")
|
|
1511
|
+
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
1512
|
+
all_records = self.table.search().where(f"id IN ({id_list})").to_list()
|
|
1559
1513
|
except Exception as e:
|
|
1560
|
-
|
|
1514
|
+
logger.error(f"Failed to batch fetch records for update: {e}")
|
|
1515
|
+
raise StorageError(f"Failed to batch fetch for update: {e}") from e
|
|
1561
1516
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1517
|
+
# Build lookup map
|
|
1518
|
+
record_map: dict[str, dict[str, Any]] = {}
|
|
1519
|
+
for record in all_records:
|
|
1520
|
+
record_map[record["id"]] = record
|
|
1521
|
+
|
|
1522
|
+
# Apply updates to found records
|
|
1523
|
+
update_dict_map = dict(validated_updates)
|
|
1524
|
+
for memory_id in validated_ids:
|
|
1525
|
+
if memory_id not in record_map:
|
|
1526
|
+
logger.debug(f"Memory {memory_id} not found for batch update")
|
|
1527
|
+
failed_ids.append(memory_id)
|
|
1528
|
+
continue
|
|
1566
1529
|
|
|
1567
|
-
|
|
1568
|
-
|
|
1530
|
+
record = record_map[memory_id]
|
|
1531
|
+
update_dict = update_dict_map[memory_id]
|
|
1569
1532
|
|
|
1570
|
-
|
|
1571
|
-
|
|
1533
|
+
# Apply updates
|
|
1534
|
+
record["updated_at"] = now
|
|
1535
|
+
for key, value in update_dict.items():
|
|
1536
|
+
if key == "metadata" and isinstance(value, dict):
|
|
1537
|
+
record[key] = json.dumps(value)
|
|
1538
|
+
elif key == "vector" and isinstance(value, np.ndarray):
|
|
1539
|
+
record[key] = value.tolist()
|
|
1540
|
+
else:
|
|
1541
|
+
record[key] = value
|
|
1542
|
+
|
|
1543
|
+
# Ensure metadata is serialized
|
|
1544
|
+
if isinstance(record.get("metadata"), dict):
|
|
1545
|
+
record["metadata"] = json.dumps(record["metadata"])
|
|
1546
|
+
|
|
1547
|
+
# Ensure vector is a list
|
|
1548
|
+
if isinstance(record.get("vector"), np.ndarray):
|
|
1549
|
+
record["vector"] = record["vector"].tolist()
|
|
1550
|
+
|
|
1551
|
+
records_to_update.append(record)
|
|
1552
|
+
|
|
1553
|
+
if not records_to_update:
|
|
1554
|
+
return 0, failed_ids
|
|
1555
|
+
|
|
1556
|
+
try:
|
|
1557
|
+
# Atomic batch upsert
|
|
1558
|
+
(
|
|
1559
|
+
self.table.merge_insert("id")
|
|
1560
|
+
.when_matched_update_all()
|
|
1561
|
+
.when_not_matched_insert_all()
|
|
1562
|
+
.execute(records_to_update)
|
|
1563
|
+
)
|
|
1564
|
+
success_count = len(records_to_update)
|
|
1565
|
+
logger.debug(
|
|
1566
|
+
f"Batch updated {success_count}/{len(updates)} memories "
|
|
1567
|
+
"(atomic merge_insert)"
|
|
1568
|
+
)
|
|
1569
|
+
return success_count, failed_ids
|
|
1570
|
+
except Exception as e:
|
|
1571
|
+
logger.error(f"Failed to batch update: {e}")
|
|
1572
|
+
raise StorageError(f"Failed to batch update: {e}") from e
|
|
1573
|
+
|
|
1574
|
+
@with_process_lock
|
|
1575
|
+
@with_write_lock
|
|
1576
|
+
def delete(self, memory_id: str) -> None:
|
|
1577
|
+
"""Delete a memory.
|
|
1578
|
+
|
|
1579
|
+
Args:
|
|
1580
|
+
memory_id: The memory ID.
|
|
1581
|
+
|
|
1582
|
+
Raises:
|
|
1583
|
+
ValidationError: If memory_id is invalid.
|
|
1584
|
+
MemoryNotFoundError: If memory doesn't exist.
|
|
1585
|
+
StorageError: If database operation fails.
|
|
1586
|
+
"""
|
|
1587
|
+
# Validate memory_id
|
|
1588
|
+
memory_id = _validate_uuid(memory_id)
|
|
1589
|
+
safe_id = _sanitize_string(memory_id)
|
|
1590
|
+
|
|
1591
|
+
# First verify the memory exists
|
|
1592
|
+
self.get(memory_id)
|
|
1593
|
+
|
|
1594
|
+
try:
|
|
1595
|
+
self.table.delete(f"id = '{safe_id}'")
|
|
1596
|
+
self._invalidate_count_cache()
|
|
1597
|
+
self._track_modification()
|
|
1598
|
+
self._invalidate_namespace_cache()
|
|
1599
|
+
logger.debug(f"Deleted memory {memory_id}")
|
|
1600
|
+
except Exception as e:
|
|
1601
|
+
raise StorageError(f"Failed to delete memory: {e}") from e
|
|
1602
|
+
|
|
1603
|
+
@with_process_lock
|
|
1604
|
+
@with_write_lock
|
|
1605
|
+
def delete_by_namespace(self, namespace: str) -> int:
|
|
1606
|
+
"""Delete all memories in a namespace.
|
|
1607
|
+
|
|
1608
|
+
Args:
|
|
1609
|
+
namespace: The namespace to delete.
|
|
1610
|
+
|
|
1611
|
+
Returns:
|
|
1612
|
+
Number of deleted records.
|
|
1572
1613
|
|
|
1573
1614
|
Raises:
|
|
1574
1615
|
ValidationError: If namespace is invalid.
|
|
@@ -1581,6 +1622,7 @@ class Database:
|
|
|
1581
1622
|
count_before: int = self.table.count_rows()
|
|
1582
1623
|
self.table.delete(f"namespace = '{safe_ns}'")
|
|
1583
1624
|
self._invalidate_count_cache()
|
|
1625
|
+
self._track_modification()
|
|
1584
1626
|
self._invalidate_namespace_cache()
|
|
1585
1627
|
count_after: int = self.table.count_rows()
|
|
1586
1628
|
deleted = count_before - count_after
|
|
@@ -1624,6 +1666,7 @@ class Database:
|
|
|
1624
1666
|
self.table.delete("id IS NOT NULL")
|
|
1625
1667
|
|
|
1626
1668
|
self._invalidate_count_cache()
|
|
1669
|
+
self._track_modification()
|
|
1627
1670
|
self._invalidate_namespace_cache()
|
|
1628
1671
|
|
|
1629
1672
|
# Reset index tracking flags for test isolation
|
|
@@ -1643,6 +1686,7 @@ class Database:
|
|
|
1643
1686
|
"""Rename all memories from one namespace to another.
|
|
1644
1687
|
|
|
1645
1688
|
Uses atomic batch update via merge_insert for data integrity.
|
|
1689
|
+
On partial failure, attempts to rollback renamed records to original namespace.
|
|
1646
1690
|
|
|
1647
1691
|
Args:
|
|
1648
1692
|
old_namespace: Source namespace name.
|
|
@@ -1661,6 +1705,7 @@ class Database:
|
|
|
1661
1705
|
old_namespace = _validate_namespace(old_namespace)
|
|
1662
1706
|
new_namespace = _validate_namespace(new_namespace)
|
|
1663
1707
|
safe_old = _sanitize_string(old_namespace)
|
|
1708
|
+
safe_new = _sanitize_string(new_namespace)
|
|
1664
1709
|
|
|
1665
1710
|
try:
|
|
1666
1711
|
# Check if source namespace exists
|
|
@@ -1674,6 +1719,9 @@ class Database:
|
|
|
1674
1719
|
logger.debug(f"Namespace '{old_namespace}' renamed to itself ({count} records)")
|
|
1675
1720
|
return count
|
|
1676
1721
|
|
|
1722
|
+
# Track renamed IDs for rollback capability
|
|
1723
|
+
renamed_ids: list[str] = []
|
|
1724
|
+
|
|
1677
1725
|
# Fetch all records in batches with iteration safeguards
|
|
1678
1726
|
batch_size = 1000
|
|
1679
1727
|
max_iterations = 10000 # Safety cap: 10M records at 1000/batch
|
|
@@ -1702,6 +1750,9 @@ class Database:
|
|
|
1702
1750
|
if not records:
|
|
1703
1751
|
break
|
|
1704
1752
|
|
|
1753
|
+
# Track IDs in this batch for potential rollback
|
|
1754
|
+
batch_ids = [r["id"] for r in records]
|
|
1755
|
+
|
|
1705
1756
|
# Update namespace field
|
|
1706
1757
|
for r in records:
|
|
1707
1758
|
r["namespace"] = new_namespace
|
|
@@ -1711,13 +1762,41 @@ class Database:
|
|
|
1711
1762
|
if isinstance(r.get("vector"), np.ndarray):
|
|
1712
1763
|
r["vector"] = r["vector"].tolist()
|
|
1713
1764
|
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1765
|
+
try:
|
|
1766
|
+
# Atomic upsert
|
|
1767
|
+
(
|
|
1768
|
+
self.table.merge_insert("id")
|
|
1769
|
+
.when_matched_update_all()
|
|
1770
|
+
.when_not_matched_insert_all()
|
|
1771
|
+
.execute(records)
|
|
1772
|
+
)
|
|
1773
|
+
# Only track as renamed after successful update
|
|
1774
|
+
renamed_ids.extend(batch_ids)
|
|
1775
|
+
except Exception as batch_error:
|
|
1776
|
+
# Batch failed - attempt rollback of previously renamed records
|
|
1777
|
+
if renamed_ids:
|
|
1778
|
+
logger.warning(
|
|
1779
|
+
f"Batch {iteration} failed, attempting rollback of "
|
|
1780
|
+
f"{len(renamed_ids)} previously renamed records"
|
|
1781
|
+
)
|
|
1782
|
+
rollback_error = self._rollback_namespace_rename(
|
|
1783
|
+
renamed_ids, old_namespace
|
|
1784
|
+
)
|
|
1785
|
+
if rollback_error:
|
|
1786
|
+
raise StorageError(
|
|
1787
|
+
f"Namespace rename failed at batch {iteration} and "
|
|
1788
|
+
f"rollback also failed. {len(renamed_ids)} records may be "
|
|
1789
|
+
f"in inconsistent state (partially in '{new_namespace}'). "
|
|
1790
|
+
f"Original error: {batch_error}. Rollback error: {rollback_error}"
|
|
1791
|
+
) from batch_error
|
|
1792
|
+
else:
|
|
1793
|
+
logger.info(
|
|
1794
|
+
f"Rollback successful, reverted {len(renamed_ids)} records "
|
|
1795
|
+
f"back to namespace '{old_namespace}'"
|
|
1796
|
+
)
|
|
1797
|
+
raise StorageError(
|
|
1798
|
+
f"Failed to rename namespace (rolled back): {batch_error}"
|
|
1799
|
+
) from batch_error
|
|
1721
1800
|
|
|
1722
1801
|
updated += len(records)
|
|
1723
1802
|
|
|
@@ -1740,6 +1819,66 @@ class Database:
|
|
|
1740
1819
|
except Exception as e:
|
|
1741
1820
|
raise StorageError(f"Failed to rename namespace: {e}") from e
|
|
1742
1821
|
|
|
1822
|
+
def _rollback_namespace_rename(
|
|
1823
|
+
self, memory_ids: list[str], target_namespace: str
|
|
1824
|
+
) -> Exception | None:
|
|
1825
|
+
"""Attempt to revert renamed records back to original namespace.
|
|
1826
|
+
|
|
1827
|
+
Args:
|
|
1828
|
+
memory_ids: List of memory IDs to revert.
|
|
1829
|
+
target_namespace: Namespace to revert records to.
|
|
1830
|
+
|
|
1831
|
+
Returns:
|
|
1832
|
+
None if rollback succeeded, Exception if it failed.
|
|
1833
|
+
"""
|
|
1834
|
+
try:
|
|
1835
|
+
if not memory_ids:
|
|
1836
|
+
return None
|
|
1837
|
+
|
|
1838
|
+
safe_namespace = _sanitize_string(target_namespace)
|
|
1839
|
+
now = utc_now()
|
|
1840
|
+
|
|
1841
|
+
# Process in batches for large rollbacks
|
|
1842
|
+
batch_size = 1000
|
|
1843
|
+
for i in range(0, len(memory_ids), batch_size):
|
|
1844
|
+
batch_ids = memory_ids[i:i + batch_size]
|
|
1845
|
+
id_list = ", ".join(f"'{_sanitize_string(mid)}'" for mid in batch_ids)
|
|
1846
|
+
|
|
1847
|
+
# Fetch records that need rollback
|
|
1848
|
+
records = (
|
|
1849
|
+
self.table.search()
|
|
1850
|
+
.where(f"id IN ({id_list})")
|
|
1851
|
+
.to_list()
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
if not records:
|
|
1855
|
+
continue
|
|
1856
|
+
|
|
1857
|
+
# Revert namespace
|
|
1858
|
+
for r in records:
|
|
1859
|
+
r["namespace"] = target_namespace
|
|
1860
|
+
r["updated_at"] = now
|
|
1861
|
+
if isinstance(r.get("metadata"), dict):
|
|
1862
|
+
r["metadata"] = json.dumps(r["metadata"])
|
|
1863
|
+
if isinstance(r.get("vector"), np.ndarray):
|
|
1864
|
+
r["vector"] = r["vector"].tolist()
|
|
1865
|
+
|
|
1866
|
+
# Atomic upsert to restore original namespace
|
|
1867
|
+
(
|
|
1868
|
+
self.table.merge_insert("id")
|
|
1869
|
+
.when_matched_update_all()
|
|
1870
|
+
.when_not_matched_insert_all()
|
|
1871
|
+
.execute(records)
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
self._invalidate_namespace_cache()
|
|
1875
|
+
logger.debug(f"Rolled back {len(memory_ids)} records to namespace '{target_namespace}'")
|
|
1876
|
+
return None
|
|
1877
|
+
|
|
1878
|
+
except Exception as e:
|
|
1879
|
+
logger.error(f"Namespace rename rollback failed: {e}")
|
|
1880
|
+
return e
|
|
1881
|
+
|
|
1743
1882
|
@with_stale_connection_recovery
|
|
1744
1883
|
def get_stats(self, namespace: str | None = None) -> dict[str, Any]:
|
|
1745
1884
|
"""Get comprehensive database statistics.
|
|
@@ -1836,15 +1975,18 @@ class Database:
|
|
|
1836
1975
|
safe_ns = _sanitize_string(namespace)
|
|
1837
1976
|
|
|
1838
1977
|
try:
|
|
1839
|
-
# Get
|
|
1840
|
-
|
|
1978
|
+
# Get count efficiently
|
|
1979
|
+
filter_expr = f"namespace = '{safe_ns}'"
|
|
1980
|
+
count_results = (
|
|
1841
1981
|
self.table.search()
|
|
1842
|
-
.where(
|
|
1843
|
-
.select(["
|
|
1982
|
+
.where(filter_expr)
|
|
1983
|
+
.select(["id"])
|
|
1984
|
+
.limit(1000000) # High limit to count all
|
|
1844
1985
|
.to_list()
|
|
1845
1986
|
)
|
|
1987
|
+
memory_count = len(count_results)
|
|
1846
1988
|
|
|
1847
|
-
if
|
|
1989
|
+
if memory_count == 0:
|
|
1848
1990
|
return {
|
|
1849
1991
|
"namespace": namespace,
|
|
1850
1992
|
"memory_count": 0,
|
|
@@ -1853,18 +1995,42 @@ class Database:
|
|
|
1853
1995
|
"avg_content_length": None,
|
|
1854
1996
|
}
|
|
1855
1997
|
|
|
1856
|
-
#
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1998
|
+
# Get oldest memory (sort ascending, limit 1)
|
|
1999
|
+
oldest_records = (
|
|
2000
|
+
self.table.search()
|
|
2001
|
+
.where(filter_expr)
|
|
2002
|
+
.select(["created_at"])
|
|
2003
|
+
.limit(1)
|
|
2004
|
+
.to_list()
|
|
2005
|
+
)
|
|
2006
|
+
oldest = oldest_records[0]["created_at"] if oldest_records else None
|
|
2007
|
+
|
|
2008
|
+
# Get newest memory - need to fetch more and find max since LanceDB
|
|
2009
|
+
# doesn't support ORDER BY DESC efficiently
|
|
2010
|
+
# Sample up to 1000 records for stats to avoid loading everything
|
|
2011
|
+
sample_size = min(memory_count, 1000)
|
|
2012
|
+
sample_records = (
|
|
2013
|
+
self.table.search()
|
|
2014
|
+
.where(filter_expr)
|
|
2015
|
+
.select(["created_at", "content"])
|
|
2016
|
+
.limit(sample_size)
|
|
2017
|
+
.to_list()
|
|
2018
|
+
)
|
|
1860
2019
|
|
|
1861
|
-
#
|
|
1862
|
-
|
|
1863
|
-
|
|
2020
|
+
# Find newest from sample (for large namespaces this is approximate)
|
|
2021
|
+
if sample_records:
|
|
2022
|
+
created_times = [r["created_at"] for r in sample_records]
|
|
2023
|
+
newest = max(created_times)
|
|
2024
|
+
# Calculate average content length from sample
|
|
2025
|
+
content_lengths = [len(r.get("content", "")) for r in sample_records]
|
|
2026
|
+
avg_content_length = sum(content_lengths) / len(content_lengths)
|
|
2027
|
+
else:
|
|
2028
|
+
newest = oldest
|
|
2029
|
+
avg_content_length = None
|
|
1864
2030
|
|
|
1865
2031
|
return {
|
|
1866
2032
|
"namespace": namespace,
|
|
1867
|
-
"memory_count":
|
|
2033
|
+
"memory_count": memory_count,
|
|
1868
2034
|
"oldest_memory": oldest,
|
|
1869
2035
|
"newest_memory": newest,
|
|
1870
2036
|
"avg_content_length": avg_content_length,
|
|
@@ -2024,21 +2190,23 @@ class Database:
|
|
|
2024
2190
|
|
|
2025
2191
|
@with_process_lock
|
|
2026
2192
|
@with_write_lock
|
|
2027
|
-
def delete_batch(self, memory_ids: list[str]) -> int:
|
|
2193
|
+
def delete_batch(self, memory_ids: list[str]) -> tuple[int, list[str]]:
|
|
2028
2194
|
"""Delete multiple memories atomically using IN clause.
|
|
2029
2195
|
|
|
2030
2196
|
Args:
|
|
2031
2197
|
memory_ids: List of memory UUIDs to delete.
|
|
2032
2198
|
|
|
2033
2199
|
Returns:
|
|
2034
|
-
|
|
2200
|
+
Tuple of (count_deleted, list_of_deleted_ids) where:
|
|
2201
|
+
- count_deleted: Number of memories actually deleted
|
|
2202
|
+
- list_of_deleted_ids: IDs that were actually deleted
|
|
2035
2203
|
|
|
2036
2204
|
Raises:
|
|
2037
2205
|
ValidationError: If any memory_id is invalid.
|
|
2038
2206
|
StorageError: If database operation fails.
|
|
2039
2207
|
"""
|
|
2040
2208
|
if not memory_ids:
|
|
2041
|
-
return 0
|
|
2209
|
+
return (0, [])
|
|
2042
2210
|
|
|
2043
2211
|
# Validate all IDs first (fail fast)
|
|
2044
2212
|
validated_ids: list[str] = []
|
|
@@ -2047,21 +2215,32 @@ class Database:
|
|
|
2047
2215
|
validated_ids.append(_sanitize_string(validated_id))
|
|
2048
2216
|
|
|
2049
2217
|
try:
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
# Build IN clause for atomic batch delete
|
|
2218
|
+
# First, check which IDs actually exist
|
|
2053
2219
|
id_list = ", ".join(f"'{mid}'" for mid in validated_ids)
|
|
2054
2220
|
filter_expr = f"id IN ({id_list})"
|
|
2055
|
-
|
|
2221
|
+
existing_records = (
|
|
2222
|
+
self.table.search()
|
|
2223
|
+
.where(filter_expr)
|
|
2224
|
+
.select(["id"])
|
|
2225
|
+
.limit(len(validated_ids))
|
|
2226
|
+
.to_list()
|
|
2227
|
+
)
|
|
2228
|
+
existing_ids = [r["id"] for r in existing_records]
|
|
2229
|
+
|
|
2230
|
+
if not existing_ids:
|
|
2231
|
+
return (0, [])
|
|
2232
|
+
|
|
2233
|
+
# Delete only existing IDs
|
|
2234
|
+
existing_id_list = ", ".join(f"'{mid}'" for mid in existing_ids)
|
|
2235
|
+
delete_expr = f"id IN ({existing_id_list})"
|
|
2236
|
+
self.table.delete(delete_expr)
|
|
2056
2237
|
|
|
2057
2238
|
self._invalidate_count_cache()
|
|
2239
|
+
self._track_modification()
|
|
2058
2240
|
self._invalidate_namespace_cache()
|
|
2059
2241
|
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
logger.debug(f"Batch deleted {deleted} memories")
|
|
2064
|
-
return deleted
|
|
2242
|
+
logger.debug(f"Batch deleted {len(existing_ids)} memories")
|
|
2243
|
+
return (len(existing_ids), existing_ids)
|
|
2065
2244
|
except ValidationError:
|
|
2066
2245
|
raise
|
|
2067
2246
|
except Exception as e:
|
|
@@ -2159,6 +2338,10 @@ class Database:
|
|
|
2159
2338
|
backoff=self.retry_backoff_seconds,
|
|
2160
2339
|
)
|
|
2161
2340
|
|
|
2341
|
+
# ========================================================================
|
|
2342
|
+
# Search Operations (delegates to SearchManager)
|
|
2343
|
+
# ========================================================================
|
|
2344
|
+
|
|
2162
2345
|
def _calculate_search_params(
|
|
2163
2346
|
self,
|
|
2164
2347
|
count: int,
|
|
@@ -2166,59 +2349,12 @@ class Database:
|
|
|
2166
2349
|
nprobes_override: int | None = None,
|
|
2167
2350
|
refine_factor_override: int | None = None,
|
|
2168
2351
|
) -> tuple[int, int]:
|
|
2169
|
-
"""Calculate optimal search parameters
|
|
2170
|
-
|
|
2171
|
-
|
|
2172
|
-
|
|
2173
|
-
|
|
2174
|
-
|
|
2175
|
-
limit: Number of results requested.
|
|
2176
|
-
nprobes_override: Optional override for nprobes (uses this if provided).
|
|
2177
|
-
refine_factor_override: Optional override for refine_factor.
|
|
2178
|
-
|
|
2179
|
-
Returns:
|
|
2180
|
-
Tuple of (nprobes, refine_factor).
|
|
2181
|
-
|
|
2182
|
-
Scaling rules:
|
|
2183
|
-
- nprobes: Base from config, scaled up for larger datasets
|
|
2184
|
-
- <100K: config value (default 20)
|
|
2185
|
-
- 100K-1M: max(config, 30)
|
|
2186
|
-
- 1M-10M: max(config, 50)
|
|
2187
|
-
- >10M: max(config, 100)
|
|
2188
|
-
- refine_factor: Base from config, scaled up for small limits
|
|
2189
|
-
- limit <= 5: config value * 2
|
|
2190
|
-
- limit <= 20: config value
|
|
2191
|
-
- limit > 20: max(config // 2, 2)
|
|
2192
|
-
"""
|
|
2193
|
-
# Calculate nprobes based on dataset size
|
|
2194
|
-
if nprobes_override is not None:
|
|
2195
|
-
nprobes = nprobes_override
|
|
2196
|
-
else:
|
|
2197
|
-
base_nprobes = self.index_nprobes
|
|
2198
|
-
if count < 100_000:
|
|
2199
|
-
nprobes = base_nprobes
|
|
2200
|
-
elif count < 1_000_000:
|
|
2201
|
-
nprobes = max(base_nprobes, 30)
|
|
2202
|
-
elif count < 10_000_000:
|
|
2203
|
-
nprobes = max(base_nprobes, 50)
|
|
2204
|
-
else:
|
|
2205
|
-
nprobes = max(base_nprobes, 100)
|
|
2206
|
-
|
|
2207
|
-
# Calculate refine_factor based on limit
|
|
2208
|
-
if refine_factor_override is not None:
|
|
2209
|
-
refine_factor = refine_factor_override
|
|
2210
|
-
else:
|
|
2211
|
-
base_refine = self.index_refine_factor
|
|
2212
|
-
if limit <= 5:
|
|
2213
|
-
# Small limits need more refinement for accuracy
|
|
2214
|
-
refine_factor = base_refine * 2
|
|
2215
|
-
elif limit <= 20:
|
|
2216
|
-
refine_factor = base_refine
|
|
2217
|
-
else:
|
|
2218
|
-
# Large limits can use less refinement
|
|
2219
|
-
refine_factor = max(base_refine // 2, 2)
|
|
2220
|
-
|
|
2221
|
-
return nprobes, refine_factor
|
|
2352
|
+
"""Calculate optimal search parameters. Delegates to SearchManager."""
|
|
2353
|
+
if self._search_manager is None:
|
|
2354
|
+
raise StorageError("Database not connected")
|
|
2355
|
+
return self._search_manager.calculate_search_params(
|
|
2356
|
+
count, limit, nprobes_override, refine_factor_override
|
|
2357
|
+
)
|
|
2222
2358
|
|
|
2223
2359
|
@with_stale_connection_recovery
|
|
2224
2360
|
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
@@ -2232,19 +2368,16 @@ class Database:
|
|
|
2232
2368
|
refine_factor: int | None = None,
|
|
2233
2369
|
include_vector: bool = False,
|
|
2234
2370
|
) -> list[dict[str, Any]]:
|
|
2235
|
-
"""Search for similar memories by vector
|
|
2371
|
+
"""Search for similar memories by vector. Delegates to SearchManager.
|
|
2236
2372
|
|
|
2237
2373
|
Args:
|
|
2238
2374
|
query_vector: Query embedding vector.
|
|
2239
2375
|
limit: Maximum number of results.
|
|
2240
2376
|
namespace: Filter to specific namespace.
|
|
2241
2377
|
min_similarity: Minimum similarity threshold (0-1).
|
|
2242
|
-
nprobes: Number of partitions to search
|
|
2243
|
-
Only effective when vector index exists. Defaults to dynamic calculation.
|
|
2378
|
+
nprobes: Number of partitions to search.
|
|
2244
2379
|
refine_factor: Re-rank top (refine_factor * limit) for accuracy.
|
|
2245
|
-
Defaults to dynamic calculation based on limit.
|
|
2246
2380
|
include_vector: Whether to include vector embeddings in results.
|
|
2247
|
-
Defaults to False to reduce response size.
|
|
2248
2381
|
|
|
2249
2382
|
Returns:
|
|
2250
2383
|
List of memory records with similarity scores.
|
|
@@ -2253,66 +2386,53 @@ class Database:
|
|
|
2253
2386
|
ValidationError: If input validation fails.
|
|
2254
2387
|
StorageError: If database operation fails.
|
|
2255
2388
|
"""
|
|
2256
|
-
|
|
2257
|
-
|
|
2389
|
+
if self._search_manager is None:
|
|
2390
|
+
raise StorageError("Database not connected")
|
|
2391
|
+
return self._search_manager.vector_search(
|
|
2392
|
+
query_vector=query_vector,
|
|
2393
|
+
limit=limit,
|
|
2394
|
+
namespace=namespace,
|
|
2395
|
+
min_similarity=min_similarity,
|
|
2396
|
+
nprobes=nprobes,
|
|
2397
|
+
refine_factor=refine_factor,
|
|
2398
|
+
include_vector=include_vector,
|
|
2399
|
+
)
|
|
2258
2400
|
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2401
|
+
@with_stale_connection_recovery
|
|
2402
|
+
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
2403
|
+
def batch_vector_search_native(
|
|
2404
|
+
self,
|
|
2405
|
+
query_vectors: list[np.ndarray],
|
|
2406
|
+
limit_per_query: int = 3,
|
|
2407
|
+
namespace: str | None = None,
|
|
2408
|
+
min_similarity: float = 0.0,
|
|
2409
|
+
include_vector: bool = False,
|
|
2410
|
+
) -> list[list[dict[str, Any]]]:
|
|
2411
|
+
"""Batch search using native LanceDB. Delegates to SearchManager.
|
|
2262
2412
|
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
)
|
|
2270
|
-
search = search.nprobes(actual_nprobes)
|
|
2271
|
-
search = search.refine_factor(actual_refine)
|
|
2413
|
+
Args:
|
|
2414
|
+
query_vectors: List of query embedding vectors.
|
|
2415
|
+
limit_per_query: Maximum number of results per query.
|
|
2416
|
+
namespace: Filter to specific namespace.
|
|
2417
|
+
min_similarity: Minimum similarity threshold (0-1).
|
|
2418
|
+
include_vector: Whether to include vector embeddings in results.
|
|
2272
2419
|
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
# Fetch extra if filtering by similarity
|
|
2290
|
-
fetch_limit = limit * 2 if min_similarity > 0.0 else limit
|
|
2291
|
-
results: list[dict[str, Any]] = search.limit(fetch_limit).to_list()
|
|
2292
|
-
|
|
2293
|
-
# Process results
|
|
2294
|
-
filtered_results: list[dict[str, Any]] = []
|
|
2295
|
-
for record in results:
|
|
2296
|
-
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
2297
|
-
# LanceDB returns _distance, convert to similarity
|
|
2298
|
-
if "_distance" in record:
|
|
2299
|
-
# Cosine distance to similarity: 1 - distance
|
|
2300
|
-
# Clamp to [0, 1] (cosine distance can exceed 1 for unnormalized)
|
|
2301
|
-
similarity = max(0.0, min(1.0, 1 - record["_distance"]))
|
|
2302
|
-
record["similarity"] = similarity
|
|
2303
|
-
del record["_distance"]
|
|
2304
|
-
|
|
2305
|
-
# Apply similarity threshold
|
|
2306
|
-
if record.get("similarity", 0) >= min_similarity:
|
|
2307
|
-
filtered_results.append(record)
|
|
2308
|
-
if len(filtered_results) >= limit:
|
|
2309
|
-
break
|
|
2310
|
-
|
|
2311
|
-
return filtered_results
|
|
2312
|
-
except ValidationError:
|
|
2313
|
-
raise
|
|
2314
|
-
except Exception as e:
|
|
2315
|
-
raise StorageError(f"Failed to search: {e}") from e
|
|
2420
|
+
Returns:
|
|
2421
|
+
List of result lists, one per query vector.
|
|
2422
|
+
|
|
2423
|
+
Raises:
|
|
2424
|
+
ValidationError: If input validation fails.
|
|
2425
|
+
StorageError: If database operation fails.
|
|
2426
|
+
"""
|
|
2427
|
+
if self._search_manager is None:
|
|
2428
|
+
raise StorageError("Database not connected")
|
|
2429
|
+
return self._search_manager.batch_vector_search_native(
|
|
2430
|
+
query_vectors=query_vectors,
|
|
2431
|
+
limit_per_query=limit_per_query,
|
|
2432
|
+
namespace=namespace,
|
|
2433
|
+
min_similarity=min_similarity,
|
|
2434
|
+
include_vector=include_vector,
|
|
2435
|
+
)
|
|
2316
2436
|
|
|
2317
2437
|
@with_stale_connection_recovery
|
|
2318
2438
|
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
@@ -2325,10 +2445,7 @@ class Database:
|
|
|
2325
2445
|
alpha: float = 0.5,
|
|
2326
2446
|
min_similarity: float = 0.0,
|
|
2327
2447
|
) -> list[dict[str, Any]]:
|
|
2328
|
-
"""Hybrid search combining vector
|
|
2329
|
-
|
|
2330
|
-
Uses LinearCombinationReranker to balance vector and keyword scores
|
|
2331
|
-
based on the alpha parameter.
|
|
2448
|
+
"""Hybrid search combining vector and keyword. Delegates to SearchManager.
|
|
2332
2449
|
|
|
2333
2450
|
Args:
|
|
2334
2451
|
query: Text query for full-text search.
|
|
@@ -2336,9 +2453,7 @@ class Database:
|
|
|
2336
2453
|
limit: Number of results.
|
|
2337
2454
|
namespace: Filter to namespace.
|
|
2338
2455
|
alpha: Balance between vector (1.0) and keyword (0.0).
|
|
2339
|
-
|
|
2340
|
-
min_similarity: Minimum similarity threshold (0.0-1.0).
|
|
2341
|
-
Results below this threshold are filtered out.
|
|
2456
|
+
min_similarity: Minimum similarity threshold.
|
|
2342
2457
|
|
|
2343
2458
|
Returns:
|
|
2344
2459
|
List of memory records with combined scores.
|
|
@@ -2347,80 +2462,16 @@ class Database:
|
|
|
2347
2462
|
ValidationError: If input validation fails.
|
|
2348
2463
|
StorageError: If database operation fails.
|
|
2349
2464
|
"""
|
|
2350
|
-
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
|
|
2354
|
-
|
|
2355
|
-
|
|
2356
|
-
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
.vector(query_vector.tolist())
|
|
2361
|
-
.vector_column_name("vector")
|
|
2362
|
-
)
|
|
2363
|
-
|
|
2364
|
-
# Apply alpha parameter using LinearCombinationReranker
|
|
2365
|
-
# alpha=1.0 means full vector, alpha=0.0 means full FTS
|
|
2366
|
-
try:
|
|
2367
|
-
from lancedb.rerankers import LinearCombinationReranker
|
|
2368
|
-
|
|
2369
|
-
reranker = LinearCombinationReranker(weight=alpha)
|
|
2370
|
-
search = search.rerank(reranker)
|
|
2371
|
-
except ImportError:
|
|
2372
|
-
logger.debug("LinearCombinationReranker not available, using default reranking")
|
|
2373
|
-
except Exception as e:
|
|
2374
|
-
logger.debug(f"Could not apply reranker: {e}")
|
|
2375
|
-
|
|
2376
|
-
# Apply namespace filter
|
|
2377
|
-
if namespace:
|
|
2378
|
-
namespace = _validate_namespace(namespace)
|
|
2379
|
-
safe_ns = _sanitize_string(namespace)
|
|
2380
|
-
search = search.where(f"namespace = '{safe_ns}'")
|
|
2381
|
-
|
|
2382
|
-
results: list[dict[str, Any]] = search.limit(limit).to_list()
|
|
2383
|
-
|
|
2384
|
-
# Process results - normalize scores and clean up internal columns
|
|
2385
|
-
processed_results: list[dict[str, Any]] = []
|
|
2386
|
-
for record in results:
|
|
2387
|
-
record["metadata"] = json.loads(record["metadata"]) if record["metadata"] else {}
|
|
2388
|
-
|
|
2389
|
-
# Compute similarity from various score columns
|
|
2390
|
-
# Priority: _relevance_score > _distance > _score > default
|
|
2391
|
-
similarity: float
|
|
2392
|
-
if "_relevance_score" in record:
|
|
2393
|
-
# Reranker output - use directly (already 0-1 range)
|
|
2394
|
-
similarity = float(record["_relevance_score"])
|
|
2395
|
-
del record["_relevance_score"]
|
|
2396
|
-
elif "_distance" in record:
|
|
2397
|
-
# Vector distance - convert to similarity
|
|
2398
|
-
similarity = max(0.0, min(1.0, 1 - float(record["_distance"])))
|
|
2399
|
-
del record["_distance"]
|
|
2400
|
-
elif "_score" in record:
|
|
2401
|
-
# BM25 score - normalize using score/(1+score)
|
|
2402
|
-
score = float(record["_score"])
|
|
2403
|
-
similarity = score / (1.0 + score)
|
|
2404
|
-
del record["_score"]
|
|
2405
|
-
else:
|
|
2406
|
-
# No score column - use default
|
|
2407
|
-
similarity = 0.5
|
|
2408
|
-
|
|
2409
|
-
record["similarity"] = similarity
|
|
2410
|
-
|
|
2411
|
-
# Mark as hybrid result with alpha value
|
|
2412
|
-
record["search_type"] = "hybrid"
|
|
2413
|
-
record["alpha"] = alpha
|
|
2414
|
-
|
|
2415
|
-
# Apply min_similarity filter
|
|
2416
|
-
if similarity >= min_similarity:
|
|
2417
|
-
processed_results.append(record)
|
|
2418
|
-
|
|
2419
|
-
return processed_results
|
|
2420
|
-
|
|
2421
|
-
except Exception as e:
|
|
2422
|
-
logger.warning(f"Hybrid search failed, falling back to vector search: {e}")
|
|
2423
|
-
return self.vector_search(query_vector, limit=limit, namespace=namespace)
|
|
2465
|
+
if self._search_manager is None:
|
|
2466
|
+
raise StorageError("Database not connected")
|
|
2467
|
+
return self._search_manager.hybrid_search(
|
|
2468
|
+
query=query,
|
|
2469
|
+
query_vector=query_vector,
|
|
2470
|
+
limit=limit,
|
|
2471
|
+
namespace=namespace,
|
|
2472
|
+
alpha=alpha,
|
|
2473
|
+
min_similarity=min_similarity,
|
|
2474
|
+
)
|
|
2424
2475
|
|
|
2425
2476
|
@with_stale_connection_recovery
|
|
2426
2477
|
@retry_on_storage_error(max_attempts=3, backoff=0.5)
|
|
@@ -2429,20 +2480,19 @@ class Database:
|
|
|
2429
2480
|
query_vectors: list[np.ndarray],
|
|
2430
2481
|
limit_per_query: int = 3,
|
|
2431
2482
|
namespace: str | None = None,
|
|
2432
|
-
parallel: bool = False,
|
|
2433
|
-
max_workers: int = 4,
|
|
2483
|
+
parallel: bool = False, # Deprecated
|
|
2484
|
+
max_workers: int = 4, # Deprecated
|
|
2485
|
+
include_vector: bool = False,
|
|
2434
2486
|
) -> list[list[dict[str, Any]]]:
|
|
2435
|
-
"""Search
|
|
2436
|
-
|
|
2437
|
-
Efficient for operations like journey interpolation where multiple
|
|
2438
|
-
points need to find nearby memories.
|
|
2487
|
+
"""Search using multiple query vectors. Delegates to SearchManager.
|
|
2439
2488
|
|
|
2440
2489
|
Args:
|
|
2441
2490
|
query_vectors: List of query embedding vectors.
|
|
2442
2491
|
limit_per_query: Maximum results per query vector.
|
|
2443
2492
|
namespace: Filter to specific namespace.
|
|
2444
|
-
parallel:
|
|
2445
|
-
max_workers:
|
|
2493
|
+
parallel: Deprecated, kept for backward compatibility.
|
|
2494
|
+
max_workers: Deprecated, kept for backward compatibility.
|
|
2495
|
+
include_vector: Whether to include vector embeddings in results.
|
|
2446
2496
|
|
|
2447
2497
|
Returns:
|
|
2448
2498
|
List of result lists (one per query vector).
|
|
@@ -2450,52 +2500,16 @@ class Database:
|
|
|
2450
2500
|
Raises:
|
|
2451
2501
|
StorageError: If database operation fails.
|
|
2452
2502
|
"""
|
|
2453
|
-
if
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
|
|
2459
|
-
|
|
2460
|
-
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
def search_single(vec: np.ndarray) -> list[dict[str, Any]]:
|
|
2464
|
-
"""Execute a single vector search."""
|
|
2465
|
-
search = self.table.search(vec.tolist()).distance_type("cosine")
|
|
2466
|
-
|
|
2467
|
-
if where_clause:
|
|
2468
|
-
search = search.where(where_clause)
|
|
2469
|
-
|
|
2470
|
-
results: list[dict[str, Any]] = search.limit(limit_per_query).to_list()
|
|
2471
|
-
|
|
2472
|
-
# Process results
|
|
2473
|
-
for record in results:
|
|
2474
|
-
meta = record["metadata"]
|
|
2475
|
-
record["metadata"] = json.loads(meta) if meta else {}
|
|
2476
|
-
if "_distance" in record:
|
|
2477
|
-
record["similarity"] = max(0.0, min(1.0, 1 - record["_distance"]))
|
|
2478
|
-
del record["_distance"]
|
|
2479
|
-
|
|
2480
|
-
return results
|
|
2481
|
-
|
|
2482
|
-
try:
|
|
2483
|
-
if parallel and len(query_vectors) > 1:
|
|
2484
|
-
# Use ThreadPoolExecutor for parallel execution
|
|
2485
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
2486
|
-
|
|
2487
|
-
workers = min(max_workers, len(query_vectors))
|
|
2488
|
-
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
2489
|
-
# Map preserves order
|
|
2490
|
-
all_results = list(executor.map(search_single, query_vectors))
|
|
2491
|
-
else:
|
|
2492
|
-
# Sequential execution
|
|
2493
|
-
all_results = [search_single(vec) for vec in query_vectors]
|
|
2494
|
-
|
|
2495
|
-
return all_results
|
|
2496
|
-
|
|
2497
|
-
except Exception as e:
|
|
2498
|
-
raise StorageError(f"Batch vector search failed: {e}") from e
|
|
2503
|
+
if self._search_manager is None:
|
|
2504
|
+
raise StorageError("Database not connected")
|
|
2505
|
+
return self._search_manager.batch_vector_search(
|
|
2506
|
+
query_vectors=query_vectors,
|
|
2507
|
+
limit_per_query=limit_per_query,
|
|
2508
|
+
namespace=namespace,
|
|
2509
|
+
parallel=parallel,
|
|
2510
|
+
max_workers=max_workers,
|
|
2511
|
+
include_vector=include_vector,
|
|
2512
|
+
)
|
|
2499
2513
|
|
|
2500
2514
|
def get_vectors_for_clustering(
|
|
2501
2515
|
self,
|
|
@@ -2941,6 +2955,7 @@ class Database:
|
|
|
2941
2955
|
|
|
2942
2956
|
if deleted > 0:
|
|
2943
2957
|
self._invalidate_count_cache()
|
|
2958
|
+
self._track_modification(deleted)
|
|
2944
2959
|
self._invalidate_namespace_cache()
|
|
2945
2960
|
logger.info(f"Cleaned up {deleted} expired memories")
|
|
2946
2961
|
|
|
@@ -2949,148 +2964,58 @@ class Database:
|
|
|
2949
2964
|
raise StorageError(f"Failed to cleanup expired memories: {e}") from e
|
|
2950
2965
|
|
|
2951
2966
|
# ========================================================================
|
|
2952
|
-
# Snapshot / Version Management
|
|
2967
|
+
# Snapshot / Version Management (delegated to VersionManager)
|
|
2953
2968
|
# ========================================================================
|
|
2954
2969
|
|
|
2955
2970
|
def create_snapshot(self, tag: str) -> int:
|
|
2956
2971
|
"""Create a named snapshot of the current table state.
|
|
2957
2972
|
|
|
2958
|
-
|
|
2959
|
-
returns the current version number which can be used with restore_snapshot().
|
|
2960
|
-
|
|
2961
|
-
Args:
|
|
2962
|
-
tag: Semantic version tag (e.g., "v1.0.0", "backup-2024-01").
|
|
2963
|
-
Note: Tag is logged for reference but LanceDB tracks versions
|
|
2964
|
-
numerically. Consider storing tag->version mappings externally
|
|
2965
|
-
if tag-based retrieval is needed.
|
|
2966
|
-
|
|
2967
|
-
Returns:
|
|
2968
|
-
Version number of the snapshot.
|
|
2969
|
-
|
|
2970
|
-
Raises:
|
|
2971
|
-
StorageError: If snapshot creation fails.
|
|
2973
|
+
Delegates to VersionManager. See VersionManager.create_snapshot for details.
|
|
2972
2974
|
"""
|
|
2973
|
-
|
|
2974
|
-
|
|
2975
|
-
|
|
2976
|
-
return version
|
|
2977
|
-
except Exception as e:
|
|
2978
|
-
raise StorageError(f"Failed to create snapshot: {e}") from e
|
|
2975
|
+
if self._version_manager is None:
|
|
2976
|
+
raise StorageError("Database not connected")
|
|
2977
|
+
return self._version_manager.create_snapshot(tag)
|
|
2979
2978
|
|
|
2980
2979
|
def list_snapshots(self) -> list[dict[str, Any]]:
|
|
2981
2980
|
"""List available versions/snapshots.
|
|
2982
2981
|
|
|
2983
|
-
|
|
2984
|
-
List of version information dictionaries. Each dict contains
|
|
2985
|
-
at minimum 'version' key. Additional fields depend on LanceDB
|
|
2986
|
-
version and available metadata.
|
|
2987
|
-
|
|
2988
|
-
Raises:
|
|
2989
|
-
StorageError: If listing fails.
|
|
2982
|
+
Delegates to VersionManager. See VersionManager.list_snapshots for details.
|
|
2990
2983
|
"""
|
|
2991
|
-
|
|
2992
|
-
|
|
2993
|
-
|
|
2994
|
-
# Try to get version history if available
|
|
2995
|
-
if hasattr(self.table, "list_versions"):
|
|
2996
|
-
try:
|
|
2997
|
-
versions = self.table.list_versions()
|
|
2998
|
-
for v in versions:
|
|
2999
|
-
if isinstance(v, dict):
|
|
3000
|
-
versions_info.append(v)
|
|
3001
|
-
elif hasattr(v, "version"):
|
|
3002
|
-
versions_info.append({
|
|
3003
|
-
"version": v.version,
|
|
3004
|
-
"timestamp": getattr(v, "timestamp", None),
|
|
3005
|
-
})
|
|
3006
|
-
else:
|
|
3007
|
-
versions_info.append({"version": v})
|
|
3008
|
-
except Exception as e:
|
|
3009
|
-
logger.debug(f"list_versions not fully supported: {e}")
|
|
3010
|
-
|
|
3011
|
-
# Always include current version
|
|
3012
|
-
if not versions_info:
|
|
3013
|
-
versions_info.append({"version": self.table.version})
|
|
3014
|
-
|
|
3015
|
-
return versions_info
|
|
3016
|
-
except Exception as e:
|
|
3017
|
-
logger.warning(f"Could not list snapshots: {e}")
|
|
3018
|
-
return [{"version": 0, "error": str(e)}]
|
|
2984
|
+
if self._version_manager is None:
|
|
2985
|
+
raise StorageError("Database not connected")
|
|
2986
|
+
return self._version_manager.list_snapshots()
|
|
3019
2987
|
|
|
3020
2988
|
def restore_snapshot(self, version: int) -> None:
|
|
3021
2989
|
"""Restore table to a specific version.
|
|
3022
2990
|
|
|
3023
|
-
|
|
3024
|
-
(doesn't delete history).
|
|
3025
|
-
|
|
3026
|
-
Args:
|
|
3027
|
-
version: The version number to restore to.
|
|
3028
|
-
|
|
3029
|
-
Raises:
|
|
3030
|
-
ValidationError: If version is invalid.
|
|
3031
|
-
StorageError: If restore fails.
|
|
2991
|
+
Delegates to VersionManager. See VersionManager.restore_snapshot for details.
|
|
3032
2992
|
"""
|
|
3033
|
-
if
|
|
3034
|
-
raise
|
|
3035
|
-
|
|
3036
|
-
try:
|
|
3037
|
-
self.table.restore(version)
|
|
3038
|
-
self._invalidate_count_cache()
|
|
3039
|
-
self._invalidate_namespace_cache()
|
|
3040
|
-
logger.info(f"Restored to version {version}")
|
|
3041
|
-
except Exception as e:
|
|
3042
|
-
raise StorageError(f"Failed to restore snapshot: {e}") from e
|
|
2993
|
+
if self._version_manager is None:
|
|
2994
|
+
raise StorageError("Database not connected")
|
|
2995
|
+
self._version_manager.restore_snapshot(version)
|
|
3043
2996
|
|
|
3044
2997
|
def get_current_version(self) -> int:
|
|
3045
2998
|
"""Get the current table version number.
|
|
3046
2999
|
|
|
3047
|
-
|
|
3048
|
-
Current version number.
|
|
3049
|
-
|
|
3050
|
-
Raises:
|
|
3051
|
-
StorageError: If version cannot be retrieved.
|
|
3000
|
+
Delegates to VersionManager. See VersionManager.get_current_version for details.
|
|
3052
3001
|
"""
|
|
3053
|
-
|
|
3054
|
-
|
|
3055
|
-
|
|
3056
|
-
raise StorageError(f"Failed to get current version: {e}") from e
|
|
3002
|
+
if self._version_manager is None:
|
|
3003
|
+
raise StorageError("Database not connected")
|
|
3004
|
+
return self._version_manager.get_current_version()
|
|
3057
3005
|
|
|
3058
3006
|
# ========================================================================
|
|
3059
|
-
# Idempotency Key Management
|
|
3007
|
+
# Idempotency Key Management (delegates to IdempotencyManager)
|
|
3060
3008
|
# ========================================================================
|
|
3061
3009
|
|
|
3062
|
-
def _ensure_idempotency_table(self) -> None:
|
|
3063
|
-
"""Ensure the idempotency keys table exists."""
|
|
3064
|
-
if self._db is None:
|
|
3065
|
-
raise StorageError("Database not connected")
|
|
3066
|
-
|
|
3067
|
-
existing_tables_result = self._db.list_tables()
|
|
3068
|
-
if hasattr(existing_tables_result, 'tables'):
|
|
3069
|
-
existing_tables = existing_tables_result.tables
|
|
3070
|
-
else:
|
|
3071
|
-
existing_tables = existing_tables_result
|
|
3072
|
-
|
|
3073
|
-
if "idempotency_keys" not in existing_tables:
|
|
3074
|
-
schema = pa.schema([
|
|
3075
|
-
pa.field("key", pa.string()),
|
|
3076
|
-
pa.field("memory_id", pa.string()),
|
|
3077
|
-
pa.field("created_at", pa.timestamp("us")),
|
|
3078
|
-
pa.field("expires_at", pa.timestamp("us")),
|
|
3079
|
-
])
|
|
3080
|
-
self._db.create_table("idempotency_keys", schema=schema)
|
|
3081
|
-
logger.info("Created idempotency_keys table")
|
|
3082
|
-
|
|
3083
3010
|
@property
|
|
3084
3011
|
def idempotency_table(self) -> LanceTable:
|
|
3085
|
-
"""Get the idempotency keys table
|
|
3086
|
-
if self.
|
|
3087
|
-
|
|
3088
|
-
self.
|
|
3089
|
-
assert self._db is not None
|
|
3090
|
-
return self._db.open_table("idempotency_keys")
|
|
3012
|
+
"""Get the idempotency keys table. Delegates to IdempotencyManager."""
|
|
3013
|
+
if self._idempotency_manager is None:
|
|
3014
|
+
raise StorageError("Database not connected")
|
|
3015
|
+
return self._idempotency_manager.idempotency_table
|
|
3091
3016
|
|
|
3092
3017
|
def get_by_idempotency_key(self, key: str) -> IdempotencyRecord | None:
|
|
3093
|
-
"""Look up an idempotency record by key.
|
|
3018
|
+
"""Look up an idempotency record by key. Delegates to IdempotencyManager.
|
|
3094
3019
|
|
|
3095
3020
|
Args:
|
|
3096
3021
|
key: The idempotency key to look up.
|
|
@@ -3101,42 +3026,9 @@ class Database:
|
|
|
3101
3026
|
Raises:
|
|
3102
3027
|
StorageError: If database operation fails.
|
|
3103
3028
|
"""
|
|
3104
|
-
if
|
|
3105
|
-
|
|
3106
|
-
|
|
3107
|
-
try:
|
|
3108
|
-
safe_key = _sanitize_string(key)
|
|
3109
|
-
results = (
|
|
3110
|
-
self.idempotency_table.search()
|
|
3111
|
-
.where(f"key = '{safe_key}'")
|
|
3112
|
-
.limit(1)
|
|
3113
|
-
.to_list()
|
|
3114
|
-
)
|
|
3115
|
-
|
|
3116
|
-
if not results:
|
|
3117
|
-
return None
|
|
3118
|
-
|
|
3119
|
-
record = results[0]
|
|
3120
|
-
now = utc_now()
|
|
3121
|
-
|
|
3122
|
-
# Check if expired (convert DB naive datetime to aware for comparison)
|
|
3123
|
-
expires_at = record.get("expires_at")
|
|
3124
|
-
if expires_at is not None:
|
|
3125
|
-
expires_at_aware = to_aware_utc(expires_at)
|
|
3126
|
-
if expires_at_aware < now:
|
|
3127
|
-
# Expired - clean it up and return None
|
|
3128
|
-
logger.debug(f"Idempotency key '{key}' has expired")
|
|
3129
|
-
return None
|
|
3130
|
-
|
|
3131
|
-
return IdempotencyRecord(
|
|
3132
|
-
key=record["key"],
|
|
3133
|
-
memory_id=record["memory_id"],
|
|
3134
|
-
created_at=record["created_at"],
|
|
3135
|
-
expires_at=record["expires_at"],
|
|
3136
|
-
)
|
|
3137
|
-
|
|
3138
|
-
except Exception as e:
|
|
3139
|
-
raise StorageError(f"Failed to look up idempotency key: {e}") from e
|
|
3029
|
+
if self._idempotency_manager is None:
|
|
3030
|
+
raise StorageError("Database not connected")
|
|
3031
|
+
return self._idempotency_manager.get_by_idempotency_key(key)
|
|
3140
3032
|
|
|
3141
3033
|
@with_process_lock
|
|
3142
3034
|
@with_write_lock
|
|
@@ -3146,7 +3038,7 @@ class Database:
|
|
|
3146
3038
|
memory_id: str,
|
|
3147
3039
|
ttl_hours: float = 24.0,
|
|
3148
3040
|
) -> None:
|
|
3149
|
-
"""Store an idempotency key mapping.
|
|
3041
|
+
"""Store an idempotency key mapping. Delegates to IdempotencyManager.
|
|
3150
3042
|
|
|
3151
3043
|
Args:
|
|
3152
3044
|
key: The idempotency key.
|
|
@@ -3157,36 +3049,14 @@ class Database:
|
|
|
3157
3049
|
ValidationError: If inputs are invalid.
|
|
3158
3050
|
StorageError: If database operation fails.
|
|
3159
3051
|
"""
|
|
3160
|
-
if
|
|
3161
|
-
raise
|
|
3162
|
-
|
|
3163
|
-
raise ValidationError("Memory ID cannot be empty")
|
|
3164
|
-
if ttl_hours <= 0:
|
|
3165
|
-
raise ValidationError("TTL must be positive")
|
|
3166
|
-
|
|
3167
|
-
now = utc_now()
|
|
3168
|
-
expires_at = now + timedelta(hours=ttl_hours)
|
|
3169
|
-
|
|
3170
|
-
record = {
|
|
3171
|
-
"key": key,
|
|
3172
|
-
"memory_id": memory_id,
|
|
3173
|
-
"created_at": now,
|
|
3174
|
-
"expires_at": expires_at,
|
|
3175
|
-
}
|
|
3176
|
-
|
|
3177
|
-
try:
|
|
3178
|
-
self.idempotency_table.add([record])
|
|
3179
|
-
logger.debug(
|
|
3180
|
-
f"Stored idempotency key '{key}' -> memory '{memory_id}' "
|
|
3181
|
-
f"(expires in {ttl_hours}h)"
|
|
3182
|
-
)
|
|
3183
|
-
except Exception as e:
|
|
3184
|
-
raise StorageError(f"Failed to store idempotency key: {e}") from e
|
|
3052
|
+
if self._idempotency_manager is None:
|
|
3053
|
+
raise StorageError("Database not connected")
|
|
3054
|
+
self._idempotency_manager.store_idempotency_key(key, memory_id, ttl_hours)
|
|
3185
3055
|
|
|
3186
3056
|
@with_process_lock
|
|
3187
3057
|
@with_write_lock
|
|
3188
3058
|
def cleanup_expired_idempotency_keys(self) -> int:
|
|
3189
|
-
"""Remove expired idempotency keys.
|
|
3059
|
+
"""Remove expired idempotency keys. Delegates to IdempotencyManager.
|
|
3190
3060
|
|
|
3191
3061
|
Returns:
|
|
3192
3062
|
Number of keys removed.
|
|
@@ -3194,20 +3064,6 @@ class Database:
|
|
|
3194
3064
|
Raises:
|
|
3195
3065
|
StorageError: If cleanup fails.
|
|
3196
3066
|
"""
|
|
3197
|
-
|
|
3198
|
-
|
|
3199
|
-
|
|
3200
|
-
|
|
3201
|
-
# Delete expired keys
|
|
3202
|
-
predicate = f"expires_at < timestamp '{now.isoformat()}'"
|
|
3203
|
-
self.idempotency_table.delete(predicate)
|
|
3204
|
-
|
|
3205
|
-
count_after = self.idempotency_table.count_rows()
|
|
3206
|
-
deleted = count_before - count_after
|
|
3207
|
-
|
|
3208
|
-
if deleted > 0:
|
|
3209
|
-
logger.info(f"Cleaned up {deleted} expired idempotency keys")
|
|
3210
|
-
|
|
3211
|
-
return deleted
|
|
3212
|
-
except Exception as e:
|
|
3213
|
-
raise StorageError(f"Failed to cleanup idempotency keys: {e}") from e
|
|
3067
|
+
if self._idempotency_manager is None:
|
|
3068
|
+
raise StorageError("Database not connected")
|
|
3069
|
+
return self._idempotency_manager.cleanup_expired_idempotency_keys()
|