lionagi 0.14.4__py3-none-any.whl → 0.14.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,6 @@ class Instruct(HashableModel):
35
35
  "reason",
36
36
  "actions",
37
37
  "action_strategy",
38
- "batch_size",
39
38
  "request_params",
40
39
  "response_params",
41
40
  ]
@@ -97,16 +96,10 @@ class Instruct(HashableModel):
97
96
  "None: Contextual execution."
98
97
  ),
99
98
  )
100
- action_strategy: Literal["batch", "sequential", "concurrent"] | None = (
101
- Field(
102
- None,
103
- description="Action strategy to use for executing actions. Default "
104
- "is 'concurrent'. Only provide for if actions are enabled.",
105
- )
106
- )
107
- batch_size: int | None = Field(
99
+ action_strategy: Literal["sequential", "concurrent"] | None = Field(
108
100
  None,
109
- description="Batch size for executing actions. Only provide for 'batch' strategy.",
101
+ description="Action strategy to use for executing actions. Default "
102
+ "is 'concurrent'. Only provide for if actions are enabled.",
110
103
  )
111
104
 
112
105
  @field_validator("instruction", "guidance", "context", mode="before")
@@ -123,13 +116,6 @@ class Instruct(HashableModel):
123
116
  return "concurrent"
124
117
  return v
125
118
 
126
- @field_validator("batch_size", mode="before")
127
- def _validate_batch_size(cls, v):
128
- try:
129
- return to_num(v, num_type=int)
130
- except Exception:
131
- return None
132
-
133
119
 
134
120
  class InstructResponse(HashableModel):
135
121
  instruct: Instruct
@@ -1,4 +1,4 @@
1
- """Structured concurrency primitives.
1
+ """Structured concurrency primitives for pynector.
2
2
 
3
3
  This module provides structured concurrency primitives using AnyIO,
4
4
  which allows for consistent behavior across asyncio and trio backends.
@@ -6,7 +6,21 @@ which allows for consistent behavior across asyncio and trio backends.
6
6
 
7
7
  from .cancel import CancelScope, fail_after, move_on_after
8
8
  from .errors import get_cancelled_exc_class, shield
9
+ from .patterns import (
10
+ ConnectionPool,
11
+ WorkerPool,
12
+ parallel_requests,
13
+ retry_with_timeout,
14
+ )
9
15
  from .primitives import CapacityLimiter, Condition, Event, Lock, Semaphore
16
+ from .resource_tracker import (
17
+ ResourceTracker,
18
+ cleanup_check,
19
+ get_global_tracker,
20
+ resource_leak_detector,
21
+ track_resource,
22
+ untrack_resource,
23
+ )
10
24
  from .task import TaskGroup, create_task_group
11
25
 
12
26
  __all__ = [
@@ -15,6 +29,10 @@ __all__ = [
15
29
  "CancelScope",
16
30
  "move_on_after",
17
31
  "fail_after",
32
+ "ConnectionPool",
33
+ "WorkerPool",
34
+ "parallel_requests",
35
+ "retry_with_timeout",
18
36
  "Lock",
19
37
  "Semaphore",
20
38
  "CapacityLimiter",
@@ -22,4 +40,10 @@ __all__ = [
22
40
  "Condition",
23
41
  "get_cancelled_exc_class",
24
42
  "shield",
43
+ "ResourceTracker",
44
+ "resource_leak_detector",
45
+ "track_resource",
46
+ "untrack_resource",
47
+ "cleanup_check",
48
+ "get_global_tracker",
25
49
  ]
@@ -4,7 +4,7 @@ import time
4
4
  from collections.abc import Iterator
5
5
  from contextlib import contextmanager
6
6
  from types import TracebackType
7
- from typing import Optional, TypeVar
7
+ from typing import TypeVar
8
8
 
9
9
  import anyio
10
10
 
@@ -1,16 +1,21 @@
1
1
  """Common concurrency patterns for structured concurrency."""
2
2
 
3
- import math
3
+ from __future__ import annotations
4
+
5
+ import logging
4
6
  from collections.abc import Awaitable, Callable
5
7
  from types import TracebackType
6
- from typing import Any, Optional, TypeVar
8
+ from typing import Any, TypeVar
7
9
 
8
10
  import anyio
9
11
 
10
12
  from .cancel import move_on_after
11
13
  from .primitives import CapacityLimiter, Lock
14
+ from .resource_tracker import track_resource, untrack_resource
12
15
  from .task import create_task_group
13
16
 
17
+ logger = logging.getLogger(__name__)
18
+
14
19
  T = TypeVar("T")
15
20
  R = TypeVar("R")
16
21
  Response = TypeVar("Response")
@@ -24,46 +29,51 @@ class ConnectionPool:
24
29
  max_connections: int,
25
30
  connection_factory: Callable[[], Awaitable[T]],
26
31
  ):
27
- """Initialize a new connection pool.
32
+ """Initialize a new connection pool."""
33
+ if max_connections < 1:
34
+ raise ValueError("max_connections must be >= 1")
35
+ if not callable(connection_factory):
36
+ raise ValueError("connection_factory must be callable")
28
37
 
29
- Args:
30
- max_connections: The maximum number of connections in the pool
31
- connection_factory: A factory function that creates new connections
32
- """
33
38
  self._connection_factory = connection_factory
34
39
  self._limiter = CapacityLimiter(max_connections)
35
40
  self._connections: list[T] = []
36
41
  self._lock = Lock()
37
42
 
43
+ track_resource(self, f"ConnectionPool-{id(self)}", "ConnectionPool")
44
+
45
+ def __del__(self):
46
+ """Clean up resource tracking."""
47
+ try:
48
+ untrack_resource(self)
49
+ except Exception:
50
+ pass
51
+
38
52
  async def acquire(self) -> T:
39
- """Acquire a connection from the pool.
53
+ """Acquire a connection from the pool."""
54
+ await self._limiter.acquire()
40
55
 
41
- Returns:
42
- A connection from the pool, or a new connection if the pool is empty.
43
- """
44
- async with self._limiter:
56
+ try:
45
57
  async with self._lock:
46
58
  if self._connections:
47
59
  return self._connections.pop()
48
60
 
49
- # No connections available, create a new one
61
+ # No pooled connection available, create new one
50
62
  return await self._connection_factory()
63
+ except Exception:
64
+ self._limiter.release()
65
+ raise
51
66
 
52
67
  async def release(self, connection: T) -> None:
53
- """Release a connection back to the pool.
54
-
55
- Args:
56
- connection: The connection to release
57
- """
58
- async with self._lock:
59
- self._connections.append(connection)
60
-
61
- async def __aenter__(self) -> "ConnectionPool":
62
- """Enter the connection pool context.
68
+ """Release a connection back to the pool."""
69
+ try:
70
+ async with self._lock:
71
+ self._connections.append(connection)
72
+ finally:
73
+ self._limiter.release()
63
74
 
64
- Returns:
65
- The connection pool instance.
66
- """
75
+ async def __aenter__(self) -> ConnectionPool[T]:
76
+ """Enter the connection pool context."""
67
77
  return self
68
78
 
69
79
  async def __aexit__(
@@ -72,113 +82,95 @@ class ConnectionPool:
72
82
  exc_val: BaseException | None,
73
83
  exc_tb: TracebackType | None,
74
84
  ) -> None:
75
- """Exit the connection pool context, closing all connections."""
85
+ """Exit the connection pool context."""
86
+ # Clean up any remaining connections
76
87
  async with self._lock:
77
- for connection in self._connections:
78
- if hasattr(connection, "close"):
79
- await connection.close()
80
- elif hasattr(connection, "disconnect"):
81
- await connection.disconnect()
82
88
  self._connections.clear()
83
89
 
84
90
 
85
91
  async def parallel_requests(
86
- urls: list[str],
87
- fetch_func: Callable[[str], Awaitable[Response]],
92
+ inputs: list[str],
93
+ func: Callable[[str], Awaitable[Response]],
88
94
  max_concurrency: int = 10,
89
95
  ) -> list[Response]:
90
- """Fetch multiple URLs in parallel with limited concurrency.
96
+ """Execute requests in parallel with controlled concurrency.
91
97
 
92
98
  Args:
93
- urls: The URLs to fetch
94
- fetch_func: The function to use for fetching
95
- max_concurrency: The maximum number of concurrent requests
99
+ inputs: List of inputs
100
+ fetch_func: Async function
101
+ max_concurrency: Maximum number of concurrent requests
96
102
 
97
103
  Returns:
98
- A list of responses in the same order as the URLs
104
+ List of responses in the same order as inputs
99
105
  """
100
- limiter = CapacityLimiter(max_concurrency)
101
- results: list[Response | None] = [None] * len(urls)
102
- exceptions: list[Exception | None] = [None] * len(urls)
103
-
104
- async def fetch_with_limit(index: int, url: str) -> None:
105
- async with limiter:
106
- try:
107
- results[index] = await fetch_func(url)
108
- except Exception as exc:
109
- exceptions[index] = exc
110
-
111
- async with create_task_group() as tg:
112
- for i, url in enumerate(urls):
113
- await tg.start_soon(fetch_with_limit, i, url)
114
-
115
- # Check for exceptions
116
- for i, exc in enumerate(exceptions):
117
- if exc is not None:
118
- raise exc
106
+ if not inputs:
107
+ return []
108
+
109
+ results: list[Response | None] = [None] * len(inputs)
110
+
111
+ async def bounded_fetch(
112
+ semaphore: anyio.Semaphore, idx: int, url: str
113
+ ) -> None:
114
+ async with semaphore:
115
+ results[idx] = await func(url)
116
+
117
+ try:
118
+ async with create_task_group() as tg:
119
+ semaphore = anyio.Semaphore(max_concurrency)
120
+
121
+ for i, inp in enumerate(inputs):
122
+ await tg.start_soon(bounded_fetch, semaphore, i, inp)
123
+ except BaseException as e:
124
+ # Re-raise the first exception directly instead of ExceptionGroup
125
+ if hasattr(e, "exceptions") and e.exceptions:
126
+ raise e.exceptions[0]
127
+ else:
128
+ raise
119
129
 
120
130
  return results # type: ignore
121
131
 
122
132
 
123
133
  async def retry_with_timeout(
124
- func: Callable[..., Awaitable[T]],
125
- *args: Any,
134
+ func: Callable[[], Awaitable[T]],
126
135
  max_retries: int = 3,
127
- timeout: float = 5.0,
128
- retry_exceptions: list[type[Exception]] | None = None,
129
- **kwargs: Any,
136
+ timeout: float = 30.0,
137
+ backoff_factor: float = 1.0,
130
138
  ) -> T:
131
- """Execute a function with retry logic and timeout.
139
+ """Retry an async function with exponential backoff and timeout.
132
140
 
133
141
  Args:
134
- func: The function to call
135
- *args: Positional arguments to pass to the function
136
- max_retries: The maximum number of retry attempts
137
- timeout: The timeout for each attempt in seconds
138
- retry_exceptions: List of exception types to retry on, or None to retry on any exception
139
- **kwargs: Keyword arguments to pass to the function
142
+ func: The async function to retry
143
+ max_retries: Maximum number of retries
144
+ timeout: Timeout for each attempt
145
+ backoff_factor: Multiplier for exponential backoff
140
146
 
141
147
  Returns:
142
- The return value of the function
148
+ The result of the successful function call
143
149
 
144
150
  Raises:
145
- TimeoutError: If all retry attempts time out
146
- Exception: If the function raises an exception after all retry attempts
151
+ Exception: The last exception raised by the function
147
152
  """
148
- retry_exceptions = retry_exceptions or [Exception]
149
153
  last_exception = None
150
154
 
151
155
  for attempt in range(max_retries):
152
156
  try:
153
- timed_out = False
154
- with move_on_after(timeout) as scope:
155
- result = await func(*args, **kwargs)
156
- if not scope.cancelled_caught:
157
+ with move_on_after(timeout) as cancel_scope:
158
+ result = await func()
159
+ if not cancel_scope.cancelled_caught:
157
160
  return result
158
- timed_out = True
159
-
160
- # If we get here, the operation timed out
161
- if timed_out:
162
- if attempt == max_retries - 1:
163
- raise TimeoutError(
164
- f"Operation timed out after {max_retries} attempts"
165
- )
166
-
167
- # Wait before retrying (exponential backoff)
168
- await anyio.sleep(2**attempt)
169
-
170
- except tuple(retry_exceptions) as exc:
171
- last_exception = exc
172
- if attempt == max_retries - 1:
173
- raise
161
+ else:
162
+ raise TimeoutError(f"Function timed out after {timeout}s")
163
+ except Exception as e:
164
+ last_exception = e
165
+ if attempt < max_retries - 1:
166
+ delay = backoff_factor * (2**attempt)
167
+ await anyio.sleep(delay)
168
+ continue
174
169
 
175
- # Wait before retrying (exponential backoff)
176
- await anyio.sleep(2**attempt)
177
-
178
- # This should never be reached, but makes the type checker happy
179
170
  if last_exception:
180
171
  raise last_exception
181
- raise RuntimeError("Unreachable code")
172
+ else:
173
+ raise RuntimeError("Retry failed without capturing exception")
182
174
 
183
175
 
184
176
  class WorkerPool:
@@ -187,66 +179,81 @@ class WorkerPool:
187
179
  def __init__(
188
180
  self, num_workers: int, worker_func: Callable[[Any], Awaitable[None]]
189
181
  ):
190
- """Initialize a new worker pool.
182
+ """Initialize a new worker pool."""
183
+ if num_workers < 1:
184
+ raise ValueError("num_workers must be >= 1")
185
+ if not callable(worker_func):
186
+ raise ValueError("worker_func must be callable")
191
187
 
192
- Args:
193
- num_workers: The number of worker tasks to create
194
- worker_func: The function that each worker will run
195
- """
196
188
  self._num_workers = num_workers
197
189
  self._worker_func = worker_func
198
- self._queue = anyio.create_memory_object_stream(math.inf)
190
+ self._queue = anyio.create_memory_object_stream(1000)
199
191
  self._task_group = None
200
192
 
193
+ track_resource(self, f"WorkerPool-{id(self)}", "WorkerPool")
194
+
195
+ def __del__(self):
196
+ """Clean up resource tracking."""
197
+ try:
198
+ untrack_resource(self)
199
+ except Exception:
200
+ pass
201
+
201
202
  async def start(self) -> None:
202
203
  """Start the worker pool."""
203
204
  if self._task_group is not None:
204
- raise RuntimeError("Worker pool already started")
205
+ raise RuntimeError("Worker pool is already started")
205
206
 
206
207
  self._task_group = create_task_group()
208
+ await self._task_group.__aenter__()
207
209
 
208
- async with self._task_group as tg:
209
- for _ in range(self._num_workers):
210
- tg.start_soon(self._worker_loop)
210
+ # Start worker tasks
211
+ for i in range(self._num_workers):
212
+ await self._task_group.start_soon(self._worker_loop)
211
213
 
212
214
  async def stop(self) -> None:
213
215
  """Stop the worker pool."""
214
216
  if self._task_group is None:
215
217
  return
216
218
 
217
- # Signal workers to stop
218
- for _ in range(self._num_workers):
219
- await self._queue[0].send(None)
219
+ # Close the queue to signal workers to stop
220
+ await self._queue[0].aclose()
220
221
 
221
- # Wait for workers to finish
222
- await self._task_group.__aexit__(None, None, None)
223
- self._task_group = None
222
+ # Wait for all workers to finish
223
+ try:
224
+ await self._task_group.__aexit__(None, None, None)
225
+ finally:
226
+ self._task_group = None
224
227
 
225
228
  async def submit(self, item: Any) -> None:
226
- """Submit an item to be processed by a worker.
227
-
228
- Args:
229
- item: The item to process
230
- """
229
+ """Submit an item for processing."""
231
230
  if self._task_group is None:
232
- raise RuntimeError("Worker pool not started")
233
-
231
+ raise RuntimeError("Worker pool is not started")
234
232
  await self._queue[0].send(item)
235
233
 
236
234
  async def _worker_loop(self) -> None:
237
- """The main loop for each worker task."""
238
- while True:
239
- try:
240
- item = await self._queue[1].receive()
241
-
242
- # None is a signal to stop
243
- if item is None:
244
- break
245
-
246
- try:
247
- await self._worker_func(item)
248
- except Exception as exc:
249
- # Log the exception but keep the worker running
250
- print(f"Worker error: {exc}")
251
- except anyio.EndOfStream:
252
- break
235
+ """Main loop for worker tasks."""
236
+ try:
237
+ async with self._queue[1]:
238
+ async for item in self._queue[1]:
239
+ try:
240
+ await self._worker_func(item)
241
+ except Exception as e:
242
+ logger.error(f"Worker error processing item: {e}")
243
+ except anyio.ClosedResourceError:
244
+ # Queue was closed, worker should exit gracefully
245
+ pass
246
+
247
+ async def __aenter__(self) -> WorkerPool:
248
+ """Enter the worker pool context."""
249
+ await self.start()
250
+ return self
251
+
252
+ async def __aexit__(
253
+ self,
254
+ exc_type: type[BaseException] | None,
255
+ exc_val: BaseException | None,
256
+ exc_tb: TracebackType | None,
257
+ ) -> None:
258
+ """Exit the worker pool context."""
259
+ await self.stop()