sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
51
51
 
52
52
  is_hip_ = is_hip()
53
53
 
54
+ if is_hip_:
55
+ from aiter.fused_moe_bf16_asm import asm_moe
56
+ from aiter.ops.shuffle import shuffle_weight
57
+
54
58
  logger = logging.getLogger(__name__)
55
59
 
56
60
 
@@ -533,6 +537,20 @@ class Fp8MoEMethod:
533
537
  )
534
538
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
535
539
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
540
+
541
+ if is_hip_ and get_bool_env_var("CK_MOE"):
542
+ # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
543
+ w13_weight_scale1 = torch.nn.Parameter(
544
+ torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
545
+ requires_grad=False,
546
+ )
547
+ w2_weight_scale1 = torch.nn.Parameter(
548
+ torch.ones(num_experts, hidden_size, dtype=torch.float32),
549
+ requires_grad=False,
550
+ )
551
+ layer.register_parameter("w13_weight_scale1", w13_weight_scale1)
552
+ layer.register_parameter("w2_weight_scale1", w2_weight_scale1)
553
+
536
554
  # Add the quantization method used (per tensor/grouped/channel)
537
555
  # to ensure the weight scales are loaded in properly
538
556
  extra_weight_attrs.update(
@@ -602,6 +620,15 @@ class Fp8MoEMethod:
602
620
  w2_weight_scale, requires_grad=False
603
621
  )
604
622
  layer.w2_input_scale = None
623
+
624
+ if get_bool_env_var("CK_MOE"):
625
+ # Pre-shuffle weights
626
+ layer.w13_weight.data = shuffle_weight(
627
+ layer.w13_weight.contiguous(), (16, 16)
628
+ )
629
+ layer.w2_weight.data = shuffle_weight(
630
+ layer.w2_weight.contiguous(), (16, 16)
631
+ )
605
632
  return
606
633
  # If checkpoint is fp16 or bfloat16, quantize in place.
607
634
  if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -640,6 +667,9 @@ class Fp8MoEMethod:
640
667
  requires_grad=False,
641
668
  )
642
669
  torch.cuda.empty_cache()
670
+ # ROCm (CK_MOE): using column-wise scaling
671
+ layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
672
+ layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
643
673
  elif get_bool_env_var("MOE_PADDING"):
644
674
  # If ROCm, apply weight padding (min. Mem channel contention) only if set
645
675
  layer.w13_weight = torch.nn.Parameter(
@@ -744,6 +774,9 @@ class Fp8MoEMethod:
744
774
  requires_grad=False,
745
775
  )
746
776
  torch.cuda.empty_cache()
777
+ # ROCm (CK_MOE): using column-wise scaling
778
+ layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
779
+ layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
747
780
  elif get_bool_env_var("MOE_PADDING"):
748
781
  # If ROCm, apply weight padding (min. Mem channel contention) only if set
749
782
  layer.w13_weight = torch.nn.Parameter(
@@ -771,6 +804,8 @@ class Fp8MoEMethod:
771
804
  custom_routing_function: Optional[Callable] = None,
772
805
  correction_bias: Optional[torch.Tensor] = None,
773
806
  activation: str = "silu",
807
+ inplace: bool = True,
808
+ no_combine: bool = False,
774
809
  ) -> torch.Tensor:
775
810
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
776
811
  from sglang.srt.layers.moe.topk import select_experts
@@ -788,33 +823,38 @@ class Fp8MoEMethod:
788
823
  correction_bias=correction_bias,
789
824
  )
790
825
 
791
- if is_hip_ and get_bool_env_var("CK_MOE"):
792
- import ater
793
- from ater.fused_moe import fused_experts_ck
794
-
795
- assert activation == "silu", f"{activation=} is not supported."
796
-
797
- return fused_experts_ck(
798
- x,
799
- layer.w13_weight,
800
- layer.w2_weight,
801
- topk_weights=topk_weights,
802
- topk_ids=topk_ids,
803
- use_fp8_w8a8=True,
804
- w1_scale=(
805
- layer.w13_weight_scale_inv
806
- if self.block_quant
807
- else layer.w13_weight_scale
808
- ),
809
- w2_scale=(
810
- layer.w2_weight_scale_inv
811
- if self.block_quant
812
- else layer.w2_weight_scale
813
- ),
814
- a1_scale=layer.w13_input_scale,
815
- a2_scale=layer.w2_input_scale,
816
- )
817
-
826
+ if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
827
+ # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
828
+ assert not no_combine, f"{no_combine=} is not supported."
829
+ if self.block_quant:
830
+ return asm_moe(
831
+ x,
832
+ layer.w13_weight,
833
+ layer.w2_weight,
834
+ topk_weights,
835
+ topk_ids,
836
+ layer.w13_weight_scale_inv,
837
+ layer.w2_weight_scale_inv,
838
+ None,
839
+ None,
840
+ False,
841
+ None,
842
+ block_shape=tuple(self.quant_config.weight_block_size),
843
+ expert_mask=None,
844
+ )
845
+ else:
846
+ return asm_moe(
847
+ x,
848
+ layer.w13_weight,
849
+ layer.w2_weight,
850
+ topk_weights,
851
+ topk_ids,
852
+ layer.w13_weight_scale1,
853
+ layer.w2_weight_scale1,
854
+ None,
855
+ None,
856
+ False,
857
+ )
818
858
  else:
819
859
  # Expert fusion with FP8 quantization
820
860
  return fused_experts(
@@ -823,7 +863,7 @@ class Fp8MoEMethod:
823
863
  layer.w2_weight,
824
864
  topk_weights=topk_weights,
825
865
  topk_ids=topk_ids,
826
- inplace=True,
866
+ inplace=inplace and not no_combine,
827
867
  activation=activation,
828
868
  use_fp8_w8a8=True,
829
869
  w1_scale=(
@@ -839,6 +879,7 @@ class Fp8MoEMethod:
839
879
  a1_scale=layer.w13_input_scale,
840
880
  a2_scale=layer.w2_input_scale,
841
881
  block_shape=self.quant_config.weight_block_size,
882
+ no_combine=no_combine,
842
883
  )
843
884
 
844
885
 
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import List, Optional, Tuple
2
3
 
3
4
  import torch
@@ -7,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
7
8
  per_token_group_quant_fp8,
8
9
  w8a8_block_fp8_matmul,
9
10
  )
10
- from sglang.srt.utils import is_hip
11
+ from sglang.srt.utils import get_bool_env_var, is_hip
11
12
 
12
13
  is_hip_ = is_hip()
14
+ if is_hip_ and get_bool_env_var("CK_MOE"):
15
+ from aiter import gemm_a8w8_blockscale
16
+
13
17
  _is_cuda = torch.cuda.is_available() and torch.version.cuda
14
18
  if _is_cuda:
15
19
  from sgl_kernel import fp8_blockwise_scaled_mm
@@ -40,6 +44,8 @@ def normalize_e4m3fn_to_e4m3fnuz(
40
44
 
41
45
 
42
46
  def cutlass_block_fp8_supported() -> bool:
47
+ if os.environ.get("SUPPORT_CUTLASS_BLOCK_FP8") is None:
48
+ return False
43
49
  if _is_cuda:
44
50
  major, minor = torch.cuda.get_device_capability()
45
51
  sm_version = major * 10 + minor
@@ -75,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
75
81
  output = fp8_blockwise_scaled_mm(
76
82
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
77
83
  )
84
+ elif is_hip_ and get_bool_env_var("CK_MOE"):
85
+ q_input, x_scale = per_token_group_quant_fp8(
86
+ input_2d, block_size[1], column_major_scales=False
87
+ )
88
+ output = torch.zeros(
89
+ [q_input.shape[0], weight.shape[0]],
90
+ dtype=input.dtype,
91
+ device=q_input.device,
92
+ )
93
+ gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
78
94
  else:
79
95
  q_input, x_scale = per_token_group_quant_fp8(
80
96
  input_2d, block_size[1], column_major_scales=False
@@ -0,0 +1,416 @@
1
+ import logging
2
+ from fractions import Fraction
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ import torch
6
+ from vllm.scalar_type import scalar_types
7
+
8
+ from sglang.srt.layers.linear import LinearBase
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GPTQConfig(QuantizationConfig):
16
+ """Config class for GPTQ.
17
+
18
+ Reference: https://arxiv.org/abs/2210.17323
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ weight_bits: int,
24
+ group_size: int,
25
+ desc_act: bool,
26
+ lm_head_quantized: bool,
27
+ dynamic: Dict[str, Dict[str, Union[int, bool]]],
28
+ ) -> None:
29
+ # GPTQModel use `dynamic` config property to allow per module
30
+ # quantization config so each module can be individually optimized.
31
+ # Format is Dict[str, Dict] where key is a regex string that can
32
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
33
+ # matching of a module.
34
+ # Default to positive match, override base quant config mode, if no
35
+ # prefix is used. Value is in dict format of field key and override
36
+ # value.
37
+ # Negative matching will skip quantization init for this module
38
+ # entirely:
39
+ # non-quantized inference. More details and quantization examples can be
40
+ # found at: https://github.com/ModelCloud/GPTQModel
41
+ # Example:
42
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
43
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
44
+ # dynamic = {
45
+ # #`.*\.` matches the layers_node prefix
46
+ # # positive match layer 10-15
47
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
48
+ # # positive match layer 16-21
49
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
50
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
51
+ # }
52
+ super().__init__()
53
+ self.dynamic = dynamic
54
+
55
+ self.weight_bits = weight_bits
56
+ self.group_size = group_size
57
+ self.desc_act = desc_act
58
+ self.lm_head_quantized = lm_head_quantized
59
+ self.pack_factor = Fraction(32, self.weight_bits)
60
+ if self.weight_bits not in [2, 3, 4, 8]:
61
+ raise ValueError(
62
+ "Currently, only 2/3/4/8-bit weight quantization is "
63
+ f"supported for GPTQ, but got {self.weight_bits} bits."
64
+ )
65
+
66
+ def __repr__(self) -> str:
67
+ return (
68
+ f"GPTQConfig(weight_bits={self.weight_bits}, "
69
+ f"group_size={self.group_size}, "
70
+ f"desc_act={self.desc_act}),"
71
+ f"lm_head_quantized={self.lm_head_quantized}), "
72
+ f"dynamic={self.dynamic}"
73
+ )
74
+
75
+ def get_scaled_act_names(self) -> List[str]:
76
+ """Returns the activation function names that should be post-scaled.
77
+
78
+ For now, this is only used by AWQ.
79
+ """
80
+ raise NotImplementedError
81
+
82
+ @classmethod
83
+ def get_name(cls) -> str:
84
+ return "gptq"
85
+
86
+ @classmethod
87
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
88
+ return [torch.half]
89
+
90
+ @classmethod
91
+ # Need to figure it out
92
+ def get_min_capability(cls) -> int:
93
+ return 60
94
+
95
+ @classmethod
96
+ def get_config_filenames(cls) -> List[str]:
97
+ return ["quantize_config.json"]
98
+
99
+ @classmethod
100
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
101
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
102
+ dynamic = {} if dynamic is None else dynamic
103
+
104
+ weight_bits = cls.get_from_keys(config, ["bits"])
105
+ group_size = cls.get_from_keys(config, ["group_size"])
106
+ desc_act = cls.get_from_keys(config, ["desc_act"])
107
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
108
+ return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
109
+
110
+ def get_quant_method(
111
+ self, layer: torch.nn.Module, prefix: str
112
+ ) -> Optional["GPTQLinearMethod"]:
113
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
114
+
115
+ from sglang.srt.layers.quantization import get_linear_quant_method
116
+
117
+ return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
118
+
119
+
120
+ class GPTQMarlinConfig(QuantizationConfig):
121
+ """Config class for GPTQ Marlin"""
122
+
123
+ # (num_bits, is_sym) -> quant_type
124
+ TYPE_MAP = {
125
+ (4, True): scalar_types.uint4b8,
126
+ (8, True): scalar_types.uint8b128,
127
+ }
128
+
129
+ def __init__(
130
+ self,
131
+ weight_bits: int,
132
+ group_size: int,
133
+ desc_act: bool,
134
+ is_sym: bool,
135
+ lm_head_quantized: bool,
136
+ dynamic: Dict[str, Dict[str, Union[int, bool]]],
137
+ full_config: Dict[str, Any],
138
+ ) -> None:
139
+ super().__init__()
140
+ if desc_act and group_size == -1:
141
+ # In this case, act_order == True is the same as act_order == False
142
+ # (since we have only one group per output channel)
143
+ desc_act = False
144
+
145
+ # GPTQModel use `dynamic` config property to allow per module
146
+ # quantization config so each module can be individually optimized.
147
+ # Format is Dict[str, Dict] where key is a regex string that can
148
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
149
+ # matching of a module.
150
+ # Default to positive match, override base quant config mode, if no
151
+ # prefix is used. Value is in dict format of field key and override
152
+ # value.
153
+ # Negative matching will skip quantization init for this module
154
+ # entirely:
155
+ # non-quantized inference. More details and quantization examples can be
156
+ # found at: https://github.com/ModelCloud/GPTQModel
157
+ # Example:
158
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
159
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
160
+ # dynamic = {
161
+ # #`.*\.` matches the layers_node prefix
162
+ # # positive match layer 10-15
163
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
164
+ # # positive match layer 16-21
165
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
166
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
167
+ # }
168
+ self.dynamic = dynamic
169
+
170
+ self.weight_bits = weight_bits
171
+ self.is_sym = is_sym
172
+
173
+ self.pack_factor = 32 // weight_bits # packed into int32
174
+ self.group_size = group_size
175
+ self.desc_act = desc_act
176
+ self.lm_head_quantized = lm_head_quantized
177
+ self.full_config = full_config
178
+
179
+ if (weight_bits, is_sym) not in self.TYPE_MAP:
180
+ raise ValueError(
181
+ "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
182
+ )
183
+
184
+ self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
185
+
186
+ def __repr__(self) -> str:
187
+ return (
188
+ f"GPTQMarlinConfig(quant_type={self.quant_type}, "
189
+ f"group_size={self.group_size}, "
190
+ f"desc_act={self.desc_act}, "
191
+ f"lm_head_quantized={self.lm_head_quantized}), "
192
+ f"dynamic={self.dynamic}"
193
+ )
194
+
195
+ def get_scaled_act_names(self) -> List[str]:
196
+ """Returns the activation function names that should be post-scaled.
197
+
198
+ For now, this is only used by AWQ.
199
+ """
200
+ raise NotImplementedError
201
+
202
+ @classmethod
203
+ def get_name(cls) -> str:
204
+ return "gptq_marlin"
205
+
206
+ @classmethod
207
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
208
+ return [torch.half, torch.bfloat16]
209
+
210
+ @classmethod
211
+ def get_min_capability(cls) -> int:
212
+ return 80
213
+
214
+ @classmethod
215
+ def get_config_filenames(cls) -> List[str]:
216
+ return ["quantize_config.json"]
217
+
218
+ @classmethod
219
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
220
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
221
+ dynamic = {} if dynamic is None else dynamic
222
+
223
+ weight_bits = cls.get_from_keys(config, ["bits"])
224
+ group_size = cls.get_from_keys(config, ["group_size"])
225
+ desc_act = cls.get_from_keys(config, ["desc_act"])
226
+ is_sym = cls.get_from_keys(config, ["sym"])
227
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
228
+ return cls(
229
+ weight_bits,
230
+ group_size,
231
+ desc_act,
232
+ is_sym,
233
+ lm_head_quantized,
234
+ dynamic,
235
+ config,
236
+ )
237
+
238
+ @classmethod
239
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
240
+ can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
241
+
242
+ is_valid_user_quant = (
243
+ user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
244
+ )
245
+
246
+ if can_convert and is_valid_user_quant:
247
+ msg = (
248
+ "The model is convertible to {} during runtime."
249
+ " Using {} kernel.".format(cls.get_name(), cls.get_name())
250
+ )
251
+ logger.info(msg)
252
+ return cls.get_name()
253
+
254
+ if can_convert and user_quant == "gptq":
255
+ logger.info(
256
+ "Detected that the model can run with gptq_marlin"
257
+ ", however you specified quantization=gptq explicitly,"
258
+ " so forcing gptq. Use quantization=gptq_marlin for"
259
+ " faster inference"
260
+ )
261
+ return None
262
+
263
+ def get_quant_method(
264
+ self, layer: torch.nn.Module, prefix: str
265
+ ) -> Optional["QuantizeMethodBase"]:
266
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
267
+ GPTQMarlinLinearMethod,
268
+ GPTQMarlinMoEMethod,
269
+ )
270
+
271
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
272
+ from sglang.srt.layers.quantization import get_linear_quant_method
273
+
274
+ if isinstance(layer, FusedMoE):
275
+ return GPTQMarlinMoEMethod(self)
276
+ # TODO: re-enable after SGLang syncs with vllm >= 0.7.3
277
+ # if layer.num_experts > 32:
278
+ # # For MoEs with many experts the moe_wna16 kernel is faster
279
+ # return MoeWNA16Config.from_config(self.full_config).get_quant_method(
280
+ # layer, prefix
281
+ # )
282
+ # else:
283
+ # return GPTQMarlinMoEMethod(self)
284
+ return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
285
+
286
+ @classmethod
287
+ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
288
+ quant_method = quant_config.get("quant_method", "").lower()
289
+ num_bits = quant_config.get("bits")
290
+ group_size = quant_config.get("group_size")
291
+ sym = quant_config.get("sym")
292
+ desc_act = quant_config.get("desc_act")
293
+
294
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
295
+ check_marlin_supported,
296
+ )
297
+ from vllm.platforms import current_platform
298
+
299
+ if not current_platform.is_cuda():
300
+ return False
301
+
302
+ if quant_method != "gptq":
303
+ return False
304
+
305
+ # Marlin conversion is only valid if required properties are found
306
+ if num_bits is None or group_size is None or sym is None or desc_act is None:
307
+ return False
308
+
309
+ if (num_bits, sym) not in cls.TYPE_MAP:
310
+ return False
311
+
312
+ return check_marlin_supported(
313
+ quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
314
+ )
315
+
316
+
317
+ class MarlinConfig(QuantizationConfig):
318
+ """Config class for Marlin.
319
+
320
+ Reference: https://github.com/IST-DASLab/marlin/tree/master
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ group_size: int,
326
+ lm_head_quantized: bool,
327
+ ) -> None:
328
+ # Group size for the quantization.
329
+ self.group_size = group_size
330
+ self.lm_head_quantized = lm_head_quantized
331
+ if self.group_size != 128 and self.group_size != -1:
332
+ raise ValueError(
333
+ "Currently, only group size 128 and -1 (channelwise) "
334
+ "is supported for Marlin, but got group_size of "
335
+ f"{self.group_size}"
336
+ )
337
+
338
+ # 4 Bits packed into 32 bit datatype.
339
+ self.pack_factor = 32 // 4
340
+
341
+ # Tile size used by marlin kernels.
342
+ self.tile_size = 16
343
+
344
+ # Min out_features dim
345
+ self.min_n_threads = 64
346
+
347
+ # Min in_features dim
348
+ self.min_k_threads = 128
349
+
350
+ # Max parallel problems to solve at once (improves large
351
+ # batch performance)
352
+ self.max_parallel = 16
353
+
354
+ # Permutation length used by the marlin kernels.
355
+ self.perm_len = 1024
356
+
357
+ def __repr__(self) -> str:
358
+ return (
359
+ f"MarlinConfig(group_size={self.group_size}, "
360
+ f"lm_head_quantized={self.lm_head_quantized})"
361
+ )
362
+
363
+ @classmethod
364
+ def get_name(cls) -> str:
365
+ return "marlin"
366
+
367
+ @classmethod
368
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
369
+ return [torch.half]
370
+
371
+ @classmethod
372
+ # Need to figure it out
373
+ def get_min_capability(cls) -> int:
374
+ return 80
375
+
376
+ @classmethod
377
+ def get_config_filenames(cls) -> List[str]:
378
+ return ["quantize_config.json"]
379
+
380
+ @classmethod
381
+ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
382
+ group_size = cls.get_from_keys(config, ["group_size"])
383
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
384
+ return cls(group_size, lm_head_quantized)
385
+
386
+ @classmethod
387
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
388
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
389
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
390
+ is_marlin_format = hf_quant_cfg.get(
391
+ "checkpoint_format"
392
+ ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
393
+
394
+ is_valid_user_quant = (
395
+ user_quant is None or user_quant == "gptq" or user_quant == "marlin"
396
+ )
397
+
398
+ if is_marlin_format and is_valid_user_quant:
399
+ msg = "The model is serialized in {} format. Using {} kernel.".format(
400
+ cls.get_name(), cls.get_name()
401
+ )
402
+ logger.info(msg)
403
+ return cls.get_name()
404
+
405
+ return None
406
+
407
+ def get_quant_method(
408
+ self, layer: torch.nn.Module, prefix: str
409
+ ) -> Optional["MarlinLinearMethod"]:
410
+ from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
411
+
412
+ if isinstance(layer, LinearBase) or (
413
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
414
+ ):
415
+ return MarlinLinearMethod(self)
416
+ return None