griptape-nodes 0.63.10__py3-none-any.whl → 0.64.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/common/node_executor.py +95 -171
- griptape_nodes/exe_types/connections.py +51 -2
- griptape_nodes/exe_types/flow.py +3 -3
- griptape_nodes/exe_types/node_types.py +330 -202
- griptape_nodes/exe_types/param_components/artifact_url/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/artifact_url/public_artifact_url_parameter.py +155 -0
- griptape_nodes/exe_types/param_components/progress_bar_component.py +1 -1
- griptape_nodes/exe_types/param_types/parameter_string.py +27 -0
- griptape_nodes/machines/control_flow.py +64 -203
- griptape_nodes/machines/dag_builder.py +85 -238
- griptape_nodes/machines/parallel_resolution.py +9 -236
- griptape_nodes/machines/sequential_resolution.py +133 -11
- griptape_nodes/retained_mode/events/agent_events.py +2 -0
- griptape_nodes/retained_mode/events/flow_events.py +5 -6
- griptape_nodes/retained_mode/events/node_events.py +151 -1
- griptape_nodes/retained_mode/events/workflow_events.py +10 -0
- griptape_nodes/retained_mode/managers/agent_manager.py +33 -1
- griptape_nodes/retained_mode/managers/flow_manager.py +213 -290
- griptape_nodes/retained_mode/managers/library_manager.py +24 -7
- griptape_nodes/retained_mode/managers/node_manager.py +400 -77
- griptape_nodes/retained_mode/managers/version_compatibility_manager.py +113 -69
- griptape_nodes/retained_mode/managers/workflow_manager.py +45 -10
- griptape_nodes/servers/mcp.py +32 -0
- griptape_nodes/version_compatibility/versions/v0_63_8/__init__.py +1 -0
- griptape_nodes/version_compatibility/versions/v0_63_8/deprecated_nodegroup_parameters.py +105 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/METADATA +3 -1
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/RECORD +31 -28
- griptape_nodes/version_compatibility/workflow_versions/__init__.py +0 -1
- /griptape_nodes/version_compatibility/{workflow_versions → versions}/v0_7_0/__init__.py +0 -0
- /griptape_nodes/version_compatibility/{workflow_versions → versions}/v0_7_0/local_executor_argument_addition.py +0 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/WHEEL +0 -0
- {griptape_nodes-0.63.10.dist-info → griptape_nodes-0.64.1.dist-info}/entry_points.txt +0 -0
|
@@ -40,7 +40,6 @@ from griptape_nodes.traits.options import Options
|
|
|
40
40
|
from griptape_nodes.utils import async_utils
|
|
41
41
|
|
|
42
42
|
if TYPE_CHECKING:
|
|
43
|
-
from griptape_nodes.exe_types.connections import Connections
|
|
44
43
|
from griptape_nodes.exe_types.core_types import NodeMessagePayload
|
|
45
44
|
from griptape_nodes.node_library.library_registry import LibraryNameAndVersion
|
|
46
45
|
|
|
@@ -129,7 +128,7 @@ class BaseNode(ABC):
|
|
|
129
128
|
# Owned by a flow
|
|
130
129
|
name: str
|
|
131
130
|
metadata: dict[Any, Any]
|
|
132
|
-
|
|
131
|
+
_parent_group: BaseNode | None
|
|
133
132
|
# Node Context Fields
|
|
134
133
|
current_spotlight_parameter: Parameter | None = None
|
|
135
134
|
parameter_values: dict[str, Any]
|
|
@@ -171,26 +170,8 @@ class BaseNode(ABC):
|
|
|
171
170
|
self.process_generator = None
|
|
172
171
|
self._tracked_parameters = []
|
|
173
172
|
self._cancellation_requested = threading.Event()
|
|
173
|
+
self._parent_group = None
|
|
174
174
|
self.set_entry_control_parameter(None)
|
|
175
|
-
self.execution_environment = Parameter(
|
|
176
|
-
name="execution_environment",
|
|
177
|
-
tooltip="Environment that the node should execute in",
|
|
178
|
-
type=ParameterTypeBuiltin.STR,
|
|
179
|
-
allowed_modes={ParameterMode.PROPERTY},
|
|
180
|
-
default_value=LOCAL_EXECUTION,
|
|
181
|
-
traits={Options(choices=get_library_names_with_publish_handlers())},
|
|
182
|
-
ui_options={"hide": True},
|
|
183
|
-
)
|
|
184
|
-
self.add_parameter(self.execution_environment)
|
|
185
|
-
self.node_group = Parameter(
|
|
186
|
-
name="job_group",
|
|
187
|
-
tooltip="Groupings of multiple nodes to send up as a Deadline Cloud job.",
|
|
188
|
-
type=ParameterTypeBuiltin.STR,
|
|
189
|
-
allowed_modes={ParameterMode.PROPERTY},
|
|
190
|
-
default_value="",
|
|
191
|
-
ui_options={"hide": True},
|
|
192
|
-
)
|
|
193
|
-
self.add_parameter(self.node_group)
|
|
194
175
|
|
|
195
176
|
@property
|
|
196
177
|
def state(self) -> NodeResolutionState:
|
|
@@ -204,6 +185,14 @@ class BaseNode(ABC):
|
|
|
204
185
|
def state(self, new_state: NodeResolutionState) -> None:
|
|
205
186
|
self._state = new_state
|
|
206
187
|
|
|
188
|
+
@property
|
|
189
|
+
def parent_group(self) -> BaseNode | None:
|
|
190
|
+
return self._parent_group
|
|
191
|
+
|
|
192
|
+
@parent_group.setter
|
|
193
|
+
def parent_group(self, parent_group: BaseNode | None) -> None:
|
|
194
|
+
self._parent_group = parent_group
|
|
195
|
+
|
|
207
196
|
# This is gross and we need to have a universal pass on resolution state changes and emission of events. That's what this ticket does!
|
|
208
197
|
# https://github.com/griptape-ai/griptape-nodes/issues/994
|
|
209
198
|
def make_node_unresolved(self, current_states_to_trigger_change_event: set[NodeResolutionState] | None) -> None:
|
|
@@ -315,6 +304,15 @@ class BaseNode(ABC):
|
|
|
315
304
|
"""Callback after a Connection has been established OUT of this Node."""
|
|
316
305
|
return
|
|
317
306
|
|
|
307
|
+
def before_incoming_connection_removed(
|
|
308
|
+
self,
|
|
309
|
+
source_node: BaseNode, # noqa: ARG002
|
|
310
|
+
source_parameter: Parameter, # noqa: ARG002
|
|
311
|
+
target_parameter: Parameter, # noqa: ARG002
|
|
312
|
+
) -> None:
|
|
313
|
+
"""Callback before a Connection TO this Node is REMOVED."""
|
|
314
|
+
return
|
|
315
|
+
|
|
318
316
|
def after_incoming_connection_removed(
|
|
319
317
|
self,
|
|
320
318
|
source_node: BaseNode, # noqa: ARG002
|
|
@@ -324,6 +322,15 @@ class BaseNode(ABC):
|
|
|
324
322
|
"""Callback after a Connection TO this Node was REMOVED."""
|
|
325
323
|
return
|
|
326
324
|
|
|
325
|
+
def before_outgoing_connection_removed(
|
|
326
|
+
self,
|
|
327
|
+
source_parameter: Parameter, # noqa: ARG002
|
|
328
|
+
target_node: BaseNode, # noqa: ARG002
|
|
329
|
+
target_parameter: Parameter, # noqa: ARG002
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Callback before a Connection OUT of this Node is REMOVED."""
|
|
332
|
+
return
|
|
333
|
+
|
|
327
334
|
def after_outgoing_connection_removed(
|
|
328
335
|
self,
|
|
329
336
|
source_parameter: Parameter, # noqa: ARG002
|
|
@@ -829,12 +836,16 @@ class BaseNode(ABC):
|
|
|
829
836
|
err = f"Attempted to remove value for Parameter '{param_name}' but parameter doesn't exist."
|
|
830
837
|
raise KeyError(err)
|
|
831
838
|
if param_name in self.parameter_values:
|
|
832
|
-
|
|
839
|
+
# Reset the parameter to default.
|
|
840
|
+
default_val = parameter.default_value
|
|
841
|
+
self.set_parameter_value(param_name, default_val)
|
|
842
|
+
|
|
833
843
|
# special handling if it's in a container.
|
|
834
844
|
if parameter.parent_container_name and parameter.parent_container_name in self.parameter_values:
|
|
835
845
|
del self.parameter_values[parameter.parent_container_name]
|
|
836
846
|
new_val = self.get_parameter_value(parameter.parent_container_name)
|
|
837
847
|
if new_val is not None:
|
|
848
|
+
# Don't set the container to None (that would make it empty)
|
|
838
849
|
self.set_parameter_value(parameter.parent_container_name, new_val)
|
|
839
850
|
else:
|
|
840
851
|
err = f"Attempted to remove value for Parameter '{param_name}' but no value was set."
|
|
@@ -1848,86 +1859,74 @@ class ErrorProxyNode(BaseNode):
|
|
|
1848
1859
|
|
|
1849
1860
|
|
|
1850
1861
|
@dataclass
|
|
1851
|
-
class
|
|
1852
|
-
"""
|
|
1862
|
+
class NodeGroupStoredConnections:
|
|
1863
|
+
"""Stores all of the connections for when we create/remove connections on a node group based on the parameters."""
|
|
1853
1864
|
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1865
|
+
@dataclass
|
|
1866
|
+
class ExternalConnections:
|
|
1867
|
+
"""Represents the External connections to/from the node group."""
|
|
1857
1868
|
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
nodes: Set of BaseNode instances that belong to this group
|
|
1861
|
-
internal_connections: Connections between nodes within the group
|
|
1862
|
-
external_incoming_connections: Connections from outside nodes into the group
|
|
1863
|
-
external_outgoing_connections: Connections from group nodes to outside nodes
|
|
1864
|
-
"""
|
|
1869
|
+
incoming_connections: list[Connection] = field(default_factory=list)
|
|
1870
|
+
outgoing_connections: list[Connection] = field(default_factory=list)
|
|
1865
1871
|
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
external_incoming_connections: list[Connection] = field(default_factory=list)
|
|
1870
|
-
external_outgoing_connections: list[Connection] = field(default_factory=list)
|
|
1871
|
-
# Store original node references before remapping to proxy (for cleanup)
|
|
1872
|
-
original_incoming_targets: dict[int, BaseNode] = field(default_factory=dict) # conn_id -> original target
|
|
1873
|
-
original_outgoing_sources: dict[int, BaseNode] = field(default_factory=dict) # conn_id -> original source
|
|
1872
|
+
@dataclass
|
|
1873
|
+
class OriginalTargets:
|
|
1874
|
+
"""Represents the connections before they were remapped."""
|
|
1874
1875
|
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
self.nodes[node.name] = node
|
|
1876
|
+
incoming_sources: dict[int, BaseNode] = field(default_factory=dict)
|
|
1877
|
+
outgoing_targets: dict[int, BaseNode] = field(default_factory=dict)
|
|
1878
1878
|
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
This method checks the dependency graph to ensure that all nodes that lie
|
|
1883
|
-
on paths between grouped nodes are also part of the group. If ungrouped
|
|
1884
|
-
nodes are found between grouped nodes, this indicates a logical error in
|
|
1885
|
-
the group definition.
|
|
1886
|
-
|
|
1887
|
-
Args:
|
|
1888
|
-
all_connections: Dictionary mapping connection IDs to Connection objects
|
|
1879
|
+
internal_connections: list[Connection] = field(default_factory=list)
|
|
1880
|
+
external_connections: ExternalConnections = field(default_factory=ExternalConnections)
|
|
1881
|
+
original_targets: OriginalTargets = field(default_factory=OriginalTargets)
|
|
1889
1882
|
|
|
1890
|
-
Raises:
|
|
1891
|
-
ValueError: If ungrouped nodes are found between grouped nodes
|
|
1892
|
-
"""
|
|
1893
|
-
from griptape_nodes.exe_types.connections import Connections
|
|
1894
1883
|
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
connections.connections = all_connections
|
|
1884
|
+
class NodeGroupNode(BaseNode):
|
|
1885
|
+
"""Proxy node that represents a group of nodes during DAG execution.
|
|
1898
1886
|
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
).append(conn_id)
|
|
1904
|
-
connections.incoming_index.setdefault(conn.target_node.name, {}).setdefault(
|
|
1905
|
-
conn.target_parameter.name, []
|
|
1906
|
-
).append(conn_id)
|
|
1887
|
+
This node acts as a single execution unit for a group of nodes that should
|
|
1888
|
+
be executed in parallel. When the DAG executor encounters this proxy node,
|
|
1889
|
+
it passes the entire NodeGroup to the NodeExecutor which handles parallel
|
|
1890
|
+
execution of all grouped nodes.
|
|
1907
1891
|
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
if node_a == node_b:
|
|
1912
|
-
continue
|
|
1892
|
+
The proxy node has parameters that mirror the external connections to/from
|
|
1893
|
+
the group, allowing it to seamlessly integrate into the DAG structure.
|
|
1894
|
+
"""
|
|
1913
1895
|
|
|
1914
|
-
|
|
1915
|
-
|
|
1896
|
+
nodes: dict[str, BaseNode]
|
|
1897
|
+
stored_connections: NodeGroupStoredConnections
|
|
1898
|
+
_proxy_param_to_node_param: dict[str, tuple[BaseNode, str]]
|
|
1916
1899
|
|
|
1917
|
-
|
|
1918
|
-
|
|
1900
|
+
def __init__(
|
|
1901
|
+
self,
|
|
1902
|
+
name: str,
|
|
1903
|
+
metadata: dict[Any, Any] | None = None,
|
|
1904
|
+
) -> None:
|
|
1905
|
+
super().__init__(name, metadata)
|
|
1906
|
+
self.execution_environment = Parameter(
|
|
1907
|
+
name="execution_environment",
|
|
1908
|
+
tooltip="Environment that the group should execute in",
|
|
1909
|
+
type=ParameterTypeBuiltin.STR,
|
|
1910
|
+
allowed_modes={ParameterMode.PROPERTY},
|
|
1911
|
+
default_value=LOCAL_EXECUTION,
|
|
1912
|
+
traits={Options(choices=get_library_names_with_publish_handlers())},
|
|
1913
|
+
)
|
|
1914
|
+
self.add_parameter(self.execution_environment)
|
|
1915
|
+
self.nodes = {}
|
|
1916
|
+
# Track mapping from proxy parameter name to (original_node, original_param_name)
|
|
1917
|
+
self._proxy_param_to_node_param = {}
|
|
1918
|
+
self.stored_connections = NodeGroupStoredConnections()
|
|
1919
1919
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
raise ValueError(msg)
|
|
1920
|
+
def get_all_nodes(self) -> dict[str, BaseNode]:
|
|
1921
|
+
all_nodes = {}
|
|
1922
|
+
for node_name, node in self.nodes.items():
|
|
1923
|
+
all_nodes[node_name] = node
|
|
1924
|
+
if isinstance(node, NodeGroupNode):
|
|
1925
|
+
all_nodes.update(node.nodes)
|
|
1926
|
+
return all_nodes
|
|
1928
1927
|
|
|
1929
1928
|
def _find_intermediate_nodes( # noqa: C901
|
|
1930
|
-
self, start_node: BaseNode, end_node: BaseNode
|
|
1929
|
+
self, start_node: BaseNode, end_node: BaseNode
|
|
1931
1930
|
) -> set[BaseNode]:
|
|
1932
1931
|
"""Find all nodes on paths between start_node and end_node (excluding endpoints).
|
|
1933
1932
|
|
|
@@ -1937,13 +1936,19 @@ class NodeGroup:
|
|
|
1937
1936
|
Args:
|
|
1938
1937
|
start_node: Starting node for path search
|
|
1939
1938
|
end_node: Target node for path search
|
|
1940
|
-
connections: Connections object for graph traversal
|
|
1941
1939
|
|
|
1942
1940
|
Returns:
|
|
1943
1941
|
Set of nodes found on paths between start and end (excluding endpoints)
|
|
1944
1942
|
"""
|
|
1945
|
-
|
|
1946
|
-
|
|
1943
|
+
# Build a lookup dictionary for faster connection queries
|
|
1944
|
+
# Map from (source_node_name, source_param_name) -> list of connections
|
|
1945
|
+
outgoing_lookup: dict[tuple[str, str], list[Connection]] = {}
|
|
1946
|
+
|
|
1947
|
+
for conn in self.stored_connections.internal_connections:
|
|
1948
|
+
key = (conn.source_node.name, conn.source_parameter.name)
|
|
1949
|
+
if key not in outgoing_lookup:
|
|
1950
|
+
outgoing_lookup[key] = []
|
|
1951
|
+
outgoing_lookup[key].append(conn)
|
|
1947
1952
|
|
|
1948
1953
|
visited = set()
|
|
1949
1954
|
intermediate = set()
|
|
@@ -1957,141 +1962,268 @@ class NodeGroup:
|
|
|
1957
1962
|
visited.add(current_node.name)
|
|
1958
1963
|
|
|
1959
1964
|
# Process outgoing connections from current node
|
|
1960
|
-
|
|
1961
|
-
|
|
1965
|
+
current_outgoing = []
|
|
1966
|
+
for param_name in [p.name for p in current_node.parameters]:
|
|
1967
|
+
key = (current_node.name, param_name)
|
|
1968
|
+
if key in outgoing_lookup:
|
|
1969
|
+
current_outgoing.extend(outgoing_lookup[key])
|
|
1970
|
+
|
|
1971
|
+
for conn in current_outgoing:
|
|
1972
|
+
next_node = conn.target_node
|
|
1973
|
+
|
|
1974
|
+
# If we reached the end node, record intermediate nodes
|
|
1975
|
+
if next_node == end_node:
|
|
1976
|
+
for node in path[1:]:
|
|
1977
|
+
intermediate.add(node)
|
|
1978
|
+
continue
|
|
1979
|
+
|
|
1980
|
+
# Continue exploring if not already visited
|
|
1981
|
+
if next_node.name not in visited:
|
|
1982
|
+
queue.append((next_node, [*path, next_node]))
|
|
1962
1983
|
|
|
1963
|
-
|
|
1964
|
-
for conn_id in conn_ids:
|
|
1965
|
-
if conn_id not in connections.connections:
|
|
1966
|
-
continue
|
|
1984
|
+
return intermediate
|
|
1967
1985
|
|
|
1968
|
-
|
|
1969
|
-
|
|
1986
|
+
def validate_no_intermediate_nodes(self) -> None:
|
|
1987
|
+
"""Validate that no ungrouped nodes exist between grouped nodes.
|
|
1970
1988
|
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
continue
|
|
1989
|
+
This method checks the dependency graph to ensure that all nodes that lie
|
|
1990
|
+
on paths between grouped nodes are also part of the group. If ungrouped
|
|
1991
|
+
nodes are found between grouped nodes, this indicates a logical error in
|
|
1992
|
+
the group definition.
|
|
1976
1993
|
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1994
|
+
Raises:
|
|
1995
|
+
ValueError: If ungrouped nodes are found between grouped nodes
|
|
1996
|
+
"""
|
|
1997
|
+
# Check each pair of nodes in the group
|
|
1998
|
+
for node_a in self.nodes.values():
|
|
1999
|
+
for node_b in self.nodes.values():
|
|
2000
|
+
if node_a == node_b:
|
|
2001
|
+
continue
|
|
1980
2002
|
|
|
1981
|
-
|
|
2003
|
+
# Check if there's a path from node_a to node_b
|
|
2004
|
+
intermediate_nodes = self._find_intermediate_nodes(node_a, node_b)
|
|
1982
2005
|
|
|
2006
|
+
# Check if any intermediate nodes are not in the group
|
|
2007
|
+
ungrouped_intermediates = [n for n in intermediate_nodes if n.name not in self.nodes]
|
|
1983
2008
|
|
|
1984
|
-
|
|
1985
|
-
|
|
2009
|
+
if ungrouped_intermediates:
|
|
2010
|
+
ungrouped_names = [n.name for n in ungrouped_intermediates]
|
|
2011
|
+
msg = (
|
|
2012
|
+
f"Invalid node group '{self.name}': Found ungrouped nodes between grouped nodes. "
|
|
2013
|
+
f"Ungrouped nodes {ungrouped_names} exist on the path from '{node_a.name}' to '{node_b.name}'. "
|
|
2014
|
+
f"All nodes on paths between grouped nodes must be part of the same group."
|
|
2015
|
+
)
|
|
2016
|
+
raise ValueError(msg)
|
|
1986
2017
|
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
it passes the entire NodeGroup to the NodeExecutor which handles parallel
|
|
1990
|
-
execution of all grouped nodes.
|
|
2018
|
+
def track_internal_connection(self, conn: Connection) -> None:
|
|
2019
|
+
"""Track a connection between nodes within the group.
|
|
1991
2020
|
|
|
1992
|
-
|
|
1993
|
-
|
|
2021
|
+
Args:
|
|
2022
|
+
conn: The internal connection to track
|
|
2023
|
+
"""
|
|
2024
|
+
if conn not in self.stored_connections.internal_connections:
|
|
2025
|
+
self.stored_connections.internal_connections.append(conn)
|
|
1994
2026
|
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
2027
|
+
def track_external_connection(
|
|
2028
|
+
self,
|
|
2029
|
+
conn: Connection,
|
|
2030
|
+
conn_id: int,
|
|
2031
|
+
is_incoming: bool, # noqa: FBT001
|
|
2032
|
+
grouped_node: BaseNode,
|
|
2033
|
+
) -> None:
|
|
2034
|
+
"""Track a connection to/from a node in the group.
|
|
1998
2035
|
|
|
1999
|
-
|
|
2036
|
+
Args:
|
|
2037
|
+
conn: The external connection to track
|
|
2038
|
+
conn_id: ID of the connection
|
|
2039
|
+
is_incoming: True if connection is coming INTO the group
|
|
2040
|
+
grouped_node: The node in the group involved in the connection
|
|
2041
|
+
"""
|
|
2042
|
+
if is_incoming:
|
|
2043
|
+
if conn not in self.stored_connections.external_connections.incoming_connections:
|
|
2044
|
+
self.stored_connections.external_connections.incoming_connections.append(conn)
|
|
2045
|
+
self.stored_connections.original_targets.incoming_sources[conn_id] = grouped_node
|
|
2046
|
+
else:
|
|
2047
|
+
if conn not in self.stored_connections.external_connections.outgoing_connections:
|
|
2048
|
+
self.stored_connections.external_connections.outgoing_connections.append(conn)
|
|
2049
|
+
self.stored_connections.original_targets.outgoing_targets[conn_id] = grouped_node
|
|
2050
|
+
|
|
2051
|
+
def untrack_internal_connection(self, conn: Connection) -> None:
|
|
2052
|
+
"""Remove tracking of an internal connection.
|
|
2053
|
+
|
|
2054
|
+
Args:
|
|
2055
|
+
conn: The internal connection to untrack
|
|
2056
|
+
"""
|
|
2057
|
+
if conn in self.stored_connections.internal_connections:
|
|
2058
|
+
self.stored_connections.internal_connections.remove(conn)
|
|
2059
|
+
|
|
2060
|
+
def untrack_external_connection(
|
|
2000
2061
|
self,
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2062
|
+
conn: Connection,
|
|
2063
|
+
conn_id: int,
|
|
2064
|
+
is_incoming: bool, # noqa: FBT001
|
|
2004
2065
|
) -> None:
|
|
2005
|
-
|
|
2006
|
-
self.node_group_data = node_group
|
|
2066
|
+
"""Remove tracking of an external connection.
|
|
2007
2067
|
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2068
|
+
Args:
|
|
2069
|
+
conn: The external connection to untrack
|
|
2070
|
+
conn_id: ID of the connection
|
|
2071
|
+
is_incoming: True if connection was coming INTO the group
|
|
2072
|
+
"""
|
|
2073
|
+
if is_incoming:
|
|
2074
|
+
if conn in self.stored_connections.external_connections.incoming_connections:
|
|
2075
|
+
self.stored_connections.external_connections.incoming_connections.remove(conn)
|
|
2076
|
+
if conn_id in self.stored_connections.original_targets.incoming_sources:
|
|
2077
|
+
del self.stored_connections.original_targets.incoming_sources[conn_id]
|
|
2078
|
+
else:
|
|
2079
|
+
if conn in self.stored_connections.external_connections.outgoing_connections:
|
|
2080
|
+
self.stored_connections.external_connections.outgoing_connections.remove(conn)
|
|
2081
|
+
if conn_id in self.stored_connections.original_targets.outgoing_targets:
|
|
2082
|
+
del self.stored_connections.original_targets.outgoing_targets[conn_id]
|
|
2083
|
+
|
|
2084
|
+
def _remove_nodes_from_existing_parents(self, nodes: list[BaseNode]) -> None:
|
|
2085
|
+
"""Remove nodes from their existing parent groups."""
|
|
2086
|
+
child_nodes = {}
|
|
2087
|
+
for node in nodes:
|
|
2088
|
+
if node.parent_group is not None:
|
|
2089
|
+
existing_parent_group = node.parent_group
|
|
2090
|
+
if isinstance(existing_parent_group, NodeGroupNode):
|
|
2091
|
+
child_nodes.setdefault(existing_parent_group, []).append(node)
|
|
2092
|
+
for parent_group, node_list in child_nodes.items():
|
|
2093
|
+
parent_group.remove_nodes_from_group(node_list)
|
|
2094
|
+
|
|
2095
|
+
def _add_nodes_to_group_dict(self, nodes: list[BaseNode]) -> None:
|
|
2096
|
+
"""Add nodes to the group's node dictionary."""
|
|
2097
|
+
for node in nodes:
|
|
2098
|
+
node.parent_group = self
|
|
2099
|
+
self.nodes[node.name] = node
|
|
2100
|
+
|
|
2101
|
+
def _track_incoming_connections(self, node: BaseNode, connections: Any, node_names_in_group: set[str]) -> None:
|
|
2102
|
+
"""Track incoming external connections for a node."""
|
|
2103
|
+
if node.name not in connections.incoming_index:
|
|
2104
|
+
return
|
|
2105
|
+
|
|
2106
|
+
for connection_ids in connections.incoming_index[node.name].values():
|
|
2107
|
+
for conn_id in connection_ids:
|
|
2108
|
+
if conn_id not in connections.connections:
|
|
2109
|
+
continue
|
|
2110
|
+
conn = connections.connections[conn_id]
|
|
2022
2111
|
|
|
2023
|
-
|
|
2024
|
-
|
|
2112
|
+
if conn.source_node.name not in node_names_in_group:
|
|
2113
|
+
self.track_external_connection(conn, conn_id, is_incoming=True, grouped_node=node)
|
|
2114
|
+
elif conn not in self.stored_connections.internal_connections:
|
|
2115
|
+
self.track_internal_connection(conn)
|
|
2025
2116
|
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2117
|
+
def _track_outgoing_connections(self, node: BaseNode, connections: Any, node_names_in_group: set[str]) -> None:
|
|
2118
|
+
"""Track outgoing external connections for a node."""
|
|
2119
|
+
if node.name not in connections.outgoing_index:
|
|
2120
|
+
return
|
|
2121
|
+
|
|
2122
|
+
for connection_ids in connections.outgoing_index[node.name].values():
|
|
2123
|
+
for conn_id in connection_ids:
|
|
2124
|
+
if conn_id not in connections.connections:
|
|
2125
|
+
continue
|
|
2126
|
+
conn = connections.connections[conn_id]
|
|
2029
2127
|
|
|
2030
|
-
|
|
2031
|
-
|
|
2128
|
+
if conn.target_node.name not in node_names_in_group:
|
|
2129
|
+
self.track_external_connection(conn, conn_id, is_incoming=False, grouped_node=node)
|
|
2130
|
+
|
|
2131
|
+
def add_nodes_to_group(self, nodes: list[BaseNode]) -> None:
|
|
2132
|
+
"""Add nodes to the group and track their connections.
|
|
2133
|
+
|
|
2134
|
+
Args:
|
|
2135
|
+
nodes: List of nodes to add to the group
|
|
2032
2136
|
"""
|
|
2033
|
-
|
|
2034
|
-
created_params = set()
|
|
2137
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
2035
2138
|
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2139
|
+
self._remove_nodes_from_existing_parents(nodes)
|
|
2140
|
+
self._add_nodes_to_group_dict(nodes)
|
|
2141
|
+
|
|
2142
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
2143
|
+
node_names_in_group = set(self.nodes.keys())
|
|
2144
|
+
self.metadata["node_names_in_group"] = list(node_names_in_group)
|
|
2145
|
+
|
|
2146
|
+
nodes_being_added = {node.name for node in nodes}
|
|
2147
|
+
internal_conns = connections.get_connections_between_nodes(nodes_being_added)
|
|
2148
|
+
for conn in internal_conns:
|
|
2149
|
+
self.track_internal_connection(conn)
|
|
2150
|
+
|
|
2151
|
+
for node in nodes:
|
|
2152
|
+
self._track_incoming_connections(node, connections, node_names_in_group)
|
|
2153
|
+
self._track_outgoing_connections(node, connections, node_names_in_group)
|
|
2154
|
+
|
|
2155
|
+
def _validate_nodes_in_group(self, nodes: list[BaseNode]) -> None:
|
|
2156
|
+
"""Validate that all nodes are in the group."""
|
|
2157
|
+
for node in nodes:
|
|
2158
|
+
if node.name not in self.nodes:
|
|
2159
|
+
msg = f"Node {node.name} is not in node group {self.name}"
|
|
2044
2160
|
raise ValueError(msg)
|
|
2045
|
-
target_param = conn.target_parameter
|
|
2046
|
-
|
|
2047
|
-
# Create proxy parameter name: {sanitized_node_name}__{param_name}
|
|
2048
|
-
sanitized_node_name = target_node.name.replace(" ", "_")
|
|
2049
|
-
proxy_param_name = f"{sanitized_node_name}__{target_param.name}"
|
|
2050
|
-
|
|
2051
|
-
if proxy_param_name not in created_params:
|
|
2052
|
-
proxy_param = Parameter(
|
|
2053
|
-
name=proxy_param_name,
|
|
2054
|
-
type=target_param.type,
|
|
2055
|
-
input_types=target_param.input_types,
|
|
2056
|
-
output_type=target_param.output_type,
|
|
2057
|
-
tooltip=f"Proxy input for {target_node.name}.{target_param.name}",
|
|
2058
|
-
allowed_modes={ParameterMode.INPUT},
|
|
2059
|
-
)
|
|
2060
|
-
self.add_parameter(proxy_param)
|
|
2061
|
-
created_params.add(proxy_param_name)
|
|
2062
2161
|
|
|
2063
|
-
|
|
2064
|
-
|
|
2162
|
+
def _untrack_external_incoming_for_node(self, node: BaseNode) -> None:
|
|
2163
|
+
"""Untrack external incoming connections for a node."""
|
|
2164
|
+
for conn in list(self.stored_connections.external_connections.incoming_connections):
|
|
2165
|
+
conn_id = id(conn)
|
|
2166
|
+
original_target = self.stored_connections.original_targets.incoming_sources.get(conn_id)
|
|
2167
|
+
if original_target and original_target.name == node.name:
|
|
2168
|
+
self.untrack_external_connection(conn, conn_id, is_incoming=True)
|
|
2065
2169
|
|
|
2066
|
-
|
|
2067
|
-
for
|
|
2170
|
+
def _untrack_external_outgoing_for_node(self, node: BaseNode) -> None:
|
|
2171
|
+
"""Untrack external outgoing connections for a node."""
|
|
2172
|
+
for conn in list(self.stored_connections.external_connections.outgoing_connections):
|
|
2068
2173
|
conn_id = id(conn)
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
2174
|
+
original_source = self.stored_connections.original_targets.outgoing_targets.get(conn_id)
|
|
2175
|
+
if original_source and original_source.name == node.name:
|
|
2176
|
+
self.untrack_external_connection(conn, conn_id, is_incoming=False)
|
|
2177
|
+
|
|
2178
|
+
def _untrack_internal_for_node(self, node: BaseNode, nodes_being_removed: set[str]) -> None:
|
|
2179
|
+
"""Untrack internal connections for a node."""
|
|
2180
|
+
for conn in list(self.stored_connections.internal_connections):
|
|
2181
|
+
if node.name not in (conn.source_node.name, conn.target_node.name):
|
|
2073
2182
|
continue
|
|
2074
2183
|
|
|
2075
|
-
|
|
2184
|
+
other_node_name = conn.target_node.name if conn.source_node.name == node.name else conn.source_node.name
|
|
2185
|
+
if other_node_name in nodes_being_removed or other_node_name not in self.nodes:
|
|
2186
|
+
self.untrack_internal_connection(conn)
|
|
2076
2187
|
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
proxy_param_name = f"{sanitized_node_name}__{source_param.name}"
|
|
2188
|
+
def has_external_control_input(self) -> bool:
|
|
2189
|
+
"""Check if this NodeGroup has any external incoming control connections.
|
|
2080
2190
|
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2191
|
+
Returns:
|
|
2192
|
+
True if any external incoming connection is a control input, False otherwise
|
|
2193
|
+
"""
|
|
2194
|
+
from griptape_nodes.exe_types.core_types import ParameterTypeBuiltin
|
|
2195
|
+
|
|
2196
|
+
for conn in self.stored_connections.external_connections.incoming_connections:
|
|
2197
|
+
if conn.target_parameter.type == ParameterTypeBuiltin.CONTROL_TYPE:
|
|
2198
|
+
return True
|
|
2199
|
+
if ParameterTypeBuiltin.CONTROL_TYPE.value in conn.target_parameter.input_types:
|
|
2200
|
+
return True
|
|
2201
|
+
|
|
2202
|
+
return False
|
|
2203
|
+
|
|
2204
|
+
def remove_nodes_from_group(self, nodes: list[BaseNode]) -> None:
|
|
2205
|
+
"""Remove nodes from the group and untrack their connections.
|
|
2206
|
+
|
|
2207
|
+
Args:
|
|
2208
|
+
nodes: List of nodes to remove from the group
|
|
2209
|
+
"""
|
|
2210
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
2211
|
+
|
|
2212
|
+
self._validate_nodes_in_group(nodes)
|
|
2213
|
+
|
|
2214
|
+
GriptapeNodes.FlowManager().get_connections()
|
|
2215
|
+
nodes_being_removed = {node.name for node in nodes}
|
|
2216
|
+
|
|
2217
|
+
for node in nodes:
|
|
2218
|
+
self._untrack_external_incoming_for_node(node)
|
|
2219
|
+
self._untrack_external_outgoing_for_node(node)
|
|
2220
|
+
self._untrack_internal_for_node(node, nodes_being_removed)
|
|
2221
|
+
|
|
2222
|
+
for node in nodes:
|
|
2223
|
+
node.parent_group = None
|
|
2224
|
+
self.nodes.pop(node.name)
|
|
2092
2225
|
|
|
2093
|
-
|
|
2094
|
-
self._proxy_param_to_node_param[proxy_param_name] = (source_node, source_param.name)
|
|
2226
|
+
self.metadata["node_names_in_group"] = list(self.nodes.keys())
|
|
2095
2227
|
|
|
2096
2228
|
async def aprocess(self) -> None:
|
|
2097
2229
|
"""Execute all nodes in the group in parallel.
|
|
@@ -2100,13 +2232,9 @@ class NodeGroupProxyNode(BaseNode):
|
|
|
2100
2232
|
group concurrently using asyncio.gather and handles propagating input
|
|
2101
2233
|
values from the proxy to the grouped nodes.
|
|
2102
2234
|
"""
|
|
2103
|
-
msg = "NodeGroupProxyNode should not be executed locally."
|
|
2104
|
-
raise NotImplementedError(msg)
|
|
2105
2235
|
|
|
2106
2236
|
def process(self) -> Any:
|
|
2107
2237
|
"""Synchronous process method - not used for proxy nodes."""
|
|
2108
|
-
msg = "NodeGroupProxyNode should use aprocess() for async execution."
|
|
2109
|
-
raise NotImplementedError(msg)
|
|
2110
2238
|
|
|
2111
2239
|
|
|
2112
2240
|
class Connection:
|