model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/__init__.py +7 -0
  2. model_library/{base.py → base/base.py} +58 -429
  3. model_library/base/batch.py +121 -0
  4. model_library/base/delegate_only.py +94 -0
  5. model_library/base/input.py +100 -0
  6. model_library/base/output.py +229 -0
  7. model_library/base/utils.py +43 -0
  8. model_library/config/ai21labs_models.yaml +1 -0
  9. model_library/config/all_models.json +461 -36
  10. model_library/config/anthropic_models.yaml +30 -3
  11. model_library/config/deepseek_models.yaml +3 -1
  12. model_library/config/google_models.yaml +49 -0
  13. model_library/config/openai_models.yaml +43 -4
  14. model_library/config/together_models.yaml +1 -0
  15. model_library/config/xai_models.yaml +63 -3
  16. model_library/exceptions.py +8 -2
  17. model_library/file_utils.py +1 -1
  18. model_library/providers/__init__.py +0 -0
  19. model_library/providers/ai21labs.py +2 -0
  20. model_library/providers/alibaba.py +16 -78
  21. model_library/providers/amazon.py +3 -0
  22. model_library/providers/anthropic.py +215 -8
  23. model_library/providers/azure.py +2 -0
  24. model_library/providers/cohere.py +14 -80
  25. model_library/providers/deepseek.py +14 -90
  26. model_library/providers/fireworks.py +17 -81
  27. model_library/providers/google/google.py +55 -47
  28. model_library/providers/inception.py +15 -83
  29. model_library/providers/kimi.py +15 -83
  30. model_library/providers/mistral.py +2 -0
  31. model_library/providers/openai.py +10 -2
  32. model_library/providers/perplexity.py +12 -79
  33. model_library/providers/together.py +19 -210
  34. model_library/providers/vals.py +2 -0
  35. model_library/providers/xai.py +2 -0
  36. model_library/providers/zai.py +15 -83
  37. model_library/register_models.py +75 -57
  38. model_library/registry_utils.py +5 -5
  39. model_library/utils.py +3 -28
  40. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
  41. model_library-0.1.3.dist-info/RECORD +61 -0
  42. model_library-0.1.1.dist-info/RECORD +0 -54
  43. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
  44. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,49 +1,27 @@
1
- import io
2
- from typing import Any, Literal, Sequence, cast
1
+ from typing import Literal
3
2
 
4
- from together import AsyncTogether
5
- from together.types.chat_completions import (
6
- ChatCompletionMessage,
7
- ChatCompletionResponse,
8
- )
9
3
  from typing_extensions import override
10
4
 
11
5
  from model_library import model_library_settings
12
6
  from model_library.base import (
13
- LLM,
14
- FileInput,
15
- FileWithBase64,
16
- FileWithId,
17
- FileWithUrl,
18
- InputItem,
7
+ DelegateOnly,
19
8
  LLMConfig,
20
- QueryResult,
9
+ ProviderConfig,
21
10
  QueryResultCost,
22
11
  QueryResultMetadata,
23
- TextInput,
24
- ToolDefinition,
25
- )
26
- from model_library.exceptions import (
27
- BadInputError,
28
- MaxOutputTokensExceededError,
29
- ModelNoOutputError,
30
12
  )
31
- from model_library.file_utils import trim_images
32
- from model_library.model_utils import get_reasoning_in_tag
33
13
  from model_library.providers.openai import OpenAIModel
14
+ from model_library.register_models import register_provider
34
15
  from model_library.utils import create_openai_client_with_defaults
35
16
 
36
17
 
37
- class TogetherModel(LLM):
38
- _client: AsyncTogether | None = None
18
+ class TogetherConfig(ProviderConfig):
19
+ serverless: bool = True
39
20
 
40
- @override
41
- def get_client(self) -> AsyncTogether:
42
- if not TogetherModel._client:
43
- TogetherModel._client = AsyncTogether(
44
- api_key=model_library_settings.TOGETHER_API_KEY,
45
- )
46
- return TogetherModel._client
21
+
22
+ @register_provider("together")
23
+ class TogetherModel(DelegateOnly):
24
+ provider_config = TogetherConfig()
47
25
 
48
26
  def __init__(
49
27
  self,
@@ -53,187 +31,18 @@ class TogetherModel(LLM):
53
31
  config: LLMConfig | None = None,
54
32
  ):
55
33
  super().__init__(model_name, provider, config=config)
56
-
57
34
  # https://docs.together.ai/docs/openai-api-compatibility
58
- self.delegate: OpenAIModel | None = (
59
- None
60
- if self.native
61
- else OpenAIModel(
62
- model_name=model_name,
63
- provider=provider,
64
- config=config,
65
- custom_client=create_openai_client_with_defaults(
66
- api_key=model_library_settings.TOGETHER_API_KEY,
67
- base_url="https://api.together.xyz/v1",
68
- ),
69
- use_completions=False,
70
- )
71
- )
72
-
73
- @override
74
- async def parse_input(
75
- self,
76
- input: Sequence[InputItem],
77
- **kwargs: Any,
78
- ) -> list[dict[str, Any] | Any]:
79
- new_input: list[dict[str, Any] | Any] = []
80
- content_user: list[dict[str, Any]] = []
81
-
82
- def flush_content_user():
83
- nonlocal content_user
84
-
85
- if content_user:
86
- new_input.append({"role": "user", "content": content_user})
87
- content_user = []
88
-
89
- for item in input:
90
- match item:
91
- case TextInput():
92
- content_user.append({"type": "text", "text": item.text})
93
- case FileWithBase64() | FileWithUrl() | FileWithId():
94
- match item.type:
95
- case "image":
96
- content_user.append(await self.parse_image(item))
97
- case "file":
98
- content_user.append(await self.parse_file(item))
99
- case ChatCompletionMessage():
100
- flush_content_user()
101
- new_input.append(item)
102
- case _:
103
- raise BadInputError("Unsupported input type")
104
-
105
- flush_content_user()
106
-
107
- return new_input
108
-
109
- @override
110
- async def parse_image(
111
- self,
112
- image: FileInput,
113
- ) -> dict[str, Any]:
114
- match image:
115
- case FileWithBase64():
116
- return {
117
- "type": "image_url",
118
- "image_url": {
119
- "url": f"data:image/{image.mime};base64,{image.base64}"
120
- },
121
- }
122
- case _:
123
- # docs show that we can pass in s3 location somehow
124
- raise BadInputError("Unsupported image type")
125
-
126
- @override
127
- async def parse_file(
128
- self,
129
- file: FileInput,
130
- ) -> Any:
131
- raise NotImplementedError()
132
-
133
- @override
134
- async def parse_tools(
135
- self,
136
- tools: list[ToolDefinition],
137
- ) -> Any:
138
- raise NotImplementedError()
139
-
140
- @override
141
- async def upload_file(
142
- self,
143
- name: str,
144
- mime: str,
145
- bytes: io.BytesIO,
146
- type: Literal["image", "file"] = "file",
147
- ) -> FileWithId:
148
- raise NotImplementedError()
149
-
150
- @override
151
- async def _query_impl(
152
- self,
153
- input: Sequence[InputItem],
154
- *,
155
- tools: list[ToolDefinition],
156
- **kwargs: object,
157
- ) -> QueryResult:
158
- if self.delegate:
159
- return await self.delegate_query(input, tools=tools, **kwargs)
160
-
161
- # llama supports max 5 images
162
- if "lama-4" in self.model_name:
163
- input = trim_images(input, max_images=5)
164
-
165
- messages: list[dict[str, Any]] = []
166
-
167
- if "nemotron-super" in self.model_name:
168
- # move system prompt to prompt
169
- if "system_prompt" in kwargs:
170
- first_text_item = next(
171
- (item for item in input if isinstance(item, TextInput)), None
172
- )
173
- if not first_text_item:
174
- raise Exception(
175
- "Given system prompt for nemotron-super model, but no text input found"
176
- )
177
- system_prompt = kwargs.pop("system_prompt")
178
- first_text_item.text = f"SYSTEM PROMPT: {system_prompt}\nUSER PROMPT: {first_text_item.text}"
179
-
180
- # set system prompt to detailed thinking
181
- mode = "on" if self.reasoning else "off"
182
- kwargs["system_prompt"] = f"detailed thinking {mode}"
183
- messages.append(
184
- {
185
- "role": "system",
186
- "content": f"detailed thinking {mode}",
187
- }
188
- )
189
-
190
- if "system_prompt" in kwargs:
191
- messages.append({"role": "system", "content": kwargs.pop("system_prompt")})
192
-
193
- messages.extend(await self.parse_input(input))
194
-
195
- body: dict[str, Any] = {
196
- "max_tokens": self.max_tokens,
197
- "model": self.model_name,
198
- "messages": messages,
199
- }
200
-
201
- if self.supports_temperature:
202
- if self.temperature is not None:
203
- body["temperature"] = self.temperature
204
- if self.top_p is not None:
205
- body["top_p"] = self.top_p
206
-
207
- body.update(kwargs)
208
-
209
- response = await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny]
210
-
211
- response = cast(ChatCompletionResponse, response)
212
-
213
- if not response or not response.choices or not response.choices[0].message:
214
- raise ModelNoOutputError("Model returned no completions")
215
-
216
- text = str(response.choices[0].message.content)
217
- reasoning = None
218
-
219
- if response.choices[0].finish_reason == "length" and not text:
220
- raise MaxOutputTokensExceededError()
221
-
222
- if self.reasoning:
223
- text, reasoning = get_reasoning_in_tag(text)
224
-
225
- output = QueryResult(
226
- output_text=text,
227
- reasoning=reasoning,
228
- history=[*input, response.choices[0].message],
35
+ self.delegate = OpenAIModel(
36
+ model_name=self.model_name,
37
+ provider=self.provider,
38
+ config=config,
39
+ custom_client=create_openai_client_with_defaults(
40
+ api_key=model_library_settings.TOGETHER_API_KEY,
41
+ base_url="https://api.together.xyz/v1",
42
+ ),
43
+ use_completions=True,
229
44
  )
230
45
 
231
- if response.usage:
232
- output.metadata.in_tokens = response.usage.prompt_tokens
233
- output.metadata.out_tokens = response.usage.completion_tokens
234
- # no cache tokens it seems
235
- return output
236
-
237
46
  @override
238
47
  async def _calculate_cost(
239
48
  self,
@@ -27,6 +27,7 @@ from model_library.base import (
27
27
  TextInput,
28
28
  ToolDefinition,
29
29
  )
30
+ from model_library.register_models import register_provider
30
31
  from model_library.utils import truncate_str
31
32
 
32
33
  FAIL_RATE = 0.1
@@ -145,6 +146,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
145
146
  return batch_status == "failed"
146
147
 
147
148
 
149
+ @register_provider("vals")
148
150
  class DummyAIModel(LLM):
149
151
  _client: Redis | None = None
150
152
 
@@ -39,6 +39,7 @@ from model_library.exceptions import (
39
39
  RateLimitException,
40
40
  )
41
41
  from model_library.providers.openai import OpenAIModel
42
+ from model_library.register_models import register_provider
42
43
  from model_library.utils import create_openai_client_with_defaults
43
44
 
44
45
  Chat = AsyncChat | SyncChat
@@ -48,6 +49,7 @@ class XAIConfig(ProviderConfig):
48
49
  sync_client: bool = False
49
50
 
50
51
 
52
+ @register_provider("grok")
51
53
  class XAIModel(LLM):
52
54
  provider_config = XAIConfig()
53
55
 
@@ -1,27 +1,17 @@
1
- import io
2
- from typing import Any, Literal, Sequence
3
-
4
- from typing_extensions import override
1
+ from typing import Literal
5
2
 
6
3
  from model_library import model_library_settings
7
4
  from model_library.base import (
8
- LLM,
9
- FileInput,
10
- FileWithId,
11
- InputItem,
5
+ DelegateOnly,
12
6
  LLMConfig,
13
- QueryResult,
14
- ToolDefinition,
15
7
  )
16
8
  from model_library.providers.openai import OpenAIModel
9
+ from model_library.register_models import register_provider
17
10
  from model_library.utils import create_openai_client_with_defaults
18
11
 
19
12
 
20
- class ZAIModel(LLM):
21
- @override
22
- def get_client(self) -> None:
23
- raise NotImplementedError("Not implemented")
24
-
13
+ @register_provider("zai")
14
+ class ZAIModel(DelegateOnly):
25
15
  def __init__(
26
16
  self,
27
17
  model_name: str,
@@ -30,73 +20,15 @@ class ZAIModel(LLM):
30
20
  config: LLMConfig | None = None,
31
21
  ):
32
22
  super().__init__(model_name, provider, config=config)
33
- self.model_name: str = model_name
34
- self.native: bool = False
35
23
 
36
- # https://docs.z.ai/
37
- self.delegate: OpenAIModel | None = (
38
- None
39
- if self.native
40
- else OpenAIModel(
41
- model_name=self.model_name,
42
- provider=provider,
43
- config=config,
44
- custom_client=create_openai_client_with_defaults(
45
- api_key=model_library_settings.ZAI_API_KEY,
46
- base_url="https://open.bigmodel.cn/api/paas/v4/",
47
- ),
48
- use_completions=True,
49
- )
24
+ # https://docs.z.ai/guides/develop/openai/python
25
+ self.delegate = OpenAIModel(
26
+ model_name=self.model_name,
27
+ provider=self.provider,
28
+ config=config,
29
+ custom_client=create_openai_client_with_defaults(
30
+ api_key=model_library_settings.ZAI_API_KEY,
31
+ base_url="https://open.bigmodel.cn/api/paas/v4/",
32
+ ),
33
+ use_completions=True,
50
34
  )
51
-
52
- @override
53
- async def parse_input(
54
- self,
55
- input: Sequence[InputItem],
56
- **kwargs: Any,
57
- ) -> Any:
58
- raise NotImplementedError()
59
-
60
- @override
61
- async def parse_image(
62
- self,
63
- image: FileInput,
64
- ) -> Any:
65
- raise NotImplementedError()
66
-
67
- @override
68
- async def parse_file(
69
- self,
70
- file: FileInput,
71
- ) -> Any:
72
- raise NotImplementedError()
73
-
74
- @override
75
- async def parse_tools(
76
- self,
77
- tools: list[ToolDefinition],
78
- ) -> Any:
79
- raise NotImplementedError()
80
-
81
- @override
82
- async def upload_file(
83
- self,
84
- name: str,
85
- mime: str,
86
- bytes: io.BytesIO,
87
- type: Literal["image", "file"] = "file",
88
- ) -> FileWithId:
89
- raise NotImplementedError()
90
-
91
- @override
92
- async def _query_impl(
93
- self,
94
- input: Sequence[InputItem],
95
- *,
96
- tools: list[ToolDefinition],
97
- **kwargs: object,
98
- ) -> QueryResult:
99
- # relies on oAI delegate
100
- if self.delegate:
101
- return await self.delegate_query(input, tools=tools, **kwargs)
102
- raise NotImplementedError()
@@ -1,61 +1,23 @@
1
+ import importlib
2
+ import pkgutil
1
3
  import threading
2
4
  from copy import deepcopy
3
5
  from datetime import date
4
6
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Any, cast, get_type_hints
7
+ from typing import Any, Callable, Type, TypeVar, cast, get_type_hints
6
8
 
7
9
  import yaml
8
10
  from pydantic import create_model, model_validator
9
11
  from pydantic.fields import Field
10
12
  from pydantic.main import BaseModel
11
13
 
14
+ from model_library import providers
12
15
  from model_library.base import LLM, ProviderConfig
13
- from model_library.providers.ai21labs import AI21LabsModel
14
- from model_library.providers.alibaba import AlibabaModel
15
- from model_library.providers.amazon import AmazonModel
16
- from model_library.providers.anthropic import AnthropicModel
17
- from model_library.providers.azure import AzureOpenAIModel
18
- from model_library.providers.cohere import CohereModel
19
- from model_library.providers.deepseek import DeepSeekModel
20
- from model_library.providers.fireworks import FireworksModel
21
- from model_library.providers.google.google import GoogleModel
22
- from model_library.providers.inception import MercuryModel
23
- from model_library.providers.kimi import KimiModel
24
- from model_library.providers.mistral import MistralModel
25
- from model_library.providers.openai import OpenAIModel
26
- from model_library.providers.perplexity import PerplexityModel
27
- from model_library.providers.together import TogetherModel
28
- from model_library.providers.vals import DummyAIModel
29
- from model_library.providers.xai import XAIModel
30
- from model_library.providers.zai import ZAIModel
31
16
  from model_library.utils import get_logger
32
17
 
33
- MAPPING_PROVIDERS: dict[str, type[LLM]] = {
34
- "openai": OpenAIModel,
35
- "azure": AzureOpenAIModel,
36
- "anthropic": AnthropicModel,
37
- "together": TogetherModel,
38
- "mistralai": MistralModel,
39
- "grok": XAIModel,
40
- "fireworks": FireworksModel,
41
- "ai21labs": AI21LabsModel,
42
- "amazon": AmazonModel,
43
- "bedrock": AmazonModel,
44
- "cohere": CohereModel,
45
- "google": GoogleModel,
46
- "vals": DummyAIModel,
47
- "alibaba": AlibabaModel,
48
- "perplexity": PerplexityModel,
49
- "deepseek": DeepSeekModel,
50
- "zai": ZAIModel,
51
- "kimi": KimiModel,
52
- "inception": MercuryModel,
53
- }
54
-
55
- logger = get_logger(__name__)
56
- # Folder containing provider YAMLs
57
- path_library = Path(__file__).parent / "config"
18
+ T = TypeVar("T", bound=LLM)
58
19
 
20
+ logger = get_logger("register_models")
59
21
 
60
22
  """
61
23
  Model Registry structure
@@ -174,14 +136,13 @@ class ClassProperties(BaseModel):
174
136
  Each provider can have a set of provider-specific properties, we however want to accept
175
137
  any possible property from a provider in the yaml, and validate later. So we join all
176
138
  provider-specific properties into a single class.
139
+ This has no effect on runtime use of ProviderConfig, only used to load the yaml
177
140
  """
178
141
 
179
142
 
180
143
  class BaseProviderProperties(BaseModel):
181
144
  """Static base class for dynamic ProviderProperties."""
182
145
 
183
- pass
184
-
185
146
 
186
147
  def all_subclasses(cls: type) -> list[type]:
187
148
  """Recursively find all subclasses of a class."""
@@ -210,14 +171,6 @@ def get_dynamic_provider_properties_model() -> type[BaseProviderProperties]:
210
171
  )
211
172
 
212
173
 
213
- ProviderProperties = get_dynamic_provider_properties_model()
214
-
215
- if TYPE_CHECKING:
216
- ProviderPropertiesType = BaseProviderProperties
217
- else:
218
- ProviderPropertiesType = ProviderProperties
219
-
220
-
221
174
  class DefaultParameters(BaseModel):
222
175
  max_output_tokens: int | None = None
223
176
  temperature: float | None = None
@@ -234,13 +187,20 @@ class RawModelConfig(BaseModel):
234
187
  documentation_url: str | None = None
235
188
  properties: Properties = Field(default_factory=Properties)
236
189
  class_properties: ClassProperties = Field(default_factory=ClassProperties)
237
- provider_properties: ProviderPropertiesType = Field(
238
- default_factory=ProviderProperties
239
- )
190
+ provider_properties: BaseProviderProperties | None = None
240
191
  costs_per_million_token: CostProperties = Field(default_factory=CostProperties)
241
192
  alternative_keys: list[str | dict[str, Any]] = Field(default_factory=list)
242
193
  default_parameters: DefaultParameters = Field(default_factory=DefaultParameters)
243
194
 
195
+ def model_dump(self, *args: object, **kwargs: object):
196
+ data = super().model_dump(*args, **kwargs)
197
+ if self.provider_properties is not None:
198
+ # explicitly dump dynamic ProviderProperties instance
199
+ data["provider_properties"] = self.provider_properties.model_dump(
200
+ *args, **kwargs
201
+ )
202
+ return data
203
+
244
204
 
245
205
  class ModelConfig(RawModelConfig):
246
206
  # post processing fields
@@ -252,6 +212,9 @@ class ModelConfig(RawModelConfig):
252
212
 
253
213
  ModelRegistry = dict[str, ModelConfig]
254
214
 
215
+ # Folder containing provider YAMLs
216
+ path_library = Path(__file__).parent / "config"
217
+
255
218
 
256
219
  def deep_update(
257
220
  base: dict[str, Any], updates: dict[str, str | dict[str, Any]]
@@ -270,6 +233,9 @@ def _register_models() -> ModelRegistry:
270
233
 
271
234
  registry: ModelRegistry = {}
272
235
 
236
+ # generate ProviderProperties class
237
+ ProviderProperties = get_dynamic_provider_properties_model()
238
+
273
239
  # load each provider YAML
274
240
  sections = Path(path_library).glob("*.yaml")
275
241
  sections = sorted(sections, key=lambda x: "openai" in x.name.lower())
@@ -325,6 +291,10 @@ def _register_models() -> ModelRegistry:
325
291
  "slug": model_name.replace("/", "_"),
326
292
  }
327
293
  )
294
+ # load provider properties separately since the model was generated at runtime
295
+ model_obj.provider_properties = ProviderProperties.model_validate(
296
+ current_model_config.get("provider_properties", {})
297
+ )
328
298
 
329
299
  registry[model_name] = model_obj
330
300
 
@@ -371,6 +341,50 @@ def _register_models() -> ModelRegistry:
371
341
  return registry
372
342
 
373
343
 
344
+ _provider_registry: dict[str, type[LLM]] = {}
345
+ _provider_registry_lock = threading.Lock()
346
+ _imported_providers = False
347
+
348
+
349
+ def register_provider(name: str) -> Callable[[Type[T]], Type[T]]:
350
+ def decorator(cls: Type[T]) -> Type[T]:
351
+ logger.debug(f"Registering provider {name}")
352
+
353
+ if name in _provider_registry:
354
+ raise ValueError(f"Provider {name} is already registered.")
355
+ _provider_registry[name] = cls
356
+ return cls
357
+
358
+ return decorator
359
+
360
+
361
+ def _import_all_providers():
362
+ """Import all provider modules. Any class with @register_provider will be automatically registered upon import"""
363
+
364
+ package_name = providers.__name__
365
+
366
+ # walk all submodules recursively
367
+ for _, module_name, _ in pkgutil.walk_packages(
368
+ providers.__path__, package_name + "."
369
+ ):
370
+ # skip private modules
371
+ if module_name.split(".")[-1].startswith("_"):
372
+ continue
373
+ importlib.import_module(module_name)
374
+
375
+
376
+ def get_provider_registry() -> dict[str, type[LLM]]:
377
+ """Return the provider registry, lazily loading all modules on first call."""
378
+ global _imported_providers
379
+ if not _imported_providers:
380
+ with _provider_registry_lock:
381
+ if not _imported_providers:
382
+ _import_all_providers()
383
+ _imported_providers = True
384
+
385
+ return _provider_registry
386
+
387
+
374
388
  _model_registry: ModelRegistry | None = None
375
389
  _model_registry_lock = threading.Lock()
376
390
 
@@ -381,5 +395,9 @@ def get_model_registry() -> ModelRegistry:
381
395
  if _model_registry is None:
382
396
  with _model_registry_lock:
383
397
  if _model_registry is None:
398
+ # initialize provider registry
399
+ global get_provider_registry
400
+ get_provider_registry()
401
+
384
402
  _model_registry = _register_models()
385
403
  return _model_registry
@@ -5,10 +5,10 @@ import tiktoken
5
5
 
6
6
  from model_library.base import LLM, LLMConfig, ProviderConfig
7
7
  from model_library.register_models import (
8
- MAPPING_PROVIDERS,
9
8
  CostProperties,
10
9
  ModelConfig,
11
10
  get_model_registry,
11
+ get_provider_registry,
12
12
  )
13
13
 
14
14
  ALL_MODELS_PATH = Path(__file__).parent / "config" / "all_models.json"
@@ -51,7 +51,7 @@ def create_config(
51
51
 
52
52
  # load provider config with correct type
53
53
  if provider_properties:
54
- ModelClass: type[LLM] = MAPPING_PROVIDERS[registry_config.provider_name]
54
+ ModelClass: type[LLM] = get_provider_registry()[registry_config.provider_name]
55
55
  if hasattr(ModelClass, "provider_config"):
56
56
  ProviderConfigClass: type[ProviderConfig] = type(ModelClass.provider_config) # type: ignore
57
57
  provider_config: ProviderConfig = ProviderConfigClass.model_validate(
@@ -89,7 +89,7 @@ def _get_model_from_registry(
89
89
 
90
90
  provider_name: str = registry_config.provider_name
91
91
  provider_endpoint: str = registry_config.provider_endpoint
92
- ModelClass: type[LLM] = MAPPING_PROVIDERS[provider_name]
92
+ ModelClass: type[LLM] = get_provider_registry()[provider_name]
93
93
 
94
94
  return ModelClass(
95
95
  model_name=provider_endpoint,
@@ -115,7 +115,7 @@ def get_registry_model(model_str: str, override_config: LLMConfig | None = None)
115
115
  def get_raw_model(model_str: str, config: LLMConfig | None = None) -> LLM:
116
116
  """Get a model exluding default config"""
117
117
  provider, model_name = model_str.split("/", 1)
118
- ModelClass = MAPPING_PROVIDERS[provider]
118
+ ModelClass = get_provider_registry()[provider]
119
119
  return ModelClass(model_name=model_name, provider=provider, config=config)
120
120
 
121
121
 
@@ -130,7 +130,7 @@ def get_model_cost(model_str: str) -> CostProperties | None:
130
130
  @cache
131
131
  def get_provider_names() -> list[str]:
132
132
  """Return all provider names in the registry"""
133
- return sorted([provider_name for provider_name in MAPPING_PROVIDERS.keys()])
133
+ return sorted([provider_name for provider_name in get_provider_registry().keys()])
134
134
 
135
135
 
136
136
  @cache