sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
|
|
172
172
|
self, layer: torch.nn.Module, prefix: str
|
173
173
|
) -> Optional[QuantizeMethodBase]:
|
174
174
|
from sglang.srt.layers.linear import LinearBase
|
175
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
175
176
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
176
177
|
|
177
178
|
if isinstance(layer, LinearBase):
|
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
|
|
180
181
|
return Fp8LinearMethod(self)
|
181
182
|
elif isinstance(layer, FusedMoE):
|
182
183
|
return Fp8MoEMethod(self)
|
184
|
+
elif isinstance(layer, EPMoE):
|
185
|
+
return Fp8EPMoEMethod(self)
|
183
186
|
return None
|
184
187
|
|
185
188
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
791
794
|
# merged w13 weights and generate a single scaling factor.
|
792
795
|
layer.w13_weight_scale = torch.nn.Parameter(
|
793
796
|
torch.ones(
|
794
|
-
layer.
|
797
|
+
layer.num_local_experts,
|
798
|
+
dtype=torch.float32,
|
799
|
+
device=w13_weight.device,
|
795
800
|
),
|
796
801
|
requires_grad=False,
|
797
802
|
)
|
798
|
-
for expert in range(layer.
|
803
|
+
for expert in range(layer.num_local_experts):
|
799
804
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
800
805
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
801
806
|
)
|
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
871
876
|
assert layer.w13_weight_scale is not None
|
872
877
|
shard_size = layer.intermediate_size_per_partition
|
873
878
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
874
|
-
for expert_id in range(layer.
|
879
|
+
for expert_id in range(layer.num_local_experts):
|
875
880
|
start = 0
|
876
881
|
for shard_id in range(2):
|
877
882
|
dq_weight = per_tensor_dequantize(
|
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
914
919
|
assert layer.w13_weight_scale is not None
|
915
920
|
shard_size = layer.intermediate_size_per_partition
|
916
921
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
917
|
-
for expert_id in range(layer.
|
922
|
+
for expert_id in range(layer.num_local_experts):
|
918
923
|
start = 0
|
919
924
|
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
920
925
|
for shard_id in range(2):
|
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
931
936
|
|
932
937
|
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
933
938
|
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
934
|
-
for expert_id in range(layer.
|
939
|
+
for expert_id in range(layer.num_local_experts):
|
935
940
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
936
941
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
937
942
|
|
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
979
984
|
no_combine: bool = False,
|
980
985
|
routed_scaling_factor: Optional[float] = None,
|
981
986
|
) -> torch.Tensor:
|
987
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
982
988
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
983
989
|
|
990
|
+
if isinstance(layer, EPMoE):
|
991
|
+
layer.w13_weight_scale = (
|
992
|
+
layer.w13_weight_scale_inv
|
993
|
+
if self.block_quant
|
994
|
+
else layer.w13_weight_scale
|
995
|
+
)
|
996
|
+
layer.w2_weight_scale = (
|
997
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
998
|
+
)
|
999
|
+
return layer.run_moe(
|
1000
|
+
hidden_states=x,
|
1001
|
+
topk_output=topk_output,
|
1002
|
+
)
|
1003
|
+
|
984
1004
|
if use_intel_amx_backend(layer):
|
985
1005
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
986
1006
|
|
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1138
1158
|
return None
|
1139
1159
|
|
1140
1160
|
|
1141
|
-
class Fp8EPMoEMethod(Fp8MoEMethod):
|
1142
|
-
"""MoE method for FP8.
|
1143
|
-
Supports loading FP8 checkpoints with static weight scale and
|
1144
|
-
dynamic/static activation scale.
|
1145
|
-
|
1146
|
-
Args:
|
1147
|
-
quant_config: The quantization config.
|
1148
|
-
"""
|
1149
|
-
|
1150
|
-
def __init__(self, quant_config: Fp8Config):
|
1151
|
-
self.quant_config = quant_config
|
1152
|
-
self.block_quant = self.quant_config.weight_block_size is not None
|
1153
|
-
|
1154
|
-
def create_weights(
|
1155
|
-
self,
|
1156
|
-
layer: Module,
|
1157
|
-
num_experts_per_partition: int,
|
1158
|
-
hidden_size: int,
|
1159
|
-
intermediate_size: int,
|
1160
|
-
params_dtype: torch.dtype,
|
1161
|
-
**extra_weight_attrs,
|
1162
|
-
):
|
1163
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
1164
|
-
|
1165
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1166
|
-
params_dtype = torch.float8_e4m3fn
|
1167
|
-
|
1168
|
-
tp_size = get_tensor_model_parallel_world_size()
|
1169
|
-
if self.block_quant:
|
1170
|
-
block_n, block_k = (
|
1171
|
-
self.quant_config.weight_block_size[0],
|
1172
|
-
self.quant_config.weight_block_size[1],
|
1173
|
-
)
|
1174
|
-
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
1175
|
-
# Required by column parallel or enabling merged weights
|
1176
|
-
if intermediate_size % block_n != 0:
|
1177
|
-
raise ValueError(
|
1178
|
-
f"The output_size of gate's and up's weight = "
|
1179
|
-
f"{intermediate_size} is not divisible by "
|
1180
|
-
f"weight quantization block_n = {block_n}."
|
1181
|
-
)
|
1182
|
-
if tp_size > 1:
|
1183
|
-
# Required by row parallel
|
1184
|
-
if intermediate_size % block_k != 0:
|
1185
|
-
raise ValueError(
|
1186
|
-
f"The input_size of down's weight = "
|
1187
|
-
f"{intermediate_size} is not divisible by "
|
1188
|
-
f"weight quantization block_k = {block_k}."
|
1189
|
-
)
|
1190
|
-
|
1191
|
-
# WEIGHTS
|
1192
|
-
w13_weight = torch.nn.Parameter(
|
1193
|
-
torch.empty(
|
1194
|
-
num_experts_per_partition,
|
1195
|
-
2 * intermediate_size,
|
1196
|
-
hidden_size,
|
1197
|
-
dtype=params_dtype,
|
1198
|
-
),
|
1199
|
-
requires_grad=False,
|
1200
|
-
)
|
1201
|
-
layer.register_parameter("w13_weight", w13_weight)
|
1202
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
1203
|
-
|
1204
|
-
w2_weight = torch.nn.Parameter(
|
1205
|
-
torch.empty(
|
1206
|
-
num_experts_per_partition,
|
1207
|
-
hidden_size,
|
1208
|
-
intermediate_size,
|
1209
|
-
dtype=params_dtype,
|
1210
|
-
),
|
1211
|
-
requires_grad=False,
|
1212
|
-
)
|
1213
|
-
layer.register_parameter("w2_weight", w2_weight)
|
1214
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
1215
|
-
|
1216
|
-
# WEIGHT_SCALES
|
1217
|
-
if self.block_quant:
|
1218
|
-
w13_weight_scale = torch.nn.Parameter(
|
1219
|
-
torch.ones(
|
1220
|
-
num_experts_per_partition,
|
1221
|
-
2 * ((intermediate_size + block_n - 1) // block_n),
|
1222
|
-
(hidden_size + block_k - 1) // block_k,
|
1223
|
-
dtype=torch.float32,
|
1224
|
-
),
|
1225
|
-
requires_grad=False,
|
1226
|
-
)
|
1227
|
-
w2_weight_scale = torch.nn.Parameter(
|
1228
|
-
torch.ones(
|
1229
|
-
num_experts_per_partition,
|
1230
|
-
(hidden_size + block_n - 1) // block_n,
|
1231
|
-
(intermediate_size + block_k - 1) // block_k,
|
1232
|
-
dtype=torch.float32,
|
1233
|
-
),
|
1234
|
-
requires_grad=False,
|
1235
|
-
)
|
1236
|
-
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
1237
|
-
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
1238
|
-
assert self.quant_config.activation_scheme == "dynamic"
|
1239
|
-
else:
|
1240
|
-
# WEIGHT_SCALES
|
1241
|
-
# Allocate 2 scales for w1 and w3 respectively.
|
1242
|
-
w13_weight_scale = torch.nn.Parameter(
|
1243
|
-
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
1244
|
-
requires_grad=False,
|
1245
|
-
)
|
1246
|
-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
1247
|
-
|
1248
|
-
w2_weight_scale = torch.nn.Parameter(
|
1249
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1250
|
-
requires_grad=False,
|
1251
|
-
)
|
1252
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
1253
|
-
# Add the quantization method used (per tensor/grouped/channel)
|
1254
|
-
# to ensure the weight scales are loaded in properly
|
1255
|
-
extra_weight_attrs.update(
|
1256
|
-
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
1257
|
-
if self.block_quant
|
1258
|
-
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
1259
|
-
)
|
1260
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
1261
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
1262
|
-
# process_weights_after_loading()
|
1263
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1264
|
-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
1265
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
1266
|
-
|
1267
|
-
# INPUT_SCALES
|
1268
|
-
if self.quant_config.activation_scheme == "static":
|
1269
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1270
|
-
raise ValueError(
|
1271
|
-
"Found static activation scheme for checkpoint that "
|
1272
|
-
"was not serialized fp8."
|
1273
|
-
)
|
1274
|
-
|
1275
|
-
w13_input_scale = torch.nn.Parameter(
|
1276
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1277
|
-
requires_grad=False,
|
1278
|
-
)
|
1279
|
-
layer.register_parameter("w13_input_scale", w13_input_scale)
|
1280
|
-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
1281
|
-
|
1282
|
-
w2_input_scale = torch.nn.Parameter(
|
1283
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1284
|
-
requires_grad=False,
|
1285
|
-
)
|
1286
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
1287
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
1288
|
-
|
1289
|
-
else:
|
1290
|
-
layer.w13_input_scale = None
|
1291
|
-
layer.w2_input_scale = None
|
1292
|
-
|
1293
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
1294
|
-
|
1295
|
-
# If checkpoint is fp16, quantize in place.
|
1296
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1297
|
-
# If rocm, use float8_e4m3fnuz as dtype
|
1298
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1299
|
-
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
1300
|
-
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
1301
|
-
|
1302
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1303
|
-
torch.ones(
|
1304
|
-
layer.num_experts_per_partition,
|
1305
|
-
dtype=torch.float32,
|
1306
|
-
device=w13_weight.device,
|
1307
|
-
),
|
1308
|
-
requires_grad=False,
|
1309
|
-
)
|
1310
|
-
|
1311
|
-
for expert in range(layer.num_experts_per_partition):
|
1312
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
1313
|
-
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
1314
|
-
)
|
1315
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
1316
|
-
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
1317
|
-
)
|
1318
|
-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
1319
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1320
|
-
return
|
1321
|
-
|
1322
|
-
# If checkpoint is fp8, we need to handle that the
|
1323
|
-
# MoE kernels require single activation scale and single weight
|
1324
|
-
# scale for w13 per expert.
|
1325
|
-
else:
|
1326
|
-
if self.quant_config.activation_scheme == "static":
|
1327
|
-
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
1328
|
-
raise ValueError(
|
1329
|
-
"QuantConfig has static quantization, but found "
|
1330
|
-
"activation scales are None."
|
1331
|
-
)
|
1332
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1333
|
-
torch.max(layer.w13_weight_scale, dim=1).values,
|
1334
|
-
requires_grad=False,
|
1335
|
-
)
|
1336
|
-
if self.block_quant:
|
1337
|
-
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1338
|
-
if _is_fp8_fnuz:
|
1339
|
-
# activation_scheme: dynamic
|
1340
|
-
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1341
|
-
weight=layer.w13_weight,
|
1342
|
-
weight_scale=layer.w13_weight_scale_inv,
|
1343
|
-
input_scale=None,
|
1344
|
-
)
|
1345
|
-
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1346
|
-
weight=layer.w2_weight,
|
1347
|
-
weight_scale=layer.w2_weight_scale_inv,
|
1348
|
-
input_scale=None,
|
1349
|
-
)
|
1350
|
-
# Reset the parameter
|
1351
|
-
layer.w13_weight = torch.nn.Parameter(
|
1352
|
-
w13_weight, requires_grad=False
|
1353
|
-
)
|
1354
|
-
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1355
|
-
w13_weight_scale, requires_grad=False
|
1356
|
-
)
|
1357
|
-
layer.w13_input_scale = None
|
1358
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1359
|
-
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1360
|
-
w2_weight_scale, requires_grad=False
|
1361
|
-
)
|
1362
|
-
layer.w2_input_scale = None
|
1363
|
-
if _use_aiter:
|
1364
|
-
layer.w13_weight = torch.nn.Parameter(
|
1365
|
-
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1366
|
-
requires_grad=False,
|
1367
|
-
)
|
1368
|
-
layer.w2_weight = torch.nn.Parameter(
|
1369
|
-
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1370
|
-
requires_grad=False,
|
1371
|
-
)
|
1372
|
-
return
|
1373
|
-
|
1374
|
-
def apply(
|
1375
|
-
self,
|
1376
|
-
layer: torch.nn.Module,
|
1377
|
-
hidden_states: torch.Tensor,
|
1378
|
-
topk_output: TopKOutput,
|
1379
|
-
) -> torch.Tensor:
|
1380
|
-
raise NotImplementedError
|
1381
|
-
|
1382
|
-
|
1383
1161
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
1384
1162
|
"""
|
1385
1163
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
|
|
173
173
|
fp8_max,
|
174
174
|
# Meta-parameters
|
175
175
|
BLOCK: tl.constexpr,
|
176
|
+
SCALE_UE8M0: tl.constexpr,
|
176
177
|
):
|
177
178
|
"""A Triton-accelerated function to perform per-token-group
|
178
179
|
quantization on a tensor.
|
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|
197
198
|
# Quant
|
198
199
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
199
200
|
y_s = _absmax / fp8_max
|
201
|
+
if SCALE_UE8M0:
|
202
|
+
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
200
203
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
201
204
|
|
202
205
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
|
|
209
212
|
eps: float = 1e-10,
|
210
213
|
column_major_scales: bool = False,
|
211
214
|
scale_tma_aligned: bool = False,
|
215
|
+
scale_ue8m0: bool = False,
|
212
216
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213
217
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
214
218
|
|
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
|
|
229
233
|
assert x.is_contiguous(), "`x` is not contiguous"
|
230
234
|
|
231
235
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
236
|
+
x_s = create_per_token_group_quant_fp8_output_scale(
|
237
|
+
x_shape=x.shape,
|
238
|
+
device=x.device,
|
239
|
+
group_size=group_size,
|
240
|
+
column_major_scales=column_major_scales,
|
241
|
+
scale_tma_aligned=scale_tma_aligned,
|
242
|
+
scale_ue8m0=False,
|
243
|
+
)
|
244
|
+
|
232
245
|
M = x.numel() // group_size
|
233
246
|
N = group_size
|
234
|
-
if column_major_scales:
|
235
|
-
if scale_tma_aligned:
|
236
|
-
# aligned to 4 * sizeof(float)
|
237
|
-
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
238
|
-
x_s = torch.empty(
|
239
|
-
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
240
|
-
device=x.device,
|
241
|
-
dtype=torch.float32,
|
242
|
-
).permute(-1, -2)[: x.shape[-2], :]
|
243
|
-
else:
|
244
|
-
x_s = torch.empty(
|
245
|
-
(x.shape[-1] // group_size,) + x.shape[:-1],
|
246
|
-
device=x.device,
|
247
|
-
dtype=torch.float32,
|
248
|
-
).permute(-1, -2)
|
249
|
-
else:
|
250
|
-
x_s = torch.empty(
|
251
|
-
x.shape[:-1] + (x.shape[-1] // group_size,),
|
252
|
-
device=x.device,
|
253
|
-
dtype=torch.float32,
|
254
|
-
)
|
255
247
|
|
256
248
|
BLOCK = triton.next_power_of_2(N)
|
257
249
|
# heuristics for number of warps
|
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
|
|
271
263
|
BLOCK=BLOCK,
|
272
264
|
num_warps=num_warps,
|
273
265
|
num_stages=num_stages,
|
266
|
+
SCALE_UE8M0=scale_ue8m0,
|
274
267
|
)
|
275
268
|
else:
|
269
|
+
assert not scale_ue8m0
|
276
270
|
_per_token_group_quant_fp8[(M,)](
|
277
271
|
x,
|
278
272
|
x_q,
|
@@ -287,57 +281,93 @@ def per_token_group_quant_fp8(
|
|
287
281
|
num_stages=num_stages,
|
288
282
|
)
|
289
283
|
|
284
|
+
if scale_ue8m0:
|
285
|
+
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
286
|
+
|
287
|
+
assert group_size == 128
|
288
|
+
x_s = transform_sf_into_required_layout(
|
289
|
+
x_s,
|
290
|
+
num_groups=None,
|
291
|
+
mn=x_q.shape[0],
|
292
|
+
k=x_q.shape[1],
|
293
|
+
recipe=(1, group_size, group_size),
|
294
|
+
is_sfa=True,
|
295
|
+
)
|
296
|
+
|
290
297
|
return x_q, x_s
|
291
298
|
|
292
299
|
|
293
|
-
def
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
column_major_scales: bool
|
298
|
-
scale_tma_aligned: bool
|
299
|
-
scale_ue8m0: bool
|
300
|
+
def create_per_token_group_quant_fp8_output_scale(
|
301
|
+
x_shape,
|
302
|
+
device,
|
303
|
+
group_size,
|
304
|
+
column_major_scales: bool,
|
305
|
+
scale_tma_aligned: bool,
|
306
|
+
scale_ue8m0: bool,
|
300
307
|
):
|
301
|
-
assert (
|
302
|
-
x.shape[-1] % group_size == 0
|
303
|
-
), "the last dimension of `x` cannot be divisible by `group_size`"
|
304
|
-
assert x.is_contiguous(), "`x` is not contiguous"
|
305
|
-
|
306
|
-
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
307
308
|
if scale_ue8m0:
|
308
309
|
assert column_major_scales and scale_tma_aligned
|
309
|
-
x_q_mn, x_q_k =
|
310
|
+
x_q_mn, x_q_k = x_shape
|
310
311
|
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
311
312
|
aligned_mn = align(x_s_mn, 4)
|
312
313
|
aligned_k = align(x_s_k, 4)
|
313
314
|
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
314
|
-
|
315
|
+
return torch.zeros(
|
315
316
|
(aligned_k // 4, aligned_mn),
|
316
|
-
device=
|
317
|
+
device=device,
|
317
318
|
dtype=torch.int,
|
318
319
|
).transpose(0, 1)[:x_s_mn, :]
|
319
320
|
elif column_major_scales:
|
320
321
|
if scale_tma_aligned:
|
321
322
|
# TODO extract "align" function
|
322
323
|
# aligned to 4 * sizeof(float)
|
323
|
-
aligned_size = (
|
324
|
-
|
325
|
-
|
326
|
-
device=
|
324
|
+
aligned_size = (x_shape[-2] + 3) // 4 * 4
|
325
|
+
return torch.empty(
|
326
|
+
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
|
327
|
+
device=device,
|
327
328
|
dtype=torch.float32,
|
328
|
-
).permute(-1, -2)[:
|
329
|
+
).permute(-1, -2)[: x_shape[-2], :]
|
329
330
|
else:
|
330
|
-
|
331
|
-
(
|
332
|
-
device=
|
331
|
+
return torch.empty(
|
332
|
+
(x_shape[-1] // group_size,) + x_shape[:-1],
|
333
|
+
device=device,
|
333
334
|
dtype=torch.float32,
|
334
335
|
).permute(-1, -2)
|
335
336
|
else:
|
336
|
-
|
337
|
-
|
338
|
-
device=
|
337
|
+
return torch.empty(
|
338
|
+
x_shape[:-1] + (x_shape[-1] // group_size,),
|
339
|
+
device=device,
|
339
340
|
dtype=torch.float32,
|
340
341
|
)
|
342
|
+
|
343
|
+
|
344
|
+
def sglang_per_token_group_quant_fp8(
|
345
|
+
x: torch.Tensor,
|
346
|
+
group_size: int,
|
347
|
+
eps: float = 1e-10,
|
348
|
+
column_major_scales: bool = False,
|
349
|
+
scale_tma_aligned: bool = False,
|
350
|
+
scale_ue8m0: bool = False,
|
351
|
+
):
|
352
|
+
assert (
|
353
|
+
x.shape[-1] % group_size == 0
|
354
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
355
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
356
|
+
|
357
|
+
if scale_ue8m0:
|
358
|
+
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
359
|
+
assert x.shape[-1] % (group_size * 4) == 0
|
360
|
+
|
361
|
+
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
362
|
+
x_s = create_per_token_group_quant_fp8_output_scale(
|
363
|
+
x_shape=x.shape,
|
364
|
+
device=x.device,
|
365
|
+
group_size=group_size,
|
366
|
+
column_major_scales=column_major_scales,
|
367
|
+
scale_tma_aligned=scale_tma_aligned,
|
368
|
+
scale_ue8m0=scale_ue8m0,
|
369
|
+
)
|
370
|
+
|
341
371
|
if x.shape[0] > 0:
|
342
372
|
sgl_per_token_group_quant_fp8(
|
343
373
|
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
|
|
35
35
|
from sglang.srt.layers.moe.topk import TopKOutput
|
36
36
|
|
37
37
|
if is_cuda():
|
38
|
-
from sgl_kernel import
|
38
|
+
from sgl_kernel import scaled_fp4_quant
|
39
|
+
|
40
|
+
try:
|
41
|
+
from flashinfer import mm_fp4 as fp4_gemm
|
42
|
+
|
43
|
+
enable_flashinfer_fp4_gemm = True
|
44
|
+
except ImportError:
|
45
|
+
if is_cuda():
|
46
|
+
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
|
47
|
+
else:
|
48
|
+
fp4_gemm = None
|
49
|
+
enable_flashinfer_fp4_gemm = False
|
39
50
|
|
40
51
|
try:
|
41
|
-
from flashinfer import fp4_quantize as fp4_quantize
|
42
52
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
43
53
|
except ImportError:
|
44
54
|
flashinfer_cutlass_fused_moe = None
|
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
683
693
|
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
684
694
|
assert layer.alpha.dtype == torch.float32
|
685
695
|
|
686
|
-
|
696
|
+
w = layer.weight
|
697
|
+
w_scale_interleaved = layer.weight_scale_interleaved
|
698
|
+
if enable_flashinfer_fp4_gemm:
|
699
|
+
w = layer.weight.T
|
700
|
+
w_scale_interleaved = layer.weight_scale_interleaved.T
|
701
|
+
out = fp4_gemm(
|
687
702
|
x_fp4,
|
688
|
-
|
703
|
+
w,
|
689
704
|
x_scale_interleaved,
|
690
|
-
|
705
|
+
w_scale_interleaved,
|
691
706
|
layer.alpha,
|
692
707
|
output_dtype,
|
693
708
|
)
|
@@ -711,7 +726,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
711
726
|
" quantization. Please use Blackwell and"
|
712
727
|
" above."
|
713
728
|
)
|
714
|
-
self.
|
729
|
+
self.enable_flashinfer_cutlass_moe = False
|
715
730
|
|
716
731
|
def create_weights(
|
717
732
|
self,
|
@@ -865,7 +880,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
865
880
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
866
881
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
867
882
|
|
868
|
-
if self.
|
883
|
+
if self.enable_flashinfer_cutlass_moe:
|
869
884
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
870
885
|
else:
|
871
886
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
@@ -885,6 +900,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
885
900
|
layer.w13_blockscale_swizzled = Parameter(
|
886
901
|
w13_blockscale_swizzled, requires_grad=False
|
887
902
|
)
|
903
|
+
del layer.w13_weight_scale
|
888
904
|
|
889
905
|
# This is for quantization, so we need to invert it.
|
890
906
|
layer.w13_input_scale_quant = Parameter(
|
@@ -894,7 +910,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
894
910
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
895
911
|
|
896
912
|
# GEMM 2
|
897
|
-
if self.
|
913
|
+
if self.enable_flashinfer_cutlass_moe:
|
898
914
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
899
915
|
else:
|
900
916
|
w2_input_scale = layer.w2_input_scale
|
@@ -920,6 +936,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
920
936
|
layer.w2_blockscale_swizzled = Parameter(
|
921
937
|
w2_blockscale_swizzled, requires_grad=False
|
922
938
|
)
|
939
|
+
del layer.w2_weight_scale
|
923
940
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
924
941
|
|
925
942
|
device = layer.w13_weight.device
|
@@ -934,7 +951,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
934
951
|
@property
|
935
952
|
def load_up_proj_weight_first(self) -> bool:
|
936
953
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
937
|
-
return self.
|
954
|
+
return self.enable_flashinfer_cutlass_moe
|
938
955
|
|
939
956
|
def apply(
|
940
957
|
self,
|
@@ -954,7 +971,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
954
971
|
) -> torch.Tensor:
|
955
972
|
assert activation == "silu", "Only SiLU activation is supported."
|
956
973
|
|
957
|
-
if self.
|
974
|
+
if self.enable_flashinfer_cutlass_moe:
|
958
975
|
assert (
|
959
976
|
not apply_router_weight_on_input
|
960
977
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|