sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -1,49 +1,260 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
4
|
-
|
5
|
-
"""
|
6
|
-
Support different attention backends.
|
7
|
-
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
8
|
-
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
9
|
-
"""
|
10
|
-
|
11
3
|
from dataclasses import dataclass
|
12
4
|
from typing import TYPE_CHECKING, Optional, Union
|
13
5
|
|
6
|
+
import numpy as np
|
14
7
|
import torch
|
15
8
|
|
16
9
|
from sglang.srt.configs.model_config import AttentionArch
|
17
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
18
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
19
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
20
14
|
|
21
15
|
if TYPE_CHECKING:
|
22
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
23
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
18
|
|
25
|
-
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
19
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
26
20
|
|
27
21
|
|
28
22
|
@dataclass
|
29
23
|
class FlashAttentionMetadata:
|
30
24
|
"""Metadata to be init once in the model forward pass,
|
31
|
-
each layer's forward pass can reuse the metadata.
|
25
|
+
each layer's forward pass can reuse the metadata.
|
32
26
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
27
|
+
For each init metadata function, we will try set up them in below order
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Sequence lengths for the forward batch
|
31
|
+
cache_seqlens_int32: torch.Tensor = None
|
37
32
|
# Maximum sequence length for query
|
38
33
|
max_seq_len_q: int = 0
|
39
34
|
# Maximum sequence length for key
|
40
35
|
max_seq_len_k: int = 0
|
36
|
+
# Cumulative sequence lengths for query
|
37
|
+
cu_seqlens_q: torch.Tensor = None
|
38
|
+
# Cumulative sequence lengths for key
|
39
|
+
cu_seqlens_k: torch.Tensor = None
|
41
40
|
# Window size (typically used by Gemma)
|
42
41
|
window_size: tuple = (-1, -1)
|
43
42
|
# Page table, the index of KV Cache Tables/Blocks
|
44
43
|
page_table: torch.Tensor = None
|
44
|
+
|
45
|
+
# Encoder metadata
|
46
|
+
# Cumulative sequence lengths for encoder key
|
47
|
+
encoder_cu_seqlens_k: torch.Tensor = None
|
48
|
+
# Maximum sequence length for encoder key
|
49
|
+
encoder_max_seq_len_k: int = 0
|
45
50
|
# Sequence lengths for the forward batch
|
46
|
-
|
51
|
+
encoder_lens_int32: torch.Tensor = None
|
52
|
+
# Page table for the encoder
|
53
|
+
encoder_page_table: torch.Tensor = None
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class LocalAttentionMetadata:
|
57
|
+
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
|
58
|
+
local_seqused_k: torch.Tensor = None # sequence lengths for local attention
|
59
|
+
local_block_table: torch.Tensor = None # block table for local attention
|
60
|
+
local_max_query_len: int = 0 # max query length for local attention
|
61
|
+
local_max_seq_len: int = 0 # max sequence length for local attention
|
62
|
+
|
63
|
+
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
64
|
+
|
65
|
+
|
66
|
+
# Copied from:
|
67
|
+
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
68
|
+
#
|
69
|
+
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
70
|
+
# local attention blocks, where each block is passed to the attention kernel
|
71
|
+
# as an independent local ("virtual") batch item.
|
72
|
+
#
|
73
|
+
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
74
|
+
# q_seqlens = [4, 10, 5]
|
75
|
+
# kv_seqlens = [6, 17, 9]
|
76
|
+
# Then normally for regular attention we would compute with an attention mask
|
77
|
+
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
78
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
79
|
+
# k_toks > 0 1 2 3 4 5
|
80
|
+
# q_toks v _____________
|
81
|
+
# 0 | 1 1 1
|
82
|
+
# 1 | 1 1 1 1
|
83
|
+
# 2 | 1 1 1 1 1
|
84
|
+
# 3 | 1 1 1 1 1 1
|
85
|
+
#
|
86
|
+
# for local attention (with attn_chunk_size = 4) we would compute with an
|
87
|
+
# attention mask like:
|
88
|
+
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
89
|
+
# k_toks > 0 1 2 3 4 5
|
90
|
+
# q_toks v _____________
|
91
|
+
# 0 | 1 1 1
|
92
|
+
# 1 | 1 1 1 1
|
93
|
+
# 2 | 1
|
94
|
+
# 3 | 1 1
|
95
|
+
#
|
96
|
+
# We can simulate this mask using standard flash-attention by breaking the
|
97
|
+
# sequences into local ("virtual") batches, where each local batch item is a
|
98
|
+
# local attention block, so in this case batch idx 0 would be broken up into:
|
99
|
+
#
|
100
|
+
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
101
|
+
# k_toks > 0 1 2 3
|
102
|
+
# q_toks v _____________
|
103
|
+
# 0 | 1 1 1
|
104
|
+
# 1 | 1 1 1 1
|
105
|
+
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
106
|
+
# k_toks > 4 5
|
107
|
+
# q_toks v _____________
|
108
|
+
# 2 | 1
|
109
|
+
# 3 | 1 1
|
110
|
+
#
|
111
|
+
# e.g. if we have:
|
112
|
+
# attn_chunk_size = 4
|
113
|
+
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
114
|
+
# Then this function would return:
|
115
|
+
# __b0__ ______b1______ __b2__ < orig batch indices
|
116
|
+
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
117
|
+
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
118
|
+
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
119
|
+
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
120
|
+
def make_local_attention_virtual_batches(
|
121
|
+
attn_chunk_size: int,
|
122
|
+
query_start_loc_np: np.ndarray,
|
123
|
+
seq_lens_np: np.ndarray,
|
124
|
+
block_table: torch.Tensor,
|
125
|
+
page_size: int = 0,
|
126
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
127
|
+
"""
|
128
|
+
Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
129
|
+
local attention blocks, where each block is passed to the attention kernel
|
130
|
+
as an independent local ("virtual") batch item.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
attn_chunk_size: Size of local attention chunks
|
134
|
+
query_start_loc_np: Cumulative sum of query lengths (numpy array)
|
135
|
+
seq_lens_np: Sequence lengths (numpy array)
|
136
|
+
block_table: Block table for KV cache
|
137
|
+
page_size: Size of each page in the KV cache
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
seqlens_q_local: Query sequence lengths for local attention
|
141
|
+
cu_seqlens_q_local: Cumulative sum of query sequence lengths for local attention
|
142
|
+
seqlens_k_local: Key sequence lengths for local attention
|
143
|
+
block_table_local: Block table for local attention
|
144
|
+
"""
|
145
|
+
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
146
|
+
actual_batch_size = seq_lens_np.shape[0]
|
147
|
+
|
148
|
+
# Handle if we are starting in the middle of a local attention block,
|
149
|
+
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
150
|
+
# the number of tokens that are not in the first local attention block and
|
151
|
+
# then we can simply use a cdiv for the rest.
|
152
|
+
# For example if we have:
|
153
|
+
# attn_chunk_size = 4
|
154
|
+
# q_seqlens = [4, 10, 5]
|
155
|
+
# k_seqlens = [6, 17, 9]
|
156
|
+
# Then we would get:
|
157
|
+
# new_tokens_in_first_block = [2, 1, 4]
|
158
|
+
# local_blocks = [2, 4, 2]
|
159
|
+
q_tokens_in_first_block = np.minimum(
|
160
|
+
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
161
|
+
).astype(np.int32)
|
162
|
+
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
163
|
+
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
164
|
+
|
165
|
+
# Once we know the number of local blocks we can compute the request spans
|
166
|
+
# for each batch idx, we can figure out the number of "virtual" requests we
|
167
|
+
# have to make,
|
168
|
+
# For the above example we would get:
|
169
|
+
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
170
|
+
#
|
171
|
+
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
172
|
+
# (TODO: max a utility to share this code with _prepare_inputs)
|
173
|
+
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
174
|
+
cu_num_blocks = np.cumsum(local_blocks)
|
175
|
+
virtual_batches = cu_num_blocks[-1]
|
176
|
+
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
177
|
+
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
178
|
+
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
179
|
+
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
180
|
+
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
181
|
+
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
182
|
+
# Then we can compute the seqlens_q_local, handling the fact that the
|
183
|
+
# first and last blocks could be partial
|
184
|
+
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
185
|
+
# set the first block since this may be a partial block
|
186
|
+
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
187
|
+
# set the remaining blocks
|
188
|
+
seqlens_q_local[arange > 0] = np.minimum(
|
189
|
+
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
190
|
+
)[arange > 0]
|
191
|
+
|
192
|
+
# convert from q_seqlens to cu_seqlens_q
|
193
|
+
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
194
|
+
|
195
|
+
# compute the seqlens_k_local,
|
196
|
+
# basically a full local attention block for all but the last block in each
|
197
|
+
# batch
|
198
|
+
# For our example this will be:
|
199
|
+
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
200
|
+
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
201
|
+
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
202
|
+
|
203
|
+
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
204
|
+
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
205
|
+
)
|
206
|
+
# For the example the local attention blocks start at:
|
207
|
+
# _b0_ _____b1_____ _b2_
|
208
|
+
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
209
|
+
block_starts = k_seqstarts_absolute // page_size
|
210
|
+
|
211
|
+
assert attn_chunk_size % page_size == 0, (
|
212
|
+
f"attn_chunk_size {attn_chunk_size} is not "
|
213
|
+
f"divisible by page_size {page_size}"
|
214
|
+
)
|
215
|
+
pages_per_local_batch = attn_chunk_size // page_size
|
216
|
+
|
217
|
+
# Create a block_table for the local attention blocks
|
218
|
+
# For out example if we have a block-table like (assuming page_size=2):
|
219
|
+
# block_table = [
|
220
|
+
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
221
|
+
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
222
|
+
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
223
|
+
# ]
|
224
|
+
# Then for the local batches we would want a block-table like
|
225
|
+
# block_table_local = [
|
226
|
+
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
227
|
+
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
228
|
+
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
229
|
+
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
230
|
+
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
231
|
+
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
232
|
+
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
233
|
+
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
234
|
+
# ]
|
235
|
+
block_indices = np.broadcast_to(
|
236
|
+
np.arange(pages_per_local_batch, dtype=np.int32),
|
237
|
+
(virtual_batches, pages_per_local_batch),
|
238
|
+
) + np.expand_dims(block_starts, axis=1)
|
239
|
+
# Ensure block_indices doesn't exceed block_table dimensions
|
240
|
+
# This is a critical safety check that prevents index out of bounds errors
|
241
|
+
# when dealing with large sequences (>8192 tokens) or when the block_table
|
242
|
+
# dimensions are smaller than what would be needed for the full attention chunk size.
|
243
|
+
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
244
|
+
batch_indices = np.repeat(
|
245
|
+
np.arange(actual_batch_size, dtype=np.int32),
|
246
|
+
local_blocks * pages_per_local_batch,
|
247
|
+
)
|
248
|
+
block_table_local = block_table[batch_indices, block_indices].view(
|
249
|
+
virtual_batches, -1
|
250
|
+
)
|
251
|
+
|
252
|
+
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
253
|
+
|
254
|
+
|
255
|
+
def cdiv(a: int, b: int) -> int:
|
256
|
+
"""Ceiling division."""
|
257
|
+
return -(a // -b)
|
47
258
|
|
48
259
|
|
49
260
|
class FlashAttentionBackend(AttentionBackend):
|
@@ -68,9 +279,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
68
279
|
self,
|
69
280
|
model_runner: ModelRunner,
|
70
281
|
skip_prefill: bool = False,
|
282
|
+
speculative_step_id=0,
|
71
283
|
topk=0,
|
72
284
|
speculative_num_steps=0,
|
73
|
-
step_id=0,
|
74
285
|
):
|
75
286
|
super().__init__()
|
76
287
|
|
@@ -85,87 +296,82 @@ class FlashAttentionBackend(AttentionBackend):
|
|
85
296
|
self.decode_cuda_graph_metadata = {}
|
86
297
|
self.target_verify_metadata = {}
|
87
298
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
299
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
300
|
+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
88
301
|
self.page_size = model_runner.page_size
|
89
302
|
self.use_mla = (
|
90
303
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
91
304
|
) and (not global_server_args_dict["disable_mla"])
|
92
305
|
self.skip_prefill = skip_prefill
|
93
306
|
|
94
|
-
|
95
|
-
assert (
|
96
|
-
topk <= 1
|
97
|
-
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
98
|
-
|
99
|
-
self.topk = 1
|
100
|
-
self.step_id = step_id
|
307
|
+
self.topk = topk
|
101
308
|
self.speculative_num_steps = speculative_num_steps
|
309
|
+
self.speculative_num_draft_tokens = (
|
310
|
+
model_runner.server_args.speculative_num_draft_tokens
|
311
|
+
)
|
312
|
+
self.speculative_step_id = speculative_step_id
|
313
|
+
|
314
|
+
# Local attention settings
|
315
|
+
self.attention_chunk_size = (
|
316
|
+
model_runner.attention_chunk_size
|
317
|
+
if hasattr(model_runner, "attention_chunk_size")
|
318
|
+
else None
|
319
|
+
)
|
102
320
|
|
103
321
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
104
|
-
"""Initialize forward metadata
|
322
|
+
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
105
323
|
metadata = FlashAttentionMetadata()
|
106
324
|
seqlens_in_batch = forward_batch.seq_lens
|
107
325
|
batch_size = len(seqlens_in_batch)
|
108
326
|
device = seqlens_in_batch.device
|
109
|
-
|
110
|
-
|
111
|
-
#
|
327
|
+
|
328
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
329
|
+
# Draft Decode
|
112
330
|
if forward_batch.spec_info is not None:
|
331
|
+
metadata.cache_seqlens_int32 = (
|
332
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
333
|
+
).to(torch.int32)
|
334
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
335
|
+
self.speculative_step_id + 1
|
336
|
+
)
|
113
337
|
metadata.cu_seqlens_q = torch.arange(
|
114
338
|
0, batch_size + 1, dtype=torch.int32, device=device
|
115
339
|
)
|
116
|
-
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
117
|
-
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
118
340
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
119
341
|
torch.cumsum(
|
120
342
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
121
343
|
),
|
122
344
|
(1, 0),
|
123
345
|
)
|
124
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
125
|
-
self.step_id + 1
|
126
|
-
)
|
127
346
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
128
347
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
129
348
|
]
|
130
|
-
|
131
|
-
|
132
|
-
).T
|
133
|
-
|
134
|
-
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
135
|
-
real_bsz_start_idx = idx
|
136
|
-
real_bsz_end_idx = idx + 1
|
137
|
-
metadata.page_table[
|
138
|
-
real_bsz_start_idx:real_bsz_end_idx,
|
139
|
-
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
140
|
-
] = cache_loc[
|
141
|
-
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
142
|
-
]
|
143
|
-
else: # Normal Decode without Spec Decoding
|
349
|
+
else:
|
350
|
+
# Normal Decode
|
144
351
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
352
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
353
|
+
metadata.cu_seqlens_q = torch.arange(
|
354
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
355
|
+
)
|
145
356
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
146
357
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
147
358
|
)
|
148
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
149
359
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
150
360
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
151
361
|
]
|
152
|
-
metadata.cu_seqlens_q = torch.arange(
|
153
|
-
0, batch_size + 1, dtype=torch.int32, device=device
|
154
|
-
)
|
155
362
|
elif forward_batch.forward_mode.is_target_verify():
|
156
|
-
# Note: Target Verify will be ran on the Target Worker
|
157
|
-
draft_token_num = forward_batch.spec_info.draft_token_num
|
158
363
|
metadata.cache_seqlens_int32 = (
|
159
|
-
forward_batch.seq_lens +
|
364
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
160
365
|
).to(torch.int32)
|
161
|
-
metadata.max_seq_len_q =
|
366
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
162
367
|
metadata.max_seq_len_k = (
|
163
|
-
forward_batch.seq_lens_cpu.max().item()
|
368
|
+
forward_batch.seq_lens_cpu.max().item()
|
369
|
+
+ self.speculative_num_draft_tokens
|
164
370
|
)
|
165
371
|
metadata.cu_seqlens_q = torch.arange(
|
166
372
|
0,
|
167
|
-
batch_size *
|
168
|
-
|
373
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
374
|
+
self.speculative_num_draft_tokens,
|
169
375
|
dtype=torch.int32,
|
170
376
|
device=device,
|
171
377
|
)
|
@@ -177,33 +383,99 @@ class FlashAttentionBackend(AttentionBackend):
|
|
177
383
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
178
384
|
]
|
179
385
|
|
180
|
-
elif forward_batch.forward_mode.
|
181
|
-
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
386
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
182
387
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
388
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
183
389
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
184
390
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
185
391
|
)
|
186
|
-
# Precompute maximum sequence length
|
187
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
188
|
-
# Precompute page table
|
189
392
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
190
393
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
191
394
|
]
|
192
|
-
|
395
|
+
|
193
396
|
if (
|
194
397
|
any(forward_batch.extend_prefix_lens_cpu)
|
195
398
|
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
196
399
|
):
|
197
400
|
extend_seq_lens = forward_batch.extend_seq_lens
|
401
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
198
402
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
199
403
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
200
404
|
)
|
201
|
-
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
202
405
|
else:
|
203
|
-
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
204
406
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
407
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
408
|
+
|
409
|
+
# Setup local attention if enabled
|
410
|
+
if (
|
411
|
+
self.attention_chunk_size is not None
|
412
|
+
and forward_batch.forward_mode == ForwardMode.EXTEND
|
413
|
+
):
|
414
|
+
# Convert tensors to numpy for local attention processing
|
415
|
+
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
416
|
+
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
417
|
+
|
418
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
419
|
+
# to avoid index out of bounds errors
|
420
|
+
max_seq_len = seq_lens_np.max()
|
421
|
+
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
422
|
+
# Make sure effective_chunk_size is divisible by page_size
|
423
|
+
effective_chunk_size = (
|
424
|
+
effective_chunk_size // self.page_size
|
425
|
+
) * self.page_size
|
426
|
+
if effective_chunk_size < self.page_size:
|
427
|
+
effective_chunk_size = self.page_size
|
428
|
+
|
429
|
+
# Create local attention metadata
|
430
|
+
(
|
431
|
+
seqlens_q_local_np,
|
432
|
+
cu_seqlens_q_local_np,
|
433
|
+
seqlens_k_local_np,
|
434
|
+
block_table_local,
|
435
|
+
) = make_local_attention_virtual_batches(
|
436
|
+
effective_chunk_size,
|
437
|
+
cu_seqlens_q_np,
|
438
|
+
seq_lens_np,
|
439
|
+
metadata.page_table,
|
440
|
+
self.page_size,
|
441
|
+
)
|
442
|
+
|
443
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
444
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
445
|
+
device
|
446
|
+
),
|
447
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
448
|
+
local_block_table=block_table_local,
|
449
|
+
local_max_query_len=seqlens_q_local_np.max(),
|
450
|
+
local_max_seq_len=seqlens_k_local_np.max(),
|
451
|
+
)
|
452
|
+
metadata.local_attn_metadata = local_metadata
|
453
|
+
|
454
|
+
# Encoder metadata for cross attention
|
455
|
+
if forward_batch.encoder_lens is not None:
|
456
|
+
assert (
|
457
|
+
forward_batch.encoder_lens.numel() == 1
|
458
|
+
), "Only encoder size 1 is supported for now"
|
205
459
|
|
206
|
-
|
460
|
+
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
461
|
+
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
462
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
463
|
+
(1, 0),
|
464
|
+
)
|
465
|
+
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
466
|
+
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
467
|
+
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
468
|
+
]
|
469
|
+
|
470
|
+
# Currently only support forward_batch.encoder_lens.numel() == 1
|
471
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
472
|
+
forward_batch.req_pool_indices,
|
473
|
+
metadata.encoder_max_seq_len_k : (
|
474
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
475
|
+
),
|
476
|
+
]
|
477
|
+
|
478
|
+
# Convert the page table to a strided format which is needed by FA3 API
|
207
479
|
if self.page_size > 1:
|
208
480
|
self.strided_indices = torch.arange(
|
209
481
|
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
@@ -211,6 +483,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
211
483
|
metadata.page_table = (
|
212
484
|
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
485
|
)
|
486
|
+
|
214
487
|
self.forward_metadata = metadata
|
215
488
|
|
216
489
|
def forward_extend(
|
@@ -242,7 +515,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
242
515
|
v,
|
243
516
|
)
|
244
517
|
|
245
|
-
# Use precomputed metadata
|
518
|
+
# Use precomputed metadata across all layers
|
246
519
|
metadata = self.forward_metadata
|
247
520
|
|
248
521
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
@@ -250,75 +523,157 @@ class FlashAttentionBackend(AttentionBackend):
|
|
250
523
|
# here is two side inclusive
|
251
524
|
window_size = (
|
252
525
|
(layer.sliding_window_size, 0)
|
253
|
-
if layer.sliding_window_size is not None
|
526
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
254
527
|
else (-1, -1)
|
255
528
|
)
|
529
|
+
k_descale, v_descale = None, None
|
530
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
531
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
532
|
+
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
|
533
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
534
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
535
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
536
|
+
q = q.to(self.kv_cache_dtype)
|
537
|
+
causal = not layer.is_cross_attention
|
538
|
+
|
539
|
+
# Check if we should use local attention
|
540
|
+
use_local_attn = (
|
541
|
+
self.attention_chunk_size is not None
|
542
|
+
and metadata.local_attn_metadata is not None
|
543
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
544
|
+
)
|
256
545
|
|
257
|
-
|
546
|
+
# Get the appropriate page table based on whether we're using local attention
|
547
|
+
if use_local_attn:
|
548
|
+
local_metadata = metadata.local_attn_metadata
|
549
|
+
page_table = local_metadata.local_block_table
|
550
|
+
cu_seqlens_q = local_metadata.local_query_start_loc
|
551
|
+
cache_seqlens = local_metadata.local_seqused_k
|
552
|
+
max_seqlen_q = local_metadata.local_max_query_len
|
553
|
+
max_seqlen_k = local_metadata.local_max_seq_len
|
554
|
+
else:
|
555
|
+
page_table = metadata.page_table
|
556
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
557
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
558
|
+
max_seqlen_q = metadata.max_seq_len_q
|
559
|
+
max_seqlen_k = metadata.max_seq_len_k
|
560
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
258
561
|
|
259
562
|
# Use Flash Attention for prefill
|
260
563
|
if not self.use_mla:
|
261
564
|
# Do multi-head attention
|
262
|
-
|
263
|
-
|
565
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
566
|
+
layer.layer_id
|
567
|
+
)
|
264
568
|
key_cache = key_cache.view(
|
265
569
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
266
570
|
)
|
267
571
|
value_cache = value_cache.view(
|
268
572
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
269
573
|
)
|
574
|
+
if layer.is_cross_attention:
|
575
|
+
page_table = metadata.encoder_page_table
|
576
|
+
cache_seqlens = metadata.encoder_lens_int32
|
577
|
+
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
578
|
+
window_size = (-1, -1)
|
579
|
+
|
270
580
|
o = flash_attn_with_kvcache(
|
271
581
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
272
582
|
k_cache=key_cache,
|
273
583
|
v_cache=value_cache,
|
274
584
|
page_table=page_table,
|
275
|
-
cache_seqlens=
|
276
|
-
cu_seqlens_q=
|
277
|
-
cu_seqlens_k_new=
|
278
|
-
max_seqlen_q=
|
585
|
+
cache_seqlens=cache_seqlens,
|
586
|
+
cu_seqlens_q=cu_seqlens_q,
|
587
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
588
|
+
max_seqlen_q=max_seqlen_q,
|
279
589
|
softmax_scale=layer.scaling,
|
280
|
-
causal=
|
590
|
+
causal=causal,
|
281
591
|
window_size=window_size,
|
282
592
|
softcap=layer.logit_cap,
|
283
|
-
k_descale=
|
284
|
-
v_descale=
|
593
|
+
k_descale=k_descale,
|
594
|
+
v_descale=v_descale,
|
285
595
|
)
|
596
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
286
597
|
else:
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
598
|
+
if (
|
599
|
+
not global_server_args_dict["disable_chunked_prefix_cache"]
|
600
|
+
and forward_batch.attn_attend_prefix_cache is not None
|
601
|
+
and not forward_batch.forward_mode.is_target_verify()
|
602
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
603
|
+
):
|
604
|
+
# Do multi-head attention with chunked prefix cache
|
605
|
+
|
606
|
+
if forward_batch.attn_attend_prefix_cache:
|
607
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
608
|
+
assert forward_batch.prefix_chunk_idx is not None
|
609
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
610
|
+
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
611
|
+
|
612
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
613
|
+
assert chunk_idx >= 0
|
614
|
+
|
615
|
+
output, lse, *rest = flash_attn_varlen_func(
|
616
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
617
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
618
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
619
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
620
|
+
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
621
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
622
|
+
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
623
|
+
softmax_scale=layer.scaling,
|
624
|
+
causal=False,
|
625
|
+
return_softmax_lse=True,
|
626
|
+
)
|
627
|
+
else:
|
628
|
+
# MHA for extend part of sequence without attending prefix kv cache
|
629
|
+
output, lse, *rest = flash_attn_varlen_func(
|
630
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
631
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
632
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
633
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
634
|
+
cu_seqlens_k=metadata.cu_seqlens_q,
|
635
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
636
|
+
max_seqlen_k=metadata.max_seq_len_q,
|
637
|
+
softmax_scale=layer.scaling,
|
638
|
+
causal=True,
|
639
|
+
return_softmax_lse=True,
|
640
|
+
)
|
641
|
+
return output, lse
|
642
|
+
else:
|
643
|
+
# Do absorbed multi-latent attention
|
644
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
645
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
646
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
647
|
+
k_rope_cache = k_rope.view(
|
648
|
+
-1,
|
649
|
+
self.page_size,
|
650
|
+
layer.tp_k_head_num,
|
651
|
+
layer.head_dim - layer.v_head_dim,
|
652
|
+
)
|
653
|
+
c_kv_cache = c_kv.view(
|
654
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
655
|
+
)
|
656
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
657
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
658
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
659
|
+
o = flash_attn_with_kvcache(
|
660
|
+
q=q_rope,
|
661
|
+
k_cache=k_rope_cache,
|
662
|
+
v_cache=c_kv_cache,
|
663
|
+
qv=q_nope,
|
664
|
+
page_table=page_table,
|
665
|
+
cache_seqlens=cache_seqlens,
|
666
|
+
cu_seqlens_q=cu_seqlens_q,
|
667
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
668
|
+
max_seqlen_q=max_seqlen_q,
|
669
|
+
softmax_scale=layer.scaling,
|
670
|
+
causal=True,
|
671
|
+
softcap=layer.logit_cap,
|
672
|
+
k_descale=k_descale,
|
673
|
+
v_descale=v_descale,
|
674
|
+
)
|
320
675
|
|
321
|
-
|
676
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
322
677
|
|
323
678
|
def forward_decode(
|
324
679
|
self,
|
@@ -329,8 +684,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
329
684
|
forward_batch: ForwardBatch,
|
330
685
|
save_kv_cache=True,
|
331
686
|
) -> torch.Tensor:
|
332
|
-
"""Forward pass with FlashAttention using precomputed metadata."""
|
333
|
-
# Save KV cache if needed
|
334
687
|
if k is not None:
|
335
688
|
assert v is not None
|
336
689
|
if save_kv_cache:
|
@@ -351,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
351
704
|
v,
|
352
705
|
)
|
353
706
|
|
354
|
-
# Use precomputed metadata
|
707
|
+
# Use precomputed metadata across all layers
|
355
708
|
metadata = self.forward_metadata
|
356
709
|
|
357
710
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
@@ -359,17 +712,27 @@ class FlashAttentionBackend(AttentionBackend):
|
|
359
712
|
# here is two side inclusive
|
360
713
|
window_size = (
|
361
714
|
(layer.sliding_window_size, 0)
|
362
|
-
if layer.sliding_window_size is not None
|
715
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
363
716
|
else (-1, -1)
|
364
717
|
)
|
365
|
-
|
718
|
+
causal = not layer.is_cross_attention
|
719
|
+
|
720
|
+
k_descale, v_descale = None, None
|
721
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
722
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
723
|
+
if self.kv_cache_dtype_str != "auto":
|
724
|
+
if layer.k_scale is not None:
|
725
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
726
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
727
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
728
|
+
q = q.to(self.kv_cache_dtype)
|
366
729
|
|
367
730
|
if not self.use_mla:
|
368
731
|
# Do multi-head attention
|
369
732
|
|
370
|
-
|
371
|
-
|
372
|
-
|
733
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
734
|
+
layer.layer_id
|
735
|
+
)
|
373
736
|
key_cache = key_cache.view(
|
374
737
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
375
738
|
)
|
@@ -377,23 +740,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
377
740
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
378
741
|
)
|
379
742
|
|
380
|
-
# Pre-reshape query tensor
|
381
743
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
744
|
+
if layer.is_cross_attention:
|
745
|
+
page_table = metadata.encoder_page_table
|
746
|
+
cache_seqlens = metadata.encoder_lens_int32
|
747
|
+
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
748
|
+
window_size = (-1, -1)
|
749
|
+
else:
|
750
|
+
page_table = metadata.page_table
|
751
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
752
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
753
|
+
|
382
754
|
o = flash_attn_with_kvcache(
|
383
755
|
q=q_reshaped,
|
384
756
|
k_cache=key_cache,
|
385
757
|
v_cache=value_cache,
|
386
758
|
page_table=page_table,
|
387
|
-
cache_seqlens=
|
759
|
+
cache_seqlens=cache_seqlens,
|
388
760
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
389
|
-
cu_seqlens_k_new=
|
761
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
390
762
|
max_seqlen_q=1,
|
391
763
|
softmax_scale=layer.scaling,
|
392
|
-
causal=
|
764
|
+
causal=causal,
|
393
765
|
window_size=window_size,
|
394
766
|
softcap=layer.logit_cap,
|
395
|
-
k_descale=
|
396
|
-
v_descale=
|
767
|
+
k_descale=k_descale,
|
768
|
+
v_descale=v_descale,
|
397
769
|
)
|
398
770
|
else:
|
399
771
|
# Do absorbed multi-latent attention
|
@@ -419,7 +791,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
419
791
|
k_cache=k_rope_cache,
|
420
792
|
v_cache=c_kv_cache,
|
421
793
|
qv=q_nope,
|
422
|
-
page_table=page_table,
|
794
|
+
page_table=metadata.page_table,
|
423
795
|
cache_seqlens=metadata.cache_seqlens_int32,
|
424
796
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
425
797
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
@@ -427,8 +799,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
427
799
|
softmax_scale=layer.scaling,
|
428
800
|
causal=True,
|
429
801
|
softcap=layer.logit_cap,
|
430
|
-
k_descale=
|
431
|
-
v_descale=
|
802
|
+
k_descale=k_descale,
|
803
|
+
v_descale=v_descale,
|
432
804
|
)
|
433
805
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
434
806
|
|
@@ -442,7 +814,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
442
814
|
to avoid memory allocations.
|
443
815
|
"""
|
444
816
|
self.decode_cuda_graph_metadata = {
|
445
|
-
|
817
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
818
|
+
"cu_seqlens_q": torch.arange(
|
819
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
820
|
+
),
|
821
|
+
"cu_seqlens_k": torch.zeros(
|
822
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
823
|
+
),
|
446
824
|
"page_table": torch.zeros(
|
447
825
|
max_bs,
|
448
826
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
@@ -458,35 +836,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|
458
836
|
"strided_indices": torch.arange(
|
459
837
|
0, self.max_context_len, self.page_size, device=self.device
|
460
838
|
),
|
839
|
+
}
|
840
|
+
|
841
|
+
self.target_verify_metadata = {
|
461
842
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
462
|
-
"cu_seqlens_q": torch.
|
463
|
-
|
843
|
+
"cu_seqlens_q": torch.zeros(
|
844
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
464
845
|
),
|
465
846
|
"cu_seqlens_k": torch.zeros(
|
466
|
-
max_bs +
|
847
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
467
848
|
),
|
468
|
-
}
|
469
|
-
|
470
|
-
self.target_verify_metadata = {
|
471
849
|
"page_table": torch.zeros(
|
472
850
|
max_bs,
|
473
851
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
474
852
|
dtype=torch.int32,
|
475
853
|
device=self.device,
|
476
854
|
),
|
477
|
-
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
478
|
-
"cu_seqlens_q": torch.zeros(
|
479
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
480
|
-
),
|
481
|
-
"cu_seqlens_k": torch.zeros(
|
482
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
483
|
-
),
|
484
|
-
"max_seqlen_q": 0,
|
485
855
|
"strided_indices": torch.arange(
|
486
856
|
0, self.max_context_len, self.page_size, device=self.device
|
487
857
|
),
|
488
858
|
}
|
489
859
|
|
860
|
+
self.encoder_metadata = {
|
861
|
+
"encoder_page_table": torch.zeros(
|
862
|
+
max_bs,
|
863
|
+
self.max_context_len,
|
864
|
+
dtype=torch.int32,
|
865
|
+
device=self.device,
|
866
|
+
),
|
867
|
+
"encoder_lens_int32": torch.zeros(
|
868
|
+
max_bs, dtype=torch.int32, device=self.device
|
869
|
+
),
|
870
|
+
"encoder_cu_seqlens_k": torch.zeros(
|
871
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
872
|
+
),
|
873
|
+
}
|
874
|
+
|
490
875
|
def init_forward_metadata_capture_cuda_graph(
|
491
876
|
self,
|
492
877
|
bs: int,
|
@@ -500,27 +885,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
500
885
|
"""Initialize forward metadata for capturing CUDA graph."""
|
501
886
|
metadata = FlashAttentionMetadata()
|
502
887
|
device = seq_lens.device
|
503
|
-
if forward_mode.
|
888
|
+
if forward_mode.is_decode_or_idle():
|
504
889
|
if spec_info is not None:
|
505
890
|
# Draft Decode
|
506
|
-
metadata.cu_seqlens_q = torch.arange(
|
507
|
-
0, bs + 1, dtype=torch.int32, device=device
|
508
|
-
)
|
509
891
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
510
892
|
"cache_seqlens"
|
511
893
|
][:bs]
|
512
|
-
|
894
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
895
|
+
self.speculative_step_id + 1
|
896
|
+
)
|
513
897
|
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
514
898
|
: bs + 1
|
515
899
|
]
|
516
|
-
|
517
900
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
518
901
|
torch.cumsum(
|
519
902
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
520
903
|
),
|
521
904
|
(1, 0),
|
522
905
|
)
|
523
|
-
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
524
906
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
525
907
|
"page_table_draft_decode"
|
526
908
|
][req_pool_indices, :]
|
@@ -545,43 +927,49 @@ class FlashAttentionBackend(AttentionBackend):
|
|
545
927
|
)
|
546
928
|
self.decode_cuda_graph_metadata[bs] = metadata
|
547
929
|
elif forward_mode.is_target_verify():
|
548
|
-
draft_token_num = spec_info.draft_token_num
|
549
|
-
|
550
930
|
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
551
931
|
:bs
|
552
932
|
]
|
553
933
|
metadata.cache_seqlens_int32.copy_(
|
554
|
-
(seq_lens +
|
934
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
555
935
|
)
|
556
936
|
|
557
|
-
metadata.max_seq_len_q =
|
558
|
-
metadata.max_seq_len_k =
|
937
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
938
|
+
metadata.max_seq_len_k = (
|
939
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
940
|
+
)
|
559
941
|
|
560
|
-
metadata.cu_seqlens_q =
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
device=device,
|
567
|
-
)
|
568
|
-
]
|
569
|
-
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
570
|
-
cu_k.copy_(
|
571
|
-
torch.nn.functional.pad(
|
572
|
-
torch.cumsum(
|
573
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
574
|
-
),
|
575
|
-
(1, 0),
|
576
|
-
)
|
942
|
+
metadata.cu_seqlens_q = torch.arange(
|
943
|
+
0,
|
944
|
+
bs * self.speculative_num_draft_tokens + 1,
|
945
|
+
self.speculative_num_draft_tokens,
|
946
|
+
dtype=torch.int32,
|
947
|
+
device=device,
|
577
948
|
)
|
578
|
-
|
949
|
+
|
950
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
951
|
+
: (bs + 1)
|
952
|
+
]
|
953
|
+
|
579
954
|
metadata.page_table = self.target_verify_metadata["page_table"][
|
580
955
|
req_pool_indices, :
|
581
956
|
]
|
582
957
|
|
583
958
|
self.target_verify_metadata[bs] = metadata
|
584
959
|
|
960
|
+
if encoder_lens is not None:
|
961
|
+
encoder_bs = encoder_lens.numel()
|
962
|
+
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
|
963
|
+
:encoder_bs
|
964
|
+
]
|
965
|
+
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
|
966
|
+
"encoder_cu_seqlens_k"
|
967
|
+
][: (encoder_bs + 1)]
|
968
|
+
|
969
|
+
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
970
|
+
req_pool_indices, :
|
971
|
+
]
|
972
|
+
|
585
973
|
self.forward_metadata = metadata
|
586
974
|
|
587
975
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -597,24 +985,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
597
985
|
out_cache_loc: torch.Tensor = None,
|
598
986
|
):
|
599
987
|
# """Initialize forward metadata for replaying CUDA graph."""
|
600
|
-
device = seq_lens.device
|
601
988
|
seq_lens = seq_lens[:bs]
|
602
|
-
req_pool_indices = req_pool_indices[:bs]
|
603
989
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
604
|
-
|
990
|
+
req_pool_indices = req_pool_indices[:bs]
|
991
|
+
if forward_mode.is_decode_or_idle():
|
605
992
|
metadata = self.decode_cuda_graph_metadata[bs]
|
606
993
|
|
607
994
|
if spec_info is not None:
|
608
995
|
# Draft Decode
|
609
|
-
max_len = seq_lens_cpu.max().item()
|
610
|
-
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
611
|
-
|
612
996
|
metadata.cache_seqlens_int32.copy_(
|
613
|
-
(seq_lens + (self.
|
997
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
614
998
|
)
|
615
999
|
|
616
|
-
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
617
|
-
|
1000
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1001
|
+
self.speculative_step_id + 1
|
1002
|
+
)
|
618
1003
|
metadata.cu_seqlens_k.copy_(
|
619
1004
|
torch.nn.functional.pad(
|
620
1005
|
torch.cumsum(
|
@@ -643,31 +1028,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
643
1028
|
metadata.max_seq_len_k + self.page_size - 1
|
644
1029
|
) // self.page_size
|
645
1030
|
page_indices = self.req_to_token[
|
646
|
-
:,
|
647
|
-
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
1031
|
+
req_pool_indices[:, None],
|
1032
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
1033
|
+
None, :
|
1034
|
+
],
|
648
1035
|
]
|
649
|
-
page_indices
|
1036
|
+
page_indices //= self.page_size
|
650
1037
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
651
1038
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
652
1039
|
|
653
1040
|
elif forward_mode.is_target_verify():
|
654
1041
|
metadata = self.target_verify_metadata[bs]
|
655
|
-
draft_token_num = spec_info.draft_token_num
|
656
|
-
|
657
|
-
metadata.cu_seqlens_q.copy_(
|
658
|
-
torch.arange(
|
659
|
-
0,
|
660
|
-
bs * draft_token_num + 1,
|
661
|
-
draft_token_num,
|
662
|
-
dtype=torch.int32,
|
663
|
-
device=device,
|
664
|
-
)
|
665
|
-
)
|
666
1042
|
metadata.cache_seqlens_int32.copy_(
|
667
|
-
(seq_lens +
|
1043
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
668
1044
|
)
|
669
1045
|
|
670
|
-
metadata.max_seq_len_k =
|
1046
|
+
metadata.max_seq_len_k = (
|
1047
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1048
|
+
)
|
671
1049
|
metadata.cu_seqlens_k.copy_(
|
672
1050
|
torch.nn.functional.pad(
|
673
1051
|
torch.cumsum(
|
@@ -679,6 +1057,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|
679
1057
|
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
680
1058
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
681
1059
|
|
1060
|
+
if encoder_lens is not None:
|
1061
|
+
# Only support encoder size 1 for now
|
1062
|
+
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1063
|
+
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1064
|
+
metadata.encoder_cu_seqlens_k.copy_(
|
1065
|
+
torch.nn.functional.pad(
|
1066
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1067
|
+
(1, 0),
|
1068
|
+
)
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
1072
|
+
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
# Update the regular page table
|
1076
|
+
page_table = self.req_to_token[
|
1077
|
+
req_pool_indices,
|
1078
|
+
metadata.encoder_max_seq_len_k : (
|
1079
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
1080
|
+
),
|
1081
|
+
]
|
1082
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1083
|
+
|
682
1084
|
self.forward_metadata = metadata
|
683
1085
|
|
684
1086
|
def get_cuda_graph_seq_len_fill_value(self):
|
@@ -695,14 +1097,19 @@ class FlashAttentionMultiStepBackend:
|
|
695
1097
|
self.topk = topk
|
696
1098
|
self.speculative_num_steps = speculative_num_steps
|
697
1099
|
|
1100
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1101
|
+
assert (
|
1102
|
+
self.topk == 1
|
1103
|
+
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1104
|
+
|
698
1105
|
self.attn_backends = []
|
699
1106
|
for i in range(self.speculative_num_steps):
|
700
1107
|
self.attn_backends.append(
|
701
1108
|
FlashAttentionBackend(
|
702
1109
|
model_runner,
|
1110
|
+
speculative_step_id=i,
|
703
1111
|
topk=self.topk,
|
704
1112
|
speculative_num_steps=self.speculative_num_steps,
|
705
|
-
step_id=i,
|
706
1113
|
)
|
707
1114
|
)
|
708
1115
|
|
@@ -727,7 +1134,7 @@ class FlashAttentionMultiStepBackend:
|
|
727
1134
|
forward_batch.batch_size * self.topk,
|
728
1135
|
forward_batch.req_pool_indices,
|
729
1136
|
forward_batch.seq_lens,
|
730
|
-
encoder_lens=
|
1137
|
+
encoder_lens=forward_batch.encoder_lens,
|
731
1138
|
forward_mode=ForwardMode.DECODE,
|
732
1139
|
spec_info=forward_batch.spec_info,
|
733
1140
|
)
|
@@ -744,7 +1151,7 @@ class FlashAttentionMultiStepBackend:
|
|
744
1151
|
forward_batch.req_pool_indices,
|
745
1152
|
forward_batch.seq_lens,
|
746
1153
|
forward_batch.seq_lens_sum,
|
747
|
-
encoder_lens=
|
1154
|
+
encoder_lens=forward_batch.encoder_lens,
|
748
1155
|
forward_mode=ForwardMode.DECODE,
|
749
1156
|
spec_info=forward_batch.spec_info,
|
750
1157
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|