py-adtools 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adtools/__init__.py +1 -0
- adtools/cli.py +61 -0
- adtools/evaluator/__init__.py +2 -0
- adtools/evaluator/auto_server.py +258 -0
- adtools/evaluator/py_evaluator.py +170 -0
- adtools/evaluator/py_evaluator_ray.py +110 -0
- adtools/lm/__init__.py +4 -0
- adtools/lm/lm_base.py +63 -0
- adtools/lm/openai_api.py +118 -0
- adtools/lm/sglang_server.py +423 -0
- adtools/lm/vllm_server.py +452 -0
- adtools/py_code.py +577 -0
- adtools/sandbox/__init__.py +2 -0
- adtools/sandbox/sandbox_executor.py +244 -0
- adtools/sandbox/sandbox_executor_ray.py +194 -0
- adtools/sandbox/utils.py +32 -0
- py_adtools-0.3.2.dist-info/METADATA +567 -0
- py_adtools-0.3.2.dist-info/RECORD +22 -0
- py_adtools-0.3.2.dist-info/WHEEL +5 -0
- py_adtools-0.3.2.dist-info/entry_points.txt +2 -0
- py_adtools-0.3.2.dist-info/licenses/LICENSE +21 -0
- py_adtools-0.3.2.dist-info/top_level.txt +1 -0
adtools/lm/openai_api.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from typing import List, Optional, Dict, Any
|
|
11
|
+
|
|
12
|
+
import openai.types.chat
|
|
13
|
+
|
|
14
|
+
from adtools.lm.lm_base import LanguageModel
|
|
15
|
+
|
|
16
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OpenAIAPI(LanguageModel):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: str,
|
|
23
|
+
base_url: str = None,
|
|
24
|
+
api_key: str = None,
|
|
25
|
+
**openai_init_kwargs,
|
|
26
|
+
):
|
|
27
|
+
super().__init__()
|
|
28
|
+
# If base_url is set to None, find 'OPENAI_BASE_URL' in environment variables
|
|
29
|
+
if base_url is None:
|
|
30
|
+
if "OPENAI_BASE_URL" not in os.environ:
|
|
31
|
+
raise RuntimeError(
|
|
32
|
+
'If "base_url" is None, the environment variable OPENAI_BASE_URL must be set.'
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
base_url = os.environ["OPENAI_BASE_URL"]
|
|
36
|
+
|
|
37
|
+
# If api_key is set to None, find 'OPENAI_API_KEY' in environment variables
|
|
38
|
+
if api_key is None:
|
|
39
|
+
if "OPENAI_API_KEY" not in os.environ:
|
|
40
|
+
raise RuntimeError('If "api_key" is None, OPENAI_API_KEY must be set.')
|
|
41
|
+
else:
|
|
42
|
+
api_key = os.environ["OPENAI_API_KEY"]
|
|
43
|
+
|
|
44
|
+
self._model = model
|
|
45
|
+
self._client = openai.OpenAI(
|
|
46
|
+
api_key=api_key, base_url=base_url, **openai_init_kwargs
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def chat_completion(
|
|
50
|
+
self,
|
|
51
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
52
|
+
max_tokens: Optional[int] = None,
|
|
53
|
+
timeout_seconds: Optional[float] = None,
|
|
54
|
+
*args,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
"""Send a chat completion query with OpenAI format to the vLLM server.
|
|
58
|
+
Return the response content.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
message: The message in str or openai format.
|
|
62
|
+
max_tokens: The maximum number of tokens to generate.
|
|
63
|
+
timeout_seconds: The timeout seconds.
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(message, str):
|
|
66
|
+
message = [{"role": "user", "content": message.strip()}]
|
|
67
|
+
|
|
68
|
+
response = self._client.chat.completions.create(
|
|
69
|
+
model=self._model,
|
|
70
|
+
messages=message,
|
|
71
|
+
stream=False,
|
|
72
|
+
max_tokens=max_tokens,
|
|
73
|
+
timeout=timeout_seconds,
|
|
74
|
+
*args,
|
|
75
|
+
**kwargs,
|
|
76
|
+
)
|
|
77
|
+
return response.choices[0].message.content
|
|
78
|
+
|
|
79
|
+
def embedding(
|
|
80
|
+
self,
|
|
81
|
+
text: str | List[str],
|
|
82
|
+
dimensions: Optional[int] = None,
|
|
83
|
+
timeout_seconds: Optional[float] = None,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> List[float] | List[List[float]]:
|
|
86
|
+
"""Generate embeddings for the given text(s) using the model specified during initialization.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
text: The text or a list of texts to embed.
|
|
90
|
+
dimensions: The number of dimensions for the output embeddings.
|
|
91
|
+
timeout_seconds: The timeout seconds.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The embedding for the text, or a list of embeddings for the list of texts.
|
|
95
|
+
"""
|
|
96
|
+
is_str_input = isinstance(text, str)
|
|
97
|
+
if is_str_input:
|
|
98
|
+
text = [text]
|
|
99
|
+
|
|
100
|
+
# Prepare arguments for the OpenAI API call
|
|
101
|
+
api_kwargs = {
|
|
102
|
+
"input": text,
|
|
103
|
+
"model": self._model,
|
|
104
|
+
}
|
|
105
|
+
api_kwargs: Dict[str, Any]
|
|
106
|
+
|
|
107
|
+
if dimensions is not None:
|
|
108
|
+
api_kwargs["dimensions"] = dimensions
|
|
109
|
+
if timeout_seconds is not None:
|
|
110
|
+
api_kwargs["timeout"] = timeout_seconds
|
|
111
|
+
|
|
112
|
+
api_kwargs.update(kwargs)
|
|
113
|
+
response = self._client.embeddings.create(**api_kwargs)
|
|
114
|
+
embeddings = [item.embedding for item in response.data]
|
|
115
|
+
|
|
116
|
+
if is_str_input:
|
|
117
|
+
return embeddings[0]
|
|
118
|
+
return embeddings
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, List, Literal, Dict, Any
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import psutil
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
import openai.types.chat
|
|
17
|
+
import requests
|
|
18
|
+
|
|
19
|
+
from adtools.lm.lm_base import LanguageModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _print_cmd_list(cmd_list, gpus, host, port):
|
|
23
|
+
print("\n" + "=" * 80)
|
|
24
|
+
print(f"[SGLang] Launching SGLang on GPU:{gpus}; URL: http://{host}:{port}")
|
|
25
|
+
print("=" * 80)
|
|
26
|
+
cmd = cmd_list[0] + " \\\n"
|
|
27
|
+
for c in cmd_list[1:]:
|
|
28
|
+
cmd += " " + c + " \\\n"
|
|
29
|
+
print(cmd.strip())
|
|
30
|
+
print("=" * 80 + "\n", flush=True)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SGLangServer(LanguageModel):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_path: str,
|
|
37
|
+
port: int,
|
|
38
|
+
gpus: int | list[int],
|
|
39
|
+
tokenizer_path: Optional[str] = None,
|
|
40
|
+
context_length: int = 16384,
|
|
41
|
+
max_lora_rank: Optional[int] = None,
|
|
42
|
+
host: str = "0.0.0.0",
|
|
43
|
+
mem_fraction_static: float = 0.85,
|
|
44
|
+
deploy_timeout_seconds: int = 600,
|
|
45
|
+
*,
|
|
46
|
+
launch_sglang_in_init=True,
|
|
47
|
+
sglang_log_level: Literal["debug", "info", "warning", "error"] = "info",
|
|
48
|
+
silent_mode: bool = False,
|
|
49
|
+
env_variable_dict: Optional[Dict[str, str]] = None,
|
|
50
|
+
sglang_serve_args: Optional[List[str]] = None,
|
|
51
|
+
sglang_serve_kwargs: Optional[Dict[str, str]] = None,
|
|
52
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
|
53
|
+
):
|
|
54
|
+
"""Deploy an SGLang server on specified GPUs.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model_path: Path to the model to deploy.
|
|
58
|
+
port: Port to deploy.
|
|
59
|
+
gpus: List of GPUs to deploy.
|
|
60
|
+
tokenizer_path: Path to the tokenizer. Defaults to model_path.
|
|
61
|
+
context_length: The context length (mapped to --context-length).
|
|
62
|
+
max_lora_rank: Max rank of LoRA adapter. Defaults to `None`.
|
|
63
|
+
Set this to enable LoRA support (mapped to --max-lora-rank).
|
|
64
|
+
host: Host address for SGLang server.
|
|
65
|
+
mem_fraction_static: The memory fraction for static allocation (mapped to --mem-fraction-static).
|
|
66
|
+
deploy_timeout_seconds: Timeout to deploy (in seconds).
|
|
67
|
+
launch_sglang_in_init: Launch SGLang during initialization of this class.
|
|
68
|
+
sglang_log_level: Log level.
|
|
69
|
+
silent_mode: Silent mode.
|
|
70
|
+
env_variable_dict: Environment variables to use.
|
|
71
|
+
sglang_serve_args: Additional arguments to pass to sglang server, e.g., ['--enable-flashinfer'],
|
|
72
|
+
or ['--attention-backend', 'triton']
|
|
73
|
+
sglang_serve_kwargs: Keyword arguments to pass to sglang server.
|
|
74
|
+
chat_template_kwargs: Keyword arguments for chat template (passed during request).
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
# deploy a model on GPU 0 and 1 with LoRA support
|
|
78
|
+
llm = SGLangServer(
|
|
79
|
+
model_path='meta-llama/Meta-Llama-3-8B-Instruct',
|
|
80
|
+
port=30000,
|
|
81
|
+
gpus=[0, 1],
|
|
82
|
+
max_lora_rank=16, # Enable LoRA
|
|
83
|
+
sglang_serve_args=['--attention-backend', 'triton']
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Load an adapter
|
|
87
|
+
llm.load_lora_adapter("my_adapter", "/path/to/adapter")
|
|
88
|
+
|
|
89
|
+
# Use the adapter
|
|
90
|
+
llm.chat_completion("Hello", lora_name="my_adapter")
|
|
91
|
+
"""
|
|
92
|
+
self._model_path = model_path
|
|
93
|
+
self._port = port
|
|
94
|
+
self._gpus = gpus
|
|
95
|
+
self._tokenizer_path = (
|
|
96
|
+
tokenizer_path if tokenizer_path is not None else model_path
|
|
97
|
+
)
|
|
98
|
+
self._context_length = context_length
|
|
99
|
+
self._max_lora_rank = max_lora_rank
|
|
100
|
+
self._host = host
|
|
101
|
+
self._mem_fraction_static = mem_fraction_static
|
|
102
|
+
self._deploy_timeout_seconds = deploy_timeout_seconds
|
|
103
|
+
self._sglang_log_level = sglang_log_level
|
|
104
|
+
self._silent_mode = silent_mode
|
|
105
|
+
self._env_variable_dict = env_variable_dict
|
|
106
|
+
self._sglang_serve_args = sglang_serve_args
|
|
107
|
+
self._sglang_serve_kwargs = sglang_serve_kwargs
|
|
108
|
+
self._chat_template_kwargs = chat_template_kwargs
|
|
109
|
+
self._server_process = None
|
|
110
|
+
|
|
111
|
+
# Deploy SGLang
|
|
112
|
+
if launch_sglang_in_init:
|
|
113
|
+
self.launch_sglang_server()
|
|
114
|
+
|
|
115
|
+
def launch_sglang_server(self, detach: bool = False, skip_if_running: bool = False):
|
|
116
|
+
try:
|
|
117
|
+
import sglang
|
|
118
|
+
except ImportError:
|
|
119
|
+
raise
|
|
120
|
+
|
|
121
|
+
if skip_if_running and self._is_server_running():
|
|
122
|
+
print(
|
|
123
|
+
f"[SGLang] Server already running on http://{self._host}:{self._port}. "
|
|
124
|
+
f"Skipping launch."
|
|
125
|
+
)
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
self._detached = detach
|
|
129
|
+
self._server_process = self._launch_sglang(detach=detach)
|
|
130
|
+
self._wait_for_server()
|
|
131
|
+
|
|
132
|
+
def _launch_sglang(self, detach: bool = False):
|
|
133
|
+
"""Launch an SGLang server and return the subprocess."""
|
|
134
|
+
if isinstance(self._gpus, int):
|
|
135
|
+
gpus = str(self._gpus)
|
|
136
|
+
tp_size = 1
|
|
137
|
+
else:
|
|
138
|
+
gpus = ",".join([str(g) for g in self._gpus])
|
|
139
|
+
tp_size = len(self._gpus)
|
|
140
|
+
|
|
141
|
+
executable_path = sys.executable
|
|
142
|
+
|
|
143
|
+
# SGLang launch command structure
|
|
144
|
+
cmd = [
|
|
145
|
+
executable_path,
|
|
146
|
+
"-m",
|
|
147
|
+
"sglang.launch_server",
|
|
148
|
+
"--model-path",
|
|
149
|
+
self._model_path,
|
|
150
|
+
"--tokenizer-path",
|
|
151
|
+
self._tokenizer_path,
|
|
152
|
+
"--port",
|
|
153
|
+
str(self._port),
|
|
154
|
+
"--host",
|
|
155
|
+
self._host,
|
|
156
|
+
"--context-length",
|
|
157
|
+
str(self._context_length),
|
|
158
|
+
"--mem-fraction-static",
|
|
159
|
+
str(self._mem_fraction_static),
|
|
160
|
+
"--tp",
|
|
161
|
+
str(tp_size),
|
|
162
|
+
"--trust-remote-code",
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
# Enable LoRA support if rank is specified
|
|
166
|
+
if self._max_lora_rank is not None:
|
|
167
|
+
cmd.extend(["--max-lora-rank", str(self._max_lora_rank)])
|
|
168
|
+
# SGLang sometimes requires disabling radix cache for LoRA in older versions,
|
|
169
|
+
# but newer versions support it. If you face issues, consider adding "--disable-radix-cache".
|
|
170
|
+
|
|
171
|
+
# Other args for sglang serve
|
|
172
|
+
if self._sglang_serve_args is not None:
|
|
173
|
+
for arg in self._sglang_serve_args:
|
|
174
|
+
cmd.append(arg)
|
|
175
|
+
|
|
176
|
+
# Other kwargs for sglang serve
|
|
177
|
+
if self._sglang_serve_kwargs is not None:
|
|
178
|
+
for kwarg, value in self._sglang_serve_kwargs.items():
|
|
179
|
+
cmd.extend([kwarg, value])
|
|
180
|
+
|
|
181
|
+
# Environmental variables
|
|
182
|
+
env = os.environ.copy()
|
|
183
|
+
env["CUDA_VISIBLE_DEVICES"] = gpus
|
|
184
|
+
env["LOG_LEVEL"] = self._sglang_log_level.upper()
|
|
185
|
+
|
|
186
|
+
# Handle NCCL issues if using multiple GPUs
|
|
187
|
+
if tp_size > 1:
|
|
188
|
+
env["NCCL_P2P_DISABLE"] = "1"
|
|
189
|
+
|
|
190
|
+
if self._env_variable_dict is not None:
|
|
191
|
+
for k, v in self._env_variable_dict.items():
|
|
192
|
+
env[k] = v
|
|
193
|
+
|
|
194
|
+
_print_cmd_list(cmd, gpus=self._gpus, host=self._host, port=self._port)
|
|
195
|
+
|
|
196
|
+
# Launch using subprocess
|
|
197
|
+
stdout = Path(os.devnull).open("w") if self._silent_mode else None
|
|
198
|
+
preexec_fn = os.setsid if detach and sys.platform != "win32" else None
|
|
199
|
+
proc = subprocess.Popen(
|
|
200
|
+
cmd, env=env, stdout=stdout, stderr=subprocess.STDOUT, preexec_fn=preexec_fn
|
|
201
|
+
)
|
|
202
|
+
return proc
|
|
203
|
+
|
|
204
|
+
def _kill_process(self):
|
|
205
|
+
if getattr(self, "_detached", False):
|
|
206
|
+
print(
|
|
207
|
+
f"[SGLang] Server on port {self._port} is detached. Not killing process."
|
|
208
|
+
)
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
# Get child processes before terminating parent
|
|
213
|
+
try:
|
|
214
|
+
parent = psutil.Process(self._server_process.pid)
|
|
215
|
+
children = parent.children(recursive=True)
|
|
216
|
+
except psutil.NoSuchProcess:
|
|
217
|
+
children = []
|
|
218
|
+
|
|
219
|
+
# Terminate parent process
|
|
220
|
+
self._server_process.terminate()
|
|
221
|
+
self._server_process.wait(timeout=5)
|
|
222
|
+
print(f"[SGLang] terminated process: {self._server_process.pid}")
|
|
223
|
+
|
|
224
|
+
# Kill any remaining children
|
|
225
|
+
for child in children:
|
|
226
|
+
try:
|
|
227
|
+
child.terminate()
|
|
228
|
+
child.wait(timeout=2)
|
|
229
|
+
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
|
230
|
+
try:
|
|
231
|
+
child.kill()
|
|
232
|
+
except psutil.NoSuchProcess:
|
|
233
|
+
pass
|
|
234
|
+
except subprocess.TimeoutExpired:
|
|
235
|
+
self._server_process.kill()
|
|
236
|
+
print(f"[SGLang] killed process: {self._server_process.pid}")
|
|
237
|
+
|
|
238
|
+
def _is_server_running(self):
|
|
239
|
+
"""Check if an SGLang server is already running on the given host and port."""
|
|
240
|
+
health = f"http://{self._host}:{self._port}/health"
|
|
241
|
+
try:
|
|
242
|
+
if requests.get(health, timeout=1).status_code == 200:
|
|
243
|
+
return True
|
|
244
|
+
except requests.exceptions.RequestException:
|
|
245
|
+
pass
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
def _wait_for_server(self):
|
|
249
|
+
"""Check server state and /health endpoint."""
|
|
250
|
+
for _ in range(self._deploy_timeout_seconds):
|
|
251
|
+
if self._server_process.poll() is not None:
|
|
252
|
+
sys.exit(f"[SGLang] crashed (exit {self._server_process.returncode})")
|
|
253
|
+
|
|
254
|
+
if self._is_server_running():
|
|
255
|
+
return
|
|
256
|
+
time.sleep(1)
|
|
257
|
+
|
|
258
|
+
print("[SGLang] failed to start within timeout")
|
|
259
|
+
self._kill_process()
|
|
260
|
+
sys.exit("[SGLang] failed to start within timeout")
|
|
261
|
+
|
|
262
|
+
def unload_lora_adapter(self, lora_name: str):
|
|
263
|
+
"""Unload lora adapter given the lora name via native SGLang endpoint.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
lora_name: Lora adapter name.
|
|
267
|
+
"""
|
|
268
|
+
# Note: SGLang native endpoints are often at root, not /v1
|
|
269
|
+
url = f"http://{self._host}:{self._port}/unload_lora_adapter"
|
|
270
|
+
headers = {"Content-Type": "application/json"}
|
|
271
|
+
try:
|
|
272
|
+
payload = {"lora_name": lora_name}
|
|
273
|
+
response = requests.post(url, json=payload, headers=headers, timeout=10)
|
|
274
|
+
if response.status_code == 200:
|
|
275
|
+
print(f"[SGLang] Unloaded LoRA adapter: {lora_name}")
|
|
276
|
+
else:
|
|
277
|
+
print(f"[SGLang] Failed to unload LoRA: {response.text}")
|
|
278
|
+
except requests.exceptions.RequestException as e:
|
|
279
|
+
print(f"[SGLang] Error unloading LoRA: {e}")
|
|
280
|
+
|
|
281
|
+
def load_lora_adapter(
|
|
282
|
+
self, lora_name: str, new_adapter_path: str, num_trails: int = 5
|
|
283
|
+
):
|
|
284
|
+
"""Dynamically load a LoRA adapter via native SGLang endpoint.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
lora_name: LoRA adapter name.
|
|
288
|
+
new_adapter_path: Path to the new LoRA adapter weights.
|
|
289
|
+
"""
|
|
290
|
+
if self._max_lora_rank is None:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
'LoRA is not enabled. Please set "max_lora_rank" in __init__.'
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Unload first to ensure clean state (optional but safer for updates)
|
|
296
|
+
self.unload_lora_adapter(lora_name)
|
|
297
|
+
|
|
298
|
+
url = f"http://{self._host}:{self._port}/load_lora_adapter"
|
|
299
|
+
headers = {"Content-Type": "application/json"}
|
|
300
|
+
payload = {"lora_name": lora_name, "lora_path": new_adapter_path}
|
|
301
|
+
|
|
302
|
+
for i in range(num_trails):
|
|
303
|
+
try:
|
|
304
|
+
response = requests.post(url, json=payload, headers=headers, timeout=60)
|
|
305
|
+
if response.status_code == 200:
|
|
306
|
+
print(
|
|
307
|
+
f"[SGLang] Successfully loaded LoRA adapter: {lora_name} from {new_adapter_path}"
|
|
308
|
+
)
|
|
309
|
+
return True
|
|
310
|
+
else:
|
|
311
|
+
print(
|
|
312
|
+
f"[SGLang] Failed to load LoRA adapter. "
|
|
313
|
+
f"Status code: {response.status_code}, Response: {response.text}"
|
|
314
|
+
)
|
|
315
|
+
# Don't retry immediately if it's a client error (4xx)
|
|
316
|
+
if 400 <= response.status_code < 500:
|
|
317
|
+
return False
|
|
318
|
+
except requests.exceptions.RequestException:
|
|
319
|
+
time.sleep(1)
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
print(f"[SGLang] Error loading LoRA adapter after {num_trails} trails.")
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
def close(self):
|
|
326
|
+
"""Shut down SGLang server and kill all processes."""
|
|
327
|
+
if self._server_process is not None:
|
|
328
|
+
self._kill_process()
|
|
329
|
+
|
|
330
|
+
def reload(self):
|
|
331
|
+
"""Reload the SGLang server."""
|
|
332
|
+
self.launch_sglang_server()
|
|
333
|
+
|
|
334
|
+
def chat_completion(
|
|
335
|
+
self,
|
|
336
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
337
|
+
max_tokens: Optional[int] = None,
|
|
338
|
+
timeout_seconds: Optional[int] = None,
|
|
339
|
+
lora_name: Optional[str] = None,
|
|
340
|
+
temperature: float = 0.9,
|
|
341
|
+
top_p: float = 0.9,
|
|
342
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
|
343
|
+
) -> str:
|
|
344
|
+
"""Send a chat completion query to the SGLang server.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
message: The message in str or openai format.
|
|
348
|
+
max_tokens: The maximum number of tokens to generate.
|
|
349
|
+
timeout_seconds: The timeout seconds.
|
|
350
|
+
lora_name: Name of the LoRA adapter to use for this request.
|
|
351
|
+
temperature: The temperature parameter.
|
|
352
|
+
top_p: The top p parameter.
|
|
353
|
+
chat_template_kwargs: Chat template kwargs.
|
|
354
|
+
"""
|
|
355
|
+
data = {
|
|
356
|
+
"messages": [
|
|
357
|
+
(
|
|
358
|
+
{"role": "user", "content": message.strip()}
|
|
359
|
+
if isinstance(message, str)
|
|
360
|
+
else message
|
|
361
|
+
)
|
|
362
|
+
],
|
|
363
|
+
"temperature": temperature,
|
|
364
|
+
"top_p": top_p,
|
|
365
|
+
"max_tokens": max_tokens,
|
|
366
|
+
}
|
|
367
|
+
data: Dict[str, Any]
|
|
368
|
+
|
|
369
|
+
# In SGLang OpenAI API, the 'model' parameter routes to the adapter
|
|
370
|
+
if lora_name is not None:
|
|
371
|
+
data["model"] = lora_name
|
|
372
|
+
else:
|
|
373
|
+
data["model"] = self._model_path
|
|
374
|
+
|
|
375
|
+
if self._chat_template_kwargs is not None:
|
|
376
|
+
data["chat_template_kwargs"] = self._chat_template_kwargs
|
|
377
|
+
elif chat_template_kwargs is not None:
|
|
378
|
+
data["chat_template_kwargs"] = chat_template_kwargs
|
|
379
|
+
|
|
380
|
+
url = f"http://{self._host}:{self._port}/v1/chat/completions"
|
|
381
|
+
headers = {"Content-Type": "application/json"}
|
|
382
|
+
|
|
383
|
+
response = requests.post(
|
|
384
|
+
url, headers=headers, json=data, timeout=timeout_seconds
|
|
385
|
+
)
|
|
386
|
+
response.raise_for_status()
|
|
387
|
+
return response.json()["choices"][0]["message"]["content"]
|
|
388
|
+
|
|
389
|
+
def embedding(
|
|
390
|
+
self,
|
|
391
|
+
text: str | List[str],
|
|
392
|
+
dimensions: Optional[int] = None,
|
|
393
|
+
timeout_seconds: Optional[float] = None,
|
|
394
|
+
lora_name: Optional[str] = None,
|
|
395
|
+
**kwargs,
|
|
396
|
+
) -> List[float] | List[List[float]]:
|
|
397
|
+
"""Generate embeddings for the given text(s)."""
|
|
398
|
+
is_str_input = isinstance(text, str)
|
|
399
|
+
if is_str_input:
|
|
400
|
+
text = [text]
|
|
401
|
+
|
|
402
|
+
data = {"input": text, "model": lora_name if lora_name else self._model_path}
|
|
403
|
+
data: Dict[str, Any]
|
|
404
|
+
|
|
405
|
+
if dimensions is not None:
|
|
406
|
+
data["dimensions"] = dimensions
|
|
407
|
+
|
|
408
|
+
data.update(kwargs)
|
|
409
|
+
|
|
410
|
+
url = f"http://{self._host}:{self._port}/v1/embeddings"
|
|
411
|
+
headers = {"Content-Type": "application/json"}
|
|
412
|
+
|
|
413
|
+
response = requests.post(
|
|
414
|
+
url, headers=headers, json=data, timeout=timeout_seconds
|
|
415
|
+
)
|
|
416
|
+
response.raise_for_status()
|
|
417
|
+
|
|
418
|
+
response_json = response.json()
|
|
419
|
+
embeddings = [item["embedding"] for item in response_json["data"]]
|
|
420
|
+
|
|
421
|
+
if is_str_input:
|
|
422
|
+
return embeddings[0]
|
|
423
|
+
return embeddings
|