inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -9
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +79 -12
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +3 -1
- inspect_ai/_eval/task/results.py +51 -22
- inspect_ai/_eval/task/run.py +47 -13
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25498 -2044
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +14 -16
- inspect_ai/_view/www/src/Types.mjs +1 -2
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +77 -4
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +59 -0
- inspect_ai/model/_conversation.py +16 -7
- inspect_ai/model/_generate_config.py +12 -12
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +22 -2
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +152 -55
- inspect_ai/model/_providers/azureai.py +21 -21
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +13 -12
- inspect_ai/model/_providers/openai.py +51 -218
- inspect_ai/model/_providers/openai_o1.py +11 -12
- inspect_ai/model/_providers/providers.py +23 -1
- inspect_ai/model/_providers/together.py +12 -12
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +4 -3
- inspect_ai/solver/__init__.py +4 -5
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,26 @@
|
|
1
1
|
import functools
|
2
2
|
import os
|
3
|
+
import sys
|
3
4
|
from copy import copy
|
4
5
|
from logging import getLogger
|
5
|
-
from typing import Any, Literal, Tuple, cast
|
6
|
+
from typing import Any, Literal, Tuple, TypedDict, cast
|
7
|
+
|
8
|
+
if sys.version_info >= (3, 11):
|
9
|
+
from typing import NotRequired
|
10
|
+
else:
|
11
|
+
from typing_extensions import NotRequired
|
6
12
|
|
7
13
|
from anthropic import (
|
8
14
|
APIConnectionError,
|
9
15
|
AsyncAnthropic,
|
10
16
|
AsyncAnthropicBedrock,
|
17
|
+
AsyncAnthropicVertex,
|
11
18
|
BadRequestError,
|
12
19
|
InternalServerError,
|
20
|
+
NotGiven,
|
13
21
|
RateLimitError,
|
14
22
|
)
|
23
|
+
from anthropic._types import Body
|
15
24
|
from anthropic.types import (
|
16
25
|
ImageBlockParam,
|
17
26
|
Message,
|
@@ -27,7 +36,11 @@ from anthropic.types import (
|
|
27
36
|
from pydantic import JsonValue
|
28
37
|
from typing_extensions import override
|
29
38
|
|
30
|
-
from inspect_ai._util.constants import
|
39
|
+
from inspect_ai._util.constants import (
|
40
|
+
BASE_64_DATA_REMOVED,
|
41
|
+
DEFAULT_MAX_RETRIES,
|
42
|
+
NO_CONTENT,
|
43
|
+
)
|
31
44
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
32
45
|
from inspect_ai._util.error import exception_message
|
33
46
|
from inspect_ai._util.images import file_as_data_uri
|
@@ -35,20 +48,11 @@ from inspect_ai._util.logger import warn_once
|
|
35
48
|
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
36
49
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
37
50
|
|
38
|
-
from .._chat_message import
|
39
|
-
ChatMessage,
|
40
|
-
ChatMessageAssistant,
|
41
|
-
ChatMessageSystem,
|
42
|
-
)
|
51
|
+
from .._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageSystem
|
43
52
|
from .._generate_config import GenerateConfig
|
44
53
|
from .._model import ModelAPI
|
45
54
|
from .._model_call import ModelCall
|
46
|
-
from .._model_output import
|
47
|
-
ChatCompletionChoice,
|
48
|
-
ModelOutput,
|
49
|
-
ModelUsage,
|
50
|
-
StopReason,
|
51
|
-
)
|
55
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
|
52
56
|
from .util import environment_prerequisite_error, model_base_url
|
53
57
|
|
54
58
|
logger = getLogger(__name__)
|
@@ -63,15 +67,25 @@ class AnthropicAPI(ModelAPI):
|
|
63
67
|
base_url: str | None = None,
|
64
68
|
api_key: str | None = None,
|
65
69
|
config: GenerateConfig = GenerateConfig(),
|
66
|
-
bedrock: bool = False,
|
67
70
|
**model_args: Any,
|
68
71
|
):
|
69
72
|
# extract any service prefix from model name
|
70
73
|
parts = model_name.split("/")
|
71
74
|
if len(parts) > 1:
|
72
|
-
service = parts[0]
|
73
|
-
bedrock = service == "bedrock"
|
75
|
+
self.service: str | None = parts[0]
|
74
76
|
model_name = "/".join(parts[1:])
|
77
|
+
else:
|
78
|
+
self.service = None
|
79
|
+
|
80
|
+
# collect gemerate model_args (then delete them so we can pass the rest on)
|
81
|
+
def collect_model_arg(name: str) -> Any | None:
|
82
|
+
nonlocal model_args
|
83
|
+
value = model_args.get(name, None)
|
84
|
+
if value is not None:
|
85
|
+
model_args.pop(name)
|
86
|
+
return value
|
87
|
+
|
88
|
+
self.extra_body: Body | None = collect_model_arg("extra_body")
|
75
89
|
|
76
90
|
# call super
|
77
91
|
super().__init__(
|
@@ -83,7 +97,7 @@ class AnthropicAPI(ModelAPI):
|
|
83
97
|
)
|
84
98
|
|
85
99
|
# create client
|
86
|
-
if
|
100
|
+
if self.is_bedrock():
|
87
101
|
base_url = model_base_url(
|
88
102
|
base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
|
89
103
|
)
|
@@ -94,7 +108,9 @@ class AnthropicAPI(ModelAPI):
|
|
94
108
|
if base_region is None:
|
95
109
|
aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
|
96
110
|
|
97
|
-
self.client:
|
111
|
+
self.client: (
|
112
|
+
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
|
113
|
+
) = AsyncAnthropicBedrock(
|
98
114
|
base_url=base_url,
|
99
115
|
max_retries=(
|
100
116
|
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
@@ -102,6 +118,21 @@ class AnthropicAPI(ModelAPI):
|
|
102
118
|
aws_region=aws_region,
|
103
119
|
**model_args,
|
104
120
|
)
|
121
|
+
elif self.is_vertex():
|
122
|
+
base_url = model_base_url(
|
123
|
+
base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
|
124
|
+
)
|
125
|
+
region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
|
126
|
+
project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
|
127
|
+
self.client = AsyncAnthropicVertex(
|
128
|
+
region=region,
|
129
|
+
project_id=project_id,
|
130
|
+
base_url=base_url,
|
131
|
+
max_retries=(
|
132
|
+
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
133
|
+
),
|
134
|
+
**model_args,
|
135
|
+
)
|
105
136
|
else:
|
106
137
|
# resolve api_key
|
107
138
|
if not self.api_key:
|
@@ -118,13 +149,19 @@ class AnthropicAPI(ModelAPI):
|
|
118
149
|
**model_args,
|
119
150
|
)
|
120
151
|
|
152
|
+
def is_bedrock(self) -> bool:
|
153
|
+
return self.service == "bedrock"
|
154
|
+
|
155
|
+
def is_vertex(self) -> bool:
|
156
|
+
return self.service == "vertex"
|
157
|
+
|
121
158
|
async def generate(
|
122
159
|
self,
|
123
160
|
input: list[ChatMessage],
|
124
161
|
tools: list[ToolInfo],
|
125
162
|
tool_choice: ToolChoice,
|
126
163
|
config: GenerateConfig,
|
127
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
164
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
128
165
|
# setup request and response for ModelCall
|
129
166
|
request: dict[str, Any] = {}
|
130
167
|
response: dict[str, Any] = {}
|
@@ -142,7 +179,7 @@ class AnthropicAPI(ModelAPI):
|
|
142
179
|
system_param,
|
143
180
|
tools_param,
|
144
181
|
messages,
|
145
|
-
|
182
|
+
computer_use,
|
146
183
|
) = await resolve_chat_input(self.model_name, input, tools, config)
|
147
184
|
|
148
185
|
# prepare request params (assembed this way so we can log the raw model call)
|
@@ -158,13 +195,15 @@ class AnthropicAPI(ModelAPI):
|
|
158
195
|
# additional options
|
159
196
|
request = request | self.completion_params(config)
|
160
197
|
|
161
|
-
#
|
162
|
-
if
|
163
|
-
request["extra_headers"] = {
|
164
|
-
"anthropic-beta": "prompt-caching-2024-07-31"
|
165
|
-
}
|
198
|
+
# computer use beta
|
199
|
+
if computer_use:
|
200
|
+
request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
|
166
201
|
|
167
|
-
#
|
202
|
+
# extra_body
|
203
|
+
if self.extra_body is not None:
|
204
|
+
request["extra_body"] = self.extra_body
|
205
|
+
|
206
|
+
# make request
|
168
207
|
message = await self.client.messages.create(**request, stream=False)
|
169
208
|
|
170
209
|
# set response for ModelCall
|
@@ -177,11 +216,7 @@ class AnthropicAPI(ModelAPI):
|
|
177
216
|
return output, model_call()
|
178
217
|
|
179
218
|
except BadRequestError as ex:
|
180
|
-
|
181
|
-
if error_output is not None:
|
182
|
-
return error_output, model_call()
|
183
|
-
else:
|
184
|
-
raise ex
|
219
|
+
return self.handle_bad_request(ex), model_call()
|
185
220
|
|
186
221
|
def completion_params(self, config: GenerateConfig) -> dict[str, Any]:
|
187
222
|
params = dict(model=self.model_name, max_tokens=cast(int, config.max_tokens))
|
@@ -234,7 +269,7 @@ class AnthropicAPI(ModelAPI):
|
|
234
269
|
return True
|
235
270
|
|
236
271
|
# convert some common BadRequestError states into 'refusal' model output
|
237
|
-
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput |
|
272
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
238
273
|
error = exception_message(ex).lower()
|
239
274
|
content: str | None = None
|
240
275
|
stop_reason: StopReason | None = None
|
@@ -256,6 +291,9 @@ class AnthropicAPI(ModelAPI):
|
|
256
291
|
elif "content filtering" in error:
|
257
292
|
content = "Sorry, but I am unable to help with that request."
|
258
293
|
stop_reason = "content_filter"
|
294
|
+
else:
|
295
|
+
content = error
|
296
|
+
stop_reason = "unknown"
|
259
297
|
|
260
298
|
if content and stop_reason:
|
261
299
|
return ModelOutput.from_content(
|
@@ -265,7 +303,21 @@ class AnthropicAPI(ModelAPI):
|
|
265
303
|
error=error,
|
266
304
|
)
|
267
305
|
else:
|
268
|
-
return
|
306
|
+
return ex
|
307
|
+
|
308
|
+
|
309
|
+
# native anthropic tool definitions for computer use beta
|
310
|
+
# https://docs.anthropic.com/en/docs/build-with-claude/computer-use
|
311
|
+
class ComputerUseToolParam(TypedDict):
|
312
|
+
type: str
|
313
|
+
name: str
|
314
|
+
display_width_px: NotRequired[int]
|
315
|
+
display_height_px: NotRequired[int]
|
316
|
+
display_number: NotRequired[int]
|
317
|
+
|
318
|
+
|
319
|
+
# tools can be either a stock tool param or a special computer use tool param
|
320
|
+
ToolParamDef = ToolParam | ComputerUseToolParam
|
269
321
|
|
270
322
|
|
271
323
|
async def resolve_chat_input(
|
@@ -273,7 +325,7 @@ async def resolve_chat_input(
|
|
273
325
|
input: list[ChatMessage],
|
274
326
|
tools: list[ToolInfo],
|
275
327
|
config: GenerateConfig,
|
276
|
-
) -> Tuple[list[TextBlockParam] | None, list[
|
328
|
+
) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
|
277
329
|
# extract system message
|
278
330
|
system_messages, messages = split_system_messages(input, config)
|
279
331
|
|
@@ -286,14 +338,7 @@ async def resolve_chat_input(
|
|
286
338
|
)
|
287
339
|
|
288
340
|
# tools
|
289
|
-
tools_params =
|
290
|
-
ToolParam(
|
291
|
-
name=tool.name,
|
292
|
-
description=tool.description,
|
293
|
-
input_schema=tool.parameters.model_dump(exclude_none=True),
|
294
|
-
)
|
295
|
-
for tool in tools
|
296
|
-
]
|
341
|
+
tools_params, computer_use = tool_params_for_tools(tools, config)
|
297
342
|
|
298
343
|
# system messages
|
299
344
|
if len(system_messages) > 0:
|
@@ -343,10 +388,66 @@ async def resolve_chat_input(
|
|
343
388
|
add_cache_control(cast(dict[str, Any], content[-1]))
|
344
389
|
|
345
390
|
# return chat input
|
346
|
-
return system_param, tools_params, message_params,
|
391
|
+
return system_param, tools_params, message_params, computer_use
|
392
|
+
|
393
|
+
|
394
|
+
def tool_params_for_tools(
|
395
|
+
tools: list[ToolInfo], config: GenerateConfig
|
396
|
+
) -> tuple[list[ToolParamDef], bool]:
|
397
|
+
# tool params and computer_use bit to return
|
398
|
+
tool_params: list[ToolParamDef] = []
|
399
|
+
computer_use = False
|
400
|
+
|
401
|
+
# for each tool, check if it has a native computer use implementation and use that
|
402
|
+
# when available (noting that we need to set the computer use request header)
|
403
|
+
for tool in tools:
|
404
|
+
computer_use_tool = (
|
405
|
+
computer_use_tool_param(tool)
|
406
|
+
if config.internal_tools is not False
|
407
|
+
else None
|
408
|
+
)
|
409
|
+
if computer_use_tool:
|
410
|
+
tool_params.append(computer_use_tool)
|
411
|
+
computer_use = True
|
412
|
+
else:
|
413
|
+
tool_params.append(
|
414
|
+
ToolParam(
|
415
|
+
name=tool.name,
|
416
|
+
description=tool.description,
|
417
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
418
|
+
)
|
419
|
+
)
|
420
|
+
|
421
|
+
return tool_params, computer_use
|
347
422
|
|
348
423
|
|
349
|
-
def
|
424
|
+
def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
|
425
|
+
# check for compatible 'computer' tool
|
426
|
+
if tool.name == "computer" and (
|
427
|
+
sorted(tool.parameters.properties.keys())
|
428
|
+
== sorted(["action", "coordinate", "text"])
|
429
|
+
):
|
430
|
+
return ComputerUseToolParam(
|
431
|
+
type="computer_20241022",
|
432
|
+
name="computer",
|
433
|
+
# Note: The dimensions passed here for display_width_px and display_height_px should
|
434
|
+
# match the dimensions of screenshots returned by the tool.
|
435
|
+
# Those dimensions will always be one of the values in MAX_SCALING_TARGETS
|
436
|
+
# in _x11_client.py.
|
437
|
+
# TODO: enhance this code to calculate the dimensions based on the scaled screen
|
438
|
+
# size used by the container.
|
439
|
+
display_width_px=1366,
|
440
|
+
display_height_px=768,
|
441
|
+
display_number=1,
|
442
|
+
)
|
443
|
+
# not a computer_use tool
|
444
|
+
else:
|
445
|
+
return None
|
446
|
+
|
447
|
+
|
448
|
+
def add_cache_control(
|
449
|
+
param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
|
450
|
+
) -> None:
|
350
451
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
351
452
|
|
352
453
|
|
@@ -404,12 +505,13 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
|
|
404
505
|
return {"type": "auto"}
|
405
506
|
|
406
507
|
|
407
|
-
# text we insert when there is no content passed
|
408
|
-
# (as this will result in an Anthropic API error)
|
409
|
-
NO_CONTENT = "(no content)"
|
410
|
-
|
411
|
-
|
412
508
|
async def message_param(message: ChatMessage) -> MessageParam:
|
509
|
+
# if content is empty that is going to result in an error when we replay
|
510
|
+
# this message to claude, so in that case insert a NO_CONTENT message
|
511
|
+
if isinstance(message.content, list) and len(message.content) == 0:
|
512
|
+
message = message.model_copy()
|
513
|
+
message.content = [ContentText(text=NO_CONTENT)]
|
514
|
+
|
413
515
|
# no system role for anthropic (this is more like an assertion,
|
414
516
|
# as these should have already been filtered out)
|
415
517
|
if message.role == "system":
|
@@ -451,7 +553,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
|
|
451
553
|
elif message.role == "assistant" and message.tool_calls:
|
452
554
|
# first include content (claude <thinking>)
|
453
555
|
tools_content: list[TextBlockParam | ImageBlockParam | ToolUseBlockParam] = (
|
454
|
-
[TextBlockParam(type="text", text=message.content)]
|
556
|
+
[TextBlockParam(type="text", text=message.content or NO_CONTENT)]
|
455
557
|
if isinstance(message.content, str)
|
456
558
|
else (
|
457
559
|
[(await message_param_content(content)) for content in message.content]
|
@@ -520,11 +622,6 @@ def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelO
|
|
520
622
|
)
|
521
623
|
)
|
522
624
|
|
523
|
-
# if content is empty that is going to result in an error when we replay
|
524
|
-
# this message to claude, so in that case insert a NO_CONTENT message
|
525
|
-
if len(content) == 0:
|
526
|
-
content = [ContentText(text=NO_CONTENT)]
|
527
|
-
|
528
625
|
# resolve choice
|
529
626
|
choice = ChatCompletionChoice(
|
530
627
|
message=ChatMessageAssistant(
|
@@ -37,6 +37,7 @@ from inspect_ai.tool import ToolChoice, ToolInfo
|
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
39
39
|
|
40
|
+
from .._call_tools import parse_tool_call
|
40
41
|
from .._chat_message import (
|
41
42
|
ChatMessage,
|
42
43
|
ChatMessageAssistant,
|
@@ -60,7 +61,6 @@ from .util import (
|
|
60
61
|
)
|
61
62
|
from .util.chatapi import ChatAPIHandler
|
62
63
|
from .util.llama31 import Llama31Handler
|
63
|
-
from .util.util import parse_tool_call
|
64
64
|
|
65
65
|
AZUREAI_API_KEY = "AZUREAI_API_KEY"
|
66
66
|
AZUREAI_ENDPOINT_KEY = "AZUREAI_ENDPOINT_KEY"
|
@@ -130,7 +130,7 @@ class AzureAIAPI(ModelAPI):
|
|
130
130
|
tools: list[ToolInfo],
|
131
131
|
tool_choice: ToolChoice,
|
132
132
|
config: GenerateConfig,
|
133
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
133
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
134
134
|
# emulate tools (auto for llama, opt-in for others)
|
135
135
|
if self.emulate_tools is None and self.is_llama():
|
136
136
|
handler: ChatAPIHandler | None = Llama31Handler()
|
@@ -162,6 +162,19 @@ class AzureAIAPI(ModelAPI):
|
|
162
162
|
model_extras=self.model_args,
|
163
163
|
)
|
164
164
|
|
165
|
+
def model_call(response: ChatCompletions | None = None) -> ModelCall:
|
166
|
+
return ModelCall.create(
|
167
|
+
request=request
|
168
|
+
| dict(
|
169
|
+
messages=[message.as_dict() for message in request["messages"]],
|
170
|
+
tools=[tool.as_dict() for tool in request["tools"]]
|
171
|
+
if request.get("tools", None) is not None
|
172
|
+
else None,
|
173
|
+
),
|
174
|
+
response=response.as_dict() if response else {},
|
175
|
+
filter=image_url_filter,
|
176
|
+
)
|
177
|
+
|
165
178
|
# make call
|
166
179
|
try:
|
167
180
|
response: ChatCompletions = await client.complete(**request)
|
@@ -173,19 +186,10 @@ class AzureAIAPI(ModelAPI):
|
|
173
186
|
output_tokens=response.usage.completion_tokens,
|
174
187
|
total_tokens=response.usage.total_tokens,
|
175
188
|
),
|
176
|
-
),
|
177
|
-
|
178
|
-
| dict(
|
179
|
-
messages=[message.as_dict() for message in request["messages"]],
|
180
|
-
tools=[tool.as_dict() for tool in request["tools"]]
|
181
|
-
if request.get("tools", None) is not None
|
182
|
-
else None,
|
183
|
-
),
|
184
|
-
response=response.as_dict(),
|
185
|
-
filter=image_url_filter,
|
186
|
-
)
|
189
|
+
), model_call(response)
|
190
|
+
|
187
191
|
except AzureError as ex:
|
188
|
-
return self.handle_azure_error(ex)
|
192
|
+
return self.handle_azure_error(ex), model_call()
|
189
193
|
finally:
|
190
194
|
await client.close()
|
191
195
|
|
@@ -251,7 +255,7 @@ class AzureAIAPI(ModelAPI):
|
|
251
255
|
def is_mistral(self) -> bool:
|
252
256
|
return "mistral" in self.model_name.lower()
|
253
257
|
|
254
|
-
def handle_azure_error(self, ex: AzureError) -> ModelOutput:
|
258
|
+
def handle_azure_error(self, ex: AzureError) -> ModelOutput | Exception:
|
255
259
|
if isinstance(ex, HttpResponseError):
|
256
260
|
response = str(ex.message)
|
257
261
|
if "maximum context length" in response.lower():
|
@@ -260,12 +264,8 @@ class AzureAIAPI(ModelAPI):
|
|
260
264
|
content=response,
|
261
265
|
stop_reason="model_length",
|
262
266
|
)
|
263
|
-
elif ex.status_code == 400
|
264
|
-
return
|
265
|
-
model=self.model_name,
|
266
|
-
content=f"Your request triggered an error: {ex.error}",
|
267
|
-
stop_reason="content_filter",
|
268
|
-
)
|
267
|
+
elif ex.status_code == 400:
|
268
|
+
return ex
|
269
269
|
|
270
270
|
raise ex
|
271
271
|
|
@@ -27,11 +27,7 @@ from .._chat_message import (
|
|
27
27
|
from .._generate_config import GenerateConfig
|
28
28
|
from .._model import ModelAPI
|
29
29
|
from .._model_call import ModelCall
|
30
|
-
from .._model_output import
|
31
|
-
ChatCompletionChoice,
|
32
|
-
ModelOutput,
|
33
|
-
ModelUsage,
|
34
|
-
)
|
30
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
35
31
|
from .util import (
|
36
32
|
model_base_url,
|
37
33
|
)
|
@@ -307,7 +303,7 @@ class BedrockAPI(ModelAPI):
|
|
307
303
|
tools: list[ToolInfo],
|
308
304
|
tool_choice: ToolChoice,
|
309
305
|
config: GenerateConfig,
|
310
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
306
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
311
307
|
from botocore.config import Config
|
312
308
|
from botocore.exceptions import ClientError
|
313
309
|
|
@@ -339,25 +335,33 @@ class BedrockAPI(ModelAPI):
|
|
339
335
|
# Resolve the input messages into converse messages
|
340
336
|
system, messages = await converse_messages(input)
|
341
337
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
338
|
+
# Make the request
|
339
|
+
request = ConverseClientConverseRequest(
|
340
|
+
modelId=self.model_name,
|
341
|
+
messages=messages,
|
342
|
+
system=system,
|
343
|
+
inferenceConfig=ConverseInferenceConfig(
|
344
|
+
maxTokens=config.max_tokens,
|
345
|
+
temperature=config.temperature,
|
346
|
+
topP=config.top_p,
|
347
|
+
stopSequences=config.stop_seqs,
|
348
|
+
),
|
349
|
+
additionalModelRequestFields={
|
350
|
+
"top_k": config.top_k,
|
351
|
+
**config.model_config,
|
352
|
+
},
|
353
|
+
toolConfig=tool_config,
|
354
|
+
)
|
355
|
+
|
356
|
+
def model_call(response: dict[str, Any] | None = None) -> ModelCall:
|
357
|
+
return ModelCall.create(
|
358
|
+
request=replace_bytes_with_placeholder(
|
359
|
+
request.model_dump(exclude_none=True)
|
353
360
|
),
|
354
|
-
|
355
|
-
"top_k": config.top_k,
|
356
|
-
**config.model_config,
|
357
|
-
},
|
358
|
-
toolConfig=tool_config,
|
361
|
+
response=response,
|
359
362
|
)
|
360
363
|
|
364
|
+
try:
|
361
365
|
# Process the reponse
|
362
366
|
response = await client.converse(
|
363
367
|
**request.model_dump(exclude_none=True)
|
@@ -366,32 +370,24 @@ class BedrockAPI(ModelAPI):
|
|
366
370
|
|
367
371
|
except ClientError as ex:
|
368
372
|
# Look for an explicit validation exception
|
369
|
-
if
|
370
|
-
ex.response["Error"]["Code"] == "ValidationException"
|
371
|
-
and "Too many input tokens" in ex.response["Error"]["Message"]
|
372
|
-
):
|
373
|
+
if ex.response["Error"]["Code"] == "ValidationException":
|
373
374
|
response = ex.response["Error"]["Message"]
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
375
|
+
if "Too many input tokens" in response:
|
376
|
+
return ModelOutput.from_content(
|
377
|
+
model=self.model_name,
|
378
|
+
content=response,
|
379
|
+
stop_reason="model_length",
|
380
|
+
)
|
381
|
+
else:
|
382
|
+
return ex, model_call(None)
|
379
383
|
else:
|
380
384
|
raise ex
|
381
385
|
|
382
386
|
# create a model output from the response
|
383
387
|
output = model_output_from_response(self.model_name, converse_response, tools)
|
384
388
|
|
385
|
-
# record call
|
386
|
-
call = ModelCall.create(
|
387
|
-
request=replace_bytes_with_placeholder(
|
388
|
-
request.model_dump(exclude_none=True)
|
389
|
-
),
|
390
|
-
response=response,
|
391
|
-
)
|
392
|
-
|
393
389
|
# return
|
394
|
-
return output,
|
390
|
+
return output, model_call(response)
|
395
391
|
|
396
392
|
|
397
393
|
async def converse_messages(
|
@@ -550,6 +546,7 @@ async def converse_chat_message(
|
|
550
546
|
"Tool call is missing a tool call id, which is required for Converse API"
|
551
547
|
)
|
552
548
|
if message.function is None:
|
549
|
+
print(message)
|
553
550
|
raise ValueError(
|
554
551
|
"Tool call is missing a function, which is required for Converse API"
|
555
552
|
)
|