inspect-ai 0.3.71__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +172 -154
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_computer/_common.py +117 -58
- inspect_ai/tool/_tools/_computer/_computer.py +80 -57
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +7 -1
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +91 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +8 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +12 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +78 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +20 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +175 -113
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +76 -20
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +65 -0
- inspect_ai/tool/_tools/_computer/test_args.py +151 -0
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +111 -103
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/model/_providers/util/tracker.py +0 -92
- inspect_ai/tool/_tools/_computer/_computer_split.py +0 -198
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Callable, Literal, Type, cast
|
|
13
13
|
|
14
14
|
from pydantic_core import to_jsonable_python
|
15
15
|
from tenacity import (
|
16
|
+
RetryCallState,
|
16
17
|
retry,
|
17
18
|
retry_if_exception,
|
18
19
|
stop_after_attempt,
|
@@ -20,8 +21,9 @@ from tenacity import (
|
|
20
21
|
stop_never,
|
21
22
|
wait_exponential_jitter,
|
22
23
|
)
|
24
|
+
from tenacity.stop import StopBaseT
|
23
25
|
|
24
|
-
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
|
26
|
+
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS, HTTP
|
25
27
|
from inspect_ai._util.content import (
|
26
28
|
Content,
|
27
29
|
ContentImage,
|
@@ -30,6 +32,7 @@ from inspect_ai._util.content import (
|
|
30
32
|
)
|
31
33
|
from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
|
32
34
|
from inspect_ai._util.interrupt import check_sample_interrupt
|
35
|
+
from inspect_ai._util.logger import warn_once
|
33
36
|
from inspect_ai._util.platform import platform_init
|
34
37
|
from inspect_ai._util.registry import (
|
35
38
|
RegistryInfo,
|
@@ -37,7 +40,7 @@ from inspect_ai._util.registry import (
|
|
37
40
|
registry_info,
|
38
41
|
registry_unqualified_name,
|
39
42
|
)
|
40
|
-
from inspect_ai._util.retry import
|
43
|
+
from inspect_ai._util.retry import report_http_retry
|
41
44
|
from inspect_ai._util.trace import trace_action
|
42
45
|
from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
|
43
46
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
@@ -173,11 +176,11 @@ class ModelAPI(abc.ABC):
|
|
173
176
|
"""Scope for enforcement of max_connections."""
|
174
177
|
return "default"
|
175
178
|
|
176
|
-
def
|
177
|
-
"""
|
179
|
+
def should_retry(self, ex: Exception) -> bool:
|
180
|
+
"""Should this exception be retried?
|
178
181
|
|
179
182
|
Args:
|
180
|
-
ex: Exception to check for
|
183
|
+
ex: Exception to check for retry
|
181
184
|
"""
|
182
185
|
return False
|
183
186
|
|
@@ -331,14 +334,17 @@ class Model:
|
|
331
334
|
start_time = datetime.now()
|
332
335
|
working_start = sample_working_time()
|
333
336
|
async with self._connection_concurrency(config):
|
337
|
+
from inspect_ai.log._samples import track_active_sample_retries
|
338
|
+
|
334
339
|
# generate
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
340
|
+
with track_active_sample_retries():
|
341
|
+
output = await self._generate(
|
342
|
+
input=input,
|
343
|
+
tools=tools,
|
344
|
+
tool_choice=tool_choice,
|
345
|
+
config=config,
|
346
|
+
cache=cache,
|
347
|
+
)
|
342
348
|
|
343
349
|
# update the most recent ModelEvent with the actual start/completed
|
344
350
|
# times as well as a computation of working time (events are
|
@@ -418,27 +424,27 @@ class Model:
|
|
418
424
|
if self.api.collapse_assistant_messages():
|
419
425
|
input = collapse_consecutive_assistant_messages(input)
|
420
426
|
|
421
|
-
# retry for
|
427
|
+
# retry for transient http errors:
|
428
|
+
# - no default timeout or max_retries (try forever)
|
429
|
+
# - exponential backoff starting at 3 seconds (will wait 25 minutes
|
430
|
+
# on the 10th retry,then will wait no longer than 30 minutes on
|
431
|
+
# subsequent retries)
|
432
|
+
if config.max_retries is not None and config.timeout is not None:
|
433
|
+
stop: StopBaseT = stop_after_attempt(config.max_retries) | stop_after_delay(
|
434
|
+
config.timeout
|
435
|
+
)
|
436
|
+
elif config.max_retries is not None:
|
437
|
+
stop = stop_after_attempt(config.max_retries)
|
438
|
+
elif config.timeout is not None:
|
439
|
+
stop = stop_after_delay(config.timeout)
|
440
|
+
else:
|
441
|
+
stop = stop_never
|
442
|
+
|
422
443
|
@retry(
|
423
|
-
wait=wait_exponential_jitter(max=(30 * 60), jitter=
|
424
|
-
retry=retry_if_exception(self.
|
425
|
-
stop=
|
426
|
-
|
427
|
-
stop_after_delay(config.timeout)
|
428
|
-
| stop_after_attempt(config.max_retries)
|
429
|
-
)
|
430
|
-
if config.timeout and config.max_retries
|
431
|
-
else (
|
432
|
-
stop_after_delay(config.timeout)
|
433
|
-
if config.timeout
|
434
|
-
else (
|
435
|
-
stop_after_attempt(config.max_retries)
|
436
|
-
if config.max_retries
|
437
|
-
else stop_never
|
438
|
-
)
|
439
|
-
)
|
440
|
-
),
|
441
|
-
before_sleep=functools.partial(log_rate_limit_retry, self.api.model_name),
|
444
|
+
wait=wait_exponential_jitter(initial=3, max=(30 * 60), jitter=3),
|
445
|
+
retry=retry_if_exception(self.should_retry),
|
446
|
+
stop=stop,
|
447
|
+
before_sleep=functools.partial(log_model_retry, self.api.model_name),
|
442
448
|
)
|
443
449
|
async def generate() -> ModelOutput:
|
444
450
|
check_sample_interrupt()
|
@@ -555,6 +561,30 @@ class Model:
|
|
555
561
|
# return results
|
556
562
|
return model_output
|
557
563
|
|
564
|
+
def should_retry(self, ex: BaseException) -> bool:
|
565
|
+
if isinstance(ex, Exception):
|
566
|
+
# check standard should_retry() method
|
567
|
+
retry = self.api.should_retry(ex)
|
568
|
+
if retry:
|
569
|
+
report_http_retry()
|
570
|
+
return True
|
571
|
+
|
572
|
+
# see if the API implements legacy is_rate_limit() method
|
573
|
+
is_rate_limit = getattr(self.api, "is_rate_limit", None)
|
574
|
+
if is_rate_limit:
|
575
|
+
warn_once(
|
576
|
+
logger,
|
577
|
+
f"provider '{self.name}' implements deprecated is_rate_limit() method, "
|
578
|
+
+ "please change to should_retry()",
|
579
|
+
)
|
580
|
+
retry = cast(bool, is_rate_limit(ex))
|
581
|
+
if retry:
|
582
|
+
report_http_retry()
|
583
|
+
return True
|
584
|
+
|
585
|
+
# no retry
|
586
|
+
return False
|
587
|
+
|
558
588
|
# function to verify that its okay to call model apis
|
559
589
|
def verify_model_apis(self) -> None:
|
560
590
|
if (
|
@@ -1064,6 +1094,7 @@ def tool_result_images_reducer(
|
|
1064
1094
|
messages
|
1065
1095
|
+ [
|
1066
1096
|
ChatMessageTool(
|
1097
|
+
id=message.id,
|
1067
1098
|
content=edited_tool_message_content,
|
1068
1099
|
tool_call_id=message.tool_call_id,
|
1069
1100
|
function=message.function,
|
@@ -1170,19 +1201,26 @@ def combine_messages(
|
|
1170
1201
|
a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage]
|
1171
1202
|
) -> ChatMessage:
|
1172
1203
|
if isinstance(a.content, str) and isinstance(b.content, str):
|
1173
|
-
return message_type(content=f"{a.content}\n{b.content}")
|
1204
|
+
return message_type(id=a.id, content=f"{a.content}\n{b.content}")
|
1174
1205
|
elif isinstance(a.content, list) and isinstance(b.content, list):
|
1175
|
-
return message_type(content=a.content + b.content)
|
1206
|
+
return message_type(id=a.id, content=a.content + b.content)
|
1176
1207
|
elif isinstance(a.content, str) and isinstance(b.content, list):
|
1177
|
-
return message_type(content=[ContentText(text=a.content), *b.content])
|
1208
|
+
return message_type(id=a.id, content=[ContentText(text=a.content), *b.content])
|
1178
1209
|
elif isinstance(a.content, list) and isinstance(b.content, str):
|
1179
|
-
return message_type(content=a.content + [ContentText(text=b.content)])
|
1210
|
+
return message_type(id=a.id, content=a.content + [ContentText(text=b.content)])
|
1180
1211
|
else:
|
1181
1212
|
raise TypeError(
|
1182
1213
|
f"Cannot combine messages with invalid content types: {a.content!r}, {b.content!r}"
|
1183
1214
|
)
|
1184
1215
|
|
1185
1216
|
|
1217
|
+
def log_model_retry(model_name: str, retry_state: RetryCallState) -> None:
|
1218
|
+
logger.log(
|
1219
|
+
HTTP,
|
1220
|
+
f"-> {model_name} retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
|
1221
|
+
)
|
1222
|
+
|
1223
|
+
|
1186
1224
|
def init_active_model(model: Model, config: GenerateConfig) -> None:
|
1187
1225
|
active_model_context_var.set(model)
|
1188
1226
|
set_active_generate_config(config)
|
inspect_ai/model/_openai.py
CHANGED
@@ -52,7 +52,7 @@ from ._model_output import ModelUsage, StopReason, as_stop_reason
|
|
52
52
|
|
53
53
|
|
54
54
|
def is_o_series(name: str) -> bool:
|
55
|
-
return bool(re.match(r"
|
55
|
+
return bool(re.match(r"(^|.*\/)o\d+", name))
|
56
56
|
|
57
57
|
|
58
58
|
def is_o1_mini(name: str) -> bool:
|
@@ -396,6 +396,9 @@ def content_from_openai(
|
|
396
396
|
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
|
397
397
|
parse_reasoning: bool = False,
|
398
398
|
) -> list[Content]:
|
399
|
+
# Some providers omit the type tag and use "object-with-a-single-field" encoding
|
400
|
+
if "type" not in content and len(content) == 1:
|
401
|
+
content["type"] = list(content.keys())[0] # type: ignore[arg-type]
|
399
402
|
if content["type"] == "text":
|
400
403
|
text = content["text"]
|
401
404
|
if parse_reasoning:
|
@@ -413,6 +416,8 @@ def content_from_openai(
|
|
413
416
|
return [ContentText(text=text)]
|
414
417
|
else:
|
415
418
|
return [ContentText(text=text)]
|
419
|
+
elif content["type"] == "reasoning": # type: ignore[comparison-overlap]
|
420
|
+
return [ContentReasoning(reasoning=content["reasoning"])]
|
416
421
|
elif content["type"] == "image_url":
|
417
422
|
return [
|
418
423
|
ContentImage(
|
@@ -428,6 +433,9 @@ def content_from_openai(
|
|
428
433
|
]
|
429
434
|
elif content["type"] == "refusal":
|
430
435
|
return [ContentText(text=content["refusal"])]
|
436
|
+
else:
|
437
|
+
content_type = content["type"]
|
438
|
+
raise ValueError(f"Unexpected content type '{content_type}' in message.")
|
431
439
|
|
432
440
|
|
433
441
|
def chat_message_assistant_from_openai(
|
@@ -4,9 +4,14 @@ import re
|
|
4
4
|
import sys
|
5
5
|
from copy import copy
|
6
6
|
from logging import getLogger
|
7
|
-
from typing import Any, Literal, Tuple, TypedDict, cast
|
7
|
+
from typing import Any, Literal, Optional, Tuple, TypedDict, cast
|
8
8
|
|
9
|
-
|
9
|
+
import httpcore
|
10
|
+
import httpx
|
11
|
+
|
12
|
+
from inspect_ai._util.http import is_retryable_http_status
|
13
|
+
|
14
|
+
from .util.hooks import HttpxHooks
|
10
15
|
|
11
16
|
if sys.version_info >= (3, 11):
|
12
17
|
from typing import NotRequired
|
@@ -16,13 +21,12 @@ else:
|
|
16
21
|
from anthropic import (
|
17
22
|
APIConnectionError,
|
18
23
|
APIStatusError,
|
24
|
+
APITimeoutError,
|
19
25
|
AsyncAnthropic,
|
20
26
|
AsyncAnthropicBedrock,
|
21
27
|
AsyncAnthropicVertex,
|
22
28
|
BadRequestError,
|
23
|
-
InternalServerError,
|
24
29
|
NotGiven,
|
25
|
-
RateLimitError,
|
26
30
|
)
|
27
31
|
from anthropic._types import Body
|
28
32
|
from anthropic.types import (
|
@@ -46,7 +50,6 @@ from typing_extensions import override
|
|
46
50
|
|
47
51
|
from inspect_ai._util.constants import (
|
48
52
|
BASE_64_DATA_REMOVED,
|
49
|
-
DEFAULT_MAX_RETRIES,
|
50
53
|
NO_CONTENT,
|
51
54
|
)
|
52
55
|
from inspect_ai._util.content import (
|
@@ -125,9 +128,6 @@ class AnthropicAPI(ModelAPI):
|
|
125
128
|
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
|
126
129
|
) = AsyncAnthropicBedrock(
|
127
130
|
base_url=base_url,
|
128
|
-
max_retries=(
|
129
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
130
|
-
),
|
131
131
|
aws_region=aws_region,
|
132
132
|
**model_args,
|
133
133
|
)
|
@@ -141,9 +141,6 @@ class AnthropicAPI(ModelAPI):
|
|
141
141
|
region=region,
|
142
142
|
project_id=project_id,
|
143
143
|
base_url=base_url,
|
144
|
-
max_retries=(
|
145
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
146
|
-
),
|
147
144
|
**model_args,
|
148
145
|
)
|
149
146
|
else:
|
@@ -156,14 +153,11 @@ class AnthropicAPI(ModelAPI):
|
|
156
153
|
self.client = AsyncAnthropic(
|
157
154
|
base_url=base_url,
|
158
155
|
api_key=self.api_key,
|
159
|
-
max_retries=(
|
160
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
161
|
-
),
|
162
156
|
**model_args,
|
163
157
|
)
|
164
158
|
|
165
159
|
# create time tracker
|
166
|
-
self.
|
160
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
167
161
|
|
168
162
|
@override
|
169
163
|
async def close(self) -> None:
|
@@ -183,7 +177,7 @@ class AnthropicAPI(ModelAPI):
|
|
183
177
|
config: GenerateConfig,
|
184
178
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
185
179
|
# allocate request_id (so we can see it from ModelCall)
|
186
|
-
request_id = self.
|
180
|
+
request_id = self._http_hooks.start_request()
|
187
181
|
|
188
182
|
# setup request and response for ModelCall
|
189
183
|
request: dict[str, Any] = {}
|
@@ -194,7 +188,7 @@ class AnthropicAPI(ModelAPI):
|
|
194
188
|
request=request,
|
195
189
|
response=response,
|
196
190
|
filter=model_call_filter,
|
197
|
-
time=self.
|
191
|
+
time=self._http_hooks.end_request(request_id),
|
198
192
|
)
|
199
193
|
|
200
194
|
# generate
|
@@ -204,7 +198,7 @@ class AnthropicAPI(ModelAPI):
|
|
204
198
|
tools_param,
|
205
199
|
messages,
|
206
200
|
computer_use,
|
207
|
-
) = await resolve_chat_input(
|
201
|
+
) = await self.resolve_chat_input(input, tools, config)
|
208
202
|
|
209
203
|
# prepare request params (assembed this way so we can log the raw model call)
|
210
204
|
request = dict(messages=messages)
|
@@ -223,9 +217,9 @@ class AnthropicAPI(ModelAPI):
|
|
223
217
|
request = request | req
|
224
218
|
|
225
219
|
# extra headers (for time tracker and computer use)
|
226
|
-
extra_headers = headers | {
|
220
|
+
extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id}
|
227
221
|
if computer_use:
|
228
|
-
betas.append("computer-use-
|
222
|
+
betas.append("computer-use-2025-01-24")
|
229
223
|
if len(betas) > 0:
|
230
224
|
extra_headers["anthropic-beta"] = ",".join(betas)
|
231
225
|
|
@@ -291,8 +285,6 @@ class AnthropicAPI(ModelAPI):
|
|
291
285
|
betas.append("output-128k-2025-02-19")
|
292
286
|
|
293
287
|
# config that applies to all models
|
294
|
-
if config.timeout is not None:
|
295
|
-
params["timeout"] = float(config.timeout)
|
296
288
|
if config.stop_seqs is not None:
|
297
289
|
params["stop_sequences"] = config.stop_seqs
|
298
290
|
|
@@ -326,18 +318,27 @@ class AnthropicAPI(ModelAPI):
|
|
326
318
|
def is_claude_3_5(self) -> bool:
|
327
319
|
return "claude-3-5-" in self.model_name
|
328
320
|
|
321
|
+
def is_claude_3_7(self) -> bool:
|
322
|
+
return "claude-3-7-" in self.model_name
|
323
|
+
|
329
324
|
@override
|
330
325
|
def connection_key(self) -> str:
|
331
326
|
return str(self.api_key)
|
332
327
|
|
333
328
|
@override
|
334
|
-
def
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
329
|
+
def should_retry(self, ex: Exception) -> bool:
|
330
|
+
if isinstance(ex, APIStatusError):
|
331
|
+
return is_retryable_http_status(ex.status_code)
|
332
|
+
elif isinstance(
|
333
|
+
ex,
|
334
|
+
APIConnectionError
|
335
|
+
| APITimeoutError
|
336
|
+
| httpx.RemoteProtocolError
|
337
|
+
| httpcore.RemoteProtocolError,
|
338
|
+
):
|
339
|
+
return True
|
340
|
+
else:
|
341
|
+
return False
|
341
342
|
|
342
343
|
@override
|
343
344
|
def collapse_user_messages(self) -> bool:
|
@@ -397,6 +398,148 @@ class AnthropicAPI(ModelAPI):
|
|
397
398
|
else:
|
398
399
|
return ex
|
399
400
|
|
401
|
+
async def resolve_chat_input(
|
402
|
+
self,
|
403
|
+
input: list[ChatMessage],
|
404
|
+
tools: list[ToolInfo],
|
405
|
+
config: GenerateConfig,
|
406
|
+
) -> Tuple[
|
407
|
+
list[TextBlockParam] | None, list["ToolParamDef"], list[MessageParam], bool
|
408
|
+
]:
|
409
|
+
# extract system message
|
410
|
+
system_messages, messages = split_system_messages(input, config)
|
411
|
+
|
412
|
+
# messages
|
413
|
+
message_params = [(await message_param(message)) for message in messages]
|
414
|
+
|
415
|
+
# collapse user messages (as Inspect 'tool' messages become Claude 'user' messages)
|
416
|
+
message_params = functools.reduce(
|
417
|
+
consecutive_user_message_reducer, message_params, []
|
418
|
+
)
|
419
|
+
|
420
|
+
# tools
|
421
|
+
tools_params, computer_use = self.tool_params_for_tools(tools, config)
|
422
|
+
|
423
|
+
# system messages
|
424
|
+
if len(system_messages) > 0:
|
425
|
+
system_param: list[TextBlockParam] | None = [
|
426
|
+
TextBlockParam(type="text", text=message.text)
|
427
|
+
for message in system_messages
|
428
|
+
]
|
429
|
+
else:
|
430
|
+
system_param = None
|
431
|
+
|
432
|
+
# add caching directives if necessary
|
433
|
+
cache_prompt = (
|
434
|
+
config.cache_prompt
|
435
|
+
if isinstance(config.cache_prompt, bool)
|
436
|
+
else True
|
437
|
+
if len(tools_params)
|
438
|
+
else False
|
439
|
+
)
|
440
|
+
|
441
|
+
# only certain claude models qualify
|
442
|
+
if cache_prompt:
|
443
|
+
if (
|
444
|
+
"claude-3-sonnet" in self.model_name
|
445
|
+
or "claude-2" in self.model_name
|
446
|
+
or "claude-instant" in self.model_name
|
447
|
+
):
|
448
|
+
cache_prompt = False
|
449
|
+
|
450
|
+
if cache_prompt:
|
451
|
+
# system
|
452
|
+
if system_param:
|
453
|
+
add_cache_control(system_param[-1])
|
454
|
+
# tools
|
455
|
+
if tools_params:
|
456
|
+
add_cache_control(tools_params[-1])
|
457
|
+
# last 2 user messages
|
458
|
+
user_message_params = list(
|
459
|
+
filter(lambda m: m["role"] == "user", reversed(message_params))
|
460
|
+
)
|
461
|
+
for message in user_message_params[:2]:
|
462
|
+
if isinstance(message["content"], str):
|
463
|
+
text_param = TextBlockParam(type="text", text=message["content"])
|
464
|
+
add_cache_control(text_param)
|
465
|
+
message["content"] = [text_param]
|
466
|
+
else:
|
467
|
+
content = list(message["content"])
|
468
|
+
add_cache_control(cast(dict[str, Any], content[-1]))
|
469
|
+
|
470
|
+
# return chat input
|
471
|
+
return system_param, tools_params, message_params, computer_use
|
472
|
+
|
473
|
+
def tool_params_for_tools(
|
474
|
+
self, tools: list[ToolInfo], config: GenerateConfig
|
475
|
+
) -> tuple[list["ToolParamDef"], bool]:
|
476
|
+
# tool params and computer_use bit to return
|
477
|
+
tool_params: list["ToolParamDef"] = []
|
478
|
+
computer_use = False
|
479
|
+
|
480
|
+
# for each tool, check if it has a native computer use implementation and use that
|
481
|
+
# when available (noting that we need to set the computer use request header)
|
482
|
+
for tool in tools:
|
483
|
+
computer_use_tool = (
|
484
|
+
self.computer_use_tool_param(tool)
|
485
|
+
if config.internal_tools is not False
|
486
|
+
else None
|
487
|
+
)
|
488
|
+
if computer_use_tool:
|
489
|
+
tool_params.append(computer_use_tool)
|
490
|
+
computer_use = True
|
491
|
+
else:
|
492
|
+
tool_params.append(
|
493
|
+
ToolParam(
|
494
|
+
name=tool.name,
|
495
|
+
description=tool.description,
|
496
|
+
input_schema=tool.parameters.model_dump(exclude_none=True),
|
497
|
+
)
|
498
|
+
)
|
499
|
+
|
500
|
+
return tool_params, computer_use
|
501
|
+
|
502
|
+
def computer_use_tool_param(
|
503
|
+
self, tool: ToolInfo
|
504
|
+
) -> Optional["ComputerUseToolParam"]:
|
505
|
+
# check for compatible 'computer' tool
|
506
|
+
if tool.name == "computer" and (
|
507
|
+
sorted(tool.parameters.properties.keys())
|
508
|
+
== sorted(
|
509
|
+
[
|
510
|
+
"action",
|
511
|
+
"coordinate",
|
512
|
+
"duration",
|
513
|
+
"scroll_amount",
|
514
|
+
"scroll_direction",
|
515
|
+
"start_coordinate",
|
516
|
+
"text",
|
517
|
+
]
|
518
|
+
)
|
519
|
+
):
|
520
|
+
if self.is_claude_3_5():
|
521
|
+
warn_once(
|
522
|
+
logger,
|
523
|
+
"Use of Anthropic's native computer use support is not enabled in Claude 3.5. Please use 3.7 or later to leverage the native support.",
|
524
|
+
)
|
525
|
+
return None
|
526
|
+
return ComputerUseToolParam(
|
527
|
+
type="computer_20250124",
|
528
|
+
name="computer",
|
529
|
+
# Note: The dimensions passed here for display_width_px and display_height_px should
|
530
|
+
# match the dimensions of screenshots returned by the tool.
|
531
|
+
# Those dimensions will always be one of the values in MAX_SCALING_TARGETS
|
532
|
+
# in _x11_client.py.
|
533
|
+
# TODO: enhance this code to calculate the dimensions based on the scaled screen
|
534
|
+
# size used by the container.
|
535
|
+
display_width_px=1366,
|
536
|
+
display_height_px=768,
|
537
|
+
display_number=1,
|
538
|
+
)
|
539
|
+
# not a computer_use tool
|
540
|
+
else:
|
541
|
+
return None
|
542
|
+
|
400
543
|
|
401
544
|
# native anthropic tool definitions for computer use beta
|
402
545
|
# https://docs.anthropic.com/en/docs/build-with-claude/computer-use
|
@@ -412,131 +555,6 @@ class ComputerUseToolParam(TypedDict):
|
|
412
555
|
ToolParamDef = ToolParam | ComputerUseToolParam
|
413
556
|
|
414
557
|
|
415
|
-
async def resolve_chat_input(
|
416
|
-
model: str,
|
417
|
-
input: list[ChatMessage],
|
418
|
-
tools: list[ToolInfo],
|
419
|
-
config: GenerateConfig,
|
420
|
-
) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
|
421
|
-
# extract system message
|
422
|
-
system_messages, messages = split_system_messages(input, config)
|
423
|
-
|
424
|
-
# messages
|
425
|
-
message_params = [(await message_param(message)) for message in messages]
|
426
|
-
|
427
|
-
# collapse user messages (as Inspect 'tool' messages become Claude 'user' messages)
|
428
|
-
message_params = functools.reduce(
|
429
|
-
consecutive_user_message_reducer, message_params, []
|
430
|
-
)
|
431
|
-
|
432
|
-
# tools
|
433
|
-
tools_params, computer_use = tool_params_for_tools(tools, config)
|
434
|
-
|
435
|
-
# system messages
|
436
|
-
if len(system_messages) > 0:
|
437
|
-
system_param: list[TextBlockParam] | None = [
|
438
|
-
TextBlockParam(type="text", text=message.text)
|
439
|
-
for message in system_messages
|
440
|
-
]
|
441
|
-
else:
|
442
|
-
system_param = None
|
443
|
-
|
444
|
-
# add caching directives if necessary
|
445
|
-
cache_prompt = (
|
446
|
-
config.cache_prompt
|
447
|
-
if isinstance(config.cache_prompt, bool)
|
448
|
-
else True
|
449
|
-
if len(tools_params)
|
450
|
-
else False
|
451
|
-
)
|
452
|
-
|
453
|
-
# only certain claude models qualify
|
454
|
-
if cache_prompt:
|
455
|
-
if (
|
456
|
-
"claude-3-sonnet" in model
|
457
|
-
or "claude-2" in model
|
458
|
-
or "claude-instant" in model
|
459
|
-
):
|
460
|
-
cache_prompt = False
|
461
|
-
|
462
|
-
if cache_prompt:
|
463
|
-
# system
|
464
|
-
if system_param:
|
465
|
-
add_cache_control(system_param[-1])
|
466
|
-
# tools
|
467
|
-
if tools_params:
|
468
|
-
add_cache_control(tools_params[-1])
|
469
|
-
# last 2 user messages
|
470
|
-
user_message_params = list(
|
471
|
-
filter(lambda m: m["role"] == "user", reversed(message_params))
|
472
|
-
)
|
473
|
-
for message in user_message_params[:2]:
|
474
|
-
if isinstance(message["content"], str):
|
475
|
-
text_param = TextBlockParam(type="text", text=message["content"])
|
476
|
-
add_cache_control(text_param)
|
477
|
-
message["content"] = [text_param]
|
478
|
-
else:
|
479
|
-
content = list(message["content"])
|
480
|
-
add_cache_control(cast(dict[str, Any], content[-1]))
|
481
|
-
|
482
|
-
# return chat input
|
483
|
-
return system_param, tools_params, message_params, computer_use
|
484
|
-
|
485
|
-
|
486
|
-
def tool_params_for_tools(
|
487
|
-
tools: list[ToolInfo], config: GenerateConfig
|
488
|
-
) -> tuple[list[ToolParamDef], bool]:
|
489
|
-
# tool params and computer_use bit to return
|
490
|
-
tool_params: list[ToolParamDef] = []
|
491
|
-
computer_use = False
|
492
|
-
|
493
|
-
# for each tool, check if it has a native computer use implementation and use that
|
494
|
-
# when available (noting that we need to set the computer use request header)
|
495
|
-
for tool in tools:
|
496
|
-
computer_use_tool = (
|
497
|
-
computer_use_tool_param(tool)
|
498
|
-
if config.internal_tools is not False
|
499
|
-
else None
|
500
|
-
)
|
501
|
-
if computer_use_tool:
|
502
|
-
tool_params.append(computer_use_tool)
|
503
|
-
computer_use = True
|
504
|
-
else:
|
505
|
-
tool_params.append(
|
506
|
-
ToolParam(
|
507
|
-
name=tool.name,
|
508
|
-
description=tool.description,
|
509
|
-
input_schema=tool.parameters.model_dump(exclude_none=True),
|
510
|
-
)
|
511
|
-
)
|
512
|
-
|
513
|
-
return tool_params, computer_use
|
514
|
-
|
515
|
-
|
516
|
-
def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
|
517
|
-
# check for compatible 'computer' tool
|
518
|
-
if tool.name == "computer" and (
|
519
|
-
sorted(tool.parameters.properties.keys())
|
520
|
-
== sorted(["action", "coordinate", "text"])
|
521
|
-
):
|
522
|
-
return ComputerUseToolParam(
|
523
|
-
type="computer_20241022",
|
524
|
-
name="computer",
|
525
|
-
# Note: The dimensions passed here for display_width_px and display_height_px should
|
526
|
-
# match the dimensions of screenshots returned by the tool.
|
527
|
-
# Those dimensions will always be one of the values in MAX_SCALING_TARGETS
|
528
|
-
# in _x11_client.py.
|
529
|
-
# TODO: enhance this code to calculate the dimensions based on the scaled screen
|
530
|
-
# size used by the container.
|
531
|
-
display_width_px=1366,
|
532
|
-
display_height_px=768,
|
533
|
-
display_number=1,
|
534
|
-
)
|
535
|
-
# not a computer_use tool
|
536
|
-
else:
|
537
|
-
return None
|
538
|
-
|
539
|
-
|
540
558
|
def add_cache_control(
|
541
559
|
param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
|
542
560
|
) -> None:
|