kash-shell 0.3.21__py3-none-any.whl → 0.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,8 @@ from kash.utils.api_utils.api_retries import (
16
16
  RetryExhaustedException,
17
17
  RetrySettings,
18
18
  calculate_backoff,
19
+ extract_http_status_code,
20
+ is_http_status_retriable,
19
21
  )
20
22
  from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
21
23
 
@@ -23,17 +25,33 @@ T = TypeVar("T")
23
25
 
24
26
  log = logging.getLogger(__name__)
25
27
 
28
+ DEFAULT_MAX_CONCURRENT: int = 5
29
+ DEFAULT_MAX_RPS: float = 5.0
30
+ DEFAULT_CANCEL_TIMEOUT: float = 1.0
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class Limit:
35
+ """
36
+ Rate limiting configuration with max RPS and max concurrency.
37
+ """
38
+
39
+ rps: float = DEFAULT_MAX_RPS
40
+ concurrency: int = DEFAULT_MAX_CONCURRENT
41
+
26
42
 
27
43
  @dataclass(frozen=True)
28
44
  class FuncTask:
29
45
  """
30
46
  A task described as an unevaluated function with args and kwargs.
31
47
  This task format allows you to use args and kwargs in the Labeler.
48
+ It also allows specifying a bucket for rate limiting.
32
49
  """
33
50
 
34
51
  func: Callable[..., Any]
35
52
  args: tuple[Any, ...] = ()
36
53
  kwargs: dict[str, Any] = field(default_factory=dict)
54
+ bucket: str = "default"
37
55
 
38
56
 
39
57
  # Type aliases for coroutine and sync specifications, including unevaluated function specs
@@ -44,9 +62,30 @@ SyncSpec: TypeAlias = Callable[[], T] | FuncTask
44
62
  CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
45
63
  SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
46
64
 
47
- DEFAULT_MAX_CONCURRENT: int = 5
48
- DEFAULT_MAX_RPS: float = 5.0
49
- DEFAULT_CANCEL_TIMEOUT: float = 1.0
65
+
66
+ def _get_bucket_limits(
67
+ bucket: str,
68
+ bucket_semaphores: dict[str, asyncio.Semaphore],
69
+ bucket_rate_limiters: dict[str, AsyncLimiter],
70
+ ) -> tuple[asyncio.Semaphore | None, AsyncLimiter | None]:
71
+ """
72
+ Get bucket-specific limits with fallback to "*" wildcard.
73
+
74
+ Checks for exact bucket match first, then falls back to "*" if available.
75
+ Returns (None, None) if neither exact match nor "*" fallback exists.
76
+ """
77
+ # Try exact bucket match first
78
+ bucket_semaphore = bucket_semaphores.get(bucket)
79
+ bucket_rate_limiter = bucket_rate_limiters.get(bucket)
80
+
81
+ if bucket_semaphore is not None and bucket_rate_limiter is not None:
82
+ return bucket_semaphore, bucket_rate_limiter
83
+
84
+ # Fall back to "*" wildcard if available
85
+ bucket_semaphore = bucket_semaphores.get("*")
86
+ bucket_rate_limiter = bucket_rate_limiters.get("*")
87
+
88
+ return bucket_semaphore, bucket_rate_limiter
50
89
 
51
90
 
52
91
  class RetryCounter:
@@ -75,8 +114,8 @@ class RetryCounter:
75
114
  @overload
76
115
  async def gather_limited_async(
77
116
  *coro_specs: CoroSpec[T],
78
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
79
- max_rps: float = DEFAULT_MAX_RPS,
117
+ global_limit: Limit = Limit(),
118
+ bucket_limits: dict[str, Limit] | None = None,
80
119
  return_exceptions: bool = False,
81
120
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
82
121
  status: ProgressTracker | None = None,
@@ -87,8 +126,8 @@ async def gather_limited_async(
87
126
  @overload
88
127
  async def gather_limited_async(
89
128
  *coro_specs: CoroSpec[T],
90
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
91
- max_rps: float = DEFAULT_MAX_RPS,
129
+ global_limit: Limit = Limit(),
130
+ bucket_limits: dict[str, Limit] | None = None,
92
131
  return_exceptions: bool = True,
93
132
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
94
133
  status: ProgressTracker | None = None,
@@ -98,21 +137,26 @@ async def gather_limited_async(
98
137
 
99
138
  async def gather_limited_async(
100
139
  *coro_specs: CoroSpec[T],
101
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
102
- max_rps: float = DEFAULT_MAX_RPS,
103
- return_exceptions: bool = False,
140
+ global_limit: Limit = Limit(),
141
+ bucket_limits: dict[str, Limit] | None = None,
142
+ return_exceptions: bool = True, # Default to True for resilient batch operations
104
143
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
105
144
  status: ProgressTracker | None = None,
106
145
  labeler: CoroLabeler[T] | None = None,
107
146
  ) -> list[T] | list[T | BaseException]:
108
147
  """
109
- Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
148
+ Rate-limited version of `asyncio.gather()` with HTTP-aware retry logic and optional progress display.
110
149
  Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
111
150
 
112
151
  Supports two levels of retry limits:
113
152
  - Per-task retries: max_task_retries attempts per individual task
114
153
  - Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
115
154
 
155
+ Features HTTP-aware retry classification:
156
+ - Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
157
+ - Configurable handling of conditional status codes like 403 Forbidden
158
+ - Defaults to return_exceptions=True for resilient batch operations
159
+
116
160
  Can optionally display live progress with retry indicators using TaskStatus.
117
161
 
118
162
  Accepts:
@@ -126,7 +170,7 @@ async def gather_limited_async(
126
170
  from kash.utils.rich_custom.task_status import TaskStatus
127
171
 
128
172
  async with TaskStatus() as status:
129
- await gather_limited(
173
+ await gather_limited_async(
130
174
  lambda: fetch_url("http://example.com"),
131
175
  lambda: process_data(data),
132
176
  status=status,
@@ -135,18 +179,31 @@ async def gather_limited_async(
135
179
  )
136
180
 
137
181
  # Without progress display:
138
- await gather_limited(
182
+ await gather_limited_async(
139
183
  lambda: fetch_url("http://example.com"),
140
184
  lambda: process_data(data),
141
185
  retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
142
186
  )
143
187
 
188
+ # With bucket-specific limits and "*" fallback:
189
+ await gather_limited_async(
190
+ FuncTask(fetch_api, ("data1",), bucket="api1"),
191
+ FuncTask(fetch_api, ("data2",), bucket="api2"),
192
+ FuncTask(fetch_api, ("data3",), bucket="api3"),
193
+ global_limit=Limit(rps=100, concurrency=50),
194
+ bucket_limits={
195
+ "api1": Limit(rps=20, concurrency=10), # Specific limit for api1
196
+ "*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
197
+ }
198
+ )
199
+
144
200
  ```
145
201
 
146
202
  Args:
147
203
  *coro_specs: Callables or coroutines to execute
148
- max_concurrent: Maximum number of concurrent executions
149
- max_rps: Maximum requests per second
204
+ global_limit: Global limits applied to all tasks regardless of bucket
205
+ bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
206
+ Use "*" as a fallback limit for buckets without specific limits.
150
207
  return_exceptions: If True, exceptions are returned as results
151
208
  retry_settings: Configuration for retry behavior, or None to disable retries
152
209
  status: Optional ProgressTracker instance for progress display
@@ -159,9 +216,9 @@ async def gather_limited_async(
159
216
  ValueError: If coroutines are passed when retries are enabled
160
217
  """
161
218
  log.info(
162
- "Executing with concurrency %s at %s rps, %s",
163
- max_concurrent,
164
- max_rps,
219
+ "Executing with global limits: concurrency %s at %s rps, %s",
220
+ global_limit.concurrency,
221
+ global_limit.rps,
165
222
  retry_settings,
166
223
  )
167
224
  if not coro_specs:
@@ -179,8 +236,18 @@ async def gather_limited_async(
179
236
  f"lambda: your_async_func(args) instead of your_async_func(args)"
180
237
  )
181
238
 
182
- semaphore = asyncio.Semaphore(max_concurrent)
183
- rate_limiter = AsyncLimiter(max_rps, 1.0)
239
+ # Global limits (apply to all tasks regardless of bucket)
240
+ global_semaphore = asyncio.Semaphore(global_limit.concurrency)
241
+ global_rate_limiter = AsyncLimiter(global_limit.rps, 1.0)
242
+
243
+ # Per-bucket limits (if bucket_limits provided)
244
+ bucket_semaphores: dict[str, asyncio.Semaphore] = {}
245
+ bucket_rate_limiters: dict[str, AsyncLimiter] = {}
246
+
247
+ if bucket_limits:
248
+ for bucket_name, limit in bucket_limits.items():
249
+ bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
250
+ bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
184
251
 
185
252
  # Global retry counter (shared across all tasks)
186
253
  global_retry_counter = RetryCounter(retry_settings.max_total_retries)
@@ -190,6 +257,16 @@ async def gather_limited_async(
190
257
  label = labeler(i, coro_spec) if labeler else f"task:{i}"
191
258
  task_id = await status.add(label) if status else None
192
259
 
260
+ # Determine bucket and get appropriate limits
261
+ bucket = "default"
262
+ if isinstance(coro_spec, FuncTask):
263
+ bucket = coro_spec.bucket
264
+
265
+ # Get bucket-specific limits if available
266
+ bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
267
+ bucket, bucket_semaphores, bucket_rate_limiters
268
+ )
269
+
193
270
  async def executor() -> T:
194
271
  # Create a fresh coroutine for each attempt
195
272
  if isinstance(coro_spec, FuncTask):
@@ -206,8 +283,10 @@ async def gather_limited_async(
206
283
  result = await _execute_with_retry(
207
284
  executor,
208
285
  retry_settings,
209
- semaphore,
210
- rate_limiter,
286
+ global_semaphore,
287
+ global_rate_limiter,
288
+ bucket_semaphore,
289
+ bucket_rate_limiter,
211
290
  global_retry_counter,
212
291
  status,
213
292
  task_id,
@@ -234,8 +313,8 @@ async def gather_limited_async(
234
313
  @overload
235
314
  async def gather_limited_sync(
236
315
  *sync_specs: SyncSpec[T],
237
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
238
- max_rps: float = DEFAULT_MAX_RPS,
316
+ global_limit: Limit = Limit(),
317
+ bucket_limits: dict[str, Limit] | None = None,
239
318
  return_exceptions: bool = False,
240
319
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
241
320
  status: ProgressTracker | None = None,
@@ -248,8 +327,8 @@ async def gather_limited_sync(
248
327
  @overload
249
328
  async def gather_limited_sync(
250
329
  *sync_specs: SyncSpec[T],
251
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
252
- max_rps: float = DEFAULT_MAX_RPS,
330
+ global_limit: Limit = Limit(),
331
+ bucket_limits: dict[str, Limit] | None = None,
253
332
  return_exceptions: bool = True,
254
333
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
255
334
  status: ProgressTracker | None = None,
@@ -261,9 +340,9 @@ async def gather_limited_sync(
261
340
 
262
341
  async def gather_limited_sync(
263
342
  *sync_specs: SyncSpec[T],
264
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
265
- max_rps: float = DEFAULT_MAX_RPS,
266
- return_exceptions: bool = False,
343
+ global_limit: Limit = Limit(),
344
+ bucket_limits: dict[str, Limit] | None = None,
345
+ return_exceptions: bool = True, # Default to True for resilient batch operations
267
346
  retry_settings: RetrySettings | None = DEFAULT_RETRIES,
268
347
  status: ProgressTracker | None = None,
269
348
  labeler: SyncLabeler[T] | None = None,
@@ -271,19 +350,25 @@ async def gather_limited_sync(
271
350
  cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
272
351
  ) -> list[T] | list[T | BaseException]:
273
352
  """
274
- Rate-limited version of `asyncio.gather()` for sync functions with retry logic.
353
+ Rate-limited version of `asyncio.gather()` for sync functions with HTTP-aware retry logic.
275
354
  Handles the asyncio.to_thread() boundary correctly for consistent exception propagation.
276
355
 
277
356
  Supports two levels of retry limits:
278
357
  - Per-task retries: max_task_retries attempts per individual task
279
358
  - Global retries: max_total_retries attempts across all tasks
280
359
 
360
+ Features HTTP-aware retry classification:
361
+ - Automatically detects HTTP status codes (403, 429, 500, etc.) and applies appropriate retry behavior
362
+ - Configurable handling of conditional status codes like 403 Forbidden
363
+ - Defaults to return_exceptions=True for resilient batch operations
364
+
281
365
  Supports cooperative cancellation and graceful thread termination on interruption.
282
366
 
283
367
  Args:
284
368
  *sync_specs: Callables that return values (not coroutines) or FuncTask objects
285
- max_concurrent: Maximum number of concurrent executions
286
- max_rps: Maximum requests per second
369
+ global_limit: Global limits applied to all tasks regardless of bucket
370
+ bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
371
+ Use "*" as a fallback limit for buckets without specific limits.
287
372
  return_exceptions: If True, exceptions are returned as results
288
373
  retry_settings: Configuration for retry behavior, or None to disable retries
289
374
  status: Optional ProgressTracker instance for progress display
@@ -296,29 +381,31 @@ async def gather_limited_sync(
296
381
 
297
382
  Example:
298
383
  ```python
299
- # Without cooperative cancellation
384
+ # Without bucket limits (backward compatible)
300
385
  results = await gather_limited_sync(
301
386
  lambda: some_sync_function(arg1),
302
387
  lambda: another_sync_function(arg2),
303
- max_concurrent=3,
304
- max_rps=2.0,
388
+ global_limit=Limit(rps=2.0, concurrency=3),
305
389
  retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
306
390
  )
307
391
 
308
- # With cooperative cancellation
309
- cancel_event = threading.Event()
392
+ # With bucket-specific limits and "*" fallback
310
393
  results = await gather_limited_sync(
311
- lambda: cancellable_sync_function(cancel_event, arg1),
312
- lambda: another_cancellable_function(cancel_event, arg2),
313
- cancel_event=cancel_event,
314
- cancel_timeout=5.0,
394
+ FuncTask(fetch_from_api, ("data1",), bucket="api1"),
395
+ FuncTask(fetch_from_api, ("data2",), bucket="api2"),
396
+ FuncTask(fetch_from_api, ("data3",), bucket="api3"),
397
+ global_limit=Limit(rps=100, concurrency=50),
398
+ bucket_limits={
399
+ "api1": Limit(rps=20, concurrency=10), # Specific limit for api1
400
+ "*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
401
+ }
315
402
  )
316
403
  ```
317
404
  """
318
405
  log.info(
319
- "Executing with concurrency %s at %s rps, %s",
320
- max_concurrent,
321
- max_rps,
406
+ "Executing with global limits: concurrency %s at %s rps, %s",
407
+ global_limit.concurrency,
408
+ global_limit.rps,
322
409
  retry_settings,
323
410
  )
324
411
  if not sync_specs:
@@ -326,8 +413,18 @@ async def gather_limited_sync(
326
413
 
327
414
  retry_settings = retry_settings or NO_RETRIES
328
415
 
329
- semaphore = asyncio.Semaphore(max_concurrent)
330
- rate_limiter = AsyncLimiter(max_rps, 1.0)
416
+ # Global limits (apply to all tasks regardless of bucket)
417
+ global_semaphore = asyncio.Semaphore(global_limit.concurrency)
418
+ global_rate_limiter = AsyncLimiter(global_limit.rps, 1.0)
419
+
420
+ # Per-bucket limits (if bucket_limits provided)
421
+ bucket_semaphores: dict[str, asyncio.Semaphore] = {}
422
+ bucket_rate_limiters: dict[str, AsyncLimiter] = {}
423
+
424
+ if bucket_limits:
425
+ for bucket_name, limit in bucket_limits.items():
426
+ bucket_semaphores[bucket_name] = asyncio.Semaphore(limit.concurrency)
427
+ bucket_rate_limiters[bucket_name] = AsyncLimiter(limit.rps, 1.0)
331
428
 
332
429
  # Global retry counter (shared across all tasks)
333
430
  global_retry_counter = RetryCounter(retry_settings.max_total_retries)
@@ -337,6 +434,16 @@ async def gather_limited_sync(
337
434
  label = labeler(i, sync_spec) if labeler else f"task:{i}"
338
435
  task_id = await status.add(label) if status else None
339
436
 
437
+ # Determine bucket and get appropriate limits
438
+ bucket = "default"
439
+ if isinstance(sync_spec, FuncTask):
440
+ bucket = sync_spec.bucket
441
+
442
+ # Get bucket-specific limits if available
443
+ bucket_semaphore, bucket_rate_limiter = _get_bucket_limits(
444
+ bucket, bucket_semaphores, bucket_rate_limiters
445
+ )
446
+
340
447
  async def executor() -> T:
341
448
  # Call sync function via asyncio.to_thread, handling retry at this level
342
449
  if isinstance(sync_spec, FuncTask):
@@ -361,8 +468,10 @@ async def gather_limited_sync(
361
468
  result = await _execute_with_retry(
362
469
  executor,
363
470
  retry_settings,
364
- semaphore,
365
- rate_limiter,
471
+ global_semaphore,
472
+ global_rate_limiter,
473
+ bucket_semaphore,
474
+ bucket_rate_limiter,
366
475
  global_retry_counter,
367
476
  status,
368
477
  task_id,
@@ -434,7 +543,8 @@ async def _gather_with_interrupt_handling(
434
543
  if cancelled_count > 0:
435
544
  try:
436
545
  await asyncio.wait_for(
437
- asyncio.gather(*async_tasks, return_exceptions=True), timeout=cancel_timeout
546
+ asyncio.gather(*async_tasks, return_exceptions=True),
547
+ timeout=cancel_timeout,
438
548
  )
439
549
  except (TimeoutError, asyncio.CancelledError):
440
550
  log.warning("Some tasks did not cancel within timeout")
@@ -462,8 +572,10 @@ async def _gather_with_interrupt_handling(
462
572
  async def _execute_with_retry(
463
573
  executor: Callable[[], Coroutine[None, None, T]],
464
574
  retry_settings: RetrySettings,
465
- semaphore: asyncio.Semaphore,
466
- rate_limiter: AsyncLimiter,
575
+ global_semaphore: asyncio.Semaphore,
576
+ global_rate_limiter: AsyncLimiter,
577
+ bucket_semaphore: asyncio.Semaphore | None,
578
+ bucket_rate_limiter: AsyncLimiter | None,
467
579
  global_retry_counter: RetryCounter,
468
580
  status: ProgressTracker | None = None,
469
581
  task_id: Any | None = None,
@@ -473,9 +585,22 @@ async def _execute_with_retry(
473
585
  start_time = time.time()
474
586
  last_exception: Exception | None = None
475
587
 
588
+ # Create HTTP-aware is_retriable function based on settings
589
+ def is_retriable_for_task(exception: Exception) -> bool:
590
+ # First check the main is_retriable function
591
+ if not retry_settings.is_retriable(exception):
592
+ return False
593
+
594
+ status_code = extract_http_status_code(exception)
595
+ if status_code:
596
+ return is_http_status_retriable(status_code, retry_settings.http_retry_map)
597
+
598
+ # Not an HTTP error, use the main is_retriable result
599
+ return True
600
+
476
601
  for attempt in range(retry_settings.max_task_retries + 1):
477
602
  # Handle backoff before acquiring any resources
478
- if attempt > 0 and last_exception is not None:
603
+ if attempt > 0 and last_exception:
479
604
  # Try to increment global retry counter
480
605
  if not await global_retry_counter.try_increment():
481
606
  log.error(
@@ -493,7 +618,7 @@ async def _execute_with_retry(
493
618
  )
494
619
 
495
620
  # Record retry in status display and log appropriately
496
- if status and task_id is not None:
621
+ if status and task_id:
497
622
  # Include retry attempt info and backoff time in the status display
498
623
  retry_info = (
499
624
  f"Attempt {attempt}/{retry_settings.max_task_retries} "
@@ -508,8 +633,11 @@ async def _execute_with_retry(
508
633
  use_debug_level = False
509
634
 
510
635
  # Log retry information at appropriate level
636
+ status_code = extract_http_status_code(last_exception)
637
+ status_info = f" (HTTP {status_code})" if status_code else ""
638
+
511
639
  rate_limit_msg = (
512
- f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
640
+ f"Rate limit hit{status_info} (attempt {attempt}/{retry_settings.max_task_retries} "
513
641
  f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
514
642
  f"backing off for {backoff_time:.2f}s"
515
643
  )
@@ -526,12 +654,20 @@ async def _execute_with_retry(
526
654
  await asyncio.sleep(backoff_time)
527
655
 
528
656
  try:
529
- # Acquire semaphore and rate limiter right before making the call
530
- async with semaphore, rate_limiter:
531
- # Mark task as started now that we've passed rate limiting
532
- if status and task_id is not None and attempt == 0:
533
- await status.start(task_id)
534
- return await executor()
657
+ # Acquire global limits first, then bucket-specific limits if present
658
+ async with global_semaphore, global_rate_limiter:
659
+ if bucket_semaphore is not None and bucket_rate_limiter is not None:
660
+ async with bucket_semaphore, bucket_rate_limiter:
661
+ # Mark task as started now that we've passed rate limiting
662
+ if status and task_id is not None and attempt == 0:
663
+ await status.start(task_id)
664
+ return await executor()
665
+ else:
666
+ # Mark task as started now that we've passed rate limiting
667
+ if status and task_id is not None and attempt == 0:
668
+ await status.start(task_id)
669
+ return await executor()
670
+
535
671
  except Exception as e:
536
672
  last_exception = e # Always store the exception
537
673
 
@@ -549,439 +685,19 @@ async def _execute_with_retry(
549
685
  )
550
686
  raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
551
687
 
552
- # Check if this is a retriable exception
553
- if retry_settings.is_retriable(e):
688
+ # Check if this is a retriable exception using our HTTP-aware function
689
+ if is_retriable_for_task(e):
554
690
  # Continue to next retry attempt (global limits will be checked at top of loop)
555
691
  continue
556
692
  else:
557
693
  # Non-retriable exception, log and re-raise immediately
558
- log.warning("Non-retriable exception (not retrying): %s", e, exc_info=True)
694
+ status_code = extract_http_status_code(e)
695
+ status_info = f" (HTTP {status_code})" if status_code else ""
696
+
697
+ log.warning(
698
+ "Non-retriable exception%s (not retrying): %s", status_info, e, exc_info=True
699
+ )
559
700
  raise
560
701
 
561
702
  # This should never be reached, but satisfy type checker
562
703
  raise RuntimeError("Unexpected code path in _execute_with_retry")
563
-
564
-
565
- ## Tests
566
-
567
-
568
- def test_gather_limited_sync():
569
- """Test gather_limited_sync with sync functions."""
570
- import asyncio
571
- import time
572
-
573
- async def run_test():
574
- def sync_func(value: int) -> int:
575
- """Simple sync function for testing."""
576
- time.sleep(0.1) # Simulate some work
577
- return value * 2
578
-
579
- # Test basic functionality
580
- results = await gather_limited_sync(
581
- lambda: sync_func(1),
582
- lambda: sync_func(2),
583
- lambda: sync_func(3),
584
- max_concurrent=2,
585
- max_rps=10.0,
586
- retry_settings=NO_RETRIES,
587
- )
588
-
589
- assert results == [2, 4, 6]
590
-
591
- # Run the async test
592
- asyncio.run(run_test())
593
-
594
-
595
- def test_gather_limited_sync_with_retries():
596
- """Test that sync functions can be retried on retriable exceptions."""
597
- import asyncio
598
-
599
- async def run_test():
600
- call_count = 0
601
-
602
- def flaky_sync_func() -> str:
603
- """Sync function that fails first time, succeeds second time."""
604
- nonlocal call_count
605
- call_count += 1
606
- if call_count == 1:
607
- raise Exception("Rate limit exceeded") # Retriable
608
- return "success"
609
-
610
- # Should succeed after retry
611
- results = await gather_limited_sync(
612
- lambda: flaky_sync_func(),
613
- retry_settings=RetrySettings(
614
- max_task_retries=2,
615
- initial_backoff=0.1,
616
- max_backoff=1.0,
617
- backoff_factor=2.0,
618
- ),
619
- )
620
-
621
- assert results == ["success"]
622
- assert call_count == 2 # Called twice (failed once, succeeded once)
623
-
624
- # Run the async test
625
- asyncio.run(run_test())
626
-
627
-
628
- def test_gather_limited_async_basic():
629
- """Test gather_limited with async functions using callables."""
630
- import asyncio
631
-
632
- async def run_test():
633
- async def async_func(value: int) -> int:
634
- """Simple async function for testing."""
635
- await asyncio.sleep(0.05) # Simulate async work
636
- return value * 3
637
-
638
- # Test with callables (recommended pattern)
639
- results = await gather_limited_async(
640
- lambda: async_func(1),
641
- lambda: async_func(2),
642
- lambda: async_func(3),
643
- max_concurrent=2,
644
- max_rps=10.0,
645
- retry_settings=NO_RETRIES,
646
- )
647
-
648
- assert results == [3, 6, 9]
649
-
650
- asyncio.run(run_test())
651
-
652
-
653
- def test_gather_limited_direct_coroutines():
654
- """Test gather_limited with direct coroutines when retries disabled."""
655
- import asyncio
656
-
657
- async def run_test():
658
- async def async_func(value: int) -> int:
659
- await asyncio.sleep(0.05)
660
- return value * 4
661
-
662
- # Test with direct coroutines (only works when retries disabled)
663
- results = await gather_limited_async(
664
- async_func(1),
665
- async_func(2),
666
- async_func(3),
667
- retry_settings=NO_RETRIES, # Required for direct coroutines
668
- )
669
-
670
- assert results == [4, 8, 12]
671
-
672
- asyncio.run(run_test())
673
-
674
-
675
- def test_gather_limited_coroutine_retry_validation():
676
- """Test that passing coroutines with retries enabled raises ValueError."""
677
- import asyncio
678
-
679
- async def run_test():
680
- async def async_func(value: int) -> int:
681
- return value
682
-
683
- coro = async_func(1) # Direct coroutine
684
-
685
- # Should raise ValueError when trying to use coroutines with retries
686
- try:
687
- await gather_limited_async(
688
- coro, # Direct coroutine
689
- lambda: async_func(2), # Callable
690
- retry_settings=RetrySettings(
691
- max_task_retries=1,
692
- initial_backoff=0.1,
693
- max_backoff=1.0,
694
- backoff_factor=2.0,
695
- ),
696
- )
697
- raise AssertionError("Expected ValueError")
698
- except ValueError as e:
699
- coro.close() # Close the unused coroutine to prevent RuntimeWarning
700
- assert "position 0" in str(e)
701
- assert "cannot be retried" in str(e)
702
-
703
- asyncio.run(run_test())
704
-
705
-
706
- def test_gather_limited_async_with_retries():
707
- """Test that async functions can be retried when using callables."""
708
- import asyncio
709
-
710
- async def run_test():
711
- call_count = 0
712
-
713
- async def flaky_async_func() -> str:
714
- """Async function that fails first time, succeeds second time."""
715
- nonlocal call_count
716
- call_count += 1
717
- if call_count == 1:
718
- raise Exception("Rate limit exceeded") # Retriable
719
- return "async_success"
720
-
721
- # Should succeed after retry using callable
722
- results = await gather_limited_async(
723
- lambda: flaky_async_func(),
724
- retry_settings=RetrySettings(
725
- max_task_retries=2,
726
- initial_backoff=0.1,
727
- max_backoff=1.0,
728
- backoff_factor=2.0,
729
- ),
730
- )
731
-
732
- assert results == ["async_success"]
733
- assert call_count == 2 # Called twice (failed once, succeeded once)
734
-
735
- asyncio.run(run_test())
736
-
737
-
738
- def test_gather_limited_sync_coroutine_validation():
739
- """Test that passing async function callables to sync version raises ValueError."""
740
- import asyncio
741
-
742
- async def run_test():
743
- async def async_func(value: int) -> int:
744
- return value
745
-
746
- # Should raise ValueError when trying to use async functions in sync version
747
- try:
748
- await gather_limited_sync(
749
- lambda: async_func(1), # Returns coroutine - should be rejected
750
- retry_settings=NO_RETRIES,
751
- )
752
- raise AssertionError("Expected ValueError")
753
- except ValueError as e:
754
- assert "returned a coroutine" in str(e)
755
- assert "gather_limited_sync() is for synchronous functions only" in str(e)
756
-
757
- asyncio.run(run_test())
758
-
759
-
760
- def test_gather_limited_retry_exhaustion():
761
- """Test that retry exhaustion produces clear error messages."""
762
- import asyncio
763
-
764
- async def run_test():
765
- call_count = 0
766
-
767
- def always_fails() -> str:
768
- """Function that always raises retriable exceptions."""
769
- nonlocal call_count
770
- call_count += 1
771
- raise Exception("Rate limit exceeded") # Always retriable
772
-
773
- # Should exhaust retries and raise RetryExhaustedException
774
- try:
775
- await gather_limited_sync(
776
- lambda: always_fails(),
777
- retry_settings=RetrySettings(
778
- max_task_retries=2,
779
- initial_backoff=0.01,
780
- max_backoff=0.1,
781
- backoff_factor=2.0,
782
- ),
783
- )
784
- raise AssertionError("Expected RetryExhaustedException")
785
- except RetryExhaustedException as e:
786
- assert "Max retries (2) exhausted" in str(e)
787
- assert "Rate limit exceeded" in str(e)
788
- assert isinstance(e.original_exception, Exception)
789
- assert call_count == 3 # Initial attempt + 2 retries
790
-
791
- asyncio.run(run_test())
792
-
793
-
794
- def test_gather_limited_return_exceptions():
795
- """Test return_exceptions=True behavior for both functions."""
796
- import asyncio
797
-
798
- async def run_test():
799
- def failing_sync() -> str:
800
- raise ValueError("sync error")
801
-
802
- async def failing_async() -> str:
803
- raise ValueError("async error")
804
-
805
- # Test sync version with exceptions returned
806
- sync_results = await gather_limited_sync(
807
- lambda: "success",
808
- lambda: failing_sync(),
809
- return_exceptions=True,
810
- retry_settings=NO_RETRIES,
811
- )
812
-
813
- assert len(sync_results) == 2
814
- assert sync_results[0] == "success"
815
- assert isinstance(sync_results[1], ValueError)
816
- assert str(sync_results[1]) == "sync error"
817
-
818
- async def success_async() -> str:
819
- return "async_success"
820
-
821
- # Test async version with exceptions returned
822
- async_results = await gather_limited_async(
823
- lambda: success_async(),
824
- lambda: failing_async(),
825
- return_exceptions=True,
826
- retry_settings=NO_RETRIES,
827
- )
828
-
829
- assert len(async_results) == 2
830
- assert async_results[0] == "async_success"
831
- assert isinstance(async_results[1], ValueError)
832
- assert str(async_results[1]) == "async error"
833
-
834
- asyncio.run(run_test())
835
-
836
-
837
- def test_gather_limited_global_retry_limit():
838
- """Test that global retry limits are enforced across all tasks."""
839
- import asyncio
840
-
841
- async def run_test():
842
- retry_counts = {"task1": 0, "task2": 0}
843
-
844
- def flaky_task(task_name: str) -> str:
845
- """Tasks that always fail but track retry counts."""
846
- retry_counts[task_name] += 1
847
- raise Exception(f"Rate limit exceeded in {task_name}")
848
-
849
- # Test with very low global retry limit
850
- try:
851
- await gather_limited_sync(
852
- lambda: flaky_task("task1"),
853
- lambda: flaky_task("task2"),
854
- retry_settings=RetrySettings(
855
- max_task_retries=5, # Each task could retry up to 5 times
856
- max_total_retries=3, # But only 3 total retries across all tasks
857
- initial_backoff=0.01,
858
- max_backoff=0.1,
859
- backoff_factor=2.0,
860
- ),
861
- return_exceptions=True,
862
- )
863
- except Exception:
864
- pass # Expected to fail due to rate limits
865
-
866
- # Verify that total retries across both tasks doesn't exceed global limit
867
- total_retries = (retry_counts["task1"] - 1) + (
868
- retry_counts["task2"] - 1
869
- ) # -1 for initial attempts
870
- assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
871
-
872
- # Verify that both tasks were attempted at least once
873
- assert retry_counts["task1"] >= 1
874
- assert retry_counts["task2"] >= 1
875
-
876
- asyncio.run(run_test())
877
-
878
-
879
- def test_gather_limited_funcspec_format():
880
- """Test gather_limited with FuncSpec format and custom labeler accessing args."""
881
- import asyncio
882
-
883
- async def run_test():
884
- def sync_func(name: str, value: int, multiplier: int = 2) -> str:
885
- """Sync function that takes args and kwargs."""
886
- return f"{name}: {value * multiplier}"
887
-
888
- async def async_func(name: str, value: int, multiplier: int = 2) -> str:
889
- """Async function that takes args and kwargs."""
890
- await asyncio.sleep(0.01)
891
- return f"{name}: {value * multiplier}"
892
-
893
- captured_labels = []
894
-
895
- def custom_labeler(i: int, spec: Any) -> str:
896
- if isinstance(spec, FuncTask):
897
- # Extract meaningful info from args for labeling
898
- if spec.args and len(spec.args) > 0:
899
- label = f"Processing {spec.args[0]}"
900
- else:
901
- label = f"Task {i}"
902
- else:
903
- label = f"Task {i}"
904
- captured_labels.append(label)
905
- return label
906
-
907
- # Test sync version with FuncSpec format and custom labeler
908
- sync_results = await gather_limited_sync(
909
- FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
910
- FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
911
- labeler=custom_labeler,
912
- retry_settings=NO_RETRIES,
913
- )
914
-
915
- assert sync_results == ["user1: 300", "user2: 100"]
916
- assert captured_labels == ["Processing user1", "Processing user2"]
917
-
918
- # Reset labels for async test
919
- captured_labels.clear()
920
-
921
- # Test async version with FuncSpec format and custom labeler
922
- async_results = await gather_limited_async(
923
- FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
924
- FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
925
- labeler=custom_labeler,
926
- retry_settings=NO_RETRIES,
927
- )
928
-
929
- assert async_results == ["api_call: 40", "data_fetch: 10"]
930
- assert captured_labels == ["Processing api_call", "Processing data_fetch"]
931
-
932
- asyncio.run(run_test())
933
-
934
-
935
- def test_gather_limited_sync_cooperative_cancellation():
936
- """Test gather_limited_sync with cooperative cancellation via threading.Event."""
937
- import asyncio
938
- import time
939
-
940
- async def run_test():
941
- cancel_event = threading.Event()
942
- call_counts = {"task1": 0, "task2": 0}
943
-
944
- def cancellable_sync_func(task_name: str, work_duration: float) -> str:
945
- """Sync function that checks cancellation event periodically."""
946
- call_counts[task_name] += 1
947
- start_time = time.time()
948
-
949
- while time.time() - start_time < work_duration:
950
- if cancel_event.is_set():
951
- return f"{task_name}: cancelled"
952
- time.sleep(0.01) # Small sleep to allow cancellation check
953
-
954
- return f"{task_name}: completed"
955
-
956
- # Test cooperative cancellation - tasks should respect the cancel_event
957
- results = await gather_limited_sync(
958
- lambda: cancellable_sync_func("task1", 0.1), # Short duration
959
- lambda: cancellable_sync_func("task2", 0.1), # Short duration
960
- cancel_event=cancel_event,
961
- cancel_timeout=1.0,
962
- retry_settings=NO_RETRIES,
963
- )
964
-
965
- # Should complete normally since cancel_event wasn't set
966
- assert results == ["task1: completed", "task2: completed"]
967
- assert call_counts["task1"] == 1
968
- assert call_counts["task2"] == 1
969
-
970
- # Test that cancel_event can be used independently
971
- cancel_event.set() # Set cancellation signal
972
-
973
- results2 = await gather_limited_sync(
974
- lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
975
- lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
976
- cancel_event=cancel_event,
977
- cancel_timeout=1.0,
978
- retry_settings=NO_RETRIES,
979
- )
980
-
981
- # Should be cancelled quickly since cancel_event is already set
982
- assert results2 == ["task1: cancelled", "task2: cancelled"]
983
- # Call counts should increment
984
- assert call_counts["task1"] == 2
985
- assert call_counts["task2"] == 2
986
-
987
- asyncio.run(run_test())