inspect-ai 0.3.75__py3-none-any.whl → 0.3.77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_display/core/results.py +6 -1
- inspect_ai/_eval/eval.py +8 -1
- inspect_ai/_eval/evalset.py +6 -2
- inspect_ai/_eval/registry.py +3 -5
- inspect_ai/_eval/run.py +7 -2
- inspect_ai/_eval/task/run.py +4 -0
- inspect_ai/_util/content.py +3 -0
- inspect_ai/_util/logger.py +3 -0
- inspect_ai/_view/www/dist/assets/index.css +28 -16
- inspect_ai/_view/www/dist/assets/index.js +4811 -4609
- inspect_ai/_view/www/log-schema.json +79 -9
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +22 -4
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/CategoricalScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -2
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +1 -1
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +2 -2
- inspect_ai/_view/www/src/types/log.d.ts +11 -5
- inspect_ai/log/_recorders/json.py +8 -0
- inspect_ai/log/_transcript.py +13 -4
- inspect_ai/model/_call_tools.py +13 -4
- inspect_ai/model/_chat_message.py +3 -0
- inspect_ai/model/_model.py +5 -1
- inspect_ai/model/_model_output.py +6 -1
- inspect_ai/model/_openai.py +78 -10
- inspect_ai/model/_openai_responses.py +277 -0
- inspect_ai/model/_providers/anthropic.py +134 -75
- inspect_ai/model/_providers/azureai.py +2 -2
- inspect_ai/model/_providers/mistral.py +29 -13
- inspect_ai/model/_providers/openai.py +64 -57
- inspect_ai/model/_providers/openai_responses.py +177 -0
- inspect_ai/model/_providers/openrouter.py +52 -2
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/vertex.py +5 -2
- inspect_ai/tool/__init__.py +6 -0
- inspect_ai/tool/_tool.py +23 -3
- inspect_ai/tool/_tool_call.py +5 -2
- inspect_ai/tool/_tool_support_helpers.py +200 -0
- inspect_ai/tool/_tools/_bash_session.py +119 -0
- inspect_ai/tool/_tools/_computer/_computer.py +1 -1
- inspect_ai/tool/_tools/_text_editor.py +121 -0
- inspect_ai/tool/_tools/_think.py +48 -0
- inspect_ai/tool/_tools/_web_browser/_back_compat.py +150 -0
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +75 -130
- inspect_ai/tool/_tools/_web_search.py +1 -1
- inspect_ai/util/_json.py +28 -0
- inspect_ai/util/_sandbox/context.py +16 -7
- inspect_ai/util/_sandbox/docker/config.py +1 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -3
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/METADATA +5 -2
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/RECORD +56 -80
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/WHEEL +1 -1
- inspect_ai/model/_image.py +0 -15
- inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +0 -8
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +0 -24
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +0 -25
- inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +0 -22
- inspect_ai/tool/_tools/_web_browser/_resources/README.md +0 -63
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +0 -71
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +0 -323
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +0 -5
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +0 -279
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +0 -9
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +0 -293
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +0 -94
- inspect_ai/tool/_tools/_web_browser/_resources/constants.py +0 -2
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +0 -2
- inspect_ai/tool/_tools/_web_browser/_resources/mock_environment.py +0 -45
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +0 -50
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +0 -48
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +0 -280
- inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +0 -64
- inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +0 -146
- inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +0 -64
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +0 -180
- inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +0 -99
- inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +0 -15
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +0 -44
- inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +0 -39
- inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +0 -214
- inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +0 -35
- inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +0 -192
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info/licenses}/LICENSE +0 -0
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,12 @@
|
|
1
1
|
import functools
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
import sys
|
5
4
|
from copy import copy
|
6
5
|
from logging import getLogger
|
7
|
-
from typing import Any, Literal, Optional, Tuple,
|
6
|
+
from typing import Any, Literal, NamedTuple, Optional, Tuple, cast
|
8
7
|
|
9
8
|
import httpcore
|
10
9
|
import httpx
|
11
|
-
|
12
|
-
from inspect_ai._util.http import is_retryable_http_status
|
13
|
-
|
14
|
-
from .util.hooks import HttpxHooks
|
15
|
-
|
16
|
-
if sys.version_info >= (3, 11):
|
17
|
-
from typing import NotRequired
|
18
|
-
else:
|
19
|
-
from typing_extensions import NotRequired
|
20
|
-
|
21
10
|
from anthropic import (
|
22
11
|
APIConnectionError,
|
23
12
|
APIStatusError,
|
@@ -39,19 +28,19 @@ from anthropic.types import (
|
|
39
28
|
TextBlockParam,
|
40
29
|
ThinkingBlock,
|
41
30
|
ThinkingBlockParam,
|
31
|
+
ToolBash20250124Param,
|
42
32
|
ToolParam,
|
43
33
|
ToolResultBlockParam,
|
34
|
+
ToolTextEditor20250124Param,
|
44
35
|
ToolUseBlock,
|
45
36
|
ToolUseBlockParam,
|
46
37
|
message_create_params,
|
47
38
|
)
|
39
|
+
from anthropic.types.beta import BetaToolComputerUse20250124Param
|
48
40
|
from pydantic import JsonValue
|
49
41
|
from typing_extensions import override
|
50
42
|
|
51
|
-
from inspect_ai._util.constants import
|
52
|
-
BASE_64_DATA_REMOVED,
|
53
|
-
NO_CONTENT,
|
54
|
-
)
|
43
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
55
44
|
from inspect_ai._util.content import (
|
56
45
|
Content,
|
57
46
|
ContentImage,
|
@@ -59,6 +48,7 @@ from inspect_ai._util.content import (
|
|
59
48
|
ContentText,
|
60
49
|
)
|
61
50
|
from inspect_ai._util.error import exception_message
|
51
|
+
from inspect_ai._util.http import is_retryable_http_status
|
62
52
|
from inspect_ai._util.images import file_as_data_uri
|
63
53
|
from inspect_ai._util.logger import warn_once
|
64
54
|
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
@@ -70,11 +60,14 @@ from .._model import ModelAPI
|
|
70
60
|
from .._model_call import ModelCall
|
71
61
|
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage, StopReason
|
72
62
|
from .util import environment_prerequisite_error, model_base_url
|
63
|
+
from .util.hooks import HttpxHooks
|
73
64
|
|
74
65
|
logger = getLogger(__name__)
|
75
66
|
|
76
67
|
ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY"
|
77
68
|
|
69
|
+
INTERNAL_COMPUTER_TOOL_NAME = "computer"
|
70
|
+
|
78
71
|
|
79
72
|
class AnthropicAPI(ModelAPI):
|
80
73
|
def __init__(
|
@@ -93,7 +86,7 @@ class AnthropicAPI(ModelAPI):
|
|
93
86
|
else:
|
94
87
|
self.service = None
|
95
88
|
|
96
|
-
# collect
|
89
|
+
# collect generate model_args (then delete them so we can pass the rest on)
|
97
90
|
def collect_model_arg(name: str) -> Any | None:
|
98
91
|
nonlocal model_args
|
99
92
|
value = model_args.get(name, None)
|
@@ -193,14 +186,11 @@ class AnthropicAPI(ModelAPI):
|
|
193
186
|
|
194
187
|
# generate
|
195
188
|
try:
|
196
|
-
(
|
197
|
-
|
198
|
-
|
199
|
-
messages,
|
200
|
-
computer_use,
|
201
|
-
) = await self.resolve_chat_input(input, tools, config)
|
189
|
+
system_param, tools_param, messages = await self.resolve_chat_input(
|
190
|
+
input, tools, config
|
191
|
+
)
|
202
192
|
|
203
|
-
# prepare request params (
|
193
|
+
# prepare request params (assembled this way so we can log the raw model call)
|
204
194
|
request = dict(messages=messages)
|
205
195
|
|
206
196
|
# system messages and tools
|
@@ -218,7 +208,13 @@ class AnthropicAPI(ModelAPI):
|
|
218
208
|
|
219
209
|
# extra headers (for time tracker and computer use)
|
220
210
|
extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id}
|
221
|
-
if
|
211
|
+
if any(
|
212
|
+
tool.get("type", None) == "computer_20250124" for tool in tools_param
|
213
|
+
):
|
214
|
+
# From: https://docs.anthropic.com/en/docs/agents-and-tools/computer-use#claude-3-7-sonnet-beta-flag
|
215
|
+
# Note: The Bash (bash_20250124) and Text Editor (text_editor_20250124)
|
216
|
+
# tools are generally available for Claude 3.5 Sonnet (new) as well and
|
217
|
+
# can be used without the computer use beta header.
|
222
218
|
betas.append("computer-use-2025-01-24")
|
223
219
|
if len(betas) > 0:
|
224
220
|
extra_headers["anthropic-beta"] = ",".join(betas)
|
@@ -405,9 +401,7 @@ class AnthropicAPI(ModelAPI):
|
|
405
401
|
input: list[ChatMessage],
|
406
402
|
tools: list[ToolInfo],
|
407
403
|
config: GenerateConfig,
|
408
|
-
) -> Tuple[
|
409
|
-
list[TextBlockParam] | None, list["ToolParamDef"], list[MessageParam], bool
|
410
|
-
]:
|
404
|
+
) -> Tuple[list[TextBlockParam] | None, list["ToolParamDef"], list[MessageParam]]:
|
411
405
|
# extract system message
|
412
406
|
system_messages, messages = split_system_messages(input, config)
|
413
407
|
|
@@ -420,7 +414,7 @@ class AnthropicAPI(ModelAPI):
|
|
420
414
|
)
|
421
415
|
|
422
416
|
# tools
|
423
|
-
tools_params
|
417
|
+
tools_params = [self.tool_param_for_tool_info(tool, config) for tool in tools]
|
424
418
|
|
425
419
|
# system messages
|
426
420
|
if len(system_messages) > 0:
|
@@ -470,40 +464,35 @@ class AnthropicAPI(ModelAPI):
|
|
470
464
|
add_cache_control(cast(dict[str, Any], content[-1]))
|
471
465
|
|
472
466
|
# return chat input
|
473
|
-
return system_param, tools_params, message_params
|
474
|
-
|
475
|
-
def
|
476
|
-
self,
|
477
|
-
) ->
|
478
|
-
# tool
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
467
|
+
return system_param, tools_params, message_params
|
468
|
+
|
469
|
+
def tool_param_for_tool_info(
|
470
|
+
self, tool: ToolInfo, config: GenerateConfig
|
471
|
+
) -> "ToolParamDef":
|
472
|
+
# Use a native tool implementation when available. Otherwise, use the
|
473
|
+
# standard tool implementation
|
474
|
+
return self.maybe_native_tool_param(tool, config) or ToolParam(
|
475
|
+
name=tool.name,
|
476
|
+
description=tool.description,
|
477
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
478
|
+
)
|
479
|
+
|
480
|
+
def maybe_native_tool_param(
|
481
|
+
self, tool: ToolInfo, config: GenerateConfig
|
482
|
+
) -> Optional["ToolParamDef"]:
|
483
|
+
return (
|
484
|
+
(
|
486
485
|
self.computer_use_tool_param(tool)
|
487
|
-
|
488
|
-
|
486
|
+
or self.text_editor_tool_param(tool)
|
487
|
+
or self.bash_tool_param(tool)
|
489
488
|
)
|
490
|
-
if
|
491
|
-
|
492
|
-
|
493
|
-
else:
|
494
|
-
tool_params.append(
|
495
|
-
ToolParam(
|
496
|
-
name=tool.name,
|
497
|
-
description=tool.description,
|
498
|
-
input_schema=tool.parameters.model_dump(exclude_none=True),
|
499
|
-
)
|
500
|
-
)
|
501
|
-
|
502
|
-
return tool_params, computer_use
|
489
|
+
if config.internal_tools is not False
|
490
|
+
else None
|
491
|
+
)
|
503
492
|
|
504
493
|
def computer_use_tool_param(
|
505
494
|
self, tool: ToolInfo
|
506
|
-
) -> Optional[
|
495
|
+
) -> Optional[BetaToolComputerUse20250124Param]:
|
507
496
|
# check for compatible 'computer' tool
|
508
497
|
if tool.name == "computer" and (
|
509
498
|
sorted(tool.parameters.properties.keys())
|
@@ -525,7 +514,7 @@ class AnthropicAPI(ModelAPI):
|
|
525
514
|
"Use of Anthropic's native computer use support is not enabled in Claude 3.5. Please use 3.7 or later to leverage the native support.",
|
526
515
|
)
|
527
516
|
return None
|
528
|
-
return
|
517
|
+
return BetaToolComputerUse20250124Param(
|
529
518
|
type="computer_20250124",
|
530
519
|
name="computer",
|
531
520
|
# Note: The dimensions passed here for display_width_px and display_height_px should
|
@@ -542,23 +531,58 @@ class AnthropicAPI(ModelAPI):
|
|
542
531
|
else:
|
543
532
|
return None
|
544
533
|
|
534
|
+
def text_editor_tool_param(
|
535
|
+
self, tool: ToolInfo
|
536
|
+
) -> Optional[ToolTextEditor20250124Param]:
|
537
|
+
# check for compatible 'text editor' tool
|
538
|
+
if tool.name == "text_editor" and (
|
539
|
+
sorted(tool.parameters.properties.keys())
|
540
|
+
== sorted(
|
541
|
+
[
|
542
|
+
"command",
|
543
|
+
"file_text",
|
544
|
+
"insert_line",
|
545
|
+
"new_str",
|
546
|
+
"old_str",
|
547
|
+
"path",
|
548
|
+
"view_range",
|
549
|
+
]
|
550
|
+
)
|
551
|
+
):
|
552
|
+
return ToolTextEditor20250124Param(
|
553
|
+
type="text_editor_20250124", name="str_replace_editor"
|
554
|
+
)
|
555
|
+
# not a text_editor tool
|
556
|
+
else:
|
557
|
+
return None
|
545
558
|
|
546
|
-
|
547
|
-
#
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
559
|
+
def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
|
560
|
+
# check for compatible 'bash' tool
|
561
|
+
if tool.name == "bash_session" and (
|
562
|
+
sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
|
563
|
+
):
|
564
|
+
return ToolBash20250124Param(type="bash_20250124", name="bash")
|
565
|
+
# not a bash tool
|
566
|
+
else:
|
567
|
+
return None
|
554
568
|
|
555
569
|
|
556
|
-
# tools can be either a stock tool param or a special
|
557
|
-
ToolParamDef =
|
570
|
+
# tools can be either a stock tool param or a special Anthropic native use tool param
|
571
|
+
ToolParamDef = (
|
572
|
+
ToolParam
|
573
|
+
| BetaToolComputerUse20250124Param
|
574
|
+
| ToolTextEditor20250124Param
|
575
|
+
| ToolBash20250124Param
|
576
|
+
)
|
558
577
|
|
559
578
|
|
560
579
|
def add_cache_control(
|
561
|
-
param: TextBlockParam
|
580
|
+
param: TextBlockParam
|
581
|
+
| ToolParam
|
582
|
+
| BetaToolComputerUse20250124Param
|
583
|
+
| ToolTextEditor20250124Param
|
584
|
+
| ToolBash20250124Param
|
585
|
+
| dict[str, Any],
|
562
586
|
) -> None:
|
563
587
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
564
588
|
|
@@ -567,10 +591,10 @@ def consecutive_user_message_reducer(
|
|
567
591
|
messages: list[MessageParam],
|
568
592
|
message: MessageParam,
|
569
593
|
) -> list[MessageParam]:
|
570
|
-
return
|
594
|
+
return consecutive_message_reducer(messages, message, "user")
|
571
595
|
|
572
596
|
|
573
|
-
def
|
597
|
+
def consecutive_message_reducer(
|
574
598
|
messages: list[MessageParam],
|
575
599
|
message: MessageParam,
|
576
600
|
role: Literal["user", "assistant"],
|
@@ -583,6 +607,7 @@ def consective_message_reducer(
|
|
583
607
|
|
584
608
|
|
585
609
|
def combine_messages(a: MessageParam, b: MessageParam) -> MessageParam:
|
610
|
+
# TODO: Fix this code as it currently drops interesting properties when combining
|
586
611
|
role = a["role"]
|
587
612
|
a_content = a["content"]
|
588
613
|
b_content = b["content"]
|
@@ -702,7 +727,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
|
|
702
727
|
ToolUseBlockParam(
|
703
728
|
type="tool_use",
|
704
729
|
id=tool_call.id,
|
705
|
-
name=tool_call.function,
|
730
|
+
name=tool_call.internal_name or tool_call.function,
|
706
731
|
input=tool_call.arguments,
|
707
732
|
)
|
708
733
|
)
|
@@ -749,11 +774,13 @@ async def model_output_from_message(
|
|
749
774
|
content.append(ContentText(type="text", text=content_text))
|
750
775
|
elif isinstance(content_block, ToolUseBlock):
|
751
776
|
tool_calls = tool_calls or []
|
777
|
+
info = maybe_mapped_call_info(content_block.name, tools)
|
752
778
|
tool_calls.append(
|
753
779
|
ToolCall(
|
754
|
-
type=
|
780
|
+
type=info.internal_type,
|
755
781
|
id=content_block.id,
|
756
|
-
function=
|
782
|
+
function=info.inspect_name,
|
783
|
+
internal_name=info.internal_name,
|
757
784
|
arguments=content_block.model_dump().get("input", {}),
|
758
785
|
)
|
759
786
|
)
|
@@ -788,6 +815,7 @@ async def model_output_from_message(
|
|
788
815
|
+ (input_tokens_cache_write or 0)
|
789
816
|
+ (input_tokens_cache_read or 0)
|
790
817
|
+ message.usage.output_tokens
|
818
|
+
+ reasoning_tokens
|
791
819
|
)
|
792
820
|
return ModelOutput(
|
793
821
|
model=message.model,
|
@@ -803,6 +831,37 @@ async def model_output_from_message(
|
|
803
831
|
)
|
804
832
|
|
805
833
|
|
834
|
+
class CallInfo(NamedTuple):
|
835
|
+
internal_name: str | None
|
836
|
+
internal_type: str
|
837
|
+
inspect_name: str
|
838
|
+
|
839
|
+
|
840
|
+
def maybe_mapped_call_info(tool_called: str, tools: list[ToolInfo]) -> CallInfo:
|
841
|
+
"""
|
842
|
+
Return call info - potentially transformed by native tool mappings.
|
843
|
+
|
844
|
+
Anthropic prescribes names for their native tools - `computer`, `bash`, and
|
845
|
+
`str_replace_editor`. For a variety of reasons, Inspect's tool names to not
|
846
|
+
necessarily conform to internal names. Anthropic also provides specific tool
|
847
|
+
types for these built-in tools.
|
848
|
+
"""
|
849
|
+
mappings = (
|
850
|
+
(INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
|
851
|
+
("str_replace_editor", "text_editor_20250124", "text_editor"),
|
852
|
+
("bash", "bash_20250124", "bash_session"),
|
853
|
+
)
|
854
|
+
|
855
|
+
return next(
|
856
|
+
(
|
857
|
+
CallInfo(entry[0], entry[1], entry[2])
|
858
|
+
for entry in mappings
|
859
|
+
if entry[0] == tool_called and any(tool.name == entry[2] for tool in tools)
|
860
|
+
),
|
861
|
+
CallInfo(None, "function", tool_called),
|
862
|
+
)
|
863
|
+
|
864
|
+
|
806
865
|
def message_stop_reason(message: Message) -> StopReason:
|
807
866
|
match message.stop_reason:
|
808
867
|
case "end_turn" | "stop_sequence":
|
@@ -51,7 +51,6 @@ from .._chat_message import (
|
|
51
51
|
ChatMessageUser,
|
52
52
|
)
|
53
53
|
from .._generate_config import GenerateConfig
|
54
|
-
from .._image import image_url_filter
|
55
54
|
from .._model import ModelAPI
|
56
55
|
from .._model_call import ModelCall
|
57
56
|
from .._model_output import (
|
@@ -60,6 +59,7 @@ from .._model_output import (
|
|
60
59
|
ModelUsage,
|
61
60
|
StopReason,
|
62
61
|
)
|
62
|
+
from .._openai import openai_media_filter
|
63
63
|
from .util import (
|
64
64
|
environment_prerequisite_error,
|
65
65
|
model_base_url,
|
@@ -182,7 +182,7 @@ class AzureAIAPI(ModelAPI):
|
|
182
182
|
else None,
|
183
183
|
),
|
184
184
|
response=response.as_dict() if response else {},
|
185
|
-
filter=
|
185
|
+
filter=openai_media_filter,
|
186
186
|
)
|
187
187
|
|
188
188
|
# make call
|
@@ -82,6 +82,14 @@ class MistralAPI(ModelAPI):
|
|
82
82
|
config: GenerateConfig = GenerateConfig(),
|
83
83
|
**model_args: Any,
|
84
84
|
):
|
85
|
+
# extract any service prefix from model name
|
86
|
+
parts = model_name.split("/")
|
87
|
+
if len(parts) > 1:
|
88
|
+
self.service: str | None = parts[0]
|
89
|
+
model_name = "/".join(parts[1:])
|
90
|
+
else:
|
91
|
+
self.service = None
|
92
|
+
|
85
93
|
super().__init__(
|
86
94
|
model_name=model_name,
|
87
95
|
base_url=base_url,
|
@@ -94,31 +102,39 @@ class MistralAPI(ModelAPI):
|
|
94
102
|
config=config,
|
95
103
|
)
|
96
104
|
|
97
|
-
# resolve api_key
|
105
|
+
# resolve api_key
|
98
106
|
if not self.api_key:
|
99
|
-
self.
|
100
|
-
if self.api_key:
|
101
|
-
base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
|
102
|
-
else:
|
107
|
+
if self.is_azure():
|
103
108
|
self.api_key = os.environ.get(
|
104
109
|
AZUREAI_MISTRAL_API_KEY, os.environ.get(AZURE_MISTRAL_API_KEY, None)
|
105
110
|
)
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
111
|
+
else:
|
112
|
+
self.api_key = os.environ.get(MISTRAL_API_KEY, None)
|
113
|
+
|
114
|
+
if not self.api_key:
|
115
|
+
raise environment_prerequisite_error(
|
116
|
+
"Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
|
117
|
+
)
|
118
|
+
|
119
|
+
if not self.base_url:
|
120
|
+
if self.is_azure():
|
121
|
+
self.base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
|
122
|
+
if not self.base_url:
|
112
123
|
raise ValueError(
|
113
124
|
"You must provide a base URL when using Mistral on Azure. Use the AZUREAI_MISTRAL_BASE_URL "
|
114
125
|
+ " environment variable or the --model-base-url CLI flag to set the base URL."
|
115
126
|
)
|
127
|
+
else:
|
128
|
+
self.base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
|
116
129
|
|
117
|
-
if base_url:
|
118
|
-
model_args["server_url"] = base_url
|
130
|
+
if self.base_url:
|
131
|
+
model_args["server_url"] = self.base_url
|
119
132
|
|
120
133
|
self.model_args = model_args
|
121
134
|
|
135
|
+
def is_azure(self) -> bool:
|
136
|
+
return self.service == "azure"
|
137
|
+
|
122
138
|
@override
|
123
139
|
async def close(self) -> None:
|
124
140
|
# client is created and destroyed in generate
|
@@ -22,28 +22,27 @@ from inspect_ai._util.error import PrerequisiteError
|
|
22
22
|
from inspect_ai._util.http import is_retryable_http_status
|
23
23
|
from inspect_ai._util.logger import warn_once
|
24
24
|
from inspect_ai.model._openai import chat_choices_from_openai
|
25
|
+
from inspect_ai.model._providers.openai_responses import generate_responses
|
25
26
|
from inspect_ai.model._providers.util.hooks import HttpxHooks
|
26
27
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
27
28
|
|
28
29
|
from .._chat_message import ChatMessage
|
29
30
|
from .._generate_config import GenerateConfig
|
30
|
-
from .._image import image_url_filter
|
31
31
|
from .._model import ModelAPI
|
32
32
|
from .._model_call import ModelCall
|
33
|
-
from .._model_output import
|
34
|
-
ChatCompletionChoice,
|
35
|
-
ModelOutput,
|
36
|
-
ModelUsage,
|
37
|
-
StopReason,
|
38
|
-
)
|
33
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
39
34
|
from .._openai import (
|
35
|
+
OpenAIResponseError,
|
40
36
|
is_gpt,
|
41
37
|
is_o1_mini,
|
42
38
|
is_o1_preview,
|
39
|
+
is_o1_pro,
|
43
40
|
is_o_series,
|
44
41
|
openai_chat_messages,
|
45
42
|
openai_chat_tool_choice,
|
46
43
|
openai_chat_tools,
|
44
|
+
openai_handle_bad_request,
|
45
|
+
openai_media_filter,
|
47
46
|
)
|
48
47
|
from .openai_o1 import generate_o1
|
49
48
|
from .util import (
|
@@ -65,8 +64,22 @@ class OpenAIAPI(ModelAPI):
|
|
65
64
|
base_url: str | None = None,
|
66
65
|
api_key: str | None = None,
|
67
66
|
config: GenerateConfig = GenerateConfig(),
|
67
|
+
responses_api: bool | None = None,
|
68
68
|
**model_args: Any,
|
69
69
|
) -> None:
|
70
|
+
# extract azure service prefix from model name (other providers
|
71
|
+
# that subclass from us like together expect to have the qualifier
|
72
|
+
# in the model name e.g. google/gemma-2b-it)
|
73
|
+
parts = model_name.split("/")
|
74
|
+
if parts[0] == "azure" and len(parts) > 1:
|
75
|
+
self.service: str | None = parts[0]
|
76
|
+
model_name = "/".join(parts[1:])
|
77
|
+
else:
|
78
|
+
self.service = None
|
79
|
+
|
80
|
+
# note whether we are forcing the responses_api
|
81
|
+
self.responses_api = True if responses_api else False
|
82
|
+
|
70
83
|
# call super
|
71
84
|
super().__init__(
|
72
85
|
model_name=model_name,
|
@@ -76,32 +89,23 @@ class OpenAIAPI(ModelAPI):
|
|
76
89
|
config=config,
|
77
90
|
)
|
78
91
|
|
79
|
-
# extract any service prefix from model name
|
80
|
-
parts = model_name.split("/")
|
81
|
-
if len(parts) > 1:
|
82
|
-
self.service: str | None = parts[0]
|
83
|
-
model_name = "/".join(parts[1:])
|
84
|
-
else:
|
85
|
-
self.service = None
|
86
|
-
|
87
92
|
# resolve api_key
|
88
93
|
if not self.api_key:
|
89
|
-
self.
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
|
94
|
-
self.service = "azure"
|
94
|
+
if self.service == "azure":
|
95
|
+
self.api_key = os.environ.get(
|
96
|
+
AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
|
97
|
+
)
|
95
98
|
else:
|
96
99
|
self.api_key = os.environ.get(OPENAI_API_KEY, None)
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
100
|
+
|
101
|
+
if not self.api_key:
|
102
|
+
raise environment_prerequisite_error(
|
103
|
+
"OpenAI",
|
104
|
+
[
|
105
|
+
OPENAI_API_KEY,
|
106
|
+
AZUREAI_OPENAI_API_KEY,
|
107
|
+
],
|
108
|
+
)
|
105
109
|
|
106
110
|
# create async http client
|
107
111
|
http_client = OpenAIAsyncHttpxClient()
|
@@ -123,10 +127,16 @@ class OpenAIAPI(ModelAPI):
|
|
123
127
|
+ "environment variable or the --model-base-url CLI flag to set the base URL."
|
124
128
|
)
|
125
129
|
|
130
|
+
# resolve version
|
131
|
+
api_version = os.environ.get(
|
132
|
+
"AZUREAI_OPENAI_API_VERSION",
|
133
|
+
os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
|
134
|
+
)
|
135
|
+
|
126
136
|
self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
|
127
137
|
api_key=self.api_key,
|
138
|
+
api_version=api_version,
|
128
139
|
azure_endpoint=base_url,
|
129
|
-
azure_deployment=model_name,
|
130
140
|
http_client=http_client,
|
131
141
|
**model_args,
|
132
142
|
)
|
@@ -147,6 +157,9 @@ class OpenAIAPI(ModelAPI):
|
|
147
157
|
def is_o_series(self) -> bool:
|
148
158
|
return is_o_series(self.model_name)
|
149
159
|
|
160
|
+
def is_o1_pro(self) -> bool:
|
161
|
+
return is_o1_pro(self.model_name)
|
162
|
+
|
150
163
|
def is_o1_mini(self) -> bool:
|
151
164
|
return is_o1_mini(self.model_name)
|
152
165
|
|
@@ -175,6 +188,16 @@ class OpenAIAPI(ModelAPI):
|
|
175
188
|
tools=tools,
|
176
189
|
**self.completion_params(config, False),
|
177
190
|
)
|
191
|
+
elif self.is_o1_pro() or self.responses_api:
|
192
|
+
return await generate_responses(
|
193
|
+
client=self.client,
|
194
|
+
http_hooks=self._http_hooks,
|
195
|
+
model_name=self.model_name,
|
196
|
+
input=input,
|
197
|
+
tools=tools,
|
198
|
+
tool_choice=tool_choice,
|
199
|
+
config=config,
|
200
|
+
)
|
178
201
|
|
179
202
|
# allocate request_id (so we can see it from ModelCall)
|
180
203
|
request_id = self._http_hooks.start_request()
|
@@ -187,7 +210,7 @@ class OpenAIAPI(ModelAPI):
|
|
187
210
|
return ModelCall.create(
|
188
211
|
request=request,
|
189
212
|
response=response,
|
190
|
-
filter=
|
213
|
+
filter=openai_media_filter,
|
191
214
|
time=self._http_hooks.end_request(request_id),
|
192
215
|
)
|
193
216
|
|
@@ -219,6 +242,7 @@ class OpenAIAPI(ModelAPI):
|
|
219
242
|
|
220
243
|
# save response for model_call
|
221
244
|
response = completion.model_dump()
|
245
|
+
self.on_response(response)
|
222
246
|
|
223
247
|
# parse out choices
|
224
248
|
choices = self._chat_choices_from_response(completion, tools)
|
@@ -250,6 +274,12 @@ class OpenAIAPI(ModelAPI):
|
|
250
274
|
except BadRequestError as e:
|
251
275
|
return self.handle_bad_request(e), model_call()
|
252
276
|
|
277
|
+
def on_response(self, response: dict[str, Any]) -> None:
|
278
|
+
pass
|
279
|
+
|
280
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
281
|
+
return openai_handle_bad_request(self.model_name, ex)
|
282
|
+
|
253
283
|
def _chat_choices_from_response(
|
254
284
|
self, response: ChatCompletion, tools: list[ToolInfo]
|
255
285
|
) -> list[ChatCompletionChoice]:
|
@@ -268,6 +298,8 @@ class OpenAIAPI(ModelAPI):
|
|
268
298
|
return True
|
269
299
|
elif isinstance(ex, APIStatusError):
|
270
300
|
return is_retryable_http_status(ex.status_code)
|
301
|
+
elif isinstance(ex, OpenAIResponseError):
|
302
|
+
return ex.code in ["rate_limit_exceeded", "server_error"]
|
271
303
|
elif isinstance(ex, APITimeoutError):
|
272
304
|
return True
|
273
305
|
else:
|
@@ -322,6 +354,7 @@ class OpenAIAPI(ModelAPI):
|
|
322
354
|
config.reasoning_effort is not None
|
323
355
|
and not self.is_gpt()
|
324
356
|
and not self.is_o1_mini()
|
357
|
+
and not self.is_o1_preview()
|
325
358
|
):
|
326
359
|
params["reasoning_effort"] = config.reasoning_effort
|
327
360
|
if config.response_schema is not None:
|
@@ -339,32 +372,6 @@ class OpenAIAPI(ModelAPI):
|
|
339
372
|
|
340
373
|
return params
|
341
374
|
|
342
|
-
# convert some well known bad request errors into ModelOutput
|
343
|
-
def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
|
344
|
-
# extract message
|
345
|
-
if isinstance(e.body, dict) and "message" in e.body.keys():
|
346
|
-
content = str(e.body.get("message"))
|
347
|
-
else:
|
348
|
-
content = e.message
|
349
|
-
|
350
|
-
# narrow stop_reason
|
351
|
-
stop_reason: StopReason | None = None
|
352
|
-
if e.code == "context_length_exceeded":
|
353
|
-
stop_reason = "model_length"
|
354
|
-
elif (
|
355
|
-
e.code == "invalid_prompt" # seems to happen for o1/o3
|
356
|
-
or e.code == "content_policy_violation" # seems to happen for vision
|
357
|
-
or e.code == "content_filter" # seems to happen on azure
|
358
|
-
):
|
359
|
-
stop_reason = "content_filter"
|
360
|
-
|
361
|
-
if stop_reason:
|
362
|
-
return ModelOutput.from_content(
|
363
|
-
model=self.model_name, content=content, stop_reason=stop_reason
|
364
|
-
)
|
365
|
-
else:
|
366
|
-
return e
|
367
|
-
|
368
375
|
|
369
376
|
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
370
377
|
"""Custom async client that deals better with long running Async requests.
|