inspect-ai 0.3.81__py3-none-any.whl → 0.3.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/eval.py +35 -2
- inspect_ai/_cli/util.py +44 -1
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +13 -4
- inspect_ai/_display/core/results.py +1 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +8 -8
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +44 -0
- inspect_ai/_eval/eval.py +74 -25
- inspect_ai/_eval/evalset.py +22 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +13 -15
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +55 -6
- inspect_ai/_eval/task/run.py +51 -10
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/file.py +30 -1
- inspect_ai/_util/json.py +37 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/server.py +113 -1
- inspect_ai/_view/www/App.css +7 -1
- inspect_ai/_view/www/dist/assets/index.css +813 -415
- inspect_ai/_view/www/dist/assets/index.js +54475 -32003
- inspect_ai/_view/www/eslint.config.mjs +1 -1
- inspect_ai/_view/www/log-schema.json +137 -31
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/package.json +11 -2
- inspect_ai/_view/www/src/App.tsx +161 -853
- inspect_ai/_view/www/src/api/api-browser.ts +176 -5
- inspect_ai/_view/www/src/api/api-vscode.ts +75 -1
- inspect_ai/_view/www/src/api/client-api.ts +66 -10
- inspect_ai/_view/www/src/api/jsonrpc.ts +2 -0
- inspect_ai/_view/www/src/api/types.ts +107 -2
- inspect_ai/_view/www/src/appearance/icons.ts +2 -0
- inspect_ai/_view/www/src/components/AsciinemaPlayer.tsx +3 -3
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/DownloadPanel.tsx +2 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +56 -61
- inspect_ai/_view/www/src/components/FindBand.tsx +17 -9
- inspect_ai/_view/www/src/components/HumanBaselineView.tsx +1 -1
- inspect_ai/_view/www/src/components/JsonPanel.tsx +14 -24
- inspect_ai/_view/www/src/components/LargeModal.tsx +2 -35
- inspect_ai/_view/www/src/components/LightboxCarousel.tsx +27 -11
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.module.css +11 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +177 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +116 -26
- inspect_ai/_view/www/src/components/MessageBand.tsx +14 -9
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/components/MorePopOver.tsx +3 -3
- inspect_ai/_view/www/src/components/NavPills.tsx +20 -8
- inspect_ai/_view/www/src/components/NoContentsPanel.module.css +12 -0
- inspect_ai/_view/www/src/components/NoContentsPanel.tsx +20 -0
- inspect_ai/_view/www/src/components/ProgressBar.module.css +5 -4
- inspect_ai/_view/www/src/components/ProgressBar.tsx +3 -2
- inspect_ai/_view/www/src/components/PulsingDots.module.css +81 -0
- inspect_ai/_view/www/src/components/PulsingDots.tsx +45 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -37
- inspect_ai/_view/www/src/components/ToolButton.tsx +3 -4
- inspect_ai/_view/www/src/index.tsx +26 -94
- inspect_ai/_view/www/src/logfile/remoteLogFile.ts +9 -1
- inspect_ai/_view/www/src/logfile/remoteZipFile.ts +30 -4
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +4 -6
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +1 -1
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +74 -28
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +58 -22
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +135 -104
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +10 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +83 -36
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +35 -30
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +1 -1
- inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +45 -53
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +6 -1
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +5 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +36 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.module.css +3 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +11 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +22 -46
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +34 -20
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -10
- inspect_ai/_view/www/src/samples/descriptor/types.ts +6 -5
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +22 -3
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +27 -2
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +122 -85
- inspect_ai/_view/www/src/samples/list/SampleRow.module.css +6 -0
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +28 -15
- inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +29 -18
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +28 -28
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +19 -9
- inspect_ai/_view/www/src/samples/sampleDataAdapter.ts +33 -0
- inspect_ai/_view/www/src/samples/sampleLimit.ts +2 -2
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +12 -27
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +10 -24
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -22
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +15 -24
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +6 -28
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +24 -34
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +33 -17
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +197 -338
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.module.css +16 -0
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.tsx +44 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +7 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +81 -60
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.module.css +23 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.tsx +27 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +29 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +102 -72
- inspect_ai/_view/www/src/scoring/utils.ts +87 -0
- inspect_ai/_view/www/src/state/appSlice.ts +244 -0
- inspect_ai/_view/www/src/state/hooks.ts +399 -0
- inspect_ai/_view/www/src/state/logPolling.ts +200 -0
- inspect_ai/_view/www/src/state/logSlice.ts +224 -0
- inspect_ai/_view/www/src/state/logsPolling.ts +118 -0
- inspect_ai/_view/www/src/state/logsSlice.ts +181 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +314 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +140 -0
- inspect_ai/_view/www/src/state/sampleUtils.ts +21 -0
- inspect_ai/_view/www/src/state/scrolling.ts +206 -0
- inspect_ai/_view/www/src/state/store.ts +168 -0
- inspect_ai/_view/www/src/state/store_filter.ts +84 -0
- inspect_ai/_view/www/src/state/utils.ts +23 -0
- inspect_ai/_view/www/src/storage/index.ts +26 -0
- inspect_ai/_view/www/src/types/log.d.ts +36 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/types.ts +94 -32
- inspect_ai/_view/www/src/utils/attachments.ts +58 -23
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/utils/logger.ts +52 -0
- inspect_ai/_view/www/src/utils/polling.ts +100 -0
- inspect_ai/_view/www/src/utils/react.ts +30 -0
- inspect_ai/_view/www/src/utils/vscode.ts +1 -1
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +184 -217
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +11 -53
- inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +8 -18
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +40 -22
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -1
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +159 -103
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.module.css +32 -0
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.tsx +32 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +12 -14
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +6 -2
- inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +4 -4
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +28 -13
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +5 -10
- inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +4 -4
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.module.css +22 -0
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.tsx +19 -0
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +128 -115
- inspect_ai/_view/www/src/workspace/tabs/grouping.ts +37 -5
- inspect_ai/_view/www/src/workspace/tabs/types.ts +4 -0
- inspect_ai/_view/www/src/workspace/types.ts +4 -3
- inspect_ai/_view/www/src/workspace/utils.ts +4 -4
- inspect_ai/_view/www/vite.config.js +6 -0
- inspect_ai/_view/www/yarn.lock +464 -355
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_condense.py +26 -0
- inspect_ai/log/_log.py +17 -5
- inspect_ai/log/_recorders/buffer/__init__.py +14 -0
- inspect_ai/log/_recorders/buffer/buffer.py +30 -0
- inspect_ai/log/_recorders/buffer/database.py +685 -0
- inspect_ai/log/_recorders/buffer/filestore.py +259 -0
- inspect_ai/log/_recorders/buffer/types.py +84 -0
- inspect_ai/log/_recorders/eval.py +2 -11
- inspect_ai/log/_recorders/types.py +30 -0
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +257 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_generate_config.py +2 -2
- inspect_ai/model/_model.py +114 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +4 -4
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +103 -62
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- inspect_ai/util/_subtask.py +1 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/RECORD +295 -229
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -169
- inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +0 -22
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
from ._agent import Agent, AgentState, agent, agent_with
|
2
|
+
from ._as_solver import as_solver
|
3
|
+
from ._as_tool import as_tool
|
4
|
+
from ._bridge.bridge import bridge
|
5
|
+
from ._filter import MessageFilter, last_message, remove_tools
|
6
|
+
from ._handoff import handoff
|
7
|
+
from ._human.agent import human_cli
|
8
|
+
from ._react import react
|
9
|
+
from ._run import run
|
10
|
+
from ._types import (
|
11
|
+
AgentAttempts,
|
12
|
+
AgentContinue,
|
13
|
+
AgentPrompt,
|
14
|
+
AgentSubmit,
|
15
|
+
)
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"react",
|
19
|
+
"bridge",
|
20
|
+
"human_cli",
|
21
|
+
"run",
|
22
|
+
"handoff",
|
23
|
+
"as_tool",
|
24
|
+
"as_solver",
|
25
|
+
"last_message",
|
26
|
+
"remove_tools",
|
27
|
+
"MessageFilter",
|
28
|
+
"Agent",
|
29
|
+
"AgentState",
|
30
|
+
"agent",
|
31
|
+
"agent_with",
|
32
|
+
"AgentPrompt",
|
33
|
+
"AgentAttempts",
|
34
|
+
"AgentContinue",
|
35
|
+
"AgentSubmit",
|
36
|
+
]
|
@@ -0,0 +1,268 @@
|
|
1
|
+
from copy import copy, deepcopy
|
2
|
+
from functools import wraps
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
ParamSpec,
|
7
|
+
Protocol,
|
8
|
+
TypeGuard,
|
9
|
+
cast,
|
10
|
+
overload,
|
11
|
+
runtime_checkable,
|
12
|
+
)
|
13
|
+
|
14
|
+
from inspect_ai._util.registry import (
|
15
|
+
RegistryInfo,
|
16
|
+
is_registry_object,
|
17
|
+
registry_add,
|
18
|
+
registry_info,
|
19
|
+
registry_name,
|
20
|
+
registry_tag,
|
21
|
+
set_registry_info,
|
22
|
+
)
|
23
|
+
from inspect_ai.model._chat_message import (
|
24
|
+
ChatMessage,
|
25
|
+
ChatMessageAssistant,
|
26
|
+
)
|
27
|
+
from inspect_ai.model._model_output import ChatCompletionChoice, ModelOutput
|
28
|
+
|
29
|
+
|
30
|
+
class AgentState:
|
31
|
+
"""Agent state."""
|
32
|
+
|
33
|
+
def __init__(self, *, messages: list[ChatMessage]) -> None:
|
34
|
+
self._messages = messages
|
35
|
+
self._output: ModelOutput | None = None
|
36
|
+
|
37
|
+
@property
|
38
|
+
def messages(self) -> list[ChatMessage]:
|
39
|
+
"""Conversation history."""
|
40
|
+
return self._messages
|
41
|
+
|
42
|
+
@messages.setter
|
43
|
+
def messages(self, messages: list[ChatMessage]) -> None:
|
44
|
+
"""Set the conversation history."""
|
45
|
+
self._messages = messages
|
46
|
+
|
47
|
+
@property
|
48
|
+
def output(self) -> ModelOutput:
|
49
|
+
"""Model output."""
|
50
|
+
# if there is no output yet then synthesize it from the last assistant message
|
51
|
+
if self._output is None:
|
52
|
+
# look for the last assistant message
|
53
|
+
for message in reversed(self.messages):
|
54
|
+
if isinstance(message, ChatMessageAssistant):
|
55
|
+
self._output = ModelOutput(
|
56
|
+
model=message.model or "",
|
57
|
+
choices=[
|
58
|
+
ChatCompletionChoice(
|
59
|
+
message=message.model_copy(),
|
60
|
+
stop_reason="stop",
|
61
|
+
)
|
62
|
+
],
|
63
|
+
)
|
64
|
+
|
65
|
+
# no assistant message, so generate an empty model output
|
66
|
+
if self._output is None:
|
67
|
+
self._output = ModelOutput()
|
68
|
+
|
69
|
+
return self._output
|
70
|
+
|
71
|
+
@output.setter
|
72
|
+
def output(self, output: ModelOutput) -> None:
|
73
|
+
"""Set the model output."""
|
74
|
+
self._output = output
|
75
|
+
|
76
|
+
def __copy__(self) -> "AgentState":
|
77
|
+
state = AgentState(messages=copy(self.messages))
|
78
|
+
state.output = self.output.model_copy()
|
79
|
+
return state
|
80
|
+
|
81
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "AgentState":
|
82
|
+
state = AgentState(messages=deepcopy(self.messages, memo))
|
83
|
+
state.output = self.output.model_copy(deep=True)
|
84
|
+
return state
|
85
|
+
|
86
|
+
|
87
|
+
@runtime_checkable
|
88
|
+
class Agent(Protocol):
|
89
|
+
async def __call__(
|
90
|
+
self,
|
91
|
+
state: AgentState,
|
92
|
+
*args: Any,
|
93
|
+
**kwargs: Any,
|
94
|
+
) -> AgentState:
|
95
|
+
"""Agents perform tasks and participate in conversations.
|
96
|
+
|
97
|
+
Agents are similar to tools however they are participants
|
98
|
+
in conversation history and can optionally append messages
|
99
|
+
and model output to the current conversation state.
|
100
|
+
|
101
|
+
You can give the model a tool that enables handoff to
|
102
|
+
your agent using the `handoff()` function.
|
103
|
+
|
104
|
+
You can create a simple tool (that receives a string as
|
105
|
+
input) from an agent using `as_tool()`.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
state: Agent state (conversation history and last model output)
|
109
|
+
*args: Arguments for the agent.
|
110
|
+
**kwargs: Keyword arguments for the agent.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
AgentState: Updated agent state.
|
114
|
+
"""
|
115
|
+
...
|
116
|
+
|
117
|
+
|
118
|
+
P = ParamSpec("P")
|
119
|
+
|
120
|
+
|
121
|
+
@overload
|
122
|
+
def agent(func: Callable[P, Agent]) -> Callable[P, Agent]: ...
|
123
|
+
|
124
|
+
|
125
|
+
@overload
|
126
|
+
def agent() -> Callable[[Callable[P, Agent]], Callable[P, Agent]]: ...
|
127
|
+
|
128
|
+
|
129
|
+
@overload
|
130
|
+
def agent(
|
131
|
+
*,
|
132
|
+
name: str | None = None,
|
133
|
+
description: str | None = None,
|
134
|
+
) -> Callable[[Callable[P, Agent]], Callable[P, Agent]]: ...
|
135
|
+
|
136
|
+
|
137
|
+
def agent(
|
138
|
+
func: Callable[P, Agent] | None = None,
|
139
|
+
*,
|
140
|
+
name: str | None = None,
|
141
|
+
description: str | None = None,
|
142
|
+
) -> Callable[P, Agent] | Callable[[Callable[P, Agent]], Callable[P, Agent]]:
|
143
|
+
r"""Decorator for registering agents.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
func: Agent function
|
147
|
+
name: Optional name for agent. If the decorator has no name
|
148
|
+
argument then the name of the agent creation function
|
149
|
+
will be used as the name of the agent.
|
150
|
+
description: Description for the agent when used as
|
151
|
+
an ordinary tool or handoff tool.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
Agent with registry attributes.
|
155
|
+
"""
|
156
|
+
|
157
|
+
def create_agent_wrapper(agent_type: Callable[P, Agent]) -> Callable[P, Agent]:
|
158
|
+
# determine the name (explicit or implicit from object)
|
159
|
+
agent_name = registry_name(
|
160
|
+
agent_type, name if name else getattr(agent_type, "__name__")
|
161
|
+
)
|
162
|
+
|
163
|
+
# wrap instantiations of agent so they carry registry info and metrics
|
164
|
+
@wraps(agent_type)
|
165
|
+
def agent_wrapper(*args: P.args, **kwargs: P.kwargs) -> Agent:
|
166
|
+
# create agent
|
167
|
+
agent = agent_type(*args, **kwargs)
|
168
|
+
|
169
|
+
# this might already have registry info, if so capture that
|
170
|
+
# and use it as default
|
171
|
+
if is_registry_object(agent):
|
172
|
+
info = registry_info(agent)
|
173
|
+
registry_name = info.name
|
174
|
+
registry_description = info.metadata.get(AGENT_DESCRIPTION, None)
|
175
|
+
else:
|
176
|
+
registry_name = None
|
177
|
+
registry_description = None
|
178
|
+
|
179
|
+
registry_tag(
|
180
|
+
agent_type,
|
181
|
+
agent,
|
182
|
+
RegistryInfo(
|
183
|
+
type="agent",
|
184
|
+
name=registry_name or agent_name,
|
185
|
+
metadata={AGENT_DESCRIPTION: registry_description or description},
|
186
|
+
),
|
187
|
+
*args,
|
188
|
+
**kwargs,
|
189
|
+
)
|
190
|
+
return agent
|
191
|
+
|
192
|
+
# register
|
193
|
+
return agent_register(cast(Callable[P, Agent], agent_wrapper), agent_name)
|
194
|
+
|
195
|
+
if func is not None:
|
196
|
+
return create_agent_wrapper(func)
|
197
|
+
else:
|
198
|
+
return create_agent_wrapper
|
199
|
+
|
200
|
+
|
201
|
+
def agent_with(
|
202
|
+
agent: Agent,
|
203
|
+
*,
|
204
|
+
name: str | None = None,
|
205
|
+
description: str | None = None,
|
206
|
+
) -> Agent:
|
207
|
+
"""Agent with modifications to name and/or description
|
208
|
+
|
209
|
+
This function modifies the passed agent in place and
|
210
|
+
returns it. If you want to create multiple variations
|
211
|
+
of a single agent using `agent_with()` you should create
|
212
|
+
the underlying agent multiple times.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
agent: Agent instance to modify.
|
216
|
+
name: Agent name (optional).
|
217
|
+
description: Agent description (optional).
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
The passed agent with the requested modifications.
|
221
|
+
"""
|
222
|
+
# resolve name and description
|
223
|
+
if is_registry_object(agent):
|
224
|
+
info = registry_info(agent)
|
225
|
+
name = name or info.name
|
226
|
+
description = description or info.metadata.get(AGENT_DESCRIPTION, None)
|
227
|
+
|
228
|
+
# if the name is null then raise
|
229
|
+
if name is None:
|
230
|
+
raise ValueError("You must provide a name to agent_with")
|
231
|
+
|
232
|
+
# now set registry info
|
233
|
+
set_registry_info(
|
234
|
+
agent,
|
235
|
+
RegistryInfo(
|
236
|
+
type="agent",
|
237
|
+
name=name,
|
238
|
+
metadata={AGENT_DESCRIPTION: description}
|
239
|
+
if description is not None
|
240
|
+
else {},
|
241
|
+
),
|
242
|
+
)
|
243
|
+
|
244
|
+
return agent
|
245
|
+
|
246
|
+
|
247
|
+
def agent_register(agent: Callable[P, Agent], name: str) -> Callable[P, Agent]:
|
248
|
+
r"""Register a function or class as an agent.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
agent: Agent function or a class derived from Agent.
|
252
|
+
name (str): Name of agent (Optional, defaults to object name)
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
Agent with registry attributes.
|
256
|
+
"""
|
257
|
+
registry_add(
|
258
|
+
agent,
|
259
|
+
RegistryInfo(type="agent", name=name),
|
260
|
+
)
|
261
|
+
return agent
|
262
|
+
|
263
|
+
|
264
|
+
def is_agent(obj: Any) -> TypeGuard[Agent]:
|
265
|
+
return is_registry_object(obj, type="agent")
|
266
|
+
|
267
|
+
|
268
|
+
AGENT_DESCRIPTION = "description"
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any
|
4
|
+
|
5
|
+
if TYPE_CHECKING:
|
6
|
+
from inspect_ai.solver._solver import Solver
|
7
|
+
|
8
|
+
from inspect_ai._util.registry import (
|
9
|
+
is_registry_object,
|
10
|
+
registry_unqualified_name,
|
11
|
+
)
|
12
|
+
from inspect_ai.tool._tool_info import parse_tool_info
|
13
|
+
|
14
|
+
from ._agent import Agent, AgentState
|
15
|
+
|
16
|
+
|
17
|
+
def as_solver(agent: Agent, **agent_kwargs: Any) -> Solver:
|
18
|
+
"""Convert an agent to a solver.
|
19
|
+
|
20
|
+
Note that agents used as solvers will only receive their first parameter
|
21
|
+
(`state`). Any other parameters must provide appropriate defaults
|
22
|
+
or be explicitly specified in `agent_kwargs`
|
23
|
+
|
24
|
+
Args:
|
25
|
+
agent: Agent to convert.
|
26
|
+
**agent_kwargs: Arguments to curry to Agent function (required
|
27
|
+
if the agent has parameters without default values).
|
28
|
+
|
29
|
+
Solver:
|
30
|
+
Solver from agent.
|
31
|
+
"""
|
32
|
+
from inspect_ai.solver._solver import Generate, solver
|
33
|
+
from inspect_ai.solver._task_state import TaskState
|
34
|
+
|
35
|
+
# agent must be registered (so we can get its name)
|
36
|
+
if not is_registry_object(agent):
|
37
|
+
raise RuntimeError(
|
38
|
+
"Agent passed to as_solver was not created by an @agent decorated function"
|
39
|
+
)
|
40
|
+
agent_name = registry_unqualified_name(agent)
|
41
|
+
|
42
|
+
# check to make sure we have all the parameters we need to run the agent
|
43
|
+
agent_info = parse_tool_info(agent)
|
44
|
+
for name, param in list(agent_info.parameters.properties.items())[1:]:
|
45
|
+
if param.default is None and name not in agent_kwargs:
|
46
|
+
raise ValueError(
|
47
|
+
f"To use the {agent_name} agent as a solver "
|
48
|
+
+ f"you must pass a value for the agent's required '{name}' "
|
49
|
+
+ "parameter to the as_solver() function."
|
50
|
+
)
|
51
|
+
|
52
|
+
@solver(name=agent_name)
|
53
|
+
def agent_to_solver() -> Solver:
|
54
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
55
|
+
# run agent
|
56
|
+
agent_state = await agent(
|
57
|
+
AgentState(messages=state.messages), **agent_kwargs
|
58
|
+
)
|
59
|
+
|
60
|
+
# update messages
|
61
|
+
state.messages = agent_state.messages
|
62
|
+
|
63
|
+
# update output if its not empty
|
64
|
+
if agent_state.output:
|
65
|
+
state.output = agent_state.output
|
66
|
+
|
67
|
+
return state
|
68
|
+
|
69
|
+
# return solver
|
70
|
+
return solve
|
71
|
+
|
72
|
+
return agent_to_solver()
|
@@ -0,0 +1,122 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from inspect_ai._util.registry import (
|
4
|
+
is_registry_object,
|
5
|
+
registry_info,
|
6
|
+
registry_unqualified_name,
|
7
|
+
)
|
8
|
+
from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageUser
|
9
|
+
from inspect_ai.tool._tool import Tool, ToolResult, tool
|
10
|
+
from inspect_ai.tool._tool_def import ToolDef, validate_tool_parameters
|
11
|
+
from inspect_ai.tool._tool_info import ToolInfo, parse_tool_info
|
12
|
+
from inspect_ai.tool._tool_params import ToolParam
|
13
|
+
|
14
|
+
from ._agent import AGENT_DESCRIPTION, Agent, AgentState
|
15
|
+
|
16
|
+
|
17
|
+
@tool
|
18
|
+
def as_tool(agent: Agent, description: str | None = None, **agent_kwargs: Any) -> Tool:
|
19
|
+
"""Convert an agent to a tool.
|
20
|
+
|
21
|
+
By default the model will see all of the agent's arguments as
|
22
|
+
tool arguments (save for `state` which is converted to an `input`
|
23
|
+
arguments of type `str`). Provide optional `agent_kwargs` to mask
|
24
|
+
out agent parameters with default values (these parameters will
|
25
|
+
not be presented to the model as part of the tool interface)
|
26
|
+
|
27
|
+
Args:
|
28
|
+
agent: Agent to convert.
|
29
|
+
description: Tool description (defaults to agent description)
|
30
|
+
**agent_kwargs: Arguments to curry to Agent function (arguments
|
31
|
+
provided here will not be presented to the model as part
|
32
|
+
of the tool interface).
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
Tool from agent.
|
36
|
+
"""
|
37
|
+
# agent must be registered (so we can get its name)
|
38
|
+
if not is_registry_object(agent):
|
39
|
+
raise RuntimeError(
|
40
|
+
"Agent passed to as_tool was not created by an @agent decorated function"
|
41
|
+
)
|
42
|
+
|
43
|
+
async def execute(input: str, *args: Any, **kwargs: Any) -> ToolResult:
|
44
|
+
# prepare state and call agent
|
45
|
+
state = AgentState(messages=[ChatMessageUser(content=input)])
|
46
|
+
state = await agent(state, *args, **(agent_kwargs | kwargs))
|
47
|
+
|
48
|
+
# find assistant message to read content from (prefer output)
|
49
|
+
if not state.output.empty:
|
50
|
+
return state.output.message.content
|
51
|
+
elif len(state.messages) > 0 and isinstance(
|
52
|
+
state.messages[-1], ChatMessageAssistant
|
53
|
+
):
|
54
|
+
return state.messages[-1].content
|
55
|
+
else:
|
56
|
+
return ""
|
57
|
+
|
58
|
+
# get tool_info
|
59
|
+
tool_info = agent_tool_info(agent, description, **agent_kwargs)
|
60
|
+
|
61
|
+
# add "input" param
|
62
|
+
tool_info.parameters.properties = {
|
63
|
+
"input": ToolParam(type="string", description="Input message.")
|
64
|
+
} | tool_info.parameters.properties
|
65
|
+
tool_info.parameters.required.append("input")
|
66
|
+
|
67
|
+
# create tool
|
68
|
+
tool_def = ToolDef(
|
69
|
+
execute,
|
70
|
+
name=tool_info.name,
|
71
|
+
description=tool_info.description,
|
72
|
+
parameters=tool_info.parameters,
|
73
|
+
)
|
74
|
+
return tool_def.as_tool()
|
75
|
+
|
76
|
+
|
77
|
+
def agent_tool_info(
|
78
|
+
agent: Agent, description: str | None, **agent_kwargs: Any
|
79
|
+
) -> ToolInfo:
|
80
|
+
# get tool_info and name
|
81
|
+
tool_info = parse_tool_info(agent)
|
82
|
+
tool_info.name = registry_unqualified_name(agent)
|
83
|
+
|
84
|
+
# remove "state" param
|
85
|
+
def remove_param(param: str) -> None:
|
86
|
+
if param in tool_info.parameters.properties:
|
87
|
+
del tool_info.parameters.properties[param]
|
88
|
+
if param in tool_info.parameters.required:
|
89
|
+
tool_info.parameters.required.remove(param)
|
90
|
+
|
91
|
+
remove_param("state")
|
92
|
+
|
93
|
+
# validate and remove curried params
|
94
|
+
for agent_param in agent_kwargs.keys():
|
95
|
+
if agent_param in tool_info.parameters.properties:
|
96
|
+
remove_param(agent_param)
|
97
|
+
else:
|
98
|
+
raise ValueError(
|
99
|
+
f"Agent {tool_info.name} does not have a '{agent_param}' parameter."
|
100
|
+
)
|
101
|
+
|
102
|
+
# resolve and validate description. the description in the call takes
|
103
|
+
# precedence, then any @agent(description="<foo>"), and finally any
|
104
|
+
# doc comment on the agent's execute function
|
105
|
+
reg_info = registry_info(agent)
|
106
|
+
tool_info.description = (
|
107
|
+
description
|
108
|
+
or reg_info.metadata.get(AGENT_DESCRIPTION, None)
|
109
|
+
or tool_info.description
|
110
|
+
)
|
111
|
+
if len(tool_info.description) == 0:
|
112
|
+
raise ValueError(
|
113
|
+
f"Description not provided for agent function '{tool_info.name}'. Provide a "
|
114
|
+
+ "description either via @agent(description='<description>'), the description "
|
115
|
+
+ "argument to as_tool() or handoff(), or via a doc comment on the agent's "
|
116
|
+
+ "execute function."
|
117
|
+
)
|
118
|
+
|
119
|
+
# validate parameter descriptions and types
|
120
|
+
validate_tool_parameters(tool_info.name, tool_info.parameters.properties)
|
121
|
+
|
122
|
+
return tool_info
|
@@ -5,17 +5,15 @@ from pydantic import BaseModel, Field, ValidationError
|
|
5
5
|
from pydantic_core import to_json
|
6
6
|
|
7
7
|
from inspect_ai._util._async import is_callable_coroutine
|
8
|
-
from inspect_ai.
|
8
|
+
from inspect_ai.agent._agent import Agent, AgentState, agent
|
9
|
+
from inspect_ai.model._model import get_model
|
10
|
+
from inspect_ai.model._model_output import ModelOutput
|
9
11
|
from inspect_ai.model._providers.providers import validate_openai_client
|
10
|
-
from inspect_ai.scorer._metric import Score
|
11
12
|
|
12
|
-
from .._solver import Generate, Solver, solver
|
13
|
-
from .._task_state import TaskState
|
14
13
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
"""Bridge an external agent into an Inspect Solver.
|
14
|
+
@agent
|
15
|
+
def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Agent:
|
16
|
+
"""Bridge an external agent into an Inspect Agent.
|
19
17
|
|
20
18
|
See documentation at <https://inspect.aisi.org.uk/agent-bridge.html>
|
21
19
|
|
@@ -25,7 +23,7 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solv
|
|
25
23
|
Returns:
|
26
24
|
Standard Inspect solver.
|
27
25
|
"""
|
28
|
-
validate_openai_client("
|
26
|
+
validate_openai_client("Agent bridge()")
|
29
27
|
|
30
28
|
from openai.types.chat import ChatCompletionMessageParam
|
31
29
|
|
@@ -36,17 +34,15 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solv
|
|
36
34
|
|
37
35
|
from .patch import openai_request_to_inspect_model
|
38
36
|
|
39
|
-
class
|
40
|
-
|
41
|
-
|
37
|
+
class BridgeInput(BaseModel):
|
38
|
+
messages: list[ChatCompletionMessageParam]
|
39
|
+
|
40
|
+
# temporarily here for backward compatibility w/ previous bridge
|
42
41
|
input: list[ChatCompletionMessageParam]
|
43
|
-
metadata: dict[str, Any]
|
44
|
-
target: list[str]
|
45
42
|
|
46
43
|
class BridgeResult(BaseModel):
|
47
44
|
output: str
|
48
45
|
messages: list[ChatCompletionMessageParam] | None = Field(default=None)
|
49
|
-
scores: dict[str, Score] | None = Field(default=None)
|
50
46
|
|
51
47
|
result_schema = BridgeResult.model_json_schema()
|
52
48
|
result_validator = Draft7Validator(result_schema)
|
@@ -55,27 +51,15 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solv
|
|
55
51
|
if not is_callable_coroutine(agent):
|
56
52
|
raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
|
57
53
|
|
58
|
-
async def
|
59
|
-
#
|
60
|
-
|
61
|
-
|
62
|
-
if isinstance(state.input, str)
|
63
|
-
else state.input
|
64
|
-
)
|
65
|
-
|
66
|
-
# create sample (use standard gpt-4 message encoding -- i.e. no 'developer' messages)
|
67
|
-
sample = BridgeSample(
|
68
|
-
sample_id=str(state.sample_id),
|
69
|
-
epoch=state.epoch,
|
70
|
-
input=await openai_chat_messages(input, model="gpt-4"),
|
71
|
-
metadata=state.metadata,
|
72
|
-
target=list(state.target),
|
73
|
-
)
|
54
|
+
async def execute(state: AgentState) -> AgentState:
|
55
|
+
# create input (use standard gpt-4 message encoding -- i.e. no 'developer' messages)
|
56
|
+
messages = await openai_chat_messages(state.messages, model="gpt-4")
|
57
|
+
input = BridgeInput(messages=messages, input=messages)
|
74
58
|
|
75
59
|
# run target function
|
76
60
|
async with openai_request_to_inspect_model():
|
77
61
|
# call the function
|
78
|
-
result_dict = await agent(
|
62
|
+
result_dict = await agent(input.model_dump())
|
79
63
|
try:
|
80
64
|
result = BridgeResult.model_validate(result_dict)
|
81
65
|
except ValidationError:
|
@@ -89,12 +73,14 @@ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solv
|
|
89
73
|
raise ValueError(message)
|
90
74
|
|
91
75
|
# update and return state
|
92
|
-
state.output
|
76
|
+
state.output = ModelOutput.from_content(
|
77
|
+
model=get_model().name, content=result.output
|
78
|
+
)
|
93
79
|
if result.messages is not None:
|
94
|
-
state.messages = chat_messages_from_openai(
|
95
|
-
|
96
|
-
|
80
|
+
state.messages = chat_messages_from_openai(
|
81
|
+
state.output.model, result.messages
|
82
|
+
)
|
97
83
|
|
98
84
|
return state
|
99
85
|
|
100
|
-
return
|
86
|
+
return execute
|
@@ -23,7 +23,6 @@ from inspect_ai.model._openai import (
|
|
23
23
|
openai_chat_choices,
|
24
24
|
openai_completion_usage,
|
25
25
|
)
|
26
|
-
from inspect_ai.solver._task_state import sample_state
|
27
26
|
from inspect_ai.tool._tool_choice import ToolChoice, ToolFunction
|
28
27
|
from inspect_ai.tool._tool_info import ToolInfo
|
29
28
|
from inspect_ai.tool._tool_params import ToolParams
|
@@ -98,10 +97,18 @@ def init_openai_request_patch() -> None:
|
|
98
97
|
async def inspect_model_request(
|
99
98
|
model_name: str, options: FinalRequestOptions
|
100
99
|
) -> ChatCompletion:
|
100
|
+
from inspect_ai.solver._task_state import sample_state
|
101
|
+
|
102
|
+
# resolve model
|
103
|
+
if model_name == "inspect":
|
104
|
+
model = get_model()
|
105
|
+
else:
|
106
|
+
model = get_model(model_name.removeprefix("inspect/"))
|
107
|
+
|
101
108
|
# convert openai messages to inspect messages
|
102
109
|
json_data = cast(dict[str, Any], options.json_data)
|
103
110
|
messages: list[ChatCompletionMessageParam] = json_data["messages"]
|
104
|
-
input = chat_messages_from_openai(messages)
|
111
|
+
input = chat_messages_from_openai(model.api.model_name, messages)
|
105
112
|
|
106
113
|
# convert openai tools to inspect tools
|
107
114
|
tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
|
@@ -130,12 +137,6 @@ async def inspect_model_request(
|
|
130
137
|
case _:
|
131
138
|
inspect_tool_choice = ToolFunction(name=tool_choice["function"]["name"])
|
132
139
|
|
133
|
-
# resolve model
|
134
|
-
if model_name == "inspect":
|
135
|
-
model = get_model()
|
136
|
-
else:
|
137
|
-
model = get_model(model_name.removeprefix("inspect/"))
|
138
|
-
|
139
140
|
output = await model.generate(
|
140
141
|
input=input,
|
141
142
|
tools=inspect_tools,
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Awaitable, Callable
|
2
|
+
|
3
|
+
from inspect_ai.model._chat_message import (
|
4
|
+
ChatMessage,
|
5
|
+
ChatMessageAssistant,
|
6
|
+
ChatMessageTool,
|
7
|
+
)
|
8
|
+
|
9
|
+
MessageFilter = Callable[[list[ChatMessage]], Awaitable[list[ChatMessage]]]
|
10
|
+
"""Filter messages sent to or received from agent handoffs."""
|
11
|
+
|
12
|
+
|
13
|
+
async def remove_tools(messages: list[ChatMessage]) -> list[ChatMessage]:
|
14
|
+
"""Remove tool calls from messages.
|
15
|
+
|
16
|
+
Removes all instances of `ChatMessageTool` as well as the `tool_calls`
|
17
|
+
field from `ChatMessageAssistant`.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
messages: Messages to remove tool calls from.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Messages without tool calls.
|
24
|
+
"""
|
25
|
+
filtered: list[ChatMessage] = []
|
26
|
+
for message in messages:
|
27
|
+
if isinstance(message, ChatMessageTool):
|
28
|
+
continue
|
29
|
+
if isinstance(message, ChatMessageAssistant):
|
30
|
+
message = message.model_copy(update=dict(tool_calls=None))
|
31
|
+
filtered.append(message)
|
32
|
+
|
33
|
+
return filtered
|
34
|
+
|
35
|
+
|
36
|
+
async def last_message(messages: list[ChatMessage]) -> list[ChatMessage]:
|
37
|
+
"""Remove all but the last message.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
messages: Target messages.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
List containing only the last message from the input list.
|
44
|
+
|
45
|
+
"""
|
46
|
+
return messages[-1:]
|