inspect-ai 0.3.71__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +172 -154
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_computer/_common.py +117 -58
- inspect_ai/tool/_tools/_computer/_computer.py +80 -57
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +7 -1
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +91 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +8 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +12 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +78 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +20 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +175 -113
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +76 -20
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +65 -0
- inspect_ai/tool/_tools/_computer/test_args.py +151 -0
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +111 -103
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/model/_providers/util/tracker.py +0 -92
- inspect_ai/tool/_tools/_computer/_computer_split.py +0 -198
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from httpcore import ReadTimeout
|
|
7
7
|
from httpx import ReadTimeout as AsyncReadTimeout
|
8
8
|
from mistralai import (
|
9
9
|
ContentChunk,
|
10
|
+
DocumentURLChunk,
|
10
11
|
FunctionCall,
|
11
12
|
FunctionName,
|
12
13
|
ImageURL,
|
@@ -22,6 +23,12 @@ from mistralai.models import (
|
|
22
23
|
ChatCompletionChoice as MistralChatCompletionChoice,
|
23
24
|
)
|
24
25
|
from mistralai.models import Function as MistralFunction
|
26
|
+
from mistralai.models import (
|
27
|
+
JSONSchema as MistralJSONSchema,
|
28
|
+
)
|
29
|
+
from mistralai.models import (
|
30
|
+
ResponseFormat as MistralResponseFormat,
|
31
|
+
)
|
25
32
|
from mistralai.models import SDKError
|
26
33
|
from mistralai.models import SystemMessage as MistralSystemMessage
|
27
34
|
from mistralai.models import Tool as MistralTool
|
@@ -38,11 +45,9 @@ from typing_extensions import override
|
|
38
45
|
|
39
46
|
# TODO: Migration guide:
|
40
47
|
# https://github.com/mistralai/client-python/blob/main/MIGRATION.md
|
41
|
-
from inspect_ai._util.constants import
|
42
|
-
DEFAULT_TIMEOUT,
|
43
|
-
NO_CONTENT,
|
44
|
-
)
|
48
|
+
from inspect_ai._util.constants import NO_CONTENT
|
45
49
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
50
|
+
from inspect_ai._util.http import is_retryable_http_status
|
46
51
|
from inspect_ai._util.images import file_as_data_uri
|
47
52
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
53
|
|
@@ -61,7 +66,7 @@ from .._model_output import (
|
|
61
66
|
StopReason,
|
62
67
|
)
|
63
68
|
from .util import environment_prerequisite_error, model_base_url
|
64
|
-
from .util.
|
69
|
+
from .util.hooks import HttpxHooks
|
65
70
|
|
66
71
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
67
72
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -127,16 +132,12 @@ class MistralAPI(ModelAPI):
|
|
127
132
|
config: GenerateConfig,
|
128
133
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
129
134
|
# create client
|
130
|
-
with Mistral(
|
131
|
-
api_key=self.api_key,
|
132
|
-
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
|
133
|
-
**self.model_args,
|
134
|
-
) as client:
|
135
|
+
with Mistral(api_key=self.api_key, **self.model_args) as client:
|
135
136
|
# create time tracker
|
136
|
-
|
137
|
+
http_hooks = HttpxHooks(client.sdk_configuration.async_client)
|
137
138
|
|
138
139
|
# build request
|
139
|
-
request_id =
|
140
|
+
request_id = http_hooks.start_request()
|
140
141
|
request: dict[str, Any] = dict(
|
141
142
|
model=self.model_name,
|
142
143
|
messages=await mistral_chat_messages(input),
|
@@ -144,7 +145,7 @@ class MistralAPI(ModelAPI):
|
|
144
145
|
tool_choice=(
|
145
146
|
mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
|
146
147
|
),
|
147
|
-
http_headers={
|
148
|
+
http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
148
149
|
)
|
149
150
|
if config.temperature is not None:
|
150
151
|
request["temperature"] = config.temperature
|
@@ -154,6 +155,18 @@ class MistralAPI(ModelAPI):
|
|
154
155
|
request["max_tokens"] = config.max_tokens
|
155
156
|
if config.seed is not None:
|
156
157
|
request["random_seed"] = config.seed
|
158
|
+
if config.response_schema is not None:
|
159
|
+
request["response_format"] = MistralResponseFormat(
|
160
|
+
type="json_schema",
|
161
|
+
json_schema=MistralJSONSchema(
|
162
|
+
name=config.response_schema.name,
|
163
|
+
description=config.response_schema.description,
|
164
|
+
schema_definition=config.response_schema.json_schema.model_dump(
|
165
|
+
exclude_none=True
|
166
|
+
),
|
167
|
+
strict=config.response_schema.strict,
|
168
|
+
),
|
169
|
+
)
|
157
170
|
|
158
171
|
# prepare response for inclusion in model call
|
159
172
|
response: dict[str, Any] = {}
|
@@ -169,7 +182,7 @@ class MistralAPI(ModelAPI):
|
|
169
182
|
return ModelCall.create(
|
170
183
|
request=req,
|
171
184
|
response=response,
|
172
|
-
time=
|
185
|
+
time=http_hooks.end_request(request_id),
|
173
186
|
)
|
174
187
|
|
175
188
|
# send request
|
@@ -205,12 +218,13 @@ class MistralAPI(ModelAPI):
|
|
205
218
|
), model_call()
|
206
219
|
|
207
220
|
@override
|
208
|
-
def
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
221
|
+
def should_retry(self, ex: Exception) -> bool:
|
222
|
+
if isinstance(ex, SDKError):
|
223
|
+
return is_retryable_http_status(ex.status_code)
|
224
|
+
elif isinstance(ex, ReadTimeout | AsyncReadTimeout):
|
225
|
+
return True
|
226
|
+
else:
|
227
|
+
return False
|
214
228
|
|
215
229
|
@override
|
216
230
|
def connection_key(self) -> str:
|
@@ -462,6 +476,8 @@ def completion_content_chunk(content: ContentChunk) -> Content:
|
|
462
476
|
raise TypeError("ReferenceChunk content is not supported by Inspect.")
|
463
477
|
elif isinstance(content, TextChunk):
|
464
478
|
return ContentText(text=content.text)
|
479
|
+
elif isinstance(content, DocumentURLChunk):
|
480
|
+
return ContentText(text=content.document_url)
|
465
481
|
else:
|
466
482
|
if isinstance(content.image_url, str):
|
467
483
|
return ContentImage(image=content.image_url)
|
@@ -7,25 +7,22 @@ import httpx
|
|
7
7
|
from openai import (
|
8
8
|
DEFAULT_CONNECTION_LIMITS,
|
9
9
|
DEFAULT_TIMEOUT,
|
10
|
-
|
10
|
+
APIStatusError,
|
11
11
|
APITimeoutError,
|
12
12
|
AsyncAzureOpenAI,
|
13
13
|
AsyncOpenAI,
|
14
14
|
BadRequestError,
|
15
|
-
InternalServerError,
|
16
15
|
RateLimitError,
|
17
16
|
)
|
18
17
|
from openai._types import NOT_GIVEN
|
19
|
-
from openai.types.chat import
|
20
|
-
ChatCompletion,
|
21
|
-
)
|
18
|
+
from openai.types.chat import ChatCompletion
|
22
19
|
from typing_extensions import override
|
23
20
|
|
24
|
-
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
25
21
|
from inspect_ai._util.error import PrerequisiteError
|
22
|
+
from inspect_ai._util.http import is_retryable_http_status
|
26
23
|
from inspect_ai._util.logger import warn_once
|
27
24
|
from inspect_ai.model._openai import chat_choices_from_openai
|
28
|
-
from inspect_ai.model._providers.util.
|
25
|
+
from inspect_ai.model._providers.util.hooks import HttpxHooks
|
29
26
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
30
27
|
|
31
28
|
from .._chat_message import ChatMessage
|
@@ -130,9 +127,6 @@ class OpenAIAPI(ModelAPI):
|
|
130
127
|
api_key=self.api_key,
|
131
128
|
azure_endpoint=base_url,
|
132
129
|
azure_deployment=model_name,
|
133
|
-
max_retries=(
|
134
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
135
|
-
),
|
136
130
|
http_client=http_client,
|
137
131
|
**model_args,
|
138
132
|
)
|
@@ -140,15 +134,12 @@ class OpenAIAPI(ModelAPI):
|
|
140
134
|
self.client = AsyncOpenAI(
|
141
135
|
api_key=self.api_key,
|
142
136
|
base_url=model_base_url(base_url, "OPENAI_BASE_URL"),
|
143
|
-
max_retries=(
|
144
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
145
|
-
),
|
146
137
|
http_client=http_client,
|
147
138
|
**model_args,
|
148
139
|
)
|
149
140
|
|
150
141
|
# create time tracker
|
151
|
-
self.
|
142
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
152
143
|
|
153
144
|
def is_azure(self) -> bool:
|
154
145
|
return self.service == "azure"
|
@@ -186,7 +177,7 @@ class OpenAIAPI(ModelAPI):
|
|
186
177
|
)
|
187
178
|
|
188
179
|
# allocate request_id (so we can see it from ModelCall)
|
189
|
-
request_id = self.
|
180
|
+
request_id = self._http_hooks.start_request()
|
190
181
|
|
191
182
|
# setup request and response for ModelCall
|
192
183
|
request: dict[str, Any] = {}
|
@@ -197,7 +188,7 @@ class OpenAIAPI(ModelAPI):
|
|
197
188
|
request=request,
|
198
189
|
response=response,
|
199
190
|
filter=image_url_filter,
|
200
|
-
time=self.
|
191
|
+
time=self._http_hooks.end_request(request_id),
|
201
192
|
)
|
202
193
|
|
203
194
|
# unlike text models, vision models require a max_tokens (and set it to a very low
|
@@ -216,7 +207,7 @@ class OpenAIAPI(ModelAPI):
|
|
216
207
|
tool_choice=openai_chat_tool_choice(tool_choice)
|
217
208
|
if len(tools) > 0
|
218
209
|
else NOT_GIVEN,
|
219
|
-
extra_headers={
|
210
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
220
211
|
**self.completion_params(config, len(tools) > 0),
|
221
212
|
)
|
222
213
|
|
@@ -266,17 +257,21 @@ class OpenAIAPI(ModelAPI):
|
|
266
257
|
return chat_choices_from_openai(response, tools)
|
267
258
|
|
268
259
|
@override
|
269
|
-
def
|
260
|
+
def should_retry(self, ex: Exception) -> bool:
|
270
261
|
if isinstance(ex, RateLimitError):
|
271
262
|
# Do not retry on these rate limit errors
|
272
263
|
# The quota exceeded one is related to monthly account quotas.
|
273
|
-
if "You exceeded your current quota"
|
264
|
+
if "You exceeded your current quota" in ex.message:
|
265
|
+
warn_once(logger, f"OpenAI quota exceeded, not retrying: {ex.message}")
|
266
|
+
return False
|
267
|
+
else:
|
274
268
|
return True
|
275
|
-
elif isinstance(
|
276
|
-
|
277
|
-
):
|
269
|
+
elif isinstance(ex, APIStatusError):
|
270
|
+
return is_retryable_http_status(ex.status_code)
|
271
|
+
elif isinstance(ex, APITimeoutError):
|
278
272
|
return True
|
279
|
-
|
273
|
+
else:
|
274
|
+
return False
|
280
275
|
|
281
276
|
@override
|
282
277
|
def connection_key(self) -> str:
|
@@ -315,8 +310,6 @@ class OpenAIAPI(ModelAPI):
|
|
315
310
|
params["temperature"] = 1
|
316
311
|
if config.top_p is not None:
|
317
312
|
params["top_p"] = config.top_p
|
318
|
-
if config.timeout is not None:
|
319
|
-
params["timeout"] = float(config.timeout)
|
320
313
|
if config.num_choices is not None:
|
321
314
|
params["n"] = config.num_choices
|
322
315
|
if config.logprobs is not None:
|
@@ -331,6 +324,18 @@ class OpenAIAPI(ModelAPI):
|
|
331
324
|
and not self.is_o1_mini()
|
332
325
|
):
|
333
326
|
params["reasoning_effort"] = config.reasoning_effort
|
327
|
+
if config.response_schema is not None:
|
328
|
+
params["response_format"] = dict(
|
329
|
+
type="json_schema",
|
330
|
+
json_schema=dict(
|
331
|
+
name=config.response_schema.name,
|
332
|
+
schema=config.response_schema.json_schema.model_dump(
|
333
|
+
exclude_none=True
|
334
|
+
),
|
335
|
+
description=config.response_schema.description,
|
336
|
+
strict=config.response_schema.strict,
|
337
|
+
),
|
338
|
+
)
|
334
339
|
|
335
340
|
return params
|
336
341
|
|
@@ -107,7 +107,7 @@ def chat_messages(
|
|
107
107
|
) -> list[ChatCompletionMessageParam]:
|
108
108
|
# o1 does not allow system messages so convert system -> user
|
109
109
|
messages: list[ChatMessage] = [
|
110
|
-
ChatMessageUser(content=message.content)
|
110
|
+
ChatMessageUser(id=message.id, content=message.content)
|
111
111
|
if message.role == "system"
|
112
112
|
else message
|
113
113
|
for message in input
|
@@ -34,8 +34,8 @@ from .util import (
|
|
34
34
|
chat_api_input,
|
35
35
|
chat_api_request,
|
36
36
|
environment_prerequisite_error,
|
37
|
-
is_chat_api_rate_limit,
|
38
37
|
model_base_url,
|
38
|
+
should_retry_chat_api_error,
|
39
39
|
)
|
40
40
|
|
41
41
|
|
@@ -186,7 +186,6 @@ class TogetherRESTAPI(ModelAPI):
|
|
186
186
|
url=f"{chat_url}",
|
187
187
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
188
188
|
json=json,
|
189
|
-
config=config,
|
190
189
|
)
|
191
190
|
|
192
191
|
if "error" in response:
|
@@ -215,8 +214,8 @@ class TogetherRESTAPI(ModelAPI):
|
|
215
214
|
return ModelOutput(model=model, choices=choices, usage=usage)
|
216
215
|
|
217
216
|
@override
|
218
|
-
def
|
219
|
-
return
|
217
|
+
def should_retry(self, ex: Exception) -> bool:
|
218
|
+
return should_retry_chat_api_error(ex)
|
220
219
|
|
221
220
|
# cloudflare enforces rate limits by model for each account
|
222
221
|
@override
|
@@ -5,7 +5,7 @@ from .chatapi import (
|
|
5
5
|
ChatAPIMessage,
|
6
6
|
chat_api_input,
|
7
7
|
chat_api_request,
|
8
|
-
|
8
|
+
should_retry_chat_api_error,
|
9
9
|
)
|
10
10
|
from .hf_handler import HFHandler
|
11
11
|
from .llama31 import Llama31Handler
|
@@ -19,7 +19,7 @@ __all__ = [
|
|
19
19
|
"as_stop_reason",
|
20
20
|
"chat_api_request",
|
21
21
|
"chat_api_input",
|
22
|
-
"
|
22
|
+
"should_retry_chat_api_error",
|
23
23
|
"model_base_url",
|
24
24
|
"parse_tool_call",
|
25
25
|
"tool_parse_error_message",
|
@@ -7,17 +7,15 @@ from tenacity import (
|
|
7
7
|
retry,
|
8
8
|
retry_if_exception,
|
9
9
|
stop_after_attempt,
|
10
|
-
stop_after_delay,
|
11
10
|
wait_exponential_jitter,
|
12
11
|
)
|
13
12
|
|
14
|
-
from inspect_ai._util.
|
15
|
-
from inspect_ai._util.
|
13
|
+
from inspect_ai._util.http import is_retryable_http_status
|
14
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
16
15
|
from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool
|
17
16
|
from inspect_ai.tool._tool_info import ToolInfo
|
18
17
|
|
19
18
|
from ..._chat_message import ChatMessage
|
20
|
-
from ..._generate_config import GenerateConfig
|
21
19
|
|
22
20
|
logger = getLogger(__name__)
|
23
21
|
|
@@ -75,21 +73,13 @@ async def chat_api_request(
|
|
75
73
|
url: str,
|
76
74
|
headers: dict[str, Any],
|
77
75
|
json: Any,
|
78
|
-
config: GenerateConfig,
|
79
76
|
) -> Any:
|
80
|
-
# provide default max_retries
|
81
|
-
max_retries = config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
82
|
-
|
83
77
|
# define call w/ retry policy
|
84
78
|
@retry(
|
85
79
|
wait=wait_exponential_jitter(),
|
86
|
-
stop=(
|
87
|
-
(stop_after_attempt(max_retries) | stop_after_delay(config.timeout))
|
88
|
-
if config.timeout
|
89
|
-
else stop_after_attempt(max_retries)
|
90
|
-
),
|
80
|
+
stop=(stop_after_attempt(2)),
|
91
81
|
retry=retry_if_exception(httpx_should_retry),
|
92
|
-
before_sleep=
|
82
|
+
before_sleep=log_httpx_retry_attempt(model_name),
|
93
83
|
)
|
94
84
|
async def call_api() -> Any:
|
95
85
|
response = await client.post(url=url, headers=headers, json=json)
|
@@ -104,14 +94,11 @@ async def chat_api_request(
|
|
104
94
|
# checking for rate limit errors needs to punch through the RetryError and
|
105
95
|
# look at its `__cause__`. we've observed Cloudflare giving transient 500
|
106
96
|
# status as well as a ReadTimeout, so we count these as rate limit errors
|
107
|
-
def
|
97
|
+
def should_retry_chat_api_error(ex: BaseException) -> bool:
|
108
98
|
return isinstance(ex, RetryError) and (
|
109
99
|
(
|
110
100
|
isinstance(ex.__cause__, httpx.HTTPStatusError)
|
111
|
-
and (
|
112
|
-
ex.__cause__.response.status_code == 429
|
113
|
-
or ex.__cause__.response.status_code == 500
|
114
|
-
)
|
101
|
+
and is_retryable_http_status(ex.__cause__.response.status_code)
|
115
102
|
)
|
116
103
|
or isinstance(ex.__cause__, httpx.ReadTimeout)
|
117
104
|
)
|
@@ -0,0 +1,165 @@
|
|
1
|
+
import re
|
2
|
+
import time
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import Any, Mapping, NamedTuple, cast
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
from shortuuid import uuid
|
8
|
+
|
9
|
+
from inspect_ai._util.constants import HTTP
|
10
|
+
from inspect_ai._util.retry import report_http_retry
|
11
|
+
|
12
|
+
logger = getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class RequestInfo(NamedTuple):
|
16
|
+
attempts: int
|
17
|
+
last_request: float
|
18
|
+
|
19
|
+
|
20
|
+
class HttpHooks:
|
21
|
+
"""Class which hooks various HTTP clients for improved tracking/logging.
|
22
|
+
|
23
|
+
A special header is injected into requests which is then read from
|
24
|
+
a request event hook -- this creates a record of when the request
|
25
|
+
started. Note that with retries a single request_id could be started
|
26
|
+
several times; our request hook makes sure we always track the time of
|
27
|
+
the last request.
|
28
|
+
|
29
|
+
There is an 'end_request()' method which gets the total request time
|
30
|
+
for a request_id and then purges the request_id from our tracking (so
|
31
|
+
the dict doesn't grow unbounded)
|
32
|
+
|
33
|
+
Additionally, an http response hook is installed and used for logging
|
34
|
+
requests for the 'http' log-level
|
35
|
+
"""
|
36
|
+
|
37
|
+
REQUEST_ID_HEADER = "x-irid"
|
38
|
+
|
39
|
+
def __init__(self) -> None:
|
40
|
+
# track request start times
|
41
|
+
self._requests: dict[str, RequestInfo] = {}
|
42
|
+
|
43
|
+
def start_request(self) -> str:
|
44
|
+
request_id = uuid()
|
45
|
+
self._requests[request_id] = RequestInfo(0, time.monotonic())
|
46
|
+
return request_id
|
47
|
+
|
48
|
+
def end_request(self, request_id: str) -> float:
|
49
|
+
# read the request info (if available) and purge from dict
|
50
|
+
request_info = self._requests.pop(request_id, None)
|
51
|
+
if request_info is None:
|
52
|
+
raise RuntimeError(f"request_id not registered: {request_id}")
|
53
|
+
|
54
|
+
# return elapsed time
|
55
|
+
return time.monotonic() - request_info.last_request
|
56
|
+
|
57
|
+
def update_request_time(self, request_id: str) -> None:
|
58
|
+
request_info = self._requests.get(request_id, None)
|
59
|
+
if not request_info:
|
60
|
+
raise RuntimeError(f"No request registered for request_id: {request_id}")
|
61
|
+
|
62
|
+
# update the attempts and last request time
|
63
|
+
request_info = RequestInfo(request_info.attempts + 1, time.monotonic())
|
64
|
+
self._requests[request_id] = request_info
|
65
|
+
|
66
|
+
# trace a retry if this is attempt > 1
|
67
|
+
if request_info.attempts > 1:
|
68
|
+
report_http_retry()
|
69
|
+
|
70
|
+
|
71
|
+
class ConverseHooks(HttpHooks):
|
72
|
+
def __init__(self, session: Any) -> None:
|
73
|
+
from aiobotocore.session import AioSession
|
74
|
+
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
# register hooks
|
78
|
+
session = cast(AioSession, session._session)
|
79
|
+
session.register(
|
80
|
+
"before-send.bedrock-runtime.Converse", self.converse_before_send
|
81
|
+
)
|
82
|
+
session.register(
|
83
|
+
"after-call.bedrock-runtime.Converse", self.converse_after_call
|
84
|
+
)
|
85
|
+
|
86
|
+
def converse_before_send(self, **kwargs: Any) -> None:
|
87
|
+
user_agent = kwargs["request"].headers["User-Agent"].decode()
|
88
|
+
match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
|
89
|
+
if match:
|
90
|
+
request_id = match.group(1)
|
91
|
+
self.update_request_time(request_id)
|
92
|
+
|
93
|
+
def converse_after_call(self, http_response: Any, **kwargs: Any) -> None:
|
94
|
+
from botocore.awsrequest import AWSResponse
|
95
|
+
|
96
|
+
response = cast(AWSResponse, http_response)
|
97
|
+
logger.log(HTTP, f"POST {response.url} - {response.status_code}")
|
98
|
+
|
99
|
+
def user_agent_extra(self, request_id: str) -> str:
|
100
|
+
return f"{self.USER_AGENT_PREFIX}{request_id}"
|
101
|
+
|
102
|
+
USER_AGENT_PREFIX = "ins/rid#"
|
103
|
+
|
104
|
+
|
105
|
+
class HttpxHooks(HttpHooks):
|
106
|
+
def __init__(self, client: httpx.AsyncClient):
|
107
|
+
super().__init__()
|
108
|
+
|
109
|
+
# install hooks
|
110
|
+
client.event_hooks["request"].append(self.request_hook)
|
111
|
+
client.event_hooks["response"].append(self.response_hook)
|
112
|
+
|
113
|
+
async def request_hook(self, request: httpx.Request) -> None:
|
114
|
+
# update the last request time for this request id (as there could be retries)
|
115
|
+
request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
|
116
|
+
if request_id:
|
117
|
+
self.update_request_time(request_id)
|
118
|
+
|
119
|
+
async def response_hook(self, response: httpx.Response) -> None:
|
120
|
+
message = f'{response.request.method} {response.request.url} "{response.http_version} {response.status_code} {response.reason_phrase}" '
|
121
|
+
logger.log(HTTP, message)
|
122
|
+
|
123
|
+
|
124
|
+
def urllib3_hooks() -> HttpHooks:
|
125
|
+
import urllib3
|
126
|
+
from urllib3.connectionpool import HTTPConnectionPool
|
127
|
+
from urllib3.response import BaseHTTPResponse
|
128
|
+
|
129
|
+
class Urllib3Hooks(HttpHooks):
|
130
|
+
def request_hook(self, headers: Mapping[str, str]) -> None:
|
131
|
+
# update the last request time for this request id (as there could be retries)
|
132
|
+
request_id = headers.get(self.REQUEST_ID_HEADER, None)
|
133
|
+
if request_id:
|
134
|
+
self.update_request_time(request_id)
|
135
|
+
|
136
|
+
def response_hook(
|
137
|
+
self, method: str, url: str, response: BaseHTTPResponse
|
138
|
+
) -> None:
|
139
|
+
message = f'{method} {url} "{response.version_string} {response.status} {response.reason}" '
|
140
|
+
logger.log(HTTP, message)
|
141
|
+
|
142
|
+
global _urlilb3_hooks
|
143
|
+
if _urlilb3_hooks is None:
|
144
|
+
# one time patch of urlopen
|
145
|
+
urlilb3_hooks = Urllib3Hooks()
|
146
|
+
original_urlopen = urllib3.connectionpool.HTTPConnectionPool.urlopen
|
147
|
+
|
148
|
+
def patched_urlopen(
|
149
|
+
self: HTTPConnectionPool, method: str, url: str, **kwargs: Any
|
150
|
+
) -> BaseHTTPResponse:
|
151
|
+
headers = kwargs.get("headers", {})
|
152
|
+
urlilb3_hooks.request_hook(headers)
|
153
|
+
response = original_urlopen(self, method, url, **kwargs)
|
154
|
+
urlilb3_hooks.response_hook(method, f"{self.host}{url}", response)
|
155
|
+
return response
|
156
|
+
|
157
|
+
urllib3.connectionpool.HTTPConnectionPool.urlopen = patched_urlopen # type: ignore[assignment,method-assign]
|
158
|
+
|
159
|
+
# assign to global hooks instance
|
160
|
+
_urlilb3_hooks = urlilb3_hooks
|
161
|
+
|
162
|
+
return _urlilb3_hooks
|
163
|
+
|
164
|
+
|
165
|
+
_urlilb3_hooks: HttpHooks | None = None
|
@@ -4,7 +4,13 @@ from copy import copy
|
|
4
4
|
from typing import Any, cast
|
5
5
|
|
6
6
|
import vertexai # type: ignore
|
7
|
-
from google.api_core.exceptions import
|
7
|
+
from google.api_core.exceptions import (
|
8
|
+
Aborted,
|
9
|
+
ClientError,
|
10
|
+
DeadlineExceeded,
|
11
|
+
ServiceUnavailable,
|
12
|
+
)
|
13
|
+
from google.api_core.retry import if_transient_error
|
8
14
|
from google.protobuf.json_format import MessageToDict
|
9
15
|
from pydantic import JsonValue
|
10
16
|
from typing_extensions import override
|
@@ -31,6 +37,7 @@ from inspect_ai._util.content import (
|
|
31
37
|
ContentText,
|
32
38
|
ContentVideo,
|
33
39
|
)
|
40
|
+
from inspect_ai._util.http import is_retryable_http_status
|
34
41
|
from inspect_ai._util.images import file_as_data
|
35
42
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
|
36
43
|
|
@@ -169,8 +176,18 @@ class VertexAPI(ModelAPI):
|
|
169
176
|
return output, call
|
170
177
|
|
171
178
|
@override
|
172
|
-
def
|
173
|
-
|
179
|
+
def should_retry(self, ex: Exception) -> bool:
|
180
|
+
# google API-specific errors
|
181
|
+
if isinstance(ex, Aborted | DeadlineExceeded | ServiceUnavailable):
|
182
|
+
return True
|
183
|
+
# standard HTTP errors
|
184
|
+
elif isinstance(ex, ClientError) and ex.code is not None:
|
185
|
+
return is_retryable_http_status(ex.code)
|
186
|
+
# additional errors flagged by google as transient
|
187
|
+
elif isinstance(ex, Exception):
|
188
|
+
return if_transient_error(ex)
|
189
|
+
else:
|
190
|
+
return False
|
174
191
|
|
175
192
|
@override
|
176
193
|
def connection_key(self) -> str:
|
@@ -1,13 +1,15 @@
|
|
1
|
-
import
|
1
|
+
import concurrent.futures
|
2
2
|
import functools
|
3
3
|
import gc
|
4
4
|
import os
|
5
5
|
import time
|
6
|
+
from concurrent.futures import Future
|
6
7
|
from dataclasses import dataclass
|
7
8
|
from queue import Empty, Queue
|
8
9
|
from threading import Thread
|
9
10
|
from typing import Any, cast
|
10
11
|
|
12
|
+
import anyio
|
11
13
|
from typing_extensions import override
|
12
14
|
from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
|
13
15
|
|
@@ -280,8 +282,7 @@ class VLLMAPI(ModelAPI):
|
|
280
282
|
@dataclass
|
281
283
|
class _QueueItem:
|
282
284
|
input: GenerateInput
|
283
|
-
future:
|
284
|
-
loop: asyncio.AbstractEventLoop
|
285
|
+
future: Future[list[GenerateOutput]]
|
285
286
|
|
286
287
|
|
287
288
|
batch_thread: Thread | None = None
|
@@ -297,15 +298,16 @@ async def batched_generate(input: GenerateInput) -> list[GenerateOutput]:
|
|
297
298
|
batch_thread.start()
|
298
299
|
|
299
300
|
# enqueue the job
|
300
|
-
|
301
|
-
|
302
|
-
batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
|
301
|
+
future = Future[list[GenerateOutput]]()
|
302
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
303
303
|
|
304
|
-
# await the
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
304
|
+
# await the future
|
305
|
+
while True:
|
306
|
+
try:
|
307
|
+
return future.result(timeout=0.01)
|
308
|
+
except concurrent.futures.TimeoutError:
|
309
|
+
pass
|
310
|
+
await anyio.sleep(1)
|
309
311
|
|
310
312
|
|
311
313
|
def string_to_bytes(string: str) -> list[int]:
|
@@ -397,13 +399,12 @@ def post_process_outputs(
|
|
397
399
|
def process_batches() -> None:
|
398
400
|
while True:
|
399
401
|
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
400
|
-
inputs: list[tuple[GenerateInput,
|
402
|
+
inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
|
401
403
|
while True:
|
402
404
|
try:
|
403
405
|
input = batch_queue.get(
|
404
406
|
timeout=2
|
405
407
|
) # wait 2 seconds max TODO: what's optimal wait time?
|
406
|
-
loop = input.loop
|
407
408
|
inputs.append((input.input, input.future))
|
408
409
|
if len(inputs) >= input.input.batch_size:
|
409
410
|
# max batch size reached
|
@@ -429,14 +430,10 @@ def process_batches() -> None:
|
|
429
430
|
for i, output in enumerate(outputs):
|
430
431
|
future = inputs[i][1]
|
431
432
|
|
432
|
-
|
433
|
-
# down to this point, so we can mark the future as done in a thread safe manner.
|
434
|
-
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
435
|
-
loop.call_soon_threadsafe(
|
436
|
-
future.set_result,
|
433
|
+
future.set_result(
|
437
434
|
post_process_outputs(output, num_top_logprobs, total_time),
|
438
435
|
)
|
439
436
|
|
440
437
|
except Exception as e:
|
441
438
|
for _, future in inputs:
|
442
|
-
|
439
|
+
future.set_exception(e)
|