inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +7 -3
- inspect_ai/_cli/eval.py +17 -2
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +4 -3
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +4 -9
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +119 -16
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/score.py +1 -0
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/results.py +50 -22
- inspect_ai/_eval/task/run.py +180 -124
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25375 -1846
- inspect_ai/_view/www/log-schema.json +129 -15
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +8 -10
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +75 -2
- inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +62 -27
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/Json.mjs +12 -6
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +3 -6
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +26 -5
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +59 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/_conversation.py +61 -0
- inspect_ai/model/_generate_config.py +10 -4
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +7 -2
- inspect_ai/model/_providers/anthropic.py +109 -51
- inspect_ai/model/_providers/azureai.py +26 -24
- inspect_ai/model/_providers/bedrock.py +43 -44
- inspect_ai/model/_providers/google.py +121 -58
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +17 -20
- inspect_ai/model/_providers/openai.py +32 -21
- inspect_ai/model/_providers/openai_o1.py +9 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +8 -8
- inspect_ai/model/_providers/vertex.py +18 -8
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +2 -2
- inspect_ai/solver/__init__.py +2 -5
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +11 -1
- inspect_ai/tool/_tool.py +21 -3
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +84 -14
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/environment.py +27 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/model/_trace.py +0 -48
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,14 @@
|
|
1
1
|
import functools
|
2
2
|
import os
|
3
|
+
import sys
|
3
4
|
from copy import copy
|
4
5
|
from logging import getLogger
|
5
|
-
from typing import Any, Literal, Tuple, cast
|
6
|
+
from typing import Any, Literal, Tuple, TypedDict, cast
|
7
|
+
|
8
|
+
if sys.version_info >= (3, 11):
|
9
|
+
from typing import NotRequired
|
10
|
+
else:
|
11
|
+
from typing_extensions import NotRequired
|
6
12
|
|
7
13
|
from anthropic import (
|
8
14
|
APIConnectionError,
|
@@ -27,28 +33,23 @@ from anthropic.types import (
|
|
27
33
|
from pydantic import JsonValue
|
28
34
|
from typing_extensions import override
|
29
35
|
|
30
|
-
from inspect_ai._util.constants import
|
31
|
-
|
36
|
+
from inspect_ai._util.constants import (
|
37
|
+
BASE_64_DATA_REMOVED,
|
38
|
+
DEFAULT_MAX_RETRIES,
|
39
|
+
NO_CONTENT,
|
40
|
+
)
|
41
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
32
42
|
from inspect_ai._util.error import exception_message
|
33
|
-
from inspect_ai._util.images import
|
43
|
+
from inspect_ai._util.images import file_as_data_uri
|
34
44
|
from inspect_ai._util.logger import warn_once
|
35
|
-
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
45
|
+
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
36
46
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
37
47
|
|
38
|
-
from .._chat_message import
|
39
|
-
ChatMessage,
|
40
|
-
ChatMessageAssistant,
|
41
|
-
ChatMessageSystem,
|
42
|
-
)
|
48
|
+
from .._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageSystem
|
43
49
|
from .._generate_config import GenerateConfig
|
44
50
|
from .._model import ModelAPI
|
45
51
|
from .._model_call import ModelCall
|
46
|
-
from .._model_output import
|
47
|
-
ChatCompletionChoice,
|
48
|
-
ModelOutput,
|
49
|
-
ModelUsage,
|
50
|
-
StopReason,
|
51
|
-
)
|
52
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
|
52
53
|
from .util import environment_prerequisite_error, model_base_url
|
53
54
|
|
54
55
|
logger = getLogger(__name__)
|
@@ -124,7 +125,7 @@ class AnthropicAPI(ModelAPI):
|
|
124
125
|
tools: list[ToolInfo],
|
125
126
|
tool_choice: ToolChoice,
|
126
127
|
config: GenerateConfig,
|
127
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
128
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
128
129
|
# setup request and response for ModelCall
|
129
130
|
request: dict[str, Any] = {}
|
130
131
|
response: dict[str, Any] = {}
|
@@ -142,7 +143,7 @@ class AnthropicAPI(ModelAPI):
|
|
142
143
|
system_param,
|
143
144
|
tools_param,
|
144
145
|
messages,
|
145
|
-
|
146
|
+
computer_use,
|
146
147
|
) = await resolve_chat_input(self.model_name, input, tools, config)
|
147
148
|
|
148
149
|
# prepare request params (assembed this way so we can log the raw model call)
|
@@ -158,13 +159,11 @@ class AnthropicAPI(ModelAPI):
|
|
158
159
|
# additional options
|
159
160
|
request = request | self.completion_params(config)
|
160
161
|
|
161
|
-
#
|
162
|
-
if
|
163
|
-
request["extra_headers"] = {
|
164
|
-
"anthropic-beta": "prompt-caching-2024-07-31"
|
165
|
-
}
|
162
|
+
# computer use beta
|
163
|
+
if computer_use:
|
164
|
+
request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
|
166
165
|
|
167
|
-
#
|
166
|
+
# make request
|
168
167
|
message = await self.client.messages.create(**request, stream=False)
|
169
168
|
|
170
169
|
# set response for ModelCall
|
@@ -177,11 +176,7 @@ class AnthropicAPI(ModelAPI):
|
|
177
176
|
return output, model_call()
|
178
177
|
|
179
178
|
except BadRequestError as ex:
|
180
|
-
|
181
|
-
if error_output is not None:
|
182
|
-
return error_output, model_call()
|
183
|
-
else:
|
184
|
-
raise ex
|
179
|
+
return self.handle_bad_request(ex), model_call()
|
185
180
|
|
186
181
|
def completion_params(self, config: GenerateConfig) -> dict[str, Any]:
|
187
182
|
params = dict(model=self.model_name, max_tokens=cast(int, config.max_tokens))
|
@@ -234,7 +229,7 @@ class AnthropicAPI(ModelAPI):
|
|
234
229
|
return True
|
235
230
|
|
236
231
|
# convert some common BadRequestError states into 'refusal' model output
|
237
|
-
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput |
|
232
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
238
233
|
error = exception_message(ex).lower()
|
239
234
|
content: str | None = None
|
240
235
|
stop_reason: StopReason | None = None
|
@@ -256,6 +251,9 @@ class AnthropicAPI(ModelAPI):
|
|
256
251
|
elif "content filtering" in error:
|
257
252
|
content = "Sorry, but I am unable to help with that request."
|
258
253
|
stop_reason = "content_filter"
|
254
|
+
else:
|
255
|
+
content = error
|
256
|
+
stop_reason = "unknown"
|
259
257
|
|
260
258
|
if content and stop_reason:
|
261
259
|
return ModelOutput.from_content(
|
@@ -265,7 +263,21 @@ class AnthropicAPI(ModelAPI):
|
|
265
263
|
error=error,
|
266
264
|
)
|
267
265
|
else:
|
268
|
-
return
|
266
|
+
return ex
|
267
|
+
|
268
|
+
|
269
|
+
# native anthropic tool definitions for computer use beta
|
270
|
+
# https://docs.anthropic.com/en/docs/build-with-claude/computer-use
|
271
|
+
class ComputerUseToolParam(TypedDict):
|
272
|
+
type: str
|
273
|
+
name: str
|
274
|
+
display_width_px: NotRequired[int]
|
275
|
+
display_height_px: NotRequired[int]
|
276
|
+
display_number: NotRequired[int]
|
277
|
+
|
278
|
+
|
279
|
+
# tools can be either a stock tool param or a special computer use tool param
|
280
|
+
ToolParamDef = ToolParam | ComputerUseToolParam
|
269
281
|
|
270
282
|
|
271
283
|
async def resolve_chat_input(
|
@@ -273,7 +285,7 @@ async def resolve_chat_input(
|
|
273
285
|
input: list[ChatMessage],
|
274
286
|
tools: list[ToolInfo],
|
275
287
|
config: GenerateConfig,
|
276
|
-
) -> Tuple[list[TextBlockParam] | None, list[
|
288
|
+
) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
|
277
289
|
# extract system message
|
278
290
|
system_messages, messages = split_system_messages(input, config)
|
279
291
|
|
@@ -286,14 +298,7 @@ async def resolve_chat_input(
|
|
286
298
|
)
|
287
299
|
|
288
300
|
# tools
|
289
|
-
tools_params =
|
290
|
-
ToolParam(
|
291
|
-
name=tool.name,
|
292
|
-
description=tool.description,
|
293
|
-
input_schema=tool.parameters.model_dump(exclude_none=True),
|
294
|
-
)
|
295
|
-
for tool in tools
|
296
|
-
]
|
301
|
+
tools_params, computer_use = tool_params_for_tools(tools, config)
|
297
302
|
|
298
303
|
# system messages
|
299
304
|
if len(system_messages) > 0:
|
@@ -343,10 +348,66 @@ async def resolve_chat_input(
|
|
343
348
|
add_cache_control(cast(dict[str, Any], content[-1]))
|
344
349
|
|
345
350
|
# return chat input
|
346
|
-
return system_param, tools_params, message_params,
|
351
|
+
return system_param, tools_params, message_params, computer_use
|
352
|
+
|
353
|
+
|
354
|
+
def tool_params_for_tools(
|
355
|
+
tools: list[ToolInfo], config: GenerateConfig
|
356
|
+
) -> tuple[list[ToolParamDef], bool]:
|
357
|
+
# tool params and computer_use bit to return
|
358
|
+
tool_params: list[ToolParamDef] = []
|
359
|
+
computer_use = False
|
360
|
+
|
361
|
+
# for each tool, check if it has a native computer use implementation and use that
|
362
|
+
# when available (noting that we need to set the computer use request header)
|
363
|
+
for tool in tools:
|
364
|
+
computer_use_tool = (
|
365
|
+
computer_use_tool_param(tool)
|
366
|
+
if config.internal_tools is not False
|
367
|
+
else None
|
368
|
+
)
|
369
|
+
if computer_use_tool:
|
370
|
+
tool_params.append(computer_use_tool)
|
371
|
+
computer_use = True
|
372
|
+
else:
|
373
|
+
tool_params.append(
|
374
|
+
ToolParam(
|
375
|
+
name=tool.name,
|
376
|
+
description=tool.description,
|
377
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
378
|
+
)
|
379
|
+
)
|
380
|
+
|
381
|
+
return tool_params, computer_use
|
382
|
+
|
347
383
|
|
384
|
+
def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
|
385
|
+
# check for compatible 'computer' tool
|
386
|
+
if tool.name == "computer" and (
|
387
|
+
sorted(tool.parameters.properties.keys())
|
388
|
+
== sorted(["action", "coordinate", "text"])
|
389
|
+
):
|
390
|
+
return ComputerUseToolParam(
|
391
|
+
type="computer_20241022",
|
392
|
+
name="computer",
|
393
|
+
# Note: The dimensions passed here for display_width_px and display_height_px should
|
394
|
+
# match the dimensions of screenshots returned by the tool.
|
395
|
+
# Those dimensions will always be one of the values in MAX_SCALING_TARGETS
|
396
|
+
# in _x11_client.py.
|
397
|
+
# TODO: enhance this code to calculate the dimensions based on the scaled screen
|
398
|
+
# size used by the container.
|
399
|
+
display_width_px=1366,
|
400
|
+
display_height_px=768,
|
401
|
+
display_number=1,
|
402
|
+
)
|
403
|
+
# not a computer_use tool
|
404
|
+
else:
|
405
|
+
return None
|
348
406
|
|
349
|
-
|
407
|
+
|
408
|
+
def add_cache_control(
|
409
|
+
param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
|
410
|
+
) -> None:
|
350
411
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
351
412
|
|
352
413
|
|
@@ -404,11 +465,6 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
|
|
404
465
|
return {"type": "auto"}
|
405
466
|
|
406
467
|
|
407
|
-
# text we insert when there is no content passed
|
408
|
-
# (as this will result in an Anthropic API error)
|
409
|
-
NO_CONTENT = "(no content)"
|
410
|
-
|
411
|
-
|
412
468
|
async def message_param(message: ChatMessage) -> MessageParam:
|
413
469
|
# no system role for anthropic (this is more like an assertion,
|
414
470
|
# as these should have already been filtered out)
|
@@ -584,11 +640,9 @@ async def message_param_content(
|
|
584
640
|
) -> TextBlockParam | ImageBlockParam:
|
585
641
|
if isinstance(content, ContentText):
|
586
642
|
return TextBlockParam(type="text", text=content.text or NO_CONTENT)
|
587
|
-
|
643
|
+
elif isinstance(content, ContentImage):
|
588
644
|
# resolve to url
|
589
|
-
image = content.image
|
590
|
-
if not is_data_uri(image):
|
591
|
-
image = await image_as_data_uri(image)
|
645
|
+
image = await file_as_data_uri(content.image)
|
592
646
|
|
593
647
|
# resolve mime type and base64 content
|
594
648
|
media_type = data_uri_mime_type(image) or "image/png"
|
@@ -601,6 +655,10 @@ async def message_param_content(
|
|
601
655
|
type="image",
|
602
656
|
source=dict(type="base64", media_type=cast(Any, media_type), data=image),
|
603
657
|
)
|
658
|
+
else:
|
659
|
+
raise RuntimeError(
|
660
|
+
"Anthropic models do not currently support audio or video inputs."
|
661
|
+
)
|
604
662
|
|
605
663
|
|
606
664
|
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
@@ -31,8 +31,8 @@ from azure.core.exceptions import AzureError, HttpResponseError
|
|
31
31
|
from typing_extensions import override
|
32
32
|
|
33
33
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
34
|
-
from inspect_ai._util.content import Content, ContentText
|
35
|
-
from inspect_ai._util.images import
|
34
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
35
|
+
from inspect_ai._util.images import file_as_data_uri
|
36
36
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
@@ -130,7 +130,7 @@ class AzureAIAPI(ModelAPI):
|
|
130
130
|
tools: list[ToolInfo],
|
131
131
|
tool_choice: ToolChoice,
|
132
132
|
config: GenerateConfig,
|
133
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
133
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
134
134
|
# emulate tools (auto for llama, opt-in for others)
|
135
135
|
if self.emulate_tools is None and self.is_llama():
|
136
136
|
handler: ChatAPIHandler | None = Llama31Handler()
|
@@ -162,6 +162,19 @@ class AzureAIAPI(ModelAPI):
|
|
162
162
|
model_extras=self.model_args,
|
163
163
|
)
|
164
164
|
|
165
|
+
def model_call(response: ChatCompletions | None = None) -> ModelCall:
|
166
|
+
return ModelCall.create(
|
167
|
+
request=request
|
168
|
+
| dict(
|
169
|
+
messages=[message.as_dict() for message in request["messages"]],
|
170
|
+
tools=[tool.as_dict() for tool in request["tools"]]
|
171
|
+
if request.get("tools", None) is not None
|
172
|
+
else None,
|
173
|
+
),
|
174
|
+
response=response.as_dict() if response else {},
|
175
|
+
filter=image_url_filter,
|
176
|
+
)
|
177
|
+
|
165
178
|
# make call
|
166
179
|
try:
|
167
180
|
response: ChatCompletions = await client.complete(**request)
|
@@ -173,19 +186,10 @@ class AzureAIAPI(ModelAPI):
|
|
173
186
|
output_tokens=response.usage.completion_tokens,
|
174
187
|
total_tokens=response.usage.total_tokens,
|
175
188
|
),
|
176
|
-
),
|
177
|
-
|
178
|
-
| dict(
|
179
|
-
messages=[message.as_dict() for message in request["messages"]],
|
180
|
-
tools=[tool.as_dict() for tool in request["tools"]]
|
181
|
-
if request.get("tools", None) is not None
|
182
|
-
else None,
|
183
|
-
),
|
184
|
-
response=response.as_dict(),
|
185
|
-
filter=image_url_filter,
|
186
|
-
)
|
189
|
+
), model_call(response)
|
190
|
+
|
187
191
|
except AzureError as ex:
|
188
|
-
return self.handle_azure_error(ex)
|
192
|
+
return self.handle_azure_error(ex), model_call()
|
189
193
|
finally:
|
190
194
|
await client.close()
|
191
195
|
|
@@ -251,7 +255,7 @@ class AzureAIAPI(ModelAPI):
|
|
251
255
|
def is_mistral(self) -> bool:
|
252
256
|
return "mistral" in self.model_name.lower()
|
253
257
|
|
254
|
-
def handle_azure_error(self, ex: AzureError) -> ModelOutput:
|
258
|
+
def handle_azure_error(self, ex: AzureError) -> ModelOutput | Exception:
|
255
259
|
if isinstance(ex, HttpResponseError):
|
256
260
|
response = str(ex.message)
|
257
261
|
if "maximum context length" in response.lower():
|
@@ -260,12 +264,8 @@ class AzureAIAPI(ModelAPI):
|
|
260
264
|
content=response,
|
261
265
|
stop_reason="model_length",
|
262
266
|
)
|
263
|
-
elif ex.status_code == 400
|
264
|
-
return
|
265
|
-
model=self.model_name,
|
266
|
-
content=f"Your request triggered an error: {ex.error}",
|
267
|
-
stop_reason="content_filter",
|
268
|
-
)
|
267
|
+
elif ex.status_code == 400:
|
268
|
+
return ex
|
269
269
|
|
270
270
|
raise ex
|
271
271
|
|
@@ -312,12 +312,14 @@ async def chat_request_message(
|
|
312
312
|
async def chat_content_item(content: Content) -> ContentItem:
|
313
313
|
if isinstance(content, ContentText):
|
314
314
|
return TextContentItem(text=content.text)
|
315
|
-
|
315
|
+
elif isinstance(content, ContentImage):
|
316
316
|
return ImageContentItem(
|
317
317
|
image_url=ImageUrl(
|
318
|
-
url=await
|
318
|
+
url=await file_as_data_uri(content.image), detail=content.detail
|
319
319
|
)
|
320
320
|
)
|
321
|
+
else:
|
322
|
+
raise RuntimeError("Azure AI models do not support audio or video inputs.")
|
321
323
|
|
322
324
|
|
323
325
|
def chat_tool_call(tool_call: ToolCall) -> ChatCompletionsToolCall:
|
@@ -11,7 +11,7 @@ from inspect_ai._util.constants import (
|
|
11
11
|
)
|
12
12
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
13
13
|
from inspect_ai._util.error import pip_dependency_error
|
14
|
-
from inspect_ai._util.images import
|
14
|
+
from inspect_ai._util.images import file_as_data
|
15
15
|
from inspect_ai._util.version import verify_required_version
|
16
16
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
17
17
|
from inspect_ai.tool._tool_call import ToolCall
|
@@ -27,11 +27,7 @@ from .._chat_message import (
|
|
27
27
|
from .._generate_config import GenerateConfig
|
28
28
|
from .._model import ModelAPI
|
29
29
|
from .._model_call import ModelCall
|
30
|
-
from .._model_output import
|
31
|
-
ChatCompletionChoice,
|
32
|
-
ModelOutput,
|
33
|
-
ModelUsage,
|
34
|
-
)
|
30
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
35
31
|
from .util import (
|
36
32
|
model_base_url,
|
37
33
|
)
|
@@ -307,7 +303,7 @@ class BedrockAPI(ModelAPI):
|
|
307
303
|
tools: list[ToolInfo],
|
308
304
|
tool_choice: ToolChoice,
|
309
305
|
config: GenerateConfig,
|
310
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
306
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
311
307
|
from botocore.config import Config
|
312
308
|
from botocore.exceptions import ClientError
|
313
309
|
|
@@ -339,25 +335,33 @@ class BedrockAPI(ModelAPI):
|
|
339
335
|
# Resolve the input messages into converse messages
|
340
336
|
system, messages = await converse_messages(input)
|
341
337
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
338
|
+
# Make the request
|
339
|
+
request = ConverseClientConverseRequest(
|
340
|
+
modelId=self.model_name,
|
341
|
+
messages=messages,
|
342
|
+
system=system,
|
343
|
+
inferenceConfig=ConverseInferenceConfig(
|
344
|
+
maxTokens=config.max_tokens,
|
345
|
+
temperature=config.temperature,
|
346
|
+
topP=config.top_p,
|
347
|
+
stopSequences=config.stop_seqs,
|
348
|
+
),
|
349
|
+
additionalModelRequestFields={
|
350
|
+
"top_k": config.top_k,
|
351
|
+
**config.model_config,
|
352
|
+
},
|
353
|
+
toolConfig=tool_config,
|
354
|
+
)
|
355
|
+
|
356
|
+
def model_call(response: dict[str, Any] | None = None) -> ModelCall:
|
357
|
+
return ModelCall.create(
|
358
|
+
request=replace_bytes_with_placeholder(
|
359
|
+
request.model_dump(exclude_none=True)
|
353
360
|
),
|
354
|
-
|
355
|
-
"top_k": config.top_k,
|
356
|
-
**config.model_config,
|
357
|
-
},
|
358
|
-
toolConfig=tool_config,
|
361
|
+
response=response,
|
359
362
|
)
|
360
363
|
|
364
|
+
try:
|
361
365
|
# Process the reponse
|
362
366
|
response = await client.converse(
|
363
367
|
**request.model_dump(exclude_none=True)
|
@@ -366,32 +370,24 @@ class BedrockAPI(ModelAPI):
|
|
366
370
|
|
367
371
|
except ClientError as ex:
|
368
372
|
# Look for an explicit validation exception
|
369
|
-
if
|
370
|
-
ex.response["Error"]["Code"] == "ValidationException"
|
371
|
-
and "Too many input tokens" in ex.response["Error"]["Message"]
|
372
|
-
):
|
373
|
+
if ex.response["Error"]["Code"] == "ValidationException":
|
373
374
|
response = ex.response["Error"]["Message"]
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
375
|
+
if "Too many input tokens" in response:
|
376
|
+
return ModelOutput.from_content(
|
377
|
+
model=self.model_name,
|
378
|
+
content=response,
|
379
|
+
stop_reason="model_length",
|
380
|
+
)
|
381
|
+
else:
|
382
|
+
return ex, model_call(None)
|
379
383
|
else:
|
380
384
|
raise ex
|
381
385
|
|
382
386
|
# create a model output from the response
|
383
387
|
output = model_output_from_response(self.model_name, converse_response, tools)
|
384
388
|
|
385
|
-
# record call
|
386
|
-
call = ModelCall.create(
|
387
|
-
request=replace_bytes_with_placeholder(
|
388
|
-
request.model_dump(exclude_none=True)
|
389
|
-
),
|
390
|
-
response=response,
|
391
|
-
)
|
392
|
-
|
393
389
|
# return
|
394
|
-
return output,
|
390
|
+
return output, model_call(response)
|
395
391
|
|
396
392
|
|
397
393
|
async def converse_messages(
|
@@ -430,7 +426,9 @@ def model_output_from_response(
|
|
430
426
|
content.append(ContentText(type="text", text=c.text))
|
431
427
|
elif c.image is not None:
|
432
428
|
base64_image = base64.b64encode(c.image.source.bytes).decode("utf-8")
|
433
|
-
content.append(
|
429
|
+
content.append(
|
430
|
+
ContentImage(image=f"data:image/{c.image.format};base64,{base64_image}")
|
431
|
+
)
|
434
432
|
elif c.toolUse is not None:
|
435
433
|
tool_calls.append(
|
436
434
|
ToolCall(
|
@@ -548,6 +546,7 @@ async def converse_chat_message(
|
|
548
546
|
"Tool call is missing a tool call id, which is required for Converse API"
|
549
547
|
)
|
550
548
|
if message.function is None:
|
549
|
+
print(message)
|
551
550
|
raise ValueError(
|
552
551
|
"Tool call is missing a function, which is required for Converse API"
|
553
552
|
)
|
@@ -565,7 +564,7 @@ async def converse_chat_message(
|
|
565
564
|
if c.type == "text":
|
566
565
|
tool_result_content.append(ConverseToolResultContent(text=c.text))
|
567
566
|
elif c.type == "image":
|
568
|
-
image_data, image_type = await
|
567
|
+
image_data, image_type = await file_as_data(c.image)
|
569
568
|
tool_result_content.append(
|
570
569
|
ConverseToolResultContent(
|
571
570
|
image=ConverseImage(
|
@@ -604,7 +603,7 @@ async def converse_contents(
|
|
604
603
|
result: list[ConverseMessageContent] = []
|
605
604
|
for c in content:
|
606
605
|
if c.type == "image":
|
607
|
-
image_data, image_type = await
|
606
|
+
image_data, image_type = await file_as_data(c.image)
|
608
607
|
result.append(
|
609
608
|
ConverseMessageContent(
|
610
609
|
image=ConverseImage(
|