alma-memory 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alma/__init__.py +296 -194
- alma/compression/__init__.py +33 -0
- alma/compression/pipeline.py +980 -0
- alma/confidence/__init__.py +47 -47
- alma/confidence/engine.py +540 -540
- alma/confidence/types.py +351 -351
- alma/config/loader.py +157 -157
- alma/consolidation/__init__.py +23 -23
- alma/consolidation/engine.py +678 -678
- alma/consolidation/prompts.py +84 -84
- alma/core.py +1189 -322
- alma/domains/__init__.py +30 -30
- alma/domains/factory.py +359 -359
- alma/domains/schemas.py +448 -448
- alma/domains/types.py +272 -272
- alma/events/__init__.py +75 -75
- alma/events/emitter.py +285 -284
- alma/events/storage_mixin.py +246 -246
- alma/events/types.py +126 -126
- alma/events/webhook.py +425 -425
- alma/exceptions.py +49 -49
- alma/extraction/__init__.py +31 -31
- alma/extraction/auto_learner.py +265 -264
- alma/extraction/extractor.py +420 -420
- alma/graph/__init__.py +106 -81
- alma/graph/backends/__init__.py +32 -18
- alma/graph/backends/kuzu.py +624 -0
- alma/graph/backends/memgraph.py +432 -0
- alma/graph/backends/memory.py +236 -236
- alma/graph/backends/neo4j.py +417 -417
- alma/graph/base.py +159 -159
- alma/graph/extraction.py +198 -198
- alma/graph/store.py +860 -860
- alma/harness/__init__.py +35 -35
- alma/harness/base.py +386 -386
- alma/harness/domains.py +705 -705
- alma/initializer/__init__.py +37 -37
- alma/initializer/initializer.py +418 -418
- alma/initializer/types.py +250 -250
- alma/integration/__init__.py +62 -62
- alma/integration/claude_agents.py +444 -432
- alma/integration/helena.py +423 -423
- alma/integration/victor.py +471 -471
- alma/learning/__init__.py +101 -86
- alma/learning/decay.py +878 -0
- alma/learning/forgetting.py +1446 -1446
- alma/learning/heuristic_extractor.py +390 -390
- alma/learning/protocols.py +374 -374
- alma/learning/validation.py +346 -346
- alma/mcp/__init__.py +123 -45
- alma/mcp/__main__.py +156 -156
- alma/mcp/resources.py +122 -122
- alma/mcp/server.py +955 -591
- alma/mcp/tools.py +3254 -511
- alma/observability/__init__.py +91 -0
- alma/observability/config.py +302 -0
- alma/observability/guidelines.py +170 -0
- alma/observability/logging.py +424 -0
- alma/observability/metrics.py +583 -0
- alma/observability/tracing.py +440 -0
- alma/progress/__init__.py +21 -21
- alma/progress/tracker.py +607 -607
- alma/progress/types.py +250 -250
- alma/retrieval/__init__.py +134 -53
- alma/retrieval/budget.py +525 -0
- alma/retrieval/cache.py +1304 -1061
- alma/retrieval/embeddings.py +202 -202
- alma/retrieval/engine.py +850 -366
- alma/retrieval/modes.py +365 -0
- alma/retrieval/progressive.py +560 -0
- alma/retrieval/scoring.py +344 -344
- alma/retrieval/trust_scoring.py +637 -0
- alma/retrieval/verification.py +797 -0
- alma/session/__init__.py +19 -19
- alma/session/manager.py +442 -399
- alma/session/types.py +288 -288
- alma/storage/__init__.py +101 -61
- alma/storage/archive.py +233 -0
- alma/storage/azure_cosmos.py +1259 -1048
- alma/storage/base.py +1083 -525
- alma/storage/chroma.py +1443 -1443
- alma/storage/constants.py +103 -0
- alma/storage/file_based.py +614 -619
- alma/storage/migrations/__init__.py +21 -0
- alma/storage/migrations/base.py +321 -0
- alma/storage/migrations/runner.py +323 -0
- alma/storage/migrations/version_stores.py +337 -0
- alma/storage/migrations/versions/__init__.py +11 -0
- alma/storage/migrations/versions/v1_0_0.py +373 -0
- alma/storage/migrations/versions/v1_1_0_workflow_context.py +551 -0
- alma/storage/pinecone.py +1080 -1080
- alma/storage/postgresql.py +1948 -1452
- alma/storage/qdrant.py +1306 -1306
- alma/storage/sqlite_local.py +3041 -1358
- alma/testing/__init__.py +46 -0
- alma/testing/factories.py +301 -0
- alma/testing/mocks.py +389 -0
- alma/types.py +292 -264
- alma/utils/__init__.py +19 -0
- alma/utils/tokenizer.py +521 -0
- alma/workflow/__init__.py +83 -0
- alma/workflow/artifacts.py +170 -0
- alma/workflow/checkpoint.py +311 -0
- alma/workflow/context.py +228 -0
- alma/workflow/outcomes.py +189 -0
- alma/workflow/reducers.py +393 -0
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/METADATA +244 -72
- alma_memory-0.7.0.dist-info/RECORD +112 -0
- alma_memory-0.5.0.dist-info/RECORD +0 -76
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/WHEEL +0 -0
- {alma_memory-0.5.0.dist-info → alma_memory-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Workflow Artifacts.
|
|
3
|
+
|
|
4
|
+
Provides artifact reference dataclass for linking external artifacts
|
|
5
|
+
(files, screenshots, logs, etc.) to memories.
|
|
6
|
+
|
|
7
|
+
Sprint 1 Task 1.4
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
from uuid import uuid4
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ArtifactType(Enum):
|
|
18
|
+
"""Types of artifacts that can be linked to memories."""
|
|
19
|
+
|
|
20
|
+
# Files and documents
|
|
21
|
+
FILE = "file"
|
|
22
|
+
DOCUMENT = "document"
|
|
23
|
+
IMAGE = "image"
|
|
24
|
+
VIDEO = "video"
|
|
25
|
+
|
|
26
|
+
# Development artifacts
|
|
27
|
+
SCREENSHOT = "screenshot"
|
|
28
|
+
LOG = "log"
|
|
29
|
+
TRACE = "trace"
|
|
30
|
+
DIFF = "diff"
|
|
31
|
+
|
|
32
|
+
# Test artifacts
|
|
33
|
+
TEST_RESULT = "test_result"
|
|
34
|
+
COVERAGE_REPORT = "coverage_report"
|
|
35
|
+
|
|
36
|
+
# Analysis artifacts
|
|
37
|
+
REPORT = "report"
|
|
38
|
+
METRICS = "metrics"
|
|
39
|
+
|
|
40
|
+
# Generic
|
|
41
|
+
OTHER = "other"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ArtifactRef:
|
|
46
|
+
"""
|
|
47
|
+
Reference to an external artifact linked to a memory.
|
|
48
|
+
|
|
49
|
+
Artifacts are stored externally (e.g., Cloudflare R2, S3, local filesystem)
|
|
50
|
+
and referenced by URL/path. This allows memories to reference large files
|
|
51
|
+
without bloating the memory database.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
id: Unique artifact reference identifier
|
|
55
|
+
memory_id: The memory this artifact is linked to
|
|
56
|
+
artifact_type: Type of artifact (screenshot, log, etc.)
|
|
57
|
+
storage_url: URL or path to the artifact in storage
|
|
58
|
+
filename: Original filename
|
|
59
|
+
mime_type: MIME type of the artifact
|
|
60
|
+
size_bytes: Size of the artifact in bytes
|
|
61
|
+
checksum: SHA256 checksum for integrity verification
|
|
62
|
+
metadata: Additional artifact metadata
|
|
63
|
+
created_at: When this reference was created
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
id: str = field(default_factory=lambda: str(uuid4()))
|
|
67
|
+
memory_id: str = ""
|
|
68
|
+
artifact_type: ArtifactType = ArtifactType.OTHER
|
|
69
|
+
storage_url: str = ""
|
|
70
|
+
filename: Optional[str] = None
|
|
71
|
+
mime_type: Optional[str] = None
|
|
72
|
+
size_bytes: Optional[int] = None
|
|
73
|
+
checksum: Optional[str] = None
|
|
74
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
75
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
76
|
+
|
|
77
|
+
def validate(self) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Validate the artifact reference.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If validation fails.
|
|
83
|
+
"""
|
|
84
|
+
if not self.memory_id:
|
|
85
|
+
raise ValueError("memory_id is required")
|
|
86
|
+
if not self.storage_url:
|
|
87
|
+
raise ValueError("storage_url is required")
|
|
88
|
+
|
|
89
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
90
|
+
"""Convert to dictionary for serialization."""
|
|
91
|
+
return {
|
|
92
|
+
"id": self.id,
|
|
93
|
+
"memory_id": self.memory_id,
|
|
94
|
+
"artifact_type": self.artifact_type.value,
|
|
95
|
+
"storage_url": self.storage_url,
|
|
96
|
+
"filename": self.filename,
|
|
97
|
+
"mime_type": self.mime_type,
|
|
98
|
+
"size_bytes": self.size_bytes,
|
|
99
|
+
"checksum": self.checksum,
|
|
100
|
+
"metadata": self.metadata,
|
|
101
|
+
"created_at": self.created_at.isoformat(),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ArtifactRef":
|
|
106
|
+
"""Create from dictionary."""
|
|
107
|
+
created_at = data.get("created_at")
|
|
108
|
+
if isinstance(created_at, str):
|
|
109
|
+
created_at = datetime.fromisoformat(created_at)
|
|
110
|
+
elif created_at is None:
|
|
111
|
+
created_at = datetime.now(timezone.utc)
|
|
112
|
+
|
|
113
|
+
artifact_type = data.get("artifact_type", "other")
|
|
114
|
+
if isinstance(artifact_type, str):
|
|
115
|
+
artifact_type = ArtifactType(artifact_type)
|
|
116
|
+
|
|
117
|
+
return cls(
|
|
118
|
+
id=data.get("id", str(uuid4())),
|
|
119
|
+
memory_id=data.get("memory_id", ""),
|
|
120
|
+
artifact_type=artifact_type,
|
|
121
|
+
storage_url=data.get("storage_url", ""),
|
|
122
|
+
filename=data.get("filename"),
|
|
123
|
+
mime_type=data.get("mime_type"),
|
|
124
|
+
size_bytes=data.get("size_bytes"),
|
|
125
|
+
checksum=data.get("checksum"),
|
|
126
|
+
metadata=data.get("metadata", {}),
|
|
127
|
+
created_at=created_at,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def link_artifact(
|
|
132
|
+
memory_id: str,
|
|
133
|
+
artifact_type: ArtifactType,
|
|
134
|
+
storage_url: str,
|
|
135
|
+
filename: Optional[str] = None,
|
|
136
|
+
mime_type: Optional[str] = None,
|
|
137
|
+
size_bytes: Optional[int] = None,
|
|
138
|
+
checksum: Optional[str] = None,
|
|
139
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
140
|
+
) -> ArtifactRef:
|
|
141
|
+
"""
|
|
142
|
+
Create an artifact reference linked to a memory.
|
|
143
|
+
|
|
144
|
+
This is a convenience function for creating ArtifactRef instances.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
memory_id: The memory to link the artifact to.
|
|
148
|
+
artifact_type: Type of artifact.
|
|
149
|
+
storage_url: URL or path to the artifact.
|
|
150
|
+
filename: Original filename.
|
|
151
|
+
mime_type: MIME type.
|
|
152
|
+
size_bytes: Size in bytes.
|
|
153
|
+
checksum: SHA256 checksum.
|
|
154
|
+
metadata: Additional metadata.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A validated ArtifactRef instance.
|
|
158
|
+
"""
|
|
159
|
+
ref = ArtifactRef(
|
|
160
|
+
memory_id=memory_id,
|
|
161
|
+
artifact_type=artifact_type,
|
|
162
|
+
storage_url=storage_url,
|
|
163
|
+
filename=filename,
|
|
164
|
+
mime_type=mime_type,
|
|
165
|
+
size_bytes=size_bytes,
|
|
166
|
+
checksum=checksum,
|
|
167
|
+
metadata=metadata or {},
|
|
168
|
+
)
|
|
169
|
+
ref.validate()
|
|
170
|
+
return ref
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Workflow Checkpoints.
|
|
3
|
+
|
|
4
|
+
Provides checkpoint dataclass and CheckpointManager for crash recovery
|
|
5
|
+
and state persistence in workflow orchestration.
|
|
6
|
+
|
|
7
|
+
Sprint 1 Task 1.3, Sprint 3 Tasks 3.1-3.4
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
|
+
from uuid import uuid4
|
|
16
|
+
|
|
17
|
+
# Default maximum state size (1MB)
|
|
18
|
+
DEFAULT_MAX_STATE_SIZE = 1024 * 1024
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class Checkpoint:
|
|
23
|
+
"""
|
|
24
|
+
Represents a workflow execution checkpoint.
|
|
25
|
+
|
|
26
|
+
Checkpoints enable crash recovery by persisting state at key points
|
|
27
|
+
during workflow execution. They support parallel execution through
|
|
28
|
+
branch tracking and parent references.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
id: Unique checkpoint identifier
|
|
32
|
+
run_id: The workflow run this checkpoint belongs to
|
|
33
|
+
node_id: The node that created this checkpoint
|
|
34
|
+
state: The serializable state data
|
|
35
|
+
sequence_number: Ordering within the run (monotonically increasing)
|
|
36
|
+
branch_id: Identifier for parallel branch (None for main branch)
|
|
37
|
+
parent_checkpoint_id: Previous checkpoint in the chain
|
|
38
|
+
state_hash: SHA256 hash for change detection
|
|
39
|
+
metadata: Additional checkpoint metadata
|
|
40
|
+
created_at: When this checkpoint was created
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
id: str = field(default_factory=lambda: str(uuid4()))
|
|
44
|
+
run_id: str = ""
|
|
45
|
+
node_id: str = ""
|
|
46
|
+
state: Dict[str, Any] = field(default_factory=dict)
|
|
47
|
+
sequence_number: int = 0
|
|
48
|
+
branch_id: Optional[str] = None
|
|
49
|
+
parent_checkpoint_id: Optional[str] = None
|
|
50
|
+
state_hash: str = ""
|
|
51
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
52
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
"""Compute state hash if not provided."""
|
|
56
|
+
if not self.state_hash and self.state:
|
|
57
|
+
self.state_hash = self._compute_hash(self.state)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _compute_hash(state: Dict[str, Any]) -> str:
|
|
61
|
+
"""Compute SHA256 hash of state for change detection."""
|
|
62
|
+
# Sort keys for consistent hashing
|
|
63
|
+
state_json = json.dumps(state, sort_keys=True, default=str)
|
|
64
|
+
return hashlib.sha256(state_json.encode()).hexdigest()
|
|
65
|
+
|
|
66
|
+
def has_changed(self, other_state: Dict[str, Any]) -> bool:
|
|
67
|
+
"""Check if state has changed compared to another state."""
|
|
68
|
+
other_hash = self._compute_hash(other_state)
|
|
69
|
+
return self.state_hash != other_hash
|
|
70
|
+
|
|
71
|
+
def get_state_size(self) -> int:
|
|
72
|
+
"""Get the size of the state in bytes."""
|
|
73
|
+
return len(json.dumps(self.state, default=str).encode())
|
|
74
|
+
|
|
75
|
+
def validate(self, max_state_size: int = DEFAULT_MAX_STATE_SIZE) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Validate the checkpoint.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
max_state_size: Maximum allowed state size in bytes.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If validation fails.
|
|
84
|
+
"""
|
|
85
|
+
if not self.run_id:
|
|
86
|
+
raise ValueError("run_id is required")
|
|
87
|
+
if not self.node_id:
|
|
88
|
+
raise ValueError("node_id is required")
|
|
89
|
+
if self.sequence_number < 0:
|
|
90
|
+
raise ValueError("sequence_number must be non-negative")
|
|
91
|
+
|
|
92
|
+
state_size = self.get_state_size()
|
|
93
|
+
if state_size > max_state_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"State size ({state_size} bytes) exceeds maximum "
|
|
96
|
+
f"({max_state_size} bytes). Consider storing large data "
|
|
97
|
+
"in artifact storage and linking via ArtifactRef."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
101
|
+
"""Convert to dictionary for serialization."""
|
|
102
|
+
return {
|
|
103
|
+
"id": self.id,
|
|
104
|
+
"run_id": self.run_id,
|
|
105
|
+
"node_id": self.node_id,
|
|
106
|
+
"state": self.state,
|
|
107
|
+
"sequence_number": self.sequence_number,
|
|
108
|
+
"branch_id": self.branch_id,
|
|
109
|
+
"parent_checkpoint_id": self.parent_checkpoint_id,
|
|
110
|
+
"state_hash": self.state_hash,
|
|
111
|
+
"metadata": self.metadata,
|
|
112
|
+
"created_at": self.created_at.isoformat(),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
|
|
117
|
+
"""Create from dictionary."""
|
|
118
|
+
created_at = data.get("created_at")
|
|
119
|
+
if isinstance(created_at, str):
|
|
120
|
+
created_at = datetime.fromisoformat(created_at)
|
|
121
|
+
elif created_at is None:
|
|
122
|
+
created_at = datetime.now(timezone.utc)
|
|
123
|
+
|
|
124
|
+
return cls(
|
|
125
|
+
id=data.get("id", str(uuid4())),
|
|
126
|
+
run_id=data.get("run_id", ""),
|
|
127
|
+
node_id=data.get("node_id", ""),
|
|
128
|
+
state=data.get("state", {}),
|
|
129
|
+
sequence_number=data.get("sequence_number", 0),
|
|
130
|
+
branch_id=data.get("branch_id"),
|
|
131
|
+
parent_checkpoint_id=data.get("parent_checkpoint_id"),
|
|
132
|
+
state_hash=data.get("state_hash", ""),
|
|
133
|
+
metadata=data.get("metadata", {}),
|
|
134
|
+
created_at=created_at,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class CheckpointManager:
|
|
139
|
+
"""
|
|
140
|
+
Manages checkpoint operations for workflow execution.
|
|
141
|
+
|
|
142
|
+
Provides methods for creating, retrieving, and cleaning up checkpoints.
|
|
143
|
+
Supports concurrent access patterns and branch management.
|
|
144
|
+
|
|
145
|
+
Sprint 3 Tasks 3.1-3.4
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
storage: Any, # StorageBackend
|
|
151
|
+
max_state_size: int = DEFAULT_MAX_STATE_SIZE,
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Initialize the checkpoint manager.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
storage: Storage backend for persisting checkpoints.
|
|
158
|
+
max_state_size: Maximum allowed state size in bytes.
|
|
159
|
+
"""
|
|
160
|
+
self._storage = storage
|
|
161
|
+
self._max_state_size = max_state_size
|
|
162
|
+
self._sequence_cache: Dict[str, int] = {} # run_id -> last sequence
|
|
163
|
+
|
|
164
|
+
def create_checkpoint(
|
|
165
|
+
self,
|
|
166
|
+
run_id: str,
|
|
167
|
+
node_id: str,
|
|
168
|
+
state: Dict[str, Any],
|
|
169
|
+
branch_id: Optional[str] = None,
|
|
170
|
+
parent_checkpoint_id: Optional[str] = None,
|
|
171
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
172
|
+
skip_if_unchanged: bool = True,
|
|
173
|
+
) -> Optional[Checkpoint]:
|
|
174
|
+
"""
|
|
175
|
+
Create a new checkpoint.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
run_id: The workflow run identifier.
|
|
179
|
+
node_id: The node creating this checkpoint.
|
|
180
|
+
state: The state to persist.
|
|
181
|
+
branch_id: Optional branch identifier for parallel execution.
|
|
182
|
+
parent_checkpoint_id: Previous checkpoint in the chain.
|
|
183
|
+
metadata: Additional checkpoint metadata.
|
|
184
|
+
skip_if_unchanged: If True, skip creating checkpoint if state
|
|
185
|
+
hasn't changed from the last checkpoint.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
The created Checkpoint, or None if skipped due to no changes.
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: If state exceeds max_state_size.
|
|
192
|
+
"""
|
|
193
|
+
# Get next sequence number
|
|
194
|
+
sequence_number = self._get_next_sequence(run_id)
|
|
195
|
+
|
|
196
|
+
# Create checkpoint
|
|
197
|
+
checkpoint = Checkpoint(
|
|
198
|
+
run_id=run_id,
|
|
199
|
+
node_id=node_id,
|
|
200
|
+
state=state,
|
|
201
|
+
sequence_number=sequence_number,
|
|
202
|
+
branch_id=branch_id,
|
|
203
|
+
parent_checkpoint_id=parent_checkpoint_id,
|
|
204
|
+
metadata=metadata or {},
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Validate
|
|
208
|
+
checkpoint.validate(self._max_state_size)
|
|
209
|
+
|
|
210
|
+
# Check if state has changed (optional optimization)
|
|
211
|
+
if skip_if_unchanged and parent_checkpoint_id:
|
|
212
|
+
parent = self.get_checkpoint(parent_checkpoint_id)
|
|
213
|
+
if parent and not parent.has_changed(state):
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
# Persist
|
|
217
|
+
self._storage.save_checkpoint(checkpoint)
|
|
218
|
+
|
|
219
|
+
# Update sequence cache
|
|
220
|
+
self._sequence_cache[run_id] = sequence_number
|
|
221
|
+
|
|
222
|
+
return checkpoint
|
|
223
|
+
|
|
224
|
+
def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
|
225
|
+
"""Get a checkpoint by ID."""
|
|
226
|
+
return self._storage.get_checkpoint(checkpoint_id)
|
|
227
|
+
|
|
228
|
+
def get_latest_checkpoint(
|
|
229
|
+
self,
|
|
230
|
+
run_id: str,
|
|
231
|
+
branch_id: Optional[str] = None,
|
|
232
|
+
) -> Optional[Checkpoint]:
|
|
233
|
+
"""
|
|
234
|
+
Get the most recent checkpoint for a run.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
run_id: The workflow run identifier.
|
|
238
|
+
branch_id: Optional branch to filter by.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The latest checkpoint, or None if no checkpoints exist.
|
|
242
|
+
"""
|
|
243
|
+
return self._storage.get_latest_checkpoint(run_id, branch_id)
|
|
244
|
+
|
|
245
|
+
def get_resume_point(self, run_id: str) -> Optional[Checkpoint]:
|
|
246
|
+
"""
|
|
247
|
+
Get the checkpoint to resume from after a crash.
|
|
248
|
+
|
|
249
|
+
This is an alias for get_latest_checkpoint for clarity.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
run_id: The workflow run identifier.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
The checkpoint to resume from, or None if no checkpoints.
|
|
256
|
+
"""
|
|
257
|
+
return self.get_latest_checkpoint(run_id)
|
|
258
|
+
|
|
259
|
+
def get_branch_checkpoints(
|
|
260
|
+
self,
|
|
261
|
+
run_id: str,
|
|
262
|
+
branch_ids: List[str],
|
|
263
|
+
) -> Dict[str, Checkpoint]:
|
|
264
|
+
"""
|
|
265
|
+
Get the latest checkpoint for each branch.
|
|
266
|
+
|
|
267
|
+
Used for parallel merge operations.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
run_id: The workflow run identifier.
|
|
271
|
+
branch_ids: List of branch identifiers.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Dictionary mapping branch_id to its latest checkpoint.
|
|
275
|
+
"""
|
|
276
|
+
result = {}
|
|
277
|
+
for branch_id in branch_ids:
|
|
278
|
+
checkpoint = self.get_latest_checkpoint(run_id, branch_id)
|
|
279
|
+
if checkpoint:
|
|
280
|
+
result[branch_id] = checkpoint
|
|
281
|
+
return result
|
|
282
|
+
|
|
283
|
+
def cleanup_checkpoints(
|
|
284
|
+
self,
|
|
285
|
+
run_id: str,
|
|
286
|
+
keep_latest: int = 1,
|
|
287
|
+
) -> int:
|
|
288
|
+
"""
|
|
289
|
+
Clean up old checkpoints for a completed run.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
run_id: The workflow run identifier.
|
|
293
|
+
keep_latest: Number of latest checkpoints to keep.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Number of checkpoints deleted.
|
|
297
|
+
"""
|
|
298
|
+
return self._storage.cleanup_checkpoints(run_id, keep_latest)
|
|
299
|
+
|
|
300
|
+
def _get_next_sequence(self, run_id: str) -> int:
|
|
301
|
+
"""Get the next sequence number for a run."""
|
|
302
|
+
if run_id in self._sequence_cache:
|
|
303
|
+
return self._sequence_cache[run_id] + 1
|
|
304
|
+
|
|
305
|
+
# Query storage for latest sequence
|
|
306
|
+
latest = self.get_latest_checkpoint(run_id)
|
|
307
|
+
if latest:
|
|
308
|
+
self._sequence_cache[run_id] = latest.sequence_number
|
|
309
|
+
return latest.sequence_number + 1
|
|
310
|
+
|
|
311
|
+
return 0
|
alma/workflow/context.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Workflow Context.
|
|
3
|
+
|
|
4
|
+
Defines the WorkflowContext dataclass and RetrievalScope enum for
|
|
5
|
+
scoped memory retrieval in workflow orchestration systems.
|
|
6
|
+
|
|
7
|
+
Sprint 1 Tasks 1.1, 1.2
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RetrievalScope(Enum):
|
|
17
|
+
"""
|
|
18
|
+
Defines the scope for memory retrieval operations.
|
|
19
|
+
|
|
20
|
+
The hierarchy from most specific to most general:
|
|
21
|
+
NODE -> RUN -> WORKFLOW -> AGENT -> TENANT -> GLOBAL
|
|
22
|
+
|
|
23
|
+
Note: Named RetrievalScope (not MemoryScope) to avoid collision
|
|
24
|
+
with the existing MemoryScope dataclass in alma/types.py which
|
|
25
|
+
defines what an agent is *allowed* to learn, not *where* to search.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Most specific - only memories from this specific node execution
|
|
29
|
+
NODE = "node"
|
|
30
|
+
|
|
31
|
+
# Memories from this specific workflow run
|
|
32
|
+
RUN = "run"
|
|
33
|
+
|
|
34
|
+
# Memories from all runs of this workflow definition
|
|
35
|
+
WORKFLOW = "workflow"
|
|
36
|
+
|
|
37
|
+
# Memories from all workflows for this agent (default)
|
|
38
|
+
AGENT = "agent"
|
|
39
|
+
|
|
40
|
+
# Memories from all agents within this tenant
|
|
41
|
+
TENANT = "tenant"
|
|
42
|
+
|
|
43
|
+
# All memories across all tenants (admin only)
|
|
44
|
+
GLOBAL = "global"
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_string(cls, value: str) -> "RetrievalScope":
|
|
48
|
+
"""Convert string to RetrievalScope enum."""
|
|
49
|
+
try:
|
|
50
|
+
return cls(value.lower())
|
|
51
|
+
except ValueError as err:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Invalid RetrievalScope: '{value}'. "
|
|
54
|
+
f"Valid options: {[s.value for s in cls]}"
|
|
55
|
+
) from err
|
|
56
|
+
|
|
57
|
+
def is_broader_than(self, other: "RetrievalScope") -> bool:
|
|
58
|
+
"""Check if this scope is broader than another scope."""
|
|
59
|
+
hierarchy = [
|
|
60
|
+
RetrievalScope.NODE,
|
|
61
|
+
RetrievalScope.RUN,
|
|
62
|
+
RetrievalScope.WORKFLOW,
|
|
63
|
+
RetrievalScope.AGENT,
|
|
64
|
+
RetrievalScope.TENANT,
|
|
65
|
+
RetrievalScope.GLOBAL,
|
|
66
|
+
]
|
|
67
|
+
return hierarchy.index(self) > hierarchy.index(other)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class WorkflowContext:
|
|
72
|
+
"""
|
|
73
|
+
Context for workflow-scoped memory operations.
|
|
74
|
+
|
|
75
|
+
Provides hierarchical scoping for AGtestari and similar workflow
|
|
76
|
+
orchestration systems. All fields are optional except when
|
|
77
|
+
require_tenant=True is passed to validate().
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
tenant_id: Multi-tenant isolation identifier
|
|
81
|
+
workflow_id: Workflow definition identifier
|
|
82
|
+
run_id: Specific workflow execution identifier
|
|
83
|
+
node_id: Current node within the workflow
|
|
84
|
+
branch_id: Parallel branch identifier (for fan-out patterns)
|
|
85
|
+
metadata: Additional context data
|
|
86
|
+
created_at: When this context was created
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
tenant_id: Optional[str] = None
|
|
90
|
+
workflow_id: Optional[str] = None
|
|
91
|
+
run_id: Optional[str] = None
|
|
92
|
+
node_id: Optional[str] = None
|
|
93
|
+
branch_id: Optional[str] = None
|
|
94
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
95
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
96
|
+
|
|
97
|
+
def validate(self, require_tenant: bool = False) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Validate the workflow context.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
require_tenant: If True, tenant_id must be provided.
|
|
103
|
+
Use for multi-tenant deployments.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If validation fails.
|
|
107
|
+
"""
|
|
108
|
+
if require_tenant and not self.tenant_id:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"tenant_id is required for multi-tenant deployments. "
|
|
111
|
+
"Set require_tenant=False for single-tenant mode."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# If node_id is provided, run_id should also be provided
|
|
115
|
+
if self.node_id and not self.run_id:
|
|
116
|
+
raise ValueError("node_id requires run_id to be set")
|
|
117
|
+
|
|
118
|
+
# If run_id is provided, workflow_id should also be provided
|
|
119
|
+
if self.run_id and not self.workflow_id:
|
|
120
|
+
raise ValueError("run_id requires workflow_id to be set")
|
|
121
|
+
|
|
122
|
+
# If branch_id is provided, run_id should also be provided
|
|
123
|
+
if self.branch_id and not self.run_id:
|
|
124
|
+
raise ValueError("branch_id requires run_id to be set")
|
|
125
|
+
|
|
126
|
+
def get_scope_filter(self, scope: RetrievalScope) -> Dict[str, Any]:
|
|
127
|
+
"""
|
|
128
|
+
Build a filter dict for the given retrieval scope.
|
|
129
|
+
|
|
130
|
+
Returns a dictionary that can be used to filter memories
|
|
131
|
+
based on the workflow context and requested scope.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
scope: The retrieval scope to filter by.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Dictionary with filter criteria.
|
|
138
|
+
"""
|
|
139
|
+
filters: Dict[str, Any] = {}
|
|
140
|
+
|
|
141
|
+
if scope == RetrievalScope.GLOBAL:
|
|
142
|
+
# No filters - return everything
|
|
143
|
+
pass
|
|
144
|
+
elif scope == RetrievalScope.TENANT:
|
|
145
|
+
if self.tenant_id:
|
|
146
|
+
filters["tenant_id"] = self.tenant_id
|
|
147
|
+
elif scope == RetrievalScope.AGENT:
|
|
148
|
+
if self.tenant_id:
|
|
149
|
+
filters["tenant_id"] = self.tenant_id
|
|
150
|
+
# Agent filtering is done separately via the agent parameter
|
|
151
|
+
elif scope == RetrievalScope.WORKFLOW:
|
|
152
|
+
if self.tenant_id:
|
|
153
|
+
filters["tenant_id"] = self.tenant_id
|
|
154
|
+
if self.workflow_id:
|
|
155
|
+
filters["workflow_id"] = self.workflow_id
|
|
156
|
+
elif scope == RetrievalScope.RUN:
|
|
157
|
+
if self.tenant_id:
|
|
158
|
+
filters["tenant_id"] = self.tenant_id
|
|
159
|
+
if self.workflow_id:
|
|
160
|
+
filters["workflow_id"] = self.workflow_id
|
|
161
|
+
if self.run_id:
|
|
162
|
+
filters["run_id"] = self.run_id
|
|
163
|
+
elif scope == RetrievalScope.NODE:
|
|
164
|
+
if self.tenant_id:
|
|
165
|
+
filters["tenant_id"] = self.tenant_id
|
|
166
|
+
if self.workflow_id:
|
|
167
|
+
filters["workflow_id"] = self.workflow_id
|
|
168
|
+
if self.run_id:
|
|
169
|
+
filters["run_id"] = self.run_id
|
|
170
|
+
if self.node_id:
|
|
171
|
+
filters["node_id"] = self.node_id
|
|
172
|
+
|
|
173
|
+
return filters
|
|
174
|
+
|
|
175
|
+
def with_node(self, node_id: str) -> "WorkflowContext":
|
|
176
|
+
"""Create a new context with a different node_id."""
|
|
177
|
+
return WorkflowContext(
|
|
178
|
+
tenant_id=self.tenant_id,
|
|
179
|
+
workflow_id=self.workflow_id,
|
|
180
|
+
run_id=self.run_id,
|
|
181
|
+
node_id=node_id,
|
|
182
|
+
branch_id=self.branch_id,
|
|
183
|
+
metadata=self.metadata.copy(),
|
|
184
|
+
created_at=self.created_at,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def with_branch(self, branch_id: str) -> "WorkflowContext":
|
|
188
|
+
"""Create a new context for a parallel branch."""
|
|
189
|
+
return WorkflowContext(
|
|
190
|
+
tenant_id=self.tenant_id,
|
|
191
|
+
workflow_id=self.workflow_id,
|
|
192
|
+
run_id=self.run_id,
|
|
193
|
+
node_id=self.node_id,
|
|
194
|
+
branch_id=branch_id,
|
|
195
|
+
metadata=self.metadata.copy(),
|
|
196
|
+
created_at=self.created_at,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
200
|
+
"""Convert to dictionary for serialization."""
|
|
201
|
+
return {
|
|
202
|
+
"tenant_id": self.tenant_id,
|
|
203
|
+
"workflow_id": self.workflow_id,
|
|
204
|
+
"run_id": self.run_id,
|
|
205
|
+
"node_id": self.node_id,
|
|
206
|
+
"branch_id": self.branch_id,
|
|
207
|
+
"metadata": self.metadata,
|
|
208
|
+
"created_at": self.created_at.isoformat(),
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowContext":
|
|
213
|
+
"""Create from dictionary."""
|
|
214
|
+
created_at = data.get("created_at")
|
|
215
|
+
if isinstance(created_at, str):
|
|
216
|
+
created_at = datetime.fromisoformat(created_at)
|
|
217
|
+
elif created_at is None:
|
|
218
|
+
created_at = datetime.now(timezone.utc)
|
|
219
|
+
|
|
220
|
+
return cls(
|
|
221
|
+
tenant_id=data.get("tenant_id"),
|
|
222
|
+
workflow_id=data.get("workflow_id"),
|
|
223
|
+
run_id=data.get("run_id"),
|
|
224
|
+
node_id=data.get("node_id"),
|
|
225
|
+
branch_id=data.get("branch_id"),
|
|
226
|
+
metadata=data.get("metadata", {}),
|
|
227
|
+
created_at=created_at,
|
|
228
|
+
)
|