sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
7
 
8
8
  import torch
9
9
  from torch.nn import Module
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
22
  from sglang.srt.utils import set_weight_attrs
23
23
 
24
24
  if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
25
26
  from sglang.srt.layers.moe.topk import TopKOutput
26
27
 
27
28
  ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
348
349
  layer: torch.nn.Module,
349
350
  x: torch.Tensor,
350
351
  topk_output: TopKOutput,
351
- *,
352
- activation: str = "silu",
353
- apply_router_weight_on_input: bool = False,
354
- inplace: bool = True,
355
- no_combine: bool = False,
356
- routed_scaling_factor: Optional[float] = None,
352
+ moe_runner_config: MoeRunnerConfig,
357
353
  ) -> torch.Tensor:
358
354
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
359
355
 
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
363
359
  layer.w13_weight,
364
360
  layer.w2_weight,
365
361
  topk_output=topk_output,
366
- inplace=inplace,
367
- activation=activation,
368
- apply_router_weight_on_input=apply_router_weight_on_input,
362
+ moe_runner_config=moe_runner_config,
369
363
  use_int8_w8a8=True,
370
364
  w1_scale=(layer.w13_weight_scale_inv),
371
365
  w2_scale=(layer.w2_weight_scale_inv),
372
366
  a1_scale=layer.w13_input_scale,
373
367
  a2_scale=layer.w2_input_scale,
374
368
  block_shape=self.quant_config.weight_block_size,
375
- no_combine=no_combine,
376
- routed_scaling_factor=routed_scaling_factor,
377
369
  )
@@ -16,19 +16,33 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
16
16
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
17
17
  from sglang.srt.layers.quantization.utils import (
18
18
  all_close_1d,
19
- cpu_has_amx_support,
20
19
  per_tensor_dequantize,
21
20
  replace_parameter,
22
21
  )
23
- from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
22
+ from sglang.srt.utils import (
23
+ get_bool_env_var,
24
+ is_cpu,
25
+ is_cuda,
26
+ is_hip,
27
+ is_npu,
28
+ set_weight_attrs,
29
+ )
24
30
 
25
31
  if TYPE_CHECKING:
26
32
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
33
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
27
34
  from sglang.srt.layers.moe.topk import TopKOutput
28
35
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
29
36
  CompressedTensorsConfig,
30
37
  )
31
38
 
39
+ _is_hip = is_hip()
40
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
41
+
42
+ if _use_aiter:
43
+ from aiter.ops.shuffle import shuffle_weight
44
+
45
+ from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
32
46
 
33
47
  try:
34
48
  import vllm
@@ -265,37 +279,66 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
265
279
  max_w13_scales, requires_grad=False
266
280
  )
267
281
 
282
+ if _use_aiter:
283
+ with torch.no_grad():
284
+ # Pre-shuffle weights
285
+ layer.w13_weight = torch.nn.Parameter(
286
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
287
+ requires_grad=False,
288
+ )
289
+ torch.cuda.empty_cache()
290
+ layer.w2_weight = torch.nn.Parameter(
291
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
292
+ requires_grad=False,
293
+ )
294
+ torch.cuda.empty_cache()
295
+
268
296
  def apply(
269
297
  self,
270
298
  layer: torch.nn.Module,
271
299
  x: torch.Tensor,
272
300
  topk_output: TopKOutput,
273
- *,
274
- activation: str = "silu",
275
- apply_router_weight_on_input: bool = False,
276
- inplace: bool = True,
277
- no_combine: bool = False,
278
- routed_scaling_factor: Optional[float] = None,
301
+ moe_runner_config: MoeRunnerConfig,
279
302
  ) -> torch.Tensor:
280
303
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
281
304
 
282
- return fused_experts(
283
- x,
284
- layer.w13_weight,
285
- layer.w2_weight,
286
- topk_output=topk_output,
287
- inplace=inplace,
288
- activation=activation,
289
- use_fp8_w8a8=True,
290
- per_channel_quant=self.weight_quant.strategy
291
- == QuantizationStrategy.CHANNEL,
292
- w1_scale=layer.w13_weight_scale,
293
- w2_scale=layer.w2_weight_scale,
294
- a1_scale=layer.w13_input_scale,
295
- a2_scale=layer.w2_input_scale,
296
- apply_router_weight_on_input=apply_router_weight_on_input,
297
- routed_scaling_factor=routed_scaling_factor,
298
- )
305
+ if (
306
+ _use_aiter
307
+ and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
308
+ and moe_runner_config.apply_router_weight_on_input
309
+ ):
310
+ topk_weights, topk_ids, _ = topk_output
311
+ return rocm_fused_experts_tkw1(
312
+ hidden_states=x,
313
+ w1=layer.w13_weight,
314
+ w2=layer.w2_weight,
315
+ topk_weights=topk_weights,
316
+ topk_ids=topk_ids,
317
+ activation=moe_runner_config.activation,
318
+ apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
319
+ use_fp8_w8a8=True,
320
+ per_channel_quant=self.weight_quant.strategy
321
+ == QuantizationStrategy.CHANNEL,
322
+ w1_scale=layer.w13_weight_scale,
323
+ w2_scale=layer.w2_weight_scale,
324
+ a1_scale=layer.w13_input_scale,
325
+ a2_scale=layer.w2_input_scale,
326
+ )
327
+ else:
328
+ return fused_experts(
329
+ x,
330
+ layer.w13_weight,
331
+ layer.w2_weight,
332
+ topk_output=topk_output,
333
+ moe_runner_config=moe_runner_config,
334
+ use_fp8_w8a8=True,
335
+ per_channel_quant=self.weight_quant.strategy
336
+ == QuantizationStrategy.CHANNEL,
337
+ w1_scale=layer.w13_weight_scale,
338
+ w2_scale=layer.w2_weight_scale,
339
+ a1_scale=layer.w13_input_scale,
340
+ a2_scale=layer.w2_input_scale,
341
+ )
299
342
 
300
343
 
301
344
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -602,12 +645,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
602
645
  layer: torch.nn.Module,
603
646
  x: torch.Tensor,
604
647
  topk_output: TopKOutput,
605
- *,
606
- activation: str = "silu",
607
- **kwargs,
648
+ moe_runner_config: MoeRunnerConfig,
608
649
  ) -> torch.Tensor:
609
650
 
610
- assert activation == "silu", "Only SiLU activation is supported."
651
+ assert (
652
+ moe_runner_config.activation == "silu"
653
+ ), "Only SiLU activation is supported."
611
654
 
612
655
  topk_weights, topk_ids, router_logits = topk_output
613
656
 
@@ -7,7 +7,8 @@ logger = logging.getLogger(__name__)
7
7
 
8
8
  def _compute_enable_deep_gemm():
9
9
  sm_version = get_device_sm()
10
- if sm_version < 90:
10
+ # TODO fix blackwell fp8
11
+ if sm_version != 90:
11
12
  return False
12
13
 
13
14
  try:
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
49
49
  )
50
50
  from sglang.srt.layers.quantization.fp8_utils import (
51
51
  apply_fp8_linear,
52
+ can_auto_enable_marlin_fp8,
52
53
  cutlass_fp8_supported,
53
54
  dispatch_w8a8_block_fp8_linear,
54
55
  input_to_float8,
@@ -79,6 +80,7 @@ from sglang.srt.utils import (
79
80
  )
80
81
 
81
82
  if TYPE_CHECKING:
83
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
82
84
  from sglang.srt.layers.moe.topk import TopKOutput
83
85
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
84
86
 
@@ -208,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
208
210
 
209
211
  # For GPUs that lack FP8 hardware support, we can leverage the Marlin
210
212
  # kernel for fast weight-only FP8 quantization
211
- self.use_marlin = (
212
- get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
213
- )
214
- # Disable marlin for ROCm
215
- if _is_hip:
216
- self.use_marlin = False
213
+ self.use_marlin = False
214
+ if _is_cuda and MARLIN_FP8_AVAILABLE:
215
+ force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
216
+ auto_enable = can_auto_enable_marlin_fp8()
217
+ self.use_marlin = force_marlin or auto_enable
217
218
 
218
219
  self.block_quant = self.quant_config.weight_block_size is not None
219
- if self.block_quant:
220
- # Marlin doesn't support block-wise fp8
221
- self.use_marlin = False
222
220
 
223
221
  self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
224
222
 
@@ -331,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
331
329
  layer.register_parameter("input_scale", None)
332
330
 
333
331
  def process_weights_after_loading(self, layer: Module) -> None:
334
- # Block quant doesn't need to process weights after loading
335
332
  if self.block_quant:
336
333
  # If ROCm, normalize the weights and scales to e4m3fnuz
337
334
  if _is_fp8_fnuz:
@@ -341,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
341
338
  weight_scale=layer.weight_scale_inv,
342
339
  input_scale=None,
343
340
  )
344
-
345
341
  layer.input_scale = None
346
342
  elif _is_cpu:
347
343
  assert (
@@ -351,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
351
347
  return
352
348
  else:
353
349
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
354
- layer.weight = torch.nn.Parameter(weight, requires_grad=False)
355
- layer.weight_scale_inv = torch.nn.Parameter(
356
- weight_scale, requires_grad=False
357
- )
358
- return
350
+ layer.weight = Parameter(weight, requires_grad=False)
351
+ layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
352
+ else:
353
+ layer.weight = Parameter(layer.weight.data, requires_grad=False)
359
354
 
360
- layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
355
+ # If checkpoint not serialized fp8, quantize the weights.
356
+ if not self.quant_config.is_checkpoint_fp8_serialized:
357
+ if self.cutlass_fp8_supported or self.use_marlin:
358
+ # apply per-channel quantization default as
359
+ # cutlass sgl-kernel and marlin only support per-channel scale
360
+ qweight, weight_scale = per_token_group_quant_fp8(
361
+ layer.weight, layer.weight.shape[-1]
362
+ )
363
+ weight_scale = weight_scale.t().contiguous()
364
+ else:
365
+ # per-tensor quantization
366
+ qweight, weight_scale = input_to_float8(layer.weight)
367
+
368
+ # Update the layer with the new values.
369
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
370
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
371
+ layer.input_scale = None
361
372
 
362
- # If checkpoint not serialized fp8, quantize the weights.
363
- if not self.quant_config.is_checkpoint_fp8_serialized:
364
- if self.cutlass_fp8_supported or self.use_marlin:
365
- # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
366
- qweight, weight_scale = per_token_group_quant_fp8(
367
- layer.weight, layer.weight.shape[-1]
368
- )
369
- weight_scale = weight_scale.t().contiguous()
373
+ # If checkpoint is fp8, handle that there are N scales for N
374
+ # shards in a fused module
370
375
  else:
371
- # per-tensor quantization
372
- qweight, weight_scale = input_to_float8(layer.weight)
373
-
374
- # Update the layer with the new values.
375
- layer.weight = Parameter(qweight.t(), requires_grad=False)
376
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
377
- layer.input_scale = None
378
-
379
- # If checkpoint is fp8, handle that there are N scales for N
380
- # shards in a fused module
381
- else:
382
- layer.weight_scale = torch.nn.Parameter(
383
- layer.weight_scale.data, requires_grad=False
384
- )
385
- if (
386
- hasattr(self.quant_config, "activation_scheme")
387
- and self.quant_config.activation_scheme == "static"
388
- ) or (
389
- hasattr(self.quant_config, "linear_activation_scheme")
390
- and self.quant_config.linear_activation_scheme == "static"
391
- ):
392
- layer.input_scale = torch.nn.Parameter(
393
- layer.input_scale.data, requires_grad=False
376
+ layer.weight_scale = Parameter(
377
+ layer.weight_scale.data, requires_grad=False
394
378
  )
379
+ if (
380
+ hasattr(self.quant_config, "activation_scheme")
381
+ and self.quant_config.activation_scheme == "static"
382
+ ) or (
383
+ hasattr(self.quant_config, "linear_activation_scheme")
384
+ and self.quant_config.linear_activation_scheme == "static"
385
+ ):
386
+ layer.input_scale = Parameter(
387
+ layer.input_scale.data, requires_grad=False
388
+ )
395
389
 
396
- # cutlass sgl-kernel and marlin only support per-channel scale
397
- if self.cutlass_fp8_supported or self.use_marlin:
398
- weight = layer.weight
399
- weight_scale = convert_to_channelwise(
400
- layer.weight_scale, layer.logical_widths
401
- )
402
- else:
403
- # Dequant -> Quant with max scale so we can run per tensor.
404
- weight = layer.weight
405
- weight_scale = layer.weight_scale
406
- # If ROCm, normalize the weights and scales to e4m3fnuz
407
- if _is_fp8_fnuz:
408
- weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
390
+ # cutlass sgl-kernel and marlin only support per-channel scale
391
+ if self.cutlass_fp8_supported or self.use_marlin:
392
+ weight = layer.weight
393
+ weight_scale = convert_to_channelwise(
394
+ layer.weight_scale, layer.logical_widths
395
+ )
396
+ else:
397
+ # Dequant -> Quant with max scale so we can run per tensor.
398
+ weight = layer.weight
399
+ weight_scale = layer.weight_scale
400
+ # If ROCm, normalize the weights and scales to e4m3fnuz
401
+ if _is_fp8_fnuz:
402
+ weight, weight_scale, input_scale = (
403
+ normalize_e4m3fn_to_e4m3fnuz(
404
+ weight=weight,
405
+ weight_scale=weight_scale,
406
+ input_scale=layer.input_scale,
407
+ )
408
+ )
409
+ if input_scale is not None:
410
+ layer.input_scale = Parameter(
411
+ input_scale, requires_grad=False
412
+ )
413
+
414
+ weight_scale, weight = requantize_with_max_scale(
409
415
  weight=weight,
410
416
  weight_scale=weight_scale,
411
- input_scale=layer.input_scale,
417
+ logical_widths=layer.logical_widths,
412
418
  )
413
- if input_scale is not None:
414
- layer.input_scale = Parameter(input_scale, requires_grad=False)
415
-
416
- weight_scale, weight = requantize_with_max_scale(
417
- weight=weight,
418
- weight_scale=weight_scale,
419
- logical_widths=layer.logical_widths,
420
- )
421
419
 
422
- # Update layer with new values.
423
- layer.weight = Parameter(weight.t(), requires_grad=False)
424
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
425
- if (
426
- hasattr(self.quant_config, "activation_scheme")
427
- and self.quant_config.activation_scheme == "static"
428
- ) or (
429
- hasattr(self.quant_config, "linear_activation_scheme")
430
- and self.quant_config.linear_activation_scheme == "static"
431
- ):
432
- layer.input_scale = Parameter(
433
- layer.input_scale.max(), requires_grad=False
434
- )
420
+ # Update layer with new values.
421
+ layer.weight = Parameter(weight.t(), requires_grad=False)
422
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
423
+ if (
424
+ hasattr(self.quant_config, "activation_scheme")
425
+ and self.quant_config.activation_scheme == "static"
426
+ ) or (
427
+ hasattr(self.quant_config, "linear_activation_scheme")
428
+ and self.quant_config.linear_activation_scheme == "static"
429
+ ):
430
+ layer.input_scale = Parameter(
431
+ layer.input_scale.max(), requires_grad=False
432
+ )
435
433
 
436
434
  if self.use_marlin:
437
- prepare_fp8_layer_for_marlin(layer)
435
+ if self.block_quant:
436
+ layer.weight_block_size = self.quant_config.weight_block_size
437
+ prepare_fp8_layer_for_marlin(layer, not self.block_quant)
438
438
  # Activations not quantized for marlin.
439
439
  del layer.input_scale
440
440
 
@@ -444,7 +444,6 @@ class Fp8LinearMethod(LinearMethodBase):
444
444
  x: torch.Tensor,
445
445
  bias: Optional[torch.Tensor] = None,
446
446
  ) -> torch.Tensor:
447
-
448
447
  if self.use_marlin:
449
448
  return apply_fp8_marlin_linear(
450
449
  input=x,
@@ -515,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
515
514
  self.quant_config = quant_config
516
515
  self.block_quant = self.quant_config.weight_block_size is not None
517
516
  self.cutlass_fp8_supported = cutlass_fp8_supported()
517
+ self.use_cutlass_fused_experts_fp8 = (
518
+ get_bool_env_var("SGLANG_CUTLASS_MOE")
519
+ and self.cutlass_fp8_supported
520
+ and self.block_quant
521
+ and (is_sm100_supported() or is_sm90_supported())
522
+ )
518
523
 
519
524
  def create_weights(
520
525
  self,
@@ -961,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
961
966
  requires_grad=False,
962
967
  )
963
968
  torch.cuda.empty_cache()
969
+
964
970
  # ROCm (_use_aiter): using column-wise scaling
965
971
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
966
972
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
@@ -982,12 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
982
988
  layer: torch.nn.Module,
983
989
  x: torch.Tensor,
984
990
  topk_output: TopKOutput,
985
- *,
986
- activation: str = "silu",
987
- apply_router_weight_on_input: bool = False,
988
- inplace: bool = True,
989
- no_combine: bool = False,
990
- routed_scaling_factor: Optional[float] = None,
991
+ moe_runner_config: MoeRunnerConfig,
991
992
  ) -> torch.Tensor:
992
993
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
993
994
 
@@ -996,7 +997,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
996
997
 
997
998
  topk_weights, topk_ids, _ = topk_output
998
999
  x, topk_weights = apply_topk_weights_cpu(
999
- apply_router_weight_on_input, topk_weights, x
1000
+ moe_runner_config.apply_router_weight_on_input, topk_weights, x
1000
1001
  )
1001
1002
 
1002
1003
  return torch.ops.sgl_kernel.fused_experts_cpu(
@@ -1021,18 +1022,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1021
1022
  layer,
1022
1023
  x,
1023
1024
  topk_output,
1024
- activation,
1025
- no_combine,
1025
+ moe_runner_config.activation,
1026
+ moe_runner_config.no_combine,
1026
1027
  )
1027
1028
  if ret is not None:
1028
1029
  return ret
1029
1030
 
1030
- if (
1031
- get_bool_env_var("SGLANG_CUTLASS_MOE")
1032
- and self.cutlass_fp8_supported
1033
- and self.block_quant
1034
- and (is_sm100_supported() or is_sm90_supported())
1035
- ):
1031
+ if self.use_cutlass_fused_experts_fp8:
1036
1032
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1037
1033
 
1038
1034
  topk_weights, topk_ids, _ = topk_output
@@ -1059,9 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1059
1055
  self.problem_sizes2,
1060
1056
  use_fp8_blockscale=True,
1061
1057
  )
1062
- # TODO: Fuse into select_experts
1063
- if routed_scaling_factor is not None:
1064
- output *= routed_scaling_factor
1058
+ # Scale by routed_scaling_factor is fused into select_experts.
1065
1059
  return output
1066
1060
  # Expert fusion with FP8 quantization
1067
1061
  return fused_experts(
@@ -1069,9 +1063,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1069
1063
  layer.w13_weight,
1070
1064
  layer.w2_weight,
1071
1065
  topk_output=topk_output,
1072
- inplace=inplace and not no_combine,
1073
- activation=activation,
1074
- apply_router_weight_on_input=apply_router_weight_on_input,
1066
+ moe_runner_config=moe_runner_config,
1075
1067
  use_fp8_w8a8=True,
1076
1068
  w1_scale=(
1077
1069
  layer.w13_weight_scale_inv
@@ -1084,30 +1076,44 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1084
1076
  a1_scale=layer.w13_input_scale,
1085
1077
  a2_scale=layer.w2_input_scale,
1086
1078
  block_shape=self.quant_config.weight_block_size,
1087
- no_combine=no_combine,
1088
- routed_scaling_factor=routed_scaling_factor,
1089
1079
  )
1090
1080
 
1091
1081
  def apply_with_router_logits(
1092
1082
  self,
1093
1083
  layer: torch.nn.Module,
1094
1084
  x: torch.Tensor,
1095
- router_logits: torch.Tensor,
1096
- *,
1097
- activation: str = "silu",
1098
- routed_scaling_factor: Optional[float] = None,
1085
+ topk_output: TopKOutput,
1086
+ moe_runner_config: MoeRunnerConfig,
1099
1087
  ) -> torch.Tensor:
1088
+ activation = moe_runner_config.activation
1089
+ routed_scaling_factor = moe_runner_config.routed_scaling_factor
1090
+
1091
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1092
+
1093
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
1094
+
1095
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
1096
+ router_logits = topk_output.router_logits
1097
+ topk_config = topk_output.topk_config
1100
1098
  assert (
1101
1099
  activation == "silu"
1102
1100
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
1103
1101
  a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
1104
1102
  # NOTE: scales of hidden states have to be transposed!
1105
1103
  a_sf_t = a_sf.t().contiguous()
1106
- from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1107
1104
 
1105
+ assert (
1106
+ topk_config.num_expert_group is not None
1107
+ and topk_config.topk_group is not None
1108
+ ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
1109
+
1110
+ if topk_config.correction_bias is None:
1111
+ correction_bias = topk_config.correction_bias.to(x.dtype)
1112
+ else:
1113
+ correction_bias = None
1108
1114
  return trtllm_fp8_block_scale_moe(
1109
1115
  routing_logits=router_logits.to(torch.float32),
1110
- routing_bias=layer.correction_bias.to(x.dtype),
1116
+ routing_bias=correction_bias,
1111
1117
  hidden_states=a_q,
1112
1118
  hidden_states_scale=a_sf_t,
1113
1119
  gemm1_weights=layer.w13_weight,
@@ -1115,15 +1121,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1115
1121
  gemm2_weights=layer.w2_weight,
1116
1122
  gemm2_weights_scale=layer.w2_weight_scale_inv,
1117
1123
  num_experts=layer.num_experts,
1118
- top_k=layer.top_k,
1119
- n_group=layer.num_expert_group,
1120
- topk_group=layer.topk_group,
1124
+ top_k=topk_config.top_k,
1125
+ n_group=topk_config.num_expert_group,
1126
+ topk_group=topk_config.topk_group,
1121
1127
  intermediate_size=layer.w2_weight.shape[2],
1122
1128
  local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
1123
1129
  local_num_experts=layer.num_local_experts,
1124
- routed_scaling_factor=routed_scaling_factor,
1130
+ routed_scaling_factor=(
1131
+ routed_scaling_factor if routed_scaling_factor is not None else 1.0
1132
+ ),
1125
1133
  tile_tokens_dim=get_tile_tokens_dim(
1126
- x.shape[0], layer.top_k, layer.num_experts
1134
+ x.shape[0], topk_config.top_k, layer.num_experts
1127
1135
  ),
1128
1136
  routing_method_type=2, # DeepSeek-styled routing method
1129
1137
  use_shuffled_weight=False,