alma-memory 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alma/__init__.py +296 -226
- alma/compression/__init__.py +33 -0
- alma/compression/pipeline.py +980 -0
- alma/confidence/__init__.py +47 -47
- alma/confidence/engine.py +540 -540
- alma/confidence/types.py +351 -351
- alma/config/loader.py +157 -157
- alma/consolidation/__init__.py +23 -23
- alma/consolidation/engine.py +678 -678
- alma/consolidation/prompts.py +84 -84
- alma/core.py +1189 -430
- alma/domains/__init__.py +30 -30
- alma/domains/factory.py +359 -359
- alma/domains/schemas.py +448 -448
- alma/domains/types.py +272 -272
- alma/events/__init__.py +75 -75
- alma/events/emitter.py +285 -284
- alma/events/storage_mixin.py +246 -246
- alma/events/types.py +126 -126
- alma/events/webhook.py +425 -425
- alma/exceptions.py +49 -49
- alma/extraction/__init__.py +31 -31
- alma/extraction/auto_learner.py +265 -265
- alma/extraction/extractor.py +420 -420
- alma/graph/__init__.py +106 -106
- alma/graph/backends/__init__.py +32 -32
- alma/graph/backends/kuzu.py +624 -624
- alma/graph/backends/memgraph.py +432 -432
- alma/graph/backends/memory.py +236 -236
- alma/graph/backends/neo4j.py +417 -417
- alma/graph/base.py +159 -159
- alma/graph/extraction.py +198 -198
- alma/graph/store.py +860 -860
- alma/harness/__init__.py +35 -35
- alma/harness/base.py +386 -386
- alma/harness/domains.py +705 -705
- alma/initializer/__init__.py +37 -37
- alma/initializer/initializer.py +418 -418
- alma/initializer/types.py +250 -250
- alma/integration/__init__.py +62 -62
- alma/integration/claude_agents.py +444 -444
- alma/integration/helena.py +423 -423
- alma/integration/victor.py +471 -471
- alma/learning/__init__.py +101 -86
- alma/learning/decay.py +878 -0
- alma/learning/forgetting.py +1446 -1446
- alma/learning/heuristic_extractor.py +390 -390
- alma/learning/protocols.py +374 -374
- alma/learning/validation.py +346 -346
- alma/mcp/__init__.py +123 -45
- alma/mcp/__main__.py +156 -156
- alma/mcp/resources.py +122 -122
- alma/mcp/server.py +955 -591
- alma/mcp/tools.py +3254 -509
- alma/observability/__init__.py +91 -84
- alma/observability/config.py +302 -302
- alma/observability/guidelines.py +170 -0
- alma/observability/logging.py +424 -424
- alma/observability/metrics.py +583 -583
- alma/observability/tracing.py +440 -440
- alma/progress/__init__.py +21 -21
- alma/progress/tracker.py +607 -607
- alma/progress/types.py +250 -250
- alma/retrieval/__init__.py +134 -53
- alma/retrieval/budget.py +525 -0
- alma/retrieval/cache.py +1304 -1061
- alma/retrieval/embeddings.py +202 -202
- alma/retrieval/engine.py +850 -427
- alma/retrieval/modes.py +365 -0
- alma/retrieval/progressive.py +560 -0
- alma/retrieval/scoring.py +344 -344
- alma/retrieval/trust_scoring.py +637 -0
- alma/retrieval/verification.py +797 -0
- alma/session/__init__.py +19 -19
- alma/session/manager.py +442 -399
- alma/session/types.py +288 -288
- alma/storage/__init__.py +101 -90
- alma/storage/archive.py +233 -0
- alma/storage/azure_cosmos.py +1259 -1259
- alma/storage/base.py +1083 -583
- alma/storage/chroma.py +1443 -1443
- alma/storage/constants.py +103 -103
- alma/storage/file_based.py +614 -614
- alma/storage/migrations/__init__.py +21 -21
- alma/storage/migrations/base.py +321 -321
- alma/storage/migrations/runner.py +323 -323
- alma/storage/migrations/version_stores.py +337 -337
- alma/storage/migrations/versions/__init__.py +11 -11
- alma/storage/migrations/versions/v1_0_0.py +373 -373
- alma/storage/migrations/versions/v1_1_0_workflow_context.py +551 -0
- alma/storage/pinecone.py +1080 -1080
- alma/storage/postgresql.py +1948 -1559
- alma/storage/qdrant.py +1306 -1306
- alma/storage/sqlite_local.py +3041 -1457
- alma/testing/__init__.py +46 -46
- alma/testing/factories.py +301 -301
- alma/testing/mocks.py +389 -389
- alma/types.py +292 -264
- alma/utils/__init__.py +19 -0
- alma/utils/tokenizer.py +521 -0
- alma/workflow/__init__.py +83 -0
- alma/workflow/artifacts.py +170 -0
- alma/workflow/checkpoint.py +311 -0
- alma/workflow/context.py +228 -0
- alma/workflow/outcomes.py +189 -0
- alma/workflow/reducers.py +393 -0
- {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/METADATA +210 -72
- alma_memory-0.7.0.dist-info/RECORD +112 -0
- alma_memory-0.5.1.dist-info/RECORD +0 -93
- {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/WHEEL +0 -0
- {alma_memory-0.5.1.dist-info → alma_memory-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Workflow Outcomes.
|
|
3
|
+
|
|
4
|
+
Provides the WorkflowOutcome dataclass for capturing learnings
|
|
5
|
+
from completed workflow executions.
|
|
6
|
+
|
|
7
|
+
Sprint 1 Task 1.5
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, Dict, List, Optional
|
|
14
|
+
from uuid import uuid4
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WorkflowResult(Enum):
|
|
18
|
+
"""Result status of a workflow execution."""
|
|
19
|
+
|
|
20
|
+
SUCCESS = "success"
|
|
21
|
+
FAILURE = "failure"
|
|
22
|
+
PARTIAL = "partial" # Partially succeeded
|
|
23
|
+
CANCELLED = "cancelled"
|
|
24
|
+
TIMEOUT = "timeout"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class WorkflowOutcome:
|
|
29
|
+
"""
|
|
30
|
+
Captures learnings from a completed workflow execution.
|
|
31
|
+
|
|
32
|
+
WorkflowOutcome records what was learned from running a workflow,
|
|
33
|
+
including the strategies used, what worked, what didn't, and any
|
|
34
|
+
extracted heuristics or anti-patterns.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
id: Unique outcome identifier
|
|
38
|
+
tenant_id: Multi-tenant isolation identifier
|
|
39
|
+
workflow_id: The workflow definition that was executed
|
|
40
|
+
run_id: The specific run this outcome is from
|
|
41
|
+
agent: The agent that executed the workflow
|
|
42
|
+
project_id: Project scope identifier
|
|
43
|
+
result: Overall result status
|
|
44
|
+
summary: Human-readable summary of what happened
|
|
45
|
+
strategies_used: List of strategies/approaches attempted
|
|
46
|
+
successful_patterns: Patterns that worked well
|
|
47
|
+
failed_patterns: Patterns that didn't work
|
|
48
|
+
extracted_heuristics: IDs of heuristics created from this run
|
|
49
|
+
extracted_anti_patterns: IDs of anti-patterns created from this run
|
|
50
|
+
duration_seconds: How long the workflow took
|
|
51
|
+
node_count: Number of nodes executed
|
|
52
|
+
error_message: Error details if failed
|
|
53
|
+
embedding: Vector embedding for semantic search
|
|
54
|
+
metadata: Additional outcome metadata
|
|
55
|
+
created_at: When this outcome was recorded
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
id: str = field(default_factory=lambda: str(uuid4()))
|
|
59
|
+
tenant_id: Optional[str] = None
|
|
60
|
+
workflow_id: str = ""
|
|
61
|
+
run_id: str = ""
|
|
62
|
+
agent: str = ""
|
|
63
|
+
project_id: str = ""
|
|
64
|
+
result: WorkflowResult = WorkflowResult.SUCCESS
|
|
65
|
+
summary: str = ""
|
|
66
|
+
strategies_used: List[str] = field(default_factory=list)
|
|
67
|
+
successful_patterns: List[str] = field(default_factory=list)
|
|
68
|
+
failed_patterns: List[str] = field(default_factory=list)
|
|
69
|
+
extracted_heuristics: List[str] = field(default_factory=list)
|
|
70
|
+
extracted_anti_patterns: List[str] = field(default_factory=list)
|
|
71
|
+
duration_seconds: Optional[float] = None
|
|
72
|
+
node_count: Optional[int] = None
|
|
73
|
+
error_message: Optional[str] = None
|
|
74
|
+
embedding: Optional[List[float]] = None
|
|
75
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
76
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
77
|
+
|
|
78
|
+
def validate(self, require_tenant: bool = False) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Validate the workflow outcome.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
require_tenant: If True, tenant_id must be provided.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If validation fails.
|
|
87
|
+
"""
|
|
88
|
+
if require_tenant and not self.tenant_id:
|
|
89
|
+
raise ValueError("tenant_id is required for multi-tenant deployments")
|
|
90
|
+
if not self.workflow_id:
|
|
91
|
+
raise ValueError("workflow_id is required")
|
|
92
|
+
if not self.run_id:
|
|
93
|
+
raise ValueError("run_id is required")
|
|
94
|
+
if not self.agent:
|
|
95
|
+
raise ValueError("agent is required")
|
|
96
|
+
if not self.project_id:
|
|
97
|
+
raise ValueError("project_id is required")
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def is_success(self) -> bool:
|
|
101
|
+
"""Check if the workflow succeeded."""
|
|
102
|
+
return self.result == WorkflowResult.SUCCESS
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def is_failure(self) -> bool:
|
|
106
|
+
"""Check if the workflow failed."""
|
|
107
|
+
return self.result in (WorkflowResult.FAILURE, WorkflowResult.TIMEOUT)
|
|
108
|
+
|
|
109
|
+
def get_searchable_text(self) -> str:
|
|
110
|
+
"""
|
|
111
|
+
Get text suitable for embedding generation.
|
|
112
|
+
|
|
113
|
+
Combines summary, strategies, and patterns into a single
|
|
114
|
+
searchable string.
|
|
115
|
+
"""
|
|
116
|
+
parts = [self.summary]
|
|
117
|
+
|
|
118
|
+
if self.strategies_used:
|
|
119
|
+
parts.append("Strategies: " + ", ".join(self.strategies_used))
|
|
120
|
+
|
|
121
|
+
if self.successful_patterns:
|
|
122
|
+
parts.append("Successful: " + ", ".join(self.successful_patterns))
|
|
123
|
+
|
|
124
|
+
if self.failed_patterns:
|
|
125
|
+
parts.append("Failed: " + ", ".join(self.failed_patterns))
|
|
126
|
+
|
|
127
|
+
if self.error_message:
|
|
128
|
+
parts.append("Error: " + self.error_message)
|
|
129
|
+
|
|
130
|
+
return " | ".join(parts)
|
|
131
|
+
|
|
132
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
133
|
+
"""Convert to dictionary for serialization."""
|
|
134
|
+
return {
|
|
135
|
+
"id": self.id,
|
|
136
|
+
"tenant_id": self.tenant_id,
|
|
137
|
+
"workflow_id": self.workflow_id,
|
|
138
|
+
"run_id": self.run_id,
|
|
139
|
+
"agent": self.agent,
|
|
140
|
+
"project_id": self.project_id,
|
|
141
|
+
"result": self.result.value,
|
|
142
|
+
"summary": self.summary,
|
|
143
|
+
"strategies_used": self.strategies_used,
|
|
144
|
+
"successful_patterns": self.successful_patterns,
|
|
145
|
+
"failed_patterns": self.failed_patterns,
|
|
146
|
+
"extracted_heuristics": self.extracted_heuristics,
|
|
147
|
+
"extracted_anti_patterns": self.extracted_anti_patterns,
|
|
148
|
+
"duration_seconds": self.duration_seconds,
|
|
149
|
+
"node_count": self.node_count,
|
|
150
|
+
"error_message": self.error_message,
|
|
151
|
+
"embedding": self.embedding,
|
|
152
|
+
"metadata": self.metadata,
|
|
153
|
+
"created_at": self.created_at.isoformat(),
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowOutcome":
|
|
158
|
+
"""Create from dictionary."""
|
|
159
|
+
created_at = data.get("created_at")
|
|
160
|
+
if isinstance(created_at, str):
|
|
161
|
+
created_at = datetime.fromisoformat(created_at)
|
|
162
|
+
elif created_at is None:
|
|
163
|
+
created_at = datetime.now(timezone.utc)
|
|
164
|
+
|
|
165
|
+
result = data.get("result", "success")
|
|
166
|
+
if isinstance(result, str):
|
|
167
|
+
result = WorkflowResult(result)
|
|
168
|
+
|
|
169
|
+
return cls(
|
|
170
|
+
id=data.get("id", str(uuid4())),
|
|
171
|
+
tenant_id=data.get("tenant_id"),
|
|
172
|
+
workflow_id=data.get("workflow_id", ""),
|
|
173
|
+
run_id=data.get("run_id", ""),
|
|
174
|
+
agent=data.get("agent", ""),
|
|
175
|
+
project_id=data.get("project_id", ""),
|
|
176
|
+
result=result,
|
|
177
|
+
summary=data.get("summary", ""),
|
|
178
|
+
strategies_used=data.get("strategies_used", []),
|
|
179
|
+
successful_patterns=data.get("successful_patterns", []),
|
|
180
|
+
failed_patterns=data.get("failed_patterns", []),
|
|
181
|
+
extracted_heuristics=data.get("extracted_heuristics", []),
|
|
182
|
+
extracted_anti_patterns=data.get("extracted_anti_patterns", []),
|
|
183
|
+
duration_seconds=data.get("duration_seconds"),
|
|
184
|
+
node_count=data.get("node_count"),
|
|
185
|
+
error_message=data.get("error_message"),
|
|
186
|
+
embedding=data.get("embedding"),
|
|
187
|
+
metadata=data.get("metadata", {}),
|
|
188
|
+
created_at=created_at,
|
|
189
|
+
)
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA State Reducers.
|
|
3
|
+
|
|
4
|
+
Provides state reducers for merging parallel branch states in workflow
|
|
5
|
+
orchestration. Each reducer defines how to combine values from multiple
|
|
6
|
+
branches into a single value.
|
|
7
|
+
|
|
8
|
+
Sprint 1 Tasks 1.8-1.10
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from alma.workflow.checkpoint import Checkpoint
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StateReducer(ABC):
|
|
22
|
+
"""
|
|
23
|
+
Abstract base class for state reducers.
|
|
24
|
+
|
|
25
|
+
A reducer defines how to combine multiple values (from parallel branches)
|
|
26
|
+
into a single value during state merge operations.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def name(self) -> str:
|
|
32
|
+
"""The name of this reducer."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def reduce(self, values: List[Any]) -> Any:
|
|
37
|
+
"""
|
|
38
|
+
Reduce multiple values into a single value.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
values: List of values from different branches.
|
|
42
|
+
May contain None values.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The reduced single value.
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AppendReducer(StateReducer):
|
|
51
|
+
"""
|
|
52
|
+
Concatenates lists from all branches.
|
|
53
|
+
|
|
54
|
+
Use for: messages, logs, notes, events
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def name(self) -> str:
|
|
59
|
+
return "append"
|
|
60
|
+
|
|
61
|
+
def reduce(self, values: List[Any]) -> List[Any]:
|
|
62
|
+
"""Concatenate all lists, preserving order."""
|
|
63
|
+
result: List[Any] = []
|
|
64
|
+
for value in values:
|
|
65
|
+
if value is None:
|
|
66
|
+
continue
|
|
67
|
+
if isinstance(value, list):
|
|
68
|
+
result.extend(value)
|
|
69
|
+
else:
|
|
70
|
+
result.append(value)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class MergeDictReducer(StateReducer):
|
|
75
|
+
"""
|
|
76
|
+
Merges dictionaries, with later values overwriting earlier ones.
|
|
77
|
+
|
|
78
|
+
Use for: context, metadata, configuration
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def name(self) -> str:
|
|
83
|
+
return "merge_dict"
|
|
84
|
+
|
|
85
|
+
def reduce(self, values: List[Any]) -> Dict[str, Any]:
|
|
86
|
+
"""Merge all dictionaries, later values win."""
|
|
87
|
+
result: Dict[str, Any] = {}
|
|
88
|
+
for value in values:
|
|
89
|
+
if value is None:
|
|
90
|
+
continue
|
|
91
|
+
if isinstance(value, dict):
|
|
92
|
+
result.update(value)
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class LastValueReducer(StateReducer):
|
|
97
|
+
"""
|
|
98
|
+
Takes the last non-None value.
|
|
99
|
+
|
|
100
|
+
Use for: single values where the most recent is preferred
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def name(self) -> str:
|
|
105
|
+
return "last_value"
|
|
106
|
+
|
|
107
|
+
def reduce(self, values: List[Any]) -> Any:
|
|
108
|
+
"""Return the last non-None value."""
|
|
109
|
+
for value in reversed(values):
|
|
110
|
+
if value is not None:
|
|
111
|
+
return value
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class FirstValueReducer(StateReducer):
|
|
116
|
+
"""
|
|
117
|
+
Takes the first non-None value.
|
|
118
|
+
|
|
119
|
+
Use for: priority values, initial state
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def name(self) -> str:
|
|
124
|
+
return "first_value"
|
|
125
|
+
|
|
126
|
+
def reduce(self, values: List[Any]) -> Any:
|
|
127
|
+
"""Return the first non-None value."""
|
|
128
|
+
for value in values:
|
|
129
|
+
if value is not None:
|
|
130
|
+
return value
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class SumReducer(StateReducer):
|
|
135
|
+
"""
|
|
136
|
+
Sums numeric values.
|
|
137
|
+
|
|
138
|
+
Use for: counters, scores, totals
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def name(self) -> str:
|
|
143
|
+
return "sum"
|
|
144
|
+
|
|
145
|
+
def reduce(self, values: List[Any]) -> Union[int, float]:
|
|
146
|
+
"""Sum all numeric values."""
|
|
147
|
+
total: Union[int, float] = 0
|
|
148
|
+
for value in values:
|
|
149
|
+
if value is not None and isinstance(value, (int, float)):
|
|
150
|
+
total += value
|
|
151
|
+
return total
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MaxReducer(StateReducer):
|
|
155
|
+
"""
|
|
156
|
+
Takes the maximum value.
|
|
157
|
+
|
|
158
|
+
Use for: high scores, limits, thresholds
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def name(self) -> str:
|
|
163
|
+
return "max"
|
|
164
|
+
|
|
165
|
+
def reduce(self, values: List[Any]) -> Optional[Union[int, float]]:
|
|
166
|
+
"""Return the maximum value."""
|
|
167
|
+
numeric_values = [
|
|
168
|
+
v for v in values if v is not None and isinstance(v, (int, float))
|
|
169
|
+
]
|
|
170
|
+
return max(numeric_values) if numeric_values else None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class MinReducer(StateReducer):
|
|
174
|
+
"""
|
|
175
|
+
Takes the minimum value.
|
|
176
|
+
|
|
177
|
+
Use for: low scores, minimums
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def name(self) -> str:
|
|
182
|
+
return "min"
|
|
183
|
+
|
|
184
|
+
def reduce(self, values: List[Any]) -> Optional[Union[int, float]]:
|
|
185
|
+
"""Return the minimum value."""
|
|
186
|
+
numeric_values = [
|
|
187
|
+
v for v in values if v is not None and isinstance(v, (int, float))
|
|
188
|
+
]
|
|
189
|
+
return min(numeric_values) if numeric_values else None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class UnionReducer(StateReducer):
|
|
193
|
+
"""
|
|
194
|
+
Creates a set union of all values.
|
|
195
|
+
|
|
196
|
+
Use for: tags, categories, unique items
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def name(self) -> str:
|
|
201
|
+
return "union"
|
|
202
|
+
|
|
203
|
+
def reduce(self, values: List[Any]) -> List[Any]:
|
|
204
|
+
"""Return union of all values as a list."""
|
|
205
|
+
seen: set = set()
|
|
206
|
+
result: List[Any] = []
|
|
207
|
+
for value in values:
|
|
208
|
+
if value is None:
|
|
209
|
+
continue
|
|
210
|
+
items = value if isinstance(value, (list, set)) else [value]
|
|
211
|
+
for item in items:
|
|
212
|
+
# Handle unhashable types
|
|
213
|
+
try:
|
|
214
|
+
if item not in seen:
|
|
215
|
+
seen.add(item)
|
|
216
|
+
result.append(item)
|
|
217
|
+
except TypeError:
|
|
218
|
+
# Unhashable type - just append
|
|
219
|
+
result.append(item)
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# Built-in reducer instances
|
|
224
|
+
BUILTIN_REDUCERS: Dict[str, StateReducer] = {
|
|
225
|
+
"append": AppendReducer(),
|
|
226
|
+
"merge_dict": MergeDictReducer(),
|
|
227
|
+
"last_value": LastValueReducer(),
|
|
228
|
+
"first_value": FirstValueReducer(),
|
|
229
|
+
"sum": SumReducer(),
|
|
230
|
+
"max": MaxReducer(),
|
|
231
|
+
"min": MinReducer(),
|
|
232
|
+
"union": UnionReducer(),
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_reducer(name: str) -> StateReducer:
|
|
237
|
+
"""
|
|
238
|
+
Get a built-in reducer by name.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
name: The reducer name.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
The reducer instance.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If reducer name is unknown.
|
|
248
|
+
"""
|
|
249
|
+
if name not in BUILTIN_REDUCERS:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Unknown reducer: '{name}'. "
|
|
252
|
+
f"Available reducers: {list(BUILTIN_REDUCERS.keys())}"
|
|
253
|
+
)
|
|
254
|
+
return BUILTIN_REDUCERS[name]
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class ReducerConfig:
|
|
259
|
+
"""
|
|
260
|
+
Configuration for state merging.
|
|
261
|
+
|
|
262
|
+
Specifies which reducer to use for each field in the state.
|
|
263
|
+
|
|
264
|
+
Attributes:
|
|
265
|
+
field_reducers: Mapping of field names to reducer names.
|
|
266
|
+
default_reducer: Default reducer for fields not in field_reducers.
|
|
267
|
+
custom_reducers: Custom reducer instances by name.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
field_reducers: Dict[str, str] = field(default_factory=dict)
|
|
271
|
+
default_reducer: str = "last_value"
|
|
272
|
+
custom_reducers: Dict[str, StateReducer] = field(default_factory=dict)
|
|
273
|
+
|
|
274
|
+
def get_reducer_for_field(self, field_name: str) -> StateReducer:
|
|
275
|
+
"""
|
|
276
|
+
Get the reducer for a specific field.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
field_name: The field name.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
The reducer to use for this field.
|
|
283
|
+
"""
|
|
284
|
+
reducer_name = self.field_reducers.get(field_name, self.default_reducer)
|
|
285
|
+
|
|
286
|
+
# Check custom reducers first
|
|
287
|
+
if reducer_name in self.custom_reducers:
|
|
288
|
+
return self.custom_reducers[reducer_name]
|
|
289
|
+
|
|
290
|
+
# Fall back to built-in
|
|
291
|
+
return get_reducer(reducer_name)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class StateMerger:
|
|
295
|
+
"""
|
|
296
|
+
Merges states from multiple parallel branches.
|
|
297
|
+
|
|
298
|
+
Uses ReducerConfig to determine how each field should be merged.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(self, config: Optional[ReducerConfig] = None):
|
|
302
|
+
"""
|
|
303
|
+
Initialize the state merger.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
config: Reducer configuration. Uses defaults if not provided.
|
|
307
|
+
"""
|
|
308
|
+
self.config = config or ReducerConfig()
|
|
309
|
+
|
|
310
|
+
def merge(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
311
|
+
"""
|
|
312
|
+
Merge multiple states into a single state.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
states: List of state dictionaries from parallel branches.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
The merged state dictionary.
|
|
319
|
+
"""
|
|
320
|
+
if not states:
|
|
321
|
+
return {}
|
|
322
|
+
|
|
323
|
+
if len(states) == 1:
|
|
324
|
+
return states[0].copy()
|
|
325
|
+
|
|
326
|
+
# Collect all field names
|
|
327
|
+
all_fields: set = set()
|
|
328
|
+
for state in states:
|
|
329
|
+
all_fields.update(state.keys())
|
|
330
|
+
|
|
331
|
+
# Merge each field
|
|
332
|
+
result: Dict[str, Any] = {}
|
|
333
|
+
for field_name in all_fields:
|
|
334
|
+
# Collect values for this field from all states
|
|
335
|
+
values = [state.get(field_name) for state in states]
|
|
336
|
+
|
|
337
|
+
# Get the appropriate reducer
|
|
338
|
+
reducer = self.config.get_reducer_for_field(field_name)
|
|
339
|
+
|
|
340
|
+
# Apply the reducer
|
|
341
|
+
result[field_name] = reducer.reduce(values)
|
|
342
|
+
|
|
343
|
+
return result
|
|
344
|
+
|
|
345
|
+
def merge_checkpoints(
|
|
346
|
+
self,
|
|
347
|
+
checkpoints: List["Checkpoint"], # type: ignore
|
|
348
|
+
) -> Dict[str, Any]:
|
|
349
|
+
"""
|
|
350
|
+
Merge states from multiple checkpoints.
|
|
351
|
+
|
|
352
|
+
Convenience method for merging checkpoint states.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
checkpoints: List of Checkpoint objects.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
The merged state dictionary.
|
|
359
|
+
"""
|
|
360
|
+
states = [cp.state for cp in checkpoints]
|
|
361
|
+
return self.merge(states)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# Convenience function for simple merges
|
|
365
|
+
def merge_states(
|
|
366
|
+
states: List[Dict[str, Any]],
|
|
367
|
+
config: Optional[ReducerConfig] = None,
|
|
368
|
+
) -> Dict[str, Any]:
|
|
369
|
+
"""
|
|
370
|
+
Merge multiple states into a single state.
|
|
371
|
+
|
|
372
|
+
Convenience function that creates a StateMerger and merges.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
states: List of state dictionaries.
|
|
376
|
+
config: Optional reducer configuration.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
The merged state dictionary.
|
|
380
|
+
|
|
381
|
+
Example:
|
|
382
|
+
>>> config = ReducerConfig(
|
|
383
|
+
... field_reducers={
|
|
384
|
+
... "messages": "append",
|
|
385
|
+
... "context": "merge_dict",
|
|
386
|
+
... "total_score": "sum",
|
|
387
|
+
... },
|
|
388
|
+
... default_reducer="last_value",
|
|
389
|
+
... )
|
|
390
|
+
>>> result = merge_states([state1, state2, state3], config)
|
|
391
|
+
"""
|
|
392
|
+
merger = StateMerger(config)
|
|
393
|
+
return merger.merge(states)
|