inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. inspect_ai/_display/textual/widgets/samples.py +3 -3
  2. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  3. inspect_ai/_eval/loader.py +1 -1
  4. inspect_ai/_eval/task/run.py +21 -12
  5. inspect_ai/_util/answer.py +26 -0
  6. inspect_ai/_util/constants.py +0 -1
  7. inspect_ai/_util/exception.py +4 -0
  8. inspect_ai/_util/hash.py +39 -0
  9. inspect_ai/_util/local_server.py +51 -21
  10. inspect_ai/_util/path.py +22 -0
  11. inspect_ai/_util/trace.py +1 -1
  12. inspect_ai/_util/working.py +4 -0
  13. inspect_ai/_view/www/dist/assets/index.css +23 -22
  14. inspect_ai/_view/www/dist/assets/index.js +517 -204
  15. inspect_ai/_view/www/log-schema.json +375 -0
  16. inspect_ai/_view/www/package.json +1 -1
  17. inspect_ai/_view/www/src/@types/log.d.ts +90 -12
  18. inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
  19. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
  20. inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
  21. inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
  22. inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
  23. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
  24. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
  25. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  26. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  27. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  28. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  29. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  30. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  31. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  32. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  33. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  34. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  35. inspect_ai/_view/www/src/app/types.ts +12 -2
  36. inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
  37. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
  38. inspect_ai/_view/www/src/state/hooks.ts +19 -3
  39. inspect_ai/_view/www/src/state/logSlice.ts +23 -5
  40. inspect_ai/_view/www/yarn.lock +9 -9
  41. inspect_ai/agent/_as_solver.py +3 -1
  42. inspect_ai/agent/_as_tool.py +6 -4
  43. inspect_ai/agent/_bridge/patch.py +1 -3
  44. inspect_ai/agent/_handoff.py +5 -1
  45. inspect_ai/agent/_react.py +4 -3
  46. inspect_ai/agent/_run.py +6 -1
  47. inspect_ai/agent/_types.py +9 -0
  48. inspect_ai/analysis/__init__.py +0 -0
  49. inspect_ai/analysis/beta/__init__.py +57 -0
  50. inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
  51. inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
  52. inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
  53. inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
  54. inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
  55. inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
  56. inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
  57. inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
  58. inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
  59. inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
  60. inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
  61. inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
  62. inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
  63. inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
  64. inspect_ai/analysis/beta/_dataframe/record.py +377 -0
  65. inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
  66. inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
  67. inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
  68. inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
  69. inspect_ai/analysis/beta/_dataframe/util.py +157 -0
  70. inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
  71. inspect_ai/dataset/_dataset.py +6 -3
  72. inspect_ai/log/__init__.py +10 -0
  73. inspect_ai/log/_convert.py +4 -9
  74. inspect_ai/log/_file.py +1 -1
  75. inspect_ai/log/_log.py +21 -1
  76. inspect_ai/log/_samples.py +14 -17
  77. inspect_ai/log/_transcript.py +77 -35
  78. inspect_ai/log/_tree.py +118 -0
  79. inspect_ai/model/_call_tools.py +44 -35
  80. inspect_ai/model/_model.py +51 -44
  81. inspect_ai/model/_openai_responses.py +17 -18
  82. inspect_ai/model/_providers/anthropic.py +30 -5
  83. inspect_ai/model/_providers/hf.py +27 -1
  84. inspect_ai/model/_providers/providers.py +1 -1
  85. inspect_ai/model/_providers/sglang.py +8 -2
  86. inspect_ai/model/_providers/vllm.py +6 -2
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/_chain.py +1 -1
  89. inspect_ai/solver/_fork.py +1 -1
  90. inspect_ai/solver/_multiple_choice.py +9 -23
  91. inspect_ai/solver/_plan.py +2 -2
  92. inspect_ai/solver/_task_state.py +7 -3
  93. inspect_ai/solver/_transcript.py +6 -7
  94. inspect_ai/tool/_mcp/_context.py +3 -5
  95. inspect_ai/tool/_mcp/_mcp.py +6 -5
  96. inspect_ai/tool/_mcp/server.py +1 -1
  97. inspect_ai/tool/_tools/_execute.py +4 -1
  98. inspect_ai/tool/_tools/_think.py +1 -1
  99. inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
  100. inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
  101. inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
  102. inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
  103. inspect_ai/util/__init__.py +4 -0
  104. inspect_ai/util/_anyio.py +11 -0
  105. inspect_ai/util/_collect.py +50 -0
  106. inspect_ai/util/_sandbox/events.py +3 -2
  107. inspect_ai/util/_span.py +58 -0
  108. inspect_ai/util/_subtask.py +27 -42
  109. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
  110. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
  111. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
  112. inspect_ai/_display/core/group.py +0 -79
  113. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
  114. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
  115. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent
2
4
  import concurrent.futures
3
5
  import copy
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
26
28
  from typing_extensions import override
27
29
 
28
30
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
29
- from inspect_ai._util.content import ContentText
31
+ from inspect_ai._util.content import (
32
+ ContentAudio,
33
+ ContentImage,
34
+ ContentText,
35
+ ContentVideo,
36
+ )
30
37
  from inspect_ai._util.trace import trace_action
31
38
  from inspect_ai.tool import ToolChoice, ToolInfo
32
39
 
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
85
92
  self.batch_size = collect_model_arg("batch_size")
86
93
  self.chat_template = collect_model_arg("chat_template")
87
94
  self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
95
+ self.enable_thinking = collect_model_arg("enable_thinking")
88
96
  if self.tokenizer_call_args is None:
89
97
  self.tokenizer_call_args = {}
90
98
 
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
263
271
  elif "qwen" in self.model_name.lower():
264
272
  hf_messages = inspect_tools_to_string(hf_messages)
265
273
 
274
+ hf_messages = message_content_to_string(hf_messages)
266
275
  # apply chat template
267
276
  if self.tokenizer.chat_template is not None:
268
277
  chat = self.tokenizer.apply_chat_template(
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
270
279
  add_generation_prompt=True,
271
280
  tokenize=False,
272
281
  tools=tools_list if len(tools_list) > 0 else None,
282
+ enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
273
283
  )
274
284
  else:
275
285
  chat = ""
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
279
289
  return cast(str, chat)
280
290
 
281
291
 
292
+ def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
293
+ """Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
294
+ for message in messages:
295
+ if isinstance(message.content, list):
296
+ is_multimodal = any(
297
+ isinstance(item, ContentAudio | ContentImage | ContentVideo)
298
+ for item in message.content
299
+ )
300
+ if is_multimodal:
301
+ raise NotImplementedError(
302
+ "HuggingFace provider does not support multimodal content, please provide text inputs only."
303
+ )
304
+ message.content = message.text
305
+ return messages
306
+
307
+
282
308
  def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
283
309
  """Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
284
310
  for i, message in enumerate(messages):
@@ -281,7 +281,7 @@ def none() -> type[ModelAPI]:
281
281
  def validate_openai_client(feature: str) -> None:
282
282
  FEATURE = feature
283
283
  PACKAGE = "openai"
284
- MIN_VERSION = "1.75.0"
284
+ MIN_VERSION = "1.78.0"
285
285
 
286
286
  # verify we have the package
287
287
  try:
@@ -71,6 +71,7 @@ class SGLangAPI(OpenAICompatibleAPI):
71
71
  SGLANG_DEFAULT_SERVER_ARGS, server_args, logger
72
72
  )
73
73
 
74
+ self.server_found = True
74
75
  try:
75
76
  # Try to initialize with existing server
76
77
  super().__init__(
@@ -83,7 +84,9 @@ class SGLangAPI(OpenAICompatibleAPI):
83
84
  )
84
85
  logger.info(f"Using existing SGLang server at {self.base_url}")
85
86
  except PrerequisiteError:
86
- # No existing server found, start a new one
87
+ self.server_found = False
88
+
89
+ if not self.server_found:
87
90
  logger.warning(
88
91
  f"Existing SGLang server not found. Starting new server for {model_name}."
89
92
  )
@@ -125,7 +128,9 @@ class SGLangAPI(OpenAICompatibleAPI):
125
128
  api_key = "inspectai" # Create a default API key if not provided
126
129
 
127
130
  # Handle device configuration
128
- self.server_args = configure_devices(self.server_args, parallel_size_param="tp")
131
+ self.server_args, env_vars = configure_devices(
132
+ self.server_args, parallel_size_param="tp"
133
+ )
129
134
 
130
135
  timeout = self.server_args.pop("timeout", None)
131
136
  host = self.server_args.pop("host", "0.0.0.0")
@@ -149,6 +154,7 @@ class SGLangAPI(OpenAICompatibleAPI):
149
154
  server_type="SGLang",
150
155
  timeout=timeout,
151
156
  server_args=self.server_args,
157
+ env=env_vars,
152
158
  )
153
159
 
154
160
  # Register cleanup function to run when Python exits
@@ -76,6 +76,7 @@ class VLLMAPI(OpenAICompatibleAPI):
76
76
  VLLM_DEFAULT_SERVER_ARGS, server_args, logger
77
77
  )
78
78
 
79
+ self.server_found = True
79
80
  try:
80
81
  # Try to initialize with existing server
81
82
  super().__init__(
@@ -88,7 +89,9 @@ class VLLMAPI(OpenAICompatibleAPI):
88
89
  )
89
90
  logger.info(f"Using existing vLLM server at {self.base_url}")
90
91
  except PrerequisiteError:
91
- # No existing server found, start a new one
92
+ self.server_found = False
93
+
94
+ if not self.server_found:
92
95
  logger.warning(
93
96
  f"Existing vLLM server not found. Starting new server for {model_name}."
94
97
  )
@@ -131,7 +134,7 @@ class VLLMAPI(OpenAICompatibleAPI):
131
134
  raise pip_dependency_error("vLLM Server", ["vllm"])
132
135
 
133
136
  # Handle device configuration
134
- self.server_args = configure_devices(
137
+ self.server_args, env_vars = configure_devices(
135
138
  self.server_args, parallel_size_param="tensor_parallel_size"
136
139
  )
137
140
 
@@ -152,6 +155,7 @@ class VLLMAPI(OpenAICompatibleAPI):
152
155
  server_type="vLLM",
153
156
  timeout=timeout,
154
157
  server_args=self.server_args,
158
+ env=env_vars,
155
159
  )
156
160
 
157
161
  # Register cleanup function to run when Python exits
@@ -1,6 +1,5 @@
1
+ from inspect_ai._util.answer import answer_character, answer_index
1
2
  from inspect_ai.solver._multiple_choice import (
2
- answer_character,
3
- answer_index,
4
3
  answer_options,
5
4
  unshuffle_choices,
6
5
  )
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
82
82
  from ._transcript import solver_transcript
83
83
 
84
84
  for slv in self._solvers:
85
- with solver_transcript(slv, state) as st:
85
+ async with solver_transcript(slv, state) as st:
86
86
  state = await slv(state, generate)
87
87
  st.complete(state)
88
88
  if state.completed:
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
73
73
  @subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
74
74
  async def solve() -> TaskState:
75
75
  if not isinstance(solver, Chain):
76
- with solver_transcript(solver, state) as st:
76
+ async with solver_transcript(solver, state) as st:
77
77
  new_state = await solver(state, generate)
78
78
  st.complete(new_state)
79
79
  return new_state
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
6
6
 
7
7
  from typing_extensions import Unpack
8
8
 
9
+ from inspect_ai._util.answer import answer_character, answer_index
9
10
  from inspect_ai._util.logger import warn_once
10
11
  from inspect_ai.util import resource
11
12
 
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
64
65
  indexes = list(range(len(choices)))
65
66
 
66
67
  return "\n".join(
67
- [f"{chr(65 + i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
+ [f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
69
  )
69
70
 
70
71
 
71
- def answer_character(index: int) -> str:
72
- r"""
73
- Helper to go from array index to char, for example:
74
-
75
- 0 -> 'A', 1 -> 'B', etc
76
- """
77
- return chr(ord("A") + index)
78
-
79
-
80
- def answer_index(char: str) -> int:
81
- r"""
82
- Helper to go from char to array index, for example:
83
-
84
- 'A' -> 0, 'B' -> 1, etc
85
- """
86
- return ord(char.upper()) - ord("A")
87
-
88
-
89
72
  def prompt(question: str, choices: Choices, template: str) -> str:
90
73
  choices_text = answer_options(choices)
91
- letters = ",".join(chr(65 + i) for i in range(len(choices)))
74
+ letters = ",".join(answer_character(i) for i in range(len(choices)))
92
75
 
93
76
  return template.format(
94
77
  choices=choices_text,
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
112
95
  # In this case, we're looking for a single line which contains the expected
113
96
  # ANSWER: B,C string with only whitespace after it
114
97
  match = re.search(
115
- r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
98
+ r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
116
99
  state.output.completion,
117
100
  flags=re.MULTILINE,
118
101
  )
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
121
104
  # version for backward compatibility
122
105
  if match is None:
123
106
  return re.search(
124
- r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
107
+ r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
125
108
  )
126
109
  else:
127
110
  return match
@@ -217,6 +200,7 @@ def multiple_choice(
217
200
  template: str | None = None,
218
201
  cot: bool = False,
219
202
  multiple_correct: bool = False,
203
+ max_tokens: int | None = None,
220
204
  **kwargs: Unpack[DeprecatedArgs],
221
205
  ) -> Solver:
222
206
  """Multiple choice question solver. Formats a multiple choice question prompt, then calls `generate()`.
@@ -243,6 +227,8 @@ def multiple_choice(
243
227
  squares? A) 3, B) 4, C) 9" has multiple correct answers, B and C. Leave
244
228
  as `False` if there's exactly one correct answer from the choices
245
229
  available. NOTE: this has no effect if you provide a custom template.
230
+ max_tokens: Default `None`. Controls the number of tokens generated through the call
231
+ to generate().
246
232
  **kwargs (Any): Deprecated arguments for backward compatibility.
247
233
 
248
234
  #### Shuffling
@@ -299,7 +285,7 @@ def multiple_choice(
299
285
  template=str(template),
300
286
  )
301
287
 
302
- state = await generate(state)
288
+ state = await generate(state, max_tokens=max_tokens)
303
289
 
304
290
  answers = parse_answers(state)
305
291
  if answers and answers.group(1):
@@ -102,7 +102,7 @@ class Plan(Solver):
102
102
  # execute steps
103
103
  for index, solver in enumerate(self.steps):
104
104
  # run solver
105
- with solver_transcript(solver, state) as st:
105
+ async with solver_transcript(solver, state) as st:
106
106
  state = await solver(state, generate)
107
107
  st.complete(state)
108
108
 
@@ -113,7 +113,7 @@ class Plan(Solver):
113
113
 
114
114
  # execute finish
115
115
  if self.finish:
116
- with solver_transcript(self.finish, state) as st:
116
+ async with solver_transcript(self.finish, state) as st:
117
117
  state = await self.finish(state, generate)
118
118
  st.complete(state)
119
119
  check_sample_interrupt()
@@ -204,13 +204,17 @@ class TaskState:
204
204
  Convenience function for accessing the initial input from the `Sample` as a string.
205
205
 
206
206
  If the `input` is a `list[ChatMessage]`, this will return the text from
207
- the first chat message
207
+ the last chat message
208
208
  """
209
209
  if isinstance(self._input, str):
210
210
  return self._input
211
211
  else:
212
212
  input = next(
213
- (message.text for message in self._input if message.role == "user"),
213
+ (
214
+ message.text
215
+ for message in reversed(self._input)
216
+ if message.role == "user"
217
+ ),
214
218
  None,
215
219
  )
216
220
  if input:
@@ -231,7 +235,7 @@ class TaskState:
231
235
  write access to the user chat prompt. Raises an
232
236
  exception if there is no user prompt
233
237
  """
234
- prompt = next((m for m in self.messages if m.role == "user"), None)
238
+ prompt = next((m for m in reversed(self.messages) if m.role == "user"), None)
235
239
  if prompt:
236
240
  return prompt
237
241
  else:
@@ -1,8 +1,9 @@
1
1
  import contextlib
2
- from typing import Iterator
2
+ from typing import AsyncIterator
3
3
 
4
4
  from inspect_ai._util.json import json_changes
5
5
  from inspect_ai._util.registry import registry_log_name
6
+ from inspect_ai.util._span import span
6
7
 
7
8
  from ._solver import Solver
8
9
  from ._task_state import TaskState, state_jsonable
@@ -22,12 +23,10 @@ class SolverTranscript:
22
23
  transcript()._event(StateEvent(changes=changes))
23
24
 
24
25
 
25
- @contextlib.contextmanager
26
- def solver_transcript(
26
+ @contextlib.asynccontextmanager
27
+ async def solver_transcript(
27
28
  solver: Solver, state: TaskState, name: str | None = None
28
- ) -> Iterator[SolverTranscript]:
29
- from inspect_ai.log._transcript import transcript
30
-
29
+ ) -> AsyncIterator[SolverTranscript]:
31
30
  name = registry_log_name(name or solver)
32
- with transcript().step(name=name, type="solver"):
31
+ async with span(name=name, type="solver"):
33
32
  yield SolverTranscript(name, state)
@@ -2,13 +2,11 @@ from contextlib import _AsyncGeneratorContextManager
2
2
  from typing import TypeAlias
3
3
 
4
4
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
5
- from mcp.types import (
6
- JSONRPCMessage,
7
- )
5
+ from mcp.shared.message import SessionMessage
8
6
 
9
7
  MCPServerContext: TypeAlias = _AsyncGeneratorContextManager[
10
8
  tuple[
11
- MemoryObjectReceiveStream[JSONRPCMessage | Exception],
12
- MemoryObjectSendStream[JSONRPCMessage],
9
+ MemoryObjectReceiveStream[SessionMessage | Exception],
10
+ MemoryObjectSendStream[SessionMessage],
13
11
  ],
14
12
  ]
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
61
61
  ) -> list[Tool]:
62
62
  return await self._task_session()._list_tools(tools)
63
63
 
64
- # create a separate MCPServer session per async task
65
- _task_sessions: dict[int, "MCPServerSession"] = {}
64
+ # create a separate MCPServer session per async task / server name
65
+ _task_sessions: dict[str, "MCPServerSession"] = {}
66
66
 
67
67
  def _task_session(self) -> "MCPServerSession":
68
68
  task_id = anyio.get_current_task().id
69
- if task_id not in self._task_sessions:
70
- MCPServerImpl._task_sessions[task_id] = MCPServerSession(
69
+ session_key = f"{task_id}_{self._name}"
70
+ if session_key not in self._task_sessions:
71
+ MCPServerImpl._task_sessions[session_key] = MCPServerSession(
71
72
  self._client, name=self._name, events=self._events
72
73
  )
73
- return MCPServerImpl._task_sessions[task_id]
74
+ return MCPServerImpl._task_sessions[session_key]
74
75
 
75
76
 
76
77
  class MCPServerSession(MCPServer):
@@ -102,7 +102,7 @@ def mcp_server_sandbox(
102
102
  def verfify_mcp_package() -> None:
103
103
  FEATURE = "MCP tools"
104
104
  PACKAGE = "mcp"
105
- MIN_VERSION = "1.6.0"
105
+ MIN_VERSION = "1.8.0"
106
106
 
107
107
  # verify we have the package
108
108
  try:
@@ -96,7 +96,10 @@ def python(
96
96
  The output of the Python code.
97
97
  """
98
98
  result = await sandbox_env(sandbox).exec(
99
- cmd=["python3"], input=code, timeout=timeout, user=user
99
+ cmd=["bash", "--login", "-c", "python3 -"],
100
+ input=code,
101
+ timeout=timeout,
102
+ user=user,
100
103
  )
101
104
  # return output (including stderr if any)
102
105
  output = ""
@@ -41,7 +41,7 @@ def think(
41
41
  def think_tool_viewer() -> ToolCallViewer:
42
42
  def viewer(tool_call: ToolCall) -> ToolCallView:
43
43
  call = ToolCallContent(
44
- format="markdown", content=tool_call.arguments["thought"]
44
+ format="markdown", content=tool_call.arguments.get("thought", "")
45
45
  )
46
46
  return ToolCallView(call=call)
47
47
 
@@ -0,0 +1,3 @@
1
+ from ._web_search import web_search
2
+
3
+ __all__ = ["web_search"]
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Literal, Protocol, runtime_checkable
2
+ from typing import Awaitable, Callable
3
3
 
4
4
  import anyio
5
5
  import httpx
@@ -16,8 +16,6 @@ from inspect_ai._util.error import PrerequisiteError
16
16
  from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
17
17
  from inspect_ai.util._concurrency import concurrency
18
18
 
19
- from .._tool import Tool, ToolResult, tool
20
-
21
19
  DEFAULT_RELEVANCE_PROMPT = """I am trying to answer the following question and need to find the most relevant information on the web. Please let me know if the following content is relevant to the question or not. You should just respond with "yes" or "no".
22
20
 
23
21
  Question: {question}
@@ -31,59 +29,35 @@ class SearchLink:
31
29
  self.snippet = snippet
32
30
 
33
31
 
34
- @runtime_checkable
35
- class SearchProvider(Protocol):
36
- async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
37
-
38
-
39
- @tool
40
- def web_search(
41
- provider: Literal["google"] = "google",
42
- num_results: int = 3,
43
- max_provider_calls: int = 3,
44
- max_connections: int = 10,
45
- model: str | None = None,
46
- ) -> Tool:
47
- """Web search tool.
48
-
49
- A tool that can be registered for use by models to search the web. Use
50
- the `use_tools()` solver to make the tool available (e.g. `use_tools(web_search())`))
51
-
52
- A web search is conducted using the specified provider, the results are parsed for relevance
53
- using the specified model, and the top 'num_results' relevant pages are returned.
54
-
55
- See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
56
-
57
- Args:
58
- provider: Search provider (defaults to "google", currently
59
- the only provider). Possible future providers include "brave" and "bing".
60
- num_results: Number of web search result pages to return to the model.
61
- max_provider_calls: Maximum number of search calls to make to the search provider.
62
- max_connections: Maximum number of concurrent connections to API
63
- endpoint of search provider.
64
- model: Model used to parse web pages for relevance.
32
+ def maybe_get_google_api_keys() -> tuple[str, str] | None:
33
+ """
34
+ Get Google API keys from environment variables.
65
35
 
66
36
  Returns:
67
- A tool that can be registered for use by models to search the web.
37
+ tuple: A tuple containing the Google API key and the Google CSE ID.
68
38
  """
69
- # get search client
70
- client = httpx.AsyncClient()
39
+ google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
40
+ google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
41
+ return (google_api_key, google_cse_id) if google_api_key and google_cse_id else None
71
42
 
72
- if provider == "google":
73
- search_provider = google_search_provider(client)
74
- else:
75
- raise ValueError(
76
- f"Provider {provider} not supported. Only 'google' is supported."
43
+
44
+ def google_search_provider(
45
+ num_results: int,
46
+ max_provider_calls: int,
47
+ max_connections: int,
48
+ model: str | None,
49
+ ) -> Callable[[str], Awaitable[str | None]]:
50
+ keys = maybe_get_google_api_keys()
51
+ if not keys:
52
+ raise PrerequisiteError(
53
+ "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
77
54
  )
55
+ google_api_key, google_cse_id = keys
78
56
 
79
- # resolve provider (only google for now)
80
- async def execute(query: str) -> ToolResult:
81
- """
82
- Use the web_search tool to perform keyword searches of the web.
57
+ # Create the client within the provider
58
+ client = httpx.AsyncClient()
83
59
 
84
- Args:
85
- query (str): Search query.
86
- """
60
+ async def search(query: str) -> str | None:
87
61
  # limit number of concurrent searches
88
62
  page_contents: list[str] = []
89
63
  urls: list[str] = []
@@ -92,8 +66,8 @@ def web_search(
92
66
 
93
67
  # Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
94
68
  while len(page_contents) < num_results and search_calls < max_provider_calls:
95
- async with concurrency(f"{provider}_web_search", max_connections):
96
- links = await search_provider(query, start_idx=search_calls * 10)
69
+ async with concurrency("google_web_search", max_connections):
70
+ links = await _search(query, start_idx=search_calls * 10)
97
71
 
98
72
  async with anyio.create_task_group() as tg:
99
73
 
@@ -114,19 +88,39 @@ def web_search(
114
88
  search_calls += 1
115
89
 
116
90
  all_page_contents = "\n\n".join(page_contents)
117
- if all_page_contents == "":
118
- response: ToolResult = (
119
- "I'm sorry, I couldn't find any relevant information on the web."
120
- )
121
- else:
122
- response = (
123
- "Here are your web search results. Please read them carefully as they may be useful later! "
124
- + all_page_contents
125
- )
91
+ return None if all_page_contents == "" else all_page_contents
126
92
 
127
- return response
93
+ async def _search(query: str, start_idx: int) -> list[SearchLink]:
94
+ # List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
95
+ search_params = {
96
+ "q": query,
97
+ "key": google_api_key,
98
+ "cx": google_cse_id,
99
+ "start": start_idx,
100
+ }
101
+ search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
102
+ [f"{key}={value}" for key, value in search_params.items()]
103
+ )
128
104
 
129
- return execute
105
+ # retry up to 5 times over a period of up to 1 minute
106
+ @retry(
107
+ wait=wait_exponential_jitter(),
108
+ stop=stop_after_attempt(5) | stop_after_delay(60),
109
+ retry=retry_if_exception(httpx_should_retry),
110
+ before_sleep=log_httpx_retry_attempt(search_url),
111
+ )
112
+ async def execute_search() -> httpx.Response:
113
+ return await client.get(search_url)
114
+
115
+ result = await execute_search()
116
+ data = result.json()
117
+
118
+ if "items" in data:
119
+ return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
120
+ else:
121
+ return []
122
+
123
+ return search
130
124
 
131
125
 
132
126
  async def page_if_relevant(
@@ -183,44 +177,3 @@ async def page_if_relevant(
183
177
  return full_text
184
178
  else:
185
179
  return None
186
-
187
-
188
- def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
189
- google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
190
- google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
191
- if not google_api_key or not google_cse_id:
192
- raise PrerequisiteError(
193
- "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
194
- )
195
-
196
- async def search(query: str, start_idx: int) -> list[SearchLink]:
197
- # List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
198
- search_params = {
199
- "q": query,
200
- "key": google_api_key,
201
- "cx": google_cse_id,
202
- "start": start_idx,
203
- }
204
- search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
205
- [f"{key}={value}" for key, value in search_params.items()]
206
- )
207
-
208
- # retry up to 5 times over a period of up to 1 minute
209
- @retry(
210
- wait=wait_exponential_jitter(),
211
- stop=stop_after_attempt(5) | stop_after_delay(60),
212
- retry=retry_if_exception(httpx_should_retry),
213
- before_sleep=log_httpx_retry_attempt(search_url),
214
- )
215
- async def execute_search() -> httpx.Response:
216
- return await client.get(search_url)
217
-
218
- result = await execute_search()
219
- data = result.json()
220
-
221
- if "items" in data:
222
- return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
223
- else:
224
- return []
225
-
226
- return search