lionagi 0.16.3__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +24 -2
- lionagi/_types.py +47 -3
- lionagi/adapters/_utils.py +10 -9
- lionagi/adapters/async_postgres_adapter.py +83 -79
- lionagi/ln/__init__.py +0 -4
- lionagi/ln/_json_dump.py +0 -6
- lionagi/operations/__init__.py +0 -6
- lionagi/operations/_visualize_graph.py +285 -0
- lionagi/operations/brainstorm/brainstorm.py +14 -12
- lionagi/operations/builder.py +23 -302
- lionagi/operations/flow.py +105 -11
- lionagi/operations/node.py +14 -3
- lionagi/operations/operate/operate.py +5 -11
- lionagi/operations/parse/parse.py +1 -2
- lionagi/operations/types.py +0 -2
- lionagi/operations/utils.py +11 -5
- lionagi/protocols/generic/pile.py +2 -6
- lionagi/protocols/graph/graph.py +23 -6
- lionagi/protocols/graph/node.py +27 -10
- lionagi/protocols/messages/message.py +0 -1
- lionagi/protocols/types.py +0 -15
- lionagi/service/__init__.py +19 -1
- lionagi/service/connections/api_calling.py +13 -4
- lionagi/service/connections/endpoint.py +11 -5
- lionagi/service/types.py +19 -1
- lionagi/session/branch.py +24 -18
- lionagi/session/session.py +44 -18
- lionagi/version.py +1 -1
- {lionagi-0.16.3.dist-info → lionagi-0.17.1.dist-info}/METADATA +5 -2
- {lionagi-0.16.3.dist-info → lionagi-0.17.1.dist-info}/RECORD +32 -32
- lionagi/protocols/graph/_utils.py +0 -22
- {lionagi-0.16.3.dist-info → lionagi-0.17.1.dist-info}/WHEEL +0 -0
- {lionagi-0.16.3.dist-info → lionagi-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
4
4
|
|
5
5
|
import logging
|
6
|
-
from typing import Any, Literal
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal
|
7
7
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
@@ -12,11 +12,13 @@ from lionagi.fields.instruct import (
|
|
12
12
|
Instruct,
|
13
13
|
InstructResponse,
|
14
14
|
)
|
15
|
-
from lionagi.ln import alcall
|
15
|
+
from lionagi.ln import alcall, to_list
|
16
16
|
from lionagi.protocols.generic.element import ID
|
17
|
-
|
18
|
-
|
19
|
-
from lionagi.
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from lionagi.session.branch import Branch
|
20
|
+
from lionagi.session.session import Session
|
21
|
+
|
20
22
|
|
21
23
|
from ..utils import prepare_instruct, prepare_session
|
22
24
|
from .prompt import PROMPT
|
@@ -58,8 +60,8 @@ def chunked(iterable, n):
|
|
58
60
|
|
59
61
|
async def run_instruct(
|
60
62
|
ins: Instruct,
|
61
|
-
session: Session,
|
62
|
-
branch: Branch,
|
63
|
+
session: "Session",
|
64
|
+
branch: "Branch",
|
63
65
|
auto_run: bool,
|
64
66
|
verbose: bool = True,
|
65
67
|
**kwargs: Any,
|
@@ -113,8 +115,8 @@ async def run_instruct(
|
|
113
115
|
async def brainstorm(
|
114
116
|
instruct: Instruct | dict[str, Any],
|
115
117
|
num_instruct: int = 2,
|
116
|
-
session: Session
|
117
|
-
branch: Branch
|
118
|
+
session: "Session" = None,
|
119
|
+
branch: ID["Branch"].Ref | None = None,
|
118
120
|
auto_run: bool = True,
|
119
121
|
auto_explore: bool = False,
|
120
122
|
explore_kwargs: dict[str, Any] | None = None,
|
@@ -156,8 +158,8 @@ async def brainstorm(
|
|
156
158
|
async def brainstormStream(
|
157
159
|
instruct: Instruct | dict[str, Any],
|
158
160
|
num_instruct: int = 2,
|
159
|
-
session: Session
|
160
|
-
branch: Branch
|
161
|
+
session: "Session" = None,
|
162
|
+
branch: ID["Branch"].Ref | None = None,
|
161
163
|
auto_run: bool = True,
|
162
164
|
auto_explore: bool = False,
|
163
165
|
explore_kwargs: dict[str, Any] | None = None,
|
@@ -377,7 +379,7 @@ async def brainstormStream(
|
|
377
379
|
all_responses = []
|
378
380
|
|
379
381
|
async def explore_concurrent_chunk(
|
380
|
-
sub_instructs: list[Instruct], base_branch: Branch
|
382
|
+
sub_instructs: list[Instruct], base_branch: "Branch"
|
381
383
|
):
|
382
384
|
"""
|
383
385
|
Explore instructions in a single chunk concurrently.
|
lionagi/operations/builder.py
CHANGED
@@ -11,9 +11,8 @@ Build → Execute → Expand → Execute → ...
|
|
11
11
|
from enum import Enum
|
12
12
|
from typing import Any
|
13
13
|
|
14
|
-
from lionagi.operations.node import
|
14
|
+
from lionagi.operations.node import create_operation
|
15
15
|
from lionagi.protocols.graph.edge import Edge
|
16
|
-
from lionagi.protocols.graph.graph import Graph
|
17
16
|
from lionagi.protocols.types import ID
|
18
17
|
|
19
18
|
__all__ = (
|
@@ -62,18 +61,20 @@ class OperationGraphBuilder:
|
|
62
61
|
|
63
62
|
def __init__(self, name: str = "DynamicGraph"):
|
64
63
|
"""Initialize the incremental graph builder."""
|
64
|
+
from lionagi.protocols.graph.graph import Graph
|
65
|
+
|
65
66
|
self.name = name
|
66
67
|
self.graph = Graph()
|
67
68
|
|
68
69
|
# Track state
|
69
|
-
self._operations
|
70
|
+
self._operations = {} # All operations by ID
|
70
71
|
self._executed: set[str] = set() # IDs of executed operations
|
71
72
|
self._current_heads: list[str] = [] # Current head nodes for linking
|
72
73
|
self.last_operation_id: str | None = None
|
73
74
|
|
74
75
|
def add_operation(
|
75
76
|
self,
|
76
|
-
operation:
|
77
|
+
operation: str,
|
77
78
|
node_id: str | None = None,
|
78
79
|
depends_on: list[str] | None = None,
|
79
80
|
inherit_context: bool = False,
|
@@ -95,7 +96,7 @@ class OperationGraphBuilder:
|
|
95
96
|
ID of the created node
|
96
97
|
"""
|
97
98
|
# Create operation node
|
98
|
-
node =
|
99
|
+
node = create_operation(operation=operation, parameters=parameters)
|
99
100
|
|
100
101
|
# Store context inheritance strategy
|
101
102
|
if inherit_context and depends_on:
|
@@ -137,7 +138,7 @@ class OperationGraphBuilder:
|
|
137
138
|
self,
|
138
139
|
items: list[Any],
|
139
140
|
source_node_id: str,
|
140
|
-
operation:
|
141
|
+
operation: str,
|
141
142
|
strategy: ExpansionStrategy = ExpansionStrategy.CONCURRENT,
|
142
143
|
inherit_context: bool = False,
|
143
144
|
chain_context: bool = False,
|
@@ -179,7 +180,7 @@ class OperationGraphBuilder:
|
|
179
180
|
params["expanded_from"] = source_node_id
|
180
181
|
params["expansion_strategy"] = strategy.value
|
181
182
|
|
182
|
-
node =
|
183
|
+
node = create_operation(
|
183
184
|
operation=operation,
|
184
185
|
parameters=params,
|
185
186
|
metadata={
|
@@ -227,7 +228,7 @@ class OperationGraphBuilder:
|
|
227
228
|
|
228
229
|
def add_aggregation(
|
229
230
|
self,
|
230
|
-
operation:
|
231
|
+
operation: str,
|
231
232
|
node_id: str | None = None,
|
232
233
|
source_node_ids: list[str] | None = None,
|
233
234
|
inherit_context: bool = False,
|
@@ -262,7 +263,7 @@ class OperationGraphBuilder:
|
|
262
263
|
**parameters,
|
263
264
|
}
|
264
265
|
|
265
|
-
node =
|
266
|
+
node = create_operation(
|
266
267
|
operation=operation,
|
267
268
|
parameters=agg_params,
|
268
269
|
metadata={"aggregation": True},
|
@@ -308,7 +309,7 @@ class OperationGraphBuilder:
|
|
308
309
|
"""
|
309
310
|
self._executed.update(node_ids)
|
310
311
|
|
311
|
-
def get_unexecuted_nodes(self)
|
312
|
+
def get_unexecuted_nodes(self):
|
312
313
|
"""
|
313
314
|
Get nodes that haven't been executed yet.
|
314
315
|
|
@@ -323,9 +324,9 @@ class OperationGraphBuilder:
|
|
323
324
|
|
324
325
|
def add_conditional_branch(
|
325
326
|
self,
|
326
|
-
condition_check_op:
|
327
|
-
true_op:
|
328
|
-
false_op:
|
327
|
+
condition_check_op: str,
|
328
|
+
true_op: str,
|
329
|
+
false_op: str | None = None,
|
329
330
|
**check_params,
|
330
331
|
) -> dict[str, str]:
|
331
332
|
"""
|
@@ -341,7 +342,7 @@ class OperationGraphBuilder:
|
|
341
342
|
Dict with node IDs: {'check': id, 'true': id, 'false': id}
|
342
343
|
"""
|
343
344
|
# Add condition check node
|
344
|
-
check_node =
|
345
|
+
check_node = create_operation(
|
345
346
|
operation=condition_check_op,
|
346
347
|
parameters={**check_params, "is_condition_check": True},
|
347
348
|
)
|
@@ -358,7 +359,9 @@ class OperationGraphBuilder:
|
|
358
359
|
result = {"check": check_node.id}
|
359
360
|
|
360
361
|
# Add true branch
|
361
|
-
true_node =
|
362
|
+
true_node = create_operation(
|
363
|
+
operation=true_op, parameters={"branch": "true"}
|
364
|
+
)
|
362
365
|
self.graph.add_node(true_node)
|
363
366
|
self._operations[true_node.id] = true_node
|
364
367
|
result["true"] = true_node.id
|
@@ -371,7 +374,7 @@ class OperationGraphBuilder:
|
|
371
374
|
|
372
375
|
# Add false branch if specified
|
373
376
|
if false_op:
|
374
|
-
false_node =
|
377
|
+
false_node = create_operation(
|
375
378
|
operation=false_op, parameters={"branch": "false"}
|
376
379
|
)
|
377
380
|
self.graph.add_node(false_node)
|
@@ -389,7 +392,7 @@ class OperationGraphBuilder:
|
|
389
392
|
|
390
393
|
return result
|
391
394
|
|
392
|
-
def get_graph(self)
|
395
|
+
def get_graph(self):
|
393
396
|
"""
|
394
397
|
Get the current graph for execution.
|
395
398
|
|
@@ -398,7 +401,7 @@ class OperationGraphBuilder:
|
|
398
401
|
"""
|
399
402
|
return self.graph
|
400
403
|
|
401
|
-
def get_node_by_reference(self, reference_id: str)
|
404
|
+
def get_node_by_reference(self, reference_id: str):
|
402
405
|
"""
|
403
406
|
Get a node by its reference ID.
|
404
407
|
|
@@ -440,292 +443,10 @@ class OperationGraphBuilder:
|
|
440
443
|
}
|
441
444
|
|
442
445
|
def visualize(self, title: str = "Operation Graph", figsize=(14, 10)):
|
446
|
+
from ._visualize_graph import visualize_graph
|
447
|
+
|
443
448
|
visualize_graph(
|
444
449
|
self,
|
445
450
|
title=title,
|
446
451
|
figsize=figsize,
|
447
452
|
)
|
448
|
-
|
449
|
-
|
450
|
-
def visualize_graph(
|
451
|
-
builder: OperationGraphBuilder,
|
452
|
-
title: str = "Operation Graph",
|
453
|
-
figsize=(14, 10),
|
454
|
-
):
|
455
|
-
"""Visualization with improved layout for complex graphs."""
|
456
|
-
from lionagi.protocols.graph.graph import (
|
457
|
-
_MATPLIB_AVAILABLE,
|
458
|
-
_NETWORKX_AVAILABLE,
|
459
|
-
)
|
460
|
-
|
461
|
-
if _MATPLIB_AVAILABLE is not True:
|
462
|
-
raise _MATPLIB_AVAILABLE
|
463
|
-
if _NETWORKX_AVAILABLE is not True:
|
464
|
-
raise _NETWORKX_AVAILABLE
|
465
|
-
|
466
|
-
import matplotlib.pyplot as plt
|
467
|
-
import networkx as nx
|
468
|
-
import numpy as np
|
469
|
-
|
470
|
-
graph = builder.get_graph()
|
471
|
-
|
472
|
-
# Convert to networkx
|
473
|
-
G = nx.DiGraph()
|
474
|
-
|
475
|
-
# Track node positions for hierarchical layout
|
476
|
-
node_levels = {}
|
477
|
-
node_labels = {}
|
478
|
-
node_colors = []
|
479
|
-
node_sizes = []
|
480
|
-
|
481
|
-
# First pass: add nodes and determine levels
|
482
|
-
for node in graph.internal_nodes.values():
|
483
|
-
node_id = str(node.id)[:8]
|
484
|
-
G.add_node(node_id)
|
485
|
-
|
486
|
-
# Determine level based on dependencies
|
487
|
-
in_edges = [
|
488
|
-
e
|
489
|
-
for e in graph.internal_edges.values()
|
490
|
-
if str(e.tail)[:8] == node_id
|
491
|
-
]
|
492
|
-
if not in_edges:
|
493
|
-
level = 0 # Root nodes
|
494
|
-
else:
|
495
|
-
# Get max level of predecessors + 1
|
496
|
-
pred_levels = []
|
497
|
-
for edge in in_edges:
|
498
|
-
pred_id = str(edge.head)[:8]
|
499
|
-
if pred_id in node_levels:
|
500
|
-
pred_levels.append(node_levels[pred_id])
|
501
|
-
level = max(pred_levels, default=0) + 1
|
502
|
-
|
503
|
-
node_levels[node_id] = level
|
504
|
-
|
505
|
-
# Create label
|
506
|
-
ref_id = node.metadata.get("reference_id", "")
|
507
|
-
if ref_id:
|
508
|
-
label = f"{node.operation}\n[{ref_id}]"
|
509
|
-
else:
|
510
|
-
label = f"{node.operation}\n{node_id}"
|
511
|
-
node_labels[node_id] = label
|
512
|
-
|
513
|
-
# Color and size based on status and type
|
514
|
-
if node.id in builder._executed:
|
515
|
-
node_colors.append("#90EE90") # Light green
|
516
|
-
node_sizes.append(4000)
|
517
|
-
elif node.metadata.get("expansion_source"):
|
518
|
-
node_colors.append("#87CEEB") # Sky blue
|
519
|
-
node_sizes.append(3500)
|
520
|
-
elif node.metadata.get("aggregation"):
|
521
|
-
node_colors.append("#FFD700") # Gold
|
522
|
-
node_sizes.append(4500)
|
523
|
-
elif node.metadata.get("is_condition_check"):
|
524
|
-
node_colors.append("#DDA0DD") # Plum
|
525
|
-
node_sizes.append(3500)
|
526
|
-
else:
|
527
|
-
node_colors.append("#E0E0E0") # Light gray
|
528
|
-
node_sizes.append(3000)
|
529
|
-
|
530
|
-
# Add edges
|
531
|
-
edge_colors = []
|
532
|
-
edge_styles = []
|
533
|
-
edge_widths = []
|
534
|
-
edge_labels = {}
|
535
|
-
|
536
|
-
for edge in graph.internal_edges.values():
|
537
|
-
head_id = str(edge.head)[:8]
|
538
|
-
tail_id = str(edge.tail)[:8]
|
539
|
-
G.add_edge(head_id, tail_id)
|
540
|
-
|
541
|
-
# Style edges based on type
|
542
|
-
edge_label = edge.label[0] if edge.label else ""
|
543
|
-
edge_labels[(head_id, tail_id)] = edge_label
|
544
|
-
|
545
|
-
if "expansion" in edge_label:
|
546
|
-
edge_colors.append("#4169E1") # Royal blue
|
547
|
-
edge_styles.append("dashed")
|
548
|
-
edge_widths.append(2)
|
549
|
-
elif "aggregate" in edge_label:
|
550
|
-
edge_colors.append("#FF6347") # Tomato
|
551
|
-
edge_styles.append("dotted")
|
552
|
-
edge_widths.append(2.5)
|
553
|
-
else:
|
554
|
-
edge_colors.append("#808080") # Gray
|
555
|
-
edge_styles.append("solid")
|
556
|
-
edge_widths.append(1.5)
|
557
|
-
|
558
|
-
# Create improved hierarchical layout
|
559
|
-
pos = {}
|
560
|
-
nodes_by_level = {}
|
561
|
-
|
562
|
-
for node_id, level in node_levels.items():
|
563
|
-
if level not in nodes_by_level:
|
564
|
-
nodes_by_level[level] = []
|
565
|
-
nodes_by_level[level].append(node_id)
|
566
|
-
|
567
|
-
# Position nodes with better spacing algorithm
|
568
|
-
y_spacing = 2.5
|
569
|
-
max_width = 16 # Maximum horizontal spread
|
570
|
-
|
571
|
-
for level, nodes in nodes_by_level.items():
|
572
|
-
num_nodes = len(nodes)
|
573
|
-
|
574
|
-
if num_nodes <= 6:
|
575
|
-
# Normal spacing for small levels
|
576
|
-
x_spacing = 2.5
|
577
|
-
x_offset = -(num_nodes - 1) * x_spacing / 2
|
578
|
-
for i, node_id in enumerate(nodes):
|
579
|
-
pos[node_id] = (x_offset + i * x_spacing, -level * y_spacing)
|
580
|
-
else:
|
581
|
-
# Multi-row layout for large levels
|
582
|
-
nodes_per_row = min(6, int(np.ceil(np.sqrt(num_nodes * 1.5))))
|
583
|
-
rows = int(np.ceil(num_nodes / nodes_per_row))
|
584
|
-
|
585
|
-
for i, node_id in enumerate(nodes):
|
586
|
-
row = i // nodes_per_row
|
587
|
-
col = i % nodes_per_row
|
588
|
-
|
589
|
-
# Calculate row width
|
590
|
-
nodes_in_row = min(
|
591
|
-
nodes_per_row, num_nodes - row * nodes_per_row
|
592
|
-
)
|
593
|
-
x_spacing = 2.5
|
594
|
-
x_offset = -(nodes_in_row - 1) * x_spacing / 2
|
595
|
-
|
596
|
-
# Add slight y offset for different rows
|
597
|
-
y_offset = row * 0.8
|
598
|
-
|
599
|
-
pos[node_id] = (
|
600
|
-
x_offset + col * x_spacing,
|
601
|
-
-level * y_spacing - y_offset,
|
602
|
-
)
|
603
|
-
|
604
|
-
# Create figure
|
605
|
-
plt.figure(figsize=figsize)
|
606
|
-
|
607
|
-
# Draw nodes
|
608
|
-
nx.draw_networkx_nodes(
|
609
|
-
G,
|
610
|
-
pos,
|
611
|
-
node_color=node_colors,
|
612
|
-
node_size=node_sizes,
|
613
|
-
alpha=0.9,
|
614
|
-
linewidths=2,
|
615
|
-
edgecolors="black",
|
616
|
-
)
|
617
|
-
|
618
|
-
# Draw edges with different styles - use curved edges for better visibility
|
619
|
-
for i, (u, v) in enumerate(G.edges()):
|
620
|
-
# Calculate curve based on node positions
|
621
|
-
u_pos = pos[u]
|
622
|
-
v_pos = pos[v]
|
623
|
-
|
624
|
-
# Determine connection style based on relative positions
|
625
|
-
if abs(u_pos[0] - v_pos[0]) > 5: # Far apart horizontally
|
626
|
-
connectionstyle = "arc3,rad=0.2"
|
627
|
-
else:
|
628
|
-
connectionstyle = "arc3,rad=0.1"
|
629
|
-
|
630
|
-
nx.draw_networkx_edges(
|
631
|
-
G,
|
632
|
-
pos,
|
633
|
-
[(u, v)],
|
634
|
-
edge_color=[edge_colors[i]],
|
635
|
-
style=edge_styles[i],
|
636
|
-
width=edge_widths[i],
|
637
|
-
alpha=0.7,
|
638
|
-
arrows=True,
|
639
|
-
arrowsize=20,
|
640
|
-
arrowstyle="-|>",
|
641
|
-
connectionstyle=connectionstyle,
|
642
|
-
)
|
643
|
-
|
644
|
-
# Draw labels
|
645
|
-
nx.draw_networkx_labels(
|
646
|
-
G,
|
647
|
-
pos,
|
648
|
-
node_labels,
|
649
|
-
font_size=9,
|
650
|
-
font_weight="bold",
|
651
|
-
font_family="monospace",
|
652
|
-
)
|
653
|
-
|
654
|
-
# Draw edge labels (only for smaller graphs)
|
655
|
-
if len(G.edges()) < 20:
|
656
|
-
nx.draw_networkx_edge_labels(
|
657
|
-
G,
|
658
|
-
pos,
|
659
|
-
edge_labels,
|
660
|
-
font_size=7,
|
661
|
-
font_color="darkblue",
|
662
|
-
bbox=dict(
|
663
|
-
boxstyle="round,pad=0.3",
|
664
|
-
facecolor="white",
|
665
|
-
edgecolor="none",
|
666
|
-
alpha=0.7,
|
667
|
-
),
|
668
|
-
)
|
669
|
-
|
670
|
-
plt.title(title, fontsize=18, fontweight="bold", pad=20)
|
671
|
-
plt.axis("off")
|
672
|
-
|
673
|
-
# Enhanced legend
|
674
|
-
from matplotlib.lines import Line2D
|
675
|
-
from matplotlib.patches import Patch, Rectangle
|
676
|
-
|
677
|
-
legend_elements = [
|
678
|
-
Patch(facecolor="#90EE90", edgecolor="black", label="Executed"),
|
679
|
-
Patch(facecolor="#87CEEB", edgecolor="black", label="Expanded"),
|
680
|
-
Patch(facecolor="#FFD700", edgecolor="black", label="Aggregation"),
|
681
|
-
Patch(facecolor="#DDA0DD", edgecolor="black", label="Condition"),
|
682
|
-
Patch(facecolor="#E0E0E0", edgecolor="black", label="Pending"),
|
683
|
-
Line2D([0], [0], color="#808080", linewidth=2, label="Sequential"),
|
684
|
-
Line2D(
|
685
|
-
[0],
|
686
|
-
[0],
|
687
|
-
color="#4169E1",
|
688
|
-
linewidth=2,
|
689
|
-
linestyle="dashed",
|
690
|
-
label="Expansion",
|
691
|
-
),
|
692
|
-
Line2D(
|
693
|
-
[0],
|
694
|
-
[0],
|
695
|
-
color="#FF6347",
|
696
|
-
linewidth=2,
|
697
|
-
linestyle="dotted",
|
698
|
-
label="Aggregate",
|
699
|
-
),
|
700
|
-
]
|
701
|
-
|
702
|
-
plt.legend(
|
703
|
-
handles=legend_elements,
|
704
|
-
loc="upper left",
|
705
|
-
bbox_to_anchor=(0, 1),
|
706
|
-
frameon=True,
|
707
|
-
fancybox=True,
|
708
|
-
shadow=True,
|
709
|
-
ncol=2,
|
710
|
-
)
|
711
|
-
|
712
|
-
# Add statistics box
|
713
|
-
stats_text = f"Nodes: {len(G.nodes())}\nEdges: {len(G.edges())}\nExecuted: {len(builder._executed)}"
|
714
|
-
if nodes_by_level:
|
715
|
-
max_level = max(nodes_by_level.keys())
|
716
|
-
stats_text += f"\nLevels: {max_level + 1}"
|
717
|
-
|
718
|
-
plt.text(
|
719
|
-
0.98,
|
720
|
-
0.02,
|
721
|
-
stats_text,
|
722
|
-
transform=plt.gca().transAxes,
|
723
|
-
bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.8),
|
724
|
-
verticalalignment="bottom",
|
725
|
-
horizontalalignment="right",
|
726
|
-
fontsize=10,
|
727
|
-
fontfamily="monospace",
|
728
|
-
)
|
729
|
-
|
730
|
-
plt.tight_layout()
|
731
|
-
plt.show()
|
lionagi/operations/flow.py
CHANGED
@@ -10,16 +10,19 @@ using Events for synchronization and CapacityLimiter for concurrency control.
|
|
10
10
|
"""
|
11
11
|
|
12
12
|
import os
|
13
|
-
from typing import Any
|
13
|
+
from typing import TYPE_CHECKING, Any
|
14
14
|
|
15
15
|
from lionagi.ln._async_call import AlcallParams
|
16
16
|
from lionagi.ln.concurrency import CapacityLimiter, ConcurrencyEvent
|
17
17
|
from lionagi.operations.node import Operation
|
18
|
-
from lionagi.protocols.types import EventStatus
|
19
|
-
from lionagi.session.branch import Branch
|
20
|
-
from lionagi.session.session import Session
|
18
|
+
from lionagi.protocols.types import EventStatus
|
21
19
|
from lionagi.utils import to_dict
|
22
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from lionagi.protocols.graph.graph import Graph
|
23
|
+
from lionagi.session.session import Branch, Session
|
24
|
+
|
25
|
+
|
23
26
|
# Maximum concurrency when None is specified (effectively unlimited)
|
24
27
|
UNLIMITED_CONCURRENCY = int(os.environ.get("LIONAGI_MAX_CONCURRENCY", "10000"))
|
25
28
|
|
@@ -29,12 +32,12 @@ class DependencyAwareExecutor:
|
|
29
32
|
|
30
33
|
def __init__(
|
31
34
|
self,
|
32
|
-
session: Session,
|
33
|
-
graph: Graph,
|
35
|
+
session: "Session",
|
36
|
+
graph: "Graph",
|
34
37
|
context: dict[str, Any] | None = None,
|
35
38
|
max_concurrent: int = 5,
|
36
39
|
verbose: bool = False,
|
37
|
-
default_branch: Branch
|
40
|
+
default_branch: "Branch" = None,
|
38
41
|
alcall_params: AlcallParams | None = None,
|
39
42
|
):
|
40
43
|
"""Initialize the executor.
|
@@ -402,7 +405,7 @@ class DependencyAwareExecutor:
|
|
402
405
|
branch = self._resolve_branch_for_operation(operation)
|
403
406
|
self.operation_branches[operation.id] = branch
|
404
407
|
|
405
|
-
def _resolve_branch_for_operation(self, operation: Operation) -> Branch:
|
408
|
+
def _resolve_branch_for_operation(self, operation: Operation) -> "Branch":
|
406
409
|
"""Resolve which branch an operation should use - all branches are pre-allocated."""
|
407
410
|
# All branches should be pre-allocated
|
408
411
|
if operation.id in self.operation_branches:
|
@@ -501,10 +504,10 @@ class DependencyAwareExecutor:
|
|
501
504
|
|
502
505
|
|
503
506
|
async def flow(
|
504
|
-
session: Session,
|
505
|
-
graph: Graph,
|
507
|
+
session: "Session",
|
508
|
+
graph: "Graph",
|
506
509
|
*,
|
507
|
-
branch: Branch
|
510
|
+
branch: "Branch" = None,
|
508
511
|
context: dict[str, Any] | None = None,
|
509
512
|
parallel: bool = True,
|
510
513
|
max_concurrent: int = None,
|
@@ -546,3 +549,94 @@ async def flow(
|
|
546
549
|
)
|
547
550
|
|
548
551
|
return await executor.execute()
|
552
|
+
|
553
|
+
|
554
|
+
def cleanup_flow_results(
|
555
|
+
result: dict[str, Any], keep_only: list[str] = None
|
556
|
+
) -> dict[str, Any]:
|
557
|
+
"""
|
558
|
+
Clean up flow execution results to reduce memory usage.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
result: Flow execution result dictionary
|
562
|
+
keep_only: List of operation IDs to keep results for (optional)
|
563
|
+
|
564
|
+
Returns:
|
565
|
+
Modified result dictionary with reduced memory footprint
|
566
|
+
"""
|
567
|
+
if not isinstance(result, dict) or "operation_results" not in result:
|
568
|
+
return result
|
569
|
+
|
570
|
+
# If keep_only is specified, only keep those results
|
571
|
+
if keep_only is not None:
|
572
|
+
filtered_results = {
|
573
|
+
op_id: res
|
574
|
+
for op_id, res in result["operation_results"].items()
|
575
|
+
if op_id in keep_only
|
576
|
+
}
|
577
|
+
result["operation_results"] = filtered_results
|
578
|
+
# Update completed_operations to match
|
579
|
+
result["completed_operations"] = [
|
580
|
+
op_id
|
581
|
+
for op_id in result.get("completed_operations", [])
|
582
|
+
if op_id in keep_only
|
583
|
+
]
|
584
|
+
else:
|
585
|
+
# Clear all results to free memory
|
586
|
+
result["operation_results"] = {}
|
587
|
+
result["completed_operations"] = []
|
588
|
+
|
589
|
+
return result
|
590
|
+
|
591
|
+
|
592
|
+
async def flow_with_cleanup(
|
593
|
+
session: "Session",
|
594
|
+
graph: "Graph",
|
595
|
+
context: dict[str, Any] | None = None,
|
596
|
+
parallel: bool = True,
|
597
|
+
max_concurrent: int = 5,
|
598
|
+
verbose: bool = False,
|
599
|
+
branch: "Branch" = None,
|
600
|
+
alcall_params: AlcallParams | None = None,
|
601
|
+
cleanup_results: bool = True,
|
602
|
+
keep_only: list[str] = None,
|
603
|
+
) -> dict[str, Any]:
|
604
|
+
"""
|
605
|
+
Execute flow with automatic cleanup to prevent memory accumulation.
|
606
|
+
|
607
|
+
Args:
|
608
|
+
session: Session instance for branch management
|
609
|
+
graph: Operation graph to execute
|
610
|
+
context: Initial context data
|
611
|
+
parallel: Execute independent operations in parallel
|
612
|
+
max_concurrent: Max concurrent operations (1 if not parallel)
|
613
|
+
verbose: Enable verbose logging
|
614
|
+
branch: Default branch for operations
|
615
|
+
alcall_params: Parameters for async parallel call execution
|
616
|
+
cleanup_results: Whether to clean up operation results after execution
|
617
|
+
keep_only: List of operation IDs to keep results for (if cleanup_results=True)
|
618
|
+
|
619
|
+
Returns:
|
620
|
+
Execution results (potentially with cleaned up memory footprint)
|
621
|
+
"""
|
622
|
+
# Execute the flow normally
|
623
|
+
result = await flow(
|
624
|
+
session=session,
|
625
|
+
graph=graph,
|
626
|
+
context=context,
|
627
|
+
parallel=parallel,
|
628
|
+
max_concurrent=max_concurrent,
|
629
|
+
verbose=verbose,
|
630
|
+
branch=branch,
|
631
|
+
alcall_params=alcall_params,
|
632
|
+
)
|
633
|
+
|
634
|
+
# Clean up session memory
|
635
|
+
if hasattr(session, "cleanup_memory"):
|
636
|
+
session.cleanup_memory()
|
637
|
+
|
638
|
+
# Clean up results if requested
|
639
|
+
if cleanup_results:
|
640
|
+
result = cleanup_flow_results(result, keep_only=keep_only)
|
641
|
+
|
642
|
+
return result
|
lionagi/operations/node.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
|
-
from typing import Any, Literal
|
3
|
+
from typing import TYPE_CHECKING, Any, Literal
|
4
4
|
from uuid import UUID
|
5
5
|
|
6
6
|
from anyio import get_cancelled_exc_class
|
7
7
|
from pydantic import BaseModel, Field
|
8
8
|
|
9
9
|
from lionagi.protocols.types import ID, Event, EventStatus, IDType, Node
|
10
|
-
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from lionagi.session.branch import Branch
|
11
13
|
|
12
14
|
BranchOperations = Literal[
|
13
15
|
"chat",
|
@@ -74,7 +76,7 @@ class Operation(Node, Event):
|
|
74
76
|
"""Get the response from the execution."""
|
75
77
|
return self.execution.response if self.execution else None
|
76
78
|
|
77
|
-
async def invoke(self, branch: Branch):
|
79
|
+
async def invoke(self, branch: "Branch"):
|
78
80
|
meth = branch.get_operation(self.operation)
|
79
81
|
if meth is None:
|
80
82
|
raise ValueError(f"Unsupported operation type: {self.operation}")
|
@@ -108,3 +110,12 @@ class Operation(Node, Event):
|
|
108
110
|
res.append(i)
|
109
111
|
return res
|
110
112
|
return await meth(**self.request)
|
113
|
+
|
114
|
+
|
115
|
+
def create_operation(
|
116
|
+
operation: BranchOperations | str,
|
117
|
+
parameters: dict[str, Any] | BaseModel = None,
|
118
|
+
**kwargs,
|
119
|
+
):
|
120
|
+
"""Create an Operation node."""
|
121
|
+
return Operation(operation=operation, parameters=parameters, **kwargs)
|