model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +98 -0
- model_library/base/delegate_only.py +10 -0
- model_library/base/input.py +10 -7
- model_library/base/output.py +5 -0
- model_library/base/utils.py +21 -7
- model_library/exceptions.py +11 -0
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +19 -7
- model_library/providers/amazon.py +70 -48
- model_library/providers/anthropic.py +101 -74
- model_library/providers/google/batch.py +3 -3
- model_library/providers/google/google.py +83 -45
- model_library/providers/minimax.py +19 -0
- model_library/providers/mistral.py +41 -27
- model_library/providers/openai.py +122 -73
- model_library/providers/vals.py +4 -3
- model_library/providers/xai.py +123 -115
- model_library/register_models.py +4 -2
- model_library/utils.py +0 -35
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/METADATA +3 -3
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/RECORD +24 -24
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/WHEEL +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.7.dist-info}/top_level.txt +0 -0
model_library/base/base.py
CHANGED
|
@@ -13,8 +13,10 @@ from typing import (
|
|
|
13
13
|
TypeVar,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
import tiktoken
|
|
16
17
|
from pydantic import model_serializer
|
|
17
18
|
from pydantic.main import BaseModel
|
|
19
|
+
from tiktoken.core import Encoding
|
|
18
20
|
from typing_extensions import override
|
|
19
21
|
|
|
20
22
|
from model_library.base.batch import (
|
|
@@ -35,6 +37,7 @@ from model_library.base.output import (
|
|
|
35
37
|
)
|
|
36
38
|
from model_library.base.utils import (
|
|
37
39
|
get_pretty_input_types,
|
|
40
|
+
serialize_for_tokenizing,
|
|
38
41
|
)
|
|
39
42
|
from model_library.exceptions import (
|
|
40
43
|
ImmediateRetryException,
|
|
@@ -379,6 +382,20 @@ class LLM(ABC):
|
|
|
379
382
|
"""
|
|
380
383
|
...
|
|
381
384
|
|
|
385
|
+
@abstractmethod
|
|
386
|
+
async def build_body(
|
|
387
|
+
self,
|
|
388
|
+
input: Sequence[InputItem],
|
|
389
|
+
*,
|
|
390
|
+
tools: list[ToolDefinition],
|
|
391
|
+
**kwargs: Any,
|
|
392
|
+
) -> dict[str, Any]:
|
|
393
|
+
"""
|
|
394
|
+
Builds the body of the request to the model provider
|
|
395
|
+
Calls parse_input
|
|
396
|
+
"""
|
|
397
|
+
...
|
|
398
|
+
|
|
382
399
|
@abstractmethod
|
|
383
400
|
async def parse_input(
|
|
384
401
|
self,
|
|
@@ -421,6 +438,87 @@ class LLM(ABC):
|
|
|
421
438
|
"""Upload a file to the model provider"""
|
|
422
439
|
...
|
|
423
440
|
|
|
441
|
+
async def get_encoding(self) -> Encoding:
|
|
442
|
+
"""Get the appropriate tokenizer"""
|
|
443
|
+
|
|
444
|
+
model = self.model_name.lower()
|
|
445
|
+
|
|
446
|
+
if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
|
|
447
|
+
return tiktoken.get_encoding("o200k_base")
|
|
448
|
+
elif "gpt-4" in model or "gpt-3.5" in model:
|
|
449
|
+
try:
|
|
450
|
+
return tiktoken.encoding_for_model(self.model_name)
|
|
451
|
+
except KeyError:
|
|
452
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
453
|
+
elif "claude" in model:
|
|
454
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
455
|
+
elif "gemini" in model:
|
|
456
|
+
return tiktoken.get_encoding("o200k_base")
|
|
457
|
+
elif "llama" in model or "mistral" in model:
|
|
458
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
459
|
+
else:
|
|
460
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
461
|
+
|
|
462
|
+
async def stringify_input(
|
|
463
|
+
self,
|
|
464
|
+
input: Sequence[InputItem],
|
|
465
|
+
*,
|
|
466
|
+
history: Sequence[InputItem] = [],
|
|
467
|
+
tools: list[ToolDefinition] = [],
|
|
468
|
+
**kwargs: object,
|
|
469
|
+
) -> str:
|
|
470
|
+
input = [*history, *input]
|
|
471
|
+
|
|
472
|
+
system_prompt = kwargs.pop(
|
|
473
|
+
"system_prompt", ""
|
|
474
|
+
) # TODO: refactor along with system prompt arg change
|
|
475
|
+
|
|
476
|
+
# special case if using a delegate
|
|
477
|
+
# don't inherit method override by default
|
|
478
|
+
if self.delegate:
|
|
479
|
+
parsed_input = await self.delegate.parse_input(input, **kwargs)
|
|
480
|
+
parsed_tools = await self.delegate.parse_tools(tools)
|
|
481
|
+
else:
|
|
482
|
+
parsed_input = await self.parse_input(input, **kwargs)
|
|
483
|
+
parsed_tools = await self.parse_tools(tools)
|
|
484
|
+
|
|
485
|
+
serialized_input = serialize_for_tokenizing(parsed_input)
|
|
486
|
+
serialized_tools = serialize_for_tokenizing(parsed_tools)
|
|
487
|
+
|
|
488
|
+
combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
|
|
489
|
+
|
|
490
|
+
return combined
|
|
491
|
+
|
|
492
|
+
async def count_tokens(
|
|
493
|
+
self,
|
|
494
|
+
input: Sequence[InputItem],
|
|
495
|
+
*,
|
|
496
|
+
history: Sequence[InputItem] = [],
|
|
497
|
+
tools: list[ToolDefinition] = [],
|
|
498
|
+
**kwargs: object,
|
|
499
|
+
) -> int:
|
|
500
|
+
"""
|
|
501
|
+
Count the number of tokens for a query.
|
|
502
|
+
Combines parsed input and tools, then tokenizes the result.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
if not input and not history:
|
|
506
|
+
return 0
|
|
507
|
+
|
|
508
|
+
if self.delegate:
|
|
509
|
+
encoding = await self.delegate.get_encoding()
|
|
510
|
+
else:
|
|
511
|
+
encoding = await self.get_encoding()
|
|
512
|
+
self.logger.debug(f"Token Count Encoding: {encoding}")
|
|
513
|
+
|
|
514
|
+
string_input = await self.stringify_input(
|
|
515
|
+
input, history=history, tools=tools, **kwargs
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
count = len(encoding.encode(string_input, disallowed_special=()))
|
|
519
|
+
self.logger.debug(f"Combined Token Count Input: {count}")
|
|
520
|
+
return count
|
|
521
|
+
|
|
424
522
|
async def query_json(
|
|
425
523
|
self,
|
|
426
524
|
input: Sequence[InputItem],
|
|
@@ -58,6 +58,16 @@ class DelegateOnly(LLM):
|
|
|
58
58
|
input, tools=tools, query_logger=query_logger, **kwargs
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
+
@override
|
|
62
|
+
async def build_body(
|
|
63
|
+
self,
|
|
64
|
+
input: Sequence[InputItem],
|
|
65
|
+
*,
|
|
66
|
+
tools: list[ToolDefinition],
|
|
67
|
+
**kwargs: object,
|
|
68
|
+
) -> dict[str, Any]:
|
|
69
|
+
raise DelegateOnlyException()
|
|
70
|
+
|
|
61
71
|
@override
|
|
62
72
|
async def parse_input(
|
|
63
73
|
self,
|
model_library/base/input.py
CHANGED
|
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
|
|
|
74
74
|
--- INPUT ---
|
|
75
75
|
"""
|
|
76
76
|
|
|
77
|
-
RawResponse = Any
|
|
78
|
-
|
|
79
77
|
|
|
80
78
|
class ToolInput(BaseModel):
|
|
81
79
|
tools: list[ToolDefinition] = []
|
|
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
|
|
|
90
88
|
text: str
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
91
|
+
class RawResponse(BaseModel):
|
|
92
|
+
# used to store a received response
|
|
93
|
+
response: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RawInput(BaseModel):
|
|
97
|
+
# used to pass in anything provider specific (e.g. a mock conversation)
|
|
98
|
+
input: Any
|
|
96
99
|
|
|
97
100
|
|
|
98
101
|
InputItem = (
|
|
99
|
-
TextInput | FileInput | ToolResult |
|
|
100
|
-
) # input item can either be a prompt, a file (image or file), a tool call result,
|
|
102
|
+
TextInput | FileInput | ToolResult | RawInput | RawResponse
|
|
103
|
+
) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
|
model_library/base/output.py
CHANGED
|
@@ -24,6 +24,11 @@ class Citation(BaseModel):
|
|
|
24
24
|
index: int | None = None
|
|
25
25
|
container_id: str | None = None
|
|
26
26
|
|
|
27
|
+
@override
|
|
28
|
+
def __repr__(self):
|
|
29
|
+
attrs = vars(self).copy()
|
|
30
|
+
return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
class QueryResultExtras(BaseModel):
|
|
29
34
|
citations: list[Citation] = Field(default_factory=list)
|
model_library/base/utils.py
CHANGED
|
@@ -1,18 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
2
5
|
|
|
3
6
|
from model_library.base.input import (
|
|
4
7
|
FileBase,
|
|
5
8
|
InputItem,
|
|
6
|
-
|
|
9
|
+
RawInput,
|
|
10
|
+
RawResponse,
|
|
7
11
|
TextInput,
|
|
8
12
|
ToolResult,
|
|
9
13
|
)
|
|
10
14
|
from model_library.utils import truncate_str
|
|
11
|
-
from pydantic import BaseModel
|
|
12
15
|
|
|
13
16
|
T = TypeVar("T", bound=BaseModel)
|
|
14
17
|
|
|
15
18
|
|
|
19
|
+
def serialize_for_tokenizing(content: Any) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Serialize parsed content into a string for tokenization
|
|
22
|
+
"""
|
|
23
|
+
parts: list[str] = []
|
|
24
|
+
if content:
|
|
25
|
+
if isinstance(content, str):
|
|
26
|
+
parts.append(content)
|
|
27
|
+
else:
|
|
28
|
+
parts.append(json.dumps(content, default=str))
|
|
29
|
+
return "\n".join(parts)
|
|
30
|
+
|
|
31
|
+
|
|
16
32
|
def add_optional(
|
|
17
33
|
a: int | float | T | None, b: int | float | T | None
|
|
18
34
|
) -> int | float | T | None:
|
|
@@ -54,11 +70,9 @@ def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False)
|
|
|
54
70
|
return repr(item)
|
|
55
71
|
case ToolResult():
|
|
56
72
|
return repr(item)
|
|
57
|
-
case
|
|
58
|
-
item = cast(RawInputItem, item)
|
|
73
|
+
case RawInput():
|
|
59
74
|
return repr(item)
|
|
60
|
-
case
|
|
61
|
-
# RawResponse
|
|
75
|
+
case RawResponse():
|
|
62
76
|
return repr(item)
|
|
63
77
|
|
|
64
78
|
processed_items = [f" {process_item(item)}" for item in input]
|
model_library/exceptions.py
CHANGED
|
@@ -146,6 +146,17 @@ class BadInputError(Exception):
|
|
|
146
146
|
super().__init__(message or BadInputError.DEFAULT_MESSAGE)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
+
class NoMatchingToolCallError(Exception):
|
|
150
|
+
"""
|
|
151
|
+
Raised when a tool call result is provided with no matching tool call
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
|
|
155
|
+
|
|
156
|
+
def __init__(self, message: str | None = None):
|
|
157
|
+
super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
|
|
158
|
+
|
|
159
|
+
|
|
149
160
|
# Add more retriable exceptions as needed
|
|
150
161
|
# Providers that don't have an explicit rate limit error are handled manually
|
|
151
162
|
# by wrapping errored Http/gRPC requests with a BackoffRetryException
|
model_library/logging.py
CHANGED
|
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
|
|
|
6
6
|
_llm_logger = logging.getLogger("llm")
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def set_logging(
|
|
9
|
+
def set_logging(
|
|
10
|
+
enable: bool = True,
|
|
11
|
+
level: int = logging.INFO,
|
|
12
|
+
handler: logging.Handler | None = None,
|
|
13
|
+
):
|
|
10
14
|
"""
|
|
11
15
|
Sets up logging for the model library
|
|
12
16
|
|
|
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
|
|
|
15
19
|
handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
|
|
16
20
|
"""
|
|
17
21
|
if enable:
|
|
18
|
-
_llm_logger.setLevel(
|
|
22
|
+
_llm_logger.setLevel(level)
|
|
19
23
|
else:
|
|
20
24
|
_llm_logger.setLevel(logging.CRITICAL)
|
|
21
25
|
|
|
@@ -22,6 +22,7 @@ from model_library.base import (
|
|
|
22
22
|
ToolDefinition,
|
|
23
23
|
ToolResult,
|
|
24
24
|
)
|
|
25
|
+
from model_library.base.input import RawResponse
|
|
25
26
|
from model_library.exceptions import (
|
|
26
27
|
BadInputError,
|
|
27
28
|
MaxOutputTokensExceededError,
|
|
@@ -65,8 +66,6 @@ class AI21LabsModel(LLM):
|
|
|
65
66
|
match item:
|
|
66
67
|
case TextInput():
|
|
67
68
|
new_input.append(ChatMessage(role="user", content=item.text))
|
|
68
|
-
case AssistantMessage():
|
|
69
|
-
new_input.append(item)
|
|
70
69
|
case ToolResult():
|
|
71
70
|
new_input.append(
|
|
72
71
|
ToolMessage(
|
|
@@ -74,7 +73,9 @@ class AI21LabsModel(LLM):
|
|
|
74
73
|
content=item.result,
|
|
75
74
|
tool_call_id=item.tool_call.id,
|
|
76
75
|
)
|
|
77
|
-
)
|
|
76
|
+
) # TODO: tool calling metadata and test
|
|
77
|
+
case RawResponse():
|
|
78
|
+
new_input.append(item.response)
|
|
78
79
|
case _:
|
|
79
80
|
raise BadInputError("Unsupported input type")
|
|
80
81
|
return new_input
|
|
@@ -133,14 +134,13 @@ class AI21LabsModel(LLM):
|
|
|
133
134
|
raise NotImplementedError()
|
|
134
135
|
|
|
135
136
|
@override
|
|
136
|
-
async def
|
|
137
|
+
async def build_body(
|
|
137
138
|
self,
|
|
138
139
|
input: Sequence[InputItem],
|
|
139
140
|
*,
|
|
140
141
|
tools: list[ToolDefinition],
|
|
141
|
-
query_logger: logging.Logger,
|
|
142
142
|
**kwargs: object,
|
|
143
|
-
) ->
|
|
143
|
+
) -> dict[str, Any]:
|
|
144
144
|
messages: list[ChatMessage] = []
|
|
145
145
|
if "system_prompt" in kwargs:
|
|
146
146
|
messages.append(
|
|
@@ -162,6 +162,18 @@ class AI21LabsModel(LLM):
|
|
|
162
162
|
body["top_p"] = self.top_p
|
|
163
163
|
|
|
164
164
|
body.update(kwargs)
|
|
165
|
+
return body
|
|
166
|
+
|
|
167
|
+
@override
|
|
168
|
+
async def _query_impl(
|
|
169
|
+
self,
|
|
170
|
+
input: Sequence[InputItem],
|
|
171
|
+
*,
|
|
172
|
+
tools: list[ToolDefinition],
|
|
173
|
+
query_logger: logging.Logger,
|
|
174
|
+
**kwargs: object,
|
|
175
|
+
) -> QueryResult:
|
|
176
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
165
177
|
|
|
166
178
|
response: ChatCompletionResponse = (
|
|
167
179
|
await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
|
|
@@ -186,7 +198,7 @@ class AI21LabsModel(LLM):
|
|
|
186
198
|
|
|
187
199
|
output = QueryResult(
|
|
188
200
|
output_text=choice.message.content,
|
|
189
|
-
history=[*input, choice.message],
|
|
201
|
+
history=[*input, RawResponse(response=choice.message)],
|
|
190
202
|
metadata=QueryResultMetadata(
|
|
191
203
|
in_tokens=response.usage.prompt_tokens,
|
|
192
204
|
out_tokens=response.usage.completion_tokens,
|
|
@@ -13,24 +13,26 @@ from typing_extensions import override
|
|
|
13
13
|
|
|
14
14
|
from model_library.base import (
|
|
15
15
|
LLM,
|
|
16
|
+
FileBase,
|
|
16
17
|
FileInput,
|
|
17
18
|
FileWithBase64,
|
|
18
19
|
FileWithId,
|
|
19
|
-
FileWithUrl,
|
|
20
20
|
InputItem,
|
|
21
21
|
LLMConfig,
|
|
22
22
|
QueryResult,
|
|
23
23
|
QueryResultMetadata,
|
|
24
|
+
RawInput,
|
|
25
|
+
RawResponse,
|
|
24
26
|
TextInput,
|
|
25
27
|
ToolBody,
|
|
26
28
|
ToolCall,
|
|
27
29
|
ToolDefinition,
|
|
28
30
|
ToolResult,
|
|
29
31
|
)
|
|
30
|
-
from model_library.base.input import FileBase
|
|
31
32
|
from model_library.exceptions import (
|
|
32
33
|
BadInputError,
|
|
33
34
|
MaxOutputTokensExceededError,
|
|
35
|
+
NoMatchingToolCallError,
|
|
34
36
|
)
|
|
35
37
|
from model_library.model_utils import get_default_budget_tokens
|
|
36
38
|
from model_library.register_models import register_provider
|
|
@@ -70,6 +72,20 @@ class AmazonModel(LLM):
|
|
|
70
72
|
|
|
71
73
|
cache_control = {"type": "default"}
|
|
72
74
|
|
|
75
|
+
async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
|
|
76
|
+
raw_responses = [x for x in input if isinstance(x, RawResponse)]
|
|
77
|
+
tool_call_ids: list[str] = []
|
|
78
|
+
|
|
79
|
+
calls = [
|
|
80
|
+
y["toolUse"]
|
|
81
|
+
for x in raw_responses
|
|
82
|
+
if "content" in x.response
|
|
83
|
+
for y in x.response["content"]
|
|
84
|
+
if "toolUse" in y
|
|
85
|
+
]
|
|
86
|
+
tool_call_ids.extend([x["toolUseId"] for x in calls])
|
|
87
|
+
return tool_call_ids
|
|
88
|
+
|
|
73
89
|
@override
|
|
74
90
|
async def parse_input(
|
|
75
91
|
self,
|
|
@@ -77,58 +93,63 @@ class AmazonModel(LLM):
|
|
|
77
93
|
**kwargs: Any,
|
|
78
94
|
) -> list[dict[str, Any]]:
|
|
79
95
|
new_input: list[dict[str, Any] | Any] = []
|
|
96
|
+
|
|
80
97
|
content_user: list[dict[str, Any]] = []
|
|
81
98
|
|
|
99
|
+
def flush_content_user():
|
|
100
|
+
if content_user:
|
|
101
|
+
# NOTE: must make new object as we clear()
|
|
102
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
103
|
+
content_user.clear()
|
|
104
|
+
|
|
105
|
+
tool_call_ids = await self.get_tool_call_ids(input)
|
|
106
|
+
|
|
82
107
|
for item in input:
|
|
108
|
+
if isinstance(item, TextInput):
|
|
109
|
+
content_user.append({"text": item.text})
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if isinstance(item, FileBase):
|
|
113
|
+
match item.type:
|
|
114
|
+
case "image":
|
|
115
|
+
parsed = await self.parse_image(item)
|
|
116
|
+
case "file":
|
|
117
|
+
parsed = await self.parse_file(item)
|
|
118
|
+
content_user.append(parsed)
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# non content user item
|
|
122
|
+
flush_content_user()
|
|
123
|
+
|
|
83
124
|
match item:
|
|
84
|
-
case
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
case _:
|
|
93
|
-
if content_user:
|
|
94
|
-
new_input.append({"role": "user", "content": content_user})
|
|
95
|
-
content_user = []
|
|
96
|
-
match item:
|
|
97
|
-
case ToolResult():
|
|
98
|
-
if not (
|
|
99
|
-
isinstance(x, dict)
|
|
100
|
-
and "toolUse" in x
|
|
101
|
-
and x["toolUse"].get("toolUseId")
|
|
102
|
-
== item.tool_call.call_id
|
|
103
|
-
for x in new_input
|
|
104
|
-
):
|
|
105
|
-
raise Exception(
|
|
106
|
-
"Tool call result provided with no matching tool call"
|
|
107
|
-
)
|
|
108
|
-
new_input.append(
|
|
125
|
+
case ToolResult():
|
|
126
|
+
if item.tool_call.id not in tool_call_ids:
|
|
127
|
+
raise NoMatchingToolCallError()
|
|
128
|
+
|
|
129
|
+
new_input.append(
|
|
130
|
+
{
|
|
131
|
+
"role": "user",
|
|
132
|
+
"content": [
|
|
109
133
|
{
|
|
110
|
-
"
|
|
111
|
-
|
|
112
|
-
{
|
|
113
|
-
|
|
114
|
-
"toolUseId": item.tool_call.id,
|
|
115
|
-
"content": [
|
|
116
|
-
{"json": {"result": item.result}}
|
|
117
|
-
],
|
|
118
|
-
}
|
|
119
|
-
}
|
|
120
|
-
],
|
|
134
|
+
"toolResult": {
|
|
135
|
+
"toolUseId": item.tool_call.id,
|
|
136
|
+
"content": [{"json": {"result": item.result}}],
|
|
137
|
+
}
|
|
121
138
|
}
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
139
|
+
],
|
|
140
|
+
}
|
|
141
|
+
)
|
|
142
|
+
case RawResponse():
|
|
143
|
+
new_input.append(item.response)
|
|
144
|
+
case RawInput():
|
|
145
|
+
new_input.append(item.input)
|
|
125
146
|
|
|
126
|
-
if content_user:
|
|
127
|
-
if
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
147
|
+
if content_user and self.supports_cache:
|
|
148
|
+
if not isinstance(input[-1], FileBase):
|
|
149
|
+
# last item cannot be file
|
|
150
|
+
content_user.append({"cachePoint": self.cache_control})
|
|
151
|
+
|
|
152
|
+
flush_content_user()
|
|
132
153
|
|
|
133
154
|
return new_input
|
|
134
155
|
|
|
@@ -196,6 +217,7 @@ class AmazonModel(LLM):
|
|
|
196
217
|
) -> FileWithId:
|
|
197
218
|
raise NotImplementedError()
|
|
198
219
|
|
|
220
|
+
@override
|
|
199
221
|
async def build_body(
|
|
200
222
|
self,
|
|
201
223
|
input: Sequence[InputItem],
|
|
@@ -383,5 +405,5 @@ class AmazonModel(LLM):
|
|
|
383
405
|
reasoning=reasoning,
|
|
384
406
|
metadata=metadata,
|
|
385
407
|
tool_calls=tool_calls,
|
|
386
|
-
history=[*input, messages],
|
|
408
|
+
history=[*input, RawResponse(response=messages)],
|
|
387
409
|
)
|