griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. griptape_nodes/__init__.py +8 -942
  2. griptape_nodes/__main__.py +6 -0
  3. griptape_nodes/app/app.py +48 -86
  4. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
  5. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
  6. griptape_nodes/cli/__init__.py +1 -0
  7. griptape_nodes/cli/commands/__init__.py +1 -0
  8. griptape_nodes/cli/commands/config.py +74 -0
  9. griptape_nodes/cli/commands/engine.py +80 -0
  10. griptape_nodes/cli/commands/init.py +550 -0
  11. griptape_nodes/cli/commands/libraries.py +96 -0
  12. griptape_nodes/cli/commands/models.py +504 -0
  13. griptape_nodes/cli/commands/self.py +120 -0
  14. griptape_nodes/cli/main.py +56 -0
  15. griptape_nodes/cli/shared.py +75 -0
  16. griptape_nodes/common/__init__.py +1 -0
  17. griptape_nodes/common/directed_graph.py +71 -0
  18. griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
  19. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
  20. griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
  21. griptape_nodes/exe_types/core_types.py +60 -2
  22. griptape_nodes/exe_types/node_types.py +257 -38
  23. griptape_nodes/exe_types/param_components/__init__.py +1 -0
  24. griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
  25. griptape_nodes/machines/control_flow.py +195 -94
  26. griptape_nodes/machines/dag_builder.py +207 -0
  27. griptape_nodes/machines/fsm.py +10 -1
  28. griptape_nodes/machines/parallel_resolution.py +558 -0
  29. griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
  30. griptape_nodes/node_library/library_registry.py +34 -1
  31. griptape_nodes/retained_mode/events/app_events.py +5 -1
  32. griptape_nodes/retained_mode/events/base_events.py +9 -9
  33. griptape_nodes/retained_mode/events/config_events.py +30 -0
  34. griptape_nodes/retained_mode/events/execution_events.py +2 -2
  35. griptape_nodes/retained_mode/events/model_events.py +296 -0
  36. griptape_nodes/retained_mode/events/node_events.py +4 -3
  37. griptape_nodes/retained_mode/griptape_nodes.py +34 -12
  38. griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
  39. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
  40. griptape_nodes/retained_mode/managers/config_manager.py +44 -3
  41. griptape_nodes/retained_mode/managers/context_manager.py +6 -5
  42. griptape_nodes/retained_mode/managers/event_manager.py +8 -2
  43. griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
  44. griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
  45. griptape_nodes/retained_mode/managers/library_manager.py +35 -25
  46. griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
  47. griptape_nodes/retained_mode/managers/node_manager.py +102 -220
  48. griptape_nodes/retained_mode/managers/object_manager.py +11 -5
  49. griptape_nodes/retained_mode/managers/os_manager.py +28 -13
  50. griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
  51. griptape_nodes/retained_mode/managers/settings.py +116 -7
  52. griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
  53. griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
  54. griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
  55. griptape_nodes/retained_mode/retained_mode.py +19 -0
  56. griptape_nodes/servers/__init__.py +1 -0
  57. griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
  58. griptape_nodes/{app/api.py → servers/static.py} +43 -40
  59. griptape_nodes/traits/add_param_button.py +1 -1
  60. griptape_nodes/traits/button.py +334 -6
  61. griptape_nodes/traits/color_picker.py +66 -0
  62. griptape_nodes/traits/multi_options.py +188 -0
  63. griptape_nodes/traits/numbers_selector.py +77 -0
  64. griptape_nodes/traits/options.py +93 -2
  65. griptape_nodes/traits/traits.json +4 -0
  66. griptape_nodes/utils/async_utils.py +31 -0
  67. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
  68. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
  69. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
  70. /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
  71. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,558 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from enum import StrEnum
6
+ from typing import TYPE_CHECKING
7
+
8
+ from griptape_nodes.exe_types.core_types import Parameter, ParameterType, ParameterTypeBuiltin
9
+ from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
10
+ from griptape_nodes.exe_types.type_validator import TypeValidator
11
+ from griptape_nodes.machines.dag_builder import NodeState
12
+ from griptape_nodes.machines.fsm import FSM, State
13
+ from griptape_nodes.node_library.library_registry import LibraryRegistry
14
+ from griptape_nodes.retained_mode.events.base_events import (
15
+ ExecutionEvent,
16
+ ExecutionGriptapeNodeEvent,
17
+ )
18
+ from griptape_nodes.retained_mode.events.execution_events import (
19
+ CurrentDataNodeEvent,
20
+ NodeResolvedEvent,
21
+ ParameterValueUpdateEvent,
22
+ )
23
+ from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
24
+
25
+ if TYPE_CHECKING:
26
+ from griptape_nodes.common.directed_graph import DirectedGraph
27
+ from griptape_nodes.machines.dag_builder import DagBuilder, DagNode
28
+ from griptape_nodes.retained_mode.managers.flow_manager import FlowManager
29
+
30
+ logger = logging.getLogger("griptape_nodes")
31
+
32
+
33
+ class WorkflowState(StrEnum):
34
+ """Workflow execution states."""
35
+
36
+ NO_ERROR = "no_error"
37
+ WORKFLOW_COMPLETE = "workflow_complete"
38
+ ERRORED = "errored"
39
+ CANCELED = "canceled"
40
+
41
+
42
+ class ParallelResolutionContext:
43
+ paused: bool
44
+ flow_name: str
45
+ error_message: str | None
46
+ workflow_state: WorkflowState
47
+ # Execution fields
48
+ async_semaphore: asyncio.Semaphore
49
+ task_to_node: dict[asyncio.Task, DagNode]
50
+ dag_builder: DagBuilder | None
51
+
52
+ def __init__(
53
+ self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
54
+ ) -> None:
55
+ self.flow_name = flow_name
56
+ self.paused = False
57
+ self.error_message = None
58
+ self.workflow_state = WorkflowState.NO_ERROR
59
+ self.dag_builder = dag_builder
60
+
61
+ # Initialize execution fields
62
+ max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
63
+ self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
64
+ self.task_to_node = {}
65
+
66
+ @property
67
+ def node_to_reference(self) -> dict[str, DagNode]:
68
+ """Get node_to_reference from dag_builder if available."""
69
+ if not self.dag_builder:
70
+ msg = "DagBuilder is not initialized"
71
+ raise ValueError(msg)
72
+ return self.dag_builder.node_to_reference
73
+
74
+ @property
75
+ def networks(self) -> dict[str, DirectedGraph]:
76
+ """Get node_to_reference from dag_builder if available."""
77
+ if not self.dag_builder:
78
+ msg = "DagBuilder is not initialized"
79
+ raise ValueError(msg)
80
+ return self.dag_builder.graphs
81
+
82
+ def reset(self, *, cancel: bool = False) -> None:
83
+ self.paused = False
84
+ if cancel:
85
+ self.workflow_state = WorkflowState.CANCELED
86
+ # Only access node_to_reference if dag_builder exists
87
+ if self.dag_builder:
88
+ for node in self.node_to_reference.values():
89
+ node.node_state = NodeState.CANCELED
90
+ else:
91
+ self.workflow_state = WorkflowState.NO_ERROR
92
+ self.error_message = None
93
+ self.task_to_node.clear()
94
+
95
+ # Clear DAG builder state to allow re-adding nodes on subsequent runs
96
+ if self.dag_builder:
97
+ self.dag_builder.clear()
98
+
99
+
100
+ class ExecuteDagState(State):
101
+ @staticmethod
102
+ async def handle_done_nodes(context: ParallelResolutionContext, done_node: DagNode, network_name: str) -> None:
103
+ current_node = done_node.node_reference
104
+
105
+ # Check if node was already resolved (shouldn't happen)
106
+ if current_node.state == NodeResolutionState.RESOLVED:
107
+ logger.error(
108
+ "DUPLICATE COMPLETION DETECTED: Node '%s' was already RESOLVED but handle_done_nodes was called again from network '%s'. This should not happen!",
109
+ current_node.name,
110
+ network_name,
111
+ )
112
+ return
113
+
114
+ # Publish all parameter updates.
115
+ current_node.state = NodeResolutionState.RESOLVED
116
+ # Serialization can be slow so only do it if the user wants debug details.
117
+ if logger.level <= logging.DEBUG:
118
+ logger.debug(
119
+ "INPUTS: %s\nOUTPUTS: %s",
120
+ TypeValidator.safe_serialize(current_node.parameter_values),
121
+ TypeValidator.safe_serialize(current_node.parameter_output_values),
122
+ )
123
+
124
+ for parameter_name, value in current_node.parameter_output_values.items():
125
+ parameter = current_node.get_parameter_by_name(parameter_name)
126
+ if parameter is None:
127
+ err = f"Canceling flow run. Node '{current_node.name}' specified a Parameter '{parameter_name}', but no such Parameter could be found on that Node."
128
+ raise KeyError(err)
129
+ data_type = parameter.type
130
+ if data_type is None:
131
+ data_type = ParameterTypeBuiltin.NONE.value
132
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
133
+
134
+ await GriptapeNodes.EventManager().aput_event(
135
+ ExecutionGriptapeNodeEvent(
136
+ wrapped_event=ExecutionEvent(
137
+ payload=ParameterValueUpdateEvent(
138
+ node_name=current_node.name,
139
+ parameter_name=parameter_name,
140
+ data_type=data_type,
141
+ value=TypeValidator.safe_serialize(value),
142
+ )
143
+ ),
144
+ )
145
+ )
146
+ # Output values should already be saved!
147
+ library = LibraryRegistry.get_libraries_with_node_type(current_node.__class__.__name__)
148
+ if len(library) == 1:
149
+ library_name = library[0]
150
+ else:
151
+ library_name = None
152
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
153
+
154
+ await GriptapeNodes.EventManager().aput_event(
155
+ ExecutionGriptapeNodeEvent(
156
+ wrapped_event=ExecutionEvent(
157
+ payload=NodeResolvedEvent(
158
+ node_name=current_node.name,
159
+ parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
160
+ node_type=current_node.__class__.__name__,
161
+ specific_library_name=library_name,
162
+ )
163
+ )
164
+ )
165
+ )
166
+ # Now the final thing to do, is to take their directed graph and update it.
167
+ ExecuteDagState.get_next_control_graph(context, current_node, network_name)
168
+
169
+ @staticmethod
170
+ def get_next_control_graph(context: ParallelResolutionContext, node: BaseNode, network_name: str) -> None:
171
+ """Get next control flow nodes and add them to the DAG graph."""
172
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
173
+
174
+ flow_manager = GriptapeNodes.FlowManager()
175
+
176
+ # Early returns for various conditions
177
+ if ExecuteDagState._should_skip_control_flow(context, node, network_name, flow_manager):
178
+ return
179
+
180
+ next_output = node.get_next_control_output()
181
+ if next_output is not None:
182
+ ExecuteDagState._process_next_control_node(context, node, next_output, network_name, flow_manager)
183
+
184
+ @staticmethod
185
+ def _should_skip_control_flow(
186
+ context: ParallelResolutionContext, node: BaseNode, network_name: str, flow_manager: FlowManager
187
+ ) -> bool:
188
+ """Check if control flow processing should be skipped."""
189
+ if flow_manager.global_single_node_resolution:
190
+ return True
191
+
192
+ if context.dag_builder is not None:
193
+ network = context.dag_builder.graphs.get(network_name, None)
194
+ if network is not None and len(network) > 0:
195
+ return True
196
+
197
+ if node.stop_flow:
198
+ node.stop_flow = False
199
+ return True
200
+
201
+ return False
202
+
203
+ @staticmethod
204
+ def _process_next_control_node(
205
+ context: ParallelResolutionContext,
206
+ node: BaseNode,
207
+ next_output: Parameter,
208
+ network_name: str,
209
+ flow_manager: FlowManager,
210
+ ) -> None:
211
+ """Process the next control node in the flow."""
212
+ node_connection = flow_manager.get_connections().get_connected_node(node, next_output)
213
+ if node_connection is not None:
214
+ next_node, _ = node_connection
215
+
216
+ # Prepare next node for execution
217
+ if not next_node.lock:
218
+ next_node.make_node_unresolved(
219
+ current_states_to_trigger_change_event=set(
220
+ {
221
+ NodeResolutionState.UNRESOLVED,
222
+ NodeResolutionState.RESOLVED,
223
+ NodeResolutionState.RESOLVING,
224
+ }
225
+ )
226
+ )
227
+
228
+ ExecuteDagState._add_and_queue_nodes(context, next_node, network_name)
229
+
230
+ @staticmethod
231
+ def _add_and_queue_nodes(context: ParallelResolutionContext, next_node: BaseNode, network_name: str) -> None:
232
+ """Add nodes to DAG and queue them if ready."""
233
+ if context.dag_builder is not None:
234
+ added_nodes = context.dag_builder.add_node_with_dependencies(next_node, network_name)
235
+ if next_node not in added_nodes:
236
+ added_nodes.append(next_node)
237
+
238
+ # Queue nodes that are ready for execution
239
+ if added_nodes:
240
+ for added_node in added_nodes:
241
+ ExecuteDagState._try_queue_waiting_node(context, added_node.name)
242
+
243
+ @staticmethod
244
+ def _try_queue_waiting_node(context: ParallelResolutionContext, node_name: str) -> None:
245
+ """Try to queue a specific waiting node if it can now be queued."""
246
+ if context.dag_builder is None:
247
+ logger.warning("DAG builder is None - cannot check queueing for node '%s'", node_name)
248
+ return
249
+
250
+ if node_name not in context.node_to_reference:
251
+ logger.warning("Node '%s' not found in node_to_reference - cannot check queueing", node_name)
252
+ return
253
+
254
+ dag_node = context.node_to_reference[node_name]
255
+
256
+ # Only check nodes that are currently waiting
257
+ if dag_node.node_state == NodeState.WAITING:
258
+ can_queue = context.dag_builder.can_queue_control_node(dag_node)
259
+ if can_queue:
260
+ dag_node.node_state = NodeState.QUEUED
261
+
262
+ @staticmethod
263
+ async def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
264
+ """Collect output values from resolved upstream nodes and pass them to the current node.
265
+
266
+ This method iterates through all input parameters of the current node, finds their
267
+ connected upstream nodes, and if those nodes are resolved, retrieves their output
268
+ values and passes them through using SetParameterValueRequest.
269
+
270
+ Args:
271
+ node_reference (DagOrchestrator.DagNode): The node to collect values for.
272
+ """
273
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
274
+
275
+ current_node = node_reference.node_reference
276
+ connections = GriptapeNodes.FlowManager().get_connections()
277
+
278
+ for parameter in current_node.parameters:
279
+ # Skip control type parameters
280
+ if ParameterTypeBuiltin.CONTROL_TYPE.value.lower() == parameter.output_type:
281
+ continue
282
+
283
+ # Get the connected upstream node for this parameter
284
+ upstream_connection = connections.get_connected_node(current_node, parameter)
285
+ if upstream_connection:
286
+ upstream_node, upstream_parameter = upstream_connection
287
+
288
+ # If the upstream node is resolved, collect its output value
289
+ if upstream_parameter.name in upstream_node.parameter_output_values:
290
+ output_value = upstream_node.parameter_output_values[upstream_parameter.name]
291
+ else:
292
+ output_value = upstream_node.get_parameter_value(upstream_parameter.name)
293
+
294
+ # Pass the value through using the same mechanism as normal resolution
295
+ # Skip propagation for Control Parameters as they should not receive values
296
+ if (
297
+ ParameterType.attempt_get_builtin(upstream_parameter.output_type)
298
+ != ParameterTypeBuiltin.CONTROL_TYPE
299
+ ):
300
+ await GriptapeNodes.get_instance().ahandle_request(
301
+ SetParameterValueRequest(
302
+ parameter_name=parameter.name,
303
+ node_name=current_node.name,
304
+ value=output_value,
305
+ data_type=upstream_parameter.output_type,
306
+ incoming_connection_source_node_name=upstream_node.name,
307
+ incoming_connection_source_parameter_name=upstream_parameter.name,
308
+ )
309
+ )
310
+
311
+ @staticmethod
312
+ def build_node_states(context: ParallelResolutionContext) -> tuple[set[str], set[str], set[str]]:
313
+ networks = context.networks
314
+ leaf_nodes = set()
315
+ for network in networks.values():
316
+ # Check and see if there are leaf nodes that are cancelled.
317
+ # Reinitialize leaf nodes since maybe we changed things up.
318
+ # We removed nodes from the network. There may be new leaf nodes.
319
+ # Add all leaf nodes from all networks (using set union to avoid duplicates)
320
+ leaf_nodes.update([n for n in network.nodes() if network.in_degree(n) == 0])
321
+ canceled_nodes = set()
322
+ queued_nodes = set()
323
+ for node in leaf_nodes:
324
+ node_reference = context.node_to_reference[node]
325
+ # If the node is locked, mark it as done so it skips execution
326
+ if node_reference.node_reference.lock:
327
+ node_reference.node_state = NodeState.DONE
328
+ continue
329
+ node_state = node_reference.node_state
330
+ if node_state == NodeState.CANCELED:
331
+ canceled_nodes.add(node)
332
+ elif node_state == NodeState.QUEUED:
333
+ queued_nodes.add(node)
334
+ return canceled_nodes, queued_nodes, leaf_nodes
335
+
336
+ @staticmethod
337
+ async def pop_done_states(context: ParallelResolutionContext) -> None:
338
+ networks = context.networks
339
+ handled_nodes = set() # Track nodes we've already processed to avoid duplicates
340
+
341
+ for network_name, network in networks.items():
342
+ # Check and see if there are leaf nodes that are cancelled.
343
+ # Reinitialize leaf nodes since maybe we changed things up.
344
+ # We removed nodes from the network. There may be new leaf nodes.
345
+ leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
346
+ for node in leaf_nodes:
347
+ node_reference = context.node_to_reference[node]
348
+ node_state = node_reference.node_state
349
+ # If the node is locked, mark it as done so it skips execution
350
+ if node_reference.node_reference.lock or node_state == NodeState.DONE:
351
+ node_reference.node_state = NodeState.DONE
352
+ network.remove_node(node)
353
+
354
+ # Only call handle_done_nodes once per node (first network that processes it)
355
+ if node not in handled_nodes:
356
+ handled_nodes.add(node)
357
+ await ExecuteDagState.handle_done_nodes(context, context.node_to_reference[node], network_name)
358
+
359
+ # After processing completions in this network, check if any remaining leaf nodes can now be queued
360
+ remaining_leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
361
+
362
+ for leaf_node in remaining_leaf_nodes:
363
+ if leaf_node in context.node_to_reference:
364
+ node_state = context.node_to_reference[leaf_node].node_state
365
+ ExecuteDagState._try_queue_waiting_node(context, leaf_node)
366
+
367
+ @staticmethod
368
+ async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
369
+ async with semaphore:
370
+ await current_node.node_reference.aprocess()
371
+
372
+ @staticmethod
373
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
374
+ # Start DAG execution after resolution is complete
375
+ for node in context.node_to_reference.values():
376
+ # Only queue nodes that are waiting - preserve state of already processed nodes.
377
+ if node.node_state == NodeState.WAITING:
378
+ node.node_state = NodeState.QUEUED
379
+
380
+ context.workflow_state = WorkflowState.NO_ERROR
381
+
382
+ if not context.paused:
383
+ return ExecuteDagState
384
+ return None
385
+
386
+ @staticmethod
387
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: C901, PLR0911
388
+ # Check if execution is paused
389
+ if context.paused:
390
+ return None
391
+
392
+ # Check if DAG execution is complete
393
+ # Check and see if there are leaf nodes that are cancelled.
394
+ # Reinitialize leaf nodes since maybe we changed things up.
395
+ # We removed nodes from the network. There may be new leaf nodes.
396
+ canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
397
+ # We have no more leaf nodes. Quit early.
398
+ if not leaf_nodes:
399
+ context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
400
+ return DagCompleteState
401
+ if len(canceled_nodes) == len(leaf_nodes):
402
+ # All leaf nodes are cancelled.
403
+ # Set state to workflow complete.
404
+ context.workflow_state = WorkflowState.CANCELED
405
+ return DagCompleteState
406
+ # Are there any in the queued state?
407
+ for node in queued_nodes:
408
+ # Process all queued nodes - the async semaphore will handle concurrency limits
409
+ node_reference = context.node_to_reference[node]
410
+
411
+ # Collect parameter values from upstream nodes before executing
412
+ try:
413
+ await ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
414
+ except Exception as e:
415
+ logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
416
+ context.error_message = (
417
+ f"Parameter passthrough failed for node '{node_reference.node_reference.name}': {e}"
418
+ )
419
+ context.workflow_state = WorkflowState.ERRORED
420
+ return ErrorState
421
+
422
+ # Clear all of the current output values but don't broadcast the clearing.
423
+ # to avoid any flickering in subscribers (UI).
424
+ node_reference.node_reference.parameter_output_values.silent_clear()
425
+ exceptions = node_reference.node_reference.validate_before_node_run()
426
+ if exceptions:
427
+ msg = f"Canceling flow run. Node '{node_reference.node_reference.name}' encountered problems: {exceptions}"
428
+ logger.error(msg)
429
+ return ErrorState
430
+
431
+ def on_task_done(task: asyncio.Task) -> None:
432
+ if task in context.task_to_node:
433
+ node = context.task_to_node[task]
434
+ node.node_state = NodeState.DONE
435
+
436
+ # Execute the node asynchronously
437
+ logger.debug(
438
+ "CREATING EXECUTION TASK for node '%s' - this should only happen once per node!",
439
+ node_reference.node_reference.name,
440
+ )
441
+ node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
442
+ # Add a callback to set node to done when task has finished.
443
+ context.task_to_node[node_task] = node_reference
444
+ node_reference.task_reference = node_task
445
+ node_task.add_done_callback(lambda t: on_task_done(t))
446
+ node_reference.node_state = NodeState.PROCESSING
447
+ node_reference.node_reference.state = NodeResolutionState.RESOLVING
448
+
449
+ # Send an event that this is a current data node:
450
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
451
+
452
+ await GriptapeNodes.EventManager().aput_event(
453
+ ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=node)))
454
+ )
455
+ # Wait for a task to finish
456
+ done, _ = await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
457
+ # Prevent task being removed before return
458
+ for task in done:
459
+ context.task_to_node.pop(task)
460
+ # Once a task has finished, loop back to the top.
461
+ await ExecuteDagState.pop_done_states(context)
462
+ # Remove all nodes that are done
463
+ if context.paused:
464
+ return None
465
+ return ExecuteDagState
466
+
467
+
468
+ class ErrorState(State):
469
+ @staticmethod
470
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
471
+ if context.error_message:
472
+ logger.error("DAG execution error: %s", context.error_message)
473
+ for node in context.node_to_reference.values():
474
+ # Cancel all nodes that haven't yet begun processing.
475
+ if node.node_state == NodeState.QUEUED:
476
+ node.node_state = NodeState.CANCELED
477
+ # Shut down and cancel all threads/tasks that haven't yet ran. Currently running ones will not be affected.
478
+ # Cancel async tasks
479
+ for task in list(context.task_to_node.keys()):
480
+ if not task.done():
481
+ task.cancel()
482
+ return ErrorState
483
+
484
+ @staticmethod
485
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
486
+ # Don't modify lists while iterating through them.
487
+ task_to_node = context.task_to_node
488
+ for task, node in task_to_node.copy().items():
489
+ if task.done():
490
+ node.node_state = NodeState.DONE
491
+ elif task.cancelled():
492
+ node.node_state = NodeState.CANCELED
493
+ task_to_node.pop(task)
494
+
495
+ # Handle async tasks
496
+ task_to_node = context.task_to_node
497
+ for task, node in task_to_node.copy().items():
498
+ if task.done():
499
+ node.node_state = NodeState.DONE
500
+ elif task.cancelled():
501
+ node.node_state = NodeState.CANCELED
502
+ task_to_node.pop(task)
503
+
504
+ if len(task_to_node) == 0:
505
+ # Finish up. We failed.
506
+ context.workflow_state = WorkflowState.ERRORED
507
+ context.networks.clear()
508
+ context.node_to_reference.clear()
509
+ context.task_to_node.clear()
510
+ return DagCompleteState
511
+ # Let's continue going through until everything is cancelled.
512
+ return ErrorState
513
+
514
+
515
+ class DagCompleteState(State):
516
+ @staticmethod
517
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
518
+ # Clear the DAG builder so we don't have any leftover nodes in node_to_reference.
519
+ if context.dag_builder is not None:
520
+ context.dag_builder.clear()
521
+ return None
522
+
523
+ @staticmethod
524
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: ARG004
525
+ return None
526
+
527
+
528
+ class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
529
+ """State machine for building DAG structure without execution."""
530
+
531
+ def __init__(
532
+ self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
533
+ ) -> None:
534
+ resolution_context = ParallelResolutionContext(
535
+ flow_name, max_nodes_in_parallel=max_nodes_in_parallel, dag_builder=dag_builder
536
+ )
537
+ super().__init__(resolution_context)
538
+
539
+ async def resolve_node(self, node: BaseNode | None = None) -> None: # noqa: ARG002
540
+ """Execute the DAG structure using the existing DagBuilder."""
541
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
542
+
543
+ if self.context.dag_builder is None:
544
+ self.context.dag_builder = GriptapeNodes.FlowManager().global_dag_builder
545
+ await self.start(ExecuteDagState)
546
+
547
+ def change_debug_mode(self, *, debug_mode: bool) -> None:
548
+ self._context.paused = debug_mode
549
+
550
+ def is_complete(self) -> bool:
551
+ return self._current_state is DagCompleteState
552
+
553
+ def is_started(self) -> bool:
554
+ return self._current_state is not None
555
+
556
+ def reset_machine(self, *, cancel: bool = False) -> None:
557
+ self._context.reset(cancel=cancel)
558
+ self._current_state = None