kash-shell 0.3.18__py3-none-any.whl → 0.3.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kash/actions/core/{markdownify.py → markdownify_html.py} +3 -6
- kash/commands/workspace/workspace_commands.py +10 -88
- kash/config/colors.py +8 -6
- kash/config/text_styles.py +2 -0
- kash/docs/markdown/topics/a1_what_is_kash.md +1 -1
- kash/docs/markdown/topics/b1_kash_overview.md +34 -45
- kash/exec/__init__.py +3 -0
- kash/exec/action_decorators.py +20 -5
- kash/exec/action_exec.py +2 -2
- kash/exec/{fetch_url_metadata.py → fetch_url_items.py} +42 -14
- kash/exec/llm_transforms.py +1 -1
- kash/exec/shell_callable_action.py +1 -1
- kash/file_storage/file_store.py +7 -1
- kash/file_storage/store_filenames.py +4 -0
- kash/help/function_param_info.py +1 -1
- kash/help/help_pages.py +1 -1
- kash/help/help_printing.py +1 -1
- kash/llm_utils/llm_completion.py +1 -1
- kash/model/actions_model.py +6 -0
- kash/model/items_model.py +18 -3
- kash/shell/output/shell_output.py +15 -0
- kash/utils/api_utils/api_retries.py +305 -0
- kash/utils/api_utils/cache_requests_limited.py +84 -0
- kash/utils/api_utils/gather_limited.py +987 -0
- kash/utils/api_utils/progress_protocol.py +299 -0
- kash/utils/common/function_inspect.py +66 -1
- kash/utils/common/parse_docstring.py +347 -0
- kash/utils/common/testing.py +10 -7
- kash/utils/rich_custom/multitask_status.py +631 -0
- kash/utils/text_handling/escape_html_tags.py +16 -11
- kash/utils/text_handling/markdown_render.py +1 -0
- kash/web_content/web_extract.py +34 -15
- kash/web_content/web_page_model.py +10 -1
- kash/web_gen/templates/base_styles.css.jinja +26 -20
- kash/web_gen/templates/components/toc_styles.css.jinja +1 -1
- kash/web_gen/templates/components/tooltip_scripts.js.jinja +171 -19
- kash/web_gen/templates/components/tooltip_styles.css.jinja +23 -8
- {kash_shell-0.3.18.dist-info → kash_shell-0.3.21.dist-info}/METADATA +4 -2
- {kash_shell-0.3.18.dist-info → kash_shell-0.3.21.dist-info}/RECORD +42 -37
- kash/help/docstring_utils.py +0 -111
- {kash_shell-0.3.18.dist-info → kash_shell-0.3.21.dist-info}/WHEEL +0 -0
- {kash_shell-0.3.18.dist-info → kash_shell-0.3.21.dist-info}/entry_points.txt +0 -0
- {kash_shell-0.3.18.dist-info → kash_shell-0.3.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,987 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
from collections.abc import Callable, Coroutine
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any, TypeAlias, TypeVar, cast, overload
|
|
10
|
+
|
|
11
|
+
from aiolimiter import AsyncLimiter
|
|
12
|
+
|
|
13
|
+
from kash.utils.api_utils.api_retries import (
|
|
14
|
+
DEFAULT_RETRIES,
|
|
15
|
+
NO_RETRIES,
|
|
16
|
+
RetryExhaustedException,
|
|
17
|
+
RetrySettings,
|
|
18
|
+
calculate_backoff,
|
|
19
|
+
)
|
|
20
|
+
from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class FuncTask:
|
|
29
|
+
"""
|
|
30
|
+
A task described as an unevaluated function with args and kwargs.
|
|
31
|
+
This task format allows you to use args and kwargs in the Labeler.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
func: Callable[..., Any]
|
|
35
|
+
args: tuple[Any, ...] = ()
|
|
36
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Type aliases for coroutine and sync specifications, including unevaluated function specs
|
|
40
|
+
CoroSpec: TypeAlias = Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask
|
|
41
|
+
SyncSpec: TypeAlias = Callable[[], T] | FuncTask
|
|
42
|
+
|
|
43
|
+
# Specific labeler types using the generic Labeler pattern
|
|
44
|
+
CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
|
|
45
|
+
SyncLabeler: TypeAlias = Labeler[SyncSpec[T]]
|
|
46
|
+
|
|
47
|
+
DEFAULT_MAX_CONCURRENT: int = 5
|
|
48
|
+
DEFAULT_MAX_RPS: float = 5.0
|
|
49
|
+
DEFAULT_CANCEL_TIMEOUT: float = 1.0
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class RetryCounter:
|
|
53
|
+
"""Thread-safe counter for tracking retries across all tasks."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, max_total_retries: int | None):
|
|
56
|
+
self.max_total_retries = max_total_retries
|
|
57
|
+
self.count = 0
|
|
58
|
+
self._lock = asyncio.Lock()
|
|
59
|
+
|
|
60
|
+
async def try_increment(self) -> bool:
|
|
61
|
+
"""
|
|
62
|
+
Try to increment the retry counter.
|
|
63
|
+
Returns True if increment was successful, False if limit reached.
|
|
64
|
+
"""
|
|
65
|
+
if self.max_total_retries is None:
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
async with self._lock:
|
|
69
|
+
if self.count < self.max_total_retries:
|
|
70
|
+
self.count += 1
|
|
71
|
+
return True
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@overload
|
|
76
|
+
async def gather_limited_async(
|
|
77
|
+
*coro_specs: CoroSpec[T],
|
|
78
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
79
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
80
|
+
return_exceptions: bool = False,
|
|
81
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
82
|
+
status: ProgressTracker | None = None,
|
|
83
|
+
labeler: CoroLabeler[T] | None = None,
|
|
84
|
+
) -> list[T]: ...
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@overload
|
|
88
|
+
async def gather_limited_async(
|
|
89
|
+
*coro_specs: CoroSpec[T],
|
|
90
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
91
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
92
|
+
return_exceptions: bool = True,
|
|
93
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
94
|
+
status: ProgressTracker | None = None,
|
|
95
|
+
labeler: CoroLabeler[T] | None = None,
|
|
96
|
+
) -> list[T | BaseException]: ...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def gather_limited_async(
|
|
100
|
+
*coro_specs: CoroSpec[T],
|
|
101
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
102
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
103
|
+
return_exceptions: bool = False,
|
|
104
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
105
|
+
status: ProgressTracker | None = None,
|
|
106
|
+
labeler: CoroLabeler[T] | None = None,
|
|
107
|
+
) -> list[T] | list[T | BaseException]:
|
|
108
|
+
"""
|
|
109
|
+
Rate-limited version of `asyncio.gather()` with retry logic and optional progress display.
|
|
110
|
+
Uses the aiolimiter leaky-bucket algorithm with exponential backoff on failures.
|
|
111
|
+
|
|
112
|
+
Supports two levels of retry limits:
|
|
113
|
+
- Per-task retries: max_task_retries attempts per individual task
|
|
114
|
+
- Global retries: max_total_retries attempts across all tasks (prevents cascade failures)
|
|
115
|
+
|
|
116
|
+
Can optionally display live progress with retry indicators using TaskStatus.
|
|
117
|
+
|
|
118
|
+
Accepts:
|
|
119
|
+
- Callables that return coroutines: `lambda: some_async_func(arg)` (recommended for retries)
|
|
120
|
+
- Coroutines directly: `some_async_func(arg)` (only if retries disabled)
|
|
121
|
+
- FuncSpec objects: `FuncSpec(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
|
|
122
|
+
|
|
123
|
+
Examples:
|
|
124
|
+
```python
|
|
125
|
+
# With progress display and custom labeling:
|
|
126
|
+
from kash.utils.rich_custom.task_status import TaskStatus
|
|
127
|
+
|
|
128
|
+
async with TaskStatus() as status:
|
|
129
|
+
await gather_limited(
|
|
130
|
+
lambda: fetch_url("http://example.com"),
|
|
131
|
+
lambda: process_data(data),
|
|
132
|
+
status=status,
|
|
133
|
+
labeler=lambda i, spec: f"Task {i+1}",
|
|
134
|
+
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Without progress display:
|
|
138
|
+
await gather_limited(
|
|
139
|
+
lambda: fetch_url("http://example.com"),
|
|
140
|
+
lambda: process_data(data),
|
|
141
|
+
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
*coro_specs: Callables or coroutines to execute
|
|
148
|
+
max_concurrent: Maximum number of concurrent executions
|
|
149
|
+
max_rps: Maximum requests per second
|
|
150
|
+
return_exceptions: If True, exceptions are returned as results
|
|
151
|
+
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
152
|
+
status: Optional ProgressTracker instance for progress display
|
|
153
|
+
labeler: Optional function to generate labels: labeler(index, spec) -> str
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of results in the same order as input specifications
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If coroutines are passed when retries are enabled
|
|
160
|
+
"""
|
|
161
|
+
log.info(
|
|
162
|
+
"Executing with concurrency %s at %s rps, %s",
|
|
163
|
+
max_concurrent,
|
|
164
|
+
max_rps,
|
|
165
|
+
retry_settings,
|
|
166
|
+
)
|
|
167
|
+
if not coro_specs:
|
|
168
|
+
return []
|
|
169
|
+
|
|
170
|
+
retry_settings = retry_settings or NO_RETRIES
|
|
171
|
+
|
|
172
|
+
# Validate that coroutines aren't used when retries are enabled
|
|
173
|
+
if retry_settings.max_task_retries > 0:
|
|
174
|
+
for i, spec in enumerate(coro_specs):
|
|
175
|
+
if inspect.iscoroutine(spec):
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"Coroutine at position {i} cannot be retried. "
|
|
178
|
+
f"When retries are enabled (max_task_retries > 0), pass callables that return fresh coroutines: "
|
|
179
|
+
f"lambda: your_async_func(args) instead of your_async_func(args)"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
183
|
+
rate_limiter = AsyncLimiter(max_rps, 1.0)
|
|
184
|
+
|
|
185
|
+
# Global retry counter (shared across all tasks)
|
|
186
|
+
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
187
|
+
|
|
188
|
+
async def run_task_with_retry(i: int, coro_spec: CoroSpec[T]) -> T:
|
|
189
|
+
# Generate label for this task
|
|
190
|
+
label = labeler(i, coro_spec) if labeler else f"task:{i}"
|
|
191
|
+
task_id = await status.add(label) if status else None
|
|
192
|
+
|
|
193
|
+
async def executor() -> T:
|
|
194
|
+
# Create a fresh coroutine for each attempt
|
|
195
|
+
if isinstance(coro_spec, FuncTask):
|
|
196
|
+
# FuncSpec format: FuncSpec(func, args, kwargs)
|
|
197
|
+
coro = coro_spec.func(*coro_spec.args, **coro_spec.kwargs)
|
|
198
|
+
elif callable(coro_spec):
|
|
199
|
+
coro = coro_spec()
|
|
200
|
+
else:
|
|
201
|
+
# Direct coroutine - only valid if retries disabled
|
|
202
|
+
coro = coro_spec
|
|
203
|
+
return await coro
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
result = await _execute_with_retry(
|
|
207
|
+
executor,
|
|
208
|
+
retry_settings,
|
|
209
|
+
semaphore,
|
|
210
|
+
rate_limiter,
|
|
211
|
+
global_retry_counter,
|
|
212
|
+
status,
|
|
213
|
+
task_id,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Mark as completed successfully
|
|
217
|
+
if status and task_id is not None:
|
|
218
|
+
await status.finish(task_id, TaskState.COMPLETED)
|
|
219
|
+
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
# Mark as failed
|
|
224
|
+
if status and task_id is not None:
|
|
225
|
+
await status.finish(task_id, TaskState.FAILED, str(e))
|
|
226
|
+
raise
|
|
227
|
+
|
|
228
|
+
return await _gather_with_interrupt_handling(
|
|
229
|
+
[run_task_with_retry(i, spec) for i, spec in enumerate(coro_specs)],
|
|
230
|
+
return_exceptions,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@overload
|
|
235
|
+
async def gather_limited_sync(
|
|
236
|
+
*sync_specs: SyncSpec[T],
|
|
237
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
238
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
239
|
+
return_exceptions: bool = False,
|
|
240
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
241
|
+
status: ProgressTracker | None = None,
|
|
242
|
+
labeler: SyncLabeler[T] | None = None,
|
|
243
|
+
cancel_event: threading.Event | None = None,
|
|
244
|
+
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
245
|
+
) -> list[T]: ...
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@overload
|
|
249
|
+
async def gather_limited_sync(
|
|
250
|
+
*sync_specs: SyncSpec[T],
|
|
251
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
252
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
253
|
+
return_exceptions: bool = True,
|
|
254
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
255
|
+
status: ProgressTracker | None = None,
|
|
256
|
+
labeler: SyncLabeler[T] | None = None,
|
|
257
|
+
cancel_event: threading.Event | None = None,
|
|
258
|
+
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
259
|
+
) -> list[T | BaseException]: ...
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
async def gather_limited_sync(
|
|
263
|
+
*sync_specs: SyncSpec[T],
|
|
264
|
+
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
|
|
265
|
+
max_rps: float = DEFAULT_MAX_RPS,
|
|
266
|
+
return_exceptions: bool = False,
|
|
267
|
+
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
268
|
+
status: ProgressTracker | None = None,
|
|
269
|
+
labeler: SyncLabeler[T] | None = None,
|
|
270
|
+
cancel_event: threading.Event | None = None,
|
|
271
|
+
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
272
|
+
) -> list[T] | list[T | BaseException]:
|
|
273
|
+
"""
|
|
274
|
+
Rate-limited version of `asyncio.gather()` for sync functions with retry logic.
|
|
275
|
+
Handles the asyncio.to_thread() boundary correctly for consistent exception propagation.
|
|
276
|
+
|
|
277
|
+
Supports two levels of retry limits:
|
|
278
|
+
- Per-task retries: max_task_retries attempts per individual task
|
|
279
|
+
- Global retries: max_total_retries attempts across all tasks
|
|
280
|
+
|
|
281
|
+
Supports cooperative cancellation and graceful thread termination on interruption.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
*sync_specs: Callables that return values (not coroutines) or FuncTask objects
|
|
285
|
+
max_concurrent: Maximum number of concurrent executions
|
|
286
|
+
max_rps: Maximum requests per second
|
|
287
|
+
return_exceptions: If True, exceptions are returned as results
|
|
288
|
+
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
289
|
+
status: Optional ProgressTracker instance for progress display
|
|
290
|
+
labeler: Optional function to generate labels: labeler(index, spec) -> str
|
|
291
|
+
cancel_event: Optional threading.Event that will be set on cancellation
|
|
292
|
+
cancel_timeout: Max seconds to wait for threads to terminate on cancellation
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
List of results in the same order as input specifications
|
|
296
|
+
|
|
297
|
+
Example:
|
|
298
|
+
```python
|
|
299
|
+
# Without cooperative cancellation
|
|
300
|
+
results = await gather_limited_sync(
|
|
301
|
+
lambda: some_sync_function(arg1),
|
|
302
|
+
lambda: another_sync_function(arg2),
|
|
303
|
+
max_concurrent=3,
|
|
304
|
+
max_rps=2.0,
|
|
305
|
+
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# With cooperative cancellation
|
|
309
|
+
cancel_event = threading.Event()
|
|
310
|
+
results = await gather_limited_sync(
|
|
311
|
+
lambda: cancellable_sync_function(cancel_event, arg1),
|
|
312
|
+
lambda: another_cancellable_function(cancel_event, arg2),
|
|
313
|
+
cancel_event=cancel_event,
|
|
314
|
+
cancel_timeout=5.0,
|
|
315
|
+
)
|
|
316
|
+
```
|
|
317
|
+
"""
|
|
318
|
+
log.info(
|
|
319
|
+
"Executing with concurrency %s at %s rps, %s",
|
|
320
|
+
max_concurrent,
|
|
321
|
+
max_rps,
|
|
322
|
+
retry_settings,
|
|
323
|
+
)
|
|
324
|
+
if not sync_specs:
|
|
325
|
+
return []
|
|
326
|
+
|
|
327
|
+
retry_settings = retry_settings or NO_RETRIES
|
|
328
|
+
|
|
329
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
330
|
+
rate_limiter = AsyncLimiter(max_rps, 1.0)
|
|
331
|
+
|
|
332
|
+
# Global retry counter (shared across all tasks)
|
|
333
|
+
global_retry_counter = RetryCounter(retry_settings.max_total_retries)
|
|
334
|
+
|
|
335
|
+
async def run_task_with_retry(i: int, sync_spec: SyncSpec[T]) -> T:
|
|
336
|
+
# Generate label for this task
|
|
337
|
+
label = labeler(i, sync_spec) if labeler else f"task:{i}"
|
|
338
|
+
task_id = await status.add(label) if status else None
|
|
339
|
+
|
|
340
|
+
async def executor() -> T:
|
|
341
|
+
# Call sync function via asyncio.to_thread, handling retry at this level
|
|
342
|
+
if isinstance(sync_spec, FuncTask):
|
|
343
|
+
# FuncSpec format: FuncSpec(func, args, kwargs)
|
|
344
|
+
result = await asyncio.to_thread(
|
|
345
|
+
sync_spec.func, *sync_spec.args, **sync_spec.kwargs
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
result = await asyncio.to_thread(sync_spec)
|
|
349
|
+
# Check if the callable returned a coroutine (which would be a bug)
|
|
350
|
+
if inspect.iscoroutine(result):
|
|
351
|
+
# Clean up the coroutine we accidentally created
|
|
352
|
+
result.close()
|
|
353
|
+
raise ValueError(
|
|
354
|
+
"Callable returned a coroutine. "
|
|
355
|
+
"gather_limited_sync() is for synchronous functions only. "
|
|
356
|
+
"Use gather_limited() for async functions."
|
|
357
|
+
)
|
|
358
|
+
return cast(T, result)
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
result = await _execute_with_retry(
|
|
362
|
+
executor,
|
|
363
|
+
retry_settings,
|
|
364
|
+
semaphore,
|
|
365
|
+
rate_limiter,
|
|
366
|
+
global_retry_counter,
|
|
367
|
+
status,
|
|
368
|
+
task_id,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Mark as completed successfully
|
|
372
|
+
if status and task_id is not None:
|
|
373
|
+
await status.finish(task_id, TaskState.COMPLETED)
|
|
374
|
+
|
|
375
|
+
return result
|
|
376
|
+
|
|
377
|
+
except Exception as e:
|
|
378
|
+
# Mark as failed
|
|
379
|
+
if status and task_id is not None:
|
|
380
|
+
await status.finish(task_id, TaskState.FAILED, str(e))
|
|
381
|
+
raise
|
|
382
|
+
|
|
383
|
+
return await _gather_with_interrupt_handling(
|
|
384
|
+
[run_task_with_retry(i, spec) for i, spec in enumerate(sync_specs)],
|
|
385
|
+
return_exceptions,
|
|
386
|
+
cancel_event=cancel_event,
|
|
387
|
+
cancel_timeout=cancel_timeout,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
async def _gather_with_interrupt_handling(
|
|
392
|
+
tasks: list[Coroutine[None, None, T]],
|
|
393
|
+
return_exceptions: bool = False,
|
|
394
|
+
cancel_event: threading.Event | None = None,
|
|
395
|
+
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
396
|
+
) -> list[T] | list[T | BaseException]:
|
|
397
|
+
"""
|
|
398
|
+
Execute asyncio.gather with graceful KeyboardInterrupt handling.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
tasks: List of coroutine functions to create tasks from
|
|
402
|
+
return_exceptions: Whether to return exceptions as results
|
|
403
|
+
cancel_event: Optional threading.Event to signal cancellation to sync functions
|
|
404
|
+
cancel_timeout: Max seconds to wait for threads to terminate on cancellation
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Results from asyncio.gather
|
|
408
|
+
|
|
409
|
+
Raises:
|
|
410
|
+
KeyboardInterrupt: Re-raised after graceful cancellation
|
|
411
|
+
"""
|
|
412
|
+
# Create tasks from coroutines so we can cancel them properly
|
|
413
|
+
async_tasks = [asyncio.create_task(task) for task in tasks]
|
|
414
|
+
|
|
415
|
+
try:
|
|
416
|
+
return await asyncio.gather(*async_tasks, return_exceptions=return_exceptions)
|
|
417
|
+
except (KeyboardInterrupt, asyncio.CancelledError) as e:
|
|
418
|
+
# Handle both KeyboardInterrupt and CancelledError (which is what tasks actually receive)
|
|
419
|
+
log.warning("Interrupt received, cancelling %d tasks...", len(async_tasks))
|
|
420
|
+
|
|
421
|
+
# Signal cancellation to sync functions if event provided
|
|
422
|
+
if cancel_event is not None:
|
|
423
|
+
cancel_event.set()
|
|
424
|
+
log.debug("Cancellation event set for cooperative sync function termination")
|
|
425
|
+
|
|
426
|
+
# Cancel all running tasks
|
|
427
|
+
cancelled_count = 0
|
|
428
|
+
for task in async_tasks:
|
|
429
|
+
if not task.done():
|
|
430
|
+
task.cancel()
|
|
431
|
+
cancelled_count += 1
|
|
432
|
+
|
|
433
|
+
# Wait briefly for tasks to cancel
|
|
434
|
+
if cancelled_count > 0:
|
|
435
|
+
try:
|
|
436
|
+
await asyncio.wait_for(
|
|
437
|
+
asyncio.gather(*async_tasks, return_exceptions=True), timeout=cancel_timeout
|
|
438
|
+
)
|
|
439
|
+
except (TimeoutError, asyncio.CancelledError):
|
|
440
|
+
log.warning("Some tasks did not cancel within timeout")
|
|
441
|
+
|
|
442
|
+
# Wait for threads to terminate gracefully
|
|
443
|
+
loop = asyncio.get_running_loop()
|
|
444
|
+
try:
|
|
445
|
+
log.debug("Waiting up to %.1fs for thread pool termination...", cancel_timeout)
|
|
446
|
+
await asyncio.wait_for(
|
|
447
|
+
loop.shutdown_default_executor(),
|
|
448
|
+
timeout=cancel_timeout,
|
|
449
|
+
)
|
|
450
|
+
log.info("Thread pool shutdown completed")
|
|
451
|
+
except TimeoutError:
|
|
452
|
+
log.warning(
|
|
453
|
+
"Thread pool shutdown timed out after %.1fs: some sync functions may still be running",
|
|
454
|
+
cancel_timeout,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
log.info("Task cancellation completed (%d tasks cancelled)", cancelled_count)
|
|
458
|
+
# Always raise KeyboardInterrupt for consistent behavior
|
|
459
|
+
raise KeyboardInterrupt("User cancellation") from e
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def _execute_with_retry(
|
|
463
|
+
executor: Callable[[], Coroutine[None, None, T]],
|
|
464
|
+
retry_settings: RetrySettings,
|
|
465
|
+
semaphore: asyncio.Semaphore,
|
|
466
|
+
rate_limiter: AsyncLimiter,
|
|
467
|
+
global_retry_counter: RetryCounter,
|
|
468
|
+
status: ProgressTracker | None = None,
|
|
469
|
+
task_id: Any | None = None,
|
|
470
|
+
) -> T:
|
|
471
|
+
import time
|
|
472
|
+
|
|
473
|
+
start_time = time.time()
|
|
474
|
+
last_exception: Exception | None = None
|
|
475
|
+
|
|
476
|
+
for attempt in range(retry_settings.max_task_retries + 1):
|
|
477
|
+
# Handle backoff before acquiring any resources
|
|
478
|
+
if attempt > 0 and last_exception is not None:
|
|
479
|
+
# Try to increment global retry counter
|
|
480
|
+
if not await global_retry_counter.try_increment():
|
|
481
|
+
log.error(
|
|
482
|
+
f"Global retry limit ({global_retry_counter.max_total_retries}) reached. "
|
|
483
|
+
f"Cannot retry task after: {type(last_exception).__name__}: {last_exception}"
|
|
484
|
+
)
|
|
485
|
+
raise last_exception
|
|
486
|
+
|
|
487
|
+
backoff_time = calculate_backoff(
|
|
488
|
+
attempt - 1, # Previous attempt that failed
|
|
489
|
+
last_exception,
|
|
490
|
+
initial_backoff=retry_settings.initial_backoff,
|
|
491
|
+
max_backoff=retry_settings.max_backoff,
|
|
492
|
+
backoff_factor=retry_settings.backoff_factor,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
# Record retry in status display and log appropriately
|
|
496
|
+
if status and task_id is not None:
|
|
497
|
+
# Include retry attempt info and backoff time in the status display
|
|
498
|
+
retry_info = (
|
|
499
|
+
f"Attempt {attempt}/{retry_settings.max_task_retries} "
|
|
500
|
+
f"(waiting {backoff_time:.1f}s): {type(last_exception).__name__}: {last_exception}"
|
|
501
|
+
)
|
|
502
|
+
await status.update(task_id, error_msg=retry_info)
|
|
503
|
+
|
|
504
|
+
# Use debug level for Rich trackers, warning/info for console trackers
|
|
505
|
+
use_debug_level = status.suppress_logs
|
|
506
|
+
else:
|
|
507
|
+
# No status display: use full logging
|
|
508
|
+
use_debug_level = False
|
|
509
|
+
|
|
510
|
+
# Log retry information at appropriate level
|
|
511
|
+
rate_limit_msg = (
|
|
512
|
+
f"Rate limit hit (attempt {attempt}/{retry_settings.max_task_retries} "
|
|
513
|
+
f"{global_retry_counter.count}/{global_retry_counter.max_total_retries or '∞'} total) "
|
|
514
|
+
f"backing off for {backoff_time:.2f}s"
|
|
515
|
+
)
|
|
516
|
+
exception_msg = (
|
|
517
|
+
f"Rate limit exception: {type(last_exception).__name__}: {last_exception}"
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
if use_debug_level:
|
|
521
|
+
log.debug(rate_limit_msg)
|
|
522
|
+
log.debug(exception_msg)
|
|
523
|
+
else:
|
|
524
|
+
log.warning(rate_limit_msg)
|
|
525
|
+
log.info(exception_msg)
|
|
526
|
+
await asyncio.sleep(backoff_time)
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
# Acquire semaphore and rate limiter right before making the call
|
|
530
|
+
async with semaphore, rate_limiter:
|
|
531
|
+
# Mark task as started now that we've passed rate limiting
|
|
532
|
+
if status and task_id is not None and attempt == 0:
|
|
533
|
+
await status.start(task_id)
|
|
534
|
+
return await executor()
|
|
535
|
+
except Exception as e:
|
|
536
|
+
last_exception = e # Always store the exception
|
|
537
|
+
|
|
538
|
+
if attempt == retry_settings.max_task_retries:
|
|
539
|
+
# Final attempt failed
|
|
540
|
+
if retry_settings.max_task_retries == 0:
|
|
541
|
+
# No retries configured - raise original exception directly
|
|
542
|
+
raise
|
|
543
|
+
else:
|
|
544
|
+
# Retries were attempted but exhausted - wrap with context
|
|
545
|
+
total_time = time.time() - start_time
|
|
546
|
+
log.error(
|
|
547
|
+
f"Max task retries ({retry_settings.max_task_retries}) exhausted after {total_time:.1f}s. "
|
|
548
|
+
f"Final attempt failed with: {type(e).__name__}: {e}"
|
|
549
|
+
)
|
|
550
|
+
raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
|
|
551
|
+
|
|
552
|
+
# Check if this is a retriable exception
|
|
553
|
+
if retry_settings.is_retriable(e):
|
|
554
|
+
# Continue to next retry attempt (global limits will be checked at top of loop)
|
|
555
|
+
continue
|
|
556
|
+
else:
|
|
557
|
+
# Non-retriable exception, log and re-raise immediately
|
|
558
|
+
log.warning("Non-retriable exception (not retrying): %s", e, exc_info=True)
|
|
559
|
+
raise
|
|
560
|
+
|
|
561
|
+
# This should never be reached, but satisfy type checker
|
|
562
|
+
raise RuntimeError("Unexpected code path in _execute_with_retry")
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
## Tests
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def test_gather_limited_sync():
|
|
569
|
+
"""Test gather_limited_sync with sync functions."""
|
|
570
|
+
import asyncio
|
|
571
|
+
import time
|
|
572
|
+
|
|
573
|
+
async def run_test():
|
|
574
|
+
def sync_func(value: int) -> int:
|
|
575
|
+
"""Simple sync function for testing."""
|
|
576
|
+
time.sleep(0.1) # Simulate some work
|
|
577
|
+
return value * 2
|
|
578
|
+
|
|
579
|
+
# Test basic functionality
|
|
580
|
+
results = await gather_limited_sync(
|
|
581
|
+
lambda: sync_func(1),
|
|
582
|
+
lambda: sync_func(2),
|
|
583
|
+
lambda: sync_func(3),
|
|
584
|
+
max_concurrent=2,
|
|
585
|
+
max_rps=10.0,
|
|
586
|
+
retry_settings=NO_RETRIES,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
assert results == [2, 4, 6]
|
|
590
|
+
|
|
591
|
+
# Run the async test
|
|
592
|
+
asyncio.run(run_test())
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def test_gather_limited_sync_with_retries():
|
|
596
|
+
"""Test that sync functions can be retried on retriable exceptions."""
|
|
597
|
+
import asyncio
|
|
598
|
+
|
|
599
|
+
async def run_test():
|
|
600
|
+
call_count = 0
|
|
601
|
+
|
|
602
|
+
def flaky_sync_func() -> str:
|
|
603
|
+
"""Sync function that fails first time, succeeds second time."""
|
|
604
|
+
nonlocal call_count
|
|
605
|
+
call_count += 1
|
|
606
|
+
if call_count == 1:
|
|
607
|
+
raise Exception("Rate limit exceeded") # Retriable
|
|
608
|
+
return "success"
|
|
609
|
+
|
|
610
|
+
# Should succeed after retry
|
|
611
|
+
results = await gather_limited_sync(
|
|
612
|
+
lambda: flaky_sync_func(),
|
|
613
|
+
retry_settings=RetrySettings(
|
|
614
|
+
max_task_retries=2,
|
|
615
|
+
initial_backoff=0.1,
|
|
616
|
+
max_backoff=1.0,
|
|
617
|
+
backoff_factor=2.0,
|
|
618
|
+
),
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
assert results == ["success"]
|
|
622
|
+
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
623
|
+
|
|
624
|
+
# Run the async test
|
|
625
|
+
asyncio.run(run_test())
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def test_gather_limited_async_basic():
|
|
629
|
+
"""Test gather_limited with async functions using callables."""
|
|
630
|
+
import asyncio
|
|
631
|
+
|
|
632
|
+
async def run_test():
|
|
633
|
+
async def async_func(value: int) -> int:
|
|
634
|
+
"""Simple async function for testing."""
|
|
635
|
+
await asyncio.sleep(0.05) # Simulate async work
|
|
636
|
+
return value * 3
|
|
637
|
+
|
|
638
|
+
# Test with callables (recommended pattern)
|
|
639
|
+
results = await gather_limited_async(
|
|
640
|
+
lambda: async_func(1),
|
|
641
|
+
lambda: async_func(2),
|
|
642
|
+
lambda: async_func(3),
|
|
643
|
+
max_concurrent=2,
|
|
644
|
+
max_rps=10.0,
|
|
645
|
+
retry_settings=NO_RETRIES,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
assert results == [3, 6, 9]
|
|
649
|
+
|
|
650
|
+
asyncio.run(run_test())
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def test_gather_limited_direct_coroutines():
|
|
654
|
+
"""Test gather_limited with direct coroutines when retries disabled."""
|
|
655
|
+
import asyncio
|
|
656
|
+
|
|
657
|
+
async def run_test():
|
|
658
|
+
async def async_func(value: int) -> int:
|
|
659
|
+
await asyncio.sleep(0.05)
|
|
660
|
+
return value * 4
|
|
661
|
+
|
|
662
|
+
# Test with direct coroutines (only works when retries disabled)
|
|
663
|
+
results = await gather_limited_async(
|
|
664
|
+
async_func(1),
|
|
665
|
+
async_func(2),
|
|
666
|
+
async_func(3),
|
|
667
|
+
retry_settings=NO_RETRIES, # Required for direct coroutines
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
assert results == [4, 8, 12]
|
|
671
|
+
|
|
672
|
+
asyncio.run(run_test())
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def test_gather_limited_coroutine_retry_validation():
|
|
676
|
+
"""Test that passing coroutines with retries enabled raises ValueError."""
|
|
677
|
+
import asyncio
|
|
678
|
+
|
|
679
|
+
async def run_test():
|
|
680
|
+
async def async_func(value: int) -> int:
|
|
681
|
+
return value
|
|
682
|
+
|
|
683
|
+
coro = async_func(1) # Direct coroutine
|
|
684
|
+
|
|
685
|
+
# Should raise ValueError when trying to use coroutines with retries
|
|
686
|
+
try:
|
|
687
|
+
await gather_limited_async(
|
|
688
|
+
coro, # Direct coroutine
|
|
689
|
+
lambda: async_func(2), # Callable
|
|
690
|
+
retry_settings=RetrySettings(
|
|
691
|
+
max_task_retries=1,
|
|
692
|
+
initial_backoff=0.1,
|
|
693
|
+
max_backoff=1.0,
|
|
694
|
+
backoff_factor=2.0,
|
|
695
|
+
),
|
|
696
|
+
)
|
|
697
|
+
raise AssertionError("Expected ValueError")
|
|
698
|
+
except ValueError as e:
|
|
699
|
+
coro.close() # Close the unused coroutine to prevent RuntimeWarning
|
|
700
|
+
assert "position 0" in str(e)
|
|
701
|
+
assert "cannot be retried" in str(e)
|
|
702
|
+
|
|
703
|
+
asyncio.run(run_test())
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def test_gather_limited_async_with_retries():
|
|
707
|
+
"""Test that async functions can be retried when using callables."""
|
|
708
|
+
import asyncio
|
|
709
|
+
|
|
710
|
+
async def run_test():
|
|
711
|
+
call_count = 0
|
|
712
|
+
|
|
713
|
+
async def flaky_async_func() -> str:
|
|
714
|
+
"""Async function that fails first time, succeeds second time."""
|
|
715
|
+
nonlocal call_count
|
|
716
|
+
call_count += 1
|
|
717
|
+
if call_count == 1:
|
|
718
|
+
raise Exception("Rate limit exceeded") # Retriable
|
|
719
|
+
return "async_success"
|
|
720
|
+
|
|
721
|
+
# Should succeed after retry using callable
|
|
722
|
+
results = await gather_limited_async(
|
|
723
|
+
lambda: flaky_async_func(),
|
|
724
|
+
retry_settings=RetrySettings(
|
|
725
|
+
max_task_retries=2,
|
|
726
|
+
initial_backoff=0.1,
|
|
727
|
+
max_backoff=1.0,
|
|
728
|
+
backoff_factor=2.0,
|
|
729
|
+
),
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
assert results == ["async_success"]
|
|
733
|
+
assert call_count == 2 # Called twice (failed once, succeeded once)
|
|
734
|
+
|
|
735
|
+
asyncio.run(run_test())
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def test_gather_limited_sync_coroutine_validation():
|
|
739
|
+
"""Test that passing async function callables to sync version raises ValueError."""
|
|
740
|
+
import asyncio
|
|
741
|
+
|
|
742
|
+
async def run_test():
|
|
743
|
+
async def async_func(value: int) -> int:
|
|
744
|
+
return value
|
|
745
|
+
|
|
746
|
+
# Should raise ValueError when trying to use async functions in sync version
|
|
747
|
+
try:
|
|
748
|
+
await gather_limited_sync(
|
|
749
|
+
lambda: async_func(1), # Returns coroutine - should be rejected
|
|
750
|
+
retry_settings=NO_RETRIES,
|
|
751
|
+
)
|
|
752
|
+
raise AssertionError("Expected ValueError")
|
|
753
|
+
except ValueError as e:
|
|
754
|
+
assert "returned a coroutine" in str(e)
|
|
755
|
+
assert "gather_limited_sync() is for synchronous functions only" in str(e)
|
|
756
|
+
|
|
757
|
+
asyncio.run(run_test())
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def test_gather_limited_retry_exhaustion():
|
|
761
|
+
"""Test that retry exhaustion produces clear error messages."""
|
|
762
|
+
import asyncio
|
|
763
|
+
|
|
764
|
+
async def run_test():
|
|
765
|
+
call_count = 0
|
|
766
|
+
|
|
767
|
+
def always_fails() -> str:
|
|
768
|
+
"""Function that always raises retriable exceptions."""
|
|
769
|
+
nonlocal call_count
|
|
770
|
+
call_count += 1
|
|
771
|
+
raise Exception("Rate limit exceeded") # Always retriable
|
|
772
|
+
|
|
773
|
+
# Should exhaust retries and raise RetryExhaustedException
|
|
774
|
+
try:
|
|
775
|
+
await gather_limited_sync(
|
|
776
|
+
lambda: always_fails(),
|
|
777
|
+
retry_settings=RetrySettings(
|
|
778
|
+
max_task_retries=2,
|
|
779
|
+
initial_backoff=0.01,
|
|
780
|
+
max_backoff=0.1,
|
|
781
|
+
backoff_factor=2.0,
|
|
782
|
+
),
|
|
783
|
+
)
|
|
784
|
+
raise AssertionError("Expected RetryExhaustedException")
|
|
785
|
+
except RetryExhaustedException as e:
|
|
786
|
+
assert "Max retries (2) exhausted" in str(e)
|
|
787
|
+
assert "Rate limit exceeded" in str(e)
|
|
788
|
+
assert isinstance(e.original_exception, Exception)
|
|
789
|
+
assert call_count == 3 # Initial attempt + 2 retries
|
|
790
|
+
|
|
791
|
+
asyncio.run(run_test())
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def test_gather_limited_return_exceptions():
|
|
795
|
+
"""Test return_exceptions=True behavior for both functions."""
|
|
796
|
+
import asyncio
|
|
797
|
+
|
|
798
|
+
async def run_test():
|
|
799
|
+
def failing_sync() -> str:
|
|
800
|
+
raise ValueError("sync error")
|
|
801
|
+
|
|
802
|
+
async def failing_async() -> str:
|
|
803
|
+
raise ValueError("async error")
|
|
804
|
+
|
|
805
|
+
# Test sync version with exceptions returned
|
|
806
|
+
sync_results = await gather_limited_sync(
|
|
807
|
+
lambda: "success",
|
|
808
|
+
lambda: failing_sync(),
|
|
809
|
+
return_exceptions=True,
|
|
810
|
+
retry_settings=NO_RETRIES,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
assert len(sync_results) == 2
|
|
814
|
+
assert sync_results[0] == "success"
|
|
815
|
+
assert isinstance(sync_results[1], ValueError)
|
|
816
|
+
assert str(sync_results[1]) == "sync error"
|
|
817
|
+
|
|
818
|
+
async def success_async() -> str:
|
|
819
|
+
return "async_success"
|
|
820
|
+
|
|
821
|
+
# Test async version with exceptions returned
|
|
822
|
+
async_results = await gather_limited_async(
|
|
823
|
+
lambda: success_async(),
|
|
824
|
+
lambda: failing_async(),
|
|
825
|
+
return_exceptions=True,
|
|
826
|
+
retry_settings=NO_RETRIES,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
assert len(async_results) == 2
|
|
830
|
+
assert async_results[0] == "async_success"
|
|
831
|
+
assert isinstance(async_results[1], ValueError)
|
|
832
|
+
assert str(async_results[1]) == "async error"
|
|
833
|
+
|
|
834
|
+
asyncio.run(run_test())
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def test_gather_limited_global_retry_limit():
|
|
838
|
+
"""Test that global retry limits are enforced across all tasks."""
|
|
839
|
+
import asyncio
|
|
840
|
+
|
|
841
|
+
async def run_test():
|
|
842
|
+
retry_counts = {"task1": 0, "task2": 0}
|
|
843
|
+
|
|
844
|
+
def flaky_task(task_name: str) -> str:
|
|
845
|
+
"""Tasks that always fail but track retry counts."""
|
|
846
|
+
retry_counts[task_name] += 1
|
|
847
|
+
raise Exception(f"Rate limit exceeded in {task_name}")
|
|
848
|
+
|
|
849
|
+
# Test with very low global retry limit
|
|
850
|
+
try:
|
|
851
|
+
await gather_limited_sync(
|
|
852
|
+
lambda: flaky_task("task1"),
|
|
853
|
+
lambda: flaky_task("task2"),
|
|
854
|
+
retry_settings=RetrySettings(
|
|
855
|
+
max_task_retries=5, # Each task could retry up to 5 times
|
|
856
|
+
max_total_retries=3, # But only 3 total retries across all tasks
|
|
857
|
+
initial_backoff=0.01,
|
|
858
|
+
max_backoff=0.1,
|
|
859
|
+
backoff_factor=2.0,
|
|
860
|
+
),
|
|
861
|
+
return_exceptions=True,
|
|
862
|
+
)
|
|
863
|
+
except Exception:
|
|
864
|
+
pass # Expected to fail due to rate limits
|
|
865
|
+
|
|
866
|
+
# Verify that total retries across both tasks doesn't exceed global limit
|
|
867
|
+
total_retries = (retry_counts["task1"] - 1) + (
|
|
868
|
+
retry_counts["task2"] - 1
|
|
869
|
+
) # -1 for initial attempts
|
|
870
|
+
assert total_retries <= 3, f"Total retries {total_retries} exceeded global limit of 3"
|
|
871
|
+
|
|
872
|
+
# Verify that both tasks were attempted at least once
|
|
873
|
+
assert retry_counts["task1"] >= 1
|
|
874
|
+
assert retry_counts["task2"] >= 1
|
|
875
|
+
|
|
876
|
+
asyncio.run(run_test())
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def test_gather_limited_funcspec_format():
|
|
880
|
+
"""Test gather_limited with FuncSpec format and custom labeler accessing args."""
|
|
881
|
+
import asyncio
|
|
882
|
+
|
|
883
|
+
async def run_test():
|
|
884
|
+
def sync_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
885
|
+
"""Sync function that takes args and kwargs."""
|
|
886
|
+
return f"{name}: {value * multiplier}"
|
|
887
|
+
|
|
888
|
+
async def async_func(name: str, value: int, multiplier: int = 2) -> str:
|
|
889
|
+
"""Async function that takes args and kwargs."""
|
|
890
|
+
await asyncio.sleep(0.01)
|
|
891
|
+
return f"{name}: {value * multiplier}"
|
|
892
|
+
|
|
893
|
+
captured_labels = []
|
|
894
|
+
|
|
895
|
+
def custom_labeler(i: int, spec: Any) -> str:
|
|
896
|
+
if isinstance(spec, FuncTask):
|
|
897
|
+
# Extract meaningful info from args for labeling
|
|
898
|
+
if spec.args and len(spec.args) > 0:
|
|
899
|
+
label = f"Processing {spec.args[0]}"
|
|
900
|
+
else:
|
|
901
|
+
label = f"Task {i}"
|
|
902
|
+
else:
|
|
903
|
+
label = f"Task {i}"
|
|
904
|
+
captured_labels.append(label)
|
|
905
|
+
return label
|
|
906
|
+
|
|
907
|
+
# Test sync version with FuncSpec format and custom labeler
|
|
908
|
+
sync_results = await gather_limited_sync(
|
|
909
|
+
FuncTask(sync_func, ("user1", 100), {"multiplier": 3}), # user1: 300
|
|
910
|
+
FuncTask(sync_func, ("user2", 50)), # user2: 100 (default multiplier)
|
|
911
|
+
labeler=custom_labeler,
|
|
912
|
+
retry_settings=NO_RETRIES,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
assert sync_results == ["user1: 300", "user2: 100"]
|
|
916
|
+
assert captured_labels == ["Processing user1", "Processing user2"]
|
|
917
|
+
|
|
918
|
+
# Reset labels for async test
|
|
919
|
+
captured_labels.clear()
|
|
920
|
+
|
|
921
|
+
# Test async version with FuncSpec format and custom labeler
|
|
922
|
+
async_results = await gather_limited_async(
|
|
923
|
+
FuncTask(async_func, ("api_call", 10), {"multiplier": 4}), # api_call: 40
|
|
924
|
+
FuncTask(async_func, ("data_fetch", 5)), # data_fetch: 10 (default multiplier)
|
|
925
|
+
labeler=custom_labeler,
|
|
926
|
+
retry_settings=NO_RETRIES,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
assert async_results == ["api_call: 40", "data_fetch: 10"]
|
|
930
|
+
assert captured_labels == ["Processing api_call", "Processing data_fetch"]
|
|
931
|
+
|
|
932
|
+
asyncio.run(run_test())
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def test_gather_limited_sync_cooperative_cancellation():
|
|
936
|
+
"""Test gather_limited_sync with cooperative cancellation via threading.Event."""
|
|
937
|
+
import asyncio
|
|
938
|
+
import time
|
|
939
|
+
|
|
940
|
+
async def run_test():
|
|
941
|
+
cancel_event = threading.Event()
|
|
942
|
+
call_counts = {"task1": 0, "task2": 0}
|
|
943
|
+
|
|
944
|
+
def cancellable_sync_func(task_name: str, work_duration: float) -> str:
|
|
945
|
+
"""Sync function that checks cancellation event periodically."""
|
|
946
|
+
call_counts[task_name] += 1
|
|
947
|
+
start_time = time.time()
|
|
948
|
+
|
|
949
|
+
while time.time() - start_time < work_duration:
|
|
950
|
+
if cancel_event.is_set():
|
|
951
|
+
return f"{task_name}: cancelled"
|
|
952
|
+
time.sleep(0.01) # Small sleep to allow cancellation check
|
|
953
|
+
|
|
954
|
+
return f"{task_name}: completed"
|
|
955
|
+
|
|
956
|
+
# Test cooperative cancellation - tasks should respect the cancel_event
|
|
957
|
+
results = await gather_limited_sync(
|
|
958
|
+
lambda: cancellable_sync_func("task1", 0.1), # Short duration
|
|
959
|
+
lambda: cancellable_sync_func("task2", 0.1), # Short duration
|
|
960
|
+
cancel_event=cancel_event,
|
|
961
|
+
cancel_timeout=1.0,
|
|
962
|
+
retry_settings=NO_RETRIES,
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# Should complete normally since cancel_event wasn't set
|
|
966
|
+
assert results == ["task1: completed", "task2: completed"]
|
|
967
|
+
assert call_counts["task1"] == 1
|
|
968
|
+
assert call_counts["task2"] == 1
|
|
969
|
+
|
|
970
|
+
# Test that cancel_event can be used independently
|
|
971
|
+
cancel_event.set() # Set cancellation signal
|
|
972
|
+
|
|
973
|
+
results2 = await gather_limited_sync(
|
|
974
|
+
lambda: cancellable_sync_func("task1", 1.0), # Would take long if not cancelled
|
|
975
|
+
lambda: cancellable_sync_func("task2", 1.0), # Would take long if not cancelled
|
|
976
|
+
cancel_event=cancel_event,
|
|
977
|
+
cancel_timeout=1.0,
|
|
978
|
+
retry_settings=NO_RETRIES,
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# Should be cancelled quickly since cancel_event is already set
|
|
982
|
+
assert results2 == ["task1: cancelled", "task2: cancelled"]
|
|
983
|
+
# Call counts should increment
|
|
984
|
+
assert call_counts["task1"] == 2
|
|
985
|
+
assert call_counts["task2"] == 2
|
|
986
|
+
|
|
987
|
+
asyncio.run(run_test())
|