inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from typing import Any, AsyncIterator, Callable, Literal, Type, cast
13
13
 
14
14
  from pydantic_core import to_jsonable_python
15
15
  from tenacity import (
16
+ RetryCallState,
16
17
  retry,
17
18
  retry_if_exception,
18
19
  stop_after_attempt,
@@ -20,8 +21,9 @@ from tenacity import (
20
21
  stop_never,
21
22
  wait_exponential_jitter,
22
23
  )
24
+ from tenacity.stop import StopBaseT
23
25
 
24
- from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
26
+ from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS, HTTP
25
27
  from inspect_ai._util.content import (
26
28
  Content,
27
29
  ContentImage,
@@ -30,6 +32,7 @@ from inspect_ai._util.content import (
30
32
  )
31
33
  from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
32
34
  from inspect_ai._util.interrupt import check_sample_interrupt
35
+ from inspect_ai._util.logger import warn_once
33
36
  from inspect_ai._util.platform import platform_init
34
37
  from inspect_ai._util.registry import (
35
38
  RegistryInfo,
@@ -37,7 +40,7 @@ from inspect_ai._util.registry import (
37
40
  registry_info,
38
41
  registry_unqualified_name,
39
42
  )
40
- from inspect_ai._util.retry import log_rate_limit_retry
43
+ from inspect_ai._util.retry import report_http_retry
41
44
  from inspect_ai._util.trace import trace_action
42
45
  from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
43
46
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
@@ -173,11 +176,11 @@ class ModelAPI(abc.ABC):
173
176
  """Scope for enforcement of max_connections."""
174
177
  return "default"
175
178
 
176
- def is_rate_limit(self, ex: BaseException) -> bool:
177
- """Is this exception a rate limit error.
179
+ def should_retry(self, ex: Exception) -> bool:
180
+ """Should this exception be retried?
178
181
 
179
182
  Args:
180
- ex: Exception to check for rate limit.
183
+ ex: Exception to check for retry
181
184
  """
182
185
  return False
183
186
 
@@ -331,14 +334,17 @@ class Model:
331
334
  start_time = datetime.now()
332
335
  working_start = sample_working_time()
333
336
  async with self._connection_concurrency(config):
337
+ from inspect_ai.log._samples import track_active_sample_retries
338
+
334
339
  # generate
335
- output = await self._generate(
336
- input=input,
337
- tools=tools,
338
- tool_choice=tool_choice,
339
- config=config,
340
- cache=cache,
341
- )
340
+ with track_active_sample_retries():
341
+ output = await self._generate(
342
+ input=input,
343
+ tools=tools,
344
+ tool_choice=tool_choice,
345
+ config=config,
346
+ cache=cache,
347
+ )
342
348
 
343
349
  # update the most recent ModelEvent with the actual start/completed
344
350
  # times as well as a computation of working time (events are
@@ -418,27 +424,27 @@ class Model:
418
424
  if self.api.collapse_assistant_messages():
419
425
  input = collapse_consecutive_assistant_messages(input)
420
426
 
421
- # retry for rate limit errors (max of 30 minutes)
427
+ # retry for transient http errors:
428
+ # - no default timeout or max_retries (try forever)
429
+ # - exponential backoff starting at 3 seconds (will wait 25 minutes
430
+ # on the 10th retry,then will wait no longer than 30 minutes on
431
+ # subsequent retries)
432
+ if config.max_retries is not None and config.timeout is not None:
433
+ stop: StopBaseT = stop_after_attempt(config.max_retries) | stop_after_delay(
434
+ config.timeout
435
+ )
436
+ elif config.max_retries is not None:
437
+ stop = stop_after_attempt(config.max_retries)
438
+ elif config.timeout is not None:
439
+ stop = stop_after_delay(config.timeout)
440
+ else:
441
+ stop = stop_never
442
+
422
443
  @retry(
423
- wait=wait_exponential_jitter(max=(30 * 60), jitter=5),
424
- retry=retry_if_exception(self.api.is_rate_limit),
425
- stop=(
426
- (
427
- stop_after_delay(config.timeout)
428
- | stop_after_attempt(config.max_retries)
429
- )
430
- if config.timeout and config.max_retries
431
- else (
432
- stop_after_delay(config.timeout)
433
- if config.timeout
434
- else (
435
- stop_after_attempt(config.max_retries)
436
- if config.max_retries
437
- else stop_never
438
- )
439
- )
440
- ),
441
- before_sleep=functools.partial(log_rate_limit_retry, self.api.model_name),
444
+ wait=wait_exponential_jitter(initial=3, max=(30 * 60), jitter=3),
445
+ retry=retry_if_exception(self.should_retry),
446
+ stop=stop,
447
+ before_sleep=functools.partial(log_model_retry, self.api.model_name),
442
448
  )
443
449
  async def generate() -> ModelOutput:
444
450
  check_sample_interrupt()
@@ -555,6 +561,30 @@ class Model:
555
561
  # return results
556
562
  return model_output
557
563
 
564
+ def should_retry(self, ex: BaseException) -> bool:
565
+ if isinstance(ex, Exception):
566
+ # check standard should_retry() method
567
+ retry = self.api.should_retry(ex)
568
+ if retry:
569
+ report_http_retry()
570
+ return True
571
+
572
+ # see if the API implements legacy is_rate_limit() method
573
+ is_rate_limit = getattr(self.api, "is_rate_limit", None)
574
+ if is_rate_limit:
575
+ warn_once(
576
+ logger,
577
+ f"provider '{self.name}' implements deprecated is_rate_limit() method, "
578
+ + "please change to should_retry()",
579
+ )
580
+ retry = cast(bool, is_rate_limit(ex))
581
+ if retry:
582
+ report_http_retry()
583
+ return True
584
+
585
+ # no retry
586
+ return False
587
+
558
588
  # function to verify that its okay to call model apis
559
589
  def verify_model_apis(self) -> None:
560
590
  if (
@@ -1064,6 +1094,7 @@ def tool_result_images_reducer(
1064
1094
  messages
1065
1095
  + [
1066
1096
  ChatMessageTool(
1097
+ id=message.id,
1067
1098
  content=edited_tool_message_content,
1068
1099
  tool_call_id=message.tool_call_id,
1069
1100
  function=message.function,
@@ -1170,19 +1201,26 @@ def combine_messages(
1170
1201
  a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage]
1171
1202
  ) -> ChatMessage:
1172
1203
  if isinstance(a.content, str) and isinstance(b.content, str):
1173
- return message_type(content=f"{a.content}\n{b.content}")
1204
+ return message_type(id=a.id, content=f"{a.content}\n{b.content}")
1174
1205
  elif isinstance(a.content, list) and isinstance(b.content, list):
1175
- return message_type(content=a.content + b.content)
1206
+ return message_type(id=a.id, content=a.content + b.content)
1176
1207
  elif isinstance(a.content, str) and isinstance(b.content, list):
1177
- return message_type(content=[ContentText(text=a.content), *b.content])
1208
+ return message_type(id=a.id, content=[ContentText(text=a.content), *b.content])
1178
1209
  elif isinstance(a.content, list) and isinstance(b.content, str):
1179
- return message_type(content=a.content + [ContentText(text=b.content)])
1210
+ return message_type(id=a.id, content=a.content + [ContentText(text=b.content)])
1180
1211
  else:
1181
1212
  raise TypeError(
1182
1213
  f"Cannot combine messages with invalid content types: {a.content!r}, {b.content!r}"
1183
1214
  )
1184
1215
 
1185
1216
 
1217
+ def log_model_retry(model_name: str, retry_state: RetryCallState) -> None:
1218
+ logger.log(
1219
+ HTTP,
1220
+ f"-> {model_name} retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}",
1221
+ )
1222
+
1223
+
1186
1224
  def init_active_model(model: Model, config: GenerateConfig) -> None:
1187
1225
  active_model_context_var.set(model)
1188
1226
  set_active_generate_config(config)
@@ -52,7 +52,7 @@ from ._model_output import ModelUsage, StopReason, as_stop_reason
52
52
 
53
53
 
54
54
  def is_o_series(name: str) -> bool:
55
- return bool(re.match(r"^o\d+", name))
55
+ return bool(re.match(r"(^|.*\/)o\d+", name))
56
56
 
57
57
 
58
58
  def is_o1_mini(name: str) -> bool:
@@ -396,6 +396,9 @@ def content_from_openai(
396
396
  content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
397
397
  parse_reasoning: bool = False,
398
398
  ) -> list[Content]:
399
+ # Some providers omit the type tag and use "object-with-a-single-field" encoding
400
+ if "type" not in content and len(content) == 1:
401
+ content["type"] = list(content.keys())[0] # type: ignore[arg-type]
399
402
  if content["type"] == "text":
400
403
  text = content["text"]
401
404
  if parse_reasoning:
@@ -413,6 +416,8 @@ def content_from_openai(
413
416
  return [ContentText(text=text)]
414
417
  else:
415
418
  return [ContentText(text=text)]
419
+ elif content["type"] == "reasoning": # type: ignore[comparison-overlap]
420
+ return [ContentReasoning(reasoning=content["reasoning"])]
416
421
  elif content["type"] == "image_url":
417
422
  return [
418
423
  ContentImage(
@@ -428,6 +433,9 @@ def content_from_openai(
428
433
  ]
429
434
  elif content["type"] == "refusal":
430
435
  return [ContentText(text=content["refusal"])]
436
+ else:
437
+ content_type = content["type"]
438
+ raise ValueError(f"Unexpected content type '{content_type}' in message.")
431
439
 
432
440
 
433
441
  def chat_message_assistant_from_openai(
@@ -6,7 +6,12 @@ from copy import copy
6
6
  from logging import getLogger
7
7
  from typing import Any, Literal, Optional, Tuple, TypedDict, cast
8
8
 
9
- from .util.tracker import HttpxTimeTracker
9
+ import httpcore
10
+ import httpx
11
+
12
+ from inspect_ai._util.http import is_retryable_http_status
13
+
14
+ from .util.hooks import HttpxHooks
10
15
 
11
16
  if sys.version_info >= (3, 11):
12
17
  from typing import NotRequired
@@ -16,13 +21,12 @@ else:
16
21
  from anthropic import (
17
22
  APIConnectionError,
18
23
  APIStatusError,
24
+ APITimeoutError,
19
25
  AsyncAnthropic,
20
26
  AsyncAnthropicBedrock,
21
27
  AsyncAnthropicVertex,
22
28
  BadRequestError,
23
- InternalServerError,
24
29
  NotGiven,
25
- RateLimitError,
26
30
  )
27
31
  from anthropic._types import Body
28
32
  from anthropic.types import (
@@ -46,7 +50,6 @@ from typing_extensions import override
46
50
 
47
51
  from inspect_ai._util.constants import (
48
52
  BASE_64_DATA_REMOVED,
49
- DEFAULT_MAX_RETRIES,
50
53
  NO_CONTENT,
51
54
  )
52
55
  from inspect_ai._util.content import (
@@ -125,9 +128,6 @@ class AnthropicAPI(ModelAPI):
125
128
  AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
126
129
  ) = AsyncAnthropicBedrock(
127
130
  base_url=base_url,
128
- max_retries=(
129
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
130
- ),
131
131
  aws_region=aws_region,
132
132
  **model_args,
133
133
  )
@@ -141,9 +141,6 @@ class AnthropicAPI(ModelAPI):
141
141
  region=region,
142
142
  project_id=project_id,
143
143
  base_url=base_url,
144
- max_retries=(
145
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
146
- ),
147
144
  **model_args,
148
145
  )
149
146
  else:
@@ -156,14 +153,11 @@ class AnthropicAPI(ModelAPI):
156
153
  self.client = AsyncAnthropic(
157
154
  base_url=base_url,
158
155
  api_key=self.api_key,
159
- max_retries=(
160
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
161
- ),
162
156
  **model_args,
163
157
  )
164
158
 
165
159
  # create time tracker
166
- self._time_tracker = HttpxTimeTracker(self.client._client)
160
+ self._http_hooks = HttpxHooks(self.client._client)
167
161
 
168
162
  @override
169
163
  async def close(self) -> None:
@@ -183,7 +177,7 @@ class AnthropicAPI(ModelAPI):
183
177
  config: GenerateConfig,
184
178
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
185
179
  # allocate request_id (so we can see it from ModelCall)
186
- request_id = self._time_tracker.start_request()
180
+ request_id = self._http_hooks.start_request()
187
181
 
188
182
  # setup request and response for ModelCall
189
183
  request: dict[str, Any] = {}
@@ -194,7 +188,7 @@ class AnthropicAPI(ModelAPI):
194
188
  request=request,
195
189
  response=response,
196
190
  filter=model_call_filter,
197
- time=self._time_tracker.end_request(request_id),
191
+ time=self._http_hooks.end_request(request_id),
198
192
  )
199
193
 
200
194
  # generate
@@ -223,7 +217,7 @@ class AnthropicAPI(ModelAPI):
223
217
  request = request | req
224
218
 
225
219
  # extra headers (for time tracker and computer use)
226
- extra_headers = headers | {HttpxTimeTracker.REQUEST_ID_HEADER: request_id}
220
+ extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id}
227
221
  if computer_use:
228
222
  betas.append("computer-use-2025-01-24")
229
223
  if len(betas) > 0:
@@ -291,8 +285,6 @@ class AnthropicAPI(ModelAPI):
291
285
  betas.append("output-128k-2025-02-19")
292
286
 
293
287
  # config that applies to all models
294
- if config.timeout is not None:
295
- params["timeout"] = float(config.timeout)
296
288
  if config.stop_seqs is not None:
297
289
  params["stop_sequences"] = config.stop_seqs
298
290
 
@@ -334,13 +326,19 @@ class AnthropicAPI(ModelAPI):
334
326
  return str(self.api_key)
335
327
 
336
328
  @override
337
- def is_rate_limit(self, ex: BaseException) -> bool:
338
- # We have observed that anthropic will frequently return InternalServerError
339
- # seemingly in place of RateLimitError (at the very least the errors seem to
340
- # always be transient). Equating this to rate limit errors may occasionally
341
- # result in retrying too many times, but much more often will avert a failed
342
- # eval that just needed to survive a transient error
343
- return isinstance(ex, RateLimitError | InternalServerError | APIConnectionError)
329
+ def should_retry(self, ex: Exception) -> bool:
330
+ if isinstance(ex, APIStatusError):
331
+ return is_retryable_http_status(ex.status_code)
332
+ elif isinstance(
333
+ ex,
334
+ APIConnectionError
335
+ | APITimeoutError
336
+ | httpx.RemoteProtocolError
337
+ | httpcore.RemoteProtocolError,
338
+ ):
339
+ return True
340
+ else:
341
+ return False
344
342
 
345
343
  @override
346
344
  def collapse_user_messages(self) -> bool:
@@ -27,11 +27,16 @@ from azure.ai.inference.models import (
27
27
  UserMessage,
28
28
  )
29
29
  from azure.core.credentials import AzureKeyCredential
30
- from azure.core.exceptions import AzureError, HttpResponseError
30
+ from azure.core.exceptions import (
31
+ AzureError,
32
+ HttpResponseError,
33
+ ServiceResponseError,
34
+ )
31
35
  from typing_extensions import override
32
36
 
33
37
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
34
38
  from inspect_ai._util.content import Content, ContentImage, ContentText
39
+ from inspect_ai._util.http import is_retryable_http_status
35
40
  from inspect_ai._util.images import file_as_data_uri
36
41
  from inspect_ai.tool import ToolChoice, ToolInfo
37
42
  from inspect_ai.tool._tool_call import ToolCall
@@ -232,14 +237,11 @@ class AzureAIAPI(ModelAPI):
232
237
  return DEFAULT_MAX_TOKENS
233
238
 
234
239
  @override
235
- def is_rate_limit(self, ex: BaseException) -> bool:
236
- if isinstance(ex, HttpResponseError):
237
- return (
238
- ex.status_code == 408
239
- or ex.status_code == 409
240
- or ex.status_code == 429
241
- or ex.status_code == 500
242
- )
240
+ def should_retry(self, ex: Exception) -> bool:
241
+ if isinstance(ex, HttpResponseError) and ex.status_code is not None:
242
+ return is_retryable_http_status(ex.status_code)
243
+ elif isinstance(ex, ServiceResponseError):
244
+ return True
243
245
  else:
244
246
  return False
245
247
 
@@ -1,16 +1,14 @@
1
1
  import base64
2
+ from logging import getLogger
2
3
  from typing import Any, Literal, Tuple, Union, cast
3
4
 
4
5
  from pydantic import BaseModel, Field
5
6
  from typing_extensions import override
6
7
 
7
- from inspect_ai._util.constants import (
8
- DEFAULT_MAX_RETRIES,
9
- DEFAULT_MAX_TOKENS,
10
- DEFAULT_TIMEOUT,
11
- )
8
+ from inspect_ai._util._async import current_async_backend
9
+ from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
12
10
  from inspect_ai._util.content import Content, ContentImage, ContentText
13
- from inspect_ai._util.error import pip_dependency_error
11
+ from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
14
12
  from inspect_ai._util.images import file_as_data
15
13
  from inspect_ai._util.version import verify_required_version
16
14
  from inspect_ai.tool import ToolChoice, ToolInfo
@@ -31,7 +29,9 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
31
29
  from .util import (
32
30
  model_base_url,
33
31
  )
34
- from .util.tracker import BotoTimeTracker
32
+ from .util.hooks import ConverseHooks
33
+
34
+ logger = getLogger(__name__)
35
35
 
36
36
  # Model for Bedrock Converse API (Response)
37
37
  # generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
@@ -245,6 +245,12 @@ class BedrockAPI(ModelAPI):
245
245
  config=config,
246
246
  )
247
247
 
248
+ # raise if we are using trio
249
+ if current_async_backend() == "trio":
250
+ raise PrerequisiteError(
251
+ "ERROR: The bedrock provider does not work with the trio async backend."
252
+ )
253
+
248
254
  # save model_args
249
255
  self.model_args = model_args
250
256
 
@@ -258,7 +264,7 @@ class BedrockAPI(ModelAPI):
258
264
  self.session = aioboto3.Session()
259
265
 
260
266
  # create time tracker
261
- self._time_tracker = BotoTimeTracker(self.session)
267
+ self._http_hooks = ConverseHooks(self.session)
262
268
 
263
269
  except ImportError:
264
270
  raise pip_dependency_error("Bedrock API", ["aioboto3"])
@@ -288,15 +294,25 @@ class BedrockAPI(ModelAPI):
288
294
  return DEFAULT_MAX_TOKENS
289
295
 
290
296
  @override
291
- def is_rate_limit(self, ex: BaseException) -> bool:
297
+ def should_retry(self, ex: Exception) -> bool:
292
298
  from botocore.exceptions import ClientError
293
299
 
294
300
  # Look for an explicit throttle exception
295
301
  if isinstance(ex, ClientError):
296
- if ex.response["Error"]["Code"] == "ThrottlingException":
297
- return True
298
-
299
- return super().is_rate_limit(ex)
302
+ error_code = ex.response.get("Error", {}).get("Code", "")
303
+ return error_code in [
304
+ "ThrottlingException",
305
+ "RequestLimitExceeded",
306
+ "Throttling",
307
+ "RequestThrottled",
308
+ "TooManyRequestsException",
309
+ "ProvisionedThroughputExceededException",
310
+ "TransactionInProgressException",
311
+ "RequestTimeout",
312
+ "ServiceUnavailable",
313
+ ]
314
+ else:
315
+ return False
300
316
 
301
317
  @override
302
318
  def collapse_user_messages(self) -> bool:
@@ -317,20 +333,13 @@ class BedrockAPI(ModelAPI):
317
333
  from botocore.exceptions import ClientError
318
334
 
319
335
  # The bedrock client
320
- request_id = self._time_tracker.start_request()
336
+ request_id = self._http_hooks.start_request()
321
337
  async with self.session.client( # type: ignore[call-overload]
322
338
  service_name="bedrock-runtime",
323
339
  endpoint_url=self.base_url,
324
340
  config=Config(
325
- connect_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT,
326
- read_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT,
327
- retries=dict(
328
- max_attempts=config.max_retries
329
- if config.max_retries
330
- else DEFAULT_MAX_RETRIES,
331
- mode="adaptive",
332
- ),
333
- user_agent_extra=self._time_tracker.user_agent_extra(request_id),
341
+ retries=dict(mode="adaptive"),
342
+ user_agent_extra=self._http_hooks.user_agent_extra(request_id),
334
343
  ),
335
344
  **self.model_args,
336
345
  ) as client:
@@ -370,7 +379,7 @@ class BedrockAPI(ModelAPI):
370
379
  request.model_dump(exclude_none=True)
371
380
  ),
372
381
  response=response,
373
- time=self._time_tracker.end_request(request_id),
382
+ time=self._http_hooks.end_request(request_id),
374
383
  )
375
384
 
376
385
  try:
@@ -16,10 +16,10 @@ from .util import (
16
16
  chat_api_input,
17
17
  chat_api_request,
18
18
  environment_prerequisite_error,
19
- is_chat_api_rate_limit,
20
19
  model_base_url,
20
+ should_retry_chat_api_error,
21
21
  )
22
- from .util.tracker import HttpxTimeTracker
22
+ from .util.hooks import HttpxHooks
23
23
 
24
24
  # https://developers.cloudflare.com/workers-ai/models/#text-generation
25
25
 
@@ -51,7 +51,7 @@ class CloudFlareAPI(ModelAPI):
51
51
  if not self.api_key:
52
52
  raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
53
53
  self.client = httpx.AsyncClient()
54
- self._time_tracker = HttpxTimeTracker(self.client)
54
+ self._http_hooks = HttpxHooks(self.client)
55
55
  base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
56
56
  self.base_url = (
57
57
  base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
@@ -79,7 +79,7 @@ class CloudFlareAPI(ModelAPI):
79
79
  json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
80
80
 
81
81
  # request_id
82
- request_id = self._time_tracker.start_request()
82
+ request_id = self._http_hooks.start_request()
83
83
 
84
84
  # setup response
85
85
  response: dict[str, Any] = {}
@@ -88,7 +88,7 @@ class CloudFlareAPI(ModelAPI):
88
88
  return ModelCall.create(
89
89
  request=json,
90
90
  response=response,
91
- time=self._time_tracker.end_request(request_id),
91
+ time=self._http_hooks.end_request(request_id),
92
92
  )
93
93
 
94
94
  # make the call
@@ -98,10 +98,9 @@ class CloudFlareAPI(ModelAPI):
98
98
  url=f"{chat_url}/{self.model_name}",
99
99
  headers={
100
100
  "Authorization": f"Bearer {self.api_key}",
101
- HttpxTimeTracker.REQUEST_ID_HEADER: request_id,
101
+ HttpxHooks.REQUEST_ID_HEADER: request_id,
102
102
  },
103
103
  json=json,
104
- config=config,
105
104
  )
106
105
 
107
106
  # handle response
@@ -127,8 +126,8 @@ class CloudFlareAPI(ModelAPI):
127
126
  raise RuntimeError(f"Error calling {self.model_name}: {error}")
128
127
 
129
128
  @override
130
- def is_rate_limit(self, ex: BaseException) -> bool:
131
- return is_chat_api_rate_limit(ex)
129
+ def should_retry(self, ex: Exception) -> bool:
130
+ return should_retry_chat_api_error(ex)
132
131
 
133
132
  # cloudflare enforces rate limits by model for each account
134
133
  @override
@@ -3,7 +3,11 @@ from typing import Any, List, Literal, get_args
3
3
 
4
4
  from goodfire import AsyncClient
5
5
  from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
- from goodfire.api.exceptions import InvalidRequestException, RateLimitException
6
+ from goodfire.api.exceptions import (
7
+ InvalidRequestException,
8
+ RateLimitException,
9
+ ServerErrorException,
10
+ )
7
11
  from goodfire.variants.variants import SUPPORTED_MODELS, Variant
8
12
  from typing_extensions import override
9
13
 
@@ -163,9 +167,9 @@ class GoodfireAPI(ModelAPI):
163
167
  return ex
164
168
 
165
169
  @override
166
- def is_rate_limit(self, ex: BaseException) -> bool:
170
+ def should_retry(self, ex: Exception) -> bool:
167
171
  """Check if exception is due to rate limiting."""
168
- return isinstance(ex, RateLimitException)
172
+ return isinstance(ex, RateLimitException | ServerErrorException)
169
173
 
170
174
  @override
171
175
  def connection_key(self) -> str: