sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -8,17 +8,44 @@ import pickle
|
|
8
8
|
import subprocess
|
9
9
|
import sys
|
10
10
|
import tempfile
|
11
|
+
from functools import wraps
|
11
12
|
from itertools import product
|
12
|
-
from typing import Dict, List, Optional, Sequence
|
13
|
+
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
|
13
14
|
|
14
15
|
import torch
|
15
16
|
import torch.distributed as dist
|
16
17
|
import torch.multiprocessing as mp
|
18
|
+
from typing_extensions import ParamSpec
|
17
19
|
|
18
20
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
21
|
+
from sglang.srt.utils import is_cuda, is_hip
|
19
22
|
|
20
23
|
logger = logging.getLogger(__name__)
|
21
24
|
|
25
|
+
_is_cuda = is_cuda()
|
26
|
+
_is_hip = is_hip()
|
27
|
+
|
28
|
+
if _is_cuda:
|
29
|
+
try:
|
30
|
+
import pynvml
|
31
|
+
except ImportError as e:
|
32
|
+
logger.warning("Failed to import pynvml with %r", e)
|
33
|
+
|
34
|
+
if _is_hip:
|
35
|
+
try:
|
36
|
+
from amdsmi import (
|
37
|
+
AmdSmiException,
|
38
|
+
amdsmi_get_processor_handles,
|
39
|
+
amdsmi_init,
|
40
|
+
amdsmi_shut_down,
|
41
|
+
amdsmi_topo_get_link_type,
|
42
|
+
)
|
43
|
+
except ImportError as e:
|
44
|
+
logger.warning("Failed to import amdsmi with %r", e)
|
45
|
+
|
46
|
+
_P = ParamSpec("_P")
|
47
|
+
_R = TypeVar("_R")
|
48
|
+
|
22
49
|
|
23
50
|
def update_environment_variables(envs: Dict[str, str]):
|
24
51
|
for k, v in envs.items():
|
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
|
282
309
|
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
283
310
|
|
284
311
|
|
312
|
+
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
313
|
+
@wraps(fn)
|
314
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
315
|
+
if _is_hip:
|
316
|
+
try:
|
317
|
+
amdsmi_init()
|
318
|
+
return fn(*args, **kwargs)
|
319
|
+
finally:
|
320
|
+
amdsmi_shut_down()
|
321
|
+
else:
|
322
|
+
pynvml.nvmlInit()
|
323
|
+
try:
|
324
|
+
return fn(*args, **kwargs)
|
325
|
+
finally:
|
326
|
+
pynvml.nvmlShutdown()
|
327
|
+
|
328
|
+
return wrapper
|
329
|
+
|
330
|
+
|
331
|
+
@with_nvml_context
|
332
|
+
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
333
|
+
if _is_hip:
|
334
|
+
"""
|
335
|
+
query if the set of gpus are fully connected by xgmi (1 hop)
|
336
|
+
"""
|
337
|
+
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
338
|
+
for i, handle in enumerate(handles):
|
339
|
+
for j, peer_handle in enumerate(handles):
|
340
|
+
if i < j:
|
341
|
+
try:
|
342
|
+
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
343
|
+
# type is 2 for XGMI
|
344
|
+
if link_type["hops"] != 1 or link_type["type"] != 2:
|
345
|
+
return False
|
346
|
+
except AmdSmiException as error:
|
347
|
+
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
348
|
+
return False
|
349
|
+
return True
|
350
|
+
else:
|
351
|
+
"""
|
352
|
+
query if the set of gpus are fully connected by nvlink (1 hop)
|
353
|
+
"""
|
354
|
+
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
355
|
+
for i, handle in enumerate(handles):
|
356
|
+
for j, peer_handle in enumerate(handles):
|
357
|
+
if i < j:
|
358
|
+
try:
|
359
|
+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
360
|
+
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
361
|
+
)
|
362
|
+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
363
|
+
return False
|
364
|
+
except pynvml.NVMLError:
|
365
|
+
logger.exception(
|
366
|
+
"NVLink detection failed. This is normal if your"
|
367
|
+
" machine has no NVLink equipped."
|
368
|
+
)
|
369
|
+
return False
|
370
|
+
return True
|
371
|
+
|
372
|
+
|
373
|
+
def is_weak_contiguous(inp: torch.Tensor):
|
374
|
+
return inp.is_contiguous() or (
|
375
|
+
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
376
|
+
== inp.numel() * inp.element_size()
|
377
|
+
)
|
378
|
+
|
379
|
+
|
285
380
|
__all__ = ["gpu_p2p_access_check"]
|
286
381
|
|
287
382
|
if __name__ == "__main__":
|
@@ -0,0 +1,273 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.distributed as dist
|
10
|
+
from torch.distributed import ProcessGroup
|
11
|
+
|
12
|
+
from sglang.srt import _custom_ops as ops
|
13
|
+
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
14
|
+
is_full_nvlink,
|
15
|
+
is_weak_contiguous,
|
16
|
+
)
|
17
|
+
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
18
|
+
from sglang.srt.utils import is_cuda, is_hip
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
_is_cuda = is_cuda()
|
23
|
+
_is_hip = is_hip()
|
24
|
+
|
25
|
+
|
26
|
+
try:
|
27
|
+
ops.qr_max_size()
|
28
|
+
quick_ar = True
|
29
|
+
except Exception:
|
30
|
+
# For CPUs and CUDA
|
31
|
+
quick_ar = False
|
32
|
+
|
33
|
+
|
34
|
+
def qr_rocm_arch_available():
|
35
|
+
if not _is_hip:
|
36
|
+
return False
|
37
|
+
try:
|
38
|
+
props = torch.cuda.get_device_properties(0)
|
39
|
+
gcn_arch = getattr(props, "gcnArchName", "")
|
40
|
+
supported_archs = ["gfx94", "gfx95"]
|
41
|
+
return any(gfx in gcn_arch for gfx in supported_archs)
|
42
|
+
except Exception as e:
|
43
|
+
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
44
|
+
return False
|
45
|
+
|
46
|
+
|
47
|
+
class QuickReduceRegime(Enum):
|
48
|
+
FP = 0
|
49
|
+
INT8 = 1
|
50
|
+
INT6 = 2
|
51
|
+
INT4 = 3
|
52
|
+
NONE = 4
|
53
|
+
|
54
|
+
|
55
|
+
MB = 1024 * 1024
|
56
|
+
|
57
|
+
|
58
|
+
class QuickAllReduce:
|
59
|
+
|
60
|
+
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
61
|
+
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
62
|
+
# The following data is based on kernel tests.
|
63
|
+
# In this order [FP, INT8, INT6, INT4].
|
64
|
+
_QR_MIN_SIZE = {
|
65
|
+
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
|
66
|
+
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
|
67
|
+
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
|
68
|
+
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
|
69
|
+
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
|
70
|
+
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
71
|
+
}
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
75
|
+
) -> None:
|
76
|
+
"""
|
77
|
+
Custom allreduce provides non-destructive acceleration and is
|
78
|
+
available for CUDA and ROCm MI300 series.
|
79
|
+
Custom quick allreduce leverages quantization for further
|
80
|
+
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
81
|
+
quantization formats and FP(float16, bfloat16).
|
82
|
+
Quick allreduce is designed as a complement to custom allreduce.
|
83
|
+
Its initialization requires even stricter conditions.
|
84
|
+
Only the ROCm MI300 series is supported for quick allreduce at
|
85
|
+
this time.
|
86
|
+
Args:
|
87
|
+
group: the process group to work on. If None, it will use the
|
88
|
+
default process group.
|
89
|
+
device: the device to bind the CustomAllreduce to. If None,
|
90
|
+
it will be bind to f"cuda:{local_rank}".
|
91
|
+
It is the caller's responsibility to make sure each communicator
|
92
|
+
is bind to a unique device, and all communicators in this group
|
93
|
+
are in the same node.
|
94
|
+
"""
|
95
|
+
self.disabled = True
|
96
|
+
if not qr_rocm_arch_available():
|
97
|
+
logger.debug(
|
98
|
+
"Custom quick allreduce is only supported on ROCm MI300 series."
|
99
|
+
)
|
100
|
+
return
|
101
|
+
|
102
|
+
if not quick_ar:
|
103
|
+
# disable because of missing quick reduce library
|
104
|
+
# e.g. in a cuda environment
|
105
|
+
logger.info(
|
106
|
+
"Custom quick allreduce is disabled because "
|
107
|
+
"of missing custom quick allreduce library"
|
108
|
+
)
|
109
|
+
return
|
110
|
+
|
111
|
+
self.group = group
|
112
|
+
assert (
|
113
|
+
dist.get_backend(group) != dist.Backend.NCCL
|
114
|
+
), "Custom quick allreduce should be attached to a non-NCCL group."
|
115
|
+
if not all(in_the_same_node_as(group, source_rank=0)):
|
116
|
+
# No need to initialize custom quick allreduce for
|
117
|
+
# multi-node case.
|
118
|
+
logger.warning(
|
119
|
+
"Custom quick allreduce is disabled because this "
|
120
|
+
"process group spans across nodes."
|
121
|
+
)
|
122
|
+
return
|
123
|
+
rank = dist.get_rank(group=self.group)
|
124
|
+
world_size = dist.get_world_size(group=self.group)
|
125
|
+
self.rank = rank
|
126
|
+
self.world_size = world_size
|
127
|
+
if world_size == 1:
|
128
|
+
# No need to initialize QuickReduce for single GPU case.
|
129
|
+
return
|
130
|
+
|
131
|
+
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
132
|
+
logger.warning(
|
133
|
+
"Custom quick allreduce is disabled due to an "
|
134
|
+
"unsupported world size: %d. Supported world sizes: %s.",
|
135
|
+
world_size,
|
136
|
+
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
137
|
+
)
|
138
|
+
return
|
139
|
+
|
140
|
+
if isinstance(device, int):
|
141
|
+
device = torch.device(f"cuda:{device}")
|
142
|
+
elif isinstance(device, str):
|
143
|
+
device = torch.device(device)
|
144
|
+
assert isinstance(device, torch.device)
|
145
|
+
self.device = device
|
146
|
+
|
147
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
148
|
+
if cuda_visible_devices:
|
149
|
+
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
150
|
+
else:
|
151
|
+
device_ids = list(range(torch.cuda.device_count()))
|
152
|
+
physical_device_id = device_ids[device.index]
|
153
|
+
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
154
|
+
gather_list = [
|
155
|
+
torch.tensor([0], dtype=torch.int, device="cpu")
|
156
|
+
for _ in range(self.world_size)
|
157
|
+
]
|
158
|
+
dist.all_gather(gather_list, tensor, group=self.group)
|
159
|
+
physical_device_ids = [t.item() for t in gather_list]
|
160
|
+
|
161
|
+
# test nvlink first, this will filter out most of the cases
|
162
|
+
# where custom quick allreduce is not supported
|
163
|
+
# this checks hardware and driver support for NVLink
|
164
|
+
if _is_cuda or _is_hip:
|
165
|
+
self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
|
166
|
+
if self.world_size > 2 and not self.fully_connected:
|
167
|
+
logger.debug(
|
168
|
+
"Custom quick allreduce is disabled because it's not supported "
|
169
|
+
"on more than two PCIe-only GPUs. "
|
170
|
+
)
|
171
|
+
return
|
172
|
+
|
173
|
+
self.init_quick_all_reduce()
|
174
|
+
|
175
|
+
def init_quick_all_reduce(self):
|
176
|
+
# On RocM, bfloat16 kernels are slower than fp16
|
177
|
+
# due to slower match operations
|
178
|
+
# If environment variable is set to 1, we convert input to fp16
|
179
|
+
self.use_fp16_kernels = int(
|
180
|
+
os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
|
181
|
+
)
|
182
|
+
regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
|
183
|
+
if regime_str not in QuickReduceRegime.__members__:
|
184
|
+
logger.warning(
|
185
|
+
"Custom quick allreduce:",
|
186
|
+
f"Invalid quantization level: {regime_str}. "
|
187
|
+
"Supported levels: "
|
188
|
+
f"{list(QuickReduceRegime.__members__.keys())}",
|
189
|
+
)
|
190
|
+
return
|
191
|
+
|
192
|
+
if regime_str == "NONE":
|
193
|
+
logger.debug(
|
194
|
+
"Custom quick allreduce is disabled based "
|
195
|
+
"on env variable "
|
196
|
+
"ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
197
|
+
)
|
198
|
+
return
|
199
|
+
self.qr_quant_level = QuickReduceRegime[regime_str]
|
200
|
+
|
201
|
+
# TODO: If the dtype is not bfloat16 or then float16,
|
202
|
+
# quickallreduce should not be created.
|
203
|
+
|
204
|
+
# ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
205
|
+
qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
|
206
|
+
if qr_max_size > 0:
|
207
|
+
if qr_max_size < 1:
|
208
|
+
logger.info(
|
209
|
+
"You should not set a max_size smaller than 1MB, which can "
|
210
|
+
"lead to error or degradation to custom allreduce or rccl."
|
211
|
+
)
|
212
|
+
qr_max_size = qr_max_size * MB
|
213
|
+
# If qr_max_size is None, then 2GB is used by default.
|
214
|
+
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
215
|
+
self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
|
216
|
+
self.create_shared_buffer()
|
217
|
+
self.disabled = False
|
218
|
+
|
219
|
+
def create_shared_buffer(self):
|
220
|
+
"""
|
221
|
+
Creates a shared buffer for quickreduce.
|
222
|
+
Has to be called after init_custom_qr
|
223
|
+
"""
|
224
|
+
handle = ops.qr_get_handle(self._ptr)
|
225
|
+
world_size = dist.get_world_size(group=self.group)
|
226
|
+
handles = [None] * world_size
|
227
|
+
dist.all_gather_object(handles, handle, group=self.group)
|
228
|
+
ops.qr_open_handles(self._ptr, handles)
|
229
|
+
|
230
|
+
def should_quick_allreduce(self, inp: torch.Tensor):
|
231
|
+
"""
|
232
|
+
Check if quickreduce is available
|
233
|
+
"""
|
234
|
+
if self.disabled:
|
235
|
+
return False
|
236
|
+
if inp.dtype not in self._SUPPORTED_DTYPES:
|
237
|
+
return False
|
238
|
+
inp_size = inp.numel() * inp.element_size()
|
239
|
+
# custom quick allreduce requires input byte size to be
|
240
|
+
# multiples of 16
|
241
|
+
if inp_size % 16 != 0:
|
242
|
+
return False
|
243
|
+
if not is_weak_contiguous(inp):
|
244
|
+
return False
|
245
|
+
dtype = inp.dtype
|
246
|
+
if self.use_fp16_kernels:
|
247
|
+
dtype = torch.float16
|
248
|
+
return (
|
249
|
+
inp_size <= self.qr_max_size
|
250
|
+
and inp_size
|
251
|
+
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
252
|
+
)
|
253
|
+
|
254
|
+
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
255
|
+
"""Performs an out-of-place custom quick all reduce."""
|
256
|
+
# quick allreduce doesn't require a separate graph mode,
|
257
|
+
# as QR uses static IPC buffer.
|
258
|
+
if out is None:
|
259
|
+
out = torch.empty_like(inp)
|
260
|
+
ops.qr_all_reduce(
|
261
|
+
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
262
|
+
)
|
263
|
+
return out
|
264
|
+
|
265
|
+
def close(self):
|
266
|
+
if not self.disabled and getattr(self, "_ptr", None):
|
267
|
+
if ops is not None:
|
268
|
+
ops.qr_destroy(self._ptr)
|
269
|
+
self._ptr = 0
|
270
|
+
self.disabled = True
|
271
|
+
|
272
|
+
def __del__(self):
|
273
|
+
self.close()
|
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
|
|
16
16
|
from zmq import IPV6 # type: ignore
|
17
17
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
18
18
|
|
19
|
-
from sglang.srt.utils import
|
19
|
+
from sglang.srt.utils import (
|
20
|
+
format_tcp_address,
|
21
|
+
get_ip,
|
22
|
+
get_open_port,
|
23
|
+
is_valid_ipv6_address,
|
24
|
+
)
|
20
25
|
|
21
26
|
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
22
27
|
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
@@ -225,9 +230,9 @@ class MessageQueue:
|
|
225
230
|
remote_subscribe_port = get_open_port()
|
226
231
|
if is_valid_ipv6_address(connect_ip):
|
227
232
|
self.remote_socket.setsockopt(IPV6, 1)
|
228
|
-
|
229
|
-
|
230
|
-
|
233
|
+
self.remote_socket.bind(
|
234
|
+
format_tcp_address(connect_ip, remote_subscribe_port)
|
235
|
+
)
|
231
236
|
|
232
237
|
else:
|
233
238
|
remote_subscribe_port = None
|
@@ -288,7 +293,9 @@ class MessageQueue:
|
|
288
293
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
289
294
|
if is_valid_ipv6_address(handle.connect_ip):
|
290
295
|
self.remote_socket.setsockopt(IPV6, 1)
|
291
|
-
socket_addr =
|
296
|
+
socket_addr = format_tcp_address(
|
297
|
+
handle.connect_ip, handle.remote_subscribe_port
|
298
|
+
)
|
292
299
|
logger.debug("Connecting to %s", socket_addr)
|
293
300
|
self.remote_socket.connect(socket_addr)
|
294
301
|
|
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
|
44
44
|
get_bool_env_var,
|
45
45
|
get_int_env_var,
|
46
46
|
is_cuda_alike,
|
47
|
+
is_hip,
|
47
48
|
is_npu,
|
48
49
|
is_shm_available,
|
49
50
|
supports_custom_op,
|
@@ -126,14 +127,18 @@ if supports_custom_op():
|
|
126
127
|
fake_impl=inplace_all_reduce_fake,
|
127
128
|
)
|
128
129
|
|
129
|
-
def outplace_all_reduce(
|
130
|
+
def outplace_all_reduce(
|
131
|
+
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
132
|
+
) -> torch.Tensor:
|
130
133
|
assert group_name in _groups, f"Group {group_name} is not found."
|
131
134
|
group = _groups[group_name]()
|
132
135
|
if group is None:
|
133
136
|
raise ValueError(f"Group {group_name} is destroyed.")
|
134
|
-
return group._all_reduce_out_place(tensor)
|
137
|
+
return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
|
135
138
|
|
136
|
-
def outplace_all_reduce_fake(
|
139
|
+
def outplace_all_reduce_fake(
|
140
|
+
tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
|
141
|
+
) -> torch.Tensor:
|
137
142
|
return torch.empty_like(tensor)
|
138
143
|
|
139
144
|
direct_register_custom_op(
|
@@ -264,6 +269,12 @@ class GroupCoordinator:
|
|
264
269
|
PyNcclCommunicator,
|
265
270
|
)
|
266
271
|
|
272
|
+
if is_hip():
|
273
|
+
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
274
|
+
QuickAllReduce,
|
275
|
+
qr_rocm_arch_available,
|
276
|
+
)
|
277
|
+
|
267
278
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
268
279
|
if use_pynccl and self.world_size > 1:
|
269
280
|
self.pynccl_comm = PyNcclCommunicator(
|
@@ -283,6 +294,7 @@ class GroupCoordinator:
|
|
283
294
|
)
|
284
295
|
|
285
296
|
self.ca_comm: Optional[CustomAllreduce] = None
|
297
|
+
self.qr_comm: Optional[QuickAllReduce] = None
|
286
298
|
if use_custom_allreduce and self.world_size > 1:
|
287
299
|
# Initialize a custom fast all-reduce implementation.
|
288
300
|
try:
|
@@ -295,6 +307,18 @@ class GroupCoordinator:
|
|
295
307
|
f"Setup Custom allreduce failed with {e}. To silence this "
|
296
308
|
"warning, specify --disable-custom-all-reduce explicitly."
|
297
309
|
)
|
310
|
+
if is_hip():
|
311
|
+
try:
|
312
|
+
# Initialize a custom quick all-reduce implementation for AMD
|
313
|
+
# when rocm >= gfx942. Quick reduce is designed as a
|
314
|
+
# complement to custom allreduce.
|
315
|
+
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
316
|
+
if qr_rocm_arch_available():
|
317
|
+
self.qr_comm = QuickAllReduce(
|
318
|
+
group=self.cpu_group, device=self.device
|
319
|
+
)
|
320
|
+
except Exception as e:
|
321
|
+
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
298
322
|
|
299
323
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
300
324
|
HpuCommunicator,
|
@@ -373,7 +397,8 @@ class GroupCoordinator:
|
|
373
397
|
graph_capture_context = GraphCaptureContext(stream)
|
374
398
|
else:
|
375
399
|
stream = graph_capture_context.stream
|
376
|
-
|
400
|
+
# We don't need the context of custom quick allreduce because the ipc access
|
401
|
+
# is already collected in init() and we can capture the quick allreduce directly.
|
377
402
|
ca_comm = self.ca_comm
|
378
403
|
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
379
404
|
|
@@ -388,23 +413,24 @@ class GroupCoordinator:
|
|
388
413
|
# operations. The current status is:
|
389
414
|
# allreduce \ Mode | Eager | Graph |
|
390
415
|
# --------------------------------------------
|
416
|
+
# quick allreduce | enabled | enabled |
|
391
417
|
# custom allreduce | enabled | enabled |
|
392
418
|
# PyNccl | disabled| enabled |
|
393
419
|
# PyMscclpp | disabled| enabled |
|
394
420
|
# torch.distributed | enabled | disabled|
|
395
421
|
#
|
422
|
+
# Note: When custom quick allreduce is enabled, a runtime check
|
423
|
+
# will be performed. If the tensor size is too small, it will
|
424
|
+
# automatically fall back to the next available option.
|
396
425
|
# Note that custom allreduce will have a runtime check, if the
|
397
426
|
# tensor size is too large, it will fallback to the next
|
398
427
|
# available option.
|
399
428
|
# Note that the PyMsccl needs to register the tensor in ahead,
|
400
429
|
# which will introduce large overhead in the eager case,
|
401
430
|
# therefore it is only supported in the graph case.
|
402
|
-
# In summary:
|
403
|
-
#
|
404
|
-
#
|
405
|
-
# PyTorch NCCL. We always prioritize using custom all-reduce
|
406
|
-
# kernel but fall back to PyTorch or pynccl if it is
|
407
|
-
# disabled or not supported.
|
431
|
+
# In summary: We select the appropriate allreduce method for
|
432
|
+
# each mode based on the algorithm order in the table and
|
433
|
+
# their usage conditions.
|
408
434
|
pynccl_comm = self.pynccl_comm
|
409
435
|
maybe_pynccl_context: Any
|
410
436
|
if not pynccl_comm:
|
@@ -464,27 +490,47 @@ class GroupCoordinator:
|
|
464
490
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
465
491
|
return self.npu_communicator.all_reduce(input_)
|
466
492
|
|
493
|
+
outplace_all_reduce_method = None
|
467
494
|
if (
|
495
|
+
self.qr_comm is not None
|
496
|
+
and not self.qr_comm.disabled
|
497
|
+
and self.qr_comm.should_quick_allreduce(input_)
|
498
|
+
):
|
499
|
+
outplace_all_reduce_method = "qr"
|
500
|
+
elif (
|
468
501
|
self.ca_comm is not None
|
469
502
|
and not self.ca_comm.disabled
|
470
503
|
and self.ca_comm.should_custom_ar(input_)
|
471
|
-
)
|
504
|
+
):
|
505
|
+
outplace_all_reduce_method = "ca"
|
506
|
+
elif (
|
472
507
|
self.pymscclpp_comm is not None
|
473
508
|
and not self.pymscclpp_comm.disabled
|
474
509
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
475
510
|
):
|
511
|
+
outplace_all_reduce_method = "pymscclpp"
|
512
|
+
if outplace_all_reduce_method is not None:
|
476
513
|
return torch.ops.sglang.outplace_all_reduce(
|
477
|
-
input_,
|
514
|
+
input_,
|
515
|
+
group_name=self.unique_name,
|
516
|
+
outplace_all_reduce_method=outplace_all_reduce_method,
|
478
517
|
)
|
479
518
|
else:
|
480
519
|
torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
|
481
520
|
return input_
|
482
521
|
|
483
|
-
def _all_reduce_out_place(
|
522
|
+
def _all_reduce_out_place(
|
523
|
+
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
524
|
+
) -> torch.Tensor:
|
525
|
+
qr_comm = self.qr_comm
|
484
526
|
ca_comm = self.ca_comm
|
485
527
|
pymscclpp_comm = self.pymscclpp_comm
|
486
|
-
assert ca_comm
|
487
|
-
if
|
528
|
+
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
529
|
+
if outplace_all_reduce_method == "qr":
|
530
|
+
assert not qr_comm.disabled
|
531
|
+
out = qr_comm.quick_all_reduce(input_)
|
532
|
+
elif outplace_all_reduce_method == "ca":
|
533
|
+
assert not ca_comm.disabled
|
488
534
|
out = ca_comm.custom_all_reduce(input_)
|
489
535
|
else:
|
490
536
|
assert not pymscclpp_comm.disabled
|
@@ -499,6 +545,15 @@ class GroupCoordinator:
|
|
499
545
|
else:
|
500
546
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
501
547
|
|
548
|
+
def reduce_scatter_tensor(
|
549
|
+
self,
|
550
|
+
output: torch.Tensor,
|
551
|
+
input: torch.Tensor,
|
552
|
+
) -> None:
|
553
|
+
# TODO(ch-wan): support other backends
|
554
|
+
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
|
555
|
+
return output
|
556
|
+
|
502
557
|
def reduce_scatter(
|
503
558
|
self,
|
504
559
|
output: torch.Tensor,
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
|
|
71
71
|
is_cuda,
|
72
72
|
kill_process_tree,
|
73
73
|
launch_dummy_health_check_server,
|
74
|
-
maybe_set_triton_cache_manager,
|
75
74
|
prepare_model_and_tokenizer,
|
76
75
|
set_prometheus_multiproc_dir,
|
77
76
|
set_ulimit,
|
@@ -637,16 +636,11 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
637
636
|
# Set ulimit
|
638
637
|
set_ulimit()
|
639
638
|
|
640
|
-
# Fix triton bugs
|
641
|
-
if server_args.tp_size * server_args.dp_size > 1:
|
642
|
-
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
643
|
-
maybe_set_triton_cache_manager()
|
644
|
-
|
645
639
|
# Check flashinfer version
|
646
640
|
if server_args.attention_backend == "flashinfer":
|
647
641
|
assert_pkg_version(
|
648
642
|
"flashinfer_python",
|
649
|
-
"0.2.
|
643
|
+
"0.2.9rc2",
|
650
644
|
"Please uninstall the old version and "
|
651
645
|
"reinstall the latest version by following the instructions "
|
652
646
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -654,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
654
648
|
if _is_cuda:
|
655
649
|
assert_pkg_version(
|
656
650
|
"sgl-kernel",
|
657
|
-
"0.2.
|
651
|
+
"0.2.7",
|
658
652
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
659
653
|
)
|
660
654
|
|
@@ -771,7 +765,9 @@ def _launch_subprocesses(
|
|
771
765
|
# When using `Engine` as a Python API, we don't want to block here.
|
772
766
|
return None, None, None
|
773
767
|
|
774
|
-
launch_dummy_health_check_server(
|
768
|
+
launch_dummy_health_check_server(
|
769
|
+
server_args.host, server_args.port, server_args.enable_metrics
|
770
|
+
)
|
775
771
|
|
776
772
|
for proc in scheduler_procs:
|
777
773
|
proc.join()
|