sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,56 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import numpy as np
|
4
|
-
|
5
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
6
|
-
|
7
|
-
"""
|
8
|
-
Support different attention backends.
|
9
|
-
Now there are three backends: FlashInfer, Triton and FlashAttention.
|
10
|
-
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
11
|
-
"""
|
12
|
-
|
13
3
|
from dataclasses import dataclass
|
14
4
|
from typing import TYPE_CHECKING, Optional, Union
|
15
5
|
|
6
|
+
import numpy as np
|
16
7
|
import torch
|
17
8
|
|
18
9
|
from sglang.srt.configs.model_config import AttentionArch
|
19
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
21
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
22
14
|
|
23
15
|
if TYPE_CHECKING:
|
24
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
25
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
26
18
|
|
27
|
-
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
19
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
28
20
|
|
29
21
|
|
30
22
|
@dataclass
|
31
23
|
class FlashAttentionMetadata:
|
32
24
|
"""Metadata to be init once in the model forward pass,
|
33
|
-
each layer's forward pass can reuse the metadata.
|
25
|
+
each layer's forward pass can reuse the metadata.
|
34
26
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
27
|
+
For each init metadata function, we will try set up them in below order
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Sequence lengths for the forward batch
|
31
|
+
cache_seqlens_int32: torch.Tensor = None
|
39
32
|
# Maximum sequence length for query
|
40
33
|
max_seq_len_q: int = 0
|
41
34
|
# Maximum sequence length for key
|
42
35
|
max_seq_len_k: int = 0
|
36
|
+
# Cumulative sequence lengths for query
|
37
|
+
cu_seqlens_q: torch.Tensor = None
|
38
|
+
# Cumulative sequence lengths for key
|
39
|
+
cu_seqlens_k: torch.Tensor = None
|
43
40
|
# Window size (typically used by Gemma)
|
44
41
|
window_size: tuple = (-1, -1)
|
45
42
|
# Page table, the index of KV Cache Tables/Blocks
|
46
43
|
page_table: torch.Tensor = None
|
44
|
+
|
45
|
+
# Encoder metadata
|
46
|
+
# Cumulative sequence lengths for encoder key
|
47
|
+
encoder_cu_seqlens_k: torch.Tensor = None
|
48
|
+
# Maximum sequence length for encoder key
|
49
|
+
encoder_max_seq_len_k: int = 0
|
47
50
|
# Sequence lengths for the forward batch
|
48
|
-
|
51
|
+
encoder_lens_int32: torch.Tensor = None
|
52
|
+
# Page table for the encoder
|
53
|
+
encoder_page_table: torch.Tensor = None
|
49
54
|
|
50
55
|
@dataclass
|
51
56
|
class LocalAttentionMetadata:
|
@@ -231,7 +236,11 @@ def make_local_attention_virtual_batches(
|
|
231
236
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
232
237
|
(virtual_batches, pages_per_local_batch),
|
233
238
|
) + np.expand_dims(block_starts, axis=1)
|
234
|
-
block_indices
|
239
|
+
# Ensure block_indices doesn't exceed block_table dimensions
|
240
|
+
# This is a critical safety check that prevents index out of bounds errors
|
241
|
+
# when dealing with large sequences (>8192 tokens) or when the block_table
|
242
|
+
# dimensions are smaller than what would be needed for the full attention chunk size.
|
243
|
+
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
235
244
|
batch_indices = np.repeat(
|
236
245
|
np.arange(actual_batch_size, dtype=np.int32),
|
237
246
|
local_blocks * pages_per_local_batch,
|
@@ -270,9 +279,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
270
279
|
self,
|
271
280
|
model_runner: ModelRunner,
|
272
281
|
skip_prefill: bool = False,
|
282
|
+
speculative_step_id=0,
|
273
283
|
topk=0,
|
274
284
|
speculative_num_steps=0,
|
275
|
-
step_id=0,
|
276
285
|
):
|
277
286
|
super().__init__()
|
278
287
|
|
@@ -287,20 +296,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|
287
296
|
self.decode_cuda_graph_metadata = {}
|
288
297
|
self.target_verify_metadata = {}
|
289
298
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
299
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
300
|
+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
290
301
|
self.page_size = model_runner.page_size
|
291
302
|
self.use_mla = (
|
292
303
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
293
304
|
) and (not global_server_args_dict["disable_mla"])
|
294
305
|
self.skip_prefill = skip_prefill
|
295
306
|
|
296
|
-
|
297
|
-
assert (
|
298
|
-
topk <= 1
|
299
|
-
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
300
|
-
|
301
|
-
self.topk = 1
|
302
|
-
self.step_id = step_id
|
307
|
+
self.topk = topk
|
303
308
|
self.speculative_num_steps = speculative_num_steps
|
309
|
+
self.speculative_num_draft_tokens = (
|
310
|
+
model_runner.server_args.speculative_num_draft_tokens
|
311
|
+
)
|
312
|
+
self.speculative_step_id = speculative_step_id
|
304
313
|
|
305
314
|
# Local attention settings
|
306
315
|
self.attention_chunk_size = (
|
@@ -310,71 +319,59 @@ class FlashAttentionBackend(AttentionBackend):
|
|
310
319
|
)
|
311
320
|
|
312
321
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
313
|
-
"""Initialize forward metadata
|
322
|
+
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
314
323
|
metadata = FlashAttentionMetadata()
|
315
324
|
seqlens_in_batch = forward_batch.seq_lens
|
316
325
|
batch_size = len(seqlens_in_batch)
|
317
326
|
device = seqlens_in_batch.device
|
318
|
-
|
319
|
-
|
320
|
-
#
|
327
|
+
|
328
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
329
|
+
# Draft Decode
|
321
330
|
if forward_batch.spec_info is not None:
|
331
|
+
metadata.cache_seqlens_int32 = (
|
332
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
333
|
+
).to(torch.int32)
|
334
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
335
|
+
self.speculative_step_id + 1
|
336
|
+
)
|
322
337
|
metadata.cu_seqlens_q = torch.arange(
|
323
338
|
0, batch_size + 1, dtype=torch.int32, device=device
|
324
339
|
)
|
325
|
-
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
326
|
-
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
327
340
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
328
341
|
torch.cumsum(
|
329
342
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
330
343
|
),
|
331
344
|
(1, 0),
|
332
345
|
)
|
333
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
334
|
-
self.step_id + 1
|
335
|
-
)
|
336
346
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
337
347
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
338
348
|
]
|
339
|
-
|
340
|
-
|
341
|
-
).T
|
342
|
-
|
343
|
-
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
344
|
-
real_bsz_start_idx = idx
|
345
|
-
real_bsz_end_idx = idx + 1
|
346
|
-
metadata.page_table[
|
347
|
-
real_bsz_start_idx:real_bsz_end_idx,
|
348
|
-
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
349
|
-
] = cache_loc[
|
350
|
-
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
351
|
-
]
|
352
|
-
else: # Normal Decode without Spec Decoding
|
349
|
+
else:
|
350
|
+
# Normal Decode
|
353
351
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
352
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
353
|
+
metadata.cu_seqlens_q = torch.arange(
|
354
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
355
|
+
)
|
354
356
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
355
357
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
356
358
|
)
|
357
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
358
359
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
359
360
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
360
361
|
]
|
361
|
-
metadata.cu_seqlens_q = torch.arange(
|
362
|
-
0, batch_size + 1, dtype=torch.int32, device=device
|
363
|
-
)
|
364
362
|
elif forward_batch.forward_mode.is_target_verify():
|
365
|
-
# Note: Target Verify will be ran on the Target Worker
|
366
|
-
draft_token_num = forward_batch.spec_info.draft_token_num
|
367
363
|
metadata.cache_seqlens_int32 = (
|
368
|
-
forward_batch.seq_lens +
|
364
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
369
365
|
).to(torch.int32)
|
370
|
-
metadata.max_seq_len_q =
|
366
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
371
367
|
metadata.max_seq_len_k = (
|
372
|
-
forward_batch.seq_lens_cpu.max().item()
|
368
|
+
forward_batch.seq_lens_cpu.max().item()
|
369
|
+
+ self.speculative_num_draft_tokens
|
373
370
|
)
|
374
371
|
metadata.cu_seqlens_q = torch.arange(
|
375
372
|
0,
|
376
|
-
batch_size *
|
377
|
-
|
373
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
374
|
+
self.speculative_num_draft_tokens,
|
378
375
|
dtype=torch.int32,
|
379
376
|
device=device,
|
380
377
|
)
|
@@ -386,32 +383,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|
386
383
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
387
384
|
]
|
388
385
|
|
389
|
-
elif forward_batch.forward_mode.
|
390
|
-
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
386
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
391
387
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
388
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
392
389
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
393
390
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
394
391
|
)
|
395
|
-
# Precompute maximum sequence length
|
396
|
-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
397
|
-
# Precompute page table
|
398
392
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
399
393
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
400
394
|
]
|
401
395
|
|
402
|
-
# Precompute cumulative sequence lengths
|
403
396
|
if (
|
404
397
|
any(forward_batch.extend_prefix_lens_cpu)
|
405
398
|
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
406
399
|
):
|
407
400
|
extend_seq_lens = forward_batch.extend_seq_lens
|
401
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
408
402
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
409
403
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
410
404
|
)
|
411
|
-
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
412
405
|
else:
|
413
|
-
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
414
406
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
407
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
415
408
|
|
416
409
|
# Setup local attention if enabled
|
417
410
|
if (
|
@@ -458,7 +451,31 @@ class FlashAttentionBackend(AttentionBackend):
|
|
458
451
|
)
|
459
452
|
metadata.local_attn_metadata = local_metadata
|
460
453
|
|
461
|
-
#
|
454
|
+
# Encoder metadata for cross attention
|
455
|
+
if forward_batch.encoder_lens is not None:
|
456
|
+
assert (
|
457
|
+
forward_batch.encoder_lens.numel() == 1
|
458
|
+
), "Only encoder size 1 is supported for now"
|
459
|
+
|
460
|
+
metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
|
461
|
+
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
|
462
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
463
|
+
(1, 0),
|
464
|
+
)
|
465
|
+
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
|
466
|
+
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
|
467
|
+
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
|
468
|
+
]
|
469
|
+
|
470
|
+
# Currently only support forward_batch.encoder_lens.numel() == 1
|
471
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
472
|
+
forward_batch.req_pool_indices,
|
473
|
+
metadata.encoder_max_seq_len_k : (
|
474
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
475
|
+
),
|
476
|
+
]
|
477
|
+
|
478
|
+
# Convert the page table to a strided format which is needed by FA3 API
|
462
479
|
if self.page_size > 1:
|
463
480
|
self.strided_indices = torch.arange(
|
464
481
|
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
@@ -498,7 +515,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
498
515
|
v,
|
499
516
|
)
|
500
517
|
|
501
|
-
# Use precomputed metadata
|
518
|
+
# Use precomputed metadata across all layers
|
502
519
|
metadata = self.forward_metadata
|
503
520
|
|
504
521
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
@@ -506,9 +523,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
506
523
|
# here is two side inclusive
|
507
524
|
window_size = (
|
508
525
|
(layer.sliding_window_size, 0)
|
509
|
-
if layer.sliding_window_size is not None
|
526
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
510
527
|
else (-1, -1)
|
511
528
|
)
|
529
|
+
k_descale, v_descale = None, None
|
530
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
531
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
532
|
+
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
|
533
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
534
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
535
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
536
|
+
q = q.to(self.kv_cache_dtype)
|
537
|
+
causal = not layer.is_cross_attention
|
512
538
|
|
513
539
|
# Check if we should use local attention
|
514
540
|
use_local_attn = (
|
@@ -536,14 +562,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
536
562
|
# Use Flash Attention for prefill
|
537
563
|
if not self.use_mla:
|
538
564
|
# Do multi-head attention
|
539
|
-
|
540
|
-
|
565
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
566
|
+
layer.layer_id
|
567
|
+
)
|
541
568
|
key_cache = key_cache.view(
|
542
569
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
543
570
|
)
|
544
571
|
value_cache = value_cache.view(
|
545
572
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
546
573
|
)
|
574
|
+
if layer.is_cross_attention:
|
575
|
+
page_table = metadata.encoder_page_table
|
576
|
+
cache_seqlens = metadata.encoder_lens_int32
|
577
|
+
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
578
|
+
window_size = (-1, -1)
|
579
|
+
|
547
580
|
o = flash_attn_with_kvcache(
|
548
581
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
549
582
|
k_cache=key_cache,
|
@@ -554,48 +587,93 @@ class FlashAttentionBackend(AttentionBackend):
|
|
554
587
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
555
588
|
max_seqlen_q=max_seqlen_q,
|
556
589
|
softmax_scale=layer.scaling,
|
557
|
-
causal=
|
590
|
+
causal=causal,
|
558
591
|
window_size=window_size,
|
559
592
|
softcap=layer.logit_cap,
|
560
|
-
k_descale=
|
561
|
-
v_descale=
|
593
|
+
k_descale=k_descale,
|
594
|
+
v_descale=v_descale,
|
562
595
|
)
|
596
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
563
597
|
else:
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
598
|
+
if (
|
599
|
+
not global_server_args_dict["disable_chunked_prefix_cache"]
|
600
|
+
and forward_batch.attn_attend_prefix_cache is not None
|
601
|
+
and not forward_batch.forward_mode.is_target_verify()
|
602
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
603
|
+
):
|
604
|
+
# Do multi-head attention with chunked prefix cache
|
605
|
+
|
606
|
+
if forward_batch.attn_attend_prefix_cache:
|
607
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
608
|
+
assert forward_batch.prefix_chunk_idx is not None
|
609
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
610
|
+
assert forward_batch.prefix_chunk_max_seq_lens is not None
|
611
|
+
|
612
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
613
|
+
assert chunk_idx >= 0
|
614
|
+
|
615
|
+
output, lse, *rest = flash_attn_varlen_func(
|
616
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
617
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
618
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
619
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
620
|
+
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
621
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
622
|
+
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
623
|
+
softmax_scale=layer.scaling,
|
624
|
+
causal=False,
|
625
|
+
return_softmax_lse=True,
|
626
|
+
)
|
627
|
+
else:
|
628
|
+
# MHA for extend part of sequence without attending prefix kv cache
|
629
|
+
output, lse, *rest = flash_attn_varlen_func(
|
630
|
+
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
631
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
632
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
633
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
634
|
+
cu_seqlens_k=metadata.cu_seqlens_q,
|
635
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
636
|
+
max_seqlen_k=metadata.max_seq_len_q,
|
637
|
+
softmax_scale=layer.scaling,
|
638
|
+
causal=True,
|
639
|
+
return_softmax_lse=True,
|
640
|
+
)
|
641
|
+
return output, lse
|
642
|
+
else:
|
643
|
+
# Do absorbed multi-latent attention
|
644
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
645
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
646
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
647
|
+
k_rope_cache = k_rope.view(
|
648
|
+
-1,
|
649
|
+
self.page_size,
|
650
|
+
layer.tp_k_head_num,
|
651
|
+
layer.head_dim - layer.v_head_dim,
|
652
|
+
)
|
653
|
+
c_kv_cache = c_kv.view(
|
654
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
655
|
+
)
|
656
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
657
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
658
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
659
|
+
o = flash_attn_with_kvcache(
|
660
|
+
q=q_rope,
|
661
|
+
k_cache=k_rope_cache,
|
662
|
+
v_cache=c_kv_cache,
|
663
|
+
qv=q_nope,
|
664
|
+
page_table=page_table,
|
665
|
+
cache_seqlens=cache_seqlens,
|
666
|
+
cu_seqlens_q=cu_seqlens_q,
|
667
|
+
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
668
|
+
max_seqlen_q=max_seqlen_q,
|
669
|
+
softmax_scale=layer.scaling,
|
670
|
+
causal=True,
|
671
|
+
softcap=layer.logit_cap,
|
672
|
+
k_descale=k_descale,
|
673
|
+
v_descale=v_descale,
|
674
|
+
)
|
597
675
|
|
598
|
-
|
676
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
599
677
|
|
600
678
|
def forward_decode(
|
601
679
|
self,
|
@@ -606,8 +684,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
606
684
|
forward_batch: ForwardBatch,
|
607
685
|
save_kv_cache=True,
|
608
686
|
) -> torch.Tensor:
|
609
|
-
"""Forward pass with FlashAttention using precomputed metadata."""
|
610
|
-
# Save KV cache if needed
|
611
687
|
if k is not None:
|
612
688
|
assert v is not None
|
613
689
|
if save_kv_cache:
|
@@ -628,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
628
704
|
v,
|
629
705
|
)
|
630
706
|
|
631
|
-
# Use precomputed metadata
|
707
|
+
# Use precomputed metadata across all layers
|
632
708
|
metadata = self.forward_metadata
|
633
709
|
|
634
710
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
@@ -636,17 +712,27 @@ class FlashAttentionBackend(AttentionBackend):
|
|
636
712
|
# here is two side inclusive
|
637
713
|
window_size = (
|
638
714
|
(layer.sliding_window_size, 0)
|
639
|
-
if layer.sliding_window_size is not None
|
715
|
+
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
640
716
|
else (-1, -1)
|
641
717
|
)
|
642
|
-
|
718
|
+
causal = not layer.is_cross_attention
|
719
|
+
|
720
|
+
k_descale, v_descale = None, None
|
721
|
+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
722
|
+
# has corresponding quantization method so that layer.k_scale is not None
|
723
|
+
if self.kv_cache_dtype_str != "auto":
|
724
|
+
if layer.k_scale is not None:
|
725
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
726
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
727
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
728
|
+
q = q.to(self.kv_cache_dtype)
|
643
729
|
|
644
730
|
if not self.use_mla:
|
645
731
|
# Do multi-head attention
|
646
732
|
|
647
|
-
|
648
|
-
|
649
|
-
|
733
|
+
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
734
|
+
layer.layer_id
|
735
|
+
)
|
650
736
|
key_cache = key_cache.view(
|
651
737
|
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
652
738
|
)
|
@@ -654,23 +740,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
654
740
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
655
741
|
)
|
656
742
|
|
657
|
-
# Pre-reshape query tensor
|
658
743
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
744
|
+
if layer.is_cross_attention:
|
745
|
+
page_table = metadata.encoder_page_table
|
746
|
+
cache_seqlens = metadata.encoder_lens_int32
|
747
|
+
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
748
|
+
window_size = (-1, -1)
|
749
|
+
else:
|
750
|
+
page_table = metadata.page_table
|
751
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
752
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
753
|
+
|
659
754
|
o = flash_attn_with_kvcache(
|
660
755
|
q=q_reshaped,
|
661
756
|
k_cache=key_cache,
|
662
757
|
v_cache=value_cache,
|
663
758
|
page_table=page_table,
|
664
|
-
cache_seqlens=
|
759
|
+
cache_seqlens=cache_seqlens,
|
665
760
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
666
|
-
cu_seqlens_k_new=
|
761
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
667
762
|
max_seqlen_q=1,
|
668
763
|
softmax_scale=layer.scaling,
|
669
|
-
causal=
|
764
|
+
causal=causal,
|
670
765
|
window_size=window_size,
|
671
766
|
softcap=layer.logit_cap,
|
672
|
-
k_descale=
|
673
|
-
v_descale=
|
767
|
+
k_descale=k_descale,
|
768
|
+
v_descale=v_descale,
|
674
769
|
)
|
675
770
|
else:
|
676
771
|
# Do absorbed multi-latent attention
|
@@ -696,7 +791,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
696
791
|
k_cache=k_rope_cache,
|
697
792
|
v_cache=c_kv_cache,
|
698
793
|
qv=q_nope,
|
699
|
-
page_table=page_table,
|
794
|
+
page_table=metadata.page_table,
|
700
795
|
cache_seqlens=metadata.cache_seqlens_int32,
|
701
796
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
702
797
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
@@ -704,8 +799,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
704
799
|
softmax_scale=layer.scaling,
|
705
800
|
causal=True,
|
706
801
|
softcap=layer.logit_cap,
|
707
|
-
k_descale=
|
708
|
-
v_descale=
|
802
|
+
k_descale=k_descale,
|
803
|
+
v_descale=v_descale,
|
709
804
|
)
|
710
805
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
711
806
|
|
@@ -719,7 +814,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
719
814
|
to avoid memory allocations.
|
720
815
|
"""
|
721
816
|
self.decode_cuda_graph_metadata = {
|
722
|
-
|
817
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
818
|
+
"cu_seqlens_q": torch.arange(
|
819
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
820
|
+
),
|
821
|
+
"cu_seqlens_k": torch.zeros(
|
822
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
823
|
+
),
|
723
824
|
"page_table": torch.zeros(
|
724
825
|
max_bs,
|
725
826
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
@@ -735,35 +836,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|
735
836
|
"strided_indices": torch.arange(
|
736
837
|
0, self.max_context_len, self.page_size, device=self.device
|
737
838
|
),
|
839
|
+
}
|
840
|
+
|
841
|
+
self.target_verify_metadata = {
|
738
842
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
739
|
-
"cu_seqlens_q": torch.
|
740
|
-
|
843
|
+
"cu_seqlens_q": torch.zeros(
|
844
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
741
845
|
),
|
742
846
|
"cu_seqlens_k": torch.zeros(
|
743
|
-
max_bs +
|
847
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
744
848
|
),
|
745
|
-
}
|
746
|
-
|
747
|
-
self.target_verify_metadata = {
|
748
849
|
"page_table": torch.zeros(
|
749
850
|
max_bs,
|
750
851
|
(self.max_context_len + self.page_size - 1) // self.page_size,
|
751
852
|
dtype=torch.int32,
|
752
853
|
device=self.device,
|
753
854
|
),
|
754
|
-
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
755
|
-
"cu_seqlens_q": torch.zeros(
|
756
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
757
|
-
),
|
758
|
-
"cu_seqlens_k": torch.zeros(
|
759
|
-
max_bs + 128, dtype=torch.int32, device=self.device
|
760
|
-
),
|
761
|
-
"max_seqlen_q": 0,
|
762
855
|
"strided_indices": torch.arange(
|
763
856
|
0, self.max_context_len, self.page_size, device=self.device
|
764
857
|
),
|
765
858
|
}
|
766
859
|
|
860
|
+
self.encoder_metadata = {
|
861
|
+
"encoder_page_table": torch.zeros(
|
862
|
+
max_bs,
|
863
|
+
self.max_context_len,
|
864
|
+
dtype=torch.int32,
|
865
|
+
device=self.device,
|
866
|
+
),
|
867
|
+
"encoder_lens_int32": torch.zeros(
|
868
|
+
max_bs, dtype=torch.int32, device=self.device
|
869
|
+
),
|
870
|
+
"encoder_cu_seqlens_k": torch.zeros(
|
871
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
872
|
+
),
|
873
|
+
}
|
874
|
+
|
767
875
|
def init_forward_metadata_capture_cuda_graph(
|
768
876
|
self,
|
769
877
|
bs: int,
|
@@ -777,27 +885,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
777
885
|
"""Initialize forward metadata for capturing CUDA graph."""
|
778
886
|
metadata = FlashAttentionMetadata()
|
779
887
|
device = seq_lens.device
|
780
|
-
if forward_mode.
|
888
|
+
if forward_mode.is_decode_or_idle():
|
781
889
|
if spec_info is not None:
|
782
890
|
# Draft Decode
|
783
|
-
metadata.cu_seqlens_q = torch.arange(
|
784
|
-
0, bs + 1, dtype=torch.int32, device=device
|
785
|
-
)
|
786
891
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
787
892
|
"cache_seqlens"
|
788
893
|
][:bs]
|
789
|
-
|
894
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
895
|
+
self.speculative_step_id + 1
|
896
|
+
)
|
790
897
|
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
791
898
|
: bs + 1
|
792
899
|
]
|
793
|
-
|
794
900
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
795
901
|
torch.cumsum(
|
796
902
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
797
903
|
),
|
798
904
|
(1, 0),
|
799
905
|
)
|
800
|
-
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
801
906
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
802
907
|
"page_table_draft_decode"
|
803
908
|
][req_pool_indices, :]
|
@@ -822,43 +927,49 @@ class FlashAttentionBackend(AttentionBackend):
|
|
822
927
|
)
|
823
928
|
self.decode_cuda_graph_metadata[bs] = metadata
|
824
929
|
elif forward_mode.is_target_verify():
|
825
|
-
draft_token_num = spec_info.draft_token_num
|
826
|
-
|
827
930
|
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
828
931
|
:bs
|
829
932
|
]
|
830
933
|
metadata.cache_seqlens_int32.copy_(
|
831
|
-
(seq_lens +
|
934
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
832
935
|
)
|
833
936
|
|
834
|
-
metadata.max_seq_len_q =
|
835
|
-
metadata.max_seq_len_k =
|
937
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
938
|
+
metadata.max_seq_len_k = (
|
939
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
940
|
+
)
|
836
941
|
|
837
|
-
metadata.cu_seqlens_q =
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
device=device,
|
844
|
-
)
|
845
|
-
]
|
846
|
-
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
847
|
-
cu_k.copy_(
|
848
|
-
torch.nn.functional.pad(
|
849
|
-
torch.cumsum(
|
850
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
851
|
-
),
|
852
|
-
(1, 0),
|
853
|
-
)
|
942
|
+
metadata.cu_seqlens_q = torch.arange(
|
943
|
+
0,
|
944
|
+
bs * self.speculative_num_draft_tokens + 1,
|
945
|
+
self.speculative_num_draft_tokens,
|
946
|
+
dtype=torch.int32,
|
947
|
+
device=device,
|
854
948
|
)
|
855
|
-
|
949
|
+
|
950
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
951
|
+
: (bs + 1)
|
952
|
+
]
|
953
|
+
|
856
954
|
metadata.page_table = self.target_verify_metadata["page_table"][
|
857
955
|
req_pool_indices, :
|
858
956
|
]
|
859
957
|
|
860
958
|
self.target_verify_metadata[bs] = metadata
|
861
959
|
|
960
|
+
if encoder_lens is not None:
|
961
|
+
encoder_bs = encoder_lens.numel()
|
962
|
+
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
|
963
|
+
:encoder_bs
|
964
|
+
]
|
965
|
+
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
|
966
|
+
"encoder_cu_seqlens_k"
|
967
|
+
][: (encoder_bs + 1)]
|
968
|
+
|
969
|
+
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
970
|
+
req_pool_indices, :
|
971
|
+
]
|
972
|
+
|
862
973
|
self.forward_metadata = metadata
|
863
974
|
|
864
975
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -874,24 +985,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
874
985
|
out_cache_loc: torch.Tensor = None,
|
875
986
|
):
|
876
987
|
# """Initialize forward metadata for replaying CUDA graph."""
|
877
|
-
device = seq_lens.device
|
878
988
|
seq_lens = seq_lens[:bs]
|
879
|
-
req_pool_indices = req_pool_indices[:bs]
|
880
989
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
881
|
-
|
990
|
+
req_pool_indices = req_pool_indices[:bs]
|
991
|
+
if forward_mode.is_decode_or_idle():
|
882
992
|
metadata = self.decode_cuda_graph_metadata[bs]
|
883
993
|
|
884
994
|
if spec_info is not None:
|
885
995
|
# Draft Decode
|
886
|
-
max_len = seq_lens_cpu.max().item()
|
887
|
-
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
888
|
-
|
889
996
|
metadata.cache_seqlens_int32.copy_(
|
890
|
-
(seq_lens + (self.
|
997
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
891
998
|
)
|
892
999
|
|
893
|
-
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
894
|
-
|
1000
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1001
|
+
self.speculative_step_id + 1
|
1002
|
+
)
|
895
1003
|
metadata.cu_seqlens_k.copy_(
|
896
1004
|
torch.nn.functional.pad(
|
897
1005
|
torch.cumsum(
|
@@ -920,31 +1028,24 @@ class FlashAttentionBackend(AttentionBackend):
|
|
920
1028
|
metadata.max_seq_len_k + self.page_size - 1
|
921
1029
|
) // self.page_size
|
922
1030
|
page_indices = self.req_to_token[
|
923
|
-
:,
|
924
|
-
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
1031
|
+
req_pool_indices[:, None],
|
1032
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
1033
|
+
None, :
|
1034
|
+
],
|
925
1035
|
]
|
926
|
-
page_indices
|
1036
|
+
page_indices //= self.page_size
|
927
1037
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
928
1038
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
929
1039
|
|
930
1040
|
elif forward_mode.is_target_verify():
|
931
1041
|
metadata = self.target_verify_metadata[bs]
|
932
|
-
draft_token_num = spec_info.draft_token_num
|
933
|
-
|
934
|
-
metadata.cu_seqlens_q.copy_(
|
935
|
-
torch.arange(
|
936
|
-
0,
|
937
|
-
bs * draft_token_num + 1,
|
938
|
-
draft_token_num,
|
939
|
-
dtype=torch.int32,
|
940
|
-
device=device,
|
941
|
-
)
|
942
|
-
)
|
943
1042
|
metadata.cache_seqlens_int32.copy_(
|
944
|
-
(seq_lens +
|
1043
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
945
1044
|
)
|
946
1045
|
|
947
|
-
metadata.max_seq_len_k =
|
1046
|
+
metadata.max_seq_len_k = (
|
1047
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1048
|
+
)
|
948
1049
|
metadata.cu_seqlens_k.copy_(
|
949
1050
|
torch.nn.functional.pad(
|
950
1051
|
torch.cumsum(
|
@@ -956,6 +1057,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|
956
1057
|
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
957
1058
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
958
1059
|
|
1060
|
+
if encoder_lens is not None:
|
1061
|
+
# Only support encoder size 1 for now
|
1062
|
+
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1063
|
+
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1064
|
+
metadata.encoder_cu_seqlens_k.copy_(
|
1065
|
+
torch.nn.functional.pad(
|
1066
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1067
|
+
(1, 0),
|
1068
|
+
)
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
1072
|
+
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
# Update the regular page table
|
1076
|
+
page_table = self.req_to_token[
|
1077
|
+
req_pool_indices,
|
1078
|
+
metadata.encoder_max_seq_len_k : (
|
1079
|
+
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
|
1080
|
+
),
|
1081
|
+
]
|
1082
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1083
|
+
|
959
1084
|
self.forward_metadata = metadata
|
960
1085
|
|
961
1086
|
def get_cuda_graph_seq_len_fill_value(self):
|
@@ -972,14 +1097,19 @@ class FlashAttentionMultiStepBackend:
|
|
972
1097
|
self.topk = topk
|
973
1098
|
self.speculative_num_steps = speculative_num_steps
|
974
1099
|
|
1100
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1101
|
+
assert (
|
1102
|
+
self.topk == 1
|
1103
|
+
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1104
|
+
|
975
1105
|
self.attn_backends = []
|
976
1106
|
for i in range(self.speculative_num_steps):
|
977
1107
|
self.attn_backends.append(
|
978
1108
|
FlashAttentionBackend(
|
979
1109
|
model_runner,
|
1110
|
+
speculative_step_id=i,
|
980
1111
|
topk=self.topk,
|
981
1112
|
speculative_num_steps=self.speculative_num_steps,
|
982
|
-
step_id=i,
|
983
1113
|
)
|
984
1114
|
)
|
985
1115
|
|
@@ -1004,7 +1134,7 @@ class FlashAttentionMultiStepBackend:
|
|
1004
1134
|
forward_batch.batch_size * self.topk,
|
1005
1135
|
forward_batch.req_pool_indices,
|
1006
1136
|
forward_batch.seq_lens,
|
1007
|
-
encoder_lens=
|
1137
|
+
encoder_lens=forward_batch.encoder_lens,
|
1008
1138
|
forward_mode=ForwardMode.DECODE,
|
1009
1139
|
spec_info=forward_batch.spec_info,
|
1010
1140
|
)
|
@@ -1021,7 +1151,7 @@ class FlashAttentionMultiStepBackend:
|
|
1021
1151
|
forward_batch.req_pool_indices,
|
1022
1152
|
forward_batch.seq_lens,
|
1023
1153
|
forward_batch.seq_lens_sum,
|
1024
|
-
encoder_lens=
|
1154
|
+
encoder_lens=forward_batch.encoder_lens,
|
1025
1155
|
forward_mode=ForwardMode.DECODE,
|
1026
1156
|
spec_info=forward_batch.spec_info,
|
1027
1157
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|