model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/__init__.py +7 -0
- model_library/{base.py → base/base.py} +58 -429
- model_library/base/batch.py +121 -0
- model_library/base/delegate_only.py +94 -0
- model_library/base/input.py +100 -0
- model_library/base/output.py +229 -0
- model_library/base/utils.py +43 -0
- model_library/config/ai21labs_models.yaml +1 -0
- model_library/config/all_models.json +461 -36
- model_library/config/anthropic_models.yaml +30 -3
- model_library/config/deepseek_models.yaml +3 -1
- model_library/config/google_models.yaml +49 -0
- model_library/config/openai_models.yaml +43 -4
- model_library/config/together_models.yaml +1 -0
- model_library/config/xai_models.yaml +63 -3
- model_library/exceptions.py +8 -2
- model_library/file_utils.py +1 -1
- model_library/providers/__init__.py +0 -0
- model_library/providers/ai21labs.py +2 -0
- model_library/providers/alibaba.py +16 -78
- model_library/providers/amazon.py +3 -0
- model_library/providers/anthropic.py +215 -8
- model_library/providers/azure.py +2 -0
- model_library/providers/cohere.py +14 -80
- model_library/providers/deepseek.py +14 -90
- model_library/providers/fireworks.py +17 -81
- model_library/providers/google/google.py +55 -47
- model_library/providers/inception.py +15 -83
- model_library/providers/kimi.py +15 -83
- model_library/providers/mistral.py +2 -0
- model_library/providers/openai.py +10 -2
- model_library/providers/perplexity.py +12 -79
- model_library/providers/together.py +19 -210
- model_library/providers/vals.py +2 -0
- model_library/providers/xai.py +2 -0
- model_library/providers/zai.py +15 -83
- model_library/register_models.py +75 -57
- model_library/registry_utils.py +5 -5
- model_library/utils.py +3 -28
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
- model_library-0.1.3.dist-info/RECORD +61 -0
- model_library-0.1.1.dist-info/RECORD +0 -54
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -1,49 +1,27 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence, cast
|
|
1
|
+
from typing import Literal
|
|
3
2
|
|
|
4
|
-
from together import AsyncTogether
|
|
5
|
-
from together.types.chat_completions import (
|
|
6
|
-
ChatCompletionMessage,
|
|
7
|
-
ChatCompletionResponse,
|
|
8
|
-
)
|
|
9
3
|
from typing_extensions import override
|
|
10
4
|
|
|
11
5
|
from model_library import model_library_settings
|
|
12
6
|
from model_library.base import (
|
|
13
|
-
|
|
14
|
-
FileInput,
|
|
15
|
-
FileWithBase64,
|
|
16
|
-
FileWithId,
|
|
17
|
-
FileWithUrl,
|
|
18
|
-
InputItem,
|
|
7
|
+
DelegateOnly,
|
|
19
8
|
LLMConfig,
|
|
20
|
-
|
|
9
|
+
ProviderConfig,
|
|
21
10
|
QueryResultCost,
|
|
22
11
|
QueryResultMetadata,
|
|
23
|
-
TextInput,
|
|
24
|
-
ToolDefinition,
|
|
25
|
-
)
|
|
26
|
-
from model_library.exceptions import (
|
|
27
|
-
BadInputError,
|
|
28
|
-
MaxOutputTokensExceededError,
|
|
29
|
-
ModelNoOutputError,
|
|
30
12
|
)
|
|
31
|
-
from model_library.file_utils import trim_images
|
|
32
|
-
from model_library.model_utils import get_reasoning_in_tag
|
|
33
13
|
from model_library.providers.openai import OpenAIModel
|
|
14
|
+
from model_library.register_models import register_provider
|
|
34
15
|
from model_library.utils import create_openai_client_with_defaults
|
|
35
16
|
|
|
36
17
|
|
|
37
|
-
class
|
|
38
|
-
|
|
18
|
+
class TogetherConfig(ProviderConfig):
|
|
19
|
+
serverless: bool = True
|
|
39
20
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
api_key=model_library_settings.TOGETHER_API_KEY,
|
|
45
|
-
)
|
|
46
|
-
return TogetherModel._client
|
|
21
|
+
|
|
22
|
+
@register_provider("together")
|
|
23
|
+
class TogetherModel(DelegateOnly):
|
|
24
|
+
provider_config = TogetherConfig()
|
|
47
25
|
|
|
48
26
|
def __init__(
|
|
49
27
|
self,
|
|
@@ -53,187 +31,18 @@ class TogetherModel(LLM):
|
|
|
53
31
|
config: LLMConfig | None = None,
|
|
54
32
|
):
|
|
55
33
|
super().__init__(model_name, provider, config=config)
|
|
56
|
-
|
|
57
34
|
# https://docs.together.ai/docs/openai-api-compatibility
|
|
58
|
-
self.delegate
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
base_url="https://api.together.xyz/v1",
|
|
68
|
-
),
|
|
69
|
-
use_completions=False,
|
|
70
|
-
)
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
async def parse_input(
|
|
75
|
-
self,
|
|
76
|
-
input: Sequence[InputItem],
|
|
77
|
-
**kwargs: Any,
|
|
78
|
-
) -> list[dict[str, Any] | Any]:
|
|
79
|
-
new_input: list[dict[str, Any] | Any] = []
|
|
80
|
-
content_user: list[dict[str, Any]] = []
|
|
81
|
-
|
|
82
|
-
def flush_content_user():
|
|
83
|
-
nonlocal content_user
|
|
84
|
-
|
|
85
|
-
if content_user:
|
|
86
|
-
new_input.append({"role": "user", "content": content_user})
|
|
87
|
-
content_user = []
|
|
88
|
-
|
|
89
|
-
for item in input:
|
|
90
|
-
match item:
|
|
91
|
-
case TextInput():
|
|
92
|
-
content_user.append({"type": "text", "text": item.text})
|
|
93
|
-
case FileWithBase64() | FileWithUrl() | FileWithId():
|
|
94
|
-
match item.type:
|
|
95
|
-
case "image":
|
|
96
|
-
content_user.append(await self.parse_image(item))
|
|
97
|
-
case "file":
|
|
98
|
-
content_user.append(await self.parse_file(item))
|
|
99
|
-
case ChatCompletionMessage():
|
|
100
|
-
flush_content_user()
|
|
101
|
-
new_input.append(item)
|
|
102
|
-
case _:
|
|
103
|
-
raise BadInputError("Unsupported input type")
|
|
104
|
-
|
|
105
|
-
flush_content_user()
|
|
106
|
-
|
|
107
|
-
return new_input
|
|
108
|
-
|
|
109
|
-
@override
|
|
110
|
-
async def parse_image(
|
|
111
|
-
self,
|
|
112
|
-
image: FileInput,
|
|
113
|
-
) -> dict[str, Any]:
|
|
114
|
-
match image:
|
|
115
|
-
case FileWithBase64():
|
|
116
|
-
return {
|
|
117
|
-
"type": "image_url",
|
|
118
|
-
"image_url": {
|
|
119
|
-
"url": f"data:image/{image.mime};base64,{image.base64}"
|
|
120
|
-
},
|
|
121
|
-
}
|
|
122
|
-
case _:
|
|
123
|
-
# docs show that we can pass in s3 location somehow
|
|
124
|
-
raise BadInputError("Unsupported image type")
|
|
125
|
-
|
|
126
|
-
@override
|
|
127
|
-
async def parse_file(
|
|
128
|
-
self,
|
|
129
|
-
file: FileInput,
|
|
130
|
-
) -> Any:
|
|
131
|
-
raise NotImplementedError()
|
|
132
|
-
|
|
133
|
-
@override
|
|
134
|
-
async def parse_tools(
|
|
135
|
-
self,
|
|
136
|
-
tools: list[ToolDefinition],
|
|
137
|
-
) -> Any:
|
|
138
|
-
raise NotImplementedError()
|
|
139
|
-
|
|
140
|
-
@override
|
|
141
|
-
async def upload_file(
|
|
142
|
-
self,
|
|
143
|
-
name: str,
|
|
144
|
-
mime: str,
|
|
145
|
-
bytes: io.BytesIO,
|
|
146
|
-
type: Literal["image", "file"] = "file",
|
|
147
|
-
) -> FileWithId:
|
|
148
|
-
raise NotImplementedError()
|
|
149
|
-
|
|
150
|
-
@override
|
|
151
|
-
async def _query_impl(
|
|
152
|
-
self,
|
|
153
|
-
input: Sequence[InputItem],
|
|
154
|
-
*,
|
|
155
|
-
tools: list[ToolDefinition],
|
|
156
|
-
**kwargs: object,
|
|
157
|
-
) -> QueryResult:
|
|
158
|
-
if self.delegate:
|
|
159
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
160
|
-
|
|
161
|
-
# llama supports max 5 images
|
|
162
|
-
if "lama-4" in self.model_name:
|
|
163
|
-
input = trim_images(input, max_images=5)
|
|
164
|
-
|
|
165
|
-
messages: list[dict[str, Any]] = []
|
|
166
|
-
|
|
167
|
-
if "nemotron-super" in self.model_name:
|
|
168
|
-
# move system prompt to prompt
|
|
169
|
-
if "system_prompt" in kwargs:
|
|
170
|
-
first_text_item = next(
|
|
171
|
-
(item for item in input if isinstance(item, TextInput)), None
|
|
172
|
-
)
|
|
173
|
-
if not first_text_item:
|
|
174
|
-
raise Exception(
|
|
175
|
-
"Given system prompt for nemotron-super model, but no text input found"
|
|
176
|
-
)
|
|
177
|
-
system_prompt = kwargs.pop("system_prompt")
|
|
178
|
-
first_text_item.text = f"SYSTEM PROMPT: {system_prompt}\nUSER PROMPT: {first_text_item.text}"
|
|
179
|
-
|
|
180
|
-
# set system prompt to detailed thinking
|
|
181
|
-
mode = "on" if self.reasoning else "off"
|
|
182
|
-
kwargs["system_prompt"] = f"detailed thinking {mode}"
|
|
183
|
-
messages.append(
|
|
184
|
-
{
|
|
185
|
-
"role": "system",
|
|
186
|
-
"content": f"detailed thinking {mode}",
|
|
187
|
-
}
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
if "system_prompt" in kwargs:
|
|
191
|
-
messages.append({"role": "system", "content": kwargs.pop("system_prompt")})
|
|
192
|
-
|
|
193
|
-
messages.extend(await self.parse_input(input))
|
|
194
|
-
|
|
195
|
-
body: dict[str, Any] = {
|
|
196
|
-
"max_tokens": self.max_tokens,
|
|
197
|
-
"model": self.model_name,
|
|
198
|
-
"messages": messages,
|
|
199
|
-
}
|
|
200
|
-
|
|
201
|
-
if self.supports_temperature:
|
|
202
|
-
if self.temperature is not None:
|
|
203
|
-
body["temperature"] = self.temperature
|
|
204
|
-
if self.top_p is not None:
|
|
205
|
-
body["top_p"] = self.top_p
|
|
206
|
-
|
|
207
|
-
body.update(kwargs)
|
|
208
|
-
|
|
209
|
-
response = await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny]
|
|
210
|
-
|
|
211
|
-
response = cast(ChatCompletionResponse, response)
|
|
212
|
-
|
|
213
|
-
if not response or not response.choices or not response.choices[0].message:
|
|
214
|
-
raise ModelNoOutputError("Model returned no completions")
|
|
215
|
-
|
|
216
|
-
text = str(response.choices[0].message.content)
|
|
217
|
-
reasoning = None
|
|
218
|
-
|
|
219
|
-
if response.choices[0].finish_reason == "length" and not text:
|
|
220
|
-
raise MaxOutputTokensExceededError()
|
|
221
|
-
|
|
222
|
-
if self.reasoning:
|
|
223
|
-
text, reasoning = get_reasoning_in_tag(text)
|
|
224
|
-
|
|
225
|
-
output = QueryResult(
|
|
226
|
-
output_text=text,
|
|
227
|
-
reasoning=reasoning,
|
|
228
|
-
history=[*input, response.choices[0].message],
|
|
35
|
+
self.delegate = OpenAIModel(
|
|
36
|
+
model_name=self.model_name,
|
|
37
|
+
provider=self.provider,
|
|
38
|
+
config=config,
|
|
39
|
+
custom_client=create_openai_client_with_defaults(
|
|
40
|
+
api_key=model_library_settings.TOGETHER_API_KEY,
|
|
41
|
+
base_url="https://api.together.xyz/v1",
|
|
42
|
+
),
|
|
43
|
+
use_completions=True,
|
|
229
44
|
)
|
|
230
45
|
|
|
231
|
-
if response.usage:
|
|
232
|
-
output.metadata.in_tokens = response.usage.prompt_tokens
|
|
233
|
-
output.metadata.out_tokens = response.usage.completion_tokens
|
|
234
|
-
# no cache tokens it seems
|
|
235
|
-
return output
|
|
236
|
-
|
|
237
46
|
@override
|
|
238
47
|
async def _calculate_cost(
|
|
239
48
|
self,
|
model_library/providers/vals.py
CHANGED
|
@@ -27,6 +27,7 @@ from model_library.base import (
|
|
|
27
27
|
TextInput,
|
|
28
28
|
ToolDefinition,
|
|
29
29
|
)
|
|
30
|
+
from model_library.register_models import register_provider
|
|
30
31
|
from model_library.utils import truncate_str
|
|
31
32
|
|
|
32
33
|
FAIL_RATE = 0.1
|
|
@@ -145,6 +146,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
|
|
|
145
146
|
return batch_status == "failed"
|
|
146
147
|
|
|
147
148
|
|
|
149
|
+
@register_provider("vals")
|
|
148
150
|
class DummyAIModel(LLM):
|
|
149
151
|
_client: Redis | None = None
|
|
150
152
|
|
model_library/providers/xai.py
CHANGED
|
@@ -39,6 +39,7 @@ from model_library.exceptions import (
|
|
|
39
39
|
RateLimitException,
|
|
40
40
|
)
|
|
41
41
|
from model_library.providers.openai import OpenAIModel
|
|
42
|
+
from model_library.register_models import register_provider
|
|
42
43
|
from model_library.utils import create_openai_client_with_defaults
|
|
43
44
|
|
|
44
45
|
Chat = AsyncChat | SyncChat
|
|
@@ -48,6 +49,7 @@ class XAIConfig(ProviderConfig):
|
|
|
48
49
|
sync_client: bool = False
|
|
49
50
|
|
|
50
51
|
|
|
52
|
+
@register_provider("grok")
|
|
51
53
|
class XAIModel(LLM):
|
|
52
54
|
provider_config = XAIConfig()
|
|
53
55
|
|
model_library/providers/zai.py
CHANGED
|
@@ -1,27 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Literal, Sequence
|
|
3
|
-
|
|
4
|
-
from typing_extensions import override
|
|
1
|
+
from typing import Literal
|
|
5
2
|
|
|
6
3
|
from model_library import model_library_settings
|
|
7
4
|
from model_library.base import (
|
|
8
|
-
|
|
9
|
-
FileInput,
|
|
10
|
-
FileWithId,
|
|
11
|
-
InputItem,
|
|
5
|
+
DelegateOnly,
|
|
12
6
|
LLMConfig,
|
|
13
|
-
QueryResult,
|
|
14
|
-
ToolDefinition,
|
|
15
7
|
)
|
|
16
8
|
from model_library.providers.openai import OpenAIModel
|
|
9
|
+
from model_library.register_models import register_provider
|
|
17
10
|
from model_library.utils import create_openai_client_with_defaults
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_client(self) -> None:
|
|
23
|
-
raise NotImplementedError("Not implemented")
|
|
24
|
-
|
|
13
|
+
@register_provider("zai")
|
|
14
|
+
class ZAIModel(DelegateOnly):
|
|
25
15
|
def __init__(
|
|
26
16
|
self,
|
|
27
17
|
model_name: str,
|
|
@@ -30,73 +20,15 @@ class ZAIModel(LLM):
|
|
|
30
20
|
config: LLMConfig | None = None,
|
|
31
21
|
):
|
|
32
22
|
super().__init__(model_name, provider, config=config)
|
|
33
|
-
self.model_name: str = model_name
|
|
34
|
-
self.native: bool = False
|
|
35
23
|
|
|
36
|
-
# https://docs.z.ai/
|
|
37
|
-
self.delegate
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
base_url="https://open.bigmodel.cn/api/paas/v4/",
|
|
47
|
-
),
|
|
48
|
-
use_completions=True,
|
|
49
|
-
)
|
|
24
|
+
# https://docs.z.ai/guides/develop/openai/python
|
|
25
|
+
self.delegate = OpenAIModel(
|
|
26
|
+
model_name=self.model_name,
|
|
27
|
+
provider=self.provider,
|
|
28
|
+
config=config,
|
|
29
|
+
custom_client=create_openai_client_with_defaults(
|
|
30
|
+
api_key=model_library_settings.ZAI_API_KEY,
|
|
31
|
+
base_url="https://open.bigmodel.cn/api/paas/v4/",
|
|
32
|
+
),
|
|
33
|
+
use_completions=True,
|
|
50
34
|
)
|
|
51
|
-
|
|
52
|
-
@override
|
|
53
|
-
async def parse_input(
|
|
54
|
-
self,
|
|
55
|
-
input: Sequence[InputItem],
|
|
56
|
-
**kwargs: Any,
|
|
57
|
-
) -> Any:
|
|
58
|
-
raise NotImplementedError()
|
|
59
|
-
|
|
60
|
-
@override
|
|
61
|
-
async def parse_image(
|
|
62
|
-
self,
|
|
63
|
-
image: FileInput,
|
|
64
|
-
) -> Any:
|
|
65
|
-
raise NotImplementedError()
|
|
66
|
-
|
|
67
|
-
@override
|
|
68
|
-
async def parse_file(
|
|
69
|
-
self,
|
|
70
|
-
file: FileInput,
|
|
71
|
-
) -> Any:
|
|
72
|
-
raise NotImplementedError()
|
|
73
|
-
|
|
74
|
-
@override
|
|
75
|
-
async def parse_tools(
|
|
76
|
-
self,
|
|
77
|
-
tools: list[ToolDefinition],
|
|
78
|
-
) -> Any:
|
|
79
|
-
raise NotImplementedError()
|
|
80
|
-
|
|
81
|
-
@override
|
|
82
|
-
async def upload_file(
|
|
83
|
-
self,
|
|
84
|
-
name: str,
|
|
85
|
-
mime: str,
|
|
86
|
-
bytes: io.BytesIO,
|
|
87
|
-
type: Literal["image", "file"] = "file",
|
|
88
|
-
) -> FileWithId:
|
|
89
|
-
raise NotImplementedError()
|
|
90
|
-
|
|
91
|
-
@override
|
|
92
|
-
async def _query_impl(
|
|
93
|
-
self,
|
|
94
|
-
input: Sequence[InputItem],
|
|
95
|
-
*,
|
|
96
|
-
tools: list[ToolDefinition],
|
|
97
|
-
**kwargs: object,
|
|
98
|
-
) -> QueryResult:
|
|
99
|
-
# relies on oAI delegate
|
|
100
|
-
if self.delegate:
|
|
101
|
-
return await self.delegate_query(input, tools=tools, **kwargs)
|
|
102
|
-
raise NotImplementedError()
|
model_library/register_models.py
CHANGED
|
@@ -1,61 +1,23 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pkgutil
|
|
1
3
|
import threading
|
|
2
4
|
from copy import deepcopy
|
|
3
5
|
from datetime import date
|
|
4
6
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
7
|
+
from typing import Any, Callable, Type, TypeVar, cast, get_type_hints
|
|
6
8
|
|
|
7
9
|
import yaml
|
|
8
10
|
from pydantic import create_model, model_validator
|
|
9
11
|
from pydantic.fields import Field
|
|
10
12
|
from pydantic.main import BaseModel
|
|
11
13
|
|
|
14
|
+
from model_library import providers
|
|
12
15
|
from model_library.base import LLM, ProviderConfig
|
|
13
|
-
from model_library.providers.ai21labs import AI21LabsModel
|
|
14
|
-
from model_library.providers.alibaba import AlibabaModel
|
|
15
|
-
from model_library.providers.amazon import AmazonModel
|
|
16
|
-
from model_library.providers.anthropic import AnthropicModel
|
|
17
|
-
from model_library.providers.azure import AzureOpenAIModel
|
|
18
|
-
from model_library.providers.cohere import CohereModel
|
|
19
|
-
from model_library.providers.deepseek import DeepSeekModel
|
|
20
|
-
from model_library.providers.fireworks import FireworksModel
|
|
21
|
-
from model_library.providers.google.google import GoogleModel
|
|
22
|
-
from model_library.providers.inception import MercuryModel
|
|
23
|
-
from model_library.providers.kimi import KimiModel
|
|
24
|
-
from model_library.providers.mistral import MistralModel
|
|
25
|
-
from model_library.providers.openai import OpenAIModel
|
|
26
|
-
from model_library.providers.perplexity import PerplexityModel
|
|
27
|
-
from model_library.providers.together import TogetherModel
|
|
28
|
-
from model_library.providers.vals import DummyAIModel
|
|
29
|
-
from model_library.providers.xai import XAIModel
|
|
30
|
-
from model_library.providers.zai import ZAIModel
|
|
31
16
|
from model_library.utils import get_logger
|
|
32
17
|
|
|
33
|
-
|
|
34
|
-
"openai": OpenAIModel,
|
|
35
|
-
"azure": AzureOpenAIModel,
|
|
36
|
-
"anthropic": AnthropicModel,
|
|
37
|
-
"together": TogetherModel,
|
|
38
|
-
"mistralai": MistralModel,
|
|
39
|
-
"grok": XAIModel,
|
|
40
|
-
"fireworks": FireworksModel,
|
|
41
|
-
"ai21labs": AI21LabsModel,
|
|
42
|
-
"amazon": AmazonModel,
|
|
43
|
-
"bedrock": AmazonModel,
|
|
44
|
-
"cohere": CohereModel,
|
|
45
|
-
"google": GoogleModel,
|
|
46
|
-
"vals": DummyAIModel,
|
|
47
|
-
"alibaba": AlibabaModel,
|
|
48
|
-
"perplexity": PerplexityModel,
|
|
49
|
-
"deepseek": DeepSeekModel,
|
|
50
|
-
"zai": ZAIModel,
|
|
51
|
-
"kimi": KimiModel,
|
|
52
|
-
"inception": MercuryModel,
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
logger = get_logger(__name__)
|
|
56
|
-
# Folder containing provider YAMLs
|
|
57
|
-
path_library = Path(__file__).parent / "config"
|
|
18
|
+
T = TypeVar("T", bound=LLM)
|
|
58
19
|
|
|
20
|
+
logger = get_logger("register_models")
|
|
59
21
|
|
|
60
22
|
"""
|
|
61
23
|
Model Registry structure
|
|
@@ -174,14 +136,13 @@ class ClassProperties(BaseModel):
|
|
|
174
136
|
Each provider can have a set of provider-specific properties, we however want to accept
|
|
175
137
|
any possible property from a provider in the yaml, and validate later. So we join all
|
|
176
138
|
provider-specific properties into a single class.
|
|
139
|
+
This has no effect on runtime use of ProviderConfig, only used to load the yaml
|
|
177
140
|
"""
|
|
178
141
|
|
|
179
142
|
|
|
180
143
|
class BaseProviderProperties(BaseModel):
|
|
181
144
|
"""Static base class for dynamic ProviderProperties."""
|
|
182
145
|
|
|
183
|
-
pass
|
|
184
|
-
|
|
185
146
|
|
|
186
147
|
def all_subclasses(cls: type) -> list[type]:
|
|
187
148
|
"""Recursively find all subclasses of a class."""
|
|
@@ -210,14 +171,6 @@ def get_dynamic_provider_properties_model() -> type[BaseProviderProperties]:
|
|
|
210
171
|
)
|
|
211
172
|
|
|
212
173
|
|
|
213
|
-
ProviderProperties = get_dynamic_provider_properties_model()
|
|
214
|
-
|
|
215
|
-
if TYPE_CHECKING:
|
|
216
|
-
ProviderPropertiesType = BaseProviderProperties
|
|
217
|
-
else:
|
|
218
|
-
ProviderPropertiesType = ProviderProperties
|
|
219
|
-
|
|
220
|
-
|
|
221
174
|
class DefaultParameters(BaseModel):
|
|
222
175
|
max_output_tokens: int | None = None
|
|
223
176
|
temperature: float | None = None
|
|
@@ -234,13 +187,20 @@ class RawModelConfig(BaseModel):
|
|
|
234
187
|
documentation_url: str | None = None
|
|
235
188
|
properties: Properties = Field(default_factory=Properties)
|
|
236
189
|
class_properties: ClassProperties = Field(default_factory=ClassProperties)
|
|
237
|
-
provider_properties:
|
|
238
|
-
default_factory=ProviderProperties
|
|
239
|
-
)
|
|
190
|
+
provider_properties: BaseProviderProperties | None = None
|
|
240
191
|
costs_per_million_token: CostProperties = Field(default_factory=CostProperties)
|
|
241
192
|
alternative_keys: list[str | dict[str, Any]] = Field(default_factory=list)
|
|
242
193
|
default_parameters: DefaultParameters = Field(default_factory=DefaultParameters)
|
|
243
194
|
|
|
195
|
+
def model_dump(self, *args: object, **kwargs: object):
|
|
196
|
+
data = super().model_dump(*args, **kwargs)
|
|
197
|
+
if self.provider_properties is not None:
|
|
198
|
+
# explicitly dump dynamic ProviderProperties instance
|
|
199
|
+
data["provider_properties"] = self.provider_properties.model_dump(
|
|
200
|
+
*args, **kwargs
|
|
201
|
+
)
|
|
202
|
+
return data
|
|
203
|
+
|
|
244
204
|
|
|
245
205
|
class ModelConfig(RawModelConfig):
|
|
246
206
|
# post processing fields
|
|
@@ -252,6 +212,9 @@ class ModelConfig(RawModelConfig):
|
|
|
252
212
|
|
|
253
213
|
ModelRegistry = dict[str, ModelConfig]
|
|
254
214
|
|
|
215
|
+
# Folder containing provider YAMLs
|
|
216
|
+
path_library = Path(__file__).parent / "config"
|
|
217
|
+
|
|
255
218
|
|
|
256
219
|
def deep_update(
|
|
257
220
|
base: dict[str, Any], updates: dict[str, str | dict[str, Any]]
|
|
@@ -270,6 +233,9 @@ def _register_models() -> ModelRegistry:
|
|
|
270
233
|
|
|
271
234
|
registry: ModelRegistry = {}
|
|
272
235
|
|
|
236
|
+
# generate ProviderProperties class
|
|
237
|
+
ProviderProperties = get_dynamic_provider_properties_model()
|
|
238
|
+
|
|
273
239
|
# load each provider YAML
|
|
274
240
|
sections = Path(path_library).glob("*.yaml")
|
|
275
241
|
sections = sorted(sections, key=lambda x: "openai" in x.name.lower())
|
|
@@ -325,6 +291,10 @@ def _register_models() -> ModelRegistry:
|
|
|
325
291
|
"slug": model_name.replace("/", "_"),
|
|
326
292
|
}
|
|
327
293
|
)
|
|
294
|
+
# load provider properties separately since the model was generated at runtime
|
|
295
|
+
model_obj.provider_properties = ProviderProperties.model_validate(
|
|
296
|
+
current_model_config.get("provider_properties", {})
|
|
297
|
+
)
|
|
328
298
|
|
|
329
299
|
registry[model_name] = model_obj
|
|
330
300
|
|
|
@@ -371,6 +341,50 @@ def _register_models() -> ModelRegistry:
|
|
|
371
341
|
return registry
|
|
372
342
|
|
|
373
343
|
|
|
344
|
+
_provider_registry: dict[str, type[LLM]] = {}
|
|
345
|
+
_provider_registry_lock = threading.Lock()
|
|
346
|
+
_imported_providers = False
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def register_provider(name: str) -> Callable[[Type[T]], Type[T]]:
|
|
350
|
+
def decorator(cls: Type[T]) -> Type[T]:
|
|
351
|
+
logger.debug(f"Registering provider {name}")
|
|
352
|
+
|
|
353
|
+
if name in _provider_registry:
|
|
354
|
+
raise ValueError(f"Provider {name} is already registered.")
|
|
355
|
+
_provider_registry[name] = cls
|
|
356
|
+
return cls
|
|
357
|
+
|
|
358
|
+
return decorator
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _import_all_providers():
|
|
362
|
+
"""Import all provider modules. Any class with @register_provider will be automatically registered upon import"""
|
|
363
|
+
|
|
364
|
+
package_name = providers.__name__
|
|
365
|
+
|
|
366
|
+
# walk all submodules recursively
|
|
367
|
+
for _, module_name, _ in pkgutil.walk_packages(
|
|
368
|
+
providers.__path__, package_name + "."
|
|
369
|
+
):
|
|
370
|
+
# skip private modules
|
|
371
|
+
if module_name.split(".")[-1].startswith("_"):
|
|
372
|
+
continue
|
|
373
|
+
importlib.import_module(module_name)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def get_provider_registry() -> dict[str, type[LLM]]:
|
|
377
|
+
"""Return the provider registry, lazily loading all modules on first call."""
|
|
378
|
+
global _imported_providers
|
|
379
|
+
if not _imported_providers:
|
|
380
|
+
with _provider_registry_lock:
|
|
381
|
+
if not _imported_providers:
|
|
382
|
+
_import_all_providers()
|
|
383
|
+
_imported_providers = True
|
|
384
|
+
|
|
385
|
+
return _provider_registry
|
|
386
|
+
|
|
387
|
+
|
|
374
388
|
_model_registry: ModelRegistry | None = None
|
|
375
389
|
_model_registry_lock = threading.Lock()
|
|
376
390
|
|
|
@@ -381,5 +395,9 @@ def get_model_registry() -> ModelRegistry:
|
|
|
381
395
|
if _model_registry is None:
|
|
382
396
|
with _model_registry_lock:
|
|
383
397
|
if _model_registry is None:
|
|
398
|
+
# initialize provider registry
|
|
399
|
+
global get_provider_registry
|
|
400
|
+
get_provider_registry()
|
|
401
|
+
|
|
384
402
|
_model_registry = _register_models()
|
|
385
403
|
return _model_registry
|
model_library/registry_utils.py
CHANGED
|
@@ -5,10 +5,10 @@ import tiktoken
|
|
|
5
5
|
|
|
6
6
|
from model_library.base import LLM, LLMConfig, ProviderConfig
|
|
7
7
|
from model_library.register_models import (
|
|
8
|
-
MAPPING_PROVIDERS,
|
|
9
8
|
CostProperties,
|
|
10
9
|
ModelConfig,
|
|
11
10
|
get_model_registry,
|
|
11
|
+
get_provider_registry,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
ALL_MODELS_PATH = Path(__file__).parent / "config" / "all_models.json"
|
|
@@ -51,7 +51,7 @@ def create_config(
|
|
|
51
51
|
|
|
52
52
|
# load provider config with correct type
|
|
53
53
|
if provider_properties:
|
|
54
|
-
ModelClass: type[LLM] =
|
|
54
|
+
ModelClass: type[LLM] = get_provider_registry()[registry_config.provider_name]
|
|
55
55
|
if hasattr(ModelClass, "provider_config"):
|
|
56
56
|
ProviderConfigClass: type[ProviderConfig] = type(ModelClass.provider_config) # type: ignore
|
|
57
57
|
provider_config: ProviderConfig = ProviderConfigClass.model_validate(
|
|
@@ -89,7 +89,7 @@ def _get_model_from_registry(
|
|
|
89
89
|
|
|
90
90
|
provider_name: str = registry_config.provider_name
|
|
91
91
|
provider_endpoint: str = registry_config.provider_endpoint
|
|
92
|
-
ModelClass: type[LLM] =
|
|
92
|
+
ModelClass: type[LLM] = get_provider_registry()[provider_name]
|
|
93
93
|
|
|
94
94
|
return ModelClass(
|
|
95
95
|
model_name=provider_endpoint,
|
|
@@ -115,7 +115,7 @@ def get_registry_model(model_str: str, override_config: LLMConfig | None = None)
|
|
|
115
115
|
def get_raw_model(model_str: str, config: LLMConfig | None = None) -> LLM:
|
|
116
116
|
"""Get a model exluding default config"""
|
|
117
117
|
provider, model_name = model_str.split("/", 1)
|
|
118
|
-
ModelClass =
|
|
118
|
+
ModelClass = get_provider_registry()[provider]
|
|
119
119
|
return ModelClass(model_name=model_name, provider=provider, config=config)
|
|
120
120
|
|
|
121
121
|
|
|
@@ -130,7 +130,7 @@ def get_model_cost(model_str: str) -> CostProperties | None:
|
|
|
130
130
|
@cache
|
|
131
131
|
def get_provider_names() -> list[str]:
|
|
132
132
|
"""Return all provider names in the registry"""
|
|
133
|
-
return sorted([provider_name for provider_name in
|
|
133
|
+
return sorted([provider_name for provider_name in get_provider_registry().keys()])
|
|
134
134
|
|
|
135
135
|
|
|
136
136
|
@cache
|