sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
657
|
)
|
658
658
|
k_descale, v_descale = None, None
|
659
659
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
660
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
660
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
661
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
662
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
663
|
+
if layer.k_scale is not None:
|
664
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
665
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
666
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
665
667
|
q = q.to(self.kv_cache_dtype)
|
668
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
669
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
666
670
|
causal = not layer.is_cross_attention
|
667
671
|
|
668
672
|
# Check if we should use local attention
|
@@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
776
780
|
|
777
781
|
output, lse, *rest = flash_attn_varlen_func(
|
778
782
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
779
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
780
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
783
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
784
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
781
785
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
782
786
|
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
783
787
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
790
794
|
# MHA for extend part of sequence without attending prefix kv cache
|
791
795
|
output, lse, *rest = flash_attn_varlen_func(
|
792
796
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
793
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
794
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
797
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
798
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
795
799
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
796
800
|
cu_seqlens_k=metadata.cu_seqlens_q,
|
797
801
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
803
807
|
return output, lse
|
804
808
|
else:
|
805
809
|
# Do absorbed multi-latent attention
|
806
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
810
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
811
|
+
layer.layer_id
|
812
|
+
).to(q.dtype)
|
807
813
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
808
814
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
809
815
|
k_rope_cache = k_rope.view(
|
@@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
933
939
|
|
934
940
|
k_descale, v_descale = None, None
|
935
941
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
936
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
937
|
-
|
942
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
943
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
944
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
938
945
|
if layer.k_scale is not None:
|
939
946
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
940
947
|
k_descale = layer.k_scale.expand(descale_shape)
|
941
948
|
v_descale = layer.v_scale.expand(descale_shape)
|
942
949
|
q = q.to(self.kv_cache_dtype)
|
943
|
-
|
950
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
951
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
944
952
|
if not self.use_mla:
|
945
953
|
# Do multi-head attention
|
946
954
|
|
@@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1048
1056
|
o = result
|
1049
1057
|
else:
|
1050
1058
|
# Do absorbed multi-latent attention
|
1051
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
1059
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
1060
|
+
q.dtype
|
1061
|
+
)
|
1052
1062
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
1053
1063
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
1054
1064
|
k_rope_cache = k_rope.view(
|
@@ -1120,7 +1130,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1120
1130
|
|
1121
1131
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1122
1132
|
|
1123
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1133
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1124
1134
|
"""Initialize CUDA graph state for the attention backend.
|
1125
1135
|
|
1126
1136
|
Args:
|
@@ -1704,14 +1714,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1704
1714
|
|
1705
1715
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1706
1716
|
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1717
|
+
|
1707
1718
|
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1708
1719
|
# metadata_expand.cu_seqlens_q already set in capture
|
1709
|
-
|
1710
1720
|
offsets = torch.arange(
|
1711
1721
|
self.speculative_num_draft_tokens, device=device
|
1712
1722
|
).unsqueeze(
|
1713
1723
|
0
|
1714
1724
|
) # shape: (1, self.speculative_num_draft_tokens)
|
1725
|
+
|
1715
1726
|
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1716
1727
|
cum_len = torch.nn.functional.pad(
|
1717
1728
|
torch.cumsum(
|
@@ -1728,17 +1739,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1728
1739
|
).view(1, -1)
|
1729
1740
|
# avoid extracting padded seq indices which will be out of boundary
|
1730
1741
|
mask_extraction_indices[
|
1731
|
-
:,
|
1742
|
+
:,
|
1743
|
+
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
|
1732
1744
|
].fill_(0)
|
1733
|
-
|
1734
1745
|
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1735
1746
|
-1, self.speculative_num_draft_tokens
|
1736
1747
|
) # (bsz * draft_num, draft_num)
|
1748
|
+
|
1737
1749
|
col_indices = offsets.expand(
|
1738
1750
|
mask.shape[0], self.speculative_num_draft_tokens
|
1739
1751
|
)
|
1740
1752
|
keys = torch.where(
|
1741
|
-
mask,
|
1753
|
+
mask,
|
1754
|
+
col_indices,
|
1755
|
+
col_indices + self.speculative_num_draft_tokens,
|
1742
1756
|
)
|
1743
1757
|
_, sort_order = torch.sort(keys, dim=1)
|
1744
1758
|
|
@@ -1747,6 +1761,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1747
1761
|
.gather(1, cols)
|
1748
1762
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1749
1763
|
) # (bsz, draft_num)
|
1764
|
+
|
1750
1765
|
metadata_expand.page_table.copy_(
|
1751
1766
|
non_masked_page_table.gather(1, sort_order)
|
1752
1767
|
)
|
@@ -1758,6 +1773,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1758
1773
|
dtype=torch.int32,
|
1759
1774
|
)
|
1760
1775
|
)
|
1776
|
+
|
1761
1777
|
elif forward_mode.is_draft_extend():
|
1762
1778
|
metadata = self.draft_extend_metadata[bs]
|
1763
1779
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
@@ -1767,7 +1783,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1767
1783
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1768
1784
|
)
|
1769
1785
|
accept_length = spec_info.accept_length[:bs]
|
1770
|
-
|
1786
|
+
if spec_info.accept_length_cpu:
|
1787
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
1788
|
+
else:
|
1789
|
+
metadata.max_seq_len_q = 1
|
1790
|
+
|
1771
1791
|
metadata.cu_seqlens_q[1:].copy_(
|
1772
1792
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
1773
1793
|
)
|
@@ -1807,7 +1827,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1807
1827
|
|
1808
1828
|
def get_cuda_graph_seq_len_fill_value(self):
|
1809
1829
|
"""Get the fill value for sequence length in CUDA graph."""
|
1810
|
-
return
|
1830
|
+
return 1
|
1811
1831
|
|
1812
1832
|
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1813
1833
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
@@ -1999,9 +2019,9 @@ class FlashAttentionMultiStepBackend:
|
|
1999
2019
|
for i in range(self.speculative_num_steps - 1):
|
2000
2020
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
2001
2021
|
|
2002
|
-
def init_cuda_graph_state(self, max_bs: int):
|
2022
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
2003
2023
|
for i in range(self.speculative_num_steps):
|
2004
|
-
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
2024
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
2005
2025
|
|
2006
2026
|
def init_forward_metadata_capture_cuda_graph(
|
2007
2027
|
self,
|
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
262
262
|
)
|
263
263
|
|
264
264
|
def init_cuda_graph_state(
|
265
|
-
self,
|
265
|
+
self,
|
266
|
+
max_bs: int,
|
267
|
+
max_num_tokens: int,
|
268
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
266
269
|
):
|
267
270
|
if kv_indices_buf is None:
|
268
271
|
cuda_graph_kv_indices = torch.zeros(
|
269
|
-
(
|
272
|
+
(max_num_tokens * self.max_context_len,),
|
270
273
|
dtype=torch.int32,
|
271
274
|
device="cuda",
|
272
275
|
)
|
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
285
288
|
|
286
289
|
if not self.skip_prefill:
|
287
290
|
self.cuda_graph_custom_mask = torch.zeros(
|
288
|
-
(
|
291
|
+
(max_num_tokens * self.max_context_len),
|
289
292
|
dtype=torch.uint8,
|
290
293
|
device="cuda",
|
291
294
|
)
|
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
440
443
|
raise ValueError("Invalid forward mode")
|
441
444
|
|
442
445
|
def get_cuda_graph_seq_len_fill_value(self):
|
443
|
-
return
|
446
|
+
return 1
|
444
447
|
|
445
448
|
def forward_extend(
|
446
449
|
self,
|
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1096
1099
|
|
1097
1100
|
self.common_template(forward_batch, kv_indices, call_fn)
|
1098
1101
|
|
1099
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1102
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1100
1103
|
self.cuda_graph_kv_indices = torch.zeros(
|
1101
1104
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
1102
1105
|
dtype=torch.int32,
|
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1105
1108
|
|
1106
1109
|
for i in range(self.speculative_num_steps):
|
1107
1110
|
self.attn_backends[i].init_cuda_graph_state(
|
1108
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1111
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1109
1112
|
)
|
1110
1113
|
|
1111
1114
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
199
199
|
)
|
200
200
|
|
201
201
|
def init_cuda_graph_state(
|
202
|
-
self,
|
202
|
+
self,
|
203
|
+
max_bs: int,
|
204
|
+
max_num_tokens: int,
|
205
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
203
206
|
):
|
204
207
|
if kv_indices_buf is None:
|
205
208
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -364,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
364
367
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
365
368
|
|
366
369
|
def get_cuda_graph_seq_len_fill_value(self):
|
367
|
-
return
|
370
|
+
return 1
|
368
371
|
|
369
372
|
def forward_extend(
|
370
373
|
self,
|
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
852
855
|
|
853
856
|
self.common_template(forward_batch, kv_indices, call_fn)
|
854
857
|
|
855
|
-
def init_cuda_graph_state(self, max_bs: int):
|
858
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
856
859
|
self.cuda_graph_kv_indices = torch.zeros(
|
857
860
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
858
861
|
dtype=torch.int32,
|
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
861
864
|
|
862
865
|
for i in range(self.speculative_num_steps):
|
863
866
|
self.attn_backends[i].init_cuda_graph_state(
|
864
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
867
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
865
868
|
)
|
866
869
|
|
867
870
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
148
148
|
def init_cuda_graph_state(
|
149
149
|
self,
|
150
150
|
max_bs: int,
|
151
|
+
max_num_tokens: int,
|
151
152
|
block_kv_indices: Optional[torch.Tensor] = None,
|
152
153
|
):
|
153
154
|
if block_kv_indices is None:
|
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
|
|
502
503
|
|
503
504
|
self.common_template(forward_batch, call_fn)
|
504
505
|
|
505
|
-
def init_cuda_graph_state(self, max_bs: int):
|
506
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
506
507
|
for i in range(self.speculative_num_steps):
|
507
|
-
self.attn_backends[i].init_cuda_graph_state(
|
508
|
+
self.attn_backends[i].init_cuda_graph_state(
|
509
|
+
max_bs, max_num_tokens, block_kv_indices=None
|
510
|
+
)
|
508
511
|
|
509
512
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
510
513
|
def call_fn(i, forward_batch):
|
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
|
|
32
32
|
if forward_batch_child.batch_size > 0:
|
33
33
|
child.init_forward_metadata(forward_batch=forward_batch_child)
|
34
34
|
|
35
|
-
def init_cuda_graph_state(self, max_bs: int):
|
36
|
-
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
35
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
36
|
+
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
37
37
|
for item in self.children:
|
38
38
|
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
39
|
-
item.init_cuda_graph_state(max_bs=max_bs)
|
39
|
+
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
40
40
|
|
41
41
|
def init_forward_metadata_capture_cuda_graph(
|
42
42
|
self,
|
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
261
261
|
num_kv_splits = None
|
262
262
|
attn_logits = None
|
263
263
|
attn_lse = None
|
264
|
+
|
264
265
|
elif forward_batch.forward_mode.is_draft_extend():
|
265
266
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
266
267
|
spec_info.generate_attn_arg_prefill(
|
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
|
|
335
336
|
)
|
336
337
|
|
337
338
|
def init_cuda_graph_state(
|
338
|
-
self,
|
339
|
+
self,
|
340
|
+
max_bs: int,
|
341
|
+
max_num_tokens: int,
|
342
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
339
343
|
):
|
340
344
|
self.cuda_graph_attn_logits = torch.zeros(
|
341
|
-
(
|
345
|
+
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
342
346
|
dtype=torch.float32,
|
343
347
|
device=self.device,
|
344
348
|
)
|
345
349
|
self.cuda_graph_attn_lse = torch.zeros(
|
346
|
-
(
|
350
|
+
(max_num_tokens, self.num_head, self.max_kv_splits),
|
347
351
|
dtype=torch.float32,
|
348
352
|
device=self.device,
|
349
353
|
)
|
350
354
|
self.cuda_graph_num_kv_splits = torch.full(
|
351
|
-
(
|
355
|
+
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
352
356
|
)
|
353
357
|
if kv_indices_buf is None:
|
354
358
|
self.cuda_graph_kv_indices = torch.zeros(
|
355
|
-
(
|
359
|
+
(max_num_tokens * self.max_context_len),
|
356
360
|
dtype=torch.int32,
|
357
361
|
device=self.device,
|
358
362
|
)
|
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
361
365
|
|
362
366
|
if not self.skip_prefill:
|
363
367
|
self.cuda_graph_custom_mask = torch.zeros(
|
364
|
-
(
|
368
|
+
(max_num_tokens * self.max_context_len),
|
365
369
|
dtype=torch.uint8,
|
366
370
|
device=self.device,
|
367
371
|
)
|
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
369
373
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
370
374
|
if kv_indices_buf is None:
|
371
375
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
372
|
-
(
|
376
|
+
(max_num_tokens * self.sliding_window_size),
|
373
377
|
dtype=torch.int32,
|
374
378
|
device=self.device,
|
375
379
|
)
|
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
377
381
|
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
378
382
|
|
379
383
|
self.cuda_graph_window_num_kv_splits = torch.full(
|
380
|
-
(
|
384
|
+
(max_num_tokens,),
|
385
|
+
self.max_kv_splits,
|
386
|
+
dtype=torch.int32,
|
387
|
+
device=self.device,
|
381
388
|
)
|
382
389
|
|
383
390
|
def init_forward_metadata_capture_cuda_graph(
|
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
458
465
|
)
|
459
466
|
|
460
467
|
custom_mask = self.cuda_graph_custom_mask
|
468
|
+
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
461
469
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
462
470
|
mask_indptr = self.mask_indptr[: bs + 1]
|
463
471
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
|
|
821
829
|
|
822
830
|
self.common_template(forward_batch, kv_indices, call_fn)
|
823
831
|
|
824
|
-
def init_cuda_graph_state(self, max_bs: int):
|
832
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
825
833
|
self.cuda_graph_kv_indices = torch.zeros(
|
826
|
-
(self.speculative_num_steps,
|
834
|
+
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
827
835
|
dtype=torch.int32,
|
828
836
|
device=self.device,
|
829
837
|
)
|
830
838
|
for i in range(self.speculative_num_steps):
|
831
839
|
self.attn_backends[i].init_cuda_graph_state(
|
832
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
840
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
833
841
|
)
|
834
842
|
|
835
843
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
|
28
28
|
attn_tp_reduce_scatter,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
|
+
get_attention_dp_size,
|
31
32
|
get_attention_tp_rank,
|
32
33
|
get_attention_tp_size,
|
33
|
-
get_local_attention_dp_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -229,7 +229,7 @@ class CommunicateContext:
|
|
229
229
|
process_group_sizes: Dict[ScatterMode, int]
|
230
230
|
attn_tp_rank: int
|
231
231
|
attn_tp_size: int
|
232
|
-
|
232
|
+
attn_dp_size: int
|
233
233
|
tp_size: int
|
234
234
|
|
235
235
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
@@ -239,7 +239,7 @@ class CommunicateContext:
|
|
239
239
|
def init_new(cls):
|
240
240
|
attn_tp_rank = get_attention_tp_rank()
|
241
241
|
attn_tp_size = get_attention_tp_size()
|
242
|
-
|
242
|
+
attn_dp_size = get_attention_dp_size()
|
243
243
|
tp_size = get_tensor_model_parallel_world_size()
|
244
244
|
process_group_sizes = {
|
245
245
|
ScatterMode.SCATTERED: 1,
|
@@ -251,7 +251,7 @@ class CommunicateContext:
|
|
251
251
|
process_group_sizes=process_group_sizes,
|
252
252
|
attn_tp_rank=attn_tp_rank,
|
253
253
|
attn_tp_size=attn_tp_size,
|
254
|
-
|
254
|
+
attn_dp_size=attn_dp_size,
|
255
255
|
tp_size=tp_size,
|
256
256
|
)
|
257
257
|
|
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
385
385
|
attn_tp_all_gather(
|
386
386
|
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
387
387
|
)
|
388
|
-
if context.
|
388
|
+
if context.attn_dp_size != 1:
|
389
389
|
if context.attn_tp_rank == 0:
|
390
390
|
hidden_states += residual
|
391
391
|
hidden_states, local_hidden_states = (
|
@@ -165,7 +165,8 @@ def disable_dp_size():
|
|
165
165
|
|
166
166
|
|
167
167
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
168
|
-
|
168
|
+
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
169
|
+
dp_rank = get_attention_dp_rank()
|
169
170
|
|
170
171
|
if forward_batch.dp_local_start_pos is None:
|
171
172
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -238,6 +239,10 @@ def _dp_gather(
|
|
238
239
|
assert (
|
239
240
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
240
241
|
), "aliasing between global_tokens and local_tokens not allowed"
|
242
|
+
if forward_batch.forward_mode.is_draft_extend():
|
243
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
244
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
245
|
+
|
241
246
|
memcpy_triton(
|
242
247
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
243
248
|
)
|
@@ -288,6 +293,10 @@ def dp_scatter(
|
|
288
293
|
assert (
|
289
294
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
290
295
|
), "aliasing between local_tokens and global_tokens not allowed"
|
296
|
+
if forward_batch.forward_mode.is_draft_extend():
|
297
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
298
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
299
|
+
|
291
300
|
memcpy_triton(
|
292
301
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
293
302
|
)
|
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
|
|
301
310
|
|
302
311
|
|
303
312
|
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
304
|
-
return get_attention_tp_group().all_gather(input_,
|
313
|
+
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -20,11 +20,21 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
22
|
from sglang.srt.custom_op import CustomOp
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
cpu_has_amx_support,
|
25
|
+
get_bool_env_var,
|
26
|
+
is_cpu,
|
27
|
+
is_cuda,
|
28
|
+
is_hip,
|
29
|
+
is_npu,
|
30
|
+
)
|
24
31
|
|
25
32
|
_is_cuda = is_cuda()
|
26
33
|
_is_hip = is_hip()
|
34
|
+
_is_npu = is_npu()
|
27
35
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
36
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
37
|
+
_is_cpu = is_cpu()
|
28
38
|
|
29
39
|
if _is_cuda:
|
30
40
|
from sgl_kernel import (
|
@@ -42,6 +52,9 @@ elif _is_hip:
|
|
42
52
|
|
43
53
|
logger = logging.getLogger(__name__)
|
44
54
|
|
55
|
+
if is_npu():
|
56
|
+
import torch_npu
|
57
|
+
|
45
58
|
|
46
59
|
class RMSNorm(CustomOp):
|
47
60
|
def __init__(
|
@@ -66,6 +79,18 @@ class RMSNorm(CustomOp):
|
|
66
79
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
67
80
|
return out
|
68
81
|
|
82
|
+
def forward_npu(
|
83
|
+
self,
|
84
|
+
x: torch.Tensor,
|
85
|
+
residual: Optional[torch.Tensor] = None,
|
86
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
87
|
+
if residual is not None:
|
88
|
+
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
89
|
+
residual, x, self.weight.data, self.variance_epsilon
|
90
|
+
)
|
91
|
+
return out, residual_out
|
92
|
+
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
93
|
+
|
69
94
|
def forward_aiter(
|
70
95
|
self,
|
71
96
|
x: torch.Tensor,
|
@@ -121,6 +146,23 @@ class RMSNorm(CustomOp):
|
|
121
146
|
else:
|
122
147
|
return x, residual
|
123
148
|
|
149
|
+
def forward_cpu(
|
150
|
+
self,
|
151
|
+
x: torch.Tensor,
|
152
|
+
residual: Optional[torch.Tensor] = None,
|
153
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
154
|
+
if _is_cpu_amx_available:
|
155
|
+
if residual is not None:
|
156
|
+
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
157
|
+
x, residual, self.weight.data, self.variance_epsilon
|
158
|
+
)
|
159
|
+
return x, residual
|
160
|
+
return torch.ops.sgl_kernel.rmsnorm_cpu(
|
161
|
+
x, self.weight.data, self.variance_epsilon
|
162
|
+
)
|
163
|
+
else:
|
164
|
+
return self.forward_native(x, residual)
|
165
|
+
|
124
166
|
|
125
167
|
class GemmaRMSNorm(CustomOp):
|
126
168
|
def __init__(
|
@@ -187,7 +229,7 @@ class Gemma3RMSNorm(nn.Module):
|
|
187
229
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
188
230
|
|
189
231
|
|
190
|
-
if not (_is_cuda or _is_hip):
|
232
|
+
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
191
233
|
logger.info(
|
192
234
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
193
235
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
32
32
|
)
|
33
|
-
from sglang.srt.utils import
|
33
|
+
from sglang.srt.utils import (
|
34
|
+
_process_weight_after_loading,
|
35
|
+
cpu_has_amx_support,
|
36
|
+
is_cpu,
|
37
|
+
set_weight_attrs,
|
38
|
+
)
|
34
39
|
|
35
40
|
logger = logging.getLogger(__name__)
|
36
41
|
|
@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
52
57
|
"IPEXAWQLinearMethod",
|
53
58
|
]
|
54
59
|
|
60
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
61
|
+
_is_cpu = is_cpu()
|
62
|
+
|
55
63
|
|
56
64
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
57
65
|
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
165
173
|
layer.register_parameter("weight", weight)
|
166
174
|
set_weight_attrs(weight, extra_weight_attrs)
|
167
175
|
|
176
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
177
|
+
if _is_cpu and _is_cpu_amx_available:
|
178
|
+
_process_weight_after_loading(layer, ["weight"])
|
179
|
+
|
168
180
|
def apply(
|
169
181
|
self,
|
170
182
|
layer: torch.nn.Module,
|
@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
172
184
|
bias: Optional[torch.Tensor] = None,
|
173
185
|
) -> torch.Tensor:
|
174
186
|
|
187
|
+
if getattr(layer, "use_intel_amx_backend", False):
|
188
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
189
|
+
x, layer.weight, bias, True # is_vnni
|
190
|
+
)
|
191
|
+
|
175
192
|
return F.linear(x, layer.weight, bias)
|
176
193
|
|
177
194
|
|
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
|
|
30
30
|
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
+
get_attention_dp_rank,
|
33
34
|
get_attention_dp_size,
|
34
35
|
get_attention_tp_size,
|
35
|
-
get_local_attention_dp_rank,
|
36
36
|
get_local_attention_dp_size,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
@@ -171,7 +171,7 @@ class LogitsMetadata:
|
|
171
171
|
return
|
172
172
|
|
173
173
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
174
|
-
dp_rank =
|
174
|
+
dp_rank = get_attention_dp_rank()
|
175
175
|
if dp_rank == 0:
|
176
176
|
dp_local_start_pos = torch.zeros_like(
|
177
177
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
|
|
442
442
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
443
443
|
|
444
444
|
if hasattr(lm_head, "weight"):
|
445
|
-
|
446
|
-
|
447
|
-
|
445
|
+
if getattr(lm_head, "use_intel_amx_backend", False):
|
446
|
+
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
447
|
+
hidden_states.to(lm_head.weight.dtype),
|
448
|
+
lm_head.weight,
|
449
|
+
None, # bias
|
450
|
+
True, # is_vnni
|
451
|
+
)
|
452
|
+
else:
|
453
|
+
logits = torch.matmul(
|
454
|
+
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
455
|
+
)
|
448
456
|
else:
|
449
457
|
# GGUF models
|
458
|
+
# TODO: use weight_packed_linear for GGUF models
|
450
459
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
451
460
|
|
452
461
|
if self.logit_scale is not None:
|