model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +139 -62
- model_library/base/delegate_only.py +77 -10
- model_library/base/output.py +43 -0
- model_library/base/utils.py +35 -0
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +3 -77
- model_library/providers/ai21labs.py +12 -8
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +49 -16
- model_library/providers/anthropic.py +93 -40
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +14 -10
- model_library/providers/google/google.py +48 -29
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +15 -17
- model_library/providers/mistral.py +20 -8
- model_library/providers/openai.py +99 -22
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +12 -6
- model_library/providers/xai.py +47 -42
- model_library/providers/zai.py +38 -8
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.7.dist-info/RECORD +0 -64
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import datetime
|
|
3
4
|
import io
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
7
|
+
import time
|
|
6
8
|
from typing import Any, Literal, Sequence, cast
|
|
7
9
|
|
|
8
10
|
from openai import APIConnectionError, AsyncOpenAI
|
|
@@ -30,6 +32,7 @@ from model_library.base import (
|
|
|
30
32
|
LLM,
|
|
31
33
|
BatchResult,
|
|
32
34
|
Citation,
|
|
35
|
+
DelegateConfig,
|
|
33
36
|
FileBase,
|
|
34
37
|
FileInput,
|
|
35
38
|
FileWithBase64,
|
|
@@ -44,6 +47,7 @@ from model_library.base import (
|
|
|
44
47
|
QueryResultCost,
|
|
45
48
|
QueryResultExtras,
|
|
46
49
|
QueryResultMetadata,
|
|
50
|
+
RateLimit,
|
|
47
51
|
RawInput,
|
|
48
52
|
RawResponse,
|
|
49
53
|
TextInput,
|
|
@@ -60,6 +64,7 @@ from model_library.exceptions import (
|
|
|
60
64
|
)
|
|
61
65
|
from model_library.model_utils import get_reasoning_in_tag
|
|
62
66
|
from model_library.register_models import register_provider
|
|
67
|
+
from model_library.retriers.base import BaseRetrier
|
|
63
68
|
from model_library.utils import create_openai_client_with_defaults
|
|
64
69
|
|
|
65
70
|
|
|
@@ -234,23 +239,31 @@ class OpenAIBatchMixin(LLMBatchMixin):
|
|
|
234
239
|
|
|
235
240
|
class OpenAIConfig(ProviderConfig):
|
|
236
241
|
deep_research: bool = False
|
|
242
|
+
verbosity: Literal["low", "medium", "high"] | None = None
|
|
237
243
|
|
|
238
244
|
|
|
239
245
|
@register_provider("openai")
|
|
240
246
|
class OpenAIModel(LLM):
|
|
241
247
|
provider_config = OpenAIConfig()
|
|
242
248
|
|
|
243
|
-
|
|
249
|
+
@override
|
|
250
|
+
def _get_default_api_key(self) -> str:
|
|
251
|
+
if self.delegate_config:
|
|
252
|
+
return self.delegate_config.api_key.get_secret_value()
|
|
253
|
+
return model_library_settings.OPENAI_API_KEY
|
|
244
254
|
|
|
245
255
|
@override
|
|
246
|
-
def get_client(self) -> AsyncOpenAI:
|
|
247
|
-
if self.
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
256
|
+
def get_client(self, api_key: str | None = None) -> AsyncOpenAI:
|
|
257
|
+
if not self.has_client():
|
|
258
|
+
assert api_key
|
|
259
|
+
client = create_openai_client_with_defaults(
|
|
260
|
+
base_url=self.delegate_config.base_url
|
|
261
|
+
if self.delegate_config
|
|
262
|
+
else None,
|
|
263
|
+
api_key=api_key,
|
|
252
264
|
)
|
|
253
|
-
|
|
265
|
+
self.assign_client(client)
|
|
266
|
+
return super().get_client()
|
|
254
267
|
|
|
255
268
|
def __init__(
|
|
256
269
|
self,
|
|
@@ -258,20 +271,21 @@ class OpenAIModel(LLM):
|
|
|
258
271
|
provider: str = "openai",
|
|
259
272
|
*,
|
|
260
273
|
config: LLMConfig | None = None,
|
|
261
|
-
custom_client: AsyncOpenAI | None = None,
|
|
262
274
|
use_completions: bool = False,
|
|
275
|
+
delegate_config: DelegateConfig | None = None,
|
|
263
276
|
):
|
|
264
|
-
super().__init__(model_name, provider, config=config)
|
|
265
277
|
self.use_completions: bool = (
|
|
266
278
|
use_completions # TODO: do completions in a separate file
|
|
267
279
|
)
|
|
268
|
-
self.
|
|
280
|
+
self.delegate_config = delegate_config
|
|
269
281
|
|
|
270
|
-
|
|
271
|
-
|
|
282
|
+
super().__init__(model_name, provider, config=config)
|
|
283
|
+
|
|
284
|
+
self.deep_research = self.provider_config.deep_research
|
|
285
|
+
self.verbosity = self.provider_config.verbosity
|
|
272
286
|
|
|
273
287
|
# batch client
|
|
274
|
-
self.supports_batch: bool = self.supports_batch and not
|
|
288
|
+
self.supports_batch: bool = self.supports_batch and not self.delegate_config
|
|
275
289
|
self.batch: LLMBatchMixin | None = (
|
|
276
290
|
OpenAIBatchMixin(self) if self.supports_batch else None
|
|
277
291
|
)
|
|
@@ -361,7 +375,6 @@ class OpenAIModel(LLM):
|
|
|
361
375
|
)
|
|
362
376
|
case RawResponse():
|
|
363
377
|
if self.use_completions:
|
|
364
|
-
pass
|
|
365
378
|
new_input.append(item.response)
|
|
366
379
|
else:
|
|
367
380
|
new_input.extend(item.response)
|
|
@@ -522,18 +535,20 @@ class OpenAIModel(LLM):
|
|
|
522
535
|
|
|
523
536
|
body: dict[str, Any] = {
|
|
524
537
|
"model": self.model_name,
|
|
525
|
-
"max_tokens": self.max_tokens,
|
|
526
538
|
"messages": parsed_input,
|
|
527
539
|
# enable usage data in streaming responses
|
|
528
540
|
"stream_options": {"include_usage": True},
|
|
529
541
|
}
|
|
530
542
|
|
|
543
|
+
if self.max_tokens:
|
|
544
|
+
body["max_tokens"] = self.max_tokens
|
|
545
|
+
|
|
531
546
|
if self.supports_tools:
|
|
532
547
|
parsed_tools = await self.parse_tools(tools)
|
|
533
548
|
if parsed_tools:
|
|
534
549
|
body["tools"] = parsed_tools
|
|
535
550
|
|
|
536
|
-
if self.reasoning:
|
|
551
|
+
if self.reasoning and self.max_tokens:
|
|
537
552
|
del body["max_tokens"]
|
|
538
553
|
body["max_completion_tokens"] = self.max_tokens
|
|
539
554
|
|
|
@@ -687,7 +702,7 @@ class OpenAIModel(LLM):
|
|
|
687
702
|
self, tools: Sequence[ToolDefinition], **kwargs: object
|
|
688
703
|
) -> None:
|
|
689
704
|
min_tokens = 30_000
|
|
690
|
-
if self.max_tokens < min_tokens:
|
|
705
|
+
if not self.max_tokens or self.max_tokens < min_tokens:
|
|
691
706
|
self.logger.warning(
|
|
692
707
|
f"Recommended to set max_tokens >= {min_tokens} for deep research models"
|
|
693
708
|
)
|
|
@@ -745,10 +760,12 @@ class OpenAIModel(LLM):
|
|
|
745
760
|
|
|
746
761
|
body: dict[str, Any] = {
|
|
747
762
|
"model": self.model_name,
|
|
748
|
-
"max_output_tokens": self.max_tokens,
|
|
749
763
|
"input": parsed_input,
|
|
750
764
|
}
|
|
751
765
|
|
|
766
|
+
if self.max_tokens:
|
|
767
|
+
body["max_output_tokens"] = self.max_tokens
|
|
768
|
+
|
|
752
769
|
if parsed_tools:
|
|
753
770
|
body["tools"] = parsed_tools
|
|
754
771
|
else:
|
|
@@ -759,6 +776,9 @@ class OpenAIModel(LLM):
|
|
|
759
776
|
if self.reasoning_effort is not None:
|
|
760
777
|
body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
|
|
761
778
|
|
|
779
|
+
if self.verbosity is not None:
|
|
780
|
+
body["text"] = {"format": {"type": "text"}, "verbosity": self.verbosity}
|
|
781
|
+
|
|
762
782
|
if self.supports_temperature:
|
|
763
783
|
if self.temperature is not None:
|
|
764
784
|
body["temperature"] = self.temperature
|
|
@@ -883,6 +903,61 @@ class OpenAIModel(LLM):
|
|
|
883
903
|
|
|
884
904
|
return result
|
|
885
905
|
|
|
906
|
+
@override
|
|
907
|
+
async def get_rate_limit(self) -> RateLimit | None:
|
|
908
|
+
headers = {}
|
|
909
|
+
|
|
910
|
+
try:
|
|
911
|
+
# NOTE: with_streaming_response doesn't seem to always work
|
|
912
|
+
if self.use_completions:
|
|
913
|
+
response = (
|
|
914
|
+
await self.get_client().chat.completions.with_raw_response.create(
|
|
915
|
+
max_completion_tokens=16,
|
|
916
|
+
model=self.model_name,
|
|
917
|
+
messages=[
|
|
918
|
+
{
|
|
919
|
+
"role": "user",
|
|
920
|
+
"content": "Ping",
|
|
921
|
+
}
|
|
922
|
+
],
|
|
923
|
+
stream=True,
|
|
924
|
+
)
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
response = await self.get_client().responses.with_raw_response.create(
|
|
928
|
+
max_output_tokens=16,
|
|
929
|
+
input="Ping",
|
|
930
|
+
model=self.model_name,
|
|
931
|
+
)
|
|
932
|
+
headers = response.headers
|
|
933
|
+
|
|
934
|
+
server_time_str = headers.get("date")
|
|
935
|
+
if server_time_str:
|
|
936
|
+
server_time = datetime.datetime.strptime(
|
|
937
|
+
server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
|
|
938
|
+
).replace(tzinfo=datetime.timezone.utc)
|
|
939
|
+
timestamp = server_time.timestamp()
|
|
940
|
+
else:
|
|
941
|
+
timestamp = time.time()
|
|
942
|
+
|
|
943
|
+
# NOTE: for openai, max_tokens is used to reject requests if the amount of tokens left is less than the max_tokens
|
|
944
|
+
|
|
945
|
+
# we calculate estimated_tokens as (character_count / 4) + max_tokens. Note that OpenAI's rate limiter doesn't tokenize the request using the model's specific tokenizer but relies on a character count-based heuristic.
|
|
946
|
+
|
|
947
|
+
return RateLimit(
|
|
948
|
+
raw=headers,
|
|
949
|
+
unix_timestamp=timestamp,
|
|
950
|
+
request_limit=headers.get("x-ratelimit-limit-requests", None)
|
|
951
|
+
or headers.get("x-ratelimit-limit", None),
|
|
952
|
+
request_remaining=headers.get("x-ratelimit-remaining-requests", None)
|
|
953
|
+
or headers.get("x-ratelimit-remaining"),
|
|
954
|
+
token_limit=int(headers["x-ratelimit-limit-tokens"]),
|
|
955
|
+
token_remaining=int(headers["x-ratelimit-remaining-tokens"]),
|
|
956
|
+
)
|
|
957
|
+
except Exception as e:
|
|
958
|
+
self.logger.warning(f"Failed to get rate limit: {e}")
|
|
959
|
+
return None
|
|
960
|
+
|
|
886
961
|
@override
|
|
887
962
|
async def query_json(
|
|
888
963
|
self,
|
|
@@ -906,7 +981,9 @@ class OpenAIModel(LLM):
|
|
|
906
981
|
except APIConnectionError:
|
|
907
982
|
raise ImmediateRetryException("Failed to connect to OpenAI")
|
|
908
983
|
|
|
909
|
-
response = await
|
|
984
|
+
response = await BaseRetrier.immediate_retry_wrapper(
|
|
985
|
+
func=_query, logger=self.logger
|
|
986
|
+
)
|
|
910
987
|
|
|
911
988
|
parsed: PydanticT | None = response.output_parsed
|
|
912
989
|
if parsed is None:
|
|
@@ -937,7 +1014,7 @@ class OpenAIModel(LLM):
|
|
|
937
1014
|
|
|
938
1015
|
return response.data[0].embedding
|
|
939
1016
|
|
|
940
|
-
return await
|
|
1017
|
+
return await BaseRetrier.immediate_retry_wrapper(
|
|
941
1018
|
func=_get_embedding, logger=self.logger
|
|
942
1019
|
)
|
|
943
1020
|
|
|
@@ -952,7 +1029,7 @@ class OpenAIModel(LLM):
|
|
|
952
1029
|
except Exception as e:
|
|
953
1030
|
raise Exception("Failed to query OpenAI's Moderation endpoint") from e
|
|
954
1031
|
|
|
955
|
-
return await
|
|
1032
|
+
return await BaseRetrier.immediate_retry_wrapper(
|
|
956
1033
|
func=_moderate_content, logger=self.logger
|
|
957
1034
|
)
|
|
958
1035
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
5
|
+
from model_library import model_library_settings
|
|
6
|
+
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
8
|
+
DelegateOnly,
|
|
9
|
+
LLMConfig,
|
|
10
|
+
)
|
|
11
|
+
from model_library.register_models import register_provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_provider("openrouter")
|
|
15
|
+
class OpenRouterModel(DelegateOnly):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str,
|
|
19
|
+
provider: Literal["openrouter"] = "openrouter",
|
|
20
|
+
*,
|
|
21
|
+
config: LLMConfig | None = None,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(model_name, provider, config=config)
|
|
24
|
+
|
|
25
|
+
# https://openrouter.ai/docs/guides/community/openai-sdk
|
|
26
|
+
self.init_delegate(
|
|
27
|
+
config=config,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
29
|
+
base_url="https://openrouter.ai/api/v1",
|
|
30
|
+
api_key=SecretStr(model_library_settings.OPENROUTER_API_KEY),
|
|
31
|
+
),
|
|
32
|
+
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
|
+
)
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
|
|
3
5
|
from model_library import model_library_settings
|
|
4
6
|
from model_library.base import (
|
|
7
|
+
DelegateConfig,
|
|
5
8
|
DelegateOnly,
|
|
6
9
|
LLMConfig,
|
|
7
10
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
11
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@register_provider("perplexity")
|
|
@@ -22,13 +23,12 @@ class PerplexityModel(DelegateOnly):
|
|
|
22
23
|
super().__init__(model_name, provider, config=config)
|
|
23
24
|
|
|
24
25
|
# https://docs.perplexity.ai/guides/chat-completions-guide
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
26
|
+
self.init_delegate(
|
|
28
27
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.PERPLEXITY_API_KEY,
|
|
28
|
+
delegate_config=DelegateConfig(
|
|
31
29
|
base_url="https://api.perplexity.ai",
|
|
30
|
+
api_key=SecretStr(model_library_settings.PERPLEXITY_API_KEY),
|
|
32
31
|
),
|
|
33
32
|
use_completions=True,
|
|
33
|
+
delegate_provider="openai",
|
|
34
34
|
)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
7
9
|
DelegateOnly,
|
|
8
10
|
LLMConfig,
|
|
9
11
|
ProviderConfig,
|
|
10
12
|
QueryResultCost,
|
|
11
13
|
QueryResultMetadata,
|
|
12
14
|
)
|
|
13
|
-
from model_library.providers.openai import OpenAIModel
|
|
14
15
|
from model_library.register_models import register_provider
|
|
15
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class TogetherConfig(ProviderConfig):
|
|
@@ -32,15 +32,14 @@ class TogetherModel(DelegateOnly):
|
|
|
32
32
|
):
|
|
33
33
|
super().__init__(model_name, provider, config=config)
|
|
34
34
|
# https://docs.together.ai/docs/openai-api-compatibility
|
|
35
|
-
self.
|
|
36
|
-
model_name=self.model_name,
|
|
37
|
-
provider=self.provider,
|
|
35
|
+
self.init_delegate(
|
|
38
36
|
config=config,
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
delegate_config=DelegateConfig(
|
|
38
|
+
base_url="https://api.together.xyz/v1/",
|
|
39
|
+
api_key=SecretStr(model_library_settings.TOGETHER_API_KEY),
|
|
42
40
|
),
|
|
43
41
|
use_completions=True,
|
|
42
|
+
delegate_provider="openai",
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
@override
|
model_library/providers/vals.py
CHANGED
|
@@ -151,13 +151,17 @@ class DummyAIBatchMixin(LLMBatchMixin):
|
|
|
151
151
|
class DummyAIModel(LLM):
|
|
152
152
|
_client: Redis | None = None
|
|
153
153
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
154
|
+
def _get_default_api_key(self) -> str:
|
|
155
|
+
return model_library_settings.REDIS_URL
|
|
156
|
+
|
|
157
|
+
def get_client(self, api_key: str | None = None) -> Redis:
|
|
158
|
+
if not self.has_client():
|
|
159
|
+
assert api_key
|
|
160
|
+
client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
|
|
158
161
|
model_library_settings.REDIS_URL, decode_responses=True
|
|
159
162
|
)
|
|
160
|
-
|
|
163
|
+
self.assign_client(client)
|
|
164
|
+
return super().get_client()
|
|
161
165
|
|
|
162
166
|
def __init__(
|
|
163
167
|
self,
|
|
@@ -238,12 +242,14 @@ class DummyAIModel(LLM):
|
|
|
238
242
|
messages = await self.parse_input(input)
|
|
239
243
|
body: dict[str, Any] = {
|
|
240
244
|
"model": self.model_name,
|
|
241
|
-
"max_tokens": self.max_tokens,
|
|
242
245
|
"seed": 0,
|
|
243
246
|
"messages": messages,
|
|
244
247
|
"tools": await self.parse_tools(tools),
|
|
245
248
|
}
|
|
246
249
|
|
|
250
|
+
if self.max_tokens:
|
|
251
|
+
body["max_tokens"] = self.max_tokens
|
|
252
|
+
|
|
247
253
|
if self.supports_temperature:
|
|
248
254
|
if self.temperature is not None:
|
|
249
255
|
body["temperature"] = self.temperature
|
model_library/providers/xai.py
CHANGED
|
@@ -2,7 +2,7 @@ import io
|
|
|
2
2
|
import logging
|
|
3
3
|
from typing import Any, Literal, Sequence
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from pydantic import SecretStr
|
|
6
6
|
from typing_extensions import override
|
|
7
7
|
from xai_sdk import AsyncClient
|
|
8
8
|
from xai_sdk.aio.chat import Chat
|
|
@@ -14,6 +14,7 @@ from xai_sdk.proto.v6.chat_pb2 import Message, Tool
|
|
|
14
14
|
from model_library import model_library_settings
|
|
15
15
|
from model_library.base import (
|
|
16
16
|
LLM,
|
|
17
|
+
DelegateConfig,
|
|
17
18
|
FileBase,
|
|
18
19
|
FileInput,
|
|
19
20
|
FileWithBase64,
|
|
@@ -36,24 +37,26 @@ from model_library.exceptions import (
|
|
|
36
37
|
MaxOutputTokensExceededError,
|
|
37
38
|
ModelNoOutputError,
|
|
38
39
|
NoMatchingToolCallError,
|
|
39
|
-
RateLimitException,
|
|
40
40
|
)
|
|
41
41
|
from model_library.providers.openai import OpenAIModel
|
|
42
42
|
from model_library.register_models import register_provider
|
|
43
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
44
43
|
|
|
45
44
|
|
|
46
45
|
@register_provider("grok")
|
|
47
46
|
class XAIModel(LLM):
|
|
48
|
-
|
|
47
|
+
@override
|
|
48
|
+
def _get_default_api_key(self) -> str:
|
|
49
|
+
return model_library_settings.XAI_API_KEY
|
|
49
50
|
|
|
50
51
|
@override
|
|
51
|
-
def get_client(self) -> AsyncClient:
|
|
52
|
-
if not
|
|
53
|
-
|
|
54
|
-
|
|
52
|
+
def get_client(self, api_key: str | None = None) -> AsyncClient:
|
|
53
|
+
if not self.has_client():
|
|
54
|
+
assert api_key
|
|
55
|
+
client = AsyncClient(
|
|
56
|
+
api_key=api_key,
|
|
55
57
|
)
|
|
56
|
-
|
|
58
|
+
self.assign_client(client)
|
|
59
|
+
return super().get_client()
|
|
57
60
|
|
|
58
61
|
@override
|
|
59
62
|
def __init__(
|
|
@@ -73,13 +76,13 @@ class XAIModel(LLM):
|
|
|
73
76
|
model_name=self.model_name,
|
|
74
77
|
provider=provider,
|
|
75
78
|
config=config,
|
|
76
|
-
|
|
77
|
-
api_key=model_library_settings.XAI_API_KEY,
|
|
79
|
+
delegate_config=DelegateConfig(
|
|
78
80
|
base_url=(
|
|
79
81
|
"https://us-west-1.api.x.ai/v1"
|
|
80
82
|
if "grok-3-mini-reasoning" in self.model_name
|
|
81
83
|
else "https://api.x.ai/v1"
|
|
82
84
|
),
|
|
85
|
+
api_key=SecretStr(model_library_settings.XAI_API_KEY),
|
|
83
86
|
),
|
|
84
87
|
use_completions=True,
|
|
85
88
|
)
|
|
@@ -210,12 +213,14 @@ class XAIModel(LLM):
|
|
|
210
213
|
messages.append(system(str(kwargs.pop("system_prompt"))))
|
|
211
214
|
|
|
212
215
|
body: dict[str, Any] = {
|
|
213
|
-
"max_tokens": self.max_tokens,
|
|
214
216
|
"model": self.model_name,
|
|
215
217
|
"tools": await self.parse_tools(tools),
|
|
216
218
|
"messages": messages,
|
|
217
219
|
}
|
|
218
220
|
|
|
221
|
+
if self.max_tokens:
|
|
222
|
+
body["max_tokens"] = self.max_tokens
|
|
223
|
+
|
|
219
224
|
if self.supports_temperature:
|
|
220
225
|
if self.temperature is not None:
|
|
221
226
|
body["temperature"] = self.temperature
|
|
@@ -253,38 +258,35 @@ class XAIModel(LLM):
|
|
|
253
258
|
|
|
254
259
|
body = await self.build_body(input, tools=tools, **kwargs)
|
|
255
260
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
args=tool_call.function.arguments,
|
|
277
|
-
)
|
|
261
|
+
chat: Chat = self.get_client().chat.create(**body)
|
|
262
|
+
|
|
263
|
+
latest_response: Response | None = None
|
|
264
|
+
async for response, _ in chat.stream():
|
|
265
|
+
latest_response = response
|
|
266
|
+
|
|
267
|
+
if not latest_response:
|
|
268
|
+
raise ModelNoOutputError("Model failed to produce a response")
|
|
269
|
+
|
|
270
|
+
tool_calls: list[ToolCall] = []
|
|
271
|
+
if (
|
|
272
|
+
latest_response.finish_reason == "REASON_TOOL_CALLS"
|
|
273
|
+
and latest_response.tool_calls
|
|
274
|
+
):
|
|
275
|
+
for tool_call in latest_response.tool_calls:
|
|
276
|
+
tool_calls.append(
|
|
277
|
+
ToolCall(
|
|
278
|
+
id=tool_call.id,
|
|
279
|
+
name=tool_call.function.name,
|
|
280
|
+
args=tool_call.function.arguments,
|
|
278
281
|
)
|
|
282
|
+
)
|
|
279
283
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
except grpc.RpcError as e:
|
|
287
|
-
raise RateLimitException(e.details())
|
|
284
|
+
if (
|
|
285
|
+
latest_response.finish_reason == "REASON_MAX_LEN"
|
|
286
|
+
and not latest_response.content
|
|
287
|
+
and not latest_response.reasoning_content
|
|
288
|
+
):
|
|
289
|
+
raise MaxOutputTokensExceededError()
|
|
288
290
|
|
|
289
291
|
return QueryResult(
|
|
290
292
|
output_text=latest_response.content,
|
|
@@ -310,6 +312,9 @@ class XAIModel(LLM):
|
|
|
310
312
|
tools: list[ToolDefinition] = [],
|
|
311
313
|
**kwargs: object,
|
|
312
314
|
) -> int:
|
|
315
|
+
if not input and not history:
|
|
316
|
+
return 0
|
|
317
|
+
|
|
313
318
|
string_input = await self.stringify_input(input, history=history, tools=tools)
|
|
314
319
|
self.logger.debug(string_input)
|
|
315
320
|
|
model_library/providers/zai.py
CHANGED
|
@@ -1,17 +1,36 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import SecretStr
|
|
4
|
+
from typing_extensions import override
|
|
2
5
|
|
|
3
6
|
from model_library import model_library_settings
|
|
4
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
5
9
|
DelegateOnly,
|
|
6
10
|
LLMConfig,
|
|
11
|
+
ProviderConfig,
|
|
7
12
|
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
9
13
|
from model_library.register_models import register_provider
|
|
10
|
-
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ZAIConfig(ProviderConfig):
|
|
17
|
+
"""Configuration for ZAI (GLM) models.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
clear_thinking: When disabled, reasoning content from previous turns is
|
|
21
|
+
preserved in context. This is useful for multi-turn conversations where
|
|
22
|
+
you want the model to maintain coherent reasoning across turns.
|
|
23
|
+
Enabled by default on the standard API endpoint.
|
|
24
|
+
See: https://docs.z.ai/guides/capabilities/thinking-mode
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
clear_thinking: bool = True
|
|
11
28
|
|
|
12
29
|
|
|
13
30
|
@register_provider("zai")
|
|
14
31
|
class ZAIModel(DelegateOnly):
|
|
32
|
+
provider_config = ZAIConfig()
|
|
33
|
+
|
|
15
34
|
def __init__(
|
|
16
35
|
self,
|
|
17
36
|
model_name: str,
|
|
@@ -21,14 +40,25 @@ class ZAIModel(DelegateOnly):
|
|
|
21
40
|
):
|
|
22
41
|
super().__init__(model_name, provider, config=config)
|
|
23
42
|
|
|
43
|
+
self.clear_thinking = self.provider_config.clear_thinking
|
|
44
|
+
|
|
24
45
|
# https://docs.z.ai/guides/develop/openai/python
|
|
25
|
-
self.
|
|
26
|
-
model_name=self.model_name,
|
|
27
|
-
provider=self.provider,
|
|
46
|
+
self.init_delegate(
|
|
28
47
|
config=config,
|
|
29
|
-
|
|
30
|
-
api_key=model_library_settings.ZAI_API_KEY,
|
|
48
|
+
delegate_config=DelegateConfig(
|
|
31
49
|
base_url="https://open.bigmodel.cn/api/paas/v4/",
|
|
50
|
+
api_key=SecretStr(model_library_settings.ZAI_API_KEY),
|
|
32
51
|
),
|
|
33
52
|
use_completions=True,
|
|
53
|
+
delegate_provider="openai",
|
|
34
54
|
)
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
58
|
+
"""Build extra body parameters for GLM-specific features."""
|
|
59
|
+
return {
|
|
60
|
+
"thinking": {
|
|
61
|
+
"type": "enabled" if self.reasoning else "disabled",
|
|
62
|
+
"clear_thinking": self.clear_thinking,
|
|
63
|
+
}
|
|
64
|
+
}
|