griptape-nodes 0.58.1__py3-none-any.whl → 0.59.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/bootstrap/utils/python_subprocess_executor.py +2 -2
- griptape_nodes/bootstrap/workflow_executors/local_session_workflow_executor.py +0 -5
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +9 -5
- griptape_nodes/bootstrap/workflow_executors/subprocess_workflow_executor.py +0 -1
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +1 -3
- griptape_nodes/bootstrap/workflow_publishers/local_workflow_publisher.py +1 -1
- griptape_nodes/cli/commands/init.py +53 -7
- griptape_nodes/cli/shared.py +1 -0
- griptape_nodes/common/node_executor.py +216 -40
- griptape_nodes/exe_types/core_types.py +46 -0
- griptape_nodes/exe_types/node_types.py +272 -0
- griptape_nodes/machines/control_flow.py +222 -16
- griptape_nodes/machines/dag_builder.py +212 -1
- griptape_nodes/machines/parallel_resolution.py +237 -4
- griptape_nodes/node_library/workflow_registry.py +1 -1
- griptape_nodes/retained_mode/events/execution_events.py +5 -4
- griptape_nodes/retained_mode/events/flow_events.py +17 -67
- griptape_nodes/retained_mode/events/parameter_events.py +122 -1
- griptape_nodes/retained_mode/managers/event_manager.py +17 -13
- griptape_nodes/retained_mode/managers/flow_manager.py +316 -573
- griptape_nodes/retained_mode/managers/library_manager.py +32 -20
- griptape_nodes/retained_mode/managers/model_manager.py +19 -8
- griptape_nodes/retained_mode/managers/node_manager.py +463 -3
- griptape_nodes/retained_mode/managers/object_manager.py +2 -2
- griptape_nodes/retained_mode/managers/workflow_manager.py +37 -46
- griptape_nodes/retained_mode/retained_mode.py +297 -3
- {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/METADATA +3 -2
- {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/RECORD +30 -30
- {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/WHEEL +1 -1
- {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/entry_points.txt +0 -0
|
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|
|
8
8
|
from griptape_nodes.exe_types.core_types import Parameter, ParameterTypeBuiltin
|
|
9
9
|
from griptape_nodes.exe_types.node_types import CONTROL_INPUT_PARAMETER, LOCAL_EXECUTION, BaseNode, NodeResolutionState
|
|
10
10
|
from griptape_nodes.machines.fsm import FSM, State
|
|
11
|
-
from griptape_nodes.machines.parallel_resolution import ParallelResolutionMachine
|
|
11
|
+
from griptape_nodes.machines.parallel_resolution import ExecuteDagState, ParallelResolutionMachine
|
|
12
12
|
from griptape_nodes.machines.sequential_resolution import SequentialResolutionMachine
|
|
13
13
|
from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent
|
|
14
14
|
from griptape_nodes.retained_mode.events.execution_events import (
|
|
@@ -20,6 +20,11 @@ from griptape_nodes.retained_mode.events.execution_events import (
|
|
|
20
20
|
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
21
21
|
from griptape_nodes.retained_mode.managers.settings import WorkflowExecutionMode
|
|
22
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from griptape_nodes.exe_types.connections import Connections
|
|
25
|
+
from griptape_nodes.exe_types.flow import ControlFlow
|
|
26
|
+
from griptape_nodes.exe_types.node_types import NodeGroup
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
@dataclass
|
|
25
30
|
class NextNodeInfo:
|
|
@@ -45,6 +50,8 @@ class ControlFlowContext:
|
|
|
45
50
|
paused: bool = False
|
|
46
51
|
flow_name: str
|
|
47
52
|
pickle_control_flow_result: bool
|
|
53
|
+
node_to_proxy_map: dict[BaseNode, BaseNode]
|
|
54
|
+
end_node: BaseNode | None = None
|
|
48
55
|
|
|
49
56
|
def __init__(
|
|
50
57
|
self,
|
|
@@ -67,6 +74,7 @@ class ControlFlowContext:
|
|
|
67
74
|
self.resolution_machine = SequentialResolutionMachine()
|
|
68
75
|
self.current_nodes = []
|
|
69
76
|
self.pickle_control_flow_result = pickle_control_flow_result
|
|
77
|
+
self.node_to_proxy_map = {}
|
|
70
78
|
|
|
71
79
|
def get_next_nodes(self, output_parameter: Parameter | None = None) -> list[NextNodeInfo]:
|
|
72
80
|
"""Get all next nodes from the current nodes.
|
|
@@ -171,7 +179,6 @@ class ResolveNodeState(State):
|
|
|
171
179
|
# Resolve nodes - pass first node for sequential resolution
|
|
172
180
|
current_node = context.current_nodes[0] if context.current_nodes else None
|
|
173
181
|
await context.resolution_machine.resolve_node(current_node)
|
|
174
|
-
|
|
175
182
|
if context.resolution_machine.is_complete():
|
|
176
183
|
# Get the last resolved node from the DAG and set it as current
|
|
177
184
|
if isinstance(context.resolution_machine, ParallelResolutionMachine):
|
|
@@ -179,6 +186,8 @@ class ResolveNodeState(State):
|
|
|
179
186
|
if last_resolved_node:
|
|
180
187
|
context.current_nodes = [last_resolved_node]
|
|
181
188
|
return CompleteState
|
|
189
|
+
if context.end_node == current_node:
|
|
190
|
+
return CompleteState
|
|
182
191
|
return NextNodeState
|
|
183
192
|
return None
|
|
184
193
|
|
|
@@ -261,6 +270,7 @@ class CompleteState(State):
|
|
|
261
270
|
)
|
|
262
271
|
)
|
|
263
272
|
)
|
|
273
|
+
context.end_node = None
|
|
264
274
|
logger.info("Flow is complete.")
|
|
265
275
|
return None
|
|
266
276
|
|
|
@@ -284,14 +294,40 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
284
294
|
)
|
|
285
295
|
super().__init__(context)
|
|
286
296
|
|
|
287
|
-
async def start_flow(
|
|
297
|
+
async def start_flow(
|
|
298
|
+
self, start_node: BaseNode, end_node: BaseNode | None = None, *, debug_mode: bool = False
|
|
299
|
+
) -> None:
|
|
300
|
+
# FIRST: Scan all nodes in the flow and create node groups BEFORE any resolution
|
|
301
|
+
flow_manager = GriptapeNodes.FlowManager()
|
|
302
|
+
flow = flow_manager.get_flow_by_name(self._context.flow_name)
|
|
303
|
+
logger.debug("Scanning flow '%s' for node groups before execution", self._context.flow_name)
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
node_to_proxy_map = self._identify_and_create_node_group_proxies(flow, flow_manager.get_connections())
|
|
307
|
+
if node_to_proxy_map:
|
|
308
|
+
logger.info(
|
|
309
|
+
"Created %d proxy nodes for %d grouped nodes in flow '%s'",
|
|
310
|
+
len(set(node_to_proxy_map.values())),
|
|
311
|
+
len(node_to_proxy_map),
|
|
312
|
+
self._context.flow_name,
|
|
313
|
+
)
|
|
314
|
+
# Store the mapping in context so it can be used by resolution machines
|
|
315
|
+
self._context.node_to_proxy_map = node_to_proxy_map
|
|
316
|
+
except ValueError as e:
|
|
317
|
+
logger.error("Failed to process node groups: %s", e)
|
|
318
|
+
raise
|
|
319
|
+
|
|
320
|
+
# Determine the actual start node (use proxy if it's part of a group)
|
|
321
|
+
actual_start_node = node_to_proxy_map.get(start_node, start_node)
|
|
322
|
+
|
|
288
323
|
# If using DAG resolution, process data_nodes from queue first
|
|
289
324
|
if isinstance(self._context.resolution_machine, ParallelResolutionMachine):
|
|
290
|
-
current_nodes = await self._process_nodes_for_dag(
|
|
325
|
+
current_nodes = await self._process_nodes_for_dag(actual_start_node)
|
|
291
326
|
else:
|
|
292
|
-
current_nodes = [
|
|
327
|
+
current_nodes = [actual_start_node]
|
|
293
328
|
# For control flow/sequential: emit all nodes in flow as involved
|
|
294
329
|
self._context.current_nodes = current_nodes
|
|
330
|
+
self._context.end_node = end_node
|
|
295
331
|
# Set entry control parameter for initial node (None for workflow start)
|
|
296
332
|
for node in current_nodes:
|
|
297
333
|
node.set_entry_control_parameter(None)
|
|
@@ -299,12 +335,14 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
299
335
|
self._context.paused = debug_mode
|
|
300
336
|
flow_manager = GriptapeNodes.FlowManager()
|
|
301
337
|
flow = flow_manager.get_flow_by_name(self._context.flow_name)
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
338
|
+
if start_node != end_node:
|
|
339
|
+
# This blocks all nodes in the entire flow from running. If we're just resolving one node, we don't want to block that.
|
|
340
|
+
involved_nodes = list(flow.nodes.keys())
|
|
341
|
+
GriptapeNodes.EventManager().put_event(
|
|
342
|
+
ExecutionGriptapeNodeEvent(
|
|
343
|
+
wrapped_event=ExecutionEvent(payload=InvolvedNodesEvent(involved_nodes=involved_nodes))
|
|
344
|
+
)
|
|
306
345
|
)
|
|
307
|
-
)
|
|
308
346
|
await self.start(ResolveNodeState) # Begins the flow
|
|
309
347
|
|
|
310
348
|
async def update(self) -> None:
|
|
@@ -347,7 +385,7 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
347
385
|
):
|
|
348
386
|
await self.update()
|
|
349
387
|
|
|
350
|
-
async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
|
|
388
|
+
async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]: # noqa: C901
|
|
351
389
|
"""Process data_nodes from the global queue to build unified DAG.
|
|
352
390
|
|
|
353
391
|
This method identifies data_nodes in the execution queue and processes
|
|
@@ -361,7 +399,11 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
361
399
|
if dag_builder is None:
|
|
362
400
|
msg = "DAG builder is not initialized."
|
|
363
401
|
raise ValueError(msg)
|
|
364
|
-
|
|
402
|
+
|
|
403
|
+
# Use the node-to-proxy map that was created in start_flow
|
|
404
|
+
node_to_proxy_map = self._context.node_to_proxy_map
|
|
405
|
+
|
|
406
|
+
# Build with the first node (it should already be the proxy if it's part of a group)
|
|
365
407
|
dag_builder.add_node_with_dependencies(start_node, start_node.name)
|
|
366
408
|
queue_items = list(flow_manager.global_flow_queue.queue)
|
|
367
409
|
start_nodes = [start_node]
|
|
@@ -372,17 +414,165 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
372
414
|
if item.dag_execution_type in (DagExecutionType.CONTROL_NODE, DagExecutionType.START_NODE):
|
|
373
415
|
node = item.node
|
|
374
416
|
node.state = NodeResolutionState.UNRESOLVED
|
|
375
|
-
|
|
417
|
+
# Use proxy node if this node is part of a group, otherwise use original node
|
|
418
|
+
if node in node_to_proxy_map:
|
|
419
|
+
node_to_add = node_to_proxy_map[node]
|
|
420
|
+
else:
|
|
421
|
+
node_to_add = node
|
|
422
|
+
# Only add if not already added (proxy might already be in DAG)
|
|
423
|
+
if node_to_add.name not in dag_builder.node_to_reference:
|
|
424
|
+
dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
|
|
425
|
+
if node_to_add not in start_nodes:
|
|
426
|
+
start_nodes.append(node_to_add)
|
|
376
427
|
flow_manager.global_flow_queue.queue.remove(item)
|
|
377
|
-
start_nodes.append(node)
|
|
378
428
|
elif item.dag_execution_type == DagExecutionType.DATA_NODE:
|
|
379
429
|
node = item.node
|
|
380
430
|
node.state = NodeResolutionState.UNRESOLVED
|
|
381
|
-
#
|
|
382
|
-
|
|
431
|
+
# Use proxy node if this node is part of a group, otherwise use original node
|
|
432
|
+
if node in node_to_proxy_map:
|
|
433
|
+
node_to_add = node_to_proxy_map[node]
|
|
434
|
+
else:
|
|
435
|
+
node_to_add = node
|
|
436
|
+
# Only add if not already added (proxy might already be in DAG)
|
|
437
|
+
if node_to_add.name not in dag_builder.node_to_reference:
|
|
438
|
+
dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
|
|
383
439
|
flow_manager.global_flow_queue.queue.remove(item)
|
|
440
|
+
|
|
384
441
|
return start_nodes
|
|
385
442
|
|
|
443
|
+
def _identify_and_create_node_group_proxies(
|
|
444
|
+
self, flow: ControlFlow, connections: Connections
|
|
445
|
+
) -> dict[BaseNode, BaseNode]:
|
|
446
|
+
"""Scan all nodes in flow, identify groups, and create proxy nodes.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
Dictionary mapping original nodes to their proxy nodes (only for grouped nodes)
|
|
450
|
+
"""
|
|
451
|
+
from griptape_nodes.exe_types.node_types import NodeGroup, NodeGroupProxyNode
|
|
452
|
+
|
|
453
|
+
# Step 1: Identify groups by scanning all nodes in the flow
|
|
454
|
+
groups: dict[str, NodeGroup] = {}
|
|
455
|
+
for node in flow.nodes.values():
|
|
456
|
+
group_id = node.get_parameter_value("job_group")
|
|
457
|
+
|
|
458
|
+
# Skip nodes without group assignment, empty group ID, or locked nodes
|
|
459
|
+
if not group_id or group_id == "" or node.lock:
|
|
460
|
+
continue
|
|
461
|
+
|
|
462
|
+
# Create group if it doesn't exist
|
|
463
|
+
if group_id not in groups:
|
|
464
|
+
groups[group_id] = NodeGroup(group_id=group_id)
|
|
465
|
+
|
|
466
|
+
# Add node to group
|
|
467
|
+
groups[group_id].add_node(node)
|
|
468
|
+
|
|
469
|
+
if not groups:
|
|
470
|
+
return {}
|
|
471
|
+
|
|
472
|
+
# Step 2: Analyze connections for each group
|
|
473
|
+
for group in groups.values():
|
|
474
|
+
self._analyze_group_connections(group, connections)
|
|
475
|
+
|
|
476
|
+
# Step 3: Validate each group
|
|
477
|
+
for group in groups.values():
|
|
478
|
+
group.validate_no_intermediate_nodes(connections.connections)
|
|
479
|
+
|
|
480
|
+
# Step 4: Create proxy nodes and build mapping
|
|
481
|
+
node_to_proxy_map: dict[BaseNode, BaseNode] = {}
|
|
482
|
+
for group_id, group in groups.items():
|
|
483
|
+
# Create proxy node
|
|
484
|
+
proxy_name = f"__group_proxy_{group_id}"
|
|
485
|
+
proxy_node = NodeGroupProxyNode(name=proxy_name, node_group=group)
|
|
486
|
+
|
|
487
|
+
# Register the proxy node with ObjectManager so it can be found during parameter updates
|
|
488
|
+
obj_manager = GriptapeNodes.ObjectManager()
|
|
489
|
+
obj_manager.add_object_by_name(proxy_name, proxy_node)
|
|
490
|
+
|
|
491
|
+
# Map all grouped nodes to this proxy
|
|
492
|
+
for node in group.nodes.values():
|
|
493
|
+
node_to_proxy_map[node] = proxy_node
|
|
494
|
+
|
|
495
|
+
# Remap connections to point to proxy
|
|
496
|
+
self._remap_connections_to_proxy_node(group, proxy_node, connections)
|
|
497
|
+
|
|
498
|
+
# Now create proxy parameters (after remapping so original references are saved)
|
|
499
|
+
proxy_node.create_proxy_parameters()
|
|
500
|
+
|
|
501
|
+
return node_to_proxy_map
|
|
502
|
+
|
|
503
|
+
def _analyze_group_connections(self, group: NodeGroup, connections: Connections) -> None:
|
|
504
|
+
"""Analyze and categorize connections for a node group."""
|
|
505
|
+
node_names_in_group = group.nodes.keys()
|
|
506
|
+
|
|
507
|
+
# Analyze all connections in the flow
|
|
508
|
+
for conn in connections.connections.values():
|
|
509
|
+
source_in_group = conn.source_node.name in node_names_in_group
|
|
510
|
+
target_in_group = conn.target_node.name in node_names_in_group
|
|
511
|
+
|
|
512
|
+
if source_in_group and target_in_group:
|
|
513
|
+
# Both endpoints in group - internal connection
|
|
514
|
+
group.internal_connections.append(conn)
|
|
515
|
+
elif source_in_group and not target_in_group:
|
|
516
|
+
# From group to outside - external outgoing
|
|
517
|
+
group.external_outgoing_connections.append(conn)
|
|
518
|
+
elif not source_in_group and target_in_group:
|
|
519
|
+
# From outside to group - external incoming
|
|
520
|
+
group.external_incoming_connections.append(conn)
|
|
521
|
+
|
|
522
|
+
def _remap_connections_to_proxy_node(
|
|
523
|
+
self, group: NodeGroup, proxy_node: BaseNode, connections: Connections
|
|
524
|
+
) -> None:
|
|
525
|
+
"""Remap external connections from group nodes to the proxy node."""
|
|
526
|
+
# Remap external incoming connections (from outside -> group becomes outside -> proxy)
|
|
527
|
+
for conn in group.external_incoming_connections:
|
|
528
|
+
conn_id = id(conn)
|
|
529
|
+
|
|
530
|
+
# Save original target node before remapping (for cleanup later)
|
|
531
|
+
original_target_node = conn.target_node
|
|
532
|
+
group.original_incoming_targets[conn_id] = original_target_node
|
|
533
|
+
|
|
534
|
+
# Remove old incoming index entry
|
|
535
|
+
if (
|
|
536
|
+
conn.target_node.name in connections.incoming_index
|
|
537
|
+
and conn.target_parameter.name in connections.incoming_index[conn.target_node.name]
|
|
538
|
+
):
|
|
539
|
+
connections.incoming_index[conn.target_node.name][conn.target_parameter.name].remove(conn_id)
|
|
540
|
+
|
|
541
|
+
# Update connection target to proxy
|
|
542
|
+
conn.target_node = proxy_node
|
|
543
|
+
|
|
544
|
+
# Create proxy parameter name using original node name
|
|
545
|
+
sanitized_node_name = original_target_node.name.replace(" ", "_")
|
|
546
|
+
proxy_param_name = f"{sanitized_node_name}__{conn.target_parameter.name}"
|
|
547
|
+
|
|
548
|
+
# Add new incoming index entry with proxy parameter name
|
|
549
|
+
connections.incoming_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
|
|
550
|
+
|
|
551
|
+
# Remap external outgoing connections (group -> outside becomes proxy -> outside)
|
|
552
|
+
for conn in group.external_outgoing_connections:
|
|
553
|
+
conn_id = id(conn)
|
|
554
|
+
|
|
555
|
+
# Save original source node before remapping (for cleanup later)
|
|
556
|
+
original_source_node = conn.source_node
|
|
557
|
+
group.original_outgoing_sources[conn_id] = original_source_node
|
|
558
|
+
|
|
559
|
+
# Remove old outgoing index entry
|
|
560
|
+
if (
|
|
561
|
+
conn.source_node.name in connections.outgoing_index
|
|
562
|
+
and conn.source_parameter.name in connections.outgoing_index[conn.source_node.name]
|
|
563
|
+
):
|
|
564
|
+
connections.outgoing_index[conn.source_node.name][conn.source_parameter.name].remove(conn_id)
|
|
565
|
+
|
|
566
|
+
# Update connection source to proxy
|
|
567
|
+
conn.source_node = proxy_node
|
|
568
|
+
|
|
569
|
+
# Create proxy parameter name using original node name
|
|
570
|
+
sanitized_node_name = original_source_node.name.replace(" ", "_")
|
|
571
|
+
proxy_param_name = f"{sanitized_node_name}__{conn.source_parameter.name}"
|
|
572
|
+
|
|
573
|
+
# Add new outgoing index entry with proxy parameter name
|
|
574
|
+
connections.outgoing_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
|
|
575
|
+
|
|
386
576
|
async def cancel_flow(self) -> None:
|
|
387
577
|
"""Cancel all nodes in the flow by delegating to the resolution machine."""
|
|
388
578
|
await self.resolution_machine.cancel_all_nodes()
|
|
@@ -391,6 +581,22 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
|
|
|
391
581
|
self._context.reset(cancel=cancel)
|
|
392
582
|
self._current_state = None
|
|
393
583
|
|
|
584
|
+
def cleanup_proxy_nodes(self) -> None:
|
|
585
|
+
"""Cleanup all proxy nodes and restore original connections."""
|
|
586
|
+
if not self._context.node_to_proxy_map:
|
|
587
|
+
# If we're calling cleanup, but it's already been cleaned up, we just want to return.
|
|
588
|
+
return
|
|
589
|
+
|
|
590
|
+
# Get all unique proxy nodes
|
|
591
|
+
proxy_nodes = set(self._context.node_to_proxy_map.values())
|
|
592
|
+
|
|
593
|
+
# Cleanup each proxy node using the existing method
|
|
594
|
+
for proxy_node in proxy_nodes:
|
|
595
|
+
ExecuteDagState._cleanup_proxy_node(proxy_node)
|
|
596
|
+
|
|
597
|
+
# Clear the proxy mapping
|
|
598
|
+
self._context.node_to_proxy_map.clear()
|
|
599
|
+
|
|
394
600
|
@property
|
|
395
601
|
def resolution_machine(self) -> ParallelResolutionMachine | SequentialResolutionMachine:
|
|
396
602
|
return self._context.resolution_machine
|
|
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|
|
13
13
|
import asyncio
|
|
14
14
|
|
|
15
15
|
from griptape_nodes.exe_types.connections import Connections
|
|
16
|
-
from griptape_nodes.exe_types.node_types import BaseNode
|
|
16
|
+
from griptape_nodes.exe_types.node_types import BaseNode, NodeGroup
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger("griptape_nodes")
|
|
19
19
|
|
|
@@ -224,3 +224,214 @@ class DagBuilder:
|
|
|
224
224
|
for node_name in self.graph_to_nodes[graph_name]:
|
|
225
225
|
self.node_to_reference.pop(node_name, None)
|
|
226
226
|
self.graph_to_nodes.pop(graph_name, None)
|
|
227
|
+
|
|
228
|
+
def identify_and_process_node_groups(self) -> dict[str, BaseNode]:
|
|
229
|
+
"""Identify node groups, validate them, and replace with proxy nodes.
|
|
230
|
+
|
|
231
|
+
Scans all nodes in the DAG for non-empty node_group parameter values,
|
|
232
|
+
creates NodeGroup instances, validates they have no intermediate ungrouped nodes,
|
|
233
|
+
and replaces them with NodeGroupProxyNode instances in the DAG.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Dictionary mapping group IDs to their proxy nodes
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
ValueError: If validation fails (e.g., ungrouped nodes between grouped nodes)
|
|
240
|
+
"""
|
|
241
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
242
|
+
|
|
243
|
+
connections = GriptapeNodes.FlowManager().get_connections()
|
|
244
|
+
|
|
245
|
+
# Step 1: Identify groups by scanning all nodes
|
|
246
|
+
groups = self._identify_node_groups(connections)
|
|
247
|
+
|
|
248
|
+
if not groups:
|
|
249
|
+
return {}
|
|
250
|
+
|
|
251
|
+
# Step 2: Validate each group
|
|
252
|
+
for group in groups.values():
|
|
253
|
+
group.validate_no_intermediate_nodes(connections.connections)
|
|
254
|
+
|
|
255
|
+
# Step 3: Create proxy nodes and replace groups in DAG
|
|
256
|
+
proxy_nodes = {}
|
|
257
|
+
for group_id, group in groups.items():
|
|
258
|
+
proxy_node = self._create_and_install_proxy_node(group_id, group, connections)
|
|
259
|
+
proxy_nodes[group_id] = proxy_node
|
|
260
|
+
|
|
261
|
+
return proxy_nodes
|
|
262
|
+
|
|
263
|
+
def _identify_node_groups(self, connections: Connections) -> dict[str, NodeGroup]:
|
|
264
|
+
"""Identify and build NodeGroup instances from nodes in the DAG.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
connections: Connections object for analyzing connection topology
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Dictionary mapping group IDs to NodeGroup instances
|
|
271
|
+
"""
|
|
272
|
+
from griptape_nodes.exe_types.node_types import NodeGroup
|
|
273
|
+
|
|
274
|
+
groups: dict[str, NodeGroup] = {}
|
|
275
|
+
|
|
276
|
+
# Scan all nodes in DAG for group membership
|
|
277
|
+
for dag_node in self.node_to_reference.values():
|
|
278
|
+
node = dag_node.node_reference
|
|
279
|
+
group_id = node.get_parameter_value("job_group")
|
|
280
|
+
|
|
281
|
+
# Skip nodes without group assignment or empty group ID
|
|
282
|
+
if not group_id or group_id == "":
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
# Create group if it doesn't exist
|
|
286
|
+
if group_id not in groups:
|
|
287
|
+
groups[group_id] = NodeGroup(group_id=group_id)
|
|
288
|
+
|
|
289
|
+
# Add node to group
|
|
290
|
+
groups[group_id].add_node(node)
|
|
291
|
+
|
|
292
|
+
# Analyze connections for each group
|
|
293
|
+
for group in groups.values():
|
|
294
|
+
self._analyze_group_connections(group, connections)
|
|
295
|
+
|
|
296
|
+
return groups
|
|
297
|
+
|
|
298
|
+
def _analyze_group_connections(self, group: NodeGroup, connections: Connections) -> None:
|
|
299
|
+
"""Analyze and categorize connections for a node group.
|
|
300
|
+
|
|
301
|
+
Categorizes all connections involving group nodes as either:
|
|
302
|
+
- Internal: Both endpoints within the group
|
|
303
|
+
- External incoming: From outside node to group node
|
|
304
|
+
- External outgoing: From group node to outside node
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
group: NodeGroup to analyze
|
|
308
|
+
connections: Connections object containing all flow connections
|
|
309
|
+
"""
|
|
310
|
+
node_names_in_group = set(group.nodes.keys())
|
|
311
|
+
|
|
312
|
+
# Analyze all connections in the flow
|
|
313
|
+
for conn in connections.connections.values():
|
|
314
|
+
source_in_group = conn.source_node.name in node_names_in_group
|
|
315
|
+
target_in_group = conn.target_node.name in node_names_in_group
|
|
316
|
+
|
|
317
|
+
if source_in_group and target_in_group:
|
|
318
|
+
# Both endpoints in group - internal connection
|
|
319
|
+
group.internal_connections.append(conn)
|
|
320
|
+
elif source_in_group and not target_in_group:
|
|
321
|
+
# From group to outside - external outgoing
|
|
322
|
+
group.external_outgoing_connections.append(conn)
|
|
323
|
+
elif not source_in_group and target_in_group:
|
|
324
|
+
# From outside to group - external incoming
|
|
325
|
+
group.external_incoming_connections.append(conn)
|
|
326
|
+
|
|
327
|
+
def _create_and_install_proxy_node(self, group_id: str, group: NodeGroup, connections: Connections) -> BaseNode:
|
|
328
|
+
"""Create a proxy node for a group and install it in the DAG.
|
|
329
|
+
|
|
330
|
+
Creates a NodeGroupProxyNode, adds it to the DAG, remaps all external
|
|
331
|
+
connections to point to the proxy, and removes the original grouped
|
|
332
|
+
nodes from the DAG.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
group_id: Unique identifier for the group
|
|
336
|
+
group: NodeGroup instance to replace
|
|
337
|
+
connections: Connections object for remapping connections
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
The created NodeGroupProxyNode
|
|
341
|
+
"""
|
|
342
|
+
from griptape_nodes.exe_types.node_types import NodeGroupProxyNode
|
|
343
|
+
|
|
344
|
+
# Create proxy node with unique name
|
|
345
|
+
proxy_name = f"__group_proxy_{group_id}"
|
|
346
|
+
proxy_node = NodeGroupProxyNode(name=proxy_name, node_group=group)
|
|
347
|
+
|
|
348
|
+
# Determine which graph to add proxy to (use first grouped node's graph)
|
|
349
|
+
target_graph_name = None
|
|
350
|
+
for graph_name, node_set in self.graph_to_nodes.items():
|
|
351
|
+
if any(node_name in node_set for node_name in group.nodes):
|
|
352
|
+
target_graph_name = graph_name
|
|
353
|
+
break
|
|
354
|
+
|
|
355
|
+
if target_graph_name is None:
|
|
356
|
+
target_graph_name = "default"
|
|
357
|
+
|
|
358
|
+
# Add proxy node to DAG
|
|
359
|
+
self.add_node(proxy_node, target_graph_name)
|
|
360
|
+
|
|
361
|
+
# Remap external connections to proxy
|
|
362
|
+
self._remap_connections_to_proxy(group, proxy_node, connections)
|
|
363
|
+
|
|
364
|
+
# Remove grouped nodes from DAG
|
|
365
|
+
self._remove_grouped_nodes_from_dag(group)
|
|
366
|
+
|
|
367
|
+
return proxy_node
|
|
368
|
+
|
|
369
|
+
def _remap_connections_to_proxy(self, group: NodeGroup, proxy_node: BaseNode, connections: Connections) -> None:
|
|
370
|
+
"""Remap external connections from group nodes to the proxy node.
|
|
371
|
+
|
|
372
|
+
Updates the connection indices and Connection objects to redirect
|
|
373
|
+
external connections through the proxy node instead of the original
|
|
374
|
+
grouped nodes.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
group: NodeGroup being replaced
|
|
378
|
+
proxy_node: Proxy node that will handle external connections
|
|
379
|
+
connections: Connections object to update
|
|
380
|
+
"""
|
|
381
|
+
# Remap external incoming connections (from outside -> group becomes outside -> proxy)
|
|
382
|
+
for conn in group.external_incoming_connections:
|
|
383
|
+
conn_id = id(conn)
|
|
384
|
+
|
|
385
|
+
# Remove old incoming index entry
|
|
386
|
+
if (
|
|
387
|
+
conn.target_node.name in connections.incoming_index
|
|
388
|
+
and conn.target_parameter.name in connections.incoming_index[conn.target_node.name]
|
|
389
|
+
):
|
|
390
|
+
connections.incoming_index[conn.target_node.name][conn.target_parameter.name].remove(conn_id)
|
|
391
|
+
|
|
392
|
+
# Update connection target to proxy
|
|
393
|
+
conn.target_node = proxy_node
|
|
394
|
+
|
|
395
|
+
# Add new incoming index entry
|
|
396
|
+
connections.incoming_index.setdefault(proxy_node.name, {}).setdefault(
|
|
397
|
+
conn.target_parameter.name, []
|
|
398
|
+
).append(conn_id)
|
|
399
|
+
|
|
400
|
+
# Remap external outgoing connections (group -> outside becomes proxy -> outside)
|
|
401
|
+
for conn in group.external_outgoing_connections:
|
|
402
|
+
conn_id = id(conn)
|
|
403
|
+
|
|
404
|
+
# Remove old outgoing index entry
|
|
405
|
+
if (
|
|
406
|
+
conn.source_node.name in connections.outgoing_index
|
|
407
|
+
and conn.source_parameter.name in connections.outgoing_index[conn.source_node.name]
|
|
408
|
+
):
|
|
409
|
+
connections.outgoing_index[conn.source_node.name][conn.source_parameter.name].remove(conn_id)
|
|
410
|
+
|
|
411
|
+
# Update connection source to proxy
|
|
412
|
+
conn.source_node = proxy_node
|
|
413
|
+
|
|
414
|
+
# Add new outgoing index entry
|
|
415
|
+
connections.outgoing_index.setdefault(proxy_node.name, {}).setdefault(
|
|
416
|
+
conn.source_parameter.name, []
|
|
417
|
+
).append(conn_id)
|
|
418
|
+
|
|
419
|
+
def _remove_grouped_nodes_from_dag(self, group: NodeGroup) -> None:
|
|
420
|
+
"""Remove all nodes in a group from the DAG graphs and references.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
group: NodeGroup whose nodes should be removed from the DAG
|
|
424
|
+
"""
|
|
425
|
+
for node_name in group.nodes:
|
|
426
|
+
# Remove from node_to_reference
|
|
427
|
+
if node_name in self.node_to_reference:
|
|
428
|
+
del self.node_to_reference[node_name]
|
|
429
|
+
|
|
430
|
+
# Remove from all graphs
|
|
431
|
+
for graph in self.graphs.values():
|
|
432
|
+
if node_name in graph.nodes():
|
|
433
|
+
graph.remove_node(node_name)
|
|
434
|
+
|
|
435
|
+
# Remove from graph_to_nodes tracking
|
|
436
|
+
for node_set in self.graph_to_nodes.values():
|
|
437
|
+
node_set.discard(node_name)
|