griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +8 -942
- griptape_nodes/__main__.py +6 -0
- griptape_nodes/app/app.py +48 -86
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/__init__.py +1 -0
- griptape_nodes/cli/commands/__init__.py +1 -0
- griptape_nodes/cli/commands/config.py +74 -0
- griptape_nodes/cli/commands/engine.py +80 -0
- griptape_nodes/cli/commands/init.py +550 -0
- griptape_nodes/cli/commands/libraries.py +96 -0
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +120 -0
- griptape_nodes/cli/main.py +56 -0
- griptape_nodes/cli/shared.py +75 -0
- griptape_nodes/common/__init__.py +1 -0
- griptape_nodes/common/directed_graph.py +71 -0
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
- griptape_nodes/exe_types/core_types.py +60 -2
- griptape_nodes/exe_types/node_types.py +257 -38
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +195 -94
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/fsm.py +10 -1
- griptape_nodes/machines/parallel_resolution.py +558 -0
- griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +9 -9
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/events/node_events.py +4 -3
- griptape_nodes/retained_mode/griptape_nodes.py +34 -12
- griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
- griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/context_manager.py +6 -5
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
- griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
- griptape_nodes/retained_mode/managers/library_manager.py +35 -25
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +102 -220
- griptape_nodes/retained_mode/managers/object_manager.py +11 -5
- griptape_nodes/retained_mode/managers/os_manager.py +28 -13
- griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
- griptape_nodes/retained_mode/managers/settings.py +116 -7
- griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
- griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
- griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/add_param_button.py +1 -1
- griptape_nodes/traits/button.py +334 -6
- griptape_nodes/traits/color_picker.py +66 -0
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/traits/traits.json +4 -0
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
|
@@ -9,7 +9,8 @@ from griptape_nodes.exe_types.core_types import Parameter
|
|
|
9
9
|
from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
|
|
10
10
|
from griptape_nodes.exe_types.type_validator import TypeValidator
|
|
11
11
|
from griptape_nodes.machines.fsm import FSM, State
|
|
12
|
-
from griptape_nodes.machines.
|
|
12
|
+
from griptape_nodes.machines.parallel_resolution import ParallelResolutionMachine
|
|
13
|
+
from griptape_nodes.machines.sequential_resolution import SequentialResolutionMachine
|
|
13
14
|
from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent
|
|
14
15
|
from griptape_nodes.retained_mode.events.execution_events import (
|
|
15
16
|
ControlFlowResolvedEvent,
|
|
@@ -17,6 +18,7 @@ from griptape_nodes.retained_mode.events.execution_events import (
|
|
|
17
18
|
SelectedControlOutputEvent,
|
|
18
19
|
)
|
|
19
20
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
21
|
+
from griptape_nodes.retained_mode.managers.settings import WorkflowExecutionMode
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@dataclass
|
|
@@ -37,40 +39,73 @@ logger = logging.getLogger("griptape_nodes")
|
|
|
37
39
|
# This is the control flow context. Owns the Resolution Machine
|
|
38
40
|
class ControlFlowContext:
|
|
39
41
|
flow: ControlFlow
|
|
40
|
-
|
|
41
|
-
resolution_machine:
|
|
42
|
+
current_nodes: list[BaseNode]
|
|
43
|
+
resolution_machine: ParallelResolutionMachine | SequentialResolutionMachine
|
|
42
44
|
selected_output: Parameter | None
|
|
43
45
|
paused: bool = False
|
|
46
|
+
flow_name: str
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
flow_name: str,
|
|
51
|
+
max_nodes_in_parallel: int,
|
|
52
|
+
*,
|
|
53
|
+
execution_type: WorkflowExecutionMode = WorkflowExecutionMode.SEQUENTIAL,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.flow_name = flow_name
|
|
56
|
+
if execution_type == WorkflowExecutionMode.PARALLEL:
|
|
57
|
+
# Get the global DagBuilder from FlowManager
|
|
58
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
59
|
+
|
|
60
|
+
dag_builder = GriptapeNodes.FlowManager().global_dag_builder
|
|
61
|
+
self.resolution_machine = ParallelResolutionMachine(
|
|
62
|
+
flow_name, max_nodes_in_parallel, dag_builder=dag_builder
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
self.resolution_machine = SequentialResolutionMachine()
|
|
66
|
+
self.current_nodes = []
|
|
44
67
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
self.current_node = None
|
|
48
|
-
|
|
49
|
-
def get_next_node(self, output_parameter: Parameter) -> NextNodeInfo | None:
|
|
50
|
-
"""Get the next node and the target parameter that will receive the control flow.
|
|
68
|
+
def get_next_nodes(self, output_parameter: Parameter | None = None) -> list[NextNodeInfo]:
|
|
69
|
+
"""Get all next nodes from the current nodes.
|
|
51
70
|
|
|
52
71
|
Returns:
|
|
53
|
-
NextNodeInfo
|
|
72
|
+
list[NextNodeInfo]: List of next nodes to process
|
|
54
73
|
"""
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
74
|
+
next_nodes = []
|
|
75
|
+
for current_node in self.current_nodes:
|
|
76
|
+
if output_parameter is not None:
|
|
77
|
+
# Get connected node from control flow
|
|
78
|
+
node_connection = (
|
|
79
|
+
GriptapeNodes.FlowManager().get_connections().get_connected_node(current_node, output_parameter)
|
|
80
|
+
)
|
|
81
|
+
if node_connection is not None:
|
|
82
|
+
node, entry_parameter = node_connection
|
|
83
|
+
next_nodes.append(NextNodeInfo(node=node, entry_parameter=entry_parameter))
|
|
84
|
+
else:
|
|
85
|
+
# Get next control output for this node
|
|
86
|
+
next_output = current_node.get_next_control_output()
|
|
87
|
+
if next_output is not None:
|
|
88
|
+
node_connection = (
|
|
89
|
+
GriptapeNodes.FlowManager().get_connections().get_connected_node(current_node, next_output)
|
|
90
|
+
)
|
|
91
|
+
if node_connection is not None:
|
|
92
|
+
node, entry_parameter = node_connection
|
|
93
|
+
next_nodes.append(NextNodeInfo(node=node, entry_parameter=entry_parameter))
|
|
94
|
+
|
|
95
|
+
# If no connections found, check execution queue
|
|
96
|
+
if not next_nodes:
|
|
64
97
|
node = GriptapeNodes.FlowManager().get_next_node_from_execution_queue()
|
|
65
98
|
if node is not None:
|
|
66
|
-
|
|
67
|
-
|
|
99
|
+
next_nodes.append(NextNodeInfo(node=node, entry_parameter=None))
|
|
100
|
+
|
|
101
|
+
return next_nodes
|
|
68
102
|
|
|
69
|
-
def reset(self) -> None:
|
|
70
|
-
if self.
|
|
71
|
-
self.
|
|
72
|
-
|
|
73
|
-
self.
|
|
103
|
+
def reset(self, *, cancel: bool = False) -> None:
|
|
104
|
+
if self.current_nodes is not None:
|
|
105
|
+
for node in self.current_nodes:
|
|
106
|
+
node.clear_node()
|
|
107
|
+
self.current_nodes = []
|
|
108
|
+
self.resolution_machine.reset_machine(cancel=cancel)
|
|
74
109
|
self.selected_output = None
|
|
75
110
|
self.paused = False
|
|
76
111
|
|
|
@@ -80,24 +115,25 @@ class ResolveNodeState(State):
|
|
|
80
115
|
@staticmethod
|
|
81
116
|
async def on_enter(context: ControlFlowContext) -> type[State] | None:
|
|
82
117
|
# The state machine has started, but it hasn't began to execute yet.
|
|
83
|
-
if context.
|
|
118
|
+
if len(context.current_nodes) == 0:
|
|
84
119
|
# We don't have anything else to do. Move back to Complete State so it has to restart.
|
|
85
120
|
return CompleteState
|
|
86
121
|
|
|
87
|
-
# Mark
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
122
|
+
# Mark all current nodes unresolved and broadcast events
|
|
123
|
+
for current_node in context.current_nodes:
|
|
124
|
+
if not current_node.lock:
|
|
125
|
+
current_node.make_node_unresolved(
|
|
126
|
+
current_states_to_trigger_change_event=set(
|
|
127
|
+
{NodeResolutionState.UNRESOLVED, NodeResolutionState.RESOLVED, NodeResolutionState.RESOLVING}
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
# Now broadcast that we have a current control node.
|
|
131
|
+
GriptapeNodes.EventManager().put_event(
|
|
132
|
+
ExecutionGriptapeNodeEvent(
|
|
133
|
+
wrapped_event=ExecutionEvent(payload=CurrentControlNodeEvent(node_name=current_node.name))
|
|
92
134
|
)
|
|
93
135
|
)
|
|
94
|
-
|
|
95
|
-
GriptapeNodes.EventManager().put_event(
|
|
96
|
-
ExecutionGriptapeNodeEvent(
|
|
97
|
-
wrapped_event=ExecutionEvent(payload=CurrentControlNodeEvent(node_name=context.current_node.name))
|
|
98
|
-
)
|
|
99
|
-
)
|
|
100
|
-
logger.info("Resolving %s", context.current_node.name)
|
|
136
|
+
logger.info("Resolving %s", current_node.name)
|
|
101
137
|
if not context.paused:
|
|
102
138
|
# Call the update. Otherwise wait
|
|
103
139
|
return ResolveNodeState
|
|
@@ -106,13 +142,17 @@ class ResolveNodeState(State):
|
|
|
106
142
|
# This is necessary to transition to the next step.
|
|
107
143
|
@staticmethod
|
|
108
144
|
async def on_update(context: ControlFlowContext) -> type[State] | None:
|
|
109
|
-
# If
|
|
110
|
-
if context.
|
|
145
|
+
# If no current nodes, we're done
|
|
146
|
+
if len(context.current_nodes) == 0:
|
|
111
147
|
return CompleteState
|
|
112
|
-
|
|
113
|
-
|
|
148
|
+
|
|
149
|
+
# Resolve nodes - pass first node for sequential resolution
|
|
150
|
+
current_node = context.current_nodes[0] if context.current_nodes else None
|
|
151
|
+
await context.resolution_machine.resolve_node(current_node)
|
|
114
152
|
|
|
115
153
|
if context.resolution_machine.is_complete():
|
|
154
|
+
if isinstance(context.resolution_machine, ParallelResolutionMachine):
|
|
155
|
+
return CompleteState
|
|
116
156
|
return NextNodeState
|
|
117
157
|
return None
|
|
118
158
|
|
|
@@ -120,44 +160,49 @@ class ResolveNodeState(State):
|
|
|
120
160
|
class NextNodeState(State):
|
|
121
161
|
@staticmethod
|
|
122
162
|
async def on_enter(context: ControlFlowContext) -> type[State] | None:
|
|
123
|
-
if context.
|
|
163
|
+
if len(context.current_nodes) == 0:
|
|
124
164
|
return CompleteState
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
165
|
+
|
|
166
|
+
# Check for stop_flow on any current nodes
|
|
167
|
+
for current_node in context.current_nodes[:]:
|
|
168
|
+
if current_node.stop_flow:
|
|
169
|
+
current_node.stop_flow = False
|
|
170
|
+
context.current_nodes.remove(current_node)
|
|
171
|
+
|
|
172
|
+
# If all nodes stopped flow, complete
|
|
173
|
+
if len(context.current_nodes) == 0:
|
|
129
174
|
return CompleteState
|
|
130
|
-
next_output = context.current_node.get_next_control_output()
|
|
131
|
-
next_node_info = None
|
|
132
175
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
176
|
+
# Get all next nodes from current nodes
|
|
177
|
+
next_node_infos = context.get_next_nodes()
|
|
178
|
+
|
|
179
|
+
# Broadcast selected control output events for nodes with outputs
|
|
180
|
+
for current_node in context.current_nodes:
|
|
181
|
+
next_output = current_node.get_next_control_output()
|
|
182
|
+
if next_output is not None:
|
|
183
|
+
context.selected_output = next_output
|
|
184
|
+
GriptapeNodes.EventManager().put_event(
|
|
185
|
+
ExecutionGriptapeNodeEvent(
|
|
186
|
+
wrapped_event=ExecutionEvent(
|
|
187
|
+
payload=SelectedControlOutputEvent(
|
|
188
|
+
node_name=current_node.name,
|
|
189
|
+
selected_output_parameter_name=next_output.name,
|
|
190
|
+
)
|
|
142
191
|
)
|
|
143
192
|
)
|
|
144
193
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
next_node = GriptapeNodes.FlowManager().get_next_node_from_execution_queue()
|
|
149
|
-
if next_node is not None:
|
|
150
|
-
next_node_info = NextNodeInfo(node=next_node, entry_parameter=None)
|
|
151
|
-
|
|
152
|
-
# The parameter that will be evaluated next
|
|
153
|
-
if next_node_info is None:
|
|
154
|
-
# If no node attached
|
|
194
|
+
|
|
195
|
+
# If no next nodes, we're complete
|
|
196
|
+
if not next_node_infos:
|
|
155
197
|
return CompleteState
|
|
156
198
|
|
|
157
|
-
#
|
|
158
|
-
|
|
199
|
+
# Set up next nodes as current nodes
|
|
200
|
+
next_nodes = []
|
|
201
|
+
for next_node_info in next_node_infos:
|
|
202
|
+
next_node_info.node.set_entry_control_parameter(next_node_info.entry_parameter)
|
|
203
|
+
next_nodes.append(next_node_info.node)
|
|
159
204
|
|
|
160
|
-
context.
|
|
205
|
+
context.current_nodes = next_nodes
|
|
161
206
|
context.selected_output = None
|
|
162
207
|
if not context.paused:
|
|
163
208
|
return ResolveNodeState
|
|
@@ -171,15 +216,14 @@ class NextNodeState(State):
|
|
|
171
216
|
class CompleteState(State):
|
|
172
217
|
@staticmethod
|
|
173
218
|
async def on_enter(context: ControlFlowContext) -> type[State] | None:
|
|
174
|
-
|
|
219
|
+
# Broadcast completion events for any remaining current nodes
|
|
220
|
+
for current_node in context.current_nodes:
|
|
175
221
|
GriptapeNodes.EventManager().put_event(
|
|
176
222
|
ExecutionGriptapeNodeEvent(
|
|
177
223
|
wrapped_event=ExecutionEvent(
|
|
178
224
|
payload=ControlFlowResolvedEvent(
|
|
179
|
-
end_node_name=
|
|
180
|
-
parameter_output_values=TypeValidator.safe_serialize(
|
|
181
|
-
context.current_node.parameter_output_values
|
|
182
|
-
),
|
|
225
|
+
end_node_name=current_node.name,
|
|
226
|
+
parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
|
|
183
227
|
)
|
|
184
228
|
)
|
|
185
229
|
)
|
|
@@ -194,14 +238,24 @@ class CompleteState(State):
|
|
|
194
238
|
|
|
195
239
|
# MACHINE TIME!!!
|
|
196
240
|
class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
197
|
-
def __init__(self) -> None:
|
|
198
|
-
|
|
241
|
+
def __init__(self, flow_name: str) -> None:
|
|
242
|
+
execution_type = GriptapeNodes.ConfigManager().get_config_value(
|
|
243
|
+
"workflow_execution_mode", default=WorkflowExecutionMode.SEQUENTIAL
|
|
244
|
+
)
|
|
245
|
+
max_nodes_in_parallel = GriptapeNodes.ConfigManager().get_config_value("max_nodes_in_parallel", default=5)
|
|
246
|
+
context = ControlFlowContext(flow_name, max_nodes_in_parallel, execution_type=execution_type)
|
|
199
247
|
super().__init__(context)
|
|
200
248
|
|
|
201
249
|
async def start_flow(self, start_node: BaseNode, debug_mode: bool = False) -> None: # noqa: FBT001, FBT002
|
|
202
|
-
|
|
250
|
+
# If using DAG resolution, process data_nodes from queue first
|
|
251
|
+
if isinstance(self._context.resolution_machine, ParallelResolutionMachine):
|
|
252
|
+
current_nodes = await self._process_nodes_for_dag(start_node)
|
|
253
|
+
else:
|
|
254
|
+
current_nodes = [start_node]
|
|
255
|
+
self._context.current_nodes = current_nodes
|
|
203
256
|
# Set entry control parameter for initial node (None for workflow start)
|
|
204
|
-
|
|
257
|
+
for node in current_nodes:
|
|
258
|
+
node.set_entry_control_parameter(None)
|
|
205
259
|
# Set up to debug
|
|
206
260
|
self._context.paused = debug_mode
|
|
207
261
|
await self.start(ResolveNodeState) # Begins the flow
|
|
@@ -214,31 +268,78 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
214
268
|
|
|
215
269
|
def change_debug_mode(self, debug_mode: bool) -> None: # noqa: FBT001
|
|
216
270
|
self._context.paused = debug_mode
|
|
217
|
-
self._context.resolution_machine.change_debug_mode(debug_mode)
|
|
271
|
+
self._context.resolution_machine.change_debug_mode(debug_mode=debug_mode)
|
|
218
272
|
|
|
219
273
|
async def granular_step(self, change_debug_mode: bool) -> None: # noqa: FBT001
|
|
220
274
|
resolution_machine = self._context.resolution_machine
|
|
275
|
+
|
|
221
276
|
if change_debug_mode:
|
|
222
|
-
resolution_machine.change_debug_mode(True)
|
|
277
|
+
resolution_machine.change_debug_mode(debug_mode=True)
|
|
223
278
|
await resolution_machine.update()
|
|
224
279
|
|
|
225
|
-
# Tick the control flow if the
|
|
226
|
-
if
|
|
280
|
+
# Tick the control flow if the current machine isn't busy
|
|
281
|
+
if self._current_state is ResolveNodeState and ( # noqa: SIM102
|
|
282
|
+
resolution_machine.is_complete() or not resolution_machine.is_started()
|
|
283
|
+
):
|
|
227
284
|
# Don't tick ourselves if we are already complete.
|
|
228
285
|
if self._current_state is not None:
|
|
229
286
|
await self.update()
|
|
230
287
|
|
|
231
288
|
async def node_step(self) -> None:
|
|
232
289
|
resolution_machine = self._context.resolution_machine
|
|
233
|
-
resolution_machine.change_debug_mode(False)
|
|
234
|
-
await resolution_machine.update()
|
|
235
290
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
291
|
+
resolution_machine.change_debug_mode(debug_mode=False)
|
|
292
|
+
|
|
293
|
+
# If we're in the resolution phase, step the resolution machine
|
|
294
|
+
if self._current_state is ResolveNodeState:
|
|
295
|
+
await resolution_machine.update()
|
|
241
296
|
|
|
242
|
-
|
|
243
|
-
self.
|
|
297
|
+
# Tick the control flow if the current machine isn't busy
|
|
298
|
+
if self._current_state is ResolveNodeState and (
|
|
299
|
+
resolution_machine.is_complete() or not resolution_machine.is_started()
|
|
300
|
+
):
|
|
301
|
+
await self.update()
|
|
302
|
+
|
|
303
|
+
async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
|
|
304
|
+
"""Process data_nodes from the global queue to build unified DAG.
|
|
305
|
+
|
|
306
|
+
This method identifies data_nodes in the execution queue and processes
|
|
307
|
+
their dependencies into the DAG resolution machine.
|
|
308
|
+
"""
|
|
309
|
+
if not isinstance(self._context.resolution_machine, ParallelResolutionMachine):
|
|
310
|
+
return []
|
|
311
|
+
# Get the global flow queue
|
|
312
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
313
|
+
dag_builder = flow_manager.global_dag_builder
|
|
314
|
+
if dag_builder is None:
|
|
315
|
+
msg = "DAG builder is not initialized."
|
|
316
|
+
raise ValueError(msg)
|
|
317
|
+
# Build with the first node:
|
|
318
|
+
dag_builder.add_node_with_dependencies(start_node, start_node.name)
|
|
319
|
+
queue_items = list(flow_manager.global_flow_queue.queue)
|
|
320
|
+
start_nodes = [start_node]
|
|
321
|
+
# Find data_nodes and remove them from queue
|
|
322
|
+
for item in queue_items:
|
|
323
|
+
from griptape_nodes.retained_mode.managers.flow_manager import DagExecutionType
|
|
324
|
+
|
|
325
|
+
if item.dag_execution_type in (DagExecutionType.CONTROL_NODE, DagExecutionType.START_NODE):
|
|
326
|
+
node = item.node
|
|
327
|
+
node.state = NodeResolutionState.UNRESOLVED
|
|
328
|
+
dag_builder.add_node_with_dependencies(node, node.name)
|
|
329
|
+
flow_manager.global_flow_queue.queue.remove(item)
|
|
330
|
+
start_nodes.append(node)
|
|
331
|
+
elif item.dag_execution_type == DagExecutionType.DATA_NODE:
|
|
332
|
+
node = item.node
|
|
333
|
+
node.state = NodeResolutionState.UNRESOLVED
|
|
334
|
+
# Build here.
|
|
335
|
+
dag_builder.add_node_with_dependencies(node, node.name)
|
|
336
|
+
flow_manager.global_flow_queue.queue.remove(item)
|
|
337
|
+
return start_nodes
|
|
338
|
+
|
|
339
|
+
def reset_machine(self, *, cancel: bool = False) -> None:
|
|
340
|
+
self._context.reset(cancel=cancel)
|
|
244
341
|
self._current_state = None
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def resolution_machine(self) -> ParallelResolutionMachine | SequentialResolutionMachine:
|
|
345
|
+
return self._context.resolution_machine
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from griptape_nodes.common.directed_graph import DirectedGraph
|
|
9
|
+
from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
|
|
10
|
+
from griptape_nodes.exe_types.node_types import NodeResolutionState
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import asyncio
|
|
14
|
+
|
|
15
|
+
from griptape_nodes.exe_types.connections import Connections
|
|
16
|
+
from griptape_nodes.exe_types.node_types import BaseNode
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("griptape_nodes")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NodeState(StrEnum):
|
|
22
|
+
"""Individual node execution states."""
|
|
23
|
+
|
|
24
|
+
QUEUED = "queued"
|
|
25
|
+
PROCESSING = "processing"
|
|
26
|
+
DONE = "done"
|
|
27
|
+
CANCELED = "canceled"
|
|
28
|
+
ERRORED = "errored"
|
|
29
|
+
WAITING = "waiting"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(kw_only=True)
|
|
33
|
+
class DagNode:
|
|
34
|
+
"""Represents a node in the DAG with runtime references."""
|
|
35
|
+
|
|
36
|
+
task_reference: asyncio.Task | None = field(default=None)
|
|
37
|
+
node_state: NodeState = field(default=NodeState.WAITING)
|
|
38
|
+
node_reference: BaseNode
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DagBuilder:
|
|
42
|
+
"""Handles DAG construction independently of execution state machine."""
|
|
43
|
+
|
|
44
|
+
graphs: dict[str, DirectedGraph] # Str is the name of the start node associated here.
|
|
45
|
+
node_to_reference: dict[str, DagNode]
|
|
46
|
+
|
|
47
|
+
def __init__(self) -> None:
|
|
48
|
+
self.graphs = {}
|
|
49
|
+
self.node_to_reference: dict[str, DagNode] = {}
|
|
50
|
+
|
|
51
|
+
# Complex with the inner recursive method, but it needs connections and added_nodes.
|
|
52
|
+
def add_node_with_dependencies(self, node: BaseNode, graph_name: str = "default") -> list[BaseNode]: # noqa: C901
|
|
53
|
+
"""Add node and all its dependencies to DAG. Returns list of added nodes."""
|
|
54
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
55
|
+
|
|
56
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
57
|
+
added_nodes = []
|
|
58
|
+
graph = self.graphs.get(graph_name, None)
|
|
59
|
+
if graph is None:
|
|
60
|
+
graph = DirectedGraph()
|
|
61
|
+
self.graphs[graph_name] = graph
|
|
62
|
+
|
|
63
|
+
def _add_node_recursive(current_node: BaseNode, visited: set[str], graph: DirectedGraph) -> None:
|
|
64
|
+
if current_node.name in visited:
|
|
65
|
+
return
|
|
66
|
+
visited.add(current_node.name)
|
|
67
|
+
|
|
68
|
+
# Skip if already in DAG (use DAG membership, not resolved state)
|
|
69
|
+
if current_node.name in self.node_to_reference:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
# Process dependencies first (depth-first)
|
|
73
|
+
ignore_data_dependencies = False
|
|
74
|
+
# This is specifically for output_selector. Overriding 'initialize_spotlight' doesn't work anymore.
|
|
75
|
+
if hasattr(current_node, "ignore_dependencies"):
|
|
76
|
+
ignore_data_dependencies = True
|
|
77
|
+
for param in current_node.parameters:
|
|
78
|
+
if param.type == ParameterTypeBuiltin.CONTROL_TYPE:
|
|
79
|
+
continue
|
|
80
|
+
if ignore_data_dependencies:
|
|
81
|
+
continue
|
|
82
|
+
upstream_connection = connections.get_connected_node(current_node, param)
|
|
83
|
+
if upstream_connection:
|
|
84
|
+
upstream_node, _ = upstream_connection
|
|
85
|
+
# Don't add nodes that have already been resolved.
|
|
86
|
+
if upstream_node.state == NodeResolutionState.RESOLVED:
|
|
87
|
+
continue
|
|
88
|
+
# If upstream is already in DAG, skip creating edge (it's in another graph)
|
|
89
|
+
if upstream_node.name in self.node_to_reference:
|
|
90
|
+
graph.add_edge(upstream_node.name, current_node.name)
|
|
91
|
+
# Otherwise, add it to DAG first then create edge
|
|
92
|
+
else:
|
|
93
|
+
_add_node_recursive(upstream_node, visited, graph)
|
|
94
|
+
graph.add_edge(upstream_node.name, current_node.name)
|
|
95
|
+
|
|
96
|
+
# Add current node to DAG (but keep original resolution state)
|
|
97
|
+
|
|
98
|
+
dag_node = DagNode(node_reference=current_node, node_state=NodeState.WAITING)
|
|
99
|
+
self.node_to_reference[current_node.name] = dag_node
|
|
100
|
+
graph.add_node(node_for_adding=current_node.name)
|
|
101
|
+
# DON'T mark as resolved - that happens during actual execution
|
|
102
|
+
added_nodes.append(current_node)
|
|
103
|
+
|
|
104
|
+
_add_node_recursive(node, set(), graph)
|
|
105
|
+
|
|
106
|
+
return added_nodes
|
|
107
|
+
|
|
108
|
+
def add_node(self, node: BaseNode, graph_name: str = "default") -> DagNode:
|
|
109
|
+
"""Add just one node to DAG without dependencies (assumes dependencies already exist)."""
|
|
110
|
+
if node.name in self.node_to_reference:
|
|
111
|
+
return self.node_to_reference[node.name]
|
|
112
|
+
|
|
113
|
+
dag_node = DagNode(node_reference=node, node_state=NodeState.WAITING)
|
|
114
|
+
self.node_to_reference[node.name] = dag_node
|
|
115
|
+
graph = self.graphs.get(graph_name, None)
|
|
116
|
+
if graph is None:
|
|
117
|
+
graph = DirectedGraph()
|
|
118
|
+
self.graphs[graph_name] = graph
|
|
119
|
+
graph.add_node(node_for_adding=node.name)
|
|
120
|
+
return dag_node
|
|
121
|
+
|
|
122
|
+
def clear(self) -> None:
|
|
123
|
+
"""Clear all nodes and references from the DAG builder."""
|
|
124
|
+
self.graphs.clear()
|
|
125
|
+
self.node_to_reference.clear()
|
|
126
|
+
|
|
127
|
+
def can_queue_control_node(self, node: DagNode) -> bool:
|
|
128
|
+
if len(self.graphs) == 1:
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
132
|
+
|
|
133
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
134
|
+
|
|
135
|
+
control_connections = self.get_number_incoming_control_connections(node.node_reference, connections)
|
|
136
|
+
if control_connections <= 1:
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
for graph in self.graphs.values():
|
|
140
|
+
# If the length of the graph is 0, skip it. it's either reached it or it's a dead end.
|
|
141
|
+
if len(graph.nodes()) == 0:
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
# If graph has nodes, the root node (not the leaf, the root), check forward path from that
|
|
145
|
+
root_nodes = [n for n in graph.nodes() if graph.out_degree(n) == 0]
|
|
146
|
+
for root_node_name in root_nodes:
|
|
147
|
+
if root_node_name in self.node_to_reference:
|
|
148
|
+
root_node = self.node_to_reference[root_node_name].node_reference
|
|
149
|
+
|
|
150
|
+
# Skip if the root node is the same as the target node - it can't reach itself
|
|
151
|
+
if root_node == node.node_reference:
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
# Check if the target node is in the forward path from this root
|
|
155
|
+
if self._is_node_in_forward_path(root_node, node.node_reference, connections):
|
|
156
|
+
return False # This graph could still reach the target node
|
|
157
|
+
|
|
158
|
+
# Otherwise, return true at the end of the function
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
def get_number_incoming_control_connections(self, node: BaseNode, connections: Connections) -> int:
|
|
162
|
+
if node.name not in connections.incoming_index:
|
|
163
|
+
return 0
|
|
164
|
+
|
|
165
|
+
control_connection_count = 0
|
|
166
|
+
node_connections = connections.incoming_index[node.name]
|
|
167
|
+
|
|
168
|
+
for param_name, connection_ids in node_connections.items():
|
|
169
|
+
# Find the parameter to check if it's a control type
|
|
170
|
+
param = node.get_parameter_by_name(param_name)
|
|
171
|
+
if param and ParameterTypeBuiltin.CONTROL_TYPE.value in param.input_types:
|
|
172
|
+
control_connection_count += len(connection_ids)
|
|
173
|
+
|
|
174
|
+
return control_connection_count
|
|
175
|
+
|
|
176
|
+
def _is_node_in_forward_path(
|
|
177
|
+
self, start_node: BaseNode, target_node: BaseNode, connections: Connections, visited: set[str] | None = None
|
|
178
|
+
) -> bool:
|
|
179
|
+
"""Check if target_node is reachable from start_node through control flow connections."""
|
|
180
|
+
if visited is None:
|
|
181
|
+
visited = set()
|
|
182
|
+
|
|
183
|
+
if start_node.name in visited:
|
|
184
|
+
return False
|
|
185
|
+
visited.add(start_node.name)
|
|
186
|
+
|
|
187
|
+
# Check ALL outgoing control connections, not just get_next_control_output()
|
|
188
|
+
# This handles IfElse nodes that have multiple possible control outputs
|
|
189
|
+
if start_node.name in connections.outgoing_index:
|
|
190
|
+
for param_name, connection_ids in connections.outgoing_index[start_node.name].items():
|
|
191
|
+
# Find the parameter to check if it's a control type
|
|
192
|
+
param = start_node.get_parameter_by_name(param_name)
|
|
193
|
+
if param and param.output_type == ParameterTypeBuiltin.CONTROL_TYPE.value:
|
|
194
|
+
# This is a control parameter - check all its connections
|
|
195
|
+
for connection_id in connection_ids:
|
|
196
|
+
if connection_id in connections.connections:
|
|
197
|
+
connection = connections.connections[connection_id]
|
|
198
|
+
next_node = connection.target_node
|
|
199
|
+
|
|
200
|
+
if next_node.name == target_node.name:
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
# Recursively check the forward path
|
|
204
|
+
if self._is_node_in_forward_path(next_node, target_node, connections, visited):
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
return False
|
griptape_nodes/machines/fsm.py
CHANGED
|
@@ -34,9 +34,18 @@ class FSM[T]:
|
|
|
34
34
|
# Enter the initial state.
|
|
35
35
|
await self.transition_state(initial_state)
|
|
36
36
|
|
|
37
|
-
|
|
37
|
+
@property
|
|
38
|
+
def current_state(self) -> type[State] | None:
|
|
38
39
|
return self._current_state
|
|
39
40
|
|
|
41
|
+
@current_state.setter
|
|
42
|
+
def current_state(self, value: type[State] | None) -> None:
|
|
43
|
+
self._current_state = value
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def context(self) -> T:
|
|
47
|
+
return self._context
|
|
48
|
+
|
|
40
49
|
async def transition_state(self, new_state: type[State] | None) -> None:
|
|
41
50
|
while new_state is not None:
|
|
42
51
|
# Exit the current state.
|