model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +141 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +49 -57
- model_library/config/all_models.json +353 -120
- model_library/config/anthropic_models.yaml +2 -1
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/mistral_models.yaml +2 -0
- model_library/config/openai_models.yaml +15 -23
- model_library/config/together_models.yaml +2 -0
- model_library/config/xiaomi_models.yaml +43 -0
- model_library/config/zai_models.yaml +27 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +128 -48
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +57 -30
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/vercel.py +34 -0
- model_library/providers/xai.py +47 -42
- model_library/providers/xiaomi.py +34 -0
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +5 -0
- model_library/registry_utils.py +48 -17
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +17 -7
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
- model_library-0.1.9.dist-info/RECORD +73 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
import uuid
|
|
5
6
|
from typing import Any, Literal, Sequence, cast
|
|
@@ -25,6 +26,7 @@ from google.genai.types import (
|
|
|
25
26
|
ToolListUnion,
|
|
26
27
|
UploadFileConfig,
|
|
27
28
|
)
|
|
29
|
+
from google.oauth2 import service_account
|
|
28
30
|
from typing_extensions import override
|
|
29
31
|
|
|
30
32
|
from model_library import model_library_settings
|
|
@@ -95,31 +97,50 @@ class GoogleModel(LLM):
|
|
|
95
97
|
),
|
|
96
98
|
]
|
|
97
99
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
"gemini-2.5-flash-preview-09-2025": "global",
|
|
108
|
-
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
100
|
+
def _get_default_api_key(self) -> str:
|
|
101
|
+
if not self.provider_config.use_vertex:
|
|
102
|
+
return model_library_settings.GOOGLE_API_KEY
|
|
103
|
+
|
|
104
|
+
return json.dumps(
|
|
105
|
+
{
|
|
106
|
+
"GCP_REGION": model_library_settings.GCP_REGION,
|
|
107
|
+
"GCP_PROJECT_ID": model_library_settings.GCP_PROJECT_ID,
|
|
108
|
+
"GCP_CREDS": model_library_settings.GCP_CREDS,
|
|
109
109
|
}
|
|
110
|
-
|
|
111
|
-
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
112
|
-
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
113
|
-
|
|
114
|
-
return Client(
|
|
115
|
-
vertexai=True,
|
|
116
|
-
project=model_library_settings.GCP_PROJECT_ID,
|
|
117
|
-
location=region,
|
|
118
|
-
# Credentials object is not typed, so we have to ignore the error
|
|
119
|
-
credentials=model_library_settings.GCP_CREDS,
|
|
120
|
-
)
|
|
110
|
+
)
|
|
121
111
|
|
|
122
|
-
|
|
112
|
+
@override
|
|
113
|
+
def get_client(self, api_key: str | None = None) -> Client:
|
|
114
|
+
if not self.has_client():
|
|
115
|
+
assert api_key
|
|
116
|
+
if self.provider_config.use_vertex:
|
|
117
|
+
# Gemini preview releases are only server from the global Vertex region after September 2025.
|
|
118
|
+
MODEL_REGION_OVERRIDES: dict[str, str] = {
|
|
119
|
+
"gemini-2.5-flash-preview-09-2025": "global",
|
|
120
|
+
"gemini-2.5-flash-lite-preview-09-2025": "global",
|
|
121
|
+
"gemini-3-flash-preview": "global",
|
|
122
|
+
"gemini-3-pro-preview": "global",
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
creds = json.loads(api_key)
|
|
126
|
+
|
|
127
|
+
region = creds["GCP_REGION"]
|
|
128
|
+
if self.model_name in MODEL_REGION_OVERRIDES:
|
|
129
|
+
region = MODEL_REGION_OVERRIDES[self.model_name]
|
|
130
|
+
|
|
131
|
+
client = Client(
|
|
132
|
+
vertexai=True,
|
|
133
|
+
project=creds["GCP_PROJECT_ID"],
|
|
134
|
+
location=region,
|
|
135
|
+
credentials=service_account.Credentials.from_service_account_info( # type: ignore
|
|
136
|
+
json.loads(creds["GCP_CREDS"]),
|
|
137
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
client = Client(api_key=api_key)
|
|
142
|
+
self.assign_client(client)
|
|
143
|
+
return super().get_client()
|
|
123
144
|
|
|
124
145
|
def __init__(
|
|
125
146
|
self,
|
|
@@ -141,8 +162,6 @@ class GoogleModel(LLM):
|
|
|
141
162
|
GoogleBatchMixin(self) if self.supports_batch else None
|
|
142
163
|
)
|
|
143
164
|
|
|
144
|
-
self.client = self.get_client()
|
|
145
|
-
|
|
146
165
|
@override
|
|
147
166
|
async def parse_input(
|
|
148
167
|
self,
|
|
@@ -260,7 +279,7 @@ class GoogleModel(LLM):
|
|
|
260
279
|
)
|
|
261
280
|
|
|
262
281
|
mime = f"image/{mime}" if type == "image" else mime # TODO:
|
|
263
|
-
response: File = self.
|
|
282
|
+
response: File = self.get_client().files.upload(
|
|
264
283
|
file=bytes, config=UploadFileConfig(mime_type=mime)
|
|
265
284
|
)
|
|
266
285
|
if not response.name:
|
|
@@ -338,20 +357,25 @@ class GoogleModel(LLM):
|
|
|
338
357
|
|
|
339
358
|
metadata: GenerateContentResponseUsageMetadata | None = None
|
|
340
359
|
|
|
341
|
-
stream = await self.
|
|
360
|
+
stream = await self.get_client().aio.models.generate_content_stream(**body)
|
|
342
361
|
contents: list[Content | None] = []
|
|
343
362
|
finish_reason: FinishReason | None = None
|
|
344
363
|
|
|
364
|
+
chunks: list[GenerateContentResponse] = []
|
|
365
|
+
|
|
345
366
|
async for chunk in stream:
|
|
367
|
+
chunks.append(chunk)
|
|
346
368
|
candidates = chunk.candidates
|
|
347
369
|
if not candidates:
|
|
348
370
|
continue
|
|
349
371
|
|
|
350
372
|
content = candidates[0].content
|
|
351
373
|
|
|
374
|
+
meaningful_content = False
|
|
352
375
|
if content and content.parts:
|
|
353
376
|
for part in content.parts:
|
|
354
377
|
if part.function_call:
|
|
378
|
+
meaningful_content = True
|
|
355
379
|
if not part.function_call.name:
|
|
356
380
|
raise Exception(f"Invalid function call: {part}")
|
|
357
381
|
|
|
@@ -368,13 +392,15 @@ class GoogleModel(LLM):
|
|
|
368
392
|
if not part.text:
|
|
369
393
|
continue
|
|
370
394
|
if part.thought:
|
|
395
|
+
meaningful_content = True
|
|
371
396
|
reasoning += part.text
|
|
372
397
|
else:
|
|
398
|
+
meaningful_content = True
|
|
373
399
|
text += part.text
|
|
374
400
|
|
|
375
401
|
if chunk.usage_metadata:
|
|
376
402
|
metadata = chunk.usage_metadata
|
|
377
|
-
if content:
|
|
403
|
+
if content and meaningful_content:
|
|
378
404
|
contents.append(content)
|
|
379
405
|
if candidates[0].finish_reason:
|
|
380
406
|
finish_reason = candidates[0].finish_reason
|
|
@@ -383,6 +409,7 @@ class GoogleModel(LLM):
|
|
|
383
409
|
self.logger.error(f"Unexpected finish reason: {finish_reason}")
|
|
384
410
|
|
|
385
411
|
if not text and not reasoning and not tool_calls:
|
|
412
|
+
self.logger.error(f"Chunks: {chunks}")
|
|
386
413
|
raise ModelNoOutputError("Model returned empty response")
|
|
387
414
|
|
|
388
415
|
result = QueryResult(
|
|
@@ -437,7 +464,7 @@ class GoogleModel(LLM):
|
|
|
437
464
|
tools=parsed_tools,
|
|
438
465
|
)
|
|
439
466
|
|
|
440
|
-
response = await self.
|
|
467
|
+
response = await self.get_client().aio.models.count_tokens(
|
|
441
468
|
model=self.model_name,
|
|
442
469
|
contents=cast(Any, contents),
|
|
443
470
|
config=config,
|
|
@@ -503,7 +530,7 @@ class GoogleModel(LLM):
|
|
|
503
530
|
# Make the request with retry wrapper
|
|
504
531
|
async def _query():
|
|
505
532
|
try:
|
|
506
|
-
return await self.
|
|
533
|
+
return await self.get_client().aio.models.generate_content(**body)
|
|
507
534
|
except (genai_errors.ServerError, genai_errors.UnknownApiResponseError):
|
|
508
535
|
raise ImmediateRetryException("Failed to connect to Google API")
|
|
509
536
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("inception")
|
|
@@ -22,13 +23,12 @@ class MercuryModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.inceptionlabs.ai/get-started/get-started#external-libraries-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.MERCURY_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.inceptionlabs.ai/v1/",
|
|
30
|
+
api_key=SecretStr(model_library_settings.MERCURY_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
model_library/providers/kimi.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
|
|
5
|
+
from pydantic import SecretStr
|
|
2
6
|
|
|
3
7
|
from model_library import model_library_settings
|
|
4
8
|
from model_library.base import (
|
|
9
|
+
DelegateConfig,
|
|
5
10
|
DelegateOnly,
|
|
6
11
|
LLMConfig,
|
|
7
12
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
13
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@register_provider("kimi")
|
|
@@ -22,13 +25,20 @@ class KimiModel(DelegateOnly):
|
|
|
22
25
|
super().__init__(model_name, provider, config=config)
|
|
23
26
|
|
|
24
27
|
# https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-api-compatibility
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
28
29
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.KIMI_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
31
31
|
base_url="https://api.moonshot.ai/v1/",
|
|
32
|
+
api_key=SecretStr(model_library_settings.KIMI_API_KEY),
|
|
32
33
|
),
|
|
33
34
|
use_completions=True,
|
|
35
|
+
delegate_provider="openai",
|
|
34
36
|
)
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Build extra body parameters for Kimi-specific features.
|
|
42
|
+
see https://platform.moonshot.ai/docs/guide/kimi-k2-5-quickstart#parameters-differences-in-request-body
|
|
43
|
+
"""
|
|
44
|
+
return {"thinking": {"type": "enabled" if self.reasoning else "disabled"}}
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal, Sequence
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
from typing_extensions import override
|
|
2
5
|
|
|
3
6
|
from model_library import model_library_settings
|
|
4
|
-
from model_library.base import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
+
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
9
|
+
DelegateOnly,
|
|
10
|
+
InputItem,
|
|
11
|
+
LLMConfig,
|
|
12
|
+
ToolDefinition,
|
|
13
|
+
)
|
|
7
14
|
from model_library.register_models import register_provider
|
|
8
|
-
from model_library.utils import default_httpx_client
|
|
9
|
-
|
|
10
|
-
from anthropic import AsyncAnthropic
|
|
11
|
-
|
|
12
|
-
from typing import Sequence
|
|
13
|
-
from typing_extensions import override
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
@register_provider("minimax")
|
|
@@ -24,16 +25,13 @@ class MinimaxModel(DelegateOnly):
|
|
|
24
25
|
):
|
|
25
26
|
super().__init__(model_name, provider, config=config)
|
|
26
27
|
|
|
27
|
-
self.
|
|
28
|
-
model_name=self.model_name,
|
|
29
|
-
provider=self.provider,
|
|
28
|
+
self.init_delegate(
|
|
30
29
|
config=config,
|
|
31
|
-
|
|
32
|
-
api_key=model_library_settings.MINIMAX_API_KEY,
|
|
30
|
+
delegate_config=DelegateConfig(
|
|
33
31
|
base_url="https://api.minimax.io/anthropic",
|
|
34
|
-
|
|
35
|
-
max_retries=1,
|
|
32
|
+
api_key=SecretStr(model_library_settings.MINIMAX_API_KEY),
|
|
36
33
|
),
|
|
34
|
+
delegate_provider="anthropic",
|
|
37
35
|
)
|
|
38
36
|
|
|
39
37
|
# minimax client shares anthropic's syntax
|
|
@@ -3,7 +3,13 @@ import logging
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import Any, Literal
|
|
5
5
|
|
|
6
|
-
from mistralai import
|
|
6
|
+
from mistralai import (
|
|
7
|
+
AssistantMessage,
|
|
8
|
+
ContentChunk,
|
|
9
|
+
Mistral,
|
|
10
|
+
TextChunk,
|
|
11
|
+
ThinkChunk,
|
|
12
|
+
)
|
|
7
13
|
from mistralai.models.completionevent import CompletionEvent
|
|
8
14
|
from mistralai.models.toolcall import ToolCall as MistralToolCall
|
|
9
15
|
from mistralai.utils.eventstreaming import EventStreamAsync
|
|
@@ -40,16 +46,20 @@ from model_library.utils import default_httpx_client
|
|
|
40
46
|
|
|
41
47
|
@register_provider("mistralai")
|
|
42
48
|
class MistralModel(LLM):
|
|
43
|
-
|
|
49
|
+
@override
|
|
50
|
+
def _get_default_api_key(self) -> str:
|
|
51
|
+
return model_library_settings.MISTRAL_API_KEY
|
|
44
52
|
|
|
45
53
|
@override
|
|
46
|
-
def get_client(self) -> Mistral:
|
|
47
|
-
if not
|
|
48
|
-
|
|
49
|
-
|
|
54
|
+
def get_client(self, api_key: str | None = None) -> Mistral:
|
|
55
|
+
if not self.has_client():
|
|
56
|
+
assert api_key
|
|
57
|
+
client = Mistral(
|
|
58
|
+
api_key=api_key,
|
|
50
59
|
async_client=default_httpx_client(),
|
|
51
60
|
)
|
|
52
|
-
|
|
61
|
+
self.assign_client(client)
|
|
62
|
+
return super().get_client()
|
|
53
63
|
|
|
54
64
|
def __init__(
|
|
55
65
|
self,
|
|
@@ -198,12 +208,14 @@ class MistralModel(LLM):
|
|
|
198
208
|
|
|
199
209
|
body: dict[str, Any] = {
|
|
200
210
|
"model": self.model_name,
|
|
201
|
-
"max_tokens": self.max_tokens,
|
|
202
211
|
"messages": messages,
|
|
203
212
|
"prompt_mode": "reasoning" if self.reasoning else None,
|
|
204
213
|
"tools": tools,
|
|
205
214
|
}
|
|
206
215
|
|
|
216
|
+
if self.max_tokens:
|
|
217
|
+
body["max_tokens"] = self.max_tokens
|
|
218
|
+
|
|
207
219
|
if self.supports_temperature:
|
|
208
220
|
if self.temperature is not None:
|
|
209
221
|
body["temperature"] = self.temperature
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import datetime
|
|
3
4
|
import io
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
7
|
+
import time
|
|
6
8
|
from typing import Any, Literal, Sequence, cast
|
|
7
9
|
|
|
8
10
|
from openai import APIConnectionError, AsyncOpenAI
|
|
@@ -30,6 +32,7 @@ from model_library.base import (
|
|
|
30
32
|
LLM,
|
|
31
33
|
BatchResult,
|
|
32
34
|
Citation,
|
|
35
|
+
DelegateConfig,
|
|
33
36
|
FileBase,
|
|
34
37
|
FileInput,
|
|
35
38
|
FileWithBase64,
|
|
@@ -44,6 +47,7 @@ from model_library.base import (
|
|
|
44
47
|
QueryResultCost,
|
|
45
48
|
QueryResultExtras,
|
|
46
49
|
QueryResultMetadata,
|
|
50
|
+
RateLimit,
|
|
47
51
|
RawInput,
|
|
48
52
|
RawResponse,
|
|
49
53
|
TextInput,
|
|
@@ -60,6 +64,7 @@ from model_library.exceptions import (
|
|
|
60
64
|
)
|
|
61
65
|
from model_library.model_utils import get_reasoning_in_tag
|
|
62
66
|
from model_library.register_models import register_provider
|
|
67
|
+
from model_library.retriers.base import BaseRetrier
|
|
63
68
|
from model_library.utils import create_openai_client_with_defaults
|
|
64
69
|
|
|
65
70
|
|
|
@@ -234,23 +239,31 @@ class OpenAIBatchMixin(LLMBatchMixin):
|
|
|
234
239
|
|
|
235
240
|
class OpenAIConfig(ProviderConfig):
|
|
236
241
|
deep_research: bool = False
|
|
242
|
+
verbosity: Literal["low", "medium", "high"] | None = None
|
|
237
243
|
|
|
238
244
|
|
|
239
245
|
@register_provider("openai")
|
|
240
246
|
class OpenAIModel(LLM):
|
|
241
247
|
provider_config = OpenAIConfig()
|
|
242
248
|
|
|
243
|
-
|
|
249
|
+
@override
|
|
250
|
+
def _get_default_api_key(self) -> str:
|
|
251
|
+
if self.delegate_config:
|
|
252
|
+
return self.delegate_config.api_key.get_secret_value()
|
|
253
|
+
return model_library_settings.OPENAI_API_KEY
|
|
244
254
|
|
|
245
255
|
@override
|
|
246
|
-
def get_client(self) -> AsyncOpenAI:
|
|
247
|
-
if self.
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
256
|
+
def get_client(self, api_key: str | None = None) -> AsyncOpenAI:
|
|
257
|
+
if not self.has_client():
|
|
258
|
+
assert api_key
|
|
259
|
+
client = create_openai_client_with_defaults(
|
|
260
|
+
base_url=self.delegate_config.base_url
|
|
261
|
+
if self.delegate_config
|
|
262
|
+
else None,
|
|
263
|
+
api_key=api_key,
|
|
252
264
|
)
|
|
253
|
-
|
|
265
|
+
self.assign_client(client)
|
|
266
|
+
return super().get_client()
|
|
254
267
|
|
|
255
268
|
def __init__(
|
|
256
269
|
self,
|
|
@@ -258,20 +271,21 @@ class OpenAIModel(LLM):
|
|
|
258
271
|
provider: str = "openai",
|
|
259
272
|
*,
|
|
260
273
|
config: LLMConfig | None = None,
|
|
261
|
-
custom_client: AsyncOpenAI | None = None,
|
|
262
274
|
use_completions: bool = False,
|
|
275
|
+
delegate_config: DelegateConfig | None = None,
|
|
263
276
|
):
|
|
264
|
-
super().__init__(model_name, provider, config=config)
|
|
265
277
|
self.use_completions: bool = (
|
|
266
278
|
use_completions # TODO: do completions in a separate file
|
|
267
279
|
)
|
|
268
|
-
self.
|
|
280
|
+
self.delegate_config = delegate_config
|
|
269
281
|
|
|
270
|
-
|
|
271
|
-
|
|
282
|
+
super().__init__(model_name, provider, config=config)
|
|
283
|
+
|
|
284
|
+
self.deep_research = self.provider_config.deep_research
|
|
285
|
+
self.verbosity = self.provider_config.verbosity
|
|
272
286
|
|
|
273
287
|
# batch client
|
|
274
|
-
self.supports_batch: bool = self.supports_batch and not
|
|
288
|
+
self.supports_batch: bool = self.supports_batch and not self.delegate_config
|
|
275
289
|
self.batch: LLMBatchMixin | None = (
|
|
276
290
|
OpenAIBatchMixin(self) if self.supports_batch else None
|
|
277
291
|
)
|
|
@@ -361,7 +375,6 @@ class OpenAIModel(LLM):
|
|
|
361
375
|
)
|
|
362
376
|
case RawResponse():
|
|
363
377
|
if self.use_completions:
|
|
364
|
-
pass
|
|
365
378
|
new_input.append(item.response)
|
|
366
379
|
else:
|
|
367
380
|
new_input.extend(item.response)
|
|
@@ -522,18 +535,20 @@ class OpenAIModel(LLM):
|
|
|
522
535
|
|
|
523
536
|
body: dict[str, Any] = {
|
|
524
537
|
"model": self.model_name,
|
|
525
|
-
"max_tokens": self.max_tokens,
|
|
526
538
|
"messages": parsed_input,
|
|
527
539
|
# enable usage data in streaming responses
|
|
528
540
|
"stream_options": {"include_usage": True},
|
|
529
541
|
}
|
|
530
542
|
|
|
543
|
+
if self.max_tokens:
|
|
544
|
+
body["max_tokens"] = self.max_tokens
|
|
545
|
+
|
|
531
546
|
if self.supports_tools:
|
|
532
547
|
parsed_tools = await self.parse_tools(tools)
|
|
533
548
|
if parsed_tools:
|
|
534
549
|
body["tools"] = parsed_tools
|
|
535
550
|
|
|
536
|
-
if self.reasoning:
|
|
551
|
+
if self.reasoning and self.max_tokens:
|
|
537
552
|
del body["max_tokens"]
|
|
538
553
|
body["max_completion_tokens"] = self.max_tokens
|
|
539
554
|
|
|
@@ -687,7 +702,7 @@ class OpenAIModel(LLM):
|
|
|
687
702
|
self, tools: Sequence[ToolDefinition], **kwargs: object
|
|
688
703
|
) -> None:
|
|
689
704
|
min_tokens = 30_000
|
|
690
|
-
if self.max_tokens < min_tokens:
|
|
705
|
+
if not self.max_tokens or self.max_tokens < min_tokens:
|
|
691
706
|
self.logger.warning(
|
|
692
707
|
f"Recommended to set max_tokens >= {min_tokens} for deep research models"
|
|
693
708
|
)
|
|
@@ -745,10 +760,12 @@ class OpenAIModel(LLM):
|
|
|
745
760
|
|
|
746
761
|
body: dict[str, Any] = {
|
|
747
762
|
"model": self.model_name,
|
|
748
|
-
"max_output_tokens": self.max_tokens,
|
|
749
763
|
"input": parsed_input,
|
|
750
764
|
}
|
|
751
765
|
|
|
766
|
+
if self.max_tokens:
|
|
767
|
+
body["max_output_tokens"] = self.max_tokens
|
|
768
|
+
|
|
752
769
|
if parsed_tools:
|
|
753
770
|
body["tools"] = parsed_tools
|
|
754
771
|
else:
|
|
@@ -759,6 +776,9 @@ class OpenAIModel(LLM):
|
|
|
759
776
|
if self.reasoning_effort is not None:
|
|
760
777
|
body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
|
|
761
778
|
|
|
779
|
+
if self.verbosity is not None:
|
|
780
|
+
body["text"] = {"format": {"type": "text"}, "verbosity": self.verbosity}
|
|
781
|
+
|
|
762
782
|
if self.supports_temperature:
|
|
763
783
|
if self.temperature is not None:
|
|
764
784
|
body["temperature"] = self.temperature
|
|
@@ -883,6 +903,61 @@ class OpenAIModel(LLM):
|
|
|
883
903
|
|
|
884
904
|
return result
|
|
885
905
|
|
|
906
|
+
@override
|
|
907
|
+
async def get_rate_limit(self) -> RateLimit | None:
|
|
908
|
+
headers = {}
|
|
909
|
+
|
|
910
|
+
try:
|
|
911
|
+
# NOTE: with_streaming_response doesn't seem to always work
|
|
912
|
+
if self.use_completions:
|
|
913
|
+
response = (
|
|
914
|
+
await self.get_client().chat.completions.with_raw_response.create(
|
|
915
|
+
max_completion_tokens=16,
|
|
916
|
+
model=self.model_name,
|
|
917
|
+
messages=[
|
|
918
|
+
{
|
|
919
|
+
"role": "user",
|
|
920
|
+
"content": "Ping",
|
|
921
|
+
}
|
|
922
|
+
],
|
|
923
|
+
stream=True,
|
|
924
|
+
)
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
response = await self.get_client().responses.with_raw_response.create(
|
|
928
|
+
max_output_tokens=16,
|
|
929
|
+
input="Ping",
|
|
930
|
+
model=self.model_name,
|
|
931
|
+
)
|
|
932
|
+
headers = response.headers
|
|
933
|
+
|
|
934
|
+
server_time_str = headers.get("date")
|
|
935
|
+
if server_time_str:
|
|
936
|
+
server_time = datetime.datetime.strptime(
|
|
937
|
+
server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
|
|
938
|
+
).replace(tzinfo=datetime.timezone.utc)
|
|
939
|
+
timestamp = server_time.timestamp()
|
|
940
|
+
else:
|
|
941
|
+
timestamp = time.time()
|
|
942
|
+
|
|
943
|
+
# NOTE: for openai, max_tokens is used to reject requests if the amount of tokens left is less than the max_tokens
|
|
944
|
+
|
|
945
|
+
# we calculate estimated_tokens as (character_count / 4) + max_tokens. Note that OpenAI's rate limiter doesn't tokenize the request using the model's specific tokenizer but relies on a character count-based heuristic.
|
|
946
|
+
|
|
947
|
+
return RateLimit(
|
|
948
|
+
raw=headers,
|
|
949
|
+
unix_timestamp=timestamp,
|
|
950
|
+
request_limit=headers.get("x-ratelimit-limit-requests", None)
|
|
951
|
+
or headers.get("x-ratelimit-limit", None),
|
|
952
|
+
request_remaining=headers.get("x-ratelimit-remaining-requests", None)
|
|
953
|
+
or headers.get("x-ratelimit-remaining"),
|
|
954
|
+
token_limit=int(headers["x-ratelimit-limit-tokens"]),
|
|
955
|
+
token_remaining=int(headers["x-ratelimit-remaining-tokens"]),
|
|
956
|
+
)
|
|
957
|
+
except Exception as e:
|
|
958
|
+
self.logger.warning(f"Failed to get rate limit: {e}")
|
|
959
|
+
return None
|
|
960
|
+
|
|
886
961
|
@override
|
|
887
962
|
async def query_json(
|
|
888
963
|
self,
|
|
@@ -906,7 +981,9 @@ class OpenAIModel(LLM):
|
|
|
906
981
|
except APIConnectionError:
|
|
907
982
|
raise ImmediateRetryException("Failed to connect to OpenAI")
|
|
908
983
|
|
|
909
|
-
response = await
|
|
984
|
+
response = await BaseRetrier.immediate_retry_wrapper(
|
|
985
|
+
func=_query, logger=self.logger
|
|
986
|
+
)
|
|
910
987
|
|
|
911
988
|
parsed: PydanticT | None = response.output_parsed
|
|
912
989
|
if parsed is None:
|
|
@@ -937,7 +1014,7 @@ class OpenAIModel(LLM):
|
|
|
937
1014
|
|
|
938
1015
|
return response.data[0].embedding
|
|
939
1016
|
|
|
940
|
-
return await
|
|
1017
|
+
return await BaseRetrier.immediate_retry_wrapper(
|
|
941
1018
|
func=_get_embedding, logger=self.logger
|
|
942
1019
|
)
|
|
943
1020
|
|
|
@@ -952,7 +1029,7 @@ class OpenAIModel(LLM):
|
|
|
952
1029
|
except Exception as e:
|
|
953
1030
|
raise Exception("Failed to query OpenAI's Moderation endpoint") from e
|
|
954
1031
|
|
|
955
|
-
return await
|
|
1032
|
+
return await BaseRetrier.immediate_retry_wrapper(
|
|
956
1033
|
func=_moderate_content, logger=self.logger
|
|
957
1034
|
)
|
|
958
1035
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
5
|
+
from model_library import model_library_settings
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
8
|
+
DelegateOnly,
|
|
9
|
+
LLMConfig,
|
|
10
|
+
)
|
|
11
|
+
from model_library.register_models import register_provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_provider("openrouter")
|
|
15
|
+
class OpenRouterModel(DelegateOnly):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str,
|
|
19
|
+
provider: Literal["openrouter"] = "openrouter",
|
|
20
|
+
*,
|
|
21
|
+
config: LLMConfig | None = None,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name, provider, config=config)
|
|
24
|
+
|
|
25
|
+
# https://openrouter.ai/docs/guides/community/openai-sdk
|
|
26
|
+
self.init_delegate(
|
|
27
|
+
config=config,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
29
|
+
base_url="https://openrouter.ai/api/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.OPENROUTER_API_KEY),
|
|
31
|
+
),
|
|
32
|
+
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
|
+
)
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("perplexity")
|
|
@@ -22,13 +23,12 @@ class PerplexityModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.perplexity.ai/guides/chat-completions-guide
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.PERPLEXITY_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.perplexity.ai",
|
|
30
|
+
api_key=SecretStr(model_library_settings.PERPLEXITY_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|