sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -13,36 +13,64 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
13
13
|
|
14
14
|
import torch
|
15
15
|
|
16
|
+
from sglang.srt.configs.model_config import AttentionArch
|
16
17
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
18
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
17
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
18
20
|
|
19
21
|
if TYPE_CHECKING:
|
20
22
|
from sglang.srt.layers.radix_attention import RadixAttention
|
21
23
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
22
24
|
|
23
|
-
from
|
25
|
+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
24
26
|
|
25
27
|
|
26
28
|
@dataclass
|
27
29
|
class FlashAttentionMetadata:
|
28
|
-
"""Metadata
|
30
|
+
"""Metadata to be init once in the model forward pass,
|
31
|
+
each layer's forward pass can reuse the metadata."""
|
29
32
|
|
33
|
+
# Cumulative sequence lengths for query
|
30
34
|
cu_seqlens_q: torch.Tensor = None
|
35
|
+
# Cumulative sequence lengths for key
|
31
36
|
cu_seqlens_k: torch.Tensor = None
|
37
|
+
# Maximum sequence length for query
|
38
|
+
max_seq_len_q: int = 0
|
39
|
+
# Maximum sequence length for key
|
32
40
|
max_seq_len_k: int = 0
|
41
|
+
# Window size (typically used by Gemma)
|
33
42
|
window_size: tuple = (-1, -1)
|
43
|
+
# Page table, the index of KV Cache Tables/Blocks
|
34
44
|
page_table: torch.Tensor = None
|
45
|
+
# Sequence lengths for the forward batch
|
35
46
|
cache_seqlens_int32: torch.Tensor = None
|
36
|
-
max_seq_len_q: int = 0
|
37
47
|
|
38
48
|
|
39
49
|
class FlashAttentionBackend(AttentionBackend):
|
40
|
-
"""FlashAttention backend implementation.
|
50
|
+
"""FlashAttention backend implementation.
|
51
|
+
|
52
|
+
Note about the init:
|
53
|
+
- If no spec decoding
|
54
|
+
- FlashAttentionBackend will be init once when the server starts.
|
55
|
+
- If spec decoding
|
56
|
+
- FlashAttentionBackend will be init once for the target worker
|
57
|
+
- FlashAttentionMultiStepBackend will be once for the draft worker
|
58
|
+
- It will spawn num_steps FlashAttentionBackend for the draft worker
|
59
|
+
|
60
|
+
Note about CUDA Graph:
|
61
|
+
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
|
62
|
+
- We don't support CUDA Graph for Extend and Draft Extend.
|
63
|
+
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
|
64
|
+
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
|
65
|
+
"""
|
41
66
|
|
42
67
|
def __init__(
|
43
68
|
self,
|
44
69
|
model_runner: ModelRunner,
|
45
70
|
skip_prefill: bool = False,
|
71
|
+
topk=0,
|
72
|
+
speculative_num_steps=0,
|
73
|
+
step_id=0,
|
46
74
|
):
|
47
75
|
super().__init__()
|
48
76
|
|
@@ -51,49 +79,138 @@ class FlashAttentionBackend(AttentionBackend):
|
|
51
79
|
and model_runner.model_config.is_encoder_decoder
|
52
80
|
), "Sliding window and cross attention are not supported together"
|
53
81
|
|
54
|
-
# Initialize metadata
|
55
82
|
self.forward_metadata: FlashAttentionMetadata = None
|
56
83
|
self.max_context_len = model_runner.model_config.context_len
|
57
84
|
self.device = model_runner.device
|
58
85
|
self.decode_cuda_graph_metadata = {}
|
86
|
+
self.target_verify_metadata = {}
|
59
87
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
88
|
+
self.page_size = model_runner.page_size
|
89
|
+
self.use_mla = (
|
90
|
+
model_runner.model_config.attention_arch == AttentionArch.MLA
|
91
|
+
) and (not global_server_args_dict["disable_mla"])
|
92
|
+
self.skip_prefill = skip_prefill
|
93
|
+
|
94
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
95
|
+
assert (
|
96
|
+
topk <= 1
|
97
|
+
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
98
|
+
|
99
|
+
self.topk = 1
|
100
|
+
self.step_id = step_id
|
101
|
+
self.speculative_num_steps = speculative_num_steps
|
60
102
|
|
61
103
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
62
104
|
"""Initialize forward metadata to cache repetitive calculations."""
|
63
|
-
# Create metadata based on forward mode
|
64
105
|
metadata = FlashAttentionMetadata()
|
65
|
-
|
66
|
-
extend_seq_lens = forward_batch.extend_seq_lens
|
67
|
-
# Get sequence information
|
68
106
|
seqlens_in_batch = forward_batch.seq_lens
|
69
|
-
# Precompute int32 version of sequence lengths
|
70
|
-
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
71
107
|
batch_size = len(seqlens_in_batch)
|
72
108
|
device = seqlens_in_batch.device
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
109
|
+
if forward_batch.forward_mode.is_decode():
|
110
|
+
# Skip Prefill or Draft Decode
|
111
|
+
# Note: Draft Decode will be ran on the Draft Worker
|
112
|
+
if forward_batch.spec_info is not None:
|
113
|
+
metadata.cu_seqlens_q = torch.arange(
|
114
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
115
|
+
)
|
116
|
+
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
117
|
+
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
118
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
119
|
+
torch.cumsum(
|
120
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
121
|
+
),
|
122
|
+
(1, 0),
|
123
|
+
)
|
124
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
125
|
+
self.step_id + 1
|
126
|
+
)
|
127
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
128
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
129
|
+
]
|
130
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
131
|
+
self.speculative_num_steps, -1
|
132
|
+
).T
|
133
|
+
|
134
|
+
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
135
|
+
real_bsz_start_idx = idx
|
136
|
+
real_bsz_end_idx = idx + 1
|
137
|
+
metadata.page_table[
|
138
|
+
real_bsz_start_idx:real_bsz_end_idx,
|
139
|
+
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
140
|
+
] = cache_loc[
|
141
|
+
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
142
|
+
]
|
143
|
+
else: # Normal Decode without Spec Decoding
|
144
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
145
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
146
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
147
|
+
)
|
148
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
149
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
150
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
151
|
+
]
|
152
|
+
metadata.cu_seqlens_q = torch.arange(
|
153
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
154
|
+
)
|
155
|
+
elif forward_batch.forward_mode.is_target_verify():
|
156
|
+
# Note: Target Verify will be ran on the Target Worker
|
157
|
+
draft_token_num = forward_batch.spec_info.draft_token_num
|
158
|
+
metadata.cache_seqlens_int32 = (
|
159
|
+
forward_batch.seq_lens + draft_token_num
|
160
|
+
).to(torch.int32)
|
161
|
+
metadata.max_seq_len_q = draft_token_num
|
162
|
+
metadata.max_seq_len_k = (
|
163
|
+
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
164
|
+
)
|
84
165
|
metadata.cu_seqlens_q = torch.arange(
|
85
|
-
0,
|
166
|
+
0,
|
167
|
+
batch_size * draft_token_num + 1,
|
168
|
+
draft_token_num,
|
169
|
+
dtype=torch.int32,
|
170
|
+
device=device,
|
86
171
|
)
|
87
|
-
|
88
|
-
|
172
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
173
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
174
|
+
(1, 0),
|
175
|
+
)
|
176
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
177
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
178
|
+
]
|
179
|
+
|
180
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
181
|
+
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
182
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
183
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
184
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
185
|
+
)
|
186
|
+
# Precompute maximum sequence length
|
187
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
188
|
+
# Precompute page table
|
189
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
190
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
191
|
+
]
|
89
192
|
# Precompute cumulative sequence lengths
|
90
|
-
if
|
193
|
+
if (
|
194
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
195
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
196
|
+
):
|
197
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
91
198
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
92
199
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
93
200
|
)
|
201
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
94
202
|
else:
|
95
203
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
96
|
-
|
204
|
+
metadata.max_seq_len_q = metadata.max_seq_len_k
|
205
|
+
|
206
|
+
# Precompute strided indices
|
207
|
+
if self.page_size > 1:
|
208
|
+
self.strided_indices = torch.arange(
|
209
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
210
|
+
)
|
211
|
+
metadata.page_table = (
|
212
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
|
+
)
|
97
214
|
self.forward_metadata = metadata
|
98
215
|
|
99
216
|
def forward_extend(
|
@@ -105,23 +222,29 @@ class FlashAttentionBackend(AttentionBackend):
|
|
105
222
|
forward_batch: ForwardBatch,
|
106
223
|
save_kv_cache=True,
|
107
224
|
):
|
108
|
-
cache_loc = (
|
109
|
-
forward_batch.out_cache_loc
|
110
|
-
if not layer.is_cross_attention
|
111
|
-
else forward_batch.encoder_out_cache_loc
|
112
|
-
)
|
113
|
-
|
114
225
|
if k is not None:
|
115
226
|
assert v is not None
|
116
227
|
if save_kv_cache:
|
117
|
-
|
118
|
-
|
228
|
+
cache_loc = (
|
229
|
+
forward_batch.out_cache_loc
|
230
|
+
if not layer.is_cross_attention
|
231
|
+
else forward_batch.encoder_out_cache_loc
|
119
232
|
)
|
233
|
+
if not self.use_mla:
|
234
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
235
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
239
|
+
layer,
|
240
|
+
cache_loc,
|
241
|
+
k,
|
242
|
+
v,
|
243
|
+
)
|
120
244
|
|
121
245
|
# Use precomputed metadata
|
122
246
|
metadata = self.forward_metadata
|
123
247
|
|
124
|
-
# # Use Flash Attention for prefill
|
125
248
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
126
249
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
127
250
|
# here is two side inclusive
|
@@ -130,26 +253,72 @@ class FlashAttentionBackend(AttentionBackend):
|
|
130
253
|
if layer.sliding_window_size is not None
|
131
254
|
else (-1, -1)
|
132
255
|
)
|
133
|
-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
134
|
-
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
135
|
-
o = flash_attn_with_kvcache(
|
136
|
-
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
137
|
-
k_cache=key_cache.unsqueeze(1),
|
138
|
-
v_cache=value_cache.unsqueeze(1),
|
139
|
-
page_table=metadata.page_table,
|
140
|
-
cache_seqlens=metadata.cache_seqlens_int32,
|
141
|
-
cu_seqlens_q=metadata.cu_seqlens_q,
|
142
|
-
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
143
|
-
max_seqlen_q=metadata.max_seq_len_q,
|
144
|
-
softmax_scale=layer.scaling,
|
145
|
-
causal=True,
|
146
|
-
window_size=window_size,
|
147
|
-
softcap=layer.logit_cap,
|
148
|
-
k_descale=layer.k_scale,
|
149
|
-
v_descale=layer.v_scale,
|
150
|
-
)
|
151
256
|
|
152
|
-
|
257
|
+
page_table = metadata.page_table
|
258
|
+
|
259
|
+
# Use Flash Attention for prefill
|
260
|
+
if not self.use_mla:
|
261
|
+
# Do multi-head attention
|
262
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
263
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
264
|
+
key_cache = key_cache.view(
|
265
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
266
|
+
)
|
267
|
+
value_cache = value_cache.view(
|
268
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
269
|
+
)
|
270
|
+
o = flash_attn_with_kvcache(
|
271
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
272
|
+
k_cache=key_cache,
|
273
|
+
v_cache=value_cache,
|
274
|
+
page_table=page_table,
|
275
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
276
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
277
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
278
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
279
|
+
softmax_scale=layer.scaling,
|
280
|
+
causal=True,
|
281
|
+
window_size=window_size,
|
282
|
+
softcap=layer.logit_cap,
|
283
|
+
k_descale=layer.k_scale,
|
284
|
+
v_descale=layer.v_scale,
|
285
|
+
)
|
286
|
+
else:
|
287
|
+
# Do absorbed multi-latent attention
|
288
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
289
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
290
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
291
|
+
k_rope_cache = k_rope.view(
|
292
|
+
-1,
|
293
|
+
self.page_size,
|
294
|
+
layer.tp_k_head_num,
|
295
|
+
layer.head_dim - layer.v_head_dim,
|
296
|
+
)
|
297
|
+
c_kv_cache = c_kv.view(
|
298
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
299
|
+
)
|
300
|
+
|
301
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
302
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
303
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
304
|
+
o = flash_attn_with_kvcache(
|
305
|
+
q=q_rope,
|
306
|
+
k_cache=k_rope_cache,
|
307
|
+
v_cache=c_kv_cache,
|
308
|
+
qv=q_nope,
|
309
|
+
page_table=page_table,
|
310
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
311
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
312
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
313
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
314
|
+
softmax_scale=layer.scaling,
|
315
|
+
causal=True,
|
316
|
+
softcap=layer.logit_cap,
|
317
|
+
k_descale=layer.k_scale,
|
318
|
+
v_descale=layer.v_scale,
|
319
|
+
)
|
320
|
+
|
321
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
153
322
|
|
154
323
|
def forward_decode(
|
155
324
|
self,
|
@@ -162,26 +331,29 @@ class FlashAttentionBackend(AttentionBackend):
|
|
162
331
|
) -> torch.Tensor:
|
163
332
|
"""Forward pass with FlashAttention using precomputed metadata."""
|
164
333
|
# Save KV cache if needed
|
165
|
-
if k is not None
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
334
|
+
if k is not None:
|
335
|
+
assert v is not None
|
336
|
+
if save_kv_cache:
|
337
|
+
cache_loc = (
|
338
|
+
forward_batch.out_cache_loc
|
339
|
+
if not layer.is_cross_attention
|
340
|
+
else forward_batch.encoder_out_cache_loc
|
341
|
+
)
|
342
|
+
if not self.use_mla:
|
343
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
344
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
348
|
+
layer,
|
349
|
+
cache_loc,
|
350
|
+
k,
|
351
|
+
v,
|
352
|
+
)
|
178
353
|
|
179
354
|
# Use precomputed metadata
|
180
355
|
metadata = self.forward_metadata
|
181
356
|
|
182
|
-
# Pre-reshape query tensor
|
183
|
-
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
184
|
-
|
185
357
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
186
358
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
187
359
|
# here is two side inclusive
|
@@ -190,25 +362,75 @@ class FlashAttentionBackend(AttentionBackend):
|
|
190
362
|
if layer.sliding_window_size is not None
|
191
363
|
else (-1, -1)
|
192
364
|
)
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
k_descale=layer.k_scale,
|
208
|
-
v_descale=layer.v_scale,
|
209
|
-
)
|
365
|
+
page_table = metadata.page_table
|
366
|
+
|
367
|
+
if not self.use_mla:
|
368
|
+
# Do multi-head attention
|
369
|
+
|
370
|
+
# Get KV cache
|
371
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
372
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
373
|
+
key_cache = key_cache.view(
|
374
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
375
|
+
)
|
376
|
+
value_cache = value_cache.view(
|
377
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
378
|
+
)
|
210
379
|
|
211
|
-
|
380
|
+
# Pre-reshape query tensor
|
381
|
+
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
382
|
+
o = flash_attn_with_kvcache(
|
383
|
+
q=q_reshaped,
|
384
|
+
k_cache=key_cache,
|
385
|
+
v_cache=value_cache,
|
386
|
+
page_table=page_table,
|
387
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
388
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
389
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
390
|
+
max_seqlen_q=1,
|
391
|
+
softmax_scale=layer.scaling,
|
392
|
+
causal=True,
|
393
|
+
window_size=window_size,
|
394
|
+
softcap=layer.logit_cap,
|
395
|
+
k_descale=layer.k_scale,
|
396
|
+
v_descale=layer.v_scale,
|
397
|
+
)
|
398
|
+
else:
|
399
|
+
# Do absorbed multi-latent attention
|
400
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
401
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
402
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
403
|
+
k_rope_cache = k_rope.view(
|
404
|
+
-1,
|
405
|
+
self.page_size,
|
406
|
+
layer.tp_k_head_num,
|
407
|
+
layer.head_dim - layer.v_head_dim,
|
408
|
+
)
|
409
|
+
c_kv_cache = c_kv.view(
|
410
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
411
|
+
)
|
412
|
+
|
413
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
414
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
415
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
416
|
+
|
417
|
+
o = flash_attn_with_kvcache(
|
418
|
+
q=q_rope,
|
419
|
+
k_cache=k_rope_cache,
|
420
|
+
v_cache=c_kv_cache,
|
421
|
+
qv=q_nope,
|
422
|
+
page_table=page_table,
|
423
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
424
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
425
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
426
|
+
max_seqlen_q=1,
|
427
|
+
softmax_scale=layer.scaling,
|
428
|
+
causal=True,
|
429
|
+
softcap=layer.logit_cap,
|
430
|
+
k_descale=layer.k_scale,
|
431
|
+
v_descale=layer.v_scale,
|
432
|
+
)
|
433
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
212
434
|
|
213
435
|
def init_cuda_graph_state(self, max_bs: int):
|
214
436
|
"""Initialize CUDA graph state for the attention backend.
|
@@ -219,11 +441,49 @@ class FlashAttentionBackend(AttentionBackend):
|
|
219
441
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
220
442
|
to avoid memory allocations.
|
221
443
|
"""
|
222
|
-
# Initialize fixed size tensors for decode operations
|
223
444
|
self.decode_cuda_graph_metadata = {
|
224
445
|
# Page table for token mapping (batch_size, max_context_len)
|
225
446
|
"page_table": torch.zeros(
|
226
|
-
max_bs,
|
447
|
+
max_bs,
|
448
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
449
|
+
dtype=torch.int32,
|
450
|
+
device=self.device,
|
451
|
+
),
|
452
|
+
"page_table_draft_decode": torch.zeros(
|
453
|
+
max_bs,
|
454
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
455
|
+
dtype=torch.int32,
|
456
|
+
device=self.device,
|
457
|
+
),
|
458
|
+
"strided_indices": torch.arange(
|
459
|
+
0, self.max_context_len, self.page_size, device=self.device
|
460
|
+
),
|
461
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
462
|
+
"cu_seqlens_q": torch.arange(
|
463
|
+
0, max_bs + 128, dtype=torch.int32, device=self.device
|
464
|
+
),
|
465
|
+
"cu_seqlens_k": torch.zeros(
|
466
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
467
|
+
),
|
468
|
+
}
|
469
|
+
|
470
|
+
self.target_verify_metadata = {
|
471
|
+
"page_table": torch.zeros(
|
472
|
+
max_bs,
|
473
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
474
|
+
dtype=torch.int32,
|
475
|
+
device=self.device,
|
476
|
+
),
|
477
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
478
|
+
"cu_seqlens_q": torch.zeros(
|
479
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
480
|
+
),
|
481
|
+
"cu_seqlens_k": torch.zeros(
|
482
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
483
|
+
),
|
484
|
+
"max_seqlen_q": 0,
|
485
|
+
"strided_indices": torch.arange(
|
486
|
+
0, self.max_context_len, self.page_size, device=self.device
|
227
487
|
),
|
228
488
|
}
|
229
489
|
|
@@ -239,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
|
|
239
499
|
):
|
240
500
|
"""Initialize forward metadata for capturing CUDA graph."""
|
241
501
|
metadata = FlashAttentionMetadata()
|
242
|
-
# Get sequence information
|
243
|
-
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
244
|
-
batch_size = len(seq_lens)
|
245
502
|
device = seq_lens.device
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
503
|
+
if forward_mode.is_decode():
|
504
|
+
if spec_info is not None:
|
505
|
+
# Draft Decode
|
506
|
+
metadata.cu_seqlens_q = torch.arange(
|
507
|
+
0, bs + 1, dtype=torch.int32, device=device
|
508
|
+
)
|
509
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
510
|
+
"cache_seqlens"
|
511
|
+
][:bs]
|
512
|
+
|
513
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
514
|
+
: bs + 1
|
515
|
+
]
|
516
|
+
|
517
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
518
|
+
torch.cumsum(
|
519
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
520
|
+
),
|
521
|
+
(1, 0),
|
522
|
+
)
|
523
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
524
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
525
|
+
"page_table_draft_decode"
|
526
|
+
][req_pool_indices, :]
|
527
|
+
else:
|
528
|
+
# Normal Decode
|
529
|
+
# Get sequence information
|
530
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
531
|
+
batch_size = len(seq_lens)
|
532
|
+
device = seq_lens.device
|
533
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
534
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
535
|
+
)
|
536
|
+
# Precompute maximum sequence length
|
537
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
538
|
+
# Precompute page table
|
539
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
540
|
+
req_pool_indices, :
|
541
|
+
]
|
542
|
+
# Precompute cumulative sequence lengths
|
543
|
+
metadata.cu_seqlens_q = torch.arange(
|
544
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
545
|
+
)
|
546
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
547
|
+
elif forward_mode.is_target_verify():
|
548
|
+
draft_token_num = spec_info.draft_token_num
|
549
|
+
|
550
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
551
|
+
:bs
|
552
|
+
]
|
553
|
+
metadata.cache_seqlens_int32.copy_(
|
554
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
259
555
|
)
|
260
|
-
|
261
|
-
|
262
|
-
|
556
|
+
|
557
|
+
metadata.max_seq_len_q = draft_token_num
|
558
|
+
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
|
559
|
+
|
560
|
+
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
|
561
|
+
torch.arange(
|
562
|
+
0,
|
563
|
+
bs * draft_token_num + 1,
|
564
|
+
draft_token_num,
|
565
|
+
dtype=torch.int32,
|
566
|
+
device=device,
|
567
|
+
)
|
568
|
+
]
|
569
|
+
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
570
|
+
cu_k.copy_(
|
571
|
+
torch.nn.functional.pad(
|
572
|
+
torch.cumsum(
|
573
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
574
|
+
),
|
575
|
+
(1, 0),
|
576
|
+
)
|
577
|
+
)
|
578
|
+
metadata.cu_seqlens_k = cu_k
|
579
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
580
|
+
req_pool_indices, :
|
581
|
+
]
|
582
|
+
|
583
|
+
self.target_verify_metadata[bs] = metadata
|
584
|
+
|
263
585
|
self.forward_metadata = metadata
|
264
586
|
|
265
587
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -272,24 +594,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
272
594
|
forward_mode: ForwardMode,
|
273
595
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
274
596
|
seq_lens_cpu: Optional[torch.Tensor],
|
597
|
+
out_cache_loc: torch.Tensor = None,
|
275
598
|
):
|
276
599
|
# """Initialize forward metadata for replaying CUDA graph."""
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
600
|
+
device = seq_lens.device
|
601
|
+
seq_lens = seq_lens[:bs]
|
602
|
+
req_pool_indices = req_pool_indices[:bs]
|
603
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
604
|
+
if forward_mode.is_decode():
|
605
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
606
|
+
|
607
|
+
if spec_info is not None:
|
608
|
+
# Draft Decode
|
609
|
+
max_len = seq_lens_cpu.max().item()
|
610
|
+
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
611
|
+
|
612
|
+
metadata.cache_seqlens_int32.copy_(
|
613
|
+
(seq_lens + (self.step_id + 1)).to(torch.int32)
|
614
|
+
)
|
615
|
+
|
616
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
|
617
|
+
|
618
|
+
metadata.cu_seqlens_k.copy_(
|
619
|
+
torch.nn.functional.pad(
|
620
|
+
torch.cumsum(
|
621
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
622
|
+
),
|
623
|
+
(1, 0),
|
624
|
+
)
|
625
|
+
)
|
626
|
+
|
627
|
+
page_table = self.req_to_token[
|
628
|
+
req_pool_indices, : metadata.max_seq_len_k
|
629
|
+
]
|
630
|
+
|
631
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
632
|
+
else:
|
633
|
+
# Normal Decode
|
634
|
+
max_len = seq_lens_cpu.max().item()
|
635
|
+
metadata.max_seq_len_k = max_len
|
636
|
+
|
637
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
638
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
639
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
640
|
+
)
|
641
|
+
|
642
|
+
max_seq_pages = (
|
643
|
+
metadata.max_seq_len_k + self.page_size - 1
|
644
|
+
) // self.page_size
|
645
|
+
page_indices = self.req_to_token[
|
646
|
+
:,
|
647
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
648
|
+
]
|
649
|
+
page_indices = page_indices[req_pool_indices] // self.page_size
|
650
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
651
|
+
metadata.page_table[:, max_seq_pages:].fill_(0)
|
652
|
+
|
653
|
+
elif forward_mode.is_target_verify():
|
654
|
+
metadata = self.target_verify_metadata[bs]
|
655
|
+
draft_token_num = spec_info.draft_token_num
|
656
|
+
|
657
|
+
metadata.cu_seqlens_q.copy_(
|
658
|
+
torch.arange(
|
659
|
+
0,
|
660
|
+
bs * draft_token_num + 1,
|
661
|
+
draft_token_num,
|
662
|
+
dtype=torch.int32,
|
663
|
+
device=device,
|
664
|
+
)
|
665
|
+
)
|
666
|
+
metadata.cache_seqlens_int32.copy_(
|
667
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
668
|
+
)
|
669
|
+
|
670
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
|
671
|
+
metadata.cu_seqlens_k.copy_(
|
672
|
+
torch.nn.functional.pad(
|
673
|
+
torch.cumsum(
|
674
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
675
|
+
),
|
676
|
+
(1, 0),
|
677
|
+
)
|
678
|
+
)
|
679
|
+
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
680
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
681
|
+
|
682
|
+
self.forward_metadata = metadata
|
292
683
|
|
293
684
|
def get_cuda_graph_seq_len_fill_value(self):
|
294
685
|
"""Get the fill value for sequence length in CUDA graph."""
|
295
686
|
return 0
|
687
|
+
|
688
|
+
|
689
|
+
class FlashAttentionMultiStepBackend:
|
690
|
+
|
691
|
+
def __init__(
|
692
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
693
|
+
):
|
694
|
+
self.model_runner = model_runner
|
695
|
+
self.topk = topk
|
696
|
+
self.speculative_num_steps = speculative_num_steps
|
697
|
+
|
698
|
+
self.attn_backends = []
|
699
|
+
for i in range(self.speculative_num_steps):
|
700
|
+
self.attn_backends.append(
|
701
|
+
FlashAttentionBackend(
|
702
|
+
model_runner,
|
703
|
+
topk=self.topk,
|
704
|
+
speculative_num_steps=self.speculative_num_steps,
|
705
|
+
step_id=i,
|
706
|
+
)
|
707
|
+
)
|
708
|
+
|
709
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
710
|
+
for i in range(self.speculative_num_steps - 1):
|
711
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
712
|
+
|
713
|
+
def init_cuda_graph_state(self, max_bs: int):
|
714
|
+
for i in range(self.speculative_num_steps):
|
715
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
716
|
+
|
717
|
+
def init_forward_metadata_capture_cuda_graph(
|
718
|
+
self,
|
719
|
+
forward_batch: ForwardBatch,
|
720
|
+
):
|
721
|
+
assert forward_batch.spec_info is not None
|
722
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
723
|
+
|
724
|
+
for i in range(self.speculative_num_steps - 1):
|
725
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
726
|
+
forward_batch.batch_size,
|
727
|
+
forward_batch.batch_size * self.topk,
|
728
|
+
forward_batch.req_pool_indices,
|
729
|
+
forward_batch.seq_lens,
|
730
|
+
encoder_lens=None,
|
731
|
+
forward_mode=ForwardMode.DECODE,
|
732
|
+
spec_info=forward_batch.spec_info,
|
733
|
+
)
|
734
|
+
|
735
|
+
def init_forward_metadata_replay_cuda_graph(
|
736
|
+
self, forward_batch: ForwardBatch, bs: int
|
737
|
+
):
|
738
|
+
assert forward_batch.spec_info is not None
|
739
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
740
|
+
|
741
|
+
for i in range(self.speculative_num_steps - 1):
|
742
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
743
|
+
bs,
|
744
|
+
forward_batch.req_pool_indices,
|
745
|
+
forward_batch.seq_lens,
|
746
|
+
forward_batch.seq_lens_sum,
|
747
|
+
encoder_lens=None,
|
748
|
+
forward_mode=ForwardMode.DECODE,
|
749
|
+
spec_info=forward_batch.spec_info,
|
750
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
751
|
+
out_cache_loc=forward_batch.out_cache_loc,
|
752
|
+
)
|