inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -9
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +79 -12
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/eval.py +10 -1
  13. inspect_ai/_eval/loader.py +79 -19
  14. inspect_ai/_eval/registry.py +6 -0
  15. inspect_ai/_eval/score.py +3 -1
  16. inspect_ai/_eval/task/results.py +51 -22
  17. inspect_ai/_eval/task/run.py +47 -13
  18. inspect_ai/_eval/task/sandbox.py +10 -5
  19. inspect_ai/_util/constants.py +1 -0
  20. inspect_ai/_util/port_names.py +61 -0
  21. inspect_ai/_util/text.py +23 -0
  22. inspect_ai/_view/www/App.css +31 -1
  23. inspect_ai/_view/www/dist/assets/index.css +31 -1
  24. inspect_ai/_view/www/dist/assets/index.js +25498 -2044
  25. inspect_ai/_view/www/log-schema.json +32 -2
  26. inspect_ai/_view/www/package.json +2 -0
  27. inspect_ai/_view/www/src/App.mjs +14 -16
  28. inspect_ai/_view/www/src/Types.mjs +1 -2
  29. inspect_ai/_view/www/src/api/Types.ts +133 -0
  30. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  31. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  32. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  33. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  34. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  35. inspect_ai/_view/www/src/api/index.ts +51 -0
  36. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  37. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  38. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  39. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  40. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  41. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  42. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  43. inspect_ai/_view/www/src/index.js +77 -4
  44. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
  46. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
  47. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  48. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  49. inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
  50. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  51. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  52. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
  53. inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
  54. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  55. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  56. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
  76. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  77. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
  78. inspect_ai/_view/www/vite.config.js +7 -0
  79. inspect_ai/_view/www/yarn.lock +116 -0
  80. inspect_ai/approval/_human/__init__.py +0 -0
  81. inspect_ai/approval/_human/manager.py +1 -1
  82. inspect_ai/approval/_policy.py +12 -6
  83. inspect_ai/log/_log.py +1 -1
  84. inspect_ai/log/_samples.py +16 -0
  85. inspect_ai/log/_transcript.py +4 -1
  86. inspect_ai/model/_call_tools.py +59 -0
  87. inspect_ai/model/_conversation.py +16 -7
  88. inspect_ai/model/_generate_config.py +12 -12
  89. inspect_ai/model/_model.py +117 -18
  90. inspect_ai/model/_model_output.py +22 -2
  91. inspect_ai/model/_openai.py +383 -0
  92. inspect_ai/model/_providers/anthropic.py +152 -55
  93. inspect_ai/model/_providers/azureai.py +21 -21
  94. inspect_ai/model/_providers/bedrock.py +37 -40
  95. inspect_ai/model/_providers/goodfire.py +248 -0
  96. inspect_ai/model/_providers/google.py +46 -54
  97. inspect_ai/model/_providers/groq.py +7 -3
  98. inspect_ai/model/_providers/hf.py +6 -0
  99. inspect_ai/model/_providers/mistral.py +13 -12
  100. inspect_ai/model/_providers/openai.py +51 -218
  101. inspect_ai/model/_providers/openai_o1.py +11 -12
  102. inspect_ai/model/_providers/providers.py +23 -1
  103. inspect_ai/model/_providers/together.py +12 -12
  104. inspect_ai/model/_providers/util/__init__.py +2 -3
  105. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  106. inspect_ai/model/_providers/util/llama31.py +1 -1
  107. inspect_ai/model/_providers/util/util.py +0 -76
  108. inspect_ai/model/_providers/vertex.py +1 -4
  109. inspect_ai/scorer/_metric.py +3 -0
  110. inspect_ai/scorer/_reducer/reducer.py +1 -1
  111. inspect_ai/scorer/_scorer.py +4 -3
  112. inspect_ai/solver/__init__.py +4 -5
  113. inspect_ai/solver/_basic_agent.py +1 -1
  114. inspect_ai/solver/_bridge/__init__.py +3 -0
  115. inspect_ai/solver/_bridge/bridge.py +100 -0
  116. inspect_ai/solver/_bridge/patch.py +170 -0
  117. inspect_ai/solver/_prompt.py +35 -5
  118. inspect_ai/solver/_solver.py +6 -0
  119. inspect_ai/solver/_task_state.py +80 -38
  120. inspect_ai/tool/__init__.py +2 -0
  121. inspect_ai/tool/_tool.py +12 -1
  122. inspect_ai/tool/_tool_call.py +10 -0
  123. inspect_ai/tool/_tool_def.py +16 -5
  124. inspect_ai/tool/_tool_with.py +21 -4
  125. inspect_ai/tool/beta/__init__.py +5 -0
  126. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  127. inspect_ai/tool/beta/_computer/_common.py +133 -0
  128. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  129. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  130. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  131. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  134. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  135. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  136. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  137. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  138. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  139. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  144. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  145. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  146. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  147. inspect_ai/util/__init__.py +2 -0
  148. inspect_ai/util/_display.py +5 -0
  149. inspect_ai/util/_limit.py +26 -0
  150. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  153. inspect_ai/util/_sandbox/environment.py +14 -0
  154. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  155. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
  156. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  157. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  158. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  159. inspect_ai/_view/www/src/api/index.mjs +0 -49
  160. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  161. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  162. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  163. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  164. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  165. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  166. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -111,7 +111,7 @@ def pass_at(
111
111
  if total - correct < k:
112
112
  return 1.0
113
113
  else:
114
- return 1.0 - cast(
114
+ return 1.0 - cast( # type: ignore[redundant-cast]
115
115
  float,
116
116
  np.prod(1.0 - k / np.arange(total - correct + 1, total + 1)).item(),
117
117
  )
@@ -1,3 +1,4 @@
1
+ from functools import wraps
1
2
  from typing import (
2
3
  Any,
3
4
  Callable,
@@ -100,7 +101,6 @@ def scorer(
100
101
 
101
102
  Returns:
102
103
  Scorer with registry attributes.
103
-
104
104
  """
105
105
 
106
106
  def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
@@ -110,6 +110,7 @@ def scorer(
110
110
  )
111
111
 
112
112
  # wrap instantiations of scorer so they carry registry info and metrics
113
+ @wraps(scorer_type)
113
114
  def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
114
115
  scorer = scorer_type(*args, **kwargs)
115
116
 
@@ -151,8 +152,8 @@ def scorer_metrics(
151
152
  return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
152
153
 
153
154
 
154
- def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
155
- base_name = registry_unqualified_name(scorer)
155
+ def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
156
+ base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
156
157
  scorer_name = base_name
157
158
  count = 1
158
159
  while scorer_name in already_used_names:
@@ -1,23 +1,21 @@
1
1
  from inspect_ai._util.deprecation import relocated_module_attribute
2
2
 
3
3
  from ._basic_agent import basic_agent
4
+ from ._bridge import bridge
4
5
  from ._chain import chain
5
6
  from ._critique import self_critique
6
7
  from ._fork import fork
7
8
  from ._human_agent.agent import human_agent
8
9
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
9
10
  from ._plan import Plan, plan
10
- from ._prompt import (
11
- chain_of_thought,
12
- prompt_template,
13
- system_message,
14
- )
11
+ from ._prompt import chain_of_thought, prompt_template, system_message, user_message
15
12
  from ._solver import Generate, Solver, SolverSpec, generate, solver
16
13
  from ._task_state import Choice, Choices, TaskState
17
14
  from ._use_tools import use_tools
18
15
 
19
16
  __all__ = [
20
17
  "basic_agent",
18
+ "bridge",
21
19
  "human_agent",
22
20
  "chain",
23
21
  "fork",
@@ -26,6 +24,7 @@ __all__ = [
26
24
  "chain_of_thought",
27
25
  "multiple_choice",
28
26
  "system_message",
27
+ "user_message",
29
28
  "self_critique",
30
29
  "use_tools",
31
30
  "plan",
@@ -119,7 +119,7 @@ def basic_agent(
119
119
  # resolve tools
120
120
  if tools is None:
121
121
  tools = []
122
- tools = tools if isinstance(tools, Solver) else use_tools(tools)
122
+ tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
123
123
 
124
124
  # resolve score_value function
125
125
  score_value_fn = score_value or value_to_float()
@@ -0,0 +1,3 @@
1
+ from .bridge import bridge
2
+
3
+ __all__ = ["bridge"]
@@ -0,0 +1,100 @@
1
+ from typing import Any, Awaitable, Callable
2
+
3
+ from jsonschema import Draft7Validator
4
+ from pydantic import BaseModel, Field, ValidationError
5
+ from pydantic_core import to_json
6
+
7
+ from inspect_ai._util._async import is_callable_coroutine
8
+ from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
9
+ from inspect_ai.model._providers.providers import validate_openai_client
10
+ from inspect_ai.scorer._metric import Score
11
+
12
+ from .._solver import Generate, Solver, solver
13
+ from .._task_state import TaskState
14
+
15
+
16
+ @solver
17
+ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
18
+ """Bridge an external agent into an Inspect Solver.
19
+
20
+ See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
21
+
22
+ Args:
23
+ agent: Callable which takes a sample `dict` and returns a result `dict`.
24
+
25
+ Returns:
26
+ Standard Inspect solver.
27
+ """
28
+ validate_openai_client("Solver bridge()")
29
+
30
+ from openai.types.chat import ChatCompletionMessageParam
31
+
32
+ from inspect_ai.model._openai import (
33
+ chat_messages_from_openai,
34
+ openai_chat_messages,
35
+ )
36
+
37
+ from .patch import openai_request_to_inspect_model
38
+
39
+ class BridgeSample(BaseModel):
40
+ sample_id: str
41
+ epoch: int
42
+ input: list[ChatCompletionMessageParam]
43
+ metadata: dict[str, Any]
44
+ target: list[str]
45
+
46
+ class BridgeResult(BaseModel):
47
+ output: str
48
+ messages: list[ChatCompletionMessageParam] | None = Field(default=None)
49
+ scores: dict[str, Score] | None = Field(default=None)
50
+
51
+ result_schema = BridgeResult.model_json_schema()
52
+ result_validator = Draft7Validator(result_schema)
53
+
54
+ # validate that the agent is an async function
55
+ if not is_callable_coroutine(agent):
56
+ raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
57
+
58
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
59
+ # resolve input to array
60
+ input: list[ChatMessage] = (
61
+ [ChatMessageUser(content=state.input)]
62
+ if isinstance(state.input, str)
63
+ else state.input
64
+ )
65
+
66
+ # create sample
67
+ sample = BridgeSample(
68
+ sample_id=str(state.sample_id),
69
+ epoch=state.epoch,
70
+ input=await openai_chat_messages(input, state.model.name),
71
+ metadata=state.metadata,
72
+ target=list(state.target),
73
+ )
74
+
75
+ # run target function
76
+ async with openai_request_to_inspect_model():
77
+ # call the function
78
+ result_dict = await agent(sample.model_dump())
79
+ try:
80
+ result = BridgeResult.model_validate(result_dict)
81
+ except ValidationError:
82
+ # if we fail to validate provide a better human readable error
83
+ errors = list(result_validator.iter_errors(result_dict))
84
+ message = "\n".join(
85
+ ["Result returned from bridged solver is not valid:"]
86
+ + [f" - {error.message}" for error in errors]
87
+ + ["", to_json(result_dict, indent=2).decode()]
88
+ )
89
+ raise ValueError(message)
90
+
91
+ # update and return state
92
+ state.output.completion = result.output
93
+ if result.messages is not None:
94
+ state.messages = chat_messages_from_openai(result.messages)
95
+ if result.scores is not None:
96
+ state.scores = result.scores
97
+
98
+ return state
99
+
100
+ return solve
@@ -0,0 +1,170 @@
1
+ import contextlib
2
+ import re
3
+ from contextvars import ContextVar
4
+ from functools import wraps
5
+ from time import time
6
+ from typing import Any, AsyncGenerator, Optional, Type, cast
7
+
8
+ from openai._base_client import AsyncAPIClient, _AsyncStreamT
9
+ from openai._models import FinalRequestOptions
10
+ from openai._types import ResponseT
11
+ from openai.types.chat import (
12
+ ChatCompletion,
13
+ ChatCompletionMessageParam,
14
+ ChatCompletionToolParam,
15
+ )
16
+ from shortuuid import uuid
17
+
18
+ from inspect_ai.model._generate_config import GenerateConfig
19
+ from inspect_ai.model._model import get_model
20
+ from inspect_ai.model._openai import (
21
+ chat_messages_from_openai,
22
+ openai_chat_choices,
23
+ openai_completion_usage,
24
+ )
25
+ from inspect_ai.solver._task_state import sample_state
26
+ from inspect_ai.tool._tool_info import ToolInfo
27
+ from inspect_ai.tool._tool_params import ToolParams
28
+
29
+
30
+ @contextlib.asynccontextmanager
31
+ async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
32
+ # ensure one time init
33
+ init_openai_request_patch()
34
+
35
+ # set the patch enabled for this context and child coroutines
36
+ token = _patch_enabled.set(True)
37
+ try:
38
+ yield
39
+ finally:
40
+ _patch_enabled.reset(token)
41
+
42
+
43
+ _patch_initialised: bool = False
44
+
45
+ _patch_enabled: ContextVar[bool] = ContextVar(
46
+ "openai_request_patch_enabled", default=False
47
+ )
48
+
49
+
50
+ def init_openai_request_patch() -> None:
51
+ global _patch_initialised
52
+ if not _patch_initialised:
53
+ # get reference to original method
54
+ original_request = getattr(AsyncAPIClient, "request")
55
+ if original_request is None:
56
+ raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
57
+
58
+ @wraps(original_request)
59
+ async def patched_request(
60
+ self: AsyncAPIClient,
61
+ cast_to: Type[ResponseT],
62
+ options: FinalRequestOptions,
63
+ *,
64
+ stream: bool = False,
65
+ stream_cls: type[_AsyncStreamT] | None = None,
66
+ remaining_retries: Optional[int] = None,
67
+ ) -> Any:
68
+ # we have patched the underlying request method so now need to figure out when to
69
+ # patch and when to stand down
70
+ if (
71
+ # enabled for this coroutine
72
+ _patch_enabled.get()
73
+ # completions request
74
+ and options.url == "/chat/completions"
75
+ # call to openai not another service (e.g. TogetherAI)
76
+ and self.base_url == "https://api.openai.com/v1/"
77
+ ):
78
+ # must also be an explicit request for an inspect model
79
+ json_data = cast(dict[str, Any], options.json_data)
80
+ model_name = str(json_data["model"])
81
+ if re.match(r"^inspect/?", model_name):
82
+ return await inspect_model_request(model_name, options)
83
+
84
+ # otherwise just delegate
85
+ return await original_request(
86
+ self,
87
+ cast_to,
88
+ options,
89
+ stream=stream,
90
+ stream_cls=stream_cls,
91
+ remaining_retries=remaining_retries,
92
+ )
93
+
94
+ setattr(AsyncAPIClient, "request", patched_request)
95
+
96
+
97
+ async def inspect_model_request(
98
+ model_name: str, options: FinalRequestOptions
99
+ ) -> ChatCompletion:
100
+ # convert openai messages to inspect messages
101
+ json_data = cast(dict[str, Any], options.json_data)
102
+ messages: list[ChatCompletionMessageParam] = json_data["messages"]
103
+ input = chat_messages_from_openai(messages)
104
+
105
+ # convert openai tools to inspect tools
106
+ tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
107
+ inspect_tools: list[ToolInfo] = []
108
+ for tool in tools:
109
+ function = tool["function"].copy()
110
+ inspect_tools.append(
111
+ ToolInfo(
112
+ name=function["name"],
113
+ description=function["description"],
114
+ parameters=ToolParams.model_validate(function["parameters"]),
115
+ )
116
+ )
117
+
118
+ # resolve model
119
+ if model_name == "inspect":
120
+ model = get_model()
121
+ else:
122
+ model = get_model(model_name.removeprefix("inspect/"))
123
+
124
+ output = await model.generate(
125
+ input=input,
126
+ tools=inspect_tools,
127
+ config=generate_config_from_openai(options),
128
+ )
129
+
130
+ # if we are using the "default" inspect model for the task, update state.messages
131
+ if model_name == "inspect":
132
+ state = sample_state()
133
+ if state:
134
+ state.messages = input + [output.choices[0].message]
135
+
136
+ # inspect completion to openai completion
137
+ return ChatCompletion(
138
+ id=uuid(),
139
+ created=int(time()),
140
+ object="chat.completion",
141
+ choices=openai_chat_choices(output.choices),
142
+ model=model_name,
143
+ usage=openai_completion_usage(output.usage) if output.usage else None,
144
+ )
145
+
146
+
147
+ def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
148
+ # get options dict
149
+ json_data = cast(dict[str, Any], options.json_data)
150
+
151
+ config = GenerateConfig()
152
+ config.max_tokens = json_data.get(
153
+ "max_completion_tokens", json_data.get("max_tokens", None)
154
+ )
155
+ config.top_p = json_data.get("top_p", None)
156
+ config.temperature = json_data.get("temperature", None)
157
+ stop = json_data.get("stop", None)
158
+ if stop:
159
+ config.stop_seqs = [stop] if isinstance(stop, str) else stop
160
+ config.frequency_penalty = json_data.get("frequency_penalty", None)
161
+ config.presence_penalty = json_data.get("presence_penalty", None)
162
+ config.seed = json_data.get("seed", None)
163
+ config.num_choices = json_data.get("n", None)
164
+ config.logprobs = json_data.get("logprobs", None)
165
+ config.top_logprobs = json_data.get("top_logprobs", None)
166
+ config.logit_bias = json_data.get("logit_bias", None)
167
+ config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
168
+ config.reasoning_effort = json_data.get("reasoning_effort", None)
169
+
170
+ return config
@@ -2,6 +2,7 @@ from typing import Any
2
2
 
3
3
  from inspect_ai._util.dict import omit
4
4
  from inspect_ai.model import ChatMessageSystem
5
+ from inspect_ai.model._chat_message import ChatMessageUser
5
6
  from inspect_ai.util import resource
6
7
 
7
8
  from ._solver import Generate, Solver, solver
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
15
16
 
16
17
  Prompt template containing a `{prompt}` placeholder and any
17
18
  number of additional `params`. All values contained in sample
18
- `metadata` are also automatically included in the `params`.
19
+ `metadata` and `store` are also automatically included in the
20
+ `params`.
19
21
 
20
22
  Args:
21
23
  template: (str): Template for prompt.
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
29
31
 
30
32
  async def solve(state: TaskState, generate: Generate) -> TaskState:
31
33
  prompt = state.user_prompt
32
- kwargs = omit(state.metadata, ["prompt"]) | params
34
+ kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
33
35
  prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
34
36
  return state
35
37
 
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
41
43
  """Solver which inserts a system message into the conversation.
42
44
 
43
45
  System message template containing any number of optional `params`.
44
- for substitution. All values contained in sample `metadata` are also
45
- automatically included in the `params`.
46
+ for substitution using the `str.format()` method. All values
47
+ contained in sample `metadata` and `store` are also automatically
48
+ included in the `params`.
46
49
 
47
50
  The new message will go after other system messages (if there
48
51
  are none it will be inserted at the beginning of the conversation).
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
58
61
  content = resource(template)
59
62
 
60
63
  async def solve(state: TaskState, generate: Generate) -> TaskState:
61
- kwargs = state.metadata | params
64
+ kwargs = state.metadata | state.store._data | params
62
65
  append_system_message(
63
66
  state.messages, ChatMessageSystem(content=content.format(**kwargs))
64
67
  )
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
67
70
  return solve
68
71
 
69
72
 
73
+ @solver
74
+ def user_message(template: str, **params: Any) -> Solver:
75
+ """Solver which inserts a user message into the conversation.
76
+
77
+ User message template containing any number of optional `params`.
78
+ for substitution using the `str.format()` method. All values
79
+ contained in sample `metadata` and `store` are also automatically
80
+ included in the `params`.
81
+
82
+ Args:
83
+ template (str): Template for user message.
84
+ **params (dict[str,Any]): Parameters to fill into the template.
85
+
86
+ Returns:
87
+ A solver that inserts the parameterised user message.
88
+ """
89
+ # read template
90
+ content = resource(template)
91
+
92
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
93
+ kwargs = state.metadata | state.store._data | params
94
+ state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
95
+ return state
96
+
97
+ return solve
98
+
99
+
70
100
  DEFAULT_COT_TEMPLATE = r"""
71
101
  {prompt}
72
102
 
@@ -180,6 +180,7 @@ def solver(
180
180
  solver_type, name if name else getattr(solver_type, "__name__")
181
181
  )
182
182
 
183
+ @wraps(solver_type)
183
184
  def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
184
185
  solver = solver_type(*args, **kwargs)
185
186
 
@@ -193,6 +194,7 @@ def solver(
193
194
  if inspect.isclass(type(solver)):
194
195
  original_call = solver.__call__
195
196
 
197
+ @wraps(original_call)
196
198
  async def call_with_state(
197
199
  state: TaskState, generate: Generate
198
200
  ) -> TaskState:
@@ -225,6 +227,10 @@ def solver(
225
227
 
226
228
  return registered_solver
227
229
 
230
+ # functools.wraps overrides the return type annotation of the inner function, so
231
+ # we explicitly set it again
232
+ solver_wrapper.__annotations__["return"] = Solver
233
+
228
234
  return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
229
235
 
230
236
  # for decorators with an explicit name, one more wrapper for the name
@@ -2,8 +2,9 @@ from collections.abc import Sequence
2
2
  from contextvars import ContextVar
3
3
  from copy import deepcopy
4
4
  from dataclasses import dataclass
5
+ from itertools import tee
5
6
  from random import Random
6
- from typing import Any, Type, Union, cast, overload
7
+ from typing import Any, Iterable, SupportsIndex, Type, Union, cast, overload
7
8
 
8
9
  from pydantic_core import to_jsonable_python
9
10
 
@@ -15,9 +16,13 @@ from inspect_ai.model import (
15
16
  ModelOutput,
16
17
  )
17
18
  from inspect_ai.model._call_tools import tools_info
19
+ from inspect_ai.model._chat_message import ChatMessageBase
18
20
  from inspect_ai.model._model import sample_total_tokens
21
+ from inspect_ai.scorer._metric import Score
22
+ from inspect_ai.scorer._target import Target
19
23
  from inspect_ai.tool import Tool, ToolChoice
20
24
  from inspect_ai.tool._tool_def import ToolDef
25
+ from inspect_ai.util._limit import SampleLimitExceededError
21
26
  from inspect_ai.util._store import Store, store_jsonable
22
27
  from inspect_ai.util._store_model import SMT
23
28
 
@@ -136,6 +141,7 @@ class TaskState:
136
141
  epoch: int,
137
142
  input: str | list[ChatMessage],
138
143
  messages: list[ChatMessage],
144
+ target: Target = Target(""),
139
145
  choices: list[str] | None = [],
140
146
  output: ModelOutput | None = None,
141
147
  message_limit: int | None = None,
@@ -161,10 +167,13 @@ class TaskState:
161
167
  or `input_text` only
162
168
  """
163
169
 
170
+ self.target = target
171
+ """The scoring target for this `Sample`."""
172
+
164
173
  self.metadata = metadata
165
174
  """Metadata from the `Sample` for this `TaskState`"""
166
175
 
167
- self.messages = messages
176
+ self._messages: list[ChatMessage] = ChatMessageList(messages)
168
177
  """
169
178
  Chat conversation history for sample.
170
179
 
@@ -189,9 +198,7 @@ class TaskState:
189
198
  """
190
199
 
191
200
  self._message_limit = message_limit
192
- self._message_limit_exceeded = False
193
201
  self._token_limit = token_limit
194
- self._token_limit_exceeded = False
195
202
  self._completed = completed
196
203
 
197
204
  """Store for shared data"""
@@ -202,6 +209,9 @@ class TaskState:
202
209
  else:
203
210
  self.choices = Choices([])
204
211
 
212
+ self.scores: dict[str, Score] | None = None
213
+ """Scores yielded by running task."""
214
+
205
215
  @property
206
216
  def model(self) -> ModelName:
207
217
  """Name of model being evaluated."""
@@ -254,6 +264,16 @@ class TaskState:
254
264
  else:
255
265
  raise ValueError("user_prompt requested from TaskState but none available")
256
266
 
267
+ @property
268
+ def messages(self) -> list[ChatMessage]:
269
+ """Messages in chat history"""
270
+ return self._messages
271
+
272
+ @messages.setter
273
+ def messages(self, messages: list[ChatMessage]) -> None:
274
+ """Set messages in chat history."""
275
+ self._messages = ChatMessageList(messages)
276
+
257
277
  @property
258
278
  def max_messages(self) -> int | None:
259
279
  """Deprecated (use message_limit)."""
@@ -300,40 +320,7 @@ class TaskState:
300
320
  @property
301
321
  def completed(self) -> bool:
302
322
  """Is the task completed."""
303
- # update messages
304
- from inspect_ai.log._samples import set_active_sample_total_messages
305
- from inspect_ai.log._transcript import SampleLimitEvent, transcript
306
-
307
- set_active_sample_total_messages(len(self.messages))
308
-
309
- if self._completed:
310
- return True
311
- elif self.message_limit and len(self.messages) >= self.message_limit:
312
- # log if this is the first time we hit this
313
- if not self._message_limit_exceeded:
314
- self._message_limit_exceeded = True
315
- transcript()._event(
316
- SampleLimitEvent(
317
- type="message",
318
- message=f"Sample completed: exceeded message limit ({self.message_limit})",
319
- limit=self.message_limit,
320
- )
321
- )
322
- return True
323
- elif self.token_limit and self.token_usage >= self.token_limit:
324
- # log if this is the first time we hit this
325
- if not self._token_limit_exceeded:
326
- self._token_limit_exceeded = True
327
- transcript()._event(
328
- SampleLimitEvent(
329
- type="token",
330
- message=f"Sample completed: exceeded token limit ({self.token_limit:,})",
331
- limit=self.token_limit,
332
- )
333
- )
334
- return True
335
- else:
336
- return False
323
+ return self._completed
337
324
 
338
325
  @completed.setter
339
326
  def completed(self, completed: bool) -> None:
@@ -413,3 +400,58 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
413
400
  def sample_jsonable(sample: Sample) -> dict[str, Any]:
414
401
  jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
415
402
  return cast(dict[str, Any], deepcopy(jsonable))
403
+
404
+
405
+ class ChatMessageList(list[ChatMessage]):
406
+ def __init__(self, iterable: Iterable[ChatMessage]):
407
+ items, length = self._iterable_length(iterable)
408
+ self._check_size(length)
409
+ super().__init__(items)
410
+
411
+ def _check_size(self, additional_items: int = 1) -> None:
412
+ from inspect_ai.log._samples import active_sample_message_limit
413
+
414
+ messages_limit = active_sample_message_limit()
415
+ if messages_limit is not None:
416
+ messages = len(self) + additional_items
417
+ if messages > messages_limit:
418
+ raise SampleLimitExceededError(
419
+ "message", value=messages, limit=messages_limit
420
+ )
421
+
422
+ def append(self, item: ChatMessage) -> None:
423
+ self._check_size()
424
+ super().append(item)
425
+
426
+ def extend(self, items: Iterable[ChatMessage]) -> None:
427
+ items, length = self._iterable_length(items)
428
+ self._check_size(length)
429
+ super().extend(items)
430
+
431
+ def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
432
+ self._check_size()
433
+ super().insert(index, item)
434
+
435
+ @overload
436
+ def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
437
+
438
+ @overload
439
+ def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
440
+
441
+ def __setitem__(
442
+ self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
443
+ ) -> None:
444
+ if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
445
+ item, length = self._iterable_length(item)
446
+ size_change = length - len(self[index])
447
+ if size_change > 0:
448
+ self._check_size(size_change)
449
+
450
+ super().__setitem__(index, item) # type: ignore[assignment,index]
451
+
452
+ def _iterable_length(
453
+ self, items: Iterable[ChatMessage]
454
+ ) -> tuple[Iterable[ChatMessage], int]:
455
+ items, counter = tee(items)
456
+ length = sum(1 for _ in counter)
457
+ return items, length
@@ -12,6 +12,7 @@ from ._tool_call import (
12
12
  ToolCall,
13
13
  ToolCallContent,
14
14
  ToolCallError,
15
+ ToolCallModelInput,
15
16
  ToolCallView,
16
17
  ToolCallViewer,
17
18
  )
@@ -42,6 +43,7 @@ __all__ = [
42
43
  "ContentVideo",
43
44
  "ToolCall",
44
45
  "ToolCallContent",
46
+ "ToolCallModelInput",
45
47
  "ToolCallView",
46
48
  "ToolCallViewer",
47
49
  "ToolChoice",