sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -1,49 +1,260 @@
1
1
  from __future__ import annotations
2
2
 
3
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
4
-
5
- """
6
- Support different attention backends.
7
- Now there are three backends: FlashInfer, Triton and FlashAttention.
8
- Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
9
- """
10
-
11
3
  from dataclasses import dataclass
12
4
  from typing import TYPE_CHECKING, Optional, Union
13
5
 
6
+ import numpy as np
14
7
  import torch
15
8
 
16
9
  from sglang.srt.configs.model_config import AttentionArch
17
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
19
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
20
14
 
21
15
  if TYPE_CHECKING:
22
16
  from sglang.srt.layers.radix_attention import RadixAttention
23
17
  from sglang.srt.model_executor.model_runner import ModelRunner
24
18
 
25
- from sgl_kernel.flash_attn import flash_attn_with_kvcache
19
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
26
20
 
27
21
 
28
22
  @dataclass
29
23
  class FlashAttentionMetadata:
30
24
  """Metadata to be init once in the model forward pass,
31
- each layer's forward pass can reuse the metadata."""
25
+ each layer's forward pass can reuse the metadata.
32
26
 
33
- # Cumulative sequence lengths for query
34
- cu_seqlens_q: torch.Tensor = None
35
- # Cumulative sequence lengths for key
36
- cu_seqlens_k: torch.Tensor = None
27
+ For each init metadata function, we will try set up them in below order
28
+ """
29
+
30
+ # Sequence lengths for the forward batch
31
+ cache_seqlens_int32: torch.Tensor = None
37
32
  # Maximum sequence length for query
38
33
  max_seq_len_q: int = 0
39
34
  # Maximum sequence length for key
40
35
  max_seq_len_k: int = 0
36
+ # Cumulative sequence lengths for query
37
+ cu_seqlens_q: torch.Tensor = None
38
+ # Cumulative sequence lengths for key
39
+ cu_seqlens_k: torch.Tensor = None
41
40
  # Window size (typically used by Gemma)
42
41
  window_size: tuple = (-1, -1)
43
42
  # Page table, the index of KV Cache Tables/Blocks
44
43
  page_table: torch.Tensor = None
44
+
45
+ # Encoder metadata
46
+ # Cumulative sequence lengths for encoder key
47
+ encoder_cu_seqlens_k: torch.Tensor = None
48
+ # Maximum sequence length for encoder key
49
+ encoder_max_seq_len_k: int = 0
45
50
  # Sequence lengths for the forward batch
46
- cache_seqlens_int32: torch.Tensor = None
51
+ encoder_lens_int32: torch.Tensor = None
52
+ # Page table for the encoder
53
+ encoder_page_table: torch.Tensor = None
54
+
55
+ @dataclass
56
+ class LocalAttentionMetadata:
57
+ local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
58
+ local_seqused_k: torch.Tensor = None # sequence lengths for local attention
59
+ local_block_table: torch.Tensor = None # block table for local attention
60
+ local_max_query_len: int = 0 # max query length for local attention
61
+ local_max_seq_len: int = 0 # max sequence length for local attention
62
+
63
+ local_attn_metadata: Optional[LocalAttentionMetadata] = None
64
+
65
+
66
+ # Copied from:
67
+ # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
68
+ #
69
+ # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
70
+ # local attention blocks, where each block is passed to the attention kernel
71
+ # as an independent local ("virtual") batch item.
72
+ #
73
+ # For example, if are performing a chunked prefill a batch of 3 sequences:
74
+ # q_seqlens = [4, 10, 5]
75
+ # kv_seqlens = [6, 17, 9]
76
+ # Then normally for regular attention we would compute with an attention mask
77
+ # for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
78
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
79
+ # k_toks > 0 1 2 3 4 5
80
+ # q_toks v _____________
81
+ # 0 | 1 1 1
82
+ # 1 | 1 1 1 1
83
+ # 2 | 1 1 1 1 1
84
+ # 3 | 1 1 1 1 1 1
85
+ #
86
+ # for local attention (with attn_chunk_size = 4) we would compute with an
87
+ # attention mask like:
88
+ # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
89
+ # k_toks > 0 1 2 3 4 5
90
+ # q_toks v _____________
91
+ # 0 | 1 1 1
92
+ # 1 | 1 1 1 1
93
+ # 2 | 1
94
+ # 3 | 1 1
95
+ #
96
+ # We can simulate this mask using standard flash-attention by breaking the
97
+ # sequences into local ("virtual") batches, where each local batch item is a
98
+ # local attention block, so in this case batch idx 0 would be broken up into:
99
+ #
100
+ # local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
101
+ # k_toks > 0 1 2 3
102
+ # q_toks v _____________
103
+ # 0 | 1 1 1
104
+ # 1 | 1 1 1 1
105
+ # local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
106
+ # k_toks > 4 5
107
+ # q_toks v _____________
108
+ # 2 | 1
109
+ # 3 | 1 1
110
+ #
111
+ # e.g. if we have:
112
+ # attn_chunk_size = 4
113
+ # query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
114
+ # Then this function would return:
115
+ # __b0__ ______b1______ __b2__ < orig batch indices
116
+ # q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
117
+ # cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
118
+ # seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
119
+ # block_table_local : shape[local_virtual_batches, pages_per_local_batch]
120
+ def make_local_attention_virtual_batches(
121
+ attn_chunk_size: int,
122
+ query_start_loc_np: np.ndarray,
123
+ seq_lens_np: np.ndarray,
124
+ block_table: torch.Tensor,
125
+ page_size: int = 0,
126
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
127
+ """
128
+ Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
129
+ local attention blocks, where each block is passed to the attention kernel
130
+ as an independent local ("virtual") batch item.
131
+
132
+ Args:
133
+ attn_chunk_size: Size of local attention chunks
134
+ query_start_loc_np: Cumulative sum of query lengths (numpy array)
135
+ seq_lens_np: Sequence lengths (numpy array)
136
+ block_table: Block table for KV cache
137
+ page_size: Size of each page in the KV cache
138
+
139
+ Returns:
140
+ seqlens_q_local: Query sequence lengths for local attention
141
+ cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
142
+ seqlens_k_local: Key sequence lengths for local attention
143
+ block_table_local: Block table for local attention
144
+ """
145
+ q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
146
+ actual_batch_size = seq_lens_np.shape[0]
147
+
148
+ # Handle if we are starting in the middle of a local attention block,
149
+ # we assume q_seqlens > 0 (for all elements), for each batch idx we compute
150
+ # the number of tokens that are not in the first local attention block and
151
+ # then we can simply use a cdiv for the rest.
152
+ # For example if we have:
153
+ # attn_chunk_size = 4
154
+ # q_seqlens = [4, 10, 5]
155
+ # k_seqlens = [6, 17, 9]
156
+ # Then we would get:
157
+ # new_tokens_in_first_block = [2, 1, 4]
158
+ # local_blocks = [2, 4, 2]
159
+ q_tokens_in_first_block = np.minimum(
160
+ attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
161
+ ).astype(np.int32)
162
+ tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
163
+ local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
164
+
165
+ # Once we know the number of local blocks we can compute the request spans
166
+ # for each batch idx, we can figure out the number of "virtual" requests we
167
+ # have to make,
168
+ # For the above example we would get:
169
+ # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
170
+ #
171
+ # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
172
+ # (TODO: max a utility to share this code with _prepare_inputs)
173
+ # arange step 1. [2, 4, 2] -> [2, 6, 8]
174
+ cu_num_blocks = np.cumsum(local_blocks)
175
+ virtual_batches = cu_num_blocks[-1]
176
+ # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
177
+ block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
178
+ # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
179
+ arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
180
+ # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
181
+ rarange = np.repeat(local_blocks, local_blocks) - arange - 1
182
+ # Then we can compute the seqlens_q_local, handling the fact that the
183
+ # first and last blocks could be partial
184
+ seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
185
+ # set the first block since this may be a partial block
186
+ seqlens_q_local[arange == 0] = q_tokens_in_first_block
187
+ # set the remaining blocks
188
+ seqlens_q_local[arange > 0] = np.minimum(
189
+ seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
190
+ )[arange > 0]
191
+
192
+ # convert from q_seqlens to cu_seqlens_q
193
+ cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
194
+
195
+ # compute the seqlens_k_local,
196
+ # basically a full local attention block for all but the last block in each
197
+ # batch
198
+ # For our example this will be:
199
+ # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
200
+ seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
201
+ seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
202
+
203
+ k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
204
+ rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
205
+ )
206
+ # For the example the local attention blocks start at:
207
+ # _b0_ _____b1_____ _b2_
208
+ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
209
+ block_starts = k_seqstarts_absolute // page_size
210
+
211
+ assert attn_chunk_size % page_size == 0, (
212
+ f"attn_chunk_size {attn_chunk_size} is not "
213
+ f"divisible by page_size {page_size}"
214
+ )
215
+ pages_per_local_batch = attn_chunk_size // page_size
216
+
217
+ # Create a block_table for the local attention blocks
218
+ # For out example if we have a block-table like (assuming page_size=2):
219
+ # block_table = [
220
+ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
221
+ # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
222
+ # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
223
+ # ]
224
+ # Then for the local batches we would want a block-table like
225
+ # block_table_local = [
226
+ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
227
+ # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
228
+ # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
229
+ # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
230
+ # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
231
+ # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
232
+ # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
233
+ # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
234
+ # ]
235
+ block_indices = np.broadcast_to(
236
+ np.arange(pages_per_local_batch, dtype=np.int32),
237
+ (virtual_batches, pages_per_local_batch),
238
+ ) + np.expand_dims(block_starts, axis=1)
239
+ # Ensure block_indices doesn't exceed block_table dimensions
240
+ # This is a critical safety check that prevents index out of bounds errors
241
+ # when dealing with large sequences (>8192 tokens) or when the block_table
242
+ # dimensions are smaller than what would be needed for the full attention chunk size.
243
+ block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
244
+ batch_indices = np.repeat(
245
+ np.arange(actual_batch_size, dtype=np.int32),
246
+ local_blocks * pages_per_local_batch,
247
+ )
248
+ block_table_local = block_table[batch_indices, block_indices].view(
249
+ virtual_batches, -1
250
+ )
251
+
252
+ return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
253
+
254
+
255
+ def cdiv(a: int, b: int) -> int:
256
+ """Ceiling division."""
257
+ return -(a // -b)
47
258
 
48
259
 
49
260
  class FlashAttentionBackend(AttentionBackend):
@@ -68,9 +279,9 @@ class FlashAttentionBackend(AttentionBackend):
68
279
  self,
69
280
  model_runner: ModelRunner,
70
281
  skip_prefill: bool = False,
282
+ speculative_step_id=0,
71
283
  topk=0,
72
284
  speculative_num_steps=0,
73
- step_id=0,
74
285
  ):
75
286
  super().__init__()
76
287
 
@@ -85,87 +296,82 @@ class FlashAttentionBackend(AttentionBackend):
85
296
  self.decode_cuda_graph_metadata = {}
86
297
  self.target_verify_metadata = {}
87
298
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
299
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
300
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
88
301
  self.page_size = model_runner.page_size
89
302
  self.use_mla = (
90
303
  model_runner.model_config.attention_arch == AttentionArch.MLA
91
304
  ) and (not global_server_args_dict["disable_mla"])
92
305
  self.skip_prefill = skip_prefill
93
306
 
94
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
95
- assert (
96
- topk <= 1
97
- ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
98
-
99
- self.topk = 1
100
- self.step_id = step_id
307
+ self.topk = topk
101
308
  self.speculative_num_steps = speculative_num_steps
309
+ self.speculative_num_draft_tokens = (
310
+ model_runner.server_args.speculative_num_draft_tokens
311
+ )
312
+ self.speculative_step_id = speculative_step_id
313
+
314
+ # Local attention settings
315
+ self.attention_chunk_size = (
316
+ model_runner.attention_chunk_size
317
+ if hasattr(model_runner, "attention_chunk_size")
318
+ else None
319
+ )
102
320
 
103
321
  def init_forward_metadata(self, forward_batch: ForwardBatch):
104
- """Initialize forward metadata to cache repetitive calculations."""
322
+ """Initialize forward metadata hence all layers in the forward pass can reuse it."""
105
323
  metadata = FlashAttentionMetadata()
106
324
  seqlens_in_batch = forward_batch.seq_lens
107
325
  batch_size = len(seqlens_in_batch)
108
326
  device = seqlens_in_batch.device
109
- if forward_batch.forward_mode.is_decode():
110
- # Skip Prefill or Draft Decode
111
- # Note: Draft Decode will be ran on the Draft Worker
327
+
328
+ if forward_batch.forward_mode.is_decode_or_idle():
329
+ # Draft Decode
112
330
  if forward_batch.spec_info is not None:
331
+ metadata.cache_seqlens_int32 = (
332
+ seqlens_in_batch + (self.speculative_step_id + 1)
333
+ ).to(torch.int32)
334
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
335
+ self.speculative_step_id + 1
336
+ )
113
337
  metadata.cu_seqlens_q = torch.arange(
114
338
  0, batch_size + 1, dtype=torch.int32, device=device
115
339
  )
116
- seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
117
- metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
118
340
  metadata.cu_seqlens_k = torch.nn.functional.pad(
119
341
  torch.cumsum(
120
342
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
121
343
  ),
122
344
  (1, 0),
123
345
  )
124
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
125
- self.step_id + 1
126
- )
127
346
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
128
347
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
129
348
  ]
130
- cache_loc = forward_batch.out_cache_loc.view(
131
- self.speculative_num_steps, -1
132
- ).T
133
-
134
- for idx, single_seq_len in enumerate(seq_lens_with_decode):
135
- real_bsz_start_idx = idx
136
- real_bsz_end_idx = idx + 1
137
- metadata.page_table[
138
- real_bsz_start_idx:real_bsz_end_idx,
139
- (single_seq_len - (self.step_id + 1)) : single_seq_len,
140
- ] = cache_loc[
141
- real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
142
- ]
143
- else: # Normal Decode without Spec Decoding
349
+ else:
350
+ # Normal Decode
144
351
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
352
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
353
+ metadata.cu_seqlens_q = torch.arange(
354
+ 0, batch_size + 1, dtype=torch.int32, device=device
355
+ )
145
356
  metadata.cu_seqlens_k = torch.nn.functional.pad(
146
357
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
147
358
  )
148
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
149
359
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
150
360
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
151
361
  ]
152
- metadata.cu_seqlens_q = torch.arange(
153
- 0, batch_size + 1, dtype=torch.int32, device=device
154
- )
155
362
  elif forward_batch.forward_mode.is_target_verify():
156
- # Note: Target Verify will be ran on the Target Worker
157
- draft_token_num = forward_batch.spec_info.draft_token_num
158
363
  metadata.cache_seqlens_int32 = (
159
- forward_batch.seq_lens + draft_token_num
364
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
160
365
  ).to(torch.int32)
161
- metadata.max_seq_len_q = draft_token_num
366
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
162
367
  metadata.max_seq_len_k = (
163
- forward_batch.seq_lens_cpu.max().item() + draft_token_num
368
+ forward_batch.seq_lens_cpu.max().item()
369
+ + self.speculative_num_draft_tokens
164
370
  )
165
371
  metadata.cu_seqlens_q = torch.arange(
166
372
  0,
167
- batch_size * draft_token_num + 1,
168
- draft_token_num,
373
+ batch_size * self.speculative_num_draft_tokens + 1,
374
+ self.speculative_num_draft_tokens,
169
375
  dtype=torch.int32,
170
376
  device=device,
171
377
  )
@@ -177,33 +383,99 @@ class FlashAttentionBackend(AttentionBackend):
177
383
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
178
384
  ]
179
385
 
180
- elif forward_batch.forward_mode.is_extend_or_draft_extend():
181
- # Normal or Draft Extend (Both of them will be ran on the Target Worker)
386
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
182
387
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
388
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
183
389
  metadata.cu_seqlens_k = torch.nn.functional.pad(
184
390
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
185
391
  )
186
- # Precompute maximum sequence length
187
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
188
- # Precompute page table
189
392
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
190
393
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
191
394
  ]
192
- # Precompute cumulative sequence lengths
395
+
193
396
  if (
194
397
  any(forward_batch.extend_prefix_lens_cpu)
195
398
  or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
196
399
  ):
197
400
  extend_seq_lens = forward_batch.extend_seq_lens
401
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
198
402
  metadata.cu_seqlens_q = torch.nn.functional.pad(
199
403
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
200
404
  )
201
- metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
202
405
  else:
203
- metadata.cu_seqlens_q = metadata.cu_seqlens_k
204
406
  metadata.max_seq_len_q = metadata.max_seq_len_k
407
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
408
+
409
+ # Setup local attention if enabled
410
+ if (
411
+ self.attention_chunk_size is not None
412
+ and forward_batch.forward_mode == ForwardMode.EXTEND
413
+ ):
414
+ # Convert tensors to numpy for local attention processing
415
+ cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
416
+ seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
417
+
418
+ # Adjust attention_chunk_size based on the actual sequence length
419
+ # to avoid index out of bounds errors
420
+ max_seq_len = seq_lens_np.max()
421
+ effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
422
+ # Make sure effective_chunk_size is divisible by page_size
423
+ effective_chunk_size = (
424
+ effective_chunk_size // self.page_size
425
+ ) * self.page_size
426
+ if effective_chunk_size < self.page_size:
427
+ effective_chunk_size = self.page_size
428
+
429
+ # Create local attention metadata
430
+ (
431
+ seqlens_q_local_np,
432
+ cu_seqlens_q_local_np,
433
+ seqlens_k_local_np,
434
+ block_table_local,
435
+ ) = make_local_attention_virtual_batches(
436
+ effective_chunk_size,
437
+ cu_seqlens_q_np,
438
+ seq_lens_np,
439
+ metadata.page_table,
440
+ self.page_size,
441
+ )
442
+
443
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
444
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
445
+ device
446
+ ),
447
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
448
+ local_block_table=block_table_local,
449
+ local_max_query_len=seqlens_q_local_np.max(),
450
+ local_max_seq_len=seqlens_k_local_np.max(),
451
+ )
452
+ metadata.local_attn_metadata = local_metadata
453
+
454
+ # Encoder metadata for cross attention
455
+ if forward_batch.encoder_lens is not None:
456
+ assert (
457
+ forward_batch.encoder_lens.numel() == 1
458
+ ), "Only encoder size 1 is supported for now"
205
459
 
206
- # Precompute strided indices
460
+ metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
461
+ metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
462
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
463
+ (1, 0),
464
+ )
465
+ metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
466
+ metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
467
+ forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
468
+ ]
469
+
470
+ # Currently only support forward_batch.encoder_lens.numel() == 1
471
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
472
+ forward_batch.req_pool_indices,
473
+ metadata.encoder_max_seq_len_k : (
474
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
475
+ ),
476
+ ]
477
+
478
+ # Convert the page table to a strided format which is needed by FA3 API
207
479
  if self.page_size > 1:
208
480
  self.strided_indices = torch.arange(
209
481
  0, metadata.page_table.shape[1], self.page_size, device=self.device
@@ -211,6 +483,7 @@ class FlashAttentionBackend(AttentionBackend):
211
483
  metadata.page_table = (
212
484
  metadata.page_table[:, self.strided_indices] // self.page_size
213
485
  )
486
+
214
487
  self.forward_metadata = metadata
215
488
 
216
489
  def forward_extend(
@@ -242,7 +515,7 @@ class FlashAttentionBackend(AttentionBackend):
242
515
  v,
243
516
  )
244
517
 
245
- # Use precomputed metadata
518
+ # Use precomputed metadata across all layers
246
519
  metadata = self.forward_metadata
247
520
 
248
521
  # Calculate window size (can be moved to metadata if layer properties don't change)
@@ -250,75 +523,157 @@ class FlashAttentionBackend(AttentionBackend):
250
523
  # here is two side inclusive
251
524
  window_size = (
252
525
  (layer.sliding_window_size, 0)
253
- if layer.sliding_window_size is not None
526
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
254
527
  else (-1, -1)
255
528
  )
529
+ k_descale, v_descale = None, None
530
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
531
+ # has corresponding quantization method so that layer.k_scale is not None
532
+ if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
533
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
534
+ k_descale = layer.k_scale.expand(descale_shape)
535
+ v_descale = layer.v_scale.expand(descale_shape)
536
+ q = q.to(self.kv_cache_dtype)
537
+ causal = not layer.is_cross_attention
538
+
539
+ # Check if we should use local attention
540
+ use_local_attn = (
541
+ self.attention_chunk_size is not None
542
+ and metadata.local_attn_metadata is not None
543
+ and (hasattr(layer, "use_irope") and layer.use_irope)
544
+ )
256
545
 
257
- page_table = metadata.page_table
546
+ # Get the appropriate page table based on whether we're using local attention
547
+ if use_local_attn:
548
+ local_metadata = metadata.local_attn_metadata
549
+ page_table = local_metadata.local_block_table
550
+ cu_seqlens_q = local_metadata.local_query_start_loc
551
+ cache_seqlens = local_metadata.local_seqused_k
552
+ max_seqlen_q = local_metadata.local_max_query_len
553
+ max_seqlen_k = local_metadata.local_max_seq_len
554
+ else:
555
+ page_table = metadata.page_table
556
+ cu_seqlens_q = metadata.cu_seqlens_q
557
+ cache_seqlens = metadata.cache_seqlens_int32
558
+ max_seqlen_q = metadata.max_seq_len_q
559
+ max_seqlen_k = metadata.max_seq_len_k
560
+ cu_seqlens_k = metadata.cu_seqlens_k
258
561
 
259
562
  # Use Flash Attention for prefill
260
563
  if not self.use_mla:
261
564
  # Do multi-head attention
262
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
263
- key_cache, value_cache = kv_cache[0], kv_cache[1]
565
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
566
+ layer.layer_id
567
+ )
264
568
  key_cache = key_cache.view(
265
569
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
266
570
  )
267
571
  value_cache = value_cache.view(
268
572
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
269
573
  )
574
+ if layer.is_cross_attention:
575
+ page_table = metadata.encoder_page_table
576
+ cache_seqlens = metadata.encoder_lens_int32
577
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
578
+ window_size = (-1, -1)
579
+
270
580
  o = flash_attn_with_kvcache(
271
581
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
272
582
  k_cache=key_cache,
273
583
  v_cache=value_cache,
274
584
  page_table=page_table,
275
- cache_seqlens=metadata.cache_seqlens_int32,
276
- cu_seqlens_q=metadata.cu_seqlens_q,
277
- cu_seqlens_k_new=metadata.cu_seqlens_k,
278
- max_seqlen_q=metadata.max_seq_len_q,
585
+ cache_seqlens=cache_seqlens,
586
+ cu_seqlens_q=cu_seqlens_q,
587
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
588
+ max_seqlen_q=max_seqlen_q,
279
589
  softmax_scale=layer.scaling,
280
- causal=True,
590
+ causal=causal,
281
591
  window_size=window_size,
282
592
  softcap=layer.logit_cap,
283
- k_descale=layer.k_scale,
284
- v_descale=layer.v_scale,
593
+ k_descale=k_descale,
594
+ v_descale=v_descale,
285
595
  )
596
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
286
597
  else:
287
- # Do absorbed multi-latent attention
288
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
289
- k_rope = kv_cache[:, :, layer.v_head_dim :]
290
- c_kv = kv_cache[:, :, : layer.v_head_dim]
291
- k_rope_cache = k_rope.view(
292
- -1,
293
- self.page_size,
294
- layer.tp_k_head_num,
295
- layer.head_dim - layer.v_head_dim,
296
- )
297
- c_kv_cache = c_kv.view(
298
- -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
299
- )
300
-
301
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
302
- q_nope = q_all[:, :, : layer.v_head_dim]
303
- q_rope = q_all[:, :, layer.v_head_dim :]
304
- o = flash_attn_with_kvcache(
305
- q=q_rope,
306
- k_cache=k_rope_cache,
307
- v_cache=c_kv_cache,
308
- qv=q_nope,
309
- page_table=page_table,
310
- cache_seqlens=metadata.cache_seqlens_int32,
311
- cu_seqlens_q=metadata.cu_seqlens_q,
312
- cu_seqlens_k_new=metadata.cu_seqlens_k,
313
- max_seqlen_q=metadata.max_seq_len_q,
314
- softmax_scale=layer.scaling,
315
- causal=True,
316
- softcap=layer.logit_cap,
317
- k_descale=layer.k_scale,
318
- v_descale=layer.v_scale,
319
- )
598
+ if (
599
+ not global_server_args_dict["disable_chunked_prefix_cache"]
600
+ and forward_batch.attn_attend_prefix_cache is not None
601
+ and not forward_batch.forward_mode.is_target_verify()
602
+ and not forward_batch.forward_mode.is_draft_extend()
603
+ ):
604
+ # Do multi-head attention with chunked prefix cache
605
+
606
+ if forward_batch.attn_attend_prefix_cache:
607
+ # MHA for chunked prefix kv cache when running model with MLA
608
+ assert forward_batch.prefix_chunk_idx is not None
609
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
610
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
611
+
612
+ chunk_idx = forward_batch.prefix_chunk_idx
613
+ assert chunk_idx >= 0
614
+
615
+ output, lse, *rest = flash_attn_varlen_func(
616
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
617
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
618
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
619
+ cu_seqlens_q=metadata.cu_seqlens_q,
620
+ cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
621
+ max_seqlen_q=metadata.max_seq_len_q,
622
+ max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
623
+ softmax_scale=layer.scaling,
624
+ causal=False,
625
+ return_softmax_lse=True,
626
+ )
627
+ else:
628
+ # MHA for extend part of sequence without attending prefix kv cache
629
+ output, lse, *rest = flash_attn_varlen_func(
630
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
631
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
632
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
633
+ cu_seqlens_q=metadata.cu_seqlens_q,
634
+ cu_seqlens_k=metadata.cu_seqlens_q,
635
+ max_seqlen_q=metadata.max_seq_len_q,
636
+ max_seqlen_k=metadata.max_seq_len_q,
637
+ softmax_scale=layer.scaling,
638
+ causal=True,
639
+ return_softmax_lse=True,
640
+ )
641
+ return output, lse
642
+ else:
643
+ # Do absorbed multi-latent attention
644
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
645
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
646
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
647
+ k_rope_cache = k_rope.view(
648
+ -1,
649
+ self.page_size,
650
+ layer.tp_k_head_num,
651
+ layer.head_dim - layer.v_head_dim,
652
+ )
653
+ c_kv_cache = c_kv.view(
654
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
655
+ )
656
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
657
+ q_nope = q_all[:, :, : layer.v_head_dim]
658
+ q_rope = q_all[:, :, layer.v_head_dim :]
659
+ o = flash_attn_with_kvcache(
660
+ q=q_rope,
661
+ k_cache=k_rope_cache,
662
+ v_cache=c_kv_cache,
663
+ qv=q_nope,
664
+ page_table=page_table,
665
+ cache_seqlens=cache_seqlens,
666
+ cu_seqlens_q=cu_seqlens_q,
667
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
668
+ max_seqlen_q=max_seqlen_q,
669
+ softmax_scale=layer.scaling,
670
+ causal=True,
671
+ softcap=layer.logit_cap,
672
+ k_descale=k_descale,
673
+ v_descale=v_descale,
674
+ )
320
675
 
321
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
676
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
322
677
 
323
678
  def forward_decode(
324
679
  self,
@@ -329,8 +684,6 @@ class FlashAttentionBackend(AttentionBackend):
329
684
  forward_batch: ForwardBatch,
330
685
  save_kv_cache=True,
331
686
  ) -> torch.Tensor:
332
- """Forward pass with FlashAttention using precomputed metadata."""
333
- # Save KV cache if needed
334
687
  if k is not None:
335
688
  assert v is not None
336
689
  if save_kv_cache:
@@ -351,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
351
704
  v,
352
705
  )
353
706
 
354
- # Use precomputed metadata
707
+ # Use precomputed metadata across all layers
355
708
  metadata = self.forward_metadata
356
709
 
357
710
  # Calculate window size (can be moved to metadata if layer properties don't change)
@@ -359,17 +712,27 @@ class FlashAttentionBackend(AttentionBackend):
359
712
  # here is two side inclusive
360
713
  window_size = (
361
714
  (layer.sliding_window_size, 0)
362
- if layer.sliding_window_size is not None
715
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
363
716
  else (-1, -1)
364
717
  )
365
- page_table = metadata.page_table
718
+ causal = not layer.is_cross_attention
719
+
720
+ k_descale, v_descale = None, None
721
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
722
+ # has corresponding quantization method so that layer.k_scale is not None
723
+ if self.kv_cache_dtype_str != "auto":
724
+ if layer.k_scale is not None:
725
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
726
+ k_descale = layer.k_scale.expand(descale_shape)
727
+ v_descale = layer.v_scale.expand(descale_shape)
728
+ q = q.to(self.kv_cache_dtype)
366
729
 
367
730
  if not self.use_mla:
368
731
  # Do multi-head attention
369
732
 
370
- # Get KV cache
371
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
372
- key_cache, value_cache = kv_cache[0], kv_cache[1]
733
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
734
+ layer.layer_id
735
+ )
373
736
  key_cache = key_cache.view(
374
737
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
375
738
  )
@@ -377,23 +740,32 @@ class FlashAttentionBackend(AttentionBackend):
377
740
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
378
741
  )
379
742
 
380
- # Pre-reshape query tensor
381
743
  q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
744
+ if layer.is_cross_attention:
745
+ page_table = metadata.encoder_page_table
746
+ cache_seqlens = metadata.encoder_lens_int32
747
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
748
+ window_size = (-1, -1)
749
+ else:
750
+ page_table = metadata.page_table
751
+ cache_seqlens = metadata.cache_seqlens_int32
752
+ cu_seqlens_k = metadata.cu_seqlens_k
753
+
382
754
  o = flash_attn_with_kvcache(
383
755
  q=q_reshaped,
384
756
  k_cache=key_cache,
385
757
  v_cache=value_cache,
386
758
  page_table=page_table,
387
- cache_seqlens=metadata.cache_seqlens_int32,
759
+ cache_seqlens=cache_seqlens,
388
760
  cu_seqlens_q=metadata.cu_seqlens_q,
389
- cu_seqlens_k_new=metadata.cu_seqlens_k,
761
+ cu_seqlens_k_new=cu_seqlens_k,
390
762
  max_seqlen_q=1,
391
763
  softmax_scale=layer.scaling,
392
- causal=True,
764
+ causal=causal,
393
765
  window_size=window_size,
394
766
  softcap=layer.logit_cap,
395
- k_descale=layer.k_scale,
396
- v_descale=layer.v_scale,
767
+ k_descale=k_descale,
768
+ v_descale=v_descale,
397
769
  )
398
770
  else:
399
771
  # Do absorbed multi-latent attention
@@ -419,7 +791,7 @@ class FlashAttentionBackend(AttentionBackend):
419
791
  k_cache=k_rope_cache,
420
792
  v_cache=c_kv_cache,
421
793
  qv=q_nope,
422
- page_table=page_table,
794
+ page_table=metadata.page_table,
423
795
  cache_seqlens=metadata.cache_seqlens_int32,
424
796
  cu_seqlens_q=metadata.cu_seqlens_q,
425
797
  cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -427,8 +799,8 @@ class FlashAttentionBackend(AttentionBackend):
427
799
  softmax_scale=layer.scaling,
428
800
  causal=True,
429
801
  softcap=layer.logit_cap,
430
- k_descale=layer.k_scale,
431
- v_descale=layer.v_scale,
802
+ k_descale=k_descale,
803
+ v_descale=v_descale,
432
804
  )
433
805
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
434
806
 
@@ -442,7 +814,13 @@ class FlashAttentionBackend(AttentionBackend):
442
814
  to avoid memory allocations.
443
815
  """
444
816
  self.decode_cuda_graph_metadata = {
445
- # Page table for token mapping (batch_size, max_context_len)
817
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
818
+ "cu_seqlens_q": torch.arange(
819
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
820
+ ),
821
+ "cu_seqlens_k": torch.zeros(
822
+ max_bs + 1, dtype=torch.int32, device=self.device
823
+ ),
446
824
  "page_table": torch.zeros(
447
825
  max_bs,
448
826
  (self.max_context_len + self.page_size - 1) // self.page_size,
@@ -458,35 +836,42 @@ class FlashAttentionBackend(AttentionBackend):
458
836
  "strided_indices": torch.arange(
459
837
  0, self.max_context_len, self.page_size, device=self.device
460
838
  ),
839
+ }
840
+
841
+ self.target_verify_metadata = {
461
842
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
462
- "cu_seqlens_q": torch.arange(
463
- 0, max_bs + 128, dtype=torch.int32, device=self.device
843
+ "cu_seqlens_q": torch.zeros(
844
+ max_bs + 1, dtype=torch.int32, device=self.device
464
845
  ),
465
846
  "cu_seqlens_k": torch.zeros(
466
- max_bs + 128, dtype=torch.int32, device=self.device
847
+ max_bs + 1, dtype=torch.int32, device=self.device
467
848
  ),
468
- }
469
-
470
- self.target_verify_metadata = {
471
849
  "page_table": torch.zeros(
472
850
  max_bs,
473
851
  (self.max_context_len + self.page_size - 1) // self.page_size,
474
852
  dtype=torch.int32,
475
853
  device=self.device,
476
854
  ),
477
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
478
- "cu_seqlens_q": torch.zeros(
479
- max_bs + 128, dtype=torch.int32, device=self.device
480
- ),
481
- "cu_seqlens_k": torch.zeros(
482
- max_bs + 128, dtype=torch.int32, device=self.device
483
- ),
484
- "max_seqlen_q": 0,
485
855
  "strided_indices": torch.arange(
486
856
  0, self.max_context_len, self.page_size, device=self.device
487
857
  ),
488
858
  }
489
859
 
860
+ self.encoder_metadata = {
861
+ "encoder_page_table": torch.zeros(
862
+ max_bs,
863
+ self.max_context_len,
864
+ dtype=torch.int32,
865
+ device=self.device,
866
+ ),
867
+ "encoder_lens_int32": torch.zeros(
868
+ max_bs, dtype=torch.int32, device=self.device
869
+ ),
870
+ "encoder_cu_seqlens_k": torch.zeros(
871
+ max_bs + 1, dtype=torch.int32, device=self.device
872
+ ),
873
+ }
874
+
490
875
  def init_forward_metadata_capture_cuda_graph(
491
876
  self,
492
877
  bs: int,
@@ -500,27 +885,24 @@ class FlashAttentionBackend(AttentionBackend):
500
885
  """Initialize forward metadata for capturing CUDA graph."""
501
886
  metadata = FlashAttentionMetadata()
502
887
  device = seq_lens.device
503
- if forward_mode.is_decode():
888
+ if forward_mode.is_decode_or_idle():
504
889
  if spec_info is not None:
505
890
  # Draft Decode
506
- metadata.cu_seqlens_q = torch.arange(
507
- 0, bs + 1, dtype=torch.int32, device=device
508
- )
509
891
  metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
510
892
  "cache_seqlens"
511
893
  ][:bs]
512
-
894
+ metadata.max_seq_len_k = seq_lens.max().item() + (
895
+ self.speculative_step_id + 1
896
+ )
513
897
  metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
514
898
  : bs + 1
515
899
  ]
516
-
517
900
  metadata.cu_seqlens_k = torch.nn.functional.pad(
518
901
  torch.cumsum(
519
902
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
520
903
  ),
521
904
  (1, 0),
522
905
  )
523
- metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
524
906
  metadata.page_table = self.decode_cuda_graph_metadata[
525
907
  "page_table_draft_decode"
526
908
  ][req_pool_indices, :]
@@ -545,43 +927,49 @@ class FlashAttentionBackend(AttentionBackend):
545
927
  )
546
928
  self.decode_cuda_graph_metadata[bs] = metadata
547
929
  elif forward_mode.is_target_verify():
548
- draft_token_num = spec_info.draft_token_num
549
-
550
930
  metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
551
931
  :bs
552
932
  ]
553
933
  metadata.cache_seqlens_int32.copy_(
554
- (seq_lens + draft_token_num).to(torch.int32)
934
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
555
935
  )
556
936
 
557
- metadata.max_seq_len_q = draft_token_num
558
- metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
937
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
938
+ metadata.max_seq_len_k = (
939
+ seq_lens.max().item() + self.speculative_num_draft_tokens
940
+ )
559
941
 
560
- metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
561
- torch.arange(
562
- 0,
563
- bs * draft_token_num + 1,
564
- draft_token_num,
565
- dtype=torch.int32,
566
- device=device,
567
- )
568
- ]
569
- cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
570
- cu_k.copy_(
571
- torch.nn.functional.pad(
572
- torch.cumsum(
573
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
574
- ),
575
- (1, 0),
576
- )
942
+ metadata.cu_seqlens_q = torch.arange(
943
+ 0,
944
+ bs * self.speculative_num_draft_tokens + 1,
945
+ self.speculative_num_draft_tokens,
946
+ dtype=torch.int32,
947
+ device=device,
577
948
  )
578
- metadata.cu_seqlens_k = cu_k
949
+
950
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
951
+ : (bs + 1)
952
+ ]
953
+
579
954
  metadata.page_table = self.target_verify_metadata["page_table"][
580
955
  req_pool_indices, :
581
956
  ]
582
957
 
583
958
  self.target_verify_metadata[bs] = metadata
584
959
 
960
+ if encoder_lens is not None:
961
+ encoder_bs = encoder_lens.numel()
962
+ metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
963
+ :encoder_bs
964
+ ]
965
+ metadata.encoder_cu_seqlens_k = self.encoder_metadata[
966
+ "encoder_cu_seqlens_k"
967
+ ][: (encoder_bs + 1)]
968
+
969
+ metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
970
+ req_pool_indices, :
971
+ ]
972
+
585
973
  self.forward_metadata = metadata
586
974
 
587
975
  def init_forward_metadata_replay_cuda_graph(
@@ -597,24 +985,21 @@ class FlashAttentionBackend(AttentionBackend):
597
985
  out_cache_loc: torch.Tensor = None,
598
986
  ):
599
987
  # """Initialize forward metadata for replaying CUDA graph."""
600
- device = seq_lens.device
601
988
  seq_lens = seq_lens[:bs]
602
- req_pool_indices = req_pool_indices[:bs]
603
989
  seq_lens_cpu = seq_lens_cpu[:bs]
604
- if forward_mode.is_decode():
990
+ req_pool_indices = req_pool_indices[:bs]
991
+ if forward_mode.is_decode_or_idle():
605
992
  metadata = self.decode_cuda_graph_metadata[bs]
606
993
 
607
994
  if spec_info is not None:
608
995
  # Draft Decode
609
- max_len = seq_lens_cpu.max().item()
610
- metadata.max_seq_len_k = max_len + (self.step_id + 1)
611
-
612
996
  metadata.cache_seqlens_int32.copy_(
613
- (seq_lens + (self.step_id + 1)).to(torch.int32)
997
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
614
998
  )
615
999
 
616
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
617
-
1000
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1001
+ self.speculative_step_id + 1
1002
+ )
618
1003
  metadata.cu_seqlens_k.copy_(
619
1004
  torch.nn.functional.pad(
620
1005
  torch.cumsum(
@@ -643,31 +1028,24 @@ class FlashAttentionBackend(AttentionBackend):
643
1028
  metadata.max_seq_len_k + self.page_size - 1
644
1029
  ) // self.page_size
645
1030
  page_indices = self.req_to_token[
646
- :,
647
- self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1031
+ req_pool_indices[:, None],
1032
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
1033
+ None, :
1034
+ ],
648
1035
  ]
649
- page_indices = page_indices[req_pool_indices] // self.page_size
1036
+ page_indices //= self.page_size
650
1037
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
651
1038
  metadata.page_table[:, max_seq_pages:].fill_(0)
652
1039
 
653
1040
  elif forward_mode.is_target_verify():
654
1041
  metadata = self.target_verify_metadata[bs]
655
- draft_token_num = spec_info.draft_token_num
656
-
657
- metadata.cu_seqlens_q.copy_(
658
- torch.arange(
659
- 0,
660
- bs * draft_token_num + 1,
661
- draft_token_num,
662
- dtype=torch.int32,
663
- device=device,
664
- )
665
- )
666
1042
  metadata.cache_seqlens_int32.copy_(
667
- (seq_lens + draft_token_num).to(torch.int32)
1043
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
668
1044
  )
669
1045
 
670
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
1046
+ metadata.max_seq_len_k = (
1047
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1048
+ )
671
1049
  metadata.cu_seqlens_k.copy_(
672
1050
  torch.nn.functional.pad(
673
1051
  torch.cumsum(
@@ -679,6 +1057,30 @@ class FlashAttentionBackend(AttentionBackend):
679
1057
  page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
680
1058
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
681
1059
 
1060
+ if encoder_lens is not None:
1061
+ # Only support encoder size 1 for now
1062
+ metadata.encoder_max_seq_len_k = encoder_lens[0]
1063
+ metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1064
+ metadata.encoder_cu_seqlens_k.copy_(
1065
+ torch.nn.functional.pad(
1066
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1067
+ (1, 0),
1068
+ )
1069
+ )
1070
+
1071
+ metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
1072
+ self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
1073
+ )
1074
+
1075
+ # Update the regular page table
1076
+ page_table = self.req_to_token[
1077
+ req_pool_indices,
1078
+ metadata.encoder_max_seq_len_k : (
1079
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
1080
+ ),
1081
+ ]
1082
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1083
+
682
1084
  self.forward_metadata = metadata
683
1085
 
684
1086
  def get_cuda_graph_seq_len_fill_value(self):
@@ -695,14 +1097,19 @@ class FlashAttentionMultiStepBackend:
695
1097
  self.topk = topk
696
1098
  self.speculative_num_steps = speculative_num_steps
697
1099
 
1100
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1101
+ assert (
1102
+ self.topk == 1
1103
+ ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1104
+
698
1105
  self.attn_backends = []
699
1106
  for i in range(self.speculative_num_steps):
700
1107
  self.attn_backends.append(
701
1108
  FlashAttentionBackend(
702
1109
  model_runner,
1110
+ speculative_step_id=i,
703
1111
  topk=self.topk,
704
1112
  speculative_num_steps=self.speculative_num_steps,
705
- step_id=i,
706
1113
  )
707
1114
  )
708
1115
 
@@ -727,7 +1134,7 @@ class FlashAttentionMultiStepBackend:
727
1134
  forward_batch.batch_size * self.topk,
728
1135
  forward_batch.req_pool_indices,
729
1136
  forward_batch.seq_lens,
730
- encoder_lens=None,
1137
+ encoder_lens=forward_batch.encoder_lens,
731
1138
  forward_mode=ForwardMode.DECODE,
732
1139
  spec_info=forward_batch.spec_info,
733
1140
  )
@@ -744,7 +1151,7 @@ class FlashAttentionMultiStepBackend:
744
1151
  forward_batch.req_pool_indices,
745
1152
  forward_batch.seq_lens,
746
1153
  forward_batch.seq_lens_sum,
747
- encoder_lens=None,
1154
+ encoder_lens=forward_batch.encoder_lens,
748
1155
  forward_mode=ForwardMode.DECODE,
749
1156
  spec_info=forward_batch.spec_info,
750
1157
  seq_lens_cpu=forward_batch.seq_lens_cpu,