sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import numpy as np
4
+
3
5
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
4
6
 
5
7
  """
@@ -22,29 +24,255 @@ if TYPE_CHECKING:
22
24
  from sglang.srt.layers.radix_attention import RadixAttention
23
25
  from sglang.srt.model_executor.model_runner import ModelRunner
24
26
 
25
- from flash_attn_interface import flash_attn_with_kvcache
27
+ from sgl_kernel.flash_attn import flash_attn_with_kvcache
26
28
 
27
29
 
28
30
  @dataclass
29
31
  class FlashAttentionMetadata:
30
- """Metadata for decode operations to avoid redundant computations."""
32
+ """Metadata to be init once in the model forward pass,
33
+ each layer's forward pass can reuse the metadata."""
31
34
 
35
+ # Cumulative sequence lengths for query
32
36
  cu_seqlens_q: torch.Tensor = None
37
+ # Cumulative sequence lengths for key
33
38
  cu_seqlens_k: torch.Tensor = None
39
+ # Maximum sequence length for query
34
40
  max_seq_len_q: int = 0
41
+ # Maximum sequence length for key
35
42
  max_seq_len_k: int = 0
43
+ # Window size (typically used by Gemma)
36
44
  window_size: tuple = (-1, -1)
45
+ # Page table, the index of KV Cache Tables/Blocks
37
46
  page_table: torch.Tensor = None
47
+ # Sequence lengths for the forward batch
38
48
  cache_seqlens_int32: torch.Tensor = None
39
49
 
50
+ @dataclass
51
+ class LocalAttentionMetadata:
52
+ local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
53
+ local_seqused_k: torch.Tensor = None # sequence lengths for local attention
54
+ local_block_table: torch.Tensor = None # block table for local attention
55
+ local_max_query_len: int = 0 # max query length for local attention
56
+ local_max_seq_len: int = 0 # max sequence length for local attention
57
+
58
+ local_attn_metadata: Optional[LocalAttentionMetadata] = None
59
+
60
+
61
+ # Copied from:
62
+ # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
63
+ #
64
+ # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
65
+ # local attention blocks, where each block is passed to the attention kernel
66
+ # as an independent local ("virtual") batch item.
67
+ #
68
+ # For example, if are performing a chunked prefill a batch of 3 sequences:
69
+ # q_seqlens = [4, 10, 5]
70
+ # kv_seqlens = [6, 17, 9]
71
+ # Then normally for regular attention we would compute with an attention mask
72
+ # for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
73
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
74
+ # k_toks > 0 1 2 3 4 5
75
+ # q_toks v _____________
76
+ # 0 | 1 1 1
77
+ # 1 | 1 1 1 1
78
+ # 2 | 1 1 1 1 1
79
+ # 3 | 1 1 1 1 1 1
80
+ #
81
+ # for local attention (with attn_chunk_size = 4) we would compute with an
82
+ # attention mask like:
83
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
84
+ # k_toks > 0 1 2 3 4 5
85
+ # q_toks v _____________
86
+ # 0 | 1 1 1
87
+ # 1 | 1 1 1 1
88
+ # 2 | 1
89
+ # 3 | 1 1
90
+ #
91
+ # We can simulate this mask using standard flash-attention by breaking the
92
+ # sequences into local ("virtual") batches, where each local batch item is a
93
+ # local attention block, so in this case batch idx 0 would be broken up into:
94
+ #
95
+ # local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
96
+ # k_toks > 0 1 2 3
97
+ # q_toks v _____________
98
+ # 0 | 1 1 1
99
+ # 1 | 1 1 1 1
100
+ # local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
101
+ # k_toks > 4 5
102
+ # q_toks v _____________
103
+ # 2 | 1
104
+ # 3 | 1 1
105
+ #
106
+ # e.g. if we have:
107
+ # attn_chunk_size = 4
108
+ # query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
109
+ # Then this function would return:
110
+ # __b0__ ______b1______ __b2__ < orig batch indices
111
+ # q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
112
+ # cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
113
+ # seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
114
+ # block_table_local : shape[local_virtual_batches, pages_per_local_batch]
115
+ def make_local_attention_virtual_batches(
116
+ attn_chunk_size: int,
117
+ query_start_loc_np: np.ndarray,
118
+ seq_lens_np: np.ndarray,
119
+ block_table: torch.Tensor,
120
+ page_size: int = 0,
121
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
122
+ """
123
+ Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
124
+ local attention blocks, where each block is passed to the attention kernel
125
+ as an independent local ("virtual") batch item.
126
+
127
+ Args:
128
+ attn_chunk_size: Size of local attention chunks
129
+ query_start_loc_np: Cumulative sum of query lengths (numpy array)
130
+ seq_lens_np: Sequence lengths (numpy array)
131
+ block_table: Block table for KV cache
132
+ page_size: Size of each page in the KV cache
133
+
134
+ Returns:
135
+ seqlens_q_local: Query sequence lengths for local attention
136
+ cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
137
+ seqlens_k_local: Key sequence lengths for local attention
138
+ block_table_local: Block table for local attention
139
+ """
140
+ q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
141
+ actual_batch_size = seq_lens_np.shape[0]
142
+
143
+ # Handle if we are starting in the middle of a local attention block,
144
+ # we assume q_seqlens > 0 (for all elements), for each batch idx we compute
145
+ # the number of tokens that are not in the first local attention block and
146
+ # then we can simply use a cdiv for the rest.
147
+ # For example if we have:
148
+ # attn_chunk_size = 4
149
+ # q_seqlens = [4, 10, 5]
150
+ # k_seqlens = [6, 17, 9]
151
+ # Then we would get:
152
+ # new_tokens_in_first_block = [2, 1, 4]
153
+ # local_blocks = [2, 4, 2]
154
+ q_tokens_in_first_block = np.minimum(
155
+ attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
156
+ ).astype(np.int32)
157
+ tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
158
+ local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
159
+
160
+ # Once we know the number of local blocks we can compute the request spans
161
+ # for each batch idx, we can figure out the number of "virtual" requests we
162
+ # have to make,
163
+ # For the above example we would get:
164
+ # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
165
+ #
166
+ # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
167
+ # (TODO: max a utility to share this code with _prepare_inputs)
168
+ # arange step 1. [2, 4, 2] -> [2, 6, 8]
169
+ cu_num_blocks = np.cumsum(local_blocks)
170
+ virtual_batches = cu_num_blocks[-1]
171
+ # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
172
+ block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
173
+ # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
174
+ arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
175
+ # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
176
+ rarange = np.repeat(local_blocks, local_blocks) - arange - 1
177
+ # Then we can compute the seqlens_q_local, handling the fact that the
178
+ # first and last blocks could be partial
179
+ seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
180
+ # set the first block since this may be a partial block
181
+ seqlens_q_local[arange == 0] = q_tokens_in_first_block
182
+ # set the remaining blocks
183
+ seqlens_q_local[arange > 0] = np.minimum(
184
+ seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
185
+ )[arange > 0]
186
+
187
+ # convert from q_seqlens to cu_seqlens_q
188
+ cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
189
+
190
+ # compute the seqlens_k_local,
191
+ # basically a full local attention block for all but the last block in each
192
+ # batch
193
+ # For our example this will be:
194
+ # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
195
+ seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
196
+ seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
197
+
198
+ k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
199
+ rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
200
+ )
201
+ # For the example the local attention blocks start at:
202
+ # _b0_ _____b1_____ _b2_
203
+ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
204
+ block_starts = k_seqstarts_absolute // page_size
205
+
206
+ assert attn_chunk_size % page_size == 0, (
207
+ f"attn_chunk_size {attn_chunk_size} is not "
208
+ f"divisible by page_size {page_size}"
209
+ )
210
+ pages_per_local_batch = attn_chunk_size // page_size
211
+
212
+ # Create a block_table for the local attention blocks
213
+ # For out example if we have a block-table like (assuming page_size=2):
214
+ # block_table = [
215
+ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
216
+ # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
217
+ # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
218
+ # ]
219
+ # Then for the local batches we would want a block-table like
220
+ # block_table_local = [
221
+ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
222
+ # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
223
+ # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
224
+ # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
225
+ # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
226
+ # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
227
+ # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
228
+ # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
229
+ # ]
230
+ block_indices = np.broadcast_to(
231
+ np.arange(pages_per_local_batch, dtype=np.int32),
232
+ (virtual_batches, pages_per_local_batch),
233
+ ) + np.expand_dims(block_starts, axis=1)
234
+ block_indices = block_indices.flatten()
235
+ batch_indices = np.repeat(
236
+ np.arange(actual_batch_size, dtype=np.int32),
237
+ local_blocks * pages_per_local_batch,
238
+ )
239
+ block_table_local = block_table[batch_indices, block_indices].view(
240
+ virtual_batches, -1
241
+ )
242
+
243
+ return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
244
+
245
+
246
+ def cdiv(a: int, b: int) -> int:
247
+ """Ceiling division."""
248
+ return -(a // -b)
249
+
40
250
 
41
251
  class FlashAttentionBackend(AttentionBackend):
42
- """FlashAttention backend implementation."""
252
+ """FlashAttention backend implementation.
253
+
254
+ Note about the init:
255
+ - If no spec decoding
256
+ - FlashAttentionBackend will be init once when the server starts.
257
+ - If spec decoding
258
+ - FlashAttentionBackend will be init once for the target worker
259
+ - FlashAttentionMultiStepBackend will be once for the draft worker
260
+ - It will spawn num_steps FlashAttentionBackend for the draft worker
261
+
262
+ Note about CUDA Graph:
263
+ - We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
264
+ - We don't support CUDA Graph for Extend and Draft Extend.
265
+ - When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
266
+ - For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
267
+ """
43
268
 
44
269
  def __init__(
45
270
  self,
46
271
  model_runner: ModelRunner,
47
272
  skip_prefill: bool = False,
273
+ topk=0,
274
+ speculative_num_steps=0,
275
+ step_id=0,
48
276
  ):
49
277
  super().__init__()
50
278
 
@@ -53,56 +281,129 @@ class FlashAttentionBackend(AttentionBackend):
53
281
  and model_runner.model_config.is_encoder_decoder
54
282
  ), "Sliding window and cross attention are not supported together"
55
283
 
56
- # Initialize metadata
57
284
  self.forward_metadata: FlashAttentionMetadata = None
58
285
  self.max_context_len = model_runner.model_config.context_len
59
286
  self.device = model_runner.device
60
287
  self.decode_cuda_graph_metadata = {}
288
+ self.target_verify_metadata = {}
61
289
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
62
290
  self.page_size = model_runner.page_size
63
291
  self.use_mla = (
64
292
  model_runner.model_config.attention_arch == AttentionArch.MLA
65
293
  ) and (not global_server_args_dict["disable_mla"])
294
+ self.skip_prefill = skip_prefill
295
+
296
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
297
+ assert (
298
+ topk <= 1
299
+ ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
300
+
301
+ self.topk = 1
302
+ self.step_id = step_id
303
+ self.speculative_num_steps = speculative_num_steps
304
+
305
+ # Local attention settings
306
+ self.attention_chunk_size = (
307
+ model_runner.attention_chunk_size
308
+ if hasattr(model_runner, "attention_chunk_size")
309
+ else None
310
+ )
66
311
 
67
312
  def init_forward_metadata(self, forward_batch: ForwardBatch):
68
313
  """Initialize forward metadata to cache repetitive calculations."""
69
- # Create metadata based on forward mode
70
314
  metadata = FlashAttentionMetadata()
71
-
72
- # Get sequence information
73
315
  seqlens_in_batch = forward_batch.seq_lens
74
- # Precompute int32 version of sequence lengths
75
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
76
316
  batch_size = len(seqlens_in_batch)
77
317
  device = seqlens_in_batch.device
78
- metadata.cu_seqlens_k = torch.nn.functional.pad(
79
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
80
- )
81
- # Precompute maximum sequence length
82
- metadata.max_seq_len_k = seqlens_in_batch.max().item()
83
- # Precompute page table
84
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
85
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
86
- ]
318
+ if forward_batch.forward_mode.is_decode():
319
+ # Skip Prefill or Draft Decode
320
+ # Note: Draft Decode will be ran on the Draft Worker
321
+ if forward_batch.spec_info is not None:
322
+ metadata.cu_seqlens_q = torch.arange(
323
+ 0, batch_size + 1, dtype=torch.int32, device=device
324
+ )
325
+ seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
326
+ metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
327
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
328
+ torch.cumsum(
329
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
330
+ ),
331
+ (1, 0),
332
+ )
333
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
334
+ self.step_id + 1
335
+ )
336
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
337
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
338
+ ]
339
+ cache_loc = forward_batch.out_cache_loc.view(
340
+ self.speculative_num_steps, -1
341
+ ).T
87
342
 
88
- # Precompute strided indices
89
- # [0, page_size, 2 * page_size, ...]
90
- if self.page_size > 1:
91
- self.strided_indices = torch.arange(
92
- 0, metadata.page_table.shape[1], self.page_size, device=self.device
343
+ for idx, single_seq_len in enumerate(seq_lens_with_decode):
344
+ real_bsz_start_idx = idx
345
+ real_bsz_end_idx = idx + 1
346
+ metadata.page_table[
347
+ real_bsz_start_idx:real_bsz_end_idx,
348
+ (single_seq_len - (self.step_id + 1)) : single_seq_len,
349
+ ] = cache_loc[
350
+ real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
351
+ ]
352
+ else: # Normal Decode without Spec Decoding
353
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
354
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
355
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
356
+ )
357
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
358
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
359
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
360
+ ]
361
+ metadata.cu_seqlens_q = torch.arange(
362
+ 0, batch_size + 1, dtype=torch.int32, device=device
363
+ )
364
+ elif forward_batch.forward_mode.is_target_verify():
365
+ # Note: Target Verify will be ran on the Target Worker
366
+ draft_token_num = forward_batch.spec_info.draft_token_num
367
+ metadata.cache_seqlens_int32 = (
368
+ forward_batch.seq_lens + draft_token_num
369
+ ).to(torch.int32)
370
+ metadata.max_seq_len_q = draft_token_num
371
+ metadata.max_seq_len_k = (
372
+ forward_batch.seq_lens_cpu.max().item() + draft_token_num
93
373
  )
94
- metadata.page_table = (
95
- metadata.page_table[:, self.strided_indices] // self.page_size
374
+ metadata.cu_seqlens_q = torch.arange(
375
+ 0,
376
+ batch_size * draft_token_num + 1,
377
+ draft_token_num,
378
+ dtype=torch.int32,
379
+ device=device,
96
380
  )
381
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
382
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
383
+ (1, 0),
384
+ )
385
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
386
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
387
+ ]
97
388
 
98
- if forward_batch.forward_mode == ForwardMode.DECODE:
99
- # Precompute cumulative sequence lengths
100
- metadata.cu_seqlens_q = torch.arange(
101
- 0, batch_size + 1, dtype=torch.int32, device=device
389
+ elif forward_batch.forward_mode.is_extend_or_draft_extend():
390
+ # Normal or Draft Extend (Both of them will be ran on the Target Worker)
391
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
392
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
393
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
102
394
  )
103
- else:
395
+ # Precompute maximum sequence length
396
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
397
+ # Precompute page table
398
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
399
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
400
+ ]
401
+
104
402
  # Precompute cumulative sequence lengths
105
- if any(forward_batch.extend_prefix_lens_cpu):
403
+ if (
404
+ any(forward_batch.extend_prefix_lens_cpu)
405
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
406
+ ):
106
407
  extend_seq_lens = forward_batch.extend_seq_lens
107
408
  metadata.cu_seqlens_q = torch.nn.functional.pad(
108
409
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
@@ -111,6 +412,61 @@ class FlashAttentionBackend(AttentionBackend):
111
412
  else:
112
413
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
113
414
  metadata.max_seq_len_q = metadata.max_seq_len_k
415
+
416
+ # Setup local attention if enabled
417
+ if (
418
+ self.attention_chunk_size is not None
419
+ and forward_batch.forward_mode == ForwardMode.EXTEND
420
+ ):
421
+ # Convert tensors to numpy for local attention processing
422
+ cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
423
+ seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
424
+
425
+ # Adjust attention_chunk_size based on the actual sequence length
426
+ # to avoid index out of bounds errors
427
+ max_seq_len = seq_lens_np.max()
428
+ effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
429
+ # Make sure effective_chunk_size is divisible by page_size
430
+ effective_chunk_size = (
431
+ effective_chunk_size // self.page_size
432
+ ) * self.page_size
433
+ if effective_chunk_size < self.page_size:
434
+ effective_chunk_size = self.page_size
435
+
436
+ # Create local attention metadata
437
+ (
438
+ seqlens_q_local_np,
439
+ cu_seqlens_q_local_np,
440
+ seqlens_k_local_np,
441
+ block_table_local,
442
+ ) = make_local_attention_virtual_batches(
443
+ effective_chunk_size,
444
+ cu_seqlens_q_np,
445
+ seq_lens_np,
446
+ metadata.page_table,
447
+ self.page_size,
448
+ )
449
+
450
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
451
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
452
+ device
453
+ ),
454
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
455
+ local_block_table=block_table_local,
456
+ local_max_query_len=seqlens_q_local_np.max(),
457
+ local_max_seq_len=seqlens_k_local_np.max(),
458
+ )
459
+ metadata.local_attn_metadata = local_metadata
460
+
461
+ # Precompute strided indices
462
+ if self.page_size > 1:
463
+ self.strided_indices = torch.arange(
464
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
465
+ )
466
+ metadata.page_table = (
467
+ metadata.page_table[:, self.strided_indices] // self.page_size
468
+ )
469
+
114
470
  self.forward_metadata = metadata
115
471
 
116
472
  def forward_extend(
@@ -122,7 +478,6 @@ class FlashAttentionBackend(AttentionBackend):
122
478
  forward_batch: ForwardBatch,
123
479
  save_kv_cache=True,
124
480
  ):
125
-
126
481
  if k is not None:
127
482
  assert v is not None
128
483
  if save_kv_cache:
@@ -155,9 +510,30 @@ class FlashAttentionBackend(AttentionBackend):
155
510
  else (-1, -1)
156
511
  )
157
512
 
158
- page_table = metadata.page_table
513
+ # Check if we should use local attention
514
+ use_local_attn = (
515
+ self.attention_chunk_size is not None
516
+ and metadata.local_attn_metadata is not None
517
+ and (hasattr(layer, "use_irope") and layer.use_irope)
518
+ )
159
519
 
160
- # # Use Flash Attention for prefill
520
+ # Get the appropriate page table based on whether we're using local attention
521
+ if use_local_attn:
522
+ local_metadata = metadata.local_attn_metadata
523
+ page_table = local_metadata.local_block_table
524
+ cu_seqlens_q = local_metadata.local_query_start_loc
525
+ cache_seqlens = local_metadata.local_seqused_k
526
+ max_seqlen_q = local_metadata.local_max_query_len
527
+ max_seqlen_k = local_metadata.local_max_seq_len
528
+ else:
529
+ page_table = metadata.page_table
530
+ cu_seqlens_q = metadata.cu_seqlens_q
531
+ cache_seqlens = metadata.cache_seqlens_int32
532
+ max_seqlen_q = metadata.max_seq_len_q
533
+ max_seqlen_k = metadata.max_seq_len_k
534
+ cu_seqlens_k = metadata.cu_seqlens_k
535
+
536
+ # Use Flash Attention for prefill
161
537
  if not self.use_mla:
162
538
  # Do multi-head attention
163
539
  kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
@@ -173,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
173
549
  k_cache=key_cache,
174
550
  v_cache=value_cache,
175
551
  page_table=page_table,
176
- cache_seqlens=metadata.cache_seqlens_int32,
177
- cu_seqlens_q=metadata.cu_seqlens_q,
178
- cu_seqlens_k_new=metadata.cu_seqlens_k,
179
- max_seqlen_q=metadata.max_seq_len_q,
552
+ cache_seqlens=cache_seqlens,
553
+ cu_seqlens_q=cu_seqlens_q,
554
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
555
+ max_seqlen_q=max_seqlen_q,
180
556
  softmax_scale=layer.scaling,
181
557
  causal=True,
182
558
  window_size=window_size,
@@ -208,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
208
584
  v_cache=c_kv_cache,
209
585
  qv=q_nope,
210
586
  page_table=page_table,
211
- cache_seqlens=metadata.cache_seqlens_int32,
212
- cu_seqlens_q=metadata.cu_seqlens_q,
213
- cu_seqlens_k_new=metadata.cu_seqlens_k,
214
- max_seqlen_q=metadata.max_seq_len_q,
587
+ cache_seqlens=cache_seqlens,
588
+ cu_seqlens_q=cu_seqlens_q,
589
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
590
+ max_seqlen_q=max_seqlen_q,
215
591
  softmax_scale=layer.scaling,
216
592
  causal=True,
217
593
  softcap=layer.logit_cap,
@@ -263,7 +639,6 @@ class FlashAttentionBackend(AttentionBackend):
263
639
  if layer.sliding_window_size is not None
264
640
  else (-1, -1)
265
641
  )
266
-
267
642
  page_table = metadata.page_table
268
643
 
269
644
  if not self.use_mla:
@@ -281,8 +656,6 @@ class FlashAttentionBackend(AttentionBackend):
281
656
 
282
657
  # Pre-reshape query tensor
283
658
  q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
284
-
285
- # Run attention with precomputed values
286
659
  o = flash_attn_with_kvcache(
287
660
  q=q_reshaped,
288
661
  k_cache=key_cache,
@@ -334,7 +707,6 @@ class FlashAttentionBackend(AttentionBackend):
334
707
  k_descale=layer.k_scale,
335
708
  v_descale=layer.v_scale,
336
709
  )
337
-
338
710
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
339
711
 
340
712
  def init_cuda_graph_state(self, max_bs: int):
@@ -346,7 +718,6 @@ class FlashAttentionBackend(AttentionBackend):
346
718
  This creates fixed-size tensors that will be reused during CUDA graph replay
347
719
  to avoid memory allocations.
348
720
  """
349
- # Initialize fixed size tensors for decode operations
350
721
  self.decode_cuda_graph_metadata = {
351
722
  # Page table for token mapping (batch_size, max_context_len)
352
723
  "page_table": torch.zeros(
@@ -355,6 +726,39 @@ class FlashAttentionBackend(AttentionBackend):
355
726
  dtype=torch.int32,
356
727
  device=self.device,
357
728
  ),
729
+ "page_table_draft_decode": torch.zeros(
730
+ max_bs,
731
+ (self.max_context_len + self.page_size - 1) // self.page_size,
732
+ dtype=torch.int32,
733
+ device=self.device,
734
+ ),
735
+ "strided_indices": torch.arange(
736
+ 0, self.max_context_len, self.page_size, device=self.device
737
+ ),
738
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
739
+ "cu_seqlens_q": torch.arange(
740
+ 0, max_bs + 128, dtype=torch.int32, device=self.device
741
+ ),
742
+ "cu_seqlens_k": torch.zeros(
743
+ max_bs + 128, dtype=torch.int32, device=self.device
744
+ ),
745
+ }
746
+
747
+ self.target_verify_metadata = {
748
+ "page_table": torch.zeros(
749
+ max_bs,
750
+ (self.max_context_len + self.page_size - 1) // self.page_size,
751
+ dtype=torch.int32,
752
+ device=self.device,
753
+ ),
754
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
755
+ "cu_seqlens_q": torch.zeros(
756
+ max_bs + 128, dtype=torch.int32, device=self.device
757
+ ),
758
+ "cu_seqlens_k": torch.zeros(
759
+ max_bs + 128, dtype=torch.int32, device=self.device
760
+ ),
761
+ "max_seqlen_q": 0,
358
762
  "strided_indices": torch.arange(
359
763
  0, self.max_context_len, self.page_size, device=self.device
360
764
  ),
@@ -372,27 +776,89 @@ class FlashAttentionBackend(AttentionBackend):
372
776
  ):
373
777
  """Initialize forward metadata for capturing CUDA graph."""
374
778
  metadata = FlashAttentionMetadata()
375
- # Get sequence information
376
- metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
377
- batch_size = len(seq_lens)
378
779
  device = seq_lens.device
379
- metadata.cu_seqlens_k = torch.nn.functional.pad(
380
- torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
381
- )
382
- # Precompute maximum sequence length
383
- metadata.max_seq_len_k = seq_lens.max().item()
384
- # Precompute page table
385
- metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
386
- req_pool_indices, :
387
- ]
388
- if forward_mode == ForwardMode.DECODE:
389
- # Precompute cumulative sequence lengths
390
- metadata.cu_seqlens_q = torch.arange(
391
- 0, batch_size + 1, dtype=torch.int32, device=device
780
+ if forward_mode.is_decode():
781
+ if spec_info is not None:
782
+ # Draft Decode
783
+ metadata.cu_seqlens_q = torch.arange(
784
+ 0, bs + 1, dtype=torch.int32, device=device
785
+ )
786
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
787
+ "cache_seqlens"
788
+ ][:bs]
789
+
790
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
791
+ : bs + 1
792
+ ]
793
+
794
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
795
+ torch.cumsum(
796
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
797
+ ),
798
+ (1, 0),
799
+ )
800
+ metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
801
+ metadata.page_table = self.decode_cuda_graph_metadata[
802
+ "page_table_draft_decode"
803
+ ][req_pool_indices, :]
804
+ else:
805
+ # Normal Decode
806
+ # Get sequence information
807
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
808
+ batch_size = len(seq_lens)
809
+ device = seq_lens.device
810
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
811
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
812
+ )
813
+ # Precompute maximum sequence length
814
+ metadata.max_seq_len_k = seq_lens.max().item()
815
+ # Precompute page table
816
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
817
+ req_pool_indices, :
818
+ ]
819
+ # Precompute cumulative sequence lengths
820
+ metadata.cu_seqlens_q = torch.arange(
821
+ 0, batch_size + 1, dtype=torch.int32, device=device
822
+ )
823
+ self.decode_cuda_graph_metadata[bs] = metadata
824
+ elif forward_mode.is_target_verify():
825
+ draft_token_num = spec_info.draft_token_num
826
+
827
+ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
828
+ :bs
829
+ ]
830
+ metadata.cache_seqlens_int32.copy_(
831
+ (seq_lens + draft_token_num).to(torch.int32)
392
832
  )
393
- else:
394
- raise ValueError("Do not support Prefill Mode cuda graph")
395
- self.decode_cuda_graph_metadata[bs] = metadata
833
+
834
+ metadata.max_seq_len_q = draft_token_num
835
+ metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
836
+
837
+ metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
838
+ torch.arange(
839
+ 0,
840
+ bs * draft_token_num + 1,
841
+ draft_token_num,
842
+ dtype=torch.int32,
843
+ device=device,
844
+ )
845
+ ]
846
+ cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
847
+ cu_k.copy_(
848
+ torch.nn.functional.pad(
849
+ torch.cumsum(
850
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
851
+ ),
852
+ (1, 0),
853
+ )
854
+ )
855
+ metadata.cu_seqlens_k = cu_k
856
+ metadata.page_table = self.target_verify_metadata["page_table"][
857
+ req_pool_indices, :
858
+ ]
859
+
860
+ self.target_verify_metadata[bs] = metadata
861
+
396
862
  self.forward_metadata = metadata
397
863
 
398
864
  def init_forward_metadata_replay_cuda_graph(
@@ -405,30 +871,159 @@ class FlashAttentionBackend(AttentionBackend):
405
871
  forward_mode: ForwardMode,
406
872
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
407
873
  seq_lens_cpu: Optional[torch.Tensor],
874
+ out_cache_loc: torch.Tensor = None,
408
875
  ):
409
876
  # """Initialize forward metadata for replaying CUDA graph."""
410
- metadata = self.decode_cuda_graph_metadata[bs]
877
+ device = seq_lens.device
878
+ seq_lens = seq_lens[:bs]
879
+ req_pool_indices = req_pool_indices[:bs]
880
+ seq_lens_cpu = seq_lens_cpu[:bs]
881
+ if forward_mode.is_decode():
882
+ metadata = self.decode_cuda_graph_metadata[bs]
411
883
 
412
- # For CPU operations
413
- max_len = seq_lens_cpu[:bs].max().item()
414
- metadata.max_seq_len_k = max_len
884
+ if spec_info is not None:
885
+ # Draft Decode
886
+ max_len = seq_lens_cpu.max().item()
887
+ metadata.max_seq_len_k = max_len + (self.step_id + 1)
415
888
 
416
- # For GPU operations
417
- seq_lens_in_batch = seq_lens[:bs]
418
- metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
419
- metadata.cu_seqlens_k = torch.nn.functional.pad(
420
- torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
421
- )
889
+ metadata.cache_seqlens_int32.copy_(
890
+ (seq_lens + (self.step_id + 1)).to(torch.int32)
891
+ )
892
+
893
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
894
+
895
+ metadata.cu_seqlens_k.copy_(
896
+ torch.nn.functional.pad(
897
+ torch.cumsum(
898
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
899
+ ),
900
+ (1, 0),
901
+ )
902
+ )
903
+
904
+ page_table = self.req_to_token[
905
+ req_pool_indices, : metadata.max_seq_len_k
906
+ ]
907
+
908
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
909
+ else:
910
+ # Normal Decode
911
+ max_len = seq_lens_cpu.max().item()
912
+ metadata.max_seq_len_k = max_len
913
+
914
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
915
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
916
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
917
+ )
918
+
919
+ max_seq_pages = (
920
+ metadata.max_seq_len_k + self.page_size - 1
921
+ ) // self.page_size
922
+ page_indices = self.req_to_token[
923
+ :,
924
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
925
+ ]
926
+ page_indices = page_indices[req_pool_indices] // self.page_size
927
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
928
+ metadata.page_table[:, max_seq_pages:].fill_(0)
929
+
930
+ elif forward_mode.is_target_verify():
931
+ metadata = self.target_verify_metadata[bs]
932
+ draft_token_num = spec_info.draft_token_num
933
+
934
+ metadata.cu_seqlens_q.copy_(
935
+ torch.arange(
936
+ 0,
937
+ bs * draft_token_num + 1,
938
+ draft_token_num,
939
+ dtype=torch.int32,
940
+ device=device,
941
+ )
942
+ )
943
+ metadata.cache_seqlens_int32.copy_(
944
+ (seq_lens + draft_token_num).to(torch.int32)
945
+ )
946
+
947
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
948
+ metadata.cu_seqlens_k.copy_(
949
+ torch.nn.functional.pad(
950
+ torch.cumsum(
951
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
952
+ ),
953
+ (1, 0),
954
+ )
955
+ )
956
+ page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
957
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
422
958
 
423
- max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
424
- page_indices = self.req_to_token[
425
- :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
426
- ]
427
- page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
428
- metadata.page_table[:, :max_seq_pages].copy_(page_indices)
429
- metadata.page_table[:, max_seq_pages:].fill_(0)
430
959
  self.forward_metadata = metadata
431
960
 
432
961
  def get_cuda_graph_seq_len_fill_value(self):
433
962
  """Get the fill value for sequence length in CUDA graph."""
434
963
  return 0
964
+
965
+
966
+ class FlashAttentionMultiStepBackend:
967
+
968
+ def __init__(
969
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
970
+ ):
971
+ self.model_runner = model_runner
972
+ self.topk = topk
973
+ self.speculative_num_steps = speculative_num_steps
974
+
975
+ self.attn_backends = []
976
+ for i in range(self.speculative_num_steps):
977
+ self.attn_backends.append(
978
+ FlashAttentionBackend(
979
+ model_runner,
980
+ topk=self.topk,
981
+ speculative_num_steps=self.speculative_num_steps,
982
+ step_id=i,
983
+ )
984
+ )
985
+
986
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
987
+ for i in range(self.speculative_num_steps - 1):
988
+ self.attn_backends[i].init_forward_metadata(forward_batch)
989
+
990
+ def init_cuda_graph_state(self, max_bs: int):
991
+ for i in range(self.speculative_num_steps):
992
+ self.attn_backends[i].init_cuda_graph_state(max_bs)
993
+
994
+ def init_forward_metadata_capture_cuda_graph(
995
+ self,
996
+ forward_batch: ForwardBatch,
997
+ ):
998
+ assert forward_batch.spec_info is not None
999
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
1000
+
1001
+ for i in range(self.speculative_num_steps - 1):
1002
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1003
+ forward_batch.batch_size,
1004
+ forward_batch.batch_size * self.topk,
1005
+ forward_batch.req_pool_indices,
1006
+ forward_batch.seq_lens,
1007
+ encoder_lens=None,
1008
+ forward_mode=ForwardMode.DECODE,
1009
+ spec_info=forward_batch.spec_info,
1010
+ )
1011
+
1012
+ def init_forward_metadata_replay_cuda_graph(
1013
+ self, forward_batch: ForwardBatch, bs: int
1014
+ ):
1015
+ assert forward_batch.spec_info is not None
1016
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
1017
+
1018
+ for i in range(self.speculative_num_steps - 1):
1019
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1020
+ bs,
1021
+ forward_batch.req_pool_indices,
1022
+ forward_batch.seq_lens,
1023
+ forward_batch.seq_lens_sum,
1024
+ encoder_lens=None,
1025
+ forward_mode=ForwardMode.DECODE,
1026
+ spec_info=forward_batch.spec_info,
1027
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1028
+ out_cache_loc=forward_batch.out_cache_loc,
1029
+ )