model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +98 -0
- model_library/base/delegate_only.py +10 -0
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +19 -7
- model_library/providers/amazon.py +70 -48
- model_library/providers/anthropic.py +101 -74
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +83 -45
- model_library/providers/minimax.py +19 -0
- model_library/providers/mistral.py +41 -27
- model_library/providers/openai.py +122 -73
- model_library/providers/vals.py +4 -3
- model_library/providers/xai.py +123 -115
- model_library/register_models.py +4 -2
- model_library/utils.py +0 -35
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/RECORD +24 -24
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -3,15 +3,15 @@ import logging
|
|
|
3
3
|
from typing import Any, Literal, Sequence, cast
|
|
4
4
|
|
|
5
5
|
from anthropic import AsyncAnthropic
|
|
6
|
-
from anthropic.types import TextBlock, ToolUseBlock
|
|
7
6
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
8
|
-
from anthropic.types.
|
|
7
|
+
from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
|
|
9
8
|
from typing_extensions import override
|
|
10
9
|
|
|
11
10
|
from model_library import model_library_settings
|
|
12
11
|
from model_library.base import (
|
|
13
12
|
LLM,
|
|
14
13
|
BatchResult,
|
|
14
|
+
FileBase,
|
|
15
15
|
FileInput,
|
|
16
16
|
FileWithBase64,
|
|
17
17
|
FileWithId,
|
|
@@ -22,7 +22,8 @@ from model_library.base import (
|
|
|
22
22
|
QueryResult,
|
|
23
23
|
QueryResultCost,
|
|
24
24
|
QueryResultMetadata,
|
|
25
|
-
|
|
25
|
+
RawInput,
|
|
26
|
+
RawResponse,
|
|
26
27
|
TextInput,
|
|
27
28
|
ToolBody,
|
|
28
29
|
ToolCall,
|
|
@@ -31,6 +32,7 @@ from model_library.base import (
|
|
|
31
32
|
)
|
|
32
33
|
from model_library.exceptions import (
|
|
33
34
|
MaxOutputTokensExceededError,
|
|
35
|
+
NoMatchingToolCallError,
|
|
34
36
|
)
|
|
35
37
|
from model_library.model_utils import get_default_budget_tokens
|
|
36
38
|
from model_library.providers.openai import OpenAIModel
|
|
@@ -38,8 +40,6 @@ from model_library.register_models import register_provider
|
|
|
38
40
|
from model_library.utils import (
|
|
39
41
|
create_openai_client_with_defaults,
|
|
40
42
|
default_httpx_client,
|
|
41
|
-
filter_empty_text_blocks,
|
|
42
|
-
normalize_tool_result,
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
|
|
@@ -62,9 +62,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
62
62
|
|
|
63
63
|
Format: {"custom_id": str, "params": {...message params...}}
|
|
64
64
|
"""
|
|
65
|
-
# Build the message body using the parent model's
|
|
65
|
+
# Build the message body using the parent model's build_body method
|
|
66
66
|
tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
|
|
67
|
-
body = await self._root.
|
|
67
|
+
body = await self._root.build_body(input, tools=tools, **kwargs)
|
|
68
68
|
|
|
69
69
|
return {
|
|
70
70
|
"custom_id": custom_id,
|
|
@@ -300,6 +300,20 @@ class AnthropicModel(LLM):
|
|
|
300
300
|
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
301
301
|
)
|
|
302
302
|
|
|
303
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
304
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
305
|
+
tool_call_ids: list[str] = []
|
|
306
|
+
|
|
307
|
+
calls = [
|
|
308
|
+
y
|
|
309
|
+
for x in raw_responses
|
|
310
|
+
if isinstance(x.response, ParsedBetaMessage)
|
|
311
|
+
for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
312
|
+
if isinstance(y, BetaToolUseBlock)
|
|
313
|
+
]
|
|
314
|
+
tool_call_ids.extend([x.id for x in calls])
|
|
315
|
+
return tool_call_ids
|
|
316
|
+
|
|
303
317
|
@override
|
|
304
318
|
async def parse_input(
|
|
305
319
|
self,
|
|
@@ -307,77 +321,61 @@ class AnthropicModel(LLM):
|
|
|
307
321
|
**kwargs: Any,
|
|
308
322
|
) -> list[dict[str, Any] | Any]:
|
|
309
323
|
new_input: list[dict[str, Any] | Any] = []
|
|
324
|
+
|
|
310
325
|
content_user: list[dict[str, Any]] = []
|
|
311
326
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
|
|
320
|
-
tool_calls_in_input.add(content.id)
|
|
327
|
+
def flush_content_user():
|
|
328
|
+
if content_user:
|
|
329
|
+
# NOTE: must make new object as we clear()
|
|
330
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
331
|
+
content_user.clear()
|
|
332
|
+
|
|
333
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
321
334
|
|
|
322
335
|
for item in input:
|
|
336
|
+
if isinstance(item, TextInput):
|
|
337
|
+
content_user.append({"type": "text", "text": item.text})
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
if isinstance(item, FileBase):
|
|
341
|
+
match item.type:
|
|
342
|
+
case "image":
|
|
343
|
+
parsed = await self.parse_image(item)
|
|
344
|
+
case "file":
|
|
345
|
+
parsed = await self.parse_file(item)
|
|
346
|
+
content_user.append(parsed)
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
# non content user item
|
|
350
|
+
flush_content_user()
|
|
351
|
+
|
|
323
352
|
match item:
|
|
324
|
-
case
|
|
325
|
-
if item.
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
content_user.append(await self.parse_file(item))
|
|
333
|
-
case _:
|
|
334
|
-
if content_user:
|
|
335
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
336
|
-
if filtered:
|
|
337
|
-
new_input.append({"role": "user", "content": filtered})
|
|
338
|
-
content_user = []
|
|
339
|
-
match item:
|
|
340
|
-
case ToolResult():
|
|
341
|
-
if item.tool_call.id not in tool_calls_in_input:
|
|
342
|
-
raise Exception(
|
|
343
|
-
"Tool call result provided with no matching tool call"
|
|
344
|
-
)
|
|
345
|
-
result_str = normalize_tool_result(item.result)
|
|
346
|
-
new_input.append(
|
|
353
|
+
case ToolResult():
|
|
354
|
+
if item.tool_call.id not in tool_call_ids:
|
|
355
|
+
raise NoMatchingToolCallError()
|
|
356
|
+
|
|
357
|
+
new_input.append(
|
|
358
|
+
{
|
|
359
|
+
"role": "user",
|
|
360
|
+
"content": [
|
|
347
361
|
{
|
|
348
|
-
"
|
|
349
|
-
"
|
|
350
|
-
|
|
351
|
-
"type": "tool_result",
|
|
352
|
-
"tool_use_id": item.tool_call.id,
|
|
353
|
-
"content": [
|
|
354
|
-
{"type": "text", "text": result_str}
|
|
355
|
-
],
|
|
356
|
-
}
|
|
357
|
-
],
|
|
362
|
+
"type": "tool_result",
|
|
363
|
+
"tool_use_id": item.tool_call.id,
|
|
364
|
+
"content": [{"type": "text", "text": item.result}],
|
|
358
365
|
}
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
]
|
|
371
|
-
if filtered_content:
|
|
372
|
-
new_input.append(
|
|
373
|
-
{"role": "assistant", "content": filtered_content}
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
if content_user:
|
|
377
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
378
|
-
if filtered:
|
|
379
|
-
new_input.append({"role": "user", "content": filtered})
|
|
366
|
+
],
|
|
367
|
+
}
|
|
368
|
+
)
|
|
369
|
+
case RawResponse():
|
|
370
|
+
content = cast(ParsedBetaMessage, item.response).content
|
|
371
|
+
new_input.append({"role": "assistant", "content": content})
|
|
372
|
+
case RawInput():
|
|
373
|
+
new_input.append(item.input)
|
|
374
|
+
|
|
375
|
+
# in case content user item is the last item
|
|
376
|
+
flush_content_user()
|
|
380
377
|
|
|
378
|
+
# cache control
|
|
381
379
|
if new_input:
|
|
382
380
|
last_msg = new_input[-1]
|
|
383
381
|
if not isinstance(last_msg, dict):
|
|
@@ -495,7 +493,7 @@ class AnthropicModel(LLM):
|
|
|
495
493
|
bytes: io.BytesIO,
|
|
496
494
|
type: Literal["image", "file"] = "file",
|
|
497
495
|
) -> FileWithId:
|
|
498
|
-
file_mime = f"image/{mime}" if type == "image" else mime
|
|
496
|
+
file_mime = f"image/{mime}" if type == "image" else mime
|
|
499
497
|
response = await self.get_client().beta.files.upload(
|
|
500
498
|
file=(
|
|
501
499
|
name,
|
|
@@ -513,7 +511,8 @@ class AnthropicModel(LLM):
|
|
|
513
511
|
|
|
514
512
|
cache_control = {"type": "ephemeral"} # 5 min cache
|
|
515
513
|
|
|
516
|
-
|
|
514
|
+
@override
|
|
515
|
+
async def build_body(
|
|
517
516
|
self,
|
|
518
517
|
input: Sequence[InputItem],
|
|
519
518
|
*,
|
|
@@ -573,7 +572,7 @@ class AnthropicModel(LLM):
|
|
|
573
572
|
input, tools=tools, query_logger=query_logger, **kwargs
|
|
574
573
|
)
|
|
575
574
|
|
|
576
|
-
body = await self.
|
|
575
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
577
576
|
|
|
578
577
|
client = self.get_client()
|
|
579
578
|
|
|
@@ -630,9 +629,37 @@ class AnthropicModel(LLM):
|
|
|
630
629
|
cache_write_tokens=message.usage.cache_creation_input_tokens,
|
|
631
630
|
),
|
|
632
631
|
tool_calls=tool_calls,
|
|
633
|
-
history=[*input, message],
|
|
632
|
+
history=[*input, RawResponse(response=message)],
|
|
634
633
|
)
|
|
635
634
|
|
|
635
|
+
@override
|
|
636
|
+
async def count_tokens(
|
|
637
|
+
self,
|
|
638
|
+
input: Sequence[InputItem],
|
|
639
|
+
*,
|
|
640
|
+
history: Sequence[InputItem] = [],
|
|
641
|
+
tools: list[ToolDefinition] = [],
|
|
642
|
+
**kwargs: object,
|
|
643
|
+
) -> int:
|
|
644
|
+
"""
|
|
645
|
+
Count the number of tokens using Anthropic's native token counting API.
|
|
646
|
+
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
|
647
|
+
"""
|
|
648
|
+
input = [*history, *input]
|
|
649
|
+
if not input:
|
|
650
|
+
return 0
|
|
651
|
+
|
|
652
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
653
|
+
|
|
654
|
+
# Remove fields not supported by count_tokens endpoint
|
|
655
|
+
body.pop("max_tokens", None)
|
|
656
|
+
body.pop("temperature", None)
|
|
657
|
+
|
|
658
|
+
client = self.get_client()
|
|
659
|
+
response = await client.messages.count_tokens(**body)
|
|
660
|
+
|
|
661
|
+
return response.input_tokens
|
|
662
|
+
|
|
636
663
|
@override
|
|
637
664
|
async def _calculate_cost(
|
|
638
665
|
self,
|
|
@@ -2,8 +2,6 @@ import io
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Final, Sequence, cast
|
|
4
4
|
|
|
5
|
-
from typing_extensions import override
|
|
6
|
-
|
|
7
5
|
from google.genai.types import (
|
|
8
6
|
BatchJob,
|
|
9
7
|
Content,
|
|
@@ -11,6 +9,8 @@ from google.genai.types import (
|
|
|
11
9
|
JobState,
|
|
12
10
|
UploadFileConfig,
|
|
13
11
|
)
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
14
|
from model_library.base import BatchResult, InputItem, LLMBatchMixin
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -144,7 +144,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
144
144
|
**kwargs: object,
|
|
145
145
|
) -> dict[str, Any]:
|
|
146
146
|
self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
|
|
147
|
-
body = await self._root.
|
|
147
|
+
body = await self._root.build_body(input, tools=[], **kwargs)
|
|
148
148
|
|
|
149
149
|
contents_any = body["contents"]
|
|
150
150
|
serialized_contents: list[dict[str, Any]] = [
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import io
|
|
3
3
|
import logging
|
|
4
|
+
import uuid
|
|
4
5
|
from typing import Any, Literal, Sequence, cast
|
|
5
6
|
|
|
6
7
|
from google.genai import Client
|
|
7
8
|
from google.genai import errors as genai_errors
|
|
8
9
|
from google.genai.types import (
|
|
9
10
|
Content,
|
|
11
|
+
CountTokensConfig,
|
|
10
12
|
File,
|
|
13
|
+
FinishReason,
|
|
11
14
|
FunctionDeclaration,
|
|
12
15
|
GenerateContentConfig,
|
|
13
16
|
GenerateContentResponse,
|
|
@@ -21,13 +24,13 @@ from google.genai.types import (
|
|
|
21
24
|
Tool,
|
|
22
25
|
ToolListUnion,
|
|
23
26
|
UploadFileConfig,
|
|
24
|
-
FinishReason,
|
|
25
27
|
)
|
|
26
28
|
from typing_extensions import override
|
|
27
29
|
|
|
28
30
|
from model_library import model_library_settings
|
|
29
31
|
from model_library.base import (
|
|
30
32
|
LLM,
|
|
33
|
+
FileBase,
|
|
31
34
|
FileInput,
|
|
32
35
|
FileWithBase64,
|
|
33
36
|
FileWithId,
|
|
@@ -40,6 +43,8 @@ from model_library.base import (
|
|
|
40
43
|
QueryResult,
|
|
41
44
|
QueryResultCost,
|
|
42
45
|
QueryResultMetadata,
|
|
46
|
+
RawInput,
|
|
47
|
+
RawResponse,
|
|
43
48
|
TextInput,
|
|
44
49
|
ToolBody,
|
|
45
50
|
ToolCall,
|
|
@@ -54,8 +59,6 @@ from model_library.exceptions import (
|
|
|
54
59
|
)
|
|
55
60
|
from model_library.providers.google.batch import GoogleBatchMixin
|
|
56
61
|
from model_library.register_models import register_provider
|
|
57
|
-
from model_library.utils import normalize_tool_result
|
|
58
|
-
import uuid
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
def generate_tool_call_id(tool_name: str) -> str:
|
|
@@ -146,63 +149,52 @@ class GoogleModel(LLM):
|
|
|
146
149
|
input: Sequence[InputItem],
|
|
147
150
|
**kwargs: Any,
|
|
148
151
|
) -> list[Content]:
|
|
149
|
-
|
|
150
|
-
parts: list[Part] = []
|
|
152
|
+
new_input: list[Content] = []
|
|
151
153
|
|
|
152
|
-
|
|
153
|
-
nonlocal parts
|
|
154
|
+
content_user: list[Part] = []
|
|
154
155
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
parts =
|
|
156
|
+
def flush_content_user():
|
|
157
|
+
if content_user:
|
|
158
|
+
new_input.append(Content(parts=content_user, role="user"))
|
|
159
|
+
content_user.clear()
|
|
158
160
|
|
|
159
161
|
for item in input:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
162
|
+
if isinstance(item, TextInput):
|
|
163
|
+
content_user.append(Part.from_text(text=item.text))
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
if isinstance(item, FileBase):
|
|
167
|
+
parsed = await self.parse_file(item)
|
|
168
|
+
content_user.append(parsed)
|
|
169
|
+
continue
|
|
164
170
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
parts.append(part)
|
|
171
|
+
# non content user item
|
|
172
|
+
flush_content_user()
|
|
168
173
|
|
|
174
|
+
match item:
|
|
169
175
|
case ToolResult():
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
parsed_input.append(
|
|
176
|
+
# id check
|
|
177
|
+
new_input.append(
|
|
173
178
|
Content(
|
|
174
179
|
role="function",
|
|
175
180
|
parts=[
|
|
176
181
|
Part.from_function_response(
|
|
177
182
|
name=item.tool_call.name,
|
|
178
|
-
response={"result":
|
|
183
|
+
response={"result": item.result},
|
|
179
184
|
)
|
|
180
185
|
],
|
|
181
186
|
)
|
|
182
187
|
)
|
|
183
188
|
|
|
184
|
-
case
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
content0 = candidates[0].content
|
|
189
|
-
if content0 is not None:
|
|
190
|
-
parsed_input.append(content0)
|
|
191
|
-
else:
|
|
192
|
-
self.logger.debug(
|
|
193
|
-
"GenerateContentResponse missing candidates; skipping"
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
case Content():
|
|
197
|
-
flush_parts()
|
|
198
|
-
parsed_input.append(item)
|
|
189
|
+
case RawResponse():
|
|
190
|
+
new_input.extend(item.response)
|
|
191
|
+
case RawInput():
|
|
192
|
+
new_input.append(item.input)
|
|
199
193
|
|
|
200
|
-
|
|
201
|
-
|
|
194
|
+
# in case content user item is the last item
|
|
195
|
+
flush_content_user()
|
|
202
196
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
return parsed_input
|
|
197
|
+
return new_input
|
|
206
198
|
|
|
207
199
|
@override
|
|
208
200
|
async def parse_file(self, file: FileInput) -> Part:
|
|
@@ -284,7 +276,8 @@ class GoogleModel(LLM):
|
|
|
284
276
|
mime=mime,
|
|
285
277
|
)
|
|
286
278
|
|
|
287
|
-
|
|
279
|
+
@override
|
|
280
|
+
async def build_body(
|
|
288
281
|
self,
|
|
289
282
|
input: Sequence[InputItem],
|
|
290
283
|
*,
|
|
@@ -337,7 +330,7 @@ class GoogleModel(LLM):
|
|
|
337
330
|
query_logger: logging.Logger,
|
|
338
331
|
**kwargs: object,
|
|
339
332
|
) -> QueryResult:
|
|
340
|
-
body: dict[str, Any] = await self.
|
|
333
|
+
body: dict[str, Any] = await self.build_body(input, tools=tools, **kwargs)
|
|
341
334
|
|
|
342
335
|
text: str = ""
|
|
343
336
|
reasoning: str = ""
|
|
@@ -395,7 +388,7 @@ class GoogleModel(LLM):
|
|
|
395
388
|
result = QueryResult(
|
|
396
389
|
output_text=text,
|
|
397
390
|
reasoning=reasoning,
|
|
398
|
-
history=[*input,
|
|
391
|
+
history=[*input, RawResponse(response=contents)],
|
|
399
392
|
tool_calls=tool_calls,
|
|
400
393
|
)
|
|
401
394
|
|
|
@@ -410,6 +403,51 @@ class GoogleModel(LLM):
|
|
|
410
403
|
)
|
|
411
404
|
return result
|
|
412
405
|
|
|
406
|
+
@override
|
|
407
|
+
async def count_tokens(
|
|
408
|
+
self,
|
|
409
|
+
input: Sequence[InputItem],
|
|
410
|
+
*,
|
|
411
|
+
history: Sequence[InputItem] = [],
|
|
412
|
+
tools: list[ToolDefinition] = [],
|
|
413
|
+
**kwargs: object,
|
|
414
|
+
) -> int:
|
|
415
|
+
"""
|
|
416
|
+
Count the number of tokens using Google's native token counting API.
|
|
417
|
+
https://ai.google.dev/gemini-api/docs/tokens
|
|
418
|
+
|
|
419
|
+
Only Vertex AI supports system_instruction and tools in count_tokens.
|
|
420
|
+
For Gemini API, fall back to the base implementation.
|
|
421
|
+
TODO: implement token counting for non-Vertex models.
|
|
422
|
+
"""
|
|
423
|
+
if not self.provider_config.use_vertex:
|
|
424
|
+
return await super().count_tokens(
|
|
425
|
+
input, history=history, tools=tools, **kwargs
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
input = [*history, *input]
|
|
429
|
+
if not input:
|
|
430
|
+
return 0
|
|
431
|
+
|
|
432
|
+
system_prompt = kwargs.pop("system_prompt", None)
|
|
433
|
+
contents = await self.parse_input(input, **kwargs)
|
|
434
|
+
parsed_tools = await self.parse_tools(tools) if tools else None
|
|
435
|
+
config = CountTokensConfig(
|
|
436
|
+
system_instruction=str(system_prompt) if system_prompt else None,
|
|
437
|
+
tools=parsed_tools,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
response = await self.client.aio.models.count_tokens(
|
|
441
|
+
model=self.model_name,
|
|
442
|
+
contents=cast(Any, contents),
|
|
443
|
+
config=config,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if response.total_tokens is None:
|
|
447
|
+
raise ValueError("count_tokens returned None")
|
|
448
|
+
|
|
449
|
+
return response.total_tokens
|
|
450
|
+
|
|
413
451
|
@override
|
|
414
452
|
async def _calculate_cost(
|
|
415
453
|
self,
|
|
@@ -446,7 +484,7 @@ class GoogleModel(LLM):
|
|
|
446
484
|
**kwargs: object,
|
|
447
485
|
) -> PydanticT:
|
|
448
486
|
# Create the request body with JSON schema
|
|
449
|
-
body: dict[str, Any] = await self.
|
|
487
|
+
body: dict[str, Any] = await self.build_body(input, tools=[], **kwargs)
|
|
450
488
|
|
|
451
489
|
# Get the JSON schema from the Pydantic model
|
|
452
490
|
json_schema = pydantic_model.model_json_schema()
|
|
@@ -2,12 +2,16 @@ from typing import Literal
|
|
|
2
2
|
|
|
3
3
|
from model_library import model_library_settings
|
|
4
4
|
from model_library.base import DelegateOnly, LLMConfig
|
|
5
|
+
from model_library.base.input import InputItem, ToolDefinition
|
|
5
6
|
from model_library.providers.anthropic import AnthropicModel
|
|
6
7
|
from model_library.register_models import register_provider
|
|
7
8
|
from model_library.utils import default_httpx_client
|
|
8
9
|
|
|
9
10
|
from anthropic import AsyncAnthropic
|
|
10
11
|
|
|
12
|
+
from typing import Sequence
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
11
15
|
|
|
12
16
|
@register_provider("minimax")
|
|
13
17
|
class MinimaxModel(DelegateOnly):
|
|
@@ -31,3 +35,18 @@ class MinimaxModel(DelegateOnly):
|
|
|
31
35
|
max_retries=1,
|
|
32
36
|
),
|
|
33
37
|
)
|
|
38
|
+
|
|
39
|
+
# minimax client shares anthropic's syntax
|
|
40
|
+
@override
|
|
41
|
+
async def count_tokens(
|
|
42
|
+
self,
|
|
43
|
+
input: Sequence[InputItem],
|
|
44
|
+
*,
|
|
45
|
+
history: Sequence[InputItem] = [],
|
|
46
|
+
tools: list[ToolDefinition] = [],
|
|
47
|
+
**kwargs: object,
|
|
48
|
+
) -> int:
|
|
49
|
+
assert self.delegate
|
|
50
|
+
return await self.delegate.count_tokens(
|
|
51
|
+
input, history=history, tools=tools, **kwargs
|
|
52
|
+
)
|