sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -8,17 +8,44 @@ import pickle
8
8
  import subprocess
9
9
  import sys
10
10
  import tempfile
11
+ from functools import wraps
11
12
  from itertools import product
12
- from typing import Dict, List, Optional, Sequence
13
+ from typing import Callable, Dict, List, Optional, Sequence, TypeVar
13
14
 
14
15
  import torch
15
16
  import torch.distributed as dist
16
17
  import torch.multiprocessing as mp
18
+ from typing_extensions import ParamSpec
17
19
 
18
20
  from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
21
+ from sglang.srt.utils import is_cuda, is_hip
19
22
 
20
23
  logger = logging.getLogger(__name__)
21
24
 
25
+ _is_cuda = is_cuda()
26
+ _is_hip = is_hip()
27
+
28
+ if _is_cuda:
29
+ try:
30
+ import pynvml
31
+ except ImportError as e:
32
+ logger.warning("Failed to import pynvml with %r", e)
33
+
34
+ if _is_hip:
35
+ try:
36
+ from amdsmi import (
37
+ AmdSmiException,
38
+ amdsmi_get_processor_handles,
39
+ amdsmi_init,
40
+ amdsmi_shut_down,
41
+ amdsmi_topo_get_link_type,
42
+ )
43
+ except ImportError as e:
44
+ logger.warning("Failed to import amdsmi with %r", e)
45
+
46
+ _P = ParamSpec("_P")
47
+ _R = TypeVar("_R")
48
+
22
49
 
23
50
  def update_environment_variables(envs: Dict[str, str]):
24
51
  for k, v in envs.items():
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
282
309
  return _gpu_p2p_access_cache[f"{src}->{tgt}"]
283
310
 
284
311
 
312
+ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
313
+ @wraps(fn)
314
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
315
+ if _is_hip:
316
+ try:
317
+ amdsmi_init()
318
+ return fn(*args, **kwargs)
319
+ finally:
320
+ amdsmi_shut_down()
321
+ else:
322
+ pynvml.nvmlInit()
323
+ try:
324
+ return fn(*args, **kwargs)
325
+ finally:
326
+ pynvml.nvmlShutdown()
327
+
328
+ return wrapper
329
+
330
+
331
+ @with_nvml_context
332
+ def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
333
+ if _is_hip:
334
+ """
335
+ query if the set of gpus are fully connected by xgmi (1 hop)
336
+ """
337
+ handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
338
+ for i, handle in enumerate(handles):
339
+ for j, peer_handle in enumerate(handles):
340
+ if i < j:
341
+ try:
342
+ link_type = amdsmi_topo_get_link_type(handle, peer_handle)
343
+ # type is 2 for XGMI
344
+ if link_type["hops"] != 1 or link_type["type"] != 2:
345
+ return False
346
+ except AmdSmiException as error:
347
+ logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
348
+ return False
349
+ return True
350
+ else:
351
+ """
352
+ query if the set of gpus are fully connected by nvlink (1 hop)
353
+ """
354
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
355
+ for i, handle in enumerate(handles):
356
+ for j, peer_handle in enumerate(handles):
357
+ if i < j:
358
+ try:
359
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
360
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
361
+ )
362
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
363
+ return False
364
+ except pynvml.NVMLError:
365
+ logger.exception(
366
+ "NVLink detection failed. This is normal if your"
367
+ " machine has no NVLink equipped."
368
+ )
369
+ return False
370
+ return True
371
+
372
+
373
+ def is_weak_contiguous(inp: torch.Tensor):
374
+ return inp.is_contiguous() or (
375
+ inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
376
+ == inp.numel() * inp.element_size()
377
+ )
378
+
379
+
285
380
  __all__ = ["gpu_p2p_access_check"]
286
381
 
287
382
  if __name__ == "__main__":
@@ -0,0 +1,273 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ import os
5
+ from enum import Enum
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+
12
+ from sglang.srt import _custom_ops as ops
13
+ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
14
+ is_full_nvlink,
15
+ is_weak_contiguous,
16
+ )
17
+ from sglang.srt.distributed.parallel_state import in_the_same_node_as
18
+ from sglang.srt.utils import is_cuda, is_hip
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ _is_cuda = is_cuda()
23
+ _is_hip = is_hip()
24
+
25
+
26
+ try:
27
+ ops.qr_max_size()
28
+ quick_ar = True
29
+ except Exception:
30
+ # For CPUs and CUDA
31
+ quick_ar = False
32
+
33
+
34
+ def qr_rocm_arch_available():
35
+ if not _is_hip:
36
+ return False
37
+ try:
38
+ props = torch.cuda.get_device_properties(0)
39
+ gcn_arch = getattr(props, "gcnArchName", "")
40
+ supported_archs = ["gfx94", "gfx95"]
41
+ return any(gfx in gcn_arch for gfx in supported_archs)
42
+ except Exception as e:
43
+ logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
44
+ return False
45
+
46
+
47
+ class QuickReduceRegime(Enum):
48
+ FP = 0
49
+ INT8 = 1
50
+ INT6 = 2
51
+ INT4 = 3
52
+ NONE = 4
53
+
54
+
55
+ MB = 1024 * 1024
56
+
57
+
58
+ class QuickAllReduce:
59
+
60
+ _SUPPORTED_WORLD_SIZES = [2, 4, 8]
61
+ _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
62
+ # The following data is based on kernel tests.
63
+ # In this order [FP, INT8, INT6, INT4].
64
+ _QR_MIN_SIZE = {
65
+ (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
66
+ (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
67
+ (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
68
+ (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
69
+ (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
70
+ (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
71
+ }
72
+
73
+ def __init__(
74
+ self, group: ProcessGroup, device: Union[int, str, torch.device]
75
+ ) -> None:
76
+ """
77
+ Custom allreduce provides non-destructive acceleration and is
78
+ available for CUDA and ROCm MI300 series.
79
+ Custom quick allreduce leverages quantization for further
80
+ acceleration on ROCm. It currently supports Q8, Q6, and Q4
81
+ quantization formats and FP(float16, bfloat16).
82
+ Quick allreduce is designed as a complement to custom allreduce.
83
+ Its initialization requires even stricter conditions.
84
+ Only the ROCm MI300 series is supported for quick allreduce at
85
+ this time.
86
+ Args:
87
+ group: the process group to work on. If None, it will use the
88
+ default process group.
89
+ device: the device to bind the CustomAllreduce to. If None,
90
+ it will be bind to f"cuda:{local_rank}".
91
+ It is the caller's responsibility to make sure each communicator
92
+ is bind to a unique device, and all communicators in this group
93
+ are in the same node.
94
+ """
95
+ self.disabled = True
96
+ if not qr_rocm_arch_available():
97
+ logger.debug(
98
+ "Custom quick allreduce is only supported on ROCm MI300 series."
99
+ )
100
+ return
101
+
102
+ if not quick_ar:
103
+ # disable because of missing quick reduce library
104
+ # e.g. in a cuda environment
105
+ logger.info(
106
+ "Custom quick allreduce is disabled because "
107
+ "of missing custom quick allreduce library"
108
+ )
109
+ return
110
+
111
+ self.group = group
112
+ assert (
113
+ dist.get_backend(group) != dist.Backend.NCCL
114
+ ), "Custom quick allreduce should be attached to a non-NCCL group."
115
+ if not all(in_the_same_node_as(group, source_rank=0)):
116
+ # No need to initialize custom quick allreduce for
117
+ # multi-node case.
118
+ logger.warning(
119
+ "Custom quick allreduce is disabled because this "
120
+ "process group spans across nodes."
121
+ )
122
+ return
123
+ rank = dist.get_rank(group=self.group)
124
+ world_size = dist.get_world_size(group=self.group)
125
+ self.rank = rank
126
+ self.world_size = world_size
127
+ if world_size == 1:
128
+ # No need to initialize QuickReduce for single GPU case.
129
+ return
130
+
131
+ if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
132
+ logger.warning(
133
+ "Custom quick allreduce is disabled due to an "
134
+ "unsupported world size: %d. Supported world sizes: %s.",
135
+ world_size,
136
+ str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
137
+ )
138
+ return
139
+
140
+ if isinstance(device, int):
141
+ device = torch.device(f"cuda:{device}")
142
+ elif isinstance(device, str):
143
+ device = torch.device(device)
144
+ assert isinstance(device, torch.device)
145
+ self.device = device
146
+
147
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
148
+ if cuda_visible_devices:
149
+ device_ids = list(map(int, cuda_visible_devices.split(",")))
150
+ else:
151
+ device_ids = list(range(torch.cuda.device_count()))
152
+ physical_device_id = device_ids[device.index]
153
+ tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
154
+ gather_list = [
155
+ torch.tensor([0], dtype=torch.int, device="cpu")
156
+ for _ in range(self.world_size)
157
+ ]
158
+ dist.all_gather(gather_list, tensor, group=self.group)
159
+ physical_device_ids = [t.item() for t in gather_list]
160
+
161
+ # test nvlink first, this will filter out most of the cases
162
+ # where custom quick allreduce is not supported
163
+ # this checks hardware and driver support for NVLink
164
+ if _is_cuda or _is_hip:
165
+ self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
166
+ if self.world_size > 2 and not self.fully_connected:
167
+ logger.debug(
168
+ "Custom quick allreduce is disabled because it's not supported "
169
+ "on more than two PCIe-only GPUs. "
170
+ )
171
+ return
172
+
173
+ self.init_quick_all_reduce()
174
+
175
+ def init_quick_all_reduce(self):
176
+ # On RocM, bfloat16 kernels are slower than fp16
177
+ # due to slower match operations
178
+ # If environment variable is set to 1, we convert input to fp16
179
+ self.use_fp16_kernels = int(
180
+ os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
181
+ )
182
+ regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
183
+ if regime_str not in QuickReduceRegime.__members__:
184
+ logger.warning(
185
+ "Custom quick allreduce:",
186
+ f"Invalid quantization level: {regime_str}. "
187
+ "Supported levels: "
188
+ f"{list(QuickReduceRegime.__members__.keys())}",
189
+ )
190
+ return
191
+
192
+ if regime_str == "NONE":
193
+ logger.debug(
194
+ "Custom quick allreduce is disabled based "
195
+ "on env variable "
196
+ "ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
197
+ )
198
+ return
199
+ self.qr_quant_level = QuickReduceRegime[regime_str]
200
+
201
+ # TODO: If the dtype is not bfloat16 or then float16,
202
+ # quickallreduce should not be created.
203
+
204
+ # ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
205
+ qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
206
+ if qr_max_size > 0:
207
+ if qr_max_size < 1:
208
+ logger.info(
209
+ "You should not set a max_size smaller than 1MB, which can "
210
+ "lead to error or degradation to custom allreduce or rccl."
211
+ )
212
+ qr_max_size = qr_max_size * MB
213
+ # If qr_max_size is None, then 2GB is used by default.
214
+ self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
215
+ self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
216
+ self.create_shared_buffer()
217
+ self.disabled = False
218
+
219
+ def create_shared_buffer(self):
220
+ """
221
+ Creates a shared buffer for quickreduce.
222
+ Has to be called after init_custom_qr
223
+ """
224
+ handle = ops.qr_get_handle(self._ptr)
225
+ world_size = dist.get_world_size(group=self.group)
226
+ handles = [None] * world_size
227
+ dist.all_gather_object(handles, handle, group=self.group)
228
+ ops.qr_open_handles(self._ptr, handles)
229
+
230
+ def should_quick_allreduce(self, inp: torch.Tensor):
231
+ """
232
+ Check if quickreduce is available
233
+ """
234
+ if self.disabled:
235
+ return False
236
+ if inp.dtype not in self._SUPPORTED_DTYPES:
237
+ return False
238
+ inp_size = inp.numel() * inp.element_size()
239
+ # custom quick allreduce requires input byte size to be
240
+ # multiples of 16
241
+ if inp_size % 16 != 0:
242
+ return False
243
+ if not is_weak_contiguous(inp):
244
+ return False
245
+ dtype = inp.dtype
246
+ if self.use_fp16_kernels:
247
+ dtype = torch.float16
248
+ return (
249
+ inp_size <= self.qr_max_size
250
+ and inp_size
251
+ >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
252
+ )
253
+
254
+ def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
255
+ """Performs an out-of-place custom quick all reduce."""
256
+ # quick allreduce doesn't require a separate graph mode,
257
+ # as QR uses static IPC buffer.
258
+ if out is None:
259
+ out = torch.empty_like(inp)
260
+ ops.qr_all_reduce(
261
+ self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
262
+ )
263
+ return out
264
+
265
+ def close(self):
266
+ if not self.disabled and getattr(self, "_ptr", None):
267
+ if ops is not None:
268
+ ops.qr_destroy(self._ptr)
269
+ self._ptr = 0
270
+ self.disabled = True
271
+
272
+ def __del__(self):
273
+ self.close()
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
16
16
  from zmq import IPV6 # type: ignore
17
17
  from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
18
 
19
- from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
19
+ from sglang.srt.utils import (
20
+ format_tcp_address,
21
+ get_ip,
22
+ get_open_port,
23
+ is_valid_ipv6_address,
24
+ )
20
25
 
21
26
  # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
22
27
  SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
@@ -225,9 +230,9 @@ class MessageQueue:
225
230
  remote_subscribe_port = get_open_port()
226
231
  if is_valid_ipv6_address(connect_ip):
227
232
  self.remote_socket.setsockopt(IPV6, 1)
228
- connect_ip = f"[{connect_ip}]"
229
- socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
230
- self.remote_socket.bind(socket_addr)
233
+ self.remote_socket.bind(
234
+ format_tcp_address(connect_ip, remote_subscribe_port)
235
+ )
231
236
 
232
237
  else:
233
238
  remote_subscribe_port = None
@@ -288,7 +293,9 @@ class MessageQueue:
288
293
  self.remote_socket.setsockopt_string(SUBSCRIBE, "")
289
294
  if is_valid_ipv6_address(handle.connect_ip):
290
295
  self.remote_socket.setsockopt(IPV6, 1)
291
- socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
296
+ socket_addr = format_tcp_address(
297
+ handle.connect_ip, handle.remote_subscribe_port
298
+ )
292
299
  logger.debug("Connecting to %s", socket_addr)
293
300
  self.remote_socket.connect(socket_addr)
294
301
 
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
44
44
  get_bool_env_var,
45
45
  get_int_env_var,
46
46
  is_cuda_alike,
47
+ is_hip,
47
48
  is_npu,
48
49
  is_shm_available,
49
50
  supports_custom_op,
@@ -126,14 +127,18 @@ if supports_custom_op():
126
127
  fake_impl=inplace_all_reduce_fake,
127
128
  )
128
129
 
129
- def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
130
+ def outplace_all_reduce(
131
+ tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
132
+ ) -> torch.Tensor:
130
133
  assert group_name in _groups, f"Group {group_name} is not found."
131
134
  group = _groups[group_name]()
132
135
  if group is None:
133
136
  raise ValueError(f"Group {group_name} is destroyed.")
134
- return group._all_reduce_out_place(tensor)
137
+ return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
135
138
 
136
- def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
139
+ def outplace_all_reduce_fake(
140
+ tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
141
+ ) -> torch.Tensor:
137
142
  return torch.empty_like(tensor)
138
143
 
139
144
  direct_register_custom_op(
@@ -264,6 +269,12 @@ class GroupCoordinator:
264
269
  PyNcclCommunicator,
265
270
  )
266
271
 
272
+ if is_hip():
273
+ from sglang.srt.distributed.device_communicators.quick_all_reduce import (
274
+ QuickAllReduce,
275
+ qr_rocm_arch_available,
276
+ )
277
+
267
278
  self.pynccl_comm: Optional[PyNcclCommunicator] = None
268
279
  if use_pynccl and self.world_size > 1:
269
280
  self.pynccl_comm = PyNcclCommunicator(
@@ -283,6 +294,7 @@ class GroupCoordinator:
283
294
  )
284
295
 
285
296
  self.ca_comm: Optional[CustomAllreduce] = None
297
+ self.qr_comm: Optional[QuickAllReduce] = None
286
298
  if use_custom_allreduce and self.world_size > 1:
287
299
  # Initialize a custom fast all-reduce implementation.
288
300
  try:
@@ -295,6 +307,18 @@ class GroupCoordinator:
295
307
  f"Setup Custom allreduce failed with {e}. To silence this "
296
308
  "warning, specify --disable-custom-all-reduce explicitly."
297
309
  )
310
+ if is_hip():
311
+ try:
312
+ # Initialize a custom quick all-reduce implementation for AMD
313
+ # when rocm >= gfx942. Quick reduce is designed as a
314
+ # complement to custom allreduce.
315
+ # Based on quickreduce (https://github.com/mk1-project/quickreduce).
316
+ if qr_rocm_arch_available():
317
+ self.qr_comm = QuickAllReduce(
318
+ group=self.cpu_group, device=self.device
319
+ )
320
+ except Exception as e:
321
+ logger.warning(f"Failed to initialize QuickAllReduce: {e}")
298
322
 
299
323
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
300
324
  HpuCommunicator,
@@ -373,7 +397,8 @@ class GroupCoordinator:
373
397
  graph_capture_context = GraphCaptureContext(stream)
374
398
  else:
375
399
  stream = graph_capture_context.stream
376
-
400
+ # We don't need the context of custom quick allreduce because the ipc access
401
+ # is already collected in init() and we can capture the quick allreduce directly.
377
402
  ca_comm = self.ca_comm
378
403
  maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
379
404
 
@@ -388,23 +413,24 @@ class GroupCoordinator:
388
413
  # operations. The current status is:
389
414
  # allreduce \ Mode | Eager | Graph |
390
415
  # --------------------------------------------
416
+ # quick allreduce | enabled | enabled |
391
417
  # custom allreduce | enabled | enabled |
392
418
  # PyNccl | disabled| enabled |
393
419
  # PyMscclpp | disabled| enabled |
394
420
  # torch.distributed | enabled | disabled|
395
421
  #
422
+ # Note: When custom quick allreduce is enabled, a runtime check
423
+ # will be performed. If the tensor size is too small, it will
424
+ # automatically fall back to the next available option.
396
425
  # Note that custom allreduce will have a runtime check, if the
397
426
  # tensor size is too large, it will fallback to the next
398
427
  # available option.
399
428
  # Note that the PyMsccl needs to register the tensor in ahead,
400
429
  # which will introduce large overhead in the eager case,
401
430
  # therefore it is only supported in the graph case.
402
- # In summary: When using CUDA graph, we use
403
- # either custom all-reduce kernel or pynccl. When not using
404
- # CUDA graph, we use either custom all-reduce kernel or
405
- # PyTorch NCCL. We always prioritize using custom all-reduce
406
- # kernel but fall back to PyTorch or pynccl if it is
407
- # disabled or not supported.
431
+ # In summary: We select the appropriate allreduce method for
432
+ # each mode based on the algorithm order in the table and
433
+ # their usage conditions.
408
434
  pynccl_comm = self.pynccl_comm
409
435
  maybe_pynccl_context: Any
410
436
  if not pynccl_comm:
@@ -464,27 +490,47 @@ class GroupCoordinator:
464
490
  if self.npu_communicator is not None and not self.npu_communicator.disabled:
465
491
  return self.npu_communicator.all_reduce(input_)
466
492
 
493
+ outplace_all_reduce_method = None
467
494
  if (
495
+ self.qr_comm is not None
496
+ and not self.qr_comm.disabled
497
+ and self.qr_comm.should_quick_allreduce(input_)
498
+ ):
499
+ outplace_all_reduce_method = "qr"
500
+ elif (
468
501
  self.ca_comm is not None
469
502
  and not self.ca_comm.disabled
470
503
  and self.ca_comm.should_custom_ar(input_)
471
- ) or (
504
+ ):
505
+ outplace_all_reduce_method = "ca"
506
+ elif (
472
507
  self.pymscclpp_comm is not None
473
508
  and not self.pymscclpp_comm.disabled
474
509
  and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
475
510
  ):
511
+ outplace_all_reduce_method = "pymscclpp"
512
+ if outplace_all_reduce_method is not None:
476
513
  return torch.ops.sglang.outplace_all_reduce(
477
- input_, group_name=self.unique_name
514
+ input_,
515
+ group_name=self.unique_name,
516
+ outplace_all_reduce_method=outplace_all_reduce_method,
478
517
  )
479
518
  else:
480
519
  torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
481
520
  return input_
482
521
 
483
- def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
522
+ def _all_reduce_out_place(
523
+ self, input_: torch.Tensor, outplace_all_reduce_method: str
524
+ ) -> torch.Tensor:
525
+ qr_comm = self.qr_comm
484
526
  ca_comm = self.ca_comm
485
527
  pymscclpp_comm = self.pymscclpp_comm
486
- assert ca_comm is not None or pymscclpp_comm is not None
487
- if ca_comm is not None and not ca_comm.disabled:
528
+ assert any([qr_comm, ca_comm, pymscclpp_comm])
529
+ if outplace_all_reduce_method == "qr":
530
+ assert not qr_comm.disabled
531
+ out = qr_comm.quick_all_reduce(input_)
532
+ elif outplace_all_reduce_method == "ca":
533
+ assert not ca_comm.disabled
488
534
  out = ca_comm.custom_all_reduce(input_)
489
535
  else:
490
536
  assert not pymscclpp_comm.disabled
@@ -499,6 +545,15 @@ class GroupCoordinator:
499
545
  else:
500
546
  torch.distributed.all_reduce(input_, group=self.device_group)
501
547
 
548
+ def reduce_scatter_tensor(
549
+ self,
550
+ output: torch.Tensor,
551
+ input: torch.Tensor,
552
+ ) -> None:
553
+ # TODO(ch-wan): support other backends
554
+ torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
555
+ return output
556
+
502
557
  def reduce_scatter(
503
558
  self,
504
559
  output: torch.Tensor,
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
71
71
  is_cuda,
72
72
  kill_process_tree,
73
73
  launch_dummy_health_check_server,
74
- maybe_set_triton_cache_manager,
75
74
  prepare_model_and_tokenizer,
76
75
  set_prometheus_multiproc_dir,
77
76
  set_ulimit,
@@ -637,16 +636,11 @@ def _set_envs_and_config(server_args: ServerArgs):
637
636
  # Set ulimit
638
637
  set_ulimit()
639
638
 
640
- # Fix triton bugs
641
- if server_args.tp_size * server_args.dp_size > 1:
642
- # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
643
- maybe_set_triton_cache_manager()
644
-
645
639
  # Check flashinfer version
646
640
  if server_args.attention_backend == "flashinfer":
647
641
  assert_pkg_version(
648
642
  "flashinfer_python",
649
- "0.2.7.post1",
643
+ "0.2.9rc2",
650
644
  "Please uninstall the old version and "
651
645
  "reinstall the latest version by following the instructions "
652
646
  "at https://docs.flashinfer.ai/installation.html.",
@@ -654,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
654
648
  if _is_cuda:
655
649
  assert_pkg_version(
656
650
  "sgl-kernel",
657
- "0.2.6.post1",
651
+ "0.2.7",
658
652
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
659
653
  )
660
654
 
@@ -771,7 +765,9 @@ def _launch_subprocesses(
771
765
  # When using `Engine` as a Python API, we don't want to block here.
772
766
  return None, None, None
773
767
 
774
- launch_dummy_health_check_server(server_args.host, server_args.port)
768
+ launch_dummy_health_check_server(
769
+ server_args.host, server_args.port, server_args.enable_metrics
770
+ )
775
771
 
776
772
  for proc in scheduler_procs:
777
773
  proc.join()