griptape-nodes 0.58.1__py3-none-any.whl → 0.59.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. griptape_nodes/bootstrap/utils/python_subprocess_executor.py +2 -2
  2. griptape_nodes/bootstrap/workflow_executors/local_session_workflow_executor.py +0 -5
  3. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +9 -5
  4. griptape_nodes/bootstrap/workflow_executors/subprocess_workflow_executor.py +0 -1
  5. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +1 -3
  6. griptape_nodes/bootstrap/workflow_publishers/local_workflow_publisher.py +1 -1
  7. griptape_nodes/cli/commands/init.py +53 -7
  8. griptape_nodes/cli/shared.py +1 -0
  9. griptape_nodes/common/node_executor.py +216 -40
  10. griptape_nodes/exe_types/core_types.py +46 -0
  11. griptape_nodes/exe_types/node_types.py +272 -0
  12. griptape_nodes/machines/control_flow.py +222 -16
  13. griptape_nodes/machines/dag_builder.py +212 -1
  14. griptape_nodes/machines/parallel_resolution.py +237 -4
  15. griptape_nodes/node_library/workflow_registry.py +1 -1
  16. griptape_nodes/retained_mode/events/execution_events.py +5 -4
  17. griptape_nodes/retained_mode/events/flow_events.py +17 -67
  18. griptape_nodes/retained_mode/events/parameter_events.py +122 -1
  19. griptape_nodes/retained_mode/managers/event_manager.py +17 -13
  20. griptape_nodes/retained_mode/managers/flow_manager.py +316 -573
  21. griptape_nodes/retained_mode/managers/library_manager.py +32 -20
  22. griptape_nodes/retained_mode/managers/model_manager.py +19 -8
  23. griptape_nodes/retained_mode/managers/node_manager.py +463 -3
  24. griptape_nodes/retained_mode/managers/object_manager.py +2 -2
  25. griptape_nodes/retained_mode/managers/workflow_manager.py +37 -46
  26. griptape_nodes/retained_mode/retained_mode.py +297 -3
  27. {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/METADATA +3 -2
  28. {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/RECORD +30 -30
  29. {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/WHEEL +1 -1
  30. {griptape_nodes-0.58.1.dist-info → griptape_nodes-0.59.0.dist-info}/entry_points.txt +0 -0
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
8
8
  from griptape_nodes.exe_types.core_types import Parameter, ParameterTypeBuiltin
9
9
  from griptape_nodes.exe_types.node_types import CONTROL_INPUT_PARAMETER, LOCAL_EXECUTION, BaseNode, NodeResolutionState
10
10
  from griptape_nodes.machines.fsm import FSM, State
11
- from griptape_nodes.machines.parallel_resolution import ParallelResolutionMachine
11
+ from griptape_nodes.machines.parallel_resolution import ExecuteDagState, ParallelResolutionMachine
12
12
  from griptape_nodes.machines.sequential_resolution import SequentialResolutionMachine
13
13
  from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent
14
14
  from griptape_nodes.retained_mode.events.execution_events import (
@@ -20,6 +20,11 @@ from griptape_nodes.retained_mode.events.execution_events import (
20
20
  from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
21
21
  from griptape_nodes.retained_mode.managers.settings import WorkflowExecutionMode
22
22
 
23
+ if TYPE_CHECKING:
24
+ from griptape_nodes.exe_types.connections import Connections
25
+ from griptape_nodes.exe_types.flow import ControlFlow
26
+ from griptape_nodes.exe_types.node_types import NodeGroup
27
+
23
28
 
24
29
  @dataclass
25
30
  class NextNodeInfo:
@@ -45,6 +50,8 @@ class ControlFlowContext:
45
50
  paused: bool = False
46
51
  flow_name: str
47
52
  pickle_control_flow_result: bool
53
+ node_to_proxy_map: dict[BaseNode, BaseNode]
54
+ end_node: BaseNode | None = None
48
55
 
49
56
  def __init__(
50
57
  self,
@@ -67,6 +74,7 @@ class ControlFlowContext:
67
74
  self.resolution_machine = SequentialResolutionMachine()
68
75
  self.current_nodes = []
69
76
  self.pickle_control_flow_result = pickle_control_flow_result
77
+ self.node_to_proxy_map = {}
70
78
 
71
79
  def get_next_nodes(self, output_parameter: Parameter | None = None) -> list[NextNodeInfo]:
72
80
  """Get all next nodes from the current nodes.
@@ -171,7 +179,6 @@ class ResolveNodeState(State):
171
179
  # Resolve nodes - pass first node for sequential resolution
172
180
  current_node = context.current_nodes[0] if context.current_nodes else None
173
181
  await context.resolution_machine.resolve_node(current_node)
174
-
175
182
  if context.resolution_machine.is_complete():
176
183
  # Get the last resolved node from the DAG and set it as current
177
184
  if isinstance(context.resolution_machine, ParallelResolutionMachine):
@@ -179,6 +186,8 @@ class ResolveNodeState(State):
179
186
  if last_resolved_node:
180
187
  context.current_nodes = [last_resolved_node]
181
188
  return CompleteState
189
+ if context.end_node == current_node:
190
+ return CompleteState
182
191
  return NextNodeState
183
192
  return None
184
193
 
@@ -261,6 +270,7 @@ class CompleteState(State):
261
270
  )
262
271
  )
263
272
  )
273
+ context.end_node = None
264
274
  logger.info("Flow is complete.")
265
275
  return None
266
276
 
@@ -284,14 +294,40 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
284
294
  )
285
295
  super().__init__(context)
286
296
 
287
- async def start_flow(self, start_node: BaseNode, debug_mode: bool = False) -> None: # noqa: FBT001, FBT002
297
+ async def start_flow(
298
+ self, start_node: BaseNode, end_node: BaseNode | None = None, *, debug_mode: bool = False
299
+ ) -> None:
300
+ # FIRST: Scan all nodes in the flow and create node groups BEFORE any resolution
301
+ flow_manager = GriptapeNodes.FlowManager()
302
+ flow = flow_manager.get_flow_by_name(self._context.flow_name)
303
+ logger.debug("Scanning flow '%s' for node groups before execution", self._context.flow_name)
304
+
305
+ try:
306
+ node_to_proxy_map = self._identify_and_create_node_group_proxies(flow, flow_manager.get_connections())
307
+ if node_to_proxy_map:
308
+ logger.info(
309
+ "Created %d proxy nodes for %d grouped nodes in flow '%s'",
310
+ len(set(node_to_proxy_map.values())),
311
+ len(node_to_proxy_map),
312
+ self._context.flow_name,
313
+ )
314
+ # Store the mapping in context so it can be used by resolution machines
315
+ self._context.node_to_proxy_map = node_to_proxy_map
316
+ except ValueError as e:
317
+ logger.error("Failed to process node groups: %s", e)
318
+ raise
319
+
320
+ # Determine the actual start node (use proxy if it's part of a group)
321
+ actual_start_node = node_to_proxy_map.get(start_node, start_node)
322
+
288
323
  # If using DAG resolution, process data_nodes from queue first
289
324
  if isinstance(self._context.resolution_machine, ParallelResolutionMachine):
290
- current_nodes = await self._process_nodes_for_dag(start_node)
325
+ current_nodes = await self._process_nodes_for_dag(actual_start_node)
291
326
  else:
292
- current_nodes = [start_node]
327
+ current_nodes = [actual_start_node]
293
328
  # For control flow/sequential: emit all nodes in flow as involved
294
329
  self._context.current_nodes = current_nodes
330
+ self._context.end_node = end_node
295
331
  # Set entry control parameter for initial node (None for workflow start)
296
332
  for node in current_nodes:
297
333
  node.set_entry_control_parameter(None)
@@ -299,12 +335,14 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
299
335
  self._context.paused = debug_mode
300
336
  flow_manager = GriptapeNodes.FlowManager()
301
337
  flow = flow_manager.get_flow_by_name(self._context.flow_name)
302
- involved_nodes = list(flow.nodes.keys())
303
- GriptapeNodes.EventManager().put_event(
304
- ExecutionGriptapeNodeEvent(
305
- wrapped_event=ExecutionEvent(payload=InvolvedNodesEvent(involved_nodes=involved_nodes))
338
+ if start_node != end_node:
339
+ # This blocks all nodes in the entire flow from running. If we're just resolving one node, we don't want to block that.
340
+ involved_nodes = list(flow.nodes.keys())
341
+ GriptapeNodes.EventManager().put_event(
342
+ ExecutionGriptapeNodeEvent(
343
+ wrapped_event=ExecutionEvent(payload=InvolvedNodesEvent(involved_nodes=involved_nodes))
344
+ )
306
345
  )
307
- )
308
346
  await self.start(ResolveNodeState) # Begins the flow
309
347
 
310
348
  async def update(self) -> None:
@@ -347,7 +385,7 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
347
385
  ):
348
386
  await self.update()
349
387
 
350
- async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]:
388
+ async def _process_nodes_for_dag(self, start_node: BaseNode) -> list[BaseNode]: # noqa: C901
351
389
  """Process data_nodes from the global queue to build unified DAG.
352
390
 
353
391
  This method identifies data_nodes in the execution queue and processes
@@ -361,7 +399,11 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
361
399
  if dag_builder is None:
362
400
  msg = "DAG builder is not initialized."
363
401
  raise ValueError(msg)
364
- # Build with the first node:
402
+
403
+ # Use the node-to-proxy map that was created in start_flow
404
+ node_to_proxy_map = self._context.node_to_proxy_map
405
+
406
+ # Build with the first node (it should already be the proxy if it's part of a group)
365
407
  dag_builder.add_node_with_dependencies(start_node, start_node.name)
366
408
  queue_items = list(flow_manager.global_flow_queue.queue)
367
409
  start_nodes = [start_node]
@@ -372,17 +414,165 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
372
414
  if item.dag_execution_type in (DagExecutionType.CONTROL_NODE, DagExecutionType.START_NODE):
373
415
  node = item.node
374
416
  node.state = NodeResolutionState.UNRESOLVED
375
- dag_builder.add_node_with_dependencies(node, node.name)
417
+ # Use proxy node if this node is part of a group, otherwise use original node
418
+ if node in node_to_proxy_map:
419
+ node_to_add = node_to_proxy_map[node]
420
+ else:
421
+ node_to_add = node
422
+ # Only add if not already added (proxy might already be in DAG)
423
+ if node_to_add.name not in dag_builder.node_to_reference:
424
+ dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
425
+ if node_to_add not in start_nodes:
426
+ start_nodes.append(node_to_add)
376
427
  flow_manager.global_flow_queue.queue.remove(item)
377
- start_nodes.append(node)
378
428
  elif item.dag_execution_type == DagExecutionType.DATA_NODE:
379
429
  node = item.node
380
430
  node.state = NodeResolutionState.UNRESOLVED
381
- # Build here.
382
- dag_builder.add_node_with_dependencies(node, node.name)
431
+ # Use proxy node if this node is part of a group, otherwise use original node
432
+ if node in node_to_proxy_map:
433
+ node_to_add = node_to_proxy_map[node]
434
+ else:
435
+ node_to_add = node
436
+ # Only add if not already added (proxy might already be in DAG)
437
+ if node_to_add.name not in dag_builder.node_to_reference:
438
+ dag_builder.add_node_with_dependencies(node_to_add, node_to_add.name)
383
439
  flow_manager.global_flow_queue.queue.remove(item)
440
+
384
441
  return start_nodes
385
442
 
443
+ def _identify_and_create_node_group_proxies(
444
+ self, flow: ControlFlow, connections: Connections
445
+ ) -> dict[BaseNode, BaseNode]:
446
+ """Scan all nodes in flow, identify groups, and create proxy nodes.
447
+
448
+ Returns:
449
+ Dictionary mapping original nodes to their proxy nodes (only for grouped nodes)
450
+ """
451
+ from griptape_nodes.exe_types.node_types import NodeGroup, NodeGroupProxyNode
452
+
453
+ # Step 1: Identify groups by scanning all nodes in the flow
454
+ groups: dict[str, NodeGroup] = {}
455
+ for node in flow.nodes.values():
456
+ group_id = node.get_parameter_value("job_group")
457
+
458
+ # Skip nodes without group assignment, empty group ID, or locked nodes
459
+ if not group_id or group_id == "" or node.lock:
460
+ continue
461
+
462
+ # Create group if it doesn't exist
463
+ if group_id not in groups:
464
+ groups[group_id] = NodeGroup(group_id=group_id)
465
+
466
+ # Add node to group
467
+ groups[group_id].add_node(node)
468
+
469
+ if not groups:
470
+ return {}
471
+
472
+ # Step 2: Analyze connections for each group
473
+ for group in groups.values():
474
+ self._analyze_group_connections(group, connections)
475
+
476
+ # Step 3: Validate each group
477
+ for group in groups.values():
478
+ group.validate_no_intermediate_nodes(connections.connections)
479
+
480
+ # Step 4: Create proxy nodes and build mapping
481
+ node_to_proxy_map: dict[BaseNode, BaseNode] = {}
482
+ for group_id, group in groups.items():
483
+ # Create proxy node
484
+ proxy_name = f"__group_proxy_{group_id}"
485
+ proxy_node = NodeGroupProxyNode(name=proxy_name, node_group=group)
486
+
487
+ # Register the proxy node with ObjectManager so it can be found during parameter updates
488
+ obj_manager = GriptapeNodes.ObjectManager()
489
+ obj_manager.add_object_by_name(proxy_name, proxy_node)
490
+
491
+ # Map all grouped nodes to this proxy
492
+ for node in group.nodes.values():
493
+ node_to_proxy_map[node] = proxy_node
494
+
495
+ # Remap connections to point to proxy
496
+ self._remap_connections_to_proxy_node(group, proxy_node, connections)
497
+
498
+ # Now create proxy parameters (after remapping so original references are saved)
499
+ proxy_node.create_proxy_parameters()
500
+
501
+ return node_to_proxy_map
502
+
503
+ def _analyze_group_connections(self, group: NodeGroup, connections: Connections) -> None:
504
+ """Analyze and categorize connections for a node group."""
505
+ node_names_in_group = group.nodes.keys()
506
+
507
+ # Analyze all connections in the flow
508
+ for conn in connections.connections.values():
509
+ source_in_group = conn.source_node.name in node_names_in_group
510
+ target_in_group = conn.target_node.name in node_names_in_group
511
+
512
+ if source_in_group and target_in_group:
513
+ # Both endpoints in group - internal connection
514
+ group.internal_connections.append(conn)
515
+ elif source_in_group and not target_in_group:
516
+ # From group to outside - external outgoing
517
+ group.external_outgoing_connections.append(conn)
518
+ elif not source_in_group and target_in_group:
519
+ # From outside to group - external incoming
520
+ group.external_incoming_connections.append(conn)
521
+
522
+ def _remap_connections_to_proxy_node(
523
+ self, group: NodeGroup, proxy_node: BaseNode, connections: Connections
524
+ ) -> None:
525
+ """Remap external connections from group nodes to the proxy node."""
526
+ # Remap external incoming connections (from outside -> group becomes outside -> proxy)
527
+ for conn in group.external_incoming_connections:
528
+ conn_id = id(conn)
529
+
530
+ # Save original target node before remapping (for cleanup later)
531
+ original_target_node = conn.target_node
532
+ group.original_incoming_targets[conn_id] = original_target_node
533
+
534
+ # Remove old incoming index entry
535
+ if (
536
+ conn.target_node.name in connections.incoming_index
537
+ and conn.target_parameter.name in connections.incoming_index[conn.target_node.name]
538
+ ):
539
+ connections.incoming_index[conn.target_node.name][conn.target_parameter.name].remove(conn_id)
540
+
541
+ # Update connection target to proxy
542
+ conn.target_node = proxy_node
543
+
544
+ # Create proxy parameter name using original node name
545
+ sanitized_node_name = original_target_node.name.replace(" ", "_")
546
+ proxy_param_name = f"{sanitized_node_name}__{conn.target_parameter.name}"
547
+
548
+ # Add new incoming index entry with proxy parameter name
549
+ connections.incoming_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
550
+
551
+ # Remap external outgoing connections (group -> outside becomes proxy -> outside)
552
+ for conn in group.external_outgoing_connections:
553
+ conn_id = id(conn)
554
+
555
+ # Save original source node before remapping (for cleanup later)
556
+ original_source_node = conn.source_node
557
+ group.original_outgoing_sources[conn_id] = original_source_node
558
+
559
+ # Remove old outgoing index entry
560
+ if (
561
+ conn.source_node.name in connections.outgoing_index
562
+ and conn.source_parameter.name in connections.outgoing_index[conn.source_node.name]
563
+ ):
564
+ connections.outgoing_index[conn.source_node.name][conn.source_parameter.name].remove(conn_id)
565
+
566
+ # Update connection source to proxy
567
+ conn.source_node = proxy_node
568
+
569
+ # Create proxy parameter name using original node name
570
+ sanitized_node_name = original_source_node.name.replace(" ", "_")
571
+ proxy_param_name = f"{sanitized_node_name}__{conn.source_parameter.name}"
572
+
573
+ # Add new outgoing index entry with proxy parameter name
574
+ connections.outgoing_index.setdefault(proxy_node.name, {}).setdefault(proxy_param_name, []).append(conn_id)
575
+
386
576
  async def cancel_flow(self) -> None:
387
577
  """Cancel all nodes in the flow by delegating to the resolution machine."""
388
578
  await self.resolution_machine.cancel_all_nodes()
@@ -391,6 +581,22 @@ class ControlFlowMachine(FSM[ControlFlowContext]):
391
581
  self._context.reset(cancel=cancel)
392
582
  self._current_state = None
393
583
 
584
+ def cleanup_proxy_nodes(self) -> None:
585
+ """Cleanup all proxy nodes and restore original connections."""
586
+ if not self._context.node_to_proxy_map:
587
+ # If we're calling cleanup, but it's already been cleaned up, we just want to return.
588
+ return
589
+
590
+ # Get all unique proxy nodes
591
+ proxy_nodes = set(self._context.node_to_proxy_map.values())
592
+
593
+ # Cleanup each proxy node using the existing method
594
+ for proxy_node in proxy_nodes:
595
+ ExecuteDagState._cleanup_proxy_node(proxy_node)
596
+
597
+ # Clear the proxy mapping
598
+ self._context.node_to_proxy_map.clear()
599
+
394
600
  @property
395
601
  def resolution_machine(self) -> ParallelResolutionMachine | SequentialResolutionMachine:
396
602
  return self._context.resolution_machine
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
13
13
  import asyncio
14
14
 
15
15
  from griptape_nodes.exe_types.connections import Connections
16
- from griptape_nodes.exe_types.node_types import BaseNode
16
+ from griptape_nodes.exe_types.node_types import BaseNode, NodeGroup
17
17
 
18
18
  logger = logging.getLogger("griptape_nodes")
19
19
 
@@ -224,3 +224,214 @@ class DagBuilder:
224
224
  for node_name in self.graph_to_nodes[graph_name]:
225
225
  self.node_to_reference.pop(node_name, None)
226
226
  self.graph_to_nodes.pop(graph_name, None)
227
+
228
+ def identify_and_process_node_groups(self) -> dict[str, BaseNode]:
229
+ """Identify node groups, validate them, and replace with proxy nodes.
230
+
231
+ Scans all nodes in the DAG for non-empty node_group parameter values,
232
+ creates NodeGroup instances, validates they have no intermediate ungrouped nodes,
233
+ and replaces them with NodeGroupProxyNode instances in the DAG.
234
+
235
+ Returns:
236
+ Dictionary mapping group IDs to their proxy nodes
237
+
238
+ Raises:
239
+ ValueError: If validation fails (e.g., ungrouped nodes between grouped nodes)
240
+ """
241
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
242
+
243
+ connections = GriptapeNodes.FlowManager().get_connections()
244
+
245
+ # Step 1: Identify groups by scanning all nodes
246
+ groups = self._identify_node_groups(connections)
247
+
248
+ if not groups:
249
+ return {}
250
+
251
+ # Step 2: Validate each group
252
+ for group in groups.values():
253
+ group.validate_no_intermediate_nodes(connections.connections)
254
+
255
+ # Step 3: Create proxy nodes and replace groups in DAG
256
+ proxy_nodes = {}
257
+ for group_id, group in groups.items():
258
+ proxy_node = self._create_and_install_proxy_node(group_id, group, connections)
259
+ proxy_nodes[group_id] = proxy_node
260
+
261
+ return proxy_nodes
262
+
263
+ def _identify_node_groups(self, connections: Connections) -> dict[str, NodeGroup]:
264
+ """Identify and build NodeGroup instances from nodes in the DAG.
265
+
266
+ Args:
267
+ connections: Connections object for analyzing connection topology
268
+
269
+ Returns:
270
+ Dictionary mapping group IDs to NodeGroup instances
271
+ """
272
+ from griptape_nodes.exe_types.node_types import NodeGroup
273
+
274
+ groups: dict[str, NodeGroup] = {}
275
+
276
+ # Scan all nodes in DAG for group membership
277
+ for dag_node in self.node_to_reference.values():
278
+ node = dag_node.node_reference
279
+ group_id = node.get_parameter_value("job_group")
280
+
281
+ # Skip nodes without group assignment or empty group ID
282
+ if not group_id or group_id == "":
283
+ continue
284
+
285
+ # Create group if it doesn't exist
286
+ if group_id not in groups:
287
+ groups[group_id] = NodeGroup(group_id=group_id)
288
+
289
+ # Add node to group
290
+ groups[group_id].add_node(node)
291
+
292
+ # Analyze connections for each group
293
+ for group in groups.values():
294
+ self._analyze_group_connections(group, connections)
295
+
296
+ return groups
297
+
298
+ def _analyze_group_connections(self, group: NodeGroup, connections: Connections) -> None:
299
+ """Analyze and categorize connections for a node group.
300
+
301
+ Categorizes all connections involving group nodes as either:
302
+ - Internal: Both endpoints within the group
303
+ - External incoming: From outside node to group node
304
+ - External outgoing: From group node to outside node
305
+
306
+ Args:
307
+ group: NodeGroup to analyze
308
+ connections: Connections object containing all flow connections
309
+ """
310
+ node_names_in_group = set(group.nodes.keys())
311
+
312
+ # Analyze all connections in the flow
313
+ for conn in connections.connections.values():
314
+ source_in_group = conn.source_node.name in node_names_in_group
315
+ target_in_group = conn.target_node.name in node_names_in_group
316
+
317
+ if source_in_group and target_in_group:
318
+ # Both endpoints in group - internal connection
319
+ group.internal_connections.append(conn)
320
+ elif source_in_group and not target_in_group:
321
+ # From group to outside - external outgoing
322
+ group.external_outgoing_connections.append(conn)
323
+ elif not source_in_group and target_in_group:
324
+ # From outside to group - external incoming
325
+ group.external_incoming_connections.append(conn)
326
+
327
+ def _create_and_install_proxy_node(self, group_id: str, group: NodeGroup, connections: Connections) -> BaseNode:
328
+ """Create a proxy node for a group and install it in the DAG.
329
+
330
+ Creates a NodeGroupProxyNode, adds it to the DAG, remaps all external
331
+ connections to point to the proxy, and removes the original grouped
332
+ nodes from the DAG.
333
+
334
+ Args:
335
+ group_id: Unique identifier for the group
336
+ group: NodeGroup instance to replace
337
+ connections: Connections object for remapping connections
338
+
339
+ Returns:
340
+ The created NodeGroupProxyNode
341
+ """
342
+ from griptape_nodes.exe_types.node_types import NodeGroupProxyNode
343
+
344
+ # Create proxy node with unique name
345
+ proxy_name = f"__group_proxy_{group_id}"
346
+ proxy_node = NodeGroupProxyNode(name=proxy_name, node_group=group)
347
+
348
+ # Determine which graph to add proxy to (use first grouped node's graph)
349
+ target_graph_name = None
350
+ for graph_name, node_set in self.graph_to_nodes.items():
351
+ if any(node_name in node_set for node_name in group.nodes):
352
+ target_graph_name = graph_name
353
+ break
354
+
355
+ if target_graph_name is None:
356
+ target_graph_name = "default"
357
+
358
+ # Add proxy node to DAG
359
+ self.add_node(proxy_node, target_graph_name)
360
+
361
+ # Remap external connections to proxy
362
+ self._remap_connections_to_proxy(group, proxy_node, connections)
363
+
364
+ # Remove grouped nodes from DAG
365
+ self._remove_grouped_nodes_from_dag(group)
366
+
367
+ return proxy_node
368
+
369
+ def _remap_connections_to_proxy(self, group: NodeGroup, proxy_node: BaseNode, connections: Connections) -> None:
370
+ """Remap external connections from group nodes to the proxy node.
371
+
372
+ Updates the connection indices and Connection objects to redirect
373
+ external connections through the proxy node instead of the original
374
+ grouped nodes.
375
+
376
+ Args:
377
+ group: NodeGroup being replaced
378
+ proxy_node: Proxy node that will handle external connections
379
+ connections: Connections object to update
380
+ """
381
+ # Remap external incoming connections (from outside -> group becomes outside -> proxy)
382
+ for conn in group.external_incoming_connections:
383
+ conn_id = id(conn)
384
+
385
+ # Remove old incoming index entry
386
+ if (
387
+ conn.target_node.name in connections.incoming_index
388
+ and conn.target_parameter.name in connections.incoming_index[conn.target_node.name]
389
+ ):
390
+ connections.incoming_index[conn.target_node.name][conn.target_parameter.name].remove(conn_id)
391
+
392
+ # Update connection target to proxy
393
+ conn.target_node = proxy_node
394
+
395
+ # Add new incoming index entry
396
+ connections.incoming_index.setdefault(proxy_node.name, {}).setdefault(
397
+ conn.target_parameter.name, []
398
+ ).append(conn_id)
399
+
400
+ # Remap external outgoing connections (group -> outside becomes proxy -> outside)
401
+ for conn in group.external_outgoing_connections:
402
+ conn_id = id(conn)
403
+
404
+ # Remove old outgoing index entry
405
+ if (
406
+ conn.source_node.name in connections.outgoing_index
407
+ and conn.source_parameter.name in connections.outgoing_index[conn.source_node.name]
408
+ ):
409
+ connections.outgoing_index[conn.source_node.name][conn.source_parameter.name].remove(conn_id)
410
+
411
+ # Update connection source to proxy
412
+ conn.source_node = proxy_node
413
+
414
+ # Add new outgoing index entry
415
+ connections.outgoing_index.setdefault(proxy_node.name, {}).setdefault(
416
+ conn.source_parameter.name, []
417
+ ).append(conn_id)
418
+
419
+ def _remove_grouped_nodes_from_dag(self, group: NodeGroup) -> None:
420
+ """Remove all nodes in a group from the DAG graphs and references.
421
+
422
+ Args:
423
+ group: NodeGroup whose nodes should be removed from the DAG
424
+ """
425
+ for node_name in group.nodes:
426
+ # Remove from node_to_reference
427
+ if node_name in self.node_to_reference:
428
+ del self.node_to_reference[node_name]
429
+
430
+ # Remove from all graphs
431
+ for graph in self.graphs.values():
432
+ if node_name in graph.nodes():
433
+ graph.remove_node(node_name)
434
+
435
+ # Remove from graph_to_nodes tracking
436
+ for node_set in self.graph_to_nodes.values():
437
+ node_set.discard(node_name)