model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,6 @@ import io
2
2
  import json
3
3
  from typing import TYPE_CHECKING, Any, Final, Sequence, cast
4
4
 
5
- from typing_extensions import override
6
-
7
5
  from google.genai.types import (
8
6
  BatchJob,
9
7
  Content,
@@ -11,6 +9,8 @@ from google.genai.types import (
11
9
  JobState,
12
10
  UploadFileConfig,
13
11
  )
12
+ from typing_extensions import override
13
+
14
14
  from model_library.base import BatchResult, InputItem, LLMBatchMixin
15
15
 
16
16
  if TYPE_CHECKING:
@@ -24,16 +24,19 @@ from google.genai.types import (
24
24
  )
25
25
 
26
26
 
27
- def extract_text_from_json_response(response: dict[str, Any]) -> str:
27
+ def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
28
28
  """Extract concatenated non-thought text from a JSON response structure."""
29
29
  # TODO: fix the typing we always ignore
30
30
  text = ""
31
+ reasoning = ""
31
32
  for candidate in response.get("candidates", []) or []: # type: ignore
32
33
  content = (candidate or {}).get("content") or {} # type: ignore
33
34
  for part in content.get("parts", []) or []: # type: ignore
34
- if not part.get("thought", False): # type: ignore
35
+ if part.get("thought", False): # type: ignore
36
+ reasoning += part.get("text", "") # type: ignore
37
+ else:
35
38
  text += part.get("text", "") # type: ignore
36
- return text # type: ignore
39
+ return text, reasoning # type: ignore
37
40
 
38
41
 
39
42
  def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
48
51
  custom_id = data.get("key", "unknown")
49
52
  if "response" in data:
50
53
  response = data["response"]
51
- text = extract_text_from_json_response(response)
54
+ text, reasoning = extract_text_from_json_response(response)
52
55
  output = QueryResult()
53
56
  output.output_text = text
57
+ output.reasoning = reasoning
54
58
  if "usageMetadata" in response:
55
59
  output.metadata.in_tokens = response["usageMetadata"].get(
56
60
  "promptTokenCount", 0
@@ -144,7 +148,7 @@ class GoogleBatchMixin(LLMBatchMixin):
144
148
  **kwargs: object,
145
149
  ) -> dict[str, Any]:
146
150
  self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
147
- body = await self._root.create_body(input, tools=[], **kwargs)
151
+ body = await self._root.build_body(input, tools=[], **kwargs)
148
152
 
149
153
  contents_any = body["contents"]
150
154
  serialized_contents: list[dict[str, Any]] = [
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
196
200
  custom_id = labels.get("qa_pair_id", f"request-{i}")
197
201
  jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
198
202
 
199
- batch_request_file = self._root.client.files.upload(
203
+ batch_request_file = self._root.get_client().files.upload(
200
204
  file=io.StringIO("\n".join(jsonl_lines)),
201
205
  config=UploadFileConfig(mime_type="application/jsonl"),
202
206
  )
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
205
209
  raise Exception("Failed to upload batch jsonl")
206
210
 
207
211
  try:
208
- job: BatchJob = await self._root.client.aio.batches.create(
212
+ job: BatchJob = await self._root.get_client().aio.batches.create(
209
213
  model=self._root.model_name,
210
214
  src=batch_request_file.name,
211
215
  config={"display_name": batch_name},
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
224
228
  async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
225
229
  self._root.logger.info(f"Retrieving batch results for {batch_id}")
226
230
 
227
- job = await self._root.client.aio.batches.get(name=batch_id)
231
+ job = await self._root.get_client().aio.batches.get(name=batch_id)
228
232
 
229
233
  results: list[BatchResult] = []
230
234
 
231
235
  if job.state == JobState.JOB_STATE_SUCCEEDED:
232
236
  if job.dest and job.dest.file_name:
233
237
  results_file_name = job.dest.file_name
234
- file_content = await self._root.client.aio.files.download(
238
+ file_content = await self._root.get_client().aio.files.download(
235
239
  file=results_file_name
236
240
  )
237
241
  decoded = file_content.decode("utf-8")
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
250
254
  @override
251
255
  async def cancel_batch_request(self, batch_id: str):
252
256
  self._root.logger.info(f"Cancelling batch {batch_id}")
253
- await self._root.client.aio.batches.cancel(name=batch_id)
257
+ await self._root.get_client().aio.batches.cancel(name=batch_id)
254
258
 
255
259
  @override
256
260
  async def get_batch_progress(self, batch_id: str) -> int:
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
262
266
 
263
267
  try:
264
268
  self._root.logger.debug(f"Checking batch status for {batch_id}")
265
- job: BatchJob = await self._root.client.aio.batches.get(name=batch_id)
269
+ job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
266
270
  state = job.state
267
271
 
268
272
  if not state:
@@ -1,13 +1,17 @@
1
1
  import base64
2
2
  import io
3
+ import json
3
4
  import logging
5
+ import uuid
4
6
  from typing import Any, Literal, Sequence, cast
5
7
 
6
8
  from google.genai import Client
7
9
  from google.genai import errors as genai_errors
8
10
  from google.genai.types import (
9
11
  Content,
12
+ CountTokensConfig,
10
13
  File,
14
+ FinishReason,
11
15
  FunctionDeclaration,
12
16
  GenerateContentConfig,
13
17
  GenerateContentResponse,
@@ -21,13 +25,14 @@ from google.genai.types import (
21
25
  Tool,
22
26
  ToolListUnion,
23
27
  UploadFileConfig,
24
- FinishReason,
25
28
  )
29
+ from google.oauth2 import service_account
26
30
  from typing_extensions import override
27
31
 
28
32
  from model_library import model_library_settings
29
33
  from model_library.base import (
30
34
  LLM,
35
+ FileBase,
31
36
  FileInput,
32
37
  FileWithBase64,
33
38
  FileWithId,
@@ -40,6 +45,8 @@ from model_library.base import (
40
45
  QueryResult,
41
46
  QueryResultCost,
42
47
  QueryResultMetadata,
48
+ RawInput,
49
+ RawResponse,
43
50
  TextInput,
44
51
  ToolBody,
45
52
  ToolCall,
@@ -54,8 +61,6 @@ from model_library.exceptions import (
54
61
  )
55
62
  from model_library.providers.google.batch import GoogleBatchMixin
56
63
  from model_library.register_models import register_provider
57
- from model_library.utils import normalize_tool_result
58
- import uuid
59
64
 
60
65
 
61
66
  def generate_tool_call_id(tool_name: str) -> str:
@@ -92,31 +97,50 @@ class GoogleModel(LLM):
92
97
  ),
93
98
  ]
94
99
 
95
- @override
96
- def get_client(self) -> Client:
97
- if self.provider_config.use_vertex:
98
- # Preview Gemini releases from September 2025 are only served from the global
99
- # Vertex region. The public docs for these SKUs list `global` as the sole
100
- # availability region (see September 25, 2025 release notes), so we override
101
- # the default `us-central1` when we detect them.
102
- # https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini/2-5-flash
103
- MODEL_REGION_OVERRIDES: dict[str, str] = {
104
- "gemini-2.5-flash-preview-09-2025": "global",
105
- "gemini-2.5-flash-lite-preview-09-2025": "global",
100
+ def _get_default_api_key(self) -> str:
101
+ if not self.provider_config.use_vertex:
102
+ return model_library_settings.GOOGLE_API_KEY
103
+
104
+ return json.dumps(
105
+ {
106
+ "GCP_REGION": model_library_settings.GCP_REGION,
107
+ "GCP_PROJECT_ID": model_library_settings.GCP_PROJECT_ID,
108
+ "GCP_CREDS": model_library_settings.GCP_CREDS,
106
109
  }
107
- region = model_library_settings.GCP_REGION
108
- if self.model_name in MODEL_REGION_OVERRIDES:
109
- region = MODEL_REGION_OVERRIDES[self.model_name]
110
-
111
- return Client(
112
- vertexai=True,
113
- project=model_library_settings.GCP_PROJECT_ID,
114
- location=region,
115
- # Credentials object is not typed, so we have to ignore the error
116
- credentials=model_library_settings.GCP_CREDS,
117
- )
110
+ )
118
111
 
119
- return Client(api_key=model_library_settings.GOOGLE_API_KEY)
112
+ @override
113
+ def get_client(self, api_key: str | None = None) -> Client:
114
+ if not self.has_client():
115
+ assert api_key
116
+ if self.provider_config.use_vertex:
117
+ # Gemini preview releases are only server from the global Vertex region after September 2025.
118
+ MODEL_REGION_OVERRIDES: dict[str, str] = {
119
+ "gemini-2.5-flash-preview-09-2025": "global",
120
+ "gemini-2.5-flash-lite-preview-09-2025": "global",
121
+ "gemini-3-flash-preview": "global",
122
+ "gemini-3-pro-preview": "global",
123
+ }
124
+
125
+ creds = json.loads(api_key)
126
+
127
+ region = creds["GCP_REGION"]
128
+ if self.model_name in MODEL_REGION_OVERRIDES:
129
+ region = MODEL_REGION_OVERRIDES[self.model_name]
130
+
131
+ client = Client(
132
+ vertexai=True,
133
+ project=creds["GCP_PROJECT_ID"],
134
+ location=region,
135
+ credentials=service_account.Credentials.from_service_account_info( # type: ignore
136
+ json.loads(creds["GCP_CREDS"]),
137
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
138
+ ),
139
+ )
140
+ else:
141
+ client = Client(api_key=api_key)
142
+ self.assign_client(client)
143
+ return super().get_client()
120
144
 
121
145
  def __init__(
122
146
  self,
@@ -138,71 +162,58 @@ class GoogleModel(LLM):
138
162
  GoogleBatchMixin(self) if self.supports_batch else None
139
163
  )
140
164
 
141
- self.client = self.get_client()
142
-
143
165
  @override
144
166
  async def parse_input(
145
167
  self,
146
168
  input: Sequence[InputItem],
147
169
  **kwargs: Any,
148
170
  ) -> list[Content]:
149
- parsed_input: list[Content] = []
150
- parts: list[Part] = []
171
+ new_input: list[Content] = []
151
172
 
152
- def flush_parts():
153
- nonlocal parts
173
+ content_user: list[Part] = []
154
174
 
155
- if parts:
156
- parsed_input.append(Content(parts=parts, role="user"))
157
- parts = []
175
+ def flush_content_user():
176
+ if content_user:
177
+ new_input.append(Content(parts=content_user, role="user"))
178
+ content_user.clear()
158
179
 
159
180
  for item in input:
160
- match item:
161
- case TextInput():
162
- if item.text.strip():
163
- parts.append(Part.from_text(text=item.text))
181
+ if isinstance(item, TextInput):
182
+ content_user.append(Part.from_text(text=item.text))
183
+ continue
184
+
185
+ if isinstance(item, FileBase):
186
+ parsed = await self.parse_file(item)
187
+ content_user.append(parsed)
188
+ continue
164
189
 
165
- case FileWithBase64() | FileWithUrl() | FileWithId():
166
- part = await self.parse_file(item)
167
- parts.append(part)
190
+ # non content user item
191
+ flush_content_user()
168
192
 
193
+ match item:
169
194
  case ToolResult():
170
- flush_parts()
171
- result_str = normalize_tool_result(item.result)
172
- parsed_input.append(
195
+ # id check
196
+ new_input.append(
173
197
  Content(
174
198
  role="function",
175
199
  parts=[
176
200
  Part.from_function_response(
177
201
  name=item.tool_call.name,
178
- response={"result": result_str},
202
+ response={"result": item.result},
179
203
  )
180
204
  ],
181
205
  )
182
206
  )
183
207
 
184
- case GenerateContentResponse():
185
- flush_parts()
186
- candidates = item.candidates
187
- if candidates and candidates[0]:
188
- content0 = candidates[0].content
189
- if content0 is not None:
190
- parsed_input.append(content0)
191
- else:
192
- self.logger.debug(
193
- "GenerateContentResponse missing candidates; skipping"
194
- )
195
-
196
- case Content():
197
- flush_parts()
198
- parsed_input.append(item)
208
+ case RawResponse():
209
+ new_input.extend(item.response)
210
+ case RawInput():
211
+ new_input.append(item.input)
199
212
 
200
- case _:
201
- raise BadInputError(f"Unsupported input type: {type(item)}")
213
+ # in case content user item is the last item
214
+ flush_content_user()
202
215
 
203
- flush_parts()
204
-
205
- return parsed_input
216
+ return new_input
206
217
 
207
218
  @override
208
219
  async def parse_file(self, file: FileInput) -> Part:
@@ -268,7 +279,7 @@ class GoogleModel(LLM):
268
279
  )
269
280
 
270
281
  mime = f"image/{mime}" if type == "image" else mime # TODO:
271
- response: File = self.client.files.upload(
282
+ response: File = self.get_client().files.upload(
272
283
  file=bytes, config=UploadFileConfig(mime_type=mime)
273
284
  )
274
285
  if not response.name:
@@ -284,7 +295,8 @@ class GoogleModel(LLM):
284
295
  mime=mime,
285
296
  )
286
297
 
287
- async def create_body(
298
+ @override
299
+ async def build_body(
288
300
  self,
289
301
  input: Sequence[InputItem],
290
302
  *,
@@ -337,7 +349,7 @@ class GoogleModel(LLM):
337
349
  query_logger: logging.Logger,
338
350
  **kwargs: object,
339
351
  ) -> QueryResult:
340
- body: dict[str, Any] = await self.create_body(input, tools=tools, **kwargs)
352
+ body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
341
353
 
342
354
  text: str = ""
343
355
  reasoning: str = ""
@@ -345,7 +357,7 @@ class GoogleModel(LLM):
345
357
 
346
358
  metadata: GenerateContentResponseUsageMetadata | None = None
347
359
 
348
- stream = await self.client.aio.models.generate_content_stream(**body)
360
+ stream = await self.get_client().aio.models.generate_content_stream(**body)
349
361
  contents: list[Content | None] = []
350
362
  finish_reason: FinishReason | None = None
351
363
 
@@ -395,7 +407,7 @@ class GoogleModel(LLM):
395
407
  result = QueryResult(
396
408
  output_text=text,
397
409
  reasoning=reasoning,
398
- history=[*input, *contents],
410
+ history=[*input, RawResponse(response=contents)],
399
411
  tool_calls=tool_calls,
400
412
  )
401
413
 
@@ -410,6 +422,51 @@ class GoogleModel(LLM):
410
422
  )
411
423
  return result
412
424
 
425
+ @override
426
+ async def count_tokens(
427
+ self,
428
+ input: Sequence[InputItem],
429
+ *,
430
+ history: Sequence[InputItem] = [],
431
+ tools: list[ToolDefinition] = [],
432
+ **kwargs: object,
433
+ ) -> int:
434
+ """
435
+ Count the number of tokens using Google's native token counting API.
436
+ https://ai.google.dev/gemini-api/docs/tokens
437
+
438
+ Only Vertex AI supports system_instruction and tools in count_tokens.
439
+ For Gemini API, fall back to the base implementation.
440
+ TODO: implement token counting for non-Vertex models.
441
+ """
442
+ if not self.provider_config.use_vertex:
443
+ return await super().count_tokens(
444
+ input, history=history, tools=tools, **kwargs
445
+ )
446
+
447
+ input = [*history, *input]
448
+ if not input:
449
+ return 0
450
+
451
+ system_prompt = kwargs.pop("system_prompt", None)
452
+ contents = await self.parse_input(input, **kwargs)
453
+ parsed_tools = await self.parse_tools(tools) if tools else None
454
+ config = CountTokensConfig(
455
+ system_instruction=str(system_prompt) if system_prompt else None,
456
+ tools=parsed_tools,
457
+ )
458
+
459
+ response = await self.get_client().aio.models.count_tokens(
460
+ model=self.model_name,
461
+ contents=cast(Any, contents),
462
+ config=config,
463
+ )
464
+
465
+ if response.total_tokens is None:
466
+ raise ValueError("count_tokens returned None")
467
+
468
+ return response.total_tokens
469
+
413
470
  @override
414
471
  async def _calculate_cost(
415
472
  self,
@@ -446,7 +503,7 @@ class GoogleModel(LLM):
446
503
  **kwargs: object,
447
504
  ) -> PydanticT:
448
505
  # Create the request body with JSON schema
449
- body: dict[str, Any] = await self.create_body(input, tools=[], **kwargs)
506
+ body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
450
507
 
451
508
  # Get the JSON schema from the Pydantic model
452
509
  json_schema = pydantic_model.model_json_schema()
@@ -465,7 +522,7 @@ class GoogleModel(LLM):
465
522
  # Make the request with retry wrapper
466
523
  async def _query():
467
524
  try:
468
- return await self.client.aio.models.generate_content(**body)
525
+ return await self.get_client().aio.models.generate_content(**body)
469
526
  except (genai_errors.ServerError, genai_errors.UnknownApiResponseError):
470
527
  raise ImmediateRetryException("Failed to connect to Google API")
471
528
 
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("inception")
@@ -22,13 +23,12 @@ class MercuryModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.inceptionlabs.ai/get-started/get-started#external-libraries-compatibility
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.MERCURY_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.inceptionlabs.ai/v1/",
30
+ api_key=SecretStr(model_library_settings.MERCURY_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -1,13 +1,16 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
+
3
+ from typing_extensions import override
4
+
5
+ from pydantic import SecretStr
2
6
 
3
7
  from model_library import model_library_settings
4
8
  from model_library.base import (
9
+ DelegateConfig,
5
10
  DelegateOnly,
6
11
  LLMConfig,
7
12
  )
8
- from model_library.providers.openai import OpenAIModel
9
13
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
14
 
12
15
 
13
16
  @register_provider("kimi")
@@ -22,13 +25,20 @@ class KimiModel(DelegateOnly):
22
25
  super().__init__(model_name, provider, config=config)
23
26
 
24
27
  # https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-api-compatibility
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
28
+ self.init_delegate(
28
29
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.KIMI_API_KEY,
30
+ delegate_config=DelegateConfig(
31
31
  base_url="https://api.moonshot.ai/v1/",
32
+ api_key=SecretStr(model_library_settings.KIMI_API_KEY),
32
33
  ),
33
34
  use_completions=True,
35
+ delegate_provider="openai",
34
36
  )
37
+
38
+ @override
39
+ def _get_extra_body(self) -> dict[str, Any]:
40
+ """
41
+ Build extra body parameters for Kimi-specific features.
42
+ see https://platform.moonshot.ai/docs/guide/kimi-k2-5-quickstart#parameters-differences-in-request-body
43
+ """
44
+ return {"thinking": {"type": "enabled" if self.reasoning else "disabled"}}
@@ -1,12 +1,17 @@
1
- from typing import Literal
1
+ from typing import Literal, Sequence
2
+
3
+ from pydantic import SecretStr
4
+ from typing_extensions import override
2
5
 
3
6
  from model_library import model_library_settings
4
- from model_library.base import DelegateOnly, LLMConfig
5
- from model_library.providers.anthropic import AnthropicModel
7
+ from model_library.base import (
8
+ DelegateConfig,
9
+ DelegateOnly,
10
+ InputItem,
11
+ LLMConfig,
12
+ ToolDefinition,
13
+ )
6
14
  from model_library.register_models import register_provider
7
- from model_library.utils import default_httpx_client
8
-
9
- from anthropic import AsyncAnthropic
10
15
 
11
16
 
12
17
  @register_provider("minimax")
@@ -20,14 +25,26 @@ class MinimaxModel(DelegateOnly):
20
25
  ):
21
26
  super().__init__(model_name, provider, config=config)
22
27
 
23
- self.delegate = AnthropicModel(
24
- model_name=self.model_name,
25
- provider=self.provider,
28
+ self.init_delegate(
26
29
  config=config,
27
- custom_client=AsyncAnthropic(
28
- api_key=model_library_settings.MINIMAX_API_KEY,
30
+ delegate_config=DelegateConfig(
29
31
  base_url="https://api.minimax.io/anthropic",
30
- http_client=default_httpx_client(),
31
- max_retries=1,
32
+ api_key=SecretStr(model_library_settings.MINIMAX_API_KEY),
32
33
  ),
34
+ delegate_provider="anthropic",
35
+ )
36
+
37
+ # minimax client shares anthropic's syntax
38
+ @override
39
+ async def count_tokens(
40
+ self,
41
+ input: Sequence[InputItem],
42
+ *,
43
+ history: Sequence[InputItem] = [],
44
+ tools: list[ToolDefinition] = [],
45
+ **kwargs: object,
46
+ ) -> int:
47
+ assert self.delegate
48
+ return await self.delegate.count_tokens(
49
+ input, history=history, tools=tools, **kwargs
33
50
  )