inspect-ai 0.3.87__py3-none-any.whl → 0.3.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -244
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +55 -18
- inspect_ai/_view/www/dist/assets/index.js +550 -458
- inspect_ai/_view/www/log-schema.json +84 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +150 -129
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -9
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +15 -2
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +25 -14
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_call.py +3 -0
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +86 -81
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from pydantic_core import to_json
|
|
6
6
|
|
7
7
|
from inspect_ai._util._async import is_callable_coroutine
|
8
8
|
from inspect_ai.agent._agent import Agent, AgentState, agent
|
9
|
+
from inspect_ai.log._samples import sample_active
|
9
10
|
from inspect_ai.model._model import get_model
|
10
11
|
from inspect_ai.model._model_output import ModelOutput
|
11
12
|
from inspect_ai.model._providers.providers import validate_openai_client
|
@@ -37,6 +38,10 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Agen
|
|
37
38
|
class BridgeInput(BaseModel):
|
38
39
|
messages: list[ChatCompletionMessageParam]
|
39
40
|
|
41
|
+
# here for backward compatibilty w/ previous bridge
|
42
|
+
# (we may choose to add this to AgentState at some point)
|
43
|
+
metadata: dict[str, Any]
|
44
|
+
|
40
45
|
# temporarily here for backward compatibility w/ previous bridge
|
41
46
|
input: list[ChatCompletionMessageParam]
|
42
47
|
|
@@ -53,8 +58,10 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Agen
|
|
53
58
|
|
54
59
|
async def execute(state: AgentState) -> AgentState:
|
55
60
|
# create input (use standard gpt-4 message encoding -- i.e. no 'developer' messages)
|
56
|
-
|
57
|
-
|
61
|
+
sample = sample_active()
|
62
|
+
metadata = (sample.sample.metadata if sample is not None else None) or {}
|
63
|
+
messages = await openai_chat_messages(state.messages)
|
64
|
+
input = BridgeInput(messages=messages, metadata=metadata, input=messages)
|
58
65
|
|
59
66
|
# run target function
|
60
67
|
async with openai_request_to_inspect_model():
|
inspect_ai/agent/_react.py
CHANGED
@@ -1,22 +1,27 @@
|
|
1
1
|
from logging import getLogger
|
2
|
+
from typing import Literal, cast
|
2
3
|
|
3
4
|
from inspect_ai._util._async import is_callable_coroutine
|
4
5
|
from inspect_ai.model._call_tools import execute_tools
|
5
6
|
from inspect_ai.model._chat_message import (
|
6
7
|
ChatMessage,
|
8
|
+
ChatMessageAssistant,
|
7
9
|
ChatMessageSystem,
|
10
|
+
ChatMessageTool,
|
8
11
|
ChatMessageUser,
|
9
12
|
)
|
10
13
|
from inspect_ai.model._model import Model, get_model
|
14
|
+
from inspect_ai.model._trim import trim_messages
|
11
15
|
from inspect_ai.scorer._score import score
|
12
16
|
from inspect_ai.tool._tool import Tool, ToolResult, tool
|
13
|
-
from inspect_ai.tool._tool_call import ToolCall
|
14
17
|
from inspect_ai.tool._tool_info import parse_tool_info
|
15
18
|
from inspect_ai.tool._tool_with import tool_with
|
16
19
|
|
17
20
|
from ._agent import Agent, AgentState, agent, agent_with
|
21
|
+
from ._filter import MessageFilter
|
18
22
|
from ._handoff import has_handoff
|
19
23
|
from ._types import (
|
24
|
+
DEFAULT_CONTINUE_PROMPT,
|
20
25
|
AgentAttempts,
|
21
26
|
AgentContinue,
|
22
27
|
AgentPrompt,
|
@@ -37,6 +42,7 @@ def react(
|
|
37
42
|
attempts: int | AgentAttempts = 1,
|
38
43
|
submit: AgentSubmit = AgentSubmit(),
|
39
44
|
on_continue: str | AgentContinue | None = None,
|
45
|
+
truncation: Literal["auto", "disabled"] | MessageFilter = "disabled",
|
40
46
|
) -> Agent:
|
41
47
|
"""Extensible ReAct agent based on the paper [ReAct: Synergizing Reasoning and Acting in Language Models](https://arxiv.org/abs/2210.03629).
|
42
48
|
|
@@ -68,9 +74,16 @@ def react(
|
|
68
74
|
attempts: Configure agent to make multiple attempts.
|
69
75
|
submit: Configure submit tool used by agent.
|
70
76
|
on_continue: Message to play back to the model to urge it to continue.
|
71
|
-
|
72
|
-
|
73
|
-
to play back.
|
77
|
+
Use the placeholder {submit} to refer to the submit tool within the message.
|
78
|
+
Alternatively, an async function to call to determine whether the loop
|
79
|
+
should continue and what message to play back. Note that this function
|
80
|
+
is called on _every_ iteration of the loop so if you only want to send
|
81
|
+
a message back when the model fails to call tools you need to code
|
82
|
+
that behavior explicitly.
|
83
|
+
truncation: Truncate the conversation history in the event of a context
|
84
|
+
window overflow. Defaults to "disabled" which does no truncation. Pass
|
85
|
+
"auto" to use `trim_messages()` to reduce the context size. Pass a
|
86
|
+
`MessageFilter` function to do custom truncation.
|
74
87
|
|
75
88
|
Returns:
|
76
89
|
ReAct agent.
|
@@ -90,24 +103,6 @@ def react(
|
|
90
103
|
else:
|
91
104
|
system_message = None
|
92
105
|
|
93
|
-
# resolve on_continue
|
94
|
-
if on_continue is None:
|
95
|
-
on_continue = "If you believe you have completed the task, please call the `submit()` tool with your answer."
|
96
|
-
if isinstance(on_continue, str):
|
97
|
-
no_tools_continue_message = on_continue
|
98
|
-
|
99
|
-
async def no_tools_continue(state: AgentState) -> bool | str:
|
100
|
-
if state.output is None or not state.output.message.tool_calls:
|
101
|
-
return no_tools_continue_message
|
102
|
-
else:
|
103
|
-
return True
|
104
|
-
|
105
|
-
on_continue = no_tools_continue
|
106
|
-
|
107
|
-
# validate that on_continue is async
|
108
|
-
if not is_callable_coroutine(on_continue):
|
109
|
-
raise ValueError("The on_continue function must be async.")
|
110
|
-
|
111
106
|
# resolve attempts
|
112
107
|
attempts = AgentAttempts(attempts) if isinstance(attempts, int) else attempts
|
113
108
|
|
@@ -124,12 +119,17 @@ def react(
|
|
124
119
|
|
125
120
|
return execute
|
126
121
|
|
127
|
-
# helper to
|
128
|
-
def
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
122
|
+
# helper to extract a submitted answer
|
123
|
+
def submission(tool_results: list[ChatMessage]) -> str | None:
|
124
|
+
return next(
|
125
|
+
(
|
126
|
+
result.text
|
127
|
+
for result in tool_results
|
128
|
+
if isinstance(result, ChatMessageTool)
|
129
|
+
and result.function == submit.name
|
130
|
+
),
|
131
|
+
None,
|
132
|
+
)
|
133
133
|
|
134
134
|
# resolve tools
|
135
135
|
tools = tools or []
|
@@ -140,6 +140,14 @@ def react(
|
|
140
140
|
if system_message:
|
141
141
|
state.messages.insert(0, system_message)
|
142
142
|
|
143
|
+
# resolve overflow handling
|
144
|
+
if truncation == "auto":
|
145
|
+
overflow = cast(MessageFilter | None, trim_messages)
|
146
|
+
elif truncation == "disabled":
|
147
|
+
overflow = None
|
148
|
+
else:
|
149
|
+
overflow = truncation
|
150
|
+
|
143
151
|
# track attempts
|
144
152
|
attempt_count = 0
|
145
153
|
|
@@ -153,59 +161,95 @@ def react(
|
|
153
161
|
if state.output.stop_reason == "model_length":
|
154
162
|
from inspect_ai.log._transcript import transcript
|
155
163
|
|
164
|
+
if overflow is not None:
|
165
|
+
previous_messages = state.messages[:-1]
|
166
|
+
state.messages = await overflow(previous_messages)
|
167
|
+
if len(state.messages) < len(previous_messages):
|
168
|
+
transcript().info(
|
169
|
+
"Agent exceeded model context window, truncating messages and continuing."
|
170
|
+
)
|
171
|
+
continue
|
172
|
+
|
173
|
+
# no overflow policy or overflow didn't reduce conversation length
|
156
174
|
transcript().info("Agent terminated: model context window exceeded")
|
157
175
|
break
|
158
176
|
|
159
|
-
#
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
state.
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
if
|
181
|
-
|
182
|
-
|
183
|
-
|
177
|
+
# resolve tool calls (if any)
|
178
|
+
if state.output.message.tool_calls:
|
179
|
+
# call tool functions
|
180
|
+
messages, output = await execute_tools(state.messages, tools)
|
181
|
+
state.messages.extend(messages)
|
182
|
+
if output:
|
183
|
+
state.output = output
|
184
|
+
|
185
|
+
# check for a submission
|
186
|
+
answer = submission(messages)
|
187
|
+
if answer is not None:
|
188
|
+
# set the output to the answer for scoring
|
189
|
+
state.output.completion = (
|
190
|
+
f"{state.output.completion}\n\n{answer}".strip()
|
191
|
+
)
|
192
|
+
|
193
|
+
# exit if we are at max_attempts
|
194
|
+
attempt_count += 1
|
195
|
+
if attempt_count >= attempts.attempts:
|
196
|
+
break
|
197
|
+
|
198
|
+
# exit if the submission is successful
|
199
|
+
answer_scores = await score(state)
|
200
|
+
if attempts.score_value(answer_scores[0].value) == 1.0:
|
201
|
+
break
|
202
|
+
|
203
|
+
# otherwise notify the model that it was incorrect and continue
|
204
|
+
else:
|
205
|
+
if callable(attempts.incorrect_message):
|
206
|
+
if not is_callable_coroutine(attempts.incorrect_message):
|
207
|
+
raise ValueError(
|
208
|
+
"The incorrect_message function must be async."
|
209
|
+
)
|
210
|
+
response_message: str = await attempts.incorrect_message(
|
211
|
+
state, answer_scores
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
response_message = attempts.incorrect_message
|
215
|
+
|
216
|
+
state.messages.append(ChatMessageUser(content=response_message))
|
217
|
+
|
218
|
+
# call the on_continue hook (if any)
|
219
|
+
if callable(on_continue):
|
220
|
+
if not is_callable_coroutine(on_continue):
|
221
|
+
raise ValueError("The on_continue function must be async.")
|
222
|
+
do_continue = await cast(AgentContinue, on_continue)(state)
|
223
|
+
if do_continue is True:
|
224
|
+
# if there were no tool calls we need to send back a user message
|
225
|
+
if not state.output.message.tool_calls:
|
226
|
+
state.messages.append(
|
227
|
+
ChatMessageUser(
|
228
|
+
content=DEFAULT_CONTINUE_PROMPT.format(
|
229
|
+
submit=submit.name
|
230
|
+
)
|
184
231
|
)
|
185
|
-
response_message: str = await attempts.incorrect_message(
|
186
|
-
state, answer_scores
|
187
232
|
)
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
# no submitted answer, call tools and evaluate whether we should continue
|
194
|
-
else:
|
195
|
-
if state.output.message.tool_calls:
|
196
|
-
# call tool functions
|
197
|
-
messages, output = await execute_tools(state.messages, tools)
|
198
|
-
state.messages.extend(messages)
|
199
|
-
if output:
|
200
|
-
state.output = output
|
201
|
-
|
202
|
-
# check if we should continue....
|
203
|
-
do_continue = await on_continue(state)
|
204
|
-
if isinstance(do_continue, str):
|
205
|
-
state.messages.append(ChatMessageUser(content=do_continue))
|
206
|
-
elif do_continue is False:
|
233
|
+
elif isinstance(do_continue, str):
|
234
|
+
state.messages.append(
|
235
|
+
ChatMessageUser(content=do_continue.format(submit=submit.name))
|
236
|
+
)
|
237
|
+
else: # do_continue is False
|
207
238
|
break
|
208
239
|
|
240
|
+
# if there is no on_continue hook then add a user message if there were no tool calls
|
241
|
+
elif not state.output.message.tool_calls:
|
242
|
+
continue_msg = (
|
243
|
+
DEFAULT_CONTINUE_PROMPT if on_continue is None else str(on_continue)
|
244
|
+
)
|
245
|
+
state.messages.append(
|
246
|
+
ChatMessageUser(content=continue_msg.format(submit=submit.name))
|
247
|
+
)
|
248
|
+
|
249
|
+
# once we are complete, remove submit tool calls from the history
|
250
|
+
# (as they will potentially confuse parent agents who also have
|
251
|
+
# their own submit tools that they are 'watching' for)
|
252
|
+
state.messages = _remove_submit_tool(state.messages, submit.name)
|
209
253
|
return state
|
210
254
|
|
211
255
|
if name is not None or description is not None:
|
@@ -239,3 +283,27 @@ def _model_generate(model: str | Model | None) -> Agent:
|
|
239
283
|
return state
|
240
284
|
|
241
285
|
return generate
|
286
|
+
|
287
|
+
|
288
|
+
def _remove_submit_tool(
|
289
|
+
messages: list[ChatMessage], submit_name: str
|
290
|
+
) -> list[ChatMessage]:
|
291
|
+
filtered: list[ChatMessage] = []
|
292
|
+
for message in messages:
|
293
|
+
# skip submit tool messages
|
294
|
+
if isinstance(message, ChatMessageTool) and message.function == submit_name:
|
295
|
+
continue
|
296
|
+
|
297
|
+
# remove submit tool from assistant messages
|
298
|
+
if isinstance(message, ChatMessageAssistant) and message.tool_calls:
|
299
|
+
tools_calls = [
|
300
|
+
tool_call
|
301
|
+
for tool_call in message.tool_calls
|
302
|
+
if tool_call.function != submit_name
|
303
|
+
]
|
304
|
+
message = message.model_copy(update=dict(tool_calls=tools_calls))
|
305
|
+
|
306
|
+
# always append message
|
307
|
+
filtered.append(message)
|
308
|
+
|
309
|
+
return filtered
|
inspect_ai/agent/_run.py
CHANGED
@@ -27,10 +27,21 @@ async def run(
|
|
27
27
|
|
28
28
|
# resolve str
|
29
29
|
if isinstance(input, str):
|
30
|
-
|
30
|
+
input_messages: list[ChatMessage] = [
|
31
|
+
ChatMessageUser(content=input, source="input")
|
32
|
+
]
|
33
|
+
elif isinstance(input, list):
|
34
|
+
input_messages = [
|
35
|
+
message.model_copy(update=dict(source="input")) for message in input
|
36
|
+
]
|
37
|
+
else:
|
38
|
+
input_messages = [
|
39
|
+
message.model_copy(update=dict(source="input"))
|
40
|
+
for message in input.messages
|
41
|
+
]
|
31
42
|
|
32
43
|
# create state
|
33
|
-
state = AgentState(messages=
|
44
|
+
state = AgentState(messages=input_messages)
|
34
45
|
|
35
46
|
# run the agent
|
36
47
|
return await agent(state, **agent_kwargs)
|
inspect_ai/agent/_types.py
CHANGED
@@ -40,6 +40,12 @@ class AgentPrompt(NamedTuple):
|
|
40
40
|
"""Prompt for assistant (covers tool use, submit tool, CoT, etc.)."""
|
41
41
|
|
42
42
|
|
43
|
+
DEFAULT_CONTINUE_PROMPT = """
|
44
|
+
Please proceed to the next step using your best judgement. If you believe you
|
45
|
+
have completed the task, please call the `{submit}()` tool.
|
46
|
+
"""
|
47
|
+
|
48
|
+
|
43
49
|
AgentContinue: TypeAlias = Callable[[AgentState], Awaitable[bool | str]]
|
44
50
|
"""Function called to determine whether the agent should continue.
|
45
51
|
|
inspect_ai/approval/_apply.py
CHANGED
@@ -2,6 +2,7 @@ from contextvars import ContextVar
|
|
2
2
|
|
3
3
|
from inspect_ai._util.format import format_function_call
|
4
4
|
from inspect_ai.approval._approval import Approval
|
5
|
+
from inspect_ai.model._chat_message import ChatMessage
|
5
6
|
from inspect_ai.tool._tool_call import (
|
6
7
|
ToolCall,
|
7
8
|
ToolCallContent,
|
@@ -14,10 +15,11 @@ from ._policy import ApprovalPolicy, policy_approver
|
|
14
15
|
|
15
16
|
|
16
17
|
async def apply_tool_approval(
|
17
|
-
message: str,
|
18
|
+
message: str,
|
19
|
+
call: ToolCall,
|
20
|
+
viewer: ToolCallViewer | None,
|
21
|
+
history: list[ChatMessage],
|
18
22
|
) -> tuple[bool, Approval | None]:
|
19
|
-
from inspect_ai.solver._task_state import sample_state
|
20
|
-
|
21
23
|
approver = _tool_approver.get(None)
|
22
24
|
if approver:
|
23
25
|
# resolve view
|
@@ -28,15 +30,12 @@ async def apply_tool_approval(
|
|
28
30
|
else:
|
29
31
|
view = default_tool_call_viewer(call)
|
30
32
|
|
31
|
-
# current sample state
|
32
|
-
state = sample_state()
|
33
|
-
|
34
33
|
# call approver
|
35
34
|
approval = await approver(
|
36
35
|
message=message,
|
37
36
|
call=call,
|
38
37
|
view=view,
|
39
|
-
|
38
|
+
history=history,
|
40
39
|
)
|
41
40
|
|
42
41
|
# process decision
|
@@ -46,8 +45,6 @@ async def apply_tool_approval(
|
|
46
45
|
case "reject":
|
47
46
|
return False, approval
|
48
47
|
case "terminate":
|
49
|
-
if state:
|
50
|
-
state.completed = True
|
51
48
|
return False, approval
|
52
49
|
case "escalate":
|
53
50
|
raise RuntimeError("Unexpected 'escalate' from policy approver.")
|
inspect_ai/approval/_approver.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Protocol
|
2
2
|
|
3
|
-
from inspect_ai.
|
3
|
+
from inspect_ai.model._chat_message import ChatMessage
|
4
4
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
5
5
|
|
6
6
|
from ._approval import Approval
|
@@ -14,7 +14,7 @@ class Approver(Protocol):
|
|
14
14
|
message: str,
|
15
15
|
call: ToolCall,
|
16
16
|
view: ToolCallView,
|
17
|
-
|
17
|
+
history: list[ChatMessage],
|
18
18
|
) -> Approval:
|
19
19
|
"""
|
20
20
|
Approve or reject a tool call.
|
@@ -23,7 +23,7 @@ class Approver(Protocol):
|
|
23
23
|
message: Message genreated by the model along with the tool call.
|
24
24
|
call: The tool call to be approved.
|
25
25
|
view: Custom rendering of tool context and call.
|
26
|
-
|
26
|
+
history: The current conversation history.
|
27
27
|
|
28
28
|
Returns:
|
29
29
|
Approval: An Approval object containing the decision and explanation.
|
inspect_ai/approval/_auto.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from inspect_ai.
|
1
|
+
from inspect_ai.model._chat_message import ChatMessage
|
2
2
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
3
3
|
|
4
4
|
from ._approval import Approval, ApprovalDecision
|
@@ -21,7 +21,7 @@ def auto_approver(decision: ApprovalDecision = "approve") -> Approver:
|
|
21
21
|
message: str,
|
22
22
|
call: ToolCall,
|
23
23
|
view: ToolCallView,
|
24
|
-
|
24
|
+
history: list[ChatMessage],
|
25
25
|
) -> Approval:
|
26
26
|
return Approval(decision=decision, explanation="Automatic decision.")
|
27
27
|
|
inspect_ai/approval/_call.py
CHANGED
@@ -1,20 +1,36 @@
|
|
1
|
+
import inspect
|
2
|
+
from logging import getLogger
|
3
|
+
|
4
|
+
from inspect_ai._util.logger import warn_once
|
1
5
|
from inspect_ai._util.registry import registry_log_name
|
2
|
-
from inspect_ai.
|
6
|
+
from inspect_ai.model._chat_message import ChatMessage
|
3
7
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
4
8
|
|
5
9
|
from ._approval import Approval
|
6
10
|
from ._approver import Approver
|
7
11
|
|
12
|
+
logger = getLogger(__name__)
|
13
|
+
|
8
14
|
|
9
15
|
async def call_approver(
|
10
16
|
approver: Approver,
|
11
17
|
message: str,
|
12
18
|
call: ToolCall,
|
13
19
|
view: ToolCallView,
|
14
|
-
|
20
|
+
history: list[ChatMessage],
|
15
21
|
) -> Approval:
|
16
|
-
# run approver
|
17
|
-
|
22
|
+
# run approver (if the approval is still using state then
|
23
|
+
# provide that but issue a warning)
|
24
|
+
signature = inspect.signature(approver)
|
25
|
+
if "state" in signature.parameters.keys():
|
26
|
+
from inspect_ai.solver._task_state import sample_state
|
27
|
+
|
28
|
+
warn_once(
|
29
|
+
logger, "Approver 'state' parameter is deprecated (use 'history' instead)"
|
30
|
+
)
|
31
|
+
approval = await approver(message, call, view, sample_state()) # type: ignore[arg-type]
|
32
|
+
else:
|
33
|
+
approval = await approver(message, call, view, history)
|
18
34
|
|
19
35
|
# record
|
20
36
|
record_approval(registry_log_name(approver), message, call, view, approval)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from inspect_ai.
|
1
|
+
from inspect_ai.model._chat_message import ChatMessage
|
2
2
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
3
3
|
|
4
4
|
from .._approval import Approval, ApprovalDecision
|
@@ -25,11 +25,11 @@ def human_approver(
|
|
25
25
|
message: str,
|
26
26
|
call: ToolCall,
|
27
27
|
view: ToolCallView,
|
28
|
-
|
28
|
+
history: list[ChatMessage],
|
29
29
|
) -> Approval:
|
30
30
|
# try to use the panel approval (available in fullscreen display)
|
31
31
|
try:
|
32
|
-
return await panel_approval(message, call, view,
|
32
|
+
return await panel_approval(message, call, view, history, choices)
|
33
33
|
|
34
34
|
# fallback to plain console approval (available in all displays)
|
35
35
|
except NotImplementedError:
|
@@ -3,7 +3,7 @@ from contextvars import ContextVar
|
|
3
3
|
from typing import Callable, Literal, NamedTuple
|
4
4
|
|
5
5
|
from inspect_ai._util.future import Future
|
6
|
-
from inspect_ai.
|
6
|
+
from inspect_ai.model._chat_message import ChatMessage
|
7
7
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
8
8
|
|
9
9
|
from .._approval import Approval, ApprovalDecision
|
@@ -13,7 +13,7 @@ class ApprovalRequest(NamedTuple):
|
|
13
13
|
message: str
|
14
14
|
call: ToolCall
|
15
15
|
view: ToolCallView
|
16
|
-
|
16
|
+
history: list[ChatMessage]
|
17
17
|
choices: list[ApprovalDecision]
|
18
18
|
|
19
19
|
|
@@ -10,7 +10,7 @@ from textual.widgets import Button, Static
|
|
10
10
|
from typing_extensions import override
|
11
11
|
|
12
12
|
from inspect_ai._util.registry import registry_unqualified_name
|
13
|
-
from inspect_ai.
|
13
|
+
from inspect_ai.model._chat_message import ChatMessage
|
14
14
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
15
15
|
from inspect_ai.util._panel import InputPanel, input_panel
|
16
16
|
|
@@ -29,7 +29,7 @@ async def panel_approval(
|
|
29
29
|
message: str,
|
30
30
|
call: ToolCall,
|
31
31
|
view: ToolCallView,
|
32
|
-
|
32
|
+
history: list[ChatMessage],
|
33
33
|
choices: list[ApprovalDecision],
|
34
34
|
) -> Approval:
|
35
35
|
# ensure the approvals panel is shown
|
@@ -39,7 +39,7 @@ async def panel_approval(
|
|
39
39
|
approvals = human_approval_manager()
|
40
40
|
id = approvals.request_approval(
|
41
41
|
ApprovalRequest(
|
42
|
-
message=message, call=call, view=view,
|
42
|
+
message=message, call=call, view=view, history=history, choices=choices
|
43
43
|
)
|
44
44
|
)
|
45
45
|
try:
|
inspect_ai/approval/_policy.py
CHANGED
@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, model_validator
|
|
9
9
|
from inspect_ai._util.config import read_config_object
|
10
10
|
from inspect_ai._util.format import format_function_call
|
11
11
|
from inspect_ai._util.registry import registry_create, registry_lookup
|
12
|
-
from inspect_ai.
|
12
|
+
from inspect_ai.model._chat_message import ChatMessage
|
13
13
|
from inspect_ai.tool._tool_call import ToolCall, ToolCallView
|
14
14
|
from inspect_ai.util._resource import resource
|
15
15
|
|
@@ -59,13 +59,13 @@ def policy_approver(policies: str | list[ApprovalPolicy]) -> Approver:
|
|
59
59
|
message: str,
|
60
60
|
call: ToolCall,
|
61
61
|
view: ToolCallView,
|
62
|
-
|
62
|
+
history: list[ChatMessage],
|
63
63
|
) -> Approval:
|
64
64
|
# process approvers for this tool call (continue loop on "escalate")
|
65
65
|
has_approver = False
|
66
66
|
for approver in tool_approvers(call):
|
67
67
|
has_approver = True
|
68
|
-
approval = await call_approver(approver, message, call, view,
|
68
|
+
approval = await call_approver(approver, message, call, view, history)
|
69
69
|
if approval.decision != "escalate":
|
70
70
|
return approval
|
71
71
|
|
inspect_ai/log/__init__.py
CHANGED
@@ -19,6 +19,7 @@ from ._log import (
|
|
19
19
|
EvalDataset,
|
20
20
|
EvalLog,
|
21
21
|
EvalMetric,
|
22
|
+
EvalModelConfig,
|
22
23
|
EvalPlan,
|
23
24
|
EvalPlanStep,
|
24
25
|
EvalResults,
|
@@ -60,6 +61,7 @@ __all__ = [
|
|
60
61
|
"EvalDataset",
|
61
62
|
"EvalLog",
|
62
63
|
"EvalMetric",
|
64
|
+
"EvalModelConfig",
|
63
65
|
"EvalPlan",
|
64
66
|
"EvalPlanStep",
|
65
67
|
"EvalResults",
|
inspect_ai/log/_log.py
CHANGED
@@ -64,7 +64,9 @@ class EvalConfig(BaseModel):
|
|
64
64
|
limit: int | tuple[int, int] | None = Field(default=None)
|
65
65
|
"""Sample limit (number of samples or range of samples)."""
|
66
66
|
|
67
|
-
sample_id: str | int | list[str | int] | None = Field(
|
67
|
+
sample_id: str | int | list[str] | list[int] | list[str | int] | None = Field(
|
68
|
+
default=None
|
69
|
+
)
|
68
70
|
"""Evaluate specific sample(s)."""
|
69
71
|
|
70
72
|
epochs: int | None = Field(default=None)
|
@@ -507,7 +509,7 @@ class EvalDataset(BaseModel):
|
|
507
509
|
samples: int | None = Field(default=None)
|
508
510
|
"""Number of samples in the dataset."""
|
509
511
|
|
510
|
-
sample_ids: list[int | str] | None = Field(default=None)
|
512
|
+
sample_ids: list[str] | list[int] | list[str | int] | None = Field(default=None)
|
511
513
|
"""IDs of samples in the dataset."""
|
512
514
|
|
513
515
|
shuffled: bool | None = Field(default=None)
|
@@ -551,6 +553,22 @@ class EvalRevision(BaseModel):
|
|
551
553
|
"""Revision commit."""
|
552
554
|
|
553
555
|
|
556
|
+
class EvalModelConfig(BaseModel):
|
557
|
+
"""Model config."""
|
558
|
+
|
559
|
+
model: str
|
560
|
+
"""Model name."""
|
561
|
+
|
562
|
+
config: GenerateConfig = Field(default_factory=GenerateConfig)
|
563
|
+
"""Generate config"""
|
564
|
+
|
565
|
+
base_url: str | None = Field(default=None)
|
566
|
+
"""Model base url."""
|
567
|
+
|
568
|
+
args: dict[str, Any] = Field(default_factory=dict)
|
569
|
+
"""Model specific arguments."""
|
570
|
+
|
571
|
+
|
554
572
|
class EvalSpec(BaseModel):
|
555
573
|
"""Eval target and configuration."""
|
556
574
|
|
@@ -608,6 +626,9 @@ class EvalSpec(BaseModel):
|
|
608
626
|
model_args: dict[str, Any] = Field(default_factory=dict)
|
609
627
|
"""Model specific arguments."""
|
610
628
|
|
629
|
+
model_roles: dict[str, EvalModelConfig] | None = Field(default=None)
|
630
|
+
"""Model roles."""
|
631
|
+
|
611
632
|
config: EvalConfig
|
612
633
|
"""Configuration values for eval."""
|
613
634
|
|