sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -1,58 +1,81 @@
|
|
1
|
+
"""Common utilities."""
|
2
|
+
|
1
3
|
import base64
|
4
|
+
import fcntl
|
5
|
+
import logging
|
6
|
+
import multiprocessing
|
2
7
|
import os
|
3
8
|
import random
|
4
9
|
import socket
|
5
|
-
import
|
10
|
+
import struct
|
6
11
|
import time
|
7
|
-
import
|
12
|
+
from importlib.metadata import PackageNotFoundError, version
|
8
13
|
from io import BytesIO
|
9
14
|
from typing import List, Optional
|
10
15
|
|
11
16
|
import numpy as np
|
17
|
+
import psutil
|
12
18
|
import requests
|
19
|
+
import rpyc
|
13
20
|
import torch
|
14
|
-
import
|
21
|
+
import triton
|
22
|
+
from fastapi.responses import JSONResponse
|
23
|
+
from packaging import version as pkg_version
|
24
|
+
from rpyc.utils.server import ThreadedServer
|
25
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
15
26
|
|
16
|
-
|
27
|
+
logger = logging.getLogger(__name__)
|
17
28
|
|
18
29
|
|
19
|
-
|
20
|
-
|
21
|
-
def time_func(*args, **kwargs):
|
22
|
-
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
23
|
-
torch.cuda.synchronize()
|
24
|
-
start_time = time.time()
|
25
|
-
ans = func(*args, **kwargs)
|
26
|
-
torch.cuda.synchronize()
|
27
|
-
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
28
|
-
return ans
|
29
|
-
else:
|
30
|
-
torch.cuda.synchronize()
|
31
|
-
ans = func(*args, **kwargs)
|
32
|
-
torch.cuda.synchronize()
|
33
|
-
return ans
|
30
|
+
show_time_cost = False
|
31
|
+
time_infos = {}
|
34
32
|
|
35
|
-
return time_func
|
36
33
|
|
37
|
-
|
34
|
+
def enable_show_time_cost():
|
35
|
+
global show_time_cost
|
36
|
+
show_time_cost = True
|
38
37
|
|
39
38
|
|
40
|
-
|
39
|
+
class TimeInfo:
|
40
|
+
def __init__(self, name, interval=0.1, color=0, indent=0):
|
41
|
+
self.name = name
|
42
|
+
self.interval = interval
|
43
|
+
self.color = color
|
44
|
+
self.indent = indent
|
45
|
+
|
46
|
+
self.acc_time = 0
|
47
|
+
self.last_acc_time = 0
|
48
|
+
|
49
|
+
def check(self):
|
50
|
+
if self.acc_time - self.last_acc_time > self.interval:
|
51
|
+
self.last_acc_time = self.acc_time
|
52
|
+
return True
|
53
|
+
return False
|
41
54
|
|
55
|
+
def pretty_print(self):
|
56
|
+
print(f"\x1b[{self.color}m", end="")
|
57
|
+
print("-" * self.indent * 2, end="")
|
58
|
+
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
|
42
59
|
|
43
|
-
|
60
|
+
|
61
|
+
def mark_start(name, interval=0.1, color=0, indent=0):
|
62
|
+
global time_infos, show_time_cost
|
63
|
+
if not show_time_cost:
|
64
|
+
return
|
44
65
|
torch.cuda.synchronize()
|
45
|
-
|
46
|
-
|
47
|
-
|
66
|
+
if time_infos.get(name, None) is None:
|
67
|
+
time_infos[name] = TimeInfo(name, interval, color, indent)
|
68
|
+
time_infos[name].acc_time -= time.time()
|
48
69
|
|
49
70
|
|
50
|
-
def mark_end(
|
71
|
+
def mark_end(name):
|
72
|
+
global time_infos, show_time_cost
|
73
|
+
if not show_time_cost:
|
74
|
+
return
|
51
75
|
torch.cuda.synchronize()
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
print(f"cost {key}:", cost_time)
|
76
|
+
time_infos[name].acc_time += time.time()
|
77
|
+
if time_infos[name].check():
|
78
|
+
time_infos[name].pretty_print()
|
56
79
|
|
57
80
|
|
58
81
|
def calculate_time(show=False, min_cost_ms=0.0):
|
@@ -74,83 +97,86 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
74
97
|
return wrapper
|
75
98
|
|
76
99
|
|
77
|
-
def
|
78
|
-
|
100
|
+
def get_available_gpu_memory(gpu_id, distributed=False):
|
101
|
+
"""
|
102
|
+
Get available memory for cuda:gpu_id device.
|
103
|
+
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
104
|
+
"""
|
105
|
+
num_gpus = torch.cuda.device_count()
|
106
|
+
assert gpu_id < num_gpus
|
79
107
|
|
80
|
-
torch.
|
81
|
-
|
82
|
-
|
108
|
+
if torch.cuda.current_device() != gpu_id:
|
109
|
+
print(
|
110
|
+
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
111
|
+
"which may cause useless memory allocation for torch CUDA context.",
|
112
|
+
)
|
83
113
|
|
114
|
+
torch.cuda.empty_cache()
|
115
|
+
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
84
116
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
117
|
+
if distributed:
|
118
|
+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
119
|
+
torch.device("cuda", gpu_id)
|
120
|
+
)
|
121
|
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
122
|
+
free_gpu_memory = tensor.item()
|
90
123
|
|
91
|
-
|
92
|
-
try:
|
93
|
-
s.bind(("", port))
|
94
|
-
port_list.append(port)
|
95
|
-
except socket.error:
|
96
|
-
pass
|
124
|
+
return free_gpu_memory / (1 << 30)
|
97
125
|
|
98
|
-
|
99
|
-
|
100
|
-
|
126
|
+
|
127
|
+
def set_random_seed(seed: int) -> None:
|
128
|
+
"""Set the random seed for all libraries."""
|
129
|
+
random.seed(seed)
|
130
|
+
np.random.seed(seed)
|
131
|
+
torch.manual_seed(seed)
|
132
|
+
if torch.cuda.is_available():
|
133
|
+
torch.cuda.manual_seed_all(seed)
|
101
134
|
|
102
135
|
|
103
|
-
def
|
136
|
+
def is_port_available(port):
|
137
|
+
"""Return whether a port is available."""
|
104
138
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
105
139
|
try:
|
106
140
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
107
141
|
s.bind(("", port))
|
142
|
+
s.listen(1)
|
108
143
|
return True
|
109
144
|
except socket.error:
|
110
145
|
return False
|
111
146
|
|
112
147
|
|
113
|
-
def
|
148
|
+
def allocate_init_ports(
|
114
149
|
port: Optional[int] = None,
|
115
150
|
additional_ports: Optional[List[int]] = None,
|
116
151
|
tp_size: int = 1,
|
152
|
+
dp_size: int = 1,
|
117
153
|
):
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
# first check on server port
|
124
|
-
if not check_port(port):
|
125
|
-
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
126
|
-
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
127
|
-
port = new_port
|
128
|
-
|
129
|
-
# then we check on additional ports
|
130
|
-
additional_unique_ports = set(additional_ports) - {port}
|
131
|
-
# filter out ports that are already in use
|
132
|
-
can_use_ports = [port for port in additional_unique_ports if check_port(port)]
|
133
|
-
|
134
|
-
num_specified_ports = len(can_use_ports)
|
135
|
-
if num_specified_ports < 4 + tp_size:
|
136
|
-
addtional_can_use_ports = alloc_usable_network_port(
|
137
|
-
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
138
|
-
)
|
139
|
-
can_use_ports.extend(addtional_can_use_ports)
|
154
|
+
"""Allocate ports for all connections."""
|
155
|
+
if additional_ports:
|
156
|
+
ret_ports = [port] + additional_ports
|
157
|
+
else:
|
158
|
+
ret_ports = [port]
|
140
159
|
|
141
|
-
|
142
|
-
|
160
|
+
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
161
|
+
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
143
162
|
|
163
|
+
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
|
164
|
+
num_ports_needed = 4 + dp_size * (1 + tp_size)
|
165
|
+
while len(ret_ports) < num_ports_needed:
|
166
|
+
if cur_port not in ret_ports and is_port_available(cur_port):
|
167
|
+
ret_ports.append(cur_port)
|
168
|
+
cur_port += 1
|
144
169
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
170
|
+
if port is not None and ret_ports[0] != port:
|
171
|
+
logger.warn(
|
172
|
+
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
173
|
+
)
|
149
174
|
|
175
|
+
return ret_ports[0], ret_ports[1:num_ports_needed]
|
150
176
|
|
151
|
-
def get_int_token_logit_bias(tokenizer, vocab_size):
|
152
|
-
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
153
177
|
|
178
|
+
def get_int_token_logit_bias(tokenizer, vocab_size):
|
179
|
+
"""Get the logit bias for integer-only tokens."""
|
154
180
|
# a bug when model's vocab size > tokenizer.vocab_size
|
155
181
|
vocab_size = tokenizer.vocab_size
|
156
182
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
@@ -164,14 +190,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|
164
190
|
|
165
191
|
def wrap_kernel_launcher(kernel):
|
166
192
|
"""A faster launcher for triton kernels."""
|
167
|
-
|
168
|
-
|
169
|
-
if dist.is_initialized():
|
170
|
-
rank = dist.get_rank()
|
171
|
-
else:
|
172
|
-
rank = 0
|
193
|
+
if int(triton.__version__.split(".")[0]) >= 3:
|
194
|
+
return None
|
173
195
|
|
174
|
-
|
196
|
+
gpu_id = torch.cuda.current_device()
|
197
|
+
kernels = kernel.cache[gpu_id].values()
|
175
198
|
kernel = next(iter(kernels))
|
176
199
|
|
177
200
|
# Different trition versions use different low-level names
|
@@ -231,20 +254,104 @@ def wrap_kernel_launcher(kernel):
|
|
231
254
|
|
232
255
|
|
233
256
|
def is_multimodal_model(model):
|
234
|
-
if isinstance(model, str):
|
235
|
-
return "llava" in model or "yi-vl" in model
|
236
257
|
from sglang.srt.model_config import ModelConfig
|
237
258
|
|
259
|
+
if isinstance(model, str):
|
260
|
+
model = model.lower()
|
261
|
+
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
262
|
+
|
238
263
|
if isinstance(model, ModelConfig):
|
239
264
|
model_path = model.path.lower()
|
240
|
-
return
|
241
|
-
|
265
|
+
return (
|
266
|
+
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
267
|
+
)
|
268
|
+
|
269
|
+
raise ValueError("unrecognized type")
|
270
|
+
|
271
|
+
|
272
|
+
def decode_video_base64(video_base64):
|
273
|
+
from PIL import Image
|
274
|
+
|
275
|
+
# Decode the base64 string
|
276
|
+
video_bytes = base64.b64decode(video_base64)
|
277
|
+
|
278
|
+
# Placeholder for the start indices of each PNG image
|
279
|
+
img_starts = []
|
280
|
+
|
281
|
+
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
282
|
+
|
283
|
+
assert frame_format in [
|
284
|
+
"PNG",
|
285
|
+
"JPEG",
|
286
|
+
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
287
|
+
|
288
|
+
if frame_format == "PNG":
|
289
|
+
# Find each PNG start signature to isolate images
|
290
|
+
i = 0
|
291
|
+
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
292
|
+
# Check if we found the start of a PNG file
|
293
|
+
if (
|
294
|
+
video_bytes[i] == 0x89
|
295
|
+
and video_bytes[i + 1] == 0x50
|
296
|
+
and video_bytes[i + 2] == 0x4E
|
297
|
+
and video_bytes[i + 3] == 0x47
|
298
|
+
and video_bytes[i + 4] == 0x0D
|
299
|
+
and video_bytes[i + 5] == 0x0A
|
300
|
+
and video_bytes[i + 6] == 0x1A
|
301
|
+
and video_bytes[i + 7] == 0x0A
|
302
|
+
):
|
303
|
+
img_starts.append(i)
|
304
|
+
i += 8 # Skip the PNG signature
|
305
|
+
else:
|
306
|
+
i += 1
|
307
|
+
else:
|
308
|
+
# Find each JPEG start (0xFFD8) to isolate images
|
309
|
+
i = 0
|
310
|
+
while (
|
311
|
+
i < len(video_bytes) - 1
|
312
|
+
): # Adjusted for the length of the JPEG SOI signature
|
313
|
+
# Check if we found the start of a JPEG file
|
314
|
+
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
315
|
+
img_starts.append(i)
|
316
|
+
# Move to the next byte to continue searching for the next image start
|
317
|
+
i += 2
|
318
|
+
else:
|
319
|
+
i += 1
|
320
|
+
|
321
|
+
frames = []
|
322
|
+
for start_idx in img_starts:
|
323
|
+
# Assuming each image is back-to-back, the end of one image is the start of another
|
324
|
+
# The last image goes until the end of the byte string
|
325
|
+
end_idx = (
|
326
|
+
img_starts[img_starts.index(start_idx) + 1]
|
327
|
+
if img_starts.index(start_idx) + 1 < len(img_starts)
|
328
|
+
else len(video_bytes)
|
329
|
+
)
|
330
|
+
img_bytes = video_bytes[start_idx:end_idx]
|
331
|
+
|
332
|
+
# Convert bytes to a PIL Image
|
333
|
+
img = Image.open(BytesIO(img_bytes))
|
334
|
+
|
335
|
+
# Convert PIL Image to a NumPy array
|
336
|
+
frame = np.array(img)
|
337
|
+
|
338
|
+
# Append the frame to the list of frames
|
339
|
+
frames.append(frame)
|
340
|
+
|
341
|
+
# Ensure there's at least one frame to avoid errors with np.stack
|
342
|
+
if frames:
|
343
|
+
return np.stack(frames, axis=0), img.size
|
344
|
+
else:
|
345
|
+
return np.array([]), (
|
346
|
+
0,
|
347
|
+
0,
|
348
|
+
) # Return an empty array and size tuple if no frames were found
|
242
349
|
|
243
350
|
|
244
351
|
def load_image(image_file):
|
245
352
|
from PIL import Image
|
246
353
|
|
247
|
-
image = None
|
354
|
+
image = image_size = None
|
248
355
|
|
249
356
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
250
357
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
@@ -255,7 +362,265 @@ def load_image(image_file):
|
|
255
362
|
elif image_file.startswith("data:"):
|
256
363
|
image_file = image_file.split(",")[1]
|
257
364
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
365
|
+
elif image_file.startswith("video:"):
|
366
|
+
image_file = image_file.replace("video:", "")
|
367
|
+
image, image_size = decode_video_base64(image_file)
|
258
368
|
else:
|
259
369
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
260
370
|
|
261
|
-
return image
|
371
|
+
return image, image_size
|
372
|
+
|
373
|
+
|
374
|
+
def connect_rpyc_service(host, port):
|
375
|
+
repeat_count = 0
|
376
|
+
while repeat_count < 20:
|
377
|
+
try:
|
378
|
+
con = rpyc.connect(
|
379
|
+
host,
|
380
|
+
port,
|
381
|
+
config={
|
382
|
+
"allow_public_attrs": True,
|
383
|
+
"allow_pickle": True,
|
384
|
+
"sync_request_timeout": 3600,
|
385
|
+
},
|
386
|
+
)
|
387
|
+
break
|
388
|
+
except ConnectionRefusedError as e:
|
389
|
+
time.sleep(1)
|
390
|
+
repeat_count += 1
|
391
|
+
if repeat_count == 20:
|
392
|
+
raise RuntimeError(f"Connect rpyc error: {e}")
|
393
|
+
|
394
|
+
return con.root
|
395
|
+
|
396
|
+
|
397
|
+
def start_rpyc_service(service: rpyc.Service, port: int):
|
398
|
+
t = ThreadedServer(
|
399
|
+
service=service,
|
400
|
+
port=port,
|
401
|
+
protocol_config={
|
402
|
+
"allow_public_attrs": True,
|
403
|
+
"allow_pickle": True,
|
404
|
+
"sync_request_timeout": 3600,
|
405
|
+
},
|
406
|
+
)
|
407
|
+
t.logger.setLevel(logging.WARN)
|
408
|
+
t.start()
|
409
|
+
|
410
|
+
|
411
|
+
def start_rpyc_service_process(service: rpyc.Service, port: int):
|
412
|
+
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
|
413
|
+
proc.start()
|
414
|
+
return proc
|
415
|
+
|
416
|
+
|
417
|
+
def suppress_other_loggers():
|
418
|
+
from vllm.logger import logger as vllm_default_logger
|
419
|
+
|
420
|
+
vllm_default_logger.setLevel(logging.WARN)
|
421
|
+
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
422
|
+
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
423
|
+
logging.WARN
|
424
|
+
)
|
425
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
426
|
+
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
427
|
+
|
428
|
+
|
429
|
+
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
430
|
+
try:
|
431
|
+
installed_version = version(pkg)
|
432
|
+
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
433
|
+
raise Exception(
|
434
|
+
f"{pkg} is installed with version {installed_version}, which "
|
435
|
+
f"is less than the minimum required version {min_version}. " + message
|
436
|
+
)
|
437
|
+
except PackageNotFoundError:
|
438
|
+
raise Exception(
|
439
|
+
f"{pkg} with minimum required version {min_version} is not installed. "
|
440
|
+
+ message
|
441
|
+
)
|
442
|
+
|
443
|
+
|
444
|
+
def kill_parent_process():
|
445
|
+
"""Kill the parent process and all children of the parent process."""
|
446
|
+
current_process = psutil.Process()
|
447
|
+
parent_process = current_process.parent()
|
448
|
+
children = current_process.children(recursive=True)
|
449
|
+
for child in children:
|
450
|
+
if child.pid != current_process.pid:
|
451
|
+
os.kill(child.pid, 9)
|
452
|
+
os.kill(parent_process.pid, 9)
|
453
|
+
|
454
|
+
|
455
|
+
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
456
|
+
"""
|
457
|
+
Monkey patch the slow p2p access check in vllm.
|
458
|
+
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
459
|
+
"""
|
460
|
+
|
461
|
+
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
462
|
+
|
463
|
+
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
464
|
+
|
465
|
+
|
466
|
+
def monkey_patch_vllm_dummy_weight_loader():
|
467
|
+
"""
|
468
|
+
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
|
469
|
+
"""
|
470
|
+
|
471
|
+
from vllm.model_executor.model_loader.loader import (
|
472
|
+
CacheConfig,
|
473
|
+
DeviceConfig,
|
474
|
+
DummyModelLoader,
|
475
|
+
LoRAConfig,
|
476
|
+
ModelConfig,
|
477
|
+
MultiModalConfig,
|
478
|
+
ParallelConfig,
|
479
|
+
SchedulerConfig,
|
480
|
+
_initialize_model,
|
481
|
+
initialize_dummy_weights,
|
482
|
+
nn,
|
483
|
+
set_default_torch_dtype,
|
484
|
+
)
|
485
|
+
|
486
|
+
def load_model(
|
487
|
+
self,
|
488
|
+
*,
|
489
|
+
model_config: ModelConfig,
|
490
|
+
device_config: DeviceConfig,
|
491
|
+
lora_config: Optional[LoRAConfig],
|
492
|
+
multimodal_config: Optional[MultiModalConfig],
|
493
|
+
parallel_config: ParallelConfig,
|
494
|
+
scheduler_config: SchedulerConfig,
|
495
|
+
cache_config: CacheConfig,
|
496
|
+
) -> nn.Module:
|
497
|
+
with set_default_torch_dtype(model_config.dtype):
|
498
|
+
with torch.device(device_config.device):
|
499
|
+
model = _initialize_model(
|
500
|
+
model_config,
|
501
|
+
self.load_config,
|
502
|
+
lora_config,
|
503
|
+
multimodal_config,
|
504
|
+
cache_config,
|
505
|
+
)
|
506
|
+
|
507
|
+
for _, module in model.named_modules():
|
508
|
+
quant_method = getattr(module, "quant_method", None)
|
509
|
+
if quant_method is not None:
|
510
|
+
quant_method.process_weights_after_loading(module)
|
511
|
+
# FIXME: Remove this after Mixtral is updated
|
512
|
+
# to use quant_method.
|
513
|
+
if hasattr(module, "process_weights_after_loading"):
|
514
|
+
module.process_weights_after_loading()
|
515
|
+
|
516
|
+
# NOTE(woosuk): For accurate performance evaluation, we assign
|
517
|
+
# random values to the weights.
|
518
|
+
initialize_dummy_weights(model)
|
519
|
+
return model.eval()
|
520
|
+
|
521
|
+
setattr(DummyModelLoader, "load_model", load_model)
|
522
|
+
|
523
|
+
|
524
|
+
API_KEY_HEADER_NAME = "X-API-Key"
|
525
|
+
|
526
|
+
|
527
|
+
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
528
|
+
def __init__(self, app, api_key: str):
|
529
|
+
super().__init__(app)
|
530
|
+
self.api_key = api_key
|
531
|
+
|
532
|
+
async def dispatch(self, request, call_next):
|
533
|
+
# extract API key from the request headers
|
534
|
+
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
535
|
+
if not api_key_header or api_key_header != self.api_key:
|
536
|
+
return JSONResponse(
|
537
|
+
status_code=403,
|
538
|
+
content={"detail": "Invalid API Key"},
|
539
|
+
)
|
540
|
+
response = await call_next(request)
|
541
|
+
return response
|
542
|
+
|
543
|
+
|
544
|
+
def get_ip_address(ifname):
|
545
|
+
"""
|
546
|
+
Get the IP address of a network interface.
|
547
|
+
|
548
|
+
:param ifname: Name of the network interface (e.g., 'eth0')
|
549
|
+
:return: IP address of the network interface
|
550
|
+
"""
|
551
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
552
|
+
ip_address = fcntl.ioctl(
|
553
|
+
s.fileno(),
|
554
|
+
0x8915, # SIOCGIFADDR
|
555
|
+
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
556
|
+
)[20:24]
|
557
|
+
return socket.inet_ntoa(ip_address)
|
558
|
+
|
559
|
+
|
560
|
+
def send_addrs_to_rank_0(model_port_args, server_args):
|
561
|
+
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
562
|
+
import torch.distributed as dist
|
563
|
+
|
564
|
+
ifname = os.environ.get(
|
565
|
+
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
566
|
+
)
|
567
|
+
ip_addr = get_ip_address(ifname)
|
568
|
+
|
569
|
+
num_tp_ports = server_args.tp_size // server_args.nnodes
|
570
|
+
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
571
|
+
ip_addr = [int(x) for x in ip_addr.split(".")]
|
572
|
+
addrs_tensor = torch.tensor(
|
573
|
+
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
574
|
+
)
|
575
|
+
|
576
|
+
init_method = f"tcp://{server_args.nccl_init_addr}"
|
577
|
+
dist.init_process_group(
|
578
|
+
backend="gloo",
|
579
|
+
init_method=init_method,
|
580
|
+
rank=server_args.node_rank,
|
581
|
+
world_size=server_args.nnodes,
|
582
|
+
)
|
583
|
+
dist.send(addrs_tensor, dst=0)
|
584
|
+
print(
|
585
|
+
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
586
|
+
)
|
587
|
+
|
588
|
+
dist.barrier()
|
589
|
+
dist.destroy_process_group()
|
590
|
+
|
591
|
+
|
592
|
+
def receive_addrs(model_port_args, server_args):
|
593
|
+
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
594
|
+
import torch.distributed as dist
|
595
|
+
|
596
|
+
ifname = os.environ.get(
|
597
|
+
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
598
|
+
)
|
599
|
+
ip_addr = get_ip_address(ifname)
|
600
|
+
|
601
|
+
num_tp_ports = server_args.tp_size // server_args.nnodes
|
602
|
+
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
603
|
+
|
604
|
+
init_method = f"tcp://{server_args.nccl_init_addr}"
|
605
|
+
dist.init_process_group(
|
606
|
+
backend="gloo",
|
607
|
+
init_method=init_method,
|
608
|
+
rank=server_args.node_rank,
|
609
|
+
world_size=server_args.nnodes,
|
610
|
+
)
|
611
|
+
|
612
|
+
for src_rank in range(1, server_args.nnodes):
|
613
|
+
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
614
|
+
dist.recv(tensor, src=src_rank)
|
615
|
+
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
616
|
+
ports = tensor[4:].tolist()
|
617
|
+
model_port_args.model_tp_ips[
|
618
|
+
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
619
|
+
] = [ip] * num_tp_ports
|
620
|
+
model_port_args.model_tp_ports[
|
621
|
+
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
622
|
+
] = ports
|
623
|
+
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
624
|
+
|
625
|
+
dist.barrier()
|
626
|
+
dist.destroy_process_group()
|