sglang 0.5.1.post2__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -1
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -5
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +4 -3
- {sglang-0.5.1.post2.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +57 -54
- {sglang-0.5.1.post2.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
8
10
|
|
9
11
|
from sglang.srt.configs.model_config import AttentionArch
|
10
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
|
|
64
66
|
|
65
67
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
66
68
|
|
69
|
+
# For sliding window attention topk>1 spec decoding
|
70
|
+
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
|
71
|
+
|
67
72
|
|
68
73
|
# Copied from:
|
69
74
|
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
340
345
|
else None
|
341
346
|
)
|
342
347
|
|
348
|
+
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
|
349
|
+
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
|
350
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
351
|
+
self.has_swa = (
|
352
|
+
self.sliding_window_size is not None and self.sliding_window_size > -1
|
353
|
+
)
|
354
|
+
|
343
355
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
344
356
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
345
357
|
metadata = FlashAttentionMetadata()
|
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
556
568
|
(1, 0),
|
557
569
|
)
|
558
570
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
571
|
+
|
572
|
+
if self.has_swa:
|
573
|
+
self._init_sliding_window_attn_spec_metadata(
|
574
|
+
metadata, metadata_expand
|
575
|
+
)
|
576
|
+
|
559
577
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
560
578
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
561
579
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
675
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
658
676
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
659
677
|
# here is two side inclusive
|
660
|
-
|
661
|
-
|
662
|
-
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
663
|
-
else (-1, -1)
|
678
|
+
is_swa = (
|
679
|
+
layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
664
680
|
)
|
681
|
+
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
|
665
682
|
k_descale, v_descale = None, None
|
666
683
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
667
684
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
684
701
|
)
|
685
702
|
|
686
703
|
# We do cascade attention for Target Verify with topk > 1
|
704
|
+
# We don't use cascade attention for Sliding Window Attention:
|
705
|
+
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
|
706
|
+
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
|
687
707
|
use_cascade_attn = (
|
688
|
-
forward_batch.forward_mode.is_target_verify()
|
708
|
+
forward_batch.forward_mode.is_target_verify()
|
709
|
+
and self.topk > 1
|
710
|
+
and not is_swa
|
689
711
|
)
|
690
712
|
|
691
713
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
700
722
|
cu_seqlens_q = local_metadata.local_query_start_loc
|
701
723
|
cache_seqlens = local_metadata.local_seqused_k
|
702
724
|
max_seqlen_q = local_metadata.local_max_query_len
|
703
|
-
|
725
|
+
elif is_swa and metadata.swa_spec_metadata is not None:
|
726
|
+
swa_spec_metadata = metadata.swa_spec_metadata
|
727
|
+
page_table = swa_spec_metadata.page_table
|
728
|
+
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
|
729
|
+
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
|
730
|
+
max_seqlen_q = swa_spec_metadata.max_seq_len_q
|
731
|
+
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
|
704
732
|
else:
|
705
733
|
page_table = metadata.page_table
|
706
734
|
cu_seqlens_q = metadata.cu_seqlens_q
|
707
735
|
cache_seqlens = metadata.cache_seqlens_int32
|
708
736
|
max_seqlen_q = metadata.max_seq_len_q
|
709
|
-
max_seqlen_k = metadata.max_seq_len_k
|
710
737
|
cu_seqlens_k = metadata.cu_seqlens_k
|
711
738
|
|
712
739
|
# Use Flash Attention for prefill
|
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1377
1404
|
),
|
1378
1405
|
}
|
1379
1406
|
|
1407
|
+
if self.has_swa:
|
1408
|
+
self.target_verify_metadata_topk_swa = {
|
1409
|
+
"cache_seqlens": torch.zeros(
|
1410
|
+
max_bs * self.speculative_num_draft_tokens,
|
1411
|
+
dtype=torch.int32,
|
1412
|
+
device=self.device,
|
1413
|
+
),
|
1414
|
+
"cu_seqlens_k": torch.zeros(
|
1415
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1416
|
+
dtype=torch.int32,
|
1417
|
+
device=self.device,
|
1418
|
+
),
|
1419
|
+
"cu_seqlens_q": torch.arange(
|
1420
|
+
0,
|
1421
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1422
|
+
dtype=torch.int32,
|
1423
|
+
device=self.device,
|
1424
|
+
),
|
1425
|
+
"page_table": torch.zeros(
|
1426
|
+
max_bs * self.speculative_num_draft_tokens,
|
1427
|
+
self.max_context_len,
|
1428
|
+
dtype=torch.int32,
|
1429
|
+
device=self.device,
|
1430
|
+
),
|
1431
|
+
}
|
1432
|
+
|
1380
1433
|
self.encoder_metadata = {
|
1381
1434
|
"encoder_page_table": torch.zeros(
|
1382
1435
|
max_bs,
|
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1564
1617
|
|
1565
1618
|
self.target_verify_metadata_topk_normal[bs] = metadata
|
1566
1619
|
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
1620
|
+
|
1621
|
+
if self.has_swa:
|
1622
|
+
metadata_swa = FlashAttentionMetadata()
|
1623
|
+
metadata_swa.cache_seqlens_int32 = (
|
1624
|
+
self.target_verify_metadata_topk_swa["cache_seqlens"][
|
1625
|
+
: bs * self.speculative_num_draft_tokens
|
1626
|
+
]
|
1627
|
+
)
|
1628
|
+
metadata_swa.max_seq_len_q = 1
|
1629
|
+
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
|
1630
|
+
"cu_seqlens_q"
|
1631
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1632
|
+
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
|
1633
|
+
"cu_seqlens_k"
|
1634
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1635
|
+
|
1636
|
+
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
|
1637
|
+
"page_table"
|
1638
|
+
][: bs * self.speculative_num_draft_tokens]
|
1639
|
+
self.target_verify_metadata_topk_swa[bs] = metadata_swa
|
1640
|
+
metadata.swa_spec_metadata = metadata_swa
|
1641
|
+
|
1567
1642
|
elif forward_mode.is_draft_extend():
|
1568
1643
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
1569
1644
|
:bs
|
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1804
1879
|
)
|
1805
1880
|
)
|
1806
1881
|
|
1882
|
+
if self.has_swa:
|
1883
|
+
metadata_swa = self.target_verify_metadata_topk_swa[bs]
|
1884
|
+
self._init_sliding_window_attn_spec_metadata(
|
1885
|
+
metadata, metadata_expand, metadata_swa
|
1886
|
+
)
|
1887
|
+
|
1807
1888
|
elif forward_mode.is_draft_extend():
|
1808
1889
|
metadata = self.draft_extend_metadata[bs]
|
1809
1890
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
2039
2120
|
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
2040
2121
|
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
2041
2122
|
|
2123
|
+
def _init_sliding_window_attn_spec_metadata(
|
2124
|
+
self,
|
2125
|
+
metadata: FlashAttentionMetadata,
|
2126
|
+
metadata_expand: FlashAttentionMetadata,
|
2127
|
+
metadata_swa: Optional[FlashAttentionMetadata] = None,
|
2128
|
+
):
|
2129
|
+
# TODO: support page_size > 1 for swa spec
|
2130
|
+
assert (
|
2131
|
+
self.page_size == 1
|
2132
|
+
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
|
2133
|
+
|
2134
|
+
cache_seqlens_int32 = (
|
2135
|
+
metadata.cache_seqlens_int32.repeat_interleave(
|
2136
|
+
self.speculative_num_draft_tokens
|
2137
|
+
)
|
2138
|
+
+ metadata_expand.cache_seqlens_int32
|
2139
|
+
)
|
2140
|
+
cu_seqlens_k = torch.nn.functional.pad(
|
2141
|
+
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
|
2142
|
+
)
|
2143
|
+
bs = cache_seqlens_int32.shape[0]
|
2144
|
+
page_table = (
|
2145
|
+
metadata.page_table.new_zeros(
|
2146
|
+
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
|
2147
|
+
)
|
2148
|
+
if metadata_swa is None
|
2149
|
+
else metadata_swa.page_table
|
2150
|
+
)
|
2151
|
+
|
2152
|
+
prepare_swa_spec_page_table_triton(
|
2153
|
+
page_table,
|
2154
|
+
metadata.page_table,
|
2155
|
+
metadata_expand.page_table,
|
2156
|
+
metadata.cache_seqlens_int32,
|
2157
|
+
metadata_expand.cache_seqlens_int32,
|
2158
|
+
self.speculative_num_draft_tokens,
|
2159
|
+
)
|
2160
|
+
|
2161
|
+
if metadata_swa is None:
|
2162
|
+
metadata_swa = FlashAttentionMetadata()
|
2163
|
+
metadata_swa.max_seq_len_q = 1
|
2164
|
+
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
|
2165
|
+
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
|
2166
|
+
metadata_swa.cu_seqlens_k = cu_seqlens_k
|
2167
|
+
metadata_swa.page_table = page_table
|
2168
|
+
else:
|
2169
|
+
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
|
2170
|
+
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
|
2171
|
+
|
2172
|
+
metadata.swa_spec_metadata = metadata_swa
|
2173
|
+
|
2174
|
+
|
2175
|
+
@triton.jit
|
2176
|
+
def _prepare_swa_spec_page_table_kernel(
|
2177
|
+
dst_ptr,
|
2178
|
+
src_a_ptr,
|
2179
|
+
src_b_ptr,
|
2180
|
+
seq_len_a_ptr,
|
2181
|
+
seq_len_b_ptr,
|
2182
|
+
dst_stride_m,
|
2183
|
+
dst_stride_n,
|
2184
|
+
a_stride_m,
|
2185
|
+
a_stride_n,
|
2186
|
+
b_stride_m,
|
2187
|
+
b_stride_n,
|
2188
|
+
LEN_A: tl.constexpr,
|
2189
|
+
LEN_B: tl.constexpr,
|
2190
|
+
REPEAT_STEP: tl.constexpr,
|
2191
|
+
BLOCK_N: tl.constexpr,
|
2192
|
+
):
|
2193
|
+
pid_m = tl.program_id(0)
|
2194
|
+
pid_n = tl.program_id(1)
|
2195
|
+
|
2196
|
+
idx_a = pid_m // REPEAT_STEP
|
2197
|
+
idx_b = pid_m
|
2198
|
+
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
|
2199
|
+
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
|
2200
|
+
|
2201
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
2202
|
+
total_len = seq_len_a + seq_len_b
|
2203
|
+
|
2204
|
+
if pid_n * BLOCK_N >= total_len:
|
2205
|
+
return
|
2206
|
+
|
2207
|
+
mask = offs_n < total_len
|
2208
|
+
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
|
2209
|
+
|
2210
|
+
if (pid_n + 1) * BLOCK_N < seq_len_a:
|
2211
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
|
2212
|
+
a_mask = mask & (offs_n < LEN_A)
|
2213
|
+
val = tl.load(a_ptr, mask=a_mask, other=0)
|
2214
|
+
tl.store(dst, val, mask=mask)
|
2215
|
+
elif pid_n * BLOCK_N >= seq_len_a:
|
2216
|
+
offs_b = offs_n - seq_len_a
|
2217
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
|
2218
|
+
b_mask = mask & (offs_b < LEN_B)
|
2219
|
+
val = tl.load(b_ptr, mask=b_mask, other=0)
|
2220
|
+
tl.store(dst, val, mask=mask)
|
2221
|
+
else:
|
2222
|
+
# mixed part
|
2223
|
+
a_offs = offs_n
|
2224
|
+
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
|
2225
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
|
2226
|
+
a_val = tl.load(a_ptr, mask=a_mask, other=0)
|
2227
|
+
|
2228
|
+
b_offs = offs_n - seq_len_a
|
2229
|
+
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
|
2230
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
|
2231
|
+
b_val = tl.load(b_ptr, mask=b_mask, other=0)
|
2232
|
+
|
2233
|
+
result = tl.where(offs_n < seq_len_a, a_val, b_val)
|
2234
|
+
tl.store(dst, result, mask=mask)
|
2235
|
+
|
2236
|
+
|
2237
|
+
def prepare_swa_spec_page_table_triton(
|
2238
|
+
page_table_dst: torch.Tensor,
|
2239
|
+
page_table_a: torch.Tensor,
|
2240
|
+
page_table_b: torch.Tensor, # expand page table
|
2241
|
+
seq_len_a: torch.Tensor,
|
2242
|
+
seq_len_b: torch.Tensor, # expand seq lens
|
2243
|
+
speculative_num_draft_tokens: int,
|
2244
|
+
):
|
2245
|
+
# concat page_table and expand page_table by kv seq length
|
2246
|
+
bs = seq_len_a.numel()
|
2247
|
+
bs_expand = seq_len_b.numel()
|
2248
|
+
assert bs_expand == bs * speculative_num_draft_tokens
|
2249
|
+
|
2250
|
+
LEN_A = page_table_a.shape[1]
|
2251
|
+
LEN_B = page_table_b.shape[1]
|
2252
|
+
LEN_OUT = LEN_A + LEN_B
|
2253
|
+
REPEAT_STEP = speculative_num_draft_tokens
|
2254
|
+
BLOCK_N = 256
|
2255
|
+
|
2256
|
+
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
|
2257
|
+
_prepare_swa_spec_page_table_kernel[grid](
|
2258
|
+
page_table_dst,
|
2259
|
+
page_table_a,
|
2260
|
+
page_table_b,
|
2261
|
+
seq_len_a,
|
2262
|
+
seq_len_b,
|
2263
|
+
page_table_dst.stride(0),
|
2264
|
+
page_table_dst.stride(1),
|
2265
|
+
page_table_a.stride(0),
|
2266
|
+
page_table_a.stride(1),
|
2267
|
+
page_table_b.stride(0),
|
2268
|
+
page_table_b.stride(1),
|
2269
|
+
LEN_A=LEN_A,
|
2270
|
+
LEN_B=LEN_B,
|
2271
|
+
REPEAT_STEP=REPEAT_STEP,
|
2272
|
+
BLOCK_N=BLOCK_N,
|
2273
|
+
num_warps=4,
|
2274
|
+
)
|
2275
|
+
|
2042
2276
|
|
2043
2277
|
class FlashAttentionMultiStepBackend:
|
2044
2278
|
|
@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
26
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
27
27
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
28
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
30
29
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
31
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
32
31
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
33
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
is_flashinfer_available,
|
34
|
+
is_sm100_supported,
|
35
|
+
next_power_of_2,
|
36
|
+
)
|
34
37
|
|
35
38
|
if TYPE_CHECKING:
|
36
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|
28
28
|
create_flashinfer_kv_indices_triton,
|
29
29
|
)
|
30
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
32
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
34
33
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
35
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
is_flashinfer_available,
|
36
|
+
is_sm100_supported,
|
37
|
+
next_power_of_2,
|
38
|
+
)
|
36
39
|
|
37
40
|
if TYPE_CHECKING:
|
38
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
|
|
40
40
|
get_moe_a2a_backend,
|
41
41
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
42
42
|
)
|
43
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
44
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
-
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
45
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
47
46
|
|
48
47
|
_is_flashinfer_available = is_flashinfer_available()
|
49
48
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
@@ -1,20 +1,12 @@
|
|
1
1
|
"""CUTLASS based Fused MoE kernels."""
|
2
2
|
|
3
|
-
import functools
|
4
|
-
import json
|
5
|
-
import logging
|
6
|
-
import os
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
|
-
|
9
3
|
import torch
|
10
4
|
|
11
5
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
13
6
|
from sglang.srt.utils import is_cuda
|
14
7
|
|
15
8
|
_is_cuda = is_cuda()
|
16
9
|
if _is_cuda:
|
17
|
-
import sgl_kernel
|
18
10
|
from sgl_kernel import (
|
19
11
|
apply_shuffle_mul_sum,
|
20
12
|
cutlass_fp4_group_mm,
|
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
|
|
248
248
|
gateup_output,
|
249
249
|
masked_m,
|
250
250
|
expected_m,
|
251
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
252
251
|
)
|
253
252
|
del gateup_input
|
254
253
|
del gateup_input_fp8
|
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
|
|
304
303
|
down_output,
|
305
304
|
masked_m,
|
306
305
|
expected_m,
|
307
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
308
306
|
)
|
309
307
|
del down_input
|
310
308
|
del down_input_fp8
|
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
|
|
667
665
|
gateup_output,
|
668
666
|
masked_m,
|
669
667
|
expected_m,
|
670
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
671
668
|
)
|
672
669
|
dispose_tensor(hidden_states_fp8[0])
|
673
670
|
|
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
|
|
708
705
|
(
|
709
706
|
down_input_scale
|
710
707
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
711
|
-
else deep_gemm_wrapper.
|
712
|
-
down_input_scale
|
713
|
-
)
|
708
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
714
709
|
),
|
715
710
|
)
|
716
711
|
down_output = torch.empty(
|
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
|
|
722
717
|
down_output,
|
723
718
|
masked_m,
|
724
719
|
expected_m,
|
725
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
726
720
|
)
|
727
721
|
|
728
722
|
return down_output
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -304,7 +304,7 @@ class TopK(CustomOp):
|
|
304
304
|
global_num_experts = router_logits.shape[-1]
|
305
305
|
|
306
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
307
|
-
if global_num_experts == 256 and self.topk_config.renormalize is
|
307
|
+
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
308
308
|
|
309
309
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
310
310
|
router_logits = router_logits.to(torch.float32)
|