sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
13
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
14
|
-
from sglang.srt.utils import get_compiler_backend
|
15
14
|
|
16
15
|
if TYPE_CHECKING:
|
17
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -394,7 +393,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
394
393
|
dtype=torch.int32,
|
395
394
|
)
|
396
395
|
metadata_expand.max_seq_len_q = 1
|
397
|
-
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
398
396
|
metadata_expand.cu_seqlens_q = torch.arange(
|
399
397
|
0,
|
400
398
|
metadata_expand.cache_seqlens_int32.numel() + 1,
|
@@ -408,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
408
406
|
dtype=torch.int32,
|
409
407
|
device=device,
|
410
408
|
)
|
409
|
+
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
411
410
|
cache_loc = forward_batch.out_cache_loc.view(
|
412
|
-
self.speculative_num_steps
|
413
|
-
)
|
411
|
+
-1, self.speculative_num_steps
|
412
|
+
)
|
414
413
|
metadata_expand.page_table = (
|
415
414
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
416
415
|
)
|
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
550
549
|
),
|
551
550
|
(1, 0),
|
552
551
|
)
|
553
|
-
metadata_expand.max_seq_len_k = (
|
554
|
-
metadata_expand.cache_seqlens_int32.max().item()
|
555
|
-
)
|
556
552
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
557
553
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
558
554
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1421
1417
|
]
|
1422
1418
|
)
|
1423
1419
|
metadata_expand.max_seq_len_q = 1
|
1424
|
-
metadata_expand.max_seq_len_k = (
|
1425
|
-
self.speculative_step_id + 1
|
1426
|
-
) # , do this in replay
|
1427
1420
|
metadata_expand.cu_seqlens_q = (
|
1428
1421
|
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
1429
1422
|
: bs * self.topk + 1
|
@@ -1469,7 +1462,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1469
1462
|
"cache_seqlens"
|
1470
1463
|
][:bs]
|
1471
1464
|
metadata.cache_seqlens_int32.copy_(
|
1472
|
-
(seq_lens + self.speculative_num_draft_tokens)
|
1465
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
1473
1466
|
)
|
1474
1467
|
|
1475
1468
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
@@ -1536,7 +1529,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1536
1529
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
1537
1530
|
:bs
|
1538
1531
|
]
|
1539
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1532
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1540
1533
|
|
1541
1534
|
num_tokens_per_bs = num_tokens // bs
|
1542
1535
|
metadata.max_seq_len_q = num_tokens_per_bs
|
@@ -1600,38 +1593,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1600
1593
|
if spec_info is not None:
|
1601
1594
|
# Draft Decode
|
1602
1595
|
if self.topk <= 1:
|
1603
|
-
metadata = self.decode_cuda_graph_metadata[bs]
|
1604
1596
|
# When topk = 1, we use the normal decode metadata
|
1605
|
-
metadata.
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1610
|
-
self.speculative_step_id + 1
|
1611
|
-
)
|
1612
|
-
metadata.cu_seqlens_k[1:].copy_(
|
1613
|
-
torch.cumsum(
|
1614
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1615
|
-
)
|
1616
|
-
)
|
1617
|
-
|
1597
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1598
|
+
max_len = seq_lens_cpu.max().item()
|
1599
|
+
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
1618
1600
|
max_seq_pages = (
|
1619
1601
|
metadata.max_seq_len_k + self.page_size - 1
|
1620
1602
|
) // self.page_size
|
1621
|
-
page_indices = self.req_to_token[
|
1622
|
-
req_pool_indices[:, None],
|
1623
|
-
self.decode_cuda_graph_metadata["strided_indices"][
|
1624
|
-
:max_seq_pages
|
1625
|
-
],
|
1626
|
-
]
|
1627
1603
|
|
1628
|
-
|
1629
|
-
|
1604
|
+
normal_decode_set_medadata(
|
1605
|
+
metadata.cache_seqlens_int32,
|
1606
|
+
metadata.cu_seqlens_k,
|
1607
|
+
metadata.page_table,
|
1608
|
+
self.req_to_token,
|
1609
|
+
req_pool_indices,
|
1610
|
+
self.decode_cuda_graph_metadata["strided_indices"],
|
1611
|
+
max_seq_pages,
|
1612
|
+
seq_lens,
|
1613
|
+
self.speculative_step_id + 1,
|
1614
|
+
self.page_size,
|
1615
|
+
)
|
1616
|
+
|
1630
1617
|
else:
|
1631
1618
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1632
1619
|
# 1. The first half of metadata for prefix tokens
|
1633
1620
|
metadata = self.draft_decode_metadata_topk_normal[bs]
|
1634
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1621
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1635
1622
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
1636
1623
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1637
1624
|
# metadata.cu_seqlens_q already set in capture
|
@@ -1650,11 +1637,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1650
1637
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1651
1638
|
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
1652
1639
|
decode_length = self.speculative_step_id + 1
|
1653
|
-
|
1654
|
-
|
1655
|
-
).T.contiguous()
|
1640
|
+
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
1641
|
+
cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
|
1656
1642
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1657
|
-
cache_loc[:, :decode_length]
|
1643
|
+
cache_loc[:, :decode_length]
|
1658
1644
|
)
|
1659
1645
|
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
1660
1646
|
else:
|
@@ -1665,12 +1651,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1665
1651
|
metadata.max_seq_len_k = max_len
|
1666
1652
|
|
1667
1653
|
normal_decode_set_medadata(
|
1668
|
-
metadata,
|
1654
|
+
metadata.cache_seqlens_int32,
|
1655
|
+
metadata.cu_seqlens_k,
|
1656
|
+
metadata.page_table,
|
1669
1657
|
self.req_to_token,
|
1670
1658
|
req_pool_indices,
|
1671
1659
|
self.decode_cuda_graph_metadata["strided_indices"],
|
1672
1660
|
max_seq_pages,
|
1673
1661
|
seq_lens,
|
1662
|
+
0,
|
1674
1663
|
self.page_size,
|
1675
1664
|
)
|
1676
1665
|
|
@@ -1679,7 +1668,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1679
1668
|
if self.topk <= 1:
|
1680
1669
|
metadata = self.target_verify_metadata[bs]
|
1681
1670
|
metadata.cache_seqlens_int32.copy_(
|
1682
|
-
(seq_lens + self.speculative_num_draft_tokens)
|
1671
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
1683
1672
|
)
|
1684
1673
|
|
1685
1674
|
metadata.max_seq_len_k = (
|
@@ -1701,7 +1690,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1701
1690
|
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1702
1691
|
# 1. The first half of metadata for prefix tokens
|
1703
1692
|
metadata = self.target_verify_metadata_topk_normal[bs]
|
1704
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1693
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1705
1694
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1706
1695
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1707
1696
|
# metadata.cu_seqlens_q already set in capture
|
@@ -1761,9 +1750,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1761
1750
|
metadata_expand.page_table.copy_(
|
1762
1751
|
non_masked_page_table.gather(1, sort_order)
|
1763
1752
|
)
|
1764
|
-
metadata_expand.cache_seqlens_int32.copy_(
|
1765
|
-
mask.sum(dim=1).to(torch.int32)
|
1766
|
-
)
|
1753
|
+
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
|
1767
1754
|
metadata_expand.cu_seqlens_k[1:].copy_(
|
1768
1755
|
torch.cumsum(
|
1769
1756
|
metadata_expand.cache_seqlens_int32,
|
@@ -1771,19 +1758,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1771
1758
|
dtype=torch.int32,
|
1772
1759
|
)
|
1773
1760
|
)
|
1774
|
-
metadata_expand.max_seq_len_k = (
|
1775
|
-
metadata_expand.cache_seqlens_int32.max().item()
|
1776
|
-
)
|
1777
1761
|
elif forward_mode.is_draft_extend():
|
1778
1762
|
metadata = self.draft_extend_metadata[bs]
|
1779
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1763
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1780
1764
|
|
1781
1765
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1782
1766
|
metadata.cu_seqlens_k[1:].copy_(
|
1783
1767
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1784
1768
|
)
|
1785
1769
|
accept_length = spec_info.accept_length[:bs]
|
1786
|
-
metadata.max_seq_len_q =
|
1770
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
1787
1771
|
metadata.cu_seqlens_q[1:].copy_(
|
1788
1772
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
1789
1773
|
)
|
@@ -1795,8 +1779,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1795
1779
|
req_pool_indices[:, None],
|
1796
1780
|
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
1797
1781
|
]
|
1798
|
-
page_indices
|
1799
|
-
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1782
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
1800
1783
|
|
1801
1784
|
if encoder_lens is not None:
|
1802
1785
|
# Only support encoder size 1 for now
|
@@ -2045,6 +2028,8 @@ class FlashAttentionMultiStepBackend:
|
|
2045
2028
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
2046
2029
|
|
2047
2030
|
for i in range(self.speculative_num_steps - 1):
|
2031
|
+
# TODO: incrementally update the metadata for the later steps,
|
2032
|
+
# so that they do not need to recompute everything from scratch.
|
2048
2033
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
2049
2034
|
bs,
|
2050
2035
|
forward_batch.req_pool_indices,
|
@@ -2058,21 +2043,25 @@ class FlashAttentionMultiStepBackend:
|
|
2058
2043
|
)
|
2059
2044
|
|
2060
2045
|
|
2061
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
2046
|
+
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
2047
|
+
# TODO: fuse these kernels
|
2048
|
+
# NOTE: torch.compile makes it slower in speculative decoding
|
2062
2049
|
def normal_decode_set_medadata(
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2068
|
-
|
2069
|
-
|
2050
|
+
cache_seqlens_int32: torch.Tensor,
|
2051
|
+
cu_seqlens_k: torch.Tensor,
|
2052
|
+
page_table: torch.Tensor,
|
2053
|
+
req_to_token: torch.Tensor,
|
2054
|
+
req_pool_indices: torch.Tensor,
|
2055
|
+
strided_indices: torch.Tensor,
|
2056
|
+
max_seq_pages: torch.Tensor,
|
2057
|
+
seq_lens: torch.Tensor,
|
2058
|
+
seq_len_delta: int,
|
2059
|
+
page_size: int,
|
2070
2060
|
):
|
2071
|
-
|
2072
|
-
|
2061
|
+
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
|
2062
|
+
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
|
2073
2063
|
page_indices = req_to_token[
|
2074
2064
|
req_pool_indices[:, None],
|
2075
2065
|
strided_indices[:max_seq_pages][None, :],
|
2076
2066
|
]
|
2077
|
-
|
2078
|
-
metadata.page_table[:, max_seq_pages:].fill_(0)
|
2067
|
+
page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
|
|
1049
1049
|
kv_indices_buffer,
|
1050
1050
|
self.kv_indptr,
|
1051
1051
|
forward_batch.positions,
|
1052
|
-
num_seqs,
|
1053
|
-
self.topk,
|
1054
1052
|
self.pool_len,
|
1055
1053
|
kv_indices_buffer.shape[1],
|
1056
1054
|
self.kv_indptr.shape[1],
|
1057
1055
|
next_power_of_2(num_seqs),
|
1058
1056
|
next_power_of_2(self.speculative_num_steps),
|
1059
1057
|
next_power_of_2(bs),
|
1058
|
+
self.page_size,
|
1060
1059
|
)
|
1061
1060
|
|
1062
1061
|
assert forward_batch.spec_info is not None
|
@@ -15,7 +15,6 @@ from functools import partial
|
|
15
15
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
-
import triton
|
19
18
|
|
20
19
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
20
|
import logging
|
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
|
|
33
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
34
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
36
|
-
from sglang.srt.utils import is_flashinfer_available
|
35
|
+
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
756
755
|
|
757
756
|
if topk > 1:
|
758
757
|
raise ValueError(
|
759
|
-
|
758
|
+
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
760
759
|
)
|
761
760
|
self.topk = topk
|
762
761
|
self.speculative_num_steps = speculative_num_steps
|
@@ -790,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
790
789
|
|
791
790
|
# Cached variables for generate_draft_decode_kv_indices
|
792
791
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
792
|
+
self.page_size = model_runner.server_args.page_size
|
793
793
|
|
794
794
|
def common_template(
|
795
795
|
self,
|
@@ -810,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
810
810
|
kv_indices_buffer,
|
811
811
|
self.kv_indptr,
|
812
812
|
forward_batch.positions,
|
813
|
-
num_seqs,
|
814
|
-
self.topk,
|
815
813
|
self.pool_len,
|
816
814
|
kv_indices_buffer.shape[1],
|
817
815
|
self.kv_indptr.shape[1],
|
818
|
-
|
819
|
-
|
820
|
-
|
816
|
+
next_power_of_2(num_seqs),
|
817
|
+
next_power_of_2(self.speculative_num_steps),
|
818
|
+
next_power_of_2(bs),
|
819
|
+
self.page_size,
|
821
820
|
)
|
822
821
|
|
823
822
|
assert forward_batch.spec_info is not None
|
@@ -920,19 +919,18 @@ def fast_mla_decode_plan(
|
|
920
919
|
self._page_size = page_size
|
921
920
|
self._sm_scale = sm_scale
|
922
921
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
self.
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
922
|
+
try:
|
923
|
+
# Standard version with just the required arguments (no use_profiler)
|
924
|
+
self._cached_module.plan.default(
|
925
|
+
self._float_workspace_buffer,
|
926
|
+
self._int_workspace_buffer,
|
927
|
+
self._pin_memory_int_workspace_buffer,
|
928
|
+
qo_indptr_cpu,
|
929
|
+
kv_indptr_cpu,
|
930
|
+
kv_len_arr_cpu,
|
931
|
+
num_heads,
|
932
|
+
head_dim_ckv,
|
933
|
+
causal,
|
934
|
+
)
|
935
|
+
except Exception as e:
|
936
|
+
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
@@ -2,9 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""
|
4
4
|
Support attention backend for FlashMLA.
|
5
|
-
|
6
|
-
#TODO
|
7
|
-
Enable speculative sampling in FlashMLA
|
8
5
|
"""
|
9
6
|
|
10
7
|
from dataclasses import dataclass
|
@@ -14,8 +11,6 @@ import torch
|
|
14
11
|
import triton
|
15
12
|
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
16
13
|
|
17
|
-
from sglang.global_config import global_config
|
18
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
19
14
|
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
20
15
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
21
16
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
@@ -24,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
24
19
|
if TYPE_CHECKING:
|
25
20
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
21
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
28
22
|
from sglang.srt.speculative.spec_info import SpecInfo
|
29
23
|
|
30
24
|
|
@@ -330,7 +324,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
330
324
|
)
|
331
325
|
|
332
326
|
def get_cuda_graph_seq_len_fill_value(self):
|
333
|
-
return
|
327
|
+
return 1
|
334
328
|
|
335
329
|
def forward_decode(
|
336
330
|
self,
|
@@ -464,11 +458,9 @@ class FlashMLAMultiStepDraftBackend:
|
|
464
458
|
topk: int,
|
465
459
|
speculative_num_steps: int,
|
466
460
|
):
|
467
|
-
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
468
|
-
|
469
461
|
if topk > 1:
|
470
462
|
raise ValueError(
|
471
|
-
|
463
|
+
"Currently FlashMLA only supports topk=1 for speculative decoding"
|
472
464
|
)
|
473
465
|
self.topk = topk
|
474
466
|
self.speculative_num_steps = speculative_num_steps
|
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
13
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
-
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
15
|
+
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
|
|
20
20
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
21
21
|
|
22
22
|
|
23
|
-
@triton.jit
|
24
|
-
def get_num_kv_splits_triton(
|
25
|
-
num_kv_splits_ptr,
|
26
|
-
seq_lens_ptr,
|
27
|
-
num_seq,
|
28
|
-
num_group,
|
29
|
-
num_head,
|
30
|
-
num_kv_head,
|
31
|
-
max_kv_splits,
|
32
|
-
device_core_count,
|
33
|
-
MAX_NUM_SEQ: tl.constexpr,
|
34
|
-
):
|
35
|
-
# TODO: this method is tunable, we need more online serving data to tune it
|
36
|
-
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
37
|
-
mask_seq = offs_seq < num_seq
|
38
|
-
|
39
|
-
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
40
|
-
max_seq_len = tl.max(seq_lens)
|
41
|
-
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
42
|
-
min_seq_len = tl.min(seq_lens)
|
43
|
-
if max_seq_len * 8 < min_seq_len * 10:
|
44
|
-
min_seq_len = max_seq_len
|
45
|
-
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
46
|
-
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
47
|
-
|
48
|
-
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
49
|
-
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
50
|
-
ext_device_core_count = tl.cast(
|
51
|
-
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
52
|
-
)
|
53
|
-
block_h, num_kv_group = 16, num_head // num_kv_head
|
54
|
-
if num_kv_group == 1:
|
55
|
-
token_grid = num_seq * num_group * num_head
|
56
|
-
else:
|
57
|
-
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
58
|
-
block_h = tl.minimum(block_h, num_kv_group)
|
59
|
-
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
60
|
-
max_kv_splits_2 = tl.minimum(
|
61
|
-
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
62
|
-
)
|
63
|
-
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
64
|
-
|
65
|
-
num_kv_splits = tl.maximum(
|
66
|
-
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
67
|
-
)
|
68
|
-
|
69
|
-
offs_token = offs_seq * num_group
|
70
|
-
mask_token = offs_token < num_seq * num_group
|
71
|
-
for i in range(0, num_group):
|
72
|
-
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
73
|
-
|
74
|
-
|
75
|
-
def update_sliding_window_buffer(
|
76
|
-
window_kv_indptr,
|
77
|
-
req_to_token,
|
78
|
-
sliding_window_size,
|
79
|
-
seq_lens,
|
80
|
-
req_pool_indices,
|
81
|
-
bs,
|
82
|
-
device,
|
83
|
-
):
|
84
|
-
window_kv_lens = torch.minimum(
|
85
|
-
seq_lens,
|
86
|
-
torch.tensor(sliding_window_size + 1),
|
87
|
-
)
|
88
|
-
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
89
|
-
window_kv_indptr = window_kv_indptr[: bs + 1]
|
90
|
-
window_kv_indices = torch.empty(
|
91
|
-
window_kv_indptr[-1], dtype=torch.int32, device=device
|
92
|
-
)
|
93
|
-
window_kv_start_idx = seq_lens - window_kv_lens
|
94
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
95
|
-
req_to_token,
|
96
|
-
req_pool_indices,
|
97
|
-
window_kv_lens,
|
98
|
-
window_kv_indptr,
|
99
|
-
window_kv_start_idx,
|
100
|
-
window_kv_indices,
|
101
|
-
req_to_token.stride(0),
|
102
|
-
)
|
103
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
104
|
-
|
105
|
-
|
106
|
-
def update_sliding_window_buffer_cuda_graph(
|
107
|
-
window_kv_indptr,
|
108
|
-
window_kv_indices,
|
109
|
-
req_to_token,
|
110
|
-
sliding_window_size,
|
111
|
-
seq_lens,
|
112
|
-
req_pool_indices,
|
113
|
-
bs,
|
114
|
-
):
|
115
|
-
window_kv_lens = torch.minimum(
|
116
|
-
seq_lens,
|
117
|
-
torch.tensor(sliding_window_size + 1),
|
118
|
-
)
|
119
|
-
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
120
|
-
window_kv_indptr = window_kv_indptr[: bs + 1]
|
121
|
-
window_kv_start_idx = seq_lens - window_kv_lens
|
122
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
123
|
-
req_to_token,
|
124
|
-
req_pool_indices,
|
125
|
-
window_kv_lens,
|
126
|
-
window_kv_indptr,
|
127
|
-
window_kv_start_idx,
|
128
|
-
window_kv_indices,
|
129
|
-
req_to_token.stride(0),
|
130
|
-
)
|
131
|
-
return window_kv_indptr, window_kv_lens
|
132
|
-
|
133
|
-
|
134
23
|
@dataclass
|
135
24
|
class ForwardMetadata:
|
136
25
|
attn_logits: torch.Tensor
|
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
165
54
|
|
166
55
|
super().__init__()
|
167
56
|
|
168
|
-
self.decode_attention_fwd = decode_attention_fwd
|
169
|
-
self.extend_attention_fwd = extend_attention_fwd
|
57
|
+
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
58
|
+
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
170
59
|
|
171
60
|
self.skip_prefill = skip_prefill
|
172
61
|
|
@@ -877,6 +766,7 @@ class TritonMultiStepDraftBackend:
|
|
877
766
|
self.device = model_runner.device
|
878
767
|
# Cached variables for generate_draft_decode_kv_indices
|
879
768
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
769
|
+
self.page_size = model_runner.server_args.page_size
|
880
770
|
|
881
771
|
def common_template(
|
882
772
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
@@ -894,14 +784,13 @@ class TritonMultiStepDraftBackend:
|
|
894
784
|
kv_indices_buffer,
|
895
785
|
self.kv_indptr,
|
896
786
|
forward_batch.positions,
|
897
|
-
num_seqs,
|
898
|
-
self.topk,
|
899
787
|
self.pool_len,
|
900
788
|
kv_indices_buffer.shape[1],
|
901
789
|
self.kv_indptr.shape[1],
|
902
|
-
|
903
|
-
|
904
|
-
|
790
|
+
next_power_of_2(num_seqs),
|
791
|
+
next_power_of_2(self.speculative_num_steps),
|
792
|
+
next_power_of_2(bs),
|
793
|
+
self.page_size,
|
905
794
|
)
|
906
795
|
|
907
796
|
for i in range(self.speculative_num_steps):
|
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
|
|
973
862
|
)
|
974
863
|
|
975
864
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
865
|
+
|
866
|
+
|
867
|
+
@triton.jit
|
868
|
+
def get_num_kv_splits_triton(
|
869
|
+
num_kv_splits_ptr,
|
870
|
+
seq_lens_ptr,
|
871
|
+
num_seq,
|
872
|
+
num_group,
|
873
|
+
num_head,
|
874
|
+
num_kv_head,
|
875
|
+
max_kv_splits,
|
876
|
+
device_core_count,
|
877
|
+
MAX_NUM_SEQ: tl.constexpr,
|
878
|
+
):
|
879
|
+
# TODO: this method is tunable, we need more online serving data to tune it
|
880
|
+
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
881
|
+
mask_seq = offs_seq < num_seq
|
882
|
+
|
883
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
884
|
+
max_seq_len = tl.max(seq_lens)
|
885
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
886
|
+
min_seq_len = tl.min(seq_lens)
|
887
|
+
if max_seq_len * 8 < min_seq_len * 10:
|
888
|
+
min_seq_len = max_seq_len
|
889
|
+
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
890
|
+
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
891
|
+
|
892
|
+
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
893
|
+
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
894
|
+
ext_device_core_count = tl.cast(
|
895
|
+
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
896
|
+
)
|
897
|
+
block_h, num_kv_group = 16, num_head // num_kv_head
|
898
|
+
if num_kv_group == 1:
|
899
|
+
token_grid = num_seq * num_group * num_head
|
900
|
+
else:
|
901
|
+
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
902
|
+
block_h = tl.minimum(block_h, num_kv_group)
|
903
|
+
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
904
|
+
max_kv_splits_2 = tl.minimum(
|
905
|
+
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
906
|
+
)
|
907
|
+
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
908
|
+
|
909
|
+
num_kv_splits = tl.maximum(
|
910
|
+
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
911
|
+
)
|
912
|
+
|
913
|
+
offs_token = offs_seq * num_group
|
914
|
+
mask_token = offs_token < num_seq * num_group
|
915
|
+
for i in range(0, num_group):
|
916
|
+
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
917
|
+
|
918
|
+
|
919
|
+
def update_sliding_window_buffer(
|
920
|
+
window_kv_indptr,
|
921
|
+
req_to_token,
|
922
|
+
sliding_window_size,
|
923
|
+
seq_lens,
|
924
|
+
req_pool_indices,
|
925
|
+
bs,
|
926
|
+
device,
|
927
|
+
):
|
928
|
+
window_kv_lens = torch.minimum(
|
929
|
+
seq_lens,
|
930
|
+
torch.tensor(sliding_window_size + 1),
|
931
|
+
)
|
932
|
+
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
933
|
+
window_kv_indptr = window_kv_indptr[: bs + 1]
|
934
|
+
window_kv_indices = torch.empty(
|
935
|
+
window_kv_indptr[-1], dtype=torch.int32, device=device
|
936
|
+
)
|
937
|
+
window_kv_start_idx = seq_lens - window_kv_lens
|
938
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
939
|
+
req_to_token,
|
940
|
+
req_pool_indices,
|
941
|
+
window_kv_lens,
|
942
|
+
window_kv_indptr,
|
943
|
+
window_kv_start_idx,
|
944
|
+
window_kv_indices,
|
945
|
+
req_to_token.stride(0),
|
946
|
+
)
|
947
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens
|
948
|
+
|
949
|
+
|
950
|
+
def update_sliding_window_buffer_cuda_graph(
|
951
|
+
window_kv_indptr,
|
952
|
+
window_kv_indices,
|
953
|
+
req_to_token,
|
954
|
+
sliding_window_size,
|
955
|
+
seq_lens,
|
956
|
+
req_pool_indices,
|
957
|
+
bs,
|
958
|
+
):
|
959
|
+
window_kv_lens = torch.minimum(
|
960
|
+
seq_lens,
|
961
|
+
torch.tensor(sliding_window_size + 1),
|
962
|
+
)
|
963
|
+
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
964
|
+
window_kv_indptr = window_kv_indptr[: bs + 1]
|
965
|
+
window_kv_start_idx = seq_lens - window_kv_lens
|
966
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
967
|
+
req_to_token,
|
968
|
+
req_pool_indices,
|
969
|
+
window_kv_lens,
|
970
|
+
window_kv_indptr,
|
971
|
+
window_kv_start_idx,
|
972
|
+
window_kv_indices,
|
973
|
+
req_to_token.stride(0),
|
974
|
+
)
|
975
|
+
return window_kv_indptr, window_kv_lens
|