alma-memory 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alma/__init__.py +121 -45
- alma/confidence/__init__.py +1 -1
- alma/confidence/engine.py +92 -58
- alma/confidence/types.py +34 -14
- alma/config/loader.py +3 -2
- alma/consolidation/__init__.py +23 -0
- alma/consolidation/engine.py +678 -0
- alma/consolidation/prompts.py +84 -0
- alma/core.py +136 -28
- alma/domains/__init__.py +6 -6
- alma/domains/factory.py +12 -9
- alma/domains/schemas.py +17 -3
- alma/domains/types.py +8 -4
- alma/events/__init__.py +75 -0
- alma/events/emitter.py +284 -0
- alma/events/storage_mixin.py +246 -0
- alma/events/types.py +126 -0
- alma/events/webhook.py +425 -0
- alma/exceptions.py +49 -0
- alma/extraction/__init__.py +31 -0
- alma/extraction/auto_learner.py +265 -0
- alma/extraction/extractor.py +420 -0
- alma/graph/__init__.py +106 -0
- alma/graph/backends/__init__.py +32 -0
- alma/graph/backends/kuzu.py +624 -0
- alma/graph/backends/memgraph.py +432 -0
- alma/graph/backends/memory.py +236 -0
- alma/graph/backends/neo4j.py +417 -0
- alma/graph/base.py +159 -0
- alma/graph/extraction.py +198 -0
- alma/graph/store.py +860 -0
- alma/harness/__init__.py +4 -4
- alma/harness/base.py +18 -9
- alma/harness/domains.py +27 -11
- alma/initializer/__init__.py +1 -1
- alma/initializer/initializer.py +51 -43
- alma/initializer/types.py +25 -17
- alma/integration/__init__.py +9 -9
- alma/integration/claude_agents.py +32 -20
- alma/integration/helena.py +32 -22
- alma/integration/victor.py +57 -33
- alma/learning/__init__.py +27 -27
- alma/learning/forgetting.py +198 -148
- alma/learning/heuristic_extractor.py +40 -24
- alma/learning/protocols.py +65 -17
- alma/learning/validation.py +7 -2
- alma/mcp/__init__.py +4 -4
- alma/mcp/__main__.py +2 -1
- alma/mcp/resources.py +17 -16
- alma/mcp/server.py +102 -44
- alma/mcp/tools.py +180 -45
- alma/observability/__init__.py +84 -0
- alma/observability/config.py +302 -0
- alma/observability/logging.py +424 -0
- alma/observability/metrics.py +583 -0
- alma/observability/tracing.py +440 -0
- alma/progress/__init__.py +3 -3
- alma/progress/tracker.py +26 -20
- alma/progress/types.py +8 -12
- alma/py.typed +0 -0
- alma/retrieval/__init__.py +11 -11
- alma/retrieval/cache.py +20 -21
- alma/retrieval/embeddings.py +4 -4
- alma/retrieval/engine.py +179 -39
- alma/retrieval/scoring.py +73 -63
- alma/session/__init__.py +2 -2
- alma/session/manager.py +5 -5
- alma/session/types.py +5 -4
- alma/storage/__init__.py +70 -0
- alma/storage/azure_cosmos.py +414 -133
- alma/storage/base.py +215 -4
- alma/storage/chroma.py +1443 -0
- alma/storage/constants.py +103 -0
- alma/storage/file_based.py +59 -28
- alma/storage/migrations/__init__.py +21 -0
- alma/storage/migrations/base.py +321 -0
- alma/storage/migrations/runner.py +323 -0
- alma/storage/migrations/version_stores.py +337 -0
- alma/storage/migrations/versions/__init__.py +11 -0
- alma/storage/migrations/versions/v1_0_0.py +373 -0
- alma/storage/pinecone.py +1080 -0
- alma/storage/postgresql.py +1559 -0
- alma/storage/qdrant.py +1306 -0
- alma/storage/sqlite_local.py +504 -60
- alma/testing/__init__.py +46 -0
- alma/testing/factories.py +301 -0
- alma/testing/mocks.py +389 -0
- alma/types.py +62 -14
- alma_memory-0.5.1.dist-info/METADATA +939 -0
- alma_memory-0.5.1.dist-info/RECORD +93 -0
- {alma_memory-0.4.0.dist-info → alma_memory-0.5.1.dist-info}/WHEEL +1 -1
- alma_memory-0.4.0.dist-info/METADATA +0 -488
- alma_memory-0.4.0.dist-info/RECORD +0 -52
- {alma_memory-0.4.0.dist-info → alma_memory-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Storage Constants.
|
|
3
|
+
|
|
4
|
+
Canonical naming conventions for memory types across all storage backends.
|
|
5
|
+
This ensures consistency for:
|
|
6
|
+
- Data migration between backends
|
|
7
|
+
- Backend-agnostic code
|
|
8
|
+
- Documentation consistency
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Dict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MemoryType:
|
|
15
|
+
"""
|
|
16
|
+
Canonical memory type identifiers.
|
|
17
|
+
|
|
18
|
+
These are the internal names used consistently across all backends.
|
|
19
|
+
Each backend may add a prefix or transform these for their specific
|
|
20
|
+
storage format, but the canonical names remain constant.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
HEURISTICS = "heuristics"
|
|
24
|
+
OUTCOMES = "outcomes"
|
|
25
|
+
PREFERENCES = "preferences"
|
|
26
|
+
DOMAIN_KNOWLEDGE = "domain_knowledge"
|
|
27
|
+
ANTI_PATTERNS = "anti_patterns"
|
|
28
|
+
|
|
29
|
+
# All memory types as a tuple for iteration
|
|
30
|
+
ALL = (
|
|
31
|
+
HEURISTICS,
|
|
32
|
+
OUTCOMES,
|
|
33
|
+
PREFERENCES,
|
|
34
|
+
DOMAIN_KNOWLEDGE,
|
|
35
|
+
ANTI_PATTERNS,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Memory types that support embeddings/vector search
|
|
39
|
+
VECTOR_ENABLED = (
|
|
40
|
+
HEURISTICS,
|
|
41
|
+
OUTCOMES,
|
|
42
|
+
DOMAIN_KNOWLEDGE,
|
|
43
|
+
ANTI_PATTERNS,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_table_name(memory_type: str, prefix: str = "") -> str:
|
|
48
|
+
"""
|
|
49
|
+
Get the table/container name for a memory type.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
memory_type: One of the MemoryType constants
|
|
53
|
+
prefix: Optional prefix to add (e.g., "alma_" for PostgreSQL)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The formatted table/container name
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> get_table_name(MemoryType.HEURISTICS, "alma_")
|
|
60
|
+
'alma_heuristics'
|
|
61
|
+
>>> get_table_name(MemoryType.DOMAIN_KNOWLEDGE)
|
|
62
|
+
'domain_knowledge'
|
|
63
|
+
"""
|
|
64
|
+
if memory_type not in MemoryType.ALL:
|
|
65
|
+
raise ValueError(f"Unknown memory type: {memory_type}")
|
|
66
|
+
return f"{prefix}{memory_type}"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_table_names(prefix: str = "") -> Dict[str, str]:
|
|
70
|
+
"""
|
|
71
|
+
Get all table/container names with an optional prefix.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
prefix: Optional prefix to add (e.g., "alma_" for PostgreSQL)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Dict mapping canonical memory type to table/container name
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
>>> get_table_names("alma_")
|
|
81
|
+
{
|
|
82
|
+
'heuristics': 'alma_heuristics',
|
|
83
|
+
'outcomes': 'alma_outcomes',
|
|
84
|
+
'preferences': 'alma_preferences',
|
|
85
|
+
'domain_knowledge': 'alma_domain_knowledge',
|
|
86
|
+
'anti_patterns': 'alma_anti_patterns',
|
|
87
|
+
}
|
|
88
|
+
"""
|
|
89
|
+
return {mt: get_table_name(mt, prefix) for mt in MemoryType.ALL}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Pre-computed table name mappings for each backend
|
|
93
|
+
# These are the canonical mappings that should be used
|
|
94
|
+
|
|
95
|
+
# PostgreSQL uses alma_ prefix with underscores
|
|
96
|
+
POSTGRESQL_TABLE_NAMES = get_table_names("alma_")
|
|
97
|
+
|
|
98
|
+
# SQLite uses no prefix (local file-based, no collision risk)
|
|
99
|
+
SQLITE_TABLE_NAMES = get_table_names("")
|
|
100
|
+
|
|
101
|
+
# Azure Cosmos uses alma_ prefix with underscores (standardized)
|
|
102
|
+
# Note: Previously used hyphens, now standardized to underscores
|
|
103
|
+
AZURE_COSMOS_CONTAINER_NAMES = get_table_names("alma_")
|
alma/storage/file_based.py
CHANGED
|
@@ -6,21 +6,20 @@ No vector search - uses basic text matching for retrieval.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import json
|
|
9
|
-
import uuid
|
|
10
9
|
import logging
|
|
11
|
-
from pathlib import Path
|
|
12
10
|
from datetime import datetime, timezone
|
|
13
|
-
from
|
|
14
|
-
from
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
15
13
|
|
|
14
|
+
from alma.storage.base import StorageBackend
|
|
15
|
+
from alma.storage.constants import MemoryType
|
|
16
16
|
from alma.types import (
|
|
17
|
+
AntiPattern,
|
|
18
|
+
DomainKnowledge,
|
|
17
19
|
Heuristic,
|
|
18
20
|
Outcome,
|
|
19
21
|
UserPreference,
|
|
20
|
-
DomainKnowledge,
|
|
21
|
-
AntiPattern,
|
|
22
22
|
)
|
|
23
|
-
from alma.storage.base import StorageBackend
|
|
24
23
|
|
|
25
24
|
logger = logging.getLogger(__name__)
|
|
26
25
|
|
|
@@ -51,14 +50,8 @@ class FileBasedStorage(StorageBackend):
|
|
|
51
50
|
self.storage_dir = Path(storage_dir)
|
|
52
51
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
53
52
|
|
|
54
|
-
# File paths
|
|
55
|
-
self._files = {
|
|
56
|
-
"heuristics": self.storage_dir / "heuristics.json",
|
|
57
|
-
"outcomes": self.storage_dir / "outcomes.json",
|
|
58
|
-
"preferences": self.storage_dir / "preferences.json",
|
|
59
|
-
"domain_knowledge": self.storage_dir / "domain_knowledge.json",
|
|
60
|
-
"anti_patterns": self.storage_dir / "anti_patterns.json",
|
|
61
|
-
}
|
|
53
|
+
# File paths (using canonical memory type names)
|
|
54
|
+
self._files = {mt: self.storage_dir / f"{mt}.json" for mt in MemoryType.ALL}
|
|
62
55
|
|
|
63
56
|
# Initialize empty files if they don't exist
|
|
64
57
|
for file_path in self._files.values():
|
|
@@ -74,46 +67,86 @@ class FileBasedStorage(StorageBackend):
|
|
|
74
67
|
# ==================== WRITE OPERATIONS ====================
|
|
75
68
|
|
|
76
69
|
def save_heuristic(self, heuristic: Heuristic) -> str:
|
|
77
|
-
"""Save a heuristic."""
|
|
70
|
+
"""Save a heuristic (UPSERT - update if exists, insert if new)."""
|
|
78
71
|
data = self._read_json(self._files["heuristics"])
|
|
79
72
|
record = self._to_dict(heuristic)
|
|
80
|
-
|
|
73
|
+
# Find and replace existing, or append new
|
|
74
|
+
found = False
|
|
75
|
+
for i, existing in enumerate(data):
|
|
76
|
+
if existing.get("id") == record["id"]:
|
|
77
|
+
data[i] = record
|
|
78
|
+
found = True
|
|
79
|
+
break
|
|
80
|
+
if not found:
|
|
81
|
+
data.append(record)
|
|
81
82
|
self._write_json(self._files["heuristics"], data)
|
|
82
83
|
logger.debug(f"Saved heuristic: {heuristic.id}")
|
|
83
84
|
return heuristic.id
|
|
84
85
|
|
|
85
86
|
def save_outcome(self, outcome: Outcome) -> str:
|
|
86
|
-
"""Save an outcome."""
|
|
87
|
+
"""Save an outcome (UPSERT - update if exists, insert if new)."""
|
|
87
88
|
data = self._read_json(self._files["outcomes"])
|
|
88
89
|
record = self._to_dict(outcome)
|
|
89
|
-
|
|
90
|
+
# Find and replace existing, or append new
|
|
91
|
+
found = False
|
|
92
|
+
for i, existing in enumerate(data):
|
|
93
|
+
if existing.get("id") == record["id"]:
|
|
94
|
+
data[i] = record
|
|
95
|
+
found = True
|
|
96
|
+
break
|
|
97
|
+
if not found:
|
|
98
|
+
data.append(record)
|
|
90
99
|
self._write_json(self._files["outcomes"], data)
|
|
91
100
|
logger.debug(f"Saved outcome: {outcome.id}")
|
|
92
101
|
return outcome.id
|
|
93
102
|
|
|
94
103
|
def save_user_preference(self, preference: UserPreference) -> str:
|
|
95
|
-
"""Save a user preference."""
|
|
104
|
+
"""Save a user preference (UPSERT - update if exists, insert if new)."""
|
|
96
105
|
data = self._read_json(self._files["preferences"])
|
|
97
106
|
record = self._to_dict(preference)
|
|
98
|
-
|
|
107
|
+
# Find and replace existing, or append new
|
|
108
|
+
found = False
|
|
109
|
+
for i, existing in enumerate(data):
|
|
110
|
+
if existing.get("id") == record["id"]:
|
|
111
|
+
data[i] = record
|
|
112
|
+
found = True
|
|
113
|
+
break
|
|
114
|
+
if not found:
|
|
115
|
+
data.append(record)
|
|
99
116
|
self._write_json(self._files["preferences"], data)
|
|
100
117
|
logger.debug(f"Saved preference: {preference.id}")
|
|
101
118
|
return preference.id
|
|
102
119
|
|
|
103
120
|
def save_domain_knowledge(self, knowledge: DomainKnowledge) -> str:
|
|
104
|
-
"""Save domain knowledge."""
|
|
121
|
+
"""Save domain knowledge (UPSERT - update if exists, insert if new)."""
|
|
105
122
|
data = self._read_json(self._files["domain_knowledge"])
|
|
106
123
|
record = self._to_dict(knowledge)
|
|
107
|
-
|
|
124
|
+
# Find and replace existing, or append new
|
|
125
|
+
found = False
|
|
126
|
+
for i, existing in enumerate(data):
|
|
127
|
+
if existing.get("id") == record["id"]:
|
|
128
|
+
data[i] = record
|
|
129
|
+
found = True
|
|
130
|
+
break
|
|
131
|
+
if not found:
|
|
132
|
+
data.append(record)
|
|
108
133
|
self._write_json(self._files["domain_knowledge"], data)
|
|
109
134
|
logger.debug(f"Saved domain knowledge: {knowledge.id}")
|
|
110
135
|
return knowledge.id
|
|
111
136
|
|
|
112
137
|
def save_anti_pattern(self, anti_pattern: AntiPattern) -> str:
|
|
113
|
-
"""Save an anti-pattern."""
|
|
138
|
+
"""Save an anti-pattern (UPSERT - update if exists, insert if new)."""
|
|
114
139
|
data = self._read_json(self._files["anti_patterns"])
|
|
115
140
|
record = self._to_dict(anti_pattern)
|
|
116
|
-
|
|
141
|
+
# Find and replace existing, or append new
|
|
142
|
+
found = False
|
|
143
|
+
for i, existing in enumerate(data):
|
|
144
|
+
if existing.get("id") == record["id"]:
|
|
145
|
+
data[i] = record
|
|
146
|
+
found = True
|
|
147
|
+
break
|
|
148
|
+
if not found:
|
|
149
|
+
data.append(record)
|
|
117
150
|
self._write_json(self._files["anti_patterns"], data)
|
|
118
151
|
logger.debug(f"Saved anti-pattern: {anti_pattern.id}")
|
|
119
152
|
return anti_pattern.id
|
|
@@ -451,9 +484,7 @@ class FileBasedStorage(StorageBackend):
|
|
|
451
484
|
count += 1
|
|
452
485
|
stats[f"{name}_count"] = count
|
|
453
486
|
|
|
454
|
-
stats["total_count"] = sum(
|
|
455
|
-
stats[k] for k in stats if k.endswith("_count")
|
|
456
|
-
)
|
|
487
|
+
stats["total_count"] = sum(stats[k] for k in stats if k.endswith("_count"))
|
|
457
488
|
|
|
458
489
|
return stats
|
|
459
490
|
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Schema Migration Framework.
|
|
3
|
+
|
|
4
|
+
Provides version tracking and migration capabilities for storage backends.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from alma.storage.migrations.base import (
|
|
8
|
+
Migration,
|
|
9
|
+
MigrationError,
|
|
10
|
+
MigrationRegistry,
|
|
11
|
+
SchemaVersion,
|
|
12
|
+
)
|
|
13
|
+
from alma.storage.migrations.runner import MigrationRunner
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"Migration",
|
|
17
|
+
"MigrationError",
|
|
18
|
+
"MigrationRegistry",
|
|
19
|
+
"MigrationRunner",
|
|
20
|
+
"SchemaVersion",
|
|
21
|
+
]
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Migration Framework - Base Classes.
|
|
3
|
+
|
|
4
|
+
Provides abstract migration classes and version tracking utilities.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime, timezone
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Type
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MigrationError(Exception):
|
|
17
|
+
"""Exception raised when a migration fails."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
message: str,
|
|
22
|
+
version: Optional[str] = None,
|
|
23
|
+
cause: Optional[Exception] = None,
|
|
24
|
+
):
|
|
25
|
+
self.version = version
|
|
26
|
+
self.cause = cause
|
|
27
|
+
super().__init__(message)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class SchemaVersion:
|
|
32
|
+
"""
|
|
33
|
+
Represents a schema version record.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
version: Semantic version string (e.g., "1.0.0")
|
|
37
|
+
applied_at: When the migration was applied
|
|
38
|
+
description: Human-readable description of changes
|
|
39
|
+
checksum: Optional hash for integrity verification
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
version: str
|
|
43
|
+
applied_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
44
|
+
description: str = ""
|
|
45
|
+
checksum: Optional[str] = None
|
|
46
|
+
|
|
47
|
+
def __lt__(self, other: "SchemaVersion") -> bool:
|
|
48
|
+
"""Compare versions for sorting."""
|
|
49
|
+
return self._parse_version(self.version) < self._parse_version(other.version)
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _parse_version(version: str) -> tuple:
|
|
53
|
+
"""Parse version string into comparable tuple."""
|
|
54
|
+
parts = version.split(".")
|
|
55
|
+
result = []
|
|
56
|
+
for part in parts:
|
|
57
|
+
try:
|
|
58
|
+
result.append(int(part))
|
|
59
|
+
except ValueError:
|
|
60
|
+
result.append(part)
|
|
61
|
+
return tuple(result)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Migration(ABC):
|
|
65
|
+
"""
|
|
66
|
+
Abstract base class for schema migrations.
|
|
67
|
+
|
|
68
|
+
Subclasses must implement upgrade() and optionally downgrade().
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
class AddTagsColumn(Migration):
|
|
72
|
+
version = "1.1.0"
|
|
73
|
+
description = "Add tags column to heuristics table"
|
|
74
|
+
|
|
75
|
+
def upgrade(self, connection):
|
|
76
|
+
connection.execute(
|
|
77
|
+
"ALTER TABLE heuristics ADD COLUMN tags TEXT"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def downgrade(self, connection):
|
|
81
|
+
connection.execute(
|
|
82
|
+
"ALTER TABLE heuristics DROP COLUMN tags"
|
|
83
|
+
)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# These must be set by subclasses
|
|
87
|
+
version: str = ""
|
|
88
|
+
description: str = ""
|
|
89
|
+
# Optional: previous version this migration depends on
|
|
90
|
+
depends_on: Optional[str] = None
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def upgrade(self, connection: Any) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Apply the migration.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
connection: Database connection or storage instance
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
def downgrade(self, connection: Any) -> None:
|
|
103
|
+
"""
|
|
104
|
+
Revert the migration (optional).
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
connection: Database connection or storage instance
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
NotImplementedError: If downgrade is not supported
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError(
|
|
113
|
+
f"Downgrade not implemented for migration {self.version}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def pre_check(self, connection: Any) -> bool:
|
|
117
|
+
"""
|
|
118
|
+
Optional pre-migration check.
|
|
119
|
+
|
|
120
|
+
Override to verify prerequisites before migration.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
connection: Database connection
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
True if migration can proceed, False otherwise
|
|
127
|
+
"""
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
def post_check(self, connection: Any) -> bool:
|
|
131
|
+
"""
|
|
132
|
+
Optional post-migration verification.
|
|
133
|
+
|
|
134
|
+
Override to verify migration was successful.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
connection: Database connection
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
True if migration was successful
|
|
141
|
+
"""
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class MigrationRegistry:
|
|
146
|
+
"""
|
|
147
|
+
Registry for available migrations.
|
|
148
|
+
|
|
149
|
+
Manages migration discovery, ordering, and execution planning.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self) -> None:
|
|
153
|
+
self._migrations: Dict[str, Type[Migration]] = {}
|
|
154
|
+
self._backend_migrations: Dict[str, Dict[str, Type[Migration]]] = {}
|
|
155
|
+
|
|
156
|
+
def register(
|
|
157
|
+
self, migration_class: Type[Migration], backend: Optional[str] = None
|
|
158
|
+
) -> Type[Migration]:
|
|
159
|
+
"""
|
|
160
|
+
Register a migration class.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
migration_class: The migration class to register
|
|
164
|
+
backend: Optional backend name (e.g., "sqlite", "postgresql")
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
The migration class (for use as decorator)
|
|
168
|
+
"""
|
|
169
|
+
version = migration_class.version
|
|
170
|
+
if not version:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Migration {migration_class.__name__} must have a version"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if backend:
|
|
176
|
+
if backend not in self._backend_migrations:
|
|
177
|
+
self._backend_migrations[backend] = {}
|
|
178
|
+
self._backend_migrations[backend][version] = migration_class
|
|
179
|
+
logger.debug(f"Registered migration {version} for backend {backend}")
|
|
180
|
+
else:
|
|
181
|
+
self._migrations[version] = migration_class
|
|
182
|
+
logger.debug(f"Registered global migration {version}")
|
|
183
|
+
|
|
184
|
+
return migration_class
|
|
185
|
+
|
|
186
|
+
def get_migration(
|
|
187
|
+
self, version: str, backend: Optional[str] = None
|
|
188
|
+
) -> Optional[Type[Migration]]:
|
|
189
|
+
"""
|
|
190
|
+
Get a migration class by version.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
version: Version string to look up
|
|
194
|
+
backend: Optional backend name
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Migration class or None if not found
|
|
198
|
+
"""
|
|
199
|
+
if backend and backend in self._backend_migrations:
|
|
200
|
+
migration = self._backend_migrations[backend].get(version)
|
|
201
|
+
if migration:
|
|
202
|
+
return migration
|
|
203
|
+
return self._migrations.get(version)
|
|
204
|
+
|
|
205
|
+
def get_all_migrations(
|
|
206
|
+
self, backend: Optional[str] = None
|
|
207
|
+
) -> List[Type[Migration]]:
|
|
208
|
+
"""
|
|
209
|
+
Get all migrations in version order.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
backend: Optional backend name to filter migrations
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
List of migration classes sorted by version
|
|
216
|
+
"""
|
|
217
|
+
migrations = dict(self._migrations)
|
|
218
|
+
if backend and backend in self._backend_migrations:
|
|
219
|
+
migrations.update(self._backend_migrations[backend])
|
|
220
|
+
|
|
221
|
+
return [
|
|
222
|
+
cls
|
|
223
|
+
for _, cls in sorted(
|
|
224
|
+
migrations.items(),
|
|
225
|
+
key=lambda x: SchemaVersion._parse_version(x[0]),
|
|
226
|
+
)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
def get_pending_migrations(
|
|
230
|
+
self,
|
|
231
|
+
current_version: Optional[str],
|
|
232
|
+
backend: Optional[str] = None,
|
|
233
|
+
) -> List[Type[Migration]]:
|
|
234
|
+
"""
|
|
235
|
+
Get migrations that need to be applied.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
current_version: Current schema version (None if fresh install)
|
|
239
|
+
backend: Optional backend name
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
List of migration classes that need to be applied
|
|
243
|
+
"""
|
|
244
|
+
all_migrations = self.get_all_migrations(backend)
|
|
245
|
+
|
|
246
|
+
if current_version is None:
|
|
247
|
+
return all_migrations
|
|
248
|
+
|
|
249
|
+
current = SchemaVersion._parse_version(current_version)
|
|
250
|
+
return [
|
|
251
|
+
m
|
|
252
|
+
for m in all_migrations
|
|
253
|
+
if SchemaVersion._parse_version(m.version) > current
|
|
254
|
+
]
|
|
255
|
+
|
|
256
|
+
def get_rollback_migrations(
|
|
257
|
+
self,
|
|
258
|
+
current_version: str,
|
|
259
|
+
target_version: str,
|
|
260
|
+
backend: Optional[str] = None,
|
|
261
|
+
) -> List[Type[Migration]]:
|
|
262
|
+
"""
|
|
263
|
+
Get migrations that need to be rolled back.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
current_version: Current schema version
|
|
267
|
+
target_version: Target version to roll back to
|
|
268
|
+
backend: Optional backend name
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
List of migration classes to roll back (in reverse order)
|
|
272
|
+
"""
|
|
273
|
+
all_migrations = self.get_all_migrations(backend)
|
|
274
|
+
|
|
275
|
+
current = SchemaVersion._parse_version(current_version)
|
|
276
|
+
target = SchemaVersion._parse_version(target_version)
|
|
277
|
+
|
|
278
|
+
rollback = [
|
|
279
|
+
m
|
|
280
|
+
for m in all_migrations
|
|
281
|
+
if target < SchemaVersion._parse_version(m.version) <= current
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
# Return in reverse order for rollback
|
|
285
|
+
return list(reversed(rollback))
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Global registry instance
|
|
289
|
+
_global_registry = MigrationRegistry()
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def register_migration(
|
|
293
|
+
backend: Optional[str] = None,
|
|
294
|
+
) -> Callable[[Type[Migration]], Type[Migration]]:
|
|
295
|
+
"""
|
|
296
|
+
Decorator to register a migration class.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
backend: Optional backend name
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
@register_migration()
|
|
303
|
+
class MyMigration(Migration):
|
|
304
|
+
version = "1.0.0"
|
|
305
|
+
...
|
|
306
|
+
|
|
307
|
+
@register_migration(backend="postgresql")
|
|
308
|
+
class PostgresSpecificMigration(Migration):
|
|
309
|
+
version = "1.0.1"
|
|
310
|
+
...
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def decorator(cls: Type[Migration]) -> Type[Migration]:
|
|
314
|
+
return _global_registry.register(cls, backend)
|
|
315
|
+
|
|
316
|
+
return decorator
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def get_registry() -> MigrationRegistry:
|
|
320
|
+
"""Get the global migration registry."""
|
|
321
|
+
return _global_registry
|