inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/log/_file.py CHANGED
@@ -1,16 +1,15 @@
1
1
  import os
2
2
  import re
3
3
  from logging import getLogger
4
- from typing import Any, Callable, Generator, Literal, cast
4
+ from typing import Any, Callable, Generator, Literal
5
5
 
6
6
  from pydantic import BaseModel
7
7
  from pydantic_core import to_json
8
8
 
9
- from inspect_ai._util._async import run_coroutine
9
+ from inspect_ai._util._async import current_async_backend, run_coroutine
10
10
  from inspect_ai._util.constants import ALL_LOG_FORMATS, EVAL_LOG_FORMAT
11
11
  from inspect_ai._util.file import (
12
12
  FileInfo,
13
- async_fileystem,
14
13
  file,
15
14
  filesystem,
16
15
  )
@@ -96,62 +95,6 @@ def list_eval_logs(
96
95
  return eval_logs
97
96
 
98
97
 
99
- async def list_eval_logs_async(
100
- log_dir: str = os.environ.get("INSPECT_LOG_DIR", "./logs"),
101
- formats: list[Literal["eval", "json"]] | None = None,
102
- recursive: bool = True,
103
- descending: bool = True,
104
- fs_options: dict[str, Any] = {},
105
- ) -> list[EvalLogInfo]:
106
- """List all eval logs in a directory.
107
-
108
- Will be async for filesystem providers that support async (e.g. s3, gcs, etc.)
109
- otherwise will fallback to sync implementation.
110
-
111
- Args:
112
- log_dir (str): Log directory (defaults to INSPECT_LOG_DIR)
113
- formats (Literal["eval", "json"]): Formats to list (default
114
- to listing all formats)
115
- recursive (bool): List log files recursively (defaults to True).
116
- descending (bool): List in descending order.
117
- fs_options (dict[str, Any]): Optional. Additional arguments to pass through
118
- to the filesystem provider (e.g. `S3FileSystem`).
119
-
120
- Returns:
121
- List of EvalLog Info.
122
- """
123
- # async filesystem if we can
124
- fs = filesystem(log_dir, fs_options)
125
- if fs.is_async():
126
- async with async_fileystem(log_dir, fs_options=fs_options) as async_fs:
127
- if await async_fs._exists(log_dir):
128
- # prevent caching of listings
129
- async_fs.invalidate_cache(log_dir)
130
- # list logs
131
- if recursive:
132
- files: list[dict[str, Any]] = []
133
- async for _, _, filenames in async_fs._walk(log_dir, detail=True):
134
- files.extend(filenames.values())
135
- else:
136
- files = cast(
137
- list[dict[str, Any]],
138
- await async_fs._ls(log_dir, detail=True),
139
- )
140
- logs = [fs._file_info(file) for file in files]
141
- # resolve to eval logs
142
- return log_files_from_ls(logs, formats, descending)
143
- else:
144
- return []
145
- else:
146
- return list_eval_logs(
147
- log_dir=log_dir,
148
- formats=formats,
149
- recursive=recursive,
150
- descending=descending,
151
- fs_options=fs_options,
152
- )
153
-
154
-
155
98
  def write_eval_log(
156
99
  log: EvalLog,
157
100
  location: str | FileInfo | None = None,
@@ -165,6 +108,14 @@ def write_eval_log(
165
108
  format (Literal["eval", "json", "auto"]): Write to format
166
109
  (defaults to 'auto' based on `log_file` extension)
167
110
  """
111
+ # don't mix trio and asyncio
112
+ if current_async_backend() == "trio":
113
+ raise RuntimeError(
114
+ "write_eval_log cannot be called from a trio async context (please use write_eval_log_async instead)"
115
+ )
116
+
117
+ # will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
118
+ # flow, so force the use of asyncio
168
119
  run_coroutine(write_eval_log_async(log, location, format))
169
120
 
170
121
 
@@ -265,8 +216,21 @@ def read_eval_log(
265
216
  Returns:
266
217
  EvalLog object read from file.
267
218
  """
219
+ # don't mix trio and asyncio
220
+ if current_async_backend() == "trio":
221
+ raise RuntimeError(
222
+ "read_eval_log cannot be called from a trio async context (please use read_eval_log_async instead)"
223
+ )
224
+
225
+ # will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
226
+ # flow, so force the use of asyncio
268
227
  return run_coroutine(
269
- read_eval_log_async(log_file, header_only, resolve_attachments, format)
228
+ read_eval_log_async(
229
+ log_file,
230
+ header_only,
231
+ resolve_attachments,
232
+ format,
233
+ )
270
234
  )
271
235
 
272
236
 
@@ -281,7 +245,7 @@ async def read_eval_log_async(
281
245
  Args:
282
246
  log_file (str | FileInfo): Log file to read.
283
247
  header_only (bool): Read only the header (i.e. exclude
284
- the "samples" and "logging" fields). Defaults to False.
248
+ the "samples" and "logging" fields). Defaults to False.
285
249
  resolve_attachments (bool): Resolve attachments (e.g. images)
286
250
  to their full content.
287
251
  format (Literal["eval", "json", "auto"]): Read from format
@@ -321,6 +285,8 @@ async def read_eval_log_async(
321
285
  def read_eval_log_headers(
322
286
  log_files: list[str] | list[EvalLogInfo],
323
287
  ) -> list[EvalLog]:
288
+ # will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
289
+ # flow, so force the use of asyncio
324
290
  return run_coroutine(read_eval_log_headers_async(log_files))
325
291
 
326
292
 
@@ -356,6 +322,14 @@ def read_eval_log_sample(
356
322
  Raises:
357
323
  IndexError: If the passed id and epoch are not found.
358
324
  """
325
+ # don't mix trio and asyncio
326
+ if current_async_backend() == "trio":
327
+ raise RuntimeError(
328
+ "read_eval_log_sample cannot be called from a trio async context (please use read_eval_log_sample_async instead)"
329
+ )
330
+
331
+ # will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
332
+ # flow, so force the use of asyncio
359
333
  return run_coroutine(
360
334
  read_eval_log_sample_async(log_file, id, epoch, resolve_attachments, format)
361
335
  )
inspect_ai/log/_log.py CHANGED
@@ -295,7 +295,7 @@ class EvalSample(BaseModel):
295
295
  # warning will handle this)
296
296
  del values["transcript"]
297
297
 
298
- return values
298
+ return migrate_sandbox_spec(values)
299
299
 
300
300
  # allow field model_usage
301
301
  model_config = ConfigDict(protected_namespaces=())
@@ -607,6 +607,23 @@ class EvalSpec(BaseModel):
607
607
  # allow field model_args
608
608
  model_config = ConfigDict(protected_namespaces=())
609
609
 
610
+ @model_validator(mode="before")
611
+ @classmethod
612
+ def read_sandbox_spec(
613
+ cls: Type["EvalSpec"], values: dict[str, Any]
614
+ ) -> dict[str, Any]:
615
+ return migrate_sandbox_spec(values)
616
+
617
+
618
+ def migrate_sandbox_spec(values: dict[str, Any]) -> dict[str, Any]:
619
+ if "sandbox" in values:
620
+ sandbox = values.get("sandbox")
621
+ if isinstance(sandbox, list):
622
+ values["sandbox"] = SandboxEnvironmentSpec(
623
+ type=sandbox[0], config=sandbox[1]
624
+ )
625
+ return values
626
+
610
627
 
611
628
  def eval_error(
612
629
  exception: BaseException,
@@ -1,13 +1,11 @@
1
- import asyncio
2
1
  import json
3
2
  import os
4
3
  import tempfile
5
- from contextlib import _AsyncGeneratorContextManager
6
4
  from logging import getLogger
7
5
  from typing import Any, BinaryIO, Literal, cast
8
6
  from zipfile import ZIP_DEFLATED, ZipFile
9
7
 
10
- from fsspec.asyn import AsyncFileSystem # type: ignore
8
+ import anyio
11
9
  from pydantic import BaseModel, Field
12
10
  from pydantic_core import to_json
13
11
  from typing_extensions import override
@@ -21,7 +19,7 @@ from inspect_ai._util.content import (
21
19
  ContentVideo,
22
20
  )
23
21
  from inspect_ai._util.error import EvalError
24
- from inspect_ai._util.file import FileSystem, async_fileystem, dirname, file, filesystem
22
+ from inspect_ai._util.file import FileSystem, dirname, file, filesystem
25
23
  from inspect_ai._util.json import jsonable_python
26
24
  from inspect_ai._util.trace import trace_action
27
25
  from inspect_ai.model._chat_message import ChatMessage
@@ -277,16 +275,14 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
277
275
 
278
276
 
279
277
  class ZipLogFile:
280
- _zip: ZipFile
278
+ _zip: ZipFile | None
281
279
  _temp_file: BinaryIO
282
280
  _fs: FileSystem
283
- _async_fs_context: _AsyncGeneratorContextManager[AsyncFileSystem] | None = None
284
- _async_fs: AsyncFileSystem | None = None
285
281
 
286
282
  def __init__(self, file: str) -> None:
287
283
  self._file = file
288
284
  self._fs = filesystem(file)
289
- self._lock = asyncio.Lock()
285
+ self._lock = anyio.Lock()
290
286
  self._temp_file = tempfile.TemporaryFile()
291
287
  self._samples: list[EvalSample] = []
292
288
  self._summary_counter = 0
@@ -300,11 +296,6 @@ class ZipLogFile:
300
296
  summaries: list[SampleSummary],
301
297
  ) -> None:
302
298
  async with self._lock:
303
- # connect to async filesystem if we can
304
- if self._fs.is_async():
305
- self._async_fs_context = async_fileystem(self._file)
306
- self._async_fs = await self._async_fs_context.__aenter__()
307
-
308
299
  self._open()
309
300
  self._summary_counter = summary_counter
310
301
  self._summaries = summaries
@@ -364,7 +355,8 @@ class ZipLogFile:
364
355
  async def flush(self) -> None:
365
356
  async with self._lock:
366
357
  # close the zip file so it is flushed
367
- self._zip.close()
358
+ if self._zip:
359
+ self._zip.close()
368
360
 
369
361
  # read the temp_file (leaves pointer at end for subsequent appends)
370
362
  self._temp_file.seek(0)
@@ -380,21 +372,19 @@ class ZipLogFile:
380
372
 
381
373
  async def close(self) -> EvalLog:
382
374
  async with self._lock:
383
- # close the async context if we have one
384
- try:
385
- if self._async_fs_context:
386
- await self._async_fs_context.__aexit__(None, None, None)
387
- except Exception as ex:
388
- logger.warning(
389
- f"Error occurred while closing async fs for {self._file}: {ex}"
390
- )
391
-
392
375
  # read the log from the temp file then close it
393
376
  try:
394
377
  self._temp_file.seek(0)
395
378
  return _read_log(self._temp_file, self._file)
396
379
  finally:
397
380
  self._temp_file.close()
381
+ if self._zip:
382
+ self._zip.close()
383
+
384
+ # cleanup zip file if we didn't in normal course
385
+ def __del__(self) -> None:
386
+ if self._zip:
387
+ self._zip.close()
398
388
 
399
389
  def _open(self) -> None:
400
390
  self._zip = ZipFile(
@@ -406,6 +396,7 @@ class ZipLogFile:
406
396
 
407
397
  # raw unsynchronized version of write
408
398
  def _zip_writestr(self, filename: str, data: Any) -> None:
399
+ assert self._zip
409
400
  self._zip.writestr(
410
401
  filename,
411
402
  to_json(
@@ -9,7 +9,7 @@ from typing_extensions import override
9
9
 
10
10
  from inspect_ai._util.constants import LOG_SCHEMA_VERSION
11
11
  from inspect_ai._util.error import EvalError
12
- from inspect_ai._util.file import absolute_file_path, async_fileystem, file, filesystem
12
+ from inspect_ai._util.file import absolute_file_path, file
13
13
  from inspect_ai._util.trace import trace_action
14
14
 
15
15
  from .._log import (
@@ -178,23 +178,8 @@ class JSONRecorder(FileRecorder):
178
178
  log_bytes = eval_log_json(log)
179
179
 
180
180
  with trace_action(logger, "Log Write", location):
181
- # try to write async for async filesystems
182
- written = False
183
- try:
184
- fs = filesystem(location)
185
- if fs.is_async():
186
- async with async_fileystem(location) as async_fs:
187
- await async_fs._pipe_file(location, log_bytes)
188
- written = True
189
- except Exception as ex:
190
- logger.warning(
191
- f"Error occurred during async write to {location}: {ex}. Falling back to sync write."
192
- )
193
-
194
- # otherwise use sync
195
- if not written:
196
- with file(location, "wb") as f:
197
- f.write(log_bytes)
181
+ with file(location, "wb") as f:
182
+ f.write(log_bytes)
198
183
 
199
184
 
200
185
  def _validate_version(ver: int) -> None:
@@ -1,15 +1,16 @@
1
1
  import contextlib
2
2
  from contextvars import ContextVar
3
3
  from datetime import datetime
4
- from typing import AsyncGenerator, Literal
4
+ from typing import AsyncGenerator, Iterator, Literal
5
5
 
6
6
  from shortuuid import uuid
7
7
 
8
+ from inspect_ai._util.constants import SAMPLE_SUBTASK
8
9
  from inspect_ai.dataset._dataset import Sample
9
10
  from inspect_ai.util._sandbox import SandboxConnection
10
11
  from inspect_ai.util._sandbox.context import sandbox_connections
11
12
 
12
- from ._transcript import Transcript
13
+ from ._transcript import Transcript, transcript
13
14
 
14
15
 
15
16
  class ActiveSample:
@@ -44,6 +45,7 @@ class ActiveSample:
44
45
  self.total_tokens = 0
45
46
  self.transcript = transcript
46
47
  self.sandboxes = sandboxes
48
+ self.retry_count = 0
47
49
  self._interrupt_action: Literal["score", "error"] | None = None
48
50
 
49
51
  @property
@@ -153,6 +155,29 @@ def set_active_sample_total_messages(total_messages: int) -> None:
153
155
  active.total_messages = total_messages
154
156
 
155
157
 
158
+ @contextlib.contextmanager
159
+ def track_active_sample_retries() -> Iterator[None]:
160
+ reset_active_sample_retries()
161
+ try:
162
+ yield
163
+ finally:
164
+ reset_active_sample_retries()
165
+
166
+
167
+ def reset_active_sample_retries() -> None:
168
+ active = sample_active()
169
+ if active:
170
+ active.retry_count = 0
171
+
172
+
173
+ def report_active_sample_retry() -> None:
174
+ active = sample_active()
175
+ if active:
176
+ # only do this for the top level subtask
177
+ if transcript().name == SAMPLE_SUBTASK:
178
+ active.retry_count = active.retry_count + 1
179
+
180
+
156
181
  _sample_active: ContextVar[ActiveSample | None] = ContextVar(
157
182
  "_sample_active", default=None
158
183
  )
@@ -1,10 +1,10 @@
1
- import asyncio
2
1
  import contextlib
3
2
  from contextvars import ContextVar
4
3
  from datetime import datetime
5
4
  from logging import getLogger
6
5
  from typing import (
7
6
  Any,
7
+ Callable,
8
8
  Iterator,
9
9
  Literal,
10
10
  Sequence,
@@ -210,15 +210,15 @@ class ToolEvent(BaseEvent):
210
210
 
211
211
  # mechanism for operator to cancel the tool call
212
212
 
213
- def _set_task(self, task: asyncio.Task[Any]) -> None:
213
+ def _set_cancel_fn(self, cancel_fn: Callable[[], None]) -> None:
214
214
  """Set the tool task (for possible cancellation)"""
215
- self._task = task
215
+ self._cancel_fn = cancel_fn
216
216
 
217
217
  def _cancel(self) -> None:
218
218
  """Cancel the tool task."""
219
- if self._task:
219
+ if self._cancel_fn and not self.cancelled:
220
220
  self._cancelled = True
221
- self._task.cancel()
221
+ self._cancel_fn()
222
222
 
223
223
  @property
224
224
  def cancelled(self) -> bool:
@@ -228,11 +228,11 @@ class ToolEvent(BaseEvent):
228
228
  _cancelled: bool | None = None
229
229
  """Was this tool call cancelled?"""
230
230
 
231
- _task: asyncio.Task[Any] | None = None
232
- """Handle to task (used for cancellation)"""
231
+ _cancel_fn: Callable[[], None] | None = None
232
+ """Function which can be used to cancel the tool call."""
233
233
 
234
234
  model_config = ConfigDict(arbitrary_types_allowed=True)
235
- """Required so that we can include '_task' as a member."""
235
+ """Required so that we can include '_cancel_fn' as a member."""
236
236
 
237
237
  @field_serializer("completed")
238
238
  def serialize_completed(self, dt: datetime) -> str:
@@ -27,7 +27,7 @@ from ._chat_message import (
27
27
  ChatMessageTool,
28
28
  ChatMessageUser,
29
29
  )
30
- from ._generate_config import GenerateConfig, GenerateConfigArgs
30
+ from ._generate_config import GenerateConfig, GenerateConfigArgs, ResponseSchema
31
31
  from ._model import (
32
32
  Model,
33
33
  ModelAPI,
@@ -49,6 +49,7 @@ from ._registry import modelapi
49
49
  __all__ = [
50
50
  "GenerateConfig",
51
51
  "GenerateConfigArgs",
52
+ "ResponseSchema",
52
53
  "CachePolicy",
53
54
  "ContentAudio",
54
55
  "ContentImage",
@@ -1,6 +1,6 @@
1
- import asyncio
2
1
  import inspect
3
2
  import json
3
+ import sys
4
4
  import types
5
5
  from dataclasses import is_dataclass
6
6
  from logging import getLogger
@@ -22,7 +22,13 @@ from typing import (
22
22
  is_typeddict,
23
23
  )
24
24
 
25
+ if sys.version_info < (3, 11):
26
+ from exceptiongroup import ExceptionGroup
27
+
28
+
29
+ import anyio
25
30
  import yaml
31
+ from anyio.streams.memory import MemoryObjectSendStream
26
32
  from jsonschema import Draft7Validator
27
33
  from pydantic import BaseModel
28
34
 
@@ -80,7 +86,10 @@ async def call_tools(
80
86
 
81
87
  tdefs = tool_defs(tools)
82
88
 
83
- async def call_tool_task(call: ToolCall) -> tuple[ChatMessageTool, ToolEvent]:
89
+ async def call_tool_task(
90
+ call: ToolCall,
91
+ send_stream: MemoryObjectSendStream[tuple[ChatMessageTool, ToolEvent]],
92
+ ) -> None:
84
93
  # create a transript for this call
85
94
  init_transcript(Transcript(name=call.function))
86
95
 
@@ -166,20 +175,23 @@ async def call_tools(
166
175
  events=list(transcript().events),
167
176
  )
168
177
 
169
- # return message and event
170
- return ChatMessageTool(
171
- content=content,
172
- tool_call_id=call.id,
173
- function=call.function,
174
- error=tool_error,
175
- ), event
178
+ # yield message and event
179
+ async with send_stream:
180
+ await send_stream.send(
181
+ (
182
+ ChatMessageTool(
183
+ content=content,
184
+ tool_call_id=call.id,
185
+ function=call.function,
186
+ error=tool_error,
187
+ ),
188
+ event,
189
+ )
190
+ )
176
191
 
177
192
  # call tools
178
193
  tool_messages: list[ChatMessageTool] = []
179
194
  for call in message.tool_calls:
180
- # create the task
181
- task = asyncio.create_task(call_tool_task(call))
182
-
183
195
  # create pending tool event and add it to the transcript
184
196
  # (record the waiting time for the sample so we can compare
185
197
  # it at the end to deduce total waiting time inside the tool
@@ -192,38 +204,46 @@ async def call_tools(
192
204
  view=call.view,
193
205
  pending=True,
194
206
  )
195
- event._set_task(task)
196
207
  transcript()._event(event)
197
208
 
198
- # execute the tool call. if the operator cancelled the
209
+ # execute the tool call. if the operator cancels the
199
210
  # tool call then synthesize the appropriate message/event
211
+ send_stream, receive_stream = anyio.create_memory_object_stream[
212
+ tuple[ChatMessageTool, ToolEvent]
213
+ ]()
200
214
  try:
201
- tool_message, result_event = await task
202
- except asyncio.CancelledError:
203
- if event.cancelled:
204
- tool_message = ChatMessageTool(
205
- content="",
206
- function=call.function,
207
- tool_call_id=call.id,
208
- error=ToolCallError(
209
- "timeout", "Command timed out before completing."
210
- ),
211
- )
212
- result_event = ToolEvent(
213
- id=call.id,
214
- function=call.function,
215
- arguments=call.arguments,
216
- result=tool_message.content,
217
- truncated=None,
218
- view=call.view,
219
- error=tool_message.error,
220
- events=[],
221
- )
222
- transcript().info(
223
- f"Tool call '{call.function}' was cancelled by operator."
224
- )
225
- else:
226
- raise
215
+ async with anyio.create_task_group() as tg:
216
+ tg.start_soon(call_tool_task, call, send_stream)
217
+ event._set_cancel_fn(tg.cancel_scope.cancel)
218
+ async with receive_stream:
219
+ async for result in receive_stream:
220
+ tool_message, result_event = result
221
+ break
222
+ except ExceptionGroup as ex:
223
+ raise ex.exceptions[0]
224
+
225
+ if event.cancelled:
226
+ tool_message = ChatMessageTool(
227
+ content="",
228
+ function=call.function,
229
+ tool_call_id=call.id,
230
+ error=ToolCallError(
231
+ "timeout", "Command timed out before completing."
232
+ ),
233
+ )
234
+ result_event = ToolEvent(
235
+ id=call.id,
236
+ function=call.function,
237
+ arguments=call.arguments,
238
+ result=tool_message.content,
239
+ truncated=None,
240
+ view=call.view,
241
+ error=tool_message.error,
242
+ events=[],
243
+ )
244
+ transcript().info(
245
+ f"Tool call '{call.function}' was cancelled by operator."
246
+ )
227
247
 
228
248
  # update return messages
229
249
  tool_messages.append(tool_message)
@@ -2,6 +2,7 @@ from logging import getLogger
2
2
  from typing import Any, Literal, Type, Union
3
3
 
4
4
  from pydantic import BaseModel, Field, model_validator
5
+ from shortuuid import uuid
5
6
 
6
7
  from inspect_ai._util.content import Content, ContentReasoning, ContentText
7
8
  from inspect_ai.tool import ToolCall
@@ -15,8 +16,8 @@ logger = getLogger(__name__)
15
16
  class ChatMessageBase(BaseModel):
16
17
  """Base class for chat messages."""
17
18
 
18
- role: Literal["system", "user", "assistant", "tool"]
19
- """Conversation role"""
19
+ id: str = Field(default_factory=uuid)
20
+ """Unique identifer for message."""
20
21
 
21
22
  content: str | list[Content]
22
23
  """Content (simple string or list of content objects)"""
@@ -5,6 +5,25 @@ from typing import Any, Literal, Union
5
5
  from pydantic import BaseModel, Field, model_validator
6
6
  from typing_extensions import TypedDict
7
7
 
8
+ from inspect_ai.util._json import JSONSchema
9
+
10
+
11
+ class ResponseSchema(BaseModel):
12
+ """Schema for model response when using Structured Output."""
13
+
14
+ name: str
15
+ """The name of the response schema. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."""
16
+
17
+ json_schema: JSONSchema
18
+ """The schema for the response format, described as a JSON Schema object."""
19
+
20
+ description: str | None = Field(default=None)
21
+ """A description of what the response format is for, used by the model to determine how to respond in the format."""
22
+
23
+ strict: bool | None = Field(default=None)
24
+ """Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the schema field.
25
+ OpenAI and Mistral only."""
26
+
8
27
 
9
28
  class GenerateConfigArgs(TypedDict, total=False):
10
29
  """Type for kwargs that selectively override GenerateConfig."""
@@ -81,6 +100,9 @@ class GenerateConfigArgs(TypedDict, total=False):
81
100
  reasoning_history: Literal["none", "all", "last", "auto"] | None
82
101
  """Include reasoning in chat message history sent to generate."""
83
102
 
103
+ response_schema: ResponseSchema | None
104
+ """Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
105
+
84
106
 
85
107
  class GenerateConfig(BaseModel):
86
108
  """Model generation options."""
@@ -159,6 +181,9 @@ class GenerateConfig(BaseModel):
159
181
  )
160
182
  """Include reasoning in chat message history sent to generate."""
161
183
 
184
+ response_schema: ResponseSchema | None = Field(default=None)
185
+ """Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
186
+
162
187
  # migrate reasoning_history as a bool
163
188
  @model_validator(mode="before")
164
189
  @classmethod