sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,6 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
-
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
17
|
-
from sglang.srt.layers.quantization.int8_kernel import (
|
18
|
-
per_token_group_quant_int8,
|
19
|
-
per_token_quant_int8,
|
20
|
-
)
|
21
16
|
from sglang.srt.utils import (
|
22
17
|
direct_register_custom_op,
|
23
18
|
get_bool_env_var,
|
@@ -42,9 +37,6 @@ if _is_cuda:
|
|
42
37
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
43
38
|
|
44
39
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
45
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
46
|
-
sglang_per_token_group_quant_fp8,
|
47
|
-
)
|
48
40
|
else:
|
49
41
|
from vllm import _custom_ops as vllm_ops
|
50
42
|
|
@@ -52,6 +44,257 @@ if _is_cuda or _is_hip:
|
|
52
44
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
53
45
|
|
54
46
|
|
47
|
+
@triton.jit
|
48
|
+
def write_zeros_to_output(
|
49
|
+
c_ptr,
|
50
|
+
stride_cm,
|
51
|
+
stride_cn,
|
52
|
+
pid_n,
|
53
|
+
N,
|
54
|
+
offs_token,
|
55
|
+
token_mask,
|
56
|
+
BLOCK_SIZE_M,
|
57
|
+
BLOCK_SIZE_N,
|
58
|
+
compute_type,
|
59
|
+
):
|
60
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
61
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
62
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
63
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
64
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
65
|
+
|
66
|
+
|
67
|
+
@triton.jit
|
68
|
+
def fused_moe_kernel_gptq_awq(
|
69
|
+
# Pointers to matrices
|
70
|
+
a_ptr,
|
71
|
+
b_ptr,
|
72
|
+
c_ptr,
|
73
|
+
b_scale_ptr,
|
74
|
+
b_zp_ptr,
|
75
|
+
topk_weights_ptr,
|
76
|
+
sorted_token_ids_ptr,
|
77
|
+
expert_ids_ptr,
|
78
|
+
num_tokens_post_padded_ptr,
|
79
|
+
# Matrix dimensions
|
80
|
+
N: tl.constexpr,
|
81
|
+
K: tl.constexpr,
|
82
|
+
EM,
|
83
|
+
num_valid_tokens,
|
84
|
+
# The stride variables represent how much to increase the ptr by when
|
85
|
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
86
|
+
# how much to increase `a_ptr` by to get the element one row down
|
87
|
+
# (A has M rows).
|
88
|
+
stride_am,
|
89
|
+
stride_ak,
|
90
|
+
stride_be,
|
91
|
+
stride_bk,
|
92
|
+
stride_bn,
|
93
|
+
stride_cm,
|
94
|
+
stride_cn,
|
95
|
+
stride_bse,
|
96
|
+
stride_bsk,
|
97
|
+
stride_bsn,
|
98
|
+
stride_bze,
|
99
|
+
stride_bzk,
|
100
|
+
stride_bzn,
|
101
|
+
group_size: tl.constexpr,
|
102
|
+
# Meta-parameters
|
103
|
+
BLOCK_SIZE_M: tl.constexpr,
|
104
|
+
BLOCK_SIZE_N: tl.constexpr,
|
105
|
+
BLOCK_SIZE_K: tl.constexpr,
|
106
|
+
GROUP_SIZE_M: tl.constexpr,
|
107
|
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
108
|
+
top_k: tl.constexpr,
|
109
|
+
compute_type: tl.constexpr,
|
110
|
+
has_zp: tl.constexpr,
|
111
|
+
use_int4_w4a16: tl.constexpr,
|
112
|
+
use_int8_w8a16: tl.constexpr,
|
113
|
+
even_Ks: tl.constexpr,
|
114
|
+
):
|
115
|
+
"""
|
116
|
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
117
|
+
token and expert matrices.
|
118
|
+
Key Parameters:
|
119
|
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
120
|
+
be any shape representing batches and K is the feature dimension of
|
121
|
+
each token.
|
122
|
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
123
|
+
the number of experts, K is the input feature dimension, and N is
|
124
|
+
the output feature dimension.
|
125
|
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
126
|
+
total number of tokens post padding, topk is the number of times
|
127
|
+
each token is repeated, and N is the output feature dimension.
|
128
|
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
129
|
+
repeated topk times and arranged by the expert index they are
|
130
|
+
assigned to.
|
131
|
+
- expert_ids: A tensor containing the indices of the expert for each
|
132
|
+
block. It determines which expert matrix from B should be used for
|
133
|
+
each block in A.
|
134
|
+
This kernel performs the multiplication of a token by its corresponding
|
135
|
+
expert matrix as determined by `expert_ids`. The sorting of
|
136
|
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
137
|
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
138
|
+
multiplication across different blocks processed by the same expert.
|
139
|
+
"""
|
140
|
+
# -----------------------------------------------------------
|
141
|
+
# Map program ids `pid` to the block of C it should compute.
|
142
|
+
# This is done in a grouped ordering to promote L2 data reuse.
|
143
|
+
pid = tl.program_id(axis=0)
|
144
|
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
145
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
146
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
147
|
+
group_id = pid // num_pid_in_group
|
148
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
149
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
150
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
151
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
152
|
+
|
153
|
+
# ----------------------------------------------------------
|
154
|
+
# Create pointers for the first blocks of A and B.
|
155
|
+
# We will advance this pointer as we move in the K direction
|
156
|
+
# and accumulate
|
157
|
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
158
|
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
159
|
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
160
|
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
161
|
+
return
|
162
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
163
|
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
164
|
+
token_mask = offs_token < num_valid_tokens
|
165
|
+
|
166
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
167
|
+
if off_experts == -1:
|
168
|
+
# -----------------------------------------------------------
|
169
|
+
# Write back zeros to the output when the expert is not
|
170
|
+
# in the current expert parallel rank.
|
171
|
+
write_zeros_to_output(
|
172
|
+
c_ptr,
|
173
|
+
stride_cm,
|
174
|
+
stride_cn,
|
175
|
+
pid_n,
|
176
|
+
N,
|
177
|
+
offs_token,
|
178
|
+
token_mask,
|
179
|
+
BLOCK_SIZE_M,
|
180
|
+
BLOCK_SIZE_N,
|
181
|
+
compute_type,
|
182
|
+
)
|
183
|
+
return
|
184
|
+
|
185
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
186
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
187
|
+
a_ptrs = a_ptr + (
|
188
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
189
|
+
)
|
190
|
+
|
191
|
+
if use_int4_w4a16:
|
192
|
+
b_ptrs = (
|
193
|
+
b_ptr
|
194
|
+
+ off_experts * stride_be
|
195
|
+
+ (offs_k[:, None] // 2) * stride_bk
|
196
|
+
+ offs_bn[None, :] * stride_bn
|
197
|
+
)
|
198
|
+
b_shifter = (offs_k[:, None] % 2) * 4
|
199
|
+
elif use_int8_w8a16:
|
200
|
+
b_ptrs = (
|
201
|
+
b_ptr
|
202
|
+
+ off_experts * stride_be
|
203
|
+
+ offs_k[:, None] * stride_bk
|
204
|
+
+ offs_bn[None, :] * stride_bn
|
205
|
+
)
|
206
|
+
|
207
|
+
if not has_zp and use_int4_w4a16:
|
208
|
+
b_zp_num = 8
|
209
|
+
if not has_zp and use_int8_w8a16:
|
210
|
+
b_zp_num = 128
|
211
|
+
elif has_zp and use_int4_w4a16:
|
212
|
+
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
213
|
+
|
214
|
+
# -----------------------------------------------------------
|
215
|
+
# Iterate to compute a block of the C matrix.
|
216
|
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
217
|
+
# of fp32 values for higher accuracy.
|
218
|
+
# `accumulator` will be converted back to fp16 after the loop.
|
219
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
220
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
221
|
+
# Load the next block of A and B, generate a mask by checking the
|
222
|
+
# K dimension.
|
223
|
+
|
224
|
+
if not even_Ks:
|
225
|
+
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
226
|
+
k_other = 0.0
|
227
|
+
else:
|
228
|
+
k_mask = None
|
229
|
+
k_other = None
|
230
|
+
|
231
|
+
a = tl.load(
|
232
|
+
a_ptrs,
|
233
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
234
|
+
other=0.0,
|
235
|
+
)
|
236
|
+
b = tl.load(b_ptrs)
|
237
|
+
if use_int4_w4a16:
|
238
|
+
b = (b >> b_shifter) & 0xF
|
239
|
+
|
240
|
+
b_scale_ptrs = (
|
241
|
+
b_scale_ptr
|
242
|
+
+ off_experts * stride_bse
|
243
|
+
+ offs_bn[None, :] * stride_bsn
|
244
|
+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
245
|
+
)
|
246
|
+
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
247
|
+
b_scale = b_scale.to(tl.float32)
|
248
|
+
|
249
|
+
if has_zp and use_int4_w4a16:
|
250
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
251
|
+
b_zp_ptrs = (
|
252
|
+
b_zp_ptr
|
253
|
+
+ off_experts * stride_bze
|
254
|
+
+ (offs_bn[None, :] // 2) * stride_bzn
|
255
|
+
+ offs_k_true * stride_bzk
|
256
|
+
)
|
257
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
258
|
+
b_zp = (b_zp >> b_zp_shifter) & 0xF
|
259
|
+
b_zp = b_zp.to(tl.float32)
|
260
|
+
elif has_zp and use_int8_w8a16:
|
261
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
262
|
+
b_zp_ptrs = (
|
263
|
+
b_zp_ptr
|
264
|
+
+ off_experts * stride_bze
|
265
|
+
+ offs_bn[None, :] * stride_bzn
|
266
|
+
+ offs_k_true * stride_bzk
|
267
|
+
)
|
268
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
269
|
+
b_zp = b_zp.to(tl.float32)
|
270
|
+
|
271
|
+
# We accumulate along the K dimension.
|
272
|
+
if has_zp:
|
273
|
+
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
274
|
+
else:
|
275
|
+
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
276
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
277
|
+
|
278
|
+
# Advance the ptrs to the next K block.
|
279
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
280
|
+
if use_int4_w4a16:
|
281
|
+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
282
|
+
else:
|
283
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
284
|
+
|
285
|
+
if MUL_ROUTED_WEIGHT:
|
286
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
287
|
+
accumulator = accumulator * moe_weight[:, None]
|
288
|
+
|
289
|
+
accumulator = accumulator.to(compute_type)
|
290
|
+
# -----------------------------------------------------------
|
291
|
+
# Write back the block of the output
|
292
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
293
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
294
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
295
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
296
|
+
|
297
|
+
|
55
298
|
@triton.jit
|
56
299
|
def fused_moe_kernel(
|
57
300
|
# Pointers to matrices
|
@@ -152,6 +395,7 @@ def fused_moe_kernel(
|
|
152
395
|
return
|
153
396
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
154
397
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
398
|
+
offs_token = offs_token.to(tl.int64)
|
155
399
|
token_mask = offs_token < num_valid_tokens
|
156
400
|
|
157
401
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
@@ -495,6 +739,7 @@ def invoke_fused_moe_kernel(
|
|
495
739
|
C: torch.Tensor,
|
496
740
|
A_scale: Optional[torch.Tensor],
|
497
741
|
B_scale: Optional[torch.Tensor],
|
742
|
+
B_zp: Optional[torch.Tensor],
|
498
743
|
topk_weights: torch.Tensor,
|
499
744
|
topk_ids: torch.Tensor,
|
500
745
|
sorted_token_ids: torch.Tensor,
|
@@ -507,9 +752,20 @@ def invoke_fused_moe_kernel(
|
|
507
752
|
use_fp8_w8a8: bool,
|
508
753
|
use_int8_w8a8: bool,
|
509
754
|
use_int8_w8a16: bool,
|
755
|
+
use_int4_w4a16: bool,
|
510
756
|
block_shape: Optional[List[int]] = None,
|
511
757
|
no_combine: bool = False,
|
512
758
|
) -> None:
|
759
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
760
|
+
per_token_group_quant_int8,
|
761
|
+
per_token_quant_int8,
|
762
|
+
)
|
763
|
+
|
764
|
+
if _is_cuda:
|
765
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
766
|
+
sglang_per_token_group_quant_fp8,
|
767
|
+
)
|
768
|
+
|
513
769
|
assert topk_weights.stride(1) == 1
|
514
770
|
assert sorted_token_ids.stride(0) == 1
|
515
771
|
|
@@ -547,8 +803,9 @@ def invoke_fused_moe_kernel(
|
|
547
803
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
548
804
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
549
805
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
550
|
-
elif use_int8_w8a16:
|
806
|
+
elif use_int8_w8a16 or use_int4_w4a16:
|
551
807
|
assert B_scale is not None
|
808
|
+
assert block_shape is None or block_shape[0] == 0
|
552
809
|
else:
|
553
810
|
assert A_scale is None
|
554
811
|
assert B_scale is None
|
@@ -564,43 +821,90 @@ def invoke_fused_moe_kernel(
|
|
564
821
|
else:
|
565
822
|
even_Ks = False
|
566
823
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
B_scale
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
824
|
+
if (
|
825
|
+
(use_int8_w8a16 or use_int4_w4a16)
|
826
|
+
and block_shape is not None
|
827
|
+
and block_shape[1] > 0
|
828
|
+
):
|
829
|
+
assert B_scale is not None and B_scale.ndim == 3
|
830
|
+
assert B_zp is None or B_zp.ndim == 3
|
831
|
+
fused_moe_kernel_gptq_awq[grid](
|
832
|
+
A,
|
833
|
+
B,
|
834
|
+
C,
|
835
|
+
B_scale,
|
836
|
+
B_zp,
|
837
|
+
topk_weights,
|
838
|
+
sorted_token_ids,
|
839
|
+
expert_ids,
|
840
|
+
num_tokens_post_padded,
|
841
|
+
B.shape[1],
|
842
|
+
A.shape[1],
|
843
|
+
sorted_token_ids.shape[0],
|
844
|
+
topk_ids.numel(),
|
845
|
+
A.stride(0),
|
846
|
+
A.stride(1),
|
847
|
+
B.stride(0),
|
848
|
+
B.stride(2),
|
849
|
+
B.stride(1),
|
850
|
+
C.stride(1),
|
851
|
+
C.stride(2),
|
852
|
+
B_scale.stride(0),
|
853
|
+
B_scale.stride(2),
|
854
|
+
B_scale.stride(1),
|
855
|
+
B_zp.stride(0) if B_zp is not None else 0,
|
856
|
+
B_zp.stride(2) if B_zp is not None else 0,
|
857
|
+
B_zp.stride(1) if B_zp is not None else 0,
|
858
|
+
group_size=block_shape[1],
|
859
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
860
|
+
top_k=top_k,
|
861
|
+
compute_type=compute_type,
|
862
|
+
has_zp=B_zp is not None,
|
863
|
+
use_int4_w4a16=use_int4_w4a16,
|
864
|
+
use_int8_w8a16=use_int8_w8a16,
|
865
|
+
even_Ks=even_Ks,
|
866
|
+
**config,
|
867
|
+
)
|
868
|
+
|
869
|
+
else:
|
870
|
+
|
871
|
+
fused_moe_kernel[grid](
|
872
|
+
A,
|
873
|
+
B,
|
874
|
+
C,
|
875
|
+
A_scale,
|
876
|
+
B_scale,
|
877
|
+
topk_weights,
|
878
|
+
sorted_token_ids,
|
879
|
+
expert_ids,
|
880
|
+
num_tokens_post_padded,
|
881
|
+
B.shape[1],
|
882
|
+
B.shape[2] - padded_size,
|
883
|
+
sorted_token_ids.shape[0],
|
884
|
+
topk_ids.numel(),
|
885
|
+
A.stride(0),
|
886
|
+
A.stride(1),
|
887
|
+
B.stride(0),
|
888
|
+
B.stride(2),
|
889
|
+
B.stride(1),
|
890
|
+
C.stride(1),
|
891
|
+
C.stride(2),
|
892
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
893
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
894
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
895
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
896
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
897
|
+
0 if block_shape is None else block_shape[0],
|
898
|
+
0 if block_shape is None else block_shape[1],
|
899
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
900
|
+
top_k=top_k,
|
901
|
+
compute_type=compute_type,
|
902
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
903
|
+
use_int8_w8a8=use_int8_w8a8,
|
904
|
+
use_int8_w8a16=use_int8_w8a16,
|
905
|
+
even_Ks=even_Ks,
|
906
|
+
**config,
|
907
|
+
)
|
604
908
|
|
605
909
|
|
606
910
|
def get_config_file_name(
|
@@ -749,6 +1053,7 @@ def try_get_optimal_moe_config(
|
|
749
1053
|
def get_config_dtype_str(
|
750
1054
|
dtype: torch.dtype,
|
751
1055
|
use_int8_w8a16: Optional[bool] = False,
|
1056
|
+
use_int4_w4a16: Optional[bool] = False,
|
752
1057
|
use_fp8_w8a8: Optional[bool] = False,
|
753
1058
|
use_int8_w8a8: Optional[bool] = False,
|
754
1059
|
):
|
@@ -756,6 +1061,8 @@ def get_config_dtype_str(
|
|
756
1061
|
return "fp8_w8a8"
|
757
1062
|
elif use_int8_w8a8:
|
758
1063
|
return "int8_w8a8"
|
1064
|
+
elif use_int4_w4a16:
|
1065
|
+
return "int4_w4a16"
|
759
1066
|
elif use_int8_w8a16:
|
760
1067
|
return "int8_w8a16"
|
761
1068
|
elif dtype == torch.float:
|
@@ -775,8 +1082,11 @@ def inplace_fused_experts(
|
|
775
1082
|
use_fp8_w8a8: bool = False,
|
776
1083
|
use_int8_w8a8: bool = False,
|
777
1084
|
use_int8_w8a16: bool = False,
|
1085
|
+
use_int4_w4a16: bool = False,
|
778
1086
|
w1_scale: Optional[torch.Tensor] = None,
|
779
1087
|
w2_scale: Optional[torch.Tensor] = None,
|
1088
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1089
|
+
w2_zp: Optional[torch.Tensor] = None,
|
780
1090
|
a1_scale: Optional[torch.Tensor] = None,
|
781
1091
|
a2_scale: Optional[torch.Tensor] = None,
|
782
1092
|
block_shape: Optional[List[int]] = None,
|
@@ -792,8 +1102,11 @@ def inplace_fused_experts(
|
|
792
1102
|
use_fp8_w8a8,
|
793
1103
|
use_int8_w8a8,
|
794
1104
|
use_int8_w8a16,
|
1105
|
+
use_int4_w4a16,
|
795
1106
|
w1_scale,
|
796
1107
|
w2_scale,
|
1108
|
+
w1_zp,
|
1109
|
+
w2_zp,
|
797
1110
|
a1_scale,
|
798
1111
|
a2_scale,
|
799
1112
|
block_shape,
|
@@ -810,8 +1123,11 @@ def inplace_fused_experts_fake(
|
|
810
1123
|
use_fp8_w8a8: bool = False,
|
811
1124
|
use_int8_w8a8: bool = False,
|
812
1125
|
use_int8_w8a16: bool = False,
|
1126
|
+
use_int4_w4a16: bool = False,
|
813
1127
|
w1_scale: Optional[torch.Tensor] = None,
|
814
1128
|
w2_scale: Optional[torch.Tensor] = None,
|
1129
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1130
|
+
w2_zp: Optional[torch.Tensor] = None,
|
815
1131
|
a1_scale: Optional[torch.Tensor] = None,
|
816
1132
|
a2_scale: Optional[torch.Tensor] = None,
|
817
1133
|
block_shape: Optional[List[int]] = None,
|
@@ -837,8 +1153,11 @@ def outplace_fused_experts(
|
|
837
1153
|
use_fp8_w8a8: bool = False,
|
838
1154
|
use_int8_w8a8: bool = False,
|
839
1155
|
use_int8_w8a16: bool = False,
|
1156
|
+
use_int4_w4a16: bool = False,
|
840
1157
|
w1_scale: Optional[torch.Tensor] = None,
|
841
1158
|
w2_scale: Optional[torch.Tensor] = None,
|
1159
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1160
|
+
w2_zp: Optional[torch.Tensor] = None,
|
842
1161
|
a1_scale: Optional[torch.Tensor] = None,
|
843
1162
|
a2_scale: Optional[torch.Tensor] = None,
|
844
1163
|
block_shape: Optional[List[int]] = None,
|
@@ -855,8 +1174,11 @@ def outplace_fused_experts(
|
|
855
1174
|
use_fp8_w8a8,
|
856
1175
|
use_int8_w8a8,
|
857
1176
|
use_int8_w8a16,
|
1177
|
+
use_int4_w4a16,
|
858
1178
|
w1_scale,
|
859
1179
|
w2_scale,
|
1180
|
+
w1_zp,
|
1181
|
+
w2_zp,
|
860
1182
|
a1_scale,
|
861
1183
|
a2_scale,
|
862
1184
|
block_shape,
|
@@ -874,8 +1196,11 @@ def outplace_fused_experts_fake(
|
|
874
1196
|
use_fp8_w8a8: bool = False,
|
875
1197
|
use_int8_w8a8: bool = False,
|
876
1198
|
use_int8_w8a16: bool = False,
|
1199
|
+
use_int4_w4a16: bool = False,
|
877
1200
|
w1_scale: Optional[torch.Tensor] = None,
|
878
1201
|
w2_scale: Optional[torch.Tensor] = None,
|
1202
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1203
|
+
w2_zp: Optional[torch.Tensor] = None,
|
879
1204
|
a1_scale: Optional[torch.Tensor] = None,
|
880
1205
|
a2_scale: Optional[torch.Tensor] = None,
|
881
1206
|
block_shape: Optional[List[int]] = None,
|
@@ -903,8 +1228,11 @@ def fused_experts(
|
|
903
1228
|
use_fp8_w8a8: bool = False,
|
904
1229
|
use_int8_w8a8: bool = False,
|
905
1230
|
use_int8_w8a16: bool = False,
|
1231
|
+
use_int4_w4a16: bool = False,
|
906
1232
|
w1_scale: Optional[torch.Tensor] = None,
|
907
1233
|
w2_scale: Optional[torch.Tensor] = None,
|
1234
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1235
|
+
w2_zp: Optional[torch.Tensor] = None,
|
908
1236
|
a1_scale: Optional[torch.Tensor] = None,
|
909
1237
|
a2_scale: Optional[torch.Tensor] = None,
|
910
1238
|
block_shape: Optional[List[int]] = None,
|
@@ -922,8 +1250,11 @@ def fused_experts(
|
|
922
1250
|
use_fp8_w8a8,
|
923
1251
|
use_int8_w8a8,
|
924
1252
|
use_int8_w8a16,
|
1253
|
+
use_int4_w4a16,
|
925
1254
|
w1_scale,
|
926
1255
|
w2_scale,
|
1256
|
+
w1_zp,
|
1257
|
+
w2_zp,
|
927
1258
|
a1_scale,
|
928
1259
|
a2_scale,
|
929
1260
|
block_shape,
|
@@ -940,8 +1271,11 @@ def fused_experts(
|
|
940
1271
|
use_fp8_w8a8,
|
941
1272
|
use_int8_w8a8,
|
942
1273
|
use_int8_w8a16,
|
1274
|
+
use_int4_w4a16,
|
943
1275
|
w1_scale,
|
944
1276
|
w2_scale,
|
1277
|
+
w1_zp,
|
1278
|
+
w2_zp,
|
945
1279
|
a1_scale,
|
946
1280
|
a2_scale,
|
947
1281
|
block_shape,
|
@@ -960,8 +1294,11 @@ def fused_experts_impl(
|
|
960
1294
|
use_fp8_w8a8: bool = False,
|
961
1295
|
use_int8_w8a8: bool = False,
|
962
1296
|
use_int8_w8a16: bool = False,
|
1297
|
+
use_int4_w4a16: bool = False,
|
963
1298
|
w1_scale: Optional[torch.Tensor] = None,
|
964
1299
|
w2_scale: Optional[torch.Tensor] = None,
|
1300
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1301
|
+
w2_zp: Optional[torch.Tensor] = None,
|
965
1302
|
a1_scale: Optional[torch.Tensor] = None,
|
966
1303
|
a2_scale: Optional[torch.Tensor] = None,
|
967
1304
|
block_shape: Optional[List[int]] = None,
|
@@ -976,7 +1313,12 @@ def fused_experts_impl(
|
|
976
1313
|
padded_size = 0
|
977
1314
|
|
978
1315
|
# Check constraints.
|
979
|
-
|
1316
|
+
if use_int4_w4a16:
|
1317
|
+
assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
|
1318
|
+
else:
|
1319
|
+
assert (
|
1320
|
+
hidden_states.shape[1] == w1.shape[2] - padded_size
|
1321
|
+
), "Hidden size mismatch"
|
980
1322
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
981
1323
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
982
1324
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -993,6 +1335,7 @@ def fused_experts_impl(
|
|
993
1335
|
use_fp8_w8a8=use_fp8_w8a8,
|
994
1336
|
use_int8_w8a8=use_int8_w8a8,
|
995
1337
|
use_int8_w8a16=use_int8_w8a16,
|
1338
|
+
use_int4_w4a16=use_int4_w4a16,
|
996
1339
|
dtype=hidden_states.dtype,
|
997
1340
|
)
|
998
1341
|
|
@@ -1074,6 +1417,7 @@ def fused_experts_impl(
|
|
1074
1417
|
intermediate_cache1,
|
1075
1418
|
a1_scale,
|
1076
1419
|
w1_scale,
|
1420
|
+
w1_zp,
|
1077
1421
|
curr_topk_weights,
|
1078
1422
|
curr_topk_ids,
|
1079
1423
|
sorted_token_ids,
|
@@ -1086,6 +1430,7 @@ def fused_experts_impl(
|
|
1086
1430
|
use_fp8_w8a8=use_fp8_w8a8,
|
1087
1431
|
use_int8_w8a8=use_int8_w8a8,
|
1088
1432
|
use_int8_w8a16=use_int8_w8a16,
|
1433
|
+
use_int4_w4a16=use_int4_w4a16,
|
1089
1434
|
block_shape=block_shape,
|
1090
1435
|
)
|
1091
1436
|
if activation == "silu":
|
@@ -1115,6 +1460,7 @@ def fused_experts_impl(
|
|
1115
1460
|
),
|
1116
1461
|
a2_scale,
|
1117
1462
|
w2_scale,
|
1463
|
+
w2_zp,
|
1118
1464
|
curr_topk_weights,
|
1119
1465
|
curr_topk_ids,
|
1120
1466
|
sorted_token_ids,
|
@@ -1127,6 +1473,7 @@ def fused_experts_impl(
|
|
1127
1473
|
use_fp8_w8a8=use_fp8_w8a8,
|
1128
1474
|
use_int8_w8a8=use_int8_w8a8,
|
1129
1475
|
use_int8_w8a16=use_int8_w8a16,
|
1476
|
+
use_int4_w4a16=use_int4_w4a16,
|
1130
1477
|
block_shape=block_shape,
|
1131
1478
|
)
|
1132
1479
|
|
@@ -1172,8 +1519,11 @@ def fused_moe(
|
|
1172
1519
|
use_fp8_w8a8: bool = False,
|
1173
1520
|
use_int8_w8a8: bool = False,
|
1174
1521
|
use_int8_w8a16: bool = False,
|
1522
|
+
use_int4_w4a16: bool = False,
|
1175
1523
|
w1_scale: Optional[torch.Tensor] = None,
|
1176
1524
|
w2_scale: Optional[torch.Tensor] = None,
|
1525
|
+
w1_zp: Optional[torch.Tensor] = None,
|
1526
|
+
w2_zp: Optional[torch.Tensor] = None,
|
1177
1527
|
a1_scale: Optional[torch.Tensor] = None,
|
1178
1528
|
a2_scale: Optional[torch.Tensor] = None,
|
1179
1529
|
block_shape: Optional[List[int]] = None,
|
@@ -1203,6 +1553,9 @@ def fused_moe(
|
|
1203
1553
|
products for w1 and w2. Defaults to False.
|
1204
1554
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
1205
1555
|
products for w1 and w2. Defaults to False.
|
1556
|
+
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
1557
|
+
activation to compute the inner products for w1 and w2.
|
1558
|
+
Defaults to False.
|
1206
1559
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
1207
1560
|
w1.
|
1208
1561
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
@@ -1242,8 +1595,11 @@ def fused_moe(
|
|
1242
1595
|
use_fp8_w8a8=use_fp8_w8a8,
|
1243
1596
|
use_int8_w8a8=use_int8_w8a8,
|
1244
1597
|
use_int8_w8a16=use_int8_w8a16,
|
1598
|
+
use_int4_w4a16=use_int4_w4a16,
|
1245
1599
|
w1_scale=w1_scale,
|
1246
1600
|
w2_scale=w2_scale,
|
1601
|
+
w1_zp=w1_zp,
|
1602
|
+
w2_zp=w2_zp,
|
1247
1603
|
a1_scale=a1_scale,
|
1248
1604
|
a2_scale=a2_scale,
|
1249
1605
|
block_shape=block_shape,
|