griptape-nodes 0.64.11__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/app/app.py +25 -5
- griptape_nodes/cli/commands/init.py +65 -54
- griptape_nodes/cli/commands/libraries.py +92 -85
- griptape_nodes/cli/commands/self.py +121 -0
- griptape_nodes/common/node_executor.py +2142 -101
- griptape_nodes/exe_types/base_iterative_nodes.py +1004 -0
- griptape_nodes/exe_types/connections.py +114 -19
- griptape_nodes/exe_types/core_types.py +225 -7
- griptape_nodes/exe_types/flow.py +3 -3
- griptape_nodes/exe_types/node_types.py +681 -225
- griptape_nodes/exe_types/param_components/README.md +414 -0
- griptape_nodes/exe_types/param_components/api_key_provider_parameter.py +200 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_model_parameter.py +2 -0
- griptape_nodes/exe_types/param_components/huggingface/huggingface_repo_file_parameter.py +79 -5
- griptape_nodes/exe_types/param_types/parameter_button.py +443 -0
- griptape_nodes/machines/control_flow.py +77 -38
- griptape_nodes/machines/dag_builder.py +148 -70
- griptape_nodes/machines/parallel_resolution.py +61 -35
- griptape_nodes/machines/sequential_resolution.py +11 -113
- griptape_nodes/retained_mode/events/app_events.py +1 -0
- griptape_nodes/retained_mode/events/base_events.py +16 -13
- griptape_nodes/retained_mode/events/connection_events.py +3 -0
- griptape_nodes/retained_mode/events/execution_events.py +35 -0
- griptape_nodes/retained_mode/events/flow_events.py +15 -2
- griptape_nodes/retained_mode/events/library_events.py +347 -0
- griptape_nodes/retained_mode/events/node_events.py +48 -0
- griptape_nodes/retained_mode/events/os_events.py +86 -3
- griptape_nodes/retained_mode/events/project_events.py +15 -1
- griptape_nodes/retained_mode/events/workflow_events.py +48 -1
- griptape_nodes/retained_mode/griptape_nodes.py +6 -2
- griptape_nodes/retained_mode/managers/config_manager.py +10 -8
- griptape_nodes/retained_mode/managers/event_manager.py +168 -0
- griptape_nodes/retained_mode/managers/fitness_problems/libraries/__init__.py +2 -0
- griptape_nodes/retained_mode/managers/fitness_problems/libraries/old_xdg_location_warning_problem.py +43 -0
- griptape_nodes/retained_mode/managers/flow_manager.py +664 -123
- griptape_nodes/retained_mode/managers/library_manager.py +1134 -138
- griptape_nodes/retained_mode/managers/model_manager.py +2 -3
- griptape_nodes/retained_mode/managers/node_manager.py +148 -25
- griptape_nodes/retained_mode/managers/object_manager.py +3 -1
- griptape_nodes/retained_mode/managers/operation_manager.py +3 -1
- griptape_nodes/retained_mode/managers/os_manager.py +1158 -122
- griptape_nodes/retained_mode/managers/secrets_manager.py +2 -3
- griptape_nodes/retained_mode/managers/settings.py +21 -1
- griptape_nodes/retained_mode/managers/sync_manager.py +2 -3
- griptape_nodes/retained_mode/managers/workflow_manager.py +358 -104
- griptape_nodes/retained_mode/retained_mode.py +3 -3
- griptape_nodes/traits/button.py +44 -2
- griptape_nodes/traits/file_system_picker.py +2 -2
- griptape_nodes/utils/file_utils.py +101 -0
- griptape_nodes/utils/git_utils.py +1226 -0
- griptape_nodes/utils/library_utils.py +122 -0
- {griptape_nodes-0.64.11.dist-info → griptape_nodes-0.65.0.dist-info}/METADATA +2 -1
- {griptape_nodes-0.64.11.dist-info → griptape_nodes-0.65.0.dist-info}/RECORD +55 -47
- {griptape_nodes-0.64.11.dist-info → griptape_nodes-0.65.0.dist-info}/WHEEL +1 -1
- {griptape_nodes-0.64.11.dist-info → griptape_nodes-0.65.0.dist-info}/entry_points.txt +0 -0
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
|
|
7
7
|
from uuid import uuid4
|
|
8
8
|
|
|
9
9
|
from griptape_nodes.common.node_executor import NodeExecutor
|
|
10
|
+
from griptape_nodes.exe_types.base_iterative_nodes import BaseIterativeStartNode
|
|
10
11
|
from griptape_nodes.exe_types.connections import Connections
|
|
11
12
|
from griptape_nodes.exe_types.core_types import (
|
|
12
13
|
Parameter,
|
|
@@ -22,7 +23,6 @@ from griptape_nodes.exe_types.node_types import (
|
|
|
22
23
|
NodeDependencies,
|
|
23
24
|
NodeGroupNode,
|
|
24
25
|
NodeResolutionState,
|
|
25
|
-
StartLoopNode,
|
|
26
26
|
StartNode,
|
|
27
27
|
)
|
|
28
28
|
from griptape_nodes.machines.control_flow import CompleteState, ControlFlowMachine
|
|
@@ -75,6 +75,9 @@ from griptape_nodes.retained_mode.events.execution_events import (
|
|
|
75
75
|
StartFlowRequest,
|
|
76
76
|
StartFlowResultFailure,
|
|
77
77
|
StartFlowResultSuccess,
|
|
78
|
+
StartLocalSubflowRequest,
|
|
79
|
+
StartLocalSubflowResultFailure,
|
|
80
|
+
StartLocalSubflowResultSuccess,
|
|
78
81
|
UnresolveFlowRequest,
|
|
79
82
|
UnresolveFlowResultFailure,
|
|
80
83
|
UnresolveFlowResultSuccess,
|
|
@@ -125,6 +128,7 @@ from griptape_nodes.retained_mode.events.node_events import (
|
|
|
125
128
|
DeleteNodeRequest,
|
|
126
129
|
DeleteNodeResultFailure,
|
|
127
130
|
DeserializeNodeFromCommandsRequest,
|
|
131
|
+
DeserializeNodeFromCommandsResultSuccess,
|
|
128
132
|
SerializedNodeCommands,
|
|
129
133
|
SerializedParameterValueTracker,
|
|
130
134
|
SerializeNodeToCommandsRequest,
|
|
@@ -199,6 +203,8 @@ class PackagingStartNodeResult(NamedTuple):
|
|
|
199
203
|
start_to_package_connections: list[SerializedFlowCommands.IndirectConnectionSerialization]
|
|
200
204
|
input_shape_data: WorkflowShapeNodes
|
|
201
205
|
start_node_parameter_value_commands: list[SerializedNodeCommands.IndirectSetParameterValueCommand]
|
|
206
|
+
parameter_name_mappings: dict[SanitizedParameterName, OriginalNodeParameter]
|
|
207
|
+
start_node_name: str
|
|
202
208
|
|
|
203
209
|
|
|
204
210
|
class PackagingEndNodeResult(NamedTuple):
|
|
@@ -215,6 +221,7 @@ class MultiNodeEndNodeResult(NamedTuple):
|
|
|
215
221
|
packaging_result: PackagingEndNodeResult
|
|
216
222
|
parameter_name_mappings: dict[SanitizedParameterName, OriginalNodeParameter]
|
|
217
223
|
alter_parameter_commands: list[AlterParameterDetailsRequest]
|
|
224
|
+
end_node_name: str
|
|
218
225
|
|
|
219
226
|
|
|
220
227
|
class FlowManager:
|
|
@@ -265,7 +272,7 @@ class FlowManager:
|
|
|
265
272
|
event_manager.assign_manager_to_request_type(
|
|
266
273
|
PackageNodesAsSerializedFlowRequest, self.on_package_nodes_as_serialized_flow_request
|
|
267
274
|
)
|
|
268
|
-
|
|
275
|
+
event_manager.assign_manager_to_request_type(StartLocalSubflowRequest, self.on_start_local_subflow_request)
|
|
269
276
|
self._name_to_parent_name = {}
|
|
270
277
|
self._flow_to_referenced_workflow_name = {}
|
|
271
278
|
self._connections = Connections()
|
|
@@ -323,14 +330,35 @@ class FlowManager:
|
|
|
323
330
|
return connections
|
|
324
331
|
|
|
325
332
|
def _get_connections_for_flow(self, flow: ControlFlow) -> list:
|
|
326
|
-
"""Get connections where both nodes are in the specified flow.
|
|
333
|
+
"""Get connections where both nodes are in the specified flow or its child flows.
|
|
334
|
+
|
|
335
|
+
For parent flows, this includes cross-flow connections between the parent and its children.
|
|
336
|
+
For child flows, this only includes connections within that specific flow.
|
|
337
|
+
"""
|
|
327
338
|
flow_connections = []
|
|
339
|
+
flow_name = flow.name
|
|
340
|
+
|
|
341
|
+
# Get all child flow names for this flow
|
|
342
|
+
child_flow_names = set()
|
|
343
|
+
for child_name, parent_name in self._name_to_parent_name.items():
|
|
344
|
+
if parent_name == flow_name:
|
|
345
|
+
child_flow_names.add(child_name)
|
|
346
|
+
|
|
347
|
+
# Build set of all node names in this flow and its direct children
|
|
348
|
+
all_node_names = set(flow.nodes.keys())
|
|
349
|
+
for child_flow_name in child_flow_names:
|
|
350
|
+
child_flow = GriptapeNodes.ObjectManager().attempt_get_object_by_name_as_type(child_flow_name, ControlFlow)
|
|
351
|
+
if child_flow is not None:
|
|
352
|
+
all_node_names.update(child_flow.nodes.keys())
|
|
353
|
+
|
|
354
|
+
# Include connections where both nodes are in this flow hierarchy
|
|
328
355
|
for connection in self._connections.connections.values():
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
if
|
|
356
|
+
source_in_hierarchy = connection.source_node.name in all_node_names
|
|
357
|
+
target_in_hierarchy = connection.target_node.name in all_node_names
|
|
358
|
+
|
|
359
|
+
if source_in_hierarchy and target_in_hierarchy:
|
|
333
360
|
flow_connections.append(connection)
|
|
361
|
+
|
|
334
362
|
return flow_connections
|
|
335
363
|
|
|
336
364
|
def get_parent_flow(self, flow_name: str) -> str | None:
|
|
@@ -550,9 +578,17 @@ class FlowManager:
|
|
|
550
578
|
details = f"Attempted to delete Flow '{flow_name}', but no Flow with that name could be found."
|
|
551
579
|
result = DeleteFlowResultFailure(result_details=details)
|
|
552
580
|
return result
|
|
553
|
-
|
|
581
|
+
|
|
582
|
+
# Only cancel if the flow being deleted is the one tracked by the global control flow machine.
|
|
583
|
+
# Isolated subflows (e.g., ForEach loop iterations) have their own separate ControlFlowMachine
|
|
584
|
+
# and should not trigger cancellation of the global machine.
|
|
585
|
+
if (
|
|
586
|
+
self.check_for_existing_running_flow()
|
|
587
|
+
and self._global_control_flow_machine is not None
|
|
588
|
+
and self._global_control_flow_machine.context.flow_name == flow.name
|
|
589
|
+
):
|
|
554
590
|
result = GriptapeNodes.handle_request(CancelFlowRequest(flow_name=flow.name))
|
|
555
|
-
if
|
|
591
|
+
if result.failed():
|
|
556
592
|
details = f"Attempted to delete flow '{flow_name}'. Failed because running flow could not cancel."
|
|
557
593
|
return DeleteFlowResultFailure(result_details=details)
|
|
558
594
|
|
|
@@ -612,9 +648,14 @@ class FlowManager:
|
|
|
612
648
|
if flow in self._flow_to_referenced_workflow_name:
|
|
613
649
|
del self._flow_to_referenced_workflow_name[flow]
|
|
614
650
|
|
|
615
|
-
# Clean up ControlFlowMachine and DAG orchestrator
|
|
616
|
-
|
|
617
|
-
|
|
651
|
+
# Clean up ControlFlowMachine and DAG orchestrator only if this is the global flow.
|
|
652
|
+
# Isolated subflows have their own machines and should not clear the global state.
|
|
653
|
+
if (
|
|
654
|
+
self._global_control_flow_machine is not None
|
|
655
|
+
and self._global_control_flow_machine.context.flow_name == flow.name
|
|
656
|
+
):
|
|
657
|
+
self._global_control_flow_machine = None
|
|
658
|
+
self._global_dag_builder.clear()
|
|
618
659
|
|
|
619
660
|
details = f"Successfully deleted Flow '{flow_name}'."
|
|
620
661
|
result = DeleteFlowResultSuccess(result_details=details)
|
|
@@ -876,13 +917,23 @@ class FlowManager:
|
|
|
876
917
|
details = f"Deleted the previous connection from '{old_source_node_name}.{old_source_param_name}' to '{old_target_node_name}.{old_target_param_name}' to make room for the new connection."
|
|
877
918
|
try:
|
|
878
919
|
# Actually create the Connection.
|
|
920
|
+
if (isinstance(source_node, NodeGroupNode) and target_node.parent_group == source_node) or (
|
|
921
|
+
isinstance(target_node, NodeGroupNode) and source_node.parent_group == target_node
|
|
922
|
+
):
|
|
923
|
+
# Here we're checking if it's an internal connection. (from the NodeGroup to a node within it.)
|
|
924
|
+
# If that's true, we set that automatically.
|
|
925
|
+
is_node_group_internal = True
|
|
926
|
+
else:
|
|
927
|
+
# If not true, we default to the request
|
|
928
|
+
is_node_group_internal = request.is_node_group_internal
|
|
879
929
|
conn = self._connections.add_connection(
|
|
880
930
|
source_node=source_node,
|
|
881
931
|
source_parameter=source_param,
|
|
882
932
|
target_node=target_node,
|
|
883
933
|
target_parameter=target_param,
|
|
934
|
+
is_node_group_internal=is_node_group_internal,
|
|
884
935
|
)
|
|
885
|
-
|
|
936
|
+
id(conn)
|
|
886
937
|
except ValueError as e:
|
|
887
938
|
details = f'Connection failed: "{e}"'
|
|
888
939
|
|
|
@@ -923,26 +974,36 @@ class FlowManager:
|
|
|
923
974
|
target_parent = target_node.parent_group
|
|
924
975
|
|
|
925
976
|
# If source is in a group, this is an outgoing external connection
|
|
926
|
-
if
|
|
927
|
-
source_parent
|
|
977
|
+
if (
|
|
978
|
+
source_parent is not None
|
|
979
|
+
and isinstance(source_parent, NodeGroupNode)
|
|
980
|
+
and source_parent not in (target_parent, target_node)
|
|
981
|
+
):
|
|
982
|
+
success = source_parent.map_external_connection(
|
|
928
983
|
conn=conn,
|
|
929
|
-
conn_id=conn_id,
|
|
930
984
|
is_incoming=False,
|
|
931
|
-
grouped_node=source_node,
|
|
932
985
|
)
|
|
986
|
+
if success:
|
|
987
|
+
details = f'Connected "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name}, remapped with proxy parameter."'
|
|
988
|
+
return CreateConnectionResultSuccess(result_details=details)
|
|
989
|
+
details = f'Failed to connect "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name} by remapping to proxy."'
|
|
990
|
+
return CreateConnectionResultFailure(result_details=details)
|
|
933
991
|
|
|
934
992
|
# If target is in a group, this is an incoming external connection
|
|
935
|
-
if
|
|
936
|
-
target_parent
|
|
993
|
+
if (
|
|
994
|
+
target_parent is not None
|
|
995
|
+
and isinstance(target_parent, NodeGroupNode)
|
|
996
|
+
and target_parent not in (source_parent, source_node)
|
|
997
|
+
):
|
|
998
|
+
success = target_parent.map_external_connection(
|
|
937
999
|
conn=conn,
|
|
938
|
-
conn_id=conn_id,
|
|
939
1000
|
is_incoming=True,
|
|
940
|
-
grouped_node=target_node,
|
|
941
1001
|
)
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
1002
|
+
if success:
|
|
1003
|
+
details = f'Connected "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name}, remapped with proxy parameter."'
|
|
1004
|
+
return CreateConnectionResultSuccess(result_details=details)
|
|
1005
|
+
details = f'Failed to connect "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name} by remapping to proxy."'
|
|
1006
|
+
return CreateConnectionResultFailure(result_details=details)
|
|
946
1007
|
|
|
947
1008
|
details = f'Connected "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name}"'
|
|
948
1009
|
|
|
@@ -1084,12 +1145,7 @@ class FlowManager:
|
|
|
1084
1145
|
|
|
1085
1146
|
# Check if either node is in a NodeGroup and untrack connections BEFORE removing connection
|
|
1086
1147
|
|
|
1087
|
-
source_parent = source_node.parent_group
|
|
1088
|
-
target_parent = target_node.parent_group
|
|
1089
|
-
|
|
1090
1148
|
# Find the connection before it's deleted
|
|
1091
|
-
conn = None
|
|
1092
|
-
conn_id = None
|
|
1093
1149
|
if (
|
|
1094
1150
|
source_node.name in self._connections.outgoing_index
|
|
1095
1151
|
and source_param.name in self._connections.outgoing_index[source_node.name]
|
|
@@ -1101,47 +1157,8 @@ class FlowManager:
|
|
|
1101
1157
|
candidate.target_node.name == target_node.name
|
|
1102
1158
|
and candidate.target_parameter.name == target_param.name
|
|
1103
1159
|
):
|
|
1104
|
-
conn = candidate
|
|
1105
|
-
conn_id = candidate_id
|
|
1106
1160
|
break
|
|
1107
1161
|
|
|
1108
|
-
# If source is in a group, untrack outgoing external connection
|
|
1109
|
-
if (
|
|
1110
|
-
conn
|
|
1111
|
-
and conn_id
|
|
1112
|
-
and source_parent is not None
|
|
1113
|
-
and isinstance(source_parent, NodeGroupNode)
|
|
1114
|
-
and target_parent != source_parent
|
|
1115
|
-
):
|
|
1116
|
-
source_parent.untrack_external_connection(
|
|
1117
|
-
conn=conn,
|
|
1118
|
-
conn_id=conn_id,
|
|
1119
|
-
is_incoming=False,
|
|
1120
|
-
)
|
|
1121
|
-
|
|
1122
|
-
# If target is in a group, untrack incoming external connection
|
|
1123
|
-
if (
|
|
1124
|
-
conn
|
|
1125
|
-
and conn_id
|
|
1126
|
-
and target_parent is not None
|
|
1127
|
-
and isinstance(target_parent, NodeGroupNode)
|
|
1128
|
-
and source_parent != target_parent
|
|
1129
|
-
):
|
|
1130
|
-
target_parent.untrack_external_connection(
|
|
1131
|
-
conn=conn,
|
|
1132
|
-
conn_id=conn_id,
|
|
1133
|
-
is_incoming=True,
|
|
1134
|
-
)
|
|
1135
|
-
|
|
1136
|
-
# If both in same group, untrack internal connection
|
|
1137
|
-
if (
|
|
1138
|
-
conn
|
|
1139
|
-
and source_parent is not None
|
|
1140
|
-
and source_parent == target_parent
|
|
1141
|
-
and isinstance(source_parent, NodeGroupNode)
|
|
1142
|
-
):
|
|
1143
|
-
source_parent.untrack_internal_connection(conn)
|
|
1144
|
-
|
|
1145
1162
|
# Remove the connection.
|
|
1146
1163
|
if not self._connections.remove_connection(
|
|
1147
1164
|
source_node=source_node.name,
|
|
@@ -1204,7 +1221,7 @@ class FlowManager:
|
|
|
1204
1221
|
result = DeleteConnectionResultSuccess(result_details=details)
|
|
1205
1222
|
return result
|
|
1206
1223
|
|
|
1207
|
-
def on_package_nodes_as_serialized_flow_request( # noqa: C901, PLR0911, PLR0912
|
|
1224
|
+
def on_package_nodes_as_serialized_flow_request( # noqa: C901, PLR0911, PLR0912, PLR0915
|
|
1208
1225
|
self, request: PackageNodesAsSerializedFlowRequest
|
|
1209
1226
|
) -> ResultPayload:
|
|
1210
1227
|
"""Handle request to package multiple nodes as a serialized flow.
|
|
@@ -1274,7 +1291,17 @@ class FlowManager:
|
|
|
1274
1291
|
if isinstance(node_connections_dict, PackageNodesAsSerializedFlowResultFailure):
|
|
1275
1292
|
return node_connections_dict
|
|
1276
1293
|
|
|
1277
|
-
# Step 8:
|
|
1294
|
+
# Step 8: Retrieve NodeGroupNode if node_group_name was provided
|
|
1295
|
+
node_group_node: NodeGroupNode | None = None
|
|
1296
|
+
if request.node_group_name:
|
|
1297
|
+
try:
|
|
1298
|
+
node = GriptapeNodes.NodeManager().get_node_by_name(request.node_group_name)
|
|
1299
|
+
if isinstance(node, NodeGroupNode):
|
|
1300
|
+
node_group_node = node
|
|
1301
|
+
except Exception as e:
|
|
1302
|
+
logger.debug("Failed to retrieve NodeGroupNode '%s': %s", request.node_group_name, e)
|
|
1303
|
+
|
|
1304
|
+
# Step 9: Create start node with parameters for external incoming connections
|
|
1278
1305
|
start_node_result = self._create_multi_node_start_node_with_connections(
|
|
1279
1306
|
request=request,
|
|
1280
1307
|
library_version=library_version,
|
|
@@ -1282,11 +1309,12 @@ class FlowManager:
|
|
|
1282
1309
|
serialized_parameter_value_tracker=serialized_parameter_value_tracker,
|
|
1283
1310
|
node_name_to_uuid=node_name_to_uuid,
|
|
1284
1311
|
external_connections_dict=node_connections_dict,
|
|
1312
|
+
node_group_node=node_group_node,
|
|
1285
1313
|
)
|
|
1286
1314
|
if isinstance(start_node_result, PackageNodesAsSerializedFlowResultFailure):
|
|
1287
1315
|
return start_node_result
|
|
1288
1316
|
|
|
1289
|
-
# Step
|
|
1317
|
+
# Step 10: Create end node with parameters for external outgoing connections and parameter mappings
|
|
1290
1318
|
end_node_result = self._create_multi_node_end_node_with_connections(
|
|
1291
1319
|
request=request,
|
|
1292
1320
|
package_nodes=nodes_to_package,
|
|
@@ -1298,9 +1326,32 @@ class FlowManager:
|
|
|
1298
1326
|
return end_node_result
|
|
1299
1327
|
|
|
1300
1328
|
end_node_packaging_result = end_node_result.packaging_result
|
|
1301
|
-
parameter_name_mappings = end_node_result.parameter_name_mappings
|
|
1302
1329
|
|
|
1303
|
-
#
|
|
1330
|
+
# If no entry control node specified, connect start directly to end
|
|
1331
|
+
if not request.entry_control_node_name and not request.entry_control_parameter_name:
|
|
1332
|
+
start_to_end_control_connection = SerializedFlowCommands.IndirectConnectionSerialization(
|
|
1333
|
+
source_node_uuid=start_node_result.start_node_commands.node_uuid,
|
|
1334
|
+
source_parameter_name="exec_out",
|
|
1335
|
+
target_node_uuid=end_node_packaging_result.end_node_commands.node_uuid,
|
|
1336
|
+
target_parameter_name="exec_in",
|
|
1337
|
+
)
|
|
1338
|
+
start_node_result.start_to_package_connections.append(start_to_end_control_connection)
|
|
1339
|
+
|
|
1340
|
+
# Combine parameter mappings as a list: [Start node (index 0), End node (index 1)]
|
|
1341
|
+
from griptape_nodes.retained_mode.events.flow_events import PackagedNodeParameterMapping
|
|
1342
|
+
|
|
1343
|
+
parameter_name_mappings = [
|
|
1344
|
+
PackagedNodeParameterMapping(
|
|
1345
|
+
node_name=start_node_result.start_node_name,
|
|
1346
|
+
parameter_mappings=start_node_result.parameter_name_mappings,
|
|
1347
|
+
),
|
|
1348
|
+
PackagedNodeParameterMapping(
|
|
1349
|
+
node_name=end_node_result.end_node_name,
|
|
1350
|
+
parameter_mappings=end_node_result.parameter_name_mappings,
|
|
1351
|
+
),
|
|
1352
|
+
]
|
|
1353
|
+
|
|
1354
|
+
# Step 11: Assemble final SerializedFlowCommands
|
|
1304
1355
|
# Collect all connections from start/end nodes and internal package connections
|
|
1305
1356
|
all_connections = self._collect_all_connections_for_multi_node_package(
|
|
1306
1357
|
start_node_result=start_node_result,
|
|
@@ -1432,7 +1483,7 @@ class FlowManager:
|
|
|
1432
1483
|
|
|
1433
1484
|
return None
|
|
1434
1485
|
|
|
1435
|
-
def _serialize_package_nodes_for_local_execution( # noqa: PLR0913
|
|
1486
|
+
def _serialize_package_nodes_for_local_execution( # noqa: PLR0913, C901
|
|
1436
1487
|
self,
|
|
1437
1488
|
nodes_to_package: list[BaseNode],
|
|
1438
1489
|
unique_parameter_uuid_to_values: dict[SerializedNodeCommands.UniqueParameterValueUUID, Any],
|
|
@@ -1496,6 +1547,21 @@ class FlowManager:
|
|
|
1496
1547
|
serialize_result.set_parameter_value_commands
|
|
1497
1548
|
)
|
|
1498
1549
|
|
|
1550
|
+
# Update NodeGroupNode commands to use UUIDs instead of names in node_names_to_add
|
|
1551
|
+
# This allows workflow generation to directly look up variable names from UUIDs
|
|
1552
|
+
|
|
1553
|
+
for node_group_command in serialized_node_group_commands:
|
|
1554
|
+
create_cmd = node_group_command.create_node_command
|
|
1555
|
+
|
|
1556
|
+
if isinstance(create_cmd, CreateNodeGroupRequest) and create_cmd.node_names_to_add:
|
|
1557
|
+
node_uuids = []
|
|
1558
|
+
for child_node_name in create_cmd.node_names_to_add:
|
|
1559
|
+
if child_node_name in node_name_to_uuid:
|
|
1560
|
+
uuid = node_name_to_uuid[child_node_name]
|
|
1561
|
+
node_uuids.append(uuid)
|
|
1562
|
+
# Replace the list with UUIDs (as strings since that's what the field expects)
|
|
1563
|
+
create_cmd.node_names_to_add = node_uuids
|
|
1564
|
+
|
|
1499
1565
|
# Build internal connections between package nodes
|
|
1500
1566
|
package_node_names_set = {n.name for n in nodes_to_package}
|
|
1501
1567
|
|
|
@@ -1801,6 +1867,7 @@ class FlowManager:
|
|
|
1801
1867
|
packaging_result=end_node_result,
|
|
1802
1868
|
parameter_name_mappings=parameter_name_mappings,
|
|
1803
1869
|
alter_parameter_commands=[],
|
|
1870
|
+
end_node_name=end_node_name,
|
|
1804
1871
|
)
|
|
1805
1872
|
|
|
1806
1873
|
def _create_end_node_control_connections( # noqa: PLR0913
|
|
@@ -1947,10 +2014,15 @@ class FlowManager:
|
|
|
1947
2014
|
output_shape_data[end_node_name][sanitized_param_name] = param_shape_info
|
|
1948
2015
|
|
|
1949
2016
|
# Create parameter command for end node
|
|
2017
|
+
# Use flexible input types for data parameters to prevent type mismatch errors
|
|
2018
|
+
# Control parameters must keep their exact types (cannot use "any")
|
|
2019
|
+
is_control_param = parameter.output_type == ParameterTypeBuiltin.CONTROL_TYPE.value
|
|
2020
|
+
|
|
1950
2021
|
add_param_request = AddParameterToNodeRequest(
|
|
1951
2022
|
node_name=end_node_name,
|
|
1952
2023
|
parameter_name=sanitized_param_name,
|
|
1953
|
-
|
|
2024
|
+
input_types=parameter.input_types if is_control_param else ["any"], # Control: exact types; Data: any
|
|
2025
|
+
output_type=parameter.output_type, # Preserve original output type
|
|
1954
2026
|
default_value=None,
|
|
1955
2027
|
tooltip=tooltip,
|
|
1956
2028
|
initial_setup=True,
|
|
@@ -1976,11 +2048,14 @@ class FlowManager:
|
|
|
1976
2048
|
external_connections_dict: dict[
|
|
1977
2049
|
str, ConnectionAnalysis
|
|
1978
2050
|
], # Contains EXTERNAL connections only - used to determine which parameters need start node inputs
|
|
2051
|
+
node_group_node: NodeGroupNode | None = None,
|
|
1979
2052
|
) -> PackagingStartNodeResult | PackageNodesAsSerializedFlowResultFailure:
|
|
1980
2053
|
"""Create start node commands and connections for external incoming connections."""
|
|
1981
2054
|
# Generate UUID and name for start node
|
|
1982
2055
|
start_node_uuid = SerializedNodeCommands.NodeUUID(str(uuid4()))
|
|
1983
2056
|
start_node_name = "Start_Package_MultiNode"
|
|
2057
|
+
# Parameter name mappings are essential to know which inputs are necessary on the start node given.
|
|
2058
|
+
parameter_name_mappings: dict[SanitizedParameterName, OriginalNodeParameter] = {}
|
|
1984
2059
|
|
|
1985
2060
|
# Build start node CreateNodeRequest
|
|
1986
2061
|
start_create_node_command = CreateNodeRequest(
|
|
@@ -2021,6 +2096,7 @@ class FlowManager:
|
|
|
2021
2096
|
node_name_to_uuid=node_name_to_uuid,
|
|
2022
2097
|
unique_parameter_uuid_to_values=unique_parameter_uuid_to_values,
|
|
2023
2098
|
serialized_parameter_value_tracker=serialized_parameter_value_tracker,
|
|
2099
|
+
parameter_name_mappings=parameter_name_mappings,
|
|
2024
2100
|
)
|
|
2025
2101
|
if isinstance(result, PackageNodesAsSerializedFlowResultFailure):
|
|
2026
2102
|
return result
|
|
@@ -2047,6 +2123,17 @@ class FlowManager:
|
|
|
2047
2123
|
# Add control connections to the same list as data connections
|
|
2048
2124
|
start_to_package_connections.extend(control_connections)
|
|
2049
2125
|
|
|
2126
|
+
# Set parameter values from NodeGroupNode if provided
|
|
2127
|
+
if node_group_node is not None:
|
|
2128
|
+
self._apply_node_group_parameters_to_start_node(
|
|
2129
|
+
node_group_node=node_group_node,
|
|
2130
|
+
start_node_library_name=request.start_node_library_name,
|
|
2131
|
+
start_node_type=request.start_node_type, # type: ignore[arg-type] # Guaranteed non-None
|
|
2132
|
+
start_node_parameter_value_commands=start_node_parameter_value_commands,
|
|
2133
|
+
unique_parameter_uuid_to_values=unique_parameter_uuid_to_values,
|
|
2134
|
+
serialized_parameter_value_tracker=serialized_parameter_value_tracker,
|
|
2135
|
+
)
|
|
2136
|
+
|
|
2050
2137
|
# Build complete SerializedNodeCommands for start node
|
|
2051
2138
|
start_node_dependencies = NodeDependencies()
|
|
2052
2139
|
start_node_dependencies.libraries.add(start_node_library_details)
|
|
@@ -2063,8 +2150,101 @@ class FlowManager:
|
|
|
2063
2150
|
start_to_package_connections=start_to_package_connections,
|
|
2064
2151
|
input_shape_data=input_shape_data,
|
|
2065
2152
|
start_node_parameter_value_commands=start_node_parameter_value_commands,
|
|
2153
|
+
parameter_name_mappings=parameter_name_mappings,
|
|
2154
|
+
start_node_name=start_node_name,
|
|
2066
2155
|
)
|
|
2067
2156
|
|
|
2157
|
+
def _apply_node_group_parameters_to_start_node( # noqa: PLR0913
|
|
2158
|
+
self,
|
|
2159
|
+
node_group_node: NodeGroupNode,
|
|
2160
|
+
start_node_library_name: str,
|
|
2161
|
+
start_node_type: str,
|
|
2162
|
+
start_node_parameter_value_commands: list[SerializedNodeCommands.IndirectSetParameterValueCommand],
|
|
2163
|
+
unique_parameter_uuid_to_values: dict[SerializedNodeCommands.UniqueParameterValueUUID, Any],
|
|
2164
|
+
serialized_parameter_value_tracker: SerializedParameterValueTracker,
|
|
2165
|
+
) -> None:
|
|
2166
|
+
"""Apply parameter values from NodeGroupNode to the StartFlow node.
|
|
2167
|
+
|
|
2168
|
+
This method reads the execution environment metadata from the NodeGroupNode,
|
|
2169
|
+
extracts parameter values for the specified StartFlow node type, and creates
|
|
2170
|
+
set parameter value commands for those parameters.
|
|
2171
|
+
|
|
2172
|
+
Args:
|
|
2173
|
+
node_group_node: The NodeGroupNode containing parameter values
|
|
2174
|
+
start_node_library_name: Name of the library containing the StartFlow node
|
|
2175
|
+
start_node_type: Type of the StartFlow node
|
|
2176
|
+
start_node_parameter_value_commands: List to append parameter value commands to
|
|
2177
|
+
unique_parameter_uuid_to_values: Dict to track unique parameter values
|
|
2178
|
+
serialized_parameter_value_tracker: Tracker for serialized parameter values
|
|
2179
|
+
|
|
2180
|
+
Raises:
|
|
2181
|
+
ValueError: If required metadata is missing from NodeGroupNode
|
|
2182
|
+
"""
|
|
2183
|
+
# Get execution environment metadata from NodeGroupNode
|
|
2184
|
+
if not node_group_node.metadata:
|
|
2185
|
+
msg = f"NodeGroupNode '{node_group_node.name}' is missing metadata. Cannot apply parameters to StartFlow node."
|
|
2186
|
+
raise ValueError(msg)
|
|
2187
|
+
|
|
2188
|
+
execution_env_metadata = node_group_node.metadata.get("execution_environment")
|
|
2189
|
+
if not execution_env_metadata:
|
|
2190
|
+
msg = f"NodeGroupNode '{node_group_node.name}' metadata is missing 'execution_environment'. Cannot apply parameters to StartFlow node."
|
|
2191
|
+
raise ValueError(msg)
|
|
2192
|
+
|
|
2193
|
+
# Find the metadata for the current library
|
|
2194
|
+
library_metadata = execution_env_metadata.get(start_node_library_name)
|
|
2195
|
+
if library_metadata is None:
|
|
2196
|
+
msg = f"NodeGroupNode '{node_group_node.name}' metadata does not contain library '{start_node_library_name}'. Available libraries: {list(execution_env_metadata.keys())}"
|
|
2197
|
+
raise ValueError(msg)
|
|
2198
|
+
|
|
2199
|
+
# Verify this is the correct StartFlow node type
|
|
2200
|
+
registered_start_flow_node = library_metadata.get("start_flow_node")
|
|
2201
|
+
if registered_start_flow_node != start_node_type:
|
|
2202
|
+
msg = f"NodeGroupNode '{node_group_node.name}' has mismatched StartFlow node type. Expected '{start_node_type}', but metadata has '{registered_start_flow_node}'"
|
|
2203
|
+
raise ValueError(msg)
|
|
2204
|
+
|
|
2205
|
+
# Get the list of parameter names that belong to this StartFlow node
|
|
2206
|
+
parameter_names = library_metadata.get("parameter_names", [])
|
|
2207
|
+
if not parameter_names:
|
|
2208
|
+
# This is not an error - it's valid for a StartFlow node to have no parameters
|
|
2209
|
+
logger.debug(
|
|
2210
|
+
"NodeGroupNode '%s' has no parameters registered for StartFlow node '%s'",
|
|
2211
|
+
node_group_node.name,
|
|
2212
|
+
start_node_type,
|
|
2213
|
+
)
|
|
2214
|
+
return
|
|
2215
|
+
|
|
2216
|
+
# For each parameter, get its value from the NodeGroupNode and create a set value command
|
|
2217
|
+
for prefixed_param_name in parameter_names:
|
|
2218
|
+
# Get the value from the NodeGroupNode parameter
|
|
2219
|
+
param_value = node_group_node.get_parameter_value(param_name=prefixed_param_name)
|
|
2220
|
+
|
|
2221
|
+
# Skip if no value is set
|
|
2222
|
+
if param_value is None:
|
|
2223
|
+
continue
|
|
2224
|
+
|
|
2225
|
+
# Strip the prefix to get the original parameter name for the StartFlow node
|
|
2226
|
+
class_name_prefix = start_node_type.lower()
|
|
2227
|
+
original_param_name = prefixed_param_name.removeprefix(f"{class_name_prefix}_")
|
|
2228
|
+
|
|
2229
|
+
# Create unique parameter UUID for this value
|
|
2230
|
+
value_id = id(param_value)
|
|
2231
|
+
unique_param_uuid = SerializedNodeCommands.UniqueParameterValueUUID(str(uuid4()))
|
|
2232
|
+
unique_parameter_uuid_to_values[unique_param_uuid] = param_value
|
|
2233
|
+
serialized_parameter_value_tracker.add_as_serializable(value_id, unique_param_uuid)
|
|
2234
|
+
|
|
2235
|
+
# Create set parameter value command
|
|
2236
|
+
set_value_request = SetParameterValueRequest(
|
|
2237
|
+
parameter_name=original_param_name,
|
|
2238
|
+
value=None, # Will be overridden when instantiated
|
|
2239
|
+
is_output=False,
|
|
2240
|
+
initial_setup=True,
|
|
2241
|
+
)
|
|
2242
|
+
indirect_set_value_command = SerializedNodeCommands.IndirectSetParameterValueCommand(
|
|
2243
|
+
set_parameter_value_command=set_value_request,
|
|
2244
|
+
unique_value_uuid=unique_param_uuid,
|
|
2245
|
+
)
|
|
2246
|
+
start_node_parameter_value_commands.append(indirect_set_value_command)
|
|
2247
|
+
|
|
2068
2248
|
def _create_start_node_parameters_and_connections_for_incoming_data( # noqa: PLR0913
|
|
2069
2249
|
self,
|
|
2070
2250
|
target_node_name: str,
|
|
@@ -2076,6 +2256,7 @@ class FlowManager:
|
|
|
2076
2256
|
node_name_to_uuid: dict[str, SerializedNodeCommands.NodeUUID],
|
|
2077
2257
|
unique_parameter_uuid_to_values: dict[SerializedNodeCommands.UniqueParameterValueUUID, Any],
|
|
2078
2258
|
serialized_parameter_value_tracker: SerializedParameterValueTracker,
|
|
2259
|
+
parameter_name_mappings: dict[SanitizedParameterName, OriginalNodeParameter],
|
|
2079
2260
|
) -> StartNodeIncomingDataResult | PackageNodesAsSerializedFlowResultFailure:
|
|
2080
2261
|
"""Create parameters and connections for incoming data connections to a specific target node."""
|
|
2081
2262
|
start_node_parameter_commands = []
|
|
@@ -2093,6 +2274,10 @@ class FlowManager:
|
|
|
2093
2274
|
parameter_name=target_parameter_name,
|
|
2094
2275
|
)
|
|
2095
2276
|
|
|
2277
|
+
parameter_name_mappings[param_name] = OriginalNodeParameter(
|
|
2278
|
+
node_name=target_node_name, parameter_name=target_parameter_name
|
|
2279
|
+
)
|
|
2280
|
+
|
|
2096
2281
|
# Get the source node to determine parameter type (from the external connection)
|
|
2097
2282
|
try:
|
|
2098
2283
|
source_node = GriptapeNodes.NodeManager().get_node_by_name(connection.source_node_name)
|
|
@@ -2339,7 +2524,7 @@ class FlowManager:
|
|
|
2339
2524
|
ValidateFlowDependenciesRequest(flow_name=flow_name, flow_node_name=start_node.name if start_node else None)
|
|
2340
2525
|
)
|
|
2341
2526
|
try:
|
|
2342
|
-
if
|
|
2527
|
+
if result.failed():
|
|
2343
2528
|
details = f"Couldn't start flow with name {flow_name}. Flow Validation Failed"
|
|
2344
2529
|
return StartFlowResultFailure(validation_exceptions=[], result_details=details)
|
|
2345
2530
|
result = cast("ValidateFlowDependenciesResultSuccess", result)
|
|
@@ -2398,7 +2583,7 @@ class FlowManager:
|
|
|
2398
2583
|
ValidateFlowDependenciesRequest(flow_name=flow_name, flow_node_name=start_node.name if start_node else None)
|
|
2399
2584
|
)
|
|
2400
2585
|
try:
|
|
2401
|
-
if
|
|
2586
|
+
if result.failed():
|
|
2402
2587
|
details = f"Couldn't start flow with name {flow_name}. Flow Validation Failed"
|
|
2403
2588
|
return StartFlowFromNodeResultFailure(validation_exceptions=[], result_details=details)
|
|
2404
2589
|
result = cast("ValidateFlowDependenciesResultSuccess", result)
|
|
@@ -2428,6 +2613,151 @@ class FlowManager:
|
|
|
2428
2613
|
|
|
2429
2614
|
return StartFlowFromNodeResultSuccess(result_details=details)
|
|
2430
2615
|
|
|
2616
|
+
def get_start_nodes_in_flow(self, flow: ControlFlow) -> list[BaseNode]: # noqa: C901, PLR0912, PLR0915
|
|
2617
|
+
"""Find start nodes in a specific flow.
|
|
2618
|
+
|
|
2619
|
+
A start node is defined as:
|
|
2620
|
+
1. An explicit StartNode instance, OR
|
|
2621
|
+
2. A control node with no incoming control connections, OR
|
|
2622
|
+
3. A data node with no outgoing connections
|
|
2623
|
+
|
|
2624
|
+
Nodes that are children of NodeGroupNodes are excluded.
|
|
2625
|
+
|
|
2626
|
+
Args:
|
|
2627
|
+
flow: The flow to search for start nodes
|
|
2628
|
+
|
|
2629
|
+
Returns:
|
|
2630
|
+
List of start nodes, prioritized as: StartNodes, control nodes, data nodes
|
|
2631
|
+
"""
|
|
2632
|
+
connections = self.get_connections()
|
|
2633
|
+
all_nodes = list(flow.nodes.values())
|
|
2634
|
+
if not all_nodes:
|
|
2635
|
+
return []
|
|
2636
|
+
|
|
2637
|
+
start_nodes = []
|
|
2638
|
+
control_nodes = []
|
|
2639
|
+
data_nodes = []
|
|
2640
|
+
|
|
2641
|
+
for node in all_nodes:
|
|
2642
|
+
if isinstance(node, StartNode):
|
|
2643
|
+
start_nodes.append(node)
|
|
2644
|
+
continue
|
|
2645
|
+
|
|
2646
|
+
has_control_param = False
|
|
2647
|
+
for parameter in node.parameters:
|
|
2648
|
+
if ParameterTypeBuiltin.CONTROL_TYPE.value == parameter.output_type:
|
|
2649
|
+
incoming_control = (
|
|
2650
|
+
node.name in connections.incoming_index
|
|
2651
|
+
and parameter.name in connections.incoming_index[node.name]
|
|
2652
|
+
)
|
|
2653
|
+
outgoing_control = (
|
|
2654
|
+
node.name in connections.outgoing_index
|
|
2655
|
+
and parameter.name in connections.outgoing_index[node.name]
|
|
2656
|
+
)
|
|
2657
|
+
if incoming_control or outgoing_control:
|
|
2658
|
+
has_control_param = True
|
|
2659
|
+
break
|
|
2660
|
+
|
|
2661
|
+
if not has_control_param:
|
|
2662
|
+
data_nodes.append(node)
|
|
2663
|
+
continue
|
|
2664
|
+
|
|
2665
|
+
has_incoming_control = False
|
|
2666
|
+
if node.name in connections.incoming_index:
|
|
2667
|
+
for param_name in connections.incoming_index[node.name]:
|
|
2668
|
+
param = node.get_parameter_by_name(param_name)
|
|
2669
|
+
if param and ParameterTypeBuiltin.CONTROL_TYPE.value == param.output_type:
|
|
2670
|
+
connection_ids = connections.incoming_index[node.name][param_name]
|
|
2671
|
+
has_external_control_connection = False
|
|
2672
|
+
for connection_id in connection_ids:
|
|
2673
|
+
connection = connections.connections[connection_id]
|
|
2674
|
+
# Skip internal NodeGroup connections
|
|
2675
|
+
if connection.is_node_group_internal:
|
|
2676
|
+
continue
|
|
2677
|
+
if isinstance(node, BaseIterativeStartNode):
|
|
2678
|
+
connected_node = connection.get_source_node()
|
|
2679
|
+
if connected_node == node.end_node:
|
|
2680
|
+
continue
|
|
2681
|
+
has_external_control_connection = True
|
|
2682
|
+
break
|
|
2683
|
+
if has_external_control_connection:
|
|
2684
|
+
has_incoming_control = True
|
|
2685
|
+
break
|
|
2686
|
+
|
|
2687
|
+
if has_incoming_control:
|
|
2688
|
+
continue
|
|
2689
|
+
|
|
2690
|
+
if node.name in connections.outgoing_index:
|
|
2691
|
+
for param_name in connections.outgoing_index[node.name]:
|
|
2692
|
+
param = node.get_parameter_by_name(param_name)
|
|
2693
|
+
if param and ParameterTypeBuiltin.CONTROL_TYPE.value == param.output_type:
|
|
2694
|
+
control_nodes.append(node)
|
|
2695
|
+
break
|
|
2696
|
+
else:
|
|
2697
|
+
control_nodes.append(node)
|
|
2698
|
+
|
|
2699
|
+
valid_data_nodes = []
|
|
2700
|
+
for node in data_nodes:
|
|
2701
|
+
# Check if the node has any non-internal outgoing connections
|
|
2702
|
+
has_external_outgoing = False
|
|
2703
|
+
if node.name in connections.outgoing_index:
|
|
2704
|
+
for param_name in connections.outgoing_index[node.name]:
|
|
2705
|
+
connection_ids = connections.outgoing_index[node.name][param_name]
|
|
2706
|
+
for connection_id in connection_ids:
|
|
2707
|
+
connection = connections.connections[connection_id]
|
|
2708
|
+
# Skip internal NodeGroup connections
|
|
2709
|
+
if connection.is_node_group_internal:
|
|
2710
|
+
continue
|
|
2711
|
+
has_external_outgoing = True
|
|
2712
|
+
break
|
|
2713
|
+
if has_external_outgoing:
|
|
2714
|
+
break
|
|
2715
|
+
# Only add nodes that have no non-internal outgoing connections
|
|
2716
|
+
if not has_external_outgoing:
|
|
2717
|
+
valid_data_nodes.append(node)
|
|
2718
|
+
|
|
2719
|
+
return start_nodes + control_nodes + valid_data_nodes
|
|
2720
|
+
|
|
2721
|
+
async def on_start_local_subflow_request(self, request: StartLocalSubflowRequest) -> ResultPayload:
|
|
2722
|
+
flow_name = request.flow_name
|
|
2723
|
+
if not flow_name:
|
|
2724
|
+
details = "Must provide flow name to start a flow."
|
|
2725
|
+
return StartFlowResultFailure(validation_exceptions=[], result_details=details)
|
|
2726
|
+
|
|
2727
|
+
try:
|
|
2728
|
+
flow = self.get_flow_by_name(flow_name)
|
|
2729
|
+
except KeyError as err:
|
|
2730
|
+
details = f"Cannot start flow. Error: {err}"
|
|
2731
|
+
return StartFlowFromNodeResultFailure(validation_exceptions=[err], result_details=details)
|
|
2732
|
+
|
|
2733
|
+
if not self.check_for_existing_running_flow():
|
|
2734
|
+
msg = "There must be a flow going to start a Subflow"
|
|
2735
|
+
return StartLocalSubflowResultFailure(result_details=msg)
|
|
2736
|
+
|
|
2737
|
+
start_node_name = request.start_node
|
|
2738
|
+
if start_node_name is None:
|
|
2739
|
+
start_nodes = self.get_start_nodes_in_flow(flow)
|
|
2740
|
+
if not start_nodes:
|
|
2741
|
+
details = f"Cannot start subflow '{flow_name}'. No start nodes found in flow."
|
|
2742
|
+
return StartLocalSubflowResultFailure(result_details=details)
|
|
2743
|
+
start_node = start_nodes[0]
|
|
2744
|
+
else:
|
|
2745
|
+
try:
|
|
2746
|
+
start_node = GriptapeNodes.NodeManager().get_node_by_name(start_node_name)
|
|
2747
|
+
except ValueError as err:
|
|
2748
|
+
details = f"Cannot start subflow '{flow_name}'. Start node '{start_node_name}' not found: {err}"
|
|
2749
|
+
return StartLocalSubflowResultFailure(result_details=details)
|
|
2750
|
+
|
|
2751
|
+
subflow_machine = ControlFlowMachine(
|
|
2752
|
+
flow.name,
|
|
2753
|
+
pickle_control_flow_result=request.pickle_control_flow_result,
|
|
2754
|
+
is_isolated=True,
|
|
2755
|
+
)
|
|
2756
|
+
|
|
2757
|
+
await subflow_machine.start_flow(start_node)
|
|
2758
|
+
|
|
2759
|
+
return StartLocalSubflowResultSuccess(result_details=f"Successfully executed local subflow '{flow_name}'")
|
|
2760
|
+
|
|
2431
2761
|
def on_get_flow_state_request(self, event: GetFlowStateRequest) -> ResultPayload:
|
|
2432
2762
|
flow_name = event.flow_name
|
|
2433
2763
|
if not flow_name:
|
|
@@ -2674,6 +3004,78 @@ class FlowManager:
|
|
|
2674
3004
|
|
|
2675
3005
|
return node_types_used
|
|
2676
3006
|
|
|
3007
|
+
def _aggregate_connections(
|
|
3008
|
+
self,
|
|
3009
|
+
flow_connections: list[SerializedFlowCommands.IndirectConnectionSerialization],
|
|
3010
|
+
sub_flows_commands: list[SerializedFlowCommands],
|
|
3011
|
+
) -> list[SerializedFlowCommands.IndirectConnectionSerialization]:
|
|
3012
|
+
"""Aggregate connections from this flow and all sub-flows into a single list.
|
|
3013
|
+
|
|
3014
|
+
Args:
|
|
3015
|
+
flow_connections: List of connections from the current flow
|
|
3016
|
+
sub_flows_commands: List of sub-flow commands to aggregate from
|
|
3017
|
+
|
|
3018
|
+
Returns:
|
|
3019
|
+
List of all connections from this flow and all sub-flows combined
|
|
3020
|
+
"""
|
|
3021
|
+
aggregated_connections = list(flow_connections)
|
|
3022
|
+
|
|
3023
|
+
# Aggregate connections from all sub-flows
|
|
3024
|
+
for sub_flow_cmd in sub_flows_commands:
|
|
3025
|
+
aggregated_connections.extend(sub_flow_cmd.serialized_connections)
|
|
3026
|
+
|
|
3027
|
+
return aggregated_connections
|
|
3028
|
+
|
|
3029
|
+
def _aggregate_unique_parameter_values(
|
|
3030
|
+
self,
|
|
3031
|
+
unique_parameter_uuid_to_values: dict[SerializedNodeCommands.UniqueParameterValueUUID, Any],
|
|
3032
|
+
sub_flows_commands: list[SerializedFlowCommands],
|
|
3033
|
+
) -> dict[SerializedNodeCommands.UniqueParameterValueUUID, Any]:
|
|
3034
|
+
"""Aggregate unique parameter values from this flow and all sub-flows.
|
|
3035
|
+
|
|
3036
|
+
Args:
|
|
3037
|
+
unique_parameter_uuid_to_values: Unique parameter values from current flow
|
|
3038
|
+
sub_flows_commands: List of sub-flow commands to aggregate from
|
|
3039
|
+
|
|
3040
|
+
Returns:
|
|
3041
|
+
Dictionary with all unique parameter values merged
|
|
3042
|
+
"""
|
|
3043
|
+
aggregated_values = dict(unique_parameter_uuid_to_values)
|
|
3044
|
+
|
|
3045
|
+
# Merge unique values from all sub-flows
|
|
3046
|
+
for sub_flow_cmd in sub_flows_commands:
|
|
3047
|
+
aggregated_values.update(sub_flow_cmd.unique_parameter_uuid_to_values)
|
|
3048
|
+
|
|
3049
|
+
return aggregated_values
|
|
3050
|
+
|
|
3051
|
+
def _aggregate_set_parameter_value_commands(
|
|
3052
|
+
self,
|
|
3053
|
+
set_parameter_value_commands: dict[
|
|
3054
|
+
SerializedNodeCommands.NodeUUID, list[SerializedNodeCommands.IndirectSetParameterValueCommand]
|
|
3055
|
+
],
|
|
3056
|
+
sub_flows_commands: list[SerializedFlowCommands],
|
|
3057
|
+
) -> dict[SerializedNodeCommands.NodeUUID, list[SerializedNodeCommands.IndirectSetParameterValueCommand]]:
|
|
3058
|
+
"""Aggregate set parameter value commands from this flow and all sub-flows.
|
|
3059
|
+
|
|
3060
|
+
Args:
|
|
3061
|
+
set_parameter_value_commands: Set parameter value commands from current flow
|
|
3062
|
+
sub_flows_commands: List of sub-flow commands to aggregate from
|
|
3063
|
+
|
|
3064
|
+
Returns:
|
|
3065
|
+
Dictionary with all set parameter value commands merged
|
|
3066
|
+
"""
|
|
3067
|
+
aggregated_commands = dict(set_parameter_value_commands)
|
|
3068
|
+
|
|
3069
|
+
# Merge commands from all sub-flows
|
|
3070
|
+
for sub_flow_cmd in sub_flows_commands:
|
|
3071
|
+
for node_uuid, commands in sub_flow_cmd.set_parameter_value_commands.items():
|
|
3072
|
+
if node_uuid in aggregated_commands:
|
|
3073
|
+
aggregated_commands[node_uuid].extend(commands)
|
|
3074
|
+
else:
|
|
3075
|
+
aggregated_commands[node_uuid] = list(commands)
|
|
3076
|
+
|
|
3077
|
+
return aggregated_commands
|
|
3078
|
+
|
|
2677
3079
|
# TODO: https://github.com/griptape-ai/griptape-nodes/issues/861
|
|
2678
3080
|
# similar manager refactors: https://github.com/griptape-ai/griptape-nodes/issues/806
|
|
2679
3081
|
def on_serialize_flow_to_commands(self, request: SerializeFlowToCommandsRequest) -> ResultPayload: # noqa: C901, PLR0911, PLR0912, PLR0915
|
|
@@ -2713,8 +3115,13 @@ class FlowManager:
|
|
|
2713
3115
|
else:
|
|
2714
3116
|
# Always set set_as_new_context=False during serialization - let the workflow manager
|
|
2715
3117
|
# that loads this serialized flow decide whether to push it to context or not
|
|
3118
|
+
# Get parent flow name from the flow manager's tracking
|
|
3119
|
+
parent_name = self.get_parent_flow(flow_name)
|
|
2716
3120
|
create_flow_request = CreateFlowRequest(
|
|
2717
|
-
|
|
3121
|
+
flow_name=flow_name,
|
|
3122
|
+
parent_flow_name=parent_name,
|
|
3123
|
+
set_as_new_context=False,
|
|
3124
|
+
metadata=flow.metadata,
|
|
2718
3125
|
)
|
|
2719
3126
|
else:
|
|
2720
3127
|
create_flow_request = None
|
|
@@ -2772,22 +3179,8 @@ class FlowManager:
|
|
|
2772
3179
|
)
|
|
2773
3180
|
set_parameter_value_commands_per_node[serialized_node.node_uuid] = set_value_commands_list
|
|
2774
3181
|
|
|
2775
|
-
#
|
|
2776
|
-
#
|
|
2777
|
-
# Create all of the connections
|
|
2778
|
-
create_connection_commands = []
|
|
2779
|
-
for connection in self._get_connections_for_flow(flow):
|
|
2780
|
-
source_node_uuid = node_name_to_uuid[connection.source_node.name]
|
|
2781
|
-
target_node_uuid = node_name_to_uuid[connection.target_node.name]
|
|
2782
|
-
create_connection_command = SerializedFlowCommands.IndirectConnectionSerialization(
|
|
2783
|
-
source_node_uuid=source_node_uuid,
|
|
2784
|
-
source_parameter_name=connection.source_parameter.name,
|
|
2785
|
-
target_node_uuid=target_node_uuid,
|
|
2786
|
-
target_parameter_name=connection.target_parameter.name,
|
|
2787
|
-
)
|
|
2788
|
-
create_connection_commands.append(create_connection_command)
|
|
2789
|
-
|
|
2790
|
-
# Now sub-flows.
|
|
3182
|
+
# Serialize sub-flows first, before creating connections.
|
|
3183
|
+
# We need the complete UUID map from all flows to handle cross-flow connections.
|
|
2791
3184
|
parent_flow = GriptapeNodes.ContextManager().get_current_flow()
|
|
2792
3185
|
parent_flow_name = parent_flow.name
|
|
2793
3186
|
flows_in_flow_request = ListFlowsInFlowRequest(parent_flow_name=parent_flow_name)
|
|
@@ -2798,18 +3191,20 @@ class FlowManager:
|
|
|
2798
3191
|
|
|
2799
3192
|
sub_flow_commands = []
|
|
2800
3193
|
for child_flow in flows_in_flow_result.flow_names:
|
|
2801
|
-
|
|
2802
|
-
|
|
3194
|
+
child_flow_obj = GriptapeNodes.ObjectManager().attempt_get_object_by_name_as_type(
|
|
3195
|
+
child_flow, ControlFlow
|
|
3196
|
+
)
|
|
3197
|
+
if child_flow_obj is None:
|
|
2803
3198
|
details = f"Attempted to serialize Flow '{flow_name}', but no Flow with that name could be found."
|
|
2804
3199
|
return SerializeFlowToCommandsResultFailure(result_details=details)
|
|
2805
3200
|
|
|
2806
3201
|
# Check if this is a referenced workflow
|
|
2807
|
-
if self.is_referenced_workflow(
|
|
3202
|
+
if self.is_referenced_workflow(child_flow_obj):
|
|
2808
3203
|
# For referenced workflows, create a minimal SerializedFlowCommands with just the import command
|
|
2809
|
-
referenced_workflow_name = self.get_referenced_workflow_name(
|
|
3204
|
+
referenced_workflow_name = self.get_referenced_workflow_name(child_flow_obj)
|
|
2810
3205
|
import_command = ImportWorkflowAsReferencedSubFlowRequest(
|
|
2811
3206
|
workflow_name=referenced_workflow_name, # type: ignore[arg-type] # is_referenced_workflow() guarantees this is not None
|
|
2812
|
-
imported_flow_metadata=
|
|
3207
|
+
imported_flow_metadata=child_flow_obj.metadata,
|
|
2813
3208
|
)
|
|
2814
3209
|
|
|
2815
3210
|
# Create NodeDependencies with just the referenced workflow
|
|
@@ -2831,7 +3226,7 @@ class FlowManager:
|
|
|
2831
3226
|
sub_flow_commands.append(serialized_flow)
|
|
2832
3227
|
else:
|
|
2833
3228
|
# For standalone sub-flows, use the existing recursive serialization
|
|
2834
|
-
with GriptapeNodes.ContextManager().flow(flow=
|
|
3229
|
+
with GriptapeNodes.ContextManager().flow(flow=child_flow_obj):
|
|
2835
3230
|
child_flow_request = SerializeFlowToCommandsRequest()
|
|
2836
3231
|
child_flow_result = GriptapeNodes().handle_request(child_flow_request)
|
|
2837
3232
|
if not isinstance(child_flow_result, SerializeFlowToCommandsResultSuccess):
|
|
@@ -2844,6 +3239,65 @@ class FlowManager:
|
|
|
2844
3239
|
# This ensures child nodes exist before their parent NodeGroups are created during deserialization
|
|
2845
3240
|
serialized_node_commands.extend(serialized_node_group_commands)
|
|
2846
3241
|
|
|
3242
|
+
# Update NodeGroupNode commands to use UUIDs instead of names in node_names_to_add
|
|
3243
|
+
# This allows workflow generation to directly look up variable names from UUIDs
|
|
3244
|
+
# Build a complete node name to UUID map including nodes from all subflows
|
|
3245
|
+
complete_node_name_to_uuid = dict(node_name_to_uuid) # Start with current flow's nodes
|
|
3246
|
+
|
|
3247
|
+
def collect_subflow_node_uuids(subflow_commands_list: list[SerializedFlowCommands]) -> None:
|
|
3248
|
+
"""Recursively collect node name-to-UUID mappings from subflows."""
|
|
3249
|
+
for subflow_cmd in subflow_commands_list:
|
|
3250
|
+
for node_cmd in subflow_cmd.serialized_node_commands:
|
|
3251
|
+
# Extract node name from the create command
|
|
3252
|
+
create_cmd = node_cmd.create_node_command
|
|
3253
|
+
if isinstance(create_cmd, CreateNodeRequest) and create_cmd.node_name:
|
|
3254
|
+
complete_node_name_to_uuid[create_cmd.node_name] = node_cmd.node_uuid
|
|
3255
|
+
elif isinstance(create_cmd, CreateNodeGroupRequest) and create_cmd.node_group_name:
|
|
3256
|
+
complete_node_name_to_uuid[create_cmd.node_group_name] = node_cmd.node_uuid
|
|
3257
|
+
# Recursively process nested subflows
|
|
3258
|
+
if subflow_cmd.sub_flows_commands:
|
|
3259
|
+
collect_subflow_node_uuids(subflow_cmd.sub_flows_commands)
|
|
3260
|
+
|
|
3261
|
+
collect_subflow_node_uuids(sub_flow_commands)
|
|
3262
|
+
|
|
3263
|
+
for node_group_command in serialized_node_group_commands:
|
|
3264
|
+
create_cmd = node_group_command.create_node_command
|
|
3265
|
+
|
|
3266
|
+
if isinstance(create_cmd, CreateNodeGroupRequest) and create_cmd.node_names_to_add:
|
|
3267
|
+
# Convert node names to UUIDs using the complete map (including subflows)
|
|
3268
|
+
node_uuids = []
|
|
3269
|
+
for child_node_name in create_cmd.node_names_to_add:
|
|
3270
|
+
if child_node_name in complete_node_name_to_uuid:
|
|
3271
|
+
uuid = complete_node_name_to_uuid[child_node_name]
|
|
3272
|
+
node_uuids.append(uuid)
|
|
3273
|
+
# Replace the list with UUIDs (as strings since that's what the field expects)
|
|
3274
|
+
create_cmd.node_names_to_add = node_uuids
|
|
3275
|
+
|
|
3276
|
+
# Now create the connections using the complete UUID map (includes all flows).
|
|
3277
|
+
# This must happen after subflows are serialized so we have all UUIDs available.
|
|
3278
|
+
create_connection_commands = []
|
|
3279
|
+
for connection in self._get_connections_for_flow(flow):
|
|
3280
|
+
source_node_name = connection.source_node.name
|
|
3281
|
+
target_node_name = connection.target_node.name
|
|
3282
|
+
|
|
3283
|
+
# Use the complete UUID map that includes nodes from all subflows
|
|
3284
|
+
if source_node_name not in complete_node_name_to_uuid:
|
|
3285
|
+
details = f"Attempted to serialize Flow '{flow_name}'. Connection source node '{source_node_name}' not found in UUID map."
|
|
3286
|
+
return SerializeFlowToCommandsResultFailure(result_details=details)
|
|
3287
|
+
if target_node_name not in complete_node_name_to_uuid:
|
|
3288
|
+
details = f"Attempted to serialize Flow '{flow_name}'. Connection target node '{target_node_name}' not found in UUID map."
|
|
3289
|
+
return SerializeFlowToCommandsResultFailure(result_details=details)
|
|
3290
|
+
|
|
3291
|
+
source_node_uuid = complete_node_name_to_uuid[source_node_name]
|
|
3292
|
+
target_node_uuid = complete_node_name_to_uuid[target_node_name]
|
|
3293
|
+
create_connection_command = SerializedFlowCommands.IndirectConnectionSerialization(
|
|
3294
|
+
source_node_uuid=source_node_uuid,
|
|
3295
|
+
source_parameter_name=connection.source_parameter.name,
|
|
3296
|
+
target_node_uuid=target_node_uuid,
|
|
3297
|
+
target_parameter_name=connection.target_parameter.name,
|
|
3298
|
+
)
|
|
3299
|
+
create_connection_commands.append(create_connection_command)
|
|
3300
|
+
|
|
2847
3301
|
# Aggregate all dependencies from nodes and sub-flows
|
|
2848
3302
|
aggregated_dependencies = self._aggregate_flow_dependencies(serialized_node_commands, sub_flow_commands)
|
|
2849
3303
|
|
|
@@ -2854,16 +3308,30 @@ class FlowManager:
|
|
|
2854
3308
|
details = f"Attempted to serialize Flow '{flow_name}' to commands. Failed while aggregating node types: {e}"
|
|
2855
3309
|
return SerializeFlowToCommandsResultFailure(result_details=details)
|
|
2856
3310
|
|
|
3311
|
+
# Aggregate unique parameter values from this flow and all sub-flows
|
|
3312
|
+
aggregated_unique_values = self._aggregate_unique_parameter_values(
|
|
3313
|
+
unique_parameter_uuid_to_values, sub_flow_commands
|
|
3314
|
+
)
|
|
3315
|
+
|
|
3316
|
+
# Aggregate all connections from this flow and all sub-flows
|
|
3317
|
+
aggregated_connections = self._aggregate_connections(create_connection_commands, sub_flow_commands)
|
|
3318
|
+
|
|
3319
|
+
# Extract flow name from initialization command if available
|
|
3320
|
+
extracted_flow_name = None
|
|
3321
|
+
if create_flow_request is not None and hasattr(create_flow_request, "flow_name"):
|
|
3322
|
+
extracted_flow_name = create_flow_request.flow_name
|
|
3323
|
+
|
|
2857
3324
|
serialized_flow = SerializedFlowCommands(
|
|
2858
3325
|
flow_initialization_command=create_flow_request,
|
|
2859
3326
|
serialized_node_commands=serialized_node_commands,
|
|
2860
|
-
serialized_connections=
|
|
2861
|
-
unique_parameter_uuid_to_values=
|
|
3327
|
+
serialized_connections=aggregated_connections,
|
|
3328
|
+
unique_parameter_uuid_to_values=aggregated_unique_values,
|
|
2862
3329
|
set_parameter_value_commands=set_parameter_value_commands_per_node,
|
|
2863
3330
|
set_lock_commands_per_node=set_lock_commands_per_node,
|
|
2864
3331
|
sub_flows_commands=sub_flow_commands,
|
|
2865
3332
|
node_dependencies=aggregated_dependencies,
|
|
2866
3333
|
node_types_used=aggregated_node_types_used,
|
|
3334
|
+
flow_name=extracted_flow_name,
|
|
2867
3335
|
)
|
|
2868
3336
|
details = f"Successfully serialized Flow '{flow_name}' into commands."
|
|
2869
3337
|
result = SerializeFlowToCommandsResultSuccess(serialized_flow_commands=serialized_flow, result_details=details)
|
|
@@ -2910,16 +3378,52 @@ class FlowManager:
|
|
|
2910
3378
|
|
|
2911
3379
|
# Create the nodes.
|
|
2912
3380
|
# Preserve the node UUIDs because we will need to tie these back together with the Connections later.
|
|
3381
|
+
# Also build a mapping from original node names to deserialized node names.
|
|
2913
3382
|
node_uuid_to_deserialized_node_result = {}
|
|
3383
|
+
node_name_mappings = {}
|
|
2914
3384
|
for serialized_node in request.serialized_flow_commands.serialized_node_commands:
|
|
2915
|
-
|
|
3385
|
+
# Get the node name from the CreateNodeGroupRequest command if necessary
|
|
3386
|
+
create_cmd = serialized_node.create_node_command
|
|
3387
|
+
original_node_name = (
|
|
3388
|
+
create_cmd.node_group_name if isinstance(create_cmd, CreateNodeGroupRequest) else create_cmd.node_name
|
|
3389
|
+
)
|
|
3390
|
+
|
|
3391
|
+
# For NodeGroupNodes, remap node_names_to_add from UUIDs to actual node names
|
|
3392
|
+
# Create a copy to avoid mutating the original serialized data
|
|
3393
|
+
serialized_node_for_deserialization = serialized_node
|
|
3394
|
+
if isinstance(create_cmd, CreateNodeGroupRequest) and create_cmd.node_names_to_add:
|
|
3395
|
+
# Use list comprehension to remap UUIDs to deserialized node names
|
|
3396
|
+
remapped_names = [
|
|
3397
|
+
node_uuid_to_deserialized_node_result[node_uuid].node_name
|
|
3398
|
+
for node_uuid in create_cmd.node_names_to_add
|
|
3399
|
+
if node_uuid in node_uuid_to_deserialized_node_result
|
|
3400
|
+
]
|
|
3401
|
+
# Create a copy of the command with remapped names instead of mutating original
|
|
3402
|
+
create_cmd_copy = CreateNodeGroupRequest(
|
|
3403
|
+
node_group_name=create_cmd.node_group_name,
|
|
3404
|
+
node_names_to_add=remapped_names,
|
|
3405
|
+
metadata=create_cmd.metadata,
|
|
3406
|
+
)
|
|
3407
|
+
# Create a copy of serialized_node with the new command
|
|
3408
|
+
serialized_node_for_deserialization = SerializedNodeCommands(
|
|
3409
|
+
node_uuid=serialized_node.node_uuid,
|
|
3410
|
+
create_node_command=create_cmd_copy,
|
|
3411
|
+
element_modification_commands=serialized_node.element_modification_commands,
|
|
3412
|
+
node_dependencies=serialized_node.node_dependencies,
|
|
3413
|
+
lock_node_command=serialized_node.lock_node_command,
|
|
3414
|
+
)
|
|
3415
|
+
|
|
3416
|
+
deserialize_node_request = DeserializeNodeFromCommandsRequest(
|
|
3417
|
+
serialized_node_commands=serialized_node_for_deserialization
|
|
3418
|
+
)
|
|
2916
3419
|
deserialized_node_result = GriptapeNodes.handle_request(deserialize_node_request)
|
|
2917
|
-
if deserialized_node_result
|
|
3420
|
+
if not isinstance(deserialized_node_result, DeserializeNodeFromCommandsResultSuccess):
|
|
2918
3421
|
details = (
|
|
2919
3422
|
f"Attempted to deserialize a Flow '{flow_name}'. Failed while deserializing a node within the flow."
|
|
2920
3423
|
)
|
|
2921
3424
|
return DeserializeFlowFromCommandsResultFailure(result_details=details)
|
|
2922
3425
|
node_uuid_to_deserialized_node_result[serialized_node.node_uuid] = deserialized_node_result
|
|
3426
|
+
node_name_mappings[original_node_name] = deserialized_node_result.node_name
|
|
2923
3427
|
|
|
2924
3428
|
# Now apply the connections.
|
|
2925
3429
|
# We didn't know the exact name that would be used for the nodes, but we knew the node's creation UUID.
|
|
@@ -2992,7 +3496,9 @@ class FlowManager:
|
|
|
2992
3496
|
return DeserializeFlowFromCommandsResultFailure(result_details=details)
|
|
2993
3497
|
|
|
2994
3498
|
details = f"Successfully deserialized Flow '{flow_name}'."
|
|
2995
|
-
return DeserializeFlowFromCommandsResultSuccess(
|
|
3499
|
+
return DeserializeFlowFromCommandsResultSuccess(
|
|
3500
|
+
flow_name=flow_name, result_details=details, node_name_mappings=node_name_mappings
|
|
3501
|
+
)
|
|
2996
3502
|
|
|
2997
3503
|
async def start_flow(
|
|
2998
3504
|
self,
|
|
@@ -3098,9 +3604,10 @@ class FlowManager:
|
|
|
3098
3604
|
self._global_flow_queue.task_done()
|
|
3099
3605
|
return queue_item.node
|
|
3100
3606
|
|
|
3101
|
-
def clear_execution_queue(self) -> None:
|
|
3607
|
+
def clear_execution_queue(self, flow: ControlFlow) -> None:
|
|
3102
3608
|
"""Clear all nodes from the global execution queue."""
|
|
3103
|
-
self.
|
|
3609
|
+
if self._global_control_flow_machine and self._global_control_flow_machine.context.flow_name == flow.name:
|
|
3610
|
+
self._global_flow_queue.queue.clear()
|
|
3104
3611
|
|
|
3105
3612
|
def has_connection(
|
|
3106
3613
|
self,
|
|
@@ -3303,10 +3810,13 @@ class FlowManager:
|
|
|
3303
3810
|
connection_ids = connections.incoming_index[node.name][parameter_name]
|
|
3304
3811
|
for connection_id in connection_ids:
|
|
3305
3812
|
connection = connections.connections[connection_id]
|
|
3813
|
+
# Skip internal NodeGroup connections
|
|
3814
|
+
if connection.is_node_group_internal:
|
|
3815
|
+
continue
|
|
3306
3816
|
return connection.get_source_node()
|
|
3307
3817
|
return None
|
|
3308
3818
|
|
|
3309
|
-
def get_start_node_queue(self) -> Queue | None: # noqa: C901, PLR0912
|
|
3819
|
+
def get_start_node_queue(self) -> Queue | None: # noqa: C901, PLR0912, PLR0915
|
|
3310
3820
|
# For cross-flow execution, we need to consider ALL nodes across ALL flows
|
|
3311
3821
|
# Clear and use the global execution queue
|
|
3312
3822
|
self._global_flow_queue.queue.clear()
|
|
@@ -3327,6 +3837,10 @@ class FlowManager:
|
|
|
3327
3837
|
control_nodes = []
|
|
3328
3838
|
cn_mgr = self.get_connections()
|
|
3329
3839
|
for node in all_nodes:
|
|
3840
|
+
# Skip nodes that are children of a NodeGroupNode - they should not be start nodes
|
|
3841
|
+
if node.parent_group is not None and isinstance(node.parent_group, NodeGroupNode):
|
|
3842
|
+
continue
|
|
3843
|
+
|
|
3330
3844
|
# if it's a start node, start here! Return the first one!
|
|
3331
3845
|
if isinstance(node, StartNode):
|
|
3332
3846
|
start_nodes.append(node)
|
|
@@ -3358,17 +3872,26 @@ class FlowManager:
|
|
|
3358
3872
|
param = node.get_parameter_by_name(param_name)
|
|
3359
3873
|
if param and ParameterTypeBuiltin.CONTROL_TYPE.value == param.output_type:
|
|
3360
3874
|
# there is a control connection coming in
|
|
3361
|
-
#
|
|
3362
|
-
|
|
3363
|
-
|
|
3875
|
+
# Check each connection to see if it's an internal NodeGroup connection
|
|
3876
|
+
connection_ids = cn_mgr.incoming_index[node.name][param_name]
|
|
3877
|
+
has_external_control_connection = False
|
|
3878
|
+
for connection_id in connection_ids:
|
|
3364
3879
|
connection = cn_mgr.connections[connection_id]
|
|
3365
|
-
|
|
3366
|
-
|
|
3367
|
-
# If it is, then this could still be the first node in the control flow.
|
|
3368
|
-
if connected_node == node.end_node:
|
|
3880
|
+
# Skip internal NodeGroup connections - they shouldn't disqualify a node from being a start node
|
|
3881
|
+
if connection.is_node_group_internal:
|
|
3369
3882
|
continue
|
|
3370
|
-
|
|
3371
|
-
|
|
3883
|
+
# If the node is a BaseIterativeStartNode, it may have an incoming hidden connection from it's BaseIterativeEndNode for iteration.
|
|
3884
|
+
if isinstance(node, BaseIterativeStartNode):
|
|
3885
|
+
connected_node = connection.get_source_node()
|
|
3886
|
+
# Check if the source node is the end loop node associated with this BaseIterativeStartNode.
|
|
3887
|
+
# If it is, then this could still be the first node in the control flow.
|
|
3888
|
+
if connected_node == node.end_node:
|
|
3889
|
+
continue
|
|
3890
|
+
has_external_control_connection = True
|
|
3891
|
+
break
|
|
3892
|
+
if has_external_control_connection:
|
|
3893
|
+
has_control_connection = True
|
|
3894
|
+
break
|
|
3372
3895
|
# if there is a connection coming in, isn't a start.
|
|
3373
3896
|
if has_control_connection:
|
|
3374
3897
|
continue
|
|
@@ -3387,8 +3910,22 @@ class FlowManager:
|
|
|
3387
3910
|
# Let's return a data node that has no OUTGOING data connections!
|
|
3388
3911
|
for node in data_nodes:
|
|
3389
3912
|
cn_mgr = self.get_connections()
|
|
3390
|
-
#
|
|
3391
|
-
|
|
3913
|
+
# Check if the node has any non-internal outgoing connections
|
|
3914
|
+
has_external_outgoing = False
|
|
3915
|
+
if node.name in cn_mgr.outgoing_index:
|
|
3916
|
+
for param_name in cn_mgr.outgoing_index[node.name]:
|
|
3917
|
+
connection_ids = cn_mgr.outgoing_index[node.name][param_name]
|
|
3918
|
+
for connection_id in connection_ids:
|
|
3919
|
+
connection = cn_mgr.connections[connection_id]
|
|
3920
|
+
# Skip internal NodeGroup connections
|
|
3921
|
+
if connection.is_node_group_internal:
|
|
3922
|
+
continue
|
|
3923
|
+
has_external_outgoing = True
|
|
3924
|
+
break
|
|
3925
|
+
if has_external_outgoing:
|
|
3926
|
+
break
|
|
3927
|
+
# Only add nodes that have no non-internal outgoing connections
|
|
3928
|
+
if not has_external_outgoing:
|
|
3392
3929
|
valid_data_nodes.append(node)
|
|
3393
3930
|
# ok now - populate the global flow queue with node type information
|
|
3394
3931
|
for node in start_nodes:
|
|
@@ -3440,7 +3977,7 @@ class FlowManager:
|
|
|
3440
3977
|
connections.append((connection.source_node, connection.source_parameter))
|
|
3441
3978
|
return connections
|
|
3442
3979
|
|
|
3443
|
-
def get_connections_on_node(self,
|
|
3980
|
+
def get_connections_on_node(self, node: BaseNode) -> list[BaseNode] | None:
|
|
3444
3981
|
connections = self.get_connections()
|
|
3445
3982
|
# get all of the connection ids
|
|
3446
3983
|
connected_nodes = []
|
|
@@ -3467,7 +4004,7 @@ class FlowManager:
|
|
|
3467
4004
|
# Return all connected nodes. No duplicates
|
|
3468
4005
|
return connected_nodes
|
|
3469
4006
|
|
|
3470
|
-
def get_all_connected_nodes(self,
|
|
4007
|
+
def get_all_connected_nodes(self, node: BaseNode) -> list[BaseNode]:
|
|
3471
4008
|
discovered = {}
|
|
3472
4009
|
processed = {}
|
|
3473
4010
|
queue = Queue()
|
|
@@ -3476,7 +4013,7 @@ class FlowManager:
|
|
|
3476
4013
|
while not queue.empty():
|
|
3477
4014
|
curr_node = queue.get()
|
|
3478
4015
|
processed[curr_node] = True
|
|
3479
|
-
next_nodes = self.get_connections_on_node(
|
|
4016
|
+
next_nodes = self.get_connections_on_node(curr_node)
|
|
3480
4017
|
if next_nodes:
|
|
3481
4018
|
for next_node in next_nodes:
|
|
3482
4019
|
if next_node not in discovered:
|
|
@@ -3484,6 +4021,10 @@ class FlowManager:
|
|
|
3484
4021
|
queue.put(next_node)
|
|
3485
4022
|
return list(processed.keys())
|
|
3486
4023
|
|
|
4024
|
+
def is_node_connected(self, start_node: BaseNode, node: BaseNode) -> bool:
|
|
4025
|
+
nodes = self.get_all_connected_nodes(start_node)
|
|
4026
|
+
return node in nodes
|
|
4027
|
+
|
|
3487
4028
|
def get_node_dependencies(self, flow: ControlFlow, node: BaseNode) -> list[BaseNode]:
|
|
3488
4029
|
"""Get all upstream nodes that the given node depends on.
|
|
3489
4030
|
|