sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,6 @@ from typing import TYPE_CHECKING, Optional, Union
11
11
  import torch
12
12
  import triton
13
13
 
14
- from sglang.global_config import global_config
15
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18
16
  from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -22,7 +20,6 @@ from sglang.srt.utils import is_cuda
22
20
  if TYPE_CHECKING:
23
21
  from sglang.srt.layers.radix_attention import RadixAttention
24
22
  from sglang.srt.model_executor.model_runner import ModelRunner
25
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
23
  from sglang.srt.speculative.spec_info import SpecInfo
27
24
 
28
25
  _is_cuda = is_cuda()
@@ -108,7 +105,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
108
105
  PAGE_SIZE,
109
106
  )
110
107
  workspace_size = cutlass_mla_get_workspace_size(
111
- max_seqlen_pad * PAGE_SIZE, bs
108
+ max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
112
109
  )
113
110
  workspace = torch.empty(
114
111
  workspace_size, device="cuda", dtype=torch.uint8
@@ -125,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
125
122
  def init_cuda_graph_state(
126
123
  self,
127
124
  max_bs: int,
125
+ max_num_tokens: int,
128
126
  block_kv_indices: Optional[torch.Tensor] = None,
129
127
  ):
130
128
  if block_kv_indices is None:
@@ -138,7 +136,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
138
136
  cuda_graph_kv_indices = block_kv_indices
139
137
 
140
138
  workspace_size = cutlass_mla_get_workspace_size(
141
- cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
139
+ cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
142
140
  )
143
141
  self.cuda_graph_mla_workspace = torch.empty(
144
142
  workspace_size, device="cuda", dtype=torch.uint8
@@ -233,29 +231,55 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
233
231
  layer: RadixAttention,
234
232
  forward_batch: ForwardBatch,
235
233
  save_kv_cache: bool = True,
234
+ # For multi-head latent attention
235
+ q_rope: Optional[torch.Tensor] = None,
236
+ k_rope: Optional[torch.Tensor] = None,
236
237
  ):
237
238
  cache_loc = forward_batch.out_cache_loc
238
239
 
239
240
  if k is not None:
240
241
  assert v is not None
241
242
  if save_kv_cache:
242
- forward_batch.token_to_kv_pool.set_kv_buffer(
243
- layer,
244
- cache_loc,
245
- k,
246
- v,
247
- )
248
- bs = forward_batch.batch_size
249
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
243
+ if k_rope is not None:
244
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
245
+ layer,
246
+ cache_loc,
247
+ k,
248
+ k_rope,
249
+ )
250
+ else:
251
+ forward_batch.token_to_kv_pool.set_kv_buffer(
252
+ layer,
253
+ cache_loc,
254
+ k,
255
+ v,
256
+ )
257
+
258
+ # Reshape inputs
259
+ if q_rope is not None:
260
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
261
+ q_rope = q_rope.view(
262
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
263
+ )
264
+ else:
265
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
266
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
267
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
250
268
 
251
- reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269
+ q_nope = q_nope.to(self.q_data_type)
270
+ q_rope = q_rope.to(self.q_data_type)
271
+
272
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
252
273
 
253
274
  o = cutlass_mla_decode(
254
- q_nope_and_q_pe=reshape_q.to(self.q_data_type),
275
+ q_nope=q_nope,
276
+ q_pe=q_rope,
255
277
  kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
256
278
  seq_lens=forward_batch.seq_lens.to(torch.int32),
257
279
  page_table=self.forward_metadata.block_kv_indices,
258
280
  workspace=self.forward_metadata.workspace,
281
+ sm_scale=layer.scaling,
282
+ num_kv_splits=1,
259
283
  )
260
284
 
261
285
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -11,7 +11,6 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
13
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
14
- from sglang.srt.utils import get_compiler_backend
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -394,7 +393,6 @@ class FlashAttentionBackend(AttentionBackend):
394
393
  dtype=torch.int32,
395
394
  )
396
395
  metadata_expand.max_seq_len_q = 1
397
- metadata_expand.max_seq_len_k = self.speculative_step_id + 1
398
396
  metadata_expand.cu_seqlens_q = torch.arange(
399
397
  0,
400
398
  metadata_expand.cache_seqlens_int32.numel() + 1,
@@ -408,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
408
406
  dtype=torch.int32,
409
407
  device=device,
410
408
  )
409
+ # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
411
410
  cache_loc = forward_batch.out_cache_loc.view(
412
- self.speculative_num_steps, -1
413
- ).T.contiguous()
411
+ -1, self.speculative_num_steps
412
+ )
414
413
  metadata_expand.page_table = (
415
414
  cache_loc[:, :decode_length].contiguous().to(torch.int32)
416
415
  )
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
550
549
  ),
551
550
  (1, 0),
552
551
  )
553
- metadata_expand.max_seq_len_k = (
554
- metadata_expand.cache_seqlens_int32.max().item()
555
- )
556
552
  self.forward_metadata_spec_decode_expand = metadata_expand
557
553
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
558
554
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
@@ -1124,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
1124
1120
 
1125
1121
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
1126
1122
 
1127
- def init_cuda_graph_state(self, max_bs: int):
1123
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1128
1124
  """Initialize CUDA graph state for the attention backend.
1129
1125
 
1130
1126
  Args:
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
1421
1417
  ]
1422
1418
  )
1423
1419
  metadata_expand.max_seq_len_q = 1
1424
- metadata_expand.max_seq_len_k = (
1425
- self.speculative_step_id + 1
1426
- ) # , do this in replay
1427
1420
  metadata_expand.cu_seqlens_q = (
1428
1421
  self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
1429
1422
  : bs * self.topk + 1
@@ -1469,7 +1462,7 @@ class FlashAttentionBackend(AttentionBackend):
1469
1462
  "cache_seqlens"
1470
1463
  ][:bs]
1471
1464
  metadata.cache_seqlens_int32.copy_(
1472
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1465
+ (seq_lens + self.speculative_num_draft_tokens)
1473
1466
  )
1474
1467
 
1475
1468
  metadata.max_seq_len_q = self.speculative_num_draft_tokens
@@ -1536,7 +1529,7 @@ class FlashAttentionBackend(AttentionBackend):
1536
1529
  metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
1537
1530
  :bs
1538
1531
  ]
1539
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1532
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1540
1533
 
1541
1534
  num_tokens_per_bs = num_tokens // bs
1542
1535
  metadata.max_seq_len_q = num_tokens_per_bs
@@ -1600,38 +1593,32 @@ class FlashAttentionBackend(AttentionBackend):
1600
1593
  if spec_info is not None:
1601
1594
  # Draft Decode
1602
1595
  if self.topk <= 1:
1603
- metadata = self.decode_cuda_graph_metadata[bs]
1604
1596
  # When topk = 1, we use the normal decode metadata
1605
- metadata.cache_seqlens_int32.copy_(
1606
- (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1607
- )
1608
-
1609
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1610
- self.speculative_step_id + 1
1611
- )
1612
- metadata.cu_seqlens_k[1:].copy_(
1613
- torch.cumsum(
1614
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1615
- )
1616
- )
1617
-
1597
+ metadata = self.decode_cuda_graph_metadata[bs]
1598
+ max_len = seq_lens_cpu.max().item()
1599
+ metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
1618
1600
  max_seq_pages = (
1619
1601
  metadata.max_seq_len_k + self.page_size - 1
1620
1602
  ) // self.page_size
1621
- page_indices = self.req_to_token[
1622
- req_pool_indices[:, None],
1623
- self.decode_cuda_graph_metadata["strided_indices"][
1624
- :max_seq_pages
1625
- ],
1626
- ]
1627
1603
 
1628
- page_indices //= self.page_size
1629
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1604
+ normal_decode_set_medadata(
1605
+ metadata.cache_seqlens_int32,
1606
+ metadata.cu_seqlens_k,
1607
+ metadata.page_table,
1608
+ self.req_to_token,
1609
+ req_pool_indices,
1610
+ self.decode_cuda_graph_metadata["strided_indices"],
1611
+ max_seq_pages,
1612
+ seq_lens,
1613
+ self.speculative_step_id + 1,
1614
+ self.page_size,
1615
+ )
1616
+
1630
1617
  else:
1631
1618
  # When top k > 1, we need two specific draft decode metadata, and then merge states
1632
1619
  # 1. The first half of metadata for prefix tokens
1633
1620
  metadata = self.draft_decode_metadata_topk_normal[bs]
1634
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1621
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1635
1622
  # metadata.max_seq_len_q = self.topk, already set in capture
1636
1623
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1637
1624
  # metadata.cu_seqlens_q already set in capture
@@ -1650,11 +1637,10 @@ class FlashAttentionBackend(AttentionBackend):
1650
1637
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1651
1638
  metadata_expand = self.draft_decode_metadata_topk_expand[bs]
1652
1639
  decode_length = self.speculative_step_id + 1
1653
- cache_loc = out_cache_loc.view(
1654
- self.speculative_num_steps, -1
1655
- ).T.contiguous()
1640
+ # shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
1641
+ cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
1656
1642
  metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1657
- cache_loc[:, :decode_length].contiguous().to(torch.int32)
1643
+ cache_loc[:, :decode_length]
1658
1644
  )
1659
1645
  # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
1660
1646
  else:
@@ -1665,12 +1651,15 @@ class FlashAttentionBackend(AttentionBackend):
1665
1651
  metadata.max_seq_len_k = max_len
1666
1652
 
1667
1653
  normal_decode_set_medadata(
1668
- metadata,
1654
+ metadata.cache_seqlens_int32,
1655
+ metadata.cu_seqlens_k,
1656
+ metadata.page_table,
1669
1657
  self.req_to_token,
1670
1658
  req_pool_indices,
1671
1659
  self.decode_cuda_graph_metadata["strided_indices"],
1672
1660
  max_seq_pages,
1673
1661
  seq_lens,
1662
+ 0,
1674
1663
  self.page_size,
1675
1664
  )
1676
1665
 
@@ -1679,7 +1668,7 @@ class FlashAttentionBackend(AttentionBackend):
1679
1668
  if self.topk <= 1:
1680
1669
  metadata = self.target_verify_metadata[bs]
1681
1670
  metadata.cache_seqlens_int32.copy_(
1682
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1671
+ (seq_lens + self.speculative_num_draft_tokens)
1683
1672
  )
1684
1673
 
1685
1674
  metadata.max_seq_len_k = (
@@ -1701,7 +1690,7 @@ class FlashAttentionBackend(AttentionBackend):
1701
1690
  # When topk > 1, we need two specific target verify metadata, and then merge states
1702
1691
  # 1. The first half of metadata for prefix tokens
1703
1692
  metadata = self.target_verify_metadata_topk_normal[bs]
1704
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1693
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1705
1694
  # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1706
1695
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1707
1696
  # metadata.cu_seqlens_q already set in capture
@@ -1715,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
1715
1704
 
1716
1705
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1717
1706
  metadata_expand = self.target_verify_metadata_topk_expand[bs]
1707
+
1718
1708
  # metadata_expand.max_seq_len_q = 1, already set in capture
1719
1709
  # metadata_expand.cu_seqlens_q already set in capture
1720
-
1721
1710
  offsets = torch.arange(
1722
1711
  self.speculative_num_draft_tokens, device=device
1723
1712
  ).unsqueeze(
1724
1713
  0
1725
1714
  ) # shape: (1, self.speculative_num_draft_tokens)
1715
+
1726
1716
  cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1727
1717
  cum_len = torch.nn.functional.pad(
1728
1718
  torch.cumsum(
@@ -1739,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
1739
1729
  ).view(1, -1)
1740
1730
  # avoid extracting padded seq indices which will be out of boundary
1741
1731
  mask_extraction_indices[
1742
- :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1732
+ :,
1733
+ spec_info.positions.numel() * self.speculative_num_draft_tokens :,
1743
1734
  ].fill_(0)
1744
-
1745
1735
  mask = spec_info.custom_mask[mask_extraction_indices].view(
1746
1736
  -1, self.speculative_num_draft_tokens
1747
1737
  ) # (bsz * draft_num, draft_num)
1738
+
1748
1739
  col_indices = offsets.expand(
1749
1740
  mask.shape[0], self.speculative_num_draft_tokens
1750
1741
  )
1751
1742
  keys = torch.where(
1752
- mask, col_indices, col_indices + self.speculative_num_draft_tokens
1743
+ mask,
1744
+ col_indices,
1745
+ col_indices + self.speculative_num_draft_tokens,
1753
1746
  )
1754
1747
  _, sort_order = torch.sort(keys, dim=1)
1755
1748
 
@@ -1758,12 +1751,11 @@ class FlashAttentionBackend(AttentionBackend):
1758
1751
  .gather(1, cols)
1759
1752
  .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1760
1753
  ) # (bsz, draft_num)
1754
+
1761
1755
  metadata_expand.page_table.copy_(
1762
1756
  non_masked_page_table.gather(1, sort_order)
1763
1757
  )
1764
- metadata_expand.cache_seqlens_int32.copy_(
1765
- mask.sum(dim=1).to(torch.int32)
1766
- )
1758
+ metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
1767
1759
  metadata_expand.cu_seqlens_k[1:].copy_(
1768
1760
  torch.cumsum(
1769
1761
  metadata_expand.cache_seqlens_int32,
@@ -1771,19 +1763,21 @@ class FlashAttentionBackend(AttentionBackend):
1771
1763
  dtype=torch.int32,
1772
1764
  )
1773
1765
  )
1774
- metadata_expand.max_seq_len_k = (
1775
- metadata_expand.cache_seqlens_int32.max().item()
1776
- )
1766
+
1777
1767
  elif forward_mode.is_draft_extend():
1778
1768
  metadata = self.draft_extend_metadata[bs]
1779
- metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1769
+ metadata.cache_seqlens_int32.copy_(seq_lens)
1780
1770
 
1781
1771
  metadata.max_seq_len_k = seq_lens_cpu.max().item()
1782
1772
  metadata.cu_seqlens_k[1:].copy_(
1783
1773
  torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1784
1774
  )
1785
1775
  accept_length = spec_info.accept_length[:bs]
1786
- metadata.max_seq_len_q = accept_length.max().item()
1776
+ if spec_info.accept_length_cpu:
1777
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1778
+ else:
1779
+ metadata.max_seq_len_q = 1
1780
+
1787
1781
  metadata.cu_seqlens_q[1:].copy_(
1788
1782
  torch.cumsum(accept_length, dim=0, dtype=torch.int32)
1789
1783
  )
@@ -1795,8 +1789,7 @@ class FlashAttentionBackend(AttentionBackend):
1795
1789
  req_pool_indices[:, None],
1796
1790
  self.draft_extend_metadata["strided_indices"][:max_seq_pages],
1797
1791
  ]
1798
- page_indices //= self.page_size
1799
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1792
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
1800
1793
 
1801
1794
  if encoder_lens is not None:
1802
1795
  # Only support encoder size 1 for now
@@ -1824,7 +1817,7 @@ class FlashAttentionBackend(AttentionBackend):
1824
1817
 
1825
1818
  def get_cuda_graph_seq_len_fill_value(self):
1826
1819
  """Get the fill value for sequence length in CUDA graph."""
1827
- return 0
1820
+ return 1
1828
1821
 
1829
1822
  def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1830
1823
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
@@ -2016,9 +2009,9 @@ class FlashAttentionMultiStepBackend:
2016
2009
  for i in range(self.speculative_num_steps - 1):
2017
2010
  self.attn_backends[i].init_forward_metadata(forward_batch)
2018
2011
 
2019
- def init_cuda_graph_state(self, max_bs: int):
2012
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
2020
2013
  for i in range(self.speculative_num_steps):
2021
- self.attn_backends[i].init_cuda_graph_state(max_bs)
2014
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
2022
2015
 
2023
2016
  def init_forward_metadata_capture_cuda_graph(
2024
2017
  self,
@@ -2045,6 +2038,8 @@ class FlashAttentionMultiStepBackend:
2045
2038
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
2046
2039
 
2047
2040
  for i in range(self.speculative_num_steps - 1):
2041
+ # TODO: incrementally update the metadata for the later steps,
2042
+ # so that they do not need to recompute everything from scratch.
2048
2043
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
2049
2044
  bs,
2050
2045
  forward_batch.req_pool_indices,
@@ -2058,21 +2053,25 @@ class FlashAttentionMultiStepBackend:
2058
2053
  )
2059
2054
 
2060
2055
 
2061
- @torch.compile(dynamic=True, backend=get_compiler_backend())
2056
+ # @torch.compile(dynamic=True, backend=get_compiler_backend())
2057
+ # TODO: fuse these kernels
2058
+ # NOTE: torch.compile makes it slower in speculative decoding
2062
2059
  def normal_decode_set_medadata(
2063
- metadata,
2064
- req_to_token,
2065
- req_pool_indices,
2066
- strided_indices,
2067
- max_seq_pages,
2068
- seq_lens,
2069
- page_size,
2060
+ cache_seqlens_int32: torch.Tensor,
2061
+ cu_seqlens_k: torch.Tensor,
2062
+ page_table: torch.Tensor,
2063
+ req_to_token: torch.Tensor,
2064
+ req_pool_indices: torch.Tensor,
2065
+ strided_indices: torch.Tensor,
2066
+ max_seq_pages: torch.Tensor,
2067
+ seq_lens: torch.Tensor,
2068
+ seq_len_delta: int,
2069
+ page_size: int,
2070
2070
  ):
2071
- metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
2072
- metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
2071
+ cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
2072
+ cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
2073
2073
  page_indices = req_to_token[
2074
2074
  req_pool_indices[:, None],
2075
2075
  strided_indices[:max_seq_pages][None, :],
2076
2076
  ]
2077
- metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
2078
- metadata.page_table[:, max_seq_pages:].fill_(0)
2077
+ page_table[:, :max_seq_pages].copy_(page_indices // page_size)
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
262
262
  )
263
263
 
264
264
  def init_cuda_graph_state(
265
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
265
+ self,
266
+ max_bs: int,
267
+ max_num_tokens: int,
268
+ kv_indices_buf: Optional[torch.Tensor] = None,
266
269
  ):
267
270
  if kv_indices_buf is None:
268
271
  cuda_graph_kv_indices = torch.zeros(
269
- (max_bs * self.max_context_len,),
272
+ (max_num_tokens * self.max_context_len,),
270
273
  dtype=torch.int32,
271
274
  device="cuda",
272
275
  )
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
285
288
 
286
289
  if not self.skip_prefill:
287
290
  self.cuda_graph_custom_mask = torch.zeros(
288
- (max_bs * self.max_context_len),
291
+ (max_num_tokens * self.max_context_len),
289
292
  dtype=torch.uint8,
290
293
  device="cuda",
291
294
  )
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
440
443
  raise ValueError("Invalid forward mode")
441
444
 
442
445
  def get_cuda_graph_seq_len_fill_value(self):
443
- return 0
446
+ return 1
444
447
 
445
448
  def forward_extend(
446
449
  self,
@@ -1049,14 +1052,13 @@ class FlashInferMultiStepDraftBackend:
1049
1052
  kv_indices_buffer,
1050
1053
  self.kv_indptr,
1051
1054
  forward_batch.positions,
1052
- num_seqs,
1053
- self.topk,
1054
1055
  self.pool_len,
1055
1056
  kv_indices_buffer.shape[1],
1056
1057
  self.kv_indptr.shape[1],
1057
1058
  next_power_of_2(num_seqs),
1058
1059
  next_power_of_2(self.speculative_num_steps),
1059
1060
  next_power_of_2(bs),
1061
+ self.page_size,
1060
1062
  )
1061
1063
 
1062
1064
  assert forward_batch.spec_info is not None
@@ -1097,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
1097
1099
 
1098
1100
  self.common_template(forward_batch, kv_indices, call_fn)
1099
1101
 
1100
- def init_cuda_graph_state(self, max_bs: int):
1102
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1101
1103
  self.cuda_graph_kv_indices = torch.zeros(
1102
1104
  (self.speculative_num_steps, max_bs * self.max_context_len),
1103
1105
  dtype=torch.int32,
@@ -1106,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
1106
1108
 
1107
1109
  for i in range(self.speculative_num_steps):
1108
1110
  self.attn_backends[i].init_cuda_graph_state(
1109
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
1111
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1110
1112
  )
1111
1113
 
1112
1114
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -15,7 +15,6 @@ from functools import partial
15
15
  from typing import TYPE_CHECKING, Callable, Optional, Union
16
16
 
17
17
  import torch
18
- import triton
19
18
 
20
19
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
20
  import logging
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
33
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
34
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
34
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
36
- from sglang.srt.utils import is_flashinfer_available
35
+ from sglang.srt.utils import is_flashinfer_available, next_power_of_2
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -200,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
200
199
  )
201
200
 
202
201
  def init_cuda_graph_state(
203
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
202
+ self,
203
+ max_bs: int,
204
+ max_num_tokens: int,
205
+ kv_indices_buf: Optional[torch.Tensor] = None,
204
206
  ):
205
207
  if kv_indices_buf is None:
206
208
  cuda_graph_kv_indices = torch.zeros(
@@ -365,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
365
367
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
366
368
 
367
369
  def get_cuda_graph_seq_len_fill_value(self):
368
- return 0
370
+ return 1
369
371
 
370
372
  def forward_extend(
371
373
  self,
@@ -756,7 +758,7 @@ class FlashInferMLAMultiStepDraftBackend:
756
758
 
757
759
  if topk > 1:
758
760
  raise ValueError(
759
- f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
761
+ "Currently Flashinfer MLA only supports topk=1 for speculative decoding"
760
762
  )
761
763
  self.topk = topk
762
764
  self.speculative_num_steps = speculative_num_steps
@@ -790,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
790
792
 
791
793
  # Cached variables for generate_draft_decode_kv_indices
792
794
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
795
+ self.page_size = model_runner.server_args.page_size
793
796
 
794
797
  def common_template(
795
798
  self,
@@ -810,14 +813,13 @@ class FlashInferMLAMultiStepDraftBackend:
810
813
  kv_indices_buffer,
811
814
  self.kv_indptr,
812
815
  forward_batch.positions,
813
- num_seqs,
814
- self.topk,
815
816
  self.pool_len,
816
817
  kv_indices_buffer.shape[1],
817
818
  self.kv_indptr.shape[1],
818
- triton.next_power_of_2(num_seqs),
819
- triton.next_power_of_2(self.speculative_num_steps),
820
- triton.next_power_of_2(bs),
819
+ next_power_of_2(num_seqs),
820
+ next_power_of_2(self.speculative_num_steps),
821
+ next_power_of_2(bs),
822
+ self.page_size,
821
823
  )
822
824
 
823
825
  assert forward_batch.spec_info is not None
@@ -853,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
853
855
 
854
856
  self.common_template(forward_batch, kv_indices, call_fn)
855
857
 
856
- def init_cuda_graph_state(self, max_bs: int):
858
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
857
859
  self.cuda_graph_kv_indices = torch.zeros(
858
860
  (self.speculative_num_steps, max_bs * self.max_context_len),
859
861
  dtype=torch.int32,
@@ -862,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
862
864
 
863
865
  for i in range(self.speculative_num_steps):
864
866
  self.attn_backends[i].init_cuda_graph_state(
865
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
867
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
866
868
  )
867
869
 
868
870
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -920,19 +922,18 @@ def fast_mla_decode_plan(
920
922
  self._page_size = page_size
921
923
  self._sm_scale = sm_scale
922
924
 
923
- with self.device as device:
924
- try:
925
- # Standard version with just the required arguments (no use_profiler)
926
- self._cached_module.plan.default(
927
- self._float_workspace_buffer,
928
- self._int_workspace_buffer,
929
- self._pin_memory_int_workspace_buffer,
930
- qo_indptr_cpu,
931
- kv_indptr_cpu,
932
- kv_len_arr_cpu,
933
- num_heads,
934
- head_dim_ckv,
935
- causal,
936
- )
937
- except Exception as e:
938
- raise RuntimeError(f"Error in alternate MLA plan: {e}")
925
+ try:
926
+ # Standard version with just the required arguments (no use_profiler)
927
+ self._cached_module.plan.default(
928
+ self._float_workspace_buffer,
929
+ self._int_workspace_buffer,
930
+ self._pin_memory_int_workspace_buffer,
931
+ qo_indptr_cpu,
932
+ kv_indptr_cpu,
933
+ kv_len_arr_cpu,
934
+ num_heads,
935
+ head_dim_ckv,
936
+ causal,
937
+ )
938
+ except Exception as e:
939
+ raise RuntimeError(f"Error in alternate MLA plan: {e}")