sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -5,16 +5,23 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
9
8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
10
9
|
from sglang.srt.utils import is_cuda
|
11
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
12
13
|
_is_cuda = is_cuda()
|
13
14
|
if _is_cuda:
|
14
15
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
15
|
-
sglang_per_token_group_quant_fp8,
|
16
|
+
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
16
17
|
)
|
17
|
-
|
18
|
+
|
19
|
+
try:
|
20
|
+
from deep_gemm import ceil_div
|
21
|
+
except ImportError:
|
22
|
+
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
23
|
+
|
24
|
+
import triton.language as tl
|
18
25
|
|
19
26
|
|
20
27
|
@triton.jit
|
@@ -654,10 +661,7 @@ def grouped_gemm_triton(
|
|
654
661
|
if block_shape is not None:
|
655
662
|
assert len(block_shape) == 2
|
656
663
|
block_n, block_k = block_shape[0], block_shape[1]
|
657
|
-
|
658
|
-
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
659
|
-
else:
|
660
|
-
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
664
|
+
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
661
665
|
|
662
666
|
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
663
667
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
@@ -707,3 +711,334 @@ def grouped_gemm_triton(
|
|
707
711
|
**config,
|
708
712
|
)
|
709
713
|
return c
|
714
|
+
|
715
|
+
|
716
|
+
@triton.jit
|
717
|
+
def _fwd_kernel_ep_scatter_1(
|
718
|
+
num_recv_tokens_per_expert,
|
719
|
+
expert_start_loc,
|
720
|
+
m_indices,
|
721
|
+
num_experts: tl.constexpr,
|
722
|
+
BLOCK_E: tl.constexpr,
|
723
|
+
BLOCK_EXPERT_NUM: tl.constexpr,
|
724
|
+
):
|
725
|
+
cur_expert = tl.program_id(0)
|
726
|
+
|
727
|
+
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
|
728
|
+
tokens_per_expert = tl.load(
|
729
|
+
num_recv_tokens_per_expert + offset_cumsum,
|
730
|
+
mask=offset_cumsum < num_experts,
|
731
|
+
other=0,
|
732
|
+
)
|
733
|
+
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
|
734
|
+
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
|
735
|
+
|
736
|
+
cur_expert_start = tl.load(expert_start_loc + cur_expert)
|
737
|
+
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
|
738
|
+
|
739
|
+
m_indices_start_ptr = m_indices + cur_expert_start
|
740
|
+
off_expert = tl.arange(0, BLOCK_E)
|
741
|
+
|
742
|
+
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
743
|
+
tl.store(
|
744
|
+
m_indices_start_ptr + start_m + off_expert,
|
745
|
+
cur_expert,
|
746
|
+
)
|
747
|
+
|
748
|
+
|
749
|
+
@triton.jit
|
750
|
+
def _fwd_kernel_ep_scatter_2(
|
751
|
+
total_token_num,
|
752
|
+
expert_start_loc,
|
753
|
+
recv_x,
|
754
|
+
recv_x_stride0,
|
755
|
+
recv_x_stride1,
|
756
|
+
recv_x_scale,
|
757
|
+
recv_x_scale_stride0,
|
758
|
+
recv_x_scale_stride1,
|
759
|
+
recv_topk,
|
760
|
+
recv_topk_stride0,
|
761
|
+
recv_topk_stride1,
|
762
|
+
output_tensor,
|
763
|
+
output_tensor_stride0,
|
764
|
+
output_tensor_stride1,
|
765
|
+
output_tensor_scale,
|
766
|
+
output_tensor_scale_stride0,
|
767
|
+
output_tensor_scale_stride1,
|
768
|
+
output_index,
|
769
|
+
output_index_stride0,
|
770
|
+
output_index_stride1,
|
771
|
+
topk_num: tl.constexpr,
|
772
|
+
HIDDEN_SIZE: tl.constexpr,
|
773
|
+
HIDDEN_SIZE_PAD: tl.constexpr,
|
774
|
+
SCALE_HIDDEN_SIZE: tl.constexpr,
|
775
|
+
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
|
776
|
+
):
|
777
|
+
start_token_id = tl.program_id(0)
|
778
|
+
grid_num = tl.num_programs(0)
|
779
|
+
|
780
|
+
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
781
|
+
mask = offset_in < HIDDEN_SIZE
|
782
|
+
|
783
|
+
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
784
|
+
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
785
|
+
|
786
|
+
for token_id in range(start_token_id, total_token_num, grid_num):
|
787
|
+
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
788
|
+
to_copy_s = tl.load(
|
789
|
+
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
790
|
+
)
|
791
|
+
|
792
|
+
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
793
|
+
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
794
|
+
if expert_id >= 0:
|
795
|
+
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
796
|
+
tl.store(
|
797
|
+
output_index + token_id * output_index_stride0 + topk_index,
|
798
|
+
dest_token_index,
|
799
|
+
)
|
800
|
+
output_tensor_ptr = (
|
801
|
+
output_tensor + dest_token_index * output_tensor_stride0
|
802
|
+
)
|
803
|
+
output_tensor_scale_ptr = (
|
804
|
+
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
805
|
+
)
|
806
|
+
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
807
|
+
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
|
808
|
+
|
809
|
+
|
810
|
+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
811
|
+
@torch.no_grad()
|
812
|
+
def ep_scatter(
|
813
|
+
recv_x: torch.Tensor,
|
814
|
+
recv_x_scale: torch.Tensor,
|
815
|
+
recv_topk: torch.Tensor,
|
816
|
+
num_recv_tokens_per_expert: torch.Tensor,
|
817
|
+
expert_start_loc: torch.Tensor,
|
818
|
+
output_tensor: torch.Tensor,
|
819
|
+
output_tensor_scale: torch.Tensor,
|
820
|
+
m_indices: torch.Tensor,
|
821
|
+
output_index: torch.Tensor,
|
822
|
+
):
|
823
|
+
BLOCK_E = 128 # token num of per expert is aligned to 128
|
824
|
+
BLOCK_D = 128 # block size of quantization
|
825
|
+
num_warps = 8
|
826
|
+
num_experts = num_recv_tokens_per_expert.shape[0]
|
827
|
+
hidden_size = recv_x.shape[1]
|
828
|
+
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
829
|
+
grid = num_experts
|
830
|
+
|
831
|
+
assert m_indices.shape[0] % BLOCK_E == 0
|
832
|
+
|
833
|
+
_fwd_kernel_ep_scatter_1[(grid,)](
|
834
|
+
num_recv_tokens_per_expert,
|
835
|
+
expert_start_loc,
|
836
|
+
m_indices,
|
837
|
+
num_experts=num_experts,
|
838
|
+
num_warps=num_warps,
|
839
|
+
BLOCK_E=BLOCK_E,
|
840
|
+
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
|
841
|
+
)
|
842
|
+
|
843
|
+
grid = min(recv_topk.shape[0], 1024 * 8)
|
844
|
+
|
845
|
+
_fwd_kernel_ep_scatter_2[(grid,)](
|
846
|
+
recv_topk.shape[0],
|
847
|
+
expert_start_loc,
|
848
|
+
recv_x,
|
849
|
+
recv_x.stride(0),
|
850
|
+
recv_x.stride(1),
|
851
|
+
recv_x_scale,
|
852
|
+
recv_x_scale.stride(0),
|
853
|
+
recv_x_scale.stride(1),
|
854
|
+
recv_topk,
|
855
|
+
recv_topk.stride(0),
|
856
|
+
recv_topk.stride(1),
|
857
|
+
output_tensor,
|
858
|
+
output_tensor.stride(0),
|
859
|
+
output_tensor.stride(1),
|
860
|
+
output_tensor_scale,
|
861
|
+
output_tensor_scale.stride(0),
|
862
|
+
output_tensor_scale.stride(1),
|
863
|
+
output_index,
|
864
|
+
output_index.stride(0),
|
865
|
+
output_index.stride(1),
|
866
|
+
topk_num=recv_topk.shape[1],
|
867
|
+
num_warps=num_warps,
|
868
|
+
HIDDEN_SIZE=hidden_size,
|
869
|
+
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
870
|
+
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
|
871
|
+
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
|
872
|
+
)
|
873
|
+
return
|
874
|
+
|
875
|
+
|
876
|
+
@triton.jit
|
877
|
+
def _fwd_kernel_ep_gather(
|
878
|
+
total_token_num,
|
879
|
+
input_tensor,
|
880
|
+
input_tensor_stride0,
|
881
|
+
input_tensor_stride1,
|
882
|
+
recv_topk_ids,
|
883
|
+
recv_topk_ids_stride0,
|
884
|
+
recv_topk_ids_stride1,
|
885
|
+
recv_topk_weight,
|
886
|
+
recv_topk_weight_stride0,
|
887
|
+
recv_topk_weight_stride1,
|
888
|
+
input_index,
|
889
|
+
input_index_stride0,
|
890
|
+
input_index_stride1,
|
891
|
+
output_tensor,
|
892
|
+
output_tensor_stride0,
|
893
|
+
output_tensor_stride1,
|
894
|
+
topk_num: tl.constexpr,
|
895
|
+
BLOCK_D: tl.constexpr,
|
896
|
+
):
|
897
|
+
cur_block = tl.program_id(0)
|
898
|
+
start_cur_token = tl.program_id(1)
|
899
|
+
grid_num = tl.num_programs(1)
|
900
|
+
|
901
|
+
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
902
|
+
off_d = tl.arange(0, BLOCK_D)
|
903
|
+
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
904
|
+
for topk_index in range(0, topk_num):
|
905
|
+
expert_id = tl.load(
|
906
|
+
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
907
|
+
)
|
908
|
+
if expert_id >= 0:
|
909
|
+
source_token_index = tl.load(
|
910
|
+
input_index + cur_token * input_index_stride0 + topk_index
|
911
|
+
)
|
912
|
+
acc_weight = tl.load(
|
913
|
+
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
914
|
+
)
|
915
|
+
tmp = tl.load(
|
916
|
+
input_tensor
|
917
|
+
+ source_token_index * input_tensor_stride0
|
918
|
+
+ cur_block * BLOCK_D
|
919
|
+
+ off_d
|
920
|
+
)
|
921
|
+
accumulator += tmp.to(tl.float32) * acc_weight
|
922
|
+
|
923
|
+
tl.store(
|
924
|
+
output_tensor
|
925
|
+
+ cur_token * output_tensor_stride0
|
926
|
+
+ cur_block * BLOCK_D
|
927
|
+
+ off_d,
|
928
|
+
accumulator.to(output_tensor.dtype.element_ty),
|
929
|
+
)
|
930
|
+
|
931
|
+
|
932
|
+
@torch.no_grad()
|
933
|
+
def ep_gather(
|
934
|
+
input_tensor: torch.Tensor,
|
935
|
+
recv_topk_ids: torch.Tensor,
|
936
|
+
recv_topk_weight: torch.Tensor,
|
937
|
+
input_index: torch.Tensor,
|
938
|
+
output_tensor: torch.Tensor,
|
939
|
+
):
|
940
|
+
BLOCK_D = 1024 # block size of quantization
|
941
|
+
num_warps = 2
|
942
|
+
num_tokens = output_tensor.shape[0]
|
943
|
+
hidden_size = input_tensor.shape[1]
|
944
|
+
assert hidden_size % BLOCK_D == 0
|
945
|
+
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
946
|
+
_fwd_kernel_ep_gather[grid](
|
947
|
+
num_tokens,
|
948
|
+
input_tensor,
|
949
|
+
input_tensor.stride(0),
|
950
|
+
input_tensor.stride(1),
|
951
|
+
recv_topk_ids,
|
952
|
+
recv_topk_ids.stride(0),
|
953
|
+
recv_topk_ids.stride(1),
|
954
|
+
recv_topk_weight,
|
955
|
+
recv_topk_weight.stride(0),
|
956
|
+
recv_topk_weight.stride(1),
|
957
|
+
input_index,
|
958
|
+
input_index.stride(0),
|
959
|
+
input_index.stride(1),
|
960
|
+
output_tensor,
|
961
|
+
output_tensor.stride(0),
|
962
|
+
output_tensor.stride(1),
|
963
|
+
topk_num=recv_topk_ids.shape[1],
|
964
|
+
num_warps=num_warps,
|
965
|
+
BLOCK_D=BLOCK_D,
|
966
|
+
)
|
967
|
+
return
|
968
|
+
|
969
|
+
|
970
|
+
# copy from
|
971
|
+
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
|
972
|
+
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
973
|
+
"""
|
974
|
+
Global memory address of TMA must be 16-byte aligned.
|
975
|
+
Since we use column-major layout for the LHS scaling tensor,
|
976
|
+
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
977
|
+
|
978
|
+
Arguments:
|
979
|
+
x: original M-axis shape of the LHS scaling tensor.
|
980
|
+
element_size: element size of the LHS scaling tensor.
|
981
|
+
|
982
|
+
Returns:
|
983
|
+
M-axis shape of the LHS scaling tensor after padding.
|
984
|
+
"""
|
985
|
+
tma_alignment_bytes = 16
|
986
|
+
assert tma_alignment_bytes % element_size == 0
|
987
|
+
alignment = tma_alignment_bytes // element_size
|
988
|
+
return ceil_div(x, alignment) * alignment
|
989
|
+
|
990
|
+
|
991
|
+
@triton.jit
|
992
|
+
def _tma_align_input_scale_kernel(
|
993
|
+
input_scale_ptr,
|
994
|
+
output_ptr,
|
995
|
+
m,
|
996
|
+
k_div_block_size,
|
997
|
+
input_scale_stride_m,
|
998
|
+
input_scale_stride_k,
|
999
|
+
output_stride_m,
|
1000
|
+
output_stride_k,
|
1001
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1002
|
+
):
|
1003
|
+
pid_m = tl.program_id(axis=0)
|
1004
|
+
grid_m = tl.num_programs(0)
|
1005
|
+
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
1006
|
+
|
1007
|
+
for m_base in range(pid_m, m, grid_m):
|
1008
|
+
input_offset = (
|
1009
|
+
input_scale_ptr
|
1010
|
+
+ m_base * input_scale_stride_m
|
1011
|
+
+ k_offsets * input_scale_stride_k
|
1012
|
+
)
|
1013
|
+
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
|
1014
|
+
|
1015
|
+
output_offset = (
|
1016
|
+
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
|
1017
|
+
)
|
1018
|
+
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
|
1019
|
+
|
1020
|
+
|
1021
|
+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
|
1022
|
+
def tma_align_input_scale(input_scale: torch.Tensor):
|
1023
|
+
assert input_scale.dim() == 2
|
1024
|
+
m, k_div_block_size = input_scale.shape
|
1025
|
+
padd_m = get_tma_aligned_size(m, input_scale.element_size())
|
1026
|
+
output = torch.empty(
|
1027
|
+
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
grid_m = min(m, 8192)
|
1031
|
+
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
|
1032
|
+
|
1033
|
+
_tma_align_input_scale_kernel[(grid_m,)](
|
1034
|
+
input_scale_ptr=input_scale,
|
1035
|
+
output_ptr=output,
|
1036
|
+
m=m,
|
1037
|
+
k_div_block_size=k_div_block_size,
|
1038
|
+
input_scale_stride_m=input_scale.stride(0),
|
1039
|
+
input_scale_stride_k=input_scale.stride(1),
|
1040
|
+
output_stride_m=output.stride(1), # Note: these are swapped
|
1041
|
+
output_stride_k=output.stride(0), # for column-major
|
1042
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1043
|
+
)
|
1044
|
+
return output.t()[:m]
|
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
+
|
7
9
|
try:
|
8
10
|
from deep_gemm import (
|
9
11
|
get_col_major_tma_aligned_tensor,
|
12
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
10
13
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
11
14
|
)
|
15
|
+
from sgl_kernel import silu_and_mul
|
16
|
+
|
17
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
18
|
+
sglang_per_token_group_quant_fp8,
|
19
|
+
)
|
12
20
|
|
13
21
|
use_deep_gemm = True
|
14
22
|
except ImportError:
|
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
|
|
20
28
|
get_tensor_model_parallel_world_size,
|
21
29
|
)
|
22
30
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
31
|
+
ep_gather,
|
32
|
+
ep_scatter,
|
23
33
|
gelu_and_mul_triton_kernel,
|
24
34
|
grouped_gemm_triton,
|
25
35
|
post_reorder_triton_kernel,
|
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
27
37
|
run_moe_ep_preproess,
|
28
38
|
silu_and_mul_masked_post_quant_fwd,
|
29
39
|
silu_and_mul_triton_kernel,
|
40
|
+
tma_align_input_scale,
|
30
41
|
)
|
31
42
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
32
43
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
|
|
842
853
|
def forward(
|
843
854
|
self,
|
844
855
|
hidden_states: torch.Tensor,
|
856
|
+
topk_idx: torch.Tensor,
|
857
|
+
topk_weights: torch.Tensor,
|
845
858
|
reorder_topk_ids: torch.Tensor,
|
846
859
|
seg_indptr: torch.Tensor,
|
847
860
|
masked_m: torch.Tensor,
|
848
861
|
expected_m: int,
|
862
|
+
num_recv_tokens_per_expert: List[int],
|
849
863
|
forward_mode: ForwardMode,
|
850
864
|
):
|
851
865
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
852
866
|
if resolved_deepep_mode == DeepEPMode.normal:
|
853
|
-
|
867
|
+
if _ENABLE_JIT_DEEPGEMM:
|
868
|
+
return self.forward_deepgemm_contiguous(
|
869
|
+
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
870
|
+
)
|
871
|
+
else:
|
872
|
+
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
854
873
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
855
874
|
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
856
875
|
else:
|
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
|
|
969
988
|
)
|
970
989
|
return down_output
|
971
990
|
|
991
|
+
def forward_deepgemm_contiguous(
|
992
|
+
self,
|
993
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
994
|
+
topk_idx,
|
995
|
+
topk_weights,
|
996
|
+
num_recv_tokens_per_expert: List[int],
|
997
|
+
):
|
998
|
+
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
999
|
+
assert self.quant_method is not None
|
1000
|
+
assert self.activation == "silu"
|
1001
|
+
if num_recv_tokens_per_expert is None:
|
1002
|
+
return hidden_states_fp8.bfloat16()
|
1003
|
+
all_tokens = sum(num_recv_tokens_per_expert)
|
1004
|
+
if all_tokens <= 0:
|
1005
|
+
return hidden_states_fp8.bfloat16()
|
1006
|
+
M, K = hidden_states_fp8.size()
|
1007
|
+
N = self.w13_weight.size(1)
|
1008
|
+
scale_block_size = 128
|
1009
|
+
|
1010
|
+
gather_out = torch.empty_like(
|
1011
|
+
hidden_states_fp8,
|
1012
|
+
device=hidden_states_fp8.device,
|
1013
|
+
dtype=torch.bfloat16,
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
input_tensor = [
|
1017
|
+
torch.empty(
|
1018
|
+
(all_tokens, K),
|
1019
|
+
device=hidden_states_fp8.device,
|
1020
|
+
dtype=hidden_states_fp8.dtype,
|
1021
|
+
),
|
1022
|
+
torch.empty(
|
1023
|
+
(all_tokens, K // 128),
|
1024
|
+
device=hidden_states_fp8.device,
|
1025
|
+
dtype=torch.float32,
|
1026
|
+
),
|
1027
|
+
]
|
1028
|
+
m_indices = torch.empty(
|
1029
|
+
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
1030
|
+
)
|
1031
|
+
output_index = torch.empty_like(topk_idx)
|
1032
|
+
|
1033
|
+
num_recv_tokens_per_expert_gpu = torch.tensor(
|
1034
|
+
num_recv_tokens_per_expert,
|
1035
|
+
dtype=torch.int32,
|
1036
|
+
pin_memory=True,
|
1037
|
+
device="cpu",
|
1038
|
+
).cuda(non_blocking=True)
|
1039
|
+
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
1040
|
+
|
1041
|
+
ep_scatter(
|
1042
|
+
hidden_states_fp8,
|
1043
|
+
hidden_states_scale,
|
1044
|
+
topk_idx,
|
1045
|
+
num_recv_tokens_per_expert_gpu,
|
1046
|
+
expert_start_loc,
|
1047
|
+
input_tensor[0],
|
1048
|
+
input_tensor[1],
|
1049
|
+
m_indices,
|
1050
|
+
output_index,
|
1051
|
+
)
|
1052
|
+
|
1053
|
+
gateup_output = torch.empty(
|
1054
|
+
(all_tokens, N),
|
1055
|
+
device=hidden_states_fp8.device,
|
1056
|
+
dtype=torch.bfloat16,
|
1057
|
+
)
|
1058
|
+
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1059
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1060
|
+
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1061
|
+
)
|
1062
|
+
down_input = torch.empty(
|
1063
|
+
(
|
1064
|
+
all_tokens,
|
1065
|
+
N // 2,
|
1066
|
+
),
|
1067
|
+
device=gateup_output.device,
|
1068
|
+
dtype=torch.bfloat16,
|
1069
|
+
)
|
1070
|
+
silu_and_mul(gateup_output.view(-1, N), down_input)
|
1071
|
+
down_output = torch.empty(
|
1072
|
+
(all_tokens, K),
|
1073
|
+
device=hidden_states_fp8.device,
|
1074
|
+
dtype=torch.bfloat16,
|
1075
|
+
)
|
1076
|
+
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1077
|
+
down_input, scale_block_size
|
1078
|
+
)
|
1079
|
+
down_input_scale = tma_align_input_scale(down_input_scale)
|
1080
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1081
|
+
(down_input_fp8, down_input_scale),
|
1082
|
+
self.w2_weight_fp8,
|
1083
|
+
down_output,
|
1084
|
+
m_indices,
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
1088
|
+
|
1089
|
+
return gather_out
|
1090
|
+
|
972
1091
|
def forward_deepgemm_masked(
|
973
1092
|
self,
|
974
1093
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|