model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/base.py +139 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +44 -57
  6. model_library/config/all_models.json +253 -126
  7. model_library/config/kimi_models.yaml +30 -3
  8. model_library/config/openai_models.yaml +15 -23
  9. model_library/config/zai_models.yaml +24 -3
  10. model_library/exceptions.py +3 -77
  11. model_library/providers/ai21labs.py +12 -8
  12. model_library/providers/alibaba.py +17 -8
  13. model_library/providers/amazon.py +49 -16
  14. model_library/providers/anthropic.py +93 -40
  15. model_library/providers/azure.py +22 -10
  16. model_library/providers/cohere.py +7 -7
  17. model_library/providers/deepseek.py +8 -8
  18. model_library/providers/fireworks.py +7 -8
  19. model_library/providers/google/batch.py +14 -10
  20. model_library/providers/google/google.py +48 -29
  21. model_library/providers/inception.py +7 -7
  22. model_library/providers/kimi.py +18 -8
  23. model_library/providers/minimax.py +15 -17
  24. model_library/providers/mistral.py +20 -8
  25. model_library/providers/openai.py +99 -22
  26. model_library/providers/openrouter.py +34 -0
  27. model_library/providers/perplexity.py +7 -7
  28. model_library/providers/together.py +7 -8
  29. model_library/providers/vals.py +12 -6
  30. model_library/providers/xai.py +47 -42
  31. model_library/providers/zai.py +38 -8
  32. model_library/registry_utils.py +39 -15
  33. model_library/retriers/__init__.py +0 -0
  34. model_library/retriers/backoff.py +73 -0
  35. model_library/retriers/base.py +225 -0
  36. model_library/retriers/token.py +427 -0
  37. model_library/retriers/utils.py +11 -0
  38. model_library/settings.py +1 -1
  39. model_library/utils.py +13 -0
  40. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
  41. model_library-0.1.8.dist-info/RECORD +70 -0
  42. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  43. model_library-0.1.7.dist-info/RECORD +0 -64
  44. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import datetime
3
4
  import io
4
5
  import json
5
6
  import logging
7
+ import time
6
8
  from typing import Any, Literal, Sequence, cast
7
9
 
8
10
  from openai import APIConnectionError, AsyncOpenAI
@@ -30,6 +32,7 @@ from model_library.base import (
30
32
  LLM,
31
33
  BatchResult,
32
34
  Citation,
35
+ DelegateConfig,
33
36
  FileBase,
34
37
  FileInput,
35
38
  FileWithBase64,
@@ -44,6 +47,7 @@ from model_library.base import (
44
47
  QueryResultCost,
45
48
  QueryResultExtras,
46
49
  QueryResultMetadata,
50
+ RateLimit,
47
51
  RawInput,
48
52
  RawResponse,
49
53
  TextInput,
@@ -60,6 +64,7 @@ from model_library.exceptions import (
60
64
  )
61
65
  from model_library.model_utils import get_reasoning_in_tag
62
66
  from model_library.register_models import register_provider
67
+ from model_library.retriers.base import BaseRetrier
63
68
  from model_library.utils import create_openai_client_with_defaults
64
69
 
65
70
 
@@ -234,23 +239,31 @@ class OpenAIBatchMixin(LLMBatchMixin):
234
239
 
235
240
  class OpenAIConfig(ProviderConfig):
236
241
  deep_research: bool = False
242
+ verbosity: Literal["low", "medium", "high"] | None = None
237
243
 
238
244
 
239
245
  @register_provider("openai")
240
246
  class OpenAIModel(LLM):
241
247
  provider_config = OpenAIConfig()
242
248
 
243
- _client: AsyncOpenAI | None = None
249
+ @override
250
+ def _get_default_api_key(self) -> str:
251
+ if self.delegate_config:
252
+ return self.delegate_config.api_key.get_secret_value()
253
+ return model_library_settings.OPENAI_API_KEY
244
254
 
245
255
  @override
246
- def get_client(self) -> AsyncOpenAI:
247
- if self._delegate_client:
248
- return self._delegate_client
249
- if not OpenAIModel._client:
250
- OpenAIModel._client = create_openai_client_with_defaults(
251
- api_key=model_library_settings.OPENAI_API_KEY
256
+ def get_client(self, api_key: str | None = None) -> AsyncOpenAI:
257
+ if not self.has_client():
258
+ assert api_key
259
+ client = create_openai_client_with_defaults(
260
+ base_url=self.delegate_config.base_url
261
+ if self.delegate_config
262
+ else None,
263
+ api_key=api_key,
252
264
  )
253
- return OpenAIModel._client
265
+ self.assign_client(client)
266
+ return super().get_client()
254
267
 
255
268
  def __init__(
256
269
  self,
@@ -258,20 +271,21 @@ class OpenAIModel(LLM):
258
271
  provider: str = "openai",
259
272
  *,
260
273
  config: LLMConfig | None = None,
261
- custom_client: AsyncOpenAI | None = None,
262
274
  use_completions: bool = False,
275
+ delegate_config: DelegateConfig | None = None,
263
276
  ):
264
- super().__init__(model_name, provider, config=config)
265
277
  self.use_completions: bool = (
266
278
  use_completions # TODO: do completions in a separate file
267
279
  )
268
- self.deep_research = self.provider_config.deep_research
280
+ self.delegate_config = delegate_config
269
281
 
270
- # allow custom client to act as delegate (native)
271
- self._delegate_client: AsyncOpenAI | None = custom_client
282
+ super().__init__(model_name, provider, config=config)
283
+
284
+ self.deep_research = self.provider_config.deep_research
285
+ self.verbosity = self.provider_config.verbosity
272
286
 
273
287
  # batch client
274
- self.supports_batch: bool = self.supports_batch and not custom_client
288
+ self.supports_batch: bool = self.supports_batch and not self.delegate_config
275
289
  self.batch: LLMBatchMixin | None = (
276
290
  OpenAIBatchMixin(self) if self.supports_batch else None
277
291
  )
@@ -361,7 +375,6 @@ class OpenAIModel(LLM):
361
375
  )
362
376
  case RawResponse():
363
377
  if self.use_completions:
364
- pass
365
378
  new_input.append(item.response)
366
379
  else:
367
380
  new_input.extend(item.response)
@@ -522,18 +535,20 @@ class OpenAIModel(LLM):
522
535
 
523
536
  body: dict[str, Any] = {
524
537
  "model": self.model_name,
525
- "max_tokens": self.max_tokens,
526
538
  "messages": parsed_input,
527
539
  # enable usage data in streaming responses
528
540
  "stream_options": {"include_usage": True},
529
541
  }
530
542
 
543
+ if self.max_tokens:
544
+ body["max_tokens"] = self.max_tokens
545
+
531
546
  if self.supports_tools:
532
547
  parsed_tools = await self.parse_tools(tools)
533
548
  if parsed_tools:
534
549
  body["tools"] = parsed_tools
535
550
 
536
- if self.reasoning:
551
+ if self.reasoning and self.max_tokens:
537
552
  del body["max_tokens"]
538
553
  body["max_completion_tokens"] = self.max_tokens
539
554
 
@@ -687,7 +702,7 @@ class OpenAIModel(LLM):
687
702
  self, tools: Sequence[ToolDefinition], **kwargs: object
688
703
  ) -> None:
689
704
  min_tokens = 30_000
690
- if self.max_tokens < min_tokens:
705
+ if not self.max_tokens or self.max_tokens < min_tokens:
691
706
  self.logger.warning(
692
707
  f"Recommended to set max_tokens >= {min_tokens} for deep research models"
693
708
  )
@@ -745,10 +760,12 @@ class OpenAIModel(LLM):
745
760
 
746
761
  body: dict[str, Any] = {
747
762
  "model": self.model_name,
748
- "max_output_tokens": self.max_tokens,
749
763
  "input": parsed_input,
750
764
  }
751
765
 
766
+ if self.max_tokens:
767
+ body["max_output_tokens"] = self.max_tokens
768
+
752
769
  if parsed_tools:
753
770
  body["tools"] = parsed_tools
754
771
  else:
@@ -759,6 +776,9 @@ class OpenAIModel(LLM):
759
776
  if self.reasoning_effort is not None:
760
777
  body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
761
778
 
779
+ if self.verbosity is not None:
780
+ body["text"] = {"format": {"type": "text"}, "verbosity": self.verbosity}
781
+
762
782
  if self.supports_temperature:
763
783
  if self.temperature is not None:
764
784
  body["temperature"] = self.temperature
@@ -883,6 +903,61 @@ class OpenAIModel(LLM):
883
903
 
884
904
  return result
885
905
 
906
+ @override
907
+ async def get_rate_limit(self) -> RateLimit | None:
908
+ headers = {}
909
+
910
+ try:
911
+ # NOTE: with_streaming_response doesn't seem to always work
912
+ if self.use_completions:
913
+ response = (
914
+ await self.get_client().chat.completions.with_raw_response.create(
915
+ max_completion_tokens=16,
916
+ model=self.model_name,
917
+ messages=[
918
+ {
919
+ "role": "user",
920
+ "content": "Ping",
921
+ }
922
+ ],
923
+ stream=True,
924
+ )
925
+ )
926
+ else:
927
+ response = await self.get_client().responses.with_raw_response.create(
928
+ max_output_tokens=16,
929
+ input="Ping",
930
+ model=self.model_name,
931
+ )
932
+ headers = response.headers
933
+
934
+ server_time_str = headers.get("date")
935
+ if server_time_str:
936
+ server_time = datetime.datetime.strptime(
937
+ server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
938
+ ).replace(tzinfo=datetime.timezone.utc)
939
+ timestamp = server_time.timestamp()
940
+ else:
941
+ timestamp = time.time()
942
+
943
+ # NOTE: for openai, max_tokens is used to reject requests if the amount of tokens left is less than the max_tokens
944
+
945
+ # we calculate estimated_tokens as (character_count / 4) + max_tokens. Note that OpenAI's rate limiter doesn't tokenize the request using the model's specific tokenizer but relies on a character count-based heuristic.
946
+
947
+ return RateLimit(
948
+ raw=headers,
949
+ unix_timestamp=timestamp,
950
+ request_limit=headers.get("x-ratelimit-limit-requests", None)
951
+ or headers.get("x-ratelimit-limit", None),
952
+ request_remaining=headers.get("x-ratelimit-remaining-requests", None)
953
+ or headers.get("x-ratelimit-remaining"),
954
+ token_limit=int(headers["x-ratelimit-limit-tokens"]),
955
+ token_remaining=int(headers["x-ratelimit-remaining-tokens"]),
956
+ )
957
+ except Exception as e:
958
+ self.logger.warning(f"Failed to get rate limit: {e}")
959
+ return None
960
+
886
961
  @override
887
962
  async def query_json(
888
963
  self,
@@ -906,7 +981,9 @@ class OpenAIModel(LLM):
906
981
  except APIConnectionError:
907
982
  raise ImmediateRetryException("Failed to connect to OpenAI")
908
983
 
909
- response = await LLM.immediate_retry_wrapper(func=_query, logger=self.logger)
984
+ response = await BaseRetrier.immediate_retry_wrapper(
985
+ func=_query, logger=self.logger
986
+ )
910
987
 
911
988
  parsed: PydanticT | None = response.output_parsed
912
989
  if parsed is None:
@@ -937,7 +1014,7 @@ class OpenAIModel(LLM):
937
1014
 
938
1015
  return response.data[0].embedding
939
1016
 
940
- return await LLM.immediate_retry_wrapper(
1017
+ return await BaseRetrier.immediate_retry_wrapper(
941
1018
  func=_get_embedding, logger=self.logger
942
1019
  )
943
1020
 
@@ -952,7 +1029,7 @@ class OpenAIModel(LLM):
952
1029
  except Exception as e:
953
1030
  raise Exception("Failed to query OpenAI's Moderation endpoint") from e
954
1031
 
955
- return await LLM.immediate_retry_wrapper(
1032
+ return await BaseRetrier.immediate_retry_wrapper(
956
1033
  func=_moderate_content, logger=self.logger
957
1034
  )
958
1035
 
@@ -0,0 +1,34 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import SecretStr
4
+
5
+ from model_library import model_library_settings
6
+ from model_library.base import (
7
+ DelegateConfig,
8
+ DelegateOnly,
9
+ LLMConfig,
10
+ )
11
+ from model_library.register_models import register_provider
12
+
13
+
14
+ @register_provider("openrouter")
15
+ class OpenRouterModel(DelegateOnly):
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ provider: Literal["openrouter"] = "openrouter",
20
+ *,
21
+ config: LLMConfig | None = None,
22
+ ):
23
+ super().__init__(model_name, provider, config=config)
24
+
25
+ # https://openrouter.ai/docs/guides/community/openai-sdk
26
+ self.init_delegate(
27
+ config=config,
28
+ delegate_config=DelegateConfig(
29
+ base_url="https://openrouter.ai/api/v1",
30
+ api_key=SecretStr(model_library_settings.OPENROUTER_API_KEY),
31
+ ),
32
+ use_completions=True,
33
+ delegate_provider="openai",
34
+ )
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("perplexity")
@@ -22,13 +23,12 @@ class PerplexityModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.perplexity.ai/guides/chat-completions-guide
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.PERPLEXITY_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.perplexity.ai",
30
+ api_key=SecretStr(model_library_settings.PERPLEXITY_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
7
9
  DelegateOnly,
8
10
  LLMConfig,
9
11
  ProviderConfig,
10
12
  QueryResultCost,
11
13
  QueryResultMetadata,
12
14
  )
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class TogetherConfig(ProviderConfig):
@@ -32,15 +32,14 @@ class TogetherModel(DelegateOnly):
32
32
  ):
33
33
  super().__init__(model_name, provider, config=config)
34
34
  # https://docs.together.ai/docs/openai-api-compatibility
35
- self.delegate = OpenAIModel(
36
- model_name=self.model_name,
37
- provider=self.provider,
35
+ self.init_delegate(
38
36
  config=config,
39
- custom_client=create_openai_client_with_defaults(
40
- api_key=model_library_settings.TOGETHER_API_KEY,
41
- base_url="https://api.together.xyz/v1",
37
+ delegate_config=DelegateConfig(
38
+ base_url="https://api.together.xyz/v1/",
39
+ api_key=SecretStr(model_library_settings.TOGETHER_API_KEY),
42
40
  ),
43
41
  use_completions=True,
42
+ delegate_provider="openai",
44
43
  )
45
44
 
46
45
  @override
@@ -151,13 +151,17 @@ class DummyAIBatchMixin(LLMBatchMixin):
151
151
  class DummyAIModel(LLM):
152
152
  _client: Redis | None = None
153
153
 
154
- @override
155
- def get_client(self) -> Redis:
156
- if not DummyAIModel._client:
157
- DummyAIModel._client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
154
+ def _get_default_api_key(self) -> str:
155
+ return model_library_settings.REDIS_URL
156
+
157
+ def get_client(self, api_key: str | None = None) -> Redis:
158
+ if not self.has_client():
159
+ assert api_key
160
+ client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
158
161
  model_library_settings.REDIS_URL, decode_responses=True
159
162
  )
160
- return DummyAIModel._client
163
+ self.assign_client(client)
164
+ return super().get_client()
161
165
 
162
166
  def __init__(
163
167
  self,
@@ -238,12 +242,14 @@ class DummyAIModel(LLM):
238
242
  messages = await self.parse_input(input)
239
243
  body: dict[str, Any] = {
240
244
  "model": self.model_name,
241
- "max_tokens": self.max_tokens,
242
245
  "seed": 0,
243
246
  "messages": messages,
244
247
  "tools": await self.parse_tools(tools),
245
248
  }
246
249
 
250
+ if self.max_tokens:
251
+ body["max_tokens"] = self.max_tokens
252
+
247
253
  if self.supports_temperature:
248
254
  if self.temperature is not None:
249
255
  body["temperature"] = self.temperature
@@ -2,7 +2,7 @@ import io
2
2
  import logging
3
3
  from typing import Any, Literal, Sequence
4
4
 
5
- import grpc
5
+ from pydantic import SecretStr
6
6
  from typing_extensions import override
7
7
  from xai_sdk import AsyncClient
8
8
  from xai_sdk.aio.chat import Chat
@@ -14,6 +14,7 @@ from xai_sdk.proto.v6.chat_pb2 import Message, Tool
14
14
  from model_library import model_library_settings
15
15
  from model_library.base import (
16
16
  LLM,
17
+ DelegateConfig,
17
18
  FileBase,
18
19
  FileInput,
19
20
  FileWithBase64,
@@ -36,24 +37,26 @@ from model_library.exceptions import (
36
37
  MaxOutputTokensExceededError,
37
38
  ModelNoOutputError,
38
39
  NoMatchingToolCallError,
39
- RateLimitException,
40
40
  )
41
41
  from model_library.providers.openai import OpenAIModel
42
42
  from model_library.register_models import register_provider
43
- from model_library.utils import create_openai_client_with_defaults
44
43
 
45
44
 
46
45
  @register_provider("grok")
47
46
  class XAIModel(LLM):
48
- _client: AsyncClient | None = None
47
+ @override
48
+ def _get_default_api_key(self) -> str:
49
+ return model_library_settings.XAI_API_KEY
49
50
 
50
51
  @override
51
- def get_client(self) -> AsyncClient:
52
- if not XAIModel._client:
53
- XAIModel._client = AsyncClient(
54
- api_key=model_library_settings.XAI_API_KEY,
52
+ def get_client(self, api_key: str | None = None) -> AsyncClient:
53
+ if not self.has_client():
54
+ assert api_key
55
+ client = AsyncClient(
56
+ api_key=api_key,
55
57
  )
56
- return XAIModel._client
58
+ self.assign_client(client)
59
+ return super().get_client()
57
60
 
58
61
  @override
59
62
  def __init__(
@@ -73,13 +76,13 @@ class XAIModel(LLM):
73
76
  model_name=self.model_name,
74
77
  provider=provider,
75
78
  config=config,
76
- custom_client=create_openai_client_with_defaults(
77
- api_key=model_library_settings.XAI_API_KEY,
79
+ delegate_config=DelegateConfig(
78
80
  base_url=(
79
81
  "https://us-west-1.api.x.ai/v1"
80
82
  if "grok-3-mini-reasoning" in self.model_name
81
83
  else "https://api.x.ai/v1"
82
84
  ),
85
+ api_key=SecretStr(model_library_settings.XAI_API_KEY),
83
86
  ),
84
87
  use_completions=True,
85
88
  )
@@ -210,12 +213,14 @@ class XAIModel(LLM):
210
213
  messages.append(system(str(kwargs.pop("system_prompt"))))
211
214
 
212
215
  body: dict[str, Any] = {
213
- "max_tokens": self.max_tokens,
214
216
  "model": self.model_name,
215
217
  "tools": await self.parse_tools(tools),
216
218
  "messages": messages,
217
219
  }
218
220
 
221
+ if self.max_tokens:
222
+ body["max_tokens"] = self.max_tokens
223
+
219
224
  if self.supports_temperature:
220
225
  if self.temperature is not None:
221
226
  body["temperature"] = self.temperature
@@ -253,38 +258,35 @@ class XAIModel(LLM):
253
258
 
254
259
  body = await self.build_body(input, tools=tools, **kwargs)
255
260
 
256
- try:
257
- chat: Chat = self.get_client().chat.create(**body)
258
-
259
- latest_response: Response | None = None
260
- async for response, _ in chat.stream():
261
- latest_response = response
262
-
263
- if not latest_response:
264
- raise ModelNoOutputError("Model failed to produce a response")
265
-
266
- tool_calls: list[ToolCall] = []
267
- if (
268
- latest_response.finish_reason == "REASON_TOOL_CALLS"
269
- and latest_response.tool_calls
270
- ):
271
- for tool_call in latest_response.tool_calls:
272
- tool_calls.append(
273
- ToolCall(
274
- id=tool_call.id,
275
- name=tool_call.function.name,
276
- args=tool_call.function.arguments,
277
- )
261
+ chat: Chat = self.get_client().chat.create(**body)
262
+
263
+ latest_response: Response | None = None
264
+ async for response, _ in chat.stream():
265
+ latest_response = response
266
+
267
+ if not latest_response:
268
+ raise ModelNoOutputError("Model failed to produce a response")
269
+
270
+ tool_calls: list[ToolCall] = []
271
+ if (
272
+ latest_response.finish_reason == "REASON_TOOL_CALLS"
273
+ and latest_response.tool_calls
274
+ ):
275
+ for tool_call in latest_response.tool_calls:
276
+ tool_calls.append(
277
+ ToolCall(
278
+ id=tool_call.id,
279
+ name=tool_call.function.name,
280
+ args=tool_call.function.arguments,
278
281
  )
282
+ )
279
283
 
280
- if (
281
- latest_response.finish_reason == "REASON_MAX_LEN"
282
- and not latest_response.content
283
- and not latest_response.reasoning_content
284
- ):
285
- raise MaxOutputTokensExceededError()
286
- except grpc.RpcError as e:
287
- raise RateLimitException(e.details())
284
+ if (
285
+ latest_response.finish_reason == "REASON_MAX_LEN"
286
+ and not latest_response.content
287
+ and not latest_response.reasoning_content
288
+ ):
289
+ raise MaxOutputTokensExceededError()
288
290
 
289
291
  return QueryResult(
290
292
  output_text=latest_response.content,
@@ -310,6 +312,9 @@ class XAIModel(LLM):
310
312
  tools: list[ToolDefinition] = [],
311
313
  **kwargs: object,
312
314
  ) -> int:
315
+ if not input and not history:
316
+ return 0
317
+
313
318
  string_input = await self.stringify_input(input, history=history, tools=tools)
314
319
  self.logger.debug(string_input)
315
320
 
@@ -1,17 +1,36 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
+
3
+ from pydantic import SecretStr
4
+ from typing_extensions import override
2
5
 
3
6
  from model_library import model_library_settings
4
7
  from model_library.base import (
8
+ DelegateConfig,
5
9
  DelegateOnly,
6
10
  LLMConfig,
11
+ ProviderConfig,
7
12
  )
8
- from model_library.providers.openai import OpenAIModel
9
13
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
14
+
15
+
16
+ class ZAIConfig(ProviderConfig):
17
+ """Configuration for ZAI (GLM) models.
18
+
19
+ Attributes:
20
+ clear_thinking: When disabled, reasoning content from previous turns is
21
+ preserved in context. This is useful for multi-turn conversations where
22
+ you want the model to maintain coherent reasoning across turns.
23
+ Enabled by default on the standard API endpoint.
24
+ See: https://docs.z.ai/guides/capabilities/thinking-mode
25
+ """
26
+
27
+ clear_thinking: bool = True
11
28
 
12
29
 
13
30
  @register_provider("zai")
14
31
  class ZAIModel(DelegateOnly):
32
+ provider_config = ZAIConfig()
33
+
15
34
  def __init__(
16
35
  self,
17
36
  model_name: str,
@@ -21,14 +40,25 @@ class ZAIModel(DelegateOnly):
21
40
  ):
22
41
  super().__init__(model_name, provider, config=config)
23
42
 
43
+ self.clear_thinking = self.provider_config.clear_thinking
44
+
24
45
  # https://docs.z.ai/guides/develop/openai/python
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
46
+ self.init_delegate(
28
47
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.ZAI_API_KEY,
48
+ delegate_config=DelegateConfig(
31
49
  base_url="https://open.bigmodel.cn/api/paas/v4/",
50
+ api_key=SecretStr(model_library_settings.ZAI_API_KEY),
32
51
  ),
33
52
  use_completions=True,
53
+ delegate_provider="openai",
34
54
  )
55
+
56
+ @override
57
+ def _get_extra_body(self) -> dict[str, Any]:
58
+ """Build extra body parameters for GLM-specific features."""
59
+ return {
60
+ "thinking": {
61
+ "type": "enabled" if self.reasoning else "disabled",
62
+ "clear_thinking": self.clear_thinking,
63
+ }
64
+ }