sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
6
  from enum import IntEnum, auto
7
- from typing import TYPE_CHECKING, List, Tuple
7
+ from typing import TYPE_CHECKING, List, Optional, Tuple
8
8
 
9
9
  import torch
10
10
  import triton
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
20
 
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.configs.model_config import ModelConfig
23
+ from sglang.srt.server_args import ServerArgs
24
+
21
25
  logger = logging.getLogger(__name__)
22
26
 
23
27
  if TYPE_CHECKING:
24
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
29
 
26
- _ATTN_TP_GROUP = None
27
- _ATTN_TP_RANK = None
28
- _ATTN_TP_SIZE = None
29
- _ATTN_DP_RANK = None
30
- _ATTN_DP_SIZE = None
31
- _LOCAL_ATTN_DP_SIZE = None
32
- _LOCAL_ATTN_DP_RANK = None
30
+ _ATTN_TP_GROUP: Optional[GroupCoordinator] = None
31
+ _ATTN_TP_RANK: Optional[int] = None
32
+ _ATTN_TP_SIZE: Optional[int] = None
33
+ _ATTN_DP_RANK: Optional[int] = None
34
+ _ATTN_DP_SIZE: Optional[int] = None
35
+ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
+ _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
+ _ENABLE_DP_ATTENTION_FLAG: bool = False
33
38
 
34
39
 
35
- class DPPaddingMode(IntEnum):
40
+ class DpPaddingMode(IntEnum):
36
41
 
37
42
  # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
38
43
  MAX_LEN = auto()
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
40
45
  SUM_LEN = auto()
41
46
 
42
47
  def is_max_len(self):
43
- return self == DPPaddingMode.MAX_LEN
48
+ return self == DpPaddingMode.MAX_LEN
44
49
 
45
50
  def is_sum_len(self):
46
- return self == DPPaddingMode.SUM_LEN
51
+ return self == DpPaddingMode.SUM_LEN
47
52
 
48
53
  @classmethod
49
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
54
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
50
55
  # we choose the mode that minimizes the communication cost
51
56
  max_len = max(global_num_tokens)
52
57
  sum_len = sum(global_num_tokens)
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
56
61
  return cls.SUM_LEN
57
62
 
58
63
  @classmethod
59
- def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
64
+ def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
60
65
  return cls.MAX_LEN
61
66
 
62
67
 
68
+ class _DpGatheredBufferWrapper:
69
+
70
+ _hidden_size: int
71
+ _dtype: torch.dtype
72
+ _device: torch.device
73
+ _global_dp_buffer_len: int
74
+ _local_dp_buffer_len: int
75
+
76
+ @classmethod
77
+ def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
78
+ cls._hidden_size = hidden_size
79
+ cls._dtype = dtype
80
+ cls._device = device
81
+
82
+ @classmethod
83
+ def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
84
+ cls._global_dp_buffer_len = global_dp_buffer_len
85
+ cls._local_dp_buffer_len = local_dp_buffer_len
86
+
87
+ @classmethod
88
+ def get_global_dp_buffer(cls) -> torch.Tensor:
89
+ return torch.empty(
90
+ (cls._global_dp_buffer_len, cls._hidden_size),
91
+ dtype=cls._dtype,
92
+ device=cls._device,
93
+ )
94
+
95
+ @classmethod
96
+ def get_local_dp_buffer(cls) -> torch.Tensor:
97
+ return torch.empty(
98
+ (cls._local_dp_buffer_len, cls._hidden_size),
99
+ dtype=cls._dtype,
100
+ device=cls._device,
101
+ )
102
+
103
+ @classmethod
104
+ def get_global_dp_buffer_len(cls) -> int:
105
+ return cls._global_dp_buffer_len
106
+
107
+ @classmethod
108
+ def get_local_dp_buffer_len(cls) -> int:
109
+ return cls._local_dp_buffer_len
110
+
111
+
112
+ def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
113
+ _DpGatheredBufferWrapper.set_dp_buffer_len(
114
+ global_dp_buffer_len, local_dp_buffer_len
115
+ )
116
+
117
+
118
+ def get_global_dp_buffer() -> torch.Tensor:
119
+ return _DpGatheredBufferWrapper.get_global_dp_buffer()
120
+
121
+
122
+ def get_local_dp_buffer() -> torch.Tensor:
123
+ return _DpGatheredBufferWrapper.get_local_dp_buffer()
124
+
125
+
126
+ def get_global_dp_buffer_len() -> int:
127
+ return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
128
+
129
+
130
+ def get_local_dp_buffer_len() -> int:
131
+ return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
132
+
133
+
63
134
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
64
135
  if not enable_dp_attention:
65
136
  return tp_rank, tp_size, 0
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
89
160
 
90
161
 
91
162
  def initialize_dp_attention(
92
- enable_dp_attention: bool,
93
- tp_rank: int,
94
- tp_size: int,
95
- dp_size: int,
96
- moe_dense_tp_size: int,
97
- pp_size: int,
163
+ server_args: ServerArgs,
164
+ model_config: ModelConfig,
98
165
  ):
99
166
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
100
- global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
167
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
101
168
 
102
169
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
103
170
 
171
+ enable_dp_attention = server_args.enable_dp_attention
172
+ tp_size = server_args.tp_size
173
+ dp_size = server_args.dp_size
174
+ moe_dense_tp_size = server_args.moe_dense_tp_size
175
+ pp_size = server_args.pp_size
176
+
177
+ tp_rank = get_tensor_model_parallel_rank()
178
+
179
+ _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
180
+
104
181
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
105
182
  enable_dp_attention, tp_rank, tp_size, dp_size
106
183
  )
@@ -135,38 +212,48 @@ def initialize_dp_attention(
135
212
  group_name="attention_tp",
136
213
  )
137
214
 
215
+ _DpGatheredBufferWrapper.set_metadata(
216
+ hidden_size=model_config.hidden_size,
217
+ dtype=model_config.dtype,
218
+ device=torch.device("cuda"),
219
+ )
138
220
 
139
- def get_attention_tp_group():
221
+
222
+ def is_dp_attention_enabled() -> bool:
223
+ return _ENABLE_DP_ATTENTION_FLAG
224
+
225
+
226
+ def get_attention_tp_group() -> GroupCoordinator:
140
227
  assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
141
228
  return _ATTN_TP_GROUP
142
229
 
143
230
 
144
- def get_attention_tp_rank():
231
+ def get_attention_tp_rank() -> int:
145
232
  assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
146
233
  return _ATTN_TP_RANK
147
234
 
148
235
 
149
- def get_attention_tp_size():
236
+ def get_attention_tp_size() -> int:
150
237
  assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
151
238
  return _ATTN_TP_SIZE
152
239
 
153
240
 
154
- def get_attention_dp_rank():
241
+ def get_attention_dp_rank() -> int:
155
242
  assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
156
243
  return _ATTN_DP_RANK
157
244
 
158
245
 
159
- def get_attention_dp_size():
246
+ def get_attention_dp_size() -> int:
160
247
  assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
161
248
  return _ATTN_DP_SIZE
162
249
 
163
250
 
164
- def get_local_attention_dp_rank():
251
+ def get_local_attention_dp_rank() -> int:
165
252
  assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
166
253
  return _LOCAL_ATTN_DP_RANK
167
254
 
168
255
 
169
- def get_local_attention_dp_size():
256
+ def get_local_attention_dp_size() -> int:
170
257
  assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
171
258
  return _LOCAL_ATTN_DP_SIZE
172
259
 
@@ -292,6 +379,10 @@ def _dp_gather_via_all_gather(
292
379
  forward_batch: ForwardBatch,
293
380
  is_partial: bool,
294
381
  ):
382
+ if get_attention_tp_size() == 1:
383
+ get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
384
+ return
385
+
295
386
  if not is_partial:
296
387
  if get_attention_tp_rank() != 0:
297
388
  local_tokens.fill_(0)
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Tuple
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
92
92
 
93
93
 
94
94
  def ensure_workspace_initialized(
95
- max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
95
+ max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
96
  ):
97
97
  """Ensure workspace is initialized"""
98
98
  if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
124
124
  residual: torch.Tensor,
125
125
  weight: torch.Tensor,
126
126
  eps: float = 1e-6,
127
- max_token_num: int = 128,
128
- use_oneshot: bool = True,
127
+ max_token_num: int = 2048,
128
+ use_oneshot: Optional[bool] = None,
129
129
  trigger_completion_at_end: bool = False,
130
130
  fp32_acc: bool = False,
131
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
1294
1294
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1295
1295
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1296
1296
  sm.tag(output_parallel)
1297
+
1297
1298
  if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
1298
1299
  output = tensor_model_parallel_all_reduce(output_parallel)
1299
1300
  else:
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
27
27
  tensor_model_parallel_all_gather,
28
28
  )
29
29
  from sglang.srt.layers.dp_attention import (
30
- DPPaddingMode,
30
+ DpPaddingMode,
31
31
  attn_tp_all_gather,
32
32
  attn_tp_all_gather_into_tensor,
33
33
  dp_gather_replicate,
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_global_dp_buffer,
38
39
  get_local_attention_dp_size,
40
+ set_dp_buffer_len,
39
41
  )
40
42
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -108,14 +110,12 @@ class LogitsMetadata:
108
110
  # The start position of local hidden states.
109
111
  dp_local_start_pos: Optional[torch.Tensor] = None
110
112
  dp_local_num_tokens: Optional[torch.Tensor] = None
111
- gathered_buffer: Optional[torch.Tensor] = None
112
- # Buffer to gather logits from all ranks.
113
- forward_batch_gathered_buffer: Optional[torch.Tensor] = None
113
+ global_dp_buffer_len: Optional[int] = None
114
114
  # Number of tokens to sample per DP rank
115
115
  global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
116
116
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
117
117
  # The gather mode for DP attention
118
- dp_padding_mode: Optional[DPPaddingMode] = None
118
+ dp_padding_mode: Optional[DpPaddingMode] = None
119
119
  # for padding
120
120
  padded_static_len: int = -1
121
121
 
@@ -164,11 +164,10 @@ class LogitsMetadata:
164
164
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
165
165
  dp_local_start_pos=forward_batch.dp_local_start_pos,
166
166
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
167
- gathered_buffer=forward_batch.gathered_buffer,
168
- forward_batch_gathered_buffer=forward_batch.gathered_buffer,
167
+ global_dp_buffer_len=forward_batch.global_dp_buffer_len,
169
168
  global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
170
169
  global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
171
- dp_padding_mode=DPPaddingMode.SUM_LEN,
170
+ dp_padding_mode=DpPaddingMode.SUM_LEN,
172
171
  )
173
172
 
174
173
  def compute_dp_attention_metadata(self):
@@ -188,16 +187,11 @@ class LogitsMetadata:
188
187
 
189
188
  if self.global_num_tokens_for_logprob_cpu is not None:
190
189
  # create a smaller buffer to reduce peak memory usage
191
- self.gathered_buffer = torch.empty(
192
- (
193
- sum(self.global_num_tokens_for_logprob_cpu),
194
- self.gathered_buffer.shape[1],
195
- ),
196
- dtype=self.gathered_buffer.dtype,
197
- device=self.gathered_buffer.device,
198
- )
190
+ self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
199
191
  else:
200
- self.gathered_buffer = torch.empty_like(self.gathered_buffer)
192
+ self.global_dp_buffer_len = self.global_dp_buffer_len
193
+
194
+ set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
201
195
 
202
196
 
203
197
  class LogitsProcessor(nn.Module):
@@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module):
443
437
  if self.do_tensor_parallel_all_gather_dp_attn:
444
438
  logits_metadata.compute_dp_attention_metadata()
445
439
  hidden_states, local_hidden_states = (
446
- logits_metadata.gathered_buffer,
440
+ get_global_dp_buffer(),
447
441
  hidden_states,
448
442
  )
449
443
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
9
9
  import torch
10
10
 
11
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
- from sglang.srt.layers.utils import is_sm100_supported
12
+ from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
13
13
  from sglang.srt.utils import is_cuda
14
14
 
15
15
  _is_cuda = is_cuda()
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
124
124
 
125
125
  if is_cuda:
126
126
  from sglang.srt.layers.quantization.fp8_kernel import (
127
+ per_group_transpose,
127
128
  per_token_group_quant_fp8_hopper_moe_mn_major,
128
129
  sglang_per_token_group_quant_fp8,
129
130
  )
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
152
153
  k,
153
154
  )
154
155
 
155
- if is_sm100_supported():
156
- a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
- rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
- rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
- else:
160
- rep_a = shuffle_rows(a, a_map, (m * topk, k))
161
- rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
162
- rep_a, expert_offsets, problem_sizes1, 128
163
- )
156
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
+ rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
+ rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
+
160
+ if not is_sm100_supported():
161
+ rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
164
162
  w1_scale = w1_scale.contiguous()
165
163
 
166
164
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
193
191
  intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
194
192
  silu_and_mul(c1, intermediate)
195
193
 
196
- if is_sm100_supported():
197
- intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
198
- else:
199
- intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
200
- intermediate, expert_offsets, problem_sizes2, 128
201
- )
194
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
195
+ if not is_sm100_supported():
196
+ a2_scale = per_group_transpose(a2_scale, expert_offsets)
202
197
  w2_scale = w2_scale.contiguous()
203
198
 
204
199
  fp8_blockwise_scaled_grouped_mm(
@@ -11,7 +11,7 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
- post_reorder_triton_kernel,
14
+ post_reorder_triton_kernel_for_cutlass_moe,
15
15
  pre_reorder_triton_kernel_for_cutlass_moe,
16
16
  run_cutlass_moe_ep_preproess,
17
17
  )
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
199
199
  )
200
200
 
201
201
  output = torch.empty_like(a)
202
- post_reorder_triton_kernel[(m,)](
202
+ post_reorder_triton_kernel_for_cutlass_moe[(m,)](
203
203
  c2,
204
204
  output,
205
205
  src2dst,
206
- topk_ids_,
206
+ local_topk_ids,
207
207
  topk_weights,
208
- start_expert_id,
209
- end_expert_id,
208
+ num_experts,
210
209
  topk,
211
210
  k,
212
211
  0,
@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
581
581
  )
582
582
 
583
583
 
584
+ @triton.jit
585
+ def post_reorder_triton_kernel_for_cutlass_moe(
586
+ down_output_ptr,
587
+ output_ptr,
588
+ src2dst_ptr,
589
+ topk_ids_ptr,
590
+ topk_weights_ptr,
591
+ num_experts,
592
+ topk,
593
+ hidden_size,
594
+ dst_start,
595
+ BLOCK_SIZE: tl.constexpr,
596
+ ):
597
+ InDtype = down_output_ptr.dtype.element_ty
598
+
599
+ src_idx_int32 = tl.program_id(0)
600
+ src_idx = src_idx_int32.to(tl.int64)
601
+ src2dst_ptr = src2dst_ptr + src_idx * topk
602
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
603
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
604
+
605
+ store_ptr = output_ptr + src_idx * hidden_size
606
+
607
+ vec = tl.arange(0, BLOCK_SIZE)
608
+
609
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
610
+ offset = start_offset + vec
611
+ mask = offset < hidden_size
612
+
613
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
+ for idx in range(topk):
615
+ expert_id = tl.load(topk_ids_ptr + idx)
616
+ if expert_id != num_experts:
617
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
+ dst_idx = dst_idx_int32.to(tl.int64)
619
+ dst_idx = dst_idx - dst_start
620
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
+ load_ptr = down_output_ptr + dst_idx * hidden_size
622
+ in_data = tl.load(load_ptr + offset, mask=mask)
623
+ sum_vec += in_data * weigh_scale
624
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
625
+
626
+
584
627
  @triton.jit
585
628
  def compute_m_range(
586
629
  pid,
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from sglang.srt.layers.moe.token_dispatcher import (
37
+ AscendDeepEPLLOutput,
37
38
  DeepEPLLOutput,
38
39
  DeepEPNormalOutput,
39
40
  DispatchOutput,
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
387
388
  return_recv_hook=True,
388
389
  )
389
390
 
390
- if self.deepep_mode.enable_low_latency():
391
+ if self.deepep_mode.enable_low_latency() and not _is_npu:
392
+ # NPU supports low_latency deepep without deepgemm
391
393
  assert (
392
394
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
393
395
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
404
406
  )
405
407
  # the last one is invalid rank_id
406
408
  self.expert_mask[:-1] = 1
407
- else:
409
+ elif not _is_npu:
408
410
  self.w13_weight_fp8 = (
409
411
  self.w13_weight,
410
412
  (
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
459
461
  if _use_aiter:
460
462
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
461
463
  return self.forward_aiter(dispatch_output)
464
+ if _is_npu:
465
+ return self.forward_npu(dispatch_output)
462
466
  if dispatch_output.format.is_deepep_normal():
463
467
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
464
468
  return self.forward_deepgemm_contiguous(dispatch_output)
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
723
727
 
724
728
  return down_output
725
729
 
730
+ def forward_npu(
731
+ self,
732
+ dispatch_output: DeepEPLLOutput,
733
+ ):
734
+ if TYPE_CHECKING:
735
+ assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
+ hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
+ assert self.quant_method is not None
738
+ assert self.activation == "silu"
739
+
740
+ # NOTE: Ascend's Dispatch & Combine does not support FP16
741
+ output_dtype = torch.bfloat16
742
+
743
+ pertoken_scale = hidden_states[1]
744
+ hidden_states = hidden_states[0]
745
+
746
+ group_list_type = 1
747
+ seg_indptr = seg_indptr.to(torch.int64)
748
+
749
+ import torch_npu
750
+
751
+ # gmm1: gate_up_proj
752
+ hidden_states = torch_npu.npu_grouped_matmul(
753
+ x=[hidden_states],
754
+ weight=[self.w13_weight],
755
+ scale=[self.w13_weight_scale.to(output_dtype)],
756
+ per_token_scale=[pertoken_scale],
757
+ split_item=2,
758
+ group_list_type=group_list_type,
759
+ group_type=0,
760
+ group_list=seg_indptr,
761
+ output_dtype=output_dtype,
762
+ )[0]
763
+
764
+ # act_fn: swiglu
765
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
766
+
767
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
768
+
769
+ # gmm2: down_proj
770
+ hidden_states = torch_npu.npu_grouped_matmul(
771
+ x=[hidden_states],
772
+ weight=[self.w2_weight],
773
+ scale=[self.w2_weight_scale.to(output_dtype)],
774
+ per_token_scale=[swiglu_out_scale],
775
+ split_item=2,
776
+ group_list_type=group_list_type,
777
+ group_type=0,
778
+ group_list=seg_indptr,
779
+ output_dtype=output_dtype,
780
+ )[0]
781
+
782
+ return hidden_states
783
+
726
784
 
727
785
  def get_moe_impl_class():
728
786
  if global_server_args_dict["moe_a2a_backend"].is_deepep():