inspect-ai 0.3.88__py3-none-any.whl → 0.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -256
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +71 -36
  22. inspect_ai/_view/www/dist/assets/index.js +573 -475
  23. inspect_ai/_view/www/log-schema.json +66 -0
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -2
  30. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
  31. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  32. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -6
  33. inspect_ai/_view/www/src/samples/transcript/TranscriptView.module.css +0 -2
  34. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  35. inspect_ai/_view/www/src/types/log.d.ts +24 -6
  36. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  37. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  38. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  39. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  40. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  41. inspect_ai/agent/_agent.py +12 -0
  42. inspect_ai/agent/_as_tool.py +1 -1
  43. inspect_ai/agent/_bridge/bridge.py +9 -2
  44. inspect_ai/agent/_react.py +142 -74
  45. inspect_ai/agent/_run.py +13 -2
  46. inspect_ai/agent/_types.py +6 -0
  47. inspect_ai/approval/_apply.py +6 -7
  48. inspect_ai/approval/_approver.py +3 -3
  49. inspect_ai/approval/_auto.py +2 -2
  50. inspect_ai/approval/_call.py +20 -4
  51. inspect_ai/approval/_human/approver.py +3 -3
  52. inspect_ai/approval/_human/manager.py +2 -2
  53. inspect_ai/approval/_human/panel.py +3 -3
  54. inspect_ai/approval/_policy.py +3 -3
  55. inspect_ai/log/__init__.py +2 -0
  56. inspect_ai/log/_log.py +23 -2
  57. inspect_ai/log/_model.py +58 -0
  58. inspect_ai/log/_recorders/file.py +14 -3
  59. inspect_ai/log/_transcript.py +3 -0
  60. inspect_ai/model/__init__.py +2 -0
  61. inspect_ai/model/_call_tools.py +4 -1
  62. inspect_ai/model/_model.py +49 -3
  63. inspect_ai/model/_openai.py +151 -21
  64. inspect_ai/model/_providers/anthropic.py +20 -12
  65. inspect_ai/model/_providers/bedrock.py +3 -3
  66. inspect_ai/model/_providers/cloudflare.py +29 -108
  67. inspect_ai/model/_providers/google.py +21 -10
  68. inspect_ai/model/_providers/grok.py +23 -17
  69. inspect_ai/model/_providers/groq.py +61 -37
  70. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  71. inspect_ai/model/_providers/mistral.py +8 -3
  72. inspect_ai/model/_providers/ollama.py +8 -9
  73. inspect_ai/model/_providers/openai.py +53 -157
  74. inspect_ai/model/_providers/openai_compatible.py +195 -0
  75. inspect_ai/model/_providers/openrouter.py +4 -15
  76. inspect_ai/model/_providers/providers.py +11 -0
  77. inspect_ai/model/_providers/together.py +25 -23
  78. inspect_ai/model/_trim.py +83 -0
  79. inspect_ai/solver/_plan.py +5 -3
  80. inspect_ai/tool/_tool_def.py +8 -2
  81. inspect_ai/util/__init__.py +3 -0
  82. inspect_ai/util/_concurrency.py +15 -2
  83. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/METADATA +1 -1
  84. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/RECORD +88 -83
  85. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/WHEEL +1 -1
  86. inspect_ai/_eval/task/rundir.py +0 -78
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/entry_points.txt +0 -0
  89. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/licenses/LICENSE +0 -0
  90. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/top_level.txt +0 -0
@@ -101,7 +101,6 @@ from .images import (
101
101
  )
102
102
  from .log import TaskLogger, collect_eval_data, log_start
103
103
  from .results import eval_results
104
- from .rundir import set_task_chdir
105
104
  from .sandbox import sandboxenv_context
106
105
  from .util import sample_messages, slice_dataset
107
106
 
@@ -121,6 +120,7 @@ SAMPLE_TOTAL_PROGRESS_UNITS = 1
121
120
  class TaskRunOptions:
122
121
  task: Task
123
122
  model: Model
123
+ model_roles: dict[str, Model] | None
124
124
  sandbox: SandboxEnvironmentSpec | None
125
125
  logger: TaskLogger
126
126
  eval_wd: str
@@ -137,6 +137,7 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
137
137
  # destructure options
138
138
  task = options.task
139
139
  model = options.model
140
+ model_roles = options.model_roles
140
141
  sandbox = options.sandbox
141
142
  logger = options.logger
142
143
  eval_wd = options.eval_wd
@@ -151,156 +152,136 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
151
152
  generate_config = task.config.merge(GenerateConfigArgs(**kwargs))
152
153
 
153
154
  # init task context
154
- init_task_context(model, options.task.approval, generate_config)
155
-
156
- # establish chdir for duration of execution (if a task has chdir=True)
157
- with set_task_chdir(task):
158
- # track stats and error
159
- results: EvalResults | None = None
160
- reductions: list[EvalSampleReductions] | None = None
161
- stats = EvalStats(started_at=iso_now())
162
-
163
- # handle sample errors (raise as required)
164
- sample_error_handler = SampleErrorHandler(
165
- config.fail_on_error, len(task.dataset)
166
- )
167
-
168
- # resolve some config
169
- model_name = ModelName(model)
170
- epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
171
- sandbox_cleanup = config.sandbox_cleanup is not False
172
- log_images = config.log_images is not False
173
- log_samples = config.log_samples is not False
174
-
175
- # resolve dataset
176
- _, samples, states = await resolve_dataset(
177
- dataset=task.dataset,
178
- model_name=model_name,
179
- limit=config.limit,
180
- sample_id=config.sample_id,
181
- epochs=epochs,
182
- log_images=log_images,
183
- message_limit=config.message_limit,
184
- token_limit=config.token_limit,
185
- )
186
-
187
- # resolve the plan (unroll chains)
188
- solver = solver or task.solver
189
- if isinstance(solver, Plan):
190
- plan = solver
191
- elif isinstance(solver, Chain):
192
- plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
193
- else:
194
- plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
195
-
196
- # add setup solver(s) if specified
197
- if task.setup:
198
- plan.steps = unroll(task.setup) + plan.steps
199
-
200
- # resolve the scorer
201
- score = score and task.scorer is not None
202
- scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
203
- scorer_profiles = (
204
- [
205
- registry_log_name(scorer)
206
- for scorer in scorers
207
- if is_registry_object(scorer)
208
- ]
209
- if scorers is not None
210
- else ["(none)"]
211
- )
212
-
213
- # compute an eval directory relative log location if we can
214
- if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
215
- log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
216
- else:
217
- log_location = logger.location
218
-
219
- # create task profile for display
220
- profile = TaskProfile(
221
- name=task.name,
222
- file=logger.eval.task_file,
223
- model=model_name,
224
- dataset=task.dataset.name or "(samples)",
225
- scorer=", ".join(scorer_profiles),
226
- samples=len(samples),
227
- steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
228
- eval_config=config,
229
- task_args=logger.eval.task_args,
230
- generate_config=generate_config,
231
- tags=tags,
232
- log_location=log_location,
233
- )
155
+ init_task_context(model, model_roles, options.task.approval, generate_config)
156
+
157
+ # track stats and error
158
+ results: EvalResults | None = None
159
+ reductions: list[EvalSampleReductions] | None = None
160
+ stats = EvalStats(started_at=iso_now())
161
+
162
+ # handle sample errors (raise as required)
163
+ sample_error_handler = SampleErrorHandler(config.fail_on_error, len(task.dataset))
164
+
165
+ # resolve some config
166
+ model_name = ModelName(model)
167
+ epochs = config.epochs if config.epochs else DEFAULT_EPOCHS
168
+ sandbox_cleanup = config.sandbox_cleanup is not False
169
+ log_images = config.log_images is not False
170
+ log_samples = config.log_samples is not False
171
+
172
+ # resolve dataset
173
+ _, samples, states = await resolve_dataset(
174
+ dataset=task.dataset,
175
+ model_name=model_name,
176
+ limit=config.limit,
177
+ sample_id=config.sample_id,
178
+ epochs=epochs,
179
+ log_images=log_images,
180
+ message_limit=config.message_limit,
181
+ token_limit=config.token_limit,
182
+ )
234
183
 
235
- with display().task(
236
- profile,
237
- ) as td:
238
- try:
239
- # start the log
240
- await log_start(logger, plan, generate_config)
241
-
242
- with td.progress() as p:
243
- # forward progress
244
- def progress(number: int) -> None:
245
- p.update(number)
246
-
247
- # provide solvers a function that they can use to generate output
248
- async def generate(
249
- state: TaskState,
250
- tool_calls: Literal["loop", "single", "none"] = "loop",
251
- cache: bool | CachePolicy = False,
252
- **kwargs: Unpack[GenerateConfigArgs],
253
- ) -> TaskState:
254
- return await task_generate(
255
- model=model,
256
- state=state,
257
- tool_calls=tool_calls,
258
- cache=cache,
259
- config=generate_config.merge(kwargs),
260
- )
184
+ # resolve the plan (unroll chains)
185
+ solver = solver or task.solver
186
+ if isinstance(solver, Plan):
187
+ plan = solver
188
+ elif isinstance(solver, Chain):
189
+ plan = Plan(list(solver), cleanup=task.cleanup, internal=True)
190
+ else:
191
+ plan = Plan(unroll(solver), cleanup=task.cleanup, internal=True)
192
+
193
+ # add setup solver(s) if specified
194
+ if task.setup:
195
+ plan.steps = unroll(task.setup) + plan.steps
196
+
197
+ # resolve the scorer
198
+ score = score and task.scorer is not None
199
+ scorers: list[Scorer] | None = task.scorer if (score and task.scorer) else None
200
+ scorer_profiles = (
201
+ [registry_log_name(scorer) for scorer in scorers if is_registry_object(scorer)]
202
+ if scorers is not None
203
+ else ["(none)"]
204
+ )
261
205
 
262
- # set generate for fork module
263
- set_task_generate(generate)
206
+ # compute an eval directory relative log location if we can
207
+ if PurePath(logger.location).is_relative_to(PurePath(eval_wd)):
208
+ log_location = PurePath(logger.location).relative_to(eval_wd).as_posix()
209
+ else:
210
+ log_location = logger.location
211
+
212
+ # create task profile for display
213
+ profile = TaskProfile(
214
+ name=task.name,
215
+ file=logger.eval.task_file,
216
+ model=model_name,
217
+ dataset=task.dataset.name or "(samples)",
218
+ scorer=", ".join(scorer_profiles),
219
+ samples=len(samples),
220
+ steps=len(samples) * SAMPLE_TOTAL_PROGRESS_UNITS,
221
+ eval_config=config,
222
+ task_args=logger.eval.task_args,
223
+ generate_config=generate_config,
224
+ tags=tags,
225
+ log_location=log_location,
226
+ )
264
227
 
265
- # semaphore to limit concurrency
266
- sample_semaphore = create_sample_semaphore(
267
- config, generate_config, model.api
228
+ with display().task(
229
+ profile,
230
+ ) as td:
231
+ try:
232
+ # start the log
233
+ await log_start(logger, plan, generate_config)
234
+
235
+ with td.progress() as p:
236
+ # forward progress
237
+ def progress(number: int) -> None:
238
+ p.update(number)
239
+
240
+ # provide solvers a function that they can use to generate output
241
+ async def generate(
242
+ state: TaskState,
243
+ tool_calls: Literal["loop", "single", "none"] = "loop",
244
+ cache: bool | CachePolicy = False,
245
+ **kwargs: Unpack[GenerateConfigArgs],
246
+ ) -> TaskState:
247
+ return await task_generate(
248
+ model=model,
249
+ state=state,
250
+ tool_calls=tool_calls,
251
+ cache=cache,
252
+ config=generate_config.merge(kwargs),
268
253
  )
269
254
 
270
- # track when samples complete and update progress as we go
271
- progress_results: list[dict[str, SampleScore]] = []
255
+ # set generate for fork module
256
+ set_task_generate(generate)
272
257
 
273
- def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
274
- td.update_metrics(metrics)
275
- logger.update_metrics(metrics)
258
+ # semaphore to limit concurrency
259
+ sample_semaphore = create_sample_semaphore(
260
+ config, generate_config, model.api
261
+ )
276
262
 
277
- update_metrics_display = update_metrics_display_fn(
278
- update_metrics,
279
- display_metrics=profile.eval_config.score_display is not False,
280
- )
263
+ # track when samples complete and update progress as we go
264
+ progress_results: list[dict[str, SampleScore]] = []
281
265
 
282
- def sample_complete(sample_score: dict[str, SampleScore]) -> None:
283
- # Capture the result
284
- progress_results.append(sample_score)
266
+ def update_metrics(metrics: list[TaskDisplayMetric]) -> None:
267
+ td.update_metrics(metrics)
268
+ logger.update_metrics(metrics)
285
269
 
286
- # Increment the segment progress
287
- td.sample_complete(
288
- complete=len(progress_results), total=len(samples)
289
- )
270
+ update_metrics_display = update_metrics_display_fn(
271
+ update_metrics,
272
+ display_metrics=profile.eval_config.score_display is not False,
273
+ )
290
274
 
291
- # Update metrics
292
- update_metrics_display(
293
- len(progress_results),
294
- progress_results,
295
- scorers,
296
- task.epochs_reducer,
297
- task.metrics,
298
- )
275
+ def sample_complete(sample_score: dict[str, SampleScore]) -> None:
276
+ # Capture the result
277
+ progress_results.append(sample_score)
299
278
 
300
- # initial progress
301
- td.sample_complete(complete=0, total=len(samples))
279
+ # Increment the segment progress
280
+ td.sample_complete(
281
+ complete=len(progress_results), total=len(samples)
282
+ )
302
283
 
303
- # Update metrics to empty state
284
+ # Update metrics
304
285
  update_metrics_display(
305
286
  len(progress_results),
306
287
  progress_results,
@@ -309,127 +290,133 @@ async def task_run(options: TaskRunOptions) -> EvalLog:
309
290
  task.metrics,
310
291
  )
311
292
 
312
- sample_results = await tg_collect(
313
- [
314
- functools.partial(
315
- task_run_sample,
316
- task_name=task.name,
317
- sample=sample,
318
- state=state,
319
- sandbox=sandbox,
320
- max_sandboxes=config.max_sandboxes,
321
- sandbox_cleanup=sandbox_cleanup,
322
- plan=plan,
323
- scorers=scorers,
324
- generate=generate,
325
- progress=progress,
326
- logger=logger if log_samples else None,
327
- log_images=log_images,
328
- sample_source=sample_source,
329
- sample_error=sample_error_handler,
330
- sample_complete=sample_complete,
331
- fails_on_error=(
332
- config.fail_on_error is None
333
- or config.fail_on_error is True
334
- ),
335
- time_limit=config.time_limit,
336
- working_limit=config.working_limit,
337
- semaphore=sample_semaphore,
338
- )
339
- for (sample, state) in zip(samples, states)
340
- ]
341
- )
293
+ # initial progress
294
+ td.sample_complete(complete=0, total=len(samples))
342
295
 
343
- # compute and record metrics if we have scores
344
- completed_scores = [
345
- score_dict
346
- for score_dict in sample_results
347
- if isinstance(score_dict, dict)
348
- ]
349
-
350
- if len(completed_scores) > 0:
351
- results, reductions = eval_results(
352
- samples=profile.samples,
353
- scores=completed_scores,
354
- reducers=task.epochs_reducer,
355
- scorers=scorers,
356
- metrics=task.metrics,
357
- )
296
+ # Update metrics to empty state
297
+ update_metrics_display(
298
+ len(progress_results),
299
+ progress_results,
300
+ scorers,
301
+ task.epochs_reducer,
302
+ task.metrics,
303
+ )
358
304
 
359
- # collect eval data
360
- collect_eval_data(stats)
305
+ sample_results = await tg_collect(
306
+ [
307
+ functools.partial(
308
+ task_run_sample,
309
+ task_name=task.name,
310
+ sample=sample,
311
+ state=state,
312
+ sandbox=sandbox,
313
+ max_sandboxes=config.max_sandboxes,
314
+ sandbox_cleanup=sandbox_cleanup,
315
+ plan=plan,
316
+ scorers=scorers,
317
+ generate=generate,
318
+ progress=progress,
319
+ logger=logger if log_samples else None,
320
+ log_images=log_images,
321
+ sample_source=sample_source,
322
+ sample_error=sample_error_handler,
323
+ sample_complete=sample_complete,
324
+ fails_on_error=(
325
+ config.fail_on_error is None
326
+ or config.fail_on_error is True
327
+ ),
328
+ time_limit=config.time_limit,
329
+ working_limit=config.working_limit,
330
+ semaphore=sample_semaphore,
331
+ )
332
+ for (sample, state) in zip(samples, states)
333
+ ]
334
+ )
361
335
 
362
- # finish w/ success status
363
- eval_log = await logger.log_finish(
364
- "success", stats, results, reductions
336
+ # compute and record metrics if we have scores
337
+ completed_scores = [
338
+ score_dict
339
+ for score_dict in sample_results
340
+ if isinstance(score_dict, dict)
341
+ ]
342
+
343
+ if len(completed_scores) > 0:
344
+ results, reductions = eval_results(
345
+ samples=profile.samples,
346
+ scores=completed_scores,
347
+ reducers=task.epochs_reducer,
348
+ scorers=scorers,
349
+ metrics=task.metrics,
365
350
  )
366
351
 
367
- # display task summary
368
- td.complete(
369
- TaskSuccess(
370
- samples_completed=logger.samples_completed,
371
- stats=stats,
372
- results=results or EvalResults(),
373
- )
352
+ # collect eval data
353
+ collect_eval_data(stats)
354
+
355
+ # finish w/ success status
356
+ eval_log = await logger.log_finish("success", stats, results, reductions)
357
+
358
+ # display task summary
359
+ td.complete(
360
+ TaskSuccess(
361
+ samples_completed=logger.samples_completed,
362
+ stats=stats,
363
+ results=results or EvalResults(),
374
364
  )
365
+ )
375
366
 
376
- except anyio.get_cancelled_exc_class():
377
- with anyio.CancelScope(shield=True):
378
- # collect eval data
379
- collect_eval_data(stats)
367
+ except anyio.get_cancelled_exc_class():
368
+ with anyio.CancelScope(shield=True):
369
+ # collect eval data
370
+ collect_eval_data(stats)
380
371
 
381
- # finish w/ cancelled status
382
- eval_log = await logger.log_finish(
383
- "cancelled", stats, results, reductions
384
- )
372
+ # finish w/ cancelled status
373
+ eval_log = await logger.log_finish(
374
+ "cancelled", stats, results, reductions
375
+ )
385
376
 
386
- # display task cancelled
387
- td.complete(TaskCancelled(logger.samples_completed, stats))
377
+ # display task cancelled
378
+ td.complete(TaskCancelled(logger.samples_completed, stats))
388
379
 
389
- except BaseException as ex:
390
- if options.debug_errors:
391
- raise
392
- else:
393
- # get exception info
394
- type, value, traceback = sys.exc_info()
395
- type = type if type else BaseException
396
- value = value if value else ex
380
+ except BaseException as ex:
381
+ if options.debug_errors:
382
+ raise
383
+ else:
384
+ # get exception info
385
+ type, value, traceback = sys.exc_info()
386
+ type = type if type else BaseException
387
+ value = value if value else ex
397
388
 
398
- # build eval error
399
- error = eval_error(ex, type, value, traceback)
389
+ # build eval error
390
+ error = eval_error(ex, type, value, traceback)
400
391
 
401
- # collect eval data
402
- collect_eval_data(stats)
392
+ # collect eval data
393
+ collect_eval_data(stats)
403
394
 
404
- # finish with error status
405
- eval_log = await logger.log_finish(
406
- "error", stats, results, reductions, error
407
- )
395
+ # finish with error status
396
+ eval_log = await logger.log_finish(
397
+ "error", stats, results, reductions, error
398
+ )
408
399
 
409
- # display it
410
- td.complete(
411
- TaskError(logger.samples_completed, type, value, traceback)
412
- )
400
+ # display it
401
+ td.complete(TaskError(logger.samples_completed, type, value, traceback))
413
402
 
414
- # notify the view module that an eval just completed
415
- # (in case we have a view polling for new evals)
416
- view_notify_eval(logger.location)
403
+ # notify the view module that an eval just completed
404
+ # (in case we have a view polling for new evals)
405
+ view_notify_eval(logger.location)
417
406
 
418
- try:
419
- if (
420
- await send_telemetry("eval_log_location", eval_log.location)
421
- == "not_handled"
422
- ):
423
- # Converting the eval log to JSON is expensive. Only do so if
424
- # eval_log_location was not handled.
425
- await send_telemetry("eval_log", eval_log_json_str(eval_log))
426
- except Exception as ex:
427
- py_logger.warning(
428
- f"Error occurred sending telemetry: {exception_message(ex)}"
429
- )
407
+ try:
408
+ if (
409
+ await send_telemetry("eval_log_location", eval_log.location)
410
+ == "not_handled"
411
+ ):
412
+ # Converting the eval log to JSON is expensive. Only do so if
413
+ # eval_log_location was not handled.
414
+ await send_telemetry("eval_log", eval_log_json_str(eval_log))
415
+ except Exception as ex:
416
+ py_logger.warning(f"Error occurred sending telemetry: {exception_message(ex)}")
430
417
 
431
- # return eval log
432
- return eval_log
418
+ # return eval log
419
+ return eval_log
433
420
 
434
421
 
435
422
  def update_metrics_display_fn(
@@ -655,18 +642,6 @@ async def task_run_sample(
655
642
  )
656
643
  )
657
644
 
658
- # sample init event (remove file bodies as they have content or absolute paths)
659
- event_sample = sample.model_copy(
660
- update=dict(files={k: "" for k in sample.files.keys()})
661
- if sample.files
662
- else None
663
- )
664
- transcript()._event(
665
- SampleInitEvent(
666
- sample=event_sample, state=state_jsonable(state)
667
- )
668
- )
669
-
670
645
  # set progress for plan then run it
671
646
  state = await plan(state, generate)
672
647
 
@@ -914,7 +889,7 @@ async def resolve_dataset(
914
889
  dataset: Dataset,
915
890
  model_name: ModelName,
916
891
  limit: int | tuple[int, int] | None,
917
- sample_id: str | int | list[str | int] | None,
892
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None,
918
893
  epochs: int,
919
894
  log_images: bool,
920
895
  message_limit: int | None,
@@ -54,6 +54,7 @@ class Task:
54
54
  metrics: list[Metric] | dict[str, list[Metric]] | None = None,
55
55
  model: str | Model | None = None,
56
56
  config: GenerateConfig = GenerateConfig(),
57
+ model_roles: dict[str, str | Model] | None = None,
57
58
  sandbox: SandboxEnvironmentType | None = None,
58
59
  approval: str | list[ApprovalPolicy] | None = None,
59
60
  epochs: int | Epochs | None = None,
@@ -79,7 +80,8 @@ class Task:
79
80
  scorer: Scorer used to evaluate model output.
80
81
  metrics: Alternative metrics (overrides the metrics provided by the specified scorer).
81
82
  model: Default model for task (Optional, defaults to eval model).
82
- config: Model generation config.
83
+ config: Model generation config for default model (does not apply to model roles)
84
+ model_roles: Named roles for use in `get_model()`.
83
85
  sandbox: Sandbox environment type (or optionally a str or tuple with a shorthand spec)
84
86
  approval: Tool use approval policies.
85
87
  Either a path to an approval policy config file or a list of approval policies. Defaults to no approval policy.
@@ -136,6 +138,7 @@ class Task:
136
138
  self.metrics = metrics
137
139
  self.model = resolve_model(model)
138
140
  self.config = config
141
+ self.model_roles = resolve_model_roles(model_roles)
139
142
  self.sandbox = resolve_sandbox_environment(sandbox)
140
143
  self.approval = resolve_approval(approval)
141
144
  epochs = resolve_epochs(epochs)
@@ -185,6 +188,7 @@ def task_with(
185
188
  metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN,
186
189
  model: str | Model | NotGiven = NOT_GIVEN,
187
190
  config: GenerateConfig | NotGiven = NOT_GIVEN,
191
+ model_roles: dict[str, str | Model] | NotGiven = NOT_GIVEN,
188
192
  sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN,
189
193
  approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN,
190
194
  epochs: int | Epochs | None | NotGiven = NOT_GIVEN,
@@ -214,7 +218,8 @@ def task_with(
214
218
  scorer: Scorer used to evaluate model output.
215
219
  metrics: Alternative metrics (overrides the metrics provided by the specified scorer).
216
220
  model: Default model for task (Optional, defaults to eval model).
217
- config: Model generation config.
221
+ config: Model generation config for default model (does not apply to model roles)
222
+ model_roles: Named roles for use in `get_model()`.
218
223
  sandbox: Sandbox environment type (or optionally a str or tuple with a shorthand spec)
219
224
  approval: Tool use approval policies.
220
225
  Either a path to an approval policy config file or a list of approval policies. Defaults to no approval policy.
@@ -257,6 +262,8 @@ def task_with(
257
262
  task.model = resolve_model(model)
258
263
  if not isinstance(config, NotGiven):
259
264
  task.config = config
265
+ if not isinstance(model_roles, NotGiven):
266
+ task.model_roles = resolve_model_roles(model_roles)
260
267
  if not isinstance(sandbox, NotGiven):
261
268
  task.sandbox = resolve_sandbox_environment(sandbox)
262
269
  if not isinstance(approval, NotGiven):
@@ -315,6 +322,7 @@ class PreviousTask:
315
322
  task: str | Task
316
323
  task_args: dict[str, Any]
317
324
  model: Model | None
325
+ model_roles: dict[str, Model] | None
318
326
  log: EvalLog
319
327
 
320
328
 
@@ -365,6 +373,21 @@ def resolve_model(model: str | Model | None) -> Model | None:
365
373
  return model
366
374
 
367
375
 
376
+ def resolve_model_roles(
377
+ model_roles: dict[str, str | Model] | None,
378
+ ) -> dict[str, Model] | None:
379
+ if model_roles is not None:
380
+ resolved_model_roles = {
381
+ k: get_model(v, memoize=False) if isinstance(v, str) else v
382
+ for k, v in model_roles.items()
383
+ }
384
+ for k, v in resolved_model_roles.items():
385
+ v._set_role(k)
386
+ return resolved_model_roles
387
+ else:
388
+ return None
389
+
390
+
368
391
  def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
369
392
  return (
370
393
  scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
@@ -25,13 +25,6 @@ def task_run_dir(task: Task) -> str:
25
25
  return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd())
26
26
 
27
27
 
28
- def task_chdir(task: Task) -> str | None:
29
- if task.attribs.get("chdir", False) is True:
30
- return task_run_dir(task)
31
- else:
32
- return None
33
-
34
-
35
28
  def task_file(task: Task, relative: bool = False) -> str | None:
36
29
  file = cast(str | None, getattr(task, TASK_FILE_ATTR, None))
37
30
  if file:
@@ -46,7 +39,7 @@ def task_file(task: Task, relative: bool = False) -> str | None:
46
39
  def slice_dataset(
47
40
  dataset: Dataset,
48
41
  limit: int | tuple[int, int] | None,
49
- sample_id: str | int | list[str | int] | None,
42
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None,
50
43
  ) -> Dataset:
51
44
  def normalise(id: str | int | None) -> str:
52
45
  if isinstance(id, str) and id.isdigit():