sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
172
172
  self, layer: torch.nn.Module, prefix: str
173
173
  ) -> Optional[QuantizeMethodBase]:
174
174
  from sglang.srt.layers.linear import LinearBase
175
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
175
176
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
176
177
 
177
178
  if isinstance(layer, LinearBase):
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
180
181
  return Fp8LinearMethod(self)
181
182
  elif isinstance(layer, FusedMoE):
182
183
  return Fp8MoEMethod(self)
184
+ elif isinstance(layer, EPMoE):
185
+ return Fp8EPMoEMethod(self)
183
186
  return None
184
187
 
185
188
  def get_scaled_act_names(self) -> List[str]:
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
791
794
  # merged w13 weights and generate a single scaling factor.
792
795
  layer.w13_weight_scale = torch.nn.Parameter(
793
796
  torch.ones(
794
- layer.num_experts, dtype=torch.float32, device=w13_weight.device
797
+ layer.num_local_experts,
798
+ dtype=torch.float32,
799
+ device=w13_weight.device,
795
800
  ),
796
801
  requires_grad=False,
797
802
  )
798
- for expert in range(layer.num_experts):
803
+ for expert in range(layer.num_local_experts):
799
804
  w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
800
805
  scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
801
806
  )
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
871
876
  assert layer.w13_weight_scale is not None
872
877
  shard_size = layer.intermediate_size_per_partition
873
878
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
874
- for expert_id in range(layer.num_experts):
879
+ for expert_id in range(layer.num_local_experts):
875
880
  start = 0
876
881
  for shard_id in range(2):
877
882
  dq_weight = per_tensor_dequantize(
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
914
919
  assert layer.w13_weight_scale is not None
915
920
  shard_size = layer.intermediate_size_per_partition
916
921
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
917
- for expert_id in range(layer.num_experts):
922
+ for expert_id in range(layer.num_local_experts):
918
923
  start = 0
919
924
  max_w13_scale_fp8 = max_w13_scales[expert_id]
920
925
  for shard_id in range(2):
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
931
936
 
932
937
  # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
933
938
  # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
934
- for expert_id in range(layer.num_experts):
939
+ for expert_id in range(layer.num_local_experts):
935
940
  layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
936
941
  layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
937
942
 
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
979
984
  no_combine: bool = False,
980
985
  routed_scaling_factor: Optional[float] = None,
981
986
  ) -> torch.Tensor:
987
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
982
988
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
983
989
 
990
+ if isinstance(layer, EPMoE):
991
+ layer.w13_weight_scale = (
992
+ layer.w13_weight_scale_inv
993
+ if self.block_quant
994
+ else layer.w13_weight_scale
995
+ )
996
+ layer.w2_weight_scale = (
997
+ layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
998
+ )
999
+ return layer.run_moe(
1000
+ hidden_states=x,
1001
+ topk_output=topk_output,
1002
+ )
1003
+
984
1004
  if use_intel_amx_backend(layer):
985
1005
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
986
1006
 
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1138
1158
  return None
1139
1159
 
1140
1160
 
1141
- class Fp8EPMoEMethod(Fp8MoEMethod):
1142
- """MoE method for FP8.
1143
- Supports loading FP8 checkpoints with static weight scale and
1144
- dynamic/static activation scale.
1145
-
1146
- Args:
1147
- quant_config: The quantization config.
1148
- """
1149
-
1150
- def __init__(self, quant_config: Fp8Config):
1151
- self.quant_config = quant_config
1152
- self.block_quant = self.quant_config.weight_block_size is not None
1153
-
1154
- def create_weights(
1155
- self,
1156
- layer: Module,
1157
- num_experts_per_partition: int,
1158
- hidden_size: int,
1159
- intermediate_size: int,
1160
- params_dtype: torch.dtype,
1161
- **extra_weight_attrs,
1162
- ):
1163
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
1164
-
1165
- if self.quant_config.is_checkpoint_fp8_serialized:
1166
- params_dtype = torch.float8_e4m3fn
1167
-
1168
- tp_size = get_tensor_model_parallel_world_size()
1169
- if self.block_quant:
1170
- block_n, block_k = (
1171
- self.quant_config.weight_block_size[0],
1172
- self.quant_config.weight_block_size[1],
1173
- )
1174
- # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
1175
- # Required by column parallel or enabling merged weights
1176
- if intermediate_size % block_n != 0:
1177
- raise ValueError(
1178
- f"The output_size of gate's and up's weight = "
1179
- f"{intermediate_size} is not divisible by "
1180
- f"weight quantization block_n = {block_n}."
1181
- )
1182
- if tp_size > 1:
1183
- # Required by row parallel
1184
- if intermediate_size % block_k != 0:
1185
- raise ValueError(
1186
- f"The input_size of down's weight = "
1187
- f"{intermediate_size} is not divisible by "
1188
- f"weight quantization block_k = {block_k}."
1189
- )
1190
-
1191
- # WEIGHTS
1192
- w13_weight = torch.nn.Parameter(
1193
- torch.empty(
1194
- num_experts_per_partition,
1195
- 2 * intermediate_size,
1196
- hidden_size,
1197
- dtype=params_dtype,
1198
- ),
1199
- requires_grad=False,
1200
- )
1201
- layer.register_parameter("w13_weight", w13_weight)
1202
- set_weight_attrs(w13_weight, extra_weight_attrs)
1203
-
1204
- w2_weight = torch.nn.Parameter(
1205
- torch.empty(
1206
- num_experts_per_partition,
1207
- hidden_size,
1208
- intermediate_size,
1209
- dtype=params_dtype,
1210
- ),
1211
- requires_grad=False,
1212
- )
1213
- layer.register_parameter("w2_weight", w2_weight)
1214
- set_weight_attrs(w2_weight, extra_weight_attrs)
1215
-
1216
- # WEIGHT_SCALES
1217
- if self.block_quant:
1218
- w13_weight_scale = torch.nn.Parameter(
1219
- torch.ones(
1220
- num_experts_per_partition,
1221
- 2 * ((intermediate_size + block_n - 1) // block_n),
1222
- (hidden_size + block_k - 1) // block_k,
1223
- dtype=torch.float32,
1224
- ),
1225
- requires_grad=False,
1226
- )
1227
- w2_weight_scale = torch.nn.Parameter(
1228
- torch.ones(
1229
- num_experts_per_partition,
1230
- (hidden_size + block_n - 1) // block_n,
1231
- (intermediate_size + block_k - 1) // block_k,
1232
- dtype=torch.float32,
1233
- ),
1234
- requires_grad=False,
1235
- )
1236
- layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
1237
- layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
1238
- assert self.quant_config.activation_scheme == "dynamic"
1239
- else:
1240
- # WEIGHT_SCALES
1241
- # Allocate 2 scales for w1 and w3 respectively.
1242
- w13_weight_scale = torch.nn.Parameter(
1243
- torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
1244
- requires_grad=False,
1245
- )
1246
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
1247
-
1248
- w2_weight_scale = torch.nn.Parameter(
1249
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1250
- requires_grad=False,
1251
- )
1252
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
1253
- # Add the quantization method used (per tensor/grouped/channel)
1254
- # to ensure the weight scales are loaded in properly
1255
- extra_weight_attrs.update(
1256
- {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
1257
- if self.block_quant
1258
- else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
1259
- )
1260
- # If loading fp8 checkpoint, pass the weight loaders.
1261
- # If loading an fp16 checkpoint, do not (we will quantize in
1262
- # process_weights_after_loading()
1263
- if self.quant_config.is_checkpoint_fp8_serialized:
1264
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1265
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1266
-
1267
- # INPUT_SCALES
1268
- if self.quant_config.activation_scheme == "static":
1269
- if not self.quant_config.is_checkpoint_fp8_serialized:
1270
- raise ValueError(
1271
- "Found static activation scheme for checkpoint that "
1272
- "was not serialized fp8."
1273
- )
1274
-
1275
- w13_input_scale = torch.nn.Parameter(
1276
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1277
- requires_grad=False,
1278
- )
1279
- layer.register_parameter("w13_input_scale", w13_input_scale)
1280
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
1281
-
1282
- w2_input_scale = torch.nn.Parameter(
1283
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1284
- requires_grad=False,
1285
- )
1286
- layer.register_parameter("w2_input_scale", w2_input_scale)
1287
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
1288
-
1289
- else:
1290
- layer.w13_input_scale = None
1291
- layer.w2_input_scale = None
1292
-
1293
- def process_weights_after_loading(self, layer: Module) -> None:
1294
-
1295
- # If checkpoint is fp16, quantize in place.
1296
- if not self.quant_config.is_checkpoint_fp8_serialized:
1297
- # If rocm, use float8_e4m3fnuz as dtype
1298
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1299
- w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
1300
- w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
1301
-
1302
- layer.w13_weight_scale = torch.nn.Parameter(
1303
- torch.ones(
1304
- layer.num_experts_per_partition,
1305
- dtype=torch.float32,
1306
- device=w13_weight.device,
1307
- ),
1308
- requires_grad=False,
1309
- )
1310
-
1311
- for expert in range(layer.num_experts_per_partition):
1312
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1313
- scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
1314
- )
1315
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1316
- scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
1317
- )
1318
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1319
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1320
- return
1321
-
1322
- # If checkpoint is fp8, we need to handle that the
1323
- # MoE kernels require single activation scale and single weight
1324
- # scale for w13 per expert.
1325
- else:
1326
- if self.quant_config.activation_scheme == "static":
1327
- if layer.w13_input_scale is None or layer.w2_input_scale is None:
1328
- raise ValueError(
1329
- "QuantConfig has static quantization, but found "
1330
- "activation scales are None."
1331
- )
1332
- layer.w13_weight_scale = torch.nn.Parameter(
1333
- torch.max(layer.w13_weight_scale, dim=1).values,
1334
- requires_grad=False,
1335
- )
1336
- if self.block_quant:
1337
- # If ROCm, normalize the weights and scales to e4m3fnuz
1338
- if _is_fp8_fnuz:
1339
- # activation_scheme: dynamic
1340
- w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1341
- weight=layer.w13_weight,
1342
- weight_scale=layer.w13_weight_scale_inv,
1343
- input_scale=None,
1344
- )
1345
- w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1346
- weight=layer.w2_weight,
1347
- weight_scale=layer.w2_weight_scale_inv,
1348
- input_scale=None,
1349
- )
1350
- # Reset the parameter
1351
- layer.w13_weight = torch.nn.Parameter(
1352
- w13_weight, requires_grad=False
1353
- )
1354
- layer.w13_weight_scale_inv = torch.nn.Parameter(
1355
- w13_weight_scale, requires_grad=False
1356
- )
1357
- layer.w13_input_scale = None
1358
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1359
- layer.w2_weight_scale_inv = torch.nn.Parameter(
1360
- w2_weight_scale, requires_grad=False
1361
- )
1362
- layer.w2_input_scale = None
1363
- if _use_aiter:
1364
- layer.w13_weight = torch.nn.Parameter(
1365
- shuffle_weight(layer.w13_weight.data, (16, 16)),
1366
- requires_grad=False,
1367
- )
1368
- layer.w2_weight = torch.nn.Parameter(
1369
- shuffle_weight(layer.w2_weight.data, (16, 16)),
1370
- requires_grad=False,
1371
- )
1372
- return
1373
-
1374
- def apply(
1375
- self,
1376
- layer: torch.nn.Module,
1377
- hidden_states: torch.Tensor,
1378
- topk_output: TopKOutput,
1379
- ) -> torch.Tensor:
1380
- raise NotImplementedError
1381
-
1382
-
1383
1161
  class Fp8KVCacheMethod(BaseKVCacheMethod):
1384
1162
  """
1385
1163
  Supports loading kv-cache scaling factors from FP8 checkpoints.
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
173
173
  fp8_max,
174
174
  # Meta-parameters
175
175
  BLOCK: tl.constexpr,
176
+ SCALE_UE8M0: tl.constexpr,
176
177
  ):
177
178
  """A Triton-accelerated function to perform per-token-group
178
179
  quantization on a tensor.
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
197
198
  # Quant
198
199
  _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
199
200
  y_s = _absmax / fp8_max
201
+ if SCALE_UE8M0:
202
+ y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
200
203
  y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
201
204
 
202
205
  tl.store(y_q_ptr + cols, y_q, mask=mask)
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
209
212
  eps: float = 1e-10,
210
213
  column_major_scales: bool = False,
211
214
  scale_tma_aligned: bool = False,
215
+ scale_ue8m0: bool = False,
212
216
  ) -> Tuple[torch.Tensor, torch.Tensor]:
213
217
  """Function to perform per-token-group quantization on an input tensor `x`.
214
218
 
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
229
233
  assert x.is_contiguous(), "`x` is not contiguous"
230
234
 
231
235
  x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
236
+ x_s = create_per_token_group_quant_fp8_output_scale(
237
+ x_shape=x.shape,
238
+ device=x.device,
239
+ group_size=group_size,
240
+ column_major_scales=column_major_scales,
241
+ scale_tma_aligned=scale_tma_aligned,
242
+ scale_ue8m0=False,
243
+ )
244
+
232
245
  M = x.numel() // group_size
233
246
  N = group_size
234
- if column_major_scales:
235
- if scale_tma_aligned:
236
- # aligned to 4 * sizeof(float)
237
- aligned_size = (x.shape[-2] + 3) // 4 * 4
238
- x_s = torch.empty(
239
- x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
240
- device=x.device,
241
- dtype=torch.float32,
242
- ).permute(-1, -2)[: x.shape[-2], :]
243
- else:
244
- x_s = torch.empty(
245
- (x.shape[-1] // group_size,) + x.shape[:-1],
246
- device=x.device,
247
- dtype=torch.float32,
248
- ).permute(-1, -2)
249
- else:
250
- x_s = torch.empty(
251
- x.shape[:-1] + (x.shape[-1] // group_size,),
252
- device=x.device,
253
- dtype=torch.float32,
254
- )
255
247
 
256
248
  BLOCK = triton.next_power_of_2(N)
257
249
  # heuristics for number of warps
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
271
263
  BLOCK=BLOCK,
272
264
  num_warps=num_warps,
273
265
  num_stages=num_stages,
266
+ SCALE_UE8M0=scale_ue8m0,
274
267
  )
275
268
  else:
269
+ assert not scale_ue8m0
276
270
  _per_token_group_quant_fp8[(M,)](
277
271
  x,
278
272
  x_q,
@@ -287,57 +281,93 @@ def per_token_group_quant_fp8(
287
281
  num_stages=num_stages,
288
282
  )
289
283
 
284
+ if scale_ue8m0:
285
+ from deep_gemm.utils.layout import transform_sf_into_required_layout
286
+
287
+ assert group_size == 128
288
+ x_s = transform_sf_into_required_layout(
289
+ x_s,
290
+ num_groups=None,
291
+ mn=x_q.shape[0],
292
+ k=x_q.shape[1],
293
+ recipe=(1, group_size, group_size),
294
+ is_sfa=True,
295
+ )
296
+
290
297
  return x_q, x_s
291
298
 
292
299
 
293
- def sglang_per_token_group_quant_fp8(
294
- x: torch.Tensor,
295
- group_size: int,
296
- eps: float = 1e-10,
297
- column_major_scales: bool = False,
298
- scale_tma_aligned: bool = False,
299
- scale_ue8m0: bool = False,
300
+ def create_per_token_group_quant_fp8_output_scale(
301
+ x_shape,
302
+ device,
303
+ group_size,
304
+ column_major_scales: bool,
305
+ scale_tma_aligned: bool,
306
+ scale_ue8m0: bool,
300
307
  ):
301
- assert (
302
- x.shape[-1] % group_size == 0
303
- ), "the last dimension of `x` cannot be divisible by `group_size`"
304
- assert x.is_contiguous(), "`x` is not contiguous"
305
-
306
- x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
307
308
  if scale_ue8m0:
308
309
  assert column_major_scales and scale_tma_aligned
309
- x_q_mn, x_q_k = x.shape
310
+ x_q_mn, x_q_k = x_shape
310
311
  x_s_mn, x_s_k = x_q_mn, x_q_k // 128
311
312
  aligned_mn = align(x_s_mn, 4)
312
313
  aligned_k = align(x_s_k, 4)
313
314
  # TODO(FIXME): Fix cuda kernel and recover here to empty.
314
- x_s = torch.zeros(
315
+ return torch.zeros(
315
316
  (aligned_k // 4, aligned_mn),
316
- device=x.device,
317
+ device=device,
317
318
  dtype=torch.int,
318
319
  ).transpose(0, 1)[:x_s_mn, :]
319
320
  elif column_major_scales:
320
321
  if scale_tma_aligned:
321
322
  # TODO extract "align" function
322
323
  # aligned to 4 * sizeof(float)
323
- aligned_size = (x.shape[-2] + 3) // 4 * 4
324
- x_s = torch.empty(
325
- x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
326
- device=x.device,
324
+ aligned_size = (x_shape[-2] + 3) // 4 * 4
325
+ return torch.empty(
326
+ x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
327
+ device=device,
327
328
  dtype=torch.float32,
328
- ).permute(-1, -2)[: x.shape[-2], :]
329
+ ).permute(-1, -2)[: x_shape[-2], :]
329
330
  else:
330
- x_s = torch.empty(
331
- (x.shape[-1] // group_size,) + x.shape[:-1],
332
- device=x.device,
331
+ return torch.empty(
332
+ (x_shape[-1] // group_size,) + x_shape[:-1],
333
+ device=device,
333
334
  dtype=torch.float32,
334
335
  ).permute(-1, -2)
335
336
  else:
336
- x_s = torch.empty(
337
- x.shape[:-1] + (x.shape[-1] // group_size,),
338
- device=x.device,
337
+ return torch.empty(
338
+ x_shape[:-1] + (x_shape[-1] // group_size,),
339
+ device=device,
339
340
  dtype=torch.float32,
340
341
  )
342
+
343
+
344
+ def sglang_per_token_group_quant_fp8(
345
+ x: torch.Tensor,
346
+ group_size: int,
347
+ eps: float = 1e-10,
348
+ column_major_scales: bool = False,
349
+ scale_tma_aligned: bool = False,
350
+ scale_ue8m0: bool = False,
351
+ ):
352
+ assert (
353
+ x.shape[-1] % group_size == 0
354
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
355
+ assert x.is_contiguous(), "`x` is not contiguous"
356
+
357
+ if scale_ue8m0:
358
+ # TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
359
+ assert x.shape[-1] % (group_size * 4) == 0
360
+
361
+ x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
362
+ x_s = create_per_token_group_quant_fp8_output_scale(
363
+ x_shape=x.shape,
364
+ device=x.device,
365
+ group_size=group_size,
366
+ column_major_scales=column_major_scales,
367
+ scale_tma_aligned=scale_tma_aligned,
368
+ scale_ue8m0=scale_ue8m0,
369
+ )
370
+
341
371
  if x.shape[0] > 0:
342
372
  sgl_per_token_group_quant_fp8(
343
373
  x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
35
35
  from sglang.srt.layers.moe.topk import TopKOutput
36
36
 
37
37
  if is_cuda():
38
- from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
38
+ from sgl_kernel import scaled_fp4_quant
39
+
40
+ try:
41
+ from flashinfer import mm_fp4 as fp4_gemm
42
+
43
+ enable_flashinfer_fp4_gemm = True
44
+ except ImportError:
45
+ if is_cuda():
46
+ from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
47
+ else:
48
+ fp4_gemm = None
49
+ enable_flashinfer_fp4_gemm = False
39
50
 
40
51
  try:
41
- from flashinfer import fp4_quantize as fp4_quantize
42
52
  from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
43
53
  except ImportError:
44
54
  flashinfer_cutlass_fused_moe = None
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
683
693
  assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
684
694
  assert layer.alpha.dtype == torch.float32
685
695
 
686
- out = cutlass_scaled_fp4_mm(
696
+ w = layer.weight
697
+ w_scale_interleaved = layer.weight_scale_interleaved
698
+ if enable_flashinfer_fp4_gemm:
699
+ w = layer.weight.T
700
+ w_scale_interleaved = layer.weight_scale_interleaved.T
701
+ out = fp4_gemm(
687
702
  x_fp4,
688
- layer.weight,
703
+ w,
689
704
  x_scale_interleaved,
690
- layer.weight_scale_interleaved,
705
+ w_scale_interleaved,
691
706
  layer.alpha,
692
707
  output_dtype,
693
708
  )
@@ -711,7 +726,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
711
726
  " quantization. Please use Blackwell and"
712
727
  " above."
713
728
  )
714
- self.enable_flashinfer_moe = False
729
+ self.enable_flashinfer_cutlass_moe = False
715
730
 
716
731
  def create_weights(
717
732
  self,
@@ -865,7 +880,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
865
880
  w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
866
881
  layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
867
882
 
868
- if self.enable_flashinfer_moe:
883
+ if self.enable_flashinfer_cutlass_moe:
869
884
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
870
885
  else:
871
886
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
@@ -894,7 +909,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
894
909
  layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
895
910
 
896
911
  # GEMM 2
897
- if self.enable_flashinfer_moe:
912
+ if self.enable_flashinfer_cutlass_moe:
898
913
  w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
899
914
  else:
900
915
  w2_input_scale = layer.w2_input_scale
@@ -934,7 +949,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
934
949
  @property
935
950
  def load_up_proj_weight_first(self) -> bool:
936
951
  # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
937
- return self.enable_flashinfer_moe
952
+ return self.enable_flashinfer_cutlass_moe
938
953
 
939
954
  def apply(
940
955
  self,
@@ -954,7 +969,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
954
969
  ) -> torch.Tensor:
955
970
  assert activation == "silu", "Only SiLU activation is supported."
956
971
 
957
- if self.enable_flashinfer_moe:
972
+ if self.enable_flashinfer_cutlass_moe:
958
973
  assert (
959
974
  not apply_router_weight_on_input
960
975
  ), "apply_router_weight_on_input is not supported for Flashinfer"