speedy-utils 1.1.23__py3-none-any.whl → 1.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -8
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +14 -6
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +26 -37
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +28 -1
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/RECORD +19 -18
- llm_utils/lm/llm_as_a_judge.py +0 -390
- llm_utils/lm/llm_task.py +0 -614
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.23.dist-info → speedy_utils-1.1.24.dist-info}/entry_points.txt +0 -0
llm_utils/lm/llm_task.py
DELETED
|
@@ -1,614 +0,0 @@
|
|
|
1
|
-
# type: ignore
|
|
2
|
-
|
|
3
|
-
"""
|
|
4
|
-
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
import os
|
|
8
|
-
import subprocess
|
|
9
|
-
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
10
|
-
|
|
11
|
-
import requests
|
|
12
|
-
from loguru import logger
|
|
13
|
-
from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
14
|
-
from openai.types.chat import ChatCompletionMessageParam
|
|
15
|
-
from pydantic import BaseModel
|
|
16
|
-
|
|
17
|
-
from .utils import (
|
|
18
|
-
_extract_port_from_vllm_cmd,
|
|
19
|
-
_start_vllm_server,
|
|
20
|
-
_kill_vllm_on_port,
|
|
21
|
-
_is_server_running,
|
|
22
|
-
get_base_client,
|
|
23
|
-
_is_lora_path,
|
|
24
|
-
_get_port_from_client,
|
|
25
|
-
_load_lora_adapter,
|
|
26
|
-
_unload_lora_adapter,
|
|
27
|
-
kill_all_vllm_processes,
|
|
28
|
-
stop_vllm_process,
|
|
29
|
-
)
|
|
30
|
-
from .base_prompt_builder import BasePromptBuilder
|
|
31
|
-
|
|
32
|
-
# Type aliases for better readability
|
|
33
|
-
Messages = List[ChatCompletionMessageParam]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class LLMTask:
|
|
37
|
-
"""LLM task with structured input/output handling."""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
instruction: Optional[str] = None,
|
|
42
|
-
input_model: Union[Type[BaseModel], type[str]] = str,
|
|
43
|
-
output_model: Type[BaseModel] | Type[str] = None,
|
|
44
|
-
client: Union[OpenAI, int, str, None] = None,
|
|
45
|
-
cache=True,
|
|
46
|
-
is_reasoning_model: bool = False,
|
|
47
|
-
force_lora_unload: bool = False,
|
|
48
|
-
lora_path: Optional[str] = None,
|
|
49
|
-
vllm_cmd: Optional[str] = None,
|
|
50
|
-
vllm_timeout: int = 1200,
|
|
51
|
-
vllm_reuse: bool = True,
|
|
52
|
-
**model_kwargs,
|
|
53
|
-
):
|
|
54
|
-
"""Initialize LLMTask."""
|
|
55
|
-
self.instruction = instruction
|
|
56
|
-
self.input_model = input_model
|
|
57
|
-
self.output_model = output_model
|
|
58
|
-
self.model_kwargs = model_kwargs
|
|
59
|
-
self.is_reasoning_model = is_reasoning_model
|
|
60
|
-
self.force_lora_unload = force_lora_unload
|
|
61
|
-
self.lora_path = lora_path
|
|
62
|
-
self.vllm_cmd = vllm_cmd
|
|
63
|
-
self.vllm_timeout = vllm_timeout
|
|
64
|
-
self.vllm_reuse = vllm_reuse
|
|
65
|
-
self.vllm_process: Optional[subprocess.Popen] = None
|
|
66
|
-
self.last_ai_response = None # Store raw response from client
|
|
67
|
-
|
|
68
|
-
# Handle VLLM server startup if vllm_cmd is provided
|
|
69
|
-
if self.vllm_cmd:
|
|
70
|
-
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
71
|
-
reuse_existing = False
|
|
72
|
-
|
|
73
|
-
if self.vllm_reuse:
|
|
74
|
-
try:
|
|
75
|
-
reuse_client = get_base_client(port, cache=False)
|
|
76
|
-
models_response = reuse_client.models.list()
|
|
77
|
-
if getattr(models_response, "data", None):
|
|
78
|
-
reuse_existing = True
|
|
79
|
-
logger.info(
|
|
80
|
-
f"VLLM server already running on port {port}, "
|
|
81
|
-
"reusing existing server (vllm_reuse=True)"
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
logger.info(
|
|
85
|
-
f"No models returned from VLLM server on port {port}; "
|
|
86
|
-
"starting a new server"
|
|
87
|
-
)
|
|
88
|
-
except Exception as exc:
|
|
89
|
-
logger.info(
|
|
90
|
-
f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. "
|
|
91
|
-
"Starting a new server."
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
if not self.vllm_reuse:
|
|
95
|
-
if _is_server_running(port):
|
|
96
|
-
logger.info(
|
|
97
|
-
f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)"
|
|
98
|
-
)
|
|
99
|
-
_kill_vllm_on_port(port)
|
|
100
|
-
logger.info(f"Starting new VLLM server on port {port}")
|
|
101
|
-
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
102
|
-
elif not reuse_existing:
|
|
103
|
-
logger.info(f"Starting VLLM server on port {port}")
|
|
104
|
-
self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
|
|
105
|
-
|
|
106
|
-
# Set client to use the VLLM server port if not explicitly provided
|
|
107
|
-
if client is None:
|
|
108
|
-
client = port
|
|
109
|
-
|
|
110
|
-
self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
|
|
111
|
-
# check connection of client
|
|
112
|
-
try:
|
|
113
|
-
self.client.models.list()
|
|
114
|
-
except Exception as e:
|
|
115
|
-
logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
|
|
116
|
-
raise e
|
|
117
|
-
|
|
118
|
-
if not self.model_kwargs.get("model", ""):
|
|
119
|
-
self.model_kwargs["model"] = self.client.models.list().data[0].id
|
|
120
|
-
|
|
121
|
-
# Handle LoRA loading if lora_path is provided
|
|
122
|
-
if self.lora_path:
|
|
123
|
-
self._load_lora_adapter()
|
|
124
|
-
|
|
125
|
-
def cleanup_vllm_server(self) -> None:
|
|
126
|
-
"""Stop the VLLM server process if it was started by this instance."""
|
|
127
|
-
if self.vllm_process is not None:
|
|
128
|
-
stop_vllm_process(self.vllm_process)
|
|
129
|
-
self.vllm_process = None
|
|
130
|
-
|
|
131
|
-
def __enter__(self):
|
|
132
|
-
"""Context manager entry."""
|
|
133
|
-
return self
|
|
134
|
-
|
|
135
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
136
|
-
"""Context manager exit with cleanup."""
|
|
137
|
-
self.cleanup_vllm_server()
|
|
138
|
-
|
|
139
|
-
def _load_lora_adapter(self) -> None:
|
|
140
|
-
"""
|
|
141
|
-
Load LoRA adapter from the specified lora_path.
|
|
142
|
-
|
|
143
|
-
This method:
|
|
144
|
-
1. Validates that lora_path is a valid LoRA directory
|
|
145
|
-
2. Checks if LoRA is already loaded (unless force_lora_unload is True)
|
|
146
|
-
3. Loads the LoRA adapter and updates the model name
|
|
147
|
-
"""
|
|
148
|
-
if not self.lora_path:
|
|
149
|
-
return
|
|
150
|
-
|
|
151
|
-
if not _is_lora_path(self.lora_path):
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"Invalid LoRA path '{self.lora_path}': "
|
|
154
|
-
"Directory must contain 'adapter_config.json'"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
158
|
-
|
|
159
|
-
# Get the expected LoRA name (basename of the path)
|
|
160
|
-
lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
|
|
161
|
-
if not lora_name: # Handle edge case of empty basename
|
|
162
|
-
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
163
|
-
|
|
164
|
-
# Get list of available models to check if LoRA is already loaded
|
|
165
|
-
try:
|
|
166
|
-
available_models = [m.id for m in self.client.models.list().data]
|
|
167
|
-
except Exception as e:
|
|
168
|
-
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
169
|
-
available_models = []
|
|
170
|
-
|
|
171
|
-
# Check if LoRA is already loaded
|
|
172
|
-
if lora_name in available_models and not self.force_lora_unload:
|
|
173
|
-
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
174
|
-
self.model_kwargs["model"] = lora_name
|
|
175
|
-
return
|
|
176
|
-
|
|
177
|
-
# Force unload if requested
|
|
178
|
-
if self.force_lora_unload and lora_name in available_models:
|
|
179
|
-
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
180
|
-
port = _get_port_from_client(self.client)
|
|
181
|
-
if port is not None:
|
|
182
|
-
try:
|
|
183
|
-
LLMTask.unload_lora(port, lora_name)
|
|
184
|
-
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
185
|
-
except Exception as e:
|
|
186
|
-
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
187
|
-
|
|
188
|
-
# Get port from client for API calls
|
|
189
|
-
port = _get_port_from_client(self.client)
|
|
190
|
-
if port is None:
|
|
191
|
-
raise ValueError(
|
|
192
|
-
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
193
|
-
"Unable to determine port from client base_url. "
|
|
194
|
-
"LoRA loading requires a client initialized with port number."
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
try:
|
|
198
|
-
# Load the LoRA adapter
|
|
199
|
-
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
200
|
-
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
201
|
-
|
|
202
|
-
# Update model name to the loaded LoRA name
|
|
203
|
-
self.model_kwargs["model"] = loaded_lora_name
|
|
204
|
-
|
|
205
|
-
except requests.RequestException as e:
|
|
206
|
-
# Check if the error is due to LoRA already being loaded
|
|
207
|
-
error_msg = str(e)
|
|
208
|
-
if "400" in error_msg or "Bad Request" in error_msg:
|
|
209
|
-
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
210
|
-
# Refresh the model list to check if it's now available
|
|
211
|
-
try:
|
|
212
|
-
updated_models = [m.id for m in self.client.models.list().data]
|
|
213
|
-
if lora_name in updated_models:
|
|
214
|
-
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
215
|
-
self.model_kwargs["model"] = lora_name
|
|
216
|
-
return
|
|
217
|
-
except Exception:
|
|
218
|
-
pass # Fall through to original error
|
|
219
|
-
|
|
220
|
-
raise ValueError(
|
|
221
|
-
f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
225
|
-
"""
|
|
226
|
-
Unload a LoRA adapter.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
lora_path: Path to the LoRA adapter directory to unload
|
|
230
|
-
|
|
231
|
-
Raises:
|
|
232
|
-
ValueError: If unable to determine port from client
|
|
233
|
-
"""
|
|
234
|
-
port = _get_port_from_client(self.client)
|
|
235
|
-
if port is None:
|
|
236
|
-
raise ValueError(
|
|
237
|
-
"Cannot unload LoRA adapter: "
|
|
238
|
-
"Unable to determine port from client base_url. "
|
|
239
|
-
"LoRA operations require a client initialized with port number."
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
_unload_lora_adapter(lora_path, port)
|
|
243
|
-
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
244
|
-
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
245
|
-
|
|
246
|
-
@staticmethod
|
|
247
|
-
def unload_lora(port: int, lora_name: str) -> None:
|
|
248
|
-
"""Static method to unload a LoRA adapter by name.
|
|
249
|
-
|
|
250
|
-
Args:
|
|
251
|
-
port: Port number for the API endpoint
|
|
252
|
-
lora_name: Name of the LoRA adapter to unload
|
|
253
|
-
|
|
254
|
-
Raises:
|
|
255
|
-
requests.RequestException: If the API call fails
|
|
256
|
-
"""
|
|
257
|
-
try:
|
|
258
|
-
response = requests.post(
|
|
259
|
-
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
260
|
-
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
261
|
-
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
262
|
-
)
|
|
263
|
-
response.raise_for_status()
|
|
264
|
-
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
265
|
-
except requests.RequestException as e:
|
|
266
|
-
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
267
|
-
raise
|
|
268
|
-
|
|
269
|
-
def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
|
|
270
|
-
"""Convert input to messages format."""
|
|
271
|
-
if isinstance(input_data, list):
|
|
272
|
-
assert isinstance(input_data[0], dict) and "role" in input_data[0], (
|
|
273
|
-
"If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
|
|
274
|
-
)
|
|
275
|
-
return cast(Messages, input_data)
|
|
276
|
-
else:
|
|
277
|
-
# Convert input to string format
|
|
278
|
-
if isinstance(input_data, str):
|
|
279
|
-
user_content = input_data
|
|
280
|
-
elif hasattr(input_data, "model_dump_json"):
|
|
281
|
-
user_content = input_data.model_dump_json()
|
|
282
|
-
elif isinstance(input_data, dict):
|
|
283
|
-
user_content = str(input_data)
|
|
284
|
-
else:
|
|
285
|
-
user_content = str(input_data)
|
|
286
|
-
|
|
287
|
-
# Build messages
|
|
288
|
-
messages = (
|
|
289
|
-
[
|
|
290
|
-
{"role": "system", "content": self.instruction},
|
|
291
|
-
]
|
|
292
|
-
if self.instruction is not None
|
|
293
|
-
else []
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
messages.append({"role": "user", "content": user_content})
|
|
297
|
-
return cast(Messages, messages)
|
|
298
|
-
|
|
299
|
-
def text_completion(
|
|
300
|
-
self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs
|
|
301
|
-
) -> List[Dict[str, Any]]:
|
|
302
|
-
"""Execute LLM task and return text responses."""
|
|
303
|
-
# Prepare messages
|
|
304
|
-
messages = self._prepare_input(input_data)
|
|
305
|
-
|
|
306
|
-
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
307
|
-
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
308
|
-
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
309
|
-
|
|
310
|
-
# Extract model name from kwargs for API call
|
|
311
|
-
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
312
|
-
|
|
313
|
-
try:
|
|
314
|
-
completion = self.client.chat.completions.create(
|
|
315
|
-
model=model_name, messages=messages, **api_kwargs
|
|
316
|
-
)
|
|
317
|
-
# Store raw response from client
|
|
318
|
-
self.last_ai_response = completion
|
|
319
|
-
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
320
|
-
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
321
|
-
logger.error(error_msg)
|
|
322
|
-
raise
|
|
323
|
-
except Exception as e:
|
|
324
|
-
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
325
|
-
if is_length_error:
|
|
326
|
-
raise ValueError(
|
|
327
|
-
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
328
|
-
)
|
|
329
|
-
# Re-raise all other exceptions
|
|
330
|
-
raise
|
|
331
|
-
# print(completion)
|
|
332
|
-
|
|
333
|
-
results: List[Dict[str, Any]] = []
|
|
334
|
-
for choice in completion.choices:
|
|
335
|
-
choice_messages = cast(
|
|
336
|
-
Messages,
|
|
337
|
-
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
338
|
-
)
|
|
339
|
-
result_dict = {"parsed": choice.message.content, "messages": choice_messages}
|
|
340
|
-
|
|
341
|
-
# Add reasoning content if this is a reasoning model
|
|
342
|
-
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
343
|
-
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
344
|
-
|
|
345
|
-
results.append(result_dict)
|
|
346
|
-
return results
|
|
347
|
-
|
|
348
|
-
def pydantic_parse(
|
|
349
|
-
self,
|
|
350
|
-
input_data: Union[str, BaseModel, list[Dict]],
|
|
351
|
-
response_model: Optional[Type[BaseModel]] | Type[str] = None,
|
|
352
|
-
**runtime_kwargs,
|
|
353
|
-
) -> List[Dict[str, Any]]:
|
|
354
|
-
"""Execute LLM task and return parsed Pydantic model responses."""
|
|
355
|
-
# Prepare messages
|
|
356
|
-
messages = self._prepare_input(input_data)
|
|
357
|
-
|
|
358
|
-
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
359
|
-
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
360
|
-
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
361
|
-
|
|
362
|
-
# Extract model name from kwargs for API call
|
|
363
|
-
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
364
|
-
|
|
365
|
-
pydantic_model_to_use_opt = response_model or self.output_model
|
|
366
|
-
if pydantic_model_to_use_opt is None:
|
|
367
|
-
raise ValueError(
|
|
368
|
-
"No response model specified. Either set output_model in constructor or pass response_model parameter."
|
|
369
|
-
)
|
|
370
|
-
pydantic_model_to_use: Type[BaseModel] = cast(
|
|
371
|
-
Type[BaseModel], pydantic_model_to_use_opt
|
|
372
|
-
)
|
|
373
|
-
try:
|
|
374
|
-
completion = self.client.chat.completions.parse(
|
|
375
|
-
model=model_name,
|
|
376
|
-
messages=messages,
|
|
377
|
-
response_format=pydantic_model_to_use,
|
|
378
|
-
**api_kwargs,
|
|
379
|
-
)
|
|
380
|
-
# Store raw response from client
|
|
381
|
-
self.last_ai_response = completion
|
|
382
|
-
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
383
|
-
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
384
|
-
logger.error(error_msg)
|
|
385
|
-
raise
|
|
386
|
-
except Exception as e:
|
|
387
|
-
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
388
|
-
if is_length_error:
|
|
389
|
-
raise ValueError(
|
|
390
|
-
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
391
|
-
)
|
|
392
|
-
# Re-raise all other exceptions
|
|
393
|
-
raise
|
|
394
|
-
|
|
395
|
-
results: List[Dict[str, Any]] = []
|
|
396
|
-
for choice in completion.choices: # type: ignore[attr-defined]
|
|
397
|
-
choice_messages = cast(
|
|
398
|
-
Messages,
|
|
399
|
-
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
403
|
-
parsed_content = choice.message.parsed # type: ignore[attr-defined]
|
|
404
|
-
if isinstance(parsed_content, dict):
|
|
405
|
-
# Cached response: validate dict back to Pydantic model
|
|
406
|
-
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
407
|
-
elif not isinstance(parsed_content, pydantic_model_to_use):
|
|
408
|
-
# Fallback: ensure it's the correct type
|
|
409
|
-
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
410
|
-
|
|
411
|
-
result_dict = {"parsed": parsed_content, "messages": choice_messages}
|
|
412
|
-
|
|
413
|
-
# Add reasoning content if this is a reasoning model
|
|
414
|
-
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
415
|
-
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
416
|
-
|
|
417
|
-
results.append(result_dict)
|
|
418
|
-
return results
|
|
419
|
-
|
|
420
|
-
def __call__(
|
|
421
|
-
self,
|
|
422
|
-
input_data: Union[str, BaseModel, list[Dict]],
|
|
423
|
-
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
424
|
-
two_step_parse_pydantic=False,
|
|
425
|
-
**runtime_kwargs,
|
|
426
|
-
) -> List[Dict[str, Any]]:
|
|
427
|
-
"""Execute LLM task. Delegates to text() or parse() based on output_model."""
|
|
428
|
-
pydantic_model_to_use = response_model or self.output_model
|
|
429
|
-
|
|
430
|
-
if pydantic_model_to_use is str or pydantic_model_to_use is None:
|
|
431
|
-
return self.text_completion(input_data, **runtime_kwargs)
|
|
432
|
-
elif two_step_parse_pydantic:
|
|
433
|
-
# step 1: get text completions
|
|
434
|
-
results = self.text_completion(input_data, **runtime_kwargs)
|
|
435
|
-
parsed_results = []
|
|
436
|
-
for result in results:
|
|
437
|
-
response_text = result["parsed"]
|
|
438
|
-
messages = result["messages"]
|
|
439
|
-
# check if the pydantic_model_to_use is validated
|
|
440
|
-
if "</think>" in response_text:
|
|
441
|
-
response_text = response_text.split("</think>")[1]
|
|
442
|
-
try:
|
|
443
|
-
parsed = pydantic_model_to_use.model_validate_json(response_text)
|
|
444
|
-
except Exception:
|
|
445
|
-
# Failed to parse JSON, falling back to LLM parsing
|
|
446
|
-
# use model to parse the response_text
|
|
447
|
-
_parsed_messages = [
|
|
448
|
-
{
|
|
449
|
-
"role": "system",
|
|
450
|
-
"content": "You are a helpful assistant that extracts JSON from text.",
|
|
451
|
-
},
|
|
452
|
-
{
|
|
453
|
-
"role": "user",
|
|
454
|
-
"content": f"Extract JSON from the following text:\n{response_text}",
|
|
455
|
-
},
|
|
456
|
-
]
|
|
457
|
-
parsed_result = self.pydantic_parse(
|
|
458
|
-
_parsed_messages,
|
|
459
|
-
response_model=pydantic_model_to_use,
|
|
460
|
-
**runtime_kwargs,
|
|
461
|
-
)[0]
|
|
462
|
-
parsed = parsed_result["parsed"]
|
|
463
|
-
# ---
|
|
464
|
-
parsed_results.append({"parsed": parsed, "messages": messages})
|
|
465
|
-
return parsed_results
|
|
466
|
-
|
|
467
|
-
else:
|
|
468
|
-
return self.pydantic_parse(
|
|
469
|
-
input_data, response_model=response_model, **runtime_kwargs
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
# Backward compatibility aliases
|
|
473
|
-
def text(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
474
|
-
"""Alias for text_completion() for backward compatibility."""
|
|
475
|
-
return self.text_completion(*args, **kwargs)
|
|
476
|
-
|
|
477
|
-
def parse(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
478
|
-
"""Alias for pydantic_parse() for backward compatibility."""
|
|
479
|
-
return self.pydantic_parse(*args, **kwargs)
|
|
480
|
-
|
|
481
|
-
@classmethod
|
|
482
|
-
def from_prompt_builder(
|
|
483
|
-
builder: BasePromptBuilder,
|
|
484
|
-
client: Union[OpenAI, int, str, None] = None,
|
|
485
|
-
cache=True,
|
|
486
|
-
is_reasoning_model: bool = False,
|
|
487
|
-
lora_path: Optional[str] = None,
|
|
488
|
-
vllm_cmd: Optional[str] = None,
|
|
489
|
-
vllm_timeout: int = 120,
|
|
490
|
-
vllm_reuse: bool = True,
|
|
491
|
-
**model_kwargs,
|
|
492
|
-
) -> "LLMTask":
|
|
493
|
-
"""
|
|
494
|
-
Create an LLMTask instance from a BasePromptBuilder instance.
|
|
495
|
-
|
|
496
|
-
This method extracts the instruction, input model, and output model
|
|
497
|
-
from the provided builder and initializes an LLMTask accordingly.
|
|
498
|
-
|
|
499
|
-
Args:
|
|
500
|
-
builder: BasePromptBuilder instance
|
|
501
|
-
client: OpenAI client, port number, or base_url string
|
|
502
|
-
cache: Whether to use cached responses (default True)
|
|
503
|
-
is_reasoning_model: Whether model is reasoning model (default False)
|
|
504
|
-
lora_path: Optional path to LoRA adapter directory
|
|
505
|
-
vllm_cmd: Optional VLLM command to start server automatically
|
|
506
|
-
vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
|
|
507
|
-
vllm_reuse: If True (default), reuse existing server on target port
|
|
508
|
-
**model_kwargs: Additional model parameters
|
|
509
|
-
"""
|
|
510
|
-
instruction = builder.get_instruction()
|
|
511
|
-
input_model = builder.get_input_model()
|
|
512
|
-
output_model = builder.get_output_model()
|
|
513
|
-
|
|
514
|
-
# Extract data from the builder to initialize LLMTask
|
|
515
|
-
return LLMTask(
|
|
516
|
-
instruction=instruction,
|
|
517
|
-
input_model=input_model,
|
|
518
|
-
output_model=output_model,
|
|
519
|
-
client=client,
|
|
520
|
-
cache=cache,
|
|
521
|
-
is_reasoning_model=is_reasoning_model,
|
|
522
|
-
lora_path=lora_path,
|
|
523
|
-
vllm_cmd=vllm_cmd,
|
|
524
|
-
vllm_timeout=vllm_timeout,
|
|
525
|
-
vllm_reuse=vllm_reuse,
|
|
526
|
-
**model_kwargs,
|
|
527
|
-
)
|
|
528
|
-
|
|
529
|
-
@staticmethod
|
|
530
|
-
def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
|
|
531
|
-
"""
|
|
532
|
-
List available models from the OpenAI client.
|
|
533
|
-
|
|
534
|
-
Args:
|
|
535
|
-
client: OpenAI client, port number, or base_url string
|
|
536
|
-
|
|
537
|
-
Returns:
|
|
538
|
-
List of available model names.
|
|
539
|
-
"""
|
|
540
|
-
client = get_base_client(client, cache=False)
|
|
541
|
-
models = client.models.list().data
|
|
542
|
-
return [m.id for m in models]
|
|
543
|
-
|
|
544
|
-
@staticmethod
|
|
545
|
-
def kill_all_vllm() -> int:
|
|
546
|
-
"""Kill all tracked VLLM server processes."""
|
|
547
|
-
return kill_all_vllm_processes()
|
|
548
|
-
|
|
549
|
-
@staticmethod
|
|
550
|
-
def kill_vllm_on_port(port: int) -> bool:
|
|
551
|
-
"""
|
|
552
|
-
Kill VLLM server running on a specific port.
|
|
553
|
-
|
|
554
|
-
Args:
|
|
555
|
-
port: Port number to kill server on
|
|
556
|
-
|
|
557
|
-
Returns:
|
|
558
|
-
True if a server was killed, False if no server was running
|
|
559
|
-
"""
|
|
560
|
-
return _kill_vllm_on_port(port)
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
# Example usage:
|
|
564
|
-
if __name__ == "__main__":
|
|
565
|
-
# Example 1: Using VLLM with reuse (default behavior)
|
|
566
|
-
vllm_command = (
|
|
567
|
-
"vllm serve saves/vng/dpo/01 -tp 4 --port 8001 "
|
|
568
|
-
"--gpu-memory-utilization 0.9 --served-model-name sft --quantization experts_int8"
|
|
569
|
-
)
|
|
570
|
-
|
|
571
|
-
print("Example 1: Using VLLM with server reuse (default)")
|
|
572
|
-
# Create LLM instance - will reuse existing server if running on port 8001
|
|
573
|
-
with LLMTask(vllm_cmd=vllm_command) as llm: # vllm_reuse=True by default
|
|
574
|
-
result = llm.text("Hello, how are you?")
|
|
575
|
-
print("Response:", result[0]["parsed"])
|
|
576
|
-
|
|
577
|
-
print("\nExample 2: Force restart server (vllm_reuse=False)")
|
|
578
|
-
# This will kill any existing server on port 8001 and start fresh
|
|
579
|
-
with LLMTask(vllm_cmd=vllm_command, vllm_reuse=False) as llm:
|
|
580
|
-
result = llm.text("Tell me a joke")
|
|
581
|
-
print("Joke:", result[0]["parsed"])
|
|
582
|
-
|
|
583
|
-
print("\nExample 3: Multiple instances with reuse")
|
|
584
|
-
# First instance starts the server
|
|
585
|
-
llm1 = LLMTask(vllm_cmd=vllm_command) # Starts server or reuses existing
|
|
586
|
-
|
|
587
|
-
# Second instance reuses the same server
|
|
588
|
-
llm2 = LLMTask(vllm_cmd=vllm_command) # Reuses server on port 8001
|
|
589
|
-
|
|
590
|
-
try:
|
|
591
|
-
result1 = llm1.text("What's the weather like?")
|
|
592
|
-
result2 = llm2.text("How's the traffic?")
|
|
593
|
-
print("Weather response:", result1[0]["parsed"])
|
|
594
|
-
print("Traffic response:", result2[0]["parsed"])
|
|
595
|
-
finally:
|
|
596
|
-
# Only cleanup if we started the process
|
|
597
|
-
llm1.cleanup_vllm_server()
|
|
598
|
-
llm2.cleanup_vllm_server() # Won't do anything if process not owned
|
|
599
|
-
|
|
600
|
-
print("\nExample 4: Different ports")
|
|
601
|
-
# These will start separate servers
|
|
602
|
-
llm_8001 = LLMTask(vllm_cmd="vllm serve model1 --port 8001", vllm_reuse=True)
|
|
603
|
-
llm_8002 = LLMTask(vllm_cmd="vllm serve model2 --port 8002", vllm_reuse=True)
|
|
604
|
-
|
|
605
|
-
print("\nExample 5: Kill all VLLM servers")
|
|
606
|
-
# Kill all tracked VLLM processes
|
|
607
|
-
killed_count = LLMTask.kill_all_vllm()
|
|
608
|
-
print(f"Killed {killed_count} VLLM servers")
|
|
609
|
-
|
|
610
|
-
print("\nYou can check VLLM server output at: /tmp/vllm.txt")
|
|
611
|
-
print("Server reuse behavior:")
|
|
612
|
-
print("- vllm_reuse=True (default): Reuse existing server on target port")
|
|
613
|
-
print("- vllm_reuse=False: Kill existing server first, then start fresh")
|
|
614
|
-
|
|
File without changes
|
|
File without changes
|