inspect-ai 0.3.94__py3-none-any.whl → 0.3.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +12 -6
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +16 -0
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +9 -9
- inspect_ai/_view/www/dist/assets/index.js +117 -120
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_types.py +1 -1
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +67 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +177 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +87 -0
- inspect_ai/analysis/beta/_dataframe/events/extract.py +26 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +100 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +73 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +79 -0
- inspect_ai/analysis/beta/_dataframe/progress.py +26 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +77 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +370 -0
- inspect_ai/analysis/beta/_dataframe/util.py +160 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/log/_file.py +10 -3
- inspect_ai/log/_log.py +21 -1
- inspect_ai/model/_call_tools.py +2 -1
- inspect_ai/model/_model.py +6 -4
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/solver/_multiple_choice.py +4 -1
- inspect_ai/solver/_task_state.py +8 -4
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_sandbox.py +17 -14
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/METADATA +9 -2
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/RECORD +75 -46
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/top_level.txt +0 -0
inspect_ai/log/_file.py
CHANGED
@@ -524,14 +524,21 @@ def manifest_eval_log_name(info: EvalLogInfo, log_dir: str, sep: str) -> str:
|
|
524
524
|
|
525
525
|
def log_files_from_ls(
|
526
526
|
ls: list[FileInfo],
|
527
|
-
formats: list[Literal["eval", "json"]] | None,
|
527
|
+
formats: list[Literal["eval", "json"]] | None = None,
|
528
528
|
descending: bool = True,
|
529
|
+
sort: bool = True,
|
529
530
|
) -> list[EvalLogInfo]:
|
530
531
|
extensions = [f".{format}" for format in (formats or ALL_LOG_FORMATS)]
|
531
532
|
return [
|
532
533
|
log_file_info(file)
|
533
|
-
for file in
|
534
|
-
|
534
|
+
for file in (
|
535
|
+
sorted(
|
536
|
+
ls,
|
537
|
+
key=lambda file: (file.mtime if file.mtime else 0),
|
538
|
+
reverse=descending,
|
539
|
+
)
|
540
|
+
if sort
|
541
|
+
else ls
|
535
542
|
)
|
536
543
|
if file.type == "file" and is_log_file(file.name, extensions)
|
537
544
|
]
|
inspect_ai/log/_log.py
CHANGED
@@ -17,9 +17,11 @@ from pydantic import (
|
|
17
17
|
)
|
18
18
|
from rich.console import Console, RenderableType
|
19
19
|
from rich.traceback import Traceback
|
20
|
+
from shortuuid import uuid
|
20
21
|
|
21
|
-
from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
|
22
|
+
from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, DESERIALIZING, PKG_NAME
|
22
23
|
from inspect_ai._util.error import EvalError, exception_message
|
24
|
+
from inspect_ai._util.hash import base57_id_hash
|
23
25
|
from inspect_ai._util.logger import warn_once
|
24
26
|
from inspect_ai.approval._policy import ApprovalPolicyConfig
|
25
27
|
from inspect_ai.dataset._dataset import MT, metadata_as
|
@@ -677,6 +679,9 @@ class EvalModelConfig(BaseModel):
|
|
677
679
|
class EvalSpec(BaseModel):
|
678
680
|
"""Eval target and configuration."""
|
679
681
|
|
682
|
+
eval_id: str = Field(default_factory=str)
|
683
|
+
"""Globally unique id for eval."""
|
684
|
+
|
680
685
|
run_id: str = Field(default_factory=str)
|
681
686
|
"""Unique run id"""
|
682
687
|
|
@@ -757,6 +762,21 @@ class EvalSpec(BaseModel):
|
|
757
762
|
# allow field model_args
|
758
763
|
model_config = ConfigDict(protected_namespaces=())
|
759
764
|
|
765
|
+
def model_post_init(self, __context: Any) -> None:
|
766
|
+
# check if deserializing
|
767
|
+
is_deserializing = isinstance(__context, dict) and __context.get(
|
768
|
+
DESERIALIZING, False
|
769
|
+
)
|
770
|
+
|
771
|
+
# Generate eval_id if needed
|
772
|
+
if self.eval_id == "":
|
773
|
+
if is_deserializing:
|
774
|
+
# we want the eval_id to be stable across reads of the eval log so we compose it
|
775
|
+
# as a hash that matches the size/apperance of shortuuid-based uuids
|
776
|
+
self.eval_id = base57_id_hash(self.run_id + self.task_id + self.created)
|
777
|
+
else:
|
778
|
+
self.eval_id = uuid()
|
779
|
+
|
760
780
|
@model_validator(mode="before")
|
761
781
|
@classmethod
|
762
782
|
def read_sandbox_spec(
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -39,6 +39,7 @@ from inspect_ai._util.content import (
|
|
39
39
|
ContentText,
|
40
40
|
ContentVideo,
|
41
41
|
)
|
42
|
+
from inspect_ai._util.exception import TerminateSampleError
|
42
43
|
from inspect_ai._util.format import format_function_call
|
43
44
|
from inspect_ai._util.logger import warn_once
|
44
45
|
from inspect_ai._util.registry import registry_unqualified_name
|
@@ -376,7 +377,7 @@ async def call_tool(
|
|
376
377
|
transcript()._event(
|
377
378
|
SampleLimitEvent(type="operator", limit=1, message=message)
|
378
379
|
)
|
379
|
-
raise
|
380
|
+
raise TerminateSampleError(message)
|
380
381
|
else:
|
381
382
|
raise ToolApprovalError(approval.explanation if approval else None)
|
382
383
|
if approval and approval.modified:
|
inspect_ai/model/_model.py
CHANGED
@@ -1237,9 +1237,10 @@ def tool_result_images_as_user_message(
|
|
1237
1237
|
|
1238
1238
|
Tool responses will have images replaced with "Image content is included below.", and the new user message will contain the images.
|
1239
1239
|
"""
|
1240
|
-
init_accum: ImagesAccumulator = ([], [], [])
|
1241
1240
|
chat_messages, user_message_content, tool_call_ids = functools.reduce(
|
1242
|
-
tool_result_images_reducer,
|
1241
|
+
tool_result_images_reducer,
|
1242
|
+
messages,
|
1243
|
+
(list[ChatMessage](), list[Content](), list[str]()),
|
1243
1244
|
)
|
1244
1245
|
# if the last message was a tool result, we may need to flush the pending stuff here
|
1245
1246
|
return maybe_adding_user_message(chat_messages, user_message_content, tool_call_ids)
|
@@ -1265,9 +1266,10 @@ def tool_result_images_reducer(
|
|
1265
1266
|
and isinstance(message.content, list)
|
1266
1267
|
and any([isinstance(c, ContentImage) for c in message.content])
|
1267
1268
|
):
|
1268
|
-
init_accum: ImageContentAccumulator = ([], [])
|
1269
1269
|
new_user_message_content, edited_tool_message_content = functools.reduce(
|
1270
|
-
tool_result_image_content_reducer,
|
1270
|
+
tool_result_image_content_reducer,
|
1271
|
+
message.content,
|
1272
|
+
(list[Content](), list[Content]()),
|
1271
1273
|
)
|
1272
1274
|
|
1273
1275
|
return (
|
@@ -184,24 +184,23 @@ def openai_responses_chat_choices(
|
|
184
184
|
# │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
|
185
185
|
# │ │ │ type: "reasoning" │ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "reasoning" │ │ │
|
186
186
|
# │ │ │ id: "rs_bbbbbb" │ │ │ │ │ │ text: "" │ │ │ │ │ │ id: "rs_bbbbbb" │ │ │
|
187
|
-
# │ │ │ summary: [] │ │ │ │ │
|
188
|
-
# │ │
|
189
|
-
# │ │
|
190
|
-
# │ │ │
|
191
|
-
# │ │ │
|
192
|
-
# │ │ │
|
193
|
-
# │ │ │
|
194
|
-
# │ │ │ │
|
195
|
-
# │ │ │ │
|
196
|
-
# │ │ │ │
|
197
|
-
# │ │ │ │
|
198
|
-
# │ │ │ │
|
199
|
-
# │ │ │
|
200
|
-
# │ │
|
201
|
-
# │
|
202
|
-
# │ │
|
203
|
-
#
|
204
|
-
# └───────────────────────────┘ │ │ │ "msg_ccccccc" │ │ │
|
187
|
+
# │ │ │ summary: [] │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ summary: [] │ │ │
|
188
|
+
# │ │ ├───────────────────┤ │ │ │ │ │ ContentText │ │ │ │ │ ├───────────────────┤ │ │
|
189
|
+
# │ │ │ type: "message" │ │ │ │ │ │ text: "text1" │ │ │ │ │ │ type: "message" │ │ │
|
190
|
+
# │ │ │ id: "msg_ccccccc" │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ id: "msg_ccccccc" │ │ │
|
191
|
+
# │ │ │ role: "assistant" │ │ │ │ │ │ ContentText │ │ │ │ │ │ role: "assistant" │ │ │
|
192
|
+
# │ │ │ ┌───────────────┐ │ │ │ -> │ │ │ text: "text2" │ │ │ -> │ │ │ ┌───────────────┐ │ │ │
|
193
|
+
# │ │ │ │ Content │ │ │ │ │ │ └───────────────────┘ │ │ │ │ │ │ Content │ │ │ │
|
194
|
+
# │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ ┌───────────┐ │ │ │ │
|
195
|
+
# │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ │"text1" │ │ │ │ │
|
196
|
+
# │ │ │ │ ├───────────┤ │ │ │ │ │ │ internal │ │ │ │ │ │ ├───────────┤ │ │ │ │
|
197
|
+
# │ │ │ │ │"text2" │ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ │"text2" │ │ │ │ │
|
198
|
+
# │ │ │ │ └───────────┘ │ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ │ └───────────┘ │ │ │ │
|
199
|
+
# │ │ │ └───────────────┘ │ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ │ │ └───────────────┘ │ │ │
|
200
|
+
# │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │
|
201
|
+
# │ └───────────────────────┘ │ │ │ ┌───────────────────┐ │ │ │ └───────────────────────┘ │
|
202
|
+
# └───────────────────────────┘ │ │ │ output_msg_id: │ │ │ └───────────────────────────┘
|
203
|
+
# │ │ │ "msg_ccccccc" │ │ │
|
205
204
|
# │ │ └───────────────────┘ │ │
|
206
205
|
# │ └───────────────────────┘ │
|
207
206
|
# └───────────────────────────┘
|
@@ -33,7 +33,10 @@ from anthropic.types import (
|
|
33
33
|
ToolUseBlockParam,
|
34
34
|
message_create_params,
|
35
35
|
)
|
36
|
-
from anthropic.types.beta import
|
36
|
+
from anthropic.types.beta import (
|
37
|
+
BetaToolComputerUse20250124Param,
|
38
|
+
BetaToolTextEditor20241022Param,
|
39
|
+
)
|
37
40
|
from pydantic import JsonValue
|
38
41
|
from typing_extensions import override
|
39
42
|
|
@@ -218,6 +221,8 @@ class AnthropicAPI(ModelAPI):
|
|
218
221
|
# tools are generally available for Claude 3.5 Sonnet (new) as well and
|
219
222
|
# can be used without the computer use beta header.
|
220
223
|
betas.append("computer-use-2025-01-24")
|
224
|
+
if any("20241022" in str(tool.get("type", "")) for tool in tools_param):
|
225
|
+
betas.append("computer-use-2024-10-22")
|
221
226
|
if len(betas) > 0:
|
222
227
|
extra_headers["anthropic-beta"] = ",".join(betas)
|
223
228
|
|
@@ -337,6 +342,15 @@ class AnthropicAPI(ModelAPI):
|
|
337
342
|
@override
|
338
343
|
def should_retry(self, ex: Exception) -> bool:
|
339
344
|
if isinstance(ex, APIStatusError):
|
345
|
+
# for unknown reasons, anthropic does not always set status_code == 529
|
346
|
+
# for "overloaded_error" so we check for it explicitly
|
347
|
+
if (
|
348
|
+
isinstance(ex.body, dict)
|
349
|
+
and ex.body.get("error", {}).get("type", "") == "overloaded_error"
|
350
|
+
):
|
351
|
+
return True
|
352
|
+
|
353
|
+
# standard http status code checking
|
340
354
|
return is_retryable_http_status(ex.status_code)
|
341
355
|
elif httpx_should_retry(ex):
|
342
356
|
return True
|
@@ -545,7 +559,7 @@ class AnthropicAPI(ModelAPI):
|
|
545
559
|
|
546
560
|
def text_editor_tool_param(
|
547
561
|
self, tool: ToolInfo
|
548
|
-
) ->
|
562
|
+
) -> ToolTextEditor20250124Param | BetaToolTextEditor20241022Param | None:
|
549
563
|
# check for compatible 'text editor' tool
|
550
564
|
if tool.name == "text_editor" and (
|
551
565
|
sorted(tool.parameters.properties.keys())
|
@@ -561,8 +575,14 @@ class AnthropicAPI(ModelAPI):
|
|
561
575
|
]
|
562
576
|
)
|
563
577
|
):
|
564
|
-
return
|
565
|
-
|
578
|
+
return (
|
579
|
+
BetaToolTextEditor20241022Param(
|
580
|
+
type="text_editor_20241022", name="str_replace_editor"
|
581
|
+
)
|
582
|
+
if self.is_claude_3_5()
|
583
|
+
else ToolTextEditor20250124Param(
|
584
|
+
type="text_editor_20250124", name="str_replace_editor"
|
585
|
+
)
|
566
586
|
)
|
567
587
|
# not a text_editor tool
|
568
588
|
else:
|
@@ -571,7 +591,10 @@ class AnthropicAPI(ModelAPI):
|
|
571
591
|
|
572
592
|
# tools can be either a stock tool param or a special Anthropic native use tool param
|
573
593
|
ToolParamDef = (
|
574
|
-
ToolParam
|
594
|
+
ToolParam
|
595
|
+
| BetaToolComputerUse20250124Param
|
596
|
+
| ToolTextEditor20250124Param
|
597
|
+
| BetaToolTextEditor20241022Param
|
575
598
|
)
|
576
599
|
|
577
600
|
|
@@ -580,6 +603,7 @@ def add_cache_control(
|
|
580
603
|
| ToolParam
|
581
604
|
| BetaToolComputerUse20250124Param
|
582
605
|
| ToolTextEditor20250124Param
|
606
|
+
| BetaToolTextEditor20241022Param
|
583
607
|
| dict[str, Any],
|
584
608
|
) -> None:
|
585
609
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
@@ -844,6 +868,7 @@ def _names_for_tool_call(
|
|
844
868
|
"""
|
845
869
|
mappings = (
|
846
870
|
(INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
|
871
|
+
("str_replace_editor", "text_editor_20241022", "text_editor"),
|
847
872
|
("str_replace_editor", "text_editor_20250124", "text_editor"),
|
848
873
|
("bash", "bash_20250124", "bash_session"),
|
849
874
|
)
|
@@ -200,6 +200,7 @@ def multiple_choice(
|
|
200
200
|
template: str | None = None,
|
201
201
|
cot: bool = False,
|
202
202
|
multiple_correct: bool = False,
|
203
|
+
max_tokens: int | None = None,
|
203
204
|
**kwargs: Unpack[DeprecatedArgs],
|
204
205
|
) -> Solver:
|
205
206
|
"""Multiple choice question solver. Formats a multiple choice question prompt, then calls `generate()`.
|
@@ -226,6 +227,8 @@ def multiple_choice(
|
|
226
227
|
squares? A) 3, B) 4, C) 9" has multiple correct answers, B and C. Leave
|
227
228
|
as `False` if there's exactly one correct answer from the choices
|
228
229
|
available. NOTE: this has no effect if you provide a custom template.
|
230
|
+
max_tokens: Default `None`. Controls the number of tokens generated through the call
|
231
|
+
to generate().
|
229
232
|
**kwargs (Any): Deprecated arguments for backward compatibility.
|
230
233
|
|
231
234
|
#### Shuffling
|
@@ -282,7 +285,7 @@ def multiple_choice(
|
|
282
285
|
template=str(template),
|
283
286
|
)
|
284
287
|
|
285
|
-
state = await generate(state)
|
288
|
+
state = await generate(state, max_tokens=max_tokens)
|
286
289
|
|
287
290
|
answers = parse_answers(state)
|
288
291
|
if answers and answers.group(1):
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -138,7 +138,7 @@ class TaskState:
|
|
138
138
|
The `TaskState` represents the internal state of the `Task` being run for a single `Sample`.
|
139
139
|
|
140
140
|
The `TaskState` is passed to and returned from each solver during a sample's
|
141
|
-
evaluation. It allows us to
|
141
|
+
evaluation. It allows us to maintain the manipulated message history, the tools
|
142
142
|
available to the model, the final output of the model, and whether the task
|
143
143
|
is completed or has hit a limit.
|
144
144
|
"""
|
@@ -204,13 +204,17 @@ class TaskState:
|
|
204
204
|
Convenience function for accessing the initial input from the `Sample` as a string.
|
205
205
|
|
206
206
|
If the `input` is a `list[ChatMessage]`, this will return the text from
|
207
|
-
the
|
207
|
+
the last chat message
|
208
208
|
"""
|
209
209
|
if isinstance(self._input, str):
|
210
210
|
return self._input
|
211
211
|
else:
|
212
212
|
input = next(
|
213
|
-
(
|
213
|
+
(
|
214
|
+
message.text
|
215
|
+
for message in reversed(self._input)
|
216
|
+
if message.role == "user"
|
217
|
+
),
|
214
218
|
None,
|
215
219
|
)
|
216
220
|
if input:
|
@@ -231,7 +235,7 @@ class TaskState:
|
|
231
235
|
write access to the user chat prompt. Raises an
|
232
236
|
exception if there is no user prompt
|
233
237
|
"""
|
234
|
-
prompt = next((m for m in self.messages if m.role == "user"), None)
|
238
|
+
prompt = next((m for m in reversed(self.messages) if m.role == "user"), None)
|
235
239
|
if prompt:
|
236
240
|
return prompt
|
237
241
|
else:
|
inspect_ai/tool/_mcp/_context.py
CHANGED
@@ -2,13 +2,11 @@ from contextlib import _AsyncGeneratorContextManager
|
|
2
2
|
from typing import TypeAlias
|
3
3
|
|
4
4
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
5
|
-
from mcp.
|
6
|
-
JSONRPCMessage,
|
7
|
-
)
|
5
|
+
from mcp.shared.message import SessionMessage
|
8
6
|
|
9
7
|
MCPServerContext: TypeAlias = _AsyncGeneratorContextManager[
|
10
8
|
tuple[
|
11
|
-
MemoryObjectReceiveStream[
|
12
|
-
MemoryObjectSendStream[
|
9
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
10
|
+
MemoryObjectSendStream[SessionMessage],
|
13
11
|
],
|
14
12
|
]
|
inspect_ai/tool/_mcp/_sandbox.py
CHANGED
@@ -5,6 +5,7 @@ from typing import TextIO
|
|
5
5
|
import anyio
|
6
6
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
7
7
|
from mcp import JSONRPCRequest, StdioServerParameters
|
8
|
+
from mcp.shared.message import SessionMessage
|
8
9
|
from mcp.types import JSONRPCMessage, JSONRPCNotification
|
9
10
|
|
10
11
|
from inspect_ai.tool._tool_support_helpers import (
|
@@ -36,12 +37,12 @@ async def sandbox_client( # type: ignore
|
|
36
37
|
)
|
37
38
|
|
38
39
|
# read_stream is remote process's stdout
|
39
|
-
read_stream: MemoryObjectReceiveStream[
|
40
|
-
read_stream_writer: MemoryObjectSendStream[
|
40
|
+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
41
|
+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
41
42
|
|
42
43
|
# write_stream is remote process's stdin
|
43
|
-
write_stream: MemoryObjectSendStream[
|
44
|
-
write_stream_reader: MemoryObjectReceiveStream[
|
44
|
+
write_stream: MemoryObjectSendStream[SessionMessage]
|
45
|
+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
45
46
|
|
46
47
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
47
48
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
@@ -64,18 +65,20 @@ async def sandbox_client( # type: ignore
|
|
64
65
|
async with write_stream_reader:
|
65
66
|
# This reads messages until the stream is closed
|
66
67
|
async for message in write_stream_reader:
|
67
|
-
root = message.root
|
68
|
+
root = message.message.root
|
68
69
|
if isinstance(root, JSONRPCRequest):
|
69
70
|
await read_stream_writer.send(
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
71
|
+
SessionMessage(
|
72
|
+
message=await exec_model_request(
|
73
|
+
sandbox=sandbox_environment,
|
74
|
+
method="mcp_send_request",
|
75
|
+
params={
|
76
|
+
"session_id": session_id,
|
77
|
+
"request": root.model_dump(),
|
78
|
+
},
|
79
|
+
result_type=JSONRPCMessage,
|
80
|
+
timeout=timeout,
|
81
|
+
)
|
79
82
|
)
|
80
83
|
)
|
81
84
|
elif isinstance(root, JSONRPCNotification):
|
inspect_ai/tool/_mcp/server.py
CHANGED
inspect_ai/tool/_tools/_think.py
CHANGED
@@ -41,7 +41,7 @@ def think(
|
|
41
41
|
def think_tool_viewer() -> ToolCallViewer:
|
42
42
|
def viewer(tool_call: ToolCall) -> ToolCallView:
|
43
43
|
call = ToolCallContent(
|
44
|
-
format="markdown", content=tool_call.arguments
|
44
|
+
format="markdown", content=tool_call.arguments.get("thought", "")
|
45
45
|
)
|
46
46
|
return ToolCallView(call=call)
|
47
47
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
2
|
+
from typing import Awaitable, Callable
|
3
3
|
|
4
4
|
import anyio
|
5
5
|
import httpx
|
@@ -16,8 +16,6 @@ from inspect_ai._util.error import PrerequisiteError
|
|
16
16
|
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
17
17
|
from inspect_ai.util._concurrency import concurrency
|
18
18
|
|
19
|
-
from .._tool import Tool, ToolResult, tool
|
20
|
-
|
21
19
|
DEFAULT_RELEVANCE_PROMPT = """I am trying to answer the following question and need to find the most relevant information on the web. Please let me know if the following content is relevant to the question or not. You should just respond with "yes" or "no".
|
22
20
|
|
23
21
|
Question: {question}
|
@@ -31,59 +29,35 @@ class SearchLink:
|
|
31
29
|
self.snippet = snippet
|
32
30
|
|
33
31
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
@tool
|
40
|
-
def web_search(
|
41
|
-
provider: Literal["google"] = "google",
|
42
|
-
num_results: int = 3,
|
43
|
-
max_provider_calls: int = 3,
|
44
|
-
max_connections: int = 10,
|
45
|
-
model: str | None = None,
|
46
|
-
) -> Tool:
|
47
|
-
"""Web search tool.
|
48
|
-
|
49
|
-
A tool that can be registered for use by models to search the web. Use
|
50
|
-
the `use_tools()` solver to make the tool available (e.g. `use_tools(web_search())`))
|
51
|
-
|
52
|
-
A web search is conducted using the specified provider, the results are parsed for relevance
|
53
|
-
using the specified model, and the top 'num_results' relevant pages are returned.
|
54
|
-
|
55
|
-
See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
provider: Search provider (defaults to "google", currently
|
59
|
-
the only provider). Possible future providers include "brave" and "bing".
|
60
|
-
num_results: Number of web search result pages to return to the model.
|
61
|
-
max_provider_calls: Maximum number of search calls to make to the search provider.
|
62
|
-
max_connections: Maximum number of concurrent connections to API
|
63
|
-
endpoint of search provider.
|
64
|
-
model: Model used to parse web pages for relevance.
|
32
|
+
def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
33
|
+
"""
|
34
|
+
Get Google API keys from environment variables.
|
65
35
|
|
66
36
|
Returns:
|
67
|
-
|
37
|
+
tuple: A tuple containing the Google API key and the Google CSE ID.
|
68
38
|
"""
|
69
|
-
|
70
|
-
|
39
|
+
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
|
40
|
+
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
|
41
|
+
return (google_api_key, google_cse_id) if google_api_key and google_cse_id else None
|
71
42
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
43
|
+
|
44
|
+
def google_search_provider(
|
45
|
+
num_results: int,
|
46
|
+
max_provider_calls: int,
|
47
|
+
max_connections: int,
|
48
|
+
model: str | None,
|
49
|
+
) -> Callable[[str], Awaitable[str | None]]:
|
50
|
+
keys = maybe_get_google_api_keys()
|
51
|
+
if not keys:
|
52
|
+
raise PrerequisiteError(
|
53
|
+
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
|
77
54
|
)
|
55
|
+
google_api_key, google_cse_id = keys
|
78
56
|
|
79
|
-
#
|
80
|
-
|
81
|
-
"""
|
82
|
-
Use the web_search tool to perform keyword searches of the web.
|
57
|
+
# Create the client within the provider
|
58
|
+
client = httpx.AsyncClient()
|
83
59
|
|
84
|
-
|
85
|
-
query (str): Search query.
|
86
|
-
"""
|
60
|
+
async def search(query: str) -> str | None:
|
87
61
|
# limit number of concurrent searches
|
88
62
|
page_contents: list[str] = []
|
89
63
|
urls: list[str] = []
|
@@ -92,8 +66,8 @@ def web_search(
|
|
92
66
|
|
93
67
|
# Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
|
94
68
|
while len(page_contents) < num_results and search_calls < max_provider_calls:
|
95
|
-
async with concurrency(
|
96
|
-
links = await
|
69
|
+
async with concurrency("google_web_search", max_connections):
|
70
|
+
links = await _search(query, start_idx=search_calls * 10)
|
97
71
|
|
98
72
|
async with anyio.create_task_group() as tg:
|
99
73
|
|
@@ -114,19 +88,39 @@ def web_search(
|
|
114
88
|
search_calls += 1
|
115
89
|
|
116
90
|
all_page_contents = "\n\n".join(page_contents)
|
117
|
-
if all_page_contents == ""
|
118
|
-
response: ToolResult = (
|
119
|
-
"I'm sorry, I couldn't find any relevant information on the web."
|
120
|
-
)
|
121
|
-
else:
|
122
|
-
response = (
|
123
|
-
"Here are your web search results. Please read them carefully as they may be useful later! "
|
124
|
-
+ all_page_contents
|
125
|
-
)
|
91
|
+
return None if all_page_contents == "" else all_page_contents
|
126
92
|
|
127
|
-
|
93
|
+
async def _search(query: str, start_idx: int) -> list[SearchLink]:
|
94
|
+
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
95
|
+
search_params = {
|
96
|
+
"q": query,
|
97
|
+
"key": google_api_key,
|
98
|
+
"cx": google_cse_id,
|
99
|
+
"start": start_idx,
|
100
|
+
}
|
101
|
+
search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
|
102
|
+
[f"{key}={value}" for key, value in search_params.items()]
|
103
|
+
)
|
128
104
|
|
129
|
-
|
105
|
+
# retry up to 5 times over a period of up to 1 minute
|
106
|
+
@retry(
|
107
|
+
wait=wait_exponential_jitter(),
|
108
|
+
stop=stop_after_attempt(5) | stop_after_delay(60),
|
109
|
+
retry=retry_if_exception(httpx_should_retry),
|
110
|
+
before_sleep=log_httpx_retry_attempt(search_url),
|
111
|
+
)
|
112
|
+
async def execute_search() -> httpx.Response:
|
113
|
+
return await client.get(search_url)
|
114
|
+
|
115
|
+
result = await execute_search()
|
116
|
+
data = result.json()
|
117
|
+
|
118
|
+
if "items" in data:
|
119
|
+
return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
|
120
|
+
else:
|
121
|
+
return []
|
122
|
+
|
123
|
+
return search
|
130
124
|
|
131
125
|
|
132
126
|
async def page_if_relevant(
|
@@ -183,44 +177,3 @@ async def page_if_relevant(
|
|
183
177
|
return full_text
|
184
178
|
else:
|
185
179
|
return None
|
186
|
-
|
187
|
-
|
188
|
-
def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
|
189
|
-
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
|
190
|
-
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
|
191
|
-
if not google_api_key or not google_cse_id:
|
192
|
-
raise PrerequisiteError(
|
193
|
-
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.aisi.org.uk/tools.html#google-provider"
|
194
|
-
)
|
195
|
-
|
196
|
-
async def search(query: str, start_idx: int) -> list[SearchLink]:
|
197
|
-
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
198
|
-
search_params = {
|
199
|
-
"q": query,
|
200
|
-
"key": google_api_key,
|
201
|
-
"cx": google_cse_id,
|
202
|
-
"start": start_idx,
|
203
|
-
}
|
204
|
-
search_url = "https://www.googleapis.com/customsearch/v1?" + "&".join(
|
205
|
-
[f"{key}={value}" for key, value in search_params.items()]
|
206
|
-
)
|
207
|
-
|
208
|
-
# retry up to 5 times over a period of up to 1 minute
|
209
|
-
@retry(
|
210
|
-
wait=wait_exponential_jitter(),
|
211
|
-
stop=stop_after_attempt(5) | stop_after_delay(60),
|
212
|
-
retry=retry_if_exception(httpx_should_retry),
|
213
|
-
before_sleep=log_httpx_retry_attempt(search_url),
|
214
|
-
)
|
215
|
-
async def execute_search() -> httpx.Response:
|
216
|
-
return await client.get(search_url)
|
217
|
-
|
218
|
-
result = await execute_search()
|
219
|
-
data = result.json()
|
220
|
-
|
221
|
-
if "items" in data:
|
222
|
-
return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
|
223
|
-
else:
|
224
|
-
return []
|
225
|
-
|
226
|
-
return search
|