inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import concurrent
|
2
4
|
import concurrent.futures
|
3
5
|
import copy
|
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
|
|
26
28
|
from typing_extensions import override
|
27
29
|
|
28
30
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
29
|
-
from inspect_ai._util.content import
|
31
|
+
from inspect_ai._util.content import (
|
32
|
+
ContentAudio,
|
33
|
+
ContentImage,
|
34
|
+
ContentText,
|
35
|
+
ContentVideo,
|
36
|
+
)
|
30
37
|
from inspect_ai._util.trace import trace_action
|
31
38
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
32
39
|
|
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
85
92
|
self.batch_size = collect_model_arg("batch_size")
|
86
93
|
self.chat_template = collect_model_arg("chat_template")
|
87
94
|
self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
|
95
|
+
self.enable_thinking = collect_model_arg("enable_thinking")
|
88
96
|
if self.tokenizer_call_args is None:
|
89
97
|
self.tokenizer_call_args = {}
|
90
98
|
|
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
263
271
|
elif "qwen" in self.model_name.lower():
|
264
272
|
hf_messages = inspect_tools_to_string(hf_messages)
|
265
273
|
|
274
|
+
hf_messages = message_content_to_string(hf_messages)
|
266
275
|
# apply chat template
|
267
276
|
if self.tokenizer.chat_template is not None:
|
268
277
|
chat = self.tokenizer.apply_chat_template(
|
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
270
279
|
add_generation_prompt=True,
|
271
280
|
tokenize=False,
|
272
281
|
tools=tools_list if len(tools_list) > 0 else None,
|
282
|
+
enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
|
273
283
|
)
|
274
284
|
else:
|
275
285
|
chat = ""
|
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
|
|
279
289
|
return cast(str, chat)
|
280
290
|
|
281
291
|
|
292
|
+
def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
|
293
|
+
"""Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
|
294
|
+
for message in messages:
|
295
|
+
if isinstance(message.content, list):
|
296
|
+
is_multimodal = any(
|
297
|
+
isinstance(item, ContentAudio | ContentImage | ContentVideo)
|
298
|
+
for item in message.content
|
299
|
+
)
|
300
|
+
if is_multimodal:
|
301
|
+
raise NotImplementedError(
|
302
|
+
"HuggingFace provider does not support multimodal content, please provide text inputs only."
|
303
|
+
)
|
304
|
+
message.content = message.text
|
305
|
+
return messages
|
306
|
+
|
307
|
+
|
282
308
|
def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
|
283
309
|
"""Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
|
284
310
|
for i, message in enumerate(messages):
|
@@ -71,6 +71,7 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
71
71
|
SGLANG_DEFAULT_SERVER_ARGS, server_args, logger
|
72
72
|
)
|
73
73
|
|
74
|
+
self.server_found = True
|
74
75
|
try:
|
75
76
|
# Try to initialize with existing server
|
76
77
|
super().__init__(
|
@@ -83,7 +84,9 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
83
84
|
)
|
84
85
|
logger.info(f"Using existing SGLang server at {self.base_url}")
|
85
86
|
except PrerequisiteError:
|
86
|
-
|
87
|
+
self.server_found = False
|
88
|
+
|
89
|
+
if not self.server_found:
|
87
90
|
logger.warning(
|
88
91
|
f"Existing SGLang server not found. Starting new server for {model_name}."
|
89
92
|
)
|
@@ -125,7 +128,9 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
125
128
|
api_key = "inspectai" # Create a default API key if not provided
|
126
129
|
|
127
130
|
# Handle device configuration
|
128
|
-
self.server_args = configure_devices(
|
131
|
+
self.server_args, env_vars = configure_devices(
|
132
|
+
self.server_args, parallel_size_param="tp"
|
133
|
+
)
|
129
134
|
|
130
135
|
timeout = self.server_args.pop("timeout", None)
|
131
136
|
host = self.server_args.pop("host", "0.0.0.0")
|
@@ -149,6 +154,7 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
149
154
|
server_type="SGLang",
|
150
155
|
timeout=timeout,
|
151
156
|
server_args=self.server_args,
|
157
|
+
env=env_vars,
|
152
158
|
)
|
153
159
|
|
154
160
|
# Register cleanup function to run when Python exits
|
@@ -76,6 +76,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
76
76
|
VLLM_DEFAULT_SERVER_ARGS, server_args, logger
|
77
77
|
)
|
78
78
|
|
79
|
+
self.server_found = True
|
79
80
|
try:
|
80
81
|
# Try to initialize with existing server
|
81
82
|
super().__init__(
|
@@ -88,7 +89,9 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
88
89
|
)
|
89
90
|
logger.info(f"Using existing vLLM server at {self.base_url}")
|
90
91
|
except PrerequisiteError:
|
91
|
-
|
92
|
+
self.server_found = False
|
93
|
+
|
94
|
+
if not self.server_found:
|
92
95
|
logger.warning(
|
93
96
|
f"Existing vLLM server not found. Starting new server for {model_name}."
|
94
97
|
)
|
@@ -131,7 +134,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
131
134
|
raise pip_dependency_error("vLLM Server", ["vllm"])
|
132
135
|
|
133
136
|
# Handle device configuration
|
134
|
-
self.server_args = configure_devices(
|
137
|
+
self.server_args, env_vars = configure_devices(
|
135
138
|
self.server_args, parallel_size_param="tensor_parallel_size"
|
136
139
|
)
|
137
140
|
|
@@ -152,6 +155,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
152
155
|
server_type="vLLM",
|
153
156
|
timeout=timeout,
|
154
157
|
server_args=self.server_args,
|
158
|
+
env=env_vars,
|
155
159
|
)
|
156
160
|
|
157
161
|
# Register cleanup function to run when Python exits
|
inspect_ai/scorer/_choice.py
CHANGED
inspect_ai/solver/_chain.py
CHANGED
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
|
|
82
82
|
from ._transcript import solver_transcript
|
83
83
|
|
84
84
|
for slv in self._solvers:
|
85
|
-
with solver_transcript(slv, state) as st:
|
85
|
+
async with solver_transcript(slv, state) as st:
|
86
86
|
state = await slv(state, generate)
|
87
87
|
st.complete(state)
|
88
88
|
if state.completed:
|
inspect_ai/solver/_fork.py
CHANGED
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
|
|
73
73
|
@subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
|
74
74
|
async def solve() -> TaskState:
|
75
75
|
if not isinstance(solver, Chain):
|
76
|
-
with solver_transcript(solver, state) as st:
|
76
|
+
async with solver_transcript(solver, state) as st:
|
77
77
|
new_state = await solver(state, generate)
|
78
78
|
st.complete(new_state)
|
79
79
|
return new_state
|
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
|
|
6
6
|
|
7
7
|
from typing_extensions import Unpack
|
8
8
|
|
9
|
+
from inspect_ai._util.answer import answer_character, answer_index
|
9
10
|
from inspect_ai._util.logger import warn_once
|
10
11
|
from inspect_ai.util import resource
|
11
12
|
|
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
|
|
64
65
|
indexes = list(range(len(choices)))
|
65
66
|
|
66
67
|
return "\n".join(
|
67
|
-
[f"{
|
68
|
+
[f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
|
68
69
|
)
|
69
70
|
|
70
71
|
|
71
|
-
def answer_character(index: int) -> str:
|
72
|
-
r"""
|
73
|
-
Helper to go from array index to char, for example:
|
74
|
-
|
75
|
-
0 -> 'A', 1 -> 'B', etc
|
76
|
-
"""
|
77
|
-
return chr(ord("A") + index)
|
78
|
-
|
79
|
-
|
80
|
-
def answer_index(char: str) -> int:
|
81
|
-
r"""
|
82
|
-
Helper to go from char to array index, for example:
|
83
|
-
|
84
|
-
'A' -> 0, 'B' -> 1, etc
|
85
|
-
"""
|
86
|
-
return ord(char.upper()) - ord("A")
|
87
|
-
|
88
|
-
|
89
72
|
def prompt(question: str, choices: Choices, template: str) -> str:
|
90
73
|
choices_text = answer_options(choices)
|
91
|
-
letters = ",".join(
|
74
|
+
letters = ",".join(answer_character(i) for i in range(len(choices)))
|
92
75
|
|
93
76
|
return template.format(
|
94
77
|
choices=choices_text,
|
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
112
95
|
# In this case, we're looking for a single line which contains the expected
|
113
96
|
# ANSWER: B,C string with only whitespace after it
|
114
97
|
match = re.search(
|
115
|
-
r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
|
98
|
+
r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
|
116
99
|
state.output.completion,
|
117
100
|
flags=re.MULTILINE,
|
118
101
|
)
|
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
121
104
|
# version for backward compatibility
|
122
105
|
if match is None:
|
123
106
|
return re.search(
|
124
|
-
r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
|
107
|
+
r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
|
125
108
|
)
|
126
109
|
else:
|
127
110
|
return match
|
@@ -217,6 +200,7 @@ def multiple_choice(
|
|
217
200
|
template: str | None = None,
|
218
201
|
cot: bool = False,
|
219
202
|
multiple_correct: bool = False,
|
203
|
+
max_tokens: int | None = None,
|
220
204
|
**kwargs: Unpack[DeprecatedArgs],
|
221
205
|
) -> Solver:
|
222
206
|
"""Multiple choice question solver. Formats a multiple choice question prompt, then calls `generate()`.
|
@@ -243,6 +227,8 @@ def multiple_choice(
|
|
243
227
|
squares? A) 3, B) 4, C) 9" has multiple correct answers, B and C. Leave
|
244
228
|
as `False` if there's exactly one correct answer from the choices
|
245
229
|
available. NOTE: this has no effect if you provide a custom template.
|
230
|
+
max_tokens: Default `None`. Controls the number of tokens generated through the call
|
231
|
+
to generate().
|
246
232
|
**kwargs (Any): Deprecated arguments for backward compatibility.
|
247
233
|
|
248
234
|
#### Shuffling
|
@@ -299,7 +285,7 @@ def multiple_choice(
|
|
299
285
|
template=str(template),
|
300
286
|
)
|
301
287
|
|
302
|
-
state = await generate(state)
|
288
|
+
state = await generate(state, max_tokens=max_tokens)
|
303
289
|
|
304
290
|
answers = parse_answers(state)
|
305
291
|
if answers and answers.group(1):
|
inspect_ai/solver/_plan.py
CHANGED
@@ -102,7 +102,7 @@ class Plan(Solver):
|
|
102
102
|
# execute steps
|
103
103
|
for index, solver in enumerate(self.steps):
|
104
104
|
# run solver
|
105
|
-
with solver_transcript(solver, state) as st:
|
105
|
+
async with solver_transcript(solver, state) as st:
|
106
106
|
state = await solver(state, generate)
|
107
107
|
st.complete(state)
|
108
108
|
|
@@ -113,7 +113,7 @@ class Plan(Solver):
|
|
113
113
|
|
114
114
|
# execute finish
|
115
115
|
if self.finish:
|
116
|
-
with solver_transcript(self.finish, state) as st:
|
116
|
+
async with solver_transcript(self.finish, state) as st:
|
117
117
|
state = await self.finish(state, generate)
|
118
118
|
st.complete(state)
|
119
119
|
check_sample_interrupt()
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -204,13 +204,17 @@ class TaskState:
|
|
204
204
|
Convenience function for accessing the initial input from the `Sample` as a string.
|
205
205
|
|
206
206
|
If the `input` is a `list[ChatMessage]`, this will return the text from
|
207
|
-
the
|
207
|
+
the last chat message
|
208
208
|
"""
|
209
209
|
if isinstance(self._input, str):
|
210
210
|
return self._input
|
211
211
|
else:
|
212
212
|
input = next(
|
213
|
-
(
|
213
|
+
(
|
214
|
+
message.text
|
215
|
+
for message in reversed(self._input)
|
216
|
+
if message.role == "user"
|
217
|
+
),
|
214
218
|
None,
|
215
219
|
)
|
216
220
|
if input:
|
@@ -231,7 +235,7 @@ class TaskState:
|
|
231
235
|
write access to the user chat prompt. Raises an
|
232
236
|
exception if there is no user prompt
|
233
237
|
"""
|
234
|
-
prompt = next((m for m in self.messages if m.role == "user"), None)
|
238
|
+
prompt = next((m for m in reversed(self.messages) if m.role == "user"), None)
|
235
239
|
if prompt:
|
236
240
|
return prompt
|
237
241
|
else:
|
inspect_ai/solver/_transcript.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import contextlib
|
2
|
-
from typing import
|
2
|
+
from typing import AsyncIterator
|
3
3
|
|
4
4
|
from inspect_ai._util.json import json_changes
|
5
5
|
from inspect_ai._util.registry import registry_log_name
|
6
|
+
from inspect_ai.util._span import span
|
6
7
|
|
7
8
|
from ._solver import Solver
|
8
9
|
from ._task_state import TaskState, state_jsonable
|
@@ -22,12 +23,10 @@ class SolverTranscript:
|
|
22
23
|
transcript()._event(StateEvent(changes=changes))
|
23
24
|
|
24
25
|
|
25
|
-
@contextlib.
|
26
|
-
def solver_transcript(
|
26
|
+
@contextlib.asynccontextmanager
|
27
|
+
async def solver_transcript(
|
27
28
|
solver: Solver, state: TaskState, name: str | None = None
|
28
|
-
) ->
|
29
|
-
from inspect_ai.log._transcript import transcript
|
30
|
-
|
29
|
+
) -> AsyncIterator[SolverTranscript]:
|
31
30
|
name = registry_log_name(name or solver)
|
32
|
-
with
|
31
|
+
async with span(name=name, type="solver"):
|
33
32
|
yield SolverTranscript(name, state)
|
inspect_ai/tool/_mcp/_context.py
CHANGED
@@ -2,13 +2,11 @@ from contextlib import _AsyncGeneratorContextManager
|
|
2
2
|
from typing import TypeAlias
|
3
3
|
|
4
4
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
5
|
-
from mcp.
|
6
|
-
JSONRPCMessage,
|
7
|
-
)
|
5
|
+
from mcp.shared.message import SessionMessage
|
8
6
|
|
9
7
|
MCPServerContext: TypeAlias = _AsyncGeneratorContextManager[
|
10
8
|
tuple[
|
11
|
-
MemoryObjectReceiveStream[
|
12
|
-
MemoryObjectSendStream[
|
9
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
10
|
+
MemoryObjectSendStream[SessionMessage],
|
13
11
|
],
|
14
12
|
]
|
inspect_ai/tool/_mcp/_mcp.py
CHANGED
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
|
|
61
61
|
) -> list[Tool]:
|
62
62
|
return await self._task_session()._list_tools(tools)
|
63
63
|
|
64
|
-
# create a separate MCPServer session per async task
|
65
|
-
_task_sessions: dict[
|
64
|
+
# create a separate MCPServer session per async task / server name
|
65
|
+
_task_sessions: dict[str, "MCPServerSession"] = {}
|
66
66
|
|
67
67
|
def _task_session(self) -> "MCPServerSession":
|
68
68
|
task_id = anyio.get_current_task().id
|
69
|
-
|
70
|
-
|
69
|
+
session_key = f"{task_id}_{self._name}"
|
70
|
+
if session_key not in self._task_sessions:
|
71
|
+
MCPServerImpl._task_sessions[session_key] = MCPServerSession(
|
71
72
|
self._client, name=self._name, events=self._events
|
72
73
|
)
|
73
|
-
return MCPServerImpl._task_sessions[
|
74
|
+
return MCPServerImpl._task_sessions[session_key]
|
74
75
|
|
75
76
|
|
76
77
|
class MCPServerSession(MCPServer):
|
inspect_ai/tool/_mcp/server.py
CHANGED
@@ -96,7 +96,10 @@ def python(
|
|
96
96
|
The output of the Python code.
|
97
97
|
"""
|
98
98
|
result = await sandbox_env(sandbox).exec(
|
99
|
-
cmd=["
|
99
|
+
cmd=["bash", "--login", "-c", "python3 -"],
|
100
|
+
input=code,
|
101
|
+
timeout=timeout,
|
102
|
+
user=user,
|
100
103
|
)
|
101
104
|
# return output (including stderr if any)
|
102
105
|
output = ""
|
inspect_ai/tool/_tools/_think.py
CHANGED
@@ -41,7 +41,7 @@ def think(
|
|
41
41
|
def think_tool_viewer() -> ToolCallViewer:
|
42
42
|
def viewer(tool_call: ToolCall) -> ToolCallView:
|
43
43
|
call = ToolCallContent(
|
44
|
-
format="markdown", content=tool_call.arguments
|
44
|
+
format="markdown", content=tool_call.arguments.get("thought", "")
|
45
45
|
)
|
46
46
|
return ToolCallView(call=call)
|
47
47
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
2
|
+
from typing import Awaitable, Callable
|
3
3
|
|
4
4
|
import anyio
|
5
5
|
import httpx
|
@@ -16,8 +16,6 @@ from inspect_ai._util.error import PrerequisiteError
|
|
16
16
|
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
17
17
|
from inspect_ai.util._concurrency import concurrency
|
18
18
|
|
19
|
-
from .._tool import Tool, ToolResult, tool
|
20
|
-
|
21
19
|
DEFAULT_RELEVANCE_PROMPT = """I am trying to answer the following question and need to find the most relevant information on the web. Please let me know if the following content is relevant to the question or not. You should just respond with "yes" or "no".
|
22
20
|
|
23
21
|
Question: {question}
|
@@ -31,59 +29,35 @@ class SearchLink:
|
|
31
29
|
self.snippet = snippet
|
32
30
|
|
33
31
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
@tool
|
40
|
-
def web_search(
|
41
|
-
provider: Literal["google"] = "google",
|
42
|
-
num_results: int = 3,
|
43
|
-
max_provider_calls: int = 3,
|
44
|
-
max_connections: int = 10,
|
45
|
-
model: str | None = None,
|
46
|
-
) -> Tool:
|
47
|
-
"""Web search tool.
|
48
|
-
|
49
|
-
A tool that can be registered for use by models to search the web. Use
|
50
|
-
the `use_tools()` solver to make the tool available (e.g. `use_tools(web_search())`))
|
51
|
-
|
52
|
-
A web search is conducted using the specified provider, the results are parsed for relevance
|
53
|
-
using the specified model, and the top 'num_results' relevant pages are returned.
|
54
|
-
|
55
|
-
See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
provider: Search provider (defaults to "google", currently
|
59
|
-
the only provider). Possible future providers include "brave" and "bing".
|
60
|
-
num_results: Number of web search result pages to return to the model.
|
61
|
-
max_provider_calls: Maximum number of search calls to make to the search provider.
|
62
|
-
max_connections: Maximum number of concurrent connections to API
|
63
|
-
endpoint of search provider.
|
64
|
-
model: Model used to parse web pages for relevance.
|
32
|
+
def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
33
|
+
"""
|
34
|
+
Get Google API keys from environment variables.
|
65
35
|
|
66
36
|
Returns:
|
67
|
-
|
37
|
+
tuple: A tuple containing the Google API key and the Google CSE ID.
|
68
38
|
"""
|
69
|
-
|
70
|
-
|
39
|
+
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
|
40
|
+
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
|
41
|
+
return (google_api_key, google_cse_id) if google_api_key and google_cse_id else None
|
71
42
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
43
|
+
|
44
|
+
def google_search_provider(
|
45
|
+
num_results: int,
|
46
|
+
max_provider_calls: int,
|
47
|
+
max_connections: int,
|
48
|
+
model: str | None,
|
49
|
+
) -> Callable[[str], Awaitable[str | None]]:
|
50
|
+
keys = maybe_get_google_api_keys()
|
51
|
+
if not keys:
|
52
|
+
raise PrerequisiteError(
|
53
|
+
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
|
77
54
|
)
|
55
|
+
google_api_key, google_cse_id = keys
|
78
56
|
|
79
|
-
#
|
80
|
-
|
81
|
-
"""
|
82
|
-
Use the web_search tool to perform keyword searches of the web.
|
57
|
+
# Create the client within the provider
|
58
|
+
client = httpx.AsyncClient()
|
83
59
|
|
84
|
-
|
85
|
-
query (str): Search query.
|
86
|
-
"""
|
60
|
+
async def search(query: str) -> str | None:
|
87
61
|
# limit number of concurrent searches
|
88
62
|
page_contents: list[str] = []
|
89
63
|
urls: list[str] = []
|
@@ -92,8 +66,8 @@ def web_search(
|
|
92
66
|
|
93
67
|
# Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
|
94
68
|
while len(page_contents) < num_results and search_calls < max_provider_calls:
|
95
|
-
async with concurrency(
|
96
|
-
links = await
|
69
|
+
async with concurrency("google_web_search", max_connections):
|
70
|
+
links = await _search(query, start_idx=search_calls * 10)
|
97
71
|
|
98
72
|
async with anyio.create_task_group() as tg:
|
99
73
|
|
@@ -114,19 +88,39 @@ def web_search(
|
|
114
88
|
search_calls += 1
|
115
89
|
|
116
90
|
all_page_contents = "\n\n".join(page_contents)
|
117
|
-
if all_page_contents == ""
|
118
|
-
response: ToolResult = (
|
119
|
-
"I'm sorry, I couldn't find any relevant information on the web."
|
120
|
-
)
|
121
|
-
else:
|
122
|
-
response = (
|
123
|
-
"Here are your web search results. Please read them carefully as they may be useful later! "
|
124
|
-
+ all_page_contents
|
125
|
-
)
|
91
|
+
return None if all_page_contents == "" else all_page_contents
|
126
92
|
|
127
|
-
|
93
|
+
async def _search(query: str, start_idx: int) -> list[SearchLink]:
|
94
|
+
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
95
|
+
search_params = {
|
96
|
+
"q": query,
|
97
|
+
"key": google_api_key,
|
98
|
+
"cx": google_cse_id,
|
99
|
+
"start": start_idx,
|
100
|
+
}
|
101
|
+
search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
|
102
|
+
[f"{key}={value}" for key, value in search_params.items()]
|
103
|
+
)
|
128
104
|
|
129
|
-
|
105
|
+
# retry up to 5 times over a period of up to 1 minute
|
106
|
+
@retry(
|
107
|
+
wait=wait_exponential_jitter(),
|
108
|
+
stop=stop_after_attempt(5) | stop_after_delay(60),
|
109
|
+
retry=retry_if_exception(httpx_should_retry),
|
110
|
+
before_sleep=log_httpx_retry_attempt(search_url),
|
111
|
+
)
|
112
|
+
async def execute_search() -> httpx.Response:
|
113
|
+
return await client.get(search_url)
|
114
|
+
|
115
|
+
result = await execute_search()
|
116
|
+
data = result.json()
|
117
|
+
|
118
|
+
if "items" in data:
|
119
|
+
return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
|
120
|
+
else:
|
121
|
+
return []
|
122
|
+
|
123
|
+
return search
|
130
124
|
|
131
125
|
|
132
126
|
async def page_if_relevant(
|
@@ -183,44 +177,3 @@ async def page_if_relevant(
|
|
183
177
|
return full_text
|
184
178
|
else:
|
185
179
|
return None
|
186
|
-
|
187
|
-
|
188
|
-
def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
|
189
|
-
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
|
190
|
-
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
|
191
|
-
if not google_api_key or not google_cse_id:
|
192
|
-
raise PrerequisiteError(
|
193
|
-
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
|
194
|
-
)
|
195
|
-
|
196
|
-
async def search(query: str, start_idx: int) -> list[SearchLink]:
|
197
|
-
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
198
|
-
search_params = {
|
199
|
-
"q": query,
|
200
|
-
"key": google_api_key,
|
201
|
-
"cx": google_cse_id,
|
202
|
-
"start": start_idx,
|
203
|
-
}
|
204
|
-
search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
|
205
|
-
[f"{key}={value}" for key, value in search_params.items()]
|
206
|
-
)
|
207
|
-
|
208
|
-
# retry up to 5 times over a period of up to 1 minute
|
209
|
-
@retry(
|
210
|
-
wait=wait_exponential_jitter(),
|
211
|
-
stop=stop_after_attempt(5) | stop_after_delay(60),
|
212
|
-
retry=retry_if_exception(httpx_should_retry),
|
213
|
-
before_sleep=log_httpx_retry_attempt(search_url),
|
214
|
-
)
|
215
|
-
async def execute_search() -> httpx.Response:
|
216
|
-
return await client.get(search_url)
|
217
|
-
|
218
|
-
result = await execute_search()
|
219
|
-
data = result.json()
|
220
|
-
|
221
|
-
if "items" in data:
|
222
|
-
return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
|
223
|
-
else:
|
224
|
-
return []
|
225
|
-
|
226
|
-
return search
|