inspect-ai 0.3.81__py3-none-any.whl → 0.3.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/eval.py +35 -2
- inspect_ai/_cli/util.py +44 -1
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +13 -4
- inspect_ai/_display/core/results.py +1 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +8 -8
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +44 -0
- inspect_ai/_eval/eval.py +74 -25
- inspect_ai/_eval/evalset.py +22 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +13 -15
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +55 -6
- inspect_ai/_eval/task/run.py +51 -10
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/file.py +30 -1
- inspect_ai/_util/json.py +37 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/server.py +113 -1
- inspect_ai/_view/www/App.css +7 -1
- inspect_ai/_view/www/dist/assets/index.css +813 -415
- inspect_ai/_view/www/dist/assets/index.js +54475 -32003
- inspect_ai/_view/www/eslint.config.mjs +1 -1
- inspect_ai/_view/www/log-schema.json +137 -31
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/package.json +11 -2
- inspect_ai/_view/www/src/App.tsx +161 -853
- inspect_ai/_view/www/src/api/api-browser.ts +176 -5
- inspect_ai/_view/www/src/api/api-vscode.ts +75 -1
- inspect_ai/_view/www/src/api/client-api.ts +66 -10
- inspect_ai/_view/www/src/api/jsonrpc.ts +2 -0
- inspect_ai/_view/www/src/api/types.ts +107 -2
- inspect_ai/_view/www/src/appearance/icons.ts +2 -0
- inspect_ai/_view/www/src/components/AsciinemaPlayer.tsx +3 -3
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/DownloadPanel.tsx +2 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +56 -61
- inspect_ai/_view/www/src/components/FindBand.tsx +17 -9
- inspect_ai/_view/www/src/components/HumanBaselineView.tsx +1 -1
- inspect_ai/_view/www/src/components/JsonPanel.tsx +14 -24
- inspect_ai/_view/www/src/components/LargeModal.tsx +2 -35
- inspect_ai/_view/www/src/components/LightboxCarousel.tsx +27 -11
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.module.css +11 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +177 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +116 -26
- inspect_ai/_view/www/src/components/MessageBand.tsx +14 -9
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/components/MorePopOver.tsx +3 -3
- inspect_ai/_view/www/src/components/NavPills.tsx +20 -8
- inspect_ai/_view/www/src/components/NoContentsPanel.module.css +12 -0
- inspect_ai/_view/www/src/components/NoContentsPanel.tsx +20 -0
- inspect_ai/_view/www/src/components/ProgressBar.module.css +5 -4
- inspect_ai/_view/www/src/components/ProgressBar.tsx +3 -2
- inspect_ai/_view/www/src/components/PulsingDots.module.css +81 -0
- inspect_ai/_view/www/src/components/PulsingDots.tsx +45 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -37
- inspect_ai/_view/www/src/components/ToolButton.tsx +3 -4
- inspect_ai/_view/www/src/index.tsx +26 -94
- inspect_ai/_view/www/src/logfile/remoteLogFile.ts +9 -1
- inspect_ai/_view/www/src/logfile/remoteZipFile.ts +30 -4
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +4 -6
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +1 -1
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +74 -28
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +58 -22
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +135 -104
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +10 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +83 -36
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +35 -30
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +1 -1
- inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +45 -53
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +6 -1
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +5 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +36 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.module.css +3 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +11 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +22 -46
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +34 -20
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -10
- inspect_ai/_view/www/src/samples/descriptor/types.ts +6 -5
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +22 -3
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +27 -2
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +122 -85
- inspect_ai/_view/www/src/samples/list/SampleRow.module.css +6 -0
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +28 -15
- inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +29 -18
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +28 -28
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +19 -9
- inspect_ai/_view/www/src/samples/sampleDataAdapter.ts +33 -0
- inspect_ai/_view/www/src/samples/sampleLimit.ts +2 -2
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +12 -27
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +10 -24
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -22
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +15 -24
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +6 -28
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +24 -34
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +33 -17
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +197 -338
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.module.css +16 -0
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.tsx +44 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +7 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +81 -60
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.module.css +23 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.tsx +27 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +29 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +102 -72
- inspect_ai/_view/www/src/scoring/utils.ts +87 -0
- inspect_ai/_view/www/src/state/appSlice.ts +244 -0
- inspect_ai/_view/www/src/state/hooks.ts +399 -0
- inspect_ai/_view/www/src/state/logPolling.ts +200 -0
- inspect_ai/_view/www/src/state/logSlice.ts +224 -0
- inspect_ai/_view/www/src/state/logsPolling.ts +118 -0
- inspect_ai/_view/www/src/state/logsSlice.ts +181 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +314 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +140 -0
- inspect_ai/_view/www/src/state/sampleUtils.ts +21 -0
- inspect_ai/_view/www/src/state/scrolling.ts +206 -0
- inspect_ai/_view/www/src/state/store.ts +168 -0
- inspect_ai/_view/www/src/state/store_filter.ts +84 -0
- inspect_ai/_view/www/src/state/utils.ts +23 -0
- inspect_ai/_view/www/src/storage/index.ts +26 -0
- inspect_ai/_view/www/src/types/log.d.ts +36 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/types.ts +94 -32
- inspect_ai/_view/www/src/utils/attachments.ts +58 -23
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/utils/logger.ts +52 -0
- inspect_ai/_view/www/src/utils/polling.ts +100 -0
- inspect_ai/_view/www/src/utils/react.ts +30 -0
- inspect_ai/_view/www/src/utils/vscode.ts +1 -1
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +184 -217
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +11 -53
- inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +8 -18
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +40 -22
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -1
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +159 -103
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.module.css +32 -0
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.tsx +32 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +12 -14
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +6 -2
- inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +4 -4
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +28 -13
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +5 -10
- inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +4 -4
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.module.css +22 -0
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.tsx +19 -0
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +128 -115
- inspect_ai/_view/www/src/workspace/tabs/grouping.ts +37 -5
- inspect_ai/_view/www/src/workspace/tabs/types.ts +4 -0
- inspect_ai/_view/www/src/workspace/types.ts +4 -3
- inspect_ai/_view/www/src/workspace/utils.ts +4 -4
- inspect_ai/_view/www/vite.config.js +6 -0
- inspect_ai/_view/www/yarn.lock +464 -355
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_condense.py +26 -0
- inspect_ai/log/_log.py +17 -5
- inspect_ai/log/_recorders/buffer/__init__.py +14 -0
- inspect_ai/log/_recorders/buffer/buffer.py +30 -0
- inspect_ai/log/_recorders/buffer/database.py +685 -0
- inspect_ai/log/_recorders/buffer/filestore.py +259 -0
- inspect_ai/log/_recorders/buffer/types.py +84 -0
- inspect_ai/log/_recorders/eval.py +2 -11
- inspect_ai/log/_recorders/types.py +30 -0
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +257 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_generate_config.py +2 -2
- inspect_ai/model/_model.py +114 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +4 -4
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +103 -62
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- inspect_ai/util/_subtask.py +1 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/RECORD +295 -229
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -169
- inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +0 -22
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,7 @@
|
|
1
1
|
from logging import getLogger
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
-
from openai import
|
5
|
-
AsyncAzureOpenAI,
|
6
|
-
AsyncOpenAI,
|
7
|
-
BadRequestError,
|
8
|
-
)
|
4
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI, BadRequestError
|
9
5
|
from openai._types import NOT_GIVEN
|
10
6
|
from openai.types.responses import Response, ResponseFormatTextJSONSchemaConfigParam
|
11
7
|
|
@@ -15,12 +11,10 @@ from inspect_ai.tool import ToolChoice, ToolInfo
|
|
15
11
|
from .._chat_message import ChatMessage
|
16
12
|
from .._generate_config import GenerateConfig
|
17
13
|
from .._model_call import ModelCall
|
18
|
-
from .._model_output import
|
19
|
-
ModelOutput,
|
20
|
-
ModelUsage,
|
21
|
-
)
|
14
|
+
from .._model_output import ModelOutput, ModelUsage
|
22
15
|
from .._openai import (
|
23
16
|
OpenAIResponseError,
|
17
|
+
is_computer_use_preview,
|
24
18
|
is_gpt,
|
25
19
|
is_o1_mini,
|
26
20
|
is_o1_preview,
|
@@ -65,12 +59,14 @@ async def generate_responses(
|
|
65
59
|
)
|
66
60
|
|
67
61
|
# prepare request (we do this so we can log the ModelCall)
|
62
|
+
tool_params = openai_responses_tools(tools, config) if len(tools) > 0 else NOT_GIVEN
|
68
63
|
request = dict(
|
69
64
|
input=await openai_responses_inputs(input, model_name),
|
70
|
-
tools=
|
71
|
-
tool_choice=openai_responses_tool_choice(tool_choice)
|
72
|
-
if
|
65
|
+
tools=tool_params,
|
66
|
+
tool_choice=openai_responses_tool_choice(tool_choice, tool_params)
|
67
|
+
if isinstance(tool_params, list) and tool_choice != "auto"
|
73
68
|
else NOT_GIVEN,
|
69
|
+
truncation="auto" if is_computer_use_preview(model_name) else NOT_GIVEN,
|
74
70
|
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
75
71
|
**completion_params_responses(model_name, config, len(tools) > 0),
|
76
72
|
)
|
@@ -89,7 +85,7 @@ async def generate_responses(
|
|
89
85
|
response = model_response.model_dump()
|
90
86
|
|
91
87
|
# parse out choices
|
92
|
-
choices = openai_responses_chat_choices(model_response, tools)
|
88
|
+
choices = openai_responses_chat_choices(model_name, model_response, tools)
|
93
89
|
|
94
90
|
# return output and call
|
95
91
|
return ModelOutput(
|
@@ -124,7 +120,9 @@ def completion_params_responses(
|
|
124
120
|
f"OpenAI Responses API does not support the '{param}' parameter.",
|
125
121
|
)
|
126
122
|
|
127
|
-
params: dict[str, Any] = dict(
|
123
|
+
params: dict[str, Any] = dict(
|
124
|
+
model=model_name, store=is_computer_use_preview(model_name)
|
125
|
+
)
|
128
126
|
if config.max_tokens is not None:
|
129
127
|
params["max_output_tokens"] = config.max_tokens
|
130
128
|
if config.frequency_penalty is not None:
|
@@ -48,7 +48,7 @@ def openai() -> type[ModelAPI]:
|
|
48
48
|
def anthropic() -> type[ModelAPI]:
|
49
49
|
FEATURE = "Anthropic API"
|
50
50
|
PACKAGE = "anthropic"
|
51
|
-
MIN_VERSION = "0.
|
51
|
+
MIN_VERSION = "0.49.0"
|
52
52
|
|
53
53
|
# verify we have the package
|
54
54
|
try:
|
@@ -278,7 +278,7 @@ def goodfire() -> type[ModelAPI]:
|
|
278
278
|
def validate_openai_client(feature: str) -> None:
|
279
279
|
FEATURE = feature
|
280
280
|
PACKAGE = "openai"
|
281
|
-
MIN_VERSION = "1.
|
281
|
+
MIN_VERSION = "1.69.0"
|
282
282
|
|
283
283
|
# verify we have the package
|
284
284
|
try:
|
@@ -68,7 +68,9 @@ def chat_choices_from_response_together(
|
|
68
68
|
logprobs_models.append(Logprobs(content=logprobs_sequence))
|
69
69
|
return [
|
70
70
|
ChatCompletionChoice(
|
71
|
-
message=chat_message_assistant_from_openai(
|
71
|
+
message=chat_message_assistant_from_openai(
|
72
|
+
response.model, choice.message, tools
|
73
|
+
),
|
72
74
|
stop_reason=as_stop_reason(choice.finish_reason),
|
73
75
|
logprobs=logprobs,
|
74
76
|
)
|
@@ -116,6 +118,14 @@ class TogetherAIAPI(OpenAIAPI):
|
|
116
118
|
else:
|
117
119
|
return ex
|
118
120
|
|
121
|
+
@override
|
122
|
+
def set_logprobs_params(
|
123
|
+
self, params: dict[str, Any], config: GenerateConfig
|
124
|
+
) -> dict[str, Any]:
|
125
|
+
if config.logprobs is True:
|
126
|
+
params["logprobs"] = 1
|
127
|
+
return params
|
128
|
+
|
119
129
|
# Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
|
120
130
|
def _chat_choices_from_response(
|
121
131
|
self, response: ChatCompletion, tools: list[ToolInfo]
|
@@ -228,7 +238,7 @@ class TogetherRESTAPI(ModelAPI):
|
|
228
238
|
return DEFAULT_MAX_TOKENS
|
229
239
|
|
230
240
|
def chat_api_handler(self) -> ChatAPIHandler:
|
231
|
-
return ChatAPIHandler()
|
241
|
+
return ChatAPIHandler(self.model_name)
|
232
242
|
|
233
243
|
|
234
244
|
def together_choices(
|
@@ -23,6 +23,9 @@ ChatAPIMessage = dict[Literal["role", "content"], str]
|
|
23
23
|
|
24
24
|
|
25
25
|
class ChatAPIHandler:
|
26
|
+
def __init__(self, model: str) -> None:
|
27
|
+
self.model = model
|
28
|
+
|
26
29
|
def input_with_tools(
|
27
30
|
self, input: list[ChatMessage], tools: list[ToolInfo]
|
28
31
|
) -> list[ChatMessage]:
|
@@ -31,7 +34,9 @@ class ChatAPIHandler:
|
|
31
34
|
def parse_assistant_response(
|
32
35
|
self, response: str, tools: list[ToolInfo]
|
33
36
|
) -> ChatMessageAssistant:
|
34
|
-
return ChatMessageAssistant(
|
37
|
+
return ChatMessageAssistant(
|
38
|
+
content=response, model=self.model, source="generate"
|
39
|
+
)
|
35
40
|
|
36
41
|
def assistant_message(self, message: ChatMessageAssistant) -> ChatAPIMessage:
|
37
42
|
return {"role": "assistant", "content": message.text}
|
@@ -48,7 +53,7 @@ class ChatAPIHandler:
|
|
48
53
|
def chat_api_input(
|
49
54
|
input: list[ChatMessage],
|
50
55
|
tools: list[ToolInfo],
|
51
|
-
handler: ChatAPIHandler
|
56
|
+
handler: ChatAPIHandler,
|
52
57
|
) -> list[ChatAPIMessage]:
|
53
58
|
# add tools to input
|
54
59
|
if len(tools) > 0:
|
@@ -50,13 +50,16 @@ class HFHandler(ChatAPIHandler):
|
|
50
50
|
return ChatMessageAssistant(
|
51
51
|
content=content,
|
52
52
|
tool_calls=tool_calls,
|
53
|
+
model=self.model_name,
|
53
54
|
source="generate",
|
54
55
|
)
|
55
56
|
|
56
57
|
# otherwise this is just an ordinary assistant message
|
57
58
|
else:
|
58
59
|
return ChatMessageAssistant(
|
59
|
-
content=filter_assistant_header(response),
|
60
|
+
content=filter_assistant_header(response),
|
61
|
+
model=self.model_name,
|
62
|
+
source="generate",
|
60
63
|
)
|
61
64
|
|
62
65
|
|
@@ -106,7 +109,6 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
|
|
106
109
|
id="unknown",
|
107
110
|
function="unknown",
|
108
111
|
arguments={},
|
109
|
-
type="function",
|
110
112
|
parse_error=parse_error,
|
111
113
|
)
|
112
114
|
|
@@ -106,13 +106,16 @@ class Llama31Handler(ChatAPIHandler):
|
|
106
106
|
return ChatMessageAssistant(
|
107
107
|
content=filter_assistant_header(content),
|
108
108
|
tool_calls=tool_calls,
|
109
|
+
model=self.model,
|
109
110
|
source="generate",
|
110
111
|
)
|
111
112
|
|
112
113
|
# otherwise this is just an ordinary assistant message
|
113
114
|
else:
|
114
115
|
return ChatMessageAssistant(
|
115
|
-
content=filter_assistant_header(response),
|
116
|
+
content=filter_assistant_header(response),
|
117
|
+
model=self.model,
|
118
|
+
source="generate",
|
116
119
|
)
|
117
120
|
|
118
121
|
@override
|
@@ -184,7 +187,6 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
|
|
184
187
|
id="unknown",
|
185
188
|
function="unknown",
|
186
189
|
arguments={},
|
187
|
-
type="function",
|
188
190
|
parse_error=parse_error,
|
189
191
|
)
|
190
192
|
|
@@ -116,11 +116,6 @@ class VertexAPI(ModelAPI):
|
|
116
116
|
|
117
117
|
self.model = GenerativeModel(model_name)
|
118
118
|
|
119
|
-
@override
|
120
|
-
async def close(self) -> None:
|
121
|
-
# GenerativeModel uses a cached/shared client so there is no 'close'
|
122
|
-
pass
|
123
|
-
|
124
119
|
async def generate(
|
125
120
|
self,
|
126
121
|
input: list[ChatMessage],
|
@@ -155,7 +150,9 @@ class VertexAPI(ModelAPI):
|
|
155
150
|
# capture output
|
156
151
|
output = ModelOutput(
|
157
152
|
model=self.model_name,
|
158
|
-
choices=completion_choices_from_candidates(
|
153
|
+
choices=completion_choices_from_candidates(
|
154
|
+
self.model_name, response.candidates
|
155
|
+
),
|
159
156
|
usage=ModelUsage(
|
160
157
|
input_tokens=response.usage_metadata.prompt_token_count,
|
161
158
|
output_tokens=response.usage_metadata.candidates_token_count,
|
@@ -377,7 +374,9 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
377
374
|
return [Tool(function_declarations=declarations)]
|
378
375
|
|
379
376
|
|
380
|
-
def completion_choice_from_candidate(
|
377
|
+
def completion_choice_from_candidate(
|
378
|
+
model: str, candidate: Candidate
|
379
|
+
) -> ChatCompletionChoice:
|
381
380
|
# check for completion text
|
382
381
|
content = " ".join(
|
383
382
|
[
|
@@ -394,7 +393,6 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
394
393
|
function_call = MessageToDict(getattr(part.function_call, "_pb"))
|
395
394
|
tool_calls.append(
|
396
395
|
ToolCall(
|
397
|
-
type="function",
|
398
396
|
id=function_call["name"],
|
399
397
|
function=function_call["name"],
|
400
398
|
arguments=function_call["args"],
|
@@ -408,6 +406,7 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
408
406
|
message=ChatMessageAssistant(
|
409
407
|
content=content,
|
410
408
|
tool_calls=tool_calls if len(tool_calls) > 0 else None,
|
409
|
+
model=model,
|
411
410
|
source="generate",
|
412
411
|
),
|
413
412
|
stop_reason=stop_reason,
|
@@ -435,11 +434,14 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
435
434
|
|
436
435
|
|
437
436
|
def completion_choices_from_candidates(
|
437
|
+
model: str,
|
438
438
|
candidates: list[Candidate],
|
439
439
|
) -> list[ChatCompletionChoice]:
|
440
440
|
candidates = copy(candidates)
|
441
441
|
candidates.sort(key=lambda c: c.index)
|
442
|
-
return [
|
442
|
+
return [
|
443
|
+
completion_choice_from_candidate(model, candidate) for candidate in candidates
|
444
|
+
]
|
443
445
|
|
444
446
|
|
445
447
|
def candidate_stop_reason(finish_reason: FinishReason) -> StopReason:
|
@@ -28,7 +28,7 @@ from .._model_output import (
|
|
28
28
|
StopReason,
|
29
29
|
TopLogprob,
|
30
30
|
)
|
31
|
-
from .util import chat_api_input
|
31
|
+
from .util import ChatAPIHandler, chat_api_input
|
32
32
|
|
33
33
|
DEFAULT_START_TOKEN = "<|im_start|>"
|
34
34
|
DEFAULT_END_TOKEN = "<|im_end|>"
|
@@ -137,7 +137,7 @@ class VLLMAPI(ModelAPI):
|
|
137
137
|
self.tokenizer = self.model.get_tokenizer()
|
138
138
|
|
139
139
|
@override
|
140
|
-
|
140
|
+
def close(self) -> None:
|
141
141
|
self.tokenizer = None
|
142
142
|
self.model = None
|
143
143
|
gc.collect()
|
@@ -148,7 +148,7 @@ class VLLMAPI(ModelAPI):
|
|
148
148
|
# handle system message and consecutive user messages
|
149
149
|
messages = simple_input_messages(messages)
|
150
150
|
# convert to chat template input format
|
151
|
-
chat_messages = chat_api_input(messages, tools)
|
151
|
+
chat_messages = chat_api_input(messages, tools, ChatAPIHandler(self.model_name))
|
152
152
|
# apply chat template
|
153
153
|
chat = self.tokenizer.apply_chat_template(
|
154
154
|
chat_messages,
|
@@ -253,7 +253,7 @@ class VLLMAPI(ModelAPI):
|
|
253
253
|
choices = [
|
254
254
|
ChatCompletionChoice(
|
255
255
|
message=ChatMessageAssistant(
|
256
|
-
content=response.output, source="generate"
|
256
|
+
content=response.output, model=self.model_name, source="generate"
|
257
257
|
),
|
258
258
|
stop_reason=response.stop_reason,
|
259
259
|
logprobs=response.logprobs,
|
inspect_ai/scorer/__init__.py
CHANGED
@@ -19,6 +19,7 @@ from ._metric import (
|
|
19
19
|
value_to_float,
|
20
20
|
)
|
21
21
|
from ._metrics.accuracy import accuracy
|
22
|
+
from ._metrics.grouped import grouped
|
22
23
|
from ._metrics.mean import mean
|
23
24
|
from ._metrics.std import bootstrap_stderr, std, stderr, var
|
24
25
|
from ._model import model_graded_fact, model_graded_qa
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
58
59
|
"std",
|
59
60
|
"stderr",
|
60
61
|
"mean",
|
62
|
+
"grouped",
|
61
63
|
"var",
|
62
64
|
"Metric",
|
63
65
|
"MetricProtocol",
|
@@ -0,0 +1,84 @@
|
|
1
|
+
from typing import Literal, cast
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from inspect_ai.scorer._metric import (
|
6
|
+
Metric,
|
7
|
+
MetricProtocol,
|
8
|
+
SampleScore,
|
9
|
+
Value,
|
10
|
+
ValueToFloat,
|
11
|
+
metric,
|
12
|
+
value_to_float,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
@metric
|
17
|
+
def grouped(
|
18
|
+
metric: Metric,
|
19
|
+
group_key: str,
|
20
|
+
*,
|
21
|
+
all: Literal["samples", "groups"] | Literal[False] = "samples",
|
22
|
+
all_label: str = "all",
|
23
|
+
value_to_float: ValueToFloat = value_to_float(),
|
24
|
+
) -> Metric:
|
25
|
+
"""
|
26
|
+
Creates a grouped metric that applies the given metric to subgroups of samples.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
metric: The metric to apply to each group of samples.
|
30
|
+
group_key: The metadata key used to group samples. Each sample must have this key in its metadata.
|
31
|
+
all: How to compute the "all" aggregate score:
|
32
|
+
- "samples": Apply the metric to all samples regardless of groups
|
33
|
+
- "groups": Calculate the mean of all group scores
|
34
|
+
- False: Don't calculate an aggregate score
|
35
|
+
all_label: The label for the "all" key in the returned dictionary.
|
36
|
+
value_to_float: Function to convert metric values to floats, used when all="groups".
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
A new metric function that returns a dictionary mapping group names to their scores,
|
40
|
+
with an optional "all" key for the aggregate score.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def grouped_metric(scores: list[SampleScore]) -> Value:
|
44
|
+
# Satisfy the type checker that the metric is a MetricProtocol
|
45
|
+
metric_protocol = cast(MetricProtocol, metric)
|
46
|
+
|
47
|
+
# Slice the scores into groups
|
48
|
+
scores_dict: dict[str, list[SampleScore]] = {}
|
49
|
+
for sample_score in scores:
|
50
|
+
if (
|
51
|
+
sample_score.sample_metadata is None
|
52
|
+
or group_key not in sample_score.sample_metadata
|
53
|
+
):
|
54
|
+
raise ValueError(
|
55
|
+
f"Sample {sample_score.sample_id} has no {group_key} metadata. To compute a grouped metric each sample metadata must have a value for '{group_key}'"
|
56
|
+
)
|
57
|
+
group_name = str(sample_score.sample_metadata.get(group_key))
|
58
|
+
if group_name not in scores_dict:
|
59
|
+
scores_dict[group_name] = []
|
60
|
+
scores_dict[group_name].append(sample_score)
|
61
|
+
|
62
|
+
# Compute the per group metric
|
63
|
+
grouped_scores = {
|
64
|
+
group_name: metric_protocol(values)
|
65
|
+
for group_name, values in scores_dict.items()
|
66
|
+
}
|
67
|
+
|
68
|
+
if not all:
|
69
|
+
return cast(Value, grouped_scores)
|
70
|
+
else:
|
71
|
+
# Compute the all metric
|
72
|
+
all_group_metric = None
|
73
|
+
if all == "samples":
|
74
|
+
# samples means apply the metric to all samples
|
75
|
+
all_group_metric = metric_protocol(scores)
|
76
|
+
elif all == "groups":
|
77
|
+
# group means the overall score is the mean of all the group scores
|
78
|
+
all_group_metric = np.mean(
|
79
|
+
[value_to_float(val) for val in grouped_scores.values()]
|
80
|
+
).item()
|
81
|
+
|
82
|
+
return cast(Value, {**grouped_scores, all_label: all_group_metric})
|
83
|
+
|
84
|
+
return grouped_metric
|
inspect_ai/scorer/_score.py
CHANGED
@@ -1,30 +1,50 @@
|
|
1
1
|
from contextvars import ContextVar
|
2
|
+
from copy import copy
|
2
3
|
|
3
|
-
from inspect_ai.
|
4
|
+
from inspect_ai.model._conversation import ModelConversation
|
5
|
+
from inspect_ai.solver._task_state import TaskState, sample_state
|
4
6
|
|
5
7
|
from ._metric import Score
|
6
8
|
from ._scorer import Scorer
|
7
9
|
from ._target import Target
|
8
10
|
|
9
11
|
|
10
|
-
async def score(
|
11
|
-
"""Score a
|
12
|
+
async def score(conversation: ModelConversation) -> list[Score]:
|
13
|
+
"""Score a model conversation.
|
12
14
|
|
13
|
-
Score a
|
15
|
+
Score a model conversation (you may pass `TaskState` or `AgentState`
|
16
|
+
as the value for `conversation`)
|
14
17
|
|
15
18
|
Args:
|
16
|
-
|
19
|
+
conversation: Conversation to submit for scoring.
|
20
|
+
Note that both `TaskState` and `AgentState` can be passed
|
21
|
+
as the `conversation` parameter.
|
17
22
|
|
18
23
|
Returns:
|
19
24
|
List of scores (one for each task scorer)
|
20
25
|
|
21
26
|
Raises:
|
22
|
-
|
27
|
+
RuntimeError: If called from outside a task or within
|
23
28
|
a task that does not have a scorer.
|
24
29
|
|
25
30
|
"""
|
26
31
|
from inspect_ai.log._transcript import ScoreEvent, transcript
|
27
32
|
|
33
|
+
# get TaskState (if the `conversation` is a `TaskState` use it directly,
|
34
|
+
# otherwise synthesize one)
|
35
|
+
if isinstance(conversation, TaskState):
|
36
|
+
state = conversation
|
37
|
+
else:
|
38
|
+
current_state = sample_state()
|
39
|
+
if current_state is None:
|
40
|
+
raise RuntimeError(
|
41
|
+
"The score() function can only be called while executing a task"
|
42
|
+
)
|
43
|
+
state = copy(current_state)
|
44
|
+
state.messages = conversation.messages
|
45
|
+
state.output = conversation.output
|
46
|
+
|
47
|
+
# get current scorers and target
|
28
48
|
scorers = _scorers.get(None)
|
29
49
|
target = _target.get(None)
|
30
50
|
if scorers is None or target is None:
|
inspect_ai/solver/__init__.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
2
|
|
3
3
|
from ._basic_agent import basic_agent
|
4
|
-
from ._bridge
|
4
|
+
from ._bridge import bridge
|
5
5
|
from ._chain import chain
|
6
6
|
from ._critique import self_critique
|
7
7
|
from ._fork import fork
|
8
|
-
from ._human_agent
|
8
|
+
from ._human_agent import human_agent
|
9
9
|
from ._limit import SampleLimitExceededError
|
10
10
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
11
11
|
from ._plan import Plan, plan
|
@@ -5,8 +5,8 @@ from typing_extensions import TypedDict, Unpack
|
|
5
5
|
|
6
6
|
from inspect_ai._util._async import is_callable_coroutine
|
7
7
|
from inspect_ai.model._cache import CachePolicy
|
8
|
-
from inspect_ai.model._call_tools import
|
9
|
-
from inspect_ai.model._chat_message import ChatMessageTool, ChatMessageUser
|
8
|
+
from inspect_ai.model._call_tools import execute_tools
|
9
|
+
from inspect_ai.model._chat_message import ChatMessage, ChatMessageTool, ChatMessageUser
|
10
10
|
from inspect_ai.model._model import get_model
|
11
11
|
from inspect_ai.scorer._metric import Score, ValueToFloat, value_to_float
|
12
12
|
from inspect_ai.scorer._score import score
|
@@ -65,6 +65,7 @@ def basic_agent(
|
|
65
65
|
continue_message: str = DEFAULT_CONTINUE_MESSAGE,
|
66
66
|
submit_name: str = DEFAULT_SUBMIT_NAME,
|
67
67
|
submit_description: str = DEFAULT_SUBMIT_DESCRIPTION,
|
68
|
+
submit_append: bool = False,
|
68
69
|
**kwargs: Unpack[BasicAgentDeprecatedArgs],
|
69
70
|
) -> Solver:
|
70
71
|
"""Basic ReAct agent.
|
@@ -102,6 +103,9 @@ def basic_agent(
|
|
102
103
|
(defaults to 'submit')
|
103
104
|
submit_description: Description of submit tool (defaults to
|
104
105
|
'Submit an answer for evaluation')
|
106
|
+
submit_append: Append the submit tool output to the model completion
|
107
|
+
text (defaults to `False`, which means the submission overwrites
|
108
|
+
the model completion).
|
105
109
|
**kwargs: Deprecated arguments for backward compatibility.
|
106
110
|
|
107
111
|
Returns:
|
@@ -149,9 +153,14 @@ def basic_agent(
|
|
149
153
|
return solve
|
150
154
|
|
151
155
|
# helper to extract a submitted answer
|
152
|
-
def submission(tool_results: list[
|
156
|
+
def submission(tool_results: list[ChatMessage]) -> str | None:
|
153
157
|
return next(
|
154
|
-
(
|
158
|
+
(
|
159
|
+
result.text
|
160
|
+
for result in tool_results
|
161
|
+
if isinstance(result, ChatMessageTool)
|
162
|
+
and result.function == submit_name
|
163
|
+
),
|
155
164
|
None,
|
156
165
|
)
|
157
166
|
|
@@ -189,9 +198,9 @@ def basic_agent(
|
|
189
198
|
|
190
199
|
# resolve tools calls (if any)
|
191
200
|
if state.output.message.tool_calls:
|
192
|
-
#
|
193
|
-
tool_results = await
|
194
|
-
state.output.message,
|
201
|
+
# execute tool functions
|
202
|
+
tool_results, _ = await execute_tools(
|
203
|
+
[state.output.message],
|
195
204
|
state.tools,
|
196
205
|
max_output=max_tool_output,
|
197
206
|
)
|
@@ -200,8 +209,12 @@ def basic_agent(
|
|
200
209
|
# was an answer submitted?
|
201
210
|
answer = submission(tool_results)
|
202
211
|
if answer:
|
203
|
-
|
204
|
-
|
212
|
+
if submit_append:
|
213
|
+
state.output.completion = (
|
214
|
+
f"{state.output.completion}\n\n{answer}".strip()
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
state.output.completion = answer
|
205
218
|
|
206
219
|
# exit if we are at max_attempts
|
207
220
|
attempts += 1
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from logging import getLogger
|
2
|
+
from typing import Any, Awaitable, Callable
|
3
|
+
|
4
|
+
from inspect_ai._util.logger import warn_once
|
5
|
+
from inspect_ai.agent._as_solver import as_solver
|
6
|
+
|
7
|
+
from ._solver import Solver, solver
|
8
|
+
|
9
|
+
logger = getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
@solver
|
13
|
+
def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
|
14
|
+
"""Bridge an external agent into an Inspect Solver.
|
15
|
+
|
16
|
+
See documentation at <https://inspect.ai-safety-institute.org.uk/agent-bridge.html>
|
17
|
+
|
18
|
+
Args:
|
19
|
+
agent: Callable which takes a sample `dict` and returns a result `dict`.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Standard Inspect solver.
|
23
|
+
"""
|
24
|
+
from inspect_ai.agent._bridge.bridge import bridge as agent_bridge
|
25
|
+
|
26
|
+
warn_once(
|
27
|
+
logger,
|
28
|
+
"The bridge solver is deprecated. Please use the bridge agent from the agents module instead.",
|
29
|
+
)
|
30
|
+
|
31
|
+
return as_solver(agent_bridge(agent))
|
inspect_ai/solver/_chain.py
CHANGED
@@ -1,14 +1,19 @@
|
|
1
|
-
from typing import Sequence, overload
|
1
|
+
from typing import Sequence, cast, overload
|
2
2
|
|
3
3
|
from typing_extensions import override
|
4
4
|
|
5
|
+
from inspect_ai.agent._agent import Agent, is_agent
|
6
|
+
from inspect_ai.agent._as_solver import as_solver
|
7
|
+
|
5
8
|
from ._solver import Generate, Solver, solver
|
6
9
|
from ._task_state import TaskState
|
7
10
|
|
8
11
|
|
9
12
|
@solver
|
10
|
-
def chain(
|
11
|
-
|
13
|
+
def chain(
|
14
|
+
*solvers: Solver | Agent | list[Solver] | list[Solver | Agent],
|
15
|
+
) -> Solver:
|
16
|
+
"""Compose a solver from multiple other solvers and/or agents.
|
12
17
|
|
13
18
|
Solvers are executed in turn, and a solver step event
|
14
19
|
is added to the transcript for each. If a solver returns
|
@@ -16,10 +21,10 @@ def chain(*solvers: Solver | list[Solver]) -> Solver:
|
|
16
21
|
early.
|
17
22
|
|
18
23
|
Args:
|
19
|
-
*solvers: One or more solvers or
|
24
|
+
*solvers: One or more solvers or agents to chain together.
|
20
25
|
|
21
26
|
Returns:
|
22
|
-
Solver that executes the passed solvers as a chain.
|
27
|
+
Solver that executes the passed solvers and agents as a chain.
|
23
28
|
"""
|
24
29
|
# flatten lists and chains
|
25
30
|
all_solvers: list[Solver] = []
|
@@ -29,17 +34,20 @@ def chain(*solvers: Solver | list[Solver]) -> Solver:
|
|
29
34
|
return Chain(all_solvers)
|
30
35
|
|
31
36
|
|
32
|
-
def unroll(
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
else:
|
37
|
-
return [solver]
|
38
|
-
else:
|
37
|
+
def unroll(
|
38
|
+
solver: Solver | Agent | list[Solver] | list[Solver | Agent],
|
39
|
+
) -> list[Solver]:
|
40
|
+
if isinstance(solver, list):
|
39
41
|
unrolled: list[Solver] = []
|
40
42
|
for s in solver:
|
41
43
|
unrolled.extend(unroll(s))
|
42
44
|
return unrolled
|
45
|
+
elif is_agent(solver):
|
46
|
+
return [as_solver(solver)]
|
47
|
+
elif isinstance(solver, Chain):
|
48
|
+
return unroll(solver._solvers)
|
49
|
+
else:
|
50
|
+
return [cast(Solver, solver)]
|
43
51
|
|
44
52
|
|
45
53
|
class Chain(Sequence[Solver], Solver):
|
inspect_ai/solver/_fork.py
CHANGED
@@ -52,7 +52,7 @@ async def fork(
|
|
52
52
|
|
53
53
|
async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
|
54
54
|
# get the generate function for the current task
|
55
|
-
generate =
|
55
|
+
generate = task_generate()
|
56
56
|
if generate is None:
|
57
57
|
raise RuntimeError("Called fork() outside of a running task.")
|
58
58
|
|
@@ -88,4 +88,8 @@ def set_task_generate(generate: Generate) -> None:
|
|
88
88
|
_generate.set(generate)
|
89
89
|
|
90
90
|
|
91
|
+
def task_generate() -> Generate | None:
|
92
|
+
return _generate.get(None)
|
93
|
+
|
94
|
+
|
91
95
|
_generate: ContextVar[Generate] = ContextVar("_generate")
|