sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
|
|
172
172
|
self, layer: torch.nn.Module, prefix: str
|
173
173
|
) -> Optional[QuantizeMethodBase]:
|
174
174
|
from sglang.srt.layers.linear import LinearBase
|
175
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
175
176
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
176
177
|
|
177
178
|
if isinstance(layer, LinearBase):
|
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
|
|
180
181
|
return Fp8LinearMethod(self)
|
181
182
|
elif isinstance(layer, FusedMoE):
|
182
183
|
return Fp8MoEMethod(self)
|
184
|
+
elif isinstance(layer, EPMoE):
|
185
|
+
return Fp8EPMoEMethod(self)
|
183
186
|
return None
|
184
187
|
|
185
188
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
791
794
|
# merged w13 weights and generate a single scaling factor.
|
792
795
|
layer.w13_weight_scale = torch.nn.Parameter(
|
793
796
|
torch.ones(
|
794
|
-
layer.
|
797
|
+
layer.num_local_experts,
|
798
|
+
dtype=torch.float32,
|
799
|
+
device=w13_weight.device,
|
795
800
|
),
|
796
801
|
requires_grad=False,
|
797
802
|
)
|
798
|
-
for expert in range(layer.
|
803
|
+
for expert in range(layer.num_local_experts):
|
799
804
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
800
805
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
801
806
|
)
|
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
871
876
|
assert layer.w13_weight_scale is not None
|
872
877
|
shard_size = layer.intermediate_size_per_partition
|
873
878
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
874
|
-
for expert_id in range(layer.
|
879
|
+
for expert_id in range(layer.num_local_experts):
|
875
880
|
start = 0
|
876
881
|
for shard_id in range(2):
|
877
882
|
dq_weight = per_tensor_dequantize(
|
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
914
919
|
assert layer.w13_weight_scale is not None
|
915
920
|
shard_size = layer.intermediate_size_per_partition
|
916
921
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
917
|
-
for expert_id in range(layer.
|
922
|
+
for expert_id in range(layer.num_local_experts):
|
918
923
|
start = 0
|
919
924
|
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
920
925
|
for shard_id in range(2):
|
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
931
936
|
|
932
937
|
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
933
938
|
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
934
|
-
for expert_id in range(layer.
|
939
|
+
for expert_id in range(layer.num_local_experts):
|
935
940
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
936
941
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
937
942
|
|
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
979
984
|
no_combine: bool = False,
|
980
985
|
routed_scaling_factor: Optional[float] = None,
|
981
986
|
) -> torch.Tensor:
|
987
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
982
988
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
983
989
|
|
990
|
+
if isinstance(layer, EPMoE):
|
991
|
+
layer.w13_weight_scale = (
|
992
|
+
layer.w13_weight_scale_inv
|
993
|
+
if self.block_quant
|
994
|
+
else layer.w13_weight_scale
|
995
|
+
)
|
996
|
+
layer.w2_weight_scale = (
|
997
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
998
|
+
)
|
999
|
+
return layer.run_moe(
|
1000
|
+
hidden_states=x,
|
1001
|
+
topk_output=topk_output,
|
1002
|
+
)
|
1003
|
+
|
984
1004
|
if use_intel_amx_backend(layer):
|
985
1005
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
986
1006
|
|
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1138
1158
|
return None
|
1139
1159
|
|
1140
1160
|
|
1141
|
-
class Fp8EPMoEMethod(Fp8MoEMethod):
|
1142
|
-
"""MoE method for FP8.
|
1143
|
-
Supports loading FP8 checkpoints with static weight scale and
|
1144
|
-
dynamic/static activation scale.
|
1145
|
-
|
1146
|
-
Args:
|
1147
|
-
quant_config: The quantization config.
|
1148
|
-
"""
|
1149
|
-
|
1150
|
-
def __init__(self, quant_config: Fp8Config):
|
1151
|
-
self.quant_config = quant_config
|
1152
|
-
self.block_quant = self.quant_config.weight_block_size is not None
|
1153
|
-
|
1154
|
-
def create_weights(
|
1155
|
-
self,
|
1156
|
-
layer: Module,
|
1157
|
-
num_experts_per_partition: int,
|
1158
|
-
hidden_size: int,
|
1159
|
-
intermediate_size: int,
|
1160
|
-
params_dtype: torch.dtype,
|
1161
|
-
**extra_weight_attrs,
|
1162
|
-
):
|
1163
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
1164
|
-
|
1165
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1166
|
-
params_dtype = torch.float8_e4m3fn
|
1167
|
-
|
1168
|
-
tp_size = get_tensor_model_parallel_world_size()
|
1169
|
-
if self.block_quant:
|
1170
|
-
block_n, block_k = (
|
1171
|
-
self.quant_config.weight_block_size[0],
|
1172
|
-
self.quant_config.weight_block_size[1],
|
1173
|
-
)
|
1174
|
-
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
1175
|
-
# Required by column parallel or enabling merged weights
|
1176
|
-
if intermediate_size % block_n != 0:
|
1177
|
-
raise ValueError(
|
1178
|
-
f"The output_size of gate's and up's weight = "
|
1179
|
-
f"{intermediate_size} is not divisible by "
|
1180
|
-
f"weight quantization block_n = {block_n}."
|
1181
|
-
)
|
1182
|
-
if tp_size > 1:
|
1183
|
-
# Required by row parallel
|
1184
|
-
if intermediate_size % block_k != 0:
|
1185
|
-
raise ValueError(
|
1186
|
-
f"The input_size of down's weight = "
|
1187
|
-
f"{intermediate_size} is not divisible by "
|
1188
|
-
f"weight quantization block_k = {block_k}."
|
1189
|
-
)
|
1190
|
-
|
1191
|
-
# WEIGHTS
|
1192
|
-
w13_weight = torch.nn.Parameter(
|
1193
|
-
torch.empty(
|
1194
|
-
num_experts_per_partition,
|
1195
|
-
2 * intermediate_size,
|
1196
|
-
hidden_size,
|
1197
|
-
dtype=params_dtype,
|
1198
|
-
),
|
1199
|
-
requires_grad=False,
|
1200
|
-
)
|
1201
|
-
layer.register_parameter("w13_weight", w13_weight)
|
1202
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
1203
|
-
|
1204
|
-
w2_weight = torch.nn.Parameter(
|
1205
|
-
torch.empty(
|
1206
|
-
num_experts_per_partition,
|
1207
|
-
hidden_size,
|
1208
|
-
intermediate_size,
|
1209
|
-
dtype=params_dtype,
|
1210
|
-
),
|
1211
|
-
requires_grad=False,
|
1212
|
-
)
|
1213
|
-
layer.register_parameter("w2_weight", w2_weight)
|
1214
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
1215
|
-
|
1216
|
-
# WEIGHT_SCALES
|
1217
|
-
if self.block_quant:
|
1218
|
-
w13_weight_scale = torch.nn.Parameter(
|
1219
|
-
torch.ones(
|
1220
|
-
num_experts_per_partition,
|
1221
|
-
2 * ((intermediate_size + block_n - 1) // block_n),
|
1222
|
-
(hidden_size + block_k - 1) // block_k,
|
1223
|
-
dtype=torch.float32,
|
1224
|
-
),
|
1225
|
-
requires_grad=False,
|
1226
|
-
)
|
1227
|
-
w2_weight_scale = torch.nn.Parameter(
|
1228
|
-
torch.ones(
|
1229
|
-
num_experts_per_partition,
|
1230
|
-
(hidden_size + block_n - 1) // block_n,
|
1231
|
-
(intermediate_size + block_k - 1) // block_k,
|
1232
|
-
dtype=torch.float32,
|
1233
|
-
),
|
1234
|
-
requires_grad=False,
|
1235
|
-
)
|
1236
|
-
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
1237
|
-
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
1238
|
-
assert self.quant_config.activation_scheme == "dynamic"
|
1239
|
-
else:
|
1240
|
-
# WEIGHT_SCALES
|
1241
|
-
# Allocate 2 scales for w1 and w3 respectively.
|
1242
|
-
w13_weight_scale = torch.nn.Parameter(
|
1243
|
-
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
1244
|
-
requires_grad=False,
|
1245
|
-
)
|
1246
|
-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
1247
|
-
|
1248
|
-
w2_weight_scale = torch.nn.Parameter(
|
1249
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1250
|
-
requires_grad=False,
|
1251
|
-
)
|
1252
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
1253
|
-
# Add the quantization method used (per tensor/grouped/channel)
|
1254
|
-
# to ensure the weight scales are loaded in properly
|
1255
|
-
extra_weight_attrs.update(
|
1256
|
-
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
1257
|
-
if self.block_quant
|
1258
|
-
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
1259
|
-
)
|
1260
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
1261
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
1262
|
-
# process_weights_after_loading()
|
1263
|
-
if self.quant_config.is_checkpoint_fp8_serialized:
|
1264
|
-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
1265
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
1266
|
-
|
1267
|
-
# INPUT_SCALES
|
1268
|
-
if self.quant_config.activation_scheme == "static":
|
1269
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1270
|
-
raise ValueError(
|
1271
|
-
"Found static activation scheme for checkpoint that "
|
1272
|
-
"was not serialized fp8."
|
1273
|
-
)
|
1274
|
-
|
1275
|
-
w13_input_scale = torch.nn.Parameter(
|
1276
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1277
|
-
requires_grad=False,
|
1278
|
-
)
|
1279
|
-
layer.register_parameter("w13_input_scale", w13_input_scale)
|
1280
|
-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
1281
|
-
|
1282
|
-
w2_input_scale = torch.nn.Parameter(
|
1283
|
-
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1284
|
-
requires_grad=False,
|
1285
|
-
)
|
1286
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
1287
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
1288
|
-
|
1289
|
-
else:
|
1290
|
-
layer.w13_input_scale = None
|
1291
|
-
layer.w2_input_scale = None
|
1292
|
-
|
1293
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
1294
|
-
|
1295
|
-
# If checkpoint is fp16, quantize in place.
|
1296
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1297
|
-
# If rocm, use float8_e4m3fnuz as dtype
|
1298
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1299
|
-
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
1300
|
-
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
1301
|
-
|
1302
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1303
|
-
torch.ones(
|
1304
|
-
layer.num_experts_per_partition,
|
1305
|
-
dtype=torch.float32,
|
1306
|
-
device=w13_weight.device,
|
1307
|
-
),
|
1308
|
-
requires_grad=False,
|
1309
|
-
)
|
1310
|
-
|
1311
|
-
for expert in range(layer.num_experts_per_partition):
|
1312
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
1313
|
-
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
1314
|
-
)
|
1315
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
1316
|
-
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
1317
|
-
)
|
1318
|
-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
1319
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1320
|
-
return
|
1321
|
-
|
1322
|
-
# If checkpoint is fp8, we need to handle that the
|
1323
|
-
# MoE kernels require single activation scale and single weight
|
1324
|
-
# scale for w13 per expert.
|
1325
|
-
else:
|
1326
|
-
if self.quant_config.activation_scheme == "static":
|
1327
|
-
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
1328
|
-
raise ValueError(
|
1329
|
-
"QuantConfig has static quantization, but found "
|
1330
|
-
"activation scales are None."
|
1331
|
-
)
|
1332
|
-
layer.w13_weight_scale = torch.nn.Parameter(
|
1333
|
-
torch.max(layer.w13_weight_scale, dim=1).values,
|
1334
|
-
requires_grad=False,
|
1335
|
-
)
|
1336
|
-
if self.block_quant:
|
1337
|
-
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1338
|
-
if _is_fp8_fnuz:
|
1339
|
-
# activation_scheme: dynamic
|
1340
|
-
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1341
|
-
weight=layer.w13_weight,
|
1342
|
-
weight_scale=layer.w13_weight_scale_inv,
|
1343
|
-
input_scale=None,
|
1344
|
-
)
|
1345
|
-
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1346
|
-
weight=layer.w2_weight,
|
1347
|
-
weight_scale=layer.w2_weight_scale_inv,
|
1348
|
-
input_scale=None,
|
1349
|
-
)
|
1350
|
-
# Reset the parameter
|
1351
|
-
layer.w13_weight = torch.nn.Parameter(
|
1352
|
-
w13_weight, requires_grad=False
|
1353
|
-
)
|
1354
|
-
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1355
|
-
w13_weight_scale, requires_grad=False
|
1356
|
-
)
|
1357
|
-
layer.w13_input_scale = None
|
1358
|
-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1359
|
-
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1360
|
-
w2_weight_scale, requires_grad=False
|
1361
|
-
)
|
1362
|
-
layer.w2_input_scale = None
|
1363
|
-
if _use_aiter:
|
1364
|
-
layer.w13_weight = torch.nn.Parameter(
|
1365
|
-
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1366
|
-
requires_grad=False,
|
1367
|
-
)
|
1368
|
-
layer.w2_weight = torch.nn.Parameter(
|
1369
|
-
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1370
|
-
requires_grad=False,
|
1371
|
-
)
|
1372
|
-
return
|
1373
|
-
|
1374
|
-
def apply(
|
1375
|
-
self,
|
1376
|
-
layer: torch.nn.Module,
|
1377
|
-
hidden_states: torch.Tensor,
|
1378
|
-
topk_output: TopKOutput,
|
1379
|
-
) -> torch.Tensor:
|
1380
|
-
raise NotImplementedError
|
1381
|
-
|
1382
|
-
|
1383
1161
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
1384
1162
|
"""
|
1385
1163
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
|
|
173
173
|
fp8_max,
|
174
174
|
# Meta-parameters
|
175
175
|
BLOCK: tl.constexpr,
|
176
|
+
SCALE_UE8M0: tl.constexpr,
|
176
177
|
):
|
177
178
|
"""A Triton-accelerated function to perform per-token-group
|
178
179
|
quantization on a tensor.
|
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|
197
198
|
# Quant
|
198
199
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
199
200
|
y_s = _absmax / fp8_max
|
201
|
+
if SCALE_UE8M0:
|
202
|
+
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
200
203
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
201
204
|
|
202
205
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
|
|
209
212
|
eps: float = 1e-10,
|
210
213
|
column_major_scales: bool = False,
|
211
214
|
scale_tma_aligned: bool = False,
|
215
|
+
scale_ue8m0: bool = False,
|
212
216
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213
217
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
214
218
|
|
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
|
|
229
233
|
assert x.is_contiguous(), "`x` is not contiguous"
|
230
234
|
|
231
235
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
236
|
+
x_s = create_per_token_group_quant_fp8_output_scale(
|
237
|
+
x_shape=x.shape,
|
238
|
+
device=x.device,
|
239
|
+
group_size=group_size,
|
240
|
+
column_major_scales=column_major_scales,
|
241
|
+
scale_tma_aligned=scale_tma_aligned,
|
242
|
+
scale_ue8m0=False,
|
243
|
+
)
|
244
|
+
|
232
245
|
M = x.numel() // group_size
|
233
246
|
N = group_size
|
234
|
-
if column_major_scales:
|
235
|
-
if scale_tma_aligned:
|
236
|
-
# aligned to 4 * sizeof(float)
|
237
|
-
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
238
|
-
x_s = torch.empty(
|
239
|
-
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
240
|
-
device=x.device,
|
241
|
-
dtype=torch.float32,
|
242
|
-
).permute(-1, -2)[: x.shape[-2], :]
|
243
|
-
else:
|
244
|
-
x_s = torch.empty(
|
245
|
-
(x.shape[-1] // group_size,) + x.shape[:-1],
|
246
|
-
device=x.device,
|
247
|
-
dtype=torch.float32,
|
248
|
-
).permute(-1, -2)
|
249
|
-
else:
|
250
|
-
x_s = torch.empty(
|
251
|
-
x.shape[:-1] + (x.shape[-1] // group_size,),
|
252
|
-
device=x.device,
|
253
|
-
dtype=torch.float32,
|
254
|
-
)
|
255
247
|
|
256
248
|
BLOCK = triton.next_power_of_2(N)
|
257
249
|
# heuristics for number of warps
|
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
|
|
271
263
|
BLOCK=BLOCK,
|
272
264
|
num_warps=num_warps,
|
273
265
|
num_stages=num_stages,
|
266
|
+
SCALE_UE8M0=scale_ue8m0,
|
274
267
|
)
|
275
268
|
else:
|
269
|
+
assert not scale_ue8m0
|
276
270
|
_per_token_group_quant_fp8[(M,)](
|
277
271
|
x,
|
278
272
|
x_q,
|
@@ -287,57 +281,93 @@ def per_token_group_quant_fp8(
|
|
287
281
|
num_stages=num_stages,
|
288
282
|
)
|
289
283
|
|
284
|
+
if scale_ue8m0:
|
285
|
+
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
286
|
+
|
287
|
+
assert group_size == 128
|
288
|
+
x_s = transform_sf_into_required_layout(
|
289
|
+
x_s,
|
290
|
+
num_groups=None,
|
291
|
+
mn=x_q.shape[0],
|
292
|
+
k=x_q.shape[1],
|
293
|
+
recipe=(1, group_size, group_size),
|
294
|
+
is_sfa=True,
|
295
|
+
)
|
296
|
+
|
290
297
|
return x_q, x_s
|
291
298
|
|
292
299
|
|
293
|
-
def
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
column_major_scales: bool
|
298
|
-
scale_tma_aligned: bool
|
299
|
-
scale_ue8m0: bool
|
300
|
+
def create_per_token_group_quant_fp8_output_scale(
|
301
|
+
x_shape,
|
302
|
+
device,
|
303
|
+
group_size,
|
304
|
+
column_major_scales: bool,
|
305
|
+
scale_tma_aligned: bool,
|
306
|
+
scale_ue8m0: bool,
|
300
307
|
):
|
301
|
-
assert (
|
302
|
-
x.shape[-1] % group_size == 0
|
303
|
-
), "the last dimension of `x` cannot be divisible by `group_size`"
|
304
|
-
assert x.is_contiguous(), "`x` is not contiguous"
|
305
|
-
|
306
|
-
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
307
308
|
if scale_ue8m0:
|
308
309
|
assert column_major_scales and scale_tma_aligned
|
309
|
-
x_q_mn, x_q_k =
|
310
|
+
x_q_mn, x_q_k = x_shape
|
310
311
|
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
311
312
|
aligned_mn = align(x_s_mn, 4)
|
312
313
|
aligned_k = align(x_s_k, 4)
|
313
314
|
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
314
|
-
|
315
|
+
return torch.zeros(
|
315
316
|
(aligned_k // 4, aligned_mn),
|
316
|
-
device=
|
317
|
+
device=device,
|
317
318
|
dtype=torch.int,
|
318
319
|
).transpose(0, 1)[:x_s_mn, :]
|
319
320
|
elif column_major_scales:
|
320
321
|
if scale_tma_aligned:
|
321
322
|
# TODO extract "align" function
|
322
323
|
# aligned to 4 * sizeof(float)
|
323
|
-
aligned_size = (
|
324
|
-
|
325
|
-
|
326
|
-
device=
|
324
|
+
aligned_size = (x_shape[-2] + 3) // 4 * 4
|
325
|
+
return torch.empty(
|
326
|
+
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
|
327
|
+
device=device,
|
327
328
|
dtype=torch.float32,
|
328
|
-
).permute(-1, -2)[:
|
329
|
+
).permute(-1, -2)[: x_shape[-2], :]
|
329
330
|
else:
|
330
|
-
|
331
|
-
(
|
332
|
-
device=
|
331
|
+
return torch.empty(
|
332
|
+
(x_shape[-1] // group_size,) + x_shape[:-1],
|
333
|
+
device=device,
|
333
334
|
dtype=torch.float32,
|
334
335
|
).permute(-1, -2)
|
335
336
|
else:
|
336
|
-
|
337
|
-
|
338
|
-
device=
|
337
|
+
return torch.empty(
|
338
|
+
x_shape[:-1] + (x_shape[-1] // group_size,),
|
339
|
+
device=device,
|
339
340
|
dtype=torch.float32,
|
340
341
|
)
|
342
|
+
|
343
|
+
|
344
|
+
def sglang_per_token_group_quant_fp8(
|
345
|
+
x: torch.Tensor,
|
346
|
+
group_size: int,
|
347
|
+
eps: float = 1e-10,
|
348
|
+
column_major_scales: bool = False,
|
349
|
+
scale_tma_aligned: bool = False,
|
350
|
+
scale_ue8m0: bool = False,
|
351
|
+
):
|
352
|
+
assert (
|
353
|
+
x.shape[-1] % group_size == 0
|
354
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
355
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
356
|
+
|
357
|
+
if scale_ue8m0:
|
358
|
+
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
359
|
+
assert x.shape[-1] % (group_size * 4) == 0
|
360
|
+
|
361
|
+
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
362
|
+
x_s = create_per_token_group_quant_fp8_output_scale(
|
363
|
+
x_shape=x.shape,
|
364
|
+
device=x.device,
|
365
|
+
group_size=group_size,
|
366
|
+
column_major_scales=column_major_scales,
|
367
|
+
scale_tma_aligned=scale_tma_aligned,
|
368
|
+
scale_ue8m0=scale_ue8m0,
|
369
|
+
)
|
370
|
+
|
341
371
|
if x.shape[0] > 0:
|
342
372
|
sgl_per_token_group_quant_fp8(
|
343
373
|
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
|
|
35
35
|
from sglang.srt.layers.moe.topk import TopKOutput
|
36
36
|
|
37
37
|
if is_cuda():
|
38
|
-
from sgl_kernel import
|
38
|
+
from sgl_kernel import scaled_fp4_quant
|
39
|
+
|
40
|
+
try:
|
41
|
+
from flashinfer import mm_fp4 as fp4_gemm
|
42
|
+
|
43
|
+
enable_flashinfer_fp4_gemm = True
|
44
|
+
except ImportError:
|
45
|
+
if is_cuda():
|
46
|
+
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
|
47
|
+
else:
|
48
|
+
fp4_gemm = None
|
49
|
+
enable_flashinfer_fp4_gemm = False
|
39
50
|
|
40
51
|
try:
|
41
|
-
from flashinfer import fp4_quantize as fp4_quantize
|
42
52
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
43
53
|
except ImportError:
|
44
54
|
flashinfer_cutlass_fused_moe = None
|
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
683
693
|
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
684
694
|
assert layer.alpha.dtype == torch.float32
|
685
695
|
|
686
|
-
|
696
|
+
w = layer.weight
|
697
|
+
w_scale_interleaved = layer.weight_scale_interleaved
|
698
|
+
if enable_flashinfer_fp4_gemm:
|
699
|
+
w = layer.weight.T
|
700
|
+
w_scale_interleaved = layer.weight_scale_interleaved.T
|
701
|
+
out = fp4_gemm(
|
687
702
|
x_fp4,
|
688
|
-
|
703
|
+
w,
|
689
704
|
x_scale_interleaved,
|
690
|
-
|
705
|
+
w_scale_interleaved,
|
691
706
|
layer.alpha,
|
692
707
|
output_dtype,
|
693
708
|
)
|
@@ -711,7 +726,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
711
726
|
" quantization. Please use Blackwell and"
|
712
727
|
" above."
|
713
728
|
)
|
714
|
-
self.
|
729
|
+
self.enable_flashinfer_cutlass_moe = False
|
715
730
|
|
716
731
|
def create_weights(
|
717
732
|
self,
|
@@ -865,7 +880,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
865
880
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
866
881
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
867
882
|
|
868
|
-
if self.
|
883
|
+
if self.enable_flashinfer_cutlass_moe:
|
869
884
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
870
885
|
else:
|
871
886
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
@@ -894,7 +909,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
894
909
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
895
910
|
|
896
911
|
# GEMM 2
|
897
|
-
if self.
|
912
|
+
if self.enable_flashinfer_cutlass_moe:
|
898
913
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
899
914
|
else:
|
900
915
|
w2_input_scale = layer.w2_input_scale
|
@@ -934,7 +949,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
934
949
|
@property
|
935
950
|
def load_up_proj_weight_first(self) -> bool:
|
936
951
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
937
|
-
return self.
|
952
|
+
return self.enable_flashinfer_cutlass_moe
|
938
953
|
|
939
954
|
def apply(
|
940
955
|
self,
|
@@ -952,10 +967,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
952
967
|
tp_rank: Optional[int] = None,
|
953
968
|
tp_size: Optional[int] = None,
|
954
969
|
) -> torch.Tensor:
|
955
|
-
|
956
970
|
assert activation == "silu", "Only SiLU activation is supported."
|
957
971
|
|
958
|
-
if self.
|
972
|
+
if self.enable_flashinfer_cutlass_moe:
|
959
973
|
assert (
|
960
974
|
not apply_router_weight_on_input
|
961
975
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
@@ -982,13 +996,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
982
996
|
tp_size=tp_size,
|
983
997
|
tp_rank=tp_rank,
|
984
998
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
985
|
-
)
|
986
|
-
|
999
|
+
)[0]
|
1000
|
+
if routed_scaling_factor is not None:
|
1001
|
+
output *= routed_scaling_factor
|
1002
|
+
return output
|
987
1003
|
|
988
1004
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
989
1005
|
|
990
1006
|
topk_weights, topk_ids, _ = topk_output
|
991
|
-
|
1007
|
+
output = cutlass_moe_fp4(
|
992
1008
|
a=x,
|
993
1009
|
a1_gscale=layer.w13_input_scale_quant,
|
994
1010
|
w1_fp4=layer.w13_weight,
|
@@ -1003,3 +1019,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1003
1019
|
params=layer.cutlass_moe_params,
|
1004
1020
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
1005
1021
|
).to(x.dtype)
|
1022
|
+
if routed_scaling_factor is not None:
|
1023
|
+
output *= routed_scaling_factor
|
1024
|
+
return output
|