inspect-ai 0.3.59__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -7
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +11 -11
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -11
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +2 -0
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +68 -63
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -24,13 +24,13 @@ from .._model_output import (
|
|
24
24
|
ModelOutput,
|
25
25
|
ModelUsage,
|
26
26
|
StopReason,
|
27
|
+
as_stop_reason,
|
27
28
|
)
|
29
|
+
from .._openai import chat_message_assistant_from_openai
|
28
30
|
from .openai import (
|
29
31
|
OpenAIAPI,
|
30
|
-
chat_message_assistant,
|
31
32
|
)
|
32
33
|
from .util import (
|
33
|
-
as_stop_reason,
|
34
34
|
chat_api_input,
|
35
35
|
chat_api_request,
|
36
36
|
environment_prerequisite_error,
|
@@ -68,7 +68,7 @@ def chat_choices_from_response_together(
|
|
68
68
|
logprobs_models.append(Logprobs(content=logprobs_sequence))
|
69
69
|
return [
|
70
70
|
ChatCompletionChoice(
|
71
|
-
message=
|
71
|
+
message=chat_message_assistant_from_openai(choice.message, tools),
|
72
72
|
stop_reason=as_stop_reason(choice.finish_reason),
|
73
73
|
logprobs=logprobs,
|
74
74
|
)
|
@@ -99,7 +99,7 @@ class TogetherAIAPI(OpenAIAPI):
|
|
99
99
|
|
100
100
|
# Together uses a default of 512 so we bump it up
|
101
101
|
@override
|
102
|
-
def max_tokens(self) -> int:
|
102
|
+
def max_tokens(self) -> int | None:
|
103
103
|
return DEFAULT_MAX_TOKENS
|
104
104
|
|
105
105
|
@override
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
2
|
+
from ..._model_output import as_stop_reason
|
1
3
|
from .chatapi import (
|
2
4
|
ChatAPIHandler,
|
3
5
|
ChatAPIMessage,
|
@@ -8,11 +10,8 @@ from .chatapi import (
|
|
8
10
|
from .hf_handler import HFHandler
|
9
11
|
from .llama31 import Llama31Handler
|
10
12
|
from .util import (
|
11
|
-
as_stop_reason,
|
12
13
|
environment_prerequisite_error,
|
13
14
|
model_base_url,
|
14
|
-
parse_tool_call,
|
15
|
-
tool_parse_error_message,
|
16
15
|
)
|
17
16
|
|
18
17
|
__all__ = [
|
@@ -8,9 +8,9 @@ from typing_extensions import override
|
|
8
8
|
from inspect_ai.tool._tool_call import ToolCall
|
9
9
|
from inspect_ai.tool._tool_info import ToolInfo
|
10
10
|
|
11
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
11
12
|
from ..._chat_message import ChatMessageAssistant
|
12
13
|
from .chatapi import ChatAPIHandler
|
13
|
-
from .util import parse_tool_call, tool_parse_error_message
|
14
14
|
|
15
15
|
logger = getLogger(__name__)
|
16
16
|
|
@@ -9,6 +9,7 @@ from typing_extensions import override
|
|
9
9
|
from inspect_ai.tool._tool_call import ToolCall
|
10
10
|
from inspect_ai.tool._tool_info import ToolInfo
|
11
11
|
|
12
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
12
13
|
from ..._chat_message import (
|
13
14
|
ChatMessage,
|
14
15
|
ChatMessageAssistant,
|
@@ -16,7 +17,6 @@ from ..._chat_message import (
|
|
16
17
|
ChatMessageTool,
|
17
18
|
)
|
18
19
|
from .chatapi import ChatAPIHandler, ChatAPIMessage
|
19
|
-
from .util import parse_tool_call, tool_parse_error_message
|
20
20
|
|
21
21
|
logger = getLogger(__name__)
|
22
22
|
|
@@ -1,34 +1,11 @@
|
|
1
|
-
import json
|
2
1
|
import os
|
3
2
|
from logging import getLogger
|
4
|
-
from typing import Any
|
5
|
-
|
6
|
-
import yaml
|
7
3
|
|
8
4
|
from inspect_ai._util.error import PrerequisiteError
|
9
|
-
from inspect_ai.tool._tool_call import ToolCall
|
10
|
-
from inspect_ai.tool._tool_info import ToolInfo
|
11
|
-
|
12
|
-
from ..._model_output import StopReason
|
13
5
|
|
14
6
|
logger = getLogger(__name__)
|
15
7
|
|
16
8
|
|
17
|
-
def as_stop_reason(reason: str | None) -> StopReason:
|
18
|
-
"""Encode common reason strings into standard StopReason."""
|
19
|
-
match reason:
|
20
|
-
case "stop" | "eos":
|
21
|
-
return "stop"
|
22
|
-
case "length":
|
23
|
-
return "max_tokens"
|
24
|
-
case "tool_calls" | "function_call":
|
25
|
-
return "tool_calls"
|
26
|
-
case "content_filter" | "model_length" | "max_tokens":
|
27
|
-
return reason
|
28
|
-
case _:
|
29
|
-
return "unknown"
|
30
|
-
|
31
|
-
|
32
9
|
def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | None:
|
33
10
|
if base_url:
|
34
11
|
return base_url
|
@@ -44,59 +21,6 @@ def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | Non
|
|
44
21
|
return os.getenv("INSPECT_EVAL_MODEL_BASE_URL", None)
|
45
22
|
|
46
23
|
|
47
|
-
def tool_parse_error_message(arguments: str, ex: Exception) -> str:
|
48
|
-
return f"Error parsing the following tool call arguments:\n\n{arguments}\n\nError details: {ex}"
|
49
|
-
|
50
|
-
|
51
|
-
def parse_tool_call(
|
52
|
-
id: str, function: str, arguments: str, tools: list[ToolInfo]
|
53
|
-
) -> ToolCall:
|
54
|
-
error: str | None = None
|
55
|
-
arguments_dict: dict[str, Any] = {}
|
56
|
-
|
57
|
-
def report_parse_error(ex: Exception) -> None:
|
58
|
-
nonlocal error
|
59
|
-
error = tool_parse_error_message(arguments, ex)
|
60
|
-
logger.info(error)
|
61
|
-
|
62
|
-
# if the arguments is a dict, then handle it with a plain json.loads
|
63
|
-
arguments = arguments.strip()
|
64
|
-
if arguments.startswith("{"):
|
65
|
-
try:
|
66
|
-
arguments_dict = json.loads(arguments)
|
67
|
-
except json.JSONDecodeError as ex:
|
68
|
-
report_parse_error(ex)
|
69
|
-
|
70
|
-
# otherwise parse it as yaml (which will pickup unquoted strings, numbers, and true/false)
|
71
|
-
# and then create a dict that maps it to the first function argument
|
72
|
-
else:
|
73
|
-
tool_info = next(
|
74
|
-
(
|
75
|
-
tool
|
76
|
-
for tool in tools
|
77
|
-
if tool.name == function and len(tool.parameters.properties) > 0
|
78
|
-
),
|
79
|
-
None,
|
80
|
-
)
|
81
|
-
if tool_info:
|
82
|
-
param_names = list(tool_info.parameters.properties.keys())
|
83
|
-
try:
|
84
|
-
value = yaml.safe_load(arguments)
|
85
|
-
arguments_dict[param_names[0]] = value
|
86
|
-
except yaml.error.YAMLError:
|
87
|
-
# If the yaml parser fails, we treat it as a string argument.
|
88
|
-
arguments_dict[param_names[0]] = arguments
|
89
|
-
|
90
|
-
# return ToolCall with error payload
|
91
|
-
return ToolCall(
|
92
|
-
id=id,
|
93
|
-
function=function,
|
94
|
-
arguments=arguments_dict,
|
95
|
-
type="function",
|
96
|
-
parse_error=error,
|
97
|
-
)
|
98
|
-
|
99
|
-
|
100
24
|
def environment_prerequisite_error(
|
101
25
|
client: str, env_vars: str | list[str]
|
102
26
|
) -> PrerequisiteError:
|
inspect_ai/scorer/_metric.py
CHANGED
@@ -125,6 +125,9 @@ class SampleScore(BaseModel):
|
|
125
125
|
sample_id: str | int | None = Field(default=None)
|
126
126
|
"""A sample id"""
|
127
127
|
|
128
|
+
scorer: str | None = Field(default=None)
|
129
|
+
"""Registry name of scorer that created this score."""
|
130
|
+
|
128
131
|
|
129
132
|
ValueToFloat = Callable[[Value], float]
|
130
133
|
"""Function used by metrics to translate from a Score value to a float value."""
|
inspect_ai/scorer/_scorer.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from functools import wraps
|
1
2
|
from typing import (
|
2
3
|
Any,
|
3
4
|
Callable,
|
@@ -100,7 +101,6 @@ def scorer(
|
|
100
101
|
|
101
102
|
Returns:
|
102
103
|
Scorer with registry attributes.
|
103
|
-
|
104
104
|
"""
|
105
105
|
|
106
106
|
def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
|
@@ -110,6 +110,7 @@ def scorer(
|
|
110
110
|
)
|
111
111
|
|
112
112
|
# wrap instantiations of scorer so they carry registry info and metrics
|
113
|
+
@wraps(scorer_type)
|
113
114
|
def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
|
114
115
|
scorer = scorer_type(*args, **kwargs)
|
115
116
|
|
inspect_ai/solver/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
2
|
|
3
3
|
from ._basic_agent import basic_agent
|
4
|
+
from ._bridge import bridge
|
4
5
|
from ._chain import chain
|
5
6
|
from ._critique import self_critique
|
6
7
|
from ._fork import fork
|
@@ -14,6 +15,7 @@ from ._use_tools import use_tools
|
|
14
15
|
|
15
16
|
__all__ = [
|
16
17
|
"basic_agent",
|
18
|
+
"bridge",
|
17
19
|
"human_agent",
|
18
20
|
"chain",
|
19
21
|
"fork",
|
@@ -119,7 +119,7 @@ def basic_agent(
|
|
119
119
|
# resolve tools
|
120
120
|
if tools is None:
|
121
121
|
tools = []
|
122
|
-
tools = tools if isinstance(tools, Solver) else use_tools(tools)
|
122
|
+
tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
|
123
123
|
|
124
124
|
# resolve score_value function
|
125
125
|
score_value_fn = score_value or value_to_float()
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Any, Awaitable, Callable
|
2
|
+
|
3
|
+
from jsonschema import Draft7Validator
|
4
|
+
from pydantic import BaseModel, Field, ValidationError
|
5
|
+
from pydantic_core import to_json
|
6
|
+
|
7
|
+
from inspect_ai._util._async import is_callable_coroutine
|
8
|
+
from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
|
9
|
+
from inspect_ai.model._providers.providers import validate_openai_client
|
10
|
+
from inspect_ai.scorer._metric import Score
|
11
|
+
|
12
|
+
from .._solver import Generate, Solver, solver
|
13
|
+
from .._task_state import TaskState
|
14
|
+
|
15
|
+
|
16
|
+
@solver
|
17
|
+
def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
|
18
|
+
"""Bridge an external agent into an Inspect Solver.
|
19
|
+
|
20
|
+
See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
|
21
|
+
|
22
|
+
Args:
|
23
|
+
agent: Callable which takes a sample `dict` and returns a result `dict`.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
Standard Inspect solver.
|
27
|
+
"""
|
28
|
+
validate_openai_client("Solver bridge()")
|
29
|
+
|
30
|
+
from openai.types.chat import ChatCompletionMessageParam
|
31
|
+
|
32
|
+
from inspect_ai.model._openai import (
|
33
|
+
chat_messages_from_openai,
|
34
|
+
openai_chat_messages,
|
35
|
+
)
|
36
|
+
|
37
|
+
from .patch import openai_request_to_inspect_model
|
38
|
+
|
39
|
+
class BridgeSample(BaseModel):
|
40
|
+
sample_id: str
|
41
|
+
epoch: int
|
42
|
+
input: list[ChatCompletionMessageParam]
|
43
|
+
metadata: dict[str, Any]
|
44
|
+
target: list[str]
|
45
|
+
|
46
|
+
class BridgeResult(BaseModel):
|
47
|
+
output: str
|
48
|
+
messages: list[ChatCompletionMessageParam] | None = Field(default=None)
|
49
|
+
scores: dict[str, Score] | None = Field(default=None)
|
50
|
+
|
51
|
+
result_schema = BridgeResult.model_json_schema()
|
52
|
+
result_validator = Draft7Validator(result_schema)
|
53
|
+
|
54
|
+
# validate that the agent is an async function
|
55
|
+
if not is_callable_coroutine(agent):
|
56
|
+
raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
|
57
|
+
|
58
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
59
|
+
# resolve input to array
|
60
|
+
input: list[ChatMessage] = (
|
61
|
+
[ChatMessageUser(content=state.input)]
|
62
|
+
if isinstance(state.input, str)
|
63
|
+
else state.input
|
64
|
+
)
|
65
|
+
|
66
|
+
# create sample
|
67
|
+
sample = BridgeSample(
|
68
|
+
sample_id=str(state.sample_id),
|
69
|
+
epoch=state.epoch,
|
70
|
+
input=await openai_chat_messages(input, state.model.name),
|
71
|
+
metadata=state.metadata,
|
72
|
+
target=list(state.target),
|
73
|
+
)
|
74
|
+
|
75
|
+
# run target function
|
76
|
+
async with openai_request_to_inspect_model():
|
77
|
+
# call the function
|
78
|
+
result_dict = await agent(sample.model_dump())
|
79
|
+
try:
|
80
|
+
result = BridgeResult.model_validate(result_dict)
|
81
|
+
except ValidationError:
|
82
|
+
# if we fail to validate provide a better human readable error
|
83
|
+
errors = list(result_validator.iter_errors(result_dict))
|
84
|
+
message = "\n".join(
|
85
|
+
["Result returned from bridged solver is not valid:"]
|
86
|
+
+ [f" - {error.message}" for error in errors]
|
87
|
+
+ ["", to_json(result_dict, indent=2).decode()]
|
88
|
+
)
|
89
|
+
raise ValueError(message)
|
90
|
+
|
91
|
+
# update and return state
|
92
|
+
state.output.completion = result.output
|
93
|
+
if result.messages is not None:
|
94
|
+
state.messages = chat_messages_from_openai(result.messages)
|
95
|
+
if result.scores is not None:
|
96
|
+
state.scores = result.scores
|
97
|
+
|
98
|
+
return state
|
99
|
+
|
100
|
+
return solve
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import contextlib
|
2
|
+
import re
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from functools import wraps
|
5
|
+
from time import time
|
6
|
+
from typing import Any, AsyncGenerator, Optional, Type, cast
|
7
|
+
|
8
|
+
from openai._base_client import AsyncAPIClient, _AsyncStreamT
|
9
|
+
from openai._models import FinalRequestOptions
|
10
|
+
from openai._types import ResponseT
|
11
|
+
from openai.types.chat import (
|
12
|
+
ChatCompletion,
|
13
|
+
ChatCompletionMessageParam,
|
14
|
+
ChatCompletionToolParam,
|
15
|
+
)
|
16
|
+
from shortuuid import uuid
|
17
|
+
|
18
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
19
|
+
from inspect_ai.model._model import get_model
|
20
|
+
from inspect_ai.model._openai import (
|
21
|
+
chat_messages_from_openai,
|
22
|
+
openai_chat_choices,
|
23
|
+
openai_completion_usage,
|
24
|
+
)
|
25
|
+
from inspect_ai.solver._task_state import sample_state
|
26
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
27
|
+
from inspect_ai.tool._tool_params import ToolParams
|
28
|
+
|
29
|
+
|
30
|
+
@contextlib.asynccontextmanager
|
31
|
+
async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
|
32
|
+
# ensure one time init
|
33
|
+
init_openai_request_patch()
|
34
|
+
|
35
|
+
# set the patch enabled for this context and child coroutines
|
36
|
+
token = _patch_enabled.set(True)
|
37
|
+
try:
|
38
|
+
yield
|
39
|
+
finally:
|
40
|
+
_patch_enabled.reset(token)
|
41
|
+
|
42
|
+
|
43
|
+
_patch_initialised: bool = False
|
44
|
+
|
45
|
+
_patch_enabled: ContextVar[bool] = ContextVar(
|
46
|
+
"openai_request_patch_enabled", default=False
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def init_openai_request_patch() -> None:
|
51
|
+
global _patch_initialised
|
52
|
+
if not _patch_initialised:
|
53
|
+
# get reference to original method
|
54
|
+
original_request = getattr(AsyncAPIClient, "request")
|
55
|
+
if original_request is None:
|
56
|
+
raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
|
57
|
+
|
58
|
+
@wraps(original_request)
|
59
|
+
async def patched_request(
|
60
|
+
self: AsyncAPIClient,
|
61
|
+
cast_to: Type[ResponseT],
|
62
|
+
options: FinalRequestOptions,
|
63
|
+
*,
|
64
|
+
stream: bool = False,
|
65
|
+
stream_cls: type[_AsyncStreamT] | None = None,
|
66
|
+
remaining_retries: Optional[int] = None,
|
67
|
+
) -> Any:
|
68
|
+
# we have patched the underlying request method so now need to figure out when to
|
69
|
+
# patch and when to stand down
|
70
|
+
if (
|
71
|
+
# enabled for this coroutine
|
72
|
+
_patch_enabled.get()
|
73
|
+
# completions request
|
74
|
+
and options.url == "/chat/completions"
|
75
|
+
# call to openai not another service (e.g. TogetherAI)
|
76
|
+
and self.base_url == "https://api.openai.com/v1/"
|
77
|
+
):
|
78
|
+
# must also be an explicit request for an inspect model
|
79
|
+
json_data = cast(dict[str, Any], options.json_data)
|
80
|
+
model_name = str(json_data["model"])
|
81
|
+
if re.match(r"^inspect/?", model_name):
|
82
|
+
return await inspect_model_request(model_name, options)
|
83
|
+
|
84
|
+
# otherwise just delegate
|
85
|
+
return await original_request(
|
86
|
+
self,
|
87
|
+
cast_to,
|
88
|
+
options,
|
89
|
+
stream=stream,
|
90
|
+
stream_cls=stream_cls,
|
91
|
+
remaining_retries=remaining_retries,
|
92
|
+
)
|
93
|
+
|
94
|
+
setattr(AsyncAPIClient, "request", patched_request)
|
95
|
+
|
96
|
+
|
97
|
+
async def inspect_model_request(
|
98
|
+
model_name: str, options: FinalRequestOptions
|
99
|
+
) -> ChatCompletion:
|
100
|
+
# convert openai messages to inspect messages
|
101
|
+
json_data = cast(dict[str, Any], options.json_data)
|
102
|
+
messages: list[ChatCompletionMessageParam] = json_data["messages"]
|
103
|
+
input = chat_messages_from_openai(messages)
|
104
|
+
|
105
|
+
# convert openai tools to inspect tools
|
106
|
+
tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
|
107
|
+
inspect_tools: list[ToolInfo] = []
|
108
|
+
for tool in tools:
|
109
|
+
function = tool["function"].copy()
|
110
|
+
inspect_tools.append(
|
111
|
+
ToolInfo(
|
112
|
+
name=function["name"],
|
113
|
+
description=function["description"],
|
114
|
+
parameters=ToolParams.model_validate(function["parameters"]),
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
# resolve model
|
119
|
+
if model_name == "inspect":
|
120
|
+
model = get_model()
|
121
|
+
else:
|
122
|
+
model = get_model(model_name.removeprefix("inspect/"))
|
123
|
+
|
124
|
+
output = await model.generate(
|
125
|
+
input=input,
|
126
|
+
tools=inspect_tools,
|
127
|
+
config=generate_config_from_openai(options),
|
128
|
+
)
|
129
|
+
|
130
|
+
# if we are using the "default" inspect model for the task, update state.messages
|
131
|
+
if model_name == "inspect":
|
132
|
+
state = sample_state()
|
133
|
+
if state:
|
134
|
+
state.messages = input + [output.choices[0].message]
|
135
|
+
|
136
|
+
# inspect completion to openai completion
|
137
|
+
return ChatCompletion(
|
138
|
+
id=uuid(),
|
139
|
+
created=int(time()),
|
140
|
+
object="chat.completion",
|
141
|
+
choices=openai_chat_choices(output.choices),
|
142
|
+
model=model_name,
|
143
|
+
usage=openai_completion_usage(output.usage) if output.usage else None,
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
|
148
|
+
# get options dict
|
149
|
+
json_data = cast(dict[str, Any], options.json_data)
|
150
|
+
|
151
|
+
config = GenerateConfig()
|
152
|
+
config.max_tokens = json_data.get(
|
153
|
+
"max_completion_tokens", json_data.get("max_tokens", None)
|
154
|
+
)
|
155
|
+
config.top_p = json_data.get("top_p", None)
|
156
|
+
config.temperature = json_data.get("temperature", None)
|
157
|
+
stop = json_data.get("stop", None)
|
158
|
+
if stop:
|
159
|
+
config.stop_seqs = [stop] if isinstance(stop, str) else stop
|
160
|
+
config.frequency_penalty = json_data.get("frequency_penalty", None)
|
161
|
+
config.presence_penalty = json_data.get("presence_penalty", None)
|
162
|
+
config.seed = json_data.get("seed", None)
|
163
|
+
config.num_choices = json_data.get("n", None)
|
164
|
+
config.logprobs = json_data.get("logprobs", None)
|
165
|
+
config.top_logprobs = json_data.get("top_logprobs", None)
|
166
|
+
config.logit_bias = json_data.get("logit_bias", None)
|
167
|
+
config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
|
168
|
+
config.reasoning_effort = json_data.get("reasoning_effort", None)
|
169
|
+
|
170
|
+
return config
|
inspect_ai/solver/_solver.py
CHANGED
@@ -180,6 +180,7 @@ def solver(
|
|
180
180
|
solver_type, name if name else getattr(solver_type, "__name__")
|
181
181
|
)
|
182
182
|
|
183
|
+
@wraps(solver_type)
|
183
184
|
def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
|
184
185
|
solver = solver_type(*args, **kwargs)
|
185
186
|
|
@@ -193,6 +194,7 @@ def solver(
|
|
193
194
|
if inspect.isclass(type(solver)):
|
194
195
|
original_call = solver.__call__
|
195
196
|
|
197
|
+
@wraps(original_call)
|
196
198
|
async def call_with_state(
|
197
199
|
state: TaskState, generate: Generate
|
198
200
|
) -> TaskState:
|
@@ -225,6 +227,10 @@ def solver(
|
|
225
227
|
|
226
228
|
return registered_solver
|
227
229
|
|
230
|
+
# functools.wraps overrides the return type annotation of the inner function, so
|
231
|
+
# we explicitly set it again
|
232
|
+
solver_wrapper.__annotations__["return"] = Solver
|
233
|
+
|
228
234
|
return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
|
229
235
|
|
230
236
|
# for decorators with an explicit name, one more wrapper for the name
|
inspect_ai/util/_display.py
CHANGED
@@ -57,7 +57,7 @@ async def validate_docker_compose(
|
|
57
57
|
version: str = DOCKER_COMPOSE_REQUIRED_VERSION,
|
58
58
|
) -> None:
|
59
59
|
def parse_version(stdout: str) -> semver.Version:
|
60
|
-
version = json.loads(stdout)["version"].removeprefix("v")
|
60
|
+
version = json.loads(stdout)["version"].removeprefix("v").split("+")[0]
|
61
61
|
return semver.Version.parse(version)
|
62
62
|
|
63
63
|
await validate_version(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: inspect_ai
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.60
|
4
4
|
Summary: Framework for large language model evaluations
|
5
5
|
Author: UK AI Safety Institute
|
6
6
|
License: MIT License
|
@@ -54,6 +54,7 @@ Requires-Dist: aioboto3; extra == "dev"
|
|
54
54
|
Requires-Dist: azure-ai-inference; extra == "dev"
|
55
55
|
Requires-Dist: google-cloud-aiplatform; extra == "dev"
|
56
56
|
Requires-Dist: google-generativeai; extra == "dev"
|
57
|
+
Requires-Dist: goodfire; extra == "dev"
|
57
58
|
Requires-Dist: groq; extra == "dev"
|
58
59
|
Requires-Dist: ipython; extra == "dev"
|
59
60
|
Requires-Dist: mistralai; extra == "dev"
|
@@ -67,7 +68,7 @@ Requires-Dist: pytest-asyncio; extra == "dev"
|
|
67
68
|
Requires-Dist: pytest-cov; extra == "dev"
|
68
69
|
Requires-Dist: pytest-dotenv; extra == "dev"
|
69
70
|
Requires-Dist: pytest-xdist; extra == "dev"
|
70
|
-
Requires-Dist: ruff==0.9.
|
71
|
+
Requires-Dist: ruff==0.9.3; extra == "dev"
|
71
72
|
Requires-Dist: textual-dev>=0.86.2; extra == "dev"
|
72
73
|
Requires-Dist: types-PyYAML; extra == "dev"
|
73
74
|
Requires-Dist: types-beautifulsoup4; extra == "dev"
|