alma-memory 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. alma/__init__.py +296 -226
  2. alma/compression/__init__.py +33 -0
  3. alma/compression/pipeline.py +980 -0
  4. alma/confidence/__init__.py +47 -47
  5. alma/confidence/engine.py +540 -540
  6. alma/confidence/types.py +351 -351
  7. alma/config/loader.py +157 -157
  8. alma/consolidation/__init__.py +23 -23
  9. alma/consolidation/engine.py +678 -678
  10. alma/consolidation/prompts.py +84 -84
  11. alma/core.py +1189 -430
  12. alma/domains/__init__.py +30 -30
  13. alma/domains/factory.py +359 -359
  14. alma/domains/schemas.py +448 -448
  15. alma/domains/types.py +272 -272
  16. alma/events/__init__.py +75 -75
  17. alma/events/emitter.py +285 -284
  18. alma/events/storage_mixin.py +246 -246
  19. alma/events/types.py +126 -126
  20. alma/events/webhook.py +425 -425
  21. alma/exceptions.py +49 -49
  22. alma/extraction/__init__.py +31 -31
  23. alma/extraction/auto_learner.py +265 -265
  24. alma/extraction/extractor.py +420 -420
  25. alma/graph/__init__.py +106 -106
  26. alma/graph/backends/__init__.py +32 -32
  27. alma/graph/backends/kuzu.py +624 -624
  28. alma/graph/backends/memgraph.py +432 -432
  29. alma/graph/backends/memory.py +236 -236
  30. alma/graph/backends/neo4j.py +417 -417
  31. alma/graph/base.py +159 -159
  32. alma/graph/extraction.py +198 -198
  33. alma/graph/store.py +860 -860
  34. alma/harness/__init__.py +35 -35
  35. alma/harness/base.py +386 -386
  36. alma/harness/domains.py +705 -705
  37. alma/initializer/__init__.py +37 -37
  38. alma/initializer/initializer.py +418 -418
  39. alma/initializer/types.py +250 -250
  40. alma/integration/__init__.py +62 -62
  41. alma/integration/claude_agents.py +444 -444
  42. alma/integration/helena.py +423 -423
  43. alma/integration/victor.py +471 -471
  44. alma/learning/__init__.py +101 -86
  45. alma/learning/decay.py +878 -0
  46. alma/learning/forgetting.py +1446 -1446
  47. alma/learning/heuristic_extractor.py +390 -390
  48. alma/learning/protocols.py +374 -374
  49. alma/learning/validation.py +346 -346
  50. alma/mcp/__init__.py +123 -45
  51. alma/mcp/__main__.py +156 -156
  52. alma/mcp/resources.py +122 -122
  53. alma/mcp/server.py +955 -591
  54. alma/mcp/tools.py +3254 -509
  55. alma/observability/__init__.py +91 -84
  56. alma/observability/config.py +302 -302
  57. alma/observability/guidelines.py +170 -0
  58. alma/observability/logging.py +424 -424
  59. alma/observability/metrics.py +583 -583
  60. alma/observability/tracing.py +440 -440
  61. alma/progress/__init__.py +21 -21
  62. alma/progress/tracker.py +607 -607
  63. alma/progress/types.py +250 -250
  64. alma/retrieval/__init__.py +134 -53
  65. alma/retrieval/budget.py +525 -0
  66. alma/retrieval/cache.py +1304 -1061
  67. alma/retrieval/embeddings.py +202 -202
  68. alma/retrieval/engine.py +850 -427
  69. alma/retrieval/modes.py +365 -0
  70. alma/retrieval/progressive.py +560 -0
  71. alma/retrieval/scoring.py +344 -344
  72. alma/retrieval/trust_scoring.py +637 -0
  73. alma/retrieval/verification.py +797 -0
  74. alma/session/__init__.py +19 -19
  75. alma/session/manager.py +442 -399
  76. alma/session/types.py +288 -288
  77. alma/storage/__init__.py +101 -90
  78. alma/storage/archive.py +233 -0
  79. alma/storage/azure_cosmos.py +1259 -1259
  80. alma/storage/base.py +1083 -583
  81. alma/storage/chroma.py +1443 -1443
  82. alma/storage/constants.py +103 -103
  83. alma/storage/file_based.py +614 -614
  84. alma/storage/migrations/__init__.py +21 -21
  85. alma/storage/migrations/base.py +321 -321
  86. alma/storage/migrations/runner.py +323 -323
  87. alma/storage/migrations/version_stores.py +337 -337
  88. alma/storage/migrations/versions/__init__.py +11 -11
  89. alma/storage/migrations/versions/v1_0_0.py +373 -373
  90. alma/storage/migrations/versions/v1_1_0_workflow_context.py +551 -0
  91. alma/storage/pinecone.py +1080 -1080
  92. alma/storage/postgresql.py +1948 -1559
  93. alma/storage/qdrant.py +1306 -1306
  94. alma/storage/sqlite_local.py +3041 -1457
  95. alma/testing/__init__.py +46 -46
  96. alma/testing/factories.py +301 -301
  97. alma/testing/mocks.py +389 -389
  98. alma/types.py +292 -264
  99. alma/utils/__init__.py +19 -0
  100. alma/utils/tokenizer.py +521 -0
  101. alma/workflow/__init__.py +83 -0
  102. alma/workflow/artifacts.py +170 -0
  103. alma/workflow/checkpoint.py +311 -0
  104. alma/workflow/context.py +228 -0
  105. alma/workflow/outcomes.py +189 -0
  106. alma/workflow/reducers.py +393 -0
  107. {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/METADATA +210 -72
  108. alma_memory-0.7.0.dist-info/RECORD +112 -0
  109. alma_memory-0.5.1.dist-info/RECORD +0 -93
  110. {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/WHEEL +0 -0
  111. {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ """
2
+ ALMA Workflow Artifacts.
3
+
4
+ Provides artifact reference dataclass for linking external artifacts
5
+ (files, screenshots, logs, etc.) to memories.
6
+
7
+ Sprint 1 Task 1.4
8
+ """
9
+
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime, timezone
12
+ from enum import Enum
13
+ from typing import Any, Dict, Optional
14
+ from uuid import uuid4
15
+
16
+
17
+ class ArtifactType(Enum):
18
+ """Types of artifacts that can be linked to memories."""
19
+
20
+ # Files and documents
21
+ FILE = "file"
22
+ DOCUMENT = "document"
23
+ IMAGE = "image"
24
+ VIDEO = "video"
25
+
26
+ # Development artifacts
27
+ SCREENSHOT = "screenshot"
28
+ LOG = "log"
29
+ TRACE = "trace"
30
+ DIFF = "diff"
31
+
32
+ # Test artifacts
33
+ TEST_RESULT = "test_result"
34
+ COVERAGE_REPORT = "coverage_report"
35
+
36
+ # Analysis artifacts
37
+ REPORT = "report"
38
+ METRICS = "metrics"
39
+
40
+ # Generic
41
+ OTHER = "other"
42
+
43
+
44
+ @dataclass
45
+ class ArtifactRef:
46
+ """
47
+ Reference to an external artifact linked to a memory.
48
+
49
+ Artifacts are stored externally (e.g., Cloudflare R2, S3, local filesystem)
50
+ and referenced by URL/path. This allows memories to reference large files
51
+ without bloating the memory database.
52
+
53
+ Attributes:
54
+ id: Unique artifact reference identifier
55
+ memory_id: The memory this artifact is linked to
56
+ artifact_type: Type of artifact (screenshot, log, etc.)
57
+ storage_url: URL or path to the artifact in storage
58
+ filename: Original filename
59
+ mime_type: MIME type of the artifact
60
+ size_bytes: Size of the artifact in bytes
61
+ checksum: SHA256 checksum for integrity verification
62
+ metadata: Additional artifact metadata
63
+ created_at: When this reference was created
64
+ """
65
+
66
+ id: str = field(default_factory=lambda: str(uuid4()))
67
+ memory_id: str = ""
68
+ artifact_type: ArtifactType = ArtifactType.OTHER
69
+ storage_url: str = ""
70
+ filename: Optional[str] = None
71
+ mime_type: Optional[str] = None
72
+ size_bytes: Optional[int] = None
73
+ checksum: Optional[str] = None
74
+ metadata: Dict[str, Any] = field(default_factory=dict)
75
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
76
+
77
+ def validate(self) -> None:
78
+ """
79
+ Validate the artifact reference.
80
+
81
+ Raises:
82
+ ValueError: If validation fails.
83
+ """
84
+ if not self.memory_id:
85
+ raise ValueError("memory_id is required")
86
+ if not self.storage_url:
87
+ raise ValueError("storage_url is required")
88
+
89
+ def to_dict(self) -> Dict[str, Any]:
90
+ """Convert to dictionary for serialization."""
91
+ return {
92
+ "id": self.id,
93
+ "memory_id": self.memory_id,
94
+ "artifact_type": self.artifact_type.value,
95
+ "storage_url": self.storage_url,
96
+ "filename": self.filename,
97
+ "mime_type": self.mime_type,
98
+ "size_bytes": self.size_bytes,
99
+ "checksum": self.checksum,
100
+ "metadata": self.metadata,
101
+ "created_at": self.created_at.isoformat(),
102
+ }
103
+
104
+ @classmethod
105
+ def from_dict(cls, data: Dict[str, Any]) -> "ArtifactRef":
106
+ """Create from dictionary."""
107
+ created_at = data.get("created_at")
108
+ if isinstance(created_at, str):
109
+ created_at = datetime.fromisoformat(created_at)
110
+ elif created_at is None:
111
+ created_at = datetime.now(timezone.utc)
112
+
113
+ artifact_type = data.get("artifact_type", "other")
114
+ if isinstance(artifact_type, str):
115
+ artifact_type = ArtifactType(artifact_type)
116
+
117
+ return cls(
118
+ id=data.get("id", str(uuid4())),
119
+ memory_id=data.get("memory_id", ""),
120
+ artifact_type=artifact_type,
121
+ storage_url=data.get("storage_url", ""),
122
+ filename=data.get("filename"),
123
+ mime_type=data.get("mime_type"),
124
+ size_bytes=data.get("size_bytes"),
125
+ checksum=data.get("checksum"),
126
+ metadata=data.get("metadata", {}),
127
+ created_at=created_at,
128
+ )
129
+
130
+
131
+ def link_artifact(
132
+ memory_id: str,
133
+ artifact_type: ArtifactType,
134
+ storage_url: str,
135
+ filename: Optional[str] = None,
136
+ mime_type: Optional[str] = None,
137
+ size_bytes: Optional[int] = None,
138
+ checksum: Optional[str] = None,
139
+ metadata: Optional[Dict[str, Any]] = None,
140
+ ) -> ArtifactRef:
141
+ """
142
+ Create an artifact reference linked to a memory.
143
+
144
+ This is a convenience function for creating ArtifactRef instances.
145
+
146
+ Args:
147
+ memory_id: The memory to link the artifact to.
148
+ artifact_type: Type of artifact.
149
+ storage_url: URL or path to the artifact.
150
+ filename: Original filename.
151
+ mime_type: MIME type.
152
+ size_bytes: Size in bytes.
153
+ checksum: SHA256 checksum.
154
+ metadata: Additional metadata.
155
+
156
+ Returns:
157
+ A validated ArtifactRef instance.
158
+ """
159
+ ref = ArtifactRef(
160
+ memory_id=memory_id,
161
+ artifact_type=artifact_type,
162
+ storage_url=storage_url,
163
+ filename=filename,
164
+ mime_type=mime_type,
165
+ size_bytes=size_bytes,
166
+ checksum=checksum,
167
+ metadata=metadata or {},
168
+ )
169
+ ref.validate()
170
+ return ref
@@ -0,0 +1,311 @@
1
+ """
2
+ ALMA Workflow Checkpoints.
3
+
4
+ Provides checkpoint dataclass and CheckpointManager for crash recovery
5
+ and state persistence in workflow orchestration.
6
+
7
+ Sprint 1 Task 1.3, Sprint 3 Tasks 3.1-3.4
8
+ """
9
+
10
+ import hashlib
11
+ import json
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime, timezone
14
+ from typing import Any, Dict, List, Optional
15
+ from uuid import uuid4
16
+
17
+ # Default maximum state size (1MB)
18
+ DEFAULT_MAX_STATE_SIZE = 1024 * 1024
19
+
20
+
21
+ @dataclass
22
+ class Checkpoint:
23
+ """
24
+ Represents a workflow execution checkpoint.
25
+
26
+ Checkpoints enable crash recovery by persisting state at key points
27
+ during workflow execution. They support parallel execution through
28
+ branch tracking and parent references.
29
+
30
+ Attributes:
31
+ id: Unique checkpoint identifier
32
+ run_id: The workflow run this checkpoint belongs to
33
+ node_id: The node that created this checkpoint
34
+ state: The serializable state data
35
+ sequence_number: Ordering within the run (monotonically increasing)
36
+ branch_id: Identifier for parallel branch (None for main branch)
37
+ parent_checkpoint_id: Previous checkpoint in the chain
38
+ state_hash: SHA256 hash for change detection
39
+ metadata: Additional checkpoint metadata
40
+ created_at: When this checkpoint was created
41
+ """
42
+
43
+ id: str = field(default_factory=lambda: str(uuid4()))
44
+ run_id: str = ""
45
+ node_id: str = ""
46
+ state: Dict[str, Any] = field(default_factory=dict)
47
+ sequence_number: int = 0
48
+ branch_id: Optional[str] = None
49
+ parent_checkpoint_id: Optional[str] = None
50
+ state_hash: str = ""
51
+ metadata: Dict[str, Any] = field(default_factory=dict)
52
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
53
+
54
+ def __post_init__(self):
55
+ """Compute state hash if not provided."""
56
+ if not self.state_hash and self.state:
57
+ self.state_hash = self._compute_hash(self.state)
58
+
59
+ @staticmethod
60
+ def _compute_hash(state: Dict[str, Any]) -> str:
61
+ """Compute SHA256 hash of state for change detection."""
62
+ # Sort keys for consistent hashing
63
+ state_json = json.dumps(state, sort_keys=True, default=str)
64
+ return hashlib.sha256(state_json.encode()).hexdigest()
65
+
66
+ def has_changed(self, other_state: Dict[str, Any]) -> bool:
67
+ """Check if state has changed compared to another state."""
68
+ other_hash = self._compute_hash(other_state)
69
+ return self.state_hash != other_hash
70
+
71
+ def get_state_size(self) -> int:
72
+ """Get the size of the state in bytes."""
73
+ return len(json.dumps(self.state, default=str).encode())
74
+
75
+ def validate(self, max_state_size: int = DEFAULT_MAX_STATE_SIZE) -> None:
76
+ """
77
+ Validate the checkpoint.
78
+
79
+ Args:
80
+ max_state_size: Maximum allowed state size in bytes.
81
+
82
+ Raises:
83
+ ValueError: If validation fails.
84
+ """
85
+ if not self.run_id:
86
+ raise ValueError("run_id is required")
87
+ if not self.node_id:
88
+ raise ValueError("node_id is required")
89
+ if self.sequence_number < 0:
90
+ raise ValueError("sequence_number must be non-negative")
91
+
92
+ state_size = self.get_state_size()
93
+ if state_size > max_state_size:
94
+ raise ValueError(
95
+ f"State size ({state_size} bytes) exceeds maximum "
96
+ f"({max_state_size} bytes). Consider storing large data "
97
+ "in artifact storage and linking via ArtifactRef."
98
+ )
99
+
100
+ def to_dict(self) -> Dict[str, Any]:
101
+ """Convert to dictionary for serialization."""
102
+ return {
103
+ "id": self.id,
104
+ "run_id": self.run_id,
105
+ "node_id": self.node_id,
106
+ "state": self.state,
107
+ "sequence_number": self.sequence_number,
108
+ "branch_id": self.branch_id,
109
+ "parent_checkpoint_id": self.parent_checkpoint_id,
110
+ "state_hash": self.state_hash,
111
+ "metadata": self.metadata,
112
+ "created_at": self.created_at.isoformat(),
113
+ }
114
+
115
+ @classmethod
116
+ def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
117
+ """Create from dictionary."""
118
+ created_at = data.get("created_at")
119
+ if isinstance(created_at, str):
120
+ created_at = datetime.fromisoformat(created_at)
121
+ elif created_at is None:
122
+ created_at = datetime.now(timezone.utc)
123
+
124
+ return cls(
125
+ id=data.get("id", str(uuid4())),
126
+ run_id=data.get("run_id", ""),
127
+ node_id=data.get("node_id", ""),
128
+ state=data.get("state", {}),
129
+ sequence_number=data.get("sequence_number", 0),
130
+ branch_id=data.get("branch_id"),
131
+ parent_checkpoint_id=data.get("parent_checkpoint_id"),
132
+ state_hash=data.get("state_hash", ""),
133
+ metadata=data.get("metadata", {}),
134
+ created_at=created_at,
135
+ )
136
+
137
+
138
+ class CheckpointManager:
139
+ """
140
+ Manages checkpoint operations for workflow execution.
141
+
142
+ Provides methods for creating, retrieving, and cleaning up checkpoints.
143
+ Supports concurrent access patterns and branch management.
144
+
145
+ Sprint 3 Tasks 3.1-3.4
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ storage: Any, # StorageBackend
151
+ max_state_size: int = DEFAULT_MAX_STATE_SIZE,
152
+ ):
153
+ """
154
+ Initialize the checkpoint manager.
155
+
156
+ Args:
157
+ storage: Storage backend for persisting checkpoints.
158
+ max_state_size: Maximum allowed state size in bytes.
159
+ """
160
+ self._storage = storage
161
+ self._max_state_size = max_state_size
162
+ self._sequence_cache: Dict[str, int] = {} # run_id -> last sequence
163
+
164
+ def create_checkpoint(
165
+ self,
166
+ run_id: str,
167
+ node_id: str,
168
+ state: Dict[str, Any],
169
+ branch_id: Optional[str] = None,
170
+ parent_checkpoint_id: Optional[str] = None,
171
+ metadata: Optional[Dict[str, Any]] = None,
172
+ skip_if_unchanged: bool = True,
173
+ ) -> Optional[Checkpoint]:
174
+ """
175
+ Create a new checkpoint.
176
+
177
+ Args:
178
+ run_id: The workflow run identifier.
179
+ node_id: The node creating this checkpoint.
180
+ state: The state to persist.
181
+ branch_id: Optional branch identifier for parallel execution.
182
+ parent_checkpoint_id: Previous checkpoint in the chain.
183
+ metadata: Additional checkpoint metadata.
184
+ skip_if_unchanged: If True, skip creating checkpoint if state
185
+ hasn't changed from the last checkpoint.
186
+
187
+ Returns:
188
+ The created Checkpoint, or None if skipped due to no changes.
189
+
190
+ Raises:
191
+ ValueError: If state exceeds max_state_size.
192
+ """
193
+ # Get next sequence number
194
+ sequence_number = self._get_next_sequence(run_id)
195
+
196
+ # Create checkpoint
197
+ checkpoint = Checkpoint(
198
+ run_id=run_id,
199
+ node_id=node_id,
200
+ state=state,
201
+ sequence_number=sequence_number,
202
+ branch_id=branch_id,
203
+ parent_checkpoint_id=parent_checkpoint_id,
204
+ metadata=metadata or {},
205
+ )
206
+
207
+ # Validate
208
+ checkpoint.validate(self._max_state_size)
209
+
210
+ # Check if state has changed (optional optimization)
211
+ if skip_if_unchanged and parent_checkpoint_id:
212
+ parent = self.get_checkpoint(parent_checkpoint_id)
213
+ if parent and not parent.has_changed(state):
214
+ return None
215
+
216
+ # Persist
217
+ self._storage.save_checkpoint(checkpoint)
218
+
219
+ # Update sequence cache
220
+ self._sequence_cache[run_id] = sequence_number
221
+
222
+ return checkpoint
223
+
224
+ def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
225
+ """Get a checkpoint by ID."""
226
+ return self._storage.get_checkpoint(checkpoint_id)
227
+
228
+ def get_latest_checkpoint(
229
+ self,
230
+ run_id: str,
231
+ branch_id: Optional[str] = None,
232
+ ) -> Optional[Checkpoint]:
233
+ """
234
+ Get the most recent checkpoint for a run.
235
+
236
+ Args:
237
+ run_id: The workflow run identifier.
238
+ branch_id: Optional branch to filter by.
239
+
240
+ Returns:
241
+ The latest checkpoint, or None if no checkpoints exist.
242
+ """
243
+ return self._storage.get_latest_checkpoint(run_id, branch_id)
244
+
245
+ def get_resume_point(self, run_id: str) -> Optional[Checkpoint]:
246
+ """
247
+ Get the checkpoint to resume from after a crash.
248
+
249
+ This is an alias for get_latest_checkpoint for clarity.
250
+
251
+ Args:
252
+ run_id: The workflow run identifier.
253
+
254
+ Returns:
255
+ The checkpoint to resume from, or None if no checkpoints.
256
+ """
257
+ return self.get_latest_checkpoint(run_id)
258
+
259
+ def get_branch_checkpoints(
260
+ self,
261
+ run_id: str,
262
+ branch_ids: List[str],
263
+ ) -> Dict[str, Checkpoint]:
264
+ """
265
+ Get the latest checkpoint for each branch.
266
+
267
+ Used for parallel merge operations.
268
+
269
+ Args:
270
+ run_id: The workflow run identifier.
271
+ branch_ids: List of branch identifiers.
272
+
273
+ Returns:
274
+ Dictionary mapping branch_id to its latest checkpoint.
275
+ """
276
+ result = {}
277
+ for branch_id in branch_ids:
278
+ checkpoint = self.get_latest_checkpoint(run_id, branch_id)
279
+ if checkpoint:
280
+ result[branch_id] = checkpoint
281
+ return result
282
+
283
+ def cleanup_checkpoints(
284
+ self,
285
+ run_id: str,
286
+ keep_latest: int = 1,
287
+ ) -> int:
288
+ """
289
+ Clean up old checkpoints for a completed run.
290
+
291
+ Args:
292
+ run_id: The workflow run identifier.
293
+ keep_latest: Number of latest checkpoints to keep.
294
+
295
+ Returns:
296
+ Number of checkpoints deleted.
297
+ """
298
+ return self._storage.cleanup_checkpoints(run_id, keep_latest)
299
+
300
+ def _get_next_sequence(self, run_id: str) -> int:
301
+ """Get the next sequence number for a run."""
302
+ if run_id in self._sequence_cache:
303
+ return self._sequence_cache[run_id] + 1
304
+
305
+ # Query storage for latest sequence
306
+ latest = self.get_latest_checkpoint(run_id)
307
+ if latest:
308
+ self._sequence_cache[run_id] = latest.sequence_number
309
+ return latest.sequence_number + 1
310
+
311
+ return 0
@@ -0,0 +1,228 @@
1
+ """
2
+ ALMA Workflow Context.
3
+
4
+ Defines the WorkflowContext dataclass and RetrievalScope enum for
5
+ scoped memory retrieval in workflow orchestration systems.
6
+
7
+ Sprint 1 Tasks 1.1, 1.2
8
+ """
9
+
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime, timezone
12
+ from enum import Enum
13
+ from typing import Any, Dict, Optional
14
+
15
+
16
+ class RetrievalScope(Enum):
17
+ """
18
+ Defines the scope for memory retrieval operations.
19
+
20
+ The hierarchy from most specific to most general:
21
+ NODE -> RUN -> WORKFLOW -> AGENT -> TENANT -> GLOBAL
22
+
23
+ Note: Named RetrievalScope (not MemoryScope) to avoid collision
24
+ with the existing MemoryScope dataclass in alma/types.py which
25
+ defines what an agent is *allowed* to learn, not *where* to search.
26
+ """
27
+
28
+ # Most specific - only memories from this specific node execution
29
+ NODE = "node"
30
+
31
+ # Memories from this specific workflow run
32
+ RUN = "run"
33
+
34
+ # Memories from all runs of this workflow definition
35
+ WORKFLOW = "workflow"
36
+
37
+ # Memories from all workflows for this agent (default)
38
+ AGENT = "agent"
39
+
40
+ # Memories from all agents within this tenant
41
+ TENANT = "tenant"
42
+
43
+ # All memories across all tenants (admin only)
44
+ GLOBAL = "global"
45
+
46
+ @classmethod
47
+ def from_string(cls, value: str) -> "RetrievalScope":
48
+ """Convert string to RetrievalScope enum."""
49
+ try:
50
+ return cls(value.lower())
51
+ except ValueError as err:
52
+ raise ValueError(
53
+ f"Invalid RetrievalScope: '{value}'. "
54
+ f"Valid options: {[s.value for s in cls]}"
55
+ ) from err
56
+
57
+ def is_broader_than(self, other: "RetrievalScope") -> bool:
58
+ """Check if this scope is broader than another scope."""
59
+ hierarchy = [
60
+ RetrievalScope.NODE,
61
+ RetrievalScope.RUN,
62
+ RetrievalScope.WORKFLOW,
63
+ RetrievalScope.AGENT,
64
+ RetrievalScope.TENANT,
65
+ RetrievalScope.GLOBAL,
66
+ ]
67
+ return hierarchy.index(self) > hierarchy.index(other)
68
+
69
+
70
+ @dataclass
71
+ class WorkflowContext:
72
+ """
73
+ Context for workflow-scoped memory operations.
74
+
75
+ Provides hierarchical scoping for AGtestari and similar workflow
76
+ orchestration systems. All fields are optional except when
77
+ require_tenant=True is passed to validate().
78
+
79
+ Attributes:
80
+ tenant_id: Multi-tenant isolation identifier
81
+ workflow_id: Workflow definition identifier
82
+ run_id: Specific workflow execution identifier
83
+ node_id: Current node within the workflow
84
+ branch_id: Parallel branch identifier (for fan-out patterns)
85
+ metadata: Additional context data
86
+ created_at: When this context was created
87
+ """
88
+
89
+ tenant_id: Optional[str] = None
90
+ workflow_id: Optional[str] = None
91
+ run_id: Optional[str] = None
92
+ node_id: Optional[str] = None
93
+ branch_id: Optional[str] = None
94
+ metadata: Dict[str, Any] = field(default_factory=dict)
95
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
96
+
97
+ def validate(self, require_tenant: bool = False) -> None:
98
+ """
99
+ Validate the workflow context.
100
+
101
+ Args:
102
+ require_tenant: If True, tenant_id must be provided.
103
+ Use for multi-tenant deployments.
104
+
105
+ Raises:
106
+ ValueError: If validation fails.
107
+ """
108
+ if require_tenant and not self.tenant_id:
109
+ raise ValueError(
110
+ "tenant_id is required for multi-tenant deployments. "
111
+ "Set require_tenant=False for single-tenant mode."
112
+ )
113
+
114
+ # If node_id is provided, run_id should also be provided
115
+ if self.node_id and not self.run_id:
116
+ raise ValueError("node_id requires run_id to be set")
117
+
118
+ # If run_id is provided, workflow_id should also be provided
119
+ if self.run_id and not self.workflow_id:
120
+ raise ValueError("run_id requires workflow_id to be set")
121
+
122
+ # If branch_id is provided, run_id should also be provided
123
+ if self.branch_id and not self.run_id:
124
+ raise ValueError("branch_id requires run_id to be set")
125
+
126
+ def get_scope_filter(self, scope: RetrievalScope) -> Dict[str, Any]:
127
+ """
128
+ Build a filter dict for the given retrieval scope.
129
+
130
+ Returns a dictionary that can be used to filter memories
131
+ based on the workflow context and requested scope.
132
+
133
+ Args:
134
+ scope: The retrieval scope to filter by.
135
+
136
+ Returns:
137
+ Dictionary with filter criteria.
138
+ """
139
+ filters: Dict[str, Any] = {}
140
+
141
+ if scope == RetrievalScope.GLOBAL:
142
+ # No filters - return everything
143
+ pass
144
+ elif scope == RetrievalScope.TENANT:
145
+ if self.tenant_id:
146
+ filters["tenant_id"] = self.tenant_id
147
+ elif scope == RetrievalScope.AGENT:
148
+ if self.tenant_id:
149
+ filters["tenant_id"] = self.tenant_id
150
+ # Agent filtering is done separately via the agent parameter
151
+ elif scope == RetrievalScope.WORKFLOW:
152
+ if self.tenant_id:
153
+ filters["tenant_id"] = self.tenant_id
154
+ if self.workflow_id:
155
+ filters["workflow_id"] = self.workflow_id
156
+ elif scope == RetrievalScope.RUN:
157
+ if self.tenant_id:
158
+ filters["tenant_id"] = self.tenant_id
159
+ if self.workflow_id:
160
+ filters["workflow_id"] = self.workflow_id
161
+ if self.run_id:
162
+ filters["run_id"] = self.run_id
163
+ elif scope == RetrievalScope.NODE:
164
+ if self.tenant_id:
165
+ filters["tenant_id"] = self.tenant_id
166
+ if self.workflow_id:
167
+ filters["workflow_id"] = self.workflow_id
168
+ if self.run_id:
169
+ filters["run_id"] = self.run_id
170
+ if self.node_id:
171
+ filters["node_id"] = self.node_id
172
+
173
+ return filters
174
+
175
+ def with_node(self, node_id: str) -> "WorkflowContext":
176
+ """Create a new context with a different node_id."""
177
+ return WorkflowContext(
178
+ tenant_id=self.tenant_id,
179
+ workflow_id=self.workflow_id,
180
+ run_id=self.run_id,
181
+ node_id=node_id,
182
+ branch_id=self.branch_id,
183
+ metadata=self.metadata.copy(),
184
+ created_at=self.created_at,
185
+ )
186
+
187
+ def with_branch(self, branch_id: str) -> "WorkflowContext":
188
+ """Create a new context for a parallel branch."""
189
+ return WorkflowContext(
190
+ tenant_id=self.tenant_id,
191
+ workflow_id=self.workflow_id,
192
+ run_id=self.run_id,
193
+ node_id=self.node_id,
194
+ branch_id=branch_id,
195
+ metadata=self.metadata.copy(),
196
+ created_at=self.created_at,
197
+ )
198
+
199
+ def to_dict(self) -> Dict[str, Any]:
200
+ """Convert to dictionary for serialization."""
201
+ return {
202
+ "tenant_id": self.tenant_id,
203
+ "workflow_id": self.workflow_id,
204
+ "run_id": self.run_id,
205
+ "node_id": self.node_id,
206
+ "branch_id": self.branch_id,
207
+ "metadata": self.metadata,
208
+ "created_at": self.created_at.isoformat(),
209
+ }
210
+
211
+ @classmethod
212
+ def from_dict(cls, data: Dict[str, Any]) -> "WorkflowContext":
213
+ """Create from dictionary."""
214
+ created_at = data.get("created_at")
215
+ if isinstance(created_at, str):
216
+ created_at = datetime.fromisoformat(created_at)
217
+ elif created_at is None:
218
+ created_at = datetime.now(timezone.utc)
219
+
220
+ return cls(
221
+ tenant_id=data.get("tenant_id"),
222
+ workflow_id=data.get("workflow_id"),
223
+ run_id=data.get("run_id"),
224
+ node_id=data.get("node_id"),
225
+ branch_id=data.get("branch_id"),
226
+ metadata=data.get("metadata", {}),
227
+ created_at=created_at,
228
+ )