inspect-ai 0.3.71__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +172 -154
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_computer/_common.py +117 -58
- inspect_ai/tool/_tools/_computer/_computer.py +80 -57
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +7 -1
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +91 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +8 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +12 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +78 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +20 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +175 -113
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +76 -20
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +65 -0
- inspect_ai/tool/_tools/_computer/test_args.py +151 -0
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +111 -103
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/model/_providers/util/tracker.py +0 -92
- inspect_ai/tool/_tools/_computer/_computer_split.py +0 -198
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.71.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -27,11 +27,16 @@ from azure.ai.inference.models import (
|
|
27
27
|
UserMessage,
|
28
28
|
)
|
29
29
|
from azure.core.credentials import AzureKeyCredential
|
30
|
-
from azure.core.exceptions import
|
30
|
+
from azure.core.exceptions import (
|
31
|
+
AzureError,
|
32
|
+
HttpResponseError,
|
33
|
+
ServiceResponseError,
|
34
|
+
)
|
31
35
|
from typing_extensions import override
|
32
36
|
|
33
37
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
34
38
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
39
|
+
from inspect_ai._util.http import is_retryable_http_status
|
35
40
|
from inspect_ai._util.images import file_as_data_uri
|
36
41
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
37
42
|
from inspect_ai.tool._tool_call import ToolCall
|
@@ -232,14 +237,11 @@ class AzureAIAPI(ModelAPI):
|
|
232
237
|
return DEFAULT_MAX_TOKENS
|
233
238
|
|
234
239
|
@override
|
235
|
-
def
|
236
|
-
if isinstance(ex, HttpResponseError):
|
237
|
-
return (
|
238
|
-
|
239
|
-
|
240
|
-
or ex.status_code == 429
|
241
|
-
or ex.status_code == 500
|
242
|
-
)
|
240
|
+
def should_retry(self, ex: Exception) -> bool:
|
241
|
+
if isinstance(ex, HttpResponseError) and ex.status_code is not None:
|
242
|
+
return is_retryable_http_status(ex.status_code)
|
243
|
+
elif isinstance(ex, ServiceResponseError):
|
244
|
+
return True
|
243
245
|
else:
|
244
246
|
return False
|
245
247
|
|
@@ -1,16 +1,14 @@
|
|
1
1
|
import base64
|
2
|
+
from logging import getLogger
|
2
3
|
from typing import Any, Literal, Tuple, Union, cast
|
3
4
|
|
4
5
|
from pydantic import BaseModel, Field
|
5
6
|
from typing_extensions import override
|
6
7
|
|
7
|
-
from inspect_ai._util.
|
8
|
-
|
9
|
-
DEFAULT_MAX_TOKENS,
|
10
|
-
DEFAULT_TIMEOUT,
|
11
|
-
)
|
8
|
+
from inspect_ai._util._async import current_async_backend
|
9
|
+
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
12
10
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
13
|
-
from inspect_ai._util.error import pip_dependency_error
|
11
|
+
from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
|
14
12
|
from inspect_ai._util.images import file_as_data
|
15
13
|
from inspect_ai._util.version import verify_required_version
|
16
14
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
@@ -31,7 +29,9 @@ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
|
31
29
|
from .util import (
|
32
30
|
model_base_url,
|
33
31
|
)
|
34
|
-
from .util.
|
32
|
+
from .util.hooks import ConverseHooks
|
33
|
+
|
34
|
+
logger = getLogger(__name__)
|
35
35
|
|
36
36
|
# Model for Bedrock Converse API (Response)
|
37
37
|
# generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse
|
@@ -245,6 +245,12 @@ class BedrockAPI(ModelAPI):
|
|
245
245
|
config=config,
|
246
246
|
)
|
247
247
|
|
248
|
+
# raise if we are using trio
|
249
|
+
if current_async_backend() == "trio":
|
250
|
+
raise PrerequisiteError(
|
251
|
+
"ERROR: The bedrock provider does not work with the trio async backend."
|
252
|
+
)
|
253
|
+
|
248
254
|
# save model_args
|
249
255
|
self.model_args = model_args
|
250
256
|
|
@@ -258,7 +264,7 @@ class BedrockAPI(ModelAPI):
|
|
258
264
|
self.session = aioboto3.Session()
|
259
265
|
|
260
266
|
# create time tracker
|
261
|
-
self.
|
267
|
+
self._http_hooks = ConverseHooks(self.session)
|
262
268
|
|
263
269
|
except ImportError:
|
264
270
|
raise pip_dependency_error("Bedrock API", ["aioboto3"])
|
@@ -288,15 +294,25 @@ class BedrockAPI(ModelAPI):
|
|
288
294
|
return DEFAULT_MAX_TOKENS
|
289
295
|
|
290
296
|
@override
|
291
|
-
def
|
297
|
+
def should_retry(self, ex: Exception) -> bool:
|
292
298
|
from botocore.exceptions import ClientError
|
293
299
|
|
294
300
|
# Look for an explicit throttle exception
|
295
301
|
if isinstance(ex, ClientError):
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
302
|
+
error_code = ex.response.get("Error", {}).get("Code", "")
|
303
|
+
return error_code in [
|
304
|
+
"ThrottlingException",
|
305
|
+
"RequestLimitExceeded",
|
306
|
+
"Throttling",
|
307
|
+
"RequestThrottled",
|
308
|
+
"TooManyRequestsException",
|
309
|
+
"ProvisionedThroughputExceededException",
|
310
|
+
"TransactionInProgressException",
|
311
|
+
"RequestTimeout",
|
312
|
+
"ServiceUnavailable",
|
313
|
+
]
|
314
|
+
else:
|
315
|
+
return False
|
300
316
|
|
301
317
|
@override
|
302
318
|
def collapse_user_messages(self) -> bool:
|
@@ -317,20 +333,13 @@ class BedrockAPI(ModelAPI):
|
|
317
333
|
from botocore.exceptions import ClientError
|
318
334
|
|
319
335
|
# The bedrock client
|
320
|
-
request_id = self.
|
336
|
+
request_id = self._http_hooks.start_request()
|
321
337
|
async with self.session.client( # type: ignore[call-overload]
|
322
338
|
service_name="bedrock-runtime",
|
323
339
|
endpoint_url=self.base_url,
|
324
340
|
config=Config(
|
325
|
-
|
326
|
-
|
327
|
-
retries=dict(
|
328
|
-
max_attempts=config.max_retries
|
329
|
-
if config.max_retries
|
330
|
-
else DEFAULT_MAX_RETRIES,
|
331
|
-
mode="adaptive",
|
332
|
-
),
|
333
|
-
user_agent_extra=self._time_tracker.user_agent_extra(request_id),
|
341
|
+
retries=dict(mode="adaptive"),
|
342
|
+
user_agent_extra=self._http_hooks.user_agent_extra(request_id),
|
334
343
|
),
|
335
344
|
**self.model_args,
|
336
345
|
) as client:
|
@@ -370,7 +379,7 @@ class BedrockAPI(ModelAPI):
|
|
370
379
|
request.model_dump(exclude_none=True)
|
371
380
|
),
|
372
381
|
response=response,
|
373
|
-
time=self.
|
382
|
+
time=self._http_hooks.end_request(request_id),
|
374
383
|
)
|
375
384
|
|
376
385
|
try:
|
@@ -16,10 +16,10 @@ from .util import (
|
|
16
16
|
chat_api_input,
|
17
17
|
chat_api_request,
|
18
18
|
environment_prerequisite_error,
|
19
|
-
is_chat_api_rate_limit,
|
20
19
|
model_base_url,
|
20
|
+
should_retry_chat_api_error,
|
21
21
|
)
|
22
|
-
from .util.
|
22
|
+
from .util.hooks import HttpxHooks
|
23
23
|
|
24
24
|
# https://developers.cloudflare.com/workers-ai/models/#text-generation
|
25
25
|
|
@@ -51,7 +51,7 @@ class CloudFlareAPI(ModelAPI):
|
|
51
51
|
if not self.api_key:
|
52
52
|
raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
|
53
53
|
self.client = httpx.AsyncClient()
|
54
|
-
self.
|
54
|
+
self._http_hooks = HttpxHooks(self.client)
|
55
55
|
base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
|
56
56
|
self.base_url = (
|
57
57
|
base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
|
@@ -79,7 +79,7 @@ class CloudFlareAPI(ModelAPI):
|
|
79
79
|
json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
|
80
80
|
|
81
81
|
# request_id
|
82
|
-
request_id = self.
|
82
|
+
request_id = self._http_hooks.start_request()
|
83
83
|
|
84
84
|
# setup response
|
85
85
|
response: dict[str, Any] = {}
|
@@ -88,7 +88,7 @@ class CloudFlareAPI(ModelAPI):
|
|
88
88
|
return ModelCall.create(
|
89
89
|
request=json,
|
90
90
|
response=response,
|
91
|
-
time=self.
|
91
|
+
time=self._http_hooks.end_request(request_id),
|
92
92
|
)
|
93
93
|
|
94
94
|
# make the call
|
@@ -98,10 +98,9 @@ class CloudFlareAPI(ModelAPI):
|
|
98
98
|
url=f"{chat_url}/{self.model_name}",
|
99
99
|
headers={
|
100
100
|
"Authorization": f"Bearer {self.api_key}",
|
101
|
-
|
101
|
+
HttpxHooks.REQUEST_ID_HEADER: request_id,
|
102
102
|
},
|
103
103
|
json=json,
|
104
|
-
config=config,
|
105
104
|
)
|
106
105
|
|
107
106
|
# handle response
|
@@ -127,8 +126,8 @@ class CloudFlareAPI(ModelAPI):
|
|
127
126
|
raise RuntimeError(f"Error calling {self.model_name}: {error}")
|
128
127
|
|
129
128
|
@override
|
130
|
-
def
|
131
|
-
return
|
129
|
+
def should_retry(self, ex: Exception) -> bool:
|
130
|
+
return should_retry_chat_api_error(ex)
|
132
131
|
|
133
132
|
# cloudflare enforces rate limits by model for each account
|
134
133
|
@override
|
@@ -3,7 +3,11 @@ from typing import Any, List, Literal, get_args
|
|
3
3
|
|
4
4
|
from goodfire import AsyncClient
|
5
5
|
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
-
from goodfire.api.exceptions import
|
6
|
+
from goodfire.api.exceptions import (
|
7
|
+
InvalidRequestException,
|
8
|
+
RateLimitException,
|
9
|
+
ServerErrorException,
|
10
|
+
)
|
7
11
|
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
8
12
|
from typing_extensions import override
|
9
13
|
|
@@ -163,9 +167,9 @@ class GoodfireAPI(ModelAPI):
|
|
163
167
|
return ex
|
164
168
|
|
165
169
|
@override
|
166
|
-
def
|
170
|
+
def should_retry(self, ex: Exception) -> bool:
|
167
171
|
"""Check if exception is due to rate limiting."""
|
168
|
-
return isinstance(ex, RateLimitException)
|
172
|
+
return isinstance(ex, RateLimitException | ServerErrorException)
|
169
173
|
|
170
174
|
@override
|
171
175
|
def connection_key(self) -> str:
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import functools
|
3
2
|
import hashlib
|
4
3
|
import json
|
@@ -9,6 +8,7 @@ from logging import getLogger
|
|
9
8
|
from typing import Any
|
10
9
|
|
11
10
|
# SDK Docs: https://googleapis.github.io/python-genai/
|
11
|
+
import anyio
|
12
12
|
from google.genai import Client # type: ignore
|
13
13
|
from google.genai.errors import APIError, ClientError # type: ignore
|
14
14
|
from google.genai.types import ( # type: ignore
|
@@ -26,6 +26,7 @@ from google.genai.types import ( # type: ignore
|
|
26
26
|
GenerationConfig,
|
27
27
|
HarmBlockThreshold,
|
28
28
|
HarmCategory,
|
29
|
+
HttpOptions,
|
29
30
|
Part,
|
30
31
|
SafetySetting,
|
31
32
|
SafetySettingDict,
|
@@ -49,6 +50,7 @@ from inspect_ai._util.content import (
|
|
49
50
|
ContentVideo,
|
50
51
|
)
|
51
52
|
from inspect_ai._util.error import PrerequisiteError
|
53
|
+
from inspect_ai._util.http import is_retryable_http_status
|
52
54
|
from inspect_ai._util.images import file_as_data
|
53
55
|
from inspect_ai._util.kvstore import inspect_kvstore
|
54
56
|
from inspect_ai._util.trace import trace_message
|
@@ -69,6 +71,7 @@ from inspect_ai.model import (
|
|
69
71
|
)
|
70
72
|
from inspect_ai.model._model_call import ModelCall
|
71
73
|
from inspect_ai.model._providers.util import model_base_url
|
74
|
+
from inspect_ai.model._providers.util.hooks import HttpHooks, urllib3_hooks
|
72
75
|
from inspect_ai.tool import (
|
73
76
|
ToolCall,
|
74
77
|
ToolChoice,
|
@@ -199,11 +202,15 @@ class GoogleGenAIAPI(ModelAPI):
|
|
199
202
|
tool_choice: ToolChoice,
|
200
203
|
config: GenerateConfig,
|
201
204
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
205
|
+
# generate request_id
|
206
|
+
request_id = urllib3_hooks().start_request()
|
207
|
+
|
202
208
|
# Create google-genai types.
|
203
209
|
gemini_contents = await as_chat_messages(self.client, input)
|
204
210
|
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
205
211
|
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
206
212
|
parameters = GenerateContentConfig(
|
213
|
+
http_options=HttpOptions(headers={HttpHooks.REQUEST_ID_HEADER: request_id}),
|
207
214
|
temperature=config.temperature,
|
208
215
|
top_p=config.top_p,
|
209
216
|
top_k=config.top_k,
|
@@ -219,6 +226,11 @@ class GoogleGenAIAPI(ModelAPI):
|
|
219
226
|
self.client, input
|
220
227
|
),
|
221
228
|
)
|
229
|
+
if config.response_schema is not None:
|
230
|
+
parameters.response_mime_type = "application/json"
|
231
|
+
parameters.response_schema = schema_from_param(
|
232
|
+
config.response_schema.json_schema, nullable=None
|
233
|
+
)
|
222
234
|
|
223
235
|
response: GenerateContentResponse | None = None
|
224
236
|
|
@@ -230,10 +242,9 @@ class GoogleGenAIAPI(ModelAPI):
|
|
230
242
|
tools=gemini_tools,
|
231
243
|
tool_config=gemini_tool_config,
|
232
244
|
response=response,
|
245
|
+
time=urllib3_hooks().end_request(request_id),
|
233
246
|
)
|
234
247
|
|
235
|
-
# TODO: would need to monkey patch AuthorizedSession.request
|
236
|
-
|
237
248
|
try:
|
238
249
|
response = await self.client.aio.models.generate_content(
|
239
250
|
model=self.model_name,
|
@@ -252,11 +263,25 @@ class GoogleGenAIAPI(ModelAPI):
|
|
252
263
|
return output, model_call()
|
253
264
|
|
254
265
|
@override
|
255
|
-
def
|
256
|
-
#
|
257
|
-
|
258
|
-
|
259
|
-
)
|
266
|
+
def should_retry(self, ex: Exception) -> bool:
|
267
|
+
import requests # type: ignore
|
268
|
+
|
269
|
+
# standard http errors
|
270
|
+
if isinstance(ex, APIError):
|
271
|
+
return is_retryable_http_status(ex.status)
|
272
|
+
|
273
|
+
# low-level requests exceptions
|
274
|
+
elif isinstance(ex, requests.exceptions.RequestException):
|
275
|
+
return isinstance(
|
276
|
+
ex,
|
277
|
+
(
|
278
|
+
requests.exceptions.ConnectionError
|
279
|
+
| requests.exceptions.ConnectTimeout
|
280
|
+
| requests.exceptions.ChunkedEncodingError
|
281
|
+
),
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
return False
|
260
285
|
|
261
286
|
@override
|
262
287
|
def connection_key(self) -> str:
|
@@ -296,6 +321,7 @@ def build_model_call(
|
|
296
321
|
tools: list[Tool] | None,
|
297
322
|
tool_config: ToolConfig | None,
|
298
323
|
response: GenerateContentResponse | None,
|
324
|
+
time: float | None,
|
299
325
|
) -> ModelCall:
|
300
326
|
return ModelCall.create(
|
301
327
|
request=dict(
|
@@ -307,6 +333,7 @@ def build_model_call(
|
|
307
333
|
),
|
308
334
|
response=response if response is not None else {},
|
309
335
|
filter=model_call_filter,
|
336
|
+
time=time,
|
310
337
|
)
|
311
338
|
|
312
339
|
|
@@ -464,7 +491,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
464
491
|
|
465
492
|
|
466
493
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
467
|
-
def schema_from_param(
|
494
|
+
def schema_from_param(
|
495
|
+
param: ToolParam | ToolParams, nullable: bool | None = False
|
496
|
+
) -> Schema:
|
468
497
|
if isinstance(param, ToolParams):
|
469
498
|
param = ToolParam(
|
470
499
|
type=param.type, properties=param.properties, required=param.required
|
@@ -529,10 +558,13 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
|
|
529
558
|
|
530
559
|
|
531
560
|
def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
|
532
|
-
# check for completion text
|
533
|
-
content = ""
|
534
561
|
# content can be None when the finish_reason is SAFETY
|
535
|
-
if candidate.content is
|
562
|
+
if candidate.content is None:
|
563
|
+
content = ""
|
564
|
+
# content.parts can be None when the finish_reason is MALFORMED_FUNCTION_CALL
|
565
|
+
elif candidate.content.parts is None:
|
566
|
+
content = ""
|
567
|
+
else:
|
536
568
|
content = " ".join(
|
537
569
|
[
|
538
570
|
part.text
|
@@ -680,6 +712,8 @@ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
|
|
680
712
|
):
|
681
713
|
return "content_filter"
|
682
714
|
case _:
|
715
|
+
# Note: to avoid adding another option to StopReason,
|
716
|
+
# this includes FinishReason.MALFORMED_FUNCTION_CALL
|
683
717
|
return "unknown"
|
684
718
|
|
685
719
|
|
@@ -775,7 +809,7 @@ async def file_for_content(
|
|
775
809
|
file=BytesIO(content_bytes), config=dict(mime_type=mime_type)
|
776
810
|
)
|
777
811
|
while upload.state.name == "PROCESSING":
|
778
|
-
await
|
812
|
+
await anyio.sleep(3)
|
779
813
|
upload = client.files.get(name=upload.name)
|
780
814
|
if upload.state.name == "FAILED":
|
781
815
|
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
@@ -5,8 +5,9 @@ from typing import Any, Dict, Iterable, List, Optional
|
|
5
5
|
|
6
6
|
import httpx
|
7
7
|
from groq import (
|
8
|
+
APIStatusError,
|
9
|
+
APITimeoutError,
|
8
10
|
AsyncGroq,
|
9
|
-
RateLimitError,
|
10
11
|
)
|
11
12
|
from groq.types.chat import (
|
12
13
|
ChatCompletion,
|
@@ -25,10 +26,10 @@ from typing_extensions import override
|
|
25
26
|
|
26
27
|
from inspect_ai._util.constants import (
|
27
28
|
BASE_64_DATA_REMOVED,
|
28
|
-
DEFAULT_MAX_RETRIES,
|
29
29
|
DEFAULT_MAX_TOKENS,
|
30
30
|
)
|
31
31
|
from inspect_ai._util.content import Content, ContentReasoning, ContentText
|
32
|
+
from inspect_ai._util.http import is_retryable_http_status
|
32
33
|
from inspect_ai._util.images import file_as_data_uri
|
33
34
|
from inspect_ai._util.url import is_http_url
|
34
35
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
@@ -54,7 +55,7 @@ from .util import (
|
|
54
55
|
environment_prerequisite_error,
|
55
56
|
model_base_url,
|
56
57
|
)
|
57
|
-
from .util.
|
58
|
+
from .util.hooks import HttpxHooks
|
58
59
|
|
59
60
|
GROQ_API_KEY = "GROQ_API_KEY"
|
60
61
|
|
@@ -84,18 +85,12 @@ class GroqAPI(ModelAPI):
|
|
84
85
|
self.client = AsyncGroq(
|
85
86
|
api_key=self.api_key,
|
86
87
|
base_url=model_base_url(base_url, "GROQ_BASE_URL"),
|
87
|
-
max_retries=(
|
88
|
-
config.max_retries
|
89
|
-
if config.max_retries is not None
|
90
|
-
else DEFAULT_MAX_RETRIES
|
91
|
-
),
|
92
|
-
timeout=config.timeout if config.timeout is not None else 60.0,
|
93
88
|
**model_args,
|
94
89
|
http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)),
|
95
90
|
)
|
96
91
|
|
97
92
|
# create time tracker
|
98
|
-
self.
|
93
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
99
94
|
|
100
95
|
@override
|
101
96
|
async def close(self) -> None:
|
@@ -109,7 +104,7 @@ class GroqAPI(ModelAPI):
|
|
109
104
|
config: GenerateConfig,
|
110
105
|
) -> tuple[ModelOutput, ModelCall]:
|
111
106
|
# allocate request_id (so we can see it from ModelCall)
|
112
|
-
request_id = self.
|
107
|
+
request_id = self._http_hooks.start_request()
|
113
108
|
|
114
109
|
# setup request and response for ModelCall
|
115
110
|
request: dict[str, Any] = {}
|
@@ -120,7 +115,7 @@ class GroqAPI(ModelAPI):
|
|
120
115
|
request=request,
|
121
116
|
response=response,
|
122
117
|
filter=model_call_filter,
|
123
|
-
time=self.
|
118
|
+
time=self._http_hooks.end_request(request_id),
|
124
119
|
)
|
125
120
|
|
126
121
|
messages = await as_groq_chat_messages(input)
|
@@ -137,7 +132,7 @@ class GroqAPI(ModelAPI):
|
|
137
132
|
request = dict(
|
138
133
|
messages=messages,
|
139
134
|
model=self.model_name,
|
140
|
-
extra_headers={
|
135
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
141
136
|
**params,
|
142
137
|
)
|
143
138
|
|
@@ -215,8 +210,13 @@ class GroqAPI(ModelAPI):
|
|
215
210
|
]
|
216
211
|
|
217
212
|
@override
|
218
|
-
def
|
219
|
-
|
213
|
+
def should_retry(self, ex: Exception) -> bool:
|
214
|
+
if isinstance(ex, APIStatusError):
|
215
|
+
return is_retryable_http_status(ex.status_code)
|
216
|
+
elif isinstance(ex, APITimeoutError):
|
217
|
+
return True
|
218
|
+
else:
|
219
|
+
return False
|
220
220
|
|
221
221
|
@override
|
222
222
|
def connection_key(self) -> str:
|
@@ -1,15 +1,19 @@
|
|
1
|
-
import
|
1
|
+
import concurrent
|
2
|
+
import concurrent.futures
|
2
3
|
import copy
|
3
4
|
import functools
|
4
5
|
import gc
|
5
6
|
import json
|
6
7
|
import os
|
7
8
|
import time
|
9
|
+
from concurrent.futures import Future
|
8
10
|
from dataclasses import dataclass
|
11
|
+
from logging import getLogger
|
9
12
|
from queue import Empty, Queue
|
10
13
|
from threading import Thread
|
11
14
|
from typing import Any, Literal, Protocol, cast
|
12
15
|
|
16
|
+
import anyio
|
13
17
|
import numpy as np
|
14
18
|
import torch # type: ignore
|
15
19
|
from torch import Tensor # type: ignore
|
@@ -23,6 +27,7 @@ from typing_extensions import override
|
|
23
27
|
|
24
28
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
25
29
|
from inspect_ai._util.content import ContentText
|
30
|
+
from inspect_ai._util.trace import trace_action
|
26
31
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
27
32
|
|
28
33
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -38,6 +43,9 @@ from .._model_output import (
|
|
38
43
|
)
|
39
44
|
from .util import ChatAPIHandler, HFHandler
|
40
45
|
|
46
|
+
logger = getLogger(__name__)
|
47
|
+
|
48
|
+
|
41
49
|
HF_TOKEN = "HF_TOKEN"
|
42
50
|
|
43
51
|
|
@@ -385,8 +393,7 @@ class GenerateOutput:
|
|
385
393
|
@dataclass
|
386
394
|
class _QueueItem:
|
387
395
|
input: GenerateInput
|
388
|
-
future:
|
389
|
-
loop: asyncio.AbstractEventLoop
|
396
|
+
future: Future[GenerateOutput]
|
390
397
|
|
391
398
|
|
392
399
|
batch_thread: Thread | None = None
|
@@ -402,25 +409,26 @@ async def batched_generate(input: GenerateInput) -> GenerateOutput:
|
|
402
409
|
batch_thread.start()
|
403
410
|
|
404
411
|
# enqueue the job
|
405
|
-
|
406
|
-
|
407
|
-
batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
|
408
|
-
|
409
|
-
# await the job
|
410
|
-
await future
|
412
|
+
future = Future[GenerateOutput]()
|
413
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
411
414
|
|
412
|
-
#
|
413
|
-
|
415
|
+
# await the future
|
416
|
+
with trace_action(logger, "HF Batched Generate", "HF Batched Generate"):
|
417
|
+
while True:
|
418
|
+
try:
|
419
|
+
return future.result(timeout=0.01)
|
420
|
+
except concurrent.futures.TimeoutError:
|
421
|
+
pass
|
422
|
+
await anyio.sleep(1)
|
414
423
|
|
415
424
|
|
416
425
|
def process_batches() -> None:
|
417
426
|
while True:
|
418
427
|
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
419
|
-
inputs: list[tuple[GenerateInput,
|
428
|
+
inputs: list[tuple[GenerateInput, Future[GenerateOutput]]] = []
|
420
429
|
while True:
|
421
430
|
try:
|
422
431
|
input = batch_queue.get(timeout=2)
|
423
|
-
loop = input.loop
|
424
432
|
inputs.append((input.input, input.future))
|
425
433
|
if len(inputs) == input.input.batch_size:
|
426
434
|
# max batch size reached
|
@@ -480,8 +488,7 @@ def process_batches() -> None:
|
|
480
488
|
# asyncio futures are not thread safe, so we need to pass the event loop
|
481
489
|
# down to this point, so we can mark the future as done in a thread safe manner.
|
482
490
|
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
483
|
-
|
484
|
-
future.set_result,
|
491
|
+
future.set_result(
|
485
492
|
GenerateOutput(
|
486
493
|
output=output,
|
487
494
|
input_tokens=input_tokens,
|
@@ -489,13 +496,13 @@ def process_batches() -> None:
|
|
489
496
|
total_tokens=input_tokens + output_tokens,
|
490
497
|
logprobs=logprobs[i] if logprobs is not None else None,
|
491
498
|
time=total_time,
|
492
|
-
)
|
499
|
+
)
|
493
500
|
)
|
494
501
|
|
495
502
|
except Exception as ex:
|
496
503
|
for inp in inputs:
|
497
504
|
future = inp[1]
|
498
|
-
|
505
|
+
future.set_exception(ex)
|
499
506
|
|
500
507
|
|
501
508
|
def extract_logprobs(
|