sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -53,9 +53,7 @@ elif _is_hip:
|
|
53
53
|
from aiter import moe_sum
|
54
54
|
except ImportError:
|
55
55
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
56
|
-
|
57
|
-
from vllm import _custom_ops as vllm_ops
|
58
|
-
from vllm._custom_ops import scaled_fp8_quant
|
56
|
+
|
59
57
|
|
60
58
|
if _is_cuda or _is_hip:
|
61
59
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
|
|
63
61
|
|
64
62
|
logger = logging.getLogger(__name__)
|
65
63
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
66
|
-
enable_moe_align_block_size_triton = bool(
|
67
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
68
|
-
)
|
69
64
|
|
70
65
|
|
71
66
|
@triton.jit
|
@@ -533,190 +528,6 @@ def fused_moe_kernel(
|
|
533
528
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
534
529
|
|
535
530
|
|
536
|
-
@triton.jit
|
537
|
-
def moe_align_block_size_stage1(
|
538
|
-
topk_ids_ptr,
|
539
|
-
tokens_cnts_ptr,
|
540
|
-
num_experts: tl.constexpr,
|
541
|
-
numel: tl.constexpr,
|
542
|
-
tokens_per_thread: tl.constexpr,
|
543
|
-
):
|
544
|
-
pid = tl.program_id(0)
|
545
|
-
|
546
|
-
start_idx = pid * tokens_per_thread
|
547
|
-
|
548
|
-
off_c = (pid + 1) * num_experts
|
549
|
-
|
550
|
-
for i in range(tokens_per_thread):
|
551
|
-
if start_idx + i < numel:
|
552
|
-
idx = tl.load(topk_ids_ptr + start_idx + i)
|
553
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
554
|
-
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
555
|
-
|
556
|
-
|
557
|
-
@triton.jit
|
558
|
-
def moe_align_block_size_stage2(
|
559
|
-
tokens_cnts_ptr,
|
560
|
-
num_experts: tl.constexpr,
|
561
|
-
):
|
562
|
-
pid = tl.program_id(0)
|
563
|
-
|
564
|
-
last_cnt = 0
|
565
|
-
for i in range(1, num_experts + 1):
|
566
|
-
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
567
|
-
last_cnt = last_cnt + token_cnt
|
568
|
-
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
569
|
-
|
570
|
-
|
571
|
-
@triton.jit
|
572
|
-
def moe_align_block_size_stage3(
|
573
|
-
total_tokens_post_pad_ptr,
|
574
|
-
tokens_cnts_ptr,
|
575
|
-
cumsum_ptr,
|
576
|
-
num_experts: tl.constexpr,
|
577
|
-
block_size: tl.constexpr,
|
578
|
-
):
|
579
|
-
last_cumsum = 0
|
580
|
-
off_cnt = num_experts * num_experts
|
581
|
-
for i in range(1, num_experts + 1):
|
582
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
583
|
-
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
584
|
-
tl.store(cumsum_ptr + i, last_cumsum)
|
585
|
-
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
586
|
-
|
587
|
-
|
588
|
-
@triton.jit
|
589
|
-
def moe_align_block_size_stage4(
|
590
|
-
topk_ids_ptr,
|
591
|
-
sorted_token_ids_ptr,
|
592
|
-
expert_ids_ptr,
|
593
|
-
tokens_cnts_ptr,
|
594
|
-
cumsum_ptr,
|
595
|
-
num_experts: tl.constexpr,
|
596
|
-
block_size: tl.constexpr,
|
597
|
-
numel: tl.constexpr,
|
598
|
-
tokens_per_thread: tl.constexpr,
|
599
|
-
):
|
600
|
-
pid = tl.program_id(0)
|
601
|
-
start_idx = tl.load(cumsum_ptr + pid)
|
602
|
-
end_idx = tl.load(cumsum_ptr + pid + 1)
|
603
|
-
|
604
|
-
for i in range(start_idx, end_idx, block_size):
|
605
|
-
tl.store(expert_ids_ptr + i // block_size, pid)
|
606
|
-
|
607
|
-
start_idx = pid * tokens_per_thread
|
608
|
-
off_t = pid * num_experts
|
609
|
-
|
610
|
-
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
611
|
-
expert_id = tl.load(topk_ids_ptr + i)
|
612
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
613
|
-
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
614
|
-
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
615
|
-
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
616
|
-
|
617
|
-
|
618
|
-
def moe_align_block_size_triton(
|
619
|
-
topk_ids: torch.Tensor,
|
620
|
-
num_experts: int,
|
621
|
-
block_size: int,
|
622
|
-
sorted_token_ids: torch.Tensor,
|
623
|
-
expert_ids: torch.Tensor,
|
624
|
-
num_tokens_post_pad: torch.Tensor,
|
625
|
-
) -> None:
|
626
|
-
numel = topk_ids.numel()
|
627
|
-
grid = (num_experts,)
|
628
|
-
tokens_cnts = torch.zeros(
|
629
|
-
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
630
|
-
)
|
631
|
-
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
632
|
-
tokens_per_thread = ceil_div(numel, num_experts)
|
633
|
-
|
634
|
-
moe_align_block_size_stage1[grid](
|
635
|
-
topk_ids,
|
636
|
-
tokens_cnts,
|
637
|
-
num_experts,
|
638
|
-
numel,
|
639
|
-
tokens_per_thread,
|
640
|
-
)
|
641
|
-
moe_align_block_size_stage2[grid](
|
642
|
-
tokens_cnts,
|
643
|
-
num_experts,
|
644
|
-
)
|
645
|
-
moe_align_block_size_stage3[(1,)](
|
646
|
-
num_tokens_post_pad,
|
647
|
-
tokens_cnts,
|
648
|
-
cumsum,
|
649
|
-
num_experts,
|
650
|
-
block_size,
|
651
|
-
)
|
652
|
-
moe_align_block_size_stage4[grid](
|
653
|
-
topk_ids,
|
654
|
-
sorted_token_ids,
|
655
|
-
expert_ids,
|
656
|
-
tokens_cnts,
|
657
|
-
cumsum,
|
658
|
-
num_experts,
|
659
|
-
block_size,
|
660
|
-
numel,
|
661
|
-
tokens_per_thread,
|
662
|
-
)
|
663
|
-
|
664
|
-
|
665
|
-
@triton.jit
|
666
|
-
def init_sorted_ids_and_cumsum_buffer_kernel(
|
667
|
-
sorted_ids_ptr,
|
668
|
-
cumsum_buffer_ptr,
|
669
|
-
max_num_tokens_padded,
|
670
|
-
topk_ids_numel,
|
671
|
-
num_experts: tl.constexpr,
|
672
|
-
BLOCK_SIZE: tl.constexpr,
|
673
|
-
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
674
|
-
):
|
675
|
-
pid = tl.program_id(0)
|
676
|
-
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
677
|
-
|
678
|
-
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
679
|
-
|
680
|
-
if pid < sorted_ids_blocks:
|
681
|
-
mask = offsets < max_num_tokens_padded
|
682
|
-
tl.store(
|
683
|
-
sorted_ids_ptr + offsets,
|
684
|
-
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
685
|
-
mask=mask,
|
686
|
-
)
|
687
|
-
elif pid == sorted_ids_blocks:
|
688
|
-
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
689
|
-
mask_e = offset_e < num_experts + 1
|
690
|
-
tl.store(
|
691
|
-
cumsum_buffer_ptr + offset_e,
|
692
|
-
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
693
|
-
mask=mask_e,
|
694
|
-
)
|
695
|
-
|
696
|
-
|
697
|
-
def init_sorted_ids_and_cumsum_buffer(
|
698
|
-
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
699
|
-
):
|
700
|
-
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
701
|
-
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
702
|
-
|
703
|
-
BLOCK_SIZE = 1024
|
704
|
-
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
705
|
-
grid = (sorted_ids_blocks + 1,)
|
706
|
-
|
707
|
-
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
708
|
-
sorted_ids,
|
709
|
-
cumsum_buffer,
|
710
|
-
max_num_tokens_padded,
|
711
|
-
topk_ids_numel,
|
712
|
-
num_experts,
|
713
|
-
BLOCK_SIZE,
|
714
|
-
next_power_of_2(num_experts + 1),
|
715
|
-
)
|
716
|
-
|
717
|
-
return sorted_ids, cumsum_buffer
|
718
|
-
|
719
|
-
|
720
531
|
def moe_align_block_size(
|
721
532
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
722
533
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
@@ -766,42 +577,32 @@ def moe_align_block_size(
|
|
766
577
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
767
578
|
)
|
768
579
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
769
|
-
if enable_moe_align_block_size_triton:
|
770
|
-
sorted_ids.fill_(topk_ids.numel())
|
771
|
-
moe_align_block_size_triton(
|
772
|
-
topk_ids,
|
773
|
-
num_experts,
|
774
|
-
block_size,
|
775
|
-
sorted_ids,
|
776
|
-
expert_ids,
|
777
|
-
num_tokens_post_pad,
|
778
|
-
)
|
779
|
-
else:
|
780
|
-
cumsum_buffer = torch.empty(
|
781
|
-
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
782
|
-
)
|
783
|
-
token_cnts_buffer = torch.empty(
|
784
|
-
(num_experts + 1) * num_experts,
|
785
|
-
dtype=torch.int32,
|
786
|
-
device=topk_ids.device,
|
787
|
-
)
|
788
580
|
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
581
|
+
cumsum_buffer = torch.empty(
|
582
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
583
|
+
)
|
584
|
+
token_cnts_buffer = torch.empty(
|
585
|
+
(num_experts + 1) * num_experts,
|
586
|
+
dtype=torch.int32,
|
587
|
+
device=topk_ids.device,
|
588
|
+
)
|
793
589
|
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
590
|
+
# Threshold based on benchmark results
|
591
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
592
|
+
if not fuse_sorted_ids_padding:
|
593
|
+
sorted_ids.fill_(topk_ids.numel())
|
594
|
+
|
595
|
+
sgl_moe_align_block_size(
|
596
|
+
topk_ids,
|
597
|
+
num_experts,
|
598
|
+
block_size,
|
599
|
+
sorted_ids,
|
600
|
+
expert_ids,
|
601
|
+
num_tokens_post_pad,
|
602
|
+
token_cnts_buffer,
|
603
|
+
cumsum_buffer,
|
604
|
+
fuse_sorted_ids_padding,
|
605
|
+
)
|
805
606
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
806
607
|
|
807
608
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import math
|
18
|
-
from typing import
|
18
|
+
from typing import Callable, NamedTuple, Optional
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn.functional as F
|
@@ -39,10 +39,10 @@ from sglang.srt.utils import (
|
|
39
39
|
|
40
40
|
_is_cuda = is_cuda()
|
41
41
|
_is_hip = is_hip()
|
42
|
-
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
44
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
45
44
|
_is_npu = is_npu()
|
45
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
46
46
|
|
47
47
|
if _is_cuda:
|
48
48
|
from sgl_kernel import moe_fused_gate
|
@@ -54,7 +54,6 @@ if _use_aiter:
|
|
54
54
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
55
55
|
except ImportError:
|
56
56
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
57
|
-
|
58
57
|
if _is_npu:
|
59
58
|
import torch_npu
|
60
59
|
|
@@ -387,6 +386,7 @@ def grouped_topk_cpu(
|
|
387
386
|
)
|
388
387
|
|
389
388
|
|
389
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
390
390
|
def biased_grouped_topk_impl(
|
391
391
|
hidden_states: torch.Tensor,
|
392
392
|
gating_output: torch.Tensor,
|
@@ -482,7 +482,6 @@ def biased_grouped_topk_gpu(
|
|
482
482
|
renormalize: bool,
|
483
483
|
num_expert_group: int = 0,
|
484
484
|
topk_group: int = 0,
|
485
|
-
compiled: bool = not _is_npu,
|
486
485
|
num_fused_shared_experts: int = 0,
|
487
486
|
routed_scaling_factor: Optional[float] = None,
|
488
487
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -535,14 +534,7 @@ def biased_grouped_topk_gpu(
|
|
535
534
|
)
|
536
535
|
return topk_weights, topk_ids
|
537
536
|
else:
|
538
|
-
|
539
|
-
torch.compile(
|
540
|
-
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
541
|
-
)
|
542
|
-
if compiled
|
543
|
-
else biased_grouped_topk_impl
|
544
|
-
)
|
545
|
-
return biased_grouped_topk_fn(
|
537
|
+
return biased_grouped_topk_impl(
|
546
538
|
hidden_states,
|
547
539
|
gating_output,
|
548
540
|
correction_bias,
|
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
|
|
28
28
|
CompressedTensorsConfig,
|
29
29
|
)
|
30
30
|
|
31
|
-
_is_cuda = is_cuda()
|
32
|
-
_is_npu = is_npu()
|
33
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
34
|
-
_is_cpu = is_cpu()
|
35
|
-
_is_hip = is_hip()
|
36
|
-
|
37
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
38
|
-
from vllm import _custom_ops as vllm_ops
|
39
|
-
from vllm._custom_ops import scaled_fp8_quant
|
40
31
|
|
41
32
|
try:
|
42
33
|
import vllm
|
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
568
559
|
requires_grad=False,
|
569
560
|
)
|
570
561
|
|
562
|
+
from vllm import _custom_ops as vllm_ops
|
563
|
+
|
571
564
|
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
572
565
|
layer.w13_weight_packed,
|
573
566
|
layer.w13_g_idx_sort_indices,
|
@@ -952,7 +952,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
952
952
|
tp_rank: Optional[int] = None,
|
953
953
|
tp_size: Optional[int] = None,
|
954
954
|
) -> torch.Tensor:
|
955
|
-
|
956
955
|
assert activation == "silu", "Only SiLU activation is supported."
|
957
956
|
|
958
957
|
if self.enable_flashinfer_moe:
|
@@ -982,13 +981,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
982
981
|
tp_size=tp_size,
|
983
982
|
tp_rank=tp_rank,
|
984
983
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
985
|
-
)
|
986
|
-
|
984
|
+
)[0]
|
985
|
+
if routed_scaling_factor is not None:
|
986
|
+
output *= routed_scaling_factor
|
987
|
+
return output
|
987
988
|
|
988
989
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
989
990
|
|
990
991
|
topk_weights, topk_ids, _ = topk_output
|
991
|
-
|
992
|
+
output = cutlass_moe_fp4(
|
992
993
|
a=x,
|
993
994
|
a1_gscale=layer.w13_input_scale_quant,
|
994
995
|
w1_fp4=layer.w13_weight,
|
@@ -1003,3 +1004,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1003
1004
|
params=layer.cutlass_moe_params,
|
1004
1005
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
1005
1006
|
).to(x.dtype)
|
1007
|
+
if routed_scaling_factor is not None:
|
1008
|
+
output *= routed_scaling_factor
|
1009
|
+
return output
|
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
19
|
|
20
|
-
_is_cuda = is_cuda()
|
21
|
-
_is_npu = is_npu()
|
22
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
23
|
-
_is_cpu = is_cpu()
|
24
|
-
_is_hip = is_hip()
|
25
|
-
|
26
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
27
|
-
from vllm._custom_ops import scaled_fp8_quant
|
28
|
-
|
29
20
|
|
30
21
|
def is_layer_skipped(
|
31
22
|
prefix: str,
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
from enum import Enum
|
17
|
-
from typing import Optional
|
18
|
+
from typing import TYPE_CHECKING, Optional
|
18
19
|
|
19
20
|
from torch import nn
|
20
21
|
|
21
|
-
|
22
|
-
from sglang.srt.
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
25
|
|
24
26
|
|
25
27
|
class AttentionType(Enum):
|