sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ import os
5
+ from enum import Enum
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+
12
+ from sglang.srt import _custom_ops as ops
13
+ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
14
+ is_full_nvlink,
15
+ is_weak_contiguous,
16
+ )
17
+ from sglang.srt.distributed.parallel_state import in_the_same_node_as
18
+ from sglang.srt.utils import is_cuda, is_hip
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ _is_cuda = is_cuda()
23
+ _is_hip = is_hip()
24
+
25
+
26
+ try:
27
+ ops.qr_max_size()
28
+ quick_ar = True
29
+ except Exception:
30
+ # For CPUs and CUDA
31
+ quick_ar = False
32
+
33
+
34
+ def qr_rocm_arch_available():
35
+ if not _is_hip:
36
+ return False
37
+ try:
38
+ props = torch.cuda.get_device_properties(0)
39
+ gcn_arch = getattr(props, "gcnArchName", "")
40
+ supported_archs = ["gfx94", "gfx95"]
41
+ return any(gfx in gcn_arch for gfx in supported_archs)
42
+ except Exception as e:
43
+ logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
44
+ return False
45
+
46
+
47
+ class QuickReduceRegime(Enum):
48
+ FP = 0
49
+ INT8 = 1
50
+ INT6 = 2
51
+ INT4 = 3
52
+ NONE = 4
53
+
54
+
55
+ MB = 1024 * 1024
56
+
57
+
58
+ class QuickAllReduce:
59
+
60
+ _SUPPORTED_WORLD_SIZES = [2, 4, 8]
61
+ _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
62
+ # The following data is based on kernel tests.
63
+ # In this order [FP, INT8, INT6, INT4].
64
+ _QR_MIN_SIZE = {
65
+ (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
66
+ (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
67
+ (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
68
+ (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
69
+ (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
70
+ (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
71
+ }
72
+
73
+ def __init__(
74
+ self, group: ProcessGroup, device: Union[int, str, torch.device]
75
+ ) -> None:
76
+ """
77
+ Custom allreduce provides non-destructive acceleration and is
78
+ available for CUDA and ROCm MI300 series.
79
+ Custom quick allreduce leverages quantization for further
80
+ acceleration on ROCm. It currently supports Q8, Q6, and Q4
81
+ quantization formats and FP(float16, bfloat16).
82
+ Quick allreduce is designed as a complement to custom allreduce.
83
+ Its initialization requires even stricter conditions.
84
+ Only the ROCm MI300 series is supported for quick allreduce at
85
+ this time.
86
+ Args:
87
+ group: the process group to work on. If None, it will use the
88
+ default process group.
89
+ device: the device to bind the CustomAllreduce to. If None,
90
+ it will be bind to f"cuda:{local_rank}".
91
+ It is the caller's responsibility to make sure each communicator
92
+ is bind to a unique device, and all communicators in this group
93
+ are in the same node.
94
+ """
95
+ self.disabled = True
96
+ if not qr_rocm_arch_available():
97
+ logger.debug(
98
+ "Custom quick allreduce is only supported on ROCm MI300 series."
99
+ )
100
+ return
101
+
102
+ if not quick_ar:
103
+ # disable because of missing quick reduce library
104
+ # e.g. in a cuda environment
105
+ logger.info(
106
+ "Custom quick allreduce is disabled because "
107
+ "of missing custom quick allreduce library"
108
+ )
109
+ return
110
+
111
+ self.group = group
112
+ assert (
113
+ dist.get_backend(group) != dist.Backend.NCCL
114
+ ), "Custom quick allreduce should be attached to a non-NCCL group."
115
+ if not all(in_the_same_node_as(group, source_rank=0)):
116
+ # No need to initialize custom quick allreduce for
117
+ # multi-node case.
118
+ logger.warning(
119
+ "Custom quick allreduce is disabled because this "
120
+ "process group spans across nodes."
121
+ )
122
+ return
123
+ rank = dist.get_rank(group=self.group)
124
+ world_size = dist.get_world_size(group=self.group)
125
+ self.rank = rank
126
+ self.world_size = world_size
127
+ if world_size == 1:
128
+ # No need to initialize QuickReduce for single GPU case.
129
+ return
130
+
131
+ if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
132
+ logger.warning(
133
+ "Custom quick allreduce is disabled due to an "
134
+ "unsupported world size: %d. Supported world sizes: %s.",
135
+ world_size,
136
+ str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
137
+ )
138
+ return
139
+
140
+ if isinstance(device, int):
141
+ device = torch.device(f"cuda:{device}")
142
+ elif isinstance(device, str):
143
+ device = torch.device(device)
144
+ assert isinstance(device, torch.device)
145
+ self.device = device
146
+
147
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
148
+ if cuda_visible_devices:
149
+ device_ids = list(map(int, cuda_visible_devices.split(",")))
150
+ else:
151
+ device_ids = list(range(torch.cuda.device_count()))
152
+ physical_device_id = device_ids[device.index]
153
+ tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
154
+ gather_list = [
155
+ torch.tensor([0], dtype=torch.int, device="cpu")
156
+ for _ in range(self.world_size)
157
+ ]
158
+ dist.all_gather(gather_list, tensor, group=self.group)
159
+ physical_device_ids = [t.item() for t in gather_list]
160
+
161
+ # test nvlink first, this will filter out most of the cases
162
+ # where custom quick allreduce is not supported
163
+ # this checks hardware and driver support for NVLink
164
+ if _is_cuda or _is_hip:
165
+ self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size)
166
+ if self.world_size > 2 and not self.fully_connected:
167
+ logger.debug(
168
+ "Custom quick allreduce is disabled because it's not supported "
169
+ "on more than two PCIe-only GPUs. "
170
+ )
171
+ return
172
+
173
+ self.init_quick_all_reduce()
174
+
175
+ def init_quick_all_reduce(self):
176
+ # On RocM, bfloat16 kernels are slower than fp16
177
+ # due to slower match operations
178
+ # If environment variable is set to 1, we convert input to fp16
179
+ self.use_fp16_kernels = int(
180
+ os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1)
181
+ )
182
+ regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE")
183
+ if regime_str not in QuickReduceRegime.__members__:
184
+ logger.warning(
185
+ "Custom quick allreduce:",
186
+ f"Invalid quantization level: {regime_str}. "
187
+ "Supported levels: "
188
+ f"{list(QuickReduceRegime.__members__.keys())}",
189
+ )
190
+ return
191
+
192
+ if regime_str == "NONE":
193
+ logger.debug(
194
+ "Custom quick allreduce is disabled based "
195
+ "on env variable "
196
+ "ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
197
+ )
198
+ return
199
+ self.qr_quant_level = QuickReduceRegime[regime_str]
200
+
201
+ # TODO: If the dtype is not bfloat16 or then float16,
202
+ # quickallreduce should not be created.
203
+
204
+ # ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
205
+ qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0))
206
+ if qr_max_size > 0:
207
+ if qr_max_size < 1:
208
+ logger.info(
209
+ "You should not set a max_size smaller than 1MB, which can "
210
+ "lead to error or degradation to custom allreduce or rccl."
211
+ )
212
+ qr_max_size = qr_max_size * MB
213
+ # If qr_max_size is None, then 2GB is used by default.
214
+ self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
215
+ self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size()
216
+ self.create_shared_buffer()
217
+ self.disabled = False
218
+
219
+ def create_shared_buffer(self):
220
+ """
221
+ Creates a shared buffer for quickreduce.
222
+ Has to be called after init_custom_qr
223
+ """
224
+ handle = ops.qr_get_handle(self._ptr)
225
+ world_size = dist.get_world_size(group=self.group)
226
+ handles = [None] * world_size
227
+ dist.all_gather_object(handles, handle, group=self.group)
228
+ ops.qr_open_handles(self._ptr, handles)
229
+
230
+ def should_quick_allreduce(self, inp: torch.Tensor):
231
+ """
232
+ Check if quickreduce is available
233
+ """
234
+ if self.disabled:
235
+ return False
236
+ if inp.dtype not in self._SUPPORTED_DTYPES:
237
+ return False
238
+ inp_size = inp.numel() * inp.element_size()
239
+ # custom quick allreduce requires input byte size to be
240
+ # multiples of 16
241
+ if inp_size % 16 != 0:
242
+ return False
243
+ if not is_weak_contiguous(inp):
244
+ return False
245
+ dtype = inp.dtype
246
+ if self.use_fp16_kernels:
247
+ dtype = torch.float16
248
+ return (
249
+ inp_size <= self.qr_max_size
250
+ and inp_size
251
+ >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
252
+ )
253
+
254
+ def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
255
+ """Performs an out-of-place custom quick all reduce."""
256
+ # quick allreduce doesn't require a separate graph mode,
257
+ # as QR uses static IPC buffer.
258
+ if out is None:
259
+ out = torch.empty_like(inp)
260
+ ops.qr_all_reduce(
261
+ self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
262
+ )
263
+ return out
264
+
265
+ def close(self):
266
+ if not self.disabled and getattr(self, "_ptr", None):
267
+ if ops is not None:
268
+ ops.qr_destroy(self._ptr)
269
+ self._ptr = 0
270
+ self.disabled = True
271
+
272
+ def __del__(self):
273
+ self.close()
@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
16
16
  from zmq import IPV6 # type: ignore
17
17
  from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
18
 
19
- from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
19
+ from sglang.srt.utils import (
20
+ format_tcp_address,
21
+ get_ip,
22
+ get_open_port,
23
+ is_valid_ipv6_address,
24
+ )
20
25
 
21
26
  # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
22
27
  SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
@@ -225,9 +230,9 @@ class MessageQueue:
225
230
  remote_subscribe_port = get_open_port()
226
231
  if is_valid_ipv6_address(connect_ip):
227
232
  self.remote_socket.setsockopt(IPV6, 1)
228
- connect_ip = f"[{connect_ip}]"
229
- socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
230
- self.remote_socket.bind(socket_addr)
233
+ self.remote_socket.bind(
234
+ format_tcp_address(connect_ip, remote_subscribe_port)
235
+ )
231
236
 
232
237
  else:
233
238
  remote_subscribe_port = None
@@ -288,7 +293,9 @@ class MessageQueue:
288
293
  self.remote_socket.setsockopt_string(SUBSCRIBE, "")
289
294
  if is_valid_ipv6_address(handle.connect_ip):
290
295
  self.remote_socket.setsockopt(IPV6, 1)
291
- socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
296
+ socket_addr = format_tcp_address(
297
+ handle.connect_ip, handle.remote_subscribe_port
298
+ )
292
299
  logger.debug("Connecting to %s", socket_addr)
293
300
  self.remote_socket.connect(socket_addr)
294
301
 
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
44
44
  get_bool_env_var,
45
45
  get_int_env_var,
46
46
  is_cuda_alike,
47
+ is_hip,
47
48
  is_npu,
48
49
  is_shm_available,
49
50
  supports_custom_op,
@@ -126,14 +127,18 @@ if supports_custom_op():
126
127
  fake_impl=inplace_all_reduce_fake,
127
128
  )
128
129
 
129
- def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
130
+ def outplace_all_reduce(
131
+ tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
132
+ ) -> torch.Tensor:
130
133
  assert group_name in _groups, f"Group {group_name} is not found."
131
134
  group = _groups[group_name]()
132
135
  if group is None:
133
136
  raise ValueError(f"Group {group_name} is destroyed.")
134
- return group._all_reduce_out_place(tensor)
137
+ return group._all_reduce_out_place(tensor, outplace_all_reduce_method)
135
138
 
136
- def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
139
+ def outplace_all_reduce_fake(
140
+ tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str
141
+ ) -> torch.Tensor:
137
142
  return torch.empty_like(tensor)
138
143
 
139
144
  direct_register_custom_op(
@@ -264,6 +269,12 @@ class GroupCoordinator:
264
269
  PyNcclCommunicator,
265
270
  )
266
271
 
272
+ if is_hip():
273
+ from sglang.srt.distributed.device_communicators.quick_all_reduce import (
274
+ QuickAllReduce,
275
+ qr_rocm_arch_available,
276
+ )
277
+
267
278
  self.pynccl_comm: Optional[PyNcclCommunicator] = None
268
279
  if use_pynccl and self.world_size > 1:
269
280
  self.pynccl_comm = PyNcclCommunicator(
@@ -283,6 +294,7 @@ class GroupCoordinator:
283
294
  )
284
295
 
285
296
  self.ca_comm: Optional[CustomAllreduce] = None
297
+ self.qr_comm: Optional[QuickAllReduce] = None
286
298
  if use_custom_allreduce and self.world_size > 1:
287
299
  # Initialize a custom fast all-reduce implementation.
288
300
  try:
@@ -295,6 +307,18 @@ class GroupCoordinator:
295
307
  f"Setup Custom allreduce failed with {e}. To silence this "
296
308
  "warning, specify --disable-custom-all-reduce explicitly."
297
309
  )
310
+ if is_hip():
311
+ try:
312
+ # Initialize a custom quick all-reduce implementation for AMD
313
+ # when rocm >= gfx942. Quick reduce is designed as a
314
+ # complement to custom allreduce.
315
+ # Based on quickreduce (https://github.com/mk1-project/quickreduce).
316
+ if qr_rocm_arch_available():
317
+ self.qr_comm = QuickAllReduce(
318
+ group=self.cpu_group, device=self.device
319
+ )
320
+ except Exception as e:
321
+ logger.warning(f"Failed to initialize QuickAllReduce: {e}")
298
322
 
299
323
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
300
324
  HpuCommunicator,
@@ -373,7 +397,8 @@ class GroupCoordinator:
373
397
  graph_capture_context = GraphCaptureContext(stream)
374
398
  else:
375
399
  stream = graph_capture_context.stream
376
-
400
+ # We don't need the context of custom quick allreduce because the ipc access
401
+ # is already collected in init() and we can capture the quick allreduce directly.
377
402
  ca_comm = self.ca_comm
378
403
  maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
379
404
 
@@ -388,23 +413,24 @@ class GroupCoordinator:
388
413
  # operations. The current status is:
389
414
  # allreduce \ Mode | Eager | Graph |
390
415
  # --------------------------------------------
416
+ # quick allreduce | enabled | enabled |
391
417
  # custom allreduce | enabled | enabled |
392
418
  # PyNccl | disabled| enabled |
393
419
  # PyMscclpp | disabled| enabled |
394
420
  # torch.distributed | enabled | disabled|
395
421
  #
422
+ # Note: When custom quick allreduce is enabled, a runtime check
423
+ # will be performed. If the tensor size is too small, it will
424
+ # automatically fall back to the next available option.
396
425
  # Note that custom allreduce will have a runtime check, if the
397
426
  # tensor size is too large, it will fallback to the next
398
427
  # available option.
399
428
  # Note that the PyMsccl needs to register the tensor in ahead,
400
429
  # which will introduce large overhead in the eager case,
401
430
  # therefore it is only supported in the graph case.
402
- # In summary: When using CUDA graph, we use
403
- # either custom all-reduce kernel or pynccl. When not using
404
- # CUDA graph, we use either custom all-reduce kernel or
405
- # PyTorch NCCL. We always prioritize using custom all-reduce
406
- # kernel but fall back to PyTorch or pynccl if it is
407
- # disabled or not supported.
431
+ # In summary: We select the appropriate allreduce method for
432
+ # each mode based on the algorithm order in the table and
433
+ # their usage conditions.
408
434
  pynccl_comm = self.pynccl_comm
409
435
  maybe_pynccl_context: Any
410
436
  if not pynccl_comm:
@@ -464,27 +490,47 @@ class GroupCoordinator:
464
490
  if self.npu_communicator is not None and not self.npu_communicator.disabled:
465
491
  return self.npu_communicator.all_reduce(input_)
466
492
 
493
+ outplace_all_reduce_method = None
467
494
  if (
495
+ self.qr_comm is not None
496
+ and not self.qr_comm.disabled
497
+ and self.qr_comm.should_quick_allreduce(input_)
498
+ ):
499
+ outplace_all_reduce_method = "qr"
500
+ elif (
468
501
  self.ca_comm is not None
469
502
  and not self.ca_comm.disabled
470
503
  and self.ca_comm.should_custom_ar(input_)
471
- ) or (
504
+ ):
505
+ outplace_all_reduce_method = "ca"
506
+ elif (
472
507
  self.pymscclpp_comm is not None
473
508
  and not self.pymscclpp_comm.disabled
474
509
  and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
475
510
  ):
511
+ outplace_all_reduce_method = "pymscclpp"
512
+ if outplace_all_reduce_method is not None:
476
513
  return torch.ops.sglang.outplace_all_reduce(
477
- input_, group_name=self.unique_name
514
+ input_,
515
+ group_name=self.unique_name,
516
+ outplace_all_reduce_method=outplace_all_reduce_method,
478
517
  )
479
518
  else:
480
519
  torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
481
520
  return input_
482
521
 
483
- def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
522
+ def _all_reduce_out_place(
523
+ self, input_: torch.Tensor, outplace_all_reduce_method: str
524
+ ) -> torch.Tensor:
525
+ qr_comm = self.qr_comm
484
526
  ca_comm = self.ca_comm
485
527
  pymscclpp_comm = self.pymscclpp_comm
486
- assert ca_comm is not None or pymscclpp_comm is not None
487
- if ca_comm is not None and not ca_comm.disabled:
528
+ assert any([qr_comm, ca_comm, pymscclpp_comm])
529
+ if outplace_all_reduce_method == "qr":
530
+ assert not qr_comm.disabled
531
+ out = qr_comm.quick_all_reduce(input_)
532
+ elif outplace_all_reduce_method == "ca":
533
+ assert not ca_comm.disabled
488
534
  out = ca_comm.custom_all_reduce(input_)
489
535
  else:
490
536
  assert not pymscclpp_comm.disabled
@@ -499,6 +545,15 @@ class GroupCoordinator:
499
545
  else:
500
546
  torch.distributed.all_reduce(input_, group=self.device_group)
501
547
 
548
+ def reduce_scatter_tensor(
549
+ self,
550
+ output: torch.Tensor,
551
+ input: torch.Tensor,
552
+ ) -> None:
553
+ # TODO(ch-wan): support other backends
554
+ torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
555
+ return output
556
+
502
557
  def reduce_scatter(
503
558
  self,
504
559
  output: torch.Tensor,
@@ -1065,8 +1120,23 @@ def init_model_parallel_group(
1065
1120
 
1066
1121
  _TP: Optional[GroupCoordinator] = None
1067
1122
 
1123
+ # duplicate GroupCoordinator for prefill in PD-Multiplexing
1124
+ _PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
1125
+
1126
+ _ENABLE_PDMUX_P_TP: bool = False
1127
+
1128
+
1129
+ def set_pdmux_status(enable_prefill_multiplexing: bool):
1130
+ global _ENABLE_PDMUX_P_TP
1131
+ _ENABLE_PDMUX_P_TP = enable_prefill_multiplexing
1132
+
1068
1133
 
1069
1134
  def get_tp_group() -> GroupCoordinator:
1135
+ if _ENABLE_PDMUX_P_TP:
1136
+ assert (
1137
+ _PDMUX_PREFILL_TP_GROUP is not None
1138
+ ), "tensor model parallel group for PD-Multiplexing Prefill is not initialized"
1139
+ return _PDMUX_PREFILL_TP_GROUP
1070
1140
  assert _TP is not None, "tensor model parallel group is not initialized"
1071
1141
  return _TP
1072
1142
 
@@ -1182,6 +1252,7 @@ def initialize_model_parallel(
1182
1252
  tensor_model_parallel_size: int = 1,
1183
1253
  pipeline_model_parallel_size: int = 1,
1184
1254
  backend: Optional[str] = None,
1255
+ duplicate_tp_group: bool = False,
1185
1256
  ) -> None:
1186
1257
  """
1187
1258
  Initialize model parallel groups.
@@ -1239,6 +1310,23 @@ def initialize_model_parallel(
1239
1310
  group_name="tp",
1240
1311
  )
1241
1312
 
1313
+ if duplicate_tp_group:
1314
+ global _PDMUX_PREFILL_TP_GROUP
1315
+ assert (
1316
+ _PDMUX_PREFILL_TP_GROUP is None
1317
+ ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
1318
+ _PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
1319
+ group_ranks,
1320
+ get_world_group().local_rank,
1321
+ backend,
1322
+ use_message_queue_broadcaster=get_bool_env_var(
1323
+ "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
1324
+ ),
1325
+ group_name="pdmux_prefill_tp",
1326
+ )
1327
+ _TP.pynccl_comm.disabled = False
1328
+ _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1329
+
1242
1330
  # Build the pipeline model-parallel groups.
1243
1331
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
1244
1332
  global _PP
@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import (
46
46
  EmbeddingReqInput,
47
47
  GenerateReqInput,
48
48
  GetWeightsByNameReqInput,
49
- ImageDataItem,
50
49
  InitWeightsUpdateGroupReqInput,
51
50
  LoadLoRAAdapterReqInput,
51
+ MultimodalDataInputFormat,
52
52
  ReleaseMemoryOccupationReqInput,
53
53
  ResumeMemoryOccupationReqInput,
54
54
  RpcReqInput,
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
71
71
  is_cuda,
72
72
  kill_process_tree,
73
73
  launch_dummy_health_check_server,
74
- maybe_set_triton_cache_manager,
75
74
  prepare_model_and_tokenizer,
76
75
  set_prometheus_multiproc_dir,
77
76
  set_ulimit,
@@ -148,13 +147,9 @@ class Engine(EngineBase):
148
147
  # - List of images (one per request in a batch)
149
148
  # - List of lists of images (multiple images per request)
150
149
  # See also python/sglang/srt/utils.py:load_image for more details.
151
- image_data: Optional[
152
- Union[
153
- List[List[ImageDataItem]],
154
- List[ImageDataItem],
155
- ImageDataItem,
156
- ]
157
- ] = None,
150
+ image_data: Optional[MultimodalDataInputFormat] = None,
151
+ audio_data: Optional[MultimodalDataInputFormat] = None,
152
+ video_data: Optional[MultimodalDataInputFormat] = None,
158
153
  return_logprob: Optional[Union[List[bool], bool]] = False,
159
154
  logprob_start_len: Optional[Union[List[int], int]] = None,
160
155
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -187,6 +182,8 @@ class Engine(EngineBase):
187
182
  input_ids=input_ids,
188
183
  sampling_params=sampling_params,
189
184
  image_data=image_data,
185
+ audio_data=audio_data,
186
+ video_data=video_data,
190
187
  return_logprob=return_logprob,
191
188
  logprob_start_len=logprob_start_len,
192
189
  top_logprobs_num=top_logprobs_num,
@@ -231,13 +228,9 @@ class Engine(EngineBase):
231
228
  # - List of images (one per request in a batch)
232
229
  # - List of lists of images (multiple images per request)
233
230
  # See also python/sglang/srt/utils.py:load_image for more details.
234
- image_data: Optional[
235
- Union[
236
- List[List[ImageDataItem]],
237
- List[ImageDataItem],
238
- ImageDataItem,
239
- ]
240
- ] = None,
231
+ image_data: Optional[MultimodalDataInputFormat] = None,
232
+ audio_data: Optional[MultimodalDataInputFormat] = None,
233
+ video_data: Optional[MultimodalDataInputFormat] = None,
241
234
  return_logprob: Optional[Union[List[bool], bool]] = False,
242
235
  logprob_start_len: Optional[Union[List[int], int]] = None,
243
236
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -272,6 +265,8 @@ class Engine(EngineBase):
272
265
  input_ids=input_ids,
273
266
  sampling_params=sampling_params,
274
267
  image_data=image_data,
268
+ audio_data=audio_data,
269
+ video_data=video_data,
275
270
  return_logprob=return_logprob,
276
271
  logprob_start_len=logprob_start_len,
277
272
  top_logprobs_num=top_logprobs_num,
@@ -295,19 +290,20 @@ class Engine(EngineBase):
295
290
  def encode(
296
291
  self,
297
292
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
298
- image_data: Optional[
299
- Union[
300
- List[List[Union[Image, str]]],
301
- List[Union[Image, str]],
302
- Union[Image, str],
303
- ]
304
- ] = None,
293
+ image_data: Optional[MultimodalDataInputFormat] = None,
294
+ audio_data: Optional[MultimodalDataInputFormat] = None,
295
+ video_data: Optional[MultimodalDataInputFormat] = None,
305
296
  ) -> Dict:
306
297
  """
307
298
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
308
299
  Please refer to `EmbeddingReqInput` for the documentation.
309
300
  """
310
- obj = EmbeddingReqInput(text=prompt, image_data=image_data)
301
+ obj = EmbeddingReqInput(
302
+ text=prompt,
303
+ image_data=image_data,
304
+ audio_data=audio_data,
305
+ video_data=video_data,
306
+ )
311
307
  loop = asyncio.get_event_loop()
312
308
  generator = self.tokenizer_manager.generate_request(obj, None)
313
309
  ret = loop.run_until_complete(generator.__anext__())
@@ -316,7 +312,9 @@ class Engine(EngineBase):
316
312
  async def async_encode(
317
313
  self,
318
314
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
319
- image_data: Optional[Union[List[str], str]] = None,
315
+ image_data: Optional[MultimodalDataInputFormat] = None,
316
+ audio_data: Optional[MultimodalDataInputFormat] = None,
317
+ video_data: Optional[MultimodalDataInputFormat] = None,
320
318
  ) -> Dict:
321
319
  """
322
320
  Asynchronous version of encode method.
@@ -324,7 +322,12 @@ class Engine(EngineBase):
324
322
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
325
323
  Please refer to `EmbeddingReqInput` for the documentation.
326
324
  """
327
- obj = EmbeddingReqInput(text=prompt, image_data=image_data)
325
+ obj = EmbeddingReqInput(
326
+ text=prompt,
327
+ image_data=image_data,
328
+ audio_data=audio_data,
329
+ video_data=video_data,
330
+ )
328
331
  generator = self.tokenizer_manager.generate_request(obj, None)
329
332
  return await generator.__anext__()
330
333
 
@@ -633,16 +636,11 @@ def _set_envs_and_config(server_args: ServerArgs):
633
636
  # Set ulimit
634
637
  set_ulimit()
635
638
 
636
- # Fix triton bugs
637
- if server_args.tp_size * server_args.dp_size > 1:
638
- # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
639
- maybe_set_triton_cache_manager()
640
-
641
639
  # Check flashinfer version
642
640
  if server_args.attention_backend == "flashinfer":
643
641
  assert_pkg_version(
644
642
  "flashinfer_python",
645
- "0.2.7.post1",
643
+ "0.2.9rc1",
646
644
  "Please uninstall the old version and "
647
645
  "reinstall the latest version by following the instructions "
648
646
  "at https://docs.flashinfer.ai/installation.html.",
@@ -650,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
650
648
  if _is_cuda:
651
649
  assert_pkg_version(
652
650
  "sgl-kernel",
653
- "0.2.5",
651
+ "0.2.7",
654
652
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
655
653
  )
656
654