inspect-ai 0.3.88__py3-none-any.whl → 0.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -256
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +71 -36
- inspect_ai/_view/www/dist/assets/index.js +573 -475
- inspect_ai/_view/www/log-schema.json +66 -0
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -6
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.module.css +0 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +24 -6
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -7
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +4 -1
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +20 -12
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/RECORD +88 -83
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/WHEEL +1 -1
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/task/run.py
CHANGED
@@ -101,7 +101,6 @@ from .images import (
|
|
101
101
|
)
|
102
102
|
from .log import TaskLogger, collect_eval_data, log_start
|
103
103
|
from .results import eval_results
|
104
|
-
from .rundir import set_task_chdir
|
105
104
|
from .sandbox import sandboxenv_context
|
106
105
|
from .util import sample_messages, slice_dataset
|
107
106
|
|
@@ -121,6 +120,7 @@ SAMPLE_TOTAL_PROGRESS_UNITS = 1
|
|
121
120
|
class TaskRunOptions:
|
122
121
|
task: Task
|
123
122
|
model: Model
|
123
|
+
model_roles: dict[str, Model] | None
|
124
124
|
sandbox: SandboxEnvironmentSpec | None
|
125
125
|
logger: TaskLogger
|
126
126
|
eval_wd: str
|
@@ -137,6 +137,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
137
137
|
# destructure options
|
138
138
|
task = options.task
|
139
139
|
model = options.model
|
140
|
+
model_roles = options.model_roles
|
140
141
|
sandbox = options.sandbox
|
141
142
|
logger = options.logger
|
142
143
|
eval_wd = options.eval_wd
|
@@ -151,156 +152,136 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
151
152
|
generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
|
152
153
|
|
153
154
|
# init task context
|
154
|
-
init_task_context(model, options.task.approval, generate_config)
|
155
|
-
|
156
|
-
#
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
log_images=log_images,
|
183
|
-
message_limit=config.message_limit,
|
184
|
-
token_limit=config.token_limit,
|
185
|
-
)
|
186
|
-
|
187
|
-
# resolve the plan (unroll chains)
|
188
|
-
solver = solver or task.solver
|
189
|
-
if isinstance(solver, Plan):
|
190
|
-
plan = solver
|
191
|
-
elif isinstance(solver, Chain):
|
192
|
-
plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
|
193
|
-
else:
|
194
|
-
plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
|
195
|
-
|
196
|
-
# add setup solver(s) if specified
|
197
|
-
if task.setup:
|
198
|
-
plan.steps = unroll(task.setup) + plan.steps
|
199
|
-
|
200
|
-
# resolve the scorer
|
201
|
-
score = score and task.scorer is not None
|
202
|
-
scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
|
203
|
-
scorer_profiles = (
|
204
|
-
[
|
205
|
-
registry_log_name(scorer)
|
206
|
-
for scorer in scorers
|
207
|
-
if is_registry_object(scorer)
|
208
|
-
]
|
209
|
-
if scorers is not None
|
210
|
-
else ["(none)"]
|
211
|
-
)
|
212
|
-
|
213
|
-
# compute an eval directory relative log location if we can
|
214
|
-
if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
|
215
|
-
log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
|
216
|
-
else:
|
217
|
-
log_location = logger.location
|
218
|
-
|
219
|
-
# create task profile for display
|
220
|
-
profile = TaskProfile(
|
221
|
-
name=task.name,
|
222
|
-
file=logger.eval.task_file,
|
223
|
-
model=model_name,
|
224
|
-
dataset=task.dataset.name or "(samples)",
|
225
|
-
scorer=", ".join(scorer_profiles),
|
226
|
-
samples=len(samples),
|
227
|
-
steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
|
228
|
-
eval_config=config,
|
229
|
-
task_args=logger.eval.task_args,
|
230
|
-
generate_config=generate_config,
|
231
|
-
tags=tags,
|
232
|
-
log_location=log_location,
|
233
|
-
)
|
155
|
+
init_task_context(model, model_roles, options.task.approval, generate_config)
|
156
|
+
|
157
|
+
# track stats and error
|
158
|
+
results: EvalResults | None = None
|
159
|
+
reductions: list[EvalSampleReductions] | None = None
|
160
|
+
stats = EvalStats(started_at=iso_now())
|
161
|
+
|
162
|
+
# handle sample errors (raise as required)
|
163
|
+
sample_error_handler = SampleErrorHandler(config.fail_on_error, len(task.dataset))
|
164
|
+
|
165
|
+
# resolve some config
|
166
|
+
model_name = ModelName(model)
|
167
|
+
epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
|
168
|
+
sandbox_cleanup = config.sandbox_cleanup is not False
|
169
|
+
log_images = config.log_images is not False
|
170
|
+
log_samples = config.log_samples is not False
|
171
|
+
|
172
|
+
# resolve dataset
|
173
|
+
_, samples, states = await resolve_dataset(
|
174
|
+
dataset=task.dataset,
|
175
|
+
model_name=model_name,
|
176
|
+
limit=config.limit,
|
177
|
+
sample_id=config.sample_id,
|
178
|
+
epochs=epochs,
|
179
|
+
log_images=log_images,
|
180
|
+
message_limit=config.message_limit,
|
181
|
+
token_limit=config.token_limit,
|
182
|
+
)
|
234
183
|
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
state=state,
|
257
|
-
tool_calls=tool_calls,
|
258
|
-
cache=cache,
|
259
|
-
config=generate_config.merge(kwargs),
|
260
|
-
)
|
184
|
+
# resolve the plan (unroll chains)
|
185
|
+
solver = solver or task.solver
|
186
|
+
if isinstance(solver, Plan):
|
187
|
+
plan = solver
|
188
|
+
elif isinstance(solver, Chain):
|
189
|
+
plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
|
190
|
+
else:
|
191
|
+
plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
|
192
|
+
|
193
|
+
# add setup solver(s) if specified
|
194
|
+
if task.setup:
|
195
|
+
plan.steps = unroll(task.setup) + plan.steps
|
196
|
+
|
197
|
+
# resolve the scorer
|
198
|
+
score = score and task.scorer is not None
|
199
|
+
scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
|
200
|
+
scorer_profiles = (
|
201
|
+
[registry_log_name(scorer) for scorer in scorers if is_registry_object(scorer)]
|
202
|
+
if scorers is not None
|
203
|
+
else ["(none)"]
|
204
|
+
)
|
261
205
|
|
262
|
-
|
263
|
-
|
206
|
+
# compute an eval directory relative log location if we can
|
207
|
+
if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
|
208
|
+
log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
|
209
|
+
else:
|
210
|
+
log_location = logger.location
|
211
|
+
|
212
|
+
# create task profile for display
|
213
|
+
profile = TaskProfile(
|
214
|
+
name=task.name,
|
215
|
+
file=logger.eval.task_file,
|
216
|
+
model=model_name,
|
217
|
+
dataset=task.dataset.name or "(samples)",
|
218
|
+
scorer=", ".join(scorer_profiles),
|
219
|
+
samples=len(samples),
|
220
|
+
steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
|
221
|
+
eval_config=config,
|
222
|
+
task_args=logger.eval.task_args,
|
223
|
+
generate_config=generate_config,
|
224
|
+
tags=tags,
|
225
|
+
log_location=log_location,
|
226
|
+
)
|
264
227
|
|
265
|
-
|
266
|
-
|
267
|
-
|
228
|
+
with display().task(
|
229
|
+
profile,
|
230
|
+
) as td:
|
231
|
+
try:
|
232
|
+
# start the log
|
233
|
+
await log_start(logger, plan, generate_config)
|
234
|
+
|
235
|
+
with td.progress() as p:
|
236
|
+
# forward progress
|
237
|
+
def progress(number: int) -> None:
|
238
|
+
p.update(number)
|
239
|
+
|
240
|
+
# provide solvers a function that they can use to generate output
|
241
|
+
async def generate(
|
242
|
+
state: TaskState,
|
243
|
+
tool_calls: Literal["loop", "single", "none"] = "loop",
|
244
|
+
cache: bool | CachePolicy = False,
|
245
|
+
**kwargs: Unpack[GenerateConfigArgs],
|
246
|
+
) -> TaskState:
|
247
|
+
return await task_generate(
|
248
|
+
model=model,
|
249
|
+
state=state,
|
250
|
+
tool_calls=tool_calls,
|
251
|
+
cache=cache,
|
252
|
+
config=generate_config.merge(kwargs),
|
268
253
|
)
|
269
254
|
|
270
|
-
|
271
|
-
|
255
|
+
# set generate for fork module
|
256
|
+
set_task_generate(generate)
|
272
257
|
|
273
|
-
|
274
|
-
|
275
|
-
|
258
|
+
# semaphore to limit concurrency
|
259
|
+
sample_semaphore = create_sample_semaphore(
|
260
|
+
config, generate_config, model.api
|
261
|
+
)
|
276
262
|
|
277
|
-
|
278
|
-
|
279
|
-
display_metrics=profile.eval_config.score_display is not False,
|
280
|
-
)
|
263
|
+
# track when samples complete and update progress as we go
|
264
|
+
progress_results: list[dict[str, SampleScore]] = []
|
281
265
|
|
282
|
-
|
283
|
-
|
284
|
-
|
266
|
+
def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
|
267
|
+
td.update_metrics(metrics)
|
268
|
+
logger.update_metrics(metrics)
|
285
269
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
270
|
+
update_metrics_display = update_metrics_display_fn(
|
271
|
+
update_metrics,
|
272
|
+
display_metrics=profile.eval_config.score_display is not False,
|
273
|
+
)
|
290
274
|
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
progress_results,
|
295
|
-
scorers,
|
296
|
-
task.epochs_reducer,
|
297
|
-
task.metrics,
|
298
|
-
)
|
275
|
+
def sample_complete(sample_score: dict[str, SampleScore]) -> None:
|
276
|
+
# Capture the result
|
277
|
+
progress_results.append(sample_score)
|
299
278
|
|
300
|
-
#
|
301
|
-
td.sample_complete(
|
279
|
+
# Increment the segment progress
|
280
|
+
td.sample_complete(
|
281
|
+
complete=len(progress_results), total=len(samples)
|
282
|
+
)
|
302
283
|
|
303
|
-
# Update metrics
|
284
|
+
# Update metrics
|
304
285
|
update_metrics_display(
|
305
286
|
len(progress_results),
|
306
287
|
progress_results,
|
@@ -309,127 +290,133 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
309
290
|
task.metrics,
|
310
291
|
)
|
311
292
|
|
312
|
-
|
313
|
-
|
314
|
-
functools.partial(
|
315
|
-
task_run_sample,
|
316
|
-
task_name=task.name,
|
317
|
-
sample=sample,
|
318
|
-
state=state,
|
319
|
-
sandbox=sandbox,
|
320
|
-
max_sandboxes=config.max_sandboxes,
|
321
|
-
sandbox_cleanup=sandbox_cleanup,
|
322
|
-
plan=plan,
|
323
|
-
scorers=scorers,
|
324
|
-
generate=generate,
|
325
|
-
progress=progress,
|
326
|
-
logger=logger if log_samples else None,
|
327
|
-
log_images=log_images,
|
328
|
-
sample_source=sample_source,
|
329
|
-
sample_error=sample_error_handler,
|
330
|
-
sample_complete=sample_complete,
|
331
|
-
fails_on_error=(
|
332
|
-
config.fail_on_error is None
|
333
|
-
or config.fail_on_error is True
|
334
|
-
),
|
335
|
-
time_limit=config.time_limit,
|
336
|
-
working_limit=config.working_limit,
|
337
|
-
semaphore=sample_semaphore,
|
338
|
-
)
|
339
|
-
for (sample, state) in zip(samples, states)
|
340
|
-
]
|
341
|
-
)
|
293
|
+
# initial progress
|
294
|
+
td.sample_complete(complete=0, total=len(samples))
|
342
295
|
|
343
|
-
#
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
results, reductions = eval_results(
|
352
|
-
samples=profile.samples,
|
353
|
-
scores=completed_scores,
|
354
|
-
reducers=task.epochs_reducer,
|
355
|
-
scorers=scorers,
|
356
|
-
metrics=task.metrics,
|
357
|
-
)
|
296
|
+
# Update metrics to empty state
|
297
|
+
update_metrics_display(
|
298
|
+
len(progress_results),
|
299
|
+
progress_results,
|
300
|
+
scorers,
|
301
|
+
task.epochs_reducer,
|
302
|
+
task.metrics,
|
303
|
+
)
|
358
304
|
|
359
|
-
|
360
|
-
|
305
|
+
sample_results = await tg_collect(
|
306
|
+
[
|
307
|
+
functools.partial(
|
308
|
+
task_run_sample,
|
309
|
+
task_name=task.name,
|
310
|
+
sample=sample,
|
311
|
+
state=state,
|
312
|
+
sandbox=sandbox,
|
313
|
+
max_sandboxes=config.max_sandboxes,
|
314
|
+
sandbox_cleanup=sandbox_cleanup,
|
315
|
+
plan=plan,
|
316
|
+
scorers=scorers,
|
317
|
+
generate=generate,
|
318
|
+
progress=progress,
|
319
|
+
logger=logger if log_samples else None,
|
320
|
+
log_images=log_images,
|
321
|
+
sample_source=sample_source,
|
322
|
+
sample_error=sample_error_handler,
|
323
|
+
sample_complete=sample_complete,
|
324
|
+
fails_on_error=(
|
325
|
+
config.fail_on_error is None
|
326
|
+
or config.fail_on_error is True
|
327
|
+
),
|
328
|
+
time_limit=config.time_limit,
|
329
|
+
working_limit=config.working_limit,
|
330
|
+
semaphore=sample_semaphore,
|
331
|
+
)
|
332
|
+
for (sample, state) in zip(samples, states)
|
333
|
+
]
|
334
|
+
)
|
361
335
|
|
362
|
-
|
363
|
-
|
364
|
-
|
336
|
+
# compute and record metrics if we have scores
|
337
|
+
completed_scores = [
|
338
|
+
score_dict
|
339
|
+
for score_dict in sample_results
|
340
|
+
if isinstance(score_dict, dict)
|
341
|
+
]
|
342
|
+
|
343
|
+
if len(completed_scores) > 0:
|
344
|
+
results, reductions = eval_results(
|
345
|
+
samples=profile.samples,
|
346
|
+
scores=completed_scores,
|
347
|
+
reducers=task.epochs_reducer,
|
348
|
+
scorers=scorers,
|
349
|
+
metrics=task.metrics,
|
365
350
|
)
|
366
351
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
352
|
+
# collect eval data
|
353
|
+
collect_eval_data(stats)
|
354
|
+
|
355
|
+
# finish w/ success status
|
356
|
+
eval_log = await logger.log_finish("success", stats, results, reductions)
|
357
|
+
|
358
|
+
# display task summary
|
359
|
+
td.complete(
|
360
|
+
TaskSuccess(
|
361
|
+
samples_completed=logger.samples_completed,
|
362
|
+
stats=stats,
|
363
|
+
results=results or EvalResults(),
|
374
364
|
)
|
365
|
+
)
|
375
366
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
367
|
+
except anyio.get_cancelled_exc_class():
|
368
|
+
with anyio.CancelScope(shield=True):
|
369
|
+
# collect eval data
|
370
|
+
collect_eval_data(stats)
|
380
371
|
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
372
|
+
# finish w/ cancelled status
|
373
|
+
eval_log = await logger.log_finish(
|
374
|
+
"cancelled", stats, results, reductions
|
375
|
+
)
|
385
376
|
|
386
|
-
|
387
|
-
|
377
|
+
# display task cancelled
|
378
|
+
td.complete(TaskCancelled(logger.samples_completed, stats))
|
388
379
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
380
|
+
except BaseException as ex:
|
381
|
+
if options.debug_errors:
|
382
|
+
raise
|
383
|
+
else:
|
384
|
+
# get exception info
|
385
|
+
type, value, traceback = sys.exc_info()
|
386
|
+
type = type if type else BaseException
|
387
|
+
value = value if value else ex
|
397
388
|
|
398
|
-
|
399
|
-
|
389
|
+
# build eval error
|
390
|
+
error = eval_error(ex, type, value, traceback)
|
400
391
|
|
401
|
-
|
402
|
-
|
392
|
+
# collect eval data
|
393
|
+
collect_eval_data(stats)
|
403
394
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
395
|
+
# finish with error status
|
396
|
+
eval_log = await logger.log_finish(
|
397
|
+
"error", stats, results, reductions, error
|
398
|
+
)
|
408
399
|
|
409
|
-
|
410
|
-
|
411
|
-
TaskError(logger.samples_completed, type, value, traceback)
|
412
|
-
)
|
400
|
+
# display it
|
401
|
+
td.complete(TaskError(logger.samples_completed, type, value, traceback))
|
413
402
|
|
414
|
-
|
415
|
-
|
416
|
-
|
403
|
+
# notify the view module that an eval just completed
|
404
|
+
# (in case we have a view polling for new evals)
|
405
|
+
view_notify_eval(logger.location)
|
417
406
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
f"Error occurred sending telemetry: {exception_message(ex)}"
|
429
|
-
)
|
407
|
+
try:
|
408
|
+
if (
|
409
|
+
await send_telemetry("eval_log_location", eval_log.location)
|
410
|
+
== "not_handled"
|
411
|
+
):
|
412
|
+
# Converting the eval log to JSON is expensive. Only do so if
|
413
|
+
# eval_log_location was not handled.
|
414
|
+
await send_telemetry("eval_log", eval_log_json_str(eval_log))
|
415
|
+
except Exception as ex:
|
416
|
+
py_logger.warning(f"Error occurred sending telemetry: {exception_message(ex)}")
|
430
417
|
|
431
|
-
|
432
|
-
|
418
|
+
# return eval log
|
419
|
+
return eval_log
|
433
420
|
|
434
421
|
|
435
422
|
def update_metrics_display_fn(
|
@@ -655,18 +642,6 @@ async def task_run_sample(
|
|
655
642
|
)
|
656
643
|
)
|
657
644
|
|
658
|
-
# sample init event (remove file bodies as they have content or absolute paths)
|
659
|
-
event_sample = sample.model_copy(
|
660
|
-
update=dict(files={k: "" for k in sample.files.keys()})
|
661
|
-
if sample.files
|
662
|
-
else None
|
663
|
-
)
|
664
|
-
transcript()._event(
|
665
|
-
SampleInitEvent(
|
666
|
-
sample=event_sample, state=state_jsonable(state)
|
667
|
-
)
|
668
|
-
)
|
669
|
-
|
670
645
|
# set progress for plan then run it
|
671
646
|
state = await plan(state, generate)
|
672
647
|
|
@@ -914,7 +889,7 @@ async def resolve_dataset(
|
|
914
889
|
dataset: Dataset,
|
915
890
|
model_name: ModelName,
|
916
891
|
limit: int | tuple[int, int] | None,
|
917
|
-
sample_id: str | int | list[str | int] | None,
|
892
|
+
sample_id: str | int | list[str] | list[int] | list[str | int] | None,
|
918
893
|
epochs: int,
|
919
894
|
log_images: bool,
|
920
895
|
message_limit: int | None,
|
inspect_ai/_eval/task/task.py
CHANGED
@@ -54,6 +54,7 @@ class Task:
|
|
54
54
|
metrics: list[Metric] | dict[str, list[Metric]] | None = None,
|
55
55
|
model: str | Model | None = None,
|
56
56
|
config: GenerateConfig = GenerateConfig(),
|
57
|
+
model_roles: dict[str, str | Model] | None = None,
|
57
58
|
sandbox: SandboxEnvironmentType | None = None,
|
58
59
|
approval: str | list[ApprovalPolicy] | None = None,
|
59
60
|
epochs: int | Epochs | None = None,
|
@@ -79,7 +80,8 @@ class Task:
|
|
79
80
|
scorer: Scorer used to evaluate model output.
|
80
81
|
metrics: Alternative metrics (overrides the metrics provided by the specified scorer).
|
81
82
|
model: Default model for task (Optional, defaults to eval model).
|
82
|
-
config: Model generation config
|
83
|
+
config: Model generation config for default model (does not apply to model roles)
|
84
|
+
model_roles: Named roles for use in `get_model()`.
|
83
85
|
sandbox: Sandbox environment type (or optionally a str or tuple with a shorthand spec)
|
84
86
|
approval: Tool use approval policies.
|
85
87
|
Either a path to an approval policy config file or a list of approval policies. Defaults to no approval policy.
|
@@ -136,6 +138,7 @@ class Task:
|
|
136
138
|
self.metrics = metrics
|
137
139
|
self.model = resolve_model(model)
|
138
140
|
self.config = config
|
141
|
+
self.model_roles = resolve_model_roles(model_roles)
|
139
142
|
self.sandbox = resolve_sandbox_environment(sandbox)
|
140
143
|
self.approval = resolve_approval(approval)
|
141
144
|
epochs = resolve_epochs(epochs)
|
@@ -185,6 +188,7 @@ def task_with(
|
|
185
188
|
metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN,
|
186
189
|
model: str | Model | NotGiven = NOT_GIVEN,
|
187
190
|
config: GenerateConfig | NotGiven = NOT_GIVEN,
|
191
|
+
model_roles: dict[str, str | Model] | NotGiven = NOT_GIVEN,
|
188
192
|
sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN,
|
189
193
|
approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN,
|
190
194
|
epochs: int | Epochs | None | NotGiven = NOT_GIVEN,
|
@@ -214,7 +218,8 @@ def task_with(
|
|
214
218
|
scorer: Scorer used to evaluate model output.
|
215
219
|
metrics: Alternative metrics (overrides the metrics provided by the specified scorer).
|
216
220
|
model: Default model for task (Optional, defaults to eval model).
|
217
|
-
config: Model generation config
|
221
|
+
config: Model generation config for default model (does not apply to model roles)
|
222
|
+
model_roles: Named roles for use in `get_model()`.
|
218
223
|
sandbox: Sandbox environment type (or optionally a str or tuple with a shorthand spec)
|
219
224
|
approval: Tool use approval policies.
|
220
225
|
Either a path to an approval policy config file or a list of approval policies. Defaults to no approval policy.
|
@@ -257,6 +262,8 @@ def task_with(
|
|
257
262
|
task.model = resolve_model(model)
|
258
263
|
if not isinstance(config, NotGiven):
|
259
264
|
task.config = config
|
265
|
+
if not isinstance(model_roles, NotGiven):
|
266
|
+
task.model_roles = resolve_model_roles(model_roles)
|
260
267
|
if not isinstance(sandbox, NotGiven):
|
261
268
|
task.sandbox = resolve_sandbox_environment(sandbox)
|
262
269
|
if not isinstance(approval, NotGiven):
|
@@ -315,6 +322,7 @@ class PreviousTask:
|
|
315
322
|
task: str | Task
|
316
323
|
task_args: dict[str, Any]
|
317
324
|
model: Model | None
|
325
|
+
model_roles: dict[str, Model] | None
|
318
326
|
log: EvalLog
|
319
327
|
|
320
328
|
|
@@ -365,6 +373,21 @@ def resolve_model(model: str | Model | None) -> Model | None:
|
|
365
373
|
return model
|
366
374
|
|
367
375
|
|
376
|
+
def resolve_model_roles(
|
377
|
+
model_roles: dict[str, str | Model] | None,
|
378
|
+
) -> dict[str, Model] | None:
|
379
|
+
if model_roles is not None:
|
380
|
+
resolved_model_roles = {
|
381
|
+
k: get_model(v, memoize=False) if isinstance(v, str) else v
|
382
|
+
for k, v in model_roles.items()
|
383
|
+
}
|
384
|
+
for k, v in resolved_model_roles.items():
|
385
|
+
v._set_role(k)
|
386
|
+
return resolved_model_roles
|
387
|
+
else:
|
388
|
+
return None
|
389
|
+
|
390
|
+
|
368
391
|
def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
|
369
392
|
return (
|
370
393
|
scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
|
inspect_ai/_eval/task/util.py
CHANGED
@@ -25,13 +25,6 @@ def task_run_dir(task: Task) -> str:
|
|
25
25
|
return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd())
|
26
26
|
|
27
27
|
|
28
|
-
def task_chdir(task: Task) -> str | None:
|
29
|
-
if task.attribs.get("chdir", False) is True:
|
30
|
-
return task_run_dir(task)
|
31
|
-
else:
|
32
|
-
return None
|
33
|
-
|
34
|
-
|
35
28
|
def task_file(task: Task, relative: bool = False) -> str | None:
|
36
29
|
file = cast(str | None, getattr(task, TASK_FILE_ATTR, None))
|
37
30
|
if file:
|
@@ -46,7 +39,7 @@ def task_file(task: Task, relative: bool = False) -> str | None:
|
|
46
39
|
def slice_dataset(
|
47
40
|
dataset: Dataset,
|
48
41
|
limit: int | tuple[int, int] | None,
|
49
|
-
sample_id: str | int | list[str | int] | None,
|
42
|
+
sample_id: str | int | list[str] | list[int] | list[str | int] | None,
|
50
43
|
) -> Dataset:
|
51
44
|
def normalise(id: str | int | None) -> str:
|
52
45
|
if isinstance(id, str) and id.isdigit():
|