sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -4,9 +4,8 @@ from typing import List, Optional
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
|
7
|
-
from sglang.math_utils import ceil_div
|
8
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
-
from sglang.srt.utils import dispose_tensor, is_cuda
|
8
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
10
9
|
|
11
10
|
logger = logging.getLogger(__name__)
|
12
11
|
|
@@ -147,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|
147
146
|
|
148
147
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
149
148
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
149
|
+
|
150
150
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
151
151
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
152
152
|
|
@@ -159,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|
159
159
|
compute_src2dst_triton_kernel[grid](
|
160
160
|
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
161
161
|
)
|
162
|
+
|
163
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
164
|
+
|
165
|
+
|
166
|
+
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
167
|
+
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
168
|
+
|
169
|
+
seg_indptr = torch.zeros(
|
170
|
+
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
171
|
+
)
|
172
|
+
src2dst = torch.empty(
|
173
|
+
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
174
|
+
)
|
175
|
+
|
176
|
+
BLOCK_SIZE = 512
|
177
|
+
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
178
|
+
compute_src2dst_triton_kernel[grid](
|
179
|
+
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
180
|
+
)
|
181
|
+
|
162
182
|
return reorder_topk_ids, src2dst, seg_indptr
|
163
183
|
|
164
184
|
|
185
|
+
@triton.jit
|
186
|
+
def pre_reorder_triton_kernel_for_cutlass_moe(
|
187
|
+
input_ptr,
|
188
|
+
gateup_input_ptr,
|
189
|
+
src2dst_ptr,
|
190
|
+
topk_ids_ptr,
|
191
|
+
a1_scales_ptr,
|
192
|
+
num_experts,
|
193
|
+
topk,
|
194
|
+
hidden_size,
|
195
|
+
BLOCK_SIZE: tl.constexpr,
|
196
|
+
):
|
197
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
198
|
+
|
199
|
+
src_idx = tl.program_id(0)
|
200
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
201
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
202
|
+
|
203
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
204
|
+
for idx in range(topk):
|
205
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
206
|
+
if expert_id != num_experts:
|
207
|
+
if a1_scales_ptr is not None:
|
208
|
+
scale = 1.0 / tl.load(a1_scales_ptr)
|
209
|
+
else:
|
210
|
+
scale = 1.0
|
211
|
+
|
212
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
213
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
214
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
215
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
216
|
+
mask = offset < hidden_size
|
217
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
218
|
+
out_data = (in_data * scale).to(OutDtype)
|
219
|
+
tl.store(dst_ptr + offset, out_data, mask=mask)
|
220
|
+
|
221
|
+
|
165
222
|
@triton.jit
|
166
223
|
def pre_reorder_triton_kernel(
|
167
224
|
input_ptr,
|
@@ -814,14 +871,17 @@ def _fwd_kernel_ep_scatter_2(
|
|
814
871
|
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
815
872
|
mask = offset_in < HIDDEN_SIZE
|
816
873
|
|
817
|
-
|
818
|
-
mask_s =
|
874
|
+
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
875
|
+
mask_s = index_in_s < SCALE_HIDDEN_SIZE
|
819
876
|
|
820
877
|
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
821
878
|
token_id = token_id_int32.to(tl.int64)
|
822
879
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
823
880
|
to_copy_s = tl.load(
|
824
|
-
recv_x_scale
|
881
|
+
recv_x_scale
|
882
|
+
+ token_id * recv_x_scale_stride0
|
883
|
+
+ index_in_s * recv_x_scale_stride1,
|
884
|
+
mask=mask_s,
|
825
885
|
)
|
826
886
|
|
827
887
|
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
@@ -842,7 +902,11 @@ def _fwd_kernel_ep_scatter_2(
|
|
842
902
|
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
843
903
|
)
|
844
904
|
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
845
|
-
tl.store(
|
905
|
+
tl.store(
|
906
|
+
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
|
907
|
+
to_copy_s,
|
908
|
+
mask=mask_s,
|
909
|
+
)
|
846
910
|
|
847
911
|
|
848
912
|
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
@@ -857,6 +921,7 @@ def ep_scatter(
|
|
857
921
|
output_tensor_scale: torch.Tensor,
|
858
922
|
m_indices: torch.Tensor,
|
859
923
|
output_index: torch.Tensor,
|
924
|
+
scale_ue8m0: bool = False,
|
860
925
|
):
|
861
926
|
BLOCK_E = 128 # token num of per expert is aligned to 128
|
862
927
|
BLOCK_D = 128 # block size of quantization
|
@@ -866,7 +931,15 @@ def ep_scatter(
|
|
866
931
|
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
867
932
|
grid = num_experts
|
868
933
|
|
934
|
+
scale_hidden_size = hidden_size // BLOCK_D
|
935
|
+
if scale_ue8m0:
|
936
|
+
# ue8m0 scales are packed here (4 scales per int32),
|
937
|
+
# hence the effective size of this dimension is divided by 4.
|
938
|
+
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
939
|
+
|
869
940
|
assert m_indices.shape[0] % BLOCK_E == 0
|
941
|
+
assert recv_x_scale.dtype == output_tensor_scale.dtype
|
942
|
+
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
870
943
|
|
871
944
|
_fwd_kernel_ep_scatter_1[(grid,)](
|
872
945
|
num_recv_tokens_per_expert,
|
@@ -905,8 +978,8 @@ def ep_scatter(
|
|
905
978
|
num_warps=num_warps,
|
906
979
|
HIDDEN_SIZE=hidden_size,
|
907
980
|
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
908
|
-
SCALE_HIDDEN_SIZE=
|
909
|
-
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(
|
981
|
+
SCALE_HIDDEN_SIZE=scale_hidden_size,
|
982
|
+
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
|
910
983
|
)
|
911
984
|
return
|
912
985
|
|
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import einops
|
5
5
|
import torch
|
6
|
-
from sgl_kernel import silu_and_mul
|
7
6
|
from torch.nn import Module
|
8
7
|
|
9
8
|
from sglang.srt.custom_op import CustomOp
|
@@ -11,6 +10,9 @@ from sglang.srt.distributed import (
|
|
11
10
|
get_tensor_model_parallel_rank,
|
12
11
|
get_tensor_model_parallel_world_size,
|
13
12
|
)
|
13
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
15
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
14
16
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
15
17
|
ep_gather,
|
16
18
|
ep_scatter,
|
@@ -19,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
19
21
|
moe_ep_deepgemm_preprocess,
|
20
22
|
post_reorder_triton_kernel,
|
21
23
|
pre_reorder_triton_kernel,
|
24
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
25
|
+
run_cutlass_moe_ep_preproess,
|
22
26
|
run_moe_ep_preproess,
|
23
27
|
silu_and_mul_masked_post_quant_fwd,
|
24
28
|
silu_and_mul_triton_kernel,
|
@@ -40,22 +44,27 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
40
44
|
sglang_per_token_quant_fp8,
|
41
45
|
)
|
42
46
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
43
|
-
from sglang.srt.
|
44
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
47
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
45
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
50
|
from sglang.srt.utils import (
|
48
51
|
DeepEPMode,
|
52
|
+
ceil_div,
|
49
53
|
dispose_tensor,
|
50
54
|
get_bool_env_var,
|
51
55
|
is_hip,
|
56
|
+
is_npu,
|
52
57
|
set_weight_attrs,
|
53
58
|
)
|
54
59
|
|
55
60
|
_is_hip = is_hip()
|
61
|
+
_is_npu = is_npu()
|
56
62
|
_is_fp8_fnuz = is_fp8_fnuz()
|
57
63
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
58
64
|
|
65
|
+
if not _is_npu:
|
66
|
+
from sgl_kernel import silu_and_mul
|
67
|
+
|
59
68
|
if _is_hip:
|
60
69
|
from vllm._custom_ops import scaled_fp8_quant
|
61
70
|
|
@@ -186,7 +195,7 @@ class EPMoE(torch.nn.Module):
|
|
186
195
|
num_fused_shared_experts == 0
|
187
196
|
), "num_fused_shared_experts is not supported in EP"
|
188
197
|
self.num_fused_shared_experts = num_fused_shared_experts
|
189
|
-
self.num_experts_per_partition
|
198
|
+
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
190
199
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
191
200
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
192
201
|
|
@@ -210,6 +219,18 @@ class EPMoE(torch.nn.Module):
|
|
210
219
|
self.use_block_quant = False
|
211
220
|
self.block_shape = None
|
212
221
|
self.activation_scheme = None
|
222
|
+
self.use_w4afp8 = False
|
223
|
+
elif isinstance(quant_config, W4AFp8Config):
|
224
|
+
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
225
|
+
quant_config
|
226
|
+
)
|
227
|
+
self.use_w4afp8 = True
|
228
|
+
self.use_fp8_w8a8 = False
|
229
|
+
self.use_block_quant = False
|
230
|
+
self.fp8_dtype = torch.float8_e4m3fn
|
231
|
+
self.w13_weight_scale = None
|
232
|
+
self.w2_weight_scale = None
|
233
|
+
self.activation_scheme = quant_config.moe_activation_scheme
|
213
234
|
else:
|
214
235
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
215
236
|
quant_config
|
@@ -223,6 +244,7 @@ class EPMoE(torch.nn.Module):
|
|
223
244
|
)
|
224
245
|
self.fp8_dtype = torch.float8_e4m3fn
|
225
246
|
self.activation_scheme = quant_config.activation_scheme
|
247
|
+
self.use_w4afp8 = False
|
226
248
|
|
227
249
|
self.quant_method.create_weights(
|
228
250
|
layer=self,
|
@@ -248,6 +270,49 @@ class EPMoE(torch.nn.Module):
|
|
248
270
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
249
271
|
)
|
250
272
|
|
273
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
274
|
+
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
275
|
+
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
276
|
+
"""
|
277
|
+
Calculates how many experts should be assigned to each rank for EP and
|
278
|
+
creates a mapping from global to local expert index. Experts are
|
279
|
+
distributed evenly across ranks. Any remaining are assigned to the
|
280
|
+
last rank.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
284
|
+
- local_num_experts (int): The number of experts assigned
|
285
|
+
to the current rank.
|
286
|
+
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
287
|
+
(global_num_experts,) mapping from global to local index.
|
288
|
+
Contains global_num_experts for experts not assigned to the current rank.
|
289
|
+
Returns None if ep_size is 1.
|
290
|
+
"""
|
291
|
+
ep_size = self.tp_size
|
292
|
+
ep_rank = self.tp_rank
|
293
|
+
global_num_experts = self.num_experts
|
294
|
+
|
295
|
+
assert ep_size > 0
|
296
|
+
if ep_size == 1:
|
297
|
+
return (global_num_experts, None)
|
298
|
+
|
299
|
+
local_num_experts = global_num_experts // ep_size
|
300
|
+
|
301
|
+
expert_map = torch.full(
|
302
|
+
(global_num_experts,), self.num_experts, dtype=torch.int32
|
303
|
+
)
|
304
|
+
if ep_rank < (ep_size - 1):
|
305
|
+
expert_map[
|
306
|
+
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
307
|
+
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
308
|
+
else:
|
309
|
+
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
310
|
+
|
311
|
+
expert_map[-local_num_experts:] = torch.arange(
|
312
|
+
0, local_num_experts, dtype=torch.int32
|
313
|
+
)
|
314
|
+
return (local_num_experts, expert_map)
|
315
|
+
|
251
316
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
252
317
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
253
318
|
return self.forward_deepgemm(hidden_states, router_logits)
|
@@ -435,6 +500,51 @@ class EPMoE(torch.nn.Module):
|
|
435
500
|
),
|
436
501
|
)
|
437
502
|
|
503
|
+
if self.use_w4afp8:
|
504
|
+
local_topk_ids = topk_ids
|
505
|
+
if self.expert_map is not None:
|
506
|
+
"Translate info from expert_map to topk_ids"
|
507
|
+
local_topk_ids = torch.where(
|
508
|
+
self.expert_map[topk_ids] != self.num_experts,
|
509
|
+
self.expert_map[topk_ids],
|
510
|
+
self.num_experts,
|
511
|
+
)
|
512
|
+
|
513
|
+
output = cutlass_w4a8_moe(
|
514
|
+
self.start_expert_id,
|
515
|
+
self.end_expert_id,
|
516
|
+
self.num_experts,
|
517
|
+
hidden_states,
|
518
|
+
self.w13_weight,
|
519
|
+
self.w2_weight,
|
520
|
+
self.w13_weight_scale_inv,
|
521
|
+
self.w2_weight_scale_inv,
|
522
|
+
topk_weights,
|
523
|
+
topk_ids,
|
524
|
+
local_topk_ids,
|
525
|
+
self.quant_method.a_strides1,
|
526
|
+
self.quant_method.b_strides1,
|
527
|
+
self.quant_method.c_strides1,
|
528
|
+
self.quant_method.a_strides2,
|
529
|
+
self.quant_method.b_strides2,
|
530
|
+
self.quant_method.c_strides2,
|
531
|
+
self.quant_method.s_strides13,
|
532
|
+
self.quant_method.s_strides2,
|
533
|
+
self.quant_method.expert_offsets,
|
534
|
+
self.quant_method.problem_sizes1,
|
535
|
+
self.quant_method.problem_sizes2,
|
536
|
+
self.w13_input_scale,
|
537
|
+
self.w2_input_scale,
|
538
|
+
)
|
539
|
+
return output
|
540
|
+
|
541
|
+
if self.grouped_gemm_runner is None:
|
542
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
543
|
+
hidden_states.device,
|
544
|
+
use_flashinfer=False, # TODO: use flashinfer
|
545
|
+
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
546
|
+
)
|
547
|
+
|
438
548
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
439
549
|
topk_ids, self.num_experts
|
440
550
|
)
|
@@ -444,7 +554,7 @@ class EPMoE(torch.nn.Module):
|
|
444
554
|
device=hidden_states.device,
|
445
555
|
dtype=(
|
446
556
|
self.fp8_dtype
|
447
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
557
|
+
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
448
558
|
else hidden_states.dtype
|
449
559
|
),
|
450
560
|
)
|
@@ -651,6 +761,23 @@ class EPMoE(torch.nn.Module):
|
|
651
761
|
]
|
652
762
|
]
|
653
763
|
|
764
|
+
@classmethod
|
765
|
+
def make_expert_input_scale_params_mapping(
|
766
|
+
cls,
|
767
|
+
num_experts: int,
|
768
|
+
) -> List[Tuple[str, str, int, str]]:
|
769
|
+
# (param_name, weight_name, expert_id, shard_id)
|
770
|
+
return [
|
771
|
+
(
|
772
|
+
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
773
|
+
f"experts.{expert_id}.{shard_id}.",
|
774
|
+
expert_id,
|
775
|
+
shard_id,
|
776
|
+
)
|
777
|
+
for expert_id in range(num_experts)
|
778
|
+
for shard_id in ["w1", "w2", "w3"]
|
779
|
+
]
|
780
|
+
|
654
781
|
def weight_loader(
|
655
782
|
self,
|
656
783
|
param: torch.nn.Parameter,
|
@@ -722,6 +849,15 @@ class EPMoE(torch.nn.Module):
|
|
722
849
|
|
723
850
|
# Input scales can be loaded directly and should be equal.
|
724
851
|
if "input_scale" in weight_name:
|
852
|
+
if self.use_w4afp8:
|
853
|
+
if shard_id == "w1":
|
854
|
+
param_data[expert_id][0] = loaded_weight
|
855
|
+
elif shard_id == "w3":
|
856
|
+
param_data[expert_id][1] = loaded_weight
|
857
|
+
else:
|
858
|
+
param_data[expert_id] = loaded_weight
|
859
|
+
return
|
860
|
+
|
725
861
|
if (
|
726
862
|
(shard_id == "w1" or shard_id == "w3")
|
727
863
|
and param_data[expert_id] != 1
|
@@ -747,6 +883,13 @@ class EPMoE(torch.nn.Module):
|
|
747
883
|
] = loaded_weight
|
748
884
|
else: # w2
|
749
885
|
param_data[expert_id] = loaded_weight
|
886
|
+
elif self.use_w4afp8:
|
887
|
+
if shard_id == "w1":
|
888
|
+
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
889
|
+
elif shard_id == "w3":
|
890
|
+
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
891
|
+
else:
|
892
|
+
param_data[expert_id] = loaded_weight
|
750
893
|
# If we are in merged column case (gate_up_proj)
|
751
894
|
else:
|
752
895
|
if shard_id in ("w1", "w3"):
|
@@ -1173,12 +1316,14 @@ class DeepEPMoE(EPMoE):
|
|
1173
1316
|
masked_m: torch.Tensor,
|
1174
1317
|
expected_m: int,
|
1175
1318
|
num_recv_tokens_per_expert: List[int],
|
1176
|
-
|
1319
|
+
forward_batch: ForwardBatch,
|
1177
1320
|
):
|
1178
1321
|
if _use_aiter:
|
1179
1322
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
1180
1323
|
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
1181
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
1324
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
1325
|
+
forward_batch.is_extend_in_batch
|
1326
|
+
)
|
1182
1327
|
if resolved_deepep_mode == DeepEPMode.normal:
|
1183
1328
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
1184
1329
|
return self.forward_deepgemm_contiguous(
|
@@ -1370,10 +1515,19 @@ class DeepEPMoE(EPMoE):
|
|
1370
1515
|
device=hidden_states_fp8.device,
|
1371
1516
|
dtype=hidden_states_fp8.dtype,
|
1372
1517
|
),
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1518
|
+
(
|
1519
|
+
# TODO check whether need `zeros`
|
1520
|
+
torch.zeros(
|
1521
|
+
(ceil_div(K // 128, 4), all_tokens),
|
1522
|
+
device=hidden_states_fp8.device,
|
1523
|
+
dtype=torch.int,
|
1524
|
+
).transpose(0, 1)
|
1525
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1526
|
+
else torch.empty(
|
1527
|
+
(all_tokens, K // 128),
|
1528
|
+
device=hidden_states_fp8.device,
|
1529
|
+
dtype=torch.float32,
|
1530
|
+
)
|
1377
1531
|
),
|
1378
1532
|
]
|
1379
1533
|
m_indices = torch.empty(
|
@@ -1399,6 +1553,7 @@ class DeepEPMoE(EPMoE):
|
|
1399
1553
|
input_tensor[1],
|
1400
1554
|
m_indices,
|
1401
1555
|
output_index,
|
1556
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1402
1557
|
)
|
1403
1558
|
dispose_tensor(hidden_states_fp8)
|
1404
1559
|
|
@@ -1407,7 +1562,8 @@ class DeepEPMoE(EPMoE):
|
|
1407
1562
|
device=hidden_states_fp8_device,
|
1408
1563
|
dtype=torch.bfloat16,
|
1409
1564
|
)
|
1410
|
-
|
1565
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1566
|
+
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1411
1567
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1412
1568
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1413
1569
|
)
|
@@ -1428,10 +1584,15 @@ class DeepEPMoE(EPMoE):
|
|
1428
1584
|
dtype=torch.bfloat16,
|
1429
1585
|
)
|
1430
1586
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1431
|
-
down_input,
|
1587
|
+
down_input,
|
1588
|
+
scale_block_size,
|
1589
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1590
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1591
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1432
1592
|
)
|
1433
1593
|
del down_input
|
1434
|
-
|
1594
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1595
|
+
down_input_scale = tma_align_input_scale(down_input_scale)
|
1435
1596
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1436
1597
|
(down_input_fp8, down_input_scale),
|
1437
1598
|
self.w2_weight_fp8,
|
@@ -1,10 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
4
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
5
|
-
from sglang.srt.managers.expert_distribution import (
|
6
|
-
get_global_expert_distribution_recorder,
|
7
|
-
)
|
8
6
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
7
|
from sglang.srt.utils import (
|
10
8
|
DeepEPMode,
|
@@ -36,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
36
34
|
deepep_post_reorder_triton_kernel,
|
37
35
|
deepep_run_moe_deep_preprocess,
|
38
36
|
)
|
39
|
-
from sglang.srt.model_executor.forward_batch_info import
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
38
|
|
41
39
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
42
40
|
|
@@ -246,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
246
244
|
topk_idx = topk_idx.to(torch.int64)
|
247
245
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
248
246
|
# TODO hard code 128 block quant,use fp8 communication
|
249
|
-
hidden_states = sglang_per_token_group_quant_fp8(
|
247
|
+
hidden_states = sglang_per_token_group_quant_fp8(
|
248
|
+
hidden_states,
|
249
|
+
128,
|
250
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
251
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
252
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
253
|
+
)
|
250
254
|
previous_event = Buffer.capture() if self.async_finish else None
|
251
255
|
return hidden_states, topk_idx, topk_weights, previous_event
|
252
256
|
|
@@ -682,21 +686,21 @@ class DeepEPDispatcher:
|
|
682
686
|
hidden_states: torch.Tensor,
|
683
687
|
topk_idx: torch.Tensor,
|
684
688
|
topk_weights: torch.Tensor,
|
685
|
-
|
689
|
+
forward_batch: ForwardBatch,
|
686
690
|
):
|
687
691
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
688
|
-
inner_state = self._get_impl(
|
692
|
+
inner_state = self._get_impl(forward_batch).dispatch_a(
|
689
693
|
hidden_states=hidden_states,
|
690
694
|
topk_idx=topk_idx,
|
691
695
|
topk_weights=topk_weights,
|
692
696
|
)
|
693
|
-
self._dispatch_intermediate_state =
|
697
|
+
self._dispatch_intermediate_state = forward_batch, inner_state
|
694
698
|
|
695
699
|
def dispatch_b(self):
|
696
700
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
697
|
-
|
701
|
+
forward_batch, inner_state = self._dispatch_intermediate_state
|
698
702
|
del self._dispatch_intermediate_state
|
699
|
-
return self._get_impl(
|
703
|
+
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
700
704
|
|
701
705
|
def combine(self, *args, **kwargs) -> Tuple:
|
702
706
|
self.combine_a(*args, **kwargs)
|
@@ -708,24 +712,26 @@ class DeepEPDispatcher:
|
|
708
712
|
hidden_states: torch.Tensor,
|
709
713
|
topk_idx: torch.Tensor,
|
710
714
|
topk_weights: torch.Tensor,
|
711
|
-
|
715
|
+
forward_batch: ForwardBatch,
|
712
716
|
):
|
713
717
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
714
|
-
inner_state = self._get_impl(
|
718
|
+
inner_state = self._get_impl(forward_batch).combine_a(
|
715
719
|
hidden_states=hidden_states,
|
716
720
|
topk_idx=topk_idx,
|
717
721
|
topk_weights=topk_weights,
|
718
722
|
)
|
719
|
-
self._combine_intermediate_state =
|
723
|
+
self._combine_intermediate_state = forward_batch, inner_state
|
720
724
|
|
721
725
|
def combine_b(self):
|
722
726
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
723
|
-
|
727
|
+
forward_batch, inner_state = self._combine_intermediate_state
|
724
728
|
del self._combine_intermediate_state
|
725
|
-
return self._get_impl(
|
729
|
+
return self._get_impl(forward_batch).combine_b(*inner_state)
|
726
730
|
|
727
|
-
def _get_impl(self,
|
728
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
731
|
+
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
732
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
733
|
+
forward_batch.is_extend_in_batch
|
734
|
+
)
|
729
735
|
if resolved_deepep_mode == DeepEPMode.normal:
|
730
736
|
return self._normal_dispatcher
|
731
737
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
@@ -12,7 +12,6 @@ import torch
|
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
-
from sglang.math_utils import ceil_div
|
16
15
|
from sglang.srt.layers.moe.topk import select_experts
|
17
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
18
17
|
per_token_group_quant_fp8,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
25
24
|
sglang_per_token_group_quant_int8,
|
26
25
|
)
|
27
26
|
from sglang.srt.utils import (
|
27
|
+
ceil_div,
|
28
28
|
cpu_has_amx_support,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_bool_env_var,
|
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
|
|
32
32
|
is_cpu,
|
33
33
|
is_cuda,
|
34
34
|
is_hip,
|
35
|
-
log_info_on_rank0,
|
36
35
|
next_power_of_2,
|
37
36
|
)
|
38
37
|
|
@@ -1738,6 +1737,7 @@ def fused_moe(
|
|
1738
1737
|
renormalize: bool,
|
1739
1738
|
inplace: bool = False,
|
1740
1739
|
activation: str = "silu",
|
1740
|
+
apply_router_weight_on_input: bool = False,
|
1741
1741
|
use_grouped_topk: bool = False,
|
1742
1742
|
num_expert_group: Optional[int] = None,
|
1743
1743
|
num_fused_shared_experts: int = 0,
|
@@ -1823,6 +1823,7 @@ def fused_moe(
|
|
1823
1823
|
topk_ids,
|
1824
1824
|
inplace=inplace,
|
1825
1825
|
activation=activation,
|
1826
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
1826
1827
|
use_fp8_w8a8=use_fp8_w8a8,
|
1827
1828
|
use_int8_w8a8=use_int8_w8a8,
|
1828
1829
|
use_int8_w8a16=use_int8_w8a16,
|