griptape-nodes 0.63.10__py3-none-any.whl → 0.64.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/common/node_executor.py +95 -171
- griptape_nodes/exe_types/connections.py +51 -2
- griptape_nodes/exe_types/flow.py +3 -3
- griptape_nodes/exe_types/node_types.py +330 -202
- griptape_nodes/exe_types/param_components/artifact_url/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/artifact_url/public_artifact_url_parameter.py +155 -0
- griptape_nodes/exe_types/param_components/progress_bar_component.py +1 -1
- griptape_nodes/exe_types/param_types/parameter_string.py +27 -0
- griptape_nodes/machines/control_flow.py +64 -203
- griptape_nodes/machines/dag_builder.py +85 -238
- griptape_nodes/machines/parallel_resolution.py +9 -236
- griptape_nodes/machines/sequential_resolution.py +133 -11
- griptape_nodes/retained_mode/events/agent_events.py +2 -0
- griptape_nodes/retained_mode/events/flow_events.py +5 -6
- griptape_nodes/retained_mode/events/node_events.py +151 -1
- griptape_nodes/retained_mode/events/workflow_events.py +10 -0
- griptape_nodes/retained_mode/managers/agent_manager.py +33 -1
- griptape_nodes/retained_mode/managers/flow_manager.py +213 -290
- griptape_nodes/retained_mode/managers/library_manager.py +24 -7
- griptape_nodes/retained_mode/managers/node_manager.py +400 -77
- griptape_nodes/retained_mode/managers/version_compatibility_manager.py +113 -69
- griptape_nodes/retained_mode/managers/workflow_manager.py +45 -10
- griptape_nodes/servers/mcp.py +32 -0
- griptape_nodes/version_compatibility/versions/v0_63_8/__init__.py +1 -0
- griptape_nodes/version_compatibility/versions/v0_63_8/deprecated_nodegroup_parameters.py +105 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/METADATA +3 -1
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/RECORD +31 -28
- griptape_nodes/version_compatibility/workflow_versions/__init__.py +0 -1
- /griptape_nodes/version_compatibility/{workflow_versions → versions}/v0_7_0/__init__.py +0 -0
- /griptape_nodes/version_compatibility/{workflow_versions → versions}/v0_7_0/local_executor_argument_addition.py +0 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/WHEEL +0 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Reusable artifact URL parameters."""
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from griptape.artifacts.audio_url_artifact import AudioUrlArtifact
|
|
8
|
+
from griptape.artifacts.image_url_artifact import ImageUrlArtifact
|
|
9
|
+
from griptape.artifacts.url_artifact import UrlArtifact
|
|
10
|
+
from griptape.artifacts.video_url_artifact import VideoUrlArtifact
|
|
11
|
+
|
|
12
|
+
from griptape_nodes.drivers.storage.griptape_cloud_storage_driver import GriptapeCloudStorageDriver
|
|
13
|
+
from griptape_nodes.exe_types.core_types import NodeMessageResult, Parameter, ParameterMessage
|
|
14
|
+
from griptape_nodes.exe_types.node_types import BaseNode
|
|
15
|
+
from griptape_nodes.retained_mode.events.config_events import GetConfigValueRequest, GetConfigValueResultSuccess
|
|
16
|
+
from griptape_nodes.retained_mode.events.secrets_events import GetSecretValueRequest, GetSecretValueResultSuccess
|
|
17
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
18
|
+
from griptape_nodes.traits.button import Button, ButtonDetailsMessagePayload
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PublicArtifactUrlParameter:
|
|
22
|
+
"""A reusable component for managing artifact URLs and ensuring public internet accessibility.
|
|
23
|
+
|
|
24
|
+
This component utilizes Griptape Cloud to provide public URLs for artifact parameters if needed.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
API_KEY_NAME = "GT_CLOUD_API_KEY"
|
|
28
|
+
BUCKET_ID_NAME = "GT_CLOUD_BUCKET_ID"
|
|
29
|
+
supported_artifact_types: ClassVar[list[type]] = [ImageUrlArtifact, VideoUrlArtifact, AudioUrlArtifact]
|
|
30
|
+
supported_artifact_type_names: ClassVar[list[str]] = [cls.__name__ for cls in supported_artifact_types]
|
|
31
|
+
gtc_file_path: Path | None = None
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self, node: BaseNode, artifact_url_parameter: Parameter, disclaimer_message: str | None = None
|
|
35
|
+
) -> None:
|
|
36
|
+
self._node = node
|
|
37
|
+
self._parameter = artifact_url_parameter
|
|
38
|
+
self._disclaimer_message = disclaimer_message
|
|
39
|
+
|
|
40
|
+
if artifact_url_parameter.type.lower() not in [name.lower() for name in self.supported_artifact_type_names]:
|
|
41
|
+
msg = (
|
|
42
|
+
f"Unsupported artifact type '{artifact_url_parameter.type}' for "
|
|
43
|
+
f"artifact URL parameter '{artifact_url_parameter.name}'. "
|
|
44
|
+
f"Supported types: {', '.join(self.supported_artifact_type_names)}"
|
|
45
|
+
)
|
|
46
|
+
raise ValueError(msg)
|
|
47
|
+
|
|
48
|
+
api_key = str(self._get_secret_value(self.API_KEY_NAME))
|
|
49
|
+
base = os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")
|
|
50
|
+
self._storage_driver = GriptapeCloudStorageDriver(
|
|
51
|
+
workspace_directory=GriptapeNodes.ConfigManager().workspace_path,
|
|
52
|
+
bucket_id=self._get_bucket_id(base, api_key),
|
|
53
|
+
api_key=api_key,
|
|
54
|
+
base_url=base,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def _get_bucket_id(cls, base_url: str, api_key: str) -> str:
|
|
59
|
+
bucket_id: str | None = cls._get_secret_value(cls.BUCKET_ID_NAME, should_error_on_not_found=False)
|
|
60
|
+
|
|
61
|
+
if bucket_id is not None:
|
|
62
|
+
return bucket_id
|
|
63
|
+
|
|
64
|
+
buckets = GriptapeCloudStorageDriver.list_buckets(
|
|
65
|
+
base_url=base_url,
|
|
66
|
+
api_key=api_key,
|
|
67
|
+
)
|
|
68
|
+
if len(buckets) == 0:
|
|
69
|
+
msg = "No Griptape Cloud storage buckets found!"
|
|
70
|
+
raise RuntimeError(msg)
|
|
71
|
+
|
|
72
|
+
return buckets[0]["bucket_id"]
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def _get_config_value(cls, key: str, default: Any | None = None) -> Any | None:
|
|
76
|
+
request = GetConfigValueRequest(category_and_key=key)
|
|
77
|
+
result_event = GriptapeNodes.handle_request(request)
|
|
78
|
+
|
|
79
|
+
if isinstance(result_event, GetConfigValueResultSuccess):
|
|
80
|
+
return result_event.value
|
|
81
|
+
|
|
82
|
+
return default
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def _get_secret_value(
|
|
86
|
+
cls, key: str, default: Any | None = None, *, should_error_on_not_found: bool = False
|
|
87
|
+
) -> Any | None:
|
|
88
|
+
request = GetSecretValueRequest(key=key, should_error_on_not_found=should_error_on_not_found)
|
|
89
|
+
result_event = GriptapeNodes.handle_request(request)
|
|
90
|
+
|
|
91
|
+
if isinstance(result_event, GetSecretValueResultSuccess):
|
|
92
|
+
return result_event.value
|
|
93
|
+
|
|
94
|
+
return default
|
|
95
|
+
|
|
96
|
+
def add_input_parameters(self) -> None:
|
|
97
|
+
self._node.add_node_element(
|
|
98
|
+
ParameterMessage(
|
|
99
|
+
name=f"artifact_url_parameter_message_{self._parameter.name}",
|
|
100
|
+
title="Media Upload",
|
|
101
|
+
variant="warning",
|
|
102
|
+
value=self.get_help_message(),
|
|
103
|
+
traits={
|
|
104
|
+
Button(
|
|
105
|
+
full_width=True,
|
|
106
|
+
on_click=self._onparameter_message_button_click,
|
|
107
|
+
)
|
|
108
|
+
},
|
|
109
|
+
button_text="Hide this message",
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
self._node.add_parameter(self._parameter)
|
|
113
|
+
|
|
114
|
+
def _onparameter_message_button_click(
|
|
115
|
+
self,
|
|
116
|
+
button: Button, # noqa: ARG002
|
|
117
|
+
button_payload: ButtonDetailsMessagePayload, # noqa: ARG002
|
|
118
|
+
) -> NodeMessageResult | None:
|
|
119
|
+
self._node.hide_message_by_name(f"artifact_url_parameter_message_{self._parameter.name}")
|
|
120
|
+
|
|
121
|
+
def get_help_message(self) -> str:
|
|
122
|
+
return (
|
|
123
|
+
f"The {self._node.name} node requires a public URL for the parameter: {self._parameter.name}.\n\n"
|
|
124
|
+
f"{self._disclaimer_message or ''}\n"
|
|
125
|
+
"Executing this node will generate a short lived, public URL for the media artifact, which will be cleaned up after execution.\n"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def get_public_url_for_parameter(self) -> str:
|
|
129
|
+
parameter_value = self._node.get_parameter_value(self._parameter.name)
|
|
130
|
+
url = parameter_value.value if isinstance(parameter_value, UrlArtifact) else parameter_value
|
|
131
|
+
|
|
132
|
+
# check if the URL is already public
|
|
133
|
+
if url.startswith(("http://", "https://")) and "localhost" not in url:
|
|
134
|
+
return url
|
|
135
|
+
|
|
136
|
+
workspace_path = GriptapeNodes.ConfigManager().workspace_path
|
|
137
|
+
static_files_dir = str(self._get_config_value("static_files_directory", default="staticfiles"))
|
|
138
|
+
static_files_path = workspace_path / static_files_dir
|
|
139
|
+
|
|
140
|
+
parsed_url = urlparse(url)
|
|
141
|
+
filename = Path(parsed_url.path).name
|
|
142
|
+
with (static_files_path / filename).open("rb") as f:
|
|
143
|
+
file_contents = f.read()
|
|
144
|
+
|
|
145
|
+
self.gtc_file_path = Path(static_files_dir) / "artifact_url_storage" / uuid4().hex / filename
|
|
146
|
+
|
|
147
|
+
# upload to Griptape Cloud and get a public URL
|
|
148
|
+
public_url = self._storage_driver.upload_file(path=self.gtc_file_path, file_content=file_contents)
|
|
149
|
+
|
|
150
|
+
return public_url
|
|
151
|
+
|
|
152
|
+
def delete_uploaded_artifact(self) -> None:
|
|
153
|
+
if not self.gtc_file_path:
|
|
154
|
+
return
|
|
155
|
+
self._storage_driver.delete_file(self.gtc_file_path)
|
|
@@ -35,7 +35,7 @@ class ProgressBarComponent:
|
|
|
35
35
|
self._current_step += steps
|
|
36
36
|
if self._current_step > self._total_steps:
|
|
37
37
|
logger.warning(
|
|
38
|
-
"Current step %i exceeds total steps %i. Progress will not exceed 100
|
|
38
|
+
"Current step %i exceeds total steps %i. Progress will not exceed 100%%.",
|
|
39
39
|
self._current_step,
|
|
40
40
|
self._total_steps,
|
|
41
41
|
)
|
|
@@ -44,6 +44,7 @@ class ParameterString(Parameter):
|
|
|
44
44
|
markdown: bool = False,
|
|
45
45
|
multiline: bool = False,
|
|
46
46
|
placeholder_text: str | None = None,
|
|
47
|
+
is_full_width: bool = False,
|
|
47
48
|
accept_any: bool = True,
|
|
48
49
|
hide: bool = False,
|
|
49
50
|
hide_label: bool = False,
|
|
@@ -78,6 +79,7 @@ class ParameterString(Parameter):
|
|
|
78
79
|
markdown: Whether to enable markdown rendering
|
|
79
80
|
multiline: Whether to use multiline input
|
|
80
81
|
placeholder_text: Placeholder text for the input field
|
|
82
|
+
is_full_width: Whether the parameter should take full width in the UI
|
|
81
83
|
accept_any: Whether to accept any input type and convert to string (default: True)
|
|
82
84
|
hide: Whether to hide the entire parameter
|
|
83
85
|
hide_label: Whether to hide the parameter label
|
|
@@ -105,6 +107,8 @@ class ParameterString(Parameter):
|
|
|
105
107
|
ui_options["multiline"] = multiline
|
|
106
108
|
if placeholder_text is not None:
|
|
107
109
|
ui_options["placeholder_text"] = placeholder_text
|
|
110
|
+
if is_full_width:
|
|
111
|
+
ui_options["is_full_width"] = is_full_width
|
|
108
112
|
|
|
109
113
|
# Set up string conversion based on accept_any setting
|
|
110
114
|
if converters is None:
|
|
@@ -230,3 +234,26 @@ class ParameterString(Parameter):
|
|
|
230
234
|
self.ui_options = ui_options
|
|
231
235
|
else:
|
|
232
236
|
self.update_ui_options_key("placeholder_text", value)
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def is_full_width(self) -> bool:
|
|
240
|
+
"""Get whether the parameter should take full width in the UI.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
True if full width is enabled, False otherwise
|
|
244
|
+
"""
|
|
245
|
+
return self.ui_options.get("is_full_width", False)
|
|
246
|
+
|
|
247
|
+
@is_full_width.setter
|
|
248
|
+
def is_full_width(self, value: bool) -> None:
|
|
249
|
+
"""Set whether the parameter should take full width in the UI.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
value: Whether to enable full width
|
|
253
|
+
"""
|
|
254
|
+
if value:
|
|
255
|
+
self.update_ui_options_key("is_full_width", value)
|
|
256
|
+
else:
|
|
257
|
+
ui_options = self.ui_options.copy()
|
|
258
|
+
ui_options.pop("is_full_width", None)
|
|
259
|
+
self.ui_options = ui_options
|
|
@@ -6,9 +6,15 @@ from dataclasses import dataclass
|
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
from griptape_nodes.exe_types.core_types import Parameter, ParameterTypeBuiltin
|
|
9
|
-
from griptape_nodes.exe_types.node_types import
|
|
9
|
+
from griptape_nodes.exe_types.node_types import (
|
|
10
|
+
CONTROL_INPUT_PARAMETER,
|
|
11
|
+
LOCAL_EXECUTION,
|
|
12
|
+
BaseNode,
|
|
13
|
+
NodeGroupNode,
|
|
14
|
+
NodeResolutionState,
|
|
15
|
+
)
|
|
10
16
|
from griptape_nodes.machines.fsm import FSM, State
|
|
11
|
-
from griptape_nodes.machines.parallel_resolution import
|
|
17
|
+
from griptape_nodes.machines.parallel_resolution import ParallelResolutionMachine
|
|
12
18
|
from griptape_nodes.machines.sequential_resolution import SequentialResolutionMachine
|
|
13
19
|
from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent
|
|
14
20
|
from griptape_nodes.retained_mode.events.execution_events import (
|
|
@@ -18,12 +24,11 @@ from griptape_nodes.retained_mode.events.execution_events import (
|
|
|
18
24
|
SelectedControlOutputEvent,
|
|
19
25
|
)
|
|
20
26
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
27
|
+
from griptape_nodes.retained_mode.managers.node_manager import NodeManager
|
|
21
28
|
from griptape_nodes.retained_mode.managers.settings import WorkflowExecutionMode
|
|
22
29
|
|
|
23
30
|
if TYPE_CHECKING:
|
|
24
|
-
from griptape_nodes.exe_types.connections import Connections
|
|
25
31
|
from griptape_nodes.exe_types.flow import ControlFlow
|
|
26
|
-
from griptape_nodes.exe_types.node_types import NodeGroup
|
|
27
32
|
|
|
28
33
|
|
|
29
34
|
@dataclass
|
|
@@ -34,10 +39,6 @@ class NextNodeInfo:
|
|
|
34
39
|
entry_parameter: Parameter | None
|
|
35
40
|
|
|
36
41
|
|
|
37
|
-
if TYPE_CHECKING:
|
|
38
|
-
from griptape_nodes.exe_types.core_types import Parameter
|
|
39
|
-
from griptape_nodes.exe_types.flow import ControlFlow
|
|
40
|
-
|
|
41
42
|
logger = logging.getLogger("griptape_nodes")
|
|
42
43
|
|
|
43
44
|
|
|
@@ -50,7 +51,6 @@ class ControlFlowContext:
|
|
|
50
51
|
paused: bool = False
|
|
51
52
|
flow_name: str
|
|
52
53
|
pickle_control_flow_result: bool
|
|
53
|
-
node_to_proxy_map: dict[BaseNode, BaseNode]
|
|
54
54
|
end_node: BaseNode | None = None
|
|
55
55
|
|
|
56
56
|
def __init__(
|
|
@@ -64,7 +64,6 @@ class ControlFlowContext:
|
|
|
64
64
|
self.flow_name = flow_name
|
|
65
65
|
if execution_type == WorkflowExecutionMode.PARALLEL:
|
|
66
66
|
# Get the global DagBuilder from FlowManager
|
|
67
|
-
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
68
67
|
|
|
69
68
|
dag_builder = GriptapeNodes.FlowManager().global_dag_builder
|
|
70
69
|
self.resolution_machine = ParallelResolutionMachine(
|
|
@@ -74,7 +73,6 @@ class ControlFlowContext:
|
|
|
74
73
|
self.resolution_machine = SequentialResolutionMachine()
|
|
75
74
|
self.current_nodes = []
|
|
76
75
|
self.pickle_control_flow_result = pickle_control_flow_result
|
|
77
|
-
self.node_to_proxy_map = {}
|
|
78
76
|
|
|
79
77
|
def get_next_nodes(self, output_parameter: Parameter | None = None) -> list[NextNodeInfo]:
|
|
80
78
|
"""Get all next nodes from the current nodes.
|
|
@@ -94,7 +92,11 @@ class ControlFlowContext:
|
|
|
94
92
|
next_nodes.append(NextNodeInfo(node=node, entry_parameter=entry_parameter))
|
|
95
93
|
else:
|
|
96
94
|
# Get next control output for this node
|
|
97
|
-
|
|
95
|
+
|
|
96
|
+
if (
|
|
97
|
+
isinstance(current_node, NodeGroupNode)
|
|
98
|
+
and current_node.get_parameter_value(current_node.execution_environment.name) != LOCAL_EXECUTION
|
|
99
|
+
):
|
|
98
100
|
next_output = self.get_next_control_output_for_non_local_execution(current_node)
|
|
99
101
|
else:
|
|
100
102
|
next_output = current_node.get_next_control_output()
|
|
@@ -192,6 +194,39 @@ class ResolveNodeState(State):
|
|
|
192
194
|
return None
|
|
193
195
|
|
|
194
196
|
|
|
197
|
+
def _resolve_target_node_for_control_flow(next_node_info: NextNodeInfo) -> tuple[BaseNode, Parameter | None]:
|
|
198
|
+
"""Resolve the target node, replacing children with their parent node group if necessary.
|
|
199
|
+
|
|
200
|
+
If the target node is inside a non-local node group, returns the parent node group instead.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
next_node_info: Information about the next node to process
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Tuple of (resolved_node, entry_parameter)
|
|
207
|
+
"""
|
|
208
|
+
target_node = next_node_info.node
|
|
209
|
+
entry_parameter = next_node_info.entry_parameter
|
|
210
|
+
|
|
211
|
+
# Check if node has a parent and if parent is not local execution
|
|
212
|
+
if target_node.parent_group is not None and isinstance(target_node.parent_group, NodeGroupNode):
|
|
213
|
+
parent_group = target_node.parent_group
|
|
214
|
+
execution_env = parent_group.get_parameter_value(parent_group.execution_environment.name)
|
|
215
|
+
if execution_env != LOCAL_EXECUTION:
|
|
216
|
+
logger.info(
|
|
217
|
+
"Control Flow: Redirecting from child node '%s' to parent node group '%s' (execution environment: %s)",
|
|
218
|
+
target_node.name,
|
|
219
|
+
parent_group.name,
|
|
220
|
+
execution_env,
|
|
221
|
+
)
|
|
222
|
+
# Move to parent instead of child
|
|
223
|
+
target_node = parent_group
|
|
224
|
+
# Entry parameter should be None for the parent node group
|
|
225
|
+
entry_parameter = None
|
|
226
|
+
|
|
227
|
+
return target_node, entry_parameter
|
|
228
|
+
|
|
229
|
+
|
|
195
230
|
class NextNodeState(State):
|
|
196
231
|
@staticmethod
|
|
197
232
|
async def on_enter(context: ControlFlowContext) -> type[State] | None:
|
|
@@ -232,10 +267,12 @@ class NextNodeState(State):
|
|
|
232
267
|
return CompleteState
|
|
233
268
|
|
|
234
269
|
# Set up next nodes as current nodes
|
|
270
|
+
# If a node has a parent (is in a node group), move to the parent instead
|
|
235
271
|
next_nodes = []
|
|
236
272
|
for next_node_info in next_node_infos:
|
|
237
|
-
|
|
238
|
-
|
|
273
|
+
target_node, entry_parameter = _resolve_target_node_for_control_flow(next_node_info)
|
|
274
|
+
target_node.set_entry_control_parameter(entry_parameter)
|
|
275
|
+
next_nodes.append(target_node)
|
|
239
276
|
|
|
240
277
|
context.current_nodes = next_nodes
|
|
241
278
|
context.selected_output = None
|
|
@@ -254,7 +291,6 @@ class CompleteState(State):
|
|
|
254
291
|
# Broadcast completion events for any remaining current nodes
|
|
255
292
|
for current_node in context.current_nodes:
|
|
256
293
|
# Use pickle-based serialization for complex parameter output values
|
|
257
|
-
from griptape_nodes.retained_mode.managers.node_manager import NodeManager
|
|
258
294
|
|
|
259
295
|
parameter_output_values, unique_uuid_to_values = NodeManager.serialize_parameter_output_values(
|
|
260
296
|
current_node, use_pickling=context.pickle_control_flow_result
|
|
@@ -297,34 +333,17 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
297
333
|
async def start_flow(
|
|
298
334
|
self, start_node: BaseNode, end_node: BaseNode | None = None, *, debug_mode: bool = False
|
|
299
335
|
) -> None:
|
|
300
|
-
# FIRST: Scan all nodes in the flow and create node groups BEFORE any resolution
|
|
301
|
-
flow_manager = GriptapeNodes.FlowManager()
|
|
302
|
-
flow = flow_manager.get_flow_by_name(self._context.flow_name)
|
|
303
|
-
logger.debug("Scanning flow '%s' for node groups before execution", self._context.flow_name)
|
|
304
|
-
|
|
305
|
-
try:
|
|
306
|
-
node_to_proxy_map = self._identify_and_create_node_group_proxies(flow, flow_manager.get_connections())
|
|
307
|
-
if node_to_proxy_map:
|
|
308
|
-
logger.info(
|
|
309
|
-
"Created %d proxy nodes for %d grouped nodes in flow '%s'",
|
|
310
|
-
len(set(node_to_proxy_map.values())),
|
|
311
|
-
len(node_to_proxy_map),
|
|
312
|
-
self._context.flow_name,
|
|
313
|
-
)
|
|
314
|
-
# Store the mapping in context so it can be used by resolution machines
|
|
315
|
-
self._context.node_to_proxy_map = node_to_proxy_map
|
|
316
|
-
except ValueError as e:
|
|
317
|
-
logger.error("Failed to process node groups: %s", e)
|
|
318
|
-
raise
|
|
319
|
-
|
|
320
|
-
# Determine the actual start node (use proxy if it's part of a group)
|
|
321
|
-
actual_start_node = node_to_proxy_map.get(start_node, start_node)
|
|
322
|
-
|
|
323
336
|
# If using DAG resolution, process data_nodes from queue first
|
|
324
337
|
if isinstance(self._context.resolution_machine, ParallelResolutionMachine):
|
|
325
|
-
current_nodes = await self._process_nodes_for_dag(
|
|
338
|
+
current_nodes = await self._process_nodes_for_dag(start_node)
|
|
326
339
|
else:
|
|
327
|
-
current_nodes = [
|
|
340
|
+
current_nodes = [start_node]
|
|
341
|
+
if isinstance(start_node.parent_group, NodeGroupNode):
|
|
342
|
+
# In sequential mode, we aren't going to run this. Just continue.
|
|
343
|
+
node = GriptapeNodes.FlowManager().get_next_node_from_execution_queue()
|
|
344
|
+
if node is not None:
|
|
345
|
+
await self.start_flow(node, end_node, debug_mode=debug_mode)
|
|
346
|
+
return
|
|
328
347
|
# For control flow/sequential: emit all nodes in flow as involved
|
|
329
348
|
self._context.current_nodes = current_nodes
|
|
330
349
|
self._context.end_node = end_node
|
|
@@ -385,7 +404,7 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
385
404
|
):
|
|
386
405
|
await self.update()
|
|
387
406
|
|
|
388
|
-
async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
|
|
407
|
+
async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
|
|
389
408
|
"""Process data_nodes from the global queue to build unified DAG.
|
|
390
409
|
|
|
391
410
|
This method identifies data_nodes in the execution queue and processes
|
|
@@ -400,25 +419,19 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
400
419
|
msg = "DAG builder is not initialized."
|
|
401
420
|
raise ValueError(msg)
|
|
402
421
|
|
|
403
|
-
# Use the node-to-proxy map that was created in start_flow
|
|
404
|
-
node_to_proxy_map = self._context.node_to_proxy_map
|
|
405
|
-
|
|
406
422
|
# Build with the first node (it should already be the proxy if it's part of a group)
|
|
407
423
|
dag_builder.add_node_with_dependencies(start_node, start_node.name)
|
|
408
424
|
queue_items = list(flow_manager.global_flow_queue.queue)
|
|
409
425
|
start_nodes = [start_node]
|
|
426
|
+
from griptape_nodes.retained_mode.managers.flow_manager import DagExecutionType
|
|
427
|
+
|
|
410
428
|
# Find data_nodes and remove them from queue
|
|
411
429
|
for item in queue_items:
|
|
412
|
-
from griptape_nodes.retained_mode.managers.flow_manager import DagExecutionType
|
|
413
|
-
|
|
414
430
|
if item.dag_execution_type in (DagExecutionType.CONTROL_NODE, DagExecutionType.START_NODE):
|
|
415
431
|
node = item.node
|
|
416
432
|
node.state = NodeResolutionState.UNRESOLVED
|
|
417
433
|
# Use proxy node if this node is part of a group, otherwise use original node
|
|
418
|
-
|
|
419
|
-
node_to_add = node_to_proxy_map[node]
|
|
420
|
-
else:
|
|
421
|
-
node_to_add = node
|
|
434
|
+
node_to_add = node
|
|
422
435
|
# Only add if not already added (proxy might already be in DAG)
|
|
423
436
|
if node_to_add.name not in dag_builder.node_to_reference:
|
|
424
437
|
dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
|
|
@@ -429,10 +442,7 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
429
442
|
node = item.node
|
|
430
443
|
node.state = NodeResolutionState.UNRESOLVED
|
|
431
444
|
# Use proxy node if this node is part of a group, otherwise use original node
|
|
432
|
-
|
|
433
|
-
node_to_add = node_to_proxy_map[node]
|
|
434
|
-
else:
|
|
435
|
-
node_to_add = node
|
|
445
|
+
node_to_add = node
|
|
436
446
|
# Only add if not already added (proxy might already be in DAG)
|
|
437
447
|
if node_to_add.name not in dag_builder.node_to_reference:
|
|
438
448
|
dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
|
|
@@ -440,139 +450,6 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
440
450
|
|
|
441
451
|
return start_nodes
|
|
442
452
|
|
|
443
|
-
def _identify_and_create_node_group_proxies(
|
|
444
|
-
self, flow: ControlFlow, connections: Connections
|
|
445
|
-
) -> dict[BaseNode, BaseNode]:
|
|
446
|
-
"""Scan all nodes in flow, identify groups, and create proxy nodes.
|
|
447
|
-
|
|
448
|
-
Returns:
|
|
449
|
-
Dictionary mapping original nodes to their proxy nodes (only for grouped nodes)
|
|
450
|
-
"""
|
|
451
|
-
from griptape_nodes.exe_types.node_types import NodeGroup, NodeGroupProxyNode
|
|
452
|
-
|
|
453
|
-
# Step 1: Identify groups by scanning all nodes in the flow
|
|
454
|
-
groups: dict[str, NodeGroup] = {}
|
|
455
|
-
for node in flow.nodes.values():
|
|
456
|
-
group_id = node.get_parameter_value("job_group")
|
|
457
|
-
|
|
458
|
-
# Skip nodes without group assignment, empty group ID, or locked nodes
|
|
459
|
-
if not group_id or group_id == "" or node.lock:
|
|
460
|
-
continue
|
|
461
|
-
|
|
462
|
-
# Create group if it doesn't exist
|
|
463
|
-
if group_id not in groups:
|
|
464
|
-
groups[group_id] = NodeGroup(group_id=group_id)
|
|
465
|
-
|
|
466
|
-
# Add node to group
|
|
467
|
-
groups[group_id].add_node(node)
|
|
468
|
-
|
|
469
|
-
if not groups:
|
|
470
|
-
return {}
|
|
471
|
-
|
|
472
|
-
# Step 2: Analyze connections for each group
|
|
473
|
-
for group in groups.values():
|
|
474
|
-
self._analyze_group_connections(group, connections)
|
|
475
|
-
|
|
476
|
-
# Step 3: Validate each group
|
|
477
|
-
for group in groups.values():
|
|
478
|
-
group.validate_no_intermediate_nodes(connections.connections)
|
|
479
|
-
|
|
480
|
-
# Step 4: Create proxy nodes and build mapping
|
|
481
|
-
node_to_proxy_map: dict[BaseNode, BaseNode] = {}
|
|
482
|
-
for group_id, group in groups.items():
|
|
483
|
-
# Create proxy node
|
|
484
|
-
proxy_name = f"__group_proxy_{group_id}"
|
|
485
|
-
proxy_node = NodeGroupProxyNode(name=proxy_name, node_group=group)
|
|
486
|
-
|
|
487
|
-
# Register the proxy node with ObjectManager so it can be found during parameter updates
|
|
488
|
-
obj_manager = GriptapeNodes.ObjectManager()
|
|
489
|
-
obj_manager.add_object_by_name(proxy_name, proxy_node)
|
|
490
|
-
|
|
491
|
-
# Map all grouped nodes to this proxy
|
|
492
|
-
for node in group.nodes.values():
|
|
493
|
-
node_to_proxy_map[node] = proxy_node
|
|
494
|
-
|
|
495
|
-
# Remap connections to point to proxy
|
|
496
|
-
self._remap_connections_to_proxy_node(group, proxy_node, connections)
|
|
497
|
-
|
|
498
|
-
# Now create proxy parameters (after remapping so original references are saved)
|
|
499
|
-
proxy_node.create_proxy_parameters()
|
|
500
|
-
|
|
501
|
-
return node_to_proxy_map
|
|
502
|
-
|
|
503
|
-
def _analyze_group_connections(self, group: NodeGroup, connections: Connections) -> None:
|
|
504
|
-
"""Analyze and categorize connections for a node group."""
|
|
505
|
-
node_names_in_group = group.nodes.keys()
|
|
506
|
-
|
|
507
|
-
# Analyze all connections in the flow
|
|
508
|
-
for conn in connections.connections.values():
|
|
509
|
-
source_in_group = conn.source_node.name in node_names_in_group
|
|
510
|
-
target_in_group = conn.target_node.name in node_names_in_group
|
|
511
|
-
|
|
512
|
-
if source_in_group and target_in_group:
|
|
513
|
-
# Both endpoints in group - internal connection
|
|
514
|
-
group.internal_connections.append(conn)
|
|
515
|
-
elif source_in_group and not target_in_group:
|
|
516
|
-
# From group to outside - external outgoing
|
|
517
|
-
group.external_outgoing_connections.append(conn)
|
|
518
|
-
elif not source_in_group and target_in_group:
|
|
519
|
-
# From outside to group - external incoming
|
|
520
|
-
group.external_incoming_connections.append(conn)
|
|
521
|
-
|
|
522
|
-
def _remap_connections_to_proxy_node(
|
|
523
|
-
self, group: NodeGroup, proxy_node: BaseNode, connections: Connections
|
|
524
|
-
) -> None:
|
|
525
|
-
"""Remap external connections from group nodes to the proxy node."""
|
|
526
|
-
# Remap external incoming connections (from outside -> group becomes outside -> proxy)
|
|
527
|
-
for conn in group.external_incoming_connections:
|
|
528
|
-
conn_id = id(conn)
|
|
529
|
-
|
|
530
|
-
# Save original target node before remapping (for cleanup later)
|
|
531
|
-
original_target_node = conn.target_node
|
|
532
|
-
group.original_incoming_targets[conn_id] = original_target_node
|
|
533
|
-
|
|
534
|
-
# Remove old incoming index entry
|
|
535
|
-
if (
|
|
536
|
-
conn.target_node.name in connections.incoming_index
|
|
537
|
-
and conn.target_parameter.name in connections.incoming_index[conn.target_node.name]
|
|
538
|
-
):
|
|
539
|
-
connections.incoming_index[conn.target_node.name][conn.target_parameter.name].remove(conn_id)
|
|
540
|
-
|
|
541
|
-
# Update connection target to proxy
|
|
542
|
-
conn.target_node = proxy_node
|
|
543
|
-
|
|
544
|
-
# Create proxy parameter name using original node name
|
|
545
|
-
sanitized_node_name = original_target_node.name.replace(" ", "_")
|
|
546
|
-
proxy_param_name = f"{sanitized_node_name}__{conn.target_parameter.name}"
|
|
547
|
-
|
|
548
|
-
# Add new incoming index entry with proxy parameter name
|
|
549
|
-
connections.incoming_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
|
|
550
|
-
|
|
551
|
-
# Remap external outgoing connections (group -> outside becomes proxy -> outside)
|
|
552
|
-
for conn in group.external_outgoing_connections:
|
|
553
|
-
conn_id = id(conn)
|
|
554
|
-
|
|
555
|
-
# Save original source node before remapping (for cleanup later)
|
|
556
|
-
original_source_node = conn.source_node
|
|
557
|
-
group.original_outgoing_sources[conn_id] = original_source_node
|
|
558
|
-
|
|
559
|
-
# Remove old outgoing index entry
|
|
560
|
-
if (
|
|
561
|
-
conn.source_node.name in connections.outgoing_index
|
|
562
|
-
and conn.source_parameter.name in connections.outgoing_index[conn.source_node.name]
|
|
563
|
-
):
|
|
564
|
-
connections.outgoing_index[conn.source_node.name][conn.source_parameter.name].remove(conn_id)
|
|
565
|
-
|
|
566
|
-
# Update connection source to proxy
|
|
567
|
-
conn.source_node = proxy_node
|
|
568
|
-
|
|
569
|
-
# Create proxy parameter name using original node name
|
|
570
|
-
sanitized_node_name = original_source_node.name.replace(" ", "_")
|
|
571
|
-
proxy_param_name = f"{sanitized_node_name}__{conn.source_parameter.name}"
|
|
572
|
-
|
|
573
|
-
# Add new outgoing index entry with proxy parameter name
|
|
574
|
-
connections.outgoing_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
|
|
575
|
-
|
|
576
453
|
async def cancel_flow(self) -> None:
|
|
577
454
|
"""Cancel all nodes in the flow by delegating to the resolution machine."""
|
|
578
455
|
await self.resolution_machine.cancel_all_nodes()
|
|
@@ -581,22 +458,6 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
581
458
|
self._context.reset(cancel=cancel)
|
|
582
459
|
self._current_state = None
|
|
583
460
|
|
|
584
|
-
def cleanup_proxy_nodes(self) -> None:
|
|
585
|
-
"""Cleanup all proxy nodes and restore original connections."""
|
|
586
|
-
if not self._context.node_to_proxy_map:
|
|
587
|
-
# If we're calling cleanup, but it's already been cleaned up, we just want to return.
|
|
588
|
-
return
|
|
589
|
-
|
|
590
|
-
# Get all unique proxy nodes
|
|
591
|
-
proxy_nodes = set(self._context.node_to_proxy_map.values())
|
|
592
|
-
|
|
593
|
-
# Cleanup each proxy node using the existing method
|
|
594
|
-
for proxy_node in proxy_nodes:
|
|
595
|
-
ExecuteDagState._cleanup_proxy_node(proxy_node)
|
|
596
|
-
|
|
597
|
-
# Clear the proxy mapping
|
|
598
|
-
self._context.node_to_proxy_map.clear()
|
|
599
|
-
|
|
600
461
|
@property
|
|
601
462
|
def resolution_machine(self) -> ParallelResolutionMachine | SequentialResolutionMachine:
|
|
602
463
|
return self._context.resolution_machine
|