sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
103
103
|
"input_activations"
|
104
104
|
)
|
105
105
|
|
106
|
-
if not (
|
107
|
-
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
108
|
-
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
109
|
-
):
|
110
|
-
raise ValueError(
|
111
|
-
"For FP8 Fused MoE layers, only per-tensor scales "
|
112
|
-
"for weights and activations are supported. Found "
|
113
|
-
f"{self.weight_quant}, {self.input_quant}"
|
114
|
-
)
|
115
|
-
|
116
106
|
self.static_input_scales = not self.input_quant.dynamic
|
117
107
|
|
118
108
|
def create_weights(
|
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
154
144
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
155
145
|
|
156
146
|
# WEIGHT_SCALES
|
157
|
-
#
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
147
|
+
# per-tensor quantization
|
148
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
149
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
150
|
+
# They will be combined to a single scale after weight loading.
|
151
|
+
w13_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
153
|
+
)
|
154
|
+
w2_weight_scale = torch.nn.Parameter(
|
155
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
156
|
+
)
|
157
|
+
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
|
158
|
+
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
159
|
+
w13_weight_scale = torch.nn.Parameter(
|
160
|
+
torch.ones(
|
161
|
+
num_experts,
|
162
|
+
2 * intermediate_size_per_partition,
|
163
|
+
1,
|
164
|
+
dtype=torch.float32,
|
165
|
+
),
|
166
|
+
requires_grad=False,
|
167
|
+
)
|
168
|
+
w2_weight_scale = torch.nn.Parameter(
|
169
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
170
|
+
requires_grad=False,
|
171
|
+
)
|
172
|
+
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
|
173
|
+
else:
|
174
|
+
raise ValueError(
|
175
|
+
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
|
176
|
+
)
|
163
177
|
|
164
|
-
|
165
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
166
|
-
)
|
178
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
167
179
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
168
180
|
# Add the quantization method used (per tensor/grouped/channel)
|
169
181
|
# to ensure the weight scales are loaded in properly
|
170
|
-
extra_weight_attrs.update(
|
171
|
-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
172
|
-
)
|
182
|
+
extra_weight_attrs.update({"quant_method": weight_quant_method})
|
173
183
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
174
184
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
175
185
|
|
176
186
|
# INPUT_SCALES
|
177
187
|
if self.static_input_scales:
|
188
|
+
assert (
|
189
|
+
self.input_quant.strategy == QuantizationStrategy.TENSOR
|
190
|
+
), "Only per-tensor quantization is supported for static input scales"
|
178
191
|
w13_input_scale = torch.nn.Parameter(
|
179
192
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
180
193
|
)
|
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
241
254
|
layer.w2_input_scale = torch.nn.Parameter(
|
242
255
|
w2_input_scale, requires_grad=False
|
243
256
|
)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
)
|
257
|
-
|
258
|
-
if _is_cuda:
|
259
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
260
|
-
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
261
|
-
)
|
262
|
-
else:
|
263
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
264
|
-
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
257
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
258
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
259
|
+
# We take the max then dequant and requant each expert.
|
260
|
+
assert layer.w13_weight_scale is not None
|
261
|
+
shard_size = layer.intermediate_size_per_partition
|
262
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
263
|
+
for expert_id in range(layer.local_num_experts):
|
264
|
+
start = 0
|
265
|
+
for shard_id in range(2):
|
266
|
+
dq_weight = per_tensor_dequantize(
|
267
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
268
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
265
269
|
)
|
266
|
-
start += shard_size
|
267
270
|
|
268
|
-
|
271
|
+
if _is_cuda:
|
272
|
+
(
|
273
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
274
|
+
_,
|
275
|
+
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
276
|
+
else:
|
277
|
+
(
|
278
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
279
|
+
_,
|
280
|
+
) = vllm_ops.scaled_fp8_quant(
|
281
|
+
dq_weight, max_w13_scales[expert_id]
|
282
|
+
)
|
283
|
+
start += shard_size
|
284
|
+
|
285
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
286
|
+
max_w13_scales, requires_grad=False
|
287
|
+
)
|
269
288
|
|
270
289
|
def apply(
|
271
290
|
self,
|
@@ -285,6 +304,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
285
304
|
activation: str = "silu",
|
286
305
|
inplace: bool = True,
|
287
306
|
no_combine: bool = False,
|
307
|
+
apply_router_weight_on_input: bool = False,
|
288
308
|
) -> torch.Tensor:
|
289
309
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
290
310
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -310,10 +330,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
310
330
|
inplace=inplace,
|
311
331
|
activation=activation,
|
312
332
|
use_fp8_w8a8=True,
|
333
|
+
per_channel_quant=self.weight_quant.strategy
|
334
|
+
== QuantizationStrategy.CHANNEL,
|
313
335
|
w1_scale=layer.w13_weight_scale,
|
314
336
|
w2_scale=layer.w2_weight_scale,
|
315
337
|
a1_scale=layer.w13_input_scale,
|
316
338
|
a2_scale=layer.w2_input_scale,
|
339
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
317
340
|
)
|
318
341
|
|
319
342
|
|
@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
71
71
|
_is_hip = is_hip()
|
72
72
|
|
73
73
|
if _is_hip:
|
74
|
-
from aiter
|
74
|
+
from aiter import ActivationType
|
75
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
75
76
|
from aiter.ops.shuffle import shuffle_weight
|
76
77
|
|
77
78
|
_is_cuda = is_cuda()
|
@@ -487,7 +488,7 @@ class Fp8MoEMethod:
|
|
487
488
|
|
488
489
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
489
490
|
params_dtype = (
|
490
|
-
torch.
|
491
|
+
torch.uint32
|
491
492
|
if get_bool_env_var("USE_INT4_WEIGHT")
|
492
493
|
else torch.float8_e4m3fn
|
493
494
|
)
|
@@ -822,12 +823,14 @@ class Fp8MoEMethod:
|
|
822
823
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
823
824
|
# Weight Permutation
|
824
825
|
layer.w13_weight = torch.nn.Parameter(
|
825
|
-
permute_weight(layer.w13_weight.data),
|
826
|
+
# permute_weight(layer.w13_weight.data),
|
827
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
826
828
|
requires_grad=False,
|
827
829
|
)
|
828
830
|
torch.cuda.empty_cache()
|
829
831
|
layer.w2_weight = torch.nn.Parameter(
|
830
|
-
permute_weight(layer.w2_weight.data),
|
832
|
+
# permute_weight(layer.w2_weight.data),
|
833
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
831
834
|
requires_grad=False,
|
832
835
|
)
|
833
836
|
torch.cuda.empty_cache()
|
@@ -867,12 +870,14 @@ class Fp8MoEMethod:
|
|
867
870
|
|
868
871
|
if get_bool_env_var("CK_MOE"):
|
869
872
|
layer.w13_weight = torch.nn.Parameter(
|
870
|
-
permute_weight(layer.w13_weight.data),
|
873
|
+
# permute_weight(layer.w13_weight.data),
|
874
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
871
875
|
requires_grad=False,
|
872
876
|
)
|
873
877
|
torch.cuda.empty_cache()
|
874
878
|
layer.w2_weight = torch.nn.Parameter(
|
875
|
-
permute_weight(layer.w2_weight.data),
|
879
|
+
# permute_weight(layer.w2_weight.data),
|
880
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
876
881
|
requires_grad=False,
|
877
882
|
)
|
878
883
|
torch.cuda.empty_cache()
|
@@ -928,7 +933,7 @@ class Fp8MoEMethod:
|
|
928
933
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
929
934
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
930
935
|
assert not no_combine, f"{no_combine=} is not supported."
|
931
|
-
return
|
936
|
+
return ck_moe_2stages_win4(
|
932
937
|
x,
|
933
938
|
layer.w13_weight,
|
934
939
|
layer.w2_weight,
|
@@ -936,15 +941,17 @@ class Fp8MoEMethod:
|
|
936
941
|
topk_ids,
|
937
942
|
layer.w13_weight_scale1,
|
938
943
|
layer.w2_weight_scale1,
|
939
|
-
activation=
|
944
|
+
activation=(
|
945
|
+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
946
|
+
),
|
940
947
|
)
|
941
948
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
942
|
-
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
943
|
-
assert (
|
944
|
-
activation == "silu"
|
945
|
-
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
946
949
|
assert not no_combine, f"{no_combine=} is not supported."
|
947
950
|
if self.block_quant:
|
951
|
+
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
952
|
+
assert (
|
953
|
+
activation == "silu"
|
954
|
+
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
948
955
|
return asm_moe(
|
949
956
|
x,
|
950
957
|
layer.w13_weight,
|
@@ -957,7 +964,7 @@ class Fp8MoEMethod:
|
|
957
964
|
expert_mask=None,
|
958
965
|
)
|
959
966
|
else:
|
960
|
-
return
|
967
|
+
return ck_moe_2stages(
|
961
968
|
x,
|
962
969
|
layer.w13_weight,
|
963
970
|
layer.w2_weight,
|
@@ -965,6 +972,11 @@ class Fp8MoEMethod:
|
|
965
972
|
topk_ids,
|
966
973
|
layer.w13_weight_scale1,
|
967
974
|
layer.w2_weight_scale1,
|
975
|
+
activation=(
|
976
|
+
ActivationType.Silu
|
977
|
+
if activation == "silu"
|
978
|
+
else ActivationType.Gelu
|
979
|
+
),
|
968
980
|
)
|
969
981
|
else:
|
970
982
|
# Expert fusion with FP8 quantization
|
@@ -16,6 +16,7 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
+
from contextlib import contextmanager
|
19
20
|
from typing import Any, Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -40,11 +41,13 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
40
41
|
|
41
42
|
_is_cuda = is_cuda()
|
42
43
|
if _is_cuda:
|
43
|
-
import deep_gemm
|
44
|
+
import deep_gemm
|
44
45
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
45
46
|
|
46
47
|
sm_version = get_device_sm()
|
47
|
-
if sm_version
|
48
|
+
if sm_version == 90 and get_bool_env_var(
|
49
|
+
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
|
50
|
+
):
|
48
51
|
_enable_jit_deepgemm = True
|
49
52
|
|
50
53
|
|
@@ -59,7 +62,10 @@ if supports_custom_op():
|
|
59
62
|
Bs: torch.Tensor,
|
60
63
|
C: torch.Tensor,
|
61
64
|
) -> None:
|
62
|
-
|
65
|
+
M, K = A.shape
|
66
|
+
N, _ = B.shape
|
67
|
+
with _log_jit_build(M, N, K):
|
68
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
63
69
|
|
64
70
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
65
71
|
A: torch.Tensor,
|
@@ -708,6 +714,25 @@ def get_w8a8_block_fp8_configs(
|
|
708
714
|
return None
|
709
715
|
|
710
716
|
|
717
|
+
@contextmanager
|
718
|
+
def _log_jit_build(M: int, N: int, K: int):
|
719
|
+
from deep_gemm.jit.runtime import RuntimeCache
|
720
|
+
|
721
|
+
origin_func = RuntimeCache.__getitem__
|
722
|
+
|
723
|
+
def __patched_func(self, *args, **kwargs):
|
724
|
+
ret = origin_func(self, *args, **kwargs)
|
725
|
+
if ret is None:
|
726
|
+
logger.warning(
|
727
|
+
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
728
|
+
)
|
729
|
+
return ret
|
730
|
+
|
731
|
+
RuntimeCache.__getitem__ = __patched_func
|
732
|
+
yield
|
733
|
+
RuntimeCache.__getitem__ = origin_func
|
734
|
+
|
735
|
+
|
711
736
|
def w8a8_block_fp8_matmul(
|
712
737
|
A: torch.Tensor,
|
713
738
|
B: torch.Tensor,
|
@@ -782,7 +807,8 @@ def w8a8_block_fp8_matmul(
|
|
782
807
|
if supports_custom_op():
|
783
808
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
784
809
|
else:
|
785
|
-
|
810
|
+
with _log_jit_build(M, N, K):
|
811
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
786
812
|
else:
|
787
813
|
kernel = (
|
788
814
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -815,3 +841,103 @@ def w8a8_block_fp8_matmul(
|
|
815
841
|
)
|
816
842
|
|
817
843
|
return C
|
844
|
+
|
845
|
+
|
846
|
+
@triton.jit
|
847
|
+
def _per_tensor_quant_mla_fp8_stage1(
|
848
|
+
x_ptr,
|
849
|
+
x_s_ptr,
|
850
|
+
head_size,
|
851
|
+
x_stride_h,
|
852
|
+
x_stride_s,
|
853
|
+
eps,
|
854
|
+
fp8_max,
|
855
|
+
BLOCK_SIZE: tl.constexpr,
|
856
|
+
):
|
857
|
+
seq_id = tl.program_id(0)
|
858
|
+
head_id = tl.program_id(1)
|
859
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
860
|
+
mask = offset < head_size
|
861
|
+
|
862
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
863
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
864
|
+
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
|
865
|
+
|
866
|
+
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
|
867
|
+
|
868
|
+
|
869
|
+
@triton.jit
|
870
|
+
def _per_tensor_quant_mla_fp8_stage2(
|
871
|
+
x_ptr,
|
872
|
+
x_s_ptr,
|
873
|
+
x_q_ptr,
|
874
|
+
num_seq,
|
875
|
+
head_size,
|
876
|
+
x_stride_h,
|
877
|
+
x_stride_s,
|
878
|
+
fp8_min,
|
879
|
+
fp8_max,
|
880
|
+
BLOCK_SIZE: tl.constexpr,
|
881
|
+
):
|
882
|
+
seq_id = tl.program_id(0)
|
883
|
+
head_id = tl.program_id(1)
|
884
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
885
|
+
mask = offset < head_size
|
886
|
+
|
887
|
+
x_s = tl.load(x_s_ptr)
|
888
|
+
x_s_inv = 1.0 / x_s
|
889
|
+
|
890
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
891
|
+
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
|
892
|
+
|
893
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
894
|
+
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
|
895
|
+
tl.store(x_q_ptr + offset, x_q, mask=mask)
|
896
|
+
|
897
|
+
|
898
|
+
def per_tensor_quant_mla_fp8(
|
899
|
+
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
900
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
901
|
+
"""
|
902
|
+
This function quantizes input values to float8 values with tensor-wise quantization
|
903
|
+
and specialized for mla absorbed case.
|
904
|
+
"""
|
905
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
906
|
+
|
907
|
+
finfo = torch.finfo(dtype)
|
908
|
+
fp8_max = finfo.max
|
909
|
+
if _is_hip:
|
910
|
+
dtype = torch.float8_e4m3fnuz
|
911
|
+
fp8_max = 224.0
|
912
|
+
|
913
|
+
x_q = x.new_empty(x.size(), dtype=dtype)
|
914
|
+
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
915
|
+
|
916
|
+
num_head, num_seq, head_size = x.shape
|
917
|
+
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
918
|
+
grid = (num_seq, num_head)
|
919
|
+
|
920
|
+
_per_tensor_quant_mla_fp8_stage1[grid](
|
921
|
+
x,
|
922
|
+
x_s,
|
923
|
+
head_size,
|
924
|
+
x.stride(0),
|
925
|
+
x.stride(1),
|
926
|
+
eps,
|
927
|
+
fp8_max,
|
928
|
+
BLOCK_SIZE,
|
929
|
+
)
|
930
|
+
_per_tensor_quant_mla_fp8_stage2[grid](
|
931
|
+
x,
|
932
|
+
x_s,
|
933
|
+
x_q,
|
934
|
+
num_seq,
|
935
|
+
head_size,
|
936
|
+
x.stride(0),
|
937
|
+
x.stride(1),
|
938
|
+
-fp8_max,
|
939
|
+
fp8_max,
|
940
|
+
BLOCK_SIZE,
|
941
|
+
)
|
942
|
+
|
943
|
+
return x_q, x_s
|
@@ -168,12 +168,13 @@ def input_to_float8(
|
|
168
168
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
169
169
|
finfo = torch.finfo(dtype)
|
170
170
|
min_val, max_val = x.aminmax()
|
171
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
171
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
172
172
|
fp8_max = finfo.max
|
173
173
|
if _is_hip:
|
174
|
+
dtype = torch.float8_e4m3fnuz
|
174
175
|
fp8_max = 224.0
|
175
176
|
scale = fp8_max / amax
|
176
|
-
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
177
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
177
178
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
178
179
|
|
179
180
|
|
@@ -212,7 +213,24 @@ def block_quant_to_tensor_quant(
|
|
212
213
|
for j in range(n_tiles):
|
213
214
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
214
215
|
|
215
|
-
x_q_tensor, scale =
|
216
|
+
x_q_tensor, scale = (
|
217
|
+
sgl_scaled_fp8_quant(x_dq_block)
|
218
|
+
if _is_cuda
|
219
|
+
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
220
|
+
)
|
221
|
+
return x_q_tensor, scale
|
222
|
+
|
223
|
+
|
224
|
+
def channel_quant_to_tensor_quant(
|
225
|
+
x_q_channel: torch.Tensor,
|
226
|
+
x_s: torch.Tensor,
|
227
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
228
|
+
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
229
|
+
x_q_tensor, scale = (
|
230
|
+
sgl_scaled_fp8_quant(x_dq_channel)
|
231
|
+
if _is_cuda
|
232
|
+
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
233
|
+
)
|
216
234
|
return x_q_tensor, scale
|
217
235
|
|
218
236
|
|
@@ -242,9 +260,19 @@ def apply_fp8_linear(
|
|
242
260
|
if _is_cuda:
|
243
261
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
244
262
|
else:
|
245
|
-
|
246
|
-
|
247
|
-
)
|
263
|
+
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
264
|
+
# final solution should be: 1. add support to per-tensor activation scaling.
|
265
|
+
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
266
|
+
if _is_hip and weight_scale.numel() == 1:
|
267
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
268
|
+
input_2d,
|
269
|
+
input_scale,
|
270
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
qinput, x_scale = per_token_group_quant_fp8(
|
274
|
+
input_2d, group_size=input_2d.shape[1]
|
275
|
+
)
|
248
276
|
|
249
277
|
if cutlass_fp8_supported:
|
250
278
|
try:
|
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
11
12
|
from sglang.srt.utils import is_hip
|
12
13
|
|
13
14
|
_is_hip = is_hip()
|
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|
17
18
|
|
18
19
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
19
20
|
"""
|
20
|
-
Quant method that adds `
|
21
|
+
Quant method that adds `k_scale` and `v_scale` attributes to the
|
21
22
|
Attention layer to support loading those scaling factors from checkpoints.
|
22
23
|
The k/v_scale will be used to:
|
23
24
|
- quantize k/v_cache entries before saving them to the cache
|
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
36
37
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
37
38
|
# If the k/v_scale appears in the checkpoint, it will be
|
38
39
|
# overwritten when loading weights.
|
39
|
-
layer.k_scale = torch.nn.Parameter(
|
40
|
-
|
40
|
+
layer.k_scale = torch.nn.Parameter(
|
41
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
42
|
+
)
|
43
|
+
layer.v_scale = torch.nn.Parameter(
|
44
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
|
+
)
|
41
46
|
|
42
47
|
@classmethod
|
43
48
|
def is_fp8_fnuz(cls) -> bool:
|
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
47
52
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
48
53
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
49
54
|
|
50
|
-
def process_weights_after_loading(self, layer:
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
# These are used in the final Attention.forward()
|
86
|
-
layer._k_scale.copy_(k_scale)
|
87
|
-
layer._v_scale.copy_(v_scale)
|
88
|
-
layer._k_scale_float = k_scale
|
89
|
-
layer._v_scale_float = v_scale
|
90
|
-
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
91
|
-
logger.warning(
|
92
|
-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
93
|
-
"may cause accuracy issues. Please make sure k/v_scale "
|
94
|
-
"scaling factors are available in the fp8 checkpoint."
|
95
|
-
)
|
96
|
-
|
97
|
-
del layer.k_scale
|
98
|
-
del layer.v_scale
|
55
|
+
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
56
|
+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
57
|
+
# We prefer to use separate k_scale and v_scale if present
|
58
|
+
k_scale = layer.k_scale.to("cpu").tolist()
|
59
|
+
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
+
if _is_hip and self.is_fp8_fnuz():
|
61
|
+
k_scale *= 2
|
62
|
+
v_scale *= 2
|
63
|
+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
64
|
+
# If no scales were loaded (both scales are invalid negative
|
65
|
+
# values), use the default value of 1.0
|
66
|
+
k_scale = 1.0
|
67
|
+
v_scale = 1.0
|
68
|
+
else:
|
69
|
+
# If we find a single kv_scale in the checkpoint, we remap
|
70
|
+
# kv_scale to k_scale during weight loading, and duplicate
|
71
|
+
# k_scale to v_scale here
|
72
|
+
assert layer.k_scale > 0.0
|
73
|
+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
|
+
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
|
+
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
+
if _is_hip and self.is_fp8_fnuz():
|
77
|
+
k_scale *= 2
|
78
|
+
v_scale *= 2
|
79
|
+
|
80
|
+
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
81
|
+
raise ValueError(
|
82
|
+
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
83
|
+
)
|
84
|
+
|
85
|
+
# These are used in the final Attention.forward()
|
86
|
+
layer.k_scale.copy_(k_scale)
|
87
|
+
layer.v_scale.copy_(v_scale)
|
88
|
+
layer.k_scale_float = k_scale
|
89
|
+
layer.v_scale_float = v_scale
|