inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/log/_file.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
1
|
import os
|
2
2
|
import re
|
3
3
|
from logging import getLogger
|
4
|
-
from typing import Any, Callable, Generator, Literal
|
4
|
+
from typing import Any, Callable, Generator, Literal
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
7
|
from pydantic_core import to_json
|
8
8
|
|
9
|
-
from inspect_ai._util._async import run_coroutine
|
9
|
+
from inspect_ai._util._async import current_async_backend, run_coroutine
|
10
10
|
from inspect_ai._util.constants import ALL_LOG_FORMATS, EVAL_LOG_FORMAT
|
11
11
|
from inspect_ai._util.file import (
|
12
12
|
FileInfo,
|
13
|
-
async_fileystem,
|
14
13
|
file,
|
15
14
|
filesystem,
|
16
15
|
)
|
@@ -96,62 +95,6 @@ def list_eval_logs(
|
|
96
95
|
return eval_logs
|
97
96
|
|
98
97
|
|
99
|
-
async def list_eval_logs_async(
|
100
|
-
log_dir: str = os.environ.get("INSPECT_LOG_DIR", "./logs"),
|
101
|
-
formats: list[Literal["eval", "json"]] | None = None,
|
102
|
-
recursive: bool = True,
|
103
|
-
descending: bool = True,
|
104
|
-
fs_options: dict[str, Any] = {},
|
105
|
-
) -> list[EvalLogInfo]:
|
106
|
-
"""List all eval logs in a directory.
|
107
|
-
|
108
|
-
Will be async for filesystem providers that support async (e.g. s3, gcs, etc.)
|
109
|
-
otherwise will fallback to sync implementation.
|
110
|
-
|
111
|
-
Args:
|
112
|
-
log_dir (str): Log directory (defaults to INSPECT_LOG_DIR)
|
113
|
-
formats (Literal["eval", "json"]): Formats to list (default
|
114
|
-
to listing all formats)
|
115
|
-
recursive (bool): List log files recursively (defaults to True).
|
116
|
-
descending (bool): List in descending order.
|
117
|
-
fs_options (dict[str, Any]): Optional. Additional arguments to pass through
|
118
|
-
to the filesystem provider (e.g. `S3FileSystem`).
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
List of EvalLog Info.
|
122
|
-
"""
|
123
|
-
# async filesystem if we can
|
124
|
-
fs = filesystem(log_dir, fs_options)
|
125
|
-
if fs.is_async():
|
126
|
-
async with async_fileystem(log_dir, fs_options=fs_options) as async_fs:
|
127
|
-
if await async_fs._exists(log_dir):
|
128
|
-
# prevent caching of listings
|
129
|
-
async_fs.invalidate_cache(log_dir)
|
130
|
-
# list logs
|
131
|
-
if recursive:
|
132
|
-
files: list[dict[str, Any]] = []
|
133
|
-
async for _, _, filenames in async_fs._walk(log_dir, detail=True):
|
134
|
-
files.extend(filenames.values())
|
135
|
-
else:
|
136
|
-
files = cast(
|
137
|
-
list[dict[str, Any]],
|
138
|
-
await async_fs._ls(log_dir, detail=True),
|
139
|
-
)
|
140
|
-
logs = [fs._file_info(file) for file in files]
|
141
|
-
# resolve to eval logs
|
142
|
-
return log_files_from_ls(logs, formats, descending)
|
143
|
-
else:
|
144
|
-
return []
|
145
|
-
else:
|
146
|
-
return list_eval_logs(
|
147
|
-
log_dir=log_dir,
|
148
|
-
formats=formats,
|
149
|
-
recursive=recursive,
|
150
|
-
descending=descending,
|
151
|
-
fs_options=fs_options,
|
152
|
-
)
|
153
|
-
|
154
|
-
|
155
98
|
def write_eval_log(
|
156
99
|
log: EvalLog,
|
157
100
|
location: str | FileInfo | None = None,
|
@@ -165,6 +108,14 @@ def write_eval_log(
|
|
165
108
|
format (Literal["eval", "json", "auto"]): Write to format
|
166
109
|
(defaults to 'auto' based on `log_file` extension)
|
167
110
|
"""
|
111
|
+
# don't mix trio and asyncio
|
112
|
+
if current_async_backend() == "trio":
|
113
|
+
raise RuntimeError(
|
114
|
+
"write_eval_log cannot be called from a trio async context (please use write_eval_log_async instead)"
|
115
|
+
)
|
116
|
+
|
117
|
+
# will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
|
118
|
+
# flow, so force the use of asyncio
|
168
119
|
run_coroutine(write_eval_log_async(log, location, format))
|
169
120
|
|
170
121
|
|
@@ -265,8 +216,21 @@ def read_eval_log(
|
|
265
216
|
Returns:
|
266
217
|
EvalLog object read from file.
|
267
218
|
"""
|
219
|
+
# don't mix trio and asyncio
|
220
|
+
if current_async_backend() == "trio":
|
221
|
+
raise RuntimeError(
|
222
|
+
"read_eval_log cannot be called from a trio async context (please use read_eval_log_async instead)"
|
223
|
+
)
|
224
|
+
|
225
|
+
# will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
|
226
|
+
# flow, so force the use of asyncio
|
268
227
|
return run_coroutine(
|
269
|
-
read_eval_log_async(
|
228
|
+
read_eval_log_async(
|
229
|
+
log_file,
|
230
|
+
header_only,
|
231
|
+
resolve_attachments,
|
232
|
+
format,
|
233
|
+
)
|
270
234
|
)
|
271
235
|
|
272
236
|
|
@@ -281,7 +245,7 @@ async def read_eval_log_async(
|
|
281
245
|
Args:
|
282
246
|
log_file (str | FileInfo): Log file to read.
|
283
247
|
header_only (bool): Read only the header (i.e. exclude
|
284
|
-
|
248
|
+
the "samples" and "logging" fields). Defaults to False.
|
285
249
|
resolve_attachments (bool): Resolve attachments (e.g. images)
|
286
250
|
to their full content.
|
287
251
|
format (Literal["eval", "json", "auto"]): Read from format
|
@@ -321,6 +285,8 @@ async def read_eval_log_async(
|
|
321
285
|
def read_eval_log_headers(
|
322
286
|
log_files: list[str] | list[EvalLogInfo],
|
323
287
|
) -> list[EvalLog]:
|
288
|
+
# will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
|
289
|
+
# flow, so force the use of asyncio
|
324
290
|
return run_coroutine(read_eval_log_headers_async(log_files))
|
325
291
|
|
326
292
|
|
@@ -356,6 +322,14 @@ def read_eval_log_sample(
|
|
356
322
|
Raises:
|
357
323
|
IndexError: If the passed id and epoch are not found.
|
358
324
|
"""
|
325
|
+
# don't mix trio and asyncio
|
326
|
+
if current_async_backend() == "trio":
|
327
|
+
raise RuntimeError(
|
328
|
+
"read_eval_log_sample cannot be called from a trio async context (please use read_eval_log_sample_async instead)"
|
329
|
+
)
|
330
|
+
|
331
|
+
# will use s3fs and is not called from main inspect solver/scorer/tool/sandbox
|
332
|
+
# flow, so force the use of asyncio
|
359
333
|
return run_coroutine(
|
360
334
|
read_eval_log_sample_async(log_file, id, epoch, resolve_attachments, format)
|
361
335
|
)
|
inspect_ai/log/_log.py
CHANGED
@@ -295,7 +295,7 @@ class EvalSample(BaseModel):
|
|
295
295
|
# warning will handle this)
|
296
296
|
del values["transcript"]
|
297
297
|
|
298
|
-
return values
|
298
|
+
return migrate_sandbox_spec(values)
|
299
299
|
|
300
300
|
# allow field model_usage
|
301
301
|
model_config = ConfigDict(protected_namespaces=())
|
@@ -607,6 +607,23 @@ class EvalSpec(BaseModel):
|
|
607
607
|
# allow field model_args
|
608
608
|
model_config = ConfigDict(protected_namespaces=())
|
609
609
|
|
610
|
+
@model_validator(mode="before")
|
611
|
+
@classmethod
|
612
|
+
def read_sandbox_spec(
|
613
|
+
cls: Type["EvalSpec"], values: dict[str, Any]
|
614
|
+
) -> dict[str, Any]:
|
615
|
+
return migrate_sandbox_spec(values)
|
616
|
+
|
617
|
+
|
618
|
+
def migrate_sandbox_spec(values: dict[str, Any]) -> dict[str, Any]:
|
619
|
+
if "sandbox" in values:
|
620
|
+
sandbox = values.get("sandbox")
|
621
|
+
if isinstance(sandbox, list):
|
622
|
+
values["sandbox"] = SandboxEnvironmentSpec(
|
623
|
+
type=sandbox[0], config=sandbox[1]
|
624
|
+
)
|
625
|
+
return values
|
626
|
+
|
610
627
|
|
611
628
|
def eval_error(
|
612
629
|
exception: BaseException,
|
@@ -1,13 +1,11 @@
|
|
1
|
-
import asyncio
|
2
1
|
import json
|
3
2
|
import os
|
4
3
|
import tempfile
|
5
|
-
from contextlib import _AsyncGeneratorContextManager
|
6
4
|
from logging import getLogger
|
7
5
|
from typing import Any, BinaryIO, Literal, cast
|
8
6
|
from zipfile import ZIP_DEFLATED, ZipFile
|
9
7
|
|
10
|
-
|
8
|
+
import anyio
|
11
9
|
from pydantic import BaseModel, Field
|
12
10
|
from pydantic_core import to_json
|
13
11
|
from typing_extensions import override
|
@@ -21,7 +19,7 @@ from inspect_ai._util.content import (
|
|
21
19
|
ContentVideo,
|
22
20
|
)
|
23
21
|
from inspect_ai._util.error import EvalError
|
24
|
-
from inspect_ai._util.file import FileSystem,
|
22
|
+
from inspect_ai._util.file import FileSystem, dirname, file, filesystem
|
25
23
|
from inspect_ai._util.json import jsonable_python
|
26
24
|
from inspect_ai._util.trace import trace_action
|
27
25
|
from inspect_ai.model._chat_message import ChatMessage
|
@@ -277,16 +275,14 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
|
|
277
275
|
|
278
276
|
|
279
277
|
class ZipLogFile:
|
280
|
-
_zip: ZipFile
|
278
|
+
_zip: ZipFile | None
|
281
279
|
_temp_file: BinaryIO
|
282
280
|
_fs: FileSystem
|
283
|
-
_async_fs_context: _AsyncGeneratorContextManager[AsyncFileSystem] | None = None
|
284
|
-
_async_fs: AsyncFileSystem | None = None
|
285
281
|
|
286
282
|
def __init__(self, file: str) -> None:
|
287
283
|
self._file = file
|
288
284
|
self._fs = filesystem(file)
|
289
|
-
self._lock =
|
285
|
+
self._lock = anyio.Lock()
|
290
286
|
self._temp_file = tempfile.TemporaryFile()
|
291
287
|
self._samples: list[EvalSample] = []
|
292
288
|
self._summary_counter = 0
|
@@ -300,11 +296,6 @@ class ZipLogFile:
|
|
300
296
|
summaries: list[SampleSummary],
|
301
297
|
) -> None:
|
302
298
|
async with self._lock:
|
303
|
-
# connect to async filesystem if we can
|
304
|
-
if self._fs.is_async():
|
305
|
-
self._async_fs_context = async_fileystem(self._file)
|
306
|
-
self._async_fs = await self._async_fs_context.__aenter__()
|
307
|
-
|
308
299
|
self._open()
|
309
300
|
self._summary_counter = summary_counter
|
310
301
|
self._summaries = summaries
|
@@ -364,7 +355,8 @@ class ZipLogFile:
|
|
364
355
|
async def flush(self) -> None:
|
365
356
|
async with self._lock:
|
366
357
|
# close the zip file so it is flushed
|
367
|
-
self._zip
|
358
|
+
if self._zip:
|
359
|
+
self._zip.close()
|
368
360
|
|
369
361
|
# read the temp_file (leaves pointer at end for subsequent appends)
|
370
362
|
self._temp_file.seek(0)
|
@@ -380,21 +372,19 @@ class ZipLogFile:
|
|
380
372
|
|
381
373
|
async def close(self) -> EvalLog:
|
382
374
|
async with self._lock:
|
383
|
-
# close the async context if we have one
|
384
|
-
try:
|
385
|
-
if self._async_fs_context:
|
386
|
-
await self._async_fs_context.__aexit__(None, None, None)
|
387
|
-
except Exception as ex:
|
388
|
-
logger.warning(
|
389
|
-
f"Error occurred while closing async fs for {self._file}: {ex}"
|
390
|
-
)
|
391
|
-
|
392
375
|
# read the log from the temp file then close it
|
393
376
|
try:
|
394
377
|
self._temp_file.seek(0)
|
395
378
|
return _read_log(self._temp_file, self._file)
|
396
379
|
finally:
|
397
380
|
self._temp_file.close()
|
381
|
+
if self._zip:
|
382
|
+
self._zip.close()
|
383
|
+
|
384
|
+
# cleanup zip file if we didn't in normal course
|
385
|
+
def __del__(self) -> None:
|
386
|
+
if self._zip:
|
387
|
+
self._zip.close()
|
398
388
|
|
399
389
|
def _open(self) -> None:
|
400
390
|
self._zip = ZipFile(
|
@@ -406,6 +396,7 @@ class ZipLogFile:
|
|
406
396
|
|
407
397
|
# raw unsynchronized version of write
|
408
398
|
def _zip_writestr(self, filename: str, data: Any) -> None:
|
399
|
+
assert self._zip
|
409
400
|
self._zip.writestr(
|
410
401
|
filename,
|
411
402
|
to_json(
|
@@ -9,7 +9,7 @@ from typing_extensions import override
|
|
9
9
|
|
10
10
|
from inspect_ai._util.constants import LOG_SCHEMA_VERSION
|
11
11
|
from inspect_ai._util.error import EvalError
|
12
|
-
from inspect_ai._util.file import absolute_file_path,
|
12
|
+
from inspect_ai._util.file import absolute_file_path, file
|
13
13
|
from inspect_ai._util.trace import trace_action
|
14
14
|
|
15
15
|
from .._log import (
|
@@ -178,23 +178,8 @@ class JSONRecorder(FileRecorder):
|
|
178
178
|
log_bytes = eval_log_json(log)
|
179
179
|
|
180
180
|
with trace_action(logger, "Log Write", location):
|
181
|
-
|
182
|
-
|
183
|
-
try:
|
184
|
-
fs = filesystem(location)
|
185
|
-
if fs.is_async():
|
186
|
-
async with async_fileystem(location) as async_fs:
|
187
|
-
await async_fs._pipe_file(location, log_bytes)
|
188
|
-
written = True
|
189
|
-
except Exception as ex:
|
190
|
-
logger.warning(
|
191
|
-
f"Error occurred during async write to {location}: {ex}. Falling back to sync write."
|
192
|
-
)
|
193
|
-
|
194
|
-
# otherwise use sync
|
195
|
-
if not written:
|
196
|
-
with file(location, "wb") as f:
|
197
|
-
f.write(log_bytes)
|
181
|
+
with file(location, "wb") as f:
|
182
|
+
f.write(log_bytes)
|
198
183
|
|
199
184
|
|
200
185
|
def _validate_version(ver: int) -> None:
|
inspect_ai/log/_samples.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1
1
|
import contextlib
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import AsyncGenerator, Literal
|
4
|
+
from typing import AsyncGenerator, Iterator, Literal
|
5
5
|
|
6
6
|
from shortuuid import uuid
|
7
7
|
|
8
|
+
from inspect_ai._util.constants import SAMPLE_SUBTASK
|
8
9
|
from inspect_ai.dataset._dataset import Sample
|
9
10
|
from inspect_ai.util._sandbox import SandboxConnection
|
10
11
|
from inspect_ai.util._sandbox.context import sandbox_connections
|
11
12
|
|
12
|
-
from ._transcript import Transcript
|
13
|
+
from ._transcript import Transcript, transcript
|
13
14
|
|
14
15
|
|
15
16
|
class ActiveSample:
|
@@ -44,6 +45,7 @@ class ActiveSample:
|
|
44
45
|
self.total_tokens = 0
|
45
46
|
self.transcript = transcript
|
46
47
|
self.sandboxes = sandboxes
|
48
|
+
self.retry_count = 0
|
47
49
|
self._interrupt_action: Literal["score", "error"] | None = None
|
48
50
|
|
49
51
|
@property
|
@@ -153,6 +155,29 @@ def set_active_sample_total_messages(total_messages: int) -> None:
|
|
153
155
|
active.total_messages = total_messages
|
154
156
|
|
155
157
|
|
158
|
+
@contextlib.contextmanager
|
159
|
+
def track_active_sample_retries() -> Iterator[None]:
|
160
|
+
reset_active_sample_retries()
|
161
|
+
try:
|
162
|
+
yield
|
163
|
+
finally:
|
164
|
+
reset_active_sample_retries()
|
165
|
+
|
166
|
+
|
167
|
+
def reset_active_sample_retries() -> None:
|
168
|
+
active = sample_active()
|
169
|
+
if active:
|
170
|
+
active.retry_count = 0
|
171
|
+
|
172
|
+
|
173
|
+
def report_active_sample_retry() -> None:
|
174
|
+
active = sample_active()
|
175
|
+
if active:
|
176
|
+
# only do this for the top level subtask
|
177
|
+
if transcript().name == SAMPLE_SUBTASK:
|
178
|
+
active.retry_count = active.retry_count + 1
|
179
|
+
|
180
|
+
|
156
181
|
_sample_active: ContextVar[ActiveSample | None] = ContextVar(
|
157
182
|
"_sample_active", default=None
|
158
183
|
)
|
inspect_ai/log/_transcript.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
|
-
import asyncio
|
2
1
|
import contextlib
|
3
2
|
from contextvars import ContextVar
|
4
3
|
from datetime import datetime
|
5
4
|
from logging import getLogger
|
6
5
|
from typing import (
|
7
6
|
Any,
|
7
|
+
Callable,
|
8
8
|
Iterator,
|
9
9
|
Literal,
|
10
10
|
Sequence,
|
@@ -210,15 +210,15 @@ class ToolEvent(BaseEvent):
|
|
210
210
|
|
211
211
|
# mechanism for operator to cancel the tool call
|
212
212
|
|
213
|
-
def
|
213
|
+
def _set_cancel_fn(self, cancel_fn: Callable[[], None]) -> None:
|
214
214
|
"""Set the tool task (for possible cancellation)"""
|
215
|
-
self.
|
215
|
+
self._cancel_fn = cancel_fn
|
216
216
|
|
217
217
|
def _cancel(self) -> None:
|
218
218
|
"""Cancel the tool task."""
|
219
|
-
if self.
|
219
|
+
if self._cancel_fn and not self.cancelled:
|
220
220
|
self._cancelled = True
|
221
|
-
self.
|
221
|
+
self._cancel_fn()
|
222
222
|
|
223
223
|
@property
|
224
224
|
def cancelled(self) -> bool:
|
@@ -228,11 +228,11 @@ class ToolEvent(BaseEvent):
|
|
228
228
|
_cancelled: bool | None = None
|
229
229
|
"""Was this tool call cancelled?"""
|
230
230
|
|
231
|
-
|
232
|
-
"""
|
231
|
+
_cancel_fn: Callable[[], None] | None = None
|
232
|
+
"""Function which can be used to cancel the tool call."""
|
233
233
|
|
234
234
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
235
|
-
"""Required so that we can include '
|
235
|
+
"""Required so that we can include '_cancel_fn' as a member."""
|
236
236
|
|
237
237
|
@field_serializer("completed")
|
238
238
|
def serialize_completed(self, dt: datetime) -> str:
|
inspect_ai/model/__init__.py
CHANGED
@@ -27,7 +27,7 @@ from ._chat_message import (
|
|
27
27
|
ChatMessageTool,
|
28
28
|
ChatMessageUser,
|
29
29
|
)
|
30
|
-
from ._generate_config import GenerateConfig, GenerateConfigArgs
|
30
|
+
from ._generate_config import GenerateConfig, GenerateConfigArgs, ResponseSchema
|
31
31
|
from ._model import (
|
32
32
|
Model,
|
33
33
|
ModelAPI,
|
@@ -49,6 +49,7 @@ from ._registry import modelapi
|
|
49
49
|
__all__ = [
|
50
50
|
"GenerateConfig",
|
51
51
|
"GenerateConfigArgs",
|
52
|
+
"ResponseSchema",
|
52
53
|
"CachePolicy",
|
53
54
|
"ContentAudio",
|
54
55
|
"ContentImage",
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
import asyncio
|
2
1
|
import inspect
|
3
2
|
import json
|
3
|
+
import sys
|
4
4
|
import types
|
5
5
|
from dataclasses import is_dataclass
|
6
6
|
from logging import getLogger
|
@@ -22,7 +22,13 @@ from typing import (
|
|
22
22
|
is_typeddict,
|
23
23
|
)
|
24
24
|
|
25
|
+
if sys.version_info < (3, 11):
|
26
|
+
from exceptiongroup import ExceptionGroup
|
27
|
+
|
28
|
+
|
29
|
+
import anyio
|
25
30
|
import yaml
|
31
|
+
from anyio.streams.memory import MemoryObjectSendStream
|
26
32
|
from jsonschema import Draft7Validator
|
27
33
|
from pydantic import BaseModel
|
28
34
|
|
@@ -80,7 +86,10 @@ async def call_tools(
|
|
80
86
|
|
81
87
|
tdefs = tool_defs(tools)
|
82
88
|
|
83
|
-
async def call_tool_task(
|
89
|
+
async def call_tool_task(
|
90
|
+
call: ToolCall,
|
91
|
+
send_stream: MemoryObjectSendStream[tuple[ChatMessageTool, ToolEvent]],
|
92
|
+
) -> None:
|
84
93
|
# create a transript for this call
|
85
94
|
init_transcript(Transcript(name=call.function))
|
86
95
|
|
@@ -166,20 +175,23 @@ async def call_tools(
|
|
166
175
|
events=list(transcript().events),
|
167
176
|
)
|
168
177
|
|
169
|
-
#
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
178
|
+
# yield message and event
|
179
|
+
async with send_stream:
|
180
|
+
await send_stream.send(
|
181
|
+
(
|
182
|
+
ChatMessageTool(
|
183
|
+
content=content,
|
184
|
+
tool_call_id=call.id,
|
185
|
+
function=call.function,
|
186
|
+
error=tool_error,
|
187
|
+
),
|
188
|
+
event,
|
189
|
+
)
|
190
|
+
)
|
176
191
|
|
177
192
|
# call tools
|
178
193
|
tool_messages: list[ChatMessageTool] = []
|
179
194
|
for call in message.tool_calls:
|
180
|
-
# create the task
|
181
|
-
task = asyncio.create_task(call_tool_task(call))
|
182
|
-
|
183
195
|
# create pending tool event and add it to the transcript
|
184
196
|
# (record the waiting time for the sample so we can compare
|
185
197
|
# it at the end to deduce total waiting time inside the tool
|
@@ -192,38 +204,46 @@ async def call_tools(
|
|
192
204
|
view=call.view,
|
193
205
|
pending=True,
|
194
206
|
)
|
195
|
-
event._set_task(task)
|
196
207
|
transcript()._event(event)
|
197
208
|
|
198
|
-
# execute the tool call. if the operator
|
209
|
+
# execute the tool call. if the operator cancels the
|
199
210
|
# tool call then synthesize the appropriate message/event
|
211
|
+
send_stream, receive_stream = anyio.create_memory_object_stream[
|
212
|
+
tuple[ChatMessageTool, ToolEvent]
|
213
|
+
]()
|
200
214
|
try:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
215
|
+
async with anyio.create_task_group() as tg:
|
216
|
+
tg.start_soon(call_tool_task, call, send_stream)
|
217
|
+
event._set_cancel_fn(tg.cancel_scope.cancel)
|
218
|
+
async with receive_stream:
|
219
|
+
async for result in receive_stream:
|
220
|
+
tool_message, result_event = result
|
221
|
+
break
|
222
|
+
except ExceptionGroup as ex:
|
223
|
+
raise ex.exceptions[0]
|
224
|
+
|
225
|
+
if event.cancelled:
|
226
|
+
tool_message = ChatMessageTool(
|
227
|
+
content="",
|
228
|
+
function=call.function,
|
229
|
+
tool_call_id=call.id,
|
230
|
+
error=ToolCallError(
|
231
|
+
"timeout", "Command timed out before completing."
|
232
|
+
),
|
233
|
+
)
|
234
|
+
result_event = ToolEvent(
|
235
|
+
id=call.id,
|
236
|
+
function=call.function,
|
237
|
+
arguments=call.arguments,
|
238
|
+
result=tool_message.content,
|
239
|
+
truncated=None,
|
240
|
+
view=call.view,
|
241
|
+
error=tool_message.error,
|
242
|
+
events=[],
|
243
|
+
)
|
244
|
+
transcript().info(
|
245
|
+
f"Tool call '{call.function}' was cancelled by operator."
|
246
|
+
)
|
227
247
|
|
228
248
|
# update return messages
|
229
249
|
tool_messages.append(tool_message)
|
@@ -2,6 +2,7 @@ from logging import getLogger
|
|
2
2
|
from typing import Any, Literal, Type, Union
|
3
3
|
|
4
4
|
from pydantic import BaseModel, Field, model_validator
|
5
|
+
from shortuuid import uuid
|
5
6
|
|
6
7
|
from inspect_ai._util.content import Content, ContentReasoning, ContentText
|
7
8
|
from inspect_ai.tool import ToolCall
|
@@ -15,8 +16,8 @@ logger = getLogger(__name__)
|
|
15
16
|
class ChatMessageBase(BaseModel):
|
16
17
|
"""Base class for chat messages."""
|
17
18
|
|
18
|
-
|
19
|
-
"""
|
19
|
+
id: str = Field(default_factory=uuid)
|
20
|
+
"""Unique identifer for message."""
|
20
21
|
|
21
22
|
content: str | list[Content]
|
22
23
|
"""Content (simple string or list of content objects)"""
|
@@ -5,6 +5,25 @@ from typing import Any, Literal, Union
|
|
5
5
|
from pydantic import BaseModel, Field, model_validator
|
6
6
|
from typing_extensions import TypedDict
|
7
7
|
|
8
|
+
from inspect_ai.util._json import JSONSchema
|
9
|
+
|
10
|
+
|
11
|
+
class ResponseSchema(BaseModel):
|
12
|
+
"""Schema for model response when using Structured Output."""
|
13
|
+
|
14
|
+
name: str
|
15
|
+
"""The name of the response schema. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64."""
|
16
|
+
|
17
|
+
json_schema: JSONSchema
|
18
|
+
"""The schema for the response format, described as a JSON Schema object."""
|
19
|
+
|
20
|
+
description: str | None = Field(default=None)
|
21
|
+
"""A description of what the response format is for, used by the model to determine how to respond in the format."""
|
22
|
+
|
23
|
+
strict: bool | None = Field(default=None)
|
24
|
+
"""Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the schema field.
|
25
|
+
OpenAI and Mistral only."""
|
26
|
+
|
8
27
|
|
9
28
|
class GenerateConfigArgs(TypedDict, total=False):
|
10
29
|
"""Type for kwargs that selectively override GenerateConfig."""
|
@@ -81,6 +100,9 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
81
100
|
reasoning_history: Literal["none", "all", "last", "auto"] | None
|
82
101
|
"""Include reasoning in chat message history sent to generate."""
|
83
102
|
|
103
|
+
response_schema: ResponseSchema | None
|
104
|
+
"""Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
|
105
|
+
|
84
106
|
|
85
107
|
class GenerateConfig(BaseModel):
|
86
108
|
"""Model generation options."""
|
@@ -159,6 +181,9 @@ class GenerateConfig(BaseModel):
|
|
159
181
|
)
|
160
182
|
"""Include reasoning in chat message history sent to generate."""
|
161
183
|
|
184
|
+
response_schema: ResponseSchema | None = Field(default=None)
|
185
|
+
"""Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
|
186
|
+
|
162
187
|
# migrate reasoning_history as a bool
|
163
188
|
@model_validator(mode="before")
|
164
189
|
@classmethod
|