sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -6,14 +6,22 @@ import functools
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
-
from typing import Any, Callable, Dict, Optional, Tuple
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
10
10
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
from vllm import _custom_ops as ops
|
15
15
|
|
16
|
-
from sglang.srt.
|
16
|
+
from sglang.srt.layers.moe.topk import select_experts
|
17
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
+
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
|
+
|
20
|
+
not_hip = False
|
21
|
+
if not is_hip():
|
22
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
|
+
|
24
|
+
not_hip = True
|
17
25
|
|
18
26
|
logger = logging.getLogger(__name__)
|
19
27
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
@@ -47,8 +55,14 @@ def fused_moe_kernel(
|
|
47
55
|
stride_bn,
|
48
56
|
stride_cm,
|
49
57
|
stride_cn,
|
58
|
+
stride_asm,
|
59
|
+
stride_ask,
|
50
60
|
stride_bse,
|
61
|
+
stride_bsk,
|
51
62
|
stride_bsn,
|
63
|
+
# Block size for block-wise quantization
|
64
|
+
group_n: tl.constexpr,
|
65
|
+
group_k: tl.constexpr,
|
52
66
|
# Meta-parameters
|
53
67
|
BLOCK_SIZE_M: tl.constexpr,
|
54
68
|
BLOCK_SIZE_N: tl.constexpr,
|
@@ -132,8 +146,15 @@ def fused_moe_kernel(
|
|
132
146
|
b_scale = tl.load(b_scale_ptrs)
|
133
147
|
|
134
148
|
if use_fp8_w8a8:
|
135
|
-
|
136
|
-
|
149
|
+
if group_k > 0 and group_n > 0:
|
150
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
151
|
+
offs_bsn = offs_bn // group_n
|
152
|
+
b_scale_ptrs = (
|
153
|
+
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
154
|
+
)
|
155
|
+
else:
|
156
|
+
a_scale = tl.load(a_scale_ptr)
|
157
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
137
158
|
|
138
159
|
# -----------------------------------------------------------
|
139
160
|
# Iterate to compute a block of the C matrix.
|
@@ -164,7 +185,17 @@ def fused_moe_kernel(
|
|
164
185
|
if use_int8_w8a16:
|
165
186
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
166
187
|
elif use_fp8_w8a8:
|
167
|
-
|
188
|
+
if group_k > 0 and group_n > 0:
|
189
|
+
k_start = k * BLOCK_SIZE_K
|
190
|
+
offs_ks = k_start // group_k
|
191
|
+
a_scale = tl.load(
|
192
|
+
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
193
|
+
)
|
194
|
+
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
195
|
+
|
196
|
+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
197
|
+
else:
|
198
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
168
199
|
else:
|
169
200
|
accumulator += tl.dot(a, b)
|
170
201
|
# Advance the ptrs to the next K block.
|
@@ -177,7 +208,10 @@ def fused_moe_kernel(
|
|
177
208
|
if use_int8_w8a16:
|
178
209
|
accumulator = (accumulator * b_scale).to(compute_type)
|
179
210
|
elif use_fp8_w8a8:
|
180
|
-
|
211
|
+
if group_k > 0 and group_n > 0:
|
212
|
+
accumulator = accumulator.to(compute_type)
|
213
|
+
else:
|
214
|
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
181
215
|
else:
|
182
216
|
accumulator = accumulator.to(compute_type)
|
183
217
|
# -----------------------------------------------------------
|
@@ -238,9 +272,33 @@ def moe_align_block_size(
|
|
238
272
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
239
273
|
)
|
240
274
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
241
|
-
|
242
|
-
|
243
|
-
|
275
|
+
if not_hip and num_experts >= 224:
|
276
|
+
token_cnts_buffer = torch.empty(
|
277
|
+
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
278
|
+
)
|
279
|
+
cumsum_buffer = torch.empty(
|
280
|
+
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
281
|
+
)
|
282
|
+
|
283
|
+
sgl_moe_align_block_size(
|
284
|
+
topk_ids,
|
285
|
+
num_experts,
|
286
|
+
block_size,
|
287
|
+
sorted_ids,
|
288
|
+
expert_ids,
|
289
|
+
num_tokens_post_pad,
|
290
|
+
token_cnts_buffer,
|
291
|
+
cumsum_buffer,
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
ops.moe_align_block_size(
|
295
|
+
topk_ids,
|
296
|
+
num_experts,
|
297
|
+
block_size,
|
298
|
+
sorted_ids,
|
299
|
+
expert_ids,
|
300
|
+
num_tokens_post_pad,
|
301
|
+
)
|
244
302
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
245
303
|
|
246
304
|
|
@@ -261,6 +319,7 @@ def invoke_fused_moe_kernel(
|
|
261
319
|
compute_type: tl.dtype,
|
262
320
|
use_fp8_w8a8: bool,
|
263
321
|
use_int8_w8a16: bool,
|
322
|
+
block_shape: Optional[List[int]] = None,
|
264
323
|
) -> None:
|
265
324
|
assert topk_weights.stride(1) == 1
|
266
325
|
assert sorted_token_ids.stride(0) == 1
|
@@ -268,8 +327,16 @@ def invoke_fused_moe_kernel(
|
|
268
327
|
padded_size = 0
|
269
328
|
if use_fp8_w8a8:
|
270
329
|
padded_size = padding_size
|
271
|
-
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
272
330
|
assert B_scale is not None
|
331
|
+
if block_shape is None:
|
332
|
+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
333
|
+
else:
|
334
|
+
assert len(block_shape) == 2
|
335
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
336
|
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
337
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
338
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
339
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
273
340
|
elif use_int8_w8a16:
|
274
341
|
assert B_scale is not None
|
275
342
|
else:
|
@@ -308,8 +375,13 @@ def invoke_fused_moe_kernel(
|
|
308
375
|
B.stride(1),
|
309
376
|
C.stride(1),
|
310
377
|
C.stride(2),
|
311
|
-
|
312
|
-
|
378
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
379
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
380
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
381
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
382
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
383
|
+
0 if block_shape is None else block_shape[0],
|
384
|
+
0 if block_shape is None else block_shape[1],
|
313
385
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
314
386
|
top_k=top_k,
|
315
387
|
compute_type=compute_type,
|
@@ -414,8 +486,9 @@ def try_get_optimal_moe_config(
|
|
414
486
|
dtype: Optional[str],
|
415
487
|
M: int,
|
416
488
|
is_marlin: bool = False,
|
489
|
+
block_shape: Optional[List[int]] = None,
|
417
490
|
):
|
418
|
-
from sglang.srt.layers.fused_moe_triton import get_config
|
491
|
+
from sglang.srt.layers.moe.fused_moe_triton import get_config
|
419
492
|
|
420
493
|
override_config = get_config()
|
421
494
|
if override_config:
|
@@ -432,77 +505,16 @@ def try_get_optimal_moe_config(
|
|
432
505
|
else:
|
433
506
|
# Else use the default config
|
434
507
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
508
|
+
# TODO(HandH1998): Optimize the configs of block-wise quant.
|
509
|
+
# NOTE(HandH1998): For block-wise quant,
|
510
|
+
# BLOCK_K must be divisable by block_shape[1]
|
511
|
+
# BLOCK_N and BLOCK_M has no requirements
|
512
|
+
if block_shape is not None:
|
513
|
+
config["BLOCK_SIZE_N"] = block_shape[0]
|
514
|
+
config["BLOCK_SIZE_K"] = block_shape[1]
|
435
515
|
return config
|
436
516
|
|
437
517
|
|
438
|
-
def fused_topk(
|
439
|
-
hidden_states: torch.Tensor,
|
440
|
-
gating_output: torch.Tensor,
|
441
|
-
topk: int,
|
442
|
-
renormalize: bool,
|
443
|
-
):
|
444
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
445
|
-
|
446
|
-
M, _ = hidden_states.shape
|
447
|
-
|
448
|
-
topk_weights = torch.empty(
|
449
|
-
M, topk, dtype=torch.float32, device=hidden_states.device
|
450
|
-
)
|
451
|
-
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
452
|
-
token_expert_indicies = torch.empty(
|
453
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
454
|
-
)
|
455
|
-
|
456
|
-
ops.topk_softmax(
|
457
|
-
topk_weights,
|
458
|
-
topk_ids,
|
459
|
-
token_expert_indicies,
|
460
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
461
|
-
)
|
462
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
463
|
-
|
464
|
-
if renormalize:
|
465
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
466
|
-
|
467
|
-
return topk_weights, topk_ids
|
468
|
-
|
469
|
-
|
470
|
-
# This is used by the Deepseek-V2 model
|
471
|
-
def grouped_topk(
|
472
|
-
hidden_states: torch.Tensor,
|
473
|
-
gating_output: torch.Tensor,
|
474
|
-
topk: int,
|
475
|
-
renormalize: bool,
|
476
|
-
num_expert_group: int = 0,
|
477
|
-
topk_group: int = 0,
|
478
|
-
):
|
479
|
-
|
480
|
-
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
481
|
-
|
482
|
-
scores = torch.softmax(gating_output, dim=-1)
|
483
|
-
num_token = scores.shape[0]
|
484
|
-
group_scores = (
|
485
|
-
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
486
|
-
) # [n, n_group]
|
487
|
-
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
488
|
-
1
|
489
|
-
] # [n, top_k_group]
|
490
|
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
491
|
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
492
|
-
score_mask = (
|
493
|
-
group_mask.unsqueeze(-1)
|
494
|
-
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
495
|
-
.reshape(num_token, -1)
|
496
|
-
) # [n, e]
|
497
|
-
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
498
|
-
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
499
|
-
|
500
|
-
if renormalize:
|
501
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
502
|
-
|
503
|
-
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
504
|
-
|
505
|
-
|
506
518
|
def get_config_dtype_str(
|
507
519
|
dtype: torch.dtype,
|
508
520
|
use_int8_w8a16: Optional[bool] = False,
|
@@ -531,6 +543,7 @@ def inplace_fused_experts(
|
|
531
543
|
w2_scale: Optional[torch.Tensor] = None,
|
532
544
|
a1_scale: Optional[torch.Tensor] = None,
|
533
545
|
a2_scale: Optional[torch.Tensor] = None,
|
546
|
+
block_shape: Optional[List[int]] = None,
|
534
547
|
) -> None:
|
535
548
|
fused_experts_impl(
|
536
549
|
hidden_states,
|
@@ -545,6 +558,7 @@ def inplace_fused_experts(
|
|
545
558
|
w2_scale,
|
546
559
|
a1_scale,
|
547
560
|
a2_scale,
|
561
|
+
block_shape,
|
548
562
|
)
|
549
563
|
|
550
564
|
|
@@ -560,6 +574,7 @@ def inplace_fused_experts_fake(
|
|
560
574
|
w2_scale: Optional[torch.Tensor] = None,
|
561
575
|
a1_scale: Optional[torch.Tensor] = None,
|
562
576
|
a2_scale: Optional[torch.Tensor] = None,
|
577
|
+
block_shape: Optional[List[int]] = None,
|
563
578
|
) -> None:
|
564
579
|
pass
|
565
580
|
|
@@ -584,6 +599,7 @@ def outplace_fused_experts(
|
|
584
599
|
w2_scale: Optional[torch.Tensor] = None,
|
585
600
|
a1_scale: Optional[torch.Tensor] = None,
|
586
601
|
a2_scale: Optional[torch.Tensor] = None,
|
602
|
+
block_shape: Optional[List[int]] = None,
|
587
603
|
) -> torch.Tensor:
|
588
604
|
return fused_experts_impl(
|
589
605
|
hidden_states,
|
@@ -598,6 +614,7 @@ def outplace_fused_experts(
|
|
598
614
|
w2_scale,
|
599
615
|
a1_scale,
|
600
616
|
a2_scale,
|
617
|
+
block_shape,
|
601
618
|
)
|
602
619
|
|
603
620
|
|
@@ -613,6 +630,7 @@ def outplace_fused_experts_fake(
|
|
613
630
|
w2_scale: Optional[torch.Tensor] = None,
|
614
631
|
a1_scale: Optional[torch.Tensor] = None,
|
615
632
|
a2_scale: Optional[torch.Tensor] = None,
|
633
|
+
block_shape: Optional[List[int]] = None,
|
616
634
|
) -> torch.Tensor:
|
617
635
|
return torch.empty_like(hidden_states)
|
618
636
|
|
@@ -638,6 +656,7 @@ def fused_experts(
|
|
638
656
|
w2_scale: Optional[torch.Tensor] = None,
|
639
657
|
a1_scale: Optional[torch.Tensor] = None,
|
640
658
|
a2_scale: Optional[torch.Tensor] = None,
|
659
|
+
block_shape: Optional[List[int]] = None,
|
641
660
|
):
|
642
661
|
if inplace:
|
643
662
|
torch.ops.sglang.inplace_fused_experts(
|
@@ -652,6 +671,7 @@ def fused_experts(
|
|
652
671
|
w2_scale,
|
653
672
|
a1_scale,
|
654
673
|
a2_scale,
|
674
|
+
block_shape,
|
655
675
|
)
|
656
676
|
return hidden_states
|
657
677
|
else:
|
@@ -667,6 +687,7 @@ def fused_experts(
|
|
667
687
|
w2_scale,
|
668
688
|
a1_scale,
|
669
689
|
a2_scale,
|
690
|
+
block_shape,
|
670
691
|
)
|
671
692
|
|
672
693
|
|
@@ -683,6 +704,7 @@ def fused_experts_impl(
|
|
683
704
|
w2_scale: Optional[torch.Tensor] = None,
|
684
705
|
a1_scale: Optional[torch.Tensor] = None,
|
685
706
|
a2_scale: Optional[torch.Tensor] = None,
|
707
|
+
block_shape: Optional[List[int]] = None,
|
686
708
|
):
|
687
709
|
padded_size = padding_size
|
688
710
|
if not use_fp8_w8a8:
|
@@ -714,6 +736,7 @@ def fused_experts_impl(
|
|
714
736
|
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
715
737
|
topk_ids.shape[1],
|
716
738
|
config_dtype,
|
739
|
+
block_shape=block_shape,
|
717
740
|
)
|
718
741
|
|
719
742
|
config = get_config_func(M)
|
@@ -786,6 +809,7 @@ def fused_experts_impl(
|
|
786
809
|
compute_type=compute_type,
|
787
810
|
use_fp8_w8a8=use_fp8_w8a8,
|
788
811
|
use_int8_w8a16=use_int8_w8a16,
|
812
|
+
block_shape=block_shape,
|
789
813
|
)
|
790
814
|
|
791
815
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
@@ -807,6 +831,7 @@ def fused_experts_impl(
|
|
807
831
|
compute_type=compute_type,
|
808
832
|
use_fp8_w8a8=use_fp8_w8a8,
|
809
833
|
use_int8_w8a16=use_int8_w8a16,
|
834
|
+
block_shape=block_shape,
|
810
835
|
)
|
811
836
|
|
812
837
|
torch.sum(
|
@@ -835,6 +860,7 @@ def fused_moe(
|
|
835
860
|
w2_scale: Optional[torch.Tensor] = None,
|
836
861
|
a1_scale: Optional[torch.Tensor] = None,
|
837
862
|
a2_scale: Optional[torch.Tensor] = None,
|
863
|
+
block_shape: Optional[List[int]] = None,
|
838
864
|
) -> torch.Tensor:
|
839
865
|
"""
|
840
866
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -862,6 +888,12 @@ def fused_moe(
|
|
862
888
|
w1.
|
863
889
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
864
890
|
w2.
|
891
|
+
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
892
|
+
a1.
|
893
|
+
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
894
|
+
a2.
|
895
|
+
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
896
|
+
quantization.
|
865
897
|
|
866
898
|
Returns:
|
867
899
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
@@ -869,24 +901,16 @@ def fused_moe(
|
|
869
901
|
# Check constraints.
|
870
902
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
871
903
|
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
elif custom_routing_function is None:
|
883
|
-
topk_weights, topk_ids = fused_topk(
|
884
|
-
hidden_states, gating_output, topk, renormalize
|
885
|
-
)
|
886
|
-
else:
|
887
|
-
topk_weights, topk_ids = custom_routing_function(
|
888
|
-
hidden_states, gating_output, topk, renormalize
|
889
|
-
)
|
904
|
+
topk_weights, topk_ids = select_experts(
|
905
|
+
hidden_states=hidden_states,
|
906
|
+
router_logits=gating_output,
|
907
|
+
use_grouped_topk=use_grouped_topk,
|
908
|
+
top_k=topk,
|
909
|
+
renormalize=renormalize,
|
910
|
+
topk_group=topk_group,
|
911
|
+
num_expert_group=num_expert_group,
|
912
|
+
custom_routing_function=custom_routing_function,
|
913
|
+
)
|
890
914
|
|
891
915
|
return fused_experts(
|
892
916
|
hidden_states,
|
@@ -901,4 +925,5 @@ def fused_moe(
|
|
901
925
|
w2_scale=w2_scale,
|
902
926
|
a1_scale=a1_scale,
|
903
927
|
a2_scale=a2_scale,
|
928
|
+
block_shape=block_shape,
|
904
929
|
)
|
@@ -13,6 +13,7 @@ from vllm.distributed import (
|
|
13
13
|
from vllm.model_executor.custom_op import CustomOp
|
14
14
|
|
15
15
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.moe.topk import select_experts
|
16
17
|
from sglang.srt.layers.quantization.base_config import (
|
17
18
|
QuantizationConfig,
|
18
19
|
QuantizeMethodBase,
|
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
21
|
from sglang.srt.utils import set_weight_attrs
|
21
22
|
|
22
23
|
if torch.cuda.is_available():
|
23
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
24
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
24
25
|
else:
|
25
26
|
fused_experts = None # type: ignore
|
26
27
|
|
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
|
|
33
34
|
TENSOR = "tensor"
|
34
35
|
CHANNEL = "channel"
|
35
36
|
GROUP = "group"
|
37
|
+
BLOCK = "block"
|
36
38
|
|
37
39
|
|
38
40
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
106
108
|
topk_group: Optional[int] = None,
|
107
109
|
num_expert_group: Optional[int] = None,
|
108
110
|
custom_routing_function: Optional[Callable] = None,
|
111
|
+
correction_bias: Optional[torch.Tensor] = None,
|
109
112
|
) -> torch.Tensor:
|
110
113
|
return self.forward(
|
111
114
|
x=x,
|
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
117
120
|
topk_group=topk_group,
|
118
121
|
num_expert_group=num_expert_group,
|
119
122
|
custom_routing_function=custom_routing_function,
|
123
|
+
correction_bias=correction_bias,
|
120
124
|
)
|
121
125
|
|
122
126
|
def forward_cuda(
|
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
130
134
|
topk_group: Optional[int] = None,
|
131
135
|
num_expert_group: Optional[int] = None,
|
132
136
|
custom_routing_function: Optional[Callable] = None,
|
137
|
+
correction_bias: Optional[torch.Tensor] = None,
|
133
138
|
) -> torch.Tensor:
|
134
|
-
topk_weights, topk_ids =
|
139
|
+
topk_weights, topk_ids = select_experts(
|
135
140
|
hidden_states=x,
|
136
141
|
router_logits=router_logits,
|
137
142
|
use_grouped_topk=use_grouped_topk,
|
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
140
145
|
topk_group=topk_group,
|
141
146
|
num_expert_group=num_expert_group,
|
142
147
|
custom_routing_function=custom_routing_function,
|
148
|
+
correction_bias=correction_bias,
|
143
149
|
)
|
144
150
|
|
145
151
|
return fused_experts(
|
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
203
|
tp_size: Optional[int] = None,
|
198
204
|
prefix: str = "",
|
199
205
|
custom_routing_function: Optional[Callable] = None,
|
206
|
+
correction_bias: Optional[torch.Tensor] = None,
|
200
207
|
):
|
201
208
|
super().__init__()
|
202
209
|
|
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
|
|
208
215
|
)
|
209
216
|
self.top_k = top_k
|
210
217
|
self.num_experts = num_experts
|
218
|
+
assert intermediate_size % self.tp_size == 0
|
211
219
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
212
220
|
self.reduce_results = reduce_results
|
213
221
|
self.renormalize = renormalize
|
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
|
|
217
225
|
self.num_expert_group = num_expert_group
|
218
226
|
self.topk_group = topk_group
|
219
227
|
self.custom_routing_function = custom_routing_function
|
228
|
+
self.correction_bias = correction_bias
|
220
229
|
|
221
230
|
if quant_config is None:
|
222
231
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
|
|
463
472
|
expert_data=expert_data,
|
464
473
|
tp_rank=tp_rank,
|
465
474
|
)
|
466
|
-
elif quant_method
|
475
|
+
elif quant_method in [
|
476
|
+
FusedMoeWeightScaleSupported.GROUP.value,
|
477
|
+
FusedMoeWeightScaleSupported.BLOCK.value,
|
478
|
+
]:
|
467
479
|
self._load_model_weight_or_group_weight_scale(
|
468
480
|
shard_id=shard_id,
|
469
481
|
shard_dim=shard_dim,
|
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
|
|
503
515
|
)
|
504
516
|
return
|
505
517
|
|
506
|
-
@staticmethod
|
507
|
-
def select_experts(
|
508
|
-
hidden_states: torch.Tensor,
|
509
|
-
router_logits: torch.Tensor,
|
510
|
-
top_k: int,
|
511
|
-
use_grouped_topk: bool,
|
512
|
-
renormalize: bool,
|
513
|
-
topk_group: Optional[int] = None,
|
514
|
-
num_expert_group: Optional[int] = None,
|
515
|
-
custom_routing_function: Optional[Callable] = None,
|
516
|
-
):
|
517
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
518
|
-
fused_topk,
|
519
|
-
grouped_topk,
|
520
|
-
)
|
521
|
-
|
522
|
-
# DeekSeekv2 uses grouped_top_k
|
523
|
-
if use_grouped_topk:
|
524
|
-
assert topk_group is not None
|
525
|
-
assert num_expert_group is not None
|
526
|
-
topk_weights, topk_ids = grouped_topk(
|
527
|
-
hidden_states=hidden_states,
|
528
|
-
gating_output=router_logits,
|
529
|
-
topk=top_k,
|
530
|
-
renormalize=renormalize,
|
531
|
-
num_expert_group=num_expert_group,
|
532
|
-
topk_group=topk_group,
|
533
|
-
)
|
534
|
-
elif custom_routing_function is None:
|
535
|
-
topk_weights, topk_ids = fused_topk(
|
536
|
-
hidden_states=hidden_states,
|
537
|
-
gating_output=router_logits,
|
538
|
-
topk=top_k,
|
539
|
-
renormalize=renormalize,
|
540
|
-
)
|
541
|
-
else:
|
542
|
-
topk_weights, topk_ids = custom_routing_function(
|
543
|
-
hidden_states=hidden_states,
|
544
|
-
gating_output=router_logits,
|
545
|
-
topk=top_k,
|
546
|
-
renormalize=renormalize,
|
547
|
-
)
|
548
|
-
|
549
|
-
return topk_weights, topk_ids
|
550
|
-
|
551
518
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
552
519
|
assert self.quant_method is not None
|
553
520
|
|
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
|
|
562
529
|
topk_group=self.topk_group,
|
563
530
|
num_expert_group=self.num_expert_group,
|
564
531
|
custom_routing_function=self.custom_routing_function,
|
532
|
+
correction_bias=self.correction_bias,
|
565
533
|
)
|
566
534
|
|
567
535
|
if self.reduce_results and self.tp_size > 1:
|