lionagi 0.13.7__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/libs/concurrency/__init__.py +25 -0
- lionagi/libs/concurrency/cancel.py +134 -0
- lionagi/libs/concurrency/errors.py +35 -0
- lionagi/libs/concurrency/patterns.py +252 -0
- lionagi/libs/concurrency/primitives.py +242 -0
- lionagi/libs/concurrency/task.py +109 -0
- lionagi/operations/builder.py +46 -0
- lionagi/operations/flow.py +292 -383
- lionagi/operations/node.py +2 -1
- lionagi/protocols/generic/pile.py +41 -156
- lionagi/protocols/graph/edge.py +1 -1
- lionagi/protocols/graph/node.py +27 -55
- lionagi/protocols/types.py +1 -2
- lionagi/service/connections/providers/claude_code_.py +31 -8
- lionagi/service/connections/providers/claude_code_cli.py +2 -3
- lionagi/session/session.py +8 -8
- lionagi/version.py +1 -1
- {lionagi-0.13.7.dist-info → lionagi-0.14.1.dist-info}/METADATA +2 -2
- {lionagi-0.13.7.dist-info → lionagi-0.14.1.dist-info}/RECORD +21 -15
- {lionagi-0.13.7.dist-info → lionagi-0.14.1.dist-info}/WHEEL +0 -0
- {lionagi-0.13.7.dist-info → lionagi-0.14.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
"""Structured concurrency primitives.
|
2
|
+
|
3
|
+
This module provides structured concurrency primitives using AnyIO,
|
4
|
+
which allows for consistent behavior across asyncio and trio backends.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .cancel import CancelScope, fail_after, move_on_after
|
8
|
+
from .errors import get_cancelled_exc_class, shield
|
9
|
+
from .primitives import CapacityLimiter, Condition, Event, Lock, Semaphore
|
10
|
+
from .task import TaskGroup, create_task_group
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"TaskGroup",
|
14
|
+
"create_task_group",
|
15
|
+
"CancelScope",
|
16
|
+
"move_on_after",
|
17
|
+
"fail_after",
|
18
|
+
"Lock",
|
19
|
+
"Semaphore",
|
20
|
+
"CapacityLimiter",
|
21
|
+
"Event",
|
22
|
+
"Condition",
|
23
|
+
"get_cancelled_exc_class",
|
24
|
+
"shield",
|
25
|
+
]
|
@@ -0,0 +1,134 @@
|
|
1
|
+
"""Cancellation scope implementation for structured concurrency."""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from collections.abc import Iterator
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from types import TracebackType
|
7
|
+
from typing import Optional, TypeVar
|
8
|
+
|
9
|
+
import anyio
|
10
|
+
|
11
|
+
T = TypeVar("T")
|
12
|
+
|
13
|
+
|
14
|
+
class CancelScope:
|
15
|
+
"""A context manager for controlling cancellation of tasks."""
|
16
|
+
|
17
|
+
def __init__(self, deadline: float | None = None, shield: bool = False):
|
18
|
+
"""Initialize a new cancel scope.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
deadline: The time (in seconds since the epoch) when this scope should be cancelled
|
22
|
+
shield: If True, this scope shields its contents from external cancellation
|
23
|
+
"""
|
24
|
+
self._scope = None
|
25
|
+
self._deadline = deadline
|
26
|
+
self._shield = shield
|
27
|
+
self.cancel_called = False
|
28
|
+
self.cancelled_caught = False
|
29
|
+
|
30
|
+
def cancel(self) -> None:
|
31
|
+
"""Cancel this scope.
|
32
|
+
|
33
|
+
This will cause all tasks within this scope to be cancelled.
|
34
|
+
"""
|
35
|
+
self.cancel_called = True
|
36
|
+
if self._scope is not None:
|
37
|
+
self._scope.cancel()
|
38
|
+
|
39
|
+
def __enter__(self) -> "CancelScope":
|
40
|
+
"""Enter the cancel scope context.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
The cancel scope instance.
|
44
|
+
"""
|
45
|
+
# Use math.inf as the default deadline (no timeout)
|
46
|
+
import math
|
47
|
+
|
48
|
+
deadline = self._deadline if self._deadline is not None else math.inf
|
49
|
+
self._scope = anyio.CancelScope(deadline=deadline, shield=self._shield)
|
50
|
+
if self.cancel_called:
|
51
|
+
self._scope.cancel()
|
52
|
+
self._scope.__enter__()
|
53
|
+
return self
|
54
|
+
|
55
|
+
def __exit__(
|
56
|
+
self,
|
57
|
+
exc_type: type[BaseException] | None,
|
58
|
+
exc_val: BaseException | None,
|
59
|
+
exc_tb: TracebackType | None,
|
60
|
+
) -> bool:
|
61
|
+
"""Exit the cancel scope context.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
True if the exception was handled, False otherwise.
|
65
|
+
"""
|
66
|
+
if self._scope is None:
|
67
|
+
return False
|
68
|
+
|
69
|
+
try:
|
70
|
+
result = self._scope.__exit__(exc_type, exc_val, exc_tb)
|
71
|
+
self.cancelled_caught = self._scope.cancelled_caught
|
72
|
+
return result
|
73
|
+
finally:
|
74
|
+
self._scope = None
|
75
|
+
|
76
|
+
|
77
|
+
@contextmanager
|
78
|
+
def move_on_after(seconds: float | None) -> Iterator[CancelScope]:
|
79
|
+
"""Return a context manager that cancels its contents after the given number of seconds.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
seconds: The number of seconds to wait before cancelling, or None to disable the timeout
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A cancel scope that will be cancelled after the specified time
|
86
|
+
|
87
|
+
Example:
|
88
|
+
with move_on_after(5) as scope:
|
89
|
+
await long_running_operation()
|
90
|
+
if scope.cancelled_caught:
|
91
|
+
print("Operation timed out")
|
92
|
+
"""
|
93
|
+
deadline = None if seconds is None else time.time() + seconds
|
94
|
+
scope = CancelScope(deadline=deadline)
|
95
|
+
with scope:
|
96
|
+
yield scope
|
97
|
+
|
98
|
+
|
99
|
+
@contextmanager
|
100
|
+
def fail_after(seconds: float | None) -> Iterator[CancelScope]:
|
101
|
+
"""Return a context manager that raises TimeoutError if its contents take longer than the given time.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
seconds: The number of seconds to wait before raising TimeoutError, or None to disable the timeout
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
A cancel scope that will raise TimeoutError after the specified time
|
108
|
+
|
109
|
+
Raises:
|
110
|
+
TimeoutError: If the operation takes longer than the specified time
|
111
|
+
|
112
|
+
Example:
|
113
|
+
try:
|
114
|
+
with fail_after(5):
|
115
|
+
await long_running_operation()
|
116
|
+
except TimeoutError:
|
117
|
+
print("Operation timed out")
|
118
|
+
"""
|
119
|
+
if seconds is None:
|
120
|
+
# No timeout
|
121
|
+
scope = CancelScope(shield=True)
|
122
|
+
with scope:
|
123
|
+
yield scope
|
124
|
+
else:
|
125
|
+
deadline = time.time() + seconds
|
126
|
+
scope = CancelScope(deadline=deadline)
|
127
|
+
try:
|
128
|
+
with scope:
|
129
|
+
yield scope
|
130
|
+
finally:
|
131
|
+
if scope.cancelled_caught:
|
132
|
+
raise TimeoutError(
|
133
|
+
f"Operation took longer than {seconds} seconds"
|
134
|
+
)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""Error handling utilities for structured concurrency."""
|
2
|
+
|
3
|
+
from collections.abc import Awaitable, Callable
|
4
|
+
from typing import Any, TypeVar
|
5
|
+
|
6
|
+
import anyio
|
7
|
+
|
8
|
+
T = TypeVar("T")
|
9
|
+
|
10
|
+
|
11
|
+
def get_cancelled_exc_class() -> type[BaseException]:
|
12
|
+
"""Get the exception class used for cancellation.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
The exception class used for cancellation (CancelledError for asyncio,
|
16
|
+
Cancelled for trio).
|
17
|
+
"""
|
18
|
+
return anyio.get_cancelled_exc_class()
|
19
|
+
|
20
|
+
|
21
|
+
async def shield(
|
22
|
+
func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
|
23
|
+
) -> T:
|
24
|
+
"""Run a coroutine function with protection from cancellation.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
func: The coroutine function to call
|
28
|
+
*args: Positional arguments to pass to the function
|
29
|
+
**kwargs: Keyword arguments to pass to the function
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
The return value of the function
|
33
|
+
"""
|
34
|
+
with anyio.CancelScope(shield=True):
|
35
|
+
return await func(*args, **kwargs)
|
@@ -0,0 +1,252 @@
|
|
1
|
+
"""Common concurrency patterns for structured concurrency."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from collections.abc import Awaitable, Callable
|
5
|
+
from types import TracebackType
|
6
|
+
from typing import Any, Optional, TypeVar
|
7
|
+
|
8
|
+
import anyio
|
9
|
+
|
10
|
+
from .cancel import move_on_after
|
11
|
+
from .primitives import CapacityLimiter, Lock
|
12
|
+
from .task import create_task_group
|
13
|
+
|
14
|
+
T = TypeVar("T")
|
15
|
+
R = TypeVar("R")
|
16
|
+
Response = TypeVar("Response")
|
17
|
+
|
18
|
+
|
19
|
+
class ConnectionPool:
|
20
|
+
"""A pool of reusable connections."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
max_connections: int,
|
25
|
+
connection_factory: Callable[[], Awaitable[T]],
|
26
|
+
):
|
27
|
+
"""Initialize a new connection pool.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
max_connections: The maximum number of connections in the pool
|
31
|
+
connection_factory: A factory function that creates new connections
|
32
|
+
"""
|
33
|
+
self._connection_factory = connection_factory
|
34
|
+
self._limiter = CapacityLimiter(max_connections)
|
35
|
+
self._connections: list[T] = []
|
36
|
+
self._lock = Lock()
|
37
|
+
|
38
|
+
async def acquire(self) -> T:
|
39
|
+
"""Acquire a connection from the pool.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
A connection from the pool, or a new connection if the pool is empty.
|
43
|
+
"""
|
44
|
+
async with self._limiter:
|
45
|
+
async with self._lock:
|
46
|
+
if self._connections:
|
47
|
+
return self._connections.pop()
|
48
|
+
|
49
|
+
# No connections available, create a new one
|
50
|
+
return await self._connection_factory()
|
51
|
+
|
52
|
+
async def release(self, connection: T) -> None:
|
53
|
+
"""Release a connection back to the pool.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
connection: The connection to release
|
57
|
+
"""
|
58
|
+
async with self._lock:
|
59
|
+
self._connections.append(connection)
|
60
|
+
|
61
|
+
async def __aenter__(self) -> "ConnectionPool":
|
62
|
+
"""Enter the connection pool context.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
The connection pool instance.
|
66
|
+
"""
|
67
|
+
return self
|
68
|
+
|
69
|
+
async def __aexit__(
|
70
|
+
self,
|
71
|
+
exc_type: type[BaseException] | None,
|
72
|
+
exc_val: BaseException | None,
|
73
|
+
exc_tb: TracebackType | None,
|
74
|
+
) -> None:
|
75
|
+
"""Exit the connection pool context, closing all connections."""
|
76
|
+
async with self._lock:
|
77
|
+
for connection in self._connections:
|
78
|
+
if hasattr(connection, "close"):
|
79
|
+
await connection.close()
|
80
|
+
elif hasattr(connection, "disconnect"):
|
81
|
+
await connection.disconnect()
|
82
|
+
self._connections.clear()
|
83
|
+
|
84
|
+
|
85
|
+
async def parallel_requests(
|
86
|
+
urls: list[str],
|
87
|
+
fetch_func: Callable[[str], Awaitable[Response]],
|
88
|
+
max_concurrency: int = 10,
|
89
|
+
) -> list[Response]:
|
90
|
+
"""Fetch multiple URLs in parallel with limited concurrency.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
urls: The URLs to fetch
|
94
|
+
fetch_func: The function to use for fetching
|
95
|
+
max_concurrency: The maximum number of concurrent requests
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
A list of responses in the same order as the URLs
|
99
|
+
"""
|
100
|
+
limiter = CapacityLimiter(max_concurrency)
|
101
|
+
results: list[Response | None] = [None] * len(urls)
|
102
|
+
exceptions: list[Exception | None] = [None] * len(urls)
|
103
|
+
|
104
|
+
async def fetch_with_limit(index: int, url: str) -> None:
|
105
|
+
async with limiter:
|
106
|
+
try:
|
107
|
+
results[index] = await fetch_func(url)
|
108
|
+
except Exception as exc:
|
109
|
+
exceptions[index] = exc
|
110
|
+
|
111
|
+
async with create_task_group() as tg:
|
112
|
+
for i, url in enumerate(urls):
|
113
|
+
await tg.start_soon(fetch_with_limit, i, url)
|
114
|
+
|
115
|
+
# Check for exceptions
|
116
|
+
for i, exc in enumerate(exceptions):
|
117
|
+
if exc is not None:
|
118
|
+
raise exc
|
119
|
+
|
120
|
+
return results # type: ignore
|
121
|
+
|
122
|
+
|
123
|
+
async def retry_with_timeout(
|
124
|
+
func: Callable[..., Awaitable[T]],
|
125
|
+
*args: Any,
|
126
|
+
max_retries: int = 3,
|
127
|
+
timeout: float = 5.0,
|
128
|
+
retry_exceptions: list[type[Exception]] | None = None,
|
129
|
+
**kwargs: Any,
|
130
|
+
) -> T:
|
131
|
+
"""Execute a function with retry logic and timeout.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
func: The function to call
|
135
|
+
*args: Positional arguments to pass to the function
|
136
|
+
max_retries: The maximum number of retry attempts
|
137
|
+
timeout: The timeout for each attempt in seconds
|
138
|
+
retry_exceptions: List of exception types to retry on, or None to retry on any exception
|
139
|
+
**kwargs: Keyword arguments to pass to the function
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
The return value of the function
|
143
|
+
|
144
|
+
Raises:
|
145
|
+
TimeoutError: If all retry attempts time out
|
146
|
+
Exception: If the function raises an exception after all retry attempts
|
147
|
+
"""
|
148
|
+
retry_exceptions = retry_exceptions or [Exception]
|
149
|
+
last_exception = None
|
150
|
+
|
151
|
+
for attempt in range(max_retries):
|
152
|
+
try:
|
153
|
+
timed_out = False
|
154
|
+
with move_on_after(timeout) as scope:
|
155
|
+
result = await func(*args, **kwargs)
|
156
|
+
if not scope.cancelled_caught:
|
157
|
+
return result
|
158
|
+
timed_out = True
|
159
|
+
|
160
|
+
# If we get here, the operation timed out
|
161
|
+
if timed_out:
|
162
|
+
if attempt == max_retries - 1:
|
163
|
+
raise TimeoutError(
|
164
|
+
f"Operation timed out after {max_retries} attempts"
|
165
|
+
)
|
166
|
+
|
167
|
+
# Wait before retrying (exponential backoff)
|
168
|
+
await anyio.sleep(2**attempt)
|
169
|
+
|
170
|
+
except tuple(retry_exceptions) as exc:
|
171
|
+
last_exception = exc
|
172
|
+
if attempt == max_retries - 1:
|
173
|
+
raise
|
174
|
+
|
175
|
+
# Wait before retrying (exponential backoff)
|
176
|
+
await anyio.sleep(2**attempt)
|
177
|
+
|
178
|
+
# This should never be reached, but makes the type checker happy
|
179
|
+
if last_exception:
|
180
|
+
raise last_exception
|
181
|
+
raise RuntimeError("Unreachable code")
|
182
|
+
|
183
|
+
|
184
|
+
class WorkerPool:
|
185
|
+
"""A pool of worker tasks that process items from a queue."""
|
186
|
+
|
187
|
+
def __init__(
|
188
|
+
self, num_workers: int, worker_func: Callable[[Any], Awaitable[None]]
|
189
|
+
):
|
190
|
+
"""Initialize a new worker pool.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
num_workers: The number of worker tasks to create
|
194
|
+
worker_func: The function that each worker will run
|
195
|
+
"""
|
196
|
+
self._num_workers = num_workers
|
197
|
+
self._worker_func = worker_func
|
198
|
+
self._queue = anyio.create_memory_object_stream(math.inf)
|
199
|
+
self._task_group = None
|
200
|
+
|
201
|
+
async def start(self) -> None:
|
202
|
+
"""Start the worker pool."""
|
203
|
+
if self._task_group is not None:
|
204
|
+
raise RuntimeError("Worker pool already started")
|
205
|
+
|
206
|
+
self._task_group = create_task_group()
|
207
|
+
|
208
|
+
async with self._task_group as tg:
|
209
|
+
for _ in range(self._num_workers):
|
210
|
+
tg.start_soon(self._worker_loop)
|
211
|
+
|
212
|
+
async def stop(self) -> None:
|
213
|
+
"""Stop the worker pool."""
|
214
|
+
if self._task_group is None:
|
215
|
+
return
|
216
|
+
|
217
|
+
# Signal workers to stop
|
218
|
+
for _ in range(self._num_workers):
|
219
|
+
await self._queue[0].send(None)
|
220
|
+
|
221
|
+
# Wait for workers to finish
|
222
|
+
await self._task_group.__aexit__(None, None, None)
|
223
|
+
self._task_group = None
|
224
|
+
|
225
|
+
async def submit(self, item: Any) -> None:
|
226
|
+
"""Submit an item to be processed by a worker.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
item: The item to process
|
230
|
+
"""
|
231
|
+
if self._task_group is None:
|
232
|
+
raise RuntimeError("Worker pool not started")
|
233
|
+
|
234
|
+
await self._queue[0].send(item)
|
235
|
+
|
236
|
+
async def _worker_loop(self) -> None:
|
237
|
+
"""The main loop for each worker task."""
|
238
|
+
while True:
|
239
|
+
try:
|
240
|
+
item = await self._queue[1].receive()
|
241
|
+
|
242
|
+
# None is a signal to stop
|
243
|
+
if item is None:
|
244
|
+
break
|
245
|
+
|
246
|
+
try:
|
247
|
+
await self._worker_func(item)
|
248
|
+
except Exception as exc:
|
249
|
+
# Log the exception but keep the worker running
|
250
|
+
print(f"Worker error: {exc}")
|
251
|
+
except anyio.EndOfStream:
|
252
|
+
break
|
@@ -0,0 +1,242 @@
|
|
1
|
+
"""Resource management primitives for structured concurrency."""
|
2
|
+
|
3
|
+
from types import TracebackType
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import anyio
|
7
|
+
|
8
|
+
|
9
|
+
class Lock:
|
10
|
+
"""A mutex lock for controlling access to a shared resource.
|
11
|
+
|
12
|
+
This lock is reentrant, meaning the same task can acquire it multiple times
|
13
|
+
without deadlocking.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self):
|
17
|
+
"""Initialize a new lock."""
|
18
|
+
self._lock = anyio.Lock()
|
19
|
+
|
20
|
+
async def __aenter__(self) -> None:
|
21
|
+
"""Acquire the lock.
|
22
|
+
|
23
|
+
If the lock is already held by another task, this will wait until it's released.
|
24
|
+
"""
|
25
|
+
await self.acquire()
|
26
|
+
|
27
|
+
async def __aexit__(
|
28
|
+
self,
|
29
|
+
exc_type: type[BaseException] | None,
|
30
|
+
exc_val: BaseException | None,
|
31
|
+
exc_tb: TracebackType | None,
|
32
|
+
) -> None:
|
33
|
+
"""Release the lock."""
|
34
|
+
self.release()
|
35
|
+
|
36
|
+
async def acquire(self) -> bool:
|
37
|
+
"""Acquire the lock.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
True if the lock was acquired, False otherwise.
|
41
|
+
"""
|
42
|
+
await self._lock.acquire()
|
43
|
+
return True
|
44
|
+
|
45
|
+
def release(self) -> None:
|
46
|
+
"""Release the lock.
|
47
|
+
|
48
|
+
Raises:
|
49
|
+
RuntimeError: If the lock is not currently held by this task.
|
50
|
+
"""
|
51
|
+
self._lock.release()
|
52
|
+
|
53
|
+
|
54
|
+
class Semaphore:
|
55
|
+
"""A semaphore for limiting concurrent access to a resource."""
|
56
|
+
|
57
|
+
def __init__(self, initial_value: int):
|
58
|
+
"""Initialize a new semaphore.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
initial_value: The initial value of the semaphore (must be >= 0)
|
62
|
+
"""
|
63
|
+
if initial_value < 0:
|
64
|
+
raise ValueError("The initial value must be >= 0")
|
65
|
+
self._semaphore = anyio.Semaphore(initial_value)
|
66
|
+
|
67
|
+
async def __aenter__(self) -> None:
|
68
|
+
"""Acquire the semaphore.
|
69
|
+
|
70
|
+
If the semaphore value is zero, this will wait until it's released.
|
71
|
+
"""
|
72
|
+
await self.acquire()
|
73
|
+
|
74
|
+
async def __aexit__(
|
75
|
+
self,
|
76
|
+
exc_type: type[BaseException] | None,
|
77
|
+
exc_val: BaseException | None,
|
78
|
+
exc_tb: TracebackType | None,
|
79
|
+
) -> None:
|
80
|
+
"""Release the semaphore."""
|
81
|
+
self.release()
|
82
|
+
|
83
|
+
async def acquire(self) -> None:
|
84
|
+
"""Acquire the semaphore.
|
85
|
+
|
86
|
+
If the semaphore value is zero, this will wait until it's released.
|
87
|
+
"""
|
88
|
+
await self._semaphore.acquire()
|
89
|
+
|
90
|
+
def release(self) -> None:
|
91
|
+
"""Release the semaphore, incrementing its value."""
|
92
|
+
self._semaphore.release()
|
93
|
+
|
94
|
+
|
95
|
+
class CapacityLimiter:
|
96
|
+
"""A context manager for limiting the number of concurrent operations."""
|
97
|
+
|
98
|
+
def __init__(self, total_tokens: float):
|
99
|
+
"""Initialize a new capacity limiter.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
total_tokens: The maximum number of tokens (>= 1)
|
103
|
+
"""
|
104
|
+
if total_tokens < 1:
|
105
|
+
raise ValueError("The total number of tokens must be >= 1")
|
106
|
+
self._limiter = anyio.CapacityLimiter(total_tokens)
|
107
|
+
|
108
|
+
async def __aenter__(self) -> None:
|
109
|
+
"""Acquire a token.
|
110
|
+
|
111
|
+
If no tokens are available, this will wait until one is released.
|
112
|
+
"""
|
113
|
+
await self.acquire()
|
114
|
+
|
115
|
+
async def __aexit__(
|
116
|
+
self,
|
117
|
+
exc_type: type[BaseException] | None,
|
118
|
+
exc_val: BaseException | None,
|
119
|
+
exc_tb: TracebackType | None,
|
120
|
+
) -> None:
|
121
|
+
"""Release the token."""
|
122
|
+
self.release()
|
123
|
+
|
124
|
+
async def acquire(self) -> None:
|
125
|
+
"""Acquire a token.
|
126
|
+
|
127
|
+
If no tokens are available, this will wait until one is released.
|
128
|
+
"""
|
129
|
+
await self._limiter.acquire()
|
130
|
+
|
131
|
+
def release(self) -> None:
|
132
|
+
"""Release a token.
|
133
|
+
|
134
|
+
Raises:
|
135
|
+
RuntimeError: If the current task doesn't hold any tokens.
|
136
|
+
"""
|
137
|
+
self._limiter.release()
|
138
|
+
|
139
|
+
@property
|
140
|
+
def total_tokens(self) -> float:
|
141
|
+
"""The total number of tokens."""
|
142
|
+
return self._limiter.total_tokens
|
143
|
+
|
144
|
+
@total_tokens.setter
|
145
|
+
def total_tokens(self, value: float) -> None:
|
146
|
+
"""Set the total number of tokens.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
value: The new total number of tokens (>= 1)
|
150
|
+
"""
|
151
|
+
if value < 1:
|
152
|
+
raise ValueError("The total number of tokens must be >= 1")
|
153
|
+
self._limiter.total_tokens = value
|
154
|
+
|
155
|
+
@property
|
156
|
+
def borrowed_tokens(self) -> int:
|
157
|
+
"""The number of tokens currently borrowed."""
|
158
|
+
return self._limiter.borrowed_tokens
|
159
|
+
|
160
|
+
@property
|
161
|
+
def available_tokens(self) -> float:
|
162
|
+
"""The number of tokens currently available."""
|
163
|
+
return self._limiter.available_tokens
|
164
|
+
|
165
|
+
|
166
|
+
class Event:
|
167
|
+
"""An event object for task synchronization.
|
168
|
+
|
169
|
+
An event can be in one of two states: set or unset. When set, tasks waiting
|
170
|
+
on the event are allowed to proceed.
|
171
|
+
"""
|
172
|
+
|
173
|
+
def __init__(self):
|
174
|
+
"""Initialize a new event in the unset state."""
|
175
|
+
self._event = anyio.Event()
|
176
|
+
|
177
|
+
def is_set(self) -> bool:
|
178
|
+
"""Check if the event is set.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
True if the event is set, False otherwise.
|
182
|
+
"""
|
183
|
+
return self._event.is_set()
|
184
|
+
|
185
|
+
def set(self) -> None:
|
186
|
+
"""Set the event, allowing all waiting tasks to proceed."""
|
187
|
+
self._event.set()
|
188
|
+
|
189
|
+
async def wait(self) -> None:
|
190
|
+
"""Wait until the event is set."""
|
191
|
+
await self._event.wait()
|
192
|
+
|
193
|
+
|
194
|
+
class Condition:
|
195
|
+
"""A condition variable for task synchronization."""
|
196
|
+
|
197
|
+
def __init__(self, lock: Lock | None = None):
|
198
|
+
"""Initialize a new condition.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
lock: The lock to use, or None to create a new one
|
202
|
+
"""
|
203
|
+
self._lock = lock or Lock()
|
204
|
+
self._condition = anyio.Condition(self._lock._lock)
|
205
|
+
|
206
|
+
async def __aenter__(self) -> "Condition":
|
207
|
+
"""Acquire the underlying lock.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
The condition instance.
|
211
|
+
"""
|
212
|
+
await self._lock.acquire()
|
213
|
+
return self
|
214
|
+
|
215
|
+
async def __aexit__(
|
216
|
+
self,
|
217
|
+
exc_type: type[BaseException] | None,
|
218
|
+
exc_val: BaseException | None,
|
219
|
+
exc_tb: TracebackType | None,
|
220
|
+
) -> None:
|
221
|
+
"""Release the underlying lock."""
|
222
|
+
self._lock.release()
|
223
|
+
|
224
|
+
async def wait(self) -> None:
|
225
|
+
"""Wait for a notification.
|
226
|
+
|
227
|
+
This releases the underlying lock, waits for a notification, and then
|
228
|
+
reacquires the lock.
|
229
|
+
"""
|
230
|
+
await self._condition.wait()
|
231
|
+
|
232
|
+
async def notify(self, n: int = 1) -> None:
|
233
|
+
"""Notify waiting tasks.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
n: The number of tasks to notify
|
237
|
+
"""
|
238
|
+
await self._condition.notify(n)
|
239
|
+
|
240
|
+
async def notify_all(self) -> None:
|
241
|
+
"""Notify all waiting tasks."""
|
242
|
+
await self._condition.notify_all()
|