model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +237 -62
- model_library/base/delegate_only.py +86 -9
- model_library/base/input.py +10 -7
- model_library/base/output.py +48 -0
- model_library/base/utils.py +56 -7
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +14 -77
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +30 -14
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +119 -64
- model_library/providers/anthropic.py +184 -104
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +17 -13
- model_library/providers/google/google.py +130 -73
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +30 -13
- model_library/providers/mistral.py +61 -35
- model_library/providers/openai.py +219 -93
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +16 -9
- model_library/providers/xai.py +157 -144
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +4 -2
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -35
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.6.dist-info/RECORD +0 -64
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
model_library/exceptions.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import random
|
|
3
1
|
import re
|
|
4
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
5
3
|
|
|
6
|
-
import backoff
|
|
7
4
|
from ai21 import TooManyRequestsError as AI21RateLimitError
|
|
8
5
|
from anthropic import InternalServerError
|
|
9
6
|
from anthropic import RateLimitError as AnthropicRateLimitError
|
|
10
|
-
from backoff._typing import Details
|
|
11
7
|
from httpcore import ReadError as HTTPCoreReadError
|
|
12
8
|
from httpx import ConnectError as HTTPXConnectError
|
|
13
9
|
from httpx import ReadError as HTTPXReadError
|
|
@@ -75,12 +71,14 @@ CONTEXT_WINDOW_PATTERN = re.compile(
|
|
|
75
71
|
r"maximum context length is \d+ tokens|"
|
|
76
72
|
r"context length is \d+ tokens|"
|
|
77
73
|
r"exceed.* context (limit|window|length)|"
|
|
74
|
+
r"context window exceeds|"
|
|
78
75
|
r"exceeds maximum length|"
|
|
79
76
|
r"too long.*tokens.*maximum|"
|
|
80
77
|
r"too large for model with \d+ maximum context length|"
|
|
81
78
|
r"longer than the model's context length|"
|
|
82
79
|
r"too many tokens.*size limit exceeded|"
|
|
83
80
|
r"prompt is too long|"
|
|
81
|
+
r"maximum prompt length|"
|
|
84
82
|
r"input length should be|"
|
|
85
83
|
r"sent message larger than max|"
|
|
86
84
|
r"input tokens exceeded|"
|
|
@@ -146,6 +144,17 @@ class BadInputError(Exception):
|
|
|
146
144
|
super().__init__(message or BadInputError.DEFAULT_MESSAGE)
|
|
147
145
|
|
|
148
146
|
|
|
147
|
+
class NoMatchingToolCallError(Exception):
|
|
148
|
+
"""
|
|
149
|
+
Raised when a tool call result is provided with no matching tool call
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
|
|
153
|
+
|
|
154
|
+
def __init__(self, message: str | None = None):
|
|
155
|
+
super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
|
|
156
|
+
|
|
157
|
+
|
|
149
158
|
# Add more retriable exceptions as needed
|
|
150
159
|
# Providers that don't have an explicit rate limit error are handled manually
|
|
151
160
|
# by wrapping errored Http/gRPC requests with a BackoffRetryException
|
|
@@ -211,75 +220,3 @@ def exception_message(exception: Exception | Any) -> str:
|
|
|
211
220
|
if str(exception)
|
|
212
221
|
else type(exception).__name__
|
|
213
222
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
RETRY_MAX_TRIES: int = 20
|
|
217
|
-
RETRY_INITIAL: float = 10.0
|
|
218
|
-
RETRY_EXPO: float = 1.4
|
|
219
|
-
RETRY_MAX_BACKOFF_WAIT: float = 240.0 # 4 minutes (more with jitter)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def jitter(wait: float) -> float:
|
|
223
|
-
"""
|
|
224
|
-
Increase or decrease the wait time by up to 20%.
|
|
225
|
-
"""
|
|
226
|
-
jitter_fraction = 0.2
|
|
227
|
-
min_wait = wait * (1 - jitter_fraction)
|
|
228
|
-
max_wait = wait * (1 + jitter_fraction)
|
|
229
|
-
return random.uniform(min_wait, max_wait)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def retry_llm_call(
|
|
233
|
-
logger: logging.Logger,
|
|
234
|
-
max_tries: int = RETRY_MAX_TRIES,
|
|
235
|
-
max_time: float | None = None,
|
|
236
|
-
backoff_callback: (
|
|
237
|
-
Callable[[int, Exception | None, float, float], None] | None
|
|
238
|
-
) = None,
|
|
239
|
-
):
|
|
240
|
-
def on_backoff(details: Details):
|
|
241
|
-
exception = details.get("exception")
|
|
242
|
-
tries = details.get("tries", 0)
|
|
243
|
-
elapsed = details.get("elapsed", 0.0)
|
|
244
|
-
wait = details.get("wait", 0.0)
|
|
245
|
-
|
|
246
|
-
logger.warning(
|
|
247
|
-
f"[Retrying] Exception: {exception_message(exception)} | Attempt: {tries} | "
|
|
248
|
-
+ f"Elapsed: {elapsed:.1f}s | Next wait: {wait:.1f}s"
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
if backoff_callback:
|
|
252
|
-
backoff_callback(tries, exception, elapsed, wait)
|
|
253
|
-
|
|
254
|
-
def giveup(e: Exception) -> bool:
|
|
255
|
-
return not is_retriable_error(e)
|
|
256
|
-
|
|
257
|
-
def on_giveup(details: Details) -> None:
|
|
258
|
-
exception: Exception | None = details.get("exception", None)
|
|
259
|
-
if not exception:
|
|
260
|
-
return
|
|
261
|
-
|
|
262
|
-
logger.error(
|
|
263
|
-
f"Giving up after retries. Final exception: {exception_message(exception)}"
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
if is_context_window_error(exception):
|
|
267
|
-
message = exception.args[0] if exception.args else str(exception)
|
|
268
|
-
raise MaxContextWindowExceededError(message)
|
|
269
|
-
|
|
270
|
-
raise exception
|
|
271
|
-
|
|
272
|
-
return backoff.on_exception(
|
|
273
|
-
wait_gen=lambda: backoff.expo(
|
|
274
|
-
base=RETRY_EXPO,
|
|
275
|
-
factor=RETRY_INITIAL,
|
|
276
|
-
max_value=RETRY_MAX_BACKOFF_WAIT,
|
|
277
|
-
),
|
|
278
|
-
exception=Exception,
|
|
279
|
-
max_tries=max_tries,
|
|
280
|
-
max_time=max_time,
|
|
281
|
-
giveup=giveup,
|
|
282
|
-
on_backoff=on_backoff,
|
|
283
|
-
on_giveup=on_giveup,
|
|
284
|
-
jitter=jitter,
|
|
285
|
-
)
|
model_library/logging.py
CHANGED
|
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
|
|
|
6
6
|
_llm_logger = logging.getLogger("llm")
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def set_logging(
|
|
9
|
+
def set_logging(
|
|
10
|
+
enable: bool = True,
|
|
11
|
+
level: int = logging.INFO,
|
|
12
|
+
handler: logging.Handler | None = None,
|
|
13
|
+
):
|
|
10
14
|
"""
|
|
11
15
|
Sets up logging for the model library
|
|
12
16
|
|
|
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
|
|
|
15
19
|
handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
|
|
16
20
|
"""
|
|
17
21
|
if enable:
|
|
18
|
-
_llm_logger.setLevel(
|
|
22
|
+
_llm_logger.setLevel(level)
|
|
19
23
|
else:
|
|
20
24
|
_llm_logger.setLevel(logging.CRITICAL)
|
|
21
25
|
|
|
@@ -16,6 +16,7 @@ from model_library.base import (
|
|
|
16
16
|
LLMConfig,
|
|
17
17
|
QueryResult,
|
|
18
18
|
QueryResultMetadata,
|
|
19
|
+
RawResponse,
|
|
19
20
|
TextInput,
|
|
20
21
|
ToolBody,
|
|
21
22
|
ToolCall,
|
|
@@ -33,17 +34,21 @@ from model_library.utils import default_httpx_client
|
|
|
33
34
|
|
|
34
35
|
@register_provider("ai21labs")
|
|
35
36
|
class AI21LabsModel(LLM):
|
|
36
|
-
|
|
37
|
+
@override
|
|
38
|
+
def _get_default_api_key(self) -> str:
|
|
39
|
+
return model_library_settings.AI21LABS_API_KEY
|
|
37
40
|
|
|
38
41
|
@override
|
|
39
|
-
def get_client(self) -> AsyncAI21Client:
|
|
40
|
-
if not
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
def get_client(self, api_key: str | None = None) -> AsyncAI21Client:
|
|
43
|
+
if not self.has_client():
|
|
44
|
+
assert api_key
|
|
45
|
+
client = AsyncAI21Client(
|
|
46
|
+
api_key=api_key,
|
|
43
47
|
http_client=default_httpx_client(),
|
|
44
|
-
num_retries=
|
|
48
|
+
num_retries=3,
|
|
45
49
|
)
|
|
46
|
-
|
|
50
|
+
self.assign_client(client)
|
|
51
|
+
return super().get_client()
|
|
47
52
|
|
|
48
53
|
def __init__(
|
|
49
54
|
self,
|
|
@@ -65,8 +70,6 @@ class AI21LabsModel(LLM):
|
|
|
65
70
|
match item:
|
|
66
71
|
case TextInput():
|
|
67
72
|
new_input.append(ChatMessage(role="user", content=item.text))
|
|
68
|
-
case AssistantMessage():
|
|
69
|
-
new_input.append(item)
|
|
70
73
|
case ToolResult():
|
|
71
74
|
new_input.append(
|
|
72
75
|
ToolMessage(
|
|
@@ -74,7 +77,9 @@ class AI21LabsModel(LLM):
|
|
|
74
77
|
content=item.result,
|
|
75
78
|
tool_call_id=item.tool_call.id,
|
|
76
79
|
)
|
|
77
|
-
)
|
|
80
|
+
) # TODO: tool calling metadata and test
|
|
81
|
+
case RawResponse():
|
|
82
|
+
new_input.append(item.response)
|
|
78
83
|
case _:
|
|
79
84
|
raise BadInputError("Unsupported input type")
|
|
80
85
|
return new_input
|
|
@@ -133,14 +138,13 @@ class AI21LabsModel(LLM):
|
|
|
133
138
|
raise NotImplementedError()
|
|
134
139
|
|
|
135
140
|
@override
|
|
136
|
-
async def
|
|
141
|
+
async def build_body(
|
|
137
142
|
self,
|
|
138
143
|
input: Sequence[InputItem],
|
|
139
144
|
*,
|
|
140
145
|
tools: list[ToolDefinition],
|
|
141
|
-
query_logger: logging.Logger,
|
|
142
146
|
**kwargs: object,
|
|
143
|
-
) ->
|
|
147
|
+
) -> dict[str, Any]:
|
|
144
148
|
messages: list[ChatMessage] = []
|
|
145
149
|
if "system_prompt" in kwargs:
|
|
146
150
|
messages.append(
|
|
@@ -162,6 +166,18 @@ class AI21LabsModel(LLM):
|
|
|
162
166
|
body["top_p"] = self.top_p
|
|
163
167
|
|
|
164
168
|
body.update(kwargs)
|
|
169
|
+
return body
|
|
170
|
+
|
|
171
|
+
@override
|
|
172
|
+
async def _query_impl(
|
|
173
|
+
self,
|
|
174
|
+
input: Sequence[InputItem],
|
|
175
|
+
*,
|
|
176
|
+
tools: list[ToolDefinition],
|
|
177
|
+
query_logger: logging.Logger,
|
|
178
|
+
**kwargs: object,
|
|
179
|
+
) -> QueryResult:
|
|
180
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
165
181
|
|
|
166
182
|
response: ChatCompletionResponse = (
|
|
167
183
|
await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
|
|
@@ -186,7 +202,7 @@ class AI21LabsModel(LLM):
|
|
|
186
202
|
|
|
187
203
|
output = QueryResult(
|
|
188
204
|
output_text=choice.message.content,
|
|
189
|
-
history=[*input, choice.message],
|
|
205
|
+
history=[*input, RawResponse(response=choice.message)],
|
|
190
206
|
metadata=QueryResultMetadata(
|
|
191
207
|
in_tokens=response.usage.prompt_tokens,
|
|
192
208
|
out_tokens=response.usage.completion_tokens,
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Any, Literal
|
|
2
2
|
|
|
3
|
+
from pydantic import SecretStr
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
5
6
|
from model_library import model_library_settings
|
|
6
7
|
from model_library.base import (
|
|
8
|
+
DelegateConfig,
|
|
7
9
|
DelegateOnly,
|
|
8
10
|
LLMConfig,
|
|
9
11
|
QueryResultCost,
|
|
10
12
|
QueryResultMetadata,
|
|
11
13
|
)
|
|
12
|
-
from model_library.providers.openai import OpenAIModel
|
|
13
14
|
from model_library.register_models import register_provider
|
|
14
|
-
from model_library.utils import create_openai_client_with_defaults
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@register_provider("alibaba")
|
|
@@ -26,17 +26,26 @@ class AlibabaModel(DelegateOnly):
|
|
|
26
26
|
super().__init__(model_name, provider, config=config)
|
|
27
27
|
|
|
28
28
|
# https://www.alibabacloud.com/help/en/model-studio/first-api-call-to-qwen
|
|
29
|
-
self.
|
|
30
|
-
model_name=self.model_name,
|
|
31
|
-
provider=self.provider,
|
|
29
|
+
self.init_delegate(
|
|
32
30
|
config=config,
|
|
33
|
-
|
|
34
|
-
api_key=model_library_settings.DASHSCOPE_API_KEY,
|
|
31
|
+
delegate_config=DelegateConfig(
|
|
35
32
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
33
|
+
api_key=SecretStr(model_library_settings.DASHSCOPE_API_KEY),
|
|
36
34
|
),
|
|
37
35
|
use_completions=True,
|
|
36
|
+
delegate_provider="openai",
|
|
38
37
|
)
|
|
39
38
|
|
|
39
|
+
@override
|
|
40
|
+
def _get_extra_body(self) -> dict[str, Any]:
|
|
41
|
+
"""Build extra body parameters for Qwen-specific features."""
|
|
42
|
+
extra: dict[str, Any] = {}
|
|
43
|
+
# Enable thinking mode for Qwen3 reasoning models
|
|
44
|
+
# https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api
|
|
45
|
+
if self.reasoning:
|
|
46
|
+
extra["enable_thinking"] = True
|
|
47
|
+
return extra
|
|
48
|
+
|
|
40
49
|
@override
|
|
41
50
|
async def _calculate_cost(
|
|
42
51
|
self,
|
|
@@ -11,26 +11,29 @@ import botocore
|
|
|
11
11
|
from botocore.client import BaseClient
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
+
from model_library import model_library_settings
|
|
14
15
|
from model_library.base import (
|
|
15
16
|
LLM,
|
|
17
|
+
FileBase,
|
|
16
18
|
FileInput,
|
|
17
19
|
FileWithBase64,
|
|
18
20
|
FileWithId,
|
|
19
|
-
FileWithUrl,
|
|
20
21
|
InputItem,
|
|
21
22
|
LLMConfig,
|
|
22
23
|
QueryResult,
|
|
23
24
|
QueryResultMetadata,
|
|
25
|
+
RawInput,
|
|
26
|
+
RawResponse,
|
|
24
27
|
TextInput,
|
|
25
28
|
ToolBody,
|
|
26
29
|
ToolCall,
|
|
27
30
|
ToolDefinition,
|
|
28
31
|
ToolResult,
|
|
29
32
|
)
|
|
30
|
-
from model_library.base.input import FileBase
|
|
31
33
|
from model_library.exceptions import (
|
|
32
34
|
BadInputError,
|
|
33
35
|
MaxOutputTokensExceededError,
|
|
36
|
+
NoMatchingToolCallError,
|
|
34
37
|
)
|
|
35
38
|
from model_library.model_utils import get_default_budget_tokens
|
|
36
39
|
from model_library.register_models import register_provider
|
|
@@ -39,20 +42,46 @@ from model_library.register_models import register_provider
|
|
|
39
42
|
@register_provider("amazon")
|
|
40
43
|
@register_provider("bedrock")
|
|
41
44
|
class AmazonModel(LLM):
|
|
42
|
-
|
|
45
|
+
@override
|
|
46
|
+
def _get_default_api_key(self) -> str:
|
|
47
|
+
if getattr(model_library_settings, "AWS_ACCESS_KEY_ID", None):
|
|
48
|
+
return json.dumps(
|
|
49
|
+
{
|
|
50
|
+
"AWS_ACCESS_KEY_ID": model_library_settings.AWS_ACCESS_KEY_ID,
|
|
51
|
+
"AWS_SECRET_ACCESS_KEY": model_library_settings.AWS_SECRET_ACCESS_KEY,
|
|
52
|
+
"AWS_DEFAULT_REGION": model_library_settings.AWS_DEFAULT_REGION,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
return "using-environment"
|
|
43
56
|
|
|
44
57
|
@override
|
|
45
|
-
def get_client(self) -> BaseClient:
|
|
46
|
-
if not
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
58
|
+
def get_client(self, api_key: str | None = None) -> BaseClient:
|
|
59
|
+
if not self.has_client():
|
|
60
|
+
assert api_key
|
|
61
|
+
if api_key != "using-environment":
|
|
62
|
+
creds = json.loads(api_key)
|
|
63
|
+
client = cast(
|
|
64
|
+
BaseClient,
|
|
65
|
+
boto3.client(
|
|
66
|
+
"bedrock-runtime",
|
|
67
|
+
aws_access_key_id=creds["AWS_ACCESS_KEY_ID"],
|
|
68
|
+
aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"],
|
|
69
|
+
region_name=creds["AWS_DEFAULT_REGION"],
|
|
70
|
+
config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
client = cast(
|
|
75
|
+
BaseClient,
|
|
76
|
+
boto3.client(
|
|
77
|
+
"bedrock-runtime",
|
|
78
|
+
# default connection pool is 10
|
|
79
|
+
config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.assign_client(client)
|
|
84
|
+
return super().get_client()
|
|
56
85
|
|
|
57
86
|
def __init__(
|
|
58
87
|
self,
|
|
@@ -68,8 +97,27 @@ class AmazonModel(LLM):
|
|
|
68
97
|
) # supported but no access yet
|
|
69
98
|
self.supports_tool_cache = self.supports_cache and "claude" in self.model_name
|
|
70
99
|
|
|
100
|
+
if config and config.custom_api_key:
|
|
101
|
+
raise Exception(
|
|
102
|
+
"custom_api_key is not currently supported for Amazon models"
|
|
103
|
+
)
|
|
104
|
+
|
|
71
105
|
cache_control = {"type": "default"}
|
|
72
106
|
|
|
107
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
108
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
109
|
+
tool_call_ids: list[str] = []
|
|
110
|
+
|
|
111
|
+
calls = [
|
|
112
|
+
y["toolUse"]
|
|
113
|
+
for x in raw_responses
|
|
114
|
+
if "content" in x.response
|
|
115
|
+
for y in x.response["content"]
|
|
116
|
+
if "toolUse" in y
|
|
117
|
+
]
|
|
118
|
+
tool_call_ids.extend([x["toolUseId"] for x in calls])
|
|
119
|
+
return tool_call_ids
|
|
120
|
+
|
|
73
121
|
@override
|
|
74
122
|
async def parse_input(
|
|
75
123
|
self,
|
|
@@ -77,58 +125,63 @@ class AmazonModel(LLM):
|
|
|
77
125
|
**kwargs: Any,
|
|
78
126
|
) -> list[dict[str, Any]]:
|
|
79
127
|
new_input: list[dict[str, Any] | Any] = []
|
|
128
|
+
|
|
80
129
|
content_user: list[dict[str, Any]] = []
|
|
81
130
|
|
|
131
|
+
def flush_content_user():
|
|
132
|
+
if content_user:
|
|
133
|
+
# NOTE: must make new object as we clear()
|
|
134
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
135
|
+
content_user.clear()
|
|
136
|
+
|
|
137
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
138
|
+
|
|
82
139
|
for item in input:
|
|
140
|
+
if isinstance(item, TextInput):
|
|
141
|
+
content_user.append({"text": item.text})
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
if isinstance(item, FileBase):
|
|
145
|
+
match item.type:
|
|
146
|
+
case "image":
|
|
147
|
+
parsed = await self.parse_image(item)
|
|
148
|
+
case "file":
|
|
149
|
+
parsed = await self.parse_file(item)
|
|
150
|
+
content_user.append(parsed)
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
# non content user item
|
|
154
|
+
flush_content_user()
|
|
155
|
+
|
|
83
156
|
match item:
|
|
84
|
-
case
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
case _:
|
|
93
|
-
if content_user:
|
|
94
|
-
new_input.append({"role": "user", "content": content_user})
|
|
95
|
-
content_user = []
|
|
96
|
-
match item:
|
|
97
|
-
case ToolResult():
|
|
98
|
-
if not (
|
|
99
|
-
isinstance(x, dict)
|
|
100
|
-
and "toolUse" in x
|
|
101
|
-
and x["toolUse"].get("toolUseId")
|
|
102
|
-
== item.tool_call.call_id
|
|
103
|
-
for x in new_input
|
|
104
|
-
):
|
|
105
|
-
raise Exception(
|
|
106
|
-
"Tool call result provided with no matching tool call"
|
|
107
|
-
)
|
|
108
|
-
new_input.append(
|
|
157
|
+
case ToolResult():
|
|
158
|
+
if item.tool_call.id not in tool_call_ids:
|
|
159
|
+
raise NoMatchingToolCallError()
|
|
160
|
+
|
|
161
|
+
new_input.append(
|
|
162
|
+
{
|
|
163
|
+
"role": "user",
|
|
164
|
+
"content": [
|
|
109
165
|
{
|
|
110
|
-
"
|
|
111
|
-
|
|
112
|
-
{
|
|
113
|
-
|
|
114
|
-
"toolUseId": item.tool_call.id,
|
|
115
|
-
"content": [
|
|
116
|
-
{"json": {"result": item.result}}
|
|
117
|
-
],
|
|
118
|
-
}
|
|
119
|
-
}
|
|
120
|
-
],
|
|
166
|
+
"toolResult": {
|
|
167
|
+
"toolUseId": item.tool_call.id,
|
|
168
|
+
"content": [{"json": {"result": item.result}}],
|
|
169
|
+
}
|
|
121
170
|
}
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
171
|
+
],
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
case RawResponse():
|
|
175
|
+
new_input.append(item.response)
|
|
176
|
+
case RawInput():
|
|
177
|
+
new_input.append(item.input)
|
|
125
178
|
|
|
126
|
-
if content_user:
|
|
127
|
-
if
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
179
|
+
if content_user and self.supports_cache:
|
|
180
|
+
if not isinstance(input[-1], FileBase):
|
|
181
|
+
# last item cannot be file
|
|
182
|
+
content_user.append({"cachePoint": self.cache_control})
|
|
183
|
+
|
|
184
|
+
flush_content_user()
|
|
132
185
|
|
|
133
186
|
return new_input
|
|
134
187
|
|
|
@@ -196,6 +249,7 @@ class AmazonModel(LLM):
|
|
|
196
249
|
) -> FileWithId:
|
|
197
250
|
raise NotImplementedError()
|
|
198
251
|
|
|
252
|
+
@override
|
|
199
253
|
async def build_body(
|
|
200
254
|
self,
|
|
201
255
|
input: Sequence[InputItem],
|
|
@@ -216,7 +270,7 @@ class AmazonModel(LLM):
|
|
|
216
270
|
if self.supports_cache:
|
|
217
271
|
body["system"].append({"cachePoint": self.cache_control})
|
|
218
272
|
|
|
219
|
-
if self.reasoning:
|
|
273
|
+
if self.reasoning and self.max_tokens:
|
|
220
274
|
if self.max_tokens < 1024:
|
|
221
275
|
self.max_tokens = 2048
|
|
222
276
|
budget_tokens = kwargs.pop(
|
|
@@ -229,9 +283,10 @@ class AmazonModel(LLM):
|
|
|
229
283
|
}
|
|
230
284
|
}
|
|
231
285
|
|
|
232
|
-
inference: dict[str, Any] = {
|
|
233
|
-
|
|
234
|
-
|
|
286
|
+
inference: dict[str, Any] = {}
|
|
287
|
+
|
|
288
|
+
if self.max_tokens:
|
|
289
|
+
inference["maxTokens"] = self.max_tokens
|
|
235
290
|
|
|
236
291
|
# Only set temperature for models where supports_temperature is True.
|
|
237
292
|
# For example, "thinking" models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
|
|
@@ -383,5 +438,5 @@ class AmazonModel(LLM):
|
|
|
383
438
|
reasoning=reasoning,
|
|
384
439
|
metadata=metadata,
|
|
385
440
|
tool_calls=tool_calls,
|
|
386
|
-
history=[*input, messages],
|
|
441
|
+
history=[*input, RawResponse(response=messages)],
|
|
387
442
|
)
|