inspect-ai 0.3.51__py3-none-any.whl → 0.3.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +44 -2
- inspect_ai/_display/core/config.py +4 -0
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/core/progress.py +9 -3
- inspect_ai/_display/core/results.py +8 -4
- inspect_ai/_display/textual/widgets/task_detail.py +45 -13
- inspect_ai/_display/textual/widgets/tasks.py +86 -5
- inspect_ai/_display/textual/widgets/transcript.py +4 -17
- inspect_ai/_eval/eval.py +29 -1
- inspect_ai/_eval/evalset.py +7 -0
- inspect_ai/_eval/registry.py +2 -2
- inspect_ai/_eval/task/log.py +6 -1
- inspect_ai/_eval/task/results.py +22 -4
- inspect_ai/_eval/task/run.py +18 -12
- inspect_ai/_eval/task/sandbox.py +72 -43
- inspect_ai/_eval/task/task.py +4 -0
- inspect_ai/_eval/task/util.py +17 -6
- inspect_ai/_util/logger.py +10 -2
- inspect_ai/_util/samples.py +7 -0
- inspect_ai/_util/transcript.py +8 -0
- inspect_ai/_view/www/App.css +13 -0
- inspect_ai/_view/www/dist/assets/index.css +13 -0
- inspect_ai/_view/www/dist/assets/index.js +105 -55
- inspect_ai/_view/www/src/App.mjs +31 -6
- inspect_ai/_view/www/src/Types.mjs +6 -0
- inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
- inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
- inspect_ai/_view/www/src/components/Tools.mjs +46 -18
- inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +18 -5
- inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
- inspect_ai/log/_log.py +6 -0
- inspect_ai/log/_recorders/eval.py +8 -7
- inspect_ai/model/_call_tools.py +2 -6
- inspect_ai/model/_generate_config.py +6 -0
- inspect_ai/model/_model.py +18 -4
- inspect_ai/model/_providers/azureai.py +22 -2
- inspect_ai/model/_providers/bedrock.py +17 -1
- inspect_ai/model/_providers/hf.py +1 -1
- inspect_ai/model/_providers/openai.py +32 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/vllm.py +1 -1
- inspect_ai/model/_render.py +7 -6
- inspect_ai/model/_trace.py +1 -1
- inspect_ai/solver/_basic_agent.py +8 -1
- inspect_ai/tool/_tool_transcript.py +28 -0
- inspect_ai/util/_sandbox/context.py +1 -2
- inspect_ai/util/_sandbox/docker/config.py +8 -10
- inspect_ai/util/_sandbox/docker/docker.py +9 -5
- inspect_ai/util/_sandbox/docker/util.py +3 -3
- inspect_ai/util/_sandbox/environment.py +7 -2
- inspect_ai/util/_sandbox/limits.py +1 -1
- inspect_ai/util/_sandbox/local.py +8 -9
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/METADATA +2 -4
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/RECORD +60 -59
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/eval.py
CHANGED
@@ -61,6 +61,7 @@ def eval(
|
|
61
61
|
log_dir: str | None = None,
|
62
62
|
log_format: Literal["eval", "json"] | None = None,
|
63
63
|
limit: int | tuple[int, int] | None = None,
|
64
|
+
sample_id: str | int | list[str | int] | None = None,
|
64
65
|
epochs: int | Epochs | None = None,
|
65
66
|
fail_on_error: bool | float | None = None,
|
66
67
|
debug_errors: bool | None = None,
|
@@ -70,6 +71,7 @@ def eval(
|
|
70
71
|
max_samples: int | None = None,
|
71
72
|
max_tasks: int | None = None,
|
72
73
|
max_subprocesses: int | None = None,
|
74
|
+
max_sandboxes: int | None = None,
|
73
75
|
log_samples: bool | None = None,
|
74
76
|
log_images: bool | None = None,
|
75
77
|
log_buffer: int | None = None,
|
@@ -110,6 +112,7 @@ def eval(
|
|
110
112
|
to "eval", the native high-performance format).
|
111
113
|
limit (int | tuple[int, int] | None): Limit evaluated samples
|
112
114
|
(defaults to all samples).
|
115
|
+
sample_id (str | int | list[str | int] | None): Evaluate specific sample(s) from the dataset.
|
113
116
|
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
|
114
117
|
reducer function(s) used to combine sample scores (defaults to "mean")
|
115
118
|
fail_on_error (bool | float | None): `True` to fail on first sample error
|
@@ -127,6 +130,8 @@ def eval(
|
|
127
130
|
(default is 1)
|
128
131
|
max_subprocesses (int | None): Maximum number of subprocesses to
|
129
132
|
run in parallel (default is os.cpu_count())
|
133
|
+
max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
|
134
|
+
to run in parallel.
|
130
135
|
log_samples: (bool | None): Log detailed samples and scores (defaults to True)
|
131
136
|
log_images: (bool | None): Log base64 encoded version of images,
|
132
137
|
even if specified as a filename or URL (defaults to False)
|
@@ -163,6 +168,7 @@ def eval(
|
|
163
168
|
log_dir=log_dir,
|
164
169
|
log_format=log_format,
|
165
170
|
limit=limit,
|
171
|
+
sample_id=sample_id,
|
166
172
|
epochs=epochs,
|
167
173
|
fail_on_error=fail_on_error,
|
168
174
|
debug_errors=debug_errors,
|
@@ -172,6 +178,7 @@ def eval(
|
|
172
178
|
max_samples=max_samples,
|
173
179
|
max_tasks=max_tasks,
|
174
180
|
max_subprocesses=max_subprocesses,
|
181
|
+
max_sandboxes=max_sandboxes,
|
175
182
|
log_samples=log_samples,
|
176
183
|
log_images=log_images,
|
177
184
|
log_buffer=log_buffer,
|
@@ -198,6 +205,7 @@ async def eval_async(
|
|
198
205
|
log_dir: str | None = None,
|
199
206
|
log_format: Literal["eval", "json"] | None = None,
|
200
207
|
limit: int | tuple[int, int] | None = None,
|
208
|
+
sample_id: str | int | list[str | int] | None = None,
|
201
209
|
epochs: int | Epochs | None = None,
|
202
210
|
fail_on_error: bool | float | None = None,
|
203
211
|
debug_errors: bool | None = None,
|
@@ -207,6 +215,7 @@ async def eval_async(
|
|
207
215
|
max_samples: int | None = None,
|
208
216
|
max_tasks: int | None = None,
|
209
217
|
max_subprocesses: int | None = None,
|
218
|
+
max_sandboxes: int | None = None,
|
210
219
|
log_samples: bool | None = None,
|
211
220
|
log_images: bool | None = None,
|
212
221
|
log_buffer: int | None = None,
|
@@ -245,8 +254,9 @@ async def eval_async(
|
|
245
254
|
(defaults to file log in ./logs directory).
|
246
255
|
log_format (Literal["eval", "json"] | None): Format for writing log files (defaults
|
247
256
|
to "eval", the native high-performance format).
|
248
|
-
limit (int |
|
257
|
+
limit (str | int | list[str | int] | None): Limit evaluated samples
|
249
258
|
(defaults to all samples).
|
259
|
+
sample_id (str | list[str] | None): Evaluate specific sample(s) from the dataset.
|
250
260
|
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
|
251
261
|
reducer function(s) used to combine sample scores (defaults to "mean")
|
252
262
|
fail_on_error (bool | float | None): `True` to fail on first sample error
|
@@ -263,6 +273,8 @@ async def eval_async(
|
|
263
273
|
(default is 1)
|
264
274
|
max_subprocesses (int | None): Maximum number of subprocesses to
|
265
275
|
run in parallel (default is os.cpu_count())
|
276
|
+
max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
|
277
|
+
to run in parallel.
|
266
278
|
log_samples: (bool | None): Log detailed samples and scores (defaults to True)
|
267
279
|
log_images: (bool | None): Log base64 encoded version of images,
|
268
280
|
even if specified as a filename or URL (defaults to False)
|
@@ -335,6 +347,10 @@ async def eval_async(
|
|
335
347
|
# resolve solver
|
336
348
|
solver = chain(solver) if isinstance(solver, list) else solver
|
337
349
|
|
350
|
+
# ensure consistency of limit and sample_id
|
351
|
+
if sample_id is not None and limit is not None:
|
352
|
+
raise ValueError("You cannot specify both sample_id and limit.")
|
353
|
+
|
338
354
|
# resolve epochs
|
339
355
|
if isinstance(epochs, int):
|
340
356
|
epochs = Epochs(epochs)
|
@@ -345,6 +361,7 @@ async def eval_async(
|
|
345
361
|
epochs_reducer = epochs.reducer if epochs else None
|
346
362
|
eval_config = EvalConfig(
|
347
363
|
limit=limit,
|
364
|
+
sample_id=sample_id,
|
348
365
|
epochs=epochs.epochs if epochs else None,
|
349
366
|
epochs_reducer=reducer_log_names(epochs_reducer)
|
350
367
|
if epochs_reducer
|
@@ -358,6 +375,7 @@ async def eval_async(
|
|
358
375
|
max_samples=max_samples,
|
359
376
|
max_tasks=max_tasks,
|
360
377
|
max_subprocesses=max_subprocesses,
|
378
|
+
max_sandboxes=max_sandboxes,
|
361
379
|
sandbox_cleanup=sandbox_cleanup,
|
362
380
|
log_samples=log_samples,
|
363
381
|
log_images=log_images,
|
@@ -440,6 +458,7 @@ def eval_retry(
|
|
440
458
|
max_samples: int | None = None,
|
441
459
|
max_tasks: int | None = None,
|
442
460
|
max_subprocesses: int | None = None,
|
461
|
+
max_sandboxes: int | None = None,
|
443
462
|
sandbox_cleanup: bool | None = None,
|
444
463
|
trace: bool | None = None,
|
445
464
|
fail_on_error: bool | float | None = None,
|
@@ -470,6 +489,8 @@ def eval_retry(
|
|
470
489
|
(default is 1)
|
471
490
|
max_subprocesses (int | None): Maximum number of subprocesses to
|
472
491
|
run in parallel (default is os.cpu_count())
|
492
|
+
max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
|
493
|
+
to run in parallel.
|
473
494
|
sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
|
474
495
|
(defaults to True)
|
475
496
|
trace (bool | None): Trace message interactions with evaluated model to terminal.
|
@@ -512,6 +533,7 @@ def eval_retry(
|
|
512
533
|
max_samples=max_samples,
|
513
534
|
max_tasks=max_tasks,
|
514
535
|
max_subprocesses=max_subprocesses,
|
536
|
+
max_sandboxes=max_sandboxes,
|
515
537
|
sandbox_cleanup=sandbox_cleanup,
|
516
538
|
fail_on_error=fail_on_error,
|
517
539
|
debug_errors=debug_errors,
|
@@ -535,6 +557,7 @@ async def eval_retry_async(
|
|
535
557
|
max_samples: int | None = None,
|
536
558
|
max_tasks: int | None = None,
|
537
559
|
max_subprocesses: int | None = None,
|
560
|
+
max_sandboxes: int | None = None,
|
538
561
|
sandbox_cleanup: bool | None = None,
|
539
562
|
fail_on_error: bool | float | None = None,
|
540
563
|
debug_errors: bool | None = None,
|
@@ -564,6 +587,7 @@ async def eval_retry_async(
|
|
564
587
|
(default is 1)
|
565
588
|
max_subprocesses (int): Maximum number of subprocesses to
|
566
589
|
run in parallel (default is os.cpu_count())
|
590
|
+
max_sandboxes (int): Maximum number of sandboxes (per-provider) to run in parallel.
|
567
591
|
sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
|
568
592
|
(defaults to True)
|
569
593
|
fail_on_error (bool | float | None): `True` to fail on first sample error
|
@@ -642,6 +666,7 @@ async def eval_retry_async(
|
|
642
666
|
task_args = eval_log.eval.task_args
|
643
667
|
tags = eval_log.eval.tags
|
644
668
|
limit = eval_log.eval.config.limit
|
669
|
+
sample_id = eval_log.eval.config.sample_id
|
645
670
|
epochs = (
|
646
671
|
Epochs(eval_log.eval.config.epochs, eval_log.eval.config.epochs_reducer)
|
647
672
|
if eval_log.eval.config.epochs
|
@@ -654,6 +679,7 @@ async def eval_retry_async(
|
|
654
679
|
max_samples = max_samples or eval_log.eval.config.max_samples
|
655
680
|
max_tasks = max_tasks or eval_log.eval.config.max_tasks
|
656
681
|
max_subprocesses = max_subprocesses or eval_log.eval.config.max_subprocesses
|
682
|
+
max_sandboxes = max_sandboxes or eval_log.eval.config.max_sandboxes
|
657
683
|
sandbox_cleanup = (
|
658
684
|
sandbox_cleanup
|
659
685
|
if sandbox_cleanup is not None
|
@@ -699,6 +725,7 @@ async def eval_retry_async(
|
|
699
725
|
log_dir=log_dir,
|
700
726
|
log_format=log_format,
|
701
727
|
limit=limit,
|
728
|
+
sample_id=sample_id,
|
702
729
|
epochs=epochs,
|
703
730
|
fail_on_error=fail_on_error,
|
704
731
|
debug_errors=debug_errors,
|
@@ -708,6 +735,7 @@ async def eval_retry_async(
|
|
708
735
|
max_samples=max_samples,
|
709
736
|
max_tasks=max_tasks,
|
710
737
|
max_subprocesses=max_subprocesses,
|
738
|
+
max_sandboxes=max_sandboxes,
|
711
739
|
log_samples=log_samples,
|
712
740
|
log_images=log_images,
|
713
741
|
log_buffer=log_buffer,
|
inspect_ai/_eval/evalset.py
CHANGED
@@ -65,6 +65,7 @@ def eval_set(
|
|
65
65
|
log_level_transcript: str | None = None,
|
66
66
|
log_format: Literal["eval", "json"] | None = None,
|
67
67
|
limit: int | tuple[int, int] | None = None,
|
68
|
+
sample_id: str | int | list[str | int] | None = None,
|
68
69
|
epochs: int | Epochs | None = None,
|
69
70
|
fail_on_error: bool | float | None = None,
|
70
71
|
debug_errors: bool | None = None,
|
@@ -74,6 +75,7 @@ def eval_set(
|
|
74
75
|
max_samples: int | None = None,
|
75
76
|
max_tasks: int | None = None,
|
76
77
|
max_subprocesses: int | None = None,
|
78
|
+
max_sandboxes: int | None = None,
|
77
79
|
log_samples: bool | None = None,
|
78
80
|
log_images: bool | None = None,
|
79
81
|
log_buffer: int | None = None,
|
@@ -125,6 +127,7 @@ def eval_set(
|
|
125
127
|
log files (defaults to "eval", the native high-performance format).
|
126
128
|
limit (int | tuple[int, int] | None): Limit evaluated samples
|
127
129
|
(defaults to all samples).
|
130
|
+
sample_id (str | int | list[str | int] | None): Evaluate specific sample(s) from the dataset.
|
128
131
|
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
|
129
132
|
reducer function(s) used to combine sample scores (defaults to "mean")
|
130
133
|
fail_on_error (bool | float | None): `True` to fail on first sample error
|
@@ -142,6 +145,8 @@ def eval_set(
|
|
142
145
|
(default is 1)
|
143
146
|
max_subprocesses (int | None): Maximum number of subprocesses to
|
144
147
|
run in parallel (default is os.cpu_count())
|
148
|
+
max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
|
149
|
+
to run in parallel.
|
145
150
|
log_samples: (bool | None): Log detailed samples and scores (defaults to True)
|
146
151
|
log_images: (bool | None): Log base64 encoded version of images,
|
147
152
|
even if specified as a filename or URL (defaults to False)
|
@@ -181,6 +186,7 @@ def eval_set(
|
|
181
186
|
log_dir=log_dir,
|
182
187
|
log_format=log_format,
|
183
188
|
limit=limit,
|
189
|
+
sample_id=sample_id,
|
184
190
|
epochs=epochs,
|
185
191
|
fail_on_error=fail_on_error,
|
186
192
|
debug_errors=debug_errors,
|
@@ -190,6 +196,7 @@ def eval_set(
|
|
190
196
|
max_samples=max_samples,
|
191
197
|
max_tasks=max_tasks,
|
192
198
|
max_subprocesses=max_subprocesses,
|
199
|
+
max_sandboxes=max_sandboxes,
|
193
200
|
log_samples=log_samples,
|
194
201
|
log_images=log_images,
|
195
202
|
log_buffer=log_buffer,
|
inspect_ai/_eval/registry.py
CHANGED
@@ -146,8 +146,8 @@ def task(*args: Any, name: str | None = None, **attribs: Any) -> Any:
|
|
146
146
|
# module import, so set its task file and run dir
|
147
147
|
if get_installed_package_name(task_type) is None:
|
148
148
|
module = inspect.getmodule(task_type)
|
149
|
-
if module and module
|
150
|
-
file = Path(module
|
149
|
+
if module and hasattr(module, "__file__"):
|
150
|
+
file = Path(getattr(module, "__file__"))
|
151
151
|
setattr(task_instance, TASK_FILE_ATTR, file.as_posix())
|
152
152
|
setattr(task_instance, TASK_RUN_DIR_ATTR, file.parent.as_posix())
|
153
153
|
|
inspect_ai/_eval/task/log.py
CHANGED
@@ -83,7 +83,12 @@ class TaskLogger:
|
|
83
83
|
# ensure that the dataset has sample ids and record them
|
84
84
|
sample_ids = cast(
|
85
85
|
list[int | str],
|
86
|
-
[
|
86
|
+
[
|
87
|
+
sample.id
|
88
|
+
for sample in slice_dataset(
|
89
|
+
dataset, eval_config.limit, eval_config.sample_id
|
90
|
+
)
|
91
|
+
],
|
87
92
|
)
|
88
93
|
|
89
94
|
# create eval spec
|
inspect_ai/_eval/task/results.py
CHANGED
@@ -267,10 +267,28 @@ def scorers_from_metric_dict(
|
|
267
267
|
value = target_metric(metric_scores)
|
268
268
|
else:
|
269
269
|
value = float("Nan")
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
)
|
270
|
+
|
271
|
+
# convert the value to a float (either by expanding the dict or array)
|
272
|
+
# or by casting to a float
|
273
|
+
if isinstance(value, dict):
|
274
|
+
for key, val in value.items():
|
275
|
+
name = f"{metric_name}_{key}"
|
276
|
+
result_metrics[name] = EvalMetric(
|
277
|
+
name=name,
|
278
|
+
value=cast(float, val),
|
279
|
+
)
|
280
|
+
elif isinstance(value, list):
|
281
|
+
for idx, item in enumerate(value):
|
282
|
+
name = f"{metric_name}_{idx}"
|
283
|
+
result_metrics[name] = EvalMetric(
|
284
|
+
name=name,
|
285
|
+
value=cast(float, item),
|
286
|
+
)
|
287
|
+
else:
|
288
|
+
result_metrics[metric_name] = EvalMetric(
|
289
|
+
name=metric_name,
|
290
|
+
value=cast(float, value),
|
291
|
+
)
|
274
292
|
|
275
293
|
# create a scorer result for this metric
|
276
294
|
# TODO: What if there is separate simple scorer which has a name collision with
|
inspect_ai/_eval/task/run.py
CHANGED
@@ -162,6 +162,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
162
162
|
dataset=task.dataset,
|
163
163
|
model_name=model_name,
|
164
164
|
limit=config.limit,
|
165
|
+
sample_id=config.sample_id,
|
165
166
|
epochs=epochs,
|
166
167
|
log_images=log_images,
|
167
168
|
message_limit=config.message_limit,
|
@@ -177,6 +178,10 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
177
178
|
else:
|
178
179
|
plan = Plan(unroll(solver), internal=True)
|
179
180
|
|
181
|
+
# add setup solver(s) if specified
|
182
|
+
if task.setup:
|
183
|
+
plan.steps = unroll(task.setup) + plan.steps
|
184
|
+
|
180
185
|
# reaolve the scorer
|
181
186
|
score = score and task.scorer is not None
|
182
187
|
scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
|
@@ -274,6 +279,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
|
|
274
279
|
sample=sample,
|
275
280
|
state=state,
|
276
281
|
sandbox=sandbox,
|
282
|
+
max_sandboxes=config.max_sandboxes,
|
277
283
|
sandbox_cleanup=sandbox_cleanup,
|
278
284
|
plan=plan,
|
279
285
|
scorers=scorers,
|
@@ -455,6 +461,7 @@ async def task_run_sample(
|
|
455
461
|
sample: Sample,
|
456
462
|
state: TaskState,
|
457
463
|
sandbox: SandboxEnvironmentSpec | None,
|
464
|
+
max_sandboxes: int | None,
|
458
465
|
sandbox_cleanup: bool,
|
459
466
|
plan: Plan,
|
460
467
|
scorers: list[Scorer] | None,
|
@@ -481,8 +488,8 @@ async def task_run_sample(
|
|
481
488
|
await logger.log_sample(previous_sample, flush=False)
|
482
489
|
|
483
490
|
# return score
|
484
|
-
|
485
|
-
|
491
|
+
sample_scores = (
|
492
|
+
{
|
486
493
|
key: SampleScore(
|
487
494
|
sample_id=previous_sample.id,
|
488
495
|
value=score.value,
|
@@ -492,8 +499,11 @@ async def task_run_sample(
|
|
492
499
|
)
|
493
500
|
for key, score in previous_sample.scores.items()
|
494
501
|
}
|
495
|
-
|
496
|
-
|
502
|
+
if previous_sample.scores
|
503
|
+
else {}
|
504
|
+
)
|
505
|
+
sample_complete(sample_scores)
|
506
|
+
return sample_scores
|
497
507
|
|
498
508
|
# use semaphore if provided
|
499
509
|
semaphore_cm: asyncio.Semaphore | contextlib.AbstractAsyncContextManager[None] = (
|
@@ -509,7 +519,7 @@ async def task_run_sample(
|
|
509
519
|
|
510
520
|
# use sandbox if provided
|
511
521
|
sandboxenv_cm = (
|
512
|
-
sandboxenv_context(task_name, sandbox, sandbox_cleanup, sample)
|
522
|
+
sandboxenv_context(task_name, sandbox, max_sandboxes, sandbox_cleanup, sample)
|
513
523
|
if sandbox or sample.sandbox is not None
|
514
524
|
else contextlib.nullcontext()
|
515
525
|
)
|
@@ -748,13 +758,14 @@ async def resolve_dataset(
|
|
748
758
|
dataset: Dataset,
|
749
759
|
model_name: ModelName,
|
750
760
|
limit: int | tuple[int, int] | None,
|
761
|
+
sample_id: str | int | list[str | int] | None,
|
751
762
|
epochs: int,
|
752
763
|
log_images: bool,
|
753
764
|
message_limit: int | None,
|
754
765
|
token_limit: int | None,
|
755
766
|
) -> tuple[Dataset, list[Sample], list[TaskState]]:
|
756
|
-
#
|
757
|
-
dataset = slice_dataset(dataset, limit)
|
767
|
+
# slice dataset
|
768
|
+
dataset = slice_dataset(dataset, limit, sample_id)
|
758
769
|
|
759
770
|
# apply epochs (deepcopy the samples so they remain independent)
|
760
771
|
samples: list[Sample] = []
|
@@ -864,10 +875,5 @@ def create_sample_semaphore(
|
|
864
875
|
else DEFAULT_MAX_CONNECTIONS
|
865
876
|
)
|
866
877
|
|
867
|
-
# if max_tasks is specified and max_samples is less
|
868
|
-
# than max_tasks then bump it up
|
869
|
-
if config.max_tasks is not None:
|
870
|
-
max_samples = max(max_samples, config.max_tasks)
|
871
|
-
|
872
878
|
# return the semaphore
|
873
879
|
return asyncio.Semaphore(max_samples)
|
inspect_ai/_eval/task/sandbox.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import asyncio
|
2
2
|
import base64
|
3
3
|
import contextlib
|
4
|
-
from
|
4
|
+
from random import random
|
5
|
+
from typing import AsyncGenerator, Callable, NamedTuple, cast
|
5
6
|
|
6
7
|
from inspect_ai._eval.task.task import Task
|
7
8
|
from inspect_ai._eval.task.util import task_run_dir
|
@@ -9,6 +10,7 @@ from inspect_ai._util.file import file, filesystem
|
|
9
10
|
from inspect_ai._util.registry import registry_unqualified_name
|
10
11
|
from inspect_ai._util.url import data_uri_to_base64, is_data_uri
|
11
12
|
from inspect_ai.dataset import Sample
|
13
|
+
from inspect_ai.util._concurrency import concurrency
|
12
14
|
from inspect_ai.util._sandbox.context import (
|
13
15
|
cleanup_sandbox_environments_sample,
|
14
16
|
init_sandbox_environments_sample,
|
@@ -18,12 +20,14 @@ from inspect_ai.util._sandbox.environment import (
|
|
18
20
|
SandboxEnvironmentConfigType,
|
19
21
|
SandboxEnvironmentSpec,
|
20
22
|
)
|
23
|
+
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
|
21
24
|
|
22
25
|
|
23
26
|
@contextlib.asynccontextmanager
|
24
27
|
async def sandboxenv_context(
|
25
28
|
task_name: str,
|
26
29
|
sandbox: SandboxEnvironmentSpec | None,
|
30
|
+
max_sandboxes: int | None,
|
27
31
|
cleanup: bool,
|
28
32
|
sample: Sample,
|
29
33
|
) -> AsyncGenerator[None, None]:
|
@@ -32,52 +36,77 @@ async def sandboxenv_context(
|
|
32
36
|
if not sandbox:
|
33
37
|
raise ValueError("sandboxenv_context called with no sandbox specified")
|
34
38
|
|
35
|
-
#
|
36
|
-
|
37
|
-
if sample.files:
|
38
|
-
for path, contents in sample.files.items():
|
39
|
-
files[path] = read_sandboxenv_file(contents)
|
40
|
-
|
41
|
-
# read setup script from sample (add bash shebang if necessary)
|
42
|
-
setup: bytes | None = None
|
43
|
-
if sample.setup:
|
44
|
-
setup = read_sandboxenv_file(sample.setup)
|
45
|
-
setup_str = setup.decode(encoding="utf-8")
|
46
|
-
if not setup_str.strip().startswith("#!"):
|
47
|
-
setup_str = f"#!/usr/bin/env bash\n\n{setup_str}"
|
48
|
-
setup = setup_str.encode(encoding="utf-8")
|
49
|
-
|
50
|
-
interrupted = False
|
51
|
-
environments: dict[str, SandboxEnvironment] | None = None
|
52
|
-
try:
|
53
|
-
# initialize sandbox environment,
|
54
|
-
environments = await init_sandbox_environments_sample(
|
55
|
-
type=sandbox.type,
|
56
|
-
task_name=registry_unqualified_name(task_name),
|
57
|
-
config=sandbox.config,
|
58
|
-
files=files,
|
59
|
-
setup=setup,
|
60
|
-
metadata=sample.metadata if sample.metadata else {},
|
61
|
-
)
|
62
|
-
|
63
|
-
# run sample
|
64
|
-
yield
|
65
|
-
|
66
|
-
except asyncio.CancelledError as ex:
|
67
|
-
interrupted = True
|
68
|
-
raise ex
|
39
|
+
# get sandboxenv_type
|
40
|
+
sandboxenv_type = registry_find_sandboxenv(sandbox.type)
|
69
41
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
42
|
+
# see if there is a max_sandboxes in play (passed or from type)
|
43
|
+
if max_sandboxes is None:
|
44
|
+
default_concurrency_fn = cast(
|
45
|
+
Callable[[], int | None], getattr(sandboxenv_type, "default_concurrency")
|
46
|
+
)
|
47
|
+
max_sandboxes = default_concurrency_fn()
|
48
|
+
|
49
|
+
# if we are enforcing max_sandboxes, then when samples are scheduled they may
|
50
|
+
# not get interleaved properly across tasks (because the first task will come
|
51
|
+
# in and grab all of the sandboxes). Therefore, in this case we wait a random
|
52
|
+
# delay so that all tasks/samples have an equal shot at getting scheduled.
|
53
|
+
if max_sandboxes is not None:
|
54
|
+
await asyncio.sleep(random())
|
55
|
+
|
56
|
+
# enforce concurrency if required
|
57
|
+
sandboxes_cm = (
|
58
|
+
concurrency(sandbox.type, max_sandboxes, f"sandboxes/{sandbox.type}")
|
59
|
+
if max_sandboxes is not None
|
60
|
+
else contextlib.nullcontext()
|
61
|
+
)
|
62
|
+
|
63
|
+
async with sandboxes_cm:
|
64
|
+
# read files from sample
|
65
|
+
files: dict[str, bytes] = {}
|
66
|
+
if sample.files:
|
67
|
+
for path, contents in sample.files.items():
|
68
|
+
files[path] = read_sandboxenv_file(contents)
|
69
|
+
|
70
|
+
# read setup script from sample (add bash shebang if necessary)
|
71
|
+
setup: bytes | None = None
|
72
|
+
if sample.setup:
|
73
|
+
setup = read_sandboxenv_file(sample.setup)
|
74
|
+
setup_str = setup.decode(encoding="utf-8")
|
75
|
+
if not setup_str.strip().startswith("#!"):
|
76
|
+
setup_str = f"#!/usr/bin/env bash\n\n{setup_str}"
|
77
|
+
setup = setup_str.encode(encoding="utf-8")
|
78
|
+
|
79
|
+
interrupted = False
|
80
|
+
environments: dict[str, SandboxEnvironment] | None = None
|
81
|
+
try:
|
82
|
+
# initialize sandbox environment,
|
83
|
+
environments = await init_sandbox_environments_sample(
|
84
|
+
sandboxenv_type=sandboxenv_type,
|
85
|
+
task_name=registry_unqualified_name(task_name),
|
76
86
|
config=sandbox.config,
|
77
|
-
|
78
|
-
|
87
|
+
files=files,
|
88
|
+
setup=setup,
|
89
|
+
metadata=sample.metadata if sample.metadata else {},
|
79
90
|
)
|
80
91
|
|
92
|
+
# run sample
|
93
|
+
yield
|
94
|
+
|
95
|
+
except asyncio.CancelledError as ex:
|
96
|
+
interrupted = True
|
97
|
+
raise ex
|
98
|
+
|
99
|
+
finally:
|
100
|
+
# cleanup sandbox environment
|
101
|
+
if environments and cleanup:
|
102
|
+
await cleanup_sandbox_environments_sample(
|
103
|
+
type=sandbox.type,
|
104
|
+
task_name=task_name,
|
105
|
+
config=sandbox.config,
|
106
|
+
environments=environments,
|
107
|
+
interrupted=interrupted,
|
108
|
+
)
|
109
|
+
|
81
110
|
|
82
111
|
def read_sandboxenv_file(contents: str) -> bytes:
|
83
112
|
if is_data_uri(contents):
|
inspect_ai/_eval/task/task.py
CHANGED
@@ -39,6 +39,8 @@ class Task:
|
|
39
39
|
|
40
40
|
Args:
|
41
41
|
dataset (Dataset | Sequence[Sample]): Dataset to evaluate
|
42
|
+
setup: (Solver | list[Solver] | None): Setup step (always run
|
43
|
+
even when the main `solver` is replaced).
|
42
44
|
solver: (Solver | list[Solver]): Solver or list of solvers.
|
43
45
|
Defaults to generate(), a normal call to the model.
|
44
46
|
scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output.
|
@@ -68,6 +70,7 @@ class Task:
|
|
68
70
|
def __init__(
|
69
71
|
self,
|
70
72
|
dataset: Dataset | Sequence[Sample] | None = None,
|
73
|
+
setup: Solver | list[Solver] | None = None,
|
71
74
|
solver: Solver | list[Solver] = generate(),
|
72
75
|
scorer: Scorer | list[Scorer] | None = None,
|
73
76
|
metrics: list[Metric] | dict[str, list[Metric]] | None = None,
|
@@ -119,6 +122,7 @@ class Task:
|
|
119
122
|
self.dataset: Dataset = (
|
120
123
|
dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
|
121
124
|
)
|
125
|
+
self.setup = setup
|
122
126
|
self.solver = chain(solver) if isinstance(solver, list) else solver
|
123
127
|
self.scorer = (
|
124
128
|
scorer
|
inspect_ai/_eval/task/util.py
CHANGED
@@ -39,10 +39,21 @@ def task_file(task: Task, relative: bool = False) -> str | None:
|
|
39
39
|
def slice_dataset(
|
40
40
|
dataset: Dataset,
|
41
41
|
limit: int | tuple[int, int] | None,
|
42
|
+
sample_id: str | int | list[str | int] | None,
|
42
43
|
) -> Dataset:
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
44
|
+
def normalise(id: str | int | None) -> str:
|
45
|
+
if isinstance(id, str) and id.isdigit():
|
46
|
+
id = int(id)
|
47
|
+
return id if isinstance(id, str) else str(id).zfill(20)
|
48
|
+
|
49
|
+
if sample_id is not None:
|
50
|
+
sample_id = sample_id if isinstance(sample_id, list) else [sample_id]
|
51
|
+
sample_id = [normalise(id) for id in sample_id]
|
52
|
+
return dataset.filter(lambda sample: normalise(sample.id) in sample_id)
|
53
|
+
else:
|
54
|
+
dataset_limit = (
|
55
|
+
slice(0, len(dataset))
|
56
|
+
if limit is None
|
57
|
+
else (slice(*limit) if isinstance(limit, tuple) else slice(0, limit))
|
58
|
+
)
|
59
|
+
return dataset[dataset_limit]
|
inspect_ai/_util/logger.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from logging import (
|
3
|
+
DEBUG,
|
3
4
|
INFO,
|
4
5
|
WARNING,
|
5
6
|
FileHandler,
|
@@ -129,7 +130,7 @@ def init_logger(
|
|
129
130
|
# init logging handler on demand
|
130
131
|
global _logHandler
|
131
132
|
if not _logHandler:
|
132
|
-
_logHandler = LogHandler(min(
|
133
|
+
_logHandler = LogHandler(min(DEBUG, levelno), transcript_levelno)
|
133
134
|
getLogger().addHandler(_logHandler)
|
134
135
|
|
135
136
|
# establish default capture level
|
@@ -139,6 +140,7 @@ def init_logger(
|
|
139
140
|
getLogger().setLevel(capture_level)
|
140
141
|
getLogger(PKG_NAME).setLevel(capture_level)
|
141
142
|
getLogger("httpx").setLevel(capture_level)
|
143
|
+
getLogger("botocore").setLevel(DEBUG)
|
142
144
|
|
143
145
|
# set the levelno on the global handler
|
144
146
|
_logHandler.display_level = levelno
|
@@ -154,7 +156,13 @@ def notify_logger_record(record: LogRecord, write: bool) -> None:
|
|
154
156
|
if write:
|
155
157
|
transcript()._event(LoggerEvent(message=LoggingMessage.from_log_record(record)))
|
156
158
|
global _rate_limit_count
|
157
|
-
if record.levelno <= INFO and "429" in record.getMessage()
|
159
|
+
if (record.levelno <= INFO and "429" in record.getMessage()) or (
|
160
|
+
record.levelno == DEBUG
|
161
|
+
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#validating-retry-attempts
|
162
|
+
# for boto retry logic / log messages (this is tracking standard or adapative retries)
|
163
|
+
and "botocore.retries.standard" in record.name
|
164
|
+
and "Retry needed, retrying request after delay of:" in record.getMessage()
|
165
|
+
):
|
158
166
|
_rate_limit_count = _rate_limit_count + 1
|
159
167
|
|
160
168
|
|
inspect_ai/_util/samples.py
CHANGED
@@ -7,3 +7,10 @@ def parse_samples_limit(limit: str | None) -> int | tuple[int, int] | None:
|
|
7
7
|
return (limit_split[0] - 1, limit_split[1])
|
8
8
|
else:
|
9
9
|
return None
|
10
|
+
|
11
|
+
|
12
|
+
def parse_sample_id(sample_id: str | None) -> list[str] | None:
|
13
|
+
if sample_id is not None:
|
14
|
+
return [id.strip() for id in sample_id.split(",")]
|
15
|
+
else:
|
16
|
+
return None
|
inspect_ai/_util/transcript.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import html
|
2
|
+
from typing import Any
|
2
3
|
|
3
4
|
from rich.align import AlignMethod
|
4
5
|
from rich.box import ROUNDED, Box
|
@@ -8,6 +9,8 @@ from rich.panel import Panel
|
|
8
9
|
from rich.rule import Rule
|
9
10
|
from rich.text import Text
|
10
11
|
|
12
|
+
from .format import format_function_call
|
13
|
+
|
11
14
|
|
12
15
|
def transcript_code_theme() -> str:
|
13
16
|
return "github-dark"
|
@@ -81,6 +84,11 @@ def transcript_separator(title: str, color: str) -> RenderableType:
|
|
81
84
|
return Rule(title=title, style=f"{color} bold", align="center", end="\n\n")
|
82
85
|
|
83
86
|
|
87
|
+
def transcript_function(function: str, arguments: dict[str, Any]) -> RenderableType:
|
88
|
+
call = format_function_call(function, arguments)
|
89
|
+
return transcript_markdown("```python\n" + call + "\n```\n")
|
90
|
+
|
91
|
+
|
84
92
|
LINE = Box(" ── \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n")
|
85
93
|
|
86
94
|
DOTTED = Box(" ·· \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n")
|