inspect-ai 0.3.94__py3-none-any.whl → 0.3.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. inspect_ai/_eval/loader.py +1 -1
  2. inspect_ai/_eval/task/run.py +12 -6
  3. inspect_ai/_util/exception.py +4 -0
  4. inspect_ai/_util/hash.py +39 -0
  5. inspect_ai/_util/local_server.py +16 -0
  6. inspect_ai/_util/path.py +22 -0
  7. inspect_ai/_util/trace.py +1 -1
  8. inspect_ai/_util/working.py +4 -0
  9. inspect_ai/_view/www/dist/assets/index.css +9 -9
  10. inspect_ai/_view/www/dist/assets/index.js +117 -120
  11. inspect_ai/_view/www/package.json +1 -1
  12. inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
  13. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
  14. inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
  15. inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
  16. inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
  17. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
  18. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
  19. inspect_ai/_view/www/src/app/types.ts +12 -2
  20. inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
  21. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
  22. inspect_ai/_view/www/src/state/hooks.ts +19 -3
  23. inspect_ai/_view/www/src/state/logSlice.ts +23 -5
  24. inspect_ai/_view/www/yarn.lock +9 -9
  25. inspect_ai/agent/_bridge/patch.py +1 -3
  26. inspect_ai/agent/_types.py +1 -1
  27. inspect_ai/analysis/__init__.py +0 -0
  28. inspect_ai/analysis/beta/__init__.py +67 -0
  29. inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
  30. inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
  31. inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
  32. inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
  33. inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
  34. inspect_ai/analysis/beta/_dataframe/evals/table.py +177 -0
  35. inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
  36. inspect_ai/analysis/beta/_dataframe/events/columns.py +87 -0
  37. inspect_ai/analysis/beta/_dataframe/events/extract.py +26 -0
  38. inspect_ai/analysis/beta/_dataframe/events/table.py +100 -0
  39. inspect_ai/analysis/beta/_dataframe/extract.py +73 -0
  40. inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
  41. inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
  42. inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
  43. inspect_ai/analysis/beta/_dataframe/messages/table.py +79 -0
  44. inspect_ai/analysis/beta/_dataframe/progress.py +26 -0
  45. inspect_ai/analysis/beta/_dataframe/record.py +377 -0
  46. inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
  47. inspect_ai/analysis/beta/_dataframe/samples/columns.py +77 -0
  48. inspect_ai/analysis/beta/_dataframe/samples/extract.py +54 -0
  49. inspect_ai/analysis/beta/_dataframe/samples/table.py +370 -0
  50. inspect_ai/analysis/beta/_dataframe/util.py +160 -0
  51. inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
  52. inspect_ai/log/_file.py +10 -3
  53. inspect_ai/log/_log.py +21 -1
  54. inspect_ai/model/_call_tools.py +2 -1
  55. inspect_ai/model/_model.py +6 -4
  56. inspect_ai/model/_openai_responses.py +17 -18
  57. inspect_ai/model/_providers/anthropic.py +30 -5
  58. inspect_ai/model/_providers/providers.py +1 -1
  59. inspect_ai/solver/_multiple_choice.py +4 -1
  60. inspect_ai/solver/_task_state.py +8 -4
  61. inspect_ai/tool/_mcp/_context.py +3 -5
  62. inspect_ai/tool/_mcp/_sandbox.py +17 -14
  63. inspect_ai/tool/_mcp/server.py +1 -1
  64. inspect_ai/tool/_tools/_think.py +1 -1
  65. inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
  66. inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
  67. inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
  68. inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
  69. inspect_ai/util/_sandbox/events.py +3 -2
  70. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/METADATA +9 -2
  71. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/RECORD +75 -46
  72. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/WHEEL +1 -1
  73. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/entry_points.txt +0 -0
  74. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/licenses/LICENSE +0 -0
  75. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/top_level.txt +0 -0
inspect_ai/log/_file.py CHANGED
@@ -524,14 +524,21 @@ def manifest_eval_log_name(info: EvalLogInfo, log_dir: str, sep: str) -> str:
524
524
 
525
525
  def log_files_from_ls(
526
526
  ls: list[FileInfo],
527
- formats: list[Literal["eval", "json"]] | None,
527
+ formats: list[Literal["eval", "json"]] | None = None,
528
528
  descending: bool = True,
529
+ sort: bool = True,
529
530
  ) -> list[EvalLogInfo]:
530
531
  extensions = [f".{format}" for format in (formats or ALL_LOG_FORMATS)]
531
532
  return [
532
533
  log_file_info(file)
533
- for file in sorted(
534
- ls, key=lambda file: (file.mtime if file.mtime else 0), reverse=descending
534
+ for file in (
535
+ sorted(
536
+ ls,
537
+ key=lambda file: (file.mtime if file.mtime else 0),
538
+ reverse=descending,
539
+ )
540
+ if sort
541
+ else ls
535
542
  )
536
543
  if file.type == "file" and is_log_file(file.name, extensions)
537
544
  ]
inspect_ai/log/_log.py CHANGED
@@ -17,9 +17,11 @@ from pydantic import (
17
17
  )
18
18
  from rich.console import Console, RenderableType
19
19
  from rich.traceback import Traceback
20
+ from shortuuid import uuid
20
21
 
21
- from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
22
+ from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, DESERIALIZING, PKG_NAME
22
23
  from inspect_ai._util.error import EvalError, exception_message
24
+ from inspect_ai._util.hash import base57_id_hash
23
25
  from inspect_ai._util.logger import warn_once
24
26
  from inspect_ai.approval._policy import ApprovalPolicyConfig
25
27
  from inspect_ai.dataset._dataset import MT, metadata_as
@@ -677,6 +679,9 @@ class EvalModelConfig(BaseModel):
677
679
  class EvalSpec(BaseModel):
678
680
  """Eval target and configuration."""
679
681
 
682
+ eval_id: str = Field(default_factory=str)
683
+ """Globally unique id for eval."""
684
+
680
685
  run_id: str = Field(default_factory=str)
681
686
  """Unique run id"""
682
687
 
@@ -757,6 +762,21 @@ class EvalSpec(BaseModel):
757
762
  # allow field model_args
758
763
  model_config = ConfigDict(protected_namespaces=())
759
764
 
765
+ def model_post_init(self, __context: Any) -> None:
766
+ # check if deserializing
767
+ is_deserializing = isinstance(__context, dict) and __context.get(
768
+ DESERIALIZING, False
769
+ )
770
+
771
+ # Generate eval_id if needed
772
+ if self.eval_id == "":
773
+ if is_deserializing:
774
+ # we want the eval_id to be stable across reads of the eval log so we compose it
775
+ # as a hash that matches the size/apperance of shortuuid-based uuids
776
+ self.eval_id = base57_id_hash(self.run_id + self.task_id + self.created)
777
+ else:
778
+ self.eval_id = uuid()
779
+
760
780
  @model_validator(mode="before")
761
781
  @classmethod
762
782
  def read_sandbox_spec(
@@ -39,6 +39,7 @@ from inspect_ai._util.content import (
39
39
  ContentText,
40
40
  ContentVideo,
41
41
  )
42
+ from inspect_ai._util.exception import TerminateSampleError
42
43
  from inspect_ai._util.format import format_function_call
43
44
  from inspect_ai._util.logger import warn_once
44
45
  from inspect_ai._util.registry import registry_unqualified_name
@@ -376,7 +377,7 @@ async def call_tool(
376
377
  transcript()._event(
377
378
  SampleLimitEvent(type="operator", limit=1, message=message)
378
379
  )
379
- raise LimitExceededError("operator", value=1, limit=1, message=message)
380
+ raise TerminateSampleError(message)
380
381
  else:
381
382
  raise ToolApprovalError(approval.explanation if approval else None)
382
383
  if approval and approval.modified:
@@ -1237,9 +1237,10 @@ def tool_result_images_as_user_message(
1237
1237
 
1238
1238
  Tool responses will have images replaced with "Image content is included below.", and the new user message will contain the images.
1239
1239
  """
1240
- init_accum: ImagesAccumulator = ([], [], [])
1241
1240
  chat_messages, user_message_content, tool_call_ids = functools.reduce(
1242
- tool_result_images_reducer, messages, init_accum
1241
+ tool_result_images_reducer,
1242
+ messages,
1243
+ (list[ChatMessage](), list[Content](), list[str]()),
1243
1244
  )
1244
1245
  # if the last message was a tool result, we may need to flush the pending stuff here
1245
1246
  return maybe_adding_user_message(chat_messages, user_message_content, tool_call_ids)
@@ -1265,9 +1266,10 @@ def tool_result_images_reducer(
1265
1266
  and isinstance(message.content, list)
1266
1267
  and any([isinstance(c, ContentImage) for c in message.content])
1267
1268
  ):
1268
- init_accum: ImageContentAccumulator = ([], [])
1269
1269
  new_user_message_content, edited_tool_message_content = functools.reduce(
1270
- tool_result_image_content_reducer, message.content, init_accum
1270
+ tool_result_image_content_reducer,
1271
+ message.content,
1272
+ (list[Content](), list[Content]()),
1271
1273
  )
1272
1274
 
1273
1275
  return (
@@ -184,24 +184,23 @@ def openai_responses_chat_choices(
184
184
  # │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
185
185
  # │ │ │ type: "reasoning" │ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "reasoning" │ │ │
186
186
  # │ │ │ id: "rs_bbbbbb" │ │ │ │ │ │ text: "" │ │ │ │ │ │ id: "rs_bbbbbb" │ │ │
187
- # │ │ │ summary: [] │ │ │ │ │ └───────────────────┘ │ │ │ │ │ summary: [] │ │ │
188
- # │ │ └───────────────────┘ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
189
- # │ │ ┌───────────────────┐ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "message" │ │ │
190
- # │ │ │ type: "message" │ │ │ │ │ text: "text1" │ │ │ │ │ id: "msg_ccccccc" │ │ │
191
- # │ │ │ id: "msg_ccccccc" │ │ │ │ │ └───────────────────┘ │ │ │ │ │ role: "assistant" │ │ │
192
- # │ │ │ role: "assistant" │ │ │--->│┌───────────────────┐│--->│ │ │ ┌───────────────┐ │ │ │
193
- # │ │ │ ┌───────────────┐ │ │ │ │ │ ContentText │ │ │ │ │ │ Content │ │ │ │
194
- # │ │ │ │ Content │ │ │ │ │ text: "text2" │ │ │ │ │ │ │ ┌───────────┐ │ │ │ │
195
- # │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ │"text1" │ │ │ │ │
196
- # │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ └───────────┘ │ │ │ │
197
- # │ │ │ │ └───────────┘ │ │ │ │ │ │ internal │ │ │ │ │ │ ┌───────────┐ │ │ │ │
198
- # │ │ │ │ ┌───────────┐ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ "text2" │ │ │ │
199
- # │ │ │ "text2" │ │ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ └───────────┘ │ │ │
200
- # │ │ └───────────┘ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ └───────────────┘ │ │
201
- # │ └───────────────┘ │ │ │ │ │ └───────────────────┘ │ │ │ └───────────────────┘
202
- # │ │ └───────────────────┘ │ │ ┌───────────────────┐ │ └───────────────────────┘ │
203
- # └───────────────────────┘ │ │ output_msg_id: │ │ │ └───────────────────────────┘
204
- # └───────────────────────────┘ │ │ │ "msg_ccccccc" │ │ │
187
+ # │ │ │ summary: [] │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ summary: [] │ │ │
188
+ # │ │ ├───────────────────┤ │ │ │ │ ContentText │ │ │ │ ├───────────────────┤ │ │
189
+ # │ │ type: "message" │ │ │ │ │ text: "text1" │ │ │ │ │ │ type: "message" │ │ │
190
+ # │ │ │ id: "msg_ccccccc" │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ id: "msg_ccccccc" │ │ │
191
+ # │ │ │ role: "assistant" │ │ │ │ │ ContentText │ │ │ │ │ role: "assistant" │ │ │
192
+ # │ │ │ ┌───────────────┐ │ │ ->text: "text2" │ │ │ -> │ │ │ ┌───────────────┐ │ │ │
193
+ # │ │ │ Content │ │ │ │ │ └───────────────────┘ │ │ │ │ │ │ Content │ │ │ │
194
+ # │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ ┌───────────┐ │ │ │ │
195
+ # │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ │"text1" │ │ │ │ │
196
+ # │ │ │ │ ├───────────┤ │ │ │ │ │ internal │ │ │ │ ├───────────┤ │ │ │ │
197
+ # │ │ │ │ │"text2" │ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ │"text2" │ │ │ │ │
198
+ # │ │ │ │ └───────────┘ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ └───────────┘ │ │ │ │
199
+ # │ │ │ └───────────────┘ │ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ │ │ └───────────────┘ │ │ │
200
+ # │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │
201
+ # │ └───────────────────────┘ │ │ │ ┌───────────────────┐ │ │ │ └───────────────────────┘
202
+ # └───────────────────────────┘ │ │ │ output_msg_id: │ │ │ └───────────────────────────┘
203
+ # │ │ │ "msg_ccccccc" │ │ │
205
204
  # │ │ └───────────────────┘ │ │
206
205
  # │ └───────────────────────┘ │
207
206
  # └───────────────────────────┘
@@ -33,7 +33,10 @@ from anthropic.types import (
33
33
  ToolUseBlockParam,
34
34
  message_create_params,
35
35
  )
36
- from anthropic.types.beta import BetaToolComputerUse20250124Param
36
+ from anthropic.types.beta import (
37
+ BetaToolComputerUse20250124Param,
38
+ BetaToolTextEditor20241022Param,
39
+ )
37
40
  from pydantic import JsonValue
38
41
  from typing_extensions import override
39
42
 
@@ -218,6 +221,8 @@ class AnthropicAPI(ModelAPI):
218
221
  # tools are generally available for Claude 3.5 Sonnet (new) as well and
219
222
  # can be used without the computer use beta header.
220
223
  betas.append("computer-use-2025-01-24")
224
+ if any("20241022" in str(tool.get("type", "")) for tool in tools_param):
225
+ betas.append("computer-use-2024-10-22")
221
226
  if len(betas) > 0:
222
227
  extra_headers["anthropic-beta"] = ",".join(betas)
223
228
 
@@ -337,6 +342,15 @@ class AnthropicAPI(ModelAPI):
337
342
  @override
338
343
  def should_retry(self, ex: Exception) -> bool:
339
344
  if isinstance(ex, APIStatusError):
345
+ # for unknown reasons, anthropic does not always set status_code == 529
346
+ # for "overloaded_error" so we check for it explicitly
347
+ if (
348
+ isinstance(ex.body, dict)
349
+ and ex.body.get("error", {}).get("type", "") == "overloaded_error"
350
+ ):
351
+ return True
352
+
353
+ # standard http status code checking
340
354
  return is_retryable_http_status(ex.status_code)
341
355
  elif httpx_should_retry(ex):
342
356
  return True
@@ -545,7 +559,7 @@ class AnthropicAPI(ModelAPI):
545
559
 
546
560
  def text_editor_tool_param(
547
561
  self, tool: ToolInfo
548
- ) -> Optional[ToolTextEditor20250124Param]:
562
+ ) -> ToolTextEditor20250124Param | BetaToolTextEditor20241022Param | None:
549
563
  # check for compatible 'text editor' tool
550
564
  if tool.name == "text_editor" and (
551
565
  sorted(tool.parameters.properties.keys())
@@ -561,8 +575,14 @@ class AnthropicAPI(ModelAPI):
561
575
  ]
562
576
  )
563
577
  ):
564
- return ToolTextEditor20250124Param(
565
- type="text_editor_20250124", name="str_replace_editor"
578
+ return (
579
+ BetaToolTextEditor20241022Param(
580
+ type="text_editor_20241022", name="str_replace_editor"
581
+ )
582
+ if self.is_claude_3_5()
583
+ else ToolTextEditor20250124Param(
584
+ type="text_editor_20250124", name="str_replace_editor"
585
+ )
566
586
  )
567
587
  # not a text_editor tool
568
588
  else:
@@ -571,7 +591,10 @@ class AnthropicAPI(ModelAPI):
571
591
 
572
592
  # tools can be either a stock tool param or a special Anthropic native use tool param
573
593
  ToolParamDef = (
574
- ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
594
+ ToolParam
595
+ | BetaToolComputerUse20250124Param
596
+ | ToolTextEditor20250124Param
597
+ | BetaToolTextEditor20241022Param
575
598
  )
576
599
 
577
600
 
@@ -580,6 +603,7 @@ def add_cache_control(
580
603
  | ToolParam
581
604
  | BetaToolComputerUse20250124Param
582
605
  | ToolTextEditor20250124Param
606
+ | BetaToolTextEditor20241022Param
583
607
  | dict[str, Any],
584
608
  ) -> None:
585
609
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
@@ -844,6 +868,7 @@ def _names_for_tool_call(
844
868
  """
845
869
  mappings = (
846
870
  (INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
871
+ ("str_replace_editor", "text_editor_20241022", "text_editor"),
847
872
  ("str_replace_editor", "text_editor_20250124", "text_editor"),
848
873
  ("bash", "bash_20250124", "bash_session"),
849
874
  )
@@ -281,7 +281,7 @@ def none() -> type[ModelAPI]:
281
281
  def validate_openai_client(feature: str) -> None:
282
282
  FEATURE = feature
283
283
  PACKAGE = "openai"
284
- MIN_VERSION = "1.75.0"
284
+ MIN_VERSION = "1.78.0"
285
285
 
286
286
  # verify we have the package
287
287
  try:
@@ -200,6 +200,7 @@ def multiple_choice(
200
200
  template: str | None = None,
201
201
  cot: bool = False,
202
202
  multiple_correct: bool = False,
203
+ max_tokens: int | None = None,
203
204
  **kwargs: Unpack[DeprecatedArgs],
204
205
  ) -> Solver:
205
206
  """Multiple choice question solver. Formats a multiple choice question prompt, then calls `generate()`.
@@ -226,6 +227,8 @@ def multiple_choice(
226
227
  squares? A) 3, B) 4, C) 9" has multiple correct answers, B and C. Leave
227
228
  as `False` if there's exactly one correct answer from the choices
228
229
  available. NOTE: this has no effect if you provide a custom template.
230
+ max_tokens: Default `None`. Controls the number of tokens generated through the call
231
+ to generate().
229
232
  **kwargs (Any): Deprecated arguments for backward compatibility.
230
233
 
231
234
  #### Shuffling
@@ -282,7 +285,7 @@ def multiple_choice(
282
285
  template=str(template),
283
286
  )
284
287
 
285
- state = await generate(state)
288
+ state = await generate(state, max_tokens=max_tokens)
286
289
 
287
290
  answers = parse_answers(state)
288
291
  if answers and answers.group(1):
@@ -138,7 +138,7 @@ class TaskState:
138
138
  The `TaskState` represents the internal state of the `Task` being run for a single `Sample`.
139
139
 
140
140
  The `TaskState` is passed to and returned from each solver during a sample's
141
- evaluation. It allows us to manipulated the message history, the tools
141
+ evaluation. It allows us to maintain the manipulated message history, the tools
142
142
  available to the model, the final output of the model, and whether the task
143
143
  is completed or has hit a limit.
144
144
  """
@@ -204,13 +204,17 @@ class TaskState:
204
204
  Convenience function for accessing the initial input from the `Sample` as a string.
205
205
 
206
206
  If the `input` is a `list[ChatMessage]`, this will return the text from
207
- the first chat message
207
+ the last chat message
208
208
  """
209
209
  if isinstance(self._input, str):
210
210
  return self._input
211
211
  else:
212
212
  input = next(
213
- (message.text for message in self._input if message.role == "user"),
213
+ (
214
+ message.text
215
+ for message in reversed(self._input)
216
+ if message.role == "user"
217
+ ),
214
218
  None,
215
219
  )
216
220
  if input:
@@ -231,7 +235,7 @@ class TaskState:
231
235
  write access to the user chat prompt. Raises an
232
236
  exception if there is no user prompt
233
237
  """
234
- prompt = next((m for m in self.messages if m.role == "user"), None)
238
+ prompt = next((m for m in reversed(self.messages) if m.role == "user"), None)
235
239
  if prompt:
236
240
  return prompt
237
241
  else:
@@ -2,13 +2,11 @@ from contextlib import _AsyncGeneratorContextManager
2
2
  from typing import TypeAlias
3
3
 
4
4
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
5
- from mcp.types import (
6
- JSONRPCMessage,
7
- )
5
+ from mcp.shared.message import SessionMessage
8
6
 
9
7
  MCPServerContext: TypeAlias = _AsyncGeneratorContextManager[
10
8
  tuple[
11
- MemoryObjectReceiveStream[JSONRPCMessage | Exception],
12
- MemoryObjectSendStream[JSONRPCMessage],
9
+ MemoryObjectReceiveStream[SessionMessage | Exception],
10
+ MemoryObjectSendStream[SessionMessage],
13
11
  ],
14
12
  ]
@@ -5,6 +5,7 @@ from typing import TextIO
5
5
  import anyio
6
6
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7
7
  from mcp import JSONRPCRequest, StdioServerParameters
8
+ from mcp.shared.message import SessionMessage
8
9
  from mcp.types import JSONRPCMessage, JSONRPCNotification
9
10
 
10
11
  from inspect_ai.tool._tool_support_helpers import (
@@ -36,12 +37,12 @@ async def sandbox_client( # type: ignore
36
37
  )
37
38
 
38
39
  # read_stream is remote process's stdout
39
- read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
40
- read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
40
+ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
41
+ read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
41
42
 
42
43
  # write_stream is remote process's stdin
43
- write_stream: MemoryObjectSendStream[JSONRPCMessage]
44
- write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
44
+ write_stream: MemoryObjectSendStream[SessionMessage]
45
+ write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
45
46
 
46
47
  read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
47
48
  write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -64,18 +65,20 @@ async def sandbox_client( # type: ignore
64
65
  async with write_stream_reader:
65
66
  # This reads messages until the stream is closed
66
67
  async for message in write_stream_reader:
67
- root = message.root
68
+ root = message.message.root
68
69
  if isinstance(root, JSONRPCRequest):
69
70
  await read_stream_writer.send(
70
- await exec_model_request(
71
- sandbox=sandbox_environment,
72
- method="mcp_send_request",
73
- params={
74
- "session_id": session_id,
75
- "request": root.model_dump(),
76
- },
77
- result_type=JSONRPCMessage,
78
- timeout=timeout,
71
+ SessionMessage(
72
+ message=await exec_model_request(
73
+ sandbox=sandbox_environment,
74
+ method="mcp_send_request",
75
+ params={
76
+ "session_id": session_id,
77
+ "request": root.model_dump(),
78
+ },
79
+ result_type=JSONRPCMessage,
80
+ timeout=timeout,
81
+ )
79
82
  )
80
83
  )
81
84
  elif isinstance(root, JSONRPCNotification):
@@ -102,7 +102,7 @@ def mcp_server_sandbox(
102
102
  def verfify_mcp_package() -> None:
103
103
  FEATURE = "MCP tools"
104
104
  PACKAGE = "mcp"
105
- MIN_VERSION = "1.6.0"
105
+ MIN_VERSION = "1.8.0"
106
106
 
107
107
  # verify we have the package
108
108
  try:
@@ -41,7 +41,7 @@ def think(
41
41
  def think_tool_viewer() -> ToolCallViewer:
42
42
  def viewer(tool_call: ToolCall) -> ToolCallView:
43
43
  call = ToolCallContent(
44
- format="markdown", content=tool_call.arguments["thought"]
44
+ format="markdown", content=tool_call.arguments.get("thought", "")
45
45
  )
46
46
  return ToolCallView(call=call)
47
47
 
@@ -0,0 +1,3 @@
1
+ from ._web_search import web_search
2
+
3
+ __all__ = ["web_search"]
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Literal, Protocol, runtime_checkable
2
+ from typing import Awaitable, Callable
3
3
 
4
4
  import anyio
5
5
  import httpx
@@ -16,8 +16,6 @@ from inspect_ai._util.error import PrerequisiteError
16
16
  from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
17
17
  from inspect_ai.util._concurrency import concurrency
18
18
 
19
- from .._tool import Tool, ToolResult, tool
20
-
21
19
  DEFAULT_RELEVANCE_PROMPT = """I am trying to answer the following question and need to find the most relevant information on the web. Please let me know if the following content is relevant to the question or not. You should just respond with "yes" or "no".
22
20
 
23
21
  Question: {question}
@@ -31,59 +29,35 @@ class SearchLink:
31
29
  self.snippet = snippet
32
30
 
33
31
 
34
- @runtime_checkable
35
- class SearchProvider(Protocol):
36
- async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
37
-
38
-
39
- @tool
40
- def web_search(
41
- provider: Literal["google"] = "google",
42
- num_results: int = 3,
43
- max_provider_calls: int = 3,
44
- max_connections: int = 10,
45
- model: str | None = None,
46
- ) -> Tool:
47
- """Web search tool.
48
-
49
- A tool that can be registered for use by models to search the web. Use
50
- the `use_tools()` solver to make the tool available (e.g. `use_tools(web_search())`))
51
-
52
- A web search is conducted using the specified provider, the results are parsed for relevance
53
- using the specified model, and the top 'num_results' relevant pages are returned.
54
-
55
- See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
56
-
57
- Args:
58
- provider: Search provider (defaults to "google", currently
59
- the only provider). Possible future providers include "brave" and "bing".
60
- num_results: Number of web search result pages to return to the model.
61
- max_provider_calls: Maximum number of search calls to make to the search provider.
62
- max_connections: Maximum number of concurrent connections to API
63
- endpoint of search provider.
64
- model: Model used to parse web pages for relevance.
32
+ def maybe_get_google_api_keys() -> tuple[str, str] | None:
33
+ """
34
+ Get Google API keys from environment variables.
65
35
 
66
36
  Returns:
67
- A tool that can be registered for use by models to search the web.
37
+ tuple: A tuple containing the Google API key and the Google CSE ID.
68
38
  """
69
- # get search client
70
- client = httpx.AsyncClient()
39
+ google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
40
+ google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
41
+ return (google_api_key, google_cse_id) if google_api_key and google_cse_id else None
71
42
 
72
- if provider == "google":
73
- search_provider = google_search_provider(client)
74
- else:
75
- raise ValueError(
76
- f"Provider {provider} not supported. Only 'google' is supported."
43
+
44
+ def google_search_provider(
45
+ num_results: int,
46
+ max_provider_calls: int,
47
+ max_connections: int,
48
+ model: str | None,
49
+ ) -> Callable[[str], Awaitable[str | None]]:
50
+ keys = maybe_get_google_api_keys()
51
+ if not keys:
52
+ raise PrerequisiteError(
53
+ "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
77
54
  )
55
+ google_api_key, google_cse_id = keys
78
56
 
79
- # resolve provider (only google for now)
80
- async def execute(query: str) -> ToolResult:
81
- """
82
- Use the web_search tool to perform keyword searches of the web.
57
+ # Create the client within the provider
58
+ client = httpx.AsyncClient()
83
59
 
84
- Args:
85
- query (str): Search query.
86
- """
60
+ async def search(query: str) -> str | None:
87
61
  # limit number of concurrent searches
88
62
  page_contents: list[str] = []
89
63
  urls: list[str] = []
@@ -92,8 +66,8 @@ def web_search(
92
66
 
93
67
  # Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
94
68
  while len(page_contents) < num_results and search_calls < max_provider_calls:
95
- async with concurrency(f"{provider}_web_search", max_connections):
96
- links = await search_provider(query, start_idx=search_calls * 10)
69
+ async with concurrency("google_web_search", max_connections):
70
+ links = await _search(query, start_idx=search_calls * 10)
97
71
 
98
72
  async with anyio.create_task_group() as tg:
99
73
 
@@ -114,19 +88,39 @@ def web_search(
114
88
  search_calls += 1
115
89
 
116
90
  all_page_contents = "\n\n".join(page_contents)
117
- if all_page_contents == "":
118
- response: ToolResult = (
119
- "I'm sorry, I couldn't find any relevant information on the web."
120
- )
121
- else:
122
- response = (
123
- "Here are your web search results. Please read them carefully as they may be useful later! "
124
- + all_page_contents
125
- )
91
+ return None if all_page_contents == "" else all_page_contents
126
92
 
127
- return response
93
+ async def _search(query: str, start_idx: int) -> list[SearchLink]:
94
+ # List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
95
+ search_params = {
96
+ "q": query,
97
+ "key": google_api_key,
98
+ "cx": google_cse_id,
99
+ "start": start_idx,
100
+ }
101
+ search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
102
+ [f"{key}={value}" for key, value in search_params.items()]
103
+ )
128
104
 
129
- return execute
105
+ # retry up to 5 times over a period of up to 1 minute
106
+ @retry(
107
+ wait=wait_exponential_jitter(),
108
+ stop=stop_after_attempt(5) | stop_after_delay(60),
109
+ retry=retry_if_exception(httpx_should_retry),
110
+ before_sleep=log_httpx_retry_attempt(search_url),
111
+ )
112
+ async def execute_search() -> httpx.Response:
113
+ return await client.get(search_url)
114
+
115
+ result = await execute_search()
116
+ data = result.json()
117
+
118
+ if "items" in data:
119
+ return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
120
+ else:
121
+ return []
122
+
123
+ return search
130
124
 
131
125
 
132
126
  async def page_if_relevant(
@@ -183,44 +177,3 @@ async def page_if_relevant(
183
177
  return full_text
184
178
  else:
185
179
  return None
186
-
187
-
188
- def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
189
- google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
190
- google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
191
- if not google_api_key or not google_cse_id:
192
- raise PrerequisiteError(
193
- "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
194
- )
195
-
196
- async def search(query: str, start_idx: int) -> list[SearchLink]:
197
- # List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
198
- search_params = {
199
- "q": query,
200
- "key": google_api_key,
201
- "cx": google_cse_id,
202
- "start": start_idx,
203
- }
204
- search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
205
- [f"{key}={value}" for key, value in search_params.items()]
206
- )
207
-
208
- # retry up to 5 times over a period of up to 1 minute
209
- @retry(
210
- wait=wait_exponential_jitter(),
211
- stop=stop_after_attempt(5) | stop_after_delay(60),
212
- retry=retry_if_exception(httpx_should_retry),
213
- before_sleep=log_httpx_retry_attempt(search_url),
214
- )
215
- async def execute_search() -> httpx.Response:
216
- return await client.get(search_url)
217
-
218
- result = await execute_search()
219
- data = result.json()
220
-
221
- if "items" in data:
222
- return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
223
- else:
224
- return []
225
-
226
- return search