inspect-ai 0.3.49__py3-none-any.whl → 0.3.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/info.py +2 -2
- inspect_ai/_cli/log.py +2 -2
- inspect_ai/_cli/score.py +2 -2
- inspect_ai/_display/core/display.py +19 -0
- inspect_ai/_display/core/panel.py +37 -7
- inspect_ai/_display/core/progress.py +29 -2
- inspect_ai/_display/core/results.py +79 -40
- inspect_ai/_display/core/textual.py +21 -0
- inspect_ai/_display/rich/display.py +28 -8
- inspect_ai/_display/textual/app.py +107 -1
- inspect_ai/_display/textual/display.py +1 -1
- inspect_ai/_display/textual/widgets/samples.py +132 -91
- inspect_ai/_display/textual/widgets/task_detail.py +236 -0
- inspect_ai/_display/textual/widgets/tasks.py +74 -6
- inspect_ai/_display/textual/widgets/toggle.py +32 -0
- inspect_ai/_eval/context.py +2 -0
- inspect_ai/_eval/eval.py +4 -3
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/run.py +35 -2
- inspect_ai/_eval/task/log.py +13 -11
- inspect_ai/_eval/task/results.py +12 -3
- inspect_ai/_eval/task/run.py +139 -36
- inspect_ai/_eval/task/sandbox.py +2 -1
- inspect_ai/_util/_async.py +30 -1
- inspect_ai/_util/file.py +31 -4
- inspect_ai/_util/html.py +3 -0
- inspect_ai/_util/logger.py +6 -5
- inspect_ai/_util/platform.py +5 -6
- inspect_ai/_util/registry.py +1 -1
- inspect_ai/_view/server.py +9 -9
- inspect_ai/_view/www/App.css +2 -2
- inspect_ai/_view/www/dist/assets/index.css +2 -2
- inspect_ai/_view/www/dist/assets/index.js +352 -294
- inspect_ai/_view/www/log-schema.json +13 -0
- inspect_ai/_view/www/package.json +1 -0
- inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +16 -13
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
- inspect_ai/_view/www/src/types/log.d.ts +2 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
- inspect_ai/_view/www/yarn.lock +9 -4
- inspect_ai/approval/__init__.py +1 -1
- inspect_ai/approval/_human/approver.py +35 -0
- inspect_ai/approval/_human/console.py +62 -0
- inspect_ai/approval/_human/manager.py +108 -0
- inspect_ai/approval/_human/panel.py +233 -0
- inspect_ai/approval/_human/util.py +51 -0
- inspect_ai/dataset/_sources/hf.py +2 -2
- inspect_ai/dataset/_sources/util.py +1 -1
- inspect_ai/log/_file.py +106 -36
- inspect_ai/log/_recorders/eval.py +226 -158
- inspect_ai/log/_recorders/file.py +9 -6
- inspect_ai/log/_recorders/json.py +35 -12
- inspect_ai/log/_recorders/recorder.py +15 -15
- inspect_ai/log/_samples.py +52 -0
- inspect_ai/model/_model.py +14 -0
- inspect_ai/model/_model_output.py +4 -0
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/hf.py +106 -4
- inspect_ai/model/_providers/util/__init__.py +2 -0
- inspect_ai/model/_providers/util/hf_handler.py +200 -0
- inspect_ai/scorer/_common.py +1 -1
- inspect_ai/solver/_plan.py +0 -8
- inspect_ai/solver/_task_state.py +18 -1
- inspect_ai/solver/_use_tools.py +9 -1
- inspect_ai/tool/_tool_def.py +2 -2
- inspect_ai/tool/_tool_info.py +14 -2
- inspect_ai/tool/_tool_params.py +2 -1
- inspect_ai/tool/_tools/_execute.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
- inspect_ai/util/__init__.py +5 -6
- inspect_ai/util/_panel.py +91 -0
- inspect_ai/util/_sandbox/__init__.py +2 -6
- inspect_ai/util/_sandbox/context.py +4 -3
- inspect_ai/util/_sandbox/docker/compose.py +12 -2
- inspect_ai/util/_sandbox/docker/docker.py +19 -9
- inspect_ai/util/_sandbox/docker/util.py +10 -2
- inspect_ai/util/_sandbox/environment.py +47 -41
- inspect_ai/util/_sandbox/local.py +15 -10
- inspect_ai/util/_subprocess.py +43 -3
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/RECORD +90 -82
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/approval/_human.py +0 -123
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.51.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,236 @@
|
|
1
|
+
import re
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from textual.app import ComposeResult
|
6
|
+
from textual.containers import Center, Grid, Horizontal
|
7
|
+
from textual.reactive import Reactive, reactive
|
8
|
+
from textual.widget import Widget
|
9
|
+
from textual.widgets import Static
|
10
|
+
|
11
|
+
from inspect_ai._display.core.display import TaskDisplayMetric
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class TaskMetric:
|
16
|
+
name: str
|
17
|
+
value: float
|
18
|
+
|
19
|
+
|
20
|
+
class TaskDetail(Widget):
|
21
|
+
hidden = reactive(False)
|
22
|
+
DEFAULT_CSS = """
|
23
|
+
TaskDetail {
|
24
|
+
background: $boost;
|
25
|
+
width: 100%;
|
26
|
+
height: auto;
|
27
|
+
padding: 1 0 1 0;
|
28
|
+
}
|
29
|
+
TaskDetail Grid {
|
30
|
+
width: 100%;
|
31
|
+
height: auto;
|
32
|
+
grid-gutter: 1 3;
|
33
|
+
}
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
*,
|
39
|
+
hidden: bool = True,
|
40
|
+
id: str | None = None,
|
41
|
+
classes: str | None = None,
|
42
|
+
) -> None:
|
43
|
+
super().__init__(id=id, classes=classes)
|
44
|
+
self.hidden = hidden
|
45
|
+
self.existing_metrics: dict[str, TaskMetrics] = {}
|
46
|
+
self.grid = Grid()
|
47
|
+
self.by_reducer: dict[str | None, dict[str, list[TaskMetric]]] = {}
|
48
|
+
self.metrics: list[TaskDisplayMetric] = []
|
49
|
+
|
50
|
+
def watch_hidden(self, hidden: bool) -> None:
|
51
|
+
"""React to changes in the `visible` property."""
|
52
|
+
if hidden:
|
53
|
+
self.add_class("hidden")
|
54
|
+
else:
|
55
|
+
self.remove_class("hidden")
|
56
|
+
|
57
|
+
def compose(self) -> ComposeResult:
|
58
|
+
yield self.grid
|
59
|
+
|
60
|
+
def on_mount(self) -> None:
|
61
|
+
self.refresh_grid()
|
62
|
+
|
63
|
+
def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
|
64
|
+
# Group by reducer then scorer within reducers
|
65
|
+
self.metrics = metrics
|
66
|
+
for metric in metrics:
|
67
|
+
reducer_group = (
|
68
|
+
self.by_reducer[metric.reducer]
|
69
|
+
if metric.reducer in self.by_reducer
|
70
|
+
else {}
|
71
|
+
)
|
72
|
+
|
73
|
+
by_scorer_metrics = (
|
74
|
+
reducer_group[metric.scorer] if metric.scorer in reducer_group else []
|
75
|
+
)
|
76
|
+
by_scorer_metrics.append(TaskMetric(name=metric.name, value=metric.value))
|
77
|
+
reducer_group[metric.scorer] = by_scorer_metrics
|
78
|
+
self.by_reducer[metric.reducer] = reducer_group
|
79
|
+
|
80
|
+
self.refresh_grid()
|
81
|
+
|
82
|
+
def refresh_grid(self) -> None:
|
83
|
+
# Don't refresh the grid if not attached
|
84
|
+
# since we may explicitly mount new widgets
|
85
|
+
if not self.grid.is_attached:
|
86
|
+
return
|
87
|
+
|
88
|
+
# don't refresh the grid if there are no scores
|
89
|
+
if len(self.by_reducer) == 0:
|
90
|
+
return
|
91
|
+
|
92
|
+
# Compute the row and column count
|
93
|
+
row_count = len(self.by_reducer)
|
94
|
+
col_count = len(next(iter(self.by_reducer.values())))
|
95
|
+
|
96
|
+
# If this can fit in a single row, make it fit
|
97
|
+
# otherwise place each reducer on their own row
|
98
|
+
self.grid.styles.grid_columns = "auto"
|
99
|
+
if row_count * col_count < 4:
|
100
|
+
self.grid.styles.grid_size_columns = row_count * col_count
|
101
|
+
self.grid.styles.grid_size_rows = 1
|
102
|
+
else:
|
103
|
+
self.grid.styles.grid_size_columns = col_count
|
104
|
+
self.grid.styles.grid_size_rows = row_count
|
105
|
+
|
106
|
+
# In order to reduce flashing the below tracks use of widgets
|
107
|
+
# and updates them when possible (removing and adding them as needed)
|
108
|
+
# Makes keys for tracking Task Metric widgets
|
109
|
+
def metric_key(reducer: str | None, scorer: str) -> str:
|
110
|
+
reducer = reducer or "none"
|
111
|
+
return valid_id(f"task-{reducer}-{scorer}-tbl")
|
112
|
+
|
113
|
+
# Remove keys that are no longer present
|
114
|
+
existing_keys = set(self.existing_metrics.keys())
|
115
|
+
new_keys = set(metric_key(m.reducer, m.scorer) for m in self.metrics)
|
116
|
+
to_remove = existing_keys - new_keys
|
117
|
+
for remove in to_remove:
|
118
|
+
task_metric = self.existing_metrics[remove]
|
119
|
+
task_metric.remove()
|
120
|
+
|
121
|
+
# add or update widgets with metrics
|
122
|
+
for reducer, scorers in self.by_reducer.items():
|
123
|
+
for scorer, scores in scorers.items():
|
124
|
+
key = metric_key(reducer=reducer, scorer=scorer)
|
125
|
+
if key in self.existing_metrics:
|
126
|
+
task_metrics = self.existing_metrics[key]
|
127
|
+
task_metrics.update(scores)
|
128
|
+
else:
|
129
|
+
task_metrics = TaskMetrics(
|
130
|
+
id=key, scorer=scorer, reducer=reducer, metrics=scores
|
131
|
+
)
|
132
|
+
self.grid.mount(task_metrics)
|
133
|
+
self.existing_metrics[key] = task_metrics
|
134
|
+
|
135
|
+
|
136
|
+
class TaskMetrics(Widget):
|
137
|
+
DEFAULT_CSS = """
|
138
|
+
TaskMetrics {
|
139
|
+
width: auto;
|
140
|
+
height: auto;
|
141
|
+
}
|
142
|
+
TaskMetrics Grid {
|
143
|
+
width: auto;
|
144
|
+
grid-size: 2;
|
145
|
+
grid-columns: auto;
|
146
|
+
grid-gutter: 0 3;
|
147
|
+
padding: 0 2 0 2;
|
148
|
+
}
|
149
|
+
TaskMetric Center {
|
150
|
+
width: auto;
|
151
|
+
}
|
152
|
+
TaskMetrics Center Static {
|
153
|
+
width: auto;
|
154
|
+
}
|
155
|
+
TaskMetrics Center Horizontal {
|
156
|
+
width: auto;
|
157
|
+
height: auto;
|
158
|
+
}
|
159
|
+
TaskMetrics Center Horizontal Static {
|
160
|
+
width: auto;
|
161
|
+
height: auto;
|
162
|
+
}
|
163
|
+
TaskMetrics .scorer {
|
164
|
+
padding: 0 1 0 0;
|
165
|
+
text-style: bold;
|
166
|
+
}
|
167
|
+
TaskMetrics .reducer {
|
168
|
+
color: $foreground-darken-3;
|
169
|
+
}
|
170
|
+
"""
|
171
|
+
|
172
|
+
metrics: Reactive[list[TaskMetric]] = reactive([])
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
*,
|
177
|
+
scorer: str | None,
|
178
|
+
reducer: str | None,
|
179
|
+
metrics: list[TaskMetric],
|
180
|
+
id: str | None = None,
|
181
|
+
classes: str | None = None,
|
182
|
+
) -> None:
|
183
|
+
super().__init__(id=id, classes=classes)
|
184
|
+
self.scorer = scorer
|
185
|
+
self.reducer = reducer
|
186
|
+
self.metrics = metrics
|
187
|
+
self.grid: Grid = Grid()
|
188
|
+
self.value_widgets: dict[str, Static] = {}
|
189
|
+
|
190
|
+
def compose(self) -> ComposeResult:
|
191
|
+
# Just yield a single DataTable widget
|
192
|
+
yield Center(self._title())
|
193
|
+
with Grid():
|
194
|
+
for metric in self.metrics:
|
195
|
+
# Add the value static but keep it around
|
196
|
+
# for future updates
|
197
|
+
self.value_widgets[metric.name] = Static(
|
198
|
+
self._metric_value(metric.value)
|
199
|
+
)
|
200
|
+
|
201
|
+
yield Static(metric.name)
|
202
|
+
yield self.value_widgets[metric.name]
|
203
|
+
|
204
|
+
def update(self, metrics: list[TaskMetric]) -> None:
|
205
|
+
for metric in metrics:
|
206
|
+
widget = self.value_widgets[metric.name]
|
207
|
+
widget.update(content=f"{metric.value:,.3f}")
|
208
|
+
|
209
|
+
def _title(self) -> Widget:
|
210
|
+
if self.scorer is None:
|
211
|
+
return Static("")
|
212
|
+
elif self.reducer is None:
|
213
|
+
return Static(self.scorer)
|
214
|
+
else:
|
215
|
+
return Horizontal(
|
216
|
+
Static(self.scorer, classes="scorer"),
|
217
|
+
Static(f"({self.reducer})", classes="reducer"),
|
218
|
+
)
|
219
|
+
|
220
|
+
def _metric_value(self, val: float) -> str:
|
221
|
+
if np.isnan(val):
|
222
|
+
return " n/a "
|
223
|
+
else:
|
224
|
+
return f"{val:.3f}"
|
225
|
+
|
226
|
+
|
227
|
+
def valid_id(identifier: str) -> str:
|
228
|
+
# Remove invalid characters
|
229
|
+
valid_part = re.sub(r"[^a-zA-Z0-9_-]", "_", identifier)
|
230
|
+
|
231
|
+
# Ensure it doesn't start with a number
|
232
|
+
if valid_part and valid_part[0].isdigit():
|
233
|
+
valid_part = "_" + valid_part
|
234
|
+
|
235
|
+
# If the string is empty return a default valid identifier
|
236
|
+
return valid_part or "default_identifier"
|
@@ -4,19 +4,25 @@ from typing import Iterator, cast
|
|
4
4
|
|
5
5
|
from rich.console import RenderableType
|
6
6
|
from rich.text import Text
|
7
|
+
from textual import on
|
7
8
|
from textual.app import ComposeResult
|
8
9
|
from textual.containers import Container, ScrollableContainer
|
10
|
+
from textual.css.query import NoMatches
|
9
11
|
from textual.reactive import reactive
|
10
12
|
from textual.widget import Widget
|
11
13
|
from textual.widgets import ProgressBar, Static
|
12
14
|
from typing_extensions import override
|
13
15
|
|
16
|
+
from inspect_ai._display.core.results import task_metric
|
14
17
|
from inspect_ai._display.textual.widgets.clock import Clock
|
18
|
+
from inspect_ai._display.textual.widgets.task_detail import TaskDetail
|
19
|
+
from inspect_ai._display.textual.widgets.toggle import Toggle
|
15
20
|
|
16
21
|
from ...core.display import (
|
17
22
|
Progress,
|
18
23
|
TaskCancelled,
|
19
24
|
TaskDisplay,
|
25
|
+
TaskDisplayMetric,
|
20
26
|
TaskError,
|
21
27
|
TaskResult,
|
22
28
|
TaskSpec,
|
@@ -25,6 +31,7 @@ from ...core.display import (
|
|
25
31
|
from ...core.progress import (
|
26
32
|
MAX_DESCRIPTION_WIDTH,
|
27
33
|
MAX_MODEL_NAME_WIDTH,
|
34
|
+
progress_count,
|
28
35
|
progress_description,
|
29
36
|
progress_model_name,
|
30
37
|
)
|
@@ -106,9 +113,10 @@ class TaskProgressView(Widget):
|
|
106
113
|
height: auto;
|
107
114
|
width: 1fr;
|
108
115
|
layout: grid;
|
109
|
-
grid-size:
|
110
|
-
grid-columns: auto auto auto 1fr auto;
|
111
|
-
grid-
|
116
|
+
grid-size: 8 2;
|
117
|
+
grid-columns: auto auto auto auto 1fr auto auto auto;
|
118
|
+
grid-rows: auto auto;
|
119
|
+
grid-gutter: 0 1;
|
112
120
|
}
|
113
121
|
TaskProgressView Bar {
|
114
122
|
width: 1fr;
|
@@ -119,6 +127,15 @@ class TaskProgressView(Widget):
|
|
119
127
|
color: $success;
|
120
128
|
}
|
121
129
|
}
|
130
|
+
#task-metrics {
|
131
|
+
color:$text-secondary;
|
132
|
+
}
|
133
|
+
#task-detail {
|
134
|
+
column-span: 8;
|
135
|
+
}
|
136
|
+
.hidden {
|
137
|
+
display: none;
|
138
|
+
}
|
122
139
|
"""
|
123
140
|
|
124
141
|
def __init__(
|
@@ -126,12 +143,19 @@ class TaskProgressView(Widget):
|
|
126
143
|
) -> None:
|
127
144
|
super().__init__()
|
128
145
|
self.t = task
|
146
|
+
|
129
147
|
self.description_width = description_width
|
130
148
|
self.model_name_width = model_name_width
|
131
149
|
self.progress_bar = ProgressBar(total=task.profile.steps, show_eta=False)
|
150
|
+
self.count_display = Static()
|
151
|
+
self.metrics_display = Static(id="task-metrics")
|
132
152
|
self.task_progress = TaskProgress(self.progress_bar)
|
133
153
|
|
154
|
+
self.toggle = Toggle()
|
155
|
+
self.task_detail = TaskDetail(id="task-detail", classes="hidden")
|
156
|
+
|
134
157
|
def compose(self) -> ComposeResult:
|
158
|
+
yield self.toggle
|
135
159
|
yield TaskStatusIcon()
|
136
160
|
yield Static(
|
137
161
|
progress_description(self.t.profile, self.description_width, pad=True)
|
@@ -140,7 +164,15 @@ class TaskProgressView(Widget):
|
|
140
164
|
progress_model_name(self.t.profile.model, self.model_name_width, pad=True)
|
141
165
|
)
|
142
166
|
yield self.progress_bar
|
167
|
+
yield self.count_display
|
168
|
+
yield self.metrics_display
|
143
169
|
yield Clock()
|
170
|
+
yield self.task_detail
|
171
|
+
|
172
|
+
@on(Toggle.Toggled)
|
173
|
+
def handle_title_toggle(self, event: Toggle.Toggled) -> None:
|
174
|
+
self.task_detail.hidden = not self.toggle.toggled
|
175
|
+
event.stop()
|
144
176
|
|
145
177
|
def on_mount(self) -> None:
|
146
178
|
self.query_one(Clock).start(datetime.now().timestamp())
|
@@ -151,10 +183,21 @@ class TaskProgressView(Widget):
|
|
151
183
|
|
152
184
|
def complete(self, result: TaskResult) -> None:
|
153
185
|
self.t.result = result
|
154
|
-
|
155
|
-
|
186
|
+
try:
|
187
|
+
self.query_one(TaskStatusIcon).result = result
|
188
|
+
self.query_one(Clock).stop()
|
189
|
+
except NoMatches:
|
190
|
+
pass
|
156
191
|
self.task_progress.complete()
|
157
192
|
|
193
|
+
def sample_complete(self, complete: int, total: int) -> None:
|
194
|
+
self.count_display.update(progress_count(complete, total))
|
195
|
+
|
196
|
+
def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None:
|
197
|
+
if len(metrics) > 0:
|
198
|
+
self.metrics_display.update(task_metric(metrics))
|
199
|
+
self.task_detail.update_metrics(metrics)
|
200
|
+
|
158
201
|
|
159
202
|
class TaskStatusIcon(Static):
|
160
203
|
result: reactive[TaskResult | None] = reactive(None)
|
@@ -181,13 +224,38 @@ class TaskStatusIcon(Static):
|
|
181
224
|
return Text("⠿", style=running)
|
182
225
|
|
183
226
|
|
227
|
+
MAX_PROGRESS_PERCENT = 0.02
|
228
|
+
MIN_PROGRESS_PERCENT = 0.98
|
229
|
+
|
230
|
+
|
184
231
|
class TaskProgress(Progress):
|
185
232
|
def __init__(self, progress_bar: ProgressBar) -> None:
|
186
233
|
self.progress_bar = progress_bar
|
234
|
+
self.current_progress = 0
|
235
|
+
|
236
|
+
# always show a minimum amount of progress
|
237
|
+
minimum_steps = (
|
238
|
+
MAX_PROGRESS_PERCENT * progress_bar.total
|
239
|
+
if progress_bar.total is not None
|
240
|
+
else 0
|
241
|
+
)
|
242
|
+
self.progress_bar.update(progress=minimum_steps)
|
187
243
|
|
188
244
|
@override
|
189
245
|
def update(self, n: int = 1) -> None:
|
190
|
-
self.
|
246
|
+
self.current_progress = self.current_progress + n
|
247
|
+
|
248
|
+
# enforce a maximum cap on task progress
|
249
|
+
max_progress = (
|
250
|
+
MIN_PROGRESS_PERCENT * self.progress_bar.total
|
251
|
+
if self.progress_bar.total is not None
|
252
|
+
else 0
|
253
|
+
)
|
254
|
+
if (
|
255
|
+
self.current_progress > self.progress_bar.progress
|
256
|
+
and self.current_progress < max_progress
|
257
|
+
):
|
258
|
+
self.progress_bar.update(progress=self.current_progress)
|
191
259
|
|
192
260
|
@override
|
193
261
|
def complete(self) -> None:
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from textual.events import Click
|
2
|
+
from textual.message import Message
|
3
|
+
from textual.reactive import reactive
|
4
|
+
from textual.widgets import Static
|
5
|
+
|
6
|
+
|
7
|
+
class Toggle(Static, can_focus=True):
|
8
|
+
toggled = reactive(True)
|
9
|
+
|
10
|
+
def __init__(
|
11
|
+
self, on_symbol: str = "▼", off_symbol: str = "▶", toggled: bool = False
|
12
|
+
) -> None:
|
13
|
+
super().__init__()
|
14
|
+
|
15
|
+
self.on_symbol = on_symbol
|
16
|
+
self.off_symbol = off_symbol
|
17
|
+
self.toggled = toggled
|
18
|
+
|
19
|
+
class Toggled(Message):
|
20
|
+
"""Request toggle."""
|
21
|
+
|
22
|
+
async def _on_click(self, event: Click) -> None:
|
23
|
+
"""Inform ancestor we want to toggle."""
|
24
|
+
event.stop()
|
25
|
+
self.toggled = not self.toggled
|
26
|
+
self.post_message(self.Toggled())
|
27
|
+
|
28
|
+
def _watch_toggled(self, toggled: bool) -> None:
|
29
|
+
if toggled:
|
30
|
+
self.update(self.on_symbol)
|
31
|
+
else:
|
32
|
+
self.update(self.off_symbol)
|
inspect_ai/_eval/context.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from inspect_ai._util.dotenv import init_dotenv
|
2
2
|
from inspect_ai._util.hooks import init_hooks
|
3
3
|
from inspect_ai._util.logger import init_http_rate_limit_count, init_logger
|
4
|
+
from inspect_ai.approval._human.manager import init_human_approval_manager
|
4
5
|
from inspect_ai.log._samples import init_active_samples
|
5
6
|
from inspect_ai.model import GenerateConfig, Model
|
6
7
|
from inspect_ai.model._model import init_active_model, init_model_usage
|
@@ -20,6 +21,7 @@ def init_eval_context(
|
|
20
21
|
init_http_rate_limit_count()
|
21
22
|
init_hooks()
|
22
23
|
init_active_samples()
|
24
|
+
init_human_approval_manager()
|
23
25
|
|
24
26
|
|
25
27
|
def init_task_context(model: Model, config: GenerateConfig = GenerateConfig()) -> None:
|
inspect_ai/_eval/eval.py
CHANGED
@@ -21,7 +21,8 @@ from inspect_ai.approval._policy import (
|
|
21
21
|
approval_policies_from_config,
|
22
22
|
config_from_approval_policies,
|
23
23
|
)
|
24
|
-
from inspect_ai.log import EvalConfig, EvalLog, EvalLogInfo
|
24
|
+
from inspect_ai.log import EvalConfig, EvalLog, EvalLogInfo
|
25
|
+
from inspect_ai.log._file import read_eval_log_async
|
25
26
|
from inspect_ai.log._recorders import create_recorder_for_format
|
26
27
|
from inspect_ai.model import (
|
27
28
|
GenerateConfig,
|
@@ -600,9 +601,9 @@ async def eval_retry_async(
|
|
600
601
|
task
|
601
602
|
if isinstance(task, EvalLog)
|
602
603
|
else (
|
603
|
-
|
604
|
+
await read_eval_log_async(task.name)
|
604
605
|
if isinstance(task, EvalLogInfo)
|
605
|
-
else
|
606
|
+
else await read_eval_log_async(task)
|
606
607
|
)
|
607
608
|
)
|
608
609
|
for task in tasks
|
inspect_ai/_eval/loader.py
CHANGED
@@ -198,7 +198,7 @@ def resolve_task_sandbox(
|
|
198
198
|
break
|
199
199
|
|
200
200
|
# resolve relative paths
|
201
|
-
if resolved_sandbox.config
|
201
|
+
if isinstance(resolved_sandbox.config, str):
|
202
202
|
file_path = Path(resolved_sandbox.config)
|
203
203
|
if not file_path.is_absolute():
|
204
204
|
file_path = Path(task_run_dir(task)) / file_path
|
inspect_ai/_eval/run.py
CHANGED
@@ -12,9 +12,10 @@ from inspect_ai._display.core.active import (
|
|
12
12
|
init_task_screen,
|
13
13
|
)
|
14
14
|
from inspect_ai._display.core.display import TaskSpec
|
15
|
-
from inspect_ai._util.error import exception_message
|
15
|
+
from inspect_ai._util.error import PrerequisiteError, exception_message
|
16
16
|
from inspect_ai._util.path import chdir
|
17
17
|
from inspect_ai._util.registry import registry_unqualified_name
|
18
|
+
from inspect_ai.dataset._dataset import Dataset
|
18
19
|
from inspect_ai.log import EvalConfig, EvalLog
|
19
20
|
from inspect_ai.log._recorders import Recorder
|
20
21
|
from inspect_ai.model import GenerateConfigArgs
|
@@ -23,6 +24,7 @@ from inspect_ai.scorer._reducer import ScoreReducer, reducer_log_names
|
|
23
24
|
from inspect_ai.scorer._reducer.registry import validate_reducer
|
24
25
|
from inspect_ai.solver._solver import Solver, SolverSpec
|
25
26
|
from inspect_ai.util._sandbox.environment import (
|
27
|
+
SandboxEnvironmentConfigType,
|
26
28
|
SandboxEnvironmentSpec,
|
27
29
|
SandboxEnvironmentType,
|
28
30
|
TaskCleanup,
|
@@ -149,6 +151,9 @@ async def eval_run(
|
|
149
151
|
if sample.id is None:
|
150
152
|
sample.id = id + 1
|
151
153
|
|
154
|
+
# Ensure sample ids are unique
|
155
|
+
ensure_unique_ids(task.dataset)
|
156
|
+
|
152
157
|
# create and track the logger
|
153
158
|
logger = TaskLogger(
|
154
159
|
task_name=task.name,
|
@@ -168,6 +173,7 @@ async def eval_run(
|
|
168
173
|
metadata=task.metadata,
|
169
174
|
recorder=recorder,
|
170
175
|
)
|
176
|
+
await logger.init()
|
171
177
|
|
172
178
|
# append task
|
173
179
|
task_run_options.append(
|
@@ -287,6 +293,12 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
|
|
287
293
|
await task
|
288
294
|
result = task.result()
|
289
295
|
results.append(result)
|
296
|
+
except Exception as ex:
|
297
|
+
# errors generally don't escape from tasks (the exception being if an error
|
298
|
+
# occurs during the final write of the log)
|
299
|
+
log.error(
|
300
|
+
f"Task '{task_options.task.name}' encountered an error during finalisation: {ex}"
|
301
|
+
)
|
290
302
|
|
291
303
|
# tracking
|
292
304
|
tasks_completed += 1
|
@@ -340,7 +352,7 @@ async def startup_sandbox_environments(
|
|
340
352
|
sandboxenvs.add(sandbox)
|
341
353
|
|
342
354
|
# initialiase sandboxenvs (track cleanups)
|
343
|
-
cleanups: list[tuple[TaskCleanup,
|
355
|
+
cleanups: list[tuple[TaskCleanup, SandboxEnvironmentConfigType | None, str]] = []
|
344
356
|
with display().suspend_task_app():
|
345
357
|
for sandboxenv in sandboxenvs:
|
346
358
|
# find type
|
@@ -377,3 +389,24 @@ def task_specs(tasks: list[TaskRunOptions]) -> list[TaskSpec]:
|
|
377
389
|
TaskSpec(registry_unqualified_name(task.task.name), ModelName(task.model))
|
378
390
|
for task in tasks
|
379
391
|
]
|
392
|
+
|
393
|
+
|
394
|
+
def ensure_unique_ids(dataset: Dataset) -> None:
|
395
|
+
"""
|
396
|
+
Validates that all samples in the dataset have unique IDs.
|
397
|
+
|
398
|
+
Raises a error if duplicates are found.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
dataset (Datatset): The dataset
|
402
|
+
|
403
|
+
Raises:
|
404
|
+
PrerequisiteError: If duplicate IDs are found in the dataset.
|
405
|
+
"""
|
406
|
+
seen_ids = set()
|
407
|
+
for sample in dataset:
|
408
|
+
if sample.id in seen_ids:
|
409
|
+
raise PrerequisiteError(
|
410
|
+
f"The dataset contains duplicate sample ids (duplicate id: {sample.id}). Please ensure each sample has a unique id."
|
411
|
+
)
|
412
|
+
seen_ids.add(sample.id)
|
inspect_ai/_eval/task/log.py
CHANGED
@@ -75,7 +75,7 @@ class TaskLogger:
|
|
75
75
|
del model_args["api_key"]
|
76
76
|
|
77
77
|
# cwd_relative_path for sandbox config
|
78
|
-
if sandbox and sandbox.config:
|
78
|
+
if sandbox and isinstance(sandbox.config, str):
|
79
79
|
sandbox = SandboxEnvironmentSpec(
|
80
80
|
sandbox.type, cwd_relative_path(sandbox.config)
|
81
81
|
)
|
@@ -118,7 +118,6 @@ class TaskLogger:
|
|
118
118
|
|
119
119
|
# stack recorder and location
|
120
120
|
self.recorder = recorder
|
121
|
-
self._location = self.recorder.log_init(self.eval)
|
122
121
|
|
123
122
|
# number of samples logged
|
124
123
|
self._samples_completed = 0
|
@@ -127,6 +126,9 @@ class TaskLogger:
|
|
127
126
|
self.flush_buffer = eval_config.log_buffer or recorder.default_log_buffer()
|
128
127
|
self.flush_pending = 0
|
129
128
|
|
129
|
+
async def init(self) -> None:
|
130
|
+
self._location = await self.recorder.log_init(self.eval)
|
131
|
+
|
130
132
|
@property
|
131
133
|
def location(self) -> str:
|
132
134
|
return self._location
|
@@ -135,25 +137,25 @@ class TaskLogger:
|
|
135
137
|
def samples_completed(self) -> int:
|
136
138
|
return self._samples_completed
|
137
139
|
|
138
|
-
def log_start(self, plan: EvalPlan) -> None:
|
139
|
-
self.recorder.log_start(self.eval, plan)
|
140
|
+
async def log_start(self, plan: EvalPlan) -> None:
|
141
|
+
await self.recorder.log_start(self.eval, plan)
|
140
142
|
|
141
|
-
def log_sample(self, sample: EvalSample, *, flush: bool) -> None:
|
143
|
+
async def log_sample(self, sample: EvalSample, *, flush: bool) -> None:
|
142
144
|
# log the sample
|
143
|
-
self.recorder.log_sample(self.eval, sample)
|
145
|
+
await self.recorder.log_sample(self.eval, sample)
|
144
146
|
|
145
147
|
# flush if requested
|
146
148
|
if flush:
|
147
149
|
self.flush_pending += 1
|
148
150
|
if self.flush_pending >= self.flush_buffer:
|
149
|
-
self.recorder.flush(self.eval)
|
151
|
+
await self.recorder.flush(self.eval)
|
150
152
|
self.flush_pending = 0
|
151
153
|
|
152
154
|
# track sucessful samples logged
|
153
155
|
if sample.error is None:
|
154
156
|
self._samples_completed += 1
|
155
157
|
|
156
|
-
def log_finish(
|
158
|
+
async def log_finish(
|
157
159
|
self,
|
158
160
|
status: Literal["success", "cancelled", "error"],
|
159
161
|
stats: EvalStats,
|
@@ -161,12 +163,12 @@ class TaskLogger:
|
|
161
163
|
reductions: list[EvalSampleReductions] | None = None,
|
162
164
|
error: EvalError | None = None,
|
163
165
|
) -> EvalLog:
|
164
|
-
return self.recorder.log_finish(
|
166
|
+
return await self.recorder.log_finish(
|
165
167
|
self.eval, status, stats, results, reductions, error
|
166
168
|
)
|
167
169
|
|
168
170
|
|
169
|
-
def log_start(
|
171
|
+
async def log_start(
|
170
172
|
logger: TaskLogger,
|
171
173
|
plan: Plan,
|
172
174
|
config: GenerateConfig,
|
@@ -185,7 +187,7 @@ def log_start(
|
|
185
187
|
if plan.finish:
|
186
188
|
eval_plan.steps.append(eval_plan_step(plan.finish))
|
187
189
|
|
188
|
-
logger.log_start(eval_plan)
|
190
|
+
await logger.log_start(eval_plan)
|
189
191
|
|
190
192
|
|
191
193
|
def collect_eval_data(stats: EvalStats) -> None:
|
inspect_ai/_eval/task/results.py
CHANGED
@@ -175,7 +175,10 @@ def scorer_for_metrics(
|
|
175
175
|
)
|
176
176
|
|
177
177
|
# process metric values
|
178
|
-
|
178
|
+
if len(scores) > 0:
|
179
|
+
metric_value = metric(scores)
|
180
|
+
else:
|
181
|
+
metric_value = float("Nan")
|
179
182
|
base_metric_name = registry_log_name(metric)
|
180
183
|
|
181
184
|
# If the metric value is a dictionary, turn each of the entries
|
@@ -233,7 +236,9 @@ def scorers_from_metric_dict(
|
|
233
236
|
results: list[EvalScore] = []
|
234
237
|
|
235
238
|
# Expand any metric keys
|
236
|
-
resolved_metrics =
|
239
|
+
resolved_metrics = (
|
240
|
+
resolve_glob_metric_keys(metrics, scores[0]) if len(scores) > 0 else metrics
|
241
|
+
)
|
237
242
|
|
238
243
|
for metric_key, metric_list in resolved_metrics.items():
|
239
244
|
# filter scores to a list of scalars with the value of the metric name
|
@@ -258,9 +263,13 @@ def scorers_from_metric_dict(
|
|
258
263
|
for target_metric in metric_list:
|
259
264
|
# compute the metric value
|
260
265
|
metric_name = registry_log_name(target_metric)
|
266
|
+
if len(metric_scores) > 0:
|
267
|
+
value = target_metric(metric_scores)
|
268
|
+
else:
|
269
|
+
value = float("Nan")
|
261
270
|
result_metrics[metric_name] = EvalMetric(
|
262
271
|
name=metric_name,
|
263
|
-
value=cast(float,
|
272
|
+
value=cast(float, value),
|
264
273
|
)
|
265
274
|
|
266
275
|
# create a scorer result for this metric
|