model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +141 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +49 -57
- model_library/config/all_models.json +353 -120
- model_library/config/anthropic_models.yaml +2 -1
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/mistral_models.yaml +2 -0
- model_library/config/openai_models.yaml +15 -23
- model_library/config/together_models.yaml +2 -0
- model_library/config/xiaomi_models.yaml +43 -0
- model_library/config/zai_models.yaml +27 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +128 -48
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +57 -30
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/vercel.py +34 -0
- model_library/providers/xai.py +47 -42
- model_library/providers/xiaomi.py +34 -0
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +5 -0
- model_library/registry_utils.py +48 -17
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +17 -7
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
- model_library-0.1.9.dist-info/RECORD +73 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any, Awaitable, Callable, Literal, TypeVar
|
|
6
|
+
|
|
7
|
+
from model_library.base.base import QueryResult
|
|
8
|
+
from model_library.exceptions import (
|
|
9
|
+
ImmediateRetryException,
|
|
10
|
+
MaxContextWindowExceededError,
|
|
11
|
+
exception_message,
|
|
12
|
+
is_context_window_error,
|
|
13
|
+
is_retriable_error,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
|
|
17
|
+
|
|
18
|
+
R = TypeVar("R") # wrapper return type
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseRetrier(ABC):
|
|
22
|
+
"""
|
|
23
|
+
Base class for retry strategies.
|
|
24
|
+
Implements core retry logic and error handling.
|
|
25
|
+
Subclasses should implement strategy-specific wait time calculations.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# NOTE: for token retrier, the estimate_tokens stays the same because ImmediateRetryException
|
|
29
|
+
# is raised for network errors, where tokens have not been deducted yet
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
async def immediate_retry_wrapper(
|
|
33
|
+
func: Callable[[], Awaitable[R]],
|
|
34
|
+
logger: logging.Logger,
|
|
35
|
+
) -> R:
|
|
36
|
+
"""
|
|
37
|
+
Retry the query immediately
|
|
38
|
+
"""
|
|
39
|
+
MAX_IMMEDIATE_RETRIES = 10
|
|
40
|
+
retries = 0
|
|
41
|
+
while True:
|
|
42
|
+
try:
|
|
43
|
+
return await func()
|
|
44
|
+
except ImmediateRetryException as e:
|
|
45
|
+
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
46
|
+
raise Exception(
|
|
47
|
+
f"[Immediate Retry Max] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
|
|
48
|
+
) from e
|
|
49
|
+
retries += 1
|
|
50
|
+
|
|
51
|
+
logger.warning(
|
|
52
|
+
f"[Immediate Retry] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
strategy: Literal["backoff", "token"],
|
|
58
|
+
logger: logging.Logger,
|
|
59
|
+
max_tries: int | None,
|
|
60
|
+
max_time: float | None,
|
|
61
|
+
retry_callback: Callable[[int, Exception | None, float, float], None] | None,
|
|
62
|
+
):
|
|
63
|
+
self.strategy = strategy
|
|
64
|
+
self.logger = logger
|
|
65
|
+
self.max_tries = max_tries
|
|
66
|
+
self.max_time = max_time
|
|
67
|
+
self.retry_callback = retry_callback
|
|
68
|
+
|
|
69
|
+
self.attempts = 0
|
|
70
|
+
self.start_time: float | None = None
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
async def _calculate_wait_time(
|
|
74
|
+
self, attempt: int, exception: Exception | None = None
|
|
75
|
+
) -> float:
|
|
76
|
+
"""
|
|
77
|
+
Calculate wait time before retrying
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
attempt: Current attempt number (0-indexed)
|
|
81
|
+
exception: The exception that triggered the retry
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Wait time in seconds
|
|
85
|
+
"""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
async def _on_retry(
|
|
90
|
+
self, exception: Exception | None, elapsed: float, wait_time: float
|
|
91
|
+
) -> None:
|
|
92
|
+
"""
|
|
93
|
+
Hook called before waiting on retry
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
exception: The exception that triggered the retry
|
|
97
|
+
elapsed: Time elapsed since start
|
|
98
|
+
wait_time: Wait time
|
|
99
|
+
"""
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
def _should_retry(self, exception: Exception) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Determine if an exception should trigger a retry.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
exception: The exception to evaluate
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
True if should retry, False otherwise
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
if is_context_window_error(exception):
|
|
114
|
+
return False
|
|
115
|
+
return is_retriable_error(exception)
|
|
116
|
+
|
|
117
|
+
def _handle_giveup(self, exception: Exception, reason: str) -> None:
|
|
118
|
+
"""
|
|
119
|
+
Handle final exception after all retries exhausted.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
exception: The final exception
|
|
123
|
+
reason: Reason for giving up
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
MaxContextWindowExceededError: If context window error
|
|
127
|
+
Exception: The original exception otherwise
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
self.logger.error(
|
|
131
|
+
f"[Give up] | {self.strategy} | {reason} | Exception: {exception_message(exception)}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# instead of raising the provider exception, raise the custom MaxContextWindowExceededError
|
|
135
|
+
if is_context_window_error(exception):
|
|
136
|
+
message = exception.args[0] if exception.args else str(exception)
|
|
137
|
+
raise MaxContextWindowExceededError(message)
|
|
138
|
+
|
|
139
|
+
raise exception
|
|
140
|
+
|
|
141
|
+
@abstractmethod
|
|
142
|
+
async def _pre_function(self) -> None:
|
|
143
|
+
"""Hook called before the actual function call"""
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
async def _post_function(self, result: tuple[QueryResult, float]) -> None:
|
|
148
|
+
"""Hook called after the actual function call"""
|
|
149
|
+
...
|
|
150
|
+
|
|
151
|
+
async def execute(
|
|
152
|
+
self,
|
|
153
|
+
func: Callable[..., Any],
|
|
154
|
+
*args: Any,
|
|
155
|
+
**kwargs: Any,
|
|
156
|
+
) -> Any:
|
|
157
|
+
"""
|
|
158
|
+
Execute function with retry logic.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
func: Async function to execute
|
|
162
|
+
*args: Positional arguments for func
|
|
163
|
+
**kwargs: Keyword arguments for func
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Result from func
|
|
167
|
+
|
|
168
|
+
Raises:
|
|
169
|
+
Exception: If retries exhausted or non-retriable error occurs
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
self.attempts = 0
|
|
173
|
+
self.start_time = time.time()
|
|
174
|
+
|
|
175
|
+
await self.validate()
|
|
176
|
+
|
|
177
|
+
while True:
|
|
178
|
+
try:
|
|
179
|
+
await self._pre_function()
|
|
180
|
+
result = await func(*args, **kwargs)
|
|
181
|
+
await self._post_function(result)
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
except Exception as e:
|
|
185
|
+
elapsed = time.time() - self.start_time
|
|
186
|
+
|
|
187
|
+
self.attempts += 1
|
|
188
|
+
|
|
189
|
+
# check if max_tries exceeded
|
|
190
|
+
if self.max_tries is not None and self.attempts >= self.max_tries:
|
|
191
|
+
self._handle_giveup(
|
|
192
|
+
e, f"max_tries exceeded ({self.attempts} >= {self.max_tries})"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# check if max_time exceeded
|
|
196
|
+
if self.max_time is not None and elapsed > self.max_time:
|
|
197
|
+
self._handle_giveup(
|
|
198
|
+
e, f"max_time exceeded ({elapsed} > {self.max_time}s)"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if not self._should_retry(e):
|
|
202
|
+
self._handle_giveup(e, "not retriable")
|
|
203
|
+
|
|
204
|
+
# calculate wait time
|
|
205
|
+
wait_time = await self._calculate_wait_time(self.attempts, e)
|
|
206
|
+
# call pre retry sleep hook
|
|
207
|
+
await self._on_retry(e, elapsed, wait_time)
|
|
208
|
+
|
|
209
|
+
await asyncio.sleep(wait_time)
|
|
210
|
+
|
|
211
|
+
async def validate(self) -> None:
|
|
212
|
+
"""Validate the retrier"""
|
|
213
|
+
...
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def retry_decorator(retrier: BaseRetrier) -> RetrierType:
|
|
217
|
+
"""Create a retry decorator from an initialized retrier"""
|
|
218
|
+
|
|
219
|
+
def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
|
|
220
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
221
|
+
return await retrier.execute(func, *args, **kwargs)
|
|
222
|
+
|
|
223
|
+
return wrapper
|
|
224
|
+
|
|
225
|
+
return decorator
|
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from asyncio.tasks import Task
|
|
6
|
+
from math import ceil, floor
|
|
7
|
+
from typing import Any, Callable, Coroutine
|
|
8
|
+
|
|
9
|
+
from redis.asyncio import Redis
|
|
10
|
+
|
|
11
|
+
from model_library.base.base import QueryResult, RateLimit
|
|
12
|
+
from model_library.exceptions import exception_message
|
|
13
|
+
from model_library.retriers.base import BaseRetrier
|
|
14
|
+
from model_library.retriers.utils import jitter
|
|
15
|
+
|
|
16
|
+
RETRY_WAIT_TIME: float = 20.0
|
|
17
|
+
TOKEN_WAIT_TIME: float = 5.0
|
|
18
|
+
|
|
19
|
+
MAX_PRIORITY: int = 1
|
|
20
|
+
MIN_PRIORITY: int = 5
|
|
21
|
+
|
|
22
|
+
MAX_RETRIES: int = 10
|
|
23
|
+
|
|
24
|
+
redis_client: Redis
|
|
25
|
+
LOCK_TIMEOUT: int = 10 # using 10 in case there is high compute load, don't want to error on lock releases
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def set_redis_client(client: Redis):
|
|
29
|
+
global redis_client
|
|
30
|
+
redis_client = client
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
refill_tasks: dict[str, tuple[dict[str, int | Any], Task[None]]] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TokenRetrier(BaseRetrier):
|
|
37
|
+
"""
|
|
38
|
+
Token-based retry strategy
|
|
39
|
+
Predicts the number of tokens required for a query, sends resquests to respect the rate limit,
|
|
40
|
+
then adjusts the estimate based on actual usage.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def get_token_key(client_registry_key: tuple[str, str]) -> str:
|
|
45
|
+
"""Get the key which stores remaining tokens"""
|
|
46
|
+
return f"{client_registry_key[0]}:{client_registry_key[1]}:tokens"
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def get_priority_key(client_registry_key: tuple[str, str], priority: int) -> str:
|
|
50
|
+
"""Get the key which stores the amount of tasks waiting for a given priority"""
|
|
51
|
+
return f"{client_registry_key[0]}:{client_registry_key[1]}:priority:{priority}"
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
async def init_remaining_tokens(
|
|
55
|
+
client_registry_key: tuple[str, str],
|
|
56
|
+
limit: int,
|
|
57
|
+
limit_refresh_seconds: int,
|
|
58
|
+
logger: logging.Logger,
|
|
59
|
+
get_rate_limit_func: Callable[[], Coroutine[Any, Any, RateLimit | None]],
|
|
60
|
+
) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Initialize remaining tokens in storage and start background refill process
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
async def _header_correction_loop(
|
|
66
|
+
key: str,
|
|
67
|
+
limit: int,
|
|
68
|
+
tokens_per_second: int,
|
|
69
|
+
get_rate_limit_func: Callable[[], Coroutine[Any, Any, RateLimit | None]],
|
|
70
|
+
version: str,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Background loop that correct tokens based on provider headers
|
|
74
|
+
Every 5 seconds
|
|
75
|
+
"""
|
|
76
|
+
interval = 5.0
|
|
77
|
+
|
|
78
|
+
assert redis_client
|
|
79
|
+
while True:
|
|
80
|
+
await asyncio.sleep(interval)
|
|
81
|
+
current_version = await redis_client.get("version:" + key)
|
|
82
|
+
if current_version != version:
|
|
83
|
+
logger.debug(
|
|
84
|
+
f"version changed ({current_version} != {version}), exiting _header_correction_loop for {key}"
|
|
85
|
+
)
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
rate_limit = await get_rate_limit_func()
|
|
89
|
+
if rate_limit is None:
|
|
90
|
+
# kill the task as no headers are provided
|
|
91
|
+
logger.debug(
|
|
92
|
+
f"no rate limit headers, exiting _header_correction_loop for {key}"
|
|
93
|
+
)
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
tokens_remaining = rate_limit.token_remaining_total
|
|
97
|
+
|
|
98
|
+
async with redis_client.lock(key + ":lock", timeout=LOCK_TIMEOUT):
|
|
99
|
+
current = int(await redis_client.get(key))
|
|
100
|
+
|
|
101
|
+
# increment
|
|
102
|
+
elapsed = time.time() - rate_limit.unix_timestamp
|
|
103
|
+
adjusted = floor(tokens_remaining + (tokens_per_second * elapsed))
|
|
104
|
+
|
|
105
|
+
# if the headers show a lower value, correct with that
|
|
106
|
+
if adjusted < current:
|
|
107
|
+
await redis_client.set(key, adjusted)
|
|
108
|
+
logger.info(
|
|
109
|
+
f"Corrected {key} from {current} to {adjusted} based on headers ({elapsed:.1f}s old)"
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
logger.debug(
|
|
113
|
+
f"Not correcting {key} from {current} to {adjusted} based on headers ({elapsed:.1f}s old) (higher value)"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def _token_refill_loop(
|
|
117
|
+
key: str,
|
|
118
|
+
limit: int,
|
|
119
|
+
tokens_per_second: int,
|
|
120
|
+
version: str,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Background loop that refills tokens
|
|
124
|
+
Every second
|
|
125
|
+
"""
|
|
126
|
+
interval: float = 1.0
|
|
127
|
+
|
|
128
|
+
assert redis_client
|
|
129
|
+
while True:
|
|
130
|
+
await asyncio.sleep(interval)
|
|
131
|
+
current_version = await redis_client.get("version:" + key)
|
|
132
|
+
if current_version != version:
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"version changed ({current_version} != {version}), exiting _token_refill_loop for {key}"
|
|
135
|
+
)
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
async with redis_client.lock(key + ":lock", timeout=LOCK_TIMEOUT):
|
|
139
|
+
# increment
|
|
140
|
+
current = await redis_client.incrby(key, tokens_per_second)
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"[Token Refill] | {key} | Amount: {tokens_per_second} | Current: {current}"
|
|
143
|
+
)
|
|
144
|
+
# cap at limit
|
|
145
|
+
if current > limit:
|
|
146
|
+
logger.debug(f"[Token Cap] | {key} | Limit: {limit}")
|
|
147
|
+
await redis_client.set(key, limit)
|
|
148
|
+
|
|
149
|
+
key = TokenRetrier.get_token_key(client_registry_key)
|
|
150
|
+
|
|
151
|
+
# limit_key is only used to check if the limit has changed
|
|
152
|
+
limit_key = f"{key}:limit"
|
|
153
|
+
|
|
154
|
+
async with redis_client.lock("init:" + key + ":lock", timeout=LOCK_TIMEOUT):
|
|
155
|
+
old_limit = int(await redis_client.get(limit_key) or 0)
|
|
156
|
+
|
|
157
|
+
# keep track of version so we can clean up old tasks
|
|
158
|
+
# even if the limit has not changed, reset background tasks just in case
|
|
159
|
+
version = str(uuid.uuid4())
|
|
160
|
+
await redis_client.set("version:" + key, version)
|
|
161
|
+
|
|
162
|
+
if old_limit != limit or not await redis_client.exists(key):
|
|
163
|
+
# if new limit if different, set it
|
|
164
|
+
await redis_client.set(key, limit)
|
|
165
|
+
await redis_client.set(limit_key, limit)
|
|
166
|
+
|
|
167
|
+
tokens_per_second = floor(limit / limit_refresh_seconds)
|
|
168
|
+
|
|
169
|
+
refill_task = asyncio.create_task(
|
|
170
|
+
_token_refill_loop(key, limit, tokens_per_second, version)
|
|
171
|
+
)
|
|
172
|
+
correction_task = asyncio.create_task(
|
|
173
|
+
_header_correction_loop(
|
|
174
|
+
key, limit, tokens_per_second, get_rate_limit_func, version
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
refill_tasks["refill:" + key] = (
|
|
179
|
+
{
|
|
180
|
+
"limit": limit,
|
|
181
|
+
"limit_refresh_seconds": limit_refresh_seconds,
|
|
182
|
+
},
|
|
183
|
+
refill_task,
|
|
184
|
+
)
|
|
185
|
+
refill_tasks["correction:" + key] = (
|
|
186
|
+
{
|
|
187
|
+
"limit": limit,
|
|
188
|
+
"limit_refresh_seconds": limit_refresh_seconds,
|
|
189
|
+
"get_rate_limit_func": get_rate_limit_func,
|
|
190
|
+
},
|
|
191
|
+
correction_task,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
async def _get_remaining_tokens(self) -> int:
|
|
195
|
+
"""Get remaining tokens"""
|
|
196
|
+
tokens = await redis_client.get(self.token_key)
|
|
197
|
+
return int(tokens)
|
|
198
|
+
|
|
199
|
+
async def _deduct_remaining_tokens(self) -> None:
|
|
200
|
+
"""Deduct from remaining tokens"""
|
|
201
|
+
# NOTE: decrby is atomic
|
|
202
|
+
await redis_client.decrby(self.token_key, self.actual_estimate_total_tokens)
|
|
203
|
+
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
logger: logging.Logger,
|
|
207
|
+
max_tries: int | None = MAX_RETRIES,
|
|
208
|
+
max_time: float | None = None,
|
|
209
|
+
retry_callback: Callable[[int, Exception | None, float, float], None]
|
|
210
|
+
| None = None,
|
|
211
|
+
*,
|
|
212
|
+
client_registry_key: tuple[str, str],
|
|
213
|
+
estimate_input_tokens: int,
|
|
214
|
+
estimate_output_tokens: int,
|
|
215
|
+
dynamic_estimate_instance_id: str | None = None,
|
|
216
|
+
retry_wait_time: float = RETRY_WAIT_TIME,
|
|
217
|
+
token_wait_time: float = TOKEN_WAIT_TIME,
|
|
218
|
+
):
|
|
219
|
+
super().__init__(
|
|
220
|
+
strategy="token",
|
|
221
|
+
logger=logger,
|
|
222
|
+
max_tries=max_tries,
|
|
223
|
+
max_time=max_time,
|
|
224
|
+
retry_callback=retry_callback,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
self.client_registry_key = client_registry_key
|
|
228
|
+
|
|
229
|
+
self.estimate_input_tokens = estimate_input_tokens
|
|
230
|
+
self.estimate_output_tokens = estimate_output_tokens
|
|
231
|
+
self.estimate_total_tokens = estimate_input_tokens + estimate_output_tokens
|
|
232
|
+
self.actual_estimate_total_tokens = (
|
|
233
|
+
self.estimate_total_tokens
|
|
234
|
+
) # when multiplying base estimate_total_tokens by ratio
|
|
235
|
+
|
|
236
|
+
self.retry_wait_time = retry_wait_time
|
|
237
|
+
self.token_wait_time = token_wait_time
|
|
238
|
+
|
|
239
|
+
self.priority = MAX_PRIORITY
|
|
240
|
+
|
|
241
|
+
self.token_key = TokenRetrier.get_token_key(client_registry_key)
|
|
242
|
+
self._token_key_lock = self.token_key + ":lock"
|
|
243
|
+
self._init_key_lock = "init:" + self.token_key + ":lock"
|
|
244
|
+
|
|
245
|
+
self.dynamic_estimate_key = (
|
|
246
|
+
f"{self.token_key}:dynamic_estimate:{dynamic_estimate_instance_id}"
|
|
247
|
+
if dynamic_estimate_instance_id
|
|
248
|
+
else None
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
async def _calculate_wait_time(
|
|
252
|
+
self, attempt: int, exception: Exception | None = None
|
|
253
|
+
) -> float:
|
|
254
|
+
"""Wait time between retries"""
|
|
255
|
+
return jitter(self.retry_wait_time)
|
|
256
|
+
|
|
257
|
+
async def _on_retry(
|
|
258
|
+
self, exception: Exception | None, elapsed: float, wait_time: float
|
|
259
|
+
) -> None:
|
|
260
|
+
"""Log retry attempt and update priority/attempts only on actual exceptions"""
|
|
261
|
+
|
|
262
|
+
self.priority = min(MIN_PRIORITY, self.priority + 1)
|
|
263
|
+
|
|
264
|
+
logger_msg = (
|
|
265
|
+
f"[Token Retry] | Attempt: {self.attempts}/{self.max_tries} | Elapsed: {elapsed:.1f}s | "
|
|
266
|
+
f"Next wait: {wait_time:.1f}s | Priority: {self.priority} ({MAX_PRIORITY}-{MIN_PRIORITY}) | "
|
|
267
|
+
f"Exception: {exception_message(exception)}"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
self.logger.warning(logger_msg)
|
|
271
|
+
|
|
272
|
+
if self.retry_callback:
|
|
273
|
+
self.retry_callback(self.attempts, exception, elapsed, wait_time)
|
|
274
|
+
|
|
275
|
+
async def _has_lower_priority_waiting(self) -> bool:
|
|
276
|
+
"""
|
|
277
|
+
Check if there are lower priority requests waiting
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
# NOTE: no lock needed, stale counts are fine
|
|
281
|
+
for priority in range(MAX_PRIORITY, self.priority):
|
|
282
|
+
key = TokenRetrier.get_priority_key(self.client_registry_key, priority)
|
|
283
|
+
count = await redis_client.get(key)
|
|
284
|
+
self.logger.debug(f"priority: {priority}, count: {count}")
|
|
285
|
+
if count and int(count) > 0:
|
|
286
|
+
return True
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
async def _pre_function(self) -> None:
|
|
290
|
+
"""
|
|
291
|
+
Loop until sufficient tokens are available.
|
|
292
|
+
Acquires priority semaphore, checks for lower priority requests, deducts tokens from Redis.
|
|
293
|
+
Logs token waits but does not count as retry attempts.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
priority_key = TokenRetrier.get_priority_key(
|
|
297
|
+
self.client_registry_key, self.priority
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# let storage know we are waiting at this priority
|
|
301
|
+
await redis_client.incr(priority_key)
|
|
302
|
+
self.logger.debug(f"priority: {self.priority}, waiting: {priority_key}")
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
while True:
|
|
306
|
+
wait_time = jitter(self.token_wait_time)
|
|
307
|
+
|
|
308
|
+
# if there is a task with lower priority waiting, go back to waiting
|
|
309
|
+
if await self._has_lower_priority_waiting():
|
|
310
|
+
self.logger.debug(
|
|
311
|
+
f"[Token Wait] Lower priority requests exist, waiting {wait_time:.1f}s | "
|
|
312
|
+
f"Priority: {self.priority}"
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
# dynamically adjust actual estimate tokens based on past requests
|
|
316
|
+
if self.dynamic_estimate_key:
|
|
317
|
+
# NOTE: ok to not lock, don't need precise ratio
|
|
318
|
+
ratio = float(
|
|
319
|
+
await redis_client.get(self.dynamic_estimate_key) or 1.0
|
|
320
|
+
)
|
|
321
|
+
self.actual_estimate_total_tokens = ceil(
|
|
322
|
+
self.estimate_total_tokens * ratio
|
|
323
|
+
)
|
|
324
|
+
self.logger.debug(
|
|
325
|
+
f"Adjusted actual estimate tokens to {self.actual_estimate_total_tokens} using ratio {ratio}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# TODO: use luascript to avoid using locks
|
|
329
|
+
|
|
330
|
+
# NOTE: `async with` releases lock in all situations
|
|
331
|
+
async with redis_client.lock(
|
|
332
|
+
self._token_key_lock, timeout=LOCK_TIMEOUT
|
|
333
|
+
):
|
|
334
|
+
tokens_remaining = await self._get_remaining_tokens()
|
|
335
|
+
|
|
336
|
+
# if we have enough tokens, deduct estimate tokens and make request
|
|
337
|
+
if tokens_remaining >= self.actual_estimate_total_tokens:
|
|
338
|
+
self.logger.debug(
|
|
339
|
+
f"Enough tokens {self.actual_estimate_total_tokens}/{tokens_remaining}, deducting"
|
|
340
|
+
)
|
|
341
|
+
await self._deduct_remaining_tokens()
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
self.logger.warning(
|
|
345
|
+
f"[Token Wait] Insufficient tokens, waiting {wait_time:.1f}s | "
|
|
346
|
+
f"estimate_tokens: {self.actual_estimate_total_tokens}/{tokens_remaining} | "
|
|
347
|
+
f"Priority: {self.priority}"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Zzz
|
|
351
|
+
self.logger.debug(f"Sleeping for {wait_time:.1f}s")
|
|
352
|
+
await asyncio.sleep(wait_time)
|
|
353
|
+
finally:
|
|
354
|
+
# let storage know we are done waiting at this priority
|
|
355
|
+
await redis_client.decr(priority_key)
|
|
356
|
+
|
|
357
|
+
async def _adjust_dynamic_estimate_ratio(self, actual_tokens: int) -> None:
|
|
358
|
+
if not self.dynamic_estimate_key:
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
observed_ratio = actual_tokens / self.estimate_total_tokens
|
|
362
|
+
|
|
363
|
+
alpha = 0.3
|
|
364
|
+
|
|
365
|
+
async with redis_client.lock(
|
|
366
|
+
self.dynamic_estimate_key + ":lock", timeout=LOCK_TIMEOUT
|
|
367
|
+
):
|
|
368
|
+
current_ratio = float(
|
|
369
|
+
await redis_client.get(self.dynamic_estimate_key) or 1.0
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
new_ratio = (observed_ratio * alpha) + (current_ratio * (1 - alpha))
|
|
373
|
+
|
|
374
|
+
# NOTE: for now, will not cap the ratio as estimates will likely be very off
|
|
375
|
+
# the ratio between the tokenized estimate and the dynamic estimate should not be too far off
|
|
376
|
+
# new_ratio = max(0.01, min(100.0, new_ratio))
|
|
377
|
+
|
|
378
|
+
await redis_client.set(self.dynamic_estimate_key, new_ratio)
|
|
379
|
+
|
|
380
|
+
self.logger.info(
|
|
381
|
+
f"[Token Ratio] {self.token_key} | Observed: {observed_ratio:.5f} | "
|
|
382
|
+
f"Global Ratio: {current_ratio:.5f} -> {new_ratio:.5f}"
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
async def _post_function(self, result: tuple[QueryResult, float]) -> None:
|
|
386
|
+
"""Adjust token estimate based on actual usage"""
|
|
387
|
+
|
|
388
|
+
metadata = result[0].metadata
|
|
389
|
+
|
|
390
|
+
countable_input_tokens = metadata.total_input_tokens - (
|
|
391
|
+
metadata.cache_read_tokens or 0
|
|
392
|
+
)
|
|
393
|
+
countable_output_tokens = metadata.total_output_tokens
|
|
394
|
+
actual_tokens = countable_input_tokens + countable_output_tokens
|
|
395
|
+
|
|
396
|
+
difference = self.actual_estimate_total_tokens - actual_tokens
|
|
397
|
+
self.logger.info(
|
|
398
|
+
f"Adjusting {self.token_key} by {difference}. Estimated {self.actual_estimate_total_tokens}, actual {actual_tokens}"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
await self._adjust_dynamic_estimate_ratio(actual_tokens)
|
|
402
|
+
|
|
403
|
+
# NOTE: this can generate negative values, which represent `debt`
|
|
404
|
+
async with redis_client.lock(self._token_key_lock, timeout=LOCK_TIMEOUT):
|
|
405
|
+
await redis_client.incrby(self.token_key, difference)
|
|
406
|
+
|
|
407
|
+
result[0].metadata.extra["token_metadata"] = {
|
|
408
|
+
"estimated": self.estimate_total_tokens,
|
|
409
|
+
"estimated_with_dynamic_ratio": self.actual_estimate_total_tokens,
|
|
410
|
+
"actual": actual_tokens,
|
|
411
|
+
"difference": difference,
|
|
412
|
+
"ratio": actual_tokens / self.estimate_total_tokens,
|
|
413
|
+
"dynamic_ratio_used": self.actual_estimate_total_tokens
|
|
414
|
+
/ self.estimate_total_tokens,
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
async def validate(self) -> None:
|
|
418
|
+
try:
|
|
419
|
+
assert redis_client
|
|
420
|
+
except Exception as e:
|
|
421
|
+
raise Exception(
|
|
422
|
+
f"redis client not set, run `TokenRetrier.set_redis_client`. Exception: {e}"
|
|
423
|
+
)
|
|
424
|
+
if not await redis_client.exists(self.token_key):
|
|
425
|
+
raise Exception(
|
|
426
|
+
"remaining_tokens not intialized, run `model.init_token_retry`"
|
|
427
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def jitter(wait: float) -> float:
|
|
5
|
+
"""
|
|
6
|
+
Increase or decrease the wait time by up to 20%.
|
|
7
|
+
"""
|
|
8
|
+
jitter_fraction = 0.2
|
|
9
|
+
min_wait = wait * (1 - jitter_fraction)
|
|
10
|
+
max_wait = wait * (1 + jitter_fraction)
|
|
11
|
+
return random.uniform(min_wait, max_wait)
|
model_library/settings.py
CHANGED
|
@@ -22,7 +22,7 @@ class ModelLibrarySettings:
|
|
|
22
22
|
except AttributeError:
|
|
23
23
|
return default
|
|
24
24
|
|
|
25
|
-
def __getattr__(self, name: str) -> str
|
|
25
|
+
def __getattr__(self, name: str) -> str:
|
|
26
26
|
# load key from override
|
|
27
27
|
if name in self._key_overrides:
|
|
28
28
|
return self._key_overrides[name]
|