inspect-ai 0.3.58__py3-none-any.whl → 0.3.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -2
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +78 -11
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/score.py +1 -0
  13. inspect_ai/_eval/task/results.py +50 -22
  14. inspect_ai/_eval/task/run.py +41 -7
  15. inspect_ai/_eval/task/sandbox.py +10 -5
  16. inspect_ai/_util/constants.py +1 -0
  17. inspect_ai/_util/port_names.py +61 -0
  18. inspect_ai/_util/text.py +23 -0
  19. inspect_ai/_view/www/App.css +31 -1
  20. inspect_ai/_view/www/dist/assets/index.css +31 -1
  21. inspect_ai/_view/www/dist/assets/index.js +25344 -1849
  22. inspect_ai/_view/www/log-schema.json +32 -2
  23. inspect_ai/_view/www/package.json +2 -0
  24. inspect_ai/_view/www/src/App.mjs +8 -10
  25. inspect_ai/_view/www/src/Types.mjs +0 -1
  26. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  27. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  28. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  29. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  30. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  31. inspect_ai/_view/www/src/index.js +75 -2
  32. inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
  33. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
  34. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  35. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  36. inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
  37. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  38. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +24 -12
  39. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
  40. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  41. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  42. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  43. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  44. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  45. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  46. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  47. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  48. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  49. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  50. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  51. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  52. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  53. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  54. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  55. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  56. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  57. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  58. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  59. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  60. inspect_ai/_view/www/src/utils/Json.mjs +12 -6
  61. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
  62. inspect_ai/_view/www/vite.config.js +7 -0
  63. inspect_ai/_view/www/yarn.lock +116 -0
  64. inspect_ai/approval/_human/__init__.py +0 -0
  65. inspect_ai/approval/_policy.py +12 -6
  66. inspect_ai/log/_log.py +1 -1
  67. inspect_ai/log/_samples.py +16 -0
  68. inspect_ai/log/_transcript.py +4 -1
  69. inspect_ai/model/_call_tools.py +4 -0
  70. inspect_ai/model/_conversation.py +20 -8
  71. inspect_ai/model/_generate_config.py +10 -4
  72. inspect_ai/model/_model.py +117 -18
  73. inspect_ai/model/_model_output.py +7 -2
  74. inspect_ai/model/_providers/anthropic.py +100 -44
  75. inspect_ai/model/_providers/azureai.py +20 -20
  76. inspect_ai/model/_providers/bedrock.py +37 -40
  77. inspect_ai/model/_providers/google.py +46 -54
  78. inspect_ai/model/_providers/mistral.py +11 -11
  79. inspect_ai/model/_providers/openai.py +15 -16
  80. inspect_ai/model/_providers/openai_o1.py +9 -8
  81. inspect_ai/model/_providers/providers.py +1 -1
  82. inspect_ai/model/_providers/together.py +8 -8
  83. inspect_ai/model/_providers/vertex.py +1 -4
  84. inspect_ai/scorer/_reducer/reducer.py +1 -1
  85. inspect_ai/scorer/_scorer.py +2 -2
  86. inspect_ai/solver/__init__.py +2 -5
  87. inspect_ai/solver/_prompt.py +35 -5
  88. inspect_ai/solver/_task_state.py +80 -38
  89. inspect_ai/tool/__init__.py +2 -0
  90. inspect_ai/tool/_tool.py +12 -1
  91. inspect_ai/tool/_tool_call.py +10 -0
  92. inspect_ai/tool/_tool_def.py +16 -5
  93. inspect_ai/tool/_tool_with.py +21 -4
  94. inspect_ai/tool/beta/__init__.py +5 -0
  95. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  96. inspect_ai/tool/beta/_computer/_common.py +133 -0
  97. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  98. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  99. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  100. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  101. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  102. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  103. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  104. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  105. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  106. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  107. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  108. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  109. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  110. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  111. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  112. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  113. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  114. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  115. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  116. inspect_ai/util/__init__.py +2 -0
  117. inspect_ai/util/_limit.py +26 -0
  118. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  119. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  120. inspect_ai/util/_sandbox/environment.py +14 -0
  121. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
  122. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +126 -98
  123. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  124. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
  125. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
  126. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
  127. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,14 @@
1
1
  import functools
2
2
  import os
3
+ import sys
3
4
  from copy import copy
4
5
  from logging import getLogger
5
- from typing import Any, Literal, Tuple, cast
6
+ from typing import Any, Literal, Tuple, TypedDict, cast
7
+
8
+ if sys.version_info >= (3, 11):
9
+ from typing import NotRequired
10
+ else:
11
+ from typing_extensions import NotRequired
6
12
 
7
13
  from anthropic import (
8
14
  APIConnectionError,
@@ -27,7 +33,11 @@ from anthropic.types import (
27
33
  from pydantic import JsonValue
28
34
  from typing_extensions import override
29
35
 
30
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED, DEFAULT_MAX_RETRIES
36
+ from inspect_ai._util.constants import (
37
+ BASE_64_DATA_REMOVED,
38
+ DEFAULT_MAX_RETRIES,
39
+ NO_CONTENT,
40
+ )
31
41
  from inspect_ai._util.content import Content, ContentImage, ContentText
32
42
  from inspect_ai._util.error import exception_message
33
43
  from inspect_ai._util.images import file_as_data_uri
@@ -35,20 +45,11 @@ from inspect_ai._util.logger import warn_once
35
45
  from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
36
46
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
37
47
 
38
- from .._chat_message import (
39
- ChatMessage,
40
- ChatMessageAssistant,
41
- ChatMessageSystem,
42
- )
48
+ from .._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageSystem
43
49
  from .._generate_config import GenerateConfig
44
50
  from .._model import ModelAPI
45
51
  from .._model_call import ModelCall
46
- from .._model_output import (
47
- ChatCompletionChoice,
48
- ModelOutput,
49
- ModelUsage,
50
- StopReason,
51
- )
52
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
52
53
  from .util import environment_prerequisite_error, model_base_url
53
54
 
54
55
  logger = getLogger(__name__)
@@ -124,7 +125,7 @@ class AnthropicAPI(ModelAPI):
124
125
  tools: list[ToolInfo],
125
126
  tool_choice: ToolChoice,
126
127
  config: GenerateConfig,
127
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
128
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
128
129
  # setup request and response for ModelCall
129
130
  request: dict[str, Any] = {}
130
131
  response: dict[str, Any] = {}
@@ -142,7 +143,7 @@ class AnthropicAPI(ModelAPI):
142
143
  system_param,
143
144
  tools_param,
144
145
  messages,
145
- cache_prompt,
146
+ computer_use,
146
147
  ) = await resolve_chat_input(self.model_name, input, tools, config)
147
148
 
148
149
  # prepare request params (assembed this way so we can log the raw model call)
@@ -158,13 +159,11 @@ class AnthropicAPI(ModelAPI):
158
159
  # additional options
159
160
  request = request | self.completion_params(config)
160
161
 
161
- # caching header
162
- if cache_prompt:
163
- request["extra_headers"] = {
164
- "anthropic-beta": "prompt-caching-2024-07-31"
165
- }
162
+ # computer use beta
163
+ if computer_use:
164
+ request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
166
165
 
167
- # call model
166
+ # make request
168
167
  message = await self.client.messages.create(**request, stream=False)
169
168
 
170
169
  # set response for ModelCall
@@ -177,11 +176,7 @@ class AnthropicAPI(ModelAPI):
177
176
  return output, model_call()
178
177
 
179
178
  except BadRequestError as ex:
180
- error_output = self.handle_bad_request(ex)
181
- if error_output is not None:
182
- return error_output, model_call()
183
- else:
184
- raise ex
179
+ return self.handle_bad_request(ex), model_call()
185
180
 
186
181
  def completion_params(self, config: GenerateConfig) -> dict[str, Any]:
187
182
  params = dict(model=self.model_name, max_tokens=cast(int, config.max_tokens))
@@ -234,7 +229,7 @@ class AnthropicAPI(ModelAPI):
234
229
  return True
235
230
 
236
231
  # convert some common BadRequestError states into 'refusal' model output
237
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
232
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
238
233
  error = exception_message(ex).lower()
239
234
  content: str | None = None
240
235
  stop_reason: StopReason | None = None
@@ -256,6 +251,9 @@ class AnthropicAPI(ModelAPI):
256
251
  elif "content filtering" in error:
257
252
  content = "Sorry, but I am unable to help with that request."
258
253
  stop_reason = "content_filter"
254
+ else:
255
+ content = error
256
+ stop_reason = "unknown"
259
257
 
260
258
  if content and stop_reason:
261
259
  return ModelOutput.from_content(
@@ -265,7 +263,21 @@ class AnthropicAPI(ModelAPI):
265
263
  error=error,
266
264
  )
267
265
  else:
268
- return None
266
+ return ex
267
+
268
+
269
+ # native anthropic tool definitions for computer use beta
270
+ # https://docs.anthropic.com/en/docs/build-with-claude/computer-use
271
+ class ComputerUseToolParam(TypedDict):
272
+ type: str
273
+ name: str
274
+ display_width_px: NotRequired[int]
275
+ display_height_px: NotRequired[int]
276
+ display_number: NotRequired[int]
277
+
278
+
279
+ # tools can be either a stock tool param or a special computer use tool param
280
+ ToolParamDef = ToolParam | ComputerUseToolParam
269
281
 
270
282
 
271
283
  async def resolve_chat_input(
@@ -273,7 +285,7 @@ async def resolve_chat_input(
273
285
  input: list[ChatMessage],
274
286
  tools: list[ToolInfo],
275
287
  config: GenerateConfig,
276
- ) -> Tuple[list[TextBlockParam] | None, list[ToolParam], list[MessageParam], bool]:
288
+ ) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
277
289
  # extract system message
278
290
  system_messages, messages = split_system_messages(input, config)
279
291
 
@@ -286,14 +298,7 @@ async def resolve_chat_input(
286
298
  )
287
299
 
288
300
  # tools
289
- tools_params = [
290
- ToolParam(
291
- name=tool.name,
292
- description=tool.description,
293
- input_schema=tool.parameters.model_dump(exclude_none=True),
294
- )
295
- for tool in tools
296
- ]
301
+ tools_params, computer_use = tool_params_for_tools(tools, config)
297
302
 
298
303
  # system messages
299
304
  if len(system_messages) > 0:
@@ -343,10 +348,66 @@ async def resolve_chat_input(
343
348
  add_cache_control(cast(dict[str, Any], content[-1]))
344
349
 
345
350
  # return chat input
346
- return system_param, tools_params, message_params, cache_prompt
351
+ return system_param, tools_params, message_params, computer_use
352
+
353
+
354
+ def tool_params_for_tools(
355
+ tools: list[ToolInfo], config: GenerateConfig
356
+ ) -> tuple[list[ToolParamDef], bool]:
357
+ # tool params and computer_use bit to return
358
+ tool_params: list[ToolParamDef] = []
359
+ computer_use = False
360
+
361
+ # for each tool, check if it has a native computer use implementation and use that
362
+ # when available (noting that we need to set the computer use request header)
363
+ for tool in tools:
364
+ computer_use_tool = (
365
+ computer_use_tool_param(tool)
366
+ if config.internal_tools is not False
367
+ else None
368
+ )
369
+ if computer_use_tool:
370
+ tool_params.append(computer_use_tool)
371
+ computer_use = True
372
+ else:
373
+ tool_params.append(
374
+ ToolParam(
375
+ name=tool.name,
376
+ description=tool.description,
377
+ input_schema=tool.parameters.model_dump(exclude_none=True),
378
+ )
379
+ )
347
380
 
381
+ return tool_params, computer_use
348
382
 
349
- def add_cache_control(param: TextBlockParam | ToolParam | dict[str, Any]) -> None:
383
+
384
+ def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
385
+ # check for compatible 'computer' tool
386
+ if tool.name == "computer" and (
387
+ sorted(tool.parameters.properties.keys())
388
+ == sorted(["action", "coordinate", "text"])
389
+ ):
390
+ return ComputerUseToolParam(
391
+ type="computer_20241022",
392
+ name="computer",
393
+ # Note: The dimensions passed here for display_width_px and display_height_px should
394
+ # match the dimensions of screenshots returned by the tool.
395
+ # Those dimensions will always be one of the values in MAX_SCALING_TARGETS
396
+ # in _x11_client.py.
397
+ # TODO: enhance this code to calculate the dimensions based on the scaled screen
398
+ # size used by the container.
399
+ display_width_px=1366,
400
+ display_height_px=768,
401
+ display_number=1,
402
+ )
403
+ # not a computer_use tool
404
+ else:
405
+ return None
406
+
407
+
408
+ def add_cache_control(
409
+ param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
410
+ ) -> None:
350
411
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
351
412
 
352
413
 
@@ -404,11 +465,6 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
404
465
  return {"type": "auto"}
405
466
 
406
467
 
407
- # text we insert when there is no content passed
408
- # (as this will result in an Anthropic API error)
409
- NO_CONTENT = "(no content)"
410
-
411
-
412
468
  async def message_param(message: ChatMessage) -> MessageParam:
413
469
  # no system role for anthropic (this is more like an assertion,
414
470
  # as these should have already been filtered out)
@@ -130,7 +130,7 @@ class AzureAIAPI(ModelAPI):
130
130
  tools: list[ToolInfo],
131
131
  tool_choice: ToolChoice,
132
132
  config: GenerateConfig,
133
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
133
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
134
134
  # emulate tools (auto for llama, opt-in for others)
135
135
  if self.emulate_tools is None and self.is_llama():
136
136
  handler: ChatAPIHandler | None = Llama31Handler()
@@ -162,6 +162,19 @@ class AzureAIAPI(ModelAPI):
162
162
  model_extras=self.model_args,
163
163
  )
164
164
 
165
+ def model_call(response: ChatCompletions | None = None) -> ModelCall:
166
+ return ModelCall.create(
167
+ request=request
168
+ | dict(
169
+ messages=[message.as_dict() for message in request["messages"]],
170
+ tools=[tool.as_dict() for tool in request["tools"]]
171
+ if request.get("tools", None) is not None
172
+ else None,
173
+ ),
174
+ response=response.as_dict() if response else {},
175
+ filter=image_url_filter,
176
+ )
177
+
165
178
  # make call
166
179
  try:
167
180
  response: ChatCompletions = await client.complete(**request)
@@ -173,19 +186,10 @@ class AzureAIAPI(ModelAPI):
173
186
  output_tokens=response.usage.completion_tokens,
174
187
  total_tokens=response.usage.total_tokens,
175
188
  ),
176
- ), ModelCall.create(
177
- request=request
178
- | dict(
179
- messages=[message.as_dict() for message in request["messages"]],
180
- tools=[tool.as_dict() for tool in request["tools"]]
181
- if request.get("tools", None) is not None
182
- else None,
183
- ),
184
- response=response.as_dict(),
185
- filter=image_url_filter,
186
- )
189
+ ), model_call(response)
190
+
187
191
  except AzureError as ex:
188
- return self.handle_azure_error(ex)
192
+ return self.handle_azure_error(ex), model_call()
189
193
  finally:
190
194
  await client.close()
191
195
 
@@ -251,7 +255,7 @@ class AzureAIAPI(ModelAPI):
251
255
  def is_mistral(self) -> bool:
252
256
  return "mistral" in self.model_name.lower()
253
257
 
254
- def handle_azure_error(self, ex: AzureError) -> ModelOutput:
258
+ def handle_azure_error(self, ex: AzureError) -> ModelOutput | Exception:
255
259
  if isinstance(ex, HttpResponseError):
256
260
  response = str(ex.message)
257
261
  if "maximum context length" in response.lower():
@@ -260,12 +264,8 @@ class AzureAIAPI(ModelAPI):
260
264
  content=response,
261
265
  stop_reason="model_length",
262
266
  )
263
- elif ex.status_code == 400 and ex.error:
264
- return ModelOutput.from_content(
265
- model=self.model_name,
266
- content=f"Your request triggered an error: {ex.error}",
267
- stop_reason="content_filter",
268
- )
267
+ elif ex.status_code == 400:
268
+ return ex
269
269
 
270
270
  raise ex
271
271
 
@@ -27,11 +27,7 @@ from .._chat_message import (
27
27
  from .._generate_config import GenerateConfig
28
28
  from .._model import ModelAPI
29
29
  from .._model_call import ModelCall
30
- from .._model_output import (
31
- ChatCompletionChoice,
32
- ModelOutput,
33
- ModelUsage,
34
- )
30
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
35
31
  from .util import (
36
32
  model_base_url,
37
33
  )
@@ -307,7 +303,7 @@ class BedrockAPI(ModelAPI):
307
303
  tools: list[ToolInfo],
308
304
  tool_choice: ToolChoice,
309
305
  config: GenerateConfig,
310
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
306
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
311
307
  from botocore.config import Config
312
308
  from botocore.exceptions import ClientError
313
309
 
@@ -339,25 +335,33 @@ class BedrockAPI(ModelAPI):
339
335
  # Resolve the input messages into converse messages
340
336
  system, messages = await converse_messages(input)
341
337
 
342
- try:
343
- # Make the request
344
- request = ConverseClientConverseRequest(
345
- modelId=self.model_name,
346
- messages=messages,
347
- system=system,
348
- inferenceConfig=ConverseInferenceConfig(
349
- maxTokens=config.max_tokens,
350
- temperature=config.temperature,
351
- topP=config.top_p,
352
- stopSequences=config.stop_seqs,
338
+ # Make the request
339
+ request = ConverseClientConverseRequest(
340
+ modelId=self.model_name,
341
+ messages=messages,
342
+ system=system,
343
+ inferenceConfig=ConverseInferenceConfig(
344
+ maxTokens=config.max_tokens,
345
+ temperature=config.temperature,
346
+ topP=config.top_p,
347
+ stopSequences=config.stop_seqs,
348
+ ),
349
+ additionalModelRequestFields={
350
+ "top_k": config.top_k,
351
+ **config.model_config,
352
+ },
353
+ toolConfig=tool_config,
354
+ )
355
+
356
+ def model_call(response: dict[str, Any] | None = None) -> ModelCall:
357
+ return ModelCall.create(
358
+ request=replace_bytes_with_placeholder(
359
+ request.model_dump(exclude_none=True)
353
360
  ),
354
- additionalModelRequestFields={
355
- "top_k": config.top_k,
356
- **config.model_config,
357
- },
358
- toolConfig=tool_config,
361
+ response=response,
359
362
  )
360
363
 
364
+ try:
361
365
  # Process the reponse
362
366
  response = await client.converse(
363
367
  **request.model_dump(exclude_none=True)
@@ -366,32 +370,24 @@ class BedrockAPI(ModelAPI):
366
370
 
367
371
  except ClientError as ex:
368
372
  # Look for an explicit validation exception
369
- if (
370
- ex.response["Error"]["Code"] == "ValidationException"
371
- and "Too many input tokens" in ex.response["Error"]["Message"]
372
- ):
373
+ if ex.response["Error"]["Code"] == "ValidationException":
373
374
  response = ex.response["Error"]["Message"]
374
- return ModelOutput.from_content(
375
- model=self.model_name,
376
- content=response,
377
- stop_reason="model_length",
378
- )
375
+ if "Too many input tokens" in response:
376
+ return ModelOutput.from_content(
377
+ model=self.model_name,
378
+ content=response,
379
+ stop_reason="model_length",
380
+ )
381
+ else:
382
+ return ex, model_call(None)
379
383
  else:
380
384
  raise ex
381
385
 
382
386
  # create a model output from the response
383
387
  output = model_output_from_response(self.model_name, converse_response, tools)
384
388
 
385
- # record call
386
- call = ModelCall.create(
387
- request=replace_bytes_with_placeholder(
388
- request.model_dump(exclude_none=True)
389
- ),
390
- response=response,
391
- )
392
-
393
389
  # return
394
- return output, call
390
+ return output, model_call(response)
395
391
 
396
392
 
397
393
  async def converse_messages(
@@ -550,6 +546,7 @@ async def converse_chat_message(
550
546
  "Tool call is missing a tool call id, which is required for Converse API"
551
547
  )
552
548
  if message.function is None:
549
+ print(message)
553
550
  raise ValueError(
554
551
  "Tool call is missing a function, which is required for Converse API"
555
552
  )