inspect-ai 0.3.92__py3-none-any.whl → 0.3.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. inspect_ai/_cli/eval.py +27 -0
  2. inspect_ai/_eval/eval.py +19 -2
  3. inspect_ai/_eval/evalset.py +4 -1
  4. inspect_ai/_eval/run.py +41 -0
  5. inspect_ai/_eval/task/generate.py +38 -44
  6. inspect_ai/_eval/task/log.py +26 -28
  7. inspect_ai/_eval/task/run.py +13 -20
  8. inspect_ai/_util/local_server.py +368 -0
  9. inspect_ai/_util/working.py +10 -4
  10. inspect_ai/_view/www/dist/assets/index.css +159 -146
  11. inspect_ai/_view/www/dist/assets/index.js +1020 -1061
  12. inspect_ai/_view/www/log-schema.json +4 -3
  13. inspect_ai/_view/www/package.json +1 -1
  14. inspect_ai/_view/www/src/@types/log.d.ts +3 -2
  15. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
  16. inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
  17. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
  18. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
  19. inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
  20. inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
  21. inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
  22. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
  23. inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
  24. inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
  25. inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
  26. inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
  27. inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
  28. inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
  29. inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
  30. inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
  31. inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
  32. inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
  33. inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
  34. inspect_ai/_view/www/src/components/Card.css +0 -1
  35. inspect_ai/_view/www/src/constants.ts +2 -0
  36. inspect_ai/_view/www/src/utils/numeric.ts +17 -0
  37. inspect_ai/agent/_agent.py +3 -3
  38. inspect_ai/agent/_as_solver.py +20 -12
  39. inspect_ai/agent/_as_tool.py +15 -3
  40. inspect_ai/agent/_handoff.py +8 -1
  41. inspect_ai/agent/_run.py +11 -3
  42. inspect_ai/log/__init__.py +4 -0
  43. inspect_ai/log/_file.py +56 -0
  44. inspect_ai/log/_log.py +99 -0
  45. inspect_ai/log/_recorders/__init__.py +2 -0
  46. inspect_ai/log/_recorders/buffer/database.py +12 -11
  47. inspect_ai/log/_recorders/buffer/filestore.py +2 -2
  48. inspect_ai/log/_recorders/buffer/types.py +2 -2
  49. inspect_ai/log/_recorders/eval.py +20 -65
  50. inspect_ai/log/_recorders/file.py +28 -6
  51. inspect_ai/log/_recorders/recorder.py +7 -0
  52. inspect_ai/log/_recorders/types.py +1 -23
  53. inspect_ai/log/_samples.py +0 -8
  54. inspect_ai/log/_transcript.py +7 -1
  55. inspect_ai/log/_util.py +52 -0
  56. inspect_ai/model/__init__.py +5 -1
  57. inspect_ai/model/_call_tools.py +32 -12
  58. inspect_ai/model/_generate_config.py +14 -8
  59. inspect_ai/model/_model.py +21 -48
  60. inspect_ai/model/_model_output.py +25 -0
  61. inspect_ai/model/_openai.py +2 -0
  62. inspect_ai/model/_providers/anthropic.py +13 -23
  63. inspect_ai/model/_providers/openai_o1.py +8 -2
  64. inspect_ai/model/_providers/providers.py +18 -4
  65. inspect_ai/model/_providers/sglang.py +241 -0
  66. inspect_ai/model/_providers/vllm.py +207 -400
  67. inspect_ai/solver/__init__.py +7 -2
  68. inspect_ai/solver/_basic_agent.py +3 -10
  69. inspect_ai/solver/_task_state.py +26 -88
  70. inspect_ai/tool/_json_rpc_helpers.py +45 -17
  71. inspect_ai/tool/_mcp/_mcp.py +2 -0
  72. inspect_ai/tool/_mcp/_sandbox.py +8 -2
  73. inspect_ai/tool/_mcp/server.py +3 -1
  74. inspect_ai/tool/_tool_call.py +4 -1
  75. inspect_ai/tool/_tool_support_helpers.py +51 -12
  76. inspect_ai/tool/_tools/_bash_session.py +190 -68
  77. inspect_ai/tool/_tools/_computer/_computer.py +25 -1
  78. inspect_ai/tool/_tools/_text_editor.py +4 -3
  79. inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
  80. inspect_ai/util/__init__.py +12 -0
  81. inspect_ai/util/_limit.py +393 -0
  82. inspect_ai/util/_limited_conversation.py +57 -0
  83. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/METADATA +1 -1
  84. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/RECORD +89 -108
  85. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/WHEEL +1 -1
  86. inspect_ai/solver/_limit.py +0 -39
  87. inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
  88. inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
  89. inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
  90. inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
  91. inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
  92. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
  93. inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
  94. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
  95. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
  96. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
  97. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
  98. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
  99. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
  100. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
  101. inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
  102. inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
  103. inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
  104. inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
  105. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
  106. inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
  107. inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
  108. inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
  109. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
  110. inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
  111. inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
  112. inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
  113. inspect_ai/tool/_tools/_computer/test_args.py +0 -151
  114. /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
  115. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/entry_points.txt +0 -0
  116. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/licenses/LICENSE +0 -0
  117. {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.93.dist-info}/top_level.txt +0 -0
@@ -1,439 +1,246 @@
1
- import concurrent.futures
2
- import functools
3
- import gc
1
+ import atexit
2
+ import logging
4
3
  import os
5
- import time
6
- from concurrent.futures import Future
7
- from dataclasses import dataclass
8
- from queue import Empty, Queue
9
- from threading import Thread
10
- from typing import Any, cast
11
-
12
- import anyio
13
- from typing_extensions import override
14
- from vllm import LLM, CompletionOutput, RequestOutput, SamplingParams # type: ignore
15
-
16
- from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
17
- from inspect_ai.tool import ToolChoice, ToolInfo
18
-
19
- from .._chat_message import ChatMessage, ChatMessageAssistant
20
- from .._generate_config import GenerateConfig
21
- from .._model import ModelAPI, simple_input_messages
22
- from .._model_output import (
23
- ChatCompletionChoice,
24
- Logprob,
25
- Logprobs,
26
- ModelOutput,
27
- ModelUsage,
28
- StopReason,
29
- TopLogprob,
30
- )
31
- from .util import ChatAPIHandler, chat_api_input
32
-
33
- DEFAULT_START_TOKEN = "<|im_start|>"
34
- DEFAULT_END_TOKEN = "<|im_end|>"
35
-
36
- HF_TOKEN = "HF_TOKEN"
37
-
38
-
39
- @dataclass
40
- class GenerateInput:
41
- input: str
42
- generator: Any
43
- batch_size: int
44
- num_top_logprobs: int | None = None
45
-
4
+ from subprocess import Popen
5
+ from typing import Any
46
6
 
47
- @dataclass
48
- class GenerateOutput:
49
- output: str
50
- input_tokens: int
51
- output_tokens: int
52
- total_tokens: int
53
- stop_reason: StopReason
54
- logprobs: Logprobs | None
55
- time: float
7
+ from openai import APIStatusError
8
+ from typing_extensions import override
56
9
 
10
+ from inspect_ai._util.error import PrerequisiteError, pip_dependency_error
11
+ from inspect_ai._util.local_server import (
12
+ configure_devices,
13
+ merge_env_server_args,
14
+ start_local_server,
15
+ terminate_process,
16
+ )
17
+ from inspect_ai.model._chat_message import ChatMessage
18
+ from inspect_ai.model._generate_config import GenerateConfig
19
+ from inspect_ai.model._model_call import ModelCall
20
+ from inspect_ai.model._model_output import ModelOutput
21
+ from inspect_ai.tool._tool_choice import ToolChoice
22
+ from inspect_ai.tool._tool_info import ToolInfo
23
+
24
+ from .openai_compatible import OpenAICompatibleAPI
25
+
26
+ # Environment variable names
27
+ # VLLM_BASE_URL = "VLLM_BASE_URL"
28
+ # VLLM_API_KEY = "VLLM_API_KEY"
29
+ VLLM_DEFAULT_SERVER_ARGS = "VLLM_DEFAULT_SERVER_ARGS"
30
+ VLLM_CONFIGURE_LOGGING = "VLLM_CONFIGURE_LOGGING"
31
+
32
+ # Set up logger for this module
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class VLLMAPI(OpenAICompatibleAPI):
37
+ """
38
+ Provider for using vLLM models.
39
+
40
+ This provider can either:
41
+ 1. Connect to an existing vLLM server (if base_url or port is provided)
42
+ 2. Start a new vLLM server for the specified model
43
+
44
+ Additional server_args:
45
+ timeout (int): Timeout for the server (default: 10 minutes)
46
+ host (str): Host to bind the server to (default: "0.0.0.0")
47
+ configure_logging (bool): Enable fine-grained vLLM logging (default: False)
48
+ device (str): Devices to run the server on. Can be a single device or a list of devices as used in CUDA_VISIBLE_DEVICES. If tensor_parallel_size is not provided, the server will use the number of devices as the tensor parallel size.
49
+
50
+ Environment variables:
51
+ VLLM_BASE_URL: Base URL for an existing vLLM server
52
+ VLLM_API_KEY: API key for the vLLM server
53
+ VLLM_DEFAULT_SERVER_ARGS: JSON string of default server args, e.g. '{"tensor_parallel_size": 4, "max_model_len": 8192}'
54
+ VLLM_CONFIGURE_LOGGING: Enable fine-grained vLLM logging
55
+ """
57
56
 
58
- class VLLMAPI(ModelAPI):
59
57
  def __init__(
60
58
  self,
61
59
  model_name: str,
62
60
  base_url: str | None = None,
61
+ port: int | None = None,
63
62
  api_key: str | None = None,
64
63
  config: GenerateConfig = GenerateConfig(),
65
- **model_args: Any,
64
+ **server_args: Any,
66
65
  ) -> None:
67
- super().__init__(
68
- model_name=model_name,
69
- base_url=base_url,
70
- api_key=api_key,
71
- api_key_vars=[HF_TOKEN],
72
- config=config,
66
+ # Validate inputs
67
+ if base_url and port:
68
+ raise ValueError("base_url and port cannot both be provided.")
69
+ if port:
70
+ base_url = f"http://localhost:{port}/v1"
71
+
72
+ # Initialize server process and port variables
73
+ self.server_process: Popen[str] | None = None
74
+ self.port: int | None = port
75
+ self.server_args = merge_env_server_args(
76
+ VLLM_DEFAULT_SERVER_ARGS, server_args, logger
73
77
  )
74
78
 
75
- self.seed = None
76
- if config.seed is not None:
77
- self.seed = config.seed
78
-
79
- # collect known model_args (then delete them so we can pass the rest on)
80
- def collect_model_arg(name: str) -> Any | None:
81
- nonlocal model_args
82
- value = model_args.get(name, None)
83
- if value is not None:
84
- model_args.pop(name)
85
- return value
86
-
87
- device = collect_model_arg("device")
88
- tokenizer = collect_model_arg("tokenizer")
89
- model_path = collect_model_arg("model_path")
90
- tokenizer_path = collect_model_arg("tokenizer_path")
91
- self.batch_size = collect_model_arg("batch_size")
92
- self.chat_template = collect_model_arg("chat_template")
93
-
94
- # if user provides model_path, use that instead of model_name
95
- if model_path:
96
- model_name = model_path
97
-
98
- # load tokenizer
99
- if not tokenizer:
100
- if tokenizer_path:
101
- tokenizer = tokenizer_path
102
- else:
103
- tokenizer = model_name
104
-
105
- # set which GPUs are available to use
106
- if device is not None:
107
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(device)
108
-
109
- # tell vllm how many GPUs to use
110
- if "tensor_parallel_size" not in model_args:
111
- devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
112
- num_devices = len(devices)
113
- if num_devices > 1:
114
- model_args["tensor_parallel_size"] = num_devices
115
- else:
116
- model_args["tensor_parallel_size"] = 1
117
-
118
- # https://github.com/vllm-project/vllm/pull/6051
119
- # Gemma 2 models require FlashInfer backend for softcap logits
120
- if "google/gemma-2" in model_name:
121
- os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
122
- try:
123
- import importlib
124
-
125
- # check if flashinfer is installed
126
- importlib.import_module("flashinfer")
127
- except ImportError:
128
- raise ImportError(
129
- "To use the 'google/gemma-2' model, you must install the 'flashinfer' package. "
130
- "See https://docs.flashinfer.ai/installation.html"
131
- )
79
+ try:
80
+ # Try to initialize with existing server
81
+ super().__init__(
82
+ model_name=model_name,
83
+ base_url=base_url,
84
+ api_key=api_key,
85
+ config=config,
86
+ service="vLLM",
87
+ service_base_url=base_url,
88
+ )
89
+ logger.info(f"Using existing vLLM server at {self.base_url}")
90
+ except PrerequisiteError:
91
+ # No existing server found, start a new one
92
+ logger.warning(
93
+ f"Existing vLLM server not found. Starting new server for {model_name}."
94
+ )
132
95
 
133
- # load model
134
- self.model = LLM(model_name, tokenizer=tokenizer, **model_args)
96
+ # Extract and handle the configure_logging parameter
97
+ configure_logging = self.server_args.pop("configure_logging", False)
98
+ os.environ[VLLM_CONFIGURE_LOGGING] = "1" if configure_logging else "0"
99
+
100
+ # Start the server
101
+ base_url, api_key = self._start_server(model_name, api_key=api_key)
102
+ logger.warning(f"vLLM server started at {base_url}")
103
+
104
+ # Initialize with new server
105
+ super().__init__(
106
+ model_name=model_name,
107
+ base_url=base_url,
108
+ api_key=api_key,
109
+ config=config,
110
+ service="vLLM",
111
+ service_base_url=base_url,
112
+ )
135
113
 
136
- # we get the tokenizer so we can use it to apply the model's chat template later
137
- self.tokenizer = self.model.get_tokenizer()
114
+ def _start_server(
115
+ self,
116
+ model_path: str,
117
+ api_key: str | None = None,
118
+ ) -> tuple[str, str]:
119
+ """Start a new vLLM server and return the base URL and API key.
120
+
121
+ Args:
122
+ model_path: Path to the model to use
123
+ api_key: API key for the server
124
+ Returns:
125
+ tuple[str, str]: The base URL for the server and the API key
126
+ """
127
+ # Verify vllm package is installed since we're starting a server
128
+ try:
129
+ import vllm # type: ignore # noqa: F401
130
+ except ImportError:
131
+ raise pip_dependency_error("vLLM Server", ["vllm"])
138
132
 
139
- @override
140
- def close(self) -> None:
141
- self.tokenizer = None
142
- self.model = None
143
- gc.collect()
144
-
145
- def apply_chat_template(
146
- self, messages: list[ChatMessage], tools: list[ToolInfo]
147
- ) -> str:
148
- # handle system message and consecutive user messages
149
- messages = simple_input_messages(messages)
150
- # convert to chat template input format
151
- chat_messages = chat_api_input(messages, tools, ChatAPIHandler(self.model_name))
152
- # apply chat template
153
- chat = self.tokenizer.apply_chat_template(
154
- chat_messages,
155
- add_generation_prompt=True,
156
- tokenize=False,
157
- chat_template=self.chat_template,
133
+ # Handle device configuration
134
+ self.server_args = configure_devices(
135
+ self.server_args, parallel_size_param="tensor_parallel_size"
158
136
  )
159
- return cast(str, chat)
160
137
 
161
- @override
162
- def max_connections(self) -> int:
163
- """Effectively the batch size."""
164
- return 32
165
-
166
- def get_sampling_params(self, config: GenerateConfig, chat: str) -> SamplingParams:
167
- kwargs: dict[str, Any] = dict()
168
- if config.max_tokens is not None:
169
- kwargs["max_tokens"] = config.max_tokens
170
- else:
171
- kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
172
-
173
- if config.temperature is not None:
174
- # for some reason vllm doesn't generate anything for 0 < temperature < 0.02
175
- if 0 < config.temperature < 0.02:
176
- config.temperature = 0.02
177
- kwargs["temperature"] = config.temperature
178
- if config.top_p is not None:
179
- kwargs["top_p"] = config.top_p
180
- if config.top_k is not None:
181
- kwargs["top_k"] = config.top_k
182
- # if config.min_p is not None:
183
- # kwargs["min_p"] = config.min_p
184
- if config.seed is not None:
185
- kwargs["seed"] = config.seed
186
- elif self.seed is not None:
187
- kwargs["seed"] = self.seed
188
-
189
- if config.frequency_penalty is not None:
190
- kwargs["frequency_penalty"] = config.frequency_penalty
191
- if config.presence_penalty is not None:
192
- kwargs["presence_penalty"] = config.presence_penalty
193
-
194
- if config.num_choices is not None:
195
- kwargs["n"] = config.num_choices
196
- if config.best_of is not None:
197
- kwargs["best_of"] = config.best_of
198
-
199
- if config.logprobs is not None:
200
- kwargs["logprobs"] = 0
201
- if config.top_logprobs is not None:
202
- kwargs["logprobs"] = config.top_logprobs
203
-
204
- if config.stop_seqs is not None:
205
- kwargs["stop"] = config.stop_seqs
206
-
207
- # some models don't stop at <|im_end|> token
208
- # perhaps there is a better solution than this (modify tokenizer?)
209
- # TODO: what model needs this?
210
- if chat.startswith(DEFAULT_START_TOKEN):
211
- if "stop" not in kwargs:
212
- kwargs["stop"] = [DEFAULT_END_TOKEN]
213
- else:
214
- kwargs["stop"].append(DEFAULT_END_TOKEN)
138
+ if not api_key:
139
+ api_key = "inspectai" # Create a default API key if not provided
215
140
 
216
- sampling_params = SamplingParams(
217
- **kwargs,
218
- stop_token_ids=self.tokenizer.all_special_ids, # We default to stopping at all special tokens
219
- include_stop_str_in_output=False,
220
- )
221
- return sampling_params
141
+ timeout = self.server_args.pop("timeout", None)
142
+ host = self.server_args.pop("host", "0.0.0.0")
222
143
 
223
- async def generate(
224
- self,
225
- input: list[ChatMessage],
226
- tools: list[ToolInfo],
227
- tool_choice: ToolChoice,
228
- config: GenerateConfig,
229
- ) -> ModelOutput:
230
- chat = self.apply_chat_template(input, tools)
144
+ # Build command as a list
145
+ cmd = ["vllm", "serve", model_path, "--host", host, "--api-key", api_key]
231
146
 
232
- # prepare generator
233
- sampling_params = self.get_sampling_params(config, chat)
234
- generator = functools.partial(
235
- self.model.generate, sampling_params=sampling_params, use_tqdm=False
147
+ base_url, self.server_process, self.port = start_local_server(
148
+ cmd,
149
+ host=host,
150
+ port=None, # find a free port
151
+ api_key=api_key,
152
+ server_type="vLLM",
153
+ timeout=timeout,
154
+ server_args=self.server_args,
236
155
  )
237
156
 
238
- # generate
239
- responses = await batched_generate(
240
- GenerateInput(
241
- input=chat,
242
- generator=generator,
243
- batch_size=config.max_connections or self.max_connections(),
244
- num_top_logprobs=config.top_logprobs,
245
- )
246
- )
157
+ # Register cleanup function to run when Python exits
158
+ atexit.register(self._cleanup_server)
247
159
 
248
- return self.process_responses(responses, tools)
249
-
250
- def process_responses(
251
- self, responses: list[GenerateOutput], tools: list[ToolInfo]
252
- ) -> ModelOutput:
253
- choices = [
254
- ChatCompletionChoice(
255
- message=ChatMessageAssistant(
256
- content=response.output, model=self.model_name, source="generate"
257
- ),
258
- stop_reason=response.stop_reason,
259
- logprobs=response.logprobs,
260
- )
261
- for response in responses
262
- ]
263
-
264
- # TODO: what's the best way to calculate token usage for num_choices > 1
265
- total_time = responses[0].time
266
- input_tokens = responses[0].input_tokens
267
- output_tokens = sum(response.output_tokens for response in responses)
268
- total_tokens = input_tokens + output_tokens
269
-
270
- return ModelOutput(
271
- model=self.model_name,
272
- choices=choices,
273
- usage=ModelUsage(
274
- input_tokens=input_tokens,
275
- output_tokens=output_tokens,
276
- total_tokens=total_tokens,
277
- ),
278
- time=total_time,
279
- )
160
+ return base_url, api_key
280
161
 
162
+ @property
163
+ def server_is_running(self) -> bool:
164
+ """Check if the server is running."""
165
+ if self.server_process is None:
166
+ return False
281
167
 
282
- @dataclass
283
- class _QueueItem:
284
- input: GenerateInput
285
- future: Future[list[GenerateOutput]]
168
+ # Check if process is still alive
169
+ return self.server_process.poll() is None
286
170
 
171
+ @override
172
+ def collapse_user_messages(self) -> bool:
173
+ return True
287
174
 
288
- batch_thread: Thread | None = None
175
+ @override
176
+ def collapse_assistant_messages(self) -> bool:
177
+ return True
289
178
 
290
- batch_queue: "Queue[_QueueItem]" = Queue()
179
+ def _cleanup_server(self) -> None:
180
+ """Cleanup method to terminate server process when Python exits."""
181
+ if self.server_is_running and self.server_process is not None:
182
+ logger.info("Cleaning up vLLM server")
183
+ terminate_process(self.server_process)
184
+ self.server_process, self.port = None, None
291
185
 
186
+ async def aclose(self) -> None:
187
+ """Close the client and terminate the server if we started it."""
188
+ logger.info("Closing vLLM server")
292
189
 
293
- async def batched_generate(input: GenerateInput) -> list[GenerateOutput]:
294
- # start the background thread if necessary
295
- global batch_thread
296
- if batch_thread is None:
297
- batch_thread = Thread(target=process_batches, daemon=True)
298
- batch_thread.start()
190
+ # Close the OpenAI client
191
+ await super().aclose()
299
192
 
300
- # enqueue the job
301
- future = Future[list[GenerateOutput]]()
302
- batch_queue.put(_QueueItem(input=input, future=future))
193
+ self.close()
303
194
 
304
- # await the future
305
- while True:
306
- try:
307
- return future.result(timeout=0.01)
308
- except concurrent.futures.TimeoutError:
309
- pass
310
- await anyio.sleep(1)
311
-
312
-
313
- def string_to_bytes(string: str) -> list[int]:
314
- return list(map(ord, string))
315
-
316
-
317
- def extract_logprobs(
318
- completion: CompletionOutput, num_top_logprobs: int | None
319
- ) -> Logprobs | None:
320
- if completion.logprobs is None or not completion.logprobs:
321
- return None
322
-
323
- # if config.logprobs = True, we want to get the selected tokens logprob
324
- # but if config.top_logprobs is not set, we don't want to return the top logprobs
325
- if num_top_logprobs is None:
326
- num_top_logprobs = 0
327
-
328
- logprobs = []
329
- for token_id, logprob in zip(completion.token_ids, completion.logprobs):
330
- top_logprobs = [
331
- TopLogprob(
332
- token=cast(str, token.decoded_token),
333
- logprob=token.logprob,
334
- bytes=string_to_bytes(cast(str, token.decoded_token)),
335
- )
336
- # exclude the chosen token if it's not in the top logprobs
337
- for token in logprob.values()
338
- if cast(int, token.rank) - 1 < num_top_logprobs
339
- ]
340
- selected_token = logprob[token_id]
341
- logprobs.append(
342
- Logprob(
343
- token=cast(str, selected_token.decoded_token),
344
- logprob=selected_token.logprob,
345
- bytes=string_to_bytes(cast(str, selected_token.decoded_token)),
346
- top_logprobs=top_logprobs,
347
- )
348
- )
195
+ def close(self) -> None:
196
+ """
197
+ Terminate the server if we started it.
349
198
 
350
- return Logprobs(content=logprobs)
351
-
352
-
353
- def get_stop_reason(finish_reason: str | None) -> StopReason:
354
- if finish_reason == "stop":
355
- return "stop"
356
- elif finish_reason == "length":
357
- return "max_tokens"
358
- elif finish_reason == "abort":
359
- return "unknown"
360
- else:
361
- return "unknown"
362
-
363
-
364
- def post_process_output(
365
- output: RequestOutput, i: int, num_top_logprobs: int | None, total_time: float
366
- ) -> GenerateOutput:
367
- completion = output.outputs[i]
368
- output_text: str = completion.text
369
-
370
- # # remove end token if it's there (byproduct of default chat template)
371
- # TODO: Remove
372
- # if output_text.endswith(DEFAULT_END_TOKEN):
373
- # output_text = output_text[:len(DEFAULT_END_TOKEN)]
374
-
375
- input_tokens = len(output.prompt_token_ids)
376
- output_tokens = len(completion.token_ids)
377
- total_tokens = input_tokens + output_tokens
378
-
379
- return GenerateOutput(
380
- output=output_text,
381
- input_tokens=input_tokens,
382
- output_tokens=output_tokens,
383
- total_tokens=total_tokens,
384
- stop_reason=get_stop_reason(completion.finish_reason),
385
- logprobs=extract_logprobs(completion, num_top_logprobs),
386
- time=total_time,
387
- )
388
-
389
-
390
- def post_process_outputs(
391
- output: RequestOutput, num_top_logprobs: int | None, total_time: float
392
- ) -> list[GenerateOutput]:
393
- return [
394
- post_process_output(output, i, num_top_logprobs, total_time)
395
- for i in range(len(output.outputs))
396
- ]
397
-
398
-
399
- def process_batches() -> None:
400
- while True:
401
- # drain the queue (wait until no new messages have shown up for 2 seconds)
402
- inputs: list[tuple[GenerateInput, Future[list[GenerateOutput]]]] = []
403
- while True:
404
- try:
405
- input = batch_queue.get(
406
- timeout=2
407
- ) # wait 2 seconds max TODO: what's optimal wait time?
408
- inputs.append((input.input, input.future))
409
- if len(inputs) >= input.input.batch_size:
410
- # max batch size reached
411
- break
412
- except Empty:
413
- # we have exhausted the queue
414
- break
415
-
416
- # see if we have any work to do
417
- if len(inputs) == 0:
418
- continue
199
+ Note that this does not close the OpenAI client as we are not in an async context.
200
+ """
201
+ self._cleanup_server()
419
202
 
420
- try:
421
- start_time = time.monotonic()
422
- first_input = inputs[0][0]
423
- generator = first_input.generator
424
- num_top_logprobs = first_input.num_top_logprobs
203
+ # Deregister the atexit handler since we've manually cleaned up
204
+ atexit.unregister(self._cleanup_server)
425
205
 
426
- # generate
427
- outputs = generator([input[0].input for input in inputs])
206
+ async def generate(
207
+ self,
208
+ input: list[ChatMessage],
209
+ tools: list[ToolInfo],
210
+ tool_choice: ToolChoice,
211
+ config: GenerateConfig,
212
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
213
+ # check if last message is an assistant message, in this case we want to
214
+ # continue the final message instead of generating a new one
215
+ if input[-1].role == "assistant":
216
+ # Create a copy of the config to avoid modifying the original
217
+ config = config.model_copy()
218
+
219
+ # Set these parameters in extra_body
220
+ if config.extra_body is None:
221
+ config.extra_body = {}
222
+
223
+ # Only set these values if they're not already present in extra_body
224
+ if (
225
+ "add_generation_prompt" not in config.extra_body
226
+ and "continue_final_message" not in config.extra_body
227
+ ):
228
+ config.extra_body["add_generation_prompt"] = False
229
+ config.extra_body["continue_final_message"] = True
230
+
231
+ return await super().generate(input, tools, tool_choice, config)
428
232
 
429
- total_time = time.monotonic() - start_time
430
- for i, output in enumerate(outputs):
431
- future = inputs[i][1]
233
+ @override
234
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
235
+ if ex.status_code == 400:
236
+ # Extract message safely
237
+ if isinstance(ex.body, dict) and "message" in ex.body:
238
+ content = str(ex.body.get("message"))
239
+ else:
240
+ content = ex.message
432
241
 
433
- future.set_result(
434
- post_process_outputs(output, num_top_logprobs, total_time),
242
+ if "maximum context length" in content:
243
+ return ModelOutput.from_content(
244
+ self.model_name, content=content, stop_reason="model_length"
435
245
  )
436
-
437
- except Exception as e:
438
- for _, future in inputs:
439
- future.set_exception(e)
246
+ return ex
@@ -6,7 +6,6 @@ from ._chain import chain
6
6
  from ._critique import self_critique
7
7
  from ._fork import fork
8
8
  from ._human_agent import human_agent
9
- from ._limit import SampleLimitExceededError
10
9
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
11
10
  from ._plan import Plan, plan
12
11
  from ._prompt import (
@@ -45,13 +44,13 @@ __all__ = [
45
44
  "TaskState",
46
45
  "Generate",
47
46
  "MultipleChoiceTemplate",
48
- "SampleLimitExceededError",
49
47
  ]
50
48
 
51
49
 
52
50
  _TOOL_MODULE_VERSION_3_18 = "0.3.18"
53
51
  _TOOL_MODULE_VERSION_3_19 = "0.3.19"
54
52
  _SUBTASKS_MODULE_VERSION = "0.3.26"
53
+ _SAMPLE_LIMIT_VERSION = "0.3.91"
55
54
  _REMOVED_IN = "0.4"
56
55
  relocated_module_attribute(
57
56
  "Tool", "inspect_ai.tool.Tool", _TOOL_MODULE_VERSION_3_18, _REMOVED_IN
@@ -137,3 +136,9 @@ relocated_module_attribute(
137
136
  _SUBTASKS_MODULE_VERSION,
138
137
  _REMOVED_IN,
139
138
  )
139
+ relocated_module_attribute(
140
+ "SampleLimitExceededError",
141
+ "inspect_ai.util.LimitExceededError",
142
+ _SAMPLE_LIMIT_VERSION,
143
+ _REMOVED_IN,
144
+ )