inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Callable, Literal, Type, cast
|
|
13
13
|
|
14
14
|
from pydantic_core import to_jsonable_python
|
15
15
|
from tenacity import (
|
16
|
+
RetryCallState,
|
16
17
|
retry,
|
17
18
|
retry_if_exception,
|
18
19
|
stop_after_attempt,
|
@@ -20,8 +21,9 @@ from tenacity import (
|
|
20
21
|
stop_never,
|
21
22
|
wait_exponential_jitter,
|
22
23
|
)
|
24
|
+
from tenacity.stop import StopBaseT
|
23
25
|
|
24
|
-
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
|
26
|
+
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS, HTTP
|
25
27
|
from inspect_ai._util.content import (
|
26
28
|
Content,
|
27
29
|
ContentImage,
|
@@ -30,6 +32,7 @@ from inspect_ai._util.content import (
|
|
30
32
|
)
|
31
33
|
from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
|
32
34
|
from inspect_ai._util.interrupt import check_sample_interrupt
|
35
|
+
from inspect_ai._util.logger import warn_once
|
33
36
|
from inspect_ai._util.platform import platform_init
|
34
37
|
from inspect_ai._util.registry import (
|
35
38
|
RegistryInfo,
|
@@ -37,7 +40,7 @@ from inspect_ai._util.registry import (
|
|
37
40
|
registry_info,
|
38
41
|
registry_unqualified_name,
|
39
42
|
)
|
40
|
-
from inspect_ai._util.retry import
|
43
|
+
from inspect_ai._util.retry import report_http_retry
|
41
44
|
from inspect_ai._util.trace import trace_action
|
42
45
|
from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
|
43
46
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
@@ -173,11 +176,11 @@ class ModelAPI(abc.ABC):
|
|
173
176
|
"""Scope for enforcement of max_connections."""
|
174
177
|
return "default"
|
175
178
|
|
176
|
-
def
|
177
|
-
"""
|
179
|
+
def should_retry(self, ex: Exception) -> bool:
|
180
|
+
"""Should this exception be retried?
|
178
181
|
|
179
182
|
Args:
|
180
|
-
ex: Exception to check for
|
183
|
+
ex: Exception to check for retry
|
181
184
|
"""
|
182
185
|
return False
|
183
186
|
|
@@ -331,14 +334,17 @@ class Model:
|
|
331
334
|
start_time = datetime.now()
|
332
335
|
working_start = sample_working_time()
|
333
336
|
async with self._connection_concurrency(config):
|
337
|
+
from inspect_ai.log._samples import track_active_sample_retries
|
338
|
+
|
334
339
|
# generate
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
340
|
+
with track_active_sample_retries():
|
341
|
+
output = await self._generate(
|
342
|
+
input=input,
|
343
|
+
tools=tools,
|
344
|
+
tool_choice=tool_choice,
|
345
|
+
config=config,
|
346
|
+
cache=cache,
|
347
|
+
)
|
342
348
|
|
343
349
|
# update the most recent ModelEvent with the actual start/completed
|
344
350
|
# times as well as a computation of working time (events are
|
@@ -418,27 +424,27 @@ class Model:
|
|
418
424
|
if self.api.collapse_assistant_messages():
|
419
425
|
input = collapse_consecutive_assistant_messages(input)
|
420
426
|
|
421
|
-
# retry for
|
427
|
+
# retry for transient http errors:
|
428
|
+
# - no default timeout or max_retries (try forever)
|
429
|
+
# - exponential backoff starting at 3 seconds (will wait 25 minutes
|
430
|
+
# on the 10th retry,then will wait no longer than 30 minutes on
|
431
|
+
# subsequent retries)
|
432
|
+
if config.max_retries is not None and config.timeout is not None:
|
433
|
+
stop: StopBaseT = stop_after_attempt(config.max_retries) | stop_after_delay(
|
434
|
+
config.timeout
|
435
|
+
)
|
436
|
+
elif config.max_retries is not None:
|
437
|
+
stop = stop_after_attempt(config.max_retries)
|
438
|
+
elif config.timeout is not None:
|
439
|
+
stop = stop_after_delay(config.timeout)
|
440
|
+
else:
|
441
|
+
stop = stop_never
|
442
|
+
|
422
443
|
@retry(
|
423
|
-
wait=wait_exponential_jitter(max=(30 * 60), jitter=
|
424
|
-
retry=retry_if_exception(self.
|
425
|
-
stop=
|
426
|
-
|
427
|
-
stop_after_delay(config.timeout)
|
428
|
-
| stop_after_attempt(config.max_retries)
|
429
|
-
)
|
430
|
-
if config.timeout and config.max_retries
|
431
|
-
else (
|
432
|
-
stop_after_delay(config.timeout)
|
433
|
-
if config.timeout
|
434
|
-
else (
|
435
|
-
stop_after_attempt(config.max_retries)
|
436
|
-
if config.max_retries
|
437
|
-
else stop_never
|
438
|
-
)
|
439
|
-
)
|
440
|
-
),
|
441
|
-
before_sleep=functools.partial(log_rate_limit_retry, self.api.model_name),
|
444
|
+
wait=wait_exponential_jitter(initial=3, max=(30 * 60), jitter=3),
|
445
|
+
retry=retry_if_exception(self.should_retry),
|
446
|
+
stop=stop,
|
447
|
+
before_sleep=functools.partial(log_model_retry, self.api.model_name),
|
442
448
|
)
|
443
449
|
async def generate() -> ModelOutput:
|
444
450
|
check_sample_interrupt()
|
@@ -555,6 +561,30 @@ class Model:
|
|
555
561
|
# return results
|
556
562
|
return model_output
|
557
563
|
|
564
|
+
def should_retry(self, ex: BaseException) -> bool:
|
565
|
+
if isinstance(ex, Exception):
|
566
|
+
# check standard should_retry() method
|
567
|
+
retry = self.api.should_retry(ex)
|
568
|
+
if retry:
|
569
|
+
report_http_retry()
|
570
|
+
return True
|
571
|
+
|
572
|
+
# see if the API implements legacy is_rate_limit() method
|
573
|
+
is_rate_limit = getattr(self.api, "is_rate_limit", None)
|
574
|
+
if is_rate_limit:
|
575
|
+
warn_once(
|
576
|
+
logger,
|
577
|
+
f"provider '{self.name}' implements deprecated is_rate_limit() method, "
|
578
|
+
+ "please change to should_retry()",
|
579
|
+
)
|
580
|
+
retry = cast(bool, is_rate_limit(ex))
|
581
|
+
if retry:
|
582
|
+
report_http_retry()
|
583
|
+
return True
|
584
|
+
|
585
|
+
# no retry
|
586
|
+
return False
|
587
|
+
|
558
588
|
# function to verify that its okay to call model apis
|
559
589
|
def verify_model_apis(self) -> None:
|
560
590
|
if (
|
@@ -1064,6 +1094,7 @@ def tool_result_images_reducer(
|
|
1064
1094
|
messages
|
1065
1095
|
+ [
|
1066
1096
|
ChatMessageTool(
|
1097
|
+
id=message.id,
|
1067
1098
|
content=edited_tool_message_content,
|
1068
1099
|
tool_call_id=message.tool_call_id,
|
1069
1100
|
function=message.function,
|
@@ -1170,19 +1201,26 @@ def combine_messages(
|
|
1170
1201
|
a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage]
|
1171
1202
|
) -> ChatMessage:
|
1172
1203
|
if isinstance(a.content, str) and isinstance(b.content, str):
|
1173
|
-
return message_type(content=f"{a.content}\n{b.content}")
|
1204
|
+
return message_type(id=a.id, content=f"{a.content}\n{b.content}")
|
1174
1205
|
elif isinstance(a.content, list) and isinstance(b.content, list):
|
1175
|
-
return message_type(content=a.content + b.content)
|
1206
|
+
return message_type(id=a.id, content=a.content + b.content)
|
1176
1207
|
elif isinstance(a.content, str) and isinstance(b.content, list):
|
1177
|
-
return message_type(content=[ContentText(text=a.content), *b.content])
|
1208
|
+
return message_type(id=a.id, content=[ContentText(text=a.content), *b.content])
|
1178
1209
|
elif isinstance(a.content, list) and isinstance(b.content, str):
|
1179
|
-
return message_type(content=a.content + [ContentText(text=b.content)])
|
1210
|
+
return message_type(id=a.id, content=a.content + [ContentText(text=b.content)])
|
1180
1211
|
else:
|
1181
1212
|
raise TypeError(
|
1182
1213
|
f"Cannot combine messages with invalid content types: {a.content!r}, {b.content!r}"
|
1183
1214
|
)
|
1184
1215
|
|
1185
1216
|
|
1217
|
+
def log_model_retry(model_name: str, retry_state: RetryCallState) -> None:
|
1218
|
+
logger.log(
|
1219
|
+
HTTP,
|
1220
|
+
f"-> {model_name} retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
|
1221
|
+
)
|
1222
|
+
|
1223
|
+
|
1186
1224
|
def init_active_model(model: Model, config: GenerateConfig) -> None:
|
1187
1225
|
active_model_context_var.set(model)
|
1188
1226
|
set_active_generate_config(config)
|
inspect_ai/model/_openai.py
CHANGED
@@ -52,7 +52,7 @@ from ._model_output import ModelUsage, StopReason, as_stop_reason
|
|
52
52
|
|
53
53
|
|
54
54
|
def is_o_series(name: str) -> bool:
|
55
|
-
return bool(re.match(r"
|
55
|
+
return bool(re.match(r"(^|.*\/)o\d+", name))
|
56
56
|
|
57
57
|
|
58
58
|
def is_o1_mini(name: str) -> bool:
|
@@ -396,6 +396,9 @@ def content_from_openai(
|
|
396
396
|
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
|
397
397
|
parse_reasoning: bool = False,
|
398
398
|
) -> list[Content]:
|
399
|
+
# Some providers omit the type tag and use "object-with-a-single-field" encoding
|
400
|
+
if "type" not in content and len(content) == 1:
|
401
|
+
content["type"] = list(content.keys())[0] # type: ignore[arg-type]
|
399
402
|
if content["type"] == "text":
|
400
403
|
text = content["text"]
|
401
404
|
if parse_reasoning:
|
@@ -413,6 +416,8 @@ def content_from_openai(
|
|
413
416
|
return [ContentText(text=text)]
|
414
417
|
else:
|
415
418
|
return [ContentText(text=text)]
|
419
|
+
elif content["type"] == "reasoning": # type: ignore[comparison-overlap]
|
420
|
+
return [ContentReasoning(reasoning=content["reasoning"])]
|
416
421
|
elif content["type"] == "image_url":
|
417
422
|
return [
|
418
423
|
ContentImage(
|
@@ -428,6 +433,9 @@ def content_from_openai(
|
|
428
433
|
]
|
429
434
|
elif content["type"] == "refusal":
|
430
435
|
return [ContentText(text=content["refusal"])]
|
436
|
+
else:
|
437
|
+
content_type = content["type"]
|
438
|
+
raise ValueError(f"Unexpected content type '{content_type}' in message.")
|
431
439
|
|
432
440
|
|
433
441
|
def chat_message_assistant_from_openai(
|
@@ -6,7 +6,12 @@ from copy import copy
|
|
6
6
|
from logging import getLogger
|
7
7
|
from typing import Any, Literal, Optional, Tuple, TypedDict, cast
|
8
8
|
|
9
|
-
|
9
|
+
import httpcore
|
10
|
+
import httpx
|
11
|
+
|
12
|
+
from inspect_ai._util.http import is_retryable_http_status
|
13
|
+
|
14
|
+
from .util.hooks import HttpxHooks
|
10
15
|
|
11
16
|
if sys.version_info >= (3, 11):
|
12
17
|
from typing import NotRequired
|
@@ -16,13 +21,12 @@ else:
|
|
16
21
|
from anthropic import (
|
17
22
|
APIConnectionError,
|
18
23
|
APIStatusError,
|
24
|
+
APITimeoutError,
|
19
25
|
AsyncAnthropic,
|
20
26
|
AsyncAnthropicBedrock,
|
21
27
|
AsyncAnthropicVertex,
|
22
28
|
BadRequestError,
|
23
|
-
InternalServerError,
|
24
29
|
NotGiven,
|
25
|
-
RateLimitError,
|
26
30
|
)
|
27
31
|
from anthropic._types import Body
|
28
32
|
from anthropic.types import (
|
@@ -46,7 +50,6 @@ from typing_extensions import override
|
|
46
50
|
|
47
51
|
from inspect_ai._util.constants import (
|
48
52
|
BASE_64_DATA_REMOVED,
|
49
|
-
DEFAULT_MAX_RETRIES,
|
50
53
|
NO_CONTENT,
|
51
54
|
)
|
52
55
|
from inspect_ai._util.content import (
|
@@ -125,9 +128,6 @@ class AnthropicAPI(ModelAPI):
|
|
125
128
|
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
|
126
129
|
) = AsyncAnthropicBedrock(
|
127
130
|
base_url=base_url,
|
128
|
-
max_retries=(
|
129
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
130
|
-
),
|
131
131
|
aws_region=aws_region,
|
132
132
|
**model_args,
|
133
133
|
)
|
@@ -141,9 +141,6 @@ class AnthropicAPI(ModelAPI):
|
|
141
141
|
region=region,
|
142
142
|
project_id=project_id,
|
143
143
|
base_url=base_url,
|
144
|
-
max_retries=(
|
145
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
146
|
-
),
|
147
144
|
**model_args,
|
148
145
|
)
|
149
146
|
else:
|
@@ -156,14 +153,11 @@ class AnthropicAPI(ModelAPI):
|
|
156
153
|
self.client = AsyncAnthropic(
|
157
154
|
base_url=base_url,
|
158
155
|
api_key=self.api_key,
|
159
|
-
max_retries=(
|
160
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
161
|
-
),
|
162
156
|
**model_args,
|
163
157
|
)
|
164
158
|
|
165
159
|
# create time tracker
|
166
|
-
self.
|
160
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
167
161
|
|
168
162
|
@override
|
169
163
|
async def close(self) -> None:
|
@@ -183,7 +177,7 @@ class AnthropicAPI(ModelAPI):
|
|
183
177
|
config: GenerateConfig,
|
184
178
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
185
179
|
# allocate request_id (so we can see it from ModelCall)
|
186
|
-
request_id = self.
|
180
|
+
request_id = self._http_hooks.start_request()
|
187
181
|
|
188
182
|
# setup request and response for ModelCall
|
189
183
|
request: dict[str, Any] = {}
|
@@ -194,7 +188,7 @@ class AnthropicAPI(ModelAPI):
|
|
194
188
|
request=request,
|
195
189
|
response=response,
|
196
190
|
filter=model_call_filter,
|
197
|
-
time=self.
|
191
|
+
time=self._http_hooks.end_request(request_id),
|
198
192
|
)
|
199
193
|
|
200
194
|
# generate
|
@@ -223,7 +217,7 @@ class AnthropicAPI(ModelAPI):
|
|
223
217
|
request = request | req
|
224
218
|
|
225
219
|
# extra headers (for time tracker and computer use)
|
226
|
-
extra_headers = headers | {
|
220
|
+
extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id}
|
227
221
|
if computer_use:
|
228
222
|
betas.append("computer-use-2025-01-24")
|
229
223
|
if len(betas) > 0:
|
@@ -291,8 +285,6 @@ class AnthropicAPI(ModelAPI):
|
|
291
285
|
betas.append("output-128k-2025-02-19")
|
292
286
|
|
293
287
|
# config that applies to all models
|
294
|
-
if config.timeout is not None:
|
295
|
-
params["timeout"] = float(config.timeout)
|
296
288
|
if config.stop_seqs is not None:
|
297
289
|
params["stop_sequences"] = config.stop_seqs
|
298
290
|
|
@@ -334,13 +326,19 @@ class AnthropicAPI(ModelAPI):
|
|
334
326
|
return str(self.api_key)
|
335
327
|
|
336
328
|
@override
|
337
|
-
def
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
329
|
+
def should_retry(self, ex: Exception) -> bool:
|
330
|
+
if isinstance(ex, APIStatusError):
|
331
|
+
return is_retryable_http_status(ex.status_code)
|
332
|
+
elif isinstance(
|
333
|
+
ex,
|
334
|
+
APIConnectionError
|
335
|
+
| APITimeoutError
|
336
|
+
| httpx.RemoteProtocolError
|
337
|
+
| httpcore.RemoteProtocolError,
|
338
|
+
):
|
339
|
+
return True
|
340
|
+
else:
|
341
|
+
return False
|
344
342
|
|
345
343
|
@override
|
346
344
|
def collapse_user_messages(self) -> bool:
|
@@ -27,11 +27,16 @@ from azure.ai.inference.models import (
|
|
27
27
|
UserMessage,
|
28
28
|
)
|
29
29
|
from azure.core.credentials import AzureKeyCredential
|
30
|
-
from azure.core.exceptions import
|
30
|
+
from azure.core.exceptions import (
|
31
|
+
AzureError,
|
32
|
+
HttpResponseError,
|
33
|
+
ServiceResponseError,
|
34
|
+
)
|
31
35
|
from typing_extensions import override
|
32
36
|
|
33
37
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
34
38
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
39
|
+
from inspect_ai._util.http import is_retryable_http_status
|
35
40
|
from inspect_ai._util.images import file_as_data_uri
|
36
41
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
37
42
|
from inspect_ai.tool._tool_call import ToolCall
|
@@ -232,14 +237,11 @@ class AzureAIAPI(ModelAPI):
|
|
232
237
|
return DEFAULT_MAX_TOKENS
|
233
238
|
|
234
239
|
@override
|
235
|
-
def
|
236
|
-
if isinstance(ex, HttpResponseError):
|
237
|
-
return (
|
238
|
-
|
239
|
-
|
240
|
-
or ex.status_code == 429
|
241
|
-
or ex.status_code == 500
|
242
|
-
)
|
240
|
+
def should_retry(self, ex: Exception) -> bool:
|
241
|
+
if isinstance(ex, HttpResponseError) and ex.status_code is not None:
|
242
|
+
return is_retryable_http_status(ex.status_code)
|
243
|
+
elif isinstance(ex, ServiceResponseError):
|
244
|
+
return True
|
243
245
|
else:
|
244
246
|
return False
|
245
247
|
|
@@ -1,16 +1,14 @@
|
|
1
1
|
import base64
|
2
|
+
from logging import getLogger
|
2
3
|
from typing import Any, Literal, Tuple, Union, cast
|
3
4
|
|
4
5
|
from pydantic import BaseModel, Field
|
5
6
|
from typing_extensions import override
|
6
7
|
|
7
|
-
from inspect_ai._util.
|
8
|
-
|
9
|
-
DEFAULT_MAX_TOKENS,
|
10
|
-
DEFAULT_TIMEOUT,
|
11
|
-
)
|
8
|
+
from inspect_ai._util._async import current_async_backend
|
9
|
+
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
12
10
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
13
|
-
from inspect_ai._util.error import pip_dependency_error
|
11
|
+
from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
|
14
12
|
from inspect_ai._util.images import file_as_data
|
15
13
|
from inspect_ai._util.version import verify_required_version
|
16
14
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
@@ -31,7 +29,9 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
|
31
29
|
from .util import (
|
32
30
|
model_base_url,
|
33
31
|
)
|
34
|
-
from .util.
|
32
|
+
from .util.hooks import ConverseHooks
|
33
|
+
|
34
|
+
logger = getLogger(__name__)
|
35
35
|
|
36
36
|
# Model for Bedrock Converse API (Response)
|
37
37
|
# generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
|
@@ -245,6 +245,12 @@ class BedrockAPI(ModelAPI):
|
|
245
245
|
config=config,
|
246
246
|
)
|
247
247
|
|
248
|
+
# raise if we are using trio
|
249
|
+
if current_async_backend() == "trio":
|
250
|
+
raise PrerequisiteError(
|
251
|
+
"ERROR: The bedrock provider does not work with the trio async backend."
|
252
|
+
)
|
253
|
+
|
248
254
|
# save model_args
|
249
255
|
self.model_args = model_args
|
250
256
|
|
@@ -258,7 +264,7 @@ class BedrockAPI(ModelAPI):
|
|
258
264
|
self.session = aioboto3.Session()
|
259
265
|
|
260
266
|
# create time tracker
|
261
|
-
self.
|
267
|
+
self._http_hooks = ConverseHooks(self.session)
|
262
268
|
|
263
269
|
except ImportError:
|
264
270
|
raise pip_dependency_error("Bedrock API", ["aioboto3"])
|
@@ -288,15 +294,25 @@ class BedrockAPI(ModelAPI):
|
|
288
294
|
return DEFAULT_MAX_TOKENS
|
289
295
|
|
290
296
|
@override
|
291
|
-
def
|
297
|
+
def should_retry(self, ex: Exception) -> bool:
|
292
298
|
from botocore.exceptions import ClientError
|
293
299
|
|
294
300
|
# Look for an explicit throttle exception
|
295
301
|
if isinstance(ex, ClientError):
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
302
|
+
error_code = ex.response.get("Error", {}).get("Code", "")
|
303
|
+
return error_code in [
|
304
|
+
"ThrottlingException",
|
305
|
+
"RequestLimitExceeded",
|
306
|
+
"Throttling",
|
307
|
+
"RequestThrottled",
|
308
|
+
"TooManyRequestsException",
|
309
|
+
"ProvisionedThroughputExceededException",
|
310
|
+
"TransactionInProgressException",
|
311
|
+
"RequestTimeout",
|
312
|
+
"ServiceUnavailable",
|
313
|
+
]
|
314
|
+
else:
|
315
|
+
return False
|
300
316
|
|
301
317
|
@override
|
302
318
|
def collapse_user_messages(self) -> bool:
|
@@ -317,20 +333,13 @@ class BedrockAPI(ModelAPI):
|
|
317
333
|
from botocore.exceptions import ClientError
|
318
334
|
|
319
335
|
# The bedrock client
|
320
|
-
request_id = self.
|
336
|
+
request_id = self._http_hooks.start_request()
|
321
337
|
async with self.session.client( # type: ignore[call-overload]
|
322
338
|
service_name="bedrock-runtime",
|
323
339
|
endpoint_url=self.base_url,
|
324
340
|
config=Config(
|
325
|
-
|
326
|
-
|
327
|
-
retries=dict(
|
328
|
-
max_attempts=config.max_retries
|
329
|
-
if config.max_retries
|
330
|
-
else DEFAULT_MAX_RETRIES,
|
331
|
-
mode="adaptive",
|
332
|
-
),
|
333
|
-
user_agent_extra=self._time_tracker.user_agent_extra(request_id),
|
341
|
+
retries=dict(mode="adaptive"),
|
342
|
+
user_agent_extra=self._http_hooks.user_agent_extra(request_id),
|
334
343
|
),
|
335
344
|
**self.model_args,
|
336
345
|
) as client:
|
@@ -370,7 +379,7 @@ class BedrockAPI(ModelAPI):
|
|
370
379
|
request.model_dump(exclude_none=True)
|
371
380
|
),
|
372
381
|
response=response,
|
373
|
-
time=self.
|
382
|
+
time=self._http_hooks.end_request(request_id),
|
374
383
|
)
|
375
384
|
|
376
385
|
try:
|
@@ -16,10 +16,10 @@ from .util import (
|
|
16
16
|
chat_api_input,
|
17
17
|
chat_api_request,
|
18
18
|
environment_prerequisite_error,
|
19
|
-
is_chat_api_rate_limit,
|
20
19
|
model_base_url,
|
20
|
+
should_retry_chat_api_error,
|
21
21
|
)
|
22
|
-
from .util.
|
22
|
+
from .util.hooks import HttpxHooks
|
23
23
|
|
24
24
|
# https://developers.cloudflare.com/workers-ai/models/#text-generation
|
25
25
|
|
@@ -51,7 +51,7 @@ class CloudFlareAPI(ModelAPI):
|
|
51
51
|
if not self.api_key:
|
52
52
|
raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
|
53
53
|
self.client = httpx.AsyncClient()
|
54
|
-
self.
|
54
|
+
self._http_hooks = HttpxHooks(self.client)
|
55
55
|
base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
|
56
56
|
self.base_url = (
|
57
57
|
base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
|
@@ -79,7 +79,7 @@ class CloudFlareAPI(ModelAPI):
|
|
79
79
|
json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
|
80
80
|
|
81
81
|
# request_id
|
82
|
-
request_id = self.
|
82
|
+
request_id = self._http_hooks.start_request()
|
83
83
|
|
84
84
|
# setup response
|
85
85
|
response: dict[str, Any] = {}
|
@@ -88,7 +88,7 @@ class CloudFlareAPI(ModelAPI):
|
|
88
88
|
return ModelCall.create(
|
89
89
|
request=json,
|
90
90
|
response=response,
|
91
|
-
time=self.
|
91
|
+
time=self._http_hooks.end_request(request_id),
|
92
92
|
)
|
93
93
|
|
94
94
|
# make the call
|
@@ -98,10 +98,9 @@ class CloudFlareAPI(ModelAPI):
|
|
98
98
|
url=f"{chat_url}/{self.model_name}",
|
99
99
|
headers={
|
100
100
|
"Authorization": f"Bearer {self.api_key}",
|
101
|
-
|
101
|
+
HttpxHooks.REQUEST_ID_HEADER: request_id,
|
102
102
|
},
|
103
103
|
json=json,
|
104
|
-
config=config,
|
105
104
|
)
|
106
105
|
|
107
106
|
# handle response
|
@@ -127,8 +126,8 @@ class CloudFlareAPI(ModelAPI):
|
|
127
126
|
raise RuntimeError(f"Error calling {self.model_name}: {error}")
|
128
127
|
|
129
128
|
@override
|
130
|
-
def
|
131
|
-
return
|
129
|
+
def should_retry(self, ex: Exception) -> bool:
|
130
|
+
return should_retry_chat_api_error(ex)
|
132
131
|
|
133
132
|
# cloudflare enforces rate limits by model for each account
|
134
133
|
@override
|
@@ -3,7 +3,11 @@ from typing import Any, List, Literal, get_args
|
|
3
3
|
|
4
4
|
from goodfire import AsyncClient
|
5
5
|
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
-
from goodfire.api.exceptions import
|
6
|
+
from goodfire.api.exceptions import (
|
7
|
+
InvalidRequestException,
|
8
|
+
RateLimitException,
|
9
|
+
ServerErrorException,
|
10
|
+
)
|
7
11
|
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
8
12
|
from typing_extensions import override
|
9
13
|
|
@@ -163,9 +167,9 @@ class GoodfireAPI(ModelAPI):
|
|
163
167
|
return ex
|
164
168
|
|
165
169
|
@override
|
166
|
-
def
|
170
|
+
def should_retry(self, ex: Exception) -> bool:
|
167
171
|
"""Check if exception is due to rate limiting."""
|
168
|
-
return isinstance(ex, RateLimitException)
|
172
|
+
return isinstance(ex, RateLimitException | ServerErrorException)
|
169
173
|
|
170
174
|
@override
|
171
175
|
def connection_key(self) -> str:
|