sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,513 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
end to end attention solution with aiter kernels
|
5
|
+
"""
|
6
|
+
|
7
|
+
import math
|
8
|
+
import os
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from enum import Enum, auto
|
11
|
+
from functools import partial
|
12
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import triton
|
16
|
+
import triton.language as tl
|
17
|
+
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
26
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
28
|
+
|
29
|
+
try:
|
30
|
+
from aiter import mha_batch_prefill_func, paged_attention_ragged
|
31
|
+
except ImportError:
|
32
|
+
print(
|
33
|
+
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
class WrapperDispatch(Enum):
|
38
|
+
SLIDING_WINDOW = auto()
|
39
|
+
CROSS_ATTENTION = auto()
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass
|
43
|
+
class ForwardMetadata:
|
44
|
+
kv_indptr: torch.Tensor
|
45
|
+
kv_indices: torch.Tensor
|
46
|
+
max_q_len: int
|
47
|
+
max_kv_len: int
|
48
|
+
|
49
|
+
|
50
|
+
global_workspace_buffer = None
|
51
|
+
|
52
|
+
_AITER_PARTITION_SIZE_ROCM = 256
|
53
|
+
|
54
|
+
|
55
|
+
class AiterAttnBackend(AttentionBackend):
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model_runner: ModelRunner,
|
59
|
+
skip_prefill: bool = False,
|
60
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
|
64
|
+
self.device = model_runner.device
|
65
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
66
|
+
self.num_head = (
|
67
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
68
|
+
)
|
69
|
+
self.head_dim = model_runner.model_config.head_dim
|
70
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
71
|
+
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
72
|
+
get_attention_tp_size()
|
73
|
+
)
|
74
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
75
|
+
|
76
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
77
|
+
|
78
|
+
# Parse constants
|
79
|
+
self.max_context_len = model_runner.model_config.context_len
|
80
|
+
self.skip_prefill = skip_prefill
|
81
|
+
|
82
|
+
max_bs = model_runner.req_to_token_pool.size
|
83
|
+
|
84
|
+
if kv_indptr_buf is None:
|
85
|
+
self.kv_indptr = torch.zeros(
|
86
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
87
|
+
)
|
88
|
+
else:
|
89
|
+
self.kv_indptr = kv_indptr_buf
|
90
|
+
|
91
|
+
self.kv_last_page_len = torch.ones(
|
92
|
+
(max_bs,), dtype=torch.int32, device=model_runner.device
|
93
|
+
)
|
94
|
+
self.qo_indptr = torch.zeros(
|
95
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
96
|
+
)
|
97
|
+
|
98
|
+
# Create prefill indices updater
|
99
|
+
if not skip_prefill:
|
100
|
+
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
101
|
+
model_runner, self
|
102
|
+
)
|
103
|
+
|
104
|
+
# aiter kernel related initialization
|
105
|
+
self.max_num_partitions = (
|
106
|
+
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
107
|
+
) // _AITER_PARTITION_SIZE_ROCM
|
108
|
+
|
109
|
+
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
110
|
+
|
111
|
+
self.workspace_buffer = torch.empty(
|
112
|
+
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
113
|
+
* nbyes_per_qo_elem
|
114
|
+
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
115
|
+
dtype=torch.uint8,
|
116
|
+
device=self.device,
|
117
|
+
)
|
118
|
+
|
119
|
+
self.scale = float(1.0 / (self.head_dim**0.5))
|
120
|
+
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
121
|
+
self.device
|
122
|
+
)
|
123
|
+
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
124
|
+
self.device
|
125
|
+
)
|
126
|
+
|
127
|
+
self.logits_soft_cap = 0.0
|
128
|
+
|
129
|
+
self.forward_metadata: ForwardMetadata = None
|
130
|
+
|
131
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
132
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
133
|
+
# update for aiter
|
134
|
+
# create kv_indices and kv_inptr
|
135
|
+
bs = forward_batch.batch_size
|
136
|
+
kv_indptr = self.kv_indptr
|
137
|
+
spec_info = forward_batch.spec_info
|
138
|
+
if spec_info is None:
|
139
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
140
|
+
kv_indptr = kv_indptr[: bs + 1]
|
141
|
+
kv_indices = torch.zeros(
|
142
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
143
|
+
)
|
144
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
145
|
+
self.req_to_token,
|
146
|
+
forward_batch.req_pool_indices,
|
147
|
+
forward_batch.seq_lens,
|
148
|
+
kv_indptr,
|
149
|
+
None,
|
150
|
+
kv_indices,
|
151
|
+
self.req_to_token.stride(0),
|
152
|
+
)
|
153
|
+
else:
|
154
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
155
|
+
bs = kv_indptr.shape[0] - 1
|
156
|
+
|
157
|
+
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
158
|
+
|
159
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
160
|
+
self.indices_updater_prefill.update(
|
161
|
+
forward_batch.req_pool_indices,
|
162
|
+
forward_batch.seq_lens,
|
163
|
+
forward_batch.seq_lens_sum,
|
164
|
+
prefix_lens=None,
|
165
|
+
encoder_lens=forward_batch.encoder_lens,
|
166
|
+
spec_info=forward_batch.spec_info,
|
167
|
+
)
|
168
|
+
self.forward_metadata = ForwardMetadata(
|
169
|
+
self.indices_updater_prefill.kv_indptr,
|
170
|
+
self.indices_updater_prefill.kv_indices,
|
171
|
+
self.indices_updater_prefill.max_q_len,
|
172
|
+
self.indices_updater_prefill.max_kv_len,
|
173
|
+
)
|
174
|
+
elif forward_batch.forward_mode.is_target_verify():
|
175
|
+
self.indices_updater_prefill.update(
|
176
|
+
forward_batch.req_pool_indices,
|
177
|
+
forward_batch.seq_lens,
|
178
|
+
forward_batch.seq_lens_sum,
|
179
|
+
prefix_lens=None,
|
180
|
+
encoder_lens=forward_batch.encoder_lens,
|
181
|
+
spec_info=forward_batch.spec_info,
|
182
|
+
)
|
183
|
+
self.forward_metadata = ForwardMetadata(
|
184
|
+
self.indices_updater_prefill.kv_indptr,
|
185
|
+
self.indices_updater_prefill.kv_indices,
|
186
|
+
self.indices_updater_prefill.max_q_len,
|
187
|
+
self.indices_updater_prefill.max_kv_len,
|
188
|
+
)
|
189
|
+
else:
|
190
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
191
|
+
|
192
|
+
if self.is_multimodal:
|
193
|
+
extend_no_prefix = False
|
194
|
+
else:
|
195
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
196
|
+
|
197
|
+
self.indices_updater_prefill.update(
|
198
|
+
forward_batch.req_pool_indices,
|
199
|
+
forward_batch.seq_lens,
|
200
|
+
forward_batch.seq_lens_sum,
|
201
|
+
prefix_lens,
|
202
|
+
encoder_lens=forward_batch.encoder_lens,
|
203
|
+
spec_info=None,
|
204
|
+
)
|
205
|
+
self.forward_metadata = ForwardMetadata(
|
206
|
+
self.indices_updater_prefill.kv_indptr,
|
207
|
+
self.indices_updater_prefill.kv_indices,
|
208
|
+
self.indices_updater_prefill.max_q_len,
|
209
|
+
self.indices_updater_prefill.max_kv_len,
|
210
|
+
)
|
211
|
+
|
212
|
+
def init_cuda_graph_state(
|
213
|
+
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
214
|
+
):
|
215
|
+
if kv_indices_buf is None:
|
216
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
217
|
+
(max_bs * self.max_context_len),
|
218
|
+
dtype=torch.int32,
|
219
|
+
device=self.device,
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
self.cuda_graph_kv_indices = kv_indices_buf
|
223
|
+
|
224
|
+
if not self.skip_prefill:
|
225
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
226
|
+
(max_bs * self.max_context_len),
|
227
|
+
dtype=torch.uint8,
|
228
|
+
device=self.device,
|
229
|
+
)
|
230
|
+
|
231
|
+
def init_forward_metadata_capture_cuda_graph(
|
232
|
+
self,
|
233
|
+
bs: int,
|
234
|
+
num_tokens: int,
|
235
|
+
req_pool_indices: torch.Tensor,
|
236
|
+
seq_lens: torch.Tensor,
|
237
|
+
encoder_lens: Optional[torch.Tensor],
|
238
|
+
forward_mode: ForwardMode,
|
239
|
+
spec_info: Optional[SpecInfo],
|
240
|
+
):
|
241
|
+
if forward_mode.is_decode_or_idle():
|
242
|
+
if spec_info is None:
|
243
|
+
kv_indptr = self.kv_indptr
|
244
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
245
|
+
kv_indptr = kv_indptr[: bs + 1]
|
246
|
+
kv_indices = self.cuda_graph_kv_indices
|
247
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
248
|
+
self.req_to_token,
|
249
|
+
req_pool_indices,
|
250
|
+
seq_lens,
|
251
|
+
kv_indptr,
|
252
|
+
None,
|
253
|
+
kv_indices,
|
254
|
+
self.req_to_token.stride(0),
|
255
|
+
)
|
256
|
+
else:
|
257
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
258
|
+
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
259
|
+
|
260
|
+
elif forward_mode.is_target_verify():
|
261
|
+
seq_lens_sum = seq_lens.sum().item()
|
262
|
+
self.indices_updater_prefill.update(
|
263
|
+
req_pool_indices,
|
264
|
+
seq_lens,
|
265
|
+
seq_lens_sum,
|
266
|
+
prefix_lens=None,
|
267
|
+
encoder_lens=encoder_lens,
|
268
|
+
spec_info=spec_info,
|
269
|
+
)
|
270
|
+
self.forward_metadata = ForwardMetadata(
|
271
|
+
self.indices_updater_prefill.kv_indptr,
|
272
|
+
self.indices_updater_prefill.kv_indices,
|
273
|
+
self.indices_updater_prefill.max_q_len,
|
274
|
+
self.indices_updater_prefill.max_kv_len,
|
275
|
+
)
|
276
|
+
|
277
|
+
else:
|
278
|
+
raise ValueError(f"Invalid mode: {forward_mode=}")
|
279
|
+
|
280
|
+
def init_forward_metadata_replay_cuda_graph(
|
281
|
+
self,
|
282
|
+
bs: int,
|
283
|
+
req_pool_indices: torch.Tensor,
|
284
|
+
seq_lens: torch.Tensor,
|
285
|
+
seq_lens_sum: int,
|
286
|
+
encoder_lens: Optional[torch.Tensor],
|
287
|
+
forward_mode: ForwardMode,
|
288
|
+
spec_info: Optional[SpecInfo],
|
289
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
290
|
+
):
|
291
|
+
if forward_mode.is_decode_or_idle():
|
292
|
+
kv_indptr = self.kv_indptr
|
293
|
+
kv_indices = self.cuda_graph_kv_indices
|
294
|
+
if spec_info is None:
|
295
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
296
|
+
kv_indptr = kv_indptr[: bs + 1]
|
297
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
298
|
+
self.req_to_token,
|
299
|
+
req_pool_indices[:bs],
|
300
|
+
seq_lens[:bs],
|
301
|
+
kv_indptr,
|
302
|
+
None,
|
303
|
+
kv_indices,
|
304
|
+
self.req_to_token.stride(0),
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
308
|
+
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
309
|
+
|
310
|
+
elif forward_mode.is_target_verify():
|
311
|
+
self.indices_updater_prefill.update(
|
312
|
+
req_pool_indices[:bs],
|
313
|
+
seq_lens[:bs],
|
314
|
+
seq_lens_sum,
|
315
|
+
prefix_lens=None,
|
316
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
317
|
+
spec_info=spec_info,
|
318
|
+
)
|
319
|
+
else:
|
320
|
+
raise ValueError("Invalid forward mode")
|
321
|
+
|
322
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
323
|
+
return 1
|
324
|
+
|
325
|
+
def forward_extend(
|
326
|
+
self,
|
327
|
+
q: torch.Tensor,
|
328
|
+
k: torch.Tensor,
|
329
|
+
v: torch.Tensor,
|
330
|
+
layer: RadixAttention,
|
331
|
+
forward_batch: ForwardBatch,
|
332
|
+
save_kv_cache=True,
|
333
|
+
):
|
334
|
+
cache_loc = (
|
335
|
+
forward_batch.out_cache_loc
|
336
|
+
if not layer.is_cross_attention
|
337
|
+
else forward_batch.encoder_out_cache_loc
|
338
|
+
)
|
339
|
+
|
340
|
+
self.logits_soft_cap = layer.logit_cap
|
341
|
+
|
342
|
+
if k is not None:
|
343
|
+
assert v is not None
|
344
|
+
if save_kv_cache:
|
345
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
346
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
347
|
+
)
|
348
|
+
|
349
|
+
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
350
|
+
|
351
|
+
bs0 = forward_batch.batch_size + 1
|
352
|
+
|
353
|
+
o = mha_batch_prefill_func(
|
354
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
355
|
+
k_cache,
|
356
|
+
v_cache,
|
357
|
+
self.qo_indptr[:bs0],
|
358
|
+
self.forward_metadata.kv_indptr[:bs0],
|
359
|
+
self.forward_metadata.kv_indices,
|
360
|
+
self.forward_metadata.max_q_len,
|
361
|
+
self.forward_metadata.max_kv_len,
|
362
|
+
causal=True,
|
363
|
+
logits_soft_cap=self.logits_soft_cap,
|
364
|
+
alibi_slopes=None,
|
365
|
+
return_lse=False,
|
366
|
+
return_attn_probs=False,
|
367
|
+
)
|
368
|
+
|
369
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
370
|
+
|
371
|
+
def forward_decode(
|
372
|
+
self,
|
373
|
+
q: torch.Tensor,
|
374
|
+
k: torch.Tensor,
|
375
|
+
v: torch.Tensor,
|
376
|
+
layer: RadixAttention,
|
377
|
+
forward_batch: ForwardBatch,
|
378
|
+
save_kv_cache=True,
|
379
|
+
):
|
380
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
381
|
+
|
382
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
383
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
384
|
+
else:
|
385
|
+
o = torch.empty_like(q)
|
386
|
+
|
387
|
+
if save_kv_cache:
|
388
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
389
|
+
layer, forward_batch.out_cache_loc, k, v
|
390
|
+
)
|
391
|
+
|
392
|
+
self.logits_soft_cap = layer.logit_cap
|
393
|
+
paged_attention_ragged(
|
394
|
+
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
395
|
+
self.workspace_buffer,
|
396
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
397
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
398
|
+
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
399
|
+
),
|
400
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
401
|
+
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
402
|
+
),
|
403
|
+
self.scale,
|
404
|
+
self.forward_metadata.kv_indptr,
|
405
|
+
self.forward_metadata.kv_indices,
|
406
|
+
self.kv_last_page_lens,
|
407
|
+
1,
|
408
|
+
self.max_num_partitions,
|
409
|
+
None,
|
410
|
+
"auto",
|
411
|
+
"NHD",
|
412
|
+
self.logits_soft_cap,
|
413
|
+
self.k_scale,
|
414
|
+
self.v_scale,
|
415
|
+
None,
|
416
|
+
_AITER_PARTITION_SIZE_ROCM,
|
417
|
+
)
|
418
|
+
|
419
|
+
return o
|
420
|
+
|
421
|
+
|
422
|
+
class AiterIndicesUpdaterPrefill:
|
423
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
424
|
+
# Parse Constants
|
425
|
+
self.num_qo_heads = (
|
426
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
427
|
+
)
|
428
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
429
|
+
get_attention_tp_size()
|
430
|
+
)
|
431
|
+
self.head_dim = model_runner.model_config.head_dim
|
432
|
+
self.data_type = model_runner.kv_cache_dtype
|
433
|
+
self.q_data_type = model_runner.dtype
|
434
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
435
|
+
self.attn_backend = attn_backend
|
436
|
+
|
437
|
+
# Buffers and wrappers
|
438
|
+
self.kv_indptr = attn_backend.kv_indptr
|
439
|
+
self.kv_last_page_len = attn_backend.kv_last_page_len
|
440
|
+
self.qo_indptr = attn_backend.qo_indptr
|
441
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
442
|
+
self.update = self.update_single_wrapper
|
443
|
+
|
444
|
+
self.kv_indices = None
|
445
|
+
self.max_q_len = 0
|
446
|
+
self.max_kv_len = 0
|
447
|
+
|
448
|
+
def update(
|
449
|
+
self,
|
450
|
+
req_pool_indices: torch.Tensor,
|
451
|
+
seq_lens: torch.Tensor,
|
452
|
+
seq_lens_sum: int,
|
453
|
+
prefix_lens: torch.Tensor,
|
454
|
+
encoder_lens: Optional[torch.Tensor],
|
455
|
+
spec_info: Optional[SpecInfo],
|
456
|
+
):
|
457
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
458
|
+
raise NotImplementedError()
|
459
|
+
|
460
|
+
def update_single_wrapper(
|
461
|
+
self,
|
462
|
+
req_pool_indices: torch.Tensor,
|
463
|
+
seq_lens: torch.Tensor,
|
464
|
+
seq_lens_sum: int,
|
465
|
+
prefix_lens: torch.Tensor,
|
466
|
+
encoder_lens: Optional[torch.Tensor],
|
467
|
+
spec_info: Optional[SpecInfo],
|
468
|
+
):
|
469
|
+
|
470
|
+
kv_start_idx = None
|
471
|
+
kv_indptr = self.kv_indptr
|
472
|
+
qo_indptr = self.qo_indptr
|
473
|
+
paged_kernel_lens = seq_lens
|
474
|
+
paged_kernel_lens_sum = seq_lens_sum
|
475
|
+
|
476
|
+
bs = len(req_pool_indices)
|
477
|
+
if spec_info is None:
|
478
|
+
# Normal extend
|
479
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
480
|
+
kv_indptr = kv_indptr[: bs + 1]
|
481
|
+
kv_indices = torch.empty(
|
482
|
+
paged_kernel_lens_sum + 256,
|
483
|
+
dtype=torch.int32,
|
484
|
+
device=req_pool_indices.device,
|
485
|
+
)
|
486
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
487
|
+
self.req_to_token,
|
488
|
+
req_pool_indices,
|
489
|
+
paged_kernel_lens,
|
490
|
+
kv_indptr,
|
491
|
+
kv_start_idx,
|
492
|
+
kv_indices,
|
493
|
+
self.req_to_token.shape[1],
|
494
|
+
)
|
495
|
+
|
496
|
+
self.max_kv_len = torch.max(paged_kernel_lens).item()
|
497
|
+
|
498
|
+
extend_lens = seq_lens - prefix_lens
|
499
|
+
self.max_q_len = torch.max(extend_lens).item()
|
500
|
+
|
501
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
502
|
+
qo_indptr = qo_indptr[: bs + 1]
|
503
|
+
custom_mask = None
|
504
|
+
else:
|
505
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
506
|
+
spec_info.generate_attn_arg_prefill(
|
507
|
+
req_pool_indices,
|
508
|
+
paged_kernel_lens,
|
509
|
+
self.req_to_token,
|
510
|
+
)
|
511
|
+
)
|
512
|
+
|
513
|
+
self.kv_indices = kv_indices
|
@@ -918,8 +918,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
918
918
|
and local_attn_metadata is not None
|
919
919
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
920
920
|
)
|
921
|
-
|
922
|
-
|
921
|
+
|
922
|
+
# When Spec Decode enabled, forward_decode would be called with two mode:
|
923
|
+
# 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
|
924
|
+
# 2. IDLE: we don’t need cascade attention, spec_info will be none in this case
|
925
|
+
use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
|
923
926
|
|
924
927
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
925
928
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
@@ -1165,7 +1168,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1165
1168
|
max_virtual_batches = max_bs * (
|
1166
1169
|
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1167
1170
|
)
|
1168
|
-
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1169
1171
|
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
1170
1172
|
|
1171
1173
|
self.decode_cuda_graph_local_attn_metadata = {
|
@@ -1177,7 +1179,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1177
1179
|
),
|
1178
1180
|
"local_block_table": torch.zeros(
|
1179
1181
|
max_virtual_batches,
|
1180
|
-
|
1182
|
+
max_pages_per_block,
|
1181
1183
|
dtype=torch.int32,
|
1182
1184
|
device=self.device,
|
1183
1185
|
),
|
@@ -1435,19 +1437,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1435
1437
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1436
1438
|
|
1437
1439
|
if self.attention_chunk_size is not None:
|
1438
|
-
metadata
|
1439
|
-
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
1440
|
-
"local_query_start_loc"
|
1441
|
-
],
|
1442
|
-
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
1443
|
-
"local_seqused_k"
|
1444
|
-
],
|
1445
|
-
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
1446
|
-
"local_block_table"
|
1447
|
-
],
|
1448
|
-
local_max_query_len=1,
|
1449
|
-
local_max_seq_len=1,
|
1450
|
-
)
|
1440
|
+
self._update_local_attn_metadata_for_capture(metadata, batch_size)
|
1451
1441
|
|
1452
1442
|
elif forward_mode.is_target_verify():
|
1453
1443
|
if self.topk <= 1:
|
@@ -1808,6 +1798,62 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1808
1798
|
)
|
1809
1799
|
metadata.local_attn_metadata = local_metadata
|
1810
1800
|
|
1801
|
+
def _update_local_attn_metadata_for_capture(
|
1802
|
+
self, metadata: FlashAttentionMetadata, bs: int
|
1803
|
+
):
|
1804
|
+
"""Update local attention metadata during CUDA graph capture phase.
|
1805
|
+
|
1806
|
+
This method calculates the exact buffer sizes needed for local attention metadata
|
1807
|
+
during the CUDA graph capture phase, optimizing memory usage by creating views of
|
1808
|
+
pre-allocated buffers with exactly the sizes needed.
|
1809
|
+
"""
|
1810
|
+
seq_lens_capture = metadata.cache_seqlens_int32
|
1811
|
+
max_seq_len = int(seq_lens_capture.max().item())
|
1812
|
+
page_table_capture = metadata.page_table
|
1813
|
+
|
1814
|
+
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
1815
|
+
seqlens_np = seq_lens_capture.cpu().numpy()
|
1816
|
+
(
|
1817
|
+
seqlens_q_local_np,
|
1818
|
+
cu_seqlens_q_local_np,
|
1819
|
+
seqlens_k_local_np,
|
1820
|
+
block_table_local_np,
|
1821
|
+
) = make_local_attention_virtual_batches(
|
1822
|
+
self.attention_chunk_size,
|
1823
|
+
cu_seqlens_q_np,
|
1824
|
+
seqlens_np,
|
1825
|
+
page_table_capture,
|
1826
|
+
self.page_size,
|
1827
|
+
)
|
1828
|
+
|
1829
|
+
# Get exact dimensions from the calculation
|
1830
|
+
q_len = len(cu_seqlens_q_local_np)
|
1831
|
+
k_len = len(seqlens_k_local_np)
|
1832
|
+
b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
|
1833
|
+
b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
|
1834
|
+
|
1835
|
+
# Create views of the pre-allocated buffers with exactly these sizes
|
1836
|
+
# This is the key optimization - we only use the memory we actually need
|
1837
|
+
local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
|
1838
|
+
"local_query_start_loc"
|
1839
|
+
][:q_len]
|
1840
|
+
|
1841
|
+
local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
|
1842
|
+
:k_len
|
1843
|
+
]
|
1844
|
+
|
1845
|
+
local_block_table = self.decode_cuda_graph_local_attn_metadata[
|
1846
|
+
"local_block_table"
|
1847
|
+
][:b0, :b1]
|
1848
|
+
|
1849
|
+
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1850
|
+
local_query_start_loc=local_query_start_loc,
|
1851
|
+
local_seqused_k=local_seqused_k,
|
1852
|
+
local_block_table=local_block_table,
|
1853
|
+
local_max_query_len=1,
|
1854
|
+
local_max_seq_len=max_seq_len,
|
1855
|
+
)
|
1856
|
+
|
1811
1857
|
def _update_local_attn_metadata_for_replay(
|
1812
1858
|
self, metadata: FlashAttentionMetadata, bs: int
|
1813
1859
|
):
|
@@ -346,7 +346,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
346
346
|
cache_loc = forward_batch.out_cache_loc
|
347
347
|
logits_soft_cap = layer.logit_cap
|
348
348
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
349
|
-
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
350
349
|
|
351
350
|
# Save kv cache
|
352
351
|
if save_kv_cache and k is not None:
|
@@ -381,6 +380,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
381
380
|
)
|
382
381
|
else:
|
383
382
|
# mla paged prefill
|
383
|
+
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
384
|
+
q.dtype
|
385
|
+
)
|
384
386
|
if q_rope is None:
|
385
387
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
386
388
|
q, q_rope = (
|
@@ -442,7 +444,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
442
444
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
443
445
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
444
446
|
|
445
|
-
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
447
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
448
|
+
q.dtype
|
449
|
+
)
|
446
450
|
|
447
451
|
o = q_nope.new_empty(q_nope.shape)
|
448
452
|
# Direct call to run without the wrapper
|
@@ -467,7 +471,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
467
471
|
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
468
472
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
469
473
|
self.scaling = model_runner.model_config.scaling
|
470
|
-
self.data_type = model_runner.
|
474
|
+
self.data_type = model_runner.dtype
|
471
475
|
self.attn_backend = attn_backend
|
472
476
|
|
473
477
|
# Buffers and wrappers
|
@@ -577,7 +581,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
577
581
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
578
582
|
self.v_head_dim = model_runner.model_config.v_head_dim
|
579
583
|
self.scaling = model_runner.model_config.scaling
|
580
|
-
self.data_type = model_runner.
|
584
|
+
self.data_type = model_runner.dtype
|
581
585
|
self.q_data_type = model_runner.dtype
|
582
586
|
self.attn_backend = attn_backend
|
583
587
|
|