lionagi 0.13.7__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ """Structured concurrency primitives.
2
+
3
+ This module provides structured concurrency primitives using AnyIO,
4
+ which allows for consistent behavior across asyncio and trio backends.
5
+ """
6
+
7
+ from .cancel import CancelScope, fail_after, move_on_after
8
+ from .errors import get_cancelled_exc_class, shield
9
+ from .primitives import CapacityLimiter, Condition, Event, Lock, Semaphore
10
+ from .task import TaskGroup, create_task_group
11
+
12
+ __all__ = [
13
+ "TaskGroup",
14
+ "create_task_group",
15
+ "CancelScope",
16
+ "move_on_after",
17
+ "fail_after",
18
+ "Lock",
19
+ "Semaphore",
20
+ "CapacityLimiter",
21
+ "Event",
22
+ "Condition",
23
+ "get_cancelled_exc_class",
24
+ "shield",
25
+ ]
@@ -0,0 +1,134 @@
1
+ """Cancellation scope implementation for structured concurrency."""
2
+
3
+ import time
4
+ from collections.abc import Iterator
5
+ from contextlib import contextmanager
6
+ from types import TracebackType
7
+ from typing import Optional, TypeVar
8
+
9
+ import anyio
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ class CancelScope:
15
+ """A context manager for controlling cancellation of tasks."""
16
+
17
+ def __init__(self, deadline: float | None = None, shield: bool = False):
18
+ """Initialize a new cancel scope.
19
+
20
+ Args:
21
+ deadline: The time (in seconds since the epoch) when this scope should be cancelled
22
+ shield: If True, this scope shields its contents from external cancellation
23
+ """
24
+ self._scope = None
25
+ self._deadline = deadline
26
+ self._shield = shield
27
+ self.cancel_called = False
28
+ self.cancelled_caught = False
29
+
30
+ def cancel(self) -> None:
31
+ """Cancel this scope.
32
+
33
+ This will cause all tasks within this scope to be cancelled.
34
+ """
35
+ self.cancel_called = True
36
+ if self._scope is not None:
37
+ self._scope.cancel()
38
+
39
+ def __enter__(self) -> "CancelScope":
40
+ """Enter the cancel scope context.
41
+
42
+ Returns:
43
+ The cancel scope instance.
44
+ """
45
+ # Use math.inf as the default deadline (no timeout)
46
+ import math
47
+
48
+ deadline = self._deadline if self._deadline is not None else math.inf
49
+ self._scope = anyio.CancelScope(deadline=deadline, shield=self._shield)
50
+ if self.cancel_called:
51
+ self._scope.cancel()
52
+ self._scope.__enter__()
53
+ return self
54
+
55
+ def __exit__(
56
+ self,
57
+ exc_type: type[BaseException] | None,
58
+ exc_val: BaseException | None,
59
+ exc_tb: TracebackType | None,
60
+ ) -> bool:
61
+ """Exit the cancel scope context.
62
+
63
+ Returns:
64
+ True if the exception was handled, False otherwise.
65
+ """
66
+ if self._scope is None:
67
+ return False
68
+
69
+ try:
70
+ result = self._scope.__exit__(exc_type, exc_val, exc_tb)
71
+ self.cancelled_caught = self._scope.cancelled_caught
72
+ return result
73
+ finally:
74
+ self._scope = None
75
+
76
+
77
+ @contextmanager
78
+ def move_on_after(seconds: float | None) -> Iterator[CancelScope]:
79
+ """Return a context manager that cancels its contents after the given number of seconds.
80
+
81
+ Args:
82
+ seconds: The number of seconds to wait before cancelling, or None to disable the timeout
83
+
84
+ Returns:
85
+ A cancel scope that will be cancelled after the specified time
86
+
87
+ Example:
88
+ with move_on_after(5) as scope:
89
+ await long_running_operation()
90
+ if scope.cancelled_caught:
91
+ print("Operation timed out")
92
+ """
93
+ deadline = None if seconds is None else time.time() + seconds
94
+ scope = CancelScope(deadline=deadline)
95
+ with scope:
96
+ yield scope
97
+
98
+
99
+ @contextmanager
100
+ def fail_after(seconds: float | None) -> Iterator[CancelScope]:
101
+ """Return a context manager that raises TimeoutError if its contents take longer than the given time.
102
+
103
+ Args:
104
+ seconds: The number of seconds to wait before raising TimeoutError, or None to disable the timeout
105
+
106
+ Returns:
107
+ A cancel scope that will raise TimeoutError after the specified time
108
+
109
+ Raises:
110
+ TimeoutError: If the operation takes longer than the specified time
111
+
112
+ Example:
113
+ try:
114
+ with fail_after(5):
115
+ await long_running_operation()
116
+ except TimeoutError:
117
+ print("Operation timed out")
118
+ """
119
+ if seconds is None:
120
+ # No timeout
121
+ scope = CancelScope(shield=True)
122
+ with scope:
123
+ yield scope
124
+ else:
125
+ deadline = time.time() + seconds
126
+ scope = CancelScope(deadline=deadline)
127
+ try:
128
+ with scope:
129
+ yield scope
130
+ finally:
131
+ if scope.cancelled_caught:
132
+ raise TimeoutError(
133
+ f"Operation took longer than {seconds} seconds"
134
+ )
@@ -0,0 +1,35 @@
1
+ """Error handling utilities for structured concurrency."""
2
+
3
+ from collections.abc import Awaitable, Callable
4
+ from typing import Any, TypeVar
5
+
6
+ import anyio
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def get_cancelled_exc_class() -> type[BaseException]:
12
+ """Get the exception class used for cancellation.
13
+
14
+ Returns:
15
+ The exception class used for cancellation (CancelledError for asyncio,
16
+ Cancelled for trio).
17
+ """
18
+ return anyio.get_cancelled_exc_class()
19
+
20
+
21
+ async def shield(
22
+ func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
23
+ ) -> T:
24
+ """Run a coroutine function with protection from cancellation.
25
+
26
+ Args:
27
+ func: The coroutine function to call
28
+ *args: Positional arguments to pass to the function
29
+ **kwargs: Keyword arguments to pass to the function
30
+
31
+ Returns:
32
+ The return value of the function
33
+ """
34
+ with anyio.CancelScope(shield=True):
35
+ return await func(*args, **kwargs)
@@ -0,0 +1,252 @@
1
+ """Common concurrency patterns for structured concurrency."""
2
+
3
+ import math
4
+ from collections.abc import Awaitable, Callable
5
+ from types import TracebackType
6
+ from typing import Any, Optional, TypeVar
7
+
8
+ import anyio
9
+
10
+ from .cancel import move_on_after
11
+ from .primitives import CapacityLimiter, Lock
12
+ from .task import create_task_group
13
+
14
+ T = TypeVar("T")
15
+ R = TypeVar("R")
16
+ Response = TypeVar("Response")
17
+
18
+
19
+ class ConnectionPool:
20
+ """A pool of reusable connections."""
21
+
22
+ def __init__(
23
+ self,
24
+ max_connections: int,
25
+ connection_factory: Callable[[], Awaitable[T]],
26
+ ):
27
+ """Initialize a new connection pool.
28
+
29
+ Args:
30
+ max_connections: The maximum number of connections in the pool
31
+ connection_factory: A factory function that creates new connections
32
+ """
33
+ self._connection_factory = connection_factory
34
+ self._limiter = CapacityLimiter(max_connections)
35
+ self._connections: list[T] = []
36
+ self._lock = Lock()
37
+
38
+ async def acquire(self) -> T:
39
+ """Acquire a connection from the pool.
40
+
41
+ Returns:
42
+ A connection from the pool, or a new connection if the pool is empty.
43
+ """
44
+ async with self._limiter:
45
+ async with self._lock:
46
+ if self._connections:
47
+ return self._connections.pop()
48
+
49
+ # No connections available, create a new one
50
+ return await self._connection_factory()
51
+
52
+ async def release(self, connection: T) -> None:
53
+ """Release a connection back to the pool.
54
+
55
+ Args:
56
+ connection: The connection to release
57
+ """
58
+ async with self._lock:
59
+ self._connections.append(connection)
60
+
61
+ async def __aenter__(self) -> "ConnectionPool":
62
+ """Enter the connection pool context.
63
+
64
+ Returns:
65
+ The connection pool instance.
66
+ """
67
+ return self
68
+
69
+ async def __aexit__(
70
+ self,
71
+ exc_type: type[BaseException] | None,
72
+ exc_val: BaseException | None,
73
+ exc_tb: TracebackType | None,
74
+ ) -> None:
75
+ """Exit the connection pool context, closing all connections."""
76
+ async with self._lock:
77
+ for connection in self._connections:
78
+ if hasattr(connection, "close"):
79
+ await connection.close()
80
+ elif hasattr(connection, "disconnect"):
81
+ await connection.disconnect()
82
+ self._connections.clear()
83
+
84
+
85
+ async def parallel_requests(
86
+ urls: list[str],
87
+ fetch_func: Callable[[str], Awaitable[Response]],
88
+ max_concurrency: int = 10,
89
+ ) -> list[Response]:
90
+ """Fetch multiple URLs in parallel with limited concurrency.
91
+
92
+ Args:
93
+ urls: The URLs to fetch
94
+ fetch_func: The function to use for fetching
95
+ max_concurrency: The maximum number of concurrent requests
96
+
97
+ Returns:
98
+ A list of responses in the same order as the URLs
99
+ """
100
+ limiter = CapacityLimiter(max_concurrency)
101
+ results: list[Response | None] = [None] * len(urls)
102
+ exceptions: list[Exception | None] = [None] * len(urls)
103
+
104
+ async def fetch_with_limit(index: int, url: str) -> None:
105
+ async with limiter:
106
+ try:
107
+ results[index] = await fetch_func(url)
108
+ except Exception as exc:
109
+ exceptions[index] = exc
110
+
111
+ async with create_task_group() as tg:
112
+ for i, url in enumerate(urls):
113
+ await tg.start_soon(fetch_with_limit, i, url)
114
+
115
+ # Check for exceptions
116
+ for i, exc in enumerate(exceptions):
117
+ if exc is not None:
118
+ raise exc
119
+
120
+ return results # type: ignore
121
+
122
+
123
+ async def retry_with_timeout(
124
+ func: Callable[..., Awaitable[T]],
125
+ *args: Any,
126
+ max_retries: int = 3,
127
+ timeout: float = 5.0,
128
+ retry_exceptions: list[type[Exception]] | None = None,
129
+ **kwargs: Any,
130
+ ) -> T:
131
+ """Execute a function with retry logic and timeout.
132
+
133
+ Args:
134
+ func: The function to call
135
+ *args: Positional arguments to pass to the function
136
+ max_retries: The maximum number of retry attempts
137
+ timeout: The timeout for each attempt in seconds
138
+ retry_exceptions: List of exception types to retry on, or None to retry on any exception
139
+ **kwargs: Keyword arguments to pass to the function
140
+
141
+ Returns:
142
+ The return value of the function
143
+
144
+ Raises:
145
+ TimeoutError: If all retry attempts time out
146
+ Exception: If the function raises an exception after all retry attempts
147
+ """
148
+ retry_exceptions = retry_exceptions or [Exception]
149
+ last_exception = None
150
+
151
+ for attempt in range(max_retries):
152
+ try:
153
+ timed_out = False
154
+ with move_on_after(timeout) as scope:
155
+ result = await func(*args, **kwargs)
156
+ if not scope.cancelled_caught:
157
+ return result
158
+ timed_out = True
159
+
160
+ # If we get here, the operation timed out
161
+ if timed_out:
162
+ if attempt == max_retries - 1:
163
+ raise TimeoutError(
164
+ f"Operation timed out after {max_retries} attempts"
165
+ )
166
+
167
+ # Wait before retrying (exponential backoff)
168
+ await anyio.sleep(2**attempt)
169
+
170
+ except tuple(retry_exceptions) as exc:
171
+ last_exception = exc
172
+ if attempt == max_retries - 1:
173
+ raise
174
+
175
+ # Wait before retrying (exponential backoff)
176
+ await anyio.sleep(2**attempt)
177
+
178
+ # This should never be reached, but makes the type checker happy
179
+ if last_exception:
180
+ raise last_exception
181
+ raise RuntimeError("Unreachable code")
182
+
183
+
184
+ class WorkerPool:
185
+ """A pool of worker tasks that process items from a queue."""
186
+
187
+ def __init__(
188
+ self, num_workers: int, worker_func: Callable[[Any], Awaitable[None]]
189
+ ):
190
+ """Initialize a new worker pool.
191
+
192
+ Args:
193
+ num_workers: The number of worker tasks to create
194
+ worker_func: The function that each worker will run
195
+ """
196
+ self._num_workers = num_workers
197
+ self._worker_func = worker_func
198
+ self._queue = anyio.create_memory_object_stream(math.inf)
199
+ self._task_group = None
200
+
201
+ async def start(self) -> None:
202
+ """Start the worker pool."""
203
+ if self._task_group is not None:
204
+ raise RuntimeError("Worker pool already started")
205
+
206
+ self._task_group = create_task_group()
207
+
208
+ async with self._task_group as tg:
209
+ for _ in range(self._num_workers):
210
+ tg.start_soon(self._worker_loop)
211
+
212
+ async def stop(self) -> None:
213
+ """Stop the worker pool."""
214
+ if self._task_group is None:
215
+ return
216
+
217
+ # Signal workers to stop
218
+ for _ in range(self._num_workers):
219
+ await self._queue[0].send(None)
220
+
221
+ # Wait for workers to finish
222
+ await self._task_group.__aexit__(None, None, None)
223
+ self._task_group = None
224
+
225
+ async def submit(self, item: Any) -> None:
226
+ """Submit an item to be processed by a worker.
227
+
228
+ Args:
229
+ item: The item to process
230
+ """
231
+ if self._task_group is None:
232
+ raise RuntimeError("Worker pool not started")
233
+
234
+ await self._queue[0].send(item)
235
+
236
+ async def _worker_loop(self) -> None:
237
+ """The main loop for each worker task."""
238
+ while True:
239
+ try:
240
+ item = await self._queue[1].receive()
241
+
242
+ # None is a signal to stop
243
+ if item is None:
244
+ break
245
+
246
+ try:
247
+ await self._worker_func(item)
248
+ except Exception as exc:
249
+ # Log the exception but keep the worker running
250
+ print(f"Worker error: {exc}")
251
+ except anyio.EndOfStream:
252
+ break
@@ -0,0 +1,242 @@
1
+ """Resource management primitives for structured concurrency."""
2
+
3
+ from types import TracebackType
4
+ from typing import Optional
5
+
6
+ import anyio
7
+
8
+
9
+ class Lock:
10
+ """A mutex lock for controlling access to a shared resource.
11
+
12
+ This lock is reentrant, meaning the same task can acquire it multiple times
13
+ without deadlocking.
14
+ """
15
+
16
+ def __init__(self):
17
+ """Initialize a new lock."""
18
+ self._lock = anyio.Lock()
19
+
20
+ async def __aenter__(self) -> None:
21
+ """Acquire the lock.
22
+
23
+ If the lock is already held by another task, this will wait until it's released.
24
+ """
25
+ await self.acquire()
26
+
27
+ async def __aexit__(
28
+ self,
29
+ exc_type: type[BaseException] | None,
30
+ exc_val: BaseException | None,
31
+ exc_tb: TracebackType | None,
32
+ ) -> None:
33
+ """Release the lock."""
34
+ self.release()
35
+
36
+ async def acquire(self) -> bool:
37
+ """Acquire the lock.
38
+
39
+ Returns:
40
+ True if the lock was acquired, False otherwise.
41
+ """
42
+ await self._lock.acquire()
43
+ return True
44
+
45
+ def release(self) -> None:
46
+ """Release the lock.
47
+
48
+ Raises:
49
+ RuntimeError: If the lock is not currently held by this task.
50
+ """
51
+ self._lock.release()
52
+
53
+
54
+ class Semaphore:
55
+ """A semaphore for limiting concurrent access to a resource."""
56
+
57
+ def __init__(self, initial_value: int):
58
+ """Initialize a new semaphore.
59
+
60
+ Args:
61
+ initial_value: The initial value of the semaphore (must be >= 0)
62
+ """
63
+ if initial_value < 0:
64
+ raise ValueError("The initial value must be >= 0")
65
+ self._semaphore = anyio.Semaphore(initial_value)
66
+
67
+ async def __aenter__(self) -> None:
68
+ """Acquire the semaphore.
69
+
70
+ If the semaphore value is zero, this will wait until it's released.
71
+ """
72
+ await self.acquire()
73
+
74
+ async def __aexit__(
75
+ self,
76
+ exc_type: type[BaseException] | None,
77
+ exc_val: BaseException | None,
78
+ exc_tb: TracebackType | None,
79
+ ) -> None:
80
+ """Release the semaphore."""
81
+ self.release()
82
+
83
+ async def acquire(self) -> None:
84
+ """Acquire the semaphore.
85
+
86
+ If the semaphore value is zero, this will wait until it's released.
87
+ """
88
+ await self._semaphore.acquire()
89
+
90
+ def release(self) -> None:
91
+ """Release the semaphore, incrementing its value."""
92
+ self._semaphore.release()
93
+
94
+
95
+ class CapacityLimiter:
96
+ """A context manager for limiting the number of concurrent operations."""
97
+
98
+ def __init__(self, total_tokens: float):
99
+ """Initialize a new capacity limiter.
100
+
101
+ Args:
102
+ total_tokens: The maximum number of tokens (>= 1)
103
+ """
104
+ if total_tokens < 1:
105
+ raise ValueError("The total number of tokens must be >= 1")
106
+ self._limiter = anyio.CapacityLimiter(total_tokens)
107
+
108
+ async def __aenter__(self) -> None:
109
+ """Acquire a token.
110
+
111
+ If no tokens are available, this will wait until one is released.
112
+ """
113
+ await self.acquire()
114
+
115
+ async def __aexit__(
116
+ self,
117
+ exc_type: type[BaseException] | None,
118
+ exc_val: BaseException | None,
119
+ exc_tb: TracebackType | None,
120
+ ) -> None:
121
+ """Release the token."""
122
+ self.release()
123
+
124
+ async def acquire(self) -> None:
125
+ """Acquire a token.
126
+
127
+ If no tokens are available, this will wait until one is released.
128
+ """
129
+ await self._limiter.acquire()
130
+
131
+ def release(self) -> None:
132
+ """Release a token.
133
+
134
+ Raises:
135
+ RuntimeError: If the current task doesn't hold any tokens.
136
+ """
137
+ self._limiter.release()
138
+
139
+ @property
140
+ def total_tokens(self) -> float:
141
+ """The total number of tokens."""
142
+ return self._limiter.total_tokens
143
+
144
+ @total_tokens.setter
145
+ def total_tokens(self, value: float) -> None:
146
+ """Set the total number of tokens.
147
+
148
+ Args:
149
+ value: The new total number of tokens (>= 1)
150
+ """
151
+ if value < 1:
152
+ raise ValueError("The total number of tokens must be >= 1")
153
+ self._limiter.total_tokens = value
154
+
155
+ @property
156
+ def borrowed_tokens(self) -> int:
157
+ """The number of tokens currently borrowed."""
158
+ return self._limiter.borrowed_tokens
159
+
160
+ @property
161
+ def available_tokens(self) -> float:
162
+ """The number of tokens currently available."""
163
+ return self._limiter.available_tokens
164
+
165
+
166
+ class Event:
167
+ """An event object for task synchronization.
168
+
169
+ An event can be in one of two states: set or unset. When set, tasks waiting
170
+ on the event are allowed to proceed.
171
+ """
172
+
173
+ def __init__(self):
174
+ """Initialize a new event in the unset state."""
175
+ self._event = anyio.Event()
176
+
177
+ def is_set(self) -> bool:
178
+ """Check if the event is set.
179
+
180
+ Returns:
181
+ True if the event is set, False otherwise.
182
+ """
183
+ return self._event.is_set()
184
+
185
+ def set(self) -> None:
186
+ """Set the event, allowing all waiting tasks to proceed."""
187
+ self._event.set()
188
+
189
+ async def wait(self) -> None:
190
+ """Wait until the event is set."""
191
+ await self._event.wait()
192
+
193
+
194
+ class Condition:
195
+ """A condition variable for task synchronization."""
196
+
197
+ def __init__(self, lock: Lock | None = None):
198
+ """Initialize a new condition.
199
+
200
+ Args:
201
+ lock: The lock to use, or None to create a new one
202
+ """
203
+ self._lock = lock or Lock()
204
+ self._condition = anyio.Condition(self._lock._lock)
205
+
206
+ async def __aenter__(self) -> "Condition":
207
+ """Acquire the underlying lock.
208
+
209
+ Returns:
210
+ The condition instance.
211
+ """
212
+ await self._lock.acquire()
213
+ return self
214
+
215
+ async def __aexit__(
216
+ self,
217
+ exc_type: type[BaseException] | None,
218
+ exc_val: BaseException | None,
219
+ exc_tb: TracebackType | None,
220
+ ) -> None:
221
+ """Release the underlying lock."""
222
+ self._lock.release()
223
+
224
+ async def wait(self) -> None:
225
+ """Wait for a notification.
226
+
227
+ This releases the underlying lock, waits for a notification, and then
228
+ reacquires the lock.
229
+ """
230
+ await self._condition.wait()
231
+
232
+ async def notify(self, n: int = 1) -> None:
233
+ """Notify waiting tasks.
234
+
235
+ Args:
236
+ n: The number of tasks to notify
237
+ """
238
+ await self._condition.notify(n)
239
+
240
+ async def notify_all(self) -> None:
241
+ """Notify all waiting tasks."""
242
+ await self._condition.notify_all()