inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from .chatapi import (
5
5
  ChatAPIMessage,
6
6
  chat_api_input,
7
7
  chat_api_request,
8
- is_chat_api_rate_limit,
8
+ should_retry_chat_api_error,
9
9
  )
10
10
  from .hf_handler import HFHandler
11
11
  from .llama31 import Llama31Handler
@@ -19,7 +19,7 @@ __all__ = [
19
19
  "as_stop_reason",
20
20
  "chat_api_request",
21
21
  "chat_api_input",
22
- "is_chat_api_rate_limit",
22
+ "should_retry_chat_api_error",
23
23
  "model_base_url",
24
24
  "parse_tool_call",
25
25
  "tool_parse_error_message",
@@ -7,17 +7,15 @@ from tenacity import (
7
7
  retry,
8
8
  retry_if_exception,
9
9
  stop_after_attempt,
10
- stop_after_delay,
11
10
  wait_exponential_jitter,
12
11
  )
13
12
 
14
- from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
15
- from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
13
+ from inspect_ai._util.http import is_retryable_http_status
14
+ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
16
15
  from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool
17
16
  from inspect_ai.tool._tool_info import ToolInfo
18
17
 
19
18
  from ..._chat_message import ChatMessage
20
- from ..._generate_config import GenerateConfig
21
19
 
22
20
  logger = getLogger(__name__)
23
21
 
@@ -75,21 +73,13 @@ async def chat_api_request(
75
73
  url: str,
76
74
  headers: dict[str, Any],
77
75
  json: Any,
78
- config: GenerateConfig,
79
76
  ) -> Any:
80
- # provide default max_retries
81
- max_retries = config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
82
-
83
77
  # define call w/ retry policy
84
78
  @retry(
85
79
  wait=wait_exponential_jitter(),
86
- stop=(
87
- (stop_after_attempt(max_retries) | stop_after_delay(config.timeout))
88
- if config.timeout
89
- else stop_after_attempt(max_retries)
90
- ),
80
+ stop=(stop_after_attempt(2)),
91
81
  retry=retry_if_exception(httpx_should_retry),
92
- before_sleep=log_retry_attempt(model_name),
82
+ before_sleep=log_httpx_retry_attempt(model_name),
93
83
  )
94
84
  async def call_api() -> Any:
95
85
  response = await client.post(url=url, headers=headers, json=json)
@@ -104,14 +94,11 @@ async def chat_api_request(
104
94
  # checking for rate limit errors needs to punch through the RetryError and
105
95
  # look at its `__cause__`. we've observed Cloudflare giving transient 500
106
96
  # status as well as a ReadTimeout, so we count these as rate limit errors
107
- def is_chat_api_rate_limit(ex: BaseException) -> bool:
97
+ def should_retry_chat_api_error(ex: BaseException) -> bool:
108
98
  return isinstance(ex, RetryError) and (
109
99
  (
110
100
  isinstance(ex.__cause__, httpx.HTTPStatusError)
111
- and (
112
- ex.__cause__.response.status_code == 429
113
- or ex.__cause__.response.status_code == 500
114
- )
101
+ and is_retryable_http_status(ex.__cause__.response.status_code)
115
102
  )
116
103
  or isinstance(ex.__cause__, httpx.ReadTimeout)
117
104
  )
@@ -0,0 +1,165 @@
1
+ import re
2
+ import time
3
+ from logging import getLogger
4
+ from typing import Any, Mapping, NamedTuple, cast
5
+
6
+ import httpx
7
+ from shortuuid import uuid
8
+
9
+ from inspect_ai._util.constants import HTTP
10
+ from inspect_ai._util.retry import report_http_retry
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class RequestInfo(NamedTuple):
16
+ attempts: int
17
+ last_request: float
18
+
19
+
20
+ class HttpHooks:
21
+ """Class which hooks various HTTP clients for improved tracking/logging.
22
+
23
+ A special header is injected into requests which is then read from
24
+ a request event hook -- this creates a record of when the request
25
+ started. Note that with retries a single request_id could be started
26
+ several times; our request hook makes sure we always track the time of
27
+ the last request.
28
+
29
+ There is an 'end_request()' method which gets the total request time
30
+ for a request_id and then purges the request_id from our tracking (so
31
+ the dict doesn't grow unbounded)
32
+
33
+ Additionally, an http response hook is installed and used for logging
34
+ requests for the 'http' log-level
35
+ """
36
+
37
+ REQUEST_ID_HEADER = "x-irid"
38
+
39
+ def __init__(self) -> None:
40
+ # track request start times
41
+ self._requests: dict[str, RequestInfo] = {}
42
+
43
+ def start_request(self) -> str:
44
+ request_id = uuid()
45
+ self._requests[request_id] = RequestInfo(0, time.monotonic())
46
+ return request_id
47
+
48
+ def end_request(self, request_id: str) -> float:
49
+ # read the request info (if available) and purge from dict
50
+ request_info = self._requests.pop(request_id, None)
51
+ if request_info is None:
52
+ raise RuntimeError(f"request_id not registered: {request_id}")
53
+
54
+ # return elapsed time
55
+ return time.monotonic() - request_info.last_request
56
+
57
+ def update_request_time(self, request_id: str) -> None:
58
+ request_info = self._requests.get(request_id, None)
59
+ if not request_info:
60
+ raise RuntimeError(f"No request registered for request_id: {request_id}")
61
+
62
+ # update the attempts and last request time
63
+ request_info = RequestInfo(request_info.attempts + 1, time.monotonic())
64
+ self._requests[request_id] = request_info
65
+
66
+ # trace a retry if this is attempt > 1
67
+ if request_info.attempts > 1:
68
+ report_http_retry()
69
+
70
+
71
+ class ConverseHooks(HttpHooks):
72
+ def __init__(self, session: Any) -> None:
73
+ from aiobotocore.session import AioSession
74
+
75
+ super().__init__()
76
+
77
+ # register hooks
78
+ session = cast(AioSession, session._session)
79
+ session.register(
80
+ "before-send.bedrock-runtime.Converse", self.converse_before_send
81
+ )
82
+ session.register(
83
+ "after-call.bedrock-runtime.Converse", self.converse_after_call
84
+ )
85
+
86
+ def converse_before_send(self, **kwargs: Any) -> None:
87
+ user_agent = kwargs["request"].headers["User-Agent"].decode()
88
+ match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent)
89
+ if match:
90
+ request_id = match.group(1)
91
+ self.update_request_time(request_id)
92
+
93
+ def converse_after_call(self, http_response: Any, **kwargs: Any) -> None:
94
+ from botocore.awsrequest import AWSResponse
95
+
96
+ response = cast(AWSResponse, http_response)
97
+ logger.log(HTTP, f"POST {response.url} - {response.status_code}")
98
+
99
+ def user_agent_extra(self, request_id: str) -> str:
100
+ return f"{self.USER_AGENT_PREFIX}{request_id}"
101
+
102
+ USER_AGENT_PREFIX = "ins/rid#"
103
+
104
+
105
+ class HttpxHooks(HttpHooks):
106
+ def __init__(self, client: httpx.AsyncClient):
107
+ super().__init__()
108
+
109
+ # install hooks
110
+ client.event_hooks["request"].append(self.request_hook)
111
+ client.event_hooks["response"].append(self.response_hook)
112
+
113
+ async def request_hook(self, request: httpx.Request) -> None:
114
+ # update the last request time for this request id (as there could be retries)
115
+ request_id = request.headers.get(self.REQUEST_ID_HEADER, None)
116
+ if request_id:
117
+ self.update_request_time(request_id)
118
+
119
+ async def response_hook(self, response: httpx.Response) -> None:
120
+ message = f'{response.request.method} {response.request.url} "{response.http_version} {response.status_code} {response.reason_phrase}" '
121
+ logger.log(HTTP, message)
122
+
123
+
124
+ def urllib3_hooks() -> HttpHooks:
125
+ import urllib3
126
+ from urllib3.connectionpool import HTTPConnectionPool
127
+ from urllib3.response import BaseHTTPResponse
128
+
129
+ class Urllib3Hooks(HttpHooks):
130
+ def request_hook(self, headers: Mapping[str, str]) -> None:
131
+ # update the last request time for this request id (as there could be retries)
132
+ request_id = headers.get(self.REQUEST_ID_HEADER, None)
133
+ if request_id:
134
+ self.update_request_time(request_id)
135
+
136
+ def response_hook(
137
+ self, method: str, url: str, response: BaseHTTPResponse
138
+ ) -> None:
139
+ message = f'{method} {url} "{response.version_string} {response.status} {response.reason}" '
140
+ logger.log(HTTP, message)
141
+
142
+ global _urlilb3_hooks
143
+ if _urlilb3_hooks is None:
144
+ # one time patch of urlopen
145
+ urlilb3_hooks = Urllib3Hooks()
146
+ original_urlopen = urllib3.connectionpool.HTTPConnectionPool.urlopen
147
+
148
+ def patched_urlopen(
149
+ self: HTTPConnectionPool, method: str, url: str, **kwargs: Any
150
+ ) -> BaseHTTPResponse:
151
+ headers = kwargs.get("headers", {})
152
+ urlilb3_hooks.request_hook(headers)
153
+ response = original_urlopen(self, method, url, **kwargs)
154
+ urlilb3_hooks.response_hook(method, f"{self.host}{url}", response)
155
+ return response
156
+
157
+ urllib3.connectionpool.HTTPConnectionPool.urlopen = patched_urlopen # type: ignore[assignment,method-assign]
158
+
159
+ # assign to global hooks instance
160
+ _urlilb3_hooks = urlilb3_hooks
161
+
162
+ return _urlilb3_hooks
163
+
164
+
165
+ _urlilb3_hooks: HttpHooks | None = None
@@ -4,7 +4,13 @@ from copy import copy
4
4
  from typing import Any, cast
5
5
 
6
6
  import vertexai # type: ignore
7
- from google.api_core.exceptions import TooManyRequests
7
+ from google.api_core.exceptions import (
8
+ Aborted,
9
+ ClientError,
10
+ DeadlineExceeded,
11
+ ServiceUnavailable,
12
+ )
13
+ from google.api_core.retry import if_transient_error
8
14
  from google.protobuf.json_format import MessageToDict
9
15
  from pydantic import JsonValue
10
16
  from typing_extensions import override
@@ -31,6 +37,7 @@ from inspect_ai._util.content import (
31
37
  ContentText,
32
38
  ContentVideo,
33
39
  )
40
+ from inspect_ai._util.http import is_retryable_http_status
34
41
  from inspect_ai._util.images import file_as_data
35
42
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
36
43
 
@@ -169,8 +176,18 @@ class VertexAPI(ModelAPI):
169
176
  return output, call
170
177
 
171
178
  @override
172
- def is_rate_limit(self, ex: BaseException) -> bool:
173
- return isinstance(ex, TooManyRequests)
179
+ def should_retry(self, ex: Exception) -> bool:
180
+ # google API-specific errors
181
+ if isinstance(ex, Aborted | DeadlineExceeded | ServiceUnavailable):
182
+ return True
183
+ # standard HTTP errors
184
+ elif isinstance(ex, ClientError) and ex.code is not None:
185
+ return is_retryable_http_status(ex.code)
186
+ # additional errors flagged by google as transient
187
+ elif isinstance(ex, Exception):
188
+ return if_transient_error(ex)
189
+ else:
190
+ return False
174
191
 
175
192
  @override
176
193
  def connection_key(self) -> str:
@@ -1,13 +1,15 @@
1
- import asyncio
1
+ import concurrent.futures
2
2
  import functools
3
3
  import gc
4
4
  import os
5
5
  import time
6
+ from concurrent.futures import Future
6
7
  from dataclasses import dataclass
7
8
  from queue import Empty, Queue
8
9
  from threading import Thread
9
10
  from typing import Any, cast
10
11
 
12
+ import anyio
11
13
  from typing_extensions import override
12
14
  from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
13
15
 
@@ -280,8 +282,7 @@ class VLLMAPI(ModelAPI):
280
282
  @dataclass
281
283
  class _QueueItem:
282
284
  input: GenerateInput
283
- future: asyncio.Future[list[GenerateOutput]]
284
- loop: asyncio.AbstractEventLoop
285
+ future: Future[list[GenerateOutput]]
285
286
 
286
287
 
287
288
  batch_thread: Thread | None = None
@@ -297,15 +298,16 @@ async def batched_generate(input: GenerateInput) -> list[GenerateOutput]:
297
298
  batch_thread.start()
298
299
 
299
300
  # enqueue the job
300
- loop = asyncio.get_event_loop()
301
- future: asyncio.Future[list[GenerateOutput]] = loop.create_future()
302
- batch_queue.put(_QueueItem(input=input, future=future, loop=loop))
301
+ future = Future[list[GenerateOutput]]()
302
+ batch_queue.put(_QueueItem(input=input, future=future))
303
303
 
304
- # await the job
305
- await future
306
-
307
- # return it
308
- return future.result()
304
+ # await the future
305
+ while True:
306
+ try:
307
+ return future.result(timeout=0.01)
308
+ except concurrent.futures.TimeoutError:
309
+ pass
310
+ await anyio.sleep(1)
309
311
 
310
312
 
311
313
  def string_to_bytes(string: str) -> list[int]:
@@ -397,13 +399,12 @@ def post_process_outputs(
397
399
  def process_batches() -> None:
398
400
  while True:
399
401
  # drain the queue (wait until no new messages have shown up for 2 seconds)
400
- inputs: list[tuple[GenerateInput, asyncio.Future[list[GenerateOutput]]]] = []
402
+ inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
401
403
  while True:
402
404
  try:
403
405
  input = batch_queue.get(
404
406
  timeout=2
405
407
  ) # wait 2 seconds max TODO: what's optimal wait time?
406
- loop = input.loop
407
408
  inputs.append((input.input, input.future))
408
409
  if len(inputs) >= input.input.batch_size:
409
410
  # max batch size reached
@@ -429,14 +430,10 @@ def process_batches() -> None:
429
430
  for i, output in enumerate(outputs):
430
431
  future = inputs[i][1]
431
432
 
432
- # asyncio futures are not thread safe, so we need to pass the event loop
433
- # down to this point, so we can mark the future as done in a thread safe manner.
434
- # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
435
- loop.call_soon_threadsafe(
436
- future.set_result,
433
+ future.set_result(
437
434
  post_process_outputs(output, num_top_logprobs, total_time),
438
435
  )
439
436
 
440
437
  except Exception as e:
441
438
  for _, future in inputs:
442
- loop.call_soon_threadsafe(future.set_exception, e)
439
+ future.set_exception(e)
@@ -1,5 +1,6 @@
1
- import asyncio
1
+ import functools
2
2
 
3
+ from inspect_ai._util._async import tg_collect
3
4
  from inspect_ai.scorer._reducer.registry import create_reducers
4
5
  from inspect_ai.solver._task_state import TaskState
5
6
 
@@ -19,7 +20,9 @@ def multi_scorer(scorers: list[Scorer], reducer: str | ScoreReducer) -> Scorer:
19
20
  reducer = create_reducers(reducer)[0]
20
21
 
21
22
  async def score(state: TaskState, target: Target) -> Score:
22
- scores = await asyncio.gather(*[_scorer(state, target) for _scorer in scorers])
23
+ scores = await tg_collect(
24
+ [functools.partial(_scorer, state, target) for _scorer in scorers]
25
+ )
23
26
  return reducer(scores)
24
27
 
25
28
  return score
@@ -11,11 +11,12 @@ from openai._types import ResponseT
11
11
  from openai.types.chat import (
12
12
  ChatCompletion,
13
13
  ChatCompletionMessageParam,
14
+ ChatCompletionToolChoiceOptionParam,
14
15
  ChatCompletionToolParam,
15
16
  )
16
17
  from shortuuid import uuid
17
18
 
18
- from inspect_ai.model._generate_config import GenerateConfig
19
+ from inspect_ai.model._generate_config import GenerateConfig, ResponseSchema
19
20
  from inspect_ai.model._model import get_model
20
21
  from inspect_ai.model._openai import (
21
22
  chat_messages_from_openai,
@@ -23,8 +24,10 @@ from inspect_ai.model._openai import (
23
24
  openai_completion_usage,
24
25
  )
25
26
  from inspect_ai.solver._task_state import sample_state
27
+ from inspect_ai.tool._tool_choice import ToolChoice, ToolFunction
26
28
  from inspect_ai.tool._tool_info import ToolInfo
27
29
  from inspect_ai.tool._tool_params import ToolParams
30
+ from inspect_ai.util._json import JSONSchema
28
31
 
29
32
 
30
33
  @contextlib.asynccontextmanager
@@ -113,6 +116,20 @@ async def inspect_model_request(
113
116
  )
114
117
  )
115
118
 
119
+ # convert openai tool choice to inspect tool_choice
120
+ inspect_tool_choice: ToolChoice | None = None
121
+ tool_choice: ChatCompletionToolChoiceOptionParam | None = json_data.get(
122
+ "tool_choice", None
123
+ )
124
+ if tool_choice is not None:
125
+ match tool_choice:
126
+ case "auto" | "none":
127
+ inspect_tool_choice = tool_choice
128
+ case "required":
129
+ inspect_tool_choice = "any"
130
+ case _:
131
+ inspect_tool_choice = ToolFunction(name=tool_choice["function"]["name"])
132
+
116
133
  # resolve model
117
134
  if model_name == "inspect":
118
135
  model = get_model()
@@ -122,6 +139,7 @@ async def inspect_model_request(
122
139
  output = await model.generate(
123
140
  input=input,
124
141
  tools=inspect_tools,
142
+ tool_choice=inspect_tool_choice,
125
143
  config=generate_config_from_openai(options),
126
144
  )
127
145
 
@@ -165,4 +183,16 @@ def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
165
183
  config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
166
184
  config.reasoning_effort = json_data.get("reasoning_effort", None)
167
185
 
186
+ # response format
187
+ response_format: dict[str, Any] | None = json_data.get("response_format", None)
188
+ if response_format is not None:
189
+ json_schema: dict[str, Any] | None = response_format.get("json_schema", None)
190
+ if json_schema is not None:
191
+ config.response_schema = ResponseSchema(
192
+ name=json_schema.get("name", "schema"),
193
+ description=json_schema.get("description", None),
194
+ json_schema=JSONSchema.model_validate(json_schema.get("schema", {})),
195
+ strict=json_schema.get("strict", None),
196
+ )
197
+
168
198
  return config
@@ -1,10 +1,11 @@
1
- import asyncio
1
+ import functools
2
2
  from contextvars import ContextVar
3
3
  from copy import deepcopy
4
4
  from typing import Any, cast
5
5
 
6
6
  from typing_extensions import overload
7
7
 
8
+ from inspect_ai._util._async import tg_collect
8
9
  from inspect_ai._util.registry import registry_log_name, registry_params
9
10
  from inspect_ai.util._subtask import subtask
10
11
 
@@ -44,8 +45,9 @@ async def fork(
44
45
  if isinstance(solvers, Solver):
45
46
  return await solver_subtask(state, solvers)
46
47
  else:
47
- subtasks = [solver_subtask(state, solver) for solver in solvers]
48
- return await asyncio.gather(*subtasks)
48
+ return await tg_collect(
49
+ [functools.partial(solver_subtask, state, solver) for solver in solvers]
50
+ )
49
51
 
50
52
 
51
53
  async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
@@ -1,6 +1,7 @@
1
- import asyncio
2
1
  from typing import cast
3
2
 
3
+ import anyio
4
+
4
5
  from inspect_ai.util import display_type, input_panel, sandbox
5
6
  from inspect_ai.util._sandbox.events import SandboxEnvironmentProxy
6
7
 
@@ -42,7 +43,7 @@ def human_agent(
42
43
  Solver: Human agent solver.
43
44
  """
44
45
  # we can only run one human agent interaction at a time (use lock to enforce)
45
- agent_lock = asyncio.Lock()
46
+ agent_lock = anyio.Lock()
46
47
 
47
48
  async def solve(state: TaskState, generate: Generate) -> TaskState:
48
49
  async with agent_lock:
@@ -20,7 +20,7 @@ from ._tool_call import (
20
20
  from ._tool_choice import ToolChoice, ToolFunction
21
21
  from ._tool_def import ToolDef
22
22
  from ._tool_info import ToolInfo
23
- from ._tool_params import JSONType, ToolParam, ToolParams
23
+ from ._tool_params import ToolParam, ToolParams
24
24
  from ._tool_with import tool_with
25
25
  from ._tools._computer import computer
26
26
  from ._tools._execute import bash, python
@@ -56,12 +56,18 @@ __all__ = [
56
56
  "ToolInfo",
57
57
  "ToolParam",
58
58
  "ToolParams",
59
- "JSONType",
60
59
  ]
61
60
 
62
61
  _UTIL_MODULE_VERSION = "0.3.19"
62
+ _JSON_MODULE_VERSION = "0.3.73"
63
63
  _REMOVED_IN = "0.4"
64
64
 
65
+ relocated_module_attribute(
66
+ "JSONType",
67
+ "inspect_ai.util.JSONType",
68
+ _JSON_MODULE_VERSION,
69
+ _REMOVED_IN,
70
+ )
65
71
 
66
72
  relocated_module_attribute(
67
73
  "ToolEnvironment",
@@ -1,27 +1,19 @@
1
1
  import inspect
2
- import types
3
- import typing
4
- from dataclasses import is_dataclass
5
2
  from typing import (
6
3
  Any,
7
4
  Callable,
8
5
  Dict,
9
- List,
10
- Optional,
11
- Tuple,
12
- Type,
13
- Union,
14
6
  get_args,
15
- get_origin,
16
7
  get_type_hints,
17
- is_typeddict,
18
8
  )
19
9
 
20
10
  from docstring_parser import Docstring, parse
21
11
  from pydantic import BaseModel, Field
22
12
 
13
+ from inspect_ai.util._json import JSONType, json_schema
14
+
23
15
  from ._tool_description import tool_description
24
- from ._tool_params import JSONType, ToolParam, ToolParams
16
+ from ._tool_params import ToolParam, ToolParams
25
17
 
26
18
 
27
19
  class ToolInfo(BaseModel):
@@ -88,7 +80,7 @@ def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
88
80
 
89
81
  # Get type information from type annotations
90
82
  if param_name in type_hints:
91
- tool_param = parse_type(type_hints[param_name])
83
+ tool_param = json_schema(type_hints[param_name])
92
84
  # as a fallback try to parse it from the docstring
93
85
  # (this is minimally necessary for backwards compatiblity
94
86
  # with tools gen1 type parsing, which only used docstrings)
@@ -129,84 +121,6 @@ def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
129
121
  return info
130
122
 
131
123
 
132
- def parse_type(type_hint: Type[Any]) -> ToolParam:
133
- origin = get_origin(type_hint)
134
- args = get_args(type_hint)
135
-
136
- if origin is None:
137
- if type_hint is int:
138
- return ToolParam(type="integer")
139
- elif type_hint is float:
140
- return ToolParam(type="number")
141
- elif type_hint is str:
142
- return ToolParam(type="string")
143
- elif type_hint is bool:
144
- return ToolParam(type="boolean")
145
- elif type_hint is list:
146
- return ToolParam(type="array", items=ToolParam())
147
- elif type_hint is dict:
148
- return ToolParam(type="object", additionalProperties=ToolParam())
149
- elif (
150
- is_dataclass(type_hint)
151
- or is_typeddict(type_hint)
152
- or (isinstance(type_hint, type) and issubclass(type_hint, BaseModel))
153
- ):
154
- return parse_object(type_hint)
155
- elif type_hint is type(None):
156
- return ToolParam(type="null")
157
- else:
158
- return ToolParam()
159
- elif origin is list or origin is List or origin is tuple or origin is Tuple:
160
- return ToolParam(
161
- type="array", items=parse_type(args[0]) if args else ToolParam()
162
- )
163
- elif origin is dict or origin is Dict:
164
- return ToolParam(
165
- type="object",
166
- additionalProperties=parse_type(args[1]) if len(args) > 1 else ToolParam(),
167
- )
168
- elif origin is Union or origin is types.UnionType:
169
- return ToolParam(anyOf=[parse_type(arg) for arg in args])
170
- elif origin is Optional:
171
- return ToolParam(
172
- anyOf=[parse_type(arg) for arg in args] + [ToolParam(type="null")]
173
- )
174
- elif origin is typing.Literal:
175
- return ToolParam(enum=list(args))
176
-
177
- return ToolParam() # Default case if we can't determine the type
178
-
179
-
180
- def parse_object(cls: Type[Any]) -> ToolParam:
181
- properties: Dict[str, ToolParam] = {}
182
- required: List[str] = []
183
-
184
- if is_dataclass(cls):
185
- fields = cls.__dataclass_fields__ # type: ignore
186
- for name, field in fields.items():
187
- properties[name] = parse_type(field.type) # type: ignore
188
- if field.default == field.default_factory:
189
- required.append(name)
190
- elif isinstance(cls, type) and issubclass(cls, BaseModel):
191
- schema = cls.model_json_schema()
192
- for name, prop in schema.get("properties", {}).items():
193
- properties[name] = ToolParam(**prop)
194
- required = schema.get("required", [])
195
- elif is_typeddict(cls):
196
- annotations = get_type_hints(cls)
197
- for name, type_hint in annotations.items():
198
- properties[name] = parse_type(type_hint)
199
- if name in cls.__required_keys__:
200
- required.append(name)
201
-
202
- return ToolParam(
203
- type="object",
204
- properties=properties,
205
- required=required if required else None,
206
- additionalProperties=False,
207
- )
208
-
209
-
210
124
  def parse_docstring(docstring: str | None, param_name: str) -> Dict[str, str]:
211
125
  if not docstring:
212
126
  return {}