inspect-ai 0.3.88__py3-none-any.whl → 0.3.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -244
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +55 -18
- inspect_ai/_view/www/dist/assets/index.js +550 -458
- inspect_ai/_view/www/log-schema.json +66 -0
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +24 -6
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -7
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +4 -1
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +20 -12
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +84 -79
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/score.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, cast
|
|
6
6
|
import anyio
|
7
7
|
|
8
8
|
from inspect_ai._display import display
|
9
|
+
from inspect_ai._eval.context import init_task_context
|
9
10
|
from inspect_ai._eval.loader import scorer_from_spec
|
10
11
|
from inspect_ai._util._async import configured_async_backend, run_coroutine, tg_collect
|
11
12
|
from inspect_ai._util.platform import platform_init, running_in_notebook
|
@@ -14,7 +15,9 @@ from inspect_ai.log import (
|
|
14
15
|
EvalLog,
|
15
16
|
)
|
16
17
|
from inspect_ai.log._log import EvalMetricDefinition
|
18
|
+
from inspect_ai.log._model import model_roles_config_to_model_roles
|
17
19
|
from inspect_ai.model import ModelName
|
20
|
+
from inspect_ai.model._model import get_model
|
18
21
|
from inspect_ai.scorer import Metric, Scorer, Target
|
19
22
|
from inspect_ai.scorer._metric import SampleScore
|
20
23
|
from inspect_ai.scorer._reducer import (
|
@@ -122,7 +125,7 @@ async def score_async(
|
|
122
125
|
scores: list[dict[str, SampleScore]] = await tg_collect(
|
123
126
|
[
|
124
127
|
functools.partial(
|
125
|
-
run_score_task, state, Target(sample.target), scorers, progress
|
128
|
+
run_score_task, log, state, Target(sample.target), scorers, progress
|
126
129
|
)
|
127
130
|
for (sample, state) in zip(log.samples, states)
|
128
131
|
]
|
@@ -218,11 +221,25 @@ async def task_score(
|
|
218
221
|
|
219
222
|
|
220
223
|
async def run_score_task(
|
224
|
+
log: EvalLog,
|
221
225
|
state: TaskState,
|
222
226
|
target: Target,
|
223
227
|
scorers: list[Scorer],
|
224
228
|
progress: Callable[..., None],
|
225
229
|
) -> dict[str, SampleScore]:
|
230
|
+
# get the model then initialize the async context
|
231
|
+
model = get_model(
|
232
|
+
model=log.eval.model,
|
233
|
+
config=log.plan.config.merge(log.eval.model_generate_config),
|
234
|
+
**log.eval.model_args,
|
235
|
+
)
|
236
|
+
|
237
|
+
# get the model roles
|
238
|
+
model_roles = model_roles_config_to_model_roles(log.eval.model_roles)
|
239
|
+
|
240
|
+
# initialize active model
|
241
|
+
init_task_context(model, model_roles)
|
242
|
+
|
226
243
|
results: dict[str, SampleScore] = {}
|
227
244
|
for scorer in scorers:
|
228
245
|
result = await scorer(state, target)
|
inspect_ai/_eval/task/log.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from importlib import metadata as importlib_metadata
|
2
|
-
from
|
3
|
-
from typing import Any, Iterator, Literal, cast
|
2
|
+
from typing import Any, Literal, cast
|
4
3
|
|
5
4
|
from shortuuid import uuid
|
6
5
|
|
@@ -34,6 +33,7 @@ from inspect_ai.log._log import (
|
|
34
33
|
EvalScorer,
|
35
34
|
eval_config_defaults,
|
36
35
|
)
|
36
|
+
from inspect_ai.log._model import model_args_for_log, model_roles_to_model_roles_config
|
37
37
|
from inspect_ai.log._recorders import Recorder
|
38
38
|
from inspect_ai.log._recorders.buffer import SampleBufferDatabase
|
39
39
|
from inspect_ai.log._recorders.types import SampleEvent, SampleSummary
|
@@ -63,6 +63,7 @@ class TaskLogger:
|
|
63
63
|
solver: SolverSpec | None,
|
64
64
|
tags: list[str] | None,
|
65
65
|
model: Model,
|
66
|
+
model_roles: dict[str, Model] | None,
|
66
67
|
dataset: Dataset,
|
67
68
|
scorer: list[ScorerSpec] | None,
|
68
69
|
metrics: list[MetricSpec] | dict[str, list[MetricSpec]] | None,
|
@@ -84,17 +85,7 @@ class TaskLogger:
|
|
84
85
|
packages = {PKG_NAME: importlib_metadata.version(PKG_NAME)}
|
85
86
|
|
86
87
|
# redact authentication oriented model_args
|
87
|
-
model_args = model_args
|
88
|
-
if "api_key" in model_args:
|
89
|
-
del model_args["api_key"]
|
90
|
-
model_args = {k: v for k, v in model_args.items() if not k.startswith("aws_")}
|
91
|
-
|
92
|
-
# don't try to serialise generators
|
93
|
-
model_args = {
|
94
|
-
k: v
|
95
|
-
for k, v in model_args.items()
|
96
|
-
if not isgenerator(v) and not isinstance(v, Iterator)
|
97
|
-
}
|
88
|
+
model_args = model_args_for_log(model_args)
|
98
89
|
|
99
90
|
# cwd_relative_path for sandbox config
|
100
91
|
if sandbox and isinstance(sandbox.config, str):
|
@@ -141,6 +132,7 @@ class TaskLogger:
|
|
141
132
|
model=str(ModelName(model)),
|
142
133
|
model_generate_config=model.config,
|
143
134
|
model_base_url=model.api.base_url,
|
135
|
+
model_roles=model_roles_to_model_roles_config(model_roles),
|
144
136
|
dataset=EvalDataset(
|
145
137
|
name=dataset.name,
|
146
138
|
location=cwd_relative_path(dataset.location),
|
inspect_ai/_eval/task/run.py
CHANGED
@@ -101,7 +101,6 @@ from .images import (
|
|
101
101
|
)
|
102
102
|
from .log import TaskLogger, collect_eval_data, log_start
|
103
103
|
from .results import eval_results
|
104
|
-
from .rundir import set_task_chdir
|
105
104
|
from .sandbox import sandboxenv_context
|
106
105
|
from .util import sample_messages, slice_dataset
|
107
106
|
|
@@ -121,6 +120,7 @@ SAMPLE_TOTAL_PROGRESS_UNITS = 1
|
|
121
120
|
class TaskRunOptions:
|
122
121
|
task: Task
|
123
122
|
model: Model
|
123
|
+
model_roles: dict[str, Model] | None
|
124
124
|
sandbox: SandboxEnvironmentSpec | None
|
125
125
|
logger: TaskLogger
|
126
126
|
eval_wd: str
|
@@ -137,6 +137,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
137
137
|
# destructure options
|
138
138
|
task = options.task
|
139
139
|
model = options.model
|
140
|
+
model_roles = options.model_roles
|
140
141
|
sandbox = options.sandbox
|
141
142
|
logger = options.logger
|
142
143
|
eval_wd = options.eval_wd
|
@@ -151,156 +152,136 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
151
152
|
generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
|
152
153
|
|
153
154
|
# init task context
|
154
|
-
init_task_context(model, options.task.approval, generate_config)
|
155
|
-
|
156
|
-
#
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
log_images=log_images,
|
183
|
-
message_limit=config.message_limit,
|
184
|
-
token_limit=config.token_limit,
|
185
|
-
)
|
186
|
-
|
187
|
-
# resolve the plan (unroll chains)
|
188
|
-
solver = solver or task.solver
|
189
|
-
if isinstance(solver, Plan):
|
190
|
-
plan = solver
|
191
|
-
elif isinstance(solver, Chain):
|
192
|
-
plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
|
193
|
-
else:
|
194
|
-
plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
|
195
|
-
|
196
|
-
# add setup solver(s) if specified
|
197
|
-
if task.setup:
|
198
|
-
plan.steps = unroll(task.setup) + plan.steps
|
199
|
-
|
200
|
-
# resolve the scorer
|
201
|
-
score = score and task.scorer is not None
|
202
|
-
scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
|
203
|
-
scorer_profiles = (
|
204
|
-
[
|
205
|
-
registry_log_name(scorer)
|
206
|
-
for scorer in scorers
|
207
|
-
if is_registry_object(scorer)
|
208
|
-
]
|
209
|
-
if scorers is not None
|
210
|
-
else ["(none)"]
|
211
|
-
)
|
212
|
-
|
213
|
-
# compute an eval directory relative log location if we can
|
214
|
-
if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
|
215
|
-
log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
|
216
|
-
else:
|
217
|
-
log_location = logger.location
|
218
|
-
|
219
|
-
# create task profile for display
|
220
|
-
profile = TaskProfile(
|
221
|
-
name=task.name,
|
222
|
-
file=logger.eval.task_file,
|
223
|
-
model=model_name,
|
224
|
-
dataset=task.dataset.name or "(samples)",
|
225
|
-
scorer=", ".join(scorer_profiles),
|
226
|
-
samples=len(samples),
|
227
|
-
steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
|
228
|
-
eval_config=config,
|
229
|
-
task_args=logger.eval.task_args,
|
230
|
-
generate_config=generate_config,
|
231
|
-
tags=tags,
|
232
|
-
log_location=log_location,
|
233
|
-
)
|
155
|
+
init_task_context(model, model_roles, options.task.approval, generate_config)
|
156
|
+
|
157
|
+
# track stats and error
|
158
|
+
results: EvalResults | None = None
|
159
|
+
reductions: list[EvalSampleReductions] | None = None
|
160
|
+
stats = EvalStats(started_at=iso_now())
|
161
|
+
|
162
|
+
# handle sample errors (raise as required)
|
163
|
+
sample_error_handler = SampleErrorHandler(config.fail_on_error, len(task.dataset))
|
164
|
+
|
165
|
+
# resolve some config
|
166
|
+
model_name = ModelName(model)
|
167
|
+
epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
|
168
|
+
sandbox_cleanup = config.sandbox_cleanup is not False
|
169
|
+
log_images = config.log_images is not False
|
170
|
+
log_samples = config.log_samples is not False
|
171
|
+
|
172
|
+
# resolve dataset
|
173
|
+
_, samples, states = await resolve_dataset(
|
174
|
+
dataset=task.dataset,
|
175
|
+
model_name=model_name,
|
176
|
+
limit=config.limit,
|
177
|
+
sample_id=config.sample_id,
|
178
|
+
epochs=epochs,
|
179
|
+
log_images=log_images,
|
180
|
+
message_limit=config.message_limit,
|
181
|
+
token_limit=config.token_limit,
|
182
|
+
)
|
234
183
|
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
state=state,
|
257
|
-
tool_calls=tool_calls,
|
258
|
-
cache=cache,
|
259
|
-
config=generate_config.merge(kwargs),
|
260
|
-
)
|
184
|
+
# resolve the plan (unroll chains)
|
185
|
+
solver = solver or task.solver
|
186
|
+
if isinstance(solver, Plan):
|
187
|
+
plan = solver
|
188
|
+
elif isinstance(solver, Chain):
|
189
|
+
plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
|
190
|
+
else:
|
191
|
+
plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
|
192
|
+
|
193
|
+
# add setup solver(s) if specified
|
194
|
+
if task.setup:
|
195
|
+
plan.steps = unroll(task.setup) + plan.steps
|
196
|
+
|
197
|
+
# resolve the scorer
|
198
|
+
score = score and task.scorer is not None
|
199
|
+
scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
|
200
|
+
scorer_profiles = (
|
201
|
+
[registry_log_name(scorer) for scorer in scorers if is_registry_object(scorer)]
|
202
|
+
if scorers is not None
|
203
|
+
else ["(none)"]
|
204
|
+
)
|
261
205
|
|
262
|
-
|
263
|
-
|
206
|
+
# compute an eval directory relative log location if we can
|
207
|
+
if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
|
208
|
+
log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
|
209
|
+
else:
|
210
|
+
log_location = logger.location
|
211
|
+
|
212
|
+
# create task profile for display
|
213
|
+
profile = TaskProfile(
|
214
|
+
name=task.name,
|
215
|
+
file=logger.eval.task_file,
|
216
|
+
model=model_name,
|
217
|
+
dataset=task.dataset.name or "(samples)",
|
218
|
+
scorer=", ".join(scorer_profiles),
|
219
|
+
samples=len(samples),
|
220
|
+
steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
|
221
|
+
eval_config=config,
|
222
|
+
task_args=logger.eval.task_args,
|
223
|
+
generate_config=generate_config,
|
224
|
+
tags=tags,
|
225
|
+
log_location=log_location,
|
226
|
+
)
|
264
227
|
|
265
|
-
|
266
|
-
|
267
|
-
|
228
|
+
with display().task(
|
229
|
+
profile,
|
230
|
+
) as td:
|
231
|
+
try:
|
232
|
+
# start the log
|
233
|
+
await log_start(logger, plan, generate_config)
|
234
|
+
|
235
|
+
with td.progress() as p:
|
236
|
+
# forward progress
|
237
|
+
def progress(number: int) -> None:
|
238
|
+
p.update(number)
|
239
|
+
|
240
|
+
# provide solvers a function that they can use to generate output
|
241
|
+
async def generate(
|
242
|
+
state: TaskState,
|
243
|
+
tool_calls: Literal["loop", "single", "none"] = "loop",
|
244
|
+
cache: bool | CachePolicy = False,
|
245
|
+
**kwargs: Unpack[GenerateConfigArgs],
|
246
|
+
) -> TaskState:
|
247
|
+
return await task_generate(
|
248
|
+
model=model,
|
249
|
+
state=state,
|
250
|
+
tool_calls=tool_calls,
|
251
|
+
cache=cache,
|
252
|
+
config=generate_config.merge(kwargs),
|
268
253
|
)
|
269
254
|
|
270
|
-
|
271
|
-
|
255
|
+
# set generate for fork module
|
256
|
+
set_task_generate(generate)
|
272
257
|
|
273
|
-
|
274
|
-
|
275
|
-
|
258
|
+
# semaphore to limit concurrency
|
259
|
+
sample_semaphore = create_sample_semaphore(
|
260
|
+
config, generate_config, model.api
|
261
|
+
)
|
276
262
|
|
277
|
-
|
278
|
-
|
279
|
-
display_metrics=profile.eval_config.score_display is not False,
|
280
|
-
)
|
263
|
+
# track when samples complete and update progress as we go
|
264
|
+
progress_results: list[dict[str, SampleScore]] = []
|
281
265
|
|
282
|
-
|
283
|
-
|
284
|
-
|
266
|
+
def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
|
267
|
+
td.update_metrics(metrics)
|
268
|
+
logger.update_metrics(metrics)
|
285
269
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
270
|
+
update_metrics_display = update_metrics_display_fn(
|
271
|
+
update_metrics,
|
272
|
+
display_metrics=profile.eval_config.score_display is not False,
|
273
|
+
)
|
290
274
|
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
progress_results,
|
295
|
-
scorers,
|
296
|
-
task.epochs_reducer,
|
297
|
-
task.metrics,
|
298
|
-
)
|
275
|
+
def sample_complete(sample_score: dict[str, SampleScore]) -> None:
|
276
|
+
# Capture the result
|
277
|
+
progress_results.append(sample_score)
|
299
278
|
|
300
|
-
#
|
301
|
-
td.sample_complete(
|
279
|
+
# Increment the segment progress
|
280
|
+
td.sample_complete(
|
281
|
+
complete=len(progress_results), total=len(samples)
|
282
|
+
)
|
302
283
|
|
303
|
-
# Update metrics
|
284
|
+
# Update metrics
|
304
285
|
update_metrics_display(
|
305
286
|
len(progress_results),
|
306
287
|
progress_results,
|
@@ -309,127 +290,133 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
309
290
|
task.metrics,
|
310
291
|
)
|
311
292
|
|
312
|
-
|
313
|
-
|
314
|
-
functools.partial(
|
315
|
-
task_run_sample,
|
316
|
-
task_name=task.name,
|
317
|
-
sample=sample,
|
318
|
-
state=state,
|
319
|
-
sandbox=sandbox,
|
320
|
-
max_sandboxes=config.max_sandboxes,
|
321
|
-
sandbox_cleanup=sandbox_cleanup,
|
322
|
-
plan=plan,
|
323
|
-
scorers=scorers,
|
324
|
-
generate=generate,
|
325
|
-
progress=progress,
|
326
|
-
logger=logger if log_samples else None,
|
327
|
-
log_images=log_images,
|
328
|
-
sample_source=sample_source,
|
329
|
-
sample_error=sample_error_handler,
|
330
|
-
sample_complete=sample_complete,
|
331
|
-
fails_on_error=(
|
332
|
-
config.fail_on_error is None
|
333
|
-
or config.fail_on_error is True
|
334
|
-
),
|
335
|
-
time_limit=config.time_limit,
|
336
|
-
working_limit=config.working_limit,
|
337
|
-
semaphore=sample_semaphore,
|
338
|
-
)
|
339
|
-
for (sample, state) in zip(samples, states)
|
340
|
-
]
|
341
|
-
)
|
293
|
+
# initial progress
|
294
|
+
td.sample_complete(complete=0, total=len(samples))
|
342
295
|
|
343
|
-
#
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
results, reductions = eval_results(
|
352
|
-
samples=profile.samples,
|
353
|
-
scores=completed_scores,
|
354
|
-
reducers=task.epochs_reducer,
|
355
|
-
scorers=scorers,
|
356
|
-
metrics=task.metrics,
|
357
|
-
)
|
296
|
+
# Update metrics to empty state
|
297
|
+
update_metrics_display(
|
298
|
+
len(progress_results),
|
299
|
+
progress_results,
|
300
|
+
scorers,
|
301
|
+
task.epochs_reducer,
|
302
|
+
task.metrics,
|
303
|
+
)
|
358
304
|
|
359
|
-
|
360
|
-
|
305
|
+
sample_results = await tg_collect(
|
306
|
+
[
|
307
|
+
functools.partial(
|
308
|
+
task_run_sample,
|
309
|
+
task_name=task.name,
|
310
|
+
sample=sample,
|
311
|
+
state=state,
|
312
|
+
sandbox=sandbox,
|
313
|
+
max_sandboxes=config.max_sandboxes,
|
314
|
+
sandbox_cleanup=sandbox_cleanup,
|
315
|
+
plan=plan,
|
316
|
+
scorers=scorers,
|
317
|
+
generate=generate,
|
318
|
+
progress=progress,
|
319
|
+
logger=logger if log_samples else None,
|
320
|
+
log_images=log_images,
|
321
|
+
sample_source=sample_source,
|
322
|
+
sample_error=sample_error_handler,
|
323
|
+
sample_complete=sample_complete,
|
324
|
+
fails_on_error=(
|
325
|
+
config.fail_on_error is None
|
326
|
+
or config.fail_on_error is True
|
327
|
+
),
|
328
|
+
time_limit=config.time_limit,
|
329
|
+
working_limit=config.working_limit,
|
330
|
+
semaphore=sample_semaphore,
|
331
|
+
)
|
332
|
+
for (sample, state) in zip(samples, states)
|
333
|
+
]
|
334
|
+
)
|
361
335
|
|
362
|
-
|
363
|
-
|
364
|
-
|
336
|
+
# compute and record metrics if we have scores
|
337
|
+
completed_scores = [
|
338
|
+
score_dict
|
339
|
+
for score_dict in sample_results
|
340
|
+
if isinstance(score_dict, dict)
|
341
|
+
]
|
342
|
+
|
343
|
+
if len(completed_scores) > 0:
|
344
|
+
results, reductions = eval_results(
|
345
|
+
samples=profile.samples,
|
346
|
+
scores=completed_scores,
|
347
|
+
reducers=task.epochs_reducer,
|
348
|
+
scorers=scorers,
|
349
|
+
metrics=task.metrics,
|
365
350
|
)
|
366
351
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
352
|
+
# collect eval data
|
353
|
+
collect_eval_data(stats)
|
354
|
+
|
355
|
+
# finish w/ success status
|
356
|
+
eval_log = await logger.log_finish("success", stats, results, reductions)
|
357
|
+
|
358
|
+
# display task summary
|
359
|
+
td.complete(
|
360
|
+
TaskSuccess(
|
361
|
+
samples_completed=logger.samples_completed,
|
362
|
+
stats=stats,
|
363
|
+
results=results or EvalResults(),
|
374
364
|
)
|
365
|
+
)
|
375
366
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
367
|
+
except anyio.get_cancelled_exc_class():
|
368
|
+
with anyio.CancelScope(shield=True):
|
369
|
+
# collect eval data
|
370
|
+
collect_eval_data(stats)
|
380
371
|
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
372
|
+
# finish w/ cancelled status
|
373
|
+
eval_log = await logger.log_finish(
|
374
|
+
"cancelled", stats, results, reductions
|
375
|
+
)
|
385
376
|
|
386
|
-
|
387
|
-
|
377
|
+
# display task cancelled
|
378
|
+
td.complete(TaskCancelled(logger.samples_completed, stats))
|
388
379
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
380
|
+
except BaseException as ex:
|
381
|
+
if options.debug_errors:
|
382
|
+
raise
|
383
|
+
else:
|
384
|
+
# get exception info
|
385
|
+
type, value, traceback = sys.exc_info()
|
386
|
+
type = type if type else BaseException
|
387
|
+
value = value if value else ex
|
397
388
|
|
398
|
-
|
399
|
-
|
389
|
+
# build eval error
|
390
|
+
error = eval_error(ex, type, value, traceback)
|
400
391
|
|
401
|
-
|
402
|
-
|
392
|
+
# collect eval data
|
393
|
+
collect_eval_data(stats)
|
403
394
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
395
|
+
# finish with error status
|
396
|
+
eval_log = await logger.log_finish(
|
397
|
+
"error", stats, results, reductions, error
|
398
|
+
)
|
408
399
|
|
409
|
-
|
410
|
-
|
411
|
-
TaskError(logger.samples_completed, type, value, traceback)
|
412
|
-
)
|
400
|
+
# display it
|
401
|
+
td.complete(TaskError(logger.samples_completed, type, value, traceback))
|
413
402
|
|
414
|
-
|
415
|
-
|
416
|
-
|
403
|
+
# notify the view module that an eval just completed
|
404
|
+
# (in case we have a view polling for new evals)
|
405
|
+
view_notify_eval(logger.location)
|
417
406
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
f"Error occurred sending telemetry: {exception_message(ex)}"
|
429
|
-
)
|
407
|
+
try:
|
408
|
+
if (
|
409
|
+
await send_telemetry("eval_log_location", eval_log.location)
|
410
|
+
== "not_handled"
|
411
|
+
):
|
412
|
+
# Converting the eval log to JSON is expensive. Only do so if
|
413
|
+
# eval_log_location was not handled.
|
414
|
+
await send_telemetry("eval_log", eval_log_json_str(eval_log))
|
415
|
+
except Exception as ex:
|
416
|
+
py_logger.warning(f"Error occurred sending telemetry: {exception_message(ex)}")
|
430
417
|
|
431
|
-
|
432
|
-
|
418
|
+
# return eval log
|
419
|
+
return eval_log
|
433
420
|
|
434
421
|
|
435
422
|
def update_metrics_display_fn(
|
@@ -914,7 +901,7 @@ async def resolve_dataset(
|
|
914
901
|
dataset: Dataset,
|
915
902
|
model_name: ModelName,
|
916
903
|
limit: int | tuple[int, int] | None,
|
917
|
-
sample_id: str | int | list[str | int] | None,
|
904
|
+
sample_id: str | int | list[str] | list[int] | list[str | int] | None,
|
918
905
|
epochs: int,
|
919
906
|
log_images: bool,
|
920
907
|
message_limit: int | None,
|