inspect-ai 0.3.87__py3-none-any.whl → 0.3.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -244
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +55 -18
  22. inspect_ai/_view/www/dist/assets/index.js +550 -458
  23. inspect_ai/_view/www/log-schema.json +84 -1
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  30. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  31. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  32. inspect_ai/_view/www/src/types/log.d.ts +150 -129
  33. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  34. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  35. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  36. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  37. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  38. inspect_ai/agent/_agent.py +12 -0
  39. inspect_ai/agent/_as_tool.py +1 -1
  40. inspect_ai/agent/_bridge/bridge.py +9 -2
  41. inspect_ai/agent/_react.py +142 -74
  42. inspect_ai/agent/_run.py +13 -2
  43. inspect_ai/agent/_types.py +6 -0
  44. inspect_ai/approval/_apply.py +6 -9
  45. inspect_ai/approval/_approver.py +3 -3
  46. inspect_ai/approval/_auto.py +2 -2
  47. inspect_ai/approval/_call.py +20 -4
  48. inspect_ai/approval/_human/approver.py +3 -3
  49. inspect_ai/approval/_human/manager.py +2 -2
  50. inspect_ai/approval/_human/panel.py +3 -3
  51. inspect_ai/approval/_policy.py +3 -3
  52. inspect_ai/log/__init__.py +2 -0
  53. inspect_ai/log/_log.py +23 -2
  54. inspect_ai/log/_model.py +58 -0
  55. inspect_ai/log/_recorders/file.py +14 -3
  56. inspect_ai/log/_transcript.py +3 -0
  57. inspect_ai/model/__init__.py +2 -0
  58. inspect_ai/model/_call_tools.py +15 -2
  59. inspect_ai/model/_model.py +49 -3
  60. inspect_ai/model/_openai.py +151 -21
  61. inspect_ai/model/_providers/anthropic.py +25 -14
  62. inspect_ai/model/_providers/bedrock.py +3 -3
  63. inspect_ai/model/_providers/cloudflare.py +29 -108
  64. inspect_ai/model/_providers/google.py +21 -10
  65. inspect_ai/model/_providers/grok.py +23 -17
  66. inspect_ai/model/_providers/groq.py +61 -37
  67. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  68. inspect_ai/model/_providers/mistral.py +8 -3
  69. inspect_ai/model/_providers/ollama.py +8 -9
  70. inspect_ai/model/_providers/openai.py +53 -157
  71. inspect_ai/model/_providers/openai_compatible.py +195 -0
  72. inspect_ai/model/_providers/openrouter.py +4 -15
  73. inspect_ai/model/_providers/providers.py +11 -0
  74. inspect_ai/model/_providers/together.py +25 -23
  75. inspect_ai/model/_trim.py +83 -0
  76. inspect_ai/solver/_plan.py +5 -3
  77. inspect_ai/tool/_tool_call.py +3 -0
  78. inspect_ai/tool/_tool_def.py +8 -2
  79. inspect_ai/util/__init__.py +3 -0
  80. inspect_ai/util/_concurrency.py +15 -2
  81. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
  82. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +86 -81
  83. inspect_ai/_eval/task/rundir.py +0 -78
  84. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  85. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
  86. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
  87. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
  88. {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/score.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, cast
6
6
  import anyio
7
7
 
8
8
  from inspect_ai._display import display
9
+ from inspect_ai._eval.context import init_task_context
9
10
  from inspect_ai._eval.loader import scorer_from_spec
10
11
  from inspect_ai._util._async import configured_async_backend, run_coroutine, tg_collect
11
12
  from inspect_ai._util.platform import platform_init, running_in_notebook
@@ -14,7 +15,9 @@ from inspect_ai.log import (
14
15
  EvalLog,
15
16
  )
16
17
  from inspect_ai.log._log import EvalMetricDefinition
18
+ from inspect_ai.log._model import model_roles_config_to_model_roles
17
19
  from inspect_ai.model import ModelName
20
+ from inspect_ai.model._model import get_model
18
21
  from inspect_ai.scorer import Metric, Scorer, Target
19
22
  from inspect_ai.scorer._metric import SampleScore
20
23
  from inspect_ai.scorer._reducer import (
@@ -122,7 +125,7 @@ async def score_async(
122
125
  scores: list[dict[str, SampleScore]] = await tg_collect(
123
126
  [
124
127
  functools.partial(
125
- run_score_task, state, Target(sample.target), scorers, progress
128
+ run_score_task, log, state, Target(sample.target), scorers, progress
126
129
  )
127
130
  for (sample, state) in zip(log.samples, states)
128
131
  ]
@@ -218,11 +221,25 @@ async def task_score(
218
221
 
219
222
 
220
223
  async def run_score_task(
224
+ log: EvalLog,
221
225
  state: TaskState,
222
226
  target: Target,
223
227
  scorers: list[Scorer],
224
228
  progress: Callable[..., None],
225
229
  ) -> dict[str, SampleScore]:
230
+ # get the model then initialize the async context
231
+ model = get_model(
232
+ model=log.eval.model,
233
+ config=log.plan.config.merge(log.eval.model_generate_config),
234
+ **log.eval.model_args,
235
+ )
236
+
237
+ # get the model roles
238
+ model_roles = model_roles_config_to_model_roles(log.eval.model_roles)
239
+
240
+ # initialize active model
241
+ init_task_context(model, model_roles)
242
+
226
243
  results: dict[str, SampleScore] = {}
227
244
  for scorer in scorers:
228
245
  result = await scorer(state, target)
@@ -1,6 +1,5 @@
1
1
  from importlib import metadata as importlib_metadata
2
- from inspect import isgenerator
3
- from typing import Any, Iterator, Literal, cast
2
+ from typing import Any, Literal, cast
4
3
 
5
4
  from shortuuid import uuid
6
5
 
@@ -34,6 +33,7 @@ from inspect_ai.log._log import (
34
33
  EvalScorer,
35
34
  eval_config_defaults,
36
35
  )
36
+ from inspect_ai.log._model import model_args_for_log, model_roles_to_model_roles_config
37
37
  from inspect_ai.log._recorders import Recorder
38
38
  from inspect_ai.log._recorders.buffer import SampleBufferDatabase
39
39
  from inspect_ai.log._recorders.types import SampleEvent, SampleSummary
@@ -63,6 +63,7 @@ class TaskLogger:
63
63
  solver: SolverSpec | None,
64
64
  tags: list[str] | None,
65
65
  model: Model,
66
+ model_roles: dict[str, Model] | None,
66
67
  dataset: Dataset,
67
68
  scorer: list[ScorerSpec] | None,
68
69
  metrics: list[MetricSpec] | dict[str, list[MetricSpec]] | None,
@@ -84,17 +85,7 @@ class TaskLogger:
84
85
  packages = {PKG_NAME: importlib_metadata.version(PKG_NAME)}
85
86
 
86
87
  # redact authentication oriented model_args
87
- model_args = model_args.copy()
88
- if "api_key" in model_args:
89
- del model_args["api_key"]
90
- model_args = {k: v for k, v in model_args.items() if not k.startswith("aws_")}
91
-
92
- # don't try to serialise generators
93
- model_args = {
94
- k: v
95
- for k, v in model_args.items()
96
- if not isgenerator(v) and not isinstance(v, Iterator)
97
- }
88
+ model_args = model_args_for_log(model_args)
98
89
 
99
90
  # cwd_relative_path for sandbox config
100
91
  if sandbox and isinstance(sandbox.config, str):
@@ -141,6 +132,7 @@ class TaskLogger:
141
132
  model=str(ModelName(model)),
142
133
  model_generate_config=model.config,
143
134
  model_base_url=model.api.base_url,
135
+ model_roles=model_roles_to_model_roles_config(model_roles),
144
136
  dataset=EvalDataset(
145
137
  name=dataset.name,
146
138
  location=cwd_relative_path(dataset.location),
@@ -13,6 +13,7 @@ class ResolvedTask:
13
13
  task_args: dict[str, Any]
14
14
  task_file: str | None
15
15
  model: Model
16
+ model_roles: dict[str, Model] | None
16
17
  sandbox: SandboxEnvironmentSpec | None
17
18
  sequence: int
18
19
  id: str | None = field(default=None)
@@ -101,7 +101,6 @@ from .images import (
101
101
  )
102
102
  from .log import TaskLogger, collect_eval_data, log_start
103
103
  from .results import eval_results
104
- from .rundir import set_task_chdir
105
104
  from .sandbox import sandboxenv_context
106
105
  from .util import sample_messages, slice_dataset
107
106
 
@@ -121,6 +120,7 @@ SAMPLE_TOTAL_PROGRESS_UNITS = 1
121
120
  class TaskRunOptions:
122
121
  task: Task
123
122
  model: Model
123
+ model_roles: dict[str, Model] | None
124
124
  sandbox: SandboxEnvironmentSpec | None
125
125
  logger: TaskLogger
126
126
  eval_wd: str
@@ -137,6 +137,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
137
137
  # destructure options
138
138
  task = options.task
139
139
  model = options.model
140
+ model_roles = options.model_roles
140
141
  sandbox = options.sandbox
141
142
  logger = options.logger
142
143
  eval_wd = options.eval_wd
@@ -151,156 +152,136 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
151
152
  generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
152
153
 
153
154
  # init task context
154
- init_task_context(model, options.task.approval, generate_config)
155
-
156
- # establish chdir for duration of execution (if a task has chdir=True)
157
- with set_task_chdir(task):
158
- # track stats and error
159
- results: EvalResults | None = None
160
- reductions: list[EvalSampleReductions] | None = None
161
- stats = EvalStats(started_at=iso_now())
162
-
163
- # handle sample errors (raise as required)
164
- sample_error_handler = SampleErrorHandler(
165
- config.fail_on_error, len(task.dataset)
166
- )
167
-
168
- # resolve some config
169
- model_name = ModelName(model)
170
- epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
171
- sandbox_cleanup = config.sandbox_cleanup is not False
172
- log_images = config.log_images is not False
173
- log_samples = config.log_samples is not False
174
-
175
- # resolve dataset
176
- _, samples, states = await resolve_dataset(
177
- dataset=task.dataset,
178
- model_name=model_name,
179
- limit=config.limit,
180
- sample_id=config.sample_id,
181
- epochs=epochs,
182
- log_images=log_images,
183
- message_limit=config.message_limit,
184
- token_limit=config.token_limit,
185
- )
186
-
187
- # resolve the plan (unroll chains)
188
- solver = solver or task.solver
189
- if isinstance(solver, Plan):
190
- plan = solver
191
- elif isinstance(solver, Chain):
192
- plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
193
- else:
194
- plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
195
-
196
- # add setup solver(s) if specified
197
- if task.setup:
198
- plan.steps = unroll(task.setup) + plan.steps
199
-
200
- # resolve the scorer
201
- score = score and task.scorer is not None
202
- scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
203
- scorer_profiles = (
204
- [
205
- registry_log_name(scorer)
206
- for scorer in scorers
207
- if is_registry_object(scorer)
208
- ]
209
- if scorers is not None
210
- else ["(none)"]
211
- )
212
-
213
- # compute an eval directory relative log location if we can
214
- if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
215
- log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
216
- else:
217
- log_location = logger.location
218
-
219
- # create task profile for display
220
- profile = TaskProfile(
221
- name=task.name,
222
- file=logger.eval.task_file,
223
- model=model_name,
224
- dataset=task.dataset.name or "(samples)",
225
- scorer=", ".join(scorer_profiles),
226
- samples=len(samples),
227
- steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
228
- eval_config=config,
229
- task_args=logger.eval.task_args,
230
- generate_config=generate_config,
231
- tags=tags,
232
- log_location=log_location,
233
- )
155
+ init_task_context(model, model_roles, options.task.approval, generate_config)
156
+
157
+ # track stats and error
158
+ results: EvalResults | None = None
159
+ reductions: list[EvalSampleReductions] | None = None
160
+ stats = EvalStats(started_at=iso_now())
161
+
162
+ # handle sample errors (raise as required)
163
+ sample_error_handler = SampleErrorHandler(config.fail_on_error, len(task.dataset))
164
+
165
+ # resolve some config
166
+ model_name = ModelName(model)
167
+ epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
168
+ sandbox_cleanup = config.sandbox_cleanup is not False
169
+ log_images = config.log_images is not False
170
+ log_samples = config.log_samples is not False
171
+
172
+ # resolve dataset
173
+ _, samples, states = await resolve_dataset(
174
+ dataset=task.dataset,
175
+ model_name=model_name,
176
+ limit=config.limit,
177
+ sample_id=config.sample_id,
178
+ epochs=epochs,
179
+ log_images=log_images,
180
+ message_limit=config.message_limit,
181
+ token_limit=config.token_limit,
182
+ )
234
183
 
235
- with display().task(
236
- profile,
237
- ) as td:
238
- try:
239
- # start the log
240
- await log_start(logger, plan, generate_config)
241
-
242
- with td.progress() as p:
243
- # forward progress
244
- def progress(number: int) -> None:
245
- p.update(number)
246
-
247
- # provide solvers a function that they can use to generate output
248
- async def generate(
249
- state: TaskState,
250
- tool_calls: Literal["loop", "single", "none"] = "loop",
251
- cache: bool | CachePolicy = False,
252
- **kwargs: Unpack[GenerateConfigArgs],
253
- ) -> TaskState:
254
- return await task_generate(
255
- model=model,
256
- state=state,
257
- tool_calls=tool_calls,
258
- cache=cache,
259
- config=generate_config.merge(kwargs),
260
- )
184
+ # resolve the plan (unroll chains)
185
+ solver = solver or task.solver
186
+ if isinstance(solver, Plan):
187
+ plan = solver
188
+ elif isinstance(solver, Chain):
189
+ plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
190
+ else:
191
+ plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
192
+
193
+ # add setup solver(s) if specified
194
+ if task.setup:
195
+ plan.steps = unroll(task.setup) + plan.steps
196
+
197
+ # resolve the scorer
198
+ score = score and task.scorer is not None
199
+ scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
200
+ scorer_profiles = (
201
+ [registry_log_name(scorer) for scorer in scorers if is_registry_object(scorer)]
202
+ if scorers is not None
203
+ else ["(none)"]
204
+ )
261
205
 
262
- # set generate for fork module
263
- set_task_generate(generate)
206
+ # compute an eval directory relative log location if we can
207
+ if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
208
+ log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
209
+ else:
210
+ log_location = logger.location
211
+
212
+ # create task profile for display
213
+ profile = TaskProfile(
214
+ name=task.name,
215
+ file=logger.eval.task_file,
216
+ model=model_name,
217
+ dataset=task.dataset.name or "(samples)",
218
+ scorer=", ".join(scorer_profiles),
219
+ samples=len(samples),
220
+ steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
221
+ eval_config=config,
222
+ task_args=logger.eval.task_args,
223
+ generate_config=generate_config,
224
+ tags=tags,
225
+ log_location=log_location,
226
+ )
264
227
 
265
- # semaphore to limit concurrency
266
- sample_semaphore = create_sample_semaphore(
267
- config, generate_config, model.api
228
+ with display().task(
229
+ profile,
230
+ ) as td:
231
+ try:
232
+ # start the log
233
+ await log_start(logger, plan, generate_config)
234
+
235
+ with td.progress() as p:
236
+ # forward progress
237
+ def progress(number: int) -> None:
238
+ p.update(number)
239
+
240
+ # provide solvers a function that they can use to generate output
241
+ async def generate(
242
+ state: TaskState,
243
+ tool_calls: Literal["loop", "single", "none"] = "loop",
244
+ cache: bool | CachePolicy = False,
245
+ **kwargs: Unpack[GenerateConfigArgs],
246
+ ) -> TaskState:
247
+ return await task_generate(
248
+ model=model,
249
+ state=state,
250
+ tool_calls=tool_calls,
251
+ cache=cache,
252
+ config=generate_config.merge(kwargs),
268
253
  )
269
254
 
270
- # track when samples complete and update progress as we go
271
- progress_results: list[dict[str, SampleScore]] = []
255
+ # set generate for fork module
256
+ set_task_generate(generate)
272
257
 
273
- def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
274
- td.update_metrics(metrics)
275
- logger.update_metrics(metrics)
258
+ # semaphore to limit concurrency
259
+ sample_semaphore = create_sample_semaphore(
260
+ config, generate_config, model.api
261
+ )
276
262
 
277
- update_metrics_display = update_metrics_display_fn(
278
- update_metrics,
279
- display_metrics=profile.eval_config.score_display is not False,
280
- )
263
+ # track when samples complete and update progress as we go
264
+ progress_results: list[dict[str, SampleScore]] = []
281
265
 
282
- def sample_complete(sample_score: dict[str, SampleScore]) -> None:
283
- # Capture the result
284
- progress_results.append(sample_score)
266
+ def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
267
+ td.update_metrics(metrics)
268
+ logger.update_metrics(metrics)
285
269
 
286
- # Increment the segment progress
287
- td.sample_complete(
288
- complete=len(progress_results), total=len(samples)
289
- )
270
+ update_metrics_display = update_metrics_display_fn(
271
+ update_metrics,
272
+ display_metrics=profile.eval_config.score_display is not False,
273
+ )
290
274
 
291
- # Update metrics
292
- update_metrics_display(
293
- len(progress_results),
294
- progress_results,
295
- scorers,
296
- task.epochs_reducer,
297
- task.metrics,
298
- )
275
+ def sample_complete(sample_score: dict[str, SampleScore]) -> None:
276
+ # Capture the result
277
+ progress_results.append(sample_score)
299
278
 
300
- # initial progress
301
- td.sample_complete(complete=0, total=len(samples))
279
+ # Increment the segment progress
280
+ td.sample_complete(
281
+ complete=len(progress_results), total=len(samples)
282
+ )
302
283
 
303
- # Update metrics to empty state
284
+ # Update metrics
304
285
  update_metrics_display(
305
286
  len(progress_results),
306
287
  progress_results,
@@ -309,127 +290,133 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
309
290
  task.metrics,
310
291
  )
311
292
 
312
- sample_results = await tg_collect(
313
- [
314
- functools.partial(
315
- task_run_sample,
316
- task_name=task.name,
317
- sample=sample,
318
- state=state,
319
- sandbox=sandbox,
320
- max_sandboxes=config.max_sandboxes,
321
- sandbox_cleanup=sandbox_cleanup,
322
- plan=plan,
323
- scorers=scorers,
324
- generate=generate,
325
- progress=progress,
326
- logger=logger if log_samples else None,
327
- log_images=log_images,
328
- sample_source=sample_source,
329
- sample_error=sample_error_handler,
330
- sample_complete=sample_complete,
331
- fails_on_error=(
332
- config.fail_on_error is None
333
- or config.fail_on_error is True
334
- ),
335
- time_limit=config.time_limit,
336
- working_limit=config.working_limit,
337
- semaphore=sample_semaphore,
338
- )
339
- for (sample, state) in zip(samples, states)
340
- ]
341
- )
293
+ # initial progress
294
+ td.sample_complete(complete=0, total=len(samples))
342
295
 
343
- # compute and record metrics if we have scores
344
- completed_scores = [
345
- score_dict
346
- for score_dict in sample_results
347
- if isinstance(score_dict, dict)
348
- ]
349
-
350
- if len(completed_scores) > 0:
351
- results, reductions = eval_results(
352
- samples=profile.samples,
353
- scores=completed_scores,
354
- reducers=task.epochs_reducer,
355
- scorers=scorers,
356
- metrics=task.metrics,
357
- )
296
+ # Update metrics to empty state
297
+ update_metrics_display(
298
+ len(progress_results),
299
+ progress_results,
300
+ scorers,
301
+ task.epochs_reducer,
302
+ task.metrics,
303
+ )
358
304
 
359
- # collect eval data
360
- collect_eval_data(stats)
305
+ sample_results = await tg_collect(
306
+ [
307
+ functools.partial(
308
+ task_run_sample,
309
+ task_name=task.name,
310
+ sample=sample,
311
+ state=state,
312
+ sandbox=sandbox,
313
+ max_sandboxes=config.max_sandboxes,
314
+ sandbox_cleanup=sandbox_cleanup,
315
+ plan=plan,
316
+ scorers=scorers,
317
+ generate=generate,
318
+ progress=progress,
319
+ logger=logger if log_samples else None,
320
+ log_images=log_images,
321
+ sample_source=sample_source,
322
+ sample_error=sample_error_handler,
323
+ sample_complete=sample_complete,
324
+ fails_on_error=(
325
+ config.fail_on_error is None
326
+ or config.fail_on_error is True
327
+ ),
328
+ time_limit=config.time_limit,
329
+ working_limit=config.working_limit,
330
+ semaphore=sample_semaphore,
331
+ )
332
+ for (sample, state) in zip(samples, states)
333
+ ]
334
+ )
361
335
 
362
- # finish w/ success status
363
- eval_log = await logger.log_finish(
364
- "success", stats, results, reductions
336
+ # compute and record metrics if we have scores
337
+ completed_scores = [
338
+ score_dict
339
+ for score_dict in sample_results
340
+ if isinstance(score_dict, dict)
341
+ ]
342
+
343
+ if len(completed_scores) > 0:
344
+ results, reductions = eval_results(
345
+ samples=profile.samples,
346
+ scores=completed_scores,
347
+ reducers=task.epochs_reducer,
348
+ scorers=scorers,
349
+ metrics=task.metrics,
365
350
  )
366
351
 
367
- # display task summary
368
- td.complete(
369
- TaskSuccess(
370
- samples_completed=logger.samples_completed,
371
- stats=stats,
372
- results=results or EvalResults(),
373
- )
352
+ # collect eval data
353
+ collect_eval_data(stats)
354
+
355
+ # finish w/ success status
356
+ eval_log = await logger.log_finish("success", stats, results, reductions)
357
+
358
+ # display task summary
359
+ td.complete(
360
+ TaskSuccess(
361
+ samples_completed=logger.samples_completed,
362
+ stats=stats,
363
+ results=results or EvalResults(),
374
364
  )
365
+ )
375
366
 
376
- except anyio.get_cancelled_exc_class():
377
- with anyio.CancelScope(shield=True):
378
- # collect eval data
379
- collect_eval_data(stats)
367
+ except anyio.get_cancelled_exc_class():
368
+ with anyio.CancelScope(shield=True):
369
+ # collect eval data
370
+ collect_eval_data(stats)
380
371
 
381
- # finish w/ cancelled status
382
- eval_log = await logger.log_finish(
383
- "cancelled", stats, results, reductions
384
- )
372
+ # finish w/ cancelled status
373
+ eval_log = await logger.log_finish(
374
+ "cancelled", stats, results, reductions
375
+ )
385
376
 
386
- # display task cancelled
387
- td.complete(TaskCancelled(logger.samples_completed, stats))
377
+ # display task cancelled
378
+ td.complete(TaskCancelled(logger.samples_completed, stats))
388
379
 
389
- except BaseException as ex:
390
- if options.debug_errors:
391
- raise
392
- else:
393
- # get exception info
394
- type, value, traceback = sys.exc_info()
395
- type = type if type else BaseException
396
- value = value if value else ex
380
+ except BaseException as ex:
381
+ if options.debug_errors:
382
+ raise
383
+ else:
384
+ # get exception info
385
+ type, value, traceback = sys.exc_info()
386
+ type = type if type else BaseException
387
+ value = value if value else ex
397
388
 
398
- # build eval error
399
- error = eval_error(ex, type, value, traceback)
389
+ # build eval error
390
+ error = eval_error(ex, type, value, traceback)
400
391
 
401
- # collect eval data
402
- collect_eval_data(stats)
392
+ # collect eval data
393
+ collect_eval_data(stats)
403
394
 
404
- # finish with error status
405
- eval_log = await logger.log_finish(
406
- "error", stats, results, reductions, error
407
- )
395
+ # finish with error status
396
+ eval_log = await logger.log_finish(
397
+ "error", stats, results, reductions, error
398
+ )
408
399
 
409
- # display it
410
- td.complete(
411
- TaskError(logger.samples_completed, type, value, traceback)
412
- )
400
+ # display it
401
+ td.complete(TaskError(logger.samples_completed, type, value, traceback))
413
402
 
414
- # notify the view module that an eval just completed
415
- # (in case we have a view polling for new evals)
416
- view_notify_eval(logger.location)
403
+ # notify the view module that an eval just completed
404
+ # (in case we have a view polling for new evals)
405
+ view_notify_eval(logger.location)
417
406
 
418
- try:
419
- if (
420
- await send_telemetry("eval_log_location", eval_log.location)
421
- == "not_handled"
422
- ):
423
- # Converting the eval log to JSON is expensive. Only do so if
424
- # eval_log_location was not handled.
425
- await send_telemetry("eval_log", eval_log_json_str(eval_log))
426
- except Exception as ex:
427
- py_logger.warning(
428
- f"Error occurred sending telemetry: {exception_message(ex)}"
429
- )
407
+ try:
408
+ if (
409
+ await send_telemetry("eval_log_location", eval_log.location)
410
+ == "not_handled"
411
+ ):
412
+ # Converting the eval log to JSON is expensive. Only do so if
413
+ # eval_log_location was not handled.
414
+ await send_telemetry("eval_log", eval_log_json_str(eval_log))
415
+ except Exception as ex:
416
+ py_logger.warning(f"Error occurred sending telemetry: {exception_message(ex)}")
430
417
 
431
- # return eval log
432
- return eval_log
418
+ # return eval log
419
+ return eval_log
433
420
 
434
421
 
435
422
  def update_metrics_display_fn(
@@ -914,7 +901,7 @@ async def resolve_dataset(
914
901
  dataset: Dataset,
915
902
  model_name: ModelName,
916
903
  limit: int | tuple[int, int] | None,
917
- sample_id: str | int | list[str | int] | None,
904
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None,
918
905
  epochs: int,
919
906
  log_images: bool,
920
907
  message_limit: int | None,