inspect-ai 0.3.49__py3-none-any.whl → 0.3.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. inspect_ai/_cli/info.py +2 -2
  2. inspect_ai/_cli/log.py +2 -2
  3. inspect_ai/_cli/score.py +2 -2
  4. inspect_ai/_display/core/display.py +19 -0
  5. inspect_ai/_display/core/panel.py +37 -7
  6. inspect_ai/_display/core/progress.py +29 -2
  7. inspect_ai/_display/core/results.py +79 -40
  8. inspect_ai/_display/core/textual.py +21 -0
  9. inspect_ai/_display/rich/display.py +28 -8
  10. inspect_ai/_display/textual/app.py +107 -1
  11. inspect_ai/_display/textual/display.py +1 -1
  12. inspect_ai/_display/textual/widgets/samples.py +132 -91
  13. inspect_ai/_display/textual/widgets/task_detail.py +236 -0
  14. inspect_ai/_display/textual/widgets/tasks.py +74 -6
  15. inspect_ai/_display/textual/widgets/toggle.py +32 -0
  16. inspect_ai/_eval/context.py +2 -0
  17. inspect_ai/_eval/eval.py +4 -3
  18. inspect_ai/_eval/loader.py +1 -1
  19. inspect_ai/_eval/run.py +35 -2
  20. inspect_ai/_eval/task/log.py +13 -11
  21. inspect_ai/_eval/task/results.py +12 -3
  22. inspect_ai/_eval/task/run.py +139 -36
  23. inspect_ai/_eval/task/sandbox.py +2 -1
  24. inspect_ai/_util/_async.py +30 -1
  25. inspect_ai/_util/file.py +31 -4
  26. inspect_ai/_util/html.py +3 -0
  27. inspect_ai/_util/logger.py +6 -5
  28. inspect_ai/_util/platform.py +5 -6
  29. inspect_ai/_util/registry.py +1 -1
  30. inspect_ai/_view/server.py +9 -9
  31. inspect_ai/_view/www/App.css +2 -2
  32. inspect_ai/_view/www/dist/assets/index.css +2 -2
  33. inspect_ai/_view/www/dist/assets/index.js +352 -294
  34. inspect_ai/_view/www/log-schema.json +13 -0
  35. inspect_ai/_view/www/package.json +1 -0
  36. inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
  37. inspect_ai/_view/www/src/components/Tools.mjs +16 -13
  38. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
  39. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
  40. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
  41. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
  42. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
  43. inspect_ai/_view/www/src/types/log.d.ts +2 -0
  44. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
  45. inspect_ai/_view/www/yarn.lock +9 -4
  46. inspect_ai/approval/__init__.py +1 -1
  47. inspect_ai/approval/_human/approver.py +35 -0
  48. inspect_ai/approval/_human/console.py +62 -0
  49. inspect_ai/approval/_human/manager.py +108 -0
  50. inspect_ai/approval/_human/panel.py +233 -0
  51. inspect_ai/approval/_human/util.py +51 -0
  52. inspect_ai/dataset/_sources/hf.py +2 -2
  53. inspect_ai/dataset/_sources/util.py +1 -1
  54. inspect_ai/log/_file.py +106 -36
  55. inspect_ai/log/_recorders/eval.py +226 -158
  56. inspect_ai/log/_recorders/file.py +9 -6
  57. inspect_ai/log/_recorders/json.py +35 -12
  58. inspect_ai/log/_recorders/recorder.py +15 -15
  59. inspect_ai/log/_samples.py +52 -0
  60. inspect_ai/model/_model.py +14 -0
  61. inspect_ai/model/_model_output.py +4 -0
  62. inspect_ai/model/_providers/azureai.py +1 -1
  63. inspect_ai/model/_providers/hf.py +106 -4
  64. inspect_ai/model/_providers/util/__init__.py +2 -0
  65. inspect_ai/model/_providers/util/hf_handler.py +200 -0
  66. inspect_ai/scorer/_common.py +1 -1
  67. inspect_ai/solver/_plan.py +0 -8
  68. inspect_ai/solver/_task_state.py +18 -1
  69. inspect_ai/solver/_use_tools.py +9 -1
  70. inspect_ai/tool/_tool_def.py +2 -2
  71. inspect_ai/tool/_tool_info.py +14 -2
  72. inspect_ai/tool/_tool_params.py +2 -1
  73. inspect_ai/tool/_tools/_execute.py +1 -1
  74. inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
  75. inspect_ai/util/__init__.py +5 -6
  76. inspect_ai/util/_panel.py +91 -0
  77. inspect_ai/util/_sandbox/__init__.py +2 -6
  78. inspect_ai/util/_sandbox/context.py +4 -3
  79. inspect_ai/util/_sandbox/docker/compose.py +12 -2
  80. inspect_ai/util/_sandbox/docker/docker.py +19 -9
  81. inspect_ai/util/_sandbox/docker/util.py +10 -2
  82. inspect_ai/util/_sandbox/environment.py +47 -41
  83. inspect_ai/util/_sandbox/local.py +15 -10
  84. inspect_ai/util/_subprocess.py +43 -3
  85. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/METADATA +2 -2
  86. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/RECORD +90 -82
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  89. inspect_ai/approval/_human.py +0 -123
  90. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/LICENSE +0 -0
  91. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/WHEEL +0 -0
  92. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/entry_points.txt +0 -0
  93. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,236 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from textual.app import ComposeResult
6
+ from textual.containers import Center, Grid, Horizontal
7
+ from textual.reactive import Reactive, reactive
8
+ from textual.widget import Widget
9
+ from textual.widgets import Static
10
+
11
+ from inspect_ai._display.core.display import TaskDisplayMetric
12
+
13
+
14
+ @dataclass
15
+ class TaskMetric:
16
+ name: str
17
+ value: float
18
+
19
+
20
+ class TaskDetail(Widget):
21
+ hidden = reactive(False)
22
+ DEFAULT_CSS = """
23
+ TaskDetail {
24
+ background: $boost;
25
+ width: 100%;
26
+ height: auto;
27
+ padding: 1 0 1 0;
28
+ }
29
+ TaskDetail Grid {
30
+ width: 100%;
31
+ height: auto;
32
+ grid-gutter: 1 3;
33
+ }
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ *,
39
+ hidden: bool = True,
40
+ id: str | None = None,
41
+ classes: str | None = None,
42
+ ) -> None:
43
+ super().__init__(id=id, classes=classes)
44
+ self.hidden = hidden
45
+ self.existing_metrics: dict[str, TaskMetrics] = {}
46
+ self.grid = Grid()
47
+ self.by_reducer: dict[str | None, dict[str, list[TaskMetric]]] = {}
48
+ self.metrics: list[TaskDisplayMetric] = []
49
+
50
+ def watch_hidden(self, hidden: bool) -> None:
51
+ """React to changes in the `visible` property."""
52
+ if hidden:
53
+ self.add_class("hidden")
54
+ else:
55
+ self.remove_class("hidden")
56
+
57
+ def compose(self) -> ComposeResult:
58
+ yield self.grid
59
+
60
+ def on_mount(self) -> None:
61
+ self.refresh_grid()
62
+
63
+ def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
64
+ # Group by reducer then scorer within reducers
65
+ self.metrics = metrics
66
+ for metric in metrics:
67
+ reducer_group = (
68
+ self.by_reducer[metric.reducer]
69
+ if metric.reducer in self.by_reducer
70
+ else {}
71
+ )
72
+
73
+ by_scorer_metrics = (
74
+ reducer_group[metric.scorer] if metric.scorer in reducer_group else []
75
+ )
76
+ by_scorer_metrics.append(TaskMetric(name=metric.name, value=metric.value))
77
+ reducer_group[metric.scorer] = by_scorer_metrics
78
+ self.by_reducer[metric.reducer] = reducer_group
79
+
80
+ self.refresh_grid()
81
+
82
+ def refresh_grid(self) -> None:
83
+ # Don't refresh the grid if not attached
84
+ # since we may explicitly mount new widgets
85
+ if not self.grid.is_attached:
86
+ return
87
+
88
+ # don't refresh the grid if there are no scores
89
+ if len(self.by_reducer) == 0:
90
+ return
91
+
92
+ # Compute the row and column count
93
+ row_count = len(self.by_reducer)
94
+ col_count = len(next(iter(self.by_reducer.values())))
95
+
96
+ # If this can fit in a single row, make it fit
97
+ # otherwise place each reducer on their own row
98
+ self.grid.styles.grid_columns = "auto"
99
+ if row_count * col_count < 4:
100
+ self.grid.styles.grid_size_columns = row_count * col_count
101
+ self.grid.styles.grid_size_rows = 1
102
+ else:
103
+ self.grid.styles.grid_size_columns = col_count
104
+ self.grid.styles.grid_size_rows = row_count
105
+
106
+ # In order to reduce flashing the below tracks use of widgets
107
+ # and updates them when possible (removing and adding them as needed)
108
+ # Makes keys for tracking Task Metric widgets
109
+ def metric_key(reducer: str | None, scorer: str) -> str:
110
+ reducer = reducer or "none"
111
+ return valid_id(f"task-{reducer}-{scorer}-tbl")
112
+
113
+ # Remove keys that are no longer present
114
+ existing_keys = set(self.existing_metrics.keys())
115
+ new_keys = set(metric_key(m.reducer, m.scorer) for m in self.metrics)
116
+ to_remove = existing_keys - new_keys
117
+ for remove in to_remove:
118
+ task_metric = self.existing_metrics[remove]
119
+ task_metric.remove()
120
+
121
+ # add or update widgets with metrics
122
+ for reducer, scorers in self.by_reducer.items():
123
+ for scorer, scores in scorers.items():
124
+ key = metric_key(reducer=reducer, scorer=scorer)
125
+ if key in self.existing_metrics:
126
+ task_metrics = self.existing_metrics[key]
127
+ task_metrics.update(scores)
128
+ else:
129
+ task_metrics = TaskMetrics(
130
+ id=key, scorer=scorer, reducer=reducer, metrics=scores
131
+ )
132
+ self.grid.mount(task_metrics)
133
+ self.existing_metrics[key] = task_metrics
134
+
135
+
136
+ class TaskMetrics(Widget):
137
+ DEFAULT_CSS = """
138
+ TaskMetrics {
139
+ width: auto;
140
+ height: auto;
141
+ }
142
+ TaskMetrics Grid {
143
+ width: auto;
144
+ grid-size: 2;
145
+ grid-columns: auto;
146
+ grid-gutter: 0 3;
147
+ padding: 0 2 0 2;
148
+ }
149
+ TaskMetric Center {
150
+ width: auto;
151
+ }
152
+ TaskMetrics Center Static {
153
+ width: auto;
154
+ }
155
+ TaskMetrics Center Horizontal {
156
+ width: auto;
157
+ height: auto;
158
+ }
159
+ TaskMetrics Center Horizontal Static {
160
+ width: auto;
161
+ height: auto;
162
+ }
163
+ TaskMetrics .scorer {
164
+ padding: 0 1 0 0;
165
+ text-style: bold;
166
+ }
167
+ TaskMetrics .reducer {
168
+ color: $foreground-darken-3;
169
+ }
170
+ """
171
+
172
+ metrics: Reactive[list[TaskMetric]] = reactive([])
173
+
174
+ def __init__(
175
+ self,
176
+ *,
177
+ scorer: str | None,
178
+ reducer: str | None,
179
+ metrics: list[TaskMetric],
180
+ id: str | None = None,
181
+ classes: str | None = None,
182
+ ) -> None:
183
+ super().__init__(id=id, classes=classes)
184
+ self.scorer = scorer
185
+ self.reducer = reducer
186
+ self.metrics = metrics
187
+ self.grid: Grid = Grid()
188
+ self.value_widgets: dict[str, Static] = {}
189
+
190
+ def compose(self) -> ComposeResult:
191
+ # Just yield a single DataTable widget
192
+ yield Center(self._title())
193
+ with Grid():
194
+ for metric in self.metrics:
195
+ # Add the value static but keep it around
196
+ # for future updates
197
+ self.value_widgets[metric.name] = Static(
198
+ self._metric_value(metric.value)
199
+ )
200
+
201
+ yield Static(metric.name)
202
+ yield self.value_widgets[metric.name]
203
+
204
+ def update(self, metrics: list[TaskMetric]) -> None:
205
+ for metric in metrics:
206
+ widget = self.value_widgets[metric.name]
207
+ widget.update(content=f"{metric.value:,.3f}")
208
+
209
+ def _title(self) -> Widget:
210
+ if self.scorer is None:
211
+ return Static("")
212
+ elif self.reducer is None:
213
+ return Static(self.scorer)
214
+ else:
215
+ return Horizontal(
216
+ Static(self.scorer, classes="scorer"),
217
+ Static(f"({self.reducer})", classes="reducer"),
218
+ )
219
+
220
+ def _metric_value(self, val: float) -> str:
221
+ if np.isnan(val):
222
+ return " n/a "
223
+ else:
224
+ return f"{val:.3f}"
225
+
226
+
227
+ def valid_id(identifier: str) -> str:
228
+ # Remove invalid characters
229
+ valid_part = re.sub(r"[^a-zA-Z0-9_-]", "_", identifier)
230
+
231
+ # Ensure it doesn't start with a number
232
+ if valid_part and valid_part[0].isdigit():
233
+ valid_part = "_" + valid_part
234
+
235
+ # If the string is empty return a default valid identifier
236
+ return valid_part or "default_identifier"
@@ -4,19 +4,25 @@ from typing import Iterator, cast
4
4
 
5
5
  from rich.console import RenderableType
6
6
  from rich.text import Text
7
+ from textual import on
7
8
  from textual.app import ComposeResult
8
9
  from textual.containers import Container, ScrollableContainer
10
+ from textual.css.query import NoMatches
9
11
  from textual.reactive import reactive
10
12
  from textual.widget import Widget
11
13
  from textual.widgets import ProgressBar, Static
12
14
  from typing_extensions import override
13
15
 
16
+ from inspect_ai._display.core.results import task_metric
14
17
  from inspect_ai._display.textual.widgets.clock import Clock
18
+ from inspect_ai._display.textual.widgets.task_detail import TaskDetail
19
+ from inspect_ai._display.textual.widgets.toggle import Toggle
15
20
 
16
21
  from ...core.display import (
17
22
  Progress,
18
23
  TaskCancelled,
19
24
  TaskDisplay,
25
+ TaskDisplayMetric,
20
26
  TaskError,
21
27
  TaskResult,
22
28
  TaskSpec,
@@ -25,6 +31,7 @@ from ...core.display import (
25
31
  from ...core.progress import (
26
32
  MAX_DESCRIPTION_WIDTH,
27
33
  MAX_MODEL_NAME_WIDTH,
34
+ progress_count,
28
35
  progress_description,
29
36
  progress_model_name,
30
37
  )
@@ -106,9 +113,10 @@ class TaskProgressView(Widget):
106
113
  height: auto;
107
114
  width: 1fr;
108
115
  layout: grid;
109
- grid-size: 5 1;
110
- grid-columns: auto auto auto 1fr auto;
111
- grid-gutter: 1;
116
+ grid-size: 8 2;
117
+ grid-columns: auto auto auto auto 1fr auto auto auto;
118
+ grid-rows: auto auto;
119
+ grid-gutter: 0 1;
112
120
  }
113
121
  TaskProgressView Bar {
114
122
  width: 1fr;
@@ -119,6 +127,15 @@ class TaskProgressView(Widget):
119
127
  color: $success;
120
128
  }
121
129
  }
130
+ #task-metrics {
131
+ color:$text-secondary;
132
+ }
133
+ #task-detail {
134
+ column-span: 8;
135
+ }
136
+ .hidden {
137
+ display: none;
138
+ }
122
139
  """
123
140
 
124
141
  def __init__(
@@ -126,12 +143,19 @@ class TaskProgressView(Widget):
126
143
  ) -> None:
127
144
  super().__init__()
128
145
  self.t = task
146
+
129
147
  self.description_width = description_width
130
148
  self.model_name_width = model_name_width
131
149
  self.progress_bar = ProgressBar(total=task.profile.steps, show_eta=False)
150
+ self.count_display = Static()
151
+ self.metrics_display = Static(id="task-metrics")
132
152
  self.task_progress = TaskProgress(self.progress_bar)
133
153
 
154
+ self.toggle = Toggle()
155
+ self.task_detail = TaskDetail(id="task-detail", classes="hidden")
156
+
134
157
  def compose(self) -> ComposeResult:
158
+ yield self.toggle
135
159
  yield TaskStatusIcon()
136
160
  yield Static(
137
161
  progress_description(self.t.profile, self.description_width, pad=True)
@@ -140,7 +164,15 @@ class TaskProgressView(Widget):
140
164
  progress_model_name(self.t.profile.model, self.model_name_width, pad=True)
141
165
  )
142
166
  yield self.progress_bar
167
+ yield self.count_display
168
+ yield self.metrics_display
143
169
  yield Clock()
170
+ yield self.task_detail
171
+
172
+ @on(Toggle.Toggled)
173
+ def handle_title_toggle(self, event: Toggle.Toggled) -> None:
174
+ self.task_detail.hidden = not self.toggle.toggled
175
+ event.stop()
144
176
 
145
177
  def on_mount(self) -> None:
146
178
  self.query_one(Clock).start(datetime.now().timestamp())
@@ -151,10 +183,21 @@ class TaskProgressView(Widget):
151
183
 
152
184
  def complete(self, result: TaskResult) -> None:
153
185
  self.t.result = result
154
- self.query_one(TaskStatusIcon).result = result
155
- self.query_one(Clock).stop()
186
+ try:
187
+ self.query_one(TaskStatusIcon).result = result
188
+ self.query_one(Clock).stop()
189
+ except NoMatches:
190
+ pass
156
191
  self.task_progress.complete()
157
192
 
193
+ def sample_complete(self, complete: int, total: int) -> None:
194
+ self.count_display.update(progress_count(complete, total))
195
+
196
+ def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
197
+ if len(metrics) > 0:
198
+ self.metrics_display.update(task_metric(metrics))
199
+ self.task_detail.update_metrics(metrics)
200
+
158
201
 
159
202
  class TaskStatusIcon(Static):
160
203
  result: reactive[TaskResult | None] = reactive(None)
@@ -181,13 +224,38 @@ class TaskStatusIcon(Static):
181
224
  return Text("⠿", style=running)
182
225
 
183
226
 
227
+ MAX_PROGRESS_PERCENT = 0.02
228
+ MIN_PROGRESS_PERCENT = 0.98
229
+
230
+
184
231
  class TaskProgress(Progress):
185
232
  def __init__(self, progress_bar: ProgressBar) -> None:
186
233
  self.progress_bar = progress_bar
234
+ self.current_progress = 0
235
+
236
+ # always show a minimum amount of progress
237
+ minimum_steps = (
238
+ MAX_PROGRESS_PERCENT * progress_bar.total
239
+ if progress_bar.total is not None
240
+ else 0
241
+ )
242
+ self.progress_bar.update(progress=minimum_steps)
187
243
 
188
244
  @override
189
245
  def update(self, n: int = 1) -> None:
190
- self.progress_bar.update(advance=n)
246
+ self.current_progress = self.current_progress + n
247
+
248
+ # enforce a maximum cap on task progress
249
+ max_progress = (
250
+ MIN_PROGRESS_PERCENT * self.progress_bar.total
251
+ if self.progress_bar.total is not None
252
+ else 0
253
+ )
254
+ if (
255
+ self.current_progress > self.progress_bar.progress
256
+ and self.current_progress < max_progress
257
+ ):
258
+ self.progress_bar.update(progress=self.current_progress)
191
259
 
192
260
  @override
193
261
  def complete(self) -> None:
@@ -0,0 +1,32 @@
1
+ from textual.events import Click
2
+ from textual.message import Message
3
+ from textual.reactive import reactive
4
+ from textual.widgets import Static
5
+
6
+
7
+ class Toggle(Static, can_focus=True):
8
+ toggled = reactive(True)
9
+
10
+ def __init__(
11
+ self, on_symbol: str = "▼", off_symbol: str = "▶", toggled: bool = False
12
+ ) -> None:
13
+ super().__init__()
14
+
15
+ self.on_symbol = on_symbol
16
+ self.off_symbol = off_symbol
17
+ self.toggled = toggled
18
+
19
+ class Toggled(Message):
20
+ """Request toggle."""
21
+
22
+ async def _on_click(self, event: Click) -> None:
23
+ """Inform ancestor we want to toggle."""
24
+ event.stop()
25
+ self.toggled = not self.toggled
26
+ self.post_message(self.Toggled())
27
+
28
+ def _watch_toggled(self, toggled: bool) -> None:
29
+ if toggled:
30
+ self.update(self.on_symbol)
31
+ else:
32
+ self.update(self.off_symbol)
@@ -1,6 +1,7 @@
1
1
  from inspect_ai._util.dotenv import init_dotenv
2
2
  from inspect_ai._util.hooks import init_hooks
3
3
  from inspect_ai._util.logger import init_http_rate_limit_count, init_logger
4
+ from inspect_ai.approval._human.manager import init_human_approval_manager
4
5
  from inspect_ai.log._samples import init_active_samples
5
6
  from inspect_ai.model import GenerateConfig, Model
6
7
  from inspect_ai.model._model import init_active_model, init_model_usage
@@ -20,6 +21,7 @@ def init_eval_context(
20
21
  init_http_rate_limit_count()
21
22
  init_hooks()
22
23
  init_active_samples()
24
+ init_human_approval_manager()
23
25
 
24
26
 
25
27
  def init_task_context(model: Model, config: GenerateConfig = GenerateConfig()) -> None:
inspect_ai/_eval/eval.py CHANGED
@@ -21,7 +21,8 @@ from inspect_ai.approval._policy import (
21
21
  approval_policies_from_config,
22
22
  config_from_approval_policies,
23
23
  )
24
- from inspect_ai.log import EvalConfig, EvalLog, EvalLogInfo, read_eval_log
24
+ from inspect_ai.log import EvalConfig, EvalLog, EvalLogInfo
25
+ from inspect_ai.log._file import read_eval_log_async
25
26
  from inspect_ai.log._recorders import create_recorder_for_format
26
27
  from inspect_ai.model import (
27
28
  GenerateConfig,
@@ -600,9 +601,9 @@ async def eval_retry_async(
600
601
  task
601
602
  if isinstance(task, EvalLog)
602
603
  else (
603
- read_eval_log(task.name)
604
+ await read_eval_log_async(task.name)
604
605
  if isinstance(task, EvalLogInfo)
605
- else read_eval_log(task)
606
+ else await read_eval_log_async(task)
606
607
  )
607
608
  )
608
609
  for task in tasks
@@ -198,7 +198,7 @@ def resolve_task_sandbox(
198
198
  break
199
199
 
200
200
  # resolve relative paths
201
- if resolved_sandbox.config is not None:
201
+ if isinstance(resolved_sandbox.config, str):
202
202
  file_path = Path(resolved_sandbox.config)
203
203
  if not file_path.is_absolute():
204
204
  file_path = Path(task_run_dir(task)) / file_path
inspect_ai/_eval/run.py CHANGED
@@ -12,9 +12,10 @@ from inspect_ai._display.core.active import (
12
12
  init_task_screen,
13
13
  )
14
14
  from inspect_ai._display.core.display import TaskSpec
15
- from inspect_ai._util.error import exception_message
15
+ from inspect_ai._util.error import PrerequisiteError, exception_message
16
16
  from inspect_ai._util.path import chdir
17
17
  from inspect_ai._util.registry import registry_unqualified_name
18
+ from inspect_ai.dataset._dataset import Dataset
18
19
  from inspect_ai.log import EvalConfig, EvalLog
19
20
  from inspect_ai.log._recorders import Recorder
20
21
  from inspect_ai.model import GenerateConfigArgs
@@ -23,6 +24,7 @@ from inspect_ai.scorer._reducer import ScoreReducer, reducer_log_names
23
24
  from inspect_ai.scorer._reducer.registry import validate_reducer
24
25
  from inspect_ai.solver._solver import Solver, SolverSpec
25
26
  from inspect_ai.util._sandbox.environment import (
27
+ SandboxEnvironmentConfigType,
26
28
  SandboxEnvironmentSpec,
27
29
  SandboxEnvironmentType,
28
30
  TaskCleanup,
@@ -149,6 +151,9 @@ async def eval_run(
149
151
  if sample.id is None:
150
152
  sample.id = id + 1
151
153
 
154
+ # Ensure sample ids are unique
155
+ ensure_unique_ids(task.dataset)
156
+
152
157
  # create and track the logger
153
158
  logger = TaskLogger(
154
159
  task_name=task.name,
@@ -168,6 +173,7 @@ async def eval_run(
168
173
  metadata=task.metadata,
169
174
  recorder=recorder,
170
175
  )
176
+ await logger.init()
171
177
 
172
178
  # append task
173
179
  task_run_options.append(
@@ -287,6 +293,12 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
287
293
  await task
288
294
  result = task.result()
289
295
  results.append(result)
296
+ except Exception as ex:
297
+ # errors generally don't escape from tasks (the exception being if an error
298
+ # occurs during the final write of the log)
299
+ log.error(
300
+ f"Task '{task_options.task.name}' encountered an error during finalisation: {ex}"
301
+ )
290
302
 
291
303
  # tracking
292
304
  tasks_completed += 1
@@ -340,7 +352,7 @@ async def startup_sandbox_environments(
340
352
  sandboxenvs.add(sandbox)
341
353
 
342
354
  # initialiase sandboxenvs (track cleanups)
343
- cleanups: list[tuple[TaskCleanup, str | None, str]] = []
355
+ cleanups: list[tuple[TaskCleanup, SandboxEnvironmentConfigType | None, str]] = []
344
356
  with display().suspend_task_app():
345
357
  for sandboxenv in sandboxenvs:
346
358
  # find type
@@ -377,3 +389,24 @@ def task_specs(tasks: list[TaskRunOptions]) -> list[TaskSpec]:
377
389
  TaskSpec(registry_unqualified_name(task.task.name), ModelName(task.model))
378
390
  for task in tasks
379
391
  ]
392
+
393
+
394
+ def ensure_unique_ids(dataset: Dataset) -> None:
395
+ """
396
+ Validates that all samples in the dataset have unique IDs.
397
+
398
+ Raises a error if duplicates are found.
399
+
400
+ Args:
401
+ dataset (Datatset): The dataset
402
+
403
+ Raises:
404
+ PrerequisiteError: If duplicate IDs are found in the dataset.
405
+ """
406
+ seen_ids = set()
407
+ for sample in dataset:
408
+ if sample.id in seen_ids:
409
+ raise PrerequisiteError(
410
+ f"The dataset contains duplicate sample ids (duplicate id: {sample.id}). Please ensure each sample has a unique id."
411
+ )
412
+ seen_ids.add(sample.id)
@@ -75,7 +75,7 @@ class TaskLogger:
75
75
  del model_args["api_key"]
76
76
 
77
77
  # cwd_relative_path for sandbox config
78
- if sandbox and sandbox.config:
78
+ if sandbox and isinstance(sandbox.config, str):
79
79
  sandbox = SandboxEnvironmentSpec(
80
80
  sandbox.type, cwd_relative_path(sandbox.config)
81
81
  )
@@ -118,7 +118,6 @@ class TaskLogger:
118
118
 
119
119
  # stack recorder and location
120
120
  self.recorder = recorder
121
- self._location = self.recorder.log_init(self.eval)
122
121
 
123
122
  # number of samples logged
124
123
  self._samples_completed = 0
@@ -127,6 +126,9 @@ class TaskLogger:
127
126
  self.flush_buffer = eval_config.log_buffer or recorder.default_log_buffer()
128
127
  self.flush_pending = 0
129
128
 
129
+ async def init(self) -> None:
130
+ self._location = await self.recorder.log_init(self.eval)
131
+
130
132
  @property
131
133
  def location(self) -> str:
132
134
  return self._location
@@ -135,25 +137,25 @@ class TaskLogger:
135
137
  def samples_completed(self) -> int:
136
138
  return self._samples_completed
137
139
 
138
- def log_start(self, plan: EvalPlan) -> None:
139
- self.recorder.log_start(self.eval, plan)
140
+ async def log_start(self, plan: EvalPlan) -> None:
141
+ await self.recorder.log_start(self.eval, plan)
140
142
 
141
- def log_sample(self, sample: EvalSample, *, flush: bool) -> None:
143
+ async def log_sample(self, sample: EvalSample, *, flush: bool) -> None:
142
144
  # log the sample
143
- self.recorder.log_sample(self.eval, sample)
145
+ await self.recorder.log_sample(self.eval, sample)
144
146
 
145
147
  # flush if requested
146
148
  if flush:
147
149
  self.flush_pending += 1
148
150
  if self.flush_pending >= self.flush_buffer:
149
- self.recorder.flush(self.eval)
151
+ await self.recorder.flush(self.eval)
150
152
  self.flush_pending = 0
151
153
 
152
154
  # track sucessful samples logged
153
155
  if sample.error is None:
154
156
  self._samples_completed += 1
155
157
 
156
- def log_finish(
158
+ async def log_finish(
157
159
  self,
158
160
  status: Literal["success", "cancelled", "error"],
159
161
  stats: EvalStats,
@@ -161,12 +163,12 @@ class TaskLogger:
161
163
  reductions: list[EvalSampleReductions] | None = None,
162
164
  error: EvalError | None = None,
163
165
  ) -> EvalLog:
164
- return self.recorder.log_finish(
166
+ return await self.recorder.log_finish(
165
167
  self.eval, status, stats, results, reductions, error
166
168
  )
167
169
 
168
170
 
169
- def log_start(
171
+ async def log_start(
170
172
  logger: TaskLogger,
171
173
  plan: Plan,
172
174
  config: GenerateConfig,
@@ -185,7 +187,7 @@ def log_start(
185
187
  if plan.finish:
186
188
  eval_plan.steps.append(eval_plan_step(plan.finish))
187
189
 
188
- logger.log_start(eval_plan)
190
+ await logger.log_start(eval_plan)
189
191
 
190
192
 
191
193
  def collect_eval_data(stats: EvalStats) -> None:
@@ -175,7 +175,10 @@ def scorer_for_metrics(
175
175
  )
176
176
 
177
177
  # process metric values
178
- metric_value = metric(scores)
178
+ if len(scores) > 0:
179
+ metric_value = metric(scores)
180
+ else:
181
+ metric_value = float("Nan")
179
182
  base_metric_name = registry_log_name(metric)
180
183
 
181
184
  # If the metric value is a dictionary, turn each of the entries
@@ -233,7 +236,9 @@ def scorers_from_metric_dict(
233
236
  results: list[EvalScore] = []
234
237
 
235
238
  # Expand any metric keys
236
- resolved_metrics = resolve_glob_metric_keys(metrics, scores[0])
239
+ resolved_metrics = (
240
+ resolve_glob_metric_keys(metrics, scores[0]) if len(scores) > 0 else metrics
241
+ )
237
242
 
238
243
  for metric_key, metric_list in resolved_metrics.items():
239
244
  # filter scores to a list of scalars with the value of the metric name
@@ -258,9 +263,13 @@ def scorers_from_metric_dict(
258
263
  for target_metric in metric_list:
259
264
  # compute the metric value
260
265
  metric_name = registry_log_name(target_metric)
266
+ if len(metric_scores) > 0:
267
+ value = target_metric(metric_scores)
268
+ else:
269
+ value = float("Nan")
261
270
  result_metrics[metric_name] = EvalMetric(
262
271
  name=metric_name,
263
- value=cast(float, target_metric(metric_scores)),
272
+ value=cast(float, value),
264
273
  )
265
274
 
266
275
  # create a scorer result for this metric