lionagi 0.13.3__py3-none-any.whl → 0.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +8 -0
- lionagi/_types.py +6 -0
- lionagi/config.py +1 -1
- lionagi/fields/__init__.py +20 -3
- lionagi/operations/__init__.py +18 -1
- lionagi/operations/builder.py +664 -0
- lionagi/operations/flow.py +436 -0
- lionagi/operations/node.py +107 -0
- lionagi/protocols/graph/__init__.py +6 -0
- lionagi/service/connections/providers/claude_code_cli.py +1 -4
- lionagi/session/__init__.py +5 -0
- lionagi/session/session.py +47 -0
- lionagi/settings.py +2 -2
- lionagi/version.py +1 -1
- {lionagi-0.13.3.dist-info → lionagi-0.13.5.dist-info}/METADATA +11 -10
- {lionagi-0.13.3.dist-info → lionagi-0.13.5.dist-info}/RECORD +18 -15
- {lionagi-0.13.3.dist-info → lionagi-0.13.5.dist-info}/WHEEL +0 -0
- {lionagi-0.13.3.dist-info → lionagi-0.13.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,664 @@
|
|
1
|
+
# Copyright (c) 2023 - 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
"""
|
6
|
+
OperationGraphBuilder: Incremental graph builder for multi-stage operations.
|
7
|
+
|
8
|
+
Build → Execute → Expand → Execute → ...
|
9
|
+
"""
|
10
|
+
|
11
|
+
from enum import Enum
|
12
|
+
from typing import Any
|
13
|
+
|
14
|
+
from lionagi.operations.node import BranchOperations, Operation
|
15
|
+
from lionagi.protocols.graph.edge import Edge
|
16
|
+
from lionagi.protocols.graph.graph import Graph
|
17
|
+
|
18
|
+
__all__ = (
|
19
|
+
"OperationGraphBuilder",
|
20
|
+
"ExpansionStrategy",
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class ExpansionStrategy(Enum):
|
25
|
+
"""Strategies for expanding operations."""
|
26
|
+
|
27
|
+
CONCURRENT = "concurrent"
|
28
|
+
SEQUENTIAL = "sequential"
|
29
|
+
SEQUENTIAL_CONCURRENT_CHUNK = "sequential_concurrent_chunk"
|
30
|
+
CONCURRENT_SEQUENTIAL_CHUNK = "concurrent_sequential_chunk"
|
31
|
+
|
32
|
+
|
33
|
+
class OperationGraphBuilder:
|
34
|
+
"""
|
35
|
+
Incremental graph builder that supports build → execute → expand cycles.
|
36
|
+
|
37
|
+
Unlike static builders, this maintains state and allows expanding the graph
|
38
|
+
based on execution results.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
>>> # Build initial graph
|
42
|
+
>>> builder = OperationGraphBuilder()
|
43
|
+
>>> builder.add_operation("operate", instruction="Generate ideas", num_instruct=5)
|
44
|
+
>>> graph = builder.get_graph()
|
45
|
+
>>>
|
46
|
+
>>> # Execute with session
|
47
|
+
>>> result = await session.flow(graph)
|
48
|
+
>>>
|
49
|
+
>>> # Expand based on results
|
50
|
+
>>> if hasattr(result, 'instruct_models'):
|
51
|
+
... builder.expand_from_result(
|
52
|
+
... result.instruct_models,
|
53
|
+
... source_node_id=builder.last_operation_id,
|
54
|
+
... operation="instruct"
|
55
|
+
... )
|
56
|
+
>>>
|
57
|
+
>>> # Get expanded graph and continue execution
|
58
|
+
>>> graph = builder.get_graph()
|
59
|
+
>>> final_result = await session.flow(graph)
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(self, name: str = "DynamicGraph"):
|
63
|
+
"""Initialize the incremental graph builder."""
|
64
|
+
self.name = name
|
65
|
+
self.graph = Graph()
|
66
|
+
|
67
|
+
# Track state
|
68
|
+
self._operations: dict[str, Operation] = {} # All operations by ID
|
69
|
+
self._executed: set[str] = set() # IDs of executed operations
|
70
|
+
self._current_heads: list[str] = [] # Current head nodes for linking
|
71
|
+
self.last_operation_id: str | None = None
|
72
|
+
|
73
|
+
def add_operation(
|
74
|
+
self,
|
75
|
+
operation: BranchOperations,
|
76
|
+
node_id: str | None = None,
|
77
|
+
depends_on: list[str] | None = None,
|
78
|
+
**parameters,
|
79
|
+
) -> str:
|
80
|
+
"""
|
81
|
+
Add an operation to the graph.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
operation: The branch operation
|
85
|
+
node_id: Optional ID reference for this node
|
86
|
+
depends_on: List of node IDs this depends on
|
87
|
+
**parameters: Operation parameters
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
ID of the created node
|
91
|
+
"""
|
92
|
+
# Create operation node
|
93
|
+
node = Operation(operation=operation, parameters=parameters)
|
94
|
+
|
95
|
+
self.graph.add_node(node)
|
96
|
+
self._operations[node.id] = node
|
97
|
+
|
98
|
+
# Store reference if provided
|
99
|
+
if node_id:
|
100
|
+
# Add as metadata for easy lookup
|
101
|
+
node.metadata["reference_id"] = node_id
|
102
|
+
|
103
|
+
# Handle dependencies
|
104
|
+
if depends_on:
|
105
|
+
for dep_id in depends_on:
|
106
|
+
if dep_id in self._operations:
|
107
|
+
edge = Edge(
|
108
|
+
head=dep_id, tail=node.id, label=["depends_on"]
|
109
|
+
)
|
110
|
+
self.graph.add_edge(edge)
|
111
|
+
elif self._current_heads:
|
112
|
+
# Auto-link from current heads
|
113
|
+
for head_id in self._current_heads:
|
114
|
+
edge = Edge(head=head_id, tail=node.id, label=["sequential"])
|
115
|
+
self.graph.add_edge(edge)
|
116
|
+
|
117
|
+
# Update state
|
118
|
+
self._current_heads = [node.id]
|
119
|
+
self.last_operation_id = node.id
|
120
|
+
|
121
|
+
return node.id
|
122
|
+
|
123
|
+
def expand_from_result(
|
124
|
+
self,
|
125
|
+
items: list[Any],
|
126
|
+
source_node_id: str,
|
127
|
+
operation: BranchOperations,
|
128
|
+
strategy: ExpansionStrategy = ExpansionStrategy.CONCURRENT,
|
129
|
+
**shared_params,
|
130
|
+
) -> list[str]:
|
131
|
+
"""
|
132
|
+
Expand the graph based on execution results.
|
133
|
+
|
134
|
+
This is called after executing the graph to add new operations
|
135
|
+
based on results.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
items: Items from result to expand (e.g., instruct_models)
|
139
|
+
source_node_id: ID of node that produced these items
|
140
|
+
operation: Operation to apply to each item
|
141
|
+
strategy: How to organize the expanded operations
|
142
|
+
**shared_params: Shared parameters for all operations
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
List of new node IDs
|
146
|
+
"""
|
147
|
+
if source_node_id not in self._operations:
|
148
|
+
raise ValueError(f"Source node {source_node_id} not found")
|
149
|
+
|
150
|
+
new_node_ids = []
|
151
|
+
|
152
|
+
# Create operation for each item
|
153
|
+
for i, item in enumerate(items):
|
154
|
+
# Extract parameters from item if it's a model
|
155
|
+
if hasattr(item, "model_dump"):
|
156
|
+
params = {**item.model_dump(), **shared_params}
|
157
|
+
else:
|
158
|
+
params = {**shared_params, "item_index": i, "item": str(item)}
|
159
|
+
|
160
|
+
# Add metadata about expansion
|
161
|
+
params["expanded_from"] = source_node_id
|
162
|
+
params["expansion_strategy"] = strategy.value
|
163
|
+
|
164
|
+
node = Operation(
|
165
|
+
operation=operation,
|
166
|
+
parameters=params,
|
167
|
+
metadata={
|
168
|
+
"expansion_index": i,
|
169
|
+
"expansion_source": source_node_id,
|
170
|
+
"expansion_strategy": strategy.value,
|
171
|
+
},
|
172
|
+
)
|
173
|
+
|
174
|
+
self.graph.add_node(node)
|
175
|
+
self._operations[node.id] = node
|
176
|
+
new_node_ids.append(node.id)
|
177
|
+
|
178
|
+
# Link from source
|
179
|
+
edge = Edge(
|
180
|
+
head=source_node_id,
|
181
|
+
tail=node.id,
|
182
|
+
label=["expansion", strategy.value],
|
183
|
+
)
|
184
|
+
self.graph.add_edge(edge)
|
185
|
+
|
186
|
+
# Update current heads based on strategy
|
187
|
+
if strategy in [
|
188
|
+
ExpansionStrategy.CONCURRENT,
|
189
|
+
ExpansionStrategy.SEQUENTIAL,
|
190
|
+
]:
|
191
|
+
self._current_heads = new_node_ids
|
192
|
+
|
193
|
+
return new_node_ids
|
194
|
+
|
195
|
+
def add_aggregation(
|
196
|
+
self,
|
197
|
+
operation: BranchOperations,
|
198
|
+
source_node_ids: list[str] | None = None,
|
199
|
+
**parameters,
|
200
|
+
) -> str:
|
201
|
+
"""
|
202
|
+
Add an aggregation operation that collects from multiple sources.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
operation: Aggregation operation
|
206
|
+
source_node_ids: Nodes to aggregate from (defaults to current heads)
|
207
|
+
**parameters: Operation parameters
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
ID of aggregation node
|
211
|
+
"""
|
212
|
+
sources = source_node_ids or self._current_heads
|
213
|
+
if not sources:
|
214
|
+
raise ValueError("No source nodes for aggregation")
|
215
|
+
|
216
|
+
# Add aggregation metadata
|
217
|
+
agg_params = {
|
218
|
+
"aggregation_sources": sources,
|
219
|
+
"aggregation_count": len(sources),
|
220
|
+
**parameters,
|
221
|
+
}
|
222
|
+
|
223
|
+
node = Operation(
|
224
|
+
operation=operation,
|
225
|
+
parameters=agg_params,
|
226
|
+
metadata={"aggregation": True},
|
227
|
+
)
|
228
|
+
|
229
|
+
self.graph.add_node(node)
|
230
|
+
self._operations[node.id] = node
|
231
|
+
|
232
|
+
# Connect all sources
|
233
|
+
for source_id in sources:
|
234
|
+
edge = Edge(head=source_id, tail=node.id, label=["aggregate"])
|
235
|
+
self.graph.add_edge(edge)
|
236
|
+
|
237
|
+
# Update state
|
238
|
+
self._current_heads = [node.id]
|
239
|
+
self.last_operation_id = node.id
|
240
|
+
|
241
|
+
return node.id
|
242
|
+
|
243
|
+
def mark_executed(self, node_ids: list[str]):
|
244
|
+
"""
|
245
|
+
Mark nodes as executed.
|
246
|
+
|
247
|
+
This helps track which parts of the graph have been run.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
node_ids: IDs of executed nodes
|
251
|
+
"""
|
252
|
+
self._executed.update(node_ids)
|
253
|
+
|
254
|
+
def get_unexecuted_nodes(self) -> list[Operation]:
|
255
|
+
"""
|
256
|
+
Get nodes that haven't been executed yet.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
List of unexecuted operations
|
260
|
+
"""
|
261
|
+
return [
|
262
|
+
op
|
263
|
+
for op_id, op in self._operations.items()
|
264
|
+
if op_id not in self._executed
|
265
|
+
]
|
266
|
+
|
267
|
+
def add_conditional_branch(
|
268
|
+
self,
|
269
|
+
condition_check_op: BranchOperations,
|
270
|
+
true_op: BranchOperations,
|
271
|
+
false_op: BranchOperations | None = None,
|
272
|
+
**check_params,
|
273
|
+
) -> dict[str, str]:
|
274
|
+
"""
|
275
|
+
Add a conditional branch structure.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
condition_check_op: Operation that evaluates condition
|
279
|
+
true_op: Operation if condition is true
|
280
|
+
false_op: Operation if condition is false
|
281
|
+
**check_params: Parameters for condition check
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
Dict with node IDs: {'check': id, 'true': id, 'false': id}
|
285
|
+
"""
|
286
|
+
# Add condition check node
|
287
|
+
check_node = Operation(
|
288
|
+
operation=condition_check_op,
|
289
|
+
parameters={**check_params, "is_condition_check": True},
|
290
|
+
)
|
291
|
+
self.graph.add_node(check_node)
|
292
|
+
self._operations[check_node.id] = check_node
|
293
|
+
|
294
|
+
# Link from current heads
|
295
|
+
for head_id in self._current_heads:
|
296
|
+
edge = Edge(
|
297
|
+
head=head_id, tail=check_node.id, label=["to_condition"]
|
298
|
+
)
|
299
|
+
self.graph.add_edge(edge)
|
300
|
+
|
301
|
+
result = {"check": check_node.id}
|
302
|
+
|
303
|
+
# Add true branch
|
304
|
+
true_node = Operation(operation=true_op, parameters={"branch": "true"})
|
305
|
+
self.graph.add_node(true_node)
|
306
|
+
self._operations[true_node.id] = true_node
|
307
|
+
result["true"] = true_node.id
|
308
|
+
|
309
|
+
# Connect with condition label
|
310
|
+
true_edge = Edge(
|
311
|
+
head=check_node.id, tail=true_node.id, label=["if_true"]
|
312
|
+
)
|
313
|
+
self.graph.add_edge(true_edge)
|
314
|
+
|
315
|
+
# Add false branch if specified
|
316
|
+
if false_op:
|
317
|
+
false_node = Operation(
|
318
|
+
operation=false_op, parameters={"branch": "false"}
|
319
|
+
)
|
320
|
+
self.graph.add_node(false_node)
|
321
|
+
self._operations[false_node.id] = false_node
|
322
|
+
result["false"] = false_node.id
|
323
|
+
|
324
|
+
false_edge = Edge(
|
325
|
+
head=check_node.id, tail=false_node.id, label=["if_false"]
|
326
|
+
)
|
327
|
+
self.graph.add_edge(false_edge)
|
328
|
+
|
329
|
+
self._current_heads = [true_node.id, false_node.id]
|
330
|
+
else:
|
331
|
+
self._current_heads = [true_node.id]
|
332
|
+
|
333
|
+
return result
|
334
|
+
|
335
|
+
def get_graph(self) -> Graph:
|
336
|
+
"""
|
337
|
+
Get the current graph for execution.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
The graph in its current state
|
341
|
+
"""
|
342
|
+
return self.graph
|
343
|
+
|
344
|
+
def get_node_by_reference(self, reference_id: str) -> Operation | None:
|
345
|
+
"""
|
346
|
+
Get a node by its reference ID.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
reference_id: The reference ID assigned when creating the node
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
The operation node or None
|
353
|
+
"""
|
354
|
+
for op in self._operations.values():
|
355
|
+
if op.metadata.get("reference_id") == reference_id:
|
356
|
+
return op
|
357
|
+
return None
|
358
|
+
|
359
|
+
def visualize_state(self) -> dict[str, Any]:
|
360
|
+
"""
|
361
|
+
Get visualization of current graph state.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
Dict with graph statistics and state
|
365
|
+
"""
|
366
|
+
# Group nodes by expansion source
|
367
|
+
expansions = {}
|
368
|
+
for op in self._operations.values():
|
369
|
+
source = op.metadata.get("expansion_source")
|
370
|
+
if source:
|
371
|
+
if source not in expansions:
|
372
|
+
expansions[source] = []
|
373
|
+
expansions[source].append(op.id)
|
374
|
+
|
375
|
+
return {
|
376
|
+
"name": self.name,
|
377
|
+
"total_nodes": len(self._operations),
|
378
|
+
"executed_nodes": len(self._executed),
|
379
|
+
"unexecuted_nodes": len(self._operations) - len(self._executed),
|
380
|
+
"current_heads": self._current_heads,
|
381
|
+
"expansions": expansions,
|
382
|
+
"edges": len(self.graph.internal_edges),
|
383
|
+
}
|
384
|
+
|
385
|
+
def visualize(self, title: str = "Operation Graph", figsize=(14, 10)):
|
386
|
+
visualize_graph(
|
387
|
+
self,
|
388
|
+
title=title,
|
389
|
+
figsize=figsize,
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
def visualize_graph(
|
394
|
+
builder: OperationGraphBuilder,
|
395
|
+
title: str = "Operation Graph",
|
396
|
+
figsize=(14, 10),
|
397
|
+
):
|
398
|
+
"""Visualization with improved layout for complex graphs."""
|
399
|
+
import matplotlib.pyplot as plt
|
400
|
+
import networkx as nx
|
401
|
+
import numpy as np
|
402
|
+
|
403
|
+
graph = builder.get_graph()
|
404
|
+
|
405
|
+
# Convert to networkx
|
406
|
+
G = nx.DiGraph()
|
407
|
+
|
408
|
+
# Track node positions for hierarchical layout
|
409
|
+
node_levels = {}
|
410
|
+
node_labels = {}
|
411
|
+
node_colors = []
|
412
|
+
node_sizes = []
|
413
|
+
|
414
|
+
# First pass: add nodes and determine levels
|
415
|
+
for node in graph.internal_nodes.values():
|
416
|
+
node_id = str(node.id)[:8]
|
417
|
+
G.add_node(node_id)
|
418
|
+
|
419
|
+
# Determine level based on dependencies
|
420
|
+
in_edges = [
|
421
|
+
e
|
422
|
+
for e in graph.internal_edges.values()
|
423
|
+
if str(e.tail)[:8] == node_id
|
424
|
+
]
|
425
|
+
if not in_edges:
|
426
|
+
level = 0 # Root nodes
|
427
|
+
else:
|
428
|
+
# Get max level of predecessors + 1
|
429
|
+
pred_levels = []
|
430
|
+
for edge in in_edges:
|
431
|
+
pred_id = str(edge.head)[:8]
|
432
|
+
if pred_id in node_levels:
|
433
|
+
pred_levels.append(node_levels[pred_id])
|
434
|
+
level = max(pred_levels, default=0) + 1
|
435
|
+
|
436
|
+
node_levels[node_id] = level
|
437
|
+
|
438
|
+
# Create label
|
439
|
+
ref_id = node.metadata.get("reference_id", "")
|
440
|
+
if ref_id:
|
441
|
+
label = f"{node.operation}\n[{ref_id}]"
|
442
|
+
else:
|
443
|
+
label = f"{node.operation}\n{node_id}"
|
444
|
+
node_labels[node_id] = label
|
445
|
+
|
446
|
+
# Color and size based on status and type
|
447
|
+
if node.id in builder._executed:
|
448
|
+
node_colors.append("#90EE90") # Light green
|
449
|
+
node_sizes.append(4000)
|
450
|
+
elif node.metadata.get("expansion_source"):
|
451
|
+
node_colors.append("#87CEEB") # Sky blue
|
452
|
+
node_sizes.append(3500)
|
453
|
+
elif node.metadata.get("aggregation"):
|
454
|
+
node_colors.append("#FFD700") # Gold
|
455
|
+
node_sizes.append(4500)
|
456
|
+
elif node.metadata.get("is_condition_check"):
|
457
|
+
node_colors.append("#DDA0DD") # Plum
|
458
|
+
node_sizes.append(3500)
|
459
|
+
else:
|
460
|
+
node_colors.append("#E0E0E0") # Light gray
|
461
|
+
node_sizes.append(3000)
|
462
|
+
|
463
|
+
# Add edges
|
464
|
+
edge_colors = []
|
465
|
+
edge_styles = []
|
466
|
+
edge_widths = []
|
467
|
+
edge_labels = {}
|
468
|
+
|
469
|
+
for edge in graph.internal_edges.values():
|
470
|
+
head_id = str(edge.head)[:8]
|
471
|
+
tail_id = str(edge.tail)[:8]
|
472
|
+
G.add_edge(head_id, tail_id)
|
473
|
+
|
474
|
+
# Style edges based on type
|
475
|
+
edge_label = edge.label[0] if edge.label else ""
|
476
|
+
edge_labels[(head_id, tail_id)] = edge_label
|
477
|
+
|
478
|
+
if "expansion" in edge_label:
|
479
|
+
edge_colors.append("#4169E1") # Royal blue
|
480
|
+
edge_styles.append("dashed")
|
481
|
+
edge_widths.append(2)
|
482
|
+
elif "aggregate" in edge_label:
|
483
|
+
edge_colors.append("#FF6347") # Tomato
|
484
|
+
edge_styles.append("dotted")
|
485
|
+
edge_widths.append(2.5)
|
486
|
+
else:
|
487
|
+
edge_colors.append("#808080") # Gray
|
488
|
+
edge_styles.append("solid")
|
489
|
+
edge_widths.append(1.5)
|
490
|
+
|
491
|
+
# Create improved hierarchical layout
|
492
|
+
pos = {}
|
493
|
+
nodes_by_level = {}
|
494
|
+
|
495
|
+
for node_id, level in node_levels.items():
|
496
|
+
if level not in nodes_by_level:
|
497
|
+
nodes_by_level[level] = []
|
498
|
+
nodes_by_level[level].append(node_id)
|
499
|
+
|
500
|
+
# Position nodes with better spacing algorithm
|
501
|
+
y_spacing = 2.5
|
502
|
+
max_width = 16 # Maximum horizontal spread
|
503
|
+
|
504
|
+
for level, nodes in nodes_by_level.items():
|
505
|
+
num_nodes = len(nodes)
|
506
|
+
|
507
|
+
if num_nodes <= 6:
|
508
|
+
# Normal spacing for small levels
|
509
|
+
x_spacing = 2.5
|
510
|
+
x_offset = -(num_nodes - 1) * x_spacing / 2
|
511
|
+
for i, node_id in enumerate(nodes):
|
512
|
+
pos[node_id] = (x_offset + i * x_spacing, -level * y_spacing)
|
513
|
+
else:
|
514
|
+
# Multi-row layout for large levels
|
515
|
+
nodes_per_row = min(6, int(np.ceil(np.sqrt(num_nodes * 1.5))))
|
516
|
+
rows = int(np.ceil(num_nodes / nodes_per_row))
|
517
|
+
|
518
|
+
for i, node_id in enumerate(nodes):
|
519
|
+
row = i // nodes_per_row
|
520
|
+
col = i % nodes_per_row
|
521
|
+
|
522
|
+
# Calculate row width
|
523
|
+
nodes_in_row = min(
|
524
|
+
nodes_per_row, num_nodes - row * nodes_per_row
|
525
|
+
)
|
526
|
+
x_spacing = 2.5
|
527
|
+
x_offset = -(nodes_in_row - 1) * x_spacing / 2
|
528
|
+
|
529
|
+
# Add slight y offset for different rows
|
530
|
+
y_offset = row * 0.8
|
531
|
+
|
532
|
+
pos[node_id] = (
|
533
|
+
x_offset + col * x_spacing,
|
534
|
+
-level * y_spacing - y_offset,
|
535
|
+
)
|
536
|
+
|
537
|
+
# Create figure
|
538
|
+
plt.figure(figsize=figsize)
|
539
|
+
|
540
|
+
# Draw nodes
|
541
|
+
nx.draw_networkx_nodes(
|
542
|
+
G,
|
543
|
+
pos,
|
544
|
+
node_color=node_colors,
|
545
|
+
node_size=node_sizes,
|
546
|
+
alpha=0.9,
|
547
|
+
linewidths=2,
|
548
|
+
edgecolors="black",
|
549
|
+
)
|
550
|
+
|
551
|
+
# Draw edges with different styles - use curved edges for better visibility
|
552
|
+
for i, (u, v) in enumerate(G.edges()):
|
553
|
+
# Calculate curve based on node positions
|
554
|
+
u_pos = pos[u]
|
555
|
+
v_pos = pos[v]
|
556
|
+
|
557
|
+
# Determine connection style based on relative positions
|
558
|
+
if abs(u_pos[0] - v_pos[0]) > 5: # Far apart horizontally
|
559
|
+
connectionstyle = "arc3,rad=0.2"
|
560
|
+
else:
|
561
|
+
connectionstyle = "arc3,rad=0.1"
|
562
|
+
|
563
|
+
nx.draw_networkx_edges(
|
564
|
+
G,
|
565
|
+
pos,
|
566
|
+
[(u, v)],
|
567
|
+
edge_color=[edge_colors[i]],
|
568
|
+
style=edge_styles[i],
|
569
|
+
width=edge_widths[i],
|
570
|
+
alpha=0.7,
|
571
|
+
arrows=True,
|
572
|
+
arrowsize=20,
|
573
|
+
arrowstyle="-|>",
|
574
|
+
connectionstyle=connectionstyle,
|
575
|
+
)
|
576
|
+
|
577
|
+
# Draw labels
|
578
|
+
nx.draw_networkx_labels(
|
579
|
+
G,
|
580
|
+
pos,
|
581
|
+
node_labels,
|
582
|
+
font_size=9,
|
583
|
+
font_weight="bold",
|
584
|
+
font_family="monospace",
|
585
|
+
)
|
586
|
+
|
587
|
+
# Draw edge labels (only for smaller graphs)
|
588
|
+
if len(G.edges()) < 20:
|
589
|
+
nx.draw_networkx_edge_labels(
|
590
|
+
G,
|
591
|
+
pos,
|
592
|
+
edge_labels,
|
593
|
+
font_size=7,
|
594
|
+
font_color="darkblue",
|
595
|
+
bbox=dict(
|
596
|
+
boxstyle="round,pad=0.3",
|
597
|
+
facecolor="white",
|
598
|
+
edgecolor="none",
|
599
|
+
alpha=0.7,
|
600
|
+
),
|
601
|
+
)
|
602
|
+
|
603
|
+
plt.title(title, fontsize=18, fontweight="bold", pad=20)
|
604
|
+
plt.axis("off")
|
605
|
+
|
606
|
+
# Enhanced legend
|
607
|
+
from matplotlib.lines import Line2D
|
608
|
+
from matplotlib.patches import Patch, Rectangle
|
609
|
+
|
610
|
+
legend_elements = [
|
611
|
+
Patch(facecolor="#90EE90", edgecolor="black", label="Executed"),
|
612
|
+
Patch(facecolor="#87CEEB", edgecolor="black", label="Expanded"),
|
613
|
+
Patch(facecolor="#FFD700", edgecolor="black", label="Aggregation"),
|
614
|
+
Patch(facecolor="#DDA0DD", edgecolor="black", label="Condition"),
|
615
|
+
Patch(facecolor="#E0E0E0", edgecolor="black", label="Pending"),
|
616
|
+
Line2D([0], [0], color="#808080", linewidth=2, label="Sequential"),
|
617
|
+
Line2D(
|
618
|
+
[0],
|
619
|
+
[0],
|
620
|
+
color="#4169E1",
|
621
|
+
linewidth=2,
|
622
|
+
linestyle="dashed",
|
623
|
+
label="Expansion",
|
624
|
+
),
|
625
|
+
Line2D(
|
626
|
+
[0],
|
627
|
+
[0],
|
628
|
+
color="#FF6347",
|
629
|
+
linewidth=2,
|
630
|
+
linestyle="dotted",
|
631
|
+
label="Aggregate",
|
632
|
+
),
|
633
|
+
]
|
634
|
+
|
635
|
+
plt.legend(
|
636
|
+
handles=legend_elements,
|
637
|
+
loc="upper left",
|
638
|
+
bbox_to_anchor=(0, 1),
|
639
|
+
frameon=True,
|
640
|
+
fancybox=True,
|
641
|
+
shadow=True,
|
642
|
+
ncol=2,
|
643
|
+
)
|
644
|
+
|
645
|
+
# Add statistics box
|
646
|
+
stats_text = f"Nodes: {len(G.nodes())}\nEdges: {len(G.edges())}\nExecuted: {len(builder._executed)}"
|
647
|
+
if nodes_by_level:
|
648
|
+
max_level = max(nodes_by_level.keys())
|
649
|
+
stats_text += f"\nLevels: {max_level + 1}"
|
650
|
+
|
651
|
+
plt.text(
|
652
|
+
0.98,
|
653
|
+
0.02,
|
654
|
+
stats_text,
|
655
|
+
transform=plt.gca().transAxes,
|
656
|
+
bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8),
|
657
|
+
verticalalignment="bottom",
|
658
|
+
horizontalalignment="right",
|
659
|
+
fontsize=10,
|
660
|
+
fontfamily="monospace",
|
661
|
+
)
|
662
|
+
|
663
|
+
plt.tight_layout()
|
664
|
+
plt.show()
|