sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -20,6 +20,7 @@ import random
|
|
20
20
|
import tempfile
|
21
21
|
from typing import List, Optional
|
22
22
|
|
23
|
+
from sglang.srt.hf_transformers_utils import check_gguf_file
|
23
24
|
from sglang.srt.utils import (
|
24
25
|
get_amdgpu_memory_capacity,
|
25
26
|
get_nvgpu_memory_capacity,
|
@@ -49,6 +50,7 @@ class ServerArgs:
|
|
49
50
|
served_model_name: Optional[str] = None
|
50
51
|
chat_template: Optional[str] = None
|
51
52
|
is_embedding: bool = False
|
53
|
+
revision: Optional[str] = None
|
52
54
|
|
53
55
|
# Port
|
54
56
|
host: str = "127.0.0.1"
|
@@ -58,7 +60,7 @@ class ServerArgs:
|
|
58
60
|
mem_fraction_static: Optional[float] = None
|
59
61
|
max_running_requests: Optional[int] = None
|
60
62
|
max_total_tokens: Optional[int] = None
|
61
|
-
chunked_prefill_size: int =
|
63
|
+
chunked_prefill_size: Optional[int] = None
|
62
64
|
max_prefill_tokens: int = 16384
|
63
65
|
schedule_policy: str = "lpm"
|
64
66
|
schedule_conservativeness: float = 1.0
|
@@ -120,7 +122,7 @@ class ServerArgs:
|
|
120
122
|
disable_jump_forward: bool = False
|
121
123
|
disable_cuda_graph: bool = False
|
122
124
|
disable_cuda_graph_padding: bool = False
|
123
|
-
|
125
|
+
disable_outlines_disk_cache: bool = False
|
124
126
|
disable_custom_all_reduce: bool = False
|
125
127
|
disable_mla: bool = False
|
126
128
|
disable_overlap_schedule: bool = False
|
@@ -128,7 +130,7 @@ class ServerArgs:
|
|
128
130
|
enable_dp_attention: bool = False
|
129
131
|
enable_torch_compile: bool = False
|
130
132
|
torch_compile_max_bs: int = 32
|
131
|
-
cuda_graph_max_bs: int =
|
133
|
+
cuda_graph_max_bs: Optional[int] = None
|
132
134
|
torchao_config: str = ""
|
133
135
|
enable_nan_detection: bool = False
|
134
136
|
enable_p2p_check: bool = False
|
@@ -144,19 +146,20 @@ class ServerArgs:
|
|
144
146
|
if self.served_model_name is None:
|
145
147
|
self.served_model_name = self.model_path
|
146
148
|
|
147
|
-
if self.chunked_prefill_size <= 0:
|
148
|
-
# Disable chunked prefill
|
149
|
-
self.chunked_prefill_size = None
|
150
|
-
|
151
149
|
if self.random_seed is None:
|
152
150
|
self.random_seed = random.randint(0, 1 << 30)
|
153
151
|
|
154
|
-
|
152
|
+
if is_hip():
|
153
|
+
gpu_mem = get_amdgpu_memory_capacity()
|
154
|
+
else:
|
155
|
+
gpu_mem = get_nvgpu_memory_capacity()
|
156
|
+
|
157
|
+
# Set mem fraction static, which depends on the tensor parallelism size
|
155
158
|
if self.mem_fraction_static is None:
|
156
159
|
if self.tp_size >= 16:
|
157
160
|
self.mem_fraction_static = 0.79
|
158
161
|
elif self.tp_size >= 8:
|
159
|
-
self.mem_fraction_static = 0.
|
162
|
+
self.mem_fraction_static = 0.81
|
160
163
|
elif self.tp_size >= 4:
|
161
164
|
self.mem_fraction_static = 0.85
|
162
165
|
elif self.tp_size >= 2:
|
@@ -164,25 +167,35 @@ class ServerArgs:
|
|
164
167
|
else:
|
165
168
|
self.mem_fraction_static = 0.88
|
166
169
|
|
167
|
-
#
|
168
|
-
if
|
169
|
-
gpu_mem
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
self.chunked_prefill_size //= 4 # make it 2048
|
174
|
-
self.cuda_graph_max_bs = 4
|
175
|
-
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
170
|
+
# Set chunked prefill size, which depends on the gpu memory capacity
|
171
|
+
if self.chunked_prefill_size is None:
|
172
|
+
if gpu_mem < 25_000:
|
173
|
+
self.chunked_prefill_size = 2048
|
174
|
+
else:
|
175
|
+
self.chunked_prefill_size = 8192
|
176
176
|
|
177
|
-
#
|
178
|
-
if
|
179
|
-
|
180
|
-
|
177
|
+
# Set cuda graph max batch size
|
178
|
+
if self.cuda_graph_max_bs is None:
|
179
|
+
if gpu_mem < 25_000:
|
180
|
+
self.cuda_graph_max_bs = 8
|
181
|
+
else:
|
182
|
+
self.cuda_graph_max_bs = 160
|
181
183
|
|
184
|
+
# Choose kernel backends
|
182
185
|
if self.attention_backend is None:
|
183
|
-
self.attention_backend =
|
186
|
+
self.attention_backend = (
|
187
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
188
|
+
)
|
184
189
|
if self.sampling_backend is None:
|
185
|
-
self.sampling_backend =
|
190
|
+
self.sampling_backend = (
|
191
|
+
"flashinfer" if is_flashinfer_available() else "pytorch"
|
192
|
+
)
|
193
|
+
|
194
|
+
if self.attention_backend == "torch_native":
|
195
|
+
logger.warning(
|
196
|
+
"Cuda graph is disabled because of using torch native attention backend"
|
197
|
+
)
|
198
|
+
self.disable_cuda_graph = True
|
186
199
|
|
187
200
|
# Others
|
188
201
|
if self.enable_dp_attention:
|
@@ -191,14 +204,20 @@ class ServerArgs:
|
|
191
204
|
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
|
192
205
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
193
206
|
self.disable_overlap_schedule = True
|
194
|
-
logger.
|
207
|
+
logger.warning(
|
195
208
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
196
209
|
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
|
197
210
|
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
198
211
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
199
|
-
"Overlap
|
212
|
+
"Overlap scheduler is disabled."
|
200
213
|
)
|
201
214
|
|
215
|
+
# GGUF
|
216
|
+
if (
|
217
|
+
self.load_format == "auto" or self.load_format == "gguf"
|
218
|
+
) and check_gguf_file(self.model_path):
|
219
|
+
self.quantization = self.load_format = "gguf"
|
220
|
+
|
202
221
|
@staticmethod
|
203
222
|
def add_cli_args(parser: argparse.ArgumentParser):
|
204
223
|
# Model and port args
|
@@ -238,7 +257,7 @@ class ServerArgs:
|
|
238
257
|
"--load-format",
|
239
258
|
type=str,
|
240
259
|
default=ServerArgs.load_format,
|
241
|
-
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
260
|
+
choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
|
242
261
|
help="The format of the model weights to load. "
|
243
262
|
'"auto" will try to load the weights in the safetensors format '
|
244
263
|
"and fall back to the pytorch bin format if safetensors format "
|
@@ -248,7 +267,8 @@ class ServerArgs:
|
|
248
267
|
'"npcache" will load the weights in pytorch format and store '
|
249
268
|
"a numpy cache to speed up the loading. "
|
250
269
|
'"dummy" will initialize the weights with random values, '
|
251
|
-
"which is mainly for profiling."
|
270
|
+
"which is mainly for profiling."
|
271
|
+
'"gguf" will load the weights in the gguf format. ',
|
252
272
|
)
|
253
273
|
parser.add_argument(
|
254
274
|
"--trust-remote-code",
|
@@ -288,6 +308,7 @@ class ServerArgs:
|
|
288
308
|
"gptq_marlin",
|
289
309
|
"awq_marlin",
|
290
310
|
"bitsandbytes",
|
311
|
+
"gguf",
|
291
312
|
],
|
292
313
|
help="The quantization method.",
|
293
314
|
)
|
@@ -321,6 +342,14 @@ class ServerArgs:
|
|
321
342
|
action="store_true",
|
322
343
|
help="Whether to use a CausalLM as an embedding model.",
|
323
344
|
)
|
345
|
+
parser.add_argument(
|
346
|
+
"--revision",
|
347
|
+
type=str,
|
348
|
+
default=None,
|
349
|
+
help="The specific model version to use. It can be a branch "
|
350
|
+
"name, a tag name, or a commit id. If unspecified, will use "
|
351
|
+
"the default version.",
|
352
|
+
)
|
324
353
|
|
325
354
|
# Memory and scheduling
|
326
355
|
parser.add_argument(
|
@@ -572,7 +601,7 @@ class ServerArgs:
|
|
572
601
|
parser.add_argument(
|
573
602
|
"--attention-backend",
|
574
603
|
type=str,
|
575
|
-
choices=["flashinfer", "triton"],
|
604
|
+
choices=["flashinfer", "triton", "torch_native"],
|
576
605
|
default=ServerArgs.attention_backend,
|
577
606
|
help="Choose the kernels for attention layers.",
|
578
607
|
)
|
@@ -613,9 +642,9 @@ class ServerArgs:
|
|
613
642
|
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
614
643
|
)
|
615
644
|
parser.add_argument(
|
616
|
-
"--disable-disk-cache",
|
645
|
+
"--disable-outlines-disk-cache",
|
617
646
|
action="store_true",
|
618
|
-
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
647
|
+
help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
|
619
648
|
)
|
620
649
|
parser.add_argument(
|
621
650
|
"--disable-custom-all-reduce",
|
@@ -716,6 +745,11 @@ class ServerArgs:
|
|
716
745
|
action=DeprecatedAction,
|
717
746
|
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
718
747
|
)
|
748
|
+
parser.add_argument(
|
749
|
+
"--disable-disk-cache",
|
750
|
+
action=DeprecatedAction,
|
751
|
+
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
|
752
|
+
)
|
719
753
|
|
720
754
|
@classmethod
|
721
755
|
def from_cli_args(cls, args: argparse.Namespace):
|
sglang/srt/utils.py
CHANGED
@@ -30,6 +30,7 @@ import subprocess
|
|
30
30
|
import tempfile
|
31
31
|
import time
|
32
32
|
import warnings
|
33
|
+
from functools import lru_cache
|
33
34
|
from importlib.metadata import PackageNotFoundError, version
|
34
35
|
from io import BytesIO
|
35
36
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
@@ -38,6 +39,7 @@ import numpy as np
|
|
38
39
|
import psutil
|
39
40
|
import requests
|
40
41
|
import torch
|
42
|
+
import torch.distributed
|
41
43
|
import torch.distributed as dist
|
42
44
|
import triton
|
43
45
|
import zmq
|
@@ -67,6 +69,22 @@ def is_hip() -> bool:
|
|
67
69
|
return torch.version.hip is not None
|
68
70
|
|
69
71
|
|
72
|
+
def is_cuda():
|
73
|
+
return hasattr(torch, "cuda") and torch.cuda.is_available()
|
74
|
+
|
75
|
+
|
76
|
+
def is_cuda_alike():
|
77
|
+
return is_cuda() or is_hip()
|
78
|
+
|
79
|
+
|
80
|
+
def is_hpu() -> bool:
|
81
|
+
return hasattr(torch, "hpu") and torch.hpu.is_available()
|
82
|
+
|
83
|
+
|
84
|
+
def is_xpu() -> bool:
|
85
|
+
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
86
|
+
|
87
|
+
|
70
88
|
def is_flashinfer_available():
|
71
89
|
"""
|
72
90
|
Check whether flashinfer is available.
|
@@ -412,16 +430,12 @@ def suppress_other_loggers():
|
|
412
430
|
from vllm.logger import logger as vllm_default_logger
|
413
431
|
|
414
432
|
vllm_default_logger.setLevel(logging.WARN)
|
415
|
-
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
416
433
|
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
417
434
|
logging.WARN
|
418
435
|
)
|
419
436
|
logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
|
420
437
|
logging.WARN
|
421
438
|
)
|
422
|
-
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
423
|
-
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
424
|
-
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
|
425
439
|
|
426
440
|
warnings.filterwarnings(
|
427
441
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -443,26 +457,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
|
|
443
457
|
)
|
444
458
|
|
445
459
|
|
446
|
-
def
|
447
|
-
"""Kill the
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
parent_process.pid, include_self=True, skip_pid=current_process.pid
|
452
|
-
)
|
453
|
-
try:
|
454
|
-
current_process.kill()
|
455
|
-
except psutil.NoSuchProcess:
|
456
|
-
pass
|
457
|
-
|
458
|
-
|
459
|
-
def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
460
|
-
"""Kill the process and all its children process."""
|
461
|
-
if pid is None:
|
462
|
-
pid = os.getpid()
|
460
|
+
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
461
|
+
"""Kill the process and all its child processes."""
|
462
|
+
if parent_pid is None:
|
463
|
+
parent_pid = os.getpid()
|
464
|
+
include_parent = False
|
463
465
|
|
464
466
|
try:
|
465
|
-
itself = psutil.Process(
|
467
|
+
itself = psutil.Process(parent_pid)
|
466
468
|
except psutil.NoSuchProcess:
|
467
469
|
return
|
468
470
|
|
@@ -475,38 +477,17 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
|
475
477
|
except psutil.NoSuchProcess:
|
476
478
|
pass
|
477
479
|
|
478
|
-
if
|
480
|
+
if include_parent:
|
479
481
|
try:
|
480
482
|
itself.kill()
|
481
483
|
|
482
484
|
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
483
485
|
# so we send an additional signal to kill them.
|
484
|
-
itself.send_signal(signal.
|
486
|
+
itself.send_signal(signal.SIGQUIT)
|
485
487
|
except psutil.NoSuchProcess:
|
486
488
|
pass
|
487
489
|
|
488
490
|
|
489
|
-
def monkey_patch_vllm_model_config():
|
490
|
-
from vllm.config import ModelConfig
|
491
|
-
|
492
|
-
if not hasattr(ModelConfig, "_resolve_task"):
|
493
|
-
return
|
494
|
-
|
495
|
-
def _resolve_task(
|
496
|
-
self,
|
497
|
-
task_option,
|
498
|
-
hf_config,
|
499
|
-
):
|
500
|
-
supported_tasks = {
|
501
|
-
"generate": True,
|
502
|
-
"embedding": False,
|
503
|
-
}
|
504
|
-
selected_task = "generate"
|
505
|
-
return supported_tasks, selected_task
|
506
|
-
|
507
|
-
setattr(ModelConfig, "_resolve_task", _resolve_task)
|
508
|
-
|
509
|
-
|
510
491
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
511
492
|
"""
|
512
493
|
Monkey patch the slow p2p access check in vllm.
|
@@ -569,6 +550,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
|
|
569
550
|
setattr(GroupCoordinator, "all_gather", all_gather)
|
570
551
|
|
571
552
|
|
553
|
+
def monkey_patch_vllm_gguf_config():
|
554
|
+
from vllm.model_executor.layers.linear import LinearBase
|
555
|
+
from vllm.model_executor.layers.quantization.gguf import (
|
556
|
+
GGUFConfig,
|
557
|
+
GGUFEmbeddingMethod,
|
558
|
+
GGUFLinearMethod,
|
559
|
+
)
|
560
|
+
|
561
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
562
|
+
|
563
|
+
def get_quant_method_with_embedding_replaced(
|
564
|
+
self, layer: torch.nn.Module, prefix: str
|
565
|
+
) -> Optional["QuantizeMethodBase"]:
|
566
|
+
if isinstance(layer, LinearBase):
|
567
|
+
return GGUFLinearMethod(self)
|
568
|
+
elif isinstance(layer, VocabParallelEmbedding):
|
569
|
+
# patch to own VocabParallelEmbedding
|
570
|
+
return GGUFEmbeddingMethod(self)
|
571
|
+
return None
|
572
|
+
|
573
|
+
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
|
574
|
+
|
575
|
+
|
572
576
|
def maybe_set_triton_cache_manager() -> None:
|
573
577
|
"""Set environment variable to tell Triton to use a
|
574
578
|
custom cache manager"""
|
@@ -874,7 +878,9 @@ def get_amdgpu_memory_capacity():
|
|
874
878
|
try:
|
875
879
|
# Run rocm-smi and capture the output
|
876
880
|
result = subprocess.run(
|
877
|
-
[
|
881
|
+
[
|
882
|
+
"rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
|
883
|
+
],
|
878
884
|
stdout=subprocess.PIPE,
|
879
885
|
stderr=subprocess.PIPE,
|
880
886
|
shell=True,
|
@@ -885,9 +891,8 @@ def get_amdgpu_memory_capacity():
|
|
885
891
|
|
886
892
|
# Parse the output to extract memory values in MiB
|
887
893
|
memory_values = [
|
888
|
-
float(mem) / 1024
|
894
|
+
float(mem.split("(")[0].strip()) / 1024
|
889
895
|
for mem in result.stdout.strip().split("\n")
|
890
|
-
if re.match(r"^\d+(\.\d+)?$", mem.strip())
|
891
896
|
]
|
892
897
|
|
893
898
|
if not memory_values:
|
@@ -934,11 +939,88 @@ def get_nvgpu_memory_capacity():
|
|
934
939
|
)
|
935
940
|
|
936
941
|
|
942
|
+
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
943
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
944
|
+
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
945
|
+
def init_custom_process_group(
|
946
|
+
backend=None,
|
947
|
+
init_method=None,
|
948
|
+
timeout=None,
|
949
|
+
world_size=-1,
|
950
|
+
rank=-1,
|
951
|
+
store=None,
|
952
|
+
group_name=None,
|
953
|
+
pg_options=None,
|
954
|
+
):
|
955
|
+
from torch.distributed.distributed_c10d import (
|
956
|
+
Backend,
|
957
|
+
PrefixStore,
|
958
|
+
_new_process_group_helper,
|
959
|
+
_world,
|
960
|
+
default_pg_timeout,
|
961
|
+
rendezvous,
|
962
|
+
)
|
963
|
+
|
964
|
+
assert (store is None) or (
|
965
|
+
init_method is None
|
966
|
+
), "Cannot specify both init_method and store."
|
967
|
+
|
968
|
+
if store is not None:
|
969
|
+
assert world_size > 0, "world_size must be positive if using store"
|
970
|
+
assert rank >= 0, "rank must be non-negative if using store"
|
971
|
+
elif init_method is None:
|
972
|
+
init_method = "env://"
|
973
|
+
|
974
|
+
if backend:
|
975
|
+
backend = Backend(backend)
|
976
|
+
else:
|
977
|
+
backend = Backend("undefined")
|
978
|
+
|
979
|
+
if timeout is None:
|
980
|
+
timeout = default_pg_timeout
|
981
|
+
|
982
|
+
# backward compatible API
|
983
|
+
if store is None:
|
984
|
+
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
985
|
+
store, rank, world_size = next(rendezvous_iterator)
|
986
|
+
store.set_timeout(timeout)
|
987
|
+
|
988
|
+
# Use a PrefixStore to avoid accidental overrides of keys used by
|
989
|
+
# different systems (e.g. RPC) in case the store is multi-tenant.
|
990
|
+
store = PrefixStore(group_name, store)
|
991
|
+
|
992
|
+
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
993
|
+
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
994
|
+
# We need to determine the appropriate parameter name based on PyTorch version
|
995
|
+
pg_options_param_name = (
|
996
|
+
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
997
|
+
)
|
998
|
+
pg, _ = _new_process_group_helper(
|
999
|
+
world_size,
|
1000
|
+
rank,
|
1001
|
+
[],
|
1002
|
+
backend,
|
1003
|
+
store,
|
1004
|
+
group_name=group_name,
|
1005
|
+
**{pg_options_param_name: pg_options},
|
1006
|
+
timeout=timeout,
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
1010
|
+
|
1011
|
+
return pg
|
1012
|
+
|
1013
|
+
|
937
1014
|
def crash_on_warnings():
|
938
1015
|
# Crash on warning if we are running CI tests
|
939
1016
|
return get_bool_env_var("SGLANG_IS_IN_CI")
|
940
1017
|
|
941
1018
|
|
1019
|
+
def print_warning_once(msg: str) -> None:
|
1020
|
+
# Set the stacklevel to 2 to print the caller's line info
|
1021
|
+
logger.warning(msg, stacklevel=2)
|
1022
|
+
|
1023
|
+
|
942
1024
|
def get_device_name(device_id: int = 0) -> str:
|
943
1025
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
944
1026
|
return torch.cuda.get_device_name(device_id)
|
@@ -953,9 +1035,42 @@ def get_device_name(device_id: int = 0) -> str:
|
|
953
1035
|
return torch.hpu.get_device_name(device_id)
|
954
1036
|
|
955
1037
|
|
1038
|
+
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
1039
|
+
major, minor = None, None
|
1040
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1041
|
+
major, minor = torch.cuda.get_device_capability(device_id)
|
1042
|
+
|
1043
|
+
if hasattr(torch, "hip") and torch.hip.is_available():
|
1044
|
+
major, minor = torch.cuda.get_device_capability(device_id)
|
1045
|
+
|
1046
|
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
1047
|
+
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
|
1048
|
+
"."
|
1049
|
+
)
|
1050
|
+
major, minor = int(major), int(minor)
|
1051
|
+
|
1052
|
+
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
1053
|
+
# Update this once the support is available.
|
1054
|
+
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
1055
|
+
try:
|
1056
|
+
major, minor = torch.hpu.get_device_capability(device_id)
|
1057
|
+
except Exception as e:
|
1058
|
+
raise RuntimeError(
|
1059
|
+
f"An error occurred while getting device capability of hpu: {e}."
|
1060
|
+
) from e
|
1061
|
+
|
1062
|
+
return major, minor
|
1063
|
+
|
1064
|
+
|
956
1065
|
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
957
1066
|
|
958
1067
|
|
1068
|
+
# Some backends use pytorch version < 2.4.0 which doesn't
|
1069
|
+
# support `torch.library.custom_op`.
|
1070
|
+
def supports_custom_op() -> bool:
|
1071
|
+
return hasattr(torch.library, "custom_op")
|
1072
|
+
|
1073
|
+
|
959
1074
|
def direct_register_custom_op(
|
960
1075
|
op_name: str,
|
961
1076
|
op_func: Callable,
|
@@ -1032,3 +1147,93 @@ def set_gpu_proc_affinity(
|
|
1032
1147
|
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
1033
1148
|
value = os.getenv(name, default)
|
1034
1149
|
return value.lower() in ("true", "1")
|
1150
|
+
|
1151
|
+
|
1152
|
+
@lru_cache(maxsize=8)
|
1153
|
+
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
1154
|
+
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
1155
|
+
# LRU Cache purposes.
|
1156
|
+
|
1157
|
+
# Code below is based on
|
1158
|
+
# https://github.com/pytorch/pytorch/blob/
|
1159
|
+
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
1160
|
+
# torch/cuda/__init__.py#L831C1-L831C17
|
1161
|
+
import torch.cuda
|
1162
|
+
import torch.version
|
1163
|
+
|
1164
|
+
if not torch.cuda._is_compiled():
|
1165
|
+
return 0
|
1166
|
+
if is_hip():
|
1167
|
+
# ROCm uses amdsmi instead of nvml for stateless device count
|
1168
|
+
# This requires a sufficiently modern version of Torch 2.4.0
|
1169
|
+
raw_count = (
|
1170
|
+
torch.cuda._device_count_amdsmi()
|
1171
|
+
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
1172
|
+
else -1
|
1173
|
+
)
|
1174
|
+
else:
|
1175
|
+
raw_count = torch.cuda._device_count_nvml()
|
1176
|
+
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
1177
|
+
return r
|
1178
|
+
|
1179
|
+
|
1180
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
|
1181
|
+
def cuda_device_count_stateless() -> int:
|
1182
|
+
"""Get number of CUDA devices, caching based on the value of
|
1183
|
+
CUDA_VISIBLE_DEVICES at the time of call.
|
1184
|
+
|
1185
|
+
This should be used instead of torch.cuda.device_count()
|
1186
|
+
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
1187
|
+
value."""
|
1188
|
+
|
1189
|
+
# This can be removed and simply replaced with torch.cuda.get_device_count
|
1190
|
+
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
1191
|
+
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
1192
|
+
|
1193
|
+
|
1194
|
+
def should_use_tensor_core(
|
1195
|
+
kv_cache_dtype: torch.dtype,
|
1196
|
+
num_attention_heads: int,
|
1197
|
+
num_kv_heads: int,
|
1198
|
+
) -> bool:
|
1199
|
+
"""
|
1200
|
+
Determine whether to use tensor cores for attention computation.
|
1201
|
+
|
1202
|
+
Args:
|
1203
|
+
kv_cache_dtype: Data type of the KV cache
|
1204
|
+
num_attention_heads: Number of attention heads
|
1205
|
+
num_kv_heads: Number of key/value heads
|
1206
|
+
|
1207
|
+
Returns:
|
1208
|
+
bool: Whether to use tensor cores
|
1209
|
+
"""
|
1210
|
+
# Try to use environment variable first
|
1211
|
+
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
1212
|
+
if env_override is not None:
|
1213
|
+
return env_override.lower() == "true"
|
1214
|
+
|
1215
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1216
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1217
|
+
try:
|
1218
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1219
|
+
|
1220
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
1221
|
+
num_attention_heads,
|
1222
|
+
num_kv_heads,
|
1223
|
+
):
|
1224
|
+
return True
|
1225
|
+
else:
|
1226
|
+
return False
|
1227
|
+
except (ImportError, AttributeError):
|
1228
|
+
pass
|
1229
|
+
|
1230
|
+
# Calculate GQA group size
|
1231
|
+
gqa_group_size = num_attention_heads // num_kv_heads
|
1232
|
+
|
1233
|
+
# Determine based on dtype and GQA group size
|
1234
|
+
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
1235
|
+
return True
|
1236
|
+
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
1237
|
+
return gqa_group_size > 4
|
1238
|
+
else:
|
1239
|
+
return False
|