inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -9
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +79 -12
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/eval.py +10 -1
  13. inspect_ai/_eval/loader.py +79 -19
  14. inspect_ai/_eval/registry.py +6 -0
  15. inspect_ai/_eval/score.py +3 -1
  16. inspect_ai/_eval/task/results.py +51 -22
  17. inspect_ai/_eval/task/run.py +47 -13
  18. inspect_ai/_eval/task/sandbox.py +10 -5
  19. inspect_ai/_util/constants.py +1 -0
  20. inspect_ai/_util/port_names.py +61 -0
  21. inspect_ai/_util/text.py +23 -0
  22. inspect_ai/_view/www/App.css +31 -1
  23. inspect_ai/_view/www/dist/assets/index.css +31 -1
  24. inspect_ai/_view/www/dist/assets/index.js +25498 -2044
  25. inspect_ai/_view/www/log-schema.json +32 -2
  26. inspect_ai/_view/www/package.json +2 -0
  27. inspect_ai/_view/www/src/App.mjs +14 -16
  28. inspect_ai/_view/www/src/Types.mjs +1 -2
  29. inspect_ai/_view/www/src/api/Types.ts +133 -0
  30. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  31. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  32. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  33. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  34. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  35. inspect_ai/_view/www/src/api/index.ts +51 -0
  36. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  37. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  38. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  39. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  40. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  41. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  42. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  43. inspect_ai/_view/www/src/index.js +77 -4
  44. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
  46. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
  47. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  48. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  49. inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
  50. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  51. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  52. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
  53. inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
  54. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  55. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  56. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
  76. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  77. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
  78. inspect_ai/_view/www/vite.config.js +7 -0
  79. inspect_ai/_view/www/yarn.lock +116 -0
  80. inspect_ai/approval/_human/__init__.py +0 -0
  81. inspect_ai/approval/_human/manager.py +1 -1
  82. inspect_ai/approval/_policy.py +12 -6
  83. inspect_ai/log/_log.py +1 -1
  84. inspect_ai/log/_samples.py +16 -0
  85. inspect_ai/log/_transcript.py +4 -1
  86. inspect_ai/model/_call_tools.py +59 -0
  87. inspect_ai/model/_conversation.py +16 -7
  88. inspect_ai/model/_generate_config.py +12 -12
  89. inspect_ai/model/_model.py +117 -18
  90. inspect_ai/model/_model_output.py +22 -2
  91. inspect_ai/model/_openai.py +383 -0
  92. inspect_ai/model/_providers/anthropic.py +152 -55
  93. inspect_ai/model/_providers/azureai.py +21 -21
  94. inspect_ai/model/_providers/bedrock.py +37 -40
  95. inspect_ai/model/_providers/goodfire.py +248 -0
  96. inspect_ai/model/_providers/google.py +46 -54
  97. inspect_ai/model/_providers/groq.py +7 -3
  98. inspect_ai/model/_providers/hf.py +6 -0
  99. inspect_ai/model/_providers/mistral.py +13 -12
  100. inspect_ai/model/_providers/openai.py +51 -218
  101. inspect_ai/model/_providers/openai_o1.py +11 -12
  102. inspect_ai/model/_providers/providers.py +23 -1
  103. inspect_ai/model/_providers/together.py +12 -12
  104. inspect_ai/model/_providers/util/__init__.py +2 -3
  105. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  106. inspect_ai/model/_providers/util/llama31.py +1 -1
  107. inspect_ai/model/_providers/util/util.py +0 -76
  108. inspect_ai/model/_providers/vertex.py +1 -4
  109. inspect_ai/scorer/_metric.py +3 -0
  110. inspect_ai/scorer/_reducer/reducer.py +1 -1
  111. inspect_ai/scorer/_scorer.py +4 -3
  112. inspect_ai/solver/__init__.py +4 -5
  113. inspect_ai/solver/_basic_agent.py +1 -1
  114. inspect_ai/solver/_bridge/__init__.py +3 -0
  115. inspect_ai/solver/_bridge/bridge.py +100 -0
  116. inspect_ai/solver/_bridge/patch.py +170 -0
  117. inspect_ai/solver/_prompt.py +35 -5
  118. inspect_ai/solver/_solver.py +6 -0
  119. inspect_ai/solver/_task_state.py +80 -38
  120. inspect_ai/tool/__init__.py +2 -0
  121. inspect_ai/tool/_tool.py +12 -1
  122. inspect_ai/tool/_tool_call.py +10 -0
  123. inspect_ai/tool/_tool_def.py +16 -5
  124. inspect_ai/tool/_tool_with.py +21 -4
  125. inspect_ai/tool/beta/__init__.py +5 -0
  126. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  127. inspect_ai/tool/beta/_computer/_common.py +133 -0
  128. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  129. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  130. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  131. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  134. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  135. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  136. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  137. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  138. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  139. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  144. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  145. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  146. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  147. inspect_ai/util/__init__.py +2 -0
  148. inspect_ai/util/_display.py +5 -0
  149. inspect_ai/util/_limit.py +26 -0
  150. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  153. inspect_ai/util/_sandbox/environment.py +14 -0
  154. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  155. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
  156. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  157. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  158. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  159. inspect_ai/_view/www/src/api/index.mjs +0 -49
  160. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  161. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  162. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  163. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  164. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  165. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  166. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,26 @@
1
1
  import functools
2
2
  import os
3
+ import sys
3
4
  from copy import copy
4
5
  from logging import getLogger
5
- from typing import Any, Literal, Tuple, cast
6
+ from typing import Any, Literal, Tuple, TypedDict, cast
7
+
8
+ if sys.version_info >= (3, 11):
9
+ from typing import NotRequired
10
+ else:
11
+ from typing_extensions import NotRequired
6
12
 
7
13
  from anthropic import (
8
14
  APIConnectionError,
9
15
  AsyncAnthropic,
10
16
  AsyncAnthropicBedrock,
17
+ AsyncAnthropicVertex,
11
18
  BadRequestError,
12
19
  InternalServerError,
20
+ NotGiven,
13
21
  RateLimitError,
14
22
  )
23
+ from anthropic._types import Body
15
24
  from anthropic.types import (
16
25
  ImageBlockParam,
17
26
  Message,
@@ -27,7 +36,11 @@ from anthropic.types import (
27
36
  from pydantic import JsonValue
28
37
  from typing_extensions import override
29
38
 
30
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED, DEFAULT_MAX_RETRIES
39
+ from inspect_ai._util.constants import (
40
+ BASE_64_DATA_REMOVED,
41
+ DEFAULT_MAX_RETRIES,
42
+ NO_CONTENT,
43
+ )
31
44
  from inspect_ai._util.content import Content, ContentImage, ContentText
32
45
  from inspect_ai._util.error import exception_message
33
46
  from inspect_ai._util.images import file_as_data_uri
@@ -35,20 +48,11 @@ from inspect_ai._util.logger import warn_once
35
48
  from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
36
49
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
37
50
 
38
- from .._chat_message import (
39
- ChatMessage,
40
- ChatMessageAssistant,
41
- ChatMessageSystem,
42
- )
51
+ from .._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageSystem
43
52
  from .._generate_config import GenerateConfig
44
53
  from .._model import ModelAPI
45
54
  from .._model_call import ModelCall
46
- from .._model_output import (
47
- ChatCompletionChoice,
48
- ModelOutput,
49
- ModelUsage,
50
- StopReason,
51
- )
55
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
52
56
  from .util import environment_prerequisite_error, model_base_url
53
57
 
54
58
  logger = getLogger(__name__)
@@ -63,15 +67,25 @@ class AnthropicAPI(ModelAPI):
63
67
  base_url: str | None = None,
64
68
  api_key: str | None = None,
65
69
  config: GenerateConfig = GenerateConfig(),
66
- bedrock: bool = False,
67
70
  **model_args: Any,
68
71
  ):
69
72
  # extract any service prefix from model name
70
73
  parts = model_name.split("/")
71
74
  if len(parts) > 1:
72
- service = parts[0]
73
- bedrock = service == "bedrock"
75
+ self.service: str | None = parts[0]
74
76
  model_name = "/".join(parts[1:])
77
+ else:
78
+ self.service = None
79
+
80
+ # collect gemerate model_args (then delete them so we can pass the rest on)
81
+ def collect_model_arg(name: str) -> Any | None:
82
+ nonlocal model_args
83
+ value = model_args.get(name, None)
84
+ if value is not None:
85
+ model_args.pop(name)
86
+ return value
87
+
88
+ self.extra_body: Body | None = collect_model_arg("extra_body")
75
89
 
76
90
  # call super
77
91
  super().__init__(
@@ -83,7 +97,7 @@ class AnthropicAPI(ModelAPI):
83
97
  )
84
98
 
85
99
  # create client
86
- if bedrock:
100
+ if self.is_bedrock():
87
101
  base_url = model_base_url(
88
102
  base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
89
103
  )
@@ -94,7 +108,9 @@ class AnthropicAPI(ModelAPI):
94
108
  if base_region is None:
95
109
  aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
96
110
 
97
- self.client: AsyncAnthropic | AsyncAnthropicBedrock = AsyncAnthropicBedrock(
111
+ self.client: (
112
+ AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
113
+ ) = AsyncAnthropicBedrock(
98
114
  base_url=base_url,
99
115
  max_retries=(
100
116
  config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
@@ -102,6 +118,21 @@ class AnthropicAPI(ModelAPI):
102
118
  aws_region=aws_region,
103
119
  **model_args,
104
120
  )
121
+ elif self.is_vertex():
122
+ base_url = model_base_url(
123
+ base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
124
+ )
125
+ region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
126
+ project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
127
+ self.client = AsyncAnthropicVertex(
128
+ region=region,
129
+ project_id=project_id,
130
+ base_url=base_url,
131
+ max_retries=(
132
+ config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
133
+ ),
134
+ **model_args,
135
+ )
105
136
  else:
106
137
  # resolve api_key
107
138
  if not self.api_key:
@@ -118,13 +149,19 @@ class AnthropicAPI(ModelAPI):
118
149
  **model_args,
119
150
  )
120
151
 
152
+ def is_bedrock(self) -> bool:
153
+ return self.service == "bedrock"
154
+
155
+ def is_vertex(self) -> bool:
156
+ return self.service == "vertex"
157
+
121
158
  async def generate(
122
159
  self,
123
160
  input: list[ChatMessage],
124
161
  tools: list[ToolInfo],
125
162
  tool_choice: ToolChoice,
126
163
  config: GenerateConfig,
127
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
164
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
128
165
  # setup request and response for ModelCall
129
166
  request: dict[str, Any] = {}
130
167
  response: dict[str, Any] = {}
@@ -142,7 +179,7 @@ class AnthropicAPI(ModelAPI):
142
179
  system_param,
143
180
  tools_param,
144
181
  messages,
145
- cache_prompt,
182
+ computer_use,
146
183
  ) = await resolve_chat_input(self.model_name, input, tools, config)
147
184
 
148
185
  # prepare request params (assembed this way so we can log the raw model call)
@@ -158,13 +195,15 @@ class AnthropicAPI(ModelAPI):
158
195
  # additional options
159
196
  request = request | self.completion_params(config)
160
197
 
161
- # caching header
162
- if cache_prompt:
163
- request["extra_headers"] = {
164
- "anthropic-beta": "prompt-caching-2024-07-31"
165
- }
198
+ # computer use beta
199
+ if computer_use:
200
+ request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
166
201
 
167
- # call model
202
+ # extra_body
203
+ if self.extra_body is not None:
204
+ request["extra_body"] = self.extra_body
205
+
206
+ # make request
168
207
  message = await self.client.messages.create(**request, stream=False)
169
208
 
170
209
  # set response for ModelCall
@@ -177,11 +216,7 @@ class AnthropicAPI(ModelAPI):
177
216
  return output, model_call()
178
217
 
179
218
  except BadRequestError as ex:
180
- error_output = self.handle_bad_request(ex)
181
- if error_output is not None:
182
- return error_output, model_call()
183
- else:
184
- raise ex
219
+ return self.handle_bad_request(ex), model_call()
185
220
 
186
221
  def completion_params(self, config: GenerateConfig) -> dict[str, Any]:
187
222
  params = dict(model=self.model_name, max_tokens=cast(int, config.max_tokens))
@@ -234,7 +269,7 @@ class AnthropicAPI(ModelAPI):
234
269
  return True
235
270
 
236
271
  # convert some common BadRequestError states into 'refusal' model output
237
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
272
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
238
273
  error = exception_message(ex).lower()
239
274
  content: str | None = None
240
275
  stop_reason: StopReason | None = None
@@ -256,6 +291,9 @@ class AnthropicAPI(ModelAPI):
256
291
  elif "content filtering" in error:
257
292
  content = "Sorry, but I am unable to help with that request."
258
293
  stop_reason = "content_filter"
294
+ else:
295
+ content = error
296
+ stop_reason = "unknown"
259
297
 
260
298
  if content and stop_reason:
261
299
  return ModelOutput.from_content(
@@ -265,7 +303,21 @@ class AnthropicAPI(ModelAPI):
265
303
  error=error,
266
304
  )
267
305
  else:
268
- return None
306
+ return ex
307
+
308
+
309
+ # native anthropic tool definitions for computer use beta
310
+ # https://docs.anthropic.com/en/docs/build-with-claude/computer-use
311
+ class ComputerUseToolParam(TypedDict):
312
+ type: str
313
+ name: str
314
+ display_width_px: NotRequired[int]
315
+ display_height_px: NotRequired[int]
316
+ display_number: NotRequired[int]
317
+
318
+
319
+ # tools can be either a stock tool param or a special computer use tool param
320
+ ToolParamDef = ToolParam | ComputerUseToolParam
269
321
 
270
322
 
271
323
  async def resolve_chat_input(
@@ -273,7 +325,7 @@ async def resolve_chat_input(
273
325
  input: list[ChatMessage],
274
326
  tools: list[ToolInfo],
275
327
  config: GenerateConfig,
276
- ) -> Tuple[list[TextBlockParam] | None, list[ToolParam], list[MessageParam], bool]:
328
+ ) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
277
329
  # extract system message
278
330
  system_messages, messages = split_system_messages(input, config)
279
331
 
@@ -286,14 +338,7 @@ async def resolve_chat_input(
286
338
  )
287
339
 
288
340
  # tools
289
- tools_params = [
290
- ToolParam(
291
- name=tool.name,
292
- description=tool.description,
293
- input_schema=tool.parameters.model_dump(exclude_none=True),
294
- )
295
- for tool in tools
296
- ]
341
+ tools_params, computer_use = tool_params_for_tools(tools, config)
297
342
 
298
343
  # system messages
299
344
  if len(system_messages) > 0:
@@ -343,10 +388,66 @@ async def resolve_chat_input(
343
388
  add_cache_control(cast(dict[str, Any], content[-1]))
344
389
 
345
390
  # return chat input
346
- return system_param, tools_params, message_params, cache_prompt
391
+ return system_param, tools_params, message_params, computer_use
392
+
393
+
394
+ def tool_params_for_tools(
395
+ tools: list[ToolInfo], config: GenerateConfig
396
+ ) -> tuple[list[ToolParamDef], bool]:
397
+ # tool params and computer_use bit to return
398
+ tool_params: list[ToolParamDef] = []
399
+ computer_use = False
400
+
401
+ # for each tool, check if it has a native computer use implementation and use that
402
+ # when available (noting that we need to set the computer use request header)
403
+ for tool in tools:
404
+ computer_use_tool = (
405
+ computer_use_tool_param(tool)
406
+ if config.internal_tools is not False
407
+ else None
408
+ )
409
+ if computer_use_tool:
410
+ tool_params.append(computer_use_tool)
411
+ computer_use = True
412
+ else:
413
+ tool_params.append(
414
+ ToolParam(
415
+ name=tool.name,
416
+ description=tool.description,
417
+ input_schema=tool.parameters.model_dump(exclude_none=True),
418
+ )
419
+ )
420
+
421
+ return tool_params, computer_use
347
422
 
348
423
 
349
- def add_cache_control(param: TextBlockParam | ToolParam | dict[str, Any]) -> None:
424
+ def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
425
+ # check for compatible 'computer' tool
426
+ if tool.name == "computer" and (
427
+ sorted(tool.parameters.properties.keys())
428
+ == sorted(["action", "coordinate", "text"])
429
+ ):
430
+ return ComputerUseToolParam(
431
+ type="computer_20241022",
432
+ name="computer",
433
+ # Note: The dimensions passed here for display_width_px and display_height_px should
434
+ # match the dimensions of screenshots returned by the tool.
435
+ # Those dimensions will always be one of the values in MAX_SCALING_TARGETS
436
+ # in _x11_client.py.
437
+ # TODO: enhance this code to calculate the dimensions based on the scaled screen
438
+ # size used by the container.
439
+ display_width_px=1366,
440
+ display_height_px=768,
441
+ display_number=1,
442
+ )
443
+ # not a computer_use tool
444
+ else:
445
+ return None
446
+
447
+
448
+ def add_cache_control(
449
+ param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
450
+ ) -> None:
350
451
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
351
452
 
352
453
 
@@ -404,12 +505,13 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
404
505
  return {"type": "auto"}
405
506
 
406
507
 
407
- # text we insert when there is no content passed
408
- # (as this will result in an Anthropic API error)
409
- NO_CONTENT = "(no content)"
410
-
411
-
412
508
  async def message_param(message: ChatMessage) -> MessageParam:
509
+ # if content is empty that is going to result in an error when we replay
510
+ # this message to claude, so in that case insert a NO_CONTENT message
511
+ if isinstance(message.content, list) and len(message.content) == 0:
512
+ message = message.model_copy()
513
+ message.content = [ContentText(text=NO_CONTENT)]
514
+
413
515
  # no system role for anthropic (this is more like an assertion,
414
516
  # as these should have already been filtered out)
415
517
  if message.role == "system":
@@ -451,7 +553,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
451
553
  elif message.role == "assistant" and message.tool_calls:
452
554
  # first include content (claude <thinking>)
453
555
  tools_content: list[TextBlockParam | ImageBlockParam | ToolUseBlockParam] = (
454
- [TextBlockParam(type="text", text=message.content)]
556
+ [TextBlockParam(type="text", text=message.content or NO_CONTENT)]
455
557
  if isinstance(message.content, str)
456
558
  else (
457
559
  [(await message_param_content(content)) for content in message.content]
@@ -520,11 +622,6 @@ def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelO
520
622
  )
521
623
  )
522
624
 
523
- # if content is empty that is going to result in an error when we replay
524
- # this message to claude, so in that case insert a NO_CONTENT message
525
- if len(content) == 0:
526
- content = [ContentText(text=NO_CONTENT)]
527
-
528
625
  # resolve choice
529
626
  choice = ChatCompletionChoice(
530
627
  message=ChatMessageAssistant(
@@ -37,6 +37,7 @@ from inspect_ai.tool import ToolChoice, ToolInfo
37
37
  from inspect_ai.tool._tool_call import ToolCall
38
38
  from inspect_ai.tool._tool_choice import ToolFunction
39
39
 
40
+ from .._call_tools import parse_tool_call
40
41
  from .._chat_message import (
41
42
  ChatMessage,
42
43
  ChatMessageAssistant,
@@ -60,7 +61,6 @@ from .util import (
60
61
  )
61
62
  from .util.chatapi import ChatAPIHandler
62
63
  from .util.llama31 import Llama31Handler
63
- from .util.util import parse_tool_call
64
64
 
65
65
  AZUREAI_API_KEY = "AZUREAI_API_KEY"
66
66
  AZUREAI_ENDPOINT_KEY = "AZUREAI_ENDPOINT_KEY"
@@ -130,7 +130,7 @@ class AzureAIAPI(ModelAPI):
130
130
  tools: list[ToolInfo],
131
131
  tool_choice: ToolChoice,
132
132
  config: GenerateConfig,
133
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
133
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
134
134
  # emulate tools (auto for llama, opt-in for others)
135
135
  if self.emulate_tools is None and self.is_llama():
136
136
  handler: ChatAPIHandler | None = Llama31Handler()
@@ -162,6 +162,19 @@ class AzureAIAPI(ModelAPI):
162
162
  model_extras=self.model_args,
163
163
  )
164
164
 
165
+ def model_call(response: ChatCompletions | None = None) -> ModelCall:
166
+ return ModelCall.create(
167
+ request=request
168
+ | dict(
169
+ messages=[message.as_dict() for message in request["messages"]],
170
+ tools=[tool.as_dict() for tool in request["tools"]]
171
+ if request.get("tools", None) is not None
172
+ else None,
173
+ ),
174
+ response=response.as_dict() if response else {},
175
+ filter=image_url_filter,
176
+ )
177
+
165
178
  # make call
166
179
  try:
167
180
  response: ChatCompletions = await client.complete(**request)
@@ -173,19 +186,10 @@ class AzureAIAPI(ModelAPI):
173
186
  output_tokens=response.usage.completion_tokens,
174
187
  total_tokens=response.usage.total_tokens,
175
188
  ),
176
- ), ModelCall.create(
177
- request=request
178
- | dict(
179
- messages=[message.as_dict() for message in request["messages"]],
180
- tools=[tool.as_dict() for tool in request["tools"]]
181
- if request.get("tools", None) is not None
182
- else None,
183
- ),
184
- response=response.as_dict(),
185
- filter=image_url_filter,
186
- )
189
+ ), model_call(response)
190
+
187
191
  except AzureError as ex:
188
- return self.handle_azure_error(ex)
192
+ return self.handle_azure_error(ex), model_call()
189
193
  finally:
190
194
  await client.close()
191
195
 
@@ -251,7 +255,7 @@ class AzureAIAPI(ModelAPI):
251
255
  def is_mistral(self) -> bool:
252
256
  return "mistral" in self.model_name.lower()
253
257
 
254
- def handle_azure_error(self, ex: AzureError) -> ModelOutput:
258
+ def handle_azure_error(self, ex: AzureError) -> ModelOutput | Exception:
255
259
  if isinstance(ex, HttpResponseError):
256
260
  response = str(ex.message)
257
261
  if "maximum context length" in response.lower():
@@ -260,12 +264,8 @@ class AzureAIAPI(ModelAPI):
260
264
  content=response,
261
265
  stop_reason="model_length",
262
266
  )
263
- elif ex.status_code == 400 and ex.error:
264
- return ModelOutput.from_content(
265
- model=self.model_name,
266
- content=f"Your request triggered an error: {ex.error}",
267
- stop_reason="content_filter",
268
- )
267
+ elif ex.status_code == 400:
268
+ return ex
269
269
 
270
270
  raise ex
271
271
 
@@ -27,11 +27,7 @@ from .._chat_message import (
27
27
  from .._generate_config import GenerateConfig
28
28
  from .._model import ModelAPI
29
29
  from .._model_call import ModelCall
30
- from .._model_output import (
31
- ChatCompletionChoice,
32
- ModelOutput,
33
- ModelUsage,
34
- )
30
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
35
31
  from .util import (
36
32
  model_base_url,
37
33
  )
@@ -307,7 +303,7 @@ class BedrockAPI(ModelAPI):
307
303
  tools: list[ToolInfo],
308
304
  tool_choice: ToolChoice,
309
305
  config: GenerateConfig,
310
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
306
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
311
307
  from botocore.config import Config
312
308
  from botocore.exceptions import ClientError
313
309
 
@@ -339,25 +335,33 @@ class BedrockAPI(ModelAPI):
339
335
  # Resolve the input messages into converse messages
340
336
  system, messages = await converse_messages(input)
341
337
 
342
- try:
343
- # Make the request
344
- request = ConverseClientConverseRequest(
345
- modelId=self.model_name,
346
- messages=messages,
347
- system=system,
348
- inferenceConfig=ConverseInferenceConfig(
349
- maxTokens=config.max_tokens,
350
- temperature=config.temperature,
351
- topP=config.top_p,
352
- stopSequences=config.stop_seqs,
338
+ # Make the request
339
+ request = ConverseClientConverseRequest(
340
+ modelId=self.model_name,
341
+ messages=messages,
342
+ system=system,
343
+ inferenceConfig=ConverseInferenceConfig(
344
+ maxTokens=config.max_tokens,
345
+ temperature=config.temperature,
346
+ topP=config.top_p,
347
+ stopSequences=config.stop_seqs,
348
+ ),
349
+ additionalModelRequestFields={
350
+ "top_k": config.top_k,
351
+ **config.model_config,
352
+ },
353
+ toolConfig=tool_config,
354
+ )
355
+
356
+ def model_call(response: dict[str, Any] | None = None) -> ModelCall:
357
+ return ModelCall.create(
358
+ request=replace_bytes_with_placeholder(
359
+ request.model_dump(exclude_none=True)
353
360
  ),
354
- additionalModelRequestFields={
355
- "top_k": config.top_k,
356
- **config.model_config,
357
- },
358
- toolConfig=tool_config,
361
+ response=response,
359
362
  )
360
363
 
364
+ try:
361
365
  # Process the reponse
362
366
  response = await client.converse(
363
367
  **request.model_dump(exclude_none=True)
@@ -366,32 +370,24 @@ class BedrockAPI(ModelAPI):
366
370
 
367
371
  except ClientError as ex:
368
372
  # Look for an explicit validation exception
369
- if (
370
- ex.response["Error"]["Code"] == "ValidationException"
371
- and "Too many input tokens" in ex.response["Error"]["Message"]
372
- ):
373
+ if ex.response["Error"]["Code"] == "ValidationException":
373
374
  response = ex.response["Error"]["Message"]
374
- return ModelOutput.from_content(
375
- model=self.model_name,
376
- content=response,
377
- stop_reason="model_length",
378
- )
375
+ if "Too many input tokens" in response:
376
+ return ModelOutput.from_content(
377
+ model=self.model_name,
378
+ content=response,
379
+ stop_reason="model_length",
380
+ )
381
+ else:
382
+ return ex, model_call(None)
379
383
  else:
380
384
  raise ex
381
385
 
382
386
  # create a model output from the response
383
387
  output = model_output_from_response(self.model_name, converse_response, tools)
384
388
 
385
- # record call
386
- call = ModelCall.create(
387
- request=replace_bytes_with_placeholder(
388
- request.model_dump(exclude_none=True)
389
- ),
390
- response=response,
391
- )
392
-
393
389
  # return
394
- return output, call
390
+ return output, model_call(response)
395
391
 
396
392
 
397
393
  async def converse_messages(
@@ -550,6 +546,7 @@ async def converse_chat_message(
550
546
  "Tool call is missing a tool call id, which is required for Converse API"
551
547
  )
552
548
  if message.function is None:
549
+ print(message)
553
550
  raise ValueError(
554
551
  "Tool call is missing a function, which is required for Converse API"
555
552
  )