inspect-ai 0.3.71__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +172 -154
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_computer/_common.py +117 -58
  84. inspect_ai/tool/_tools/_computer/_computer.py +80 -57
  85. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +7 -1
  86. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +91 -0
  87. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +8 -0
  88. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +12 -0
  89. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +78 -0
  90. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +20 -0
  91. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +175 -113
  92. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +76 -20
  93. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +65 -0
  94. inspect_ai/tool/_tools/_computer/test_args.py +151 -0
  95. inspect_ai/tool/_tools/_web_search.py +30 -24
  96. inspect_ai/util/__init__.py +4 -0
  97. inspect_ai/util/_concurrency.py +5 -6
  98. inspect_ai/util/_display.py +6 -0
  99. inspect_ai/util/_json.py +170 -0
  100. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  101. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  102. inspect_ai/util/_sandbox/environment.py +56 -9
  103. inspect_ai/util/_sandbox/service.py +12 -5
  104. inspect_ai/util/_subprocess.py +94 -113
  105. inspect_ai/util/_subtask.py +2 -4
  106. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  107. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +111 -103
  108. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  109. inspect_ai/_util/timeouts.py +0 -160
  110. inspect_ai/model/_providers/util/tracker.py +0 -92
  111. inspect_ai/tool/_tools/_computer/_computer_split.py +0 -198
  112. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  113. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  114. {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -27,11 +27,16 @@ from azure.ai.inference.models import (
27
27
  UserMessage,
28
28
  )
29
29
  from azure.core.credentials import AzureKeyCredential
30
- from azure.core.exceptions import AzureError, HttpResponseError
30
+ from azure.core.exceptions import (
31
+ AzureError,
32
+ HttpResponseError,
33
+ ServiceResponseError,
34
+ )
31
35
  from typing_extensions import override
32
36
 
33
37
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
34
38
  from inspect_ai._util.content import Content, ContentImage, ContentText
39
+ from inspect_ai._util.http import is_retryable_http_status
35
40
  from inspect_ai._util.images import file_as_data_uri
36
41
  from inspect_ai.tool import ToolChoice, ToolInfo
37
42
  from inspect_ai.tool._tool_call import ToolCall
@@ -232,14 +237,11 @@ class AzureAIAPI(ModelAPI):
232
237
  return DEFAULT_MAX_TOKENS
233
238
 
234
239
  @override
235
- def is_rate_limit(self, ex: BaseException) -> bool:
236
- if isinstance(ex, HttpResponseError):
237
- return (
238
- ex.status_code == 408
239
- or ex.status_code == 409
240
- or ex.status_code == 429
241
- or ex.status_code == 500
242
- )
240
+ def should_retry(self, ex: Exception) -> bool:
241
+ if isinstance(ex, HttpResponseError) and ex.status_code is not None:
242
+ return is_retryable_http_status(ex.status_code)
243
+ elif isinstance(ex, ServiceResponseError):
244
+ return True
243
245
  else:
244
246
  return False
245
247
 
@@ -1,16 +1,14 @@
1
1
  import base64
2
+ from logging import getLogger
2
3
  from typing import Any, Literal, Tuple, Union, cast
3
4
 
4
5
  from pydantic import BaseModel, Field
5
6
  from typing_extensions import override
6
7
 
7
- from inspect_ai._util.constants import (
8
- DEFAULT_MAX_RETRIES,
9
- DEFAULT_MAX_TOKENS,
10
- DEFAULT_TIMEOUT,
11
- )
8
+ from inspect_ai._util._async import current_async_backend
9
+ from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
12
10
  from inspect_ai._util.content import Content, ContentImage, ContentText
13
- from inspect_ai._util.error import pip_dependency_error
11
+ from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
14
12
  from inspect_ai._util.images import file_as_data
15
13
  from inspect_ai._util.version import verify_required_version
16
14
  from inspect_ai.tool import ToolChoice, ToolInfo
@@ -31,7 +29,9 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
31
29
  from .util import (
32
30
  model_base_url,
33
31
  )
34
- from .util.tracker import BotoTimeTracker
32
+ from .util.hooks import ConverseHooks
33
+
34
+ logger = getLogger(__name__)
35
35
 
36
36
  # Model for Bedrock Converse API (Response)
37
37
  # generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
@@ -245,6 +245,12 @@ class BedrockAPI(ModelAPI):
245
245
  config=config,
246
246
  )
247
247
 
248
+ # raise if we are using trio
249
+ if current_async_backend() == "trio":
250
+ raise PrerequisiteError(
251
+ "ERROR: The bedrock provider does not work with the trio async backend."
252
+ )
253
+
248
254
  # save model_args
249
255
  self.model_args = model_args
250
256
 
@@ -258,7 +264,7 @@ class BedrockAPI(ModelAPI):
258
264
  self.session = aioboto3.Session()
259
265
 
260
266
  # create time tracker
261
- self._time_tracker = BotoTimeTracker(self.session)
267
+ self._http_hooks = ConverseHooks(self.session)
262
268
 
263
269
  except ImportError:
264
270
  raise pip_dependency_error("Bedrock API", ["aioboto3"])
@@ -288,15 +294,25 @@ class BedrockAPI(ModelAPI):
288
294
  return DEFAULT_MAX_TOKENS
289
295
 
290
296
  @override
291
- def is_rate_limit(self, ex: BaseException) -> bool:
297
+ def should_retry(self, ex: Exception) -> bool:
292
298
  from botocore.exceptions import ClientError
293
299
 
294
300
  # Look for an explicit throttle exception
295
301
  if isinstance(ex, ClientError):
296
- if ex.response["Error"]["Code"] == "ThrottlingException":
297
- return True
298
-
299
- return super().is_rate_limit(ex)
302
+ error_code = ex.response.get("Error", {}).get("Code", "")
303
+ return error_code in [
304
+ "ThrottlingException",
305
+ "RequestLimitExceeded",
306
+ "Throttling",
307
+ "RequestThrottled",
308
+ "TooManyRequestsException",
309
+ "ProvisionedThroughputExceededException",
310
+ "TransactionInProgressException",
311
+ "RequestTimeout",
312
+ "ServiceUnavailable",
313
+ ]
314
+ else:
315
+ return False
300
316
 
301
317
  @override
302
318
  def collapse_user_messages(self) -> bool:
@@ -317,20 +333,13 @@ class BedrockAPI(ModelAPI):
317
333
  from botocore.exceptions import ClientError
318
334
 
319
335
  # The bedrock client
320
- request_id = self._time_tracker.start_request()
336
+ request_id = self._http_hooks.start_request()
321
337
  async with self.session.client( # type: ignore[call-overload]
322
338
  service_name="bedrock-runtime",
323
339
  endpoint_url=self.base_url,
324
340
  config=Config(
325
- connect_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT,
326
- read_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT,
327
- retries=dict(
328
- max_attempts=config.max_retries
329
- if config.max_retries
330
- else DEFAULT_MAX_RETRIES,
331
- mode="adaptive",
332
- ),
333
- user_agent_extra=self._time_tracker.user_agent_extra(request_id),
341
+ retries=dict(mode="adaptive"),
342
+ user_agent_extra=self._http_hooks.user_agent_extra(request_id),
334
343
  ),
335
344
  **self.model_args,
336
345
  ) as client:
@@ -370,7 +379,7 @@ class BedrockAPI(ModelAPI):
370
379
  request.model_dump(exclude_none=True)
371
380
  ),
372
381
  response=response,
373
- time=self._time_tracker.end_request(request_id),
382
+ time=self._http_hooks.end_request(request_id),
374
383
  )
375
384
 
376
385
  try:
@@ -16,10 +16,10 @@ from .util import (
16
16
  chat_api_input,
17
17
  chat_api_request,
18
18
  environment_prerequisite_error,
19
- is_chat_api_rate_limit,
20
19
  model_base_url,
20
+ should_retry_chat_api_error,
21
21
  )
22
- from .util.tracker import HttpxTimeTracker
22
+ from .util.hooks import HttpxHooks
23
23
 
24
24
  # https://developers.cloudflare.com/workers-ai/models/#text-generation
25
25
 
@@ -51,7 +51,7 @@ class CloudFlareAPI(ModelAPI):
51
51
  if not self.api_key:
52
52
  raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
53
53
  self.client = httpx.AsyncClient()
54
- self._time_tracker = HttpxTimeTracker(self.client)
54
+ self._http_hooks = HttpxHooks(self.client)
55
55
  base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
56
56
  self.base_url = (
57
57
  base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
@@ -79,7 +79,7 @@ class CloudFlareAPI(ModelAPI):
79
79
  json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
80
80
 
81
81
  # request_id
82
- request_id = self._time_tracker.start_request()
82
+ request_id = self._http_hooks.start_request()
83
83
 
84
84
  # setup response
85
85
  response: dict[str, Any] = {}
@@ -88,7 +88,7 @@ class CloudFlareAPI(ModelAPI):
88
88
  return ModelCall.create(
89
89
  request=json,
90
90
  response=response,
91
- time=self._time_tracker.end_request(request_id),
91
+ time=self._http_hooks.end_request(request_id),
92
92
  )
93
93
 
94
94
  # make the call
@@ -98,10 +98,9 @@ class CloudFlareAPI(ModelAPI):
98
98
  url=f"{chat_url}/{self.model_name}",
99
99
  headers={
100
100
  "Authorization": f"Bearer {self.api_key}",
101
- HttpxTimeTracker.REQUEST_ID_HEADER: request_id,
101
+ HttpxHooks.REQUEST_ID_HEADER: request_id,
102
102
  },
103
103
  json=json,
104
- config=config,
105
104
  )
106
105
 
107
106
  # handle response
@@ -127,8 +126,8 @@ class CloudFlareAPI(ModelAPI):
127
126
  raise RuntimeError(f"Error calling {self.model_name}: {error}")
128
127
 
129
128
  @override
130
- def is_rate_limit(self, ex: BaseException) -> bool:
131
- return is_chat_api_rate_limit(ex)
129
+ def should_retry(self, ex: Exception) -> bool:
130
+ return should_retry_chat_api_error(ex)
132
131
 
133
132
  # cloudflare enforces rate limits by model for each account
134
133
  @override
@@ -3,7 +3,11 @@ from typing import Any, List, Literal, get_args
3
3
 
4
4
  from goodfire import AsyncClient
5
5
  from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
- from goodfire.api.exceptions import InvalidRequestException, RateLimitException
6
+ from goodfire.api.exceptions import (
7
+ InvalidRequestException,
8
+ RateLimitException,
9
+ ServerErrorException,
10
+ )
7
11
  from goodfire.variants.variants import SUPPORTED_MODELS, Variant
8
12
  from typing_extensions import override
9
13
 
@@ -163,9 +167,9 @@ class GoodfireAPI(ModelAPI):
163
167
  return ex
164
168
 
165
169
  @override
166
- def is_rate_limit(self, ex: BaseException) -> bool:
170
+ def should_retry(self, ex: Exception) -> bool:
167
171
  """Check if exception is due to rate limiting."""
168
- return isinstance(ex, RateLimitException)
172
+ return isinstance(ex, RateLimitException | ServerErrorException)
169
173
 
170
174
  @override
171
175
  def connection_key(self) -> str:
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import functools
3
2
  import hashlib
4
3
  import json
@@ -9,6 +8,7 @@ from logging import getLogger
9
8
  from typing import Any
10
9
 
11
10
  # SDK Docs: https://googleapis.github.io/python-genai/
11
+ import anyio
12
12
  from google.genai import Client # type: ignore
13
13
  from google.genai.errors import APIError, ClientError # type: ignore
14
14
  from google.genai.types import ( # type: ignore
@@ -26,6 +26,7 @@ from google.genai.types import ( # type: ignore
26
26
  GenerationConfig,
27
27
  HarmBlockThreshold,
28
28
  HarmCategory,
29
+ HttpOptions,
29
30
  Part,
30
31
  SafetySetting,
31
32
  SafetySettingDict,
@@ -49,6 +50,7 @@ from inspect_ai._util.content import (
49
50
  ContentVideo,
50
51
  )
51
52
  from inspect_ai._util.error import PrerequisiteError
53
+ from inspect_ai._util.http import is_retryable_http_status
52
54
  from inspect_ai._util.images import file_as_data
53
55
  from inspect_ai._util.kvstore import inspect_kvstore
54
56
  from inspect_ai._util.trace import trace_message
@@ -69,6 +71,7 @@ from inspect_ai.model import (
69
71
  )
70
72
  from inspect_ai.model._model_call import ModelCall
71
73
  from inspect_ai.model._providers.util import model_base_url
74
+ from inspect_ai.model._providers.util.hooks import HttpHooks, urllib3_hooks
72
75
  from inspect_ai.tool import (
73
76
  ToolCall,
74
77
  ToolChoice,
@@ -199,11 +202,15 @@ class GoogleGenAIAPI(ModelAPI):
199
202
  tool_choice: ToolChoice,
200
203
  config: GenerateConfig,
201
204
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
205
+ # generate request_id
206
+ request_id = urllib3_hooks().start_request()
207
+
202
208
  # Create google-genai types.
203
209
  gemini_contents = await as_chat_messages(self.client, input)
204
210
  gemini_tools = chat_tools(tools) if len(tools) > 0 else None
205
211
  gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
206
212
  parameters = GenerateContentConfig(
213
+ http_options=HttpOptions(headers={HttpHooks.REQUEST_ID_HEADER: request_id}),
207
214
  temperature=config.temperature,
208
215
  top_p=config.top_p,
209
216
  top_k=config.top_k,
@@ -219,6 +226,11 @@ class GoogleGenAIAPI(ModelAPI):
219
226
  self.client, input
220
227
  ),
221
228
  )
229
+ if config.response_schema is not None:
230
+ parameters.response_mime_type = "application/json"
231
+ parameters.response_schema = schema_from_param(
232
+ config.response_schema.json_schema, nullable=None
233
+ )
222
234
 
223
235
  response: GenerateContentResponse | None = None
224
236
 
@@ -230,10 +242,9 @@ class GoogleGenAIAPI(ModelAPI):
230
242
  tools=gemini_tools,
231
243
  tool_config=gemini_tool_config,
232
244
  response=response,
245
+ time=urllib3_hooks().end_request(request_id),
233
246
  )
234
247
 
235
- # TODO: would need to monkey patch AuthorizedSession.request
236
-
237
248
  try:
238
249
  response = await self.client.aio.models.generate_content(
239
250
  model=self.model_name,
@@ -252,11 +263,25 @@ class GoogleGenAIAPI(ModelAPI):
252
263
  return output, model_call()
253
264
 
254
265
  @override
255
- def is_rate_limit(self, ex: BaseException) -> bool:
256
- # see https://cloud.google.com/storage/docs/retry-strategy
257
- return isinstance(ex, APIError) and (
258
- ex.code in (408, 429, 429) or ex.code >= 500
259
- )
266
+ def should_retry(self, ex: Exception) -> bool:
267
+ import requests # type: ignore
268
+
269
+ # standard http errors
270
+ if isinstance(ex, APIError):
271
+ return is_retryable_http_status(ex.status)
272
+
273
+ # low-level requests exceptions
274
+ elif isinstance(ex, requests.exceptions.RequestException):
275
+ return isinstance(
276
+ ex,
277
+ (
278
+ requests.exceptions.ConnectionError
279
+ | requests.exceptions.ConnectTimeout
280
+ | requests.exceptions.ChunkedEncodingError
281
+ ),
282
+ )
283
+ else:
284
+ return False
260
285
 
261
286
  @override
262
287
  def connection_key(self) -> str:
@@ -296,6 +321,7 @@ def build_model_call(
296
321
  tools: list[Tool] | None,
297
322
  tool_config: ToolConfig | None,
298
323
  response: GenerateContentResponse | None,
324
+ time: float | None,
299
325
  ) -> ModelCall:
300
326
  return ModelCall.create(
301
327
  request=dict(
@@ -307,6 +333,7 @@ def build_model_call(
307
333
  ),
308
334
  response=response if response is not None else {},
309
335
  filter=model_call_filter,
336
+ time=time,
310
337
  )
311
338
 
312
339
 
@@ -464,7 +491,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
464
491
 
465
492
 
466
493
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
467
- def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
494
+ def schema_from_param(
495
+ param: ToolParam | ToolParams, nullable: bool | None = False
496
+ ) -> Schema:
468
497
  if isinstance(param, ToolParams):
469
498
  param = ToolParam(
470
499
  type=param.type, properties=param.properties, required=param.required
@@ -529,10 +558,13 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
529
558
 
530
559
 
531
560
  def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
532
- # check for completion text
533
- content = ""
534
561
  # content can be None when the finish_reason is SAFETY
535
- if candidate.content is not None:
562
+ if candidate.content is None:
563
+ content = ""
564
+ # content.parts can be None when the finish_reason is MALFORMED_FUNCTION_CALL
565
+ elif candidate.content.parts is None:
566
+ content = ""
567
+ else:
536
568
  content = " ".join(
537
569
  [
538
570
  part.text
@@ -680,6 +712,8 @@ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
680
712
  ):
681
713
  return "content_filter"
682
714
  case _:
715
+ # Note: to avoid adding another option to StopReason,
716
+ # this includes FinishReason.MALFORMED_FUNCTION_CALL
683
717
  return "unknown"
684
718
 
685
719
 
@@ -775,7 +809,7 @@ async def file_for_content(
775
809
  file=BytesIO(content_bytes), config=dict(mime_type=mime_type)
776
810
  )
777
811
  while upload.state.name == "PROCESSING":
778
- await asyncio.sleep(3)
812
+ await anyio.sleep(3)
779
813
  upload = client.files.get(name=upload.name)
780
814
  if upload.state.name == "FAILED":
781
815
  trace(f"Failed to upload file '{upload.name}: {upload.error}")
@@ -5,8 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional
5
5
 
6
6
  import httpx
7
7
  from groq import (
8
+ APIStatusError,
9
+ APITimeoutError,
8
10
  AsyncGroq,
9
- RateLimitError,
10
11
  )
11
12
  from groq.types.chat import (
12
13
  ChatCompletion,
@@ -25,10 +26,10 @@ from typing_extensions import override
25
26
 
26
27
  from inspect_ai._util.constants import (
27
28
  BASE_64_DATA_REMOVED,
28
- DEFAULT_MAX_RETRIES,
29
29
  DEFAULT_MAX_TOKENS,
30
30
  )
31
31
  from inspect_ai._util.content import Content, ContentReasoning, ContentText
32
+ from inspect_ai._util.http import is_retryable_http_status
32
33
  from inspect_ai._util.images import file_as_data_uri
33
34
  from inspect_ai._util.url import is_http_url
34
35
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
@@ -54,7 +55,7 @@ from .util import (
54
55
  environment_prerequisite_error,
55
56
  model_base_url,
56
57
  )
57
- from .util.tracker import HttpxTimeTracker
58
+ from .util.hooks import HttpxHooks
58
59
 
59
60
  GROQ_API_KEY = "GROQ_API_KEY"
60
61
 
@@ -84,18 +85,12 @@ class GroqAPI(ModelAPI):
84
85
  self.client = AsyncGroq(
85
86
  api_key=self.api_key,
86
87
  base_url=model_base_url(base_url, "GROQ_BASE_URL"),
87
- max_retries=(
88
- config.max_retries
89
- if config.max_retries is not None
90
- else DEFAULT_MAX_RETRIES
91
- ),
92
- timeout=config.timeout if config.timeout is not None else 60.0,
93
88
  **model_args,
94
89
  http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
95
90
  )
96
91
 
97
92
  # create time tracker
98
- self._time_tracker = HttpxTimeTracker(self.client._client)
93
+ self._http_hooks = HttpxHooks(self.client._client)
99
94
 
100
95
  @override
101
96
  async def close(self) -> None:
@@ -109,7 +104,7 @@ class GroqAPI(ModelAPI):
109
104
  config: GenerateConfig,
110
105
  ) -> tuple[ModelOutput, ModelCall]:
111
106
  # allocate request_id (so we can see it from ModelCall)
112
- request_id = self._time_tracker.start_request()
107
+ request_id = self._http_hooks.start_request()
113
108
 
114
109
  # setup request and response for ModelCall
115
110
  request: dict[str, Any] = {}
@@ -120,7 +115,7 @@ class GroqAPI(ModelAPI):
120
115
  request=request,
121
116
  response=response,
122
117
  filter=model_call_filter,
123
- time=self._time_tracker.end_request(request_id),
118
+ time=self._http_hooks.end_request(request_id),
124
119
  )
125
120
 
126
121
  messages = await as_groq_chat_messages(input)
@@ -137,7 +132,7 @@ class GroqAPI(ModelAPI):
137
132
  request = dict(
138
133
  messages=messages,
139
134
  model=self.model_name,
140
- extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
135
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
141
136
  **params,
142
137
  )
143
138
 
@@ -215,8 +210,13 @@ class GroqAPI(ModelAPI):
215
210
  ]
216
211
 
217
212
  @override
218
- def is_rate_limit(self, ex: BaseException) -> bool:
219
- return isinstance(ex, RateLimitError)
213
+ def should_retry(self, ex: Exception) -> bool:
214
+ if isinstance(ex, APIStatusError):
215
+ return is_retryable_http_status(ex.status_code)
216
+ elif isinstance(ex, APITimeoutError):
217
+ return True
218
+ else:
219
+ return False
220
220
 
221
221
  @override
222
222
  def connection_key(self) -> str:
@@ -1,15 +1,19 @@
1
- import asyncio
1
+ import concurrent
2
+ import concurrent.futures
2
3
  import copy
3
4
  import functools
4
5
  import gc
5
6
  import json
6
7
  import os
7
8
  import time
9
+ from concurrent.futures import Future
8
10
  from dataclasses import dataclass
11
+ from logging import getLogger
9
12
  from queue import Empty, Queue
10
13
  from threading import Thread
11
14
  from typing import Any, Literal, Protocol, cast
12
15
 
16
+ import anyio
13
17
  import numpy as np
14
18
  import torch # type: ignore
15
19
  from torch import Tensor # type: ignore
@@ -23,6 +27,7 @@ from typing_extensions import override
23
27
 
24
28
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
25
29
  from inspect_ai._util.content import ContentText
30
+ from inspect_ai._util.trace import trace_action
26
31
  from inspect_ai.tool import ToolChoice, ToolInfo
27
32
 
28
33
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -38,6 +43,9 @@ from .._model_output import (
38
43
  )
39
44
  from .util import ChatAPIHandler, HFHandler
40
45
 
46
+ logger = getLogger(__name__)
47
+
48
+
41
49
  HF_TOKEN = "HF_TOKEN"
42
50
 
43
51
 
@@ -385,8 +393,7 @@ class GenerateOutput:
385
393
  @dataclass
386
394
  class _QueueItem:
387
395
  input: GenerateInput
388
- future: asyncio.Future[GenerateOutput]
389
- loop: asyncio.AbstractEventLoop
396
+ future: Future[GenerateOutput]
390
397
 
391
398
 
392
399
  batch_thread: Thread | None = None
@@ -402,25 +409,26 @@ async def batched_generate(input: GenerateInput) -> GenerateOutput:
402
409
  batch_thread.start()
403
410
 
404
411
  # enqueue the job
405
- loop = asyncio.get_event_loop()
406
- future: asyncio.Future[GenerateOutput] = loop.create_future()
407
- batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
408
-
409
- # await the job
410
- await future
412
+ future = Future[GenerateOutput]()
413
+ batch_queue.put(_QueueItem(input=input, future=future))
411
414
 
412
- # return it
413
- return future.result()
415
+ # await the future
416
+ with trace_action(logger, "HF Batched Generate", "HF Batched Generate"):
417
+ while True:
418
+ try:
419
+ return future.result(timeout=0.01)
420
+ except concurrent.futures.TimeoutError:
421
+ pass
422
+ await anyio.sleep(1)
414
423
 
415
424
 
416
425
  def process_batches() -> None:
417
426
  while True:
418
427
  # drain the queue (wait until no new messages have shown up for 2 seconds)
419
- inputs: list[tuple[GenerateInput, asyncio.Future[GenerateOutput]]] = []
428
+ inputs: list[tuple[GenerateInput, Future[GenerateOutput]]] = []
420
429
  while True:
421
430
  try:
422
431
  input = batch_queue.get(timeout=2)
423
- loop = input.loop
424
432
  inputs.append((input.input, input.future))
425
433
  if len(inputs) == input.input.batch_size:
426
434
  # max batch size reached
@@ -480,8 +488,7 @@ def process_batches() -> None:
480
488
  # asyncio futures are not thread safe, so we need to pass the event loop
481
489
  # down to this point, so we can mark the future as done in a thread safe manner.
482
490
  # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
483
- loop.call_soon_threadsafe(
484
- future.set_result,
491
+ future.set_result(
485
492
  GenerateOutput(
486
493
  output=output,
487
494
  input_tokens=input_tokens,
@@ -489,13 +496,13 @@ def process_batches() -> None:
489
496
  total_tokens=input_tokens + output_tokens,
490
497
  logprobs=logprobs[i] if logprobs is not None else None,
491
498
  time=total_time,
492
- ),
499
+ )
493
500
  )
494
501
 
495
502
  except Exception as ex:
496
503
  for inp in inputs:
497
504
  future = inp[1]
498
- loop.call_soon_threadsafe(future.set_exception, ex)
505
+ future.set_exception(ex)
499
506
 
500
507
 
501
508
  def extract_logprobs(