inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -9
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +79 -12
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +3 -1
- inspect_ai/_eval/task/results.py +51 -22
- inspect_ai/_eval/task/run.py +47 -13
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25498 -2044
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +14 -16
- inspect_ai/_view/www/src/Types.mjs +1 -2
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +77 -4
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +59 -0
- inspect_ai/model/_conversation.py +16 -7
- inspect_ai/model/_generate_config.py +12 -12
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +22 -2
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +152 -55
- inspect_ai/model/_providers/azureai.py +21 -21
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +13 -12
- inspect_ai/model/_providers/openai.py +51 -218
- inspect_ai/model/_providers/openai_o1.py +11 -12
- inspect_ai/model/_providers/providers.py +23 -1
- inspect_ai/model/_providers/together.py +12 -12
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +4 -3
- inspect_ai/solver/__init__.py +4 -5
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
inspect_ai/scorer/_scorer.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from functools import wraps
|
1
2
|
from typing import (
|
2
3
|
Any,
|
3
4
|
Callable,
|
@@ -100,7 +101,6 @@ def scorer(
|
|
100
101
|
|
101
102
|
Returns:
|
102
103
|
Scorer with registry attributes.
|
103
|
-
|
104
104
|
"""
|
105
105
|
|
106
106
|
def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
|
@@ -110,6 +110,7 @@ def scorer(
|
|
110
110
|
)
|
111
111
|
|
112
112
|
# wrap instantiations of scorer so they carry registry info and metrics
|
113
|
+
@wraps(scorer_type)
|
113
114
|
def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
|
114
115
|
scorer = scorer_type(*args, **kwargs)
|
115
116
|
|
@@ -151,8 +152,8 @@ def scorer_metrics(
|
|
151
152
|
return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
|
152
153
|
|
153
154
|
|
154
|
-
def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
|
155
|
-
base_name = registry_unqualified_name(scorer)
|
155
|
+
def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
|
156
|
+
base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
|
156
157
|
scorer_name = base_name
|
157
158
|
count = 1
|
158
159
|
while scorer_name in already_used_names:
|
inspect_ai/solver/__init__.py
CHANGED
@@ -1,23 +1,21 @@
|
|
1
1
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
2
|
|
3
3
|
from ._basic_agent import basic_agent
|
4
|
+
from ._bridge import bridge
|
4
5
|
from ._chain import chain
|
5
6
|
from ._critique import self_critique
|
6
7
|
from ._fork import fork
|
7
8
|
from ._human_agent.agent import human_agent
|
8
9
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
9
10
|
from ._plan import Plan, plan
|
10
|
-
from ._prompt import
|
11
|
-
chain_of_thought,
|
12
|
-
prompt_template,
|
13
|
-
system_message,
|
14
|
-
)
|
11
|
+
from ._prompt import chain_of_thought, prompt_template, system_message, user_message
|
15
12
|
from ._solver import Generate, Solver, SolverSpec, generate, solver
|
16
13
|
from ._task_state import Choice, Choices, TaskState
|
17
14
|
from ._use_tools import use_tools
|
18
15
|
|
19
16
|
__all__ = [
|
20
17
|
"basic_agent",
|
18
|
+
"bridge",
|
21
19
|
"human_agent",
|
22
20
|
"chain",
|
23
21
|
"fork",
|
@@ -26,6 +24,7 @@ __all__ = [
|
|
26
24
|
"chain_of_thought",
|
27
25
|
"multiple_choice",
|
28
26
|
"system_message",
|
27
|
+
"user_message",
|
29
28
|
"self_critique",
|
30
29
|
"use_tools",
|
31
30
|
"plan",
|
@@ -119,7 +119,7 @@ def basic_agent(
|
|
119
119
|
# resolve tools
|
120
120
|
if tools is None:
|
121
121
|
tools = []
|
122
|
-
tools = tools if isinstance(tools, Solver) else use_tools(tools)
|
122
|
+
tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
|
123
123
|
|
124
124
|
# resolve score_value function
|
125
125
|
score_value_fn = score_value or value_to_float()
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Any, Awaitable, Callable
|
2
|
+
|
3
|
+
from jsonschema import Draft7Validator
|
4
|
+
from pydantic import BaseModel, Field, ValidationError
|
5
|
+
from pydantic_core import to_json
|
6
|
+
|
7
|
+
from inspect_ai._util._async import is_callable_coroutine
|
8
|
+
from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
|
9
|
+
from inspect_ai.model._providers.providers import validate_openai_client
|
10
|
+
from inspect_ai.scorer._metric import Score
|
11
|
+
|
12
|
+
from .._solver import Generate, Solver, solver
|
13
|
+
from .._task_state import TaskState
|
14
|
+
|
15
|
+
|
16
|
+
@solver
|
17
|
+
def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
|
18
|
+
"""Bridge an external agent into an Inspect Solver.
|
19
|
+
|
20
|
+
See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
|
21
|
+
|
22
|
+
Args:
|
23
|
+
agent: Callable which takes a sample `dict` and returns a result `dict`.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
Standard Inspect solver.
|
27
|
+
"""
|
28
|
+
validate_openai_client("Solver bridge()")
|
29
|
+
|
30
|
+
from openai.types.chat import ChatCompletionMessageParam
|
31
|
+
|
32
|
+
from inspect_ai.model._openai import (
|
33
|
+
chat_messages_from_openai,
|
34
|
+
openai_chat_messages,
|
35
|
+
)
|
36
|
+
|
37
|
+
from .patch import openai_request_to_inspect_model
|
38
|
+
|
39
|
+
class BridgeSample(BaseModel):
|
40
|
+
sample_id: str
|
41
|
+
epoch: int
|
42
|
+
input: list[ChatCompletionMessageParam]
|
43
|
+
metadata: dict[str, Any]
|
44
|
+
target: list[str]
|
45
|
+
|
46
|
+
class BridgeResult(BaseModel):
|
47
|
+
output: str
|
48
|
+
messages: list[ChatCompletionMessageParam] | None = Field(default=None)
|
49
|
+
scores: dict[str, Score] | None = Field(default=None)
|
50
|
+
|
51
|
+
result_schema = BridgeResult.model_json_schema()
|
52
|
+
result_validator = Draft7Validator(result_schema)
|
53
|
+
|
54
|
+
# validate that the agent is an async function
|
55
|
+
if not is_callable_coroutine(agent):
|
56
|
+
raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
|
57
|
+
|
58
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
59
|
+
# resolve input to array
|
60
|
+
input: list[ChatMessage] = (
|
61
|
+
[ChatMessageUser(content=state.input)]
|
62
|
+
if isinstance(state.input, str)
|
63
|
+
else state.input
|
64
|
+
)
|
65
|
+
|
66
|
+
# create sample
|
67
|
+
sample = BridgeSample(
|
68
|
+
sample_id=str(state.sample_id),
|
69
|
+
epoch=state.epoch,
|
70
|
+
input=await openai_chat_messages(input, state.model.name),
|
71
|
+
metadata=state.metadata,
|
72
|
+
target=list(state.target),
|
73
|
+
)
|
74
|
+
|
75
|
+
# run target function
|
76
|
+
async with openai_request_to_inspect_model():
|
77
|
+
# call the function
|
78
|
+
result_dict = await agent(sample.model_dump())
|
79
|
+
try:
|
80
|
+
result = BridgeResult.model_validate(result_dict)
|
81
|
+
except ValidationError:
|
82
|
+
# if we fail to validate provide a better human readable error
|
83
|
+
errors = list(result_validator.iter_errors(result_dict))
|
84
|
+
message = "\n".join(
|
85
|
+
["Result returned from bridged solver is not valid:"]
|
86
|
+
+ [f" - {error.message}" for error in errors]
|
87
|
+
+ ["", to_json(result_dict, indent=2).decode()]
|
88
|
+
)
|
89
|
+
raise ValueError(message)
|
90
|
+
|
91
|
+
# update and return state
|
92
|
+
state.output.completion = result.output
|
93
|
+
if result.messages is not None:
|
94
|
+
state.messages = chat_messages_from_openai(result.messages)
|
95
|
+
if result.scores is not None:
|
96
|
+
state.scores = result.scores
|
97
|
+
|
98
|
+
return state
|
99
|
+
|
100
|
+
return solve
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import contextlib
|
2
|
+
import re
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from functools import wraps
|
5
|
+
from time import time
|
6
|
+
from typing import Any, AsyncGenerator, Optional, Type, cast
|
7
|
+
|
8
|
+
from openai._base_client import AsyncAPIClient, _AsyncStreamT
|
9
|
+
from openai._models import FinalRequestOptions
|
10
|
+
from openai._types import ResponseT
|
11
|
+
from openai.types.chat import (
|
12
|
+
ChatCompletion,
|
13
|
+
ChatCompletionMessageParam,
|
14
|
+
ChatCompletionToolParam,
|
15
|
+
)
|
16
|
+
from shortuuid import uuid
|
17
|
+
|
18
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
19
|
+
from inspect_ai.model._model import get_model
|
20
|
+
from inspect_ai.model._openai import (
|
21
|
+
chat_messages_from_openai,
|
22
|
+
openai_chat_choices,
|
23
|
+
openai_completion_usage,
|
24
|
+
)
|
25
|
+
from inspect_ai.solver._task_state import sample_state
|
26
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
27
|
+
from inspect_ai.tool._tool_params import ToolParams
|
28
|
+
|
29
|
+
|
30
|
+
@contextlib.asynccontextmanager
|
31
|
+
async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
|
32
|
+
# ensure one time init
|
33
|
+
init_openai_request_patch()
|
34
|
+
|
35
|
+
# set the patch enabled for this context and child coroutines
|
36
|
+
token = _patch_enabled.set(True)
|
37
|
+
try:
|
38
|
+
yield
|
39
|
+
finally:
|
40
|
+
_patch_enabled.reset(token)
|
41
|
+
|
42
|
+
|
43
|
+
_patch_initialised: bool = False
|
44
|
+
|
45
|
+
_patch_enabled: ContextVar[bool] = ContextVar(
|
46
|
+
"openai_request_patch_enabled", default=False
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def init_openai_request_patch() -> None:
|
51
|
+
global _patch_initialised
|
52
|
+
if not _patch_initialised:
|
53
|
+
# get reference to original method
|
54
|
+
original_request = getattr(AsyncAPIClient, "request")
|
55
|
+
if original_request is None:
|
56
|
+
raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
|
57
|
+
|
58
|
+
@wraps(original_request)
|
59
|
+
async def patched_request(
|
60
|
+
self: AsyncAPIClient,
|
61
|
+
cast_to: Type[ResponseT],
|
62
|
+
options: FinalRequestOptions,
|
63
|
+
*,
|
64
|
+
stream: bool = False,
|
65
|
+
stream_cls: type[_AsyncStreamT] | None = None,
|
66
|
+
remaining_retries: Optional[int] = None,
|
67
|
+
) -> Any:
|
68
|
+
# we have patched the underlying request method so now need to figure out when to
|
69
|
+
# patch and when to stand down
|
70
|
+
if (
|
71
|
+
# enabled for this coroutine
|
72
|
+
_patch_enabled.get()
|
73
|
+
# completions request
|
74
|
+
and options.url == "/chat/completions"
|
75
|
+
# call to openai not another service (e.g. TogetherAI)
|
76
|
+
and self.base_url == "https://api.openai.com/v1/"
|
77
|
+
):
|
78
|
+
# must also be an explicit request for an inspect model
|
79
|
+
json_data = cast(dict[str, Any], options.json_data)
|
80
|
+
model_name = str(json_data["model"])
|
81
|
+
if re.match(r"^inspect/?", model_name):
|
82
|
+
return await inspect_model_request(model_name, options)
|
83
|
+
|
84
|
+
# otherwise just delegate
|
85
|
+
return await original_request(
|
86
|
+
self,
|
87
|
+
cast_to,
|
88
|
+
options,
|
89
|
+
stream=stream,
|
90
|
+
stream_cls=stream_cls,
|
91
|
+
remaining_retries=remaining_retries,
|
92
|
+
)
|
93
|
+
|
94
|
+
setattr(AsyncAPIClient, "request", patched_request)
|
95
|
+
|
96
|
+
|
97
|
+
async def inspect_model_request(
|
98
|
+
model_name: str, options: FinalRequestOptions
|
99
|
+
) -> ChatCompletion:
|
100
|
+
# convert openai messages to inspect messages
|
101
|
+
json_data = cast(dict[str, Any], options.json_data)
|
102
|
+
messages: list[ChatCompletionMessageParam] = json_data["messages"]
|
103
|
+
input = chat_messages_from_openai(messages)
|
104
|
+
|
105
|
+
# convert openai tools to inspect tools
|
106
|
+
tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
|
107
|
+
inspect_tools: list[ToolInfo] = []
|
108
|
+
for tool in tools:
|
109
|
+
function = tool["function"].copy()
|
110
|
+
inspect_tools.append(
|
111
|
+
ToolInfo(
|
112
|
+
name=function["name"],
|
113
|
+
description=function["description"],
|
114
|
+
parameters=ToolParams.model_validate(function["parameters"]),
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
# resolve model
|
119
|
+
if model_name == "inspect":
|
120
|
+
model = get_model()
|
121
|
+
else:
|
122
|
+
model = get_model(model_name.removeprefix("inspect/"))
|
123
|
+
|
124
|
+
output = await model.generate(
|
125
|
+
input=input,
|
126
|
+
tools=inspect_tools,
|
127
|
+
config=generate_config_from_openai(options),
|
128
|
+
)
|
129
|
+
|
130
|
+
# if we are using the "default" inspect model for the task, update state.messages
|
131
|
+
if model_name == "inspect":
|
132
|
+
state = sample_state()
|
133
|
+
if state:
|
134
|
+
state.messages = input + [output.choices[0].message]
|
135
|
+
|
136
|
+
# inspect completion to openai completion
|
137
|
+
return ChatCompletion(
|
138
|
+
id=uuid(),
|
139
|
+
created=int(time()),
|
140
|
+
object="chat.completion",
|
141
|
+
choices=openai_chat_choices(output.choices),
|
142
|
+
model=model_name,
|
143
|
+
usage=openai_completion_usage(output.usage) if output.usage else None,
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
|
148
|
+
# get options dict
|
149
|
+
json_data = cast(dict[str, Any], options.json_data)
|
150
|
+
|
151
|
+
config = GenerateConfig()
|
152
|
+
config.max_tokens = json_data.get(
|
153
|
+
"max_completion_tokens", json_data.get("max_tokens", None)
|
154
|
+
)
|
155
|
+
config.top_p = json_data.get("top_p", None)
|
156
|
+
config.temperature = json_data.get("temperature", None)
|
157
|
+
stop = json_data.get("stop", None)
|
158
|
+
if stop:
|
159
|
+
config.stop_seqs = [stop] if isinstance(stop, str) else stop
|
160
|
+
config.frequency_penalty = json_data.get("frequency_penalty", None)
|
161
|
+
config.presence_penalty = json_data.get("presence_penalty", None)
|
162
|
+
config.seed = json_data.get("seed", None)
|
163
|
+
config.num_choices = json_data.get("n", None)
|
164
|
+
config.logprobs = json_data.get("logprobs", None)
|
165
|
+
config.top_logprobs = json_data.get("top_logprobs", None)
|
166
|
+
config.logit_bias = json_data.get("logit_bias", None)
|
167
|
+
config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
|
168
|
+
config.reasoning_effort = json_data.get("reasoning_effort", None)
|
169
|
+
|
170
|
+
return config
|
inspect_ai/solver/_prompt.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Any
|
|
2
2
|
|
3
3
|
from inspect_ai._util.dict import omit
|
4
4
|
from inspect_ai.model import ChatMessageSystem
|
5
|
+
from inspect_ai.model._chat_message import ChatMessageUser
|
5
6
|
from inspect_ai.util import resource
|
6
7
|
|
7
8
|
from ._solver import Generate, Solver, solver
|
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
15
16
|
|
16
17
|
Prompt template containing a `{prompt}` placeholder and any
|
17
18
|
number of additional `params`. All values contained in sample
|
18
|
-
`metadata` are also automatically included in the
|
19
|
+
`metadata` and `store` are also automatically included in the
|
20
|
+
`params`.
|
19
21
|
|
20
22
|
Args:
|
21
23
|
template: (str): Template for prompt.
|
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
29
31
|
|
30
32
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
31
33
|
prompt = state.user_prompt
|
32
|
-
kwargs = omit(state.metadata, ["prompt"]) | params
|
34
|
+
kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
|
33
35
|
prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
|
34
36
|
return state
|
35
37
|
|
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
41
43
|
"""Solver which inserts a system message into the conversation.
|
42
44
|
|
43
45
|
System message template containing any number of optional `params`.
|
44
|
-
for substitution
|
45
|
-
|
46
|
+
for substitution using the `str.format()` method. All values
|
47
|
+
contained in sample `metadata` and `store` are also automatically
|
48
|
+
included in the `params`.
|
46
49
|
|
47
50
|
The new message will go after other system messages (if there
|
48
51
|
are none it will be inserted at the beginning of the conversation).
|
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
58
61
|
content = resource(template)
|
59
62
|
|
60
63
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
61
|
-
kwargs = state.metadata | params
|
64
|
+
kwargs = state.metadata | state.store._data | params
|
62
65
|
append_system_message(
|
63
66
|
state.messages, ChatMessageSystem(content=content.format(**kwargs))
|
64
67
|
)
|
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
67
70
|
return solve
|
68
71
|
|
69
72
|
|
73
|
+
@solver
|
74
|
+
def user_message(template: str, **params: Any) -> Solver:
|
75
|
+
"""Solver which inserts a user message into the conversation.
|
76
|
+
|
77
|
+
User message template containing any number of optional `params`.
|
78
|
+
for substitution using the `str.format()` method. All values
|
79
|
+
contained in sample `metadata` and `store` are also automatically
|
80
|
+
included in the `params`.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
template (str): Template for user message.
|
84
|
+
**params (dict[str,Any]): Parameters to fill into the template.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
A solver that inserts the parameterised user message.
|
88
|
+
"""
|
89
|
+
# read template
|
90
|
+
content = resource(template)
|
91
|
+
|
92
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
93
|
+
kwargs = state.metadata | state.store._data | params
|
94
|
+
state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
|
95
|
+
return state
|
96
|
+
|
97
|
+
return solve
|
98
|
+
|
99
|
+
|
70
100
|
DEFAULT_COT_TEMPLATE = r"""
|
71
101
|
{prompt}
|
72
102
|
|
inspect_ai/solver/_solver.py
CHANGED
@@ -180,6 +180,7 @@ def solver(
|
|
180
180
|
solver_type, name if name else getattr(solver_type, "__name__")
|
181
181
|
)
|
182
182
|
|
183
|
+
@wraps(solver_type)
|
183
184
|
def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
|
184
185
|
solver = solver_type(*args, **kwargs)
|
185
186
|
|
@@ -193,6 +194,7 @@ def solver(
|
|
193
194
|
if inspect.isclass(type(solver)):
|
194
195
|
original_call = solver.__call__
|
195
196
|
|
197
|
+
@wraps(original_call)
|
196
198
|
async def call_with_state(
|
197
199
|
state: TaskState, generate: Generate
|
198
200
|
) -> TaskState:
|
@@ -225,6 +227,10 @@ def solver(
|
|
225
227
|
|
226
228
|
return registered_solver
|
227
229
|
|
230
|
+
# functools.wraps overrides the return type annotation of the inner function, so
|
231
|
+
# we explicitly set it again
|
232
|
+
solver_wrapper.__annotations__["return"] = Solver
|
233
|
+
|
228
234
|
return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
|
229
235
|
|
230
236
|
# for decorators with an explicit name, one more wrapper for the name
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -2,8 +2,9 @@ from collections.abc import Sequence
|
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from copy import deepcopy
|
4
4
|
from dataclasses import dataclass
|
5
|
+
from itertools import tee
|
5
6
|
from random import Random
|
6
|
-
from typing import Any, Type, Union, cast, overload
|
7
|
+
from typing import Any, Iterable, SupportsIndex, Type, Union, cast, overload
|
7
8
|
|
8
9
|
from pydantic_core import to_jsonable_python
|
9
10
|
|
@@ -15,9 +16,13 @@ from inspect_ai.model import (
|
|
15
16
|
ModelOutput,
|
16
17
|
)
|
17
18
|
from inspect_ai.model._call_tools import tools_info
|
19
|
+
from inspect_ai.model._chat_message import ChatMessageBase
|
18
20
|
from inspect_ai.model._model import sample_total_tokens
|
21
|
+
from inspect_ai.scorer._metric import Score
|
22
|
+
from inspect_ai.scorer._target import Target
|
19
23
|
from inspect_ai.tool import Tool, ToolChoice
|
20
24
|
from inspect_ai.tool._tool_def import ToolDef
|
25
|
+
from inspect_ai.util._limit import SampleLimitExceededError
|
21
26
|
from inspect_ai.util._store import Store, store_jsonable
|
22
27
|
from inspect_ai.util._store_model import SMT
|
23
28
|
|
@@ -136,6 +141,7 @@ class TaskState:
|
|
136
141
|
epoch: int,
|
137
142
|
input: str | list[ChatMessage],
|
138
143
|
messages: list[ChatMessage],
|
144
|
+
target: Target = Target(""),
|
139
145
|
choices: list[str] | None = [],
|
140
146
|
output: ModelOutput | None = None,
|
141
147
|
message_limit: int | None = None,
|
@@ -161,10 +167,13 @@ class TaskState:
|
|
161
167
|
or `input_text` only
|
162
168
|
"""
|
163
169
|
|
170
|
+
self.target = target
|
171
|
+
"""The scoring target for this `Sample`."""
|
172
|
+
|
164
173
|
self.metadata = metadata
|
165
174
|
"""Metadata from the `Sample` for this `TaskState`"""
|
166
175
|
|
167
|
-
self.
|
176
|
+
self._messages: list[ChatMessage] = ChatMessageList(messages)
|
168
177
|
"""
|
169
178
|
Chat conversation history for sample.
|
170
179
|
|
@@ -189,9 +198,7 @@ class TaskState:
|
|
189
198
|
"""
|
190
199
|
|
191
200
|
self._message_limit = message_limit
|
192
|
-
self._message_limit_exceeded = False
|
193
201
|
self._token_limit = token_limit
|
194
|
-
self._token_limit_exceeded = False
|
195
202
|
self._completed = completed
|
196
203
|
|
197
204
|
"""Store for shared data"""
|
@@ -202,6 +209,9 @@ class TaskState:
|
|
202
209
|
else:
|
203
210
|
self.choices = Choices([])
|
204
211
|
|
212
|
+
self.scores: dict[str, Score] | None = None
|
213
|
+
"""Scores yielded by running task."""
|
214
|
+
|
205
215
|
@property
|
206
216
|
def model(self) -> ModelName:
|
207
217
|
"""Name of model being evaluated."""
|
@@ -254,6 +264,16 @@ class TaskState:
|
|
254
264
|
else:
|
255
265
|
raise ValueError("user_prompt requested from TaskState but none available")
|
256
266
|
|
267
|
+
@property
|
268
|
+
def messages(self) -> list[ChatMessage]:
|
269
|
+
"""Messages in chat history"""
|
270
|
+
return self._messages
|
271
|
+
|
272
|
+
@messages.setter
|
273
|
+
def messages(self, messages: list[ChatMessage]) -> None:
|
274
|
+
"""Set messages in chat history."""
|
275
|
+
self._messages = ChatMessageList(messages)
|
276
|
+
|
257
277
|
@property
|
258
278
|
def max_messages(self) -> int | None:
|
259
279
|
"""Deprecated (use message_limit)."""
|
@@ -300,40 +320,7 @@ class TaskState:
|
|
300
320
|
@property
|
301
321
|
def completed(self) -> bool:
|
302
322
|
"""Is the task completed."""
|
303
|
-
|
304
|
-
from inspect_ai.log._samples import set_active_sample_total_messages
|
305
|
-
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
306
|
-
|
307
|
-
set_active_sample_total_messages(len(self.messages))
|
308
|
-
|
309
|
-
if self._completed:
|
310
|
-
return True
|
311
|
-
elif self.message_limit and len(self.messages) >= self.message_limit:
|
312
|
-
# log if this is the first time we hit this
|
313
|
-
if not self._message_limit_exceeded:
|
314
|
-
self._message_limit_exceeded = True
|
315
|
-
transcript()._event(
|
316
|
-
SampleLimitEvent(
|
317
|
-
type="message",
|
318
|
-
message=f"Sample completed: exceeded message limit ({self.message_limit})",
|
319
|
-
limit=self.message_limit,
|
320
|
-
)
|
321
|
-
)
|
322
|
-
return True
|
323
|
-
elif self.token_limit and self.token_usage >= self.token_limit:
|
324
|
-
# log if this is the first time we hit this
|
325
|
-
if not self._token_limit_exceeded:
|
326
|
-
self._token_limit_exceeded = True
|
327
|
-
transcript()._event(
|
328
|
-
SampleLimitEvent(
|
329
|
-
type="token",
|
330
|
-
message=f"Sample completed: exceeded token limit ({self.token_limit:,})",
|
331
|
-
limit=self.token_limit,
|
332
|
-
)
|
333
|
-
)
|
334
|
-
return True
|
335
|
-
else:
|
336
|
-
return False
|
323
|
+
return self._completed
|
337
324
|
|
338
325
|
@completed.setter
|
339
326
|
def completed(self, completed: bool) -> None:
|
@@ -413,3 +400,58 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
|
|
413
400
|
def sample_jsonable(sample: Sample) -> dict[str, Any]:
|
414
401
|
jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
|
415
402
|
return cast(dict[str, Any], deepcopy(jsonable))
|
403
|
+
|
404
|
+
|
405
|
+
class ChatMessageList(list[ChatMessage]):
|
406
|
+
def __init__(self, iterable: Iterable[ChatMessage]):
|
407
|
+
items, length = self._iterable_length(iterable)
|
408
|
+
self._check_size(length)
|
409
|
+
super().__init__(items)
|
410
|
+
|
411
|
+
def _check_size(self, additional_items: int = 1) -> None:
|
412
|
+
from inspect_ai.log._samples import active_sample_message_limit
|
413
|
+
|
414
|
+
messages_limit = active_sample_message_limit()
|
415
|
+
if messages_limit is not None:
|
416
|
+
messages = len(self) + additional_items
|
417
|
+
if messages > messages_limit:
|
418
|
+
raise SampleLimitExceededError(
|
419
|
+
"message", value=messages, limit=messages_limit
|
420
|
+
)
|
421
|
+
|
422
|
+
def append(self, item: ChatMessage) -> None:
|
423
|
+
self._check_size()
|
424
|
+
super().append(item)
|
425
|
+
|
426
|
+
def extend(self, items: Iterable[ChatMessage]) -> None:
|
427
|
+
items, length = self._iterable_length(items)
|
428
|
+
self._check_size(length)
|
429
|
+
super().extend(items)
|
430
|
+
|
431
|
+
def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
|
432
|
+
self._check_size()
|
433
|
+
super().insert(index, item)
|
434
|
+
|
435
|
+
@overload
|
436
|
+
def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
|
437
|
+
|
438
|
+
@overload
|
439
|
+
def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
|
440
|
+
|
441
|
+
def __setitem__(
|
442
|
+
self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
|
443
|
+
) -> None:
|
444
|
+
if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
|
445
|
+
item, length = self._iterable_length(item)
|
446
|
+
size_change = length - len(self[index])
|
447
|
+
if size_change > 0:
|
448
|
+
self._check_size(size_change)
|
449
|
+
|
450
|
+
super().__setitem__(index, item) # type: ignore[assignment,index]
|
451
|
+
|
452
|
+
def _iterable_length(
|
453
|
+
self, items: Iterable[ChatMessage]
|
454
|
+
) -> tuple[Iterable[ChatMessage], int]:
|
455
|
+
items, counter = tee(items)
|
456
|
+
length = sum(1 for _ in counter)
|
457
|
+
return items, length
|
inspect_ai/tool/__init__.py
CHANGED
@@ -12,6 +12,7 @@ from ._tool_call import (
|
|
12
12
|
ToolCall,
|
13
13
|
ToolCallContent,
|
14
14
|
ToolCallError,
|
15
|
+
ToolCallModelInput,
|
15
16
|
ToolCallView,
|
16
17
|
ToolCallViewer,
|
17
18
|
)
|
@@ -42,6 +43,7 @@ __all__ = [
|
|
42
43
|
"ContentVideo",
|
43
44
|
"ToolCall",
|
44
45
|
"ToolCallContent",
|
46
|
+
"ToolCallModelInput",
|
45
47
|
"ToolCallView",
|
46
48
|
"ToolCallViewer",
|
47
49
|
"ToolChoice",
|