lionagi 0.13.3__py3-none-any.whl → 0.13.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,664 @@
1
+ # Copyright (c) 2023 - 2025, HaiyangLi <quantocean.li at gmail dot com>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ OperationGraphBuilder: Incremental graph builder for multi-stage operations.
7
+
8
+ Build → Execute → Expand → Execute → ...
9
+ """
10
+
11
+ from enum import Enum
12
+ from typing import Any
13
+
14
+ from lionagi.operations.node import BranchOperations, Operation
15
+ from lionagi.protocols.graph.edge import Edge
16
+ from lionagi.protocols.graph.graph import Graph
17
+
18
+ __all__ = (
19
+ "OperationGraphBuilder",
20
+ "ExpansionStrategy",
21
+ )
22
+
23
+
24
+ class ExpansionStrategy(Enum):
25
+ """Strategies for expanding operations."""
26
+
27
+ CONCURRENT = "concurrent"
28
+ SEQUENTIAL = "sequential"
29
+ SEQUENTIAL_CONCURRENT_CHUNK = "sequential_concurrent_chunk"
30
+ CONCURRENT_SEQUENTIAL_CHUNK = "concurrent_sequential_chunk"
31
+
32
+
33
+ class OperationGraphBuilder:
34
+ """
35
+ Incremental graph builder that supports build → execute → expand cycles.
36
+
37
+ Unlike static builders, this maintains state and allows expanding the graph
38
+ based on execution results.
39
+
40
+ Examples:
41
+ >>> # Build initial graph
42
+ >>> builder = OperationGraphBuilder()
43
+ >>> builder.add_operation("operate", instruction="Generate ideas", num_instruct=5)
44
+ >>> graph = builder.get_graph()
45
+ >>>
46
+ >>> # Execute with session
47
+ >>> result = await session.flow(graph)
48
+ >>>
49
+ >>> # Expand based on results
50
+ >>> if hasattr(result, 'instruct_models'):
51
+ ... builder.expand_from_result(
52
+ ... result.instruct_models,
53
+ ... source_node_id=builder.last_operation_id,
54
+ ... operation="instruct"
55
+ ... )
56
+ >>>
57
+ >>> # Get expanded graph and continue execution
58
+ >>> graph = builder.get_graph()
59
+ >>> final_result = await session.flow(graph)
60
+ """
61
+
62
+ def __init__(self, name: str = "DynamicGraph"):
63
+ """Initialize the incremental graph builder."""
64
+ self.name = name
65
+ self.graph = Graph()
66
+
67
+ # Track state
68
+ self._operations: dict[str, Operation] = {} # All operations by ID
69
+ self._executed: set[str] = set() # IDs of executed operations
70
+ self._current_heads: list[str] = [] # Current head nodes for linking
71
+ self.last_operation_id: str | None = None
72
+
73
+ def add_operation(
74
+ self,
75
+ operation: BranchOperations,
76
+ node_id: str | None = None,
77
+ depends_on: list[str] | None = None,
78
+ **parameters,
79
+ ) -> str:
80
+ """
81
+ Add an operation to the graph.
82
+
83
+ Args:
84
+ operation: The branch operation
85
+ node_id: Optional ID reference for this node
86
+ depends_on: List of node IDs this depends on
87
+ **parameters: Operation parameters
88
+
89
+ Returns:
90
+ ID of the created node
91
+ """
92
+ # Create operation node
93
+ node = Operation(operation=operation, parameters=parameters)
94
+
95
+ self.graph.add_node(node)
96
+ self._operations[node.id] = node
97
+
98
+ # Store reference if provided
99
+ if node_id:
100
+ # Add as metadata for easy lookup
101
+ node.metadata["reference_id"] = node_id
102
+
103
+ # Handle dependencies
104
+ if depends_on:
105
+ for dep_id in depends_on:
106
+ if dep_id in self._operations:
107
+ edge = Edge(
108
+ head=dep_id, tail=node.id, label=["depends_on"]
109
+ )
110
+ self.graph.add_edge(edge)
111
+ elif self._current_heads:
112
+ # Auto-link from current heads
113
+ for head_id in self._current_heads:
114
+ edge = Edge(head=head_id, tail=node.id, label=["sequential"])
115
+ self.graph.add_edge(edge)
116
+
117
+ # Update state
118
+ self._current_heads = [node.id]
119
+ self.last_operation_id = node.id
120
+
121
+ return node.id
122
+
123
+ def expand_from_result(
124
+ self,
125
+ items: list[Any],
126
+ source_node_id: str,
127
+ operation: BranchOperations,
128
+ strategy: ExpansionStrategy = ExpansionStrategy.CONCURRENT,
129
+ **shared_params,
130
+ ) -> list[str]:
131
+ """
132
+ Expand the graph based on execution results.
133
+
134
+ This is called after executing the graph to add new operations
135
+ based on results.
136
+
137
+ Args:
138
+ items: Items from result to expand (e.g., instruct_models)
139
+ source_node_id: ID of node that produced these items
140
+ operation: Operation to apply to each item
141
+ strategy: How to organize the expanded operations
142
+ **shared_params: Shared parameters for all operations
143
+
144
+ Returns:
145
+ List of new node IDs
146
+ """
147
+ if source_node_id not in self._operations:
148
+ raise ValueError(f"Source node {source_node_id} not found")
149
+
150
+ new_node_ids = []
151
+
152
+ # Create operation for each item
153
+ for i, item in enumerate(items):
154
+ # Extract parameters from item if it's a model
155
+ if hasattr(item, "model_dump"):
156
+ params = {**item.model_dump(), **shared_params}
157
+ else:
158
+ params = {**shared_params, "item_index": i, "item": str(item)}
159
+
160
+ # Add metadata about expansion
161
+ params["expanded_from"] = source_node_id
162
+ params["expansion_strategy"] = strategy.value
163
+
164
+ node = Operation(
165
+ operation=operation,
166
+ parameters=params,
167
+ metadata={
168
+ "expansion_index": i,
169
+ "expansion_source": source_node_id,
170
+ "expansion_strategy": strategy.value,
171
+ },
172
+ )
173
+
174
+ self.graph.add_node(node)
175
+ self._operations[node.id] = node
176
+ new_node_ids.append(node.id)
177
+
178
+ # Link from source
179
+ edge = Edge(
180
+ head=source_node_id,
181
+ tail=node.id,
182
+ label=["expansion", strategy.value],
183
+ )
184
+ self.graph.add_edge(edge)
185
+
186
+ # Update current heads based on strategy
187
+ if strategy in [
188
+ ExpansionStrategy.CONCURRENT,
189
+ ExpansionStrategy.SEQUENTIAL,
190
+ ]:
191
+ self._current_heads = new_node_ids
192
+
193
+ return new_node_ids
194
+
195
+ def add_aggregation(
196
+ self,
197
+ operation: BranchOperations,
198
+ source_node_ids: list[str] | None = None,
199
+ **parameters,
200
+ ) -> str:
201
+ """
202
+ Add an aggregation operation that collects from multiple sources.
203
+
204
+ Args:
205
+ operation: Aggregation operation
206
+ source_node_ids: Nodes to aggregate from (defaults to current heads)
207
+ **parameters: Operation parameters
208
+
209
+ Returns:
210
+ ID of aggregation node
211
+ """
212
+ sources = source_node_ids or self._current_heads
213
+ if not sources:
214
+ raise ValueError("No source nodes for aggregation")
215
+
216
+ # Add aggregation metadata
217
+ agg_params = {
218
+ "aggregation_sources": sources,
219
+ "aggregation_count": len(sources),
220
+ **parameters,
221
+ }
222
+
223
+ node = Operation(
224
+ operation=operation,
225
+ parameters=agg_params,
226
+ metadata={"aggregation": True},
227
+ )
228
+
229
+ self.graph.add_node(node)
230
+ self._operations[node.id] = node
231
+
232
+ # Connect all sources
233
+ for source_id in sources:
234
+ edge = Edge(head=source_id, tail=node.id, label=["aggregate"])
235
+ self.graph.add_edge(edge)
236
+
237
+ # Update state
238
+ self._current_heads = [node.id]
239
+ self.last_operation_id = node.id
240
+
241
+ return node.id
242
+
243
+ def mark_executed(self, node_ids: list[str]):
244
+ """
245
+ Mark nodes as executed.
246
+
247
+ This helps track which parts of the graph have been run.
248
+
249
+ Args:
250
+ node_ids: IDs of executed nodes
251
+ """
252
+ self._executed.update(node_ids)
253
+
254
+ def get_unexecuted_nodes(self) -> list[Operation]:
255
+ """
256
+ Get nodes that haven't been executed yet.
257
+
258
+ Returns:
259
+ List of unexecuted operations
260
+ """
261
+ return [
262
+ op
263
+ for op_id, op in self._operations.items()
264
+ if op_id not in self._executed
265
+ ]
266
+
267
+ def add_conditional_branch(
268
+ self,
269
+ condition_check_op: BranchOperations,
270
+ true_op: BranchOperations,
271
+ false_op: BranchOperations | None = None,
272
+ **check_params,
273
+ ) -> dict[str, str]:
274
+ """
275
+ Add a conditional branch structure.
276
+
277
+ Args:
278
+ condition_check_op: Operation that evaluates condition
279
+ true_op: Operation if condition is true
280
+ false_op: Operation if condition is false
281
+ **check_params: Parameters for condition check
282
+
283
+ Returns:
284
+ Dict with node IDs: {'check': id, 'true': id, 'false': id}
285
+ """
286
+ # Add condition check node
287
+ check_node = Operation(
288
+ operation=condition_check_op,
289
+ parameters={**check_params, "is_condition_check": True},
290
+ )
291
+ self.graph.add_node(check_node)
292
+ self._operations[check_node.id] = check_node
293
+
294
+ # Link from current heads
295
+ for head_id in self._current_heads:
296
+ edge = Edge(
297
+ head=head_id, tail=check_node.id, label=["to_condition"]
298
+ )
299
+ self.graph.add_edge(edge)
300
+
301
+ result = {"check": check_node.id}
302
+
303
+ # Add true branch
304
+ true_node = Operation(operation=true_op, parameters={"branch": "true"})
305
+ self.graph.add_node(true_node)
306
+ self._operations[true_node.id] = true_node
307
+ result["true"] = true_node.id
308
+
309
+ # Connect with condition label
310
+ true_edge = Edge(
311
+ head=check_node.id, tail=true_node.id, label=["if_true"]
312
+ )
313
+ self.graph.add_edge(true_edge)
314
+
315
+ # Add false branch if specified
316
+ if false_op:
317
+ false_node = Operation(
318
+ operation=false_op, parameters={"branch": "false"}
319
+ )
320
+ self.graph.add_node(false_node)
321
+ self._operations[false_node.id] = false_node
322
+ result["false"] = false_node.id
323
+
324
+ false_edge = Edge(
325
+ head=check_node.id, tail=false_node.id, label=["if_false"]
326
+ )
327
+ self.graph.add_edge(false_edge)
328
+
329
+ self._current_heads = [true_node.id, false_node.id]
330
+ else:
331
+ self._current_heads = [true_node.id]
332
+
333
+ return result
334
+
335
+ def get_graph(self) -> Graph:
336
+ """
337
+ Get the current graph for execution.
338
+
339
+ Returns:
340
+ The graph in its current state
341
+ """
342
+ return self.graph
343
+
344
+ def get_node_by_reference(self, reference_id: str) -> Operation | None:
345
+ """
346
+ Get a node by its reference ID.
347
+
348
+ Args:
349
+ reference_id: The reference ID assigned when creating the node
350
+
351
+ Returns:
352
+ The operation node or None
353
+ """
354
+ for op in self._operations.values():
355
+ if op.metadata.get("reference_id") == reference_id:
356
+ return op
357
+ return None
358
+
359
+ def visualize_state(self) -> dict[str, Any]:
360
+ """
361
+ Get visualization of current graph state.
362
+
363
+ Returns:
364
+ Dict with graph statistics and state
365
+ """
366
+ # Group nodes by expansion source
367
+ expansions = {}
368
+ for op in self._operations.values():
369
+ source = op.metadata.get("expansion_source")
370
+ if source:
371
+ if source not in expansions:
372
+ expansions[source] = []
373
+ expansions[source].append(op.id)
374
+
375
+ return {
376
+ "name": self.name,
377
+ "total_nodes": len(self._operations),
378
+ "executed_nodes": len(self._executed),
379
+ "unexecuted_nodes": len(self._operations) - len(self._executed),
380
+ "current_heads": self._current_heads,
381
+ "expansions": expansions,
382
+ "edges": len(self.graph.internal_edges),
383
+ }
384
+
385
+ def visualize(self, title: str = "Operation Graph", figsize=(14, 10)):
386
+ visualize_graph(
387
+ self,
388
+ title=title,
389
+ figsize=figsize,
390
+ )
391
+
392
+
393
+ def visualize_graph(
394
+ builder: OperationGraphBuilder,
395
+ title: str = "Operation Graph",
396
+ figsize=(14, 10),
397
+ ):
398
+ """Visualization with improved layout for complex graphs."""
399
+ import matplotlib.pyplot as plt
400
+ import networkx as nx
401
+ import numpy as np
402
+
403
+ graph = builder.get_graph()
404
+
405
+ # Convert to networkx
406
+ G = nx.DiGraph()
407
+
408
+ # Track node positions for hierarchical layout
409
+ node_levels = {}
410
+ node_labels = {}
411
+ node_colors = []
412
+ node_sizes = []
413
+
414
+ # First pass: add nodes and determine levels
415
+ for node in graph.internal_nodes.values():
416
+ node_id = str(node.id)[:8]
417
+ G.add_node(node_id)
418
+
419
+ # Determine level based on dependencies
420
+ in_edges = [
421
+ e
422
+ for e in graph.internal_edges.values()
423
+ if str(e.tail)[:8] == node_id
424
+ ]
425
+ if not in_edges:
426
+ level = 0 # Root nodes
427
+ else:
428
+ # Get max level of predecessors + 1
429
+ pred_levels = []
430
+ for edge in in_edges:
431
+ pred_id = str(edge.head)[:8]
432
+ if pred_id in node_levels:
433
+ pred_levels.append(node_levels[pred_id])
434
+ level = max(pred_levels, default=0) + 1
435
+
436
+ node_levels[node_id] = level
437
+
438
+ # Create label
439
+ ref_id = node.metadata.get("reference_id", "")
440
+ if ref_id:
441
+ label = f"{node.operation}\n[{ref_id}]"
442
+ else:
443
+ label = f"{node.operation}\n{node_id}"
444
+ node_labels[node_id] = label
445
+
446
+ # Color and size based on status and type
447
+ if node.id in builder._executed:
448
+ node_colors.append("#90EE90") # Light green
449
+ node_sizes.append(4000)
450
+ elif node.metadata.get("expansion_source"):
451
+ node_colors.append("#87CEEB") # Sky blue
452
+ node_sizes.append(3500)
453
+ elif node.metadata.get("aggregation"):
454
+ node_colors.append("#FFD700") # Gold
455
+ node_sizes.append(4500)
456
+ elif node.metadata.get("is_condition_check"):
457
+ node_colors.append("#DDA0DD") # Plum
458
+ node_sizes.append(3500)
459
+ else:
460
+ node_colors.append("#E0E0E0") # Light gray
461
+ node_sizes.append(3000)
462
+
463
+ # Add edges
464
+ edge_colors = []
465
+ edge_styles = []
466
+ edge_widths = []
467
+ edge_labels = {}
468
+
469
+ for edge in graph.internal_edges.values():
470
+ head_id = str(edge.head)[:8]
471
+ tail_id = str(edge.tail)[:8]
472
+ G.add_edge(head_id, tail_id)
473
+
474
+ # Style edges based on type
475
+ edge_label = edge.label[0] if edge.label else ""
476
+ edge_labels[(head_id, tail_id)] = edge_label
477
+
478
+ if "expansion" in edge_label:
479
+ edge_colors.append("#4169E1") # Royal blue
480
+ edge_styles.append("dashed")
481
+ edge_widths.append(2)
482
+ elif "aggregate" in edge_label:
483
+ edge_colors.append("#FF6347") # Tomato
484
+ edge_styles.append("dotted")
485
+ edge_widths.append(2.5)
486
+ else:
487
+ edge_colors.append("#808080") # Gray
488
+ edge_styles.append("solid")
489
+ edge_widths.append(1.5)
490
+
491
+ # Create improved hierarchical layout
492
+ pos = {}
493
+ nodes_by_level = {}
494
+
495
+ for node_id, level in node_levels.items():
496
+ if level not in nodes_by_level:
497
+ nodes_by_level[level] = []
498
+ nodes_by_level[level].append(node_id)
499
+
500
+ # Position nodes with better spacing algorithm
501
+ y_spacing = 2.5
502
+ max_width = 16 # Maximum horizontal spread
503
+
504
+ for level, nodes in nodes_by_level.items():
505
+ num_nodes = len(nodes)
506
+
507
+ if num_nodes <= 6:
508
+ # Normal spacing for small levels
509
+ x_spacing = 2.5
510
+ x_offset = -(num_nodes - 1) * x_spacing / 2
511
+ for i, node_id in enumerate(nodes):
512
+ pos[node_id] = (x_offset + i * x_spacing, -level * y_spacing)
513
+ else:
514
+ # Multi-row layout for large levels
515
+ nodes_per_row = min(6, int(np.ceil(np.sqrt(num_nodes * 1.5))))
516
+ rows = int(np.ceil(num_nodes / nodes_per_row))
517
+
518
+ for i, node_id in enumerate(nodes):
519
+ row = i // nodes_per_row
520
+ col = i % nodes_per_row
521
+
522
+ # Calculate row width
523
+ nodes_in_row = min(
524
+ nodes_per_row, num_nodes - row * nodes_per_row
525
+ )
526
+ x_spacing = 2.5
527
+ x_offset = -(nodes_in_row - 1) * x_spacing / 2
528
+
529
+ # Add slight y offset for different rows
530
+ y_offset = row * 0.8
531
+
532
+ pos[node_id] = (
533
+ x_offset + col * x_spacing,
534
+ -level * y_spacing - y_offset,
535
+ )
536
+
537
+ # Create figure
538
+ plt.figure(figsize=figsize)
539
+
540
+ # Draw nodes
541
+ nx.draw_networkx_nodes(
542
+ G,
543
+ pos,
544
+ node_color=node_colors,
545
+ node_size=node_sizes,
546
+ alpha=0.9,
547
+ linewidths=2,
548
+ edgecolors="black",
549
+ )
550
+
551
+ # Draw edges with different styles - use curved edges for better visibility
552
+ for i, (u, v) in enumerate(G.edges()):
553
+ # Calculate curve based on node positions
554
+ u_pos = pos[u]
555
+ v_pos = pos[v]
556
+
557
+ # Determine connection style based on relative positions
558
+ if abs(u_pos[0] - v_pos[0]) > 5: # Far apart horizontally
559
+ connectionstyle = "arc3,rad=0.2"
560
+ else:
561
+ connectionstyle = "arc3,rad=0.1"
562
+
563
+ nx.draw_networkx_edges(
564
+ G,
565
+ pos,
566
+ [(u, v)],
567
+ edge_color=[edge_colors[i]],
568
+ style=edge_styles[i],
569
+ width=edge_widths[i],
570
+ alpha=0.7,
571
+ arrows=True,
572
+ arrowsize=20,
573
+ arrowstyle="-|>",
574
+ connectionstyle=connectionstyle,
575
+ )
576
+
577
+ # Draw labels
578
+ nx.draw_networkx_labels(
579
+ G,
580
+ pos,
581
+ node_labels,
582
+ font_size=9,
583
+ font_weight="bold",
584
+ font_family="monospace",
585
+ )
586
+
587
+ # Draw edge labels (only for smaller graphs)
588
+ if len(G.edges()) < 20:
589
+ nx.draw_networkx_edge_labels(
590
+ G,
591
+ pos,
592
+ edge_labels,
593
+ font_size=7,
594
+ font_color="darkblue",
595
+ bbox=dict(
596
+ boxstyle="round,pad=0.3",
597
+ facecolor="white",
598
+ edgecolor="none",
599
+ alpha=0.7,
600
+ ),
601
+ )
602
+
603
+ plt.title(title, fontsize=18, fontweight="bold", pad=20)
604
+ plt.axis("off")
605
+
606
+ # Enhanced legend
607
+ from matplotlib.lines import Line2D
608
+ from matplotlib.patches import Patch, Rectangle
609
+
610
+ legend_elements = [
611
+ Patch(facecolor="#90EE90", edgecolor="black", label="Executed"),
612
+ Patch(facecolor="#87CEEB", edgecolor="black", label="Expanded"),
613
+ Patch(facecolor="#FFD700", edgecolor="black", label="Aggregation"),
614
+ Patch(facecolor="#DDA0DD", edgecolor="black", label="Condition"),
615
+ Patch(facecolor="#E0E0E0", edgecolor="black", label="Pending"),
616
+ Line2D([0], [0], color="#808080", linewidth=2, label="Sequential"),
617
+ Line2D(
618
+ [0],
619
+ [0],
620
+ color="#4169E1",
621
+ linewidth=2,
622
+ linestyle="dashed",
623
+ label="Expansion",
624
+ ),
625
+ Line2D(
626
+ [0],
627
+ [0],
628
+ color="#FF6347",
629
+ linewidth=2,
630
+ linestyle="dotted",
631
+ label="Aggregate",
632
+ ),
633
+ ]
634
+
635
+ plt.legend(
636
+ handles=legend_elements,
637
+ loc="upper left",
638
+ bbox_to_anchor=(0, 1),
639
+ frameon=True,
640
+ fancybox=True,
641
+ shadow=True,
642
+ ncol=2,
643
+ )
644
+
645
+ # Add statistics box
646
+ stats_text = f"Nodes: {len(G.nodes())}\nEdges: {len(G.edges())}\nExecuted: {len(builder._executed)}"
647
+ if nodes_by_level:
648
+ max_level = max(nodes_by_level.keys())
649
+ stats_text += f"\nLevels: {max_level + 1}"
650
+
651
+ plt.text(
652
+ 0.98,
653
+ 0.02,
654
+ stats_text,
655
+ transform=plt.gca().transAxes,
656
+ bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8),
657
+ verticalalignment="bottom",
658
+ horizontalalignment="right",
659
+ fontsize=10,
660
+ fontfamily="monospace",
661
+ )
662
+
663
+ plt.tight_layout()
664
+ plt.show()