model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/base.py +139 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +44 -57
  6. model_library/config/all_models.json +253 -126
  7. model_library/config/kimi_models.yaml +30 -3
  8. model_library/config/openai_models.yaml +15 -23
  9. model_library/config/zai_models.yaml +24 -3
  10. model_library/exceptions.py +3 -77
  11. model_library/providers/ai21labs.py +12 -8
  12. model_library/providers/alibaba.py +17 -8
  13. model_library/providers/amazon.py +49 -16
  14. model_library/providers/anthropic.py +93 -40
  15. model_library/providers/azure.py +22 -10
  16. model_library/providers/cohere.py +7 -7
  17. model_library/providers/deepseek.py +8 -8
  18. model_library/providers/fireworks.py +7 -8
  19. model_library/providers/google/batch.py +14 -10
  20. model_library/providers/google/google.py +48 -29
  21. model_library/providers/inception.py +7 -7
  22. model_library/providers/kimi.py +18 -8
  23. model_library/providers/minimax.py +15 -17
  24. model_library/providers/mistral.py +20 -8
  25. model_library/providers/openai.py +99 -22
  26. model_library/providers/openrouter.py +34 -0
  27. model_library/providers/perplexity.py +7 -7
  28. model_library/providers/together.py +7 -8
  29. model_library/providers/vals.py +12 -6
  30. model_library/providers/xai.py +47 -42
  31. model_library/providers/zai.py +38 -8
  32. model_library/registry_utils.py +39 -15
  33. model_library/retriers/__init__.py +0 -0
  34. model_library/retriers/backoff.py +73 -0
  35. model_library/retriers/base.py +225 -0
  36. model_library/retriers/token.py +427 -0
  37. model_library/retriers/utils.py +11 -0
  38. model_library/settings.py +1 -1
  39. model_library/utils.py +13 -0
  40. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
  41. model_library-0.1.8.dist-info/RECORD +70 -0
  42. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  43. model_library-0.1.7.dist-info/RECORD +0 -64
  44. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
9
+ DelegateOnly,
7
10
  LLMConfig,
8
11
  ProviderConfig,
9
12
  QueryResultCost,
10
13
  QueryResultMetadata,
11
14
  )
12
- from model_library.base.delegate_only import DelegateOnly
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class FireworksConfig(ProviderConfig):
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
38
38
  self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
39
39
 
40
40
  # https://docs.fireworks.ai/tools-sdks/openai-compatibility
41
- self.delegate = OpenAIModel(
42
- model_name=self.model_name,
43
- provider=self.provider,
41
+ self.init_delegate(
44
42
  config=config,
45
- custom_client=create_openai_client_with_defaults(
46
- api_key=model_library_settings.FIREWORKS_API_KEY,
43
+ delegate_config=DelegateConfig(
47
44
  base_url="https://api.fireworks.ai/inference/v1",
45
+ api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
48
46
  ),
49
47
  use_completions=True,
48
+ delegate_provider="openai",
50
49
  )
51
50
 
52
51
  @override
@@ -24,16 +24,19 @@ from google.genai.types import (
24
24
  )
25
25
 
26
26
 
27
- def extract_text_from_json_response(response: dict[str, Any]) -> str:
27
+ def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
28
28
  """Extract concatenated non-thought text from a JSON response structure."""
29
29
  # TODO: fix the typing we always ignore
30
30
  text = ""
31
+ reasoning = ""
31
32
  for candidate in response.get("candidates", []) or []: # type: ignore
32
33
  content = (candidate or {}).get("content") or {} # type: ignore
33
34
  for part in content.get("parts", []) or []: # type: ignore
34
- if not part.get("thought", False): # type: ignore
35
+ if part.get("thought", False): # type: ignore
36
+ reasoning += part.get("text", "") # type: ignore
37
+ else:
35
38
  text += part.get("text", "") # type: ignore
36
- return text # type: ignore
39
+ return text, reasoning # type: ignore
37
40
 
38
41
 
39
42
  def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
48
51
  custom_id = data.get("key", "unknown")
49
52
  if "response" in data:
50
53
  response = data["response"]
51
- text = extract_text_from_json_response(response)
54
+ text, reasoning = extract_text_from_json_response(response)
52
55
  output = QueryResult()
53
56
  output.output_text = text
57
+ output.reasoning = reasoning
54
58
  if "usageMetadata" in response:
55
59
  output.metadata.in_tokens = response["usageMetadata"].get(
56
60
  "promptTokenCount", 0
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
196
200
  custom_id = labels.get("qa_pair_id", f"request-{i}")
197
201
  jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
198
202
 
199
- batch_request_file = self._root.client.files.upload(
203
+ batch_request_file = self._root.get_client().files.upload(
200
204
  file=io.StringIO("\n".join(jsonl_lines)),
201
205
  config=UploadFileConfig(mime_type="application/jsonl"),
202
206
  )
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
205
209
  raise Exception("Failed to upload batch jsonl")
206
210
 
207
211
  try:
208
- job: BatchJob = await self._root.client.aio.batches.create(
212
+ job: BatchJob = await self._root.get_client().aio.batches.create(
209
213
  model=self._root.model_name,
210
214
  src=batch_request_file.name,
211
215
  config={"display_name": batch_name},
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
224
228
  async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
225
229
  self._root.logger.info(f"Retrieving batch results for {batch_id}")
226
230
 
227
- job = await self._root.client.aio.batches.get(name=batch_id)
231
+ job = await self._root.get_client().aio.batches.get(name=batch_id)
228
232
 
229
233
  results: list[BatchResult] = []
230
234
 
231
235
  if job.state == JobState.JOB_STATE_SUCCEEDED:
232
236
  if job.dest and job.dest.file_name:
233
237
  results_file_name = job.dest.file_name
234
- file_content = await self._root.client.aio.files.download(
238
+ file_content = await self._root.get_client().aio.files.download(
235
239
  file=results_file_name
236
240
  )
237
241
  decoded = file_content.decode("utf-8")
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
250
254
  @override
251
255
  async def cancel_batch_request(self, batch_id: str):
252
256
  self._root.logger.info(f"Cancelling batch {batch_id}")
253
- await self._root.client.aio.batches.cancel(name=batch_id)
257
+ await self._root.get_client().aio.batches.cancel(name=batch_id)
254
258
 
255
259
  @override
256
260
  async def get_batch_progress(self, batch_id: str) -> int:
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
262
266
 
263
267
  try:
264
268
  self._root.logger.debug(f"Checking batch status for {batch_id}")
265
- job: BatchJob = await self._root.client.aio.batches.get(name=batch_id)
269
+ job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
266
270
  state = job.state
267
271
 
268
272
  if not state:
@@ -1,5 +1,6 @@
1
1
  import base64
2
2
  import io
3
+ import json
3
4
  import logging
4
5
  import uuid
5
6
  from typing import Any, Literal, Sequence, cast
@@ -25,6 +26,7 @@ from google.genai.types import (
25
26
  ToolListUnion,
26
27
  UploadFileConfig,
27
28
  )
29
+ from google.oauth2 import service_account
28
30
  from typing_extensions import override
29
31
 
30
32
  from model_library import model_library_settings
@@ -95,31 +97,50 @@ class GoogleModel(LLM):
95
97
  ),
96
98
  ]
97
99
 
98
- @override
99
- def get_client(self) -> Client:
100
- if self.provider_config.use_vertex:
101
- # Preview Gemini releases from September 2025 are only served from the global
102
- # Vertex region. The public docs for these SKUs list `global` as the sole
103
- # availability region (see September 25, 2025 release notes), so we override
104
- # the default `us-central1` when we detect them.
105
- # https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini/2-5-flash
106
- MODEL_REGION_OVERRIDES: dict[str, str] = {
107
- "gemini-2.5-flash-preview-09-2025": "global",
108
- "gemini-2.5-flash-lite-preview-09-2025": "global",
100
+ def _get_default_api_key(self) -> str:
101
+ if not self.provider_config.use_vertex:
102
+ return model_library_settings.GOOGLE_API_KEY
103
+
104
+ return json.dumps(
105
+ {
106
+ "GCP_REGION": model_library_settings.GCP_REGION,
107
+ "GCP_PROJECT_ID": model_library_settings.GCP_PROJECT_ID,
108
+ "GCP_CREDS": model_library_settings.GCP_CREDS,
109
109
  }
110
- region = model_library_settings.GCP_REGION
111
- if self.model_name in MODEL_REGION_OVERRIDES:
112
- region = MODEL_REGION_OVERRIDES[self.model_name]
113
-
114
- return Client(
115
- vertexai=True,
116
- project=model_library_settings.GCP_PROJECT_ID,
117
- location=region,
118
- # Credentials object is not typed, so we have to ignore the error
119
- credentials=model_library_settings.GCP_CREDS,
120
- )
110
+ )
121
111
 
122
- return Client(api_key=model_library_settings.GOOGLE_API_KEY)
112
+ @override
113
+ def get_client(self, api_key: str | None = None) -> Client:
114
+ if not self.has_client():
115
+ assert api_key
116
+ if self.provider_config.use_vertex:
117
+ # Gemini preview releases are only server from the global Vertex region after September 2025.
118
+ MODEL_REGION_OVERRIDES: dict[str, str] = {
119
+ "gemini-2.5-flash-preview-09-2025": "global",
120
+ "gemini-2.5-flash-lite-preview-09-2025": "global",
121
+ "gemini-3-flash-preview": "global",
122
+ "gemini-3-pro-preview": "global",
123
+ }
124
+
125
+ creds = json.loads(api_key)
126
+
127
+ region = creds["GCP_REGION"]
128
+ if self.model_name in MODEL_REGION_OVERRIDES:
129
+ region = MODEL_REGION_OVERRIDES[self.model_name]
130
+
131
+ client = Client(
132
+ vertexai=True,
133
+ project=creds["GCP_PROJECT_ID"],
134
+ location=region,
135
+ credentials=service_account.Credentials.from_service_account_info( # type: ignore
136
+ json.loads(creds["GCP_CREDS"]),
137
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
138
+ ),
139
+ )
140
+ else:
141
+ client = Client(api_key=api_key)
142
+ self.assign_client(client)
143
+ return super().get_client()
123
144
 
124
145
  def __init__(
125
146
  self,
@@ -141,8 +162,6 @@ class GoogleModel(LLM):
141
162
  GoogleBatchMixin(self) if self.supports_batch else None
142
163
  )
143
164
 
144
- self.client = self.get_client()
145
-
146
165
  @override
147
166
  async def parse_input(
148
167
  self,
@@ -260,7 +279,7 @@ class GoogleModel(LLM):
260
279
  )
261
280
 
262
281
  mime = f"image/{mime}" if type == "image" else mime # TODO:
263
- response: File = self.client.files.upload(
282
+ response: File = self.get_client().files.upload(
264
283
  file=bytes, config=UploadFileConfig(mime_type=mime)
265
284
  )
266
285
  if not response.name:
@@ -338,7 +357,7 @@ class GoogleModel(LLM):
338
357
 
339
358
  metadata: GenerateContentResponseUsageMetadata | None = None
340
359
 
341
- stream = await self.client.aio.models.generate_content_stream(**body)
360
+ stream = await self.get_client().aio.models.generate_content_stream(**body)
342
361
  contents: list[Content | None] = []
343
362
  finish_reason: FinishReason | None = None
344
363
 
@@ -437,7 +456,7 @@ class GoogleModel(LLM):
437
456
  tools=parsed_tools,
438
457
  )
439
458
 
440
- response = await self.client.aio.models.count_tokens(
459
+ response = await self.get_client().aio.models.count_tokens(
441
460
  model=self.model_name,
442
461
  contents=cast(Any, contents),
443
462
  config=config,
@@ -503,7 +522,7 @@ class GoogleModel(LLM):
503
522
  # Make the request with retry wrapper
504
523
  async def _query():
505
524
  try:
506
- return await self.client.aio.models.generate_content(**body)
525
+ return await self.get_client().aio.models.generate_content(**body)
507
526
  except (genai_errors.ServerError, genai_errors.UnknownApiResponseError):
508
527
  raise ImmediateRetryException("Failed to connect to Google API")
509
528
 
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("inception")
@@ -22,13 +23,12 @@ class MercuryModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.inceptionlabs.ai/get-started/get-started#external-libraries-compatibility
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.MERCURY_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.inceptionlabs.ai/v1/",
30
+ api_key=SecretStr(model_library_settings.MERCURY_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -1,13 +1,16 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
+
3
+ from typing_extensions import override
4
+
5
+ from pydantic import SecretStr
2
6
 
3
7
  from model_library import model_library_settings
4
8
  from model_library.base import (
9
+ DelegateConfig,
5
10
  DelegateOnly,
6
11
  LLMConfig,
7
12
  )
8
- from model_library.providers.openai import OpenAIModel
9
13
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
14
 
12
15
 
13
16
  @register_provider("kimi")
@@ -22,13 +25,20 @@ class KimiModel(DelegateOnly):
22
25
  super().__init__(model_name, provider, config=config)
23
26
 
24
27
  # https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-api-compatibility
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
28
+ self.init_delegate(
28
29
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.KIMI_API_KEY,
30
+ delegate_config=DelegateConfig(
31
31
  base_url="https://api.moonshot.ai/v1/",
32
+ api_key=SecretStr(model_library_settings.KIMI_API_KEY),
32
33
  ),
33
34
  use_completions=True,
35
+ delegate_provider="openai",
34
36
  )
37
+
38
+ @override
39
+ def _get_extra_body(self) -> dict[str, Any]:
40
+ """
41
+ Build extra body parameters for Kimi-specific features.
42
+ see https://platform.moonshot.ai/docs/guide/kimi-k2-5-quickstart#parameters-differences-in-request-body
43
+ """
44
+ return {"thinking": {"type": "enabled" if self.reasoning else "disabled"}}
@@ -1,16 +1,17 @@
1
- from typing import Literal
1
+ from typing import Literal, Sequence
2
+
3
+ from pydantic import SecretStr
4
+ from typing_extensions import override
2
5
 
3
6
  from model_library import model_library_settings
4
- from model_library.base import DelegateOnly, LLMConfig
5
- from model_library.base.input import InputItem, ToolDefinition
6
- from model_library.providers.anthropic import AnthropicModel
7
+ from model_library.base import (
8
+ DelegateConfig,
9
+ DelegateOnly,
10
+ InputItem,
11
+ LLMConfig,
12
+ ToolDefinition,
13
+ )
7
14
  from model_library.register_models import register_provider
8
- from model_library.utils import default_httpx_client
9
-
10
- from anthropic import AsyncAnthropic
11
-
12
- from typing import Sequence
13
- from typing_extensions import override
14
15
 
15
16
 
16
17
  @register_provider("minimax")
@@ -24,16 +25,13 @@ class MinimaxModel(DelegateOnly):
24
25
  ):
25
26
  super().__init__(model_name, provider, config=config)
26
27
 
27
- self.delegate = AnthropicModel(
28
- model_name=self.model_name,
29
- provider=self.provider,
28
+ self.init_delegate(
30
29
  config=config,
31
- custom_client=AsyncAnthropic(
32
- api_key=model_library_settings.MINIMAX_API_KEY,
30
+ delegate_config=DelegateConfig(
33
31
  base_url="https://api.minimax.io/anthropic",
34
- http_client=default_httpx_client(),
35
- max_retries=1,
32
+ api_key=SecretStr(model_library_settings.MINIMAX_API_KEY),
36
33
  ),
34
+ delegate_provider="anthropic",
37
35
  )
38
36
 
39
37
  # minimax client shares anthropic's syntax
@@ -3,7 +3,13 @@ import logging
3
3
  from collections.abc import Sequence
4
4
  from typing import Any, Literal
5
5
 
6
- from mistralai import AssistantMessage, ContentChunk, Mistral, TextChunk, ThinkChunk
6
+ from mistralai import (
7
+ AssistantMessage,
8
+ ContentChunk,
9
+ Mistral,
10
+ TextChunk,
11
+ ThinkChunk,
12
+ )
7
13
  from mistralai.models.completionevent import CompletionEvent
8
14
  from mistralai.models.toolcall import ToolCall as MistralToolCall
9
15
  from mistralai.utils.eventstreaming import EventStreamAsync
@@ -40,16 +46,20 @@ from model_library.utils import default_httpx_client
40
46
 
41
47
  @register_provider("mistralai")
42
48
  class MistralModel(LLM):
43
- _client: Mistral | None = None
49
+ @override
50
+ def _get_default_api_key(self) -> str:
51
+ return model_library_settings.MISTRAL_API_KEY
44
52
 
45
53
  @override
46
- def get_client(self) -> Mistral:
47
- if not MistralModel._client:
48
- MistralModel._client = Mistral(
49
- api_key=model_library_settings.MISTRAL_API_KEY,
54
+ def get_client(self, api_key: str | None = None) -> Mistral:
55
+ if not self.has_client():
56
+ assert api_key
57
+ client = Mistral(
58
+ api_key=api_key,
50
59
  async_client=default_httpx_client(),
51
60
  )
52
- return MistralModel._client
61
+ self.assign_client(client)
62
+ return super().get_client()
53
63
 
54
64
  def __init__(
55
65
  self,
@@ -198,12 +208,14 @@ class MistralModel(LLM):
198
208
 
199
209
  body: dict[str, Any] = {
200
210
  "model": self.model_name,
201
- "max_tokens": self.max_tokens,
202
211
  "messages": messages,
203
212
  "prompt_mode": "reasoning" if self.reasoning else None,
204
213
  "tools": tools,
205
214
  }
206
215
 
216
+ if self.max_tokens:
217
+ body["max_tokens"] = self.max_tokens
218
+
207
219
  if self.supports_temperature:
208
220
  if self.temperature is not None:
209
221
  body["temperature"] = self.temperature