kash-shell 0.3.21__py3-none-any.whl → 0.3.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kash/actions/core/markdownify_html.py +11 -0
- kash/actions/core/tabbed_webpage_generate.py +2 -2
- kash/commands/help/assistant_commands.py +2 -4
- kash/commands/help/logo.py +12 -17
- kash/commands/help/welcome.py +5 -4
- kash/docs/markdown/warning.md +3 -3
- kash/docs/markdown/welcome.md +2 -1
- kash/exec/fetch_url_items.py +23 -13
- kash/exec/preconditions.py +7 -2
- kash/file_storage/file_store.py +3 -3
- kash/model/items_model.py +14 -11
- kash/shell/output/shell_output.py +8 -4
- kash/utils/api_utils/api_retries.py +335 -9
- kash/utils/api_utils/gather_limited.py +204 -488
- kash/utils/text_handling/markdown_utils.py +158 -1
- kash/web_content/web_extract.py +1 -1
- kash/web_gen/tabbed_webpage.py +2 -2
- kash/xonsh_custom/load_into_xonsh.py +0 -3
- {kash_shell-0.3.21.dist-info → kash_shell-0.3.23.dist-info}/METADATA +1 -1
- {kash_shell-0.3.21.dist-info → kash_shell-0.3.23.dist-info}/RECORD +23 -23
- {kash_shell-0.3.21.dist-info → kash_shell-0.3.23.dist-info}/WHEEL +0 -0
- {kash_shell-0.3.21.dist-info → kash_shell-0.3.23.dist-info}/entry_points.txt +0 -0
- {kash_shell-0.3.21.dist-info → kash_shell-0.3.23.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,6 +16,8 @@ from kash.utils.api_utils.api_retries import (
|
|
|
16
16
|
RetryExhaustedException,
|
|
17
17
|
RetrySettings,
|
|
18
18
|
calculate_backoff,
|
|
19
|
+
extract_http_status_code,
|
|
20
|
+
is_http_status_retriable,
|
|
19
21
|
)
|
|
20
22
|
from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
|
|
21
23
|
|
|
@@ -23,17 +25,33 @@ T = TypeVar("T")
|
|
|
23
25
|
|
|
24
26
|
log = logging.getLogger(__name__)
|
|
25
27
|
|
|
28
|
+
DEFAULT_MAX_CONCURRENT: int = 5
|
|
29
|
+
DEFAULT_MAX_RPS: float = 5.0
|
|
30
|
+
DEFAULT_CANCEL_TIMEOUT: float = 1.0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class Limit:
|
|
35
|
+
"""
|
|
36
|
+
Rate limiting configuration with max RPS and max concurrency.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
rps: float = DEFAULT_MAX_RPS
|
|
40
|
+
concurrency: int = DEFAULT_MAX_CONCURRENT
|
|
41
|
+
|
|
26
42
|
|
|
27
43
|
@dataclass(frozen=True)
|
|
28
44
|
class FuncTask:
|
|
29
45
|
"""
|
|
30
46
|
A task described as an unevaluated function with args and kwargs.
|
|
31
47
|
This task format allows you to use args and kwargs in the Labeler.
|
|
48
|
+
It also allows specifying a bucket for rate limiting.
|
|
32
49
|
"""
|
|
33
50
|
|
|
34
51
|
func: Callable[..., Any]
|
|
35
52
|
args: tuple[Any, ...] = ()
|
|
36
53
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
54
|
+
bucket: str = "default"
|
|
37
55
|
|
|
38
56
|
|
|
39
57
|
# Type aliases for coroutine and sync specifications, including unevaluated function specs
|
|
@@ -44,9 +62,30 @@ SyncSpec: TypeAlias = Callable[[], T] | FuncTask
|
|
|
44
62
|
CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
|
|
45
63
|
SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
|
|
46
64
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
65
|
+
|
|
66
|
+
def _get_bucket_limits(
|
|
67
|
+
bucket: str,
|
|
68
|
+
bucket_semaphores: dict[str, asyncio.Semaphore],
|
|
69
|
+
bucket_rate_limiters: dict[str, AsyncLimiter],
|
|
70
|
+
) -> tuple[asyncio.Semaphore | None, AsyncLimiter | None]:
|
|
71
|
+
"""
|
|
72
|
+
Get bucket-specific limits with fallback to "*" wildcard.
|
|
73
|
+
|
|
74
|
+
Checks for exact bucket match first, then falls back to "*" if available.
|
|
75
|
+
Returns (None, None) if neither exact match nor "*" fallback exists.
|
|
76
|
+
"""
|
|
77
|
+
# Try exact bucket match first
|
|
78
|
+
bucket_semaphore = bucket_semaphores.get(bucket)
|
|
79
|
+
bucket_rate_limiter = bucket_rate_limiters.get(bucket)
|
|
80
|
+
|
|
81
|
+
if bucket_semaphore is not None and bucket_rate_limiter is not None:
|
|
82
|
+
return bucket_semaphore, bucket_rate_limiter
|
|
83
|
+
|
|
84
|
+
# Fall back to "*" wildcard if available
|
|
85
|
+
bucket_semaphore = bucket_semaphores.get("*")
|
|
86
|
+
bucket_rate_limiter = bucket_rate_limiters.get("*")
|
|
87
|
+
|
|
88
|
+
return bucket_semaphore, bucket_rate_limiter
|
|
50
89
|
|
|
51
90
|
|
|
52
91
|
class RetryCounter:
|
|
@@ -75,8 +114,8 @@ class RetryCounter:
|
|
|
75
114
|
@overload
|
|
76
115
|
async def gather_limited_async(
|
|
77
116
|
*coro_specs: CoroSpec[T],
|
|
78
|
-
|
|
79
|
-
|
|
117
|
+
global_limit: Limit = Limit(),
|
|
118
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
80
119
|
return_exceptions: bool = False,
|
|
81
120
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
82
121
|
status: ProgressTracker | None = None,
|
|
@@ -87,8 +126,8 @@ async def gather_limited_async(
|
|
|
87
126
|
@overload
|
|
88
127
|
async def gather_limited_async(
|
|
89
128
|
*coro_specs: CoroSpec[T],
|
|
90
|
-
|
|
91
|
-
|
|
129
|
+
global_limit: Limit = Limit(),
|
|
130
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
92
131
|
return_exceptions: bool = True,
|
|
93
132
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
94
133
|
status: ProgressTracker | None = None,
|
|
@@ -98,21 +137,26 @@ async def gather_limited_async(
|
|
|
98
137
|
|
|
99
138
|
async def gather_limited_async(
|
|
100
139
|
*coro_specs: CoroSpec[T],
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
return_exceptions: bool =
|
|
140
|
+
global_limit: Limit = Limit(),
|
|
141
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
142
|
+
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
104
143
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
105
144
|
status: ProgressTracker | None = None,
|
|
106
145
|
labeler: CoroLabeler[T] | None = None,
|
|
107
146
|
) -> list[T] | list[T | BaseException]:
|
|
108
147
|
"""
|
|
109
|
-
Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
|
|
148
|
+
Rate-limited version of `asyncio.gather()` with HTTP-aware retry logic and optional progress display.
|
|
110
149
|
Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
|
|
111
150
|
|
|
112
151
|
Supports two levels of retry limits:
|
|
113
152
|
- Per-task retries: max_task_retries attempts per individual task
|
|
114
153
|
- Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
|
|
115
154
|
|
|
155
|
+
Features HTTP-aware retry classification:
|
|
156
|
+
- Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
|
|
157
|
+
- Configurable handling of conditional status codes like 403 Forbidden
|
|
158
|
+
- Defaults to return_exceptions=True for resilient batch operations
|
|
159
|
+
|
|
116
160
|
Can optionally display live progress with retry indicators using TaskStatus.
|
|
117
161
|
|
|
118
162
|
Accepts:
|
|
@@ -126,7 +170,7 @@ async def gather_limited_async(
|
|
|
126
170
|
from kash.utils.rich_custom.task_status import TaskStatus
|
|
127
171
|
|
|
128
172
|
async with TaskStatus() as status:
|
|
129
|
-
await
|
|
173
|
+
await gather_limited_async(
|
|
130
174
|
lambda: fetch_url("http://example.com"),
|
|
131
175
|
lambda: process_data(data),
|
|
132
176
|
status=status,
|
|
@@ -135,18 +179,31 @@ async def gather_limited_async(
|
|
|
135
179
|
)
|
|
136
180
|
|
|
137
181
|
# Without progress display:
|
|
138
|
-
await
|
|
182
|
+
await gather_limited_async(
|
|
139
183
|
lambda: fetch_url("http://example.com"),
|
|
140
184
|
lambda: process_data(data),
|
|
141
185
|
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
142
186
|
)
|
|
143
187
|
|
|
188
|
+
# With bucket-specific limits and "*" fallback:
|
|
189
|
+
await gather_limited_async(
|
|
190
|
+
FuncTask(fetch_api, ("data1",), bucket="api1"),
|
|
191
|
+
FuncTask(fetch_api, ("data2",), bucket="api2"),
|
|
192
|
+
FuncTask(fetch_api, ("data3",), bucket="api3"),
|
|
193
|
+
global_limit=Limit(rps=100, concurrency=50),
|
|
194
|
+
bucket_limits={
|
|
195
|
+
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
196
|
+
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
197
|
+
}
|
|
198
|
+
)
|
|
199
|
+
|
|
144
200
|
```
|
|
145
201
|
|
|
146
202
|
Args:
|
|
147
203
|
*coro_specs: Callables or coroutines to execute
|
|
148
|
-
|
|
149
|
-
|
|
204
|
+
global_limit: Global limits applied to all tasks regardless of bucket
|
|
205
|
+
bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
|
|
206
|
+
Use "*" as a fallback limit for buckets without specific limits.
|
|
150
207
|
return_exceptions: If True, exceptions are returned as results
|
|
151
208
|
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
152
209
|
status: Optional ProgressTracker instance for progress display
|
|
@@ -159,9 +216,9 @@ async def gather_limited_async(
|
|
|
159
216
|
ValueError: If coroutines are passed when retries are enabled
|
|
160
217
|
"""
|
|
161
218
|
log.info(
|
|
162
|
-
"Executing with concurrency %s at %s rps, %s",
|
|
163
|
-
|
|
164
|
-
|
|
219
|
+
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
220
|
+
global_limit.concurrency,
|
|
221
|
+
global_limit.rps,
|
|
165
222
|
retry_settings,
|
|
166
223
|
)
|
|
167
224
|
if not coro_specs:
|
|
@@ -179,8 +236,18 @@ async def gather_limited_async(
|
|
|
179
236
|
f"lambda: your_async_func(args) instead of your_async_func(args)"
|
|
180
237
|
)
|
|
181
238
|
|
|
182
|
-
|
|
183
|
-
|
|
239
|
+
# Global limits (apply to all tasks regardless of bucket)
|
|
240
|
+
global_semaphore = asyncio.Semaphore(global_limit.concurrency)
|
|
241
|
+
global_rate_limiter = AsyncLimiter(global_limit.rps, 1.0)
|
|
242
|
+
|
|
243
|
+
# Per-bucket limits (if bucket_limits provided)
|
|
244
|
+
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
245
|
+
bucket_rate_limiters: dict[str, AsyncLimiter] = {}
|
|
246
|
+
|
|
247
|
+
if bucket_limits:
|
|
248
|
+
for bucket_name, limit in bucket_limits.items():
|
|
249
|
+
bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
|
|
250
|
+
bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
|
|
184
251
|
|
|
185
252
|
# Global retry counter (shared across all tasks)
|
|
186
253
|
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
@@ -190,6 +257,16 @@ async def gather_limited_async(
|
|
|
190
257
|
label = labeler(i, coro_spec) if labeler else f"task:{i}"
|
|
191
258
|
task_id = await status.add(label) if status else None
|
|
192
259
|
|
|
260
|
+
# Determine bucket and get appropriate limits
|
|
261
|
+
bucket = "default"
|
|
262
|
+
if isinstance(coro_spec, FuncTask):
|
|
263
|
+
bucket = coro_spec.bucket
|
|
264
|
+
|
|
265
|
+
# Get bucket-specific limits if available
|
|
266
|
+
bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
|
|
267
|
+
bucket, bucket_semaphores, bucket_rate_limiters
|
|
268
|
+
)
|
|
269
|
+
|
|
193
270
|
async def executor() -> T:
|
|
194
271
|
# Create a fresh coroutine for each attempt
|
|
195
272
|
if isinstance(coro_spec, FuncTask):
|
|
@@ -206,8 +283,10 @@ async def gather_limited_async(
|
|
|
206
283
|
result = await _execute_with_retry(
|
|
207
284
|
executor,
|
|
208
285
|
retry_settings,
|
|
209
|
-
|
|
210
|
-
|
|
286
|
+
global_semaphore,
|
|
287
|
+
global_rate_limiter,
|
|
288
|
+
bucket_semaphore,
|
|
289
|
+
bucket_rate_limiter,
|
|
211
290
|
global_retry_counter,
|
|
212
291
|
status,
|
|
213
292
|
task_id,
|
|
@@ -234,8 +313,8 @@ async def gather_limited_async(
|
|
|
234
313
|
@overload
|
|
235
314
|
async def gather_limited_sync(
|
|
236
315
|
*sync_specs: SyncSpec[T],
|
|
237
|
-
|
|
238
|
-
|
|
316
|
+
global_limit: Limit = Limit(),
|
|
317
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
239
318
|
return_exceptions: bool = False,
|
|
240
319
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
241
320
|
status: ProgressTracker | None = None,
|
|
@@ -248,8 +327,8 @@ async def gather_limited_sync(
|
|
|
248
327
|
@overload
|
|
249
328
|
async def gather_limited_sync(
|
|
250
329
|
*sync_specs: SyncSpec[T],
|
|
251
|
-
|
|
252
|
-
|
|
330
|
+
global_limit: Limit = Limit(),
|
|
331
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
253
332
|
return_exceptions: bool = True,
|
|
254
333
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
255
334
|
status: ProgressTracker | None = None,
|
|
@@ -261,9 +340,9 @@ async def gather_limited_sync(
|
|
|
261
340
|
|
|
262
341
|
async def gather_limited_sync(
|
|
263
342
|
*sync_specs: SyncSpec[T],
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
return_exceptions: bool =
|
|
343
|
+
global_limit: Limit = Limit(),
|
|
344
|
+
bucket_limits: dict[str, Limit] | None = None,
|
|
345
|
+
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
267
346
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
268
347
|
status: ProgressTracker | None = None,
|
|
269
348
|
labeler: SyncLabeler[T] | None = None,
|
|
@@ -271,19 +350,25 @@ async def gather_limited_sync(
|
|
|
271
350
|
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
272
351
|
) -> list[T] | list[T | BaseException]:
|
|
273
352
|
"""
|
|
274
|
-
Rate-limited version of `asyncio.gather()` for sync functions with retry logic.
|
|
353
|
+
Rate-limited version of `asyncio.gather()` for sync functions with HTTP-aware retry logic.
|
|
275
354
|
Handles the asyncio.to_thread() boundary correctly for consistent exception propagation.
|
|
276
355
|
|
|
277
356
|
Supports two levels of retry limits:
|
|
278
357
|
- Per-task retries: max_task_retries attempts per individual task
|
|
279
358
|
- Global retries: max_total_retries attempts across all tasks
|
|
280
359
|
|
|
360
|
+
Features HTTP-aware retry classification:
|
|
361
|
+
- Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
|
|
362
|
+
- Configurable handling of conditional status codes like 403 Forbidden
|
|
363
|
+
- Defaults to return_exceptions=True for resilient batch operations
|
|
364
|
+
|
|
281
365
|
Supports cooperative cancellation and graceful thread termination on interruption.
|
|
282
366
|
|
|
283
367
|
Args:
|
|
284
368
|
*sync_specs: Callables that return values (not coroutines) or FuncTask objects
|
|
285
|
-
|
|
286
|
-
|
|
369
|
+
global_limit: Global limits applied to all tasks regardless of bucket
|
|
370
|
+
bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
|
|
371
|
+
Use "*" as a fallback limit for buckets without specific limits.
|
|
287
372
|
return_exceptions: If True, exceptions are returned as results
|
|
288
373
|
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
289
374
|
status: Optional ProgressTracker instance for progress display
|
|
@@ -296,29 +381,31 @@ async def gather_limited_sync(
|
|
|
296
381
|
|
|
297
382
|
Example:
|
|
298
383
|
```python
|
|
299
|
-
# Without
|
|
384
|
+
# Without bucket limits (backward compatible)
|
|
300
385
|
results = await gather_limited_sync(
|
|
301
386
|
lambda: some_sync_function(arg1),
|
|
302
387
|
lambda: another_sync_function(arg2),
|
|
303
|
-
|
|
304
|
-
max_rps=2.0,
|
|
388
|
+
global_limit=Limit(rps=2.0, concurrency=3),
|
|
305
389
|
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
306
390
|
)
|
|
307
391
|
|
|
308
|
-
# With
|
|
309
|
-
cancel_event = threading.Event()
|
|
392
|
+
# With bucket-specific limits and "*" fallback
|
|
310
393
|
results = await gather_limited_sync(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
394
|
+
FuncTask(fetch_from_api, ("data1",), bucket="api1"),
|
|
395
|
+
FuncTask(fetch_from_api, ("data2",), bucket="api2"),
|
|
396
|
+
FuncTask(fetch_from_api, ("data3",), bucket="api3"),
|
|
397
|
+
global_limit=Limit(rps=100, concurrency=50),
|
|
398
|
+
bucket_limits={
|
|
399
|
+
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
400
|
+
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
401
|
+
}
|
|
315
402
|
)
|
|
316
403
|
```
|
|
317
404
|
"""
|
|
318
405
|
log.info(
|
|
319
|
-
"Executing with concurrency %s at %s rps, %s",
|
|
320
|
-
|
|
321
|
-
|
|
406
|
+
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
407
|
+
global_limit.concurrency,
|
|
408
|
+
global_limit.rps,
|
|
322
409
|
retry_settings,
|
|
323
410
|
)
|
|
324
411
|
if not sync_specs:
|
|
@@ -326,8 +413,18 @@ async def gather_limited_sync(
|
|
|
326
413
|
|
|
327
414
|
retry_settings = retry_settings or NO_RETRIES
|
|
328
415
|
|
|
329
|
-
|
|
330
|
-
|
|
416
|
+
# Global limits (apply to all tasks regardless of bucket)
|
|
417
|
+
global_semaphore = asyncio.Semaphore(global_limit.concurrency)
|
|
418
|
+
global_rate_limiter = AsyncLimiter(global_limit.rps, 1.0)
|
|
419
|
+
|
|
420
|
+
# Per-bucket limits (if bucket_limits provided)
|
|
421
|
+
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
422
|
+
bucket_rate_limiters: dict[str, AsyncLimiter] = {}
|
|
423
|
+
|
|
424
|
+
if bucket_limits:
|
|
425
|
+
for bucket_name, limit in bucket_limits.items():
|
|
426
|
+
bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
|
|
427
|
+
bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
|
|
331
428
|
|
|
332
429
|
# Global retry counter (shared across all tasks)
|
|
333
430
|
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
@@ -337,6 +434,16 @@ async def gather_limited_sync(
|
|
|
337
434
|
label = labeler(i, sync_spec) if labeler else f"task:{i}"
|
|
338
435
|
task_id = await status.add(label) if status else None
|
|
339
436
|
|
|
437
|
+
# Determine bucket and get appropriate limits
|
|
438
|
+
bucket = "default"
|
|
439
|
+
if isinstance(sync_spec, FuncTask):
|
|
440
|
+
bucket = sync_spec.bucket
|
|
441
|
+
|
|
442
|
+
# Get bucket-specific limits if available
|
|
443
|
+
bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
|
|
444
|
+
bucket, bucket_semaphores, bucket_rate_limiters
|
|
445
|
+
)
|
|
446
|
+
|
|
340
447
|
async def executor() -> T:
|
|
341
448
|
# Call sync function via asyncio.to_thread, handling retry at this level
|
|
342
449
|
if isinstance(sync_spec, FuncTask):
|
|
@@ -361,8 +468,10 @@ async def gather_limited_sync(
|
|
|
361
468
|
result = await _execute_with_retry(
|
|
362
469
|
executor,
|
|
363
470
|
retry_settings,
|
|
364
|
-
|
|
365
|
-
|
|
471
|
+
global_semaphore,
|
|
472
|
+
global_rate_limiter,
|
|
473
|
+
bucket_semaphore,
|
|
474
|
+
bucket_rate_limiter,
|
|
366
475
|
global_retry_counter,
|
|
367
476
|
status,
|
|
368
477
|
task_id,
|
|
@@ -434,7 +543,8 @@ async def _gather_with_interrupt_handling(
|
|
|
434
543
|
if cancelled_count > 0:
|
|
435
544
|
try:
|
|
436
545
|
await asyncio.wait_for(
|
|
437
|
-
asyncio.gather(*async_tasks, return_exceptions=True),
|
|
546
|
+
asyncio.gather(*async_tasks, return_exceptions=True),
|
|
547
|
+
timeout=cancel_timeout,
|
|
438
548
|
)
|
|
439
549
|
except (TimeoutError, asyncio.CancelledError):
|
|
440
550
|
log.warning("Some tasks did not cancel within timeout")
|
|
@@ -462,8 +572,10 @@ async def _gather_with_interrupt_handling(
|
|
|
462
572
|
async def _execute_with_retry(
|
|
463
573
|
executor: Callable[[], Coroutine[None, None, T]],
|
|
464
574
|
retry_settings: RetrySettings,
|
|
465
|
-
|
|
466
|
-
|
|
575
|
+
global_semaphore: asyncio.Semaphore,
|
|
576
|
+
global_rate_limiter: AsyncLimiter,
|
|
577
|
+
bucket_semaphore: asyncio.Semaphore | None,
|
|
578
|
+
bucket_rate_limiter: AsyncLimiter | None,
|
|
467
579
|
global_retry_counter: RetryCounter,
|
|
468
580
|
status: ProgressTracker | None = None,
|
|
469
581
|
task_id: Any | None = None,
|
|
@@ -473,9 +585,22 @@ async def _execute_with_retry(
|
|
|
473
585
|
start_time = time.time()
|
|
474
586
|
last_exception: Exception | None = None
|
|
475
587
|
|
|
588
|
+
# Create HTTP-aware is_retriable function based on settings
|
|
589
|
+
def is_retriable_for_task(exception: Exception) -> bool:
|
|
590
|
+
# First check the main is_retriable function
|
|
591
|
+
if not retry_settings.is_retriable(exception):
|
|
592
|
+
return False
|
|
593
|
+
|
|
594
|
+
status_code = extract_http_status_code(exception)
|
|
595
|
+
if status_code:
|
|
596
|
+
return is_http_status_retriable(status_code, retry_settings.http_retry_map)
|
|
597
|
+
|
|
598
|
+
# Not an HTTP error, use the main is_retriable result
|
|
599
|
+
return True
|
|
600
|
+
|
|
476
601
|
for attempt in range(retry_settings.max_task_retries + 1):
|
|
477
602
|
# Handle backoff before acquiring any resources
|
|
478
|
-
if attempt > 0 and last_exception
|
|
603
|
+
if attempt > 0 and last_exception:
|
|
479
604
|
# Try to increment global retry counter
|
|
480
605
|
if not await global_retry_counter.try_increment():
|
|
481
606
|
log.error(
|
|
@@ -493,7 +618,7 @@ async def _execute_with_retry(
|
|
|
493
618
|
)
|
|
494
619
|
|
|
495
620
|
# Record retry in status display and log appropriately
|
|
496
|
-
if status and task_id
|
|
621
|
+
if status and task_id:
|
|
497
622
|
# Include retry attempt info and backoff time in the status display
|
|
498
623
|
retry_info = (
|
|
499
624
|
f"Attempt {attempt}/{retry_settings.max_task_retries} "
|
|
@@ -508,8 +633,11 @@ async def _execute_with_retry(
|
|
|
508
633
|
use_debug_level = False
|
|
509
634
|
|
|
510
635
|
# Log retry information at appropriate level
|
|
636
|
+
status_code = extract_http_status_code(last_exception)
|
|
637
|
+
status_info = f" (HTTP {status_code})" if status_code else ""
|
|
638
|
+
|
|
511
639
|
rate_limit_msg = (
|
|
512
|
-
f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
|
|
640
|
+
f"Rate limit hit{status_info} (attempt {attempt}/{retry_settings.max_task_retries} "
|
|
513
641
|
f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
|
|
514
642
|
f"backing off for {backoff_time:.2f}s"
|
|
515
643
|
)
|
|
@@ -526,12 +654,20 @@ async def _execute_with_retry(
|
|
|
526
654
|
await asyncio.sleep(backoff_time)
|
|
527
655
|
|
|
528
656
|
try:
|
|
529
|
-
# Acquire
|
|
530
|
-
async with
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
657
|
+
# Acquire global limits first, then bucket-specific limits if present
|
|
658
|
+
async with global_semaphore, global_rate_limiter:
|
|
659
|
+
if bucket_semaphore is not None and bucket_rate_limiter is not None:
|
|
660
|
+
async with bucket_semaphore, bucket_rate_limiter:
|
|
661
|
+
# Mark task as started now that we've passed rate limiting
|
|
662
|
+
if status and task_id is not None and attempt == 0:
|
|
663
|
+
await status.start(task_id)
|
|
664
|
+
return await executor()
|
|
665
|
+
else:
|
|
666
|
+
# Mark task as started now that we've passed rate limiting
|
|
667
|
+
if status and task_id is not None and attempt == 0:
|
|
668
|
+
await status.start(task_id)
|
|
669
|
+
return await executor()
|
|
670
|
+
|
|
535
671
|
except Exception as e:
|
|
536
672
|
last_exception = e # Always store the exception
|
|
537
673
|
|
|
@@ -549,439 +685,19 @@ async def _execute_with_retry(
|
|
|
549
685
|
)
|
|
550
686
|
raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
|
|
551
687
|
|
|
552
|
-
# Check if this is a retriable exception
|
|
553
|
-
if
|
|
688
|
+
# Check if this is a retriable exception using our HTTP-aware function
|
|
689
|
+
if is_retriable_for_task(e):
|
|
554
690
|
# Continue to next retry attempt (global limits will be checked at top of loop)
|
|
555
691
|
continue
|
|
556
692
|
else:
|
|
557
693
|
# Non-retriable exception, log and re-raise immediately
|
|
558
|
-
|
|
694
|
+
status_code = extract_http_status_code(e)
|
|
695
|
+
status_info = f" (HTTP {status_code})" if status_code else ""
|
|
696
|
+
|
|
697
|
+
log.warning(
|
|
698
|
+
"Non-retriable exception%s (not retrying): %s", status_info, e, exc_info=True
|
|
699
|
+
)
|
|
559
700
|
raise
|
|
560
701
|
|
|
561
702
|
# This should never be reached, but satisfy type checker
|
|
562
703
|
raise RuntimeError("Unexpected code path in _execute_with_retry")
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
## Tests
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
def test_gather_limited_sync():
|
|
569
|
-
"""Test gather_limited_sync with sync functions."""
|
|
570
|
-
import asyncio
|
|
571
|
-
import time
|
|
572
|
-
|
|
573
|
-
async def run_test():
|
|
574
|
-
def sync_func(value: int) -> int:
|
|
575
|
-
"""Simple sync function for testing."""
|
|
576
|
-
time.sleep(0.1) # Simulate some work
|
|
577
|
-
return value * 2
|
|
578
|
-
|
|
579
|
-
# Test basic functionality
|
|
580
|
-
results = await gather_limited_sync(
|
|
581
|
-
lambda: sync_func(1),
|
|
582
|
-
lambda: sync_func(2),
|
|
583
|
-
lambda: sync_func(3),
|
|
584
|
-
max_concurrent=2,
|
|
585
|
-
max_rps=10.0,
|
|
586
|
-
retry_settings=NO_RETRIES,
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
assert results == [2, 4, 6]
|
|
590
|
-
|
|
591
|
-
# Run the async test
|
|
592
|
-
asyncio.run(run_test())
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def test_gather_limited_sync_with_retries():
|
|
596
|
-
"""Test that sync functions can be retried on retriable exceptions."""
|
|
597
|
-
import asyncio
|
|
598
|
-
|
|
599
|
-
async def run_test():
|
|
600
|
-
call_count = 0
|
|
601
|
-
|
|
602
|
-
def flaky_sync_func() -> str:
|
|
603
|
-
"""Sync function that fails first time, succeeds second time."""
|
|
604
|
-
nonlocal call_count
|
|
605
|
-
call_count += 1
|
|
606
|
-
if call_count == 1:
|
|
607
|
-
raise Exception("Rate limit exceeded") # Retriable
|
|
608
|
-
return "success"
|
|
609
|
-
|
|
610
|
-
# Should succeed after retry
|
|
611
|
-
results = await gather_limited_sync(
|
|
612
|
-
lambda: flaky_sync_func(),
|
|
613
|
-
retry_settings=RetrySettings(
|
|
614
|
-
max_task_retries=2,
|
|
615
|
-
initial_backoff=0.1,
|
|
616
|
-
max_backoff=1.0,
|
|
617
|
-
backoff_factor=2.0,
|
|
618
|
-
),
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
assert results == ["success"]
|
|
622
|
-
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
623
|
-
|
|
624
|
-
# Run the async test
|
|
625
|
-
asyncio.run(run_test())
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def test_gather_limited_async_basic():
|
|
629
|
-
"""Test gather_limited with async functions using callables."""
|
|
630
|
-
import asyncio
|
|
631
|
-
|
|
632
|
-
async def run_test():
|
|
633
|
-
async def async_func(value: int) -> int:
|
|
634
|
-
"""Simple async function for testing."""
|
|
635
|
-
await asyncio.sleep(0.05) # Simulate async work
|
|
636
|
-
return value * 3
|
|
637
|
-
|
|
638
|
-
# Test with callables (recommended pattern)
|
|
639
|
-
results = await gather_limited_async(
|
|
640
|
-
lambda: async_func(1),
|
|
641
|
-
lambda: async_func(2),
|
|
642
|
-
lambda: async_func(3),
|
|
643
|
-
max_concurrent=2,
|
|
644
|
-
max_rps=10.0,
|
|
645
|
-
retry_settings=NO_RETRIES,
|
|
646
|
-
)
|
|
647
|
-
|
|
648
|
-
assert results == [3, 6, 9]
|
|
649
|
-
|
|
650
|
-
asyncio.run(run_test())
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
def test_gather_limited_direct_coroutines():
|
|
654
|
-
"""Test gather_limited with direct coroutines when retries disabled."""
|
|
655
|
-
import asyncio
|
|
656
|
-
|
|
657
|
-
async def run_test():
|
|
658
|
-
async def async_func(value: int) -> int:
|
|
659
|
-
await asyncio.sleep(0.05)
|
|
660
|
-
return value * 4
|
|
661
|
-
|
|
662
|
-
# Test with direct coroutines (only works when retries disabled)
|
|
663
|
-
results = await gather_limited_async(
|
|
664
|
-
async_func(1),
|
|
665
|
-
async_func(2),
|
|
666
|
-
async_func(3),
|
|
667
|
-
retry_settings=NO_RETRIES, # Required for direct coroutines
|
|
668
|
-
)
|
|
669
|
-
|
|
670
|
-
assert results == [4, 8, 12]
|
|
671
|
-
|
|
672
|
-
asyncio.run(run_test())
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
def test_gather_limited_coroutine_retry_validation():
|
|
676
|
-
"""Test that passing coroutines with retries enabled raises ValueError."""
|
|
677
|
-
import asyncio
|
|
678
|
-
|
|
679
|
-
async def run_test():
|
|
680
|
-
async def async_func(value: int) -> int:
|
|
681
|
-
return value
|
|
682
|
-
|
|
683
|
-
coro = async_func(1) # Direct coroutine
|
|
684
|
-
|
|
685
|
-
# Should raise ValueError when trying to use coroutines with retries
|
|
686
|
-
try:
|
|
687
|
-
await gather_limited_async(
|
|
688
|
-
coro, # Direct coroutine
|
|
689
|
-
lambda: async_func(2), # Callable
|
|
690
|
-
retry_settings=RetrySettings(
|
|
691
|
-
max_task_retries=1,
|
|
692
|
-
initial_backoff=0.1,
|
|
693
|
-
max_backoff=1.0,
|
|
694
|
-
backoff_factor=2.0,
|
|
695
|
-
),
|
|
696
|
-
)
|
|
697
|
-
raise AssertionError("Expected ValueError")
|
|
698
|
-
except ValueError as e:
|
|
699
|
-
coro.close() # Close the unused coroutine to prevent RuntimeWarning
|
|
700
|
-
assert "position 0" in str(e)
|
|
701
|
-
assert "cannot be retried" in str(e)
|
|
702
|
-
|
|
703
|
-
asyncio.run(run_test())
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
def test_gather_limited_async_with_retries():
|
|
707
|
-
"""Test that async functions can be retried when using callables."""
|
|
708
|
-
import asyncio
|
|
709
|
-
|
|
710
|
-
async def run_test():
|
|
711
|
-
call_count = 0
|
|
712
|
-
|
|
713
|
-
async def flaky_async_func() -> str:
|
|
714
|
-
"""Async function that fails first time, succeeds second time."""
|
|
715
|
-
nonlocal call_count
|
|
716
|
-
call_count += 1
|
|
717
|
-
if call_count == 1:
|
|
718
|
-
raise Exception("Rate limit exceeded") # Retriable
|
|
719
|
-
return "async_success"
|
|
720
|
-
|
|
721
|
-
# Should succeed after retry using callable
|
|
722
|
-
results = await gather_limited_async(
|
|
723
|
-
lambda: flaky_async_func(),
|
|
724
|
-
retry_settings=RetrySettings(
|
|
725
|
-
max_task_retries=2,
|
|
726
|
-
initial_backoff=0.1,
|
|
727
|
-
max_backoff=1.0,
|
|
728
|
-
backoff_factor=2.0,
|
|
729
|
-
),
|
|
730
|
-
)
|
|
731
|
-
|
|
732
|
-
assert results == ["async_success"]
|
|
733
|
-
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
734
|
-
|
|
735
|
-
asyncio.run(run_test())
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
def test_gather_limited_sync_coroutine_validation():
|
|
739
|
-
"""Test that passing async function callables to sync version raises ValueError."""
|
|
740
|
-
import asyncio
|
|
741
|
-
|
|
742
|
-
async def run_test():
|
|
743
|
-
async def async_func(value: int) -> int:
|
|
744
|
-
return value
|
|
745
|
-
|
|
746
|
-
# Should raise ValueError when trying to use async functions in sync version
|
|
747
|
-
try:
|
|
748
|
-
await gather_limited_sync(
|
|
749
|
-
lambda: async_func(1), # Returns coroutine - should be rejected
|
|
750
|
-
retry_settings=NO_RETRIES,
|
|
751
|
-
)
|
|
752
|
-
raise AssertionError("Expected ValueError")
|
|
753
|
-
except ValueError as e:
|
|
754
|
-
assert "returned a coroutine" in str(e)
|
|
755
|
-
assert "gather_limited_sync() is for synchronous functions only" in str(e)
|
|
756
|
-
|
|
757
|
-
asyncio.run(run_test())
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
def test_gather_limited_retry_exhaustion():
|
|
761
|
-
"""Test that retry exhaustion produces clear error messages."""
|
|
762
|
-
import asyncio
|
|
763
|
-
|
|
764
|
-
async def run_test():
|
|
765
|
-
call_count = 0
|
|
766
|
-
|
|
767
|
-
def always_fails() -> str:
|
|
768
|
-
"""Function that always raises retriable exceptions."""
|
|
769
|
-
nonlocal call_count
|
|
770
|
-
call_count += 1
|
|
771
|
-
raise Exception("Rate limit exceeded") # Always retriable
|
|
772
|
-
|
|
773
|
-
# Should exhaust retries and raise RetryExhaustedException
|
|
774
|
-
try:
|
|
775
|
-
await gather_limited_sync(
|
|
776
|
-
lambda: always_fails(),
|
|
777
|
-
retry_settings=RetrySettings(
|
|
778
|
-
max_task_retries=2,
|
|
779
|
-
initial_backoff=0.01,
|
|
780
|
-
max_backoff=0.1,
|
|
781
|
-
backoff_factor=2.0,
|
|
782
|
-
),
|
|
783
|
-
)
|
|
784
|
-
raise AssertionError("Expected RetryExhaustedException")
|
|
785
|
-
except RetryExhaustedException as e:
|
|
786
|
-
assert "Max retries (2) exhausted" in str(e)
|
|
787
|
-
assert "Rate limit exceeded" in str(e)
|
|
788
|
-
assert isinstance(e.original_exception, Exception)
|
|
789
|
-
assert call_count == 3 # Initial attempt + 2 retries
|
|
790
|
-
|
|
791
|
-
asyncio.run(run_test())
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
def test_gather_limited_return_exceptions():
|
|
795
|
-
"""Test return_exceptions=True behavior for both functions."""
|
|
796
|
-
import asyncio
|
|
797
|
-
|
|
798
|
-
async def run_test():
|
|
799
|
-
def failing_sync() -> str:
|
|
800
|
-
raise ValueError("sync error")
|
|
801
|
-
|
|
802
|
-
async def failing_async() -> str:
|
|
803
|
-
raise ValueError("async error")
|
|
804
|
-
|
|
805
|
-
# Test sync version with exceptions returned
|
|
806
|
-
sync_results = await gather_limited_sync(
|
|
807
|
-
lambda: "success",
|
|
808
|
-
lambda: failing_sync(),
|
|
809
|
-
return_exceptions=True,
|
|
810
|
-
retry_settings=NO_RETRIES,
|
|
811
|
-
)
|
|
812
|
-
|
|
813
|
-
assert len(sync_results) == 2
|
|
814
|
-
assert sync_results[0] == "success"
|
|
815
|
-
assert isinstance(sync_results[1], ValueError)
|
|
816
|
-
assert str(sync_results[1]) == "sync error"
|
|
817
|
-
|
|
818
|
-
async def success_async() -> str:
|
|
819
|
-
return "async_success"
|
|
820
|
-
|
|
821
|
-
# Test async version with exceptions returned
|
|
822
|
-
async_results = await gather_limited_async(
|
|
823
|
-
lambda: success_async(),
|
|
824
|
-
lambda: failing_async(),
|
|
825
|
-
return_exceptions=True,
|
|
826
|
-
retry_settings=NO_RETRIES,
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
assert len(async_results) == 2
|
|
830
|
-
assert async_results[0] == "async_success"
|
|
831
|
-
assert isinstance(async_results[1], ValueError)
|
|
832
|
-
assert str(async_results[1]) == "async error"
|
|
833
|
-
|
|
834
|
-
asyncio.run(run_test())
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
def test_gather_limited_global_retry_limit():
|
|
838
|
-
"""Test that global retry limits are enforced across all tasks."""
|
|
839
|
-
import asyncio
|
|
840
|
-
|
|
841
|
-
async def run_test():
|
|
842
|
-
retry_counts = {"task1": 0, "task2": 0}
|
|
843
|
-
|
|
844
|
-
def flaky_task(task_name: str) -> str:
|
|
845
|
-
"""Tasks that always fail but track retry counts."""
|
|
846
|
-
retry_counts[task_name] += 1
|
|
847
|
-
raise Exception(f"Rate limit exceeded in {task_name}")
|
|
848
|
-
|
|
849
|
-
# Test with very low global retry limit
|
|
850
|
-
try:
|
|
851
|
-
await gather_limited_sync(
|
|
852
|
-
lambda: flaky_task("task1"),
|
|
853
|
-
lambda: flaky_task("task2"),
|
|
854
|
-
retry_settings=RetrySettings(
|
|
855
|
-
max_task_retries=5, # Each task could retry up to 5 times
|
|
856
|
-
max_total_retries=3, # But only 3 total retries across all tasks
|
|
857
|
-
initial_backoff=0.01,
|
|
858
|
-
max_backoff=0.1,
|
|
859
|
-
backoff_factor=2.0,
|
|
860
|
-
),
|
|
861
|
-
return_exceptions=True,
|
|
862
|
-
)
|
|
863
|
-
except Exception:
|
|
864
|
-
pass # Expected to fail due to rate limits
|
|
865
|
-
|
|
866
|
-
# Verify that total retries across both tasks doesn't exceed global limit
|
|
867
|
-
total_retries = (retry_counts["task1"] - 1) + (
|
|
868
|
-
retry_counts["task2"] - 1
|
|
869
|
-
) # -1 for initial attempts
|
|
870
|
-
assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
|
|
871
|
-
|
|
872
|
-
# Verify that both tasks were attempted at least once
|
|
873
|
-
assert retry_counts["task1"] >= 1
|
|
874
|
-
assert retry_counts["task2"] >= 1
|
|
875
|
-
|
|
876
|
-
asyncio.run(run_test())
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def test_gather_limited_funcspec_format():
|
|
880
|
-
"""Test gather_limited with FuncSpec format and custom labeler accessing args."""
|
|
881
|
-
import asyncio
|
|
882
|
-
|
|
883
|
-
async def run_test():
|
|
884
|
-
def sync_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
885
|
-
"""Sync function that takes args and kwargs."""
|
|
886
|
-
return f"{name}: {value * multiplier}"
|
|
887
|
-
|
|
888
|
-
async def async_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
889
|
-
"""Async function that takes args and kwargs."""
|
|
890
|
-
await asyncio.sleep(0.01)
|
|
891
|
-
return f"{name}: {value * multiplier}"
|
|
892
|
-
|
|
893
|
-
captured_labels = []
|
|
894
|
-
|
|
895
|
-
def custom_labeler(i: int, spec: Any) -> str:
|
|
896
|
-
if isinstance(spec, FuncTask):
|
|
897
|
-
# Extract meaningful info from args for labeling
|
|
898
|
-
if spec.args and len(spec.args) > 0:
|
|
899
|
-
label = f"Processing {spec.args[0]}"
|
|
900
|
-
else:
|
|
901
|
-
label = f"Task {i}"
|
|
902
|
-
else:
|
|
903
|
-
label = f"Task {i}"
|
|
904
|
-
captured_labels.append(label)
|
|
905
|
-
return label
|
|
906
|
-
|
|
907
|
-
# Test sync version with FuncSpec format and custom labeler
|
|
908
|
-
sync_results = await gather_limited_sync(
|
|
909
|
-
FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
|
|
910
|
-
FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
|
|
911
|
-
labeler=custom_labeler,
|
|
912
|
-
retry_settings=NO_RETRIES,
|
|
913
|
-
)
|
|
914
|
-
|
|
915
|
-
assert sync_results == ["user1: 300", "user2: 100"]
|
|
916
|
-
assert captured_labels == ["Processing user1", "Processing user2"]
|
|
917
|
-
|
|
918
|
-
# Reset labels for async test
|
|
919
|
-
captured_labels.clear()
|
|
920
|
-
|
|
921
|
-
# Test async version with FuncSpec format and custom labeler
|
|
922
|
-
async_results = await gather_limited_async(
|
|
923
|
-
FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
|
|
924
|
-
FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
|
|
925
|
-
labeler=custom_labeler,
|
|
926
|
-
retry_settings=NO_RETRIES,
|
|
927
|
-
)
|
|
928
|
-
|
|
929
|
-
assert async_results == ["api_call: 40", "data_fetch: 10"]
|
|
930
|
-
assert captured_labels == ["Processing api_call", "Processing data_fetch"]
|
|
931
|
-
|
|
932
|
-
asyncio.run(run_test())
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
def test_gather_limited_sync_cooperative_cancellation():
|
|
936
|
-
"""Test gather_limited_sync with cooperative cancellation via threading.Event."""
|
|
937
|
-
import asyncio
|
|
938
|
-
import time
|
|
939
|
-
|
|
940
|
-
async def run_test():
|
|
941
|
-
cancel_event = threading.Event()
|
|
942
|
-
call_counts = {"task1": 0, "task2": 0}
|
|
943
|
-
|
|
944
|
-
def cancellable_sync_func(task_name: str, work_duration: float) -> str:
|
|
945
|
-
"""Sync function that checks cancellation event periodically."""
|
|
946
|
-
call_counts[task_name] += 1
|
|
947
|
-
start_time = time.time()
|
|
948
|
-
|
|
949
|
-
while time.time() - start_time < work_duration:
|
|
950
|
-
if cancel_event.is_set():
|
|
951
|
-
return f"{task_name}: cancelled"
|
|
952
|
-
time.sleep(0.01) # Small sleep to allow cancellation check
|
|
953
|
-
|
|
954
|
-
return f"{task_name}: completed"
|
|
955
|
-
|
|
956
|
-
# Test cooperative cancellation - tasks should respect the cancel_event
|
|
957
|
-
results = await gather_limited_sync(
|
|
958
|
-
lambda: cancellable_sync_func("task1", 0.1), # Short duration
|
|
959
|
-
lambda: cancellable_sync_func("task2", 0.1), # Short duration
|
|
960
|
-
cancel_event=cancel_event,
|
|
961
|
-
cancel_timeout=1.0,
|
|
962
|
-
retry_settings=NO_RETRIES,
|
|
963
|
-
)
|
|
964
|
-
|
|
965
|
-
# Should complete normally since cancel_event wasn't set
|
|
966
|
-
assert results == ["task1: completed", "task2: completed"]
|
|
967
|
-
assert call_counts["task1"] == 1
|
|
968
|
-
assert call_counts["task2"] == 1
|
|
969
|
-
|
|
970
|
-
# Test that cancel_event can be used independently
|
|
971
|
-
cancel_event.set() # Set cancellation signal
|
|
972
|
-
|
|
973
|
-
results2 = await gather_limited_sync(
|
|
974
|
-
lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
|
|
975
|
-
lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
|
|
976
|
-
cancel_event=cancel_event,
|
|
977
|
-
cancel_timeout=1.0,
|
|
978
|
-
retry_settings=NO_RETRIES,
|
|
979
|
-
)
|
|
980
|
-
|
|
981
|
-
# Should be cancelled quickly since cancel_event is already set
|
|
982
|
-
assert results2 == ["task1: cancelled", "task2: cancelled"]
|
|
983
|
-
# Call counts should increment
|
|
984
|
-
assert call_counts["task1"] == 2
|
|
985
|
-
assert call_counts["task2"] == 2
|
|
986
|
-
|
|
987
|
-
asyncio.run(run_test())
|