lionagi 0.14.4__py3-none-any.whl → 0.14.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/fields/instruct.py +3 -17
- lionagi/libs/concurrency/__init__.py +25 -1
- lionagi/libs/concurrency/cancel.py +1 -1
- lionagi/libs/concurrency/patterns.py +145 -138
- lionagi/libs/concurrency/primitives.py +145 -97
- lionagi/libs/concurrency/resource_tracker.py +182 -0
- lionagi/libs/concurrency/task.py +4 -2
- lionagi/operations/builder.py +9 -0
- lionagi/operations/flow.py +163 -60
- lionagi/protocols/generic/pile.py +7 -10
- lionagi/protocols/generic/processor.py +53 -26
- lionagi/service/connections/providers/_claude_code/__init__.py +3 -0
- lionagi/service/connections/providers/_claude_code/models.py +235 -0
- lionagi/service/connections/providers/_claude_code/stream_cli.py +350 -0
- lionagi/service/connections/providers/claude_code_.py +13 -223
- lionagi/service/connections/providers/claude_code_cli.py +38 -343
- lionagi/service/rate_limited_processor.py +53 -35
- lionagi/session/branch.py +6 -51
- lionagi/session/session.py +26 -8
- lionagi/utils.py +56 -174
- lionagi/version.py +1 -1
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/METADATA +6 -2
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/RECORD +25 -21
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/WHEEL +0 -0
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,36 @@
|
|
1
|
-
"""Resource management primitives for structured concurrency.
|
1
|
+
"""Resource management primitives for structured concurrency.
|
2
2
|
|
3
|
+
Pure async primitives focused on correctness and simplicity.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import math
|
3
7
|
from types import TracebackType
|
4
|
-
from typing import Optional
|
5
8
|
|
6
9
|
import anyio
|
7
10
|
|
11
|
+
from .resource_tracker import track_resource, untrack_resource
|
8
12
|
|
9
|
-
class Lock:
|
10
|
-
"""A mutex lock for controlling access to a shared resource.
|
11
13
|
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
14
|
+
class Lock:
|
15
|
+
"""A mutex lock for controlling access to a shared resource."""
|
15
16
|
|
16
17
|
def __init__(self):
|
17
18
|
"""Initialize a new lock."""
|
18
19
|
self._lock = anyio.Lock()
|
20
|
+
self._acquired = False
|
21
|
+
track_resource(self, f"Lock-{id(self)}", "Lock")
|
19
22
|
|
20
|
-
|
21
|
-
"""
|
23
|
+
def __del__(self):
|
24
|
+
"""Clean up resource tracking when lock is destroyed."""
|
25
|
+
try:
|
26
|
+
untrack_resource(self)
|
27
|
+
except Exception:
|
28
|
+
pass
|
22
29
|
|
23
|
-
|
24
|
-
"""
|
25
|
-
await self.acquire()
|
30
|
+
async def __aenter__(self) -> None:
|
31
|
+
"""Acquire the lock."""
|
32
|
+
await self._lock.acquire()
|
33
|
+
self._acquired = True
|
26
34
|
|
27
35
|
async def __aexit__(
|
28
36
|
self,
|
@@ -31,44 +39,45 @@ class Lock:
|
|
31
39
|
exc_tb: TracebackType | None,
|
32
40
|
) -> None:
|
33
41
|
"""Release the lock."""
|
34
|
-
self.release()
|
35
|
-
|
36
|
-
async def acquire(self) -> bool:
|
37
|
-
"""Acquire the lock.
|
42
|
+
self._lock.release()
|
43
|
+
self._acquired = False
|
38
44
|
|
39
|
-
|
40
|
-
|
41
|
-
"""
|
45
|
+
async def acquire(self) -> None:
|
46
|
+
"""Acquire the lock directly."""
|
42
47
|
await self._lock.acquire()
|
43
|
-
|
48
|
+
self._acquired = True
|
44
49
|
|
45
50
|
def release(self) -> None:
|
46
|
-
"""Release the lock.
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
+
"""Release the lock directly."""
|
52
|
+
if not self._acquired:
|
53
|
+
raise RuntimeError(
|
54
|
+
"Attempted to release lock that was not acquired by this task"
|
55
|
+
)
|
51
56
|
self._lock.release()
|
57
|
+
self._acquired = False
|
52
58
|
|
53
59
|
|
54
60
|
class Semaphore:
|
55
|
-
"""A semaphore
|
61
|
+
"""A semaphore preventing excessive releases."""
|
56
62
|
|
57
63
|
def __init__(self, initial_value: int):
|
58
|
-
"""Initialize a new semaphore.
|
59
|
-
|
60
|
-
Args:
|
61
|
-
initial_value: The initial value of the semaphore (must be >= 0)
|
62
|
-
"""
|
64
|
+
"""Initialize a new semaphore."""
|
63
65
|
if initial_value < 0:
|
64
66
|
raise ValueError("The initial value must be >= 0")
|
67
|
+
self._initial_value = initial_value
|
68
|
+
self._current_acquisitions = 0
|
65
69
|
self._semaphore = anyio.Semaphore(initial_value)
|
70
|
+
track_resource(self, f"Semaphore-{id(self)}", "Semaphore")
|
66
71
|
|
67
|
-
|
68
|
-
"""
|
72
|
+
def __del__(self):
|
73
|
+
"""Clean up resource tracking when semaphore is destroyed."""
|
74
|
+
try:
|
75
|
+
untrack_resource(self)
|
76
|
+
except Exception:
|
77
|
+
pass
|
69
78
|
|
70
|
-
|
71
|
-
"""
|
79
|
+
async def __aenter__(self) -> None:
|
80
|
+
"""Acquire the semaphore."""
|
72
81
|
await self.acquire()
|
73
82
|
|
74
83
|
async def __aexit__(
|
@@ -81,35 +90,60 @@ class Semaphore:
|
|
81
90
|
self.release()
|
82
91
|
|
83
92
|
async def acquire(self) -> None:
|
84
|
-
"""Acquire the semaphore.
|
85
|
-
|
86
|
-
If the semaphore value is zero, this will wait until it's released.
|
87
|
-
"""
|
93
|
+
"""Acquire the semaphore."""
|
88
94
|
await self._semaphore.acquire()
|
95
|
+
self._current_acquisitions += 1
|
89
96
|
|
90
97
|
def release(self) -> None:
|
91
|
-
"""Release the semaphore
|
98
|
+
"""Release the semaphore."""
|
99
|
+
if self._current_acquisitions <= 0:
|
100
|
+
raise RuntimeError(
|
101
|
+
"Cannot release semaphore: no outstanding acquisitions"
|
102
|
+
)
|
92
103
|
self._semaphore.release()
|
104
|
+
self._current_acquisitions -= 1
|
105
|
+
|
106
|
+
@property
|
107
|
+
def current_acquisitions(self) -> int:
|
108
|
+
"""Get the current number of outstanding acquisitions."""
|
109
|
+
return self._current_acquisitions
|
110
|
+
|
111
|
+
@property
|
112
|
+
def initial_value(self) -> int:
|
113
|
+
"""Get the initial semaphore value."""
|
114
|
+
return self._initial_value
|
93
115
|
|
94
116
|
|
95
117
|
class CapacityLimiter:
|
96
118
|
"""A context manager for limiting the number of concurrent operations."""
|
97
119
|
|
98
|
-
def __init__(self, total_tokens: float):
|
99
|
-
"""Initialize a new capacity limiter.
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
120
|
+
def __init__(self, total_tokens: int | float):
|
121
|
+
"""Initialize a new capacity limiter."""
|
122
|
+
if total_tokens == math.inf:
|
123
|
+
processed_tokens = math.inf
|
124
|
+
elif isinstance(total_tokens, (int, float)) and total_tokens >= 1:
|
125
|
+
processed_tokens = (
|
126
|
+
int(total_tokens) if total_tokens != math.inf else math.inf
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
raise ValueError(
|
130
|
+
"The total number of tokens must be >= 1 (int or math.inf)"
|
131
|
+
)
|
132
|
+
|
133
|
+
self._limiter = anyio.CapacityLimiter(processed_tokens)
|
134
|
+
self._borrower_counter = 0
|
135
|
+
self._active_borrowers = {}
|
136
|
+
track_resource(self, f"CapacityLimiter-{id(self)}", "CapacityLimiter")
|
137
|
+
|
138
|
+
def __del__(self):
|
139
|
+
"""Clean up resource tracking when limiter is destroyed."""
|
140
|
+
try:
|
141
|
+
untrack_resource(self)
|
142
|
+
except Exception:
|
143
|
+
pass
|
107
144
|
|
108
145
|
async def __aenter__(self) -> None:
|
109
|
-
"""Acquire a token.
|
110
|
-
|
111
|
-
If no tokens are available, this will wait until one is released.
|
112
|
-
"""
|
146
|
+
"""Acquire a token."""
|
113
147
|
await self.acquire()
|
114
148
|
|
115
149
|
async def __aexit__(
|
@@ -122,35 +156,49 @@ class CapacityLimiter:
|
|
122
156
|
self.release()
|
123
157
|
|
124
158
|
async def acquire(self) -> None:
|
125
|
-
"""Acquire a token.
|
126
|
-
|
127
|
-
|
128
|
-
""
|
129
|
-
await self._limiter.
|
159
|
+
"""Acquire a token."""
|
160
|
+
# Create a unique borrower identity for each acquisition
|
161
|
+
self._borrower_counter += 1
|
162
|
+
borrower = f"borrower-{self._borrower_counter}"
|
163
|
+
await self._limiter.acquire_on_behalf_of(borrower)
|
164
|
+
self._active_borrowers[borrower] = True
|
130
165
|
|
131
166
|
def release(self) -> None:
|
132
|
-
"""Release a token.
|
167
|
+
"""Release a token."""
|
168
|
+
# Find and release the first active borrower
|
169
|
+
if not self._active_borrowers:
|
170
|
+
raise RuntimeError("No tokens to release")
|
133
171
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
self._limiter.release()
|
172
|
+
borrower = next(iter(self._active_borrowers))
|
173
|
+
self._limiter.release_on_behalf_of(borrower)
|
174
|
+
del self._active_borrowers[borrower]
|
138
175
|
|
139
176
|
@property
|
140
|
-
def total_tokens(self) -> float:
|
177
|
+
def total_tokens(self) -> int | float:
|
141
178
|
"""The total number of tokens."""
|
142
|
-
return self._limiter.total_tokens
|
179
|
+
return float(self._limiter.total_tokens)
|
143
180
|
|
144
181
|
@total_tokens.setter
|
145
|
-
def total_tokens(self, value: float) -> None:
|
146
|
-
"""Set the total number of tokens.
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
raise ValueError(
|
153
|
-
|
182
|
+
def total_tokens(self, value: int | float) -> None:
|
183
|
+
"""Set the total number of tokens."""
|
184
|
+
if value == math.inf:
|
185
|
+
processed_value = math.inf
|
186
|
+
elif isinstance(value, (int, float)) and value >= 1:
|
187
|
+
processed_value = int(value) if value != math.inf else math.inf
|
188
|
+
else:
|
189
|
+
raise ValueError(
|
190
|
+
"The total number of tokens must be >= 1 (int or math.inf)"
|
191
|
+
)
|
192
|
+
|
193
|
+
current_borrowed = self._limiter.borrowed_tokens
|
194
|
+
if processed_value != math.inf and processed_value < current_borrowed:
|
195
|
+
raise ValueError(
|
196
|
+
f"Cannot set total_tokens to {processed_value}: {current_borrowed} tokens "
|
197
|
+
f"are currently borrowed. Wait for tokens to be released or "
|
198
|
+
f"set total_tokens to at least {current_borrowed}."
|
199
|
+
)
|
200
|
+
|
201
|
+
self._limiter.total_tokens = processed_value
|
154
202
|
|
155
203
|
@property
|
156
204
|
def borrowed_tokens(self) -> int:
|
@@ -158,28 +206,28 @@ class CapacityLimiter:
|
|
158
206
|
return self._limiter.borrowed_tokens
|
159
207
|
|
160
208
|
@property
|
161
|
-
def available_tokens(self) -> float:
|
209
|
+
def available_tokens(self) -> int | float:
|
162
210
|
"""The number of tokens currently available."""
|
163
211
|
return self._limiter.available_tokens
|
164
212
|
|
165
213
|
|
166
214
|
class Event:
|
167
|
-
"""An event object for task synchronization.
|
168
|
-
|
169
|
-
An event can be in one of two states: set or unset. When set, tasks waiting
|
170
|
-
on the event are allowed to proceed.
|
171
|
-
"""
|
215
|
+
"""An event object for task synchronization."""
|
172
216
|
|
173
217
|
def __init__(self):
|
174
218
|
"""Initialize a new event in the unset state."""
|
175
219
|
self._event = anyio.Event()
|
220
|
+
track_resource(self, f"Event-{id(self)}", "Event")
|
176
221
|
|
177
|
-
def
|
178
|
-
"""
|
222
|
+
def __del__(self):
|
223
|
+
"""Clean up resource tracking when event is destroyed."""
|
224
|
+
try:
|
225
|
+
untrack_resource(self)
|
226
|
+
except Exception:
|
227
|
+
pass
|
179
228
|
|
180
|
-
|
181
|
-
|
182
|
-
"""
|
229
|
+
def is_set(self) -> bool:
|
230
|
+
"""Check if the event is set."""
|
183
231
|
return self._event.is_set()
|
184
232
|
|
185
233
|
def set(self) -> None:
|
@@ -195,21 +243,21 @@ class Condition:
|
|
195
243
|
"""A condition variable for task synchronization."""
|
196
244
|
|
197
245
|
def __init__(self, lock: Lock | None = None):
|
198
|
-
"""Initialize a new condition.
|
199
|
-
|
200
|
-
Args:
|
201
|
-
lock: The lock to use, or None to create a new one
|
202
|
-
"""
|
246
|
+
"""Initialize a new condition."""
|
203
247
|
self._lock = lock or Lock()
|
204
248
|
self._condition = anyio.Condition(self._lock._lock)
|
249
|
+
track_resource(self, f"Condition-{id(self)}", "Condition")
|
205
250
|
|
206
|
-
|
207
|
-
"""
|
251
|
+
def __del__(self):
|
252
|
+
"""Clean up resource tracking when condition is destroyed."""
|
253
|
+
try:
|
254
|
+
untrack_resource(self)
|
255
|
+
except Exception:
|
256
|
+
pass
|
208
257
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
await self._lock.acquire()
|
258
|
+
async def __aenter__(self) -> "Condition":
|
259
|
+
"""Acquire the underlying lock."""
|
260
|
+
await self._lock.__aenter__()
|
213
261
|
return self
|
214
262
|
|
215
263
|
async def __aexit__(
|
@@ -219,7 +267,7 @@ class Condition:
|
|
219
267
|
exc_tb: TracebackType | None,
|
220
268
|
) -> None:
|
221
269
|
"""Release the underlying lock."""
|
222
|
-
self._lock.
|
270
|
+
await self._lock.__aexit__(exc_type, exc_val, exc_tb)
|
223
271
|
|
224
272
|
async def wait(self) -> None:
|
225
273
|
"""Wait for a notification.
|
@@ -0,0 +1,182 @@
|
|
1
|
+
"""Resource tracking utilities for concurrency primitives.
|
2
|
+
|
3
|
+
This module provides lightweight resource leak detection and lifecycle tracking
|
4
|
+
to address the security vulnerabilities identified in the hardening tests.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import weakref
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from datetime import datetime
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ResourceInfo:
|
15
|
+
"""Information about a tracked resource."""
|
16
|
+
|
17
|
+
name: str
|
18
|
+
creation_time: datetime
|
19
|
+
resource_type: str
|
20
|
+
|
21
|
+
|
22
|
+
class ResourceTracker:
|
23
|
+
"""Lightweight resource lifecycle tracking for leak detection.
|
24
|
+
|
25
|
+
This addresses the over-engineering concerns by providing simple,
|
26
|
+
practical resource management without complex abstraction layers.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
"""Initialize a new resource tracker."""
|
31
|
+
self._active_resources: dict[int, ResourceInfo] = {}
|
32
|
+
self._weak_refs: weakref.WeakKeyDictionary = (
|
33
|
+
weakref.WeakKeyDictionary()
|
34
|
+
)
|
35
|
+
|
36
|
+
def track(
|
37
|
+
self, resource: Any, name: str, resource_type: str | None = None
|
38
|
+
) -> None:
|
39
|
+
"""Track a resource for leak detection.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
resource: The resource to track
|
43
|
+
name: Human-readable name for the resource
|
44
|
+
resource_type: Optional type classification
|
45
|
+
"""
|
46
|
+
if resource_type is None:
|
47
|
+
resource_type = type(resource).__name__
|
48
|
+
|
49
|
+
resource_info = ResourceInfo(
|
50
|
+
name=name,
|
51
|
+
creation_time=datetime.now(),
|
52
|
+
resource_type=resource_type,
|
53
|
+
)
|
54
|
+
|
55
|
+
# Use weak reference to avoid interfering with garbage collection
|
56
|
+
self._weak_refs[resource] = resource_info
|
57
|
+
self._active_resources[id(resource)] = resource_info
|
58
|
+
|
59
|
+
def untrack(self, resource: Any) -> None:
|
60
|
+
"""Manually untrack a resource.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
resource: The resource to stop tracking
|
64
|
+
"""
|
65
|
+
resource_id = id(resource)
|
66
|
+
self._active_resources.pop(resource_id, None)
|
67
|
+
self._weak_refs.pop(resource, None)
|
68
|
+
|
69
|
+
def cleanup_check(self) -> list[ResourceInfo]:
|
70
|
+
"""Check for potentially leaked resources.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
List of resource info for resources that may have leaked
|
74
|
+
"""
|
75
|
+
# Clean up references to garbage collected objects
|
76
|
+
current_resources = []
|
77
|
+
for resource, info in list(self._weak_refs.items()):
|
78
|
+
current_resources.append(info)
|
79
|
+
|
80
|
+
return current_resources
|
81
|
+
|
82
|
+
def get_active_count(self) -> int:
|
83
|
+
"""Get the number of currently tracked resources.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Number of active tracked resources
|
87
|
+
"""
|
88
|
+
return len(self._weak_refs)
|
89
|
+
|
90
|
+
def get_resource_summary(self) -> dict[str, int]:
|
91
|
+
"""Get a summary of tracked resources by type.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
Dictionary mapping resource types to counts
|
95
|
+
"""
|
96
|
+
summary = {}
|
97
|
+
for info in self._weak_refs.values():
|
98
|
+
resource_type = info.resource_type
|
99
|
+
summary[resource_type] = summary.get(resource_type, 0) + 1
|
100
|
+
return summary
|
101
|
+
|
102
|
+
|
103
|
+
# Global tracker instance for convenience
|
104
|
+
_global_tracker = ResourceTracker()
|
105
|
+
|
106
|
+
|
107
|
+
def track_resource(
|
108
|
+
resource: Any, name: str, resource_type: str | None = None
|
109
|
+
) -> None:
|
110
|
+
"""Track a resource using the global tracker.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
resource: The resource to track
|
114
|
+
name: Human-readable name for the resource
|
115
|
+
resource_type: Optional type classification
|
116
|
+
"""
|
117
|
+
_global_tracker.track(resource, name, resource_type)
|
118
|
+
|
119
|
+
|
120
|
+
def untrack_resource(resource: Any) -> None:
|
121
|
+
"""Untrack a resource using the global tracker.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
resource: The resource to stop tracking
|
125
|
+
"""
|
126
|
+
_global_tracker.untrack(resource)
|
127
|
+
|
128
|
+
|
129
|
+
def get_global_tracker() -> ResourceTracker:
|
130
|
+
"""Get the global resource tracker instance.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
The global ResourceTracker instance
|
134
|
+
"""
|
135
|
+
return _global_tracker
|
136
|
+
|
137
|
+
|
138
|
+
def cleanup_check() -> list[ResourceInfo]:
|
139
|
+
"""Check for potentially leaked resources using global tracker.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
List of resource info for resources that may have leaked
|
143
|
+
"""
|
144
|
+
return _global_tracker.cleanup_check()
|
145
|
+
|
146
|
+
|
147
|
+
class resource_leak_detector:
|
148
|
+
"""Context manager for resource leak detection in tests and production.
|
149
|
+
|
150
|
+
Example:
|
151
|
+
async with resource_leak_detector() as tracker:
|
152
|
+
lock = Lock()
|
153
|
+
tracker.track(lock, "test_lock")
|
154
|
+
# ... use lock
|
155
|
+
# Automatically checks for leaks on exit
|
156
|
+
"""
|
157
|
+
|
158
|
+
def __init__(self, raise_on_leak: bool = False):
|
159
|
+
"""Initialize the leak detector.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
raise_on_leak: Whether to raise an exception if leaks are detected
|
163
|
+
"""
|
164
|
+
self.raise_on_leak = raise_on_leak
|
165
|
+
self.tracker = ResourceTracker()
|
166
|
+
self._initial_count = 0
|
167
|
+
|
168
|
+
async def __aenter__(self) -> ResourceTracker:
|
169
|
+
"""Enter the context and return the tracker."""
|
170
|
+
self._initial_count = self.tracker.get_active_count()
|
171
|
+
return self.tracker
|
172
|
+
|
173
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
174
|
+
"""Exit the context and check for leaks."""
|
175
|
+
leaked_resources = self.tracker.cleanup_check()
|
176
|
+
|
177
|
+
if leaked_resources and self.raise_on_leak:
|
178
|
+
resource_summary = self.tracker.get_resource_summary()
|
179
|
+
raise RuntimeError(
|
180
|
+
f"Resource leak detected: {len(leaked_resources)} resources "
|
181
|
+
f"still active. Summary: {resource_summary}"
|
182
|
+
)
|
lionagi/libs/concurrency/task.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
"""Task group implementation for structured concurrency."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from collections.abc import Awaitable, Callable
|
4
6
|
from types import TracebackType
|
5
|
-
from typing import Any,
|
7
|
+
from typing import Any, TypeVar
|
6
8
|
|
7
9
|
import anyio
|
8
10
|
|
@@ -61,7 +63,7 @@ class TaskGroup:
|
|
61
63
|
raise RuntimeError("Task group is not active")
|
62
64
|
return await self._task_group.start(func, *args, name=name)
|
63
65
|
|
64
|
-
async def __aenter__(self) ->
|
66
|
+
async def __aenter__(self) -> TaskGroup:
|
65
67
|
"""Enter the task group context.
|
66
68
|
|
67
69
|
Returns:
|
lionagi/operations/builder.py
CHANGED
@@ -14,6 +14,7 @@ from typing import Any
|
|
14
14
|
from lionagi.operations.node import BranchOperations, Operation
|
15
15
|
from lionagi.protocols.graph.edge import Edge
|
16
16
|
from lionagi.protocols.graph.graph import Graph
|
17
|
+
from lionagi.protocols.types import ID
|
17
18
|
|
18
19
|
__all__ = (
|
19
20
|
"OperationGraphBuilder",
|
@@ -76,6 +77,7 @@ class OperationGraphBuilder:
|
|
76
77
|
node_id: str | None = None,
|
77
78
|
depends_on: list[str] | None = None,
|
78
79
|
inherit_context: bool = False,
|
80
|
+
branch=None,
|
79
81
|
**parameters,
|
80
82
|
) -> str:
|
81
83
|
"""
|
@@ -108,6 +110,9 @@ class OperationGraphBuilder:
|
|
108
110
|
# Add as metadata for easy lookup
|
109
111
|
node.metadata["reference_id"] = node_id
|
110
112
|
|
113
|
+
if branch:
|
114
|
+
node.branch_id = ID.get_id(branch)
|
115
|
+
|
111
116
|
# Handle dependencies
|
112
117
|
if depends_on:
|
113
118
|
for dep_id in depends_on:
|
@@ -227,6 +232,7 @@ class OperationGraphBuilder:
|
|
227
232
|
source_node_ids: list[str] | None = None,
|
228
233
|
inherit_context: bool = False,
|
229
234
|
inherit_from_source: int = 0,
|
235
|
+
branch=None,
|
230
236
|
**parameters,
|
231
237
|
) -> str:
|
232
238
|
"""
|
@@ -264,6 +270,9 @@ class OperationGraphBuilder:
|
|
264
270
|
if node_id:
|
265
271
|
node.metadata["reference_id"] = node_id
|
266
272
|
|
273
|
+
if branch:
|
274
|
+
node.branch_id = ID.get_id(branch)
|
275
|
+
|
267
276
|
# Store context inheritance for aggregations
|
268
277
|
if inherit_context and sources:
|
269
278
|
node.metadata["inherit_context"] = True
|