inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import functools
|
3
2
|
import hashlib
|
4
3
|
import json
|
@@ -9,6 +8,7 @@ from logging import getLogger
|
|
9
8
|
from typing import Any
|
10
9
|
|
11
10
|
# SDK Docs: https://googleapis.github.io/python-genai/
|
11
|
+
import anyio
|
12
12
|
from google.genai import Client # type: ignore
|
13
13
|
from google.genai.errors import APIError, ClientError # type: ignore
|
14
14
|
from google.genai.types import ( # type: ignore
|
@@ -26,6 +26,7 @@ from google.genai.types import ( # type: ignore
|
|
26
26
|
GenerationConfig,
|
27
27
|
HarmBlockThreshold,
|
28
28
|
HarmCategory,
|
29
|
+
HttpOptions,
|
29
30
|
Part,
|
30
31
|
SafetySetting,
|
31
32
|
SafetySettingDict,
|
@@ -49,6 +50,7 @@ from inspect_ai._util.content import (
|
|
49
50
|
ContentVideo,
|
50
51
|
)
|
51
52
|
from inspect_ai._util.error import PrerequisiteError
|
53
|
+
from inspect_ai._util.http import is_retryable_http_status
|
52
54
|
from inspect_ai._util.images import file_as_data
|
53
55
|
from inspect_ai._util.kvstore import inspect_kvstore
|
54
56
|
from inspect_ai._util.trace import trace_message
|
@@ -69,6 +71,7 @@ from inspect_ai.model import (
|
|
69
71
|
)
|
70
72
|
from inspect_ai.model._model_call import ModelCall
|
71
73
|
from inspect_ai.model._providers.util import model_base_url
|
74
|
+
from inspect_ai.model._providers.util.hooks import HttpHooks, urllib3_hooks
|
72
75
|
from inspect_ai.tool import (
|
73
76
|
ToolCall,
|
74
77
|
ToolChoice,
|
@@ -199,11 +202,15 @@ class GoogleGenAIAPI(ModelAPI):
|
|
199
202
|
tool_choice: ToolChoice,
|
200
203
|
config: GenerateConfig,
|
201
204
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
205
|
+
# generate request_id
|
206
|
+
request_id = urllib3_hooks().start_request()
|
207
|
+
|
202
208
|
# Create google-genai types.
|
203
209
|
gemini_contents = await as_chat_messages(self.client, input)
|
204
210
|
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
205
211
|
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
206
212
|
parameters = GenerateContentConfig(
|
213
|
+
http_options=HttpOptions(headers={HttpHooks.REQUEST_ID_HEADER: request_id}),
|
207
214
|
temperature=config.temperature,
|
208
215
|
top_p=config.top_p,
|
209
216
|
top_k=config.top_k,
|
@@ -219,6 +226,11 @@ class GoogleGenAIAPI(ModelAPI):
|
|
219
226
|
self.client, input
|
220
227
|
),
|
221
228
|
)
|
229
|
+
if config.response_schema is not None:
|
230
|
+
parameters.response_mime_type = "application/json"
|
231
|
+
parameters.response_schema = schema_from_param(
|
232
|
+
config.response_schema.json_schema, nullable=None
|
233
|
+
)
|
222
234
|
|
223
235
|
response: GenerateContentResponse | None = None
|
224
236
|
|
@@ -230,10 +242,9 @@ class GoogleGenAIAPI(ModelAPI):
|
|
230
242
|
tools=gemini_tools,
|
231
243
|
tool_config=gemini_tool_config,
|
232
244
|
response=response,
|
245
|
+
time=urllib3_hooks().end_request(request_id),
|
233
246
|
)
|
234
247
|
|
235
|
-
# TODO: would need to monkey patch AuthorizedSession.request
|
236
|
-
|
237
248
|
try:
|
238
249
|
response = await self.client.aio.models.generate_content(
|
239
250
|
model=self.model_name,
|
@@ -252,11 +263,25 @@ class GoogleGenAIAPI(ModelAPI):
|
|
252
263
|
return output, model_call()
|
253
264
|
|
254
265
|
@override
|
255
|
-
def
|
256
|
-
#
|
257
|
-
|
258
|
-
|
259
|
-
)
|
266
|
+
def should_retry(self, ex: Exception) -> bool:
|
267
|
+
import requests # type: ignore
|
268
|
+
|
269
|
+
# standard http errors
|
270
|
+
if isinstance(ex, APIError):
|
271
|
+
return is_retryable_http_status(ex.status)
|
272
|
+
|
273
|
+
# low-level requests exceptions
|
274
|
+
elif isinstance(ex, requests.exceptions.RequestException):
|
275
|
+
return isinstance(
|
276
|
+
ex,
|
277
|
+
(
|
278
|
+
requests.exceptions.ConnectionError
|
279
|
+
| requests.exceptions.ConnectTimeout
|
280
|
+
| requests.exceptions.ChunkedEncodingError
|
281
|
+
),
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
return False
|
260
285
|
|
261
286
|
@override
|
262
287
|
def connection_key(self) -> str:
|
@@ -296,6 +321,7 @@ def build_model_call(
|
|
296
321
|
tools: list[Tool] | None,
|
297
322
|
tool_config: ToolConfig | None,
|
298
323
|
response: GenerateContentResponse | None,
|
324
|
+
time: float | None,
|
299
325
|
) -> ModelCall:
|
300
326
|
return ModelCall.create(
|
301
327
|
request=dict(
|
@@ -307,6 +333,7 @@ def build_model_call(
|
|
307
333
|
),
|
308
334
|
response=response if response is not None else {},
|
309
335
|
filter=model_call_filter,
|
336
|
+
time=time,
|
310
337
|
)
|
311
338
|
|
312
339
|
|
@@ -464,7 +491,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
464
491
|
|
465
492
|
|
466
493
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
467
|
-
def schema_from_param(
|
494
|
+
def schema_from_param(
|
495
|
+
param: ToolParam | ToolParams, nullable: bool | None = False
|
496
|
+
) -> Schema:
|
468
497
|
if isinstance(param, ToolParams):
|
469
498
|
param = ToolParam(
|
470
499
|
type=param.type, properties=param.properties, required=param.required
|
@@ -529,10 +558,13 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
|
|
529
558
|
|
530
559
|
|
531
560
|
def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
|
532
|
-
# check for completion text
|
533
|
-
content = ""
|
534
561
|
# content can be None when the finish_reason is SAFETY
|
535
|
-
if candidate.content is
|
562
|
+
if candidate.content is None:
|
563
|
+
content = ""
|
564
|
+
# content.parts can be None when the finish_reason is MALFORMED_FUNCTION_CALL
|
565
|
+
elif candidate.content.parts is None:
|
566
|
+
content = ""
|
567
|
+
else:
|
536
568
|
content = " ".join(
|
537
569
|
[
|
538
570
|
part.text
|
@@ -680,6 +712,8 @@ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
|
|
680
712
|
):
|
681
713
|
return "content_filter"
|
682
714
|
case _:
|
715
|
+
# Note: to avoid adding another option to StopReason,
|
716
|
+
# this includes FinishReason.MALFORMED_FUNCTION_CALL
|
683
717
|
return "unknown"
|
684
718
|
|
685
719
|
|
@@ -775,7 +809,7 @@ async def file_for_content(
|
|
775
809
|
file=BytesIO(content_bytes), config=dict(mime_type=mime_type)
|
776
810
|
)
|
777
811
|
while upload.state.name == "PROCESSING":
|
778
|
-
await
|
812
|
+
await anyio.sleep(3)
|
779
813
|
upload = client.files.get(name=upload.name)
|
780
814
|
if upload.state.name == "FAILED":
|
781
815
|
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
@@ -5,8 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional
|
|
5
5
|
|
6
6
|
import httpx
|
7
7
|
from groq import (
|
8
|
+
APIStatusError,
|
9
|
+
APITimeoutError,
|
8
10
|
AsyncGroq,
|
9
|
-
RateLimitError,
|
10
11
|
)
|
11
12
|
from groq.types.chat import (
|
12
13
|
ChatCompletion,
|
@@ -25,10 +26,10 @@ from typing_extensions import override
|
|
25
26
|
|
26
27
|
from inspect_ai._util.constants import (
|
27
28
|
BASE_64_DATA_REMOVED,
|
28
|
-
DEFAULT_MAX_RETRIES,
|
29
29
|
DEFAULT_MAX_TOKENS,
|
30
30
|
)
|
31
31
|
from inspect_ai._util.content import Content, ContentReasoning, ContentText
|
32
|
+
from inspect_ai._util.http import is_retryable_http_status
|
32
33
|
from inspect_ai._util.images import file_as_data_uri
|
33
34
|
from inspect_ai._util.url import is_http_url
|
34
35
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
@@ -54,7 +55,7 @@ from .util import (
|
|
54
55
|
environment_prerequisite_error,
|
55
56
|
model_base_url,
|
56
57
|
)
|
57
|
-
from .util.
|
58
|
+
from .util.hooks import HttpxHooks
|
58
59
|
|
59
60
|
GROQ_API_KEY = "GROQ_API_KEY"
|
60
61
|
|
@@ -84,18 +85,12 @@ class GroqAPI(ModelAPI):
|
|
84
85
|
self.client = AsyncGroq(
|
85
86
|
api_key=self.api_key,
|
86
87
|
base_url=model_base_url(base_url, "GROQ_BASE_URL"),
|
87
|
-
max_retries=(
|
88
|
-
config.max_retries
|
89
|
-
if config.max_retries is not None
|
90
|
-
else DEFAULT_MAX_RETRIES
|
91
|
-
),
|
92
|
-
timeout=config.timeout if config.timeout is not None else 60.0,
|
93
88
|
**model_args,
|
94
89
|
http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
|
95
90
|
)
|
96
91
|
|
97
92
|
# create time tracker
|
98
|
-
self.
|
93
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
99
94
|
|
100
95
|
@override
|
101
96
|
async def close(self) -> None:
|
@@ -109,7 +104,7 @@ class GroqAPI(ModelAPI):
|
|
109
104
|
config: GenerateConfig,
|
110
105
|
) -> tuple[ModelOutput, ModelCall]:
|
111
106
|
# allocate request_id (so we can see it from ModelCall)
|
112
|
-
request_id = self.
|
107
|
+
request_id = self._http_hooks.start_request()
|
113
108
|
|
114
109
|
# setup request and response for ModelCall
|
115
110
|
request: dict[str, Any] = {}
|
@@ -120,7 +115,7 @@ class GroqAPI(ModelAPI):
|
|
120
115
|
request=request,
|
121
116
|
response=response,
|
122
117
|
filter=model_call_filter,
|
123
|
-
time=self.
|
118
|
+
time=self._http_hooks.end_request(request_id),
|
124
119
|
)
|
125
120
|
|
126
121
|
messages = await as_groq_chat_messages(input)
|
@@ -137,7 +132,7 @@ class GroqAPI(ModelAPI):
|
|
137
132
|
request = dict(
|
138
133
|
messages=messages,
|
139
134
|
model=self.model_name,
|
140
|
-
extra_headers={
|
135
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
141
136
|
**params,
|
142
137
|
)
|
143
138
|
|
@@ -215,8 +210,13 @@ class GroqAPI(ModelAPI):
|
|
215
210
|
]
|
216
211
|
|
217
212
|
@override
|
218
|
-
def
|
219
|
-
|
213
|
+
def should_retry(self, ex: Exception) -> bool:
|
214
|
+
if isinstance(ex, APIStatusError):
|
215
|
+
return is_retryable_http_status(ex.status_code)
|
216
|
+
elif isinstance(ex, APITimeoutError):
|
217
|
+
return True
|
218
|
+
else:
|
219
|
+
return False
|
220
220
|
|
221
221
|
@override
|
222
222
|
def connection_key(self) -> str:
|
@@ -1,15 +1,19 @@
|
|
1
|
-
import
|
1
|
+
import concurrent
|
2
|
+
import concurrent.futures
|
2
3
|
import copy
|
3
4
|
import functools
|
4
5
|
import gc
|
5
6
|
import json
|
6
7
|
import os
|
7
8
|
import time
|
9
|
+
from concurrent.futures import Future
|
8
10
|
from dataclasses import dataclass
|
11
|
+
from logging import getLogger
|
9
12
|
from queue import Empty, Queue
|
10
13
|
from threading import Thread
|
11
14
|
from typing import Any, Literal, Protocol, cast
|
12
15
|
|
16
|
+
import anyio
|
13
17
|
import numpy as np
|
14
18
|
import torch # type: ignore
|
15
19
|
from torch import Tensor # type: ignore
|
@@ -23,6 +27,7 @@ from typing_extensions import override
|
|
23
27
|
|
24
28
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
25
29
|
from inspect_ai._util.content import ContentText
|
30
|
+
from inspect_ai._util.trace import trace_action
|
26
31
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
27
32
|
|
28
33
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -38,6 +43,9 @@ from .._model_output import (
|
|
38
43
|
)
|
39
44
|
from .util import ChatAPIHandler, HFHandler
|
40
45
|
|
46
|
+
logger = getLogger(__name__)
|
47
|
+
|
48
|
+
|
41
49
|
HF_TOKEN = "HF_TOKEN"
|
42
50
|
|
43
51
|
|
@@ -385,8 +393,7 @@ class GenerateOutput:
|
|
385
393
|
@dataclass
|
386
394
|
class _QueueItem:
|
387
395
|
input: GenerateInput
|
388
|
-
future:
|
389
|
-
loop: asyncio.AbstractEventLoop
|
396
|
+
future: Future[GenerateOutput]
|
390
397
|
|
391
398
|
|
392
399
|
batch_thread: Thread | None = None
|
@@ -402,25 +409,26 @@ async def batched_generate(input: GenerateInput) -> GenerateOutput:
|
|
402
409
|
batch_thread.start()
|
403
410
|
|
404
411
|
# enqueue the job
|
405
|
-
|
406
|
-
|
407
|
-
batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
|
408
|
-
|
409
|
-
# await the job
|
410
|
-
await future
|
412
|
+
future = Future[GenerateOutput]()
|
413
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
411
414
|
|
412
|
-
#
|
413
|
-
|
415
|
+
# await the future
|
416
|
+
with trace_action(logger, "HF Batched Generate", "HF Batched Generate"):
|
417
|
+
while True:
|
418
|
+
try:
|
419
|
+
return future.result(timeout=0.01)
|
420
|
+
except concurrent.futures.TimeoutError:
|
421
|
+
pass
|
422
|
+
await anyio.sleep(1)
|
414
423
|
|
415
424
|
|
416
425
|
def process_batches() -> None:
|
417
426
|
while True:
|
418
427
|
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
419
|
-
inputs: list[tuple[GenerateInput,
|
428
|
+
inputs: list[tuple[GenerateInput, Future[GenerateOutput]]] = []
|
420
429
|
while True:
|
421
430
|
try:
|
422
431
|
input = batch_queue.get(timeout=2)
|
423
|
-
loop = input.loop
|
424
432
|
inputs.append((input.input, input.future))
|
425
433
|
if len(inputs) == input.input.batch_size:
|
426
434
|
# max batch size reached
|
@@ -480,8 +488,7 @@ def process_batches() -> None:
|
|
480
488
|
# asyncio futures are not thread safe, so we need to pass the event loop
|
481
489
|
# down to this point, so we can mark the future as done in a thread safe manner.
|
482
490
|
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
483
|
-
|
484
|
-
future.set_result,
|
491
|
+
future.set_result(
|
485
492
|
GenerateOutput(
|
486
493
|
output=output,
|
487
494
|
input_tokens=input_tokens,
|
@@ -489,13 +496,13 @@ def process_batches() -> None:
|
|
489
496
|
total_tokens=input_tokens + output_tokens,
|
490
497
|
logprobs=logprobs[i] if logprobs is not None else None,
|
491
498
|
time=total_time,
|
492
|
-
)
|
499
|
+
)
|
493
500
|
)
|
494
501
|
|
495
502
|
except Exception as ex:
|
496
503
|
for inp in inputs:
|
497
504
|
future = inp[1]
|
498
|
-
|
505
|
+
future.set_exception(ex)
|
499
506
|
|
500
507
|
|
501
508
|
def extract_logprobs(
|
@@ -7,6 +7,7 @@ from httpcore import ReadTimeout
|
|
7
7
|
from httpx import ReadTimeout as AsyncReadTimeout
|
8
8
|
from mistralai import (
|
9
9
|
ContentChunk,
|
10
|
+
DocumentURLChunk,
|
10
11
|
FunctionCall,
|
11
12
|
FunctionName,
|
12
13
|
ImageURL,
|
@@ -22,6 +23,12 @@ from mistralai.models import (
|
|
22
23
|
ChatCompletionChoice as MistralChatCompletionChoice,
|
23
24
|
)
|
24
25
|
from mistralai.models import Function as MistralFunction
|
26
|
+
from mistralai.models import (
|
27
|
+
JSONSchema as MistralJSONSchema,
|
28
|
+
)
|
29
|
+
from mistralai.models import (
|
30
|
+
ResponseFormat as MistralResponseFormat,
|
31
|
+
)
|
25
32
|
from mistralai.models import SDKError
|
26
33
|
from mistralai.models import SystemMessage as MistralSystemMessage
|
27
34
|
from mistralai.models import Tool as MistralTool
|
@@ -38,11 +45,9 @@ from typing_extensions import override
|
|
38
45
|
|
39
46
|
# TODO: Migration guide:
|
40
47
|
# https://github.com/mistralai/client-python/blob/main/MIGRATION.md
|
41
|
-
from inspect_ai._util.constants import
|
42
|
-
DEFAULT_TIMEOUT,
|
43
|
-
NO_CONTENT,
|
44
|
-
)
|
48
|
+
from inspect_ai._util.constants import NO_CONTENT
|
45
49
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
50
|
+
from inspect_ai._util.http import is_retryable_http_status
|
46
51
|
from inspect_ai._util.images import file_as_data_uri
|
47
52
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
53
|
|
@@ -61,7 +66,7 @@ from .._model_output import (
|
|
61
66
|
StopReason,
|
62
67
|
)
|
63
68
|
from .util import environment_prerequisite_error, model_base_url
|
64
|
-
from .util.
|
69
|
+
from .util.hooks import HttpxHooks
|
65
70
|
|
66
71
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
67
72
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -127,16 +132,12 @@ class MistralAPI(ModelAPI):
|
|
127
132
|
config: GenerateConfig,
|
128
133
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
129
134
|
# create client
|
130
|
-
with Mistral(
|
131
|
-
api_key=self.api_key,
|
132
|
-
timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000,
|
133
|
-
**self.model_args,
|
134
|
-
) as client:
|
135
|
+
with Mistral(api_key=self.api_key, **self.model_args) as client:
|
135
136
|
# create time tracker
|
136
|
-
|
137
|
+
http_hooks = HttpxHooks(client.sdk_configuration.async_client)
|
137
138
|
|
138
139
|
# build request
|
139
|
-
request_id =
|
140
|
+
request_id = http_hooks.start_request()
|
140
141
|
request: dict[str, Any] = dict(
|
141
142
|
model=self.model_name,
|
142
143
|
messages=await mistral_chat_messages(input),
|
@@ -144,7 +145,7 @@ class MistralAPI(ModelAPI):
|
|
144
145
|
tool_choice=(
|
145
146
|
mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None
|
146
147
|
),
|
147
|
-
http_headers={
|
148
|
+
http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
148
149
|
)
|
149
150
|
if config.temperature is not None:
|
150
151
|
request["temperature"] = config.temperature
|
@@ -154,6 +155,18 @@ class MistralAPI(ModelAPI):
|
|
154
155
|
request["max_tokens"] = config.max_tokens
|
155
156
|
if config.seed is not None:
|
156
157
|
request["random_seed"] = config.seed
|
158
|
+
if config.response_schema is not None:
|
159
|
+
request["response_format"] = MistralResponseFormat(
|
160
|
+
type="json_schema",
|
161
|
+
json_schema=MistralJSONSchema(
|
162
|
+
name=config.response_schema.name,
|
163
|
+
description=config.response_schema.description,
|
164
|
+
schema_definition=config.response_schema.json_schema.model_dump(
|
165
|
+
exclude_none=True
|
166
|
+
),
|
167
|
+
strict=config.response_schema.strict,
|
168
|
+
),
|
169
|
+
)
|
157
170
|
|
158
171
|
# prepare response for inclusion in model call
|
159
172
|
response: dict[str, Any] = {}
|
@@ -169,7 +182,7 @@ class MistralAPI(ModelAPI):
|
|
169
182
|
return ModelCall.create(
|
170
183
|
request=req,
|
171
184
|
response=response,
|
172
|
-
time=
|
185
|
+
time=http_hooks.end_request(request_id),
|
173
186
|
)
|
174
187
|
|
175
188
|
# send request
|
@@ -205,12 +218,13 @@ class MistralAPI(ModelAPI):
|
|
205
218
|
), model_call()
|
206
219
|
|
207
220
|
@override
|
208
|
-
def
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
221
|
+
def should_retry(self, ex: Exception) -> bool:
|
222
|
+
if isinstance(ex, SDKError):
|
223
|
+
return is_retryable_http_status(ex.status_code)
|
224
|
+
elif isinstance(ex, ReadTimeout | AsyncReadTimeout):
|
225
|
+
return True
|
226
|
+
else:
|
227
|
+
return False
|
214
228
|
|
215
229
|
@override
|
216
230
|
def connection_key(self) -> str:
|
@@ -462,6 +476,8 @@ def completion_content_chunk(content: ContentChunk) -> Content:
|
|
462
476
|
raise TypeError("ReferenceChunk content is not supported by Inspect.")
|
463
477
|
elif isinstance(content, TextChunk):
|
464
478
|
return ContentText(text=content.text)
|
479
|
+
elif isinstance(content, DocumentURLChunk):
|
480
|
+
return ContentText(text=content.document_url)
|
465
481
|
else:
|
466
482
|
if isinstance(content.image_url, str):
|
467
483
|
return ContentImage(image=content.image_url)
|
@@ -7,25 +7,22 @@ import httpx
|
|
7
7
|
from openai import (
|
8
8
|
DEFAULT_CONNECTION_LIMITS,
|
9
9
|
DEFAULT_TIMEOUT,
|
10
|
-
|
10
|
+
APIStatusError,
|
11
11
|
APITimeoutError,
|
12
12
|
AsyncAzureOpenAI,
|
13
13
|
AsyncOpenAI,
|
14
14
|
BadRequestError,
|
15
|
-
InternalServerError,
|
16
15
|
RateLimitError,
|
17
16
|
)
|
18
17
|
from openai._types import NOT_GIVEN
|
19
|
-
from openai.types.chat import
|
20
|
-
ChatCompletion,
|
21
|
-
)
|
18
|
+
from openai.types.chat import ChatCompletion
|
22
19
|
from typing_extensions import override
|
23
20
|
|
24
|
-
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
25
21
|
from inspect_ai._util.error import PrerequisiteError
|
22
|
+
from inspect_ai._util.http import is_retryable_http_status
|
26
23
|
from inspect_ai._util.logger import warn_once
|
27
24
|
from inspect_ai.model._openai import chat_choices_from_openai
|
28
|
-
from inspect_ai.model._providers.util.
|
25
|
+
from inspect_ai.model._providers.util.hooks import HttpxHooks
|
29
26
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
30
27
|
|
31
28
|
from .._chat_message import ChatMessage
|
@@ -130,9 +127,6 @@ class OpenAIAPI(ModelAPI):
|
|
130
127
|
api_key=self.api_key,
|
131
128
|
azure_endpoint=base_url,
|
132
129
|
azure_deployment=model_name,
|
133
|
-
max_retries=(
|
134
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
135
|
-
),
|
136
130
|
http_client=http_client,
|
137
131
|
**model_args,
|
138
132
|
)
|
@@ -140,15 +134,12 @@ class OpenAIAPI(ModelAPI):
|
|
140
134
|
self.client = AsyncOpenAI(
|
141
135
|
api_key=self.api_key,
|
142
136
|
base_url=model_base_url(base_url, "OPENAI_BASE_URL"),
|
143
|
-
max_retries=(
|
144
|
-
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
145
|
-
),
|
146
137
|
http_client=http_client,
|
147
138
|
**model_args,
|
148
139
|
)
|
149
140
|
|
150
141
|
# create time tracker
|
151
|
-
self.
|
142
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
152
143
|
|
153
144
|
def is_azure(self) -> bool:
|
154
145
|
return self.service == "azure"
|
@@ -186,7 +177,7 @@ class OpenAIAPI(ModelAPI):
|
|
186
177
|
)
|
187
178
|
|
188
179
|
# allocate request_id (so we can see it from ModelCall)
|
189
|
-
request_id = self.
|
180
|
+
request_id = self._http_hooks.start_request()
|
190
181
|
|
191
182
|
# setup request and response for ModelCall
|
192
183
|
request: dict[str, Any] = {}
|
@@ -197,7 +188,7 @@ class OpenAIAPI(ModelAPI):
|
|
197
188
|
request=request,
|
198
189
|
response=response,
|
199
190
|
filter=image_url_filter,
|
200
|
-
time=self.
|
191
|
+
time=self._http_hooks.end_request(request_id),
|
201
192
|
)
|
202
193
|
|
203
194
|
# unlike text models, vision models require a max_tokens (and set it to a very low
|
@@ -216,7 +207,7 @@ class OpenAIAPI(ModelAPI):
|
|
216
207
|
tool_choice=openai_chat_tool_choice(tool_choice)
|
217
208
|
if len(tools) > 0
|
218
209
|
else NOT_GIVEN,
|
219
|
-
extra_headers={
|
210
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
220
211
|
**self.completion_params(config, len(tools) > 0),
|
221
212
|
)
|
222
213
|
|
@@ -266,17 +257,21 @@ class OpenAIAPI(ModelAPI):
|
|
266
257
|
return chat_choices_from_openai(response, tools)
|
267
258
|
|
268
259
|
@override
|
269
|
-
def
|
260
|
+
def should_retry(self, ex: Exception) -> bool:
|
270
261
|
if isinstance(ex, RateLimitError):
|
271
262
|
# Do not retry on these rate limit errors
|
272
263
|
# The quota exceeded one is related to monthly account quotas.
|
273
|
-
if "You exceeded your current quota"
|
264
|
+
if "You exceeded your current quota" in ex.message:
|
265
|
+
warn_once(logger, f"OpenAI quota exceeded, not retrying: {ex.message}")
|
266
|
+
return False
|
267
|
+
else:
|
274
268
|
return True
|
275
|
-
elif isinstance(
|
276
|
-
|
277
|
-
):
|
269
|
+
elif isinstance(ex, APIStatusError):
|
270
|
+
return is_retryable_http_status(ex.status_code)
|
271
|
+
elif isinstance(ex, APITimeoutError):
|
278
272
|
return True
|
279
|
-
|
273
|
+
else:
|
274
|
+
return False
|
280
275
|
|
281
276
|
@override
|
282
277
|
def connection_key(self) -> str:
|
@@ -315,8 +310,6 @@ class OpenAIAPI(ModelAPI):
|
|
315
310
|
params["temperature"] = 1
|
316
311
|
if config.top_p is not None:
|
317
312
|
params["top_p"] = config.top_p
|
318
|
-
if config.timeout is not None:
|
319
|
-
params["timeout"] = float(config.timeout)
|
320
313
|
if config.num_choices is not None:
|
321
314
|
params["n"] = config.num_choices
|
322
315
|
if config.logprobs is not None:
|
@@ -331,6 +324,18 @@ class OpenAIAPI(ModelAPI):
|
|
331
324
|
and not self.is_o1_mini()
|
332
325
|
):
|
333
326
|
params["reasoning_effort"] = config.reasoning_effort
|
327
|
+
if config.response_schema is not None:
|
328
|
+
params["response_format"] = dict(
|
329
|
+
type="json_schema",
|
330
|
+
json_schema=dict(
|
331
|
+
name=config.response_schema.name,
|
332
|
+
schema=config.response_schema.json_schema.model_dump(
|
333
|
+
exclude_none=True
|
334
|
+
),
|
335
|
+
description=config.response_schema.description,
|
336
|
+
strict=config.response_schema.strict,
|
337
|
+
),
|
338
|
+
)
|
334
339
|
|
335
340
|
return params
|
336
341
|
|
@@ -107,7 +107,7 @@ def chat_messages(
|
|
107
107
|
) -> list[ChatCompletionMessageParam]:
|
108
108
|
# o1 does not allow system messages so convert system -> user
|
109
109
|
messages: list[ChatMessage] = [
|
110
|
-
ChatMessageUser(content=message.content)
|
110
|
+
ChatMessageUser(id=message.id, content=message.content)
|
111
111
|
if message.role == "system"
|
112
112
|
else message
|
113
113
|
for message in input
|
@@ -34,8 +34,8 @@ from .util import (
|
|
34
34
|
chat_api_input,
|
35
35
|
chat_api_request,
|
36
36
|
environment_prerequisite_error,
|
37
|
-
is_chat_api_rate_limit,
|
38
37
|
model_base_url,
|
38
|
+
should_retry_chat_api_error,
|
39
39
|
)
|
40
40
|
|
41
41
|
|
@@ -186,7 +186,6 @@ class TogetherRESTAPI(ModelAPI):
|
|
186
186
|
url=f"{chat_url}",
|
187
187
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
188
188
|
json=json,
|
189
|
-
config=config,
|
190
189
|
)
|
191
190
|
|
192
191
|
if "error" in response:
|
@@ -215,8 +214,8 @@ class TogetherRESTAPI(ModelAPI):
|
|
215
214
|
return ModelOutput(model=model, choices=choices, usage=usage)
|
216
215
|
|
217
216
|
@override
|
218
|
-
def
|
219
|
-
return
|
217
|
+
def should_retry(self, ex: Exception) -> bool:
|
218
|
+
return should_retry_chat_api_error(ex)
|
220
219
|
|
221
220
|
# cloudflare enforces rate limits by model for each account
|
222
221
|
@override
|