griptape-nodes 0.52.0__py3-none-any.whl → 0.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +6 -943
- griptape_nodes/__main__.py +6 -0
- griptape_nodes/app/api.py +1 -12
- griptape_nodes/app/app.py +256 -209
- griptape_nodes/cli/__init__.py +1 -0
- griptape_nodes/cli/commands/__init__.py +1 -0
- griptape_nodes/cli/commands/config.py +71 -0
- griptape_nodes/cli/commands/engine.py +80 -0
- griptape_nodes/cli/commands/init.py +548 -0
- griptape_nodes/cli/commands/libraries.py +90 -0
- griptape_nodes/cli/commands/self.py +117 -0
- griptape_nodes/cli/main.py +46 -0
- griptape_nodes/cli/shared.py +84 -0
- griptape_nodes/common/__init__.py +1 -0
- griptape_nodes/common/directed_graph.py +55 -0
- griptape_nodes/drivers/storage/local_storage_driver.py +7 -2
- griptape_nodes/exe_types/core_types.py +60 -2
- griptape_nodes/exe_types/node_types.py +38 -24
- griptape_nodes/machines/control_flow.py +86 -22
- griptape_nodes/machines/fsm.py +10 -1
- griptape_nodes/machines/parallel_resolution.py +570 -0
- griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +22 -51
- griptape_nodes/mcp_server/server.py +1 -1
- griptape_nodes/retained_mode/events/base_events.py +2 -2
- griptape_nodes/retained_mode/events/node_events.py +4 -3
- griptape_nodes/retained_mode/griptape_nodes.py +25 -12
- griptape_nodes/retained_mode/managers/agent_manager.py +9 -5
- griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
- griptape_nodes/retained_mode/managers/context_manager.py +6 -5
- griptape_nodes/retained_mode/managers/flow_manager.py +117 -204
- griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
- griptape_nodes/retained_mode/managers/library_manager.py +35 -25
- griptape_nodes/retained_mode/managers/node_manager.py +81 -199
- griptape_nodes/retained_mode/managers/object_manager.py +11 -5
- griptape_nodes/retained_mode/managers/os_manager.py +24 -9
- griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
- griptape_nodes/retained_mode/managers/settings.py +32 -1
- griptape_nodes/retained_mode/managers/static_files_manager.py +8 -3
- griptape_nodes/retained_mode/managers/sync_manager.py +8 -5
- griptape_nodes/retained_mode/managers/workflow_manager.py +110 -122
- griptape_nodes/traits/add_param_button.py +1 -1
- griptape_nodes/traits/button.py +216 -6
- griptape_nodes/traits/color_picker.py +66 -0
- griptape_nodes/traits/traits.json +4 -0
- {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/METADATA +2 -1
- {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/RECORD +48 -34
- {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/WHEEL +0 -0
- {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import StrEnum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from griptape_nodes.common.directed_graph import DirectedGraph
|
|
10
|
+
from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
|
|
11
|
+
from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
|
|
12
|
+
from griptape_nodes.exe_types.type_validator import TypeValidator
|
|
13
|
+
from griptape_nodes.machines.fsm import FSM, State
|
|
14
|
+
from griptape_nodes.node_library.library_registry import LibraryRegistry
|
|
15
|
+
from griptape_nodes.retained_mode.events.base_events import (
|
|
16
|
+
ExecutionEvent,
|
|
17
|
+
ExecutionGriptapeNodeEvent,
|
|
18
|
+
)
|
|
19
|
+
from griptape_nodes.retained_mode.events.execution_events import (
|
|
20
|
+
CurrentDataNodeEvent,
|
|
21
|
+
NodeResolvedEvent,
|
|
22
|
+
ParameterSpotlightEvent,
|
|
23
|
+
ParameterValueUpdateEvent,
|
|
24
|
+
)
|
|
25
|
+
from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("griptape_nodes")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NodeState(StrEnum):
|
|
31
|
+
"""Individual node execution states."""
|
|
32
|
+
|
|
33
|
+
QUEUED = "queued"
|
|
34
|
+
PROCESSING = "processing"
|
|
35
|
+
DONE = "done"
|
|
36
|
+
CANCELED = "canceled"
|
|
37
|
+
ERRORED = "errored"
|
|
38
|
+
WAITING = "waiting"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(kw_only=True)
|
|
42
|
+
class DagNode:
|
|
43
|
+
"""Represents a node in the DAG with runtime references."""
|
|
44
|
+
|
|
45
|
+
task_reference: asyncio.Task | None = field(default=None)
|
|
46
|
+
node_state: NodeState = field(default=NodeState.WAITING)
|
|
47
|
+
node_reference: BaseNode
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class Focus:
|
|
52
|
+
node: BaseNode
|
|
53
|
+
scheduled_value: Any | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WorkflowState(StrEnum):
|
|
57
|
+
"""Workflow execution states."""
|
|
58
|
+
|
|
59
|
+
NO_ERROR = "no_error"
|
|
60
|
+
WORKFLOW_COMPLETE = "workflow_complete"
|
|
61
|
+
ERRORED = "errored"
|
|
62
|
+
CANCELED = "canceled"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ParallelResolutionContext:
|
|
66
|
+
focus_stack: list[Focus]
|
|
67
|
+
paused: bool
|
|
68
|
+
flow_name: str
|
|
69
|
+
build_only: bool
|
|
70
|
+
batched_nodes: list[BaseNode]
|
|
71
|
+
error_message: str | None
|
|
72
|
+
workflow_state: WorkflowState
|
|
73
|
+
# DAG fields moved from DagOrchestrator
|
|
74
|
+
network: DirectedGraph
|
|
75
|
+
node_to_reference: dict[str, DagNode]
|
|
76
|
+
async_semaphore: asyncio.Semaphore
|
|
77
|
+
task_to_node: dict[asyncio.Task, DagNode]
|
|
78
|
+
|
|
79
|
+
def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
|
|
80
|
+
self.flow_name = flow_name
|
|
81
|
+
self.focus_stack = []
|
|
82
|
+
self.paused = False
|
|
83
|
+
self.build_only = False
|
|
84
|
+
self.batched_nodes = []
|
|
85
|
+
self.error_message = None
|
|
86
|
+
self.workflow_state = WorkflowState.NO_ERROR
|
|
87
|
+
|
|
88
|
+
# Initialize DAG fields
|
|
89
|
+
self.network = DirectedGraph()
|
|
90
|
+
self.node_to_reference = {}
|
|
91
|
+
max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
|
|
92
|
+
self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
|
|
93
|
+
self.task_to_node = {}
|
|
94
|
+
|
|
95
|
+
def reset(self, *, cancel: bool = False) -> None:
|
|
96
|
+
if self.focus_stack:
|
|
97
|
+
node = self.focus_stack[-1].node
|
|
98
|
+
node.clear_node()
|
|
99
|
+
self.focus_stack.clear()
|
|
100
|
+
self.paused = False
|
|
101
|
+
if cancel:
|
|
102
|
+
self.workflow_state = WorkflowState.CANCELED
|
|
103
|
+
for node in self.node_to_reference.values():
|
|
104
|
+
node.node_state = NodeState.CANCELED
|
|
105
|
+
else:
|
|
106
|
+
self.workflow_state = WorkflowState.NO_ERROR
|
|
107
|
+
self.error_message = None
|
|
108
|
+
self.network.clear()
|
|
109
|
+
self.node_to_reference.clear()
|
|
110
|
+
self.task_to_node.clear()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class InitializeDagSpotlightState(State):
|
|
114
|
+
@staticmethod
|
|
115
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
116
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
117
|
+
|
|
118
|
+
current_node = context.focus_stack[-1].node
|
|
119
|
+
GriptapeNodes.EventManager().put_event(
|
|
120
|
+
ExecutionGriptapeNodeEvent(
|
|
121
|
+
wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=current_node.name))
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
if not context.paused:
|
|
125
|
+
return InitializeDagSpotlightState
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
130
|
+
if not len(context.focus_stack):
|
|
131
|
+
return DagCompleteState
|
|
132
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
133
|
+
|
|
134
|
+
current_node = context.focus_stack[-1].node
|
|
135
|
+
if current_node.state == NodeResolutionState.UNRESOLVED:
|
|
136
|
+
GriptapeNodes.FlowManager().get_connections().unresolve_future_nodes(current_node)
|
|
137
|
+
current_node.initialize_spotlight()
|
|
138
|
+
current_node.state = NodeResolutionState.RESOLVING
|
|
139
|
+
if current_node.get_current_parameter() is None:
|
|
140
|
+
if current_node.advance_parameter():
|
|
141
|
+
return EvaluateDagParameterState
|
|
142
|
+
return BuildDagNodeState
|
|
143
|
+
return EvaluateDagParameterState
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class EvaluateDagParameterState(State):
|
|
147
|
+
@staticmethod
|
|
148
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
149
|
+
current_node = context.focus_stack[-1].node
|
|
150
|
+
current_parameter = current_node.get_current_parameter()
|
|
151
|
+
if current_parameter is None:
|
|
152
|
+
return BuildDagNodeState
|
|
153
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
154
|
+
|
|
155
|
+
GriptapeNodes.EventManager().put_event(
|
|
156
|
+
ExecutionGriptapeNodeEvent(
|
|
157
|
+
wrapped_event=ExecutionEvent(
|
|
158
|
+
payload=ParameterSpotlightEvent(
|
|
159
|
+
node_name=current_node.name,
|
|
160
|
+
parameter_name=current_parameter.name,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
if not context.paused:
|
|
166
|
+
return EvaluateDagParameterState
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
171
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
172
|
+
|
|
173
|
+
current_node = context.focus_stack[-1].node
|
|
174
|
+
current_parameter = current_node.get_current_parameter()
|
|
175
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
176
|
+
if current_parameter is None:
|
|
177
|
+
msg = "No current parameter set."
|
|
178
|
+
raise ValueError(msg)
|
|
179
|
+
next_node = connections.get_connected_node(current_node, current_parameter)
|
|
180
|
+
if next_node:
|
|
181
|
+
next_node, _ = next_node
|
|
182
|
+
if next_node:
|
|
183
|
+
if next_node.state == NodeResolutionState.UNRESOLVED:
|
|
184
|
+
focus_stack_names = {focus.node.name for focus in context.focus_stack}
|
|
185
|
+
if next_node.name in focus_stack_names:
|
|
186
|
+
msg = f"Cycle detected between node '{current_node.name}' and '{next_node.name}'."
|
|
187
|
+
raise RuntimeError(msg)
|
|
188
|
+
context.network.add_edge(next_node.name, current_node.name)
|
|
189
|
+
context.focus_stack.append(Focus(node=next_node))
|
|
190
|
+
return InitializeDagSpotlightState
|
|
191
|
+
if next_node.state == NodeResolutionState.RESOLVED and next_node in context.batched_nodes:
|
|
192
|
+
context.network.add_edge(next_node.name, current_node.name)
|
|
193
|
+
if current_node.advance_parameter():
|
|
194
|
+
return InitializeDagSpotlightState
|
|
195
|
+
return BuildDagNodeState
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class BuildDagNodeState(State):
|
|
199
|
+
@staticmethod
|
|
200
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
201
|
+
current_node = context.focus_stack[-1].node
|
|
202
|
+
|
|
203
|
+
# Add the current node to the DAG
|
|
204
|
+
node_reference = DagNode(node_reference=current_node)
|
|
205
|
+
context.node_to_reference[current_node.name] = node_reference
|
|
206
|
+
# Add node name to DAG (has to be a hashable value)
|
|
207
|
+
context.network.add_node(node_for_adding=current_node.name)
|
|
208
|
+
|
|
209
|
+
if not context.paused:
|
|
210
|
+
return BuildDagNodeState
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
215
|
+
current_node = context.focus_stack[-1].node
|
|
216
|
+
|
|
217
|
+
# Mark node as resolved for DAG building purposes
|
|
218
|
+
current_node.state = NodeResolutionState.RESOLVED
|
|
219
|
+
# Add to batched nodes
|
|
220
|
+
context.batched_nodes.append(current_node)
|
|
221
|
+
|
|
222
|
+
context.focus_stack.pop()
|
|
223
|
+
if len(context.focus_stack):
|
|
224
|
+
return EvaluateDagParameterState
|
|
225
|
+
|
|
226
|
+
if context.build_only:
|
|
227
|
+
return DagCompleteState
|
|
228
|
+
return ExecuteDagState
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class ExecuteDagState(State):
|
|
232
|
+
@staticmethod
|
|
233
|
+
def handle_done_nodes(done_node: DagNode) -> None:
|
|
234
|
+
current_node = done_node.node_reference
|
|
235
|
+
# Publish all parameter updates.
|
|
236
|
+
current_node.state = NodeResolutionState.RESOLVED
|
|
237
|
+
# Serialization can be slow so only do it if the user wants debug details.
|
|
238
|
+
if logger.level <= logging.DEBUG:
|
|
239
|
+
logger.debug(
|
|
240
|
+
"INPUTS: %s\nOUTPUTS: %s",
|
|
241
|
+
TypeValidator.safe_serialize(current_node.parameter_values),
|
|
242
|
+
TypeValidator.safe_serialize(current_node.parameter_output_values),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
for parameter_name, value in current_node.parameter_output_values.items():
|
|
246
|
+
parameter = current_node.get_parameter_by_name(parameter_name)
|
|
247
|
+
if parameter is None:
|
|
248
|
+
err = f"Canceling flow run. Node '{current_node.name}' specified a Parameter '{parameter_name}', but no such Parameter could be found on that Node."
|
|
249
|
+
raise KeyError(err)
|
|
250
|
+
data_type = parameter.type
|
|
251
|
+
if data_type is None:
|
|
252
|
+
data_type = ParameterTypeBuiltin.NONE.value
|
|
253
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
254
|
+
|
|
255
|
+
GriptapeNodes.EventManager().put_event(
|
|
256
|
+
ExecutionGriptapeNodeEvent(
|
|
257
|
+
wrapped_event=ExecutionEvent(
|
|
258
|
+
payload=ParameterValueUpdateEvent(
|
|
259
|
+
node_name=current_node.name,
|
|
260
|
+
parameter_name=parameter_name,
|
|
261
|
+
data_type=data_type,
|
|
262
|
+
value=TypeValidator.safe_serialize(value),
|
|
263
|
+
)
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
# Output values should already be saved!
|
|
268
|
+
library = LibraryRegistry.get_libraries_with_node_type(current_node.__class__.__name__)
|
|
269
|
+
if len(library) == 1:
|
|
270
|
+
library_name = library[0]
|
|
271
|
+
else:
|
|
272
|
+
library_name = None
|
|
273
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
274
|
+
|
|
275
|
+
GriptapeNodes.EventManager().put_event(
|
|
276
|
+
ExecutionGriptapeNodeEvent(
|
|
277
|
+
wrapped_event=ExecutionEvent(
|
|
278
|
+
payload=NodeResolvedEvent(
|
|
279
|
+
node_name=current_node.name,
|
|
280
|
+
parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
|
|
281
|
+
node_type=current_node.__class__.__name__,
|
|
282
|
+
specific_library_name=library_name,
|
|
283
|
+
)
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
|
|
290
|
+
"""Collect output values from resolved upstream nodes and pass them to the current node.
|
|
291
|
+
|
|
292
|
+
This method iterates through all input parameters of the current node, finds their
|
|
293
|
+
connected upstream nodes, and if those nodes are resolved, retrieves their output
|
|
294
|
+
values and passes them through using SetParameterValueRequest.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
node_reference (DagOrchestrator.DagNode): The node to collect values for.
|
|
298
|
+
"""
|
|
299
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
300
|
+
|
|
301
|
+
current_node = node_reference.node_reference
|
|
302
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
303
|
+
|
|
304
|
+
for parameter in current_node.parameters:
|
|
305
|
+
# Skip control type parameters
|
|
306
|
+
if ParameterTypeBuiltin.CONTROL_TYPE.value.lower() == parameter.output_type:
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
# Get the connected upstream node for this parameter
|
|
310
|
+
upstream_connection = connections.get_connected_node(current_node, parameter)
|
|
311
|
+
if upstream_connection:
|
|
312
|
+
upstream_node, upstream_parameter = upstream_connection
|
|
313
|
+
|
|
314
|
+
# If the upstream node is resolved, collect its output value
|
|
315
|
+
if upstream_parameter.name in upstream_node.parameter_output_values:
|
|
316
|
+
output_value = upstream_node.parameter_output_values[upstream_parameter.name]
|
|
317
|
+
else:
|
|
318
|
+
output_value = upstream_node.get_parameter_value(upstream_parameter.name)
|
|
319
|
+
|
|
320
|
+
# Pass the value through using the same mechanism as normal resolution
|
|
321
|
+
GriptapeNodes.get_instance().handle_request(
|
|
322
|
+
SetParameterValueRequest(
|
|
323
|
+
parameter_name=parameter.name,
|
|
324
|
+
node_name=current_node.name,
|
|
325
|
+
value=output_value,
|
|
326
|
+
data_type=upstream_parameter.output_type,
|
|
327
|
+
incoming_connection_source_node_name=upstream_node.name,
|
|
328
|
+
incoming_connection_source_parameter_name=upstream_parameter.name,
|
|
329
|
+
)
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def clear_parameter_output_values(node_reference: DagNode) -> None:
|
|
334
|
+
"""Clear all parameter output values for the given node and publish events.
|
|
335
|
+
|
|
336
|
+
This method iterates through each parameter output value stored in the node,
|
|
337
|
+
removes it from the node's parameter_output_values dictionary, and publishes an event
|
|
338
|
+
to notify the system about the parameter value being set to None.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
node_reference (DagOrchestrator.DagNode): The DAG node to clear values for.
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
ValueError: If a parameter name in parameter_output_values doesn't correspond
|
|
345
|
+
to an actual parameter in the node.
|
|
346
|
+
"""
|
|
347
|
+
current_node = node_reference.node_reference
|
|
348
|
+
for parameter_name in current_node.parameter_output_values:
|
|
349
|
+
parameter = current_node.get_parameter_by_name(parameter_name)
|
|
350
|
+
if parameter is None:
|
|
351
|
+
err = f"Attempted to clear output values for node '{current_node.name}' but could not find parameter '{parameter_name}' that was indicated as having a value."
|
|
352
|
+
raise ValueError(err)
|
|
353
|
+
parameter_type = parameter.type
|
|
354
|
+
if parameter_type is None:
|
|
355
|
+
parameter_type = ParameterTypeBuiltin.NONE.value
|
|
356
|
+
payload = ParameterValueUpdateEvent(
|
|
357
|
+
node_name=current_node.name,
|
|
358
|
+
parameter_name=parameter_name,
|
|
359
|
+
data_type=parameter_type,
|
|
360
|
+
value=None,
|
|
361
|
+
)
|
|
362
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
363
|
+
|
|
364
|
+
GriptapeNodes.EventManager().put_event(
|
|
365
|
+
ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=payload))
|
|
366
|
+
)
|
|
367
|
+
current_node.parameter_output_values.clear()
|
|
368
|
+
|
|
369
|
+
@staticmethod
|
|
370
|
+
def build_node_states(context: ParallelResolutionContext) -> tuple[list[str], list[str], list[str], list[str]]:
|
|
371
|
+
network = context.network
|
|
372
|
+
leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
|
|
373
|
+
done_nodes = []
|
|
374
|
+
canceled_nodes = []
|
|
375
|
+
queued_nodes = []
|
|
376
|
+
for node in leaf_nodes:
|
|
377
|
+
node_reference = context.node_to_reference[node]
|
|
378
|
+
# If the node is locked, mark it as done so it skips execution
|
|
379
|
+
if node_reference.node_reference.lock:
|
|
380
|
+
node_reference.node_state = NodeState.DONE
|
|
381
|
+
done_nodes.append(node)
|
|
382
|
+
continue
|
|
383
|
+
node_state = node_reference.node_state
|
|
384
|
+
if node_state == NodeState.DONE:
|
|
385
|
+
done_nodes.append(node)
|
|
386
|
+
elif node_state == NodeState.CANCELED:
|
|
387
|
+
canceled_nodes.append(node)
|
|
388
|
+
elif node_state == NodeState.QUEUED:
|
|
389
|
+
queued_nodes.append(node)
|
|
390
|
+
return done_nodes, canceled_nodes, queued_nodes, leaf_nodes
|
|
391
|
+
|
|
392
|
+
@staticmethod
|
|
393
|
+
async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
|
|
394
|
+
async with semaphore:
|
|
395
|
+
await current_node.node_reference.aprocess()
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
399
|
+
# Start DAG execution after resolution is complete
|
|
400
|
+
context.batched_nodes.clear()
|
|
401
|
+
for node in context.node_to_reference.values():
|
|
402
|
+
# We have a DAG. Flag all nodes in DAG as queued. Workflow state is NO_ERROR
|
|
403
|
+
node.node_state = NodeState.QUEUED
|
|
404
|
+
context.workflow_state = WorkflowState.NO_ERROR
|
|
405
|
+
if not context.paused:
|
|
406
|
+
return ExecuteDagState
|
|
407
|
+
return None
|
|
408
|
+
|
|
409
|
+
@staticmethod
|
|
410
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
411
|
+
# Check if DAG execution is complete
|
|
412
|
+
network = context.network
|
|
413
|
+
# Check and see if there are leaf nodes that are cancelled.
|
|
414
|
+
done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
415
|
+
# Are there any nodes in Done state?
|
|
416
|
+
for node in done_nodes:
|
|
417
|
+
# We have nodes in done state.
|
|
418
|
+
# Remove the leaf node from the graph.
|
|
419
|
+
network.remove_node(node)
|
|
420
|
+
# Return thread to thread pool.
|
|
421
|
+
ExecuteDagState.handle_done_nodes(context.node_to_reference[node])
|
|
422
|
+
# Reinitialize leaf nodes since maybe we changed things up.
|
|
423
|
+
if len(done_nodes) > 0:
|
|
424
|
+
# We removed nodes from the network. There may be new leaf nodes.
|
|
425
|
+
done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
|
|
426
|
+
# We have no more leaf nodes. Quit early.
|
|
427
|
+
if not leaf_nodes:
|
|
428
|
+
context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
|
|
429
|
+
return DagCompleteState
|
|
430
|
+
if len(canceled_nodes) == len(leaf_nodes):
|
|
431
|
+
# All leaf nodes are cancelled.
|
|
432
|
+
# Set state to workflow complete.
|
|
433
|
+
context.workflow_state = WorkflowState.CANCELED
|
|
434
|
+
return DagCompleteState
|
|
435
|
+
# Are there any in the queued state?
|
|
436
|
+
for node in queued_nodes:
|
|
437
|
+
# Process all queued nodes - the async semaphore will handle concurrency limits
|
|
438
|
+
node_reference = context.node_to_reference[node]
|
|
439
|
+
|
|
440
|
+
# Collect parameter values from upstream nodes before executing
|
|
441
|
+
try:
|
|
442
|
+
ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
|
|
443
|
+
except Exception as e:
|
|
444
|
+
logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
|
|
445
|
+
context.error_message = (
|
|
446
|
+
f"Parameter passthrough failed for node '{node_reference.node_reference.name}': {e}"
|
|
447
|
+
)
|
|
448
|
+
context.workflow_state = WorkflowState.ERRORED
|
|
449
|
+
return ErrorState
|
|
450
|
+
|
|
451
|
+
# Clear parameter output values before execution
|
|
452
|
+
try:
|
|
453
|
+
ExecuteDagState.clear_parameter_output_values(node_reference)
|
|
454
|
+
except Exception as e:
|
|
455
|
+
logger.exception(
|
|
456
|
+
"Error clearing parameter output values for node '%s'", node_reference.node_reference.name
|
|
457
|
+
)
|
|
458
|
+
context.error_message = (
|
|
459
|
+
f"Parameter clearing failed for node '{node_reference.node_reference.name}': {e}"
|
|
460
|
+
)
|
|
461
|
+
context.workflow_state = WorkflowState.ERRORED
|
|
462
|
+
return ErrorState
|
|
463
|
+
|
|
464
|
+
def on_task_done(task: asyncio.Task) -> None:
|
|
465
|
+
node = context.task_to_node.pop(task)
|
|
466
|
+
node.node_state = NodeState.DONE
|
|
467
|
+
logger.info("Task done: %s", node.node_reference.name)
|
|
468
|
+
|
|
469
|
+
# Execute the node asynchronously
|
|
470
|
+
node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
|
|
471
|
+
# Add a callback to set node to done when task has finished.
|
|
472
|
+
context.task_to_node[node_task] = node_reference
|
|
473
|
+
node_reference.task_reference = node_task
|
|
474
|
+
node_task.add_done_callback(lambda t: on_task_done(t))
|
|
475
|
+
node_reference.node_state = NodeState.PROCESSING
|
|
476
|
+
node_reference.node_reference.state = NodeResolutionState.RESOLVING
|
|
477
|
+
# Wait for a task to finish
|
|
478
|
+
await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
479
|
+
# Once a task has finished, loop back to the top.
|
|
480
|
+
return ExecuteDagState
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class ErrorState(State):
|
|
484
|
+
@staticmethod
|
|
485
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
486
|
+
if context.error_message:
|
|
487
|
+
logger.error("DAG execution error: %s", context.error_message)
|
|
488
|
+
for node in context.node_to_reference.values():
|
|
489
|
+
# Cancel all nodes that haven't yet begun processing.
|
|
490
|
+
if node.node_state == NodeState.QUEUED:
|
|
491
|
+
node.node_state = NodeState.CANCELED
|
|
492
|
+
# Shut down and cancel all threads/tasks that haven't yet ran. Currently running ones will not be affected.
|
|
493
|
+
# Cancel async tasks
|
|
494
|
+
for task in list(context.task_to_node.keys()):
|
|
495
|
+
if not task.done():
|
|
496
|
+
task.cancel()
|
|
497
|
+
return ErrorState
|
|
498
|
+
|
|
499
|
+
@staticmethod
|
|
500
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None:
|
|
501
|
+
# Don't modify lists while iterating through them.
|
|
502
|
+
task_to_node = context.task_to_node
|
|
503
|
+
for task, node in task_to_node.copy().items():
|
|
504
|
+
if task.done():
|
|
505
|
+
node.node_state = NodeState.DONE
|
|
506
|
+
elif task.cancelled():
|
|
507
|
+
node.node_state = NodeState.CANCELED
|
|
508
|
+
task_to_node.pop(task)
|
|
509
|
+
|
|
510
|
+
# Handle async tasks
|
|
511
|
+
task_to_node = context.task_to_node
|
|
512
|
+
for task, node in task_to_node.copy().items():
|
|
513
|
+
if task.done():
|
|
514
|
+
node.node_state = NodeState.DONE
|
|
515
|
+
elif task.cancelled():
|
|
516
|
+
node.node_state = NodeState.CANCELED
|
|
517
|
+
task_to_node.pop(task)
|
|
518
|
+
|
|
519
|
+
if len(task_to_node) == 0:
|
|
520
|
+
# Finish up. We failed.
|
|
521
|
+
context.workflow_state = WorkflowState.ERRORED
|
|
522
|
+
context.network.clear()
|
|
523
|
+
context.node_to_reference.clear()
|
|
524
|
+
context.task_to_node.clear()
|
|
525
|
+
return DagCompleteState
|
|
526
|
+
# Let's continue going through until everything is cancelled.
|
|
527
|
+
return ErrorState
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class DagCompleteState(State):
|
|
531
|
+
@staticmethod
|
|
532
|
+
async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
|
|
533
|
+
# Set build_only back to False.
|
|
534
|
+
context.build_only = False
|
|
535
|
+
return None
|
|
536
|
+
|
|
537
|
+
@staticmethod
|
|
538
|
+
async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: ARG004
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
|
|
543
|
+
"""State machine for building DAG structure without execution."""
|
|
544
|
+
|
|
545
|
+
def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
|
|
546
|
+
resolution_context = ParallelResolutionContext(flow_name, max_nodes_in_parallel=max_nodes_in_parallel)
|
|
547
|
+
super().__init__(resolution_context)
|
|
548
|
+
|
|
549
|
+
async def resolve_node(self, node: BaseNode, *, build_only: bool = False) -> None:
|
|
550
|
+
"""Build DAG structure starting from the given node."""
|
|
551
|
+
self._context.focus_stack.append(Focus(node=node))
|
|
552
|
+
self._context.build_only = build_only
|
|
553
|
+
await self.start(InitializeDagSpotlightState)
|
|
554
|
+
|
|
555
|
+
async def build_dag_for_node(self, node: BaseNode) -> None:
|
|
556
|
+
"""Build DAG structure starting from the given node. (Deprecated: use resolve_node)."""
|
|
557
|
+
await self.resolve_node(node)
|
|
558
|
+
|
|
559
|
+
def change_debug_mode(self, *, debug_mode: bool) -> None:
|
|
560
|
+
self._context.paused = debug_mode
|
|
561
|
+
|
|
562
|
+
def is_complete(self) -> bool:
|
|
563
|
+
return self._current_state is DagCompleteState
|
|
564
|
+
|
|
565
|
+
def is_started(self) -> bool:
|
|
566
|
+
return self._current_state is not None
|
|
567
|
+
|
|
568
|
+
def reset_machine(self, *, cancel: bool = False) -> None:
|
|
569
|
+
self._context.reset(cancel=cancel)
|
|
570
|
+
self._current_state = None
|