sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import numpy as np
|
4
|
+
|
3
5
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
4
6
|
|
5
7
|
"""
|
@@ -22,29 +24,255 @@ if TYPE_CHECKING:
|
|
22
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
23
25
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
26
|
|
25
|
-
from
|
27
|
+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
26
28
|
|
27
29
|
|
28
30
|
@dataclass
|
29
31
|
class FlashAttentionMetadata:
|
30
|
-
"""Metadata
|
32
|
+
"""Metadata to be init once in the model forward pass,
|
33
|
+
each layer's forward pass can reuse the metadata."""
|
31
34
|
|
35
|
+
# Cumulative sequence lengths for query
|
32
36
|
cu_seqlens_q: torch.Tensor = None
|
37
|
+
# Cumulative sequence lengths for key
|
33
38
|
cu_seqlens_k: torch.Tensor = None
|
39
|
+
# Maximum sequence length for query
|
34
40
|
max_seq_len_q: int = 0
|
41
|
+
# Maximum sequence length for key
|
35
42
|
max_seq_len_k: int = 0
|
43
|
+
# Window size (typically used by Gemma)
|
36
44
|
window_size: tuple = (-1, -1)
|
45
|
+
# Page table, the index of KV Cache Tables/Blocks
|
37
46
|
page_table: torch.Tensor = None
|
47
|
+
# Sequence lengths for the forward batch
|
38
48
|
cache_seqlens_int32: torch.Tensor = None
|
39
49
|
|
50
|
+
@dataclass
|
51
|
+
class LocalAttentionMetadata:
|
52
|
+
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
53
|
+
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
54
|
+
local_block_table: torch.Tensor = None # block table for local attention
|
55
|
+
local_max_query_len: int = 0 # max query length for local attention
|
56
|
+
local_max_seq_len: int = 0 # max sequence length for local attention
|
57
|
+
|
58
|
+
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
59
|
+
|
60
|
+
|
61
|
+
# Copied from:
|
62
|
+
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
63
|
+
#
|
64
|
+
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
65
|
+
# local attention blocks, where each block is passed to the attention kernel
|
66
|
+
# as an independent local ("virtual") batch item.
|
67
|
+
#
|
68
|
+
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
69
|
+
# q_seqlens = [4, 10, 5]
|
70
|
+
# kv_seqlens = [6, 17, 9]
|
71
|
+
# Then normally for regular attention we would compute with an attention mask
|
72
|
+
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
73
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
74
|
+
# k_toks > 0 1 2 3 4 5
|
75
|
+
# q_toks v _____________
|
76
|
+
# 0 | 1 1 1
|
77
|
+
# 1 | 1 1 1 1
|
78
|
+
# 2 | 1 1 1 1 1
|
79
|
+
# 3 | 1 1 1 1 1 1
|
80
|
+
#
|
81
|
+
# for local attention (with attn_chunk_size = 4) we would compute with an
|
82
|
+
# attention mask like:
|
83
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
84
|
+
# k_toks > 0 1 2 3 4 5
|
85
|
+
# q_toks v _____________
|
86
|
+
# 0 | 1 1 1
|
87
|
+
# 1 | 1 1 1 1
|
88
|
+
# 2 | 1
|
89
|
+
# 3 | 1 1
|
90
|
+
#
|
91
|
+
# We can simulate this mask using standard flash-attention by breaking the
|
92
|
+
# sequences into local ("virtual") batches, where each local batch item is a
|
93
|
+
# local attention block, so in this case batch idx 0 would be broken up into:
|
94
|
+
#
|
95
|
+
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
96
|
+
# k_toks > 0 1 2 3
|
97
|
+
# q_toks v _____________
|
98
|
+
# 0 | 1 1 1
|
99
|
+
# 1 | 1 1 1 1
|
100
|
+
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
101
|
+
# k_toks > 4 5
|
102
|
+
# q_toks v _____________
|
103
|
+
# 2 | 1
|
104
|
+
# 3 | 1 1
|
105
|
+
#
|
106
|
+
# e.g. if we have:
|
107
|
+
# attn_chunk_size = 4
|
108
|
+
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
109
|
+
# Then this function would return:
|
110
|
+
# __b0__ ______b1______ __b2__ < orig batch indices
|
111
|
+
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
112
|
+
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
113
|
+
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
114
|
+
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
115
|
+
def make_local_attention_virtual_batches(
|
116
|
+
attn_chunk_size: int,
|
117
|
+
query_start_loc_np: np.ndarray,
|
118
|
+
seq_lens_np: np.ndarray,
|
119
|
+
block_table: torch.Tensor,
|
120
|
+
page_size: int = 0,
|
121
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
122
|
+
"""
|
123
|
+
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
124
|
+
local attention blocks, where each block is passed to the attention kernel
|
125
|
+
as an independent local ("virtual") batch item.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
attn_chunk_size: Size of local attention chunks
|
129
|
+
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
130
|
+
seq_lens_np: Sequence lengths (numpy array)
|
131
|
+
block_table: Block table for KV cache
|
132
|
+
page_size: Size of each page in the KV cache
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
seqlens_q_local: Query sequence lengths for local attention
|
136
|
+
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
137
|
+
seqlens_k_local: Key sequence lengths for local attention
|
138
|
+
block_table_local: Block table for local attention
|
139
|
+
"""
|
140
|
+
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
141
|
+
actual_batch_size = seq_lens_np.shape[0]
|
142
|
+
|
143
|
+
# Handle if we are starting in the middle of a local attention block,
|
144
|
+
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
145
|
+
# the number of tokens that are not in the first local attention block and
|
146
|
+
# then we can simply use a cdiv for the rest.
|
147
|
+
# For example if we have:
|
148
|
+
# attn_chunk_size = 4
|
149
|
+
# q_seqlens = [4, 10, 5]
|
150
|
+
# k_seqlens = [6, 17, 9]
|
151
|
+
# Then we would get:
|
152
|
+
# new_tokens_in_first_block = [2, 1, 4]
|
153
|
+
# local_blocks = [2, 4, 2]
|
154
|
+
q_tokens_in_first_block = np.minimum(
|
155
|
+
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
156
|
+
).astype(np.int32)
|
157
|
+
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
158
|
+
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
159
|
+
|
160
|
+
# Once we know the number of local blocks we can compute the request spans
|
161
|
+
# for each batch idx, we can figure out the number of "virtual" requests we
|
162
|
+
# have to make,
|
163
|
+
# For the above example we would get:
|
164
|
+
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
165
|
+
#
|
166
|
+
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
167
|
+
# (TODO: max a utility to share this code with _prepare_inputs)
|
168
|
+
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
169
|
+
cu_num_blocks = np.cumsum(local_blocks)
|
170
|
+
virtual_batches = cu_num_blocks[-1]
|
171
|
+
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
172
|
+
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
173
|
+
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
174
|
+
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
175
|
+
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
176
|
+
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
177
|
+
# Then we can compute the seqlens_q_local, handling the fact that the
|
178
|
+
# first and last blocks could be partial
|
179
|
+
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
180
|
+
# set the first block since this may be a partial block
|
181
|
+
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
182
|
+
# set the remaining blocks
|
183
|
+
seqlens_q_local[arange > 0] = np.minimum(
|
184
|
+
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
185
|
+
)[arange > 0]
|
186
|
+
|
187
|
+
# convert from q_seqlens to cu_seqlens_q
|
188
|
+
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
189
|
+
|
190
|
+
# compute the seqlens_k_local,
|
191
|
+
# basically a full local attention block for all but the last block in each
|
192
|
+
# batch
|
193
|
+
# For our example this will be:
|
194
|
+
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
195
|
+
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
196
|
+
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
197
|
+
|
198
|
+
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
199
|
+
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
200
|
+
)
|
201
|
+
# For the example the local attention blocks start at:
|
202
|
+
# _b0_ _____b1_____ _b2_
|
203
|
+
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
204
|
+
block_starts = k_seqstarts_absolute // page_size
|
205
|
+
|
206
|
+
assert attn_chunk_size % page_size == 0, (
|
207
|
+
f"attn_chunk_size {attn_chunk_size} is not "
|
208
|
+
f"divisible by page_size {page_size}"
|
209
|
+
)
|
210
|
+
pages_per_local_batch = attn_chunk_size // page_size
|
211
|
+
|
212
|
+
# Create a block_table for the local attention blocks
|
213
|
+
# For out example if we have a block-table like (assuming page_size=2):
|
214
|
+
# block_table = [
|
215
|
+
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
216
|
+
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
217
|
+
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
218
|
+
# ]
|
219
|
+
# Then for the local batches we would want a block-table like
|
220
|
+
# block_table_local = [
|
221
|
+
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
222
|
+
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
223
|
+
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
224
|
+
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
225
|
+
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
226
|
+
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
227
|
+
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
228
|
+
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
229
|
+
# ]
|
230
|
+
block_indices = np.broadcast_to(
|
231
|
+
np.arange(pages_per_local_batch, dtype=np.int32),
|
232
|
+
(virtual_batches, pages_per_local_batch),
|
233
|
+
) + np.expand_dims(block_starts, axis=1)
|
234
|
+
block_indices = block_indices.flatten()
|
235
|
+
batch_indices = np.repeat(
|
236
|
+
np.arange(actual_batch_size, dtype=np.int32),
|
237
|
+
local_blocks * pages_per_local_batch,
|
238
|
+
)
|
239
|
+
block_table_local = block_table[batch_indices, block_indices].view(
|
240
|
+
virtual_batches, -1
|
241
|
+
)
|
242
|
+
|
243
|
+
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
244
|
+
|
245
|
+
|
246
|
+
def cdiv(a: int, b: int) -> int:
|
247
|
+
"""Ceiling division."""
|
248
|
+
return -(a // -b)
|
249
|
+
|
40
250
|
|
41
251
|
class FlashAttentionBackend(AttentionBackend):
|
42
|
-
"""FlashAttention backend implementation.
|
252
|
+
"""FlashAttention backend implementation.
|
253
|
+
|
254
|
+
Note about the init:
|
255
|
+
- If no spec decoding
|
256
|
+
- FlashAttentionBackend will be init once when the server starts.
|
257
|
+
- If spec decoding
|
258
|
+
- FlashAttentionBackend will be init once for the target worker
|
259
|
+
- FlashAttentionMultiStepBackend will be once for the draft worker
|
260
|
+
- It will spawn num_steps FlashAttentionBackend for the draft worker
|
261
|
+
|
262
|
+
Note about CUDA Graph:
|
263
|
+
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
|
264
|
+
- We don't support CUDA Graph for Extend and Draft Extend.
|
265
|
+
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
|
266
|
+
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
|
267
|
+
"""
|
43
268
|
|
44
269
|
def __init__(
|
45
270
|
self,
|
46
271
|
model_runner: ModelRunner,
|
47
272
|
skip_prefill: bool = False,
|
273
|
+
topk=0,
|
274
|
+
speculative_num_steps=0,
|
275
|
+
step_id=0,
|
48
276
|
):
|
49
277
|
super().__init__()
|
50
278
|
|
@@ -53,56 +281,129 @@ class FlashAttentionBackend(AttentionBackend):
|
|
53
281
|
and model_runner.model_config.is_encoder_decoder
|
54
282
|
), "Sliding window and cross attention are not supported together"
|
55
283
|
|
56
|
-
# Initialize metadata
|
57
284
|
self.forward_metadata: FlashAttentionMetadata = None
|
58
285
|
self.max_context_len = model_runner.model_config.context_len
|
59
286
|
self.device = model_runner.device
|
60
287
|
self.decode_cuda_graph_metadata = {}
|
288
|
+
self.target_verify_metadata = {}
|
61
289
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
62
290
|
self.page_size = model_runner.page_size
|
63
291
|
self.use_mla = (
|
64
292
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
65
293
|
) and (not global_server_args_dict["disable_mla"])
|
294
|
+
self.skip_prefill = skip_prefill
|
295
|
+
|
296
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
297
|
+
assert (
|
298
|
+
topk <= 1
|
299
|
+
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
300
|
+
|
301
|
+
self.topk = 1
|
302
|
+
self.step_id = step_id
|
303
|
+
self.speculative_num_steps = speculative_num_steps
|
304
|
+
|
305
|
+
# Local attention settings
|
306
|
+
self.attention_chunk_size = (
|
307
|
+
model_runner.attention_chunk_size
|
308
|
+
if hasattr(model_runner, "attention_chunk_size")
|
309
|
+
else None
|
310
|
+
)
|
66
311
|
|
67
312
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
68
313
|
"""Initialize forward metadata to cache repetitive calculations."""
|
69
|
-
# Create metadata based on forward mode
|
70
314
|
metadata = FlashAttentionMetadata()
|
71
|
-
|
72
|
-
# Get sequence information
|
73
315
|
seqlens_in_batch = forward_batch.seq_lens
|
74
|
-
# Precompute int32 version of sequence lengths
|
75
|
-
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
76
316
|
batch_size = len(seqlens_in_batch)
|
77
317
|
device = seqlens_in_batch.device
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
318
|
+
if forward_batch.forward_mode.is_decode():
|
319
|
+
# Skip Prefill or Draft Decode
|
320
|
+
# Note: Draft Decode will be ran on the Draft Worker
|
321
|
+
if forward_batch.spec_info is not None:
|
322
|
+
metadata.cu_seqlens_q = torch.arange(
|
323
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
324
|
+
)
|
325
|
+
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
326
|
+
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
327
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
328
|
+
torch.cumsum(
|
329
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
330
|
+
),
|
331
|
+
(1, 0),
|
332
|
+
)
|
333
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
334
|
+
self.step_id + 1
|
335
|
+
)
|
336
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
337
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
338
|
+
]
|
339
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
340
|
+
self.speculative_num_steps, -1
|
341
|
+
).T
|
87
342
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
343
|
+
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
344
|
+
real_bsz_start_idx = idx
|
345
|
+
real_bsz_end_idx = idx + 1
|
346
|
+
metadata.page_table[
|
347
|
+
real_bsz_start_idx:real_bsz_end_idx,
|
348
|
+
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
349
|
+
] = cache_loc[
|
350
|
+
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
351
|
+
]
|
352
|
+
else: # Normal Decode without Spec Decoding
|
353
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
354
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
355
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
356
|
+
)
|
357
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
358
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
359
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
360
|
+
]
|
361
|
+
metadata.cu_seqlens_q = torch.arange(
|
362
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
363
|
+
)
|
364
|
+
elif forward_batch.forward_mode.is_target_verify():
|
365
|
+
# Note: Target Verify will be ran on the Target Worker
|
366
|
+
draft_token_num = forward_batch.spec_info.draft_token_num
|
367
|
+
metadata.cache_seqlens_int32 = (
|
368
|
+
forward_batch.seq_lens + draft_token_num
|
369
|
+
).to(torch.int32)
|
370
|
+
metadata.max_seq_len_q = draft_token_num
|
371
|
+
metadata.max_seq_len_k = (
|
372
|
+
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
93
373
|
)
|
94
|
-
metadata.
|
95
|
-
|
374
|
+
metadata.cu_seqlens_q = torch.arange(
|
375
|
+
0,
|
376
|
+
batch_size * draft_token_num + 1,
|
377
|
+
draft_token_num,
|
378
|
+
dtype=torch.int32,
|
379
|
+
device=device,
|
96
380
|
)
|
381
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
382
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
383
|
+
(1, 0),
|
384
|
+
)
|
385
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
386
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
387
|
+
]
|
97
388
|
|
98
|
-
|
99
|
-
#
|
100
|
-
metadata.
|
101
|
-
|
389
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
390
|
+
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
391
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
392
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
393
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
102
394
|
)
|
103
|
-
|
395
|
+
# Precompute maximum sequence length
|
396
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
397
|
+
# Precompute page table
|
398
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
399
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
400
|
+
]
|
401
|
+
|
104
402
|
# Precompute cumulative sequence lengths
|
105
|
-
if
|
403
|
+
if (
|
404
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
405
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
406
|
+
):
|
106
407
|
extend_seq_lens = forward_batch.extend_seq_lens
|
107
408
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
108
409
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
@@ -111,6 +412,61 @@ class FlashAttentionBackend(AttentionBackend):
|
|
111
412
|
else:
|
112
413
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
113
414
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
415
|
+
|
416
|
+
# Setup local attention if enabled
|
417
|
+
if (
|
418
|
+
self.attention_chunk_size is not None
|
419
|
+
and forward_batch.forward_mode == ForwardMode.EXTEND
|
420
|
+
):
|
421
|
+
# Convert tensors to numpy for local attention processing
|
422
|
+
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
423
|
+
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
424
|
+
|
425
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
426
|
+
# to avoid index out of bounds errors
|
427
|
+
max_seq_len = seq_lens_np.max()
|
428
|
+
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
429
|
+
# Make sure effective_chunk_size is divisible by page_size
|
430
|
+
effective_chunk_size = (
|
431
|
+
effective_chunk_size // self.page_size
|
432
|
+
) * self.page_size
|
433
|
+
if effective_chunk_size < self.page_size:
|
434
|
+
effective_chunk_size = self.page_size
|
435
|
+
|
436
|
+
# Create local attention metadata
|
437
|
+
(
|
438
|
+
seqlens_q_local_np,
|
439
|
+
cu_seqlens_q_local_np,
|
440
|
+
seqlens_k_local_np,
|
441
|
+
block_table_local,
|
442
|
+
) = make_local_attention_virtual_batches(
|
443
|
+
effective_chunk_size,
|
444
|
+
cu_seqlens_q_np,
|
445
|
+
seq_lens_np,
|
446
|
+
metadata.page_table,
|
447
|
+
self.page_size,
|
448
|
+
)
|
449
|
+
|
450
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
451
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
452
|
+
device
|
453
|
+
),
|
454
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
455
|
+
local_block_table=block_table_local,
|
456
|
+
local_max_query_len=seqlens_q_local_np.max(),
|
457
|
+
local_max_seq_len=seqlens_k_local_np.max(),
|
458
|
+
)
|
459
|
+
metadata.local_attn_metadata = local_metadata
|
460
|
+
|
461
|
+
# Precompute strided indices
|
462
|
+
if self.page_size > 1:
|
463
|
+
self.strided_indices = torch.arange(
|
464
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
465
|
+
)
|
466
|
+
metadata.page_table = (
|
467
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
468
|
+
)
|
469
|
+
|
114
470
|
self.forward_metadata = metadata
|
115
471
|
|
116
472
|
def forward_extend(
|
@@ -122,7 +478,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
122
478
|
forward_batch: ForwardBatch,
|
123
479
|
save_kv_cache=True,
|
124
480
|
):
|
125
|
-
|
126
481
|
if k is not None:
|
127
482
|
assert v is not None
|
128
483
|
if save_kv_cache:
|
@@ -155,9 +510,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|
155
510
|
else (-1, -1)
|
156
511
|
)
|
157
512
|
|
158
|
-
|
513
|
+
# Check if we should use local attention
|
514
|
+
use_local_attn = (
|
515
|
+
self.attention_chunk_size is not None
|
516
|
+
and metadata.local_attn_metadata is not None
|
517
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
518
|
+
)
|
159
519
|
|
160
|
-
#
|
520
|
+
# Get the appropriate page table based on whether we're using local attention
|
521
|
+
if use_local_attn:
|
522
|
+
local_metadata = metadata.local_attn_metadata
|
523
|
+
page_table = local_metadata.local_block_table
|
524
|
+
cu_seqlens_q = local_metadata.local_query_start_loc
|
525
|
+
cache_seqlens = local_metadata.local_seqused_k
|
526
|
+
max_seqlen_q = local_metadata.local_max_query_len
|
527
|
+
max_seqlen_k = local_metadata.local_max_seq_len
|
528
|
+
else:
|
529
|
+
page_table = metadata.page_table
|
530
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
531
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
532
|
+
max_seqlen_q = metadata.max_seq_len_q
|
533
|
+
max_seqlen_k = metadata.max_seq_len_k
|
534
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
535
|
+
|
536
|
+
# Use Flash Attention for prefill
|
161
537
|
if not self.use_mla:
|
162
538
|
# Do multi-head attention
|
163
539
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
@@ -173,10 +549,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
173
549
|
k_cache=key_cache,
|
174
550
|
v_cache=value_cache,
|
175
551
|
page_table=page_table,
|
176
|
-
cache_seqlens=
|
177
|
-
cu_seqlens_q=
|
178
|
-
cu_seqlens_k_new=
|
179
|
-
max_seqlen_q=
|
552
|
+
cache_seqlens=cache_seqlens,
|
553
|
+
cu_seqlens_q=cu_seqlens_q,
|
554
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
555
|
+
max_seqlen_q=max_seqlen_q,
|
180
556
|
softmax_scale=layer.scaling,
|
181
557
|
causal=True,
|
182
558
|
window_size=window_size,
|
@@ -208,10 +584,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
208
584
|
v_cache=c_kv_cache,
|
209
585
|
qv=q_nope,
|
210
586
|
page_table=page_table,
|
211
|
-
cache_seqlens=
|
212
|
-
cu_seqlens_q=
|
213
|
-
cu_seqlens_k_new=
|
214
|
-
max_seqlen_q=
|
587
|
+
cache_seqlens=cache_seqlens,
|
588
|
+
cu_seqlens_q=cu_seqlens_q,
|
589
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
590
|
+
max_seqlen_q=max_seqlen_q,
|
215
591
|
softmax_scale=layer.scaling,
|
216
592
|
causal=True,
|
217
593
|
softcap=layer.logit_cap,
|
@@ -263,7 +639,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
263
639
|
if layer.sliding_window_size is not None
|
264
640
|
else (-1, -1)
|
265
641
|
)
|
266
|
-
|
267
642
|
page_table = metadata.page_table
|
268
643
|
|
269
644
|
if not self.use_mla:
|
@@ -281,8 +656,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
281
656
|
|
282
657
|
# Pre-reshape query tensor
|
283
658
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
284
|
-
|
285
|
-
# Run attention with precomputed values
|
286
659
|
o = flash_attn_with_kvcache(
|
287
660
|
q=q_reshaped,
|
288
661
|
k_cache=key_cache,
|
@@ -334,7 +707,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
334
707
|
k_descale=layer.k_scale,
|
335
708
|
v_descale=layer.v_scale,
|
336
709
|
)
|
337
|
-
|
338
710
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
339
711
|
|
340
712
|
def init_cuda_graph_state(self, max_bs: int):
|
@@ -346,7 +718,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
346
718
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
347
719
|
to avoid memory allocations.
|
348
720
|
"""
|
349
|
-
# Initialize fixed size tensors for decode operations
|
350
721
|
self.decode_cuda_graph_metadata = {
|
351
722
|
# Page table for token mapping (batch_size, max_context_len)
|
352
723
|
"page_table": torch.zeros(
|
@@ -355,6 +726,39 @@ class FlashAttentionBackend(AttentionBackend):
|
|
355
726
|
dtype=torch.int32,
|
356
727
|
device=self.device,
|
357
728
|
),
|
729
|
+
"page_table_draft_decode": torch.zeros(
|
730
|
+
max_bs,
|
731
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
732
|
+
dtype=torch.int32,
|
733
|
+
device=self.device,
|
734
|
+
),
|
735
|
+
"strided_indices": torch.arange(
|
736
|
+
0, self.max_context_len, self.page_size, device=self.device
|
737
|
+
),
|
738
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
739
|
+
"cu_seqlens_q": torch.arange(
|
740
|
+
0, max_bs + 128, dtype=torch.int32, device=self.device
|
741
|
+
),
|
742
|
+
"cu_seqlens_k": torch.zeros(
|
743
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
744
|
+
),
|
745
|
+
}
|
746
|
+
|
747
|
+
self.target_verify_metadata = {
|
748
|
+
"page_table": torch.zeros(
|
749
|
+
max_bs,
|
750
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
751
|
+
dtype=torch.int32,
|
752
|
+
device=self.device,
|
753
|
+
),
|
754
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
755
|
+
"cu_seqlens_q": torch.zeros(
|
756
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
757
|
+
),
|
758
|
+
"cu_seqlens_k": torch.zeros(
|
759
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
760
|
+
),
|
761
|
+
"max_seqlen_q": 0,
|
358
762
|
"strided_indices": torch.arange(
|
359
763
|
0, self.max_context_len, self.page_size, device=self.device
|
360
764
|
),
|
@@ -372,27 +776,89 @@ class FlashAttentionBackend(AttentionBackend):
|
|
372
776
|
):
|
373
777
|
"""Initialize forward metadata for capturing CUDA graph."""
|
374
778
|
metadata = FlashAttentionMetadata()
|
375
|
-
# Get sequence information
|
376
|
-
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
377
|
-
batch_size = len(seq_lens)
|
378
779
|
device = seq_lens.device
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
780
|
+
if forward_mode.is_decode():
|
781
|
+
if spec_info is not None:
|
782
|
+
# Draft Decode
|
783
|
+
metadata.cu_seqlens_q = torch.arange(
|
784
|
+
0, bs + 1, dtype=torch.int32, device=device
|
785
|
+
)
|
786
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
787
|
+
"cache_seqlens"
|
788
|
+
][:bs]
|
789
|
+
|
790
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
791
|
+
: bs + 1
|
792
|
+
]
|
793
|
+
|
794
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
795
|
+
torch.cumsum(
|
796
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
797
|
+
),
|
798
|
+
(1, 0),
|
799
|
+
)
|
800
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
801
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
802
|
+
"page_table_draft_decode"
|
803
|
+
][req_pool_indices, :]
|
804
|
+
else:
|
805
|
+
# Normal Decode
|
806
|
+
# Get sequence information
|
807
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
808
|
+
batch_size = len(seq_lens)
|
809
|
+
device = seq_lens.device
|
810
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
811
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
812
|
+
)
|
813
|
+
# Precompute maximum sequence length
|
814
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
815
|
+
# Precompute page table
|
816
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
817
|
+
req_pool_indices, :
|
818
|
+
]
|
819
|
+
# Precompute cumulative sequence lengths
|
820
|
+
metadata.cu_seqlens_q = torch.arange(
|
821
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
822
|
+
)
|
823
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
824
|
+
elif forward_mode.is_target_verify():
|
825
|
+
draft_token_num = spec_info.draft_token_num
|
826
|
+
|
827
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
828
|
+
:bs
|
829
|
+
]
|
830
|
+
metadata.cache_seqlens_int32.copy_(
|
831
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
392
832
|
)
|
393
|
-
|
394
|
-
|
395
|
-
|
833
|
+
|
834
|
+
metadata.max_seq_len_q = draft_token_num
|
835
|
+
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
|
836
|
+
|
837
|
+
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
|
838
|
+
torch.arange(
|
839
|
+
0,
|
840
|
+
bs * draft_token_num + 1,
|
841
|
+
draft_token_num,
|
842
|
+
dtype=torch.int32,
|
843
|
+
device=device,
|
844
|
+
)
|
845
|
+
]
|
846
|
+
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
847
|
+
cu_k.copy_(
|
848
|
+
torch.nn.functional.pad(
|
849
|
+
torch.cumsum(
|
850
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
851
|
+
),
|
852
|
+
(1, 0),
|
853
|
+
)
|
854
|
+
)
|
855
|
+
metadata.cu_seqlens_k = cu_k
|
856
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
857
|
+
req_pool_indices, :
|
858
|
+
]
|
859
|
+
|
860
|
+
self.target_verify_metadata[bs] = metadata
|
861
|
+
|
396
862
|
self.forward_metadata = metadata
|
397
863
|
|
398
864
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -405,30 +871,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
405
871
|
forward_mode: ForwardMode,
|
406
872
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
407
873
|
seq_lens_cpu: Optional[torch.Tensor],
|
874
|
+
out_cache_loc: torch.Tensor = None,
|
408
875
|
):
|
409
876
|
# """Initialize forward metadata for replaying CUDA graph."""
|
410
|
-
|
877
|
+
device = seq_lens.device
|
878
|
+
seq_lens = seq_lens[:bs]
|
879
|
+
req_pool_indices = req_pool_indices[:bs]
|
880
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
881
|
+
if forward_mode.is_decode():
|
882
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
411
883
|
|
412
|
-
|
413
|
-
|
414
|
-
|
884
|
+
if spec_info is not None:
|
885
|
+
# Draft Decode
|
886
|
+
max_len = seq_lens_cpu.max().item()
|
887
|
+
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
415
888
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
889
|
+
metadata.cache_seqlens_int32.copy_(
|
890
|
+
(seq_lens + (self.step_id + 1)).to(torch.int32)
|
891
|
+
)
|
892
|
+
|
893
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
|
894
|
+
|
895
|
+
metadata.cu_seqlens_k.copy_(
|
896
|
+
torch.nn.functional.pad(
|
897
|
+
torch.cumsum(
|
898
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
899
|
+
),
|
900
|
+
(1, 0),
|
901
|
+
)
|
902
|
+
)
|
903
|
+
|
904
|
+
page_table = self.req_to_token[
|
905
|
+
req_pool_indices, : metadata.max_seq_len_k
|
906
|
+
]
|
907
|
+
|
908
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
909
|
+
else:
|
910
|
+
# Normal Decode
|
911
|
+
max_len = seq_lens_cpu.max().item()
|
912
|
+
metadata.max_seq_len_k = max_len
|
913
|
+
|
914
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
915
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
916
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
917
|
+
)
|
918
|
+
|
919
|
+
max_seq_pages = (
|
920
|
+
metadata.max_seq_len_k + self.page_size - 1
|
921
|
+
) // self.page_size
|
922
|
+
page_indices = self.req_to_token[
|
923
|
+
:,
|
924
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
925
|
+
]
|
926
|
+
page_indices = page_indices[req_pool_indices] // self.page_size
|
927
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
928
|
+
metadata.page_table[:, max_seq_pages:].fill_(0)
|
929
|
+
|
930
|
+
elif forward_mode.is_target_verify():
|
931
|
+
metadata = self.target_verify_metadata[bs]
|
932
|
+
draft_token_num = spec_info.draft_token_num
|
933
|
+
|
934
|
+
metadata.cu_seqlens_q.copy_(
|
935
|
+
torch.arange(
|
936
|
+
0,
|
937
|
+
bs * draft_token_num + 1,
|
938
|
+
draft_token_num,
|
939
|
+
dtype=torch.int32,
|
940
|
+
device=device,
|
941
|
+
)
|
942
|
+
)
|
943
|
+
metadata.cache_seqlens_int32.copy_(
|
944
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
945
|
+
)
|
946
|
+
|
947
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
|
948
|
+
metadata.cu_seqlens_k.copy_(
|
949
|
+
torch.nn.functional.pad(
|
950
|
+
torch.cumsum(
|
951
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
952
|
+
),
|
953
|
+
(1, 0),
|
954
|
+
)
|
955
|
+
)
|
956
|
+
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
957
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
422
958
|
|
423
|
-
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
424
|
-
page_indices = self.req_to_token[
|
425
|
-
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
426
|
-
]
|
427
|
-
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
428
|
-
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
429
|
-
metadata.page_table[:, max_seq_pages:].fill_(0)
|
430
959
|
self.forward_metadata = metadata
|
431
960
|
|
432
961
|
def get_cuda_graph_seq_len_fill_value(self):
|
433
962
|
"""Get the fill value for sequence length in CUDA graph."""
|
434
963
|
return 0
|
964
|
+
|
965
|
+
|
966
|
+
class FlashAttentionMultiStepBackend:
|
967
|
+
|
968
|
+
def __init__(
|
969
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
970
|
+
):
|
971
|
+
self.model_runner = model_runner
|
972
|
+
self.topk = topk
|
973
|
+
self.speculative_num_steps = speculative_num_steps
|
974
|
+
|
975
|
+
self.attn_backends = []
|
976
|
+
for i in range(self.speculative_num_steps):
|
977
|
+
self.attn_backends.append(
|
978
|
+
FlashAttentionBackend(
|
979
|
+
model_runner,
|
980
|
+
topk=self.topk,
|
981
|
+
speculative_num_steps=self.speculative_num_steps,
|
982
|
+
step_id=i,
|
983
|
+
)
|
984
|
+
)
|
985
|
+
|
986
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
987
|
+
for i in range(self.speculative_num_steps - 1):
|
988
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
989
|
+
|
990
|
+
def init_cuda_graph_state(self, max_bs: int):
|
991
|
+
for i in range(self.speculative_num_steps):
|
992
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
993
|
+
|
994
|
+
def init_forward_metadata_capture_cuda_graph(
|
995
|
+
self,
|
996
|
+
forward_batch: ForwardBatch,
|
997
|
+
):
|
998
|
+
assert forward_batch.spec_info is not None
|
999
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1000
|
+
|
1001
|
+
for i in range(self.speculative_num_steps - 1):
|
1002
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
1003
|
+
forward_batch.batch_size,
|
1004
|
+
forward_batch.batch_size * self.topk,
|
1005
|
+
forward_batch.req_pool_indices,
|
1006
|
+
forward_batch.seq_lens,
|
1007
|
+
encoder_lens=None,
|
1008
|
+
forward_mode=ForwardMode.DECODE,
|
1009
|
+
spec_info=forward_batch.spec_info,
|
1010
|
+
)
|
1011
|
+
|
1012
|
+
def init_forward_metadata_replay_cuda_graph(
|
1013
|
+
self, forward_batch: ForwardBatch, bs: int
|
1014
|
+
):
|
1015
|
+
assert forward_batch.spec_info is not None
|
1016
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1017
|
+
|
1018
|
+
for i in range(self.speculative_num_steps - 1):
|
1019
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1020
|
+
bs,
|
1021
|
+
forward_batch.req_pool_indices,
|
1022
|
+
forward_batch.seq_lens,
|
1023
|
+
forward_batch.seq_lens_sum,
|
1024
|
+
encoder_lens=None,
|
1025
|
+
forward_mode=ForwardMode.DECODE,
|
1026
|
+
spec_info=forward_batch.spec_info,
|
1027
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
1028
|
+
out_cache_loc=forward_batch.out_cache_loc,
|
1029
|
+
)
|