griptape-nodes 0.64.10__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/app/app.py +25 -5
- griptape_nodes/cli/commands/init.py +65 -54
- griptape_nodes/cli/commands/libraries.py +92 -85
- griptape_nodes/cli/commands/self.py +121 -0
- griptape_nodes/common/node_executor.py +2142 -101
- griptape_nodes/exe_types/base_iterative_nodes.py +1004 -0
- griptape_nodes/exe_types/connections.py +114 -19
- griptape_nodes/exe_types/core_types.py +225 -7
- griptape_nodes/exe_types/flow.py +3 -3
- griptape_nodes/exe_types/node_types.py +681 -225
- griptape_nodes/exe_types/param_components/README.md +414 -0
- griptape_nodes/exe_types/param_components/api_key_provider_parameter.py +200 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_model_parameter.py +2 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_file_parameter.py +79 -5
- griptape_nodes/exe_types/param_types/parameter_button.py +443 -0
- griptape_nodes/machines/control_flow.py +77 -38
- griptape_nodes/machines/dag_builder.py +148 -70
- griptape_nodes/machines/parallel_resolution.py +61 -35
- griptape_nodes/machines/sequential_resolution.py +11 -113
- griptape_nodes/retained_mode/events/app_events.py +1 -0
- griptape_nodes/retained_mode/events/base_events.py +16 -13
- griptape_nodes/retained_mode/events/connection_events.py +3 -0
- griptape_nodes/retained_mode/events/execution_events.py +35 -0
- griptape_nodes/retained_mode/events/flow_events.py +15 -2
- griptape_nodes/retained_mode/events/library_events.py +347 -0
- griptape_nodes/retained_mode/events/node_events.py +48 -0
- griptape_nodes/retained_mode/events/os_events.py +86 -3
- griptape_nodes/retained_mode/events/project_events.py +15 -1
- griptape_nodes/retained_mode/events/workflow_events.py +48 -1
- griptape_nodes/retained_mode/griptape_nodes.py +6 -2
- griptape_nodes/retained_mode/managers/config_manager.py +10 -8
- griptape_nodes/retained_mode/managers/event_manager.py +168 -0
- griptape_nodes/retained_mode/managers/fitness_problems/libraries/__init__.py +2 -0
- griptape_nodes/retained_mode/managers/fitness_problems/libraries/old_xdg_location_warning_problem.py +43 -0
- griptape_nodes/retained_mode/managers/flow_manager.py +664 -123
- griptape_nodes/retained_mode/managers/library_manager.py +1143 -139
- griptape_nodes/retained_mode/managers/model_manager.py +2 -3
- griptape_nodes/retained_mode/managers/node_manager.py +148 -25
- griptape_nodes/retained_mode/managers/object_manager.py +3 -1
- griptape_nodes/retained_mode/managers/operation_manager.py +3 -1
- griptape_nodes/retained_mode/managers/os_manager.py +1158 -122
- griptape_nodes/retained_mode/managers/secrets_manager.py +2 -3
- griptape_nodes/retained_mode/managers/settings.py +21 -1
- griptape_nodes/retained_mode/managers/sync_manager.py +2 -3
- griptape_nodes/retained_mode/managers/workflow_manager.py +358 -104
- griptape_nodes/retained_mode/retained_mode.py +3 -3
- griptape_nodes/traits/button.py +44 -2
- griptape_nodes/traits/file_system_picker.py +2 -2
- griptape_nodes/utils/file_utils.py +101 -0
- griptape_nodes/utils/git_utils.py +1226 -0
- griptape_nodes/utils/library_utils.py +122 -0
- {griptape_nodes-0.64.10.dist-info → griptape_nodes-0.65.0.dist-info}/METADATA +2 -1
- {griptape_nodes-0.64.10.dist-info → griptape_nodes-0.65.0.dist-info}/RECORD +55 -47
- {griptape_nodes-0.64.10.dist-info → griptape_nodes-0.65.0.dist-info}/WHEEL +1 -1
- {griptape_nodes-0.64.10.dist-info → griptape_nodes-0.65.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import ast
|
|
4
|
+
import asyncio
|
|
4
5
|
import logging
|
|
5
6
|
import pickle
|
|
6
7
|
from dataclasses import dataclass
|
|
@@ -9,6 +10,10 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
|
|
9
10
|
|
|
10
11
|
from griptape_nodes.bootstrap.workflow_publishers.subprocess_workflow_publisher import SubprocessWorkflowPublisher
|
|
11
12
|
from griptape_nodes.drivers.storage.storage_backend import StorageBackend
|
|
13
|
+
from griptape_nodes.exe_types.base_iterative_nodes import (
|
|
14
|
+
BaseIterativeEndNode,
|
|
15
|
+
BaseIterativeStartNode,
|
|
16
|
+
)
|
|
12
17
|
from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
|
|
13
18
|
from griptape_nodes.exe_types.node_types import (
|
|
14
19
|
CONTROL_INPUT_PARAMETER,
|
|
@@ -17,25 +22,82 @@ from griptape_nodes.exe_types.node_types import (
|
|
|
17
22
|
BaseNode,
|
|
18
23
|
EndNode,
|
|
19
24
|
NodeGroupNode,
|
|
25
|
+
NodeResolutionState,
|
|
20
26
|
StartNode,
|
|
21
27
|
)
|
|
28
|
+
from griptape_nodes.machines.dag_builder import DagBuilder
|
|
22
29
|
from griptape_nodes.node_library.library_registry import Library, LibraryRegistry
|
|
23
30
|
from griptape_nodes.node_library.workflow_registry import WorkflowRegistry
|
|
31
|
+
from griptape_nodes.retained_mode.events.agent_events import AgentStreamEvent
|
|
32
|
+
from griptape_nodes.retained_mode.events.base_events import ProgressEvent
|
|
33
|
+
from griptape_nodes.retained_mode.events.connection_events import (
|
|
34
|
+
CreateConnectionResultFailure,
|
|
35
|
+
CreateConnectionResultSuccess,
|
|
36
|
+
ListConnectionsForNodeRequest,
|
|
37
|
+
ListConnectionsForNodeResultSuccess,
|
|
38
|
+
)
|
|
39
|
+
from griptape_nodes.retained_mode.events.execution_events import (
|
|
40
|
+
ControlFlowCancelledEvent,
|
|
41
|
+
ControlFlowResolvedEvent,
|
|
42
|
+
CurrentControlNodeEvent,
|
|
43
|
+
CurrentDataNodeEvent,
|
|
44
|
+
GriptapeEvent,
|
|
45
|
+
InvolvedNodesEvent,
|
|
46
|
+
NodeFinishProcessEvent,
|
|
47
|
+
NodeResolvedEvent,
|
|
48
|
+
NodeStartProcessEvent,
|
|
49
|
+
NodeUnresolvedEvent,
|
|
50
|
+
ParameterSpotlightEvent,
|
|
51
|
+
ParameterValueUpdateEvent,
|
|
52
|
+
SelectedControlOutputEvent,
|
|
53
|
+
StartLocalSubflowRequest,
|
|
54
|
+
StartLocalSubflowResultFailure,
|
|
55
|
+
StartLocalSubflowResultSuccess,
|
|
56
|
+
)
|
|
24
57
|
from griptape_nodes.retained_mode.events.flow_events import (
|
|
58
|
+
CreateFlowResultFailure,
|
|
59
|
+
CreateFlowResultSuccess,
|
|
60
|
+
DeleteFlowRequest,
|
|
61
|
+
DeleteFlowResultFailure,
|
|
62
|
+
DeleteFlowResultSuccess,
|
|
63
|
+
DeserializeFlowFromCommandsRequest,
|
|
64
|
+
DeserializeFlowFromCommandsResultFailure,
|
|
65
|
+
DeserializeFlowFromCommandsResultSuccess,
|
|
66
|
+
PackagedNodeParameterMapping,
|
|
25
67
|
PackageNodesAsSerializedFlowRequest,
|
|
26
68
|
PackageNodesAsSerializedFlowResultSuccess,
|
|
27
69
|
)
|
|
70
|
+
from griptape_nodes.retained_mode.events.node_events import (
|
|
71
|
+
DeserializeNodeFromCommandsResultFailure,
|
|
72
|
+
DeserializeNodeFromCommandsResultSuccess,
|
|
73
|
+
SetLockNodeStateResultFailure,
|
|
74
|
+
SetLockNodeStateResultSuccess,
|
|
75
|
+
)
|
|
76
|
+
from griptape_nodes.retained_mode.events.parameter_events import (
|
|
77
|
+
AlterElementEvent,
|
|
78
|
+
RemoveElementEvent,
|
|
79
|
+
SetParameterValueRequest,
|
|
80
|
+
SetParameterValueResultFailure,
|
|
81
|
+
SetParameterValueResultSuccess,
|
|
82
|
+
)
|
|
28
83
|
from griptape_nodes.retained_mode.events.workflow_events import (
|
|
29
84
|
DeleteWorkflowRequest,
|
|
30
85
|
DeleteWorkflowResultFailure,
|
|
86
|
+
ImportWorkflowAsReferencedSubFlowResultFailure,
|
|
87
|
+
ImportWorkflowAsReferencedSubFlowResultSuccess,
|
|
31
88
|
LoadWorkflowMetadata,
|
|
32
89
|
LoadWorkflowMetadataResultSuccess,
|
|
90
|
+
PublishWorkflowProgressEvent,
|
|
33
91
|
PublishWorkflowRegisteredEventData,
|
|
34
92
|
PublishWorkflowRequest,
|
|
35
93
|
SaveWorkflowFileFromSerializedFlowRequest,
|
|
36
94
|
SaveWorkflowFileFromSerializedFlowResultSuccess,
|
|
37
95
|
)
|
|
38
96
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
97
|
+
from griptape_nodes.retained_mode.managers.event_manager import (
|
|
98
|
+
EventSuppressionContext,
|
|
99
|
+
EventTranslationContext,
|
|
100
|
+
)
|
|
39
101
|
|
|
40
102
|
if TYPE_CHECKING:
|
|
41
103
|
from griptape_nodes.retained_mode.events.node_events import SerializedNodeCommands
|
|
@@ -43,6 +105,46 @@ if TYPE_CHECKING:
|
|
|
43
105
|
|
|
44
106
|
logger = logging.getLogger("griptape_nodes")
|
|
45
107
|
|
|
108
|
+
LOOP_EVENTS_TO_SUPPRESS = {
|
|
109
|
+
CreateFlowResultSuccess,
|
|
110
|
+
CreateFlowResultFailure,
|
|
111
|
+
ImportWorkflowAsReferencedSubFlowResultSuccess,
|
|
112
|
+
ImportWorkflowAsReferencedSubFlowResultFailure,
|
|
113
|
+
DeserializeNodeFromCommandsResultSuccess,
|
|
114
|
+
DeserializeNodeFromCommandsResultFailure,
|
|
115
|
+
CreateConnectionResultSuccess,
|
|
116
|
+
CreateConnectionResultFailure,
|
|
117
|
+
SetParameterValueResultSuccess,
|
|
118
|
+
SetParameterValueResultFailure,
|
|
119
|
+
SetLockNodeStateResultSuccess,
|
|
120
|
+
SetLockNodeStateResultFailure,
|
|
121
|
+
DeserializeFlowFromCommandsResultSuccess,
|
|
122
|
+
DeserializeFlowFromCommandsResultFailure,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
EXECUTION_EVENTS_TO_SUPPRESS = {
|
|
126
|
+
CurrentControlNodeEvent,
|
|
127
|
+
CurrentDataNodeEvent,
|
|
128
|
+
SelectedControlOutputEvent,
|
|
129
|
+
ParameterSpotlightEvent,
|
|
130
|
+
ControlFlowResolvedEvent,
|
|
131
|
+
ControlFlowCancelledEvent,
|
|
132
|
+
NodeResolvedEvent,
|
|
133
|
+
ParameterValueUpdateEvent,
|
|
134
|
+
NodeUnresolvedEvent,
|
|
135
|
+
NodeStartProcessEvent,
|
|
136
|
+
NodeFinishProcessEvent,
|
|
137
|
+
InvolvedNodesEvent,
|
|
138
|
+
GriptapeEvent,
|
|
139
|
+
PublishWorkflowProgressEvent,
|
|
140
|
+
AgentStreamEvent,
|
|
141
|
+
AlterElementEvent,
|
|
142
|
+
RemoveElementEvent,
|
|
143
|
+
StartLocalSubflowResultSuccess,
|
|
144
|
+
StartLocalSubflowResultFailure,
|
|
145
|
+
ProgressEvent,
|
|
146
|
+
}
|
|
147
|
+
|
|
46
148
|
|
|
47
149
|
@dataclass
|
|
48
150
|
class PublishWorkflowStartEndNodes:
|
|
@@ -61,6 +163,21 @@ class PublishLocalWorkflowResult(NamedTuple):
|
|
|
61
163
|
package_result: PackageNodesAsSerializedFlowResultSuccess
|
|
62
164
|
|
|
63
165
|
|
|
166
|
+
class EntryNodeParameter(NamedTuple):
|
|
167
|
+
"""Entry node and Entry Parameter."""
|
|
168
|
+
|
|
169
|
+
entry_node: str | None
|
|
170
|
+
entry_parameter: str | None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class LoopBodyNodes(NamedTuple):
|
|
174
|
+
"""Result of collecting loop body nodes."""
|
|
175
|
+
|
|
176
|
+
all_nodes: set[str]
|
|
177
|
+
execution_type: str
|
|
178
|
+
node_group_name: str | None
|
|
179
|
+
|
|
180
|
+
|
|
64
181
|
class NodeExecutor:
|
|
65
182
|
"""Singleton executor that executes nodes dynamically."""
|
|
66
183
|
|
|
@@ -93,7 +210,13 @@ class NodeExecutor:
|
|
|
93
210
|
# If it isn't Local or Private, it must be a library name. We'll try to execute it, and if the library name doesn't exist, it'll raise an error.
|
|
94
211
|
await self._execute_library_workflow(node, execution_type)
|
|
95
212
|
return
|
|
96
|
-
|
|
213
|
+
|
|
214
|
+
# Handle iterative loop nodes - check if we need to package and execute the loop
|
|
215
|
+
if isinstance(node, BaseIterativeEndNode):
|
|
216
|
+
await self.handle_loop_execution(node)
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
# We default to local execution if it is not a NodeGroupNode or BaseIterativeEndNode!
|
|
97
220
|
await node.aprocess()
|
|
98
221
|
|
|
99
222
|
async def _execute_and_apply_workflow(
|
|
@@ -301,6 +424,9 @@ class NodeExecutor:
|
|
|
301
424
|
if len(node_names) == 0:
|
|
302
425
|
return None
|
|
303
426
|
|
|
427
|
+
# Pass node_group_name if we're packaging a NodeGroupNode
|
|
428
|
+
node_group_name = node.name if isinstance(node, NodeGroupNode) else None
|
|
429
|
+
|
|
304
430
|
request = PackageNodesAsSerializedFlowRequest(
|
|
305
431
|
node_names=node_names,
|
|
306
432
|
start_node_type=workflow_start_end_nodes.start_flow_node_type,
|
|
@@ -310,6 +436,7 @@ class NodeExecutor:
|
|
|
310
436
|
output_parameter_prefix=output_parameter_prefix,
|
|
311
437
|
entry_control_node_name=None,
|
|
312
438
|
entry_control_parameter_name=None,
|
|
439
|
+
node_group_name=node_group_name,
|
|
313
440
|
)
|
|
314
441
|
package_result = GriptapeNodes.handle_request(request)
|
|
315
442
|
if not isinstance(package_result, PackageNodesAsSerializedFlowResultSuccess):
|
|
@@ -362,6 +489,7 @@ class NodeExecutor:
|
|
|
362
489
|
published_workflow_filename: Path,
|
|
363
490
|
file_name: str,
|
|
364
491
|
pickle_control_flow_result: bool = True, # noqa: FBT001, FBT002
|
|
492
|
+
flow_input: dict[str, Any] | None = None,
|
|
365
493
|
) -> dict[str, dict[str | SerializedNodeCommands.UniqueParameterValueUUID, Any] | None]:
|
|
366
494
|
"""Execute the published workflow in a subprocess.
|
|
367
495
|
|
|
@@ -369,6 +497,7 @@ class NodeExecutor:
|
|
|
369
497
|
published_workflow_filename: Path to the workflow file to execute
|
|
370
498
|
file_name: Name of the workflow for logging
|
|
371
499
|
pickle_control_flow_result: Whether to pickle control flow results (defaults to True)
|
|
500
|
+
flow_input: Optional dictionary of parameter values to pass to the workflow's StartFlow node
|
|
372
501
|
|
|
373
502
|
Returns:
|
|
374
503
|
The subprocess execution output dictionary
|
|
@@ -378,11 +507,10 @@ class NodeExecutor:
|
|
|
378
507
|
)
|
|
379
508
|
|
|
380
509
|
subprocess_executor = SubprocessWorkflowExecutor(workflow_path=str(published_workflow_filename))
|
|
381
|
-
|
|
382
510
|
try:
|
|
383
511
|
async with subprocess_executor as executor:
|
|
384
512
|
await executor.arun(
|
|
385
|
-
flow_input={},
|
|
513
|
+
flow_input=flow_input or {},
|
|
386
514
|
storage_backend=await self._get_storage_backend(),
|
|
387
515
|
pickle_control_flow_result=pickle_control_flow_result,
|
|
388
516
|
)
|
|
@@ -403,133 +531,2046 @@ class NodeExecutor:
|
|
|
403
531
|
raise ValueError(msg)
|
|
404
532
|
return my_subprocess_result
|
|
405
533
|
|
|
406
|
-
def
|
|
407
|
-
self,
|
|
408
|
-
) ->
|
|
409
|
-
"""
|
|
534
|
+
def _find_loop_entry_node(
|
|
535
|
+
self, start_node: BaseIterativeStartNode, node_group_name: str | None, connections: Any
|
|
536
|
+
) -> EntryNodeParameter:
|
|
537
|
+
"""Find the entry control node and parameter for a loop body.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
start_node: The loop start node
|
|
541
|
+
node_group_name: Name of NodeGroup if loop body is a NodeGroup, None otherwise
|
|
542
|
+
connections: Connections object from FlowManager
|
|
410
543
|
|
|
411
544
|
Returns:
|
|
412
|
-
|
|
545
|
+
Tuple of (entry_node_name, entry_parameter_name) or (None, None) if not found
|
|
413
546
|
"""
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
547
|
+
entry_control_node_name = None
|
|
548
|
+
entry_control_parameter_name = None
|
|
549
|
+
exec_out_param_name = start_node.exec_out.name
|
|
550
|
+
|
|
551
|
+
if start_node.name not in connections.outgoing_index:
|
|
552
|
+
return EntryNodeParameter(None, None)
|
|
553
|
+
|
|
554
|
+
exec_out_connections = connections.outgoing_index[start_node.name].get(exec_out_param_name, [])
|
|
555
|
+
if not exec_out_connections:
|
|
556
|
+
return EntryNodeParameter(None, None)
|
|
557
|
+
|
|
558
|
+
first_conn_id = exec_out_connections[0]
|
|
559
|
+
first_conn = connections.connections[first_conn_id]
|
|
560
|
+
|
|
561
|
+
# If connecting to a NodeGroup, find the actual internal entry node
|
|
562
|
+
if node_group_name is not None and first_conn.target_node.name == node_group_name:
|
|
563
|
+
# The connection goes to a proxy parameter on the NodeGroup
|
|
564
|
+
# Find the internal connection from that proxy parameter to the actual entry node
|
|
565
|
+
proxy_param = first_conn.target_parameter
|
|
566
|
+
if node_group_name in connections.outgoing_index:
|
|
567
|
+
proxy_connections = connections.outgoing_index[node_group_name].get(proxy_param.name, [])
|
|
568
|
+
if proxy_connections:
|
|
569
|
+
internal_conn_id = proxy_connections[0]
|
|
570
|
+
internal_conn = connections.connections[internal_conn_id]
|
|
571
|
+
if internal_conn.is_node_group_internal:
|
|
572
|
+
entry_control_node_name = internal_conn.target_node.name
|
|
573
|
+
entry_control_parameter_name = internal_conn.target_parameter.name
|
|
574
|
+
else:
|
|
575
|
+
# Direct connection to a regular node
|
|
576
|
+
entry_control_node_name = first_conn.target_node.name
|
|
577
|
+
entry_control_parameter_name = first_conn.target_parameter.name
|
|
578
|
+
# If the connection is just to the End Node, then we don't have an entry control connection.
|
|
579
|
+
if first_conn.target_node == start_node.end_node:
|
|
580
|
+
return EntryNodeParameter(None, None)
|
|
420
581
|
|
|
421
|
-
|
|
422
|
-
unique_uuid_to_values = result_dict.get("unique_parameter_uuid_to_values")
|
|
582
|
+
return EntryNodeParameter(entry_node=entry_control_node_name, entry_parameter=entry_control_parameter_name)
|
|
423
583
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
584
|
+
def _collect_loop_body_nodes(
|
|
585
|
+
self,
|
|
586
|
+
start_node: BaseIterativeStartNode,
|
|
587
|
+
end_node: BaseIterativeEndNode,
|
|
588
|
+
nodes_in_control_flow: set[str],
|
|
589
|
+
connections: Any,
|
|
590
|
+
) -> LoopBodyNodes:
|
|
591
|
+
"""Collect all nodes in the loop body, including data dependencies.
|
|
428
592
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
593
|
+
Returns:
|
|
594
|
+
LoopBodyNodes containing all_nodes, execution_type, and node_group_name
|
|
595
|
+
"""
|
|
596
|
+
all_nodes: set[str] = set()
|
|
597
|
+
visited_deps: set[str] = set()
|
|
598
|
+
|
|
599
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
600
|
+
# Exclude the start node from packaging. And, we don't want their dependencies.
|
|
601
|
+
nodes_in_control_flow.discard(start_node.name)
|
|
602
|
+
for node_name in nodes_in_control_flow:
|
|
603
|
+
# Add ALL nodes in control flow for removal from parent DAG
|
|
604
|
+
all_nodes.add(node_name)
|
|
605
|
+
node_obj = node_manager.get_node_by_name(node_name)
|
|
606
|
+
deps = DagBuilder.collect_data_dependencies_for_node(
|
|
607
|
+
node_obj, connections, nodes_in_control_flow, visited_deps
|
|
608
|
+
)
|
|
609
|
+
all_nodes.update(deps)
|
|
610
|
+
# Discard the end node from packaging.
|
|
611
|
+
all_nodes.discard(end_node.name)
|
|
612
|
+
# Make sure the start node wasn't added in the dependencies.
|
|
613
|
+
all_nodes.discard(start_node.name)
|
|
614
|
+
|
|
615
|
+
# See if they're all in one NodeGroup
|
|
616
|
+
execution_type = LOCAL_EXECUTION
|
|
617
|
+
node_group_name = None
|
|
618
|
+
if len(all_nodes) == 1:
|
|
619
|
+
node_inside = all_nodes.pop()
|
|
620
|
+
node_obj = node_manager.get_node_by_name(node_inside)
|
|
621
|
+
if isinstance(node_obj, NodeGroupNode):
|
|
622
|
+
execution_type = node_obj.get_parameter_value(node_obj.execution_environment.name)
|
|
623
|
+
all_nodes.update(node_obj.get_all_nodes())
|
|
624
|
+
node_group_name = node_obj.name
|
|
625
|
+
else:
|
|
626
|
+
all_nodes.add(node_inside)
|
|
435
627
|
|
|
436
|
-
|
|
437
|
-
|
|
628
|
+
return LoopBodyNodes(all_nodes=all_nodes, execution_type=execution_type, node_group_name=node_group_name)
|
|
629
|
+
|
|
630
|
+
async def _package_loop_body(
|
|
631
|
+
self,
|
|
632
|
+
start_node: BaseIterativeStartNode,
|
|
633
|
+
end_node: BaseIterativeEndNode,
|
|
634
|
+
) -> tuple[PackageNodesAsSerializedFlowResultSuccess, str] | None:
|
|
635
|
+
"""Package the loop body (nodes between start and end) into a serialized flow.
|
|
438
636
|
|
|
439
637
|
Args:
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
638
|
+
start_node: The BaseIterativeStartNode marking the start of the loop
|
|
639
|
+
end_node: The BaseIterativeEndNode marking the end of the loop
|
|
640
|
+
execution_type: The execution environment type
|
|
443
641
|
|
|
444
642
|
Returns:
|
|
445
|
-
|
|
643
|
+
PackageNodesAsSerializedFlowResultSuccess if successful, None if empty loop body
|
|
446
644
|
"""
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
return param_value
|
|
645
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
646
|
+
connections = flow_manager.get_connections()
|
|
450
647
|
|
|
451
|
-
|
|
648
|
+
# Collect all nodes in the forward control path from start to end
|
|
649
|
+
nodes_in_control_flow = DagBuilder.collect_nodes_in_forward_control_path(start_node, end_node, connections)
|
|
452
650
|
|
|
453
|
-
#
|
|
454
|
-
|
|
455
|
-
|
|
651
|
+
# Filter out nodes already in the current DAG and collect data dependencies
|
|
652
|
+
loop_body_result = self._collect_loop_body_nodes(start_node, end_node, nodes_in_control_flow, connections)
|
|
653
|
+
all_nodes = loop_body_result.all_nodes
|
|
654
|
+
execution_type = loop_body_result.execution_type
|
|
655
|
+
node_group_name = loop_body_result.node_group_name
|
|
456
656
|
|
|
457
|
-
#
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
657
|
+
# Handle empty loop body (no nodes between start and end)
|
|
658
|
+
if not all_nodes:
|
|
659
|
+
await self._handle_empty_loop_body(start_node, end_node)
|
|
660
|
+
return None
|
|
661
|
+
# Find the first node in the loop body (where start_node.exec_out connects to)
|
|
662
|
+
entry_node_parameter = self._find_loop_entry_node(start_node, node_group_name, connections)
|
|
663
|
+
entry_control_node_name = entry_node_parameter.entry_node
|
|
664
|
+
entry_control_parameter_name = entry_node_parameter.entry_parameter
|
|
665
|
+
# Determine library and node types based on execution_type
|
|
666
|
+
library = None
|
|
667
|
+
if execution_type not in (LOCAL_EXECUTION, PRIVATE_EXECUTION):
|
|
668
|
+
try:
|
|
669
|
+
library = LibraryRegistry.get_library(name=execution_type)
|
|
670
|
+
except KeyError:
|
|
671
|
+
msg = "Could not find library '%s' for loop execution", execution_type
|
|
672
|
+
logger.error(msg)
|
|
673
|
+
raise RuntimeError(msg) # noqa: B904
|
|
470
674
|
|
|
471
|
-
|
|
675
|
+
library_name = library.get_library_data().name
|
|
676
|
+
workflow_start_end_nodes = await self._get_workflow_start_end_nodes(library)
|
|
677
|
+
start_node_type = workflow_start_end_nodes.start_flow_node_type
|
|
678
|
+
end_node_type = workflow_start_end_nodes.end_flow_node_type
|
|
679
|
+
library_name = workflow_start_end_nodes.start_flow_node_library_name
|
|
680
|
+
|
|
681
|
+
# Create the packaging request
|
|
682
|
+
request = PackageNodesAsSerializedFlowRequest(
|
|
683
|
+
node_names=list(all_nodes),
|
|
684
|
+
start_node_type=start_node_type,
|
|
685
|
+
end_node_type=end_node_type,
|
|
686
|
+
start_node_library_name=library_name,
|
|
687
|
+
end_node_library_name=library_name,
|
|
688
|
+
entry_control_node_name=entry_control_node_name,
|
|
689
|
+
entry_control_parameter_name=entry_control_parameter_name,
|
|
690
|
+
output_parameter_prefix=f"{end_node.name.replace(' ', '_')}_loop_",
|
|
691
|
+
node_group_name=node_group_name,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
package_result = GriptapeNodes.handle_request(request)
|
|
695
|
+
if not isinstance(package_result, PackageNodesAsSerializedFlowResultSuccess):
|
|
696
|
+
msg = f"Failed to package loop nodes for '{end_node.name}'. Error: {package_result.result_details}"
|
|
697
|
+
raise TypeError(msg)
|
|
698
|
+
|
|
699
|
+
logger.info(
|
|
700
|
+
"Successfully packaged %d nodes for loop execution from '%s' to '%s'",
|
|
701
|
+
len(all_nodes),
|
|
702
|
+
start_node.name,
|
|
703
|
+
end_node.name,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Remove packaged nodes from global queue since they will be copied into loop iterations
|
|
707
|
+
self._remove_packaged_nodes_from_queue(all_nodes)
|
|
708
|
+
|
|
709
|
+
return package_result, execution_type
|
|
710
|
+
|
|
711
|
+
async def _handle_empty_loop_body(
|
|
472
712
|
self,
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
713
|
+
start_node: BaseIterativeStartNode,
|
|
714
|
+
end_node: BaseIterativeEndNode,
|
|
476
715
|
) -> None:
|
|
477
|
-
"""
|
|
716
|
+
"""Handle empty loop body (no nodes between start and end).
|
|
478
717
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
718
|
+
Args:
|
|
719
|
+
start_node: The BaseIterativeStartNode
|
|
720
|
+
end_node: The BaseIterativeEndNode
|
|
482
721
|
"""
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
722
|
+
total_iterations = start_node._get_total_iterations()
|
|
723
|
+
logger.info(
|
|
724
|
+
"No nodes found between '%s' and '%s'. Processing empty loop body.",
|
|
725
|
+
start_node.name,
|
|
726
|
+
end_node.name,
|
|
727
|
+
)
|
|
487
728
|
|
|
488
|
-
#
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
729
|
+
# Check if there are direct data connections from start to end
|
|
730
|
+
list_connections_request = ListConnectionsForNodeRequest(node_name=start_node.name)
|
|
731
|
+
list_connections_result = GriptapeNodes.handle_request(list_connections_request)
|
|
732
|
+
|
|
733
|
+
connected_source_param = None
|
|
734
|
+
if isinstance(list_connections_result, ListConnectionsForNodeResultSuccess):
|
|
735
|
+
for conn in list_connections_result.outgoing_connections:
|
|
736
|
+
if conn.target_node_name == end_node.name and conn.target_parameter_name == "new_item_to_add":
|
|
737
|
+
connected_source_param = conn.source_parameter_name
|
|
738
|
+
break
|
|
739
|
+
|
|
740
|
+
logger.info(
|
|
741
|
+
"Processing %d iterations for empty loop from '%s' to '%s' (connected param: %s)",
|
|
742
|
+
total_iterations,
|
|
743
|
+
start_node.name,
|
|
744
|
+
end_node.name,
|
|
745
|
+
connected_source_param,
|
|
746
|
+
)
|
|
494
747
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
748
|
+
# Process iterations to collect results from direct connections
|
|
749
|
+
end_node._results_list = []
|
|
750
|
+
if connected_source_param:
|
|
751
|
+
for iteration_index in range(total_iterations):
|
|
752
|
+
start_node._current_iteration_count = iteration_index
|
|
498
753
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
754
|
+
# Get the value based on which parameter is connected
|
|
755
|
+
if connected_source_param == "current_item":
|
|
756
|
+
value = start_node._get_current_item_value()
|
|
757
|
+
elif connected_source_param == "index":
|
|
758
|
+
value = start_node.get_current_index()
|
|
759
|
+
else:
|
|
760
|
+
start_node._get_current_item_value()
|
|
761
|
+
value = start_node.parameter_output_values.get(connected_source_param)
|
|
762
|
+
|
|
763
|
+
if value is not None:
|
|
764
|
+
end_node._results_list.append(value)
|
|
765
|
+
|
|
766
|
+
end_node._output_results_list()
|
|
767
|
+
|
|
768
|
+
def _should_break_loop(
|
|
769
|
+
self,
|
|
770
|
+
node_name_mappings: dict[str, str],
|
|
771
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
772
|
+
) -> bool:
|
|
773
|
+
"""Check if the loop should break based on the end node's control output.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
node_name_mappings: Mapping from original to deserialized node names
|
|
777
|
+
package_result: The package result containing parameter mappings
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
True if the end node signaled a break, False otherwise
|
|
781
|
+
"""
|
|
782
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
783
|
+
|
|
784
|
+
# Get the End node mapping
|
|
785
|
+
end_node_mapping = self.get_node_parameter_mappings(package_result, "end")
|
|
786
|
+
end_node_name = end_node_mapping.node_name
|
|
787
|
+
|
|
788
|
+
# Get the deserialized end node name
|
|
789
|
+
packaged_end_node_name = node_name_mappings.get(end_node_name)
|
|
790
|
+
if packaged_end_node_name is None:
|
|
791
|
+
logger.warning("Could not find deserialized End node name for %s", end_node_name)
|
|
792
|
+
return False
|
|
793
|
+
|
|
794
|
+
# Get the deserialized end node instance
|
|
795
|
+
deserialized_end_node = node_manager.get_node_by_name(packaged_end_node_name)
|
|
796
|
+
if deserialized_end_node is None:
|
|
797
|
+
logger.warning("Could not find deserialized End node instance for %s", packaged_end_node_name)
|
|
798
|
+
return False
|
|
799
|
+
|
|
800
|
+
# Check if this is a BaseIterativeEndNode
|
|
801
|
+
if not isinstance(deserialized_end_node, BaseIterativeEndNode):
|
|
802
|
+
return False
|
|
803
|
+
|
|
804
|
+
# Check if end node would emit break_loop_signal_output
|
|
805
|
+
next_control_output = deserialized_end_node.get_next_control_output()
|
|
806
|
+
if next_control_output is None:
|
|
807
|
+
return False
|
|
808
|
+
|
|
809
|
+
# Check if it's the break signal
|
|
810
|
+
return next_control_output == deserialized_end_node.break_loop_signal_output
|
|
811
|
+
|
|
812
|
+
async def _execute_loop_iterations_sequentially( # noqa: PLR0915
|
|
813
|
+
self,
|
|
814
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
815
|
+
total_iterations: int,
|
|
816
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
817
|
+
end_loop_node: BaseIterativeEndNode,
|
|
818
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
819
|
+
"""Execute loop iterations sequentially by running one flow instance N times.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
package_result: The packaged flow with parameter mappings
|
|
823
|
+
total_iterations: Number of iterations to run
|
|
824
|
+
parameter_values_per_iteration: Dict mapping iteration_index -> parameter values
|
|
825
|
+
end_loop_node: The End Loop Node to extract results for
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
Tuple of:
|
|
829
|
+
- iteration_results: Dict mapping iteration_index -> result value
|
|
830
|
+
- successful_iterations: List of iteration indices that succeeded
|
|
831
|
+
- last_iteration_values: Dict mapping parameter names -> values from last iteration
|
|
832
|
+
"""
|
|
833
|
+
# Deserialize flow once
|
|
834
|
+
context_manager = GriptapeNodes.ContextManager()
|
|
835
|
+
event_manager = GriptapeNodes.EventManager()
|
|
836
|
+
with EventSuppressionContext(event_manager, LOOP_EVENTS_TO_SUPPRESS):
|
|
837
|
+
deserialize_request = DeserializeFlowFromCommandsRequest(
|
|
838
|
+
serialized_flow_commands=package_result.serialized_flow_commands
|
|
839
|
+
)
|
|
840
|
+
deserialize_result = GriptapeNodes.handle_request(deserialize_request)
|
|
841
|
+
if not isinstance(deserialize_result, DeserializeFlowFromCommandsResultSuccess):
|
|
842
|
+
msg = f"Failed to deserialize flow for sequential loop. Error: {deserialize_result.result_details}"
|
|
843
|
+
raise TypeError(msg)
|
|
844
|
+
|
|
845
|
+
flow_name = deserialize_result.flow_name
|
|
846
|
+
node_name_mappings = deserialize_result.node_name_mappings
|
|
847
|
+
|
|
848
|
+
# Pop the deserialized flow from context stack
|
|
849
|
+
if context_manager.has_current_flow() and context_manager.get_current_flow().name == flow_name:
|
|
850
|
+
context_manager.pop_flow()
|
|
851
|
+
|
|
852
|
+
logger.info("Successfully deserialized flow for sequential execution: %s", flow_name)
|
|
853
|
+
# Get node mappings
|
|
854
|
+
start_node_mapping = self.get_node_parameter_mappings(package_result, "start")
|
|
855
|
+
start_node_name = start_node_mapping.node_name
|
|
856
|
+
packaged_start_node_name = node_name_mappings.get(start_node_name)
|
|
857
|
+
|
|
858
|
+
if packaged_start_node_name is None:
|
|
859
|
+
msg = f"Could not find deserialized Start node (original: '{start_node_name}') for sequential loop"
|
|
860
|
+
raise TypeError(msg)
|
|
861
|
+
|
|
862
|
+
iteration_results: dict[int, Any] = {}
|
|
863
|
+
successful_iterations: list[int] = []
|
|
864
|
+
|
|
865
|
+
# Build reverse mapping: packaged_name → original_name for event translation
|
|
866
|
+
reverse_node_mapping = {
|
|
867
|
+
packaged_name: original_name for original_name, packaged_name in node_name_mappings.items()
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
try:
|
|
871
|
+
# Execute iterations one at a time
|
|
872
|
+
for iteration_index in range(total_iterations):
|
|
873
|
+
logger.info(
|
|
874
|
+
"Starting sequential iteration %d/%d for loop ending at '%s'",
|
|
875
|
+
iteration_index,
|
|
876
|
+
total_iterations,
|
|
877
|
+
end_loop_node.name,
|
|
878
|
+
)
|
|
879
|
+
# Set input values for this iteration
|
|
880
|
+
parameter_values = parameter_values_per_iteration[iteration_index]
|
|
881
|
+
|
|
882
|
+
for startflow_param_name, value_to_set in parameter_values.items():
|
|
883
|
+
set_value_request = SetParameterValueRequest(
|
|
884
|
+
node_name=packaged_start_node_name,
|
|
885
|
+
parameter_name=startflow_param_name,
|
|
886
|
+
value=value_to_set,
|
|
506
887
|
)
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
888
|
+
set_value_result = await GriptapeNodes.ahandle_request(set_value_request)
|
|
889
|
+
if not isinstance(set_value_result, SetParameterValueResultSuccess):
|
|
890
|
+
logger.warning(
|
|
891
|
+
"Failed to set parameter '%s' on Start node '%s' for iteration %d: %s",
|
|
892
|
+
startflow_param_name,
|
|
893
|
+
packaged_start_node_name,
|
|
894
|
+
iteration_index,
|
|
895
|
+
set_value_result.result_details,
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
# Execute this iteration with event translation instead of suppression
|
|
899
|
+
# This allows the UI to show the original nodes highlighting during loop execution
|
|
900
|
+
logger.info(
|
|
901
|
+
"Executing subflow for iteration %d - flow: '%s', start_node: '%s'",
|
|
902
|
+
iteration_index,
|
|
903
|
+
flow_name,
|
|
904
|
+
packaged_start_node_name,
|
|
905
|
+
)
|
|
906
|
+
with EventTranslationContext(event_manager, reverse_node_mapping):
|
|
907
|
+
start_subflow_request = StartLocalSubflowRequest(
|
|
908
|
+
flow_name=flow_name,
|
|
909
|
+
start_node=packaged_start_node_name,
|
|
910
|
+
pickle_control_flow_result=False,
|
|
911
|
+
)
|
|
912
|
+
start_subflow_result = await GriptapeNodes.ahandle_request(start_subflow_request)
|
|
511
913
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
914
|
+
if not isinstance(start_subflow_result, StartLocalSubflowResultSuccess):
|
|
915
|
+
msg = f"Sequential loop iteration {iteration_index} failed: {start_subflow_result.result_details}"
|
|
916
|
+
logger.error(
|
|
917
|
+
"Sequential iteration %d failed for loop ending at '%s'", iteration_index, end_loop_node.name
|
|
918
|
+
)
|
|
919
|
+
raise RuntimeError(msg) # noqa: TRY004 - This is a runtime execution error, not a type error
|
|
920
|
+
|
|
921
|
+
successful_iterations.append(iteration_index)
|
|
922
|
+
|
|
923
|
+
# Extract result from this iteration
|
|
924
|
+
deserialized_flows = [(iteration_index, flow_name, node_name_mappings)]
|
|
925
|
+
single_iteration_results = self.get_parameter_values_from_iterations(
|
|
926
|
+
end_loop_node=end_loop_node,
|
|
927
|
+
deserialized_flows=deserialized_flows,
|
|
928
|
+
package_flow_result_success=package_result,
|
|
519
929
|
)
|
|
520
|
-
|
|
930
|
+
iteration_results.update(single_iteration_results)
|
|
521
931
|
|
|
522
|
-
|
|
523
|
-
if target_param.type != ParameterTypeBuiltin.CONTROL_TYPE:
|
|
524
|
-
target_node.set_parameter_value(target_param_name, param_value)
|
|
525
|
-
target_node.parameter_output_values[target_param_name] = param_value
|
|
932
|
+
logger.info("Completed sequential iteration %d/%d", iteration_index + 1, total_iterations)
|
|
526
933
|
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
934
|
+
# Check if the end node signaled a break
|
|
935
|
+
if self._should_break_loop(node_name_mappings, package_result):
|
|
936
|
+
logger.info(
|
|
937
|
+
"Loop break detected at iteration %d/%d - stopping execution early",
|
|
938
|
+
iteration_index + 1,
|
|
939
|
+
total_iterations,
|
|
940
|
+
)
|
|
941
|
+
break
|
|
942
|
+
|
|
943
|
+
# Extract last iteration values from the last successful iteration
|
|
944
|
+
last_successful_iteration = successful_iterations[-1] if successful_iterations else 0
|
|
945
|
+
deserialized_flows = [(last_successful_iteration, flow_name, node_name_mappings)]
|
|
946
|
+
last_iteration_values = self.get_last_iteration_values_for_packaged_nodes(
|
|
947
|
+
deserialized_flows=deserialized_flows,
|
|
948
|
+
package_result=package_result,
|
|
949
|
+
total_iterations=len(successful_iterations),
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
return iteration_results, successful_iterations, last_iteration_values
|
|
953
|
+
|
|
954
|
+
finally:
|
|
955
|
+
# Cleanup - delete the flow
|
|
956
|
+
with EventSuppressionContext(event_manager, {DeleteFlowResultSuccess, DeleteFlowResultFailure}):
|
|
957
|
+
delete_request = DeleteFlowRequest(flow_name=flow_name)
|
|
958
|
+
delete_result = await GriptapeNodes.ahandle_request(delete_request)
|
|
959
|
+
if not isinstance(delete_result, DeleteFlowResultSuccess):
|
|
960
|
+
logger.warning(
|
|
961
|
+
"Failed to delete sequential loop flow '%s': %s",
|
|
962
|
+
flow_name,
|
|
963
|
+
delete_result.result_details,
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
async def _handle_sequential_loop_execution( # noqa: C901
|
|
967
|
+
self, start_node: BaseIterativeStartNode, end_node: BaseIterativeEndNode
|
|
968
|
+
) -> None:
|
|
969
|
+
"""Handle sequential loop execution by running iterations one at a time.
|
|
970
|
+
|
|
971
|
+
Args:
|
|
972
|
+
start_node: The BaseIterativeStartNode marking the start of the loop
|
|
973
|
+
end_node: The BaseIterativeEndNode marking the end of the loop
|
|
974
|
+
"""
|
|
975
|
+
total_iterations = start_node._get_total_iterations()
|
|
976
|
+
logger.info(
|
|
977
|
+
"Executing loop sequentially from '%s' to '%s' for %d iterations",
|
|
978
|
+
start_node.name,
|
|
979
|
+
end_node.name,
|
|
980
|
+
total_iterations,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
# Package the loop body (nodes between start and end)
|
|
984
|
+
package_result_and_execution = await self._package_loop_body(start_node, end_node)
|
|
985
|
+
|
|
986
|
+
# Handle empty loop body (no nodes between start and end)
|
|
987
|
+
if package_result_and_execution is None:
|
|
988
|
+
logger.info("Empty loop body - results already set by _package_loop_body")
|
|
989
|
+
return
|
|
990
|
+
package_result, execution_type = package_result_and_execution
|
|
991
|
+
|
|
992
|
+
# Get parameter values per iteration
|
|
993
|
+
parameter_values_per_iteration = self.get_parameter_values_per_iteration(start_node, package_result)
|
|
994
|
+
|
|
995
|
+
# Get resolved upstream values (constant across all iterations)
|
|
996
|
+
# Reuse the packaged_node_names from package_result instead of recalculating
|
|
997
|
+
resolved_upstream_values = self.get_resolved_upstream_values(
|
|
998
|
+
packaged_node_names=package_result.packaged_node_names, package_result=package_result
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
# Merge upstream values into each iteration (only if parameter doesn't already exist)
|
|
1002
|
+
if resolved_upstream_values:
|
|
1003
|
+
for iteration_index in parameter_values_per_iteration:
|
|
1004
|
+
for param_name, param_value in resolved_upstream_values.items():
|
|
1005
|
+
if param_name not in parameter_values_per_iteration[iteration_index]:
|
|
1006
|
+
parameter_values_per_iteration[iteration_index][param_name] = param_value
|
|
1007
|
+
|
|
1008
|
+
# Execute iterations sequentially based on execution environment
|
|
1009
|
+
if execution_type == LOCAL_EXECUTION:
|
|
1010
|
+
(
|
|
1011
|
+
iteration_results,
|
|
1012
|
+
successful_iterations,
|
|
1013
|
+
last_iteration_values,
|
|
1014
|
+
) = await self._execute_loop_iterations_sequentially(
|
|
1015
|
+
package_result=package_result,
|
|
1016
|
+
total_iterations=total_iterations,
|
|
1017
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
1018
|
+
end_loop_node=end_node,
|
|
1019
|
+
)
|
|
1020
|
+
elif execution_type == PRIVATE_EXECUTION:
|
|
1021
|
+
(
|
|
1022
|
+
iteration_results,
|
|
1023
|
+
successful_iterations,
|
|
1024
|
+
last_iteration_values,
|
|
1025
|
+
) = await self._execute_loop_iterations_sequentially_private(
|
|
1026
|
+
package_result=package_result,
|
|
1027
|
+
total_iterations=total_iterations,
|
|
1028
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
1029
|
+
end_loop_node=end_node,
|
|
1030
|
+
)
|
|
1031
|
+
else:
|
|
1032
|
+
# Cloud publisher execution (Deadline Cloud, etc.)
|
|
1033
|
+
(
|
|
1034
|
+
iteration_results,
|
|
1035
|
+
successful_iterations,
|
|
1036
|
+
last_iteration_values,
|
|
1037
|
+
) = await self._execute_loop_iterations_sequentially_via_publisher(
|
|
1038
|
+
package_result=package_result,
|
|
1039
|
+
total_iterations=total_iterations,
|
|
1040
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
1041
|
+
end_loop_node=end_node,
|
|
1042
|
+
execution_type=execution_type,
|
|
1043
|
+
)
|
|
1044
|
+
# Check if execution stopped early due to break (not failure)
|
|
1045
|
+
if len(successful_iterations) < total_iterations:
|
|
1046
|
+
# Only raise an error if there were actual failures (not just early termination)
|
|
1047
|
+
# If iterations stopped due to break, the last successful iteration count matches
|
|
1048
|
+
expected_count = len(successful_iterations)
|
|
1049
|
+
actual_count = len(iteration_results)
|
|
1050
|
+
if expected_count != actual_count:
|
|
1051
|
+
failed_count = expected_count - actual_count
|
|
1052
|
+
msg = f"Loop execution failed: {failed_count} of {expected_count} iterations failed"
|
|
1053
|
+
raise RuntimeError(msg)
|
|
1054
|
+
logger.info(
|
|
1055
|
+
"Loop execution stopped early at %d of %d iterations (break signal)",
|
|
1056
|
+
len(successful_iterations),
|
|
1057
|
+
total_iterations,
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
# Build results list in iteration order
|
|
1061
|
+
end_node._results_list = []
|
|
1062
|
+
for iteration_index in sorted(iteration_results.keys()):
|
|
1063
|
+
value = iteration_results[iteration_index]
|
|
1064
|
+
end_node._results_list.append(value)
|
|
1065
|
+
|
|
1066
|
+
logger.info(
|
|
1067
|
+
"Loop '%s': Built results list with %d items from sequential iterations",
|
|
1068
|
+
end_node.name,
|
|
1069
|
+
len(end_node._results_list),
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
# Output final results to the results parameter
|
|
1073
|
+
end_node._output_results_list()
|
|
1074
|
+
logger.info("Loop '%s': Outputted final results list", end_node.name)
|
|
1075
|
+
|
|
1076
|
+
# Apply last iteration values to the original packaged nodes
|
|
1077
|
+
self._apply_last_iteration_to_packaged_nodes(
|
|
1078
|
+
last_iteration_values=last_iteration_values,
|
|
1079
|
+
package_result=package_result,
|
|
1080
|
+
)
|
|
1081
|
+
logger.info("Loop '%s': Applied last iteration values to packaged nodes", end_node.name)
|
|
1082
|
+
|
|
1083
|
+
logger.info(
|
|
1084
|
+
"Completed sequential loop execution from '%s' to '%s' with %d results",
|
|
1085
|
+
start_node.name,
|
|
1086
|
+
end_node.name,
|
|
1087
|
+
len(iteration_results),
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
def _get_merged_parameter_values_for_iterations(
|
|
1091
|
+
self, start_node: BaseIterativeStartNode, package_result: PackageNodesAsSerializedFlowResultSuccess
|
|
1092
|
+
) -> dict[int, dict[str, Any]]:
|
|
1093
|
+
"""Get parameter values for each iteration with resolved upstream values merged in.
|
|
1094
|
+
|
|
1095
|
+
Args:
|
|
1096
|
+
start_node: The start node for the loop
|
|
1097
|
+
package_result: The packaged flow result containing parameter mappings
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
Dict mapping iteration_index -> {parameter_name: value}
|
|
1101
|
+
"""
|
|
1102
|
+
# Get parameter values from start node (vary per iteration)
|
|
1103
|
+
parameter_values_per_iteration = self.get_parameter_values_per_iteration(start_node, package_result)
|
|
1104
|
+
|
|
1105
|
+
# Get resolved upstream values (constant across all iterations)
|
|
1106
|
+
resolved_upstream_values = self.get_resolved_upstream_values(
|
|
1107
|
+
packaged_node_names=package_result.packaged_node_names, package_result=package_result
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
# Merge upstream values into each iteration (only if parameter doesn't already exist)
|
|
1111
|
+
if resolved_upstream_values:
|
|
1112
|
+
for iteration_index in parameter_values_per_iteration:
|
|
1113
|
+
for param_name, param_value in resolved_upstream_values.items():
|
|
1114
|
+
if param_name not in parameter_values_per_iteration[iteration_index]:
|
|
1115
|
+
parameter_values_per_iteration[iteration_index][param_name] = param_value
|
|
1116
|
+
logger.info(
|
|
1117
|
+
"Added %d resolved upstream values to %d iterations",
|
|
1118
|
+
len(resolved_upstream_values),
|
|
1119
|
+
len(parameter_values_per_iteration),
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
return parameter_values_per_iteration
|
|
1123
|
+
|
|
1124
|
+
async def handle_loop_execution(self, node: BaseIterativeEndNode) -> None:
|
|
1125
|
+
"""Handle execution of a loop by packaging nodes from start to end and running them.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
node: The BaseIterativeEndNode marking the end of the loop
|
|
1129
|
+
execution_type: The execution environment type
|
|
1130
|
+
"""
|
|
1131
|
+
# Validate start node exists
|
|
1132
|
+
if node.start_node is None:
|
|
1133
|
+
msg = f"BaseIterativeEndNode '{node.name}' has no start_node reference"
|
|
1134
|
+
raise ValueError(msg)
|
|
1135
|
+
|
|
1136
|
+
start_node = node.start_node
|
|
1137
|
+
|
|
1138
|
+
# Initialize iteration data to determine total iterations
|
|
1139
|
+
start_node._initialize_iteration_data()
|
|
1140
|
+
|
|
1141
|
+
total_iterations = start_node._get_total_iterations()
|
|
1142
|
+
if total_iterations == 0:
|
|
1143
|
+
logger.info("No iterations for empty loop from '%s' to '%s'", start_node.name, node.name)
|
|
1144
|
+
return
|
|
1145
|
+
|
|
1146
|
+
# Check if we should run in parallel (default is sequential/False)
|
|
1147
|
+
run_in_parallel = start_node.get_parameter_value("run_in_parallel")
|
|
1148
|
+
if not run_in_parallel:
|
|
1149
|
+
# Sequential execution - run iterations one at a time in the main execution flow
|
|
1150
|
+
await self._handle_sequential_loop_execution(start_node, node)
|
|
1151
|
+
return
|
|
1152
|
+
|
|
1153
|
+
# Parallel execution - package and run all iterations concurrently
|
|
1154
|
+
# Package the loop body (nodes between start and end)
|
|
1155
|
+
package_result_and_execution_type = await self._package_loop_body(start_node, node)
|
|
1156
|
+
|
|
1157
|
+
# Handle empty loop body (no nodes between start and end)
|
|
1158
|
+
if package_result_and_execution_type is None:
|
|
1159
|
+
logger.info("Empty loop body - results already set by _package_loop_body")
|
|
1160
|
+
return
|
|
1161
|
+
package_result, execution_type = package_result_and_execution_type
|
|
1162
|
+
# Get parameter values for each iteration
|
|
1163
|
+
parameter_values_to_set_before_run = self._get_merged_parameter_values_for_iterations(
|
|
1164
|
+
start_node, package_result
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
# Step 5: Execute all iterations based on execution environment
|
|
1168
|
+
if execution_type == LOCAL_EXECUTION:
|
|
1169
|
+
(
|
|
1170
|
+
iteration_results,
|
|
1171
|
+
successful_iterations,
|
|
1172
|
+
last_iteration_values,
|
|
1173
|
+
) = await self._execute_loop_iterations_locally(
|
|
1174
|
+
package_result=package_result,
|
|
1175
|
+
total_iterations=total_iterations,
|
|
1176
|
+
parameter_values_per_iteration=parameter_values_to_set_before_run,
|
|
1177
|
+
end_loop_node=node,
|
|
1178
|
+
)
|
|
1179
|
+
elif execution_type == PRIVATE_EXECUTION:
|
|
1180
|
+
(
|
|
1181
|
+
iteration_results,
|
|
1182
|
+
successful_iterations,
|
|
1183
|
+
last_iteration_values,
|
|
1184
|
+
) = await self._execute_loop_iterations_privately(
|
|
1185
|
+
package_result=package_result,
|
|
1186
|
+
total_iterations=total_iterations,
|
|
1187
|
+
parameter_values_per_iteration=parameter_values_to_set_before_run,
|
|
1188
|
+
end_loop_node=node,
|
|
532
1189
|
)
|
|
1190
|
+
else:
|
|
1191
|
+
# Cloud publisher execution (Deadline Cloud, etc.)
|
|
1192
|
+
(
|
|
1193
|
+
iteration_results,
|
|
1194
|
+
successful_iterations,
|
|
1195
|
+
last_iteration_values,
|
|
1196
|
+
) = await self._execute_loop_iterations_via_publisher(
|
|
1197
|
+
package_result=package_result,
|
|
1198
|
+
total_iterations=total_iterations,
|
|
1199
|
+
parameter_values_per_iteration=parameter_values_to_set_before_run,
|
|
1200
|
+
end_loop_node=node,
|
|
1201
|
+
execution_type=execution_type,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
if len(successful_iterations) != total_iterations:
|
|
1205
|
+
failed_count = total_iterations - len(successful_iterations)
|
|
1206
|
+
msg = f"Loop execution failed: {failed_count} of {total_iterations} iterations failed"
|
|
1207
|
+
raise RuntimeError(msg)
|
|
1208
|
+
|
|
1209
|
+
logger.info(
|
|
1210
|
+
"Successfully completed parallel execution of %d iterations for loop '%s'",
|
|
1211
|
+
total_iterations,
|
|
1212
|
+
start_node.name,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
# Step 6: Build results list in iteration order
|
|
1216
|
+
node._results_list = []
|
|
1217
|
+
for iteration_index in sorted(iteration_results.keys()):
|
|
1218
|
+
value = iteration_results[iteration_index]
|
|
1219
|
+
node._results_list.append(value)
|
|
1220
|
+
|
|
1221
|
+
# Step 7: Output final results to the results parameter
|
|
1222
|
+
node._output_results_list()
|
|
1223
|
+
|
|
1224
|
+
# Step 8: Apply last iteration values to the original packaged nodes in main flow
|
|
1225
|
+
self._apply_last_iteration_to_packaged_nodes(
|
|
1226
|
+
last_iteration_values=last_iteration_values,
|
|
1227
|
+
package_result=package_result,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
logger.info(
|
|
1231
|
+
"Successfully aggregated %d results for loop '%s' to '%s'",
|
|
1232
|
+
len(iteration_results),
|
|
1233
|
+
start_node.name,
|
|
1234
|
+
node.name,
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
def _get_iteration_value_for_parameter(
|
|
1238
|
+
self,
|
|
1239
|
+
source_param_name: str,
|
|
1240
|
+
iteration_index: int,
|
|
1241
|
+
index_values: list[int],
|
|
1242
|
+
current_item_values: list[Any],
|
|
1243
|
+
) -> Any:
|
|
1244
|
+
"""Get the value for a specific parameter at a given iteration.
|
|
1245
|
+
|
|
1246
|
+
Args:
|
|
1247
|
+
source_param_name: Name of the source parameter (e.g., "index" or "current_item")
|
|
1248
|
+
iteration_index: 0-based iteration index
|
|
1249
|
+
index_values: List of actual loop values for ForLoop nodes
|
|
1250
|
+
current_item_values: List of items for ForEach nodes
|
|
1251
|
+
|
|
1252
|
+
Returns:
|
|
1253
|
+
The value to set for this parameter at this iteration
|
|
1254
|
+
"""
|
|
1255
|
+
if source_param_name == "index":
|
|
1256
|
+
# For ForLoop nodes, use actual loop value; otherwise use iteration_index
|
|
1257
|
+
if index_values and iteration_index < len(index_values):
|
|
1258
|
+
return index_values[iteration_index]
|
|
1259
|
+
return iteration_index
|
|
1260
|
+
if source_param_name == "current_item" and iteration_index < len(current_item_values):
|
|
1261
|
+
return current_item_values[iteration_index]
|
|
1262
|
+
return None
|
|
1263
|
+
|
|
1264
|
+
def get_parameter_values_per_iteration( # noqa: C901, Needed to add special handling for node groups.
|
|
1265
|
+
self,
|
|
1266
|
+
start_node: BaseIterativeStartNode,
|
|
1267
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
1268
|
+
) -> dict[int, dict[str, Any]]:
|
|
1269
|
+
"""Get parameter values for each iteration of the loop.
|
|
1270
|
+
|
|
1271
|
+
This maps iteration index to parameter values that should be set on the packaged flow's StartFlow node.
|
|
1272
|
+
Useful for: setting local values, sending as input for cloud publishing, or private workflow execution.
|
|
1273
|
+
|
|
1274
|
+
Args:
|
|
1275
|
+
start_node: The start loop node (ForEach or ForLoop)
|
|
1276
|
+
|
|
1277
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess containing parameter_name_mappings
|
|
1278
|
+
|
|
1279
|
+
Returns:
|
|
1280
|
+
Dict mapping iteration_index -> {startflow_param_name: value}
|
|
1281
|
+
"""
|
|
1282
|
+
total_iterations = start_node._get_total_iterations()
|
|
1283
|
+
|
|
1284
|
+
# Calculate current_item values for ForEach nodes
|
|
1285
|
+
current_item_values = []
|
|
1286
|
+
iteration_items = start_node._get_iteration_items()
|
|
1287
|
+
current_item_values = list(iteration_items)
|
|
1288
|
+
|
|
1289
|
+
# Calculate index values for ForLoop nodes
|
|
1290
|
+
# For ForLoop, we need actual loop values (start, start+step, start+2*step, ...)
|
|
1291
|
+
# not just 0-based iteration indices
|
|
1292
|
+
index_values = []
|
|
1293
|
+
index_values = start_node.get_all_iteration_values()
|
|
1294
|
+
|
|
1295
|
+
list_connections_request = ListConnectionsForNodeRequest(node_name=start_node.name)
|
|
1296
|
+
list_connections_result = GriptapeNodes.handle_request(list_connections_request)
|
|
1297
|
+
if not isinstance(list_connections_result, ListConnectionsForNodeResultSuccess):
|
|
1298
|
+
msg = f"Failed to list connections for node {start_node.name}: {list_connections_result.result_details}"
|
|
1299
|
+
raise RuntimeError(msg) # noqa: TRY004 This should be a runtime error because it happens during execution.
|
|
1300
|
+
# Build parameter values for each iteration
|
|
1301
|
+
outgoing_connections = list_connections_result.outgoing_connections
|
|
1302
|
+
|
|
1303
|
+
# Get Start node's parameter mappings (index 0 in the list)
|
|
1304
|
+
start_node_mapping = self.get_node_parameter_mappings(package_result, "start")
|
|
1305
|
+
start_node_param_mappings = start_node_mapping.parameter_mappings
|
|
1306
|
+
|
|
1307
|
+
# For each outgoing connection from start_node, find the corresponding StartFlow parameter
|
|
1308
|
+
# The start_node_param_mappings tells us: startflow_param_name -> OriginalNodeParameter(target_node, target_param)
|
|
1309
|
+
# We need to match the target of each connection to find the right startflow parameter
|
|
1310
|
+
parameter_val_mappings = {}
|
|
1311
|
+
for iteration_index in range(total_iterations):
|
|
1312
|
+
iteration_values = {}
|
|
1313
|
+
# iteration_values is going to be startflow parameter name -> value to set
|
|
1314
|
+
|
|
1315
|
+
# For each outgoing data connection from start_node
|
|
1316
|
+
for conn in outgoing_connections:
|
|
1317
|
+
source_param_name = conn.source_parameter_name
|
|
1318
|
+
target_node_name = conn.target_node_name
|
|
1319
|
+
target_param_name = conn.target_parameter_name
|
|
1320
|
+
|
|
1321
|
+
# If target is a NodeGroup, follow the internal connection to get the actual target
|
|
1322
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
1323
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
1324
|
+
try:
|
|
1325
|
+
target_node = node_manager.get_node_by_name(target_node_name)
|
|
1326
|
+
except ValueError:
|
|
1327
|
+
msg = f"Failed to get node {target_node_name} for connection {conn} from start node {start_node.name}. Can't get parameter value iterations."
|
|
1328
|
+
logger.error(msg)
|
|
1329
|
+
raise RuntimeError(msg) # noqa: B904
|
|
1330
|
+
if isinstance(target_node, NodeGroupNode):
|
|
1331
|
+
# Get connections from this proxy parameter to find the actual internal target
|
|
1332
|
+
connections = flow_manager.get_connections()
|
|
1333
|
+
proxy_param = target_node.get_parameter_by_name(target_param_name)
|
|
1334
|
+
if proxy_param:
|
|
1335
|
+
internal_connections = connections.get_all_outgoing_connections(target_node)
|
|
1336
|
+
for internal_conn in internal_connections:
|
|
1337
|
+
if (
|
|
1338
|
+
internal_conn.source_parameter.name == target_param_name
|
|
1339
|
+
and internal_conn.is_node_group_internal
|
|
1340
|
+
):
|
|
1341
|
+
target_node_name = internal_conn.target_node.name
|
|
1342
|
+
target_param_name = internal_conn.target_parameter.name
|
|
1343
|
+
break
|
|
1344
|
+
|
|
1345
|
+
# Find the target parameter that corresponds to this target
|
|
1346
|
+
for startflow_param_name, original_node_param in start_node_param_mappings.items():
|
|
1347
|
+
if (
|
|
1348
|
+
original_node_param.node_name == target_node_name
|
|
1349
|
+
and original_node_param.parameter_name == target_param_name
|
|
1350
|
+
):
|
|
1351
|
+
# This StartFlow parameter feeds the target - set the appropriate value
|
|
1352
|
+
value = self._get_iteration_value_for_parameter(
|
|
1353
|
+
source_param_name, iteration_index, index_values, current_item_values
|
|
1354
|
+
)
|
|
1355
|
+
if value is not None:
|
|
1356
|
+
iteration_values[startflow_param_name] = value
|
|
1357
|
+
break
|
|
1358
|
+
|
|
1359
|
+
parameter_val_mappings[iteration_index] = iteration_values
|
|
1360
|
+
|
|
1361
|
+
return parameter_val_mappings
|
|
1362
|
+
|
|
1363
|
+
def get_resolved_upstream_values(
|
|
1364
|
+
self,
|
|
1365
|
+
packaged_node_names: list[str],
|
|
1366
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
1367
|
+
) -> dict[str, Any]:
|
|
1368
|
+
"""Collect parameter values from resolved upstream nodes outside the loop.
|
|
1369
|
+
|
|
1370
|
+
When nodes inside the loop have connections to nodes outside that have already
|
|
1371
|
+
executed (RESOLVED state), we need to pass those values into the packaged flow
|
|
1372
|
+
via the StartFlow node parameters.
|
|
1373
|
+
|
|
1374
|
+
Args:
|
|
1375
|
+
packaged_node_names: List of node names being packaged in the loop
|
|
1376
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess containing parameter_name_mappings
|
|
1377
|
+
|
|
1378
|
+
Returns:
|
|
1379
|
+
Dict mapping startflow_param_name -> value from resolved upstream node
|
|
1380
|
+
"""
|
|
1381
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
1382
|
+
connections = flow_manager.get_connections()
|
|
1383
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
1384
|
+
|
|
1385
|
+
# Get Start node's parameter mappings (index 0 in the list)
|
|
1386
|
+
start_node_mapping = self.get_node_parameter_mappings(package_result, "start")
|
|
1387
|
+
start_node_param_mappings = start_node_mapping.parameter_mappings
|
|
1388
|
+
|
|
1389
|
+
resolved_upstream_values = {}
|
|
1390
|
+
|
|
1391
|
+
# For each packaged node, check its incoming data connections
|
|
1392
|
+
for packaged_node_name in packaged_node_names:
|
|
1393
|
+
try:
|
|
1394
|
+
packaged_node = node_manager.get_node_by_name(packaged_node_name)
|
|
1395
|
+
except Exception:
|
|
1396
|
+
logger.warning("Could not find packaged node '%s' to check upstream connections", packaged_node_name)
|
|
1397
|
+
continue
|
|
1398
|
+
|
|
1399
|
+
# Check each parameter for incoming connections
|
|
1400
|
+
for param in packaged_node.parameters:
|
|
1401
|
+
# Skip control parameters
|
|
1402
|
+
if param.type == ParameterTypeBuiltin.CONTROL_TYPE:
|
|
1403
|
+
continue
|
|
1404
|
+
|
|
1405
|
+
# Get upstream connection
|
|
1406
|
+
upstream_connection = connections.get_connected_node(packaged_node, param)
|
|
1407
|
+
if not upstream_connection:
|
|
1408
|
+
continue
|
|
1409
|
+
|
|
1410
|
+
upstream_node, upstream_param = upstream_connection
|
|
1411
|
+
|
|
1412
|
+
# Get upstream value if it meets criteria (resolved, not internal)
|
|
1413
|
+
upstream_value = self._get_upstream_connection_value(upstream_node, upstream_param, packaged_node_names)
|
|
1414
|
+
if upstream_value is None:
|
|
1415
|
+
continue
|
|
1416
|
+
|
|
1417
|
+
# Find the corresponding StartFlow parameter name
|
|
1418
|
+
startflow_param_name = self._map_to_startflow_parameter(
|
|
1419
|
+
packaged_node_name, param.name, start_node_param_mappings
|
|
1420
|
+
)
|
|
1421
|
+
if startflow_param_name:
|
|
1422
|
+
resolved_upstream_values[startflow_param_name] = upstream_value
|
|
1423
|
+
logger.debug(
|
|
1424
|
+
"Collected resolved upstream value: %s.%s -> StartFlow.%s = %s",
|
|
1425
|
+
upstream_node.name,
|
|
1426
|
+
upstream_param.name,
|
|
1427
|
+
startflow_param_name,
|
|
1428
|
+
upstream_value,
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
logger.info("Collected %d resolved upstream values for loop execution", len(resolved_upstream_values))
|
|
1432
|
+
return resolved_upstream_values
|
|
1433
|
+
|
|
1434
|
+
def _get_upstream_connection_value(
|
|
1435
|
+
self,
|
|
1436
|
+
upstream_node: BaseNode,
|
|
1437
|
+
upstream_param: Any,
|
|
1438
|
+
packaged_node_names: list[str],
|
|
1439
|
+
) -> Any | None:
|
|
1440
|
+
"""Extract value from upstream node if it meets criteria.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
upstream_node: The upstream node that provides the value
|
|
1444
|
+
upstream_param: The parameter on the upstream node
|
|
1445
|
+
packaged_node_names: List of packaged node names to exclude internal connections
|
|
1446
|
+
|
|
1447
|
+
Returns:
|
|
1448
|
+
The upstream value if criteria met, None otherwise
|
|
1449
|
+
"""
|
|
1450
|
+
if upstream_node.state != NodeResolutionState.RESOLVED:
|
|
1451
|
+
return None
|
|
1452
|
+
|
|
1453
|
+
if upstream_node.name in packaged_node_names:
|
|
1454
|
+
return None
|
|
1455
|
+
|
|
1456
|
+
if upstream_param.name in upstream_node.parameter_output_values:
|
|
1457
|
+
return upstream_node.parameter_output_values[upstream_param.name]
|
|
1458
|
+
|
|
1459
|
+
return upstream_node.get_parameter_value(upstream_param.name)
|
|
1460
|
+
|
|
1461
|
+
def _map_to_startflow_parameter(
|
|
1462
|
+
self,
|
|
1463
|
+
packaged_node_name: str,
|
|
1464
|
+
param_name: str,
|
|
1465
|
+
start_node_param_mappings: dict[str, Any],
|
|
1466
|
+
) -> str | None:
|
|
1467
|
+
"""Find the StartFlow parameter name that maps to a packaged node parameter.
|
|
1468
|
+
|
|
1469
|
+
Args:
|
|
1470
|
+
packaged_node_name: Name of the packaged node
|
|
1471
|
+
param_name: Name of the parameter on the packaged node
|
|
1472
|
+
start_node_param_mappings: Dict mapping startflow_param_name -> OriginalNodeParameter
|
|
1473
|
+
|
|
1474
|
+
Returns:
|
|
1475
|
+
The StartFlow parameter name if found, None otherwise
|
|
1476
|
+
"""
|
|
1477
|
+
for startflow_param_name, original_node_param in start_node_param_mappings.items():
|
|
1478
|
+
if original_node_param.node_name == packaged_node_name and original_node_param.parameter_name == param_name:
|
|
1479
|
+
return startflow_param_name
|
|
1480
|
+
return None
|
|
1481
|
+
|
|
1482
|
+
def _find_endflow_param_for_end_loop_node(
|
|
1483
|
+
self,
|
|
1484
|
+
incoming_connections: list,
|
|
1485
|
+
end_node_param_mappings: dict,
|
|
1486
|
+
) -> str | None:
|
|
1487
|
+
"""Find the EndFlow parameter name that corresponds to BaseIterativeEndNode's new_item_to_add.
|
|
1488
|
+
|
|
1489
|
+
Args:
|
|
1490
|
+
incoming_connections: List of incoming connections to end_loop_node
|
|
1491
|
+
end_node_param_mappings: Parameter mappings from EndFlow node
|
|
1492
|
+
|
|
1493
|
+
Returns:
|
|
1494
|
+
Sanitized parameter name on EndFlow node, or None if not found
|
|
1495
|
+
"""
|
|
1496
|
+
for conn in incoming_connections:
|
|
1497
|
+
if conn.target_parameter_name == "new_item_to_add":
|
|
1498
|
+
source_node_name = conn.source_node_name
|
|
1499
|
+
source_param_name = conn.source_parameter_name
|
|
1500
|
+
|
|
1501
|
+
# If source is a NodeGroup, follow the internal connection to get the actual source
|
|
1502
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
1503
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
1504
|
+
try:
|
|
1505
|
+
source_node = node_manager.get_node_by_name(source_node_name)
|
|
1506
|
+
except ValueError:
|
|
1507
|
+
continue
|
|
1508
|
+
if isinstance(source_node, NodeGroupNode):
|
|
1509
|
+
# Get connections to this proxy parameter to find the actual internal source
|
|
1510
|
+
connections = flow_manager.get_connections()
|
|
1511
|
+
proxy_param = source_node.get_parameter_by_name(source_param_name)
|
|
1512
|
+
if proxy_param:
|
|
1513
|
+
internal_connections = connections.get_all_incoming_connections(source_node)
|
|
1514
|
+
for internal_conn in internal_connections:
|
|
1515
|
+
if (
|
|
1516
|
+
internal_conn.target_parameter.name == source_param_name
|
|
1517
|
+
and internal_conn.is_node_group_internal
|
|
1518
|
+
):
|
|
1519
|
+
source_node_name = internal_conn.source_node.name
|
|
1520
|
+
source_param_name = internal_conn.source_parameter.name
|
|
1521
|
+
break
|
|
1522
|
+
|
|
1523
|
+
# Find the EndFlow parameter that corresponds to this source
|
|
1524
|
+
for sanitized_param_name, original_node_param in end_node_param_mappings.items():
|
|
1525
|
+
if (
|
|
1526
|
+
original_node_param.node_name == source_node_name
|
|
1527
|
+
and original_node_param.parameter_name == source_param_name
|
|
1528
|
+
):
|
|
1529
|
+
return sanitized_param_name
|
|
1530
|
+
|
|
1531
|
+
return None
|
|
1532
|
+
|
|
1533
|
+
def get_node_parameter_mappings(
|
|
1534
|
+
self, package_result: PackageNodesAsSerializedFlowResultSuccess, start_or_end: str
|
|
1535
|
+
) -> PackagedNodeParameterMapping:
|
|
1536
|
+
if start_or_end.lower() == "start":
|
|
1537
|
+
return package_result.parameter_name_mappings[0]
|
|
1538
|
+
if start_or_end.lower() == "end":
|
|
1539
|
+
return package_result.parameter_name_mappings[1]
|
|
1540
|
+
msg = f"start_or_end must be 'start' or 'end', got {start_or_end}"
|
|
1541
|
+
raise ValueError(msg)
|
|
1542
|
+
|
|
1543
|
+
def get_parameter_values_from_iterations(
|
|
1544
|
+
self,
|
|
1545
|
+
end_loop_node: BaseIterativeEndNode,
|
|
1546
|
+
deserialized_flows: list[tuple[int, str, dict[str, str]]],
|
|
1547
|
+
package_flow_result_success: PackageNodesAsSerializedFlowResultSuccess,
|
|
1548
|
+
) -> dict[int, Any]:
|
|
1549
|
+
"""Extract parameter values from each iteration's EndFlow node.
|
|
1550
|
+
|
|
1551
|
+
The BaseIterativeEndNode is NOT packaged. Instead, we find what connects TO it,
|
|
1552
|
+
then extract those values from the packaged EndFlow node.
|
|
1553
|
+
|
|
1554
|
+
Mirrors get_parameter_values_per_iteration pattern but works in reverse.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
end_loop_node: The End Loop Node (NOT packaged, just used for reference)
|
|
1558
|
+
deserialized_flows: List of (iteration_index, flow_name, node_name_mappings)
|
|
1559
|
+
package_flow_result_success: PackageNodesAsSerializedFlowResultSuccess containing parameter_name_mappings
|
|
1560
|
+
|
|
1561
|
+
Returns:
|
|
1562
|
+
Dict mapping iteration_index -> value for that iteration
|
|
1563
|
+
"""
|
|
1564
|
+
# Step 1: Get incoming connections TO the end_loop_node
|
|
1565
|
+
list_connections_request = ListConnectionsForNodeRequest(node_name=end_loop_node.name)
|
|
1566
|
+
list_connections_result = GriptapeNodes.handle_request(list_connections_request)
|
|
1567
|
+
if not isinstance(list_connections_result, ListConnectionsForNodeResultSuccess):
|
|
1568
|
+
msg = f"Failed to list connections for node {end_loop_node.name}: {list_connections_result.result_details}"
|
|
1569
|
+
raise RuntimeError(msg) # noqa: TRY004
|
|
1570
|
+
|
|
1571
|
+
incoming_connections = list_connections_result.incoming_connections
|
|
1572
|
+
|
|
1573
|
+
# Step 2: Get End node's parameter mappings (index 1 = EndFlow node)
|
|
1574
|
+
|
|
1575
|
+
end_node_mapping = self.get_node_parameter_mappings(package_flow_result_success, "end")
|
|
1576
|
+
end_node_param_mappings = end_node_mapping.parameter_mappings
|
|
1577
|
+
|
|
1578
|
+
# Step 3: Find the EndFlow parameter that corresponds to new_item_to_add
|
|
1579
|
+
endflow_param_name = self._find_endflow_param_for_end_loop_node(incoming_connections, end_node_param_mappings)
|
|
1580
|
+
|
|
1581
|
+
if endflow_param_name is None:
|
|
1582
|
+
logger.warning(
|
|
1583
|
+
"No connections found to BaseIterativeEndNode '%s' new_item_to_add parameter. No results will be collected.",
|
|
1584
|
+
end_loop_node.name,
|
|
1585
|
+
)
|
|
1586
|
+
return {}
|
|
1587
|
+
|
|
1588
|
+
# Step 4: Extract values from each iteration's EndFlow node
|
|
1589
|
+
packaged_end_node_name = end_node_mapping.node_name
|
|
1590
|
+
iteration_results = {}
|
|
1591
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
1592
|
+
|
|
1593
|
+
for iteration_index, flow_name, node_name_mappings in deserialized_flows:
|
|
1594
|
+
deserialized_end_node_name = node_name_mappings.get(packaged_end_node_name)
|
|
1595
|
+
if deserialized_end_node_name is None:
|
|
1596
|
+
logger.warning(
|
|
1597
|
+
"Could not find deserialized End node for iteration %d in flow '%s'",
|
|
1598
|
+
iteration_index,
|
|
1599
|
+
flow_name,
|
|
1600
|
+
)
|
|
1601
|
+
continue
|
|
1602
|
+
|
|
1603
|
+
try:
|
|
1604
|
+
deserialized_end_node = node_manager.get_node_by_name(deserialized_end_node_name)
|
|
1605
|
+
if endflow_param_name in deserialized_end_node.parameter_output_values:
|
|
1606
|
+
iteration_results[iteration_index] = deserialized_end_node.parameter_output_values[
|
|
1607
|
+
endflow_param_name
|
|
1608
|
+
]
|
|
1609
|
+
except Exception as e:
|
|
1610
|
+
logger.warning(
|
|
1611
|
+
"Failed to extract result from End node for iteration %d: %s",
|
|
1612
|
+
iteration_index,
|
|
1613
|
+
e,
|
|
1614
|
+
)
|
|
1615
|
+
|
|
1616
|
+
return iteration_results
|
|
1617
|
+
|
|
1618
|
+
def get_last_iteration_values_for_packaged_nodes(
|
|
1619
|
+
self,
|
|
1620
|
+
deserialized_flows: list[tuple[int, str, dict[str, str]]],
|
|
1621
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
1622
|
+
total_iterations: int,
|
|
1623
|
+
) -> dict[str, Any]:
|
|
1624
|
+
"""Extract parameter values from the LAST iteration's End Flow node for all output parameters.
|
|
1625
|
+
|
|
1626
|
+
Returns values in same format as _extract_parameter_output_values(), ready to pass to
|
|
1627
|
+
_apply_parameter_values_to_node(). This sets the final state of packaged nodes after loop completes.
|
|
1628
|
+
|
|
1629
|
+
Args:
|
|
1630
|
+
deserialized_flows: List of (iteration_index, flow_name, node_name_mappings)
|
|
1631
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess containing parameter mappings
|
|
1632
|
+
total_iterations: Total number of iterations that were executed
|
|
1633
|
+
|
|
1634
|
+
Returns:
|
|
1635
|
+
Dict mapping sanitized parameter names -> values from last iteration's End node
|
|
1636
|
+
"""
|
|
1637
|
+
if total_iterations == 0:
|
|
1638
|
+
return {}
|
|
1639
|
+
|
|
1640
|
+
last_iteration_index = total_iterations - 1
|
|
1641
|
+
|
|
1642
|
+
# Find the last iteration in deserialized_flows
|
|
1643
|
+
last_iteration_flow = None
|
|
1644
|
+
for iteration_index, flow_name, node_name_mappings in deserialized_flows:
|
|
1645
|
+
if iteration_index == last_iteration_index:
|
|
1646
|
+
last_iteration_flow = (iteration_index, flow_name, node_name_mappings)
|
|
1647
|
+
break
|
|
1648
|
+
|
|
1649
|
+
if last_iteration_flow is None:
|
|
1650
|
+
logger.warning(
|
|
1651
|
+
"Could not find last iteration (index %d) in deserialized flows. Cannot extract final values.",
|
|
1652
|
+
last_iteration_index,
|
|
1653
|
+
)
|
|
1654
|
+
return {}
|
|
1655
|
+
|
|
1656
|
+
# Get End node's parameter mappings (index 1 = EndFlow node)
|
|
1657
|
+
end_node_mapping = self.get_node_parameter_mappings(package_result, "end")
|
|
1658
|
+
packaged_end_node_name = end_node_mapping.node_name
|
|
1659
|
+
|
|
1660
|
+
# Get the deserialized End node name for last iteration
|
|
1661
|
+
_, _, node_name_mappings = last_iteration_flow
|
|
1662
|
+
deserialized_end_node_name = node_name_mappings.get(packaged_end_node_name)
|
|
1663
|
+
|
|
1664
|
+
if deserialized_end_node_name is None:
|
|
1665
|
+
logger.warning(
|
|
1666
|
+
"Could not find deserialized End node (packaged name: '%s') in last iteration",
|
|
1667
|
+
packaged_end_node_name,
|
|
1668
|
+
)
|
|
1669
|
+
return {}
|
|
1670
|
+
|
|
1671
|
+
# Get the End node instance
|
|
1672
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
1673
|
+
try:
|
|
1674
|
+
deserialized_end_node = node_manager.get_node_by_name(deserialized_end_node_name)
|
|
1675
|
+
except Exception as e:
|
|
1676
|
+
logger.warning("Failed to get End node '%s' for last iteration: %s", deserialized_end_node_name, e)
|
|
1677
|
+
return {}
|
|
1678
|
+
|
|
1679
|
+
# Extract ALL parameter output values from the End node
|
|
1680
|
+
# Return them with sanitized names (as they appear on End node)
|
|
1681
|
+
last_iteration_values = {}
|
|
1682
|
+
for sanitized_param_name in end_node_mapping.parameter_mappings:
|
|
1683
|
+
if sanitized_param_name in deserialized_end_node.parameter_output_values:
|
|
1684
|
+
last_iteration_values[sanitized_param_name] = deserialized_end_node.parameter_output_values[
|
|
1685
|
+
sanitized_param_name
|
|
1686
|
+
]
|
|
1687
|
+
|
|
1688
|
+
logger.debug(
|
|
1689
|
+
"Extracted %d parameter values from last iteration's End node '%s'",
|
|
1690
|
+
len(last_iteration_values),
|
|
1691
|
+
deserialized_end_node_name,
|
|
1692
|
+
)
|
|
1693
|
+
|
|
1694
|
+
return last_iteration_values
|
|
1695
|
+
|
|
1696
|
+
async def _execute_loop_iterations_locally( # noqa: C901, PLR0912, PLR0915
|
|
1697
|
+
self,
|
|
1698
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
1699
|
+
total_iterations: int,
|
|
1700
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
1701
|
+
end_loop_node: BaseIterativeEndNode,
|
|
1702
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
1703
|
+
"""Execute loop iterations locally by deserializing and running flows.
|
|
1704
|
+
|
|
1705
|
+
This method handles LOCAL execution of loop iterations. Other libraries
|
|
1706
|
+
can implement their own execution strategies (cloud, remote, etc.) by
|
|
1707
|
+
creating similar methods with the same signature.
|
|
1708
|
+
|
|
1709
|
+
Args:
|
|
1710
|
+
package_result: The packaged flow with parameter mappings
|
|
1711
|
+
total_iterations: Number of iterations to run
|
|
1712
|
+
parameter_values_per_iteration: Dict mapping iteration_index -> parameter values
|
|
1713
|
+
end_loop_node: The End Loop Node to extract results for
|
|
1714
|
+
|
|
1715
|
+
Returns:
|
|
1716
|
+
Tuple of:
|
|
1717
|
+
- iteration_results: Dict mapping iteration_index -> result value
|
|
1718
|
+
- successful_iterations: List of iteration indices that succeeded
|
|
1719
|
+
- last_iteration_values: Dict mapping parameter names -> values from last iteration
|
|
1720
|
+
"""
|
|
1721
|
+
# Step 1: Deserialize N flow instances from the serialized flow
|
|
1722
|
+
# Save the current context and restore it after each deserialization to prevent
|
|
1723
|
+
# iteration flows from becoming children of each other
|
|
1724
|
+
deserialized_flows = []
|
|
1725
|
+
context_manager = GriptapeNodes.ContextManager()
|
|
1726
|
+
saved_context_flow = context_manager.get_current_flow() if context_manager.has_current_flow() else None
|
|
1727
|
+
|
|
1728
|
+
# Suppress events during deserialization to prevent sending them to websockets
|
|
1729
|
+
event_manager = GriptapeNodes.EventManager()
|
|
1730
|
+
with EventSuppressionContext(event_manager, LOOP_EVENTS_TO_SUPPRESS):
|
|
1731
|
+
for iteration_index in range(total_iterations):
|
|
1732
|
+
# Restore context before each deserialization to ensure all iteration flows
|
|
1733
|
+
# are created at the same level (not as children of each other)
|
|
1734
|
+
if saved_context_flow is not None:
|
|
1735
|
+
# Pop any flows that were pushed during previous iteration
|
|
1736
|
+
while (
|
|
1737
|
+
context_manager.has_current_flow() and context_manager.get_current_flow() != saved_context_flow
|
|
1738
|
+
):
|
|
1739
|
+
context_manager.pop_flow()
|
|
1740
|
+
|
|
1741
|
+
deserialize_request = DeserializeFlowFromCommandsRequest(
|
|
1742
|
+
serialized_flow_commands=package_result.serialized_flow_commands
|
|
1743
|
+
)
|
|
1744
|
+
deserialize_result = GriptapeNodes.handle_request(deserialize_request)
|
|
1745
|
+
if not isinstance(deserialize_result, DeserializeFlowFromCommandsResultSuccess):
|
|
1746
|
+
msg = f"Failed to deserialize flow for iteration {iteration_index}. Error: {deserialize_result.result_details}"
|
|
1747
|
+
raise TypeError(msg)
|
|
1748
|
+
|
|
1749
|
+
deserialized_flows.append(
|
|
1750
|
+
(iteration_index, deserialize_result.flow_name, deserialize_result.node_name_mappings)
|
|
1751
|
+
)
|
|
1752
|
+
|
|
1753
|
+
# Pop the deserialized flow from the context stack to prevent it from staying there
|
|
1754
|
+
# Deserialization pushes the flow onto the stack, but we don't want iteration flows
|
|
1755
|
+
# to remain on the stack after deserialization
|
|
1756
|
+
if (
|
|
1757
|
+
context_manager.has_current_flow()
|
|
1758
|
+
and context_manager.get_current_flow().name == deserialize_result.flow_name
|
|
1759
|
+
):
|
|
1760
|
+
context_manager.pop_flow()
|
|
1761
|
+
logger.info("Successfully deserialized %d flow instances for parallel execution", total_iterations)
|
|
1762
|
+
# Step 2: Set input values on start nodes for each iteration
|
|
1763
|
+
for iteration_index, _, node_name_mappings in deserialized_flows:
|
|
1764
|
+
parameter_values = parameter_values_per_iteration[iteration_index]
|
|
1765
|
+
|
|
1766
|
+
# Get Start node mapping (index 0 in the list)
|
|
1767
|
+
start_node_mapping = self.get_node_parameter_mappings(package_result, "start")
|
|
1768
|
+
start_node_name = start_node_mapping.node_name
|
|
1769
|
+
start_params = start_node_mapping.parameter_mappings
|
|
1770
|
+
|
|
1771
|
+
# Find the deserialized name for the Start node
|
|
1772
|
+
deserialized_start_node_name = node_name_mappings.get(start_node_name)
|
|
1773
|
+
if deserialized_start_node_name is None:
|
|
1774
|
+
logger.warning(
|
|
1775
|
+
"Could not find deserialized Start node (original: '%s') for iteration %d",
|
|
1776
|
+
start_node_name,
|
|
1777
|
+
iteration_index,
|
|
1778
|
+
)
|
|
1779
|
+
continue
|
|
1780
|
+
|
|
1781
|
+
# Set all parameter values on the deserialized Start node
|
|
1782
|
+
for startflow_param_name in start_params:
|
|
1783
|
+
if startflow_param_name not in parameter_values:
|
|
1784
|
+
continue
|
|
1785
|
+
|
|
1786
|
+
value_to_set = parameter_values[startflow_param_name]
|
|
1787
|
+
|
|
1788
|
+
set_value_request = SetParameterValueRequest(
|
|
1789
|
+
node_name=deserialized_start_node_name,
|
|
1790
|
+
parameter_name=startflow_param_name,
|
|
1791
|
+
value=value_to_set,
|
|
1792
|
+
)
|
|
1793
|
+
set_value_result = await GriptapeNodes.ahandle_request(set_value_request)
|
|
1794
|
+
if not isinstance(set_value_result, SetParameterValueResultSuccess):
|
|
1795
|
+
logger.warning(
|
|
1796
|
+
"Failed to set parameter '%s' on Start node '%s' for iteration %d: %s",
|
|
1797
|
+
startflow_param_name,
|
|
1798
|
+
deserialized_start_node_name,
|
|
1799
|
+
iteration_index,
|
|
1800
|
+
set_value_result.result_details,
|
|
1801
|
+
)
|
|
1802
|
+
|
|
1803
|
+
logger.info("Successfully set input values for %d iterations", total_iterations)
|
|
1804
|
+
|
|
1805
|
+
# Step 3: Run all flows concurrently
|
|
1806
|
+
packaged_start_node_name = self.get_node_parameter_mappings(package_result, "start").node_name
|
|
1807
|
+
|
|
1808
|
+
async def run_single_iteration(flow_name: str, iteration_index: int, start_node_name: str) -> tuple[int, bool]:
|
|
1809
|
+
"""Run a single iteration flow and return success status."""
|
|
1810
|
+
# Suppress execution events during parallel iteration to prevent flooding websockets
|
|
1811
|
+
with EventSuppressionContext(event_manager, EXECUTION_EVENTS_TO_SUPPRESS):
|
|
1812
|
+
start_subflow_request = StartLocalSubflowRequest(
|
|
1813
|
+
flow_name=flow_name,
|
|
1814
|
+
start_node=start_node_name,
|
|
1815
|
+
pickle_control_flow_result=False,
|
|
1816
|
+
)
|
|
1817
|
+
start_subflow_result = await GriptapeNodes.ahandle_request(start_subflow_request)
|
|
1818
|
+
success = isinstance(start_subflow_result, StartLocalSubflowResultSuccess)
|
|
1819
|
+
return iteration_index, success
|
|
1820
|
+
|
|
1821
|
+
try:
|
|
1822
|
+
# Run all iterations concurrently
|
|
1823
|
+
iteration_tasks = [
|
|
1824
|
+
run_single_iteration(
|
|
1825
|
+
flow_name,
|
|
1826
|
+
iteration_index,
|
|
1827
|
+
node_name_mappings.get(packaged_start_node_name),
|
|
1828
|
+
)
|
|
1829
|
+
for iteration_index, flow_name, node_name_mappings in deserialized_flows
|
|
1830
|
+
]
|
|
1831
|
+
iteration_results = await asyncio.gather(*iteration_tasks, return_exceptions=True)
|
|
1832
|
+
|
|
1833
|
+
# Step 4: Collect successful and failed iterations
|
|
1834
|
+
successful_iterations = []
|
|
1835
|
+
failed_iterations = []
|
|
1836
|
+
|
|
1837
|
+
for result in iteration_results:
|
|
1838
|
+
if isinstance(result, Exception):
|
|
1839
|
+
failed_iterations.append(result)
|
|
1840
|
+
continue
|
|
1841
|
+
if isinstance(result, tuple):
|
|
1842
|
+
iteration_index, success = result
|
|
1843
|
+
if success:
|
|
1844
|
+
successful_iterations.append(iteration_index)
|
|
1845
|
+
else:
|
|
1846
|
+
failed_iterations.append(iteration_index)
|
|
1847
|
+
|
|
1848
|
+
if failed_iterations:
|
|
1849
|
+
msg = f"Loop execution failed: {len(failed_iterations)} of {total_iterations} iterations failed"
|
|
1850
|
+
raise RuntimeError(msg)
|
|
1851
|
+
|
|
1852
|
+
# Step 4: Extract parameter values from iterations BEFORE cleanup
|
|
1853
|
+
iteration_results = self.get_parameter_values_from_iterations(
|
|
1854
|
+
end_loop_node=end_loop_node,
|
|
1855
|
+
deserialized_flows=deserialized_flows,
|
|
1856
|
+
package_flow_result_success=package_result,
|
|
1857
|
+
)
|
|
1858
|
+
|
|
1859
|
+
# Step 5: Extract last iteration values BEFORE cleanup (flows deleted in finally block)
|
|
1860
|
+
last_iteration_values = self.get_last_iteration_values_for_packaged_nodes(
|
|
1861
|
+
deserialized_flows=deserialized_flows,
|
|
1862
|
+
package_result=package_result,
|
|
1863
|
+
total_iterations=total_iterations,
|
|
1864
|
+
)
|
|
1865
|
+
|
|
1866
|
+
return iteration_results, successful_iterations, last_iteration_values
|
|
1867
|
+
|
|
1868
|
+
finally:
|
|
1869
|
+
# Step 5: Cleanup - delete all iteration flows
|
|
1870
|
+
# Suppress events during deletion to prevent sending them to websockets
|
|
1871
|
+
with EventSuppressionContext(event_manager, {DeleteFlowResultSuccess, DeleteFlowResultFailure}):
|
|
1872
|
+
for iteration_index, flow_name, _ in deserialized_flows:
|
|
1873
|
+
delete_request = DeleteFlowRequest(flow_name=flow_name)
|
|
1874
|
+
delete_result = await GriptapeNodes.ahandle_request(delete_request)
|
|
1875
|
+
if not isinstance(delete_result, DeleteFlowResultSuccess):
|
|
1876
|
+
logger.warning(
|
|
1877
|
+
"Failed to delete iteration flow '%s' (iteration %d): %s",
|
|
1878
|
+
flow_name,
|
|
1879
|
+
iteration_index,
|
|
1880
|
+
delete_result.result_details,
|
|
1881
|
+
)
|
|
1882
|
+
|
|
1883
|
+
async def _execute_loop_iterations_via_subprocess( # noqa: PLR0913
|
|
1884
|
+
self,
|
|
1885
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
1886
|
+
total_iterations: int,
|
|
1887
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
1888
|
+
end_loop_node: BaseIterativeEndNode,
|
|
1889
|
+
workflow_path: Path,
|
|
1890
|
+
workflow_result: Any, # noqa: ARG002 - Used by wrapper methods for cleanup
|
|
1891
|
+
file_name_prefix: str,
|
|
1892
|
+
execution_type: str,
|
|
1893
|
+
*,
|
|
1894
|
+
run_sequentially: bool,
|
|
1895
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
1896
|
+
"""Execute loop iterations via subprocess (unified helper for private/cloud execution).
|
|
1897
|
+
|
|
1898
|
+
This unified helper handles both sequential and parallel execution modes for
|
|
1899
|
+
workflows that run as subprocesses (PRIVATE or CLOUD publishers).
|
|
1900
|
+
|
|
1901
|
+
Args:
|
|
1902
|
+
package_result: The packaged flow with parameter mappings
|
|
1903
|
+
total_iterations: Number of iterations to run
|
|
1904
|
+
parameter_values_per_iteration: Dict mapping iteration_index -> parameter values
|
|
1905
|
+
end_loop_node: The End Loop Node to extract results for
|
|
1906
|
+
workflow_path: Path to the saved/published workflow file
|
|
1907
|
+
workflow_result: Result from saving/publishing the workflow
|
|
1908
|
+
file_name_prefix: Prefix for iteration-specific file names
|
|
1909
|
+
execution_type: Human-readable execution mode name for logging
|
|
1910
|
+
run_sequentially: If True, run iterations one-at-a-time; if False, run concurrently
|
|
1911
|
+
|
|
1912
|
+
Returns:
|
|
1913
|
+
Tuple of (iteration_results, successful_iterations, last_iteration_values)
|
|
1914
|
+
"""
|
|
1915
|
+
# if it's private execution, we aren't republishing it in a library.
|
|
1916
|
+
# So our original package is what is running, and we can count on using these mappings
|
|
1917
|
+
if execution_type == PRIVATE_EXECUTION:
|
|
1918
|
+
start_node_mapping = self.get_node_parameter_mappings(package_result, "start")
|
|
1919
|
+
start_node_name = start_node_mapping.node_name
|
|
1920
|
+
# For published libraries, we need to get the new Start Node name, based on what their registered nodes are.
|
|
1921
|
+
else:
|
|
1922
|
+
library = LibraryRegistry.get_library(execution_type)
|
|
1923
|
+
node_details = await self._get_workflow_start_end_nodes(library)
|
|
1924
|
+
start_node_type = node_details.start_flow_node_type
|
|
1925
|
+
node_metadata = library.get_node_metadata(start_node_type)
|
|
1926
|
+
start_node_name = node_metadata.display_name
|
|
1927
|
+
|
|
1928
|
+
mode_str = "sequentially" if run_sequentially else "concurrently"
|
|
1929
|
+
logger.info(
|
|
1930
|
+
"Executing %d iterations %s in %s for loop '%s'",
|
|
1931
|
+
total_iterations,
|
|
1932
|
+
mode_str,
|
|
1933
|
+
execution_type,
|
|
1934
|
+
end_loop_node.name,
|
|
1935
|
+
)
|
|
1936
|
+
|
|
1937
|
+
try:
|
|
1938
|
+
if run_sequentially:
|
|
1939
|
+
# Execute iterations one-at-a-time
|
|
1940
|
+
iteration_outputs: list[tuple[int, bool, dict[str, Any] | None]] = []
|
|
1941
|
+
for iteration_index in range(total_iterations):
|
|
1942
|
+
try:
|
|
1943
|
+
flow_input = {start_node_name: parameter_values_per_iteration[iteration_index]}
|
|
1944
|
+
logger.info(
|
|
1945
|
+
"Executing iteration %d/%d for loop '%s'",
|
|
1946
|
+
iteration_index + 1,
|
|
1947
|
+
total_iterations,
|
|
1948
|
+
end_loop_node.name,
|
|
1949
|
+
)
|
|
1950
|
+
|
|
1951
|
+
subprocess_result = await self._execute_subprocess(
|
|
1952
|
+
published_workflow_filename=workflow_path,
|
|
1953
|
+
file_name=f"{file_name_prefix}_iteration_{iteration_index}",
|
|
1954
|
+
pickle_control_flow_result=True,
|
|
1955
|
+
flow_input=flow_input,
|
|
1956
|
+
)
|
|
1957
|
+
iteration_outputs.append((iteration_index, True, subprocess_result))
|
|
1958
|
+
except Exception:
|
|
1959
|
+
logger.exception("Iteration %d failed for loop '%s'", iteration_index, end_loop_node.name)
|
|
1960
|
+
iteration_outputs.append((iteration_index, False, None))
|
|
1961
|
+
else:
|
|
1962
|
+
# Execute all iterations concurrently
|
|
1963
|
+
async def run_single_iteration(iteration_index: int) -> tuple[int, bool, dict[str, Any] | None]:
|
|
1964
|
+
try:
|
|
1965
|
+
flow_input = {start_node_name: parameter_values_per_iteration[iteration_index]}
|
|
1966
|
+
logger.info(
|
|
1967
|
+
"Executing iteration %d/%d for loop '%s'",
|
|
1968
|
+
iteration_index + 1,
|
|
1969
|
+
total_iterations,
|
|
1970
|
+
end_loop_node.name,
|
|
1971
|
+
)
|
|
1972
|
+
|
|
1973
|
+
subprocess_result = await self._execute_subprocess(
|
|
1974
|
+
published_workflow_filename=workflow_path,
|
|
1975
|
+
file_name=f"{file_name_prefix}_iteration_{iteration_index}",
|
|
1976
|
+
pickle_control_flow_result=True,
|
|
1977
|
+
flow_input=flow_input,
|
|
1978
|
+
)
|
|
1979
|
+
except Exception:
|
|
1980
|
+
logger.exception("Iteration %d failed for loop '%s'", iteration_index, end_loop_node.name)
|
|
1981
|
+
return iteration_index, False, None
|
|
1982
|
+
else:
|
|
1983
|
+
return iteration_index, True, subprocess_result
|
|
1984
|
+
|
|
1985
|
+
iteration_tasks = [run_single_iteration(i) for i in range(total_iterations)]
|
|
1986
|
+
iteration_outputs = await asyncio.gather(*iteration_tasks)
|
|
1987
|
+
|
|
1988
|
+
# Extract results
|
|
1989
|
+
iteration_results, successful_iterations, last_iteration_values = (
|
|
1990
|
+
self._extract_iteration_results_from_subprocess(
|
|
1991
|
+
iteration_outputs=iteration_outputs,
|
|
1992
|
+
package_result=package_result,
|
|
1993
|
+
end_loop_node=end_loop_node,
|
|
1994
|
+
)
|
|
1995
|
+
)
|
|
1996
|
+
|
|
1997
|
+
logger.info(
|
|
1998
|
+
"Successfully completed %d/%d iterations %s in %s for loop '%s'",
|
|
1999
|
+
len(successful_iterations),
|
|
2000
|
+
total_iterations,
|
|
2001
|
+
mode_str,
|
|
2002
|
+
execution_type,
|
|
2003
|
+
end_loop_node.name,
|
|
2004
|
+
)
|
|
2005
|
+
|
|
2006
|
+
return iteration_results, successful_iterations, last_iteration_values
|
|
2007
|
+
finally:
|
|
2008
|
+
# Cleanup handled by wrapper methods
|
|
2009
|
+
pass
|
|
2010
|
+
|
|
2011
|
+
async def _execute_loop_iterations_sequentially_private(
|
|
2012
|
+
self,
|
|
2013
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2014
|
+
total_iterations: int,
|
|
2015
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
2016
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2017
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
2018
|
+
"""Execute loop iterations sequentially in private subprocesses (no cloud publishing)."""
|
|
2019
|
+
workflow_path, workflow_result = await self._save_workflow_file_for_loop(
|
|
2020
|
+
end_loop_node=end_loop_node,
|
|
2021
|
+
package_result=package_result,
|
|
2022
|
+
pickle_control_flow_result=True,
|
|
2023
|
+
)
|
|
2024
|
+
sanitized_loop_name = end_loop_node.name.replace(" ", "_")
|
|
2025
|
+
file_name_prefix = f"{sanitized_loop_name}_private_sequential_loop_flow"
|
|
2026
|
+
|
|
2027
|
+
try:
|
|
2028
|
+
return await self._execute_loop_iterations_via_subprocess(
|
|
2029
|
+
package_result=package_result,
|
|
2030
|
+
total_iterations=total_iterations,
|
|
2031
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
2032
|
+
end_loop_node=end_loop_node,
|
|
2033
|
+
workflow_path=workflow_path,
|
|
2034
|
+
workflow_result=workflow_result,
|
|
2035
|
+
file_name_prefix=file_name_prefix,
|
|
2036
|
+
execution_type=PRIVATE_EXECUTION,
|
|
2037
|
+
run_sequentially=True,
|
|
2038
|
+
)
|
|
2039
|
+
finally:
|
|
2040
|
+
try:
|
|
2041
|
+
await self._delete_workflow(
|
|
2042
|
+
workflow_name=workflow_result.workflow_metadata.name, workflow_path=workflow_path
|
|
2043
|
+
)
|
|
2044
|
+
except Exception as e:
|
|
2045
|
+
logger.warning("Failed to cleanup workflow file: %s", e)
|
|
2046
|
+
|
|
2047
|
+
async def _execute_loop_iterations_privately(
|
|
2048
|
+
self,
|
|
2049
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2050
|
+
total_iterations: int,
|
|
2051
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
2052
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2053
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
2054
|
+
"""Execute loop iterations in parallel via private subprocesses (no cloud publishing)."""
|
|
2055
|
+
workflow_path, workflow_result = await self._save_workflow_file_for_loop(
|
|
2056
|
+
end_loop_node=end_loop_node,
|
|
2057
|
+
package_result=package_result,
|
|
2058
|
+
pickle_control_flow_result=True,
|
|
2059
|
+
)
|
|
2060
|
+
sanitized_loop_name = end_loop_node.name.replace(" ", "_")
|
|
2061
|
+
file_name_prefix = f"{sanitized_loop_name}_private_loop_flow"
|
|
2062
|
+
|
|
2063
|
+
try:
|
|
2064
|
+
return await self._execute_loop_iterations_via_subprocess(
|
|
2065
|
+
package_result=package_result,
|
|
2066
|
+
total_iterations=total_iterations,
|
|
2067
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
2068
|
+
end_loop_node=end_loop_node,
|
|
2069
|
+
workflow_path=workflow_path,
|
|
2070
|
+
workflow_result=workflow_result,
|
|
2071
|
+
file_name_prefix=file_name_prefix,
|
|
2072
|
+
execution_type=PRIVATE_EXECUTION,
|
|
2073
|
+
run_sequentially=False,
|
|
2074
|
+
)
|
|
2075
|
+
finally:
|
|
2076
|
+
try:
|
|
2077
|
+
await self._delete_workflow(
|
|
2078
|
+
workflow_name=workflow_result.workflow_metadata.name, workflow_path=workflow_path
|
|
2079
|
+
)
|
|
2080
|
+
except Exception as e:
|
|
2081
|
+
logger.warning("Failed to cleanup workflow file: %s", e)
|
|
2082
|
+
|
|
2083
|
+
async def _save_workflow_file_for_loop(
|
|
2084
|
+
self,
|
|
2085
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2086
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2087
|
+
*,
|
|
2088
|
+
pickle_control_flow_result: bool,
|
|
2089
|
+
) -> tuple[Path, Any]:
|
|
2090
|
+
"""Save workflow file for loop execution.
|
|
2091
|
+
|
|
2092
|
+
Args:
|
|
2093
|
+
end_loop_node: The end loop node
|
|
2094
|
+
package_result: The packaged flow
|
|
2095
|
+
pickle_control_flow_result: Whether to pickle the control flow result
|
|
2096
|
+
|
|
2097
|
+
Returns:
|
|
2098
|
+
Tuple of (workflow_path, workflow_result)
|
|
2099
|
+
"""
|
|
2100
|
+
sanitized_loop_name = end_loop_node.name.replace(" ", "_")
|
|
2101
|
+
file_name = f"{sanitized_loop_name}_private_loop_flow"
|
|
2102
|
+
|
|
2103
|
+
workflow_file_request = SaveWorkflowFileFromSerializedFlowRequest(
|
|
2104
|
+
file_name=file_name,
|
|
2105
|
+
serialized_flow_commands=package_result.serialized_flow_commands,
|
|
2106
|
+
workflow_shape=package_result.workflow_shape,
|
|
2107
|
+
pickle_control_flow_result=pickle_control_flow_result,
|
|
2108
|
+
)
|
|
2109
|
+
|
|
2110
|
+
workflow_result = await GriptapeNodes.ahandle_request(workflow_file_request)
|
|
2111
|
+
if not isinstance(workflow_result, SaveWorkflowFileFromSerializedFlowResultSuccess):
|
|
2112
|
+
msg = f"Failed to save workflow file for private loop execution: {workflow_result.result_details}"
|
|
2113
|
+
raise TypeError(msg)
|
|
2114
|
+
|
|
2115
|
+
workflow_path = Path(workflow_result.file_path)
|
|
2116
|
+
logger.info("Saved workflow to '%s'", workflow_path)
|
|
2117
|
+
|
|
2118
|
+
return workflow_path, workflow_result
|
|
2119
|
+
|
|
2120
|
+
def _extract_iteration_results_from_subprocess(
|
|
2121
|
+
self,
|
|
2122
|
+
iteration_outputs: list[tuple[int, bool, dict[str, Any] | None]],
|
|
2123
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2124
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2125
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
2126
|
+
"""Extract results from subprocess iteration outputs.
|
|
2127
|
+
|
|
2128
|
+
Args:
|
|
2129
|
+
iteration_outputs: List of (iteration_index, success, subprocess_result) tuples
|
|
2130
|
+
package_result: The packaged flow
|
|
2131
|
+
end_loop_node: The end loop node
|
|
2132
|
+
|
|
2133
|
+
Returns:
|
|
2134
|
+
Tuple of (iteration_results, successful_iterations, last_iteration_values)
|
|
2135
|
+
"""
|
|
2136
|
+
successful_iterations = []
|
|
2137
|
+
iteration_subprocess_outputs = {}
|
|
2138
|
+
|
|
2139
|
+
for iteration_index, success, subprocess_result in iteration_outputs:
|
|
2140
|
+
if success and subprocess_result is not None:
|
|
2141
|
+
successful_iterations.append(iteration_index)
|
|
2142
|
+
iteration_subprocess_outputs[iteration_index] = subprocess_result
|
|
2143
|
+
|
|
2144
|
+
# Extract the actual result values from subprocess outputs
|
|
2145
|
+
end_node_mapping = self.get_node_parameter_mappings(package_result, "end")
|
|
2146
|
+
end_node_param_mappings = end_node_mapping.parameter_mappings
|
|
2147
|
+
|
|
2148
|
+
# Find which EndFlow parameter corresponds to new_item_to_add
|
|
2149
|
+
list_connections_request = ListConnectionsForNodeRequest(node_name=end_loop_node.name)
|
|
2150
|
+
list_connections_result = GriptapeNodes.handle_request(list_connections_request)
|
|
2151
|
+
|
|
2152
|
+
endflow_param_name = None
|
|
2153
|
+
if isinstance(list_connections_result, ListConnectionsForNodeResultSuccess):
|
|
2154
|
+
endflow_param_name = self._find_endflow_param_for_end_loop_node(
|
|
2155
|
+
list_connections_result.incoming_connections, end_node_param_mappings
|
|
2156
|
+
)
|
|
2157
|
+
|
|
2158
|
+
# Extract iteration results from subprocess outputs
|
|
2159
|
+
iteration_results = {}
|
|
2160
|
+
for iteration_index in successful_iterations:
|
|
2161
|
+
subprocess_result = iteration_subprocess_outputs[iteration_index]
|
|
2162
|
+
parameter_output_values = self._extract_parameter_output_values(subprocess_result)
|
|
2163
|
+
|
|
2164
|
+
if endflow_param_name and endflow_param_name in parameter_output_values:
|
|
2165
|
+
iteration_results[iteration_index] = parameter_output_values[endflow_param_name]
|
|
2166
|
+
|
|
2167
|
+
# Get last iteration values from the last successful iteration
|
|
2168
|
+
last_iteration_values = {}
|
|
2169
|
+
if successful_iterations:
|
|
2170
|
+
last_iteration_index = max(successful_iterations)
|
|
2171
|
+
last_subprocess_result = iteration_subprocess_outputs[last_iteration_index]
|
|
2172
|
+
last_iteration_values = self._extract_parameter_output_values(last_subprocess_result)
|
|
2173
|
+
|
|
2174
|
+
return iteration_results, successful_iterations, last_iteration_values
|
|
2175
|
+
|
|
2176
|
+
async def _execute_loop_iterations_sequentially_via_publisher(
|
|
2177
|
+
self,
|
|
2178
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2179
|
+
total_iterations: int,
|
|
2180
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
2181
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2182
|
+
execution_type: str,
|
|
2183
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
2184
|
+
"""Execute loop iterations sequentially via cloud publisher (Deadline Cloud, etc.)."""
|
|
2185
|
+
try:
|
|
2186
|
+
library = LibraryRegistry.get_library(name=execution_type)
|
|
2187
|
+
except KeyError:
|
|
2188
|
+
msg = f"Could not find library for execution environment {execution_type}"
|
|
2189
|
+
raise RuntimeError(msg) # noqa: B904
|
|
2190
|
+
|
|
2191
|
+
library_name = library.get_library_data().name
|
|
2192
|
+
sanitized_loop_name = end_loop_node.name.replace(" ", "_")
|
|
2193
|
+
file_name_prefix = f"{sanitized_loop_name}_{library_name.replace(' ', '_')}_sequential_loop_flow"
|
|
2194
|
+
|
|
2195
|
+
published_workflow_filename, workflow_result = await self._publish_workflow_for_loop_execution(
|
|
2196
|
+
package_result=package_result,
|
|
2197
|
+
library_name=library_name,
|
|
2198
|
+
file_name=file_name_prefix,
|
|
2199
|
+
)
|
|
2200
|
+
|
|
2201
|
+
try:
|
|
2202
|
+
return await self._execute_loop_iterations_via_subprocess(
|
|
2203
|
+
package_result=package_result,
|
|
2204
|
+
total_iterations=total_iterations,
|
|
2205
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
2206
|
+
end_loop_node=end_loop_node,
|
|
2207
|
+
workflow_path=Path(published_workflow_filename),
|
|
2208
|
+
workflow_result=workflow_result,
|
|
2209
|
+
file_name_prefix=file_name_prefix,
|
|
2210
|
+
execution_type=execution_type,
|
|
2211
|
+
run_sequentially=True,
|
|
2212
|
+
)
|
|
2213
|
+
finally:
|
|
2214
|
+
await self._cleanup_published_workflows(
|
|
2215
|
+
workflow_result=workflow_result,
|
|
2216
|
+
published_workflow_filename=published_workflow_filename,
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
async def _execute_loop_iterations_via_publisher(
|
|
2220
|
+
self,
|
|
2221
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2222
|
+
total_iterations: int,
|
|
2223
|
+
parameter_values_per_iteration: dict[int, dict[str, Any]],
|
|
2224
|
+
end_loop_node: BaseIterativeEndNode,
|
|
2225
|
+
execution_type: str,
|
|
2226
|
+
) -> tuple[dict[int, Any], list[int], dict[str, Any]]:
|
|
2227
|
+
"""Execute loop iterations in parallel via cloud publisher (Deadline Cloud, etc.)."""
|
|
2228
|
+
try:
|
|
2229
|
+
library = LibraryRegistry.get_library(name=execution_type)
|
|
2230
|
+
except KeyError:
|
|
2231
|
+
msg = f"Could not find library for execution environment {execution_type}"
|
|
2232
|
+
raise RuntimeError(msg) # noqa: B904
|
|
2233
|
+
|
|
2234
|
+
library_name = library.get_library_data().name
|
|
2235
|
+
sanitized_loop_name = end_loop_node.name.replace(" ", "_")
|
|
2236
|
+
file_name_prefix = f"{sanitized_loop_name}_{library_name.replace(' ', '_')}_loop_flow"
|
|
2237
|
+
|
|
2238
|
+
published_workflow_filename, workflow_result = await self._publish_workflow_for_loop_execution(
|
|
2239
|
+
package_result=package_result,
|
|
2240
|
+
library_name=library_name,
|
|
2241
|
+
file_name=file_name_prefix,
|
|
2242
|
+
)
|
|
2243
|
+
|
|
2244
|
+
try:
|
|
2245
|
+
return await self._execute_loop_iterations_via_subprocess(
|
|
2246
|
+
package_result=package_result,
|
|
2247
|
+
total_iterations=total_iterations,
|
|
2248
|
+
parameter_values_per_iteration=parameter_values_per_iteration,
|
|
2249
|
+
end_loop_node=end_loop_node,
|
|
2250
|
+
workflow_path=Path(published_workflow_filename),
|
|
2251
|
+
workflow_result=workflow_result,
|
|
2252
|
+
file_name_prefix=file_name_prefix,
|
|
2253
|
+
execution_type=library_name,
|
|
2254
|
+
run_sequentially=False,
|
|
2255
|
+
)
|
|
2256
|
+
finally:
|
|
2257
|
+
await self._cleanup_published_workflows(
|
|
2258
|
+
workflow_result=workflow_result,
|
|
2259
|
+
published_workflow_filename=published_workflow_filename,
|
|
2260
|
+
)
|
|
2261
|
+
|
|
2262
|
+
async def _publish_workflow_for_loop_execution(
|
|
2263
|
+
self,
|
|
2264
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2265
|
+
library_name: str,
|
|
2266
|
+
file_name: str,
|
|
2267
|
+
) -> tuple[Path, Any]:
|
|
2268
|
+
"""Save and publish workflow for loop execution via publisher.
|
|
2269
|
+
|
|
2270
|
+
Args:
|
|
2271
|
+
package_result: The packaged flow
|
|
2272
|
+
library_name: Name of the library to publish to
|
|
2273
|
+
file_name: Base file name for the workflow
|
|
2274
|
+
|
|
2275
|
+
Returns:
|
|
2276
|
+
Tuple of (published_workflow_filename, workflow_result)
|
|
2277
|
+
"""
|
|
2278
|
+
workflow_file_request = SaveWorkflowFileFromSerializedFlowRequest(
|
|
2279
|
+
file_name=file_name,
|
|
2280
|
+
serialized_flow_commands=package_result.serialized_flow_commands,
|
|
2281
|
+
workflow_shape=package_result.workflow_shape,
|
|
2282
|
+
pickle_control_flow_result=True,
|
|
2283
|
+
)
|
|
2284
|
+
|
|
2285
|
+
workflow_result = await GriptapeNodes.ahandle_request(workflow_file_request)
|
|
2286
|
+
if not isinstance(workflow_result, SaveWorkflowFileFromSerializedFlowResultSuccess):
|
|
2287
|
+
msg = f"Failed to save workflow file for loop: {workflow_result.result_details}"
|
|
2288
|
+
raise RuntimeError(msg) # noqa: TRY004 - This is a runtime failure, not a type validation error
|
|
2289
|
+
|
|
2290
|
+
# Publish to the library
|
|
2291
|
+
published_workflow_filename = await self._publish_library_workflow(workflow_result, library_name, file_name)
|
|
2292
|
+
|
|
2293
|
+
logger.info("Successfully published workflow to '%s'", published_workflow_filename)
|
|
2294
|
+
|
|
2295
|
+
return published_workflow_filename, workflow_result
|
|
2296
|
+
|
|
2297
|
+
async def _cleanup_published_workflows(
|
|
2298
|
+
self,
|
|
2299
|
+
workflow_result: Any,
|
|
2300
|
+
published_workflow_filename: Path,
|
|
2301
|
+
) -> None:
|
|
2302
|
+
"""Clean up published workflow files.
|
|
2303
|
+
|
|
2304
|
+
Args:
|
|
2305
|
+
workflow_result: The workflow result containing metadata
|
|
2306
|
+
published_workflow_filename: Path to the published workflow file
|
|
2307
|
+
"""
|
|
2308
|
+
try:
|
|
2309
|
+
await self._delete_workflow(
|
|
2310
|
+
workflow_name=workflow_result.workflow_metadata.name,
|
|
2311
|
+
workflow_path=Path(workflow_result.file_path),
|
|
2312
|
+
)
|
|
2313
|
+
published_filename = published_workflow_filename.stem
|
|
2314
|
+
await self._delete_workflow(workflow_name=published_filename, workflow_path=published_workflow_filename)
|
|
2315
|
+
except Exception as e:
|
|
2316
|
+
logger.warning("Failed to cleanup workflow files: %s", e)
|
|
2317
|
+
|
|
2318
|
+
def set_parameter_output_values_for_loops(
|
|
2319
|
+
self, subprocess_result: dict[str, dict[str | SerializedNodeCommands.UniqueParameterValueUUID, Any] | None]
|
|
2320
|
+
) -> None:
|
|
2321
|
+
pass
|
|
2322
|
+
|
|
2323
|
+
def _extract_parameter_output_values(
|
|
2324
|
+
self, subprocess_result: dict[str, dict[str | SerializedNodeCommands.UniqueParameterValueUUID, Any] | None]
|
|
2325
|
+
) -> dict[str, Any]:
|
|
2326
|
+
"""Extract and deserialize parameter output values from subprocess result.
|
|
2327
|
+
|
|
2328
|
+
Returns:
|
|
2329
|
+
Dictionary of parameter names to their deserialized values
|
|
2330
|
+
"""
|
|
2331
|
+
parameter_output_values = {}
|
|
2332
|
+
for result_dict in subprocess_result.values():
|
|
2333
|
+
# Handle backward compatibility: old flat structure
|
|
2334
|
+
if not isinstance(result_dict, dict) or "parameter_output_values" not in result_dict:
|
|
2335
|
+
parameter_output_values.update(result_dict) # type: ignore[arg-type]
|
|
2336
|
+
continue
|
|
2337
|
+
|
|
2338
|
+
param_output_vals = result_dict["parameter_output_values"]
|
|
2339
|
+
unique_uuid_to_values = result_dict.get("unique_parameter_uuid_to_values")
|
|
2340
|
+
|
|
2341
|
+
# No UUID mapping - use values directly
|
|
2342
|
+
if not unique_uuid_to_values:
|
|
2343
|
+
parameter_output_values.update(param_output_vals)
|
|
2344
|
+
continue
|
|
2345
|
+
|
|
2346
|
+
# Deserialize UUID-referenced values
|
|
2347
|
+
for param_name, param_value in param_output_vals.items():
|
|
2348
|
+
parameter_output_values[param_name] = self._deserialize_parameter_value(
|
|
2349
|
+
param_name, param_value, unique_uuid_to_values
|
|
2350
|
+
)
|
|
2351
|
+
return parameter_output_values
|
|
2352
|
+
|
|
2353
|
+
def _remove_packaged_nodes_from_queue(self, packaged_node_names: set[str]) -> None:
|
|
2354
|
+
"""Remove nodes from global flow queue after they've been packaged for loop execution.
|
|
2355
|
+
|
|
2356
|
+
When nodes are packaged for For Each loops, they will be deserialized into separate
|
|
2357
|
+
flow instances. We need to remove them from the global queue to prevent them from
|
|
2358
|
+
being executed in the main flow while also being copied into loop iterations.
|
|
2359
|
+
|
|
2360
|
+
Args:
|
|
2361
|
+
packaged_node_names: Set of node names that were packaged
|
|
2362
|
+
"""
|
|
2363
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
2364
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
2365
|
+
|
|
2366
|
+
# Get the nodes from the names
|
|
2367
|
+
packaged_nodes = set()
|
|
2368
|
+
for node_name in packaged_node_names:
|
|
2369
|
+
node = node_manager.get_node_by_name(node_name)
|
|
2370
|
+
if node:
|
|
2371
|
+
packaged_nodes.add(node)
|
|
2372
|
+
|
|
2373
|
+
# Remove matching queue items from global queue
|
|
2374
|
+
items_to_remove = [item for item in flow_manager.global_flow_queue.queue if item.node in packaged_nodes]
|
|
2375
|
+
|
|
2376
|
+
for item in items_to_remove:
|
|
2377
|
+
flow_manager.global_flow_queue.queue.remove(item)
|
|
2378
|
+
|
|
2379
|
+
# Remove from DAG builder to prevent parallel execution in parent flow
|
|
2380
|
+
dag_builder = flow_manager.global_dag_builder
|
|
2381
|
+
if dag_builder:
|
|
2382
|
+
for node_name in packaged_node_names:
|
|
2383
|
+
# Remove from node_to_reference
|
|
2384
|
+
if node_name in dag_builder.node_to_reference:
|
|
2385
|
+
dag_builder.node_to_reference.pop(node_name)
|
|
2386
|
+
|
|
2387
|
+
# Remove from all networks and check if any become empty
|
|
2388
|
+
for network in list(dag_builder.graphs.values()):
|
|
2389
|
+
if node_name in network.nodes():
|
|
2390
|
+
network.remove_node(node_name)
|
|
2391
|
+
|
|
2392
|
+
def _deserialize_parameter_value(self, param_name: str, param_value: Any, unique_uuid_to_values: dict) -> Any:
|
|
2393
|
+
"""Deserialize a single parameter value, handling UUID references and pickling.
|
|
2394
|
+
|
|
2395
|
+
Args:
|
|
2396
|
+
param_name: Parameter name for logging
|
|
2397
|
+
param_value: Either a direct value or UUID reference
|
|
2398
|
+
unique_uuid_to_values: Mapping of UUIDs to pickled values
|
|
2399
|
+
|
|
2400
|
+
Returns:
|
|
2401
|
+
Deserialized parameter value
|
|
2402
|
+
"""
|
|
2403
|
+
# Direct value (not a UUID reference)
|
|
2404
|
+
if param_value not in unique_uuid_to_values:
|
|
2405
|
+
return param_value
|
|
2406
|
+
|
|
2407
|
+
stored_value = unique_uuid_to_values[param_value]
|
|
2408
|
+
|
|
2409
|
+
# Non-string stored values are used directly
|
|
2410
|
+
if not isinstance(stored_value, str):
|
|
2411
|
+
return stored_value
|
|
2412
|
+
|
|
2413
|
+
# Attempt to unpickle string-represented bytes
|
|
2414
|
+
try:
|
|
2415
|
+
actual_bytes = ast.literal_eval(stored_value)
|
|
2416
|
+
if isinstance(actual_bytes, bytes):
|
|
2417
|
+
return pickle.loads(actual_bytes) # noqa: S301
|
|
2418
|
+
except (ValueError, SyntaxError, pickle.UnpicklingError) as e:
|
|
2419
|
+
logger.warning(
|
|
2420
|
+
"Failed to unpickle string-represented bytes for parameter '%s': %s",
|
|
2421
|
+
param_name,
|
|
2422
|
+
e,
|
|
2423
|
+
)
|
|
2424
|
+
return stored_value
|
|
2425
|
+
return stored_value
|
|
2426
|
+
|
|
2427
|
+
def _apply_parameter_values_to_node(
|
|
2428
|
+
self,
|
|
2429
|
+
node: BaseNode,
|
|
2430
|
+
parameter_output_values: dict[str, Any],
|
|
2431
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2432
|
+
) -> None:
|
|
2433
|
+
"""Apply deserialized parameter values back to the node.
|
|
2434
|
+
|
|
2435
|
+
Sets parameter values on the node and updates parameter_output_values dictionary.
|
|
2436
|
+
Uses parameter_name_mappings from package_result to map packaged parameters back to original nodes.
|
|
2437
|
+
Works for both single-node and multi-node packages (NodeGroupNode).
|
|
2438
|
+
"""
|
|
2439
|
+
# If the packaged flow fails, the End Flow Node in the library published workflow will have entered from 'failed'
|
|
2440
|
+
if "failed" in parameter_output_values and parameter_output_values["failed"] == CONTROL_INPUT_PARAMETER:
|
|
2441
|
+
msg = f"Failed to execute node: {node.name}, with exception: {parameter_output_values.get('result_details', 'No result details were returned.')}"
|
|
2442
|
+
raise RuntimeError(msg)
|
|
2443
|
+
|
|
2444
|
+
# Use parameter mappings to apply values back to original nodes
|
|
2445
|
+
# Output values come from the End node (index 1 in the list)
|
|
2446
|
+
end_node_mapping = self.get_node_parameter_mappings(package_result, "end")
|
|
2447
|
+
end_node_param_mappings = end_node_mapping.parameter_mappings
|
|
2448
|
+
|
|
2449
|
+
for param_name, param_value in parameter_output_values.items():
|
|
2450
|
+
# Check if this parameter has a mapping in the End node
|
|
2451
|
+
if param_name not in end_node_param_mappings:
|
|
2452
|
+
continue
|
|
2453
|
+
|
|
2454
|
+
original_node_param = end_node_param_mappings[param_name]
|
|
2455
|
+
target_node_name = original_node_param.node_name
|
|
2456
|
+
target_param_name = original_node_param.parameter_name
|
|
2457
|
+
|
|
2458
|
+
# Determine the target node - if this is a NodeGroupNode, look up the child node
|
|
2459
|
+
if isinstance(node, NodeGroupNode):
|
|
2460
|
+
if target_node_name not in node.nodes:
|
|
2461
|
+
logger.warning(
|
|
2462
|
+
"Node '%s' not found in NodeGroupNode '%s', skipping value application",
|
|
2463
|
+
target_node_name,
|
|
2464
|
+
node.name,
|
|
2465
|
+
)
|
|
2466
|
+
continue
|
|
2467
|
+
target_node = node.nodes[target_node_name]
|
|
2468
|
+
else:
|
|
2469
|
+
target_node = node
|
|
2470
|
+
|
|
2471
|
+
# Get the parameter from the target node
|
|
2472
|
+
target_param = target_node.get_parameter_by_name(target_param_name)
|
|
2473
|
+
if target_param is None:
|
|
2474
|
+
logger.warning(
|
|
2475
|
+
"Parameter '%s' not found on node '%s', skipping value application",
|
|
2476
|
+
target_param_name,
|
|
2477
|
+
target_node_name,
|
|
2478
|
+
)
|
|
2479
|
+
continue
|
|
2480
|
+
|
|
2481
|
+
# Set the value on the target node
|
|
2482
|
+
# Provide source node/parameter to bypass connection conflict validation
|
|
2483
|
+
# These values are coming from execution results, treat as upstream values
|
|
2484
|
+
if target_param.type != ParameterTypeBuiltin.CONTROL_TYPE:
|
|
2485
|
+
GriptapeNodes.NodeManager().on_set_parameter_value_request(
|
|
2486
|
+
SetParameterValueRequest(
|
|
2487
|
+
node_name=target_node_name,
|
|
2488
|
+
parameter_name=target_param_name,
|
|
2489
|
+
value=param_value,
|
|
2490
|
+
incoming_connection_source_node_name=node.name,
|
|
2491
|
+
incoming_connection_source_parameter_name=target_param_name,
|
|
2492
|
+
)
|
|
2493
|
+
)
|
|
2494
|
+
target_node.parameter_output_values[target_param_name] = param_value
|
|
2495
|
+
|
|
2496
|
+
logger.debug(
|
|
2497
|
+
"Set parameter '%s' on node '%s' to value: %s",
|
|
2498
|
+
target_param_name,
|
|
2499
|
+
target_node_name,
|
|
2500
|
+
param_value,
|
|
2501
|
+
)
|
|
2502
|
+
|
|
2503
|
+
def _apply_last_iteration_to_packaged_nodes(
|
|
2504
|
+
self,
|
|
2505
|
+
last_iteration_values: dict[str, Any],
|
|
2506
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess,
|
|
2507
|
+
) -> None:
|
|
2508
|
+
"""Apply last iteration values to the original packaged nodes in main flow.
|
|
2509
|
+
|
|
2510
|
+
After parallel loop execution, this sets the final state of each packaged node
|
|
2511
|
+
to match the last iteration's execution results. This is important for nodes that
|
|
2512
|
+
output values or produce artifacts during loop execution.
|
|
2513
|
+
|
|
2514
|
+
Args:
|
|
2515
|
+
last_iteration_values: Dict mapping sanitized End node parameter names to values
|
|
2516
|
+
package_result: PackageNodesAsSerializedFlowResultSuccess containing parameter mappings and node names
|
|
2517
|
+
"""
|
|
2518
|
+
if not last_iteration_values:
|
|
2519
|
+
logger.debug("No last iteration values to apply to packaged nodes")
|
|
2520
|
+
return
|
|
2521
|
+
|
|
2522
|
+
# Get End node parameter mappings (index 1 in the list)
|
|
2523
|
+
end_node_mapping = self.get_node_parameter_mappings(package_result, "end")
|
|
2524
|
+
end_node_param_mappings = end_node_mapping.parameter_mappings
|
|
2525
|
+
|
|
2526
|
+
node_manager = GriptapeNodes.NodeManager()
|
|
2527
|
+
|
|
2528
|
+
# For each parameter in the End node, map it back to the original node and set the value
|
|
2529
|
+
for sanitized_param_name, param_value in last_iteration_values.items():
|
|
2530
|
+
# Check if this parameter has a mapping in the End node
|
|
2531
|
+
if sanitized_param_name not in end_node_param_mappings:
|
|
2532
|
+
continue
|
|
2533
|
+
|
|
2534
|
+
original_node_param = end_node_param_mappings[sanitized_param_name]
|
|
2535
|
+
target_node_name = original_node_param.node_name
|
|
2536
|
+
target_param_name = original_node_param.parameter_name
|
|
2537
|
+
|
|
2538
|
+
# Get the original packaged node in the main flow
|
|
2539
|
+
try:
|
|
2540
|
+
target_node = node_manager.get_node_by_name(target_node_name)
|
|
2541
|
+
except Exception:
|
|
2542
|
+
logger.warning(
|
|
2543
|
+
"Could not find packaged node '%s' in main flow to apply last iteration values", target_node_name
|
|
2544
|
+
)
|
|
2545
|
+
continue
|
|
2546
|
+
|
|
2547
|
+
# Get the parameter from the target node
|
|
2548
|
+
target_param = target_node.get_parameter_by_name(target_param_name)
|
|
2549
|
+
|
|
2550
|
+
# Skip if parameter not found or is special parameter
|
|
2551
|
+
if target_param is None:
|
|
2552
|
+
logger.debug("Skipping missing parameter '%s' on node '%s'", target_param_name, target_node_name)
|
|
2553
|
+
continue
|
|
2554
|
+
|
|
2555
|
+
# Skip control parameters
|
|
2556
|
+
if target_param.type == ParameterTypeBuiltin.CONTROL_TYPE:
|
|
2557
|
+
logger.debug("Skipping control parameter '%s' on node '%s'", target_param_name, target_node_name)
|
|
2558
|
+
continue
|
|
2559
|
+
|
|
2560
|
+
# Set the value on the target node
|
|
2561
|
+
target_node.set_parameter_value(target_param_name, param_value)
|
|
2562
|
+
target_node.parameter_output_values[target_param_name] = param_value
|
|
2563
|
+
|
|
2564
|
+
logger.debug(
|
|
2565
|
+
"Applied last iteration value to packaged node '%s' parameter '%s'",
|
|
2566
|
+
target_node_name,
|
|
2567
|
+
target_param_name,
|
|
2568
|
+
)
|
|
2569
|
+
|
|
2570
|
+
logger.info(
|
|
2571
|
+
"Successfully applied %d parameter values from last iteration to packaged nodes",
|
|
2572
|
+
len(last_iteration_values),
|
|
2573
|
+
)
|
|
533
2574
|
|
|
534
2575
|
async def _delete_workflow(self, workflow_name: str, workflow_path: Path) -> None:
|
|
535
2576
|
try:
|