sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
77
77
  sparsity_ignore_list: List[str],
78
78
  kv_cache_scheme: Optional[Dict[str, Any]] = None,
79
79
  config: Optional[Dict[str, Any]] = None,
80
+ packed_modules_mapping: Dict[str, List[str]] = {},
80
81
  ):
81
82
  super().__init__()
82
83
  self.ignore = ignore
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
87
88
  self.sparsity_scheme_map = sparsity_scheme_map
88
89
  self.sparsity_ignore_list = sparsity_ignore_list
89
90
  self.config = config
91
+ self.packed_modules_mapping = packed_modules_mapping
90
92
 
91
93
  def get_linear_method(self) -> "CompressedTensorsLinearMethod":
92
94
  return CompressedTensorsLinearMethod(self)
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
136
138
  sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
137
139
  config=config
138
140
  )
141
+ packed_modules_mapping = config.get("packed_modules_mapping", {})
139
142
 
140
143
  return cls(
141
144
  target_scheme_map=target_scheme_map,
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
144
147
  sparsity_scheme_map=sparsity_scheme_map,
145
148
  sparsity_ignore_list=sparsity_ignore_list,
146
149
  config=config,
150
+ packed_modules_mapping=packed_modules_mapping,
147
151
  )
148
152
 
149
153
  @classmethod
@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
103
103
  "input_activations"
104
104
  )
105
105
 
106
- if not (
107
- self.weight_quant.strategy == QuantizationStrategy.TENSOR
108
- and self.input_quant.strategy == QuantizationStrategy.TENSOR
109
- ):
110
- raise ValueError(
111
- "For FP8 Fused MoE layers, only per-tensor scales "
112
- "for weights and activations are supported. Found "
113
- f"{self.weight_quant}, {self.input_quant}"
114
- )
115
-
116
106
  self.static_input_scales = not self.input_quant.dynamic
117
107
 
118
108
  def create_weights(
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
154
144
  set_weight_attrs(w2_weight, extra_weight_attrs)
155
145
 
156
146
  # WEIGHT_SCALES
157
- # Allocate 2 scales for w1 and w3 respectively.
158
- # They will be combined to a single scale after weight loading.
159
- w13_weight_scale = torch.nn.Parameter(
160
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
161
- )
162
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
147
+ # per-tensor quantization
148
+ if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
149
+ # Allocate 2 scales for w1 and w3 respectively.
150
+ # They will be combined to a single scale after weight loading.
151
+ w13_weight_scale = torch.nn.Parameter(
152
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
153
+ )
154
+ w2_weight_scale = torch.nn.Parameter(
155
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
156
+ )
157
+ weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
158
+ elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
159
+ w13_weight_scale = torch.nn.Parameter(
160
+ torch.ones(
161
+ num_experts,
162
+ 2 * intermediate_size_per_partition,
163
+ 1,
164
+ dtype=torch.float32,
165
+ ),
166
+ requires_grad=False,
167
+ )
168
+ w2_weight_scale = torch.nn.Parameter(
169
+ torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
170
+ requires_grad=False,
171
+ )
172
+ weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
173
+ else:
174
+ raise ValueError(
175
+ f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
176
+ )
163
177
 
164
- w2_weight_scale = torch.nn.Parameter(
165
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
166
- )
178
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
167
179
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
168
180
  # Add the quantization method used (per tensor/grouped/channel)
169
181
  # to ensure the weight scales are loaded in properly
170
- extra_weight_attrs.update(
171
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
172
- )
182
+ extra_weight_attrs.update({"quant_method": weight_quant_method})
173
183
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
174
184
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
175
185
 
176
186
  # INPUT_SCALES
177
187
  if self.static_input_scales:
188
+ assert (
189
+ self.input_quant.strategy == QuantizationStrategy.TENSOR
190
+ ), "Only per-tensor quantization is supported for static input scales"
178
191
  w13_input_scale = torch.nn.Parameter(
179
192
  torch.ones(num_experts, dtype=torch.float32), requires_grad=False
180
193
  )
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
241
254
  layer.w2_input_scale = torch.nn.Parameter(
242
255
  w2_input_scale, requires_grad=False
243
256
  )
244
-
245
- # Fp8 moe kernel needs single weight scale for w13 per expert.
246
- # We take the max then dequant and requant each expert.
247
- assert layer.w13_weight_scale is not None
248
- shard_size = layer.intermediate_size_per_partition
249
- max_w13_scales = layer.w13_weight_scale.max(dim=1).values
250
- for expert_id in range(layer.local_num_experts):
251
- start = 0
252
- for shard_id in range(2):
253
- dq_weight = per_tensor_dequantize(
254
- layer.w13_weight[expert_id][start : start + shard_size, :],
255
- layer.w13_weight_scale[expert_id][shard_id],
256
- )
257
-
258
- if _is_cuda:
259
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
260
- sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
261
- )
262
- else:
263
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
264
- vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
257
+ if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
258
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
259
+ # We take the max then dequant and requant each expert.
260
+ assert layer.w13_weight_scale is not None
261
+ shard_size = layer.intermediate_size_per_partition
262
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
263
+ for expert_id in range(layer.local_num_experts):
264
+ start = 0
265
+ for shard_id in range(2):
266
+ dq_weight = per_tensor_dequantize(
267
+ layer.w13_weight[expert_id][start : start + shard_size, :],
268
+ layer.w13_weight_scale[expert_id][shard_id],
265
269
  )
266
- start += shard_size
267
270
 
268
- layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
271
+ if _is_cuda:
272
+ (
273
+ layer.w13_weight[expert_id][start : start + shard_size, :],
274
+ _,
275
+ ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
276
+ else:
277
+ (
278
+ layer.w13_weight[expert_id][start : start + shard_size, :],
279
+ _,
280
+ ) = vllm_ops.scaled_fp8_quant(
281
+ dq_weight, max_w13_scales[expert_id]
282
+ )
283
+ start += shard_size
284
+
285
+ layer.w13_weight_scale = torch.nn.Parameter(
286
+ max_w13_scales, requires_grad=False
287
+ )
269
288
 
270
289
  def apply(
271
290
  self,
@@ -285,6 +304,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
285
304
  activation: str = "silu",
286
305
  inplace: bool = True,
287
306
  no_combine: bool = False,
307
+ apply_router_weight_on_input: bool = False,
288
308
  ) -> torch.Tensor:
289
309
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
290
310
  from sglang.srt.layers.moe.topk import select_experts
@@ -310,10 +330,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
310
330
  inplace=inplace,
311
331
  activation=activation,
312
332
  use_fp8_w8a8=True,
333
+ per_channel_quant=self.weight_quant.strategy
334
+ == QuantizationStrategy.CHANNEL,
313
335
  w1_scale=layer.w13_weight_scale,
314
336
  w2_scale=layer.w2_weight_scale,
315
337
  a1_scale=layer.w13_input_scale,
316
338
  a2_scale=layer.w2_input_scale,
339
+ apply_router_weight_on_input=apply_router_weight_on_input,
317
340
  )
318
341
 
319
342
 
@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
71
71
  _is_hip = is_hip()
72
72
 
73
73
  if _is_hip:
74
- from aiter.fused_moe_bf16_asm import asm_moe
74
+ from aiter import ActivationType
75
+ from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75
76
  from aiter.ops.shuffle import shuffle_weight
76
77
 
77
78
  _is_cuda = is_cuda()
@@ -487,7 +488,7 @@ class Fp8MoEMethod:
487
488
 
488
489
  if self.quant_config.is_checkpoint_fp8_serialized:
489
490
  params_dtype = (
490
- torch.int32
491
+ torch.uint32
491
492
  if get_bool_env_var("USE_INT4_WEIGHT")
492
493
  else torch.float8_e4m3fn
493
494
  )
@@ -822,12 +823,14 @@ class Fp8MoEMethod:
822
823
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
823
824
  # Weight Permutation
824
825
  layer.w13_weight = torch.nn.Parameter(
825
- permute_weight(layer.w13_weight.data),
826
+ # permute_weight(layer.w13_weight.data),
827
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
826
828
  requires_grad=False,
827
829
  )
828
830
  torch.cuda.empty_cache()
829
831
  layer.w2_weight = torch.nn.Parameter(
830
- permute_weight(layer.w2_weight.data),
832
+ # permute_weight(layer.w2_weight.data),
833
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
831
834
  requires_grad=False,
832
835
  )
833
836
  torch.cuda.empty_cache()
@@ -860,19 +863,21 @@ class Fp8MoEMethod:
860
863
  layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
861
864
  layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
862
865
 
863
- def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
866
+ def process_weights_hip_scale_padding(self, layer: Module):
864
867
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
865
868
  padding_size, # Avoid circular import
866
869
  )
867
870
 
868
871
  if get_bool_env_var("CK_MOE"):
869
872
  layer.w13_weight = torch.nn.Parameter(
870
- permute_weight(layer.w13_weight.data),
873
+ # permute_weight(layer.w13_weight.data),
874
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
871
875
  requires_grad=False,
872
876
  )
873
877
  torch.cuda.empty_cache()
874
878
  layer.w2_weight = torch.nn.Parameter(
875
- permute_weight(layer.w2_weight.data),
879
+ # permute_weight(layer.w2_weight.data),
880
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
876
881
  requires_grad=False,
877
882
  )
878
883
  torch.cuda.empty_cache()
@@ -905,6 +910,7 @@ class Fp8MoEMethod:
905
910
  custom_routing_function: Optional[Callable] = None,
906
911
  correction_bias: Optional[torch.Tensor] = None,
907
912
  activation: str = "silu",
913
+ apply_router_weight_on_input: bool = False,
908
914
  inplace: bool = True,
909
915
  no_combine: bool = False,
910
916
  ) -> torch.Tensor:
@@ -927,7 +933,7 @@ class Fp8MoEMethod:
927
933
  if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
928
934
  # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
929
935
  assert not no_combine, f"{no_combine=} is not supported."
930
- return asm_moe(
936
+ return ck_moe_2stages_win4(
931
937
  x,
932
938
  layer.w13_weight,
933
939
  layer.w2_weight,
@@ -935,15 +941,17 @@ class Fp8MoEMethod:
935
941
  topk_ids,
936
942
  layer.w13_weight_scale1,
937
943
  layer.w2_weight_scale1,
938
- activation=activation,
944
+ activation=(
945
+ ActivationType.Silu if activation == "silu" else ActivationType.Gelu
946
+ ),
939
947
  )
940
948
  if _is_hip and get_bool_env_var("CK_MOE"):
941
- # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
942
- assert (
943
- activation == "silu"
944
- ), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
945
949
  assert not no_combine, f"{no_combine=} is not supported."
946
950
  if self.block_quant:
951
+ # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
952
+ assert (
953
+ activation == "silu"
954
+ ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
947
955
  return asm_moe(
948
956
  x,
949
957
  layer.w13_weight,
@@ -956,7 +964,7 @@ class Fp8MoEMethod:
956
964
  expert_mask=None,
957
965
  )
958
966
  else:
959
- return asm_moe(
967
+ return ck_moe_2stages(
960
968
  x,
961
969
  layer.w13_weight,
962
970
  layer.w2_weight,
@@ -964,6 +972,11 @@ class Fp8MoEMethod:
964
972
  topk_ids,
965
973
  layer.w13_weight_scale1,
966
974
  layer.w2_weight_scale1,
975
+ activation=(
976
+ ActivationType.Silu
977
+ if activation == "silu"
978
+ else ActivationType.Gelu
979
+ ),
967
980
  )
968
981
  else:
969
982
  # Expert fusion with FP8 quantization
@@ -975,6 +988,7 @@ class Fp8MoEMethod:
975
988
  topk_ids=topk_ids,
976
989
  inplace=inplace and not no_combine,
977
990
  activation=activation,
991
+ apply_router_weight_on_input=apply_router_weight_on_input,
978
992
  use_fp8_w8a8=True,
979
993
  w1_scale=(
980
994
  layer.w13_weight_scale_inv
@@ -16,6 +16,7 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
+ from contextlib import contextmanager
19
20
  from typing import Any, Dict, List, Optional, Tuple
20
21
 
21
22
  import torch
@@ -40,11 +41,13 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
40
41
 
41
42
  _is_cuda = is_cuda()
42
43
  if _is_cuda:
43
- import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
44
+ import deep_gemm
44
45
  from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
45
46
 
46
47
  sm_version = get_device_sm()
47
- if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
48
+ if sm_version == 90 and get_bool_env_var(
49
+ "SGL_ENABLE_JIT_DEEPGEMM", default="false"
50
+ ):
48
51
  _enable_jit_deepgemm = True
49
52
 
50
53
 
@@ -59,7 +62,10 @@ if supports_custom_op():
59
62
  Bs: torch.Tensor,
60
63
  C: torch.Tensor,
61
64
  ) -> None:
62
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
65
+ M, K = A.shape
66
+ N, _ = B.shape
67
+ with _log_jit_build(M, N, K):
68
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
63
69
 
64
70
  def deep_gemm_fp8_fp8_bf16_nt_fake(
65
71
  A: torch.Tensor,
@@ -708,6 +714,25 @@ def get_w8a8_block_fp8_configs(
708
714
  return None
709
715
 
710
716
 
717
+ @contextmanager
718
+ def _log_jit_build(M: int, N: int, K: int):
719
+ from deep_gemm.jit.runtime import RuntimeCache
720
+
721
+ origin_func = RuntimeCache.__getitem__
722
+
723
+ def __patched_func(self, *args, **kwargs):
724
+ ret = origin_func(self, *args, **kwargs)
725
+ if ret is None:
726
+ logger.warning(
727
+ f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
728
+ )
729
+ return ret
730
+
731
+ RuntimeCache.__getitem__ = __patched_func
732
+ yield
733
+ RuntimeCache.__getitem__ = origin_func
734
+
735
+
711
736
  def w8a8_block_fp8_matmul(
712
737
  A: torch.Tensor,
713
738
  B: torch.Tensor,
@@ -782,7 +807,8 @@ def w8a8_block_fp8_matmul(
782
807
  if supports_custom_op():
783
808
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
784
809
  else:
785
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
810
+ with _log_jit_build(M, N, K):
811
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
786
812
  else:
787
813
  kernel = (
788
814
  _w8a8_block_fp8_matmul_unrolledx4
@@ -815,3 +841,103 @@ def w8a8_block_fp8_matmul(
815
841
  )
816
842
 
817
843
  return C
844
+
845
+
846
+ @triton.jit
847
+ def _per_tensor_quant_mla_fp8_stage1(
848
+ x_ptr,
849
+ x_s_ptr,
850
+ head_size,
851
+ x_stride_h,
852
+ x_stride_s,
853
+ eps,
854
+ fp8_max,
855
+ BLOCK_SIZE: tl.constexpr,
856
+ ):
857
+ seq_id = tl.program_id(0)
858
+ head_id = tl.program_id(1)
859
+ offset = tl.arange(0, BLOCK_SIZE)
860
+ mask = offset < head_size
861
+
862
+ x_ptr += head_id * x_stride_h + seq_id * x_stride_s
863
+ x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
864
+ _absmax = tl.maximum(tl.max(tl.abs(x)), eps)
865
+
866
+ tl.atomic_max(x_s_ptr, _absmax / fp8_max)
867
+
868
+
869
+ @triton.jit
870
+ def _per_tensor_quant_mla_fp8_stage2(
871
+ x_ptr,
872
+ x_s_ptr,
873
+ x_q_ptr,
874
+ num_seq,
875
+ head_size,
876
+ x_stride_h,
877
+ x_stride_s,
878
+ fp8_min,
879
+ fp8_max,
880
+ BLOCK_SIZE: tl.constexpr,
881
+ ):
882
+ seq_id = tl.program_id(0)
883
+ head_id = tl.program_id(1)
884
+ offset = tl.arange(0, BLOCK_SIZE)
885
+ mask = offset < head_size
886
+
887
+ x_s = tl.load(x_s_ptr)
888
+ x_s_inv = 1.0 / x_s
889
+
890
+ x_ptr += head_id * x_stride_h + seq_id * x_stride_s
891
+ x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
892
+
893
+ x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
894
+ x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
895
+ tl.store(x_q_ptr + offset, x_q, mask=mask)
896
+
897
+
898
+ def per_tensor_quant_mla_fp8(
899
+ x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
900
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
901
+ """
902
+ This function quantizes input values to float8 values with tensor-wise quantization
903
+ and specialized for mla absorbed case.
904
+ """
905
+ assert x.dim() == 3, "`x` is not a 3d-tensor"
906
+
907
+ finfo = torch.finfo(dtype)
908
+ fp8_max = finfo.max
909
+ if _is_hip:
910
+ dtype = torch.float8_e4m3fnuz
911
+ fp8_max = 224.0
912
+
913
+ x_q = x.new_empty(x.size(), dtype=dtype)
914
+ x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
915
+
916
+ num_head, num_seq, head_size = x.shape
917
+ BLOCK_SIZE = triton.next_power_of_2(head_size)
918
+ grid = (num_seq, num_head)
919
+
920
+ _per_tensor_quant_mla_fp8_stage1[grid](
921
+ x,
922
+ x_s,
923
+ head_size,
924
+ x.stride(0),
925
+ x.stride(1),
926
+ eps,
927
+ fp8_max,
928
+ BLOCK_SIZE,
929
+ )
930
+ _per_tensor_quant_mla_fp8_stage2[grid](
931
+ x,
932
+ x_s,
933
+ x_q,
934
+ num_seq,
935
+ head_size,
936
+ x.stride(0),
937
+ x.stride(1),
938
+ -fp8_max,
939
+ fp8_max,
940
+ BLOCK_SIZE,
941
+ )
942
+
943
+ return x_q, x_s
@@ -168,12 +168,13 @@ def input_to_float8(
168
168
  """This function quantizes input values to float8 values with tensor-wise quantization."""
169
169
  finfo = torch.finfo(dtype)
170
170
  min_val, max_val = x.aminmax()
171
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
171
+ amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
172
172
  fp8_max = finfo.max
173
173
  if _is_hip:
174
+ dtype = torch.float8_e4m3fnuz
174
175
  fp8_max = 224.0
175
176
  scale = fp8_max / amax
176
- x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
177
+ x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
177
178
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
178
179
 
179
180
 
@@ -212,7 +213,24 @@ def block_quant_to_tensor_quant(
212
213
  for j in range(n_tiles):
213
214
  x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
214
215
 
215
- x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
216
+ x_q_tensor, scale = (
217
+ sgl_scaled_fp8_quant(x_dq_block)
218
+ if _is_cuda
219
+ else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
220
+ )
221
+ return x_q_tensor, scale
222
+
223
+
224
+ def channel_quant_to_tensor_quant(
225
+ x_q_channel: torch.Tensor,
226
+ x_s: torch.Tensor,
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ x_dq_channel = x_q_channel.to(torch.float32) * x_s
229
+ x_q_tensor, scale = (
230
+ sgl_scaled_fp8_quant(x_dq_channel)
231
+ if _is_cuda
232
+ else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
233
+ )
216
234
  return x_q_tensor, scale
217
235
 
218
236
 
@@ -242,9 +260,19 @@ def apply_fp8_linear(
242
260
  if _is_cuda:
243
261
  qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
244
262
  else:
245
- qinput, x_scale = per_token_group_quant_fp8(
246
- input_2d, group_size=input_2d.shape[1]
247
- )
263
+ # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
264
+ # final solution should be: 1. add support to per-tensor activation scaling.
265
+ # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
266
+ if _is_hip and weight_scale.numel() == 1:
267
+ qinput, x_scale = ops.scaled_fp8_quant(
268
+ input_2d,
269
+ input_scale,
270
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
271
+ )
272
+ else:
273
+ qinput, x_scale = per_token_group_quant_fp8(
274
+ input_2d, group_size=input_2d.shape[1]
275
+ )
248
276
 
249
277
  if cutlass_fp8_supported:
250
278
  try: