sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -5,39 +5,27 @@
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
7
|
import functools
|
8
|
-
import json
|
9
|
-
import logging
|
10
8
|
import os
|
11
|
-
from typing import
|
9
|
+
from typing import List, Optional
|
12
10
|
|
13
11
|
import torch
|
14
|
-
import triton
|
15
12
|
import triton.language as tl
|
16
13
|
|
17
14
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
18
15
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
19
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
20
|
-
per_token_group_quant_fp8,
|
21
|
-
scaled_fp8_quant,
|
22
|
-
sglang_per_token_group_quant_fp8,
|
23
|
-
)
|
24
|
-
from sglang.srt.layers.quantization.int8_kernel import (
|
25
|
-
per_token_group_quant_int8,
|
26
|
-
per_token_quant_int8,
|
27
|
-
sglang_per_token_group_quant_int8,
|
28
|
-
)
|
29
16
|
from sglang.srt.utils import (
|
30
|
-
ceil_div,
|
31
17
|
cpu_has_amx_support,
|
32
18
|
direct_register_custom_op,
|
33
19
|
get_bool_env_var,
|
34
|
-
get_device_name,
|
35
20
|
is_cpu,
|
36
21
|
is_cuda,
|
37
22
|
is_hip,
|
38
|
-
next_power_of_2,
|
39
23
|
)
|
40
24
|
|
25
|
+
from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
|
26
|
+
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
|
27
|
+
from .moe_align_block_size import moe_align_block_size
|
28
|
+
|
41
29
|
_is_hip = is_hip()
|
42
30
|
_is_cuda = is_cuda()
|
43
31
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -59,954 +47,9 @@ elif _is_hip:
|
|
59
47
|
else:
|
60
48
|
from vllm import _custom_ops as vllm_ops
|
61
49
|
|
62
|
-
|
63
|
-
if _is_cuda or _is_hip:
|
64
|
-
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
65
|
-
|
66
|
-
|
67
|
-
logger = logging.getLogger(__name__)
|
68
50
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
69
51
|
|
70
52
|
|
71
|
-
@triton.jit
|
72
|
-
def write_zeros_to_output(
|
73
|
-
c_ptr,
|
74
|
-
stride_cm,
|
75
|
-
stride_cn,
|
76
|
-
pid_n,
|
77
|
-
N,
|
78
|
-
offs_token,
|
79
|
-
token_mask,
|
80
|
-
BLOCK_SIZE_M,
|
81
|
-
BLOCK_SIZE_N,
|
82
|
-
compute_type,
|
83
|
-
):
|
84
|
-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
85
|
-
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
86
|
-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
87
|
-
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
88
|
-
tl.store(c_ptrs, accumulator, mask=c_mask)
|
89
|
-
|
90
|
-
|
91
|
-
@triton.jit
|
92
|
-
def fused_moe_kernel_gptq_awq(
|
93
|
-
# Pointers to matrices
|
94
|
-
a_ptr,
|
95
|
-
b_ptr,
|
96
|
-
c_ptr,
|
97
|
-
b_scale_ptr,
|
98
|
-
b_zp_ptr,
|
99
|
-
topk_weights_ptr,
|
100
|
-
sorted_token_ids_ptr,
|
101
|
-
expert_ids_ptr,
|
102
|
-
num_tokens_post_padded_ptr,
|
103
|
-
# Matrix dimensions
|
104
|
-
N: tl.constexpr,
|
105
|
-
K: tl.constexpr,
|
106
|
-
EM,
|
107
|
-
num_valid_tokens,
|
108
|
-
# The stride variables represent how much to increase the ptr by when
|
109
|
-
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
110
|
-
# how much to increase `a_ptr` by to get the element one row down
|
111
|
-
# (A has M rows).
|
112
|
-
stride_am,
|
113
|
-
stride_ak,
|
114
|
-
stride_be,
|
115
|
-
stride_bk,
|
116
|
-
stride_bn,
|
117
|
-
stride_cm,
|
118
|
-
stride_cn,
|
119
|
-
stride_bse,
|
120
|
-
stride_bsk,
|
121
|
-
stride_bsn,
|
122
|
-
stride_bze,
|
123
|
-
stride_bzk,
|
124
|
-
stride_bzn,
|
125
|
-
group_size: tl.constexpr,
|
126
|
-
# Meta-parameters
|
127
|
-
BLOCK_SIZE_M: tl.constexpr,
|
128
|
-
BLOCK_SIZE_N: tl.constexpr,
|
129
|
-
BLOCK_SIZE_K: tl.constexpr,
|
130
|
-
GROUP_SIZE_M: tl.constexpr,
|
131
|
-
MUL_ROUTED_WEIGHT: tl.constexpr,
|
132
|
-
top_k: tl.constexpr,
|
133
|
-
compute_type: tl.constexpr,
|
134
|
-
has_zp: tl.constexpr,
|
135
|
-
use_int4_w4a16: tl.constexpr,
|
136
|
-
use_int8_w8a16: tl.constexpr,
|
137
|
-
even_Ks: tl.constexpr,
|
138
|
-
):
|
139
|
-
"""
|
140
|
-
Implements the fused computation for a Mixture of Experts (MOE) using
|
141
|
-
token and expert matrices.
|
142
|
-
Key Parameters:
|
143
|
-
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
144
|
-
be any shape representing batches and K is the feature dimension of
|
145
|
-
each token.
|
146
|
-
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
147
|
-
the number of experts, K is the input feature dimension, and N is
|
148
|
-
the output feature dimension.
|
149
|
-
- C: The output cache tensor with shape (M, topk, N), where M is the
|
150
|
-
total number of tokens post padding, topk is the number of times
|
151
|
-
each token is repeated, and N is the output feature dimension.
|
152
|
-
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
153
|
-
repeated topk times and arranged by the expert index they are
|
154
|
-
assigned to.
|
155
|
-
- expert_ids: A tensor containing the indices of the expert for each
|
156
|
-
block. It determines which expert matrix from B should be used for
|
157
|
-
each block in A.
|
158
|
-
This kernel performs the multiplication of a token by its corresponding
|
159
|
-
expert matrix as determined by `expert_ids`. The sorting of
|
160
|
-
`sorted_token_ids` by expert index and padding ensures divisibility by
|
161
|
-
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
162
|
-
multiplication across different blocks processed by the same expert.
|
163
|
-
"""
|
164
|
-
# -----------------------------------------------------------
|
165
|
-
# Map program ids `pid` to the block of C it should compute.
|
166
|
-
# This is done in a grouped ordering to promote L2 data reuse.
|
167
|
-
pid = tl.program_id(axis=0)
|
168
|
-
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
169
|
-
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
170
|
-
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
171
|
-
group_id = pid // num_pid_in_group
|
172
|
-
first_pid_m = group_id * GROUP_SIZE_M
|
173
|
-
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
174
|
-
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
175
|
-
pid_n = (pid % num_pid_in_group) // group_size_m
|
176
|
-
|
177
|
-
# ----------------------------------------------------------
|
178
|
-
# Create pointers for the first blocks of A and B.
|
179
|
-
# We will advance this pointer as we move in the K direction
|
180
|
-
# and accumulate
|
181
|
-
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
182
|
-
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
183
|
-
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
184
|
-
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
185
|
-
return
|
186
|
-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
187
|
-
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
188
|
-
token_mask = offs_token < num_valid_tokens
|
189
|
-
|
190
|
-
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
191
|
-
if off_experts == -1:
|
192
|
-
# -----------------------------------------------------------
|
193
|
-
# Write back zeros to the output when the expert is not
|
194
|
-
# in the current expert parallel rank.
|
195
|
-
write_zeros_to_output(
|
196
|
-
c_ptr,
|
197
|
-
stride_cm,
|
198
|
-
stride_cn,
|
199
|
-
pid_n,
|
200
|
-
N,
|
201
|
-
offs_token,
|
202
|
-
token_mask,
|
203
|
-
BLOCK_SIZE_M,
|
204
|
-
BLOCK_SIZE_N,
|
205
|
-
compute_type,
|
206
|
-
)
|
207
|
-
return
|
208
|
-
|
209
|
-
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
210
|
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
211
|
-
a_ptrs = a_ptr + (
|
212
|
-
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
213
|
-
)
|
214
|
-
|
215
|
-
if use_int4_w4a16:
|
216
|
-
b_ptrs = (
|
217
|
-
b_ptr
|
218
|
-
+ off_experts * stride_be
|
219
|
-
+ (offs_k[:, None] // 2) * stride_bk
|
220
|
-
+ offs_bn[None, :] * stride_bn
|
221
|
-
)
|
222
|
-
b_shifter = (offs_k[:, None] % 2) * 4
|
223
|
-
elif use_int8_w8a16:
|
224
|
-
b_ptrs = (
|
225
|
-
b_ptr
|
226
|
-
+ off_experts * stride_be
|
227
|
-
+ offs_k[:, None] * stride_bk
|
228
|
-
+ offs_bn[None, :] * stride_bn
|
229
|
-
)
|
230
|
-
|
231
|
-
if not has_zp and use_int4_w4a16:
|
232
|
-
b_zp_num = 8
|
233
|
-
if not has_zp and use_int8_w8a16:
|
234
|
-
b_zp_num = 128
|
235
|
-
elif has_zp and use_int4_w4a16:
|
236
|
-
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
237
|
-
|
238
|
-
# -----------------------------------------------------------
|
239
|
-
# Iterate to compute a block of the C matrix.
|
240
|
-
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
241
|
-
# of fp32 values for higher accuracy.
|
242
|
-
# `accumulator` will be converted back to fp16 after the loop.
|
243
|
-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
244
|
-
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
245
|
-
# Load the next block of A and B, generate a mask by checking the
|
246
|
-
# K dimension.
|
247
|
-
|
248
|
-
if not even_Ks:
|
249
|
-
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
250
|
-
k_other = 0.0
|
251
|
-
else:
|
252
|
-
k_mask = None
|
253
|
-
k_other = None
|
254
|
-
|
255
|
-
a = tl.load(
|
256
|
-
a_ptrs,
|
257
|
-
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
258
|
-
other=0.0,
|
259
|
-
)
|
260
|
-
b = tl.load(b_ptrs)
|
261
|
-
if use_int4_w4a16:
|
262
|
-
b = (b >> b_shifter) & 0xF
|
263
|
-
|
264
|
-
b_scale_ptrs = (
|
265
|
-
b_scale_ptr
|
266
|
-
+ off_experts * stride_bse
|
267
|
-
+ offs_bn[None, :] * stride_bsn
|
268
|
-
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
269
|
-
)
|
270
|
-
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
271
|
-
b_scale = b_scale.to(tl.float32)
|
272
|
-
|
273
|
-
if has_zp and use_int4_w4a16:
|
274
|
-
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
275
|
-
b_zp_ptrs = (
|
276
|
-
b_zp_ptr
|
277
|
-
+ off_experts * stride_bze
|
278
|
-
+ (offs_bn[None, :] // 2) * stride_bzn
|
279
|
-
+ offs_k_true * stride_bzk
|
280
|
-
)
|
281
|
-
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
282
|
-
b_zp = (b_zp >> b_zp_shifter) & 0xF
|
283
|
-
b_zp = b_zp.to(tl.float32)
|
284
|
-
elif has_zp and use_int8_w8a16:
|
285
|
-
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
286
|
-
b_zp_ptrs = (
|
287
|
-
b_zp_ptr
|
288
|
-
+ off_experts * stride_bze
|
289
|
-
+ offs_bn[None, :] * stride_bzn
|
290
|
-
+ offs_k_true * stride_bzk
|
291
|
-
)
|
292
|
-
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
293
|
-
b_zp = b_zp.to(tl.float32)
|
294
|
-
|
295
|
-
# We accumulate along the K dimension.
|
296
|
-
if has_zp:
|
297
|
-
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
298
|
-
else:
|
299
|
-
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
300
|
-
accumulator = tl.dot(a, b, acc=accumulator)
|
301
|
-
|
302
|
-
# Advance the ptrs to the next K block.
|
303
|
-
a_ptrs += BLOCK_SIZE_K * stride_ak
|
304
|
-
if use_int4_w4a16:
|
305
|
-
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
306
|
-
else:
|
307
|
-
b_ptrs += BLOCK_SIZE_K * stride_bk
|
308
|
-
|
309
|
-
if MUL_ROUTED_WEIGHT:
|
310
|
-
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
311
|
-
accumulator = accumulator * moe_weight[:, None]
|
312
|
-
|
313
|
-
accumulator = accumulator.to(compute_type)
|
314
|
-
# -----------------------------------------------------------
|
315
|
-
# Write back the block of the output
|
316
|
-
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
317
|
-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
318
|
-
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
319
|
-
tl.store(c_ptrs, accumulator, mask=c_mask)
|
320
|
-
|
321
|
-
|
322
|
-
@triton.jit
|
323
|
-
def fused_moe_kernel(
|
324
|
-
# Pointers to matrices
|
325
|
-
a_ptr,
|
326
|
-
b_ptr,
|
327
|
-
bias_ptr,
|
328
|
-
c_ptr,
|
329
|
-
a_scale_ptr,
|
330
|
-
b_scale_ptr,
|
331
|
-
topk_weights_ptr,
|
332
|
-
sorted_token_ids_ptr,
|
333
|
-
expert_ids_ptr,
|
334
|
-
num_tokens_post_padded_ptr,
|
335
|
-
# Matrix dimensions
|
336
|
-
N,
|
337
|
-
K,
|
338
|
-
EM,
|
339
|
-
num_valid_tokens,
|
340
|
-
# The stride variables represent how much to increase the ptr by when
|
341
|
-
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
342
|
-
# how much to increase `a_ptr` by to get the element one row down
|
343
|
-
# (A has M rows).
|
344
|
-
stride_am,
|
345
|
-
stride_ak,
|
346
|
-
stride_be,
|
347
|
-
stride_bk,
|
348
|
-
stride_bn,
|
349
|
-
stride_bias_e,
|
350
|
-
stride_bias_n,
|
351
|
-
stride_cm,
|
352
|
-
stride_cn,
|
353
|
-
stride_asm,
|
354
|
-
stride_ask,
|
355
|
-
stride_bse,
|
356
|
-
stride_bsk,
|
357
|
-
stride_bsn,
|
358
|
-
# Block size for block-wise quantization
|
359
|
-
group_n: tl.constexpr,
|
360
|
-
group_k: tl.constexpr,
|
361
|
-
# Meta-parameters
|
362
|
-
BLOCK_SIZE_M: tl.constexpr,
|
363
|
-
BLOCK_SIZE_N: tl.constexpr,
|
364
|
-
BLOCK_SIZE_K: tl.constexpr,
|
365
|
-
GROUP_SIZE_M: tl.constexpr,
|
366
|
-
MUL_ROUTED_WEIGHT: tl.constexpr,
|
367
|
-
top_k: tl.constexpr,
|
368
|
-
compute_type: tl.constexpr,
|
369
|
-
use_fp8_w8a8: tl.constexpr,
|
370
|
-
use_int8_w8a8: tl.constexpr,
|
371
|
-
use_int8_w8a16: tl.constexpr,
|
372
|
-
per_channel_quant: tl.constexpr,
|
373
|
-
even_Ks: tl.constexpr,
|
374
|
-
):
|
375
|
-
"""
|
376
|
-
Implements the fused computation for a Mixture of Experts (MOE) using
|
377
|
-
token and expert matrices.
|
378
|
-
|
379
|
-
Key Parameters:
|
380
|
-
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
381
|
-
be any shape representing batches and K is the feature dimension of
|
382
|
-
each token.
|
383
|
-
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
384
|
-
the number of experts, K is the input feature dimension, and N is
|
385
|
-
the output feature dimension.
|
386
|
-
- C: The output cache tensor with shape (M, topk, N), where M is the
|
387
|
-
total number of tokens post padding, topk is the number of times
|
388
|
-
each token is repeated, and N is the output feature dimension.
|
389
|
-
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
390
|
-
repeated topk times and arranged by the expert index they are
|
391
|
-
assigned to.
|
392
|
-
- expert_ids: A tensor containing the indices of the expert for each
|
393
|
-
block. It determines which expert matrix from B should be used for
|
394
|
-
each block in A.
|
395
|
-
|
396
|
-
This kernel performs the multiplication of a token by its corresponding
|
397
|
-
expert matrix as determined by `expert_ids`. The sorting of
|
398
|
-
`sorted_token_ids` by expert index and padding ensures divisibility by
|
399
|
-
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
400
|
-
multiplication across different blocks processed by the same expert.
|
401
|
-
"""
|
402
|
-
# -----------------------------------------------------------
|
403
|
-
# Map program ids `pid` to the block of C it should compute.
|
404
|
-
# This is done in a grouped ordering to promote L2 data reuse.
|
405
|
-
pid = tl.program_id(axis=0)
|
406
|
-
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
407
|
-
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
408
|
-
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
409
|
-
group_id = pid // num_pid_in_group
|
410
|
-
first_pid_m = group_id * GROUP_SIZE_M
|
411
|
-
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
412
|
-
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
413
|
-
pid_n = (pid % num_pid_in_group) // group_size_m
|
414
|
-
|
415
|
-
# ----------------------------------------------------------
|
416
|
-
# Create pointers for the first blocks of A and B.
|
417
|
-
# We will advance this pointer as we move in the K direction
|
418
|
-
# and accumulate
|
419
|
-
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
420
|
-
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
421
|
-
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
422
|
-
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
423
|
-
return
|
424
|
-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
425
|
-
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
426
|
-
offs_token = offs_token.to(tl.int64)
|
427
|
-
token_mask = offs_token < num_valid_tokens
|
428
|
-
|
429
|
-
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
430
|
-
|
431
|
-
if off_experts == -1:
|
432
|
-
# -----------------------------------------------------------
|
433
|
-
# Write back zeros to the output when the expert is not
|
434
|
-
# in the current expert parallel rank.
|
435
|
-
write_zeros_to_output(
|
436
|
-
c_ptr,
|
437
|
-
stride_cm,
|
438
|
-
stride_cn,
|
439
|
-
pid_n,
|
440
|
-
N,
|
441
|
-
offs_token,
|
442
|
-
token_mask,
|
443
|
-
BLOCK_SIZE_M,
|
444
|
-
BLOCK_SIZE_N,
|
445
|
-
compute_type,
|
446
|
-
)
|
447
|
-
return
|
448
|
-
|
449
|
-
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
450
|
-
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
451
|
-
a_ptrs = a_ptr + (
|
452
|
-
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
453
|
-
)
|
454
|
-
|
455
|
-
b_ptrs = (
|
456
|
-
b_ptr
|
457
|
-
+ off_experts * stride_be
|
458
|
-
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
459
|
-
)
|
460
|
-
if bias_ptr is not None:
|
461
|
-
bias = tl.load(
|
462
|
-
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
463
|
-
)
|
464
|
-
if use_int8_w8a16:
|
465
|
-
b_scale_ptrs = (
|
466
|
-
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
467
|
-
)
|
468
|
-
b_scale = tl.load(b_scale_ptrs)
|
469
|
-
|
470
|
-
if use_fp8_w8a8 or use_int8_w8a8:
|
471
|
-
# block-wise
|
472
|
-
if group_k > 0 and group_n > 0:
|
473
|
-
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
474
|
-
offs_bsn = offs_bn // group_n
|
475
|
-
b_scale_ptrs = (
|
476
|
-
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
477
|
-
)
|
478
|
-
# channel-wise
|
479
|
-
elif per_channel_quant:
|
480
|
-
b_scale_ptrs = (
|
481
|
-
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
482
|
-
)
|
483
|
-
b_scale = tl.load(b_scale_ptrs)
|
484
|
-
# Load per-token scale for activations
|
485
|
-
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
486
|
-
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
487
|
-
# tensor-wise
|
488
|
-
else:
|
489
|
-
a_scale = tl.load(a_scale_ptr)
|
490
|
-
b_scale = tl.load(b_scale_ptr + off_experts)
|
491
|
-
|
492
|
-
# -----------------------------------------------------------
|
493
|
-
# Iterate to compute a block of the C matrix.
|
494
|
-
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
495
|
-
# of fp32 values for higher accuracy.
|
496
|
-
# `accumulator` will be converted back to fp16 after the loop.
|
497
|
-
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
498
|
-
|
499
|
-
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
500
|
-
# Load the next block of A and B, generate a mask by checking the
|
501
|
-
# K dimension.
|
502
|
-
if even_Ks:
|
503
|
-
a = tl.load(
|
504
|
-
a_ptrs,
|
505
|
-
mask=token_mask[:, None],
|
506
|
-
other=0.0,
|
507
|
-
)
|
508
|
-
b = tl.load(b_ptrs)
|
509
|
-
else:
|
510
|
-
a = tl.load(
|
511
|
-
a_ptrs,
|
512
|
-
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
513
|
-
other=0.0,
|
514
|
-
)
|
515
|
-
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
516
|
-
|
517
|
-
# We accumulate along the K dimension.
|
518
|
-
if use_int8_w8a16:
|
519
|
-
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
520
|
-
elif use_fp8_w8a8 or use_int8_w8a8:
|
521
|
-
if group_k > 0 and group_n > 0:
|
522
|
-
k_start = k * BLOCK_SIZE_K
|
523
|
-
offs_ks = k_start // group_k
|
524
|
-
a_scale = tl.load(
|
525
|
-
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
526
|
-
)
|
527
|
-
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
528
|
-
|
529
|
-
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
530
|
-
else:
|
531
|
-
if use_fp8_w8a8:
|
532
|
-
accumulator = tl.dot(a, b, acc=accumulator)
|
533
|
-
else:
|
534
|
-
accumulator += tl.dot(a, b)
|
535
|
-
else:
|
536
|
-
accumulator += tl.dot(a, b)
|
537
|
-
# Advance the ptrs to the next K block.
|
538
|
-
a_ptrs += BLOCK_SIZE_K * stride_ak
|
539
|
-
b_ptrs += BLOCK_SIZE_K * stride_bk
|
540
|
-
|
541
|
-
if use_int8_w8a16:
|
542
|
-
accumulator *= b_scale
|
543
|
-
elif use_fp8_w8a8 or use_int8_w8a8:
|
544
|
-
if group_k == 0 or group_n == 0:
|
545
|
-
accumulator *= a_scale * b_scale
|
546
|
-
|
547
|
-
if bias_ptr is not None:
|
548
|
-
accumulator += bias
|
549
|
-
|
550
|
-
if MUL_ROUTED_WEIGHT:
|
551
|
-
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
552
|
-
accumulator *= moe_weight[:, None]
|
553
|
-
|
554
|
-
accumulator = accumulator.to(compute_type)
|
555
|
-
# -----------------------------------------------------------
|
556
|
-
# Write back the block of the output
|
557
|
-
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
558
|
-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
559
|
-
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
560
|
-
tl.store(c_ptrs, accumulator, mask=c_mask)
|
561
|
-
|
562
|
-
|
563
|
-
def moe_align_block_size(
|
564
|
-
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
565
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
566
|
-
"""
|
567
|
-
Aligns the token distribution across experts to be compatible with block
|
568
|
-
size for matrix multiplication.
|
569
|
-
|
570
|
-
Parameters:
|
571
|
-
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
572
|
-
top-k expert indices for each token.
|
573
|
-
- block_size: The block size used in block matrix multiplication.
|
574
|
-
- num_experts: The total number of experts.
|
575
|
-
|
576
|
-
Returns:
|
577
|
-
- sorted_token_ids: A tensor containing the sorted token indices according
|
578
|
-
to their allocated expert.
|
579
|
-
- expert_ids: A tensor indicating the assigned expert index for each block.
|
580
|
-
- num_tokens_post_padded: The total number of tokens after padding,
|
581
|
-
ensuring divisibility by block_size.
|
582
|
-
|
583
|
-
This function pads the number of tokens that each expert needs to process
|
584
|
-
so that it is divisible by block_size.
|
585
|
-
Padding ensures that during block matrix multiplication, the dimensions
|
586
|
-
align correctly.
|
587
|
-
|
588
|
-
Example:
|
589
|
-
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
590
|
-
block_size = 4, and num_experts = 4:
|
591
|
-
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
592
|
-
with each expert needing to process 3 tokens.
|
593
|
-
- As block_size is 4, we pad 1 token for each expert.
|
594
|
-
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
595
|
-
- Then append padding tokens [12, 12, 12, 12] for each block.
|
596
|
-
- After sorting by expert index, we obtain token_ids
|
597
|
-
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
598
|
-
Tokens 12 are non-existent (padding) and are ignored in
|
599
|
-
the subsequent matrix multiplication.
|
600
|
-
- The padding ensures that the total number of tokens is now divisible
|
601
|
-
by block_size for proper block matrix operations.
|
602
|
-
"""
|
603
|
-
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
604
|
-
sorted_ids = torch.empty(
|
605
|
-
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
606
|
-
)
|
607
|
-
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
608
|
-
expert_ids = torch.empty(
|
609
|
-
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
610
|
-
)
|
611
|
-
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
612
|
-
|
613
|
-
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
614
|
-
cumsum_buffer = torch.empty(
|
615
|
-
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
616
|
-
)
|
617
|
-
|
618
|
-
# Threshold based on benchmark results
|
619
|
-
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
620
|
-
if not fuse_sorted_ids_padding:
|
621
|
-
sorted_ids.fill_(topk_ids.numel())
|
622
|
-
|
623
|
-
sgl_moe_align_block_size(
|
624
|
-
topk_ids,
|
625
|
-
num_experts + 1,
|
626
|
-
block_size,
|
627
|
-
sorted_ids,
|
628
|
-
expert_ids,
|
629
|
-
num_tokens_post_pad,
|
630
|
-
cumsum_buffer,
|
631
|
-
fuse_sorted_ids_padding,
|
632
|
-
)
|
633
|
-
return sorted_ids, expert_ids, num_tokens_post_pad
|
634
|
-
|
635
|
-
|
636
|
-
def invoke_fused_moe_kernel(
|
637
|
-
A: torch.Tensor,
|
638
|
-
B: torch.Tensor,
|
639
|
-
bias: Optional[torch.Tensor],
|
640
|
-
C: torch.Tensor,
|
641
|
-
A_scale: Optional[torch.Tensor],
|
642
|
-
B_scale: Optional[torch.Tensor],
|
643
|
-
B_zp: Optional[torch.Tensor],
|
644
|
-
topk_weights: torch.Tensor,
|
645
|
-
topk_ids: torch.Tensor,
|
646
|
-
sorted_token_ids: torch.Tensor,
|
647
|
-
expert_ids: torch.Tensor,
|
648
|
-
num_tokens_post_padded: torch.Tensor,
|
649
|
-
mul_routed_weight: bool,
|
650
|
-
top_k: int,
|
651
|
-
config: Dict[str, Any],
|
652
|
-
compute_type: tl.dtype,
|
653
|
-
use_fp8_w8a8: bool,
|
654
|
-
use_int8_w8a8: bool,
|
655
|
-
use_int8_w8a16: bool,
|
656
|
-
use_int4_w4a16: bool,
|
657
|
-
per_channel_quant: bool,
|
658
|
-
block_shape: Optional[List[int]] = None,
|
659
|
-
no_combine: bool = False,
|
660
|
-
) -> None:
|
661
|
-
assert topk_weights.stride(1) == 1
|
662
|
-
assert sorted_token_ids.stride(0) == 1
|
663
|
-
|
664
|
-
padded_size = 0
|
665
|
-
if use_fp8_w8a8:
|
666
|
-
assert B_scale is not None
|
667
|
-
if block_shape is None:
|
668
|
-
# activation tensor-wise fp8 quantization, dynamic or static
|
669
|
-
padded_size = padding_size
|
670
|
-
# activations apply per-token quantization when weights apply per-channel quantization by default
|
671
|
-
A, A_scale = scaled_fp8_quant(
|
672
|
-
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
673
|
-
)
|
674
|
-
else:
|
675
|
-
# activation block-wise fp8 quantization
|
676
|
-
assert len(block_shape) == 2
|
677
|
-
block_n, block_k = block_shape[0], block_shape[1]
|
678
|
-
if _is_cuda:
|
679
|
-
A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
|
680
|
-
else:
|
681
|
-
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
682
|
-
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
683
|
-
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
684
|
-
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
685
|
-
elif use_int8_w8a8:
|
686
|
-
assert B_scale is not None
|
687
|
-
if block_shape is None:
|
688
|
-
# activation channel-wise int8 quantization
|
689
|
-
assert (
|
690
|
-
per_channel_quant
|
691
|
-
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
|
692
|
-
A, A_scale = per_token_quant_int8(A)
|
693
|
-
else:
|
694
|
-
# activation block-wise int8 quantization
|
695
|
-
assert len(block_shape) == 2
|
696
|
-
block_n, block_k = block_shape[0], block_shape[1]
|
697
|
-
if _is_cuda:
|
698
|
-
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
|
699
|
-
else:
|
700
|
-
A, A_scale = per_token_group_quant_int8(A, block_k)
|
701
|
-
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
702
|
-
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
703
|
-
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
704
|
-
elif use_int8_w8a16 or use_int4_w4a16:
|
705
|
-
assert B_scale is not None
|
706
|
-
assert block_shape is None or block_shape[0] == 0
|
707
|
-
else:
|
708
|
-
assert A_scale is None
|
709
|
-
assert B_scale is None
|
710
|
-
|
711
|
-
grid = lambda META: (
|
712
|
-
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
713
|
-
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
714
|
-
)
|
715
|
-
|
716
|
-
K = B.shape[2] - padded_size
|
717
|
-
if K % config["BLOCK_SIZE_K"] == 0:
|
718
|
-
even_Ks = True
|
719
|
-
else:
|
720
|
-
even_Ks = False
|
721
|
-
|
722
|
-
if (
|
723
|
-
(use_int8_w8a16 or use_int4_w4a16)
|
724
|
-
and block_shape is not None
|
725
|
-
and block_shape[1] > 0
|
726
|
-
):
|
727
|
-
assert B_scale is not None and B_scale.ndim == 3
|
728
|
-
assert B_zp is None or B_zp.ndim == 3
|
729
|
-
assert bias is None
|
730
|
-
fused_moe_kernel_gptq_awq[grid](
|
731
|
-
A,
|
732
|
-
B,
|
733
|
-
C,
|
734
|
-
B_scale,
|
735
|
-
B_zp,
|
736
|
-
topk_weights,
|
737
|
-
sorted_token_ids,
|
738
|
-
expert_ids,
|
739
|
-
num_tokens_post_padded,
|
740
|
-
B.shape[1],
|
741
|
-
A.shape[1],
|
742
|
-
sorted_token_ids.shape[0],
|
743
|
-
topk_ids.numel(),
|
744
|
-
A.stride(0),
|
745
|
-
A.stride(1),
|
746
|
-
B.stride(0),
|
747
|
-
B.stride(2),
|
748
|
-
B.stride(1),
|
749
|
-
C.stride(1),
|
750
|
-
C.stride(2),
|
751
|
-
B_scale.stride(0),
|
752
|
-
B_scale.stride(2),
|
753
|
-
B_scale.stride(1),
|
754
|
-
B_zp.stride(0) if B_zp is not None else 0,
|
755
|
-
B_zp.stride(2) if B_zp is not None else 0,
|
756
|
-
B_zp.stride(1) if B_zp is not None else 0,
|
757
|
-
group_size=block_shape[1],
|
758
|
-
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
759
|
-
top_k=top_k,
|
760
|
-
compute_type=compute_type,
|
761
|
-
has_zp=B_zp is not None,
|
762
|
-
use_int4_w4a16=use_int4_w4a16,
|
763
|
-
use_int8_w8a16=use_int8_w8a16,
|
764
|
-
even_Ks=even_Ks,
|
765
|
-
**config,
|
766
|
-
)
|
767
|
-
|
768
|
-
else:
|
769
|
-
|
770
|
-
fused_moe_kernel[grid](
|
771
|
-
A,
|
772
|
-
B,
|
773
|
-
bias,
|
774
|
-
C,
|
775
|
-
A_scale,
|
776
|
-
B_scale,
|
777
|
-
topk_weights,
|
778
|
-
sorted_token_ids,
|
779
|
-
expert_ids,
|
780
|
-
num_tokens_post_padded,
|
781
|
-
B.shape[1],
|
782
|
-
B.shape[2] - padded_size,
|
783
|
-
sorted_token_ids.shape[0],
|
784
|
-
topk_ids.numel(),
|
785
|
-
A.stride(0),
|
786
|
-
A.stride(1),
|
787
|
-
B.stride(0),
|
788
|
-
B.stride(2),
|
789
|
-
B.stride(1),
|
790
|
-
bias.stride(0) if bias is not None else 0,
|
791
|
-
bias.stride(1) if bias is not None else 0,
|
792
|
-
C.stride(1),
|
793
|
-
C.stride(2),
|
794
|
-
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
795
|
-
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
796
|
-
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
797
|
-
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
798
|
-
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
799
|
-
0 if block_shape is None else block_shape[0],
|
800
|
-
0 if block_shape is None else block_shape[1],
|
801
|
-
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
802
|
-
top_k=top_k,
|
803
|
-
compute_type=compute_type,
|
804
|
-
use_fp8_w8a8=use_fp8_w8a8,
|
805
|
-
use_int8_w8a8=use_int8_w8a8,
|
806
|
-
use_int8_w8a16=use_int8_w8a16,
|
807
|
-
per_channel_quant=per_channel_quant,
|
808
|
-
even_Ks=even_Ks,
|
809
|
-
**config,
|
810
|
-
)
|
811
|
-
|
812
|
-
|
813
|
-
def get_config_file_name(
|
814
|
-
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
|
815
|
-
) -> str:
|
816
|
-
device_name = get_device_name().replace(" ", "_")
|
817
|
-
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
818
|
-
block_shape_selector = (
|
819
|
-
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
820
|
-
)
|
821
|
-
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
|
822
|
-
|
823
|
-
|
824
|
-
@functools.lru_cache
|
825
|
-
def get_moe_configs(
|
826
|
-
E: int,
|
827
|
-
N: int,
|
828
|
-
dtype: Optional[str],
|
829
|
-
block_n: Optional[int] = 0,
|
830
|
-
block_k: Optional[int] = 0,
|
831
|
-
) -> Optional[Dict[int, Any]]:
|
832
|
-
"""
|
833
|
-
Return optimized configurations for the fused MoE kernel.
|
834
|
-
|
835
|
-
The return value will be a dictionary that maps an irregular grid of
|
836
|
-
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
837
|
-
kernel on a given batch size bs, the closest batch size in the grid should
|
838
|
-
be picked and the associated configuration chosen to invoke the kernel.
|
839
|
-
"""
|
840
|
-
# Supported Triton versions, should be sorted from the newest to the oldest
|
841
|
-
supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
|
842
|
-
|
843
|
-
# First look up if an optimized configuration is available in the configs
|
844
|
-
# directory
|
845
|
-
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
|
846
|
-
|
847
|
-
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
848
|
-
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
849
|
-
triton_version = triton.__version__
|
850
|
-
version_dir = f"triton_{triton_version.replace('.', '_')}"
|
851
|
-
config_file_path = os.path.join(
|
852
|
-
os.path.dirname(os.path.realpath(__file__)),
|
853
|
-
"configs",
|
854
|
-
version_dir,
|
855
|
-
json_file_name,
|
856
|
-
)
|
857
|
-
if os.path.exists(config_file_path):
|
858
|
-
with open(config_file_path) as f:
|
859
|
-
# Please note that although we find the config files, performance might still be suboptimal.
|
860
|
-
# This is because the tuning environment might differ from your current environment.
|
861
|
-
# For example, updating the Triton version might cause all old configs to become suboptimal.
|
862
|
-
# To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
|
863
|
-
# For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
|
864
|
-
logger.info(f"Using MoE kernel config from {config_file_path}.")
|
865
|
-
# If a configuration has been found, return it
|
866
|
-
return {int(key): val for key, val in json.load(f).items()}
|
867
|
-
|
868
|
-
# Searching for other triton versions that supports the same config
|
869
|
-
for try_triton_version in supported_triton_versions:
|
870
|
-
if try_triton_version == triton_version:
|
871
|
-
continue
|
872
|
-
try_config_file_path = os.path.join(
|
873
|
-
os.path.dirname(os.path.realpath(__file__)),
|
874
|
-
"configs",
|
875
|
-
f"triton_{try_triton_version.replace('.', '_')}",
|
876
|
-
json_file_name,
|
877
|
-
)
|
878
|
-
if os.path.exists(try_config_file_path):
|
879
|
-
with open(try_config_file_path) as f:
|
880
|
-
logger.warning(
|
881
|
-
f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!",
|
882
|
-
)
|
883
|
-
# If a configuration has been found, return it
|
884
|
-
return {int(key): val for key, val in json.load(f).items()}
|
885
|
-
|
886
|
-
# If no optimized configuration is available, we will use the default
|
887
|
-
# configuration
|
888
|
-
logger.warning(
|
889
|
-
(
|
890
|
-
"Using default MoE kernel config. Performance might be sub-optimal! "
|
891
|
-
"Config file not found at %s, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton"
|
892
|
-
),
|
893
|
-
config_file_path,
|
894
|
-
)
|
895
|
-
return None
|
896
|
-
|
897
|
-
|
898
|
-
def get_default_config(
|
899
|
-
M: int,
|
900
|
-
E: int,
|
901
|
-
N: int,
|
902
|
-
K: int,
|
903
|
-
topk: int,
|
904
|
-
dtype: Optional[str],
|
905
|
-
is_marlin: bool,
|
906
|
-
block_shape: Optional[List[int]] = None,
|
907
|
-
) -> Dict[str, int]:
|
908
|
-
if dtype == "fp8_w8a8":
|
909
|
-
if block_shape is None:
|
910
|
-
config = {
|
911
|
-
"BLOCK_SIZE_M": 128,
|
912
|
-
"BLOCK_SIZE_N": 256,
|
913
|
-
"BLOCK_SIZE_K": 128,
|
914
|
-
"GROUP_SIZE_M": 32,
|
915
|
-
"num_warps": 8,
|
916
|
-
"num_stages": 2 if _is_hip else 4,
|
917
|
-
}
|
918
|
-
if M <= E:
|
919
|
-
config = {
|
920
|
-
"BLOCK_SIZE_M": 64,
|
921
|
-
"BLOCK_SIZE_N": 128,
|
922
|
-
"BLOCK_SIZE_K": 128,
|
923
|
-
"GROUP_SIZE_M": 1,
|
924
|
-
"num_warps": 4,
|
925
|
-
"num_stages": 2 if _is_hip else 4,
|
926
|
-
}
|
927
|
-
else:
|
928
|
-
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
|
929
|
-
config = {
|
930
|
-
"BLOCK_SIZE_M": 64,
|
931
|
-
"BLOCK_SIZE_N": block_shape[0],
|
932
|
-
"BLOCK_SIZE_K": block_shape[1],
|
933
|
-
"GROUP_SIZE_M": 32,
|
934
|
-
"num_warps": 4,
|
935
|
-
"num_stages": 2 if _is_hip else 3,
|
936
|
-
}
|
937
|
-
else:
|
938
|
-
config = {
|
939
|
-
"BLOCK_SIZE_M": 64,
|
940
|
-
"BLOCK_SIZE_N": 64,
|
941
|
-
"BLOCK_SIZE_K": 32,
|
942
|
-
"GROUP_SIZE_M": 8,
|
943
|
-
}
|
944
|
-
# A heuristic: fused marlin works faster with this config for small M
|
945
|
-
if M <= E or (is_marlin and M <= 32):
|
946
|
-
config = {
|
947
|
-
"BLOCK_SIZE_M": 16,
|
948
|
-
"BLOCK_SIZE_N": 32,
|
949
|
-
"BLOCK_SIZE_K": 64,
|
950
|
-
"GROUP_SIZE_M": 1,
|
951
|
-
}
|
952
|
-
return config
|
953
|
-
|
954
|
-
|
955
|
-
def try_get_optimal_moe_config(
|
956
|
-
w1_shape: Tuple[int, ...],
|
957
|
-
w2_shape: Tuple[int, ...],
|
958
|
-
top_k: int,
|
959
|
-
dtype: Optional[str],
|
960
|
-
M: int,
|
961
|
-
is_marlin: bool = False,
|
962
|
-
block_shape: Optional[List[int]] = None,
|
963
|
-
):
|
964
|
-
from sglang.srt.layers.moe.fused_moe_triton import get_config
|
965
|
-
|
966
|
-
override_config = get_config()
|
967
|
-
if override_config:
|
968
|
-
config = override_config
|
969
|
-
else:
|
970
|
-
# First try to load optimal config from the file
|
971
|
-
E, _, N = w2_shape
|
972
|
-
block_n = block_shape[0] if block_shape else 0
|
973
|
-
block_k = block_shape[1] if block_shape else 0
|
974
|
-
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
975
|
-
|
976
|
-
if configs:
|
977
|
-
# If an optimal configuration map has been found, look up the
|
978
|
-
# optimal config
|
979
|
-
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
980
|
-
else:
|
981
|
-
# Else use the default config
|
982
|
-
config = get_default_config(
|
983
|
-
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
|
984
|
-
)
|
985
|
-
return config
|
986
|
-
|
987
|
-
|
988
|
-
def get_config_dtype_str(
|
989
|
-
dtype: torch.dtype,
|
990
|
-
use_int8_w8a16: Optional[bool] = False,
|
991
|
-
use_int4_w4a16: Optional[bool] = False,
|
992
|
-
use_fp8_w8a8: Optional[bool] = False,
|
993
|
-
use_int8_w8a8: Optional[bool] = False,
|
994
|
-
):
|
995
|
-
if use_fp8_w8a8:
|
996
|
-
return "fp8_w8a8"
|
997
|
-
elif use_int8_w8a8:
|
998
|
-
return "int8_w8a8"
|
999
|
-
elif use_int4_w4a16:
|
1000
|
-
return "int4_w4a16"
|
1001
|
-
elif use_int8_w8a16:
|
1002
|
-
return "int8_w8a16"
|
1003
|
-
elif dtype == torch.float:
|
1004
|
-
# avoiding cases where kernel fails when float32 MoE
|
1005
|
-
# use fp16/bfloat16 configs
|
1006
|
-
return "float32"
|
1007
|
-
return None
|
1008
|
-
|
1009
|
-
|
1010
53
|
def inplace_fused_experts(
|
1011
54
|
hidden_states: torch.Tensor,
|
1012
55
|
w1: torch.Tensor,
|
@@ -1276,92 +319,6 @@ def fused_experts(
|
|
1276
319
|
)
|
1277
320
|
|
1278
321
|
|
1279
|
-
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
1280
|
-
@triton.jit
|
1281
|
-
def _moe_sum_reduce_kernel(
|
1282
|
-
input_ptr,
|
1283
|
-
input_stride_0,
|
1284
|
-
input_stride_1,
|
1285
|
-
input_stride_2,
|
1286
|
-
output_ptr,
|
1287
|
-
output_stride_0,
|
1288
|
-
output_stride_1,
|
1289
|
-
token_num: int,
|
1290
|
-
topk_num: int,
|
1291
|
-
hidden_dim: int,
|
1292
|
-
routed_scaling_factor: tl.constexpr,
|
1293
|
-
BLOCK_M: tl.constexpr,
|
1294
|
-
BLOCK_DIM: tl.constexpr,
|
1295
|
-
NUM_STAGE: tl.constexpr,
|
1296
|
-
):
|
1297
|
-
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
1298
|
-
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
1299
|
-
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
1300
|
-
|
1301
|
-
token_block_id = tl.program_id(0)
|
1302
|
-
dim_block_id = tl.program_id(1)
|
1303
|
-
|
1304
|
-
token_start = token_block_id * BLOCK_M
|
1305
|
-
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
1306
|
-
|
1307
|
-
dim_start = dim_block_id * BLOCK_DIM
|
1308
|
-
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
1309
|
-
|
1310
|
-
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
1311
|
-
|
1312
|
-
for token_index in range(token_start, token_end):
|
1313
|
-
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
1314
|
-
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
1315
|
-
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
1316
|
-
tmp = tl.load(
|
1317
|
-
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
1318
|
-
)
|
1319
|
-
accumulator += tmp
|
1320
|
-
accumulator = accumulator * routed_scaling_factor
|
1321
|
-
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
1322
|
-
tl.store(
|
1323
|
-
store_t_ptr,
|
1324
|
-
accumulator.to(input_ptr.dtype.element_ty),
|
1325
|
-
mask=offs_dim < dim_end,
|
1326
|
-
)
|
1327
|
-
|
1328
|
-
|
1329
|
-
def moe_sum_reduce_triton(
|
1330
|
-
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
1331
|
-
):
|
1332
|
-
assert input.is_contiguous()
|
1333
|
-
assert output.is_contiguous()
|
1334
|
-
|
1335
|
-
token_num, topk_num, hidden_dim = input.shape
|
1336
|
-
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
1337
|
-
|
1338
|
-
BLOCK_M = 1
|
1339
|
-
BLOCK_DIM = 2048
|
1340
|
-
NUM_STAGE = 1
|
1341
|
-
num_warps = 8
|
1342
|
-
|
1343
|
-
grid = (
|
1344
|
-
triton.cdiv(token_num, BLOCK_M),
|
1345
|
-
triton.cdiv(hidden_dim, BLOCK_DIM),
|
1346
|
-
)
|
1347
|
-
|
1348
|
-
_moe_sum_reduce_kernel[grid](
|
1349
|
-
input,
|
1350
|
-
*input.stride(),
|
1351
|
-
output,
|
1352
|
-
*output.stride(),
|
1353
|
-
token_num=token_num,
|
1354
|
-
topk_num=topk_num,
|
1355
|
-
hidden_dim=hidden_dim,
|
1356
|
-
routed_scaling_factor=routed_scaling_factor,
|
1357
|
-
BLOCK_M=BLOCK_M,
|
1358
|
-
BLOCK_DIM=BLOCK_DIM,
|
1359
|
-
NUM_STAGE=NUM_STAGE,
|
1360
|
-
num_warps=num_warps,
|
1361
|
-
)
|
1362
|
-
return
|
1363
|
-
|
1364
|
-
|
1365
322
|
@torch.compile
|
1366
323
|
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
1367
324
|
torch.sum(x, dim=1, out=out)
|