inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +4 -2
  3. inspect_ai/_cli/eval.py +2 -0
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +0 -2
  6. inspect_ai/_display/core/panel.py +1 -1
  7. inspect_ai/_display/rich/display.py +4 -4
  8. inspect_ai/_display/textual/app.py +4 -1
  9. inspect_ai/_display/textual/widgets/samples.py +41 -5
  10. inspect_ai/_eval/eval.py +32 -20
  11. inspect_ai/_eval/evalset.py +7 -5
  12. inspect_ai/_eval/run.py +16 -11
  13. inspect_ai/_eval/task/__init__.py +2 -2
  14. inspect_ai/_eval/task/images.py +40 -25
  15. inspect_ai/_eval/task/run.py +141 -119
  16. inspect_ai/_eval/task/task.py +140 -25
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/content.py +23 -1
  19. inspect_ai/_util/datetime.py +1 -1
  20. inspect_ai/_util/deprecation.py +1 -1
  21. inspect_ai/_util/images.py +20 -17
  22. inspect_ai/_util/json.py +11 -1
  23. inspect_ai/_util/kvstore.py +73 -0
  24. inspect_ai/_util/logger.py +2 -1
  25. inspect_ai/_util/notgiven.py +18 -0
  26. inspect_ai/_util/thread.py +5 -0
  27. inspect_ai/_util/trace.py +39 -3
  28. inspect_ai/_util/transcript.py +36 -7
  29. inspect_ai/_view/www/.prettierrc.js +12 -0
  30. inspect_ai/_view/www/dist/assets/index.js +322 -226
  31. inspect_ai/_view/www/log-schema.json +221 -138
  32. inspect_ai/_view/www/src/App.mjs +18 -9
  33. inspect_ai/_view/www/src/Types.mjs +0 -1
  34. inspect_ai/_view/www/src/api/Types.mjs +15 -4
  35. inspect_ai/_view/www/src/api/api-http.mjs +2 -0
  36. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
  37. inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
  38. inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
  39. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  40. inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
  41. inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
  42. inspect_ai/_view/www/src/components/Tools.mjs +18 -3
  43. inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
  44. inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
  46. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
  47. inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
  48. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
  49. inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
  50. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
  51. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
  52. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
  53. inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
  54. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
  55. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
  56. inspect_ai/_view/www/src/types/log.d.ts +53 -35
  57. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  58. inspect_ai/approval/_human/util.py +2 -2
  59. inspect_ai/dataset/_sources/csv.py +2 -1
  60. inspect_ai/dataset/_sources/json.py +2 -1
  61. inspect_ai/dataset/_sources/util.py +15 -7
  62. inspect_ai/log/_condense.py +11 -1
  63. inspect_ai/log/_log.py +27 -5
  64. inspect_ai/log/_recorders/eval.py +21 -8
  65. inspect_ai/log/_samples.py +10 -5
  66. inspect_ai/log/_transcript.py +28 -1
  67. inspect_ai/model/__init__.py +10 -2
  68. inspect_ai/model/_call_tools.py +82 -17
  69. inspect_ai/model/_chat_message.py +2 -4
  70. inspect_ai/model/{_trace.py → _conversation.py} +9 -8
  71. inspect_ai/model/_model.py +2 -2
  72. inspect_ai/model/_providers/anthropic.py +9 -7
  73. inspect_ai/model/_providers/azureai.py +6 -4
  74. inspect_ai/model/_providers/bedrock.py +6 -4
  75. inspect_ai/model/_providers/google.py +103 -14
  76. inspect_ai/model/_providers/groq.py +7 -5
  77. inspect_ai/model/_providers/hf.py +11 -6
  78. inspect_ai/model/_providers/mistral.py +6 -9
  79. inspect_ai/model/_providers/openai.py +34 -8
  80. inspect_ai/model/_providers/openai_o1.py +10 -12
  81. inspect_ai/model/_providers/vertex.py +17 -4
  82. inspect_ai/scorer/__init__.py +13 -2
  83. inspect_ai/scorer/_metrics/__init__.py +2 -2
  84. inspect_ai/scorer/_metrics/std.py +3 -3
  85. inspect_ai/tool/__init__.py +9 -1
  86. inspect_ai/tool/_tool.py +9 -2
  87. inspect_ai/tool/_tool_info.py +2 -1
  88. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
  89. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
  90. inspect_ai/util/__init__.py +4 -3
  91. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  92. inspect_ai/util/_display.py +14 -4
  93. inspect_ai/util/_sandbox/context.py +12 -13
  94. inspect_ai/util/_sandbox/docker/compose.py +24 -13
  95. inspect_ai/util/_sandbox/docker/docker.py +20 -13
  96. inspect_ai/util/_sandbox/docker/util.py +2 -1
  97. inspect_ai/util/_sandbox/environment.py +13 -1
  98. inspect_ai/util/_sandbox/local.py +1 -0
  99. inspect_ai/util/_sandbox/self_check.py +18 -18
  100. inspect_ai/util/_store.py +2 -2
  101. inspect_ai/util/_subprocess.py +3 -3
  102. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
  103. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
  104. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
  105. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
  106. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
  107. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -1,66 +1,69 @@
1
1
  import asyncio
2
2
 
3
3
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED
4
- from inspect_ai._util.images import image_as_data_uri
4
+ from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentVideo
5
+ from inspect_ai._util.images import file_as_data_uri
5
6
  from inspect_ai._util.url import is_data_uri
6
7
  from inspect_ai.dataset import Sample
7
- from inspect_ai.model import ChatMessage, ChatMessageUser, Content, ContentImage
8
+ from inspect_ai.model import ChatMessage, ChatMessageUser
8
9
  from inspect_ai.solver import TaskState
9
10
 
10
11
 
11
- async def states_with_base64_images(states: list[TaskState]) -> list[TaskState]:
12
- return await asyncio.gather(*[state_with_base64_images(state) for state in states])
12
+ async def states_with_base64_content(states: list[TaskState]) -> list[TaskState]:
13
+ return await asyncio.gather(*[state_with_base64_content(state) for state in states])
13
14
 
14
15
 
15
- async def state_with_base64_images(state: TaskState) -> TaskState:
16
- state.messages = await messages_with_base64_images(state.messages)
16
+ async def state_with_base64_content(state: TaskState) -> TaskState:
17
+ state.messages = await messages_with_base64_content(state.messages)
17
18
  return state
18
19
 
19
20
 
20
- def state_without_base64_images(state: TaskState) -> TaskState:
21
- state.messages = messages_without_base64_images(state.messages)
21
+ def state_without_base64_content(state: TaskState) -> TaskState:
22
+ state.messages = messages_without_base64_content(state.messages)
22
23
  return state
23
24
 
24
25
 
25
- async def samples_with_base64_images(samples: list[Sample]) -> list[Sample]:
26
+ async def samples_with_base64_content(samples: list[Sample]) -> list[Sample]:
26
27
  return await asyncio.gather(
27
- *[sample_with_base64_images(sample) for sample in samples]
28
+ *[sample_with_base64_content(sample) for sample in samples]
28
29
  )
29
30
 
30
31
 
31
- async def sample_with_base64_images(sample: Sample) -> Sample:
32
+ async def sample_with_base64_content(sample: Sample) -> Sample:
32
33
  if isinstance(sample.input, list):
33
34
  return sample.model_copy(
34
- update={"input": await messages_with_base64_images(sample.input)}
35
+ update={"input": await messages_with_base64_content(sample.input)}
35
36
  )
36
37
  else:
37
38
  return sample
38
39
 
39
40
 
40
- def sample_without_base64_images(sample: Sample) -> Sample:
41
+ def sample_without_base64_content(sample: Sample) -> Sample:
41
42
  if isinstance(sample.input, list):
42
43
  return sample.model_copy(
43
- update={"input": messages_without_base64_images(sample.input)}
44
+ update={"input": messages_without_base64_content(sample.input)}
44
45
  )
45
46
  else:
46
47
  return sample
47
48
 
48
49
 
49
- async def messages_with_base64_images(messages: list[ChatMessage]) -> list[ChatMessage]:
50
+ async def messages_with_base64_content(
51
+ messages: list[ChatMessage],
52
+ ) -> list[ChatMessage]:
50
53
  return await asyncio.gather(
51
- *[message_with_base64_image(message) for message in messages]
54
+ *[message_with_base64_content(message) for message in messages]
52
55
  )
53
56
 
54
57
 
55
- def messages_without_base64_images(messages: list[ChatMessage]) -> list[ChatMessage]:
56
- return [message_without_base64_image(message) for message in messages]
58
+ def messages_without_base64_content(messages: list[ChatMessage]) -> list[ChatMessage]:
59
+ return [message_without_base64_content(message) for message in messages]
57
60
 
58
61
 
59
- async def message_with_base64_image(message: ChatMessage) -> ChatMessage:
62
+ async def message_with_base64_content(message: ChatMessage) -> ChatMessage:
60
63
  if isinstance(message, ChatMessageUser) and not isinstance(message.content, str):
61
64
  return ChatMessageUser(
62
65
  content=[
63
- await chat_content_with_base64_image(content)
66
+ await chat_content_with_base64_content(content)
64
67
  for content in message.content
65
68
  ],
66
69
  source=message.source,
@@ -69,11 +72,11 @@ async def message_with_base64_image(message: ChatMessage) -> ChatMessage:
69
72
  return message
70
73
 
71
74
 
72
- def message_without_base64_image(message: ChatMessage) -> ChatMessage:
75
+ def message_without_base64_content(message: ChatMessage) -> ChatMessage:
73
76
  if isinstance(message, ChatMessageUser) and not isinstance(message.content, str):
74
77
  return ChatMessageUser(
75
78
  content=[
76
- chat_content_without_base64_image(content)
79
+ chat_content_without_base64_content(content)
77
80
  for content in message.content
78
81
  ],
79
82
  source=message.source,
@@ -82,18 +85,30 @@ def message_without_base64_image(message: ChatMessage) -> ChatMessage:
82
85
  return message
83
86
 
84
87
 
85
- async def chat_content_with_base64_image(content: Content) -> Content:
88
+ async def chat_content_with_base64_content(content: Content) -> Content:
86
89
  if isinstance(content, ContentImage):
87
90
  return ContentImage(
88
- image=await image_as_data_uri(content.image),
91
+ image=await file_as_data_uri(content.image),
89
92
  detail=content.detail,
90
93
  )
94
+ elif isinstance(content, ContentAudio):
95
+ return ContentAudio(
96
+ audio=await file_as_data_uri(content.audio), format=content.format
97
+ )
98
+ elif isinstance(content, ContentVideo):
99
+ return ContentVideo(
100
+ video=await file_as_data_uri(content.video), format=content.format
101
+ )
91
102
  else:
92
103
  return content
93
104
 
94
105
 
95
- def chat_content_without_base64_image(content: Content) -> Content:
106
+ def chat_content_without_base64_content(content: Content) -> Content:
96
107
  if isinstance(content, ContentImage) and is_data_uri(content.image):
97
108
  return ContentImage(image=BASE_64_DATA_REMOVED, detail=content.detail)
109
+ elif isinstance(content, ContentAudio) and is_data_uri(content.audio):
110
+ return ContentAudio(audio=BASE_64_DATA_REMOVED, format="mp3")
111
+ elif isinstance(content, ContentVideo) and is_data_uri(content.video):
112
+ return ContentVideo(video=BASE_64_DATA_REMOVED, format="mp4")
98
113
  else:
99
114
  return content
@@ -4,6 +4,7 @@ import sys
4
4
  import time
5
5
  from copy import deepcopy
6
6
  from dataclasses import dataclass, field
7
+ from datetime import datetime
7
8
  from logging import getLogger
8
9
  from pathlib import PurePath
9
10
  from typing import Callable, Literal
@@ -71,6 +72,7 @@ from inspect_ai.solver._chain import Chain, unroll
71
72
  from inspect_ai.solver._fork import set_task_generate
72
73
  from inspect_ai.solver._solver import Solver
73
74
  from inspect_ai.solver._task_state import sample_state, set_sample_state, state_jsonable
75
+ from inspect_ai.util._sandbox.context import sandbox_connections
74
76
  from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
75
77
  from inspect_ai.util._subtask import init_subtask
76
78
 
@@ -79,10 +81,10 @@ from ..task import Task
79
81
  from .error import SampleErrorHandler
80
82
  from .generate import task_generate
81
83
  from .images import (
82
- sample_without_base64_images,
83
- samples_with_base64_images,
84
- state_without_base64_images,
85
- states_with_base64_images,
84
+ sample_without_base64_content,
85
+ samples_with_base64_content,
86
+ state_without_base64_content,
87
+ states_with_base64_content,
86
88
  )
87
89
  from .log import TaskLogger, collect_eval_data, log_start
88
90
  from .results import eval_results
@@ -533,11 +535,6 @@ async def task_run_sample(
533
535
  else contextlib.nullcontext()
534
536
  )
535
537
 
536
- # use timeout if provided
537
- timeout_cm = (
538
- timeout(time_limit) if time_limit is not None else contextlib.nullcontext()
539
- )
540
-
541
538
  # helper to handle exceptions (will throw if we've exceeded the limit)
542
539
  def handle_error(ex: BaseException) -> EvalError:
543
540
  err = sample_error(ex)
@@ -547,7 +544,6 @@ async def task_run_sample(
547
544
  # solver loop
548
545
  async with (
549
546
  semaphore_cm,
550
- sandboxenv_cm,
551
547
  active_sample(
552
548
  task=task_name,
553
549
  model=str(state.model),
@@ -561,125 +557,151 @@ async def task_run_sample(
561
557
  ) as active,
562
558
  ):
563
559
  error: EvalError | None = None
560
+ results: dict[str, SampleScore] = {}
564
561
  try:
565
- async with timeout_cm:
566
- # sample init event (remove file bodies as they have content or absolute paths)
567
- event_sample = sample.model_copy(
568
- update=dict(files={k: "" for k in sample.files.keys()})
569
- if sample.files
570
- else None
571
- )
572
- transcript()._event(
573
- SampleInitEvent(sample=event_sample, state=state_jsonable(state))
574
- )
562
+ async with sandboxenv_cm:
563
+ try:
564
+ # update active sample wth sandboxes now that we are initialised
565
+ active.sandboxes = await sandbox_connections()
566
+
567
+ # initialise timeout context manager
568
+ timeout_cm = (
569
+ timeout(time_limit)
570
+ if time_limit is not None
571
+ else contextlib.nullcontext()
572
+ )
575
573
 
576
- # set progress for plan then run it
577
- state = await plan(state, generate)
574
+ # run sample w/ optional timeout
575
+ async with timeout_cm:
576
+ # mark started
577
+ active.started = datetime.now().timestamp()
578
578
 
579
- except TimeoutError:
580
- if time_limit is not None:
581
- transcript()._event(
582
- SampleLimitEvent(
583
- type="time",
584
- message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
585
- limit=time_limit,
586
- )
587
- )
588
- else:
589
- py_logger.warning(
590
- "Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
591
- )
579
+ # sample init event (remove file bodies as they have content or absolute paths)
580
+ event_sample = sample.model_copy(
581
+ update=dict(files={k: "" for k in sample.files.keys()})
582
+ if sample.files
583
+ else None
584
+ )
585
+ transcript()._event(
586
+ SampleInitEvent(
587
+ sample=event_sample, state=state_jsonable(state)
588
+ )
589
+ )
592
590
 
593
- # capture most recent state for scoring
594
- state = sample_state() or state
591
+ # set progress for plan then run it
592
+ state = await plan(state, generate)
595
593
 
596
- except asyncio.CancelledError as ex:
597
- if active.interrupt_action:
598
- # record eve t
599
- transcript()._event(
600
- SampleLimitEvent(
601
- type="operator",
602
- message="Sample completed: interrupted by operator",
603
- )
604
- )
594
+ except TimeoutError:
595
+ if time_limit is not None:
596
+ transcript()._event(
597
+ SampleLimitEvent(
598
+ type="time",
599
+ message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
600
+ limit=time_limit,
601
+ )
602
+ )
603
+ else:
604
+ py_logger.warning(
605
+ "Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
606
+ )
605
607
 
606
- # handle the action
607
- match active.interrupt_action:
608
- case "score":
609
- # continue to scoring (capture the most recent state)
610
- state = sample_state() or state
611
- case "error":
612
- # default error handling
613
- error = handle_error(ex)
608
+ # capture most recent state for scoring
609
+ state = sample_state() or state
614
610
 
615
- else:
616
- raise
611
+ except asyncio.CancelledError as ex:
612
+ if active.interrupt_action:
613
+ # record eve t
614
+ transcript()._event(
615
+ SampleLimitEvent(
616
+ type="operator",
617
+ message="Sample completed: interrupted by operator",
618
+ )
619
+ )
617
620
 
618
- except BaseException as ex:
619
- error = handle_error(ex)
621
+ # handle the action
622
+ match active.interrupt_action:
623
+ case "score":
624
+ # continue to scoring (capture the most recent state)
625
+ state = sample_state() or state
626
+ case "error":
627
+ # default error handling
628
+ error = handle_error(ex)
629
+
630
+ else:
631
+ raise
632
+
633
+ except BaseException as ex:
634
+ error = handle_error(ex)
635
+
636
+ # set timeout for scoring. if the original timeout was never hit
637
+ # then just create a new timeout_cm targeting the original
638
+ # timeout time. if the original timeout was hit we still want
639
+ # to provide an opportunity for scoring, but we don't necessarily
640
+ # want to wait the full timeout again (especially in the case where
641
+ # the cause of the timeout is a hung container and scoring requires
642
+ # interacting with the container). as a middle ground we use half
643
+ # of the original timeout value for scoring.
644
+ if isinstance(timeout_cm, Timeout):
645
+ if not timeout_cm.expired():
646
+ timeout_cm = timeout_at(timeout_cm.when())
647
+ else:
648
+ assert time_limit
649
+ timeout_cm = timeout(time_limit / 2)
650
+
651
+ # scoring
652
+ try:
653
+ # timeout during scoring will result in an ordinary sample error
654
+ async with timeout_cm:
655
+ if scorers and error is None:
656
+ for scorer in scorers:
657
+ scorer_name = unique_scorer_name(
658
+ scorer, list(results.keys())
659
+ )
660
+ with transcript().step(name=scorer_name, type="scorer"):
661
+ score_result = (
662
+ await scorer(state, Target(sample.target))
663
+ if scorer
664
+ else None
665
+ )
666
+ if score_result is not None:
667
+ sample_score = SampleScore(
668
+ score=score_result,
669
+ sample_id=sample.id,
670
+ )
671
+ transcript()._event(
672
+ ScoreEvent(
673
+ score=score_result, target=sample.target
674
+ )
675
+ )
676
+ results[scorer_name] = sample_score
677
+
678
+ except asyncio.CancelledError:
679
+ if active.interrupt_action:
680
+ transcript()._event(
681
+ SampleLimitEvent(
682
+ type="operator",
683
+ message="Unable to score sample due to operator interruption",
684
+ )
685
+ )
620
686
 
621
- # set timeout for scoring. if the original timeout was never hit
622
- # then just create a new timeout_cm targeting the original
623
- # timeout time. if the original timeout was hit we still want
624
- # to provide an opportunity for scoring, but we don't necessarily
625
- # want to wait the full timeout again (especially in the case where
626
- # the cause of the timeout is a hung container and scoring requires
627
- # interacting with the container). as a middle ground we use half
628
- # of the original timeout value for scoring.
629
- if isinstance(timeout_cm, Timeout):
630
- if not timeout_cm.expired():
631
- timeout_cm = timeout_at(timeout_cm.when())
632
- else:
633
- assert time_limit
634
- timeout_cm = timeout(time_limit / 2)
687
+ raise
635
688
 
636
- # scoring
637
- try:
638
- # timeout during scoring will result in an ordinary sample error
639
- async with timeout_cm:
640
- results: dict[str, SampleScore] = {}
641
- if scorers and error is None:
642
- for scorer in scorers:
643
- scorer_name = unique_scorer_name(scorer, list(results.keys()))
644
- with transcript().step(name=scorer_name, type="scorer"):
645
- score_result = (
646
- await scorer(state, Target(sample.target))
647
- if scorer
648
- else None
689
+ except BaseException as ex:
690
+ # note timeout
691
+ if isinstance(ex, TimeoutError):
692
+ transcript()._event(
693
+ SampleLimitEvent(
694
+ type="time",
695
+ message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)",
696
+ limit=time_limit,
649
697
  )
650
- if score_result is not None:
651
- sample_score = SampleScore(
652
- score=score_result,
653
- sample_id=sample.id,
654
- )
655
- transcript()._event(
656
- ScoreEvent(score=score_result, target=sample.target)
657
- )
658
- results[scorer_name] = sample_score
659
-
660
- except asyncio.CancelledError:
661
- if active.interrupt_action:
662
- transcript()._event(
663
- SampleLimitEvent(
664
- type="operator",
665
- message="Unable to score sample due to operator interruption",
666
- )
667
- )
698
+ )
668
699
 
669
- raise
700
+ # handle error (this will throw if we've exceeded the limit)
701
+ error = handle_error(ex)
670
702
 
703
+ # handle sandboxenv init errors
671
704
  except BaseException as ex:
672
- # note timeout
673
- if isinstance(ex, TimeoutError):
674
- transcript()._event(
675
- SampleLimitEvent(
676
- type="time",
677
- message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)",
678
- limit=time_limit,
679
- )
680
- )
681
-
682
- # handle error (this will throw if we've exceeded the limit)
683
705
  error = handle_error(ex)
684
706
 
685
707
  # complete the sample
@@ -689,12 +711,12 @@ async def task_run_sample(
689
711
  if logger is not None:
690
712
  # if we are logging images then be sure to base64 images injected by solvers
691
713
  if log_images:
692
- state = (await states_with_base64_images([state]))[0]
714
+ state = (await states_with_base64_content([state]))[0]
693
715
 
694
716
  # otherwise ensure there are no base64 images in sample or messages
695
717
  else:
696
- sample = sample_without_base64_images(sample)
697
- state = state_without_base64_images(state)
718
+ sample = sample_without_base64_content(sample)
719
+ state = state_without_base64_content(state)
698
720
 
699
721
  # log the sample
700
722
  await log_sample(
@@ -784,7 +806,7 @@ async def resolve_dataset(
784
806
 
785
807
  # if we are logging images then resolve sample images here
786
808
  if log_images:
787
- samples = await samples_with_base64_images(samples)
809
+ samples = await samples_with_base64_content(samples)
788
810
 
789
811
  # prime the eval tasks (deep copy so they share no state w/ sample)
790
812
  sample_epochs: list[int] = []
@@ -1,3 +1,4 @@
1
+ from copy import deepcopy
1
2
  from dataclasses import dataclass
2
3
  from logging import getLogger
3
4
  from typing import Any, Callable, Sequence, cast
@@ -6,6 +7,7 @@ from pydantic import BaseModel
6
7
  from typing_extensions import TypedDict, Unpack
7
8
 
8
9
  from inspect_ai._util.logger import warn_once
10
+ from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
9
11
  from inspect_ai._util.registry import is_registry_object, registry_info
10
12
  from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config
11
13
  from inspect_ai.dataset import Dataset, MemoryDataset, Sample
@@ -115,35 +117,15 @@ class Task:
115
117
  f"DEPRECATED: the '{arg}' parameter is deprecated (please use the '{newarg}' parameter instead)",
116
118
  )
117
119
 
118
- # resolve epochs / epochs_reducer
119
- if isinstance(epochs, int):
120
- epochs = Epochs(epochs)
121
- if epochs is not None and epochs.epochs < 1:
122
- raise ValueError("epochs must be a positive integer.")
123
-
124
- # resolve dataset (provide empty sample to bootstrap tasks w/o samples,
125
- # which could occur for testing or for an interactive mode eval)
126
- dataset = dataset or [Sample(input="prompt")]
127
- self.dataset: Dataset = (
128
- dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
129
- )
120
+ self.dataset = resolve_dataset(dataset)
130
121
  self.setup = setup
131
- self.solver = chain(solver) if isinstance(solver, list) else solver
132
- self.scorer = (
133
- scorer
134
- if isinstance(scorer, list)
135
- else [scorer]
136
- if scorer is not None
137
- else None
138
- )
122
+ self.solver = resolve_solver(solver)
123
+ self.scorer = resolve_scorer(scorer)
139
124
  self.metrics = metrics
140
125
  self.config = config
141
126
  self.sandbox = resolve_sandbox_environment(sandbox)
142
- self.approval = (
143
- approval_policies_from_config(approval)
144
- if isinstance(approval, str)
145
- else approval
146
- )
127
+ self.approval = resolve_approval(approval)
128
+ epochs = resolve_epochs(epochs)
147
129
  self.epochs = epochs.epochs if epochs else None
148
130
  self.epochs_reducer = epochs.reducer if epochs else None
149
131
  self.fail_on_error = fail_on_error
@@ -171,6 +153,106 @@ class Task:
171
153
  return dict()
172
154
 
173
155
 
156
+ def task_with(
157
+ task: Task,
158
+ *,
159
+ dataset: Dataset | Sequence[Sample] | None | NotGiven = NOT_GIVEN,
160
+ setup: Solver | list[Solver] | None | NotGiven = NOT_GIVEN,
161
+ solver: Solver | list[Solver] | NotGiven = NOT_GIVEN,
162
+ scorer: Scorer | list[Scorer] | None | NotGiven = NOT_GIVEN,
163
+ metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN,
164
+ config: GenerateConfig | NotGiven = NOT_GIVEN,
165
+ sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN,
166
+ approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN,
167
+ epochs: int | Epochs | None | NotGiven = NOT_GIVEN,
168
+ fail_on_error: bool | float | None | NotGiven = NOT_GIVEN,
169
+ message_limit: int | None | NotGiven = NOT_GIVEN,
170
+ token_limit: int | None | NotGiven = NOT_GIVEN,
171
+ time_limit: int | None | NotGiven = NOT_GIVEN,
172
+ name: str | None | NotGiven = NOT_GIVEN,
173
+ version: int | NotGiven = NOT_GIVEN,
174
+ metadata: dict[str, Any] | None | NotGiven = NOT_GIVEN,
175
+ ) -> Task:
176
+ """Task adapted with alternate values for one or more options.
177
+
178
+ Args:
179
+ task (Task): Task to adapt (it is deep copied prior to mutating options)
180
+ dataset (Dataset | Sequence[Sample]): Dataset to evaluate
181
+ setup: (Solver | list[Solver] | None): Setup step (always run
182
+ even when the main `solver` is replaced).
183
+ solver: (Solver | list[Solver]): Solver or list of solvers.
184
+ Defaults to generate(), a normal call to the model.
185
+ scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output.
186
+ metrics (list[Metric] | dict[str, list[Metric]] | None):
187
+ Alternative metrics (overrides the metrics provided by the specified scorer).
188
+ config (GenerateConfig): Model generation config.
189
+ sandbox (SandboxEnvironmentType | None): Sandbox environment type
190
+ (or optionally a str or tuple with a shorthand spec)
191
+ approval: (str | list[ApprovalPolicy] | None): Tool use approval policies.
192
+ Either a path to an approval policy config file or a list of approval policies.
193
+ Defaults to no approval policy.
194
+ epochs (int | Epochs | None): Epochs to repeat samples for and optional score
195
+ reducer function(s) used to combine sample scores (defaults to "mean")
196
+ fail_on_error (bool | float | None): `True` to fail on first sample error
197
+ (default); `False` to never fail on sample errors; Value between 0 and 1
198
+ to fail if a proportion of total samples fails. Value greater than 1 to fail
199
+ eval if a count of samples fails.
200
+ message_limit (int | None): Limit on total messages used for each sample.
201
+ token_limit (int | None): Limit on total tokens used for each sample.
202
+ time_limit (int | None): Limit on time (in seconds) for execution of each sample.
203
+ name: (str | None): Task name. If not specified is automatically
204
+ determined based on the name of the task directory (or "task")
205
+ if its anonymous task (e.g. created in a notebook and passed to
206
+ eval() directly)
207
+ version: (int): Version of task (to distinguish evolutions
208
+ of the task spec or breaking changes to it)
209
+ metadata: (dict[str, Any] | None): Additional metadata to associate with the task.
210
+
211
+ Returns:
212
+ Task: Task adapted with alternate options.
213
+ """
214
+ # deep copy the task
215
+ task = deepcopy(task)
216
+
217
+ if not isinstance(dataset, NotGiven):
218
+ task.dataset = resolve_dataset(dataset)
219
+ if not isinstance(setup, NotGiven):
220
+ task.setup = setup
221
+ if not isinstance(solver, NotGiven):
222
+ task.solver = resolve_solver(solver)
223
+ if not isinstance(scorer, NotGiven):
224
+ task.scorer = resolve_scorer(scorer)
225
+ if not isinstance(metrics, NotGiven):
226
+ task.metrics = metrics
227
+ if not isinstance(config, NotGiven):
228
+ task.config = config
229
+ if not isinstance(sandbox, NotGiven):
230
+ task.sandbox = resolve_sandbox_environment(sandbox)
231
+ if not isinstance(approval, NotGiven):
232
+ task.approval = resolve_approval(approval)
233
+ if not isinstance(epochs, NotGiven):
234
+ epochs = resolve_epochs(epochs)
235
+ task.epochs = epochs.epochs if epochs else None
236
+ task.epochs_reducer = epochs.reducer if epochs else None
237
+ if not isinstance(fail_on_error, NotGiven):
238
+ task.fail_on_error = fail_on_error
239
+ if not isinstance(message_limit, NotGiven):
240
+ task.message_limit = message_limit
241
+ if not isinstance(token_limit, NotGiven):
242
+ task.token_limit = token_limit
243
+ if not isinstance(time_limit, NotGiven):
244
+ task.time_limit = time_limit
245
+ if not isinstance(version, NotGiven):
246
+ task.version = version
247
+ if not isinstance(name, NotGiven):
248
+ task._name = name
249
+ if not isinstance(metadata, NotGiven):
250
+ task.metadata = metadata
251
+
252
+ # return modified task
253
+ return task
254
+
255
+
174
256
  class TaskInfo(BaseModel):
175
257
  """Task information (file, name, and attributes)."""
176
258
 
@@ -225,3 +307,36 @@ classes, and task instances (a single task or list of tasks
225
307
  can be specified). None is a request to read a task out
226
308
  of the current working directory.
227
309
  """
310
+
311
+
312
+ def resolve_approval(
313
+ approval: str | list[ApprovalPolicy] | None,
314
+ ) -> list[ApprovalPolicy] | None:
315
+ return (
316
+ approval_policies_from_config(approval)
317
+ if isinstance(approval, str)
318
+ else approval
319
+ )
320
+
321
+
322
+ def resolve_epochs(epochs: int | Epochs | None) -> Epochs | None:
323
+ if isinstance(epochs, int):
324
+ epochs = Epochs(epochs)
325
+ if epochs is not None and epochs.epochs < 1:
326
+ raise ValueError("epochs must be a positive integer.")
327
+ return epochs
328
+
329
+
330
+ def resolve_dataset(dataset: Dataset | Sequence[Sample] | None) -> Dataset:
331
+ dataset = dataset or [Sample(input="prompt")]
332
+ return dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
333
+
334
+
335
+ def resolve_solver(solver: Solver | list[Solver]) -> Solver:
336
+ return chain(solver) if isinstance(solver, list) else solver
337
+
338
+
339
+ def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
340
+ return (
341
+ scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
342
+ )
@@ -36,3 +36,4 @@ SCORED_SUFFIX = "-scored"
36
36
  SAMPLE_SUBTASK = "sample"
37
37
  CONSOLE_DISPLAY_WIDTH = 120
38
38
  BASE_64_DATA_REMOVED = "<base64-data-removed>"
39
+ SANDBOX_SETUP_TIMEOUT = 300