inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import functools
3
2
  import hashlib
4
3
  import json
@@ -9,6 +8,7 @@ from logging import getLogger
9
8
  from typing import Any
10
9
 
11
10
  # SDK Docs: https://googleapis.github.io/python-genai/
11
+ import anyio
12
12
  from google.genai import Client # type: ignore
13
13
  from google.genai.errors import APIError, ClientError # type: ignore
14
14
  from google.genai.types import ( # type: ignore
@@ -26,6 +26,7 @@ from google.genai.types import ( # type: ignore
26
26
  GenerationConfig,
27
27
  HarmBlockThreshold,
28
28
  HarmCategory,
29
+ HttpOptions,
29
30
  Part,
30
31
  SafetySetting,
31
32
  SafetySettingDict,
@@ -49,6 +50,7 @@ from inspect_ai._util.content import (
49
50
  ContentVideo,
50
51
  )
51
52
  from inspect_ai._util.error import PrerequisiteError
53
+ from inspect_ai._util.http import is_retryable_http_status
52
54
  from inspect_ai._util.images import file_as_data
53
55
  from inspect_ai._util.kvstore import inspect_kvstore
54
56
  from inspect_ai._util.trace import trace_message
@@ -69,6 +71,7 @@ from inspect_ai.model import (
69
71
  )
70
72
  from inspect_ai.model._model_call import ModelCall
71
73
  from inspect_ai.model._providers.util import model_base_url
74
+ from inspect_ai.model._providers.util.hooks import HttpHooks, urllib3_hooks
72
75
  from inspect_ai.tool import (
73
76
  ToolCall,
74
77
  ToolChoice,
@@ -199,11 +202,15 @@ class GoogleGenAIAPI(ModelAPI):
199
202
  tool_choice: ToolChoice,
200
203
  config: GenerateConfig,
201
204
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
205
+ # generate request_id
206
+ request_id = urllib3_hooks().start_request()
207
+
202
208
  # Create google-genai types.
203
209
  gemini_contents = await as_chat_messages(self.client, input)
204
210
  gemini_tools = chat_tools(tools) if len(tools) > 0 else None
205
211
  gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
206
212
  parameters = GenerateContentConfig(
213
+ http_options=HttpOptions(headers={HttpHooks.REQUEST_ID_HEADER: request_id}),
207
214
  temperature=config.temperature,
208
215
  top_p=config.top_p,
209
216
  top_k=config.top_k,
@@ -219,6 +226,11 @@ class GoogleGenAIAPI(ModelAPI):
219
226
  self.client, input
220
227
  ),
221
228
  )
229
+ if config.response_schema is not None:
230
+ parameters.response_mime_type = "application/json"
231
+ parameters.response_schema = schema_from_param(
232
+ config.response_schema.json_schema, nullable=None
233
+ )
222
234
 
223
235
  response: GenerateContentResponse | None = None
224
236
 
@@ -230,10 +242,9 @@ class GoogleGenAIAPI(ModelAPI):
230
242
  tools=gemini_tools,
231
243
  tool_config=gemini_tool_config,
232
244
  response=response,
245
+ time=urllib3_hooks().end_request(request_id),
233
246
  )
234
247
 
235
- # TODO: would need to monkey patch AuthorizedSession.request
236
-
237
248
  try:
238
249
  response = await self.client.aio.models.generate_content(
239
250
  model=self.model_name,
@@ -252,11 +263,25 @@ class GoogleGenAIAPI(ModelAPI):
252
263
  return output, model_call()
253
264
 
254
265
  @override
255
- def is_rate_limit(self, ex: BaseException) -> bool:
256
- # see https://cloud.google.com/storage/docs/retry-strategy
257
- return isinstance(ex, APIError) and (
258
- ex.code in (408, 429, 429) or ex.code >= 500
259
- )
266
+ def should_retry(self, ex: Exception) -> bool:
267
+ import requests # type: ignore
268
+
269
+ # standard http errors
270
+ if isinstance(ex, APIError):
271
+ return is_retryable_http_status(ex.status)
272
+
273
+ # low-level requests exceptions
274
+ elif isinstance(ex, requests.exceptions.RequestException):
275
+ return isinstance(
276
+ ex,
277
+ (
278
+ requests.exceptions.ConnectionError
279
+ | requests.exceptions.ConnectTimeout
280
+ | requests.exceptions.ChunkedEncodingError
281
+ ),
282
+ )
283
+ else:
284
+ return False
260
285
 
261
286
  @override
262
287
  def connection_key(self) -> str:
@@ -296,6 +321,7 @@ def build_model_call(
296
321
  tools: list[Tool] | None,
297
322
  tool_config: ToolConfig | None,
298
323
  response: GenerateContentResponse | None,
324
+ time: float | None,
299
325
  ) -> ModelCall:
300
326
  return ModelCall.create(
301
327
  request=dict(
@@ -307,6 +333,7 @@ def build_model_call(
307
333
  ),
308
334
  response=response if response is not None else {},
309
335
  filter=model_call_filter,
336
+ time=time,
310
337
  )
311
338
 
312
339
 
@@ -464,7 +491,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
464
491
 
465
492
 
466
493
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
467
- def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
494
+ def schema_from_param(
495
+ param: ToolParam | ToolParams, nullable: bool | None = False
496
+ ) -> Schema:
468
497
  if isinstance(param, ToolParams):
469
498
  param = ToolParam(
470
499
  type=param.type, properties=param.properties, required=param.required
@@ -529,10 +558,13 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
529
558
 
530
559
 
531
560
  def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
532
- # check for completion text
533
- content = ""
534
561
  # content can be None when the finish_reason is SAFETY
535
- if candidate.content is not None:
562
+ if candidate.content is None:
563
+ content = ""
564
+ # content.parts can be None when the finish_reason is MALFORMED_FUNCTION_CALL
565
+ elif candidate.content.parts is None:
566
+ content = ""
567
+ else:
536
568
  content = " ".join(
537
569
  [
538
570
  part.text
@@ -680,6 +712,8 @@ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
680
712
  ):
681
713
  return "content_filter"
682
714
  case _:
715
+ # Note: to avoid adding another option to StopReason,
716
+ # this includes FinishReason.MALFORMED_FUNCTION_CALL
683
717
  return "unknown"
684
718
 
685
719
 
@@ -775,7 +809,7 @@ async def file_for_content(
775
809
  file=BytesIO(content_bytes), config=dict(mime_type=mime_type)
776
810
  )
777
811
  while upload.state.name == "PROCESSING":
778
- await asyncio.sleep(3)
812
+ await anyio.sleep(3)
779
813
  upload = client.files.get(name=upload.name)
780
814
  if upload.state.name == "FAILED":
781
815
  trace(f"Failed to upload file '{upload.name}: {upload.error}")
@@ -5,8 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional
5
5
 
6
6
  import httpx
7
7
  from groq import (
8
+ APIStatusError,
9
+ APITimeoutError,
8
10
  AsyncGroq,
9
- RateLimitError,
10
11
  )
11
12
  from groq.types.chat import (
12
13
  ChatCompletion,
@@ -25,10 +26,10 @@ from typing_extensions import override
25
26
 
26
27
  from inspect_ai._util.constants import (
27
28
  BASE_64_DATA_REMOVED,
28
- DEFAULT_MAX_RETRIES,
29
29
  DEFAULT_MAX_TOKENS,
30
30
  )
31
31
  from inspect_ai._util.content import Content, ContentReasoning, ContentText
32
+ from inspect_ai._util.http import is_retryable_http_status
32
33
  from inspect_ai._util.images import file_as_data_uri
33
34
  from inspect_ai._util.url import is_http_url
34
35
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
@@ -54,7 +55,7 @@ from .util import (
54
55
  environment_prerequisite_error,
55
56
  model_base_url,
56
57
  )
57
- from .util.tracker import HttpxTimeTracker
58
+ from .util.hooks import HttpxHooks
58
59
 
59
60
  GROQ_API_KEY = "GROQ_API_KEY"
60
61
 
@@ -84,18 +85,12 @@ class GroqAPI(ModelAPI):
84
85
  self.client = AsyncGroq(
85
86
  api_key=self.api_key,
86
87
  base_url=model_base_url(base_url, "GROQ_BASE_URL"),
87
- max_retries=(
88
- config.max_retries
89
- if config.max_retries is not None
90
- else DEFAULT_MAX_RETRIES
91
- ),
92
- timeout=config.timeout if config.timeout is not None else 60.0,
93
88
  **model_args,
94
89
  http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
95
90
  )
96
91
 
97
92
  # create time tracker
98
- self._time_tracker = HttpxTimeTracker(self.client._client)
93
+ self._http_hooks = HttpxHooks(self.client._client)
99
94
 
100
95
  @override
101
96
  async def close(self) -> None:
@@ -109,7 +104,7 @@ class GroqAPI(ModelAPI):
109
104
  config: GenerateConfig,
110
105
  ) -> tuple[ModelOutput, ModelCall]:
111
106
  # allocate request_id (so we can see it from ModelCall)
112
- request_id = self._time_tracker.start_request()
107
+ request_id = self._http_hooks.start_request()
113
108
 
114
109
  # setup request and response for ModelCall
115
110
  request: dict[str, Any] = {}
@@ -120,7 +115,7 @@ class GroqAPI(ModelAPI):
120
115
  request=request,
121
116
  response=response,
122
117
  filter=model_call_filter,
123
- time=self._time_tracker.end_request(request_id),
118
+ time=self._http_hooks.end_request(request_id),
124
119
  )
125
120
 
126
121
  messages = await as_groq_chat_messages(input)
@@ -137,7 +132,7 @@ class GroqAPI(ModelAPI):
137
132
  request = dict(
138
133
  messages=messages,
139
134
  model=self.model_name,
140
- extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
135
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
141
136
  **params,
142
137
  )
143
138
 
@@ -215,8 +210,13 @@ class GroqAPI(ModelAPI):
215
210
  ]
216
211
 
217
212
  @override
218
- def is_rate_limit(self, ex: BaseException) -> bool:
219
- return isinstance(ex, RateLimitError)
213
+ def should_retry(self, ex: Exception) -> bool:
214
+ if isinstance(ex, APIStatusError):
215
+ return is_retryable_http_status(ex.status_code)
216
+ elif isinstance(ex, APITimeoutError):
217
+ return True
218
+ else:
219
+ return False
220
220
 
221
221
  @override
222
222
  def connection_key(self) -> str:
@@ -1,15 +1,19 @@
1
- import asyncio
1
+ import concurrent
2
+ import concurrent.futures
2
3
  import copy
3
4
  import functools
4
5
  import gc
5
6
  import json
6
7
  import os
7
8
  import time
9
+ from concurrent.futures import Future
8
10
  from dataclasses import dataclass
11
+ from logging import getLogger
9
12
  from queue import Empty, Queue
10
13
  from threading import Thread
11
14
  from typing import Any, Literal, Protocol, cast
12
15
 
16
+ import anyio
13
17
  import numpy as np
14
18
  import torch # type: ignore
15
19
  from torch import Tensor # type: ignore
@@ -23,6 +27,7 @@ from typing_extensions import override
23
27
 
24
28
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
25
29
  from inspect_ai._util.content import ContentText
30
+ from inspect_ai._util.trace import trace_action
26
31
  from inspect_ai.tool import ToolChoice, ToolInfo
27
32
 
28
33
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -38,6 +43,9 @@ from .._model_output import (
38
43
  )
39
44
  from .util import ChatAPIHandler, HFHandler
40
45
 
46
+ logger = getLogger(__name__)
47
+
48
+
41
49
  HF_TOKEN = "HF_TOKEN"
42
50
 
43
51
 
@@ -385,8 +393,7 @@ class GenerateOutput:
385
393
  @dataclass
386
394
  class _QueueItem:
387
395
  input: GenerateInput
388
- future: asyncio.Future[GenerateOutput]
389
- loop: asyncio.AbstractEventLoop
396
+ future: Future[GenerateOutput]
390
397
 
391
398
 
392
399
  batch_thread: Thread | None = None
@@ -402,25 +409,26 @@ async def batched_generate(input: GenerateInput) -> GenerateOutput:
402
409
  batch_thread.start()
403
410
 
404
411
  # enqueue the job
405
- loop = asyncio.get_event_loop()
406
- future: asyncio.Future[GenerateOutput] = loop.create_future()
407
- batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
408
-
409
- # await the job
410
- await future
412
+ future = Future[GenerateOutput]()
413
+ batch_queue.put(_QueueItem(input=input, future=future))
411
414
 
412
- # return it
413
- return future.result()
415
+ # await the future
416
+ with trace_action(logger, "HF Batched Generate", "HF Batched Generate"):
417
+ while True:
418
+ try:
419
+ return future.result(timeout=0.01)
420
+ except concurrent.futures.TimeoutError:
421
+ pass
422
+ await anyio.sleep(1)
414
423
 
415
424
 
416
425
  def process_batches() -> None:
417
426
  while True:
418
427
  # drain the queue (wait until no new messages have shown up for 2 seconds)
419
- inputs: list[tuple[GenerateInput, asyncio.Future[GenerateOutput]]] = []
428
+ inputs: list[tuple[GenerateInput, Future[GenerateOutput]]] = []
420
429
  while True:
421
430
  try:
422
431
  input = batch_queue.get(timeout=2)
423
- loop = input.loop
424
432
  inputs.append((input.input, input.future))
425
433
  if len(inputs) == input.input.batch_size:
426
434
  # max batch size reached
@@ -480,8 +488,7 @@ def process_batches() -> None:
480
488
  # asyncio futures are not thread safe, so we need to pass the event loop
481
489
  # down to this point, so we can mark the future as done in a thread safe manner.
482
490
  # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
483
- loop.call_soon_threadsafe(
484
- future.set_result,
491
+ future.set_result(
485
492
  GenerateOutput(
486
493
  output=output,
487
494
  input_tokens=input_tokens,
@@ -489,13 +496,13 @@ def process_batches() -> None:
489
496
  total_tokens=input_tokens + output_tokens,
490
497
  logprobs=logprobs[i] if logprobs is not None else None,
491
498
  time=total_time,
492
- ),
499
+ )
493
500
  )
494
501
 
495
502
  except Exception as ex:
496
503
  for inp in inputs:
497
504
  future = inp[1]
498
- loop.call_soon_threadsafe(future.set_exception, ex)
505
+ future.set_exception(ex)
499
506
 
500
507
 
501
508
  def extract_logprobs(
@@ -7,6 +7,7 @@ from httpcore import ReadTimeout
7
7
  from httpx import ReadTimeout as AsyncReadTimeout
8
8
  from mistralai import (
9
9
  ContentChunk,
10
+ DocumentURLChunk,
10
11
  FunctionCall,
11
12
  FunctionName,
12
13
  ImageURL,
@@ -22,6 +23,12 @@ from mistralai.models import (
22
23
  ChatCompletionChoice as MistralChatCompletionChoice,
23
24
  )
24
25
  from mistralai.models import Function as MistralFunction
26
+ from mistralai.models import (
27
+ JSONSchema as MistralJSONSchema,
28
+ )
29
+ from mistralai.models import (
30
+ ResponseFormat as MistralResponseFormat,
31
+ )
25
32
  from mistralai.models import SDKError
26
33
  from mistralai.models import SystemMessage as MistralSystemMessage
27
34
  from mistralai.models import Tool as MistralTool
@@ -38,11 +45,9 @@ from typing_extensions import override
38
45
 
39
46
  # TODO: Migration guide:
40
47
  # https://github.com/mistralai/client-python/blob/main/MIGRATION.md
41
- from inspect_ai._util.constants import (
42
- DEFAULT_TIMEOUT,
43
- NO_CONTENT,
44
- )
48
+ from inspect_ai._util.constants import NO_CONTENT
45
49
  from inspect_ai._util.content import Content, ContentImage, ContentText
50
+ from inspect_ai._util.http import is_retryable_http_status
46
51
  from inspect_ai._util.images import file_as_data_uri
47
52
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
53
 
@@ -61,7 +66,7 @@ from .._model_output import (
61
66
  StopReason,
62
67
  )
63
68
  from .util import environment_prerequisite_error, model_base_url
64
- from .util.tracker import HttpxTimeTracker
69
+ from .util.hooks import HttpxHooks
65
70
 
66
71
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
67
72
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -127,16 +132,12 @@ class MistralAPI(ModelAPI):
127
132
  config: GenerateConfig,
128
133
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
129
134
  # create client
130
- with Mistral(
131
- api_key=self.api_key,
132
- timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
133
- **self.model_args,
134
- ) as client:
135
+ with Mistral(api_key=self.api_key, **self.model_args) as client:
135
136
  # create time tracker
136
- time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client)
137
+ http_hooks = HttpxHooks(client.sdk_configuration.async_client)
137
138
 
138
139
  # build request
139
- request_id = time_tracker.start_request()
140
+ request_id = http_hooks.start_request()
140
141
  request: dict[str, Any] = dict(
141
142
  model=self.model_name,
142
143
  messages=await mistral_chat_messages(input),
@@ -144,7 +145,7 @@ class MistralAPI(ModelAPI):
144
145
  tool_choice=(
145
146
  mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
146
147
  ),
147
- http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
148
+ http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
148
149
  )
149
150
  if config.temperature is not None:
150
151
  request["temperature"] = config.temperature
@@ -154,6 +155,18 @@ class MistralAPI(ModelAPI):
154
155
  request["max_tokens"] = config.max_tokens
155
156
  if config.seed is not None:
156
157
  request["random_seed"] = config.seed
158
+ if config.response_schema is not None:
159
+ request["response_format"] = MistralResponseFormat(
160
+ type="json_schema",
161
+ json_schema=MistralJSONSchema(
162
+ name=config.response_schema.name,
163
+ description=config.response_schema.description,
164
+ schema_definition=config.response_schema.json_schema.model_dump(
165
+ exclude_none=True
166
+ ),
167
+ strict=config.response_schema.strict,
168
+ ),
169
+ )
157
170
 
158
171
  # prepare response for inclusion in model call
159
172
  response: dict[str, Any] = {}
@@ -169,7 +182,7 @@ class MistralAPI(ModelAPI):
169
182
  return ModelCall.create(
170
183
  request=req,
171
184
  response=response,
172
- time=time_tracker.end_request(request_id),
185
+ time=http_hooks.end_request(request_id),
173
186
  )
174
187
 
175
188
  # send request
@@ -205,12 +218,13 @@ class MistralAPI(ModelAPI):
205
218
  ), model_call()
206
219
 
207
220
  @override
208
- def is_rate_limit(self, ex: BaseException) -> bool:
209
- return (
210
- isinstance(ex, SDKError)
211
- and ex.status_code == 429
212
- or isinstance(ex, ReadTimeout | AsyncReadTimeout)
213
- )
221
+ def should_retry(self, ex: Exception) -> bool:
222
+ if isinstance(ex, SDKError):
223
+ return is_retryable_http_status(ex.status_code)
224
+ elif isinstance(ex, ReadTimeout | AsyncReadTimeout):
225
+ return True
226
+ else:
227
+ return False
214
228
 
215
229
  @override
216
230
  def connection_key(self) -> str:
@@ -462,6 +476,8 @@ def completion_content_chunk(content: ContentChunk) -> Content:
462
476
  raise TypeError("ReferenceChunk content is not supported by Inspect.")
463
477
  elif isinstance(content, TextChunk):
464
478
  return ContentText(text=content.text)
479
+ elif isinstance(content, DocumentURLChunk):
480
+ return ContentText(text=content.document_url)
465
481
  else:
466
482
  if isinstance(content.image_url, str):
467
483
  return ContentImage(image=content.image_url)
@@ -7,25 +7,22 @@ import httpx
7
7
  from openai import (
8
8
  DEFAULT_CONNECTION_LIMITS,
9
9
  DEFAULT_TIMEOUT,
10
- APIConnectionError,
10
+ APIStatusError,
11
11
  APITimeoutError,
12
12
  AsyncAzureOpenAI,
13
13
  AsyncOpenAI,
14
14
  BadRequestError,
15
- InternalServerError,
16
15
  RateLimitError,
17
16
  )
18
17
  from openai._types import NOT_GIVEN
19
- from openai.types.chat import (
20
- ChatCompletion,
21
- )
18
+ from openai.types.chat import ChatCompletion
22
19
  from typing_extensions import override
23
20
 
24
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
25
21
  from inspect_ai._util.error import PrerequisiteError
22
+ from inspect_ai._util.http import is_retryable_http_status
26
23
  from inspect_ai._util.logger import warn_once
27
24
  from inspect_ai.model._openai import chat_choices_from_openai
28
- from inspect_ai.model._providers.util.tracker import HttpxTimeTracker
25
+ from inspect_ai.model._providers.util.hooks import HttpxHooks
29
26
  from inspect_ai.tool import ToolChoice, ToolInfo
30
27
 
31
28
  from .._chat_message import ChatMessage
@@ -130,9 +127,6 @@ class OpenAIAPI(ModelAPI):
130
127
  api_key=self.api_key,
131
128
  azure_endpoint=base_url,
132
129
  azure_deployment=model_name,
133
- max_retries=(
134
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
135
- ),
136
130
  http_client=http_client,
137
131
  **model_args,
138
132
  )
@@ -140,15 +134,12 @@ class OpenAIAPI(ModelAPI):
140
134
  self.client = AsyncOpenAI(
141
135
  api_key=self.api_key,
142
136
  base_url=model_base_url(base_url, "OPENAI_BASE_URL"),
143
- max_retries=(
144
- config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
145
- ),
146
137
  http_client=http_client,
147
138
  **model_args,
148
139
  )
149
140
 
150
141
  # create time tracker
151
- self._time_tracker = HttpxTimeTracker(self.client._client)
142
+ self._http_hooks = HttpxHooks(self.client._client)
152
143
 
153
144
  def is_azure(self) -> bool:
154
145
  return self.service == "azure"
@@ -186,7 +177,7 @@ class OpenAIAPI(ModelAPI):
186
177
  )
187
178
 
188
179
  # allocate request_id (so we can see it from ModelCall)
189
- request_id = self._time_tracker.start_request()
180
+ request_id = self._http_hooks.start_request()
190
181
 
191
182
  # setup request and response for ModelCall
192
183
  request: dict[str, Any] = {}
@@ -197,7 +188,7 @@ class OpenAIAPI(ModelAPI):
197
188
  request=request,
198
189
  response=response,
199
190
  filter=image_url_filter,
200
- time=self._time_tracker.end_request(request_id),
191
+ time=self._http_hooks.end_request(request_id),
201
192
  )
202
193
 
203
194
  # unlike text models, vision models require a max_tokens (and set it to a very low
@@ -216,7 +207,7 @@ class OpenAIAPI(ModelAPI):
216
207
  tool_choice=openai_chat_tool_choice(tool_choice)
217
208
  if len(tools) > 0
218
209
  else NOT_GIVEN,
219
- extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id},
210
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
220
211
  **self.completion_params(config, len(tools) > 0),
221
212
  )
222
213
 
@@ -266,17 +257,21 @@ class OpenAIAPI(ModelAPI):
266
257
  return chat_choices_from_openai(response, tools)
267
258
 
268
259
  @override
269
- def is_rate_limit(self, ex: BaseException) -> bool:
260
+ def should_retry(self, ex: Exception) -> bool:
270
261
  if isinstance(ex, RateLimitError):
271
262
  # Do not retry on these rate limit errors
272
263
  # The quota exceeded one is related to monthly account quotas.
273
- if "You exceeded your current quota" not in ex.message:
264
+ if "You exceeded your current quota" in ex.message:
265
+ warn_once(logger, f"OpenAI quota exceeded, not retrying: {ex.message}")
266
+ return False
267
+ else:
274
268
  return True
275
- elif isinstance(
276
- ex, (APIConnectionError | APITimeoutError | InternalServerError)
277
- ):
269
+ elif isinstance(ex, APIStatusError):
270
+ return is_retryable_http_status(ex.status_code)
271
+ elif isinstance(ex, APITimeoutError):
278
272
  return True
279
- return False
273
+ else:
274
+ return False
280
275
 
281
276
  @override
282
277
  def connection_key(self) -> str:
@@ -315,8 +310,6 @@ class OpenAIAPI(ModelAPI):
315
310
  params["temperature"] = 1
316
311
  if config.top_p is not None:
317
312
  params["top_p"] = config.top_p
318
- if config.timeout is not None:
319
- params["timeout"] = float(config.timeout)
320
313
  if config.num_choices is not None:
321
314
  params["n"] = config.num_choices
322
315
  if config.logprobs is not None:
@@ -331,6 +324,18 @@ class OpenAIAPI(ModelAPI):
331
324
  and not self.is_o1_mini()
332
325
  ):
333
326
  params["reasoning_effort"] = config.reasoning_effort
327
+ if config.response_schema is not None:
328
+ params["response_format"] = dict(
329
+ type="json_schema",
330
+ json_schema=dict(
331
+ name=config.response_schema.name,
332
+ schema=config.response_schema.json_schema.model_dump(
333
+ exclude_none=True
334
+ ),
335
+ description=config.response_schema.description,
336
+ strict=config.response_schema.strict,
337
+ ),
338
+ )
334
339
 
335
340
  return params
336
341
 
@@ -107,7 +107,7 @@ def chat_messages(
107
107
  ) -> list[ChatCompletionMessageParam]:
108
108
  # o1 does not allow system messages so convert system -> user
109
109
  messages: list[ChatMessage] = [
110
- ChatMessageUser(content=message.content)
110
+ ChatMessageUser(id=message.id, content=message.content)
111
111
  if message.role == "system"
112
112
  else message
113
113
  for message in input
@@ -148,7 +148,7 @@ def cf() -> type[ModelAPI]:
148
148
  def mistral() -> type[ModelAPI]:
149
149
  FEATURE = "Mistral API"
150
150
  PACKAGE = "mistralai"
151
- MIN_VERSION = "1.5.0"
151
+ MIN_VERSION = "1.5.1"
152
152
 
153
153
  # verify we have the package
154
154
  try:
@@ -34,8 +34,8 @@ from .util import (
34
34
  chat_api_input,
35
35
  chat_api_request,
36
36
  environment_prerequisite_error,
37
- is_chat_api_rate_limit,
38
37
  model_base_url,
38
+ should_retry_chat_api_error,
39
39
  )
40
40
 
41
41
 
@@ -186,7 +186,6 @@ class TogetherRESTAPI(ModelAPI):
186
186
  url=f"{chat_url}",
187
187
  headers={"Authorization": f"Bearer {self.api_key}"},
188
188
  json=json,
189
- config=config,
190
189
  )
191
190
 
192
191
  if "error" in response:
@@ -215,8 +214,8 @@ class TogetherRESTAPI(ModelAPI):
215
214
  return ModelOutput(model=model, choices=choices, usage=usage)
216
215
 
217
216
  @override
218
- def is_rate_limit(self, ex: BaseException) -> bool:
219
- return is_chat_api_rate_limit(ex)
217
+ def should_retry(self, ex: Exception) -> bool:
218
+ return should_retry_chat_api_error(ex)
220
219
 
221
220
  # cloudflare enforces rate limits by model for each account
222
221
  @override