inspect-ai 0.3.103__py3-none-any.whl → 0.3.105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +2 -1
- inspect_ai/_cli/eval.py +2 -2
- inspect_ai/_display/core/active.py +3 -0
- inspect_ai/_display/core/config.py +1 -0
- inspect_ai/_display/core/panel.py +21 -13
- inspect_ai/_display/core/results.py +3 -7
- inspect_ai/_display/core/rich.py +3 -5
- inspect_ai/_display/log/__init__.py +0 -0
- inspect_ai/_display/log/display.py +173 -0
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_display/rich/display.py +2 -4
- inspect_ai/_display/textual/app.py +1 -6
- inspect_ai/_display/textual/widgets/task_detail.py +3 -14
- inspect_ai/_display/textual/widgets/tasks.py +1 -1
- inspect_ai/_eval/eval.py +1 -1
- inspect_ai/_eval/evalset.py +3 -3
- inspect_ai/_eval/registry.py +6 -1
- inspect_ai/_eval/run.py +5 -1
- inspect_ai/_eval/task/constants.py +1 -0
- inspect_ai/_eval/task/log.py +2 -0
- inspect_ai/_eval/task/run.py +65 -39
- inspect_ai/_util/citation.py +88 -0
- inspect_ai/_util/content.py +24 -2
- inspect_ai/_util/json.py +17 -2
- inspect_ai/_util/registry.py +19 -4
- inspect_ai/_view/schema.py +0 -6
- inspect_ai/_view/server.py +17 -0
- inspect_ai/_view/www/dist/assets/index.css +93 -31
- inspect_ai/_view/www/dist/assets/index.js +10639 -10011
- inspect_ai/_view/www/log-schema.json +418 -1
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/node_modules/katex/src/fonts/generate_fonts.py +58 -0
- inspect_ai/_view/www/node_modules/katex/src/metrics/extract_tfms.py +114 -0
- inspect_ai/_view/www/node_modules/katex/src/metrics/extract_ttfs.py +122 -0
- inspect_ai/_view/www/node_modules/katex/src/metrics/format_json.py +28 -0
- inspect_ai/_view/www/node_modules/katex/src/metrics/parse_tfm.py +211 -0
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/@types/log.d.ts +140 -39
- inspect_ai/_view/www/src/app/content/RecordTree.tsx +13 -0
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +1 -1
- inspect_ai/_view/www/src/app/routing/logNavigation.ts +31 -0
- inspect_ai/_view/www/src/app/routing/{navigationHooks.ts → sampleNavigation.ts} +39 -86
- inspect_ai/_view/www/src/app/samples/SampleDialog.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.module.css +4 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +17 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageCitations.module.css +16 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageCitations.tsx +63 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageContent.module.css +6 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +174 -25
- inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +21 -3
- inspect_ai/_view/www/src/app/samples/chat/content-data/ContentDataView.module.css +7 -0
- inspect_ai/_view/www/src/app/samples/chat/content-data/ContentDataView.tsx +111 -0
- inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearch.module.css +10 -0
- inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearch.tsx +14 -0
- inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearchResults.module.css +19 -0
- inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearchResults.tsx +49 -0
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +7 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/chat/types.ts +4 -0
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +26 -0
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +14 -3
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +359 -7
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/language.ts +6 -0
- inspect_ai/_view/www/src/app/samples/sampleLimit.ts +2 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +4 -4
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +1 -1
- inspect_ai/_view/www/src/client/api/api-browser.ts +25 -0
- inspect_ai/_view/www/src/client/api/api-http.ts +3 -0
- inspect_ai/_view/www/src/client/api/api-vscode.ts +6 -0
- inspect_ai/_view/www/src/client/api/client-api.ts +3 -0
- inspect_ai/_view/www/src/client/api/jsonrpc.ts +1 -0
- inspect_ai/_view/www/src/client/api/types.ts +3 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +15 -2
- inspect_ai/_view/www/src/state/samplePolling.ts +17 -1
- inspect_ai/_view/www/src/tests/README.md +2 -2
- inspect_ai/_view/www/src/utils/git.ts +3 -1
- inspect_ai/_view/www/src/utils/html.ts +6 -0
- inspect_ai/agent/_handoff.py +8 -5
- inspect_ai/agent/_react.py +5 -5
- inspect_ai/dataset/_dataset.py +1 -1
- inspect_ai/log/_condense.py +5 -0
- inspect_ai/log/_file.py +4 -1
- inspect_ai/log/_log.py +9 -4
- inspect_ai/log/_recorders/json.py +4 -2
- inspect_ai/log/_samples.py +5 -0
- inspect_ai/log/_util.py +2 -0
- inspect_ai/model/__init__.py +14 -0
- inspect_ai/model/_call_tools.py +17 -8
- inspect_ai/model/_chat_message.py +3 -0
- inspect_ai/model/_openai_responses.py +80 -34
- inspect_ai/model/_providers/_anthropic_citations.py +158 -0
- inspect_ai/model/_providers/_google_citations.py +100 -0
- inspect_ai/model/_providers/anthropic.py +219 -36
- inspect_ai/model/_providers/google.py +98 -22
- inspect_ai/model/_providers/mistral.py +20 -7
- inspect_ai/model/_providers/openai.py +11 -10
- inspect_ai/model/_providers/openai_compatible.py +3 -2
- inspect_ai/model/_providers/openai_responses.py +2 -5
- inspect_ai/model/_providers/perplexity.py +123 -0
- inspect_ai/model/_providers/providers.py +13 -2
- inspect_ai/model/_providers/vertex.py +3 -0
- inspect_ai/model/_trim.py +5 -0
- inspect_ai/tool/__init__.py +14 -0
- inspect_ai/tool/_mcp/_mcp.py +5 -2
- inspect_ai/tool/_mcp/sampling.py +19 -3
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tool.py +10 -1
- inspect_ai/tool/_tools/_web_search/_base_http_provider.py +104 -0
- inspect_ai/tool/_tools/_web_search/_exa.py +78 -0
- inspect_ai/tool/_tools/_web_search/_google.py +22 -25
- inspect_ai/tool/_tools/_web_search/_tavily.py +47 -65
- inspect_ai/tool/_tools/_web_search/_web_search.py +83 -36
- inspect_ai/tool/_tools/_web_search/_web_search_provider.py +7 -0
- inspect_ai/util/__init__.py +8 -0
- inspect_ai/util/_background.py +64 -0
- inspect_ai/util/_display.py +11 -2
- inspect_ai/util/_limit.py +72 -5
- inspect_ai/util/_sandbox/__init__.py +2 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/service.py +28 -7
- inspect_ai/util/_span.py +12 -1
- inspect_ai/util/_subprocess.py +51 -38
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/RECORD +134 -109
- /inspect_ai/model/{_openai_computer_use.py → _providers/_openai_computer_use.py} +0 -0
- /inspect_ai/model/{_openai_web_search.py → _providers/_openai_web_search.py} +0 -0
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.105.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from openai._types import NOT_GIVEN
|
|
13
13
|
from openai.types.chat import ChatCompletion
|
14
14
|
from typing_extensions import override
|
15
15
|
|
16
|
+
from inspect_ai._util.deprecation import deprecation_warning
|
16
17
|
from inspect_ai._util.error import PrerequisiteError
|
17
18
|
from inspect_ai._util.logger import warn_once
|
18
19
|
from inspect_ai.model._openai import chat_choices_from_openai
|
@@ -64,6 +65,8 @@ class OpenAIAPI(ModelAPI):
|
|
64
65
|
api_key: str | None = None,
|
65
66
|
config: GenerateConfig = GenerateConfig(),
|
66
67
|
responses_api: bool | None = None,
|
68
|
+
# Can't use the XxxDeprecatedArgs approach since this already has a **param
|
69
|
+
# but responses_store is deprecated and should not be used.
|
67
70
|
responses_store: Literal["auto"] | bool = "auto",
|
68
71
|
service_tier: str | None = None,
|
69
72
|
client_timeout: float | None = None,
|
@@ -88,19 +91,18 @@ class OpenAIAPI(ModelAPI):
|
|
88
91
|
)
|
89
92
|
|
90
93
|
# is this a model we use responses api by default for?
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
or self.is_codex()
|
95
|
-
)
|
94
|
+
responses_preferred = (
|
95
|
+
self.is_o_series() and not self.is_o1_early()
|
96
|
+
) or self.is_codex()
|
96
97
|
|
97
98
|
# resolve whether we are forcing the responses api
|
98
|
-
self.responses_api =
|
99
|
+
self.responses_api = self.is_computer_use_preview() or (
|
100
|
+
responses_api if responses_api is not None else responses_preferred
|
101
|
+
)
|
99
102
|
|
100
103
|
# resolve whether we are using the responses store
|
101
|
-
|
102
|
-
responses_store
|
103
|
-
)
|
104
|
+
if isinstance(responses_store, bool):
|
105
|
+
deprecation_warning("`responses_store` is no longer supported.")
|
104
106
|
|
105
107
|
# set service tier if specified
|
106
108
|
self.service_tier = service_tier
|
@@ -260,7 +262,6 @@ class OpenAIAPI(ModelAPI):
|
|
260
262
|
tool_choice=tool_choice,
|
261
263
|
config=config,
|
262
264
|
service_tier=self.service_tier,
|
263
|
-
store=self.responses_store,
|
264
265
|
)
|
265
266
|
|
266
267
|
# allocate request_id (so we can see it from ModelCall)
|
@@ -61,7 +61,8 @@ class OpenAICompatibleAPI(ModelAPI):
|
|
61
61
|
self.service = service
|
62
62
|
|
63
63
|
# compute api key
|
64
|
-
|
64
|
+
service_env_name = self.service.upper().replace("-", "_")
|
65
|
+
api_key_var = f"{service_env_name}_API_KEY"
|
65
66
|
|
66
67
|
super().__init__(
|
67
68
|
model_name=model_name,
|
@@ -82,7 +83,7 @@ class OpenAICompatibleAPI(ModelAPI):
|
|
82
83
|
|
83
84
|
# use service prefix to lookup base_url
|
84
85
|
if not self.base_url:
|
85
|
-
base_url_var = f"{
|
86
|
+
base_url_var = f"{service_env_name}_BASE_URL"
|
86
87
|
self.base_url = model_base_url(base_url, [base_url_var]) or service_base_url
|
87
88
|
if not self.base_url:
|
88
89
|
raise environment_prerequisite_error(
|
@@ -40,7 +40,6 @@ async def generate_responses(
|
|
40
40
|
tool_choice: ToolChoice,
|
41
41
|
config: GenerateConfig,
|
42
42
|
service_tier: str | None,
|
43
|
-
store: bool,
|
44
43
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
45
44
|
# allocate request_id (so we can see it from ModelCall)
|
46
45
|
request_id = http_hooks.start_request()
|
@@ -65,7 +64,7 @@ async def generate_responses(
|
|
65
64
|
else NOT_GIVEN
|
66
65
|
)
|
67
66
|
request = dict(
|
68
|
-
input=await openai_responses_inputs(input, model_name
|
67
|
+
input=await openai_responses_inputs(input, model_name),
|
69
68
|
tools=tool_params,
|
70
69
|
tool_choice=openai_responses_tool_choice(tool_choice, tool_params)
|
71
70
|
if isinstance(tool_params, list) and tool_choice != "auto"
|
@@ -77,7 +76,6 @@ async def generate_responses(
|
|
77
76
|
config=config,
|
78
77
|
service_tier=service_tier,
|
79
78
|
tools=len(tools) > 0,
|
80
|
-
store=store,
|
81
79
|
),
|
82
80
|
)
|
83
81
|
|
@@ -125,7 +123,6 @@ def completion_params_responses(
|
|
125
123
|
config: GenerateConfig,
|
126
124
|
service_tier: str | None,
|
127
125
|
tools: bool,
|
128
|
-
store: bool,
|
129
126
|
) -> dict[str, Any]:
|
130
127
|
# TODO: we'll need a computer_use_preview bool for the 'include'
|
131
128
|
# and 'reasoning' parameters
|
@@ -135,7 +132,7 @@ def completion_params_responses(
|
|
135
132
|
f"OpenAI Responses API does not support the '{param}' parameter.",
|
136
133
|
)
|
137
134
|
|
138
|
-
params: dict[str, Any] = dict(model=model_name
|
135
|
+
params: dict[str, Any] = dict(model=model_name)
|
139
136
|
if service_tier is not None:
|
140
137
|
params["service_tier"] = service_tier
|
141
138
|
if config.max_tokens is not None:
|
@@ -0,0 +1,123 @@
|
|
1
|
+
from typing import Any, cast
|
2
|
+
|
3
|
+
from openai.types.chat import ChatCompletion
|
4
|
+
|
5
|
+
from inspect_ai._util.citation import UrlCitation
|
6
|
+
from inspect_ai._util.content import ContentText
|
7
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
8
|
+
from inspect_ai.model._model_output import ModelOutput, ModelUsage
|
9
|
+
from inspect_ai.model._openai import chat_choices_from_openai
|
10
|
+
from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
|
11
|
+
from inspect_ai.tool import ToolChoice, ToolInfo
|
12
|
+
|
13
|
+
from .._chat_message import ChatMessage
|
14
|
+
from .._model_call import ModelCall
|
15
|
+
from .._model_output import ChatCompletionChoice
|
16
|
+
|
17
|
+
|
18
|
+
class PerplexityAPI(OpenAICompatibleAPI):
|
19
|
+
"""Model provider for Perplexity AI."""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model_name: str,
|
24
|
+
base_url: str | None = None,
|
25
|
+
api_key: str | None = None,
|
26
|
+
config: GenerateConfig = GenerateConfig(),
|
27
|
+
**model_args: Any,
|
28
|
+
) -> None:
|
29
|
+
super().__init__(
|
30
|
+
model_name=model_name,
|
31
|
+
base_url=base_url,
|
32
|
+
api_key=api_key,
|
33
|
+
config=config,
|
34
|
+
service="Perplexity",
|
35
|
+
service_base_url="https://api.perplexity.ai",
|
36
|
+
**model_args,
|
37
|
+
)
|
38
|
+
|
39
|
+
self._response: dict[str, Any] | None = None
|
40
|
+
|
41
|
+
def on_response(self, response: dict[str, Any]) -> None:
|
42
|
+
"""Capture the raw response for post-processing."""
|
43
|
+
self._response = response
|
44
|
+
|
45
|
+
async def generate(
|
46
|
+
self,
|
47
|
+
input: list["ChatMessage"],
|
48
|
+
tools: list["ToolInfo"],
|
49
|
+
tool_choice: "ToolChoice",
|
50
|
+
config: GenerateConfig,
|
51
|
+
) -> tuple[ModelOutput | Exception, "ModelCall"]:
|
52
|
+
result = await super().generate(input, tools, tool_choice, config)
|
53
|
+
output, call = cast(tuple[ModelOutput, "ModelCall"], result)
|
54
|
+
|
55
|
+
if self._response:
|
56
|
+
response = self._response
|
57
|
+
|
58
|
+
# attach citations if search results are returned
|
59
|
+
search_results = response.get("search_results")
|
60
|
+
if isinstance(search_results, list):
|
61
|
+
citations = [
|
62
|
+
UrlCitation(title=sr.get("title"), url=sr.get("url", ""))
|
63
|
+
for sr in search_results
|
64
|
+
if isinstance(sr, dict) and sr.get("url") is not None
|
65
|
+
]
|
66
|
+
if citations:
|
67
|
+
for choice in output.choices:
|
68
|
+
msg = choice.message
|
69
|
+
if isinstance(msg.content, str):
|
70
|
+
msg.content = [
|
71
|
+
ContentText(text=msg.content, citations=citations)
|
72
|
+
]
|
73
|
+
else:
|
74
|
+
added = False
|
75
|
+
for content in msg.content:
|
76
|
+
if (
|
77
|
+
isinstance(content, ContentText)
|
78
|
+
and getattr(content, "citations", None) is None
|
79
|
+
):
|
80
|
+
content.citations = citations
|
81
|
+
added = True
|
82
|
+
break
|
83
|
+
if not added:
|
84
|
+
msg.content.append(
|
85
|
+
ContentText(text="", citations=citations)
|
86
|
+
)
|
87
|
+
|
88
|
+
# update usage with additional metrics
|
89
|
+
usage_data = response.get("usage")
|
90
|
+
if isinstance(usage_data, dict):
|
91
|
+
extra_usage = {
|
92
|
+
k: usage_data.get(k)
|
93
|
+
for k in [
|
94
|
+
"search_context_size",
|
95
|
+
"citation_tokens",
|
96
|
+
"num_search_queries",
|
97
|
+
]
|
98
|
+
if k in usage_data
|
99
|
+
}
|
100
|
+
if output.usage:
|
101
|
+
output.usage.reasoning_tokens = usage_data.get("reasoning_tokens")
|
102
|
+
else:
|
103
|
+
output.usage = ModelUsage(
|
104
|
+
input_tokens=usage_data.get("prompt_tokens", 0),
|
105
|
+
output_tokens=usage_data.get("completion_tokens", 0),
|
106
|
+
total_tokens=usage_data.get("total_tokens", 0),
|
107
|
+
reasoning_tokens=usage_data.get("reasoning_tokens"),
|
108
|
+
)
|
109
|
+
if extra_usage:
|
110
|
+
output.metadata = output.metadata or {}
|
111
|
+
output.metadata.update(extra_usage)
|
112
|
+
|
113
|
+
# keep search_results for reference
|
114
|
+
if search_results:
|
115
|
+
output.metadata = output.metadata or {}
|
116
|
+
output.metadata["search_results"] = search_results
|
117
|
+
|
118
|
+
return output, call
|
119
|
+
|
120
|
+
def chat_choices_from_completion(
|
121
|
+
self, completion: ChatCompletion, tools: list[ToolInfo]
|
122
|
+
) -> list[ChatCompletionChoice]:
|
123
|
+
return chat_choices_from_openai(completion, tools)
|
@@ -59,7 +59,7 @@ def openai_api() -> type[ModelAPI]:
|
|
59
59
|
def anthropic() -> type[ModelAPI]:
|
60
60
|
FEATURE = "Anthropic API"
|
61
61
|
PACKAGE = "anthropic"
|
62
|
-
MIN_VERSION = "0.
|
62
|
+
MIN_VERSION = "0.52.0"
|
63
63
|
|
64
64
|
# verify we have the package
|
65
65
|
try:
|
@@ -157,7 +157,7 @@ def cf() -> type[ModelAPI]:
|
|
157
157
|
def mistral() -> type[ModelAPI]:
|
158
158
|
FEATURE = "Mistral API"
|
159
159
|
PACKAGE = "mistralai"
|
160
|
-
MIN_VERSION = "1.
|
160
|
+
MIN_VERSION = "1.8.2"
|
161
161
|
|
162
162
|
# verify we have the package
|
163
163
|
try:
|
@@ -218,6 +218,17 @@ def openrouter() -> type[ModelAPI]:
|
|
218
218
|
return OpenRouterAPI
|
219
219
|
|
220
220
|
|
221
|
+
@modelapi(name="perplexity")
|
222
|
+
def perplexity() -> type[ModelAPI]:
|
223
|
+
# validate
|
224
|
+
validate_openai_client("Perplexity API")
|
225
|
+
|
226
|
+
# in the clear
|
227
|
+
from .perplexity import PerplexityAPI
|
228
|
+
|
229
|
+
return PerplexityAPI
|
230
|
+
|
231
|
+
|
221
232
|
@modelapi(name="llama-cpp-python")
|
222
233
|
def llama_cpp_python() -> type[ModelAPI]:
|
223
234
|
# validate
|
@@ -33,6 +33,7 @@ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
|
33
33
|
from inspect_ai._util.content import (
|
34
34
|
Content,
|
35
35
|
ContentAudio,
|
36
|
+
ContentData,
|
36
37
|
ContentImage,
|
37
38
|
ContentReasoning,
|
38
39
|
ContentText,
|
@@ -338,6 +339,8 @@ async def content_part(content: Content | str) -> Part:
|
|
338
339
|
else:
|
339
340
|
if isinstance(content, ContentAudio):
|
340
341
|
file = content.audio
|
342
|
+
elif isinstance(content, ContentData):
|
343
|
+
assert False, "Vertex provider should never encounter ContentData"
|
341
344
|
else:
|
342
345
|
# it's ContentVideo
|
343
346
|
file = content.video
|
inspect_ai/model/_trim.py
CHANGED
@@ -13,6 +13,7 @@ def trim_messages(
|
|
13
13
|
- Retaining the 'input' messages from the sample.
|
14
14
|
- Preserving a proportion of the remaining messages (`preserve=0.7` by default).
|
15
15
|
- Ensuring that all assistant tool calls have corresponding tool messages.
|
16
|
+
- Ensuring that the sequence of messages doesn't end with an assistant message.
|
16
17
|
|
17
18
|
Args:
|
18
19
|
messages: List of messages to trim.
|
@@ -49,6 +50,10 @@ def trim_messages(
|
|
49
50
|
active_tool_ids = set()
|
50
51
|
conversation_messages.append(message)
|
51
52
|
|
53
|
+
# it's possible that we end with an assistant message w/ if so, remove it
|
54
|
+
if len(conversation_messages) and conversation_messages[-1].role == "assistant":
|
55
|
+
conversation_messages.pop()
|
56
|
+
|
52
57
|
# return trimmed messages
|
53
58
|
return partitioned.system + partitioned.input + conversation_messages
|
54
59
|
|
inspect_ai/tool/__init__.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1
|
+
from inspect_ai._util.citation import (
|
2
|
+
Citation,
|
3
|
+
CitationBase,
|
4
|
+
ContentCitation,
|
5
|
+
DocumentCitation,
|
6
|
+
UrlCitation,
|
7
|
+
)
|
1
8
|
from inspect_ai._util.content import (
|
2
9
|
Content,
|
3
10
|
ContentAudio,
|
11
|
+
ContentData,
|
4
12
|
ContentImage,
|
5
13
|
ContentReasoning,
|
6
14
|
ContentText,
|
@@ -62,6 +70,7 @@ __all__ = [
|
|
62
70
|
"MCPServer",
|
63
71
|
"Content",
|
64
72
|
"ContentAudio",
|
73
|
+
"ContentData",
|
65
74
|
"ContentImage",
|
66
75
|
"ContentReasoning",
|
67
76
|
"ContentText",
|
@@ -77,6 +86,11 @@ __all__ = [
|
|
77
86
|
"ToolInfo",
|
78
87
|
"ToolParam",
|
79
88
|
"ToolParams",
|
89
|
+
"Citation",
|
90
|
+
"CitationBase",
|
91
|
+
"DocumentCitation",
|
92
|
+
"ContentCitation",
|
93
|
+
"UrlCitation",
|
80
94
|
]
|
81
95
|
|
82
96
|
_UTIL_MODULE_VERSION = "0.3.19"
|
inspect_ai/tool/_mcp/_mcp.py
CHANGED
@@ -12,6 +12,7 @@ from mcp.client.session import ClientSession, SamplingFnT
|
|
12
12
|
from mcp.client.sse import sse_client
|
13
13
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
14
14
|
from mcp.types import (
|
15
|
+
AudioContent,
|
15
16
|
EmbeddedResource,
|
16
17
|
ImageContent,
|
17
18
|
TextContent,
|
@@ -282,14 +283,16 @@ def create_server_sandbox(
|
|
282
283
|
|
283
284
|
|
284
285
|
def tool_result_as_text(
|
285
|
-
content: list[TextContent | ImageContent | EmbeddedResource],
|
286
|
+
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource],
|
286
287
|
) -> str:
|
287
288
|
content_list: list[str] = []
|
288
289
|
for c in content:
|
289
290
|
if isinstance(c, TextContent):
|
290
291
|
content_list.append(c.text)
|
291
292
|
elif isinstance(c, ImageContent):
|
292
|
-
content_list.append("(base64 encoded image
|
293
|
+
content_list.append("(base64 encoded image omitted)")
|
294
|
+
elif isinstance(c, AudioContent):
|
295
|
+
content_list.append("(base64 encoded audio omitted)")
|
293
296
|
elif isinstance(c.resource, TextResourceContents):
|
294
297
|
content_list.append(c.resource.text)
|
295
298
|
|
inspect_ai/tool/_mcp/sampling.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any, Literal
|
2
2
|
|
3
3
|
from mcp.client.session import ClientSession
|
4
4
|
from mcp.shared.context import RequestContext
|
5
5
|
from mcp.types import (
|
6
6
|
INTERNAL_ERROR,
|
7
|
+
AudioContent,
|
7
8
|
CreateMessageRequestParams,
|
8
9
|
CreateMessageResult,
|
9
10
|
EmbeddedResource,
|
@@ -16,7 +17,7 @@ from mcp.types import (
|
|
16
17
|
StopReason as MCPStopReason,
|
17
18
|
)
|
18
19
|
|
19
|
-
from inspect_ai._util.content import Content, ContentImage, ContentText
|
20
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
|
20
21
|
from inspect_ai._util.error import exception_message
|
21
22
|
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
22
23
|
|
@@ -93,7 +94,7 @@ async def sampling_fn(
|
|
93
94
|
|
94
95
|
|
95
96
|
def as_inspect_content(
|
96
|
-
content: TextContent | ImageContent | EmbeddedResource,
|
97
|
+
content: TextContent | ImageContent | AudioContent | EmbeddedResource,
|
97
98
|
) -> Content:
|
98
99
|
if isinstance(content, TextContent):
|
99
100
|
return ContentText(text=content.text)
|
@@ -101,6 +102,11 @@ def as_inspect_content(
|
|
101
102
|
return ContentImage(
|
102
103
|
image=f"data:image/{content.mimeType};base64,{content.data}"
|
103
104
|
)
|
105
|
+
elif isinstance(content, AudioContent):
|
106
|
+
return ContentAudio(
|
107
|
+
audio=f"data:audio/{content.mimeType};base64,{content.data}",
|
108
|
+
format=_get_audio_format(content.mimeType),
|
109
|
+
)
|
104
110
|
elif isinstance(content.resource, TextResourceContents):
|
105
111
|
return ContentText(text=content.resource.text)
|
106
112
|
else:
|
@@ -116,3 +122,13 @@ def as_mcp_content(content: ContentText | ContentImage) -> TextContent | ImageCo
|
|
116
122
|
mimeType=data_uri_mime_type(content.image) or "image/png",
|
117
123
|
data=data_uri_to_base64(content.image),
|
118
124
|
)
|
125
|
+
|
126
|
+
|
127
|
+
def _get_audio_format(mime_type: str) -> Literal["wav", "mp3"]:
|
128
|
+
"""Helper function to determine audio format from MIME type."""
|
129
|
+
if mime_type in ("audio/wav", "audio/x-wav"):
|
130
|
+
return "wav"
|
131
|
+
elif mime_type == "audio/mpeg":
|
132
|
+
return "mp3"
|
133
|
+
else:
|
134
|
+
raise ValueError(f"Unsupported audio mime type: {mime_type}")
|
inspect_ai/tool/_mcp/server.py
CHANGED
inspect_ai/tool/_tool.py
CHANGED
@@ -13,6 +13,7 @@ from typing import (
|
|
13
13
|
|
14
14
|
from inspect_ai._util.content import (
|
15
15
|
ContentAudio,
|
16
|
+
ContentData,
|
16
17
|
ContentImage,
|
17
18
|
ContentReasoning,
|
18
19
|
ContentText,
|
@@ -41,7 +42,15 @@ ToolResult = (
|
|
41
42
|
| ContentImage
|
42
43
|
| ContentAudio
|
43
44
|
| ContentVideo
|
44
|
-
|
|
45
|
+
| ContentData
|
46
|
+
| list[
|
47
|
+
ContentText
|
48
|
+
| ContentReasoning
|
49
|
+
| ContentImage
|
50
|
+
| ContentAudio
|
51
|
+
| ContentVideo
|
52
|
+
| ContentData
|
53
|
+
]
|
45
54
|
)
|
46
55
|
"""Valid types for results from tool calls."""
|
47
56
|
|
@@ -0,0 +1,104 @@
|
|
1
|
+
import os
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
from tenacity import (
|
7
|
+
retry,
|
8
|
+
retry_if_exception,
|
9
|
+
stop_after_attempt,
|
10
|
+
stop_after_delay,
|
11
|
+
wait_exponential_jitter,
|
12
|
+
)
|
13
|
+
|
14
|
+
from inspect_ai._util.content import ContentText
|
15
|
+
from inspect_ai._util.error import PrerequisiteError
|
16
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
17
|
+
from inspect_ai.util._concurrency import concurrency
|
18
|
+
|
19
|
+
|
20
|
+
class BaseHttpProvider(ABC):
|
21
|
+
"""Base class for HTTP-based web search providers (Exa, Tavily, etc.)."""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
env_key_name: str,
|
26
|
+
api_endpoint: str,
|
27
|
+
provider_name: str,
|
28
|
+
concurrency_key: str,
|
29
|
+
options: dict[str, Any] | None = None,
|
30
|
+
):
|
31
|
+
self.env_key_name = env_key_name
|
32
|
+
self.api_endpoint = api_endpoint
|
33
|
+
self.provider_name = provider_name
|
34
|
+
self.concurrency_key = concurrency_key
|
35
|
+
|
36
|
+
self.max_connections = self._extract_max_connections(options)
|
37
|
+
self.api_options = self._prepare_api_options(options)
|
38
|
+
self.api_key = self._validate_api_key()
|
39
|
+
self.client = httpx.AsyncClient(timeout=30)
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def prepare_headers(self, api_key: str) -> dict[str, str]:
|
43
|
+
"""Prepare HTTP headers for the request."""
|
44
|
+
pass
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def parse_response(self, response_data: dict[str, Any]) -> ContentText | None:
|
48
|
+
"""Parse the API response and extract content with citations."""
|
49
|
+
pass
|
50
|
+
|
51
|
+
@abstractmethod
|
52
|
+
def set_default_options(self, options: dict[str, Any]) -> dict[str, Any]:
|
53
|
+
"""Set provider-specific default options."""
|
54
|
+
pass
|
55
|
+
|
56
|
+
def _extract_max_connections(self, options: dict[str, Any] | None) -> int:
|
57
|
+
"""Extract max_connections from options, defaulting to 10."""
|
58
|
+
if not options:
|
59
|
+
return 10
|
60
|
+
max_conn = options.get("max_connections", 10)
|
61
|
+
return int(max_conn) if max_conn is not None else 10
|
62
|
+
|
63
|
+
def _prepare_api_options(self, options: dict[str, Any] | None) -> dict[str, Any]:
|
64
|
+
"""Prepare API options by removing max_connections and setting defaults."""
|
65
|
+
if not options:
|
66
|
+
api_options = {}
|
67
|
+
else:
|
68
|
+
# Remove max_connections as it's not an API option
|
69
|
+
api_options = {k: v for k, v in options.items() if k != "max_connections"}
|
70
|
+
|
71
|
+
# Apply provider-specific defaults
|
72
|
+
return self.set_default_options(api_options)
|
73
|
+
|
74
|
+
def _validate_api_key(self) -> str:
|
75
|
+
"""Validate that the required API key is set in environment."""
|
76
|
+
api_key = os.environ.get(self.env_key_name)
|
77
|
+
if not api_key:
|
78
|
+
raise PrerequisiteError(
|
79
|
+
f"{self.env_key_name} not set in the environment. Please ensure this variable is defined to use {self.provider_name} with the web_search tool.\n\nLearn more about the {self.provider_name} web search provider at https://inspect.aisi.org.uk/tools.html#{self.provider_name.lower()}-provider"
|
80
|
+
)
|
81
|
+
return api_key
|
82
|
+
|
83
|
+
async def search(self, query: str) -> ContentText | None:
|
84
|
+
"""Execute a search query using the provider's API."""
|
85
|
+
|
86
|
+
# Common retry logic for all HTTP providers
|
87
|
+
@retry(
|
88
|
+
wait=wait_exponential_jitter(),
|
89
|
+
stop=stop_after_attempt(5) | stop_after_delay(60),
|
90
|
+
retry=retry_if_exception(httpx_should_retry),
|
91
|
+
before_sleep=log_httpx_retry_attempt(self.api_endpoint),
|
92
|
+
)
|
93
|
+
async def _search() -> httpx.Response:
|
94
|
+
response = await self.client.post(
|
95
|
+
self.api_endpoint,
|
96
|
+
headers=self.prepare_headers(self.api_key),
|
97
|
+
json={"query": query, **self.api_options},
|
98
|
+
)
|
99
|
+
response.raise_for_status()
|
100
|
+
return response
|
101
|
+
|
102
|
+
async with concurrency(self.concurrency_key, self.max_connections):
|
103
|
+
response_data = (await _search()).json()
|
104
|
+
return self.parse_response(response_data)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from typing import Any, Literal
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from inspect_ai._util.citation import UrlCitation
|
6
|
+
from inspect_ai._util.content import ContentText
|
7
|
+
|
8
|
+
from ._base_http_provider import BaseHttpProvider
|
9
|
+
from ._web_search_provider import SearchProvider
|
10
|
+
|
11
|
+
|
12
|
+
class ExaOptions(BaseModel):
|
13
|
+
# See https://docs.exa.ai/reference/answer
|
14
|
+
text: bool | None = None
|
15
|
+
"""Whether to include text content in citations"""
|
16
|
+
model: Literal["exa", "exa-pro"] | None = None
|
17
|
+
"""LLM model to use for generating the answer"""
|
18
|
+
max_connections: int | None = None
|
19
|
+
"""max_connections is not an Exa API option, but an inspect option"""
|
20
|
+
|
21
|
+
|
22
|
+
class ExaCitation(BaseModel):
|
23
|
+
id: str
|
24
|
+
url: str
|
25
|
+
title: str
|
26
|
+
author: str | None = None
|
27
|
+
publishedDate: str | None = None
|
28
|
+
text: str
|
29
|
+
|
30
|
+
|
31
|
+
class ExaSearchResponse(BaseModel):
|
32
|
+
answer: str
|
33
|
+
citations: list[ExaCitation]
|
34
|
+
|
35
|
+
|
36
|
+
class ExaSearchProvider(BaseHttpProvider):
|
37
|
+
def __init__(self, options: dict[str, Any] | None = None):
|
38
|
+
super().__init__(
|
39
|
+
env_key_name="EXA_API_KEY",
|
40
|
+
api_endpoint="https://api.exa.ai/answer",
|
41
|
+
provider_name="Exa",
|
42
|
+
concurrency_key="exa_web_search",
|
43
|
+
options=options,
|
44
|
+
)
|
45
|
+
|
46
|
+
def prepare_headers(self, api_key: str) -> dict[str, str]:
|
47
|
+
return {
|
48
|
+
"x-api-key": api_key,
|
49
|
+
"Content-Type": "application/json",
|
50
|
+
}
|
51
|
+
|
52
|
+
def set_default_options(self, options: dict[str, Any]) -> dict[str, Any]:
|
53
|
+
return options
|
54
|
+
|
55
|
+
def parse_response(self, response_data: dict[str, Any]) -> ContentText | None:
|
56
|
+
exa_search_response = ExaSearchResponse.model_validate(response_data)
|
57
|
+
|
58
|
+
if not exa_search_response.answer and not exa_search_response.citations:
|
59
|
+
return None
|
60
|
+
|
61
|
+
return ContentText(
|
62
|
+
text=exa_search_response.answer,
|
63
|
+
citations=[
|
64
|
+
UrlCitation(
|
65
|
+
cited_text=citation.text, title=citation.title, url=citation.url
|
66
|
+
)
|
67
|
+
for citation in exa_search_response.citations
|
68
|
+
],
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def exa_search_provider(
|
73
|
+
in_options: dict[str, object] | None = None,
|
74
|
+
) -> SearchProvider:
|
75
|
+
options = ExaOptions.model_validate(in_options) if in_options else None
|
76
|
+
return ExaSearchProvider(
|
77
|
+
options.model_dump(exclude_none=True) if options else None
|
78
|
+
).search
|