sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
+
|
7
9
|
try:
|
8
10
|
from deep_gemm import (
|
9
11
|
get_col_major_tma_aligned_tensor,
|
12
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
10
13
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
11
14
|
)
|
15
|
+
from sgl_kernel import silu_and_mul
|
16
|
+
|
17
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
18
|
+
sglang_per_token_group_quant_fp8,
|
19
|
+
)
|
12
20
|
|
13
21
|
use_deep_gemm = True
|
14
22
|
except ImportError:
|
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
|
|
20
28
|
get_tensor_model_parallel_world_size,
|
21
29
|
)
|
22
30
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
31
|
+
ep_gather,
|
32
|
+
ep_scatter,
|
23
33
|
gelu_and_mul_triton_kernel,
|
24
34
|
grouped_gemm_triton,
|
25
35
|
post_reorder_triton_kernel,
|
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
27
37
|
run_moe_ep_preproess,
|
28
38
|
silu_and_mul_masked_post_quant_fwd,
|
29
39
|
silu_and_mul_triton_kernel,
|
40
|
+
tma_align_input_scale,
|
30
41
|
)
|
31
42
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
32
43
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
@@ -600,7 +611,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
600
611
|
self.quant_config.weight_block_size[1],
|
601
612
|
)
|
602
613
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
603
|
-
# Required by
|
614
|
+
# Required by column parallel or enabling merged weights
|
604
615
|
if intermediate_size % block_n != 0:
|
605
616
|
raise ValueError(
|
606
617
|
f"The output_size of gate's and up's weight = "
|
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
|
|
842
853
|
def forward(
|
843
854
|
self,
|
844
855
|
hidden_states: torch.Tensor,
|
856
|
+
topk_idx: torch.Tensor,
|
857
|
+
topk_weights: torch.Tensor,
|
845
858
|
reorder_topk_ids: torch.Tensor,
|
846
859
|
seg_indptr: torch.Tensor,
|
847
860
|
masked_m: torch.Tensor,
|
848
861
|
expected_m: int,
|
862
|
+
num_recv_tokens_per_expert: List[int],
|
849
863
|
forward_mode: ForwardMode,
|
850
864
|
):
|
851
865
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
852
866
|
if resolved_deepep_mode == DeepEPMode.normal:
|
853
|
-
|
867
|
+
if _ENABLE_JIT_DEEPGEMM:
|
868
|
+
return self.forward_deepgemm_contiguous(
|
869
|
+
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
870
|
+
)
|
871
|
+
else:
|
872
|
+
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
854
873
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
855
874
|
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
856
875
|
else:
|
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
|
|
969
988
|
)
|
970
989
|
return down_output
|
971
990
|
|
991
|
+
def forward_deepgemm_contiguous(
|
992
|
+
self,
|
993
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
994
|
+
topk_idx,
|
995
|
+
topk_weights,
|
996
|
+
num_recv_tokens_per_expert: List[int],
|
997
|
+
):
|
998
|
+
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
999
|
+
assert self.quant_method is not None
|
1000
|
+
assert self.activation == "silu"
|
1001
|
+
if num_recv_tokens_per_expert is None:
|
1002
|
+
return hidden_states_fp8.bfloat16()
|
1003
|
+
all_tokens = sum(num_recv_tokens_per_expert)
|
1004
|
+
if all_tokens <= 0:
|
1005
|
+
return hidden_states_fp8.bfloat16()
|
1006
|
+
M, K = hidden_states_fp8.size()
|
1007
|
+
N = self.w13_weight.size(1)
|
1008
|
+
scale_block_size = 128
|
1009
|
+
|
1010
|
+
gather_out = torch.empty_like(
|
1011
|
+
hidden_states_fp8,
|
1012
|
+
device=hidden_states_fp8.device,
|
1013
|
+
dtype=torch.bfloat16,
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
input_tensor = [
|
1017
|
+
torch.empty(
|
1018
|
+
(all_tokens, K),
|
1019
|
+
device=hidden_states_fp8.device,
|
1020
|
+
dtype=hidden_states_fp8.dtype,
|
1021
|
+
),
|
1022
|
+
torch.empty(
|
1023
|
+
(all_tokens, K // 128),
|
1024
|
+
device=hidden_states_fp8.device,
|
1025
|
+
dtype=torch.float32,
|
1026
|
+
),
|
1027
|
+
]
|
1028
|
+
m_indices = torch.empty(
|
1029
|
+
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
1030
|
+
)
|
1031
|
+
output_index = torch.empty_like(topk_idx)
|
1032
|
+
|
1033
|
+
num_recv_tokens_per_expert_gpu = torch.tensor(
|
1034
|
+
num_recv_tokens_per_expert,
|
1035
|
+
dtype=torch.int32,
|
1036
|
+
pin_memory=True,
|
1037
|
+
device="cpu",
|
1038
|
+
).cuda(non_blocking=True)
|
1039
|
+
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
1040
|
+
|
1041
|
+
ep_scatter(
|
1042
|
+
hidden_states_fp8,
|
1043
|
+
hidden_states_scale,
|
1044
|
+
topk_idx,
|
1045
|
+
num_recv_tokens_per_expert_gpu,
|
1046
|
+
expert_start_loc,
|
1047
|
+
input_tensor[0],
|
1048
|
+
input_tensor[1],
|
1049
|
+
m_indices,
|
1050
|
+
output_index,
|
1051
|
+
)
|
1052
|
+
|
1053
|
+
gateup_output = torch.empty(
|
1054
|
+
(all_tokens, N),
|
1055
|
+
device=hidden_states_fp8.device,
|
1056
|
+
dtype=torch.bfloat16,
|
1057
|
+
)
|
1058
|
+
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1059
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1060
|
+
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1061
|
+
)
|
1062
|
+
down_input = torch.empty(
|
1063
|
+
(
|
1064
|
+
all_tokens,
|
1065
|
+
N // 2,
|
1066
|
+
),
|
1067
|
+
device=gateup_output.device,
|
1068
|
+
dtype=torch.bfloat16,
|
1069
|
+
)
|
1070
|
+
silu_and_mul(gateup_output.view(-1, N), down_input)
|
1071
|
+
down_output = torch.empty(
|
1072
|
+
(all_tokens, K),
|
1073
|
+
device=hidden_states_fp8.device,
|
1074
|
+
dtype=torch.bfloat16,
|
1075
|
+
)
|
1076
|
+
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1077
|
+
down_input, scale_block_size
|
1078
|
+
)
|
1079
|
+
down_input_scale = tma_align_input_scale(down_input_scale)
|
1080
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1081
|
+
(down_input_fp8, down_input_scale),
|
1082
|
+
self.w2_weight_fp8,
|
1083
|
+
down_output,
|
1084
|
+
m_indices,
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
1088
|
+
|
1089
|
+
return gather_out
|
1090
|
+
|
972
1091
|
def forward_deepgemm_masked(
|
973
1092
|
self,
|
974
1093
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
@@ -1,14 +1,19 @@
|
|
1
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
1
2
|
from sglang.srt.utils import DeepEPMode
|
2
3
|
|
3
4
|
try:
|
4
5
|
from deep_ep import Buffer
|
5
6
|
|
7
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
8
|
+
sglang_per_token_group_quant_fp8,
|
9
|
+
)
|
10
|
+
|
6
11
|
use_deepep = True
|
7
12
|
except ImportError:
|
8
13
|
use_deepep = False
|
9
14
|
|
10
15
|
from enum import IntEnum, auto
|
11
|
-
from typing import Optional, Tuple
|
16
|
+
from typing import Optional, Tuple, Union
|
12
17
|
|
13
18
|
import torch
|
14
19
|
import torch.distributed as dist
|
@@ -78,7 +83,6 @@ class DeepEPBuffer:
|
|
78
83
|
),
|
79
84
|
num_rdma_bytes,
|
80
85
|
)
|
81
|
-
|
82
86
|
cls._buffer = Buffer(
|
83
87
|
group,
|
84
88
|
num_nvl_bytes,
|
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
181
185
|
topk_weights: torch.Tensor,
|
182
186
|
):
|
183
187
|
topk_idx = topk_idx.to(torch.int64)
|
188
|
+
if _ENABLE_JIT_DEEPGEMM:
|
189
|
+
# TODO hard code 128 block quant,use fp8 communication
|
190
|
+
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
184
191
|
previous_event = Buffer.capture() if self.async_finish else None
|
185
192
|
return hidden_states, topk_idx, topk_weights, previous_event
|
186
193
|
|
187
194
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
195
|
+
if _ENABLE_JIT_DEEPGEMM:
|
196
|
+
(
|
197
|
+
hidden_states,
|
198
|
+
topk_idx,
|
199
|
+
topk_weights,
|
200
|
+
num_recv_tokens_per_expert_list,
|
201
|
+
event,
|
202
|
+
) = self._dispatch_core(
|
203
|
+
hidden_states, topk_idx, topk_weights, previous_event
|
198
204
|
)
|
199
|
-
|
200
|
-
|
201
|
-
|
205
|
+
event.current_stream_wait() if self.async_finish else ()
|
206
|
+
return (
|
207
|
+
hidden_states,
|
208
|
+
topk_idx,
|
209
|
+
topk_weights,
|
210
|
+
None,
|
211
|
+
num_recv_tokens_per_expert_list,
|
212
|
+
None,
|
213
|
+
None,
|
214
|
+
None,
|
202
215
|
)
|
203
|
-
|
204
|
-
|
216
|
+
else:
|
217
|
+
(
|
218
|
+
hidden_states,
|
219
|
+
topk_idx,
|
220
|
+
topk_weights,
|
221
|
+
num_recv_tokens_per_expert_list,
|
222
|
+
event,
|
223
|
+
) = self._dispatch_core(
|
224
|
+
hidden_states, topk_idx, topk_weights, previous_event
|
205
225
|
)
|
226
|
+
event.current_stream_wait() if self.async_finish else ()
|
227
|
+
if hidden_states.shape[0] > 0:
|
228
|
+
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
229
|
+
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
230
|
+
)
|
231
|
+
else:
|
232
|
+
reorder_topk_ids = torch.empty(
|
233
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
234
|
+
)
|
235
|
+
seg_indptr = torch.zeros(
|
236
|
+
(self.num_experts + 1,),
|
237
|
+
device=hidden_states.device,
|
238
|
+
dtype=torch.int64,
|
239
|
+
)
|
206
240
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
241
|
+
masked_m = expected_m = None
|
242
|
+
return (
|
243
|
+
hidden_states,
|
244
|
+
topk_idx,
|
245
|
+
topk_weights,
|
246
|
+
reorder_topk_ids,
|
247
|
+
None,
|
248
|
+
seg_indptr,
|
249
|
+
masked_m,
|
250
|
+
expected_m,
|
251
|
+
)
|
218
252
|
|
219
253
|
def _dispatch_core(
|
220
254
|
self,
|
221
|
-
x: torch.Tensor,
|
255
|
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
222
256
|
topk_idx: torch.Tensor,
|
223
257
|
topk_weights: torch.Tensor,
|
224
258
|
previous_event,
|
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
246
280
|
recv_x,
|
247
281
|
recv_topk_idx,
|
248
282
|
recv_topk_weights,
|
249
|
-
|
283
|
+
num_recv_tokens_per_expert_list,
|
250
284
|
self.handle,
|
251
285
|
event,
|
252
286
|
) = buffer.dispatch(
|
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
260
294
|
previous_event=previous_event,
|
261
295
|
async_finish=self.async_finish,
|
262
296
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
297
|
+
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
263
298
|
)
|
264
299
|
|
265
300
|
return (
|
266
301
|
recv_x,
|
267
302
|
recv_topk_idx,
|
268
303
|
recv_topk_weights,
|
304
|
+
num_recv_tokens_per_expert_list,
|
269
305
|
event,
|
270
306
|
)
|
271
307
|
|
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
314
350
|
topk_idx: torch.Tensor,
|
315
351
|
topk_weights: torch.Tensor,
|
316
352
|
):
|
317
|
-
if
|
318
|
-
|
319
|
-
output = torch.empty(
|
320
|
-
(num_tokens, hidden_states.shape[1]),
|
321
|
-
device=hidden_states.device,
|
322
|
-
dtype=hidden_states.dtype,
|
323
|
-
)
|
324
|
-
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
325
|
-
hidden_states,
|
326
|
-
output,
|
327
|
-
self.src2dst,
|
328
|
-
topk_idx,
|
329
|
-
topk_weights,
|
330
|
-
self.router_topk,
|
331
|
-
hidden_states.shape[1],
|
332
|
-
BLOCK_SIZE=512,
|
333
|
-
)
|
353
|
+
if _ENABLE_JIT_DEEPGEMM:
|
354
|
+
output = hidden_states
|
334
355
|
else:
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
356
|
+
if hidden_states.shape[0] > 0:
|
357
|
+
num_tokens = self.src2dst.shape[0] // self.router_topk
|
358
|
+
output = torch.empty(
|
359
|
+
(num_tokens, hidden_states.shape[1]),
|
360
|
+
device=hidden_states.device,
|
361
|
+
dtype=hidden_states.dtype,
|
362
|
+
)
|
363
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
364
|
+
hidden_states,
|
365
|
+
output,
|
366
|
+
self.src2dst,
|
367
|
+
topk_idx,
|
368
|
+
topk_weights,
|
369
|
+
self.router_topk,
|
370
|
+
hidden_states.shape[1],
|
371
|
+
BLOCK_SIZE=512,
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
output = torch.zeros(
|
375
|
+
(0, hidden_states.shape[1]),
|
376
|
+
device=hidden_states.device,
|
377
|
+
dtype=hidden_states.dtype,
|
378
|
+
)
|
340
379
|
previous_event = Buffer.capture() if self.async_finish else None
|
341
380
|
return output, previous_event
|
342
381
|
|
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
360
399
|
|
361
400
|
def _get_buffer(self):
|
362
401
|
DeepEPBuffer.set_dispatch_mode_as_normal()
|
402
|
+
|
363
403
|
return DeepEPBuffer.get_deepep_buffer(
|
364
404
|
self.group,
|
365
405
|
self.hidden_size,
|
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
426
466
|
topk_idx,
|
427
467
|
topk_weights,
|
428
468
|
reorder_topk_ids,
|
469
|
+
None,
|
429
470
|
seg_indptr,
|
430
471
|
masked_m,
|
431
472
|
expected_m,
|
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
|
|
570
611
|
|
571
612
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
572
613
|
self.dispatch_a(*args, **kwargs)
|
573
|
-
|
614
|
+
ret = self.dispatch_b()
|
615
|
+
return ret
|
574
616
|
|
575
617
|
def dispatch_a(
|
576
618
|
self,
|
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
|
|
593
635
|
|
594
636
|
def combine(self, *args, **kwargs) -> Tuple:
|
595
637
|
self.combine_a(*args, **kwargs)
|
596
|
-
|
638
|
+
ret = self.combine_b()
|
639
|
+
return ret
|
597
640
|
|
598
641
|
def combine_a(
|
599
642
|
self,
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|