griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. griptape_nodes/__init__.py +8 -942
  2. griptape_nodes/__main__.py +6 -0
  3. griptape_nodes/app/app.py +48 -86
  4. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
  5. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
  6. griptape_nodes/cli/__init__.py +1 -0
  7. griptape_nodes/cli/commands/__init__.py +1 -0
  8. griptape_nodes/cli/commands/config.py +74 -0
  9. griptape_nodes/cli/commands/engine.py +80 -0
  10. griptape_nodes/cli/commands/init.py +550 -0
  11. griptape_nodes/cli/commands/libraries.py +96 -0
  12. griptape_nodes/cli/commands/models.py +504 -0
  13. griptape_nodes/cli/commands/self.py +120 -0
  14. griptape_nodes/cli/main.py +56 -0
  15. griptape_nodes/cli/shared.py +75 -0
  16. griptape_nodes/common/__init__.py +1 -0
  17. griptape_nodes/common/directed_graph.py +71 -0
  18. griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
  19. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
  20. griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
  21. griptape_nodes/exe_types/core_types.py +60 -2
  22. griptape_nodes/exe_types/node_types.py +257 -38
  23. griptape_nodes/exe_types/param_components/__init__.py +1 -0
  24. griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
  25. griptape_nodes/machines/control_flow.py +195 -94
  26. griptape_nodes/machines/dag_builder.py +207 -0
  27. griptape_nodes/machines/fsm.py +10 -1
  28. griptape_nodes/machines/parallel_resolution.py +558 -0
  29. griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
  30. griptape_nodes/node_library/library_registry.py +34 -1
  31. griptape_nodes/retained_mode/events/app_events.py +5 -1
  32. griptape_nodes/retained_mode/events/base_events.py +9 -9
  33. griptape_nodes/retained_mode/events/config_events.py +30 -0
  34. griptape_nodes/retained_mode/events/execution_events.py +2 -2
  35. griptape_nodes/retained_mode/events/model_events.py +296 -0
  36. griptape_nodes/retained_mode/events/node_events.py +4 -3
  37. griptape_nodes/retained_mode/griptape_nodes.py +34 -12
  38. griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
  39. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
  40. griptape_nodes/retained_mode/managers/config_manager.py +44 -3
  41. griptape_nodes/retained_mode/managers/context_manager.py +6 -5
  42. griptape_nodes/retained_mode/managers/event_manager.py +8 -2
  43. griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
  44. griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
  45. griptape_nodes/retained_mode/managers/library_manager.py +35 -25
  46. griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
  47. griptape_nodes/retained_mode/managers/node_manager.py +102 -220
  48. griptape_nodes/retained_mode/managers/object_manager.py +11 -5
  49. griptape_nodes/retained_mode/managers/os_manager.py +28 -13
  50. griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
  51. griptape_nodes/retained_mode/managers/settings.py +116 -7
  52. griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
  53. griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
  54. griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
  55. griptape_nodes/retained_mode/retained_mode.py +19 -0
  56. griptape_nodes/servers/__init__.py +1 -0
  57. griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
  58. griptape_nodes/{app/api.py → servers/static.py} +43 -40
  59. griptape_nodes/traits/add_param_button.py +1 -1
  60. griptape_nodes/traits/button.py +334 -6
  61. griptape_nodes/traits/color_picker.py +66 -0
  62. griptape_nodes/traits/multi_options.py +188 -0
  63. griptape_nodes/traits/numbers_selector.py +77 -0
  64. griptape_nodes/traits/options.py +93 -2
  65. griptape_nodes/traits/traits.json +4 -0
  66. griptape_nodes/utils/async_utils.py +31 -0
  67. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
  68. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
  69. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
  70. /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
  71. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
@@ -9,7 +9,8 @@ from griptape_nodes.exe_types.core_types import Parameter
9
9
  from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
10
10
  from griptape_nodes.exe_types.type_validator import TypeValidator
11
11
  from griptape_nodes.machines.fsm import FSM, State
12
- from griptape_nodes.machines.node_resolution import NodeResolutionMachine
12
+ from griptape_nodes.machines.parallel_resolution import ParallelResolutionMachine
13
+ from griptape_nodes.machines.sequential_resolution import SequentialResolutionMachine
13
14
  from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent
14
15
  from griptape_nodes.retained_mode.events.execution_events import (
15
16
  ControlFlowResolvedEvent,
@@ -17,6 +18,7 @@ from griptape_nodes.retained_mode.events.execution_events import (
17
18
  SelectedControlOutputEvent,
18
19
  )
19
20
  from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
21
+ from griptape_nodes.retained_mode.managers.settings import WorkflowExecutionMode
20
22
 
21
23
 
22
24
  @dataclass
@@ -37,40 +39,73 @@ logger = logging.getLogger("griptape_nodes")
37
39
  # This is the control flow context. Owns the Resolution Machine
38
40
  class ControlFlowContext:
39
41
  flow: ControlFlow
40
- current_node: BaseNode | None
41
- resolution_machine: NodeResolutionMachine
42
+ current_nodes: list[BaseNode]
43
+ resolution_machine: ParallelResolutionMachine | SequentialResolutionMachine
42
44
  selected_output: Parameter | None
43
45
  paused: bool = False
46
+ flow_name: str
47
+
48
+ def __init__(
49
+ self,
50
+ flow_name: str,
51
+ max_nodes_in_parallel: int,
52
+ *,
53
+ execution_type: WorkflowExecutionMode = WorkflowExecutionMode.SEQUENTIAL,
54
+ ) -> None:
55
+ self.flow_name = flow_name
56
+ if execution_type == WorkflowExecutionMode.PARALLEL:
57
+ # Get the global DagBuilder from FlowManager
58
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
59
+
60
+ dag_builder = GriptapeNodes.FlowManager().global_dag_builder
61
+ self.resolution_machine = ParallelResolutionMachine(
62
+ flow_name, max_nodes_in_parallel, dag_builder=dag_builder
63
+ )
64
+ else:
65
+ self.resolution_machine = SequentialResolutionMachine()
66
+ self.current_nodes = []
44
67
 
45
- def __init__(self) -> None:
46
- self.resolution_machine = NodeResolutionMachine()
47
- self.current_node = None
48
-
49
- def get_next_node(self, output_parameter: Parameter) -> NextNodeInfo | None:
50
- """Get the next node and the target parameter that will receive the control flow.
68
+ def get_next_nodes(self, output_parameter: Parameter | None = None) -> list[NextNodeInfo]:
69
+ """Get all next nodes from the current nodes.
51
70
 
52
71
  Returns:
53
- NextNodeInfo | None: Information about the next node or None if no connection
72
+ list[NextNodeInfo]: List of next nodes to process
54
73
  """
55
- if self.current_node is not None:
56
- node_connection = (
57
- GriptapeNodes.FlowManager().get_connections().get_connected_node(self.current_node, output_parameter)
58
- )
59
- if node_connection is not None:
60
- node, entry_parameter = node_connection
61
- return NextNodeInfo(node=node, entry_parameter=entry_parameter)
62
- # Continue Execution to the next node that needs to be executed using global execution queue
63
- # Get the next node in the execution queue, or None if queue is empty
74
+ next_nodes = []
75
+ for current_node in self.current_nodes:
76
+ if output_parameter is not None:
77
+ # Get connected node from control flow
78
+ node_connection = (
79
+ GriptapeNodes.FlowManager().get_connections().get_connected_node(current_node, output_parameter)
80
+ )
81
+ if node_connection is not None:
82
+ node, entry_parameter = node_connection
83
+ next_nodes.append(NextNodeInfo(node=node, entry_parameter=entry_parameter))
84
+ else:
85
+ # Get next control output for this node
86
+ next_output = current_node.get_next_control_output()
87
+ if next_output is not None:
88
+ node_connection = (
89
+ GriptapeNodes.FlowManager().get_connections().get_connected_node(current_node, next_output)
90
+ )
91
+ if node_connection is not None:
92
+ node, entry_parameter = node_connection
93
+ next_nodes.append(NextNodeInfo(node=node, entry_parameter=entry_parameter))
94
+
95
+ # If no connections found, check execution queue
96
+ if not next_nodes:
64
97
  node = GriptapeNodes.FlowManager().get_next_node_from_execution_queue()
65
98
  if node is not None:
66
- return NextNodeInfo(node=node, entry_parameter=None)
67
- return None
99
+ next_nodes.append(NextNodeInfo(node=node, entry_parameter=None))
100
+
101
+ return next_nodes
68
102
 
69
- def reset(self) -> None:
70
- if self.current_node:
71
- self.current_node.clear_node()
72
- self.current_node = None
73
- self.resolution_machine.reset_machine()
103
+ def reset(self, *, cancel: bool = False) -> None:
104
+ if self.current_nodes is not None:
105
+ for node in self.current_nodes:
106
+ node.clear_node()
107
+ self.current_nodes = []
108
+ self.resolution_machine.reset_machine(cancel=cancel)
74
109
  self.selected_output = None
75
110
  self.paused = False
76
111
 
@@ -80,24 +115,25 @@ class ResolveNodeState(State):
80
115
  @staticmethod
81
116
  async def on_enter(context: ControlFlowContext) -> type[State] | None:
82
117
  # The state machine has started, but it hasn't began to execute yet.
83
- if context.current_node is None:
118
+ if len(context.current_nodes) == 0:
84
119
  # We don't have anything else to do. Move back to Complete State so it has to restart.
85
120
  return CompleteState
86
121
 
87
- # Mark the node unresolved, and broadcast an event to the GUI.
88
- if not context.current_node.lock:
89
- context.current_node.make_node_unresolved(
90
- current_states_to_trigger_change_event=set(
91
- {NodeResolutionState.UNRESOLVED, NodeResolutionState.RESOLVED, NodeResolutionState.RESOLVING}
122
+ # Mark all current nodes unresolved and broadcast events
123
+ for current_node in context.current_nodes:
124
+ if not current_node.lock:
125
+ current_node.make_node_unresolved(
126
+ current_states_to_trigger_change_event=set(
127
+ {NodeResolutionState.UNRESOLVED, NodeResolutionState.RESOLVED, NodeResolutionState.RESOLVING}
128
+ )
129
+ )
130
+ # Now broadcast that we have a current control node.
131
+ GriptapeNodes.EventManager().put_event(
132
+ ExecutionGriptapeNodeEvent(
133
+ wrapped_event=ExecutionEvent(payload=CurrentControlNodeEvent(node_name=current_node.name))
92
134
  )
93
135
  )
94
- # Now broadcast that we have a current control node.
95
- GriptapeNodes.EventManager().put_event(
96
- ExecutionGriptapeNodeEvent(
97
- wrapped_event=ExecutionEvent(payload=CurrentControlNodeEvent(node_name=context.current_node.name))
98
- )
99
- )
100
- logger.info("Resolving %s", context.current_node.name)
136
+ logger.info("Resolving %s", current_node.name)
101
137
  if not context.paused:
102
138
  # Call the update. Otherwise wait
103
139
  return ResolveNodeState
@@ -106,13 +142,17 @@ class ResolveNodeState(State):
106
142
  # This is necessary to transition to the next step.
107
143
  @staticmethod
108
144
  async def on_update(context: ControlFlowContext) -> type[State] | None:
109
- # If node has not already been resolved!
110
- if context.current_node is None:
145
+ # If no current nodes, we're done
146
+ if len(context.current_nodes) == 0:
111
147
  return CompleteState
112
- if context.current_node.state != NodeResolutionState.RESOLVED:
113
- await context.resolution_machine.resolve_node(context.current_node)
148
+
149
+ # Resolve nodes - pass first node for sequential resolution
150
+ current_node = context.current_nodes[0] if context.current_nodes else None
151
+ await context.resolution_machine.resolve_node(current_node)
114
152
 
115
153
  if context.resolution_machine.is_complete():
154
+ if isinstance(context.resolution_machine, ParallelResolutionMachine):
155
+ return CompleteState
116
156
  return NextNodeState
117
157
  return None
118
158
 
@@ -120,44 +160,49 @@ class ResolveNodeState(State):
120
160
  class NextNodeState(State):
121
161
  @staticmethod
122
162
  async def on_enter(context: ControlFlowContext) -> type[State] | None:
123
- if context.current_node is None:
163
+ if len(context.current_nodes) == 0:
124
164
  return CompleteState
125
- # I did define this on the ControlNode.
126
- if context.current_node.stop_flow:
127
- # We're done here.
128
- context.current_node.stop_flow = False
165
+
166
+ # Check for stop_flow on any current nodes
167
+ for current_node in context.current_nodes[:]:
168
+ if current_node.stop_flow:
169
+ current_node.stop_flow = False
170
+ context.current_nodes.remove(current_node)
171
+
172
+ # If all nodes stopped flow, complete
173
+ if len(context.current_nodes) == 0:
129
174
  return CompleteState
130
- next_output = context.current_node.get_next_control_output()
131
- next_node_info = None
132
175
 
133
- if next_output is not None:
134
- context.selected_output = next_output
135
- next_node_info = context.get_next_node(context.selected_output)
136
- GriptapeNodes.EventManager().put_event(
137
- ExecutionGriptapeNodeEvent(
138
- wrapped_event=ExecutionEvent(
139
- payload=SelectedControlOutputEvent(
140
- node_name=context.current_node.name,
141
- selected_output_parameter_name=next_output.name,
176
+ # Get all next nodes from current nodes
177
+ next_node_infos = context.get_next_nodes()
178
+
179
+ # Broadcast selected control output events for nodes with outputs
180
+ for current_node in context.current_nodes:
181
+ next_output = current_node.get_next_control_output()
182
+ if next_output is not None:
183
+ context.selected_output = next_output
184
+ GriptapeNodes.EventManager().put_event(
185
+ ExecutionGriptapeNodeEvent(
186
+ wrapped_event=ExecutionEvent(
187
+ payload=SelectedControlOutputEvent(
188
+ node_name=current_node.name,
189
+ selected_output_parameter_name=next_output.name,
190
+ )
142
191
  )
143
192
  )
144
193
  )
145
- )
146
- else:
147
- # Get the next node in the execution queue, or None if queue is empty
148
- next_node = GriptapeNodes.FlowManager().get_next_node_from_execution_queue()
149
- if next_node is not None:
150
- next_node_info = NextNodeInfo(node=next_node, entry_parameter=None)
151
-
152
- # The parameter that will be evaluated next
153
- if next_node_info is None:
154
- # If no node attached
194
+
195
+ # If no next nodes, we're complete
196
+ if not next_node_infos:
155
197
  return CompleteState
156
198
 
157
- # Always set the entry control parameter (None for execution queue nodes)
158
- next_node_info.node.set_entry_control_parameter(next_node_info.entry_parameter)
199
+ # Set up next nodes as current nodes
200
+ next_nodes = []
201
+ for next_node_info in next_node_infos:
202
+ next_node_info.node.set_entry_control_parameter(next_node_info.entry_parameter)
203
+ next_nodes.append(next_node_info.node)
159
204
 
160
- context.current_node = next_node_info.node
205
+ context.current_nodes = next_nodes
161
206
  context.selected_output = None
162
207
  if not context.paused:
163
208
  return ResolveNodeState
@@ -171,15 +216,14 @@ class NextNodeState(State):
171
216
  class CompleteState(State):
172
217
  @staticmethod
173
218
  async def on_enter(context: ControlFlowContext) -> type[State] | None:
174
- if context.current_node is not None:
219
+ # Broadcast completion events for any remaining current nodes
220
+ for current_node in context.current_nodes:
175
221
  GriptapeNodes.EventManager().put_event(
176
222
  ExecutionGriptapeNodeEvent(
177
223
  wrapped_event=ExecutionEvent(
178
224
  payload=ControlFlowResolvedEvent(
179
- end_node_name=context.current_node.name,
180
- parameter_output_values=TypeValidator.safe_serialize(
181
- context.current_node.parameter_output_values
182
- ),
225
+ end_node_name=current_node.name,
226
+ parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
183
227
  )
184
228
  )
185
229
  )
@@ -194,14 +238,24 @@ class CompleteState(State):
194
238
 
195
239
  # MACHINE TIME!!!
196
240
  class ControlFlowMachine(FSM[ControlFlowContext]):
197
- def __init__(self) -> None:
198
- context = ControlFlowContext()
241
+ def __init__(self, flow_name: str) -> None:
242
+ execution_type = GriptapeNodes.ConfigManager().get_config_value(
243
+ "workflow_execution_mode", default=WorkflowExecutionMode.SEQUENTIAL
244
+ )
245
+ max_nodes_in_parallel = GriptapeNodes.ConfigManager().get_config_value("max_nodes_in_parallel", default=5)
246
+ context = ControlFlowContext(flow_name, max_nodes_in_parallel, execution_type=execution_type)
199
247
  super().__init__(context)
200
248
 
201
249
  async def start_flow(self, start_node: BaseNode, debug_mode: bool = False) -> None: # noqa: FBT001, FBT002
202
- self._context.current_node = start_node
250
+ # If using DAG resolution, process data_nodes from queue first
251
+ if isinstance(self._context.resolution_machine, ParallelResolutionMachine):
252
+ current_nodes = await self._process_nodes_for_dag(start_node)
253
+ else:
254
+ current_nodes = [start_node]
255
+ self._context.current_nodes = current_nodes
203
256
  # Set entry control parameter for initial node (None for workflow start)
204
- start_node.set_entry_control_parameter(None)
257
+ for node in current_nodes:
258
+ node.set_entry_control_parameter(None)
205
259
  # Set up to debug
206
260
  self._context.paused = debug_mode
207
261
  await self.start(ResolveNodeState) # Begins the flow
@@ -214,31 +268,78 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
214
268
 
215
269
  def change_debug_mode(self, debug_mode: bool) -> None: # noqa: FBT001
216
270
  self._context.paused = debug_mode
217
- self._context.resolution_machine.change_debug_mode(debug_mode)
271
+ self._context.resolution_machine.change_debug_mode(debug_mode=debug_mode)
218
272
 
219
273
  async def granular_step(self, change_debug_mode: bool) -> None: # noqa: FBT001
220
274
  resolution_machine = self._context.resolution_machine
275
+
221
276
  if change_debug_mode:
222
- resolution_machine.change_debug_mode(True)
277
+ resolution_machine.change_debug_mode(debug_mode=True)
223
278
  await resolution_machine.update()
224
279
 
225
- # Tick the control flow if the resolution machine inside it isn't busy.
226
- if resolution_machine.is_complete() or not resolution_machine.is_started(): # noqa: SIM102
280
+ # Tick the control flow if the current machine isn't busy
281
+ if self._current_state is ResolveNodeState and ( # noqa: SIM102
282
+ resolution_machine.is_complete() or not resolution_machine.is_started()
283
+ ):
227
284
  # Don't tick ourselves if we are already complete.
228
285
  if self._current_state is not None:
229
286
  await self.update()
230
287
 
231
288
  async def node_step(self) -> None:
232
289
  resolution_machine = self._context.resolution_machine
233
- resolution_machine.change_debug_mode(False)
234
- await resolution_machine.update()
235
290
 
236
- # Tick the control flow if the resolution machine inside it isn't busy.
237
- if resolution_machine.is_complete() or not resolution_machine.is_started(): # noqa: SIM102
238
- # Don't tick ourselves if we are already complete.
239
- if self._current_state is not None:
240
- await self.update()
291
+ resolution_machine.change_debug_mode(debug_mode=False)
292
+
293
+ # If we're in the resolution phase, step the resolution machine
294
+ if self._current_state is ResolveNodeState:
295
+ await resolution_machine.update()
241
296
 
242
- def reset_machine(self) -> None:
243
- self._context.reset()
297
+ # Tick the control flow if the current machine isn't busy
298
+ if self._current_state is ResolveNodeState and (
299
+ resolution_machine.is_complete() or not resolution_machine.is_started()
300
+ ):
301
+ await self.update()
302
+
303
+ async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
304
+ """Process data_nodes from the global queue to build unified DAG.
305
+
306
+ This method identifies data_nodes in the execution queue and processes
307
+ their dependencies into the DAG resolution machine.
308
+ """
309
+ if not isinstance(self._context.resolution_machine, ParallelResolutionMachine):
310
+ return []
311
+ # Get the global flow queue
312
+ flow_manager = GriptapeNodes.FlowManager()
313
+ dag_builder = flow_manager.global_dag_builder
314
+ if dag_builder is None:
315
+ msg = "DAG builder is not initialized."
316
+ raise ValueError(msg)
317
+ # Build with the first node:
318
+ dag_builder.add_node_with_dependencies(start_node, start_node.name)
319
+ queue_items = list(flow_manager.global_flow_queue.queue)
320
+ start_nodes = [start_node]
321
+ # Find data_nodes and remove them from queue
322
+ for item in queue_items:
323
+ from griptape_nodes.retained_mode.managers.flow_manager import DagExecutionType
324
+
325
+ if item.dag_execution_type in (DagExecutionType.CONTROL_NODE, DagExecutionType.START_NODE):
326
+ node = item.node
327
+ node.state = NodeResolutionState.UNRESOLVED
328
+ dag_builder.add_node_with_dependencies(node, node.name)
329
+ flow_manager.global_flow_queue.queue.remove(item)
330
+ start_nodes.append(node)
331
+ elif item.dag_execution_type == DagExecutionType.DATA_NODE:
332
+ node = item.node
333
+ node.state = NodeResolutionState.UNRESOLVED
334
+ # Build here.
335
+ dag_builder.add_node_with_dependencies(node, node.name)
336
+ flow_manager.global_flow_queue.queue.remove(item)
337
+ return start_nodes
338
+
339
+ def reset_machine(self, *, cancel: bool = False) -> None:
340
+ self._context.reset(cancel=cancel)
244
341
  self._current_state = None
342
+
343
+ @property
344
+ def resolution_machine(self) -> ParallelResolutionMachine | SequentialResolutionMachine:
345
+ return self._context.resolution_machine
@@ -0,0 +1,207 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass, field
5
+ from enum import StrEnum
6
+ from typing import TYPE_CHECKING
7
+
8
+ from griptape_nodes.common.directed_graph import DirectedGraph
9
+ from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
10
+ from griptape_nodes.exe_types.node_types import NodeResolutionState
11
+
12
+ if TYPE_CHECKING:
13
+ import asyncio
14
+
15
+ from griptape_nodes.exe_types.connections import Connections
16
+ from griptape_nodes.exe_types.node_types import BaseNode
17
+
18
+ logger = logging.getLogger("griptape_nodes")
19
+
20
+
21
+ class NodeState(StrEnum):
22
+ """Individual node execution states."""
23
+
24
+ QUEUED = "queued"
25
+ PROCESSING = "processing"
26
+ DONE = "done"
27
+ CANCELED = "canceled"
28
+ ERRORED = "errored"
29
+ WAITING = "waiting"
30
+
31
+
32
+ @dataclass(kw_only=True)
33
+ class DagNode:
34
+ """Represents a node in the DAG with runtime references."""
35
+
36
+ task_reference: asyncio.Task | None = field(default=None)
37
+ node_state: NodeState = field(default=NodeState.WAITING)
38
+ node_reference: BaseNode
39
+
40
+
41
+ class DagBuilder:
42
+ """Handles DAG construction independently of execution state machine."""
43
+
44
+ graphs: dict[str, DirectedGraph] # Str is the name of the start node associated here.
45
+ node_to_reference: dict[str, DagNode]
46
+
47
+ def __init__(self) -> None:
48
+ self.graphs = {}
49
+ self.node_to_reference: dict[str, DagNode] = {}
50
+
51
+ # Complex with the inner recursive method, but it needs connections and added_nodes.
52
+ def add_node_with_dependencies(self, node: BaseNode, graph_name: str = "default") -> list[BaseNode]: # noqa: C901
53
+ """Add node and all its dependencies to DAG. Returns list of added nodes."""
54
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
55
+
56
+ connections = GriptapeNodes.FlowManager().get_connections()
57
+ added_nodes = []
58
+ graph = self.graphs.get(graph_name, None)
59
+ if graph is None:
60
+ graph = DirectedGraph()
61
+ self.graphs[graph_name] = graph
62
+
63
+ def _add_node_recursive(current_node: BaseNode, visited: set[str], graph: DirectedGraph) -> None:
64
+ if current_node.name in visited:
65
+ return
66
+ visited.add(current_node.name)
67
+
68
+ # Skip if already in DAG (use DAG membership, not resolved state)
69
+ if current_node.name in self.node_to_reference:
70
+ return
71
+
72
+ # Process dependencies first (depth-first)
73
+ ignore_data_dependencies = False
74
+ # This is specifically for output_selector. Overriding 'initialize_spotlight' doesn't work anymore.
75
+ if hasattr(current_node, "ignore_dependencies"):
76
+ ignore_data_dependencies = True
77
+ for param in current_node.parameters:
78
+ if param.type == ParameterTypeBuiltin.CONTROL_TYPE:
79
+ continue
80
+ if ignore_data_dependencies:
81
+ continue
82
+ upstream_connection = connections.get_connected_node(current_node, param)
83
+ if upstream_connection:
84
+ upstream_node, _ = upstream_connection
85
+ # Don't add nodes that have already been resolved.
86
+ if upstream_node.state == NodeResolutionState.RESOLVED:
87
+ continue
88
+ # If upstream is already in DAG, skip creating edge (it's in another graph)
89
+ if upstream_node.name in self.node_to_reference:
90
+ graph.add_edge(upstream_node.name, current_node.name)
91
+ # Otherwise, add it to DAG first then create edge
92
+ else:
93
+ _add_node_recursive(upstream_node, visited, graph)
94
+ graph.add_edge(upstream_node.name, current_node.name)
95
+
96
+ # Add current node to DAG (but keep original resolution state)
97
+
98
+ dag_node = DagNode(node_reference=current_node, node_state=NodeState.WAITING)
99
+ self.node_to_reference[current_node.name] = dag_node
100
+ graph.add_node(node_for_adding=current_node.name)
101
+ # DON'T mark as resolved - that happens during actual execution
102
+ added_nodes.append(current_node)
103
+
104
+ _add_node_recursive(node, set(), graph)
105
+
106
+ return added_nodes
107
+
108
+ def add_node(self, node: BaseNode, graph_name: str = "default") -> DagNode:
109
+ """Add just one node to DAG without dependencies (assumes dependencies already exist)."""
110
+ if node.name in self.node_to_reference:
111
+ return self.node_to_reference[node.name]
112
+
113
+ dag_node = DagNode(node_reference=node, node_state=NodeState.WAITING)
114
+ self.node_to_reference[node.name] = dag_node
115
+ graph = self.graphs.get(graph_name, None)
116
+ if graph is None:
117
+ graph = DirectedGraph()
118
+ self.graphs[graph_name] = graph
119
+ graph.add_node(node_for_adding=node.name)
120
+ return dag_node
121
+
122
+ def clear(self) -> None:
123
+ """Clear all nodes and references from the DAG builder."""
124
+ self.graphs.clear()
125
+ self.node_to_reference.clear()
126
+
127
+ def can_queue_control_node(self, node: DagNode) -> bool:
128
+ if len(self.graphs) == 1:
129
+ return True
130
+
131
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
132
+
133
+ connections = GriptapeNodes.FlowManager().get_connections()
134
+
135
+ control_connections = self.get_number_incoming_control_connections(node.node_reference, connections)
136
+ if control_connections <= 1:
137
+ return True
138
+
139
+ for graph in self.graphs.values():
140
+ # If the length of the graph is 0, skip it. it's either reached it or it's a dead end.
141
+ if len(graph.nodes()) == 0:
142
+ continue
143
+
144
+ # If graph has nodes, the root node (not the leaf, the root), check forward path from that
145
+ root_nodes = [n for n in graph.nodes() if graph.out_degree(n) == 0]
146
+ for root_node_name in root_nodes:
147
+ if root_node_name in self.node_to_reference:
148
+ root_node = self.node_to_reference[root_node_name].node_reference
149
+
150
+ # Skip if the root node is the same as the target node - it can't reach itself
151
+ if root_node == node.node_reference:
152
+ continue
153
+
154
+ # Check if the target node is in the forward path from this root
155
+ if self._is_node_in_forward_path(root_node, node.node_reference, connections):
156
+ return False # This graph could still reach the target node
157
+
158
+ # Otherwise, return true at the end of the function
159
+ return True
160
+
161
+ def get_number_incoming_control_connections(self, node: BaseNode, connections: Connections) -> int:
162
+ if node.name not in connections.incoming_index:
163
+ return 0
164
+
165
+ control_connection_count = 0
166
+ node_connections = connections.incoming_index[node.name]
167
+
168
+ for param_name, connection_ids in node_connections.items():
169
+ # Find the parameter to check if it's a control type
170
+ param = node.get_parameter_by_name(param_name)
171
+ if param and ParameterTypeBuiltin.CONTROL_TYPE.value in param.input_types:
172
+ control_connection_count += len(connection_ids)
173
+
174
+ return control_connection_count
175
+
176
+ def _is_node_in_forward_path(
177
+ self, start_node: BaseNode, target_node: BaseNode, connections: Connections, visited: set[str] | None = None
178
+ ) -> bool:
179
+ """Check if target_node is reachable from start_node through control flow connections."""
180
+ if visited is None:
181
+ visited = set()
182
+
183
+ if start_node.name in visited:
184
+ return False
185
+ visited.add(start_node.name)
186
+
187
+ # Check ALL outgoing control connections, not just get_next_control_output()
188
+ # This handles IfElse nodes that have multiple possible control outputs
189
+ if start_node.name in connections.outgoing_index:
190
+ for param_name, connection_ids in connections.outgoing_index[start_node.name].items():
191
+ # Find the parameter to check if it's a control type
192
+ param = start_node.get_parameter_by_name(param_name)
193
+ if param and param.output_type == ParameterTypeBuiltin.CONTROL_TYPE.value:
194
+ # This is a control parameter - check all its connections
195
+ for connection_id in connection_ids:
196
+ if connection_id in connections.connections:
197
+ connection = connections.connections[connection_id]
198
+ next_node = connection.target_node
199
+
200
+ if next_node.name == target_node.name:
201
+ return True
202
+
203
+ # Recursively check the forward path
204
+ if self._is_node_in_forward_path(next_node, target_node, connections, visited):
205
+ return True
206
+
207
+ return False
@@ -34,9 +34,18 @@ class FSM[T]:
34
34
  # Enter the initial state.
35
35
  await self.transition_state(initial_state)
36
36
 
37
- def get_current_state(self) -> type[State] | None:
37
+ @property
38
+ def current_state(self) -> type[State] | None:
38
39
  return self._current_state
39
40
 
41
+ @current_state.setter
42
+ def current_state(self, value: type[State] | None) -> None:
43
+ self._current_state = value
44
+
45
+ @property
46
+ def context(self) -> T:
47
+ return self._context
48
+
40
49
  async def transition_state(self, new_state: type[State] | None) -> None:
41
50
  while new_state is not None:
42
51
  # Exit the current state.