model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. model_library/base/base.py +141 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +49 -57
  6. model_library/config/all_models.json +353 -120
  7. model_library/config/anthropic_models.yaml +2 -1
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/mistral_models.yaml +2 -0
  10. model_library/config/openai_models.yaml +15 -23
  11. model_library/config/together_models.yaml +2 -0
  12. model_library/config/xiaomi_models.yaml +43 -0
  13. model_library/config/zai_models.yaml +27 -3
  14. model_library/exceptions.py +3 -77
  15. model_library/providers/ai21labs.py +12 -8
  16. model_library/providers/alibaba.py +17 -8
  17. model_library/providers/amazon.py +49 -16
  18. model_library/providers/anthropic.py +128 -48
  19. model_library/providers/azure.py +22 -10
  20. model_library/providers/cohere.py +7 -7
  21. model_library/providers/deepseek.py +8 -8
  22. model_library/providers/fireworks.py +7 -8
  23. model_library/providers/google/batch.py +14 -10
  24. model_library/providers/google/google.py +57 -30
  25. model_library/providers/inception.py +7 -7
  26. model_library/providers/kimi.py +18 -8
  27. model_library/providers/minimax.py +15 -17
  28. model_library/providers/mistral.py +20 -8
  29. model_library/providers/openai.py +99 -22
  30. model_library/providers/openrouter.py +34 -0
  31. model_library/providers/perplexity.py +7 -7
  32. model_library/providers/together.py +7 -8
  33. model_library/providers/vals.py +12 -6
  34. model_library/providers/vercel.py +34 -0
  35. model_library/providers/xai.py +47 -42
  36. model_library/providers/xiaomi.py +34 -0
  37. model_library/providers/zai.py +38 -8
  38. model_library/register_models.py +5 -0
  39. model_library/registry_utils.py +48 -17
  40. model_library/retriers/__init__.py +0 -0
  41. model_library/retriers/backoff.py +73 -0
  42. model_library/retriers/base.py +225 -0
  43. model_library/retriers/token.py +427 -0
  44. model_library/retriers/utils.py +11 -0
  45. model_library/settings.py +1 -1
  46. model_library/utils.py +17 -7
  47. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
  48. model_library-0.1.9.dist-info/RECORD +73 -0
  49. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
  50. model_library-0.1.7.dist-info/RECORD +0 -64
  51. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
  52. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Awaitable, Callable, Literal, TypeVar
6
+
7
+ from model_library.base.base import QueryResult
8
+ from model_library.exceptions import (
9
+ ImmediateRetryException,
10
+ MaxContextWindowExceededError,
11
+ exception_message,
12
+ is_context_window_error,
13
+ is_retriable_error,
14
+ )
15
+
16
+ RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
17
+
18
+ R = TypeVar("R") # wrapper return type
19
+
20
+
21
+ class BaseRetrier(ABC):
22
+ """
23
+ Base class for retry strategies.
24
+ Implements core retry logic and error handling.
25
+ Subclasses should implement strategy-specific wait time calculations.
26
+ """
27
+
28
+ # NOTE: for token retrier, the estimate_tokens stays the same because ImmediateRetryException
29
+ # is raised for network errors, where tokens have not been deducted yet
30
+
31
+ @staticmethod
32
+ async def immediate_retry_wrapper(
33
+ func: Callable[[], Awaitable[R]],
34
+ logger: logging.Logger,
35
+ ) -> R:
36
+ """
37
+ Retry the query immediately
38
+ """
39
+ MAX_IMMEDIATE_RETRIES = 10
40
+ retries = 0
41
+ while True:
42
+ try:
43
+ return await func()
44
+ except ImmediateRetryException as e:
45
+ if retries >= MAX_IMMEDIATE_RETRIES:
46
+ raise Exception(
47
+ f"[Immediate Retry Max] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
48
+ ) from e
49
+ retries += 1
50
+
51
+ logger.warning(
52
+ f"[Immediate Retry] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
53
+ )
54
+
55
+ def __init__(
56
+ self,
57
+ strategy: Literal["backoff", "token"],
58
+ logger: logging.Logger,
59
+ max_tries: int | None,
60
+ max_time: float | None,
61
+ retry_callback: Callable[[int, Exception | None, float, float], None] | None,
62
+ ):
63
+ self.strategy = strategy
64
+ self.logger = logger
65
+ self.max_tries = max_tries
66
+ self.max_time = max_time
67
+ self.retry_callback = retry_callback
68
+
69
+ self.attempts = 0
70
+ self.start_time: float | None = None
71
+
72
+ @abstractmethod
73
+ async def _calculate_wait_time(
74
+ self, attempt: int, exception: Exception | None = None
75
+ ) -> float:
76
+ """
77
+ Calculate wait time before retrying
78
+
79
+ Args:
80
+ attempt: Current attempt number (0-indexed)
81
+ exception: The exception that triggered the retry
82
+
83
+ Returns:
84
+ Wait time in seconds
85
+ """
86
+ ...
87
+
88
+ @abstractmethod
89
+ async def _on_retry(
90
+ self, exception: Exception | None, elapsed: float, wait_time: float
91
+ ) -> None:
92
+ """
93
+ Hook called before waiting on retry
94
+
95
+ Args:
96
+ exception: The exception that triggered the retry
97
+ elapsed: Time elapsed since start
98
+ wait_time: Wait time
99
+ """
100
+ ...
101
+
102
+ def _should_retry(self, exception: Exception) -> bool:
103
+ """
104
+ Determine if an exception should trigger a retry.
105
+
106
+ Args:
107
+ exception: The exception to evaluate
108
+
109
+ Returns:
110
+ True if should retry, False otherwise
111
+ """
112
+
113
+ if is_context_window_error(exception):
114
+ return False
115
+ return is_retriable_error(exception)
116
+
117
+ def _handle_giveup(self, exception: Exception, reason: str) -> None:
118
+ """
119
+ Handle final exception after all retries exhausted.
120
+
121
+ Args:
122
+ exception: The final exception
123
+ reason: Reason for giving up
124
+
125
+ Raises:
126
+ MaxContextWindowExceededError: If context window error
127
+ Exception: The original exception otherwise
128
+ """
129
+
130
+ self.logger.error(
131
+ f"[Give up] | {self.strategy} | {reason} | Exception: {exception_message(exception)}"
132
+ )
133
+
134
+ # instead of raising the provider exception, raise the custom MaxContextWindowExceededError
135
+ if is_context_window_error(exception):
136
+ message = exception.args[0] if exception.args else str(exception)
137
+ raise MaxContextWindowExceededError(message)
138
+
139
+ raise exception
140
+
141
+ @abstractmethod
142
+ async def _pre_function(self) -> None:
143
+ """Hook called before the actual function call"""
144
+ ...
145
+
146
+ @abstractmethod
147
+ async def _post_function(self, result: tuple[QueryResult, float]) -> None:
148
+ """Hook called after the actual function call"""
149
+ ...
150
+
151
+ async def execute(
152
+ self,
153
+ func: Callable[..., Any],
154
+ *args: Any,
155
+ **kwargs: Any,
156
+ ) -> Any:
157
+ """
158
+ Execute function with retry logic.
159
+
160
+ Args:
161
+ func: Async function to execute
162
+ *args: Positional arguments for func
163
+ **kwargs: Keyword arguments for func
164
+
165
+ Returns:
166
+ Result from func
167
+
168
+ Raises:
169
+ Exception: If retries exhausted or non-retriable error occurs
170
+ """
171
+
172
+ self.attempts = 0
173
+ self.start_time = time.time()
174
+
175
+ await self.validate()
176
+
177
+ while True:
178
+ try:
179
+ await self._pre_function()
180
+ result = await func(*args, **kwargs)
181
+ await self._post_function(result)
182
+ return result
183
+
184
+ except Exception as e:
185
+ elapsed = time.time() - self.start_time
186
+
187
+ self.attempts += 1
188
+
189
+ # check if max_tries exceeded
190
+ if self.max_tries is not None and self.attempts >= self.max_tries:
191
+ self._handle_giveup(
192
+ e, f"max_tries exceeded ({self.attempts} >= {self.max_tries})"
193
+ )
194
+
195
+ # check if max_time exceeded
196
+ if self.max_time is not None and elapsed > self.max_time:
197
+ self._handle_giveup(
198
+ e, f"max_time exceeded ({elapsed} > {self.max_time}s)"
199
+ )
200
+
201
+ if not self._should_retry(e):
202
+ self._handle_giveup(e, "not retriable")
203
+
204
+ # calculate wait time
205
+ wait_time = await self._calculate_wait_time(self.attempts, e)
206
+ # call pre retry sleep hook
207
+ await self._on_retry(e, elapsed, wait_time)
208
+
209
+ await asyncio.sleep(wait_time)
210
+
211
+ async def validate(self) -> None:
212
+ """Validate the retrier"""
213
+ ...
214
+
215
+
216
+ def retry_decorator(retrier: BaseRetrier) -> RetrierType:
217
+ """Create a retry decorator from an initialized retrier"""
218
+
219
+ def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
220
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
221
+ return await retrier.execute(func, *args, **kwargs)
222
+
223
+ return wrapper
224
+
225
+ return decorator
@@ -0,0 +1,427 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ import uuid
5
+ from asyncio.tasks import Task
6
+ from math import ceil, floor
7
+ from typing import Any, Callable, Coroutine
8
+
9
+ from redis.asyncio import Redis
10
+
11
+ from model_library.base.base import QueryResult, RateLimit
12
+ from model_library.exceptions import exception_message
13
+ from model_library.retriers.base import BaseRetrier
14
+ from model_library.retriers.utils import jitter
15
+
16
+ RETRY_WAIT_TIME: float = 20.0
17
+ TOKEN_WAIT_TIME: float = 5.0
18
+
19
+ MAX_PRIORITY: int = 1
20
+ MIN_PRIORITY: int = 5
21
+
22
+ MAX_RETRIES: int = 10
23
+
24
+ redis_client: Redis
25
+ LOCK_TIMEOUT: int = 10 # using 10 in case there is high compute load, don't want to error on lock releases
26
+
27
+
28
+ def set_redis_client(client: Redis):
29
+ global redis_client
30
+ redis_client = client
31
+
32
+
33
+ refill_tasks: dict[str, tuple[dict[str, int | Any], Task[None]]] = {}
34
+
35
+
36
+ class TokenRetrier(BaseRetrier):
37
+ """
38
+ Token-based retry strategy
39
+ Predicts the number of tokens required for a query, sends resquests to respect the rate limit,
40
+ then adjusts the estimate based on actual usage.
41
+ """
42
+
43
+ @staticmethod
44
+ def get_token_key(client_registry_key: tuple[str, str]) -> str:
45
+ """Get the key which stores remaining tokens"""
46
+ return f"{client_registry_key[0]}:{client_registry_key[1]}:tokens"
47
+
48
+ @staticmethod
49
+ def get_priority_key(client_registry_key: tuple[str, str], priority: int) -> str:
50
+ """Get the key which stores the amount of tasks waiting for a given priority"""
51
+ return f"{client_registry_key[0]}:{client_registry_key[1]}:priority:{priority}"
52
+
53
+ @staticmethod
54
+ async def init_remaining_tokens(
55
+ client_registry_key: tuple[str, str],
56
+ limit: int,
57
+ limit_refresh_seconds: int,
58
+ logger: logging.Logger,
59
+ get_rate_limit_func: Callable[[], Coroutine[Any, Any, RateLimit | None]],
60
+ ) -> None:
61
+ """
62
+ Initialize remaining tokens in storage and start background refill process
63
+ """
64
+
65
+ async def _header_correction_loop(
66
+ key: str,
67
+ limit: int,
68
+ tokens_per_second: int,
69
+ get_rate_limit_func: Callable[[], Coroutine[Any, Any, RateLimit | None]],
70
+ version: str,
71
+ ) -> None:
72
+ """
73
+ Background loop that correct tokens based on provider headers
74
+ Every 5 seconds
75
+ """
76
+ interval = 5.0
77
+
78
+ assert redis_client
79
+ while True:
80
+ await asyncio.sleep(interval)
81
+ current_version = await redis_client.get("version:" + key)
82
+ if current_version != version:
83
+ logger.debug(
84
+ f"version changed ({current_version} != {version}), exiting _header_correction_loop for {key}"
85
+ )
86
+ return
87
+
88
+ rate_limit = await get_rate_limit_func()
89
+ if rate_limit is None:
90
+ # kill the task as no headers are provided
91
+ logger.debug(
92
+ f"no rate limit headers, exiting _header_correction_loop for {key}"
93
+ )
94
+ return
95
+
96
+ tokens_remaining = rate_limit.token_remaining_total
97
+
98
+ async with redis_client.lock(key + ":lock", timeout=LOCK_TIMEOUT):
99
+ current = int(await redis_client.get(key))
100
+
101
+ # increment
102
+ elapsed = time.time() - rate_limit.unix_timestamp
103
+ adjusted = floor(tokens_remaining + (tokens_per_second * elapsed))
104
+
105
+ # if the headers show a lower value, correct with that
106
+ if adjusted < current:
107
+ await redis_client.set(key, adjusted)
108
+ logger.info(
109
+ f"Corrected {key} from {current} to {adjusted} based on headers ({elapsed:.1f}s old)"
110
+ )
111
+ else:
112
+ logger.debug(
113
+ f"Not correcting {key} from {current} to {adjusted} based on headers ({elapsed:.1f}s old) (higher value)"
114
+ )
115
+
116
+ async def _token_refill_loop(
117
+ key: str,
118
+ limit: int,
119
+ tokens_per_second: int,
120
+ version: str,
121
+ ) -> None:
122
+ """
123
+ Background loop that refills tokens
124
+ Every second
125
+ """
126
+ interval: float = 1.0
127
+
128
+ assert redis_client
129
+ while True:
130
+ await asyncio.sleep(interval)
131
+ current_version = await redis_client.get("version:" + key)
132
+ if current_version != version:
133
+ logger.debug(
134
+ f"version changed ({current_version} != {version}), exiting _token_refill_loop for {key}"
135
+ )
136
+ return
137
+
138
+ async with redis_client.lock(key + ":lock", timeout=LOCK_TIMEOUT):
139
+ # increment
140
+ current = await redis_client.incrby(key, tokens_per_second)
141
+ logger.debug(
142
+ f"[Token Refill] | {key} | Amount: {tokens_per_second} | Current: {current}"
143
+ )
144
+ # cap at limit
145
+ if current > limit:
146
+ logger.debug(f"[Token Cap] | {key} | Limit: {limit}")
147
+ await redis_client.set(key, limit)
148
+
149
+ key = TokenRetrier.get_token_key(client_registry_key)
150
+
151
+ # limit_key is only used to check if the limit has changed
152
+ limit_key = f"{key}:limit"
153
+
154
+ async with redis_client.lock("init:" + key + ":lock", timeout=LOCK_TIMEOUT):
155
+ old_limit = int(await redis_client.get(limit_key) or 0)
156
+
157
+ # keep track of version so we can clean up old tasks
158
+ # even if the limit has not changed, reset background tasks just in case
159
+ version = str(uuid.uuid4())
160
+ await redis_client.set("version:" + key, version)
161
+
162
+ if old_limit != limit or not await redis_client.exists(key):
163
+ # if new limit if different, set it
164
+ await redis_client.set(key, limit)
165
+ await redis_client.set(limit_key, limit)
166
+
167
+ tokens_per_second = floor(limit / limit_refresh_seconds)
168
+
169
+ refill_task = asyncio.create_task(
170
+ _token_refill_loop(key, limit, tokens_per_second, version)
171
+ )
172
+ correction_task = asyncio.create_task(
173
+ _header_correction_loop(
174
+ key, limit, tokens_per_second, get_rate_limit_func, version
175
+ )
176
+ )
177
+
178
+ refill_tasks["refill:" + key] = (
179
+ {
180
+ "limit": limit,
181
+ "limit_refresh_seconds": limit_refresh_seconds,
182
+ },
183
+ refill_task,
184
+ )
185
+ refill_tasks["correction:" + key] = (
186
+ {
187
+ "limit": limit,
188
+ "limit_refresh_seconds": limit_refresh_seconds,
189
+ "get_rate_limit_func": get_rate_limit_func,
190
+ },
191
+ correction_task,
192
+ )
193
+
194
+ async def _get_remaining_tokens(self) -> int:
195
+ """Get remaining tokens"""
196
+ tokens = await redis_client.get(self.token_key)
197
+ return int(tokens)
198
+
199
+ async def _deduct_remaining_tokens(self) -> None:
200
+ """Deduct from remaining tokens"""
201
+ # NOTE: decrby is atomic
202
+ await redis_client.decrby(self.token_key, self.actual_estimate_total_tokens)
203
+
204
+ def __init__(
205
+ self,
206
+ logger: logging.Logger,
207
+ max_tries: int | None = MAX_RETRIES,
208
+ max_time: float | None = None,
209
+ retry_callback: Callable[[int, Exception | None, float, float], None]
210
+ | None = None,
211
+ *,
212
+ client_registry_key: tuple[str, str],
213
+ estimate_input_tokens: int,
214
+ estimate_output_tokens: int,
215
+ dynamic_estimate_instance_id: str | None = None,
216
+ retry_wait_time: float = RETRY_WAIT_TIME,
217
+ token_wait_time: float = TOKEN_WAIT_TIME,
218
+ ):
219
+ super().__init__(
220
+ strategy="token",
221
+ logger=logger,
222
+ max_tries=max_tries,
223
+ max_time=max_time,
224
+ retry_callback=retry_callback,
225
+ )
226
+
227
+ self.client_registry_key = client_registry_key
228
+
229
+ self.estimate_input_tokens = estimate_input_tokens
230
+ self.estimate_output_tokens = estimate_output_tokens
231
+ self.estimate_total_tokens = estimate_input_tokens + estimate_output_tokens
232
+ self.actual_estimate_total_tokens = (
233
+ self.estimate_total_tokens
234
+ ) # when multiplying base estimate_total_tokens by ratio
235
+
236
+ self.retry_wait_time = retry_wait_time
237
+ self.token_wait_time = token_wait_time
238
+
239
+ self.priority = MAX_PRIORITY
240
+
241
+ self.token_key = TokenRetrier.get_token_key(client_registry_key)
242
+ self._token_key_lock = self.token_key + ":lock"
243
+ self._init_key_lock = "init:" + self.token_key + ":lock"
244
+
245
+ self.dynamic_estimate_key = (
246
+ f"{self.token_key}:dynamic_estimate:{dynamic_estimate_instance_id}"
247
+ if dynamic_estimate_instance_id
248
+ else None
249
+ )
250
+
251
+ async def _calculate_wait_time(
252
+ self, attempt: int, exception: Exception | None = None
253
+ ) -> float:
254
+ """Wait time between retries"""
255
+ return jitter(self.retry_wait_time)
256
+
257
+ async def _on_retry(
258
+ self, exception: Exception | None, elapsed: float, wait_time: float
259
+ ) -> None:
260
+ """Log retry attempt and update priority/attempts only on actual exceptions"""
261
+
262
+ self.priority = min(MIN_PRIORITY, self.priority + 1)
263
+
264
+ logger_msg = (
265
+ f"[Token Retry] | Attempt: {self.attempts}/{self.max_tries} | Elapsed: {elapsed:.1f}s | "
266
+ f"Next wait: {wait_time:.1f}s | Priority: {self.priority} ({MAX_PRIORITY}-{MIN_PRIORITY}) | "
267
+ f"Exception: {exception_message(exception)}"
268
+ )
269
+
270
+ self.logger.warning(logger_msg)
271
+
272
+ if self.retry_callback:
273
+ self.retry_callback(self.attempts, exception, elapsed, wait_time)
274
+
275
+ async def _has_lower_priority_waiting(self) -> bool:
276
+ """
277
+ Check if there are lower priority requests waiting
278
+ """
279
+
280
+ # NOTE: no lock needed, stale counts are fine
281
+ for priority in range(MAX_PRIORITY, self.priority):
282
+ key = TokenRetrier.get_priority_key(self.client_registry_key, priority)
283
+ count = await redis_client.get(key)
284
+ self.logger.debug(f"priority: {priority}, count: {count}")
285
+ if count and int(count) > 0:
286
+ return True
287
+ return False
288
+
289
+ async def _pre_function(self) -> None:
290
+ """
291
+ Loop until sufficient tokens are available.
292
+ Acquires priority semaphore, checks for lower priority requests, deducts tokens from Redis.
293
+ Logs token waits but does not count as retry attempts.
294
+ """
295
+
296
+ priority_key = TokenRetrier.get_priority_key(
297
+ self.client_registry_key, self.priority
298
+ )
299
+
300
+ # let storage know we are waiting at this priority
301
+ await redis_client.incr(priority_key)
302
+ self.logger.debug(f"priority: {self.priority}, waiting: {priority_key}")
303
+
304
+ try:
305
+ while True:
306
+ wait_time = jitter(self.token_wait_time)
307
+
308
+ # if there is a task with lower priority waiting, go back to waiting
309
+ if await self._has_lower_priority_waiting():
310
+ self.logger.debug(
311
+ f"[Token Wait] Lower priority requests exist, waiting {wait_time:.1f}s | "
312
+ f"Priority: {self.priority}"
313
+ )
314
+ else:
315
+ # dynamically adjust actual estimate tokens based on past requests
316
+ if self.dynamic_estimate_key:
317
+ # NOTE: ok to not lock, don't need precise ratio
318
+ ratio = float(
319
+ await redis_client.get(self.dynamic_estimate_key) or 1.0
320
+ )
321
+ self.actual_estimate_total_tokens = ceil(
322
+ self.estimate_total_tokens * ratio
323
+ )
324
+ self.logger.debug(
325
+ f"Adjusted actual estimate tokens to {self.actual_estimate_total_tokens} using ratio {ratio}"
326
+ )
327
+
328
+ # TODO: use luascript to avoid using locks
329
+
330
+ # NOTE: `async with` releases lock in all situations
331
+ async with redis_client.lock(
332
+ self._token_key_lock, timeout=LOCK_TIMEOUT
333
+ ):
334
+ tokens_remaining = await self._get_remaining_tokens()
335
+
336
+ # if we have enough tokens, deduct estimate tokens and make request
337
+ if tokens_remaining >= self.actual_estimate_total_tokens:
338
+ self.logger.debug(
339
+ f"Enough tokens {self.actual_estimate_total_tokens}/{tokens_remaining}, deducting"
340
+ )
341
+ await self._deduct_remaining_tokens()
342
+ return
343
+
344
+ self.logger.warning(
345
+ f"[Token Wait] Insufficient tokens, waiting {wait_time:.1f}s | "
346
+ f"estimate_tokens: {self.actual_estimate_total_tokens}/{tokens_remaining} | "
347
+ f"Priority: {self.priority}"
348
+ )
349
+
350
+ # Zzz
351
+ self.logger.debug(f"Sleeping for {wait_time:.1f}s")
352
+ await asyncio.sleep(wait_time)
353
+ finally:
354
+ # let storage know we are done waiting at this priority
355
+ await redis_client.decr(priority_key)
356
+
357
+ async def _adjust_dynamic_estimate_ratio(self, actual_tokens: int) -> None:
358
+ if not self.dynamic_estimate_key:
359
+ return
360
+
361
+ observed_ratio = actual_tokens / self.estimate_total_tokens
362
+
363
+ alpha = 0.3
364
+
365
+ async with redis_client.lock(
366
+ self.dynamic_estimate_key + ":lock", timeout=LOCK_TIMEOUT
367
+ ):
368
+ current_ratio = float(
369
+ await redis_client.get(self.dynamic_estimate_key) or 1.0
370
+ )
371
+
372
+ new_ratio = (observed_ratio * alpha) + (current_ratio * (1 - alpha))
373
+
374
+ # NOTE: for now, will not cap the ratio as estimates will likely be very off
375
+ # the ratio between the tokenized estimate and the dynamic estimate should not be too far off
376
+ # new_ratio = max(0.01, min(100.0, new_ratio))
377
+
378
+ await redis_client.set(self.dynamic_estimate_key, new_ratio)
379
+
380
+ self.logger.info(
381
+ f"[Token Ratio] {self.token_key} | Observed: {observed_ratio:.5f} | "
382
+ f"Global Ratio: {current_ratio:.5f} -> {new_ratio:.5f}"
383
+ )
384
+
385
+ async def _post_function(self, result: tuple[QueryResult, float]) -> None:
386
+ """Adjust token estimate based on actual usage"""
387
+
388
+ metadata = result[0].metadata
389
+
390
+ countable_input_tokens = metadata.total_input_tokens - (
391
+ metadata.cache_read_tokens or 0
392
+ )
393
+ countable_output_tokens = metadata.total_output_tokens
394
+ actual_tokens = countable_input_tokens + countable_output_tokens
395
+
396
+ difference = self.actual_estimate_total_tokens - actual_tokens
397
+ self.logger.info(
398
+ f"Adjusting {self.token_key} by {difference}. Estimated {self.actual_estimate_total_tokens}, actual {actual_tokens}"
399
+ )
400
+
401
+ await self._adjust_dynamic_estimate_ratio(actual_tokens)
402
+
403
+ # NOTE: this can generate negative values, which represent `debt`
404
+ async with redis_client.lock(self._token_key_lock, timeout=LOCK_TIMEOUT):
405
+ await redis_client.incrby(self.token_key, difference)
406
+
407
+ result[0].metadata.extra["token_metadata"] = {
408
+ "estimated": self.estimate_total_tokens,
409
+ "estimated_with_dynamic_ratio": self.actual_estimate_total_tokens,
410
+ "actual": actual_tokens,
411
+ "difference": difference,
412
+ "ratio": actual_tokens / self.estimate_total_tokens,
413
+ "dynamic_ratio_used": self.actual_estimate_total_tokens
414
+ / self.estimate_total_tokens,
415
+ }
416
+
417
+ async def validate(self) -> None:
418
+ try:
419
+ assert redis_client
420
+ except Exception as e:
421
+ raise Exception(
422
+ f"redis client not set, run `TokenRetrier.set_redis_client`. Exception: {e}"
423
+ )
424
+ if not await redis_client.exists(self.token_key):
425
+ raise Exception(
426
+ "remaining_tokens not intialized, run `model.init_token_retry`"
427
+ )
@@ -0,0 +1,11 @@
1
+ import random
2
+
3
+
4
+ def jitter(wait: float) -> float:
5
+ """
6
+ Increase or decrease the wait time by up to 20%.
7
+ """
8
+ jitter_fraction = 0.2
9
+ min_wait = wait * (1 - jitter_fraction)
10
+ max_wait = wait * (1 + jitter_fraction)
11
+ return random.uniform(min_wait, max_wait)
model_library/settings.py CHANGED
@@ -22,7 +22,7 @@ class ModelLibrarySettings:
22
22
  except AttributeError:
23
23
  return default
24
24
 
25
- def __getattr__(self, name: str) -> str | Any:
25
+ def __getattr__(self, name: str) -> str:
26
26
  # load key from override
27
27
  if name in self._key_overrides:
28
28
  return self._key_overrides[name]