griptape-nodes 0.53.0__py3-none-any.whl → 0.54.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. griptape_nodes/__init__.py +5 -2
  2. griptape_nodes/app/app.py +4 -26
  3. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
  4. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
  5. griptape_nodes/cli/commands/config.py +4 -1
  6. griptape_nodes/cli/commands/init.py +5 -3
  7. griptape_nodes/cli/commands/libraries.py +14 -8
  8. griptape_nodes/cli/commands/models.py +504 -0
  9. griptape_nodes/cli/commands/self.py +5 -2
  10. griptape_nodes/cli/main.py +11 -1
  11. griptape_nodes/cli/shared.py +0 -9
  12. griptape_nodes/common/directed_graph.py +17 -1
  13. griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
  14. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
  15. griptape_nodes/drivers/storage/local_storage_driver.py +17 -13
  16. griptape_nodes/exe_types/node_types.py +219 -14
  17. griptape_nodes/exe_types/param_components/__init__.py +1 -0
  18. griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
  19. griptape_nodes/machines/control_flow.py +129 -92
  20. griptape_nodes/machines/dag_builder.py +207 -0
  21. griptape_nodes/machines/parallel_resolution.py +264 -276
  22. griptape_nodes/machines/sequential_resolution.py +9 -7
  23. griptape_nodes/node_library/library_registry.py +34 -1
  24. griptape_nodes/retained_mode/events/app_events.py +5 -1
  25. griptape_nodes/retained_mode/events/base_events.py +7 -7
  26. griptape_nodes/retained_mode/events/config_events.py +30 -0
  27. griptape_nodes/retained_mode/events/execution_events.py +2 -2
  28. griptape_nodes/retained_mode/events/model_events.py +296 -0
  29. griptape_nodes/retained_mode/griptape_nodes.py +10 -1
  30. griptape_nodes/retained_mode/managers/agent_manager.py +14 -0
  31. griptape_nodes/retained_mode/managers/config_manager.py +44 -3
  32. griptape_nodes/retained_mode/managers/event_manager.py +8 -2
  33. griptape_nodes/retained_mode/managers/flow_manager.py +45 -14
  34. griptape_nodes/retained_mode/managers/library_manager.py +3 -3
  35. griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
  36. griptape_nodes/retained_mode/managers/node_manager.py +26 -26
  37. griptape_nodes/retained_mode/managers/object_manager.py +1 -1
  38. griptape_nodes/retained_mode/managers/os_manager.py +6 -6
  39. griptape_nodes/retained_mode/managers/settings.py +87 -9
  40. griptape_nodes/retained_mode/managers/static_files_manager.py +77 -9
  41. griptape_nodes/retained_mode/managers/sync_manager.py +10 -5
  42. griptape_nodes/retained_mode/managers/workflow_manager.py +101 -92
  43. griptape_nodes/retained_mode/retained_mode.py +19 -0
  44. griptape_nodes/servers/__init__.py +1 -0
  45. griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
  46. griptape_nodes/{app/api.py → servers/static.py} +43 -40
  47. griptape_nodes/traits/button.py +124 -6
  48. griptape_nodes/traits/multi_options.py +188 -0
  49. griptape_nodes/traits/numbers_selector.py +77 -0
  50. griptape_nodes/traits/options.py +93 -2
  51. griptape_nodes/utils/async_utils.py +31 -0
  52. {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/METADATA +3 -1
  53. {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/RECORD +56 -47
  54. {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/WHEEL +1 -1
  55. /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
  56. {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/entry_points.txt +0 -0
@@ -2,14 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import logging
5
- from dataclasses import dataclass, field
6
5
  from enum import StrEnum
7
- from typing import Any
6
+ from typing import TYPE_CHECKING
8
7
 
9
- from griptape_nodes.common.directed_graph import DirectedGraph
10
- from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
8
+ from griptape_nodes.exe_types.core_types import Parameter, ParameterType, ParameterTypeBuiltin
11
9
  from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
12
10
  from griptape_nodes.exe_types.type_validator import TypeValidator
11
+ from griptape_nodes.machines.dag_builder import NodeState
13
12
  from griptape_nodes.machines.fsm import FSM, State
14
13
  from griptape_nodes.node_library.library_registry import LibraryRegistry
15
14
  from griptape_nodes.retained_mode.events.base_events import (
@@ -19,38 +18,16 @@ from griptape_nodes.retained_mode.events.base_events import (
19
18
  from griptape_nodes.retained_mode.events.execution_events import (
20
19
  CurrentDataNodeEvent,
21
20
  NodeResolvedEvent,
22
- ParameterSpotlightEvent,
23
21
  ParameterValueUpdateEvent,
24
22
  )
25
23
  from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
26
24
 
27
- logger = logging.getLogger("griptape_nodes")
28
-
29
-
30
- class NodeState(StrEnum):
31
- """Individual node execution states."""
32
-
33
- QUEUED = "queued"
34
- PROCESSING = "processing"
35
- DONE = "done"
36
- CANCELED = "canceled"
37
- ERRORED = "errored"
38
- WAITING = "waiting"
39
-
40
-
41
- @dataclass(kw_only=True)
42
- class DagNode:
43
- """Represents a node in the DAG with runtime references."""
44
-
45
- task_reference: asyncio.Task | None = field(default=None)
46
- node_state: NodeState = field(default=NodeState.WAITING)
47
- node_reference: BaseNode
25
+ if TYPE_CHECKING:
26
+ from griptape_nodes.common.directed_graph import DirectedGraph
27
+ from griptape_nodes.machines.dag_builder import DagBuilder, DagNode
28
+ from griptape_nodes.retained_mode.managers.flow_manager import FlowManager
48
29
 
49
-
50
- @dataclass
51
- class Focus:
52
- node: BaseNode
53
- scheduled_value: Any | None = None
30
+ logger = logging.getLogger("griptape_nodes")
54
31
 
55
32
 
56
33
  class WorkflowState(StrEnum):
@@ -63,175 +40,77 @@ class WorkflowState(StrEnum):
63
40
 
64
41
 
65
42
  class ParallelResolutionContext:
66
- focus_stack: list[Focus]
67
43
  paused: bool
68
44
  flow_name: str
69
- build_only: bool
70
- batched_nodes: list[BaseNode]
71
45
  error_message: str | None
72
46
  workflow_state: WorkflowState
73
- # DAG fields moved from DagOrchestrator
74
- network: DirectedGraph
75
- node_to_reference: dict[str, DagNode]
47
+ # Execution fields
76
48
  async_semaphore: asyncio.Semaphore
77
49
  task_to_node: dict[asyncio.Task, DagNode]
50
+ dag_builder: DagBuilder | None
78
51
 
79
- def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
52
+ def __init__(
53
+ self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
54
+ ) -> None:
80
55
  self.flow_name = flow_name
81
- self.focus_stack = []
82
56
  self.paused = False
83
- self.build_only = False
84
- self.batched_nodes = []
85
57
  self.error_message = None
86
58
  self.workflow_state = WorkflowState.NO_ERROR
59
+ self.dag_builder = dag_builder
87
60
 
88
- # Initialize DAG fields
89
- self.network = DirectedGraph()
90
- self.node_to_reference = {}
61
+ # Initialize execution fields
91
62
  max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
92
63
  self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
93
64
  self.task_to_node = {}
94
65
 
66
+ @property
67
+ def node_to_reference(self) -> dict[str, DagNode]:
68
+ """Get node_to_reference from dag_builder if available."""
69
+ if not self.dag_builder:
70
+ msg = "DagBuilder is not initialized"
71
+ raise ValueError(msg)
72
+ return self.dag_builder.node_to_reference
73
+
74
+ @property
75
+ def networks(self) -> dict[str, DirectedGraph]:
76
+ """Get node_to_reference from dag_builder if available."""
77
+ if not self.dag_builder:
78
+ msg = "DagBuilder is not initialized"
79
+ raise ValueError(msg)
80
+ return self.dag_builder.graphs
81
+
95
82
  def reset(self, *, cancel: bool = False) -> None:
96
- if self.focus_stack:
97
- node = self.focus_stack[-1].node
98
- node.clear_node()
99
- self.focus_stack.clear()
100
83
  self.paused = False
101
84
  if cancel:
102
85
  self.workflow_state = WorkflowState.CANCELED
103
- for node in self.node_to_reference.values():
104
- node.node_state = NodeState.CANCELED
86
+ # Only access node_to_reference if dag_builder exists
87
+ if self.dag_builder:
88
+ for node in self.node_to_reference.values():
89
+ node.node_state = NodeState.CANCELED
105
90
  else:
106
91
  self.workflow_state = WorkflowState.NO_ERROR
107
92
  self.error_message = None
108
- self.network.clear()
109
- self.node_to_reference.clear()
110
93
  self.task_to_node.clear()
111
94
 
95
+ # Clear DAG builder state to allow re-adding nodes on subsequent runs
96
+ if self.dag_builder:
97
+ self.dag_builder.clear()
112
98
 
113
- class InitializeDagSpotlightState(State):
114
- @staticmethod
115
- async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
116
- from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
117
-
118
- current_node = context.focus_stack[-1].node
119
- GriptapeNodes.EventManager().put_event(
120
- ExecutionGriptapeNodeEvent(
121
- wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=current_node.name))
122
- )
123
- )
124
- if not context.paused:
125
- return InitializeDagSpotlightState
126
- return None
127
99
 
100
+ class ExecuteDagState(State):
128
101
  @staticmethod
129
- async def on_update(context: ParallelResolutionContext) -> type[State] | None:
130
- if not len(context.focus_stack):
131
- return DagCompleteState
132
- from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
133
-
134
- current_node = context.focus_stack[-1].node
135
- if current_node.state == NodeResolutionState.UNRESOLVED:
136
- GriptapeNodes.FlowManager().get_connections().unresolve_future_nodes(current_node)
137
- current_node.initialize_spotlight()
138
- current_node.state = NodeResolutionState.RESOLVING
139
- if current_node.get_current_parameter() is None:
140
- if current_node.advance_parameter():
141
- return EvaluateDagParameterState
142
- return BuildDagNodeState
143
- return EvaluateDagParameterState
144
-
145
-
146
- class EvaluateDagParameterState(State):
147
- @staticmethod
148
- async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
149
- current_node = context.focus_stack[-1].node
150
- current_parameter = current_node.get_current_parameter()
151
- if current_parameter is None:
152
- return BuildDagNodeState
153
- from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
102
+ async def handle_done_nodes(context: ParallelResolutionContext, done_node: DagNode, network_name: str) -> None:
103
+ current_node = done_node.node_reference
154
104
 
155
- GriptapeNodes.EventManager().put_event(
156
- ExecutionGriptapeNodeEvent(
157
- wrapped_event=ExecutionEvent(
158
- payload=ParameterSpotlightEvent(
159
- node_name=current_node.name,
160
- parameter_name=current_parameter.name,
161
- )
162
- )
105
+ # Check if node was already resolved (shouldn't happen)
106
+ if current_node.state == NodeResolutionState.RESOLVED:
107
+ logger.error(
108
+ "DUPLICATE COMPLETION DETECTED: Node '%s' was already RESOLVED but handle_done_nodes was called again from network '%s'. This should not happen!",
109
+ current_node.name,
110
+ network_name,
163
111
  )
164
- )
165
- if not context.paused:
166
- return EvaluateDagParameterState
167
- return None
168
-
169
- @staticmethod
170
- async def on_update(context: ParallelResolutionContext) -> type[State] | None:
171
- from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
172
-
173
- current_node = context.focus_stack[-1].node
174
- current_parameter = current_node.get_current_parameter()
175
- connections = GriptapeNodes.FlowManager().get_connections()
176
- if current_parameter is None:
177
- msg = "No current parameter set."
178
- raise ValueError(msg)
179
- next_node = connections.get_connected_node(current_node, current_parameter)
180
- if next_node:
181
- next_node, _ = next_node
182
- if next_node:
183
- if next_node.state == NodeResolutionState.UNRESOLVED:
184
- focus_stack_names = {focus.node.name for focus in context.focus_stack}
185
- if next_node.name in focus_stack_names:
186
- msg = f"Cycle detected between node '{current_node.name}' and '{next_node.name}'."
187
- raise RuntimeError(msg)
188
- context.network.add_edge(next_node.name, current_node.name)
189
- context.focus_stack.append(Focus(node=next_node))
190
- return InitializeDagSpotlightState
191
- if next_node.state == NodeResolutionState.RESOLVED and next_node in context.batched_nodes:
192
- context.network.add_edge(next_node.name, current_node.name)
193
- if current_node.advance_parameter():
194
- return InitializeDagSpotlightState
195
- return BuildDagNodeState
196
-
197
-
198
- class BuildDagNodeState(State):
199
- @staticmethod
200
- async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
201
- current_node = context.focus_stack[-1].node
202
-
203
- # Add the current node to the DAG
204
- node_reference = DagNode(node_reference=current_node)
205
- context.node_to_reference[current_node.name] = node_reference
206
- # Add node name to DAG (has to be a hashable value)
207
- context.network.add_node(node_for_adding=current_node.name)
208
-
209
- if not context.paused:
210
- return BuildDagNodeState
211
- return None
212
-
213
- @staticmethod
214
- async def on_update(context: ParallelResolutionContext) -> type[State] | None:
215
- current_node = context.focus_stack[-1].node
112
+ return
216
113
 
217
- # Mark node as resolved for DAG building purposes
218
- current_node.state = NodeResolutionState.RESOLVED
219
- # Add to batched nodes
220
- context.batched_nodes.append(current_node)
221
-
222
- context.focus_stack.pop()
223
- if len(context.focus_stack):
224
- return EvaluateDagParameterState
225
-
226
- if context.build_only:
227
- return DagCompleteState
228
- return ExecuteDagState
229
-
230
-
231
- class ExecuteDagState(State):
232
- @staticmethod
233
- def handle_done_nodes(done_node: DagNode) -> None:
234
- current_node = done_node.node_reference
235
114
  # Publish all parameter updates.
236
115
  current_node.state = NodeResolutionState.RESOLVED
237
116
  # Serialization can be slow so only do it if the user wants debug details.
@@ -252,7 +131,7 @@ class ExecuteDagState(State):
252
131
  data_type = ParameterTypeBuiltin.NONE.value
253
132
  from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
254
133
 
255
- GriptapeNodes.EventManager().put_event(
134
+ await GriptapeNodes.EventManager().aput_event(
256
135
  ExecutionGriptapeNodeEvent(
257
136
  wrapped_event=ExecutionEvent(
258
137
  payload=ParameterValueUpdateEvent(
@@ -272,7 +151,7 @@ class ExecuteDagState(State):
272
151
  library_name = None
273
152
  from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
274
153
 
275
- GriptapeNodes.EventManager().put_event(
154
+ await GriptapeNodes.EventManager().aput_event(
276
155
  ExecutionGriptapeNodeEvent(
277
156
  wrapped_event=ExecutionEvent(
278
157
  payload=NodeResolvedEvent(
@@ -284,9 +163,104 @@ class ExecuteDagState(State):
284
163
  )
285
164
  )
286
165
  )
166
+ # Now the final thing to do, is to take their directed graph and update it.
167
+ ExecuteDagState.get_next_control_graph(context, current_node, network_name)
168
+
169
+ @staticmethod
170
+ def get_next_control_graph(context: ParallelResolutionContext, node: BaseNode, network_name: str) -> None:
171
+ """Get next control flow nodes and add them to the DAG graph."""
172
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
173
+
174
+ flow_manager = GriptapeNodes.FlowManager()
175
+
176
+ # Early returns for various conditions
177
+ if ExecuteDagState._should_skip_control_flow(context, node, network_name, flow_manager):
178
+ return
179
+
180
+ next_output = node.get_next_control_output()
181
+ if next_output is not None:
182
+ ExecuteDagState._process_next_control_node(context, node, next_output, network_name, flow_manager)
183
+
184
+ @staticmethod
185
+ def _should_skip_control_flow(
186
+ context: ParallelResolutionContext, node: BaseNode, network_name: str, flow_manager: FlowManager
187
+ ) -> bool:
188
+ """Check if control flow processing should be skipped."""
189
+ if flow_manager.global_single_node_resolution:
190
+ return True
191
+
192
+ if context.dag_builder is not None:
193
+ network = context.dag_builder.graphs.get(network_name, None)
194
+ if network is not None and len(network) > 0:
195
+ return True
196
+
197
+ if node.stop_flow:
198
+ node.stop_flow = False
199
+ return True
200
+
201
+ return False
202
+
203
+ @staticmethod
204
+ def _process_next_control_node(
205
+ context: ParallelResolutionContext,
206
+ node: BaseNode,
207
+ next_output: Parameter,
208
+ network_name: str,
209
+ flow_manager: FlowManager,
210
+ ) -> None:
211
+ """Process the next control node in the flow."""
212
+ node_connection = flow_manager.get_connections().get_connected_node(node, next_output)
213
+ if node_connection is not None:
214
+ next_node, _ = node_connection
215
+
216
+ # Prepare next node for execution
217
+ if not next_node.lock:
218
+ next_node.make_node_unresolved(
219
+ current_states_to_trigger_change_event=set(
220
+ {
221
+ NodeResolutionState.UNRESOLVED,
222
+ NodeResolutionState.RESOLVED,
223
+ NodeResolutionState.RESOLVING,
224
+ }
225
+ )
226
+ )
227
+
228
+ ExecuteDagState._add_and_queue_nodes(context, next_node, network_name)
229
+
230
+ @staticmethod
231
+ def _add_and_queue_nodes(context: ParallelResolutionContext, next_node: BaseNode, network_name: str) -> None:
232
+ """Add nodes to DAG and queue them if ready."""
233
+ if context.dag_builder is not None:
234
+ added_nodes = context.dag_builder.add_node_with_dependencies(next_node, network_name)
235
+ if next_node not in added_nodes:
236
+ added_nodes.append(next_node)
237
+
238
+ # Queue nodes that are ready for execution
239
+ if added_nodes:
240
+ for added_node in added_nodes:
241
+ ExecuteDagState._try_queue_waiting_node(context, added_node.name)
287
242
 
288
243
  @staticmethod
289
- def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
244
+ def _try_queue_waiting_node(context: ParallelResolutionContext, node_name: str) -> None:
245
+ """Try to queue a specific waiting node if it can now be queued."""
246
+ if context.dag_builder is None:
247
+ logger.warning("DAG builder is None - cannot check queueing for node '%s'", node_name)
248
+ return
249
+
250
+ if node_name not in context.node_to_reference:
251
+ logger.warning("Node '%s' not found in node_to_reference - cannot check queueing", node_name)
252
+ return
253
+
254
+ dag_node = context.node_to_reference[node_name]
255
+
256
+ # Only check nodes that are currently waiting
257
+ if dag_node.node_state == NodeState.WAITING:
258
+ can_queue = context.dag_builder.can_queue_control_node(dag_node)
259
+ if can_queue:
260
+ dag_node.node_state = NodeState.QUEUED
261
+
262
+ @staticmethod
263
+ async def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
290
264
  """Collect output values from resolved upstream nodes and pass them to the current node.
291
265
 
292
266
  This method iterates through all input parameters of the current node, finds their
@@ -318,76 +292,77 @@ class ExecuteDagState(State):
318
292
  output_value = upstream_node.get_parameter_value(upstream_parameter.name)
319
293
 
320
294
  # Pass the value through using the same mechanism as normal resolution
321
- GriptapeNodes.get_instance().handle_request(
322
- SetParameterValueRequest(
323
- parameter_name=parameter.name,
324
- node_name=current_node.name,
325
- value=output_value,
326
- data_type=upstream_parameter.output_type,
327
- incoming_connection_source_node_name=upstream_node.name,
328
- incoming_connection_source_parameter_name=upstream_parameter.name,
295
+ # Skip propagation for Control Parameters as they should not receive values
296
+ if (
297
+ ParameterType.attempt_get_builtin(upstream_parameter.output_type)
298
+ != ParameterTypeBuiltin.CONTROL_TYPE
299
+ ):
300
+ await GriptapeNodes.get_instance().ahandle_request(
301
+ SetParameterValueRequest(
302
+ parameter_name=parameter.name,
303
+ node_name=current_node.name,
304
+ value=output_value,
305
+ data_type=upstream_parameter.output_type,
306
+ incoming_connection_source_node_name=upstream_node.name,
307
+ incoming_connection_source_parameter_name=upstream_parameter.name,
308
+ )
329
309
  )
330
- )
331
-
332
- @staticmethod
333
- def clear_parameter_output_values(node_reference: DagNode) -> None:
334
- """Clear all parameter output values for the given node and publish events.
335
-
336
- This method iterates through each parameter output value stored in the node,
337
- removes it from the node's parameter_output_values dictionary, and publishes an event
338
- to notify the system about the parameter value being set to None.
339
-
340
- Args:
341
- node_reference (DagOrchestrator.DagNode): The DAG node to clear values for.
342
-
343
- Raises:
344
- ValueError: If a parameter name in parameter_output_values doesn't correspond
345
- to an actual parameter in the node.
346
- """
347
- current_node = node_reference.node_reference
348
- for parameter_name in current_node.parameter_output_values:
349
- parameter = current_node.get_parameter_by_name(parameter_name)
350
- if parameter is None:
351
- err = f"Attempted to clear output values for node '{current_node.name}' but could not find parameter '{parameter_name}' that was indicated as having a value."
352
- raise ValueError(err)
353
- parameter_type = parameter.type
354
- if parameter_type is None:
355
- parameter_type = ParameterTypeBuiltin.NONE.value
356
- payload = ParameterValueUpdateEvent(
357
- node_name=current_node.name,
358
- parameter_name=parameter_name,
359
- data_type=parameter_type,
360
- value=None,
361
- )
362
- from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
363
-
364
- GriptapeNodes.EventManager().put_event(
365
- ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=payload))
366
- )
367
- current_node.parameter_output_values.clear()
368
310
 
369
311
  @staticmethod
370
- def build_node_states(context: ParallelResolutionContext) -> tuple[list[str], list[str], list[str], list[str]]:
371
- network = context.network
372
- leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
373
- done_nodes = []
374
- canceled_nodes = []
375
- queued_nodes = []
312
+ def build_node_states(context: ParallelResolutionContext) -> tuple[set[str], set[str], set[str]]:
313
+ networks = context.networks
314
+ leaf_nodes = set()
315
+ for network in networks.values():
316
+ # Check and see if there are leaf nodes that are cancelled.
317
+ # Reinitialize leaf nodes since maybe we changed things up.
318
+ # We removed nodes from the network. There may be new leaf nodes.
319
+ # Add all leaf nodes from all networks (using set union to avoid duplicates)
320
+ leaf_nodes.update([n for n in network.nodes() if network.in_degree(n) == 0])
321
+ canceled_nodes = set()
322
+ queued_nodes = set()
376
323
  for node in leaf_nodes:
377
324
  node_reference = context.node_to_reference[node]
378
325
  # If the node is locked, mark it as done so it skips execution
379
326
  if node_reference.node_reference.lock:
380
327
  node_reference.node_state = NodeState.DONE
381
- done_nodes.append(node)
382
328
  continue
383
329
  node_state = node_reference.node_state
384
- if node_state == NodeState.DONE:
385
- done_nodes.append(node)
386
- elif node_state == NodeState.CANCELED:
387
- canceled_nodes.append(node)
330
+ if node_state == NodeState.CANCELED:
331
+ canceled_nodes.add(node)
388
332
  elif node_state == NodeState.QUEUED:
389
- queued_nodes.append(node)
390
- return done_nodes, canceled_nodes, queued_nodes, leaf_nodes
333
+ queued_nodes.add(node)
334
+ return canceled_nodes, queued_nodes, leaf_nodes
335
+
336
+ @staticmethod
337
+ async def pop_done_states(context: ParallelResolutionContext) -> None:
338
+ networks = context.networks
339
+ handled_nodes = set() # Track nodes we've already processed to avoid duplicates
340
+
341
+ for network_name, network in networks.items():
342
+ # Check and see if there are leaf nodes that are cancelled.
343
+ # Reinitialize leaf nodes since maybe we changed things up.
344
+ # We removed nodes from the network. There may be new leaf nodes.
345
+ leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
346
+ for node in leaf_nodes:
347
+ node_reference = context.node_to_reference[node]
348
+ node_state = node_reference.node_state
349
+ # If the node is locked, mark it as done so it skips execution
350
+ if node_reference.node_reference.lock or node_state == NodeState.DONE:
351
+ node_reference.node_state = NodeState.DONE
352
+ network.remove_node(node)
353
+
354
+ # Only call handle_done_nodes once per node (first network that processes it)
355
+ if node not in handled_nodes:
356
+ handled_nodes.add(node)
357
+ await ExecuteDagState.handle_done_nodes(context, context.node_to_reference[node], network_name)
358
+
359
+ # After processing completions in this network, check if any remaining leaf nodes can now be queued
360
+ remaining_leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
361
+
362
+ for leaf_node in remaining_leaf_nodes:
363
+ if leaf_node in context.node_to_reference:
364
+ node_state = context.node_to_reference[leaf_node].node_state
365
+ ExecuteDagState._try_queue_waiting_node(context, leaf_node)
391
366
 
392
367
  @staticmethod
393
368
  async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
@@ -397,32 +372,28 @@ class ExecuteDagState(State):
397
372
  @staticmethod
398
373
  async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
399
374
  # Start DAG execution after resolution is complete
400
- context.batched_nodes.clear()
401
375
  for node in context.node_to_reference.values():
402
- # We have a DAG. Flag all nodes in DAG as queued. Workflow state is NO_ERROR
403
- node.node_state = NodeState.QUEUED
376
+ # Only queue nodes that are waiting - preserve state of already processed nodes.
377
+ if node.node_state == NodeState.WAITING:
378
+ node.node_state = NodeState.QUEUED
379
+
404
380
  context.workflow_state = WorkflowState.NO_ERROR
381
+
405
382
  if not context.paused:
406
383
  return ExecuteDagState
407
384
  return None
408
385
 
409
386
  @staticmethod
410
- async def on_update(context: ParallelResolutionContext) -> type[State] | None:
387
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: C901, PLR0911
388
+ # Check if execution is paused
389
+ if context.paused:
390
+ return None
391
+
411
392
  # Check if DAG execution is complete
412
- network = context.network
413
393
  # Check and see if there are leaf nodes that are cancelled.
414
- done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
415
- # Are there any nodes in Done state?
416
- for node in done_nodes:
417
- # We have nodes in done state.
418
- # Remove the leaf node from the graph.
419
- network.remove_node(node)
420
- # Return thread to thread pool.
421
- ExecuteDagState.handle_done_nodes(context.node_to_reference[node])
422
394
  # Reinitialize leaf nodes since maybe we changed things up.
423
- if len(done_nodes) > 0:
424
- # We removed nodes from the network. There may be new leaf nodes.
425
- done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
395
+ # We removed nodes from the network. There may be new leaf nodes.
396
+ canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
426
397
  # We have no more leaf nodes. Quit early.
427
398
  if not leaf_nodes:
428
399
  context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
@@ -439,7 +410,7 @@ class ExecuteDagState(State):
439
410
 
440
411
  # Collect parameter values from upstream nodes before executing
441
412
  try:
442
- ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
413
+ await ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
443
414
  except Exception as e:
444
415
  logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
445
416
  context.error_message = (
@@ -448,25 +419,25 @@ class ExecuteDagState(State):
448
419
  context.workflow_state = WorkflowState.ERRORED
449
420
  return ErrorState
450
421
 
451
- # Clear parameter output values before execution
452
- try:
453
- ExecuteDagState.clear_parameter_output_values(node_reference)
454
- except Exception as e:
455
- logger.exception(
456
- "Error clearing parameter output values for node '%s'", node_reference.node_reference.name
457
- )
458
- context.error_message = (
459
- f"Parameter clearing failed for node '{node_reference.node_reference.name}': {e}"
460
- )
461
- context.workflow_state = WorkflowState.ERRORED
422
+ # Clear all of the current output values but don't broadcast the clearing.
423
+ # to avoid any flickering in subscribers (UI).
424
+ node_reference.node_reference.parameter_output_values.silent_clear()
425
+ exceptions = node_reference.node_reference.validate_before_node_run()
426
+ if exceptions:
427
+ msg = f"Canceling flow run. Node '{node_reference.node_reference.name}' encountered problems: {exceptions}"
428
+ logger.error(msg)
462
429
  return ErrorState
463
430
 
464
431
  def on_task_done(task: asyncio.Task) -> None:
465
- node = context.task_to_node.pop(task)
466
- node.node_state = NodeState.DONE
467
- logger.info("Task done: %s", node.node_reference.name)
432
+ if task in context.task_to_node:
433
+ node = context.task_to_node[task]
434
+ node.node_state = NodeState.DONE
468
435
 
469
436
  # Execute the node asynchronously
437
+ logger.debug(
438
+ "CREATING EXECUTION TASK for node '%s' - this should only happen once per node!",
439
+ node_reference.node_reference.name,
440
+ )
470
441
  node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
471
442
  # Add a callback to set node to done when task has finished.
472
443
  context.task_to_node[node_task] = node_reference
@@ -474,9 +445,23 @@ class ExecuteDagState(State):
474
445
  node_task.add_done_callback(lambda t: on_task_done(t))
475
446
  node_reference.node_state = NodeState.PROCESSING
476
447
  node_reference.node_reference.state = NodeResolutionState.RESOLVING
448
+
449
+ # Send an event that this is a current data node:
450
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
451
+
452
+ await GriptapeNodes.EventManager().aput_event(
453
+ ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=node)))
454
+ )
477
455
  # Wait for a task to finish
478
- await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
456
+ done, _ = await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
457
+ # Prevent task being removed before return
458
+ for task in done:
459
+ context.task_to_node.pop(task)
479
460
  # Once a task has finished, loop back to the top.
461
+ await ExecuteDagState.pop_done_states(context)
462
+ # Remove all nodes that are done
463
+ if context.paused:
464
+ return None
480
465
  return ExecuteDagState
481
466
 
482
467
 
@@ -519,7 +504,7 @@ class ErrorState(State):
519
504
  if len(task_to_node) == 0:
520
505
  # Finish up. We failed.
521
506
  context.workflow_state = WorkflowState.ERRORED
522
- context.network.clear()
507
+ context.networks.clear()
523
508
  context.node_to_reference.clear()
524
509
  context.task_to_node.clear()
525
510
  return DagCompleteState
@@ -530,8 +515,9 @@ class ErrorState(State):
530
515
  class DagCompleteState(State):
531
516
  @staticmethod
532
517
  async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
533
- # Set build_only back to False.
534
- context.build_only = False
518
+ # Clear the DAG builder so we don't have any leftover nodes in node_to_reference.
519
+ if context.dag_builder is not None:
520
+ context.dag_builder.clear()
535
521
  return None
536
522
 
537
523
  @staticmethod
@@ -542,19 +528,21 @@ class DagCompleteState(State):
542
528
  class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
543
529
  """State machine for building DAG structure without execution."""
544
530
 
545
- def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
546
- resolution_context = ParallelResolutionContext(flow_name, max_nodes_in_parallel=max_nodes_in_parallel)
531
+ def __init__(
532
+ self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
533
+ ) -> None:
534
+ resolution_context = ParallelResolutionContext(
535
+ flow_name, max_nodes_in_parallel=max_nodes_in_parallel, dag_builder=dag_builder
536
+ )
547
537
  super().__init__(resolution_context)
548
538
 
549
- async def resolve_node(self, node: BaseNode, *, build_only: bool = False) -> None:
550
- """Build DAG structure starting from the given node."""
551
- self._context.focus_stack.append(Focus(node=node))
552
- self._context.build_only = build_only
553
- await self.start(InitializeDagSpotlightState)
539
+ async def resolve_node(self, node: BaseNode | None = None) -> None: # noqa: ARG002
540
+ """Execute the DAG structure using the existing DagBuilder."""
541
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
554
542
 
555
- async def build_dag_for_node(self, node: BaseNode) -> None:
556
- """Build DAG structure starting from the given node. (Deprecated: use resolve_node)."""
557
- await self.resolve_node(node)
543
+ if self.context.dag_builder is None:
544
+ self.context.dag_builder = GriptapeNodes.FlowManager().global_dag_builder
545
+ await self.start(ExecuteDagState)
558
546
 
559
547
  def change_debug_mode(self, *, debug_mode: bool) -> None:
560
548
  self._context.paused = debug_mode