model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +114 -12
- model_library/base/delegate_only.py +15 -1
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/config/all_models.json +92 -1
- model_library/config/fireworks_models.yaml +2 -0
- model_library/config/minimax_models.yaml +18 -0
- model_library/config/zai_models.yaml +14 -0
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +20 -6
- model_library/providers/amazon.py +72 -48
- model_library/providers/anthropic.py +138 -85
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +92 -46
- model_library/providers/minimax.py +29 -10
- model_library/providers/mistral.py +42 -26
- model_library/providers/openai.py +131 -77
- model_library/providers/vals.py +6 -3
- model_library/providers/xai.py +125 -113
- model_library/register_models.py +5 -3
- model_library/utils.py +0 -35
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/RECORD +28 -28
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
|
+
import logging
|
|
4
|
+
import uuid
|
|
3
5
|
from typing import Any, Literal, Sequence, cast
|
|
4
6
|
|
|
5
7
|
from google.genai import Client
|
|
6
8
|
from google.genai import errors as genai_errors
|
|
7
9
|
from google.genai.types import (
|
|
8
10
|
Content,
|
|
11
|
+
CountTokensConfig,
|
|
9
12
|
File,
|
|
13
|
+
FinishReason,
|
|
10
14
|
FunctionDeclaration,
|
|
11
15
|
GenerateContentConfig,
|
|
12
16
|
GenerateContentResponse,
|
|
@@ -20,13 +24,13 @@ from google.genai.types import (
|
|
|
20
24
|
Tool,
|
|
21
25
|
ToolListUnion,
|
|
22
26
|
UploadFileConfig,
|
|
23
|
-
FinishReason,
|
|
24
27
|
)
|
|
25
28
|
from typing_extensions import override
|
|
26
29
|
|
|
27
30
|
from model_library import model_library_settings
|
|
28
31
|
from model_library.base import (
|
|
29
32
|
LLM,
|
|
33
|
+
FileBase,
|
|
30
34
|
FileInput,
|
|
31
35
|
FileWithBase64,
|
|
32
36
|
FileWithId,
|
|
@@ -39,6 +43,8 @@ from model_library.base import (
|
|
|
39
43
|
QueryResult,
|
|
40
44
|
QueryResultCost,
|
|
41
45
|
QueryResultMetadata,
|
|
46
|
+
RawInput,
|
|
47
|
+
RawResponse,
|
|
42
48
|
TextInput,
|
|
43
49
|
ToolBody,
|
|
44
50
|
ToolCall,
|
|
@@ -53,7 +59,10 @@ from model_library.exceptions import (
|
|
|
53
59
|
)
|
|
54
60
|
from model_library.providers.google.batch import GoogleBatchMixin
|
|
55
61
|
from model_library.register_models import register_provider
|
|
56
|
-
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def generate_tool_call_id(tool_name: str) -> str:
|
|
65
|
+
return str(tool_name + "_" + str(uuid.uuid4()))
|
|
57
66
|
|
|
58
67
|
|
|
59
68
|
class GoogleConfig(ProviderConfig):
|
|
@@ -140,63 +149,52 @@ class GoogleModel(LLM):
|
|
|
140
149
|
input: Sequence[InputItem],
|
|
141
150
|
**kwargs: Any,
|
|
142
151
|
) -> list[Content]:
|
|
143
|
-
|
|
144
|
-
parts: list[Part] = []
|
|
152
|
+
new_input: list[Content] = []
|
|
145
153
|
|
|
146
|
-
|
|
147
|
-
nonlocal parts
|
|
154
|
+
content_user: list[Part] = []
|
|
148
155
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
parts =
|
|
156
|
+
def flush_content_user():
|
|
157
|
+
if content_user:
|
|
158
|
+
new_input.append(Content(parts=content_user, role="user"))
|
|
159
|
+
content_user.clear()
|
|
152
160
|
|
|
153
161
|
for item in input:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
162
|
+
if isinstance(item, TextInput):
|
|
163
|
+
content_user.append(Part.from_text(text=item.text))
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
if isinstance(item, FileBase):
|
|
167
|
+
parsed = await self.parse_file(item)
|
|
168
|
+
content_user.append(parsed)
|
|
169
|
+
continue
|
|
158
170
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
parts.append(part)
|
|
171
|
+
# non content user item
|
|
172
|
+
flush_content_user()
|
|
162
173
|
|
|
174
|
+
match item:
|
|
163
175
|
case ToolResult():
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
parsed_input.append(
|
|
176
|
+
# id check
|
|
177
|
+
new_input.append(
|
|
167
178
|
Content(
|
|
168
179
|
role="function",
|
|
169
180
|
parts=[
|
|
170
181
|
Part.from_function_response(
|
|
171
182
|
name=item.tool_call.name,
|
|
172
|
-
response={"result":
|
|
183
|
+
response={"result": item.result},
|
|
173
184
|
)
|
|
174
185
|
],
|
|
175
186
|
)
|
|
176
187
|
)
|
|
177
188
|
|
|
178
|
-
case
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
content0 = candidates[0].content
|
|
183
|
-
if content0 is not None:
|
|
184
|
-
parsed_input.append(content0)
|
|
185
|
-
else:
|
|
186
|
-
self.logger.debug(
|
|
187
|
-
"GenerateContentResponse missing candidates; skipping"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
case Content():
|
|
191
|
-
flush_parts()
|
|
192
|
-
parsed_input.append(item)
|
|
189
|
+
case RawResponse():
|
|
190
|
+
new_input.extend(item.response)
|
|
191
|
+
case RawInput():
|
|
192
|
+
new_input.append(item.input)
|
|
193
193
|
|
|
194
|
-
|
|
195
|
-
|
|
194
|
+
# in case content user item is the last item
|
|
195
|
+
flush_content_user()
|
|
196
196
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return parsed_input
|
|
197
|
+
return new_input
|
|
200
198
|
|
|
201
199
|
@override
|
|
202
200
|
async def parse_file(self, file: FileInput) -> Part:
|
|
@@ -278,7 +276,8 @@ class GoogleModel(LLM):
|
|
|
278
276
|
mime=mime,
|
|
279
277
|
)
|
|
280
278
|
|
|
281
|
-
|
|
279
|
+
@override
|
|
280
|
+
async def build_body(
|
|
282
281
|
self,
|
|
283
282
|
input: Sequence[InputItem],
|
|
284
283
|
*,
|
|
@@ -328,9 +327,10 @@ class GoogleModel(LLM):
|
|
|
328
327
|
input: Sequence[InputItem],
|
|
329
328
|
*,
|
|
330
329
|
tools: list[ToolDefinition],
|
|
330
|
+
query_logger: logging.Logger,
|
|
331
331
|
**kwargs: object,
|
|
332
332
|
) -> QueryResult:
|
|
333
|
-
body: dict[str, Any] = await self.
|
|
333
|
+
body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
|
|
334
334
|
|
|
335
335
|
text: str = ""
|
|
336
336
|
reasoning: str = ""
|
|
@@ -357,9 +357,10 @@ class GoogleModel(LLM):
|
|
|
357
357
|
|
|
358
358
|
call_args = part.function_call.args or {}
|
|
359
359
|
tool_calls.append(
|
|
360
|
-
#
|
|
360
|
+
# Weirdly, id is not required. If not provided, we generate one.
|
|
361
361
|
ToolCall(
|
|
362
|
-
id=part.function_call.id
|
|
362
|
+
id=part.function_call.id
|
|
363
|
+
or generate_tool_call_id(part.function_call.name),
|
|
363
364
|
name=part.function_call.name,
|
|
364
365
|
args=call_args,
|
|
365
366
|
)
|
|
@@ -387,7 +388,7 @@ class GoogleModel(LLM):
|
|
|
387
388
|
result = QueryResult(
|
|
388
389
|
output_text=text,
|
|
389
390
|
reasoning=reasoning,
|
|
390
|
-
history=[*input,
|
|
391
|
+
history=[*input, RawResponse(response=contents)],
|
|
391
392
|
tool_calls=tool_calls,
|
|
392
393
|
)
|
|
393
394
|
|
|
@@ -402,6 +403,51 @@ class GoogleModel(LLM):
|
|
|
402
403
|
)
|
|
403
404
|
return result
|
|
404
405
|
|
|
406
|
+
@override
|
|
407
|
+
async def count_tokens(
|
|
408
|
+
self,
|
|
409
|
+
input: Sequence[InputItem],
|
|
410
|
+
*,
|
|
411
|
+
history: Sequence[InputItem] = [],
|
|
412
|
+
tools: list[ToolDefinition] = [],
|
|
413
|
+
**kwargs: object,
|
|
414
|
+
) -> int:
|
|
415
|
+
"""
|
|
416
|
+
Count the number of tokens using Google's native token counting API.
|
|
417
|
+
https://ai.google.dev/gemini-api/docs/tokens
|
|
418
|
+
|
|
419
|
+
Only Vertex AI supports system_instruction and tools in count_tokens.
|
|
420
|
+
For Gemini API, fall back to the base implementation.
|
|
421
|
+
TODO: implement token counting for non-Vertex models.
|
|
422
|
+
"""
|
|
423
|
+
if not self.provider_config.use_vertex:
|
|
424
|
+
return await super().count_tokens(
|
|
425
|
+
input, history=history, tools=tools, **kwargs
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
input = [*history, *input]
|
|
429
|
+
if not input:
|
|
430
|
+
return 0
|
|
431
|
+
|
|
432
|
+
system_prompt = kwargs.pop("system_prompt", None)
|
|
433
|
+
contents = await self.parse_input(input, **kwargs)
|
|
434
|
+
parsed_tools = await self.parse_tools(tools) if tools else None
|
|
435
|
+
config = CountTokensConfig(
|
|
436
|
+
system_instruction=str(system_prompt) if system_prompt else None,
|
|
437
|
+
tools=parsed_tools,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
response = await self.client.aio.models.count_tokens(
|
|
441
|
+
model=self.model_name,
|
|
442
|
+
contents=cast(Any, contents),
|
|
443
|
+
config=config,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if response.total_tokens is None:
|
|
447
|
+
raise ValueError("count_tokens returned None")
|
|
448
|
+
|
|
449
|
+
return response.total_tokens
|
|
450
|
+
|
|
405
451
|
@override
|
|
406
452
|
async def _calculate_cost(
|
|
407
453
|
self,
|
|
@@ -438,7 +484,7 @@ class GoogleModel(LLM):
|
|
|
438
484
|
**kwargs: object,
|
|
439
485
|
) -> PydanticT:
|
|
440
486
|
# Create the request body with JSON schema
|
|
441
|
-
body: dict[str, Any] = await self.
|
|
487
|
+
body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
|
|
442
488
|
|
|
443
489
|
# Get the JSON schema from the Pydantic model
|
|
444
490
|
json_schema = pydantic_model.model_json_schema()
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
3
|
from model_library import model_library_settings
|
|
4
|
-
from model_library.base import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
)
|
|
8
|
-
from model_library.providers.openai import OpenAIModel
|
|
4
|
+
from model_library.base import DelegateOnly, LLMConfig
|
|
5
|
+
from model_library.base.input import InputItem, ToolDefinition
|
|
6
|
+
from model_library.providers.anthropic import AnthropicModel
|
|
9
7
|
from model_library.register_models import register_provider
|
|
10
|
-
from model_library.utils import
|
|
8
|
+
from model_library.utils import default_httpx_client
|
|
9
|
+
|
|
10
|
+
from anthropic import AsyncAnthropic
|
|
11
|
+
|
|
12
|
+
from typing import Sequence
|
|
13
|
+
from typing_extensions import override
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@register_provider("minimax")
|
|
@@ -21,13 +24,29 @@ class MinimaxModel(DelegateOnly):
|
|
|
21
24
|
):
|
|
22
25
|
super().__init__(model_name, provider, config=config)
|
|
23
26
|
|
|
24
|
-
self.delegate =
|
|
27
|
+
self.delegate = AnthropicModel(
|
|
25
28
|
model_name=self.model_name,
|
|
26
29
|
provider=self.provider,
|
|
27
30
|
config=config,
|
|
28
|
-
custom_client=
|
|
31
|
+
custom_client=AsyncAnthropic(
|
|
29
32
|
api_key=model_library_settings.MINIMAX_API_KEY,
|
|
30
|
-
base_url="https://api.minimax.io/
|
|
33
|
+
base_url="https://api.minimax.io/anthropic",
|
|
34
|
+
http_client=default_httpx_client(),
|
|
35
|
+
max_retries=1,
|
|
31
36
|
),
|
|
32
|
-
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# minimax client shares anthropic's syntax
|
|
40
|
+
@override
|
|
41
|
+
async def count_tokens(
|
|
42
|
+
self,
|
|
43
|
+
input: Sequence[InputItem],
|
|
44
|
+
*,
|
|
45
|
+
history: Sequence[InputItem] = [],
|
|
46
|
+
tools: list[ToolDefinition] = [],
|
|
47
|
+
**kwargs: object,
|
|
48
|
+
) -> int:
|
|
49
|
+
assert self.delegate
|
|
50
|
+
return await self.delegate.count_tokens(
|
|
51
|
+
input, history=history, tools=tools, **kwargs
|
|
33
52
|
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import io
|
|
2
|
-
import
|
|
2
|
+
import logging
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import Any, Literal
|
|
5
5
|
|
|
@@ -12,14 +12,16 @@ from typing_extensions import override
|
|
|
12
12
|
from model_library import model_library_settings
|
|
13
13
|
from model_library.base import (
|
|
14
14
|
LLM,
|
|
15
|
+
FileBase,
|
|
15
16
|
FileInput,
|
|
16
17
|
FileWithBase64,
|
|
17
18
|
FileWithId,
|
|
18
|
-
FileWithUrl,
|
|
19
19
|
InputItem,
|
|
20
20
|
LLMConfig,
|
|
21
21
|
QueryResult,
|
|
22
22
|
QueryResultMetadata,
|
|
23
|
+
RawInput,
|
|
24
|
+
RawResponse,
|
|
23
25
|
TextInput,
|
|
24
26
|
ToolBody,
|
|
25
27
|
ToolCall,
|
|
@@ -68,27 +70,30 @@ class MistralModel(LLM):
|
|
|
68
70
|
content_user: list[dict[str, Any]] = []
|
|
69
71
|
|
|
70
72
|
def flush_content_user():
|
|
71
|
-
nonlocal content_user
|
|
72
|
-
|
|
73
73
|
if content_user:
|
|
74
|
-
|
|
75
|
-
|
|
74
|
+
# NOTE: must make new object as we clear()
|
|
75
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
76
|
+
content_user.clear()
|
|
76
77
|
|
|
77
78
|
for item in input:
|
|
79
|
+
if isinstance(item, TextInput):
|
|
80
|
+
content_user.append({"type": "text", "text": item.text})
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
if isinstance(item, FileBase):
|
|
84
|
+
match item.type:
|
|
85
|
+
case "image":
|
|
86
|
+
parsed = await self.parse_image(item)
|
|
87
|
+
case "file":
|
|
88
|
+
parsed = await self.parse_file(item)
|
|
89
|
+
content_user.append(parsed)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# non content user item
|
|
93
|
+
flush_content_user()
|
|
94
|
+
|
|
78
95
|
match item:
|
|
79
|
-
case TextInput():
|
|
80
|
-
content_user.append({"type": "text", "text": item.text})
|
|
81
|
-
case FileWithBase64() | FileWithUrl() | FileWithId():
|
|
82
|
-
match item.type:
|
|
83
|
-
case "image":
|
|
84
|
-
content_user.append(await self.parse_image(item))
|
|
85
|
-
case "file":
|
|
86
|
-
content_user.append(await self.parse_file(item))
|
|
87
|
-
case AssistantMessage():
|
|
88
|
-
flush_content_user()
|
|
89
|
-
new_input.append(item)
|
|
90
96
|
case ToolResult():
|
|
91
|
-
flush_content_user()
|
|
92
97
|
new_input.append(
|
|
93
98
|
{
|
|
94
99
|
"role": "tool",
|
|
@@ -97,9 +102,12 @@ class MistralModel(LLM):
|
|
|
97
102
|
"tool_call_id": item.tool_call.id,
|
|
98
103
|
}
|
|
99
104
|
)
|
|
100
|
-
case
|
|
101
|
-
|
|
105
|
+
case RawResponse():
|
|
106
|
+
new_input.append(item.response)
|
|
107
|
+
case RawInput():
|
|
108
|
+
new_input.append(item.input)
|
|
102
109
|
|
|
110
|
+
# in case content user item is the last item
|
|
103
111
|
flush_content_user()
|
|
104
112
|
|
|
105
113
|
return new_input
|
|
@@ -166,13 +174,13 @@ class MistralModel(LLM):
|
|
|
166
174
|
raise NotImplementedError()
|
|
167
175
|
|
|
168
176
|
@override
|
|
169
|
-
async def
|
|
177
|
+
async def build_body(
|
|
170
178
|
self,
|
|
171
179
|
input: Sequence[InputItem],
|
|
172
180
|
*,
|
|
173
181
|
tools: list[ToolDefinition],
|
|
174
182
|
**kwargs: object,
|
|
175
|
-
) ->
|
|
183
|
+
) -> dict[str, Any]:
|
|
176
184
|
# mistral supports max 8 images, merge extra images into the 8th image
|
|
177
185
|
input = trim_images(input, max_images=8)
|
|
178
186
|
|
|
@@ -203,8 +211,18 @@ class MistralModel(LLM):
|
|
|
203
211
|
body["top_p"] = self.top_p
|
|
204
212
|
|
|
205
213
|
body.update(kwargs)
|
|
214
|
+
return body
|
|
206
215
|
|
|
207
|
-
|
|
216
|
+
@override
|
|
217
|
+
async def _query_impl(
|
|
218
|
+
self,
|
|
219
|
+
input: Sequence[InputItem],
|
|
220
|
+
*,
|
|
221
|
+
tools: list[ToolDefinition],
|
|
222
|
+
query_logger: logging.Logger,
|
|
223
|
+
**kwargs: object,
|
|
224
|
+
) -> QueryResult:
|
|
225
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
208
226
|
|
|
209
227
|
response: EventStreamAsync[
|
|
210
228
|
CompletionEvent
|
|
@@ -245,8 +263,6 @@ class MistralModel(LLM):
|
|
|
245
263
|
in_tokens += data.usage.prompt_tokens or 0
|
|
246
264
|
out_tokens += data.usage.completion_tokens or 0
|
|
247
265
|
|
|
248
|
-
self.logger.info(f"Finished in: {time.time() - start}")
|
|
249
|
-
|
|
250
266
|
except Exception as e:
|
|
251
267
|
self.logger.error(f"Error: {e}", exc_info=True)
|
|
252
268
|
raise e
|
|
@@ -300,7 +316,7 @@ class MistralModel(LLM):
|
|
|
300
316
|
return QueryResult(
|
|
301
317
|
output_text=text,
|
|
302
318
|
reasoning=reasoning or None,
|
|
303
|
-
history=[*input, message],
|
|
319
|
+
history=[*input, RawResponse(response=message)],
|
|
304
320
|
tool_calls=tool_calls,
|
|
305
321
|
metadata=QueryResultMetadata(
|
|
306
322
|
in_tokens=in_tokens,
|