inspect-ai 0.3.103__py3-none-any.whl → 0.3.104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. inspect_ai/_cli/common.py +2 -1
  2. inspect_ai/_cli/eval.py +2 -2
  3. inspect_ai/_display/core/active.py +3 -0
  4. inspect_ai/_display/core/config.py +1 -0
  5. inspect_ai/_display/core/panel.py +21 -13
  6. inspect_ai/_display/core/results.py +3 -7
  7. inspect_ai/_display/core/rich.py +3 -5
  8. inspect_ai/_display/log/__init__.py +0 -0
  9. inspect_ai/_display/log/display.py +173 -0
  10. inspect_ai/_display/plain/display.py +2 -2
  11. inspect_ai/_display/rich/display.py +2 -4
  12. inspect_ai/_display/textual/app.py +1 -6
  13. inspect_ai/_display/textual/widgets/task_detail.py +3 -14
  14. inspect_ai/_display/textual/widgets/tasks.py +1 -1
  15. inspect_ai/_eval/eval.py +1 -1
  16. inspect_ai/_eval/evalset.py +2 -2
  17. inspect_ai/_eval/registry.py +6 -1
  18. inspect_ai/_eval/run.py +5 -1
  19. inspect_ai/_eval/task/constants.py +1 -0
  20. inspect_ai/_eval/task/log.py +2 -0
  21. inspect_ai/_eval/task/run.py +1 -1
  22. inspect_ai/_util/citation.py +88 -0
  23. inspect_ai/_util/content.py +24 -2
  24. inspect_ai/_util/json.py +17 -2
  25. inspect_ai/_util/registry.py +19 -4
  26. inspect_ai/_view/schema.py +0 -6
  27. inspect_ai/_view/www/dist/assets/index.css +82 -24
  28. inspect_ai/_view/www/dist/assets/index.js +10124 -9808
  29. inspect_ai/_view/www/log-schema.json +418 -1
  30. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
  31. inspect_ai/_view/www/node_modules/katex/src/fonts/generate_fonts.py +58 -0
  32. inspect_ai/_view/www/node_modules/katex/src/metrics/extract_tfms.py +114 -0
  33. inspect_ai/_view/www/node_modules/katex/src/metrics/extract_ttfs.py +122 -0
  34. inspect_ai/_view/www/node_modules/katex/src/metrics/format_json.py +28 -0
  35. inspect_ai/_view/www/node_modules/katex/src/metrics/parse_tfm.py +211 -0
  36. inspect_ai/_view/www/package.json +2 -2
  37. inspect_ai/_view/www/src/@types/log.d.ts +140 -39
  38. inspect_ai/_view/www/src/app/content/RecordTree.tsx +13 -0
  39. inspect_ai/_view/www/src/app/log-view/LogView.tsx +1 -1
  40. inspect_ai/_view/www/src/app/routing/logNavigation.ts +31 -0
  41. inspect_ai/_view/www/src/app/routing/{navigationHooks.ts → sampleNavigation.ts} +39 -86
  42. inspect_ai/_view/www/src/app/samples/SampleDialog.tsx +1 -1
  43. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +1 -1
  44. inspect_ai/_view/www/src/app/samples/chat/MessageCitations.module.css +16 -0
  45. inspect_ai/_view/www/src/app/samples/chat/MessageCitations.tsx +63 -0
  46. inspect_ai/_view/www/src/app/samples/chat/MessageContent.module.css +6 -0
  47. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +174 -25
  48. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +21 -3
  49. inspect_ai/_view/www/src/app/samples/chat/content-data/ContentDataView.module.css +7 -0
  50. inspect_ai/_view/www/src/app/samples/chat/content-data/ContentDataView.tsx +111 -0
  51. inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearch.module.css +10 -0
  52. inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearch.tsx +14 -0
  53. inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearchResults.module.css +19 -0
  54. inspect_ai/_view/www/src/app/samples/chat/content-data/WebSearchResults.tsx +49 -0
  55. inspect_ai/_view/www/src/app/samples/chat/messages.ts +7 -1
  56. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +12 -2
  57. inspect_ai/_view/www/src/app/samples/chat/types.ts +4 -0
  58. inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +1 -1
  59. inspect_ai/_view/www/src/app/samples/sampleLimit.ts +2 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  61. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +4 -4
  62. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +1 -1
  63. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +15 -2
  64. inspect_ai/_view/www/src/tests/README.md +2 -2
  65. inspect_ai/_view/www/src/utils/git.ts +3 -1
  66. inspect_ai/_view/www/src/utils/html.ts +6 -0
  67. inspect_ai/agent/_handoff.py +3 -3
  68. inspect_ai/log/_condense.py +5 -0
  69. inspect_ai/log/_file.py +4 -1
  70. inspect_ai/log/_log.py +9 -4
  71. inspect_ai/log/_recorders/json.py +4 -2
  72. inspect_ai/log/_util.py +2 -0
  73. inspect_ai/model/__init__.py +14 -0
  74. inspect_ai/model/_call_tools.py +13 -4
  75. inspect_ai/model/_chat_message.py +3 -0
  76. inspect_ai/model/_openai_responses.py +80 -34
  77. inspect_ai/model/_providers/_anthropic_citations.py +158 -0
  78. inspect_ai/model/_providers/_google_citations.py +100 -0
  79. inspect_ai/model/_providers/anthropic.py +196 -34
  80. inspect_ai/model/_providers/google.py +94 -22
  81. inspect_ai/model/_providers/mistral.py +20 -7
  82. inspect_ai/model/_providers/openai.py +11 -10
  83. inspect_ai/model/_providers/openai_compatible.py +3 -2
  84. inspect_ai/model/_providers/openai_responses.py +2 -5
  85. inspect_ai/model/_providers/perplexity.py +123 -0
  86. inspect_ai/model/_providers/providers.py +13 -2
  87. inspect_ai/model/_providers/vertex.py +3 -0
  88. inspect_ai/model/_trim.py +5 -0
  89. inspect_ai/tool/__init__.py +14 -0
  90. inspect_ai/tool/_mcp/_mcp.py +5 -2
  91. inspect_ai/tool/_mcp/sampling.py +19 -3
  92. inspect_ai/tool/_mcp/server.py +1 -1
  93. inspect_ai/tool/_tool.py +10 -1
  94. inspect_ai/tool/_tools/_web_search/_base_http_provider.py +104 -0
  95. inspect_ai/tool/_tools/_web_search/_exa.py +78 -0
  96. inspect_ai/tool/_tools/_web_search/_google.py +22 -25
  97. inspect_ai/tool/_tools/_web_search/_tavily.py +47 -65
  98. inspect_ai/tool/_tools/_web_search/_web_search.py +83 -36
  99. inspect_ai/tool/_tools/_web_search/_web_search_provider.py +7 -0
  100. inspect_ai/util/_display.py +11 -2
  101. inspect_ai/util/_sandbox/docker/compose.py +2 -2
  102. inspect_ai/util/_span.py +12 -1
  103. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/METADATA +2 -2
  104. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/RECORD +110 -86
  105. /inspect_ai/model/{_openai_computer_use.py → _providers/_openai_computer_use.py} +0 -0
  106. /inspect_ai/model/{_openai_web_search.py → _providers/_openai_web_search.py} +0 -0
  107. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/WHEEL +0 -0
  108. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/entry_points.txt +0 -0
  109. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/licenses/LICENSE +0 -0
  110. {inspect_ai-0.3.103.dist-info → inspect_ai-0.3.104.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,8 @@ class OpenAICompatibleAPI(ModelAPI):
61
61
  self.service = service
62
62
 
63
63
  # compute api key
64
- api_key_var = f"{self.service.upper()}_API_KEY"
64
+ service_env_name = self.service.upper().replace("-", "_")
65
+ api_key_var = f"{service_env_name}_API_KEY"
65
66
 
66
67
  super().__init__(
67
68
  model_name=model_name,
@@ -82,7 +83,7 @@ class OpenAICompatibleAPI(ModelAPI):
82
83
 
83
84
  # use service prefix to lookup base_url
84
85
  if not self.base_url:
85
- base_url_var = f"{self.service.upper()}_BASE_URL"
86
+ base_url_var = f"{service_env_name}_BASE_URL"
86
87
  self.base_url = model_base_url(base_url, [base_url_var]) or service_base_url
87
88
  if not self.base_url:
88
89
  raise environment_prerequisite_error(
@@ -40,7 +40,6 @@ async def generate_responses(
40
40
  tool_choice: ToolChoice,
41
41
  config: GenerateConfig,
42
42
  service_tier: str | None,
43
- store: bool,
44
43
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
45
44
  # allocate request_id (so we can see it from ModelCall)
46
45
  request_id = http_hooks.start_request()
@@ -65,7 +64,7 @@ async def generate_responses(
65
64
  else NOT_GIVEN
66
65
  )
67
66
  request = dict(
68
- input=await openai_responses_inputs(input, model_name, store),
67
+ input=await openai_responses_inputs(input, model_name),
69
68
  tools=tool_params,
70
69
  tool_choice=openai_responses_tool_choice(tool_choice, tool_params)
71
70
  if isinstance(tool_params, list) and tool_choice != "auto"
@@ -77,7 +76,6 @@ async def generate_responses(
77
76
  config=config,
78
77
  service_tier=service_tier,
79
78
  tools=len(tools) > 0,
80
- store=store,
81
79
  ),
82
80
  )
83
81
 
@@ -125,7 +123,6 @@ def completion_params_responses(
125
123
  config: GenerateConfig,
126
124
  service_tier: str | None,
127
125
  tools: bool,
128
- store: bool,
129
126
  ) -> dict[str, Any]:
130
127
  # TODO: we'll need a computer_use_preview bool for the 'include'
131
128
  # and 'reasoning' parameters
@@ -135,7 +132,7 @@ def completion_params_responses(
135
132
  f"OpenAI Responses API does not support the '{param}' parameter.",
136
133
  )
137
134
 
138
- params: dict[str, Any] = dict(model=model_name, store=store)
135
+ params: dict[str, Any] = dict(model=model_name)
139
136
  if service_tier is not None:
140
137
  params["service_tier"] = service_tier
141
138
  if config.max_tokens is not None:
@@ -0,0 +1,123 @@
1
+ from typing import Any, cast
2
+
3
+ from openai.types.chat import ChatCompletion
4
+
5
+ from inspect_ai._util.citation import UrlCitation
6
+ from inspect_ai._util.content import ContentText
7
+ from inspect_ai.model._generate_config import GenerateConfig
8
+ from inspect_ai.model._model_output import ModelOutput, ModelUsage
9
+ from inspect_ai.model._openai import chat_choices_from_openai
10
+ from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
11
+ from inspect_ai.tool import ToolChoice, ToolInfo
12
+
13
+ from .._chat_message import ChatMessage
14
+ from .._model_call import ModelCall
15
+ from .._model_output import ChatCompletionChoice
16
+
17
+
18
+ class PerplexityAPI(OpenAICompatibleAPI):
19
+ """Model provider for Perplexity AI."""
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str,
24
+ base_url: str | None = None,
25
+ api_key: str | None = None,
26
+ config: GenerateConfig = GenerateConfig(),
27
+ **model_args: Any,
28
+ ) -> None:
29
+ super().__init__(
30
+ model_name=model_name,
31
+ base_url=base_url,
32
+ api_key=api_key,
33
+ config=config,
34
+ service="Perplexity",
35
+ service_base_url="https://api.perplexity.ai",
36
+ **model_args,
37
+ )
38
+
39
+ self._response: dict[str, Any] | None = None
40
+
41
+ def on_response(self, response: dict[str, Any]) -> None:
42
+ """Capture the raw response for post-processing."""
43
+ self._response = response
44
+
45
+ async def generate(
46
+ self,
47
+ input: list["ChatMessage"],
48
+ tools: list["ToolInfo"],
49
+ tool_choice: "ToolChoice",
50
+ config: GenerateConfig,
51
+ ) -> tuple[ModelOutput | Exception, "ModelCall"]:
52
+ result = await super().generate(input, tools, tool_choice, config)
53
+ output, call = cast(tuple[ModelOutput, "ModelCall"], result)
54
+
55
+ if self._response:
56
+ response = self._response
57
+
58
+ # attach citations if search results are returned
59
+ search_results = response.get("search_results")
60
+ if isinstance(search_results, list):
61
+ citations = [
62
+ UrlCitation(title=sr.get("title"), url=sr.get("url", ""))
63
+ for sr in search_results
64
+ if isinstance(sr, dict) and sr.get("url") is not None
65
+ ]
66
+ if citations:
67
+ for choice in output.choices:
68
+ msg = choice.message
69
+ if isinstance(msg.content, str):
70
+ msg.content = [
71
+ ContentText(text=msg.content, citations=citations)
72
+ ]
73
+ else:
74
+ added = False
75
+ for content in msg.content:
76
+ if (
77
+ isinstance(content, ContentText)
78
+ and getattr(content, "citations", None) is None
79
+ ):
80
+ content.citations = citations
81
+ added = True
82
+ break
83
+ if not added:
84
+ msg.content.append(
85
+ ContentText(text="", citations=citations)
86
+ )
87
+
88
+ # update usage with additional metrics
89
+ usage_data = response.get("usage")
90
+ if isinstance(usage_data, dict):
91
+ extra_usage = {
92
+ k: usage_data.get(k)
93
+ for k in [
94
+ "search_context_size",
95
+ "citation_tokens",
96
+ "num_search_queries",
97
+ ]
98
+ if k in usage_data
99
+ }
100
+ if output.usage:
101
+ output.usage.reasoning_tokens = usage_data.get("reasoning_tokens")
102
+ else:
103
+ output.usage = ModelUsage(
104
+ input_tokens=usage_data.get("prompt_tokens", 0),
105
+ output_tokens=usage_data.get("completion_tokens", 0),
106
+ total_tokens=usage_data.get("total_tokens", 0),
107
+ reasoning_tokens=usage_data.get("reasoning_tokens"),
108
+ )
109
+ if extra_usage:
110
+ output.metadata = output.metadata or {}
111
+ output.metadata.update(extra_usage)
112
+
113
+ # keep search_results for reference
114
+ if search_results:
115
+ output.metadata = output.metadata or {}
116
+ output.metadata["search_results"] = search_results
117
+
118
+ return output, call
119
+
120
+ def chat_choices_from_completion(
121
+ self, completion: ChatCompletion, tools: list[ToolInfo]
122
+ ) -> list[ChatCompletionChoice]:
123
+ return chat_choices_from_openai(completion, tools)
@@ -59,7 +59,7 @@ def openai_api() -> type[ModelAPI]:
59
59
  def anthropic() -> type[ModelAPI]:
60
60
  FEATURE = "Anthropic API"
61
61
  PACKAGE = "anthropic"
62
- MIN_VERSION = "0.49.0"
62
+ MIN_VERSION = "0.52.0"
63
63
 
64
64
  # verify we have the package
65
65
  try:
@@ -157,7 +157,7 @@ def cf() -> type[ModelAPI]:
157
157
  def mistral() -> type[ModelAPI]:
158
158
  FEATURE = "Mistral API"
159
159
  PACKAGE = "mistralai"
160
- MIN_VERSION = "1.6.0"
160
+ MIN_VERSION = "1.8.2"
161
161
 
162
162
  # verify we have the package
163
163
  try:
@@ -218,6 +218,17 @@ def openrouter() -> type[ModelAPI]:
218
218
  return OpenRouterAPI
219
219
 
220
220
 
221
+ @modelapi(name="perplexity")
222
+ def perplexity() -> type[ModelAPI]:
223
+ # validate
224
+ validate_openai_client("Perplexity API")
225
+
226
+ # in the clear
227
+ from .perplexity import PerplexityAPI
228
+
229
+ return PerplexityAPI
230
+
231
+
221
232
  @modelapi(name="llama-cpp-python")
222
233
  def llama_cpp_python() -> type[ModelAPI]:
223
234
  # validate
@@ -33,6 +33,7 @@ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
33
33
  from inspect_ai._util.content import (
34
34
  Content,
35
35
  ContentAudio,
36
+ ContentData,
36
37
  ContentImage,
37
38
  ContentReasoning,
38
39
  ContentText,
@@ -338,6 +339,8 @@ async def content_part(content: Content | str) -> Part:
338
339
  else:
339
340
  if isinstance(content, ContentAudio):
340
341
  file = content.audio
342
+ elif isinstance(content, ContentData):
343
+ assert False, "Vertex provider should never encounter ContentData"
341
344
  else:
342
345
  # it's ContentVideo
343
346
  file = content.video
inspect_ai/model/_trim.py CHANGED
@@ -13,6 +13,7 @@ def trim_messages(
13
13
  - Retaining the 'input' messages from the sample.
14
14
  - Preserving a proportion of the remaining messages (`preserve=0.7` by default).
15
15
  - Ensuring that all assistant tool calls have corresponding tool messages.
16
+ - Ensuring that the sequence of messages doesn't end with an assistant message.
16
17
 
17
18
  Args:
18
19
  messages: List of messages to trim.
@@ -49,6 +50,10 @@ def trim_messages(
49
50
  active_tool_ids = set()
50
51
  conversation_messages.append(message)
51
52
 
53
+ # it's possible that we end with an assistant message w/ if so, remove it
54
+ if len(conversation_messages) and conversation_messages[-1].role == "assistant":
55
+ conversation_messages.pop()
56
+
52
57
  # return trimmed messages
53
58
  return partitioned.system + partitioned.input + conversation_messages
54
59
 
@@ -1,6 +1,14 @@
1
+ from inspect_ai._util.citation import (
2
+ Citation,
3
+ CitationBase,
4
+ ContentCitation,
5
+ DocumentCitation,
6
+ UrlCitation,
7
+ )
1
8
  from inspect_ai._util.content import (
2
9
  Content,
3
10
  ContentAudio,
11
+ ContentData,
4
12
  ContentImage,
5
13
  ContentReasoning,
6
14
  ContentText,
@@ -62,6 +70,7 @@ __all__ = [
62
70
  "MCPServer",
63
71
  "Content",
64
72
  "ContentAudio",
73
+ "ContentData",
65
74
  "ContentImage",
66
75
  "ContentReasoning",
67
76
  "ContentText",
@@ -77,6 +86,11 @@ __all__ = [
77
86
  "ToolInfo",
78
87
  "ToolParam",
79
88
  "ToolParams",
89
+ "Citation",
90
+ "CitationBase",
91
+ "DocumentCitation",
92
+ "ContentCitation",
93
+ "UrlCitation",
80
94
  ]
81
95
 
82
96
  _UTIL_MODULE_VERSION = "0.3.19"
@@ -12,6 +12,7 @@ from mcp.client.session import ClientSession, SamplingFnT
12
12
  from mcp.client.sse import sse_client
13
13
  from mcp.client.stdio import StdioServerParameters, stdio_client
14
14
  from mcp.types import (
15
+ AudioContent,
15
16
  EmbeddedResource,
16
17
  ImageContent,
17
18
  TextContent,
@@ -282,14 +283,16 @@ def create_server_sandbox(
282
283
 
283
284
 
284
285
  def tool_result_as_text(
285
- content: list[TextContent | ImageContent | EmbeddedResource],
286
+ content: list[TextContent | ImageContent | AudioContent | EmbeddedResource],
286
287
  ) -> str:
287
288
  content_list: list[str] = []
288
289
  for c in content:
289
290
  if isinstance(c, TextContent):
290
291
  content_list.append(c.text)
291
292
  elif isinstance(c, ImageContent):
292
- content_list.append("(base64 encoded image ommitted)")
293
+ content_list.append("(base64 encoded image omitted)")
294
+ elif isinstance(c, AudioContent):
295
+ content_list.append("(base64 encoded audio omitted)")
293
296
  elif isinstance(c.resource, TextResourceContents):
294
297
  content_list.append(c.resource.text)
295
298
 
@@ -1,9 +1,10 @@
1
- from typing import Any
1
+ from typing import Any, Literal
2
2
 
3
3
  from mcp.client.session import ClientSession
4
4
  from mcp.shared.context import RequestContext
5
5
  from mcp.types import (
6
6
  INTERNAL_ERROR,
7
+ AudioContent,
7
8
  CreateMessageRequestParams,
8
9
  CreateMessageResult,
9
10
  EmbeddedResource,
@@ -16,7 +17,7 @@ from mcp.types import (
16
17
  StopReason as MCPStopReason,
17
18
  )
18
19
 
19
- from inspect_ai._util.content import Content, ContentImage, ContentText
20
+ from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
20
21
  from inspect_ai._util.error import exception_message
21
22
  from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
22
23
 
@@ -93,7 +94,7 @@ async def sampling_fn(
93
94
 
94
95
 
95
96
  def as_inspect_content(
96
- content: TextContent | ImageContent | EmbeddedResource,
97
+ content: TextContent | ImageContent | AudioContent | EmbeddedResource,
97
98
  ) -> Content:
98
99
  if isinstance(content, TextContent):
99
100
  return ContentText(text=content.text)
@@ -101,6 +102,11 @@ def as_inspect_content(
101
102
  return ContentImage(
102
103
  image=f"data:image/{content.mimeType};base64,{content.data}"
103
104
  )
105
+ elif isinstance(content, AudioContent):
106
+ return ContentAudio(
107
+ audio=f"data:audio/{content.mimeType};base64,{content.data}",
108
+ format=_get_audio_format(content.mimeType),
109
+ )
104
110
  elif isinstance(content.resource, TextResourceContents):
105
111
  return ContentText(text=content.resource.text)
106
112
  else:
@@ -116,3 +122,13 @@ def as_mcp_content(content: ContentText | ContentImage) -> TextContent | ImageCo
116
122
  mimeType=data_uri_mime_type(content.image) or "image/png",
117
123
  data=data_uri_to_base64(content.image),
118
124
  )
125
+
126
+
127
+ def _get_audio_format(mime_type: str) -> Literal["wav", "mp3"]:
128
+ """Helper function to determine audio format from MIME type."""
129
+ if mime_type in ("audio/wav", "audio/x-wav"):
130
+ return "wav"
131
+ elif mime_type == "audio/mpeg":
132
+ return "mp3"
133
+ else:
134
+ raise ValueError(f"Unsupported audio mime type: {mime_type}")
@@ -102,7 +102,7 @@ def mcp_server_sandbox(
102
102
  def verfify_mcp_package() -> None:
103
103
  FEATURE = "MCP tools"
104
104
  PACKAGE = "mcp"
105
- MIN_VERSION = "1.8.0"
105
+ MIN_VERSION = "1.9.4"
106
106
 
107
107
  # verify we have the package
108
108
  try:
inspect_ai/tool/_tool.py CHANGED
@@ -13,6 +13,7 @@ from typing import (
13
13
 
14
14
  from inspect_ai._util.content import (
15
15
  ContentAudio,
16
+ ContentData,
16
17
  ContentImage,
17
18
  ContentReasoning,
18
19
  ContentText,
@@ -41,7 +42,15 @@ ToolResult = (
41
42
  | ContentImage
42
43
  | ContentAudio
43
44
  | ContentVideo
44
- | list[ContentText | ContentReasoning | ContentImage | ContentAudio | ContentVideo]
45
+ | ContentData
46
+ | list[
47
+ ContentText
48
+ | ContentReasoning
49
+ | ContentImage
50
+ | ContentAudio
51
+ | ContentVideo
52
+ | ContentData
53
+ ]
45
54
  )
46
55
  """Valid types for results from tool calls."""
47
56
 
@@ -0,0 +1,104 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any
4
+
5
+ import httpx
6
+ from tenacity import (
7
+ retry,
8
+ retry_if_exception,
9
+ stop_after_attempt,
10
+ stop_after_delay,
11
+ wait_exponential_jitter,
12
+ )
13
+
14
+ from inspect_ai._util.content import ContentText
15
+ from inspect_ai._util.error import PrerequisiteError
16
+ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
17
+ from inspect_ai.util._concurrency import concurrency
18
+
19
+
20
+ class BaseHttpProvider(ABC):
21
+ """Base class for HTTP-based web search providers (Exa, Tavily, etc.)."""
22
+
23
+ def __init__(
24
+ self,
25
+ env_key_name: str,
26
+ api_endpoint: str,
27
+ provider_name: str,
28
+ concurrency_key: str,
29
+ options: dict[str, Any] | None = None,
30
+ ):
31
+ self.env_key_name = env_key_name
32
+ self.api_endpoint = api_endpoint
33
+ self.provider_name = provider_name
34
+ self.concurrency_key = concurrency_key
35
+
36
+ self.max_connections = self._extract_max_connections(options)
37
+ self.api_options = self._prepare_api_options(options)
38
+ self.api_key = self._validate_api_key()
39
+ self.client = httpx.AsyncClient(timeout=30)
40
+
41
+ @abstractmethod
42
+ def prepare_headers(self, api_key: str) -> dict[str, str]:
43
+ """Prepare HTTP headers for the request."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def parse_response(self, response_data: dict[str, Any]) -> ContentText | None:
48
+ """Parse the API response and extract content with citations."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def set_default_options(self, options: dict[str, Any]) -> dict[str, Any]:
53
+ """Set provider-specific default options."""
54
+ pass
55
+
56
+ def _extract_max_connections(self, options: dict[str, Any] | None) -> int:
57
+ """Extract max_connections from options, defaulting to 10."""
58
+ if not options:
59
+ return 10
60
+ max_conn = options.get("max_connections", 10)
61
+ return int(max_conn) if max_conn is not None else 10
62
+
63
+ def _prepare_api_options(self, options: dict[str, Any] | None) -> dict[str, Any]:
64
+ """Prepare API options by removing max_connections and setting defaults."""
65
+ if not options:
66
+ api_options = {}
67
+ else:
68
+ # Remove max_connections as it's not an API option
69
+ api_options = {k: v for k, v in options.items() if k != "max_connections"}
70
+
71
+ # Apply provider-specific defaults
72
+ return self.set_default_options(api_options)
73
+
74
+ def _validate_api_key(self) -> str:
75
+ """Validate that the required API key is set in environment."""
76
+ api_key = os.environ.get(self.env_key_name)
77
+ if not api_key:
78
+ raise PrerequisiteError(
79
+ f"{self.env_key_name} not set in the environment. Please ensure this variable is defined to use {self.provider_name} with the web_search tool.\n\nLearn more about the {self.provider_name} web search provider at https://inspect.aisi.org.uk/tools.html#{self.provider_name.lower()}-provider"
80
+ )
81
+ return api_key
82
+
83
+ async def search(self, query: str) -> ContentText | None:
84
+ """Execute a search query using the provider's API."""
85
+
86
+ # Common retry logic for all HTTP providers
87
+ @retry(
88
+ wait=wait_exponential_jitter(),
89
+ stop=stop_after_attempt(5) | stop_after_delay(60),
90
+ retry=retry_if_exception(httpx_should_retry),
91
+ before_sleep=log_httpx_retry_attempt(self.api_endpoint),
92
+ )
93
+ async def _search() -> httpx.Response:
94
+ response = await self.client.post(
95
+ self.api_endpoint,
96
+ headers=self.prepare_headers(self.api_key),
97
+ json={"query": query, **self.api_options},
98
+ )
99
+ response.raise_for_status()
100
+ return response
101
+
102
+ async with concurrency(self.concurrency_key, self.max_connections):
103
+ response_data = (await _search()).json()
104
+ return self.parse_response(response_data)
@@ -0,0 +1,78 @@
1
+ from typing import Any, Literal
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from inspect_ai._util.citation import UrlCitation
6
+ from inspect_ai._util.content import ContentText
7
+
8
+ from ._base_http_provider import BaseHttpProvider
9
+ from ._web_search_provider import SearchProvider
10
+
11
+
12
+ class ExaOptions(BaseModel):
13
+ # See https://docs.exa.ai/reference/answer
14
+ text: bool | None = None
15
+ """Whether to include text content in citations"""
16
+ model: Literal["exa", "exa-pro"] | None = None
17
+ """LLM model to use for generating the answer"""
18
+ max_connections: int | None = None
19
+ """max_connections is not an Exa API option, but an inspect option"""
20
+
21
+
22
+ class ExaCitation(BaseModel):
23
+ id: str
24
+ url: str
25
+ title: str
26
+ author: str | None = None
27
+ publishedDate: str | None = None
28
+ text: str
29
+
30
+
31
+ class ExaSearchResponse(BaseModel):
32
+ answer: str
33
+ citations: list[ExaCitation]
34
+
35
+
36
+ class ExaSearchProvider(BaseHttpProvider):
37
+ def __init__(self, options: dict[str, Any] | None = None):
38
+ super().__init__(
39
+ env_key_name="EXA_API_KEY",
40
+ api_endpoint="https://api.exa.ai/answer",
41
+ provider_name="Exa",
42
+ concurrency_key="exa_web_search",
43
+ options=options,
44
+ )
45
+
46
+ def prepare_headers(self, api_key: str) -> dict[str, str]:
47
+ return {
48
+ "x-api-key": api_key,
49
+ "Content-Type": "application/json",
50
+ }
51
+
52
+ def set_default_options(self, options: dict[str, Any]) -> dict[str, Any]:
53
+ return options
54
+
55
+ def parse_response(self, response_data: dict[str, Any]) -> ContentText | None:
56
+ exa_search_response = ExaSearchResponse.model_validate(response_data)
57
+
58
+ if not exa_search_response.answer and not exa_search_response.citations:
59
+ return None
60
+
61
+ return ContentText(
62
+ text=exa_search_response.answer,
63
+ citations=[
64
+ UrlCitation(
65
+ cited_text=citation.text, title=citation.title, url=citation.url
66
+ )
67
+ for citation in exa_search_response.citations
68
+ ],
69
+ )
70
+
71
+
72
+ def exa_search_provider(
73
+ in_options: dict[str, object] | None = None,
74
+ ) -> SearchProvider:
75
+ options = ExaOptions.model_validate(in_options) if in_options else None
76
+ return ExaSearchProvider(
77
+ options.model_dump(exclude_none=True) if options else None
78
+ ).search