griptape-nodes 0.52.0__py3-none-any.whl → 0.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. griptape_nodes/__init__.py +6 -943
  2. griptape_nodes/__main__.py +6 -0
  3. griptape_nodes/app/api.py +1 -12
  4. griptape_nodes/app/app.py +256 -209
  5. griptape_nodes/cli/__init__.py +1 -0
  6. griptape_nodes/cli/commands/__init__.py +1 -0
  7. griptape_nodes/cli/commands/config.py +71 -0
  8. griptape_nodes/cli/commands/engine.py +80 -0
  9. griptape_nodes/cli/commands/init.py +548 -0
  10. griptape_nodes/cli/commands/libraries.py +90 -0
  11. griptape_nodes/cli/commands/self.py +117 -0
  12. griptape_nodes/cli/main.py +46 -0
  13. griptape_nodes/cli/shared.py +84 -0
  14. griptape_nodes/common/__init__.py +1 -0
  15. griptape_nodes/common/directed_graph.py +55 -0
  16. griptape_nodes/drivers/storage/local_storage_driver.py +7 -2
  17. griptape_nodes/exe_types/core_types.py +60 -2
  18. griptape_nodes/exe_types/node_types.py +38 -24
  19. griptape_nodes/machines/control_flow.py +86 -22
  20. griptape_nodes/machines/fsm.py +10 -1
  21. griptape_nodes/machines/parallel_resolution.py +570 -0
  22. griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +22 -51
  23. griptape_nodes/mcp_server/server.py +1 -1
  24. griptape_nodes/retained_mode/events/base_events.py +2 -2
  25. griptape_nodes/retained_mode/events/node_events.py +4 -3
  26. griptape_nodes/retained_mode/griptape_nodes.py +25 -12
  27. griptape_nodes/retained_mode/managers/agent_manager.py +9 -5
  28. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
  29. griptape_nodes/retained_mode/managers/context_manager.py +6 -5
  30. griptape_nodes/retained_mode/managers/flow_manager.py +117 -204
  31. griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
  32. griptape_nodes/retained_mode/managers/library_manager.py +35 -25
  33. griptape_nodes/retained_mode/managers/node_manager.py +81 -199
  34. griptape_nodes/retained_mode/managers/object_manager.py +11 -5
  35. griptape_nodes/retained_mode/managers/os_manager.py +24 -9
  36. griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
  37. griptape_nodes/retained_mode/managers/settings.py +32 -1
  38. griptape_nodes/retained_mode/managers/static_files_manager.py +8 -3
  39. griptape_nodes/retained_mode/managers/sync_manager.py +8 -5
  40. griptape_nodes/retained_mode/managers/workflow_manager.py +110 -122
  41. griptape_nodes/traits/add_param_button.py +1 -1
  42. griptape_nodes/traits/button.py +216 -6
  43. griptape_nodes/traits/color_picker.py +66 -0
  44. griptape_nodes/traits/traits.json +4 -0
  45. {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/METADATA +2 -1
  46. {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/RECORD +48 -34
  47. {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/WHEEL +0 -0
  48. {griptape_nodes-0.52.0.dist-info → griptape_nodes-0.53.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,570 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from enum import StrEnum
7
+ from typing import Any
8
+
9
+ from griptape_nodes.common.directed_graph import DirectedGraph
10
+ from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
11
+ from griptape_nodes.exe_types.node_types import BaseNode, NodeResolutionState
12
+ from griptape_nodes.exe_types.type_validator import TypeValidator
13
+ from griptape_nodes.machines.fsm import FSM, State
14
+ from griptape_nodes.node_library.library_registry import LibraryRegistry
15
+ from griptape_nodes.retained_mode.events.base_events import (
16
+ ExecutionEvent,
17
+ ExecutionGriptapeNodeEvent,
18
+ )
19
+ from griptape_nodes.retained_mode.events.execution_events import (
20
+ CurrentDataNodeEvent,
21
+ NodeResolvedEvent,
22
+ ParameterSpotlightEvent,
23
+ ParameterValueUpdateEvent,
24
+ )
25
+ from griptape_nodes.retained_mode.events.parameter_events import SetParameterValueRequest
26
+
27
+ logger = logging.getLogger("griptape_nodes")
28
+
29
+
30
+ class NodeState(StrEnum):
31
+ """Individual node execution states."""
32
+
33
+ QUEUED = "queued"
34
+ PROCESSING = "processing"
35
+ DONE = "done"
36
+ CANCELED = "canceled"
37
+ ERRORED = "errored"
38
+ WAITING = "waiting"
39
+
40
+
41
+ @dataclass(kw_only=True)
42
+ class DagNode:
43
+ """Represents a node in the DAG with runtime references."""
44
+
45
+ task_reference: asyncio.Task | None = field(default=None)
46
+ node_state: NodeState = field(default=NodeState.WAITING)
47
+ node_reference: BaseNode
48
+
49
+
50
+ @dataclass
51
+ class Focus:
52
+ node: BaseNode
53
+ scheduled_value: Any | None = None
54
+
55
+
56
+ class WorkflowState(StrEnum):
57
+ """Workflow execution states."""
58
+
59
+ NO_ERROR = "no_error"
60
+ WORKFLOW_COMPLETE = "workflow_complete"
61
+ ERRORED = "errored"
62
+ CANCELED = "canceled"
63
+
64
+
65
+ class ParallelResolutionContext:
66
+ focus_stack: list[Focus]
67
+ paused: bool
68
+ flow_name: str
69
+ build_only: bool
70
+ batched_nodes: list[BaseNode]
71
+ error_message: str | None
72
+ workflow_state: WorkflowState
73
+ # DAG fields moved from DagOrchestrator
74
+ network: DirectedGraph
75
+ node_to_reference: dict[str, DagNode]
76
+ async_semaphore: asyncio.Semaphore
77
+ task_to_node: dict[asyncio.Task, DagNode]
78
+
79
+ def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
80
+ self.flow_name = flow_name
81
+ self.focus_stack = []
82
+ self.paused = False
83
+ self.build_only = False
84
+ self.batched_nodes = []
85
+ self.error_message = None
86
+ self.workflow_state = WorkflowState.NO_ERROR
87
+
88
+ # Initialize DAG fields
89
+ self.network = DirectedGraph()
90
+ self.node_to_reference = {}
91
+ max_nodes_in_parallel = max_nodes_in_parallel if max_nodes_in_parallel is not None else 5
92
+ self.async_semaphore = asyncio.Semaphore(max_nodes_in_parallel)
93
+ self.task_to_node = {}
94
+
95
+ def reset(self, *, cancel: bool = False) -> None:
96
+ if self.focus_stack:
97
+ node = self.focus_stack[-1].node
98
+ node.clear_node()
99
+ self.focus_stack.clear()
100
+ self.paused = False
101
+ if cancel:
102
+ self.workflow_state = WorkflowState.CANCELED
103
+ for node in self.node_to_reference.values():
104
+ node.node_state = NodeState.CANCELED
105
+ else:
106
+ self.workflow_state = WorkflowState.NO_ERROR
107
+ self.error_message = None
108
+ self.network.clear()
109
+ self.node_to_reference.clear()
110
+ self.task_to_node.clear()
111
+
112
+
113
+ class InitializeDagSpotlightState(State):
114
+ @staticmethod
115
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
116
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
117
+
118
+ current_node = context.focus_stack[-1].node
119
+ GriptapeNodes.EventManager().put_event(
120
+ ExecutionGriptapeNodeEvent(
121
+ wrapped_event=ExecutionEvent(payload=CurrentDataNodeEvent(node_name=current_node.name))
122
+ )
123
+ )
124
+ if not context.paused:
125
+ return InitializeDagSpotlightState
126
+ return None
127
+
128
+ @staticmethod
129
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
130
+ if not len(context.focus_stack):
131
+ return DagCompleteState
132
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
133
+
134
+ current_node = context.focus_stack[-1].node
135
+ if current_node.state == NodeResolutionState.UNRESOLVED:
136
+ GriptapeNodes.FlowManager().get_connections().unresolve_future_nodes(current_node)
137
+ current_node.initialize_spotlight()
138
+ current_node.state = NodeResolutionState.RESOLVING
139
+ if current_node.get_current_parameter() is None:
140
+ if current_node.advance_parameter():
141
+ return EvaluateDagParameterState
142
+ return BuildDagNodeState
143
+ return EvaluateDagParameterState
144
+
145
+
146
+ class EvaluateDagParameterState(State):
147
+ @staticmethod
148
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
149
+ current_node = context.focus_stack[-1].node
150
+ current_parameter = current_node.get_current_parameter()
151
+ if current_parameter is None:
152
+ return BuildDagNodeState
153
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
154
+
155
+ GriptapeNodes.EventManager().put_event(
156
+ ExecutionGriptapeNodeEvent(
157
+ wrapped_event=ExecutionEvent(
158
+ payload=ParameterSpotlightEvent(
159
+ node_name=current_node.name,
160
+ parameter_name=current_parameter.name,
161
+ )
162
+ )
163
+ )
164
+ )
165
+ if not context.paused:
166
+ return EvaluateDagParameterState
167
+ return None
168
+
169
+ @staticmethod
170
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
171
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
172
+
173
+ current_node = context.focus_stack[-1].node
174
+ current_parameter = current_node.get_current_parameter()
175
+ connections = GriptapeNodes.FlowManager().get_connections()
176
+ if current_parameter is None:
177
+ msg = "No current parameter set."
178
+ raise ValueError(msg)
179
+ next_node = connections.get_connected_node(current_node, current_parameter)
180
+ if next_node:
181
+ next_node, _ = next_node
182
+ if next_node:
183
+ if next_node.state == NodeResolutionState.UNRESOLVED:
184
+ focus_stack_names = {focus.node.name for focus in context.focus_stack}
185
+ if next_node.name in focus_stack_names:
186
+ msg = f"Cycle detected between node '{current_node.name}' and '{next_node.name}'."
187
+ raise RuntimeError(msg)
188
+ context.network.add_edge(next_node.name, current_node.name)
189
+ context.focus_stack.append(Focus(node=next_node))
190
+ return InitializeDagSpotlightState
191
+ if next_node.state == NodeResolutionState.RESOLVED and next_node in context.batched_nodes:
192
+ context.network.add_edge(next_node.name, current_node.name)
193
+ if current_node.advance_parameter():
194
+ return InitializeDagSpotlightState
195
+ return BuildDagNodeState
196
+
197
+
198
+ class BuildDagNodeState(State):
199
+ @staticmethod
200
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
201
+ current_node = context.focus_stack[-1].node
202
+
203
+ # Add the current node to the DAG
204
+ node_reference = DagNode(node_reference=current_node)
205
+ context.node_to_reference[current_node.name] = node_reference
206
+ # Add node name to DAG (has to be a hashable value)
207
+ context.network.add_node(node_for_adding=current_node.name)
208
+
209
+ if not context.paused:
210
+ return BuildDagNodeState
211
+ return None
212
+
213
+ @staticmethod
214
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
215
+ current_node = context.focus_stack[-1].node
216
+
217
+ # Mark node as resolved for DAG building purposes
218
+ current_node.state = NodeResolutionState.RESOLVED
219
+ # Add to batched nodes
220
+ context.batched_nodes.append(current_node)
221
+
222
+ context.focus_stack.pop()
223
+ if len(context.focus_stack):
224
+ return EvaluateDagParameterState
225
+
226
+ if context.build_only:
227
+ return DagCompleteState
228
+ return ExecuteDagState
229
+
230
+
231
+ class ExecuteDagState(State):
232
+ @staticmethod
233
+ def handle_done_nodes(done_node: DagNode) -> None:
234
+ current_node = done_node.node_reference
235
+ # Publish all parameter updates.
236
+ current_node.state = NodeResolutionState.RESOLVED
237
+ # Serialization can be slow so only do it if the user wants debug details.
238
+ if logger.level <= logging.DEBUG:
239
+ logger.debug(
240
+ "INPUTS: %s\nOUTPUTS: %s",
241
+ TypeValidator.safe_serialize(current_node.parameter_values),
242
+ TypeValidator.safe_serialize(current_node.parameter_output_values),
243
+ )
244
+
245
+ for parameter_name, value in current_node.parameter_output_values.items():
246
+ parameter = current_node.get_parameter_by_name(parameter_name)
247
+ if parameter is None:
248
+ err = f"Canceling flow run. Node '{current_node.name}' specified a Parameter '{parameter_name}', but no such Parameter could be found on that Node."
249
+ raise KeyError(err)
250
+ data_type = parameter.type
251
+ if data_type is None:
252
+ data_type = ParameterTypeBuiltin.NONE.value
253
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
254
+
255
+ GriptapeNodes.EventManager().put_event(
256
+ ExecutionGriptapeNodeEvent(
257
+ wrapped_event=ExecutionEvent(
258
+ payload=ParameterValueUpdateEvent(
259
+ node_name=current_node.name,
260
+ parameter_name=parameter_name,
261
+ data_type=data_type,
262
+ value=TypeValidator.safe_serialize(value),
263
+ )
264
+ ),
265
+ )
266
+ )
267
+ # Output values should already be saved!
268
+ library = LibraryRegistry.get_libraries_with_node_type(current_node.__class__.__name__)
269
+ if len(library) == 1:
270
+ library_name = library[0]
271
+ else:
272
+ library_name = None
273
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
274
+
275
+ GriptapeNodes.EventManager().put_event(
276
+ ExecutionGriptapeNodeEvent(
277
+ wrapped_event=ExecutionEvent(
278
+ payload=NodeResolvedEvent(
279
+ node_name=current_node.name,
280
+ parameter_output_values=TypeValidator.safe_serialize(current_node.parameter_output_values),
281
+ node_type=current_node.__class__.__name__,
282
+ specific_library_name=library_name,
283
+ )
284
+ )
285
+ )
286
+ )
287
+
288
+ @staticmethod
289
+ def collect_values_from_upstream_nodes(node_reference: DagNode) -> None:
290
+ """Collect output values from resolved upstream nodes and pass them to the current node.
291
+
292
+ This method iterates through all input parameters of the current node, finds their
293
+ connected upstream nodes, and if those nodes are resolved, retrieves their output
294
+ values and passes them through using SetParameterValueRequest.
295
+
296
+ Args:
297
+ node_reference (DagOrchestrator.DagNode): The node to collect values for.
298
+ """
299
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
300
+
301
+ current_node = node_reference.node_reference
302
+ connections = GriptapeNodes.FlowManager().get_connections()
303
+
304
+ for parameter in current_node.parameters:
305
+ # Skip control type parameters
306
+ if ParameterTypeBuiltin.CONTROL_TYPE.value.lower() == parameter.output_type:
307
+ continue
308
+
309
+ # Get the connected upstream node for this parameter
310
+ upstream_connection = connections.get_connected_node(current_node, parameter)
311
+ if upstream_connection:
312
+ upstream_node, upstream_parameter = upstream_connection
313
+
314
+ # If the upstream node is resolved, collect its output value
315
+ if upstream_parameter.name in upstream_node.parameter_output_values:
316
+ output_value = upstream_node.parameter_output_values[upstream_parameter.name]
317
+ else:
318
+ output_value = upstream_node.get_parameter_value(upstream_parameter.name)
319
+
320
+ # Pass the value through using the same mechanism as normal resolution
321
+ GriptapeNodes.get_instance().handle_request(
322
+ SetParameterValueRequest(
323
+ parameter_name=parameter.name,
324
+ node_name=current_node.name,
325
+ value=output_value,
326
+ data_type=upstream_parameter.output_type,
327
+ incoming_connection_source_node_name=upstream_node.name,
328
+ incoming_connection_source_parameter_name=upstream_parameter.name,
329
+ )
330
+ )
331
+
332
+ @staticmethod
333
+ def clear_parameter_output_values(node_reference: DagNode) -> None:
334
+ """Clear all parameter output values for the given node and publish events.
335
+
336
+ This method iterates through each parameter output value stored in the node,
337
+ removes it from the node's parameter_output_values dictionary, and publishes an event
338
+ to notify the system about the parameter value being set to None.
339
+
340
+ Args:
341
+ node_reference (DagOrchestrator.DagNode): The DAG node to clear values for.
342
+
343
+ Raises:
344
+ ValueError: If a parameter name in parameter_output_values doesn't correspond
345
+ to an actual parameter in the node.
346
+ """
347
+ current_node = node_reference.node_reference
348
+ for parameter_name in current_node.parameter_output_values:
349
+ parameter = current_node.get_parameter_by_name(parameter_name)
350
+ if parameter is None:
351
+ err = f"Attempted to clear output values for node '{current_node.name}' but could not find parameter '{parameter_name}' that was indicated as having a value."
352
+ raise ValueError(err)
353
+ parameter_type = parameter.type
354
+ if parameter_type is None:
355
+ parameter_type = ParameterTypeBuiltin.NONE.value
356
+ payload = ParameterValueUpdateEvent(
357
+ node_name=current_node.name,
358
+ parameter_name=parameter_name,
359
+ data_type=parameter_type,
360
+ value=None,
361
+ )
362
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
363
+
364
+ GriptapeNodes.EventManager().put_event(
365
+ ExecutionGriptapeNodeEvent(wrapped_event=ExecutionEvent(payload=payload))
366
+ )
367
+ current_node.parameter_output_values.clear()
368
+
369
+ @staticmethod
370
+ def build_node_states(context: ParallelResolutionContext) -> tuple[list[str], list[str], list[str], list[str]]:
371
+ network = context.network
372
+ leaf_nodes = [n for n in network.nodes() if network.in_degree(n) == 0]
373
+ done_nodes = []
374
+ canceled_nodes = []
375
+ queued_nodes = []
376
+ for node in leaf_nodes:
377
+ node_reference = context.node_to_reference[node]
378
+ # If the node is locked, mark it as done so it skips execution
379
+ if node_reference.node_reference.lock:
380
+ node_reference.node_state = NodeState.DONE
381
+ done_nodes.append(node)
382
+ continue
383
+ node_state = node_reference.node_state
384
+ if node_state == NodeState.DONE:
385
+ done_nodes.append(node)
386
+ elif node_state == NodeState.CANCELED:
387
+ canceled_nodes.append(node)
388
+ elif node_state == NodeState.QUEUED:
389
+ queued_nodes.append(node)
390
+ return done_nodes, canceled_nodes, queued_nodes, leaf_nodes
391
+
392
+ @staticmethod
393
+ async def execute_node(current_node: DagNode, semaphore: asyncio.Semaphore) -> None:
394
+ async with semaphore:
395
+ await current_node.node_reference.aprocess()
396
+
397
+ @staticmethod
398
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
399
+ # Start DAG execution after resolution is complete
400
+ context.batched_nodes.clear()
401
+ for node in context.node_to_reference.values():
402
+ # We have a DAG. Flag all nodes in DAG as queued. Workflow state is NO_ERROR
403
+ node.node_state = NodeState.QUEUED
404
+ context.workflow_state = WorkflowState.NO_ERROR
405
+ if not context.paused:
406
+ return ExecuteDagState
407
+ return None
408
+
409
+ @staticmethod
410
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
411
+ # Check if DAG execution is complete
412
+ network = context.network
413
+ # Check and see if there are leaf nodes that are cancelled.
414
+ done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
415
+ # Are there any nodes in Done state?
416
+ for node in done_nodes:
417
+ # We have nodes in done state.
418
+ # Remove the leaf node from the graph.
419
+ network.remove_node(node)
420
+ # Return thread to thread pool.
421
+ ExecuteDagState.handle_done_nodes(context.node_to_reference[node])
422
+ # Reinitialize leaf nodes since maybe we changed things up.
423
+ if len(done_nodes) > 0:
424
+ # We removed nodes from the network. There may be new leaf nodes.
425
+ done_nodes, canceled_nodes, queued_nodes, leaf_nodes = ExecuteDagState.build_node_states(context)
426
+ # We have no more leaf nodes. Quit early.
427
+ if not leaf_nodes:
428
+ context.workflow_state = WorkflowState.WORKFLOW_COMPLETE
429
+ return DagCompleteState
430
+ if len(canceled_nodes) == len(leaf_nodes):
431
+ # All leaf nodes are cancelled.
432
+ # Set state to workflow complete.
433
+ context.workflow_state = WorkflowState.CANCELED
434
+ return DagCompleteState
435
+ # Are there any in the queued state?
436
+ for node in queued_nodes:
437
+ # Process all queued nodes - the async semaphore will handle concurrency limits
438
+ node_reference = context.node_to_reference[node]
439
+
440
+ # Collect parameter values from upstream nodes before executing
441
+ try:
442
+ ExecuteDagState.collect_values_from_upstream_nodes(node_reference)
443
+ except Exception as e:
444
+ logger.exception("Error collecting parameter values for node '%s'", node_reference.node_reference.name)
445
+ context.error_message = (
446
+ f"Parameter passthrough failed for node '{node_reference.node_reference.name}': {e}"
447
+ )
448
+ context.workflow_state = WorkflowState.ERRORED
449
+ return ErrorState
450
+
451
+ # Clear parameter output values before execution
452
+ try:
453
+ ExecuteDagState.clear_parameter_output_values(node_reference)
454
+ except Exception as e:
455
+ logger.exception(
456
+ "Error clearing parameter output values for node '%s'", node_reference.node_reference.name
457
+ )
458
+ context.error_message = (
459
+ f"Parameter clearing failed for node '{node_reference.node_reference.name}': {e}"
460
+ )
461
+ context.workflow_state = WorkflowState.ERRORED
462
+ return ErrorState
463
+
464
+ def on_task_done(task: asyncio.Task) -> None:
465
+ node = context.task_to_node.pop(task)
466
+ node.node_state = NodeState.DONE
467
+ logger.info("Task done: %s", node.node_reference.name)
468
+
469
+ # Execute the node asynchronously
470
+ node_task = asyncio.create_task(ExecuteDagState.execute_node(node_reference, context.async_semaphore))
471
+ # Add a callback to set node to done when task has finished.
472
+ context.task_to_node[node_task] = node_reference
473
+ node_reference.task_reference = node_task
474
+ node_task.add_done_callback(lambda t: on_task_done(t))
475
+ node_reference.node_state = NodeState.PROCESSING
476
+ node_reference.node_reference.state = NodeResolutionState.RESOLVING
477
+ # Wait for a task to finish
478
+ await asyncio.wait(context.task_to_node.keys(), return_when=asyncio.FIRST_COMPLETED)
479
+ # Once a task has finished, loop back to the top.
480
+ return ExecuteDagState
481
+
482
+
483
+ class ErrorState(State):
484
+ @staticmethod
485
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
486
+ if context.error_message:
487
+ logger.error("DAG execution error: %s", context.error_message)
488
+ for node in context.node_to_reference.values():
489
+ # Cancel all nodes that haven't yet begun processing.
490
+ if node.node_state == NodeState.QUEUED:
491
+ node.node_state = NodeState.CANCELED
492
+ # Shut down and cancel all threads/tasks that haven't yet ran. Currently running ones will not be affected.
493
+ # Cancel async tasks
494
+ for task in list(context.task_to_node.keys()):
495
+ if not task.done():
496
+ task.cancel()
497
+ return ErrorState
498
+
499
+ @staticmethod
500
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None:
501
+ # Don't modify lists while iterating through them.
502
+ task_to_node = context.task_to_node
503
+ for task, node in task_to_node.copy().items():
504
+ if task.done():
505
+ node.node_state = NodeState.DONE
506
+ elif task.cancelled():
507
+ node.node_state = NodeState.CANCELED
508
+ task_to_node.pop(task)
509
+
510
+ # Handle async tasks
511
+ task_to_node = context.task_to_node
512
+ for task, node in task_to_node.copy().items():
513
+ if task.done():
514
+ node.node_state = NodeState.DONE
515
+ elif task.cancelled():
516
+ node.node_state = NodeState.CANCELED
517
+ task_to_node.pop(task)
518
+
519
+ if len(task_to_node) == 0:
520
+ # Finish up. We failed.
521
+ context.workflow_state = WorkflowState.ERRORED
522
+ context.network.clear()
523
+ context.node_to_reference.clear()
524
+ context.task_to_node.clear()
525
+ return DagCompleteState
526
+ # Let's continue going through until everything is cancelled.
527
+ return ErrorState
528
+
529
+
530
+ class DagCompleteState(State):
531
+ @staticmethod
532
+ async def on_enter(context: ParallelResolutionContext) -> type[State] | None:
533
+ # Set build_only back to False.
534
+ context.build_only = False
535
+ return None
536
+
537
+ @staticmethod
538
+ async def on_update(context: ParallelResolutionContext) -> type[State] | None: # noqa: ARG004
539
+ return None
540
+
541
+
542
+ class ParallelResolutionMachine(FSM[ParallelResolutionContext]):
543
+ """State machine for building DAG structure without execution."""
544
+
545
+ def __init__(self, flow_name: str, max_nodes_in_parallel: int | None = None) -> None:
546
+ resolution_context = ParallelResolutionContext(flow_name, max_nodes_in_parallel=max_nodes_in_parallel)
547
+ super().__init__(resolution_context)
548
+
549
+ async def resolve_node(self, node: BaseNode, *, build_only: bool = False) -> None:
550
+ """Build DAG structure starting from the given node."""
551
+ self._context.focus_stack.append(Focus(node=node))
552
+ self._context.build_only = build_only
553
+ await self.start(InitializeDagSpotlightState)
554
+
555
+ async def build_dag_for_node(self, node: BaseNode) -> None:
556
+ """Build DAG structure starting from the given node. (Deprecated: use resolve_node)."""
557
+ await self.resolve_node(node)
558
+
559
+ def change_debug_mode(self, *, debug_mode: bool) -> None:
560
+ self._context.paused = debug_mode
561
+
562
+ def is_complete(self) -> bool:
563
+ return self._current_state is DagCompleteState
564
+
565
+ def is_started(self) -> bool:
566
+ return self._current_state is not None
567
+
568
+ def reset_machine(self, *, cancel: bool = False) -> None:
569
+ self._context.reset(cancel=cancel)
570
+ self._current_state = None