inspect-ai 0.3.75__py3-none-any.whl → 0.3.77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_display/core/results.py +6 -1
  3. inspect_ai/_eval/eval.py +8 -1
  4. inspect_ai/_eval/evalset.py +6 -2
  5. inspect_ai/_eval/registry.py +3 -5
  6. inspect_ai/_eval/run.py +7 -2
  7. inspect_ai/_eval/task/run.py +4 -0
  8. inspect_ai/_util/content.py +3 -0
  9. inspect_ai/_util/logger.py +3 -0
  10. inspect_ai/_view/www/dist/assets/index.css +28 -16
  11. inspect_ai/_view/www/dist/assets/index.js +4811 -4609
  12. inspect_ai/_view/www/log-schema.json +79 -9
  13. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +22 -4
  14. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +1 -1
  15. inspect_ai/_view/www/src/samples/descriptor/score/CategoricalScoreDescriptor.tsx +1 -1
  16. inspect_ai/_view/www/src/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -2
  17. inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +1 -1
  18. inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +2 -2
  19. inspect_ai/_view/www/src/types/log.d.ts +11 -5
  20. inspect_ai/log/_recorders/json.py +8 -0
  21. inspect_ai/log/_transcript.py +13 -4
  22. inspect_ai/model/_call_tools.py +13 -4
  23. inspect_ai/model/_chat_message.py +3 -0
  24. inspect_ai/model/_model.py +5 -1
  25. inspect_ai/model/_model_output.py +6 -1
  26. inspect_ai/model/_openai.py +78 -10
  27. inspect_ai/model/_openai_responses.py +277 -0
  28. inspect_ai/model/_providers/anthropic.py +134 -75
  29. inspect_ai/model/_providers/azureai.py +2 -2
  30. inspect_ai/model/_providers/mistral.py +29 -13
  31. inspect_ai/model/_providers/openai.py +64 -57
  32. inspect_ai/model/_providers/openai_responses.py +177 -0
  33. inspect_ai/model/_providers/openrouter.py +52 -2
  34. inspect_ai/model/_providers/providers.py +1 -1
  35. inspect_ai/model/_providers/vertex.py +5 -2
  36. inspect_ai/tool/__init__.py +6 -0
  37. inspect_ai/tool/_tool.py +23 -3
  38. inspect_ai/tool/_tool_call.py +5 -2
  39. inspect_ai/tool/_tool_support_helpers.py +200 -0
  40. inspect_ai/tool/_tools/_bash_session.py +119 -0
  41. inspect_ai/tool/_tools/_computer/_computer.py +1 -1
  42. inspect_ai/tool/_tools/_text_editor.py +121 -0
  43. inspect_ai/tool/_tools/_think.py +48 -0
  44. inspect_ai/tool/_tools/_web_browser/_back_compat.py +150 -0
  45. inspect_ai/tool/_tools/_web_browser/_web_browser.py +75 -130
  46. inspect_ai/tool/_tools/_web_search.py +1 -1
  47. inspect_ai/util/_json.py +28 -0
  48. inspect_ai/util/_sandbox/context.py +16 -7
  49. inspect_ai/util/_sandbox/docker/config.py +1 -1
  50. inspect_ai/util/_sandbox/docker/internal.py +3 -3
  51. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/METADATA +5 -2
  52. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/RECORD +56 -80
  53. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/WHEEL +1 -1
  54. inspect_ai/model/_image.py +0 -15
  55. inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +0 -8
  56. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +0 -24
  57. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +0 -25
  58. inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +0 -22
  59. inspect_ai/tool/_tools/_web_browser/_resources/README.md +0 -63
  60. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +0 -71
  61. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +0 -323
  62. inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +0 -5
  63. inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +0 -279
  64. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +0 -9
  65. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +0 -293
  66. inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +0 -94
  67. inspect_ai/tool/_tools/_web_browser/_resources/constants.py +0 -2
  68. inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +0 -2
  69. inspect_ai/tool/_tools/_web_browser/_resources/mock_environment.py +0 -45
  70. inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +0 -50
  71. inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +0 -48
  72. inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +0 -280
  73. inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +0 -65
  74. inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +0 -64
  75. inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +0 -146
  76. inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +0 -64
  77. inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +0 -180
  78. inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +0 -99
  79. inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +0 -15
  80. inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +0 -44
  81. inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +0 -39
  82. inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +0 -214
  83. inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +0 -35
  84. inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +0 -192
  85. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/entry_points.txt +0 -0
  86. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info/licenses}/LICENSE +0 -0
  87. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,12 @@
1
1
  import functools
2
2
  import os
3
3
  import re
4
- import sys
5
4
  from copy import copy
6
5
  from logging import getLogger
7
- from typing import Any, Literal, Optional, Tuple, TypedDict, cast
6
+ from typing import Any, Literal, NamedTuple, Optional, Tuple, cast
8
7
 
9
8
  import httpcore
10
9
  import httpx
11
-
12
- from inspect_ai._util.http import is_retryable_http_status
13
-
14
- from .util.hooks import HttpxHooks
15
-
16
- if sys.version_info >= (3, 11):
17
- from typing import NotRequired
18
- else:
19
- from typing_extensions import NotRequired
20
-
21
10
  from anthropic import (
22
11
  APIConnectionError,
23
12
  APIStatusError,
@@ -39,19 +28,19 @@ from anthropic.types import (
39
28
  TextBlockParam,
40
29
  ThinkingBlock,
41
30
  ThinkingBlockParam,
31
+ ToolBash20250124Param,
42
32
  ToolParam,
43
33
  ToolResultBlockParam,
34
+ ToolTextEditor20250124Param,
44
35
  ToolUseBlock,
45
36
  ToolUseBlockParam,
46
37
  message_create_params,
47
38
  )
39
+ from anthropic.types.beta import BetaToolComputerUse20250124Param
48
40
  from pydantic import JsonValue
49
41
  from typing_extensions import override
50
42
 
51
- from inspect_ai._util.constants import (
52
- BASE_64_DATA_REMOVED,
53
- NO_CONTENT,
54
- )
43
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
55
44
  from inspect_ai._util.content import (
56
45
  Content,
57
46
  ContentImage,
@@ -59,6 +48,7 @@ from inspect_ai._util.content import (
59
48
  ContentText,
60
49
  )
61
50
  from inspect_ai._util.error import exception_message
51
+ from inspect_ai._util.http import is_retryable_http_status
62
52
  from inspect_ai._util.images import file_as_data_uri
63
53
  from inspect_ai._util.logger import warn_once
64
54
  from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
@@ -70,11 +60,14 @@ from .._model import ModelAPI
70
60
  from .._model_call import ModelCall
71
61
  from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
72
62
  from .util import environment_prerequisite_error, model_base_url
63
+ from .util.hooks import HttpxHooks
73
64
 
74
65
  logger = getLogger(__name__)
75
66
 
76
67
  ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY"
77
68
 
69
+ INTERNAL_COMPUTER_TOOL_NAME = "computer"
70
+
78
71
 
79
72
  class AnthropicAPI(ModelAPI):
80
73
  def __init__(
@@ -93,7 +86,7 @@ class AnthropicAPI(ModelAPI):
93
86
  else:
94
87
  self.service = None
95
88
 
96
- # collect gemerate model_args (then delete them so we can pass the rest on)
89
+ # collect generate model_args (then delete them so we can pass the rest on)
97
90
  def collect_model_arg(name: str) -> Any | None:
98
91
  nonlocal model_args
99
92
  value = model_args.get(name, None)
@@ -193,14 +186,11 @@ class AnthropicAPI(ModelAPI):
193
186
 
194
187
  # generate
195
188
  try:
196
- (
197
- system_param,
198
- tools_param,
199
- messages,
200
- computer_use,
201
- ) = await self.resolve_chat_input(input, tools, config)
189
+ system_param, tools_param, messages = await self.resolve_chat_input(
190
+ input, tools, config
191
+ )
202
192
 
203
- # prepare request params (assembed this way so we can log the raw model call)
193
+ # prepare request params (assembled this way so we can log the raw model call)
204
194
  request = dict(messages=messages)
205
195
 
206
196
  # system messages and tools
@@ -218,7 +208,13 @@ class AnthropicAPI(ModelAPI):
218
208
 
219
209
  # extra headers (for time tracker and computer use)
220
210
  extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id}
221
- if computer_use:
211
+ if any(
212
+ tool.get("type", None) == "computer_20250124" for tool in tools_param
213
+ ):
214
+ # From: https://docs.anthropic.com/en/docs/agents-and-tools/computer-use#claude-3-7-sonnet-beta-flag
215
+ # Note: The Bash (bash_20250124) and Text Editor (text_editor_20250124)
216
+ # tools are generally available for Claude 3.5 Sonnet (new) as well and
217
+ # can be used without the computer use beta header.
222
218
  betas.append("computer-use-2025-01-24")
223
219
  if len(betas) > 0:
224
220
  extra_headers["anthropic-beta"] = ",".join(betas)
@@ -405,9 +401,7 @@ class AnthropicAPI(ModelAPI):
405
401
  input: list[ChatMessage],
406
402
  tools: list[ToolInfo],
407
403
  config: GenerateConfig,
408
- ) -> Tuple[
409
- list[TextBlockParam] | None, list["ToolParamDef"], list[MessageParam], bool
410
- ]:
404
+ ) -> Tuple[list[TextBlockParam] | None, list["ToolParamDef"], list[MessageParam]]:
411
405
  # extract system message
412
406
  system_messages, messages = split_system_messages(input, config)
413
407
 
@@ -420,7 +414,7 @@ class AnthropicAPI(ModelAPI):
420
414
  )
421
415
 
422
416
  # tools
423
- tools_params, computer_use = self.tool_params_for_tools(tools, config)
417
+ tools_params = [self.tool_param_for_tool_info(tool, config) for tool in tools]
424
418
 
425
419
  # system messages
426
420
  if len(system_messages) > 0:
@@ -470,40 +464,35 @@ class AnthropicAPI(ModelAPI):
470
464
  add_cache_control(cast(dict[str, Any], content[-1]))
471
465
 
472
466
  # return chat input
473
- return system_param, tools_params, message_params, computer_use
474
-
475
- def tool_params_for_tools(
476
- self, tools: list[ToolInfo], config: GenerateConfig
477
- ) -> tuple[list["ToolParamDef"], bool]:
478
- # tool params and computer_use bit to return
479
- tool_params: list["ToolParamDef"] = []
480
- computer_use = False
481
-
482
- # for each tool, check if it has a native computer use implementation and use that
483
- # when available (noting that we need to set the computer use request header)
484
- for tool in tools:
485
- computer_use_tool = (
467
+ return system_param, tools_params, message_params
468
+
469
+ def tool_param_for_tool_info(
470
+ self, tool: ToolInfo, config: GenerateConfig
471
+ ) -> "ToolParamDef":
472
+ # Use a native tool implementation when available. Otherwise, use the
473
+ # standard tool implementation
474
+ return self.maybe_native_tool_param(tool, config) or ToolParam(
475
+ name=tool.name,
476
+ description=tool.description,
477
+ input_schema=tool.parameters.model_dump(exclude_none=True),
478
+ )
479
+
480
+ def maybe_native_tool_param(
481
+ self, tool: ToolInfo, config: GenerateConfig
482
+ ) -> Optional["ToolParamDef"]:
483
+ return (
484
+ (
486
485
  self.computer_use_tool_param(tool)
487
- if config.internal_tools is not False
488
- else None
486
+ or self.text_editor_tool_param(tool)
487
+ or self.bash_tool_param(tool)
489
488
  )
490
- if computer_use_tool:
491
- tool_params.append(computer_use_tool)
492
- computer_use = True
493
- else:
494
- tool_params.append(
495
- ToolParam(
496
- name=tool.name,
497
- description=tool.description,
498
- input_schema=tool.parameters.model_dump(exclude_none=True),
499
- )
500
- )
501
-
502
- return tool_params, computer_use
489
+ if config.internal_tools is not False
490
+ else None
491
+ )
503
492
 
504
493
  def computer_use_tool_param(
505
494
  self, tool: ToolInfo
506
- ) -> Optional["ComputerUseToolParam"]:
495
+ ) -> Optional[BetaToolComputerUse20250124Param]:
507
496
  # check for compatible 'computer' tool
508
497
  if tool.name == "computer" and (
509
498
  sorted(tool.parameters.properties.keys())
@@ -525,7 +514,7 @@ class AnthropicAPI(ModelAPI):
525
514
  "Use of Anthropic's native computer use support is not enabled in Claude 3.5. Please use 3.7 or later to leverage the native support.",
526
515
  )
527
516
  return None
528
- return ComputerUseToolParam(
517
+ return BetaToolComputerUse20250124Param(
529
518
  type="computer_20250124",
530
519
  name="computer",
531
520
  # Note: The dimensions passed here for display_width_px and display_height_px should
@@ -542,23 +531,58 @@ class AnthropicAPI(ModelAPI):
542
531
  else:
543
532
  return None
544
533
 
534
+ def text_editor_tool_param(
535
+ self, tool: ToolInfo
536
+ ) -> Optional[ToolTextEditor20250124Param]:
537
+ # check for compatible 'text editor' tool
538
+ if tool.name == "text_editor" and (
539
+ sorted(tool.parameters.properties.keys())
540
+ == sorted(
541
+ [
542
+ "command",
543
+ "file_text",
544
+ "insert_line",
545
+ "new_str",
546
+ "old_str",
547
+ "path",
548
+ "view_range",
549
+ ]
550
+ )
551
+ ):
552
+ return ToolTextEditor20250124Param(
553
+ type="text_editor_20250124", name="str_replace_editor"
554
+ )
555
+ # not a text_editor tool
556
+ else:
557
+ return None
545
558
 
546
- # native anthropic tool definitions for computer use beta
547
- # https://docs.anthropic.com/en/docs/build-with-claude/computer-use
548
- class ComputerUseToolParam(TypedDict):
549
- type: str
550
- name: str
551
- display_width_px: NotRequired[int]
552
- display_height_px: NotRequired[int]
553
- display_number: NotRequired[int]
559
+ def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
560
+ # check for compatible 'bash' tool
561
+ if tool.name == "bash_session" and (
562
+ sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
563
+ ):
564
+ return ToolBash20250124Param(type="bash_20250124", name="bash")
565
+ # not a bash tool
566
+ else:
567
+ return None
554
568
 
555
569
 
556
- # tools can be either a stock tool param or a special computer use tool param
557
- ToolParamDef = ToolParam | ComputerUseToolParam
570
+ # tools can be either a stock tool param or a special Anthropic native use tool param
571
+ ToolParamDef = (
572
+ ToolParam
573
+ | BetaToolComputerUse20250124Param
574
+ | ToolTextEditor20250124Param
575
+ | ToolBash20250124Param
576
+ )
558
577
 
559
578
 
560
579
  def add_cache_control(
561
- param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
580
+ param: TextBlockParam
581
+ | ToolParam
582
+ | BetaToolComputerUse20250124Param
583
+ | ToolTextEditor20250124Param
584
+ | ToolBash20250124Param
585
+ | dict[str, Any],
562
586
  ) -> None:
563
587
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
564
588
 
@@ -567,10 +591,10 @@ def consecutive_user_message_reducer(
567
591
  messages: list[MessageParam],
568
592
  message: MessageParam,
569
593
  ) -> list[MessageParam]:
570
- return consective_message_reducer(messages, message, "user")
594
+ return consecutive_message_reducer(messages, message, "user")
571
595
 
572
596
 
573
- def consective_message_reducer(
597
+ def consecutive_message_reducer(
574
598
  messages: list[MessageParam],
575
599
  message: MessageParam,
576
600
  role: Literal["user", "assistant"],
@@ -583,6 +607,7 @@ def consective_message_reducer(
583
607
 
584
608
 
585
609
  def combine_messages(a: MessageParam, b: MessageParam) -> MessageParam:
610
+ # TODO: Fix this code as it currently drops interesting properties when combining
586
611
  role = a["role"]
587
612
  a_content = a["content"]
588
613
  b_content = b["content"]
@@ -702,7 +727,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
702
727
  ToolUseBlockParam(
703
728
  type="tool_use",
704
729
  id=tool_call.id,
705
- name=tool_call.function,
730
+ name=tool_call.internal_name or tool_call.function,
706
731
  input=tool_call.arguments,
707
732
  )
708
733
  )
@@ -749,11 +774,13 @@ async def model_output_from_message(
749
774
  content.append(ContentText(type="text", text=content_text))
750
775
  elif isinstance(content_block, ToolUseBlock):
751
776
  tool_calls = tool_calls or []
777
+ info = maybe_mapped_call_info(content_block.name, tools)
752
778
  tool_calls.append(
753
779
  ToolCall(
754
- type="function",
780
+ type=info.internal_type,
755
781
  id=content_block.id,
756
- function=content_block.name,
782
+ function=info.inspect_name,
783
+ internal_name=info.internal_name,
757
784
  arguments=content_block.model_dump().get("input", {}),
758
785
  )
759
786
  )
@@ -788,6 +815,7 @@ async def model_output_from_message(
788
815
  + (input_tokens_cache_write or 0)
789
816
  + (input_tokens_cache_read or 0)
790
817
  + message.usage.output_tokens
818
+ + reasoning_tokens
791
819
  )
792
820
  return ModelOutput(
793
821
  model=message.model,
@@ -803,6 +831,37 @@ async def model_output_from_message(
803
831
  )
804
832
 
805
833
 
834
+ class CallInfo(NamedTuple):
835
+ internal_name: str | None
836
+ internal_type: str
837
+ inspect_name: str
838
+
839
+
840
+ def maybe_mapped_call_info(tool_called: str, tools: list[ToolInfo]) -> CallInfo:
841
+ """
842
+ Return call info - potentially transformed by native tool mappings.
843
+
844
+ Anthropic prescribes names for their native tools - `computer`, `bash`, and
845
+ `str_replace_editor`. For a variety of reasons, Inspect's tool names to not
846
+ necessarily conform to internal names. Anthropic also provides specific tool
847
+ types for these built-in tools.
848
+ """
849
+ mappings = (
850
+ (INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
851
+ ("str_replace_editor", "text_editor_20250124", "text_editor"),
852
+ ("bash", "bash_20250124", "bash_session"),
853
+ )
854
+
855
+ return next(
856
+ (
857
+ CallInfo(entry[0], entry[1], entry[2])
858
+ for entry in mappings
859
+ if entry[0] == tool_called and any(tool.name == entry[2] for tool in tools)
860
+ ),
861
+ CallInfo(None, "function", tool_called),
862
+ )
863
+
864
+
806
865
  def message_stop_reason(message: Message) -> StopReason:
807
866
  match message.stop_reason:
808
867
  case "end_turn" | "stop_sequence":
@@ -51,7 +51,6 @@ from .._chat_message import (
51
51
  ChatMessageUser,
52
52
  )
53
53
  from .._generate_config import GenerateConfig
54
- from .._image import image_url_filter
55
54
  from .._model import ModelAPI
56
55
  from .._model_call import ModelCall
57
56
  from .._model_output import (
@@ -60,6 +59,7 @@ from .._model_output import (
60
59
  ModelUsage,
61
60
  StopReason,
62
61
  )
62
+ from .._openai import openai_media_filter
63
63
  from .util import (
64
64
  environment_prerequisite_error,
65
65
  model_base_url,
@@ -182,7 +182,7 @@ class AzureAIAPI(ModelAPI):
182
182
  else None,
183
183
  ),
184
184
  response=response.as_dict() if response else {},
185
- filter=image_url_filter,
185
+ filter=openai_media_filter,
186
186
  )
187
187
 
188
188
  # make call
@@ -82,6 +82,14 @@ class MistralAPI(ModelAPI):
82
82
  config: GenerateConfig = GenerateConfig(),
83
83
  **model_args: Any,
84
84
  ):
85
+ # extract any service prefix from model name
86
+ parts = model_name.split("/")
87
+ if len(parts) > 1:
88
+ self.service: str | None = parts[0]
89
+ model_name = "/".join(parts[1:])
90
+ else:
91
+ self.service = None
92
+
85
93
  super().__init__(
86
94
  model_name=model_name,
87
95
  base_url=base_url,
@@ -94,31 +102,39 @@ class MistralAPI(ModelAPI):
94
102
  config=config,
95
103
  )
96
104
 
97
- # resolve api_key -- look for mistral then azure
105
+ # resolve api_key
98
106
  if not self.api_key:
99
- self.api_key = os.environ.get(MISTRAL_API_KEY, None)
100
- if self.api_key:
101
- base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
102
- else:
107
+ if self.is_azure():
103
108
  self.api_key = os.environ.get(
104
109
  AZUREAI_MISTRAL_API_KEY, os.environ.get(AZURE_MISTRAL_API_KEY, None)
105
110
  )
106
- if not self.api_key:
107
- raise environment_prerequisite_error(
108
- "Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
109
- )
110
- base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
111
- if not base_url:
111
+ else:
112
+ self.api_key = os.environ.get(MISTRAL_API_KEY, None)
113
+
114
+ if not self.api_key:
115
+ raise environment_prerequisite_error(
116
+ "Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
117
+ )
118
+
119
+ if not self.base_url:
120
+ if self.is_azure():
121
+ self.base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
122
+ if not self.base_url:
112
123
  raise ValueError(
113
124
  "You must provide a base URL when using Mistral on Azure. Use the AZUREAI_MISTRAL_BASE_URL "
114
125
  + " environment variable or the --model-base-url CLI flag to set the base URL."
115
126
  )
127
+ else:
128
+ self.base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
116
129
 
117
- if base_url:
118
- model_args["server_url"] = base_url
130
+ if self.base_url:
131
+ model_args["server_url"] = self.base_url
119
132
 
120
133
  self.model_args = model_args
121
134
 
135
+ def is_azure(self) -> bool:
136
+ return self.service == "azure"
137
+
122
138
  @override
123
139
  async def close(self) -> None:
124
140
  # client is created and destroyed in generate
@@ -22,28 +22,27 @@ from inspect_ai._util.error import PrerequisiteError
22
22
  from inspect_ai._util.http import is_retryable_http_status
23
23
  from inspect_ai._util.logger import warn_once
24
24
  from inspect_ai.model._openai import chat_choices_from_openai
25
+ from inspect_ai.model._providers.openai_responses import generate_responses
25
26
  from inspect_ai.model._providers.util.hooks import HttpxHooks
26
27
  from inspect_ai.tool import ToolChoice, ToolInfo
27
28
 
28
29
  from .._chat_message import ChatMessage
29
30
  from .._generate_config import GenerateConfig
30
- from .._image import image_url_filter
31
31
  from .._model import ModelAPI
32
32
  from .._model_call import ModelCall
33
- from .._model_output import (
34
- ChatCompletionChoice,
35
- ModelOutput,
36
- ModelUsage,
37
- StopReason,
38
- )
33
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
39
34
  from .._openai import (
35
+ OpenAIResponseError,
40
36
  is_gpt,
41
37
  is_o1_mini,
42
38
  is_o1_preview,
39
+ is_o1_pro,
43
40
  is_o_series,
44
41
  openai_chat_messages,
45
42
  openai_chat_tool_choice,
46
43
  openai_chat_tools,
44
+ openai_handle_bad_request,
45
+ openai_media_filter,
47
46
  )
48
47
  from .openai_o1 import generate_o1
49
48
  from .util import (
@@ -65,8 +64,22 @@ class OpenAIAPI(ModelAPI):
65
64
  base_url: str | None = None,
66
65
  api_key: str | None = None,
67
66
  config: GenerateConfig = GenerateConfig(),
67
+ responses_api: bool | None = None,
68
68
  **model_args: Any,
69
69
  ) -> None:
70
+ # extract azure service prefix from model name (other providers
71
+ # that subclass from us like together expect to have the qualifier
72
+ # in the model name e.g. google/gemma-2b-it)
73
+ parts = model_name.split("/")
74
+ if parts[0] == "azure" and len(parts) > 1:
75
+ self.service: str | None = parts[0]
76
+ model_name = "/".join(parts[1:])
77
+ else:
78
+ self.service = None
79
+
80
+ # note whether we are forcing the responses_api
81
+ self.responses_api = True if responses_api else False
82
+
70
83
  # call super
71
84
  super().__init__(
72
85
  model_name=model_name,
@@ -76,32 +89,23 @@ class OpenAIAPI(ModelAPI):
76
89
  config=config,
77
90
  )
78
91
 
79
- # extract any service prefix from model name
80
- parts = model_name.split("/")
81
- if len(parts) > 1:
82
- self.service: str | None = parts[0]
83
- model_name = "/".join(parts[1:])
84
- else:
85
- self.service = None
86
-
87
92
  # resolve api_key
88
93
  if not self.api_key:
89
- self.api_key = os.environ.get(
90
- AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
91
- )
92
- # backward compatibility for when env vars determined service
93
- if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
94
- self.service = "azure"
94
+ if self.service == "azure":
95
+ self.api_key = os.environ.get(
96
+ AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
97
+ )
95
98
  else:
96
99
  self.api_key = os.environ.get(OPENAI_API_KEY, None)
97
- if not self.api_key:
98
- raise environment_prerequisite_error(
99
- "OpenAI",
100
- [
101
- OPENAI_API_KEY,
102
- AZUREAI_OPENAI_API_KEY,
103
- ],
104
- )
100
+
101
+ if not self.api_key:
102
+ raise environment_prerequisite_error(
103
+ "OpenAI",
104
+ [
105
+ OPENAI_API_KEY,
106
+ AZUREAI_OPENAI_API_KEY,
107
+ ],
108
+ )
105
109
 
106
110
  # create async http client
107
111
  http_client = OpenAIAsyncHttpxClient()
@@ -123,10 +127,16 @@ class OpenAIAPI(ModelAPI):
123
127
  + "environment variable or the --model-base-url CLI flag to set the base URL."
124
128
  )
125
129
 
130
+ # resolve version
131
+ api_version = os.environ.get(
132
+ "AZUREAI_OPENAI_API_VERSION",
133
+ os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
134
+ )
135
+
126
136
  self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
127
137
  api_key=self.api_key,
138
+ api_version=api_version,
128
139
  azure_endpoint=base_url,
129
- azure_deployment=model_name,
130
140
  http_client=http_client,
131
141
  **model_args,
132
142
  )
@@ -147,6 +157,9 @@ class OpenAIAPI(ModelAPI):
147
157
  def is_o_series(self) -> bool:
148
158
  return is_o_series(self.model_name)
149
159
 
160
+ def is_o1_pro(self) -> bool:
161
+ return is_o1_pro(self.model_name)
162
+
150
163
  def is_o1_mini(self) -> bool:
151
164
  return is_o1_mini(self.model_name)
152
165
 
@@ -175,6 +188,16 @@ class OpenAIAPI(ModelAPI):
175
188
  tools=tools,
176
189
  **self.completion_params(config, False),
177
190
  )
191
+ elif self.is_o1_pro() or self.responses_api:
192
+ return await generate_responses(
193
+ client=self.client,
194
+ http_hooks=self._http_hooks,
195
+ model_name=self.model_name,
196
+ input=input,
197
+ tools=tools,
198
+ tool_choice=tool_choice,
199
+ config=config,
200
+ )
178
201
 
179
202
  # allocate request_id (so we can see it from ModelCall)
180
203
  request_id = self._http_hooks.start_request()
@@ -187,7 +210,7 @@ class OpenAIAPI(ModelAPI):
187
210
  return ModelCall.create(
188
211
  request=request,
189
212
  response=response,
190
- filter=image_url_filter,
213
+ filter=openai_media_filter,
191
214
  time=self._http_hooks.end_request(request_id),
192
215
  )
193
216
 
@@ -219,6 +242,7 @@ class OpenAIAPI(ModelAPI):
219
242
 
220
243
  # save response for model_call
221
244
  response = completion.model_dump()
245
+ self.on_response(response)
222
246
 
223
247
  # parse out choices
224
248
  choices = self._chat_choices_from_response(completion, tools)
@@ -250,6 +274,12 @@ class OpenAIAPI(ModelAPI):
250
274
  except BadRequestError as e:
251
275
  return self.handle_bad_request(e), model_call()
252
276
 
277
+ def on_response(self, response: dict[str, Any]) -> None:
278
+ pass
279
+
280
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
281
+ return openai_handle_bad_request(self.model_name, ex)
282
+
253
283
  def _chat_choices_from_response(
254
284
  self, response: ChatCompletion, tools: list[ToolInfo]
255
285
  ) -> list[ChatCompletionChoice]:
@@ -268,6 +298,8 @@ class OpenAIAPI(ModelAPI):
268
298
  return True
269
299
  elif isinstance(ex, APIStatusError):
270
300
  return is_retryable_http_status(ex.status_code)
301
+ elif isinstance(ex, OpenAIResponseError):
302
+ return ex.code in ["rate_limit_exceeded", "server_error"]
271
303
  elif isinstance(ex, APITimeoutError):
272
304
  return True
273
305
  else:
@@ -322,6 +354,7 @@ class OpenAIAPI(ModelAPI):
322
354
  config.reasoning_effort is not None
323
355
  and not self.is_gpt()
324
356
  and not self.is_o1_mini()
357
+ and not self.is_o1_preview()
325
358
  ):
326
359
  params["reasoning_effort"] = config.reasoning_effort
327
360
  if config.response_schema is not None:
@@ -339,32 +372,6 @@ class OpenAIAPI(ModelAPI):
339
372
 
340
373
  return params
341
374
 
342
- # convert some well known bad request errors into ModelOutput
343
- def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
344
- # extract message
345
- if isinstance(e.body, dict) and "message" in e.body.keys():
346
- content = str(e.body.get("message"))
347
- else:
348
- content = e.message
349
-
350
- # narrow stop_reason
351
- stop_reason: StopReason | None = None
352
- if e.code == "context_length_exceeded":
353
- stop_reason = "model_length"
354
- elif (
355
- e.code == "invalid_prompt" # seems to happen for o1/o3
356
- or e.code == "content_policy_violation" # seems to happen for vision
357
- or e.code == "content_filter" # seems to happen on azure
358
- ):
359
- stop_reason = "content_filter"
360
-
361
- if stop_reason:
362
- return ModelOutput.from_content(
363
- model=self.model_name, content=content, stop_reason=stop_reason
364
- )
365
- else:
366
- return e
367
-
368
375
 
369
376
  class OpenAIAsyncHttpxClient(httpx.AsyncClient):
370
377
  """Custom async client that deals better with long running Async requests.