sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
13
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
14
- from sglang.srt.utils import get_compiler_backend
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -394,7 +393,6 @@ class FlashAttentionBackend(AttentionBackend):
394
393
  dtype=torch.int32,
395
394
  )
396
395
  metadata_expand.max_seq_len_q = 1
397
- metadata_expand.max_seq_len_k = self.speculative_step_id + 1
398
396
  metadata_expand.cu_seqlens_q = torch.arange(
399
397
  0,
400
398
  metadata_expand.cache_seqlens_int32.numel() + 1,
@@ -408,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
408
406
  dtype=torch.int32,
409
407
  device=device,
410
408
  )
409
+ # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
411
410
  cache_loc = forward_batch.out_cache_loc.view(
412
- self.speculative_num_steps, -1
413
- ).T.contiguous()
411
+ -1, self.speculative_num_steps
412
+ )
414
413
  metadata_expand.page_table = (
415
414
  cache_loc[:, :decode_length].contiguous().to(torch.int32)
416
415
  )
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
550
549
  ),
551
550
  (1, 0),
552
551
  )
553
- metadata_expand.max_seq_len_k = (
554
- metadata_expand.cache_seqlens_int32.max().item()
555
- )
556
552
  self.forward_metadata_spec_decode_expand = metadata_expand
557
553
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
558
554
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
1421
1417
  ]
1422
1418
  )
1423
1419
  metadata_expand.max_seq_len_q = 1
1424
- metadata_expand.max_seq_len_k = (
1425
- self.speculative_step_id + 1
1426
- ) # , do this in replay
1427
1420
  metadata_expand.cu_seqlens_q = (
1428
1421
  self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
1429
1422
  : bs * self.topk + 1
@@ -1469,7 +1462,7 @@ class FlashAttentionBackend(AttentionBackend):
1469
1462
  "cache_seqlens"
1470
1463
  ][:bs]
1471
1464
  metadata.cache_seqlens_int32.copy_(
1472
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1465
+ (seq_lens + self.speculative_num_draft_tokens)
1473
1466
  )
1474
1467
 
1475
1468
  metadata.max_seq_len_q = self.speculative_num_draft_tokens
@@ -1536,7 +1529,7 @@ class FlashAttentionBackend(AttentionBackend):
1536
1529
  metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
1537
1530
  :bs
1538
1531
  ]
1539
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1532
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1540
1533
 
1541
1534
  num_tokens_per_bs = num_tokens // bs
1542
1535
  metadata.max_seq_len_q = num_tokens_per_bs
@@ -1600,38 +1593,32 @@ class FlashAttentionBackend(AttentionBackend):
1600
1593
  if spec_info is not None:
1601
1594
  # Draft Decode
1602
1595
  if self.topk <= 1:
1603
- metadata = self.decode_cuda_graph_metadata[bs]
1604
1596
  # When topk = 1, we use the normal decode metadata
1605
- metadata.cache_seqlens_int32.copy_(
1606
- (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1607
- )
1608
-
1609
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1610
- self.speculative_step_id + 1
1611
- )
1612
- metadata.cu_seqlens_k[1:].copy_(
1613
- torch.cumsum(
1614
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1615
- )
1616
- )
1617
-
1597
+ metadata = self.decode_cuda_graph_metadata[bs]
1598
+ max_len = seq_lens_cpu.max().item()
1599
+ metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
1618
1600
  max_seq_pages = (
1619
1601
  metadata.max_seq_len_k + self.page_size - 1
1620
1602
  ) // self.page_size
1621
- page_indices = self.req_to_token[
1622
- req_pool_indices[:, None],
1623
- self.decode_cuda_graph_metadata["strided_indices"][
1624
- :max_seq_pages
1625
- ],
1626
- ]
1627
1603
 
1628
- page_indices //= self.page_size
1629
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1604
+ normal_decode_set_medadata(
1605
+ metadata.cache_seqlens_int32,
1606
+ metadata.cu_seqlens_k,
1607
+ metadata.page_table,
1608
+ self.req_to_token,
1609
+ req_pool_indices,
1610
+ self.decode_cuda_graph_metadata["strided_indices"],
1611
+ max_seq_pages,
1612
+ seq_lens,
1613
+ self.speculative_step_id + 1,
1614
+ self.page_size,
1615
+ )
1616
+
1630
1617
  else:
1631
1618
  # When top k > 1, we need two specific draft decode metadata, and then merge states
1632
1619
  # 1. The first half of metadata for prefix tokens
1633
1620
  metadata = self.draft_decode_metadata_topk_normal[bs]
1634
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1621
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1635
1622
  # metadata.max_seq_len_q = self.topk, already set in capture
1636
1623
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1637
1624
  # metadata.cu_seqlens_q already set in capture
@@ -1650,11 +1637,10 @@ class FlashAttentionBackend(AttentionBackend):
1650
1637
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1651
1638
  metadata_expand = self.draft_decode_metadata_topk_expand[bs]
1652
1639
  decode_length = self.speculative_step_id + 1
1653
- cache_loc = out_cache_loc.view(
1654
- self.speculative_num_steps, -1
1655
- ).T.contiguous()
1640
+ # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
1641
+ cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
1656
1642
  metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1657
- cache_loc[:, :decode_length].contiguous().to(torch.int32)
1643
+ cache_loc[:, :decode_length]
1658
1644
  )
1659
1645
  # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
1660
1646
  else:
@@ -1665,12 +1651,15 @@ class FlashAttentionBackend(AttentionBackend):
1665
1651
  metadata.max_seq_len_k = max_len
1666
1652
 
1667
1653
  normal_decode_set_medadata(
1668
- metadata,
1654
+ metadata.cache_seqlens_int32,
1655
+ metadata.cu_seqlens_k,
1656
+ metadata.page_table,
1669
1657
  self.req_to_token,
1670
1658
  req_pool_indices,
1671
1659
  self.decode_cuda_graph_metadata["strided_indices"],
1672
1660
  max_seq_pages,
1673
1661
  seq_lens,
1662
+ 0,
1674
1663
  self.page_size,
1675
1664
  )
1676
1665
 
@@ -1679,7 +1668,7 @@ class FlashAttentionBackend(AttentionBackend):
1679
1668
  if self.topk <= 1:
1680
1669
  metadata = self.target_verify_metadata[bs]
1681
1670
  metadata.cache_seqlens_int32.copy_(
1682
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1671
+ (seq_lens + self.speculative_num_draft_tokens)
1683
1672
  )
1684
1673
 
1685
1674
  metadata.max_seq_len_k = (
@@ -1701,7 +1690,7 @@ class FlashAttentionBackend(AttentionBackend):
1701
1690
  # When topk > 1, we need two specific target verify metadata, and then merge states
1702
1691
  # 1. The first half of metadata for prefix tokens
1703
1692
  metadata = self.target_verify_metadata_topk_normal[bs]
1704
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1693
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1705
1694
  # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1706
1695
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1707
1696
  # metadata.cu_seqlens_q already set in capture
@@ -1761,9 +1750,7 @@ class FlashAttentionBackend(AttentionBackend):
1761
1750
  metadata_expand.page_table.copy_(
1762
1751
  non_masked_page_table.gather(1, sort_order)
1763
1752
  )
1764
- metadata_expand.cache_seqlens_int32.copy_(
1765
- mask.sum(dim=1).to(torch.int32)
1766
- )
1753
+ metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
1767
1754
  metadata_expand.cu_seqlens_k[1:].copy_(
1768
1755
  torch.cumsum(
1769
1756
  metadata_expand.cache_seqlens_int32,
@@ -1771,19 +1758,16 @@ class FlashAttentionBackend(AttentionBackend):
1771
1758
  dtype=torch.int32,
1772
1759
  )
1773
1760
  )
1774
- metadata_expand.max_seq_len_k = (
1775
- metadata_expand.cache_seqlens_int32.max().item()
1776
- )
1777
1761
  elif forward_mode.is_draft_extend():
1778
1762
  metadata = self.draft_extend_metadata[bs]
1779
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1763
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1780
1764
 
1781
1765
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1782
1766
  metadata.cu_seqlens_k[1:].copy_(
1783
1767
  torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1784
1768
  )
1785
1769
  accept_length = spec_info.accept_length[:bs]
1786
- metadata.max_seq_len_q = accept_length.max().item()
1770
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1787
1771
  metadata.cu_seqlens_q[1:].copy_(
1788
1772
  torch.cumsum(accept_length, dim=0, dtype=torch.int32)
1789
1773
  )
@@ -1795,8 +1779,7 @@ class FlashAttentionBackend(AttentionBackend):
1795
1779
  req_pool_indices[:, None],
1796
1780
  self.draft_extend_metadata["strided_indices"][:max_seq_pages],
1797
1781
  ]
1798
- page_indices //= self.page_size
1799
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1782
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
1800
1783
 
1801
1784
  if encoder_lens is not None:
1802
1785
  # Only support encoder size 1 for now
@@ -2045,6 +2028,8 @@ class FlashAttentionMultiStepBackend:
2045
2028
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
2046
2029
 
2047
2030
  for i in range(self.speculative_num_steps - 1):
2031
+ # TODO: incrementally update the metadata for the later steps,
2032
+ # so that they do not need to recompute everything from scratch.
2048
2033
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
2049
2034
  bs,
2050
2035
  forward_batch.req_pool_indices,
@@ -2058,21 +2043,25 @@ class FlashAttentionMultiStepBackend:
2058
2043
  )
2059
2044
 
2060
2045
 
2061
- @torch.compile(dynamic=True, backend=get_compiler_backend())
2046
+ # @torch.compile(dynamic=True, backend=get_compiler_backend())
2047
+ # TODO: fuse these kernels
2048
+ # NOTE: torch.compile makes it slower in speculative decoding
2062
2049
  def normal_decode_set_medadata(
2063
- metadata,
2064
- req_to_token,
2065
- req_pool_indices,
2066
- strided_indices,
2067
- max_seq_pages,
2068
- seq_lens,
2069
- page_size,
2050
+ cache_seqlens_int32: torch.Tensor,
2051
+ cu_seqlens_k: torch.Tensor,
2052
+ page_table: torch.Tensor,
2053
+ req_to_token: torch.Tensor,
2054
+ req_pool_indices: torch.Tensor,
2055
+ strided_indices: torch.Tensor,
2056
+ max_seq_pages: torch.Tensor,
2057
+ seq_lens: torch.Tensor,
2058
+ seq_len_delta: int,
2059
+ page_size: int,
2070
2060
  ):
2071
- metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
2072
- metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
2061
+ cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
2062
+ cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
2073
2063
  page_indices = req_to_token[
2074
2064
  req_pool_indices[:, None],
2075
2065
  strided_indices[:max_seq_pages][None, :],
2076
2066
  ]
2077
- metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
2078
- metadata.page_table[:, max_seq_pages:].fill_(0)
2067
+ page_table[:, :max_seq_pages].copy_(page_indices // page_size)
@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
1049
1049
  kv_indices_buffer,
1050
1050
  self.kv_indptr,
1051
1051
  forward_batch.positions,
1052
- num_seqs,
1053
- self.topk,
1054
1052
  self.pool_len,
1055
1053
  kv_indices_buffer.shape[1],
1056
1054
  self.kv_indptr.shape[1],
1057
1055
  next_power_of_2(num_seqs),
1058
1056
  next_power_of_2(self.speculative_num_steps),
1059
1057
  next_power_of_2(bs),
1058
+ self.page_size,
1060
1059
  )
1061
1060
 
1062
1061
  assert forward_batch.spec_info is not None
@@ -15,7 +15,6 @@ from functools import partial
15
15
  from typing import TYPE_CHECKING, Callable, Optional, Union
16
16
 
17
17
  import torch
18
- import triton
19
18
 
20
19
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
20
  import logging
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
33
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
34
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
34
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
36
- from sglang.srt.utils import is_flashinfer_available
35
+ from sglang.srt.utils import is_flashinfer_available, next_power_of_2
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend:
756
755
 
757
756
  if topk > 1:
758
757
  raise ValueError(
759
- f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
758
+ "Currently Flashinfer MLA only supports topk=1 for speculative decoding"
760
759
  )
761
760
  self.topk = topk
762
761
  self.speculative_num_steps = speculative_num_steps
@@ -790,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
790
789
 
791
790
  # Cached variables for generate_draft_decode_kv_indices
792
791
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
792
+ self.page_size = model_runner.server_args.page_size
793
793
 
794
794
  def common_template(
795
795
  self,
@@ -810,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
810
810
  kv_indices_buffer,
811
811
  self.kv_indptr,
812
812
  forward_batch.positions,
813
- num_seqs,
814
- self.topk,
815
813
  self.pool_len,
816
814
  kv_indices_buffer.shape[1],
817
815
  self.kv_indptr.shape[1],
818
- triton.next_power_of_2(num_seqs),
819
- triton.next_power_of_2(self.speculative_num_steps),
820
- triton.next_power_of_2(bs),
816
+ next_power_of_2(num_seqs),
817
+ next_power_of_2(self.speculative_num_steps),
818
+ next_power_of_2(bs),
819
+ self.page_size,
821
820
  )
822
821
 
823
822
  assert forward_batch.spec_info is not None
@@ -920,19 +919,18 @@ def fast_mla_decode_plan(
920
919
  self._page_size = page_size
921
920
  self._sm_scale = sm_scale
922
921
 
923
- with self.device as device:
924
- try:
925
- # Standard version with just the required arguments (no use_profiler)
926
- self._cached_module.plan.default(
927
- self._float_workspace_buffer,
928
- self._int_workspace_buffer,
929
- self._pin_memory_int_workspace_buffer,
930
- qo_indptr_cpu,
931
- kv_indptr_cpu,
932
- kv_len_arr_cpu,
933
- num_heads,
934
- head_dim_ckv,
935
- causal,
936
- )
937
- except Exception as e:
938
- raise RuntimeError(f"Error in alternate MLA plan: {e}")
922
+ try:
923
+ # Standard version with just the required arguments (no use_profiler)
924
+ self._cached_module.plan.default(
925
+ self._float_workspace_buffer,
926
+ self._int_workspace_buffer,
927
+ self._pin_memory_int_workspace_buffer,
928
+ qo_indptr_cpu,
929
+ kv_indptr_cpu,
930
+ kv_len_arr_cpu,
931
+ num_heads,
932
+ head_dim_ckv,
933
+ causal,
934
+ )
935
+ except Exception as e:
936
+ raise RuntimeError(f"Error in alternate MLA plan: {e}")
@@ -2,9 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  """
4
4
  Support attention backend for FlashMLA.
5
-
6
- #TODO
7
- Enable speculative sampling in FlashMLA
8
5
  """
9
6
 
10
7
  from dataclasses import dataclass
@@ -14,8 +11,6 @@ import torch
14
11
  import triton
15
12
  from flash_mla import flash_mla_with_kvcache, get_mla_metadata
16
13
 
17
- from sglang.global_config import global_config
18
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
19
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
21
16
  from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -24,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
24
19
  if TYPE_CHECKING:
25
20
  from sglang.srt.layers.radix_attention import RadixAttention
26
21
  from sglang.srt.model_executor.model_runner import ModelRunner
27
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
28
22
  from sglang.srt.speculative.spec_info import SpecInfo
29
23
 
30
24
 
@@ -330,7 +324,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
330
324
  )
331
325
 
332
326
  def get_cuda_graph_seq_len_fill_value(self):
333
- return 1024
327
+ return 1
334
328
 
335
329
  def forward_decode(
336
330
  self,
@@ -464,11 +458,9 @@ class FlashMLAMultiStepDraftBackend:
464
458
  topk: int,
465
459
  speculative_num_steps: int,
466
460
  ):
467
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
468
-
469
461
  if topk > 1:
470
462
  raise ValueError(
471
- f"Currently FlashMLA only supports topk=1 for speculative decoding"
463
+ "Currently FlashMLA only supports topk=1 for speculative decoding"
472
464
  )
473
465
  self.topk = topk
474
466
  self.speculative_num_steps = speculative_num_steps
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
- from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
+ from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
20
20
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
21
 
22
22
 
23
- @triton.jit
24
- def get_num_kv_splits_triton(
25
- num_kv_splits_ptr,
26
- seq_lens_ptr,
27
- num_seq,
28
- num_group,
29
- num_head,
30
- num_kv_head,
31
- max_kv_splits,
32
- device_core_count,
33
- MAX_NUM_SEQ: tl.constexpr,
34
- ):
35
- # TODO: this method is tunable, we need more online serving data to tune it
36
- offs_seq = tl.arange(0, MAX_NUM_SEQ)
37
- mask_seq = offs_seq < num_seq
38
-
39
- seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
40
- max_seq_len = tl.max(seq_lens)
41
- seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
42
- min_seq_len = tl.min(seq_lens)
43
- if max_seq_len * 8 < min_seq_len * 10:
44
- min_seq_len = max_seq_len
45
- max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
46
- kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
47
-
48
- # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
49
- ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
50
- ext_device_core_count = tl.cast(
51
- device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
52
- )
53
- block_h, num_kv_group = 16, num_head // num_kv_head
54
- if num_kv_group == 1:
55
- token_grid = num_seq * num_group * num_head
56
- else:
57
- # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
58
- block_h = tl.minimum(block_h, num_kv_group)
59
- token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
60
- max_kv_splits_2 = tl.minimum(
61
- tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
62
- )
63
- kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
64
-
65
- num_kv_splits = tl.maximum(
66
- tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
67
- )
68
-
69
- offs_token = offs_seq * num_group
70
- mask_token = offs_token < num_seq * num_group
71
- for i in range(0, num_group):
72
- tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
73
-
74
-
75
- def update_sliding_window_buffer(
76
- window_kv_indptr,
77
- req_to_token,
78
- sliding_window_size,
79
- seq_lens,
80
- req_pool_indices,
81
- bs,
82
- device,
83
- ):
84
- window_kv_lens = torch.minimum(
85
- seq_lens,
86
- torch.tensor(sliding_window_size + 1),
87
- )
88
- window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
89
- window_kv_indptr = window_kv_indptr[: bs + 1]
90
- window_kv_indices = torch.empty(
91
- window_kv_indptr[-1], dtype=torch.int32, device=device
92
- )
93
- window_kv_start_idx = seq_lens - window_kv_lens
94
- create_flashinfer_kv_indices_triton[(bs,)](
95
- req_to_token,
96
- req_pool_indices,
97
- window_kv_lens,
98
- window_kv_indptr,
99
- window_kv_start_idx,
100
- window_kv_indices,
101
- req_to_token.stride(0),
102
- )
103
- return window_kv_indptr, window_kv_indices, window_kv_lens
104
-
105
-
106
- def update_sliding_window_buffer_cuda_graph(
107
- window_kv_indptr,
108
- window_kv_indices,
109
- req_to_token,
110
- sliding_window_size,
111
- seq_lens,
112
- req_pool_indices,
113
- bs,
114
- ):
115
- window_kv_lens = torch.minimum(
116
- seq_lens,
117
- torch.tensor(sliding_window_size + 1),
118
- )
119
- window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
120
- window_kv_indptr = window_kv_indptr[: bs + 1]
121
- window_kv_start_idx = seq_lens - window_kv_lens
122
- create_flashinfer_kv_indices_triton[(bs,)](
123
- req_to_token,
124
- req_pool_indices,
125
- window_kv_lens,
126
- window_kv_indptr,
127
- window_kv_start_idx,
128
- window_kv_indices,
129
- req_to_token.stride(0),
130
- )
131
- return window_kv_indptr, window_kv_lens
132
-
133
-
134
23
  @dataclass
135
24
  class ForwardMetadata:
136
25
  attn_logits: torch.Tensor
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
165
54
 
166
55
  super().__init__()
167
56
 
168
- self.decode_attention_fwd = decode_attention_fwd
169
- self.extend_attention_fwd = extend_attention_fwd
57
+ self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
58
+ self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
170
59
 
171
60
  self.skip_prefill = skip_prefill
172
61
 
@@ -877,6 +766,7 @@ class TritonMultiStepDraftBackend:
877
766
  self.device = model_runner.device
878
767
  # Cached variables for generate_draft_decode_kv_indices
879
768
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
769
+ self.page_size = model_runner.server_args.page_size
880
770
 
881
771
  def common_template(
882
772
  self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
@@ -894,14 +784,13 @@ class TritonMultiStepDraftBackend:
894
784
  kv_indices_buffer,
895
785
  self.kv_indptr,
896
786
  forward_batch.positions,
897
- num_seqs,
898
- self.topk,
899
787
  self.pool_len,
900
788
  kv_indices_buffer.shape[1],
901
789
  self.kv_indptr.shape[1],
902
- triton.next_power_of_2(num_seqs),
903
- triton.next_power_of_2(self.speculative_num_steps),
904
- triton.next_power_of_2(bs),
790
+ next_power_of_2(num_seqs),
791
+ next_power_of_2(self.speculative_num_steps),
792
+ next_power_of_2(bs),
793
+ self.page_size,
905
794
  )
906
795
 
907
796
  for i in range(self.speculative_num_steps):
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
973
862
  )
974
863
 
975
864
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
865
+
866
+
867
+ @triton.jit
868
+ def get_num_kv_splits_triton(
869
+ num_kv_splits_ptr,
870
+ seq_lens_ptr,
871
+ num_seq,
872
+ num_group,
873
+ num_head,
874
+ num_kv_head,
875
+ max_kv_splits,
876
+ device_core_count,
877
+ MAX_NUM_SEQ: tl.constexpr,
878
+ ):
879
+ # TODO: this method is tunable, we need more online serving data to tune it
880
+ offs_seq = tl.arange(0, MAX_NUM_SEQ)
881
+ mask_seq = offs_seq < num_seq
882
+
883
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
884
+ max_seq_len = tl.max(seq_lens)
885
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
886
+ min_seq_len = tl.min(seq_lens)
887
+ if max_seq_len * 8 < min_seq_len * 10:
888
+ min_seq_len = max_seq_len
889
+ max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
890
+ kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
891
+
892
+ # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
893
+ ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
894
+ ext_device_core_count = tl.cast(
895
+ device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
896
+ )
897
+ block_h, num_kv_group = 16, num_head // num_kv_head
898
+ if num_kv_group == 1:
899
+ token_grid = num_seq * num_group * num_head
900
+ else:
901
+ # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
902
+ block_h = tl.minimum(block_h, num_kv_group)
903
+ token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
904
+ max_kv_splits_2 = tl.minimum(
905
+ tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
906
+ )
907
+ kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
908
+
909
+ num_kv_splits = tl.maximum(
910
+ tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
911
+ )
912
+
913
+ offs_token = offs_seq * num_group
914
+ mask_token = offs_token < num_seq * num_group
915
+ for i in range(0, num_group):
916
+ tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
917
+
918
+
919
+ def update_sliding_window_buffer(
920
+ window_kv_indptr,
921
+ req_to_token,
922
+ sliding_window_size,
923
+ seq_lens,
924
+ req_pool_indices,
925
+ bs,
926
+ device,
927
+ ):
928
+ window_kv_lens = torch.minimum(
929
+ seq_lens,
930
+ torch.tensor(sliding_window_size + 1),
931
+ )
932
+ window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
933
+ window_kv_indptr = window_kv_indptr[: bs + 1]
934
+ window_kv_indices = torch.empty(
935
+ window_kv_indptr[-1], dtype=torch.int32, device=device
936
+ )
937
+ window_kv_start_idx = seq_lens - window_kv_lens
938
+ create_flashinfer_kv_indices_triton[(bs,)](
939
+ req_to_token,
940
+ req_pool_indices,
941
+ window_kv_lens,
942
+ window_kv_indptr,
943
+ window_kv_start_idx,
944
+ window_kv_indices,
945
+ req_to_token.stride(0),
946
+ )
947
+ return window_kv_indptr, window_kv_indices, window_kv_lens
948
+
949
+
950
+ def update_sliding_window_buffer_cuda_graph(
951
+ window_kv_indptr,
952
+ window_kv_indices,
953
+ req_to_token,
954
+ sliding_window_size,
955
+ seq_lens,
956
+ req_pool_indices,
957
+ bs,
958
+ ):
959
+ window_kv_lens = torch.minimum(
960
+ seq_lens,
961
+ torch.tensor(sliding_window_size + 1),
962
+ )
963
+ window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
964
+ window_kv_indptr = window_kv_indptr[: bs + 1]
965
+ window_kv_start_idx = seq_lens - window_kv_lens
966
+ create_flashinfer_kv_indices_triton[(bs,)](
967
+ req_to_token,
968
+ req_pool_indices,
969
+ window_kv_lens,
970
+ window_kv_indptr,
971
+ window_kv_start_idx,
972
+ window_kv_indices,
973
+ req_to_token.stride(0),
974
+ )
975
+ return window_kv_indptr, window_kv_lens