py-adtools 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,452 @@
1
+ """
2
+ Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
3
+
4
+ NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
5
+ Commercial use of this software or its derivatives requires prior written permission.
6
+ """
7
+
8
+ from typing import Optional, List, Literal, Dict, Any
9
+ import os
10
+ import subprocess
11
+ import sys
12
+ from pathlib import Path
13
+ import psutil
14
+ import time
15
+
16
+ import openai.types.chat
17
+ import requests
18
+
19
+ from adtools.lm.lm_base import LanguageModel
20
+
21
+
22
+ def _print_cmd_list(cmd_list, gpus, host, port):
23
+ print("\n" + "=" * 80)
24
+ print(f"[vLLM] Launching vLLM on GPU:{gpus}; URL: https://{host}:{port}")
25
+ print("=" * 80)
26
+ cmd = cmd_list[0] + " \\\n"
27
+ for c in cmd_list[1:]:
28
+ cmd += " " + c + " \\\n"
29
+ print(cmd.strip())
30
+ print("=" * 80 + "\n", flush=True)
31
+
32
+
33
+ class VLLMServer(LanguageModel):
34
+ def __init__(
35
+ self,
36
+ model_path: str,
37
+ port: int,
38
+ gpus: int | list[int],
39
+ tokenizer_path: Optional[str] = None,
40
+ max_model_len: int = 16384,
41
+ max_lora_rank: Optional[int] = None,
42
+ host: str = "0.0.0.0",
43
+ mem_util: float = 0.85,
44
+ deploy_timeout_seconds: int = 600,
45
+ *,
46
+ launch_vllm_in_init=True,
47
+ enforce_eager: bool = False,
48
+ vllm_log_level: Literal[
49
+ "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
50
+ ] = "INFO",
51
+ silent_mode: bool = False,
52
+ env_variable_dict: Optional[Dict[str, str]] = None,
53
+ vllm_serve_args: Optional[List[str]] = None,
54
+ vllm_serve_kwargs: Optional[Dict[str, str]] = None,
55
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
56
+ ):
57
+ """Deploy an LLM on specified GPUs.
58
+
59
+ Args:
60
+ model_path: Path to the model to deploy.
61
+ tokenizer_path: Path to the tokenizer to use.
62
+ port: List of ports to deploy.
63
+ gpus: List of GPUs to deploy.
64
+ max_lora_rank: Max rank of LoRA adapter. Defaults to `None` which disables LoRA adapter.
65
+ host: Host address for vLLM server.
66
+ mem_util: Memory utility for each vLLM deployment.
67
+ deploy_timeout_seconds: Timeout to deploy (in seconds).
68
+ launch_vllm_in_init: Launch vLLM during initialization of this class.
69
+ enforce_eager: Enforce eager mode.
70
+ vllm_log_level: Log level of vLLM server.
71
+ silent_mode: Silent mode.
72
+ env_variable_dict: Environment variables to use for vLLM server, e.g., {'KEY': 'VALUE'}.
73
+ vllm_serve_args: Arguments to pass to vLLM server, e.g., ['--enable-reasoning'].
74
+ vllm_serve_kwargs: Keyword arguments to pass to vLLM server, e.g., {'--reasoning-parser': 'deepseek-r1'}.
75
+ chat_template_kwargs: Keyword arguments to pass to chat template, e.g., {'enable_thinking': False}.
76
+
77
+ Example:
78
+ # deploy a model on GPU 0 and 1
79
+ llm = VLLMServer(
80
+ model_path='path/to/model',
81
+ tokenizer_path='path/to/tokenizer',
82
+ gpus=[0, 1], # set gpus=0 or gpus=[0] if you only use one GPU
83
+ port=12001,
84
+ mem_util=0.8
85
+ )
86
+ # draw sample using base model
87
+ llm.draw_sample('hello')
88
+
89
+ # load adapter and draw sample
90
+ llm.load_lora_adapter('adapter_1', '/path/to/adapter')
91
+ llm.draw_sample('hello', lora_name='adapter_1')
92
+
93
+ # unload adapter
94
+ llm.unload_lora_adapter('adapter_1')
95
+
96
+ # release resources
97
+ llm.close()
98
+ """
99
+ self._model_path = model_path
100
+ self._port = port
101
+ self._gpus = gpus
102
+ self._tokenizer_path = (
103
+ tokenizer_path if tokenizer_path is not None else model_path
104
+ )
105
+ self._max_model_len = max_model_len
106
+ self._max_lora_rank = max_lora_rank
107
+ self._host = host
108
+ self._mem_util = mem_util
109
+ self._deploy_timeout_seconds = deploy_timeout_seconds
110
+ self._enforce_eager = enforce_eager
111
+ self._vllm_log_level = vllm_log_level
112
+ self._silent_mode = silent_mode
113
+ self._env_variable_dict = env_variable_dict
114
+ self._vllm_serve_args = vllm_serve_args
115
+ self._vllm_serve_kwargs = vllm_serve_kwargs
116
+ self._chat_template_kwargs = chat_template_kwargs
117
+ self._vllm_server_process = None
118
+
119
+ # Deploy vLLMs
120
+ if launch_vllm_in_init:
121
+ self.launch_vllm_server()
122
+
123
+ def launch_vllm_server(self, detach: bool = False, skip_if_running: bool = False):
124
+ try:
125
+ import vllm
126
+ except ImportError:
127
+ raise
128
+
129
+ if skip_if_running and self._is_server_running():
130
+ print(
131
+ f"[vLLM] Server already running on http://{self._host}:{self._port}. "
132
+ f"Skipping launch."
133
+ )
134
+ return
135
+
136
+ self._detached = detach
137
+ self._vllm_server_process = self._launch_vllm(detach=detach)
138
+ self._wait_for_vllm()
139
+
140
+ def _launch_vllm(self, detach: bool = False):
141
+ """Launch a vLLM server and return the subprocess."""
142
+ if isinstance(self._gpus, int):
143
+ gpus = str(self._gpus)
144
+ else:
145
+ gpus = ",".join([str(g) for g in self._gpus])
146
+
147
+ executable_path = sys.executable
148
+ cmd = [
149
+ executable_path,
150
+ "-m",
151
+ "vllm.entrypoints.openai.api_server",
152
+ "--model",
153
+ self._model_path,
154
+ "--tokenizer",
155
+ self._tokenizer_path,
156
+ "--max_model_len",
157
+ str(self._max_model_len),
158
+ "--host",
159
+ self._host,
160
+ "--port",
161
+ str(self._port),
162
+ "--gpu-memory-utilization",
163
+ str(self._mem_util),
164
+ "--tensor-parallel-size",
165
+ str(len(self._gpus)) if isinstance(self._gpus, list) else "1",
166
+ "--trust-remote-code",
167
+ "--chat-template-content-format",
168
+ "string",
169
+ ]
170
+
171
+ if self._enforce_eager:
172
+ cmd.append("--enforce_eager")
173
+
174
+ # Other args for vllm serve
175
+ if self._vllm_serve_args is not None:
176
+ for arg in self._vllm_serve_args:
177
+ cmd.append(arg)
178
+
179
+ # Other kwargs for vllm serve
180
+ if self._vllm_serve_kwargs is not None:
181
+ for kwarg, value in self._vllm_serve_kwargs.items():
182
+ cmd.extend([kwarg, value])
183
+
184
+ # Environmental variables
185
+ env = os.environ.copy()
186
+ env["CUDA_VISIBLE_DEVICES"] = gpus
187
+ env["VLLM_LOGGING_LEVEL"] = self._vllm_log_level
188
+
189
+ # FIXME: These code are required for my machine :(
190
+ # FIXME: This may due to the bad NCCL configuration :(
191
+ if isinstance(self._gpus, list) and len(self._gpus) > 1:
192
+ # set NCCL environment variable
193
+ env["NCCL_P2P_DISABLE"] = "1"
194
+ # disable custom all reduce
195
+ cmd.append("--disable-custom-all-reduce")
196
+
197
+ # Enable LoRA dynamic loading
198
+ if self._max_lora_rank is not None:
199
+ cmd.extend(
200
+ [
201
+ "--enable-lora",
202
+ "--max-lora-rank",
203
+ str(self._max_lora_rank),
204
+ ]
205
+ )
206
+ env["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
207
+
208
+ # Other env variables
209
+ if self._env_variable_dict is not None:
210
+ for k, v in self._env_variable_dict.items():
211
+ env[k] = v
212
+
213
+ _print_cmd_list(cmd, gpus=self._gpus, host=self._host, port=self._port)
214
+
215
+ # Launch vllm using subprocess
216
+ stdout = Path(os.devnull).open("w") if self._silent_mode else None
217
+ preexec_fn = os.setsid if detach and sys.platform != "win32" else None
218
+ proc = subprocess.Popen(
219
+ cmd, env=env, stdout=stdout, stderr=subprocess.STDOUT, preexec_fn=preexec_fn
220
+ )
221
+ return proc
222
+
223
+ def _kill_vllm_process(self):
224
+ if getattr(self, "_detached", False):
225
+ print(
226
+ f"[vLLM] Server on port {self._port} is detached. Not killing process."
227
+ )
228
+ return
229
+
230
+ try:
231
+ # Get child processes before terminating parent
232
+ try:
233
+ parent = psutil.Process(self._vllm_server_process.pid)
234
+ children = parent.children(recursive=True)
235
+ except psutil.NoSuchProcess:
236
+ children = []
237
+
238
+ # Terminate parent process
239
+ self._vllm_server_process.terminate()
240
+ self._vllm_server_process.wait(timeout=5)
241
+ print(f"[vLLM] terminated process: {self._vllm_server_process.pid}")
242
+
243
+ # Kill any remaining children
244
+ for child in children:
245
+ try:
246
+ child.terminate()
247
+ child.wait(timeout=2)
248
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
249
+ try:
250
+ child.kill()
251
+ except psutil.NoSuchProcess:
252
+ pass
253
+ except subprocess.TimeoutExpired:
254
+ self._vllm_server_process.kill()
255
+ print(f"[vLLM] killed process: {self._vllm_server_process.pid}")
256
+
257
+ def _is_server_running(self):
258
+ """Check if a vLLM server is already running on the given host and port."""
259
+ health = f"http://{self._host}:{self._port}/health"
260
+ try:
261
+ if requests.get(health, timeout=1).status_code == 200:
262
+ return True
263
+ except requests.exceptions.RequestException:
264
+ pass
265
+ return False
266
+
267
+ def _wait_for_vllm(self):
268
+ """Check each vLLM server's state and check /health. Kill all vLLM server processes if timeout."""
269
+ for _ in range(self._deploy_timeout_seconds):
270
+ # check process status
271
+ if self._vllm_server_process.poll() is not None:
272
+ sys.exit(
273
+ f"[vLLM] crashed (exit {self._vllm_server_process.returncode})"
274
+ )
275
+
276
+ # check server status
277
+ if self._is_server_running():
278
+ return
279
+ time.sleep(1)
280
+
281
+ # Servers fail to initialize
282
+ print("[vLLM] failed to start within timeout")
283
+ self._kill_vllm_process()
284
+ sys.exit("[vLLM] failed to start within timeout")
285
+
286
+ def unload_lora_adapter(self, lora_name: str):
287
+ """Unload lora adapter given the lora name.
288
+ Args:
289
+ lora_name: Lora adapter name.
290
+ """
291
+ lora_api_url = f"http://{self._host}:{self._port}/v1/unload_lora_adapter"
292
+ headers = {"Content-Type": "application/json"}
293
+ try:
294
+ payload = {"lora_name": lora_name}
295
+ requests.post(lora_api_url, json=payload, headers=headers, timeout=10)
296
+ except requests.exceptions.RequestException:
297
+ pass
298
+
299
+ def load_lora_adapter(
300
+ self, lora_name: str, new_adapter_path: str, num_trails: int = 5
301
+ ):
302
+ """Dynamically load a LoRA adapter.
303
+
304
+ Args:
305
+ lora_name: LoRA adapter name.
306
+ new_adapter_path: Path to the new LoRA adapter weights.
307
+ """
308
+ # First unload lora adapter
309
+ self.unload_lora_adapter(lora_name)
310
+
311
+ if self._max_lora_rank is None:
312
+ raise ValueError(
313
+ 'LoRA is not enabled for this VLLMServer instance, since "max_lora_rank" is not set.'
314
+ )
315
+
316
+ # Prepare the payload for LoRA update
317
+ payload = {"lora_name": lora_name, "lora_path": new_adapter_path}
318
+ headers = {"Content-Type": "application/json"}
319
+ lora_api_url = f"http://{self._host}:{self._port}/v1/load_lora_adapter"
320
+
321
+ # Repeatedly trying to load lora adapters
322
+ for i in range(num_trails):
323
+ try:
324
+ response = requests.post(
325
+ lora_api_url, json=payload, headers=headers, timeout=60
326
+ )
327
+ if response.status_code == 200:
328
+ print(
329
+ f"[vLLM] Successfully load LoRA adapter: {lora_name} from {new_adapter_path}"
330
+ )
331
+ else:
332
+ print(
333
+ f"[vLLM] Failed to load LoRA adapter. "
334
+ f"Status code: {response.status_code}, Response: {response.text}"
335
+ )
336
+ return True
337
+ except requests.exceptions.RequestException:
338
+ continue
339
+
340
+ print(f"[vLLM] Error loading LoRA adapter.")
341
+ return False
342
+
343
+ def close(self):
344
+ """Shut down vLLM server and kill all vLLM processes."""
345
+ if self._vllm_server_process is not None:
346
+ self._kill_vllm_process()
347
+
348
+ def reload(self):
349
+ """Reload the vllm server."""
350
+ self.launch_vllm_server()
351
+
352
+ def chat_completion(
353
+ self,
354
+ message: str | List[openai.types.chat.ChatCompletionMessageParam],
355
+ max_tokens: Optional[int] = None,
356
+ timeout_seconds: Optional[int] = None,
357
+ lora_name: Optional[str] = None,
358
+ temperature: float = 0.9,
359
+ top_p: float = 0.9,
360
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
361
+ ) -> str:
362
+ """Send a chat completion query with OpenAI format to the vLLM server.
363
+ Return the response content.
364
+
365
+ Args:
366
+ message: The message in str or openai format.
367
+ max_tokens: The maximum number of tokens to generate.
368
+ timeout_seconds: The timeout seconds.
369
+ lora_name: Lora adapter name. Defaults to None which uses base model.
370
+ temperature: The temperature parameter.
371
+ top_p: The top p parameter.
372
+ chat_template_kwargs: The chat template kwargs, e.g., {'enable_thinking': False}.
373
+ """
374
+ data = {
375
+ "messages": [
376
+ (
377
+ {"role": "user", "content": message.strip()}
378
+ if isinstance(message, str)
379
+ else message
380
+ )
381
+ ],
382
+ "temperature": temperature,
383
+ "top_p": top_p,
384
+ "max_tokens": max_tokens,
385
+ }
386
+ data: Dict[str, Any]
387
+
388
+ # Use the specified lora adapter
389
+ if lora_name is not None:
390
+ data["model"] = lora_name
391
+ # Chat template keyword args
392
+ if self._chat_template_kwargs is not None:
393
+ data["chat_template_kwargs"] = self._chat_template_kwargs
394
+ elif chat_template_kwargs is not None:
395
+ data["chat_template_kwargs"] = chat_template_kwargs
396
+ # Request
397
+ url = f"http://{self._host}:{self._port}/v1/chat/completions"
398
+ headers = {"Content-Type": "application/json"}
399
+ response = requests.post(
400
+ url, headers=headers, json=data, timeout=timeout_seconds
401
+ )
402
+ return response.json()["choices"][0]["message"]["content"]
403
+
404
+ def embedding(
405
+ self,
406
+ text: str | List[str],
407
+ dimensions: Optional[int] = None,
408
+ timeout_seconds: Optional[float] = None,
409
+ lora_name: Optional[str] = None,
410
+ **kwargs,
411
+ ) -> List[float] | List[List[float]]:
412
+ """Generate embeddings for the given text(s).
413
+
414
+ Args:
415
+ text: The text or a list of texts to embed.
416
+ dimensions: The number of dimensions for the output embeddings.
417
+ timeout_seconds: The timeout seconds.
418
+ lora_name: Lora adapter name. Defaults to None which uses base model.
419
+
420
+ Returns:
421
+ The embedding for the text, or a list of embeddings for the list of texts.
422
+ """
423
+ is_str_input = isinstance(text, str)
424
+ if is_str_input:
425
+ text = [text]
426
+
427
+ # Prepare arguments for the API call
428
+ data = {"input": text}
429
+ data: Dict[str, Any]
430
+
431
+ if dimensions is not None:
432
+ data["dimensions"] = dimensions
433
+
434
+ if lora_name is not None:
435
+ data["model"] = lora_name
436
+
437
+ data.update(kwargs)
438
+
439
+ url = f"http://{self._host}:{self._port}/v1/embeddings"
440
+ headers = {"Content-Type": "application/json"}
441
+
442
+ response = requests.post(
443
+ url, headers=headers, json=data, timeout=timeout_seconds
444
+ )
445
+ response.raise_for_status()
446
+
447
+ response_json = response.json()
448
+ embeddings = [item["embedding"] for item in response_json["data"]]
449
+
450
+ if is_str_input:
451
+ return embeddings[0]
452
+ return embeddings