kash-shell 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kash/actions/core/combine_docs.py +52 -0
- kash/actions/core/concat_docs.py +47 -0
- kash/commands/workspace/workspace_commands.py +2 -2
- kash/config/logger.py +3 -2
- kash/config/settings.py +8 -0
- kash/docs/markdown/topics/a2_installation.md +2 -2
- kash/embeddings/embeddings.py +1 -1
- kash/exec/action_exec.py +1 -1
- kash/exec/fetch_url_items.py +52 -16
- kash/file_storage/file_store.py +3 -3
- kash/llm_utils/llm_completion.py +1 -1
- kash/mcp/mcp_cli.py +2 -2
- kash/utils/api_utils/api_retries.py +348 -14
- kash/utils/api_utils/gather_limited.py +366 -512
- kash/utils/api_utils/http_utils.py +46 -0
- kash/utils/api_utils/progress_protocol.py +49 -56
- kash/utils/rich_custom/multitask_status.py +70 -21
- kash/utils/text_handling/markdown_utils.py +14 -3
- kash/web_content/web_extract.py +13 -9
- kash/web_content/web_fetch.py +289 -60
- kash/web_content/web_page_model.py +5 -0
- {kash_shell-0.3.22.dist-info → kash_shell-0.3.24.dist-info}/METADATA +5 -3
- {kash_shell-0.3.22.dist-info → kash_shell-0.3.24.dist-info}/RECORD +26 -23
- {kash_shell-0.3.22.dist-info → kash_shell-0.3.24.dist-info}/WHEEL +0 -0
- {kash_shell-0.3.22.dist-info → kash_shell-0.3.24.dist-info}/entry_points.txt +0 -0
- {kash_shell-0.3.22.dist-info → kash_shell-0.3.24.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,9 +4,10 @@ import asyncio
|
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
6
|
import threading
|
|
7
|
+
import time
|
|
7
8
|
from collections.abc import Callable, Coroutine
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
|
-
from typing import Any, TypeAlias, TypeVar, cast, overload
|
|
10
|
+
from typing import Any, Generic, TypeAlias, TypeVar, cast, overload
|
|
10
11
|
|
|
11
12
|
from aiolimiter import AsyncLimiter
|
|
12
13
|
|
|
@@ -16,6 +17,7 @@ from kash.utils.api_utils.api_retries import (
|
|
|
16
17
|
RetryExhaustedException,
|
|
17
18
|
RetrySettings,
|
|
18
19
|
calculate_backoff,
|
|
20
|
+
extract_http_status_code,
|
|
19
21
|
)
|
|
20
22
|
from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
|
|
21
23
|
|
|
@@ -25,28 +27,82 @@ log = logging.getLogger(__name__)
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
@dataclass(frozen=True)
|
|
28
|
-
class
|
|
30
|
+
class Limit:
|
|
31
|
+
"""
|
|
32
|
+
Rate limiting configuration with max RPS and max concurrency.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
rps: float
|
|
36
|
+
concurrency: int
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
DEFAULT_CANCEL_TIMEOUT: float = 1.0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class TaskResult(Generic[T]):
|
|
44
|
+
"""
|
|
45
|
+
Optional wrapper for task results that can signal rate limiting behavior.
|
|
46
|
+
Use this to wrap results that should bypass rate limiting (e.g., cache hits).
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
value: T
|
|
50
|
+
disable_limits: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True)
|
|
54
|
+
class FuncTask(Generic[T]):
|
|
29
55
|
"""
|
|
30
56
|
A task described as an unevaluated function with args and kwargs.
|
|
31
57
|
This task format allows you to use args and kwargs in the Labeler.
|
|
58
|
+
It also allows specifying a bucket for rate limiting.
|
|
59
|
+
|
|
60
|
+
For async functions: The function should return a coroutine that yields either `T` or `TaskResult[T]`.
|
|
61
|
+
For sync functions: The function should return either `T` or `TaskResult[T]` directly.
|
|
62
|
+
|
|
63
|
+
Using `TaskResult[T]` allows controlling rate limiting behavior (e.g., cache hits can bypass rate limits).
|
|
32
64
|
"""
|
|
33
65
|
|
|
34
|
-
func: Callable[..., Any]
|
|
66
|
+
func: Callable[..., Any] # Keep as Any since it can be sync or async
|
|
35
67
|
args: tuple[Any, ...] = ()
|
|
36
68
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
69
|
+
bucket: str = "default"
|
|
37
70
|
|
|
38
71
|
|
|
39
72
|
# Type aliases for coroutine and sync specifications, including unevaluated function specs
|
|
40
|
-
CoroSpec: TypeAlias =
|
|
41
|
-
|
|
73
|
+
CoroSpec: TypeAlias = (
|
|
74
|
+
Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask[T]
|
|
75
|
+
)
|
|
76
|
+
SyncSpec: TypeAlias = Callable[[], T] | FuncTask[T]
|
|
42
77
|
|
|
43
78
|
# Specific labeler types using the generic Labeler pattern
|
|
44
79
|
CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
|
|
45
80
|
SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
|
|
46
81
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
82
|
+
|
|
83
|
+
def _get_bucket_limits(
|
|
84
|
+
bucket: str,
|
|
85
|
+
bucket_semaphores: dict[str, asyncio.Semaphore],
|
|
86
|
+
bucket_rate_limiters: dict[str, AsyncLimiter],
|
|
87
|
+
) -> tuple[asyncio.Semaphore | None, AsyncLimiter | None]:
|
|
88
|
+
"""
|
|
89
|
+
Get bucket-specific limits with fallback to "*" wildcard.
|
|
90
|
+
|
|
91
|
+
Checks for exact bucket match first, then falls back to "*" if available.
|
|
92
|
+
Returns (None, None) if neither exact match nor "*" fallback exists.
|
|
93
|
+
"""
|
|
94
|
+
# Try exact bucket match first
|
|
95
|
+
bucket_semaphore = bucket_semaphores.get(bucket)
|
|
96
|
+
bucket_rate_limiter = bucket_rate_limiters.get(bucket)
|
|
97
|
+
|
|
98
|
+
if bucket_semaphore is not None and bucket_rate_limiter is not None:
|
|
99
|
+
return bucket_semaphore, bucket_rate_limiter
|
|
100
|
+
|
|
101
|
+
# Fall back to "*" wildcard if available
|
|
102
|
+
bucket_semaphore = bucket_semaphores.get("*")
|
|
103
|
+
bucket_rate_limiter = bucket_rate_limiters.get("*")
|
|
104
|
+
|
|
105
|
+
return bucket_semaphore, bucket_rate_limiter
|
|
50
106
|
|
|
51
107
|
|
|
52
108
|
class RetryCounter:
|
|
@@ -75,8 +131,8 @@ class RetryCounter:
|
|
|
75
131
|
@overload
|
|
76
132
|
async def gather_limited_async(
|
|
77
133
|
*coro_specs: CoroSpec[T],
|
|
78
|
-
|
|
79
|
-
|
|
134
|
+
limit: Limit | None,
|
|
135
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
80
136
|
return_exceptions: bool = False,
|
|
81
137
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
82
138
|
status: ProgressTracker | None = None,
|
|
@@ -87,8 +143,8 @@ async def gather_limited_async(
|
|
|
87
143
|
@overload
|
|
88
144
|
async def gather_limited_async(
|
|
89
145
|
*coro_specs: CoroSpec[T],
|
|
90
|
-
|
|
91
|
-
|
|
146
|
+
limit: Limit | None,
|
|
147
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
92
148
|
return_exceptions: bool = True,
|
|
93
149
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
94
150
|
status: ProgressTracker | None = None,
|
|
@@ -98,27 +154,42 @@ async def gather_limited_async(
|
|
|
98
154
|
|
|
99
155
|
async def gather_limited_async(
|
|
100
156
|
*coro_specs: CoroSpec[T],
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
return_exceptions: bool =
|
|
157
|
+
limit: Limit | None,
|
|
158
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
159
|
+
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
104
160
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
105
161
|
status: ProgressTracker | None = None,
|
|
106
162
|
labeler: CoroLabeler[T] | None = None,
|
|
107
163
|
) -> list[T] | list[T | BaseException]:
|
|
108
164
|
"""
|
|
109
|
-
Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
|
|
165
|
+
Rate-limited version of `asyncio.gather()` with HTTP-aware retry logic and optional progress display.
|
|
110
166
|
Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
|
|
111
167
|
|
|
112
168
|
Supports two levels of retry limits:
|
|
113
169
|
- Per-task retries: max_task_retries attempts per individual task
|
|
114
170
|
- Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
|
|
115
171
|
|
|
172
|
+
Features HTTP-aware retry classification:
|
|
173
|
+
- Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
|
|
174
|
+
- Configurable handling of conditional status codes like 403 Forbidden
|
|
175
|
+
- Defaults to return_exceptions=True for resilient batch operations
|
|
176
|
+
|
|
177
|
+
Prevents API flooding during rate limit backoffs:
|
|
178
|
+
- Semaphore slots are held during entire retry cycles, including backoff periods
|
|
179
|
+
- New tasks cannot start while existing tasks are backing off from rate limits
|
|
180
|
+
- Rate limiters are only acquired during actual execution attempts
|
|
181
|
+
|
|
116
182
|
Can optionally display live progress with retry indicators using TaskStatus.
|
|
117
183
|
|
|
118
184
|
Accepts:
|
|
119
185
|
- Callables that return coroutines: `lambda: some_async_func(arg)` (recommended for retries)
|
|
120
186
|
- Coroutines directly: `some_async_func(arg)` (only if retries disabled)
|
|
121
|
-
-
|
|
187
|
+
- FuncTask objects: `FuncTask(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
|
|
188
|
+
|
|
189
|
+
Functions can return `TaskResult[T]` to bypass rate limiting for specific results:
|
|
190
|
+
- `TaskResult(value, disable_limits=True)`: bypass rate limiting (e.g., cache hits)
|
|
191
|
+
- `TaskResult(value, disable_limits=False)`: apply normal rate limiting
|
|
192
|
+
- `value` directly: apply normal rate limiting
|
|
122
193
|
|
|
123
194
|
Examples:
|
|
124
195
|
```python
|
|
@@ -126,7 +197,7 @@ async def gather_limited_async(
|
|
|
126
197
|
from kash.utils.rich_custom.task_status import TaskStatus
|
|
127
198
|
|
|
128
199
|
async with TaskStatus() as status:
|
|
129
|
-
await
|
|
200
|
+
await gather_limited_async(
|
|
130
201
|
lambda: fetch_url("http://example.com"),
|
|
131
202
|
lambda: process_data(data),
|
|
132
203
|
status=status,
|
|
@@ -135,18 +206,47 @@ async def gather_limited_async(
|
|
|
135
206
|
)
|
|
136
207
|
|
|
137
208
|
# Without progress display:
|
|
138
|
-
await
|
|
209
|
+
await gather_limited_async(
|
|
139
210
|
lambda: fetch_url("http://example.com"),
|
|
140
211
|
lambda: process_data(data),
|
|
141
212
|
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
142
213
|
)
|
|
143
214
|
|
|
215
|
+
# With bucket-specific limits and "*" fallback:
|
|
216
|
+
await gather_limited_async(
|
|
217
|
+
FuncTask(fetch_api, ("data1",), bucket="api1"),
|
|
218
|
+
FuncTask(fetch_api, ("data2",), bucket="api2"),
|
|
219
|
+
FuncTask(fetch_api, ("data3",), bucket="api3"),
|
|
220
|
+
limit=Limit(rps=100, concurrency=50),
|
|
221
|
+
bucket_limits={
|
|
222
|
+
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
223
|
+
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# With cache bypass using TaskResult:
|
|
228
|
+
async def fetch_with_cache(url: str) -> TaskResult[dict] | dict:
|
|
229
|
+
cached_data = await cache.get(url)
|
|
230
|
+
if cached_data:
|
|
231
|
+
return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
|
|
232
|
+
|
|
233
|
+
data = await fetch_api(url) # Will be rate limited
|
|
234
|
+
await cache.set(url, data)
|
|
235
|
+
return data
|
|
236
|
+
|
|
237
|
+
await gather_limited_async(
|
|
238
|
+
lambda: fetch_with_cache("http://api.com/data1"),
|
|
239
|
+
lambda: fetch_with_cache("http://api.com/data2"),
|
|
240
|
+
limit=Limit(rps=10, concurrency=5) # Cache hits bypass these limits
|
|
241
|
+
)
|
|
242
|
+
|
|
144
243
|
```
|
|
145
244
|
|
|
146
245
|
Args:
|
|
147
246
|
*coro_specs: Callables or coroutines to execute
|
|
148
|
-
|
|
149
|
-
|
|
247
|
+
limit: Global limits applied to all tasks regardless of bucket
|
|
248
|
+
bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
|
|
249
|
+
Use "*" as a fallback limit for buckets without specific limits.
|
|
150
250
|
return_exceptions: If True, exceptions are returned as results
|
|
151
251
|
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
152
252
|
status: Optional ProgressTracker instance for progress display
|
|
@@ -159,9 +259,9 @@ async def gather_limited_async(
|
|
|
159
259
|
ValueError: If coroutines are passed when retries are enabled
|
|
160
260
|
"""
|
|
161
261
|
log.info(
|
|
162
|
-
"Executing with concurrency %s at %s rps, %s",
|
|
163
|
-
|
|
164
|
-
|
|
262
|
+
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
263
|
+
limit.concurrency if limit else "None",
|
|
264
|
+
limit.rps if limit else "None",
|
|
165
265
|
retry_settings,
|
|
166
266
|
)
|
|
167
267
|
if not coro_specs:
|
|
@@ -179,8 +279,18 @@ async def gather_limited_async(
|
|
|
179
279
|
f"lambda: your_async_func(args) instead of your_async_func(args)"
|
|
180
280
|
)
|
|
181
281
|
|
|
182
|
-
|
|
183
|
-
|
|
282
|
+
# Global limits (apply to all tasks regardless of bucket)
|
|
283
|
+
global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
|
|
284
|
+
global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
|
|
285
|
+
|
|
286
|
+
# Per-bucket limits (if bucket_limits provided)
|
|
287
|
+
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
288
|
+
bucket_rate_limiters: dict[str, AsyncLimiter] = {}
|
|
289
|
+
|
|
290
|
+
if bucket_limits:
|
|
291
|
+
for bucket_name, limit in bucket_limits.items():
|
|
292
|
+
bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
|
|
293
|
+
bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
|
|
184
294
|
|
|
185
295
|
# Global retry counter (shared across all tasks)
|
|
186
296
|
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
@@ -190,6 +300,16 @@ async def gather_limited_async(
|
|
|
190
300
|
label = labeler(i, coro_spec) if labeler else f"task:{i}"
|
|
191
301
|
task_id = await status.add(label) if status else None
|
|
192
302
|
|
|
303
|
+
# Determine bucket and get appropriate limits
|
|
304
|
+
bucket = "default"
|
|
305
|
+
if isinstance(coro_spec, FuncTask):
|
|
306
|
+
bucket = coro_spec.bucket
|
|
307
|
+
|
|
308
|
+
# Get bucket-specific limits if available
|
|
309
|
+
bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
|
|
310
|
+
bucket, bucket_semaphores, bucket_rate_limiters
|
|
311
|
+
)
|
|
312
|
+
|
|
193
313
|
async def executor() -> T:
|
|
194
314
|
# Create a fresh coroutine for each attempt
|
|
195
315
|
if isinstance(coro_spec, FuncTask):
|
|
@@ -198,7 +318,7 @@ async def gather_limited_async(
|
|
|
198
318
|
elif callable(coro_spec):
|
|
199
319
|
coro = coro_spec()
|
|
200
320
|
else:
|
|
201
|
-
# Direct coroutine
|
|
321
|
+
# Direct coroutine: only valid if retries disabled
|
|
202
322
|
coro = coro_spec
|
|
203
323
|
return await coro
|
|
204
324
|
|
|
@@ -206,8 +326,10 @@ async def gather_limited_async(
|
|
|
206
326
|
result = await _execute_with_retry(
|
|
207
327
|
executor,
|
|
208
328
|
retry_settings,
|
|
209
|
-
|
|
210
|
-
|
|
329
|
+
global_semaphore,
|
|
330
|
+
global_rate_limiter,
|
|
331
|
+
bucket_semaphore,
|
|
332
|
+
bucket_rate_limiter,
|
|
211
333
|
global_retry_counter,
|
|
212
334
|
status,
|
|
213
335
|
task_id,
|
|
@@ -234,8 +356,8 @@ async def gather_limited_async(
|
|
|
234
356
|
@overload
|
|
235
357
|
async def gather_limited_sync(
|
|
236
358
|
*sync_specs: SyncSpec[T],
|
|
237
|
-
|
|
238
|
-
|
|
359
|
+
limit: Limit | None,
|
|
360
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
239
361
|
return_exceptions: bool = False,
|
|
240
362
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
241
363
|
status: ProgressTracker | None = None,
|
|
@@ -248,8 +370,8 @@ async def gather_limited_sync(
|
|
|
248
370
|
@overload
|
|
249
371
|
async def gather_limited_sync(
|
|
250
372
|
*sync_specs: SyncSpec[T],
|
|
251
|
-
|
|
252
|
-
|
|
373
|
+
limit: Limit | None,
|
|
374
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
253
375
|
return_exceptions: bool = True,
|
|
254
376
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
255
377
|
status: ProgressTracker | None = None,
|
|
@@ -261,9 +383,9 @@ async def gather_limited_sync(
|
|
|
261
383
|
|
|
262
384
|
async def gather_limited_sync(
|
|
263
385
|
*sync_specs: SyncSpec[T],
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
return_exceptions: bool =
|
|
386
|
+
limit: Limit | None,
|
|
387
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
388
|
+
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
267
389
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
268
390
|
status: ProgressTracker | None = None,
|
|
269
391
|
labeler: SyncLabeler[T] | None = None,
|
|
@@ -271,54 +393,74 @@ async def gather_limited_sync(
|
|
|
271
393
|
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
272
394
|
) -> list[T] | list[T | BaseException]:
|
|
273
395
|
"""
|
|
274
|
-
|
|
275
|
-
|
|
396
|
+
Sync version of `gather_limited_async()` that executes synchronous functions with the same
|
|
397
|
+
rate limiting, retry logic, and progress tracking capabilities.
|
|
398
|
+
See `gather_limited_async()` for details.
|
|
276
399
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
400
|
+
Sync-specific differences:
|
|
401
|
+
|
|
402
|
+
**Function Execution**: Runs sync functions via `asyncio.to_thread()` instead of executing
|
|
403
|
+
coroutines directly. Validates that callables don't accidentally return coroutines.
|
|
280
404
|
|
|
281
|
-
|
|
405
|
+
**Cancellation Support**: Provides cooperative cancellation for long-running sync functions
|
|
406
|
+
through optional `cancel_event` and graceful thread termination.
|
|
407
|
+
|
|
408
|
+
**Input Validation**: Ensures sync functions don't return coroutines, which would indicate
|
|
409
|
+
incorrect usage (async functions should use `gather_limited_async()`).
|
|
282
410
|
|
|
283
411
|
Args:
|
|
284
|
-
*sync_specs: Callables that return values (not coroutines) or FuncTask objects
|
|
285
|
-
|
|
286
|
-
max_rps: Maximum requests per second
|
|
287
|
-
return_exceptions: If True, exceptions are returned as results
|
|
288
|
-
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
289
|
-
status: Optional ProgressTracker instance for progress display
|
|
290
|
-
labeler: Optional function to generate labels: labeler(index, spec) -> str
|
|
412
|
+
*sync_specs: Callables that return values (not coroutines) or FuncTask objects.
|
|
413
|
+
Functions can return `TaskResult[T]` to control rate limiting behavior.
|
|
291
414
|
cancel_event: Optional threading.Event that will be set on cancellation
|
|
292
415
|
cancel_timeout: Max seconds to wait for threads to terminate on cancellation
|
|
416
|
+
(All other args identical to gather_limited_async())
|
|
293
417
|
|
|
294
418
|
Returns:
|
|
295
419
|
List of results in the same order as input specifications
|
|
296
420
|
|
|
297
421
|
Example:
|
|
298
422
|
```python
|
|
299
|
-
#
|
|
423
|
+
# Basic usage with sync functions
|
|
300
424
|
results = await gather_limited_sync(
|
|
301
425
|
lambda: some_sync_function(arg1),
|
|
302
426
|
lambda: another_sync_function(arg2),
|
|
303
|
-
|
|
304
|
-
max_rps=2.0,
|
|
427
|
+
limit=Limit(rps=2.0, concurrency=3),
|
|
305
428
|
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
306
429
|
)
|
|
307
430
|
|
|
308
|
-
# With
|
|
309
|
-
cancel_event = threading.Event()
|
|
431
|
+
# With bucket-specific limits and "*" fallback
|
|
310
432
|
results = await gather_limited_sync(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
433
|
+
FuncTask(fetch_from_api, ("data1",), bucket="api1"),
|
|
434
|
+
FuncTask(fetch_from_api, ("data2",), bucket="api2"),
|
|
435
|
+
FuncTask(fetch_from_api, ("data3",), bucket="api3"),
|
|
436
|
+
limit=Limit(rps=100, concurrency=50),
|
|
437
|
+
bucket_limits={
|
|
438
|
+
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
439
|
+
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
440
|
+
}
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# With cache bypass using TaskResult
|
|
444
|
+
def sync_fetch_with_cache(url: str) -> TaskResult[dict] | dict:
|
|
445
|
+
cached_data = cache.get_sync(url)
|
|
446
|
+
if cached_data:
|
|
447
|
+
return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
|
|
448
|
+
|
|
449
|
+
data = requests.get(url).json() # Will be rate limited
|
|
450
|
+
cache.set_sync(url, data)
|
|
451
|
+
return data
|
|
452
|
+
|
|
453
|
+
results = await gather_limited_sync(
|
|
454
|
+
lambda: sync_fetch_with_cache("http://api.com/data1"),
|
|
455
|
+
lambda: sync_fetch_with_cache("http://api.com/data2"),
|
|
456
|
+
limit=Limit(rps=5, concurrency=3) # Cache hits bypass these limits
|
|
315
457
|
)
|
|
316
458
|
```
|
|
317
459
|
"""
|
|
318
460
|
log.info(
|
|
319
|
-
"Executing with concurrency %s at %s rps, %s",
|
|
320
|
-
|
|
321
|
-
|
|
461
|
+
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
462
|
+
limit.concurrency if limit else "None",
|
|
463
|
+
limit.rps if limit else "None",
|
|
322
464
|
retry_settings,
|
|
323
465
|
)
|
|
324
466
|
if not sync_specs:
|
|
@@ -326,8 +468,18 @@ async def gather_limited_sync(
|
|
|
326
468
|
|
|
327
469
|
retry_settings = retry_settings or NO_RETRIES
|
|
328
470
|
|
|
329
|
-
|
|
330
|
-
|
|
471
|
+
# Global limits (apply to all tasks regardless of bucket)
|
|
472
|
+
global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
|
|
473
|
+
global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
|
|
474
|
+
|
|
475
|
+
# Per-bucket limits (if bucket_limits provided)
|
|
476
|
+
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
477
|
+
bucket_rate_limiters: dict[str, AsyncLimiter] = {}
|
|
478
|
+
|
|
479
|
+
if bucket_limits:
|
|
480
|
+
for bucket_name, limit in bucket_limits.items():
|
|
481
|
+
bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
|
|
482
|
+
bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
|
|
331
483
|
|
|
332
484
|
# Global retry counter (shared across all tasks)
|
|
333
485
|
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
@@ -337,6 +489,16 @@ async def gather_limited_sync(
|
|
|
337
489
|
label = labeler(i, sync_spec) if labeler else f"task:{i}"
|
|
338
490
|
task_id = await status.add(label) if status else None
|
|
339
491
|
|
|
492
|
+
# Determine bucket and get appropriate limits
|
|
493
|
+
bucket = "default"
|
|
494
|
+
if isinstance(sync_spec, FuncTask):
|
|
495
|
+
bucket = sync_spec.bucket
|
|
496
|
+
|
|
497
|
+
# Get bucket-specific limits if available
|
|
498
|
+
bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
|
|
499
|
+
bucket, bucket_semaphores, bucket_rate_limiters
|
|
500
|
+
)
|
|
501
|
+
|
|
340
502
|
async def executor() -> T:
|
|
341
503
|
# Call sync function via asyncio.to_thread, handling retry at this level
|
|
342
504
|
if isinstance(sync_spec, FuncTask):
|
|
@@ -361,8 +523,10 @@ async def gather_limited_sync(
|
|
|
361
523
|
result = await _execute_with_retry(
|
|
362
524
|
executor,
|
|
363
525
|
retry_settings,
|
|
364
|
-
|
|
365
|
-
|
|
526
|
+
global_semaphore,
|
|
527
|
+
global_rate_limiter,
|
|
528
|
+
bucket_semaphore,
|
|
529
|
+
bucket_rate_limiter,
|
|
366
530
|
global_retry_counter,
|
|
367
531
|
status,
|
|
368
532
|
task_id,
|
|
@@ -434,7 +598,8 @@ async def _gather_with_interrupt_handling(
|
|
|
434
598
|
if cancelled_count > 0:
|
|
435
599
|
try:
|
|
436
600
|
await asyncio.wait_for(
|
|
437
|
-
asyncio.gather(*async_tasks, return_exceptions=True),
|
|
601
|
+
asyncio.gather(*async_tasks, return_exceptions=True),
|
|
602
|
+
timeout=cancel_timeout,
|
|
438
603
|
)
|
|
439
604
|
except (TimeoutError, asyncio.CancelledError):
|
|
440
605
|
log.warning("Some tasks did not cancel within timeout")
|
|
@@ -462,20 +627,95 @@ async def _gather_with_interrupt_handling(
|
|
|
462
627
|
async def _execute_with_retry(
|
|
463
628
|
executor: Callable[[], Coroutine[None, None, T]],
|
|
464
629
|
retry_settings: RetrySettings,
|
|
465
|
-
|
|
466
|
-
|
|
630
|
+
global_semaphore: asyncio.Semaphore | None,
|
|
631
|
+
global_rate_limiter: AsyncLimiter | None,
|
|
632
|
+
bucket_semaphore: asyncio.Semaphore | None,
|
|
633
|
+
bucket_rate_limiter: AsyncLimiter | None,
|
|
467
634
|
global_retry_counter: RetryCounter,
|
|
468
635
|
status: ProgressTracker | None = None,
|
|
469
636
|
task_id: Any | None = None,
|
|
470
637
|
) -> T:
|
|
471
|
-
|
|
638
|
+
"""
|
|
639
|
+
Execute a task with retry logic, holding semaphores for the entire retry cycle.
|
|
640
|
+
|
|
641
|
+
Semaphores are held during backoff periods to prevent API flooding when tasks
|
|
642
|
+
are already backing off from rate limits. Only rate limiters are acquired/released
|
|
643
|
+
for each execution attempt.
|
|
644
|
+
"""
|
|
645
|
+
# Acquire semaphores once for the entire retry cycle (including backoff periods)
|
|
646
|
+
if global_semaphore is not None:
|
|
647
|
+
async with global_semaphore:
|
|
648
|
+
if bucket_semaphore is not None:
|
|
649
|
+
async with bucket_semaphore:
|
|
650
|
+
return await _execute_with_retry_inner(
|
|
651
|
+
executor,
|
|
652
|
+
retry_settings,
|
|
653
|
+
global_rate_limiter,
|
|
654
|
+
bucket_rate_limiter,
|
|
655
|
+
global_retry_counter,
|
|
656
|
+
status,
|
|
657
|
+
task_id,
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
return await _execute_with_retry_inner(
|
|
661
|
+
executor,
|
|
662
|
+
retry_settings,
|
|
663
|
+
global_rate_limiter,
|
|
664
|
+
None,
|
|
665
|
+
global_retry_counter,
|
|
666
|
+
status,
|
|
667
|
+
task_id,
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
# No global semaphore, check bucket semaphore
|
|
671
|
+
if bucket_semaphore is not None:
|
|
672
|
+
async with bucket_semaphore:
|
|
673
|
+
return await _execute_with_retry_inner(
|
|
674
|
+
executor,
|
|
675
|
+
retry_settings,
|
|
676
|
+
global_rate_limiter,
|
|
677
|
+
bucket_rate_limiter,
|
|
678
|
+
global_retry_counter,
|
|
679
|
+
status,
|
|
680
|
+
task_id,
|
|
681
|
+
)
|
|
682
|
+
else:
|
|
683
|
+
# No semaphores at all, go straight to running
|
|
684
|
+
return await _execute_with_retry_inner(
|
|
685
|
+
executor,
|
|
686
|
+
retry_settings,
|
|
687
|
+
global_rate_limiter,
|
|
688
|
+
None,
|
|
689
|
+
global_retry_counter,
|
|
690
|
+
status,
|
|
691
|
+
task_id,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
async def _execute_with_retry_inner(
|
|
696
|
+
executor: Callable[[], Coroutine[None, None, T]],
|
|
697
|
+
retry_settings: RetrySettings,
|
|
698
|
+
global_rate_limiter: AsyncLimiter | None,
|
|
699
|
+
bucket_rate_limiter: AsyncLimiter | None,
|
|
700
|
+
global_retry_counter: RetryCounter,
|
|
701
|
+
status: ProgressTracker | None = None,
|
|
702
|
+
task_id: Any | None = None,
|
|
703
|
+
) -> T:
|
|
704
|
+
"""
|
|
705
|
+
Inner retry logic that handles rate limiting and backoff.
|
|
706
|
+
|
|
707
|
+
This function assumes semaphores are already held by the caller and only
|
|
708
|
+
manages rate limiters for each execution attempt.
|
|
709
|
+
"""
|
|
710
|
+
if status and task_id:
|
|
711
|
+
await status.start(task_id)
|
|
472
712
|
|
|
473
713
|
start_time = time.time()
|
|
474
714
|
last_exception: Exception | None = None
|
|
475
715
|
|
|
476
716
|
for attempt in range(retry_settings.max_task_retries + 1):
|
|
477
|
-
# Handle backoff before acquiring
|
|
478
|
-
if attempt > 0 and last_exception
|
|
717
|
+
# Handle backoff before acquiring rate limiters (semaphores remain held)
|
|
718
|
+
if attempt > 0 and last_exception:
|
|
479
719
|
# Try to increment global retry counter
|
|
480
720
|
if not await global_retry_counter.try_increment():
|
|
481
721
|
log.error(
|
|
@@ -493,13 +733,13 @@ async def _execute_with_retry(
|
|
|
493
733
|
)
|
|
494
734
|
|
|
495
735
|
# Record retry in status display and log appropriately
|
|
496
|
-
if status and task_id
|
|
736
|
+
if status and task_id:
|
|
497
737
|
# Include retry attempt info and backoff time in the status display
|
|
498
738
|
retry_info = (
|
|
499
739
|
f"Attempt {attempt}/{retry_settings.max_task_retries} "
|
|
500
740
|
f"(waiting {backoff_time:.1f}s): {type(last_exception).__name__}: {last_exception}"
|
|
501
741
|
)
|
|
502
|
-
await status.update(task_id, error_msg=retry_info)
|
|
742
|
+
await status.update(task_id, TaskState.WAITING, error_msg=retry_info)
|
|
503
743
|
|
|
504
744
|
# Use debug level for Rich trackers, warning/info for console trackers
|
|
505
745
|
use_debug_level = status.suppress_logs
|
|
@@ -508,8 +748,11 @@ async def _execute_with_retry(
|
|
|
508
748
|
use_debug_level = False
|
|
509
749
|
|
|
510
750
|
# Log retry information at appropriate level
|
|
751
|
+
status_code = extract_http_status_code(last_exception)
|
|
752
|
+
status_info = f" (HTTP {status_code})" if status_code else ""
|
|
753
|
+
|
|
511
754
|
rate_limit_msg = (
|
|
512
|
-
f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
|
|
755
|
+
f"Rate limit hit{status_info} (attempt {attempt}/{retry_settings.max_task_retries} "
|
|
513
756
|
f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
|
|
514
757
|
f"backing off for {backoff_time:.2f}s"
|
|
515
758
|
)
|
|
@@ -523,25 +766,57 @@ async def _execute_with_retry(
|
|
|
523
766
|
else:
|
|
524
767
|
log.warning(rate_limit_msg)
|
|
525
768
|
log.info(exception_msg)
|
|
769
|
+
|
|
770
|
+
# Sleep during backoff while holding semaphore slots
|
|
526
771
|
await asyncio.sleep(backoff_time)
|
|
527
772
|
|
|
773
|
+
# Set state back to running before next attempt
|
|
774
|
+
if status and task_id:
|
|
775
|
+
await status.update(task_id, TaskState.RUNNING)
|
|
776
|
+
|
|
528
777
|
try:
|
|
529
|
-
#
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
778
|
+
# Execute task to check for potential rate limit bypass
|
|
779
|
+
raw_result = await executor()
|
|
780
|
+
|
|
781
|
+
# Check if result indicates limits should be disabled (e.g., cache hit)
|
|
782
|
+
if isinstance(raw_result, TaskResult):
|
|
783
|
+
if raw_result.disable_limits:
|
|
784
|
+
# Bypass rate limiting and return immediately
|
|
785
|
+
return cast(T, raw_result.value)
|
|
786
|
+
else:
|
|
787
|
+
# Wrapped but limits enabled: extract value for rate limiting
|
|
788
|
+
result_value = cast(T, raw_result.value)
|
|
789
|
+
else:
|
|
790
|
+
# Unwrapped result: apply normal rate limiting
|
|
791
|
+
result_value = cast(T, raw_result)
|
|
792
|
+
|
|
793
|
+
# Apply rate limiting for non-bypassed results
|
|
794
|
+
if global_rate_limiter is not None:
|
|
795
|
+
async with global_rate_limiter:
|
|
796
|
+
if bucket_rate_limiter is not None:
|
|
797
|
+
async with bucket_rate_limiter:
|
|
798
|
+
return result_value
|
|
799
|
+
else:
|
|
800
|
+
return result_value
|
|
801
|
+
else:
|
|
802
|
+
# No global rate limiter, check bucket rate limiter
|
|
803
|
+
if bucket_rate_limiter is not None:
|
|
804
|
+
async with bucket_rate_limiter:
|
|
805
|
+
return result_value
|
|
806
|
+
else:
|
|
807
|
+
# No rate limiting at all
|
|
808
|
+
return result_value
|
|
809
|
+
|
|
535
810
|
except Exception as e:
|
|
536
811
|
last_exception = e # Always store the exception
|
|
537
812
|
|
|
538
813
|
if attempt == retry_settings.max_task_retries:
|
|
539
814
|
# Final attempt failed
|
|
540
815
|
if retry_settings.max_task_retries == 0:
|
|
541
|
-
# No retries configured
|
|
816
|
+
# No retries configured: raise original exception directly
|
|
542
817
|
raise
|
|
543
818
|
else:
|
|
544
|
-
# Retries were attempted but exhausted
|
|
819
|
+
# Retries were attempted but exhausted: wrap with context
|
|
545
820
|
total_time = time.time() - start_time
|
|
546
821
|
log.error(
|
|
547
822
|
f"Max task retries ({retry_settings.max_task_retries}) exhausted after {total_time:.1f}s. "
|
|
@@ -549,439 +824,18 @@ async def _execute_with_retry(
|
|
|
549
824
|
)
|
|
550
825
|
raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
|
|
551
826
|
|
|
552
|
-
# Check if this is a retriable exception
|
|
553
|
-
if retry_settings.
|
|
554
|
-
# Continue to next retry attempt (
|
|
827
|
+
# Check if this is a retriable exception using the centralized logic
|
|
828
|
+
if retry_settings.should_retry(e):
|
|
829
|
+
# Continue to next retry attempt (semaphores remain held for backoff)
|
|
555
830
|
continue
|
|
556
831
|
else:
|
|
557
832
|
# Non-retriable exception, log and re-raise immediately
|
|
558
|
-
|
|
833
|
+
status_code = extract_http_status_code(e)
|
|
834
|
+
status_info = f" (HTTP {status_code})" if status_code else ""
|
|
835
|
+
|
|
836
|
+
log.warning("Non-retriable exception%s (not retrying): %s", status_info, e)
|
|
837
|
+
log.debug("Exception traceback:", exc_info=True)
|
|
559
838
|
raise
|
|
560
839
|
|
|
561
840
|
# This should never be reached, but satisfy type checker
|
|
562
|
-
raise RuntimeError("Unexpected code path in
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
## Tests
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
def test_gather_limited_sync():
|
|
569
|
-
"""Test gather_limited_sync with sync functions."""
|
|
570
|
-
import asyncio
|
|
571
|
-
import time
|
|
572
|
-
|
|
573
|
-
async def run_test():
|
|
574
|
-
def sync_func(value: int) -> int:
|
|
575
|
-
"""Simple sync function for testing."""
|
|
576
|
-
time.sleep(0.1) # Simulate some work
|
|
577
|
-
return value * 2
|
|
578
|
-
|
|
579
|
-
# Test basic functionality
|
|
580
|
-
results = await gather_limited_sync(
|
|
581
|
-
lambda: sync_func(1),
|
|
582
|
-
lambda: sync_func(2),
|
|
583
|
-
lambda: sync_func(3),
|
|
584
|
-
max_concurrent=2,
|
|
585
|
-
max_rps=10.0,
|
|
586
|
-
retry_settings=NO_RETRIES,
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
assert results == [2, 4, 6]
|
|
590
|
-
|
|
591
|
-
# Run the async test
|
|
592
|
-
asyncio.run(run_test())
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def test_gather_limited_sync_with_retries():
|
|
596
|
-
"""Test that sync functions can be retried on retriable exceptions."""
|
|
597
|
-
import asyncio
|
|
598
|
-
|
|
599
|
-
async def run_test():
|
|
600
|
-
call_count = 0
|
|
601
|
-
|
|
602
|
-
def flaky_sync_func() -> str:
|
|
603
|
-
"""Sync function that fails first time, succeeds second time."""
|
|
604
|
-
nonlocal call_count
|
|
605
|
-
call_count += 1
|
|
606
|
-
if call_count == 1:
|
|
607
|
-
raise Exception("Rate limit exceeded") # Retriable
|
|
608
|
-
return "success"
|
|
609
|
-
|
|
610
|
-
# Should succeed after retry
|
|
611
|
-
results = await gather_limited_sync(
|
|
612
|
-
lambda: flaky_sync_func(),
|
|
613
|
-
retry_settings=RetrySettings(
|
|
614
|
-
max_task_retries=2,
|
|
615
|
-
initial_backoff=0.1,
|
|
616
|
-
max_backoff=1.0,
|
|
617
|
-
backoff_factor=2.0,
|
|
618
|
-
),
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
assert results == ["success"]
|
|
622
|
-
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
623
|
-
|
|
624
|
-
# Run the async test
|
|
625
|
-
asyncio.run(run_test())
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def test_gather_limited_async_basic():
|
|
629
|
-
"""Test gather_limited with async functions using callables."""
|
|
630
|
-
import asyncio
|
|
631
|
-
|
|
632
|
-
async def run_test():
|
|
633
|
-
async def async_func(value: int) -> int:
|
|
634
|
-
"""Simple async function for testing."""
|
|
635
|
-
await asyncio.sleep(0.05) # Simulate async work
|
|
636
|
-
return value * 3
|
|
637
|
-
|
|
638
|
-
# Test with callables (recommended pattern)
|
|
639
|
-
results = await gather_limited_async(
|
|
640
|
-
lambda: async_func(1),
|
|
641
|
-
lambda: async_func(2),
|
|
642
|
-
lambda: async_func(3),
|
|
643
|
-
max_concurrent=2,
|
|
644
|
-
max_rps=10.0,
|
|
645
|
-
retry_settings=NO_RETRIES,
|
|
646
|
-
)
|
|
647
|
-
|
|
648
|
-
assert results == [3, 6, 9]
|
|
649
|
-
|
|
650
|
-
asyncio.run(run_test())
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
def test_gather_limited_direct_coroutines():
|
|
654
|
-
"""Test gather_limited with direct coroutines when retries disabled."""
|
|
655
|
-
import asyncio
|
|
656
|
-
|
|
657
|
-
async def run_test():
|
|
658
|
-
async def async_func(value: int) -> int:
|
|
659
|
-
await asyncio.sleep(0.05)
|
|
660
|
-
return value * 4
|
|
661
|
-
|
|
662
|
-
# Test with direct coroutines (only works when retries disabled)
|
|
663
|
-
results = await gather_limited_async(
|
|
664
|
-
async_func(1),
|
|
665
|
-
async_func(2),
|
|
666
|
-
async_func(3),
|
|
667
|
-
retry_settings=NO_RETRIES, # Required for direct coroutines
|
|
668
|
-
)
|
|
669
|
-
|
|
670
|
-
assert results == [4, 8, 12]
|
|
671
|
-
|
|
672
|
-
asyncio.run(run_test())
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
def test_gather_limited_coroutine_retry_validation():
|
|
676
|
-
"""Test that passing coroutines with retries enabled raises ValueError."""
|
|
677
|
-
import asyncio
|
|
678
|
-
|
|
679
|
-
async def run_test():
|
|
680
|
-
async def async_func(value: int) -> int:
|
|
681
|
-
return value
|
|
682
|
-
|
|
683
|
-
coro = async_func(1) # Direct coroutine
|
|
684
|
-
|
|
685
|
-
# Should raise ValueError when trying to use coroutines with retries
|
|
686
|
-
try:
|
|
687
|
-
await gather_limited_async(
|
|
688
|
-
coro, # Direct coroutine
|
|
689
|
-
lambda: async_func(2), # Callable
|
|
690
|
-
retry_settings=RetrySettings(
|
|
691
|
-
max_task_retries=1,
|
|
692
|
-
initial_backoff=0.1,
|
|
693
|
-
max_backoff=1.0,
|
|
694
|
-
backoff_factor=2.0,
|
|
695
|
-
),
|
|
696
|
-
)
|
|
697
|
-
raise AssertionError("Expected ValueError")
|
|
698
|
-
except ValueError as e:
|
|
699
|
-
coro.close() # Close the unused coroutine to prevent RuntimeWarning
|
|
700
|
-
assert "position 0" in str(e)
|
|
701
|
-
assert "cannot be retried" in str(e)
|
|
702
|
-
|
|
703
|
-
asyncio.run(run_test())
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
def test_gather_limited_async_with_retries():
|
|
707
|
-
"""Test that async functions can be retried when using callables."""
|
|
708
|
-
import asyncio
|
|
709
|
-
|
|
710
|
-
async def run_test():
|
|
711
|
-
call_count = 0
|
|
712
|
-
|
|
713
|
-
async def flaky_async_func() -> str:
|
|
714
|
-
"""Async function that fails first time, succeeds second time."""
|
|
715
|
-
nonlocal call_count
|
|
716
|
-
call_count += 1
|
|
717
|
-
if call_count == 1:
|
|
718
|
-
raise Exception("Rate limit exceeded") # Retriable
|
|
719
|
-
return "async_success"
|
|
720
|
-
|
|
721
|
-
# Should succeed after retry using callable
|
|
722
|
-
results = await gather_limited_async(
|
|
723
|
-
lambda: flaky_async_func(),
|
|
724
|
-
retry_settings=RetrySettings(
|
|
725
|
-
max_task_retries=2,
|
|
726
|
-
initial_backoff=0.1,
|
|
727
|
-
max_backoff=1.0,
|
|
728
|
-
backoff_factor=2.0,
|
|
729
|
-
),
|
|
730
|
-
)
|
|
731
|
-
|
|
732
|
-
assert results == ["async_success"]
|
|
733
|
-
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
734
|
-
|
|
735
|
-
asyncio.run(run_test())
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
def test_gather_limited_sync_coroutine_validation():
|
|
739
|
-
"""Test that passing async function callables to sync version raises ValueError."""
|
|
740
|
-
import asyncio
|
|
741
|
-
|
|
742
|
-
async def run_test():
|
|
743
|
-
async def async_func(value: int) -> int:
|
|
744
|
-
return value
|
|
745
|
-
|
|
746
|
-
# Should raise ValueError when trying to use async functions in sync version
|
|
747
|
-
try:
|
|
748
|
-
await gather_limited_sync(
|
|
749
|
-
lambda: async_func(1), # Returns coroutine - should be rejected
|
|
750
|
-
retry_settings=NO_RETRIES,
|
|
751
|
-
)
|
|
752
|
-
raise AssertionError("Expected ValueError")
|
|
753
|
-
except ValueError as e:
|
|
754
|
-
assert "returned a coroutine" in str(e)
|
|
755
|
-
assert "gather_limited_sync() is for synchronous functions only" in str(e)
|
|
756
|
-
|
|
757
|
-
asyncio.run(run_test())
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
def test_gather_limited_retry_exhaustion():
|
|
761
|
-
"""Test that retry exhaustion produces clear error messages."""
|
|
762
|
-
import asyncio
|
|
763
|
-
|
|
764
|
-
async def run_test():
|
|
765
|
-
call_count = 0
|
|
766
|
-
|
|
767
|
-
def always_fails() -> str:
|
|
768
|
-
"""Function that always raises retriable exceptions."""
|
|
769
|
-
nonlocal call_count
|
|
770
|
-
call_count += 1
|
|
771
|
-
raise Exception("Rate limit exceeded") # Always retriable
|
|
772
|
-
|
|
773
|
-
# Should exhaust retries and raise RetryExhaustedException
|
|
774
|
-
try:
|
|
775
|
-
await gather_limited_sync(
|
|
776
|
-
lambda: always_fails(),
|
|
777
|
-
retry_settings=RetrySettings(
|
|
778
|
-
max_task_retries=2,
|
|
779
|
-
initial_backoff=0.01,
|
|
780
|
-
max_backoff=0.1,
|
|
781
|
-
backoff_factor=2.0,
|
|
782
|
-
),
|
|
783
|
-
)
|
|
784
|
-
raise AssertionError("Expected RetryExhaustedException")
|
|
785
|
-
except RetryExhaustedException as e:
|
|
786
|
-
assert "Max retries (2) exhausted" in str(e)
|
|
787
|
-
assert "Rate limit exceeded" in str(e)
|
|
788
|
-
assert isinstance(e.original_exception, Exception)
|
|
789
|
-
assert call_count == 3 # Initial attempt + 2 retries
|
|
790
|
-
|
|
791
|
-
asyncio.run(run_test())
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
def test_gather_limited_return_exceptions():
|
|
795
|
-
"""Test return_exceptions=True behavior for both functions."""
|
|
796
|
-
import asyncio
|
|
797
|
-
|
|
798
|
-
async def run_test():
|
|
799
|
-
def failing_sync() -> str:
|
|
800
|
-
raise ValueError("sync error")
|
|
801
|
-
|
|
802
|
-
async def failing_async() -> str:
|
|
803
|
-
raise ValueError("async error")
|
|
804
|
-
|
|
805
|
-
# Test sync version with exceptions returned
|
|
806
|
-
sync_results = await gather_limited_sync(
|
|
807
|
-
lambda: "success",
|
|
808
|
-
lambda: failing_sync(),
|
|
809
|
-
return_exceptions=True,
|
|
810
|
-
retry_settings=NO_RETRIES,
|
|
811
|
-
)
|
|
812
|
-
|
|
813
|
-
assert len(sync_results) == 2
|
|
814
|
-
assert sync_results[0] == "success"
|
|
815
|
-
assert isinstance(sync_results[1], ValueError)
|
|
816
|
-
assert str(sync_results[1]) == "sync error"
|
|
817
|
-
|
|
818
|
-
async def success_async() -> str:
|
|
819
|
-
return "async_success"
|
|
820
|
-
|
|
821
|
-
# Test async version with exceptions returned
|
|
822
|
-
async_results = await gather_limited_async(
|
|
823
|
-
lambda: success_async(),
|
|
824
|
-
lambda: failing_async(),
|
|
825
|
-
return_exceptions=True,
|
|
826
|
-
retry_settings=NO_RETRIES,
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
assert len(async_results) == 2
|
|
830
|
-
assert async_results[0] == "async_success"
|
|
831
|
-
assert isinstance(async_results[1], ValueError)
|
|
832
|
-
assert str(async_results[1]) == "async error"
|
|
833
|
-
|
|
834
|
-
asyncio.run(run_test())
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
def test_gather_limited_global_retry_limit():
|
|
838
|
-
"""Test that global retry limits are enforced across all tasks."""
|
|
839
|
-
import asyncio
|
|
840
|
-
|
|
841
|
-
async def run_test():
|
|
842
|
-
retry_counts = {"task1": 0, "task2": 0}
|
|
843
|
-
|
|
844
|
-
def flaky_task(task_name: str) -> str:
|
|
845
|
-
"""Tasks that always fail but track retry counts."""
|
|
846
|
-
retry_counts[task_name] += 1
|
|
847
|
-
raise Exception(f"Rate limit exceeded in {task_name}")
|
|
848
|
-
|
|
849
|
-
# Test with very low global retry limit
|
|
850
|
-
try:
|
|
851
|
-
await gather_limited_sync(
|
|
852
|
-
lambda: flaky_task("task1"),
|
|
853
|
-
lambda: flaky_task("task2"),
|
|
854
|
-
retry_settings=RetrySettings(
|
|
855
|
-
max_task_retries=5, # Each task could retry up to 5 times
|
|
856
|
-
max_total_retries=3, # But only 3 total retries across all tasks
|
|
857
|
-
initial_backoff=0.01,
|
|
858
|
-
max_backoff=0.1,
|
|
859
|
-
backoff_factor=2.0,
|
|
860
|
-
),
|
|
861
|
-
return_exceptions=True,
|
|
862
|
-
)
|
|
863
|
-
except Exception:
|
|
864
|
-
pass # Expected to fail due to rate limits
|
|
865
|
-
|
|
866
|
-
# Verify that total retries across both tasks doesn't exceed global limit
|
|
867
|
-
total_retries = (retry_counts["task1"] - 1) + (
|
|
868
|
-
retry_counts["task2"] - 1
|
|
869
|
-
) # -1 for initial attempts
|
|
870
|
-
assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
|
|
871
|
-
|
|
872
|
-
# Verify that both tasks were attempted at least once
|
|
873
|
-
assert retry_counts["task1"] >= 1
|
|
874
|
-
assert retry_counts["task2"] >= 1
|
|
875
|
-
|
|
876
|
-
asyncio.run(run_test())
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def test_gather_limited_funcspec_format():
|
|
880
|
-
"""Test gather_limited with FuncSpec format and custom labeler accessing args."""
|
|
881
|
-
import asyncio
|
|
882
|
-
|
|
883
|
-
async def run_test():
|
|
884
|
-
def sync_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
885
|
-
"""Sync function that takes args and kwargs."""
|
|
886
|
-
return f"{name}: {value * multiplier}"
|
|
887
|
-
|
|
888
|
-
async def async_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
889
|
-
"""Async function that takes args and kwargs."""
|
|
890
|
-
await asyncio.sleep(0.01)
|
|
891
|
-
return f"{name}: {value * multiplier}"
|
|
892
|
-
|
|
893
|
-
captured_labels = []
|
|
894
|
-
|
|
895
|
-
def custom_labeler(i: int, spec: Any) -> str:
|
|
896
|
-
if isinstance(spec, FuncTask):
|
|
897
|
-
# Extract meaningful info from args for labeling
|
|
898
|
-
if spec.args and len(spec.args) > 0:
|
|
899
|
-
label = f"Processing {spec.args[0]}"
|
|
900
|
-
else:
|
|
901
|
-
label = f"Task {i}"
|
|
902
|
-
else:
|
|
903
|
-
label = f"Task {i}"
|
|
904
|
-
captured_labels.append(label)
|
|
905
|
-
return label
|
|
906
|
-
|
|
907
|
-
# Test sync version with FuncSpec format and custom labeler
|
|
908
|
-
sync_results = await gather_limited_sync(
|
|
909
|
-
FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
|
|
910
|
-
FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
|
|
911
|
-
labeler=custom_labeler,
|
|
912
|
-
retry_settings=NO_RETRIES,
|
|
913
|
-
)
|
|
914
|
-
|
|
915
|
-
assert sync_results == ["user1: 300", "user2: 100"]
|
|
916
|
-
assert captured_labels == ["Processing user1", "Processing user2"]
|
|
917
|
-
|
|
918
|
-
# Reset labels for async test
|
|
919
|
-
captured_labels.clear()
|
|
920
|
-
|
|
921
|
-
# Test async version with FuncSpec format and custom labeler
|
|
922
|
-
async_results = await gather_limited_async(
|
|
923
|
-
FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
|
|
924
|
-
FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
|
|
925
|
-
labeler=custom_labeler,
|
|
926
|
-
retry_settings=NO_RETRIES,
|
|
927
|
-
)
|
|
928
|
-
|
|
929
|
-
assert async_results == ["api_call: 40", "data_fetch: 10"]
|
|
930
|
-
assert captured_labels == ["Processing api_call", "Processing data_fetch"]
|
|
931
|
-
|
|
932
|
-
asyncio.run(run_test())
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
def test_gather_limited_sync_cooperative_cancellation():
|
|
936
|
-
"""Test gather_limited_sync with cooperative cancellation via threading.Event."""
|
|
937
|
-
import asyncio
|
|
938
|
-
import time
|
|
939
|
-
|
|
940
|
-
async def run_test():
|
|
941
|
-
cancel_event = threading.Event()
|
|
942
|
-
call_counts = {"task1": 0, "task2": 0}
|
|
943
|
-
|
|
944
|
-
def cancellable_sync_func(task_name: str, work_duration: float) -> str:
|
|
945
|
-
"""Sync function that checks cancellation event periodically."""
|
|
946
|
-
call_counts[task_name] += 1
|
|
947
|
-
start_time = time.time()
|
|
948
|
-
|
|
949
|
-
while time.time() - start_time < work_duration:
|
|
950
|
-
if cancel_event.is_set():
|
|
951
|
-
return f"{task_name}: cancelled"
|
|
952
|
-
time.sleep(0.01) # Small sleep to allow cancellation check
|
|
953
|
-
|
|
954
|
-
return f"{task_name}: completed"
|
|
955
|
-
|
|
956
|
-
# Test cooperative cancellation - tasks should respect the cancel_event
|
|
957
|
-
results = await gather_limited_sync(
|
|
958
|
-
lambda: cancellable_sync_func("task1", 0.1), # Short duration
|
|
959
|
-
lambda: cancellable_sync_func("task2", 0.1), # Short duration
|
|
960
|
-
cancel_event=cancel_event,
|
|
961
|
-
cancel_timeout=1.0,
|
|
962
|
-
retry_settings=NO_RETRIES,
|
|
963
|
-
)
|
|
964
|
-
|
|
965
|
-
# Should complete normally since cancel_event wasn't set
|
|
966
|
-
assert results == ["task1: completed", "task2: completed"]
|
|
967
|
-
assert call_counts["task1"] == 1
|
|
968
|
-
assert call_counts["task2"] == 1
|
|
969
|
-
|
|
970
|
-
# Test that cancel_event can be used independently
|
|
971
|
-
cancel_event.set() # Set cancellation signal
|
|
972
|
-
|
|
973
|
-
results2 = await gather_limited_sync(
|
|
974
|
-
lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
|
|
975
|
-
lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
|
|
976
|
-
cancel_event=cancel_event,
|
|
977
|
-
cancel_timeout=1.0,
|
|
978
|
-
retry_settings=NO_RETRIES,
|
|
979
|
-
)
|
|
980
|
-
|
|
981
|
-
# Should be cancelled quickly since cancel_event is already set
|
|
982
|
-
assert results2 == ["task1: cancelled", "task2: cancelled"]
|
|
983
|
-
# Call counts should increment
|
|
984
|
-
assert call_counts["task1"] == 2
|
|
985
|
-
assert call_counts["task2"] == 2
|
|
986
|
-
|
|
987
|
-
asyncio.run(run_test())
|
|
841
|
+
raise RuntimeError("Unexpected code path in _execute_with_retry_inner")
|