model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. model_library/base/base.py +141 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +49 -57
  6. model_library/config/all_models.json +353 -120
  7. model_library/config/anthropic_models.yaml +2 -1
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/mistral_models.yaml +2 -0
  10. model_library/config/openai_models.yaml +15 -23
  11. model_library/config/together_models.yaml +2 -0
  12. model_library/config/xiaomi_models.yaml +43 -0
  13. model_library/config/zai_models.yaml +27 -3
  14. model_library/exceptions.py +3 -77
  15. model_library/providers/ai21labs.py +12 -8
  16. model_library/providers/alibaba.py +17 -8
  17. model_library/providers/amazon.py +49 -16
  18. model_library/providers/anthropic.py +128 -48
  19. model_library/providers/azure.py +22 -10
  20. model_library/providers/cohere.py +7 -7
  21. model_library/providers/deepseek.py +8 -8
  22. model_library/providers/fireworks.py +7 -8
  23. model_library/providers/google/batch.py +14 -10
  24. model_library/providers/google/google.py +57 -30
  25. model_library/providers/inception.py +7 -7
  26. model_library/providers/kimi.py +18 -8
  27. model_library/providers/minimax.py +15 -17
  28. model_library/providers/mistral.py +20 -8
  29. model_library/providers/openai.py +99 -22
  30. model_library/providers/openrouter.py +34 -0
  31. model_library/providers/perplexity.py +7 -7
  32. model_library/providers/together.py +7 -8
  33. model_library/providers/vals.py +12 -6
  34. model_library/providers/vercel.py +34 -0
  35. model_library/providers/xai.py +47 -42
  36. model_library/providers/xiaomi.py +34 -0
  37. model_library/providers/zai.py +38 -8
  38. model_library/register_models.py +5 -0
  39. model_library/registry_utils.py +48 -17
  40. model_library/retriers/__init__.py +0 -0
  41. model_library/retriers/backoff.py +73 -0
  42. model_library/retriers/base.py +225 -0
  43. model_library/retriers/token.py +427 -0
  44. model_library/retriers/utils.py +11 -0
  45. model_library/settings.py +1 -1
  46. model_library/utils.py +17 -7
  47. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
  48. model_library-0.1.9.dist-info/RECORD +73 -0
  49. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
  50. model_library-0.1.7.dist-info/RECORD +0 -64
  51. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
  52. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
7
9
  DelegateOnly,
8
10
  LLMConfig,
9
11
  ProviderConfig,
10
12
  QueryResultCost,
11
13
  QueryResultMetadata,
12
14
  )
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class TogetherConfig(ProviderConfig):
@@ -32,15 +32,14 @@ class TogetherModel(DelegateOnly):
32
32
  ):
33
33
  super().__init__(model_name, provider, config=config)
34
34
  # https://docs.together.ai/docs/openai-api-compatibility
35
- self.delegate = OpenAIModel(
36
- model_name=self.model_name,
37
- provider=self.provider,
35
+ self.init_delegate(
38
36
  config=config,
39
- custom_client=create_openai_client_with_defaults(
40
- api_key=model_library_settings.TOGETHER_API_KEY,
41
- base_url="https://api.together.xyz/v1",
37
+ delegate_config=DelegateConfig(
38
+ base_url="https://api.together.xyz/v1/",
39
+ api_key=SecretStr(model_library_settings.TOGETHER_API_KEY),
42
40
  ),
43
41
  use_completions=True,
42
+ delegate_provider="openai",
44
43
  )
45
44
 
46
45
  @override
@@ -151,13 +151,17 @@ class DummyAIBatchMixin(LLMBatchMixin):
151
151
  class DummyAIModel(LLM):
152
152
  _client: Redis | None = None
153
153
 
154
- @override
155
- def get_client(self) -> Redis:
156
- if not DummyAIModel._client:
157
- DummyAIModel._client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
154
+ def _get_default_api_key(self) -> str:
155
+ return model_library_settings.REDIS_URL
156
+
157
+ def get_client(self, api_key: str | None = None) -> Redis:
158
+ if not self.has_client():
159
+ assert api_key
160
+ client = redis.from_url( # pyright: ignore[reportUnknownMemberType]
158
161
  model_library_settings.REDIS_URL, decode_responses=True
159
162
  )
160
- return DummyAIModel._client
163
+ self.assign_client(client)
164
+ return super().get_client()
161
165
 
162
166
  def __init__(
163
167
  self,
@@ -238,12 +242,14 @@ class DummyAIModel(LLM):
238
242
  messages = await self.parse_input(input)
239
243
  body: dict[str, Any] = {
240
244
  "model": self.model_name,
241
- "max_tokens": self.max_tokens,
242
245
  "seed": 0,
243
246
  "messages": messages,
244
247
  "tools": await self.parse_tools(tools),
245
248
  }
246
249
 
250
+ if self.max_tokens:
251
+ body["max_tokens"] = self.max_tokens
252
+
247
253
  if self.supports_temperature:
248
254
  if self.temperature is not None:
249
255
  body["temperature"] = self.temperature
@@ -0,0 +1,34 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import SecretStr
4
+
5
+ from model_library import model_library_settings
6
+ from model_library.base import (
7
+ DelegateConfig,
8
+ DelegateOnly,
9
+ LLMConfig,
10
+ )
11
+ from model_library.register_models import register_provider
12
+
13
+
14
+ @register_provider("vercel")
15
+ class VercelModel(DelegateOnly):
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ provider: Literal["vercel"] = "vercel",
20
+ *,
21
+ config: LLMConfig | None = None,
22
+ ):
23
+ super().__init__(model_name, provider, config=config)
24
+
25
+ # https://vercel.com/docs/ai-gateway/sdks-and-apis#quick-start
26
+ self.init_delegate(
27
+ config=config,
28
+ delegate_config=DelegateConfig(
29
+ base_url="https://ai-gateway.vercel.sh/v1",
30
+ api_key=SecretStr(model_library_settings.VERCEL_API_KEY),
31
+ ),
32
+ use_completions=True,
33
+ delegate_provider="openai",
34
+ )
@@ -2,7 +2,7 @@ import io
2
2
  import logging
3
3
  from typing import Any, Literal, Sequence
4
4
 
5
- import grpc
5
+ from pydantic import SecretStr
6
6
  from typing_extensions import override
7
7
  from xai_sdk import AsyncClient
8
8
  from xai_sdk.aio.chat import Chat
@@ -14,6 +14,7 @@ from xai_sdk.proto.v6.chat_pb2 import Message, Tool
14
14
  from model_library import model_library_settings
15
15
  from model_library.base import (
16
16
  LLM,
17
+ DelegateConfig,
17
18
  FileBase,
18
19
  FileInput,
19
20
  FileWithBase64,
@@ -36,24 +37,26 @@ from model_library.exceptions import (
36
37
  MaxOutputTokensExceededError,
37
38
  ModelNoOutputError,
38
39
  NoMatchingToolCallError,
39
- RateLimitException,
40
40
  )
41
41
  from model_library.providers.openai import OpenAIModel
42
42
  from model_library.register_models import register_provider
43
- from model_library.utils import create_openai_client_with_defaults
44
43
 
45
44
 
46
45
  @register_provider("grok")
47
46
  class XAIModel(LLM):
48
- _client: AsyncClient | None = None
47
+ @override
48
+ def _get_default_api_key(self) -> str:
49
+ return model_library_settings.XAI_API_KEY
49
50
 
50
51
  @override
51
- def get_client(self) -> AsyncClient:
52
- if not XAIModel._client:
53
- XAIModel._client = AsyncClient(
54
- api_key=model_library_settings.XAI_API_KEY,
52
+ def get_client(self, api_key: str | None = None) -> AsyncClient:
53
+ if not self.has_client():
54
+ assert api_key
55
+ client = AsyncClient(
56
+ api_key=api_key,
55
57
  )
56
- return XAIModel._client
58
+ self.assign_client(client)
59
+ return super().get_client()
57
60
 
58
61
  @override
59
62
  def __init__(
@@ -73,13 +76,13 @@ class XAIModel(LLM):
73
76
  model_name=self.model_name,
74
77
  provider=provider,
75
78
  config=config,
76
- custom_client=create_openai_client_with_defaults(
77
- api_key=model_library_settings.XAI_API_KEY,
79
+ delegate_config=DelegateConfig(
78
80
  base_url=(
79
81
  "https://us-west-1.api.x.ai/v1"
80
82
  if "grok-3-mini-reasoning" in self.model_name
81
83
  else "https://api.x.ai/v1"
82
84
  ),
85
+ api_key=SecretStr(model_library_settings.XAI_API_KEY),
83
86
  ),
84
87
  use_completions=True,
85
88
  )
@@ -210,12 +213,14 @@ class XAIModel(LLM):
210
213
  messages.append(system(str(kwargs.pop("system_prompt"))))
211
214
 
212
215
  body: dict[str, Any] = {
213
- "max_tokens": self.max_tokens,
214
216
  "model": self.model_name,
215
217
  "tools": await self.parse_tools(tools),
216
218
  "messages": messages,
217
219
  }
218
220
 
221
+ if self.max_tokens:
222
+ body["max_tokens"] = self.max_tokens
223
+
219
224
  if self.supports_temperature:
220
225
  if self.temperature is not None:
221
226
  body["temperature"] = self.temperature
@@ -253,38 +258,35 @@ class XAIModel(LLM):
253
258
 
254
259
  body = await self.build_body(input, tools=tools, **kwargs)
255
260
 
256
- try:
257
- chat: Chat = self.get_client().chat.create(**body)
258
-
259
- latest_response: Response | None = None
260
- async for response, _ in chat.stream():
261
- latest_response = response
262
-
263
- if not latest_response:
264
- raise ModelNoOutputError("Model failed to produce a response")
265
-
266
- tool_calls: list[ToolCall] = []
267
- if (
268
- latest_response.finish_reason == "REASON_TOOL_CALLS"
269
- and latest_response.tool_calls
270
- ):
271
- for tool_call in latest_response.tool_calls:
272
- tool_calls.append(
273
- ToolCall(
274
- id=tool_call.id,
275
- name=tool_call.function.name,
276
- args=tool_call.function.arguments,
277
- )
261
+ chat: Chat = self.get_client().chat.create(**body)
262
+
263
+ latest_response: Response | None = None
264
+ async for response, _ in chat.stream():
265
+ latest_response = response
266
+
267
+ if not latest_response:
268
+ raise ModelNoOutputError("Model failed to produce a response")
269
+
270
+ tool_calls: list[ToolCall] = []
271
+ if (
272
+ latest_response.finish_reason == "REASON_TOOL_CALLS"
273
+ and latest_response.tool_calls
274
+ ):
275
+ for tool_call in latest_response.tool_calls:
276
+ tool_calls.append(
277
+ ToolCall(
278
+ id=tool_call.id,
279
+ name=tool_call.function.name,
280
+ args=tool_call.function.arguments,
278
281
  )
282
+ )
279
283
 
280
- if (
281
- latest_response.finish_reason == "REASON_MAX_LEN"
282
- and not latest_response.content
283
- and not latest_response.reasoning_content
284
- ):
285
- raise MaxOutputTokensExceededError()
286
- except grpc.RpcError as e:
287
- raise RateLimitException(e.details())
284
+ if (
285
+ latest_response.finish_reason == "REASON_MAX_LEN"
286
+ and not latest_response.content
287
+ and not latest_response.reasoning_content
288
+ ):
289
+ raise MaxOutputTokensExceededError()
288
290
 
289
291
  return QueryResult(
290
292
  output_text=latest_response.content,
@@ -310,6 +312,9 @@ class XAIModel(LLM):
310
312
  tools: list[ToolDefinition] = [],
311
313
  **kwargs: object,
312
314
  ) -> int:
315
+ if not input and not history:
316
+ return 0
317
+
313
318
  string_input = await self.stringify_input(input, history=history, tools=tools)
314
319
  self.logger.debug(string_input)
315
320
 
@@ -0,0 +1,34 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import SecretStr
4
+
5
+ from model_library import model_library_settings
6
+ from model_library.base import (
7
+ DelegateConfig,
8
+ DelegateOnly,
9
+ LLMConfig,
10
+ )
11
+ from model_library.register_models import register_provider
12
+
13
+
14
+ @register_provider("xiaomi")
15
+ class XiaomiModel(DelegateOnly):
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ provider: Literal["xiaomi"] = "xiaomi",
20
+ *,
21
+ config: LLMConfig | None = None,
22
+ ):
23
+ super().__init__(model_name, provider, config=config)
24
+
25
+ # https://platform.xiaomimimo.com/#/docs/quick-start/first-api-call
26
+ self.init_delegate(
27
+ config=config,
28
+ delegate_config=DelegateConfig(
29
+ base_url="https://api.xiaomimimo.com/v1",
30
+ api_key=SecretStr(model_library_settings.XIAOMI_API_KEY),
31
+ ),
32
+ use_completions=True,
33
+ delegate_provider="openai",
34
+ )
@@ -1,17 +1,36 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
+
3
+ from pydantic import SecretStr
4
+ from typing_extensions import override
2
5
 
3
6
  from model_library import model_library_settings
4
7
  from model_library.base import (
8
+ DelegateConfig,
5
9
  DelegateOnly,
6
10
  LLMConfig,
11
+ ProviderConfig,
7
12
  )
8
- from model_library.providers.openai import OpenAIModel
9
13
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
14
+
15
+
16
+ class ZAIConfig(ProviderConfig):
17
+ """Configuration for ZAI (GLM) models.
18
+
19
+ Attributes:
20
+ clear_thinking: When disabled, reasoning content from previous turns is
21
+ preserved in context. This is useful for multi-turn conversations where
22
+ you want the model to maintain coherent reasoning across turns.
23
+ Enabled by default on the standard API endpoint.
24
+ See: https://docs.z.ai/guides/capabilities/thinking-mode
25
+ """
26
+
27
+ clear_thinking: bool = True
11
28
 
12
29
 
13
30
  @register_provider("zai")
14
31
  class ZAIModel(DelegateOnly):
32
+ provider_config = ZAIConfig()
33
+
15
34
  def __init__(
16
35
  self,
17
36
  model_name: str,
@@ -21,14 +40,25 @@ class ZAIModel(DelegateOnly):
21
40
  ):
22
41
  super().__init__(model_name, provider, config=config)
23
42
 
43
+ self.clear_thinking = self.provider_config.clear_thinking
44
+
24
45
  # https://docs.z.ai/guides/develop/openai/python
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
46
+ self.init_delegate(
28
47
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.ZAI_API_KEY,
48
+ delegate_config=DelegateConfig(
31
49
  base_url="https://open.bigmodel.cn/api/paas/v4/",
50
+ api_key=SecretStr(model_library_settings.ZAI_API_KEY),
32
51
  ),
33
52
  use_completions=True,
53
+ delegate_provider="openai",
34
54
  )
55
+
56
+ @override
57
+ def _get_extra_body(self) -> dict[str, Any]:
58
+ """Build extra body parameters for GLM-specific features."""
59
+ return {
60
+ "thinking": {
61
+ "type": "enabled" if self.reasoning else "disabled",
62
+ "clear_thinking": self.clear_thinking,
63
+ }
64
+ }
@@ -170,6 +170,7 @@ class DefaultParameters(BaseModel):
170
170
  top_p: float | None = None
171
171
  top_k: int | None = None
172
172
  reasoning_effort: str | bool | None = None
173
+ compute_effort: str | bool | None = None
173
174
 
174
175
 
175
176
  class RawModelConfig(BaseModel):
@@ -338,6 +339,10 @@ def _register_models() -> ModelRegistry:
338
339
  copy.slug = key.replace("/", "_")
339
340
  copy.full_key = key
340
341
  copy.alternative_keys = []
342
+ copy.provider_properties = ProviderProperties.model_validate(
343
+ provider_properties
344
+ )
345
+
341
346
  registry[key] = copy
342
347
 
343
348
  return registry
@@ -1,11 +1,15 @@
1
+ import tiktoken
1
2
  from functools import cache
2
3
  from pathlib import Path
3
4
  from typing import TypedDict
4
5
 
5
- import tiktoken
6
-
7
- from model_library.base import LLM, LLMConfig, ProviderConfig
8
- from model_library.base.output import QueryResultCost, QueryResultMetadata
6
+ from model_library.base import (
7
+ LLM,
8
+ LLMConfig,
9
+ ProviderConfig,
10
+ QueryResultCost,
11
+ QueryResultMetadata,
12
+ )
9
13
  from model_library.register_models import (
10
14
  CostProperties,
11
15
  ModelConfig,
@@ -196,19 +200,46 @@ def get_provider_names() -> list[str]:
196
200
 
197
201
 
198
202
  @cache
199
- def get_model_names() -> list[str]:
200
- """Return all model names in the registry"""
201
- return sorted([model_name for model_name in get_model_registry().keys()])
203
+ def get_model_names(
204
+ provider: str | None = None,
205
+ include_deprecated: bool = False,
206
+ include_alt_keys: bool = True,
207
+ ) -> list[str]:
208
+ """
209
+ Return model names in the registry
210
+ - provider: Filter by provider name
211
+ - include_deprecated: Include deprecated models
212
+ - include_alt_keys: Include alternative keys from the same provider
213
+ """
214
+ registry = get_model_registry()
215
+ alternative_keys_set: set[str] = set()
216
+
217
+ if not include_alt_keys:
218
+ for model in registry.values():
219
+ for alt_item in model.alternative_keys:
220
+ alt_key = (
221
+ alt_item if isinstance(alt_item, str) else list(alt_item.keys())[0]
222
+ )
223
+ if alt_key.split("/")[0] == model.provider_name:
224
+ alternative_keys_set.add(alt_key)
225
+
226
+ return sorted(
227
+ [
228
+ model.full_key
229
+ for model in get_model_registry().values()
230
+ if (not provider or model.provider_name.lower() == provider.lower())
231
+ and (not model.metadata.deprecated or include_deprecated)
232
+ and model.full_key not in alternative_keys_set
233
+ ]
234
+ )
202
235
 
203
236
 
204
- @cache
205
- def get_model_names_by_provider(provider_name: str) -> list[str]:
206
- """Return all models in the registry from a provider"""
207
- return [
208
- model.full_key
209
- for model in get_model_registry().values()
210
- if model.provider_name.lower() == provider_name.lower()
211
- ]
237
+ """
238
+ everything below this comment is included for legacy support of caselaw/corpfin custom models.
239
+ @orestes please remove this as part of the migration to a standard CorpFin harness.
240
+ """
241
+
242
+ DEFAULT_CONTEXT_WINDOW = 128_000
212
243
 
213
244
 
214
245
  @cache
@@ -233,7 +264,7 @@ def auto_trim_document(
233
264
  Trimmed document, or original document if trimming isn't needed
234
265
  """
235
266
 
236
- max_tokens = get_max_document_tokens(model_name)
267
+ max_tokens = get_max_document_tokens(model_name) or DEFAULT_CONTEXT_WINDOW
237
268
 
238
269
  encoding = _get_tiktoken_encoder()
239
270
  tokens = encoding.encode(document)
@@ -260,5 +291,5 @@ def get_max_document_tokens(model_name: str, output_buffer: int = 10000) -> int:
260
291
  # Import here to avoid circular imports
261
292
  from model_library.utils import get_context_window_for_model
262
293
 
263
- context_window = get_context_window_for_model(model_name)
294
+ context_window = get_context_window_for_model(model_name) or DEFAULT_CONTEXT_WINDOW
264
295
  return context_window - output_buffer
File without changes
@@ -0,0 +1,73 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ from model_library.base.base import QueryResult
5
+ from model_library.exceptions import exception_message
6
+ from model_library.retriers.base import BaseRetrier
7
+ from model_library.retriers.utils import jitter
8
+
9
+ RETRY_MAX_TRIES: int = 20
10
+ RETRY_INITIAL: float = 10.0
11
+ RETRY_EXPO: float = 1.4
12
+ RETRY_MAX_BACKOFF_WAIT: float = 240.0
13
+
14
+
15
+ class ExponentialBackoffRetrier(BaseRetrier):
16
+ """
17
+ Exponential backoff retry strategy.
18
+ Uses exponential backoff with jitter for wait times.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ logger: logging.Logger,
24
+ max_tries: int = RETRY_MAX_TRIES,
25
+ max_time: float | None = None,
26
+ retry_callback: Callable[[int, Exception | None, float, float], None]
27
+ | None = None,
28
+ *,
29
+ initial: float = RETRY_INITIAL,
30
+ expo: float = RETRY_EXPO,
31
+ max_backoff_wait: float = RETRY_MAX_BACKOFF_WAIT,
32
+ ):
33
+ super().__init__(
34
+ strategy="backoff",
35
+ logger=logger,
36
+ max_tries=max_tries,
37
+ max_time=max_time,
38
+ retry_callback=retry_callback,
39
+ )
40
+
41
+ self.initial = initial
42
+ self.expo = expo
43
+ self.max_backoff_wait = max_backoff_wait
44
+
45
+ async def _calculate_wait_time(
46
+ self, attempt: int, exception: Exception | None = None
47
+ ) -> float:
48
+ """Calculate exponential backoff wait time with jitter"""
49
+
50
+ exponential_wait = self.initial * (self.expo**attempt)
51
+ capped_wait = min(exponential_wait, self.max_backoff_wait)
52
+ return jitter(capped_wait)
53
+
54
+ async def _on_retry(
55
+ self, exception: Exception | None, elapsed: float, wait_time: float
56
+ ) -> None:
57
+ """Increment attempt counter and log retry attempt"""
58
+
59
+ logger_msg = f"[Retry] | {self.strategy} | Attempt: {self.attempts} | Elapsed: {elapsed:.1f}s | Next wait: {wait_time:.1f}s | Exception: {exception_message(exception)} "
60
+
61
+ self.logger.warning(logger_msg)
62
+
63
+ if self.retry_callback:
64
+ self.retry_callback(self.attempts, exception, elapsed, wait_time)
65
+
66
+ async def _pre_function(self) -> None:
67
+ return
68
+
69
+ async def _post_function(self, result: tuple[QueryResult, float]) -> None:
70
+ return
71
+
72
+ async def validate(self) -> None:
73
+ return