sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
5
5
 
6
6
  import torch
7
7
 
@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import (
13
13
  get_moe_runner_backend,
14
14
  should_use_flashinfer_trtllm_moe,
15
15
  )
16
- from sglang.srt.layers.moe.ep_moe.kernels import (
17
- ep_gather,
18
- ep_scatter,
19
- silu_and_mul_masked_post_quant_fwd,
20
- tma_align_input_scale,
21
- )
22
16
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
17
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
18
+ DeepEPLLCombineInput,
19
+ DeepEPNormalCombineInput,
20
+ )
23
21
  from sglang.srt.layers.moe.topk import TopKOutput
24
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
23
  from sglang.srt.layers.quantization.fp8 import Fp8Config
26
- from sglang.srt.layers.quantization.fp8_kernel import (
27
- is_fp8_fnuz,
28
- sglang_per_token_group_quant_fp8,
29
- )
24
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
30
25
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
31
26
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
32
- from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
33
- from sglang.srt.utils.offloader import get_offloader
27
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
34
28
 
35
29
  if TYPE_CHECKING:
36
30
  from sglang.srt.layers.moe.token_dispatcher import (
37
- DeepEPLLOutput,
38
- DeepEPNormalOutput,
31
+ DeepEPLLDispatchOutput,
32
+ DeepEPNormalDispatchOutput,
39
33
  DispatchOutput,
40
34
  )
41
35
 
@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
45
39
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
46
40
 
47
41
  if not (_is_npu or _is_hip):
48
- from sgl_kernel import silu_and_mul
42
+ pass
49
43
 
50
44
  if _use_aiter:
51
45
  from aiter import ActivationType, QuantType
@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE):
90
84
  routed_scaling_factor=routed_scaling_factor,
91
85
  )
92
86
 
87
+ if _use_aiter or _is_npu:
88
+ self.deprecate_flag = False
89
+ elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance(
90
+ quant_config, Fp8Config
91
+ ):
92
+ self.deprecate_flag = True
93
+ else:
94
+ self.deprecate_flag = False
95
+
96
+ if self.deprecate_flag:
97
+ return
98
+
93
99
  if isinstance(quant_config, Fp8Config):
94
100
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
95
101
  self.use_fp8_w8a8 = True
@@ -100,6 +106,7 @@ class DeepEPMoE(FusedMoE):
100
106
  self.use_fp8_w8a8 = False
101
107
  self.use_block_quant = False
102
108
  else:
109
+ self.use_w4afp8 = False
103
110
  self.use_fp8_w8a8 = False
104
111
  self.use_block_quant = False
105
112
  self.use_w4afp8 = False
@@ -124,23 +131,6 @@ class DeepEPMoE(FusedMoE):
124
131
  )
125
132
  # the last one is invalid rank_id
126
133
  self.expert_mask[:-1] = 1
127
- elif not _is_npu:
128
- self.w13_weight_fp8 = (
129
- self.w13_weight,
130
- (
131
- self.w13_weight_scale_inv
132
- if self.use_block_quant or self.use_w4afp8
133
- else self.w13_weight_scale
134
- ),
135
- )
136
- self.w2_weight_fp8 = (
137
- self.w2_weight,
138
- (
139
- self.w2_weight_scale_inv
140
- if self.use_block_quant or self.use_w4afp8
141
- else self.w2_weight_scale
142
- ),
143
- )
144
134
 
145
135
  def forward(
146
136
  self,
@@ -151,6 +141,14 @@ class DeepEPMoE(FusedMoE):
151
141
  disable_sbo=False,
152
142
  ):
153
143
 
144
+ if self.deprecate_flag:
145
+ assert forward_shared_experts is None
146
+ assert alt_stream is None
147
+ return super().forward(
148
+ hidden_states,
149
+ topk_output,
150
+ )
151
+
154
152
  # We have to call SBO inside MoE to be compatible with hooks used in offloading
155
153
  return single_batch_overlap.execute_sbo(
156
154
  hidden_states=hidden_states,
@@ -177,35 +175,50 @@ class DeepEPMoE(FusedMoE):
177
175
  dispatch_output: DispatchOutput,
178
176
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
179
177
  ):
178
+
179
+ if self.deprecate_flag:
180
+ assert down_gemm_overlap_args is None
181
+ return super().run_moe_core(
182
+ dispatch_output,
183
+ )
184
+
180
185
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
181
186
 
182
187
  if _use_aiter:
183
188
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
184
189
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
185
- return self.forward_aiter(dispatch_output)
186
- if _is_npu:
190
+ output = self.forward_aiter(dispatch_output)
191
+ elif _is_npu:
187
192
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
188
- return self.forward_npu(dispatch_output)
189
- if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
193
+ output = self.forward_npu(dispatch_output)
194
+ elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
190
195
  if self.use_w4afp8:
191
- return self.forward_cutlass_w4afp8(dispatch_output)
192
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
193
- return self.forward_deepgemm_contiguous(dispatch_output)
196
+ output = self.forward_cutlass_w4afp8(dispatch_output)
197
+ else:
198
+ assert False, "forward_deepgemm_contiguous is deprecated"
194
199
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
195
200
  if (
196
201
  get_moe_runner_backend().is_flashinfer_cutedsl()
197
202
  and self.quant_config.get_name() == "modelopt_fp4"
198
203
  ):
199
- return self.forward_flashinfer_cutedsl(
204
+ output = self.forward_flashinfer_cutedsl(
200
205
  dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
201
206
  )
202
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
203
- assert down_gemm_overlap_args is None
204
- return self.forward_deepgemm_masked(dispatch_output)
205
- else:
206
- raise ValueError(
207
- f"Dispatch output format {dispatch_output.format} is not supported"
208
- )
207
+ elif self.use_w4afp8:
208
+ output = self.forward_cutlass_w4afp8_masked(dispatch_output)
209
+ else:
210
+ assert False, "forward_deepgemm_masked is deprecated"
211
+
212
+ combine_input_wrapper = (
213
+ DeepEPNormalCombineInput
214
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
215
+ else DeepEPLLCombineInput
216
+ )
217
+ return combine_input_wrapper(
218
+ hidden_states=output,
219
+ topk_ids=dispatch_output.topk_ids,
220
+ topk_weights=dispatch_output.topk_weights,
221
+ )
209
222
 
210
223
  def combine(
211
224
  self,
@@ -223,7 +236,7 @@ class DeepEPMoE(FusedMoE):
223
236
 
224
237
  def forward_aiter(
225
238
  self,
226
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
239
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
227
240
  ):
228
241
  hidden_states, topk_ids, topk_weights = (
229
242
  dispatch_output.hidden_states,
@@ -255,158 +268,9 @@ class DeepEPMoE(FusedMoE):
255
268
  expert_mask=self.expert_mask,
256
269
  )
257
270
 
258
- def forward_deepgemm_contiguous(
259
- self,
260
- dispatch_output: DeepEPNormalOutput,
261
- ):
262
- (
263
- hidden_states,
264
- hidden_states_scale,
265
- topk_ids,
266
- topk_weights,
267
- num_recv_tokens_per_expert,
268
- ) = dispatch_output
269
- assert self.quant_method is not None
270
- assert self.moe_runner_config.activation == "silu"
271
- if num_recv_tokens_per_expert is None:
272
- return hidden_states.bfloat16()
273
- all_tokens = sum(num_recv_tokens_per_expert)
274
- if all_tokens <= 0:
275
- return hidden_states.bfloat16()
276
- M, K = hidden_states.size()
277
- N = self.w13_weight.size(1)
278
- scale_block_size = 128
279
-
280
- w13_weight_fp8 = (
281
- self.w13_weight,
282
- (
283
- self.w13_weight_scale_inv
284
- if self.use_block_quant
285
- else self.w13_weight_scale
286
- ),
287
- )
288
- w2_weight_fp8 = (
289
- self.w2_weight,
290
- (
291
- self.w2_weight_scale_inv
292
- if self.use_block_quant
293
- else self.w2_weight_scale
294
- ),
295
- )
296
-
297
- hidden_states_shape = hidden_states.shape
298
- hidden_states_device = hidden_states.device
299
- hidden_states_dtype = hidden_states.dtype
300
-
301
- input_tensor = [
302
- torch.empty(
303
- (all_tokens, K),
304
- device=hidden_states.device,
305
- dtype=hidden_states.dtype,
306
- ),
307
- (
308
- # TODO check whether need `zeros`
309
- torch.zeros(
310
- (ceil_div(K // 128, 4), all_tokens),
311
- device=hidden_states.device,
312
- dtype=torch.int,
313
- ).transpose(0, 1)
314
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
315
- else torch.empty(
316
- (all_tokens, K // 128),
317
- device=hidden_states.device,
318
- dtype=torch.float32,
319
- )
320
- ),
321
- ]
322
- m_indices = torch.empty(
323
- all_tokens, device=hidden_states.device, dtype=torch.int32
324
- )
325
- output_index = torch.empty_like(topk_ids)
326
-
327
- if get_offloader().forbid_copy_engine_usage:
328
- num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
329
- num_recv_tokens_per_expert
330
- )
331
- else:
332
- num_recv_tokens_per_expert_gpu = torch.tensor(
333
- num_recv_tokens_per_expert,
334
- dtype=torch.int32,
335
- pin_memory=True,
336
- device="cpu",
337
- ).cuda(non_blocking=True)
338
- expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
339
-
340
- ep_scatter(
341
- hidden_states,
342
- hidden_states_scale,
343
- topk_ids,
344
- num_recv_tokens_per_expert_gpu,
345
- expert_start_loc,
346
- input_tensor[0],
347
- input_tensor[1],
348
- m_indices,
349
- output_index,
350
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
351
- )
352
- dispose_tensor(hidden_states)
353
-
354
- gateup_output = torch.empty(
355
- (all_tokens, N),
356
- device=hidden_states_device,
357
- dtype=torch.bfloat16,
358
- )
359
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
360
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
361
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
362
- input_tensor, w13_weight_fp8, gateup_output, m_indices
363
- )
364
- del input_tensor
365
- down_input = torch.empty(
366
- (
367
- all_tokens,
368
- N // 2,
369
- ),
370
- device=gateup_output.device,
371
- dtype=torch.bfloat16,
372
- )
373
- silu_and_mul(gateup_output.view(-1, N), down_input)
374
- del gateup_output
375
- down_output = torch.empty(
376
- (all_tokens, K),
377
- device=hidden_states_device,
378
- dtype=torch.bfloat16,
379
- )
380
- down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
381
- down_input,
382
- scale_block_size,
383
- column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
384
- scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
385
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
386
- )
387
- del down_input
388
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
389
- down_input_scale = tma_align_input_scale(down_input_scale)
390
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
391
- (down_input_fp8, down_input_scale),
392
- w2_weight_fp8,
393
- down_output,
394
- m_indices,
395
- )
396
- del down_input_fp8, down_input_scale
397
-
398
- gather_out = torch.empty(
399
- hidden_states_shape,
400
- device=hidden_states_device,
401
- dtype=torch.bfloat16,
402
- )
403
- ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
404
-
405
- return gather_out
406
-
407
271
  def forward_flashinfer_cutedsl(
408
272
  self,
409
- dispatch_output: DeepEPLLOutput,
273
+ dispatch_output: DeepEPLLDispatchOutput,
410
274
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
411
275
  ):
412
276
  hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
@@ -424,7 +288,7 @@ class DeepEPMoE(FusedMoE):
424
288
 
425
289
  def forward_cutlass_w4afp8(
426
290
  self,
427
- dispatch_output: DeepEPNormalOutput,
291
+ dispatch_output: DeepEPNormalDispatchOutput,
428
292
  ):
429
293
  assert self.moe_runner_config.activation == "silu"
430
294
  assert isinstance(self.quant_method, W4AFp8MoEMethod)
@@ -433,89 +297,23 @@ class DeepEPMoE(FusedMoE):
433
297
  dispatch_output=dispatch_output,
434
298
  )
435
299
 
436
- def forward_deepgemm_masked(
300
+ def forward_cutlass_w4afp8_masked(
437
301
  self,
438
- dispatch_output: DeepEPLLOutput,
302
+ dispatch_output: DeepEPLLDispatchOutput,
439
303
  ):
440
- hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
441
- assert self.quant_method is not None
442
304
  assert self.moe_runner_config.activation == "silu"
443
- assert (
444
- hidden_states_scale.dtype == torch.float32
445
- ), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
446
-
447
- # GroupGemm-0
448
- num_groups, m, k = hidden_states.size()
449
- n = self.w13_weight.size(1)
450
- expected_m = min(expected_m, m)
451
- gateup_output = torch.empty(
452
- (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
453
- )
454
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
455
- (hidden_states, hidden_states_scale),
456
- self.w13_weight_fp8,
457
- gateup_output,
458
- masked_m,
459
- expected_m,
460
- )
461
- dispose_tensor(hidden_states)
462
-
463
- # Act
464
- down_input = torch.empty(
465
- (
466
- gateup_output.shape[0],
467
- gateup_output.shape[1],
468
- gateup_output.shape[2] // 2,
469
- ),
470
- device=gateup_output.device,
471
- dtype=self.fp8_dtype,
472
- )
473
- scale_block_size = 128
474
- down_input_scale = torch.empty(
475
- (
476
- gateup_output.shape[0],
477
- gateup_output.shape[1],
478
- gateup_output.shape[2] // 2 // scale_block_size,
479
- ),
480
- device=gateup_output.device,
481
- dtype=torch.float32,
482
- )
483
- silu_and_mul_masked_post_quant_fwd(
484
- gateup_output,
485
- down_input,
486
- down_input_scale,
487
- scale_block_size,
488
- masked_m,
489
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
490
- )
491
- del gateup_output
492
-
493
- # GroupGemm-1
494
- n = self.w2_weight.size(1)
495
- down_input_fp8 = (
496
- down_input,
497
- (
498
- down_input_scale
499
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
500
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
501
- ),
502
- )
503
- down_output = torch.empty(
504
- (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
505
- )
506
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
507
- down_input_fp8,
508
- self.w2_weight_fp8,
509
- down_output,
510
- masked_m,
511
- expected_m,
305
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
306
+ assert get_bool_env_var(
307
+ "SGLANG_DEEPEP_BF16_DISPATCH"
308
+ ), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
309
+ return self.quant_method.apply_deepep_ll(
310
+ layer=self,
311
+ dispatch_output=dispatch_output,
512
312
  )
513
313
 
514
- return down_output
515
-
516
314
  def forward_npu(
517
315
  self,
518
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
316
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
519
317
  ):
520
318
  assert self.quant_method is not None
521
319
  assert self.moe_runner_config.activation == "silu"
@@ -528,9 +326,9 @@ class DeepEPMoE(FusedMoE):
528
326
  output_dtype = torch.bfloat16
529
327
  group_list_type = 1
530
328
 
531
- def _forward_normal(dispatch_output: DeepEPNormalOutput):
329
+ def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput):
532
330
  if TYPE_CHECKING:
533
- assert isinstance(dispatch_output, DeepEPNormalOutput)
331
+ assert isinstance(dispatch_output, DeepEPNormalDispatchOutput)
534
332
  hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
535
333
  dispatch_output
536
334
  )
@@ -600,9 +398,9 @@ class DeepEPMoE(FusedMoE):
600
398
 
601
399
  return hidden_states
602
400
 
603
- def _forward_ll(dispatch_output: DeepEPLLOutput):
401
+ def _forward_ll(dispatch_output: DeepEPLLDispatchOutput):
604
402
  if TYPE_CHECKING:
605
- assert isinstance(dispatch_output, DeepEPLLOutput)
403
+ assert isinstance(dispatch_output, DeepEPLLDispatchOutput)
606
404
  (
607
405
  hidden_states,
608
406
  hidden_states_scale,
@@ -713,12 +511,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
713
511
  if get_moe_runner_backend().is_flashinfer_cutlass():
714
512
  return FusedMoE
715
513
  return FusedMoE
716
-
717
-
718
- def copy_list_to_gpu_no_ce(arr: List[int]):
719
- from sgl_kernel.elementwise import copy_to_gpu_no_ce
720
-
721
- tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
722
- tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
723
- copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
724
- return tensor_gpu
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 8,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }