sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
49
49
  )
50
50
  from sglang.srt.layers.quantization.fp8_utils import (
51
51
  apply_fp8_linear,
52
+ can_auto_enable_marlin_fp8,
52
53
  cutlass_fp8_supported,
53
54
  dispatch_w8a8_block_fp8_linear,
54
55
  input_to_float8,
@@ -79,6 +80,7 @@ from sglang.srt.utils import (
79
80
  )
80
81
 
81
82
  if TYPE_CHECKING:
83
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
82
84
  from sglang.srt.layers.moe.topk import TopKOutput
83
85
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
84
86
 
@@ -208,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
208
210
 
209
211
  # For GPUs that lack FP8 hardware support, we can leverage the Marlin
210
212
  # kernel for fast weight-only FP8 quantization
211
- self.use_marlin = (
212
- get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
213
- )
214
- # Disable marlin for ROCm
215
- if _is_hip:
216
- self.use_marlin = False
213
+ self.use_marlin = False
214
+ if _is_cuda and MARLIN_FP8_AVAILABLE:
215
+ force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
216
+ auto_enable = can_auto_enable_marlin_fp8()
217
+ self.use_marlin = force_marlin or auto_enable
217
218
 
218
219
  self.block_quant = self.quant_config.weight_block_size is not None
219
- if self.block_quant:
220
- # Marlin doesn't support block-wise fp8
221
- self.use_marlin = False
222
220
 
223
221
  self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
224
222
 
@@ -331,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
331
329
  layer.register_parameter("input_scale", None)
332
330
 
333
331
  def process_weights_after_loading(self, layer: Module) -> None:
334
- # Block quant doesn't need to process weights after loading
335
332
  if self.block_quant:
336
333
  # If ROCm, normalize the weights and scales to e4m3fnuz
337
334
  if _is_fp8_fnuz:
@@ -341,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
341
338
  weight_scale=layer.weight_scale_inv,
342
339
  input_scale=None,
343
340
  )
344
-
345
341
  layer.input_scale = None
346
342
  elif _is_cpu:
347
343
  assert (
@@ -351,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
351
347
  return
352
348
  else:
353
349
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
354
- layer.weight = torch.nn.Parameter(weight, requires_grad=False)
355
- layer.weight_scale_inv = torch.nn.Parameter(
356
- weight_scale, requires_grad=False
357
- )
358
- return
350
+ layer.weight = Parameter(weight, requires_grad=False)
351
+ layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
352
+ else:
353
+ layer.weight = Parameter(layer.weight.data, requires_grad=False)
359
354
 
360
- layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
355
+ # If checkpoint not serialized fp8, quantize the weights.
356
+ if not self.quant_config.is_checkpoint_fp8_serialized:
357
+ if self.cutlass_fp8_supported or self.use_marlin:
358
+ # apply per-channel quantization default as
359
+ # cutlass sgl-kernel and marlin only support per-channel scale
360
+ qweight, weight_scale = per_token_group_quant_fp8(
361
+ layer.weight, layer.weight.shape[-1]
362
+ )
363
+ weight_scale = weight_scale.t().contiguous()
364
+ else:
365
+ # per-tensor quantization
366
+ qweight, weight_scale = input_to_float8(layer.weight)
367
+
368
+ # Update the layer with the new values.
369
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
370
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
371
+ layer.input_scale = None
361
372
 
362
- # If checkpoint not serialized fp8, quantize the weights.
363
- if not self.quant_config.is_checkpoint_fp8_serialized:
364
- if self.cutlass_fp8_supported or self.use_marlin:
365
- # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
366
- qweight, weight_scale = per_token_group_quant_fp8(
367
- layer.weight, layer.weight.shape[-1]
368
- )
369
- weight_scale = weight_scale.t().contiguous()
373
+ # If checkpoint is fp8, handle that there are N scales for N
374
+ # shards in a fused module
370
375
  else:
371
- # per-tensor quantization
372
- qweight, weight_scale = input_to_float8(layer.weight)
373
-
374
- # Update the layer with the new values.
375
- layer.weight = Parameter(qweight.t(), requires_grad=False)
376
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
377
- layer.input_scale = None
378
-
379
- # If checkpoint is fp8, handle that there are N scales for N
380
- # shards in a fused module
381
- else:
382
- layer.weight_scale = torch.nn.Parameter(
383
- layer.weight_scale.data, requires_grad=False
384
- )
385
- if (
386
- hasattr(self.quant_config, "activation_scheme")
387
- and self.quant_config.activation_scheme == "static"
388
- ) or (
389
- hasattr(self.quant_config, "linear_activation_scheme")
390
- and self.quant_config.linear_activation_scheme == "static"
391
- ):
392
- layer.input_scale = torch.nn.Parameter(
393
- layer.input_scale.data, requires_grad=False
376
+ layer.weight_scale = Parameter(
377
+ layer.weight_scale.data, requires_grad=False
394
378
  )
379
+ if (
380
+ hasattr(self.quant_config, "activation_scheme")
381
+ and self.quant_config.activation_scheme == "static"
382
+ ) or (
383
+ hasattr(self.quant_config, "linear_activation_scheme")
384
+ and self.quant_config.linear_activation_scheme == "static"
385
+ ):
386
+ layer.input_scale = Parameter(
387
+ layer.input_scale.data, requires_grad=False
388
+ )
395
389
 
396
- # cutlass sgl-kernel and marlin only support per-channel scale
397
- if self.cutlass_fp8_supported or self.use_marlin:
398
- weight = layer.weight
399
- weight_scale = convert_to_channelwise(
400
- layer.weight_scale, layer.logical_widths
401
- )
402
- else:
403
- # Dequant -> Quant with max scale so we can run per tensor.
404
- weight = layer.weight
405
- weight_scale = layer.weight_scale
406
- # If ROCm, normalize the weights and scales to e4m3fnuz
407
- if _is_fp8_fnuz:
408
- weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
390
+ # cutlass sgl-kernel and marlin only support per-channel scale
391
+ if self.cutlass_fp8_supported or self.use_marlin:
392
+ weight = layer.weight
393
+ weight_scale = convert_to_channelwise(
394
+ layer.weight_scale, layer.logical_widths
395
+ )
396
+ else:
397
+ # Dequant -> Quant with max scale so we can run per tensor.
398
+ weight = layer.weight
399
+ weight_scale = layer.weight_scale
400
+ # If ROCm, normalize the weights and scales to e4m3fnuz
401
+ if _is_fp8_fnuz:
402
+ weight, weight_scale, input_scale = (
403
+ normalize_e4m3fn_to_e4m3fnuz(
404
+ weight=weight,
405
+ weight_scale=weight_scale,
406
+ input_scale=layer.input_scale,
407
+ )
408
+ )
409
+ if input_scale is not None:
410
+ layer.input_scale = Parameter(
411
+ input_scale, requires_grad=False
412
+ )
413
+
414
+ weight_scale, weight = requantize_with_max_scale(
409
415
  weight=weight,
410
416
  weight_scale=weight_scale,
411
- input_scale=layer.input_scale,
417
+ logical_widths=layer.logical_widths,
412
418
  )
413
- if input_scale is not None:
414
- layer.input_scale = Parameter(input_scale, requires_grad=False)
415
-
416
- weight_scale, weight = requantize_with_max_scale(
417
- weight=weight,
418
- weight_scale=weight_scale,
419
- logical_widths=layer.logical_widths,
420
- )
421
419
 
422
- # Update layer with new values.
423
- layer.weight = Parameter(weight.t(), requires_grad=False)
424
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
425
- if (
426
- hasattr(self.quant_config, "activation_scheme")
427
- and self.quant_config.activation_scheme == "static"
428
- ) or (
429
- hasattr(self.quant_config, "linear_activation_scheme")
430
- and self.quant_config.linear_activation_scheme == "static"
431
- ):
432
- layer.input_scale = Parameter(
433
- layer.input_scale.max(), requires_grad=False
434
- )
420
+ # Update layer with new values.
421
+ layer.weight = Parameter(weight.t(), requires_grad=False)
422
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
423
+ if (
424
+ hasattr(self.quant_config, "activation_scheme")
425
+ and self.quant_config.activation_scheme == "static"
426
+ ) or (
427
+ hasattr(self.quant_config, "linear_activation_scheme")
428
+ and self.quant_config.linear_activation_scheme == "static"
429
+ ):
430
+ layer.input_scale = Parameter(
431
+ layer.input_scale.max(), requires_grad=False
432
+ )
435
433
 
436
434
  if self.use_marlin:
437
- prepare_fp8_layer_for_marlin(layer)
435
+ if self.block_quant:
436
+ layer.weight_block_size = self.quant_config.weight_block_size
437
+ prepare_fp8_layer_for_marlin(layer, not self.block_quant)
438
438
  # Activations not quantized for marlin.
439
439
  del layer.input_scale
440
440
 
@@ -444,7 +444,6 @@ class Fp8LinearMethod(LinearMethodBase):
444
444
  x: torch.Tensor,
445
445
  bias: Optional[torch.Tensor] = None,
446
446
  ) -> torch.Tensor:
447
-
448
447
  if self.use_marlin:
449
448
  return apply_fp8_marlin_linear(
450
449
  input=x,
@@ -515,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
515
514
  self.quant_config = quant_config
516
515
  self.block_quant = self.quant_config.weight_block_size is not None
517
516
  self.cutlass_fp8_supported = cutlass_fp8_supported()
517
+ self.use_cutlass_fused_experts_fp8 = (
518
+ get_bool_env_var("SGLANG_CUTLASS_MOE")
519
+ and self.cutlass_fp8_supported
520
+ and self.block_quant
521
+ and (is_sm100_supported() or is_sm90_supported())
522
+ )
518
523
 
519
524
  def create_weights(
520
525
  self,
@@ -961,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
961
966
  requires_grad=False,
962
967
  )
963
968
  torch.cuda.empty_cache()
969
+
964
970
  # ROCm (_use_aiter): using column-wise scaling
965
971
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
966
972
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
@@ -982,12 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
982
988
  layer: torch.nn.Module,
983
989
  x: torch.Tensor,
984
990
  topk_output: TopKOutput,
985
- *,
986
- activation: str = "silu",
987
- apply_router_weight_on_input: bool = False,
988
- inplace: bool = True,
989
- no_combine: bool = False,
990
- routed_scaling_factor: Optional[float] = None,
991
+ moe_runner_config: MoeRunnerConfig,
991
992
  ) -> torch.Tensor:
992
993
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
993
994
 
@@ -996,7 +997,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
996
997
 
997
998
  topk_weights, topk_ids, _ = topk_output
998
999
  x, topk_weights = apply_topk_weights_cpu(
999
- apply_router_weight_on_input, topk_weights, x
1000
+ moe_runner_config.apply_router_weight_on_input, topk_weights, x
1000
1001
  )
1001
1002
 
1002
1003
  return torch.ops.sgl_kernel.fused_experts_cpu(
@@ -1021,18 +1022,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1021
1022
  layer,
1022
1023
  x,
1023
1024
  topk_output,
1024
- activation,
1025
- no_combine,
1025
+ moe_runner_config.activation,
1026
+ moe_runner_config.no_combine,
1026
1027
  )
1027
1028
  if ret is not None:
1028
1029
  return ret
1029
1030
 
1030
- if (
1031
- get_bool_env_var("SGLANG_CUTLASS_MOE")
1032
- and self.cutlass_fp8_supported
1033
- and self.block_quant
1034
- and (is_sm100_supported() or is_sm90_supported())
1035
- ):
1031
+ if self.use_cutlass_fused_experts_fp8:
1036
1032
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1037
1033
 
1038
1034
  topk_weights, topk_ids, _ = topk_output
@@ -1059,9 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1059
1055
  self.problem_sizes2,
1060
1056
  use_fp8_blockscale=True,
1061
1057
  )
1062
- # TODO: Fuse into select_experts
1063
- if routed_scaling_factor is not None:
1064
- output *= routed_scaling_factor
1058
+ # Scale by routed_scaling_factor is fused into select_experts.
1065
1059
  return output
1066
1060
  # Expert fusion with FP8 quantization
1067
1061
  return fused_experts(
@@ -1069,9 +1063,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1069
1063
  layer.w13_weight,
1070
1064
  layer.w2_weight,
1071
1065
  topk_output=topk_output,
1072
- inplace=inplace and not no_combine,
1073
- activation=activation,
1074
- apply_router_weight_on_input=apply_router_weight_on_input,
1066
+ moe_runner_config=moe_runner_config,
1075
1067
  use_fp8_w8a8=True,
1076
1068
  w1_scale=(
1077
1069
  layer.w13_weight_scale_inv
@@ -1084,30 +1076,44 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1084
1076
  a1_scale=layer.w13_input_scale,
1085
1077
  a2_scale=layer.w2_input_scale,
1086
1078
  block_shape=self.quant_config.weight_block_size,
1087
- no_combine=no_combine,
1088
- routed_scaling_factor=routed_scaling_factor,
1089
1079
  )
1090
1080
 
1091
1081
  def apply_with_router_logits(
1092
1082
  self,
1093
1083
  layer: torch.nn.Module,
1094
1084
  x: torch.Tensor,
1095
- router_logits: torch.Tensor,
1096
- *,
1097
- activation: str = "silu",
1098
- routed_scaling_factor: Optional[float] = None,
1085
+ topk_output: TopKOutput,
1086
+ moe_runner_config: MoeRunnerConfig,
1099
1087
  ) -> torch.Tensor:
1088
+ activation = moe_runner_config.activation
1089
+ routed_scaling_factor = moe_runner_config.routed_scaling_factor
1090
+
1091
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1092
+
1093
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
1094
+
1095
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
1096
+ router_logits = topk_output.router_logits
1097
+ topk_config = topk_output.topk_config
1100
1098
  assert (
1101
1099
  activation == "silu"
1102
1100
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
1103
1101
  a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
1104
1102
  # NOTE: scales of hidden states have to be transposed!
1105
1103
  a_sf_t = a_sf.t().contiguous()
1106
- from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1107
1104
 
1105
+ assert (
1106
+ topk_config.num_expert_group is not None
1107
+ and topk_config.topk_group is not None
1108
+ ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
1109
+
1110
+ if topk_config.correction_bias is None:
1111
+ correction_bias = topk_config.correction_bias.to(x.dtype)
1112
+ else:
1113
+ correction_bias = None
1108
1114
  return trtllm_fp8_block_scale_moe(
1109
1115
  routing_logits=router_logits.to(torch.float32),
1110
- routing_bias=layer.correction_bias.to(x.dtype),
1116
+ routing_bias=correction_bias,
1111
1117
  hidden_states=a_q,
1112
1118
  hidden_states_scale=a_sf_t,
1113
1119
  gemm1_weights=layer.w13_weight,
@@ -1115,15 +1121,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1115
1121
  gemm2_weights=layer.w2_weight,
1116
1122
  gemm2_weights_scale=layer.w2_weight_scale_inv,
1117
1123
  num_experts=layer.num_experts,
1118
- top_k=layer.top_k,
1119
- n_group=layer.num_expert_group,
1120
- topk_group=layer.topk_group,
1124
+ top_k=topk_config.top_k,
1125
+ n_group=topk_config.num_expert_group,
1126
+ topk_group=topk_config.topk_group,
1121
1127
  intermediate_size=layer.w2_weight.shape[2],
1122
1128
  local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
1123
1129
  local_num_experts=layer.num_local_experts,
1124
- routed_scaling_factor=routed_scaling_factor,
1130
+ routed_scaling_factor=(
1131
+ routed_scaling_factor if routed_scaling_factor is not None else 1.0
1132
+ ),
1125
1133
  tile_tokens_dim=get_tile_tokens_dim(
1126
- x.shape[0], layer.top_k, layer.num_experts
1134
+ x.shape[0], topk_config.top_k, layer.num_experts
1127
1135
  ),
1128
1136
  routing_method_type=2, # DeepSeek-styled routing method
1129
1137
  use_shuffled_weight=False,