sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
11
11
 
12
12
  from dataclasses import dataclass
13
13
  from functools import partial
14
- from typing import TYPE_CHECKING, Optional, Union
14
+ from typing import TYPE_CHECKING, Callable, Optional, Union
15
15
 
16
16
  import torch
17
+ import triton
17
18
 
18
19
  from sglang.global_config import global_config
19
20
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
23
24
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
25
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
28
  from sglang.srt.utils import is_flashinfer_available
27
29
 
28
30
  if TYPE_CHECKING:
@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
58
60
  def __init__(
59
61
  self,
60
62
  model_runner: ModelRunner,
63
+ skip_prefill: bool = False,
64
+ kv_indptr_buf: Optional[torch.Tensor] = None,
65
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
61
66
  ):
62
67
  super().__init__()
63
68
 
64
69
  # Parse constants
65
70
  self.max_context_len = model_runner.model_config.context_len
66
71
  self.device = model_runner.device
72
+ self.skip_prefill = skip_prefill
67
73
 
68
74
  global_config.enable_flashinfer_mla = True
69
75
 
@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend):
78
84
  self.workspace_buffer = global_workspace_buffer
79
85
 
80
86
  max_bs = model_runner.req_to_token_pool.size
81
- self.kv_indptr = torch.zeros(
82
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
83
- )
87
+ if kv_indptr_buf is None:
88
+ self.kv_indptr = torch.zeros(
89
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
90
+ )
91
+ else:
92
+ self.kv_indptr = kv_indptr_buf
84
93
 
85
- self.qo_indptr = torch.zeros(
86
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
87
- )
94
+ if not self.skip_prefill:
95
+ self.qo_indptr = torch.zeros(
96
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
97
+ )
88
98
 
89
- self.q_indptr_decode = torch.arange(
90
- 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
91
- )
99
+ if q_indptr_decode_buf is None:
100
+ self.q_indptr_decode = torch.arange(
101
+ 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
102
+ )
103
+ else:
104
+ self.q_indptr_decode = q_indptr_decode_buf
92
105
 
93
106
  self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
94
107
  self.workspace_buffer, "NHD"
95
108
  )
96
109
 
97
- self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
98
- self.workspace_buffer,
99
- backend="auto",
100
- )
110
+ if not self.skip_prefill:
111
+ self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
112
+ self.workspace_buffer,
113
+ backend="auto",
114
+ )
115
+
116
+ # FlashinferMLA backend uses mla wrapper for target verify
117
+ self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
118
+ self.workspace_buffer,
119
+ backend="auto",
120
+ )
101
121
 
102
122
  self.decode_wrapper = BatchMLAPagedAttentionWrapper(
103
123
  self.workspace_buffer, backend="auto"
104
124
  )
105
125
 
106
126
  # Create indices updater
107
- self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
108
- model_runner, self
109
- )
127
+ if not skip_prefill:
128
+ self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
129
+ model_runner, self
130
+ )
131
+
110
132
  self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
111
133
  model_runner, self
112
134
  )
@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
114
136
  # Other metadata
115
137
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
116
138
  self.decode_cuda_graph_metadata = {}
117
- self.prefill_cuda_graph_metadata = {}
139
+ self.prefill_cuda_graph_metadata = {} # For verify
118
140
 
119
141
  def init_forward_metadata(self, forward_batch: ForwardBatch):
120
142
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
126
148
  init_metadata_replay=False,
127
149
  )
128
150
  self.forward_metadata = DecodeMetadata(self.decode_wrapper)
151
+ elif forward_batch.forward_mode.is_draft_extend():
152
+ self.indices_updater_prefill.update(
153
+ forward_batch.req_pool_indices,
154
+ forward_batch.seq_lens,
155
+ forward_batch.seq_lens_sum,
156
+ prefix_lens=None,
157
+ prefill_wrapper_paged=self.prefill_wrapper_paged,
158
+ use_ragged=False,
159
+ spec_info=forward_batch.spec_info,
160
+ )
161
+ self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
162
+ elif forward_batch.forward_mode.is_target_verify():
163
+ self.indices_updater_prefill.update(
164
+ forward_batch.req_pool_indices,
165
+ forward_batch.seq_lens,
166
+ forward_batch.seq_lens_sum,
167
+ prefix_lens=None,
168
+ prefill_wrapper_paged=self.prefill_wrapper_verify,
169
+ use_ragged=False,
170
+ spec_info=forward_batch.spec_info,
171
+ )
172
+ self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
129
173
  else:
130
174
  prefix_lens = forward_batch.extend_prefix_lens
131
175
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
202
246
  seq_lens_sum,
203
247
  decode_wrapper=decode_wrapper,
204
248
  init_metadata_replay=False,
249
+ spec_info=spec_info,
205
250
  )
206
251
  self.decode_cuda_graph_metadata[bs] = decode_wrapper
207
252
  self.forward_metadata = DecodeMetadata(decode_wrapper)
208
253
  decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
254
+ elif forward_mode.is_target_verify():
255
+ verify_wrapper = BatchMLAPagedAttentionWrapper(
256
+ self.workspace_buffer,
257
+ use_cuda_graph=True,
258
+ qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
259
+ kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
260
+ kv_indices=self.cuda_graph_kv_indices,
261
+ kv_len_arr=self.cuda_graph_kv_lens[:bs],
262
+ backend="auto",
263
+ )
264
+ seq_lens_sum = seq_lens.sum().item()
265
+ self.indices_updater_prefill.update(
266
+ req_pool_indices,
267
+ seq_lens,
268
+ seq_lens_sum,
269
+ prefix_lens=None,
270
+ prefill_wrapper_paged=verify_wrapper,
271
+ use_ragged=False,
272
+ spec_info=spec_info,
273
+ )
274
+ self.prefill_cuda_graph_metadata[bs] = verify_wrapper
275
+ self.forward_metadata = PrefillMetadata(verify_wrapper, False)
209
276
  else:
210
277
  raise ValueError(f"Invalid mode: {forward_mode=}")
211
278
 
@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
221
288
  seq_lens_cpu: Optional[torch.Tensor],
222
289
  ):
223
290
  if forward_mode.is_decode_or_idle():
291
+ assert seq_lens_cpu is not None
224
292
  kv_len_arr_cpu = seq_lens_cpu[:bs]
225
293
  self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
226
294
  kv_len_arr_cpu, dim=0
@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
239
307
  seq_lens_sum,
240
308
  decode_wrapper=self.decode_cuda_graph_metadata[bs],
241
309
  init_metadata_replay=True,
310
+ spec_info=spec_info,
242
311
  **self.fast_decode_kwargs,
243
312
  )
313
+ elif forward_mode.is_target_verify():
314
+ self.indices_updater_prefill.update(
315
+ req_pool_indices[:bs],
316
+ seq_lens[:bs],
317
+ seq_lens_sum,
318
+ prefix_lens=None,
319
+ prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
320
+ use_ragged=False,
321
+ spec_info=spec_info,
322
+ )
244
323
  else:
245
324
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
246
325
 
@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
254
333
  v: torch.Tensor,
255
334
  layer: RadixAttention,
256
335
  forward_batch: ForwardBatch,
257
- save_kv_cache=True,
336
+ save_kv_cache: bool = True,
258
337
  ):
259
338
 
260
339
  cache_loc = forward_batch.out_cache_loc
@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
297
376
  v: torch.Tensor,
298
377
  layer: RadixAttention,
299
378
  forward_batch: ForwardBatch,
300
- save_kv_cache=True,
379
+ save_kv_cache: bool = True,
301
380
  ):
302
381
  decode_wrapper = self.forward_metadata.decode_wrapper
303
382
  cache_loc = forward_batch.out_cache_loc
@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode:
349
428
  seq_lens_sum: int,
350
429
  decode_wrapper: BatchMLAPagedAttentionWrapper,
351
430
  init_metadata_replay: bool = False,
431
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
352
432
  **fast_decode_kwargs,
353
433
  ):
354
434
  decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode:
360
440
  self.q_indptr,
361
441
  self.kv_indptr,
362
442
  init_metadata_replay,
443
+ spec_info,
363
444
  **fast_decode_kwargs,
364
445
  )
365
446
 
@@ -372,30 +453,33 @@ class FlashInferMLAIndicesUpdaterDecode:
372
453
  q_indptr: torch.Tensor,
373
454
  kv_indptr: torch.Tensor,
374
455
  init_metadata_replay: bool = False,
456
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
375
457
  **fast_decode_kwargs,
376
458
  ):
377
459
  bs = len(req_pool_indices)
378
460
  q_indptr = q_indptr[: bs + 1]
379
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
380
- kv_indptr = kv_indptr[: bs + 1]
381
- kv_indices = (
382
- torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
383
- if not init_metadata_replay
384
- else fast_decode_kwargs["kv_indices"]
385
- )
386
-
387
461
  kv_lens = paged_kernel_lens.to(torch.int32)
388
462
  sm_scale = self.scaling
463
+ if spec_info is None:
464
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
465
+ kv_indptr = kv_indptr[: bs + 1]
466
+ kv_indices = (
467
+ torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
468
+ if not init_metadata_replay
469
+ else fast_decode_kwargs["kv_indices"]
470
+ )
471
+ create_flashinfer_kv_indices_triton[(bs,)](
472
+ self.req_to_token,
473
+ req_pool_indices,
474
+ paged_kernel_lens,
475
+ kv_indptr,
476
+ None,
477
+ kv_indices,
478
+ self.req_to_token.shape[1],
479
+ )
480
+ else:
481
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
389
482
 
390
- create_flashinfer_kv_indices_triton[(bs,)](
391
- self.req_to_token,
392
- req_pool_indices,
393
- paged_kernel_lens,
394
- kv_indptr,
395
- None,
396
- kv_indices,
397
- self.req_to_token.shape[1],
398
- )
399
483
  if not init_metadata_replay:
400
484
  wrapper.plan(
401
485
  q_indptr,
@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
457
541
  prefix_lens: torch.Tensor,
458
542
  prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
459
543
  use_ragged: bool,
544
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
460
545
  ):
461
546
  if use_ragged:
462
547
  paged_kernel_lens = prefix_lens
@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
476
561
  self.kv_indptr,
477
562
  self.qo_indptr,
478
563
  use_ragged,
564
+ spec_info,
479
565
  )
480
566
 
481
567
  def call_begin_forward(
@@ -490,29 +576,46 @@ class FlashInferMLAIndicesUpdaterPrefill:
490
576
  kv_indptr: torch.Tensor,
491
577
  qo_indptr: torch.Tensor,
492
578
  use_ragged: bool,
579
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
493
580
  ):
494
- bs = len(req_pool_indices)
495
- kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
496
- kv_indptr = kv_indptr[: bs + 1]
497
- kv_indices = torch.empty(
498
- paged_kernel_lens_sum,
499
- dtype=torch.int32,
500
- device=req_pool_indices.device,
501
- )
502
- create_flashinfer_kv_indices_triton[(bs,)](
503
- self.req_to_token,
504
- req_pool_indices,
505
- paged_kernel_lens,
506
- kv_indptr,
507
- None,
508
- kv_indices,
509
- self.req_to_token.shape[1],
510
- )
511
-
512
- qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
513
- qo_indptr = qo_indptr[: bs + 1]
581
+ bs = len(seq_lens)
514
582
  sm_scale = self.scaling
515
583
 
584
+ if spec_info is None:
585
+ assert len(seq_lens) == len(req_pool_indices)
586
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
587
+ kv_indptr = kv_indptr[: bs + 1]
588
+ kv_indices = torch.empty(
589
+ paged_kernel_lens_sum,
590
+ dtype=torch.int32,
591
+ device=req_pool_indices.device,
592
+ )
593
+ create_flashinfer_kv_indices_triton[(bs,)](
594
+ self.req_to_token,
595
+ req_pool_indices,
596
+ paged_kernel_lens,
597
+ kv_indptr,
598
+ None,
599
+ kv_indices,
600
+ self.req_to_token.shape[1],
601
+ )
602
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
603
+ qo_indptr = qo_indptr[: bs + 1]
604
+ custom_mask = None
605
+ else:
606
+ assert isinstance(spec_info, EagleDraftInput) or isinstance(
607
+ spec_info, EagleVerifyInput
608
+ )
609
+ # TODO: Support topk > 1 with custom mask
610
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
611
+ spec_info.generate_attn_arg_prefill(
612
+ req_pool_indices,
613
+ paged_kernel_lens,
614
+ paged_kernel_lens_sum,
615
+ self.req_to_token,
616
+ )
617
+ )
618
+
516
619
  if use_ragged:
517
620
  # ragged prefill
518
621
  wrapper_ragged.begin_forward(
@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill:
543
646
  )
544
647
 
545
648
 
649
+ class FlashInferMLAMultiStepDraftBackend:
650
+ """
651
+ Wrap multiple flashinfer mla attention backends as one for multiple consecutive
652
+ draft decoding steps.
653
+ """
654
+
655
+ def __init__(
656
+ self,
657
+ model_runner: ModelRunner,
658
+ topk: int,
659
+ speculative_num_steps: int,
660
+ ):
661
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
662
+
663
+ if topk > 1:
664
+ raise ValueError(
665
+ f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
666
+ )
667
+ self.topk = topk
668
+ self.speculative_num_steps = speculative_num_steps
669
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
670
+
671
+ max_bs = model_runner.req_to_token_pool.size * self.topk
672
+ self.kv_indptr = torch.zeros(
673
+ (
674
+ self.speculative_num_steps,
675
+ max_bs + 1,
676
+ ),
677
+ dtype=torch.int32,
678
+ device=model_runner.device,
679
+ )
680
+ self.q_indptr_decode = torch.arange(
681
+ 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
682
+ )
683
+
684
+ self.attn_backends = []
685
+ for i in range(self.speculative_num_steps):
686
+ self.attn_backends.append(
687
+ FlashInferMLAAttnBackend(
688
+ model_runner,
689
+ skip_prefill=True,
690
+ kv_indptr_buf=self.kv_indptr[i],
691
+ q_indptr_decode_buf=self.q_indptr_decode,
692
+ )
693
+ )
694
+
695
+ self.max_context_len = self.attn_backends[0].max_context_len
696
+
697
+ # Cached variables for generate_draft_decode_kv_indices
698
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
699
+
700
+ def common_template(
701
+ self,
702
+ forward_batch: ForwardBatch,
703
+ kv_indices_buffer: torch.Tensor,
704
+ call_fn: Callable,
705
+ ):
706
+ num_seqs = forward_batch.batch_size
707
+ bs = self.topk * num_seqs
708
+ seq_lens_sum = forward_batch.seq_lens_sum
709
+
710
+ self.generate_draft_decode_kv_indices[
711
+ (self.speculative_num_steps, num_seqs, self.topk)
712
+ ](
713
+ forward_batch.req_pool_indices,
714
+ forward_batch.req_to_token_pool.req_to_token,
715
+ forward_batch.seq_lens,
716
+ kv_indices_buffer,
717
+ self.kv_indptr,
718
+ forward_batch.positions,
719
+ num_seqs,
720
+ self.topk,
721
+ self.pool_len,
722
+ kv_indices_buffer.shape[1],
723
+ self.kv_indptr.shape[1],
724
+ triton.next_power_of_2(num_seqs),
725
+ triton.next_power_of_2(self.speculative_num_steps),
726
+ triton.next_power_of_2(bs),
727
+ )
728
+
729
+ assert forward_batch.spec_info is not None
730
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
731
+
732
+ for i in range(self.speculative_num_steps - 1):
733
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
734
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
735
+ : seq_lens_sum * self.topk + bs * (i + 1)
736
+ ]
737
+ call_fn(i, forward_batch)
738
+
739
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
740
+ kv_indices = torch.zeros(
741
+ (
742
+ self.speculative_num_steps,
743
+ forward_batch.batch_size * self.topk * self.max_context_len,
744
+ ),
745
+ dtype=torch.int32,
746
+ device="cuda",
747
+ )
748
+
749
+ def call_fn(i, forward_batch):
750
+ assert forward_batch.spec_info is not None
751
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
752
+ forward_batch.spec_info.kv_indptr = (
753
+ forward_batch.spec_info.kv_indptr.clone()
754
+ )
755
+ forward_batch.spec_info.kv_indices = (
756
+ forward_batch.spec_info.kv_indices.clone()
757
+ )
758
+ self.attn_backends[i].init_forward_metadata(forward_batch)
759
+
760
+ self.common_template(forward_batch, kv_indices, call_fn)
761
+
762
+ def init_cuda_graph_state(self, max_bs: int):
763
+ self.cuda_graph_kv_indices = torch.zeros(
764
+ (self.speculative_num_steps, max_bs * self.max_context_len),
765
+ dtype=torch.int32,
766
+ device="cuda",
767
+ )
768
+
769
+ for i in range(self.speculative_num_steps):
770
+ self.attn_backends[i].init_cuda_graph_state(
771
+ max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
772
+ )
773
+
774
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
775
+ def call_fn(i, forward_batch):
776
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
777
+ forward_batch.batch_size,
778
+ forward_batch.batch_size * self.topk,
779
+ forward_batch.req_pool_indices,
780
+ forward_batch.seq_lens,
781
+ encoder_lens=None,
782
+ forward_mode=ForwardMode.DECODE,
783
+ spec_info=forward_batch.spec_info,
784
+ )
785
+
786
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
787
+
788
+ def init_forward_metadata_replay_cuda_graph(
789
+ self, forward_batch: ForwardBatch, bs: int
790
+ ):
791
+ def call_fn(i, forward_batch):
792
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
793
+ bs,
794
+ forward_batch.req_pool_indices,
795
+ forward_batch.seq_lens,
796
+ seq_lens_sum=-1,
797
+ encoder_lens=None,
798
+ forward_mode=ForwardMode.DECODE,
799
+ spec_info=forward_batch.spec_info,
800
+ seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
801
+ )
802
+
803
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
804
+
805
+
546
806
  def fast_mla_decode_plan(
547
807
  self,
548
808
  qo_indptr_cpu: torch.Tensor,
@@ -6,9 +6,7 @@ import torch
6
6
  import triton
7
7
 
8
8
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
- from sglang.srt.layers.attention.flashinfer_backend import (
10
- create_flashinfer_kv_indices_triton,
11
- )
9
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
10
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
11
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
12
 
@@ -27,7 +27,7 @@ import triton.language as tl
27
27
 
28
28
  from sglang.srt.utils import is_hip
29
29
 
30
- is_hip_ = is_hip()
30
+ _is_hip = is_hip()
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
180
180
  ):
181
181
  BLOCK = 64
182
182
  # [TODO] work around SGPR limit on MI3xx
183
- if is_hip_:
183
+ if _is_hip:
184
184
  BLOCK = 8
185
185
  NUM_KV_SPLITS = num_kv_splits
186
186
  Lk = k_buffer.shape[-1]
@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
195
195
  num_warps = 4
196
196
  else:
197
197
  num_warps = 2
198
- if is_hip_:
198
+ if _is_hip:
199
199
  num_warps = 1
200
200
 
201
201
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
406
406
  Lv = v_buffer.shape[-1]
407
407
 
408
408
  # [TODO] work around shmem limit on MI3xx
409
- if is_hip_ and Lk >= 576:
409
+ if _is_hip and Lk >= 576:
410
410
  BLOCK = 16
411
411
 
412
412
  if Lk == 576:
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
433
433
 
434
434
  extra_kargs = {}
435
435
  num_stages = 2
436
- if is_hip_:
436
+ if _is_hip:
437
437
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
438
438
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
439
439
  extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
546
546
  NUM_KV_SPLITS = num_kv_splits
547
547
 
548
548
  extra_kargs = {}
549
- if is_hip_:
549
+ if _is_hip:
550
550
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
551
551
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
552
552
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
9
9
  if is_cuda_available:
10
10
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
11
 
12
- is_hip_ = is_hip()
12
+ _is_hip = is_hip()
13
13
 
14
14
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
15
15
  REDUCE_TRITON_TYPE = tl.float32
@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
1032
1032
  BLOCK_DPE = 0
1033
1033
  BLOCK_DV = triton.next_power_of_2(Lv)
1034
1034
 
1035
- if is_hip_:
1035
+ if _is_hip:
1036
1036
  BLOCK_M, BLOCK_N = (64, 64)
1037
1037
  num_warps = 4
1038
1038
 
@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
1062
1062
  num_stages = 1
1063
1063
 
1064
1064
  extra_kargs = {}
1065
- if is_hip_:
1065
+ if _is_hip:
1066
1066
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
1067
1067
 
1068
1068
  _fwd_kernel[grid](
@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
29
29
  if is_cuda_available:
30
30
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
31
31
 
32
- is_hip_ = is_hip()
32
+ _is_hip = is_hip()
33
33
 
34
34
 
35
35
  @triton.jit
@@ -330,7 +330,7 @@ def extend_attention_fwd(
330
330
  BLOCK_DPE = 0
331
331
  BLOCK_DV = triton.next_power_of_2(Lv)
332
332
 
333
- if is_hip_:
333
+ if _is_hip:
334
334
  BLOCK_M, BLOCK_N = (64, 64)
335
335
  num_warps = 4
336
336
 
@@ -364,7 +364,7 @@ def extend_attention_fwd(
364
364
  num_stages = 1
365
365
 
366
366
  extra_kargs = {}
367
- if is_hip_:
367
+ if _is_hip:
368
368
  extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
369
369
 
370
370
  _fwd_kernel[grid](
@@ -403,7 +403,7 @@ def extend_attention_fwd(
403
403
  Lv=Lv,
404
404
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
405
405
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
406
- STORE_TRANSPOSE=is_hip_,
406
+ STORE_TRANSPOSE=_is_hip,
407
407
  num_warps=num_warps,
408
408
  num_stages=num_stages,
409
409
  **extra_kargs,
@@ -32,7 +32,7 @@ def is_hip():
32
32
  return triton.runtime.driver.active.get_current_target().backend == "hip"
33
33
 
34
34
 
35
- is_hip_ = is_hip()
35
+ _is_hip = is_hip()
36
36
 
37
37
 
38
38
  @triton.jit
@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
333
333
  BLOCK = 32
334
334
 
335
335
  # # [TODO] work around shmem limit on MI3xx
336
- # if is_hip_ and kv_lora_rank >= 576:
336
+ # if _is_hip and kv_lora_rank >= 576:
337
337
  # BLOCK = 16
338
338
 
339
339
  qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
353
353
 
354
354
  extra_kargs = {}
355
355
  num_stages = 2
356
- if is_hip_:
356
+ if _is_hip:
357
357
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
358
358
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
359
359
  extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}