kash-shell 0.3.22__py3-none-any.whl → 0.3.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,9 +4,10 @@ import asyncio
4
4
  import inspect
5
5
  import logging
6
6
  import threading
7
+ import time
7
8
  from collections.abc import Callable, Coroutine
8
9
  from dataclasses import dataclass, field
9
- from typing import Any, TypeAlias, TypeVar, cast, overload
10
+ from typing import Any, Generic, TypeAlias, TypeVar, cast, overload
10
11
 
11
12
  from aiolimiter import AsyncLimiter
12
13
 
@@ -16,6 +17,7 @@ from kash.utils.api_utils.api_retries import (
16
17
  RetryExhaustedException,
17
18
  RetrySettings,
18
19
  calculate_backoff,
20
+ extract_http_status_code,
19
21
  )
20
22
  from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
21
23
 
@@ -25,28 +27,82 @@ log = logging.getLogger(__name__)
25
27
 
26
28
 
27
29
  @dataclass(frozen=True)
28
- class FuncTask:
30
+ class Limit:
31
+ """
32
+ Rate limiting configuration with max RPS and max concurrency.
33
+ """
34
+
35
+ rps: float
36
+ concurrency: int
37
+
38
+
39
+ DEFAULT_CANCEL_TIMEOUT: float = 1.0
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class TaskResult(Generic[T]):
44
+ """
45
+ Optional wrapper for task results that can signal rate limiting behavior.
46
+ Use this to wrap results that should bypass rate limiting (e.g., cache hits).
47
+ """
48
+
49
+ value: T
50
+ disable_limits: bool = False
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class FuncTask(Generic[T]):
29
55
  """
30
56
  A task described as an unevaluated function with args and kwargs.
31
57
  This task format allows you to use args and kwargs in the Labeler.
58
+ It also allows specifying a bucket for rate limiting.
59
+
60
+ For async functions: The function should return a coroutine that yields either `T` or `TaskResult[T]`.
61
+ For sync functions: The function should return either `T` or `TaskResult[T]` directly.
62
+
63
+ Using `TaskResult[T]` allows controlling rate limiting behavior (e.g., cache hits can bypass rate limits).
32
64
  """
33
65
 
34
- func: Callable[..., Any]
66
+ func: Callable[..., Any] # Keep as Any since it can be sync or async
35
67
  args: tuple[Any, ...] = ()
36
68
  kwargs: dict[str, Any] = field(default_factory=dict)
69
+ bucket: str = "default"
37
70
 
38
71
 
39
72
  # Type aliases for coroutine and sync specifications, including unevaluated function specs
40
- CoroSpec: TypeAlias = Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask
41
- SyncSpec: TypeAlias = Callable[[], T] | FuncTask
73
+ CoroSpec: TypeAlias = (
74
+ Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask[T]
75
+ )
76
+ SyncSpec: TypeAlias = Callable[[], T] | FuncTask[T]
42
77
 
43
78
  # Specific labeler types using the generic Labeler pattern
44
79
  CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
45
80
  SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
46
81
 
47
- DEFAULT_MAX_CONCURRENT: int = 5
48
- DEFAULT_MAX_RPS: float = 5.0
49
- DEFAULT_CANCEL_TIMEOUT: float = 1.0
82
+
83
+ def _get_bucket_limits(
84
+ bucket: str,
85
+ bucket_semaphores: dict[str, asyncio.Semaphore],
86
+ bucket_rate_limiters: dict[str, AsyncLimiter],
87
+ ) -> tuple[asyncio.Semaphore | None, AsyncLimiter | None]:
88
+ """
89
+ Get bucket-specific limits with fallback to "*" wildcard.
90
+
91
+ Checks for exact bucket match first, then falls back to "*" if available.
92
+ Returns (None, None) if neither exact match nor "*" fallback exists.
93
+ """
94
+ # Try exact bucket match first
95
+ bucket_semaphore = bucket_semaphores.get(bucket)
96
+ bucket_rate_limiter = bucket_rate_limiters.get(bucket)
97
+
98
+ if bucket_semaphore is not None and bucket_rate_limiter is not None:
99
+ return bucket_semaphore, bucket_rate_limiter
100
+
101
+ # Fall back to "*" wildcard if available
102
+ bucket_semaphore = bucket_semaphores.get("*")
103
+ bucket_rate_limiter = bucket_rate_limiters.get("*")
104
+
105
+ return bucket_semaphore, bucket_rate_limiter
50
106
 
51
107
 
52
108
  class RetryCounter:
@@ -75,8 +131,8 @@ class RetryCounter:
75
131
  @overload
76
132
  async def gather_limited_async(
77
133
  *coro_specs: CoroSpec[T],
78
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
79
- max_rps: float = DEFAULT_MAX_RPS,
134
+ limit: Limit | None,
135
+ bucket_limits: dict[str, Limit] | None = None,
80
136
  return_exceptions: bool = False,
81
137
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
82
138
  status: ProgressTracker | None = None,
@@ -87,8 +143,8 @@ async def gather_limited_async(
87
143
  @overload
88
144
  async def gather_limited_async(
89
145
  *coro_specs: CoroSpec[T],
90
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
91
- max_rps: float = DEFAULT_MAX_RPS,
146
+ limit: Limit | None,
147
+ bucket_limits: dict[str, Limit] | None = None,
92
148
  return_exceptions: bool = True,
93
149
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
94
150
  status: ProgressTracker | None = None,
@@ -98,27 +154,42 @@ async def gather_limited_async(
98
154
 
99
155
  async def gather_limited_async(
100
156
  *coro_specs: CoroSpec[T],
101
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
102
- max_rps: float = DEFAULT_MAX_RPS,
103
- return_exceptions: bool = False,
157
+ limit: Limit | None,
158
+ bucket_limits: dict[str, Limit] | None = None,
159
+ return_exceptions: bool = True, # Default to True for resilient batch operations
104
160
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
105
161
  status: ProgressTracker | None = None,
106
162
  labeler: CoroLabeler[T] | None = None,
107
163
  ) -> list[T] | list[T | BaseException]:
108
164
  """
109
- Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
165
+ Rate-limited version of `asyncio.gather()` with HTTP-aware retry logic and optional progress display.
110
166
  Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
111
167
 
112
168
  Supports two levels of retry limits:
113
169
  - Per-task retries: max_task_retries attempts per individual task
114
170
  - Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
115
171
 
172
+ Features HTTP-aware retry classification:
173
+ - Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
174
+ - Configurable handling of conditional status codes like 403 Forbidden
175
+ - Defaults to return_exceptions=True for resilient batch operations
176
+
177
+ Prevents API flooding during rate limit backoffs:
178
+ - Semaphore slots are held during entire retry cycles, including backoff periods
179
+ - New tasks cannot start while existing tasks are backing off from rate limits
180
+ - Rate limiters are only acquired during actual execution attempts
181
+
116
182
  Can optionally display live progress with retry indicators using TaskStatus.
117
183
 
118
184
  Accepts:
119
185
  - Callables that return coroutines: `lambda: some_async_func(arg)` (recommended for retries)
120
186
  - Coroutines directly: `some_async_func(arg)` (only if retries disabled)
121
- - FuncSpec objects: `FuncSpec(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
187
+ - FuncTask objects: `FuncTask(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
188
+
189
+ Functions can return `TaskResult[T]` to bypass rate limiting for specific results:
190
+ - `TaskResult(value, disable_limits=True)`: bypass rate limiting (e.g., cache hits)
191
+ - `TaskResult(value, disable_limits=False)`: apply normal rate limiting
192
+ - `value` directly: apply normal rate limiting
122
193
 
123
194
  Examples:
124
195
  ```python
@@ -126,7 +197,7 @@ async def gather_limited_async(
126
197
  from kash.utils.rich_custom.task_status import TaskStatus
127
198
 
128
199
  async with TaskStatus() as status:
129
- await gather_limited(
200
+ await gather_limited_async(
130
201
  lambda: fetch_url("http://example.com"),
131
202
  lambda: process_data(data),
132
203
  status=status,
@@ -135,18 +206,47 @@ async def gather_limited_async(
135
206
  )
136
207
 
137
208
  # Without progress display:
138
- await gather_limited(
209
+ await gather_limited_async(
139
210
  lambda: fetch_url("http://example.com"),
140
211
  lambda: process_data(data),
141
212
  retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
142
213
  )
143
214
 
215
+ # With bucket-specific limits and "*" fallback:
216
+ await gather_limited_async(
217
+ FuncTask(fetch_api, ("data1",), bucket="api1"),
218
+ FuncTask(fetch_api, ("data2",), bucket="api2"),
219
+ FuncTask(fetch_api, ("data3",), bucket="api3"),
220
+ limit=Limit(rps=100, concurrency=50),
221
+ bucket_limits={
222
+ "api1": Limit(rps=20, concurrency=10), # Specific limit for api1
223
+ "*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
224
+ }
225
+ )
226
+
227
+ # With cache bypass using TaskResult:
228
+ async def fetch_with_cache(url: str) -> TaskResult[dict] | dict:
229
+ cached_data = await cache.get(url)
230
+ if cached_data:
231
+ return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
232
+
233
+ data = await fetch_api(url) # Will be rate limited
234
+ await cache.set(url, data)
235
+ return data
236
+
237
+ await gather_limited_async(
238
+ lambda: fetch_with_cache("http://api.com/data1"),
239
+ lambda: fetch_with_cache("http://api.com/data2"),
240
+ limit=Limit(rps=10, concurrency=5) # Cache hits bypass these limits
241
+ )
242
+
144
243
  ```
145
244
 
146
245
  Args:
147
246
  *coro_specs: Callables or coroutines to execute
148
- max_concurrent: Maximum number of concurrent executions
149
- max_rps: Maximum requests per second
247
+ limit: Global limits applied to all tasks regardless of bucket
248
+ bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
249
+ Use "*" as a fallback limit for buckets without specific limits.
150
250
  return_exceptions: If True, exceptions are returned as results
151
251
  retry_settings: Configuration for retry behavior, or None to disable retries
152
252
  status: Optional ProgressTracker instance for progress display
@@ -159,9 +259,9 @@ async def gather_limited_async(
159
259
  ValueError: If coroutines are passed when retries are enabled
160
260
  """
161
261
  log.info(
162
- "Executing with concurrency %s at %s rps, %s",
163
- max_concurrent,
164
- max_rps,
262
+ "Executing with global limits: concurrency %s at %s rps, %s",
263
+ limit.concurrency if limit else "None",
264
+ limit.rps if limit else "None",
165
265
  retry_settings,
166
266
  )
167
267
  if not coro_specs:
@@ -179,8 +279,18 @@ async def gather_limited_async(
179
279
  f"lambda: your_async_func(args) instead of your_async_func(args)"
180
280
  )
181
281
 
182
- semaphore = asyncio.Semaphore(max_concurrent)
183
- rate_limiter = AsyncLimiter(max_rps, 1.0)
282
+ # Global limits (apply to all tasks regardless of bucket)
283
+ global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
284
+ global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
285
+
286
+ # Per-bucket limits (if bucket_limits provided)
287
+ bucket_semaphores: dict[str, asyncio.Semaphore] = {}
288
+ bucket_rate_limiters: dict[str, AsyncLimiter] = {}
289
+
290
+ if bucket_limits:
291
+ for bucket_name, limit in bucket_limits.items():
292
+ bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
293
+ bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
184
294
 
185
295
  # Global retry counter (shared across all tasks)
186
296
  global_retry_counter = RetryCounter(retry_settings.max_total_retries)
@@ -190,6 +300,16 @@ async def gather_limited_async(
190
300
  label = labeler(i, coro_spec) if labeler else f"task:{i}"
191
301
  task_id = await status.add(label) if status else None
192
302
 
303
+ # Determine bucket and get appropriate limits
304
+ bucket = "default"
305
+ if isinstance(coro_spec, FuncTask):
306
+ bucket = coro_spec.bucket
307
+
308
+ # Get bucket-specific limits if available
309
+ bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
310
+ bucket, bucket_semaphores, bucket_rate_limiters
311
+ )
312
+
193
313
  async def executor() -> T:
194
314
  # Create a fresh coroutine for each attempt
195
315
  if isinstance(coro_spec, FuncTask):
@@ -198,7 +318,7 @@ async def gather_limited_async(
198
318
  elif callable(coro_spec):
199
319
  coro = coro_spec()
200
320
  else:
201
- # Direct coroutine - only valid if retries disabled
321
+ # Direct coroutine: only valid if retries disabled
202
322
  coro = coro_spec
203
323
  return await coro
204
324
 
@@ -206,8 +326,10 @@ async def gather_limited_async(
206
326
  result = await _execute_with_retry(
207
327
  executor,
208
328
  retry_settings,
209
- semaphore,
210
- rate_limiter,
329
+ global_semaphore,
330
+ global_rate_limiter,
331
+ bucket_semaphore,
332
+ bucket_rate_limiter,
211
333
  global_retry_counter,
212
334
  status,
213
335
  task_id,
@@ -234,8 +356,8 @@ async def gather_limited_async(
234
356
  @overload
235
357
  async def gather_limited_sync(
236
358
  *sync_specs: SyncSpec[T],
237
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
238
- max_rps: float = DEFAULT_MAX_RPS,
359
+ limit: Limit | None,
360
+ bucket_limits: dict[str, Limit] | None = None,
239
361
  return_exceptions: bool = False,
240
362
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
241
363
  status: ProgressTracker | None = None,
@@ -248,8 +370,8 @@ async def gather_limited_sync(
248
370
  @overload
249
371
  async def gather_limited_sync(
250
372
  *sync_specs: SyncSpec[T],
251
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
252
- max_rps: float = DEFAULT_MAX_RPS,
373
+ limit: Limit | None,
374
+ bucket_limits: dict[str, Limit] | None = None,
253
375
  return_exceptions: bool = True,
254
376
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
255
377
  status: ProgressTracker | None = None,
@@ -261,9 +383,9 @@ async def gather_limited_sync(
261
383
 
262
384
  async def gather_limited_sync(
263
385
  *sync_specs: SyncSpec[T],
264
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
265
- max_rps: float = DEFAULT_MAX_RPS,
266
- return_exceptions: bool = False,
386
+ limit: Limit | None,
387
+ bucket_limits: dict[str, Limit] | None = None,
388
+ return_exceptions: bool = True, # Default to True for resilient batch operations
267
389
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
268
390
  status: ProgressTracker | None = None,
269
391
  labeler: SyncLabeler[T] | None = None,
@@ -271,54 +393,74 @@ async def gather_limited_sync(
271
393
  cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
272
394
  ) -> list[T] | list[T | BaseException]:
273
395
  """
274
- Rate-limited version of `asyncio.gather()` for sync functions with retry logic.
275
- Handles the asyncio.to_thread() boundary correctly for consistent exception propagation.
396
+ Sync version of `gather_limited_async()` that executes synchronous functions with the same
397
+ rate limiting, retry logic, and progress tracking capabilities.
398
+ See `gather_limited_async()` for details.
276
399
 
277
- Supports two levels of retry limits:
278
- - Per-task retries: max_task_retries attempts per individual task
279
- - Global retries: max_total_retries attempts across all tasks
400
+ Sync-specific differences:
401
+
402
+ **Function Execution**: Runs sync functions via `asyncio.to_thread()` instead of executing
403
+ coroutines directly. Validates that callables don't accidentally return coroutines.
280
404
 
281
- Supports cooperative cancellation and graceful thread termination on interruption.
405
+ **Cancellation Support**: Provides cooperative cancellation for long-running sync functions
406
+ through optional `cancel_event` and graceful thread termination.
407
+
408
+ **Input Validation**: Ensures sync functions don't return coroutines, which would indicate
409
+ incorrect usage (async functions should use `gather_limited_async()`).
282
410
 
283
411
  Args:
284
- *sync_specs: Callables that return values (not coroutines) or FuncTask objects
285
- max_concurrent: Maximum number of concurrent executions
286
- max_rps: Maximum requests per second
287
- return_exceptions: If True, exceptions are returned as results
288
- retry_settings: Configuration for retry behavior, or None to disable retries
289
- status: Optional ProgressTracker instance for progress display
290
- labeler: Optional function to generate labels: labeler(index, spec) -> str
412
+ *sync_specs: Callables that return values (not coroutines) or FuncTask objects.
413
+ Functions can return `TaskResult[T]` to control rate limiting behavior.
291
414
  cancel_event: Optional threading.Event that will be set on cancellation
292
415
  cancel_timeout: Max seconds to wait for threads to terminate on cancellation
416
+ (All other args identical to gather_limited_async())
293
417
 
294
418
  Returns:
295
419
  List of results in the same order as input specifications
296
420
 
297
421
  Example:
298
422
  ```python
299
- # Without cooperative cancellation
423
+ # Basic usage with sync functions
300
424
  results = await gather_limited_sync(
301
425
  lambda: some_sync_function(arg1),
302
426
  lambda: another_sync_function(arg2),
303
- max_concurrent=3,
304
- max_rps=2.0,
427
+ limit=Limit(rps=2.0, concurrency=3),
305
428
  retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
306
429
  )
307
430
 
308
- # With cooperative cancellation
309
- cancel_event = threading.Event()
431
+ # With bucket-specific limits and "*" fallback
310
432
  results = await gather_limited_sync(
311
- lambda: cancellable_sync_function(cancel_event, arg1),
312
- lambda: another_cancellable_function(cancel_event, arg2),
313
- cancel_event=cancel_event,
314
- cancel_timeout=5.0,
433
+ FuncTask(fetch_from_api, ("data1",), bucket="api1"),
434
+ FuncTask(fetch_from_api, ("data2",), bucket="api2"),
435
+ FuncTask(fetch_from_api, ("data3",), bucket="api3"),
436
+ limit=Limit(rps=100, concurrency=50),
437
+ bucket_limits={
438
+ "api1": Limit(rps=20, concurrency=10), # Specific limit for api1
439
+ "*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
440
+ }
441
+ )
442
+
443
+ # With cache bypass using TaskResult
444
+ def sync_fetch_with_cache(url: str) -> TaskResult[dict] | dict:
445
+ cached_data = cache.get_sync(url)
446
+ if cached_data:
447
+ return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
448
+
449
+ data = requests.get(url).json() # Will be rate limited
450
+ cache.set_sync(url, data)
451
+ return data
452
+
453
+ results = await gather_limited_sync(
454
+ lambda: sync_fetch_with_cache("http://api.com/data1"),
455
+ lambda: sync_fetch_with_cache("http://api.com/data2"),
456
+ limit=Limit(rps=5, concurrency=3) # Cache hits bypass these limits
315
457
  )
316
458
  ```
317
459
  """
318
460
  log.info(
319
- "Executing with concurrency %s at %s rps, %s",
320
- max_concurrent,
321
- max_rps,
461
+ "Executing with global limits: concurrency %s at %s rps, %s",
462
+ limit.concurrency if limit else "None",
463
+ limit.rps if limit else "None",
322
464
  retry_settings,
323
465
  )
324
466
  if not sync_specs:
@@ -326,8 +468,18 @@ async def gather_limited_sync(
326
468
 
327
469
  retry_settings = retry_settings or NO_RETRIES
328
470
 
329
- semaphore = asyncio.Semaphore(max_concurrent)
330
- rate_limiter = AsyncLimiter(max_rps, 1.0)
471
+ # Global limits (apply to all tasks regardless of bucket)
472
+ global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
473
+ global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
474
+
475
+ # Per-bucket limits (if bucket_limits provided)
476
+ bucket_semaphores: dict[str, asyncio.Semaphore] = {}
477
+ bucket_rate_limiters: dict[str, AsyncLimiter] = {}
478
+
479
+ if bucket_limits:
480
+ for bucket_name, limit in bucket_limits.items():
481
+ bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
482
+ bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
331
483
 
332
484
  # Global retry counter (shared across all tasks)
333
485
  global_retry_counter = RetryCounter(retry_settings.max_total_retries)
@@ -337,6 +489,16 @@ async def gather_limited_sync(
337
489
  label = labeler(i, sync_spec) if labeler else f"task:{i}"
338
490
  task_id = await status.add(label) if status else None
339
491
 
492
+ # Determine bucket and get appropriate limits
493
+ bucket = "default"
494
+ if isinstance(sync_spec, FuncTask):
495
+ bucket = sync_spec.bucket
496
+
497
+ # Get bucket-specific limits if available
498
+ bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
499
+ bucket, bucket_semaphores, bucket_rate_limiters
500
+ )
501
+
340
502
  async def executor() -> T:
341
503
  # Call sync function via asyncio.to_thread, handling retry at this level
342
504
  if isinstance(sync_spec, FuncTask):
@@ -361,8 +523,10 @@ async def gather_limited_sync(
361
523
  result = await _execute_with_retry(
362
524
  executor,
363
525
  retry_settings,
364
- semaphore,
365
- rate_limiter,
526
+ global_semaphore,
527
+ global_rate_limiter,
528
+ bucket_semaphore,
529
+ bucket_rate_limiter,
366
530
  global_retry_counter,
367
531
  status,
368
532
  task_id,
@@ -434,7 +598,8 @@ async def _gather_with_interrupt_handling(
434
598
  if cancelled_count > 0:
435
599
  try:
436
600
  await asyncio.wait_for(
437
- asyncio.gather(*async_tasks, return_exceptions=True), timeout=cancel_timeout
601
+ asyncio.gather(*async_tasks, return_exceptions=True),
602
+ timeout=cancel_timeout,
438
603
  )
439
604
  except (TimeoutError, asyncio.CancelledError):
440
605
  log.warning("Some tasks did not cancel within timeout")
@@ -462,20 +627,95 @@ async def _gather_with_interrupt_handling(
462
627
  async def _execute_with_retry(
463
628
  executor: Callable[[], Coroutine[None, None, T]],
464
629
  retry_settings: RetrySettings,
465
- semaphore: asyncio.Semaphore,
466
- rate_limiter: AsyncLimiter,
630
+ global_semaphore: asyncio.Semaphore | None,
631
+ global_rate_limiter: AsyncLimiter | None,
632
+ bucket_semaphore: asyncio.Semaphore | None,
633
+ bucket_rate_limiter: AsyncLimiter | None,
467
634
  global_retry_counter: RetryCounter,
468
635
  status: ProgressTracker | None = None,
469
636
  task_id: Any | None = None,
470
637
  ) -> T:
471
- import time
638
+ """
639
+ Execute a task with retry logic, holding semaphores for the entire retry cycle.
640
+
641
+ Semaphores are held during backoff periods to prevent API flooding when tasks
642
+ are already backing off from rate limits. Only rate limiters are acquired/released
643
+ for each execution attempt.
644
+ """
645
+ # Acquire semaphores once for the entire retry cycle (including backoff periods)
646
+ if global_semaphore is not None:
647
+ async with global_semaphore:
648
+ if bucket_semaphore is not None:
649
+ async with bucket_semaphore:
650
+ return await _execute_with_retry_inner(
651
+ executor,
652
+ retry_settings,
653
+ global_rate_limiter,
654
+ bucket_rate_limiter,
655
+ global_retry_counter,
656
+ status,
657
+ task_id,
658
+ )
659
+ else:
660
+ return await _execute_with_retry_inner(
661
+ executor,
662
+ retry_settings,
663
+ global_rate_limiter,
664
+ None,
665
+ global_retry_counter,
666
+ status,
667
+ task_id,
668
+ )
669
+ else:
670
+ # No global semaphore, check bucket semaphore
671
+ if bucket_semaphore is not None:
672
+ async with bucket_semaphore:
673
+ return await _execute_with_retry_inner(
674
+ executor,
675
+ retry_settings,
676
+ global_rate_limiter,
677
+ bucket_rate_limiter,
678
+ global_retry_counter,
679
+ status,
680
+ task_id,
681
+ )
682
+ else:
683
+ # No semaphores at all, go straight to running
684
+ return await _execute_with_retry_inner(
685
+ executor,
686
+ retry_settings,
687
+ global_rate_limiter,
688
+ None,
689
+ global_retry_counter,
690
+ status,
691
+ task_id,
692
+ )
693
+
694
+
695
+ async def _execute_with_retry_inner(
696
+ executor: Callable[[], Coroutine[None, None, T]],
697
+ retry_settings: RetrySettings,
698
+ global_rate_limiter: AsyncLimiter | None,
699
+ bucket_rate_limiter: AsyncLimiter | None,
700
+ global_retry_counter: RetryCounter,
701
+ status: ProgressTracker | None = None,
702
+ task_id: Any | None = None,
703
+ ) -> T:
704
+ """
705
+ Inner retry logic that handles rate limiting and backoff.
706
+
707
+ This function assumes semaphores are already held by the caller and only
708
+ manages rate limiters for each execution attempt.
709
+ """
710
+ if status and task_id:
711
+ await status.start(task_id)
472
712
 
473
713
  start_time = time.time()
474
714
  last_exception: Exception | None = None
475
715
 
476
716
  for attempt in range(retry_settings.max_task_retries + 1):
477
- # Handle backoff before acquiring any resources
478
- if attempt > 0 and last_exception is not None:
717
+ # Handle backoff before acquiring rate limiters (semaphores remain held)
718
+ if attempt > 0 and last_exception:
479
719
  # Try to increment global retry counter
480
720
  if not await global_retry_counter.try_increment():
481
721
  log.error(
@@ -493,13 +733,13 @@ async def _execute_with_retry(
493
733
  )
494
734
 
495
735
  # Record retry in status display and log appropriately
496
- if status and task_id is not None:
736
+ if status and task_id:
497
737
  # Include retry attempt info and backoff time in the status display
498
738
  retry_info = (
499
739
  f"Attempt {attempt}/{retry_settings.max_task_retries} "
500
740
  f"(waiting {backoff_time:.1f}s): {type(last_exception).__name__}: {last_exception}"
501
741
  )
502
- await status.update(task_id, error_msg=retry_info)
742
+ await status.update(task_id, TaskState.WAITING, error_msg=retry_info)
503
743
 
504
744
  # Use debug level for Rich trackers, warning/info for console trackers
505
745
  use_debug_level = status.suppress_logs
@@ -508,8 +748,11 @@ async def _execute_with_retry(
508
748
  use_debug_level = False
509
749
 
510
750
  # Log retry information at appropriate level
751
+ status_code = extract_http_status_code(last_exception)
752
+ status_info = f" (HTTP {status_code})" if status_code else ""
753
+
511
754
  rate_limit_msg = (
512
- f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
755
+ f"Rate limit hit{status_info} (attempt {attempt}/{retry_settings.max_task_retries} "
513
756
  f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
514
757
  f"backing off for {backoff_time:.2f}s"
515
758
  )
@@ -523,25 +766,57 @@ async def _execute_with_retry(
523
766
  else:
524
767
  log.warning(rate_limit_msg)
525
768
  log.info(exception_msg)
769
+
770
+ # Sleep during backoff while holding semaphore slots
526
771
  await asyncio.sleep(backoff_time)
527
772
 
773
+ # Set state back to running before next attempt
774
+ if status and task_id:
775
+ await status.update(task_id, TaskState.RUNNING)
776
+
528
777
  try:
529
- # Acquire semaphore and rate limiter right before making the call
530
- async with semaphore, rate_limiter:
531
- # Mark task as started now that we've passed rate limiting
532
- if status and task_id is not None and attempt == 0:
533
- await status.start(task_id)
534
- return await executor()
778
+ # Execute task to check for potential rate limit bypass
779
+ raw_result = await executor()
780
+
781
+ # Check if result indicates limits should be disabled (e.g., cache hit)
782
+ if isinstance(raw_result, TaskResult):
783
+ if raw_result.disable_limits:
784
+ # Bypass rate limiting and return immediately
785
+ return cast(T, raw_result.value)
786
+ else:
787
+ # Wrapped but limits enabled: extract value for rate limiting
788
+ result_value = cast(T, raw_result.value)
789
+ else:
790
+ # Unwrapped result: apply normal rate limiting
791
+ result_value = cast(T, raw_result)
792
+
793
+ # Apply rate limiting for non-bypassed results
794
+ if global_rate_limiter is not None:
795
+ async with global_rate_limiter:
796
+ if bucket_rate_limiter is not None:
797
+ async with bucket_rate_limiter:
798
+ return result_value
799
+ else:
800
+ return result_value
801
+ else:
802
+ # No global rate limiter, check bucket rate limiter
803
+ if bucket_rate_limiter is not None:
804
+ async with bucket_rate_limiter:
805
+ return result_value
806
+ else:
807
+ # No rate limiting at all
808
+ return result_value
809
+
535
810
  except Exception as e:
536
811
  last_exception = e # Always store the exception
537
812
 
538
813
  if attempt == retry_settings.max_task_retries:
539
814
  # Final attempt failed
540
815
  if retry_settings.max_task_retries == 0:
541
- # No retries configured - raise original exception directly
816
+ # No retries configured: raise original exception directly
542
817
  raise
543
818
  else:
544
- # Retries were attempted but exhausted - wrap with context
819
+ # Retries were attempted but exhausted: wrap with context
545
820
  total_time = time.time() - start_time
546
821
  log.error(
547
822
  f"Max task retries ({retry_settings.max_task_retries}) exhausted after {total_time:.1f}s. "
@@ -549,439 +824,18 @@ async def _execute_with_retry(
549
824
  )
550
825
  raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
551
826
 
552
- # Check if this is a retriable exception
553
- if retry_settings.is_retriable(e):
554
- # Continue to next retry attempt (global limits will be checked at top of loop)
827
+ # Check if this is a retriable exception using the centralized logic
828
+ if retry_settings.should_retry(e):
829
+ # Continue to next retry attempt (semaphores remain held for backoff)
555
830
  continue
556
831
  else:
557
832
  # Non-retriable exception, log and re-raise immediately
558
- log.warning("Non-retriable exception (not retrying): %s", e, exc_info=True)
833
+ status_code = extract_http_status_code(e)
834
+ status_info = f" (HTTP {status_code})" if status_code else ""
835
+
836
+ log.warning("Non-retriable exception%s (not retrying): %s", status_info, e)
837
+ log.debug("Exception traceback:", exc_info=True)
559
838
  raise
560
839
 
561
840
  # This should never be reached, but satisfy type checker
562
- raise RuntimeError("Unexpected code path in _execute_with_retry")
563
-
564
-
565
- ## Tests
566
-
567
-
568
- def test_gather_limited_sync():
569
- """Test gather_limited_sync with sync functions."""
570
- import asyncio
571
- import time
572
-
573
- async def run_test():
574
- def sync_func(value: int) -> int:
575
- """Simple sync function for testing."""
576
- time.sleep(0.1) # Simulate some work
577
- return value * 2
578
-
579
- # Test basic functionality
580
- results = await gather_limited_sync(
581
- lambda: sync_func(1),
582
- lambda: sync_func(2),
583
- lambda: sync_func(3),
584
- max_concurrent=2,
585
- max_rps=10.0,
586
- retry_settings=NO_RETRIES,
587
- )
588
-
589
- assert results == [2, 4, 6]
590
-
591
- # Run the async test
592
- asyncio.run(run_test())
593
-
594
-
595
- def test_gather_limited_sync_with_retries():
596
- """Test that sync functions can be retried on retriable exceptions."""
597
- import asyncio
598
-
599
- async def run_test():
600
- call_count = 0
601
-
602
- def flaky_sync_func() -> str:
603
- """Sync function that fails first time, succeeds second time."""
604
- nonlocal call_count
605
- call_count += 1
606
- if call_count == 1:
607
- raise Exception("Rate limit exceeded") # Retriable
608
- return "success"
609
-
610
- # Should succeed after retry
611
- results = await gather_limited_sync(
612
- lambda: flaky_sync_func(),
613
- retry_settings=RetrySettings(
614
- max_task_retries=2,
615
- initial_backoff=0.1,
616
- max_backoff=1.0,
617
- backoff_factor=2.0,
618
- ),
619
- )
620
-
621
- assert results == ["success"]
622
- assert call_count == 2 # Called twice (failed once, succeeded once)
623
-
624
- # Run the async test
625
- asyncio.run(run_test())
626
-
627
-
628
- def test_gather_limited_async_basic():
629
- """Test gather_limited with async functions using callables."""
630
- import asyncio
631
-
632
- async def run_test():
633
- async def async_func(value: int) -> int:
634
- """Simple async function for testing."""
635
- await asyncio.sleep(0.05) # Simulate async work
636
- return value * 3
637
-
638
- # Test with callables (recommended pattern)
639
- results = await gather_limited_async(
640
- lambda: async_func(1),
641
- lambda: async_func(2),
642
- lambda: async_func(3),
643
- max_concurrent=2,
644
- max_rps=10.0,
645
- retry_settings=NO_RETRIES,
646
- )
647
-
648
- assert results == [3, 6, 9]
649
-
650
- asyncio.run(run_test())
651
-
652
-
653
- def test_gather_limited_direct_coroutines():
654
- """Test gather_limited with direct coroutines when retries disabled."""
655
- import asyncio
656
-
657
- async def run_test():
658
- async def async_func(value: int) -> int:
659
- await asyncio.sleep(0.05)
660
- return value * 4
661
-
662
- # Test with direct coroutines (only works when retries disabled)
663
- results = await gather_limited_async(
664
- async_func(1),
665
- async_func(2),
666
- async_func(3),
667
- retry_settings=NO_RETRIES, # Required for direct coroutines
668
- )
669
-
670
- assert results == [4, 8, 12]
671
-
672
- asyncio.run(run_test())
673
-
674
-
675
- def test_gather_limited_coroutine_retry_validation():
676
- """Test that passing coroutines with retries enabled raises ValueError."""
677
- import asyncio
678
-
679
- async def run_test():
680
- async def async_func(value: int) -> int:
681
- return value
682
-
683
- coro = async_func(1) # Direct coroutine
684
-
685
- # Should raise ValueError when trying to use coroutines with retries
686
- try:
687
- await gather_limited_async(
688
- coro, # Direct coroutine
689
- lambda: async_func(2), # Callable
690
- retry_settings=RetrySettings(
691
- max_task_retries=1,
692
- initial_backoff=0.1,
693
- max_backoff=1.0,
694
- backoff_factor=2.0,
695
- ),
696
- )
697
- raise AssertionError("Expected ValueError")
698
- except ValueError as e:
699
- coro.close() # Close the unused coroutine to prevent RuntimeWarning
700
- assert "position 0" in str(e)
701
- assert "cannot be retried" in str(e)
702
-
703
- asyncio.run(run_test())
704
-
705
-
706
- def test_gather_limited_async_with_retries():
707
- """Test that async functions can be retried when using callables."""
708
- import asyncio
709
-
710
- async def run_test():
711
- call_count = 0
712
-
713
- async def flaky_async_func() -> str:
714
- """Async function that fails first time, succeeds second time."""
715
- nonlocal call_count
716
- call_count += 1
717
- if call_count == 1:
718
- raise Exception("Rate limit exceeded") # Retriable
719
- return "async_success"
720
-
721
- # Should succeed after retry using callable
722
- results = await gather_limited_async(
723
- lambda: flaky_async_func(),
724
- retry_settings=RetrySettings(
725
- max_task_retries=2,
726
- initial_backoff=0.1,
727
- max_backoff=1.0,
728
- backoff_factor=2.0,
729
- ),
730
- )
731
-
732
- assert results == ["async_success"]
733
- assert call_count == 2 # Called twice (failed once, succeeded once)
734
-
735
- asyncio.run(run_test())
736
-
737
-
738
- def test_gather_limited_sync_coroutine_validation():
739
- """Test that passing async function callables to sync version raises ValueError."""
740
- import asyncio
741
-
742
- async def run_test():
743
- async def async_func(value: int) -> int:
744
- return value
745
-
746
- # Should raise ValueError when trying to use async functions in sync version
747
- try:
748
- await gather_limited_sync(
749
- lambda: async_func(1), # Returns coroutine - should be rejected
750
- retry_settings=NO_RETRIES,
751
- )
752
- raise AssertionError("Expected ValueError")
753
- except ValueError as e:
754
- assert "returned a coroutine" in str(e)
755
- assert "gather_limited_sync() is for synchronous functions only" in str(e)
756
-
757
- asyncio.run(run_test())
758
-
759
-
760
- def test_gather_limited_retry_exhaustion():
761
- """Test that retry exhaustion produces clear error messages."""
762
- import asyncio
763
-
764
- async def run_test():
765
- call_count = 0
766
-
767
- def always_fails() -> str:
768
- """Function that always raises retriable exceptions."""
769
- nonlocal call_count
770
- call_count += 1
771
- raise Exception("Rate limit exceeded") # Always retriable
772
-
773
- # Should exhaust retries and raise RetryExhaustedException
774
- try:
775
- await gather_limited_sync(
776
- lambda: always_fails(),
777
- retry_settings=RetrySettings(
778
- max_task_retries=2,
779
- initial_backoff=0.01,
780
- max_backoff=0.1,
781
- backoff_factor=2.0,
782
- ),
783
- )
784
- raise AssertionError("Expected RetryExhaustedException")
785
- except RetryExhaustedException as e:
786
- assert "Max retries (2) exhausted" in str(e)
787
- assert "Rate limit exceeded" in str(e)
788
- assert isinstance(e.original_exception, Exception)
789
- assert call_count == 3 # Initial attempt + 2 retries
790
-
791
- asyncio.run(run_test())
792
-
793
-
794
- def test_gather_limited_return_exceptions():
795
- """Test return_exceptions=True behavior for both functions."""
796
- import asyncio
797
-
798
- async def run_test():
799
- def failing_sync() -> str:
800
- raise ValueError("sync error")
801
-
802
- async def failing_async() -> str:
803
- raise ValueError("async error")
804
-
805
- # Test sync version with exceptions returned
806
- sync_results = await gather_limited_sync(
807
- lambda: "success",
808
- lambda: failing_sync(),
809
- return_exceptions=True,
810
- retry_settings=NO_RETRIES,
811
- )
812
-
813
- assert len(sync_results) == 2
814
- assert sync_results[0] == "success"
815
- assert isinstance(sync_results[1], ValueError)
816
- assert str(sync_results[1]) == "sync error"
817
-
818
- async def success_async() -> str:
819
- return "async_success"
820
-
821
- # Test async version with exceptions returned
822
- async_results = await gather_limited_async(
823
- lambda: success_async(),
824
- lambda: failing_async(),
825
- return_exceptions=True,
826
- retry_settings=NO_RETRIES,
827
- )
828
-
829
- assert len(async_results) == 2
830
- assert async_results[0] == "async_success"
831
- assert isinstance(async_results[1], ValueError)
832
- assert str(async_results[1]) == "async error"
833
-
834
- asyncio.run(run_test())
835
-
836
-
837
- def test_gather_limited_global_retry_limit():
838
- """Test that global retry limits are enforced across all tasks."""
839
- import asyncio
840
-
841
- async def run_test():
842
- retry_counts = {"task1": 0, "task2": 0}
843
-
844
- def flaky_task(task_name: str) -> str:
845
- """Tasks that always fail but track retry counts."""
846
- retry_counts[task_name] += 1
847
- raise Exception(f"Rate limit exceeded in {task_name}")
848
-
849
- # Test with very low global retry limit
850
- try:
851
- await gather_limited_sync(
852
- lambda: flaky_task("task1"),
853
- lambda: flaky_task("task2"),
854
- retry_settings=RetrySettings(
855
- max_task_retries=5, # Each task could retry up to 5 times
856
- max_total_retries=3, # But only 3 total retries across all tasks
857
- initial_backoff=0.01,
858
- max_backoff=0.1,
859
- backoff_factor=2.0,
860
- ),
861
- return_exceptions=True,
862
- )
863
- except Exception:
864
- pass # Expected to fail due to rate limits
865
-
866
- # Verify that total retries across both tasks doesn't exceed global limit
867
- total_retries = (retry_counts["task1"] - 1) + (
868
- retry_counts["task2"] - 1
869
- ) # -1 for initial attempts
870
- assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
871
-
872
- # Verify that both tasks were attempted at least once
873
- assert retry_counts["task1"] >= 1
874
- assert retry_counts["task2"] >= 1
875
-
876
- asyncio.run(run_test())
877
-
878
-
879
- def test_gather_limited_funcspec_format():
880
- """Test gather_limited with FuncSpec format and custom labeler accessing args."""
881
- import asyncio
882
-
883
- async def run_test():
884
- def sync_func(name: str, value: int, multiplier: int = 2) -> str:
885
- """Sync function that takes args and kwargs."""
886
- return f"{name}: {value * multiplier}"
887
-
888
- async def async_func(name: str, value: int, multiplier: int = 2) -> str:
889
- """Async function that takes args and kwargs."""
890
- await asyncio.sleep(0.01)
891
- return f"{name}: {value * multiplier}"
892
-
893
- captured_labels = []
894
-
895
- def custom_labeler(i: int, spec: Any) -> str:
896
- if isinstance(spec, FuncTask):
897
- # Extract meaningful info from args for labeling
898
- if spec.args and len(spec.args) > 0:
899
- label = f"Processing {spec.args[0]}"
900
- else:
901
- label = f"Task {i}"
902
- else:
903
- label = f"Task {i}"
904
- captured_labels.append(label)
905
- return label
906
-
907
- # Test sync version with FuncSpec format and custom labeler
908
- sync_results = await gather_limited_sync(
909
- FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
910
- FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
911
- labeler=custom_labeler,
912
- retry_settings=NO_RETRIES,
913
- )
914
-
915
- assert sync_results == ["user1: 300", "user2: 100"]
916
- assert captured_labels == ["Processing user1", "Processing user2"]
917
-
918
- # Reset labels for async test
919
- captured_labels.clear()
920
-
921
- # Test async version with FuncSpec format and custom labeler
922
- async_results = await gather_limited_async(
923
- FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
924
- FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
925
- labeler=custom_labeler,
926
- retry_settings=NO_RETRIES,
927
- )
928
-
929
- assert async_results == ["api_call: 40", "data_fetch: 10"]
930
- assert captured_labels == ["Processing api_call", "Processing data_fetch"]
931
-
932
- asyncio.run(run_test())
933
-
934
-
935
- def test_gather_limited_sync_cooperative_cancellation():
936
- """Test gather_limited_sync with cooperative cancellation via threading.Event."""
937
- import asyncio
938
- import time
939
-
940
- async def run_test():
941
- cancel_event = threading.Event()
942
- call_counts = {"task1": 0, "task2": 0}
943
-
944
- def cancellable_sync_func(task_name: str, work_duration: float) -> str:
945
- """Sync function that checks cancellation event periodically."""
946
- call_counts[task_name] += 1
947
- start_time = time.time()
948
-
949
- while time.time() - start_time < work_duration:
950
- if cancel_event.is_set():
951
- return f"{task_name}: cancelled"
952
- time.sleep(0.01) # Small sleep to allow cancellation check
953
-
954
- return f"{task_name}: completed"
955
-
956
- # Test cooperative cancellation - tasks should respect the cancel_event
957
- results = await gather_limited_sync(
958
- lambda: cancellable_sync_func("task1", 0.1), # Short duration
959
- lambda: cancellable_sync_func("task2", 0.1), # Short duration
960
- cancel_event=cancel_event,
961
- cancel_timeout=1.0,
962
- retry_settings=NO_RETRIES,
963
- )
964
-
965
- # Should complete normally since cancel_event wasn't set
966
- assert results == ["task1: completed", "task2: completed"]
967
- assert call_counts["task1"] == 1
968
- assert call_counts["task2"] == 1
969
-
970
- # Test that cancel_event can be used independently
971
- cancel_event.set() # Set cancellation signal
972
-
973
- results2 = await gather_limited_sync(
974
- lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
975
- lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
976
- cancel_event=cancel_event,
977
- cancel_timeout=1.0,
978
- retry_settings=NO_RETRIES,
979
- )
980
-
981
- # Should be cancelled quickly since cancel_event is already set
982
- assert results2 == ["task1: cancelled", "task2: cancelled"]
983
- # Call counts should increment
984
- assert call_counts["task1"] == 2
985
- assert call_counts["task2"] == 2
986
-
987
- asyncio.run(run_test())
841
+ raise RuntimeError("Unexpected code path in _execute_with_retry_inner")