sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,16 +6,20 @@ import functools
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
-
from typing import Any, Callable, Dict, Optional, Tuple
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
10
10
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
14
15
|
from vllm import _custom_ops as ops
|
15
16
|
|
17
|
+
from sglang.srt.layers.moe.topk import select_experts
|
18
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
16
19
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
17
20
|
|
18
21
|
logger = logging.getLogger(__name__)
|
22
|
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
19
23
|
|
20
24
|
|
21
25
|
@triton.jit
|
@@ -46,8 +50,14 @@ def fused_moe_kernel(
|
|
46
50
|
stride_bn,
|
47
51
|
stride_cm,
|
48
52
|
stride_cn,
|
53
|
+
stride_asm,
|
54
|
+
stride_ask,
|
49
55
|
stride_bse,
|
56
|
+
stride_bsk,
|
50
57
|
stride_bsn,
|
58
|
+
# Block size for block-wise quantization
|
59
|
+
group_n: tl.constexpr,
|
60
|
+
group_k: tl.constexpr,
|
51
61
|
# Meta-parameters
|
52
62
|
BLOCK_SIZE_M: tl.constexpr,
|
53
63
|
BLOCK_SIZE_N: tl.constexpr,
|
@@ -58,6 +68,7 @@ def fused_moe_kernel(
|
|
58
68
|
compute_type: tl.constexpr,
|
59
69
|
use_fp8_w8a8: tl.constexpr,
|
60
70
|
use_int8_w8a16: tl.constexpr,
|
71
|
+
even_Ks: tl.constexpr,
|
61
72
|
):
|
62
73
|
"""
|
63
74
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
@@ -130,8 +141,15 @@ def fused_moe_kernel(
|
|
130
141
|
b_scale = tl.load(b_scale_ptrs)
|
131
142
|
|
132
143
|
if use_fp8_w8a8:
|
133
|
-
|
134
|
-
|
144
|
+
if group_k > 0 and group_n > 0:
|
145
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
146
|
+
offs_bsn = offs_bn // group_n
|
147
|
+
b_scale_ptrs = (
|
148
|
+
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
149
|
+
)
|
150
|
+
else:
|
151
|
+
a_scale = tl.load(a_scale_ptr)
|
152
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
135
153
|
|
136
154
|
# -----------------------------------------------------------
|
137
155
|
# Iterate to compute a block of the C matrix.
|
@@ -143,17 +161,36 @@ def fused_moe_kernel(
|
|
143
161
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
144
162
|
# Load the next block of A and B, generate a mask by checking the
|
145
163
|
# K dimension.
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
164
|
+
if even_Ks:
|
165
|
+
a = tl.load(
|
166
|
+
a_ptrs,
|
167
|
+
mask=token_mask[:, None],
|
168
|
+
other=0.0,
|
169
|
+
)
|
170
|
+
b = tl.load(b_ptrs)
|
171
|
+
else:
|
172
|
+
a = tl.load(
|
173
|
+
a_ptrs,
|
174
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
175
|
+
other=0.0,
|
176
|
+
)
|
177
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
178
|
+
|
152
179
|
# We accumulate along the K dimension.
|
153
180
|
if use_int8_w8a16:
|
154
181
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
155
182
|
elif use_fp8_w8a8:
|
156
|
-
|
183
|
+
if group_k > 0 and group_n > 0:
|
184
|
+
k_start = k * BLOCK_SIZE_K
|
185
|
+
offs_ks = k_start // group_k
|
186
|
+
a_scale = tl.load(
|
187
|
+
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
188
|
+
)
|
189
|
+
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
190
|
+
|
191
|
+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
192
|
+
else:
|
193
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
157
194
|
else:
|
158
195
|
accumulator += tl.dot(a, b)
|
159
196
|
# Advance the ptrs to the next K block.
|
@@ -166,7 +203,10 @@ def fused_moe_kernel(
|
|
166
203
|
if use_int8_w8a16:
|
167
204
|
accumulator = (accumulator * b_scale).to(compute_type)
|
168
205
|
elif use_fp8_w8a8:
|
169
|
-
|
206
|
+
if group_k > 0 and group_n > 0:
|
207
|
+
accumulator = accumulator.to(compute_type)
|
208
|
+
else:
|
209
|
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
170
210
|
else:
|
171
211
|
accumulator = accumulator.to(compute_type)
|
172
212
|
# -----------------------------------------------------------
|
@@ -227,9 +267,25 @@ def moe_align_block_size(
|
|
227
267
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
228
268
|
)
|
229
269
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
230
|
-
|
231
|
-
|
232
|
-
|
270
|
+
# FIXME(zhyncs)
|
271
|
+
if num_experts >= 256:
|
272
|
+
sgl_moe_align_block_size(
|
273
|
+
topk_ids,
|
274
|
+
num_experts,
|
275
|
+
block_size,
|
276
|
+
sorted_ids,
|
277
|
+
expert_ids,
|
278
|
+
num_tokens_post_pad,
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
ops.moe_align_block_size(
|
282
|
+
topk_ids,
|
283
|
+
num_experts,
|
284
|
+
block_size,
|
285
|
+
sorted_ids,
|
286
|
+
expert_ids,
|
287
|
+
num_tokens_post_pad,
|
288
|
+
)
|
233
289
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
234
290
|
|
235
291
|
|
@@ -250,13 +306,24 @@ def invoke_fused_moe_kernel(
|
|
250
306
|
compute_type: tl.dtype,
|
251
307
|
use_fp8_w8a8: bool,
|
252
308
|
use_int8_w8a16: bool,
|
309
|
+
block_shape: Optional[List[int]] = None,
|
253
310
|
) -> None:
|
254
311
|
assert topk_weights.stride(1) == 1
|
255
312
|
assert sorted_token_ids.stride(0) == 1
|
256
313
|
|
314
|
+
padded_size = 0
|
257
315
|
if use_fp8_w8a8:
|
258
|
-
|
316
|
+
padded_size = padding_size
|
259
317
|
assert B_scale is not None
|
318
|
+
if block_shape is None:
|
319
|
+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
320
|
+
else:
|
321
|
+
assert len(block_shape) == 2
|
322
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
323
|
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
324
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
325
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
326
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
260
327
|
elif use_int8_w8a16:
|
261
328
|
assert B_scale is not None
|
262
329
|
else:
|
@@ -268,6 +335,12 @@ def invoke_fused_moe_kernel(
|
|
268
335
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
269
336
|
)
|
270
337
|
|
338
|
+
K = B.shape[2] - padded_size
|
339
|
+
if K % config["BLOCK_SIZE_K"] == 0:
|
340
|
+
even_Ks = True
|
341
|
+
else:
|
342
|
+
even_Ks = False
|
343
|
+
|
271
344
|
fused_moe_kernel[grid](
|
272
345
|
A,
|
273
346
|
B,
|
@@ -279,7 +352,7 @@ def invoke_fused_moe_kernel(
|
|
279
352
|
expert_ids,
|
280
353
|
num_tokens_post_padded,
|
281
354
|
B.shape[1],
|
282
|
-
B.shape[2],
|
355
|
+
B.shape[2] - padded_size,
|
283
356
|
sorted_token_ids.shape[0],
|
284
357
|
topk_ids.numel(),
|
285
358
|
A.stride(0),
|
@@ -289,13 +362,19 @@ def invoke_fused_moe_kernel(
|
|
289
362
|
B.stride(1),
|
290
363
|
C.stride(1),
|
291
364
|
C.stride(2),
|
292
|
-
|
293
|
-
|
365
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
366
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
367
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
368
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
369
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
370
|
+
0 if block_shape is None else block_shape[0],
|
371
|
+
0 if block_shape is None else block_shape[1],
|
294
372
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
295
373
|
top_k=top_k,
|
296
374
|
compute_type=compute_type,
|
297
375
|
use_fp8_w8a8=use_fp8_w8a8,
|
298
376
|
use_int8_w8a16=use_int8_w8a16,
|
377
|
+
even_Ks=even_Ks,
|
299
378
|
**config,
|
300
379
|
)
|
301
380
|
|
@@ -351,20 +430,39 @@ def get_default_config(
|
|
351
430
|
dtype: Optional[str],
|
352
431
|
is_marlin: bool,
|
353
432
|
) -> Dict[str, int]:
|
354
|
-
|
355
|
-
"BLOCK_SIZE_M": 64,
|
356
|
-
"BLOCK_SIZE_N": 64,
|
357
|
-
"BLOCK_SIZE_K": 32,
|
358
|
-
"GROUP_SIZE_M": 8,
|
359
|
-
}
|
360
|
-
# A heuristic: fused marlin works faster with this config for small M
|
361
|
-
if M <= E or (is_marlin and M <= 32):
|
433
|
+
if dtype == "fp8_w8a8":
|
362
434
|
config = {
|
363
|
-
"BLOCK_SIZE_M":
|
364
|
-
"BLOCK_SIZE_N":
|
365
|
-
"BLOCK_SIZE_K":
|
366
|
-
"GROUP_SIZE_M":
|
435
|
+
"BLOCK_SIZE_M": 128,
|
436
|
+
"BLOCK_SIZE_N": 256,
|
437
|
+
"BLOCK_SIZE_K": 128,
|
438
|
+
"GROUP_SIZE_M": 32,
|
439
|
+
"num_warps": 8,
|
440
|
+
"num_stages": 4,
|
367
441
|
}
|
442
|
+
if M <= E:
|
443
|
+
config = {
|
444
|
+
"BLOCK_SIZE_M": 64,
|
445
|
+
"BLOCK_SIZE_N": 128,
|
446
|
+
"BLOCK_SIZE_K": 128,
|
447
|
+
"GROUP_SIZE_M": 1,
|
448
|
+
"num_warps": 4,
|
449
|
+
"num_stages": 4,
|
450
|
+
}
|
451
|
+
else:
|
452
|
+
config = {
|
453
|
+
"BLOCK_SIZE_M": 64,
|
454
|
+
"BLOCK_SIZE_N": 64,
|
455
|
+
"BLOCK_SIZE_K": 32,
|
456
|
+
"GROUP_SIZE_M": 8,
|
457
|
+
}
|
458
|
+
# A heuristic: fused marlin works faster with this config for small M
|
459
|
+
if M <= E or (is_marlin and M <= 32):
|
460
|
+
config = {
|
461
|
+
"BLOCK_SIZE_M": 16,
|
462
|
+
"BLOCK_SIZE_N": 32,
|
463
|
+
"BLOCK_SIZE_K": 64,
|
464
|
+
"GROUP_SIZE_M": 1,
|
465
|
+
}
|
368
466
|
return config
|
369
467
|
|
370
468
|
|
@@ -375,8 +473,9 @@ def try_get_optimal_moe_config(
|
|
375
473
|
dtype: Optional[str],
|
376
474
|
M: int,
|
377
475
|
is_marlin: bool = False,
|
476
|
+
block_shape: Optional[List[int]] = None,
|
378
477
|
):
|
379
|
-
from sglang.srt.layers.fused_moe_triton import get_config
|
478
|
+
from sglang.srt.layers.moe.fused_moe_triton import get_config
|
380
479
|
|
381
480
|
override_config = get_config()
|
382
481
|
if override_config:
|
@@ -393,77 +492,16 @@ def try_get_optimal_moe_config(
|
|
393
492
|
else:
|
394
493
|
# Else use the default config
|
395
494
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
495
|
+
# TODO(HandH1998): Optimize the configs of block-wise quant.
|
496
|
+
# NOTE(HandH1998): For block-wise quant,
|
497
|
+
# BLOCK_K must be divisable by block_shape[1]
|
498
|
+
# BLOCK_N and BLOCK_M has no requirements
|
499
|
+
if block_shape is not None:
|
500
|
+
config["BLOCK_SIZE_N"] = block_shape[0]
|
501
|
+
config["BLOCK_SIZE_K"] = block_shape[1]
|
396
502
|
return config
|
397
503
|
|
398
504
|
|
399
|
-
def fused_topk(
|
400
|
-
hidden_states: torch.Tensor,
|
401
|
-
gating_output: torch.Tensor,
|
402
|
-
topk: int,
|
403
|
-
renormalize: bool,
|
404
|
-
):
|
405
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
406
|
-
|
407
|
-
M, _ = hidden_states.shape
|
408
|
-
|
409
|
-
topk_weights = torch.empty(
|
410
|
-
M, topk, dtype=torch.float32, device=hidden_states.device
|
411
|
-
)
|
412
|
-
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
413
|
-
token_expert_indicies = torch.empty(
|
414
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
415
|
-
)
|
416
|
-
|
417
|
-
ops.topk_softmax(
|
418
|
-
topk_weights,
|
419
|
-
topk_ids,
|
420
|
-
token_expert_indicies,
|
421
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
422
|
-
)
|
423
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
424
|
-
|
425
|
-
if renormalize:
|
426
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
427
|
-
|
428
|
-
return topk_weights, topk_ids
|
429
|
-
|
430
|
-
|
431
|
-
# This is used by the Deepseek-V2 model
|
432
|
-
def grouped_topk(
|
433
|
-
hidden_states: torch.Tensor,
|
434
|
-
gating_output: torch.Tensor,
|
435
|
-
topk: int,
|
436
|
-
renormalize: bool,
|
437
|
-
num_expert_group: int = 0,
|
438
|
-
topk_group: int = 0,
|
439
|
-
):
|
440
|
-
|
441
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
442
|
-
|
443
|
-
scores = torch.softmax(gating_output, dim=-1)
|
444
|
-
num_token = scores.shape[0]
|
445
|
-
group_scores = (
|
446
|
-
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
447
|
-
) # [n, n_group]
|
448
|
-
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
449
|
-
1
|
450
|
-
] # [n, top_k_group]
|
451
|
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
452
|
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
453
|
-
score_mask = (
|
454
|
-
group_mask.unsqueeze(-1)
|
455
|
-
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
456
|
-
.reshape(num_token, -1)
|
457
|
-
) # [n, e]
|
458
|
-
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
459
|
-
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
460
|
-
|
461
|
-
if renormalize:
|
462
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
463
|
-
|
464
|
-
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
465
|
-
|
466
|
-
|
467
505
|
def get_config_dtype_str(
|
468
506
|
dtype: torch.dtype,
|
469
507
|
use_int8_w8a16: Optional[bool] = False,
|
@@ -492,6 +530,7 @@ def inplace_fused_experts(
|
|
492
530
|
w2_scale: Optional[torch.Tensor] = None,
|
493
531
|
a1_scale: Optional[torch.Tensor] = None,
|
494
532
|
a2_scale: Optional[torch.Tensor] = None,
|
533
|
+
block_shape: Optional[List[int]] = None,
|
495
534
|
) -> None:
|
496
535
|
fused_experts_impl(
|
497
536
|
hidden_states,
|
@@ -506,6 +545,7 @@ def inplace_fused_experts(
|
|
506
545
|
w2_scale,
|
507
546
|
a1_scale,
|
508
547
|
a2_scale,
|
548
|
+
block_shape,
|
509
549
|
)
|
510
550
|
|
511
551
|
|
@@ -521,6 +561,7 @@ def inplace_fused_experts_fake(
|
|
521
561
|
w2_scale: Optional[torch.Tensor] = None,
|
522
562
|
a1_scale: Optional[torch.Tensor] = None,
|
523
563
|
a2_scale: Optional[torch.Tensor] = None,
|
564
|
+
block_shape: Optional[List[int]] = None,
|
524
565
|
) -> None:
|
525
566
|
pass
|
526
567
|
|
@@ -545,6 +586,7 @@ def outplace_fused_experts(
|
|
545
586
|
w2_scale: Optional[torch.Tensor] = None,
|
546
587
|
a1_scale: Optional[torch.Tensor] = None,
|
547
588
|
a2_scale: Optional[torch.Tensor] = None,
|
589
|
+
block_shape: Optional[List[int]] = None,
|
548
590
|
) -> torch.Tensor:
|
549
591
|
return fused_experts_impl(
|
550
592
|
hidden_states,
|
@@ -559,6 +601,7 @@ def outplace_fused_experts(
|
|
559
601
|
w2_scale,
|
560
602
|
a1_scale,
|
561
603
|
a2_scale,
|
604
|
+
block_shape,
|
562
605
|
)
|
563
606
|
|
564
607
|
|
@@ -574,6 +617,7 @@ def outplace_fused_experts_fake(
|
|
574
617
|
w2_scale: Optional[torch.Tensor] = None,
|
575
618
|
a1_scale: Optional[torch.Tensor] = None,
|
576
619
|
a2_scale: Optional[torch.Tensor] = None,
|
620
|
+
block_shape: Optional[List[int]] = None,
|
577
621
|
) -> torch.Tensor:
|
578
622
|
return torch.empty_like(hidden_states)
|
579
623
|
|
@@ -599,6 +643,7 @@ def fused_experts(
|
|
599
643
|
w2_scale: Optional[torch.Tensor] = None,
|
600
644
|
a1_scale: Optional[torch.Tensor] = None,
|
601
645
|
a2_scale: Optional[torch.Tensor] = None,
|
646
|
+
block_shape: Optional[List[int]] = None,
|
602
647
|
):
|
603
648
|
if inplace:
|
604
649
|
torch.ops.sglang.inplace_fused_experts(
|
@@ -613,6 +658,7 @@ def fused_experts(
|
|
613
658
|
w2_scale,
|
614
659
|
a1_scale,
|
615
660
|
a2_scale,
|
661
|
+
block_shape,
|
616
662
|
)
|
617
663
|
return hidden_states
|
618
664
|
else:
|
@@ -628,6 +674,7 @@ def fused_experts(
|
|
628
674
|
w2_scale,
|
629
675
|
a1_scale,
|
630
676
|
a2_scale,
|
677
|
+
block_shape,
|
631
678
|
)
|
632
679
|
|
633
680
|
|
@@ -644,9 +691,14 @@ def fused_experts_impl(
|
|
644
691
|
w2_scale: Optional[torch.Tensor] = None,
|
645
692
|
a1_scale: Optional[torch.Tensor] = None,
|
646
693
|
a2_scale: Optional[torch.Tensor] = None,
|
694
|
+
block_shape: Optional[List[int]] = None,
|
647
695
|
):
|
696
|
+
padded_size = padding_size
|
697
|
+
if not use_fp8_w8a8:
|
698
|
+
padded_size = 0
|
699
|
+
|
648
700
|
# Check constraints.
|
649
|
-
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
701
|
+
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
650
702
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
651
703
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
652
704
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -668,9 +720,10 @@ def fused_experts_impl(
|
|
668
720
|
get_config_func = functools.partial(
|
669
721
|
try_get_optimal_moe_config,
|
670
722
|
w1.shape,
|
671
|
-
w2.shape,
|
723
|
+
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
672
724
|
topk_ids.shape[1],
|
673
725
|
config_dtype,
|
726
|
+
block_shape=block_shape,
|
674
727
|
)
|
675
728
|
|
676
729
|
config = get_config_func(M)
|
@@ -743,6 +796,7 @@ def fused_experts_impl(
|
|
743
796
|
compute_type=compute_type,
|
744
797
|
use_fp8_w8a8=use_fp8_w8a8,
|
745
798
|
use_int8_w8a16=use_int8_w8a16,
|
799
|
+
block_shape=block_shape,
|
746
800
|
)
|
747
801
|
|
748
802
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
@@ -764,6 +818,7 @@ def fused_experts_impl(
|
|
764
818
|
compute_type=compute_type,
|
765
819
|
use_fp8_w8a8=use_fp8_w8a8,
|
766
820
|
use_int8_w8a16=use_int8_w8a16,
|
821
|
+
block_shape=block_shape,
|
767
822
|
)
|
768
823
|
|
769
824
|
torch.sum(
|
@@ -792,6 +847,7 @@ def fused_moe(
|
|
792
847
|
w2_scale: Optional[torch.Tensor] = None,
|
793
848
|
a1_scale: Optional[torch.Tensor] = None,
|
794
849
|
a2_scale: Optional[torch.Tensor] = None,
|
850
|
+
block_shape: Optional[List[int]] = None,
|
795
851
|
) -> torch.Tensor:
|
796
852
|
"""
|
797
853
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -819,6 +875,12 @@ def fused_moe(
|
|
819
875
|
w1.
|
820
876
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
821
877
|
w2.
|
878
|
+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
879
|
+
a1.
|
880
|
+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
881
|
+
a2.
|
882
|
+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
883
|
+
quantization.
|
822
884
|
|
823
885
|
Returns:
|
824
886
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
@@ -826,24 +888,16 @@ def fused_moe(
|
|
826
888
|
# Check constraints.
|
827
889
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
828
890
|
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
elif custom_routing_function is None:
|
840
|
-
topk_weights, topk_ids = fused_topk(
|
841
|
-
hidden_states, gating_output, topk, renormalize
|
842
|
-
)
|
843
|
-
else:
|
844
|
-
topk_weights, topk_ids = custom_routing_function(
|
845
|
-
hidden_states, gating_output, topk, renormalize
|
846
|
-
)
|
891
|
+
topk_weights, topk_ids = select_experts(
|
892
|
+
hidden_states=hidden_states,
|
893
|
+
router_logits=gating_output,
|
894
|
+
use_grouped_topk=use_grouped_topk,
|
895
|
+
top_k=topk,
|
896
|
+
renormalize=renormalize,
|
897
|
+
topk_group=topk_group,
|
898
|
+
num_expert_group=num_expert_group,
|
899
|
+
custom_routing_function=custom_routing_function,
|
900
|
+
)
|
847
901
|
|
848
902
|
return fused_experts(
|
849
903
|
hidden_states,
|
@@ -858,4 +912,5 @@ def fused_moe(
|
|
858
912
|
w2_scale=w2_scale,
|
859
913
|
a1_scale=a1_scale,
|
860
914
|
a2_scale=a2_scale,
|
915
|
+
block_shape=block_shape,
|
861
916
|
)
|
@@ -13,14 +13,15 @@ from vllm.distributed import (
|
|
13
13
|
from vllm.model_executor.custom_op import CustomOp
|
14
14
|
|
15
15
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.moe.topk import select_experts
|
16
17
|
from sglang.srt.layers.quantization.base_config import (
|
17
18
|
QuantizationConfig,
|
18
19
|
QuantizeMethodBase,
|
19
20
|
)
|
20
21
|
from sglang.srt.utils import set_weight_attrs
|
21
22
|
|
22
|
-
if torch.cuda.is_available()
|
23
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
23
|
+
if torch.cuda.is_available():
|
24
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
24
25
|
else:
|
25
26
|
fused_experts = None # type: ignore
|
26
27
|
|
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
|
|
33
34
|
TENSOR = "tensor"
|
34
35
|
CHANNEL = "channel"
|
35
36
|
GROUP = "group"
|
37
|
+
BLOCK = "block"
|
36
38
|
|
37
39
|
|
38
40
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
106
108
|
topk_group: Optional[int] = None,
|
107
109
|
num_expert_group: Optional[int] = None,
|
108
110
|
custom_routing_function: Optional[Callable] = None,
|
111
|
+
correction_bias: Optional[torch.Tensor] = None,
|
109
112
|
) -> torch.Tensor:
|
110
113
|
return self.forward(
|
111
114
|
x=x,
|
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
117
120
|
topk_group=topk_group,
|
118
121
|
num_expert_group=num_expert_group,
|
119
122
|
custom_routing_function=custom_routing_function,
|
123
|
+
correction_bias=correction_bias,
|
120
124
|
)
|
121
125
|
|
122
126
|
def forward_cuda(
|
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
130
134
|
topk_group: Optional[int] = None,
|
131
135
|
num_expert_group: Optional[int] = None,
|
132
136
|
custom_routing_function: Optional[Callable] = None,
|
137
|
+
correction_bias: Optional[torch.Tensor] = None,
|
133
138
|
) -> torch.Tensor:
|
134
|
-
topk_weights, topk_ids =
|
139
|
+
topk_weights, topk_ids = select_experts(
|
135
140
|
hidden_states=x,
|
136
141
|
router_logits=router_logits,
|
137
142
|
use_grouped_topk=use_grouped_topk,
|
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
140
145
|
topk_group=topk_group,
|
141
146
|
num_expert_group=num_expert_group,
|
142
147
|
custom_routing_function=custom_routing_function,
|
148
|
+
correction_bias=correction_bias,
|
143
149
|
)
|
144
150
|
|
145
151
|
return fused_experts(
|
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
203
|
tp_size: Optional[int] = None,
|
198
204
|
prefix: str = "",
|
199
205
|
custom_routing_function: Optional[Callable] = None,
|
206
|
+
correction_bias: Optional[torch.Tensor] = None,
|
200
207
|
):
|
201
208
|
super().__init__()
|
202
209
|
|
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
|
|
208
215
|
)
|
209
216
|
self.top_k = top_k
|
210
217
|
self.num_experts = num_experts
|
218
|
+
assert intermediate_size % self.tp_size == 0
|
211
219
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
212
220
|
self.reduce_results = reduce_results
|
213
221
|
self.renormalize = renormalize
|
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
|
|
217
225
|
self.num_expert_group = num_expert_group
|
218
226
|
self.topk_group = topk_group
|
219
227
|
self.custom_routing_function = custom_routing_function
|
228
|
+
self.correction_bias = correction_bias
|
220
229
|
|
221
230
|
if quant_config is None:
|
222
231
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
|
|
463
472
|
expert_data=expert_data,
|
464
473
|
tp_rank=tp_rank,
|
465
474
|
)
|
466
|
-
elif quant_method
|
475
|
+
elif quant_method in [
|
476
|
+
FusedMoeWeightScaleSupported.GROUP.value,
|
477
|
+
FusedMoeWeightScaleSupported.BLOCK.value,
|
478
|
+
]:
|
467
479
|
self._load_model_weight_or_group_weight_scale(
|
468
480
|
shard_id=shard_id,
|
469
481
|
shard_dim=shard_dim,
|
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
|
|
503
515
|
)
|
504
516
|
return
|
505
517
|
|
506
|
-
@staticmethod
|
507
|
-
def select_experts(
|
508
|
-
hidden_states: torch.Tensor,
|
509
|
-
router_logits: torch.Tensor,
|
510
|
-
top_k: int,
|
511
|
-
use_grouped_topk: bool,
|
512
|
-
renormalize: bool,
|
513
|
-
topk_group: Optional[int] = None,
|
514
|
-
num_expert_group: Optional[int] = None,
|
515
|
-
custom_routing_function: Optional[Callable] = None,
|
516
|
-
):
|
517
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
518
|
-
fused_topk,
|
519
|
-
grouped_topk,
|
520
|
-
)
|
521
|
-
|
522
|
-
# DeekSeekv2 uses grouped_top_k
|
523
|
-
if use_grouped_topk:
|
524
|
-
assert topk_group is not None
|
525
|
-
assert num_expert_group is not None
|
526
|
-
topk_weights, topk_ids = grouped_topk(
|
527
|
-
hidden_states=hidden_states,
|
528
|
-
gating_output=router_logits,
|
529
|
-
topk=top_k,
|
530
|
-
renormalize=renormalize,
|
531
|
-
num_expert_group=num_expert_group,
|
532
|
-
topk_group=topk_group,
|
533
|
-
)
|
534
|
-
elif custom_routing_function is None:
|
535
|
-
topk_weights, topk_ids = fused_topk(
|
536
|
-
hidden_states=hidden_states,
|
537
|
-
gating_output=router_logits,
|
538
|
-
topk=top_k,
|
539
|
-
renormalize=renormalize,
|
540
|
-
)
|
541
|
-
else:
|
542
|
-
topk_weights, topk_ids = custom_routing_function(
|
543
|
-
hidden_states=hidden_states,
|
544
|
-
gating_output=router_logits,
|
545
|
-
topk=top_k,
|
546
|
-
renormalize=renormalize,
|
547
|
-
)
|
548
|
-
|
549
|
-
return topk_weights, topk_ids
|
550
|
-
|
551
518
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
552
519
|
assert self.quant_method is not None
|
553
520
|
|
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
|
|
562
529
|
topk_group=self.topk_group,
|
563
530
|
num_expert_group=self.num_expert_group,
|
564
531
|
custom_routing_function=self.custom_routing_function,
|
532
|
+
correction_bias=self.correction_bias,
|
565
533
|
)
|
566
534
|
|
567
535
|
if self.reduce_results and self.tp_size > 1:
|