sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -308,7 +308,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
308
308
|
), "Sliding window and cross attention are not supported together"
|
309
309
|
|
310
310
|
self.forward_metadata: FlashAttentionMetadata = None
|
311
|
-
# extra
|
311
|
+
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
|
312
312
|
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
313
313
|
self.max_context_len = model_runner.model_config.context_len
|
314
314
|
self.device = model_runner.device
|
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
338
338
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
339
339
|
metadata = FlashAttentionMetadata()
|
340
340
|
seqlens_in_batch = forward_batch.seq_lens
|
341
|
-
batch_size =
|
341
|
+
batch_size = forward_batch.batch_size
|
342
342
|
device = seqlens_in_batch.device
|
343
343
|
|
344
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
913
913
|
# Use precomputed metadata across all layers
|
914
914
|
metadata = self.forward_metadata
|
915
915
|
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
916
|
-
|
917
|
-
self.attention_chunk_size is not None
|
916
|
+
use_local_attn = (
|
917
|
+
self.attention_chunk_size is not None
|
918
|
+
and local_attn_metadata is not None
|
919
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
918
920
|
)
|
919
921
|
# We do cascade attention for Draft Decode with topk > 1
|
920
922
|
use_cascade_attn = self.topk > 1
|
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
970
972
|
k_descale=k_descale,
|
971
973
|
v_descale=v_descale,
|
972
974
|
)
|
973
|
-
elif
|
975
|
+
elif use_local_attn:
|
974
976
|
# Use chunked (local) attention batching for self-attention
|
975
977
|
o = flash_attn_with_kvcache(
|
976
978
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
979
981
|
page_table=local_attn_metadata.local_block_table,
|
980
982
|
cache_seqlens=local_attn_metadata.local_seqused_k,
|
981
983
|
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
982
|
-
cu_seqlens_k_new=
|
984
|
+
cu_seqlens_k_new=None,
|
983
985
|
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
984
986
|
softmax_scale=layer.scaling,
|
985
987
|
causal=True,
|
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1127
1129
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
1128
1130
|
to avoid memory allocations.
|
1129
1131
|
"""
|
1130
|
-
|
1131
1132
|
# This is being used by normal decode and draft decode when topk == 1
|
1132
1133
|
self.decode_cuda_graph_metadata = {
|
1133
1134
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1154
1155
|
),
|
1155
1156
|
}
|
1156
1157
|
|
1158
|
+
# Only allocate local attention buffers if local attention is enabled
|
1159
|
+
# This prevents OOM errors when local attention is not being used
|
1160
|
+
if self.attention_chunk_size is not None:
|
1161
|
+
# Estimate maximum sizes for local attention metadata
|
1162
|
+
max_seq_len = self.max_context_len
|
1163
|
+
page_size = self.page_size or 1
|
1164
|
+
attn_chunk_size = self.attention_chunk_size
|
1165
|
+
max_virtual_batches = max_bs * (
|
1166
|
+
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1167
|
+
)
|
1168
|
+
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1169
|
+
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
1170
|
+
|
1171
|
+
self.decode_cuda_graph_local_attn_metadata = {
|
1172
|
+
"local_query_start_loc": torch.zeros(
|
1173
|
+
max_virtual_batches + 1, dtype=torch.int32, device=self.device
|
1174
|
+
),
|
1175
|
+
"local_seqused_k": torch.zeros(
|
1176
|
+
max_virtual_batches, dtype=torch.int32, device=self.device
|
1177
|
+
),
|
1178
|
+
"local_block_table": torch.zeros(
|
1179
|
+
max_virtual_batches,
|
1180
|
+
max_blocks_per_seq * max_pages_per_block,
|
1181
|
+
dtype=torch.int32,
|
1182
|
+
device=self.device,
|
1183
|
+
),
|
1184
|
+
}
|
1185
|
+
|
1157
1186
|
# This is used by draft decode's first half of metadata when topk > 1
|
1158
1187
|
if self.topk > 1:
|
1159
1188
|
self.draft_decode_metadata_topk_normal = {
|
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1405
1434
|
)
|
1406
1435
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1407
1436
|
|
1437
|
+
if self.attention_chunk_size is not None:
|
1438
|
+
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1439
|
+
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
1440
|
+
"local_query_start_loc"
|
1441
|
+
],
|
1442
|
+
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
1443
|
+
"local_seqused_k"
|
1444
|
+
],
|
1445
|
+
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
1446
|
+
"local_block_table"
|
1447
|
+
],
|
1448
|
+
local_max_query_len=1,
|
1449
|
+
local_max_seq_len=1,
|
1450
|
+
)
|
1451
|
+
|
1408
1452
|
elif forward_mode.is_target_verify():
|
1409
1453
|
if self.topk <= 1:
|
1410
1454
|
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1525
1569
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1526
1570
|
self.speculative_step_id + 1
|
1527
1571
|
)
|
1528
|
-
metadata.cu_seqlens_k.copy_(
|
1529
|
-
torch.
|
1530
|
-
torch.
|
1531
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1532
|
-
),
|
1533
|
-
(1, 0),
|
1572
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1573
|
+
torch.cumsum(
|
1574
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1534
1575
|
)
|
1535
1576
|
)
|
1536
1577
|
|
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1554
1595
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
1555
1596
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1556
1597
|
# metadata.cu_seqlens_q already set in capture
|
1557
|
-
metadata.cu_seqlens_k.copy_(
|
1558
|
-
torch.
|
1559
|
-
torch.
|
1560
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1561
|
-
),
|
1562
|
-
(1, 0),
|
1598
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1599
|
+
torch.cumsum(
|
1600
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1563
1601
|
)
|
1564
1602
|
)
|
1565
1603
|
|
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1616
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1579
1617
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1580
1618
|
)
|
1581
|
-
# TODO:
|
1582
|
-
self._init_local_attn_metadata(metadata, device)
|
1619
|
+
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
1583
1620
|
else:
|
1584
1621
|
metadata = self.decode_cuda_graph_metadata[bs]
|
1585
1622
|
# Normal Decode
|
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1587
1624
|
metadata.max_seq_len_k = max_len
|
1588
1625
|
|
1589
1626
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
1590
|
-
|
1591
|
-
|
1627
|
+
# Optimize cumulative sequence length calculation
|
1628
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1629
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
1592
1630
|
)
|
1593
1631
|
|
1594
1632
|
max_seq_pages = (
|
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1604
1642
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1605
1643
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
1606
1644
|
|
1607
|
-
self.
|
1645
|
+
self._update_local_attn_metadata_for_replay(metadata, bs)
|
1608
1646
|
elif forward_mode.is_target_verify():
|
1609
1647
|
if self.topk <= 1:
|
1610
1648
|
metadata = self.target_verify_metadata[bs]
|
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1615
1653
|
metadata.max_seq_len_k = (
|
1616
1654
|
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1617
1655
|
)
|
1618
|
-
metadata.cu_seqlens_k.copy_(
|
1619
|
-
torch.
|
1620
|
-
torch.cumsum(
|
1621
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1622
|
-
),
|
1623
|
-
(1, 0),
|
1624
|
-
)
|
1656
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1657
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1625
1658
|
)
|
1626
1659
|
max_seq_pages = (
|
1627
1660
|
metadata.max_seq_len_k + self.page_size - 1
|
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1640
1673
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1641
1674
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1642
1675
|
# metadata.cu_seqlens_q already set in capture
|
1643
|
-
metadata.cu_seqlens_k.copy_(
|
1644
|
-
torch.
|
1645
|
-
torch.cumsum(
|
1646
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1647
|
-
),
|
1648
|
-
(1, 0),
|
1649
|
-
)
|
1676
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1677
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1650
1678
|
)
|
1651
1679
|
page_table = self.req_to_token[
|
1652
1680
|
req_pool_indices, : metadata.max_seq_len_k
|
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1704
1732
|
metadata_expand.cache_seqlens_int32.copy_(
|
1705
1733
|
mask.sum(dim=1).to(torch.int32)
|
1706
1734
|
)
|
1707
|
-
metadata_expand.cu_seqlens_k.copy_(
|
1708
|
-
torch.
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
dtype=torch.int32,
|
1713
|
-
),
|
1714
|
-
(1, 0),
|
1735
|
+
metadata_expand.cu_seqlens_k[1:].copy_(
|
1736
|
+
torch.cumsum(
|
1737
|
+
metadata_expand.cache_seqlens_int32,
|
1738
|
+
dim=0,
|
1739
|
+
dtype=torch.int32,
|
1715
1740
|
)
|
1716
1741
|
)
|
1717
1742
|
metadata_expand.max_seq_len_k = (
|
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1722
1747
|
# Only support encoder size 1 for now
|
1723
1748
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1724
1749
|
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1725
|
-
metadata.encoder_cu_seqlens_k.copy_(
|
1726
|
-
torch.
|
1727
|
-
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1728
|
-
(1, 0),
|
1729
|
-
)
|
1750
|
+
metadata.encoder_cu_seqlens_k[1:].copy_(
|
1751
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
|
1730
1752
|
)
|
1731
1753
|
|
1732
1754
|
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1776
1798
|
page_table,
|
1777
1799
|
self.page_size,
|
1778
1800
|
)
|
1801
|
+
|
1779
1802
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1780
1803
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1781
1804
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1785
1808
|
)
|
1786
1809
|
metadata.local_attn_metadata = local_metadata
|
1787
1810
|
|
1811
|
+
def _update_local_attn_metadata_for_replay(
|
1812
|
+
self, metadata: FlashAttentionMetadata, bs: int
|
1813
|
+
):
|
1814
|
+
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
1815
|
+
if self.attention_chunk_size is None:
|
1816
|
+
return
|
1817
|
+
|
1818
|
+
# Access preallocated buffers
|
1819
|
+
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
|
1820
|
+
"local_query_start_loc"
|
1821
|
+
]
|
1822
|
+
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
|
1823
|
+
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
|
1824
|
+
"local_block_table"
|
1825
|
+
]
|
1826
|
+
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
|
1827
|
+
|
1828
|
+
# Create a modified version for local attention that only processes the last token
|
1829
|
+
# This mimics the normal decode pattern
|
1830
|
+
cu_seqlens_q = torch.arange(
|
1831
|
+
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
|
1832
|
+
)
|
1833
|
+
seqlens = metadata.cache_seqlens_int32[:bs]
|
1834
|
+
# Slice the page_table to match the batch size and actual sequence length
|
1835
|
+
# This serves three important purposes:
|
1836
|
+
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
|
1837
|
+
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
|
1838
|
+
# 3. Prevents zeros in the block table which can cause garbage output during replay
|
1839
|
+
#
|
1840
|
+
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
1841
|
+
# beyond the actual sequence length, leading to incorrect attention calculations
|
1842
|
+
max_seq_len = int(seqlens.max().item())
|
1843
|
+
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
1844
|
+
|
1845
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1846
|
+
seqlens_np = seqlens.cpu().numpy()
|
1847
|
+
(
|
1848
|
+
seqlens_q_local_np,
|
1849
|
+
cu_seqlens_q_local_np,
|
1850
|
+
seqlens_k_local_np,
|
1851
|
+
block_table_local,
|
1852
|
+
) = make_local_attention_virtual_batches(
|
1853
|
+
self.attention_chunk_size,
|
1854
|
+
cu_seqlens_q_np,
|
1855
|
+
seqlens_np,
|
1856
|
+
sliced_page_table,
|
1857
|
+
self.page_size,
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
# Convert back to tensors
|
1861
|
+
device = local_q_buf.device
|
1862
|
+
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
|
1863
|
+
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
|
1864
|
+
block_table_local = block_table_local.to(device)
|
1865
|
+
# Get sizes
|
1866
|
+
q_len = cu_seqlens_q_local.shape[0]
|
1867
|
+
k_len = seqlens_k_local.shape[0]
|
1868
|
+
b0, b1 = block_table_local.shape
|
1869
|
+
|
1870
|
+
# In-place updates into preallocated tensors and zero out the unused space
|
1871
|
+
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
|
1872
|
+
local_q_buf[q_len:].fill_(0)
|
1873
|
+
local_k_buf[:k_len].copy_(seqlens_k_local)
|
1874
|
+
local_k_buf[k_len:].fill_(0)
|
1875
|
+
local_block_buf[:b0, :b1].copy_(block_table_local)
|
1876
|
+
local_block_buf[b0:, :].fill_(0)
|
1877
|
+
local_block_buf[:b0, b1:].fill_(0)
|
1878
|
+
|
1879
|
+
if metadata.local_attn_metadata is not None:
|
1880
|
+
lam = metadata.local_attn_metadata
|
1881
|
+
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
1882
|
+
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
1883
|
+
|
1788
1884
|
|
1789
1885
|
class FlashAttentionMultiStepBackend:
|
1790
1886
|
|
@@ -16,8 +16,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
16
16
|
import torch
|
17
17
|
|
18
18
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
-
import
|
19
|
+
import logging
|
20
20
|
|
21
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
21
22
|
torch._dynamo.config.suppress_errors = True
|
22
23
|
|
23
24
|
from sglang.global_config import global_config
|
@@ -107,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
107
108
|
if (
|
108
109
|
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
109
110
|
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
111
|
+
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
|
110
112
|
):
|
111
113
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
112
114
|
|
@@ -416,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
416
418
|
|
417
419
|
logits_soft_cap = layer.logit_cap
|
418
420
|
|
421
|
+
q = q.contiguous()
|
419
422
|
if not self.forward_metadata.use_ragged:
|
420
423
|
if k is not None:
|
421
424
|
assert v is not None
|
@@ -425,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
425
428
|
)
|
426
429
|
|
427
430
|
o = prefill_wrapper_paged.forward(
|
428
|
-
q.
|
431
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
429
432
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
430
433
|
causal=not layer.is_cross_attention,
|
431
434
|
sm_scale=layer.scaling,
|
@@ -435,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
435
438
|
v_scale=layer.v_scale,
|
436
439
|
)
|
437
440
|
else:
|
438
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
439
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
440
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
441
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
442
|
-
causal=True,
|
443
|
-
sm_scale=layer.scaling,
|
444
|
-
logits_soft_cap=logits_soft_cap,
|
445
|
-
)
|
446
|
-
|
447
441
|
if self.forward_metadata.extend_no_prefix:
|
448
|
-
o =
|
442
|
+
o = self.prefill_wrapper_ragged.forward(
|
443
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
444
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
445
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
446
|
+
causal=True,
|
447
|
+
sm_scale=layer.scaling,
|
448
|
+
logits_soft_cap=logits_soft_cap,
|
449
|
+
)
|
450
|
+
|
449
451
|
else:
|
452
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
453
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
454
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
455
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
456
|
+
causal=True,
|
457
|
+
sm_scale=layer.scaling,
|
458
|
+
logits_soft_cap=logits_soft_cap,
|
459
|
+
)
|
450
460
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
451
|
-
q.
|
461
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
452
462
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
453
463
|
causal=False,
|
454
464
|
sm_scale=layer.scaling,
|
@@ -18,8 +18,9 @@ import torch
|
|
18
18
|
import triton
|
19
19
|
|
20
20
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
|
-
import
|
21
|
+
import logging
|
22
22
|
|
23
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
23
24
|
torch._dynamo.config.suppress_errors = True
|
24
25
|
|
25
26
|
from sglang.global_config import global_config
|
@@ -338,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
338
339
|
layer: RadixAttention,
|
339
340
|
forward_batch: ForwardBatch,
|
340
341
|
save_kv_cache: bool = True,
|
342
|
+
q_rope: Optional[torch.Tensor] = None,
|
343
|
+
k_rope: Optional[torch.Tensor] = None,
|
341
344
|
):
|
342
345
|
|
343
346
|
cache_loc = forward_batch.out_cache_loc
|
344
347
|
logits_soft_cap = layer.logit_cap
|
345
348
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
346
|
-
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
347
349
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
348
350
|
|
349
351
|
# Save kv cache
|
350
352
|
if save_kv_cache and k is not None:
|
351
353
|
assert v is not None
|
352
354
|
if save_kv_cache:
|
353
|
-
|
355
|
+
if k_rope is not None:
|
356
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
357
|
+
layer, cache_loc, k, k_rope
|
358
|
+
)
|
359
|
+
else:
|
360
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
361
|
+
if q_rope is not None:
|
362
|
+
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
363
|
+
q_rope = q_rope.view(
|
364
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
365
|
+
)
|
354
366
|
|
355
367
|
if self.forward_metadata.use_ragged:
|
356
368
|
# ragged prefill
|
357
|
-
|
369
|
+
if q_rope is not None:
|
370
|
+
q = torch.cat([q, q_rope], dim=-1)
|
371
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
372
|
+
if k_rope is not None:
|
373
|
+
k = torch.cat([k, k_rope], dim=-1)
|
374
|
+
o = self.prefill_wrapper_ragged.forward(
|
358
375
|
qall,
|
359
376
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
360
377
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
@@ -364,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
364
381
|
)
|
365
382
|
else:
|
366
383
|
# mla paged prefill
|
384
|
+
if q_rope is None:
|
385
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
386
|
+
q, q_rope = (
|
387
|
+
qall[:, :, : layer.v_head_dim],
|
388
|
+
qall[:, :, layer.v_head_dim :],
|
389
|
+
)
|
390
|
+
o = q.new_empty(q.shape)
|
367
391
|
o = prefill_wrapper_paged.run(
|
368
|
-
|
369
|
-
|
392
|
+
q,
|
393
|
+
q_rope,
|
370
394
|
k_buf[:, :, : layer.v_head_dim],
|
371
395
|
k_buf[:, :, layer.v_head_dim :],
|
396
|
+
out=o,
|
372
397
|
)
|
373
398
|
|
374
399
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -381,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
381
406
|
layer: RadixAttention,
|
382
407
|
forward_batch: ForwardBatch,
|
383
408
|
save_kv_cache: bool = True,
|
409
|
+
# For multi-head latent attention
|
410
|
+
q_rope: Optional[torch.Tensor] = None,
|
411
|
+
k_rope: Optional[torch.Tensor] = None,
|
384
412
|
):
|
385
413
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
386
414
|
cache_loc = forward_batch.out_cache_loc
|
@@ -388,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
388
416
|
if k is not None:
|
389
417
|
assert v is not None
|
390
418
|
if save_kv_cache:
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
419
|
+
if k_rope is not None:
|
420
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
421
|
+
layer,
|
422
|
+
cache_loc,
|
423
|
+
k,
|
424
|
+
k_rope,
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
428
|
+
layer,
|
429
|
+
cache_loc,
|
430
|
+
k,
|
431
|
+
v,
|
432
|
+
)
|
397
433
|
|
398
434
|
# Reshape inputs
|
399
|
-
|
435
|
+
if q_rope is not None:
|
436
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
437
|
+
q_rope = q_rope.view(
|
438
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
442
|
+
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
443
|
+
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
444
|
+
|
400
445
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
401
446
|
|
447
|
+
o = q_nope.new_empty(q_nope.shape)
|
402
448
|
# Direct call to run without the wrapper
|
403
449
|
o = decode_wrapper.run(
|
404
|
-
|
405
|
-
|
450
|
+
q_nope,
|
451
|
+
q_rope,
|
406
452
|
k_buffer[:, :, : layer.v_head_dim],
|
407
453
|
k_buffer[:, :, layer.v_head_dim :],
|
454
|
+
out=o,
|
408
455
|
)
|
409
456
|
|
410
457
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from sgl_kernel import merge_state_v2
|
5
|
+
|
6
|
+
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
|
7
|
+
from sglang.srt.utils import is_cuda
|
8
|
+
|
9
|
+
_is_cuda = is_cuda()
|
10
|
+
|
11
|
+
|
12
|
+
# Automatically fallback to the Triton kernel in some cases
|
13
|
+
# (e.g., for AMD GPUs, when the head dimension is not a multiple
|
14
|
+
# of 4 or 8, and in FP8 precision)
|
15
|
+
def _supported_dtypes(o: torch.Tensor) -> bool:
|
16
|
+
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
17
|
+
|
18
|
+
|
19
|
+
def _supported_headdim(o: torch.Tensor) -> bool:
|
20
|
+
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
21
|
+
if o.dtype == torch.float32:
|
22
|
+
return headdim % 4 == 0
|
23
|
+
return headdim % 8 == 0
|
24
|
+
|
25
|
+
|
26
|
+
def merge_state(
|
27
|
+
prefix_output: torch.Tensor,
|
28
|
+
prefix_lse: torch.Tensor,
|
29
|
+
suffix_output: torch.Tensor,
|
30
|
+
suffix_lse: torch.Tensor,
|
31
|
+
output: Optional[torch.Tensor] = None,
|
32
|
+
output_lse: Optional[torch.Tensor] = None,
|
33
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
34
|
+
if (
|
35
|
+
_is_cuda
|
36
|
+
and _supported_dtypes(prefix_output)
|
37
|
+
and _supported_headdim(prefix_output)
|
38
|
+
):
|
39
|
+
return merge_state_v2(
|
40
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# Fallback to Triton kernel
|
44
|
+
return merge_state_triton(
|
45
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
46
|
+
)
|
@@ -919,7 +919,7 @@ def _fwd_kernel(
|
|
919
919
|
|
920
920
|
e_max = n_e_max
|
921
921
|
|
922
|
-
# stage 2: compute the
|
922
|
+
# stage 2: compute the triangle part
|
923
923
|
|
924
924
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
925
925
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|