kailash 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +31 -0
- kailash/__main__.py +11 -0
- kailash/cli/__init__.py +5 -0
- kailash/cli/commands.py +563 -0
- kailash/manifest.py +778 -0
- kailash/nodes/__init__.py +23 -0
- kailash/nodes/ai/__init__.py +26 -0
- kailash/nodes/ai/agents.py +417 -0
- kailash/nodes/ai/models.py +488 -0
- kailash/nodes/api/__init__.py +52 -0
- kailash/nodes/api/auth.py +567 -0
- kailash/nodes/api/graphql.py +480 -0
- kailash/nodes/api/http.py +598 -0
- kailash/nodes/api/rate_limiting.py +572 -0
- kailash/nodes/api/rest.py +665 -0
- kailash/nodes/base.py +1032 -0
- kailash/nodes/base_async.py +128 -0
- kailash/nodes/code/__init__.py +32 -0
- kailash/nodes/code/python.py +1021 -0
- kailash/nodes/data/__init__.py +125 -0
- kailash/nodes/data/readers.py +496 -0
- kailash/nodes/data/sharepoint_graph.py +623 -0
- kailash/nodes/data/sql.py +380 -0
- kailash/nodes/data/streaming.py +1168 -0
- kailash/nodes/data/vector_db.py +964 -0
- kailash/nodes/data/writers.py +529 -0
- kailash/nodes/logic/__init__.py +6 -0
- kailash/nodes/logic/async_operations.py +702 -0
- kailash/nodes/logic/operations.py +551 -0
- kailash/nodes/transform/__init__.py +5 -0
- kailash/nodes/transform/processors.py +379 -0
- kailash/runtime/__init__.py +6 -0
- kailash/runtime/async_local.py +356 -0
- kailash/runtime/docker.py +697 -0
- kailash/runtime/local.py +434 -0
- kailash/runtime/parallel.py +557 -0
- kailash/runtime/runner.py +110 -0
- kailash/runtime/testing.py +347 -0
- kailash/sdk_exceptions.py +307 -0
- kailash/tracking/__init__.py +7 -0
- kailash/tracking/manager.py +885 -0
- kailash/tracking/metrics_collector.py +342 -0
- kailash/tracking/models.py +535 -0
- kailash/tracking/storage/__init__.py +0 -0
- kailash/tracking/storage/base.py +113 -0
- kailash/tracking/storage/database.py +619 -0
- kailash/tracking/storage/filesystem.py +543 -0
- kailash/utils/__init__.py +0 -0
- kailash/utils/export.py +924 -0
- kailash/utils/templates.py +680 -0
- kailash/visualization/__init__.py +62 -0
- kailash/visualization/api.py +732 -0
- kailash/visualization/dashboard.py +951 -0
- kailash/visualization/performance.py +808 -0
- kailash/visualization/reports.py +1471 -0
- kailash/workflow/__init__.py +15 -0
- kailash/workflow/builder.py +245 -0
- kailash/workflow/graph.py +827 -0
- kailash/workflow/mermaid_visualizer.py +628 -0
- kailash/workflow/mock_registry.py +63 -0
- kailash/workflow/runner.py +302 -0
- kailash/workflow/state.py +238 -0
- kailash/workflow/visualization.py +588 -0
- kailash-0.1.0.dist-info/METADATA +710 -0
- kailash-0.1.0.dist-info/RECORD +69 -0
- kailash-0.1.0.dist-info/WHEEL +5 -0
- kailash-0.1.0.dist-info/entry_points.txt +2 -0
- kailash-0.1.0.dist-info/licenses/LICENSE +21 -0
- kailash-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,302 @@
|
|
1
|
+
"""Workflow runner for executing connected workflows.
|
2
|
+
|
3
|
+
This module provides tools for connecting and executing multiple workflows,
|
4
|
+
allowing for complex multi-stage processing pipelines.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
9
|
+
|
10
|
+
from pydantic import BaseModel
|
11
|
+
|
12
|
+
from kailash.sdk_exceptions import WorkflowExecutionError
|
13
|
+
from kailash.tracking import TaskManager
|
14
|
+
from kailash.workflow.graph import Workflow
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class WorkflowConnection:
|
20
|
+
"""Defines a connection between two workflows."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
source_workflow_id: str,
|
25
|
+
target_workflow_id: str,
|
26
|
+
condition: Optional[Dict[str, Any]] = None,
|
27
|
+
state_mapping: Optional[Dict[str, str]] = None,
|
28
|
+
):
|
29
|
+
"""Initialize a workflow connection.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
source_workflow_id: ID of the source workflow
|
33
|
+
target_workflow_id: ID of the target workflow
|
34
|
+
condition: Optional condition for when this connection should be followed
|
35
|
+
state_mapping: Optional mapping of state fields between workflows
|
36
|
+
"""
|
37
|
+
self.source_workflow_id = source_workflow_id
|
38
|
+
self.target_workflow_id = target_workflow_id
|
39
|
+
self.condition = condition or {}
|
40
|
+
self.state_mapping = state_mapping or {}
|
41
|
+
|
42
|
+
def should_follow(self, state: BaseModel) -> bool:
|
43
|
+
"""Check if this connection should be followed based on state.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
state: The current state object
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
True if the connection should be followed, False otherwise
|
50
|
+
"""
|
51
|
+
if not self.condition:
|
52
|
+
# If no condition is specified, always follow the connection
|
53
|
+
return True
|
54
|
+
|
55
|
+
# Extract condition field and value from the state
|
56
|
+
field_name = self.condition.get("field")
|
57
|
+
operator = self.condition.get("operator", "==")
|
58
|
+
expected_value = self.condition.get("value")
|
59
|
+
|
60
|
+
if not field_name:
|
61
|
+
# If no field name is specified, always follow the connection
|
62
|
+
return True
|
63
|
+
|
64
|
+
# Get the field value from the state
|
65
|
+
field_value = getattr(state, field_name, None)
|
66
|
+
|
67
|
+
# Check the condition
|
68
|
+
if operator == "==":
|
69
|
+
return field_value == expected_value
|
70
|
+
elif operator == "!=":
|
71
|
+
return field_value != expected_value
|
72
|
+
elif operator == ">":
|
73
|
+
return field_value > expected_value
|
74
|
+
elif operator == ">=":
|
75
|
+
return field_value >= expected_value
|
76
|
+
elif operator == "<":
|
77
|
+
return field_value < expected_value
|
78
|
+
elif operator == "<=":
|
79
|
+
return field_value <= expected_value
|
80
|
+
elif operator == "in":
|
81
|
+
return field_value in expected_value
|
82
|
+
elif operator == "not in":
|
83
|
+
return field_value not in expected_value
|
84
|
+
else:
|
85
|
+
# Unknown operator, default to always follow
|
86
|
+
logger.warning(
|
87
|
+
f"Unknown condition operator: {operator}. Always following connection."
|
88
|
+
)
|
89
|
+
return True
|
90
|
+
|
91
|
+
def map_state(self, state: BaseModel) -> Dict[str, Any]:
|
92
|
+
"""Map state fields according to the mapping configuration.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
state: The current state object
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
Dictionary with mapped state fields
|
99
|
+
"""
|
100
|
+
if not self.state_mapping:
|
101
|
+
# If no mapping is specified, use the state as is
|
102
|
+
return {"state": state}
|
103
|
+
|
104
|
+
# Apply mappings
|
105
|
+
mapped_state = {}
|
106
|
+
for source_key, target_key in self.state_mapping.items():
|
107
|
+
if hasattr(state, source_key):
|
108
|
+
mapped_state[target_key] = getattr(state, source_key)
|
109
|
+
|
110
|
+
return mapped_state
|
111
|
+
|
112
|
+
|
113
|
+
class WorkflowRunner:
|
114
|
+
"""Manages execution across multiple connected workflows.
|
115
|
+
|
116
|
+
This class allows building complex processing pipelines by connecting
|
117
|
+
multiple workflows together, with conditional branching based on state.
|
118
|
+
"""
|
119
|
+
|
120
|
+
def __init__(self):
|
121
|
+
"""Initialize a workflow runner."""
|
122
|
+
self.workflows = {}
|
123
|
+
self.connections = []
|
124
|
+
|
125
|
+
def add_workflow(self, workflow_id: str, workflow: Workflow) -> None:
|
126
|
+
"""Add a workflow to the runner.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
workflow_id: Unique identifier for the workflow
|
130
|
+
workflow: Workflow instance
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
ValueError: If a workflow with the given ID already exists
|
134
|
+
"""
|
135
|
+
if workflow_id in self.workflows:
|
136
|
+
raise ValueError(f"Workflow with ID '{workflow_id}' already exists")
|
137
|
+
|
138
|
+
self.workflows[workflow_id] = workflow
|
139
|
+
logger.info(f"Added workflow '{workflow.name}' with ID '{workflow_id}'")
|
140
|
+
|
141
|
+
def connect_workflows(
|
142
|
+
self,
|
143
|
+
source_workflow_id: str,
|
144
|
+
target_workflow_id: str,
|
145
|
+
condition: Optional[Dict[str, Any]] = None,
|
146
|
+
state_mapping: Optional[Dict[str, str]] = None,
|
147
|
+
) -> None:
|
148
|
+
"""Connect two workflows.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
source_workflow_id: ID of the source workflow
|
152
|
+
target_workflow_id: ID of the target workflow
|
153
|
+
condition: Optional condition for when this connection should be followed
|
154
|
+
state_mapping: Optional mapping of state fields between workflows
|
155
|
+
|
156
|
+
Raises:
|
157
|
+
ValueError: If any workflow ID is invalid
|
158
|
+
"""
|
159
|
+
# Validate workflow IDs
|
160
|
+
if source_workflow_id not in self.workflows:
|
161
|
+
raise ValueError(
|
162
|
+
f"Source workflow with ID '{source_workflow_id}' not found"
|
163
|
+
)
|
164
|
+
|
165
|
+
if target_workflow_id not in self.workflows:
|
166
|
+
raise ValueError(
|
167
|
+
f"Target workflow with ID '{target_workflow_id}' not found"
|
168
|
+
)
|
169
|
+
|
170
|
+
# Create connection
|
171
|
+
connection = WorkflowConnection(
|
172
|
+
source_workflow_id=source_workflow_id,
|
173
|
+
target_workflow_id=target_workflow_id,
|
174
|
+
condition=condition,
|
175
|
+
state_mapping=state_mapping,
|
176
|
+
)
|
177
|
+
|
178
|
+
self.connections.append(connection)
|
179
|
+
logger.info(
|
180
|
+
f"Connected workflow '{source_workflow_id}' to '{target_workflow_id}'"
|
181
|
+
)
|
182
|
+
|
183
|
+
def get_next_workflows(
|
184
|
+
self, current_workflow_id: str, state: BaseModel
|
185
|
+
) -> List[Tuple[str, Dict[str, Any]]]:
|
186
|
+
"""Get the next workflows to execute based on current state.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
current_workflow_id: ID of the current workflow
|
190
|
+
state: Current state object
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
List of (workflow_id, mapped_state) tuples for next workflows
|
194
|
+
"""
|
195
|
+
next_workflows = []
|
196
|
+
|
197
|
+
for connection in self.connections:
|
198
|
+
if connection.source_workflow_id == current_workflow_id:
|
199
|
+
if connection.should_follow(state):
|
200
|
+
mapped_state = connection.map_state(state)
|
201
|
+
next_workflows.append((connection.target_workflow_id, mapped_state))
|
202
|
+
|
203
|
+
return next_workflows
|
204
|
+
|
205
|
+
def execute(
|
206
|
+
self,
|
207
|
+
entry_workflow_id: str,
|
208
|
+
initial_state: BaseModel,
|
209
|
+
task_manager: Optional[TaskManager] = None,
|
210
|
+
max_steps: int = 10, # Prevent infinite loops
|
211
|
+
) -> Tuple[BaseModel, Dict[str, Dict[str, Any]]]:
|
212
|
+
"""Execute a sequence of connected workflows.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
entry_workflow_id: ID of the first workflow to execute
|
216
|
+
initial_state: Initial state for workflow execution
|
217
|
+
task_manager: Optional task manager for tracking
|
218
|
+
max_steps: Maximum number of workflow steps to execute
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
Tuple of (final state, all results by workflow)
|
222
|
+
|
223
|
+
Raises:
|
224
|
+
WorkflowExecutionError: If workflow execution fails
|
225
|
+
ValueError: If entry workflow is not found
|
226
|
+
"""
|
227
|
+
if entry_workflow_id not in self.workflows:
|
228
|
+
raise ValueError(f"Entry workflow with ID '{entry_workflow_id}' not found")
|
229
|
+
|
230
|
+
# Initialize execution
|
231
|
+
current_workflow_id = entry_workflow_id
|
232
|
+
current_state = initial_state
|
233
|
+
all_results = {}
|
234
|
+
executed_workflows = set()
|
235
|
+
step_count = 0
|
236
|
+
|
237
|
+
# Execute workflows until no more connections to follow
|
238
|
+
while current_workflow_id and step_count < max_steps:
|
239
|
+
step_count += 1
|
240
|
+
logger.info(
|
241
|
+
f"Executing workflow '{current_workflow_id}' (step {step_count}/{max_steps})"
|
242
|
+
)
|
243
|
+
|
244
|
+
# Get the workflow
|
245
|
+
workflow = self.workflows[current_workflow_id]
|
246
|
+
|
247
|
+
# Track executed workflows to detect cycles
|
248
|
+
if current_workflow_id in executed_workflows:
|
249
|
+
logger.warning(
|
250
|
+
f"Cycle detected in workflow execution: already executed '{current_workflow_id}'"
|
251
|
+
)
|
252
|
+
# Continue to next workflow rather than stopping, to handle intentional cycles
|
253
|
+
|
254
|
+
executed_workflows.add(current_workflow_id)
|
255
|
+
|
256
|
+
try:
|
257
|
+
# Execute the workflow
|
258
|
+
final_state, workflow_results = workflow.execute_with_state(
|
259
|
+
state_model=current_state, task_manager=task_manager
|
260
|
+
)
|
261
|
+
|
262
|
+
# Store results
|
263
|
+
all_results[current_workflow_id] = workflow_results
|
264
|
+
|
265
|
+
# Update current state
|
266
|
+
current_state = final_state
|
267
|
+
|
268
|
+
# Find next workflows
|
269
|
+
next_workflows = self.get_next_workflows(
|
270
|
+
current_workflow_id, current_state
|
271
|
+
)
|
272
|
+
|
273
|
+
if not next_workflows:
|
274
|
+
# No more workflows to execute
|
275
|
+
logger.info(
|
276
|
+
f"No more workflows to execute after '{current_workflow_id}'"
|
277
|
+
)
|
278
|
+
break
|
279
|
+
|
280
|
+
# Take the first matching workflow as the next one
|
281
|
+
current_workflow_id = next_workflows[0][0]
|
282
|
+
|
283
|
+
# Apply state mapping if needed
|
284
|
+
if next_workflows[0][1]:
|
285
|
+
# If a complete state object is provided, use it
|
286
|
+
if "state" in next_workflows[0][1] and isinstance(
|
287
|
+
next_workflows[0][1]["state"], BaseModel
|
288
|
+
):
|
289
|
+
current_state = next_workflows[0][1]["state"]
|
290
|
+
# Otherwise, merge the mapped values into the current state
|
291
|
+
# using StateManager would be ideal here, but keeping it simple for now
|
292
|
+
|
293
|
+
except Exception as e:
|
294
|
+
logger.error(f"Error executing workflow '{current_workflow_id}': {e}")
|
295
|
+
raise WorkflowExecutionError(
|
296
|
+
f"Failed to execute workflow '{current_workflow_id}': {e}"
|
297
|
+
) from e
|
298
|
+
|
299
|
+
if step_count >= max_steps:
|
300
|
+
logger.warning(f"Reached maximum steps ({max_steps}) in workflow execution")
|
301
|
+
|
302
|
+
return current_state, all_results
|
@@ -0,0 +1,238 @@
|
|
1
|
+
"""State management for workflow execution.
|
2
|
+
|
3
|
+
This module provides tools for managing immutable state throughout workflow execution,
|
4
|
+
making it easier to handle state transitions in a predictable manner.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from copy import deepcopy
|
9
|
+
from typing import Any, Generic, List, Tuple, TypeVar
|
10
|
+
|
11
|
+
from pydantic import BaseModel
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
# Type variable for the state model
|
16
|
+
StateT = TypeVar("StateT", bound=BaseModel)
|
17
|
+
|
18
|
+
|
19
|
+
class StateManager:
|
20
|
+
"""Manages immutable state operations for workflow execution.
|
21
|
+
|
22
|
+
This class provides utilities for updating state objects immutably,
|
23
|
+
focusing on Pydantic models to ensure type safety and validation.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def update_in(state_obj: BaseModel, path: List[str], value: Any) -> BaseModel:
|
28
|
+
"""Update a nested property in the state and return a new state object.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
state_obj: The Pydantic model state object
|
32
|
+
path: List of attribute names forming a path to the property to update
|
33
|
+
value: The new value to set
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
A new state object with the update applied
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
TypeError: If state_obj is not a Pydantic BaseModel
|
40
|
+
KeyError: If the path is invalid
|
41
|
+
"""
|
42
|
+
if not isinstance(state_obj, BaseModel):
|
43
|
+
raise TypeError(f"Expected BaseModel, got {type(state_obj)}")
|
44
|
+
|
45
|
+
# Create deep copy
|
46
|
+
new_state = state_obj.model_copy(deep=True)
|
47
|
+
|
48
|
+
# For simple top-level updates
|
49
|
+
if len(path) == 1:
|
50
|
+
setattr(new_state, path[0], value)
|
51
|
+
return new_state
|
52
|
+
|
53
|
+
# For nested updates
|
54
|
+
current = new_state
|
55
|
+
for i, key in enumerate(path[:-1]):
|
56
|
+
if not hasattr(current, key):
|
57
|
+
raise KeyError(f"Invalid path: {'.'.join(path[:i+1])}")
|
58
|
+
|
59
|
+
# Get the next level object and ensure we're working with a copy
|
60
|
+
next_obj = getattr(current, key)
|
61
|
+
if isinstance(next_obj, BaseModel):
|
62
|
+
next_obj = next_obj.model_copy(deep=True)
|
63
|
+
setattr(current, key, next_obj)
|
64
|
+
elif isinstance(next_obj, dict):
|
65
|
+
next_obj = deepcopy(next_obj)
|
66
|
+
setattr(current, key, next_obj)
|
67
|
+
elif isinstance(next_obj, list):
|
68
|
+
next_obj = deepcopy(next_obj)
|
69
|
+
setattr(current, key, next_obj)
|
70
|
+
|
71
|
+
current = next_obj
|
72
|
+
|
73
|
+
# Set the final value
|
74
|
+
if hasattr(current, path[-1]):
|
75
|
+
setattr(current, path[-1], value)
|
76
|
+
else:
|
77
|
+
raise KeyError(f"Invalid path: {'.'.join(path)}")
|
78
|
+
|
79
|
+
return new_state
|
80
|
+
|
81
|
+
@staticmethod
|
82
|
+
def batch_update(
|
83
|
+
state_obj: BaseModel, updates: List[Tuple[List[str], Any]]
|
84
|
+
) -> BaseModel:
|
85
|
+
"""Apply multiple updates to the state atomically.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
state_obj: The Pydantic model state object
|
89
|
+
updates: List of (path, value) tuples with updates to apply
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
A new state object with all updates applied
|
93
|
+
|
94
|
+
Raises:
|
95
|
+
TypeError: If state_obj is not a Pydantic BaseModel
|
96
|
+
KeyError: If any path is invalid
|
97
|
+
"""
|
98
|
+
if not isinstance(state_obj, BaseModel):
|
99
|
+
raise TypeError(f"Expected BaseModel, got {type(state_obj)}")
|
100
|
+
|
101
|
+
# Create deep copy
|
102
|
+
new_state = state_obj.model_copy(deep=True)
|
103
|
+
|
104
|
+
# Apply each update
|
105
|
+
for path, value in updates:
|
106
|
+
new_state = StateManager.update_in(new_state, path, value)
|
107
|
+
|
108
|
+
return new_state
|
109
|
+
|
110
|
+
@staticmethod
|
111
|
+
def get_in(state_obj: BaseModel, path: List[str]) -> Any:
|
112
|
+
"""Get the value at a nested path.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
state_obj: The Pydantic model state object
|
116
|
+
path: List of attribute names forming a path to the property to retrieve
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The value at the specified path
|
120
|
+
|
121
|
+
Raises:
|
122
|
+
TypeError: If state_obj is not a Pydantic BaseModel
|
123
|
+
KeyError: If the path is invalid
|
124
|
+
"""
|
125
|
+
if not isinstance(state_obj, BaseModel):
|
126
|
+
raise TypeError(f"Expected BaseModel, got {type(state_obj)}")
|
127
|
+
|
128
|
+
# For simple top-level properties
|
129
|
+
if len(path) == 1:
|
130
|
+
if not hasattr(state_obj, path[0]):
|
131
|
+
raise KeyError(f"Invalid path: {path[0]}")
|
132
|
+
return getattr(state_obj, path[0])
|
133
|
+
|
134
|
+
# For nested properties
|
135
|
+
current = state_obj
|
136
|
+
for i, key in enumerate(path):
|
137
|
+
if not hasattr(current, key):
|
138
|
+
raise KeyError(f"Invalid path: {'.'.join(path[:i+1])}")
|
139
|
+
current = getattr(current, key)
|
140
|
+
|
141
|
+
return current
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def merge(state_obj: BaseModel, **updates) -> BaseModel:
|
145
|
+
"""Merge flat updates into state and return a new state.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
state_obj: The Pydantic model state object
|
149
|
+
**updates: Attribute updates to apply to the top level
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
A new state object with the updates applied
|
153
|
+
|
154
|
+
Raises:
|
155
|
+
TypeError: If state_obj is not a Pydantic BaseModel
|
156
|
+
"""
|
157
|
+
if not isinstance(state_obj, BaseModel):
|
158
|
+
raise TypeError(f"Expected BaseModel, got {type(state_obj)}")
|
159
|
+
|
160
|
+
return state_obj.model_copy(update=updates)
|
161
|
+
|
162
|
+
|
163
|
+
class WorkflowStateWrapper(Generic[StateT]):
|
164
|
+
"""Wraps a state object with convenient update methods for use in workflows.
|
165
|
+
|
166
|
+
This wrapper provides a clean interface for immutable state updates
|
167
|
+
within workflow nodes, simplifying state management.
|
168
|
+
"""
|
169
|
+
|
170
|
+
def __init__(self, state: StateT):
|
171
|
+
"""Initialize the state wrapper.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
state: The Pydantic model state object to wrap
|
175
|
+
"""
|
176
|
+
self._state = state
|
177
|
+
|
178
|
+
def update_in(self, path: List[str], value: Any) -> "WorkflowStateWrapper[StateT]":
|
179
|
+
"""Update state at path and return new wrapper.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
path: List of attribute names forming a path to the property to update
|
183
|
+
value: The new value to set
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
A new state wrapper with the update applied
|
187
|
+
"""
|
188
|
+
new_state = StateManager.update_in(self._state, path, value)
|
189
|
+
return WorkflowStateWrapper(new_state)
|
190
|
+
|
191
|
+
def batch_update(
|
192
|
+
self, updates: List[Tuple[List[str], Any]]
|
193
|
+
) -> "WorkflowStateWrapper[StateT]":
|
194
|
+
"""Apply multiple updates to the state atomically.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
updates: List of (path, value) tuples with updates to apply
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
A new state wrapper with all updates applied
|
201
|
+
"""
|
202
|
+
new_state = StateManager.batch_update(self._state, updates)
|
203
|
+
return WorkflowStateWrapper(new_state)
|
204
|
+
|
205
|
+
def get_in(self, path: List[str]) -> Any:
|
206
|
+
"""Get the value at a nested path.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
path: List of attribute names forming a path to the property to retrieve
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
The value at the specified path
|
213
|
+
"""
|
214
|
+
return StateManager.get_in(self._state, path)
|
215
|
+
|
216
|
+
def merge(self, **updates) -> "WorkflowStateWrapper[StateT]":
|
217
|
+
"""Merge flat updates into state and return a new wrapper.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
**updates: Attribute updates to apply to the top level
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
A new state wrapper with the updates applied
|
224
|
+
"""
|
225
|
+
new_state = StateManager.merge(self._state, **updates)
|
226
|
+
return WorkflowStateWrapper(new_state)
|
227
|
+
|
228
|
+
def get_state(self) -> StateT:
|
229
|
+
"""Get the wrapped state object.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
The current state object
|
233
|
+
"""
|
234
|
+
return self._state
|
235
|
+
|
236
|
+
def __repr__(self) -> str:
|
237
|
+
"""Get string representation."""
|
238
|
+
return f"WorkflowStateWrapper({self._state})"
|