kash-shell 0.3.20__py3-none-any.whl → 0.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,987 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ import logging
6
+ import threading
7
+ from collections.abc import Callable, Coroutine
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, TypeAlias, TypeVar, cast, overload
10
+
11
+ from aiolimiter import AsyncLimiter
12
+
13
+ from kash.utils.api_utils.api_retries import (
14
+ DEFAULT_RETRIES,
15
+ NO_RETRIES,
16
+ RetryExhaustedException,
17
+ RetrySettings,
18
+ calculate_backoff,
19
+ )
20
+ from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
21
+
22
+ T = TypeVar("T")
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class FuncTask:
29
+ """
30
+ A task described as an unevaluated function with args and kwargs.
31
+ This task format allows you to use args and kwargs in the Labeler.
32
+ """
33
+
34
+ func: Callable[..., Any]
35
+ args: tuple[Any, ...] = ()
36
+ kwargs: dict[str, Any] = field(default_factory=dict)
37
+
38
+
39
+ # Type aliases for coroutine and sync specifications, including unevaluated function specs
40
+ CoroSpec: TypeAlias = Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask
41
+ SyncSpec: TypeAlias = Callable[[], T] | FuncTask
42
+
43
+ # Specific labeler types using the generic Labeler pattern
44
+ CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
45
+ SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
46
+
47
+ DEFAULT_MAX_CONCURRENT: int = 5
48
+ DEFAULT_MAX_RPS: float = 5.0
49
+ DEFAULT_CANCEL_TIMEOUT: float = 1.0
50
+
51
+
52
+ class RetryCounter:
53
+ """Thread-safe counter for tracking retries across all tasks."""
54
+
55
+ def __init__(self, max_total_retries: int | None):
56
+ self.max_total_retries = max_total_retries
57
+ self.count = 0
58
+ self._lock = asyncio.Lock()
59
+
60
+ async def try_increment(self) -> bool:
61
+ """
62
+ Try to increment the retry counter.
63
+ Returns True if increment was successful, False if limit reached.
64
+ """
65
+ if self.max_total_retries is None:
66
+ return True
67
+
68
+ async with self._lock:
69
+ if self.count < self.max_total_retries:
70
+ self.count += 1
71
+ return True
72
+ return False
73
+
74
+
75
+ @overload
76
+ async def gather_limited_async(
77
+ *coro_specs: CoroSpec[T],
78
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
79
+ max_rps: float = DEFAULT_MAX_RPS,
80
+ return_exceptions: bool = False,
81
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
82
+ status: ProgressTracker | None = None,
83
+ labeler: CoroLabeler[T] | None = None,
84
+ ) -> list[T]: ...
85
+
86
+
87
+ @overload
88
+ async def gather_limited_async(
89
+ *coro_specs: CoroSpec[T],
90
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
91
+ max_rps: float = DEFAULT_MAX_RPS,
92
+ return_exceptions: bool = True,
93
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
94
+ status: ProgressTracker | None = None,
95
+ labeler: CoroLabeler[T] | None = None,
96
+ ) -> list[T | BaseException]: ...
97
+
98
+
99
+ async def gather_limited_async(
100
+ *coro_specs: CoroSpec[T],
101
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
102
+ max_rps: float = DEFAULT_MAX_RPS,
103
+ return_exceptions: bool = False,
104
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
105
+ status: ProgressTracker | None = None,
106
+ labeler: CoroLabeler[T] | None = None,
107
+ ) -> list[T] | list[T | BaseException]:
108
+ """
109
+ Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
110
+ Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
111
+
112
+ Supports two levels of retry limits:
113
+ - Per-task retries: max_task_retries attempts per individual task
114
+ - Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
115
+
116
+ Can optionally display live progress with retry indicators using TaskStatus.
117
+
118
+ Accepts:
119
+ - Callables that return coroutines: `lambda: some_async_func(arg)` (recommended for retries)
120
+ - Coroutines directly: `some_async_func(arg)` (only if retries disabled)
121
+ - FuncSpec objects: `FuncSpec(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
122
+
123
+ Examples:
124
+ ```python
125
+ # With progress display and custom labeling:
126
+ from kash.utils.rich_custom.task_status import TaskStatus
127
+
128
+ async with TaskStatus() as status:
129
+ await gather_limited(
130
+ lambda: fetch_url("http://example.com"),
131
+ lambda: process_data(data),
132
+ status=status,
133
+ labeler=lambda i, spec: f"Task {i+1}",
134
+ retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
135
+ )
136
+
137
+ # Without progress display:
138
+ await gather_limited(
139
+ lambda: fetch_url("http://example.com"),
140
+ lambda: process_data(data),
141
+ retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
142
+ )
143
+
144
+ ```
145
+
146
+ Args:
147
+ *coro_specs: Callables or coroutines to execute
148
+ max_concurrent: Maximum number of concurrent executions
149
+ max_rps: Maximum requests per second
150
+ return_exceptions: If True, exceptions are returned as results
151
+ retry_settings: Configuration for retry behavior, or None to disable retries
152
+ status: Optional ProgressTracker instance for progress display
153
+ labeler: Optional function to generate labels: labeler(index, spec) -> str
154
+
155
+ Returns:
156
+ List of results in the same order as input specifications
157
+
158
+ Raises:
159
+ ValueError: If coroutines are passed when retries are enabled
160
+ """
161
+ log.info(
162
+ "Executing with concurrency %s at %s rps, %s",
163
+ max_concurrent,
164
+ max_rps,
165
+ retry_settings,
166
+ )
167
+ if not coro_specs:
168
+ return []
169
+
170
+ retry_settings = retry_settings or NO_RETRIES
171
+
172
+ # Validate that coroutines aren't used when retries are enabled
173
+ if retry_settings.max_task_retries > 0:
174
+ for i, spec in enumerate(coro_specs):
175
+ if inspect.iscoroutine(spec):
176
+ raise ValueError(
177
+ f"Coroutine at position {i} cannot be retried. "
178
+ f"When retries are enabled (max_task_retries > 0), pass callables that return fresh coroutines: "
179
+ f"lambda: your_async_func(args) instead of your_async_func(args)"
180
+ )
181
+
182
+ semaphore = asyncio.Semaphore(max_concurrent)
183
+ rate_limiter = AsyncLimiter(max_rps, 1.0)
184
+
185
+ # Global retry counter (shared across all tasks)
186
+ global_retry_counter = RetryCounter(retry_settings.max_total_retries)
187
+
188
+ async def run_task_with_retry(i: int, coro_spec: CoroSpec[T]) -> T:
189
+ # Generate label for this task
190
+ label = labeler(i, coro_spec) if labeler else f"task:{i}"
191
+ task_id = await status.add(label) if status else None
192
+
193
+ async def executor() -> T:
194
+ # Create a fresh coroutine for each attempt
195
+ if isinstance(coro_spec, FuncTask):
196
+ # FuncSpec format: FuncSpec(func, args, kwargs)
197
+ coro = coro_spec.func(*coro_spec.args, **coro_spec.kwargs)
198
+ elif callable(coro_spec):
199
+ coro = coro_spec()
200
+ else:
201
+ # Direct coroutine - only valid if retries disabled
202
+ coro = coro_spec
203
+ return await coro
204
+
205
+ try:
206
+ result = await _execute_with_retry(
207
+ executor,
208
+ retry_settings,
209
+ semaphore,
210
+ rate_limiter,
211
+ global_retry_counter,
212
+ status,
213
+ task_id,
214
+ )
215
+
216
+ # Mark as completed successfully
217
+ if status and task_id is not None:
218
+ await status.finish(task_id, TaskState.COMPLETED)
219
+
220
+ return result
221
+
222
+ except Exception as e:
223
+ # Mark as failed
224
+ if status and task_id is not None:
225
+ await status.finish(task_id, TaskState.FAILED, str(e))
226
+ raise
227
+
228
+ return await _gather_with_interrupt_handling(
229
+ [run_task_with_retry(i, spec) for i, spec in enumerate(coro_specs)],
230
+ return_exceptions,
231
+ )
232
+
233
+
234
+ @overload
235
+ async def gather_limited_sync(
236
+ *sync_specs: SyncSpec[T],
237
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
238
+ max_rps: float = DEFAULT_MAX_RPS,
239
+ return_exceptions: bool = False,
240
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
241
+ status: ProgressTracker | None = None,
242
+ labeler: SyncLabeler[T] | None = None,
243
+ cancel_event: threading.Event | None = None,
244
+ cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
245
+ ) -> list[T]: ...
246
+
247
+
248
+ @overload
249
+ async def gather_limited_sync(
250
+ *sync_specs: SyncSpec[T],
251
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
252
+ max_rps: float = DEFAULT_MAX_RPS,
253
+ return_exceptions: bool = True,
254
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
255
+ status: ProgressTracker | None = None,
256
+ labeler: SyncLabeler[T] | None = None,
257
+ cancel_event: threading.Event | None = None,
258
+ cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
259
+ ) -> list[T | BaseException]: ...
260
+
261
+
262
+ async def gather_limited_sync(
263
+ *sync_specs: SyncSpec[T],
264
+ max_concurrent: int = DEFAULT_MAX_CONCURRENT,
265
+ max_rps: float = DEFAULT_MAX_RPS,
266
+ return_exceptions: bool = False,
267
+ retry_settings: RetrySettings | None = DEFAULT_RETRIES,
268
+ status: ProgressTracker | None = None,
269
+ labeler: SyncLabeler[T] | None = None,
270
+ cancel_event: threading.Event | None = None,
271
+ cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
272
+ ) -> list[T] | list[T | BaseException]:
273
+ """
274
+ Rate-limited version of `asyncio.gather()` for sync functions with retry logic.
275
+ Handles the asyncio.to_thread() boundary correctly for consistent exception propagation.
276
+
277
+ Supports two levels of retry limits:
278
+ - Per-task retries: max_task_retries attempts per individual task
279
+ - Global retries: max_total_retries attempts across all tasks
280
+
281
+ Supports cooperative cancellation and graceful thread termination on interruption.
282
+
283
+ Args:
284
+ *sync_specs: Callables that return values (not coroutines) or FuncTask objects
285
+ max_concurrent: Maximum number of concurrent executions
286
+ max_rps: Maximum requests per second
287
+ return_exceptions: If True, exceptions are returned as results
288
+ retry_settings: Configuration for retry behavior, or None to disable retries
289
+ status: Optional ProgressTracker instance for progress display
290
+ labeler: Optional function to generate labels: labeler(index, spec) -> str
291
+ cancel_event: Optional threading.Event that will be set on cancellation
292
+ cancel_timeout: Max seconds to wait for threads to terminate on cancellation
293
+
294
+ Returns:
295
+ List of results in the same order as input specifications
296
+
297
+ Example:
298
+ ```python
299
+ # Without cooperative cancellation
300
+ results = await gather_limited_sync(
301
+ lambda: some_sync_function(arg1),
302
+ lambda: another_sync_function(arg2),
303
+ max_concurrent=3,
304
+ max_rps=2.0,
305
+ retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
306
+ )
307
+
308
+ # With cooperative cancellation
309
+ cancel_event = threading.Event()
310
+ results = await gather_limited_sync(
311
+ lambda: cancellable_sync_function(cancel_event, arg1),
312
+ lambda: another_cancellable_function(cancel_event, arg2),
313
+ cancel_event=cancel_event,
314
+ cancel_timeout=5.0,
315
+ )
316
+ ```
317
+ """
318
+ log.info(
319
+ "Executing with concurrency %s at %s rps, %s",
320
+ max_concurrent,
321
+ max_rps,
322
+ retry_settings,
323
+ )
324
+ if not sync_specs:
325
+ return []
326
+
327
+ retry_settings = retry_settings or NO_RETRIES
328
+
329
+ semaphore = asyncio.Semaphore(max_concurrent)
330
+ rate_limiter = AsyncLimiter(max_rps, 1.0)
331
+
332
+ # Global retry counter (shared across all tasks)
333
+ global_retry_counter = RetryCounter(retry_settings.max_total_retries)
334
+
335
+ async def run_task_with_retry(i: int, sync_spec: SyncSpec[T]) -> T:
336
+ # Generate label for this task
337
+ label = labeler(i, sync_spec) if labeler else f"task:{i}"
338
+ task_id = await status.add(label) if status else None
339
+
340
+ async def executor() -> T:
341
+ # Call sync function via asyncio.to_thread, handling retry at this level
342
+ if isinstance(sync_spec, FuncTask):
343
+ # FuncSpec format: FuncSpec(func, args, kwargs)
344
+ result = await asyncio.to_thread(
345
+ sync_spec.func, *sync_spec.args, **sync_spec.kwargs
346
+ )
347
+ else:
348
+ result = await asyncio.to_thread(sync_spec)
349
+ # Check if the callable returned a coroutine (which would be a bug)
350
+ if inspect.iscoroutine(result):
351
+ # Clean up the coroutine we accidentally created
352
+ result.close()
353
+ raise ValueError(
354
+ "Callable returned a coroutine. "
355
+ "gather_limited_sync() is for synchronous functions only. "
356
+ "Use gather_limited() for async functions."
357
+ )
358
+ return cast(T, result)
359
+
360
+ try:
361
+ result = await _execute_with_retry(
362
+ executor,
363
+ retry_settings,
364
+ semaphore,
365
+ rate_limiter,
366
+ global_retry_counter,
367
+ status,
368
+ task_id,
369
+ )
370
+
371
+ # Mark as completed successfully
372
+ if status and task_id is not None:
373
+ await status.finish(task_id, TaskState.COMPLETED)
374
+
375
+ return result
376
+
377
+ except Exception as e:
378
+ # Mark as failed
379
+ if status and task_id is not None:
380
+ await status.finish(task_id, TaskState.FAILED, str(e))
381
+ raise
382
+
383
+ return await _gather_with_interrupt_handling(
384
+ [run_task_with_retry(i, spec) for i, spec in enumerate(sync_specs)],
385
+ return_exceptions,
386
+ cancel_event=cancel_event,
387
+ cancel_timeout=cancel_timeout,
388
+ )
389
+
390
+
391
+ async def _gather_with_interrupt_handling(
392
+ tasks: list[Coroutine[None, None, T]],
393
+ return_exceptions: bool = False,
394
+ cancel_event: threading.Event | None = None,
395
+ cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
396
+ ) -> list[T] | list[T | BaseException]:
397
+ """
398
+ Execute asyncio.gather with graceful KeyboardInterrupt handling.
399
+
400
+ Args:
401
+ tasks: List of coroutine functions to create tasks from
402
+ return_exceptions: Whether to return exceptions as results
403
+ cancel_event: Optional threading.Event to signal cancellation to sync functions
404
+ cancel_timeout: Max seconds to wait for threads to terminate on cancellation
405
+
406
+ Returns:
407
+ Results from asyncio.gather
408
+
409
+ Raises:
410
+ KeyboardInterrupt: Re-raised after graceful cancellation
411
+ """
412
+ # Create tasks from coroutines so we can cancel them properly
413
+ async_tasks = [asyncio.create_task(task) for task in tasks]
414
+
415
+ try:
416
+ return await asyncio.gather(*async_tasks, return_exceptions=return_exceptions)
417
+ except (KeyboardInterrupt, asyncio.CancelledError) as e:
418
+ # Handle both KeyboardInterrupt and CancelledError (which is what tasks actually receive)
419
+ log.warning("Interrupt received, cancelling %d tasks...", len(async_tasks))
420
+
421
+ # Signal cancellation to sync functions if event provided
422
+ if cancel_event is not None:
423
+ cancel_event.set()
424
+ log.debug("Cancellation event set for cooperative sync function termination")
425
+
426
+ # Cancel all running tasks
427
+ cancelled_count = 0
428
+ for task in async_tasks:
429
+ if not task.done():
430
+ task.cancel()
431
+ cancelled_count += 1
432
+
433
+ # Wait briefly for tasks to cancel
434
+ if cancelled_count > 0:
435
+ try:
436
+ await asyncio.wait_for(
437
+ asyncio.gather(*async_tasks, return_exceptions=True), timeout=cancel_timeout
438
+ )
439
+ except (TimeoutError, asyncio.CancelledError):
440
+ log.warning("Some tasks did not cancel within timeout")
441
+
442
+ # Wait for threads to terminate gracefully
443
+ loop = asyncio.get_running_loop()
444
+ try:
445
+ log.debug("Waiting up to %.1fs for thread pool termination...", cancel_timeout)
446
+ await asyncio.wait_for(
447
+ loop.shutdown_default_executor(),
448
+ timeout=cancel_timeout,
449
+ )
450
+ log.info("Thread pool shutdown completed")
451
+ except TimeoutError:
452
+ log.warning(
453
+ "Thread pool shutdown timed out after %.1fs: some sync functions may still be running",
454
+ cancel_timeout,
455
+ )
456
+
457
+ log.info("Task cancellation completed (%d tasks cancelled)", cancelled_count)
458
+ # Always raise KeyboardInterrupt for consistent behavior
459
+ raise KeyboardInterrupt("User cancellation") from e
460
+
461
+
462
+ async def _execute_with_retry(
463
+ executor: Callable[[], Coroutine[None, None, T]],
464
+ retry_settings: RetrySettings,
465
+ semaphore: asyncio.Semaphore,
466
+ rate_limiter: AsyncLimiter,
467
+ global_retry_counter: RetryCounter,
468
+ status: ProgressTracker | None = None,
469
+ task_id: Any | None = None,
470
+ ) -> T:
471
+ import time
472
+
473
+ start_time = time.time()
474
+ last_exception: Exception | None = None
475
+
476
+ for attempt in range(retry_settings.max_task_retries + 1):
477
+ # Handle backoff before acquiring any resources
478
+ if attempt > 0 and last_exception is not None:
479
+ # Try to increment global retry counter
480
+ if not await global_retry_counter.try_increment():
481
+ log.error(
482
+ f"Global retry limit ({global_retry_counter.max_total_retries}) reached. "
483
+ f"Cannot retry task after: {type(last_exception).__name__}: {last_exception}"
484
+ )
485
+ raise last_exception
486
+
487
+ backoff_time = calculate_backoff(
488
+ attempt - 1, # Previous attempt that failed
489
+ last_exception,
490
+ initial_backoff=retry_settings.initial_backoff,
491
+ max_backoff=retry_settings.max_backoff,
492
+ backoff_factor=retry_settings.backoff_factor,
493
+ )
494
+
495
+ # Record retry in status display and log appropriately
496
+ if status and task_id is not None:
497
+ # Include retry attempt info and backoff time in the status display
498
+ retry_info = (
499
+ f"Attempt {attempt}/{retry_settings.max_task_retries} "
500
+ f"(waiting {backoff_time:.1f}s): {type(last_exception).__name__}: {last_exception}"
501
+ )
502
+ await status.update(task_id, error_msg=retry_info)
503
+
504
+ # Use debug level for Rich trackers, warning/info for console trackers
505
+ use_debug_level = status.suppress_logs
506
+ else:
507
+ # No status display: use full logging
508
+ use_debug_level = False
509
+
510
+ # Log retry information at appropriate level
511
+ rate_limit_msg = (
512
+ f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
513
+ f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
514
+ f"backing off for {backoff_time:.2f}s"
515
+ )
516
+ exception_msg = (
517
+ f"Rate limit exception: {type(last_exception).__name__}: {last_exception}"
518
+ )
519
+
520
+ if use_debug_level:
521
+ log.debug(rate_limit_msg)
522
+ log.debug(exception_msg)
523
+ else:
524
+ log.warning(rate_limit_msg)
525
+ log.info(exception_msg)
526
+ await asyncio.sleep(backoff_time)
527
+
528
+ try:
529
+ # Acquire semaphore and rate limiter right before making the call
530
+ async with semaphore, rate_limiter:
531
+ # Mark task as started now that we've passed rate limiting
532
+ if status and task_id is not None and attempt == 0:
533
+ await status.start(task_id)
534
+ return await executor()
535
+ except Exception as e:
536
+ last_exception = e # Always store the exception
537
+
538
+ if attempt == retry_settings.max_task_retries:
539
+ # Final attempt failed
540
+ if retry_settings.max_task_retries == 0:
541
+ # No retries configured - raise original exception directly
542
+ raise
543
+ else:
544
+ # Retries were attempted but exhausted - wrap with context
545
+ total_time = time.time() - start_time
546
+ log.error(
547
+ f"Max task retries ({retry_settings.max_task_retries}) exhausted after {total_time:.1f}s. "
548
+ f"Final attempt failed with: {type(e).__name__}: {e}"
549
+ )
550
+ raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
551
+
552
+ # Check if this is a retriable exception
553
+ if retry_settings.is_retriable(e):
554
+ # Continue to next retry attempt (global limits will be checked at top of loop)
555
+ continue
556
+ else:
557
+ # Non-retriable exception, log and re-raise immediately
558
+ log.warning("Non-retriable exception (not retrying): %s", e, exc_info=True)
559
+ raise
560
+
561
+ # This should never be reached, but satisfy type checker
562
+ raise RuntimeError("Unexpected code path in _execute_with_retry")
563
+
564
+
565
+ ## Tests
566
+
567
+
568
+ def test_gather_limited_sync():
569
+ """Test gather_limited_sync with sync functions."""
570
+ import asyncio
571
+ import time
572
+
573
+ async def run_test():
574
+ def sync_func(value: int) -> int:
575
+ """Simple sync function for testing."""
576
+ time.sleep(0.1) # Simulate some work
577
+ return value * 2
578
+
579
+ # Test basic functionality
580
+ results = await gather_limited_sync(
581
+ lambda: sync_func(1),
582
+ lambda: sync_func(2),
583
+ lambda: sync_func(3),
584
+ max_concurrent=2,
585
+ max_rps=10.0,
586
+ retry_settings=NO_RETRIES,
587
+ )
588
+
589
+ assert results == [2, 4, 6]
590
+
591
+ # Run the async test
592
+ asyncio.run(run_test())
593
+
594
+
595
+ def test_gather_limited_sync_with_retries():
596
+ """Test that sync functions can be retried on retriable exceptions."""
597
+ import asyncio
598
+
599
+ async def run_test():
600
+ call_count = 0
601
+
602
+ def flaky_sync_func() -> str:
603
+ """Sync function that fails first time, succeeds second time."""
604
+ nonlocal call_count
605
+ call_count += 1
606
+ if call_count == 1:
607
+ raise Exception("Rate limit exceeded") # Retriable
608
+ return "success"
609
+
610
+ # Should succeed after retry
611
+ results = await gather_limited_sync(
612
+ lambda: flaky_sync_func(),
613
+ retry_settings=RetrySettings(
614
+ max_task_retries=2,
615
+ initial_backoff=0.1,
616
+ max_backoff=1.0,
617
+ backoff_factor=2.0,
618
+ ),
619
+ )
620
+
621
+ assert results == ["success"]
622
+ assert call_count == 2 # Called twice (failed once, succeeded once)
623
+
624
+ # Run the async test
625
+ asyncio.run(run_test())
626
+
627
+
628
+ def test_gather_limited_async_basic():
629
+ """Test gather_limited with async functions using callables."""
630
+ import asyncio
631
+
632
+ async def run_test():
633
+ async def async_func(value: int) -> int:
634
+ """Simple async function for testing."""
635
+ await asyncio.sleep(0.05) # Simulate async work
636
+ return value * 3
637
+
638
+ # Test with callables (recommended pattern)
639
+ results = await gather_limited_async(
640
+ lambda: async_func(1),
641
+ lambda: async_func(2),
642
+ lambda: async_func(3),
643
+ max_concurrent=2,
644
+ max_rps=10.0,
645
+ retry_settings=NO_RETRIES,
646
+ )
647
+
648
+ assert results == [3, 6, 9]
649
+
650
+ asyncio.run(run_test())
651
+
652
+
653
+ def test_gather_limited_direct_coroutines():
654
+ """Test gather_limited with direct coroutines when retries disabled."""
655
+ import asyncio
656
+
657
+ async def run_test():
658
+ async def async_func(value: int) -> int:
659
+ await asyncio.sleep(0.05)
660
+ return value * 4
661
+
662
+ # Test with direct coroutines (only works when retries disabled)
663
+ results = await gather_limited_async(
664
+ async_func(1),
665
+ async_func(2),
666
+ async_func(3),
667
+ retry_settings=NO_RETRIES, # Required for direct coroutines
668
+ )
669
+
670
+ assert results == [4, 8, 12]
671
+
672
+ asyncio.run(run_test())
673
+
674
+
675
+ def test_gather_limited_coroutine_retry_validation():
676
+ """Test that passing coroutines with retries enabled raises ValueError."""
677
+ import asyncio
678
+
679
+ async def run_test():
680
+ async def async_func(value: int) -> int:
681
+ return value
682
+
683
+ coro = async_func(1) # Direct coroutine
684
+
685
+ # Should raise ValueError when trying to use coroutines with retries
686
+ try:
687
+ await gather_limited_async(
688
+ coro, # Direct coroutine
689
+ lambda: async_func(2), # Callable
690
+ retry_settings=RetrySettings(
691
+ max_task_retries=1,
692
+ initial_backoff=0.1,
693
+ max_backoff=1.0,
694
+ backoff_factor=2.0,
695
+ ),
696
+ )
697
+ raise AssertionError("Expected ValueError")
698
+ except ValueError as e:
699
+ coro.close() # Close the unused coroutine to prevent RuntimeWarning
700
+ assert "position 0" in str(e)
701
+ assert "cannot be retried" in str(e)
702
+
703
+ asyncio.run(run_test())
704
+
705
+
706
+ def test_gather_limited_async_with_retries():
707
+ """Test that async functions can be retried when using callables."""
708
+ import asyncio
709
+
710
+ async def run_test():
711
+ call_count = 0
712
+
713
+ async def flaky_async_func() -> str:
714
+ """Async function that fails first time, succeeds second time."""
715
+ nonlocal call_count
716
+ call_count += 1
717
+ if call_count == 1:
718
+ raise Exception("Rate limit exceeded") # Retriable
719
+ return "async_success"
720
+
721
+ # Should succeed after retry using callable
722
+ results = await gather_limited_async(
723
+ lambda: flaky_async_func(),
724
+ retry_settings=RetrySettings(
725
+ max_task_retries=2,
726
+ initial_backoff=0.1,
727
+ max_backoff=1.0,
728
+ backoff_factor=2.0,
729
+ ),
730
+ )
731
+
732
+ assert results == ["async_success"]
733
+ assert call_count == 2 # Called twice (failed once, succeeded once)
734
+
735
+ asyncio.run(run_test())
736
+
737
+
738
+ def test_gather_limited_sync_coroutine_validation():
739
+ """Test that passing async function callables to sync version raises ValueError."""
740
+ import asyncio
741
+
742
+ async def run_test():
743
+ async def async_func(value: int) -> int:
744
+ return value
745
+
746
+ # Should raise ValueError when trying to use async functions in sync version
747
+ try:
748
+ await gather_limited_sync(
749
+ lambda: async_func(1), # Returns coroutine - should be rejected
750
+ retry_settings=NO_RETRIES,
751
+ )
752
+ raise AssertionError("Expected ValueError")
753
+ except ValueError as e:
754
+ assert "returned a coroutine" in str(e)
755
+ assert "gather_limited_sync() is for synchronous functions only" in str(e)
756
+
757
+ asyncio.run(run_test())
758
+
759
+
760
+ def test_gather_limited_retry_exhaustion():
761
+ """Test that retry exhaustion produces clear error messages."""
762
+ import asyncio
763
+
764
+ async def run_test():
765
+ call_count = 0
766
+
767
+ def always_fails() -> str:
768
+ """Function that always raises retriable exceptions."""
769
+ nonlocal call_count
770
+ call_count += 1
771
+ raise Exception("Rate limit exceeded") # Always retriable
772
+
773
+ # Should exhaust retries and raise RetryExhaustedException
774
+ try:
775
+ await gather_limited_sync(
776
+ lambda: always_fails(),
777
+ retry_settings=RetrySettings(
778
+ max_task_retries=2,
779
+ initial_backoff=0.01,
780
+ max_backoff=0.1,
781
+ backoff_factor=2.0,
782
+ ),
783
+ )
784
+ raise AssertionError("Expected RetryExhaustedException")
785
+ except RetryExhaustedException as e:
786
+ assert "Max retries (2) exhausted" in str(e)
787
+ assert "Rate limit exceeded" in str(e)
788
+ assert isinstance(e.original_exception, Exception)
789
+ assert call_count == 3 # Initial attempt + 2 retries
790
+
791
+ asyncio.run(run_test())
792
+
793
+
794
+ def test_gather_limited_return_exceptions():
795
+ """Test return_exceptions=True behavior for both functions."""
796
+ import asyncio
797
+
798
+ async def run_test():
799
+ def failing_sync() -> str:
800
+ raise ValueError("sync error")
801
+
802
+ async def failing_async() -> str:
803
+ raise ValueError("async error")
804
+
805
+ # Test sync version with exceptions returned
806
+ sync_results = await gather_limited_sync(
807
+ lambda: "success",
808
+ lambda: failing_sync(),
809
+ return_exceptions=True,
810
+ retry_settings=NO_RETRIES,
811
+ )
812
+
813
+ assert len(sync_results) == 2
814
+ assert sync_results[0] == "success"
815
+ assert isinstance(sync_results[1], ValueError)
816
+ assert str(sync_results[1]) == "sync error"
817
+
818
+ async def success_async() -> str:
819
+ return "async_success"
820
+
821
+ # Test async version with exceptions returned
822
+ async_results = await gather_limited_async(
823
+ lambda: success_async(),
824
+ lambda: failing_async(),
825
+ return_exceptions=True,
826
+ retry_settings=NO_RETRIES,
827
+ )
828
+
829
+ assert len(async_results) == 2
830
+ assert async_results[0] == "async_success"
831
+ assert isinstance(async_results[1], ValueError)
832
+ assert str(async_results[1]) == "async error"
833
+
834
+ asyncio.run(run_test())
835
+
836
+
837
+ def test_gather_limited_global_retry_limit():
838
+ """Test that global retry limits are enforced across all tasks."""
839
+ import asyncio
840
+
841
+ async def run_test():
842
+ retry_counts = {"task1": 0, "task2": 0}
843
+
844
+ def flaky_task(task_name: str) -> str:
845
+ """Tasks that always fail but track retry counts."""
846
+ retry_counts[task_name] += 1
847
+ raise Exception(f"Rate limit exceeded in {task_name}")
848
+
849
+ # Test with very low global retry limit
850
+ try:
851
+ await gather_limited_sync(
852
+ lambda: flaky_task("task1"),
853
+ lambda: flaky_task("task2"),
854
+ retry_settings=RetrySettings(
855
+ max_task_retries=5, # Each task could retry up to 5 times
856
+ max_total_retries=3, # But only 3 total retries across all tasks
857
+ initial_backoff=0.01,
858
+ max_backoff=0.1,
859
+ backoff_factor=2.0,
860
+ ),
861
+ return_exceptions=True,
862
+ )
863
+ except Exception:
864
+ pass # Expected to fail due to rate limits
865
+
866
+ # Verify that total retries across both tasks doesn't exceed global limit
867
+ total_retries = (retry_counts["task1"] - 1) + (
868
+ retry_counts["task2"] - 1
869
+ ) # -1 for initial attempts
870
+ assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
871
+
872
+ # Verify that both tasks were attempted at least once
873
+ assert retry_counts["task1"] >= 1
874
+ assert retry_counts["task2"] >= 1
875
+
876
+ asyncio.run(run_test())
877
+
878
+
879
+ def test_gather_limited_funcspec_format():
880
+ """Test gather_limited with FuncSpec format and custom labeler accessing args."""
881
+ import asyncio
882
+
883
+ async def run_test():
884
+ def sync_func(name: str, value: int, multiplier: int = 2) -> str:
885
+ """Sync function that takes args and kwargs."""
886
+ return f"{name}: {value * multiplier}"
887
+
888
+ async def async_func(name: str, value: int, multiplier: int = 2) -> str:
889
+ """Async function that takes args and kwargs."""
890
+ await asyncio.sleep(0.01)
891
+ return f"{name}: {value * multiplier}"
892
+
893
+ captured_labels = []
894
+
895
+ def custom_labeler(i: int, spec: Any) -> str:
896
+ if isinstance(spec, FuncTask):
897
+ # Extract meaningful info from args for labeling
898
+ if spec.args and len(spec.args) > 0:
899
+ label = f"Processing {spec.args[0]}"
900
+ else:
901
+ label = f"Task {i}"
902
+ else:
903
+ label = f"Task {i}"
904
+ captured_labels.append(label)
905
+ return label
906
+
907
+ # Test sync version with FuncSpec format and custom labeler
908
+ sync_results = await gather_limited_sync(
909
+ FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
910
+ FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
911
+ labeler=custom_labeler,
912
+ retry_settings=NO_RETRIES,
913
+ )
914
+
915
+ assert sync_results == ["user1: 300", "user2: 100"]
916
+ assert captured_labels == ["Processing user1", "Processing user2"]
917
+
918
+ # Reset labels for async test
919
+ captured_labels.clear()
920
+
921
+ # Test async version with FuncSpec format and custom labeler
922
+ async_results = await gather_limited_async(
923
+ FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
924
+ FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
925
+ labeler=custom_labeler,
926
+ retry_settings=NO_RETRIES,
927
+ )
928
+
929
+ assert async_results == ["api_call: 40", "data_fetch: 10"]
930
+ assert captured_labels == ["Processing api_call", "Processing data_fetch"]
931
+
932
+ asyncio.run(run_test())
933
+
934
+
935
+ def test_gather_limited_sync_cooperative_cancellation():
936
+ """Test gather_limited_sync with cooperative cancellation via threading.Event."""
937
+ import asyncio
938
+ import time
939
+
940
+ async def run_test():
941
+ cancel_event = threading.Event()
942
+ call_counts = {"task1": 0, "task2": 0}
943
+
944
+ def cancellable_sync_func(task_name: str, work_duration: float) -> str:
945
+ """Sync function that checks cancellation event periodically."""
946
+ call_counts[task_name] += 1
947
+ start_time = time.time()
948
+
949
+ while time.time() - start_time < work_duration:
950
+ if cancel_event.is_set():
951
+ return f"{task_name}: cancelled"
952
+ time.sleep(0.01) # Small sleep to allow cancellation check
953
+
954
+ return f"{task_name}: completed"
955
+
956
+ # Test cooperative cancellation - tasks should respect the cancel_event
957
+ results = await gather_limited_sync(
958
+ lambda: cancellable_sync_func("task1", 0.1), # Short duration
959
+ lambda: cancellable_sync_func("task2", 0.1), # Short duration
960
+ cancel_event=cancel_event,
961
+ cancel_timeout=1.0,
962
+ retry_settings=NO_RETRIES,
963
+ )
964
+
965
+ # Should complete normally since cancel_event wasn't set
966
+ assert results == ["task1: completed", "task2: completed"]
967
+ assert call_counts["task1"] == 1
968
+ assert call_counts["task2"] == 1
969
+
970
+ # Test that cancel_event can be used independently
971
+ cancel_event.set() # Set cancellation signal
972
+
973
+ results2 = await gather_limited_sync(
974
+ lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
975
+ lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
976
+ cancel_event=cancel_event,
977
+ cancel_timeout=1.0,
978
+ retry_settings=NO_RETRIES,
979
+ )
980
+
981
+ # Should be cancelled quickly since cancel_event is already set
982
+ assert results2 == ["task1: cancelled", "task2: cancelled"]
983
+ # Call counts should increment
984
+ assert call_counts["task1"] == 2
985
+ assert call_counts["task2"] == 2
986
+
987
+ asyncio.run(run_test())