inspect-ai 0.3.73__py3-none-any.whl → 0.3.75__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +3 -2
- inspect_ai/_cli/cache.py +1 -1
- inspect_ai/_cli/common.py +15 -0
- inspect_ai/_cli/eval.py +4 -5
- inspect_ai/_cli/log.py +1 -1
- inspect_ai/_cli/sandbox.py +1 -1
- inspect_ai/_cli/trace.py +1 -1
- inspect_ai/_cli/view.py +1 -1
- inspect_ai/_display/core/config.py +3 -1
- inspect_ai/_eval/eval.py +55 -61
- inspect_ai/_eval/evalset.py +63 -154
- inspect_ai/_eval/loader.py +27 -54
- inspect_ai/_eval/registry.py +1 -10
- inspect_ai/_eval/run.py +3 -4
- inspect_ai/_eval/task/__init__.py +8 -2
- inspect_ai/_eval/task/log.py +9 -1
- inspect_ai/_eval/task/resolved.py +35 -0
- inspect_ai/_eval/task/task.py +50 -69
- inspect_ai/_eval/task/tasks.py +30 -0
- inspect_ai/_util/constants.py +3 -0
- inspect_ai/_util/dotenv.py +17 -0
- inspect_ai/_util/registry.py +43 -2
- inspect_ai/_view/server.py +28 -10
- inspect_ai/_view/www/dist/assets/index.css +4 -3
- inspect_ai/_view/www/dist/assets/index.js +13030 -25523
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/appearance/styles.ts +6 -5
- inspect_ai/_view/www/src/components/AnsiDisplay.tsx +2 -2
- inspect_ai/_view/www/src/constants.ts +3 -0
- inspect_ai/_view/www/src/logfile/remoteZipFile.ts +141 -20
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +2 -1
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +1 -1
- inspect_ai/_view/www/src/samples/chat/tools/tool.ts +7 -5
- inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.module.css +1 -0
- inspect_ai/_view/www/src/samples/error/FlatSampleErrorView.tsx +3 -1
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +5 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +5 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +17 -12
- inspect_ai/_view/www/src/workspace/sidebar/SidebarLogEntry.tsx +2 -1
- inspect_ai/_view/www/yarn.lock +12 -5
- inspect_ai/log/_log.py +10 -1
- inspect_ai/log/_recorders/eval.py +27 -8
- inspect_ai/log/_recorders/json.py +2 -2
- inspect_ai/model/_cache.py +3 -1
- inspect_ai/model/_chat_message.py +12 -1
- inspect_ai/model/_model.py +25 -11
- inspect_ai/model/_providers/anthropic.py +34 -2
- inspect_ai/model/_providers/google.py +6 -2
- inspect_ai/model/_providers/none.py +31 -0
- inspect_ai/model/_providers/providers.py +7 -0
- inspect_ai/solver/_bridge/bridge.py +1 -1
- inspect_ai/solver/_chain.py +7 -6
- inspect_ai/tool/_tools/_computer/_computer.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
- inspect_ai/tool/_tools/_web_search.py +2 -2
- inspect_ai/util/_sandbox/context.py +2 -1
- inspect_ai/util/_sandbox/environment.py +17 -2
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/METADATA +4 -4
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/RECORD +63 -60
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.73.dist-info → inspect_ai-0.3.75.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/evalset.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import hashlib
|
2
2
|
import logging
|
3
3
|
from copy import deepcopy
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Literal, NamedTuple, Set, cast
|
5
5
|
|
6
6
|
import rich
|
7
7
|
from pydantic_core import to_json
|
@@ -17,6 +17,7 @@ from typing_extensions import Unpack
|
|
17
17
|
|
18
18
|
from inspect_ai._util.error import PrerequisiteError
|
19
19
|
from inspect_ai._util.file import basename, filesystem
|
20
|
+
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
|
20
21
|
from inspect_ai.approval._policy import ApprovalPolicy
|
21
22
|
from inspect_ai.log import EvalLog
|
22
23
|
from inspect_ai.log._bundle import bundle_log_dir
|
@@ -34,11 +35,14 @@ from inspect_ai.model import (
|
|
34
35
|
from inspect_ai.model._generate_config import GenerateConfig
|
35
36
|
from inspect_ai.solver._solver import Solver, SolverSpec
|
36
37
|
from inspect_ai.util import DisplayType, SandboxEnvironmentType
|
38
|
+
from inspect_ai.util._display import init_display_type
|
37
39
|
|
38
40
|
from .eval import eval, eval_init
|
39
|
-
from .loader import
|
40
|
-
from .task import Epochs
|
41
|
-
from .task.
|
41
|
+
from .loader import resolve_task_args
|
42
|
+
from .task import Epochs
|
43
|
+
from .task.resolved import ResolvedTask
|
44
|
+
from .task.task import PreviousTask
|
45
|
+
from .task.tasks import Tasks
|
42
46
|
|
43
47
|
logger = logging.getLogger(__name__)
|
44
48
|
|
@@ -56,7 +60,7 @@ def eval_set(
|
|
56
60
|
retry_wait: float | None = None,
|
57
61
|
retry_connections: float | None = None,
|
58
62
|
retry_cleanup: bool | None = None,
|
59
|
-
model: str | Model | list[str] | list[Model] | None =
|
63
|
+
model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
|
60
64
|
model_base_url: str | None = None,
|
61
65
|
model_args: dict[str, Any] | str = dict(),
|
62
66
|
task_args: dict[str, Any] | str = dict(),
|
@@ -107,9 +111,9 @@ def eval_set(
|
|
107
111
|
(defaults to 0.5)
|
108
112
|
retry_cleanup: Cleanup failed log files after retries
|
109
113
|
(defaults to True)
|
110
|
-
model: Model(s) for
|
111
|
-
|
112
|
-
|
114
|
+
model: Model(s) for evaluation. If not specified use the value of the INSPECT_EVAL_MODEL
|
115
|
+
environment variable. Specify `None` to define no default model(s), which will
|
116
|
+
leave model usage entirely up to tasks.
|
113
117
|
model_base_url: Base URL for communicating
|
114
118
|
with the model API.
|
115
119
|
model_args: Model creation args
|
@@ -154,7 +158,7 @@ def eval_set(
|
|
154
158
|
max_samples: Maximum number of samples to run in parallel
|
155
159
|
(default is max_connections)
|
156
160
|
max_tasks: Maximum number of tasks to run in parallel
|
157
|
-
(
|
161
|
+
(defaults to number of models being evaluated)
|
158
162
|
max_subprocesses: Maximum number of subprocesses to
|
159
163
|
run in parallel (default is os.cpu_count())
|
160
164
|
max_sandboxes: Maximum number of sandboxes (per-provider)
|
@@ -177,13 +181,11 @@ def eval_set(
|
|
177
181
|
"""
|
178
182
|
|
179
183
|
# helper function to run a set of evals
|
180
|
-
def run_eval(
|
181
|
-
tasks: list[Task] | list[PreviousTask], models: list[Model]
|
182
|
-
) -> list[EvalLog]:
|
184
|
+
def run_eval(tasks: list[ResolvedTask] | list[PreviousTask]) -> list[EvalLog]:
|
183
185
|
# run evals
|
184
186
|
results = eval(
|
185
187
|
tasks=tasks,
|
186
|
-
model=
|
188
|
+
model=None, # ResolvedTask/PreviousTask already carries its model
|
187
189
|
model_base_url=model_base_url,
|
188
190
|
model_args=model_args,
|
189
191
|
task_args=task_args,
|
@@ -231,30 +233,10 @@ def eval_set(
|
|
231
233
|
# return results
|
232
234
|
return results
|
233
235
|
|
234
|
-
#
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
) -> list[EvalLog]:
|
239
|
-
logs: list[EvalLog] = []
|
240
|
-
for task_group in task_groups:
|
241
|
-
# alias
|
242
|
-
group_models, group_tasks = task_group
|
243
|
-
|
244
|
-
# info log
|
245
|
-
logger.info(
|
246
|
-
f"eval_set (running task group): {','.join([task.task.name for task in group_tasks])}: {group_models}"
|
247
|
-
)
|
248
|
-
|
249
|
-
# run the evals
|
250
|
-
logs.extend(
|
251
|
-
run_eval(
|
252
|
-
tasks=run_tasks(group_tasks),
|
253
|
-
models=group_models.models,
|
254
|
-
)
|
255
|
-
)
|
256
|
-
|
257
|
-
return logs
|
236
|
+
# initialise display (otherwise eval_init will set it to full)
|
237
|
+
display = init_display_type(display)
|
238
|
+
if display == "conversation":
|
239
|
+
raise RuntimeError("eval_set cannot be used with conversation display.")
|
258
240
|
|
259
241
|
# resolve tasks
|
260
242
|
models, _, resolved_tasks = eval_init(
|
@@ -283,6 +265,7 @@ def eval_set(
|
|
283
265
|
retry_connections = retry_connections or 0.5
|
284
266
|
retry_cleanup = retry_cleanup is not False
|
285
267
|
max_connections = starting_max_connections(models, GenerateConfig(**kwargs))
|
268
|
+
max_tasks = max_tasks if max_tasks is not None else len(models)
|
286
269
|
|
287
270
|
# prepare console/status
|
288
271
|
console = rich.get_console()
|
@@ -331,15 +314,11 @@ def eval_set(
|
|
331
314
|
pending_tasks = [
|
332
315
|
task[1] for task in all_tasks if task[0] not in log_task_identifiers
|
333
316
|
]
|
334
|
-
task_groups = schedule_pending_tasks(pending_tasks)
|
335
317
|
|
336
318
|
# we have some pending tasks yet to run, run them
|
337
|
-
if len(
|
319
|
+
if len(pending_tasks) > 0:
|
338
320
|
# run the tasks
|
339
|
-
run_logs =
|
340
|
-
task_groups=task_groups,
|
341
|
-
run_tasks=lambda tasks: [task.task for task in tasks],
|
342
|
-
)
|
321
|
+
run_logs = run_eval(pending_tasks)
|
343
322
|
|
344
323
|
# if this was the entire list of resolved tasks, return results
|
345
324
|
if len(pending_tasks) == len(all_tasks):
|
@@ -365,42 +344,10 @@ def eval_set(
|
|
365
344
|
for task in resolved_tasks
|
366
345
|
if task_identifier(task) in failed_task_identifiers
|
367
346
|
]
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
def task_to_failed_log(task: ResolvedTask) -> Log:
|
373
|
-
resolved_task_identifier = task_identifier(task)
|
374
|
-
return next(
|
375
|
-
log
|
376
|
-
for log in failed_logs
|
377
|
-
if log.task_identifier == resolved_task_identifier
|
378
|
-
)
|
379
|
-
|
380
|
-
previous_tasks: list[PreviousTask] = []
|
381
|
-
for task, log in zip(tasks, map(task_to_failed_log, tasks)):
|
382
|
-
# NOTE: we used to try to recreate registry objects by
|
383
|
-
# by just passing the task name, but that didn't work
|
384
|
-
# when evals were run from another directory. we may
|
385
|
-
# want to bring this back but we'd need to resolve the
|
386
|
-
# directory issues.
|
387
|
-
|
388
|
-
# deepcopy so the same instance is not run twice
|
389
|
-
prev_task = deepcopy(task.task)
|
390
|
-
|
391
|
-
previous_tasks.append(
|
392
|
-
PreviousTask(
|
393
|
-
id=log.header.eval.task_id,
|
394
|
-
task=prev_task,
|
395
|
-
task_args=resolve_task_args(task.task),
|
396
|
-
log=read_eval_log(log.info),
|
397
|
-
)
|
398
|
-
)
|
399
|
-
|
400
|
-
return previous_tasks
|
401
|
-
|
402
|
-
retried_logs = run_task_groups(
|
403
|
-
task_groups=task_groups, run_tasks=run_previous_tasks
|
347
|
+
|
348
|
+
# run previous tasks (no models passed b/c previous task already carries its model)
|
349
|
+
retried_logs = run_eval(
|
350
|
+
tasks=as_previous_tasks(failed_tasks, failed_logs)
|
404
351
|
)
|
405
352
|
|
406
353
|
# return success
|
@@ -443,6 +390,42 @@ def eval_set(
|
|
443
390
|
return success, results
|
444
391
|
|
445
392
|
|
393
|
+
# convert resolved tasks to previous tasks
|
394
|
+
def as_previous_tasks(
|
395
|
+
tasks: list[ResolvedTask], failed_logs: list[Log]
|
396
|
+
) -> list[PreviousTask]:
|
397
|
+
def task_to_failed_log(task: ResolvedTask) -> Log:
|
398
|
+
resolved_task_identifier = task_identifier(task)
|
399
|
+
return next(
|
400
|
+
log
|
401
|
+
for log in failed_logs
|
402
|
+
if log.task_identifier == resolved_task_identifier
|
403
|
+
)
|
404
|
+
|
405
|
+
previous_tasks: list[PreviousTask] = []
|
406
|
+
for task, log in zip(tasks, map(task_to_failed_log, tasks)):
|
407
|
+
# NOTE: we used to try to recreate registry objects by
|
408
|
+
# by just passing the task name, but that didn't work
|
409
|
+
# when evals were run from another directory. we may
|
410
|
+
# want to bring this back but we'd need to resolve the
|
411
|
+
# directory issues.
|
412
|
+
|
413
|
+
# deepcopy so the same instance is not run twice
|
414
|
+
prev_task = deepcopy(task.task)
|
415
|
+
|
416
|
+
previous_tasks.append(
|
417
|
+
PreviousTask(
|
418
|
+
id=log.header.eval.task_id,
|
419
|
+
task=prev_task,
|
420
|
+
task_args=resolve_task_args(task.task),
|
421
|
+
model=task.model,
|
422
|
+
log=read_eval_log(log.info),
|
423
|
+
)
|
424
|
+
)
|
425
|
+
|
426
|
+
return previous_tasks
|
427
|
+
|
428
|
+
|
446
429
|
# filters to determine when we are done
|
447
430
|
|
448
431
|
|
@@ -574,7 +557,7 @@ def task_identifier(task: ResolvedTask | EvalLog) -> str:
|
|
574
557
|
task_file = task.eval.task_file or ""
|
575
558
|
task_name = task.eval.task
|
576
559
|
task_args = task.eval.task_args
|
577
|
-
model = task.eval.model
|
560
|
+
model = str(task.eval.model)
|
578
561
|
|
579
562
|
# hash for task args
|
580
563
|
task_args_hash = hashlib.sha256(
|
@@ -617,80 +600,6 @@ class ModelList:
|
|
617
600
|
return ",".join(model_names)
|
618
601
|
|
619
602
|
|
620
|
-
class TaskGroup(NamedTuple):
|
621
|
-
models: ModelList
|
622
|
-
tasks: list[ResolvedTask]
|
623
|
-
|
624
|
-
|
625
|
-
# group into models => tasks for maximum parallelism
|
626
|
-
def schedule_pending_tasks(pending_tasks: list[ResolvedTask]) -> list[TaskGroup]:
|
627
|
-
# build a map of task identifiers and the models they target
|
628
|
-
task_id_model_targets: dict[str, ModelList] = {}
|
629
|
-
for pending_task in pending_tasks:
|
630
|
-
task_id = task_identifier_without_model(task_identifier(pending_task))
|
631
|
-
if task_id not in task_id_model_targets:
|
632
|
-
task_id_model_targets[task_id] = ModelList([])
|
633
|
-
if pending_task.model not in task_id_model_targets[task_id].models:
|
634
|
-
task_id_model_targets[task_id].models.append(pending_task.model)
|
635
|
-
|
636
|
-
# build a list of unique model targets
|
637
|
-
unique_model_targets: Set[ModelList] = set(task_id_model_targets.values())
|
638
|
-
|
639
|
-
# create schedule
|
640
|
-
schedule: list[TaskGroup] = [
|
641
|
-
TaskGroup(models=model_target, tasks=[])
|
642
|
-
for model_target in unique_model_targets
|
643
|
-
]
|
644
|
-
|
645
|
-
for models, tasks in schedule:
|
646
|
-
# which task ids have this set of models
|
647
|
-
task_ids: list[str] = []
|
648
|
-
for task_id, task_models in task_id_model_targets.items():
|
649
|
-
if task_models == models:
|
650
|
-
task_ids.append(task_id)
|
651
|
-
|
652
|
-
# find a task for each of these ids
|
653
|
-
for task_id in task_ids:
|
654
|
-
tasks.append(
|
655
|
-
next(
|
656
|
-
(
|
657
|
-
task
|
658
|
-
for task in pending_tasks
|
659
|
-
if task_id
|
660
|
-
== task_identifier_without_model(task_identifier(task))
|
661
|
-
)
|
662
|
-
)
|
663
|
-
)
|
664
|
-
|
665
|
-
# deterministic return order
|
666
|
-
schedule.sort(key=lambda x: str(x[0]))
|
667
|
-
|
668
|
-
return schedule
|
669
|
-
|
670
|
-
|
671
|
-
# group into model => tasks (can't do multiple models b/c these are PreviousTask
|
672
|
-
# instances (and therefore model/task pair specific -- we don't want to create
|
673
|
-
# multiple instances of these tasks)
|
674
|
-
def schedule_retry_tasks(retry_tasks: list[ResolvedTask]) -> list[TaskGroup]:
|
675
|
-
# build a list of unique model targets
|
676
|
-
unique_model_targets: Set[ModelList] = set()
|
677
|
-
for retry_task in retry_tasks:
|
678
|
-
unique_model_targets.add(ModelList([retry_task.model]))
|
679
|
-
|
680
|
-
# create a task group for reach model target
|
681
|
-
schedule: list[TaskGroup] = []
|
682
|
-
for model_target in unique_model_targets:
|
683
|
-
group_tasks = [
|
684
|
-
task for task in retry_tasks if ModelList([task.model]) == model_target
|
685
|
-
]
|
686
|
-
schedule.append(TaskGroup(model_target, group_tasks))
|
687
|
-
|
688
|
-
# deterministic return order
|
689
|
-
schedule.sort(key=lambda x: str(x[0]))
|
690
|
-
|
691
|
-
return schedule
|
692
|
-
|
693
|
-
|
694
603
|
def starting_max_connections(models: list[Model], config: GenerateConfig) -> int:
|
695
604
|
# if there is an explicit config use that
|
696
605
|
if config.max_connections is not None:
|
inspect_ai/_eval/loader.py
CHANGED
@@ -2,7 +2,6 @@ import ast
|
|
2
2
|
import contextlib
|
3
3
|
import inspect
|
4
4
|
import os
|
5
|
-
from dataclasses import dataclass, field
|
6
5
|
from importlib.machinery import SourceFileLoader
|
7
6
|
from importlib.util import module_from_spec, spec_from_loader
|
8
7
|
from logging import getLogger
|
@@ -12,6 +11,7 @@ from typing import Any, Callable, Tuple, cast
|
|
12
11
|
|
13
12
|
from typing_extensions import overload
|
14
13
|
|
14
|
+
from inspect_ai._eval.task.resolved import ResolvedTask
|
15
15
|
from inspect_ai._eval.task.util import task_file, task_run_dir
|
16
16
|
from inspect_ai._util._async import configured_async_backend
|
17
17
|
from inspect_ai._util.decorator import parse_decorators
|
@@ -26,44 +26,26 @@ from inspect_ai._util.registry import (
|
|
26
26
|
registry_lookup,
|
27
27
|
registry_params,
|
28
28
|
)
|
29
|
-
from inspect_ai.model import Model
|
29
|
+
from inspect_ai.model import Model
|
30
30
|
from inspect_ai.scorer._scorer import Scorer, ScorerSpec, scorer_create
|
31
31
|
from inspect_ai.solver._bridge import bridge
|
32
32
|
from inspect_ai.solver._solver import Solver, SolverSpec
|
33
33
|
from inspect_ai.util import SandboxEnvironmentSpec, SandboxEnvironmentType
|
34
|
-
from inspect_ai.util._sandbox.environment import
|
34
|
+
from inspect_ai.util._sandbox.environment import (
|
35
|
+
resolve_sandbox_environment,
|
36
|
+
)
|
35
37
|
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
|
36
38
|
|
37
39
|
from .list import task_files
|
38
40
|
from .registry import task_create
|
39
|
-
from .task import PreviousTask, Task, TaskInfo
|
41
|
+
from .task import PreviousTask, Task, TaskInfo
|
40
42
|
from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
|
41
|
-
from .task.run import
|
43
|
+
from .task.run import eval_log_sample_source
|
44
|
+
from .task.tasks import Tasks
|
42
45
|
|
43
46
|
logger = getLogger(__name__)
|
44
47
|
|
45
48
|
|
46
|
-
@dataclass(frozen=True)
|
47
|
-
class ResolvedTask:
|
48
|
-
task: Task
|
49
|
-
task_args: dict[str, Any]
|
50
|
-
task_file: str | None
|
51
|
-
model: Model
|
52
|
-
sandbox: SandboxEnvironmentSpec | None
|
53
|
-
sequence: int
|
54
|
-
id: str | None = field(default=None)
|
55
|
-
sample_source: EvalSampleSource | None = field(default=None)
|
56
|
-
|
57
|
-
@property
|
58
|
-
def has_sandbox(self) -> bool:
|
59
|
-
if self.sandbox:
|
60
|
-
return True
|
61
|
-
else:
|
62
|
-
return any(
|
63
|
-
[True if sample.sandbox else False for sample in self.task.dataset]
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
49
|
def resolve_tasks(
|
68
50
|
tasks: Tasks,
|
69
51
|
task_args: dict[str, Any],
|
@@ -76,16 +58,22 @@ def resolve_tasks(
|
|
76
58
|
task=task,
|
77
59
|
task_args=resolve_task_args(task),
|
78
60
|
task_file=task_file(task, relative=True),
|
79
|
-
model=model,
|
61
|
+
model=task.model or model,
|
80
62
|
sandbox=resolve_task_sandbox(task, sandbox),
|
81
63
|
sequence=sequence,
|
82
64
|
)
|
83
65
|
for sequence, task in enumerate(tasks)
|
84
66
|
]
|
85
67
|
|
68
|
+
# reflect resolved tasks right back
|
69
|
+
if isinstance(tasks, ResolvedTask):
|
70
|
+
return [tasks]
|
71
|
+
elif isinstance(tasks, list) and isinstance(tasks[0], ResolvedTask):
|
72
|
+
return cast(list[ResolvedTask], tasks)
|
73
|
+
|
86
74
|
# take empty lists out of play
|
87
75
|
if isinstance(tasks, list) and len(tasks) == 0:
|
88
|
-
return as_resolved_tasks(load_tasks(None,
|
76
|
+
return as_resolved_tasks(load_tasks(None, task_args))
|
89
77
|
|
90
78
|
# simple cases of passing us Task objects
|
91
79
|
if isinstance(tasks, Task):
|
@@ -109,9 +97,7 @@ def resolve_tasks(
|
|
109
97
|
loaded_task = previous_task.task
|
110
98
|
else:
|
111
99
|
loaded_task_args = previous_task.task_args
|
112
|
-
loaded_task = load_tasks([previous_task.task],
|
113
|
-
0
|
114
|
-
]
|
100
|
+
loaded_task = load_tasks([previous_task.task], loaded_task_args)[0]
|
115
101
|
loaded_tasks.append(loaded_task)
|
116
102
|
loaded_tasks_args.append(loaded_task_args)
|
117
103
|
|
@@ -120,7 +106,7 @@ def resolve_tasks(
|
|
120
106
|
task=loaded_task,
|
121
107
|
task_args=loaded_task_args,
|
122
108
|
task_file=previous_task.log.eval.task_file,
|
123
|
-
model=model,
|
109
|
+
model=previous_task.model or loaded_task.model or model,
|
124
110
|
sandbox=previous_task.log.eval.sandbox,
|
125
111
|
sequence=sequence,
|
126
112
|
id=previous_task.id,
|
@@ -153,19 +139,14 @@ def resolve_tasks(
|
|
153
139
|
tasks = [tasks]
|
154
140
|
|
155
141
|
# done! let's load the tasks
|
156
|
-
return as_resolved_tasks(
|
157
|
-
load_tasks(cast(list[str] | None, tasks), model, task_args)
|
158
|
-
)
|
142
|
+
return as_resolved_tasks(load_tasks(cast(list[str] | None, tasks), task_args))
|
159
143
|
|
160
144
|
|
161
145
|
def resolve_task_args(task: Task) -> dict[str, Any]:
|
162
146
|
# was the task instantiated via the registry or a decorator?
|
163
147
|
# if so then we can get the task_args from the registry.
|
164
148
|
try:
|
165
|
-
# filter out model as that's dyanmic and automatically passed
|
166
149
|
task_args = dict(registry_params(task))
|
167
|
-
if "model" in task_args:
|
168
|
-
del task_args["model"]
|
169
150
|
return task_args
|
170
151
|
|
171
152
|
# if it wasn't instantiated via the registry or a decorator
|
@@ -217,34 +198,29 @@ def resolve_task_sandbox(
|
|
217
198
|
|
218
199
|
|
219
200
|
def load_tasks(
|
220
|
-
task_specs: list[str] | None,
|
201
|
+
task_specs: list[str] | None, task_args: dict[str, Any] = {}
|
221
202
|
) -> list[Task]:
|
222
203
|
"""Load one more more tasks (if no tasks are specified, load from the current working directory"""
|
223
|
-
# determine ModelName object for task creation parameterized by model
|
224
|
-
model_name = ModelName(model)
|
225
204
|
# load tasks
|
226
205
|
return [
|
227
206
|
spec
|
228
207
|
for task_spec in (task_specs if task_specs else [Path.cwd().as_posix()])
|
229
|
-
for spec in load_task_spec(task_spec,
|
208
|
+
for spec in load_task_spec(task_spec, task_args)
|
230
209
|
]
|
231
210
|
|
232
211
|
|
233
|
-
def load_task_spec(
|
234
|
-
task_spec: str, model: ModelName, task_args: dict[str, Any] = {}
|
235
|
-
) -> list[Task]:
|
212
|
+
def load_task_spec(task_spec: str, task_args: dict[str, Any] = {}) -> list[Task]:
|
236
213
|
# task in a python package
|
237
214
|
if registry_lookup("task", task_spec) is not None:
|
238
215
|
# create the task from a python package
|
239
|
-
return [task_create(task_spec,
|
216
|
+
return [task_create(task_spec, **task_args)]
|
240
217
|
else:
|
241
218
|
# load tasks from glob
|
242
|
-
return create_tasks([task_spec],
|
219
|
+
return create_tasks([task_spec], task_args)
|
243
220
|
|
244
221
|
|
245
222
|
def create_tasks(
|
246
223
|
globs: list[str],
|
247
|
-
model: ModelName,
|
248
224
|
task_args: dict[str, Any] = {},
|
249
225
|
root_dir: Path | None = None,
|
250
226
|
) -> list[Task]:
|
@@ -261,9 +237,7 @@ def create_tasks(
|
|
261
237
|
if spec_split[1] is not None:
|
262
238
|
task_path = Path(spec_split[0])
|
263
239
|
load_file_tasks(task_path.absolute())
|
264
|
-
tasks.extend(
|
265
|
-
create_file_tasks(task_path, model, [spec_split[1]], task_args)
|
266
|
-
)
|
240
|
+
tasks.extend(create_file_tasks(task_path, [spec_split[1]], task_args))
|
267
241
|
else:
|
268
242
|
# if the glob is the root dir then set it to empty (will result in
|
269
243
|
# enumeration of the root dir)
|
@@ -271,7 +245,7 @@ def create_tasks(
|
|
271
245
|
files = task_files(target, root_dir)
|
272
246
|
files = sorted(files, key=lambda f: f.as_posix())
|
273
247
|
for file in files:
|
274
|
-
tasks.extend(create_file_tasks(file,
|
248
|
+
tasks.extend(create_file_tasks(file, None, task_args))
|
275
249
|
return tasks
|
276
250
|
|
277
251
|
|
@@ -282,7 +256,6 @@ def load_file_tasks(file: Path) -> None:
|
|
282
256
|
|
283
257
|
def create_file_tasks(
|
284
258
|
file: Path,
|
285
|
-
model: ModelName,
|
286
259
|
task_specs: list[str] | list[RegistryInfo] | None = None,
|
287
260
|
task_args: dict[str, Any] = {},
|
288
261
|
) -> list[Task]:
|
@@ -302,7 +275,7 @@ def create_file_tasks(
|
|
302
275
|
# create the task from the loaded source file and
|
303
276
|
# note that it was loaded from this directory
|
304
277
|
# (will be used later to ensure it runs in the directory)
|
305
|
-
task = task_create(task_spec,
|
278
|
+
task = task_create(task_spec, **task_args)
|
306
279
|
setattr(task, TASK_FILE_ATTR, file.as_posix())
|
307
280
|
setattr(task, TASK_RUN_DIR_ATTR, run_dir)
|
308
281
|
tasks.append(task)
|
inspect_ai/_eval/registry.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
-
from copy import deepcopy
|
4
3
|
from functools import wraps
|
5
4
|
from pathlib import Path
|
6
5
|
from typing import Any, Callable, TypeVar, cast, overload
|
@@ -16,7 +15,6 @@ from inspect_ai._util.registry import (
|
|
16
15
|
registry_name,
|
17
16
|
registry_tag,
|
18
17
|
)
|
19
|
-
from inspect_ai.model import ModelName
|
20
18
|
|
21
19
|
from .task import Task
|
22
20
|
from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
|
@@ -54,7 +52,7 @@ def task_register(
|
|
54
52
|
return task
|
55
53
|
|
56
54
|
|
57
|
-
def task_create(name: str,
|
55
|
+
def task_create(name: str, **kwargs: Any) -> Task:
|
58
56
|
r"""Create a Task based on its registered name.
|
59
57
|
|
60
58
|
Tasks can be a function that returns a Task or a
|
@@ -62,17 +60,11 @@ def task_create(name: str, model: ModelName, **kwargs: Any) -> Task:
|
|
62
60
|
|
63
61
|
Args:
|
64
62
|
name (str): Name of task (Optional, defaults to object name)
|
65
|
-
model (ModelName): Model name
|
66
63
|
**kwargs (dict): Optional creation arguments for the task
|
67
64
|
|
68
65
|
Returns:
|
69
66
|
Task with registry info attribute
|
70
67
|
"""
|
71
|
-
# bring in model arg (first deepcopy as we will mutate it)
|
72
|
-
# add model to task_args
|
73
|
-
kwargs = deepcopy(kwargs)
|
74
|
-
kwargs[MODEL_PARAM] = model
|
75
|
-
|
76
68
|
# match kwargs params to signature (warn if param not found)
|
77
69
|
# (note that we always pass the 'model' param but tasks aren't
|
78
70
|
# required to consume it, so we don't warn for 'model')
|
@@ -85,7 +77,6 @@ def task_create(name: str, model: ModelName, **kwargs: Any) -> Task:
|
|
85
77
|
for param in kwargs.keys():
|
86
78
|
if param in task_params:
|
87
79
|
task_args[param] = kwargs[param]
|
88
|
-
elif param != MODEL_PARAM:
|
89
80
|
if "kwargs" in task_params:
|
90
81
|
task_args[param] = kwargs[param]
|
91
82
|
else:
|
inspect_ai/_eval/run.py
CHANGED
@@ -2,7 +2,7 @@ import functools
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import sys
|
5
|
-
from typing import
|
5
|
+
from typing import Awaitable, Callable, Set, cast
|
6
6
|
|
7
7
|
from inspect_ai._util.trace import trace_action
|
8
8
|
|
@@ -44,11 +44,11 @@ from inspect_ai.util._sandbox.environment import (
|
|
44
44
|
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
|
45
45
|
|
46
46
|
from .loader import (
|
47
|
-
ResolvedTask,
|
48
47
|
as_solver_spec,
|
49
48
|
solver_from_spec,
|
50
49
|
)
|
51
50
|
from .task.log import TaskLogger
|
51
|
+
from .task.resolved import ResolvedTask
|
52
52
|
from .task.run import TaskRunOptions, task_run
|
53
53
|
from .task.rundir import task_run_dir_switching
|
54
54
|
from .task.sandbox import TaskSandboxEnvironment, resolve_sandbox_for_task
|
@@ -64,7 +64,6 @@ async def eval_run(
|
|
64
64
|
eval_config: EvalConfig,
|
65
65
|
eval_sandbox: SandboxEnvironmentType | None,
|
66
66
|
recorder: Recorder,
|
67
|
-
model_args: dict[str, Any],
|
68
67
|
epochs_reducer: list[ScoreReducer] | None = None,
|
69
68
|
solver: Solver | SolverSpec | None = None,
|
70
69
|
tags: list[str] | None = None,
|
@@ -200,7 +199,7 @@ async def eval_run(
|
|
200
199
|
sandbox=resolved_task.sandbox,
|
201
200
|
task_attribs=task.attribs,
|
202
201
|
task_args=resolved_task.task_args,
|
203
|
-
model_args=model_args,
|
202
|
+
model_args=resolved_task.model.model_args,
|
204
203
|
eval_config=task_eval_config,
|
205
204
|
metadata=task.metadata,
|
206
205
|
recorder=recorder,
|
@@ -1,4 +1,10 @@
|
|
1
|
-
from .task import Task, TaskInfo, PreviousTask,
|
1
|
+
from .task import Task, TaskInfo, PreviousTask, task_with # noqa: I001, F401
|
2
2
|
from .epochs import Epochs
|
3
3
|
|
4
|
-
__all__ = [
|
4
|
+
__all__ = [
|
5
|
+
"Epochs",
|
6
|
+
"Task",
|
7
|
+
"TaskInfo",
|
8
|
+
"PreviousTask",
|
9
|
+
"task_with",
|
10
|
+
]
|
inspect_ai/_eval/task/log.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from importlib import metadata as importlib_metadata
|
2
|
-
from
|
2
|
+
from inspect import isgenerator
|
3
|
+
from typing import Any, Iterator, Literal, cast
|
3
4
|
|
4
5
|
from shortuuid import uuid
|
5
6
|
|
@@ -83,6 +84,13 @@ class TaskLogger:
|
|
83
84
|
del model_args["api_key"]
|
84
85
|
model_args = {k: v for k, v in model_args.items() if not k.startswith("aws_")}
|
85
86
|
|
87
|
+
# don't try to serialise generators
|
88
|
+
model_args = {
|
89
|
+
k: v
|
90
|
+
for k, v in model_args.items()
|
91
|
+
if not isgenerator(v) and not isinstance(v, Iterator)
|
92
|
+
}
|
93
|
+
|
86
94
|
# cwd_relative_path for sandbox config
|
87
95
|
if sandbox and isinstance(sandbox.config, str):
|
88
96
|
sandbox = SandboxEnvironmentSpec(
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import Any, Set
|
3
|
+
|
4
|
+
from inspect_ai._eval.task import Task
|
5
|
+
from inspect_ai._eval.task.run import EvalSampleSource
|
6
|
+
from inspect_ai.model import Model
|
7
|
+
from inspect_ai.util import SandboxEnvironmentSpec
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass(frozen=True)
|
11
|
+
class ResolvedTask:
|
12
|
+
task: Task
|
13
|
+
task_args: dict[str, Any]
|
14
|
+
task_file: str | None
|
15
|
+
model: Model
|
16
|
+
sandbox: SandboxEnvironmentSpec | None
|
17
|
+
sequence: int
|
18
|
+
id: str | None = field(default=None)
|
19
|
+
sample_source: EvalSampleSource | None = field(default=None)
|
20
|
+
|
21
|
+
@property
|
22
|
+
def has_sandbox(self) -> bool:
|
23
|
+
if self.sandbox:
|
24
|
+
return True
|
25
|
+
else:
|
26
|
+
return any(
|
27
|
+
[True if sample.sandbox else False for sample in self.task.dataset]
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def resolved_model_names(tasks: list[ResolvedTask]) -> list[str]:
|
32
|
+
models: Set[str] = set()
|
33
|
+
for task in tasks:
|
34
|
+
models.add(str(task.model))
|
35
|
+
return list(models)
|