sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
338
338
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
339
339
|
metadata = FlashAttentionMetadata()
|
340
340
|
seqlens_in_batch = forward_batch.seq_lens
|
341
|
-
batch_size =
|
341
|
+
batch_size = forward_batch.batch_size
|
342
342
|
device = seqlens_in_batch.device
|
343
343
|
|
344
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
913
913
|
# Use precomputed metadata across all layers
|
914
914
|
metadata = self.forward_metadata
|
915
915
|
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
916
|
-
|
917
|
-
self.attention_chunk_size is not None
|
916
|
+
use_local_attn = (
|
917
|
+
self.attention_chunk_size is not None
|
918
|
+
and local_attn_metadata is not None
|
919
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
918
920
|
)
|
919
921
|
# We do cascade attention for Draft Decode with topk > 1
|
920
922
|
use_cascade_attn = self.topk > 1
|
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
970
972
|
k_descale=k_descale,
|
971
973
|
v_descale=v_descale,
|
972
974
|
)
|
973
|
-
elif
|
975
|
+
elif use_local_attn:
|
974
976
|
# Use chunked (local) attention batching for self-attention
|
975
977
|
o = flash_attn_with_kvcache(
|
976
978
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
979
981
|
page_table=local_attn_metadata.local_block_table,
|
980
982
|
cache_seqlens=local_attn_metadata.local_seqused_k,
|
981
983
|
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
982
|
-
cu_seqlens_k_new=
|
984
|
+
cu_seqlens_k_new=None,
|
983
985
|
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
984
986
|
softmax_scale=layer.scaling,
|
985
987
|
causal=True,
|
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1127
1129
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
1128
1130
|
to avoid memory allocations.
|
1129
1131
|
"""
|
1130
|
-
|
1131
1132
|
# This is being used by normal decode and draft decode when topk == 1
|
1132
1133
|
self.decode_cuda_graph_metadata = {
|
1133
1134
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1154
1155
|
),
|
1155
1156
|
}
|
1156
1157
|
|
1158
|
+
# Only allocate local attention buffers if local attention is enabled
|
1159
|
+
# This prevents OOM errors when local attention is not being used
|
1160
|
+
if self.attention_chunk_size is not None:
|
1161
|
+
# Estimate maximum sizes for local attention metadata
|
1162
|
+
max_seq_len = self.max_context_len
|
1163
|
+
page_size = self.page_size or 1
|
1164
|
+
attn_chunk_size = self.attention_chunk_size
|
1165
|
+
max_virtual_batches = max_bs * (
|
1166
|
+
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1167
|
+
)
|
1168
|
+
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1169
|
+
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
1170
|
+
|
1171
|
+
self.decode_cuda_graph_local_attn_metadata = {
|
1172
|
+
"local_query_start_loc": torch.zeros(
|
1173
|
+
max_virtual_batches + 1, dtype=torch.int32, device=self.device
|
1174
|
+
),
|
1175
|
+
"local_seqused_k": torch.zeros(
|
1176
|
+
max_virtual_batches, dtype=torch.int32, device=self.device
|
1177
|
+
),
|
1178
|
+
"local_block_table": torch.zeros(
|
1179
|
+
max_virtual_batches,
|
1180
|
+
max_blocks_per_seq * max_pages_per_block,
|
1181
|
+
dtype=torch.int32,
|
1182
|
+
device=self.device,
|
1183
|
+
),
|
1184
|
+
}
|
1185
|
+
|
1157
1186
|
# This is used by draft decode's first half of metadata when topk > 1
|
1158
1187
|
if self.topk > 1:
|
1159
1188
|
self.draft_decode_metadata_topk_normal = {
|
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1405
1434
|
)
|
1406
1435
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1407
1436
|
|
1437
|
+
if self.attention_chunk_size is not None:
|
1438
|
+
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1439
|
+
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
1440
|
+
"local_query_start_loc"
|
1441
|
+
],
|
1442
|
+
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
1443
|
+
"local_seqused_k"
|
1444
|
+
],
|
1445
|
+
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
1446
|
+
"local_block_table"
|
1447
|
+
],
|
1448
|
+
local_max_query_len=1,
|
1449
|
+
local_max_seq_len=1,
|
1450
|
+
)
|
1451
|
+
|
1408
1452
|
elif forward_mode.is_target_verify():
|
1409
1453
|
if self.topk <= 1:
|
1410
1454
|
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1525
1569
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1526
1570
|
self.speculative_step_id + 1
|
1527
1571
|
)
|
1528
|
-
metadata.cu_seqlens_k.copy_(
|
1529
|
-
torch.
|
1530
|
-
torch.
|
1531
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1532
|
-
),
|
1533
|
-
(1, 0),
|
1572
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1573
|
+
torch.cumsum(
|
1574
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1534
1575
|
)
|
1535
1576
|
)
|
1536
1577
|
|
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1554
1595
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
1555
1596
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1556
1597
|
# metadata.cu_seqlens_q already set in capture
|
1557
|
-
metadata.cu_seqlens_k.copy_(
|
1558
|
-
torch.
|
1559
|
-
torch.
|
1560
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1561
|
-
),
|
1562
|
-
(1, 0),
|
1598
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1599
|
+
torch.cumsum(
|
1600
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1563
1601
|
)
|
1564
1602
|
)
|
1565
1603
|
|
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1616
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1579
1617
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1580
1618
|
)
|
1581
|
-
# TODO:
|
1582
|
-
self._init_local_attn_metadata(metadata, device)
|
1619
|
+
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
1583
1620
|
else:
|
1584
1621
|
metadata = self.decode_cuda_graph_metadata[bs]
|
1585
1622
|
# Normal Decode
|
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1587
1624
|
metadata.max_seq_len_k = max_len
|
1588
1625
|
|
1589
1626
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
1590
|
-
|
1591
|
-
|
1627
|
+
# Optimize cumulative sequence length calculation
|
1628
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1629
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
1592
1630
|
)
|
1593
1631
|
|
1594
1632
|
max_seq_pages = (
|
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1604
1642
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1605
1643
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
1606
1644
|
|
1607
|
-
self.
|
1645
|
+
self._update_local_attn_metadata_for_replay(metadata, bs)
|
1608
1646
|
elif forward_mode.is_target_verify():
|
1609
1647
|
if self.topk <= 1:
|
1610
1648
|
metadata = self.target_verify_metadata[bs]
|
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1615
1653
|
metadata.max_seq_len_k = (
|
1616
1654
|
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1617
1655
|
)
|
1618
|
-
metadata.cu_seqlens_k.copy_(
|
1619
|
-
torch.
|
1620
|
-
torch.cumsum(
|
1621
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1622
|
-
),
|
1623
|
-
(1, 0),
|
1624
|
-
)
|
1656
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1657
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1625
1658
|
)
|
1626
1659
|
max_seq_pages = (
|
1627
1660
|
metadata.max_seq_len_k + self.page_size - 1
|
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1640
1673
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1641
1674
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1642
1675
|
# metadata.cu_seqlens_q already set in capture
|
1643
|
-
metadata.cu_seqlens_k.copy_(
|
1644
|
-
torch.
|
1645
|
-
torch.cumsum(
|
1646
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1647
|
-
),
|
1648
|
-
(1, 0),
|
1649
|
-
)
|
1676
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1677
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1650
1678
|
)
|
1651
1679
|
page_table = self.req_to_token[
|
1652
1680
|
req_pool_indices, : metadata.max_seq_len_k
|
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1704
1732
|
metadata_expand.cache_seqlens_int32.copy_(
|
1705
1733
|
mask.sum(dim=1).to(torch.int32)
|
1706
1734
|
)
|
1707
|
-
metadata_expand.cu_seqlens_k.copy_(
|
1708
|
-
torch.
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
dtype=torch.int32,
|
1713
|
-
),
|
1714
|
-
(1, 0),
|
1735
|
+
metadata_expand.cu_seqlens_k[1:].copy_(
|
1736
|
+
torch.cumsum(
|
1737
|
+
metadata_expand.cache_seqlens_int32,
|
1738
|
+
dim=0,
|
1739
|
+
dtype=torch.int32,
|
1715
1740
|
)
|
1716
1741
|
)
|
1717
1742
|
metadata_expand.max_seq_len_k = (
|
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1722
1747
|
# Only support encoder size 1 for now
|
1723
1748
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1724
1749
|
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1725
|
-
metadata.encoder_cu_seqlens_k.copy_(
|
1726
|
-
torch.
|
1727
|
-
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1728
|
-
(1, 0),
|
1729
|
-
)
|
1750
|
+
metadata.encoder_cu_seqlens_k[1:].copy_(
|
1751
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
|
1730
1752
|
)
|
1731
1753
|
|
1732
1754
|
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1776
1798
|
page_table,
|
1777
1799
|
self.page_size,
|
1778
1800
|
)
|
1801
|
+
|
1779
1802
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1780
1803
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1781
1804
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1785
1808
|
)
|
1786
1809
|
metadata.local_attn_metadata = local_metadata
|
1787
1810
|
|
1811
|
+
def _update_local_attn_metadata_for_replay(
|
1812
|
+
self, metadata: FlashAttentionMetadata, bs: int
|
1813
|
+
):
|
1814
|
+
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
1815
|
+
if self.attention_chunk_size is None:
|
1816
|
+
return
|
1817
|
+
|
1818
|
+
# Access preallocated buffers
|
1819
|
+
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
|
1820
|
+
"local_query_start_loc"
|
1821
|
+
]
|
1822
|
+
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
|
1823
|
+
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
|
1824
|
+
"local_block_table"
|
1825
|
+
]
|
1826
|
+
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
|
1827
|
+
|
1828
|
+
# Create a modified version for local attention that only processes the last token
|
1829
|
+
# This mimics the normal decode pattern
|
1830
|
+
cu_seqlens_q = torch.arange(
|
1831
|
+
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
|
1832
|
+
)
|
1833
|
+
seqlens = metadata.cache_seqlens_int32[:bs]
|
1834
|
+
# Slice the page_table to match the batch size and actual sequence length
|
1835
|
+
# This serves three important purposes:
|
1836
|
+
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
|
1837
|
+
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
|
1838
|
+
# 3. Prevents zeros in the block table which can cause garbage output during replay
|
1839
|
+
#
|
1840
|
+
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
1841
|
+
# beyond the actual sequence length, leading to incorrect attention calculations
|
1842
|
+
max_seq_len = int(seqlens.max().item())
|
1843
|
+
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
1844
|
+
|
1845
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1846
|
+
seqlens_np = seqlens.cpu().numpy()
|
1847
|
+
(
|
1848
|
+
seqlens_q_local_np,
|
1849
|
+
cu_seqlens_q_local_np,
|
1850
|
+
seqlens_k_local_np,
|
1851
|
+
block_table_local,
|
1852
|
+
) = make_local_attention_virtual_batches(
|
1853
|
+
self.attention_chunk_size,
|
1854
|
+
cu_seqlens_q_np,
|
1855
|
+
seqlens_np,
|
1856
|
+
sliced_page_table,
|
1857
|
+
self.page_size,
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
# Convert back to tensors
|
1861
|
+
device = local_q_buf.device
|
1862
|
+
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
|
1863
|
+
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
|
1864
|
+
block_table_local = block_table_local.to(device)
|
1865
|
+
# Get sizes
|
1866
|
+
q_len = cu_seqlens_q_local.shape[0]
|
1867
|
+
k_len = seqlens_k_local.shape[0]
|
1868
|
+
b0, b1 = block_table_local.shape
|
1869
|
+
|
1870
|
+
# In-place updates into preallocated tensors and zero out the unused space
|
1871
|
+
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
|
1872
|
+
local_q_buf[q_len:].fill_(0)
|
1873
|
+
local_k_buf[:k_len].copy_(seqlens_k_local)
|
1874
|
+
local_k_buf[k_len:].fill_(0)
|
1875
|
+
local_block_buf[:b0, :b1].copy_(block_table_local)
|
1876
|
+
local_block_buf[b0:, :].fill_(0)
|
1877
|
+
local_block_buf[:b0, b1:].fill_(0)
|
1878
|
+
|
1879
|
+
if metadata.local_attn_metadata is not None:
|
1880
|
+
lam = metadata.local_attn_metadata
|
1881
|
+
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
1882
|
+
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
1883
|
+
|
1788
1884
|
|
1789
1885
|
class FlashAttentionMultiStepBackend:
|
1790
1886
|
|
@@ -15,6 +15,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
15
15
|
|
16
16
|
import torch
|
17
17
|
|
18
|
+
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
+
import logging
|
20
|
+
|
21
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
22
|
+
torch._dynamo.config.suppress_errors = True
|
23
|
+
|
18
24
|
from sglang.global_config import global_config
|
19
25
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
@@ -82,8 +88,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
82
88
|
self.max_context_len = model_runner.model_config.context_len
|
83
89
|
self.skip_prefill = skip_prefill
|
84
90
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
85
|
-
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
86
|
-
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
87
91
|
|
88
92
|
assert not (
|
89
93
|
model_runner.sliding_window_size is not None
|
@@ -104,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
104
108
|
if (
|
105
109
|
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
106
110
|
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
111
|
+
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
|
107
112
|
):
|
108
113
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
109
114
|
|
@@ -268,6 +273,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
268
273
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
269
274
|
]
|
270
275
|
|
276
|
+
# Ensure tensors are properly allocated
|
277
|
+
for i in range(self.num_wrappers):
|
278
|
+
# Force allocation by performing a small operation
|
279
|
+
if len(self.cuda_graph_kv_indices[i]) > 0:
|
280
|
+
self.cuda_graph_kv_indices[i][0] = 0
|
281
|
+
|
271
282
|
if not self.skip_prefill:
|
272
283
|
self.cuda_graph_custom_mask = torch.zeros(
|
273
284
|
(max_bs * self.max_context_len),
|
@@ -396,8 +407,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
396
407
|
forward_batch: ForwardBatch,
|
397
408
|
save_kv_cache=True,
|
398
409
|
):
|
399
|
-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
400
|
-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
401
410
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
402
411
|
self._get_wrapper_idx(layer)
|
403
412
|
]
|
@@ -409,39 +418,47 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
409
418
|
|
410
419
|
logits_soft_cap = layer.logit_cap
|
411
420
|
|
421
|
+
q = q.contiguous()
|
412
422
|
if not self.forward_metadata.use_ragged:
|
413
423
|
if k is not None:
|
414
424
|
assert v is not None
|
415
425
|
if save_kv_cache:
|
416
426
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
417
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
427
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
418
428
|
)
|
419
429
|
|
420
430
|
o = prefill_wrapper_paged.forward(
|
421
|
-
q.
|
431
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
422
432
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
423
433
|
causal=not layer.is_cross_attention,
|
424
434
|
sm_scale=layer.scaling,
|
425
435
|
window_left=layer.sliding_window_size,
|
426
436
|
logits_soft_cap=logits_soft_cap,
|
427
|
-
k_scale=k_scale,
|
428
|
-
v_scale=v_scale,
|
437
|
+
k_scale=layer.k_scale,
|
438
|
+
v_scale=layer.v_scale,
|
429
439
|
)
|
430
440
|
else:
|
431
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
432
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
433
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
434
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
435
|
-
causal=True,
|
436
|
-
sm_scale=layer.scaling,
|
437
|
-
logits_soft_cap=logits_soft_cap,
|
438
|
-
)
|
439
|
-
|
440
441
|
if self.forward_metadata.extend_no_prefix:
|
441
|
-
o =
|
442
|
+
o = self.prefill_wrapper_ragged.forward(
|
443
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
444
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
445
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
446
|
+
causal=True,
|
447
|
+
sm_scale=layer.scaling,
|
448
|
+
logits_soft_cap=logits_soft_cap,
|
449
|
+
)
|
450
|
+
|
442
451
|
else:
|
452
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
453
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
454
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
455
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
456
|
+
causal=True,
|
457
|
+
sm_scale=layer.scaling,
|
458
|
+
logits_soft_cap=logits_soft_cap,
|
459
|
+
)
|
443
460
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
444
|
-
q.
|
461
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
445
462
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
446
463
|
causal=False,
|
447
464
|
sm_scale=layer.scaling,
|
@@ -452,7 +469,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
452
469
|
|
453
470
|
if save_kv_cache:
|
454
471
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
455
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
472
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
456
473
|
)
|
457
474
|
|
458
475
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -466,8 +483,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
466
483
|
forward_batch: ForwardBatch,
|
467
484
|
save_kv_cache=True,
|
468
485
|
):
|
469
|
-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
470
|
-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
471
486
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
472
487
|
self._get_wrapper_idx(layer)
|
473
488
|
]
|
@@ -481,16 +496,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
481
496
|
assert v is not None
|
482
497
|
if save_kv_cache:
|
483
498
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
484
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
499
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
485
500
|
)
|
486
501
|
|
502
|
+
# Call the wrapped function
|
487
503
|
o = decode_wrapper.forward(
|
488
504
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
489
505
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
490
506
|
sm_scale=layer.scaling,
|
491
507
|
logits_soft_cap=layer.logit_cap,
|
492
|
-
k_scale=k_scale,
|
493
|
-
v_scale=v_scale,
|
508
|
+
k_scale=layer.k_scale,
|
509
|
+
v_scale=layer.v_scale,
|
494
510
|
)
|
495
511
|
|
496
512
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -1146,8 +1162,9 @@ def fast_decode_plan(
|
|
1146
1162
|
pos_encoding_mode: str = "NONE",
|
1147
1163
|
window_left: int = -1,
|
1148
1164
|
logits_soft_cap: Optional[float] = None,
|
1149
|
-
data_type: Union[str, torch.dtype] = "float16",
|
1150
1165
|
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1166
|
+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
1167
|
+
data_type: Optional[Union[str, torch.dtype]] = None,
|
1151
1168
|
sm_scale: Optional[float] = None,
|
1152
1169
|
rope_scale: Optional[float] = None,
|
1153
1170
|
rope_theta: Optional[float] = None,
|
@@ -1163,6 +1180,18 @@ def fast_decode_plan(
|
|
1163
1180
|
if logits_soft_cap is None:
|
1164
1181
|
logits_soft_cap = 0.0
|
1165
1182
|
|
1183
|
+
# Handle data types consistently
|
1184
|
+
if data_type is not None:
|
1185
|
+
if q_data_type is None:
|
1186
|
+
q_data_type = data_type
|
1187
|
+
if kv_data_type is None:
|
1188
|
+
kv_data_type = data_type
|
1189
|
+
elif q_data_type is None:
|
1190
|
+
q_data_type = "float16"
|
1191
|
+
|
1192
|
+
if kv_data_type is None:
|
1193
|
+
kv_data_type = q_data_type
|
1194
|
+
|
1166
1195
|
if self.use_tensor_cores:
|
1167
1196
|
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1168
1197
|
|
@@ -1178,36 +1207,33 @@ def fast_decode_plan(
|
|
1178
1207
|
raise ValueError(
|
1179
1208
|
"The size of indices should be less than or equal to the allocated buffer"
|
1180
1209
|
)
|
1181
|
-
# Skip these copies because we directly write to them during prepartion
|
1182
|
-
# self._paged_kv_indptr_buf.copy_(indptr)
|
1183
|
-
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1184
|
-
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
1185
1210
|
else:
|
1186
1211
|
self._paged_kv_indptr_buf = indptr
|
1187
1212
|
self._paged_kv_indices_buf = indices
|
1188
1213
|
self._paged_kv_last_page_len_buf = last_page_len
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
)
|
1210
|
-
self.
|
1214
|
+
if self.use_tensor_cores:
|
1215
|
+
self._qo_indptr_buf = qo_indptr_host.to(
|
1216
|
+
self.device, non_blocking=non_blocking
|
1217
|
+
)
|
1218
|
+
|
1219
|
+
# Create empty tensors for dtype info if needed
|
1220
|
+
empty_q_data = torch.empty(
|
1221
|
+
0,
|
1222
|
+
dtype=(
|
1223
|
+
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
1224
|
+
),
|
1225
|
+
device=self.device,
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
empty_kv_cache = torch.empty(
|
1229
|
+
0,
|
1230
|
+
dtype=(
|
1231
|
+
getattr(torch, kv_data_type)
|
1232
|
+
if isinstance(kv_data_type, str)
|
1233
|
+
else kv_data_type
|
1234
|
+
),
|
1235
|
+
device=self.device,
|
1236
|
+
)
|
1211
1237
|
|
1212
1238
|
indptr_host = (
|
1213
1239
|
global_override_indptr_cpu
|
@@ -1215,48 +1241,57 @@ def fast_decode_plan(
|
|
1215
1241
|
else indptr.cpu()
|
1216
1242
|
)
|
1217
1243
|
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1244
|
+
with torch.cuda.device(self.device):
|
1245
|
+
|
1246
|
+
if self.use_tensor_cores:
|
1247
|
+
# ALSO convert last_page_len to CPU
|
1248
|
+
last_page_len_host = last_page_len.cpu()
|
1249
|
+
|
1250
|
+
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1251
|
+
|
1252
|
+
try:
|
1253
|
+
# Make sure we pass exactly 15 arguments for tensor core version
|
1254
|
+
self._plan_info = self._cached_module.plan(
|
1255
|
+
self._float_workspace_buffer,
|
1256
|
+
self._int_workspace_buffer,
|
1257
|
+
self._pin_memory_int_workspace_buffer,
|
1258
|
+
qo_indptr_host,
|
1259
|
+
indptr_host,
|
1260
|
+
kv_lens_arr_host,
|
1261
|
+
batch_size, # total_num_rows
|
1262
|
+
batch_size,
|
1263
|
+
num_qo_heads,
|
1264
|
+
num_kv_heads,
|
1265
|
+
page_size,
|
1266
|
+
self.is_cuda_graph_enabled,
|
1267
|
+
head_dim,
|
1268
|
+
head_dim,
|
1269
|
+
False, # causal
|
1270
|
+
)
|
1271
|
+
except Exception as e:
|
1272
|
+
raise RuntimeError(f"Error in standard plan: {e}")
|
1273
|
+
else:
|
1274
|
+
try:
|
1275
|
+
# Make sure we pass exactly 15 arguments for standard version
|
1276
|
+
self._plan_info = self._cached_module.plan(
|
1277
|
+
self._float_workspace_buffer,
|
1278
|
+
self._int_workspace_buffer,
|
1279
|
+
self._pin_memory_int_workspace_buffer,
|
1280
|
+
indptr_host,
|
1281
|
+
batch_size,
|
1282
|
+
num_qo_heads,
|
1283
|
+
num_kv_heads,
|
1284
|
+
page_size,
|
1285
|
+
self.is_cuda_graph_enabled,
|
1286
|
+
window_left,
|
1287
|
+
logits_soft_cap,
|
1288
|
+
head_dim,
|
1289
|
+
head_dim,
|
1290
|
+
empty_q_data,
|
1291
|
+
empty_kv_cache,
|
1292
|
+
)
|
1293
|
+
except Exception as e:
|
1294
|
+
raise RuntimeError(f"Error in standard plan: {e}")
|
1260
1295
|
|
1261
1296
|
self._pos_encoding_mode = pos_encoding_mode
|
1262
1297
|
self._window_left = window_left
|