inspect-ai 0.3.104__py3-none-any.whl → 0.3.106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/context.py +5 -0
- inspect_ai/_eval/eval.py +113 -1
- inspect_ai/_eval/evalset.py +1 -1
- inspect_ai/_eval/task/run.py +64 -38
- inspect_ai/_util/eval_task_group.py +15 -0
- inspect_ai/_view/server.py +17 -0
- inspect_ai/_view/www/dist/assets/index.css +33 -29
- inspect_ai/_view/www/dist/assets/index.js +559 -247
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.module.css +4 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +17 -0
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +26 -0
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +14 -3
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +359 -7
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/language.ts +6 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +1 -1
- inspect_ai/_view/www/src/client/api/api-browser.ts +25 -0
- inspect_ai/_view/www/src/client/api/api-http.ts +3 -0
- inspect_ai/_view/www/src/client/api/api-vscode.ts +6 -0
- inspect_ai/_view/www/src/client/api/client-api.ts +3 -0
- inspect_ai/_view/www/src/client/api/jsonrpc.ts +1 -0
- inspect_ai/_view/www/src/client/api/types.ts +3 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +17 -1
- inspect_ai/agent/_handoff.py +5 -2
- inspect_ai/agent/_react.py +43 -20
- inspect_ai/dataset/_dataset.py +1 -1
- inspect_ai/log/_samples.py +5 -0
- inspect_ai/model/_call_tools.py +4 -4
- inspect_ai/model/_providers/_openai_web_search.py +1 -1
- inspect_ai/model/_providers/anthropic.py +23 -2
- inspect_ai/model/_providers/google.py +5 -1
- inspect_ai/model/_providers/groq.py +5 -0
- inspect_ai/model/_providers/perplexity.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/tool/_tools/_web_search/_web_search.py +8 -3
- inspect_ai/util/__init__.py +8 -0
- inspect_ai/util/_background.py +64 -0
- inspect_ai/util/_limit.py +72 -5
- inspect_ai/util/_sandbox/__init__.py +2 -0
- inspect_ai/util/_sandbox/service.py +28 -7
- inspect_ai/util/_subprocess.py +51 -38
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/RECORD +46 -44
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.104.dist-info → inspect_ai-0.3.106.dist-info}/top_level.txt +0 -0
@@ -335,6 +335,9 @@ export const clientApi = (api: LogViewAPI, log_file?: string): ClientAPI => {
|
|
335
335
|
) => {
|
336
336
|
return api.download_file(download_file, file_contents);
|
337
337
|
},
|
338
|
+
log_message: (log_file: string, message: string) => {
|
339
|
+
return api.log_message(log_file, message);
|
340
|
+
},
|
338
341
|
get_log_pending_samples: api.eval_pending_samples
|
339
342
|
? get_log_pending_samples
|
340
343
|
: undefined,
|
@@ -41,6 +41,7 @@ export const kMethodEvalLogBytes = "eval_log_bytes";
|
|
41
41
|
export const kMethodEvalLogHeaders = "eval_log_headers";
|
42
42
|
export const kMethodPendingSamples = "eval_log_pending_samples";
|
43
43
|
export const kMethodSampleData = "eval_log_sample_data";
|
44
|
+
export const kMethodLogMessage = "log_message";
|
44
45
|
|
45
46
|
export const kJsonRpcParseError = -32700;
|
46
47
|
export const kJsonRpcInvalidRequest = -32600;
|
@@ -115,6 +115,7 @@ export interface SampleSummary {
|
|
115
115
|
scores: Scores1;
|
116
116
|
error?: string;
|
117
117
|
limit?: string;
|
118
|
+
metadata?: Record<string, any>;
|
118
119
|
completed?: boolean;
|
119
120
|
retries?: number;
|
120
121
|
}
|
@@ -149,6 +150,7 @@ export interface LogViewAPI {
|
|
149
150
|
end: number,
|
150
151
|
) => Promise<Uint8Array>;
|
151
152
|
eval_log_headers: (log_files: string[]) => Promise<EvalLog[]>;
|
153
|
+
log_message: (log_file: string, message: string) => Promise<void>;
|
152
154
|
download_file: (
|
153
155
|
filename: string,
|
154
156
|
filecontents: string | Blob | ArrayBuffer | ArrayBufferView,
|
@@ -177,6 +179,7 @@ export interface ClientAPI {
|
|
177
179
|
id: string | number,
|
178
180
|
epoch: number,
|
179
181
|
) => Promise<EvalSample | undefined>;
|
182
|
+
log_message?: (log_file: string, message: string) => Promise<void>;
|
180
183
|
download_file: (
|
181
184
|
file_name: string,
|
182
185
|
file_contents: string | Blob | ArrayBuffer | ArrayBufferView,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import { Event } from "../app/types";
|
2
2
|
import {
|
3
3
|
AttachmentData,
|
4
|
+
ClientAPI,
|
4
5
|
EventData,
|
5
6
|
SampleData,
|
6
7
|
SampleSummary,
|
@@ -183,6 +184,8 @@ export function createSamplePolling(
|
|
183
184
|
const processedEvents = processEvents(
|
184
185
|
sampleDataResponse.sampleData,
|
185
186
|
pollingState,
|
187
|
+
api,
|
188
|
+
logFile,
|
186
189
|
);
|
187
190
|
|
188
191
|
// update max attachment id
|
@@ -268,7 +271,12 @@ function processAttachments(
|
|
268
271
|
});
|
269
272
|
}
|
270
273
|
|
271
|
-
function processEvents(
|
274
|
+
function processEvents(
|
275
|
+
sampleData: SampleData,
|
276
|
+
pollingState: PollingState,
|
277
|
+
api: ClientAPI,
|
278
|
+
log_file: string,
|
279
|
+
) {
|
272
280
|
// Go through each event and resolve it, either appending or replacing
|
273
281
|
log.debug(`Processing ${sampleData.events.length} events`);
|
274
282
|
if (sampleData.events.length === 0) {
|
@@ -289,6 +297,14 @@ function processEvents(sampleData: SampleData, pollingState: PollingState) {
|
|
289
297
|
attachmentId,
|
290
298
|
available_attachments: Object.keys(pollingState.attachments),
|
291
299
|
};
|
300
|
+
|
301
|
+
if (api.log_message) {
|
302
|
+
api.log_message(
|
303
|
+
log_file,
|
304
|
+
`Unable to resolve attachment ${attachmentId}\n` +
|
305
|
+
JSON.stringify(snapshot),
|
306
|
+
);
|
307
|
+
}
|
292
308
|
console.warn(`Unable to resolve attachment ${attachmentId}`, snapshot);
|
293
309
|
},
|
294
310
|
);
|
inspect_ai/agent/_handoff.py
CHANGED
@@ -6,7 +6,7 @@ from inspect_ai._util.registry import (
|
|
6
6
|
registry_unqualified_name,
|
7
7
|
set_registry_info,
|
8
8
|
)
|
9
|
-
from inspect_ai.tool._tool import Tool, ToolResult, ToolSource
|
9
|
+
from inspect_ai.tool._tool import TOOL_PARALLEL, Tool, ToolResult, ToolSource
|
10
10
|
from inspect_ai.tool._tool_def import ToolDef
|
11
11
|
from inspect_ai.tool._tool_description import ToolDescription, set_tool_description
|
12
12
|
from inspect_ai.util._limit import Limit
|
@@ -61,7 +61,10 @@ def handoff(
|
|
61
61
|
agent, tool_info.name, input_filter, output_filter, limits, **agent_kwargs
|
62
62
|
)
|
63
63
|
tool_name = tool_name or f"transfer_to_{tool_info.name}"
|
64
|
-
set_registry_info(
|
64
|
+
set_registry_info(
|
65
|
+
agent_tool,
|
66
|
+
RegistryInfo(type="tool", name=tool_name, metadata={TOOL_PARALLEL: False}),
|
67
|
+
)
|
65
68
|
set_tool_description(
|
66
69
|
agent_tool,
|
67
70
|
ToolDescription(
|
inspect_ai/agent/_react.py
CHANGED
@@ -82,9 +82,8 @@ def react(
|
|
82
82
|
the submit tool within the message. Alternatively, an async function
|
83
83
|
to call to determine whether the loop should continue and what message
|
84
84
|
to play back. Note that this function is called on _every_ iteration of
|
85
|
-
the loop
|
86
|
-
|
87
|
-
calls were made.
|
85
|
+
the loop so if you only want to send a message back when the model fails
|
86
|
+
to call tools you need to code that behavior explicitly.
|
88
87
|
truncation: Truncate the conversation history in the event of a context
|
89
88
|
window overflow. Defaults to "disabled" which does no truncation. Pass
|
90
89
|
"auto" to use `trim_messages()` to reduce the context size. Pass a
|
@@ -246,13 +245,12 @@ def react(
|
|
246
245
|
)
|
247
246
|
)
|
248
247
|
elif isinstance(do_continue, str):
|
249
|
-
#
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
content=do_continue.format(submit=submit_tool.name)
|
254
|
-
)
|
248
|
+
# send back the user message
|
249
|
+
state.messages.append(
|
250
|
+
ChatMessageUser(
|
251
|
+
content=do_continue.format(submit=submit_tool.name)
|
255
252
|
)
|
253
|
+
)
|
256
254
|
else: # do_continue is False
|
257
255
|
break
|
258
256
|
|
@@ -328,11 +326,14 @@ def react_no_submit(
|
|
328
326
|
if on_continue:
|
329
327
|
do_continue = await _call_on_continue(on_continue, state)
|
330
328
|
if do_continue is True:
|
331
|
-
do_continue = DEFAULT_CONTINUE_PROMOT_NO_SUBMIT
|
332
|
-
if do_continue:
|
333
|
-
# send back user message if there are no tool calls
|
334
329
|
if not state.output.message.tool_calls:
|
335
|
-
state.messages.append(
|
330
|
+
state.messages.append(
|
331
|
+
ChatMessageUser(
|
332
|
+
content=DEFAULT_CONTINUE_PROMOT_NO_SUBMIT
|
333
|
+
)
|
334
|
+
)
|
335
|
+
elif isinstance(do_continue, str):
|
336
|
+
state.messages.append(ChatMessageUser(content=do_continue))
|
336
337
|
else:
|
337
338
|
break
|
338
339
|
elif not state.output.message.tool_calls:
|
@@ -361,13 +362,13 @@ def _prompt_to_system_message(
|
|
361
362
|
and ("{submit}" not in prompt.assistant_prompt)
|
362
363
|
and prompt.submit_prompt
|
363
364
|
):
|
364
|
-
assistant_prompt = f"{prompt.assistant_prompt}\n{prompt.submit_prompt}"
|
365
|
+
assistant_prompt = f"{prompt.assistant_prompt}\n{prompt.submit_prompt.format(submit=submit_tool)}"
|
365
366
|
else:
|
366
|
-
assistant_prompt = prompt.assistant_prompt
|
367
|
+
assistant_prompt = prompt.assistant_prompt.format(
|
368
|
+
submit=submit_tool or "submit"
|
369
|
+
)
|
367
370
|
prompt_lines.append(assistant_prompt)
|
368
|
-
prompt_content = "\n\n".join(prompt_lines)
|
369
|
-
submit=submit_tool or "submit"
|
370
|
-
)
|
371
|
+
prompt_content = "\n\n".join(prompt_lines)
|
371
372
|
system_message: ChatMessage | None = ChatMessageSystem(content=prompt_content)
|
372
373
|
else:
|
373
374
|
system_message = None
|
@@ -471,12 +472,34 @@ def _remove_submit_tool(
|
|
471
472
|
|
472
473
|
# remove submit tool from assistant messages
|
473
474
|
if isinstance(message, ChatMessageAssistant) and message.tool_calls:
|
474
|
-
|
475
|
+
new_tools_calls = [
|
475
476
|
tool_call
|
476
477
|
for tool_call in message.tool_calls
|
477
478
|
if tool_call.function != submit_name
|
478
479
|
]
|
479
|
-
|
480
|
+
|
481
|
+
# If a submit tool call was removed, we need to update the message
|
482
|
+
if len(new_tools_calls) < len(message.tool_calls):
|
483
|
+
message = message.model_copy(
|
484
|
+
update=dict(
|
485
|
+
tool_calls=new_tools_calls,
|
486
|
+
# Some models (OpenAI) don't like to see the reasoning
|
487
|
+
# content item that led to the submit tool call, so we
|
488
|
+
# have to remove it too.
|
489
|
+
content=(
|
490
|
+
[
|
491
|
+
content
|
492
|
+
for content in message.content
|
493
|
+
if (
|
494
|
+
isinstance(content, str)
|
495
|
+
or content.type != "reasoning"
|
496
|
+
)
|
497
|
+
]
|
498
|
+
if isinstance(message.content, list)
|
499
|
+
else message.content
|
500
|
+
),
|
501
|
+
)
|
502
|
+
)
|
480
503
|
|
481
504
|
# always append message
|
482
505
|
filtered.append(message)
|
inspect_ai/dataset/_dataset.py
CHANGED
inspect_ai/log/_samples.py
CHANGED
@@ -3,6 +3,7 @@ from contextvars import ContextVar
|
|
3
3
|
from datetime import datetime
|
4
4
|
from typing import AsyncGenerator, Iterator, Literal
|
5
5
|
|
6
|
+
from anyio.abc import TaskGroup
|
6
7
|
from shortuuid import uuid
|
7
8
|
|
8
9
|
from inspect_ai.dataset._dataset import Sample
|
@@ -28,6 +29,7 @@ class ActiveSample:
|
|
28
29
|
fails_on_error: bool,
|
29
30
|
transcript: Transcript,
|
30
31
|
sandboxes: dict[str, SandboxConnection],
|
32
|
+
tg: TaskGroup,
|
31
33
|
) -> None:
|
32
34
|
self.id = uuid()
|
33
35
|
self.started: float | None = None
|
@@ -47,6 +49,7 @@ class ActiveSample:
|
|
47
49
|
self.transcript = transcript
|
48
50
|
self.sandboxes = sandboxes
|
49
51
|
self._interrupt_action: Literal["score", "error"] | None = None
|
52
|
+
self.tg = tg
|
50
53
|
|
51
54
|
@property
|
52
55
|
def running_time(self) -> float:
|
@@ -86,6 +89,7 @@ async def active_sample(
|
|
86
89
|
working_limit: int | None,
|
87
90
|
fails_on_error: bool,
|
88
91
|
transcript: Transcript,
|
92
|
+
tg: TaskGroup,
|
89
93
|
) -> AsyncGenerator[ActiveSample, None]:
|
90
94
|
# create the sample
|
91
95
|
active = ActiveSample(
|
@@ -101,6 +105,7 @@ async def active_sample(
|
|
101
105
|
sandboxes=await sandbox_connections(),
|
102
106
|
fails_on_error=fails_on_error,
|
103
107
|
transcript=transcript,
|
108
|
+
tg=tg,
|
104
109
|
)
|
105
110
|
|
106
111
|
_active_samples.append(active)
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -534,11 +534,11 @@ def prepend_agent_name(
|
|
534
534
|
content = copy(message.content)
|
535
535
|
for i in range(0, len(content)):
|
536
536
|
if isinstance(content[i], ContentText):
|
537
|
-
|
538
|
-
|
539
|
-
|
537
|
+
text = cast(ContentText, content[i]).text
|
538
|
+
if text:
|
539
|
+
content[i] = content[i].model_copy(
|
540
|
+
update=dict(text=f"[{agent_name}] {text}")
|
540
541
|
)
|
541
|
-
)
|
542
542
|
break
|
543
543
|
return message.model_copy(update=dict(content=content))
|
544
544
|
|
@@ -14,7 +14,7 @@ def maybe_web_search_tool(model_name: str, tool: ToolInfo) -> WebSearchToolParam
|
|
14
14
|
tool.name == "web_search"
|
15
15
|
and tool.options
|
16
16
|
and "openai" in tool.options
|
17
|
-
and model_name in COMPATIBLE_MODELS
|
17
|
+
and any(model_name.startswith(model) for model in COMPATIBLE_MODELS)
|
18
18
|
)
|
19
19
|
else None
|
20
20
|
)
|
@@ -41,6 +41,7 @@ from anthropic.types import (
|
|
41
41
|
from anthropic.types.beta import (
|
42
42
|
BetaToolComputerUse20250124Param,
|
43
43
|
BetaToolTextEditor20241022Param,
|
44
|
+
BetaToolTextEditor20250429Param,
|
44
45
|
)
|
45
46
|
from pydantic import JsonValue
|
46
47
|
from typing_extensions import override
|
@@ -397,6 +398,9 @@ class AnthropicAPI(ModelAPI):
|
|
397
398
|
def is_claude_3_7(self) -> bool:
|
398
399
|
return "claude-3-7-" in self.service_model_name()
|
399
400
|
|
401
|
+
def is_claude_4(self) -> bool:
|
402
|
+
return re.search(r"claude-4-[a-zA-Z]", self.service_model_name()) is not None
|
403
|
+
|
400
404
|
@override
|
401
405
|
def connection_key(self) -> str:
|
402
406
|
return str(self.api_key)
|
@@ -627,7 +631,17 @@ class AnthropicAPI(ModelAPI):
|
|
627
631
|
|
628
632
|
def text_editor_tool_param(
|
629
633
|
self, tool: ToolInfo
|
630
|
-
) ->
|
634
|
+
) -> (
|
635
|
+
ToolTextEditor20250124Param
|
636
|
+
| BetaToolTextEditor20241022Param
|
637
|
+
| BetaToolTextEditor20250429Param
|
638
|
+
| None
|
639
|
+
):
|
640
|
+
# See: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/text-editor-tool#before-using-the-text-editor-tool
|
641
|
+
# TODO: It would be great to enhance our `is_claude_xxx` functions to help here.
|
642
|
+
if self.model_name.startswith(("claude-3-5-haiku", "claude-3-opus")):
|
643
|
+
return None
|
644
|
+
|
631
645
|
# check for compatible 'text editor' tool
|
632
646
|
if tool.name == "text_editor" and (
|
633
647
|
sorted(tool.parameters.properties.keys())
|
@@ -644,7 +658,11 @@ class AnthropicAPI(ModelAPI):
|
|
644
658
|
)
|
645
659
|
):
|
646
660
|
return (
|
647
|
-
|
661
|
+
BetaToolTextEditor20250429Param(
|
662
|
+
type="text_editor_20250429", name="str_replace_based_edit_tool"
|
663
|
+
)
|
664
|
+
if self.is_claude_4()
|
665
|
+
else BetaToolTextEditor20241022Param(
|
648
666
|
type="text_editor_20241022", name="str_replace_editor"
|
649
667
|
)
|
650
668
|
if self.is_claude_3_5()
|
@@ -706,6 +724,7 @@ ToolParamDef = (
|
|
706
724
|
| BetaToolComputerUse20250124Param
|
707
725
|
| ToolTextEditor20250124Param
|
708
726
|
| BetaToolTextEditor20241022Param
|
727
|
+
| BetaToolTextEditor20250429Param
|
709
728
|
| WebSearchTool20250305Param
|
710
729
|
)
|
711
730
|
|
@@ -716,6 +735,7 @@ def add_cache_control(
|
|
716
735
|
| BetaToolComputerUse20250124Param
|
717
736
|
| ToolTextEditor20250124Param
|
718
737
|
| BetaToolTextEditor20241022Param
|
738
|
+
| BetaToolTextEditor20250429Param
|
719
739
|
| WebSearchTool20250305Param
|
720
740
|
| dict[str, Any],
|
721
741
|
) -> None:
|
@@ -1008,6 +1028,7 @@ def _names_for_tool_call(
|
|
1008
1028
|
(INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
|
1009
1029
|
("str_replace_editor", "text_editor_20241022", "text_editor"),
|
1010
1030
|
("str_replace_editor", "text_editor_20250124", "text_editor"),
|
1031
|
+
("str_replace_based_edit_tool", "text_editor_20250429", "text_editor"),
|
1011
1032
|
("bash", "bash_20250124", "bash_session"),
|
1012
1033
|
)
|
1013
1034
|
|
@@ -991,6 +991,10 @@ def _combine_text_parts(acc: list[Part], part: Part) -> list[Part]:
|
|
991
991
|
"""Combine adjacent text parts into a single part."""
|
992
992
|
return (
|
993
993
|
acc + [part]
|
994
|
-
if part.text is None
|
994
|
+
if part.text is None
|
995
|
+
or part.thought is True
|
996
|
+
or len(acc) == 0
|
997
|
+
or acc[-1].text is None
|
998
|
+
or acc[-1].thought is True
|
995
999
|
else acc[:-1] + [Part(text=acc[-1].text + part.text)]
|
996
1000
|
)
|
@@ -156,6 +156,11 @@ class GroqAPI(ModelAPI):
|
|
156
156
|
"completion_time": completion.usage.completion_time,
|
157
157
|
"total_time": completion.usage.total_time,
|
158
158
|
}
|
159
|
+
if completion.choices[0].message.executed_tools:
|
160
|
+
metadata["executed_tools"] = [
|
161
|
+
tool.model_dump()
|
162
|
+
for tool in completion.choices[0].message.executed_tools
|
163
|
+
]
|
159
164
|
|
160
165
|
# extract output
|
161
166
|
choices = self._chat_choices_from_response(completion, tools)
|
@@ -49,7 +49,33 @@ class PerplexityAPI(OpenAICompatibleAPI):
|
|
49
49
|
tool_choice: "ToolChoice",
|
50
50
|
config: GenerateConfig,
|
51
51
|
) -> tuple[ModelOutput | Exception, "ModelCall"]:
|
52
|
-
|
52
|
+
search_options: dict[str, Any] | None = None
|
53
|
+
for tool in tools:
|
54
|
+
if (
|
55
|
+
tool.name == "web_search"
|
56
|
+
and tool.options
|
57
|
+
and "perplexity" in tool.options
|
58
|
+
):
|
59
|
+
maybe_opts = tool.options["perplexity"]
|
60
|
+
if maybe_opts is not None:
|
61
|
+
if maybe_opts is True:
|
62
|
+
search_options = {}
|
63
|
+
elif isinstance(maybe_opts, dict):
|
64
|
+
search_options = maybe_opts
|
65
|
+
else:
|
66
|
+
raise TypeError(
|
67
|
+
f"Expected a dictionary or True for perplexity_options, got {type(maybe_opts)}"
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
raise ValueError(
|
71
|
+
"Perplexity does not support tools other than web_search with perplexity options"
|
72
|
+
)
|
73
|
+
|
74
|
+
if search_options:
|
75
|
+
extra_body = {**(config.extra_body or {}), **search_options}
|
76
|
+
config = config.merge(GenerateConfig(extra_body=extra_body))
|
77
|
+
|
78
|
+
result = await super().generate(input, [], tool_choice, config)
|
53
79
|
output, call = cast(tuple[ModelOutput, "ModelCall"], result)
|
54
80
|
|
55
81
|
if self._response:
|
@@ -18,7 +18,7 @@ from ._tavily import TavilyOptions, tavily_search_provider
|
|
18
18
|
from ._web_search_provider import SearchProvider
|
19
19
|
|
20
20
|
Provider: TypeAlias = Literal[
|
21
|
-
"gemini", "openai", "anthropic", "tavily", "google", "exa"
|
21
|
+
"gemini", "openai", "anthropic", "perplexity", "tavily", "google", "exa"
|
22
22
|
]
|
23
23
|
valid_providers = set(get_args(Provider))
|
24
24
|
|
@@ -35,6 +35,7 @@ class Providers(TypedDict, total=False):
|
|
35
35
|
openai: dict[str, Any] | Literal[True]
|
36
36
|
anthropic: dict[str, Any] | Literal[True]
|
37
37
|
gemini: dict[str, Any] | Literal[True]
|
38
|
+
perplexity: dict[str, Any] | Literal[True]
|
38
39
|
tavily: dict[str, Any] | Literal[True]
|
39
40
|
google: dict[str, Any] | Literal[True]
|
40
41
|
exa: dict[str, Any] | Literal[True]
|
@@ -44,6 +45,7 @@ class _NormalizedProviders(TypedDict, total=False):
|
|
44
45
|
openai: dict[str, Any]
|
45
46
|
anthropic: dict[str, Any]
|
46
47
|
gemini: dict[str, Any]
|
48
|
+
perplexity: dict[str, Any]
|
47
49
|
tavily: dict[str, Any]
|
48
50
|
google: dict[str, Any]
|
49
51
|
exa: dict[str, Any]
|
@@ -67,7 +69,7 @@ def web_search(
|
|
67
69
|
Web searches are executed using a provider. Providers are split
|
68
70
|
into two categories:
|
69
71
|
|
70
|
-
- Internal providers: "openai", "anthropic" - these use the model's built-in
|
72
|
+
- Internal providers: "openai", "anthropic", "gemini", "perplexity" - these use the model's built-in
|
71
73
|
search capability and do not require separate API keys. These work only for
|
72
74
|
their respective model provider (e.g. the "openai" search provider
|
73
75
|
works only for `openai/*` models).
|
@@ -84,7 +86,7 @@ def web_search(
|
|
84
86
|
|
85
87
|
Args:
|
86
88
|
providers: Configuration for the search providers to use. Currently supported
|
87
|
-
providers are "openai", "anthropic", "tavily", "google", and "exa". The
|
89
|
+
providers are "openai", "anthropic", "perplexity", "tavily", "google", and "exa". The
|
88
90
|
`providers` parameter supports several formats based on either a `str`
|
89
91
|
specifying a provider or a `dict` whose keys are the provider names and
|
90
92
|
whose values are the provider-specific options. A single value or a list
|
@@ -121,6 +123,9 @@ def web_search(
|
|
121
123
|
- anthropic: Supports Anthropic's web search parameters.
|
122
124
|
See https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool#tool-definition
|
123
125
|
|
126
|
+
- perplexity: Supports Perplexity's web search parameters.
|
127
|
+
See https://docs.perplexity.ai/api-reference/chat-completions-post
|
128
|
+
|
124
129
|
- tavily: Supports options like `max_results`, `search_depth`, etc.
|
125
130
|
See https://docs.tavily.com/documentation/api-reference/endpoint/search
|
126
131
|
|
inspect_ai/util/__init__.py
CHANGED
@@ -4,13 +4,16 @@ from inspect_ai.util._limit import (
|
|
4
4
|
Limit,
|
5
5
|
LimitExceededError,
|
6
6
|
LimitScope,
|
7
|
+
SampleLimits,
|
7
8
|
apply_limits,
|
8
9
|
message_limit,
|
10
|
+
sample_limits,
|
9
11
|
time_limit,
|
10
12
|
token_limit,
|
11
13
|
working_limit,
|
12
14
|
)
|
13
15
|
|
16
|
+
from ._background import background
|
14
17
|
from ._collect import collect
|
15
18
|
from ._concurrency import concurrency
|
16
19
|
from ._console import input_screen
|
@@ -29,6 +32,7 @@ from ._sandbox import (
|
|
29
32
|
SandboxEnvironmentType,
|
30
33
|
sandbox,
|
31
34
|
sandbox_default,
|
35
|
+
sandbox_service,
|
32
36
|
sandbox_with,
|
33
37
|
sandboxenv,
|
34
38
|
)
|
@@ -44,6 +48,8 @@ from ._throttle import throttle
|
|
44
48
|
|
45
49
|
__all__ = [
|
46
50
|
"apply_limits",
|
51
|
+
"sample_limits",
|
52
|
+
"SampleLimits",
|
47
53
|
"ExecResult",
|
48
54
|
"concurrency",
|
49
55
|
"DisplayType",
|
@@ -73,6 +79,7 @@ __all__ = [
|
|
73
79
|
"sandbox",
|
74
80
|
"sandbox_with",
|
75
81
|
"sandbox_default",
|
82
|
+
"sandbox_service",
|
76
83
|
"Store",
|
77
84
|
"store",
|
78
85
|
"StoreModel",
|
@@ -82,6 +89,7 @@ __all__ = [
|
|
82
89
|
"Subtask",
|
83
90
|
"subtask",
|
84
91
|
"throttle",
|
92
|
+
"background",
|
85
93
|
"token_limit",
|
86
94
|
"time_limit",
|
87
95
|
"working_limit",
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import sys
|
2
|
+
from logging import getLogger
|
3
|
+
from typing import Any, Awaitable, Callable
|
4
|
+
|
5
|
+
if sys.version_info >= (3, 11):
|
6
|
+
from typing import TypeVarTuple
|
7
|
+
else:
|
8
|
+
from typing_extensions import TypeVarTuple
|
9
|
+
|
10
|
+
|
11
|
+
from typing_extensions import Unpack
|
12
|
+
|
13
|
+
logger = getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
PosArgsT = TypeVarTuple("PosArgsT")
|
17
|
+
|
18
|
+
|
19
|
+
def background(
|
20
|
+
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
|
21
|
+
*args: Unpack[PosArgsT],
|
22
|
+
) -> None:
|
23
|
+
"""Run an async function in the background of the current sample.
|
24
|
+
|
25
|
+
Background functions must be run from an executing sample.
|
26
|
+
The function will run as long as the current sample is running.
|
27
|
+
|
28
|
+
When the sample terminates, an anyio cancelled error will be
|
29
|
+
raised in the background function. To catch this error and
|
30
|
+
cleanup:
|
31
|
+
|
32
|
+
```python
|
33
|
+
import anyio
|
34
|
+
|
35
|
+
async def run():
|
36
|
+
try:
|
37
|
+
# background code
|
38
|
+
except anyio.get_cancelled_exc_class():
|
39
|
+
...
|
40
|
+
```
|
41
|
+
|
42
|
+
Args:
|
43
|
+
func: Async function to run
|
44
|
+
*args: Optional function arguments.
|
45
|
+
"""
|
46
|
+
from inspect_ai.log._samples import sample_active
|
47
|
+
|
48
|
+
# get the active sample
|
49
|
+
sample = sample_active()
|
50
|
+
if sample is None:
|
51
|
+
raise RuntimeError(
|
52
|
+
"background() function must be called from a running sample."
|
53
|
+
)
|
54
|
+
|
55
|
+
# handle and log background exceptions
|
56
|
+
async def run() -> None:
|
57
|
+
try:
|
58
|
+
await func(*args)
|
59
|
+
except Exception as ex:
|
60
|
+
logger.error(f"Background worker error: {ex}")
|
61
|
+
raise
|
62
|
+
|
63
|
+
# kick it off
|
64
|
+
sample.tg.start_soon(run)
|