lionagi 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. lionagi/adapters/_utils.py +10 -23
  2. lionagi/adapters/async_postgres_adapter.py +83 -79
  3. lionagi/ln/__init__.py +4 -4
  4. lionagi/ln/_json_dump.py +0 -6
  5. lionagi/ln/fuzzy/__init__.py +4 -1
  6. lionagi/ln/fuzzy/_fuzzy_validate.py +109 -0
  7. lionagi/ln/fuzzy/_to_dict.py +388 -0
  8. lionagi/models/__init__.py +0 -2
  9. lionagi/operations/__init__.py +0 -6
  10. lionagi/operations/_visualize_graph.py +285 -0
  11. lionagi/operations/brainstorm/brainstorm.py +14 -12
  12. lionagi/operations/builder.py +23 -302
  13. lionagi/operations/communicate/communicate.py +1 -1
  14. lionagi/operations/flow.py +14 -11
  15. lionagi/operations/node.py +14 -3
  16. lionagi/operations/operate/operate.py +5 -11
  17. lionagi/operations/parse/parse.py +2 -3
  18. lionagi/operations/types.py +0 -2
  19. lionagi/operations/utils.py +11 -5
  20. lionagi/protocols/generic/pile.py +3 -7
  21. lionagi/protocols/graph/graph.py +23 -6
  22. lionagi/protocols/graph/node.py +0 -2
  23. lionagi/protocols/messages/message.py +0 -1
  24. lionagi/protocols/operatives/operative.py +2 -2
  25. lionagi/protocols/types.py +0 -15
  26. lionagi/service/connections/endpoint.py +11 -5
  27. lionagi/service/connections/match_endpoint.py +2 -10
  28. lionagi/service/connections/providers/types.py +1 -3
  29. lionagi/service/hooks/hook_event.py +1 -1
  30. lionagi/service/hooks/hook_registry.py +1 -1
  31. lionagi/service/rate_limited_processor.py +1 -1
  32. lionagi/session/branch.py +24 -18
  33. lionagi/session/session.py +2 -18
  34. lionagi/utils.py +3 -335
  35. lionagi/version.py +1 -1
  36. {lionagi-0.16.2.dist-info → lionagi-0.17.0.dist-info}/METADATA +4 -13
  37. {lionagi-0.16.2.dist-info → lionagi-0.17.0.dist-info}/RECORD +39 -61
  38. lionagi/adapters/postgres_model_adapter.py +0 -131
  39. lionagi/libs/concurrency.py +0 -1
  40. lionagi/libs/nested/__init__.py +0 -3
  41. lionagi/libs/nested/flatten.py +0 -172
  42. lionagi/libs/nested/nfilter.py +0 -59
  43. lionagi/libs/nested/nget.py +0 -45
  44. lionagi/libs/nested/ninsert.py +0 -104
  45. lionagi/libs/nested/nmerge.py +0 -158
  46. lionagi/libs/nested/npop.py +0 -69
  47. lionagi/libs/nested/nset.py +0 -94
  48. lionagi/libs/nested/unflatten.py +0 -83
  49. lionagi/libs/nested/utils.py +0 -189
  50. lionagi/libs/parse.py +0 -31
  51. lionagi/libs/schema/json_schema.py +0 -231
  52. lionagi/libs/unstructured/__init__.py +0 -0
  53. lionagi/libs/unstructured/pdf_to_image.py +0 -45
  54. lionagi/libs/unstructured/read_image_to_base64.py +0 -33
  55. lionagi/libs/validate/fuzzy_match_keys.py +0 -7
  56. lionagi/libs/validate/fuzzy_validate_mapping.py +0 -144
  57. lionagi/libs/validate/string_similarity.py +0 -7
  58. lionagi/libs/validate/xml_parser.py +0 -203
  59. lionagi/models/note.py +0 -387
  60. lionagi/protocols/graph/_utils.py +0 -22
  61. lionagi/service/connections/providers/claude_code_.py +0 -299
  62. {lionagi-0.16.2.dist-info → lionagi-0.17.0.dist-info}/WHEEL +0 -0
  63. {lionagi-0.16.2.dist-info → lionagi-0.17.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,388 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import dataclasses
5
+ import json
6
+ from collections.abc import Callable, Iterable, Mapping, Sequence
7
+ from enum import Enum as _Enum
8
+ from typing import Any, Literal
9
+
10
+ from ._fuzzy_json import fuzzy_json
11
+
12
+ # ----------------------------
13
+ # Helpers (small, tight, local)
14
+ # ----------------------------
15
+
16
+
17
+ def _is_na(obj: Any) -> bool:
18
+ """None / Pydantic undefined sentinels -> treat as NA."""
19
+ if obj is None:
20
+ return True
21
+ # Avoid importing pydantic types; match by typename to stay lightweight
22
+ tname = type(obj).__name__
23
+ return tname in {
24
+ "Undefined",
25
+ "UndefinedType",
26
+ "PydanticUndefined",
27
+ "PydanticUndefinedType",
28
+ }
29
+
30
+
31
+ def _enum_class_to_dict(
32
+ enum_cls: type[_Enum], use_enum_values: bool
33
+ ) -> dict[str, Any]:
34
+ members = dict(enum_cls.__members__) # cheap, stable
35
+ if use_enum_values:
36
+ return {k: v.value for k, v in members.items()}
37
+ return {k: v for k, v in members.items()}
38
+
39
+
40
+ def _parse_str(
41
+ s: str,
42
+ *,
43
+ fuzzy_parse: bool,
44
+ str_type: Literal["json", "xml"] | None,
45
+ parser: Callable[[str], Any] | None,
46
+ **kwargs: Any,
47
+ ) -> Any:
48
+ """Parse str -> Python object. Keep imports local to avoid cold start overhead."""
49
+ if parser is not None:
50
+ return parser(s, **kwargs)
51
+
52
+ if str_type == "xml":
53
+ # xmltodict is optional; import only if needed
54
+ import xmltodict
55
+
56
+ return xmltodict.parse(s, **kwargs)
57
+
58
+ # JSON path
59
+ if fuzzy_parse:
60
+ # If the caller supplied a fuzzy parser in scope, use it; otherwise fallback.
61
+ # We intentionally do not import anything heavy here.
62
+ with contextlib.suppress(NameError):
63
+ return fuzzy_json(s, **kwargs) # type: ignore[name-defined]
64
+ return json.loads(s, **kwargs)
65
+
66
+
67
+ def _object_to_mapping_like(
68
+ obj: Any,
69
+ *,
70
+ use_model_dump: bool,
71
+ **kwargs: Any,
72
+ ) -> Mapping | dict | Any:
73
+ """
74
+ Convert 'custom' objects to mapping-like, if possible.
75
+ Order:
76
+ 1) Pydantic v2 'model_dump' (duck-typed)
77
+ 2) Common methods: to_dict, dict, to_json/json (parsed if string)
78
+ 3) Dataclass
79
+ 4) __dict__
80
+ 5) dict(obj)
81
+ """
82
+ # 1) Pydantic v2
83
+ if use_model_dump and hasattr(obj, "model_dump"):
84
+ return obj.model_dump(**kwargs)
85
+
86
+ # 2) Common methods
87
+ for name in ("to_dict", "dict", "to_json", "json"):
88
+ if hasattr(obj, name):
89
+ res = getattr(obj, name)(**kwargs)
90
+ return json.loads(res) if isinstance(res, str) else res
91
+
92
+ # 3) Dataclass
93
+ if dataclasses.is_dataclass(obj):
94
+ # asdict is already recursive; keep it (fast enough & simple)
95
+ return dataclasses.asdict(obj)
96
+
97
+ # 4) __dict__
98
+ if hasattr(obj, "__dict__"):
99
+ return obj.__dict__
100
+
101
+ # 5) Try dict() fallback
102
+ return dict(obj) # may raise -> handled by caller
103
+
104
+
105
+ def _enumerate_iterable(it: Iterable) -> dict[int, Any]:
106
+ return {i: v for i, v in enumerate(it)}
107
+
108
+
109
+ # ---------------------------------------
110
+ # Recursive pre-processing (single pass)
111
+ # ---------------------------------------
112
+
113
+
114
+ def _preprocess_recursive(
115
+ obj: Any,
116
+ *,
117
+ depth: int,
118
+ max_depth: int,
119
+ recursive_custom_types: bool,
120
+ str_parse_opts: dict[str, Any],
121
+ use_model_dump: bool,
122
+ ) -> Any:
123
+ """
124
+ Recursively process nested structures:
125
+ - Parse strings (JSON/XML/custom parser)
126
+ - Recurse into dict/list/tuple/set/etc.
127
+ - If recursive_custom_types=True, convert custom objects to mapping-like then continue
128
+ Containers retain their original types (dict stays dict, list stays list, set stays set, etc.)
129
+ """
130
+ if depth >= max_depth:
131
+ return obj
132
+
133
+ # Fast paths by exact type where possible
134
+ t = type(obj)
135
+
136
+ # Strings: try to parse; on failure, keep as-is
137
+ if t is str:
138
+ try:
139
+ parsed = _parse_str(obj, **str_parse_opts)
140
+ except Exception:
141
+ return obj
142
+ return _preprocess_recursive(
143
+ parsed,
144
+ depth=depth + 1,
145
+ max_depth=max_depth,
146
+ recursive_custom_types=recursive_custom_types,
147
+ str_parse_opts=str_parse_opts,
148
+ use_model_dump=use_model_dump,
149
+ )
150
+
151
+ # Dict-like
152
+ if isinstance(obj, Mapping):
153
+ # Recurse only into values (keys kept as-is)
154
+ return {
155
+ k: _preprocess_recursive(
156
+ v,
157
+ depth=depth + 1,
158
+ max_depth=max_depth,
159
+ recursive_custom_types=recursive_custom_types,
160
+ str_parse_opts=str_parse_opts,
161
+ use_model_dump=use_model_dump,
162
+ )
163
+ for k, v in obj.items()
164
+ }
165
+
166
+ # Sequence/Set-like (but not str)
167
+ if isinstance(obj, (list, tuple, set, frozenset)):
168
+ items = [
169
+ _preprocess_recursive(
170
+ v,
171
+ depth=depth + 1,
172
+ max_depth=max_depth,
173
+ recursive_custom_types=recursive_custom_types,
174
+ str_parse_opts=str_parse_opts,
175
+ use_model_dump=use_model_dump,
176
+ )
177
+ for v in obj
178
+ ]
179
+ if t is list:
180
+ return items
181
+ if t is tuple:
182
+ return tuple(items)
183
+ if t is set:
184
+ return set(items)
185
+ if t is frozenset:
186
+ return frozenset(items)
187
+
188
+ # Enum *class* (rare in values, but preserve your original attempt)
189
+ if isinstance(obj, type) and issubclass(obj, _Enum):
190
+ try:
191
+ enum_map = _enum_class_to_dict(
192
+ obj,
193
+ use_enum_values=str_parse_opts.get("use_enum_values", True),
194
+ )
195
+ return _preprocess_recursive(
196
+ enum_map,
197
+ depth=depth + 1,
198
+ max_depth=max_depth,
199
+ recursive_custom_types=recursive_custom_types,
200
+ str_parse_opts=str_parse_opts,
201
+ use_model_dump=use_model_dump,
202
+ )
203
+ except Exception:
204
+ return obj
205
+
206
+ # Custom objects
207
+ if recursive_custom_types:
208
+ with contextlib.suppress(Exception):
209
+ mapped = _object_to_mapping_like(
210
+ obj, use_model_dump=use_model_dump
211
+ )
212
+ return _preprocess_recursive(
213
+ mapped,
214
+ depth=depth + 1,
215
+ max_depth=max_depth,
216
+ recursive_custom_types=recursive_custom_types,
217
+ str_parse_opts=str_parse_opts,
218
+ use_model_dump=use_model_dump,
219
+ )
220
+
221
+ return obj
222
+
223
+
224
+ # ---------------------------------------
225
+ # Top-level conversion (non-recursive)
226
+ # ---------------------------------------
227
+
228
+
229
+ def _convert_top_level_to_dict(
230
+ obj: Any,
231
+ *,
232
+ fuzzy_parse: bool,
233
+ str_type: Literal["json", "xml"] | None,
234
+ parser: Callable[[str], Any] | None,
235
+ use_model_dump: bool,
236
+ use_enum_values: bool,
237
+ **kwargs: Any,
238
+ ) -> dict[str, Any]:
239
+ """
240
+ Convert a *single* object to dict using the 'brute force' rules.
241
+ Mirrors your original order, with fixes & optimizations.
242
+ """
243
+ # Set -> {v: v}
244
+ if isinstance(obj, set):
245
+ return {v: v for v in obj}
246
+
247
+ # Enum class -> members mapping
248
+ if isinstance(obj, type) and issubclass(obj, _Enum):
249
+ return _enum_class_to_dict(obj, use_enum_values)
250
+
251
+ # Mapping -> copy to plain dict (preserve your copy semantics)
252
+ if isinstance(obj, Mapping):
253
+ return dict(obj)
254
+
255
+ # None / pydantic undefined -> {}
256
+ if _is_na(obj):
257
+ return {}
258
+
259
+ # str -> parse (and return *as parsed*, which may be list, dict, etc.)
260
+ if isinstance(obj, str):
261
+ return _parse_str(
262
+ obj,
263
+ fuzzy_parse=fuzzy_parse,
264
+ str_type=str_type,
265
+ parser=parser,
266
+ **kwargs,
267
+ )
268
+
269
+ # Try "custom" object conversions
270
+ # (Covers BaseModel via model_dump, dataclasses, __dict__, json-strings, etc.)
271
+ try:
272
+ # If it's *not* a Sequence (e.g., numbers, objects) we try object conversion first,
273
+ # faithfully following your previous "non-Sequence -> model path" behavior.
274
+ if not isinstance(obj, Sequence):
275
+ converted = _object_to_mapping_like(
276
+ obj, use_model_dump=use_model_dump, **kwargs
277
+ )
278
+ # If conversion returned a string, try to parse JSON to mapping; else pass-through
279
+ if isinstance(converted, str):
280
+ return _parse_str(
281
+ converted,
282
+ fuzzy_parse=fuzzy_parse,
283
+ str_type="json",
284
+ parser=None,
285
+ )
286
+ if isinstance(converted, Mapping):
287
+ return dict(converted)
288
+ # If it's a list/tuple/etc., enumerate (your original did that after the fact)
289
+ if isinstance(converted, Iterable) and not isinstance(
290
+ converted, (str, bytes, bytearray)
291
+ ):
292
+ return _enumerate_iterable(converted)
293
+ # Best effort final cast
294
+ return dict(converted)
295
+
296
+ except Exception:
297
+ # Fall through to other strategies
298
+ pass
299
+
300
+ # Iterable (list/tuple/namedtuple/frozenset/…): enumerate
301
+ if isinstance(obj, Iterable) and not isinstance(
302
+ obj, (str, bytes, bytearray)
303
+ ):
304
+ return _enumerate_iterable(obj)
305
+
306
+ # Dataclass fallback (reachable only if it wasn't caught above)
307
+ with contextlib.suppress(Exception):
308
+ if dataclasses.is_dataclass(obj):
309
+ return dataclasses.asdict(obj)
310
+
311
+ # Last-ditch attempt
312
+ return dict(obj) # may raise, handled by top-level try/except
313
+
314
+
315
+ # ---------------
316
+ # Public function
317
+ # ---------------
318
+
319
+
320
+ def to_dict(
321
+ input_: Any,
322
+ /,
323
+ *,
324
+ use_model_dump: bool = True,
325
+ fuzzy_parse: bool = False,
326
+ suppress: bool = False,
327
+ str_type: Literal["json", "xml"] | None = "json",
328
+ parser: Callable[[str], Any] | None = None,
329
+ recursive: bool = False,
330
+ max_recursive_depth: int | None = None,
331
+ recursive_python_only: bool = True,
332
+ use_enum_values: bool = False,
333
+ **kwargs: Any,
334
+ ) -> dict[str, Any]:
335
+ """
336
+ Convert various input types to a dictionary, with optional recursive processing.
337
+ Semantics preserved from original implementation.
338
+ """
339
+ try:
340
+ # Clamp recursion depth (match your constraints)
341
+ if not isinstance(max_recursive_depth, int):
342
+ max_depth = 5
343
+ else:
344
+ if max_recursive_depth < 0:
345
+ raise ValueError(
346
+ "max_recursive_depth must be a non-negative integer"
347
+ )
348
+ if max_recursive_depth > 10:
349
+ raise ValueError(
350
+ "max_recursive_depth must be less than or equal to 10"
351
+ )
352
+ max_depth = max_recursive_depth
353
+
354
+ # Prepare one small dict to avoid repeated arg passing and lookups
355
+ str_parse_opts = {
356
+ "fuzzy_parse": fuzzy_parse,
357
+ "str_type": str_type,
358
+ "parser": parser,
359
+ "use_enum_values": use_enum_values, # threaded for enum class in recursion
360
+ **kwargs,
361
+ }
362
+
363
+ obj = input_
364
+ if recursive:
365
+ obj = _preprocess_recursive(
366
+ obj,
367
+ depth=0,
368
+ max_depth=max_depth,
369
+ recursive_custom_types=not recursive_python_only,
370
+ str_parse_opts=str_parse_opts,
371
+ use_model_dump=use_model_dump,
372
+ )
373
+
374
+ # Final top-level conversion
375
+ return _convert_top_level_to_dict(
376
+ obj,
377
+ fuzzy_parse=fuzzy_parse,
378
+ str_type=str_type,
379
+ parser=parser,
380
+ use_model_dump=use_model_dump,
381
+ use_enum_values=use_enum_values,
382
+ **kwargs,
383
+ )
384
+
385
+ except Exception as e:
386
+ if suppress or input_ == "":
387
+ return {}
388
+ raise e
@@ -5,7 +5,6 @@
5
5
  from .field_model import FieldModel
6
6
  from .hashable_model import HashableModel
7
7
  from .model_params import ModelParams
8
- from .note import Note
9
8
  from .operable_model import OperableModel
10
9
  from .schema_model import SchemaModel
11
10
 
@@ -13,7 +12,6 @@ __all__ = (
13
12
  "FieldModel",
14
13
  "ModelParams",
15
14
  "OperableModel",
16
- "Note",
17
15
  "SchemaModel",
18
16
  "HashableModel",
19
17
  )
@@ -2,11 +2,9 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from .brainstorm.brainstorm import BrainstormOperation, brainstorm
6
5
  from .builder import ExpansionStrategy, OperationGraphBuilder
7
6
  from .flow import flow
8
7
  from .node import BranchOperations, Operation
9
- from .plan.plan import PlanOperation, plan
10
8
 
11
9
  Builder = OperationGraphBuilder
12
10
 
@@ -17,9 +15,5 @@ __all__ = (
17
15
  "flow",
18
16
  "BranchOperations",
19
17
  "Operation",
20
- "plan",
21
- "PlanOperation",
22
- "brainstorm",
23
- "BrainstormOperation",
24
18
  "Builder",
25
19
  )
@@ -0,0 +1,285 @@
1
+ def visualize_graph(
2
+ builder,
3
+ title: str = "Operation Graph",
4
+ figsize=(14, 10),
5
+ ):
6
+ """Visualization with improved layout for complex graphs."""
7
+ from lionagi.utils import is_import_installed
8
+
9
+ if not is_import_installed("matplotlib"):
10
+ raise ImportError(
11
+ "matplotlib is required for visualization. "
12
+ "Please install it using `pip install matplotlib`."
13
+ )
14
+ if not is_import_installed("networkx"):
15
+ raise ImportError(
16
+ "networkx is required for visualization. "
17
+ "Please install it using `pip install networkx`."
18
+ )
19
+
20
+ import matplotlib.pyplot as plt
21
+ import networkx as nx
22
+ import numpy as np
23
+
24
+ graph = builder.get_graph()
25
+
26
+ # Convert to networkx
27
+ G = nx.DiGraph()
28
+
29
+ # Track node positions for hierarchical layout
30
+ node_levels = {}
31
+ node_labels = {}
32
+ node_colors = []
33
+ node_sizes = []
34
+
35
+ # First pass: add nodes and determine levels
36
+ for node in graph.internal_nodes.values():
37
+ node_id = str(node.id)[:8]
38
+ G.add_node(node_id)
39
+
40
+ # Determine level based on dependencies
41
+ in_edges = [
42
+ e
43
+ for e in graph.internal_edges.values()
44
+ if str(e.tail)[:8] == node_id
45
+ ]
46
+ if not in_edges:
47
+ level = 0 # Root nodes
48
+ else:
49
+ # Get max level of predecessors + 1
50
+ pred_levels = []
51
+ for edge in in_edges:
52
+ pred_id = str(edge.head)[:8]
53
+ if pred_id in node_levels:
54
+ pred_levels.append(node_levels[pred_id])
55
+ level = max(pred_levels, default=0) + 1
56
+
57
+ node_levels[node_id] = level
58
+
59
+ # Create label
60
+ ref_id = node.metadata.get("reference_id", "")
61
+ if ref_id:
62
+ label = f"{node.operation}\n[{ref_id}]"
63
+ else:
64
+ label = f"{node.operation}\n{node_id}"
65
+ node_labels[node_id] = label
66
+
67
+ # Color and size based on status and type
68
+ if node.id in builder._executed:
69
+ node_colors.append("#90EE90") # Light green
70
+ node_sizes.append(4000)
71
+ elif node.metadata.get("expansion_source"):
72
+ node_colors.append("#87CEEB") # Sky blue
73
+ node_sizes.append(3500)
74
+ elif node.metadata.get("aggregation"):
75
+ node_colors.append("#FFD700") # Gold
76
+ node_sizes.append(4500)
77
+ elif node.metadata.get("is_condition_check"):
78
+ node_colors.append("#DDA0DD") # Plum
79
+ node_sizes.append(3500)
80
+ else:
81
+ node_colors.append("#E0E0E0") # Light gray
82
+ node_sizes.append(3000)
83
+
84
+ # Add edges
85
+ edge_colors = []
86
+ edge_styles = []
87
+ edge_widths = []
88
+ edge_labels = {}
89
+
90
+ for edge in graph.internal_edges.values():
91
+ head_id = str(edge.head)[:8]
92
+ tail_id = str(edge.tail)[:8]
93
+ G.add_edge(head_id, tail_id)
94
+
95
+ # Style edges based on type
96
+ edge_label = edge.label[0] if edge.label else ""
97
+ edge_labels[(head_id, tail_id)] = edge_label
98
+
99
+ if "expansion" in edge_label:
100
+ edge_colors.append("#4169E1") # Royal blue
101
+ edge_styles.append("dashed")
102
+ edge_widths.append(2)
103
+ elif "aggregate" in edge_label:
104
+ edge_colors.append("#FF6347") # Tomato
105
+ edge_styles.append("dotted")
106
+ edge_widths.append(2.5)
107
+ else:
108
+ edge_colors.append("#808080") # Gray
109
+ edge_styles.append("solid")
110
+ edge_widths.append(1.5)
111
+
112
+ # Create improved hierarchical layout
113
+ pos = {}
114
+ nodes_by_level = {}
115
+
116
+ for node_id, level in node_levels.items():
117
+ if level not in nodes_by_level:
118
+ nodes_by_level[level] = []
119
+ nodes_by_level[level].append(node_id)
120
+
121
+ # Position nodes with better spacing algorithm
122
+ y_spacing = 2.5
123
+ max_width = 16 # Maximum horizontal spread
124
+
125
+ for level, nodes in nodes_by_level.items():
126
+ num_nodes = len(nodes)
127
+
128
+ if num_nodes <= 6:
129
+ # Normal spacing for small levels
130
+ x_spacing = 2.5
131
+ x_offset = -(num_nodes - 1) * x_spacing / 2
132
+ for i, node_id in enumerate(nodes):
133
+ pos[node_id] = (x_offset + i * x_spacing, -level * y_spacing)
134
+ else:
135
+ # Multi-row layout for large levels
136
+ nodes_per_row = min(6, int(np.ceil(np.sqrt(num_nodes * 1.5))))
137
+ rows = int(np.ceil(num_nodes / nodes_per_row))
138
+
139
+ for i, node_id in enumerate(nodes):
140
+ row = i // nodes_per_row
141
+ col = i % nodes_per_row
142
+
143
+ # Calculate row width
144
+ nodes_in_row = min(
145
+ nodes_per_row, num_nodes - row * nodes_per_row
146
+ )
147
+ x_spacing = 2.5
148
+ x_offset = -(nodes_in_row - 1) * x_spacing / 2
149
+
150
+ # Add slight y offset for different rows
151
+ y_offset = row * 0.8
152
+
153
+ pos[node_id] = (
154
+ x_offset + col * x_spacing,
155
+ -level * y_spacing - y_offset,
156
+ )
157
+
158
+ # Create figure
159
+ plt.figure(figsize=figsize)
160
+
161
+ # Draw nodes
162
+ nx.draw_networkx_nodes(
163
+ G,
164
+ pos,
165
+ node_color=node_colors,
166
+ node_size=node_sizes,
167
+ alpha=0.9,
168
+ linewidths=2,
169
+ edgecolors="black",
170
+ )
171
+
172
+ # Draw edges with different styles - use curved edges for better visibility
173
+ for i, (u, v) in enumerate(G.edges()):
174
+ # Calculate curve based on node positions
175
+ u_pos = pos[u]
176
+ v_pos = pos[v]
177
+
178
+ # Determine connection style based on relative positions
179
+ if abs(u_pos[0] - v_pos[0]) > 5: # Far apart horizontally
180
+ connectionstyle = "arc3,rad=0.2"
181
+ else:
182
+ connectionstyle = "arc3,rad=0.1"
183
+
184
+ nx.draw_networkx_edges(
185
+ G,
186
+ pos,
187
+ [(u, v)],
188
+ edge_color=[edge_colors[i]],
189
+ style=edge_styles[i],
190
+ width=edge_widths[i],
191
+ alpha=0.7,
192
+ arrows=True,
193
+ arrowsize=20,
194
+ arrowstyle="-|>",
195
+ connectionstyle=connectionstyle,
196
+ )
197
+
198
+ # Draw labels
199
+ nx.draw_networkx_labels(
200
+ G,
201
+ pos,
202
+ node_labels,
203
+ font_size=9,
204
+ font_weight="bold",
205
+ font_family="monospace",
206
+ )
207
+
208
+ # Draw edge labels (only for smaller graphs)
209
+ if len(G.edges()) < 20:
210
+ nx.draw_networkx_edge_labels(
211
+ G,
212
+ pos,
213
+ edge_labels,
214
+ font_size=7,
215
+ font_color="darkblue",
216
+ bbox=dict(
217
+ boxstyle="round,pad=0.3",
218
+ facecolor="white",
219
+ edgecolor="none",
220
+ alpha=0.7,
221
+ ),
222
+ )
223
+
224
+ plt.title(title, fontsize=18, fontweight="bold", pad=20)
225
+ plt.axis("off")
226
+
227
+ # Enhanced legend
228
+ from matplotlib.lines import Line2D
229
+ from matplotlib.patches import Patch, Rectangle
230
+
231
+ legend_elements = [
232
+ Patch(facecolor="#90EE90", edgecolor="black", label="Executed"),
233
+ Patch(facecolor="#87CEEB", edgecolor="black", label="Expanded"),
234
+ Patch(facecolor="#FFD700", edgecolor="black", label="Aggregation"),
235
+ Patch(facecolor="#DDA0DD", edgecolor="black", label="Condition"),
236
+ Patch(facecolor="#E0E0E0", edgecolor="black", label="Pending"),
237
+ Line2D([0], [0], color="#808080", linewidth=2, label="Sequential"),
238
+ Line2D(
239
+ [0],
240
+ [0],
241
+ color="#4169E1",
242
+ linewidth=2,
243
+ linestyle="dashed",
244
+ label="Expansion",
245
+ ),
246
+ Line2D(
247
+ [0],
248
+ [0],
249
+ color="#FF6347",
250
+ linewidth=2,
251
+ linestyle="dotted",
252
+ label="Aggregate",
253
+ ),
254
+ ]
255
+
256
+ plt.legend(
257
+ handles=legend_elements,
258
+ loc="upper left",
259
+ bbox_to_anchor=(0, 1),
260
+ frameon=True,
261
+ fancybox=True,
262
+ shadow=True,
263
+ ncol=2,
264
+ )
265
+
266
+ # Add statistics box
267
+ stats_text = f"Nodes: {len(G.nodes())}\nEdges: {len(G.edges())}\nExecuted: {len(builder._executed)}"
268
+ if nodes_by_level:
269
+ max_level = max(nodes_by_level.keys())
270
+ stats_text += f"\nLevels: {max_level + 1}"
271
+
272
+ plt.text(
273
+ 0.98,
274
+ 0.02,
275
+ stats_text,
276
+ transform=plt.gca().transAxes,
277
+ bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8),
278
+ verticalalignment="bottom",
279
+ horizontalalignment="right",
280
+ fontsize=10,
281
+ fontfamily="monospace",
282
+ )
283
+
284
+ plt.tight_layout()
285
+ plt.show()