model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +139 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +93 -40
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +48 -29
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/xai.py +47 -42
- model_library/providers/zai.py +38 -8
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
model_library/exceptions.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import random
|
|
3
1
|
import re
|
|
4
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
5
3
|
|
|
6
|
-
import backoff
|
|
7
4
|
from ai21 import TooManyRequestsError as AI21RateLimitError
|
|
8
5
|
from anthropic import InternalServerError
|
|
9
6
|
from anthropic import RateLimitError as AnthropicRateLimitError
|
|
10
|
-
from backoff._typing import Details
|
|
11
7
|
from httpcore import ReadError as HTTPCoreReadError
|
|
12
8
|
from httpx import ConnectError as HTTPXConnectError
|
|
13
9
|
from httpx import ReadError as HTTPXReadError
|
|
@@ -75,12 +71,14 @@ CONTEXT_WINDOW_PATTERN = re.compile(
|
|
|
75
71
|
r"maximum context length is \d+ tokens|"
|
|
76
72
|
r"context length is \d+ tokens|"
|
|
77
73
|
r"exceed.* context (limit|window|length)|"
|
|
74
|
+
r"context window exceeds|"
|
|
78
75
|
r"exceeds maximum length|"
|
|
79
76
|
r"too long.*tokens.*maximum|"
|
|
80
77
|
r"too large for model with \d+ maximum context length|"
|
|
81
78
|
r"longer than the model's context length|"
|
|
82
79
|
r"too many tokens.*size limit exceeded|"
|
|
83
80
|
r"prompt is too long|"
|
|
81
|
+
r"maximum prompt length|"
|
|
84
82
|
r"input length should be|"
|
|
85
83
|
r"sent message larger than max|"
|
|
86
84
|
r"input tokens exceeded|"
|
|
@@ -222,75 +220,3 @@ def exception_message(exception: Exception | Any) -> str:
|
|
|
222
220
|
if str(exception)
|
|
223
221
|
else type(exception).__name__
|
|
224
222
|
)
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
RETRY_MAX_TRIES: int = 20
|
|
228
|
-
RETRY_INITIAL: float = 10.0
|
|
229
|
-
RETRY_EXPO: float = 1.4
|
|
230
|
-
RETRY_MAX_BACKOFF_WAIT: float = 240.0 # 4 minutes (more with jitter)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def jitter(wait: float) -> float:
|
|
234
|
-
"""
|
|
235
|
-
Increase or decrease the wait time by up to 20%.
|
|
236
|
-
"""
|
|
237
|
-
jitter_fraction = 0.2
|
|
238
|
-
min_wait = wait * (1 - jitter_fraction)
|
|
239
|
-
max_wait = wait * (1 + jitter_fraction)
|
|
240
|
-
return random.uniform(min_wait, max_wait)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
def retry_llm_call(
|
|
244
|
-
logger: logging.Logger,
|
|
245
|
-
max_tries: int = RETRY_MAX_TRIES,
|
|
246
|
-
max_time: float | None = None,
|
|
247
|
-
backoff_callback: (
|
|
248
|
-
Callable[[int, Exception | None, float, float], None] | None
|
|
249
|
-
) = None,
|
|
250
|
-
):
|
|
251
|
-
def on_backoff(details: Details):
|
|
252
|
-
exception = details.get("exception")
|
|
253
|
-
tries = details.get("tries", 0)
|
|
254
|
-
elapsed = details.get("elapsed", 0.0)
|
|
255
|
-
wait = details.get("wait", 0.0)
|
|
256
|
-
|
|
257
|
-
logger.warning(
|
|
258
|
-
f"[Retrying] Exception: {exception_message(exception)} | Attempt: {tries} | "
|
|
259
|
-
+ f"Elapsed: {elapsed:.1f}s | Next wait: {wait:.1f}s"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if backoff_callback:
|
|
263
|
-
backoff_callback(tries, exception, elapsed, wait)
|
|
264
|
-
|
|
265
|
-
def giveup(e: Exception) -> bool:
|
|
266
|
-
return not is_retriable_error(e)
|
|
267
|
-
|
|
268
|
-
def on_giveup(details: Details) -> None:
|
|
269
|
-
exception: Exception | None = details.get("exception", None)
|
|
270
|
-
if not exception:
|
|
271
|
-
return
|
|
272
|
-
|
|
273
|
-
logger.error(
|
|
274
|
-
f"Giving up after retries. Final exception: {exception_message(exception)}"
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
if is_context_window_error(exception):
|
|
278
|
-
message = exception.args[0] if exception.args else str(exception)
|
|
279
|
-
raise MaxContextWindowExceededError(message)
|
|
280
|
-
|
|
281
|
-
raise exception
|
|
282
|
-
|
|
283
|
-
return backoff.on_exception(
|
|
284
|
-
wait_gen=lambda: backoff.expo(
|
|
285
|
-
base=RETRY_EXPO,
|
|
286
|
-
factor=RETRY_INITIAL,
|
|
287
|
-
max_value=RETRY_MAX_BACKOFF_WAIT,
|
|
288
|
-
),
|
|
289
|
-
exception=Exception,
|
|
290
|
-
max_tries=max_tries,
|
|
291
|
-
max_time=max_time,
|
|
292
|
-
giveup=giveup,
|
|
293
|
-
on_backoff=on_backoff,
|
|
294
|
-
on_giveup=on_giveup,
|
|
295
|
-
jitter=jitter,
|
|
296
|
-
)
|
|
@@ -16,13 +16,13 @@ from model_library.base import (
|
|
|
16
16
|
LLMConfig,
|
|
17
17
|
QueryResult,
|
|
18
18
|
QueryResultMetadata,
|
|
19
|
+
RawResponse,
|
|
19
20
|
TextInput,
|
|
20
21
|
ToolBody,
|
|
21
22
|
ToolCall,
|
|
22
23
|
ToolDefinition,
|
|
23
24
|
ToolResult,
|
|
24
25
|
)
|
|
25
|
-
from model_library.base.input import RawResponse
|
|
26
26
|
from model_library.exceptions import (
|
|
27
27
|
BadInputError,
|
|
28
28
|
MaxOutputTokensExceededError,
|
|
@@ -34,17 +34,21 @@ from model_library.utils import default_httpx_client
|
|
|
34
34
|
|
|
35
35
|
@register_provider("ai21labs")
|
|
36
36
|
class AI21LabsModel(LLM):
|
|
37
|
-
|
|
37
|
+
@override
|
|
38
|
+
def _get_default_api_key(self) -> str:
|
|
39
|
+
return model_library_settings.AI21LABS_API_KEY
|
|
38
40
|
|
|
39
41
|
@override
|
|
40
|
-
def get_client(self) -> AsyncAI21Client:
|
|
41
|
-
if not
|
|
42
|
-
|
|
43
|
-
|
|
42
|
+
def get_client(self, api_key: str | None = None) -> AsyncAI21Client:
|
|
43
|
+
if not self.has_client():
|
|
44
|
+
assert api_key
|
|
45
|
+
client = AsyncAI21Client(
|
|
46
|
+
api_key=api_key,
|
|
44
47
|
http_client=default_httpx_client(),
|
|
45
|
-
num_retries=
|
|
48
|
+
num_retries=3,
|
|
46
49
|
)
|
|
47
|
-
|
|
50
|
+
self.assign_client(client)
|
|
51
|
+
return super().get_client()
|
|
48
52
|
|
|
49
53
|
def __init__(
|
|
50
54
|
self,
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
7
9
|
DelegateOnly,
|
|
8
10
|
LLMConfig,
|
|
9
11
|
QueryResultCost,
|
|
10
12
|
QueryResultMetadata,
|
|
11
13
|
)
|
|
12
|
-
from model_library.providers.openai import OpenAIModel
|
|
13
14
|
from model_library.register_models import register_provider
|
|
14
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@register_provider("alibaba")
|
|
@@ -26,17 +26,26 @@ class AlibabaModel(DelegateOnly):
|
|
|
26
26
|
super().__init__(model_name, provider, config=config)
|
|
27
27
|
|
|
28
28
|
# https://www.alibabacloud.com/help/en/model-studio/first-api-call-to-qwen
|
|
29
|
-
self.
|
|
30
|
-
model_name=self.model_name,
|
|
31
|
-
provider=self.provider,
|
|
29
|
+
self.init_delegate(
|
|
32
30
|
config=config,
|
|
33
|
-
|
|
34
|
-
api_key=model_library_settings.DASHSCOPE_API_KEY,
|
|
31
|
+
delegate_config=DelegateConfig(
|
|
35
32
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
33
|
+
api_key=SecretStr(model_library_settings.DASHSCOPE_API_KEY),
|
|
36
34
|
),
|
|
37
35
|
use_completions=True,
|
|
36
|
+
delegate_provider="openai",
|
|
38
37
|
)
|
|
39
38
|
|
|
39
|
+
@override
|
|
40
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
41
|
+
"""Build extra body parameters for Qwen-specific features."""
|
|
42
|
+
extra: dict[str, Any] = {}
|
|
43
|
+
# Enable thinking mode for Qwen3 reasoning models
|
|
44
|
+
# https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api
|
|
45
|
+
if self.reasoning:
|
|
46
|
+
extra["enable_thinking"] = True
|
|
47
|
+
return extra
|
|
48
|
+
|
|
40
49
|
@override
|
|
41
50
|
async def _calculate_cost(
|
|
42
51
|
self,
|
|
@@ -11,6 +11,7 @@ import botocore
|
|
|
11
11
|
from botocore.client import BaseClient
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
+
from model_library import model_library_settings
|
|
14
15
|
from model_library.base import (
|
|
15
16
|
LLM,
|
|
16
17
|
FileBase,
|
|
@@ -41,20 +42,46 @@ from model_library.register_models import register_provider
|
|
|
41
42
|
@register_provider("amazon")
|
|
42
43
|
@register_provider("bedrock")
|
|
43
44
|
class AmazonModel(LLM):
|
|
44
|
-
|
|
45
|
+
@override
|
|
46
|
+
def _get_default_api_key(self) -> str:
|
|
47
|
+
if getattr(model_library_settings, "AWS_ACCESS_KEY_ID", None):
|
|
48
|
+
return json.dumps(
|
|
49
|
+
{
|
|
50
|
+
"AWS_ACCESS_KEY_ID": model_library_settings.AWS_ACCESS_KEY_ID,
|
|
51
|
+
"AWS_SECRET_ACCESS_KEY": model_library_settings.AWS_SECRET_ACCESS_KEY,
|
|
52
|
+
"AWS_DEFAULT_REGION": model_library_settings.AWS_DEFAULT_REGION,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
return "using-environment"
|
|
45
56
|
|
|
46
57
|
@override
|
|
47
|
-
def get_client(self) -> BaseClient:
|
|
48
|
-
if not
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
def get_client(self, api_key: str | None = None) -> BaseClient:
|
|
59
|
+
if not self.has_client():
|
|
60
|
+
assert api_key
|
|
61
|
+
if api_key != "using-environment":
|
|
62
|
+
creds = json.loads(api_key)
|
|
63
|
+
client = cast(
|
|
64
|
+
BaseClient,
|
|
65
|
+
boto3.client(
|
|
66
|
+
"bedrock-runtime",
|
|
67
|
+
aws_access_key_id=creds["AWS_ACCESS_KEY_ID"],
|
|
68
|
+
aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"],
|
|
69
|
+
region_name=creds["AWS_DEFAULT_REGION"],
|
|
70
|
+
config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
client = cast(
|
|
75
|
+
BaseClient,
|
|
76
|
+
boto3.client(
|
|
77
|
+
"bedrock-runtime",
|
|
78
|
+
# default connection pool is 10
|
|
79
|
+
config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.assign_client(client)
|
|
84
|
+
return super().get_client()
|
|
58
85
|
|
|
59
86
|
def __init__(
|
|
60
87
|
self,
|
|
@@ -70,6 +97,11 @@ class AmazonModel(LLM):
|
|
|
70
97
|
) # supported but no access yet
|
|
71
98
|
self.supports_tool_cache = self.supports_cache and "claude" in self.model_name
|
|
72
99
|
|
|
100
|
+
if config and config.custom_api_key:
|
|
101
|
+
raise Exception(
|
|
102
|
+
"custom_api_key is not currently supported for Amazon models"
|
|
103
|
+
)
|
|
104
|
+
|
|
73
105
|
cache_control = {"type": "default"}
|
|
74
106
|
|
|
75
107
|
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
@@ -238,7 +270,7 @@ class AmazonModel(LLM):
|
|
|
238
270
|
if self.supports_cache:
|
|
239
271
|
body["system"].append({"cachePoint": self.cache_control})
|
|
240
272
|
|
|
241
|
-
if self.reasoning:
|
|
273
|
+
if self.reasoning and self.max_tokens:
|
|
242
274
|
if self.max_tokens < 1024:
|
|
243
275
|
self.max_tokens = 2048
|
|
244
276
|
budget_tokens = kwargs.pop(
|
|
@@ -251,9 +283,10 @@ class AmazonModel(LLM):
|
|
|
251
283
|
}
|
|
252
284
|
}
|
|
253
285
|
|
|
254
|
-
inference: dict[str, Any] = {
|
|
255
|
-
|
|
256
|
-
|
|
286
|
+
inference: dict[str, Any] = {}
|
|
287
|
+
|
|
288
|
+
if self.max_tokens:
|
|
289
|
+
inference["maxTokens"] = self.max_tokens
|
|
257
290
|
|
|
258
291
|
# Only set temperature for models where supports_temperature is True.
|
|
259
292
|
# For example, "thinking" models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
|
|
@@ -1,16 +1,20 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import io
|
|
2
3
|
import logging
|
|
4
|
+
import time
|
|
3
5
|
from typing import Any, Literal, Sequence, cast
|
|
4
6
|
|
|
5
|
-
from anthropic import AsyncAnthropic
|
|
7
|
+
from anthropic import APIConnectionError, AsyncAnthropic
|
|
6
8
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
7
9
|
from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
|
|
10
|
+
from pydantic import SecretStr
|
|
8
11
|
from typing_extensions import override
|
|
9
12
|
|
|
10
13
|
from model_library import model_library_settings
|
|
11
14
|
from model_library.base import (
|
|
12
15
|
LLM,
|
|
13
16
|
BatchResult,
|
|
17
|
+
DelegateConfig,
|
|
14
18
|
FileBase,
|
|
15
19
|
FileInput,
|
|
16
20
|
FileWithBase64,
|
|
@@ -22,6 +26,7 @@ from model_library.base import (
|
|
|
22
26
|
QueryResult,
|
|
23
27
|
QueryResultCost,
|
|
24
28
|
QueryResultMetadata,
|
|
29
|
+
RateLimit,
|
|
25
30
|
RawInput,
|
|
26
31
|
RawResponse,
|
|
27
32
|
TextInput,
|
|
@@ -31,6 +36,7 @@ from model_library.base import (
|
|
|
31
36
|
ToolResult,
|
|
32
37
|
)
|
|
33
38
|
from model_library.exceptions import (
|
|
39
|
+
ImmediateRetryException,
|
|
34
40
|
MaxOutputTokensExceededError,
|
|
35
41
|
NoMatchingToolCallError,
|
|
36
42
|
)
|
|
@@ -38,8 +44,7 @@ from model_library.model_utils import get_default_budget_tokens
|
|
|
38
44
|
from model_library.providers.openai import OpenAIModel
|
|
39
45
|
from model_library.register_models import register_provider
|
|
40
46
|
from model_library.utils import (
|
|
41
|
-
|
|
42
|
-
default_httpx_client,
|
|
47
|
+
create_anthropic_client_with_defaults,
|
|
43
48
|
)
|
|
44
49
|
|
|
45
50
|
|
|
@@ -246,21 +251,25 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
246
251
|
|
|
247
252
|
@register_provider("anthropic")
|
|
248
253
|
class AnthropicModel(LLM):
|
|
249
|
-
|
|
254
|
+
def _get_default_api_key(self) -> str:
|
|
255
|
+
if self.delegate_config:
|
|
256
|
+
return self.delegate_config.api_key.get_secret_value()
|
|
257
|
+
return model_library_settings.ANTHROPIC_API_KEY
|
|
250
258
|
|
|
251
259
|
@override
|
|
252
|
-
def get_client(self) -> AsyncAnthropic:
|
|
253
|
-
if self.
|
|
254
|
-
|
|
255
|
-
if not AnthropicModel._client:
|
|
260
|
+
def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
|
|
261
|
+
if not self.has_client():
|
|
262
|
+
assert api_key
|
|
256
263
|
headers: dict[str, str] = {}
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
264
|
+
client = create_anthropic_client_with_defaults(
|
|
265
|
+
base_url=self.delegate_config.base_url
|
|
266
|
+
if self.delegate_config
|
|
267
|
+
else None,
|
|
268
|
+
api_key=api_key,
|
|
261
269
|
default_headers=headers,
|
|
262
270
|
)
|
|
263
|
-
|
|
271
|
+
self.assign_client(client)
|
|
272
|
+
return super().get_client()
|
|
264
273
|
|
|
265
274
|
def __init__(
|
|
266
275
|
self,
|
|
@@ -268,33 +277,32 @@ class AnthropicModel(LLM):
|
|
|
268
277
|
provider: str = "anthropic",
|
|
269
278
|
*,
|
|
270
279
|
config: LLMConfig | None = None,
|
|
271
|
-
|
|
280
|
+
delegate_config: DelegateConfig | None = None,
|
|
272
281
|
):
|
|
273
|
-
|
|
282
|
+
self.delegate_config = delegate_config
|
|
274
283
|
|
|
275
|
-
|
|
276
|
-
self._delegate_client: AsyncAnthropic | None = custom_client
|
|
284
|
+
super().__init__(model_name, provider, config=config)
|
|
277
285
|
|
|
278
286
|
# https://docs.anthropic.com/en/api/openai-sdk
|
|
279
287
|
self.delegate = (
|
|
280
288
|
None
|
|
281
|
-
if self.native or
|
|
289
|
+
if self.native or self.delegate_config
|
|
282
290
|
else OpenAIModel(
|
|
283
291
|
model_name=self.model_name,
|
|
284
|
-
provider=provider,
|
|
292
|
+
provider=self.provider,
|
|
285
293
|
config=config,
|
|
286
|
-
|
|
287
|
-
|
|
294
|
+
use_completions=True,
|
|
295
|
+
delegate_config=DelegateConfig(
|
|
288
296
|
base_url="https://api.anthropic.com/v1/",
|
|
297
|
+
api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
|
|
289
298
|
),
|
|
290
|
-
use_completions=True,
|
|
291
299
|
)
|
|
292
300
|
)
|
|
293
301
|
|
|
294
302
|
# Initialize batch support if enabled
|
|
295
303
|
# Disable batch when using custom_client (similar to OpenAI)
|
|
296
304
|
self.supports_batch: bool = (
|
|
297
|
-
self.supports_batch and self.native and not
|
|
305
|
+
self.supports_batch and self.native and not self.delegate_config
|
|
298
306
|
)
|
|
299
307
|
self.batch: LLMBatchMixin | None = (
|
|
300
308
|
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
@@ -520,7 +528,6 @@ class AnthropicModel(LLM):
|
|
|
520
528
|
**kwargs: object,
|
|
521
529
|
) -> dict[str, Any]:
|
|
522
530
|
body: dict[str, Any] = {
|
|
523
|
-
"max_tokens": self.max_tokens,
|
|
524
531
|
"model": self.model_name,
|
|
525
532
|
"messages": await self.parse_input(input),
|
|
526
533
|
}
|
|
@@ -534,6 +541,11 @@ class AnthropicModel(LLM):
|
|
|
534
541
|
}
|
|
535
542
|
]
|
|
536
543
|
|
|
544
|
+
if not self.max_tokens:
|
|
545
|
+
raise Exception("Anthropic models require a max_tokens parameter")
|
|
546
|
+
|
|
547
|
+
body["max_tokens"] = self.max_tokens
|
|
548
|
+
|
|
537
549
|
if self.reasoning:
|
|
538
550
|
budget_tokens = kwargs.pop(
|
|
539
551
|
"budget_tokens", get_default_budget_tokens(self.max_tokens)
|
|
@@ -577,7 +589,7 @@ class AnthropicModel(LLM):
|
|
|
577
589
|
client = self.get_client()
|
|
578
590
|
|
|
579
591
|
# only send betas for the official Anthropic endpoint
|
|
580
|
-
is_anthropic_endpoint = self.
|
|
592
|
+
is_anthropic_endpoint = self.delegate_config is None
|
|
581
593
|
if not is_anthropic_endpoint:
|
|
582
594
|
client_base_url = getattr(client, "_base_url", None) or getattr(
|
|
583
595
|
client, "base_url", None
|
|
@@ -592,11 +604,14 @@ class AnthropicModel(LLM):
|
|
|
592
604
|
betas.append("context-1m-2025-08-07")
|
|
593
605
|
stream_kwargs["betas"] = betas
|
|
594
606
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
607
|
+
try:
|
|
608
|
+
async with client.beta.messages.stream(
|
|
609
|
+
**stream_kwargs,
|
|
610
|
+
) as stream: # pyright: ignore[reportAny]
|
|
611
|
+
message = await stream.get_final_message()
|
|
612
|
+
self.logger.info(f"Anthropic Response finished: {message.id}")
|
|
613
|
+
except APIConnectionError:
|
|
614
|
+
raise ImmediateRetryException("Failed to connect to Anthropic")
|
|
600
615
|
|
|
601
616
|
text = ""
|
|
602
617
|
reasoning = ""
|
|
@@ -632,6 +647,38 @@ class AnthropicModel(LLM):
|
|
|
632
647
|
history=[*input, RawResponse(response=message)],
|
|
633
648
|
)
|
|
634
649
|
|
|
650
|
+
@override
|
|
651
|
+
async def get_rate_limit(self) -> RateLimit:
|
|
652
|
+
response = await self.get_client().messages.with_raw_response.create(
|
|
653
|
+
max_tokens=1,
|
|
654
|
+
messages=[
|
|
655
|
+
{
|
|
656
|
+
"role": "user",
|
|
657
|
+
"content": "Ping",
|
|
658
|
+
}
|
|
659
|
+
],
|
|
660
|
+
model=self.model_name,
|
|
661
|
+
)
|
|
662
|
+
headers = response.headers
|
|
663
|
+
|
|
664
|
+
server_time_str = headers.get("date")
|
|
665
|
+
if server_time_str:
|
|
666
|
+
server_time = datetime.datetime.strptime(
|
|
667
|
+
server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
|
|
668
|
+
).replace(tzinfo=datetime.timezone.utc)
|
|
669
|
+
timestamp = server_time.timestamp()
|
|
670
|
+
else:
|
|
671
|
+
timestamp = time.time()
|
|
672
|
+
|
|
673
|
+
return RateLimit(
|
|
674
|
+
unix_timestamp=timestamp,
|
|
675
|
+
raw=headers,
|
|
676
|
+
request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
|
|
677
|
+
request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
|
|
678
|
+
token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
|
|
679
|
+
token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
|
|
680
|
+
)
|
|
681
|
+
|
|
635
682
|
@override
|
|
636
683
|
async def count_tokens(
|
|
637
684
|
self,
|
|
@@ -645,20 +692,26 @@ class AnthropicModel(LLM):
|
|
|
645
692
|
Count the number of tokens using Anthropic's native token counting API.
|
|
646
693
|
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
|
647
694
|
"""
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
695
|
+
try:
|
|
696
|
+
input = [*history, *input]
|
|
697
|
+
if not input:
|
|
698
|
+
return 0
|
|
651
699
|
|
|
652
|
-
|
|
700
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
653
701
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
702
|
+
# Remove fields not supported by count_tokens endpoint
|
|
703
|
+
body.pop("max_tokens", None)
|
|
704
|
+
body.pop("temperature", None)
|
|
657
705
|
|
|
658
|
-
|
|
659
|
-
|
|
706
|
+
client = self.get_client()
|
|
707
|
+
response = await client.messages.count_tokens(**body)
|
|
660
708
|
|
|
661
|
-
|
|
709
|
+
return response.input_tokens
|
|
710
|
+
except Exception as e:
|
|
711
|
+
self.logger.error(f"Error counting tokens: {e}")
|
|
712
|
+
return await super().count_tokens(
|
|
713
|
+
input, history=history, tools=tools, **kwargs
|
|
714
|
+
)
|
|
662
715
|
|
|
663
716
|
@override
|
|
664
717
|
async def _calculate_cost(
|
model_library/providers/azure.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import Literal
|
|
2
3
|
|
|
3
4
|
from openai.lib.azure import AsyncAzureOpenAI
|
|
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
|
|
|
14
15
|
|
|
15
16
|
@register_provider("azure")
|
|
16
17
|
class AzureOpenAIModel(OpenAIModel):
|
|
17
|
-
_azure_client: AsyncAzureOpenAI | None = None
|
|
18
|
-
|
|
19
18
|
@override
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
19
|
+
def _get_default_api_key(self) -> str:
|
|
20
|
+
return json.dumps(
|
|
21
|
+
{
|
|
22
|
+
"AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
|
|
23
|
+
"AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
|
|
24
|
+
"AZURE_API_VERSION": model_library_settings.get(
|
|
26
25
|
"AZURE_API_VERSION", "2025-04-01-preview"
|
|
27
26
|
),
|
|
27
|
+
}
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
|
|
32
|
+
if not self.has_client():
|
|
33
|
+
assert api_key
|
|
34
|
+
creds = json.loads(api_key)
|
|
35
|
+
client = AsyncAzureOpenAI(
|
|
36
|
+
api_key=creds["AZURE_API_KEY"],
|
|
37
|
+
azure_endpoint=creds["AZURE_ENDPOINT"],
|
|
38
|
+
api_version=creds["AZURE_API_VERSION"],
|
|
28
39
|
http_client=default_httpx_client(),
|
|
29
|
-
max_retries=
|
|
40
|
+
max_retries=3,
|
|
30
41
|
)
|
|
31
|
-
|
|
42
|
+
self.assign_client(client)
|
|
43
|
+
return super(OpenAIModel, self).get_client(api_key)
|
|
32
44
|
|
|
33
45
|
def __init__(
|
|
34
46
|
self,
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("cohere")
|
|
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.cohere.com/docs/compatibility-api
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.COHERE_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.cohere.ai/compatibility/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.COHERE_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
|
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
|
|
|
5
5
|
|
|
6
6
|
from typing import Literal
|
|
7
7
|
|
|
8
|
+
from pydantic import SecretStr
|
|
9
|
+
|
|
8
10
|
from model_library import model_library_settings
|
|
9
11
|
from model_library.base import (
|
|
12
|
+
DelegateConfig,
|
|
10
13
|
DelegateOnly,
|
|
11
14
|
LLMConfig,
|
|
12
15
|
)
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
16
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@register_provider("deepseek")
|
|
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
|
|
|
27
28
|
super().__init__(model_name, provider, config=config)
|
|
28
29
|
|
|
29
30
|
# https://api-docs.deepseek.com/
|
|
30
|
-
self.
|
|
31
|
-
model_name=self.model_name,
|
|
32
|
-
provider=self.provider,
|
|
31
|
+
self.init_delegate(
|
|
33
32
|
config=config,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
delegate_config=DelegateConfig(
|
|
34
|
+
base_url="https://api.deepseek.com/v1",
|
|
35
|
+
api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
|
|
37
36
|
),
|
|
38
37
|
use_completions=True,
|
|
38
|
+
delegate_provider="openai",
|
|
39
39
|
)
|