griptape-nodes 0.53.0__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +5 -2
- griptape_nodes/app/app.py +4 -26
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/commands/config.py +4 -1
- griptape_nodes/cli/commands/init.py +5 -3
- griptape_nodes/cli/commands/libraries.py +14 -8
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +5 -2
- griptape_nodes/cli/main.py +11 -1
- griptape_nodes/cli/shared.py +0 -9
- griptape_nodes/common/directed_graph.py +17 -1
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +17 -13
- griptape_nodes/exe_types/node_types.py +219 -14
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +129 -92
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/parallel_resolution.py +264 -276
- griptape_nodes/machines/sequential_resolution.py +9 -7
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +7 -7
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/griptape_nodes.py +10 -1
- griptape_nodes/retained_mode/managers/agent_manager.py +14 -0
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +45 -14
- griptape_nodes/retained_mode/managers/library_manager.py +3 -3
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +26 -26
- griptape_nodes/retained_mode/managers/object_manager.py +1 -1
- griptape_nodes/retained_mode/managers/os_manager.py +6 -6
- griptape_nodes/retained_mode/managers/settings.py +87 -9
- griptape_nodes/retained_mode/managers/static_files_manager.py +77 -9
- griptape_nodes/retained_mode/managers/sync_manager.py +10 -5
- griptape_nodes/retained_mode/managers/workflow_manager.py +98 -92
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/button.py +124 -6
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +3 -1
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +56 -47
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
|
@@ -2,14 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import logging
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
5
|
from enum import StrEnum
|
|
7
|
-
from typing import
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
8
7
|
|
|
9
|
-
from griptape_nodes.
|
|
10
|
-
from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
|
|
8
|
+
from griptape_nodes.exe_types.core_types import Parameter, ParameterType, ParameterTypeBuiltin
|
|
11
9
|
from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
|
|
12
10
|
from griptape_nodes.exe_types.type_validator import TypeValidator
|
|
11
|
+
from griptape_nodes.machines.dag_builder import NodeState
|
|
13
12
|
from griptape_nodes.machines.fsm import FSM, State
|
|
14
13
|
from griptape_nodes.node_library.library_registry import LibraryRegistry
|
|
15
14
|
from griptape_nodes.retained_mode.events.base_events import (
|
|
@@ -19,38 +18,16 @@ from griptape_nodes.retained_mode.events.base_events import (
|
|
|
19
18
|
from griptape_nodes.retained_mode.events.execution_events import (
|
|
20
19
|
CurrentDataNodeEvent,
|
|
21
20
|
NodeResolvedEvent,
|
|
22
|
-
ParameterSpotlightEvent,
|
|
23
21
|
ParameterValueUpdateEvent,
|
|
24
22
|
)
|
|
25
23
|
from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
|
|
26
24
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
"""Individual node execution states."""
|
|
32
|
-
|
|
33
|
-
QUEUED = "queued"
|
|
34
|
-
PROCESSING = "processing"
|
|
35
|
-
DONE = "done"
|
|
36
|
-
CANCELED = "canceled"
|
|
37
|
-
ERRORED = "errored"
|
|
38
|
-
WAITING = "waiting"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@dataclass(kw_only=True)
|
|
42
|
-
class DagNode:
|
|
43
|
-
"""Represents a node in the DAG with runtime references."""
|
|
44
|
-
|
|
45
|
-
task_reference: asyncio.Task | None = field(default=None)
|
|
46
|
-
node_state: NodeState = field(default=NodeState.WAITING)
|
|
47
|
-
node_reference: BaseNode
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from griptape_nodes.common.directed_graph import DirectedGraph
|
|
27
|
+
from griptape_nodes.machines.dag_builder import DagBuilder, DagNode
|
|
28
|
+
from griptape_nodes.retained_mode.managers.flow_manager import FlowManager
|
|
48
29
|
|
|
49
|
-
|
|
50
|
-
@dataclass
|
|
51
|
-
class Focus:
|
|
52
|
-
node: BaseNode
|
|
53
|
-
scheduled_value: Any | None = None
|
|
30
|
+
logger = logging.getLogger("griptape_nodes")
|
|
54
31
|
|
|
55
32
|
|
|
56
33
|
class WorkflowState(StrEnum):
|
|
@@ -63,175 +40,77 @@ class WorkflowState(StrEnum):
|
|
|
63
40
|
|
|
64
41
|
|
|
65
42
|
class ParallelResolutionContext:
|
|
66
|
-
focus_stack: list[Focus]
|
|
67
43
|
paused: bool
|
|
68
44
|
flow_name: str
|
|
69
|
-
build_only: bool
|
|
70
|
-
batched_nodes: list[BaseNode]
|
|
71
45
|
error_message: str | None
|
|
72
46
|
workflow_state: WorkflowState
|
|
73
|
-
#
|
|
74
|
-
network: DirectedGraph
|
|
75
|
-
node_to_reference: dict[str, DagNode]
|
|
47
|
+
# Execution fields
|
|
76
48
|
async_semaphore: asyncio.Semaphore
|
|
77
49
|
task_to_node: dict[asyncio.Task, DagNode]
|
|
50
|
+
dag_builder: DagBuilder | None
|
|
78
51
|
|
|
79
|
-
def __init__(
|
|
52
|
+
def __init__(
|
|
53
|
+
self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
|
|
54
|
+
) -> None:
|
|
80
55
|
self.flow_name = flow_name
|
|
81
|
-
self.focus_stack = []
|
|
82
56
|
self.paused = False
|
|
83
|
-
self.build_only = False
|
|
84
|
-
self.batched_nodes = []
|
|
85
57
|
self.error_message = None
|
|
86
58
|
self.workflow_state = WorkflowState.NO_ERROR
|
|
59
|
+
self.dag_builder = dag_builder
|
|
87
60
|
|
|
88
|
-
# Initialize
|
|
89
|
-
self.network = DirectedGraph()
|
|
90
|
-
self.node_to_reference = {}
|
|
61
|
+
# Initialize execution fields
|
|
91
62
|
max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
|
|
92
63
|
self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
|
|
93
64
|
self.task_to_node = {}
|
|
94
65
|
|
|
66
|
+
@property
|
|
67
|
+
def node_to_reference(self) -> dict[str, DagNode]:
|
|
68
|
+
"""Get node_to_reference from dag_builder if available."""
|
|
69
|
+
if not self.dag_builder:
|
|
70
|
+
msg = "DagBuilder is not initialized"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
return self.dag_builder.node_to_reference
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def networks(self) -> dict[str, DirectedGraph]:
|
|
76
|
+
"""Get node_to_reference from dag_builder if available."""
|
|
77
|
+
if not self.dag_builder:
|
|
78
|
+
msg = "DagBuilder is not initialized"
|
|
79
|
+
raise ValueError(msg)
|
|
80
|
+
return self.dag_builder.graphs
|
|
81
|
+
|
|
95
82
|
def reset(self, *, cancel: bool = False) -> None:
|
|
96
|
-
if self.focus_stack:
|
|
97
|
-
node = self.focus_stack[-1].node
|
|
98
|
-
node.clear_node()
|
|
99
|
-
self.focus_stack.clear()
|
|
100
83
|
self.paused = False
|
|
101
84
|
if cancel:
|
|
102
85
|
self.workflow_state = WorkflowState.CANCELED
|
|
103
|
-
|
|
104
|
-
|
|
86
|
+
# Only access node_to_reference if dag_builder exists
|
|
87
|
+
if self.dag_builder:
|
|
88
|
+
for node in self.node_to_reference.values():
|
|
89
|
+
node.node_state = NodeState.CANCELED
|
|
105
90
|
else:
|
|
106
91
|
self.workflow_state = WorkflowState.NO_ERROR
|
|
107
92
|
self.error_message = None
|
|
108
|
-
self.network.clear()
|
|
109
|
-
self.node_to_reference.clear()
|
|
110
93
|
self.task_to_node.clear()
|
|
111
94
|
|
|
95
|
+
# Clear DAG builder state to allow re-adding nodes on subsequent runs
|
|
96
|
+
if self.dag_builder:
|
|
97
|
+
self.dag_builder.clear()
|
|
112
98
|
|
|
113
|
-
class InitializeDagSpotlightState(State):
|
|
114
|
-
@staticmethod
|
|
115
|
-
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
116
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
117
|
-
|
|
118
|
-
current_node = context.focus_stack[-1].node
|
|
119
|
-
GriptapeNodes.EventManager().put_event(
|
|
120
|
-
ExecutionGriptapeNodeEvent(
|
|
121
|
-
wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=current_node.name))
|
|
122
|
-
)
|
|
123
|
-
)
|
|
124
|
-
if not context.paused:
|
|
125
|
-
return InitializeDagSpotlightState
|
|
126
|
-
return None
|
|
127
99
|
|
|
100
|
+
class ExecuteDagState(State):
|
|
128
101
|
@staticmethod
|
|
129
|
-
async def
|
|
130
|
-
|
|
131
|
-
return DagCompleteState
|
|
132
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
133
|
-
|
|
134
|
-
current_node = context.focus_stack[-1].node
|
|
135
|
-
if current_node.state == NodeResolutionState.UNRESOLVED:
|
|
136
|
-
GriptapeNodes.FlowManager().get_connections().unresolve_future_nodes(current_node)
|
|
137
|
-
current_node.initialize_spotlight()
|
|
138
|
-
current_node.state = NodeResolutionState.RESOLVING
|
|
139
|
-
if current_node.get_current_parameter() is None:
|
|
140
|
-
if current_node.advance_parameter():
|
|
141
|
-
return EvaluateDagParameterState
|
|
142
|
-
return BuildDagNodeState
|
|
143
|
-
return EvaluateDagParameterState
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class EvaluateDagParameterState(State):
|
|
147
|
-
@staticmethod
|
|
148
|
-
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
149
|
-
current_node = context.focus_stack[-1].node
|
|
150
|
-
current_parameter = current_node.get_current_parameter()
|
|
151
|
-
if current_parameter is None:
|
|
152
|
-
return BuildDagNodeState
|
|
153
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
102
|
+
async def handle_done_nodes(context: ParallelResolutionContext, done_node: DagNode, network_name: str) -> None:
|
|
103
|
+
current_node = done_node.node_reference
|
|
154
104
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
)
|
|
162
|
-
)
|
|
105
|
+
# Check if node was already resolved (shouldn't happen)
|
|
106
|
+
if current_node.state == NodeResolutionState.RESOLVED:
|
|
107
|
+
logger.error(
|
|
108
|
+
"DUPLICATE COMPLETION DETECTED: Node '%s' was already RESOLVED but handle_done_nodes was called again from network '%s'. This should not happen!",
|
|
109
|
+
current_node.name,
|
|
110
|
+
network_name,
|
|
163
111
|
)
|
|
164
|
-
|
|
165
|
-
if not context.paused:
|
|
166
|
-
return EvaluateDagParameterState
|
|
167
|
-
return None
|
|
168
|
-
|
|
169
|
-
@staticmethod
|
|
170
|
-
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
171
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
172
|
-
|
|
173
|
-
current_node = context.focus_stack[-1].node
|
|
174
|
-
current_parameter = current_node.get_current_parameter()
|
|
175
|
-
connections = GriptapeNodes.FlowManager().get_connections()
|
|
176
|
-
if current_parameter is None:
|
|
177
|
-
msg = "No current parameter set."
|
|
178
|
-
raise ValueError(msg)
|
|
179
|
-
next_node = connections.get_connected_node(current_node, current_parameter)
|
|
180
|
-
if next_node:
|
|
181
|
-
next_node, _ = next_node
|
|
182
|
-
if next_node:
|
|
183
|
-
if next_node.state == NodeResolutionState.UNRESOLVED:
|
|
184
|
-
focus_stack_names = {focus.node.name for focus in context.focus_stack}
|
|
185
|
-
if next_node.name in focus_stack_names:
|
|
186
|
-
msg = f"Cycle detected between node '{current_node.name}' and '{next_node.name}'."
|
|
187
|
-
raise RuntimeError(msg)
|
|
188
|
-
context.network.add_edge(next_node.name, current_node.name)
|
|
189
|
-
context.focus_stack.append(Focus(node=next_node))
|
|
190
|
-
return InitializeDagSpotlightState
|
|
191
|
-
if next_node.state == NodeResolutionState.RESOLVED and next_node in context.batched_nodes:
|
|
192
|
-
context.network.add_edge(next_node.name, current_node.name)
|
|
193
|
-
if current_node.advance_parameter():
|
|
194
|
-
return InitializeDagSpotlightState
|
|
195
|
-
return BuildDagNodeState
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
class BuildDagNodeState(State):
|
|
199
|
-
@staticmethod
|
|
200
|
-
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
201
|
-
current_node = context.focus_stack[-1].node
|
|
202
|
-
|
|
203
|
-
# Add the current node to the DAG
|
|
204
|
-
node_reference = DagNode(node_reference=current_node)
|
|
205
|
-
context.node_to_reference[current_node.name] = node_reference
|
|
206
|
-
# Add node name to DAG (has to be a hashable value)
|
|
207
|
-
context.network.add_node(node_for_adding=current_node.name)
|
|
208
|
-
|
|
209
|
-
if not context.paused:
|
|
210
|
-
return BuildDagNodeState
|
|
211
|
-
return None
|
|
212
|
-
|
|
213
|
-
@staticmethod
|
|
214
|
-
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
215
|
-
current_node = context.focus_stack[-1].node
|
|
112
|
+
return
|
|
216
113
|
|
|
217
|
-
# Mark node as resolved for DAG building purposes
|
|
218
|
-
current_node.state = NodeResolutionState.RESOLVED
|
|
219
|
-
# Add to batched nodes
|
|
220
|
-
context.batched_nodes.append(current_node)
|
|
221
|
-
|
|
222
|
-
context.focus_stack.pop()
|
|
223
|
-
if len(context.focus_stack):
|
|
224
|
-
return EvaluateDagParameterState
|
|
225
|
-
|
|
226
|
-
if context.build_only:
|
|
227
|
-
return DagCompleteState
|
|
228
|
-
return ExecuteDagState
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
class ExecuteDagState(State):
|
|
232
|
-
@staticmethod
|
|
233
|
-
def handle_done_nodes(done_node: DagNode) -> None:
|
|
234
|
-
current_node = done_node.node_reference
|
|
235
114
|
# Publish all parameter updates.
|
|
236
115
|
current_node.state = NodeResolutionState.RESOLVED
|
|
237
116
|
# Serialization can be slow so only do it if the user wants debug details.
|
|
@@ -252,7 +131,7 @@ class ExecuteDagState(State):
|
|
|
252
131
|
data_type = ParameterTypeBuiltin.NONE.value
|
|
253
132
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
254
133
|
|
|
255
|
-
GriptapeNodes.EventManager().
|
|
134
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
256
135
|
ExecutionGriptapeNodeEvent(
|
|
257
136
|
wrapped_event=ExecutionEvent(
|
|
258
137
|
payload=ParameterValueUpdateEvent(
|
|
@@ -272,7 +151,7 @@ class ExecuteDagState(State):
|
|
|
272
151
|
library_name = None
|
|
273
152
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
274
153
|
|
|
275
|
-
GriptapeNodes.EventManager().
|
|
154
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
276
155
|
ExecutionGriptapeNodeEvent(
|
|
277
156
|
wrapped_event=ExecutionEvent(
|
|
278
157
|
payload=NodeResolvedEvent(
|
|
@@ -284,9 +163,104 @@ class ExecuteDagState(State):
|
|
|
284
163
|
)
|
|
285
164
|
)
|
|
286
165
|
)
|
|
166
|
+
# Now the final thing to do, is to take their directed graph and update it.
|
|
167
|
+
ExecuteDagState.get_next_control_graph(context, current_node, network_name)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def get_next_control_graph(context: ParallelResolutionContext, node: BaseNode, network_name: str) -> None:
|
|
171
|
+
"""Get next control flow nodes and add them to the DAG graph."""
|
|
172
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
173
|
+
|
|
174
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
175
|
+
|
|
176
|
+
# Early returns for various conditions
|
|
177
|
+
if ExecuteDagState._should_skip_control_flow(context, node, network_name, flow_manager):
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
next_output = node.get_next_control_output()
|
|
181
|
+
if next_output is not None:
|
|
182
|
+
ExecuteDagState._process_next_control_node(context, node, next_output, network_name, flow_manager)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _should_skip_control_flow(
|
|
186
|
+
context: ParallelResolutionContext, node: BaseNode, network_name: str, flow_manager: FlowManager
|
|
187
|
+
) -> bool:
|
|
188
|
+
"""Check if control flow processing should be skipped."""
|
|
189
|
+
if flow_manager.global_single_node_resolution:
|
|
190
|
+
return True
|
|
191
|
+
|
|
192
|
+
if context.dag_builder is not None:
|
|
193
|
+
network = context.dag_builder.graphs.get(network_name, None)
|
|
194
|
+
if network is not None and len(network) > 0:
|
|
195
|
+
return True
|
|
196
|
+
|
|
197
|
+
if node.stop_flow:
|
|
198
|
+
node.stop_flow = False
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def _process_next_control_node(
|
|
205
|
+
context: ParallelResolutionContext,
|
|
206
|
+
node: BaseNode,
|
|
207
|
+
next_output: Parameter,
|
|
208
|
+
network_name: str,
|
|
209
|
+
flow_manager: FlowManager,
|
|
210
|
+
) -> None:
|
|
211
|
+
"""Process the next control node in the flow."""
|
|
212
|
+
node_connection = flow_manager.get_connections().get_connected_node(node, next_output)
|
|
213
|
+
if node_connection is not None:
|
|
214
|
+
next_node, _ = node_connection
|
|
215
|
+
|
|
216
|
+
# Prepare next node for execution
|
|
217
|
+
if not next_node.lock:
|
|
218
|
+
next_node.make_node_unresolved(
|
|
219
|
+
current_states_to_trigger_change_event=set(
|
|
220
|
+
{
|
|
221
|
+
NodeResolutionState.UNRESOLVED,
|
|
222
|
+
NodeResolutionState.RESOLVED,
|
|
223
|
+
NodeResolutionState.RESOLVING,
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
ExecuteDagState._add_and_queue_nodes(context, next_node, network_name)
|
|
229
|
+
|
|
230
|
+
@staticmethod
|
|
231
|
+
def _add_and_queue_nodes(context: ParallelResolutionContext, next_node: BaseNode, network_name: str) -> None:
|
|
232
|
+
"""Add nodes to DAG and queue them if ready."""
|
|
233
|
+
if context.dag_builder is not None:
|
|
234
|
+
added_nodes = context.dag_builder.add_node_with_dependencies(next_node, network_name)
|
|
235
|
+
if next_node not in added_nodes:
|
|
236
|
+
added_nodes.append(next_node)
|
|
237
|
+
|
|
238
|
+
# Queue nodes that are ready for execution
|
|
239
|
+
if added_nodes:
|
|
240
|
+
for added_node in added_nodes:
|
|
241
|
+
ExecuteDagState._try_queue_waiting_node(context, added_node.name)
|
|
287
242
|
|
|
288
243
|
@staticmethod
|
|
289
|
-
def
|
|
244
|
+
def _try_queue_waiting_node(context: ParallelResolutionContext, node_name: str) -> None:
|
|
245
|
+
"""Try to queue a specific waiting node if it can now be queued."""
|
|
246
|
+
if context.dag_builder is None:
|
|
247
|
+
logger.warning("DAG builder is None - cannot check queueing for node '%s'", node_name)
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
if node_name not in context.node_to_reference:
|
|
251
|
+
logger.warning("Node '%s' not found in node_to_reference - cannot check queueing", node_name)
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
dag_node = context.node_to_reference[node_name]
|
|
255
|
+
|
|
256
|
+
# Only check nodes that are currently waiting
|
|
257
|
+
if dag_node.node_state == NodeState.WAITING:
|
|
258
|
+
can_queue = context.dag_builder.can_queue_control_node(dag_node)
|
|
259
|
+
if can_queue:
|
|
260
|
+
dag_node.node_state = NodeState.QUEUED
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
async def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
|
|
290
264
|
"""Collect output values from resolved upstream nodes and pass them to the current node.
|
|
291
265
|
|
|
292
266
|
This method iterates through all input parameters of the current node, finds their
|
|
@@ -318,76 +292,77 @@ class ExecuteDagState(State):
|
|
|
318
292
|
output_value = upstream_node.get_parameter_value(upstream_parameter.name)
|
|
319
293
|
|
|
320
294
|
# Pass the value through using the same mechanism as normal resolution
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
295
|
+
# Skip propagation for Control Parameters as they should not receive values
|
|
296
|
+
if (
|
|
297
|
+
ParameterType.attempt_get_builtin(upstream_parameter.output_type)
|
|
298
|
+
!= ParameterTypeBuiltin.CONTROL_TYPE
|
|
299
|
+
):
|
|
300
|
+
await GriptapeNodes.get_instance().ahandle_request(
|
|
301
|
+
SetParameterValueRequest(
|
|
302
|
+
parameter_name=parameter.name,
|
|
303
|
+
node_name=current_node.name,
|
|
304
|
+
value=output_value,
|
|
305
|
+
data_type=upstream_parameter.output_type,
|
|
306
|
+
incoming_connection_source_node_name=upstream_node.name,
|
|
307
|
+
incoming_connection_source_parameter_name=upstream_parameter.name,
|
|
308
|
+
)
|
|
329
309
|
)
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
@staticmethod
|
|
333
|
-
def clear_parameter_output_values(node_reference: DagNode) -> None:
|
|
334
|
-
"""Clear all parameter output values for the given node and publish events.
|
|
335
|
-
|
|
336
|
-
This method iterates through each parameter output value stored in the node,
|
|
337
|
-
removes it from the node's parameter_output_values dictionary, and publishes an event
|
|
338
|
-
to notify the system about the parameter value being set to None.
|
|
339
|
-
|
|
340
|
-
Args:
|
|
341
|
-
node_reference (DagOrchestrator.DagNode): The DAG node to clear values for.
|
|
342
|
-
|
|
343
|
-
Raises:
|
|
344
|
-
ValueError: If a parameter name in parameter_output_values doesn't correspond
|
|
345
|
-
to an actual parameter in the node.
|
|
346
|
-
"""
|
|
347
|
-
current_node = node_reference.node_reference
|
|
348
|
-
for parameter_name in current_node.parameter_output_values:
|
|
349
|
-
parameter = current_node.get_parameter_by_name(parameter_name)
|
|
350
|
-
if parameter is None:
|
|
351
|
-
err = f"Attempted to clear output values for node '{current_node.name}' but could not find parameter '{parameter_name}' that was indicated as having a value."
|
|
352
|
-
raise ValueError(err)
|
|
353
|
-
parameter_type = parameter.type
|
|
354
|
-
if parameter_type is None:
|
|
355
|
-
parameter_type = ParameterTypeBuiltin.NONE.value
|
|
356
|
-
payload = ParameterValueUpdateEvent(
|
|
357
|
-
node_name=current_node.name,
|
|
358
|
-
parameter_name=parameter_name,
|
|
359
|
-
data_type=parameter_type,
|
|
360
|
-
value=None,
|
|
361
|
-
)
|
|
362
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
363
|
-
|
|
364
|
-
GriptapeNodes.EventManager().put_event(
|
|
365
|
-
ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=payload))
|
|
366
|
-
)
|
|
367
|
-
current_node.parameter_output_values.clear()
|
|
368
310
|
|
|
369
311
|
@staticmethod
|
|
370
|
-
def build_node_states(context: ParallelResolutionContext) -> tuple[
|
|
371
|
-
|
|
372
|
-
leaf_nodes =
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
312
|
+
def build_node_states(context: ParallelResolutionContext) -> tuple[set[str], set[str], set[str]]:
|
|
313
|
+
networks = context.networks
|
|
314
|
+
leaf_nodes = set()
|
|
315
|
+
for network in networks.values():
|
|
316
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
317
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
318
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
319
|
+
# Add all leaf nodes from all networks (using set union to avoid duplicates)
|
|
320
|
+
leaf_nodes.update([n for n in network.nodes() if network.in_degree(n) == 0])
|
|
321
|
+
canceled_nodes = set()
|
|
322
|
+
queued_nodes = set()
|
|
376
323
|
for node in leaf_nodes:
|
|
377
324
|
node_reference = context.node_to_reference[node]
|
|
378
325
|
# If the node is locked, mark it as done so it skips execution
|
|
379
326
|
if node_reference.node_reference.lock:
|
|
380
327
|
node_reference.node_state = NodeState.DONE
|
|
381
|
-
done_nodes.append(node)
|
|
382
328
|
continue
|
|
383
329
|
node_state = node_reference.node_state
|
|
384
|
-
if node_state == NodeState.
|
|
385
|
-
|
|
386
|
-
elif node_state == NodeState.CANCELED:
|
|
387
|
-
canceled_nodes.append(node)
|
|
330
|
+
if node_state == NodeState.CANCELED:
|
|
331
|
+
canceled_nodes.add(node)
|
|
388
332
|
elif node_state == NodeState.QUEUED:
|
|
389
|
-
queued_nodes.
|
|
390
|
-
return
|
|
333
|
+
queued_nodes.add(node)
|
|
334
|
+
return canceled_nodes, queued_nodes, leaf_nodes
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
async def pop_done_states(context: ParallelResolutionContext) -> None:
|
|
338
|
+
networks = context.networks
|
|
339
|
+
handled_nodes = set() # Track nodes we've already processed to avoid duplicates
|
|
340
|
+
|
|
341
|
+
for network_name, network in networks.items():
|
|
342
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
343
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
344
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
345
|
+
leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
|
|
346
|
+
for node in leaf_nodes:
|
|
347
|
+
node_reference = context.node_to_reference[node]
|
|
348
|
+
node_state = node_reference.node_state
|
|
349
|
+
# If the node is locked, mark it as done so it skips execution
|
|
350
|
+
if node_reference.node_reference.lock or node_state == NodeState.DONE:
|
|
351
|
+
node_reference.node_state = NodeState.DONE
|
|
352
|
+
network.remove_node(node)
|
|
353
|
+
|
|
354
|
+
# Only call handle_done_nodes once per node (first network that processes it)
|
|
355
|
+
if node not in handled_nodes:
|
|
356
|
+
handled_nodes.add(node)
|
|
357
|
+
await ExecuteDagState.handle_done_nodes(context, context.node_to_reference[node], network_name)
|
|
358
|
+
|
|
359
|
+
# After processing completions in this network, check if any remaining leaf nodes can now be queued
|
|
360
|
+
remaining_leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
|
|
361
|
+
|
|
362
|
+
for leaf_node in remaining_leaf_nodes:
|
|
363
|
+
if leaf_node in context.node_to_reference:
|
|
364
|
+
node_state = context.node_to_reference[leaf_node].node_state
|
|
365
|
+
ExecuteDagState._try_queue_waiting_node(context, leaf_node)
|
|
391
366
|
|
|
392
367
|
@staticmethod
|
|
393
368
|
async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
|
|
@@ -397,32 +372,28 @@ class ExecuteDagState(State):
|
|
|
397
372
|
@staticmethod
|
|
398
373
|
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
399
374
|
# Start DAG execution after resolution is complete
|
|
400
|
-
context.batched_nodes.clear()
|
|
401
375
|
for node in context.node_to_reference.values():
|
|
402
|
-
#
|
|
403
|
-
node.node_state
|
|
376
|
+
# Only queue nodes that are waiting - preserve state of already processed nodes.
|
|
377
|
+
if node.node_state == NodeState.WAITING:
|
|
378
|
+
node.node_state = NodeState.QUEUED
|
|
379
|
+
|
|
404
380
|
context.workflow_state = WorkflowState.NO_ERROR
|
|
381
|
+
|
|
405
382
|
if not context.paused:
|
|
406
383
|
return ExecuteDagState
|
|
407
384
|
return None
|
|
408
385
|
|
|
409
386
|
@staticmethod
|
|
410
|
-
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
387
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: C901, PLR0911
|
|
388
|
+
# Check if execution is paused
|
|
389
|
+
if context.paused:
|
|
390
|
+
return None
|
|
391
|
+
|
|
411
392
|
# Check if DAG execution is complete
|
|
412
|
-
network = context.network
|
|
413
393
|
# Check and see if there are leaf nodes that are cancelled.
|
|
414
|
-
done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
415
|
-
# Are there any nodes in Done state?
|
|
416
|
-
for node in done_nodes:
|
|
417
|
-
# We have nodes in done state.
|
|
418
|
-
# Remove the leaf node from the graph.
|
|
419
|
-
network.remove_node(node)
|
|
420
|
-
# Return thread to thread pool.
|
|
421
|
-
ExecuteDagState.handle_done_nodes(context.node_to_reference[node])
|
|
422
394
|
# Reinitialize leaf nodes since maybe we changed things up.
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
395
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
396
|
+
canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
426
397
|
# We have no more leaf nodes. Quit early.
|
|
427
398
|
if not leaf_nodes:
|
|
428
399
|
context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
|
|
@@ -439,7 +410,7 @@ class ExecuteDagState(State):
|
|
|
439
410
|
|
|
440
411
|
# Collect parameter values from upstream nodes before executing
|
|
441
412
|
try:
|
|
442
|
-
ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
|
|
413
|
+
await ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
|
|
443
414
|
except Exception as e:
|
|
444
415
|
logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
|
|
445
416
|
context.error_message = (
|
|
@@ -448,25 +419,25 @@ class ExecuteDagState(State):
|
|
|
448
419
|
context.workflow_state = WorkflowState.ERRORED
|
|
449
420
|
return ErrorState
|
|
450
421
|
|
|
451
|
-
# Clear
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
)
|
|
458
|
-
context.error_message = (
|
|
459
|
-
f"Parameter clearing failed for node '{node_reference.node_reference.name}': {e}"
|
|
460
|
-
)
|
|
461
|
-
context.workflow_state = WorkflowState.ERRORED
|
|
422
|
+
# Clear all of the current output values but don't broadcast the clearing.
|
|
423
|
+
# to avoid any flickering in subscribers (UI).
|
|
424
|
+
node_reference.node_reference.parameter_output_values.silent_clear()
|
|
425
|
+
exceptions = node_reference.node_reference.validate_before_node_run()
|
|
426
|
+
if exceptions:
|
|
427
|
+
msg = f"Canceling flow run. Node '{node_reference.node_reference.name}' encountered problems: {exceptions}"
|
|
428
|
+
logger.error(msg)
|
|
462
429
|
return ErrorState
|
|
463
430
|
|
|
464
431
|
def on_task_done(task: asyncio.Task) -> None:
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
432
|
+
if task in context.task_to_node:
|
|
433
|
+
node = context.task_to_node[task]
|
|
434
|
+
node.node_state = NodeState.DONE
|
|
468
435
|
|
|
469
436
|
# Execute the node asynchronously
|
|
437
|
+
logger.debug(
|
|
438
|
+
"CREATING EXECUTION TASK for node '%s' - this should only happen once per node!",
|
|
439
|
+
node_reference.node_reference.name,
|
|
440
|
+
)
|
|
470
441
|
node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
|
|
471
442
|
# Add a callback to set node to done when task has finished.
|
|
472
443
|
context.task_to_node[node_task] = node_reference
|
|
@@ -474,9 +445,23 @@ class ExecuteDagState(State):
|
|
|
474
445
|
node_task.add_done_callback(lambda t: on_task_done(t))
|
|
475
446
|
node_reference.node_state = NodeState.PROCESSING
|
|
476
447
|
node_reference.node_reference.state = NodeResolutionState.RESOLVING
|
|
448
|
+
|
|
449
|
+
# Send an event that this is a current data node:
|
|
450
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
451
|
+
|
|
452
|
+
await GriptapeNodes.EventManager().aput_event(
|
|
453
|
+
ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=node)))
|
|
454
|
+
)
|
|
477
455
|
# Wait for a task to finish
|
|
478
|
-
await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
456
|
+
done, _ = await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
457
|
+
# Prevent task being removed before return
|
|
458
|
+
for task in done:
|
|
459
|
+
context.task_to_node.pop(task)
|
|
479
460
|
# Once a task has finished, loop back to the top.
|
|
461
|
+
await ExecuteDagState.pop_done_states(context)
|
|
462
|
+
# Remove all nodes that are done
|
|
463
|
+
if context.paused:
|
|
464
|
+
return None
|
|
480
465
|
return ExecuteDagState
|
|
481
466
|
|
|
482
467
|
|
|
@@ -519,7 +504,7 @@ class ErrorState(State):
|
|
|
519
504
|
if len(task_to_node) == 0:
|
|
520
505
|
# Finish up. We failed.
|
|
521
506
|
context.workflow_state = WorkflowState.ERRORED
|
|
522
|
-
context.
|
|
507
|
+
context.networks.clear()
|
|
523
508
|
context.node_to_reference.clear()
|
|
524
509
|
context.task_to_node.clear()
|
|
525
510
|
return DagCompleteState
|
|
@@ -530,8 +515,9 @@ class ErrorState(State):
|
|
|
530
515
|
class DagCompleteState(State):
|
|
531
516
|
@staticmethod
|
|
532
517
|
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
533
|
-
#
|
|
534
|
-
context.
|
|
518
|
+
# Clear the DAG builder so we don't have any leftover nodes in node_to_reference.
|
|
519
|
+
if context.dag_builder is not None:
|
|
520
|
+
context.dag_builder.clear()
|
|
535
521
|
return None
|
|
536
522
|
|
|
537
523
|
@staticmethod
|
|
@@ -542,19 +528,21 @@ class DagCompleteState(State):
|
|
|
542
528
|
class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
|
|
543
529
|
"""State machine for building DAG structure without execution."""
|
|
544
530
|
|
|
545
|
-
def __init__(
|
|
546
|
-
|
|
531
|
+
def __init__(
|
|
532
|
+
self, flow_name: str, max_nodes_in_parallel: int | None = None, dag_builder: DagBuilder | None = None
|
|
533
|
+
) -> None:
|
|
534
|
+
resolution_context = ParallelResolutionContext(
|
|
535
|
+
flow_name, max_nodes_in_parallel=max_nodes_in_parallel, dag_builder=dag_builder
|
|
536
|
+
)
|
|
547
537
|
super().__init__(resolution_context)
|
|
548
538
|
|
|
549
|
-
async def resolve_node(self, node: BaseNode
|
|
550
|
-
"""
|
|
551
|
-
|
|
552
|
-
self._context.build_only = build_only
|
|
553
|
-
await self.start(InitializeDagSpotlightState)
|
|
539
|
+
async def resolve_node(self, node: BaseNode | None = None) -> None: # noqa: ARG002
|
|
540
|
+
"""Execute the DAG structure using the existing DagBuilder."""
|
|
541
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
554
542
|
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
await self.
|
|
543
|
+
if self.context.dag_builder is None:
|
|
544
|
+
self.context.dag_builder = GriptapeNodes.FlowManager().global_dag_builder
|
|
545
|
+
await self.start(ExecuteDagState)
|
|
558
546
|
|
|
559
547
|
def change_debug_mode(self, *, debug_mode: bool) -> None:
|
|
560
548
|
self._context.paused = debug_mode
|