py-adtools 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,118 @@
1
+ """
2
+ Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
3
+
4
+ NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
5
+ Commercial use of this software or its derivatives requires prior written permission.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ from typing import List, Optional, Dict, Any
11
+
12
+ import openai.types.chat
13
+
14
+ from adtools.lm.lm_base import LanguageModel
15
+
16
+ logging.getLogger("httpx").setLevel(logging.WARNING)
17
+
18
+
19
+ class OpenAIAPI(LanguageModel):
20
+ def __init__(
21
+ self,
22
+ model: str,
23
+ base_url: str = None,
24
+ api_key: str = None,
25
+ **openai_init_kwargs,
26
+ ):
27
+ super().__init__()
28
+ # If base_url is set to None, find 'OPENAI_BASE_URL' in environment variables
29
+ if base_url is None:
30
+ if "OPENAI_BASE_URL" not in os.environ:
31
+ raise RuntimeError(
32
+ 'If "base_url" is None, the environment variable OPENAI_BASE_URL must be set.'
33
+ )
34
+ else:
35
+ base_url = os.environ["OPENAI_BASE_URL"]
36
+
37
+ # If api_key is set to None, find 'OPENAI_API_KEY' in environment variables
38
+ if api_key is None:
39
+ if "OPENAI_API_KEY" not in os.environ:
40
+ raise RuntimeError('If "api_key" is None, OPENAI_API_KEY must be set.')
41
+ else:
42
+ api_key = os.environ["OPENAI_API_KEY"]
43
+
44
+ self._model = model
45
+ self._client = openai.OpenAI(
46
+ api_key=api_key, base_url=base_url, **openai_init_kwargs
47
+ )
48
+
49
+ def chat_completion(
50
+ self,
51
+ message: str | List[openai.types.chat.ChatCompletionMessageParam],
52
+ max_tokens: Optional[int] = None,
53
+ timeout_seconds: Optional[float] = None,
54
+ *args,
55
+ **kwargs,
56
+ ):
57
+ """Send a chat completion query with OpenAI format to the vLLM server.
58
+ Return the response content.
59
+
60
+ Args:
61
+ message: The message in str or openai format.
62
+ max_tokens: The maximum number of tokens to generate.
63
+ timeout_seconds: The timeout seconds.
64
+ """
65
+ if isinstance(message, str):
66
+ message = [{"role": "user", "content": message.strip()}]
67
+
68
+ response = self._client.chat.completions.create(
69
+ model=self._model,
70
+ messages=message,
71
+ stream=False,
72
+ max_tokens=max_tokens,
73
+ timeout=timeout_seconds,
74
+ *args,
75
+ **kwargs,
76
+ )
77
+ return response.choices[0].message.content
78
+
79
+ def embedding(
80
+ self,
81
+ text: str | List[str],
82
+ dimensions: Optional[int] = None,
83
+ timeout_seconds: Optional[float] = None,
84
+ **kwargs,
85
+ ) -> List[float] | List[List[float]]:
86
+ """Generate embeddings for the given text(s) using the model specified during initialization.
87
+
88
+ Args:
89
+ text: The text or a list of texts to embed.
90
+ dimensions: The number of dimensions for the output embeddings.
91
+ timeout_seconds: The timeout seconds.
92
+
93
+ Returns:
94
+ The embedding for the text, or a list of embeddings for the list of texts.
95
+ """
96
+ is_str_input = isinstance(text, str)
97
+ if is_str_input:
98
+ text = [text]
99
+
100
+ # Prepare arguments for the OpenAI API call
101
+ api_kwargs = {
102
+ "input": text,
103
+ "model": self._model,
104
+ }
105
+ api_kwargs: Dict[str, Any]
106
+
107
+ if dimensions is not None:
108
+ api_kwargs["dimensions"] = dimensions
109
+ if timeout_seconds is not None:
110
+ api_kwargs["timeout"] = timeout_seconds
111
+
112
+ api_kwargs.update(kwargs)
113
+ response = self._client.embeddings.create(**api_kwargs)
114
+ embeddings = [item.embedding for item in response.data]
115
+
116
+ if is_str_input:
117
+ return embeddings[0]
118
+ return embeddings
@@ -0,0 +1,423 @@
1
+ """
2
+ Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
3
+
4
+ NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
5
+ Commercial use of this software or its derivatives requires prior written permission.
6
+ """
7
+
8
+ from typing import Optional, List, Literal, Dict, Any
9
+ import os
10
+ import subprocess
11
+ import sys
12
+ from pathlib import Path
13
+ import psutil
14
+ import time
15
+
16
+ import openai.types.chat
17
+ import requests
18
+
19
+ from adtools.lm.lm_base import LanguageModel
20
+
21
+
22
+ def _print_cmd_list(cmd_list, gpus, host, port):
23
+ print("\n" + "=" * 80)
24
+ print(f"[SGLang] Launching SGLang on GPU:{gpus}; URL: http://{host}:{port}")
25
+ print("=" * 80)
26
+ cmd = cmd_list[0] + " \\\n"
27
+ for c in cmd_list[1:]:
28
+ cmd += " " + c + " \\\n"
29
+ print(cmd.strip())
30
+ print("=" * 80 + "\n", flush=True)
31
+
32
+
33
+ class SGLangServer(LanguageModel):
34
+ def __init__(
35
+ self,
36
+ model_path: str,
37
+ port: int,
38
+ gpus: int | list[int],
39
+ tokenizer_path: Optional[str] = None,
40
+ context_length: int = 16384,
41
+ max_lora_rank: Optional[int] = None,
42
+ host: str = "0.0.0.0",
43
+ mem_fraction_static: float = 0.85,
44
+ deploy_timeout_seconds: int = 600,
45
+ *,
46
+ launch_sglang_in_init=True,
47
+ sglang_log_level: Literal["debug", "info", "warning", "error"] = "info",
48
+ silent_mode: bool = False,
49
+ env_variable_dict: Optional[Dict[str, str]] = None,
50
+ sglang_serve_args: Optional[List[str]] = None,
51
+ sglang_serve_kwargs: Optional[Dict[str, str]] = None,
52
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
53
+ ):
54
+ """Deploy an SGLang server on specified GPUs.
55
+
56
+ Args:
57
+ model_path: Path to the model to deploy.
58
+ port: Port to deploy.
59
+ gpus: List of GPUs to deploy.
60
+ tokenizer_path: Path to the tokenizer. Defaults to model_path.
61
+ context_length: The context length (mapped to --context-length).
62
+ max_lora_rank: Max rank of LoRA adapter. Defaults to `None`.
63
+ Set this to enable LoRA support (mapped to --max-lora-rank).
64
+ host: Host address for SGLang server.
65
+ mem_fraction_static: The memory fraction for static allocation (mapped to --mem-fraction-static).
66
+ deploy_timeout_seconds: Timeout to deploy (in seconds).
67
+ launch_sglang_in_init: Launch SGLang during initialization of this class.
68
+ sglang_log_level: Log level.
69
+ silent_mode: Silent mode.
70
+ env_variable_dict: Environment variables to use.
71
+ sglang_serve_args: Additional arguments to pass to sglang server, e.g., ['--enable-flashinfer'],
72
+ or ['--attention-backend', 'triton']
73
+ sglang_serve_kwargs: Keyword arguments to pass to sglang server.
74
+ chat_template_kwargs: Keyword arguments for chat template (passed during request).
75
+
76
+ Example:
77
+ # deploy a model on GPU 0 and 1 with LoRA support
78
+ llm = SGLangServer(
79
+ model_path='meta-llama/Meta-Llama-3-8B-Instruct',
80
+ port=30000,
81
+ gpus=[0, 1],
82
+ max_lora_rank=16, # Enable LoRA
83
+ sglang_serve_args=['--attention-backend', 'triton']
84
+ )
85
+
86
+ # Load an adapter
87
+ llm.load_lora_adapter("my_adapter", "/path/to/adapter")
88
+
89
+ # Use the adapter
90
+ llm.chat_completion("Hello", lora_name="my_adapter")
91
+ """
92
+ self._model_path = model_path
93
+ self._port = port
94
+ self._gpus = gpus
95
+ self._tokenizer_path = (
96
+ tokenizer_path if tokenizer_path is not None else model_path
97
+ )
98
+ self._context_length = context_length
99
+ self._max_lora_rank = max_lora_rank
100
+ self._host = host
101
+ self._mem_fraction_static = mem_fraction_static
102
+ self._deploy_timeout_seconds = deploy_timeout_seconds
103
+ self._sglang_log_level = sglang_log_level
104
+ self._silent_mode = silent_mode
105
+ self._env_variable_dict = env_variable_dict
106
+ self._sglang_serve_args = sglang_serve_args
107
+ self._sglang_serve_kwargs = sglang_serve_kwargs
108
+ self._chat_template_kwargs = chat_template_kwargs
109
+ self._server_process = None
110
+
111
+ # Deploy SGLang
112
+ if launch_sglang_in_init:
113
+ self.launch_sglang_server()
114
+
115
+ def launch_sglang_server(self, detach: bool = False, skip_if_running: bool = False):
116
+ try:
117
+ import sglang
118
+ except ImportError:
119
+ raise
120
+
121
+ if skip_if_running and self._is_server_running():
122
+ print(
123
+ f"[SGLang] Server already running on http://{self._host}:{self._port}. "
124
+ f"Skipping launch."
125
+ )
126
+ return
127
+
128
+ self._detached = detach
129
+ self._server_process = self._launch_sglang(detach=detach)
130
+ self._wait_for_server()
131
+
132
+ def _launch_sglang(self, detach: bool = False):
133
+ """Launch an SGLang server and return the subprocess."""
134
+ if isinstance(self._gpus, int):
135
+ gpus = str(self._gpus)
136
+ tp_size = 1
137
+ else:
138
+ gpus = ",".join([str(g) for g in self._gpus])
139
+ tp_size = len(self._gpus)
140
+
141
+ executable_path = sys.executable
142
+
143
+ # SGLang launch command structure
144
+ cmd = [
145
+ executable_path,
146
+ "-m",
147
+ "sglang.launch_server",
148
+ "--model-path",
149
+ self._model_path,
150
+ "--tokenizer-path",
151
+ self._tokenizer_path,
152
+ "--port",
153
+ str(self._port),
154
+ "--host",
155
+ self._host,
156
+ "--context-length",
157
+ str(self._context_length),
158
+ "--mem-fraction-static",
159
+ str(self._mem_fraction_static),
160
+ "--tp",
161
+ str(tp_size),
162
+ "--trust-remote-code",
163
+ ]
164
+
165
+ # Enable LoRA support if rank is specified
166
+ if self._max_lora_rank is not None:
167
+ cmd.extend(["--max-lora-rank", str(self._max_lora_rank)])
168
+ # SGLang sometimes requires disabling radix cache for LoRA in older versions,
169
+ # but newer versions support it. If you face issues, consider adding "--disable-radix-cache".
170
+
171
+ # Other args for sglang serve
172
+ if self._sglang_serve_args is not None:
173
+ for arg in self._sglang_serve_args:
174
+ cmd.append(arg)
175
+
176
+ # Other kwargs for sglang serve
177
+ if self._sglang_serve_kwargs is not None:
178
+ for kwarg, value in self._sglang_serve_kwargs.items():
179
+ cmd.extend([kwarg, value])
180
+
181
+ # Environmental variables
182
+ env = os.environ.copy()
183
+ env["CUDA_VISIBLE_DEVICES"] = gpus
184
+ env["LOG_LEVEL"] = self._sglang_log_level.upper()
185
+
186
+ # Handle NCCL issues if using multiple GPUs
187
+ if tp_size > 1:
188
+ env["NCCL_P2P_DISABLE"] = "1"
189
+
190
+ if self._env_variable_dict is not None:
191
+ for k, v in self._env_variable_dict.items():
192
+ env[k] = v
193
+
194
+ _print_cmd_list(cmd, gpus=self._gpus, host=self._host, port=self._port)
195
+
196
+ # Launch using subprocess
197
+ stdout = Path(os.devnull).open("w") if self._silent_mode else None
198
+ preexec_fn = os.setsid if detach and sys.platform != "win32" else None
199
+ proc = subprocess.Popen(
200
+ cmd, env=env, stdout=stdout, stderr=subprocess.STDOUT, preexec_fn=preexec_fn
201
+ )
202
+ return proc
203
+
204
+ def _kill_process(self):
205
+ if getattr(self, "_detached", False):
206
+ print(
207
+ f"[SGLang] Server on port {self._port} is detached. Not killing process."
208
+ )
209
+ return
210
+
211
+ try:
212
+ # Get child processes before terminating parent
213
+ try:
214
+ parent = psutil.Process(self._server_process.pid)
215
+ children = parent.children(recursive=True)
216
+ except psutil.NoSuchProcess:
217
+ children = []
218
+
219
+ # Terminate parent process
220
+ self._server_process.terminate()
221
+ self._server_process.wait(timeout=5)
222
+ print(f"[SGLang] terminated process: {self._server_process.pid}")
223
+
224
+ # Kill any remaining children
225
+ for child in children:
226
+ try:
227
+ child.terminate()
228
+ child.wait(timeout=2)
229
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
230
+ try:
231
+ child.kill()
232
+ except psutil.NoSuchProcess:
233
+ pass
234
+ except subprocess.TimeoutExpired:
235
+ self._server_process.kill()
236
+ print(f"[SGLang] killed process: {self._server_process.pid}")
237
+
238
+ def _is_server_running(self):
239
+ """Check if an SGLang server is already running on the given host and port."""
240
+ health = f"http://{self._host}:{self._port}/health"
241
+ try:
242
+ if requests.get(health, timeout=1).status_code == 200:
243
+ return True
244
+ except requests.exceptions.RequestException:
245
+ pass
246
+ return False
247
+
248
+ def _wait_for_server(self):
249
+ """Check server state and /health endpoint."""
250
+ for _ in range(self._deploy_timeout_seconds):
251
+ if self._server_process.poll() is not None:
252
+ sys.exit(f"[SGLang] crashed (exit {self._server_process.returncode})")
253
+
254
+ if self._is_server_running():
255
+ return
256
+ time.sleep(1)
257
+
258
+ print("[SGLang] failed to start within timeout")
259
+ self._kill_process()
260
+ sys.exit("[SGLang] failed to start within timeout")
261
+
262
+ def unload_lora_adapter(self, lora_name: str):
263
+ """Unload lora adapter given the lora name via native SGLang endpoint.
264
+
265
+ Args:
266
+ lora_name: Lora adapter name.
267
+ """
268
+ # Note: SGLang native endpoints are often at root, not /v1
269
+ url = f"http://{self._host}:{self._port}/unload_lora_adapter"
270
+ headers = {"Content-Type": "application/json"}
271
+ try:
272
+ payload = {"lora_name": lora_name}
273
+ response = requests.post(url, json=payload, headers=headers, timeout=10)
274
+ if response.status_code == 200:
275
+ print(f"[SGLang] Unloaded LoRA adapter: {lora_name}")
276
+ else:
277
+ print(f"[SGLang] Failed to unload LoRA: {response.text}")
278
+ except requests.exceptions.RequestException as e:
279
+ print(f"[SGLang] Error unloading LoRA: {e}")
280
+
281
+ def load_lora_adapter(
282
+ self, lora_name: str, new_adapter_path: str, num_trails: int = 5
283
+ ):
284
+ """Dynamically load a LoRA adapter via native SGLang endpoint.
285
+
286
+ Args:
287
+ lora_name: LoRA adapter name.
288
+ new_adapter_path: Path to the new LoRA adapter weights.
289
+ """
290
+ if self._max_lora_rank is None:
291
+ raise ValueError(
292
+ 'LoRA is not enabled. Please set "max_lora_rank" in __init__.'
293
+ )
294
+
295
+ # Unload first to ensure clean state (optional but safer for updates)
296
+ self.unload_lora_adapter(lora_name)
297
+
298
+ url = f"http://{self._host}:{self._port}/load_lora_adapter"
299
+ headers = {"Content-Type": "application/json"}
300
+ payload = {"lora_name": lora_name, "lora_path": new_adapter_path}
301
+
302
+ for i in range(num_trails):
303
+ try:
304
+ response = requests.post(url, json=payload, headers=headers, timeout=60)
305
+ if response.status_code == 200:
306
+ print(
307
+ f"[SGLang] Successfully loaded LoRA adapter: {lora_name} from {new_adapter_path}"
308
+ )
309
+ return True
310
+ else:
311
+ print(
312
+ f"[SGLang] Failed to load LoRA adapter. "
313
+ f"Status code: {response.status_code}, Response: {response.text}"
314
+ )
315
+ # Don't retry immediately if it's a client error (4xx)
316
+ if 400 <= response.status_code < 500:
317
+ return False
318
+ except requests.exceptions.RequestException:
319
+ time.sleep(1)
320
+ continue
321
+
322
+ print(f"[SGLang] Error loading LoRA adapter after {num_trails} trails.")
323
+ return False
324
+
325
+ def close(self):
326
+ """Shut down SGLang server and kill all processes."""
327
+ if self._server_process is not None:
328
+ self._kill_process()
329
+
330
+ def reload(self):
331
+ """Reload the SGLang server."""
332
+ self.launch_sglang_server()
333
+
334
+ def chat_completion(
335
+ self,
336
+ message: str | List[openai.types.chat.ChatCompletionMessageParam],
337
+ max_tokens: Optional[int] = None,
338
+ timeout_seconds: Optional[int] = None,
339
+ lora_name: Optional[str] = None,
340
+ temperature: float = 0.9,
341
+ top_p: float = 0.9,
342
+ chat_template_kwargs: Optional[Dict[str, Any]] = None,
343
+ ) -> str:
344
+ """Send a chat completion query to the SGLang server.
345
+
346
+ Args:
347
+ message: The message in str or openai format.
348
+ max_tokens: The maximum number of tokens to generate.
349
+ timeout_seconds: The timeout seconds.
350
+ lora_name: Name of the LoRA adapter to use for this request.
351
+ temperature: The temperature parameter.
352
+ top_p: The top p parameter.
353
+ chat_template_kwargs: Chat template kwargs.
354
+ """
355
+ data = {
356
+ "messages": [
357
+ (
358
+ {"role": "user", "content": message.strip()}
359
+ if isinstance(message, str)
360
+ else message
361
+ )
362
+ ],
363
+ "temperature": temperature,
364
+ "top_p": top_p,
365
+ "max_tokens": max_tokens,
366
+ }
367
+ data: Dict[str, Any]
368
+
369
+ # In SGLang OpenAI API, the 'model' parameter routes to the adapter
370
+ if lora_name is not None:
371
+ data["model"] = lora_name
372
+ else:
373
+ data["model"] = self._model_path
374
+
375
+ if self._chat_template_kwargs is not None:
376
+ data["chat_template_kwargs"] = self._chat_template_kwargs
377
+ elif chat_template_kwargs is not None:
378
+ data["chat_template_kwargs"] = chat_template_kwargs
379
+
380
+ url = f"http://{self._host}:{self._port}/v1/chat/completions"
381
+ headers = {"Content-Type": "application/json"}
382
+
383
+ response = requests.post(
384
+ url, headers=headers, json=data, timeout=timeout_seconds
385
+ )
386
+ response.raise_for_status()
387
+ return response.json()["choices"][0]["message"]["content"]
388
+
389
+ def embedding(
390
+ self,
391
+ text: str | List[str],
392
+ dimensions: Optional[int] = None,
393
+ timeout_seconds: Optional[float] = None,
394
+ lora_name: Optional[str] = None,
395
+ **kwargs,
396
+ ) -> List[float] | List[List[float]]:
397
+ """Generate embeddings for the given text(s)."""
398
+ is_str_input = isinstance(text, str)
399
+ if is_str_input:
400
+ text = [text]
401
+
402
+ data = {"input": text, "model": lora_name if lora_name else self._model_path}
403
+ data: Dict[str, Any]
404
+
405
+ if dimensions is not None:
406
+ data["dimensions"] = dimensions
407
+
408
+ data.update(kwargs)
409
+
410
+ url = f"http://{self._host}:{self._port}/v1/embeddings"
411
+ headers = {"Content-Type": "application/json"}
412
+
413
+ response = requests.post(
414
+ url, headers=headers, json=data, timeout=timeout_seconds
415
+ )
416
+ response.raise_for_status()
417
+
418
+ response_json = response.json()
419
+ embeddings = [item["embedding"] for item in response_json["data"]]
420
+
421
+ if is_str_input:
422
+ return embeddings[0]
423
+ return embeddings