sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Cutlass W4A8 MoE kernel."""
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import (
|
7
|
+
cutlass_w4a8_moe_mm,
|
8
|
+
get_cutlass_w4a8_moe_mm_data,
|
9
|
+
sgl_per_tensor_quant_fp8,
|
10
|
+
silu_and_mul,
|
11
|
+
)
|
12
|
+
|
13
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
+
post_reorder_triton_kernel,
|
15
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
|
+
run_cutlass_moe_ep_preproess,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def cutlass_w4a8_moe(
|
21
|
+
start_expert_id: int,
|
22
|
+
end_expert_id: int,
|
23
|
+
total_num_experts: int,
|
24
|
+
a: torch.Tensor,
|
25
|
+
w1_q: torch.Tensor,
|
26
|
+
w2_q: torch.Tensor,
|
27
|
+
w1_scale: torch.Tensor,
|
28
|
+
w2_scale: torch.Tensor,
|
29
|
+
topk_weights: torch.Tensor,
|
30
|
+
topk_ids_: torch.Tensor,
|
31
|
+
local_topk_ids: torch.Tensor,
|
32
|
+
a_strides1: torch.Tensor,
|
33
|
+
b_strides1: torch.Tensor,
|
34
|
+
c_strides1: torch.Tensor,
|
35
|
+
a_strides2: torch.Tensor,
|
36
|
+
b_strides2: torch.Tensor,
|
37
|
+
c_strides2: torch.Tensor,
|
38
|
+
s_strides13: torch.Tensor,
|
39
|
+
s_strides2: torch.Tensor,
|
40
|
+
expert_offsets: torch.Tensor,
|
41
|
+
problem_sizes1: torch.Tensor,
|
42
|
+
problem_sizes2: torch.Tensor,
|
43
|
+
a1_scale: Optional[torch.Tensor] = None,
|
44
|
+
a2_scale: Optional[torch.Tensor] = None,
|
45
|
+
apply_router_weight_on_input: bool = False,
|
46
|
+
) -> torch.Tensor:
|
47
|
+
"""
|
48
|
+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
49
|
+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
50
|
+
mechanism. The matrix multiplications are implemented with CUTLASS
|
51
|
+
grouped gemm.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
- a (torch.Tensor): The input tensor to the MoE layer.
|
55
|
+
Shape: [M, K]
|
56
|
+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
57
|
+
Shape: [num_experts, N * 2, K // 2]
|
58
|
+
(the weights are passed transposed and int4-packed)
|
59
|
+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
60
|
+
Shape: [num_experts, K, N // 2]
|
61
|
+
(the weights are passed transposed and int4-packed)
|
62
|
+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
63
|
+
Shape: [num_experts, K // 512, N * 8]
|
64
|
+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
65
|
+
Shape: [num_experts, N // 512, K * 4]
|
66
|
+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
67
|
+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
68
|
+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
69
|
+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
70
|
+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
71
|
+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
72
|
+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
73
|
+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
74
|
+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
75
|
+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
76
|
+
Shape: scalar or [1, K]
|
77
|
+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
78
|
+
quantize the intermediate result between the gemms.
|
79
|
+
Shape: scalar or [1, N]
|
80
|
+
- apply_router_weight_on_input (bool): When true, the topk weights are
|
81
|
+
applied directly on the inputs. This is only applicable when topk is 1.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
85
|
+
"""
|
86
|
+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
87
|
+
assert w1_q.dtype == torch.int8
|
88
|
+
assert w2_q.dtype == torch.int8
|
89
|
+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
90
|
+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
91
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
+
assert (
|
95
|
+
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
+
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
+
), "W1 scale shape mismatch"
|
98
|
+
assert (
|
99
|
+
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
+
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
+
), "W2 scale shape mismatch"
|
102
|
+
|
103
|
+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
|
+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
|
+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
|
+
num_experts = w1_q.size(0)
|
108
|
+
m = a.size(0)
|
109
|
+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
110
|
+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
111
|
+
topk = topk_ids_.size(1)
|
112
|
+
|
113
|
+
if apply_router_weight_on_input:
|
114
|
+
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
115
|
+
|
116
|
+
device = a.device
|
117
|
+
|
118
|
+
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
119
|
+
local_topk_ids,
|
120
|
+
num_experts,
|
121
|
+
)
|
122
|
+
|
123
|
+
gateup_input = torch.empty(
|
124
|
+
(m * topk, k),
|
125
|
+
device=device,
|
126
|
+
dtype=torch.float8_e4m3fn,
|
127
|
+
)
|
128
|
+
|
129
|
+
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
130
|
+
a,
|
131
|
+
gateup_input,
|
132
|
+
src2dst,
|
133
|
+
local_topk_ids,
|
134
|
+
a1_scale,
|
135
|
+
total_num_experts,
|
136
|
+
topk,
|
137
|
+
k,
|
138
|
+
BLOCK_SIZE=512,
|
139
|
+
)
|
140
|
+
|
141
|
+
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
142
|
+
# they are kept to allow for a quick switch of the permutation logic
|
143
|
+
# from the current triton kernel implementation to the cutlass-based one if needed.
|
144
|
+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
145
|
+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
146
|
+
get_cutlass_w4a8_moe_mm_data(
|
147
|
+
local_topk_ids,
|
148
|
+
expert_offsets,
|
149
|
+
problem_sizes1,
|
150
|
+
problem_sizes2,
|
151
|
+
a_map,
|
152
|
+
c_map,
|
153
|
+
num_experts,
|
154
|
+
n,
|
155
|
+
k,
|
156
|
+
)
|
157
|
+
|
158
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
159
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
160
|
+
|
161
|
+
cutlass_w4a8_moe_mm(
|
162
|
+
c1,
|
163
|
+
gateup_input,
|
164
|
+
w1_q,
|
165
|
+
a1_scale.float(),
|
166
|
+
w1_scale,
|
167
|
+
expert_offsets[:-1],
|
168
|
+
problem_sizes1,
|
169
|
+
a_strides1,
|
170
|
+
b_strides1,
|
171
|
+
c_strides1,
|
172
|
+
s_strides13,
|
173
|
+
128,
|
174
|
+
topk,
|
175
|
+
)
|
176
|
+
|
177
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
178
|
+
silu_and_mul(c1, intermediate)
|
179
|
+
|
180
|
+
intermediate_q = torch.empty(
|
181
|
+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
182
|
+
)
|
183
|
+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
184
|
+
|
185
|
+
cutlass_w4a8_moe_mm(
|
186
|
+
c2,
|
187
|
+
intermediate_q,
|
188
|
+
w2_q,
|
189
|
+
a2_scale.float(),
|
190
|
+
w2_scale,
|
191
|
+
expert_offsets[:-1],
|
192
|
+
problem_sizes2,
|
193
|
+
a_strides2,
|
194
|
+
b_strides2,
|
195
|
+
c_strides2,
|
196
|
+
s_strides2,
|
197
|
+
128,
|
198
|
+
topk,
|
199
|
+
)
|
200
|
+
|
201
|
+
output = torch.empty_like(a)
|
202
|
+
post_reorder_triton_kernel[(m,)](
|
203
|
+
c2,
|
204
|
+
output,
|
205
|
+
src2dst,
|
206
|
+
topk_ids_,
|
207
|
+
topk_weights,
|
208
|
+
start_expert_id,
|
209
|
+
end_expert_id,
|
210
|
+
topk,
|
211
|
+
k,
|
212
|
+
0,
|
213
|
+
BLOCK_SIZE=512,
|
214
|
+
)
|
215
|
+
return output
|
@@ -6,6 +6,7 @@ import triton
|
|
6
6
|
|
7
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
8
8
|
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
|
9
|
+
from sglang.utils import is_in_ci
|
9
10
|
|
10
11
|
logger = logging.getLogger(__name__)
|
11
12
|
|
@@ -146,6 +147,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|
146
147
|
|
147
148
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
148
149
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
150
|
+
|
149
151
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
150
152
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
151
153
|
|
@@ -158,9 +160,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|
158
160
|
compute_src2dst_triton_kernel[grid](
|
159
161
|
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
160
162
|
)
|
163
|
+
|
161
164
|
return reorder_topk_ids, src2dst, seg_indptr
|
162
165
|
|
163
166
|
|
167
|
+
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
168
|
+
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
169
|
+
|
170
|
+
seg_indptr = torch.zeros(
|
171
|
+
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
172
|
+
)
|
173
|
+
src2dst = torch.empty(
|
174
|
+
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
175
|
+
)
|
176
|
+
|
177
|
+
BLOCK_SIZE = 512
|
178
|
+
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
179
|
+
compute_src2dst_triton_kernel[grid](
|
180
|
+
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
181
|
+
)
|
182
|
+
|
183
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
184
|
+
|
185
|
+
|
186
|
+
@triton.jit
|
187
|
+
def pre_reorder_triton_kernel_for_cutlass_moe(
|
188
|
+
input_ptr,
|
189
|
+
gateup_input_ptr,
|
190
|
+
src2dst_ptr,
|
191
|
+
topk_ids_ptr,
|
192
|
+
a1_scales_ptr,
|
193
|
+
num_experts,
|
194
|
+
topk,
|
195
|
+
hidden_size,
|
196
|
+
BLOCK_SIZE: tl.constexpr,
|
197
|
+
):
|
198
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
199
|
+
|
200
|
+
src_idx = tl.program_id(0)
|
201
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
202
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
203
|
+
|
204
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
205
|
+
for idx in range(topk):
|
206
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
207
|
+
if expert_id != num_experts:
|
208
|
+
if a1_scales_ptr is not None:
|
209
|
+
scale = 1.0 / tl.load(a1_scales_ptr)
|
210
|
+
else:
|
211
|
+
scale = 1.0
|
212
|
+
|
213
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
214
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
215
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
216
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
217
|
+
mask = offset < hidden_size
|
218
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
219
|
+
out_data = (in_data * scale).to(OutDtype)
|
220
|
+
tl.store(dst_ptr + offset, out_data, mask=mask)
|
221
|
+
|
222
|
+
|
164
223
|
@triton.jit
|
165
224
|
def pre_reorder_triton_kernel(
|
166
225
|
input_ptr,
|
@@ -1000,7 +1059,7 @@ def ep_gather(
|
|
1000
1059
|
input_index: torch.Tensor,
|
1001
1060
|
output_tensor: torch.Tensor,
|
1002
1061
|
):
|
1003
|
-
BLOCK_D = 1024 # block size of quantization
|
1062
|
+
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
|
1004
1063
|
num_warps = 2
|
1005
1064
|
num_tokens = output_tensor.shape[0]
|
1006
1065
|
hidden_size = input_tensor.shape[1]
|
@@ -20,6 +20,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
20
20
|
moe_ep_deepgemm_preprocess,
|
21
21
|
post_reorder_triton_kernel,
|
22
22
|
pre_reorder_triton_kernel,
|
23
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
24
|
+
run_cutlass_moe_ep_preproess,
|
23
25
|
run_moe_ep_preproess,
|
24
26
|
silu_and_mul_masked_post_quant_fwd,
|
25
27
|
silu_and_mul_triton_kernel,
|
@@ -41,6 +43,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
41
43
|
sglang_per_token_quant_fp8,
|
42
44
|
)
|
43
45
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
46
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
44
47
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
49
|
from sglang.srt.utils import (
|
@@ -61,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
61
64
|
if not _is_npu:
|
62
65
|
from sgl_kernel import silu_and_mul
|
63
66
|
|
67
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
68
|
+
|
64
69
|
if _is_hip:
|
65
70
|
from vllm._custom_ops import scaled_fp8_quant
|
66
71
|
|
@@ -191,7 +196,7 @@ class EPMoE(torch.nn.Module):
|
|
191
196
|
num_fused_shared_experts == 0
|
192
197
|
), "num_fused_shared_experts is not supported in EP"
|
193
198
|
self.num_fused_shared_experts = num_fused_shared_experts
|
194
|
-
self.num_experts_per_partition
|
199
|
+
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
195
200
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
196
201
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
197
202
|
|
@@ -215,6 +220,18 @@ class EPMoE(torch.nn.Module):
|
|
215
220
|
self.use_block_quant = False
|
216
221
|
self.block_shape = None
|
217
222
|
self.activation_scheme = None
|
223
|
+
self.use_w4afp8 = False
|
224
|
+
elif isinstance(quant_config, W4AFp8Config):
|
225
|
+
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
226
|
+
quant_config
|
227
|
+
)
|
228
|
+
self.use_w4afp8 = True
|
229
|
+
self.use_fp8_w8a8 = False
|
230
|
+
self.use_block_quant = False
|
231
|
+
self.fp8_dtype = torch.float8_e4m3fn
|
232
|
+
self.w13_weight_scale = None
|
233
|
+
self.w2_weight_scale = None
|
234
|
+
self.activation_scheme = quant_config.moe_activation_scheme
|
218
235
|
else:
|
219
236
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
220
237
|
quant_config
|
@@ -228,6 +245,7 @@ class EPMoE(torch.nn.Module):
|
|
228
245
|
)
|
229
246
|
self.fp8_dtype = torch.float8_e4m3fn
|
230
247
|
self.activation_scheme = quant_config.activation_scheme
|
248
|
+
self.use_w4afp8 = False
|
231
249
|
|
232
250
|
self.quant_method.create_weights(
|
233
251
|
layer=self,
|
@@ -253,6 +271,49 @@ class EPMoE(torch.nn.Module):
|
|
253
271
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
254
272
|
)
|
255
273
|
|
274
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
275
|
+
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
276
|
+
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
277
|
+
"""
|
278
|
+
Calculates how many experts should be assigned to each rank for EP and
|
279
|
+
creates a mapping from global to local expert index. Experts are
|
280
|
+
distributed evenly across ranks. Any remaining are assigned to the
|
281
|
+
last rank.
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
285
|
+
- local_num_experts (int): The number of experts assigned
|
286
|
+
to the current rank.
|
287
|
+
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
288
|
+
(global_num_experts,) mapping from global to local index.
|
289
|
+
Contains global_num_experts for experts not assigned to the current rank.
|
290
|
+
Returns None if ep_size is 1.
|
291
|
+
"""
|
292
|
+
ep_size = self.tp_size
|
293
|
+
ep_rank = self.tp_rank
|
294
|
+
global_num_experts = self.num_experts
|
295
|
+
|
296
|
+
assert ep_size > 0
|
297
|
+
if ep_size == 1:
|
298
|
+
return (global_num_experts, None)
|
299
|
+
|
300
|
+
local_num_experts = global_num_experts // ep_size
|
301
|
+
|
302
|
+
expert_map = torch.full(
|
303
|
+
(global_num_experts,), self.num_experts, dtype=torch.int32
|
304
|
+
)
|
305
|
+
if ep_rank < (ep_size - 1):
|
306
|
+
expert_map[
|
307
|
+
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
308
|
+
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
309
|
+
else:
|
310
|
+
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
311
|
+
|
312
|
+
expert_map[-local_num_experts:] = torch.arange(
|
313
|
+
0, local_num_experts, dtype=torch.int32
|
314
|
+
)
|
315
|
+
return (local_num_experts, expert_map)
|
316
|
+
|
256
317
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
257
318
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
258
319
|
return self.forward_deepgemm(hidden_states, router_logits)
|
@@ -440,6 +501,51 @@ class EPMoE(torch.nn.Module):
|
|
440
501
|
),
|
441
502
|
)
|
442
503
|
|
504
|
+
if self.use_w4afp8:
|
505
|
+
local_topk_ids = topk_ids
|
506
|
+
if self.expert_map is not None:
|
507
|
+
"Translate info from expert_map to topk_ids"
|
508
|
+
local_topk_ids = torch.where(
|
509
|
+
self.expert_map[topk_ids] != self.num_experts,
|
510
|
+
self.expert_map[topk_ids],
|
511
|
+
self.num_experts,
|
512
|
+
)
|
513
|
+
|
514
|
+
output = cutlass_w4a8_moe(
|
515
|
+
self.start_expert_id,
|
516
|
+
self.end_expert_id,
|
517
|
+
self.num_experts,
|
518
|
+
hidden_states,
|
519
|
+
self.w13_weight,
|
520
|
+
self.w2_weight,
|
521
|
+
self.w13_weight_scale_inv,
|
522
|
+
self.w2_weight_scale_inv,
|
523
|
+
topk_weights,
|
524
|
+
topk_ids,
|
525
|
+
local_topk_ids,
|
526
|
+
self.quant_method.a_strides1,
|
527
|
+
self.quant_method.b_strides1,
|
528
|
+
self.quant_method.c_strides1,
|
529
|
+
self.quant_method.a_strides2,
|
530
|
+
self.quant_method.b_strides2,
|
531
|
+
self.quant_method.c_strides2,
|
532
|
+
self.quant_method.s_strides13,
|
533
|
+
self.quant_method.s_strides2,
|
534
|
+
self.quant_method.expert_offsets,
|
535
|
+
self.quant_method.problem_sizes1,
|
536
|
+
self.quant_method.problem_sizes2,
|
537
|
+
self.w13_input_scale,
|
538
|
+
self.w2_input_scale,
|
539
|
+
)
|
540
|
+
return output
|
541
|
+
|
542
|
+
if self.grouped_gemm_runner is None:
|
543
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
544
|
+
hidden_states.device,
|
545
|
+
use_flashinfer=False, # TODO: use flashinfer
|
546
|
+
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
547
|
+
)
|
548
|
+
|
443
549
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
444
550
|
topk_ids, self.num_experts
|
445
551
|
)
|
@@ -449,7 +555,7 @@ class EPMoE(torch.nn.Module):
|
|
449
555
|
device=hidden_states.device,
|
450
556
|
dtype=(
|
451
557
|
self.fp8_dtype
|
452
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
558
|
+
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
453
559
|
else hidden_states.dtype
|
454
560
|
),
|
455
561
|
)
|
@@ -656,6 +762,23 @@ class EPMoE(torch.nn.Module):
|
|
656
762
|
]
|
657
763
|
]
|
658
764
|
|
765
|
+
@classmethod
|
766
|
+
def make_expert_input_scale_params_mapping(
|
767
|
+
cls,
|
768
|
+
num_experts: int,
|
769
|
+
) -> List[Tuple[str, str, int, str]]:
|
770
|
+
# (param_name, weight_name, expert_id, shard_id)
|
771
|
+
return [
|
772
|
+
(
|
773
|
+
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
774
|
+
f"experts.{expert_id}.{shard_id}.",
|
775
|
+
expert_id,
|
776
|
+
shard_id,
|
777
|
+
)
|
778
|
+
for expert_id in range(num_experts)
|
779
|
+
for shard_id in ["w1", "w2", "w3"]
|
780
|
+
]
|
781
|
+
|
659
782
|
def weight_loader(
|
660
783
|
self,
|
661
784
|
param: torch.nn.Parameter,
|
@@ -727,6 +850,15 @@ class EPMoE(torch.nn.Module):
|
|
727
850
|
|
728
851
|
# Input scales can be loaded directly and should be equal.
|
729
852
|
if "input_scale" in weight_name:
|
853
|
+
if self.use_w4afp8:
|
854
|
+
if shard_id == "w1":
|
855
|
+
param_data[expert_id][0] = loaded_weight
|
856
|
+
elif shard_id == "w3":
|
857
|
+
param_data[expert_id][1] = loaded_weight
|
858
|
+
else:
|
859
|
+
param_data[expert_id] = loaded_weight
|
860
|
+
return
|
861
|
+
|
730
862
|
if (
|
731
863
|
(shard_id == "w1" or shard_id == "w3")
|
732
864
|
and param_data[expert_id] != 1
|
@@ -752,6 +884,13 @@ class EPMoE(torch.nn.Module):
|
|
752
884
|
] = loaded_weight
|
753
885
|
else: # w2
|
754
886
|
param_data[expert_id] = loaded_weight
|
887
|
+
elif self.use_w4afp8:
|
888
|
+
if shard_id == "w1":
|
889
|
+
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
890
|
+
elif shard_id == "w3":
|
891
|
+
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
892
|
+
else:
|
893
|
+
param_data[expert_id] = loaded_weight
|
755
894
|
# If we are in merged column case (gate_up_proj)
|
756
895
|
else:
|
757
896
|
if shard_id in ("w1", "w3"):
|
@@ -1737,6 +1737,7 @@ def fused_moe(
|
|
1737
1737
|
renormalize: bool,
|
1738
1738
|
inplace: bool = False,
|
1739
1739
|
activation: str = "silu",
|
1740
|
+
apply_router_weight_on_input: bool = False,
|
1740
1741
|
use_grouped_topk: bool = False,
|
1741
1742
|
num_expert_group: Optional[int] = None,
|
1742
1743
|
num_fused_shared_experts: int = 0,
|
@@ -1822,6 +1823,7 @@ def fused_moe(
|
|
1822
1823
|
topk_ids,
|
1823
1824
|
inplace=inplace,
|
1824
1825
|
activation=activation,
|
1826
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
1825
1827
|
use_fp8_w8a8=use_fp8_w8a8,
|
1826
1828
|
use_int8_w8a8=use_int8_w8a8,
|
1827
1829
|
use_int8_w8a16=use_int8_w8a16,
|