inspect-ai 0.3.87__py3-none-any.whl → 0.3.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -244
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +55 -18
  22. inspect_ai/_view/www/dist/assets/index.js +550 -458
  23. inspect_ai/_view/www/log-schema.json +84 -1
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  30. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  31. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  32. inspect_ai/_view/www/src/types/log.d.ts +150 -129
  33. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  34. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  35. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  36. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  37. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  38. inspect_ai/agent/_agent.py +12 -0
  39. inspect_ai/agent/_as_tool.py +1 -1
  40. inspect_ai/agent/_bridge/bridge.py +9 -2
  41. inspect_ai/agent/_react.py +142 -74
  42. inspect_ai/agent/_run.py +13 -2
  43. inspect_ai/agent/_types.py +6 -0
  44. inspect_ai/approval/_apply.py +6 -9
  45. inspect_ai/approval/_approver.py +3 -3
  46. inspect_ai/approval/_auto.py +2 -2
  47. inspect_ai/approval/_call.py +20 -4
  48. inspect_ai/approval/_human/approver.py +3 -3
  49. inspect_ai/approval/_human/manager.py +2 -2
  50. inspect_ai/approval/_human/panel.py +3 -3
  51. inspect_ai/approval/_policy.py +3 -3
  52. inspect_ai/log/__init__.py +2 -0
  53. inspect_ai/log/_log.py +23 -2
  54. inspect_ai/log/_model.py +58 -0
  55. inspect_ai/log/_recorders/file.py +14 -3
  56. inspect_ai/log/_transcript.py +3 -0
  57. inspect_ai/model/__init__.py +2 -0
  58. inspect_ai/model/_call_tools.py +15 -2
  59. inspect_ai/model/_model.py +49 -3
  60. inspect_ai/model/_openai.py +151 -21
  61. inspect_ai/model/_providers/anthropic.py +25 -14
  62. inspect_ai/model/_providers/bedrock.py +3 -3
  63. inspect_ai/model/_providers/cloudflare.py +29 -108
  64. inspect_ai/model/_providers/google.py +21 -10
  65. inspect_ai/model/_providers/grok.py +23 -17
  66. inspect_ai/model/_providers/groq.py +61 -37
  67. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  68. inspect_ai/model/_providers/mistral.py +8 -3
  69. inspect_ai/model/_providers/ollama.py +8 -9
  70. inspect_ai/model/_providers/openai.py +53 -157
  71. inspect_ai/model/_providers/openai_compatible.py +195 -0
  72. inspect_ai/model/_providers/openrouter.py +4 -15
  73. inspect_ai/model/_providers/providers.py +11 -0
  74. inspect_ai/model/_providers/together.py +25 -23
  75. inspect_ai/model/_trim.py +83 -0
  76. inspect_ai/solver/_plan.py +5 -3
  77. inspect_ai/tool/_tool_call.py +3 -0
  78. inspect_ai/tool/_tool_def.py +8 -2
  79. inspect_ai/util/__init__.py +3 -0
  80. inspect_ai/util/_concurrency.py +15 -2
  81. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
  82. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +86 -81
  83. inspect_ai/_eval/task/rundir.py +0 -78
  84. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  85. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
  86. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
  87. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
  88. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
inspect_ai/_cli/eval.py CHANGED
@@ -84,6 +84,13 @@ def eval_options(func: Callable[..., Any]) -> Callable[..., click.Context]:
84
84
  envvar="INSPECT_EVAL_MODEL_CONFIG",
85
85
  help="YAML or JSON config file with model arguments.",
86
86
  )
87
+ @click.option(
88
+ "--model-role",
89
+ multiple=True,
90
+ type=str,
91
+ envvar="INSPECT_EVAL_MODEL_ROLE",
92
+ help="Named model role, e.g. --model-role critic=openai/gpt-4o",
93
+ )
87
94
  @click.option(
88
95
  "-T",
89
96
  multiple=True,
@@ -467,6 +474,7 @@ def eval_command(
467
474
  model_base_url: str | None,
468
475
  m: tuple[str] | None,
469
476
  model_config: str | None,
477
+ model_role: tuple[str] | None,
470
478
  t: tuple[str] | None,
471
479
  task_config: str | None,
472
480
  s: tuple[str] | None,
@@ -545,6 +553,7 @@ def eval_command(
545
553
  model_base_url=model_base_url,
546
554
  m=m,
547
555
  model_config=model_config,
556
+ model_role=model_role,
548
557
  t=t,
549
558
  task_config=task_config,
550
559
  s=s,
@@ -638,6 +647,7 @@ def eval_set_command(
638
647
  model_base_url: str | None,
639
648
  m: tuple[str] | None,
640
649
  model_config: str | None,
650
+ model_role: tuple[str] | None,
641
651
  t: tuple[str] | None,
642
652
  task_config: str | None,
643
653
  s: tuple[str] | None,
@@ -719,6 +729,7 @@ def eval_set_command(
719
729
  model_base_url=model_base_url,
720
730
  m=m,
721
731
  model_config=model_config,
732
+ model_role=model_role,
722
733
  t=t,
723
734
  task_config=task_config,
724
735
  s=s,
@@ -775,6 +786,7 @@ def eval_exec(
775
786
  model_base_url: str | None,
776
787
  m: tuple[str] | None,
777
788
  model_config: str | None,
789
+ model_role: tuple[str] | None,
778
790
  t: tuple[str] | None,
779
791
  task_config: str | None,
780
792
  s: tuple[str] | None,
@@ -820,6 +832,9 @@ def eval_exec(
820
832
  solver_args = parse_cli_config(s, solver_config)
821
833
  model_args = parse_cli_config(m, model_config)
822
834
 
835
+ # parse model roles
836
+ eval_model_roles = parse_cli_args(model_role, force_str=True)
837
+
823
838
  # parse tags
824
839
  eval_tags = parse_comma_separated(tags)
825
840
 
@@ -858,6 +873,7 @@ def eval_exec(
858
873
  model=model,
859
874
  model_base_url=model_base_url,
860
875
  model_args=model_args,
876
+ model_roles=eval_model_roles,
861
877
  task_args=task_args,
862
878
  solver=SolverSpec(solver, solver_args) if solver else None,
863
879
  tags=eval_tags,
inspect_ai/_cli/score.py CHANGED
@@ -11,13 +11,12 @@ from typing_extensions import Unpack
11
11
  from inspect_ai._cli.util import parse_cli_config
12
12
  from inspect_ai._display import display
13
13
  from inspect_ai._display.core.rich import rich_theme
14
- from inspect_ai._eval.context import init_eval_context, init_task_context
14
+ from inspect_ai._eval.context import init_eval_context
15
15
  from inspect_ai._eval.score import ScoreAction, task_score
16
16
  from inspect_ai._util._async import configured_async_backend
17
17
  from inspect_ai._util.file import basename, dirname, exists
18
18
  from inspect_ai.log._log import EvalLog
19
19
  from inspect_ai.log._recorders import create_recorder_for_location
20
- from inspect_ai.model import get_model
21
20
 
22
21
  from .common import CommonOptions, common_options, process_common_options
23
22
 
@@ -109,16 +108,6 @@ async def score(
109
108
  if eval_log.samples is None or len(eval_log.samples) == 0:
110
109
  raise ValueError(f"{log_file} does not include samples to score")
111
110
 
112
- # get the model then initialize the async context
113
- model = get_model(
114
- model=eval_log.eval.model,
115
- config=eval_log.plan.config,
116
- **eval_log.eval.model_args,
117
- )
118
-
119
- # initialize active model
120
- init_task_context(model)
121
-
122
111
  # re-score the task
123
112
  eval_log = await task_score(
124
113
  log=eval_log, scorer=scorer, scorer_args=scorer_args, action=action
inspect_ai/_cli/util.py CHANGED
@@ -63,7 +63,9 @@ def parse_cli_config(
63
63
  return cli_config
64
64
 
65
65
 
66
- def parse_cli_args(args: tuple[str] | list[str] | None) -> dict[str, Any]:
66
+ def parse_cli_args(
67
+ args: tuple[str] | list[str] | None, force_str: bool = False
68
+ ) -> dict[str, Any]:
67
69
  params: dict[str, Any] = dict()
68
70
  if args:
69
71
  for arg in list(args):
@@ -74,7 +76,7 @@ def parse_cli_args(args: tuple[str] | list[str] | None) -> dict[str, Any]:
74
76
  if isinstance(value, str):
75
77
  value = value.split(",")
76
78
  value = value if len(value) > 1 else value[0]
77
- params[key] = value
79
+ params[key] = str(value) if force_str else value
78
80
  return params
79
81
 
80
82
 
@@ -2,7 +2,7 @@ from rich.console import RenderableType
2
2
  from rich.text import Text
3
3
 
4
4
  from inspect_ai._util.retry import http_retries_count
5
- from inspect_ai.util._concurrency import concurrency_status
5
+ from inspect_ai.util._concurrency import concurrency_status_display
6
6
  from inspect_ai.util._throttle import throttle
7
7
 
8
8
  from .config import task_dict
@@ -20,7 +20,7 @@ def task_footer(
20
20
 
21
21
  def task_resources() -> str:
22
22
  resources: dict[str, str] = {}
23
- for model, resource in concurrency_status().items():
23
+ for model, resource in concurrency_status_display().items():
24
24
  resources[model] = f"{resource[0]}/{resource[1]}"
25
25
  return task_dict(resources)
26
26
 
@@ -10,7 +10,7 @@ from inspect_ai._util.platform import running_in_notebook
10
10
  from inspect_ai._util.text import truncate
11
11
  from inspect_ai._util.throttle import throttle
12
12
 
13
- from ...util._concurrency import concurrency_status
13
+ from ...util._concurrency import concurrency_status_display
14
14
  from ..core.config import task_config
15
15
  from ..core.display import (
16
16
  TR,
@@ -179,7 +179,7 @@ class PlainTaskDisplay(TaskDisplay):
179
179
  # Very similar to ``inspect_ai._display.core.footer.task_resources``, but without
180
180
  # the rich formatting added in the ``task_dict`` call
181
181
  resources_dict: dict[str, str] = {}
182
- for model, resource in concurrency_status().items():
182
+ for model, resource in concurrency_status_display().items():
183
183
  resources_dict[model] = f"{resource[0]:2d}/{resource[1]:2d}"
184
184
  resources = ", ".join(
185
185
  [f"{key}: {value}" for key, value in resources_dict.items()]
@@ -6,7 +6,11 @@ from inspect_ai.approval._human.manager import init_human_approval_manager
6
6
  from inspect_ai.approval._policy import ApprovalPolicy
7
7
  from inspect_ai.log._samples import init_active_samples
8
8
  from inspect_ai.model import GenerateConfig, Model
9
- from inspect_ai.model._model import init_active_model, init_model_usage
9
+ from inspect_ai.model._model import (
10
+ init_active_model,
11
+ init_model_roles,
12
+ init_model_usage,
13
+ )
10
14
  from inspect_ai.util._concurrency import init_concurrency
11
15
  from inspect_ai.util._subprocess import init_max_subprocesses
12
16
 
@@ -27,10 +31,12 @@ def init_eval_context(
27
31
 
28
32
  def init_task_context(
29
33
  model: Model,
34
+ model_roles: dict[str, Model] | None = None,
30
35
  approval: list[ApprovalPolicy] | None = None,
31
36
  config: GenerateConfig = GenerateConfig(),
32
37
  ) -> None:
33
38
  init_active_model(model, config)
39
+ init_model_roles(model_roles or {})
34
40
  init_model_usage()
35
41
  if not have_tool_approval():
36
42
  init_tool_approval(approval)
inspect_ai/_eval/eval.py CHANGED
@@ -4,9 +4,11 @@ import sys
4
4
  from pathlib import Path
5
5
  from typing import Any, Literal, cast
6
6
 
7
+ from inspect_ai._eval.task.task import resolve_model_roles
7
8
  from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
8
9
  from inspect_ai.agent._agent import Agent, is_agent
9
10
  from inspect_ai.agent._as_solver import as_solver
11
+ from inspect_ai.log._model import model_roles_config_to_model_roles
10
12
 
11
13
  if sys.version_info < (3, 11):
12
14
  from exceptiongroup import ExceptionGroup
@@ -70,6 +72,7 @@ def eval(
70
72
  model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
71
73
  model_base_url: str | None = None,
72
74
  model_args: dict[str, Any] | str = dict(),
75
+ model_roles: dict[str, str | Model] | None = None,
73
76
  task_args: dict[str, Any] | str = dict(),
74
77
  sandbox: SandboxEnvironmentType | None = None,
75
78
  sandbox_cleanup: bool | None = None,
@@ -84,7 +87,7 @@ def eval(
84
87
  log_dir: str | None = None,
85
88
  log_format: Literal["eval", "json"] | None = None,
86
89
  limit: int | tuple[int, int] | None = None,
87
- sample_id: str | int | list[str | int] | None = None,
90
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None = None,
88
91
  epochs: int | Epochs | None = None,
89
92
  fail_on_error: bool | float | None = None,
90
93
  debug_errors: bool | None = None,
@@ -116,6 +119,7 @@ def eval(
116
119
  with the model API.
117
120
  model_args: Model creation args
118
121
  (as a dictionary or as a path to a JSON or YAML config file)
122
+ model_roles: Named roles for use in `get_model()`.
119
123
  task_args: Task creation arguments
120
124
  (as a dictionary or as a path to a JSON or YAML config file)
121
125
  sandbox: Sandbox environment type
@@ -194,6 +198,7 @@ def eval(
194
198
  model=model,
195
199
  model_base_url=model_base_url,
196
200
  model_args=model_args,
201
+ model_roles=model_roles,
197
202
  task_args=task_args,
198
203
  sandbox=sandbox,
199
204
  sandbox_cleanup=sandbox_cleanup,
@@ -245,6 +250,7 @@ async def eval_async(
245
250
  model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
246
251
  model_base_url: str | None = None,
247
252
  model_args: dict[str, Any] | str = dict(),
253
+ model_roles: dict[str, str | Model] | None = None,
248
254
  task_args: dict[str, Any] | str = dict(),
249
255
  sandbox: SandboxEnvironmentType | None = None,
250
256
  sandbox_cleanup: bool | None = None,
@@ -257,7 +263,7 @@ async def eval_async(
257
263
  log_dir: str | None = None,
258
264
  log_format: Literal["eval", "json"] | None = None,
259
265
  limit: int | tuple[int, int] | None = None,
260
- sample_id: str | int | list[str | int] | None = None,
266
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None = None,
261
267
  epochs: int | Epochs | None = None,
262
268
  fail_on_error: bool | float | None = None,
263
269
  debug_errors: bool | None = None,
@@ -286,7 +292,8 @@ async def eval_async(
286
292
  environment variable. Specify `None` to define no default model(s), which will
287
293
  leave model usage entirely up to tasks.
288
294
  model_base_url: Base URL for communicating with the model API.
289
- model_args: Model creation args (as a dictionary or as a path to a JSON or YAML config file)
295
+ model_args: Model creation args (as a dictionary or as a path to a JSON or YAML config file
296
+ model_roles: Named roles for use in `get_model()`.
290
297
  task_args: Task creation arguments (as a dictionary or as a path to a JSON or YAML config file)
291
298
  sandbox: Sandbox environment type (or optionally a str or tuple with a shorthand spec)
292
299
  sandbox_cleanup: Cleanup sandbox environments after task completes (defaults to True)
@@ -333,12 +340,11 @@ async def eval_async(
333
340
  Returns:
334
341
  List of EvalLog (one for each task)
335
342
  """
336
- # only a single call to eval_async can be active at a time, this is
337
- # because when running a task a chdir to the task's directory (and
338
- # similar mutation of the Python sys.path) occurs. since this is a
339
- # change to global process state it cannot occur in parallel. for
340
- # task parallelism, pass multiple tasks to eval or eval_async (which
341
- # will enforce the appropriate constraints on task parallelism)
343
+ # only a single call to eval_async can be active at a time, this used
344
+ # to be due to running tasks switching to the task's directory, however
345
+ # that feature no longer exists so we may be able to revisit this
346
+ # restriction (probably just need to examine if there is *global* state
347
+ # that could have conflicts in the case of multiple eval_async calls)
342
348
  global _eval_async_running
343
349
  if _eval_async_running:
344
350
  raise RuntimeError("Multiple concurrent calls to eval_async are not allowed.")
@@ -355,11 +361,10 @@ async def eval_async(
355
361
 
356
362
  try:
357
363
  # intialise eval
358
- model, approval = eval_init(
364
+ model = eval_init(
359
365
  model=model,
360
366
  model_base_url=model_base_url,
361
367
  model_args=model_args,
362
- approval=approval,
363
368
  max_subprocesses=max_subprocesses,
364
369
  log_level=log_level,
365
370
  log_level_transcript=log_level_transcript,
@@ -367,8 +372,14 @@ async def eval_async(
367
372
  )
368
373
 
369
374
  # resolve tasks
370
- resolved_tasks = eval_resolve_tasks(
371
- tasks, task_args, model, GenerateConfig(**kwargs), sandbox
375
+ resolved_tasks, approval = eval_resolve_tasks(
376
+ tasks,
377
+ task_args,
378
+ model,
379
+ model_roles,
380
+ GenerateConfig(**kwargs),
381
+ approval,
382
+ sandbox,
372
383
  )
373
384
 
374
385
  # warn and return empty string if we resolved no tasks
@@ -759,6 +770,9 @@ async def eval_retry_async(
759
770
  **eval_log.eval.model_args,
760
771
  )
761
772
 
773
+ # resolve model roles
774
+ model_roles = model_roles_config_to_model_roles(eval_log.eval.model_roles)
775
+
762
776
  # collect the rest of the params we need for the eval
763
777
  task_args = eval_log.eval.task_args
764
778
  tags = eval_log.eval.tags
@@ -815,9 +829,15 @@ async def eval_retry_async(
815
829
  log = (
816
830
  await eval_async(
817
831
  tasks=PreviousTask(
818
- id=task_id, task=task, task_args=task_args, model=None, log=eval_log
832
+ id=task_id,
833
+ task=task,
834
+ task_args=task_args,
835
+ model=None,
836
+ model_roles=None,
837
+ log=eval_log,
819
838
  ),
820
839
  model=model,
840
+ model_roles=cast(dict[str, str | Model], model_roles),
821
841
  task_args=task_args,
822
842
  sandbox=eval_log.eval.sandbox,
823
843
  sandbox_cleanup=sandbox_cleanup,
@@ -861,12 +881,11 @@ def eval_init(
861
881
  model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
862
882
  model_base_url: str | None = None,
863
883
  model_args: dict[str, Any] | str = dict(),
864
- approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
865
884
  max_subprocesses: int | None = None,
866
885
  log_level: str | None = None,
867
886
  log_level_transcript: str | None = None,
868
887
  **kwargs: Unpack[GenerateConfigArgs],
869
- ) -> tuple[list[Model], list[ApprovalPolicy] | None]:
888
+ ) -> list[Model]:
870
889
  # init eval context
871
890
  init_eval_context(log_level, log_level_transcript, max_subprocesses)
872
891
 
@@ -880,32 +899,37 @@ def eval_init(
880
899
  args = [arg.strip() for arg in env_model_args.split(" ")]
881
900
  model_args = parse_cli_args(args)
882
901
 
883
- # resolve models
902
+ # resolve and return models
884
903
  generate_config = GenerateConfig(**kwargs)
885
904
  models = resolve_models(model, model_base_url, model_args, generate_config)
886
-
887
- # resolve approval
888
- if isinstance(approval, str | ApprovalPolicyConfig):
889
- approval = approval_policies_from_config(approval)
890
- init_tool_approval(approval)
891
-
892
- return models, approval
905
+ return models
893
906
 
894
907
 
895
908
  def eval_resolve_tasks(
896
909
  tasks: Tasks,
897
910
  task_args: dict[str, Any] | str,
898
911
  models: list[Model],
912
+ model_roles: dict[str, str | Model] | None,
899
913
  config: GenerateConfig,
914
+ approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None,
900
915
  sandbox: SandboxEnvironmentType | None,
901
- ) -> list[ResolvedTask]:
916
+ ) -> tuple[list[ResolvedTask], list[ApprovalPolicy] | None]:
917
+ resolved_model_roles = resolve_model_roles(model_roles)
902
918
  task_args = resolve_args(task_args)
903
919
  with task_display().suspend_task_app():
904
920
  resolved_tasks: list[ResolvedTask] = []
905
921
  for m in models:
906
922
  init_active_model(m, config)
907
- resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox))
908
- return resolved_tasks
923
+ resolved_tasks.extend(
924
+ resolve_tasks(tasks, task_args, m, resolved_model_roles, sandbox)
925
+ )
926
+
927
+ if isinstance(approval, str | ApprovalPolicyConfig):
928
+ approval = approval_policies_from_config(approval)
929
+ init_tool_approval(approval)
930
+
931
+ # return tasks and approval
932
+ return resolved_tasks, approval
909
933
 
910
934
 
911
935
  def init_eval_display(
@@ -28,6 +28,7 @@ from inspect_ai.log._file import (
28
28
  read_eval_log_headers,
29
29
  write_log_dir_manifest,
30
30
  )
31
+ from inspect_ai.log._model import model_roles_to_model_roles_config
31
32
  from inspect_ai.model import (
32
33
  GenerateConfigArgs,
33
34
  Model,
@@ -63,6 +64,7 @@ def eval_set(
63
64
  model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
64
65
  model_base_url: str | None = None,
65
66
  model_args: dict[str, Any] | str = dict(),
67
+ model_roles: dict[str, str | Model] | None = None,
66
68
  task_args: dict[str, Any] | str = dict(),
67
69
  sandbox: SandboxEnvironmentType | None = None,
68
70
  sandbox_cleanup: bool | None = None,
@@ -77,7 +79,7 @@ def eval_set(
77
79
  log_level_transcript: str | None = None,
78
80
  log_format: Literal["eval", "json"] | None = None,
79
81
  limit: int | tuple[int, int] | None = None,
80
- sample_id: str | int | list[str | int] | None = None,
82
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None = None,
81
83
  epochs: int | Epochs | None = None,
82
84
  fail_on_error: bool | float | None = None,
83
85
  debug_errors: bool | None = None,
@@ -120,6 +122,7 @@ def eval_set(
120
122
  with the model API.
121
123
  model_args: Model creation args
122
124
  (as a dictionary or as a path to a JSON or YAML config file)
125
+ model_roles: Named roles for use in `get_model()`.
123
126
  task_args: Task creation arguments
124
127
  (as a dictionary or as a path to a JSON or YAML config file)
125
128
  sandbox: Sandbox environment type
@@ -194,6 +197,7 @@ def eval_set(
194
197
  model=None, # ResolvedTask/PreviousTask already carries its model
195
198
  model_base_url=model_base_url,
196
199
  model_args=model_args,
200
+ model_roles=model_roles,
197
201
  task_args=task_args,
198
202
  sandbox=sandbox,
199
203
  sandbox_cleanup=sandbox_cleanup,
@@ -248,7 +252,7 @@ def eval_set(
248
252
  raise RuntimeError("eval_set cannot be used with conversation display.")
249
253
 
250
254
  # initialize eval
251
- models, _ = eval_init(
255
+ models = eval_init(
252
256
  model=model,
253
257
  model_base_url=model_base_url,
254
258
  model_args=model_args,
@@ -303,8 +307,14 @@ def eval_set(
303
307
  # - tasks with failed logs (they'll be retried)
304
308
  def try_eval() -> list[EvalLog]:
305
309
  # resolve tasks
306
- resolved_tasks = eval_resolve_tasks(
307
- tasks, task_args, models, GenerateConfig(**kwargs), sandbox
310
+ resolved_tasks, _ = eval_resolve_tasks(
311
+ tasks,
312
+ task_args,
313
+ models,
314
+ model_roles,
315
+ GenerateConfig(**kwargs),
316
+ approval,
317
+ sandbox,
308
318
  )
309
319
 
310
320
  # list all logs currently in the log directory (update manifest if there are some)
@@ -415,18 +425,13 @@ def as_previous_tasks(
415
425
 
416
426
  previous_tasks: list[PreviousTask] = []
417
427
  for task, log in zip(tasks, map(task_to_failed_log, tasks)):
418
- # NOTE: we used to try to recreate registry objects by
419
- # by just passing the task name, but that didn't work
420
- # when evals were run from another directory. we may
421
- # want to bring this back but we'd need to resolve the
422
- # directory issues.
423
-
424
428
  previous_tasks.append(
425
429
  PreviousTask(
426
430
  id=log.header.eval.task_id,
427
431
  task=task.task,
428
432
  task_args=resolve_task_args(task.task),
429
433
  model=task.model,
434
+ model_roles=task.model_roles,
430
435
  log=read_eval_log(log.info),
431
436
  )
432
437
  )
@@ -561,17 +566,29 @@ def task_identifier(task: ResolvedTask | EvalLog) -> str:
561
566
  task_name = task.task.name
562
567
  task_args = task.task_args
563
568
  model = str(task.model)
569
+ model_roles = model_roles_to_model_roles_config(task.model_roles) or {}
564
570
  else:
565
571
  task_file = task.eval.task_file or ""
566
572
  task_name = task.eval.task
567
573
  task_args = task.eval.task_args
568
574
  model = str(task.eval.model)
575
+ model_roles = task.eval.model_roles or {}
569
576
 
570
577
  # hash for task args
571
578
  task_args_hash = hashlib.sha256(
572
579
  to_json(task_args, exclude_none=True, fallback=lambda _x: None)
573
580
  ).hexdigest()
574
581
 
582
+ # hash for model roles
583
+ if len(model_roles):
584
+ model = (
585
+ model
586
+ + "/"
587
+ + hashlib.sha256(
588
+ to_json(model_roles, exclude_none=True, fallback=lambda _x: None)
589
+ ).hexdigest()
590
+ )
591
+
575
592
  if task_file:
576
593
  return f"{task_file}@{task_name}#{task_args_hash}/{model}"
577
594
  else:
@@ -13,7 +13,6 @@ from typing_extensions import overload
13
13
 
14
14
  from inspect_ai._eval.task.resolved import ResolvedTask
15
15
  from inspect_ai._eval.task.util import task_file, task_run_dir
16
- from inspect_ai._util._async import configured_async_backend
17
16
  from inspect_ai._util.decorator import parse_decorators
18
17
  from inspect_ai._util.error import PrerequisiteError
19
18
  from inspect_ai._util.logger import warn_once
@@ -52,6 +51,7 @@ def resolve_tasks(
52
51
  tasks: Tasks,
53
52
  task_args: dict[str, Any],
54
53
  model: Model,
54
+ model_roles: dict[str, Model] | None,
55
55
  sandbox: SandboxEnvironmentType | None,
56
56
  ) -> list[ResolvedTask]:
57
57
  def as_resolved_tasks(tasks: list[Task]) -> list[ResolvedTask]:
@@ -61,6 +61,7 @@ def resolve_tasks(
61
61
  task_args=resolve_task_args(task),
62
62
  task_file=task_file(task, relative=True),
63
63
  model=task.model or model,
64
+ model_roles=task.model_roles or model_roles,
64
65
  sandbox=resolve_task_sandbox(task, sandbox),
65
66
  sequence=sequence,
66
67
  )
@@ -109,6 +110,9 @@ def resolve_tasks(
109
110
  task_args=loaded_task_args,
110
111
  task_file=previous_task.log.eval.task_file,
111
112
  model=previous_task.model or loaded_task.model or model,
113
+ model_roles=(
114
+ previous_task.model_roles or loaded_task.model_roles or model_roles
115
+ ),
112
116
  sandbox=previous_task.log.eval.sandbox,
113
117
  sequence=sequence,
114
118
  id=previous_task.id,
@@ -282,16 +286,11 @@ def create_file_tasks(
282
286
  setattr(task, TASK_RUN_DIR_ATTR, run_dir)
283
287
  tasks.append(task)
284
288
 
285
- # warn that chdir is deprecated
289
+ # warn that chdir has been removed
286
290
  if "chdir" in task.attribs:
287
- if configured_async_backend() == "trio":
288
- raise RuntimeError(
289
- "The task 'chdir' attribute is not compatible with the trio async backend."
290
- )
291
-
292
291
  warn_once(
293
292
  logger,
294
- "The 'chdir' task attribute is deprecated and will be removed in a future release "
293
+ "The 'chdir' task attribute is no longer supported "
295
294
  + "(you should write your tasks to not depend on their runtime working directory)",
296
295
  )
297
296
 
inspect_ai/_eval/run.py CHANGED
@@ -49,9 +49,8 @@ from .loader import (
49
49
  from .task.log import TaskLogger
50
50
  from .task.resolved import ResolvedTask
51
51
  from .task.run import TaskRunOptions, task_run
52
- from .task.rundir import task_run_dir_switching
53
52
  from .task.sandbox import TaskSandboxEnvironment, resolve_sandbox_for_task
54
- from .task.util import slice_dataset, task_chdir, task_run_dir
53
+ from .task.util import slice_dataset, task_run_dir
55
54
 
56
55
  log = logging.getLogger(__name__)
57
56
 
@@ -71,13 +70,10 @@ async def eval_run(
71
70
  score: bool = True,
72
71
  **kwargs: Unpack[GenerateConfigArgs],
73
72
  ) -> list[EvalLog]:
74
- # see if we need to use run_dir switching
75
- run_dir = task_run_dir(tasks[0].task)
76
- multiple_run_dirs = any([task_run_dir(task.task) != run_dir for task in tasks])
77
- tasks_chdir = any([task_chdir(task.task) is not None for task in tasks])
73
+ # are sandboxes in play?
78
74
  has_sandbox = next((task.has_sandbox for task in tasks), None)
79
75
 
80
- # get cwd before switching to task dir
76
+ # get cwd before any switching
81
77
  eval_wd = os.getcwd()
82
78
 
83
79
  # ensure sample ids
@@ -199,6 +195,7 @@ async def eval_run(
199
195
  solver=eval_solver_spec,
200
196
  tags=tags,
201
197
  model=resolved_task.model,
198
+ model_roles=resolved_task.model_roles,
202
199
  dataset=task.dataset,
203
200
  scorer=eval_scorer_specs,
204
201
  metrics=eval_metrics,
@@ -217,6 +214,7 @@ async def eval_run(
217
214
  TaskRunOptions(
218
215
  task=task,
219
216
  model=resolved_task.model,
217
+ model_roles=resolved_task.model_roles,
220
218
  sandbox=resolved_task.sandbox,
221
219
  logger=logger,
222
220
  eval_wd=eval_wd,
@@ -233,25 +231,10 @@ async def eval_run(
233
231
  # multiple mode is for running/displaying multiple
234
232
  # task definitions, which requires some smart scheduling
235
233
  # to ensure that we spread work among models
236
- if tasks_chdir:
237
- if parallel > 1:
238
- if multiple_run_dirs:
239
- with task_run_dir_switching():
240
- return await run_multiple(task_run_options, parallel)
241
- else:
242
- with chdir(run_dir):
243
- return await run_multiple(task_run_options, parallel)
244
-
245
- # single mode is for a single task definitions (which
246
- # could in turn be executed for multiple models)
247
- else:
248
- with chdir(run_dir):
249
- return await run_single(task_run_options, debug_errors)
234
+ if parallel > 1:
235
+ return await run_multiple(task_run_options, parallel)
250
236
  else:
251
- if parallel > 1:
252
- return await run_multiple(task_run_options, parallel)
253
- else:
254
- return await run_single(task_run_options, debug_errors)
237
+ return await run_single(task_run_options, debug_errors)
255
238
 
256
239
  finally:
257
240
  # shutdown sandbox environments
@@ -359,12 +342,21 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
359
342
  f"task: {task_options.task.name} ({task_options.model})",
360
343
  ):
361
344
  async with anyio.create_task_group() as tg:
362
-
363
- async def run_task() -> None:
364
- result = await task_run(task_options)
365
- results.append(result)
366
-
367
- tg.start_soon(run_task)
345
+ # Create a factory function that captures the current
346
+ # task_options. Otherwise, we suffer from Python's
347
+ # late/by reference binding behavior.
348
+ # see: https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
349
+ def create_task_runner(
350
+ options: TaskRunOptions = task_options,
351
+ ) -> Callable[[], Awaitable[None]]:
352
+ async def run_task() -> None:
353
+ nonlocal result
354
+ result = await task_run(options)
355
+ results.append(result)
356
+
357
+ return run_task
358
+
359
+ tg.start_soon(create_task_runner())
368
360
 
369
361
  except Exception as ex:
370
362
  # errors generally don't escape from tasks (the exception being if an error