inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from .chatapi import (
|
|
5
5
|
ChatAPIMessage,
|
6
6
|
chat_api_input,
|
7
7
|
chat_api_request,
|
8
|
-
|
8
|
+
should_retry_chat_api_error,
|
9
9
|
)
|
10
10
|
from .hf_handler import HFHandler
|
11
11
|
from .llama31 import Llama31Handler
|
@@ -19,7 +19,7 @@ __all__ = [
|
|
19
19
|
"as_stop_reason",
|
20
20
|
"chat_api_request",
|
21
21
|
"chat_api_input",
|
22
|
-
"
|
22
|
+
"should_retry_chat_api_error",
|
23
23
|
"model_base_url",
|
24
24
|
"parse_tool_call",
|
25
25
|
"tool_parse_error_message",
|
@@ -7,17 +7,15 @@ from tenacity import (
|
|
7
7
|
retry,
|
8
8
|
retry_if_exception,
|
9
9
|
stop_after_attempt,
|
10
|
-
stop_after_delay,
|
11
10
|
wait_exponential_jitter,
|
12
11
|
)
|
13
12
|
|
14
|
-
from inspect_ai._util.
|
15
|
-
from inspect_ai._util.
|
13
|
+
from inspect_ai._util.http import is_retryable_http_status
|
14
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
16
15
|
from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool
|
17
16
|
from inspect_ai.tool._tool_info import ToolInfo
|
18
17
|
|
19
18
|
from ..._chat_message import ChatMessage
|
20
|
-
from ..._generate_config import GenerateConfig
|
21
19
|
|
22
20
|
logger = getLogger(__name__)
|
23
21
|
|
@@ -75,21 +73,13 @@ async def chat_api_request(
|
|
75
73
|
url: str,
|
76
74
|
headers: dict[str, Any],
|
77
75
|
json: Any,
|
78
|
-
config: GenerateConfig,
|
79
76
|
) -> Any:
|
80
|
-
# provide default max_retries
|
81
|
-
max_retries = config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
82
|
-
|
83
77
|
# define call w/ retry policy
|
84
78
|
@retry(
|
85
79
|
wait=wait_exponential_jitter(),
|
86
|
-
stop=(
|
87
|
-
(stop_after_attempt(max_retries) | stop_after_delay(config.timeout))
|
88
|
-
if config.timeout
|
89
|
-
else stop_after_attempt(max_retries)
|
90
|
-
),
|
80
|
+
stop=(stop_after_attempt(2)),
|
91
81
|
retry=retry_if_exception(httpx_should_retry),
|
92
|
-
before_sleep=
|
82
|
+
before_sleep=log_httpx_retry_attempt(model_name),
|
93
83
|
)
|
94
84
|
async def call_api() -> Any:
|
95
85
|
response = await client.post(url=url, headers=headers, json=json)
|
@@ -104,14 +94,11 @@ async def chat_api_request(
|
|
104
94
|
# checking for rate limit errors needs to punch through the RetryError and
|
105
95
|
# look at its `__cause__`. we've observed Cloudflare giving transient 500
|
106
96
|
# status as well as a ReadTimeout, so we count these as rate limit errors
|
107
|
-
def
|
97
|
+
def should_retry_chat_api_error(ex: BaseException) -> bool:
|
108
98
|
return isinstance(ex, RetryError) and (
|
109
99
|
(
|
110
100
|
isinstance(ex.__cause__, httpx.HTTPStatusError)
|
111
|
-
and (
|
112
|
-
ex.__cause__.response.status_code == 429
|
113
|
-
or ex.__cause__.response.status_code == 500
|
114
|
-
)
|
101
|
+
and is_retryable_http_status(ex.__cause__.response.status_code)
|
115
102
|
)
|
116
103
|
or isinstance(ex.__cause__, httpx.ReadTimeout)
|
117
104
|
)
|
@@ -0,0 +1,165 @@
|
|
1
|
+
import re
|
2
|
+
import time
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import Any, Mapping, NamedTuple, cast
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
from shortuuid import uuid
|
8
|
+
|
9
|
+
from inspect_ai._util.constants import HTTP
|
10
|
+
from inspect_ai._util.retry import report_http_retry
|
11
|
+
|
12
|
+
logger = getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class RequestInfo(NamedTuple):
|
16
|
+
attempts: int
|
17
|
+
last_request: float
|
18
|
+
|
19
|
+
|
20
|
+
class HttpHooks:
|
21
|
+
"""Class which hooks various HTTP clients for improved tracking/logging.
|
22
|
+
|
23
|
+
A special header is injected into requests which is then read from
|
24
|
+
a request event hook -- this creates a record of when the request
|
25
|
+
started. Note that with retries a single request_id could be started
|
26
|
+
several times; our request hook makes sure we always track the time of
|
27
|
+
the last request.
|
28
|
+
|
29
|
+
There is an 'end_request()' method which gets the total request time
|
30
|
+
for a request_id and then purges the request_id from our tracking (so
|
31
|
+
the dict doesn't grow unbounded)
|
32
|
+
|
33
|
+
Additionally, an http response hook is installed and used for logging
|
34
|
+
requests for the 'http' log-level
|
35
|
+
"""
|
36
|
+
|
37
|
+
REQUEST_ID_HEADER = "x-irid"
|
38
|
+
|
39
|
+
def __init__(self) -> None:
|
40
|
+
# track request start times
|
41
|
+
self._requests: dict[str, RequestInfo] = {}
|
42
|
+
|
43
|
+
def start_request(self) -> str:
|
44
|
+
request_id = uuid()
|
45
|
+
self._requests[request_id] = RequestInfo(0, time.monotonic())
|
46
|
+
return request_id
|
47
|
+
|
48
|
+
def end_request(self, request_id: str) -> float:
|
49
|
+
# read the request info (if available) and purge from dict
|
50
|
+
request_info = self._requests.pop(request_id, None)
|
51
|
+
if request_info is None:
|
52
|
+
raise RuntimeError(f"request_id not registered: {request_id}")
|
53
|
+
|
54
|
+
# return elapsed time
|
55
|
+
return time.monotonic() - request_info.last_request
|
56
|
+
|
57
|
+
def update_request_time(self, request_id: str) -> None:
|
58
|
+
request_info = self._requests.get(request_id, None)
|
59
|
+
if not request_info:
|
60
|
+
raise RuntimeError(f"No request registered for request_id: {request_id}")
|
61
|
+
|
62
|
+
# update the attempts and last request time
|
63
|
+
request_info = RequestInfo(request_info.attempts + 1, time.monotonic())
|
64
|
+
self._requests[request_id] = request_info
|
65
|
+
|
66
|
+
# trace a retry if this is attempt > 1
|
67
|
+
if request_info.attempts > 1:
|
68
|
+
report_http_retry()
|
69
|
+
|
70
|
+
|
71
|
+
class ConverseHooks(HttpHooks):
|
72
|
+
def __init__(self, session: Any) -> None:
|
73
|
+
from aiobotocore.session import AioSession
|
74
|
+
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
# register hooks
|
78
|
+
session = cast(AioSession, session._session)
|
79
|
+
session.register(
|
80
|
+
"before-send.bedrock-runtime.Converse", self.converse_before_send
|
81
|
+
)
|
82
|
+
session.register(
|
83
|
+
"after-call.bedrock-runtime.Converse", self.converse_after_call
|
84
|
+
)
|
85
|
+
|
86
|
+
def converse_before_send(self, **kwargs: Any) -> None:
|
87
|
+
user_agent = kwargs["request"].headers["User-Agent"].decode()
|
88
|
+
match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
|
89
|
+
if match:
|
90
|
+
request_id = match.group(1)
|
91
|
+
self.update_request_time(request_id)
|
92
|
+
|
93
|
+
def converse_after_call(self, http_response: Any, **kwargs: Any) -> None:
|
94
|
+
from botocore.awsrequest import AWSResponse
|
95
|
+
|
96
|
+
response = cast(AWSResponse, http_response)
|
97
|
+
logger.log(HTTP, f"POST {response.url} - {response.status_code}")
|
98
|
+
|
99
|
+
def user_agent_extra(self, request_id: str) -> str:
|
100
|
+
return f"{self.USER_AGENT_PREFIX}{request_id}"
|
101
|
+
|
102
|
+
USER_AGENT_PREFIX = "ins/rid#"
|
103
|
+
|
104
|
+
|
105
|
+
class HttpxHooks(HttpHooks):
|
106
|
+
def __init__(self, client: httpx.AsyncClient):
|
107
|
+
super().__init__()
|
108
|
+
|
109
|
+
# install hooks
|
110
|
+
client.event_hooks["request"].append(self.request_hook)
|
111
|
+
client.event_hooks["response"].append(self.response_hook)
|
112
|
+
|
113
|
+
async def request_hook(self, request: httpx.Request) -> None:
|
114
|
+
# update the last request time for this request id (as there could be retries)
|
115
|
+
request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
|
116
|
+
if request_id:
|
117
|
+
self.update_request_time(request_id)
|
118
|
+
|
119
|
+
async def response_hook(self, response: httpx.Response) -> None:
|
120
|
+
message = f'{response.request.method} {response.request.url} "{response.http_version} {response.status_code} {response.reason_phrase}" '
|
121
|
+
logger.log(HTTP, message)
|
122
|
+
|
123
|
+
|
124
|
+
def urllib3_hooks() -> HttpHooks:
|
125
|
+
import urllib3
|
126
|
+
from urllib3.connectionpool import HTTPConnectionPool
|
127
|
+
from urllib3.response import BaseHTTPResponse
|
128
|
+
|
129
|
+
class Urllib3Hooks(HttpHooks):
|
130
|
+
def request_hook(self, headers: Mapping[str, str]) -> None:
|
131
|
+
# update the last request time for this request id (as there could be retries)
|
132
|
+
request_id = headers.get(self.REQUEST_ID_HEADER, None)
|
133
|
+
if request_id:
|
134
|
+
self.update_request_time(request_id)
|
135
|
+
|
136
|
+
def response_hook(
|
137
|
+
self, method: str, url: str, response: BaseHTTPResponse
|
138
|
+
) -> None:
|
139
|
+
message = f'{method} {url} "{response.version_string} {response.status} {response.reason}" '
|
140
|
+
logger.log(HTTP, message)
|
141
|
+
|
142
|
+
global _urlilb3_hooks
|
143
|
+
if _urlilb3_hooks is None:
|
144
|
+
# one time patch of urlopen
|
145
|
+
urlilb3_hooks = Urllib3Hooks()
|
146
|
+
original_urlopen = urllib3.connectionpool.HTTPConnectionPool.urlopen
|
147
|
+
|
148
|
+
def patched_urlopen(
|
149
|
+
self: HTTPConnectionPool, method: str, url: str, **kwargs: Any
|
150
|
+
) -> BaseHTTPResponse:
|
151
|
+
headers = kwargs.get("headers", {})
|
152
|
+
urlilb3_hooks.request_hook(headers)
|
153
|
+
response = original_urlopen(self, method, url, **kwargs)
|
154
|
+
urlilb3_hooks.response_hook(method, f"{self.host}{url}", response)
|
155
|
+
return response
|
156
|
+
|
157
|
+
urllib3.connectionpool.HTTPConnectionPool.urlopen = patched_urlopen # type: ignore[assignment,method-assign]
|
158
|
+
|
159
|
+
# assign to global hooks instance
|
160
|
+
_urlilb3_hooks = urlilb3_hooks
|
161
|
+
|
162
|
+
return _urlilb3_hooks
|
163
|
+
|
164
|
+
|
165
|
+
_urlilb3_hooks: HttpHooks | None = None
|
@@ -4,7 +4,13 @@ from copy import copy
|
|
4
4
|
from typing import Any, cast
|
5
5
|
|
6
6
|
import vertexai # type: ignore
|
7
|
-
from google.api_core.exceptions import
|
7
|
+
from google.api_core.exceptions import (
|
8
|
+
Aborted,
|
9
|
+
ClientError,
|
10
|
+
DeadlineExceeded,
|
11
|
+
ServiceUnavailable,
|
12
|
+
)
|
13
|
+
from google.api_core.retry import if_transient_error
|
8
14
|
from google.protobuf.json_format import MessageToDict
|
9
15
|
from pydantic import JsonValue
|
10
16
|
from typing_extensions import override
|
@@ -31,6 +37,7 @@ from inspect_ai._util.content import (
|
|
31
37
|
ContentText,
|
32
38
|
ContentVideo,
|
33
39
|
)
|
40
|
+
from inspect_ai._util.http import is_retryable_http_status
|
34
41
|
from inspect_ai._util.images import file_as_data
|
35
42
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
|
36
43
|
|
@@ -169,8 +176,18 @@ class VertexAPI(ModelAPI):
|
|
169
176
|
return output, call
|
170
177
|
|
171
178
|
@override
|
172
|
-
def
|
173
|
-
|
179
|
+
def should_retry(self, ex: Exception) -> bool:
|
180
|
+
# google API-specific errors
|
181
|
+
if isinstance(ex, Aborted | DeadlineExceeded | ServiceUnavailable):
|
182
|
+
return True
|
183
|
+
# standard HTTP errors
|
184
|
+
elif isinstance(ex, ClientError) and ex.code is not None:
|
185
|
+
return is_retryable_http_status(ex.code)
|
186
|
+
# additional errors flagged by google as transient
|
187
|
+
elif isinstance(ex, Exception):
|
188
|
+
return if_transient_error(ex)
|
189
|
+
else:
|
190
|
+
return False
|
174
191
|
|
175
192
|
@override
|
176
193
|
def connection_key(self) -> str:
|
@@ -1,13 +1,15 @@
|
|
1
|
-
import
|
1
|
+
import concurrent.futures
|
2
2
|
import functools
|
3
3
|
import gc
|
4
4
|
import os
|
5
5
|
import time
|
6
|
+
from concurrent.futures import Future
|
6
7
|
from dataclasses import dataclass
|
7
8
|
from queue import Empty, Queue
|
8
9
|
from threading import Thread
|
9
10
|
from typing import Any, cast
|
10
11
|
|
12
|
+
import anyio
|
11
13
|
from typing_extensions import override
|
12
14
|
from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
|
13
15
|
|
@@ -280,8 +282,7 @@ class VLLMAPI(ModelAPI):
|
|
280
282
|
@dataclass
|
281
283
|
class _QueueItem:
|
282
284
|
input: GenerateInput
|
283
|
-
future:
|
284
|
-
loop: asyncio.AbstractEventLoop
|
285
|
+
future: Future[list[GenerateOutput]]
|
285
286
|
|
286
287
|
|
287
288
|
batch_thread: Thread | None = None
|
@@ -297,15 +298,16 @@ async def batched_generate(input: GenerateInput) -> list[GenerateOutput]:
|
|
297
298
|
batch_thread.start()
|
298
299
|
|
299
300
|
# enqueue the job
|
300
|
-
|
301
|
-
|
302
|
-
batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
|
301
|
+
future = Future[list[GenerateOutput]]()
|
302
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
303
303
|
|
304
|
-
# await the
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
304
|
+
# await the future
|
305
|
+
while True:
|
306
|
+
try:
|
307
|
+
return future.result(timeout=0.01)
|
308
|
+
except concurrent.futures.TimeoutError:
|
309
|
+
pass
|
310
|
+
await anyio.sleep(1)
|
309
311
|
|
310
312
|
|
311
313
|
def string_to_bytes(string: str) -> list[int]:
|
@@ -397,13 +399,12 @@ def post_process_outputs(
|
|
397
399
|
def process_batches() -> None:
|
398
400
|
while True:
|
399
401
|
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
400
|
-
inputs: list[tuple[GenerateInput,
|
402
|
+
inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
|
401
403
|
while True:
|
402
404
|
try:
|
403
405
|
input = batch_queue.get(
|
404
406
|
timeout=2
|
405
407
|
) # wait 2 seconds max TODO: what's optimal wait time?
|
406
|
-
loop = input.loop
|
407
408
|
inputs.append((input.input, input.future))
|
408
409
|
if len(inputs) >= input.input.batch_size:
|
409
410
|
# max batch size reached
|
@@ -429,14 +430,10 @@ def process_batches() -> None:
|
|
429
430
|
for i, output in enumerate(outputs):
|
430
431
|
future = inputs[i][1]
|
431
432
|
|
432
|
-
|
433
|
-
# down to this point, so we can mark the future as done in a thread safe manner.
|
434
|
-
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
435
|
-
loop.call_soon_threadsafe(
|
436
|
-
future.set_result,
|
433
|
+
future.set_result(
|
437
434
|
post_process_outputs(output, num_top_logprobs, total_time),
|
438
435
|
)
|
439
436
|
|
440
437
|
except Exception as e:
|
441
438
|
for _, future in inputs:
|
442
|
-
|
439
|
+
future.set_exception(e)
|
inspect_ai/scorer/_multi.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
|
-
import
|
1
|
+
import functools
|
2
2
|
|
3
|
+
from inspect_ai._util._async import tg_collect
|
3
4
|
from inspect_ai.scorer._reducer.registry import create_reducers
|
4
5
|
from inspect_ai.solver._task_state import TaskState
|
5
6
|
|
@@ -19,7 +20,9 @@ def multi_scorer(scorers: list[Scorer], reducer: str | ScoreReducer) -> Scorer:
|
|
19
20
|
reducer = create_reducers(reducer)[0]
|
20
21
|
|
21
22
|
async def score(state: TaskState, target: Target) -> Score:
|
22
|
-
scores = await
|
23
|
+
scores = await tg_collect(
|
24
|
+
[functools.partial(_scorer, state, target) for _scorer in scorers]
|
25
|
+
)
|
23
26
|
return reducer(scores)
|
24
27
|
|
25
28
|
return score
|
@@ -11,11 +11,12 @@ from openai._types import ResponseT
|
|
11
11
|
from openai.types.chat import (
|
12
12
|
ChatCompletion,
|
13
13
|
ChatCompletionMessageParam,
|
14
|
+
ChatCompletionToolChoiceOptionParam,
|
14
15
|
ChatCompletionToolParam,
|
15
16
|
)
|
16
17
|
from shortuuid import uuid
|
17
18
|
|
18
|
-
from inspect_ai.model._generate_config import GenerateConfig
|
19
|
+
from inspect_ai.model._generate_config import GenerateConfig, ResponseSchema
|
19
20
|
from inspect_ai.model._model import get_model
|
20
21
|
from inspect_ai.model._openai import (
|
21
22
|
chat_messages_from_openai,
|
@@ -23,8 +24,10 @@ from inspect_ai.model._openai import (
|
|
23
24
|
openai_completion_usage,
|
24
25
|
)
|
25
26
|
from inspect_ai.solver._task_state import sample_state
|
27
|
+
from inspect_ai.tool._tool_choice import ToolChoice, ToolFunction
|
26
28
|
from inspect_ai.tool._tool_info import ToolInfo
|
27
29
|
from inspect_ai.tool._tool_params import ToolParams
|
30
|
+
from inspect_ai.util._json import JSONSchema
|
28
31
|
|
29
32
|
|
30
33
|
@contextlib.asynccontextmanager
|
@@ -113,6 +116,20 @@ async def inspect_model_request(
|
|
113
116
|
)
|
114
117
|
)
|
115
118
|
|
119
|
+
# convert openai tool choice to inspect tool_choice
|
120
|
+
inspect_tool_choice: ToolChoice | None = None
|
121
|
+
tool_choice: ChatCompletionToolChoiceOptionParam | None = json_data.get(
|
122
|
+
"tool_choice", None
|
123
|
+
)
|
124
|
+
if tool_choice is not None:
|
125
|
+
match tool_choice:
|
126
|
+
case "auto" | "none":
|
127
|
+
inspect_tool_choice = tool_choice
|
128
|
+
case "required":
|
129
|
+
inspect_tool_choice = "any"
|
130
|
+
case _:
|
131
|
+
inspect_tool_choice = ToolFunction(name=tool_choice["function"]["name"])
|
132
|
+
|
116
133
|
# resolve model
|
117
134
|
if model_name == "inspect":
|
118
135
|
model = get_model()
|
@@ -122,6 +139,7 @@ async def inspect_model_request(
|
|
122
139
|
output = await model.generate(
|
123
140
|
input=input,
|
124
141
|
tools=inspect_tools,
|
142
|
+
tool_choice=inspect_tool_choice,
|
125
143
|
config=generate_config_from_openai(options),
|
126
144
|
)
|
127
145
|
|
@@ -165,4 +183,16 @@ def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
|
|
165
183
|
config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
|
166
184
|
config.reasoning_effort = json_data.get("reasoning_effort", None)
|
167
185
|
|
186
|
+
# response format
|
187
|
+
response_format: dict[str, Any] | None = json_data.get("response_format", None)
|
188
|
+
if response_format is not None:
|
189
|
+
json_schema: dict[str, Any] | None = response_format.get("json_schema", None)
|
190
|
+
if json_schema is not None:
|
191
|
+
config.response_schema = ResponseSchema(
|
192
|
+
name=json_schema.get("name", "schema"),
|
193
|
+
description=json_schema.get("description", None),
|
194
|
+
json_schema=JSONSchema.model_validate(json_schema.get("schema", {})),
|
195
|
+
strict=json_schema.get("strict", None),
|
196
|
+
)
|
197
|
+
|
168
198
|
return config
|
inspect_ai/solver/_fork.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
import
|
1
|
+
import functools
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from copy import deepcopy
|
4
4
|
from typing import Any, cast
|
5
5
|
|
6
6
|
from typing_extensions import overload
|
7
7
|
|
8
|
+
from inspect_ai._util._async import tg_collect
|
8
9
|
from inspect_ai._util.registry import registry_log_name, registry_params
|
9
10
|
from inspect_ai.util._subtask import subtask
|
10
11
|
|
@@ -44,8 +45,9 @@ async def fork(
|
|
44
45
|
if isinstance(solvers, Solver):
|
45
46
|
return await solver_subtask(state, solvers)
|
46
47
|
else:
|
47
|
-
|
48
|
-
|
48
|
+
return await tg_collect(
|
49
|
+
[functools.partial(solver_subtask, state, solver) for solver in solvers]
|
50
|
+
)
|
49
51
|
|
50
52
|
|
51
53
|
async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
|
@@ -1,6 +1,7 @@
|
|
1
|
-
import asyncio
|
2
1
|
from typing import cast
|
3
2
|
|
3
|
+
import anyio
|
4
|
+
|
4
5
|
from inspect_ai.util import display_type, input_panel, sandbox
|
5
6
|
from inspect_ai.util._sandbox.events import SandboxEnvironmentProxy
|
6
7
|
|
@@ -42,7 +43,7 @@ def human_agent(
|
|
42
43
|
Solver: Human agent solver.
|
43
44
|
"""
|
44
45
|
# we can only run one human agent interaction at a time (use lock to enforce)
|
45
|
-
agent_lock =
|
46
|
+
agent_lock = anyio.Lock()
|
46
47
|
|
47
48
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
48
49
|
async with agent_lock:
|
inspect_ai/tool/__init__.py
CHANGED
@@ -20,7 +20,7 @@ from ._tool_call import (
|
|
20
20
|
from ._tool_choice import ToolChoice, ToolFunction
|
21
21
|
from ._tool_def import ToolDef
|
22
22
|
from ._tool_info import ToolInfo
|
23
|
-
from ._tool_params import
|
23
|
+
from ._tool_params import ToolParam, ToolParams
|
24
24
|
from ._tool_with import tool_with
|
25
25
|
from ._tools._computer import computer
|
26
26
|
from ._tools._execute import bash, python
|
@@ -56,12 +56,18 @@ __all__ = [
|
|
56
56
|
"ToolInfo",
|
57
57
|
"ToolParam",
|
58
58
|
"ToolParams",
|
59
|
-
"JSONType",
|
60
59
|
]
|
61
60
|
|
62
61
|
_UTIL_MODULE_VERSION = "0.3.19"
|
62
|
+
_JSON_MODULE_VERSION = "0.3.73"
|
63
63
|
_REMOVED_IN = "0.4"
|
64
64
|
|
65
|
+
relocated_module_attribute(
|
66
|
+
"JSONType",
|
67
|
+
"inspect_ai.util.JSONType",
|
68
|
+
_JSON_MODULE_VERSION,
|
69
|
+
_REMOVED_IN,
|
70
|
+
)
|
65
71
|
|
66
72
|
relocated_module_attribute(
|
67
73
|
"ToolEnvironment",
|
inspect_ai/tool/_tool_info.py
CHANGED
@@ -1,27 +1,19 @@
|
|
1
1
|
import inspect
|
2
|
-
import types
|
3
|
-
import typing
|
4
|
-
from dataclasses import is_dataclass
|
5
2
|
from typing import (
|
6
3
|
Any,
|
7
4
|
Callable,
|
8
5
|
Dict,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Tuple,
|
12
|
-
Type,
|
13
|
-
Union,
|
14
6
|
get_args,
|
15
|
-
get_origin,
|
16
7
|
get_type_hints,
|
17
|
-
is_typeddict,
|
18
8
|
)
|
19
9
|
|
20
10
|
from docstring_parser import Docstring, parse
|
21
11
|
from pydantic import BaseModel, Field
|
22
12
|
|
13
|
+
from inspect_ai.util._json import JSONType, json_schema
|
14
|
+
|
23
15
|
from ._tool_description import tool_description
|
24
|
-
from ._tool_params import
|
16
|
+
from ._tool_params import ToolParam, ToolParams
|
25
17
|
|
26
18
|
|
27
19
|
class ToolInfo(BaseModel):
|
@@ -88,7 +80,7 @@ def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
|
|
88
80
|
|
89
81
|
# Get type information from type annotations
|
90
82
|
if param_name in type_hints:
|
91
|
-
tool_param =
|
83
|
+
tool_param = json_schema(type_hints[param_name])
|
92
84
|
# as a fallback try to parse it from the docstring
|
93
85
|
# (this is minimally necessary for backwards compatiblity
|
94
86
|
# with tools gen1 type parsing, which only used docstrings)
|
@@ -129,84 +121,6 @@ def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
|
|
129
121
|
return info
|
130
122
|
|
131
123
|
|
132
|
-
def parse_type(type_hint: Type[Any]) -> ToolParam:
|
133
|
-
origin = get_origin(type_hint)
|
134
|
-
args = get_args(type_hint)
|
135
|
-
|
136
|
-
if origin is None:
|
137
|
-
if type_hint is int:
|
138
|
-
return ToolParam(type="integer")
|
139
|
-
elif type_hint is float:
|
140
|
-
return ToolParam(type="number")
|
141
|
-
elif type_hint is str:
|
142
|
-
return ToolParam(type="string")
|
143
|
-
elif type_hint is bool:
|
144
|
-
return ToolParam(type="boolean")
|
145
|
-
elif type_hint is list:
|
146
|
-
return ToolParam(type="array", items=ToolParam())
|
147
|
-
elif type_hint is dict:
|
148
|
-
return ToolParam(type="object", additionalProperties=ToolParam())
|
149
|
-
elif (
|
150
|
-
is_dataclass(type_hint)
|
151
|
-
or is_typeddict(type_hint)
|
152
|
-
or (isinstance(type_hint, type) and issubclass(type_hint, BaseModel))
|
153
|
-
):
|
154
|
-
return parse_object(type_hint)
|
155
|
-
elif type_hint is type(None):
|
156
|
-
return ToolParam(type="null")
|
157
|
-
else:
|
158
|
-
return ToolParam()
|
159
|
-
elif origin is list or origin is List or origin is tuple or origin is Tuple:
|
160
|
-
return ToolParam(
|
161
|
-
type="array", items=parse_type(args[0]) if args else ToolParam()
|
162
|
-
)
|
163
|
-
elif origin is dict or origin is Dict:
|
164
|
-
return ToolParam(
|
165
|
-
type="object",
|
166
|
-
additionalProperties=parse_type(args[1]) if len(args) > 1 else ToolParam(),
|
167
|
-
)
|
168
|
-
elif origin is Union or origin is types.UnionType:
|
169
|
-
return ToolParam(anyOf=[parse_type(arg) for arg in args])
|
170
|
-
elif origin is Optional:
|
171
|
-
return ToolParam(
|
172
|
-
anyOf=[parse_type(arg) for arg in args] + [ToolParam(type="null")]
|
173
|
-
)
|
174
|
-
elif origin is typing.Literal:
|
175
|
-
return ToolParam(enum=list(args))
|
176
|
-
|
177
|
-
return ToolParam() # Default case if we can't determine the type
|
178
|
-
|
179
|
-
|
180
|
-
def parse_object(cls: Type[Any]) -> ToolParam:
|
181
|
-
properties: Dict[str, ToolParam] = {}
|
182
|
-
required: List[str] = []
|
183
|
-
|
184
|
-
if is_dataclass(cls):
|
185
|
-
fields = cls.__dataclass_fields__ # type: ignore
|
186
|
-
for name, field in fields.items():
|
187
|
-
properties[name] = parse_type(field.type) # type: ignore
|
188
|
-
if field.default == field.default_factory:
|
189
|
-
required.append(name)
|
190
|
-
elif isinstance(cls, type) and issubclass(cls, BaseModel):
|
191
|
-
schema = cls.model_json_schema()
|
192
|
-
for name, prop in schema.get("properties", {}).items():
|
193
|
-
properties[name] = ToolParam(**prop)
|
194
|
-
required = schema.get("required", [])
|
195
|
-
elif is_typeddict(cls):
|
196
|
-
annotations = get_type_hints(cls)
|
197
|
-
for name, type_hint in annotations.items():
|
198
|
-
properties[name] = parse_type(type_hint)
|
199
|
-
if name in cls.__required_keys__:
|
200
|
-
required.append(name)
|
201
|
-
|
202
|
-
return ToolParam(
|
203
|
-
type="object",
|
204
|
-
properties=properties,
|
205
|
-
required=required if required else None,
|
206
|
-
additionalProperties=False,
|
207
|
-
)
|
208
|
-
|
209
|
-
|
210
124
|
def parse_docstring(docstring: str | None, param_name: str) -> Dict[str, str]:
|
211
125
|
if not docstring:
|
212
126
|
return {}
|