sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
8
10
|
|
9
11
|
from sglang.srt.configs.model_config import AttentionArch
|
10
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
|
|
64
66
|
|
65
67
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
66
68
|
|
69
|
+
# For sliding window attention topk>1 spec decoding
|
70
|
+
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
|
71
|
+
|
67
72
|
|
68
73
|
# Copied from:
|
69
74
|
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
340
345
|
else None
|
341
346
|
)
|
342
347
|
|
348
|
+
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
|
349
|
+
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
|
350
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
351
|
+
self.has_swa = (
|
352
|
+
self.sliding_window_size is not None and self.sliding_window_size > -1
|
353
|
+
)
|
354
|
+
|
343
355
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
344
356
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
345
357
|
metadata = FlashAttentionMetadata()
|
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
556
568
|
(1, 0),
|
557
569
|
)
|
558
570
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
571
|
+
|
572
|
+
if self.has_swa:
|
573
|
+
self._init_sliding_window_attn_spec_metadata(
|
574
|
+
metadata, metadata_expand
|
575
|
+
)
|
576
|
+
|
559
577
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
560
578
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
561
579
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
675
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
658
676
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
659
677
|
# here is two side inclusive
|
660
|
-
|
661
|
-
|
662
|
-
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
663
|
-
else (-1, -1)
|
678
|
+
is_swa = (
|
679
|
+
layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
664
680
|
)
|
681
|
+
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
|
665
682
|
k_descale, v_descale = None, None
|
666
683
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
667
684
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
684
701
|
)
|
685
702
|
|
686
703
|
# We do cascade attention for Target Verify with topk > 1
|
704
|
+
# We don't use cascade attention for Sliding Window Attention:
|
705
|
+
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
|
706
|
+
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
|
687
707
|
use_cascade_attn = (
|
688
|
-
forward_batch.forward_mode.is_target_verify()
|
708
|
+
forward_batch.forward_mode.is_target_verify()
|
709
|
+
and self.topk > 1
|
710
|
+
and not is_swa
|
689
711
|
)
|
690
712
|
|
691
713
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
700
722
|
cu_seqlens_q = local_metadata.local_query_start_loc
|
701
723
|
cache_seqlens = local_metadata.local_seqused_k
|
702
724
|
max_seqlen_q = local_metadata.local_max_query_len
|
703
|
-
|
725
|
+
elif is_swa and metadata.swa_spec_metadata is not None:
|
726
|
+
swa_spec_metadata = metadata.swa_spec_metadata
|
727
|
+
page_table = swa_spec_metadata.page_table
|
728
|
+
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
|
729
|
+
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
|
730
|
+
max_seqlen_q = swa_spec_metadata.max_seq_len_q
|
731
|
+
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
|
704
732
|
else:
|
705
733
|
page_table = metadata.page_table
|
706
734
|
cu_seqlens_q = metadata.cu_seqlens_q
|
707
735
|
cache_seqlens = metadata.cache_seqlens_int32
|
708
736
|
max_seqlen_q = metadata.max_seq_len_q
|
709
|
-
max_seqlen_k = metadata.max_seq_len_k
|
710
737
|
cu_seqlens_k = metadata.cu_seqlens_k
|
711
738
|
|
712
739
|
# Use Flash Attention for prefill
|
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1377
1404
|
),
|
1378
1405
|
}
|
1379
1406
|
|
1407
|
+
if self.has_swa:
|
1408
|
+
self.target_verify_metadata_topk_swa = {
|
1409
|
+
"cache_seqlens": torch.zeros(
|
1410
|
+
max_bs * self.speculative_num_draft_tokens,
|
1411
|
+
dtype=torch.int32,
|
1412
|
+
device=self.device,
|
1413
|
+
),
|
1414
|
+
"cu_seqlens_k": torch.zeros(
|
1415
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1416
|
+
dtype=torch.int32,
|
1417
|
+
device=self.device,
|
1418
|
+
),
|
1419
|
+
"cu_seqlens_q": torch.arange(
|
1420
|
+
0,
|
1421
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1422
|
+
dtype=torch.int32,
|
1423
|
+
device=self.device,
|
1424
|
+
),
|
1425
|
+
"page_table": torch.zeros(
|
1426
|
+
max_bs * self.speculative_num_draft_tokens,
|
1427
|
+
self.max_context_len,
|
1428
|
+
dtype=torch.int32,
|
1429
|
+
device=self.device,
|
1430
|
+
),
|
1431
|
+
}
|
1432
|
+
|
1380
1433
|
self.encoder_metadata = {
|
1381
1434
|
"encoder_page_table": torch.zeros(
|
1382
1435
|
max_bs,
|
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1564
1617
|
|
1565
1618
|
self.target_verify_metadata_topk_normal[bs] = metadata
|
1566
1619
|
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
1620
|
+
|
1621
|
+
if self.has_swa:
|
1622
|
+
metadata_swa = FlashAttentionMetadata()
|
1623
|
+
metadata_swa.cache_seqlens_int32 = (
|
1624
|
+
self.target_verify_metadata_topk_swa["cache_seqlens"][
|
1625
|
+
: bs * self.speculative_num_draft_tokens
|
1626
|
+
]
|
1627
|
+
)
|
1628
|
+
metadata_swa.max_seq_len_q = 1
|
1629
|
+
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
|
1630
|
+
"cu_seqlens_q"
|
1631
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1632
|
+
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
|
1633
|
+
"cu_seqlens_k"
|
1634
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1635
|
+
|
1636
|
+
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
|
1637
|
+
"page_table"
|
1638
|
+
][: bs * self.speculative_num_draft_tokens]
|
1639
|
+
self.target_verify_metadata_topk_swa[bs] = metadata_swa
|
1640
|
+
metadata.swa_spec_metadata = metadata_swa
|
1641
|
+
|
1567
1642
|
elif forward_mode.is_draft_extend():
|
1568
1643
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
1569
1644
|
:bs
|
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1804
1879
|
)
|
1805
1880
|
)
|
1806
1881
|
|
1882
|
+
if self.has_swa:
|
1883
|
+
metadata_swa = self.target_verify_metadata_topk_swa[bs]
|
1884
|
+
self._init_sliding_window_attn_spec_metadata(
|
1885
|
+
metadata, metadata_expand, metadata_swa
|
1886
|
+
)
|
1887
|
+
|
1807
1888
|
elif forward_mode.is_draft_extend():
|
1808
1889
|
metadata = self.draft_extend_metadata[bs]
|
1809
1890
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
2039
2120
|
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
2040
2121
|
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
2041
2122
|
|
2123
|
+
def _init_sliding_window_attn_spec_metadata(
|
2124
|
+
self,
|
2125
|
+
metadata: FlashAttentionMetadata,
|
2126
|
+
metadata_expand: FlashAttentionMetadata,
|
2127
|
+
metadata_swa: Optional[FlashAttentionMetadata] = None,
|
2128
|
+
):
|
2129
|
+
# TODO: support page_size > 1 for swa spec
|
2130
|
+
assert (
|
2131
|
+
self.page_size == 1
|
2132
|
+
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
|
2133
|
+
|
2134
|
+
cache_seqlens_int32 = (
|
2135
|
+
metadata.cache_seqlens_int32.repeat_interleave(
|
2136
|
+
self.speculative_num_draft_tokens
|
2137
|
+
)
|
2138
|
+
+ metadata_expand.cache_seqlens_int32
|
2139
|
+
)
|
2140
|
+
cu_seqlens_k = torch.nn.functional.pad(
|
2141
|
+
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
|
2142
|
+
)
|
2143
|
+
bs = cache_seqlens_int32.shape[0]
|
2144
|
+
page_table = (
|
2145
|
+
metadata.page_table.new_zeros(
|
2146
|
+
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
|
2147
|
+
)
|
2148
|
+
if metadata_swa is None
|
2149
|
+
else metadata_swa.page_table
|
2150
|
+
)
|
2151
|
+
|
2152
|
+
prepare_swa_spec_page_table_triton(
|
2153
|
+
page_table,
|
2154
|
+
metadata.page_table,
|
2155
|
+
metadata_expand.page_table,
|
2156
|
+
metadata.cache_seqlens_int32,
|
2157
|
+
metadata_expand.cache_seqlens_int32,
|
2158
|
+
self.speculative_num_draft_tokens,
|
2159
|
+
)
|
2160
|
+
|
2161
|
+
if metadata_swa is None:
|
2162
|
+
metadata_swa = FlashAttentionMetadata()
|
2163
|
+
metadata_swa.max_seq_len_q = 1
|
2164
|
+
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
|
2165
|
+
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
|
2166
|
+
metadata_swa.cu_seqlens_k = cu_seqlens_k
|
2167
|
+
metadata_swa.page_table = page_table
|
2168
|
+
else:
|
2169
|
+
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
|
2170
|
+
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
|
2171
|
+
|
2172
|
+
metadata.swa_spec_metadata = metadata_swa
|
2173
|
+
|
2174
|
+
|
2175
|
+
@triton.jit
|
2176
|
+
def _prepare_swa_spec_page_table_kernel(
|
2177
|
+
dst_ptr,
|
2178
|
+
src_a_ptr,
|
2179
|
+
src_b_ptr,
|
2180
|
+
seq_len_a_ptr,
|
2181
|
+
seq_len_b_ptr,
|
2182
|
+
dst_stride_m,
|
2183
|
+
dst_stride_n,
|
2184
|
+
a_stride_m,
|
2185
|
+
a_stride_n,
|
2186
|
+
b_stride_m,
|
2187
|
+
b_stride_n,
|
2188
|
+
LEN_A: tl.constexpr,
|
2189
|
+
LEN_B: tl.constexpr,
|
2190
|
+
REPEAT_STEP: tl.constexpr,
|
2191
|
+
BLOCK_N: tl.constexpr,
|
2192
|
+
):
|
2193
|
+
pid_m = tl.program_id(0)
|
2194
|
+
pid_n = tl.program_id(1)
|
2195
|
+
|
2196
|
+
idx_a = pid_m // REPEAT_STEP
|
2197
|
+
idx_b = pid_m
|
2198
|
+
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
|
2199
|
+
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
|
2200
|
+
|
2201
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
2202
|
+
total_len = seq_len_a + seq_len_b
|
2203
|
+
|
2204
|
+
if pid_n * BLOCK_N >= total_len:
|
2205
|
+
return
|
2206
|
+
|
2207
|
+
mask = offs_n < total_len
|
2208
|
+
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
|
2209
|
+
|
2210
|
+
if (pid_n + 1) * BLOCK_N < seq_len_a:
|
2211
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
|
2212
|
+
a_mask = mask & (offs_n < LEN_A)
|
2213
|
+
val = tl.load(a_ptr, mask=a_mask, other=0)
|
2214
|
+
tl.store(dst, val, mask=mask)
|
2215
|
+
elif pid_n * BLOCK_N >= seq_len_a:
|
2216
|
+
offs_b = offs_n - seq_len_a
|
2217
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
|
2218
|
+
b_mask = mask & (offs_b < LEN_B)
|
2219
|
+
val = tl.load(b_ptr, mask=b_mask, other=0)
|
2220
|
+
tl.store(dst, val, mask=mask)
|
2221
|
+
else:
|
2222
|
+
# mixed part
|
2223
|
+
a_offs = offs_n
|
2224
|
+
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
|
2225
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
|
2226
|
+
a_val = tl.load(a_ptr, mask=a_mask, other=0)
|
2227
|
+
|
2228
|
+
b_offs = offs_n - seq_len_a
|
2229
|
+
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
|
2230
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
|
2231
|
+
b_val = tl.load(b_ptr, mask=b_mask, other=0)
|
2232
|
+
|
2233
|
+
result = tl.where(offs_n < seq_len_a, a_val, b_val)
|
2234
|
+
tl.store(dst, result, mask=mask)
|
2235
|
+
|
2236
|
+
|
2237
|
+
def prepare_swa_spec_page_table_triton(
|
2238
|
+
page_table_dst: torch.Tensor,
|
2239
|
+
page_table_a: torch.Tensor,
|
2240
|
+
page_table_b: torch.Tensor, # expand page table
|
2241
|
+
seq_len_a: torch.Tensor,
|
2242
|
+
seq_len_b: torch.Tensor, # expand seq lens
|
2243
|
+
speculative_num_draft_tokens: int,
|
2244
|
+
):
|
2245
|
+
# concat page_table and expand page_table by kv seq length
|
2246
|
+
bs = seq_len_a.numel()
|
2247
|
+
bs_expand = seq_len_b.numel()
|
2248
|
+
assert bs_expand == bs * speculative_num_draft_tokens
|
2249
|
+
|
2250
|
+
LEN_A = page_table_a.shape[1]
|
2251
|
+
LEN_B = page_table_b.shape[1]
|
2252
|
+
LEN_OUT = LEN_A + LEN_B
|
2253
|
+
REPEAT_STEP = speculative_num_draft_tokens
|
2254
|
+
BLOCK_N = 256
|
2255
|
+
|
2256
|
+
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
|
2257
|
+
_prepare_swa_spec_page_table_kernel[grid](
|
2258
|
+
page_table_dst,
|
2259
|
+
page_table_a,
|
2260
|
+
page_table_b,
|
2261
|
+
seq_len_a,
|
2262
|
+
seq_len_b,
|
2263
|
+
page_table_dst.stride(0),
|
2264
|
+
page_table_dst.stride(1),
|
2265
|
+
page_table_a.stride(0),
|
2266
|
+
page_table_a.stride(1),
|
2267
|
+
page_table_b.stride(0),
|
2268
|
+
page_table_b.stride(1),
|
2269
|
+
LEN_A=LEN_A,
|
2270
|
+
LEN_B=LEN_B,
|
2271
|
+
REPEAT_STEP=REPEAT_STEP,
|
2272
|
+
BLOCK_N=BLOCK_N,
|
2273
|
+
num_warps=4,
|
2274
|
+
)
|
2275
|
+
|
2042
2276
|
|
2043
2277
|
class FlashAttentionMultiStepBackend:
|
2044
2278
|
|
@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
26
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
27
27
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
28
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
30
29
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
31
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
32
31
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
33
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
is_flashinfer_available,
|
34
|
+
is_sm100_supported,
|
35
|
+
next_power_of_2,
|
36
|
+
)
|
34
37
|
|
35
38
|
if TYPE_CHECKING:
|
36
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|
28
28
|
create_flashinfer_kv_indices_triton,
|
29
29
|
)
|
30
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
32
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
34
33
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
35
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
is_flashinfer_available,
|
36
|
+
is_sm100_supported,
|
37
|
+
next_power_of_2,
|
38
|
+
)
|
36
39
|
|
37
40
|
if TYPE_CHECKING:
|
38
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -5,6 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
6
|
from sglang.srt.layers.radix_attention import RadixAttention
|
7
7
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
8
9
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
9
10
|
|
10
11
|
|
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
|
|
12
13
|
"""Support different backends for prefill and decode."""
|
13
14
|
|
14
15
|
def __init__(
|
15
|
-
self,
|
16
|
+
self,
|
17
|
+
model_runner: ModelRunner,
|
18
|
+
prefill_backend: AttentionBackend,
|
19
|
+
decode_backend: AttentionBackend,
|
16
20
|
):
|
21
|
+
self.model_runner = model_runner
|
17
22
|
self.prefill_backend = prefill_backend
|
18
23
|
self.decode_backend = decode_backend
|
19
24
|
|
20
25
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
21
|
-
if forward_batch.forward_mode.
|
26
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
22
27
|
self.decode_backend.init_forward_metadata(forward_batch)
|
23
28
|
else:
|
24
29
|
self.prefill_backend.init_forward_metadata(forward_batch)
|
25
30
|
|
26
31
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
27
32
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
33
|
+
if self.model_runner.server_args.speculative_algorithm is not None:
|
34
|
+
# When speculative decoding is enabled, we also need to initialize the
|
35
|
+
# prefill backend's cuda graph state to support target_verify.
|
36
|
+
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
28
37
|
|
29
38
|
def init_forward_metadata_capture_cuda_graph(
|
30
39
|
self,
|
@@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend):
|
|
36
45
|
forward_mode: ForwardMode,
|
37
46
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
38
47
|
):
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
+
if forward_mode.is_decode_or_idle():
|
49
|
+
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
50
|
+
bs,
|
51
|
+
num_tokens,
|
52
|
+
req_pool_indices,
|
53
|
+
seq_lens,
|
54
|
+
encoder_lens,
|
55
|
+
forward_mode,
|
56
|
+
spec_info,
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
|
60
|
+
bs,
|
61
|
+
num_tokens,
|
62
|
+
req_pool_indices,
|
63
|
+
seq_lens,
|
64
|
+
encoder_lens,
|
65
|
+
forward_mode,
|
66
|
+
spec_info,
|
67
|
+
)
|
48
68
|
|
49
69
|
def init_forward_metadata_replay_cuda_graph(
|
50
70
|
self,
|
@@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend):
|
|
57
77
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
58
78
|
seq_lens_cpu: Optional[torch.Tensor],
|
59
79
|
):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
80
|
+
if forward_mode.is_decode_or_idle():
|
81
|
+
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
82
|
+
bs,
|
83
|
+
req_pool_indices,
|
84
|
+
seq_lens,
|
85
|
+
seq_lens_sum,
|
86
|
+
encoder_lens,
|
87
|
+
forward_mode,
|
88
|
+
spec_info,
|
89
|
+
seq_lens_cpu,
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
|
93
|
+
bs,
|
94
|
+
req_pool_indices,
|
95
|
+
seq_lens,
|
96
|
+
seq_lens_sum,
|
97
|
+
encoder_lens,
|
98
|
+
forward_mode,
|
99
|
+
spec_info,
|
100
|
+
seq_lens_cpu,
|
101
|
+
)
|
70
102
|
|
71
103
|
def get_cuda_graph_seq_len_fill_value(self):
|
72
104
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
|
|
51
51
|
|
52
52
|
workspace: Optional[torch.Tensor] = None
|
53
53
|
block_kv_indices: Optional[torch.Tensor] = None
|
54
|
+
max_seq_len: Optional[int] = None
|
54
55
|
|
55
56
|
|
56
57
|
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
207
208
|
)
|
208
209
|
|
209
210
|
# Custom fast-path for decode/idle.
|
210
|
-
|
211
|
-
|
211
|
+
# Capture with full width so future longer sequences are safe during replay
|
212
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
213
|
+
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
|
212
214
|
|
213
215
|
create_flashmla_kv_indices_triton[(bs,)](
|
214
216
|
self.req_to_token,
|
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
217
219
|
None,
|
218
220
|
block_kv_indices,
|
219
221
|
self.req_to_token.stride(0),
|
220
|
-
|
222
|
+
max_blocks_per_seq,
|
221
223
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
222
224
|
PAGED_SIZE=self.page_size,
|
223
225
|
)
|
224
226
|
|
227
|
+
# Record the true maximum sequence length for this capture batch so that
|
228
|
+
# the kernel launch path (which requires an int not a tensor) can reuse
|
229
|
+
# it safely during both capture and replay.
|
230
|
+
max_seq_len_val = int(seq_lens.max().item())
|
231
|
+
|
225
232
|
metadata = TRTLLMMLADecodeMetadata(
|
226
|
-
self.decode_cuda_graph_workspace,
|
233
|
+
self.decode_cuda_graph_workspace,
|
234
|
+
block_kv_indices,
|
235
|
+
max_seq_len_val,
|
227
236
|
)
|
228
237
|
self.decode_cuda_graph_metadata[bs] = metadata
|
229
238
|
self.forward_metadata = metadata
|
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
268
277
|
PAGED_SIZE=self.page_size,
|
269
278
|
)
|
270
279
|
|
280
|
+
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
281
|
+
# Prefer CPU tensor to avoid GPU synchronization when available.
|
282
|
+
if seq_lens_cpu is not None:
|
283
|
+
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
284
|
+
else:
|
285
|
+
metadata.max_seq_len = int(seq_lens.max().item())
|
286
|
+
|
271
287
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
272
288
|
"""Get the fill value for sequence lengths in CUDA graph."""
|
273
289
|
return 1
|
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
295
311
|
forward_batch.seq_lens.device,
|
296
312
|
)
|
297
313
|
|
314
|
+
max_seq_len_val = int(max_seq)
|
298
315
|
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
299
|
-
self.workspace_buffer, block_kv_indices
|
316
|
+
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
300
317
|
)
|
301
318
|
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
302
319
|
|
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
471
488
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
472
489
|
block_tables=metadata.block_kv_indices,
|
473
490
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
474
|
-
max_seq_len=
|
491
|
+
max_seq_len=metadata.max_seq_len,
|
475
492
|
bmm1_scale=bmm1_scale,
|
476
493
|
)
|
477
494
|
|
478
|
-
#
|
479
|
-
|
480
|
-
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
481
|
-
|
495
|
+
# Reshape output directly without slicing
|
496
|
+
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
482
497
|
return output
|
483
498
|
|
484
499
|
|
@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
|
|
40
40
|
get_moe_a2a_backend,
|
41
41
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
42
42
|
)
|
43
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
44
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
-
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
45
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
47
46
|
|
48
47
|
_is_flashinfer_available = is_flashinfer_available()
|
49
48
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -53,7 +53,7 @@ elif _is_hip:
|
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
56
|
-
if
|
56
|
+
if _is_npu:
|
57
57
|
import torch_npu
|
58
58
|
|
59
59
|
|
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
|
|
266
266
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
267
267
|
return out
|
268
268
|
|
269
|
+
def forward_npu(
|
270
|
+
self,
|
271
|
+
x: torch.Tensor,
|
272
|
+
residual: Optional[torch.Tensor] = None,
|
273
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
274
|
+
orig_dtype = x.dtype
|
275
|
+
if residual is not None:
|
276
|
+
x = x + residual
|
277
|
+
residual = x
|
269
278
|
|
270
|
-
|
279
|
+
x = x.float()
|
280
|
+
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
281
|
+
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
282
|
+
x = x * (1.0 + self.weight.float())
|
283
|
+
x = x.to(orig_dtype)
|
284
|
+
return x if residual is None else (x, residual)
|
285
|
+
|
286
|
+
|
287
|
+
class Gemma3RMSNorm(CustomOp):
|
271
288
|
def __init__(self, dim: int, eps: float = 1e-6):
|
272
289
|
super().__init__()
|
273
290
|
self.eps = eps
|
274
291
|
self.weight = nn.Parameter(torch.zeros(dim))
|
292
|
+
# Re-dispatch
|
275
293
|
|
276
294
|
def _norm(self, x):
|
277
295
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
278
296
|
|
279
|
-
def
|
297
|
+
def forward_native(self, x):
|
280
298
|
output = self._norm(x.float())
|
281
299
|
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
282
300
|
# See https://github.com/huggingface/transformers/pull/29402
|
283
301
|
output = output * (1.0 + self.weight.float())
|
284
302
|
return output.type_as(x)
|
285
303
|
|
304
|
+
def forward_cuda(self, x):
|
305
|
+
return self.forward_native(x)
|
306
|
+
|
307
|
+
def forward_npu(self, x):
|
308
|
+
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
|
309
|
+
return output
|
310
|
+
|
286
311
|
def extra_repr(self):
|
287
312
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
288
313
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
|
|
235
235
|
loaded_weight = loaded_weight[:1]
|
236
236
|
else:
|
237
237
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
238
|
+
assert (
|
239
|
+
param.size() == loaded_weight.size()
|
240
|
+
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
240
241
|
param.data.copy_(loaded_weight)
|
241
242
|
|
242
243
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
|
|
61
61
|
hidden_states: Optional[torch.Tensor] = None
|
62
62
|
|
63
63
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
64
|
-
#
|
64
|
+
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
65
65
|
next_token_logprobs: Optional[torch.Tensor] = None
|
66
66
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
67
67
|
next_token_top_logprobs_val: Optional[List] = None
|
@@ -1,20 +1,12 @@
|
|
1
1
|
"""CUTLASS based Fused MoE kernels."""
|
2
2
|
|
3
|
-
import functools
|
4
|
-
import json
|
5
|
-
import logging
|
6
|
-
import os
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
|
-
|
9
3
|
import torch
|
10
4
|
|
11
5
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
13
6
|
from sglang.srt.utils import is_cuda
|
14
7
|
|
15
8
|
_is_cuda = is_cuda()
|
16
9
|
if _is_cuda:
|
17
|
-
import sgl_kernel
|
18
10
|
from sgl_kernel import (
|
19
11
|
apply_shuffle_mul_sum,
|
20
12
|
cutlass_fp4_group_mm,
|