speedy-utils 1.1.20__py3-none-any.whl → 1.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -1
- llm_utils/lm/llm_task.py +172 -251
- llm_utils/lm/utils.py +332 -110
- speedy_utils/multi_worker/process.py +128 -27
- speedy_utils/multi_worker/thread.py +341 -226
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/RECORD +9 -9
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -6,6 +6,15 @@ from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
|
6
6
|
|
|
7
7
|
LLM = LLMTask
|
|
8
8
|
|
|
9
|
+
# Convenience functions for killing VLLM servers
|
|
10
|
+
def kill_all_vllm() -> int:
|
|
11
|
+
"""Kill all tracked VLLM server processes. Returns number of processes killed."""
|
|
12
|
+
return LLMTask.kill_all_vllm()
|
|
13
|
+
|
|
14
|
+
def kill_vllm_on_port(port: int) -> bool:
|
|
15
|
+
"""Kill VLLM server on specific port. Returns True if server was killed."""
|
|
16
|
+
return LLMTask.kill_vllm_on_port(port)
|
|
17
|
+
|
|
9
18
|
from .chat_format import (
|
|
10
19
|
build_chatml_input,
|
|
11
20
|
display_chat_messages_as_html,
|
|
@@ -35,5 +44,7 @@ __all__ = [
|
|
|
35
44
|
"get_model_name",
|
|
36
45
|
"VectorCache",
|
|
37
46
|
"BasePromptBuilder",
|
|
38
|
-
"LLM"
|
|
47
|
+
"LLM",
|
|
48
|
+
"kill_all_vllm",
|
|
49
|
+
"kill_vllm_on_port"
|
|
39
50
|
]
|
llm_utils/lm/llm_task.py
CHANGED
|
@@ -5,6 +5,7 @@ Simplified LLM Task module for handling language model interactions with structu
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import os
|
|
8
|
+
import subprocess
|
|
8
9
|
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
9
10
|
|
|
10
11
|
import requests
|
|
@@ -13,177 +14,27 @@ from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
|
13
14
|
from openai.types.chat import ChatCompletionMessageParam
|
|
14
15
|
from pydantic import BaseModel
|
|
15
16
|
|
|
17
|
+
from .utils import (
|
|
18
|
+
_extract_port_from_vllm_cmd,
|
|
19
|
+
_start_vllm_server,
|
|
20
|
+
_kill_vllm_on_port,
|
|
21
|
+
_is_server_running,
|
|
22
|
+
get_base_client,
|
|
23
|
+
_is_lora_path,
|
|
24
|
+
_get_port_from_client,
|
|
25
|
+
_load_lora_adapter,
|
|
26
|
+
_unload_lora_adapter,
|
|
27
|
+
kill_all_vllm_processes,
|
|
28
|
+
stop_vllm_process,
|
|
29
|
+
)
|
|
16
30
|
from .base_prompt_builder import BasePromptBuilder
|
|
17
31
|
|
|
18
32
|
# Type aliases for better readability
|
|
19
33
|
Messages = List[ChatCompletionMessageParam]
|
|
20
34
|
|
|
21
35
|
|
|
22
|
-
def get_base_client(
|
|
23
|
-
client: Union[OpenAI, int, str, None] = None, cache: bool = True, api_key="abc"
|
|
24
|
-
) -> OpenAI:
|
|
25
|
-
"""Get OpenAI client from port number, base_url string, or existing client."""
|
|
26
|
-
from llm_utils import MOpenAI
|
|
27
|
-
|
|
28
|
-
open_ai_class = OpenAI if not cache else MOpenAI
|
|
29
|
-
if client is None:
|
|
30
|
-
return open_ai_class()
|
|
31
|
-
elif isinstance(client, int):
|
|
32
|
-
return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
|
|
33
|
-
elif isinstance(client, str):
|
|
34
|
-
return open_ai_class(base_url=client, api_key=api_key)
|
|
35
|
-
elif isinstance(client, OpenAI):
|
|
36
|
-
return client
|
|
37
|
-
else:
|
|
38
|
-
raise ValueError(
|
|
39
|
-
"Invalid client type. Must be OpenAI instance, port number (int), base_url (str), or None."
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _is_lora_path(path: str) -> bool:
|
|
44
|
-
"""Check if the given path is a LoRA adapter directory.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
path: Path to check
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
True if the path contains adapter_config.json, False otherwise
|
|
51
|
-
"""
|
|
52
|
-
if not os.path.isdir(path):
|
|
53
|
-
return False
|
|
54
|
-
adapter_config_path = os.path.join(path, 'adapter_config.json')
|
|
55
|
-
return os.path.isfile(adapter_config_path)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
59
|
-
"""Extract port number from OpenAI client base_url.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
client: OpenAI client instance
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
Port number if found, None otherwise
|
|
66
|
-
"""
|
|
67
|
-
if hasattr(client, 'base_url') and client.base_url:
|
|
68
|
-
base_url = str(client.base_url)
|
|
69
|
-
if 'localhost:' in base_url:
|
|
70
|
-
try:
|
|
71
|
-
# Extract port from localhost:PORT/v1 format
|
|
72
|
-
port_part = base_url.split('localhost:')[1].split('/')[0]
|
|
73
|
-
return int(port_part)
|
|
74
|
-
except (IndexError, ValueError):
|
|
75
|
-
pass
|
|
76
|
-
return None
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
80
|
-
"""Load a LoRA adapter from the specified path.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
lora_path: Path to the LoRA adapter directory
|
|
84
|
-
port: Port number for the API endpoint
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
Name of the loaded LoRA adapter
|
|
88
|
-
|
|
89
|
-
Raises:
|
|
90
|
-
requests.RequestException: If the API call fails
|
|
91
|
-
"""
|
|
92
|
-
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
93
|
-
if not lora_name: # Handle edge case of empty basename
|
|
94
|
-
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
95
|
-
|
|
96
|
-
response = requests.post(
|
|
97
|
-
f'http://localhost:{port}/v1/load_lora_adapter',
|
|
98
|
-
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
99
|
-
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
100
|
-
)
|
|
101
|
-
response.raise_for_status()
|
|
102
|
-
return lora_name
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
106
|
-
"""Unload the current LoRA adapter.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
lora_path: Path to the LoRA adapter directory
|
|
110
|
-
port: Port number for the API endpoint
|
|
111
|
-
"""
|
|
112
|
-
try:
|
|
113
|
-
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
114
|
-
if not lora_name: # Handle edge case of empty basename
|
|
115
|
-
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
116
|
-
|
|
117
|
-
response = requests.post(
|
|
118
|
-
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
119
|
-
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
120
|
-
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
121
|
-
)
|
|
122
|
-
response.raise_for_status()
|
|
123
|
-
except requests.RequestException as e:
|
|
124
|
-
logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
|
|
125
|
-
|
|
126
|
-
|
|
127
36
|
class LLMTask:
|
|
128
|
-
"""
|
|
129
|
-
Language model task with structured input/output and optional system instruction.
|
|
130
|
-
|
|
131
|
-
Supports str or Pydantic models for both input and output. Automatically handles
|
|
132
|
-
message formatting and response parsing.
|
|
133
|
-
|
|
134
|
-
Two main APIs:
|
|
135
|
-
- text(): Returns raw text responses as list of dicts (alias for text_completion)
|
|
136
|
-
- parse(): Returns parsed Pydantic model responses as list of dicts (alias for pydantic_parse)
|
|
137
|
-
- __call__(): Backward compatibility method that delegates based on output_model
|
|
138
|
-
|
|
139
|
-
Example:
|
|
140
|
-
```python
|
|
141
|
-
from pydantic import BaseModel
|
|
142
|
-
from llm_utils.lm.llm_task import LLMTask
|
|
143
|
-
|
|
144
|
-
class EmailOutput(BaseModel):
|
|
145
|
-
content: str
|
|
146
|
-
estimated_read_time: int
|
|
147
|
-
|
|
148
|
-
# Set up task with Pydantic output model
|
|
149
|
-
task = LLMTask(
|
|
150
|
-
instruction="Generate professional email content.",
|
|
151
|
-
output_model=EmailOutput,
|
|
152
|
-
client=OpenAI(),
|
|
153
|
-
temperature=0.7
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
# Use parse() for structured output
|
|
157
|
-
results = task.parse("Write a meeting follow-up email")
|
|
158
|
-
result = results[0]
|
|
159
|
-
print(result["parsed"].content, result["parsed"].estimated_read_time)
|
|
160
|
-
|
|
161
|
-
# Use text() for plain text output
|
|
162
|
-
results = task.text("Write a meeting follow-up email")
|
|
163
|
-
text_result = results[0]
|
|
164
|
-
print(text_result["parsed"])
|
|
165
|
-
|
|
166
|
-
# Multiple responses
|
|
167
|
-
results = task.parse("Write a meeting follow-up email", n=3)
|
|
168
|
-
for result in results:
|
|
169
|
-
print(f"Content: {result['parsed'].content}")
|
|
170
|
-
|
|
171
|
-
# Override parameters at runtime
|
|
172
|
-
results = task.text(
|
|
173
|
-
"Write a meeting follow-up email",
|
|
174
|
-
temperature=0.9,
|
|
175
|
-
n=2,
|
|
176
|
-
max_tokens=500
|
|
177
|
-
)
|
|
178
|
-
for result in results:
|
|
179
|
-
print(result["parsed"])
|
|
180
|
-
|
|
181
|
-
# Backward compatibility (uses output_model to choose method)
|
|
182
|
-
results = task("Write a meeting follow-up email") # Calls parse()
|
|
183
|
-
result = results[0]
|
|
184
|
-
print(result["parsed"].content)
|
|
185
|
-
```
|
|
186
|
-
"""
|
|
37
|
+
"""LLM task with structured input/output handling."""
|
|
187
38
|
|
|
188
39
|
def __init__(
|
|
189
40
|
self,
|
|
@@ -195,29 +46,12 @@ class LLMTask:
|
|
|
195
46
|
is_reasoning_model: bool = False,
|
|
196
47
|
force_lora_unload: bool = False,
|
|
197
48
|
lora_path: Optional[str] = None,
|
|
49
|
+
vllm_cmd: Optional[str] = None,
|
|
50
|
+
vllm_timeout: int = 1200,
|
|
51
|
+
vllm_reuse: bool = True,
|
|
198
52
|
**model_kwargs,
|
|
199
53
|
):
|
|
200
|
-
"""
|
|
201
|
-
Initialize the LLMTask.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
instruction: Optional system instruction for the task
|
|
205
|
-
input_model: Input type (str or BaseModel subclass)
|
|
206
|
-
output_model: Output BaseModel type
|
|
207
|
-
client: OpenAI client, port number, or base_url string
|
|
208
|
-
cache: Whether to use cached responses (default True)
|
|
209
|
-
is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
|
|
210
|
-
that outputs reasoning_content separately from content (default False)
|
|
211
|
-
force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
|
|
212
|
-
a new one when lora_path is provided (default False)
|
|
213
|
-
lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
|
|
214
|
-
and use it as the model. Takes precedence over model parameter.
|
|
215
|
-
**model_kwargs: Additional model parameters including:
|
|
216
|
-
- temperature: Controls randomness (0.0 to 2.0)
|
|
217
|
-
- n: Number of responses to generate (when n > 1, returns list)
|
|
218
|
-
- max_tokens: Maximum tokens in response
|
|
219
|
-
- model: Model name (auto-detected if not provided)
|
|
220
|
-
"""
|
|
54
|
+
"""Initialize LLMTask."""
|
|
221
55
|
self.instruction = instruction
|
|
222
56
|
self.input_model = input_model
|
|
223
57
|
self.output_model = output_model
|
|
@@ -225,15 +59,55 @@ class LLMTask:
|
|
|
225
59
|
self.is_reasoning_model = is_reasoning_model
|
|
226
60
|
self.force_lora_unload = force_lora_unload
|
|
227
61
|
self.lora_path = lora_path
|
|
62
|
+
self.vllm_cmd = vllm_cmd
|
|
63
|
+
self.vllm_timeout = vllm_timeout
|
|
64
|
+
self.vllm_reuse = vllm_reuse
|
|
65
|
+
self.vllm_process: Optional[subprocess.Popen] = None
|
|
228
66
|
self.last_ai_response = None # Store raw response from client
|
|
229
67
|
|
|
230
|
-
# if
|
|
231
|
-
|
|
68
|
+
# Handle VLLM server startup if vllm_cmd is provided
|
|
69
|
+
if self.vllm_cmd:
|
|
70
|
+
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
71
|
+
reuse_existing = False
|
|
232
72
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
73
|
+
if self.vllm_reuse:
|
|
74
|
+
try:
|
|
75
|
+
reuse_client = get_base_client(port, cache=False)
|
|
76
|
+
models_response = reuse_client.models.list()
|
|
77
|
+
if getattr(models_response, "data", None):
|
|
78
|
+
reuse_existing = True
|
|
79
|
+
logger.info(
|
|
80
|
+
f"VLLM server already running on port {port}, "
|
|
81
|
+
"reusing existing server (vllm_reuse=True)"
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
logger.info(
|
|
85
|
+
f"No models returned from VLLM server on port {port}; "
|
|
86
|
+
"starting a new server"
|
|
87
|
+
)
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
logger.info(
|
|
90
|
+
f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. "
|
|
91
|
+
"Starting a new server."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if not self.vllm_reuse:
|
|
95
|
+
if _is_server_running(port):
|
|
96
|
+
logger.info(
|
|
97
|
+
f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)"
|
|
98
|
+
)
|
|
99
|
+
_kill_vllm_on_port(port)
|
|
100
|
+
logger.info(f"Starting new VLLM server on port {port}")
|
|
101
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
102
|
+
elif not reuse_existing:
|
|
103
|
+
logger.info(f"Starting VLLM server on port {port}")
|
|
104
|
+
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
105
|
+
|
|
106
|
+
# Set client to use the VLLM server port if not explicitly provided
|
|
107
|
+
if client is None:
|
|
108
|
+
client = port
|
|
109
|
+
|
|
110
|
+
self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
|
|
237
111
|
# check connection of client
|
|
238
112
|
try:
|
|
239
113
|
self.client.models.list()
|
|
@@ -247,8 +121,20 @@ class LLMTask:
|
|
|
247
121
|
# Handle LoRA loading if lora_path is provided
|
|
248
122
|
if self.lora_path:
|
|
249
123
|
self._load_lora_adapter()
|
|
250
|
-
|
|
251
|
-
|
|
124
|
+
|
|
125
|
+
def cleanup_vllm_server(self) -> None:
|
|
126
|
+
"""Stop the VLLM server process if it was started by this instance."""
|
|
127
|
+
if self.vllm_process is not None:
|
|
128
|
+
stop_vllm_process(self.vllm_process)
|
|
129
|
+
self.vllm_process = None
|
|
130
|
+
|
|
131
|
+
def __enter__(self):
|
|
132
|
+
"""Context manager entry."""
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
136
|
+
"""Context manager exit with cleanup."""
|
|
137
|
+
self.cleanup_vllm_server()
|
|
252
138
|
|
|
253
139
|
def _load_lora_adapter(self) -> None:
|
|
254
140
|
"""
|
|
@@ -413,23 +299,7 @@ class LLMTask:
|
|
|
413
299
|
def text_completion(
|
|
414
300
|
self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs
|
|
415
301
|
) -> List[Dict[str, Any]]:
|
|
416
|
-
"""
|
|
417
|
-
Execute the LLM task and return text responses.
|
|
418
|
-
|
|
419
|
-
Args:
|
|
420
|
-
input_data: Input as string or BaseModel
|
|
421
|
-
**runtime_kwargs: Runtime model parameters that override defaults
|
|
422
|
-
- temperature: Controls randomness (0.0 to 2.0)
|
|
423
|
-
- n: Number of responses to generate
|
|
424
|
-
- max_tokens: Maximum tokens in response
|
|
425
|
-
- model: Model name override
|
|
426
|
-
- Any other model parameters supported by OpenAI API
|
|
427
|
-
|
|
428
|
-
Returns:
|
|
429
|
-
List of dicts [{'parsed': text_response, 'messages': messages}, ...]
|
|
430
|
-
When n=1: List contains one dict
|
|
431
|
-
When n>1: List contains multiple dicts
|
|
432
|
-
"""
|
|
302
|
+
"""Execute LLM task and return text responses."""
|
|
433
303
|
# Prepare messages
|
|
434
304
|
messages = self._prepare_input(input_data)
|
|
435
305
|
|
|
@@ -481,29 +351,7 @@ class LLMTask:
|
|
|
481
351
|
response_model: Optional[Type[BaseModel]] | Type[str] = None,
|
|
482
352
|
**runtime_kwargs,
|
|
483
353
|
) -> List[Dict[str, Any]]:
|
|
484
|
-
"""
|
|
485
|
-
Execute the LLM task and return parsed Pydantic model responses.
|
|
486
|
-
|
|
487
|
-
Args:
|
|
488
|
-
input_data: Input as string or BaseModel
|
|
489
|
-
response_model: Pydantic model for response parsing (overrides default)
|
|
490
|
-
**runtime_kwargs: Runtime model parameters that override defaults
|
|
491
|
-
- temperature: Controls randomness (0.0 to 2.0)
|
|
492
|
-
- n: Number of responses to generate
|
|
493
|
-
- max_tokens: Maximum tokens in response
|
|
494
|
-
- model: Model name override
|
|
495
|
-
- Any other model parameters supported by OpenAI API
|
|
496
|
-
|
|
497
|
-
Returns:
|
|
498
|
-
List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
|
|
499
|
-
When n=1: List contains one dict
|
|
500
|
-
When n>1: List contains multiple dicts
|
|
501
|
-
|
|
502
|
-
Note:
|
|
503
|
-
This method ensures consistent Pydantic model output for both fresh and cached responses.
|
|
504
|
-
When responses are cached and loaded back, the parsed content is re-validated to maintain
|
|
505
|
-
type consistency between first-time and subsequent calls.
|
|
506
|
-
"""
|
|
354
|
+
"""Execute LLM task and return parsed Pydantic model responses."""
|
|
507
355
|
# Prepare messages
|
|
508
356
|
messages = self._prepare_input(input_data)
|
|
509
357
|
|
|
@@ -576,20 +424,7 @@ class LLMTask:
|
|
|
576
424
|
two_step_parse_pydantic=False,
|
|
577
425
|
**runtime_kwargs,
|
|
578
426
|
) -> List[Dict[str, Any]]:
|
|
579
|
-
"""
|
|
580
|
-
Execute the LLM task. Delegates to text() or parse() based on output_model.
|
|
581
|
-
|
|
582
|
-
This method maintains backward compatibility by automatically choosing
|
|
583
|
-
between text and parse methods based on the output_model configuration.
|
|
584
|
-
|
|
585
|
-
Args:
|
|
586
|
-
input_data: Input as string or BaseModel
|
|
587
|
-
response_model: Optional override for output model
|
|
588
|
-
**runtime_kwargs: Runtime model parameters
|
|
589
|
-
|
|
590
|
-
Returns:
|
|
591
|
-
List of dicts [{'parsed': response, 'messages': messages}, ...]
|
|
592
|
-
"""
|
|
427
|
+
"""Execute LLM task. Delegates to text() or parse() based on output_model."""
|
|
593
428
|
pydantic_model_to_use = response_model or self.output_model
|
|
594
429
|
|
|
595
430
|
if pydantic_model_to_use is str or pydantic_model_to_use is None:
|
|
@@ -606,10 +441,8 @@ class LLMTask:
|
|
|
606
441
|
response_text = response_text.split("</think>")[1]
|
|
607
442
|
try:
|
|
608
443
|
parsed = pydantic_model_to_use.model_validate_json(response_text)
|
|
609
|
-
except Exception
|
|
610
|
-
#
|
|
611
|
-
# f"Warning: Failed to parsed JSON, Falling back to LLM parsing. Error: {str(e)[:100]}..."
|
|
612
|
-
# )
|
|
444
|
+
except Exception:
|
|
445
|
+
# Failed to parse JSON, falling back to LLM parsing
|
|
613
446
|
# use model to parse the response_text
|
|
614
447
|
_parsed_messages = [
|
|
615
448
|
{
|
|
@@ -652,6 +485,9 @@ class LLMTask:
|
|
|
652
485
|
cache=True,
|
|
653
486
|
is_reasoning_model: bool = False,
|
|
654
487
|
lora_path: Optional[str] = None,
|
|
488
|
+
vllm_cmd: Optional[str] = None,
|
|
489
|
+
vllm_timeout: int = 120,
|
|
490
|
+
vllm_reuse: bool = True,
|
|
655
491
|
**model_kwargs,
|
|
656
492
|
) -> "LLMTask":
|
|
657
493
|
"""
|
|
@@ -659,6 +495,17 @@ class LLMTask:
|
|
|
659
495
|
|
|
660
496
|
This method extracts the instruction, input model, and output model
|
|
661
497
|
from the provided builder and initializes an LLMTask accordingly.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
builder: BasePromptBuilder instance
|
|
501
|
+
client: OpenAI client, port number, or base_url string
|
|
502
|
+
cache: Whether to use cached responses (default True)
|
|
503
|
+
is_reasoning_model: Whether model is reasoning model (default False)
|
|
504
|
+
lora_path: Optional path to LoRA adapter directory
|
|
505
|
+
vllm_cmd: Optional VLLM command to start server automatically
|
|
506
|
+
vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
|
|
507
|
+
vllm_reuse: If True (default), reuse existing server on target port
|
|
508
|
+
**model_kwargs: Additional model parameters
|
|
662
509
|
"""
|
|
663
510
|
instruction = builder.get_instruction()
|
|
664
511
|
input_model = builder.get_input_model()
|
|
@@ -673,6 +520,9 @@ class LLMTask:
|
|
|
673
520
|
cache=cache,
|
|
674
521
|
is_reasoning_model=is_reasoning_model,
|
|
675
522
|
lora_path=lora_path,
|
|
523
|
+
vllm_cmd=vllm_cmd,
|
|
524
|
+
vllm_timeout=vllm_timeout,
|
|
525
|
+
vllm_reuse=vllm_reuse,
|
|
676
526
|
**model_kwargs,
|
|
677
527
|
)
|
|
678
528
|
|
|
@@ -691,3 +541,74 @@ class LLMTask:
|
|
|
691
541
|
models = client.models.list().data
|
|
692
542
|
return [m.id for m in models]
|
|
693
543
|
|
|
544
|
+
@staticmethod
|
|
545
|
+
def kill_all_vllm() -> int:
|
|
546
|
+
"""Kill all tracked VLLM server processes."""
|
|
547
|
+
return kill_all_vllm_processes()
|
|
548
|
+
|
|
549
|
+
@staticmethod
|
|
550
|
+
def kill_vllm_on_port(port: int) -> bool:
|
|
551
|
+
"""
|
|
552
|
+
Kill VLLM server running on a specific port.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
port: Port number to kill server on
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
True if a server was killed, False if no server was running
|
|
559
|
+
"""
|
|
560
|
+
return _kill_vllm_on_port(port)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
# Example usage:
|
|
564
|
+
if __name__ == "__main__":
|
|
565
|
+
# Example 1: Using VLLM with reuse (default behavior)
|
|
566
|
+
vllm_command = (
|
|
567
|
+
"vllm serve saves/vng/dpo/01 -tp 4 --port 8001 "
|
|
568
|
+
"--gpu-memory-utilization 0.9 --served-model-name sft --quantization experts_int8"
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
print("Example 1: Using VLLM with server reuse (default)")
|
|
572
|
+
# Create LLM instance - will reuse existing server if running on port 8001
|
|
573
|
+
with LLMTask(vllm_cmd=vllm_command) as llm: # vllm_reuse=True by default
|
|
574
|
+
result = llm.text("Hello, how are you?")
|
|
575
|
+
print("Response:", result[0]["parsed"])
|
|
576
|
+
|
|
577
|
+
print("\nExample 2: Force restart server (vllm_reuse=False)")
|
|
578
|
+
# This will kill any existing server on port 8001 and start fresh
|
|
579
|
+
with LLMTask(vllm_cmd=vllm_command, vllm_reuse=False) as llm:
|
|
580
|
+
result = llm.text("Tell me a joke")
|
|
581
|
+
print("Joke:", result[0]["parsed"])
|
|
582
|
+
|
|
583
|
+
print("\nExample 3: Multiple instances with reuse")
|
|
584
|
+
# First instance starts the server
|
|
585
|
+
llm1 = LLMTask(vllm_cmd=vllm_command) # Starts server or reuses existing
|
|
586
|
+
|
|
587
|
+
# Second instance reuses the same server
|
|
588
|
+
llm2 = LLMTask(vllm_cmd=vllm_command) # Reuses server on port 8001
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
result1 = llm1.text("What's the weather like?")
|
|
592
|
+
result2 = llm2.text("How's the traffic?")
|
|
593
|
+
print("Weather response:", result1[0]["parsed"])
|
|
594
|
+
print("Traffic response:", result2[0]["parsed"])
|
|
595
|
+
finally:
|
|
596
|
+
# Only cleanup if we started the process
|
|
597
|
+
llm1.cleanup_vllm_server()
|
|
598
|
+
llm2.cleanup_vllm_server() # Won't do anything if process not owned
|
|
599
|
+
|
|
600
|
+
print("\nExample 4: Different ports")
|
|
601
|
+
# These will start separate servers
|
|
602
|
+
llm_8001 = LLMTask(vllm_cmd="vllm serve model1 --port 8001", vllm_reuse=True)
|
|
603
|
+
llm_8002 = LLMTask(vllm_cmd="vllm serve model2 --port 8002", vllm_reuse=True)
|
|
604
|
+
|
|
605
|
+
print("\nExample 5: Kill all VLLM servers")
|
|
606
|
+
# Kill all tracked VLLM processes
|
|
607
|
+
killed_count = LLMTask.kill_all_vllm()
|
|
608
|
+
print(f"Killed {killed_count} VLLM servers")
|
|
609
|
+
|
|
610
|
+
print("\nYou can check VLLM server output at: /tmp/vllm.txt")
|
|
611
|
+
print("Server reuse behavior:")
|
|
612
|
+
print("- vllm_reuse=True (default): Reuse existing server on target port")
|
|
613
|
+
print("- vllm_reuse=False: Kill existing server first, then start fresh")
|
|
614
|
+
|