kash-shell 0.3.23__py3-none-any.whl → 0.3.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kash/actions/core/combine_docs.py +52 -0
- kash/actions/core/concat_docs.py +47 -0
- kash/commands/workspace/workspace_commands.py +2 -2
- kash/config/logger.py +3 -2
- kash/config/settings.py +8 -0
- kash/docs/markdown/topics/a2_installation.md +2 -2
- kash/embeddings/embeddings.py +1 -1
- kash/exec/action_exec.py +1 -1
- kash/exec/fetch_url_items.py +36 -8
- kash/llm_utils/llm_completion.py +1 -1
- kash/mcp/mcp_cli.py +2 -2
- kash/utils/api_utils/api_retries.py +84 -76
- kash/utils/api_utils/gather_limited.py +227 -89
- kash/utils/api_utils/http_utils.py +46 -0
- kash/utils/api_utils/progress_protocol.py +49 -56
- kash/utils/rich_custom/multitask_status.py +70 -21
- kash/utils/text_handling/markdown_utils.py +14 -3
- kash/web_content/web_extract.py +12 -8
- kash/web_content/web_fetch.py +289 -60
- kash/web_content/web_page_model.py +5 -0
- {kash_shell-0.3.23.dist-info → kash_shell-0.3.24.dist-info}/METADATA +5 -3
- {kash_shell-0.3.23.dist-info → kash_shell-0.3.24.dist-info}/RECORD +25 -22
- {kash_shell-0.3.23.dist-info → kash_shell-0.3.24.dist-info}/WHEEL +0 -0
- {kash_shell-0.3.23.dist-info → kash_shell-0.3.24.dist-info}/entry_points.txt +0 -0
- {kash_shell-0.3.23.dist-info → kash_shell-0.3.24.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,9 +4,10 @@ import asyncio
|
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
6
|
import threading
|
|
7
|
+
import time
|
|
7
8
|
from collections.abc import Callable, Coroutine
|
|
8
9
|
from dataclasses import dataclass, field
|
|
9
|
-
from typing import Any, TypeAlias, TypeVar, cast, overload
|
|
10
|
+
from typing import Any, Generic, TypeAlias, TypeVar, cast, overload
|
|
10
11
|
|
|
11
12
|
from aiolimiter import AsyncLimiter
|
|
12
13
|
|
|
@@ -17,7 +18,6 @@ from kash.utils.api_utils.api_retries import (
|
|
|
17
18
|
RetrySettings,
|
|
18
19
|
calculate_backoff,
|
|
19
20
|
extract_http_status_code,
|
|
20
|
-
is_http_status_retriable,
|
|
21
21
|
)
|
|
22
22
|
from kash.utils.api_utils.progress_protocol import Labeler, ProgressTracker, TaskState
|
|
23
23
|
|
|
@@ -25,10 +25,6 @@ T = TypeVar("T")
|
|
|
25
25
|
|
|
26
26
|
log = logging.getLogger(__name__)
|
|
27
27
|
|
|
28
|
-
DEFAULT_MAX_CONCURRENT: int = 5
|
|
29
|
-
DEFAULT_MAX_RPS: float = 5.0
|
|
30
|
-
DEFAULT_CANCEL_TIMEOUT: float = 1.0
|
|
31
|
-
|
|
32
28
|
|
|
33
29
|
@dataclass(frozen=True)
|
|
34
30
|
class Limit:
|
|
@@ -36,27 +32,48 @@ class Limit:
|
|
|
36
32
|
Rate limiting configuration with max RPS and max concurrency.
|
|
37
33
|
"""
|
|
38
34
|
|
|
39
|
-
rps: float
|
|
40
|
-
concurrency: int
|
|
35
|
+
rps: float
|
|
36
|
+
concurrency: int
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
DEFAULT_CANCEL_TIMEOUT: float = 1.0
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
@dataclass(frozen=True)
|
|
44
|
-
class
|
|
43
|
+
class TaskResult(Generic[T]):
|
|
44
|
+
"""
|
|
45
|
+
Optional wrapper for task results that can signal rate limiting behavior.
|
|
46
|
+
Use this to wrap results that should bypass rate limiting (e.g., cache hits).
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
value: T
|
|
50
|
+
disable_limits: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True)
|
|
54
|
+
class FuncTask(Generic[T]):
|
|
45
55
|
"""
|
|
46
56
|
A task described as an unevaluated function with args and kwargs.
|
|
47
57
|
This task format allows you to use args and kwargs in the Labeler.
|
|
48
58
|
It also allows specifying a bucket for rate limiting.
|
|
59
|
+
|
|
60
|
+
For async functions: The function should return a coroutine that yields either `T` or `TaskResult[T]`.
|
|
61
|
+
For sync functions: The function should return either `T` or `TaskResult[T]` directly.
|
|
62
|
+
|
|
63
|
+
Using `TaskResult[T]` allows controlling rate limiting behavior (e.g., cache hits can bypass rate limits).
|
|
49
64
|
"""
|
|
50
65
|
|
|
51
|
-
func: Callable[..., Any]
|
|
66
|
+
func: Callable[..., Any] # Keep as Any since it can be sync or async
|
|
52
67
|
args: tuple[Any, ...] = ()
|
|
53
68
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
54
69
|
bucket: str = "default"
|
|
55
70
|
|
|
56
71
|
|
|
57
72
|
# Type aliases for coroutine and sync specifications, including unevaluated function specs
|
|
58
|
-
CoroSpec: TypeAlias =
|
|
59
|
-
|
|
73
|
+
CoroSpec: TypeAlias = (
|
|
74
|
+
Callable[[], Coroutine[None, None, T]] | Coroutine[None, None, T] | FuncTask[T]
|
|
75
|
+
)
|
|
76
|
+
SyncSpec: TypeAlias = Callable[[], T] | FuncTask[T]
|
|
60
77
|
|
|
61
78
|
# Specific labeler types using the generic Labeler pattern
|
|
62
79
|
CoroLabeler: TypeAlias = Labeler[CoroSpec[T]]
|
|
@@ -114,7 +131,7 @@ class RetryCounter:
|
|
|
114
131
|
@overload
|
|
115
132
|
async def gather_limited_async(
|
|
116
133
|
*coro_specs: CoroSpec[T],
|
|
117
|
-
|
|
134
|
+
limit: Limit | None,
|
|
118
135
|
bucket_limits: dict[str, Limit] | None = None,
|
|
119
136
|
return_exceptions: bool = False,
|
|
120
137
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -126,7 +143,7 @@ async def gather_limited_async(
|
|
|
126
143
|
@overload
|
|
127
144
|
async def gather_limited_async(
|
|
128
145
|
*coro_specs: CoroSpec[T],
|
|
129
|
-
|
|
146
|
+
limit: Limit | None,
|
|
130
147
|
bucket_limits: dict[str, Limit] | None = None,
|
|
131
148
|
return_exceptions: bool = True,
|
|
132
149
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -137,7 +154,7 @@ async def gather_limited_async(
|
|
|
137
154
|
|
|
138
155
|
async def gather_limited_async(
|
|
139
156
|
*coro_specs: CoroSpec[T],
|
|
140
|
-
|
|
157
|
+
limit: Limit | None,
|
|
141
158
|
bucket_limits: dict[str, Limit] | None = None,
|
|
142
159
|
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
143
160
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -157,12 +174,22 @@ async def gather_limited_async(
|
|
|
157
174
|
- Configurable handling of conditional status codes like 403 Forbidden
|
|
158
175
|
- Defaults to return_exceptions=True for resilient batch operations
|
|
159
176
|
|
|
177
|
+
Prevents API flooding during rate limit backoffs:
|
|
178
|
+
- Semaphore slots are held during entire retry cycles, including backoff periods
|
|
179
|
+
- New tasks cannot start while existing tasks are backing off from rate limits
|
|
180
|
+
- Rate limiters are only acquired during actual execution attempts
|
|
181
|
+
|
|
160
182
|
Can optionally display live progress with retry indicators using TaskStatus.
|
|
161
183
|
|
|
162
184
|
Accepts:
|
|
163
185
|
- Callables that return coroutines: `lambda: some_async_func(arg)` (recommended for retries)
|
|
164
186
|
- Coroutines directly: `some_async_func(arg)` (only if retries disabled)
|
|
165
|
-
-
|
|
187
|
+
- FuncTask objects: `FuncTask(some_async_func, (arg1, arg2), {"kwarg": value})` (args accessible to labeler)
|
|
188
|
+
|
|
189
|
+
Functions can return `TaskResult[T]` to bypass rate limiting for specific results:
|
|
190
|
+
- `TaskResult(value, disable_limits=True)`: bypass rate limiting (e.g., cache hits)
|
|
191
|
+
- `TaskResult(value, disable_limits=False)`: apply normal rate limiting
|
|
192
|
+
- `value` directly: apply normal rate limiting
|
|
166
193
|
|
|
167
194
|
Examples:
|
|
168
195
|
```python
|
|
@@ -190,18 +217,34 @@ async def gather_limited_async(
|
|
|
190
217
|
FuncTask(fetch_api, ("data1",), bucket="api1"),
|
|
191
218
|
FuncTask(fetch_api, ("data2",), bucket="api2"),
|
|
192
219
|
FuncTask(fetch_api, ("data3",), bucket="api3"),
|
|
193
|
-
|
|
220
|
+
limit=Limit(rps=100, concurrency=50),
|
|
194
221
|
bucket_limits={
|
|
195
222
|
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
196
223
|
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
197
224
|
}
|
|
198
225
|
)
|
|
199
226
|
|
|
227
|
+
# With cache bypass using TaskResult:
|
|
228
|
+
async def fetch_with_cache(url: str) -> TaskResult[dict] | dict:
|
|
229
|
+
cached_data = await cache.get(url)
|
|
230
|
+
if cached_data:
|
|
231
|
+
return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
|
|
232
|
+
|
|
233
|
+
data = await fetch_api(url) # Will be rate limited
|
|
234
|
+
await cache.set(url, data)
|
|
235
|
+
return data
|
|
236
|
+
|
|
237
|
+
await gather_limited_async(
|
|
238
|
+
lambda: fetch_with_cache("http://api.com/data1"),
|
|
239
|
+
lambda: fetch_with_cache("http://api.com/data2"),
|
|
240
|
+
limit=Limit(rps=10, concurrency=5) # Cache hits bypass these limits
|
|
241
|
+
)
|
|
242
|
+
|
|
200
243
|
```
|
|
201
244
|
|
|
202
245
|
Args:
|
|
203
246
|
*coro_specs: Callables or coroutines to execute
|
|
204
|
-
|
|
247
|
+
limit: Global limits applied to all tasks regardless of bucket
|
|
205
248
|
bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
|
|
206
249
|
Use "*" as a fallback limit for buckets without specific limits.
|
|
207
250
|
return_exceptions: If True, exceptions are returned as results
|
|
@@ -217,8 +260,8 @@ async def gather_limited_async(
|
|
|
217
260
|
"""
|
|
218
261
|
log.info(
|
|
219
262
|
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
220
|
-
|
|
221
|
-
|
|
263
|
+
limit.concurrency if limit else "None",
|
|
264
|
+
limit.rps if limit else "None",
|
|
222
265
|
retry_settings,
|
|
223
266
|
)
|
|
224
267
|
if not coro_specs:
|
|
@@ -237,8 +280,8 @@ async def gather_limited_async(
|
|
|
237
280
|
)
|
|
238
281
|
|
|
239
282
|
# Global limits (apply to all tasks regardless of bucket)
|
|
240
|
-
global_semaphore = asyncio.Semaphore(
|
|
241
|
-
global_rate_limiter = AsyncLimiter(
|
|
283
|
+
global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
|
|
284
|
+
global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
|
|
242
285
|
|
|
243
286
|
# Per-bucket limits (if bucket_limits provided)
|
|
244
287
|
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
@@ -275,7 +318,7 @@ async def gather_limited_async(
|
|
|
275
318
|
elif callable(coro_spec):
|
|
276
319
|
coro = coro_spec()
|
|
277
320
|
else:
|
|
278
|
-
# Direct coroutine
|
|
321
|
+
# Direct coroutine: only valid if retries disabled
|
|
279
322
|
coro = coro_spec
|
|
280
323
|
return await coro
|
|
281
324
|
|
|
@@ -313,7 +356,7 @@ async def gather_limited_async(
|
|
|
313
356
|
@overload
|
|
314
357
|
async def gather_limited_sync(
|
|
315
358
|
*sync_specs: SyncSpec[T],
|
|
316
|
-
|
|
359
|
+
limit: Limit | None,
|
|
317
360
|
bucket_limits: dict[str, Limit] | None = None,
|
|
318
361
|
return_exceptions: bool = False,
|
|
319
362
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -327,7 +370,7 @@ async def gather_limited_sync(
|
|
|
327
370
|
@overload
|
|
328
371
|
async def gather_limited_sync(
|
|
329
372
|
*sync_specs: SyncSpec[T],
|
|
330
|
-
|
|
373
|
+
limit: Limit | None,
|
|
331
374
|
bucket_limits: dict[str, Limit] | None = None,
|
|
332
375
|
return_exceptions: bool = True,
|
|
333
376
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -340,7 +383,7 @@ async def gather_limited_sync(
|
|
|
340
383
|
|
|
341
384
|
async def gather_limited_sync(
|
|
342
385
|
*sync_specs: SyncSpec[T],
|
|
343
|
-
|
|
386
|
+
limit: Limit | None,
|
|
344
387
|
bucket_limits: dict[str, Limit] | None = None,
|
|
345
388
|
return_exceptions: bool = True, # Default to True for resilient batch operations
|
|
346
389
|
retry_settings: RetrySettings | None = DEFAULT_RETRIES,
|
|
@@ -350,42 +393,38 @@ async def gather_limited_sync(
|
|
|
350
393
|
cancel_timeout: float = DEFAULT_CANCEL_TIMEOUT,
|
|
351
394
|
) -> list[T] | list[T | BaseException]:
|
|
352
395
|
"""
|
|
353
|
-
|
|
354
|
-
|
|
396
|
+
Sync version of `gather_limited_async()` that executes synchronous functions with the same
|
|
397
|
+
rate limiting, retry logic, and progress tracking capabilities.
|
|
398
|
+
See `gather_limited_async()` for details.
|
|
355
399
|
|
|
356
|
-
|
|
357
|
-
- Per-task retries: max_task_retries attempts per individual task
|
|
358
|
-
- Global retries: max_total_retries attempts across all tasks
|
|
400
|
+
Sync-specific differences:
|
|
359
401
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
402
|
+
**Function Execution**: Runs sync functions via `asyncio.to_thread()` instead of executing
|
|
403
|
+
coroutines directly. Validates that callables don't accidentally return coroutines.
|
|
404
|
+
|
|
405
|
+
**Cancellation Support**: Provides cooperative cancellation for long-running sync functions
|
|
406
|
+
through optional `cancel_event` and graceful thread termination.
|
|
364
407
|
|
|
365
|
-
|
|
408
|
+
**Input Validation**: Ensures sync functions don't return coroutines, which would indicate
|
|
409
|
+
incorrect usage (async functions should use `gather_limited_async()`).
|
|
366
410
|
|
|
367
411
|
Args:
|
|
368
|
-
*sync_specs: Callables that return values (not coroutines) or FuncTask objects
|
|
369
|
-
|
|
370
|
-
bucket_limits: Optional per-bucket limits. Tasks use their bucket field to determine limits.
|
|
371
|
-
Use "*" as a fallback limit for buckets without specific limits.
|
|
372
|
-
return_exceptions: If True, exceptions are returned as results
|
|
373
|
-
retry_settings: Configuration for retry behavior, or None to disable retries
|
|
374
|
-
status: Optional ProgressTracker instance for progress display
|
|
375
|
-
labeler: Optional function to generate labels: labeler(index, spec) -> str
|
|
412
|
+
*sync_specs: Callables that return values (not coroutines) or FuncTask objects.
|
|
413
|
+
Functions can return `TaskResult[T]` to control rate limiting behavior.
|
|
376
414
|
cancel_event: Optional threading.Event that will be set on cancellation
|
|
377
415
|
cancel_timeout: Max seconds to wait for threads to terminate on cancellation
|
|
416
|
+
(All other args identical to gather_limited_async())
|
|
378
417
|
|
|
379
418
|
Returns:
|
|
380
419
|
List of results in the same order as input specifications
|
|
381
420
|
|
|
382
421
|
Example:
|
|
383
422
|
```python
|
|
384
|
-
#
|
|
423
|
+
# Basic usage with sync functions
|
|
385
424
|
results = await gather_limited_sync(
|
|
386
425
|
lambda: some_sync_function(arg1),
|
|
387
426
|
lambda: another_sync_function(arg2),
|
|
388
|
-
|
|
427
|
+
limit=Limit(rps=2.0, concurrency=3),
|
|
389
428
|
retry_settings=RetrySettings(max_task_retries=3, max_total_retries=25)
|
|
390
429
|
)
|
|
391
430
|
|
|
@@ -394,18 +433,34 @@ async def gather_limited_sync(
|
|
|
394
433
|
FuncTask(fetch_from_api, ("data1",), bucket="api1"),
|
|
395
434
|
FuncTask(fetch_from_api, ("data2",), bucket="api2"),
|
|
396
435
|
FuncTask(fetch_from_api, ("data3",), bucket="api3"),
|
|
397
|
-
|
|
436
|
+
limit=Limit(rps=100, concurrency=50),
|
|
398
437
|
bucket_limits={
|
|
399
438
|
"api1": Limit(rps=20, concurrency=10), # Specific limit for api1
|
|
400
439
|
"*": Limit(rps=15, concurrency=8), # Fallback for api2, api3, and others
|
|
401
440
|
}
|
|
402
441
|
)
|
|
442
|
+
|
|
443
|
+
# With cache bypass using TaskResult
|
|
444
|
+
def sync_fetch_with_cache(url: str) -> TaskResult[dict] | dict:
|
|
445
|
+
cached_data = cache.get_sync(url)
|
|
446
|
+
if cached_data:
|
|
447
|
+
return TaskResult(cached_data, disable_limits=True) # Cache hit: no rate limit
|
|
448
|
+
|
|
449
|
+
data = requests.get(url).json() # Will be rate limited
|
|
450
|
+
cache.set_sync(url, data)
|
|
451
|
+
return data
|
|
452
|
+
|
|
453
|
+
results = await gather_limited_sync(
|
|
454
|
+
lambda: sync_fetch_with_cache("http://api.com/data1"),
|
|
455
|
+
lambda: sync_fetch_with_cache("http://api.com/data2"),
|
|
456
|
+
limit=Limit(rps=5, concurrency=3) # Cache hits bypass these limits
|
|
457
|
+
)
|
|
403
458
|
```
|
|
404
459
|
"""
|
|
405
460
|
log.info(
|
|
406
461
|
"Executing with global limits: concurrency %s at %s rps, %s",
|
|
407
|
-
|
|
408
|
-
|
|
462
|
+
limit.concurrency if limit else "None",
|
|
463
|
+
limit.rps if limit else "None",
|
|
409
464
|
retry_settings,
|
|
410
465
|
)
|
|
411
466
|
if not sync_specs:
|
|
@@ -414,8 +469,8 @@ async def gather_limited_sync(
|
|
|
414
469
|
retry_settings = retry_settings or NO_RETRIES
|
|
415
470
|
|
|
416
471
|
# Global limits (apply to all tasks regardless of bucket)
|
|
417
|
-
global_semaphore = asyncio.Semaphore(
|
|
418
|
-
global_rate_limiter = AsyncLimiter(
|
|
472
|
+
global_semaphore = asyncio.Semaphore(limit.concurrency) if limit else None
|
|
473
|
+
global_rate_limiter = AsyncLimiter(limit.rps, 1.0) if limit else None
|
|
419
474
|
|
|
420
475
|
# Per-bucket limits (if bucket_limits provided)
|
|
421
476
|
bucket_semaphores: dict[str, asyncio.Semaphore] = {}
|
|
@@ -572,34 +627,94 @@ async def _gather_with_interrupt_handling(
|
|
|
572
627
|
async def _execute_with_retry(
|
|
573
628
|
executor: Callable[[], Coroutine[None, None, T]],
|
|
574
629
|
retry_settings: RetrySettings,
|
|
575
|
-
global_semaphore: asyncio.Semaphore,
|
|
576
|
-
global_rate_limiter: AsyncLimiter,
|
|
630
|
+
global_semaphore: asyncio.Semaphore | None,
|
|
631
|
+
global_rate_limiter: AsyncLimiter | None,
|
|
577
632
|
bucket_semaphore: asyncio.Semaphore | None,
|
|
578
633
|
bucket_rate_limiter: AsyncLimiter | None,
|
|
579
634
|
global_retry_counter: RetryCounter,
|
|
580
635
|
status: ProgressTracker | None = None,
|
|
581
636
|
task_id: Any | None = None,
|
|
582
637
|
) -> T:
|
|
583
|
-
|
|
638
|
+
"""
|
|
639
|
+
Execute a task with retry logic, holding semaphores for the entire retry cycle.
|
|
584
640
|
|
|
585
|
-
|
|
586
|
-
|
|
641
|
+
Semaphores are held during backoff periods to prevent API flooding when tasks
|
|
642
|
+
are already backing off from rate limits. Only rate limiters are acquired/released
|
|
643
|
+
for each execution attempt.
|
|
644
|
+
"""
|
|
645
|
+
# Acquire semaphores once for the entire retry cycle (including backoff periods)
|
|
646
|
+
if global_semaphore is not None:
|
|
647
|
+
async with global_semaphore:
|
|
648
|
+
if bucket_semaphore is not None:
|
|
649
|
+
async with bucket_semaphore:
|
|
650
|
+
return await _execute_with_retry_inner(
|
|
651
|
+
executor,
|
|
652
|
+
retry_settings,
|
|
653
|
+
global_rate_limiter,
|
|
654
|
+
bucket_rate_limiter,
|
|
655
|
+
global_retry_counter,
|
|
656
|
+
status,
|
|
657
|
+
task_id,
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
return await _execute_with_retry_inner(
|
|
661
|
+
executor,
|
|
662
|
+
retry_settings,
|
|
663
|
+
global_rate_limiter,
|
|
664
|
+
None,
|
|
665
|
+
global_retry_counter,
|
|
666
|
+
status,
|
|
667
|
+
task_id,
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
# No global semaphore, check bucket semaphore
|
|
671
|
+
if bucket_semaphore is not None:
|
|
672
|
+
async with bucket_semaphore:
|
|
673
|
+
return await _execute_with_retry_inner(
|
|
674
|
+
executor,
|
|
675
|
+
retry_settings,
|
|
676
|
+
global_rate_limiter,
|
|
677
|
+
bucket_rate_limiter,
|
|
678
|
+
global_retry_counter,
|
|
679
|
+
status,
|
|
680
|
+
task_id,
|
|
681
|
+
)
|
|
682
|
+
else:
|
|
683
|
+
# No semaphores at all, go straight to running
|
|
684
|
+
return await _execute_with_retry_inner(
|
|
685
|
+
executor,
|
|
686
|
+
retry_settings,
|
|
687
|
+
global_rate_limiter,
|
|
688
|
+
None,
|
|
689
|
+
global_retry_counter,
|
|
690
|
+
status,
|
|
691
|
+
task_id,
|
|
692
|
+
)
|
|
587
693
|
|
|
588
|
-
# Create HTTP-aware is_retriable function based on settings
|
|
589
|
-
def is_retriable_for_task(exception: Exception) -> bool:
|
|
590
|
-
# First check the main is_retriable function
|
|
591
|
-
if not retry_settings.is_retriable(exception):
|
|
592
|
-
return False
|
|
593
694
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
695
|
+
async def _execute_with_retry_inner(
|
|
696
|
+
executor: Callable[[], Coroutine[None, None, T]],
|
|
697
|
+
retry_settings: RetrySettings,
|
|
698
|
+
global_rate_limiter: AsyncLimiter | None,
|
|
699
|
+
bucket_rate_limiter: AsyncLimiter | None,
|
|
700
|
+
global_retry_counter: RetryCounter,
|
|
701
|
+
status: ProgressTracker | None = None,
|
|
702
|
+
task_id: Any | None = None,
|
|
703
|
+
) -> T:
|
|
704
|
+
"""
|
|
705
|
+
Inner retry logic that handles rate limiting and backoff.
|
|
706
|
+
|
|
707
|
+
This function assumes semaphores are already held by the caller and only
|
|
708
|
+
manages rate limiters for each execution attempt.
|
|
709
|
+
"""
|
|
710
|
+
if status and task_id:
|
|
711
|
+
await status.start(task_id)
|
|
597
712
|
|
|
598
|
-
|
|
599
|
-
|
|
713
|
+
start_time = time.time()
|
|
714
|
+
last_exception: Exception | None = None
|
|
600
715
|
|
|
601
716
|
for attempt in range(retry_settings.max_task_retries + 1):
|
|
602
|
-
# Handle backoff before acquiring
|
|
717
|
+
# Handle backoff before acquiring rate limiters (semaphores remain held)
|
|
603
718
|
if attempt > 0 and last_exception:
|
|
604
719
|
# Try to increment global retry counter
|
|
605
720
|
if not await global_retry_counter.try_increment():
|
|
@@ -624,7 +739,7 @@ async def _execute_with_retry(
|
|
|
624
739
|
f"Attempt {attempt}/{retry_settings.max_task_retries} "
|
|
625
740
|
f"(waiting {backoff_time:.1f}s): {type(last_exception).__name__}: {last_exception}"
|
|
626
741
|
)
|
|
627
|
-
await status.update(task_id, error_msg=retry_info)
|
|
742
|
+
await status.update(task_id, TaskState.WAITING, error_msg=retry_info)
|
|
628
743
|
|
|
629
744
|
# Use debug level for Rich trackers, warning/info for console trackers
|
|
630
745
|
use_debug_level = status.suppress_logs
|
|
@@ -651,22 +766,46 @@ async def _execute_with_retry(
|
|
|
651
766
|
else:
|
|
652
767
|
log.warning(rate_limit_msg)
|
|
653
768
|
log.info(exception_msg)
|
|
769
|
+
|
|
770
|
+
# Sleep during backoff while holding semaphore slots
|
|
654
771
|
await asyncio.sleep(backoff_time)
|
|
655
772
|
|
|
773
|
+
# Set state back to running before next attempt
|
|
774
|
+
if status and task_id:
|
|
775
|
+
await status.update(task_id, TaskState.RUNNING)
|
|
776
|
+
|
|
656
777
|
try:
|
|
657
|
-
#
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
778
|
+
# Execute task to check for potential rate limit bypass
|
|
779
|
+
raw_result = await executor()
|
|
780
|
+
|
|
781
|
+
# Check if result indicates limits should be disabled (e.g., cache hit)
|
|
782
|
+
if isinstance(raw_result, TaskResult):
|
|
783
|
+
if raw_result.disable_limits:
|
|
784
|
+
# Bypass rate limiting and return immediately
|
|
785
|
+
return cast(T, raw_result.value)
|
|
665
786
|
else:
|
|
666
|
-
#
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
787
|
+
# Wrapped but limits enabled: extract value for rate limiting
|
|
788
|
+
result_value = cast(T, raw_result.value)
|
|
789
|
+
else:
|
|
790
|
+
# Unwrapped result: apply normal rate limiting
|
|
791
|
+
result_value = cast(T, raw_result)
|
|
792
|
+
|
|
793
|
+
# Apply rate limiting for non-bypassed results
|
|
794
|
+
if global_rate_limiter is not None:
|
|
795
|
+
async with global_rate_limiter:
|
|
796
|
+
if bucket_rate_limiter is not None:
|
|
797
|
+
async with bucket_rate_limiter:
|
|
798
|
+
return result_value
|
|
799
|
+
else:
|
|
800
|
+
return result_value
|
|
801
|
+
else:
|
|
802
|
+
# No global rate limiter, check bucket rate limiter
|
|
803
|
+
if bucket_rate_limiter is not None:
|
|
804
|
+
async with bucket_rate_limiter:
|
|
805
|
+
return result_value
|
|
806
|
+
else:
|
|
807
|
+
# No rate limiting at all
|
|
808
|
+
return result_value
|
|
670
809
|
|
|
671
810
|
except Exception as e:
|
|
672
811
|
last_exception = e # Always store the exception
|
|
@@ -674,10 +813,10 @@ async def _execute_with_retry(
|
|
|
674
813
|
if attempt == retry_settings.max_task_retries:
|
|
675
814
|
# Final attempt failed
|
|
676
815
|
if retry_settings.max_task_retries == 0:
|
|
677
|
-
# No retries configured
|
|
816
|
+
# No retries configured: raise original exception directly
|
|
678
817
|
raise
|
|
679
818
|
else:
|
|
680
|
-
# Retries were attempted but exhausted
|
|
819
|
+
# Retries were attempted but exhausted: wrap with context
|
|
681
820
|
total_time = time.time() - start_time
|
|
682
821
|
log.error(
|
|
683
822
|
f"Max task retries ({retry_settings.max_task_retries}) exhausted after {total_time:.1f}s. "
|
|
@@ -685,19 +824,18 @@ async def _execute_with_retry(
|
|
|
685
824
|
)
|
|
686
825
|
raise RetryExhaustedException(e, retry_settings.max_task_retries, total_time)
|
|
687
826
|
|
|
688
|
-
# Check if this is a retriable exception using
|
|
689
|
-
if
|
|
690
|
-
# Continue to next retry attempt (
|
|
827
|
+
# Check if this is a retriable exception using the centralized logic
|
|
828
|
+
if retry_settings.should_retry(e):
|
|
829
|
+
# Continue to next retry attempt (semaphores remain held for backoff)
|
|
691
830
|
continue
|
|
692
831
|
else:
|
|
693
832
|
# Non-retriable exception, log and re-raise immediately
|
|
694
833
|
status_code = extract_http_status_code(e)
|
|
695
834
|
status_info = f" (HTTP {status_code})" if status_code else ""
|
|
696
835
|
|
|
697
|
-
log.warning(
|
|
698
|
-
|
|
699
|
-
)
|
|
836
|
+
log.warning("Non-retriable exception%s (not retrying): %s", status_info, e)
|
|
837
|
+
log.debug("Exception traceback:", exc_info=True)
|
|
700
838
|
raise
|
|
701
839
|
|
|
702
840
|
# This should never be reached, but satisfy type checker
|
|
703
|
-
raise RuntimeError("Unexpected code path in
|
|
841
|
+
raise RuntimeError("Unexpected code path in _execute_with_retry_inner")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def extract_http_status_code(exception: Exception) -> int | None:
|
|
5
|
+
"""
|
|
6
|
+
Extract HTTP status code from various exception types.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
exception: The exception to extract status code from
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
HTTP status code or None if not found
|
|
13
|
+
"""
|
|
14
|
+
# Check for httpx.HTTPStatusError and requests.HTTPError
|
|
15
|
+
if hasattr(exception, "response"):
|
|
16
|
+
response = getattr(exception, "response", None)
|
|
17
|
+
if response and hasattr(response, "status_code"):
|
|
18
|
+
return getattr(response, "status_code", None)
|
|
19
|
+
|
|
20
|
+
# Check for aiohttp errors
|
|
21
|
+
if hasattr(exception, "status"):
|
|
22
|
+
return getattr(exception, "status", None)
|
|
23
|
+
|
|
24
|
+
# Parse from exception message as fallback
|
|
25
|
+
exception_str = str(exception)
|
|
26
|
+
|
|
27
|
+
# Try to find status code patterns in the message
|
|
28
|
+
import re
|
|
29
|
+
|
|
30
|
+
# Pattern for "403 Forbidden", "HTTP 429", etc.
|
|
31
|
+
status_patterns = [
|
|
32
|
+
r"\b(\d{3})\s+(?:Forbidden|Unauthorized|Not Found|Too Many Requests|Internal Server Error|Bad Gateway|Service Unavailable|Gateway Timeout)\b",
|
|
33
|
+
r"\bHTTP\s+(\d{3})\b",
|
|
34
|
+
r"\b(\d{3})\s+error\b",
|
|
35
|
+
r"status\s*(?:code)?:\s*(\d{3})\b",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
for pattern in status_patterns:
|
|
39
|
+
match = re.search(pattern, exception_str, re.IGNORECASE)
|
|
40
|
+
if match:
|
|
41
|
+
try:
|
|
42
|
+
return int(match.group(1))
|
|
43
|
+
except (ValueError, IndexError):
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
return None
|