inspect-ai 0.3.71__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +172 -154
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_computer/_common.py +117 -58
  84. inspect_ai/tool/_tools/_computer/_computer.py +80 -57
  85. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +7 -1
  86. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +91 -0
  87. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +8 -0
  88. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +12 -0
  89. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +78 -0
  90. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +20 -0
  91. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +175 -113
  92. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +76 -20
  93. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +65 -0
  94. inspect_ai/tool/_tools/_computer/test_args.py +151 -0
  95. inspect_ai/tool/_tools/_web_search.py +30 -24
  96. inspect_ai/util/__init__.py +4 -0
  97. inspect_ai/util/_concurrency.py +5 -6
  98. inspect_ai/util/_display.py +6 -0
  99. inspect_ai/util/_json.py +170 -0
  100. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  101. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  102. inspect_ai/util/_sandbox/environment.py +56 -9
  103. inspect_ai/util/_sandbox/service.py +12 -5
  104. inspect_ai/util/_subprocess.py +94 -113
  105. inspect_ai/util/_subtask.py +2 -4
  106. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  107. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +111 -103
  108. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  109. inspect_ai/_util/timeouts.py +0 -160
  110. inspect_ai/model/_providers/util/tracker.py +0 -92
  111. inspect_ai/tool/_tools/_computer/_computer_split.py +0 -198
  112. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  113. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  114. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from httpcore import ReadTimeout
7
7
  from httpx import ReadTimeout as AsyncReadTimeout
8
8
  from mistralai import (
9
9
  ContentChunk,
10
+ DocumentURLChunk,
10
11
  FunctionCall,
11
12
  FunctionName,
12
13
  ImageURL,
@@ -22,6 +23,12 @@ from mistralai.models import (
22
23
  ChatCompletionChoice as MistralChatCompletionChoice,
23
24
  )
24
25
  from mistralai.models import Function as MistralFunction
26
+ from mistralai.models import (
27
+ JSONSchema as MistralJSONSchema,
28
+ )
29
+ from mistralai.models import (
30
+ ResponseFormat as MistralResponseFormat,
31
+ )
25
32
  from mistralai.models import SDKError
26
33
  from mistralai.models import SystemMessage as MistralSystemMessage
27
34
  from mistralai.models import Tool as MistralTool
@@ -38,11 +45,9 @@ from typing_extensions import override
38
45
 
39
46
  # TODO: Migration guide:
40
47
  # https://github.com/mistralai/client-python/blob/main/MIGRATION.md
41
- from inspect_ai._util.constants import (
42
- DEFAULT_TIMEOUT,
43
- NO_CONTENT,
44
- )
48
+ from inspect_ai._util.constants import NO_CONTENT
45
49
  from inspect_ai._util.content import Content, ContentImage, ContentText
50
+ from inspect_ai._util.http import is_retryable_http_status
46
51
  from inspect_ai._util.images import file_as_data_uri
47
52
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
53
 
@@ -61,7 +66,7 @@ from .._model_output import (
61
66
  StopReason,
62
67
  )
63
68
  from .util import environment_prerequisite_error, model_base_url
64
- from .util.tracker import HttpxTimeTracker
69
+ from .util.hooks import HttpxHooks
65
70
 
66
71
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
67
72
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -127,16 +132,12 @@ class MistralAPI(ModelAPI):
127
132
  config: GenerateConfig,
128
133
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
129
134
  # create client
130
- with Mistral(
131
- api_key=self.api_key,
132
- timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
133
- **self.model_args,
134
- ) as client:
135
+ with Mistral(api_key=self.api_key, **self.model_args) as client:
135
136
  # create time tracker
136
- time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
137
+ http_hooks = HttpxHooks(client.sdk_configuration.async_client)
137
138
 
138
139
  # build request
139
- request_id = time_tracker.start_request()
140
+ request_id = http_hooks.start_request()
140
141
  request: dict[str, Any] = dict(
141
142
  model=self.model_name,
142
143
  messages=await mistral_chat_messages(input),
@@ -144,7 +145,7 @@ class MistralAPI(ModelAPI):
144
145
  tool_choice=(
145
146
  mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
146
147
  ),
147
- http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
148
+ http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
148
149
  )
149
150
  if config.temperature is not None:
150
151
  request["temperature"] = config.temperature
@@ -154,6 +155,18 @@ class MistralAPI(ModelAPI):
154
155
  request["max_tokens"] = config.max_tokens
155
156
  if config.seed is not None:
156
157
  request["random_seed"] = config.seed
158
+ if config.response_schema is not None:
159
+ request["response_format"] = MistralResponseFormat(
160
+ type="json_schema",
161
+ json_schema=MistralJSONSchema(
162
+ name=config.response_schema.name,
163
+ description=config.response_schema.description,
164
+ schema_definition=config.response_schema.json_schema.model_dump(
165
+ exclude_none=True
166
+ ),
167
+ strict=config.response_schema.strict,
168
+ ),
169
+ )
157
170
 
158
171
  # prepare response for inclusion in model call
159
172
  response: dict[str, Any] = {}
@@ -169,7 +182,7 @@ class MistralAPI(ModelAPI):
169
182
  return ModelCall.create(
170
183
  request=req,
171
184
  response=response,
172
- time=time_tracker.end_request(request_id),
185
+ time=http_hooks.end_request(request_id),
173
186
  )
174
187
 
175
188
  # send request
@@ -205,12 +218,13 @@ class MistralAPI(ModelAPI):
205
218
  ), model_call()
206
219
 
207
220
  @override
208
- def is_rate_limit(self, ex: BaseException) -> bool:
209
- return (
210
- isinstance(ex, SDKError)
211
- and ex.status_code == 429
212
- or isinstance(ex, ReadTimeout | AsyncReadTimeout)
213
- )
221
+ def should_retry(self, ex: Exception) -> bool:
222
+ if isinstance(ex, SDKError):
223
+ return is_retryable_http_status(ex.status_code)
224
+ elif isinstance(ex, ReadTimeout | AsyncReadTimeout):
225
+ return True
226
+ else:
227
+ return False
214
228
 
215
229
  @override
216
230
  def connection_key(self) -> str:
@@ -462,6 +476,8 @@ def completion_content_chunk(content: ContentChunk) -> Content:
462
476
  raise TypeError("ReferenceChunk content is not supported by Inspect.")
463
477
  elif isinstance(content, TextChunk):
464
478
  return ContentText(text=content.text)
479
+ elif isinstance(content, DocumentURLChunk):
480
+ return ContentText(text=content.document_url)
465
481
  else:
466
482
  if isinstance(content.image_url, str):
467
483
  return ContentImage(image=content.image_url)
@@ -7,25 +7,22 @@ import httpx
7
7
  from openai import (
8
8
  DEFAULT_CONNECTION_LIMITS,
9
9
  DEFAULT_TIMEOUT,
10
- APIConnectionError,
10
+ APIStatusError,
11
11
  APITimeoutError,
12
12
  AsyncAzureOpenAI,
13
13
  AsyncOpenAI,
14
14
  BadRequestError,
15
- InternalServerError,
16
15
  RateLimitError,
17
16
  )
18
17
  from openai._types import NOT_GIVEN
19
- from openai.types.chat import (
20
- ChatCompletion,
21
- )
18
+ from openai.types.chat import ChatCompletion
22
19
  from typing_extensions import override
23
20
 
24
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
25
21
  from inspect_ai._util.error import PrerequisiteError
22
+ from inspect_ai._util.http import is_retryable_http_status
26
23
  from inspect_ai._util.logger import warn_once
27
24
  from inspect_ai.model._openai import chat_choices_from_openai
28
- from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
25
+ from inspect_ai.model._providers.util.hooks import HttpxHooks
29
26
  from inspect_ai.tool import ToolChoice, ToolInfo
30
27
 
31
28
  from .._chat_message import ChatMessage
@@ -130,9 +127,6 @@ class OpenAIAPI(ModelAPI):
130
127
  api_key=self.api_key,
131
128
  azure_endpoint=base_url,
132
129
  azure_deployment=model_name,
133
- max_retries=(
134
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
135
- ),
136
130
  http_client=http_client,
137
131
  **model_args,
138
132
  )
@@ -140,15 +134,12 @@ class OpenAIAPI(ModelAPI):
140
134
  self.client = AsyncOpenAI(
141
135
  api_key=self.api_key,
142
136
  base_url=model_base_url(base_url, "OPENAI_BASE_URL"),
143
- max_retries=(
144
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
145
- ),
146
137
  http_client=http_client,
147
138
  **model_args,
148
139
  )
149
140
 
150
141
  # create time tracker
151
- self._time_tracker = HttpxTimeTracker(self.client._client)
142
+ self._http_hooks = HttpxHooks(self.client._client)
152
143
 
153
144
  def is_azure(self) -> bool:
154
145
  return self.service == "azure"
@@ -186,7 +177,7 @@ class OpenAIAPI(ModelAPI):
186
177
  )
187
178
 
188
179
  # allocate request_id (so we can see it from ModelCall)
189
- request_id = self._time_tracker.start_request()
180
+ request_id = self._http_hooks.start_request()
190
181
 
191
182
  # setup request and response for ModelCall
192
183
  request: dict[str, Any] = {}
@@ -197,7 +188,7 @@ class OpenAIAPI(ModelAPI):
197
188
  request=request,
198
189
  response=response,
199
190
  filter=image_url_filter,
200
- time=self._time_tracker.end_request(request_id),
191
+ time=self._http_hooks.end_request(request_id),
201
192
  )
202
193
 
203
194
  # unlike text models, vision models require a max_tokens (and set it to a very low
@@ -216,7 +207,7 @@ class OpenAIAPI(ModelAPI):
216
207
  tool_choice=openai_chat_tool_choice(tool_choice)
217
208
  if len(tools) > 0
218
209
  else NOT_GIVEN,
219
- extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
210
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
220
211
  **self.completion_params(config, len(tools) > 0),
221
212
  )
222
213
 
@@ -266,17 +257,21 @@ class OpenAIAPI(ModelAPI):
266
257
  return chat_choices_from_openai(response, tools)
267
258
 
268
259
  @override
269
- def is_rate_limit(self, ex: BaseException) -> bool:
260
+ def should_retry(self, ex: Exception) -> bool:
270
261
  if isinstance(ex, RateLimitError):
271
262
  # Do not retry on these rate limit errors
272
263
  # The quota exceeded one is related to monthly account quotas.
273
- if "You exceeded your current quota" not in ex.message:
264
+ if "You exceeded your current quota" in ex.message:
265
+ warn_once(logger, f"OpenAI quota exceeded, not retrying: {ex.message}")
266
+ return False
267
+ else:
274
268
  return True
275
- elif isinstance(
276
- ex, (APIConnectionError | APITimeoutError | InternalServerError)
277
- ):
269
+ elif isinstance(ex, APIStatusError):
270
+ return is_retryable_http_status(ex.status_code)
271
+ elif isinstance(ex, APITimeoutError):
278
272
  return True
279
- return False
273
+ else:
274
+ return False
280
275
 
281
276
  @override
282
277
  def connection_key(self) -> str:
@@ -315,8 +310,6 @@ class OpenAIAPI(ModelAPI):
315
310
  params["temperature"] = 1
316
311
  if config.top_p is not None:
317
312
  params["top_p"] = config.top_p
318
- if config.timeout is not None:
319
- params["timeout"] = float(config.timeout)
320
313
  if config.num_choices is not None:
321
314
  params["n"] = config.num_choices
322
315
  if config.logprobs is not None:
@@ -331,6 +324,18 @@ class OpenAIAPI(ModelAPI):
331
324
  and not self.is_o1_mini()
332
325
  ):
333
326
  params["reasoning_effort"] = config.reasoning_effort
327
+ if config.response_schema is not None:
328
+ params["response_format"] = dict(
329
+ type="json_schema",
330
+ json_schema=dict(
331
+ name=config.response_schema.name,
332
+ schema=config.response_schema.json_schema.model_dump(
333
+ exclude_none=True
334
+ ),
335
+ description=config.response_schema.description,
336
+ strict=config.response_schema.strict,
337
+ ),
338
+ )
334
339
 
335
340
  return params
336
341
 
@@ -107,7 +107,7 @@ def chat_messages(
107
107
  ) -> list[ChatCompletionMessageParam]:
108
108
  # o1 does not allow system messages so convert system -> user
109
109
  messages: list[ChatMessage] = [
110
- ChatMessageUser(content=message.content)
110
+ ChatMessageUser(id=message.id, content=message.content)
111
111
  if message.role == "system"
112
112
  else message
113
113
  for message in input
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
148
148
  def mistral() -> type[ModelAPI]:
149
149
  FEATURE = "Mistral API"
150
150
  PACKAGE = "mistralai"
151
- MIN_VERSION = "1.5.0"
151
+ MIN_VERSION = "1.5.1"
152
152
 
153
153
  # verify we have the package
154
154
  try:
@@ -34,8 +34,8 @@ from .util import (
34
34
  chat_api_input,
35
35
  chat_api_request,
36
36
  environment_prerequisite_error,
37
- is_chat_api_rate_limit,
38
37
  model_base_url,
38
+ should_retry_chat_api_error,
39
39
  )
40
40
 
41
41
 
@@ -186,7 +186,6 @@ class TogetherRESTAPI(ModelAPI):
186
186
  url=f"{chat_url}",
187
187
  headers={"Authorization": f"Bearer {self.api_key}"},
188
188
  json=json,
189
- config=config,
190
189
  )
191
190
 
192
191
  if "error" in response:
@@ -215,8 +214,8 @@ class TogetherRESTAPI(ModelAPI):
215
214
  return ModelOutput(model=model, choices=choices, usage=usage)
216
215
 
217
216
  @override
218
- def is_rate_limit(self, ex: BaseException) -> bool:
219
- return is_chat_api_rate_limit(ex)
217
+ def should_retry(self, ex: Exception) -> bool:
218
+ return should_retry_chat_api_error(ex)
220
219
 
221
220
  # cloudflare enforces rate limits by model for each account
222
221
  @override
@@ -5,7 +5,7 @@ from .chatapi import (
5
5
  ChatAPIMessage,
6
6
  chat_api_input,
7
7
  chat_api_request,
8
- is_chat_api_rate_limit,
8
+ should_retry_chat_api_error,
9
9
  )
10
10
  from .hf_handler import HFHandler
11
11
  from .llama31 import Llama31Handler
@@ -19,7 +19,7 @@ __all__ = [
19
19
  "as_stop_reason",
20
20
  "chat_api_request",
21
21
  "chat_api_input",
22
- "is_chat_api_rate_limit",
22
+ "should_retry_chat_api_error",
23
23
  "model_base_url",
24
24
  "parse_tool_call",
25
25
  "tool_parse_error_message",
@@ -7,17 +7,15 @@ from tenacity import (
7
7
  retry,
8
8
  retry_if_exception,
9
9
  stop_after_attempt,
10
- stop_after_delay,
11
10
  wait_exponential_jitter,
12
11
  )
13
12
 
14
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
15
- from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
13
+ from inspect_ai._util.http import is_retryable_http_status
14
+ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
16
15
  from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool
17
16
  from inspect_ai.tool._tool_info import ToolInfo
18
17
 
19
18
  from ..._chat_message import ChatMessage
20
- from ..._generate_config import GenerateConfig
21
19
 
22
20
  logger = getLogger(__name__)
23
21
 
@@ -75,21 +73,13 @@ async def chat_api_request(
75
73
  url: str,
76
74
  headers: dict[str, Any],
77
75
  json: Any,
78
- config: GenerateConfig,
79
76
  ) -> Any:
80
- # provide default max_retries
81
- max_retries = config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
82
-
83
77
  # define call w/ retry policy
84
78
  @retry(
85
79
  wait=wait_exponential_jitter(),
86
- stop=(
87
- (stop_after_attempt(max_retries) | stop_after_delay(config.timeout))
88
- if config.timeout
89
- else stop_after_attempt(max_retries)
90
- ),
80
+ stop=(stop_after_attempt(2)),
91
81
  retry=retry_if_exception(httpx_should_retry),
92
- before_sleep=log_retry_attempt(model_name),
82
+ before_sleep=log_httpx_retry_attempt(model_name),
93
83
  )
94
84
  async def call_api() -> Any:
95
85
  response = await client.post(url=url, headers=headers, json=json)
@@ -104,14 +94,11 @@ async def chat_api_request(
104
94
  # checking for rate limit errors needs to punch through the RetryError and
105
95
  # look at its `__cause__`. we've observed Cloudflare giving transient 500
106
96
  # status as well as a ReadTimeout, so we count these as rate limit errors
107
- def is_chat_api_rate_limit(ex: BaseException) -> bool:
97
+ def should_retry_chat_api_error(ex: BaseException) -> bool:
108
98
  return isinstance(ex, RetryError) and (
109
99
  (
110
100
  isinstance(ex.__cause__, httpx.HTTPStatusError)
111
- and (
112
- ex.__cause__.response.status_code == 429
113
- or ex.__cause__.response.status_code == 500
114
- )
101
+ and is_retryable_http_status(ex.__cause__.response.status_code)
115
102
  )
116
103
  or isinstance(ex.__cause__, httpx.ReadTimeout)
117
104
  )
@@ -0,0 +1,165 @@
1
+ import re
2
+ import time
3
+ from logging import getLogger
4
+ from typing import Any, Mapping, NamedTuple, cast
5
+
6
+ import httpx
7
+ from shortuuid import uuid
8
+
9
+ from inspect_ai._util.constants import HTTP
10
+ from inspect_ai._util.retry import report_http_retry
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class RequestInfo(NamedTuple):
16
+ attempts: int
17
+ last_request: float
18
+
19
+
20
+ class HttpHooks:
21
+ """Class which hooks various HTTP clients for improved tracking/logging.
22
+
23
+ A special header is injected into requests which is then read from
24
+ a request event hook -- this creates a record of when the request
25
+ started. Note that with retries a single request_id could be started
26
+ several times; our request hook makes sure we always track the time of
27
+ the last request.
28
+
29
+ There is an 'end_request()' method which gets the total request time
30
+ for a request_id and then purges the request_id from our tracking (so
31
+ the dict doesn't grow unbounded)
32
+
33
+ Additionally, an http response hook is installed and used for logging
34
+ requests for the 'http' log-level
35
+ """
36
+
37
+ REQUEST_ID_HEADER = "x-irid"
38
+
39
+ def __init__(self) -> None:
40
+ # track request start times
41
+ self._requests: dict[str, RequestInfo] = {}
42
+
43
+ def start_request(self) -> str:
44
+ request_id = uuid()
45
+ self._requests[request_id] = RequestInfo(0, time.monotonic())
46
+ return request_id
47
+
48
+ def end_request(self, request_id: str) -> float:
49
+ # read the request info (if available) and purge from dict
50
+ request_info = self._requests.pop(request_id, None)
51
+ if request_info is None:
52
+ raise RuntimeError(f"request_id not registered: {request_id}")
53
+
54
+ # return elapsed time
55
+ return time.monotonic() - request_info.last_request
56
+
57
+ def update_request_time(self, request_id: str) -> None:
58
+ request_info = self._requests.get(request_id, None)
59
+ if not request_info:
60
+ raise RuntimeError(f"No request registered for request_id: {request_id}")
61
+
62
+ # update the attempts and last request time
63
+ request_info = RequestInfo(request_info.attempts + 1, time.monotonic())
64
+ self._requests[request_id] = request_info
65
+
66
+ # trace a retry if this is attempt > 1
67
+ if request_info.attempts > 1:
68
+ report_http_retry()
69
+
70
+
71
+ class ConverseHooks(HttpHooks):
72
+ def __init__(self, session: Any) -> None:
73
+ from aiobotocore.session import AioSession
74
+
75
+ super().__init__()
76
+
77
+ # register hooks
78
+ session = cast(AioSession, session._session)
79
+ session.register(
80
+ "before-send.bedrock-runtime.Converse", self.converse_before_send
81
+ )
82
+ session.register(
83
+ "after-call.bedrock-runtime.Converse", self.converse_after_call
84
+ )
85
+
86
+ def converse_before_send(self, **kwargs: Any) -> None:
87
+ user_agent = kwargs["request"].headers["User-Agent"].decode()
88
+ match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
89
+ if match:
90
+ request_id = match.group(1)
91
+ self.update_request_time(request_id)
92
+
93
+ def converse_after_call(self, http_response: Any, **kwargs: Any) -> None:
94
+ from botocore.awsrequest import AWSResponse
95
+
96
+ response = cast(AWSResponse, http_response)
97
+ logger.log(HTTP, f"POST {response.url} - {response.status_code}")
98
+
99
+ def user_agent_extra(self, request_id: str) -> str:
100
+ return f"{self.USER_AGENT_PREFIX}{request_id}"
101
+
102
+ USER_AGENT_PREFIX = "ins/rid#"
103
+
104
+
105
+ class HttpxHooks(HttpHooks):
106
+ def __init__(self, client: httpx.AsyncClient):
107
+ super().__init__()
108
+
109
+ # install hooks
110
+ client.event_hooks["request"].append(self.request_hook)
111
+ client.event_hooks["response"].append(self.response_hook)
112
+
113
+ async def request_hook(self, request: httpx.Request) -> None:
114
+ # update the last request time for this request id (as there could be retries)
115
+ request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
116
+ if request_id:
117
+ self.update_request_time(request_id)
118
+
119
+ async def response_hook(self, response: httpx.Response) -> None:
120
+ message = f'{response.request.method} {response.request.url} "{response.http_version} {response.status_code} {response.reason_phrase}" '
121
+ logger.log(HTTP, message)
122
+
123
+
124
+ def urllib3_hooks() -> HttpHooks:
125
+ import urllib3
126
+ from urllib3.connectionpool import HTTPConnectionPool
127
+ from urllib3.response import BaseHTTPResponse
128
+
129
+ class Urllib3Hooks(HttpHooks):
130
+ def request_hook(self, headers: Mapping[str, str]) -> None:
131
+ # update the last request time for this request id (as there could be retries)
132
+ request_id = headers.get(self.REQUEST_ID_HEADER, None)
133
+ if request_id:
134
+ self.update_request_time(request_id)
135
+
136
+ def response_hook(
137
+ self, method: str, url: str, response: BaseHTTPResponse
138
+ ) -> None:
139
+ message = f'{method} {url} "{response.version_string} {response.status} {response.reason}" '
140
+ logger.log(HTTP, message)
141
+
142
+ global _urlilb3_hooks
143
+ if _urlilb3_hooks is None:
144
+ # one time patch of urlopen
145
+ urlilb3_hooks = Urllib3Hooks()
146
+ original_urlopen = urllib3.connectionpool.HTTPConnectionPool.urlopen
147
+
148
+ def patched_urlopen(
149
+ self: HTTPConnectionPool, method: str, url: str, **kwargs: Any
150
+ ) -> BaseHTTPResponse:
151
+ headers = kwargs.get("headers", {})
152
+ urlilb3_hooks.request_hook(headers)
153
+ response = original_urlopen(self, method, url, **kwargs)
154
+ urlilb3_hooks.response_hook(method, f"{self.host}{url}", response)
155
+ return response
156
+
157
+ urllib3.connectionpool.HTTPConnectionPool.urlopen = patched_urlopen # type: ignore[assignment,method-assign]
158
+
159
+ # assign to global hooks instance
160
+ _urlilb3_hooks = urlilb3_hooks
161
+
162
+ return _urlilb3_hooks
163
+
164
+
165
+ _urlilb3_hooks: HttpHooks | None = None
@@ -4,7 +4,13 @@ from copy import copy
4
4
  from typing import Any, cast
5
5
 
6
6
  import vertexai # type: ignore
7
- from google.api_core.exceptions import TooManyRequests
7
+ from google.api_core.exceptions import (
8
+ Aborted,
9
+ ClientError,
10
+ DeadlineExceeded,
11
+ ServiceUnavailable,
12
+ )
13
+ from google.api_core.retry import if_transient_error
8
14
  from google.protobuf.json_format import MessageToDict
9
15
  from pydantic import JsonValue
10
16
  from typing_extensions import override
@@ -31,6 +37,7 @@ from inspect_ai._util.content import (
31
37
  ContentText,
32
38
  ContentVideo,
33
39
  )
40
+ from inspect_ai._util.http import is_retryable_http_status
34
41
  from inspect_ai._util.images import file_as_data
35
42
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
36
43
 
@@ -169,8 +176,18 @@ class VertexAPI(ModelAPI):
169
176
  return output, call
170
177
 
171
178
  @override
172
- def is_rate_limit(self, ex: BaseException) -> bool:
173
- return isinstance(ex, TooManyRequests)
179
+ def should_retry(self, ex: Exception) -> bool:
180
+ # google API-specific errors
181
+ if isinstance(ex, Aborted | DeadlineExceeded | ServiceUnavailable):
182
+ return True
183
+ # standard HTTP errors
184
+ elif isinstance(ex, ClientError) and ex.code is not None:
185
+ return is_retryable_http_status(ex.code)
186
+ # additional errors flagged by google as transient
187
+ elif isinstance(ex, Exception):
188
+ return if_transient_error(ex)
189
+ else:
190
+ return False
174
191
 
175
192
  @override
176
193
  def connection_key(self) -> str:
@@ -1,13 +1,15 @@
1
- import asyncio
1
+ import concurrent.futures
2
2
  import functools
3
3
  import gc
4
4
  import os
5
5
  import time
6
+ from concurrent.futures import Future
6
7
  from dataclasses import dataclass
7
8
  from queue import Empty, Queue
8
9
  from threading import Thread
9
10
  from typing import Any, cast
10
11
 
12
+ import anyio
11
13
  from typing_extensions import override
12
14
  from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
13
15
 
@@ -280,8 +282,7 @@ class VLLMAPI(ModelAPI):
280
282
  @dataclass
281
283
  class _QueueItem:
282
284
  input: GenerateInput
283
- future: asyncio.Future[list[GenerateOutput]]
284
- loop: asyncio.AbstractEventLoop
285
+ future: Future[list[GenerateOutput]]
285
286
 
286
287
 
287
288
  batch_thread: Thread | None = None
@@ -297,15 +298,16 @@ async def batched_generate(input: GenerateInput) -> list[GenerateOutput]:
297
298
  batch_thread.start()
298
299
 
299
300
  # enqueue the job
300
- loop = asyncio.get_event_loop()
301
- future: asyncio.Future[list[GenerateOutput]] = loop.create_future()
302
- batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
301
+ future = Future[list[GenerateOutput]]()
302
+ batch_queue.put(_QueueItem(input=input, future=future))
303
303
 
304
- # await the job
305
- await future
306
-
307
- # return it
308
- return future.result()
304
+ # await the future
305
+ while True:
306
+ try:
307
+ return future.result(timeout=0.01)
308
+ except concurrent.futures.TimeoutError:
309
+ pass
310
+ await anyio.sleep(1)
309
311
 
310
312
 
311
313
  def string_to_bytes(string: str) -> list[int]:
@@ -397,13 +399,12 @@ def post_process_outputs(
397
399
  def process_batches() -> None:
398
400
  while True:
399
401
  # drain the queue (wait until no new messages have shown up for 2 seconds)
400
- inputs: list[tuple[GenerateInput, asyncio.Future[list[GenerateOutput]]]] = []
402
+ inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
401
403
  while True:
402
404
  try:
403
405
  input = batch_queue.get(
404
406
  timeout=2
405
407
  ) # wait 2 seconds max TODO: what's optimal wait time?
406
- loop = input.loop
407
408
  inputs.append((input.input, input.future))
408
409
  if len(inputs) >= input.input.batch_size:
409
410
  # max batch size reached
@@ -429,14 +430,10 @@ def process_batches() -> None:
429
430
  for i, output in enumerate(outputs):
430
431
  future = inputs[i][1]
431
432
 
432
- # asyncio futures are not thread safe, so we need to pass the event loop
433
- # down to this point, so we can mark the future as done in a thread safe manner.
434
- # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
435
- loop.call_soon_threadsafe(
436
- future.set_result,
433
+ future.set_result(
437
434
  post_process_outputs(output, num_top_logprobs, total_time),
438
435
  )
439
436
 
440
437
  except Exception as e:
441
438
  for _, future in inputs:
442
- loop.call_soon_threadsafe(future.set_exception, e)
439
+ future.set_exception(e)