sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,6 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
-
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
17
|
-
from sglang.srt.layers.quantization.int8_kernel import (
|
18
|
-
per_token_group_quant_int8,
|
19
|
-
per_token_quant_int8,
|
20
|
-
)
|
21
16
|
from sglang.srt.utils import (
|
22
17
|
direct_register_custom_op,
|
23
18
|
get_bool_env_var,
|
@@ -42,9 +37,6 @@ if _is_cuda:
|
|
42
37
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
43
38
|
|
44
39
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
45
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
46
|
-
sglang_per_token_group_quant_fp8,
|
47
|
-
)
|
48
40
|
else:
|
49
41
|
from vllm import _custom_ops as vllm_ops
|
50
42
|
|
@@ -52,6 +44,257 @@ if _is_cuda or _is_hip:
|
|
52
44
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
53
45
|
|
54
46
|
|
47
|
+
@triton.jit
|
48
|
+
def write_zeros_to_output(
|
49
|
+
c_ptr,
|
50
|
+
stride_cm,
|
51
|
+
stride_cn,
|
52
|
+
pid_n,
|
53
|
+
N,
|
54
|
+
offs_token,
|
55
|
+
token_mask,
|
56
|
+
BLOCK_SIZE_M,
|
57
|
+
BLOCK_SIZE_N,
|
58
|
+
compute_type,
|
59
|
+
):
|
60
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
61
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
62
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
63
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
64
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
65
|
+
|
66
|
+
|
67
|
+
@triton.jit
|
68
|
+
def fused_moe_kernel_gptq_awq(
|
69
|
+
# Pointers to matrices
|
70
|
+
a_ptr,
|
71
|
+
b_ptr,
|
72
|
+
c_ptr,
|
73
|
+
b_scale_ptr,
|
74
|
+
b_zp_ptr,
|
75
|
+
topk_weights_ptr,
|
76
|
+
sorted_token_ids_ptr,
|
77
|
+
expert_ids_ptr,
|
78
|
+
num_tokens_post_padded_ptr,
|
79
|
+
# Matrix dimensions
|
80
|
+
N: tl.constexpr,
|
81
|
+
K: tl.constexpr,
|
82
|
+
EM,
|
83
|
+
num_valid_tokens,
|
84
|
+
# The stride variables represent how much to increase the ptr by when
|
85
|
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
86
|
+
# how much to increase `a_ptr` by to get the element one row down
|
87
|
+
# (A has M rows).
|
88
|
+
stride_am,
|
89
|
+
stride_ak,
|
90
|
+
stride_be,
|
91
|
+
stride_bk,
|
92
|
+
stride_bn,
|
93
|
+
stride_cm,
|
94
|
+
stride_cn,
|
95
|
+
stride_bse,
|
96
|
+
stride_bsk,
|
97
|
+
stride_bsn,
|
98
|
+
stride_bze,
|
99
|
+
stride_bzk,
|
100
|
+
stride_bzn,
|
101
|
+
group_size: tl.constexpr,
|
102
|
+
# Meta-parameters
|
103
|
+
BLOCK_SIZE_M: tl.constexpr,
|
104
|
+
BLOCK_SIZE_N: tl.constexpr,
|
105
|
+
BLOCK_SIZE_K: tl.constexpr,
|
106
|
+
GROUP_SIZE_M: tl.constexpr,
|
107
|
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
108
|
+
top_k: tl.constexpr,
|
109
|
+
compute_type: tl.constexpr,
|
110
|
+
has_zp: tl.constexpr,
|
111
|
+
use_int4_w4a16: tl.constexpr,
|
112
|
+
use_int8_w8a16: tl.constexpr,
|
113
|
+
even_Ks: tl.constexpr,
|
114
|
+
):
|
115
|
+
"""
|
116
|
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
117
|
+
token and expert matrices.
|
118
|
+
Key Parameters:
|
119
|
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
120
|
+
be any shape representing batches and K is the feature dimension of
|
121
|
+
each token.
|
122
|
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
123
|
+
the number of experts, K is the input feature dimension, and N is
|
124
|
+
the output feature dimension.
|
125
|
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
126
|
+
total number of tokens post padding, topk is the number of times
|
127
|
+
each token is repeated, and N is the output feature dimension.
|
128
|
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
129
|
+
repeated topk times and arranged by the expert index they are
|
130
|
+
assigned to.
|
131
|
+
- expert_ids: A tensor containing the indices of the expert for each
|
132
|
+
block. It determines which expert matrix from B should be used for
|
133
|
+
each block in A.
|
134
|
+
This kernel performs the multiplication of a token by its corresponding
|
135
|
+
expert matrix as determined by `expert_ids`. The sorting of
|
136
|
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
137
|
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
138
|
+
multiplication across different blocks processed by the same expert.
|
139
|
+
"""
|
140
|
+
# -----------------------------------------------------------
|
141
|
+
# Map program ids `pid` to the block of C it should compute.
|
142
|
+
# This is done in a grouped ordering to promote L2 data reuse.
|
143
|
+
pid = tl.program_id(axis=0)
|
144
|
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
145
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
146
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
147
|
+
group_id = pid // num_pid_in_group
|
148
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
149
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
150
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
151
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
152
|
+
|
153
|
+
# ----------------------------------------------------------
|
154
|
+
# Create pointers for the first blocks of A and B.
|
155
|
+
# We will advance this pointer as we move in the K direction
|
156
|
+
# and accumulate
|
157
|
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
158
|
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
159
|
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
160
|
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
161
|
+
return
|
162
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
163
|
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
164
|
+
token_mask = offs_token < num_valid_tokens
|
165
|
+
|
166
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
167
|
+
if off_experts == -1:
|
168
|
+
# -----------------------------------------------------------
|
169
|
+
# Write back zeros to the output when the expert is not
|
170
|
+
# in the current expert parallel rank.
|
171
|
+
write_zeros_to_output(
|
172
|
+
c_ptr,
|
173
|
+
stride_cm,
|
174
|
+
stride_cn,
|
175
|
+
pid_n,
|
176
|
+
N,
|
177
|
+
offs_token,
|
178
|
+
token_mask,
|
179
|
+
BLOCK_SIZE_M,
|
180
|
+
BLOCK_SIZE_N,
|
181
|
+
compute_type,
|
182
|
+
)
|
183
|
+
return
|
184
|
+
|
185
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
186
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
187
|
+
a_ptrs = a_ptr + (
|
188
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
189
|
+
)
|
190
|
+
|
191
|
+
if use_int4_w4a16:
|
192
|
+
b_ptrs = (
|
193
|
+
b_ptr
|
194
|
+
+ off_experts * stride_be
|
195
|
+
+ (offs_k[:, None] // 2) * stride_bk
|
196
|
+
+ offs_bn[None, :] * stride_bn
|
197
|
+
)
|
198
|
+
b_shifter = (offs_k[:, None] % 2) * 4
|
199
|
+
elif use_int8_w8a16:
|
200
|
+
b_ptrs = (
|
201
|
+
b_ptr
|
202
|
+
+ off_experts * stride_be
|
203
|
+
+ offs_k[:, None] * stride_bk
|
204
|
+
+ offs_bn[None, :] * stride_bn
|
205
|
+
)
|
206
|
+
|
207
|
+
if not has_zp and use_int4_w4a16:
|
208
|
+
b_zp_num = 8
|
209
|
+
if not has_zp and use_int8_w8a16:
|
210
|
+
b_zp_num = 128
|
211
|
+
elif has_zp and use_int4_w4a16:
|
212
|
+
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
213
|
+
|
214
|
+
# -----------------------------------------------------------
|
215
|
+
# Iterate to compute a block of the C matrix.
|
216
|
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
217
|
+
# of fp32 values for higher accuracy.
|
218
|
+
# `accumulator` will be converted back to fp16 after the loop.
|
219
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
220
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
221
|
+
# Load the next block of A and B, generate a mask by checking the
|
222
|
+
# K dimension.
|
223
|
+
|
224
|
+
if not even_Ks:
|
225
|
+
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
226
|
+
k_other = 0.0
|
227
|
+
else:
|
228
|
+
k_mask = None
|
229
|
+
k_other = None
|
230
|
+
|
231
|
+
a = tl.load(
|
232
|
+
a_ptrs,
|
233
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
234
|
+
other=0.0,
|
235
|
+
)
|
236
|
+
b = tl.load(b_ptrs)
|
237
|
+
if use_int4_w4a16:
|
238
|
+
b = (b >> b_shifter) & 0xF
|
239
|
+
|
240
|
+
b_scale_ptrs = (
|
241
|
+
b_scale_ptr
|
242
|
+
+ off_experts * stride_bse
|
243
|
+
+ offs_bn[None, :] * stride_bsn
|
244
|
+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
245
|
+
)
|
246
|
+
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
247
|
+
b_scale = b_scale.to(tl.float32)
|
248
|
+
|
249
|
+
if has_zp and use_int4_w4a16:
|
250
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
251
|
+
b_zp_ptrs = (
|
252
|
+
b_zp_ptr
|
253
|
+
+ off_experts * stride_bze
|
254
|
+
+ (offs_bn[None, :] // 2) * stride_bzn
|
255
|
+
+ offs_k_true * stride_bzk
|
256
|
+
)
|
257
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
258
|
+
b_zp = (b_zp >> b_zp_shifter) & 0xF
|
259
|
+
b_zp = b_zp.to(tl.float32)
|
260
|
+
elif has_zp and use_int8_w8a16:
|
261
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
262
|
+
b_zp_ptrs = (
|
263
|
+
b_zp_ptr
|
264
|
+
+ off_experts * stride_bze
|
265
|
+
+ offs_bn[None, :] * stride_bzn
|
266
|
+
+ offs_k_true * stride_bzk
|
267
|
+
)
|
268
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
269
|
+
b_zp = b_zp.to(tl.float32)
|
270
|
+
|
271
|
+
# We accumulate along the K dimension.
|
272
|
+
if has_zp:
|
273
|
+
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
274
|
+
else:
|
275
|
+
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
276
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
277
|
+
|
278
|
+
# Advance the ptrs to the next K block.
|
279
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
280
|
+
if use_int4_w4a16:
|
281
|
+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
282
|
+
else:
|
283
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
284
|
+
|
285
|
+
if MUL_ROUTED_WEIGHT:
|
286
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
287
|
+
accumulator = accumulator * moe_weight[:, None]
|
288
|
+
|
289
|
+
accumulator = accumulator.to(compute_type)
|
290
|
+
# -----------------------------------------------------------
|
291
|
+
# Write back the block of the output
|
292
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
293
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
294
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
295
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
296
|
+
|
297
|
+
|
55
298
|
@triton.jit
|
56
299
|
def fused_moe_kernel(
|
57
300
|
# Pointers to matrices
|
@@ -152,6 +395,7 @@ def fused_moe_kernel(
|
|
152
395
|
return
|
153
396
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
154
397
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
398
|
+
offs_token = offs_token.to(tl.int64)
|
155
399
|
token_mask = offs_token < num_valid_tokens
|
156
400
|
|
157
401
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
@@ -495,6 +739,7 @@ def invoke_fused_moe_kernel(
|
|
495
739
|
C: torch.Tensor,
|
496
740
|
A_scale: Optional[torch.Tensor],
|
497
741
|
B_scale: Optional[torch.Tensor],
|
742
|
+
B_zp: Optional[torch.Tensor],
|
498
743
|
topk_weights: torch.Tensor,
|
499
744
|
topk_ids: torch.Tensor,
|
500
745
|
sorted_token_ids: torch.Tensor,
|
@@ -507,9 +752,20 @@ def invoke_fused_moe_kernel(
|
|
507
752
|
use_fp8_w8a8: bool,
|
508
753
|
use_int8_w8a8: bool,
|
509
754
|
use_int8_w8a16: bool,
|
755
|
+
use_int4_w4a16: bool,
|
510
756
|
block_shape: Optional[List[int]] = None,
|
511
757
|
no_combine: bool = False,
|
512
758
|
) -> None:
|
759
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
760
|
+
per_token_group_quant_int8,
|
761
|
+
per_token_quant_int8,
|
762
|
+
)
|
763
|
+
|
764
|
+
if _is_cuda:
|
765
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
766
|
+
sglang_per_token_group_quant_fp8,
|
767
|
+
)
|
768
|
+
|
513
769
|
assert topk_weights.stride(1) == 1
|
514
770
|
assert sorted_token_ids.stride(0) == 1
|
515
771
|
|
@@ -547,8 +803,9 @@ def invoke_fused_moe_kernel(
|
|
547
803
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
548
804
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
549
805
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
550
|
-
elif use_int8_w8a16:
|
806
|
+
elif use_int8_w8a16 or use_int4_w4a16:
|
551
807
|
assert B_scale is not None
|
808
|
+
assert block_shape is None or block_shape[0] == 0
|
552
809
|
else:
|
553
810
|
assert A_scale is None
|
554
811
|
assert B_scale is None
|
@@ -564,43 +821,90 @@ def invoke_fused_moe_kernel(
|
|
564
821
|
else:
|
565
822
|
even_Ks = False
|
566
823
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
B_scale
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
824
|
+
if (
|
825
|
+
(use_int8_w8a16 or use_int4_w4a16)
|
826
|
+
and block_shape is not None
|
827
|
+
and block_shape[1] > 0
|
828
|
+
):
|
829
|
+
assert B_scale is not None and B_scale.ndim == 3
|
830
|
+
assert B_zp is None or B_zp.ndim == 3
|
831
|
+
fused_moe_kernel_gptq_awq[grid](
|
832
|
+
A,
|
833
|
+
B,
|
834
|
+
C,
|
835
|
+
B_scale,
|
836
|
+
B_zp,
|
837
|
+
topk_weights,
|
838
|
+
sorted_token_ids,
|
839
|
+
expert_ids,
|
840
|
+
num_tokens_post_padded,
|
841
|
+
B.shape[1],
|
842
|
+
A.shape[1],
|
843
|
+
sorted_token_ids.shape[0],
|
844
|
+
topk_ids.numel(),
|
845
|
+
A.stride(0),
|
846
|
+
A.stride(1),
|
847
|
+
B.stride(0),
|
848
|
+
B.stride(2),
|
849
|
+
B.stride(1),
|
850
|
+
C.stride(1),
|
851
|
+
C.stride(2),
|
852
|
+
B_scale.stride(0),
|
853
|
+
B_scale.stride(2),
|
854
|
+
B_scale.stride(1),
|
855
|
+
B_zp.stride(0) if B_zp is not None else 0,
|
856
|
+
B_zp.stride(2) if B_zp is not None else 0,
|
857
|
+
B_zp.stride(1) if B_zp is not None else 0,
|
858
|
+
group_size=block_shape[1],
|
859
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
860
|
+
top_k=top_k,
|
861
|
+
compute_type=compute_type,
|
862
|
+
has_zp=B_zp is not None,
|
863
|
+
use_int4_w4a16=use_int4_w4a16,
|
864
|
+
use_int8_w8a16=use_int8_w8a16,
|
865
|
+
even_Ks=even_Ks,
|
866
|
+
**config,
|
867
|
+
)
|
868
|
+
|
869
|
+
else:
|
870
|
+
|
871
|
+
fused_moe_kernel[grid](
|
872
|
+
A,
|
873
|
+
B,
|
874
|
+
C,
|
875
|
+
A_scale,
|
876
|
+
B_scale,
|
877
|
+
topk_weights,
|
878
|
+
sorted_token_ids,
|
879
|
+
expert_ids,
|
880
|
+
num_tokens_post_padded,
|
881
|
+
B.shape[1],
|
882
|
+
B.shape[2] - padded_size,
|
883
|
+
sorted_token_ids.shape[0],
|
884
|
+
topk_ids.numel(),
|
885
|
+
A.stride(0),
|
886
|
+
A.stride(1),
|
887
|
+
B.stride(0),
|
888
|
+
B.stride(2),
|
889
|
+
B.stride(1),
|
890
|
+
C.stride(1),
|
891
|
+
C.stride(2),
|
892
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
893
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
894
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
895
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
896
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
897
|
+
0 if block_shape is None else block_shape[0],
|
898
|
+
0 if block_shape is None else block_shape[1],
|
899
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
900
|
+
top_k=top_k,
|
901
|
+
compute_type=compute_type,
|
902
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
903
|
+
use_int8_w8a8=use_int8_w8a8,
|
904
|
+
use_int8_w8a16=use_int8_w8a16,
|
905
|
+
even_Ks=even_Ks,
|
906
|
+
**config,
|
907
|
+
)
|
604
908
|
|
605
909
|
|
606
910
|
def get_config_file_name(
|
@@ -749,6 +1053,7 @@ def try_get_optimal_moe_config(
|
|
749
1053
|
def get_config_dtype_str(
|
750
1054
|
dtype: torch.dtype,
|
751
1055
|
use_int8_w8a16: Optional[bool] = False,
|
1056
|
+
use_int4_w4a16: Optional[bool] = False,
|
752
1057
|
use_fp8_w8a8: Optional[bool] = False,
|
753
1058
|
use_int8_w8a8: Optional[bool] = False,
|
754
1059
|
):
|
@@ -756,6 +1061,8 @@ def get_config_dtype_str(
|
|
756
1061
|
return "fp8_w8a8"
|
757
1062
|
elif use_int8_w8a8:
|
758
1063
|
return "int8_w8a8"
|
1064
|
+
elif use_int4_w4a16:
|
1065
|
+
return "int4_w4a16"
|
759
1066
|
elif use_int8_w8a16:
|
760
1067
|
return "int8_w8a16"
|
761
1068
|
elif dtype == torch.float:
|
@@ -772,11 +1079,15 @@ def inplace_fused_experts(
|
|
772
1079
|
topk_weights: torch.Tensor,
|
773
1080
|
topk_ids: torch.Tensor,
|
774
1081
|
activation: str = "silu",
|
1082
|
+
apply_router_weight_on_input: bool = False,
|
775
1083
|
use_fp8_w8a8: bool = False,
|
776
1084
|
use_int8_w8a8: bool = False,
|
777
1085
|
use_int8_w8a16: bool = False,
|
1086
|
+
use_int4_w4a16: bool = False,
|
778
1087
|
w1_scale: Optional[torch.Tensor] = None,
|
779
1088
|
w2_scale: Optional[torch.Tensor] = None,
|
1089
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1090
|
+
w2_zp: Optional[torch.Tensor] = None,
|
780
1091
|
a1_scale: Optional[torch.Tensor] = None,
|
781
1092
|
a2_scale: Optional[torch.Tensor] = None,
|
782
1093
|
block_shape: Optional[List[int]] = None,
|
@@ -789,11 +1100,15 @@ def inplace_fused_experts(
|
|
789
1100
|
topk_ids,
|
790
1101
|
True,
|
791
1102
|
activation,
|
1103
|
+
apply_router_weight_on_input,
|
792
1104
|
use_fp8_w8a8,
|
793
1105
|
use_int8_w8a8,
|
794
1106
|
use_int8_w8a16,
|
1107
|
+
use_int4_w4a16,
|
795
1108
|
w1_scale,
|
796
1109
|
w2_scale,
|
1110
|
+
w1_zp,
|
1111
|
+
w2_zp,
|
797
1112
|
a1_scale,
|
798
1113
|
a2_scale,
|
799
1114
|
block_shape,
|
@@ -807,11 +1122,15 @@ def inplace_fused_experts_fake(
|
|
807
1122
|
topk_weights: torch.Tensor,
|
808
1123
|
topk_ids: torch.Tensor,
|
809
1124
|
activation: str = "silu",
|
1125
|
+
apply_router_weight_on_input: bool = False,
|
810
1126
|
use_fp8_w8a8: bool = False,
|
811
1127
|
use_int8_w8a8: bool = False,
|
812
1128
|
use_int8_w8a16: bool = False,
|
1129
|
+
use_int4_w4a16: bool = False,
|
813
1130
|
w1_scale: Optional[torch.Tensor] = None,
|
814
1131
|
w2_scale: Optional[torch.Tensor] = None,
|
1132
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1133
|
+
w2_zp: Optional[torch.Tensor] = None,
|
815
1134
|
a1_scale: Optional[torch.Tensor] = None,
|
816
1135
|
a2_scale: Optional[torch.Tensor] = None,
|
817
1136
|
block_shape: Optional[List[int]] = None,
|
@@ -834,11 +1153,15 @@ def outplace_fused_experts(
|
|
834
1153
|
topk_weights: torch.Tensor,
|
835
1154
|
topk_ids: torch.Tensor,
|
836
1155
|
activation: str = "silu",
|
1156
|
+
apply_router_weight_on_input: bool = False,
|
837
1157
|
use_fp8_w8a8: bool = False,
|
838
1158
|
use_int8_w8a8: bool = False,
|
839
1159
|
use_int8_w8a16: bool = False,
|
1160
|
+
use_int4_w4a16: bool = False,
|
840
1161
|
w1_scale: Optional[torch.Tensor] = None,
|
841
1162
|
w2_scale: Optional[torch.Tensor] = None,
|
1163
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1164
|
+
w2_zp: Optional[torch.Tensor] = None,
|
842
1165
|
a1_scale: Optional[torch.Tensor] = None,
|
843
1166
|
a2_scale: Optional[torch.Tensor] = None,
|
844
1167
|
block_shape: Optional[List[int]] = None,
|
@@ -852,11 +1175,15 @@ def outplace_fused_experts(
|
|
852
1175
|
topk_ids,
|
853
1176
|
False,
|
854
1177
|
activation,
|
1178
|
+
apply_router_weight_on_input,
|
855
1179
|
use_fp8_w8a8,
|
856
1180
|
use_int8_w8a8,
|
857
1181
|
use_int8_w8a16,
|
1182
|
+
use_int4_w4a16,
|
858
1183
|
w1_scale,
|
859
1184
|
w2_scale,
|
1185
|
+
w1_zp,
|
1186
|
+
w2_zp,
|
860
1187
|
a1_scale,
|
861
1188
|
a2_scale,
|
862
1189
|
block_shape,
|
@@ -871,11 +1198,15 @@ def outplace_fused_experts_fake(
|
|
871
1198
|
topk_weights: torch.Tensor,
|
872
1199
|
topk_ids: torch.Tensor,
|
873
1200
|
activation: str = "silu",
|
1201
|
+
apply_router_weight_on_input: bool = False,
|
874
1202
|
use_fp8_w8a8: bool = False,
|
875
1203
|
use_int8_w8a8: bool = False,
|
876
1204
|
use_int8_w8a16: bool = False,
|
1205
|
+
use_int4_w4a16: bool = False,
|
877
1206
|
w1_scale: Optional[torch.Tensor] = None,
|
878
1207
|
w2_scale: Optional[torch.Tensor] = None,
|
1208
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1209
|
+
w2_zp: Optional[torch.Tensor] = None,
|
879
1210
|
a1_scale: Optional[torch.Tensor] = None,
|
880
1211
|
a2_scale: Optional[torch.Tensor] = None,
|
881
1212
|
block_shape: Optional[List[int]] = None,
|
@@ -900,11 +1231,15 @@ def fused_experts(
|
|
900
1231
|
topk_ids: torch.Tensor,
|
901
1232
|
inplace: bool = False,
|
902
1233
|
activation: str = "silu",
|
1234
|
+
apply_router_weight_on_input: bool = False,
|
903
1235
|
use_fp8_w8a8: bool = False,
|
904
1236
|
use_int8_w8a8: bool = False,
|
905
1237
|
use_int8_w8a16: bool = False,
|
1238
|
+
use_int4_w4a16: bool = False,
|
906
1239
|
w1_scale: Optional[torch.Tensor] = None,
|
907
1240
|
w2_scale: Optional[torch.Tensor] = None,
|
1241
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1242
|
+
w2_zp: Optional[torch.Tensor] = None,
|
908
1243
|
a1_scale: Optional[torch.Tensor] = None,
|
909
1244
|
a2_scale: Optional[torch.Tensor] = None,
|
910
1245
|
block_shape: Optional[List[int]] = None,
|
@@ -919,11 +1254,15 @@ def fused_experts(
|
|
919
1254
|
topk_weights,
|
920
1255
|
topk_ids,
|
921
1256
|
activation,
|
1257
|
+
apply_router_weight_on_input,
|
922
1258
|
use_fp8_w8a8,
|
923
1259
|
use_int8_w8a8,
|
924
1260
|
use_int8_w8a16,
|
1261
|
+
use_int4_w4a16,
|
925
1262
|
w1_scale,
|
926
1263
|
w2_scale,
|
1264
|
+
w1_zp,
|
1265
|
+
w2_zp,
|
927
1266
|
a1_scale,
|
928
1267
|
a2_scale,
|
929
1268
|
block_shape,
|
@@ -937,11 +1276,15 @@ def fused_experts(
|
|
937
1276
|
topk_weights,
|
938
1277
|
topk_ids,
|
939
1278
|
activation,
|
1279
|
+
apply_router_weight_on_input,
|
940
1280
|
use_fp8_w8a8,
|
941
1281
|
use_int8_w8a8,
|
942
1282
|
use_int8_w8a16,
|
1283
|
+
use_int4_w4a16,
|
943
1284
|
w1_scale,
|
944
1285
|
w2_scale,
|
1286
|
+
w1_zp,
|
1287
|
+
w2_zp,
|
945
1288
|
a1_scale,
|
946
1289
|
a2_scale,
|
947
1290
|
block_shape,
|
@@ -957,11 +1300,15 @@ def fused_experts_impl(
|
|
957
1300
|
topk_ids: torch.Tensor,
|
958
1301
|
inplace: bool = False,
|
959
1302
|
activation: str = "silu",
|
1303
|
+
apply_router_weight_on_input: bool = False,
|
960
1304
|
use_fp8_w8a8: bool = False,
|
961
1305
|
use_int8_w8a8: bool = False,
|
962
1306
|
use_int8_w8a16: bool = False,
|
1307
|
+
use_int4_w4a16: bool = False,
|
963
1308
|
w1_scale: Optional[torch.Tensor] = None,
|
964
1309
|
w2_scale: Optional[torch.Tensor] = None,
|
1310
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1311
|
+
w2_zp: Optional[torch.Tensor] = None,
|
965
1312
|
a1_scale: Optional[torch.Tensor] = None,
|
966
1313
|
a2_scale: Optional[torch.Tensor] = None,
|
967
1314
|
block_shape: Optional[List[int]] = None,
|
@@ -976,7 +1323,12 @@ def fused_experts_impl(
|
|
976
1323
|
padded_size = 0
|
977
1324
|
|
978
1325
|
# Check constraints.
|
979
|
-
|
1326
|
+
if use_int4_w4a16:
|
1327
|
+
assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
|
1328
|
+
else:
|
1329
|
+
assert (
|
1330
|
+
hidden_states.shape[1] == w1.shape[2] - padded_size
|
1331
|
+
), "Hidden size mismatch"
|
980
1332
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
981
1333
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
982
1334
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -993,6 +1345,7 @@ def fused_experts_impl(
|
|
993
1345
|
use_fp8_w8a8=use_fp8_w8a8,
|
994
1346
|
use_int8_w8a8=use_int8_w8a8,
|
995
1347
|
use_int8_w8a16=use_int8_w8a16,
|
1348
|
+
use_int4_w4a16=use_int4_w4a16,
|
996
1349
|
dtype=hidden_states.dtype,
|
997
1350
|
)
|
998
1351
|
|
@@ -1074,18 +1427,20 @@ def fused_experts_impl(
|
|
1074
1427
|
intermediate_cache1,
|
1075
1428
|
a1_scale,
|
1076
1429
|
w1_scale,
|
1430
|
+
w1_zp,
|
1077
1431
|
curr_topk_weights,
|
1078
1432
|
curr_topk_ids,
|
1079
1433
|
sorted_token_ids,
|
1080
1434
|
expert_ids,
|
1081
1435
|
num_tokens_post_padded,
|
1082
|
-
|
1436
|
+
apply_router_weight_on_input,
|
1083
1437
|
topk_ids.shape[1],
|
1084
1438
|
config,
|
1085
1439
|
compute_type=compute_type,
|
1086
1440
|
use_fp8_w8a8=use_fp8_w8a8,
|
1087
1441
|
use_int8_w8a8=use_int8_w8a8,
|
1088
1442
|
use_int8_w8a16=use_int8_w8a16,
|
1443
|
+
use_int4_w4a16=use_int4_w4a16,
|
1089
1444
|
block_shape=block_shape,
|
1090
1445
|
)
|
1091
1446
|
if activation == "silu":
|
@@ -1111,22 +1466,24 @@ def fused_experts_impl(
|
|
1111
1466
|
(
|
1112
1467
|
intermediate_cache3
|
1113
1468
|
if not no_combine and topk_ids.shape[1] != 1
|
1114
|
-
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
1469
|
+
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
|
1115
1470
|
),
|
1116
1471
|
a2_scale,
|
1117
1472
|
w2_scale,
|
1473
|
+
w2_zp,
|
1118
1474
|
curr_topk_weights,
|
1119
1475
|
curr_topk_ids,
|
1120
1476
|
sorted_token_ids,
|
1121
1477
|
expert_ids,
|
1122
1478
|
num_tokens_post_padded,
|
1123
|
-
|
1479
|
+
not apply_router_weight_on_input,
|
1124
1480
|
1,
|
1125
1481
|
config,
|
1126
1482
|
compute_type=compute_type,
|
1127
1483
|
use_fp8_w8a8=use_fp8_w8a8,
|
1128
1484
|
use_int8_w8a8=use_int8_w8a8,
|
1129
1485
|
use_int8_w8a16=use_int8_w8a16,
|
1486
|
+
use_int4_w4a16=use_int4_w4a16,
|
1130
1487
|
block_shape=block_shape,
|
1131
1488
|
)
|
1132
1489
|
|
@@ -1172,8 +1529,11 @@ def fused_moe(
|
|
1172
1529
|
use_fp8_w8a8: bool = False,
|
1173
1530
|
use_int8_w8a8: bool = False,
|
1174
1531
|
use_int8_w8a16: bool = False,
|
1532
|
+
use_int4_w4a16: bool = False,
|
1175
1533
|
w1_scale: Optional[torch.Tensor] = None,
|
1176
1534
|
w2_scale: Optional[torch.Tensor] = None,
|
1535
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1536
|
+
w2_zp: Optional[torch.Tensor] = None,
|
1177
1537
|
a1_scale: Optional[torch.Tensor] = None,
|
1178
1538
|
a2_scale: Optional[torch.Tensor] = None,
|
1179
1539
|
block_shape: Optional[List[int]] = None,
|
@@ -1203,6 +1563,9 @@ def fused_moe(
|
|
1203
1563
|
products for w1 and w2. Defaults to False.
|
1204
1564
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
1205
1565
|
products for w1 and w2. Defaults to False.
|
1566
|
+
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
1567
|
+
activation to compute the inner products for w1 and w2.
|
1568
|
+
Defaults to False.
|
1206
1569
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
1207
1570
|
w1.
|
1208
1571
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
@@ -1242,8 +1605,11 @@ def fused_moe(
|
|
1242
1605
|
use_fp8_w8a8=use_fp8_w8a8,
|
1243
1606
|
use_int8_w8a8=use_int8_w8a8,
|
1244
1607
|
use_int8_w8a16=use_int8_w8a16,
|
1608
|
+
use_int4_w4a16=use_int4_w4a16,
|
1245
1609
|
w1_scale=w1_scale,
|
1246
1610
|
w2_scale=w2_scale,
|
1611
|
+
w1_zp=w1_zp,
|
1612
|
+
w2_zp=w2_zp,
|
1247
1613
|
a1_scale=a1_scale,
|
1248
1614
|
a2_scale=a2_scale,
|
1249
1615
|
block_shape=block_shape,
|