inspect-ai 0.3.55__py3-none-any.whl → 0.3.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +1 -0
- inspect_ai/_cli/common.py +1 -1
- inspect_ai/_cli/trace.py +33 -20
- inspect_ai/_display/core/active.py +1 -1
- inspect_ai/_display/core/display.py +1 -1
- inspect_ai/_display/core/footer.py +1 -1
- inspect_ai/_display/core/progress.py +0 -6
- inspect_ai/_display/core/rich.py +1 -1
- inspect_ai/_display/rich/display.py +2 -2
- inspect_ai/_display/textual/app.py +15 -17
- inspect_ai/_display/textual/widgets/clock.py +3 -3
- inspect_ai/_display/textual/widgets/samples.py +6 -13
- inspect_ai/_eval/context.py +9 -1
- inspect_ai/_eval/score.py +4 -10
- inspect_ai/_eval/task/results.py +5 -4
- inspect_ai/_eval/task/run.py +6 -12
- inspect_ai/_eval/task/task.py +10 -0
- inspect_ai/_util/ansi.py +31 -0
- inspect_ai/_util/format.py +7 -0
- inspect_ai/_util/logger.py +12 -12
- inspect_ai/_util/throttle.py +10 -1
- inspect_ai/_util/trace.py +43 -47
- inspect_ai/_util/transcript.py +4 -0
- inspect_ai/_util/vscode.py +51 -0
- inspect_ai/_view/notify.py +2 -1
- inspect_ai/_view/www/App.css +22 -1
- inspect_ai/_view/www/dist/assets/index.css +2374 -2
- inspect_ai/_view/www/dist/assets/index.js +29622 -24424
- inspect_ai/_view/www/log-schema.json +138 -90
- inspect_ai/_view/www/package.json +1 -0
- inspect_ai/_view/www/src/App.mjs +1 -0
- inspect_ai/_view/www/src/appearance/Icons.mjs +2 -0
- inspect_ai/_view/www/src/components/AsciiCinemaPlayer.mjs +74 -0
- inspect_ai/_view/www/src/components/CopyButton.mjs +0 -1
- inspect_ai/_view/www/src/components/HumanBaselineView.mjs +168 -0
- inspect_ai/_view/www/src/components/LightboxCarousel.mjs +217 -0
- inspect_ai/_view/www/src/components/Tools.mjs +11 -3
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +3 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +1 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +56 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +17 -5
- inspect_ai/_view/www/src/types/asciicinema-player.d.ts +26 -0
- inspect_ai/_view/www/src/types/log.d.ts +26 -12
- inspect_ai/_view/www/yarn.lock +44 -0
- inspect_ai/approval/_apply.py +4 -0
- inspect_ai/approval/_human/panel.py +5 -8
- inspect_ai/dataset/_dataset.py +51 -10
- inspect_ai/dataset/_util.py +31 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +5 -2
- inspect_ai/model/_call_tools.py +4 -2
- inspect_ai/model/_chat_message.py +3 -0
- inspect_ai/model/_model.py +42 -1
- inspect_ai/model/_providers/anthropic.py +4 -0
- inspect_ai/model/_render.py +9 -2
- inspect_ai/scorer/_metric.py +12 -1
- inspect_ai/solver/__init__.py +2 -0
- inspect_ai/solver/_human_agent/agent.py +83 -0
- inspect_ai/solver/_human_agent/commands/__init__.py +36 -0
- inspect_ai/solver/_human_agent/commands/clock.py +70 -0
- inspect_ai/solver/_human_agent/commands/command.py +59 -0
- inspect_ai/solver/_human_agent/commands/instructions.py +74 -0
- inspect_ai/solver/_human_agent/commands/note.py +42 -0
- inspect_ai/solver/_human_agent/commands/score.py +80 -0
- inspect_ai/solver/_human_agent/commands/status.py +62 -0
- inspect_ai/solver/_human_agent/commands/submit.py +151 -0
- inspect_ai/solver/_human_agent/install.py +222 -0
- inspect_ai/solver/_human_agent/panel.py +252 -0
- inspect_ai/solver/_human_agent/service.py +45 -0
- inspect_ai/solver/_human_agent/state.py +55 -0
- inspect_ai/solver/_human_agent/view.py +24 -0
- inspect_ai/solver/_task_state.py +28 -2
- inspect_ai/tool/_tool.py +10 -2
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +13 -10
- inspect_ai/util/__init__.py +8 -4
- inspect_ai/{_util/display.py → util/_display.py} +6 -0
- inspect_ai/util/_panel.py +31 -9
- inspect_ai/util/_sandbox/__init__.py +0 -3
- inspect_ai/util/_sandbox/context.py +5 -1
- inspect_ai/util/_sandbox/docker/compose.py +16 -10
- inspect_ai/util/_sandbox/docker/docker.py +9 -6
- inspect_ai/util/_sandbox/docker/internal.py +1 -1
- inspect_ai/util/_sandbox/docker/util.py +2 -2
- inspect_ai/util/_sandbox/environment.py +6 -5
- inspect_ai/util/_sandbox/local.py +1 -1
- inspect_ai/util/_sandbox/service.py +22 -7
- inspect_ai/util/_store.py +5 -6
- inspect_ai/util/_store_model.py +110 -0
- inspect_ai/util/_throttle.py +32 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/RECORD +95 -73
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/top_level.txt +0 -0
inspect_ai/__init__.py
CHANGED
@@ -9,6 +9,7 @@ from inspect_ai._eval.registry import task
|
|
9
9
|
from inspect_ai._eval.score import score, score_async
|
10
10
|
from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks
|
11
11
|
from inspect_ai._util.constants import PKG_NAME
|
12
|
+
from inspect_ai.solver._human_agent.agent import human_agent
|
12
13
|
|
13
14
|
__version__ = importlib_version(PKG_NAME)
|
14
15
|
|
inspect_ai/_cli/common.py
CHANGED
inspect_ai/_cli/trace.py
CHANGED
@@ -4,7 +4,7 @@ import time
|
|
4
4
|
from datetime import datetime
|
5
5
|
from json import dumps
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Callable
|
7
|
+
from typing import Callable
|
8
8
|
|
9
9
|
import click
|
10
10
|
from pydantic_core import to_json
|
@@ -13,8 +13,12 @@ from rich.console import Console, RenderableType
|
|
13
13
|
from rich.table import Column, Table
|
14
14
|
|
15
15
|
from inspect_ai._util.error import PrerequisiteError
|
16
|
-
from inspect_ai._util.
|
17
|
-
|
16
|
+
from inspect_ai._util.trace import (
|
17
|
+
ActionTraceRecord,
|
18
|
+
inspect_trace_dir,
|
19
|
+
list_trace_files,
|
20
|
+
read_trace_file,
|
21
|
+
)
|
18
22
|
|
19
23
|
|
20
24
|
@click.group("trace")
|
@@ -36,32 +40,31 @@ def trace_command() -> None:
|
|
36
40
|
)
|
37
41
|
def list_command(json: bool) -> None:
|
38
42
|
"""List all trace files."""
|
39
|
-
|
40
|
-
trace_files: list[dict[str, float | str]] = [
|
41
|
-
{"mtime": f.lstat().st_mtime, "file": f.absolute().as_posix()}
|
42
|
-
for f in trace_dir.iterdir()
|
43
|
-
if f.is_file()
|
44
|
-
]
|
45
|
-
trace_files.sort(key=lambda f: cast(float, f["mtime"]), reverse=True)
|
43
|
+
trace_files = list_trace_files()
|
46
44
|
if json:
|
47
|
-
print(
|
45
|
+
print(
|
46
|
+
dumps(
|
47
|
+
[dict(file=str(file.file), mtime=file.mtime) for file in trace_files],
|
48
|
+
indent=2,
|
49
|
+
)
|
50
|
+
)
|
48
51
|
else:
|
49
52
|
table = Table(box=None, show_header=True, pad_edge=False)
|
50
53
|
table.add_column("Time")
|
51
54
|
table.add_column("Trace File")
|
52
55
|
for file in trace_files:
|
53
|
-
mtime = datetime.fromtimestamp(
|
56
|
+
mtime = datetime.fromtimestamp(file.mtime).astimezone()
|
54
57
|
table.add_row(
|
55
|
-
mtime.strftime("%d-%b %H:%M:%S %Z"), shlex.quote(str(file
|
58
|
+
mtime.strftime("%d-%b %H:%M:%S %Z"), shlex.quote(str(file.file))
|
56
59
|
)
|
57
60
|
r_print(table)
|
58
61
|
|
59
62
|
|
60
63
|
@trace_command.command("dump")
|
61
|
-
@click.argument("trace-file", type=str, required=False
|
62
|
-
def
|
64
|
+
@click.argument("trace-file", type=str, required=False)
|
65
|
+
def dump_command(trace_file: str | None) -> None:
|
63
66
|
"""Dump a trace file to stdout (as a JSON array of log records)."""
|
64
|
-
trace_file_path =
|
67
|
+
trace_file_path = _resolve_trace_file_path(trace_file)
|
65
68
|
|
66
69
|
traces = read_trace_file(trace_file_path)
|
67
70
|
print(
|
@@ -70,16 +73,16 @@ def read_command(trace_file: str) -> None:
|
|
70
73
|
|
71
74
|
|
72
75
|
@trace_command.command("anomalies")
|
73
|
-
@click.argument("trace-file", type=str, required=False
|
76
|
+
@click.argument("trace-file", type=str, required=False)
|
74
77
|
@click.option(
|
75
78
|
"--all",
|
76
79
|
is_flag=True,
|
77
80
|
default=False,
|
78
81
|
help="Show all anomolies including errors and timeouts (by default only still running and cancelled actions are shown).",
|
79
82
|
)
|
80
|
-
def anomolies_command(trace_file: str, all: bool) -> None:
|
83
|
+
def anomolies_command(trace_file: str | None, all: bool) -> None:
|
81
84
|
"""Look for anomalies in a trace file (never completed or cancelled actions)."""
|
82
|
-
trace_file_path =
|
85
|
+
trace_file_path = _resolve_trace_file_path(trace_file)
|
83
86
|
traces = read_trace_file(trace_file_path)
|
84
87
|
|
85
88
|
# Track started actions
|
@@ -226,7 +229,17 @@ def _print_bucket(
|
|
226
229
|
print_fn(table)
|
227
230
|
|
228
231
|
|
229
|
-
def
|
232
|
+
def _resolve_trace_file(trace_file: str | None) -> str:
|
233
|
+
if trace_file is None:
|
234
|
+
trace_files = list_trace_files()
|
235
|
+
if len(trace_files) == 0:
|
236
|
+
raise PrerequisiteError("No trace files currently availalble.")
|
237
|
+
trace_file = str(trace_files[0].file)
|
238
|
+
return trace_file
|
239
|
+
|
240
|
+
|
241
|
+
def _resolve_trace_file_path(trace_file: str | None) -> Path:
|
242
|
+
trace_file = _resolve_trace_file(trace_file)
|
230
243
|
trace_file_path = Path(trace_file)
|
231
244
|
if not trace_file_path.is_absolute():
|
232
245
|
trace_file_path = inspect_trace_dir() / trace_file_path
|
@@ -99,7 +99,7 @@ class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]):
|
|
99
99
|
) -> Iterator[Console]:
|
100
100
|
yield rich.get_console()
|
101
101
|
|
102
|
-
async def input_panel(self,
|
102
|
+
async def input_panel(self, panel_type: type[TP]) -> TP:
|
103
103
|
raise NotImplementedError("input_panel not implemented by current display")
|
104
104
|
|
105
105
|
|
@@ -2,8 +2,8 @@ from rich.console import RenderableType
|
|
2
2
|
from rich.text import Text
|
3
3
|
|
4
4
|
from inspect_ai._util.logger import http_rate_limit_count
|
5
|
-
from inspect_ai._util.throttle import throttle
|
6
5
|
from inspect_ai.util._concurrency import concurrency_status
|
6
|
+
from inspect_ai.util._throttle import throttle
|
7
7
|
|
8
8
|
from .config import task_dict
|
9
9
|
|
@@ -124,12 +124,6 @@ def progress_status_icon(result: TaskResult | None) -> str:
|
|
124
124
|
return f"[{theme.meta}]⠿[{theme.meta}]"
|
125
125
|
|
126
126
|
|
127
|
-
def progress_time(time: float) -> str:
|
128
|
-
minutes, seconds = divmod(time, 60)
|
129
|
-
hours, minutes = divmod(minutes, 60)
|
130
|
-
return f"{hours:2.0f}:{minutes:02.0f}:{seconds:02.0f}"
|
131
|
-
|
132
|
-
|
133
127
|
def progress_count(complete: int, total: int, width: int | None = None) -> str:
|
134
128
|
# Pad the display to keep it stable as the
|
135
129
|
# complete metrics
|
inspect_ai/_display/core/rich.py
CHANGED
@@ -9,9 +9,9 @@ from rich.segment import Segment
|
|
9
9
|
from rich.syntax import Syntax
|
10
10
|
from typing_extensions import override
|
11
11
|
|
12
|
-
from inspect_ai._util.display import display_type
|
13
12
|
from inspect_ai._util.platform import is_running_in_jupyterlab, is_running_in_vscode
|
14
13
|
from inspect_ai._util.transcript import transcript_code_theme
|
14
|
+
from inspect_ai.util._display import display_type
|
15
15
|
|
16
16
|
|
17
17
|
def is_vscode_notebook(console: Console) -> bool:
|
@@ -12,9 +12,9 @@ from rich.table import Table
|
|
12
12
|
from typing_extensions import override
|
13
13
|
|
14
14
|
from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH
|
15
|
-
from inspect_ai._util.display import display_type
|
16
|
-
from inspect_ai._util.throttle import throttle
|
17
15
|
from inspect_ai.log._transcript import InputEvent, transcript
|
16
|
+
from inspect_ai.util._display import display_type
|
17
|
+
from inspect_ai.util._throttle import throttle
|
18
18
|
from inspect_ai.util._trace import trace_enabled
|
19
19
|
|
20
20
|
from ..core.config import task_config
|
@@ -10,6 +10,7 @@ from textual.app import App, ComposeResult
|
|
10
10
|
from textual.binding import Binding, BindingType
|
11
11
|
from textual.css.query import NoMatches
|
12
12
|
from textual.events import Print
|
13
|
+
from textual.widget import Widget
|
13
14
|
from textual.widgets import TabbedContent, TabPane
|
14
15
|
from textual.widgets.tabbed_content import ContentTabs
|
15
16
|
from textual.worker import Worker, WorkerState
|
@@ -345,13 +346,15 @@ class TaskScreenApp(App[TR]):
|
|
345
346
|
self.update_title()
|
346
347
|
|
347
348
|
# dynamic input panels
|
348
|
-
async def add_input_panel(self,
|
349
|
+
async def add_input_panel(self, panel: InputPanel) -> None:
|
349
350
|
tabs = self.query_one(TabbedContent)
|
350
|
-
await tabs.add_pane(
|
351
|
+
await tabs.add_pane(
|
352
|
+
TabPane(panel.title, panel, id=as_input_panel_id(type(panel)))
|
353
|
+
)
|
351
354
|
|
352
|
-
def get_input_panel(self,
|
355
|
+
def get_input_panel(self, panel_type: type) -> InputPanel | None:
|
353
356
|
try:
|
354
|
-
tab_pane = self.query_one(f"#{as_input_panel_id(
|
357
|
+
tab_pane = self.query_one(f"#{as_input_panel_id(panel_type)}")
|
355
358
|
if len(tab_pane.children) > 0:
|
356
359
|
return cast(InputPanel, tab_pane.children[0])
|
357
360
|
else:
|
@@ -359,10 +362,6 @@ class TaskScreenApp(App[TR]):
|
|
359
362
|
except NoMatches:
|
360
363
|
return None
|
361
364
|
|
362
|
-
async def remove_input_panel(self, title: str) -> None:
|
363
|
-
tabs = self.query_one(TabbedContent)
|
364
|
-
await tabs.remove_pane(as_html_id(as_input_panel_id(title), title))
|
365
|
-
|
366
365
|
class InputPanelHost(InputPanel.Host):
|
367
366
|
def __init__(self, app: "TaskScreenApp[TR]", tab_id: str) -> None:
|
368
367
|
self.app = app
|
@@ -383,7 +382,7 @@ class TaskScreenApp(App[TR]):
|
|
383
382
|
# the tabs control so the user can switch back w/ the keyboard
|
384
383
|
tab_pane = self.app.query_one(f"#{self.tab_id}")
|
385
384
|
panel = cast(InputPanel, tab_pane.children[0])
|
386
|
-
for child in panel.
|
385
|
+
for child in panel.walk_children(Widget):
|
387
386
|
if child.focusable:
|
388
387
|
child.focus()
|
389
388
|
self.app.query_one(ContentTabs).focus()
|
@@ -455,19 +454,18 @@ class TextualTaskScreen(TaskScreen, Generic[TR]):
|
|
455
454
|
console.width = old_width
|
456
455
|
|
457
456
|
@override
|
458
|
-
async def input_panel(self,
|
457
|
+
async def input_panel(self, panel_type: type[TP]) -> TP:
|
459
458
|
async with self.lock:
|
460
|
-
panel_widget = self.app.get_input_panel(
|
459
|
+
panel_widget = self.app.get_input_panel(panel_type)
|
461
460
|
if panel_widget is None:
|
462
|
-
panel_widget =
|
463
|
-
title,
|
461
|
+
panel_widget = panel_type(
|
464
462
|
TaskScreenApp[TR].InputPanelHost(
|
465
|
-
self.app, as_input_panel_id(
|
463
|
+
self.app, as_input_panel_id(panel_type)
|
466
464
|
),
|
467
465
|
)
|
468
|
-
await self.app.add_input_panel(
|
466
|
+
await self.app.add_input_panel(panel_widget)
|
469
467
|
return cast(TP, panel_widget)
|
470
468
|
|
471
469
|
|
472
|
-
def as_input_panel_id(
|
473
|
-
return as_html_id("id-input-panel",
|
470
|
+
def as_input_panel_id(panel_type: type) -> str:
|
471
|
+
return as_html_id("id-input-panel", panel_type.__name__)
|
@@ -4,7 +4,7 @@ from textual.reactive import reactive
|
|
4
4
|
from textual.timer import Timer
|
5
5
|
from textual.widgets import Static
|
6
6
|
|
7
|
-
from inspect_ai.
|
7
|
+
from inspect_ai._util.format import format_progress_time
|
8
8
|
|
9
9
|
|
10
10
|
class Clock(Static):
|
@@ -43,7 +43,7 @@ class Clock(Static):
|
|
43
43
|
if start_time is not None:
|
44
44
|
if self.timer is None:
|
45
45
|
self.timer = self.set_interval(self.interval, self.update_time)
|
46
|
-
self.update(
|
46
|
+
self.update(format_progress_time(start_time))
|
47
47
|
else:
|
48
48
|
self.stop()
|
49
49
|
|
@@ -52,4 +52,4 @@ class Clock(Static):
|
|
52
52
|
self.time = datetime.now().timestamp() - self.start_time
|
53
53
|
|
54
54
|
def watch_time(self, time: float) -> None:
|
55
|
-
self.update(
|
55
|
+
self.update(format_progress_time(time))
|
@@ -22,10 +22,10 @@ from textual.widgets import (
|
|
22
22
|
)
|
23
23
|
from textual.widgets.option_list import Option, Separator
|
24
24
|
|
25
|
+
from inspect_ai._util.format import format_progress_time
|
25
26
|
from inspect_ai._util.registry import registry_unqualified_name
|
26
27
|
from inspect_ai.log._samples import ActiveSample
|
27
28
|
|
28
|
-
from ...core.progress import progress_time
|
29
29
|
from .clock import Clock
|
30
30
|
from .transcript import TranscriptView
|
31
31
|
|
@@ -147,7 +147,9 @@ class SamplesList(OptionList):
|
|
147
147
|
table.add_column(width=1)
|
148
148
|
task_name = Text.from_markup(f"{registry_unqualified_name(sample.task)}")
|
149
149
|
task_name.truncate(18, overflow="ellipsis", pad=True)
|
150
|
-
task_time = Text.from_markup(
|
150
|
+
task_time = Text.from_markup(
|
151
|
+
f"{format_progress_time(sample.execution_time)}"
|
152
|
+
)
|
151
153
|
table.add_row(task_name, task_time, " ")
|
152
154
|
sample_id = Text.from_markup(f"id: {sample.sample.id}")
|
153
155
|
sample_id.truncate(18, overflow="ellipsis", pad=True)
|
@@ -308,12 +310,7 @@ class SandboxesView(Vertical):
|
|
308
310
|
yield Vertical(id="sandboxes-list")
|
309
311
|
|
310
312
|
async def sync_sample(self, sample: ActiveSample) -> None:
|
311
|
-
|
312
|
-
show_sandboxes = (
|
313
|
-
len([sandbox for sandbox in sandboxes.values() if sandbox.container]) > 0
|
314
|
-
)
|
315
|
-
|
316
|
-
if show_sandboxes:
|
313
|
+
if len(sample.sandboxes) > 0:
|
317
314
|
self.display = True
|
318
315
|
sandboxes_caption = cast(Static, self.query_one("#sandboxes-caption"))
|
319
316
|
sandboxes_caption.update("[bold]sandbox containers:[/bold]")
|
@@ -321,11 +318,7 @@ class SandboxesView(Vertical):
|
|
321
318
|
sandboxes_list = self.query_one("#sandboxes-list")
|
322
319
|
await sandboxes_list.remove_children()
|
323
320
|
await sandboxes_list.mount_all(
|
324
|
-
[
|
325
|
-
Static(sandbox.container)
|
326
|
-
for sandbox in sandboxes.values()
|
327
|
-
if sandbox.container
|
328
|
-
]
|
321
|
+
[Static(sandbox.command) for sandbox in sample.sandboxes.values()]
|
329
322
|
)
|
330
323
|
sandboxes_list.mount(
|
331
324
|
Static(
|
inspect_ai/_eval/context.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from inspect_ai._util.dotenv import init_dotenv
|
2
2
|
from inspect_ai._util.hooks import init_hooks
|
3
3
|
from inspect_ai._util.logger import init_http_rate_limit_count, init_logger
|
4
|
+
from inspect_ai.approval._apply import have_tool_approval, init_tool_approval
|
4
5
|
from inspect_ai.approval._human.manager import init_human_approval_manager
|
6
|
+
from inspect_ai.approval._policy import ApprovalPolicy
|
5
7
|
from inspect_ai.log._samples import init_active_samples
|
6
8
|
from inspect_ai.model import GenerateConfig, Model
|
7
9
|
from inspect_ai.model._model import init_active_model, init_model_usage
|
@@ -24,6 +26,12 @@ def init_eval_context(
|
|
24
26
|
init_human_approval_manager()
|
25
27
|
|
26
28
|
|
27
|
-
def init_task_context(
|
29
|
+
def init_task_context(
|
30
|
+
model: Model,
|
31
|
+
approval: list[ApprovalPolicy] | None = None,
|
32
|
+
config: GenerateConfig = GenerateConfig(),
|
33
|
+
) -> None:
|
28
34
|
init_active_model(model, config)
|
29
35
|
init_model_usage()
|
36
|
+
if not have_tool_approval():
|
37
|
+
init_tool_approval(approval)
|
inspect_ai/_eval/score.py
CHANGED
@@ -11,7 +11,7 @@ from inspect_ai.log import (
|
|
11
11
|
EvalMetric,
|
12
12
|
)
|
13
13
|
from inspect_ai.model import ModelName
|
14
|
-
from inspect_ai.scorer import Metric,
|
14
|
+
from inspect_ai.scorer import Metric, Scorer, Target
|
15
15
|
from inspect_ai.scorer._metric import SampleScore
|
16
16
|
from inspect_ai.scorer._reducer import (
|
17
17
|
ScoreReducer,
|
@@ -108,7 +108,7 @@ async def score_async(
|
|
108
108
|
|
109
109
|
# write them back (gather ensures that they come back in the same order)
|
110
110
|
for index, score in enumerate(scores):
|
111
|
-
log.samples[index].scores =
|
111
|
+
log.samples[index].scores = {k: v.score for k, v in score.items()}
|
112
112
|
|
113
113
|
# collect metrics from EvalLog (they may overlap w/ the scorer metrics,
|
114
114
|
# that will be taken care of in eval_results)
|
@@ -151,11 +151,8 @@ async def task_score(task: Task, log: EvalLog) -> EvalLog:
|
|
151
151
|
sample_scores = [
|
152
152
|
{
|
153
153
|
score_key: SampleScore(
|
154
|
+
score=score,
|
154
155
|
sample_id=sample.id,
|
155
|
-
value=score.value,
|
156
|
-
answer=score.answer,
|
157
|
-
explanation=score.explanation,
|
158
|
-
metadata=score.metadata,
|
159
156
|
)
|
160
157
|
for score_key, score in sample.scores.items()
|
161
158
|
}
|
@@ -185,11 +182,8 @@ async def run_score_task(
|
|
185
182
|
scorer_name = unique_scorer_name(scorer, list(results.keys()))
|
186
183
|
|
187
184
|
results[scorer_name] = SampleScore(
|
185
|
+
score=result,
|
188
186
|
sample_id=state.sample_id,
|
189
|
-
value=result.value,
|
190
|
-
answer=result.answer,
|
191
|
-
explanation=result.explanation,
|
192
|
-
metadata=result.metadata,
|
193
187
|
)
|
194
188
|
|
195
189
|
progress()
|
inspect_ai/_eval/task/results.py
CHANGED
@@ -13,6 +13,7 @@ from inspect_ai._util.registry import (
|
|
13
13
|
from inspect_ai.log import (
|
14
14
|
EvalMetric,
|
15
15
|
EvalResults,
|
16
|
+
EvalSampleScore,
|
16
17
|
EvalScore,
|
17
18
|
)
|
18
19
|
from inspect_ai.log._log import EvalSampleReductions
|
@@ -345,7 +346,7 @@ def resolve_glob_metric_keys(
|
|
345
346
|
|
346
347
|
def reduce_scores(
|
347
348
|
scores: list[SampleScore], reducer: ScoreReducer
|
348
|
-
) -> list[
|
349
|
+
) -> list[EvalSampleScore]:
|
349
350
|
# Group the scores by sample_id
|
350
351
|
grouped_scores: dict[str, list[SampleScore]] = defaultdict(list)
|
351
352
|
for sample_score in scores:
|
@@ -353,11 +354,11 @@ def reduce_scores(
|
|
353
354
|
grouped_scores[str(sample_score.sample_id)].append(sample_score)
|
354
355
|
|
355
356
|
# reduce the scores
|
356
|
-
reduced_scores: list[
|
357
|
+
reduced_scores: list[EvalSampleScore] = []
|
357
358
|
for scores in grouped_scores.values():
|
358
|
-
reduced = reducer(
|
359
|
+
reduced = reducer([score.score for score in scores])
|
359
360
|
reduced_scores.append(
|
360
|
-
|
361
|
+
EvalSampleScore(
|
361
362
|
sample_id=scores[0].sample_id,
|
362
363
|
value=reduced.value,
|
363
364
|
answer=reduced.answer,
|
inspect_ai/_eval/task/run.py
CHANGED
@@ -6,7 +6,7 @@ from copy import deepcopy
|
|
6
6
|
from dataclasses import dataclass, field
|
7
7
|
from logging import getLogger
|
8
8
|
from pathlib import PurePath
|
9
|
-
from typing import Callable, Literal
|
9
|
+
from typing import Callable, Literal
|
10
10
|
|
11
11
|
from typing_extensions import Unpack
|
12
12
|
|
@@ -62,7 +62,7 @@ from inspect_ai.model import (
|
|
62
62
|
)
|
63
63
|
from inspect_ai.model._model import init_sample_model_usage, sample_model_usage
|
64
64
|
from inspect_ai.scorer import Scorer, Target
|
65
|
-
from inspect_ai.scorer._metric import Metric, SampleScore
|
65
|
+
from inspect_ai.scorer._metric import Metric, SampleScore
|
66
66
|
from inspect_ai.scorer._reducer.types import ScoreReducer
|
67
67
|
from inspect_ai.scorer._score import init_scoring_context
|
68
68
|
from inspect_ai.scorer._scorer import unique_scorer_name
|
@@ -136,7 +136,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
136
136
|
generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
|
137
137
|
|
138
138
|
# init task context
|
139
|
-
init_task_context(model, generate_config)
|
139
|
+
init_task_context(model, options.task.approval, generate_config)
|
140
140
|
|
141
141
|
# establish run_dir for duration of execution
|
142
142
|
with set_task_run_dir(task_run_dir(task)):
|
@@ -503,11 +503,8 @@ async def task_run_sample(
|
|
503
503
|
sample_scores = (
|
504
504
|
{
|
505
505
|
key: SampleScore(
|
506
|
+
score=score,
|
506
507
|
sample_id=previous_sample.id,
|
507
|
-
value=score.value,
|
508
|
-
answer=score.answer,
|
509
|
-
explanation=score.explanation,
|
510
|
-
metadata=score.metadata,
|
511
508
|
)
|
512
509
|
for key, score in previous_sample.scores.items()
|
513
510
|
}
|
@@ -652,11 +649,8 @@ async def task_run_sample(
|
|
652
649
|
)
|
653
650
|
if score_result is not None:
|
654
651
|
sample_score = SampleScore(
|
652
|
+
score=score_result,
|
655
653
|
sample_id=sample.id,
|
656
|
-
value=score_result.value,
|
657
|
-
answer=score_result.answer,
|
658
|
-
explanation=score_result.explanation,
|
659
|
-
metadata=score_result.metadata,
|
660
654
|
)
|
661
655
|
transcript()._event(
|
662
656
|
ScoreEvent(score=score_result, target=sample.target)
|
@@ -759,7 +753,7 @@ async def log_sample(
|
|
759
753
|
setup=sample.setup,
|
760
754
|
messages=state.messages,
|
761
755
|
output=state.output,
|
762
|
-
scores=
|
756
|
+
scores={k: v.score for k, v in scores.items()},
|
763
757
|
store=dict(state.store.items()),
|
764
758
|
events=list(transcript().events),
|
765
759
|
model_usage=sample_model_usage(),
|
inspect_ai/_eval/task/task.py
CHANGED
@@ -7,6 +7,7 @@ from typing_extensions import TypedDict, Unpack
|
|
7
7
|
|
8
8
|
from inspect_ai._util.logger import warn_once
|
9
9
|
from inspect_ai._util.registry import is_registry_object, registry_info
|
10
|
+
from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config
|
10
11
|
from inspect_ai.dataset import Dataset, MemoryDataset, Sample
|
11
12
|
from inspect_ai.log import EvalLog
|
12
13
|
from inspect_ai.model import GenerateConfig
|
@@ -49,6 +50,9 @@ class Task:
|
|
49
50
|
config (GenerateConfig): Model generation config.
|
50
51
|
sandbox (SandboxEnvironmentType | None): Sandbox environment type
|
51
52
|
(or optionally a str or tuple with a shorthand spec)
|
53
|
+
approval: (str | list[ApprovalPolicy] | None): Tool use approval policies.
|
54
|
+
Either a path to an approval policy config file or a list of approval policies.
|
55
|
+
Defaults to no approval policy.
|
52
56
|
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
|
53
57
|
reducer function(s) used to combine sample scores (defaults to "mean")
|
54
58
|
fail_on_error (bool | float | None): `True` to fail on first sample error
|
@@ -76,6 +80,7 @@ class Task:
|
|
76
80
|
metrics: list[Metric] | dict[str, list[Metric]] | None = None,
|
77
81
|
config: GenerateConfig = GenerateConfig(),
|
78
82
|
sandbox: SandboxEnvironmentType | None = None,
|
83
|
+
approval: str | list[ApprovalPolicy] | None = None,
|
79
84
|
epochs: int | Epochs | None = None,
|
80
85
|
fail_on_error: bool | float | None = None,
|
81
86
|
message_limit: int | None = None,
|
@@ -134,6 +139,11 @@ class Task:
|
|
134
139
|
self.metrics = metrics
|
135
140
|
self.config = config
|
136
141
|
self.sandbox = resolve_sandbox_environment(sandbox)
|
142
|
+
self.approval = (
|
143
|
+
approval_policies_from_config(approval)
|
144
|
+
if isinstance(approval, str)
|
145
|
+
else approval
|
146
|
+
)
|
137
147
|
self.epochs = epochs.epochs if epochs else None
|
138
148
|
self.epochs_reducer = epochs.reducer if epochs else None
|
139
149
|
self.fail_on_error = fail_on_error
|
inspect_ai/_util/ansi.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from rich.console import Console, RenderableType
|
5
|
+
|
6
|
+
|
7
|
+
def render_text(
|
8
|
+
text: RenderableType | list[RenderableType], styles: bool = True, **options: Any
|
9
|
+
) -> str:
|
10
|
+
"""Render text from Rich renderables.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
text (RenderableType | list[RenderableType]): Renderables.
|
14
|
+
styles (bool): If True, ansi escape codes will be included. False for plain text.
|
15
|
+
Defaults to True.
|
16
|
+
**options (Any): Additonal keyword arguments to pass to `Console` constructor.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
str: Rendered text (with ansi codes if `styles=True`)
|
20
|
+
"""
|
21
|
+
# resolve to text
|
22
|
+
text = text if isinstance(text, list) else [text]
|
23
|
+
|
24
|
+
# print to console attached to /dev/null
|
25
|
+
with open(os.devnull, "w") as f:
|
26
|
+
console = Console(file=f, record=True, force_terminal=True, **options)
|
27
|
+
for t in text:
|
28
|
+
console.print(t)
|
29
|
+
|
30
|
+
# export (optionally w/ ansi styles)
|
31
|
+
return console.export_text(styles=styles).strip()
|
inspect_ai/_util/format.py
CHANGED
@@ -26,3 +26,10 @@ def format_value(value: object, width: int) -> str:
|
|
26
26
|
elif isinstance(value, list | tuple | dict):
|
27
27
|
return pprint.pformat(value, width=width)
|
28
28
|
return str(value)
|
29
|
+
|
30
|
+
|
31
|
+
def format_progress_time(time: float, pad_hours: bool = True) -> str:
|
32
|
+
minutes, seconds = divmod(time, 60)
|
33
|
+
hours, minutes = divmod(minutes, 60)
|
34
|
+
hours_fmt = f"{hours:2.0f}" if pad_hours else f"{hours:.0f}"
|
35
|
+
return f"{hours_fmt}:{minutes:02.0f}:{seconds:02.0f}"
|
inspect_ai/_util/logger.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import atexit
|
1
2
|
import os
|
2
3
|
from logging import (
|
3
4
|
DEBUG,
|
@@ -30,7 +31,12 @@ from .constants import (
|
|
30
31
|
TRACE_LOG_LEVEL,
|
31
32
|
)
|
32
33
|
from .error import PrerequisiteError
|
33
|
-
from .trace import
|
34
|
+
from .trace import (
|
35
|
+
TraceFormatter,
|
36
|
+
compress_trace_log,
|
37
|
+
inspect_trace_file,
|
38
|
+
rotate_trace_files,
|
39
|
+
)
|
34
40
|
|
35
41
|
TRACE_FILE_NAME = "trace.log"
|
36
42
|
|
@@ -56,19 +62,13 @@ class LogHandler(RichHandler):
|
|
56
62
|
else:
|
57
63
|
self.file_logger_level = 0
|
58
64
|
|
59
|
-
# add a trace handler
|
60
|
-
|
61
|
-
have_existing_trace_file = default_trace_file.exists()
|
65
|
+
# add a trace file handler
|
66
|
+
rotate_trace_files() # remove oldest if > 10 trace files
|
62
67
|
env_trace_file = os.environ.get("INSPECT_TRACE_FILE", None)
|
63
|
-
trace_file = Path(env_trace_file) if env_trace_file else
|
64
|
-
|
65
|
-
self.trace_logger = TraceFileHandler(
|
66
|
-
trace_file.as_posix(),
|
67
|
-
backupCount=trace_total_files - 1, # exclude the current file (10 total)
|
68
|
-
)
|
68
|
+
trace_file = Path(env_trace_file) if env_trace_file else inspect_trace_file()
|
69
|
+
self.trace_logger = FileHandler(trace_file)
|
69
70
|
self.trace_logger.setFormatter(TraceFormatter())
|
70
|
-
|
71
|
-
self.trace_logger.doRollover()
|
71
|
+
atexit.register(compress_trace_log(self.trace_logger))
|
72
72
|
|
73
73
|
# set trace level
|
74
74
|
trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
|
inspect_ai/_util/throttle.py
CHANGED
@@ -3,7 +3,16 @@ from functools import wraps
|
|
3
3
|
from typing import Any, Callable
|
4
4
|
|
5
5
|
|
6
|
-
def throttle(seconds:
|
6
|
+
def throttle(seconds: float) -> Callable[..., Any]:
|
7
|
+
"""Throttle a function to ensure it is called no more than every n seconds.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
seconds (float): Throttle time.
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
Callable: Throttled function.
|
14
|
+
"""
|
15
|
+
|
7
16
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
8
17
|
last_called: float = 0
|
9
18
|
last_result: Any = None
|