inspect-ai 0.3.51__py3-none-any.whl → 0.3.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. inspect_ai/_cli/eval.py +44 -2
  2. inspect_ai/_display/core/config.py +4 -0
  3. inspect_ai/_display/core/panel.py +1 -1
  4. inspect_ai/_display/core/progress.py +9 -3
  5. inspect_ai/_display/core/results.py +8 -4
  6. inspect_ai/_display/textual/widgets/task_detail.py +45 -13
  7. inspect_ai/_display/textual/widgets/tasks.py +86 -5
  8. inspect_ai/_display/textual/widgets/transcript.py +4 -17
  9. inspect_ai/_eval/eval.py +29 -1
  10. inspect_ai/_eval/evalset.py +7 -0
  11. inspect_ai/_eval/registry.py +2 -2
  12. inspect_ai/_eval/task/log.py +6 -1
  13. inspect_ai/_eval/task/results.py +22 -4
  14. inspect_ai/_eval/task/run.py +18 -12
  15. inspect_ai/_eval/task/sandbox.py +72 -43
  16. inspect_ai/_eval/task/task.py +4 -0
  17. inspect_ai/_eval/task/util.py +17 -6
  18. inspect_ai/_util/logger.py +10 -2
  19. inspect_ai/_util/samples.py +7 -0
  20. inspect_ai/_util/transcript.py +8 -0
  21. inspect_ai/_view/www/App.css +13 -0
  22. inspect_ai/_view/www/dist/assets/index.css +13 -0
  23. inspect_ai/_view/www/dist/assets/index.js +105 -55
  24. inspect_ai/_view/www/src/App.mjs +31 -6
  25. inspect_ai/_view/www/src/Types.mjs +6 -0
  26. inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
  27. inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
  28. inspect_ai/_view/www/src/components/Tools.mjs +46 -18
  29. inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
  30. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +18 -5
  31. inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
  32. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
  33. inspect_ai/log/_log.py +6 -0
  34. inspect_ai/log/_recorders/eval.py +8 -7
  35. inspect_ai/model/_call_tools.py +2 -6
  36. inspect_ai/model/_generate_config.py +6 -0
  37. inspect_ai/model/_model.py +18 -4
  38. inspect_ai/model/_providers/azureai.py +22 -2
  39. inspect_ai/model/_providers/bedrock.py +17 -1
  40. inspect_ai/model/_providers/hf.py +1 -1
  41. inspect_ai/model/_providers/openai.py +32 -8
  42. inspect_ai/model/_providers/providers.py +1 -1
  43. inspect_ai/model/_providers/vllm.py +1 -1
  44. inspect_ai/model/_render.py +7 -6
  45. inspect_ai/model/_trace.py +1 -1
  46. inspect_ai/solver/_basic_agent.py +8 -1
  47. inspect_ai/tool/_tool_transcript.py +28 -0
  48. inspect_ai/util/_sandbox/context.py +1 -2
  49. inspect_ai/util/_sandbox/docker/config.py +8 -10
  50. inspect_ai/util/_sandbox/docker/docker.py +9 -5
  51. inspect_ai/util/_sandbox/docker/util.py +3 -3
  52. inspect_ai/util/_sandbox/environment.py +7 -2
  53. inspect_ai/util/_sandbox/limits.py +1 -1
  54. inspect_ai/util/_sandbox/local.py +8 -9
  55. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/METADATA +2 -4
  56. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/RECORD +60 -59
  57. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/LICENSE +0 -0
  58. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/WHEEL +0 -0
  59. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/entry_points.txt +0 -0
  60. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/eval.py CHANGED
@@ -61,6 +61,7 @@ def eval(
61
61
  log_dir: str | None = None,
62
62
  log_format: Literal["eval", "json"] | None = None,
63
63
  limit: int | tuple[int, int] | None = None,
64
+ sample_id: str | int | list[str | int] | None = None,
64
65
  epochs: int | Epochs | None = None,
65
66
  fail_on_error: bool | float | None = None,
66
67
  debug_errors: bool | None = None,
@@ -70,6 +71,7 @@ def eval(
70
71
  max_samples: int | None = None,
71
72
  max_tasks: int | None = None,
72
73
  max_subprocesses: int | None = None,
74
+ max_sandboxes: int | None = None,
73
75
  log_samples: bool | None = None,
74
76
  log_images: bool | None = None,
75
77
  log_buffer: int | None = None,
@@ -110,6 +112,7 @@ def eval(
110
112
  to "eval", the native high-performance format).
111
113
  limit (int | tuple[int, int] | None): Limit evaluated samples
112
114
  (defaults to all samples).
115
+ sample_id (str | int | list[str | int] | None): Evaluate specific sample(s) from the dataset.
113
116
  epochs (int | Epochs | None): Epochs to repeat samples for and optional score
114
117
  reducer function(s) used to combine sample scores (defaults to "mean")
115
118
  fail_on_error (bool | float | None): `True` to fail on first sample error
@@ -127,6 +130,8 @@ def eval(
127
130
  (default is 1)
128
131
  max_subprocesses (int | None): Maximum number of subprocesses to
129
132
  run in parallel (default is os.cpu_count())
133
+ max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
134
+ to run in parallel.
130
135
  log_samples: (bool | None): Log detailed samples and scores (defaults to True)
131
136
  log_images: (bool | None): Log base64 encoded version of images,
132
137
  even if specified as a filename or URL (defaults to False)
@@ -163,6 +168,7 @@ def eval(
163
168
  log_dir=log_dir,
164
169
  log_format=log_format,
165
170
  limit=limit,
171
+ sample_id=sample_id,
166
172
  epochs=epochs,
167
173
  fail_on_error=fail_on_error,
168
174
  debug_errors=debug_errors,
@@ -172,6 +178,7 @@ def eval(
172
178
  max_samples=max_samples,
173
179
  max_tasks=max_tasks,
174
180
  max_subprocesses=max_subprocesses,
181
+ max_sandboxes=max_sandboxes,
175
182
  log_samples=log_samples,
176
183
  log_images=log_images,
177
184
  log_buffer=log_buffer,
@@ -198,6 +205,7 @@ async def eval_async(
198
205
  log_dir: str | None = None,
199
206
  log_format: Literal["eval", "json"] | None = None,
200
207
  limit: int | tuple[int, int] | None = None,
208
+ sample_id: str | int | list[str | int] | None = None,
201
209
  epochs: int | Epochs | None = None,
202
210
  fail_on_error: bool | float | None = None,
203
211
  debug_errors: bool | None = None,
@@ -207,6 +215,7 @@ async def eval_async(
207
215
  max_samples: int | None = None,
208
216
  max_tasks: int | None = None,
209
217
  max_subprocesses: int | None = None,
218
+ max_sandboxes: int | None = None,
210
219
  log_samples: bool | None = None,
211
220
  log_images: bool | None = None,
212
221
  log_buffer: int | None = None,
@@ -245,8 +254,9 @@ async def eval_async(
245
254
  (defaults to file log in ./logs directory).
246
255
  log_format (Literal["eval", "json"] | None): Format for writing log files (defaults
247
256
  to "eval", the native high-performance format).
248
- limit (int | tuple[int, int] | None): Limit evaluated samples
257
+ limit (str | int | list[str | int] | None): Limit evaluated samples
249
258
  (defaults to all samples).
259
+ sample_id (str | list[str] | None): Evaluate specific sample(s) from the dataset.
250
260
  epochs (int | Epochs | None): Epochs to repeat samples for and optional score
251
261
  reducer function(s) used to combine sample scores (defaults to "mean")
252
262
  fail_on_error (bool | float | None): `True` to fail on first sample error
@@ -263,6 +273,8 @@ async def eval_async(
263
273
  (default is 1)
264
274
  max_subprocesses (int | None): Maximum number of subprocesses to
265
275
  run in parallel (default is os.cpu_count())
276
+ max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
277
+ to run in parallel.
266
278
  log_samples: (bool | None): Log detailed samples and scores (defaults to True)
267
279
  log_images: (bool | None): Log base64 encoded version of images,
268
280
  even if specified as a filename or URL (defaults to False)
@@ -335,6 +347,10 @@ async def eval_async(
335
347
  # resolve solver
336
348
  solver = chain(solver) if isinstance(solver, list) else solver
337
349
 
350
+ # ensure consistency of limit and sample_id
351
+ if sample_id is not None and limit is not None:
352
+ raise ValueError("You cannot specify both sample_id and limit.")
353
+
338
354
  # resolve epochs
339
355
  if isinstance(epochs, int):
340
356
  epochs = Epochs(epochs)
@@ -345,6 +361,7 @@ async def eval_async(
345
361
  epochs_reducer = epochs.reducer if epochs else None
346
362
  eval_config = EvalConfig(
347
363
  limit=limit,
364
+ sample_id=sample_id,
348
365
  epochs=epochs.epochs if epochs else None,
349
366
  epochs_reducer=reducer_log_names(epochs_reducer)
350
367
  if epochs_reducer
@@ -358,6 +375,7 @@ async def eval_async(
358
375
  max_samples=max_samples,
359
376
  max_tasks=max_tasks,
360
377
  max_subprocesses=max_subprocesses,
378
+ max_sandboxes=max_sandboxes,
361
379
  sandbox_cleanup=sandbox_cleanup,
362
380
  log_samples=log_samples,
363
381
  log_images=log_images,
@@ -440,6 +458,7 @@ def eval_retry(
440
458
  max_samples: int | None = None,
441
459
  max_tasks: int | None = None,
442
460
  max_subprocesses: int | None = None,
461
+ max_sandboxes: int | None = None,
443
462
  sandbox_cleanup: bool | None = None,
444
463
  trace: bool | None = None,
445
464
  fail_on_error: bool | float | None = None,
@@ -470,6 +489,8 @@ def eval_retry(
470
489
  (default is 1)
471
490
  max_subprocesses (int | None): Maximum number of subprocesses to
472
491
  run in parallel (default is os.cpu_count())
492
+ max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
493
+ to run in parallel.
473
494
  sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
474
495
  (defaults to True)
475
496
  trace (bool | None): Trace message interactions with evaluated model to terminal.
@@ -512,6 +533,7 @@ def eval_retry(
512
533
  max_samples=max_samples,
513
534
  max_tasks=max_tasks,
514
535
  max_subprocesses=max_subprocesses,
536
+ max_sandboxes=max_sandboxes,
515
537
  sandbox_cleanup=sandbox_cleanup,
516
538
  fail_on_error=fail_on_error,
517
539
  debug_errors=debug_errors,
@@ -535,6 +557,7 @@ async def eval_retry_async(
535
557
  max_samples: int | None = None,
536
558
  max_tasks: int | None = None,
537
559
  max_subprocesses: int | None = None,
560
+ max_sandboxes: int | None = None,
538
561
  sandbox_cleanup: bool | None = None,
539
562
  fail_on_error: bool | float | None = None,
540
563
  debug_errors: bool | None = None,
@@ -564,6 +587,7 @@ async def eval_retry_async(
564
587
  (default is 1)
565
588
  max_subprocesses (int): Maximum number of subprocesses to
566
589
  run in parallel (default is os.cpu_count())
590
+ max_sandboxes (int): Maximum number of sandboxes (per-provider) to run in parallel.
567
591
  sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes
568
592
  (defaults to True)
569
593
  fail_on_error (bool | float | None): `True` to fail on first sample error
@@ -642,6 +666,7 @@ async def eval_retry_async(
642
666
  task_args = eval_log.eval.task_args
643
667
  tags = eval_log.eval.tags
644
668
  limit = eval_log.eval.config.limit
669
+ sample_id = eval_log.eval.config.sample_id
645
670
  epochs = (
646
671
  Epochs(eval_log.eval.config.epochs, eval_log.eval.config.epochs_reducer)
647
672
  if eval_log.eval.config.epochs
@@ -654,6 +679,7 @@ async def eval_retry_async(
654
679
  max_samples = max_samples or eval_log.eval.config.max_samples
655
680
  max_tasks = max_tasks or eval_log.eval.config.max_tasks
656
681
  max_subprocesses = max_subprocesses or eval_log.eval.config.max_subprocesses
682
+ max_sandboxes = max_sandboxes or eval_log.eval.config.max_sandboxes
657
683
  sandbox_cleanup = (
658
684
  sandbox_cleanup
659
685
  if sandbox_cleanup is not None
@@ -699,6 +725,7 @@ async def eval_retry_async(
699
725
  log_dir=log_dir,
700
726
  log_format=log_format,
701
727
  limit=limit,
728
+ sample_id=sample_id,
702
729
  epochs=epochs,
703
730
  fail_on_error=fail_on_error,
704
731
  debug_errors=debug_errors,
@@ -708,6 +735,7 @@ async def eval_retry_async(
708
735
  max_samples=max_samples,
709
736
  max_tasks=max_tasks,
710
737
  max_subprocesses=max_subprocesses,
738
+ max_sandboxes=max_sandboxes,
711
739
  log_samples=log_samples,
712
740
  log_images=log_images,
713
741
  log_buffer=log_buffer,
@@ -65,6 +65,7 @@ def eval_set(
65
65
  log_level_transcript: str | None = None,
66
66
  log_format: Literal["eval", "json"] | None = None,
67
67
  limit: int | tuple[int, int] | None = None,
68
+ sample_id: str | int | list[str | int] | None = None,
68
69
  epochs: int | Epochs | None = None,
69
70
  fail_on_error: bool | float | None = None,
70
71
  debug_errors: bool | None = None,
@@ -74,6 +75,7 @@ def eval_set(
74
75
  max_samples: int | None = None,
75
76
  max_tasks: int | None = None,
76
77
  max_subprocesses: int | None = None,
78
+ max_sandboxes: int | None = None,
77
79
  log_samples: bool | None = None,
78
80
  log_images: bool | None = None,
79
81
  log_buffer: int | None = None,
@@ -125,6 +127,7 @@ def eval_set(
125
127
  log files (defaults to "eval", the native high-performance format).
126
128
  limit (int | tuple[int, int] | None): Limit evaluated samples
127
129
  (defaults to all samples).
130
+ sample_id (str | int | list[str | int] | None): Evaluate specific sample(s) from the dataset.
128
131
  epochs (int | Epochs | None): Epochs to repeat samples for and optional score
129
132
  reducer function(s) used to combine sample scores (defaults to "mean")
130
133
  fail_on_error (bool | float | None): `True` to fail on first sample error
@@ -142,6 +145,8 @@ def eval_set(
142
145
  (default is 1)
143
146
  max_subprocesses (int | None): Maximum number of subprocesses to
144
147
  run in parallel (default is os.cpu_count())
148
+ max_sandboxes (int | None): Maximum number of sandboxes (per-provider)
149
+ to run in parallel.
145
150
  log_samples: (bool | None): Log detailed samples and scores (defaults to True)
146
151
  log_images: (bool | None): Log base64 encoded version of images,
147
152
  even if specified as a filename or URL (defaults to False)
@@ -181,6 +186,7 @@ def eval_set(
181
186
  log_dir=log_dir,
182
187
  log_format=log_format,
183
188
  limit=limit,
189
+ sample_id=sample_id,
184
190
  epochs=epochs,
185
191
  fail_on_error=fail_on_error,
186
192
  debug_errors=debug_errors,
@@ -190,6 +196,7 @@ def eval_set(
190
196
  max_samples=max_samples,
191
197
  max_tasks=max_tasks,
192
198
  max_subprocesses=max_subprocesses,
199
+ max_sandboxes=max_sandboxes,
193
200
  log_samples=log_samples,
194
201
  log_images=log_images,
195
202
  log_buffer=log_buffer,
@@ -146,8 +146,8 @@ def task(*args: Any, name: str | None = None, **attribs: Any) -> Any:
146
146
  # module import, so set its task file and run dir
147
147
  if get_installed_package_name(task_type) is None:
148
148
  module = inspect.getmodule(task_type)
149
- if module and module.__file__:
150
- file = Path(module.__file__)
149
+ if module and hasattr(module, "__file__"):
150
+ file = Path(getattr(module, "__file__"))
151
151
  setattr(task_instance, TASK_FILE_ATTR, file.as_posix())
152
152
  setattr(task_instance, TASK_RUN_DIR_ATTR, file.parent.as_posix())
153
153
 
@@ -83,7 +83,12 @@ class TaskLogger:
83
83
  # ensure that the dataset has sample ids and record them
84
84
  sample_ids = cast(
85
85
  list[int | str],
86
- [sample.id for sample in slice_dataset(dataset, eval_config.limit)],
86
+ [
87
+ sample.id
88
+ for sample in slice_dataset(
89
+ dataset, eval_config.limit, eval_config.sample_id
90
+ )
91
+ ],
87
92
  )
88
93
 
89
94
  # create eval spec
@@ -267,10 +267,28 @@ def scorers_from_metric_dict(
267
267
  value = target_metric(metric_scores)
268
268
  else:
269
269
  value = float("Nan")
270
- result_metrics[metric_name] = EvalMetric(
271
- name=metric_name,
272
- value=cast(float, value),
273
- )
270
+
271
+ # convert the value to a float (either by expanding the dict or array)
272
+ # or by casting to a float
273
+ if isinstance(value, dict):
274
+ for key, val in value.items():
275
+ name = f"{metric_name}_{key}"
276
+ result_metrics[name] = EvalMetric(
277
+ name=name,
278
+ value=cast(float, val),
279
+ )
280
+ elif isinstance(value, list):
281
+ for idx, item in enumerate(value):
282
+ name = f"{metric_name}_{idx}"
283
+ result_metrics[name] = EvalMetric(
284
+ name=name,
285
+ value=cast(float, item),
286
+ )
287
+ else:
288
+ result_metrics[metric_name] = EvalMetric(
289
+ name=metric_name,
290
+ value=cast(float, value),
291
+ )
274
292
 
275
293
  # create a scorer result for this metric
276
294
  # TODO: What if there is separate simple scorer which has a name collision with
@@ -162,6 +162,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
162
162
  dataset=task.dataset,
163
163
  model_name=model_name,
164
164
  limit=config.limit,
165
+ sample_id=config.sample_id,
165
166
  epochs=epochs,
166
167
  log_images=log_images,
167
168
  message_limit=config.message_limit,
@@ -177,6 +178,10 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
177
178
  else:
178
179
  plan = Plan(unroll(solver), internal=True)
179
180
 
181
+ # add setup solver(s) if specified
182
+ if task.setup:
183
+ plan.steps = unroll(task.setup) + plan.steps
184
+
180
185
  # reaolve the scorer
181
186
  score = score and task.scorer is not None
182
187
  scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
@@ -274,6 +279,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
274
279
  sample=sample,
275
280
  state=state,
276
281
  sandbox=sandbox,
282
+ max_sandboxes=config.max_sandboxes,
277
283
  sandbox_cleanup=sandbox_cleanup,
278
284
  plan=plan,
279
285
  scorers=scorers,
@@ -455,6 +461,7 @@ async def task_run_sample(
455
461
  sample: Sample,
456
462
  state: TaskState,
457
463
  sandbox: SandboxEnvironmentSpec | None,
464
+ max_sandboxes: int | None,
458
465
  sandbox_cleanup: bool,
459
466
  plan: Plan,
460
467
  scorers: list[Scorer] | None,
@@ -481,8 +488,8 @@ async def task_run_sample(
481
488
  await logger.log_sample(previous_sample, flush=False)
482
489
 
483
490
  # return score
484
- if previous_sample.scores:
485
- return {
491
+ sample_scores = (
492
+ {
486
493
  key: SampleScore(
487
494
  sample_id=previous_sample.id,
488
495
  value=score.value,
@@ -492,8 +499,11 @@ async def task_run_sample(
492
499
  )
493
500
  for key, score in previous_sample.scores.items()
494
501
  }
495
- else:
496
- return {}
502
+ if previous_sample.scores
503
+ else {}
504
+ )
505
+ sample_complete(sample_scores)
506
+ return sample_scores
497
507
 
498
508
  # use semaphore if provided
499
509
  semaphore_cm: asyncio.Semaphore | contextlib.AbstractAsyncContextManager[None] = (
@@ -509,7 +519,7 @@ async def task_run_sample(
509
519
 
510
520
  # use sandbox if provided
511
521
  sandboxenv_cm = (
512
- sandboxenv_context(task_name, sandbox, sandbox_cleanup, sample)
522
+ sandboxenv_context(task_name, sandbox, max_sandboxes, sandbox_cleanup, sample)
513
523
  if sandbox or sample.sandbox is not None
514
524
  else contextlib.nullcontext()
515
525
  )
@@ -748,13 +758,14 @@ async def resolve_dataset(
748
758
  dataset: Dataset,
749
759
  model_name: ModelName,
750
760
  limit: int | tuple[int, int] | None,
761
+ sample_id: str | int | list[str | int] | None,
751
762
  epochs: int,
752
763
  log_images: bool,
753
764
  message_limit: int | None,
754
765
  token_limit: int | None,
755
766
  ) -> tuple[Dataset, list[Sample], list[TaskState]]:
756
- # apply limit to dataset
757
- dataset = slice_dataset(dataset, limit)
767
+ # slice dataset
768
+ dataset = slice_dataset(dataset, limit, sample_id)
758
769
 
759
770
  # apply epochs (deepcopy the samples so they remain independent)
760
771
  samples: list[Sample] = []
@@ -864,10 +875,5 @@ def create_sample_semaphore(
864
875
  else DEFAULT_MAX_CONNECTIONS
865
876
  )
866
877
 
867
- # if max_tasks is specified and max_samples is less
868
- # than max_tasks then bump it up
869
- if config.max_tasks is not None:
870
- max_samples = max(max_samples, config.max_tasks)
871
-
872
878
  # return the semaphore
873
879
  return asyncio.Semaphore(max_samples)
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import base64
3
3
  import contextlib
4
- from typing import AsyncGenerator, NamedTuple
4
+ from random import random
5
+ from typing import AsyncGenerator, Callable, NamedTuple, cast
5
6
 
6
7
  from inspect_ai._eval.task.task import Task
7
8
  from inspect_ai._eval.task.util import task_run_dir
@@ -9,6 +10,7 @@ from inspect_ai._util.file import file, filesystem
9
10
  from inspect_ai._util.registry import registry_unqualified_name
10
11
  from inspect_ai._util.url import data_uri_to_base64, is_data_uri
11
12
  from inspect_ai.dataset import Sample
13
+ from inspect_ai.util._concurrency import concurrency
12
14
  from inspect_ai.util._sandbox.context import (
13
15
  cleanup_sandbox_environments_sample,
14
16
  init_sandbox_environments_sample,
@@ -18,12 +20,14 @@ from inspect_ai.util._sandbox.environment import (
18
20
  SandboxEnvironmentConfigType,
19
21
  SandboxEnvironmentSpec,
20
22
  )
23
+ from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
21
24
 
22
25
 
23
26
  @contextlib.asynccontextmanager
24
27
  async def sandboxenv_context(
25
28
  task_name: str,
26
29
  sandbox: SandboxEnvironmentSpec | None,
30
+ max_sandboxes: int | None,
27
31
  cleanup: bool,
28
32
  sample: Sample,
29
33
  ) -> AsyncGenerator[None, None]:
@@ -32,52 +36,77 @@ async def sandboxenv_context(
32
36
  if not sandbox:
33
37
  raise ValueError("sandboxenv_context called with no sandbox specified")
34
38
 
35
- # read files from sample
36
- files: dict[str, bytes] = {}
37
- if sample.files:
38
- for path, contents in sample.files.items():
39
- files[path] = read_sandboxenv_file(contents)
40
-
41
- # read setup script from sample (add bash shebang if necessary)
42
- setup: bytes | None = None
43
- if sample.setup:
44
- setup = read_sandboxenv_file(sample.setup)
45
- setup_str = setup.decode(encoding="utf-8")
46
- if not setup_str.strip().startswith("#!"):
47
- setup_str = f"#!/usr/bin/env bash\n\n{setup_str}"
48
- setup = setup_str.encode(encoding="utf-8")
49
-
50
- interrupted = False
51
- environments: dict[str, SandboxEnvironment] | None = None
52
- try:
53
- # initialize sandbox environment,
54
- environments = await init_sandbox_environments_sample(
55
- type=sandbox.type,
56
- task_name=registry_unqualified_name(task_name),
57
- config=sandbox.config,
58
- files=files,
59
- setup=setup,
60
- metadata=sample.metadata if sample.metadata else {},
61
- )
62
-
63
- # run sample
64
- yield
65
-
66
- except asyncio.CancelledError as ex:
67
- interrupted = True
68
- raise ex
39
+ # get sandboxenv_type
40
+ sandboxenv_type = registry_find_sandboxenv(sandbox.type)
69
41
 
70
- finally:
71
- # cleanup sandbox environment
72
- if environments and cleanup:
73
- await cleanup_sandbox_environments_sample(
74
- type=sandbox.type,
75
- task_name=task_name,
42
+ # see if there is a max_sandboxes in play (passed or from type)
43
+ if max_sandboxes is None:
44
+ default_concurrency_fn = cast(
45
+ Callable[[], int | None], getattr(sandboxenv_type, "default_concurrency")
46
+ )
47
+ max_sandboxes = default_concurrency_fn()
48
+
49
+ # if we are enforcing max_sandboxes, then when samples are scheduled they may
50
+ # not get interleaved properly across tasks (because the first task will come
51
+ # in and grab all of the sandboxes). Therefore, in this case we wait a random
52
+ # delay so that all tasks/samples have an equal shot at getting scheduled.
53
+ if max_sandboxes is not None:
54
+ await asyncio.sleep(random())
55
+
56
+ # enforce concurrency if required
57
+ sandboxes_cm = (
58
+ concurrency(sandbox.type, max_sandboxes, f"sandboxes/{sandbox.type}")
59
+ if max_sandboxes is not None
60
+ else contextlib.nullcontext()
61
+ )
62
+
63
+ async with sandboxes_cm:
64
+ # read files from sample
65
+ files: dict[str, bytes] = {}
66
+ if sample.files:
67
+ for path, contents in sample.files.items():
68
+ files[path] = read_sandboxenv_file(contents)
69
+
70
+ # read setup script from sample (add bash shebang if necessary)
71
+ setup: bytes | None = None
72
+ if sample.setup:
73
+ setup = read_sandboxenv_file(sample.setup)
74
+ setup_str = setup.decode(encoding="utf-8")
75
+ if not setup_str.strip().startswith("#!"):
76
+ setup_str = f"#!/usr/bin/env bash\n\n{setup_str}"
77
+ setup = setup_str.encode(encoding="utf-8")
78
+
79
+ interrupted = False
80
+ environments: dict[str, SandboxEnvironment] | None = None
81
+ try:
82
+ # initialize sandbox environment,
83
+ environments = await init_sandbox_environments_sample(
84
+ sandboxenv_type=sandboxenv_type,
85
+ task_name=registry_unqualified_name(task_name),
76
86
  config=sandbox.config,
77
- environments=environments,
78
- interrupted=interrupted,
87
+ files=files,
88
+ setup=setup,
89
+ metadata=sample.metadata if sample.metadata else {},
79
90
  )
80
91
 
92
+ # run sample
93
+ yield
94
+
95
+ except asyncio.CancelledError as ex:
96
+ interrupted = True
97
+ raise ex
98
+
99
+ finally:
100
+ # cleanup sandbox environment
101
+ if environments and cleanup:
102
+ await cleanup_sandbox_environments_sample(
103
+ type=sandbox.type,
104
+ task_name=task_name,
105
+ config=sandbox.config,
106
+ environments=environments,
107
+ interrupted=interrupted,
108
+ )
109
+
81
110
 
82
111
  def read_sandboxenv_file(contents: str) -> bytes:
83
112
  if is_data_uri(contents):
@@ -39,6 +39,8 @@ class Task:
39
39
 
40
40
  Args:
41
41
  dataset (Dataset | Sequence[Sample]): Dataset to evaluate
42
+ setup: (Solver | list[Solver] | None): Setup step (always run
43
+ even when the main `solver` is replaced).
42
44
  solver: (Solver | list[Solver]): Solver or list of solvers.
43
45
  Defaults to generate(), a normal call to the model.
44
46
  scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output.
@@ -68,6 +70,7 @@ class Task:
68
70
  def __init__(
69
71
  self,
70
72
  dataset: Dataset | Sequence[Sample] | None = None,
73
+ setup: Solver | list[Solver] | None = None,
71
74
  solver: Solver | list[Solver] = generate(),
72
75
  scorer: Scorer | list[Scorer] | None = None,
73
76
  metrics: list[Metric] | dict[str, list[Metric]] | None = None,
@@ -119,6 +122,7 @@ class Task:
119
122
  self.dataset: Dataset = (
120
123
  dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
121
124
  )
125
+ self.setup = setup
122
126
  self.solver = chain(solver) if isinstance(solver, list) else solver
123
127
  self.scorer = (
124
128
  scorer
@@ -39,10 +39,21 @@ def task_file(task: Task, relative: bool = False) -> str | None:
39
39
  def slice_dataset(
40
40
  dataset: Dataset,
41
41
  limit: int | tuple[int, int] | None,
42
+ sample_id: str | int | list[str | int] | None,
42
43
  ) -> Dataset:
43
- dataset_limit = (
44
- slice(0, len(dataset))
45
- if limit is None
46
- else (slice(*limit) if isinstance(limit, tuple) else slice(0, limit))
47
- )
48
- return dataset[dataset_limit]
44
+ def normalise(id: str | int | None) -> str:
45
+ if isinstance(id, str) and id.isdigit():
46
+ id = int(id)
47
+ return id if isinstance(id, str) else str(id).zfill(20)
48
+
49
+ if sample_id is not None:
50
+ sample_id = sample_id if isinstance(sample_id, list) else [sample_id]
51
+ sample_id = [normalise(id) for id in sample_id]
52
+ return dataset.filter(lambda sample: normalise(sample.id) in sample_id)
53
+ else:
54
+ dataset_limit = (
55
+ slice(0, len(dataset))
56
+ if limit is None
57
+ else (slice(*limit) if isinstance(limit, tuple) else slice(0, limit))
58
+ )
59
+ return dataset[dataset_limit]
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  from logging import (
3
+ DEBUG,
3
4
  INFO,
4
5
  WARNING,
5
6
  FileHandler,
@@ -129,7 +130,7 @@ def init_logger(
129
130
  # init logging handler on demand
130
131
  global _logHandler
131
132
  if not _logHandler:
132
- _logHandler = LogHandler(min(HTTP, levelno), transcript_levelno)
133
+ _logHandler = LogHandler(min(DEBUG, levelno), transcript_levelno)
133
134
  getLogger().addHandler(_logHandler)
134
135
 
135
136
  # establish default capture level
@@ -139,6 +140,7 @@ def init_logger(
139
140
  getLogger().setLevel(capture_level)
140
141
  getLogger(PKG_NAME).setLevel(capture_level)
141
142
  getLogger("httpx").setLevel(capture_level)
143
+ getLogger("botocore").setLevel(DEBUG)
142
144
 
143
145
  # set the levelno on the global handler
144
146
  _logHandler.display_level = levelno
@@ -154,7 +156,13 @@ def notify_logger_record(record: LogRecord, write: bool) -> None:
154
156
  if write:
155
157
  transcript()._event(LoggerEvent(message=LoggingMessage.from_log_record(record)))
156
158
  global _rate_limit_count
157
- if record.levelno <= INFO and "429" in record.getMessage():
159
+ if (record.levelno <= INFO and "429" in record.getMessage()) or (
160
+ record.levelno == DEBUG
161
+ # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#validating-retry-attempts
162
+ # for boto retry logic / log messages (this is tracking standard or adapative retries)
163
+ and "botocore.retries.standard" in record.name
164
+ and "Retry needed, retrying request after delay of:" in record.getMessage()
165
+ ):
158
166
  _rate_limit_count = _rate_limit_count + 1
159
167
 
160
168
 
@@ -7,3 +7,10 @@ def parse_samples_limit(limit: str | None) -> int | tuple[int, int] | None:
7
7
  return (limit_split[0] - 1, limit_split[1])
8
8
  else:
9
9
  return None
10
+
11
+
12
+ def parse_sample_id(sample_id: str | None) -> list[str] | None:
13
+ if sample_id is not None:
14
+ return [id.strip() for id in sample_id.split(",")]
15
+ else:
16
+ return None
@@ -1,4 +1,5 @@
1
1
  import html
2
+ from typing import Any
2
3
 
3
4
  from rich.align import AlignMethod
4
5
  from rich.box import ROUNDED, Box
@@ -8,6 +9,8 @@ from rich.panel import Panel
8
9
  from rich.rule import Rule
9
10
  from rich.text import Text
10
11
 
12
+ from .format import format_function_call
13
+
11
14
 
12
15
  def transcript_code_theme() -> str:
13
16
  return "github-dark"
@@ -81,6 +84,11 @@ def transcript_separator(title: str, color: str) -> RenderableType:
81
84
  return Rule(title=title, style=f"{color} bold", align="center", end="\n\n")
82
85
 
83
86
 
87
+ def transcript_function(function: str, arguments: dict[str, Any]) -> RenderableType:
88
+ call = format_function_call(function, arguments)
89
+ return transcript_markdown("```python\n" + call + "\n```\n")
90
+
91
+
84
92
  LINE = Box(" ── \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n")
85
93
 
86
94
  DOTTED = Box(" ·· \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n")