sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,6 @@ import torch.nn.functional as F
8
8
  from torch.nn import Module
9
9
  from torch.nn.parameter import Parameter
10
10
 
11
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
12
- from sglang.srt.layers.quantization.utils import (
13
- all_close_1d,
14
- convert_to_channelwise,
15
- is_layer_skipped,
16
- per_tensor_dequantize,
17
- requantize_with_max_scale,
18
- )
19
-
20
11
  try:
21
12
  from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
22
13
  apply_fp8_marlin_linear,
@@ -27,11 +18,12 @@ try:
27
18
  except ImportError:
28
19
  MARLIN_FP8_AVAILABLE = False
29
20
 
30
- def apply_fp8_marlin_linear(*args, **kwargs):
31
- raise ImportError("vllm is not installed")
21
+ def dummy_func(*args, **kwargs):
22
+ raise ImportError(
23
+ "marlin FP8 requires some operators from vllm. Please install vllm."
24
+ )
32
25
 
33
- def prepare_fp8_layer_for_marlin(*args, **kwargs):
34
- raise ImportError("vllm is not installed")
26
+ apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
35
27
 
36
28
 
37
29
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
49
41
  QuantizationConfig,
50
42
  QuantizeMethodBase,
51
43
  )
52
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
44
+ from sglang.srt.layers.quantization.fp8_kernel import (
45
+ per_token_group_quant_fp8,
46
+ scaled_fp8_quant,
47
+ )
53
48
  from sglang.srt.layers.quantization.fp8_utils import (
54
49
  apply_fp8_linear,
55
50
  apply_w8a8_block_fp8_linear,
@@ -57,29 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
57
52
  input_to_float8,
58
53
  normalize_e4m3fn_to_e4m3fnuz,
59
54
  )
55
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
56
+ from sglang.srt.layers.quantization.utils import (
57
+ all_close_1d,
58
+ convert_to_channelwise,
59
+ is_layer_skipped,
60
+ per_tensor_dequantize,
61
+ requantize_with_max_scale,
62
+ )
60
63
  from sglang.srt.utils import (
61
64
  get_bool_env_var,
62
65
  is_cuda,
63
66
  is_hip,
64
- permute_weight,
65
67
  print_warning_once,
66
68
  set_weight_attrs,
67
69
  )
68
70
 
69
- ACTIVATION_SCHEMES = ["static", "dynamic"]
70
-
71
71
  _is_hip = is_hip()
72
+ _is_cuda = is_cuda()
72
73
 
73
74
  if _is_hip:
74
- from aiter.fused_moe_bf16_asm import asm_moe
75
+ from aiter import ActivationType
76
+ from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75
77
  from aiter.ops.shuffle import shuffle_weight
76
78
 
77
- _is_cuda = is_cuda()
79
+ if not _is_cuda:
80
+ from vllm._custom_ops import scaled_fp8_quant
78
81
 
79
- if _is_cuda:
80
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
81
- else:
82
- from vllm import _custom_ops as vllm_ops
82
+
83
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
83
84
 
84
85
  logger = logging.getLogger(__name__)
85
86
 
@@ -242,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
242
243
  )
243
244
 
244
245
  layer.logical_widths = output_partition_sizes
245
-
246
246
  layer.input_size_per_partition = input_size_per_partition
247
247
  layer.output_size_per_partition = output_size_per_partition
248
248
  layer.orig_dtype = params_dtype
@@ -326,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
326
326
  layer.weight_scale_inv.data, requires_grad=False
327
327
  )
328
328
  return
329
+
329
330
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
331
+
330
332
  # If checkpoint not serialized fp8, quantize the weights.
331
333
  if not self.quant_config.is_checkpoint_fp8_serialized:
332
334
  if self.cutlass_fp8_supported or self.use_marlin:
@@ -390,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
390
392
  )
391
393
 
392
394
  if self.use_marlin:
393
- try:
394
- prepare_fp8_layer_for_marlin(layer)
395
- # Activations not quantized for marlin.
396
- del layer.input_scale
397
- except ImportError:
398
- self.use_marlin = False
395
+ prepare_fp8_layer_for_marlin(layer)
396
+ # Activations not quantized for marlin.
397
+ del layer.input_scale
399
398
 
400
399
  def apply(
401
400
  self,
@@ -405,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
405
404
  ) -> torch.Tensor:
406
405
 
407
406
  if self.use_marlin:
408
- try:
409
- return apply_fp8_marlin_linear(
410
- input=x,
411
- weight=layer.weight,
412
- weight_scale=layer.weight_scale,
413
- workspace=layer.workspace,
414
- size_n=layer.output_size_per_partition,
415
- size_k=layer.input_size_per_partition,
416
- bias=bias,
417
- )
418
- except ImportError:
419
- self.use_marlin = False
407
+ return apply_fp8_marlin_linear(
408
+ input=x,
409
+ weight=layer.weight,
410
+ weight_scale=layer.weight_scale,
411
+ workspace=layer.workspace,
412
+ size_n=layer.output_size_per_partition,
413
+ size_k=layer.input_size_per_partition,
414
+ bias=bias,
415
+ )
420
416
 
421
417
  if self.block_quant:
422
418
  return apply_w8a8_block_fp8_linear(
@@ -487,7 +483,7 @@ class Fp8MoEMethod:
487
483
 
488
484
  if self.quant_config.is_checkpoint_fp8_serialized:
489
485
  params_dtype = (
490
- torch.int32
486
+ torch.uint32
491
487
  if get_bool_env_var("USE_INT4_WEIGHT")
492
488
  else torch.float8_e4m3fn
493
489
  )
@@ -515,7 +511,7 @@ class Fp8MoEMethod:
515
511
  )
516
512
 
517
513
  # WEIGHTS
518
- if get_bool_env_var("USE_INT4_WEIGHT"):
514
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
519
515
  # INT4 MoE weight - INT32 packed
520
516
  w13_weight = torch.nn.Parameter(
521
517
  torch.empty(
@@ -616,7 +612,7 @@ class Fp8MoEMethod:
616
612
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
617
613
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
618
614
 
619
- if get_bool_env_var("USE_INT4_WEIGHT"):
615
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
620
616
  extra_weight_attrs.update(
621
617
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
622
618
  )
@@ -648,7 +644,7 @@ class Fp8MoEMethod:
648
644
  layer.w2_input_scale = None
649
645
 
650
646
  def process_weights_after_loading(self, layer: Module) -> None:
651
- if get_bool_env_var("USE_INT4_WEIGHT"):
647
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
652
648
  self.process_weights_hip_int4(layer)
653
649
  return
654
650
 
@@ -705,20 +701,12 @@ class Fp8MoEMethod:
705
701
  requires_grad=False,
706
702
  )
707
703
  for expert in range(layer.num_experts):
708
- if _is_cuda:
709
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
710
- sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
711
- )
712
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
713
- sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
714
- )
715
- else:
716
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
717
- vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
718
- )
719
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
720
- vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
721
- )
704
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
705
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
706
+ )
707
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
708
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
709
+ )
722
710
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
723
711
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
724
712
 
@@ -795,18 +783,10 @@ class Fp8MoEMethod:
795
783
  layer.w13_weight[expert_id][start : start + shard_size, :],
796
784
  layer.w13_weight_scale[expert_id][shard_id],
797
785
  )
798
- if _is_cuda:
799
- (
800
- layer.w13_weight[expert_id][start : start + shard_size, :],
801
- _,
802
- ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
803
- else:
804
- (
805
- layer.w13_weight[expert_id][start : start + shard_size, :],
806
- _,
807
- ) = vllm_ops.scaled_fp8_quant(
808
- dq_weight, max_w13_scales[expert_id]
809
- )
786
+ (
787
+ layer.w13_weight[expert_id][start : start + shard_size, :],
788
+ _,
789
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
810
790
  start += shard_size
811
791
 
812
792
  layer.w13_weight_scale = torch.nn.Parameter(
@@ -822,12 +802,14 @@ class Fp8MoEMethod:
822
802
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
823
803
  # Weight Permutation
824
804
  layer.w13_weight = torch.nn.Parameter(
825
- permute_weight(layer.w13_weight.data),
805
+ # permute_weight(layer.w13_weight.data),
806
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
826
807
  requires_grad=False,
827
808
  )
828
809
  torch.cuda.empty_cache()
829
810
  layer.w2_weight = torch.nn.Parameter(
830
- permute_weight(layer.w2_weight.data),
811
+ # permute_weight(layer.w2_weight.data),
812
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
831
813
  requires_grad=False,
832
814
  )
833
815
  torch.cuda.empty_cache()
@@ -867,12 +849,14 @@ class Fp8MoEMethod:
867
849
 
868
850
  if get_bool_env_var("CK_MOE"):
869
851
  layer.w13_weight = torch.nn.Parameter(
870
- permute_weight(layer.w13_weight.data),
852
+ # permute_weight(layer.w13_weight.data),
853
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
871
854
  requires_grad=False,
872
855
  )
873
856
  torch.cuda.empty_cache()
874
857
  layer.w2_weight = torch.nn.Parameter(
875
- permute_weight(layer.w2_weight.data),
858
+ # permute_weight(layer.w2_weight.data),
859
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
876
860
  requires_grad=False,
877
861
  )
878
862
  torch.cuda.empty_cache()
@@ -908,6 +892,7 @@ class Fp8MoEMethod:
908
892
  apply_router_weight_on_input: bool = False,
909
893
  inplace: bool = True,
910
894
  no_combine: bool = False,
895
+ routed_scaling_factor: Optional[float] = None,
911
896
  ) -> torch.Tensor:
912
897
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
913
898
  from sglang.srt.layers.moe.topk import select_experts
@@ -923,41 +908,14 @@ class Fp8MoEMethod:
923
908
  num_expert_group=num_expert_group,
924
909
  custom_routing_function=custom_routing_function,
925
910
  correction_bias=correction_bias,
911
+ routed_scaling_factor=routed_scaling_factor,
926
912
  )
927
913
 
928
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
929
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
930
- assert not no_combine, f"{no_combine=} is not supported."
931
- return asm_moe(
932
- x,
933
- layer.w13_weight,
934
- layer.w2_weight,
935
- topk_weights,
936
- topk_ids,
937
- layer.w13_weight_scale1,
938
- layer.w2_weight_scale1,
939
- activation=activation,
940
- )
941
- if _is_hip and get_bool_env_var("CK_MOE"):
942
- # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
943
- assert (
944
- activation == "silu"
945
- ), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
946
- assert not no_combine, f"{no_combine=} is not supported."
947
- if self.block_quant:
948
- return asm_moe(
949
- x,
950
- layer.w13_weight,
951
- layer.w2_weight,
952
- topk_weights,
953
- topk_ids,
954
- layer.w13_weight_scale_inv,
955
- layer.w2_weight_scale_inv,
956
- block_shape=tuple(self.quant_config.weight_block_size),
957
- expert_mask=None,
958
- )
959
- else:
960
- return asm_moe(
914
+ if _is_hip:
915
+ if get_bool_env_var("USE_INT4_WEIGHT"):
916
+ # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
917
+ assert not no_combine, f"{no_combine=} is not supported."
918
+ return ck_moe_2stages_win4(
961
919
  x,
962
920
  layer.w13_weight,
963
921
  layer.w2_weight,
@@ -965,34 +923,71 @@ class Fp8MoEMethod:
965
923
  topk_ids,
966
924
  layer.w13_weight_scale1,
967
925
  layer.w2_weight_scale1,
926
+ activation=(
927
+ ActivationType.Silu
928
+ if activation == "silu"
929
+ else ActivationType.Gelu
930
+ ),
968
931
  )
969
- else:
970
- # Expert fusion with FP8 quantization
971
- return fused_experts(
972
- x,
973
- layer.w13_weight,
974
- layer.w2_weight,
975
- topk_weights=topk_weights,
976
- topk_ids=topk_ids,
977
- inplace=inplace and not no_combine,
978
- activation=activation,
979
- apply_router_weight_on_input=apply_router_weight_on_input,
980
- use_fp8_w8a8=True,
981
- w1_scale=(
982
- layer.w13_weight_scale_inv
983
- if self.block_quant
984
- else layer.w13_weight_scale
985
- ),
986
- w2_scale=(
987
- layer.w2_weight_scale_inv
988
- if self.block_quant
989
- else layer.w2_weight_scale
990
- ),
991
- a1_scale=layer.w13_input_scale,
992
- a2_scale=layer.w2_input_scale,
993
- block_shape=self.quant_config.weight_block_size,
994
- no_combine=no_combine,
995
- )
932
+
933
+ if get_bool_env_var("CK_MOE"):
934
+ assert not no_combine, f"{no_combine=} is not supported."
935
+ if self.block_quant:
936
+ # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
937
+ assert (
938
+ activation == "silu"
939
+ ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
940
+ return asm_moe(
941
+ x,
942
+ layer.w13_weight,
943
+ layer.w2_weight,
944
+ topk_weights,
945
+ topk_ids,
946
+ layer.w13_weight_scale_inv,
947
+ layer.w2_weight_scale_inv,
948
+ block_shape=tuple(self.quant_config.weight_block_size),
949
+ expert_mask=None,
950
+ )
951
+ else:
952
+ return ck_moe_2stages(
953
+ x,
954
+ layer.w13_weight,
955
+ layer.w2_weight,
956
+ topk_weights,
957
+ topk_ids,
958
+ layer.w13_weight_scale1,
959
+ layer.w2_weight_scale1,
960
+ activation=(
961
+ ActivationType.Silu
962
+ if activation == "silu"
963
+ else ActivationType.Gelu
964
+ ),
965
+ )
966
+
967
+ # Expert fusion with FP8 quantization
968
+ return fused_experts(
969
+ x,
970
+ layer.w13_weight,
971
+ layer.w2_weight,
972
+ topk_weights=topk_weights,
973
+ topk_ids=topk_ids,
974
+ inplace=inplace and not no_combine,
975
+ activation=activation,
976
+ apply_router_weight_on_input=apply_router_weight_on_input,
977
+ use_fp8_w8a8=True,
978
+ w1_scale=(
979
+ layer.w13_weight_scale_inv
980
+ if self.block_quant
981
+ else layer.w13_weight_scale
982
+ ),
983
+ w2_scale=(
984
+ layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
985
+ ),
986
+ a1_scale=layer.w13_input_scale,
987
+ a2_scale=layer.w2_input_scale,
988
+ block_shape=self.quant_config.weight_block_size,
989
+ no_combine=no_combine,
990
+ )
996
991
 
997
992
 
998
993
  class Fp8KVCacheMethod(BaseKVCacheMethod):