griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +8 -942
- griptape_nodes/__main__.py +6 -0
- griptape_nodes/app/app.py +48 -86
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/__init__.py +1 -0
- griptape_nodes/cli/commands/__init__.py +1 -0
- griptape_nodes/cli/commands/config.py +74 -0
- griptape_nodes/cli/commands/engine.py +80 -0
- griptape_nodes/cli/commands/init.py +550 -0
- griptape_nodes/cli/commands/libraries.py +96 -0
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +120 -0
- griptape_nodes/cli/main.py +56 -0
- griptape_nodes/cli/shared.py +75 -0
- griptape_nodes/common/__init__.py +1 -0
- griptape_nodes/common/directed_graph.py +71 -0
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
- griptape_nodes/exe_types/core_types.py +60 -2
- griptape_nodes/exe_types/node_types.py +257 -38
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +195 -94
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/fsm.py +10 -1
- griptape_nodes/machines/parallel_resolution.py +558 -0
- griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +9 -9
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/events/node_events.py +4 -3
- griptape_nodes/retained_mode/griptape_nodes.py +34 -12
- griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
- griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/context_manager.py +6 -5
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
- griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
- griptape_nodes/retained_mode/managers/library_manager.py +35 -25
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +102 -220
- griptape_nodes/retained_mode/managers/object_manager.py +11 -5
- griptape_nodes/retained_mode/managers/os_manager.py +28 -13
- griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
- griptape_nodes/retained_mode/managers/settings.py +116 -7
- griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
- griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
- griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/add_param_button.py +1 -1
- griptape_nodes/traits/button.py +334 -6
- griptape_nodes/traits/color_picker.py +66 -0
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/traits/traits.json +4 -0
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from griptape_nodes.exe_types.core_types import Parameter, ParameterType, ParameterTypeBuiltin
|
|
9
|
+
from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
|
|
10
|
+
from griptape_nodes.exe_types.type_validator import TypeValidator
|
|
11
|
+
from griptape_nodes.machines.dag_builder import NodeState
|
|
12
|
+
from griptape_nodes.machines.fsm import FSM, State
|
|
13
|
+
from griptape_nodes.node_library.library_registry import LibraryRegistry
|
|
14
|
+
from griptape_nodes.retained_mode.events.base_events import (
|
|
15
|
+
ExecutionEvent,
|
|
16
|
+
ExecutionGriptapeNodeEvent,
|
|
17
|
+
)
|
|
18
|
+
from griptape_nodes.retained_mode.events.execution_events import (
|
|
19
|
+
CurrentDataNodeEvent,
|
|
20
|
+
NodeResolvedEvent,
|
|
21
|
+
ParameterValueUpdateEvent,
|
|
22
|
+
)
|
|
23
|
+
from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from griptape_nodes.common.directed_graph import DirectedGraph
|
|
27
|
+
from griptape_nodes.machines.dag_builder import DagBuilder, DagNode
|
|
28
|
+
from griptape_nodes.retained_mode.managers.flow_manager import FlowManager
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger("griptape_nodes")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class WorkflowState(StrEnum):
|
|
34
|
+
"""Workflow execution states."""
|
|
35
|
+
|
|
36
|
+
NO_ERROR = "no_error"
|
|
37
|
+
WORKFLOW_COMPLETE = "workflow_complete"
|
|
38
|
+
ERRORED = "errored"
|
|
39
|
+
CANCELED = "canceled"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ParallelResolutionContext:
|
|
43
|
+
paused: bool
|
|
44
|
+
flow_name: str
|
|
45
|
+
error_message: str | None
|
|
46
|
+
workflow_state: WorkflowState
|
|
47
|
+
# Execution fields
|
|
48
|
+
async_semaphore: asyncio.Semaphore
|
|
49
|
+
task_to_node: dict[asyncio.Task, DagNode]
|
|
50
|
+
dag_builder: DagBuilder | None
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
|
|
54
|
+
) -> None:
|
|
55
|
+
self.flow_name = flow_name
|
|
56
|
+
self.paused = False
|
|
57
|
+
self.error_message = None
|
|
58
|
+
self.workflow_state = WorkflowState.NO_ERROR
|
|
59
|
+
self.dag_builder = dag_builder
|
|
60
|
+
|
|
61
|
+
# Initialize execution fields
|
|
62
|
+
max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
|
|
63
|
+
self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
|
|
64
|
+
self.task_to_node = {}
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def node_to_reference(self) -> dict[str, DagNode]:
|
|
68
|
+
"""Get node_to_reference from dag_builder if available."""
|
|
69
|
+
if not self.dag_builder:
|
|
70
|
+
msg = "DagBuilder is not initialized"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
return self.dag_builder.node_to_reference
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def networks(self) -> dict[str, DirectedGraph]:
|
|
76
|
+
"""Get node_to_reference from dag_builder if available."""
|
|
77
|
+
if not self.dag_builder:
|
|
78
|
+
msg = "DagBuilder is not initialized"
|
|
79
|
+
raise ValueError(msg)
|
|
80
|
+
return self.dag_builder.graphs
|
|
81
|
+
|
|
82
|
+
def reset(self, *, cancel: bool = False) -> None:
|
|
83
|
+
self.paused = False
|
|
84
|
+
if cancel:
|
|
85
|
+
self.workflow_state = WorkflowState.CANCELED
|
|
86
|
+
# Only access node_to_reference if dag_builder exists
|
|
87
|
+
if self.dag_builder:
|
|
88
|
+
for node in self.node_to_reference.values():
|
|
89
|
+
node.node_state = NodeState.CANCELED
|
|
90
|
+
else:
|
|
91
|
+
self.workflow_state = WorkflowState.NO_ERROR
|
|
92
|
+
self.error_message = None
|
|
93
|
+
self.task_to_node.clear()
|
|
94
|
+
|
|
95
|
+
# Clear DAG builder state to allow re-adding nodes on subsequent runs
|
|
96
|
+
if self.dag_builder:
|
|
97
|
+
self.dag_builder.clear()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ExecuteDagState(State):
|
|
101
|
+
@staticmethod
|
|
102
|
+
async def handle_done_nodes(context: ParallelResolutionContext, done_node: DagNode, network_name: str) -> None:
|
|
103
|
+
current_node = done_node.node_reference
|
|
104
|
+
|
|
105
|
+
# Check if node was already resolved (shouldn't happen)
|
|
106
|
+
if current_node.state == NodeResolutionState.RESOLVED:
|
|
107
|
+
logger.error(
|
|
108
|
+
"DUPLICATE COMPLETION DETECTED: Node '%s' was already RESOLVED but handle_done_nodes was called again from network '%s'. This should not happen!",
|
|
109
|
+
current_node.name,
|
|
110
|
+
network_name,
|
|
111
|
+
)
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# Publish all parameter updates.
|
|
115
|
+
current_node.state = NodeResolutionState.RESOLVED
|
|
116
|
+
# Serialization can be slow so only do it if the user wants debug details.
|
|
117
|
+
if logger.level <= logging.DEBUG:
|
|
118
|
+
logger.debug(
|
|
119
|
+
"INPUTS: %s\nOUTPUTS: %s",
|
|
120
|
+
TypeValidator.safe_serialize(current_node.parameter_values),
|
|
121
|
+
TypeValidator.safe_serialize(current_node.parameter_output_values),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
for parameter_name, value in current_node.parameter_output_values.items():
|
|
125
|
+
parameter = current_node.get_parameter_by_name(parameter_name)
|
|
126
|
+
if parameter is None:
|
|
127
|
+
err = f"Canceling flow run. Node '{current_node.name}' specified a Parameter '{parameter_name}', but no such Parameter could be found on that Node."
|
|
128
|
+
raise KeyError(err)
|
|
129
|
+
data_type = parameter.type
|
|
130
|
+
if data_type is None:
|
|
131
|
+
data_type = ParameterTypeBuiltin.NONE.value
|
|
132
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
133
|
+
|
|
134
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
135
|
+
ExecutionGriptapeNodeEvent(
|
|
136
|
+
wrapped_event=ExecutionEvent(
|
|
137
|
+
payload=ParameterValueUpdateEvent(
|
|
138
|
+
node_name=current_node.name,
|
|
139
|
+
parameter_name=parameter_name,
|
|
140
|
+
data_type=data_type,
|
|
141
|
+
value=TypeValidator.safe_serialize(value),
|
|
142
|
+
)
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
# Output values should already be saved!
|
|
147
|
+
library = LibraryRegistry.get_libraries_with_node_type(current_node.__class__.__name__)
|
|
148
|
+
if len(library) == 1:
|
|
149
|
+
library_name = library[0]
|
|
150
|
+
else:
|
|
151
|
+
library_name = None
|
|
152
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
153
|
+
|
|
154
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
155
|
+
ExecutionGriptapeNodeEvent(
|
|
156
|
+
wrapped_event=ExecutionEvent(
|
|
157
|
+
payload=NodeResolvedEvent(
|
|
158
|
+
node_name=current_node.name,
|
|
159
|
+
parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
|
|
160
|
+
node_type=current_node.__class__.__name__,
|
|
161
|
+
specific_library_name=library_name,
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
# Now the final thing to do, is to take their directed graph and update it.
|
|
167
|
+
ExecuteDagState.get_next_control_graph(context, current_node, network_name)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def get_next_control_graph(context: ParallelResolutionContext, node: BaseNode, network_name: str) -> None:
|
|
171
|
+
"""Get next control flow nodes and add them to the DAG graph."""
|
|
172
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
173
|
+
|
|
174
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
175
|
+
|
|
176
|
+
# Early returns for various conditions
|
|
177
|
+
if ExecuteDagState._should_skip_control_flow(context, node, network_name, flow_manager):
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
next_output = node.get_next_control_output()
|
|
181
|
+
if next_output is not None:
|
|
182
|
+
ExecuteDagState._process_next_control_node(context, node, next_output, network_name, flow_manager)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _should_skip_control_flow(
|
|
186
|
+
context: ParallelResolutionContext, node: BaseNode, network_name: str, flow_manager: FlowManager
|
|
187
|
+
) -> bool:
|
|
188
|
+
"""Check if control flow processing should be skipped."""
|
|
189
|
+
if flow_manager.global_single_node_resolution:
|
|
190
|
+
return True
|
|
191
|
+
|
|
192
|
+
if context.dag_builder is not None:
|
|
193
|
+
network = context.dag_builder.graphs.get(network_name, None)
|
|
194
|
+
if network is not None and len(network) > 0:
|
|
195
|
+
return True
|
|
196
|
+
|
|
197
|
+
if node.stop_flow:
|
|
198
|
+
node.stop_flow = False
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _process_next_control_node(
|
|
205
|
+
context: ParallelResolutionContext,
|
|
206
|
+
node: BaseNode,
|
|
207
|
+
next_output: Parameter,
|
|
208
|
+
network_name: str,
|
|
209
|
+
flow_manager: FlowManager,
|
|
210
|
+
) -> None:
|
|
211
|
+
"""Process the next control node in the flow."""
|
|
212
|
+
node_connection = flow_manager.get_connections().get_connected_node(node, next_output)
|
|
213
|
+
if node_connection is not None:
|
|
214
|
+
next_node, _ = node_connection
|
|
215
|
+
|
|
216
|
+
# Prepare next node for execution
|
|
217
|
+
if not next_node.lock:
|
|
218
|
+
next_node.make_node_unresolved(
|
|
219
|
+
current_states_to_trigger_change_event=set(
|
|
220
|
+
{
|
|
221
|
+
NodeResolutionState.UNRESOLVED,
|
|
222
|
+
NodeResolutionState.RESOLVED,
|
|
223
|
+
NodeResolutionState.RESOLVING,
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
ExecuteDagState._add_and_queue_nodes(context, next_node, network_name)
|
|
229
|
+
|
|
230
|
+
@staticmethod
|
|
231
|
+
def _add_and_queue_nodes(context: ParallelResolutionContext, next_node: BaseNode, network_name: str) -> None:
|
|
232
|
+
"""Add nodes to DAG and queue them if ready."""
|
|
233
|
+
if context.dag_builder is not None:
|
|
234
|
+
added_nodes = context.dag_builder.add_node_with_dependencies(next_node, network_name)
|
|
235
|
+
if next_node not in added_nodes:
|
|
236
|
+
added_nodes.append(next_node)
|
|
237
|
+
|
|
238
|
+
# Queue nodes that are ready for execution
|
|
239
|
+
if added_nodes:
|
|
240
|
+
for added_node in added_nodes:
|
|
241
|
+
ExecuteDagState._try_queue_waiting_node(context, added_node.name)
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _try_queue_waiting_node(context: ParallelResolutionContext, node_name: str) -> None:
|
|
245
|
+
"""Try to queue a specific waiting node if it can now be queued."""
|
|
246
|
+
if context.dag_builder is None:
|
|
247
|
+
logger.warning("DAG builder is None - cannot check queueing for node '%s'", node_name)
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
if node_name not in context.node_to_reference:
|
|
251
|
+
logger.warning("Node '%s' not found in node_to_reference - cannot check queueing", node_name)
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
dag_node = context.node_to_reference[node_name]
|
|
255
|
+
|
|
256
|
+
# Only check nodes that are currently waiting
|
|
257
|
+
if dag_node.node_state == NodeState.WAITING:
|
|
258
|
+
can_queue = context.dag_builder.can_queue_control_node(dag_node)
|
|
259
|
+
if can_queue:
|
|
260
|
+
dag_node.node_state = NodeState.QUEUED
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
async def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
|
|
264
|
+
"""Collect output values from resolved upstream nodes and pass them to the current node.
|
|
265
|
+
|
|
266
|
+
This method iterates through all input parameters of the current node, finds their
|
|
267
|
+
connected upstream nodes, and if those nodes are resolved, retrieves their output
|
|
268
|
+
values and passes them through using SetParameterValueRequest.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
node_reference (DagOrchestrator.DagNode): The node to collect values for.
|
|
272
|
+
"""
|
|
273
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
274
|
+
|
|
275
|
+
current_node = node_reference.node_reference
|
|
276
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
277
|
+
|
|
278
|
+
for parameter in current_node.parameters:
|
|
279
|
+
# Skip control type parameters
|
|
280
|
+
if ParameterTypeBuiltin.CONTROL_TYPE.value.lower() == parameter.output_type:
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
# Get the connected upstream node for this parameter
|
|
284
|
+
upstream_connection = connections.get_connected_node(current_node, parameter)
|
|
285
|
+
if upstream_connection:
|
|
286
|
+
upstream_node, upstream_parameter = upstream_connection
|
|
287
|
+
|
|
288
|
+
# If the upstream node is resolved, collect its output value
|
|
289
|
+
if upstream_parameter.name in upstream_node.parameter_output_values:
|
|
290
|
+
output_value = upstream_node.parameter_output_values[upstream_parameter.name]
|
|
291
|
+
else:
|
|
292
|
+
output_value = upstream_node.get_parameter_value(upstream_parameter.name)
|
|
293
|
+
|
|
294
|
+
# Pass the value through using the same mechanism as normal resolution
|
|
295
|
+
# Skip propagation for Control Parameters as they should not receive values
|
|
296
|
+
if (
|
|
297
|
+
ParameterType.attempt_get_builtin(upstream_parameter.output_type)
|
|
298
|
+
!= ParameterTypeBuiltin.CONTROL_TYPE
|
|
299
|
+
):
|
|
300
|
+
await GriptapeNodes.get_instance().ahandle_request(
|
|
301
|
+
SetParameterValueRequest(
|
|
302
|
+
parameter_name=parameter.name,
|
|
303
|
+
node_name=current_node.name,
|
|
304
|
+
value=output_value,
|
|
305
|
+
data_type=upstream_parameter.output_type,
|
|
306
|
+
incoming_connection_source_node_name=upstream_node.name,
|
|
307
|
+
incoming_connection_source_parameter_name=upstream_parameter.name,
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def build_node_states(context: ParallelResolutionContext) -> tuple[set[str], set[str], set[str]]:
|
|
313
|
+
networks = context.networks
|
|
314
|
+
leaf_nodes = set()
|
|
315
|
+
for network in networks.values():
|
|
316
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
317
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
318
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
319
|
+
# Add all leaf nodes from all networks (using set union to avoid duplicates)
|
|
320
|
+
leaf_nodes.update([n for n in network.nodes() if network.in_degree(n) == 0])
|
|
321
|
+
canceled_nodes = set()
|
|
322
|
+
queued_nodes = set()
|
|
323
|
+
for node in leaf_nodes:
|
|
324
|
+
node_reference = context.node_to_reference[node]
|
|
325
|
+
# If the node is locked, mark it as done so it skips execution
|
|
326
|
+
if node_reference.node_reference.lock:
|
|
327
|
+
node_reference.node_state = NodeState.DONE
|
|
328
|
+
continue
|
|
329
|
+
node_state = node_reference.node_state
|
|
330
|
+
if node_state == NodeState.CANCELED:
|
|
331
|
+
canceled_nodes.add(node)
|
|
332
|
+
elif node_state == NodeState.QUEUED:
|
|
333
|
+
queued_nodes.add(node)
|
|
334
|
+
return canceled_nodes, queued_nodes, leaf_nodes
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
async def pop_done_states(context: ParallelResolutionContext) -> None:
|
|
338
|
+
networks = context.networks
|
|
339
|
+
handled_nodes = set() # Track nodes we've already processed to avoid duplicates
|
|
340
|
+
|
|
341
|
+
for network_name, network in networks.items():
|
|
342
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
343
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
344
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
345
|
+
leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
|
|
346
|
+
for node in leaf_nodes:
|
|
347
|
+
node_reference = context.node_to_reference[node]
|
|
348
|
+
node_state = node_reference.node_state
|
|
349
|
+
# If the node is locked, mark it as done so it skips execution
|
|
350
|
+
if node_reference.node_reference.lock or node_state == NodeState.DONE:
|
|
351
|
+
node_reference.node_state = NodeState.DONE
|
|
352
|
+
network.remove_node(node)
|
|
353
|
+
|
|
354
|
+
# Only call handle_done_nodes once per node (first network that processes it)
|
|
355
|
+
if node not in handled_nodes:
|
|
356
|
+
handled_nodes.add(node)
|
|
357
|
+
await ExecuteDagState.handle_done_nodes(context, context.node_to_reference[node], network_name)
|
|
358
|
+
|
|
359
|
+
# After processing completions in this network, check if any remaining leaf nodes can now be queued
|
|
360
|
+
remaining_leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
|
|
361
|
+
|
|
362
|
+
for leaf_node in remaining_leaf_nodes:
|
|
363
|
+
if leaf_node in context.node_to_reference:
|
|
364
|
+
node_state = context.node_to_reference[leaf_node].node_state
|
|
365
|
+
ExecuteDagState._try_queue_waiting_node(context, leaf_node)
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
|
|
369
|
+
async with semaphore:
|
|
370
|
+
await current_node.node_reference.aprocess()
|
|
371
|
+
|
|
372
|
+
@staticmethod
|
|
373
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
374
|
+
# Start DAG execution after resolution is complete
|
|
375
|
+
for node in context.node_to_reference.values():
|
|
376
|
+
# Only queue nodes that are waiting - preserve state of already processed nodes.
|
|
377
|
+
if node.node_state == NodeState.WAITING:
|
|
378
|
+
node.node_state = NodeState.QUEUED
|
|
379
|
+
|
|
380
|
+
context.workflow_state = WorkflowState.NO_ERROR
|
|
381
|
+
|
|
382
|
+
if not context.paused:
|
|
383
|
+
return ExecuteDagState
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
@staticmethod
|
|
387
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: C901, PLR0911
|
|
388
|
+
# Check if execution is paused
|
|
389
|
+
if context.paused:
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
# Check if DAG execution is complete
|
|
393
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
394
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
395
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
396
|
+
canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
397
|
+
# We have no more leaf nodes. Quit early.
|
|
398
|
+
if not leaf_nodes:
|
|
399
|
+
context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
|
|
400
|
+
return DagCompleteState
|
|
401
|
+
if len(canceled_nodes) == len(leaf_nodes):
|
|
402
|
+
# All leaf nodes are cancelled.
|
|
403
|
+
# Set state to workflow complete.
|
|
404
|
+
context.workflow_state = WorkflowState.CANCELED
|
|
405
|
+
return DagCompleteState
|
|
406
|
+
# Are there any in the queued state?
|
|
407
|
+
for node in queued_nodes:
|
|
408
|
+
# Process all queued nodes - the async semaphore will handle concurrency limits
|
|
409
|
+
node_reference = context.node_to_reference[node]
|
|
410
|
+
|
|
411
|
+
# Collect parameter values from upstream nodes before executing
|
|
412
|
+
try:
|
|
413
|
+
await ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
|
|
414
|
+
except Exception as e:
|
|
415
|
+
logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
|
|
416
|
+
context.error_message = (
|
|
417
|
+
f"Parameter passthrough failed for node '{node_reference.node_reference.name}': {e}"
|
|
418
|
+
)
|
|
419
|
+
context.workflow_state = WorkflowState.ERRORED
|
|
420
|
+
return ErrorState
|
|
421
|
+
|
|
422
|
+
# Clear all of the current output values but don't broadcast the clearing.
|
|
423
|
+
# to avoid any flickering in subscribers (UI).
|
|
424
|
+
node_reference.node_reference.parameter_output_values.silent_clear()
|
|
425
|
+
exceptions = node_reference.node_reference.validate_before_node_run()
|
|
426
|
+
if exceptions:
|
|
427
|
+
msg = f"Canceling flow run. Node '{node_reference.node_reference.name}' encountered problems: {exceptions}"
|
|
428
|
+
logger.error(msg)
|
|
429
|
+
return ErrorState
|
|
430
|
+
|
|
431
|
+
def on_task_done(task: asyncio.Task) -> None:
|
|
432
|
+
if task in context.task_to_node:
|
|
433
|
+
node = context.task_to_node[task]
|
|
434
|
+
node.node_state = NodeState.DONE
|
|
435
|
+
|
|
436
|
+
# Execute the node asynchronously
|
|
437
|
+
logger.debug(
|
|
438
|
+
"CREATING EXECUTION TASK for node '%s' - this should only happen once per node!",
|
|
439
|
+
node_reference.node_reference.name,
|
|
440
|
+
)
|
|
441
|
+
node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
|
|
442
|
+
# Add a callback to set node to done when task has finished.
|
|
443
|
+
context.task_to_node[node_task] = node_reference
|
|
444
|
+
node_reference.task_reference = node_task
|
|
445
|
+
node_task.add_done_callback(lambda t: on_task_done(t))
|
|
446
|
+
node_reference.node_state = NodeState.PROCESSING
|
|
447
|
+
node_reference.node_reference.state = NodeResolutionState.RESOLVING
|
|
448
|
+
|
|
449
|
+
# Send an event that this is a current data node:
|
|
450
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
451
|
+
|
|
452
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
453
|
+
ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=node)))
|
|
454
|
+
)
|
|
455
|
+
# Wait for a task to finish
|
|
456
|
+
done, _ = await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
457
|
+
# Prevent task being removed before return
|
|
458
|
+
for task in done:
|
|
459
|
+
context.task_to_node.pop(task)
|
|
460
|
+
# Once a task has finished, loop back to the top.
|
|
461
|
+
await ExecuteDagState.pop_done_states(context)
|
|
462
|
+
# Remove all nodes that are done
|
|
463
|
+
if context.paused:
|
|
464
|
+
return None
|
|
465
|
+
return ExecuteDagState
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class ErrorState(State):
|
|
469
|
+
@staticmethod
|
|
470
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
471
|
+
if context.error_message:
|
|
472
|
+
logger.error("DAG execution error: %s", context.error_message)
|
|
473
|
+
for node in context.node_to_reference.values():
|
|
474
|
+
# Cancel all nodes that haven't yet begun processing.
|
|
475
|
+
if node.node_state == NodeState.QUEUED:
|
|
476
|
+
node.node_state = NodeState.CANCELED
|
|
477
|
+
# Shut down and cancel all threads/tasks that haven't yet ran. Currently running ones will not be affected.
|
|
478
|
+
# Cancel async tasks
|
|
479
|
+
for task in list(context.task_to_node.keys()):
|
|
480
|
+
if not task.done():
|
|
481
|
+
task.cancel()
|
|
482
|
+
return ErrorState
|
|
483
|
+
|
|
484
|
+
@staticmethod
|
|
485
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
486
|
+
# Don't modify lists while iterating through them.
|
|
487
|
+
task_to_node = context.task_to_node
|
|
488
|
+
for task, node in task_to_node.copy().items():
|
|
489
|
+
if task.done():
|
|
490
|
+
node.node_state = NodeState.DONE
|
|
491
|
+
elif task.cancelled():
|
|
492
|
+
node.node_state = NodeState.CANCELED
|
|
493
|
+
task_to_node.pop(task)
|
|
494
|
+
|
|
495
|
+
# Handle async tasks
|
|
496
|
+
task_to_node = context.task_to_node
|
|
497
|
+
for task, node in task_to_node.copy().items():
|
|
498
|
+
if task.done():
|
|
499
|
+
node.node_state = NodeState.DONE
|
|
500
|
+
elif task.cancelled():
|
|
501
|
+
node.node_state = NodeState.CANCELED
|
|
502
|
+
task_to_node.pop(task)
|
|
503
|
+
|
|
504
|
+
if len(task_to_node) == 0:
|
|
505
|
+
# Finish up. We failed.
|
|
506
|
+
context.workflow_state = WorkflowState.ERRORED
|
|
507
|
+
context.networks.clear()
|
|
508
|
+
context.node_to_reference.clear()
|
|
509
|
+
context.task_to_node.clear()
|
|
510
|
+
return DagCompleteState
|
|
511
|
+
# Let's continue going through until everything is cancelled.
|
|
512
|
+
return ErrorState
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class DagCompleteState(State):
|
|
516
|
+
@staticmethod
|
|
517
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
518
|
+
# Clear the DAG builder so we don't have any leftover nodes in node_to_reference.
|
|
519
|
+
if context.dag_builder is not None:
|
|
520
|
+
context.dag_builder.clear()
|
|
521
|
+
return None
|
|
522
|
+
|
|
523
|
+
@staticmethod
|
|
524
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: ARG004
|
|
525
|
+
return None
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
|
|
529
|
+
"""State machine for building DAG structure without execution."""
|
|
530
|
+
|
|
531
|
+
def __init__(
|
|
532
|
+
self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
|
|
533
|
+
) -> None:
|
|
534
|
+
resolution_context = ParallelResolutionContext(
|
|
535
|
+
flow_name, max_nodes_in_parallel=max_nodes_in_parallel, dag_builder=dag_builder
|
|
536
|
+
)
|
|
537
|
+
super().__init__(resolution_context)
|
|
538
|
+
|
|
539
|
+
async def resolve_node(self, node: BaseNode | None = None) -> None: # noqa: ARG002
|
|
540
|
+
"""Execute the DAG structure using the existing DagBuilder."""
|
|
541
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
542
|
+
|
|
543
|
+
if self.context.dag_builder is None:
|
|
544
|
+
self.context.dag_builder = GriptapeNodes.FlowManager().global_dag_builder
|
|
545
|
+
await self.start(ExecuteDagState)
|
|
546
|
+
|
|
547
|
+
def change_debug_mode(self, *, debug_mode: bool) -> None:
|
|
548
|
+
self._context.paused = debug_mode
|
|
549
|
+
|
|
550
|
+
def is_complete(self) -> bool:
|
|
551
|
+
return self._current_state is DagCompleteState
|
|
552
|
+
|
|
553
|
+
def is_started(self) -> bool:
|
|
554
|
+
return self._current_state is not None
|
|
555
|
+
|
|
556
|
+
def reset_machine(self, *, cancel: bool = False) -> None:
|
|
557
|
+
self._context.reset(cancel=cancel)
|
|
558
|
+
self._current_state = None
|