sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|
42
42
|
QuantizeMethodBase,
|
43
43
|
)
|
44
44
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
45
|
+
fp8_dtype,
|
46
|
+
is_fp8_fnuz,
|
45
47
|
per_token_group_quant_fp8,
|
46
48
|
scaled_fp8_quant,
|
47
49
|
)
|
@@ -64,6 +66,7 @@ from sglang.srt.utils import (
|
|
64
66
|
get_bool_env_var,
|
65
67
|
is_cuda,
|
66
68
|
is_hip,
|
69
|
+
log_info_on_rank0,
|
67
70
|
print_warning_once,
|
68
71
|
set_weight_attrs,
|
69
72
|
)
|
@@ -71,6 +74,11 @@ from sglang.srt.utils import (
|
|
71
74
|
_is_hip = is_hip()
|
72
75
|
_is_cuda = is_cuda()
|
73
76
|
|
77
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
78
|
+
|
79
|
+
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
80
|
+
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
81
|
+
|
74
82
|
if _is_hip:
|
75
83
|
from aiter import ActivationType, QuantType
|
76
84
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
@@ -97,10 +105,7 @@ class Fp8Config(QuantizationConfig):
|
|
97
105
|
) -> None:
|
98
106
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
99
107
|
if is_checkpoint_fp8_serialized:
|
100
|
-
logger.
|
101
|
-
"Detected fp8 checkpoint. Please note that the "
|
102
|
-
"format is experimental and subject to change."
|
103
|
-
)
|
108
|
+
log_info_on_rank0(logger, "Detected fp8 checkpoint.")
|
104
109
|
if activation_scheme not in ACTIVATION_SCHEMES:
|
105
110
|
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
106
111
|
self.activation_scheme = activation_scheme
|
@@ -306,25 +311,21 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
306
311
|
# Block quant doesn't need to process weights after loading
|
307
312
|
if self.block_quant:
|
308
313
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
309
|
-
if
|
314
|
+
if _is_fp8_fnuz:
|
310
315
|
# activation_scheme: dynamic
|
311
316
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
312
317
|
weight=layer.weight,
|
313
318
|
weight_scale=layer.weight_scale_inv,
|
314
319
|
input_scale=None,
|
315
320
|
)
|
316
|
-
|
317
|
-
layer.weight_scale_inv = torch.nn.Parameter(
|
318
|
-
weight_scale, requires_grad=False
|
319
|
-
)
|
321
|
+
|
320
322
|
layer.input_scale = None
|
321
323
|
else:
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
)
|
324
|
+
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
325
|
+
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
326
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
327
|
+
weight_scale, requires_grad=False
|
328
|
+
)
|
328
329
|
return
|
329
330
|
|
330
331
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
@@ -368,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
368
369
|
weight = layer.weight
|
369
370
|
weight_scale = layer.weight_scale
|
370
371
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
371
|
-
if
|
372
|
+
if _is_fp8_fnuz:
|
372
373
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
373
374
|
weight=weight,
|
374
375
|
weight_scale=weight_scale,
|
@@ -482,11 +483,7 @@ class Fp8MoEMethod:
|
|
482
483
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
483
484
|
|
484
485
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
485
|
-
params_dtype =
|
486
|
-
torch.uint32
|
487
|
-
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
488
|
-
else torch.float8_e4m3fn
|
489
|
-
)
|
486
|
+
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
|
490
487
|
tp_size = get_tensor_model_parallel_world_size()
|
491
488
|
if self.block_quant:
|
492
489
|
block_n, block_k = (
|
@@ -511,7 +508,7 @@ class Fp8MoEMethod:
|
|
511
508
|
)
|
512
509
|
|
513
510
|
# WEIGHTS
|
514
|
-
if _is_hip and
|
511
|
+
if _is_hip and use_hip_int4:
|
515
512
|
# INT4 MoE weight - INT32 packed
|
516
513
|
w13_weight = torch.nn.Parameter(
|
517
514
|
torch.empty(
|
@@ -583,9 +580,7 @@ class Fp8MoEMethod:
|
|
583
580
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
584
581
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
585
582
|
|
586
|
-
if
|
587
|
-
_is_hip
|
588
|
-
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
583
|
+
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
|
589
584
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
590
585
|
w13_weight_scale1 = torch.nn.Parameter(
|
591
586
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
@@ -612,7 +607,7 @@ class Fp8MoEMethod:
|
|
612
607
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
613
608
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
614
609
|
|
615
|
-
if _is_hip and
|
610
|
+
if _is_hip and use_hip_int4:
|
616
611
|
extra_weight_attrs.update(
|
617
612
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
618
613
|
)
|
@@ -644,14 +639,14 @@ class Fp8MoEMethod:
|
|
644
639
|
layer.w2_input_scale = None
|
645
640
|
|
646
641
|
def process_weights_after_loading(self, layer: Module) -> None:
|
647
|
-
if _is_hip and
|
642
|
+
if _is_hip and use_hip_int4:
|
648
643
|
self.process_weights_hip_int4(layer)
|
649
644
|
return
|
650
645
|
|
651
646
|
# Block quant doesn't need to process weights after loading
|
652
647
|
if self.block_quant:
|
653
648
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
654
|
-
if
|
649
|
+
if _is_fp8_fnuz:
|
655
650
|
# activation_scheme: dynamic
|
656
651
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
657
652
|
weight=layer.w13_weight,
|
@@ -675,20 +670,19 @@ class Fp8MoEMethod:
|
|
675
670
|
)
|
676
671
|
layer.w2_input_scale = None
|
677
672
|
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
673
|
+
if _is_hip and use_aiter_moe:
|
674
|
+
# Pre-shuffle weights
|
675
|
+
layer.w13_weight.data = shuffle_weight(
|
676
|
+
layer.w13_weight.contiguous(), (16, 16)
|
677
|
+
)
|
678
|
+
layer.w2_weight.data = shuffle_weight(
|
679
|
+
layer.w2_weight.contiguous(), (16, 16)
|
680
|
+
)
|
686
681
|
return
|
687
682
|
|
688
683
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
689
684
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
690
|
-
# If ROCm,
|
691
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
685
|
+
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
|
692
686
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
693
687
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
694
688
|
|
@@ -742,7 +736,7 @@ class Fp8MoEMethod:
|
|
742
736
|
)
|
743
737
|
|
744
738
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
745
|
-
if
|
739
|
+
if _is_fp8_fnuz:
|
746
740
|
# Normalize the weights and scales
|
747
741
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
748
742
|
normalize_e4m3fn_to_e4m3fnuz(
|
@@ -798,7 +792,7 @@ class Fp8MoEMethod:
|
|
798
792
|
return
|
799
793
|
|
800
794
|
def process_weights_hip_int4(self, layer: Module):
|
801
|
-
# TODO: and
|
795
|
+
# TODO: and use_aiter_moe: add after triton kernel added
|
802
796
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
803
797
|
# Weight Permutation
|
804
798
|
layer.w13_weight = torch.nn.Parameter(
|
@@ -845,7 +839,7 @@ class Fp8MoEMethod:
|
|
845
839
|
padding_size, # Avoid circular import
|
846
840
|
)
|
847
841
|
|
848
|
-
if
|
842
|
+
if use_aiter_moe:
|
849
843
|
layer.w13_weight = torch.nn.Parameter(
|
850
844
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
851
845
|
requires_grad=False,
|
@@ -856,7 +850,7 @@ class Fp8MoEMethod:
|
|
856
850
|
requires_grad=False,
|
857
851
|
)
|
858
852
|
torch.cuda.empty_cache()
|
859
|
-
# ROCm (
|
853
|
+
# ROCm (use_aiter_moe): using column-wise scaling
|
860
854
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
861
855
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
862
856
|
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
@@ -908,59 +902,16 @@ class Fp8MoEMethod:
|
|
908
902
|
)
|
909
903
|
|
910
904
|
if _is_hip:
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
layer.w13_weight_scale1,
|
922
|
-
layer.w2_weight_scale1,
|
923
|
-
activation=(
|
924
|
-
ActivationType.Silu
|
925
|
-
if activation == "silu"
|
926
|
-
else ActivationType.Gelu
|
927
|
-
),
|
928
|
-
)
|
929
|
-
|
930
|
-
if get_bool_env_var("SGLANG_AITER_MOE"):
|
931
|
-
assert not no_combine, f"{no_combine=} is not supported."
|
932
|
-
if self.block_quant:
|
933
|
-
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
934
|
-
assert (
|
935
|
-
activation == "silu"
|
936
|
-
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
937
|
-
return asm_moe(
|
938
|
-
x,
|
939
|
-
layer.w13_weight,
|
940
|
-
layer.w2_weight,
|
941
|
-
topk_weights,
|
942
|
-
topk_ids,
|
943
|
-
layer.w13_weight_scale_inv,
|
944
|
-
layer.w2_weight_scale_inv,
|
945
|
-
block_shape=tuple(self.quant_config.weight_block_size),
|
946
|
-
expert_mask=None,
|
947
|
-
)
|
948
|
-
else:
|
949
|
-
return ck_moe_2stages(
|
950
|
-
x,
|
951
|
-
layer.w13_weight,
|
952
|
-
layer.w2_weight,
|
953
|
-
topk_weights,
|
954
|
-
topk_ids,
|
955
|
-
QuantType.per_Token,
|
956
|
-
layer.w13_weight_scale1,
|
957
|
-
layer.w2_weight_scale1,
|
958
|
-
activation=(
|
959
|
-
ActivationType.Silu
|
960
|
-
if activation == "silu"
|
961
|
-
else ActivationType.Gelu
|
962
|
-
),
|
963
|
-
)
|
905
|
+
ret = self.maybe_apply_hip_fused_experts(
|
906
|
+
layer,
|
907
|
+
x,
|
908
|
+
topk_weights,
|
909
|
+
topk_ids,
|
910
|
+
activation,
|
911
|
+
no_combine,
|
912
|
+
)
|
913
|
+
if ret is not None:
|
914
|
+
return ret
|
964
915
|
|
965
916
|
# Expert fusion with FP8 quantization
|
966
917
|
return fused_experts(
|
@@ -987,6 +938,68 @@ class Fp8MoEMethod:
|
|
987
938
|
no_combine=no_combine,
|
988
939
|
)
|
989
940
|
|
941
|
+
def maybe_apply_hip_fused_experts(
|
942
|
+
self,
|
943
|
+
layer: torch.nn.Module,
|
944
|
+
x: torch.Tensor,
|
945
|
+
topk_weights: torch.Tensor,
|
946
|
+
topk_ids: torch.Tensor,
|
947
|
+
activation: str = "silu",
|
948
|
+
no_combine: bool = False,
|
949
|
+
) -> Optional[torch.Tensor]:
|
950
|
+
if use_hip_int4:
|
951
|
+
# TODO: add triton kernel and add check use_aiter_moe
|
952
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
953
|
+
return ck_moe_2stages(
|
954
|
+
x,
|
955
|
+
layer.w13_weight,
|
956
|
+
layer.w2_weight,
|
957
|
+
topk_weights,
|
958
|
+
topk_ids,
|
959
|
+
QuantType.per_Token,
|
960
|
+
layer.w13_weight_scale1,
|
961
|
+
layer.w2_weight_scale1,
|
962
|
+
activation=(
|
963
|
+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
964
|
+
),
|
965
|
+
)
|
966
|
+
|
967
|
+
if use_aiter_moe:
|
968
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
969
|
+
if self.block_quant:
|
970
|
+
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
|
971
|
+
assert (
|
972
|
+
activation == "silu"
|
973
|
+
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
|
974
|
+
return asm_moe(
|
975
|
+
x,
|
976
|
+
layer.w13_weight,
|
977
|
+
layer.w2_weight,
|
978
|
+
topk_weights,
|
979
|
+
topk_ids,
|
980
|
+
layer.w13_weight_scale_inv,
|
981
|
+
layer.w2_weight_scale_inv,
|
982
|
+
block_shape=tuple(self.quant_config.weight_block_size),
|
983
|
+
expert_mask=None,
|
984
|
+
)
|
985
|
+
else:
|
986
|
+
return ck_moe_2stages(
|
987
|
+
x,
|
988
|
+
layer.w13_weight,
|
989
|
+
layer.w2_weight,
|
990
|
+
topk_weights,
|
991
|
+
topk_ids,
|
992
|
+
QuantType.per_Token,
|
993
|
+
layer.w13_weight_scale1,
|
994
|
+
layer.w2_weight_scale1,
|
995
|
+
activation=(
|
996
|
+
ActivationType.Silu
|
997
|
+
if activation == "silu"
|
998
|
+
else ActivationType.Gelu
|
999
|
+
),
|
1000
|
+
)
|
1001
|
+
return None
|
1002
|
+
|
990
1003
|
|
991
1004
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
992
1005
|
"""
|
@@ -16,6 +16,7 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
+
from functools import lru_cache
|
19
20
|
from typing import Any, Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -29,17 +30,12 @@ from sglang.srt.utils import (
|
|
29
30
|
get_device_name,
|
30
31
|
is_cuda,
|
31
32
|
is_hip,
|
33
|
+
log_info_on_rank0,
|
32
34
|
supports_custom_op,
|
33
35
|
)
|
34
36
|
|
35
37
|
_is_hip = is_hip()
|
36
38
|
_is_cuda = is_cuda()
|
37
|
-
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
38
|
-
if _is_hip:
|
39
|
-
fp8_max = 224.0
|
40
|
-
else:
|
41
|
-
fp8_max = torch.finfo(_fp8_type).max
|
42
|
-
fp8_min = -fp8_max
|
43
39
|
|
44
40
|
if _is_cuda:
|
45
41
|
from sgl_kernel import (
|
@@ -54,6 +50,24 @@ if _is_cuda:
|
|
54
50
|
|
55
51
|
logger = logging.getLogger(__name__)
|
56
52
|
|
53
|
+
|
54
|
+
@lru_cache()
|
55
|
+
def is_fp8_fnuz() -> bool:
|
56
|
+
if _is_hip:
|
57
|
+
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
58
|
+
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
59
|
+
return False
|
60
|
+
|
61
|
+
|
62
|
+
if is_fp8_fnuz():
|
63
|
+
fp8_dtype = torch.float8_e4m3fnuz
|
64
|
+
fp8_max = 224.0
|
65
|
+
else:
|
66
|
+
fp8_dtype = torch.float8_e4m3fn
|
67
|
+
fp8_max = torch.finfo(fp8_dtype).max
|
68
|
+
fp8_min = -fp8_max
|
69
|
+
|
70
|
+
|
57
71
|
if supports_custom_op():
|
58
72
|
|
59
73
|
def deep_gemm_fp8_fp8_bf16_nt(
|
@@ -198,7 +212,7 @@ def per_token_group_quant_fp8(
|
|
198
212
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
199
213
|
assert x.is_contiguous(), "`x` is not contiguous"
|
200
214
|
|
201
|
-
x_q = torch.empty_like(x, device=x.device, dtype=
|
215
|
+
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
202
216
|
M = x.numel() // group_size
|
203
217
|
N = group_size
|
204
218
|
if column_major_scales:
|
@@ -272,7 +286,7 @@ def sglang_per_token_group_quant_fp8(
|
|
272
286
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
273
287
|
assert x.is_contiguous(), "`x` is not contiguous"
|
274
288
|
|
275
|
-
x_q = torch.empty_like(x, device=x.device, dtype=
|
289
|
+
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
276
290
|
if column_major_scales:
|
277
291
|
if scale_tma_aligned:
|
278
292
|
# aligned to 4 * sizeof(float)
|
@@ -294,15 +308,15 @@ def sglang_per_token_group_quant_fp8(
|
|
294
308
|
device=x.device,
|
295
309
|
dtype=torch.float32,
|
296
310
|
)
|
297
|
-
|
298
|
-
|
311
|
+
if x.shape[0] > 0:
|
312
|
+
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
299
313
|
|
300
314
|
return x_q, x_s
|
301
315
|
|
302
316
|
|
303
317
|
def sglang_per_token_quant_fp8(
|
304
318
|
x: torch.Tensor,
|
305
|
-
dtype: torch.dtype =
|
319
|
+
dtype: torch.dtype = fp8_dtype,
|
306
320
|
):
|
307
321
|
assert x.is_contiguous(), "`x` is not contiguous"
|
308
322
|
|
@@ -384,7 +398,7 @@ def static_quant_fp8(
|
|
384
398
|
assert x.is_contiguous(), "`x` is not contiguous"
|
385
399
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
386
400
|
|
387
|
-
x_q = torch.empty_like(x, device=x.device, dtype=
|
401
|
+
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
388
402
|
M = x.numel() // x.shape[-1]
|
389
403
|
N = x.shape[-1]
|
390
404
|
if repeat_scale:
|
@@ -685,9 +699,9 @@ def get_w8a8_block_fp8_configs(
|
|
685
699
|
)
|
686
700
|
if os.path.exists(config_file_path):
|
687
701
|
with open(config_file_path) as f:
|
688
|
-
|
689
|
-
|
690
|
-
config_file_path,
|
702
|
+
log_info_on_rank0(
|
703
|
+
logger,
|
704
|
+
f"Using configuration from {config_file_path} for W8A8 Block FP8 kernel.",
|
691
705
|
)
|
692
706
|
# If a configuration has been found, return it
|
693
707
|
return {int(key): val for key, val in json.load(f).items()}
|
@@ -704,6 +718,28 @@ def get_w8a8_block_fp8_configs(
|
|
704
718
|
return None
|
705
719
|
|
706
720
|
|
721
|
+
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
722
|
+
return _w8a8_block_fp8_matmul
|
723
|
+
|
724
|
+
|
725
|
+
if _is_hip:
|
726
|
+
|
727
|
+
def use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
728
|
+
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
729
|
+
# Empirical testing shows the sweet spot lies when it's less than the # of
|
730
|
+
# compute units available on the device.
|
731
|
+
num_workgroups = triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(
|
732
|
+
N, META["BLOCK_SIZE_N"]
|
733
|
+
)
|
734
|
+
num_workgroups <= get_device_core_count()
|
735
|
+
|
736
|
+
def select_w8a8_block_fp8_matmul_kernel(M, N, META):
|
737
|
+
if use_w8a8_block_fp8_matmul_unrolledx4(M, N, META):
|
738
|
+
return _w8a8_block_fp8_matmul_unrolledx4
|
739
|
+
else:
|
740
|
+
return _w8a8_block_fp8_matmul
|
741
|
+
|
742
|
+
|
707
743
|
def w8a8_block_fp8_matmul(
|
708
744
|
A: torch.Tensor,
|
709
745
|
B: torch.Tensor,
|
@@ -744,35 +780,6 @@ def w8a8_block_fp8_matmul(
|
|
744
780
|
C_shape = A.shape[:-1] + (N,)
|
745
781
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
746
782
|
|
747
|
-
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
748
|
-
if configs:
|
749
|
-
# If an optimal configuration map has been found, look up the
|
750
|
-
# optimal config
|
751
|
-
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
752
|
-
else:
|
753
|
-
# Default config
|
754
|
-
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
755
|
-
config = {
|
756
|
-
"BLOCK_SIZE_M": 64,
|
757
|
-
"BLOCK_SIZE_N": block_size[0],
|
758
|
-
"BLOCK_SIZE_K": block_size[1],
|
759
|
-
"GROUP_SIZE_M": 32,
|
760
|
-
"num_warps": 4,
|
761
|
-
"num_stages": 3,
|
762
|
-
}
|
763
|
-
|
764
|
-
def grid(META):
|
765
|
-
return (
|
766
|
-
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
767
|
-
)
|
768
|
-
|
769
|
-
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
770
|
-
# Empirical testing shows the sweet spot lies when it's less than the # of
|
771
|
-
# compute units available on the device.
|
772
|
-
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
773
|
-
N, config["BLOCK_SIZE_N"]
|
774
|
-
)
|
775
|
-
|
776
783
|
# deepgemm only support bf16
|
777
784
|
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
778
785
|
if supports_custom_op():
|
@@ -780,11 +787,30 @@ def w8a8_block_fp8_matmul(
|
|
780
787
|
else:
|
781
788
|
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
782
789
|
else:
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
790
|
+
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
791
|
+
if configs:
|
792
|
+
# If an optimal configuration map has been found, look up the
|
793
|
+
# optimal config
|
794
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
795
|
+
else:
|
796
|
+
# Default config
|
797
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
798
|
+
config = {
|
799
|
+
"BLOCK_SIZE_M": 64,
|
800
|
+
"BLOCK_SIZE_N": block_size[0],
|
801
|
+
"BLOCK_SIZE_K": block_size[1],
|
802
|
+
"GROUP_SIZE_M": 32,
|
803
|
+
"num_warps": 4,
|
804
|
+
"num_stages": 3,
|
805
|
+
}
|
806
|
+
|
807
|
+
def grid(META):
|
808
|
+
return (
|
809
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
810
|
+
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
811
|
+
)
|
812
|
+
|
813
|
+
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
|
788
814
|
|
789
815
|
kernel[grid](
|
790
816
|
A,
|
@@ -879,7 +905,7 @@ def per_tensor_quant_mla_fp8(
|
|
879
905
|
and x_s_out.device == x.device
|
880
906
|
)
|
881
907
|
|
882
|
-
x_q = x.new_empty(x.size(), dtype=
|
908
|
+
x_q = x.new_empty(x.size(), dtype=fp8_dtype)
|
883
909
|
|
884
910
|
num_head, num_seq, head_size = x.shape
|
885
911
|
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
@@ -961,11 +987,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
|
961
987
|
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
962
988
|
|
963
989
|
|
964
|
-
def
|
990
|
+
def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
965
991
|
x: torch.Tensor,
|
966
992
|
group_size: int = 128,
|
967
993
|
eps: float = 1e-12,
|
968
|
-
dtype: torch.dtype =
|
994
|
+
dtype: torch.dtype = fp8_dtype,
|
969
995
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
970
996
|
"""
|
971
997
|
This function quantizes input values to float8 values with per-token-group-quantization
|
@@ -973,12 +999,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
|
973
999
|
"""
|
974
1000
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
975
1001
|
|
976
|
-
finfo = torch.finfo(dtype)
|
977
|
-
fp8_max = finfo.max
|
978
|
-
if _is_hip:
|
979
|
-
dtype = torch.float8_e4m3fnuz
|
980
|
-
fp8_max = 224.0
|
981
|
-
|
982
1002
|
b, m, k = x.shape
|
983
1003
|
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
984
1004
|
num_tiles_k = k // group_size
|
@@ -1043,10 +1063,9 @@ def scaled_fp8_quant(
|
|
1043
1063
|
"""
|
1044
1064
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1045
1065
|
shape = input.shape
|
1046
|
-
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1047
1066
|
if num_token_padding:
|
1048
1067
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1049
|
-
output = torch.empty(shape, device=input.device, dtype=
|
1068
|
+
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
1050
1069
|
|
1051
1070
|
if scale is None:
|
1052
1071
|
# Dynamic scaling
|