sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +1 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/scheduler.py +2 -2
- sglang/srt/managers/tokenizer_manager.py +86 -76
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -0
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/server.py +1 -0
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,16 @@ import functools
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
-
from typing import Any, Callable, Dict, Optional, Tuple
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
10
10
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
14
15
|
from vllm import _custom_ops as ops
|
15
16
|
|
17
|
+
from sglang.srt.layers.moe.topk import select_experts
|
18
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
16
19
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
17
20
|
|
18
21
|
logger = logging.getLogger(__name__)
|
@@ -47,8 +50,14 @@ def fused_moe_kernel(
|
|
47
50
|
stride_bn,
|
48
51
|
stride_cm,
|
49
52
|
stride_cn,
|
53
|
+
stride_asm,
|
54
|
+
stride_ask,
|
50
55
|
stride_bse,
|
56
|
+
stride_bsk,
|
51
57
|
stride_bsn,
|
58
|
+
# Block size for block-wise quantization
|
59
|
+
group_n: tl.constexpr,
|
60
|
+
group_k: tl.constexpr,
|
52
61
|
# Meta-parameters
|
53
62
|
BLOCK_SIZE_M: tl.constexpr,
|
54
63
|
BLOCK_SIZE_N: tl.constexpr,
|
@@ -132,8 +141,15 @@ def fused_moe_kernel(
|
|
132
141
|
b_scale = tl.load(b_scale_ptrs)
|
133
142
|
|
134
143
|
if use_fp8_w8a8:
|
135
|
-
|
136
|
-
|
144
|
+
if group_k > 0 and group_n > 0:
|
145
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
146
|
+
offs_bsn = offs_bn // group_n
|
147
|
+
b_scale_ptrs = (
|
148
|
+
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
149
|
+
)
|
150
|
+
else:
|
151
|
+
a_scale = tl.load(a_scale_ptr)
|
152
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
137
153
|
|
138
154
|
# -----------------------------------------------------------
|
139
155
|
# Iterate to compute a block of the C matrix.
|
@@ -164,7 +180,17 @@ def fused_moe_kernel(
|
|
164
180
|
if use_int8_w8a16:
|
165
181
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
166
182
|
elif use_fp8_w8a8:
|
167
|
-
|
183
|
+
if group_k > 0 and group_n > 0:
|
184
|
+
k_start = k * BLOCK_SIZE_K
|
185
|
+
offs_ks = k_start // group_k
|
186
|
+
a_scale = tl.load(
|
187
|
+
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
188
|
+
)
|
189
|
+
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
190
|
+
|
191
|
+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
192
|
+
else:
|
193
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
168
194
|
else:
|
169
195
|
accumulator += tl.dot(a, b)
|
170
196
|
# Advance the ptrs to the next K block.
|
@@ -177,7 +203,10 @@ def fused_moe_kernel(
|
|
177
203
|
if use_int8_w8a16:
|
178
204
|
accumulator = (accumulator * b_scale).to(compute_type)
|
179
205
|
elif use_fp8_w8a8:
|
180
|
-
|
206
|
+
if group_k > 0 and group_n > 0:
|
207
|
+
accumulator = accumulator.to(compute_type)
|
208
|
+
else:
|
209
|
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
181
210
|
else:
|
182
211
|
accumulator = accumulator.to(compute_type)
|
183
212
|
# -----------------------------------------------------------
|
@@ -238,9 +267,25 @@ def moe_align_block_size(
|
|
238
267
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
239
268
|
)
|
240
269
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
241
|
-
|
242
|
-
|
243
|
-
|
270
|
+
# FIXME(zhyncs)
|
271
|
+
if num_experts >= 256:
|
272
|
+
sgl_moe_align_block_size(
|
273
|
+
topk_ids,
|
274
|
+
num_experts,
|
275
|
+
block_size,
|
276
|
+
sorted_ids,
|
277
|
+
expert_ids,
|
278
|
+
num_tokens_post_pad,
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
ops.moe_align_block_size(
|
282
|
+
topk_ids,
|
283
|
+
num_experts,
|
284
|
+
block_size,
|
285
|
+
sorted_ids,
|
286
|
+
expert_ids,
|
287
|
+
num_tokens_post_pad,
|
288
|
+
)
|
244
289
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
245
290
|
|
246
291
|
|
@@ -261,6 +306,7 @@ def invoke_fused_moe_kernel(
|
|
261
306
|
compute_type: tl.dtype,
|
262
307
|
use_fp8_w8a8: bool,
|
263
308
|
use_int8_w8a16: bool,
|
309
|
+
block_shape: Optional[List[int]] = None,
|
264
310
|
) -> None:
|
265
311
|
assert topk_weights.stride(1) == 1
|
266
312
|
assert sorted_token_ids.stride(0) == 1
|
@@ -268,8 +314,16 @@ def invoke_fused_moe_kernel(
|
|
268
314
|
padded_size = 0
|
269
315
|
if use_fp8_w8a8:
|
270
316
|
padded_size = padding_size
|
271
|
-
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
272
317
|
assert B_scale is not None
|
318
|
+
if block_shape is None:
|
319
|
+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
320
|
+
else:
|
321
|
+
assert len(block_shape) == 2
|
322
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
323
|
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
324
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
325
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
326
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
273
327
|
elif use_int8_w8a16:
|
274
328
|
assert B_scale is not None
|
275
329
|
else:
|
@@ -308,8 +362,13 @@ def invoke_fused_moe_kernel(
|
|
308
362
|
B.stride(1),
|
309
363
|
C.stride(1),
|
310
364
|
C.stride(2),
|
311
|
-
|
312
|
-
|
365
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
366
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
367
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
368
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
369
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
370
|
+
0 if block_shape is None else block_shape[0],
|
371
|
+
0 if block_shape is None else block_shape[1],
|
313
372
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
314
373
|
top_k=top_k,
|
315
374
|
compute_type=compute_type,
|
@@ -414,8 +473,9 @@ def try_get_optimal_moe_config(
|
|
414
473
|
dtype: Optional[str],
|
415
474
|
M: int,
|
416
475
|
is_marlin: bool = False,
|
476
|
+
block_shape: Optional[List[int]] = None,
|
417
477
|
):
|
418
|
-
from sglang.srt.layers.fused_moe_triton import get_config
|
478
|
+
from sglang.srt.layers.moe.fused_moe_triton import get_config
|
419
479
|
|
420
480
|
override_config = get_config()
|
421
481
|
if override_config:
|
@@ -432,77 +492,16 @@ def try_get_optimal_moe_config(
|
|
432
492
|
else:
|
433
493
|
# Else use the default config
|
434
494
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
495
|
+
# TODO(HandH1998): Optimize the configs of block-wise quant.
|
496
|
+
# NOTE(HandH1998): For block-wise quant,
|
497
|
+
# BLOCK_K must be divisable by block_shape[1]
|
498
|
+
# BLOCK_N and BLOCK_M has no requirements
|
499
|
+
if block_shape is not None:
|
500
|
+
config["BLOCK_SIZE_N"] = block_shape[0]
|
501
|
+
config["BLOCK_SIZE_K"] = block_shape[1]
|
435
502
|
return config
|
436
503
|
|
437
504
|
|
438
|
-
def fused_topk(
|
439
|
-
hidden_states: torch.Tensor,
|
440
|
-
gating_output: torch.Tensor,
|
441
|
-
topk: int,
|
442
|
-
renormalize: bool,
|
443
|
-
):
|
444
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
445
|
-
|
446
|
-
M, _ = hidden_states.shape
|
447
|
-
|
448
|
-
topk_weights = torch.empty(
|
449
|
-
M, topk, dtype=torch.float32, device=hidden_states.device
|
450
|
-
)
|
451
|
-
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
452
|
-
token_expert_indicies = torch.empty(
|
453
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
454
|
-
)
|
455
|
-
|
456
|
-
ops.topk_softmax(
|
457
|
-
topk_weights,
|
458
|
-
topk_ids,
|
459
|
-
token_expert_indicies,
|
460
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
461
|
-
)
|
462
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
463
|
-
|
464
|
-
if renormalize:
|
465
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
466
|
-
|
467
|
-
return topk_weights, topk_ids
|
468
|
-
|
469
|
-
|
470
|
-
# This is used by the Deepseek-V2 model
|
471
|
-
def grouped_topk(
|
472
|
-
hidden_states: torch.Tensor,
|
473
|
-
gating_output: torch.Tensor,
|
474
|
-
topk: int,
|
475
|
-
renormalize: bool,
|
476
|
-
num_expert_group: int = 0,
|
477
|
-
topk_group: int = 0,
|
478
|
-
):
|
479
|
-
|
480
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
481
|
-
|
482
|
-
scores = torch.softmax(gating_output, dim=-1)
|
483
|
-
num_token = scores.shape[0]
|
484
|
-
group_scores = (
|
485
|
-
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
486
|
-
) # [n, n_group]
|
487
|
-
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
488
|
-
1
|
489
|
-
] # [n, top_k_group]
|
490
|
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
491
|
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
492
|
-
score_mask = (
|
493
|
-
group_mask.unsqueeze(-1)
|
494
|
-
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
495
|
-
.reshape(num_token, -1)
|
496
|
-
) # [n, e]
|
497
|
-
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
498
|
-
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
499
|
-
|
500
|
-
if renormalize:
|
501
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
502
|
-
|
503
|
-
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
504
|
-
|
505
|
-
|
506
505
|
def get_config_dtype_str(
|
507
506
|
dtype: torch.dtype,
|
508
507
|
use_int8_w8a16: Optional[bool] = False,
|
@@ -531,6 +530,7 @@ def inplace_fused_experts(
|
|
531
530
|
w2_scale: Optional[torch.Tensor] = None,
|
532
531
|
a1_scale: Optional[torch.Tensor] = None,
|
533
532
|
a2_scale: Optional[torch.Tensor] = None,
|
533
|
+
block_shape: Optional[List[int]] = None,
|
534
534
|
) -> None:
|
535
535
|
fused_experts_impl(
|
536
536
|
hidden_states,
|
@@ -545,6 +545,7 @@ def inplace_fused_experts(
|
|
545
545
|
w2_scale,
|
546
546
|
a1_scale,
|
547
547
|
a2_scale,
|
548
|
+
block_shape,
|
548
549
|
)
|
549
550
|
|
550
551
|
|
@@ -560,6 +561,7 @@ def inplace_fused_experts_fake(
|
|
560
561
|
w2_scale: Optional[torch.Tensor] = None,
|
561
562
|
a1_scale: Optional[torch.Tensor] = None,
|
562
563
|
a2_scale: Optional[torch.Tensor] = None,
|
564
|
+
block_shape: Optional[List[int]] = None,
|
563
565
|
) -> None:
|
564
566
|
pass
|
565
567
|
|
@@ -584,6 +586,7 @@ def outplace_fused_experts(
|
|
584
586
|
w2_scale: Optional[torch.Tensor] = None,
|
585
587
|
a1_scale: Optional[torch.Tensor] = None,
|
586
588
|
a2_scale: Optional[torch.Tensor] = None,
|
589
|
+
block_shape: Optional[List[int]] = None,
|
587
590
|
) -> torch.Tensor:
|
588
591
|
return fused_experts_impl(
|
589
592
|
hidden_states,
|
@@ -598,6 +601,7 @@ def outplace_fused_experts(
|
|
598
601
|
w2_scale,
|
599
602
|
a1_scale,
|
600
603
|
a2_scale,
|
604
|
+
block_shape,
|
601
605
|
)
|
602
606
|
|
603
607
|
|
@@ -613,6 +617,7 @@ def outplace_fused_experts_fake(
|
|
613
617
|
w2_scale: Optional[torch.Tensor] = None,
|
614
618
|
a1_scale: Optional[torch.Tensor] = None,
|
615
619
|
a2_scale: Optional[torch.Tensor] = None,
|
620
|
+
block_shape: Optional[List[int]] = None,
|
616
621
|
) -> torch.Tensor:
|
617
622
|
return torch.empty_like(hidden_states)
|
618
623
|
|
@@ -638,6 +643,7 @@ def fused_experts(
|
|
638
643
|
w2_scale: Optional[torch.Tensor] = None,
|
639
644
|
a1_scale: Optional[torch.Tensor] = None,
|
640
645
|
a2_scale: Optional[torch.Tensor] = None,
|
646
|
+
block_shape: Optional[List[int]] = None,
|
641
647
|
):
|
642
648
|
if inplace:
|
643
649
|
torch.ops.sglang.inplace_fused_experts(
|
@@ -652,6 +658,7 @@ def fused_experts(
|
|
652
658
|
w2_scale,
|
653
659
|
a1_scale,
|
654
660
|
a2_scale,
|
661
|
+
block_shape,
|
655
662
|
)
|
656
663
|
return hidden_states
|
657
664
|
else:
|
@@ -667,6 +674,7 @@ def fused_experts(
|
|
667
674
|
w2_scale,
|
668
675
|
a1_scale,
|
669
676
|
a2_scale,
|
677
|
+
block_shape,
|
670
678
|
)
|
671
679
|
|
672
680
|
|
@@ -683,6 +691,7 @@ def fused_experts_impl(
|
|
683
691
|
w2_scale: Optional[torch.Tensor] = None,
|
684
692
|
a1_scale: Optional[torch.Tensor] = None,
|
685
693
|
a2_scale: Optional[torch.Tensor] = None,
|
694
|
+
block_shape: Optional[List[int]] = None,
|
686
695
|
):
|
687
696
|
padded_size = padding_size
|
688
697
|
if not use_fp8_w8a8:
|
@@ -714,6 +723,7 @@ def fused_experts_impl(
|
|
714
723
|
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
715
724
|
topk_ids.shape[1],
|
716
725
|
config_dtype,
|
726
|
+
block_shape=block_shape,
|
717
727
|
)
|
718
728
|
|
719
729
|
config = get_config_func(M)
|
@@ -786,6 +796,7 @@ def fused_experts_impl(
|
|
786
796
|
compute_type=compute_type,
|
787
797
|
use_fp8_w8a8=use_fp8_w8a8,
|
788
798
|
use_int8_w8a16=use_int8_w8a16,
|
799
|
+
block_shape=block_shape,
|
789
800
|
)
|
790
801
|
|
791
802
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
@@ -807,6 +818,7 @@ def fused_experts_impl(
|
|
807
818
|
compute_type=compute_type,
|
808
819
|
use_fp8_w8a8=use_fp8_w8a8,
|
809
820
|
use_int8_w8a16=use_int8_w8a16,
|
821
|
+
block_shape=block_shape,
|
810
822
|
)
|
811
823
|
|
812
824
|
torch.sum(
|
@@ -835,6 +847,7 @@ def fused_moe(
|
|
835
847
|
w2_scale: Optional[torch.Tensor] = None,
|
836
848
|
a1_scale: Optional[torch.Tensor] = None,
|
837
849
|
a2_scale: Optional[torch.Tensor] = None,
|
850
|
+
block_shape: Optional[List[int]] = None,
|
838
851
|
) -> torch.Tensor:
|
839
852
|
"""
|
840
853
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -862,6 +875,12 @@ def fused_moe(
|
|
862
875
|
w1.
|
863
876
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
864
877
|
w2.
|
878
|
+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
879
|
+
a1.
|
880
|
+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
881
|
+
a2.
|
882
|
+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
883
|
+
quantization.
|
865
884
|
|
866
885
|
Returns:
|
867
886
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
@@ -869,24 +888,16 @@ def fused_moe(
|
|
869
888
|
# Check constraints.
|
870
889
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
871
890
|
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
elif custom_routing_function is None:
|
883
|
-
topk_weights, topk_ids = fused_topk(
|
884
|
-
hidden_states, gating_output, topk, renormalize
|
885
|
-
)
|
886
|
-
else:
|
887
|
-
topk_weights, topk_ids = custom_routing_function(
|
888
|
-
hidden_states, gating_output, topk, renormalize
|
889
|
-
)
|
891
|
+
topk_weights, topk_ids = select_experts(
|
892
|
+
hidden_states=hidden_states,
|
893
|
+
router_logits=gating_output,
|
894
|
+
use_grouped_topk=use_grouped_topk,
|
895
|
+
top_k=topk,
|
896
|
+
renormalize=renormalize,
|
897
|
+
topk_group=topk_group,
|
898
|
+
num_expert_group=num_expert_group,
|
899
|
+
custom_routing_function=custom_routing_function,
|
900
|
+
)
|
890
901
|
|
891
902
|
return fused_experts(
|
892
903
|
hidden_states,
|
@@ -901,4 +912,5 @@ def fused_moe(
|
|
901
912
|
w2_scale=w2_scale,
|
902
913
|
a1_scale=a1_scale,
|
903
914
|
a2_scale=a2_scale,
|
915
|
+
block_shape=block_shape,
|
904
916
|
)
|
@@ -13,6 +13,7 @@ from vllm.distributed import (
|
|
13
13
|
from vllm.model_executor.custom_op import CustomOp
|
14
14
|
|
15
15
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.moe.topk import select_experts
|
16
17
|
from sglang.srt.layers.quantization.base_config import (
|
17
18
|
QuantizationConfig,
|
18
19
|
QuantizeMethodBase,
|
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
21
|
from sglang.srt.utils import set_weight_attrs
|
21
22
|
|
22
23
|
if torch.cuda.is_available():
|
23
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
24
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
24
25
|
else:
|
25
26
|
fused_experts = None # type: ignore
|
26
27
|
|
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
|
|
33
34
|
TENSOR = "tensor"
|
34
35
|
CHANNEL = "channel"
|
35
36
|
GROUP = "group"
|
37
|
+
BLOCK = "block"
|
36
38
|
|
37
39
|
|
38
40
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
106
108
|
topk_group: Optional[int] = None,
|
107
109
|
num_expert_group: Optional[int] = None,
|
108
110
|
custom_routing_function: Optional[Callable] = None,
|
111
|
+
correction_bias: Optional[torch.Tensor] = None,
|
109
112
|
) -> torch.Tensor:
|
110
113
|
return self.forward(
|
111
114
|
x=x,
|
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
117
120
|
topk_group=topk_group,
|
118
121
|
num_expert_group=num_expert_group,
|
119
122
|
custom_routing_function=custom_routing_function,
|
123
|
+
correction_bias=correction_bias,
|
120
124
|
)
|
121
125
|
|
122
126
|
def forward_cuda(
|
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
130
134
|
topk_group: Optional[int] = None,
|
131
135
|
num_expert_group: Optional[int] = None,
|
132
136
|
custom_routing_function: Optional[Callable] = None,
|
137
|
+
correction_bias: Optional[torch.Tensor] = None,
|
133
138
|
) -> torch.Tensor:
|
134
|
-
topk_weights, topk_ids =
|
139
|
+
topk_weights, topk_ids = select_experts(
|
135
140
|
hidden_states=x,
|
136
141
|
router_logits=router_logits,
|
137
142
|
use_grouped_topk=use_grouped_topk,
|
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
140
145
|
topk_group=topk_group,
|
141
146
|
num_expert_group=num_expert_group,
|
142
147
|
custom_routing_function=custom_routing_function,
|
148
|
+
correction_bias=correction_bias,
|
143
149
|
)
|
144
150
|
|
145
151
|
return fused_experts(
|
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
203
|
tp_size: Optional[int] = None,
|
198
204
|
prefix: str = "",
|
199
205
|
custom_routing_function: Optional[Callable] = None,
|
206
|
+
correction_bias: Optional[torch.Tensor] = None,
|
200
207
|
):
|
201
208
|
super().__init__()
|
202
209
|
|
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
|
|
208
215
|
)
|
209
216
|
self.top_k = top_k
|
210
217
|
self.num_experts = num_experts
|
218
|
+
assert intermediate_size % self.tp_size == 0
|
211
219
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
212
220
|
self.reduce_results = reduce_results
|
213
221
|
self.renormalize = renormalize
|
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
|
|
217
225
|
self.num_expert_group = num_expert_group
|
218
226
|
self.topk_group = topk_group
|
219
227
|
self.custom_routing_function = custom_routing_function
|
228
|
+
self.correction_bias = correction_bias
|
220
229
|
|
221
230
|
if quant_config is None:
|
222
231
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
|
|
463
472
|
expert_data=expert_data,
|
464
473
|
tp_rank=tp_rank,
|
465
474
|
)
|
466
|
-
elif quant_method
|
475
|
+
elif quant_method in [
|
476
|
+
FusedMoeWeightScaleSupported.GROUP.value,
|
477
|
+
FusedMoeWeightScaleSupported.BLOCK.value,
|
478
|
+
]:
|
467
479
|
self._load_model_weight_or_group_weight_scale(
|
468
480
|
shard_id=shard_id,
|
469
481
|
shard_dim=shard_dim,
|
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
|
|
503
515
|
)
|
504
516
|
return
|
505
517
|
|
506
|
-
@staticmethod
|
507
|
-
def select_experts(
|
508
|
-
hidden_states: torch.Tensor,
|
509
|
-
router_logits: torch.Tensor,
|
510
|
-
top_k: int,
|
511
|
-
use_grouped_topk: bool,
|
512
|
-
renormalize: bool,
|
513
|
-
topk_group: Optional[int] = None,
|
514
|
-
num_expert_group: Optional[int] = None,
|
515
|
-
custom_routing_function: Optional[Callable] = None,
|
516
|
-
):
|
517
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
518
|
-
fused_topk,
|
519
|
-
grouped_topk,
|
520
|
-
)
|
521
|
-
|
522
|
-
# DeekSeekv2 uses grouped_top_k
|
523
|
-
if use_grouped_topk:
|
524
|
-
assert topk_group is not None
|
525
|
-
assert num_expert_group is not None
|
526
|
-
topk_weights, topk_ids = grouped_topk(
|
527
|
-
hidden_states=hidden_states,
|
528
|
-
gating_output=router_logits,
|
529
|
-
topk=top_k,
|
530
|
-
renormalize=renormalize,
|
531
|
-
num_expert_group=num_expert_group,
|
532
|
-
topk_group=topk_group,
|
533
|
-
)
|
534
|
-
elif custom_routing_function is None:
|
535
|
-
topk_weights, topk_ids = fused_topk(
|
536
|
-
hidden_states=hidden_states,
|
537
|
-
gating_output=router_logits,
|
538
|
-
topk=top_k,
|
539
|
-
renormalize=renormalize,
|
540
|
-
)
|
541
|
-
else:
|
542
|
-
topk_weights, topk_ids = custom_routing_function(
|
543
|
-
hidden_states=hidden_states,
|
544
|
-
gating_output=router_logits,
|
545
|
-
topk=top_k,
|
546
|
-
renormalize=renormalize,
|
547
|
-
)
|
548
|
-
|
549
|
-
return topk_weights, topk_ids
|
550
|
-
|
551
518
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
552
519
|
assert self.quant_method is not None
|
553
520
|
|
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
|
|
562
529
|
topk_group=self.topk_group,
|
563
530
|
num_expert_group=self.num_expert_group,
|
564
531
|
custom_routing_function=self.custom_routing_function,
|
532
|
+
correction_bias=self.correction_bias,
|
565
533
|
)
|
566
534
|
|
567
535
|
if self.reduce_results and self.tp_size > 1:
|