sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,56 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import numpy as np
|
4
|
-
|
5
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
6
|
-
|
7
|
-
"""
|
8
|
-
Support different attention backends.
|
9
|
-
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
10
|
-
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
11
|
-
"""
|
12
|
-
|
13
3
|
from dataclasses import dataclass
|
14
4
|
from typing import TYPE_CHECKING, Optional, Union
|
15
5
|
|
6
|
+
import numpy as np
|
16
7
|
import torch
|
17
8
|
|
18
9
|
from sglang.srt.configs.model_config import AttentionArch
|
19
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
21
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
22
14
|
|
23
15
|
if TYPE_CHECKING:
|
24
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
25
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
26
18
|
|
27
|
-
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
19
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
28
20
|
|
29
21
|
|
30
22
|
@dataclass
|
31
23
|
class FlashAttentionMetadata:
|
32
24
|
"""Metadata to be init once in the model forward pass,
|
33
|
-
each layer's forward pass can reuse the metadata.
|
25
|
+
each layer's forward pass can reuse the metadata.
|
34
26
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
27
|
+
For each init metadata function, we will try set up them in below order
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Sequence lengths for the forward batch
|
31
|
+
cache_seqlens_int32: torch.Tensor = None
|
39
32
|
# Maximum sequence length for query
|
40
33
|
max_seq_len_q: int = 0
|
41
34
|
# Maximum sequence length for key
|
42
35
|
max_seq_len_k: int = 0
|
36
|
+
# Cumulative sequence lengths for query
|
37
|
+
cu_seqlens_q: torch.Tensor = None
|
38
|
+
# Cumulative sequence lengths for key
|
39
|
+
cu_seqlens_k: torch.Tensor = None
|
43
40
|
# Window size (typically used by Gemma)
|
44
41
|
window_size: tuple = (-1, -1)
|
45
42
|
# Page table, the index of KV Cache Tables/Blocks
|
46
43
|
page_table: torch.Tensor = None
|
44
|
+
|
45
|
+
# Encoder metadata
|
46
|
+
# Cumulative sequence lengths for encoder key
|
47
|
+
encoder_cu_seqlens_k: torch.Tensor = None
|
48
|
+
# Maximum sequence length for encoder key
|
49
|
+
encoder_max_seq_len_k: int = 0
|
47
50
|
# Sequence lengths for the forward batch
|
48
|
-
|
51
|
+
encoder_lens_int32: torch.Tensor = None
|
52
|
+
# Page table for the encoder
|
53
|
+
encoder_page_table: torch.Tensor = None
|
49
54
|
|
50
55
|
@dataclass
|
51
56
|
class LocalAttentionMetadata:
|
@@ -137,6 +142,16 @@ def make_local_attention_virtual_batches(
|
|
137
142
|
seqlens_k_local: Key sequence lengths for local attention
|
138
143
|
block_table_local: Block table for local attention
|
139
144
|
"""
|
145
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
146
|
+
# to avoid index out of bounds errors
|
147
|
+
max_seq_len = seq_lens_np.max()
|
148
|
+
effective_chunk_size = min(attn_chunk_size, max_seq_len)
|
149
|
+
# Make sure effective_chunk_size is divisible by page_size
|
150
|
+
effective_chunk_size = (effective_chunk_size // page_size) * page_size
|
151
|
+
if effective_chunk_size < page_size:
|
152
|
+
effective_chunk_size = page_size
|
153
|
+
attn_chunk_size = effective_chunk_size
|
154
|
+
|
140
155
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
141
156
|
actual_batch_size = seq_lens_np.shape[0]
|
142
157
|
|
@@ -231,7 +246,11 @@ def make_local_attention_virtual_batches(
|
|
231
246
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
232
247
|
(virtual_batches, pages_per_local_batch),
|
233
248
|
) + np.expand_dims(block_starts, axis=1)
|
234
|
-
block_indices
|
249
|
+
# Ensure block_indices doesn't exceed block_table dimensions
|
250
|
+
# This is a critical safety check that prevents index out of bounds errors
|
251
|
+
# when dealing with large sequences (>8192 tokens) or when the block_table
|
252
|
+
# dimensions are smaller than what would be needed for the full attention chunk size.
|
253
|
+
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
235
254
|
batch_indices = np.repeat(
|
236
255
|
np.arange(actual_batch_size, dtype=np.int32),
|
237
256
|
local_blocks * pages_per_local_batch,
|
@@ -270,9 +289,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
270
289
|
self,
|
271
290
|
model_runner: ModelRunner,
|
272
291
|
skip_prefill: bool = False,
|
292
|
+
speculative_step_id=0,
|
273
293
|
topk=0,
|
274
294
|
speculative_num_steps=0,
|
275
|
-
step_id=0,
|
276
295
|
):
|
277
296
|
super().__init__()
|
278
297
|
|
@@ -287,20 +306,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
287
306
|
self.decode_cuda_graph_metadata = {}
|
288
307
|
self.target_verify_metadata = {}
|
289
308
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
309
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
310
|
+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
290
311
|
self.page_size = model_runner.page_size
|
291
|
-
self.use_mla =
|
292
|
-
model_runner.model_config.attention_arch == AttentionArch.MLA
|
293
|
-
) and (not global_server_args_dict["disable_mla"])
|
312
|
+
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
294
313
|
self.skip_prefill = skip_prefill
|
295
314
|
|
296
|
-
|
297
|
-
assert (
|
298
|
-
topk <= 1
|
299
|
-
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
300
|
-
|
301
|
-
self.topk = 1
|
302
|
-
self.step_id = step_id
|
315
|
+
self.topk = topk
|
303
316
|
self.speculative_num_steps = speculative_num_steps
|
317
|
+
self.speculative_num_draft_tokens = (
|
318
|
+
model_runner.server_args.speculative_num_draft_tokens
|
319
|
+
)
|
320
|
+
self.speculative_step_id = speculative_step_id
|
304
321
|
|
305
322
|
# Local attention settings
|
306
323
|
self.attention_chunk_size = (
|
@@ -310,71 +327,63 @@ class FlashAttentionBackend(AttentionBackend):
|
|
310
327
|
)
|
311
328
|
|
312
329
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
313
|
-
"""Initialize forward metadata
|
330
|
+
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
314
331
|
metadata = FlashAttentionMetadata()
|
315
332
|
seqlens_in_batch = forward_batch.seq_lens
|
316
333
|
batch_size = len(seqlens_in_batch)
|
317
334
|
device = seqlens_in_batch.device
|
318
|
-
|
319
|
-
|
320
|
-
#
|
335
|
+
|
336
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
337
|
+
# Draft Decode
|
321
338
|
if forward_batch.spec_info is not None:
|
339
|
+
metadata.cache_seqlens_int32 = (
|
340
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
341
|
+
).to(torch.int32)
|
342
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
343
|
+
self.speculative_step_id + 1
|
344
|
+
)
|
322
345
|
metadata.cu_seqlens_q = torch.arange(
|
323
346
|
0, batch_size + 1, dtype=torch.int32, device=device
|
324
347
|
)
|
325
|
-
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
326
|
-
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
327
348
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
328
349
|
torch.cumsum(
|
329
350
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
330
351
|
),
|
331
352
|
(1, 0),
|
332
353
|
)
|
333
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
334
|
-
self.step_id + 1
|
335
|
-
)
|
336
354
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
337
355
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
338
356
|
]
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
344
|
-
real_bsz_start_idx = idx
|
345
|
-
real_bsz_end_idx = idx + 1
|
346
|
-
metadata.page_table[
|
347
|
-
real_bsz_start_idx:real_bsz_end_idx,
|
348
|
-
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
349
|
-
] = cache_loc[
|
350
|
-
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
351
|
-
]
|
352
|
-
else: # Normal Decode without Spec Decoding
|
357
|
+
|
358
|
+
self._init_local_attn_metadata(metadata, device)
|
359
|
+
else:
|
360
|
+
# Normal Decode
|
353
361
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
362
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
363
|
+
metadata.cu_seqlens_q = torch.arange(
|
364
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
365
|
+
)
|
354
366
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
355
367
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
356
368
|
)
|
357
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
358
369
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
359
370
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
360
371
|
]
|
361
|
-
|
362
|
-
|
363
|
-
)
|
372
|
+
|
373
|
+
self._init_local_attn_metadata(metadata, device)
|
364
374
|
elif forward_batch.forward_mode.is_target_verify():
|
365
|
-
# Note: Target Verify will be ran on the Target Worker
|
366
|
-
draft_token_num = forward_batch.spec_info.draft_token_num
|
367
375
|
metadata.cache_seqlens_int32 = (
|
368
|
-
forward_batch.seq_lens +
|
376
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
369
377
|
).to(torch.int32)
|
370
|
-
metadata.max_seq_len_q =
|
378
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
371
379
|
metadata.max_seq_len_k = (
|
372
|
-
forward_batch.seq_lens_cpu.max().item()
|
380
|
+
forward_batch.seq_lens_cpu.max().item()
|
381
|
+
+ self.speculative_num_draft_tokens
|
373
382
|
)
|
374
383
|
metadata.cu_seqlens_q = torch.arange(
|
375
384
|
0,
|
376
|
-
batch_size *
|
377
|
-
|
385
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
386
|
+
self.speculative_num_draft_tokens,
|
378
387
|
dtype=torch.int32,
|
379
388
|
device=device,
|
380
389
|
)
|
@@ -386,79 +395,58 @@ class FlashAttentionBackend(AttentionBackend):
|
|
386
395
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
387
396
|
]
|
388
397
|
|
389
|
-
elif forward_batch.forward_mode.
|
390
|
-
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
398
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
391
399
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
400
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
392
401
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
393
402
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
394
403
|
)
|
395
|
-
# Precompute maximum sequence length
|
396
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
397
|
-
# Precompute page table
|
398
404
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
399
405
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
400
406
|
]
|
401
407
|
|
402
|
-
# Precompute cumulative sequence lengths
|
403
408
|
if (
|
404
409
|
any(forward_batch.extend_prefix_lens_cpu)
|
405
410
|
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
406
411
|
):
|
407
412
|
extend_seq_lens = forward_batch.extend_seq_lens
|
413
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
408
414
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
409
415
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
410
416
|
)
|
411
|
-
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
412
417
|
else:
|
413
|
-
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
414
418
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
419
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
415
420
|
|
416
421
|
# Setup local attention if enabled
|
417
|
-
if
|
418
|
-
self.
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
# Create local attention metadata
|
437
|
-
(
|
438
|
-
seqlens_q_local_np,
|
439
|
-
cu_seqlens_q_local_np,
|
440
|
-
seqlens_k_local_np,
|
441
|
-
block_table_local,
|
442
|
-
) = make_local_attention_virtual_batches(
|
443
|
-
effective_chunk_size,
|
444
|
-
cu_seqlens_q_np,
|
445
|
-
seq_lens_np,
|
446
|
-
metadata.page_table,
|
447
|
-
self.page_size,
|
448
|
-
)
|
422
|
+
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
423
|
+
self._init_local_attn_metadata(metadata, device)
|
424
|
+
|
425
|
+
# Encoder metadata for cross attention
|
426
|
+
if forward_batch.encoder_lens is not None:
|
427
|
+
assert (
|
428
|
+
forward_batch.encoder_lens.numel() == 1
|
429
|
+
), "Only encoder size 1 is supported for now"
|
430
|
+
|
431
|
+
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
432
|
+
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
433
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
434
|
+
(1, 0),
|
435
|
+
)
|
436
|
+
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
437
|
+
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
438
|
+
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
439
|
+
]
|
449
440
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
local_max_seq_len=seqlens_k_local_np.max(),
|
458
|
-
)
|
459
|
-
metadata.local_attn_metadata = local_metadata
|
441
|
+
# Currently only support forward_batch.encoder_lens.numel() == 1
|
442
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
443
|
+
forward_batch.req_pool_indices,
|
444
|
+
metadata.encoder_max_seq_len_k : (
|
445
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
446
|
+
),
|
447
|
+
]
|
460
448
|
|
461
|
-
#
|
449
|
+
# Convert the page table to a strided format which is needed by FA3 API
|
462
450
|
if self.page_size > 1:
|
463
451
|
self.strided_indices = torch.arange(
|
464
452
|
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
@@ -498,7 +486,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
498
486
|
v,
|
499
487
|
)
|
500
488
|
|
501
|
-
# Use precomputed metadata
|
489
|
+
# Use precomputed metadata across all layers
|
502
490
|
metadata = self.forward_metadata
|
503
491
|
|
504
492
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
@@ -506,9 +494,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
506
494
|
# here is two side inclusive
|
507
495
|
window_size = (
|
508
496
|
(layer.sliding_window_size, 0)
|
509
|
-
if layer.sliding_window_size is not None
|
497
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
510
498
|
else (-1, -1)
|
511
499
|
)
|
500
|
+
k_descale, v_descale = None, None
|
501
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
502
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
503
|
+
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
|
504
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
505
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
506
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
507
|
+
q = q.to(self.kv_cache_dtype)
|
508
|
+
causal = not layer.is_cross_attention
|
512
509
|
|
513
510
|
# Check if we should use local attention
|
514
511
|
use_local_attn = (
|
@@ -536,14 +533,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
536
533
|
# Use Flash Attention for prefill
|
537
534
|
if not self.use_mla:
|
538
535
|
# Do multi-head attention
|
539
|
-
|
540
|
-
|
536
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
537
|
+
layer.layer_id
|
538
|
+
)
|
541
539
|
key_cache = key_cache.view(
|
542
540
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
543
541
|
)
|
544
542
|
value_cache = value_cache.view(
|
545
543
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
546
544
|
)
|
545
|
+
if layer.is_cross_attention:
|
546
|
+
page_table = metadata.encoder_page_table
|
547
|
+
cache_seqlens = metadata.encoder_lens_int32
|
548
|
+
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
549
|
+
window_size = (-1, -1)
|
550
|
+
|
547
551
|
o = flash_attn_with_kvcache(
|
548
552
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
549
553
|
k_cache=key_cache,
|
@@ -554,48 +558,93 @@ class FlashAttentionBackend(AttentionBackend):
|
|
554
558
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
555
559
|
max_seqlen_q=max_seqlen_q,
|
556
560
|
softmax_scale=layer.scaling,
|
557
|
-
causal=
|
561
|
+
causal=causal,
|
558
562
|
window_size=window_size,
|
559
563
|
softcap=layer.logit_cap,
|
560
|
-
k_descale=
|
561
|
-
v_descale=
|
564
|
+
k_descale=k_descale,
|
565
|
+
v_descale=v_descale,
|
562
566
|
)
|
567
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
563
568
|
else:
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
569
|
+
if (
|
570
|
+
not global_server_args_dict["disable_chunked_prefix_cache"]
|
571
|
+
and forward_batch.attn_attend_prefix_cache is not None
|
572
|
+
and not forward_batch.forward_mode.is_target_verify()
|
573
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
574
|
+
):
|
575
|
+
# Do multi-head attention with chunked prefix cache
|
576
|
+
|
577
|
+
if forward_batch.attn_attend_prefix_cache:
|
578
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
579
|
+
assert forward_batch.prefix_chunk_idx is not None
|
580
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
581
|
+
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
582
|
+
|
583
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
584
|
+
assert chunk_idx >= 0
|
585
|
+
|
586
|
+
output, lse, *rest = flash_attn_varlen_func(
|
587
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
588
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
589
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
590
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
591
|
+
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
592
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
593
|
+
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
594
|
+
softmax_scale=layer.scaling,
|
595
|
+
causal=False,
|
596
|
+
return_softmax_lse=True,
|
597
|
+
)
|
598
|
+
else:
|
599
|
+
# MHA for extend part of sequence without attending prefix kv cache
|
600
|
+
output, lse, *rest = flash_attn_varlen_func(
|
601
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
602
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
603
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
604
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
605
|
+
cu_seqlens_k=metadata.cu_seqlens_q,
|
606
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
607
|
+
max_seqlen_k=metadata.max_seq_len_q,
|
608
|
+
softmax_scale=layer.scaling,
|
609
|
+
causal=True,
|
610
|
+
return_softmax_lse=True,
|
611
|
+
)
|
612
|
+
return output, lse
|
613
|
+
else:
|
614
|
+
# Do absorbed multi-latent attention
|
615
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
616
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
617
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
618
|
+
k_rope_cache = k_rope.view(
|
619
|
+
-1,
|
620
|
+
self.page_size,
|
621
|
+
layer.tp_k_head_num,
|
622
|
+
layer.head_dim - layer.v_head_dim,
|
623
|
+
)
|
624
|
+
c_kv_cache = c_kv.view(
|
625
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
626
|
+
)
|
627
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
628
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
629
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
630
|
+
o = flash_attn_with_kvcache(
|
631
|
+
q=q_rope,
|
632
|
+
k_cache=k_rope_cache,
|
633
|
+
v_cache=c_kv_cache,
|
634
|
+
qv=q_nope,
|
635
|
+
page_table=page_table,
|
636
|
+
cache_seqlens=cache_seqlens,
|
637
|
+
cu_seqlens_q=cu_seqlens_q,
|
638
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
639
|
+
max_seqlen_q=max_seqlen_q,
|
640
|
+
softmax_scale=layer.scaling,
|
641
|
+
causal=True,
|
642
|
+
softcap=layer.logit_cap,
|
643
|
+
k_descale=k_descale,
|
644
|
+
v_descale=v_descale,
|
645
|
+
)
|
597
646
|
|
598
|
-
|
647
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
599
648
|
|
600
649
|
def forward_decode(
|
601
650
|
self,
|
@@ -606,8 +655,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
606
655
|
forward_batch: ForwardBatch,
|
607
656
|
save_kv_cache=True,
|
608
657
|
) -> torch.Tensor:
|
609
|
-
"""Forward pass with FlashAttention using precomputed metadata."""
|
610
|
-
# Save KV cache if needed
|
611
658
|
if k is not None:
|
612
659
|
assert v is not None
|
613
660
|
if save_kv_cache:
|
@@ -628,25 +675,39 @@ class FlashAttentionBackend(AttentionBackend):
|
|
628
675
|
v,
|
629
676
|
)
|
630
677
|
|
631
|
-
# Use precomputed metadata
|
678
|
+
# Use precomputed metadata across all layers
|
632
679
|
metadata = self.forward_metadata
|
680
|
+
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
681
|
+
use_local_attention = (
|
682
|
+
self.attention_chunk_size is not None and local_attn_metadata is not None
|
683
|
+
)
|
633
684
|
|
634
685
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
635
686
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
636
687
|
# here is two side inclusive
|
637
688
|
window_size = (
|
638
689
|
(layer.sliding_window_size, 0)
|
639
|
-
if layer.sliding_window_size is not None
|
690
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
640
691
|
else (-1, -1)
|
641
692
|
)
|
642
|
-
|
693
|
+
causal = not layer.is_cross_attention
|
694
|
+
|
695
|
+
k_descale, v_descale = None, None
|
696
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
697
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
698
|
+
if self.kv_cache_dtype_str != "auto":
|
699
|
+
if layer.k_scale is not None:
|
700
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
701
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
702
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
703
|
+
q = q.to(self.kv_cache_dtype)
|
643
704
|
|
644
705
|
if not self.use_mla:
|
645
706
|
# Do multi-head attention
|
646
707
|
|
647
|
-
|
648
|
-
|
649
|
-
|
708
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
709
|
+
layer.layer_id
|
710
|
+
)
|
650
711
|
key_cache = key_cache.view(
|
651
712
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
652
713
|
)
|
@@ -654,24 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
|
|
654
715
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
655
716
|
)
|
656
717
|
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
718
|
+
if layer.is_cross_attention:
|
719
|
+
# Always use non-chunked logic for cross-attention
|
720
|
+
o = flash_attn_with_kvcache(
|
721
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
722
|
+
k_cache=key_cache,
|
723
|
+
v_cache=value_cache,
|
724
|
+
page_table=metadata.encoder_page_table,
|
725
|
+
cache_seqlens=metadata.encoder_lens_int32,
|
726
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
727
|
+
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
728
|
+
max_seqlen_q=1,
|
729
|
+
softmax_scale=layer.scaling,
|
730
|
+
causal=False,
|
731
|
+
window_size=(-1, -1),
|
732
|
+
softcap=layer.logit_cap,
|
733
|
+
k_descale=k_descale,
|
734
|
+
v_descale=v_descale,
|
735
|
+
)
|
736
|
+
elif use_local_attention:
|
737
|
+
# Use chunked (local) attention batching for self-attention
|
738
|
+
o = flash_attn_with_kvcache(
|
739
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
740
|
+
k_cache=key_cache,
|
741
|
+
v_cache=value_cache,
|
742
|
+
page_table=local_attn_metadata.local_block_table,
|
743
|
+
cache_seqlens=local_attn_metadata.local_seqused_k,
|
744
|
+
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
745
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
746
|
+
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
747
|
+
softmax_scale=layer.scaling,
|
748
|
+
causal=True,
|
749
|
+
window_size=(-1, -1),
|
750
|
+
softcap=layer.logit_cap,
|
751
|
+
k_descale=k_descale,
|
752
|
+
v_descale=v_descale,
|
753
|
+
)
|
754
|
+
else:
|
755
|
+
# Default: single-token self-attention
|
756
|
+
o = flash_attn_with_kvcache(
|
757
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
758
|
+
k_cache=key_cache,
|
759
|
+
v_cache=value_cache,
|
760
|
+
page_table=metadata.page_table,
|
761
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
762
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
763
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
764
|
+
max_seqlen_q=1,
|
765
|
+
softmax_scale=layer.scaling,
|
766
|
+
causal=True,
|
767
|
+
window_size=window_size,
|
768
|
+
softcap=layer.logit_cap,
|
769
|
+
k_descale=k_descale,
|
770
|
+
v_descale=v_descale,
|
771
|
+
)
|
675
772
|
else:
|
676
773
|
# Do absorbed multi-latent attention
|
677
774
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
@@ -696,7 +793,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
696
793
|
k_cache=k_rope_cache,
|
697
794
|
v_cache=c_kv_cache,
|
698
795
|
qv=q_nope,
|
699
|
-
page_table=page_table,
|
796
|
+
page_table=metadata.page_table,
|
700
797
|
cache_seqlens=metadata.cache_seqlens_int32,
|
701
798
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
702
799
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
@@ -704,8 +801,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
704
801
|
softmax_scale=layer.scaling,
|
705
802
|
causal=True,
|
706
803
|
softcap=layer.logit_cap,
|
707
|
-
k_descale=
|
708
|
-
v_descale=
|
804
|
+
k_descale=k_descale,
|
805
|
+
v_descale=v_descale,
|
709
806
|
)
|
710
807
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
711
808
|
|
@@ -719,7 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
719
816
|
to avoid memory allocations.
|
720
817
|
"""
|
721
818
|
self.decode_cuda_graph_metadata = {
|
722
|
-
|
819
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
820
|
+
"cu_seqlens_q": torch.arange(
|
821
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
822
|
+
),
|
823
|
+
"cu_seqlens_k": torch.zeros(
|
824
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
825
|
+
),
|
723
826
|
"page_table": torch.zeros(
|
724
827
|
max_bs,
|
725
828
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
@@ -735,35 +838,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|
735
838
|
"strided_indices": torch.arange(
|
736
839
|
0, self.max_context_len, self.page_size, device=self.device
|
737
840
|
),
|
841
|
+
}
|
842
|
+
|
843
|
+
self.target_verify_metadata = {
|
738
844
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
739
|
-
"cu_seqlens_q": torch.
|
740
|
-
|
845
|
+
"cu_seqlens_q": torch.zeros(
|
846
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
741
847
|
),
|
742
848
|
"cu_seqlens_k": torch.zeros(
|
743
|
-
max_bs +
|
849
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
744
850
|
),
|
745
|
-
}
|
746
|
-
|
747
|
-
self.target_verify_metadata = {
|
748
851
|
"page_table": torch.zeros(
|
749
852
|
max_bs,
|
750
853
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
751
854
|
dtype=torch.int32,
|
752
855
|
device=self.device,
|
753
856
|
),
|
754
|
-
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
755
|
-
"cu_seqlens_q": torch.zeros(
|
756
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
757
|
-
),
|
758
|
-
"cu_seqlens_k": torch.zeros(
|
759
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
760
|
-
),
|
761
|
-
"max_seqlen_q": 0,
|
762
857
|
"strided_indices": torch.arange(
|
763
858
|
0, self.max_context_len, self.page_size, device=self.device
|
764
859
|
),
|
765
860
|
}
|
766
861
|
|
862
|
+
self.encoder_metadata = {
|
863
|
+
"encoder_page_table": torch.zeros(
|
864
|
+
max_bs,
|
865
|
+
self.max_context_len,
|
866
|
+
dtype=torch.int32,
|
867
|
+
device=self.device,
|
868
|
+
),
|
869
|
+
"encoder_lens_int32": torch.zeros(
|
870
|
+
max_bs, dtype=torch.int32, device=self.device
|
871
|
+
),
|
872
|
+
"encoder_cu_seqlens_k": torch.zeros(
|
873
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
874
|
+
),
|
875
|
+
}
|
876
|
+
|
767
877
|
def init_forward_metadata_capture_cuda_graph(
|
768
878
|
self,
|
769
879
|
bs: int,
|
@@ -777,27 +887,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
777
887
|
"""Initialize forward metadata for capturing CUDA graph."""
|
778
888
|
metadata = FlashAttentionMetadata()
|
779
889
|
device = seq_lens.device
|
780
|
-
if forward_mode.
|
890
|
+
if forward_mode.is_decode_or_idle():
|
781
891
|
if spec_info is not None:
|
782
892
|
# Draft Decode
|
783
|
-
metadata.cu_seqlens_q = torch.arange(
|
784
|
-
0, bs + 1, dtype=torch.int32, device=device
|
785
|
-
)
|
786
893
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
787
894
|
"cache_seqlens"
|
788
895
|
][:bs]
|
789
|
-
|
896
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
897
|
+
self.speculative_step_id + 1
|
898
|
+
)
|
790
899
|
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
791
900
|
: bs + 1
|
792
901
|
]
|
793
|
-
|
794
902
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
795
903
|
torch.cumsum(
|
796
904
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
797
905
|
),
|
798
906
|
(1, 0),
|
799
907
|
)
|
800
|
-
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
801
908
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
802
909
|
"page_table_draft_decode"
|
803
910
|
][req_pool_indices, :]
|
@@ -822,43 +929,49 @@ class FlashAttentionBackend(AttentionBackend):
|
|
822
929
|
)
|
823
930
|
self.decode_cuda_graph_metadata[bs] = metadata
|
824
931
|
elif forward_mode.is_target_verify():
|
825
|
-
draft_token_num = spec_info.draft_token_num
|
826
|
-
|
827
932
|
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
828
933
|
:bs
|
829
934
|
]
|
830
935
|
metadata.cache_seqlens_int32.copy_(
|
831
|
-
(seq_lens +
|
936
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
832
937
|
)
|
833
938
|
|
834
|
-
metadata.max_seq_len_q =
|
835
|
-
metadata.max_seq_len_k =
|
939
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
940
|
+
metadata.max_seq_len_k = (
|
941
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
942
|
+
)
|
836
943
|
|
837
|
-
metadata.cu_seqlens_q =
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
device=device,
|
844
|
-
)
|
845
|
-
]
|
846
|
-
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
847
|
-
cu_k.copy_(
|
848
|
-
torch.nn.functional.pad(
|
849
|
-
torch.cumsum(
|
850
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
851
|
-
),
|
852
|
-
(1, 0),
|
853
|
-
)
|
944
|
+
metadata.cu_seqlens_q = torch.arange(
|
945
|
+
0,
|
946
|
+
bs * self.speculative_num_draft_tokens + 1,
|
947
|
+
self.speculative_num_draft_tokens,
|
948
|
+
dtype=torch.int32,
|
949
|
+
device=device,
|
854
950
|
)
|
855
|
-
|
951
|
+
|
952
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
953
|
+
: (bs + 1)
|
954
|
+
]
|
955
|
+
|
856
956
|
metadata.page_table = self.target_verify_metadata["page_table"][
|
857
957
|
req_pool_indices, :
|
858
958
|
]
|
859
959
|
|
860
960
|
self.target_verify_metadata[bs] = metadata
|
861
961
|
|
962
|
+
if encoder_lens is not None:
|
963
|
+
encoder_bs = encoder_lens.numel()
|
964
|
+
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
|
965
|
+
:encoder_bs
|
966
|
+
]
|
967
|
+
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
|
968
|
+
"encoder_cu_seqlens_k"
|
969
|
+
][: (encoder_bs + 1)]
|
970
|
+
|
971
|
+
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
972
|
+
req_pool_indices, :
|
973
|
+
]
|
974
|
+
|
862
975
|
self.forward_metadata = metadata
|
863
976
|
|
864
977
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -874,24 +987,23 @@ class FlashAttentionBackend(AttentionBackend):
|
|
874
987
|
out_cache_loc: torch.Tensor = None,
|
875
988
|
):
|
876
989
|
# """Initialize forward metadata for replaying CUDA graph."""
|
877
|
-
device = seq_lens.device
|
878
990
|
seq_lens = seq_lens[:bs]
|
879
|
-
req_pool_indices = req_pool_indices[:bs]
|
880
991
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
881
|
-
|
992
|
+
req_pool_indices = req_pool_indices[:bs]
|
993
|
+
device = seq_lens.device
|
994
|
+
|
995
|
+
if forward_mode.is_decode_or_idle():
|
882
996
|
metadata = self.decode_cuda_graph_metadata[bs]
|
883
997
|
|
884
998
|
if spec_info is not None:
|
885
999
|
# Draft Decode
|
886
|
-
max_len = seq_lens_cpu.max().item()
|
887
|
-
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
888
|
-
|
889
1000
|
metadata.cache_seqlens_int32.copy_(
|
890
|
-
(seq_lens + (self.
|
1001
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
891
1002
|
)
|
892
1003
|
|
893
|
-
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
894
|
-
|
1004
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1005
|
+
self.speculative_step_id + 1
|
1006
|
+
)
|
895
1007
|
metadata.cu_seqlens_k.copy_(
|
896
1008
|
torch.nn.functional.pad(
|
897
1009
|
torch.cumsum(
|
@@ -906,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
906
1018
|
]
|
907
1019
|
|
908
1020
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1021
|
+
|
1022
|
+
self._init_local_attn_metadata(metadata, device)
|
909
1023
|
else:
|
910
1024
|
# Normal Decode
|
911
1025
|
max_len = seq_lens_cpu.max().item()
|
@@ -920,31 +1034,25 @@ class FlashAttentionBackend(AttentionBackend):
|
|
920
1034
|
metadata.max_seq_len_k + self.page_size - 1
|
921
1035
|
) // self.page_size
|
922
1036
|
page_indices = self.req_to_token[
|
923
|
-
:,
|
924
|
-
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
1037
|
+
req_pool_indices[:, None],
|
1038
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
1039
|
+
None, :
|
1040
|
+
],
|
925
1041
|
]
|
926
|
-
page_indices
|
1042
|
+
page_indices //= self.page_size
|
927
1043
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
928
1044
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
929
1045
|
|
1046
|
+
self._init_local_attn_metadata(metadata, device)
|
930
1047
|
elif forward_mode.is_target_verify():
|
931
1048
|
metadata = self.target_verify_metadata[bs]
|
932
|
-
draft_token_num = spec_info.draft_token_num
|
933
|
-
|
934
|
-
metadata.cu_seqlens_q.copy_(
|
935
|
-
torch.arange(
|
936
|
-
0,
|
937
|
-
bs * draft_token_num + 1,
|
938
|
-
draft_token_num,
|
939
|
-
dtype=torch.int32,
|
940
|
-
device=device,
|
941
|
-
)
|
942
|
-
)
|
943
1049
|
metadata.cache_seqlens_int32.copy_(
|
944
|
-
(seq_lens +
|
1050
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
945
1051
|
)
|
946
1052
|
|
947
|
-
metadata.max_seq_len_k =
|
1053
|
+
metadata.max_seq_len_k = (
|
1054
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1055
|
+
)
|
948
1056
|
metadata.cu_seqlens_k.copy_(
|
949
1057
|
torch.nn.functional.pad(
|
950
1058
|
torch.cumsum(
|
@@ -956,12 +1064,72 @@ class FlashAttentionBackend(AttentionBackend):
|
|
956
1064
|
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
957
1065
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
958
1066
|
|
1067
|
+
if encoder_lens is not None:
|
1068
|
+
# Only support encoder size 1 for now
|
1069
|
+
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1070
|
+
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1071
|
+
metadata.encoder_cu_seqlens_k.copy_(
|
1072
|
+
torch.nn.functional.pad(
|
1073
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1074
|
+
(1, 0),
|
1075
|
+
)
|
1076
|
+
)
|
1077
|
+
|
1078
|
+
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
1079
|
+
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
|
1080
|
+
)
|
1081
|
+
|
1082
|
+
# Update the regular page table
|
1083
|
+
page_table = self.req_to_token[
|
1084
|
+
req_pool_indices,
|
1085
|
+
metadata.encoder_max_seq_len_k : (
|
1086
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
1087
|
+
),
|
1088
|
+
]
|
1089
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1090
|
+
|
959
1091
|
self.forward_metadata = metadata
|
960
1092
|
|
961
1093
|
def get_cuda_graph_seq_len_fill_value(self):
|
962
1094
|
"""Get the fill value for sequence length in CUDA graph."""
|
963
1095
|
return 0
|
964
1096
|
|
1097
|
+
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1098
|
+
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
1099
|
+
if self.attention_chunk_size is None:
|
1100
|
+
metadata.local_attn_metadata = None
|
1101
|
+
return
|
1102
|
+
|
1103
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
1104
|
+
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
1105
|
+
page_table = metadata.page_table
|
1106
|
+
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
1107
|
+
metadata.local_attn_metadata = None
|
1108
|
+
return
|
1109
|
+
|
1110
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1111
|
+
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
1112
|
+
(
|
1113
|
+
seqlens_q_local_np,
|
1114
|
+
cu_seqlens_q_local_np,
|
1115
|
+
seqlens_k_local_np,
|
1116
|
+
block_table_local,
|
1117
|
+
) = make_local_attention_virtual_batches(
|
1118
|
+
self.attention_chunk_size,
|
1119
|
+
cu_seqlens_q_np,
|
1120
|
+
seq_lens_np,
|
1121
|
+
page_table,
|
1122
|
+
self.page_size,
|
1123
|
+
)
|
1124
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1125
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1126
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
1127
|
+
local_block_table=block_table_local.to(device),
|
1128
|
+
local_max_query_len=int(seqlens_q_local_np.max()),
|
1129
|
+
local_max_seq_len=int(seqlens_k_local_np.max()),
|
1130
|
+
)
|
1131
|
+
metadata.local_attn_metadata = local_metadata
|
1132
|
+
|
965
1133
|
|
966
1134
|
class FlashAttentionMultiStepBackend:
|
967
1135
|
|
@@ -972,14 +1140,19 @@ class FlashAttentionMultiStepBackend:
|
|
972
1140
|
self.topk = topk
|
973
1141
|
self.speculative_num_steps = speculative_num_steps
|
974
1142
|
|
1143
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1144
|
+
assert (
|
1145
|
+
self.topk == 1
|
1146
|
+
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1147
|
+
|
975
1148
|
self.attn_backends = []
|
976
1149
|
for i in range(self.speculative_num_steps):
|
977
1150
|
self.attn_backends.append(
|
978
1151
|
FlashAttentionBackend(
|
979
1152
|
model_runner,
|
1153
|
+
speculative_step_id=i,
|
980
1154
|
topk=self.topk,
|
981
1155
|
speculative_num_steps=self.speculative_num_steps,
|
982
|
-
step_id=i,
|
983
1156
|
)
|
984
1157
|
)
|
985
1158
|
|
@@ -1004,7 +1177,7 @@ class FlashAttentionMultiStepBackend:
|
|
1004
1177
|
forward_batch.batch_size * self.topk,
|
1005
1178
|
forward_batch.req_pool_indices,
|
1006
1179
|
forward_batch.seq_lens,
|
1007
|
-
encoder_lens=
|
1180
|
+
encoder_lens=forward_batch.encoder_lens,
|
1008
1181
|
forward_mode=ForwardMode.DECODE,
|
1009
1182
|
spec_info=forward_batch.spec_info,
|
1010
1183
|
)
|
@@ -1021,7 +1194,7 @@ class FlashAttentionMultiStepBackend:
|
|
1021
1194
|
forward_batch.req_pool_indices,
|
1022
1195
|
forward_batch.seq_lens,
|
1023
1196
|
forward_batch.seq_lens_sum,
|
1024
|
-
encoder_lens=
|
1197
|
+
encoder_lens=forward_batch.encoder_lens,
|
1025
1198
|
forward_mode=ForwardMode.DECODE,
|
1026
1199
|
spec_info=forward_batch.spec_info,
|
1027
1200
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|