inspect-ai 0.3.55__py3-none-any.whl → 0.3.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. inspect_ai/__init__.py +1 -0
  2. inspect_ai/_cli/common.py +1 -1
  3. inspect_ai/_cli/trace.py +33 -20
  4. inspect_ai/_display/core/active.py +1 -1
  5. inspect_ai/_display/core/display.py +1 -1
  6. inspect_ai/_display/core/footer.py +1 -1
  7. inspect_ai/_display/core/progress.py +0 -6
  8. inspect_ai/_display/core/rich.py +1 -1
  9. inspect_ai/_display/rich/display.py +2 -2
  10. inspect_ai/_display/textual/app.py +15 -17
  11. inspect_ai/_display/textual/widgets/clock.py +3 -3
  12. inspect_ai/_display/textual/widgets/samples.py +6 -13
  13. inspect_ai/_eval/context.py +9 -1
  14. inspect_ai/_eval/score.py +4 -10
  15. inspect_ai/_eval/task/results.py +5 -4
  16. inspect_ai/_eval/task/run.py +6 -12
  17. inspect_ai/_eval/task/task.py +10 -0
  18. inspect_ai/_util/ansi.py +31 -0
  19. inspect_ai/_util/format.py +7 -0
  20. inspect_ai/_util/logger.py +12 -12
  21. inspect_ai/_util/throttle.py +10 -1
  22. inspect_ai/_util/trace.py +43 -47
  23. inspect_ai/_util/transcript.py +4 -0
  24. inspect_ai/_util/vscode.py +51 -0
  25. inspect_ai/_view/notify.py +2 -1
  26. inspect_ai/_view/www/App.css +22 -1
  27. inspect_ai/_view/www/dist/assets/index.css +2374 -2
  28. inspect_ai/_view/www/dist/assets/index.js +29622 -24424
  29. inspect_ai/_view/www/log-schema.json +138 -90
  30. inspect_ai/_view/www/package.json +1 -0
  31. inspect_ai/_view/www/src/App.mjs +1 -0
  32. inspect_ai/_view/www/src/appearance/Icons.mjs +2 -0
  33. inspect_ai/_view/www/src/components/AsciiCinemaPlayer.mjs +74 -0
  34. inspect_ai/_view/www/src/components/CopyButton.mjs +0 -1
  35. inspect_ai/_view/www/src/components/HumanBaselineView.mjs +168 -0
  36. inspect_ai/_view/www/src/components/LightboxCarousel.mjs +217 -0
  37. inspect_ai/_view/www/src/components/Tools.mjs +11 -3
  38. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +3 -2
  39. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +1 -0
  40. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +56 -0
  41. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +17 -5
  42. inspect_ai/_view/www/src/types/asciicinema-player.d.ts +26 -0
  43. inspect_ai/_view/www/src/types/log.d.ts +26 -12
  44. inspect_ai/_view/www/yarn.lock +44 -0
  45. inspect_ai/approval/_apply.py +4 -0
  46. inspect_ai/approval/_human/panel.py +5 -8
  47. inspect_ai/dataset/_dataset.py +51 -10
  48. inspect_ai/dataset/_util.py +31 -3
  49. inspect_ai/log/__init__.py +2 -0
  50. inspect_ai/log/_log.py +5 -2
  51. inspect_ai/model/_call_tools.py +4 -2
  52. inspect_ai/model/_chat_message.py +3 -0
  53. inspect_ai/model/_model.py +42 -1
  54. inspect_ai/model/_providers/anthropic.py +4 -0
  55. inspect_ai/model/_render.py +9 -2
  56. inspect_ai/scorer/_metric.py +12 -1
  57. inspect_ai/solver/__init__.py +2 -0
  58. inspect_ai/solver/_human_agent/agent.py +83 -0
  59. inspect_ai/solver/_human_agent/commands/__init__.py +36 -0
  60. inspect_ai/solver/_human_agent/commands/clock.py +70 -0
  61. inspect_ai/solver/_human_agent/commands/command.py +59 -0
  62. inspect_ai/solver/_human_agent/commands/instructions.py +74 -0
  63. inspect_ai/solver/_human_agent/commands/note.py +42 -0
  64. inspect_ai/solver/_human_agent/commands/score.py +80 -0
  65. inspect_ai/solver/_human_agent/commands/status.py +62 -0
  66. inspect_ai/solver/_human_agent/commands/submit.py +151 -0
  67. inspect_ai/solver/_human_agent/install.py +222 -0
  68. inspect_ai/solver/_human_agent/panel.py +252 -0
  69. inspect_ai/solver/_human_agent/service.py +45 -0
  70. inspect_ai/solver/_human_agent/state.py +55 -0
  71. inspect_ai/solver/_human_agent/view.py +24 -0
  72. inspect_ai/solver/_task_state.py +28 -2
  73. inspect_ai/tool/_tool.py +10 -2
  74. inspect_ai/tool/_tools/_web_browser/_web_browser.py +13 -10
  75. inspect_ai/util/__init__.py +8 -4
  76. inspect_ai/{_util/display.py → util/_display.py} +6 -0
  77. inspect_ai/util/_panel.py +31 -9
  78. inspect_ai/util/_sandbox/__init__.py +0 -3
  79. inspect_ai/util/_sandbox/context.py +5 -1
  80. inspect_ai/util/_sandbox/docker/compose.py +16 -10
  81. inspect_ai/util/_sandbox/docker/docker.py +9 -6
  82. inspect_ai/util/_sandbox/docker/internal.py +1 -1
  83. inspect_ai/util/_sandbox/docker/util.py +2 -2
  84. inspect_ai/util/_sandbox/environment.py +6 -5
  85. inspect_ai/util/_sandbox/local.py +1 -1
  86. inspect_ai/util/_sandbox/service.py +22 -7
  87. inspect_ai/util/_store.py +5 -6
  88. inspect_ai/util/_store_model.py +110 -0
  89. inspect_ai/util/_throttle.py +32 -0
  90. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/METADATA +1 -1
  91. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/RECORD +95 -73
  92. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/LICENSE +0 -0
  93. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/WHEEL +0 -0
  94. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/entry_points.txt +0 -0
  95. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.56.dist-info}/top_level.txt +0 -0
inspect_ai/__init__.py CHANGED
@@ -9,6 +9,7 @@ from inspect_ai._eval.registry import task
9
9
  from inspect_ai._eval.score import score, score_async
10
10
  from inspect_ai._eval.task import Epochs, Task, TaskInfo, Tasks
11
11
  from inspect_ai._util.constants import PKG_NAME
12
+ from inspect_ai.solver._human_agent.agent import human_agent
12
13
 
13
14
  __version__ = importlib_version(PKG_NAME)
14
15
 
inspect_ai/_cli/common.py CHANGED
@@ -10,7 +10,7 @@ from inspect_ai._util.constants import (
10
10
  DEFAULT_LOG_LEVEL,
11
11
  DEFAULT_LOG_LEVEL_TRANSCRIPT,
12
12
  )
13
- from inspect_ai._util.display import init_display_type
13
+ from inspect_ai.util._display import init_display_type
14
14
 
15
15
 
16
16
  class CommonOptions(TypedDict):
inspect_ai/_cli/trace.py CHANGED
@@ -4,7 +4,7 @@ import time
4
4
  from datetime import datetime
5
5
  from json import dumps
6
6
  from pathlib import Path
7
- from typing import Callable, cast
7
+ from typing import Callable
8
8
 
9
9
  import click
10
10
  from pydantic_core import to_json
@@ -13,8 +13,12 @@ from rich.console import Console, RenderableType
13
13
  from rich.table import Column, Table
14
14
 
15
15
  from inspect_ai._util.error import PrerequisiteError
16
- from inspect_ai._util.logger import TRACE_FILE_NAME
17
- from inspect_ai._util.trace import ActionTraceRecord, inspect_trace_dir, read_trace_file
16
+ from inspect_ai._util.trace import (
17
+ ActionTraceRecord,
18
+ inspect_trace_dir,
19
+ list_trace_files,
20
+ read_trace_file,
21
+ )
18
22
 
19
23
 
20
24
  @click.group("trace")
@@ -36,32 +40,31 @@ def trace_command() -> None:
36
40
  )
37
41
  def list_command(json: bool) -> None:
38
42
  """List all trace files."""
39
- trace_dir = inspect_trace_dir()
40
- trace_files: list[dict[str, float | str]] = [
41
- {"mtime": f.lstat().st_mtime, "file": f.absolute().as_posix()}
42
- for f in trace_dir.iterdir()
43
- if f.is_file()
44
- ]
45
- trace_files.sort(key=lambda f: cast(float, f["mtime"]), reverse=True)
43
+ trace_files = list_trace_files()
46
44
  if json:
47
- print(dumps(trace_files, indent=2))
45
+ print(
46
+ dumps(
47
+ [dict(file=str(file.file), mtime=file.mtime) for file in trace_files],
48
+ indent=2,
49
+ )
50
+ )
48
51
  else:
49
52
  table = Table(box=None, show_header=True, pad_edge=False)
50
53
  table.add_column("Time")
51
54
  table.add_column("Trace File")
52
55
  for file in trace_files:
53
- mtime = datetime.fromtimestamp(cast(float, file["mtime"])).astimezone()
56
+ mtime = datetime.fromtimestamp(file.mtime).astimezone()
54
57
  table.add_row(
55
- mtime.strftime("%d-%b %H:%M:%S %Z"), shlex.quote(str(file["file"]))
58
+ mtime.strftime("%d-%b %H:%M:%S %Z"), shlex.quote(str(file.file))
56
59
  )
57
60
  r_print(table)
58
61
 
59
62
 
60
63
  @trace_command.command("dump")
61
- @click.argument("trace-file", type=str, required=False, default=TRACE_FILE_NAME)
62
- def read_command(trace_file: str) -> None:
64
+ @click.argument("trace-file", type=str, required=False)
65
+ def dump_command(trace_file: str | None) -> None:
63
66
  """Dump a trace file to stdout (as a JSON array of log records)."""
64
- trace_file_path = resolve_trace_file_path(trace_file)
67
+ trace_file_path = _resolve_trace_file_path(trace_file)
65
68
 
66
69
  traces = read_trace_file(trace_file_path)
67
70
  print(
@@ -70,16 +73,16 @@ def read_command(trace_file: str) -> None:
70
73
 
71
74
 
72
75
  @trace_command.command("anomalies")
73
- @click.argument("trace-file", type=str, required=False, default=TRACE_FILE_NAME)
76
+ @click.argument("trace-file", type=str, required=False)
74
77
  @click.option(
75
78
  "--all",
76
79
  is_flag=True,
77
80
  default=False,
78
81
  help="Show all anomolies including errors and timeouts (by default only still running and cancelled actions are shown).",
79
82
  )
80
- def anomolies_command(trace_file: str, all: bool) -> None:
83
+ def anomolies_command(trace_file: str | None, all: bool) -> None:
81
84
  """Look for anomalies in a trace file (never completed or cancelled actions)."""
82
- trace_file_path = resolve_trace_file_path(trace_file)
85
+ trace_file_path = _resolve_trace_file_path(trace_file)
83
86
  traces = read_trace_file(trace_file_path)
84
87
 
85
88
  # Track started actions
@@ -226,7 +229,17 @@ def _print_bucket(
226
229
  print_fn(table)
227
230
 
228
231
 
229
- def resolve_trace_file_path(trace_file: str) -> Path:
232
+ def _resolve_trace_file(trace_file: str | None) -> str:
233
+ if trace_file is None:
234
+ trace_files = list_trace_files()
235
+ if len(trace_files) == 0:
236
+ raise PrerequisiteError("No trace files currently availalble.")
237
+ trace_file = str(trace_files[0].file)
238
+ return trace_file
239
+
240
+
241
+ def _resolve_trace_file_path(trace_file: str | None) -> Path:
242
+ trace_file = _resolve_trace_file(trace_file)
230
243
  trace_file_path = Path(trace_file)
231
244
  if not trace_file_path.is_absolute():
232
245
  trace_file_path = inspect_trace_dir() / trace_file_path
@@ -3,7 +3,7 @@ from contextvars import ContextVar
3
3
 
4
4
  import rich
5
5
 
6
- from inspect_ai._util.display import display_type
6
+ from inspect_ai.util._display import display_type
7
7
  from inspect_ai.util._trace import trace_enabled
8
8
 
9
9
  from ..rich.display import RichDisplay
@@ -99,7 +99,7 @@ class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]):
99
99
  ) -> Iterator[Console]:
100
100
  yield rich.get_console()
101
101
 
102
- async def input_panel(self, title: str, panel: type[TP]) -> TP:
102
+ async def input_panel(self, panel_type: type[TP]) -> TP:
103
103
  raise NotImplementedError("input_panel not implemented by current display")
104
104
 
105
105
 
@@ -2,8 +2,8 @@ from rich.console import RenderableType
2
2
  from rich.text import Text
3
3
 
4
4
  from inspect_ai._util.logger import http_rate_limit_count
5
- from inspect_ai._util.throttle import throttle
6
5
  from inspect_ai.util._concurrency import concurrency_status
6
+ from inspect_ai.util._throttle import throttle
7
7
 
8
8
  from .config import task_dict
9
9
 
@@ -124,12 +124,6 @@ def progress_status_icon(result: TaskResult | None) -> str:
124
124
  return f"[{theme.meta}]⠿[{theme.meta}]"
125
125
 
126
126
 
127
- def progress_time(time: float) -> str:
128
- minutes, seconds = divmod(time, 60)
129
- hours, minutes = divmod(minutes, 60)
130
- return f"{hours:2.0f}:{minutes:02.0f}:{seconds:02.0f}"
131
-
132
-
133
127
  def progress_count(complete: int, total: int, width: int | None = None) -> str:
134
128
  # Pad the display to keep it stable as the
135
129
  # complete metrics
@@ -9,9 +9,9 @@ from rich.segment import Segment
9
9
  from rich.syntax import Syntax
10
10
  from typing_extensions import override
11
11
 
12
- from inspect_ai._util.display import display_type
13
12
  from inspect_ai._util.platform import is_running_in_jupyterlab, is_running_in_vscode
14
13
  from inspect_ai._util.transcript import transcript_code_theme
14
+ from inspect_ai.util._display import display_type
15
15
 
16
16
 
17
17
  def is_vscode_notebook(console: Console) -> bool:
@@ -12,9 +12,9 @@ from rich.table import Table
12
12
  from typing_extensions import override
13
13
 
14
14
  from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH
15
- from inspect_ai._util.display import display_type
16
- from inspect_ai._util.throttle import throttle
17
15
  from inspect_ai.log._transcript import InputEvent, transcript
16
+ from inspect_ai.util._display import display_type
17
+ from inspect_ai.util._throttle import throttle
18
18
  from inspect_ai.util._trace import trace_enabled
19
19
 
20
20
  from ..core.config import task_config
@@ -10,6 +10,7 @@ from textual.app import App, ComposeResult
10
10
  from textual.binding import Binding, BindingType
11
11
  from textual.css.query import NoMatches
12
12
  from textual.events import Print
13
+ from textual.widget import Widget
13
14
  from textual.widgets import TabbedContent, TabPane
14
15
  from textual.widgets.tabbed_content import ContentTabs
15
16
  from textual.worker import Worker, WorkerState
@@ -345,13 +346,15 @@ class TaskScreenApp(App[TR]):
345
346
  self.update_title()
346
347
 
347
348
  # dynamic input panels
348
- async def add_input_panel(self, title: str, panel: InputPanel) -> None:
349
+ async def add_input_panel(self, panel: InputPanel) -> None:
349
350
  tabs = self.query_one(TabbedContent)
350
- await tabs.add_pane(TabPane(title, panel, id=as_input_panel_id(title)))
351
+ await tabs.add_pane(
352
+ TabPane(panel.title, panel, id=as_input_panel_id(type(panel)))
353
+ )
351
354
 
352
- def get_input_panel(self, title: str) -> InputPanel | None:
355
+ def get_input_panel(self, panel_type: type) -> InputPanel | None:
353
356
  try:
354
- tab_pane = self.query_one(f"#{as_input_panel_id(title)}")
357
+ tab_pane = self.query_one(f"#{as_input_panel_id(panel_type)}")
355
358
  if len(tab_pane.children) > 0:
356
359
  return cast(InputPanel, tab_pane.children[0])
357
360
  else:
@@ -359,10 +362,6 @@ class TaskScreenApp(App[TR]):
359
362
  except NoMatches:
360
363
  return None
361
364
 
362
- async def remove_input_panel(self, title: str) -> None:
363
- tabs = self.query_one(TabbedContent)
364
- await tabs.remove_pane(as_html_id(as_input_panel_id(title), title))
365
-
366
365
  class InputPanelHost(InputPanel.Host):
367
366
  def __init__(self, app: "TaskScreenApp[TR]", tab_id: str) -> None:
368
367
  self.app = app
@@ -383,7 +382,7 @@ class TaskScreenApp(App[TR]):
383
382
  # the tabs control so the user can switch back w/ the keyboard
384
383
  tab_pane = self.app.query_one(f"#{self.tab_id}")
385
384
  panel = cast(InputPanel, tab_pane.children[0])
386
- for child in panel.children:
385
+ for child in panel.walk_children(Widget):
387
386
  if child.focusable:
388
387
  child.focus()
389
388
  self.app.query_one(ContentTabs).focus()
@@ -455,19 +454,18 @@ class TextualTaskScreen(TaskScreen, Generic[TR]):
455
454
  console.width = old_width
456
455
 
457
456
  @override
458
- async def input_panel(self, title: str, panel: type[TP]) -> TP:
457
+ async def input_panel(self, panel_type: type[TP]) -> TP:
459
458
  async with self.lock:
460
- panel_widget = self.app.get_input_panel(title)
459
+ panel_widget = self.app.get_input_panel(panel_type)
461
460
  if panel_widget is None:
462
- panel_widget = panel(
463
- title,
461
+ panel_widget = panel_type(
464
462
  TaskScreenApp[TR].InputPanelHost(
465
- self.app, as_input_panel_id(title)
463
+ self.app, as_input_panel_id(panel_type)
466
464
  ),
467
465
  )
468
- await self.app.add_input_panel(title, panel_widget)
466
+ await self.app.add_input_panel(panel_widget)
469
467
  return cast(TP, panel_widget)
470
468
 
471
469
 
472
- def as_input_panel_id(title: str) -> str:
473
- return as_html_id("id-input-panel", title)
470
+ def as_input_panel_id(panel_type: type) -> str:
471
+ return as_html_id("id-input-panel", panel_type.__name__)
@@ -4,7 +4,7 @@ from textual.reactive import reactive
4
4
  from textual.timer import Timer
5
5
  from textual.widgets import Static
6
6
 
7
- from inspect_ai._display.core.progress import progress_time
7
+ from inspect_ai._util.format import format_progress_time
8
8
 
9
9
 
10
10
  class Clock(Static):
@@ -43,7 +43,7 @@ class Clock(Static):
43
43
  if start_time is not None:
44
44
  if self.timer is None:
45
45
  self.timer = self.set_interval(self.interval, self.update_time)
46
- self.update(progress_time(start_time))
46
+ self.update(format_progress_time(start_time))
47
47
  else:
48
48
  self.stop()
49
49
 
@@ -52,4 +52,4 @@ class Clock(Static):
52
52
  self.time = datetime.now().timestamp() - self.start_time
53
53
 
54
54
  def watch_time(self, time: float) -> None:
55
- self.update(progress_time(time))
55
+ self.update(format_progress_time(time))
@@ -22,10 +22,10 @@ from textual.widgets import (
22
22
  )
23
23
  from textual.widgets.option_list import Option, Separator
24
24
 
25
+ from inspect_ai._util.format import format_progress_time
25
26
  from inspect_ai._util.registry import registry_unqualified_name
26
27
  from inspect_ai.log._samples import ActiveSample
27
28
 
28
- from ...core.progress import progress_time
29
29
  from .clock import Clock
30
30
  from .transcript import TranscriptView
31
31
 
@@ -147,7 +147,9 @@ class SamplesList(OptionList):
147
147
  table.add_column(width=1)
148
148
  task_name = Text.from_markup(f"{registry_unqualified_name(sample.task)}")
149
149
  task_name.truncate(18, overflow="ellipsis", pad=True)
150
- task_time = Text.from_markup(f"{progress_time(sample.execution_time)}")
150
+ task_time = Text.from_markup(
151
+ f"{format_progress_time(sample.execution_time)}"
152
+ )
151
153
  table.add_row(task_name, task_time, " ")
152
154
  sample_id = Text.from_markup(f"id: {sample.sample.id}")
153
155
  sample_id.truncate(18, overflow="ellipsis", pad=True)
@@ -308,12 +310,7 @@ class SandboxesView(Vertical):
308
310
  yield Vertical(id="sandboxes-list")
309
311
 
310
312
  async def sync_sample(self, sample: ActiveSample) -> None:
311
- sandboxes = sample.sandboxes
312
- show_sandboxes = (
313
- len([sandbox for sandbox in sandboxes.values() if sandbox.container]) > 0
314
- )
315
-
316
- if show_sandboxes:
313
+ if len(sample.sandboxes) > 0:
317
314
  self.display = True
318
315
  sandboxes_caption = cast(Static, self.query_one("#sandboxes-caption"))
319
316
  sandboxes_caption.update("[bold]sandbox containers:[/bold]")
@@ -321,11 +318,7 @@ class SandboxesView(Vertical):
321
318
  sandboxes_list = self.query_one("#sandboxes-list")
322
319
  await sandboxes_list.remove_children()
323
320
  await sandboxes_list.mount_all(
324
- [
325
- Static(sandbox.container)
326
- for sandbox in sandboxes.values()
327
- if sandbox.container
328
- ]
321
+ [Static(sandbox.command) for sandbox in sample.sandboxes.values()]
329
322
  )
330
323
  sandboxes_list.mount(
331
324
  Static(
@@ -1,7 +1,9 @@
1
1
  from inspect_ai._util.dotenv import init_dotenv
2
2
  from inspect_ai._util.hooks import init_hooks
3
3
  from inspect_ai._util.logger import init_http_rate_limit_count, init_logger
4
+ from inspect_ai.approval._apply import have_tool_approval, init_tool_approval
4
5
  from inspect_ai.approval._human.manager import init_human_approval_manager
6
+ from inspect_ai.approval._policy import ApprovalPolicy
5
7
  from inspect_ai.log._samples import init_active_samples
6
8
  from inspect_ai.model import GenerateConfig, Model
7
9
  from inspect_ai.model._model import init_active_model, init_model_usage
@@ -24,6 +26,12 @@ def init_eval_context(
24
26
  init_human_approval_manager()
25
27
 
26
28
 
27
- def init_task_context(model: Model, config: GenerateConfig = GenerateConfig()) -> None:
29
+ def init_task_context(
30
+ model: Model,
31
+ approval: list[ApprovalPolicy] | None = None,
32
+ config: GenerateConfig = GenerateConfig(),
33
+ ) -> None:
28
34
  init_active_model(model, config)
29
35
  init_model_usage()
36
+ if not have_tool_approval():
37
+ init_tool_approval(approval)
inspect_ai/_eval/score.py CHANGED
@@ -11,7 +11,7 @@ from inspect_ai.log import (
11
11
  EvalMetric,
12
12
  )
13
13
  from inspect_ai.model import ModelName
14
- from inspect_ai.scorer import Metric, Score, Scorer, Target
14
+ from inspect_ai.scorer import Metric, Scorer, Target
15
15
  from inspect_ai.scorer._metric import SampleScore
16
16
  from inspect_ai.scorer._reducer import (
17
17
  ScoreReducer,
@@ -108,7 +108,7 @@ async def score_async(
108
108
 
109
109
  # write them back (gather ensures that they come back in the same order)
110
110
  for index, score in enumerate(scores):
111
- log.samples[index].scores = cast(dict[str, Score], score)
111
+ log.samples[index].scores = {k: v.score for k, v in score.items()}
112
112
 
113
113
  # collect metrics from EvalLog (they may overlap w/ the scorer metrics,
114
114
  # that will be taken care of in eval_results)
@@ -151,11 +151,8 @@ async def task_score(task: Task, log: EvalLog) -> EvalLog:
151
151
  sample_scores = [
152
152
  {
153
153
  score_key: SampleScore(
154
+ score=score,
154
155
  sample_id=sample.id,
155
- value=score.value,
156
- answer=score.answer,
157
- explanation=score.explanation,
158
- metadata=score.metadata,
159
156
  )
160
157
  for score_key, score in sample.scores.items()
161
158
  }
@@ -185,11 +182,8 @@ async def run_score_task(
185
182
  scorer_name = unique_scorer_name(scorer, list(results.keys()))
186
183
 
187
184
  results[scorer_name] = SampleScore(
185
+ score=result,
188
186
  sample_id=state.sample_id,
189
- value=result.value,
190
- answer=result.answer,
191
- explanation=result.explanation,
192
- metadata=result.metadata,
193
187
  )
194
188
 
195
189
  progress()
@@ -13,6 +13,7 @@ from inspect_ai._util.registry import (
13
13
  from inspect_ai.log import (
14
14
  EvalMetric,
15
15
  EvalResults,
16
+ EvalSampleScore,
16
17
  EvalScore,
17
18
  )
18
19
  from inspect_ai.log._log import EvalSampleReductions
@@ -345,7 +346,7 @@ def resolve_glob_metric_keys(
345
346
 
346
347
  def reduce_scores(
347
348
  scores: list[SampleScore], reducer: ScoreReducer
348
- ) -> list[SampleScore]:
349
+ ) -> list[EvalSampleScore]:
349
350
  # Group the scores by sample_id
350
351
  grouped_scores: dict[str, list[SampleScore]] = defaultdict(list)
351
352
  for sample_score in scores:
@@ -353,11 +354,11 @@ def reduce_scores(
353
354
  grouped_scores[str(sample_score.sample_id)].append(sample_score)
354
355
 
355
356
  # reduce the scores
356
- reduced_scores: list[SampleScore] = []
357
+ reduced_scores: list[EvalSampleScore] = []
357
358
  for scores in grouped_scores.values():
358
- reduced = reducer(cast(list[Score], scores))
359
+ reduced = reducer([score.score for score in scores])
359
360
  reduced_scores.append(
360
- SampleScore(
361
+ EvalSampleScore(
361
362
  sample_id=scores[0].sample_id,
362
363
  value=reduced.value,
363
364
  answer=reduced.answer,
@@ -6,7 +6,7 @@ from copy import deepcopy
6
6
  from dataclasses import dataclass, field
7
7
  from logging import getLogger
8
8
  from pathlib import PurePath
9
- from typing import Callable, Literal, cast
9
+ from typing import Callable, Literal
10
10
 
11
11
  from typing_extensions import Unpack
12
12
 
@@ -62,7 +62,7 @@ from inspect_ai.model import (
62
62
  )
63
63
  from inspect_ai.model._model import init_sample_model_usage, sample_model_usage
64
64
  from inspect_ai.scorer import Scorer, Target
65
- from inspect_ai.scorer._metric import Metric, SampleScore, Score
65
+ from inspect_ai.scorer._metric import Metric, SampleScore
66
66
  from inspect_ai.scorer._reducer.types import ScoreReducer
67
67
  from inspect_ai.scorer._score import init_scoring_context
68
68
  from inspect_ai.scorer._scorer import unique_scorer_name
@@ -136,7 +136,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
136
136
  generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
137
137
 
138
138
  # init task context
139
- init_task_context(model, generate_config)
139
+ init_task_context(model, options.task.approval, generate_config)
140
140
 
141
141
  # establish run_dir for duration of execution
142
142
  with set_task_run_dir(task_run_dir(task)):
@@ -503,11 +503,8 @@ async def task_run_sample(
503
503
  sample_scores = (
504
504
  {
505
505
  key: SampleScore(
506
+ score=score,
506
507
  sample_id=previous_sample.id,
507
- value=score.value,
508
- answer=score.answer,
509
- explanation=score.explanation,
510
- metadata=score.metadata,
511
508
  )
512
509
  for key, score in previous_sample.scores.items()
513
510
  }
@@ -652,11 +649,8 @@ async def task_run_sample(
652
649
  )
653
650
  if score_result is not None:
654
651
  sample_score = SampleScore(
652
+ score=score_result,
655
653
  sample_id=sample.id,
656
- value=score_result.value,
657
- answer=score_result.answer,
658
- explanation=score_result.explanation,
659
- metadata=score_result.metadata,
660
654
  )
661
655
  transcript()._event(
662
656
  ScoreEvent(score=score_result, target=sample.target)
@@ -759,7 +753,7 @@ async def log_sample(
759
753
  setup=sample.setup,
760
754
  messages=state.messages,
761
755
  output=state.output,
762
- scores=cast(dict[str, Score], scores),
756
+ scores={k: v.score for k, v in scores.items()},
763
757
  store=dict(state.store.items()),
764
758
  events=list(transcript().events),
765
759
  model_usage=sample_model_usage(),
@@ -7,6 +7,7 @@ from typing_extensions import TypedDict, Unpack
7
7
 
8
8
  from inspect_ai._util.logger import warn_once
9
9
  from inspect_ai._util.registry import is_registry_object, registry_info
10
+ from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config
10
11
  from inspect_ai.dataset import Dataset, MemoryDataset, Sample
11
12
  from inspect_ai.log import EvalLog
12
13
  from inspect_ai.model import GenerateConfig
@@ -49,6 +50,9 @@ class Task:
49
50
  config (GenerateConfig): Model generation config.
50
51
  sandbox (SandboxEnvironmentType | None): Sandbox environment type
51
52
  (or optionally a str or tuple with a shorthand spec)
53
+ approval: (str | list[ApprovalPolicy] | None): Tool use approval policies.
54
+ Either a path to an approval policy config file or a list of approval policies.
55
+ Defaults to no approval policy.
52
56
  epochs (int | Epochs | None): Epochs to repeat samples for and optional score
53
57
  reducer function(s) used to combine sample scores (defaults to "mean")
54
58
  fail_on_error (bool | float | None): `True` to fail on first sample error
@@ -76,6 +80,7 @@ class Task:
76
80
  metrics: list[Metric] | dict[str, list[Metric]] | None = None,
77
81
  config: GenerateConfig = GenerateConfig(),
78
82
  sandbox: SandboxEnvironmentType | None = None,
83
+ approval: str | list[ApprovalPolicy] | None = None,
79
84
  epochs: int | Epochs | None = None,
80
85
  fail_on_error: bool | float | None = None,
81
86
  message_limit: int | None = None,
@@ -134,6 +139,11 @@ class Task:
134
139
  self.metrics = metrics
135
140
  self.config = config
136
141
  self.sandbox = resolve_sandbox_environment(sandbox)
142
+ self.approval = (
143
+ approval_policies_from_config(approval)
144
+ if isinstance(approval, str)
145
+ else approval
146
+ )
137
147
  self.epochs = epochs.epochs if epochs else None
138
148
  self.epochs_reducer = epochs.reducer if epochs else None
139
149
  self.fail_on_error = fail_on_error
@@ -0,0 +1,31 @@
1
+ import os
2
+ from typing import Any
3
+
4
+ from rich.console import Console, RenderableType
5
+
6
+
7
+ def render_text(
8
+ text: RenderableType | list[RenderableType], styles: bool = True, **options: Any
9
+ ) -> str:
10
+ """Render text from Rich renderables.
11
+
12
+ Args:
13
+ text (RenderableType | list[RenderableType]): Renderables.
14
+ styles (bool): If True, ansi escape codes will be included. False for plain text.
15
+ Defaults to True.
16
+ **options (Any): Additonal keyword arguments to pass to `Console` constructor.
17
+
18
+ Returns:
19
+ str: Rendered text (with ansi codes if `styles=True`)
20
+ """
21
+ # resolve to text
22
+ text = text if isinstance(text, list) else [text]
23
+
24
+ # print to console attached to /dev/null
25
+ with open(os.devnull, "w") as f:
26
+ console = Console(file=f, record=True, force_terminal=True, **options)
27
+ for t in text:
28
+ console.print(t)
29
+
30
+ # export (optionally w/ ansi styles)
31
+ return console.export_text(styles=styles).strip()
@@ -26,3 +26,10 @@ def format_value(value: object, width: int) -> str:
26
26
  elif isinstance(value, list | tuple | dict):
27
27
  return pprint.pformat(value, width=width)
28
28
  return str(value)
29
+
30
+
31
+ def format_progress_time(time: float, pad_hours: bool = True) -> str:
32
+ minutes, seconds = divmod(time, 60)
33
+ hours, minutes = divmod(minutes, 60)
34
+ hours_fmt = f"{hours:2.0f}" if pad_hours else f"{hours:.0f}"
35
+ return f"{hours_fmt}:{minutes:02.0f}:{seconds:02.0f}"
@@ -1,3 +1,4 @@
1
+ import atexit
1
2
  import os
2
3
  from logging import (
3
4
  DEBUG,
@@ -30,7 +31,12 @@ from .constants import (
30
31
  TRACE_LOG_LEVEL,
31
32
  )
32
33
  from .error import PrerequisiteError
33
- from .trace import TraceFileHandler, TraceFormatter, inspect_trace_dir
34
+ from .trace import (
35
+ TraceFormatter,
36
+ compress_trace_log,
37
+ inspect_trace_file,
38
+ rotate_trace_files,
39
+ )
34
40
 
35
41
  TRACE_FILE_NAME = "trace.log"
36
42
 
@@ -56,19 +62,13 @@ class LogHandler(RichHandler):
56
62
  else:
57
63
  self.file_logger_level = 0
58
64
 
59
- # add a trace handler
60
- default_trace_file = inspect_trace_dir() / TRACE_FILE_NAME
61
- have_existing_trace_file = default_trace_file.exists()
65
+ # add a trace file handler
66
+ rotate_trace_files() # remove oldest if > 10 trace files
62
67
  env_trace_file = os.environ.get("INSPECT_TRACE_FILE", None)
63
- trace_file = Path(env_trace_file) if env_trace_file else default_trace_file
64
- trace_total_files = 10
65
- self.trace_logger = TraceFileHandler(
66
- trace_file.as_posix(),
67
- backupCount=trace_total_files - 1, # exclude the current file (10 total)
68
- )
68
+ trace_file = Path(env_trace_file) if env_trace_file else inspect_trace_file()
69
+ self.trace_logger = FileHandler(trace_file)
69
70
  self.trace_logger.setFormatter(TraceFormatter())
70
- if have_existing_trace_file:
71
- self.trace_logger.doRollover()
71
+ atexit.register(compress_trace_log(self.trace_logger))
72
72
 
73
73
  # set trace level
74
74
  trace_level = os.environ.get("INSPECT_TRACE_LEVEL", TRACE_LOG_LEVEL)
@@ -3,7 +3,16 @@ from functools import wraps
3
3
  from typing import Any, Callable
4
4
 
5
5
 
6
- def throttle(seconds: int) -> Callable[..., Any]:
6
+ def throttle(seconds: float) -> Callable[..., Any]:
7
+ """Throttle a function to ensure it is called no more than every n seconds.
8
+
9
+ Args:
10
+ seconds (float): Throttle time.
11
+
12
+ Returns:
13
+ Callable: Throttled function.
14
+ """
15
+
7
16
  def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
8
17
  last_called: float = 0
9
18
  last_result: Any = None