inspect-ai 0.3.73__py3-none-any.whl → 0.3.75__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. inspect_ai/__init__.py +3 -2
  2. inspect_ai/_cli/cache.py +1 -1
  3. inspect_ai/_cli/common.py +15 -0
  4. inspect_ai/_cli/eval.py +4 -5
  5. inspect_ai/_cli/log.py +1 -1
  6. inspect_ai/_cli/sandbox.py +1 -1
  7. inspect_ai/_cli/trace.py +1 -1
  8. inspect_ai/_cli/view.py +1 -1
  9. inspect_ai/_display/core/config.py +3 -1
  10. inspect_ai/_eval/eval.py +55 -61
  11. inspect_ai/_eval/evalset.py +63 -154
  12. inspect_ai/_eval/loader.py +27 -54
  13. inspect_ai/_eval/registry.py +1 -10
  14. inspect_ai/_eval/run.py +3 -4
  15. inspect_ai/_eval/task/__init__.py +8 -2
  16. inspect_ai/_eval/task/log.py +9 -1
  17. inspect_ai/_eval/task/resolved.py +35 -0
  18. inspect_ai/_eval/task/task.py +50 -69
  19. inspect_ai/_eval/task/tasks.py +30 -0
  20. inspect_ai/_util/constants.py +3 -0
  21. inspect_ai/_util/dotenv.py +17 -0
  22. inspect_ai/_util/registry.py +43 -2
  23. inspect_ai/_view/server.py +28 -10
  24. inspect_ai/_view/www/dist/assets/index.css +4 -3
  25. inspect_ai/_view/www/dist/assets/index.js +13030 -25523
  26. inspect_ai/_view/www/package.json +2 -2
  27. inspect_ai/_view/www/src/appearance/styles.ts +6 -5
  28. inspect_ai/_view/www/src/components/AnsiDisplay.tsx +2 -2
  29. inspect_ai/_view/www/src/constants.ts +3 -0
  30. inspect_ai/_view/www/src/logfile/remoteZipFile.ts +141 -20
  31. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +2 -1
  32. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +1 -1
  33. inspect_ai/_view/www/src/samples/chat/tools/tool.ts +7 -5
  34. inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.module.css +1 -0
  35. inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +3 -1
  36. inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +5 -2
  37. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +5 -1
  38. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +17 -12
  39. inspect_ai/_view/www/src/workspace/sidebar/SidebarLogEntry.tsx +2 -1
  40. inspect_ai/_view/www/yarn.lock +12 -5
  41. inspect_ai/log/_log.py +10 -1
  42. inspect_ai/log/_recorders/eval.py +27 -8
  43. inspect_ai/log/_recorders/json.py +2 -2
  44. inspect_ai/model/_cache.py +3 -1
  45. inspect_ai/model/_chat_message.py +12 -1
  46. inspect_ai/model/_model.py +25 -11
  47. inspect_ai/model/_providers/anthropic.py +34 -2
  48. inspect_ai/model/_providers/google.py +6 -2
  49. inspect_ai/model/_providers/none.py +31 -0
  50. inspect_ai/model/_providers/providers.py +7 -0
  51. inspect_ai/solver/_bridge/bridge.py +1 -1
  52. inspect_ai/solver/_chain.py +7 -6
  53. inspect_ai/tool/_tools/_computer/_computer.py +1 -1
  54. inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
  55. inspect_ai/tool/_tools/_web_search.py +2 -2
  56. inspect_ai/util/_sandbox/context.py +2 -1
  57. inspect_ai/util/_sandbox/environment.py +17 -2
  58. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/METADATA +4 -4
  59. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/RECORD +63 -60
  60. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/WHEEL +1 -1
  61. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/LICENSE +0 -0
  62. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/entry_points.txt +0 -0
  63. {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import hashlib
2
2
  import logging
3
3
  from copy import deepcopy
4
- from typing import Any, Callable, Literal, NamedTuple, Set, cast
4
+ from typing import Any, Literal, NamedTuple, Set, cast
5
5
 
6
6
  import rich
7
7
  from pydantic_core import to_json
@@ -17,6 +17,7 @@ from typing_extensions import Unpack
17
17
 
18
18
  from inspect_ai._util.error import PrerequisiteError
19
19
  from inspect_ai._util.file import basename, filesystem
20
+ from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
20
21
  from inspect_ai.approval._policy import ApprovalPolicy
21
22
  from inspect_ai.log import EvalLog
22
23
  from inspect_ai.log._bundle import bundle_log_dir
@@ -34,11 +35,14 @@ from inspect_ai.model import (
34
35
  from inspect_ai.model._generate_config import GenerateConfig
35
36
  from inspect_ai.solver._solver import Solver, SolverSpec
36
37
  from inspect_ai.util import DisplayType, SandboxEnvironmentType
38
+ from inspect_ai.util._display import init_display_type
37
39
 
38
40
  from .eval import eval, eval_init
39
- from .loader import ResolvedTask, resolve_task_args
40
- from .task import Epochs, Tasks
41
- from .task.task import PreviousTask, Task
41
+ from .loader import resolve_task_args
42
+ from .task import Epochs
43
+ from .task.resolved import ResolvedTask
44
+ from .task.task import PreviousTask
45
+ from .task.tasks import Tasks
42
46
 
43
47
  logger = logging.getLogger(__name__)
44
48
 
@@ -56,7 +60,7 @@ def eval_set(
56
60
  retry_wait: float | None = None,
57
61
  retry_connections: float | None = None,
58
62
  retry_cleanup: bool | None = None,
59
- model: str | Model | list[str] | list[Model] | None = None,
63
+ model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
60
64
  model_base_url: str | None = None,
61
65
  model_args: dict[str, Any] | str = dict(),
62
66
  task_args: dict[str, Any] | str = dict(),
@@ -107,9 +111,9 @@ def eval_set(
107
111
  (defaults to 0.5)
108
112
  retry_cleanup: Cleanup failed log files after retries
109
113
  (defaults to True)
110
- model: Model(s) for
111
- evaluation. If not specified use the value of the INSPECT_EVAL_MODEL
112
- environment variable.
114
+ model: Model(s) for evaluation. If not specified use the value of the INSPECT_EVAL_MODEL
115
+ environment variable. Specify `None` to define no default model(s), which will
116
+ leave model usage entirely up to tasks.
113
117
  model_base_url: Base URL for communicating
114
118
  with the model API.
115
119
  model_args: Model creation args
@@ -154,7 +158,7 @@ def eval_set(
154
158
  max_samples: Maximum number of samples to run in parallel
155
159
  (default is max_connections)
156
160
  max_tasks: Maximum number of tasks to run in parallel
157
- (default is 1)
161
+ (defaults to number of models being evaluated)
158
162
  max_subprocesses: Maximum number of subprocesses to
159
163
  run in parallel (default is os.cpu_count())
160
164
  max_sandboxes: Maximum number of sandboxes (per-provider)
@@ -177,13 +181,11 @@ def eval_set(
177
181
  """
178
182
 
179
183
  # helper function to run a set of evals
180
- def run_eval(
181
- tasks: list[Task] | list[PreviousTask], models: list[Model]
182
- ) -> list[EvalLog]:
184
+ def run_eval(tasks: list[ResolvedTask] | list[PreviousTask]) -> list[EvalLog]:
183
185
  # run evals
184
186
  results = eval(
185
187
  tasks=tasks,
186
- model=models,
188
+ model=None, # ResolvedTask/PreviousTask already carries its model
187
189
  model_base_url=model_base_url,
188
190
  model_args=model_args,
189
191
  task_args=task_args,
@@ -231,30 +233,10 @@ def eval_set(
231
233
  # return results
232
234
  return results
233
235
 
234
- # helper function to run a list of task groups
235
- def run_task_groups(
236
- task_groups: list[TaskGroup],
237
- run_tasks: Callable[[list[ResolvedTask]], list[Task] | list[PreviousTask]],
238
- ) -> list[EvalLog]:
239
- logs: list[EvalLog] = []
240
- for task_group in task_groups:
241
- # alias
242
- group_models, group_tasks = task_group
243
-
244
- # info log
245
- logger.info(
246
- f"eval_set (running task group): {','.join([task.task.name for task in group_tasks])}: {group_models}"
247
- )
248
-
249
- # run the evals
250
- logs.extend(
251
- run_eval(
252
- tasks=run_tasks(group_tasks),
253
- models=group_models.models,
254
- )
255
- )
256
-
257
- return logs
236
+ # initialise display (otherwise eval_init will set it to full)
237
+ display = init_display_type(display)
238
+ if display == "conversation":
239
+ raise RuntimeError("eval_set cannot be used with conversation display.")
258
240
 
259
241
  # resolve tasks
260
242
  models, _, resolved_tasks = eval_init(
@@ -283,6 +265,7 @@ def eval_set(
283
265
  retry_connections = retry_connections or 0.5
284
266
  retry_cleanup = retry_cleanup is not False
285
267
  max_connections = starting_max_connections(models, GenerateConfig(**kwargs))
268
+ max_tasks = max_tasks if max_tasks is not None else len(models)
286
269
 
287
270
  # prepare console/status
288
271
  console = rich.get_console()
@@ -331,15 +314,11 @@ def eval_set(
331
314
  pending_tasks = [
332
315
  task[1] for task in all_tasks if task[0] not in log_task_identifiers
333
316
  ]
334
- task_groups = schedule_pending_tasks(pending_tasks)
335
317
 
336
318
  # we have some pending tasks yet to run, run them
337
- if len(task_groups) > 0:
319
+ if len(pending_tasks) > 0:
338
320
  # run the tasks
339
- run_logs = run_task_groups(
340
- task_groups=task_groups,
341
- run_tasks=lambda tasks: [task.task for task in tasks],
342
- )
321
+ run_logs = run_eval(pending_tasks)
343
322
 
344
323
  # if this was the entire list of resolved tasks, return results
345
324
  if len(pending_tasks) == len(all_tasks):
@@ -365,42 +344,10 @@ def eval_set(
365
344
  for task in resolved_tasks
366
345
  if task_identifier(task) in failed_task_identifiers
367
346
  ]
368
- task_groups = schedule_retry_tasks(failed_tasks)
369
-
370
- # execute task groups (run previous task so we get the samples from the log)
371
- def run_previous_tasks(tasks: list[ResolvedTask]) -> list[PreviousTask]:
372
- def task_to_failed_log(task: ResolvedTask) -> Log:
373
- resolved_task_identifier = task_identifier(task)
374
- return next(
375
- log
376
- for log in failed_logs
377
- if log.task_identifier == resolved_task_identifier
378
- )
379
-
380
- previous_tasks: list[PreviousTask] = []
381
- for task, log in zip(tasks, map(task_to_failed_log, tasks)):
382
- # NOTE: we used to try to recreate registry objects by
383
- # by just passing the task name, but that didn't work
384
- # when evals were run from another directory. we may
385
- # want to bring this back but we'd need to resolve the
386
- # directory issues.
387
-
388
- # deepcopy so the same instance is not run twice
389
- prev_task = deepcopy(task.task)
390
-
391
- previous_tasks.append(
392
- PreviousTask(
393
- id=log.header.eval.task_id,
394
- task=prev_task,
395
- task_args=resolve_task_args(task.task),
396
- log=read_eval_log(log.info),
397
- )
398
- )
399
-
400
- return previous_tasks
401
-
402
- retried_logs = run_task_groups(
403
- task_groups=task_groups, run_tasks=run_previous_tasks
347
+
348
+ # run previous tasks (no models passed b/c previous task already carries its model)
349
+ retried_logs = run_eval(
350
+ tasks=as_previous_tasks(failed_tasks, failed_logs)
404
351
  )
405
352
 
406
353
  # return success
@@ -443,6 +390,42 @@ def eval_set(
443
390
  return success, results
444
391
 
445
392
 
393
+ # convert resolved tasks to previous tasks
394
+ def as_previous_tasks(
395
+ tasks: list[ResolvedTask], failed_logs: list[Log]
396
+ ) -> list[PreviousTask]:
397
+ def task_to_failed_log(task: ResolvedTask) -> Log:
398
+ resolved_task_identifier = task_identifier(task)
399
+ return next(
400
+ log
401
+ for log in failed_logs
402
+ if log.task_identifier == resolved_task_identifier
403
+ )
404
+
405
+ previous_tasks: list[PreviousTask] = []
406
+ for task, log in zip(tasks, map(task_to_failed_log, tasks)):
407
+ # NOTE: we used to try to recreate registry objects by
408
+ # by just passing the task name, but that didn't work
409
+ # when evals were run from another directory. we may
410
+ # want to bring this back but we'd need to resolve the
411
+ # directory issues.
412
+
413
+ # deepcopy so the same instance is not run twice
414
+ prev_task = deepcopy(task.task)
415
+
416
+ previous_tasks.append(
417
+ PreviousTask(
418
+ id=log.header.eval.task_id,
419
+ task=prev_task,
420
+ task_args=resolve_task_args(task.task),
421
+ model=task.model,
422
+ log=read_eval_log(log.info),
423
+ )
424
+ )
425
+
426
+ return previous_tasks
427
+
428
+
446
429
  # filters to determine when we are done
447
430
 
448
431
 
@@ -574,7 +557,7 @@ def task_identifier(task: ResolvedTask | EvalLog) -> str:
574
557
  task_file = task.eval.task_file or ""
575
558
  task_name = task.eval.task
576
559
  task_args = task.eval.task_args
577
- model = task.eval.model
560
+ model = str(task.eval.model)
578
561
 
579
562
  # hash for task args
580
563
  task_args_hash = hashlib.sha256(
@@ -617,80 +600,6 @@ class ModelList:
617
600
  return ",".join(model_names)
618
601
 
619
602
 
620
- class TaskGroup(NamedTuple):
621
- models: ModelList
622
- tasks: list[ResolvedTask]
623
-
624
-
625
- # group into models => tasks for maximum parallelism
626
- def schedule_pending_tasks(pending_tasks: list[ResolvedTask]) -> list[TaskGroup]:
627
- # build a map of task identifiers and the models they target
628
- task_id_model_targets: dict[str, ModelList] = {}
629
- for pending_task in pending_tasks:
630
- task_id = task_identifier_without_model(task_identifier(pending_task))
631
- if task_id not in task_id_model_targets:
632
- task_id_model_targets[task_id] = ModelList([])
633
- if pending_task.model not in task_id_model_targets[task_id].models:
634
- task_id_model_targets[task_id].models.append(pending_task.model)
635
-
636
- # build a list of unique model targets
637
- unique_model_targets: Set[ModelList] = set(task_id_model_targets.values())
638
-
639
- # create schedule
640
- schedule: list[TaskGroup] = [
641
- TaskGroup(models=model_target, tasks=[])
642
- for model_target in unique_model_targets
643
- ]
644
-
645
- for models, tasks in schedule:
646
- # which task ids have this set of models
647
- task_ids: list[str] = []
648
- for task_id, task_models in task_id_model_targets.items():
649
- if task_models == models:
650
- task_ids.append(task_id)
651
-
652
- # find a task for each of these ids
653
- for task_id in task_ids:
654
- tasks.append(
655
- next(
656
- (
657
- task
658
- for task in pending_tasks
659
- if task_id
660
- == task_identifier_without_model(task_identifier(task))
661
- )
662
- )
663
- )
664
-
665
- # deterministic return order
666
- schedule.sort(key=lambda x: str(x[0]))
667
-
668
- return schedule
669
-
670
-
671
- # group into model => tasks (can't do multiple models b/c these are PreviousTask
672
- # instances (and therefore model/task pair specific -- we don't want to create
673
- # multiple instances of these tasks)
674
- def schedule_retry_tasks(retry_tasks: list[ResolvedTask]) -> list[TaskGroup]:
675
- # build a list of unique model targets
676
- unique_model_targets: Set[ModelList] = set()
677
- for retry_task in retry_tasks:
678
- unique_model_targets.add(ModelList([retry_task.model]))
679
-
680
- # create a task group for reach model target
681
- schedule: list[TaskGroup] = []
682
- for model_target in unique_model_targets:
683
- group_tasks = [
684
- task for task in retry_tasks if ModelList([task.model]) == model_target
685
- ]
686
- schedule.append(TaskGroup(model_target, group_tasks))
687
-
688
- # deterministic return order
689
- schedule.sort(key=lambda x: str(x[0]))
690
-
691
- return schedule
692
-
693
-
694
603
  def starting_max_connections(models: list[Model], config: GenerateConfig) -> int:
695
604
  # if there is an explicit config use that
696
605
  if config.max_connections is not None:
@@ -2,7 +2,6 @@ import ast
2
2
  import contextlib
3
3
  import inspect
4
4
  import os
5
- from dataclasses import dataclass, field
6
5
  from importlib.machinery import SourceFileLoader
7
6
  from importlib.util import module_from_spec, spec_from_loader
8
7
  from logging import getLogger
@@ -12,6 +11,7 @@ from typing import Any, Callable, Tuple, cast
12
11
 
13
12
  from typing_extensions import overload
14
13
 
14
+ from inspect_ai._eval.task.resolved import ResolvedTask
15
15
  from inspect_ai._eval.task.util import task_file, task_run_dir
16
16
  from inspect_ai._util._async import configured_async_backend
17
17
  from inspect_ai._util.decorator import parse_decorators
@@ -26,44 +26,26 @@ from inspect_ai._util.registry import (
26
26
  registry_lookup,
27
27
  registry_params,
28
28
  )
29
- from inspect_ai.model import Model, ModelName
29
+ from inspect_ai.model import Model
30
30
  from inspect_ai.scorer._scorer import Scorer, ScorerSpec, scorer_create
31
31
  from inspect_ai.solver._bridge import bridge
32
32
  from inspect_ai.solver._solver import Solver, SolverSpec
33
33
  from inspect_ai.util import SandboxEnvironmentSpec, SandboxEnvironmentType
34
- from inspect_ai.util._sandbox.environment import resolve_sandbox_environment
34
+ from inspect_ai.util._sandbox.environment import (
35
+ resolve_sandbox_environment,
36
+ )
35
37
  from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
36
38
 
37
39
  from .list import task_files
38
40
  from .registry import task_create
39
- from .task import PreviousTask, Task, TaskInfo, Tasks
41
+ from .task import PreviousTask, Task, TaskInfo
40
42
  from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
41
- from .task.run import EvalSampleSource, eval_log_sample_source
43
+ from .task.run import eval_log_sample_source
44
+ from .task.tasks import Tasks
42
45
 
43
46
  logger = getLogger(__name__)
44
47
 
45
48
 
46
- @dataclass(frozen=True)
47
- class ResolvedTask:
48
- task: Task
49
- task_args: dict[str, Any]
50
- task_file: str | None
51
- model: Model
52
- sandbox: SandboxEnvironmentSpec | None
53
- sequence: int
54
- id: str | None = field(default=None)
55
- sample_source: EvalSampleSource | None = field(default=None)
56
-
57
- @property
58
- def has_sandbox(self) -> bool:
59
- if self.sandbox:
60
- return True
61
- else:
62
- return any(
63
- [True if sample.sandbox else False for sample in self.task.dataset]
64
- )
65
-
66
-
67
49
  def resolve_tasks(
68
50
  tasks: Tasks,
69
51
  task_args: dict[str, Any],
@@ -76,16 +58,22 @@ def resolve_tasks(
76
58
  task=task,
77
59
  task_args=resolve_task_args(task),
78
60
  task_file=task_file(task, relative=True),
79
- model=model,
61
+ model=task.model or model,
80
62
  sandbox=resolve_task_sandbox(task, sandbox),
81
63
  sequence=sequence,
82
64
  )
83
65
  for sequence, task in enumerate(tasks)
84
66
  ]
85
67
 
68
+ # reflect resolved tasks right back
69
+ if isinstance(tasks, ResolvedTask):
70
+ return [tasks]
71
+ elif isinstance(tasks, list) and isinstance(tasks[0], ResolvedTask):
72
+ return cast(list[ResolvedTask], tasks)
73
+
86
74
  # take empty lists out of play
87
75
  if isinstance(tasks, list) and len(tasks) == 0:
88
- return as_resolved_tasks(load_tasks(None, model, task_args))
76
+ return as_resolved_tasks(load_tasks(None, task_args))
89
77
 
90
78
  # simple cases of passing us Task objects
91
79
  if isinstance(tasks, Task):
@@ -109,9 +97,7 @@ def resolve_tasks(
109
97
  loaded_task = previous_task.task
110
98
  else:
111
99
  loaded_task_args = previous_task.task_args
112
- loaded_task = load_tasks([previous_task.task], model, loaded_task_args)[
113
- 0
114
- ]
100
+ loaded_task = load_tasks([previous_task.task], loaded_task_args)[0]
115
101
  loaded_tasks.append(loaded_task)
116
102
  loaded_tasks_args.append(loaded_task_args)
117
103
 
@@ -120,7 +106,7 @@ def resolve_tasks(
120
106
  task=loaded_task,
121
107
  task_args=loaded_task_args,
122
108
  task_file=previous_task.log.eval.task_file,
123
- model=model,
109
+ model=previous_task.model or loaded_task.model or model,
124
110
  sandbox=previous_task.log.eval.sandbox,
125
111
  sequence=sequence,
126
112
  id=previous_task.id,
@@ -153,19 +139,14 @@ def resolve_tasks(
153
139
  tasks = [tasks]
154
140
 
155
141
  # done! let's load the tasks
156
- return as_resolved_tasks(
157
- load_tasks(cast(list[str] | None, tasks), model, task_args)
158
- )
142
+ return as_resolved_tasks(load_tasks(cast(list[str] | None, tasks), task_args))
159
143
 
160
144
 
161
145
  def resolve_task_args(task: Task) -> dict[str, Any]:
162
146
  # was the task instantiated via the registry or a decorator?
163
147
  # if so then we can get the task_args from the registry.
164
148
  try:
165
- # filter out model as that's dyanmic and automatically passed
166
149
  task_args = dict(registry_params(task))
167
- if "model" in task_args:
168
- del task_args["model"]
169
150
  return task_args
170
151
 
171
152
  # if it wasn't instantiated via the registry or a decorator
@@ -217,34 +198,29 @@ def resolve_task_sandbox(
217
198
 
218
199
 
219
200
  def load_tasks(
220
- task_specs: list[str] | None, model: Model, task_args: dict[str, Any] = {}
201
+ task_specs: list[str] | None, task_args: dict[str, Any] = {}
221
202
  ) -> list[Task]:
222
203
  """Load one more more tasks (if no tasks are specified, load from the current working directory"""
223
- # determine ModelName object for task creation parameterized by model
224
- model_name = ModelName(model)
225
204
  # load tasks
226
205
  return [
227
206
  spec
228
207
  for task_spec in (task_specs if task_specs else [Path.cwd().as_posix()])
229
- for spec in load_task_spec(task_spec, model_name, task_args)
208
+ for spec in load_task_spec(task_spec, task_args)
230
209
  ]
231
210
 
232
211
 
233
- def load_task_spec(
234
- task_spec: str, model: ModelName, task_args: dict[str, Any] = {}
235
- ) -> list[Task]:
212
+ def load_task_spec(task_spec: str, task_args: dict[str, Any] = {}) -> list[Task]:
236
213
  # task in a python package
237
214
  if registry_lookup("task", task_spec) is not None:
238
215
  # create the task from a python package
239
- return [task_create(task_spec, model, **task_args)]
216
+ return [task_create(task_spec, **task_args)]
240
217
  else:
241
218
  # load tasks from glob
242
- return create_tasks([task_spec], model, task_args)
219
+ return create_tasks([task_spec], task_args)
243
220
 
244
221
 
245
222
  def create_tasks(
246
223
  globs: list[str],
247
- model: ModelName,
248
224
  task_args: dict[str, Any] = {},
249
225
  root_dir: Path | None = None,
250
226
  ) -> list[Task]:
@@ -261,9 +237,7 @@ def create_tasks(
261
237
  if spec_split[1] is not None:
262
238
  task_path = Path(spec_split[0])
263
239
  load_file_tasks(task_path.absolute())
264
- tasks.extend(
265
- create_file_tasks(task_path, model, [spec_split[1]], task_args)
266
- )
240
+ tasks.extend(create_file_tasks(task_path, [spec_split[1]], task_args))
267
241
  else:
268
242
  # if the glob is the root dir then set it to empty (will result in
269
243
  # enumeration of the root dir)
@@ -271,7 +245,7 @@ def create_tasks(
271
245
  files = task_files(target, root_dir)
272
246
  files = sorted(files, key=lambda f: f.as_posix())
273
247
  for file in files:
274
- tasks.extend(create_file_tasks(file, model, None, task_args))
248
+ tasks.extend(create_file_tasks(file, None, task_args))
275
249
  return tasks
276
250
 
277
251
 
@@ -282,7 +256,6 @@ def load_file_tasks(file: Path) -> None:
282
256
 
283
257
  def create_file_tasks(
284
258
  file: Path,
285
- model: ModelName,
286
259
  task_specs: list[str] | list[RegistryInfo] | None = None,
287
260
  task_args: dict[str, Any] = {},
288
261
  ) -> list[Task]:
@@ -302,7 +275,7 @@ def create_file_tasks(
302
275
  # create the task from the loaded source file and
303
276
  # note that it was loaded from this directory
304
277
  # (will be used later to ensure it runs in the directory)
305
- task = task_create(task_spec, model, **task_args)
278
+ task = task_create(task_spec, **task_args)
306
279
  setattr(task, TASK_FILE_ATTR, file.as_posix())
307
280
  setattr(task, TASK_RUN_DIR_ATTR, run_dir)
308
281
  tasks.append(task)
@@ -1,6 +1,5 @@
1
1
  import inspect
2
2
  import logging
3
- from copy import deepcopy
4
3
  from functools import wraps
5
4
  from pathlib import Path
6
5
  from typing import Any, Callable, TypeVar, cast, overload
@@ -16,7 +15,6 @@ from inspect_ai._util.registry import (
16
15
  registry_name,
17
16
  registry_tag,
18
17
  )
19
- from inspect_ai.model import ModelName
20
18
 
21
19
  from .task import Task
22
20
  from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
@@ -54,7 +52,7 @@ def task_register(
54
52
  return task
55
53
 
56
54
 
57
- def task_create(name: str, model: ModelName, **kwargs: Any) -> Task:
55
+ def task_create(name: str, **kwargs: Any) -> Task:
58
56
  r"""Create a Task based on its registered name.
59
57
 
60
58
  Tasks can be a function that returns a Task or a
@@ -62,17 +60,11 @@ def task_create(name: str, model: ModelName, **kwargs: Any) -> Task:
62
60
 
63
61
  Args:
64
62
  name (str): Name of task (Optional, defaults to object name)
65
- model (ModelName): Model name
66
63
  **kwargs (dict): Optional creation arguments for the task
67
64
 
68
65
  Returns:
69
66
  Task with registry info attribute
70
67
  """
71
- # bring in model arg (first deepcopy as we will mutate it)
72
- # add model to task_args
73
- kwargs = deepcopy(kwargs)
74
- kwargs[MODEL_PARAM] = model
75
-
76
68
  # match kwargs params to signature (warn if param not found)
77
69
  # (note that we always pass the 'model' param but tasks aren't
78
70
  # required to consume it, so we don't warn for 'model')
@@ -85,7 +77,6 @@ def task_create(name: str, model: ModelName, **kwargs: Any) -> Task:
85
77
  for param in kwargs.keys():
86
78
  if param in task_params:
87
79
  task_args[param] = kwargs[param]
88
- elif param != MODEL_PARAM:
89
80
  if "kwargs" in task_params:
90
81
  task_args[param] = kwargs[param]
91
82
  else:
inspect_ai/_eval/run.py CHANGED
@@ -2,7 +2,7 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  import sys
5
- from typing import Any, Awaitable, Callable, Set, cast
5
+ from typing import Awaitable, Callable, Set, cast
6
6
 
7
7
  from inspect_ai._util.trace import trace_action
8
8
 
@@ -44,11 +44,11 @@ from inspect_ai.util._sandbox.environment import (
44
44
  from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
45
45
 
46
46
  from .loader import (
47
- ResolvedTask,
48
47
  as_solver_spec,
49
48
  solver_from_spec,
50
49
  )
51
50
  from .task.log import TaskLogger
51
+ from .task.resolved import ResolvedTask
52
52
  from .task.run import TaskRunOptions, task_run
53
53
  from .task.rundir import task_run_dir_switching
54
54
  from .task.sandbox import TaskSandboxEnvironment, resolve_sandbox_for_task
@@ -64,7 +64,6 @@ async def eval_run(
64
64
  eval_config: EvalConfig,
65
65
  eval_sandbox: SandboxEnvironmentType | None,
66
66
  recorder: Recorder,
67
- model_args: dict[str, Any],
68
67
  epochs_reducer: list[ScoreReducer] | None = None,
69
68
  solver: Solver | SolverSpec | None = None,
70
69
  tags: list[str] | None = None,
@@ -200,7 +199,7 @@ async def eval_run(
200
199
  sandbox=resolved_task.sandbox,
201
200
  task_attribs=task.attribs,
202
201
  task_args=resolved_task.task_args,
203
- model_args=model_args,
202
+ model_args=resolved_task.model.model_args,
204
203
  eval_config=task_eval_config,
205
204
  metadata=task.metadata,
206
205
  recorder=recorder,
@@ -1,4 +1,10 @@
1
- from .task import Task, TaskInfo, PreviousTask, Tasks, task_with # noqa: I001, F401
1
+ from .task import Task, TaskInfo, PreviousTask, task_with # noqa: I001, F401
2
2
  from .epochs import Epochs
3
3
 
4
- __all__ = ["Epochs", "Task", "TaskInfo", "PreviousTask", "Tasks", "task_with"]
4
+ __all__ = [
5
+ "Epochs",
6
+ "Task",
7
+ "TaskInfo",
8
+ "PreviousTask",
9
+ "task_with",
10
+ ]
@@ -1,5 +1,6 @@
1
1
  from importlib import metadata as importlib_metadata
2
- from typing import Any, Literal, cast
2
+ from inspect import isgenerator
3
+ from typing import Any, Iterator, Literal, cast
3
4
 
4
5
  from shortuuid import uuid
5
6
 
@@ -83,6 +84,13 @@ class TaskLogger:
83
84
  del model_args["api_key"]
84
85
  model_args = {k: v for k, v in model_args.items() if not k.startswith("aws_")}
85
86
 
87
+ # don't try to serialise generators
88
+ model_args = {
89
+ k: v
90
+ for k, v in model_args.items()
91
+ if not isgenerator(v) and not isinstance(v, Iterator)
92
+ }
93
+
86
94
  # cwd_relative_path for sandbox config
87
95
  if sandbox and isinstance(sandbox.config, str):
88
96
  sandbox = SandboxEnvironmentSpec(
@@ -0,0 +1,35 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Set
3
+
4
+ from inspect_ai._eval.task import Task
5
+ from inspect_ai._eval.task.run import EvalSampleSource
6
+ from inspect_ai.model import Model
7
+ from inspect_ai.util import SandboxEnvironmentSpec
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class ResolvedTask:
12
+ task: Task
13
+ task_args: dict[str, Any]
14
+ task_file: str | None
15
+ model: Model
16
+ sandbox: SandboxEnvironmentSpec | None
17
+ sequence: int
18
+ id: str | None = field(default=None)
19
+ sample_source: EvalSampleSource | None = field(default=None)
20
+
21
+ @property
22
+ def has_sandbox(self) -> bool:
23
+ if self.sandbox:
24
+ return True
25
+ else:
26
+ return any(
27
+ [True if sample.sandbox else False for sample in self.task.dataset]
28
+ )
29
+
30
+
31
+ def resolved_model_names(tasks: list[ResolvedTask]) -> list[str]:
32
+ models: Set[str] = set()
33
+ for task in tasks:
34
+ models.add(str(task.model))
35
+ return list(models)