model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +139 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +93 -40
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +48 -29
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/xai.py +47 -42
- model_library/providers/zai.py +38 -8
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
model_library/registry_utils.py
CHANGED
|
@@ -4,8 +4,13 @@ from typing import TypedDict
|
|
|
4
4
|
|
|
5
5
|
import tiktoken
|
|
6
6
|
|
|
7
|
-
from model_library.base import
|
|
8
|
-
|
|
7
|
+
from model_library.base import (
|
|
8
|
+
LLM,
|
|
9
|
+
LLMConfig,
|
|
10
|
+
ProviderConfig,
|
|
11
|
+
QueryResultCost,
|
|
12
|
+
QueryResultMetadata,
|
|
13
|
+
)
|
|
9
14
|
from model_library.register_models import (
|
|
10
15
|
CostProperties,
|
|
11
16
|
ModelConfig,
|
|
@@ -196,19 +201,38 @@ def get_provider_names() -> list[str]:
|
|
|
196
201
|
|
|
197
202
|
|
|
198
203
|
@cache
|
|
199
|
-
def get_model_names(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
]
|
|
204
|
+
def get_model_names(
|
|
205
|
+
provider: str | None = None,
|
|
206
|
+
include_deprecated: bool = False,
|
|
207
|
+
include_alt_keys: bool = True,
|
|
208
|
+
) -> list[str]:
|
|
209
|
+
"""
|
|
210
|
+
Return model names in the registry
|
|
211
|
+
- provider: Filter by provider name
|
|
212
|
+
- include_deprecated: Include deprecated models
|
|
213
|
+
- include_alt_keys: Include alternative keys from the same provider
|
|
214
|
+
"""
|
|
215
|
+
registry = get_model_registry()
|
|
216
|
+
alternative_keys_set: set[str] = set()
|
|
217
|
+
|
|
218
|
+
if not include_alt_keys:
|
|
219
|
+
for model in registry.values():
|
|
220
|
+
for alt_item in model.alternative_keys:
|
|
221
|
+
alt_key = (
|
|
222
|
+
alt_item if isinstance(alt_item, str) else list(alt_item.keys())[0]
|
|
223
|
+
)
|
|
224
|
+
if alt_key.split("/")[0] == model.provider_name:
|
|
225
|
+
alternative_keys_set.add(alt_key)
|
|
226
|
+
|
|
227
|
+
return sorted(
|
|
228
|
+
[
|
|
229
|
+
model.full_key
|
|
230
|
+
for model in get_model_registry().values()
|
|
231
|
+
if (not provider or model.provider_name.lower() == provider.lower())
|
|
232
|
+
and (not model.metadata.deprecated or include_deprecated)
|
|
233
|
+
and model.full_key not in alternative_keys_set
|
|
234
|
+
]
|
|
235
|
+
)
|
|
212
236
|
|
|
213
237
|
|
|
214
238
|
@cache
|
|
File without changes
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from model_library.base.base import QueryResult
|
|
5
|
+
from model_library.exceptions import exception_message
|
|
6
|
+
from model_library.retriers.base import BaseRetrier
|
|
7
|
+
from model_library.retriers.utils import jitter
|
|
8
|
+
|
|
9
|
+
RETRY_MAX_TRIES: int = 20
|
|
10
|
+
RETRY_INITIAL: float = 10.0
|
|
11
|
+
RETRY_EXPO: float = 1.4
|
|
12
|
+
RETRY_MAX_BACKOFF_WAIT: float = 240.0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ExponentialBackoffRetrier(BaseRetrier):
|
|
16
|
+
"""
|
|
17
|
+
Exponential backoff retry strategy.
|
|
18
|
+
Uses exponential backoff with jitter for wait times.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
logger: logging.Logger,
|
|
24
|
+
max_tries: int = RETRY_MAX_TRIES,
|
|
25
|
+
max_time: float | None = None,
|
|
26
|
+
retry_callback: Callable[[int, Exception | None, float, float], None]
|
|
27
|
+
| None = None,
|
|
28
|
+
*,
|
|
29
|
+
initial: float = RETRY_INITIAL,
|
|
30
|
+
expo: float = RETRY_EXPO,
|
|
31
|
+
max_backoff_wait: float = RETRY_MAX_BACKOFF_WAIT,
|
|
32
|
+
):
|
|
33
|
+
super().__init__(
|
|
34
|
+
strategy="backoff",
|
|
35
|
+
logger=logger,
|
|
36
|
+
max_tries=max_tries,
|
|
37
|
+
max_time=max_time,
|
|
38
|
+
retry_callback=retry_callback,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.initial = initial
|
|
42
|
+
self.expo = expo
|
|
43
|
+
self.max_backoff_wait = max_backoff_wait
|
|
44
|
+
|
|
45
|
+
async def _calculate_wait_time(
|
|
46
|
+
self, attempt: int, exception: Exception | None = None
|
|
47
|
+
) -> float:
|
|
48
|
+
"""Calculate exponential backoff wait time with jitter"""
|
|
49
|
+
|
|
50
|
+
exponential_wait = self.initial * (self.expo**attempt)
|
|
51
|
+
capped_wait = min(exponential_wait, self.max_backoff_wait)
|
|
52
|
+
return jitter(capped_wait)
|
|
53
|
+
|
|
54
|
+
async def _on_retry(
|
|
55
|
+
self, exception: Exception | None, elapsed: float, wait_time: float
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Increment attempt counter and log retry attempt"""
|
|
58
|
+
|
|
59
|
+
logger_msg = f"[Retry] | {self.strategy} | Attempt: {self.attempts} | Elapsed: {elapsed:.1f}s | Next wait: {wait_time:.1f}s | Exception: {exception_message(exception)} "
|
|
60
|
+
|
|
61
|
+
self.logger.warning(logger_msg)
|
|
62
|
+
|
|
63
|
+
if self.retry_callback:
|
|
64
|
+
self.retry_callback(self.attempts, exception, elapsed, wait_time)
|
|
65
|
+
|
|
66
|
+
async def _pre_function(self) -> None:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
async def _post_function(self, result: tuple[QueryResult, float]) -> None:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
async def validate(self) -> None:
|
|
73
|
+
return
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any, Awaitable, Callable, Literal, TypeVar
|
|
6
|
+
|
|
7
|
+
from model_library.base.base import QueryResult
|
|
8
|
+
from model_library.exceptions import (
|
|
9
|
+
ImmediateRetryException,
|
|
10
|
+
MaxContextWindowExceededError,
|
|
11
|
+
exception_message,
|
|
12
|
+
is_context_window_error,
|
|
13
|
+
is_retriable_error,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
|
|
17
|
+
|
|
18
|
+
R = TypeVar("R") # wrapper return type
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseRetrier(ABC):
|
|
22
|
+
"""
|
|
23
|
+
Base class for retry strategies.
|
|
24
|
+
Implements core retry logic and error handling.
|
|
25
|
+
Subclasses should implement strategy-specific wait time calculations.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# NOTE: for token retrier, the estimate_tokens stays the same because ImmediateRetryException
|
|
29
|
+
# is raised for network errors, where tokens have not been deducted yet
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
async def immediate_retry_wrapper(
|
|
33
|
+
func: Callable[[], Awaitable[R]],
|
|
34
|
+
logger: logging.Logger,
|
|
35
|
+
) -> R:
|
|
36
|
+
"""
|
|
37
|
+
Retry the query immediately
|
|
38
|
+
"""
|
|
39
|
+
MAX_IMMEDIATE_RETRIES = 10
|
|
40
|
+
retries = 0
|
|
41
|
+
while True:
|
|
42
|
+
try:
|
|
43
|
+
return await func()
|
|
44
|
+
except ImmediateRetryException as e:
|
|
45
|
+
if retries >= MAX_IMMEDIATE_RETRIES:
|
|
46
|
+
raise Exception(
|
|
47
|
+
f"[Immediate Retry Max] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
|
|
48
|
+
) from e
|
|
49
|
+
retries += 1
|
|
50
|
+
|
|
51
|
+
logger.warning(
|
|
52
|
+
f"[Immediate Retry] | {retries}/{MAX_IMMEDIATE_RETRIES} | Exception {exception_message(e)}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
strategy: Literal["backoff", "token"],
|
|
58
|
+
logger: logging.Logger,
|
|
59
|
+
max_tries: int | None,
|
|
60
|
+
max_time: float | None,
|
|
61
|
+
retry_callback: Callable[[int, Exception | None, float, float], None] | None,
|
|
62
|
+
):
|
|
63
|
+
self.strategy = strategy
|
|
64
|
+
self.logger = logger
|
|
65
|
+
self.max_tries = max_tries
|
|
66
|
+
self.max_time = max_time
|
|
67
|
+
self.retry_callback = retry_callback
|
|
68
|
+
|
|
69
|
+
self.attempts = 0
|
|
70
|
+
self.start_time: float | None = None
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
async def _calculate_wait_time(
|
|
74
|
+
self, attempt: int, exception: Exception | None = None
|
|
75
|
+
) -> float:
|
|
76
|
+
"""
|
|
77
|
+
Calculate wait time before retrying
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
attempt: Current attempt number (0-indexed)
|
|
81
|
+
exception: The exception that triggered the retry
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Wait time in seconds
|
|
85
|
+
"""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
async def _on_retry(
|
|
90
|
+
self, exception: Exception | None, elapsed: float, wait_time: float
|
|
91
|
+
) -> None:
|
|
92
|
+
"""
|
|
93
|
+
Hook called before waiting on retry
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
exception: The exception that triggered the retry
|
|
97
|
+
elapsed: Time elapsed since start
|
|
98
|
+
wait_time: Wait time
|
|
99
|
+
"""
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
def _should_retry(self, exception: Exception) -> bool:
|
|
103
|
+
"""
|
|
104
|
+
Determine if an exception should trigger a retry.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
exception: The exception to evaluate
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
True if should retry, False otherwise
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
if is_context_window_error(exception):
|
|
114
|
+
return False
|
|
115
|
+
return is_retriable_error(exception)
|
|
116
|
+
|
|
117
|
+
def _handle_giveup(self, exception: Exception, reason: str) -> None:
|
|
118
|
+
"""
|
|
119
|
+
Handle final exception after all retries exhausted.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
exception: The final exception
|
|
123
|
+
reason: Reason for giving up
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
MaxContextWindowExceededError: If context window error
|
|
127
|
+
Exception: The original exception otherwise
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
self.logger.error(
|
|
131
|
+
f"[Give up] | {self.strategy} | {reason} | Exception: {exception_message(exception)}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# instead of raising the provider exception, raise the custom MaxContextWindowExceededError
|
|
135
|
+
if is_context_window_error(exception):
|
|
136
|
+
message = exception.args[0] if exception.args else str(exception)
|
|
137
|
+
raise MaxContextWindowExceededError(message)
|
|
138
|
+
|
|
139
|
+
raise exception
|
|
140
|
+
|
|
141
|
+
@abstractmethod
|
|
142
|
+
async def _pre_function(self) -> None:
|
|
143
|
+
"""Hook called before the actual function call"""
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
async def _post_function(self, result: tuple[QueryResult, float]) -> None:
|
|
148
|
+
"""Hook called after the actual function call"""
|
|
149
|
+
...
|
|
150
|
+
|
|
151
|
+
async def execute(
|
|
152
|
+
self,
|
|
153
|
+
func: Callable[..., Any],
|
|
154
|
+
*args: Any,
|
|
155
|
+
**kwargs: Any,
|
|
156
|
+
) -> Any:
|
|
157
|
+
"""
|
|
158
|
+
Execute function with retry logic.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
func: Async function to execute
|
|
162
|
+
*args: Positional arguments for func
|
|
163
|
+
**kwargs: Keyword arguments for func
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Result from func
|
|
167
|
+
|
|
168
|
+
Raises:
|
|
169
|
+
Exception: If retries exhausted or non-retriable error occurs
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
self.attempts = 0
|
|
173
|
+
self.start_time = time.time()
|
|
174
|
+
|
|
175
|
+
await self.validate()
|
|
176
|
+
|
|
177
|
+
while True:
|
|
178
|
+
try:
|
|
179
|
+
await self._pre_function()
|
|
180
|
+
result = await func(*args, **kwargs)
|
|
181
|
+
await self._post_function(result)
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
except Exception as e:
|
|
185
|
+
elapsed = time.time() - self.start_time
|
|
186
|
+
|
|
187
|
+
self.attempts += 1
|
|
188
|
+
|
|
189
|
+
# check if max_tries exceeded
|
|
190
|
+
if self.max_tries is not None and self.attempts >= self.max_tries:
|
|
191
|
+
self._handle_giveup(
|
|
192
|
+
e, f"max_tries exceeded ({self.attempts} >= {self.max_tries})"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# check if max_time exceeded
|
|
196
|
+
if self.max_time is not None and elapsed > self.max_time:
|
|
197
|
+
self._handle_giveup(
|
|
198
|
+
e, f"max_time exceeded ({elapsed} > {self.max_time}s)"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if not self._should_retry(e):
|
|
202
|
+
self._handle_giveup(e, "not retriable")
|
|
203
|
+
|
|
204
|
+
# calculate wait time
|
|
205
|
+
wait_time = await self._calculate_wait_time(self.attempts, e)
|
|
206
|
+
# call pre retry sleep hook
|
|
207
|
+
await self._on_retry(e, elapsed, wait_time)
|
|
208
|
+
|
|
209
|
+
await asyncio.sleep(wait_time)
|
|
210
|
+
|
|
211
|
+
async def validate(self) -> None:
|
|
212
|
+
"""Validate the retrier"""
|
|
213
|
+
...
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def retry_decorator(retrier: BaseRetrier) -> RetrierType:
|
|
217
|
+
"""Create a retry decorator from an initialized retrier"""
|
|
218
|
+
|
|
219
|
+
def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
|
|
220
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
221
|
+
return await retrier.execute(func, *args, **kwargs)
|
|
222
|
+
|
|
223
|
+
return wrapper
|
|
224
|
+
|
|
225
|
+
return decorator
|