inspect-ai 0.3.88__py3-none-any.whl → 0.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -256
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +71 -36
  22. inspect_ai/_view/www/dist/assets/index.js +573 -475
  23. inspect_ai/_view/www/log-schema.json +66 -0
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -2
  30. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
  31. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  32. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -6
  33. inspect_ai/_view/www/src/samples/transcript/TranscriptView.module.css +0 -2
  34. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  35. inspect_ai/_view/www/src/types/log.d.ts +24 -6
  36. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  37. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  38. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  39. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  40. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  41. inspect_ai/agent/_agent.py +12 -0
  42. inspect_ai/agent/_as_tool.py +1 -1
  43. inspect_ai/agent/_bridge/bridge.py +9 -2
  44. inspect_ai/agent/_react.py +142 -74
  45. inspect_ai/agent/_run.py +13 -2
  46. inspect_ai/agent/_types.py +6 -0
  47. inspect_ai/approval/_apply.py +6 -7
  48. inspect_ai/approval/_approver.py +3 -3
  49. inspect_ai/approval/_auto.py +2 -2
  50. inspect_ai/approval/_call.py +20 -4
  51. inspect_ai/approval/_human/approver.py +3 -3
  52. inspect_ai/approval/_human/manager.py +2 -2
  53. inspect_ai/approval/_human/panel.py +3 -3
  54. inspect_ai/approval/_policy.py +3 -3
  55. inspect_ai/log/__init__.py +2 -0
  56. inspect_ai/log/_log.py +23 -2
  57. inspect_ai/log/_model.py +58 -0
  58. inspect_ai/log/_recorders/file.py +14 -3
  59. inspect_ai/log/_transcript.py +3 -0
  60. inspect_ai/model/__init__.py +2 -0
  61. inspect_ai/model/_call_tools.py +4 -1
  62. inspect_ai/model/_model.py +49 -3
  63. inspect_ai/model/_openai.py +151 -21
  64. inspect_ai/model/_providers/anthropic.py +20 -12
  65. inspect_ai/model/_providers/bedrock.py +3 -3
  66. inspect_ai/model/_providers/cloudflare.py +29 -108
  67. inspect_ai/model/_providers/google.py +21 -10
  68. inspect_ai/model/_providers/grok.py +23 -17
  69. inspect_ai/model/_providers/groq.py +61 -37
  70. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  71. inspect_ai/model/_providers/mistral.py +8 -3
  72. inspect_ai/model/_providers/ollama.py +8 -9
  73. inspect_ai/model/_providers/openai.py +53 -157
  74. inspect_ai/model/_providers/openai_compatible.py +195 -0
  75. inspect_ai/model/_providers/openrouter.py +4 -15
  76. inspect_ai/model/_providers/providers.py +11 -0
  77. inspect_ai/model/_providers/together.py +25 -23
  78. inspect_ai/model/_trim.py +83 -0
  79. inspect_ai/solver/_plan.py +5 -3
  80. inspect_ai/tool/_tool_def.py +8 -2
  81. inspect_ai/util/__init__.py +3 -0
  82. inspect_ai/util/_concurrency.py +15 -2
  83. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/METADATA +1 -1
  84. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/RECORD +88 -83
  85. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/WHEEL +1 -1
  86. inspect_ai/_eval/task/rundir.py +0 -78
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/entry_points.txt +0 -0
  89. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/licenses/LICENSE +0 -0
  90. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from pydantic_core import to_json
6
6
 
7
7
  from inspect_ai._util._async import is_callable_coroutine
8
8
  from inspect_ai.agent._agent import Agent, AgentState, agent
9
+ from inspect_ai.log._samples import sample_active
9
10
  from inspect_ai.model._model import get_model
10
11
  from inspect_ai.model._model_output import ModelOutput
11
12
  from inspect_ai.model._providers.providers import validate_openai_client
@@ -37,6 +38,10 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Agen
37
38
  class BridgeInput(BaseModel):
38
39
  messages: list[ChatCompletionMessageParam]
39
40
 
41
+ # here for backward compatibilty w/ previous bridge
42
+ # (we may choose to add this to AgentState at some point)
43
+ metadata: dict[str, Any]
44
+
40
45
  # temporarily here for backward compatibility w/ previous bridge
41
46
  input: list[ChatCompletionMessageParam]
42
47
 
@@ -53,8 +58,10 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Agen
53
58
 
54
59
  async def execute(state: AgentState) -> AgentState:
55
60
  # create input (use standard gpt-4 message encoding -- i.e. no 'developer' messages)
56
- messages = await openai_chat_messages(state.messages, model="gpt-4")
57
- input = BridgeInput(messages=messages, input=messages)
61
+ sample = sample_active()
62
+ metadata = (sample.sample.metadata if sample is not None else None) or {}
63
+ messages = await openai_chat_messages(state.messages)
64
+ input = BridgeInput(messages=messages, metadata=metadata, input=messages)
58
65
 
59
66
  # run target function
60
67
  async with openai_request_to_inspect_model():
@@ -1,22 +1,27 @@
1
1
  from logging import getLogger
2
+ from typing import Literal, cast
2
3
 
3
4
  from inspect_ai._util._async import is_callable_coroutine
4
5
  from inspect_ai.model._call_tools import execute_tools
5
6
  from inspect_ai.model._chat_message import (
6
7
  ChatMessage,
8
+ ChatMessageAssistant,
7
9
  ChatMessageSystem,
10
+ ChatMessageTool,
8
11
  ChatMessageUser,
9
12
  )
10
13
  from inspect_ai.model._model import Model, get_model
14
+ from inspect_ai.model._trim import trim_messages
11
15
  from inspect_ai.scorer._score import score
12
16
  from inspect_ai.tool._tool import Tool, ToolResult, tool
13
- from inspect_ai.tool._tool_call import ToolCall
14
17
  from inspect_ai.tool._tool_info import parse_tool_info
15
18
  from inspect_ai.tool._tool_with import tool_with
16
19
 
17
20
  from ._agent import Agent, AgentState, agent, agent_with
21
+ from ._filter import MessageFilter
18
22
  from ._handoff import has_handoff
19
23
  from ._types import (
24
+ DEFAULT_CONTINUE_PROMPT,
20
25
  AgentAttempts,
21
26
  AgentContinue,
22
27
  AgentPrompt,
@@ -37,6 +42,7 @@ def react(
37
42
  attempts: int | AgentAttempts = 1,
38
43
  submit: AgentSubmit = AgentSubmit(),
39
44
  on_continue: str | AgentContinue | None = None,
45
+ truncation: Literal["auto", "disabled"] | MessageFilter = "disabled",
40
46
  ) -> Agent:
41
47
  """Extensible ReAct agent based on the paper [ReAct: Synergizing Reasoning and Acting in Language Models](https://arxiv.org/abs/2210.03629).
42
48
 
@@ -68,9 +74,16 @@ def react(
68
74
  attempts: Configure agent to make multiple attempts.
69
75
  submit: Configure submit tool used by agent.
70
76
  on_continue: Message to play back to the model to urge it to continue.
71
- Optionally, can also be an async function to call to determine whether
72
- the loop should continue (executed on every turn) and what message
73
- to play back.
77
+ Use the placeholder {submit} to refer to the submit tool within the message.
78
+ Alternatively, an async function to call to determine whether the loop
79
+ should continue and what message to play back. Note that this function
80
+ is called on _every_ iteration of the loop so if you only want to send
81
+ a message back when the model fails to call tools you need to code
82
+ that behavior explicitly.
83
+ truncation: Truncate the conversation history in the event of a context
84
+ window overflow. Defaults to "disabled" which does no truncation. Pass
85
+ "auto" to use `trim_messages()` to reduce the context size. Pass a
86
+ `MessageFilter` function to do custom truncation.
74
87
 
75
88
  Returns:
76
89
  ReAct agent.
@@ -90,24 +103,6 @@ def react(
90
103
  else:
91
104
  system_message = None
92
105
 
93
- # resolve on_continue
94
- if on_continue is None:
95
- on_continue = "If you believe you have completed the task, please call the `submit()` tool with your answer."
96
- if isinstance(on_continue, str):
97
- no_tools_continue_message = on_continue
98
-
99
- async def no_tools_continue(state: AgentState) -> bool | str:
100
- if state.output is None or not state.output.message.tool_calls:
101
- return no_tools_continue_message
102
- else:
103
- return True
104
-
105
- on_continue = no_tools_continue
106
-
107
- # validate that on_continue is async
108
- if not is_callable_coroutine(on_continue):
109
- raise ValueError("The on_continue function must be async.")
110
-
111
106
  # resolve attempts
112
107
  attempts = AgentAttempts(attempts) if isinstance(attempts, int) else attempts
113
108
 
@@ -124,12 +119,17 @@ def react(
124
119
 
125
120
  return execute
126
121
 
127
- # helper to see if there is a submit tool call
128
- def submitted_answer(tool_calls: list[ToolCall] | None) -> str | None:
129
- for tool_call in tool_calls or []:
130
- if tool_call.function == submit.name and tool_call.parse_error is None:
131
- return str(tool_call.arguments["answer"])
132
- return None
122
+ # helper to extract a submitted answer
123
+ def submission(tool_results: list[ChatMessage]) -> str | None:
124
+ return next(
125
+ (
126
+ result.text
127
+ for result in tool_results
128
+ if isinstance(result, ChatMessageTool)
129
+ and result.function == submit.name
130
+ ),
131
+ None,
132
+ )
133
133
 
134
134
  # resolve tools
135
135
  tools = tools or []
@@ -140,6 +140,14 @@ def react(
140
140
  if system_message:
141
141
  state.messages.insert(0, system_message)
142
142
 
143
+ # resolve overflow handling
144
+ if truncation == "auto":
145
+ overflow = cast(MessageFilter | None, trim_messages)
146
+ elif truncation == "disabled":
147
+ overflow = None
148
+ else:
149
+ overflow = truncation
150
+
143
151
  # track attempts
144
152
  attempt_count = 0
145
153
 
@@ -153,59 +161,95 @@ def react(
153
161
  if state.output.stop_reason == "model_length":
154
162
  from inspect_ai.log._transcript import transcript
155
163
 
164
+ if overflow is not None:
165
+ previous_messages = state.messages[:-1]
166
+ state.messages = await overflow(previous_messages)
167
+ if len(state.messages) < len(previous_messages):
168
+ transcript().info(
169
+ "Agent exceeded model context window, truncating messages and continuing."
170
+ )
171
+ continue
172
+
173
+ # no overflow policy or overflow didn't reduce conversation length
156
174
  transcript().info("Agent terminated: model context window exceeded")
157
175
  break
158
176
 
159
- # check for a submission
160
- answer = submitted_answer(state.output.message.tool_calls)
161
- if answer is not None:
162
- # remove the tool call and set the output to the answer for scoring
163
- state.output.message.tool_calls = None
164
- state.output.completion = (
165
- f"{state.output.completion}\n\n{answer}".strip()
166
- )
167
-
168
- # exit if we are at max_attempts
169
- attempt_count += 1
170
- if attempt_count >= attempts.attempts:
171
- break
172
-
173
- # exit if the submission is successful
174
- answer_scores = await score(state)
175
- if attempts.score_value(answer_scores[0].value) == 1.0:
176
- break
177
-
178
- # otherwise notify the model that it was incorrect and continue
179
- else:
180
- if callable(attempts.incorrect_message):
181
- if not is_callable_coroutine(attempts.incorrect_message):
182
- raise ValueError(
183
- "The incorrect_message function must be async."
177
+ # resolve tool calls (if any)
178
+ if state.output.message.tool_calls:
179
+ # call tool functions
180
+ messages, output = await execute_tools(state.messages, tools)
181
+ state.messages.extend(messages)
182
+ if output:
183
+ state.output = output
184
+
185
+ # check for a submission
186
+ answer = submission(messages)
187
+ if answer is not None:
188
+ # set the output to the answer for scoring
189
+ state.output.completion = (
190
+ f"{state.output.completion}\n\n{answer}".strip()
191
+ )
192
+
193
+ # exit if we are at max_attempts
194
+ attempt_count += 1
195
+ if attempt_count >= attempts.attempts:
196
+ break
197
+
198
+ # exit if the submission is successful
199
+ answer_scores = await score(state)
200
+ if attempts.score_value(answer_scores[0].value) == 1.0:
201
+ break
202
+
203
+ # otherwise notify the model that it was incorrect and continue
204
+ else:
205
+ if callable(attempts.incorrect_message):
206
+ if not is_callable_coroutine(attempts.incorrect_message):
207
+ raise ValueError(
208
+ "The incorrect_message function must be async."
209
+ )
210
+ response_message: str = await attempts.incorrect_message(
211
+ state, answer_scores
212
+ )
213
+ else:
214
+ response_message = attempts.incorrect_message
215
+
216
+ state.messages.append(ChatMessageUser(content=response_message))
217
+
218
+ # call the on_continue hook (if any)
219
+ if callable(on_continue):
220
+ if not is_callable_coroutine(on_continue):
221
+ raise ValueError("The on_continue function must be async.")
222
+ do_continue = await cast(AgentContinue, on_continue)(state)
223
+ if do_continue is True:
224
+ # if there were no tool calls we need to send back a user message
225
+ if not state.output.message.tool_calls:
226
+ state.messages.append(
227
+ ChatMessageUser(
228
+ content=DEFAULT_CONTINUE_PROMPT.format(
229
+ submit=submit.name
230
+ )
184
231
  )
185
- response_message: str = await attempts.incorrect_message(
186
- state, answer_scores
187
232
  )
188
- else:
189
- response_message = attempts.incorrect_message
190
-
191
- state.messages.append(ChatMessageUser(content=response_message))
192
-
193
- # no submitted answer, call tools and evaluate whether we should continue
194
- else:
195
- if state.output.message.tool_calls:
196
- # call tool functions
197
- messages, output = await execute_tools(state.messages, tools)
198
- state.messages.extend(messages)
199
- if output:
200
- state.output = output
201
-
202
- # check if we should continue....
203
- do_continue = await on_continue(state)
204
- if isinstance(do_continue, str):
205
- state.messages.append(ChatMessageUser(content=do_continue))
206
- elif do_continue is False:
233
+ elif isinstance(do_continue, str):
234
+ state.messages.append(
235
+ ChatMessageUser(content=do_continue.format(submit=submit.name))
236
+ )
237
+ else: # do_continue is False
207
238
  break
208
239
 
240
+ # if there is no on_continue hook then add a user message if there were no tool calls
241
+ elif not state.output.message.tool_calls:
242
+ continue_msg = (
243
+ DEFAULT_CONTINUE_PROMPT if on_continue is None else str(on_continue)
244
+ )
245
+ state.messages.append(
246
+ ChatMessageUser(content=continue_msg.format(submit=submit.name))
247
+ )
248
+
249
+ # once we are complete, remove submit tool calls from the history
250
+ # (as they will potentially confuse parent agents who also have
251
+ # their own submit tools that they are 'watching' for)
252
+ state.messages = _remove_submit_tool(state.messages, submit.name)
209
253
  return state
210
254
 
211
255
  if name is not None or description is not None:
@@ -239,3 +283,27 @@ def _model_generate(model: str | Model | None) -> Agent:
239
283
  return state
240
284
 
241
285
  return generate
286
+
287
+
288
+ def _remove_submit_tool(
289
+ messages: list[ChatMessage], submit_name: str
290
+ ) -> list[ChatMessage]:
291
+ filtered: list[ChatMessage] = []
292
+ for message in messages:
293
+ # skip submit tool messages
294
+ if isinstance(message, ChatMessageTool) and message.function == submit_name:
295
+ continue
296
+
297
+ # remove submit tool from assistant messages
298
+ if isinstance(message, ChatMessageAssistant) and message.tool_calls:
299
+ tools_calls = [
300
+ tool_call
301
+ for tool_call in message.tool_calls
302
+ if tool_call.function != submit_name
303
+ ]
304
+ message = message.model_copy(update=dict(tool_calls=tools_calls))
305
+
306
+ # always append message
307
+ filtered.append(message)
308
+
309
+ return filtered
inspect_ai/agent/_run.py CHANGED
@@ -27,10 +27,21 @@ async def run(
27
27
 
28
28
  # resolve str
29
29
  if isinstance(input, str):
30
- input = [ChatMessageUser(content=input)]
30
+ input_messages: list[ChatMessage] = [
31
+ ChatMessageUser(content=input, source="input")
32
+ ]
33
+ elif isinstance(input, list):
34
+ input_messages = [
35
+ message.model_copy(update=dict(source="input")) for message in input
36
+ ]
37
+ else:
38
+ input_messages = [
39
+ message.model_copy(update=dict(source="input"))
40
+ for message in input.messages
41
+ ]
31
42
 
32
43
  # create state
33
- state = AgentState(messages=input) if isinstance(input, list) else input
44
+ state = AgentState(messages=input_messages)
34
45
 
35
46
  # run the agent
36
47
  return await agent(state, **agent_kwargs)
@@ -40,6 +40,12 @@ class AgentPrompt(NamedTuple):
40
40
  """Prompt for assistant (covers tool use, submit tool, CoT, etc.)."""
41
41
 
42
42
 
43
+ DEFAULT_CONTINUE_PROMPT = """
44
+ Please proceed to the next step using your best judgement. If you believe you
45
+ have completed the task, please call the `{submit}()` tool.
46
+ """
47
+
48
+
43
49
  AgentContinue: TypeAlias = Callable[[AgentState], Awaitable[bool | str]]
44
50
  """Function called to determine whether the agent should continue.
45
51
 
@@ -2,6 +2,7 @@ from contextvars import ContextVar
2
2
 
3
3
  from inspect_ai._util.format import format_function_call
4
4
  from inspect_ai.approval._approval import Approval
5
+ from inspect_ai.model._chat_message import ChatMessage
5
6
  from inspect_ai.tool._tool_call import (
6
7
  ToolCall,
7
8
  ToolCallContent,
@@ -14,10 +15,11 @@ from ._policy import ApprovalPolicy, policy_approver
14
15
 
15
16
 
16
17
  async def apply_tool_approval(
17
- message: str, call: ToolCall, viewer: ToolCallViewer | None
18
+ message: str,
19
+ call: ToolCall,
20
+ viewer: ToolCallViewer | None,
21
+ history: list[ChatMessage],
18
22
  ) -> tuple[bool, Approval | None]:
19
- from inspect_ai.solver._task_state import sample_state
20
-
21
23
  approver = _tool_approver.get(None)
22
24
  if approver:
23
25
  # resolve view
@@ -28,15 +30,12 @@ async def apply_tool_approval(
28
30
  else:
29
31
  view = default_tool_call_viewer(call)
30
32
 
31
- # current sample state
32
- state = sample_state()
33
-
34
33
  # call approver
35
34
  approval = await approver(
36
35
  message=message,
37
36
  call=call,
38
37
  view=view,
39
- state=state,
38
+ history=history,
40
39
  )
41
40
 
42
41
  # process decision
@@ -1,6 +1,6 @@
1
1
  from typing import Protocol
2
2
 
3
- from inspect_ai.solver._task_state import TaskState
3
+ from inspect_ai.model._chat_message import ChatMessage
4
4
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
5
5
 
6
6
  from ._approval import Approval
@@ -14,7 +14,7 @@ class Approver(Protocol):
14
14
  message: str,
15
15
  call: ToolCall,
16
16
  view: ToolCallView,
17
- state: TaskState | None = None,
17
+ history: list[ChatMessage],
18
18
  ) -> Approval:
19
19
  """
20
20
  Approve or reject a tool call.
@@ -23,7 +23,7 @@ class Approver(Protocol):
23
23
  message: Message genreated by the model along with the tool call.
24
24
  call: The tool call to be approved.
25
25
  view: Custom rendering of tool context and call.
26
- state: The current task state, if available.
26
+ history: The current conversation history.
27
27
 
28
28
  Returns:
29
29
  Approval: An Approval object containing the decision and explanation.
@@ -1,4 +1,4 @@
1
- from inspect_ai.solver._task_state import TaskState
1
+ from inspect_ai.model._chat_message import ChatMessage
2
2
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
3
3
 
4
4
  from ._approval import Approval, ApprovalDecision
@@ -21,7 +21,7 @@ def auto_approver(decision: ApprovalDecision = "approve") -> Approver:
21
21
  message: str,
22
22
  call: ToolCall,
23
23
  view: ToolCallView,
24
- state: TaskState | None = None,
24
+ history: list[ChatMessage],
25
25
  ) -> Approval:
26
26
  return Approval(decision=decision, explanation="Automatic decision.")
27
27
 
@@ -1,20 +1,36 @@
1
+ import inspect
2
+ from logging import getLogger
3
+
4
+ from inspect_ai._util.logger import warn_once
1
5
  from inspect_ai._util.registry import registry_log_name
2
- from inspect_ai.solver._task_state import TaskState
6
+ from inspect_ai.model._chat_message import ChatMessage
3
7
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
4
8
 
5
9
  from ._approval import Approval
6
10
  from ._approver import Approver
7
11
 
12
+ logger = getLogger(__name__)
13
+
8
14
 
9
15
  async def call_approver(
10
16
  approver: Approver,
11
17
  message: str,
12
18
  call: ToolCall,
13
19
  view: ToolCallView,
14
- state: TaskState | None = None,
20
+ history: list[ChatMessage],
15
21
  ) -> Approval:
16
- # run approver
17
- approval = await approver(message, call, view, state)
22
+ # run approver (if the approval is still using state then
23
+ # provide that but issue a warning)
24
+ signature = inspect.signature(approver)
25
+ if "state" in signature.parameters.keys():
26
+ from inspect_ai.solver._task_state import sample_state
27
+
28
+ warn_once(
29
+ logger, "Approver 'state' parameter is deprecated (use 'history' instead)"
30
+ )
31
+ approval = await approver(message, call, view, sample_state()) # type: ignore[arg-type]
32
+ else:
33
+ approval = await approver(message, call, view, history)
18
34
 
19
35
  # record
20
36
  record_approval(registry_log_name(approver), message, call, view, approval)
@@ -1,4 +1,4 @@
1
- from inspect_ai.solver._task_state import TaskState
1
+ from inspect_ai.model._chat_message import ChatMessage
2
2
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
3
3
 
4
4
  from .._approval import Approval, ApprovalDecision
@@ -25,11 +25,11 @@ def human_approver(
25
25
  message: str,
26
26
  call: ToolCall,
27
27
  view: ToolCallView,
28
- state: TaskState | None = None,
28
+ history: list[ChatMessage],
29
29
  ) -> Approval:
30
30
  # try to use the panel approval (available in fullscreen display)
31
31
  try:
32
- return await panel_approval(message, call, view, state, choices)
32
+ return await panel_approval(message, call, view, history, choices)
33
33
 
34
34
  # fallback to plain console approval (available in all displays)
35
35
  except NotImplementedError:
@@ -3,7 +3,7 @@ from contextvars import ContextVar
3
3
  from typing import Callable, Literal, NamedTuple
4
4
 
5
5
  from inspect_ai._util.future import Future
6
- from inspect_ai.solver._task_state import TaskState
6
+ from inspect_ai.model._chat_message import ChatMessage
7
7
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
8
8
 
9
9
  from .._approval import Approval, ApprovalDecision
@@ -13,7 +13,7 @@ class ApprovalRequest(NamedTuple):
13
13
  message: str
14
14
  call: ToolCall
15
15
  view: ToolCallView
16
- state: TaskState | None
16
+ history: list[ChatMessage]
17
17
  choices: list[ApprovalDecision]
18
18
 
19
19
 
@@ -10,7 +10,7 @@ from textual.widgets import Button, Static
10
10
  from typing_extensions import override
11
11
 
12
12
  from inspect_ai._util.registry import registry_unqualified_name
13
- from inspect_ai.solver._task_state import TaskState
13
+ from inspect_ai.model._chat_message import ChatMessage
14
14
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
15
15
  from inspect_ai.util._panel import InputPanel, input_panel
16
16
 
@@ -29,7 +29,7 @@ async def panel_approval(
29
29
  message: str,
30
30
  call: ToolCall,
31
31
  view: ToolCallView,
32
- state: TaskState | None,
32
+ history: list[ChatMessage],
33
33
  choices: list[ApprovalDecision],
34
34
  ) -> Approval:
35
35
  # ensure the approvals panel is shown
@@ -39,7 +39,7 @@ async def panel_approval(
39
39
  approvals = human_approval_manager()
40
40
  id = approvals.request_approval(
41
41
  ApprovalRequest(
42
- message=message, call=call, view=view, state=state, choices=choices
42
+ message=message, call=call, view=view, history=history, choices=choices
43
43
  )
44
44
  )
45
45
  try:
@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, model_validator
9
9
  from inspect_ai._util.config import read_config_object
10
10
  from inspect_ai._util.format import format_function_call
11
11
  from inspect_ai._util.registry import registry_create, registry_lookup
12
- from inspect_ai.solver._task_state import TaskState
12
+ from inspect_ai.model._chat_message import ChatMessage
13
13
  from inspect_ai.tool._tool_call import ToolCall, ToolCallView
14
14
  from inspect_ai.util._resource import resource
15
15
 
@@ -59,13 +59,13 @@ def policy_approver(policies: str | list[ApprovalPolicy]) -> Approver:
59
59
  message: str,
60
60
  call: ToolCall,
61
61
  view: ToolCallView,
62
- state: TaskState | None = None,
62
+ history: list[ChatMessage],
63
63
  ) -> Approval:
64
64
  # process approvers for this tool call (continue loop on "escalate")
65
65
  has_approver = False
66
66
  for approver in tool_approvers(call):
67
67
  has_approver = True
68
- approval = await call_approver(approver, message, call, view, state)
68
+ approval = await call_approver(approver, message, call, view, history)
69
69
  if approval.decision != "escalate":
70
70
  return approval
71
71
 
@@ -19,6 +19,7 @@ from ._log import (
19
19
  EvalDataset,
20
20
  EvalLog,
21
21
  EvalMetric,
22
+ EvalModelConfig,
22
23
  EvalPlan,
23
24
  EvalPlanStep,
24
25
  EvalResults,
@@ -60,6 +61,7 @@ __all__ = [
60
61
  "EvalDataset",
61
62
  "EvalLog",
62
63
  "EvalMetric",
64
+ "EvalModelConfig",
63
65
  "EvalPlan",
64
66
  "EvalPlanStep",
65
67
  "EvalResults",
inspect_ai/log/_log.py CHANGED
@@ -64,7 +64,9 @@ class EvalConfig(BaseModel):
64
64
  limit: int | tuple[int, int] | None = Field(default=None)
65
65
  """Sample limit (number of samples or range of samples)."""
66
66
 
67
- sample_id: str | int | list[str | int] | None = Field(default=None)
67
+ sample_id: str | int | list[str] | list[int] | list[str | int] | None = Field(
68
+ default=None
69
+ )
68
70
  """Evaluate specific sample(s)."""
69
71
 
70
72
  epochs: int | None = Field(default=None)
@@ -507,7 +509,7 @@ class EvalDataset(BaseModel):
507
509
  samples: int | None = Field(default=None)
508
510
  """Number of samples in the dataset."""
509
511
 
510
- sample_ids: list[int | str] | None = Field(default=None)
512
+ sample_ids: list[str] | list[int] | list[str | int] | None = Field(default=None)
511
513
  """IDs of samples in the dataset."""
512
514
 
513
515
  shuffled: bool | None = Field(default=None)
@@ -551,6 +553,22 @@ class EvalRevision(BaseModel):
551
553
  """Revision commit."""
552
554
 
553
555
 
556
+ class EvalModelConfig(BaseModel):
557
+ """Model config."""
558
+
559
+ model: str
560
+ """Model name."""
561
+
562
+ config: GenerateConfig = Field(default_factory=GenerateConfig)
563
+ """Generate config"""
564
+
565
+ base_url: str | None = Field(default=None)
566
+ """Model base url."""
567
+
568
+ args: dict[str, Any] = Field(default_factory=dict)
569
+ """Model specific arguments."""
570
+
571
+
554
572
  class EvalSpec(BaseModel):
555
573
  """Eval target and configuration."""
556
574
 
@@ -608,6 +626,9 @@ class EvalSpec(BaseModel):
608
626
  model_args: dict[str, Any] = Field(default_factory=dict)
609
627
  """Model specific arguments."""
610
628
 
629
+ model_roles: dict[str, EvalModelConfig] | None = Field(default=None)
630
+ """Model roles."""
631
+
611
632
  config: EvalConfig
612
633
  """Configuration values for eval."""
613
634