model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +114 -12
- model_library/base/delegate_only.py +15 -1
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/config/all_models.json +92 -1
- model_library/config/fireworks_models.yaml +2 -0
- model_library/config/minimax_models.yaml +18 -0
- model_library/config/zai_models.yaml +14 -0
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +20 -6
- model_library/providers/amazon.py +72 -48
- model_library/providers/anthropic.py +138 -85
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +92 -46
- model_library/providers/minimax.py +29 -10
- model_library/providers/mistral.py +42 -26
- model_library/providers/openai.py +131 -77
- model_library/providers/vals.py +6 -3
- model_library/providers/xai.py +125 -113
- model_library/register_models.py +5 -3
- model_library/utils.py +0 -35
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/RECORD +28 -28
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.5.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ import asyncio
|
|
|
3
3
|
import base64
|
|
4
4
|
import io
|
|
5
5
|
import json
|
|
6
|
+
import logging
|
|
6
7
|
from typing import Any, Literal, Sequence, cast
|
|
7
8
|
|
|
8
9
|
import boto3
|
|
@@ -12,24 +13,26 @@ from typing_extensions import override
|
|
|
12
13
|
|
|
13
14
|
from model_library.base import (
|
|
14
15
|
LLM,
|
|
16
|
+
FileBase,
|
|
15
17
|
FileInput,
|
|
16
18
|
FileWithBase64,
|
|
17
19
|
FileWithId,
|
|
18
|
-
FileWithUrl,
|
|
19
20
|
InputItem,
|
|
20
21
|
LLMConfig,
|
|
21
22
|
QueryResult,
|
|
22
23
|
QueryResultMetadata,
|
|
24
|
+
RawInput,
|
|
25
|
+
RawResponse,
|
|
23
26
|
TextInput,
|
|
24
27
|
ToolBody,
|
|
25
28
|
ToolCall,
|
|
26
29
|
ToolDefinition,
|
|
27
30
|
ToolResult,
|
|
28
31
|
)
|
|
29
|
-
from model_library.base.input import FileBase
|
|
30
32
|
from model_library.exceptions import (
|
|
31
33
|
BadInputError,
|
|
32
34
|
MaxOutputTokensExceededError,
|
|
35
|
+
NoMatchingToolCallError,
|
|
33
36
|
)
|
|
34
37
|
from model_library.model_utils import get_default_budget_tokens
|
|
35
38
|
from model_library.register_models import register_provider
|
|
@@ -69,6 +72,20 @@ class AmazonModel(LLM):
|
|
|
69
72
|
|
|
70
73
|
cache_control = {"type": "default"}
|
|
71
74
|
|
|
75
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
76
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
77
|
+
tool_call_ids: list[str] = []
|
|
78
|
+
|
|
79
|
+
calls = [
|
|
80
|
+
y["toolUse"]
|
|
81
|
+
for x in raw_responses
|
|
82
|
+
if "content" in x.response
|
|
83
|
+
for y in x.response["content"]
|
|
84
|
+
if "toolUse" in y
|
|
85
|
+
]
|
|
86
|
+
tool_call_ids.extend([x["toolUseId"] for x in calls])
|
|
87
|
+
return tool_call_ids
|
|
88
|
+
|
|
72
89
|
@override
|
|
73
90
|
async def parse_input(
|
|
74
91
|
self,
|
|
@@ -76,58 +93,63 @@ class AmazonModel(LLM):
|
|
|
76
93
|
**kwargs: Any,
|
|
77
94
|
) -> list[dict[str, Any]]:
|
|
78
95
|
new_input: list[dict[str, Any] | Any] = []
|
|
96
|
+
|
|
79
97
|
content_user: list[dict[str, Any]] = []
|
|
80
98
|
|
|
99
|
+
def flush_content_user():
|
|
100
|
+
if content_user:
|
|
101
|
+
# NOTE: must make new object as we clear()
|
|
102
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
103
|
+
content_user.clear()
|
|
104
|
+
|
|
105
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
106
|
+
|
|
81
107
|
for item in input:
|
|
108
|
+
if isinstance(item, TextInput):
|
|
109
|
+
content_user.append({"text": item.text})
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if isinstance(item, FileBase):
|
|
113
|
+
match item.type:
|
|
114
|
+
case "image":
|
|
115
|
+
parsed = await self.parse_image(item)
|
|
116
|
+
case "file":
|
|
117
|
+
parsed = await self.parse_file(item)
|
|
118
|
+
content_user.append(parsed)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# non content user item
|
|
122
|
+
flush_content_user()
|
|
123
|
+
|
|
82
124
|
match item:
|
|
83
|
-
case
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
case _:
|
|
92
|
-
if content_user:
|
|
93
|
-
new_input.append({"role": "user", "content": content_user})
|
|
94
|
-
content_user = []
|
|
95
|
-
match item:
|
|
96
|
-
case ToolResult():
|
|
97
|
-
if not (
|
|
98
|
-
isinstance(x, dict)
|
|
99
|
-
and "toolUse" in x
|
|
100
|
-
and x["toolUse"].get("toolUseId")
|
|
101
|
-
== item.tool_call.call_id
|
|
102
|
-
for x in new_input
|
|
103
|
-
):
|
|
104
|
-
raise Exception(
|
|
105
|
-
"Tool call result provided with no matching tool call"
|
|
106
|
-
)
|
|
107
|
-
new_input.append(
|
|
125
|
+
case ToolResult():
|
|
126
|
+
if item.tool_call.id not in tool_call_ids:
|
|
127
|
+
raise NoMatchingToolCallError()
|
|
128
|
+
|
|
129
|
+
new_input.append(
|
|
130
|
+
{
|
|
131
|
+
"role": "user",
|
|
132
|
+
"content": [
|
|
108
133
|
{
|
|
109
|
-
"
|
|
110
|
-
|
|
111
|
-
{
|
|
112
|
-
|
|
113
|
-
"toolUseId": item.tool_call.id,
|
|
114
|
-
"content": [
|
|
115
|
-
{"json": {"result": item.result}}
|
|
116
|
-
],
|
|
117
|
-
}
|
|
118
|
-
}
|
|
119
|
-
],
|
|
134
|
+
"toolResult": {
|
|
135
|
+
"toolUseId": item.tool_call.id,
|
|
136
|
+
"content": [{"json": {"result": item.result}}],
|
|
137
|
+
}
|
|
120
138
|
}
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
139
|
+
],
|
|
140
|
+
}
|
|
141
|
+
)
|
|
142
|
+
case RawResponse():
|
|
143
|
+
new_input.append(item.response)
|
|
144
|
+
case RawInput():
|
|
145
|
+
new_input.append(item.input)
|
|
124
146
|
|
|
125
|
-
if content_user:
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
147
|
+
if content_user and self.supports_cache:
|
|
148
|
+
if not isinstance(input[-1], FileBase):
|
|
149
|
+
# last item cannot be file
|
|
150
|
+
content_user.append({"cachePoint": self.cache_control})
|
|
151
|
+
|
|
152
|
+
flush_content_user()
|
|
131
153
|
|
|
132
154
|
return new_input
|
|
133
155
|
|
|
@@ -195,6 +217,7 @@ class AmazonModel(LLM):
|
|
|
195
217
|
) -> FileWithId:
|
|
196
218
|
raise NotImplementedError()
|
|
197
219
|
|
|
220
|
+
@override
|
|
198
221
|
async def build_body(
|
|
199
222
|
self,
|
|
200
223
|
input: Sequence[InputItem],
|
|
@@ -337,6 +360,7 @@ class AmazonModel(LLM):
|
|
|
337
360
|
input: Sequence[InputItem],
|
|
338
361
|
*,
|
|
339
362
|
tools: list[ToolDefinition],
|
|
363
|
+
query_logger: logging.Logger,
|
|
340
364
|
**kwargs: object,
|
|
341
365
|
) -> QueryResult:
|
|
342
366
|
body = await self.build_body(input, tools=tools, **kwargs)
|
|
@@ -381,5 +405,5 @@ class AmazonModel(LLM):
|
|
|
381
405
|
reasoning=reasoning,
|
|
382
406
|
metadata=metadata,
|
|
383
407
|
tool_calls=tool_calls,
|
|
384
|
-
history=[*input, messages],
|
|
408
|
+
history=[*input, RawResponse(response=messages)],
|
|
385
409
|
)
|
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import io
|
|
2
|
+
import logging
|
|
2
3
|
from typing import Any, Literal, Sequence, cast
|
|
3
4
|
|
|
4
5
|
from anthropic import AsyncAnthropic
|
|
5
|
-
from anthropic.types import TextBlock, ToolUseBlock
|
|
6
6
|
from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
|
|
7
|
-
from anthropic.types.
|
|
7
|
+
from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
|
|
8
8
|
from typing_extensions import override
|
|
9
9
|
|
|
10
10
|
from model_library import model_library_settings
|
|
11
11
|
from model_library.base import (
|
|
12
12
|
LLM,
|
|
13
13
|
BatchResult,
|
|
14
|
+
FileBase,
|
|
14
15
|
FileInput,
|
|
15
16
|
FileWithBase64,
|
|
16
17
|
FileWithId,
|
|
@@ -21,7 +22,8 @@ from model_library.base import (
|
|
|
21
22
|
QueryResult,
|
|
22
23
|
QueryResultCost,
|
|
23
24
|
QueryResultMetadata,
|
|
24
|
-
|
|
25
|
+
RawInput,
|
|
26
|
+
RawResponse,
|
|
25
27
|
TextInput,
|
|
26
28
|
ToolBody,
|
|
27
29
|
ToolCall,
|
|
@@ -30,6 +32,7 @@ from model_library.base import (
|
|
|
30
32
|
)
|
|
31
33
|
from model_library.exceptions import (
|
|
32
34
|
MaxOutputTokensExceededError,
|
|
35
|
+
NoMatchingToolCallError,
|
|
33
36
|
)
|
|
34
37
|
from model_library.model_utils import get_default_budget_tokens
|
|
35
38
|
from model_library.providers.openai import OpenAIModel
|
|
@@ -37,8 +40,6 @@ from model_library.register_models import register_provider
|
|
|
37
40
|
from model_library.utils import (
|
|
38
41
|
create_openai_client_with_defaults,
|
|
39
42
|
default_httpx_client,
|
|
40
|
-
filter_empty_text_blocks,
|
|
41
|
-
normalize_tool_result,
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
@@ -61,9 +62,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
|
|
|
61
62
|
|
|
62
63
|
Format: {"custom_id": str, "params": {...message params...}}
|
|
63
64
|
"""
|
|
64
|
-
# Build the message body using the parent model's
|
|
65
|
+
# Build the message body using the parent model's build_body method
|
|
65
66
|
tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
|
|
66
|
-
body = await self._root.
|
|
67
|
+
body = await self._root.build_body(input, tools=tools, **kwargs)
|
|
67
68
|
|
|
68
69
|
return {
|
|
69
70
|
"custom_id": custom_id,
|
|
@@ -249,6 +250,8 @@ class AnthropicModel(LLM):
|
|
|
249
250
|
|
|
250
251
|
@override
|
|
251
252
|
def get_client(self) -> AsyncAnthropic:
|
|
253
|
+
if self._delegate_client:
|
|
254
|
+
return self._delegate_client
|
|
252
255
|
if not AnthropicModel._client:
|
|
253
256
|
headers: dict[str, str] = {}
|
|
254
257
|
AnthropicModel._client = AsyncAnthropic(
|
|
@@ -262,16 +265,20 @@ class AnthropicModel(LLM):
|
|
|
262
265
|
def __init__(
|
|
263
266
|
self,
|
|
264
267
|
model_name: str,
|
|
265
|
-
provider:
|
|
268
|
+
provider: str = "anthropic",
|
|
266
269
|
*,
|
|
267
270
|
config: LLMConfig | None = None,
|
|
271
|
+
custom_client: AsyncAnthropic | None = None,
|
|
268
272
|
):
|
|
269
273
|
super().__init__(model_name, provider, config=config)
|
|
270
274
|
|
|
275
|
+
# allow custom client to act as delegate (native)
|
|
276
|
+
self._delegate_client: AsyncAnthropic | None = custom_client
|
|
277
|
+
|
|
271
278
|
# https://docs.anthropic.com/en/api/openai-sdk
|
|
272
|
-
self.delegate
|
|
279
|
+
self.delegate = (
|
|
273
280
|
None
|
|
274
|
-
if self.native
|
|
281
|
+
if self.native or custom_client
|
|
275
282
|
else OpenAIModel(
|
|
276
283
|
model_name=self.model_name,
|
|
277
284
|
provider=provider,
|
|
@@ -285,11 +292,28 @@ class AnthropicModel(LLM):
|
|
|
285
292
|
)
|
|
286
293
|
|
|
287
294
|
# Initialize batch support if enabled
|
|
288
|
-
|
|
295
|
+
# Disable batch when using custom_client (similar to OpenAI)
|
|
296
|
+
self.supports_batch: bool = (
|
|
297
|
+
self.supports_batch and self.native and not custom_client
|
|
298
|
+
)
|
|
289
299
|
self.batch: LLMBatchMixin | None = (
|
|
290
300
|
AnthropicBatchMixin(self) if self.supports_batch else None
|
|
291
301
|
)
|
|
292
302
|
|
|
303
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
304
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
305
|
+
tool_call_ids: list[str] = []
|
|
306
|
+
|
|
307
|
+
calls = [
|
|
308
|
+
y
|
|
309
|
+
for x in raw_responses
|
|
310
|
+
if isinstance(x.response, ParsedBetaMessage)
|
|
311
|
+
for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
|
312
|
+
if isinstance(y, BetaToolUseBlock)
|
|
313
|
+
]
|
|
314
|
+
tool_call_ids.extend([x.id for x in calls])
|
|
315
|
+
return tool_call_ids
|
|
316
|
+
|
|
293
317
|
@override
|
|
294
318
|
async def parse_input(
|
|
295
319
|
self,
|
|
@@ -297,77 +321,61 @@ class AnthropicModel(LLM):
|
|
|
297
321
|
**kwargs: Any,
|
|
298
322
|
) -> list[dict[str, Any] | Any]:
|
|
299
323
|
new_input: list[dict[str, Any] | Any] = []
|
|
324
|
+
|
|
300
325
|
content_user: list[dict[str, Any]] = []
|
|
301
326
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
|
|
310
|
-
tool_calls_in_input.add(content.id)
|
|
327
|
+
def flush_content_user():
|
|
328
|
+
if content_user:
|
|
329
|
+
# NOTE: must make new object as we clear()
|
|
330
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
331
|
+
content_user.clear()
|
|
332
|
+
|
|
333
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
311
334
|
|
|
312
335
|
for item in input:
|
|
336
|
+
if isinstance(item, TextInput):
|
|
337
|
+
content_user.append({"type": "text", "text": item.text})
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
if isinstance(item, FileBase):
|
|
341
|
+
match item.type:
|
|
342
|
+
case "image":
|
|
343
|
+
parsed = await self.parse_image(item)
|
|
344
|
+
case "file":
|
|
345
|
+
parsed = await self.parse_file(item)
|
|
346
|
+
content_user.append(parsed)
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
# non content user item
|
|
350
|
+
flush_content_user()
|
|
351
|
+
|
|
313
352
|
match item:
|
|
314
|
-
case
|
|
315
|
-
if item.
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
content_user.append(await self.parse_file(item))
|
|
323
|
-
case _:
|
|
324
|
-
if content_user:
|
|
325
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
326
|
-
if filtered:
|
|
327
|
-
new_input.append({"role": "user", "content": filtered})
|
|
328
|
-
content_user = []
|
|
329
|
-
match item:
|
|
330
|
-
case ToolResult():
|
|
331
|
-
if item.tool_call.id not in tool_calls_in_input:
|
|
332
|
-
raise Exception(
|
|
333
|
-
"Tool call result provided with no matching tool call"
|
|
334
|
-
)
|
|
335
|
-
result_str = normalize_tool_result(item.result)
|
|
336
|
-
new_input.append(
|
|
353
|
+
case ToolResult():
|
|
354
|
+
if item.tool_call.id not in tool_call_ids:
|
|
355
|
+
raise NoMatchingToolCallError()
|
|
356
|
+
|
|
357
|
+
new_input.append(
|
|
358
|
+
{
|
|
359
|
+
"role": "user",
|
|
360
|
+
"content": [
|
|
337
361
|
{
|
|
338
|
-
"
|
|
339
|
-
"
|
|
340
|
-
|
|
341
|
-
"type": "tool_result",
|
|
342
|
-
"tool_use_id": item.tool_call.id,
|
|
343
|
-
"content": [
|
|
344
|
-
{"type": "text", "text": result_str}
|
|
345
|
-
],
|
|
346
|
-
}
|
|
347
|
-
],
|
|
362
|
+
"type": "tool_result",
|
|
363
|
+
"tool_use_id": item.tool_call.id,
|
|
364
|
+
"content": [{"type": "text", "text": item.result}],
|
|
348
365
|
}
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
]
|
|
361
|
-
if filtered_content:
|
|
362
|
-
new_input.append(
|
|
363
|
-
{"role": "assistant", "content": filtered_content}
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
if content_user:
|
|
367
|
-
filtered = filter_empty_text_blocks(content_user)
|
|
368
|
-
if filtered:
|
|
369
|
-
new_input.append({"role": "user", "content": filtered})
|
|
366
|
+
],
|
|
367
|
+
}
|
|
368
|
+
)
|
|
369
|
+
case RawResponse():
|
|
370
|
+
content = cast(ParsedBetaMessage, item.response).content
|
|
371
|
+
new_input.append({"role": "assistant", "content": content})
|
|
372
|
+
case RawInput():
|
|
373
|
+
new_input.append(item.input)
|
|
374
|
+
|
|
375
|
+
# in case content user item is the last item
|
|
376
|
+
flush_content_user()
|
|
370
377
|
|
|
378
|
+
# cache control
|
|
371
379
|
if new_input:
|
|
372
380
|
last_msg = new_input[-1]
|
|
373
381
|
if not isinstance(last_msg, dict):
|
|
@@ -485,7 +493,7 @@ class AnthropicModel(LLM):
|
|
|
485
493
|
bytes: io.BytesIO,
|
|
486
494
|
type: Literal["image", "file"] = "file",
|
|
487
495
|
) -> FileWithId:
|
|
488
|
-
file_mime = f"image/{mime}" if type == "image" else mime
|
|
496
|
+
file_mime = f"image/{mime}" if type == "image" else mime
|
|
489
497
|
response = await self.get_client().beta.files.upload(
|
|
490
498
|
file=(
|
|
491
499
|
name,
|
|
@@ -503,7 +511,8 @@ class AnthropicModel(LLM):
|
|
|
503
511
|
|
|
504
512
|
cache_control = {"type": "ephemeral"} # 5 min cache
|
|
505
513
|
|
|
506
|
-
|
|
514
|
+
@override
|
|
515
|
+
async def build_body(
|
|
507
516
|
self,
|
|
508
517
|
input: Sequence[InputItem],
|
|
509
518
|
*,
|
|
@@ -555,20 +564,36 @@ class AnthropicModel(LLM):
|
|
|
555
564
|
input: Sequence[InputItem],
|
|
556
565
|
*,
|
|
557
566
|
tools: list[ToolDefinition],
|
|
567
|
+
query_logger: logging.Logger,
|
|
558
568
|
**kwargs: object,
|
|
559
569
|
) -> QueryResult:
|
|
560
570
|
if self.delegate:
|
|
561
|
-
return await self.delegate_query(
|
|
571
|
+
return await self.delegate_query(
|
|
572
|
+
input, tools=tools, query_logger=query_logger, **kwargs
|
|
573
|
+
)
|
|
562
574
|
|
|
563
|
-
body = await self.
|
|
575
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
564
576
|
|
|
565
|
-
|
|
566
|
-
if "sonnet-4-5" in self.model_name:
|
|
567
|
-
betas.append("context-1m-2025-08-07")
|
|
577
|
+
client = self.get_client()
|
|
568
578
|
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
579
|
+
# only send betas for the official Anthropic endpoint
|
|
580
|
+
is_anthropic_endpoint = self._delegate_client is None
|
|
581
|
+
if not is_anthropic_endpoint:
|
|
582
|
+
client_base_url = getattr(client, "_base_url", None) or getattr(
|
|
583
|
+
client, "base_url", None
|
|
584
|
+
)
|
|
585
|
+
if client_base_url:
|
|
586
|
+
is_anthropic_endpoint = "api.anthropic.com" in str(client_base_url)
|
|
587
|
+
|
|
588
|
+
stream_kwargs = {**body}
|
|
589
|
+
if is_anthropic_endpoint:
|
|
590
|
+
betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
|
|
591
|
+
if "sonnet-4-5" in self.model_name:
|
|
592
|
+
betas.append("context-1m-2025-08-07")
|
|
593
|
+
stream_kwargs["betas"] = betas
|
|
594
|
+
|
|
595
|
+
async with client.beta.messages.stream(
|
|
596
|
+
**stream_kwargs,
|
|
572
597
|
) as stream: # pyright: ignore[reportAny]
|
|
573
598
|
message = await stream.get_final_message()
|
|
574
599
|
self.logger.info(f"Anthropic Response finished: {message.id}")
|
|
@@ -604,9 +629,37 @@ class AnthropicModel(LLM):
|
|
|
604
629
|
cache_write_tokens=message.usage.cache_creation_input_tokens,
|
|
605
630
|
),
|
|
606
631
|
tool_calls=tool_calls,
|
|
607
|
-
history=[*input, message],
|
|
632
|
+
history=[*input, RawResponse(response=message)],
|
|
608
633
|
)
|
|
609
634
|
|
|
635
|
+
@override
|
|
636
|
+
async def count_tokens(
|
|
637
|
+
self,
|
|
638
|
+
input: Sequence[InputItem],
|
|
639
|
+
*,
|
|
640
|
+
history: Sequence[InputItem] = [],
|
|
641
|
+
tools: list[ToolDefinition] = [],
|
|
642
|
+
**kwargs: object,
|
|
643
|
+
) -> int:
|
|
644
|
+
"""
|
|
645
|
+
Count the number of tokens using Anthropic's native token counting API.
|
|
646
|
+
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
|
647
|
+
"""
|
|
648
|
+
input = [*history, *input]
|
|
649
|
+
if not input:
|
|
650
|
+
return 0
|
|
651
|
+
|
|
652
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
653
|
+
|
|
654
|
+
# Remove fields not supported by count_tokens endpoint
|
|
655
|
+
body.pop("max_tokens", None)
|
|
656
|
+
body.pop("temperature", None)
|
|
657
|
+
|
|
658
|
+
client = self.get_client()
|
|
659
|
+
response = await client.messages.count_tokens(**body)
|
|
660
|
+
|
|
661
|
+
return response.input_tokens
|
|
662
|
+
|
|
610
663
|
@override
|
|
611
664
|
async def _calculate_cost(
|
|
612
665
|
self,
|
|
@@ -2,8 +2,6 @@ import io
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Final, Sequence, cast
|
|
4
4
|
|
|
5
|
-
from typing_extensions import override
|
|
6
|
-
|
|
7
5
|
from google.genai.types import (
|
|
8
6
|
BatchJob,
|
|
9
7
|
Content,
|
|
@@ -11,6 +9,8 @@ from google.genai.types import (
|
|
|
11
9
|
JobState,
|
|
12
10
|
UploadFileConfig,
|
|
13
11
|
)
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
14
|
from model_library.base import BatchResult, InputItem, LLMBatchMixin
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -144,7 +144,7 @@ class GoogleBatchMixin(LLMBatchMixin):
|
|
|
144
144
|
**kwargs: object,
|
|
145
145
|
) -> dict[str, Any]:
|
|
146
146
|
self._root.logger.debug(f"Creating batch request for custom_id: {custom_id}")
|
|
147
|
-
body = await self._root.
|
|
147
|
+
body = await self._root.build_body(input, tools=[], **kwargs)
|
|
148
148
|
|
|
149
149
|
contents_any = body["contents"]
|
|
150
150
|
serialized_contents: list[dict[str, Any]] = [
|