sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
|
|
49
49
|
K_Buffer,
|
50
50
|
V_Buffer,
|
51
51
|
sm_scale,
|
52
|
-
|
53
|
-
|
54
|
-
B_Seqlen,
|
52
|
+
kv_indptr,
|
53
|
+
kv_indices,
|
55
54
|
Att_Out,
|
56
|
-
stride_req_to_tokens_b,
|
57
55
|
stride_qbs,
|
58
56
|
stride_qh,
|
59
57
|
stride_buf_kbs,
|
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
|
|
82
80
|
offs_dv = tl.arange(0, BLOCK_DV)
|
83
81
|
mask_d = offs_d < Lk
|
84
82
|
mask_dv = offs_dv < Lv
|
85
|
-
|
86
|
-
|
83
|
+
|
84
|
+
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
85
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
87
86
|
|
88
87
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
89
88
|
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
|
|
100
99
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
101
100
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
102
101
|
kv_loc = tl.load(
|
103
|
-
|
102
|
+
kv_indices + cur_batch_kv_start_idx + offs_n,
|
104
103
|
mask=offs_n < split_kv_end,
|
105
104
|
other=0,
|
106
105
|
)
|
@@ -173,19 +172,21 @@ def _decode_att_m_fwd(
|
|
173
172
|
k_buffer,
|
174
173
|
v_buffer,
|
175
174
|
att_out,
|
176
|
-
|
177
|
-
|
178
|
-
B_Seqlen,
|
175
|
+
kv_indptr,
|
176
|
+
kv_indices,
|
179
177
|
num_kv_splits,
|
180
178
|
sm_scale,
|
181
179
|
logit_cap,
|
182
180
|
):
|
183
181
|
BLOCK = 64
|
182
|
+
# [TODO] work around SGPR limit on MI3xx
|
183
|
+
if is_hip_:
|
184
|
+
BLOCK = 8
|
184
185
|
NUM_KV_SPLITS = num_kv_splits
|
185
186
|
Lk = k_buffer.shape[-1]
|
186
187
|
Lv = v_buffer.shape[-1]
|
187
188
|
|
188
|
-
batch, head_num =
|
189
|
+
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
189
190
|
|
190
191
|
grid = (batch, head_num, NUM_KV_SPLITS)
|
191
192
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
@@ -194,6 +195,8 @@ def _decode_att_m_fwd(
|
|
194
195
|
num_warps = 4
|
195
196
|
else:
|
196
197
|
num_warps = 2
|
198
|
+
if is_hip_:
|
199
|
+
num_warps = 1
|
197
200
|
|
198
201
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
199
202
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
@@ -203,11 +206,9 @@ def _decode_att_m_fwd(
|
|
203
206
|
k_buffer,
|
204
207
|
v_buffer,
|
205
208
|
sm_scale,
|
206
|
-
|
207
|
-
|
208
|
-
B_Seqlen,
|
209
|
+
kv_indptr,
|
210
|
+
kv_indices,
|
209
211
|
att_out,
|
210
|
-
Req_to_tokens.stride(0),
|
211
212
|
q.stride(0),
|
212
213
|
q.stride(1),
|
213
214
|
k_buffer.stride(0),
|
@@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
|
|
236
237
|
K_Buffer,
|
237
238
|
V_Buffer,
|
238
239
|
sm_scale,
|
239
|
-
|
240
|
-
|
241
|
-
B_Seqlen,
|
240
|
+
kv_indptr,
|
241
|
+
kv_indices,
|
242
242
|
Att_Out,
|
243
|
-
stride_req_to_tokens_b,
|
244
243
|
stride_qbs,
|
245
244
|
stride_qh,
|
246
245
|
stride_buf_kbs,
|
@@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
|
|
279
278
|
offs_dv = tl.arange(0, BLOCK_DV)
|
280
279
|
mask_d = offs_d < Lk
|
281
280
|
mask_dv = offs_dv < Lv
|
282
|
-
|
283
|
-
|
281
|
+
|
282
|
+
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
283
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
284
284
|
|
285
285
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
286
286
|
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
@@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
|
|
307
307
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
308
308
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
309
309
|
kv_loc = tl.load(
|
310
|
-
|
310
|
+
kv_indices + cur_batch_kv_start_idx + offs_n,
|
311
311
|
mask=offs_n < split_kv_end,
|
312
312
|
other=0,
|
313
313
|
)
|
@@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd(
|
|
395
395
|
k_buffer,
|
396
396
|
v_buffer,
|
397
397
|
att_out,
|
398
|
-
|
399
|
-
|
400
|
-
B_Seqlen,
|
398
|
+
kv_indptr,
|
399
|
+
kv_indices,
|
401
400
|
num_kv_splits,
|
402
401
|
sm_scale,
|
403
402
|
logit_cap,
|
@@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd(
|
|
421
420
|
BLOCK_DPE = 0
|
422
421
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
423
422
|
|
424
|
-
batch, head_num =
|
423
|
+
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
425
424
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
426
425
|
|
427
426
|
BLOCK_H = 16
|
@@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd(
|
|
433
432
|
)
|
434
433
|
|
435
434
|
extra_kargs = {}
|
435
|
+
num_stages = 2
|
436
436
|
if is_hip_:
|
437
437
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
438
438
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
439
|
-
extra_kargs = {"waves_per_eu":
|
439
|
+
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
440
|
+
num_stages = 1
|
440
441
|
|
441
442
|
_fwd_grouped_kernel_stage1[grid](
|
442
443
|
q,
|
443
444
|
k_buffer,
|
444
445
|
v_buffer,
|
445
446
|
sm_scale,
|
446
|
-
|
447
|
-
|
448
|
-
B_Seqlen,
|
447
|
+
kv_indptr,
|
448
|
+
kv_indices,
|
449
449
|
att_out,
|
450
|
-
Req_to_tokens.stride(0),
|
451
450
|
q.stride(0),
|
452
451
|
q.stride(1),
|
453
452
|
k_buffer.stride(0),
|
@@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd(
|
|
467
466
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
468
467
|
logit_cap=logit_cap,
|
469
468
|
num_warps=4,
|
470
|
-
num_stages=
|
469
|
+
num_stages=num_stages,
|
471
470
|
Lk=Lk,
|
472
471
|
Lv=Lv,
|
473
472
|
**extra_kargs,
|
@@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd(
|
|
478
477
|
def _fwd_kernel_stage2(
|
479
478
|
Mid_O,
|
480
479
|
O,
|
481
|
-
|
480
|
+
kv_indptr,
|
482
481
|
stride_mid_ob,
|
483
482
|
stride_mid_oh,
|
484
483
|
stride_mid_os,
|
@@ -491,7 +490,9 @@ def _fwd_kernel_stage2(
|
|
491
490
|
cur_batch = tl.program_id(0)
|
492
491
|
cur_head = tl.program_id(1)
|
493
492
|
|
494
|
-
cur_batch_seq_len = tl.load(
|
493
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
494
|
+
kv_indptr + cur_batch
|
495
|
+
)
|
495
496
|
|
496
497
|
offs_d = tl.arange(0, BLOCK_DV)
|
497
498
|
mask_d = offs_d < Lv
|
@@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd(
|
|
535
536
|
q,
|
536
537
|
o,
|
537
538
|
v_buffer,
|
538
|
-
|
539
|
+
kv_indptr,
|
539
540
|
num_kv_splits,
|
540
541
|
):
|
541
542
|
batch, head_num = q.shape[0], q.shape[1]
|
@@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd(
|
|
554
555
|
_fwd_kernel_stage2[grid](
|
555
556
|
logits,
|
556
557
|
o,
|
557
|
-
|
558
|
+
kv_indptr,
|
558
559
|
logits.stride(0),
|
559
560
|
logits.stride(1),
|
560
561
|
logits.stride(2),
|
@@ -574,9 +575,8 @@ def decode_attention_fwd_normal(
|
|
574
575
|
k_buffer,
|
575
576
|
v_buffer,
|
576
577
|
o,
|
577
|
-
|
578
|
-
|
579
|
-
b_seq_len,
|
578
|
+
kv_indptr,
|
579
|
+
kv_indices,
|
580
580
|
attn_logits,
|
581
581
|
num_kv_splits,
|
582
582
|
sm_scale,
|
@@ -587,14 +587,13 @@ def decode_attention_fwd_normal(
|
|
587
587
|
k_buffer,
|
588
588
|
v_buffer,
|
589
589
|
attn_logits,
|
590
|
-
|
591
|
-
|
592
|
-
b_seq_len,
|
590
|
+
kv_indptr,
|
591
|
+
kv_indices,
|
593
592
|
num_kv_splits,
|
594
593
|
sm_scale,
|
595
594
|
logit_cap,
|
596
595
|
)
|
597
|
-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer,
|
596
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
598
597
|
|
599
598
|
|
600
599
|
def decode_attention_fwd_grouped(
|
@@ -602,9 +601,8 @@ def decode_attention_fwd_grouped(
|
|
602
601
|
k_buffer,
|
603
602
|
v_buffer,
|
604
603
|
o,
|
605
|
-
|
606
|
-
|
607
|
-
b_seq_len,
|
604
|
+
kv_indptr,
|
605
|
+
kv_indices,
|
608
606
|
attn_logits,
|
609
607
|
num_kv_splits,
|
610
608
|
sm_scale,
|
@@ -615,14 +613,13 @@ def decode_attention_fwd_grouped(
|
|
615
613
|
k_buffer,
|
616
614
|
v_buffer,
|
617
615
|
attn_logits,
|
618
|
-
|
619
|
-
|
620
|
-
b_seq_len,
|
616
|
+
kv_indptr,
|
617
|
+
kv_indices,
|
621
618
|
num_kv_splits,
|
622
619
|
sm_scale,
|
623
620
|
logit_cap,
|
624
621
|
)
|
625
|
-
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer,
|
622
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
626
623
|
|
627
624
|
|
628
625
|
def decode_attention_fwd(
|
@@ -630,9 +627,8 @@ def decode_attention_fwd(
|
|
630
627
|
k_buffer,
|
631
628
|
v_buffer,
|
632
629
|
o,
|
633
|
-
|
634
|
-
|
635
|
-
b_seq_len,
|
630
|
+
kv_indptr,
|
631
|
+
kv_indices,
|
636
632
|
attn_logits,
|
637
633
|
num_kv_splits,
|
638
634
|
sm_scale,
|
@@ -648,9 +644,8 @@ def decode_attention_fwd(
|
|
648
644
|
k_buffer,
|
649
645
|
v_buffer,
|
650
646
|
o,
|
651
|
-
|
652
|
-
|
653
|
-
b_seq_len,
|
647
|
+
kv_indptr,
|
648
|
+
kv_indices,
|
654
649
|
attn_logits,
|
655
650
|
num_kv_splits,
|
656
651
|
sm_scale,
|
@@ -663,9 +658,8 @@ def decode_attention_fwd(
|
|
663
658
|
k_buffer,
|
664
659
|
v_buffer,
|
665
660
|
o,
|
666
|
-
|
667
|
-
|
668
|
-
b_seq_len,
|
661
|
+
kv_indptr,
|
662
|
+
kv_indices,
|
669
663
|
attn_logits,
|
670
664
|
num_kv_splits,
|
671
665
|
sm_scale,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -29,14 +29,11 @@ if is_cuda_available():
|
|
29
29
|
rmsnorm,
|
30
30
|
)
|
31
31
|
|
32
|
-
from
|
33
|
-
|
34
|
-
from sglang.srt.layers.custom_op_util import register_custom_op
|
32
|
+
from sglang.srt.custom_op import CustomOp
|
35
33
|
|
36
34
|
logger = logging.getLogger(__name__)
|
37
35
|
|
38
36
|
|
39
|
-
@register_custom_op("sglang_rmsnorm")
|
40
37
|
class RMSNorm(CustomOp):
|
41
38
|
def __init__(
|
42
39
|
self,
|
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
|
|
79
76
|
return x, residual
|
80
77
|
|
81
78
|
|
82
|
-
@register_custom_op("sglang_gemma_rmsnorm")
|
83
79
|
class GemmaRMSNorm(CustomOp):
|
84
80
|
def __init__(
|
85
81
|
self,
|
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
from vllm import _custom_ops as ops
|
7
|
-
from vllm.model_executor.custom_op import CustomOp
|
8
7
|
|
8
|
+
from sglang.srt.custom_op import CustomOp
|
9
9
|
from sglang.srt.distributed import (
|
10
10
|
get_tensor_model_parallel_rank,
|
11
11
|
get_tensor_model_parallel_world_size,
|
12
12
|
)
|
13
|
-
from sglang.srt.layers.custom_op_util import register_custom_op
|
14
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
15
14
|
grouped_gemm_triton,
|
16
15
|
post_reorder_triton_kernel,
|
@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
|
|
407
406
|
param_data[expert_id] = loaded_weight
|
408
407
|
|
409
408
|
|
410
|
-
@register_custom_op("sglang_unquantized_ep_moe")
|
411
409
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
412
410
|
def create_weights(
|
413
411
|
self,
|
@@ -0,0 +1,164 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2,
|
9
|
+
"waves_per_eu": 0
|
10
|
+
},
|
11
|
+
"2": {
|
12
|
+
"BLOCK_SIZE_M": 32,
|
13
|
+
"BLOCK_SIZE_N": 64,
|
14
|
+
"BLOCK_SIZE_K": 128,
|
15
|
+
"GROUP_SIZE_M": 1,
|
16
|
+
"num_warps": 4,
|
17
|
+
"num_stages": 2,
|
18
|
+
"waves_per_eu": 0
|
19
|
+
},
|
20
|
+
"4": {
|
21
|
+
"BLOCK_SIZE_M": 64,
|
22
|
+
"BLOCK_SIZE_N": 64,
|
23
|
+
"BLOCK_SIZE_K": 128,
|
24
|
+
"GROUP_SIZE_M": 16,
|
25
|
+
"num_warps": 4,
|
26
|
+
"num_stages": 2,
|
27
|
+
"waves_per_eu": 0
|
28
|
+
},
|
29
|
+
"8": {
|
30
|
+
"BLOCK_SIZE_M": 32,
|
31
|
+
"BLOCK_SIZE_N": 128,
|
32
|
+
"BLOCK_SIZE_K": 128,
|
33
|
+
"GROUP_SIZE_M": 32,
|
34
|
+
"num_warps": 4,
|
35
|
+
"num_stages": 2,
|
36
|
+
"waves_per_eu": 0
|
37
|
+
},
|
38
|
+
"16": {
|
39
|
+
"BLOCK_SIZE_M": 32,
|
40
|
+
"BLOCK_SIZE_N": 128,
|
41
|
+
"BLOCK_SIZE_K": 128,
|
42
|
+
"GROUP_SIZE_M": 1,
|
43
|
+
"num_warps": 4,
|
44
|
+
"num_stages": 2,
|
45
|
+
"waves_per_eu": 0
|
46
|
+
},
|
47
|
+
"24": {
|
48
|
+
"BLOCK_SIZE_M": 32,
|
49
|
+
"BLOCK_SIZE_N": 128,
|
50
|
+
"BLOCK_SIZE_K": 128,
|
51
|
+
"GROUP_SIZE_M": 4,
|
52
|
+
"num_warps": 4,
|
53
|
+
"num_stages": 2,
|
54
|
+
"waves_per_eu": 0
|
55
|
+
},
|
56
|
+
"32": {
|
57
|
+
"BLOCK_SIZE_M": 32,
|
58
|
+
"BLOCK_SIZE_N": 128,
|
59
|
+
"BLOCK_SIZE_K": 128,
|
60
|
+
"GROUP_SIZE_M": 8,
|
61
|
+
"num_warps": 4,
|
62
|
+
"num_stages": 2,
|
63
|
+
"waves_per_eu": 0
|
64
|
+
},
|
65
|
+
"48": {
|
66
|
+
"BLOCK_SIZE_M": 32,
|
67
|
+
"BLOCK_SIZE_N": 128,
|
68
|
+
"BLOCK_SIZE_K": 128,
|
69
|
+
"GROUP_SIZE_M": 4,
|
70
|
+
"num_warps": 4,
|
71
|
+
"num_stages": 2,
|
72
|
+
"waves_per_eu": 0
|
73
|
+
},
|
74
|
+
"64": {
|
75
|
+
"BLOCK_SIZE_M": 256,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2,
|
81
|
+
"waves_per_eu": 0
|
82
|
+
},
|
83
|
+
"96": {
|
84
|
+
"BLOCK_SIZE_M": 32,
|
85
|
+
"BLOCK_SIZE_N": 128,
|
86
|
+
"BLOCK_SIZE_K": 128,
|
87
|
+
"GROUP_SIZE_M": 8,
|
88
|
+
"num_warps": 4,
|
89
|
+
"num_stages": 2,
|
90
|
+
"waves_per_eu": 0
|
91
|
+
},
|
92
|
+
"128": {
|
93
|
+
"BLOCK_SIZE_M": 32,
|
94
|
+
"BLOCK_SIZE_N": 16,
|
95
|
+
"BLOCK_SIZE_K": 128,
|
96
|
+
"GROUP_SIZE_M": 4,
|
97
|
+
"num_warps": 4,
|
98
|
+
"num_stages": 2,
|
99
|
+
"waves_per_eu": 0
|
100
|
+
},
|
101
|
+
"256": {
|
102
|
+
"BLOCK_SIZE_M": 64,
|
103
|
+
"BLOCK_SIZE_N": 16,
|
104
|
+
"BLOCK_SIZE_K": 128,
|
105
|
+
"GROUP_SIZE_M": 1,
|
106
|
+
"num_warps": 4,
|
107
|
+
"num_stages": 2,
|
108
|
+
"waves_per_eu": 0
|
109
|
+
},
|
110
|
+
"512": {
|
111
|
+
"BLOCK_SIZE_M": 64,
|
112
|
+
"BLOCK_SIZE_N": 64,
|
113
|
+
"BLOCK_SIZE_K": 128,
|
114
|
+
"GROUP_SIZE_M": 32,
|
115
|
+
"num_warps": 4,
|
116
|
+
"num_stages": 2,
|
117
|
+
"waves_per_eu": 0
|
118
|
+
},
|
119
|
+
"1024": {
|
120
|
+
"BLOCK_SIZE_M": 64,
|
121
|
+
"BLOCK_SIZE_N": 64,
|
122
|
+
"BLOCK_SIZE_K": 128,
|
123
|
+
"GROUP_SIZE_M": 4,
|
124
|
+
"num_warps": 8,
|
125
|
+
"num_stages": 2,
|
126
|
+
"waves_per_eu": 0
|
127
|
+
},
|
128
|
+
"1536": {
|
129
|
+
"BLOCK_SIZE_M": 64,
|
130
|
+
"BLOCK_SIZE_N": 64,
|
131
|
+
"BLOCK_SIZE_K": 128,
|
132
|
+
"GROUP_SIZE_M": 8,
|
133
|
+
"num_warps": 4,
|
134
|
+
"num_stages": 2,
|
135
|
+
"waves_per_eu": 0
|
136
|
+
},
|
137
|
+
"2048": {
|
138
|
+
"BLOCK_SIZE_M": 32,
|
139
|
+
"BLOCK_SIZE_N": 64,
|
140
|
+
"BLOCK_SIZE_K": 128,
|
141
|
+
"GROUP_SIZE_M": 1,
|
142
|
+
"num_warps": 4,
|
143
|
+
"num_stages": 2,
|
144
|
+
"waves_per_eu": 0
|
145
|
+
},
|
146
|
+
"3072": {
|
147
|
+
"BLOCK_SIZE_M": 32,
|
148
|
+
"BLOCK_SIZE_N": 128,
|
149
|
+
"BLOCK_SIZE_K": 128,
|
150
|
+
"GROUP_SIZE_M": 1,
|
151
|
+
"num_warps": 4,
|
152
|
+
"num_stages": 2,
|
153
|
+
"waves_per_eu": 0
|
154
|
+
},
|
155
|
+
"4096": {
|
156
|
+
"BLOCK_SIZE_M": 64,
|
157
|
+
"BLOCK_SIZE_N": 128,
|
158
|
+
"BLOCK_SIZE_K": 64,
|
159
|
+
"GROUP_SIZE_M": 4,
|
160
|
+
"num_warps": 4,
|
161
|
+
"num_stages": 2,
|
162
|
+
"waves_per_eu": 0
|
163
|
+
}
|
164
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 256,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 8,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 256,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 8,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 256,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 16,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 16,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 16,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 32,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 32,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|