inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -0
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +23 -27
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +398 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +173 -159
- inspect_ai/_view/www/dist/assets/index.js +1417 -1142
- inspect_ai/_view/www/log-schema.json +379 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +93 -14
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +22 -12
- inspect_ai/agent/_as_tool.py +20 -6
- inspect_ai/agent/_handoff.py +12 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +16 -3
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +14 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +14 -25
- inspect_ai/log/_transcript.py +84 -36
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +72 -44
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +66 -88
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +247 -0
- inspect_ai/model/_providers/vllm.py +211 -400
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +8 -5
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +16 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -1,439 +1,250 @@
|
|
1
|
-
import
|
2
|
-
import
|
3
|
-
import gc
|
1
|
+
import atexit
|
2
|
+
import logging
|
4
3
|
import os
|
5
|
-
import
|
6
|
-
from
|
7
|
-
from dataclasses import dataclass
|
8
|
-
from queue import Empty, Queue
|
9
|
-
from threading import Thread
|
10
|
-
from typing import Any, cast
|
11
|
-
|
12
|
-
import anyio
|
13
|
-
from typing_extensions import override
|
14
|
-
from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
|
15
|
-
|
16
|
-
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
17
|
-
from inspect_ai.tool import ToolChoice, ToolInfo
|
18
|
-
|
19
|
-
from .._chat_message import ChatMessage, ChatMessageAssistant
|
20
|
-
from .._generate_config import GenerateConfig
|
21
|
-
from .._model import ModelAPI, simple_input_messages
|
22
|
-
from .._model_output import (
|
23
|
-
ChatCompletionChoice,
|
24
|
-
Logprob,
|
25
|
-
Logprobs,
|
26
|
-
ModelOutput,
|
27
|
-
ModelUsage,
|
28
|
-
StopReason,
|
29
|
-
TopLogprob,
|
30
|
-
)
|
31
|
-
from .util import ChatAPIHandler, chat_api_input
|
32
|
-
|
33
|
-
DEFAULT_START_TOKEN = "<|im_start|>"
|
34
|
-
DEFAULT_END_TOKEN = "<|im_end|>"
|
35
|
-
|
36
|
-
HF_TOKEN = "HF_TOKEN"
|
37
|
-
|
38
|
-
|
39
|
-
@dataclass
|
40
|
-
class GenerateInput:
|
41
|
-
input: str
|
42
|
-
generator: Any
|
43
|
-
batch_size: int
|
44
|
-
num_top_logprobs: int | None = None
|
45
|
-
|
4
|
+
from subprocess import Popen
|
5
|
+
from typing import Any
|
46
6
|
|
47
|
-
|
48
|
-
|
49
|
-
output: str
|
50
|
-
input_tokens: int
|
51
|
-
output_tokens: int
|
52
|
-
total_tokens: int
|
53
|
-
stop_reason: StopReason
|
54
|
-
logprobs: Logprobs | None
|
55
|
-
time: float
|
7
|
+
from openai import APIStatusError
|
8
|
+
from typing_extensions import override
|
56
9
|
|
10
|
+
from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
|
11
|
+
from inspect_ai._util.local_server import (
|
12
|
+
configure_devices,
|
13
|
+
merge_env_server_args,
|
14
|
+
start_local_server,
|
15
|
+
terminate_process,
|
16
|
+
)
|
17
|
+
from inspect_ai.model._chat_message import ChatMessage
|
18
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
19
|
+
from inspect_ai.model._model_call import ModelCall
|
20
|
+
from inspect_ai.model._model_output import ModelOutput
|
21
|
+
from inspect_ai.tool._tool_choice import ToolChoice
|
22
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
23
|
+
|
24
|
+
from .openai_compatible import OpenAICompatibleAPI
|
25
|
+
|
26
|
+
# Environment variable names
|
27
|
+
# VLLM_BASE_URL = "VLLM_BASE_URL"
|
28
|
+
# VLLM_API_KEY = "VLLM_API_KEY"
|
29
|
+
VLLM_DEFAULT_SERVER_ARGS = "VLLM_DEFAULT_SERVER_ARGS"
|
30
|
+
VLLM_CONFIGURE_LOGGING = "VLLM_CONFIGURE_LOGGING"
|
31
|
+
|
32
|
+
# Set up logger for this module
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class VLLMAPI(OpenAICompatibleAPI):
|
37
|
+
"""
|
38
|
+
Provider for using vLLM models.
|
39
|
+
|
40
|
+
This provider can either:
|
41
|
+
1. Connect to an existing vLLM server (if base_url or port is provided)
|
42
|
+
2. Start a new vLLM server for the specified model
|
43
|
+
|
44
|
+
Additional server_args:
|
45
|
+
timeout (int): Timeout for the server (default: 10 minutes)
|
46
|
+
host (str): Host to bind the server to (default: "0.0.0.0")
|
47
|
+
configure_logging (bool): Enable fine-grained vLLM logging (default: False)
|
48
|
+
device (str): Devices to run the server on. Can be a single device or a list of devices as used in CUDA_VISIBLE_DEVICES. If tensor_parallel_size is not provided, the server will use the number of devices as the tensor parallel size.
|
49
|
+
|
50
|
+
Environment variables:
|
51
|
+
VLLM_BASE_URL: Base URL for an existing vLLM server
|
52
|
+
VLLM_API_KEY: API key for the vLLM server
|
53
|
+
VLLM_DEFAULT_SERVER_ARGS: JSON string of default server args, e.g. '{"tensor_parallel_size": 4, "max_model_len": 8192}'
|
54
|
+
VLLM_CONFIGURE_LOGGING: Enable fine-grained vLLM logging
|
55
|
+
"""
|
57
56
|
|
58
|
-
class VLLMAPI(ModelAPI):
|
59
57
|
def __init__(
|
60
58
|
self,
|
61
59
|
model_name: str,
|
62
60
|
base_url: str | None = None,
|
61
|
+
port: int | None = None,
|
63
62
|
api_key: str | None = None,
|
64
63
|
config: GenerateConfig = GenerateConfig(),
|
65
|
-
**
|
64
|
+
**server_args: Any,
|
66
65
|
) -> None:
|
67
|
-
|
68
|
-
|
69
|
-
base_url
|
70
|
-
|
71
|
-
|
72
|
-
|
66
|
+
# Validate inputs
|
67
|
+
if base_url and port:
|
68
|
+
raise ValueError("base_url and port cannot both be provided.")
|
69
|
+
if port:
|
70
|
+
base_url = f"http://localhost:{port}/v1"
|
71
|
+
|
72
|
+
# Initialize server process and port variables
|
73
|
+
self.server_process: Popen[str] | None = None
|
74
|
+
self.port: int | None = port
|
75
|
+
self.server_args = merge_env_server_args(
|
76
|
+
VLLM_DEFAULT_SERVER_ARGS, server_args, logger
|
73
77
|
)
|
74
78
|
|
75
|
-
self.
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
model_path = collect_model_arg("model_path")
|
90
|
-
tokenizer_path = collect_model_arg("tokenizer_path")
|
91
|
-
self.batch_size = collect_model_arg("batch_size")
|
92
|
-
self.chat_template = collect_model_arg("chat_template")
|
93
|
-
|
94
|
-
# if user provides model_path, use that instead of model_name
|
95
|
-
if model_path:
|
96
|
-
model_name = model_path
|
97
|
-
|
98
|
-
# load tokenizer
|
99
|
-
if not tokenizer:
|
100
|
-
if tokenizer_path:
|
101
|
-
tokenizer = tokenizer_path
|
102
|
-
else:
|
103
|
-
tokenizer = model_name
|
104
|
-
|
105
|
-
# set which GPUs are available to use
|
106
|
-
if device is not None:
|
107
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(device)
|
108
|
-
|
109
|
-
# tell vllm how many GPUs to use
|
110
|
-
if "tensor_parallel_size" not in model_args:
|
111
|
-
devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
|
112
|
-
num_devices = len(devices)
|
113
|
-
if num_devices > 1:
|
114
|
-
model_args["tensor_parallel_size"] = num_devices
|
115
|
-
else:
|
116
|
-
model_args["tensor_parallel_size"] = 1
|
117
|
-
|
118
|
-
# https://github.com/vllm-project/vllm/pull/6051
|
119
|
-
# Gemma 2 models require FlashInfer backend for softcap logits
|
120
|
-
if "google/gemma-2" in model_name:
|
121
|
-
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
|
122
|
-
try:
|
123
|
-
import importlib
|
124
|
-
|
125
|
-
# check if flashinfer is installed
|
126
|
-
importlib.import_module("flashinfer")
|
127
|
-
except ImportError:
|
128
|
-
raise ImportError(
|
129
|
-
"To use the 'google/gemma-2' model, you must install the 'flashinfer' package. "
|
130
|
-
"See https://docs.flashinfer.ai/installation.html"
|
131
|
-
)
|
79
|
+
self.server_found = True
|
80
|
+
try:
|
81
|
+
# Try to initialize with existing server
|
82
|
+
super().__init__(
|
83
|
+
model_name=model_name,
|
84
|
+
base_url=base_url,
|
85
|
+
api_key=api_key,
|
86
|
+
config=config,
|
87
|
+
service="vLLM",
|
88
|
+
service_base_url=base_url,
|
89
|
+
)
|
90
|
+
logger.info(f"Using existing vLLM server at {self.base_url}")
|
91
|
+
except PrerequisiteError:
|
92
|
+
self.server_found = False
|
132
93
|
|
133
|
-
|
134
|
-
|
94
|
+
if not self.server_found:
|
95
|
+
logger.warning(
|
96
|
+
f"Existing vLLM server not found. Starting new server for {model_name}."
|
97
|
+
)
|
135
98
|
|
136
|
-
|
137
|
-
|
99
|
+
# Extract and handle the configure_logging parameter
|
100
|
+
configure_logging = self.server_args.pop("configure_logging", False)
|
101
|
+
os.environ[VLLM_CONFIGURE_LOGGING] = "1" if configure_logging else "0"
|
102
|
+
|
103
|
+
# Start the server
|
104
|
+
base_url, api_key = self._start_server(model_name, api_key=api_key)
|
105
|
+
logger.warning(f"vLLM server started at {base_url}")
|
106
|
+
|
107
|
+
# Initialize with new server
|
108
|
+
super().__init__(
|
109
|
+
model_name=model_name,
|
110
|
+
base_url=base_url,
|
111
|
+
api_key=api_key,
|
112
|
+
config=config,
|
113
|
+
service="vLLM",
|
114
|
+
service_base_url=base_url,
|
115
|
+
)
|
138
116
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
#
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
117
|
+
def _start_server(
|
118
|
+
self,
|
119
|
+
model_path: str,
|
120
|
+
api_key: str | None = None,
|
121
|
+
) -> tuple[str, str]:
|
122
|
+
"""Start a new vLLM server and return the base URL and API key.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
model_path: Path to the model to use
|
126
|
+
api_key: API key for the server
|
127
|
+
Returns:
|
128
|
+
tuple[str, str]: The base URL for the server and the API key
|
129
|
+
"""
|
130
|
+
# Verify vllm package is installed since we're starting a server
|
131
|
+
try:
|
132
|
+
import vllm # type: ignore # noqa: F401
|
133
|
+
except ImportError:
|
134
|
+
raise pip_dependency_error("vLLM Server", ["vllm"])
|
135
|
+
|
136
|
+
# Handle device configuration
|
137
|
+
self.server_args, env_vars = configure_devices(
|
138
|
+
self.server_args, parallel_size_param="tensor_parallel_size"
|
158
139
|
)
|
159
|
-
return cast(str, chat)
|
160
140
|
|
161
|
-
|
162
|
-
|
163
|
-
"""Effectively the batch size."""
|
164
|
-
return 32
|
165
|
-
|
166
|
-
def get_sampling_params(self, config: GenerateConfig, chat: str) -> SamplingParams:
|
167
|
-
kwargs: dict[str, Any] = dict()
|
168
|
-
if config.max_tokens is not None:
|
169
|
-
kwargs["max_tokens"] = config.max_tokens
|
170
|
-
else:
|
171
|
-
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
|
172
|
-
|
173
|
-
if config.temperature is not None:
|
174
|
-
# for some reason vllm doesn't generate anything for 0 < temperature < 0.02
|
175
|
-
if 0 < config.temperature < 0.02:
|
176
|
-
config.temperature = 0.02
|
177
|
-
kwargs["temperature"] = config.temperature
|
178
|
-
if config.top_p is not None:
|
179
|
-
kwargs["top_p"] = config.top_p
|
180
|
-
if config.top_k is not None:
|
181
|
-
kwargs["top_k"] = config.top_k
|
182
|
-
# if config.min_p is not None:
|
183
|
-
# kwargs["min_p"] = config.min_p
|
184
|
-
if config.seed is not None:
|
185
|
-
kwargs["seed"] = config.seed
|
186
|
-
elif self.seed is not None:
|
187
|
-
kwargs["seed"] = self.seed
|
188
|
-
|
189
|
-
if config.frequency_penalty is not None:
|
190
|
-
kwargs["frequency_penalty"] = config.frequency_penalty
|
191
|
-
if config.presence_penalty is not None:
|
192
|
-
kwargs["presence_penalty"] = config.presence_penalty
|
193
|
-
|
194
|
-
if config.num_choices is not None:
|
195
|
-
kwargs["n"] = config.num_choices
|
196
|
-
if config.best_of is not None:
|
197
|
-
kwargs["best_of"] = config.best_of
|
198
|
-
|
199
|
-
if config.logprobs is not None:
|
200
|
-
kwargs["logprobs"] = 0
|
201
|
-
if config.top_logprobs is not None:
|
202
|
-
kwargs["logprobs"] = config.top_logprobs
|
203
|
-
|
204
|
-
if config.stop_seqs is not None:
|
205
|
-
kwargs["stop"] = config.stop_seqs
|
206
|
-
|
207
|
-
# some models don't stop at <|im_end|> token
|
208
|
-
# perhaps there is a better solution than this (modify tokenizer?)
|
209
|
-
# TODO: what model needs this?
|
210
|
-
if chat.startswith(DEFAULT_START_TOKEN):
|
211
|
-
if "stop" not in kwargs:
|
212
|
-
kwargs["stop"] = [DEFAULT_END_TOKEN]
|
213
|
-
else:
|
214
|
-
kwargs["stop"].append(DEFAULT_END_TOKEN)
|
141
|
+
if not api_key:
|
142
|
+
api_key = "inspectai" # Create a default API key if not provided
|
215
143
|
|
216
|
-
|
217
|
-
|
218
|
-
stop_token_ids=self.tokenizer.all_special_ids, # We default to stopping at all special tokens
|
219
|
-
include_stop_str_in_output=False,
|
220
|
-
)
|
221
|
-
return sampling_params
|
144
|
+
timeout = self.server_args.pop("timeout", None)
|
145
|
+
host = self.server_args.pop("host", "0.0.0.0")
|
222
146
|
|
223
|
-
|
224
|
-
|
225
|
-
input: list[ChatMessage],
|
226
|
-
tools: list[ToolInfo],
|
227
|
-
tool_choice: ToolChoice,
|
228
|
-
config: GenerateConfig,
|
229
|
-
) -> ModelOutput:
|
230
|
-
chat = self.apply_chat_template(input, tools)
|
147
|
+
# Build command as a list
|
148
|
+
cmd = ["vllm", "serve", model_path, "--host", host, "--api-key", api_key]
|
231
149
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
150
|
+
base_url, self.server_process, self.port = start_local_server(
|
151
|
+
cmd,
|
152
|
+
host=host,
|
153
|
+
port=None, # find a free port
|
154
|
+
api_key=api_key,
|
155
|
+
server_type="vLLM",
|
156
|
+
timeout=timeout,
|
157
|
+
server_args=self.server_args,
|
158
|
+
env=env_vars,
|
236
159
|
)
|
237
160
|
|
238
|
-
#
|
239
|
-
|
240
|
-
GenerateInput(
|
241
|
-
input=chat,
|
242
|
-
generator=generator,
|
243
|
-
batch_size=config.max_connections or self.max_connections(),
|
244
|
-
num_top_logprobs=config.top_logprobs,
|
245
|
-
)
|
246
|
-
)
|
161
|
+
# Register cleanup function to run when Python exits
|
162
|
+
atexit.register(self._cleanup_server)
|
247
163
|
|
248
|
-
return
|
249
|
-
|
250
|
-
def process_responses(
|
251
|
-
self, responses: list[GenerateOutput], tools: list[ToolInfo]
|
252
|
-
) -> ModelOutput:
|
253
|
-
choices = [
|
254
|
-
ChatCompletionChoice(
|
255
|
-
message=ChatMessageAssistant(
|
256
|
-
content=response.output, model=self.model_name, source="generate"
|
257
|
-
),
|
258
|
-
stop_reason=response.stop_reason,
|
259
|
-
logprobs=response.logprobs,
|
260
|
-
)
|
261
|
-
for response in responses
|
262
|
-
]
|
263
|
-
|
264
|
-
# TODO: what's the best way to calculate token usage for num_choices > 1
|
265
|
-
total_time = responses[0].time
|
266
|
-
input_tokens = responses[0].input_tokens
|
267
|
-
output_tokens = sum(response.output_tokens for response in responses)
|
268
|
-
total_tokens = input_tokens + output_tokens
|
269
|
-
|
270
|
-
return ModelOutput(
|
271
|
-
model=self.model_name,
|
272
|
-
choices=choices,
|
273
|
-
usage=ModelUsage(
|
274
|
-
input_tokens=input_tokens,
|
275
|
-
output_tokens=output_tokens,
|
276
|
-
total_tokens=total_tokens,
|
277
|
-
),
|
278
|
-
time=total_time,
|
279
|
-
)
|
164
|
+
return base_url, api_key
|
280
165
|
|
166
|
+
@property
|
167
|
+
def server_is_running(self) -> bool:
|
168
|
+
"""Check if the server is running."""
|
169
|
+
if self.server_process is None:
|
170
|
+
return False
|
281
171
|
|
282
|
-
|
283
|
-
|
284
|
-
input: GenerateInput
|
285
|
-
future: Future[list[GenerateOutput]]
|
172
|
+
# Check if process is still alive
|
173
|
+
return self.server_process.poll() is None
|
286
174
|
|
175
|
+
@override
|
176
|
+
def collapse_user_messages(self) -> bool:
|
177
|
+
return True
|
287
178
|
|
288
|
-
|
179
|
+
@override
|
180
|
+
def collapse_assistant_messages(self) -> bool:
|
181
|
+
return True
|
289
182
|
|
290
|
-
|
183
|
+
def _cleanup_server(self) -> None:
|
184
|
+
"""Cleanup method to terminate server process when Python exits."""
|
185
|
+
if self.server_is_running and self.server_process is not None:
|
186
|
+
logger.info("Cleaning up vLLM server")
|
187
|
+
terminate_process(self.server_process)
|
188
|
+
self.server_process, self.port = None, None
|
291
189
|
|
190
|
+
async def aclose(self) -> None:
|
191
|
+
"""Close the client and terminate the server if we started it."""
|
192
|
+
logger.info("Closing vLLM server")
|
292
193
|
|
293
|
-
|
294
|
-
|
295
|
-
global batch_thread
|
296
|
-
if batch_thread is None:
|
297
|
-
batch_thread = Thread(target=process_batches, daemon=True)
|
298
|
-
batch_thread.start()
|
194
|
+
# Close the OpenAI client
|
195
|
+
await super().aclose()
|
299
196
|
|
300
|
-
|
301
|
-
future = Future[list[GenerateOutput]]()
|
302
|
-
batch_queue.put(_QueueItem(input=input, future=future))
|
197
|
+
self.close()
|
303
198
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
return future.result(timeout=0.01)
|
308
|
-
except concurrent.futures.TimeoutError:
|
309
|
-
pass
|
310
|
-
await anyio.sleep(1)
|
311
|
-
|
312
|
-
|
313
|
-
def string_to_bytes(string: str) -> list[int]:
|
314
|
-
return list(map(ord, string))
|
315
|
-
|
316
|
-
|
317
|
-
def extract_logprobs(
|
318
|
-
completion: CompletionOutput, num_top_logprobs: int | None
|
319
|
-
) -> Logprobs | None:
|
320
|
-
if completion.logprobs is None or not completion.logprobs:
|
321
|
-
return None
|
322
|
-
|
323
|
-
# if config.logprobs = True, we want to get the selected tokens logprob
|
324
|
-
# but if config.top_logprobs is not set, we don't want to return the top logprobs
|
325
|
-
if num_top_logprobs is None:
|
326
|
-
num_top_logprobs = 0
|
327
|
-
|
328
|
-
logprobs = []
|
329
|
-
for token_id, logprob in zip(completion.token_ids, completion.logprobs):
|
330
|
-
top_logprobs = [
|
331
|
-
TopLogprob(
|
332
|
-
token=cast(str, token.decoded_token),
|
333
|
-
logprob=token.logprob,
|
334
|
-
bytes=string_to_bytes(cast(str, token.decoded_token)),
|
335
|
-
)
|
336
|
-
# exclude the chosen token if it's not in the top logprobs
|
337
|
-
for token in logprob.values()
|
338
|
-
if cast(int, token.rank) - 1 < num_top_logprobs
|
339
|
-
]
|
340
|
-
selected_token = logprob[token_id]
|
341
|
-
logprobs.append(
|
342
|
-
Logprob(
|
343
|
-
token=cast(str, selected_token.decoded_token),
|
344
|
-
logprob=selected_token.logprob,
|
345
|
-
bytes=string_to_bytes(cast(str, selected_token.decoded_token)),
|
346
|
-
top_logprobs=top_logprobs,
|
347
|
-
)
|
348
|
-
)
|
199
|
+
def close(self) -> None:
|
200
|
+
"""
|
201
|
+
Terminate the server if we started it.
|
349
202
|
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
def get_stop_reason(finish_reason: str | None) -> StopReason:
|
354
|
-
if finish_reason == "stop":
|
355
|
-
return "stop"
|
356
|
-
elif finish_reason == "length":
|
357
|
-
return "max_tokens"
|
358
|
-
elif finish_reason == "abort":
|
359
|
-
return "unknown"
|
360
|
-
else:
|
361
|
-
return "unknown"
|
362
|
-
|
363
|
-
|
364
|
-
def post_process_output(
|
365
|
-
output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
|
366
|
-
) -> GenerateOutput:
|
367
|
-
completion = output.outputs[i]
|
368
|
-
output_text: str = completion.text
|
369
|
-
|
370
|
-
# # remove end token if it's there (byproduct of default chat template)
|
371
|
-
# TODO: Remove
|
372
|
-
# if output_text.endswith(DEFAULT_END_TOKEN):
|
373
|
-
# output_text = output_text[:len(DEFAULT_END_TOKEN)]
|
374
|
-
|
375
|
-
input_tokens = len(output.prompt_token_ids)
|
376
|
-
output_tokens = len(completion.token_ids)
|
377
|
-
total_tokens = input_tokens + output_tokens
|
378
|
-
|
379
|
-
return GenerateOutput(
|
380
|
-
output=output_text,
|
381
|
-
input_tokens=input_tokens,
|
382
|
-
output_tokens=output_tokens,
|
383
|
-
total_tokens=total_tokens,
|
384
|
-
stop_reason=get_stop_reason(completion.finish_reason),
|
385
|
-
logprobs=extract_logprobs(completion, num_top_logprobs),
|
386
|
-
time=total_time,
|
387
|
-
)
|
388
|
-
|
389
|
-
|
390
|
-
def post_process_outputs(
|
391
|
-
output: RequestOutput, num_top_logprobs: int | None, total_time: float
|
392
|
-
) -> list[GenerateOutput]:
|
393
|
-
return [
|
394
|
-
post_process_output(output, i, num_top_logprobs, total_time)
|
395
|
-
for i in range(len(output.outputs))
|
396
|
-
]
|
397
|
-
|
398
|
-
|
399
|
-
def process_batches() -> None:
|
400
|
-
while True:
|
401
|
-
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
402
|
-
inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
|
403
|
-
while True:
|
404
|
-
try:
|
405
|
-
input = batch_queue.get(
|
406
|
-
timeout=2
|
407
|
-
) # wait 2 seconds max TODO: what's optimal wait time?
|
408
|
-
inputs.append((input.input, input.future))
|
409
|
-
if len(inputs) >= input.input.batch_size:
|
410
|
-
# max batch size reached
|
411
|
-
break
|
412
|
-
except Empty:
|
413
|
-
# we have exhausted the queue
|
414
|
-
break
|
415
|
-
|
416
|
-
# see if we have any work to do
|
417
|
-
if len(inputs) == 0:
|
418
|
-
continue
|
203
|
+
Note that this does not close the OpenAI client as we are not in an async context.
|
204
|
+
"""
|
205
|
+
self._cleanup_server()
|
419
206
|
|
420
|
-
|
421
|
-
|
422
|
-
first_input = inputs[0][0]
|
423
|
-
generator = first_input.generator
|
424
|
-
num_top_logprobs = first_input.num_top_logprobs
|
207
|
+
# Deregister the atexit handler since we've manually cleaned up
|
208
|
+
atexit.unregister(self._cleanup_server)
|
425
209
|
|
426
|
-
|
427
|
-
|
210
|
+
async def generate(
|
211
|
+
self,
|
212
|
+
input: list[ChatMessage],
|
213
|
+
tools: list[ToolInfo],
|
214
|
+
tool_choice: ToolChoice,
|
215
|
+
config: GenerateConfig,
|
216
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
217
|
+
# check if last message is an assistant message, in this case we want to
|
218
|
+
# continue the final message instead of generating a new one
|
219
|
+
if input[-1].role == "assistant":
|
220
|
+
# Create a copy of the config to avoid modifying the original
|
221
|
+
config = config.model_copy()
|
222
|
+
|
223
|
+
# Set these parameters in extra_body
|
224
|
+
if config.extra_body is None:
|
225
|
+
config.extra_body = {}
|
226
|
+
|
227
|
+
# Only set these values if they're not already present in extra_body
|
228
|
+
if (
|
229
|
+
"add_generation_prompt" not in config.extra_body
|
230
|
+
and "continue_final_message" not in config.extra_body
|
231
|
+
):
|
232
|
+
config.extra_body["add_generation_prompt"] = False
|
233
|
+
config.extra_body["continue_final_message"] = True
|
234
|
+
|
235
|
+
return await super().generate(input, tools, tool_choice, config)
|
428
236
|
|
429
|
-
|
430
|
-
|
431
|
-
|
237
|
+
@override
|
238
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
239
|
+
if ex.status_code == 400:
|
240
|
+
# Extract message safely
|
241
|
+
if isinstance(ex.body, dict) and "message" in ex.body:
|
242
|
+
content = str(ex.body.get("message"))
|
243
|
+
else:
|
244
|
+
content = ex.message
|
432
245
|
|
433
|
-
|
434
|
-
|
246
|
+
if "maximum context length" in content:
|
247
|
+
return ModelOutput.from_content(
|
248
|
+
self.model_name, content=content, stop_reason="model_length"
|
435
249
|
)
|
436
|
-
|
437
|
-
except Exception as e:
|
438
|
-
for _, future in inputs:
|
439
|
-
future.set_exception(e)
|
250
|
+
return ex
|
inspect_ai/scorer/_choice.py
CHANGED
inspect_ai/solver/__init__.py
CHANGED
@@ -6,7 +6,6 @@ from ._chain import chain
|
|
6
6
|
from ._critique import self_critique
|
7
7
|
from ._fork import fork
|
8
8
|
from ._human_agent import human_agent
|
9
|
-
from ._limit import SampleLimitExceededError
|
10
9
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
11
10
|
from ._plan import Plan, plan
|
12
11
|
from ._prompt import (
|
@@ -45,13 +44,13 @@ __all__ = [
|
|
45
44
|
"TaskState",
|
46
45
|
"Generate",
|
47
46
|
"MultipleChoiceTemplate",
|
48
|
-
"SampleLimitExceededError",
|
49
47
|
]
|
50
48
|
|
51
49
|
|
52
50
|
_TOOL_MODULE_VERSION_3_18 = "0.3.18"
|
53
51
|
_TOOL_MODULE_VERSION_3_19 = "0.3.19"
|
54
52
|
_SUBTASKS_MODULE_VERSION = "0.3.26"
|
53
|
+
_SAMPLE_LIMIT_VERSION = "0.3.91"
|
55
54
|
_REMOVED_IN = "0.4"
|
56
55
|
relocated_module_attribute(
|
57
56
|
"Tool", "inspect_ai.tool.Tool", _TOOL_MODULE_VERSION_3_18, _REMOVED_IN
|
@@ -137,3 +136,9 @@ relocated_module_attribute(
|
|
137
136
|
_SUBTASKS_MODULE_VERSION,
|
138
137
|
_REMOVED_IN,
|
139
138
|
)
|
139
|
+
relocated_module_attribute(
|
140
|
+
"SampleLimitExceededError",
|
141
|
+
"inspect_ai.util.LimitExceededError",
|
142
|
+
_SAMPLE_LIMIT_VERSION,
|
143
|
+
_REMOVED_IN,
|
144
|
+
)
|