sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|
16
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
17
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
18
|
|
19
|
+
from sgl_kernel import merge_state_v2
|
19
20
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
20
21
|
|
21
22
|
|
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
|
|
30
31
|
# Sequence lengths for the forward batch
|
31
32
|
cache_seqlens_int32: torch.Tensor = None
|
32
33
|
# Maximum sequence length for query
|
33
|
-
max_seq_len_q: int =
|
34
|
+
max_seq_len_q: int = 1
|
34
35
|
# Maximum sequence length for key
|
35
36
|
max_seq_len_k: int = 0
|
36
37
|
# Cumulative sequence lengths for query
|
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
|
|
267
268
|
return -(a // -b)
|
268
269
|
|
269
270
|
|
271
|
+
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
|
272
|
+
@torch._dynamo.disable()
|
273
|
+
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
|
274
|
+
return merge_state_v2(o, s_a, o_exp, s_b)
|
275
|
+
|
276
|
+
|
270
277
|
class FlashAttentionBackend(AttentionBackend):
|
271
278
|
"""FlashAttention backend implementation.
|
272
279
|
|
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
301
308
|
), "Sliding window and cross attention are not supported together"
|
302
309
|
|
303
310
|
self.forward_metadata: FlashAttentionMetadata = None
|
311
|
+
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
|
312
|
+
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
304
313
|
self.max_context_len = model_runner.model_config.context_len
|
305
314
|
self.device = model_runner.device
|
306
315
|
self.decode_cuda_graph_metadata = {}
|
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
311
320
|
self.page_size = model_runner.page_size
|
312
321
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
313
322
|
self.skip_prefill = skip_prefill
|
314
|
-
|
315
|
-
self.topk = topk
|
323
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
316
324
|
self.speculative_num_steps = speculative_num_steps
|
317
325
|
self.speculative_num_draft_tokens = (
|
318
326
|
model_runner.server_args.speculative_num_draft_tokens
|
@@ -336,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
|
|
336
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
337
345
|
# Draft Decode
|
338
346
|
if forward_batch.spec_info is not None:
|
347
|
+
if self.topk <= 1:
|
348
|
+
metadata.cache_seqlens_int32 = (
|
349
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
350
|
+
).to(torch.int32)
|
351
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
352
|
+
self.speculative_step_id + 1
|
353
|
+
)
|
354
|
+
metadata.cu_seqlens_q = torch.arange(
|
355
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
356
|
+
)
|
357
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
358
|
+
torch.cumsum(
|
359
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
360
|
+
),
|
361
|
+
(1, 0),
|
362
|
+
)
|
363
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
364
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
365
|
+
]
|
366
|
+
else:
|
367
|
+
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
368
|
+
metadata.max_seq_len_q = self.topk
|
369
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
370
|
+
metadata.cu_seqlens_q = torch.arange(
|
371
|
+
0,
|
372
|
+
batch_size * self.topk + 1,
|
373
|
+
step=self.topk,
|
374
|
+
dtype=torch.int32,
|
375
|
+
device=device,
|
376
|
+
)
|
377
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
378
|
+
torch.cumsum(
|
379
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
380
|
+
),
|
381
|
+
(1, 0),
|
382
|
+
)
|
383
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
384
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
385
|
+
]
|
386
|
+
|
387
|
+
metadata_expand = FlashAttentionMetadata()
|
388
|
+
decode_length = self.speculative_step_id + 1
|
389
|
+
metadata_expand.cache_seqlens_int32 = torch.full(
|
390
|
+
(seqlens_in_batch.numel() * self.topk,),
|
391
|
+
decode_length,
|
392
|
+
device=device,
|
393
|
+
dtype=torch.int32,
|
394
|
+
)
|
395
|
+
metadata_expand.max_seq_len_q = 1
|
396
|
+
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
397
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
398
|
+
0,
|
399
|
+
metadata_expand.cache_seqlens_int32.numel() + 1,
|
400
|
+
dtype=torch.int32,
|
401
|
+
device=device,
|
402
|
+
)
|
403
|
+
metadata_expand.cu_seqlens_k = torch.arange(
|
404
|
+
0,
|
405
|
+
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
406
|
+
step=decode_length,
|
407
|
+
dtype=torch.int32,
|
408
|
+
device=device,
|
409
|
+
)
|
410
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
411
|
+
self.speculative_num_steps, -1
|
412
|
+
).T.contiguous()
|
413
|
+
metadata_expand.page_table = (
|
414
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
415
|
+
)
|
416
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
417
|
+
else:
|
418
|
+
# Normal Decode
|
419
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
420
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
421
|
+
metadata.cu_seqlens_q = torch.arange(
|
422
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
423
|
+
)
|
424
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
425
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
426
|
+
)
|
427
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
428
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
429
|
+
]
|
430
|
+
# TODO: we need to test this part for llama 4 eagle case
|
431
|
+
self._init_local_attn_metadata(metadata, device)
|
432
|
+
elif forward_batch.forward_mode.is_target_verify():
|
433
|
+
if self.topk <= 1:
|
339
434
|
metadata.cache_seqlens_int32 = (
|
340
|
-
|
435
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
341
436
|
).to(torch.int32)
|
342
|
-
metadata.
|
343
|
-
|
437
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
438
|
+
metadata.max_seq_len_k = (
|
439
|
+
forward_batch.seq_lens_cpu.max().item()
|
440
|
+
+ self.speculative_num_draft_tokens
|
344
441
|
)
|
345
442
|
metadata.cu_seqlens_q = torch.arange(
|
346
|
-
0,
|
443
|
+
0,
|
444
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
445
|
+
self.speculative_num_draft_tokens,
|
446
|
+
dtype=torch.int32,
|
447
|
+
device=device,
|
347
448
|
)
|
348
449
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
349
450
|
torch.cumsum(
|
@@ -357,44 +458,101 @@ class FlashAttentionBackend(AttentionBackend):
|
|
357
458
|
|
358
459
|
self._init_local_attn_metadata(metadata, device)
|
359
460
|
else:
|
360
|
-
|
361
|
-
metadata.
|
461
|
+
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
462
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
362
463
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
363
464
|
metadata.cu_seqlens_q = torch.arange(
|
364
|
-
0,
|
465
|
+
0,
|
466
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
467
|
+
step=self.speculative_num_draft_tokens,
|
468
|
+
dtype=torch.int32,
|
469
|
+
device=device,
|
365
470
|
)
|
366
471
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
367
|
-
torch.cumsum(
|
472
|
+
torch.cumsum(
|
473
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
474
|
+
),
|
475
|
+
(1, 0),
|
368
476
|
)
|
369
477
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
370
478
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
371
479
|
]
|
372
480
|
|
373
|
-
|
374
|
-
elif forward_batch.forward_mode.is_target_verify():
|
375
|
-
metadata.cache_seqlens_int32 = (
|
376
|
-
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
377
|
-
).to(torch.int32)
|
378
|
-
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
379
|
-
metadata.max_seq_len_k = (
|
380
|
-
forward_batch.seq_lens_cpu.max().item()
|
381
|
-
+ self.speculative_num_draft_tokens
|
382
|
-
)
|
383
|
-
metadata.cu_seqlens_q = torch.arange(
|
384
|
-
0,
|
385
|
-
batch_size * self.speculative_num_draft_tokens + 1,
|
386
|
-
self.speculative_num_draft_tokens,
|
387
|
-
dtype=torch.int32,
|
388
|
-
device=device,
|
389
|
-
)
|
390
|
-
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
391
|
-
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
392
|
-
(1, 0),
|
393
|
-
)
|
394
|
-
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
395
|
-
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
396
|
-
]
|
481
|
+
metadata_expand = FlashAttentionMetadata()
|
397
482
|
|
483
|
+
metadata_expand.max_seq_len_q = 1
|
484
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
485
|
+
0,
|
486
|
+
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
487
|
+
+ 1,
|
488
|
+
dtype=torch.int32,
|
489
|
+
device=device,
|
490
|
+
)
|
491
|
+
|
492
|
+
# create expand page table
|
493
|
+
offsets = torch.arange(
|
494
|
+
self.speculative_num_draft_tokens, device=device
|
495
|
+
).unsqueeze(
|
496
|
+
0
|
497
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
498
|
+
cols = offsets.expand(
|
499
|
+
forward_batch.seq_lens.numel(), -1
|
500
|
+
) + forward_batch.seq_lens.unsqueeze(1)
|
501
|
+
cum_len = torch.nn.functional.pad(
|
502
|
+
torch.cumsum(
|
503
|
+
(
|
504
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
505
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
506
|
+
dim=0,
|
507
|
+
),
|
508
|
+
(1, 0),
|
509
|
+
)[:-1]
|
510
|
+
mask_extraction_indices = (
|
511
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
512
|
+
+ cum_len[:, None]
|
513
|
+
).view(1, -1)
|
514
|
+
mask = forward_batch.spec_info.custom_mask[
|
515
|
+
mask_extraction_indices
|
516
|
+
].view(
|
517
|
+
-1, self.speculative_num_draft_tokens
|
518
|
+
) # (bsz * draft_num, draft_num)
|
519
|
+
|
520
|
+
# shift table indices to avoid padding
|
521
|
+
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
522
|
+
# [8, 9, 10], [1, 1, 0],
|
523
|
+
# [8, 9, 10]] [1, 0, 1]]
|
524
|
+
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
525
|
+
# [8, 9, 0], [8, 9, 10],
|
526
|
+
# [8, 0, 10]] [8, 10, 9]]
|
527
|
+
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
528
|
+
col_indices = offsets.expand(
|
529
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
530
|
+
)
|
531
|
+
# Build keys: if an entry is valid (mask==True), keep its original index;
|
532
|
+
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
533
|
+
keys = torch.where(
|
534
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
535
|
+
)
|
536
|
+
_, sort_order = torch.sort(keys, dim=1)
|
537
|
+
non_masked_page_table = (
|
538
|
+
forward_batch.req_to_token_pool.req_to_token[
|
539
|
+
forward_batch.req_pool_indices, :
|
540
|
+
]
|
541
|
+
.gather(1, cols)
|
542
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
543
|
+
) # (bsz, draft_num)
|
544
|
+
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
545
|
+
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
546
|
+
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
547
|
+
torch.cumsum(
|
548
|
+
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
549
|
+
),
|
550
|
+
(1, 0),
|
551
|
+
)
|
552
|
+
metadata_expand.max_seq_len_k = (
|
553
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
554
|
+
)
|
555
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
398
556
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
399
557
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
400
558
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -465,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
465
623
|
layer: RadixAttention,
|
466
624
|
forward_batch: ForwardBatch,
|
467
625
|
save_kv_cache=True,
|
626
|
+
# For multi-head latent attention
|
627
|
+
q_rope: Optional[torch.Tensor] = None,
|
628
|
+
k_rope: Optional[torch.Tensor] = None,
|
468
629
|
):
|
469
630
|
if k is not None:
|
470
631
|
assert v is not None
|
@@ -479,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
479
640
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
480
641
|
)
|
481
642
|
else:
|
482
|
-
forward_batch.token_to_kv_pool.
|
643
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
483
644
|
layer,
|
484
645
|
cache_loc,
|
485
646
|
k,
|
486
|
-
|
647
|
+
k_rope,
|
487
648
|
)
|
488
649
|
|
489
650
|
# Use precomputed metadata across all layers
|
@@ -514,6 +675,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
514
675
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
515
676
|
)
|
516
677
|
|
678
|
+
# We do cascade attention for Target Verify with topk > 1
|
679
|
+
use_cascade_attn = (
|
680
|
+
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
681
|
+
)
|
682
|
+
|
517
683
|
# Get the appropriate page table based on whether we're using local attention
|
518
684
|
if use_local_attn:
|
519
685
|
local_metadata = metadata.local_attn_metadata
|
@@ -548,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
548
714
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
549
715
|
window_size = (-1, -1)
|
550
716
|
|
551
|
-
|
717
|
+
result = flash_attn_with_kvcache(
|
552
718
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
553
719
|
k_cache=key_cache,
|
554
720
|
v_cache=value_cache,
|
@@ -558,13 +724,41 @@ class FlashAttentionBackend(AttentionBackend):
|
|
558
724
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
559
725
|
max_seqlen_q=max_seqlen_q,
|
560
726
|
softmax_scale=layer.scaling,
|
561
|
-
causal=causal,
|
727
|
+
causal=False if use_cascade_attn else causal,
|
562
728
|
window_size=window_size,
|
563
729
|
softcap=layer.logit_cap,
|
564
730
|
k_descale=k_descale,
|
565
731
|
v_descale=v_descale,
|
732
|
+
return_softmax_lse=use_cascade_attn,
|
566
733
|
)
|
567
|
-
|
734
|
+
|
735
|
+
if use_cascade_attn:
|
736
|
+
o, softmax_lse, *rest = result
|
737
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
738
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
739
|
+
k_cache=key_cache,
|
740
|
+
v_cache=value_cache,
|
741
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
742
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
743
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
744
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
745
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
746
|
+
softmax_scale=layer.scaling,
|
747
|
+
causal=False,
|
748
|
+
window_size=window_size,
|
749
|
+
softcap=layer.logit_cap,
|
750
|
+
k_descale=k_descale,
|
751
|
+
v_descale=v_descale,
|
752
|
+
return_softmax_lse=True,
|
753
|
+
)
|
754
|
+
o, _ = merge_state_v2_wrapper(
|
755
|
+
o,
|
756
|
+
softmax_lse.T.contiguous(),
|
757
|
+
o_expand,
|
758
|
+
softmax_lse_expand.T.contiguous(),
|
759
|
+
)
|
760
|
+
else:
|
761
|
+
o = result
|
568
762
|
else:
|
569
763
|
if (
|
570
764
|
not global_server_args_dict["disable_chunked_prefix_cache"]
|
@@ -624,10 +818,17 @@ class FlashAttentionBackend(AttentionBackend):
|
|
624
818
|
c_kv_cache = c_kv.view(
|
625
819
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
626
820
|
)
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
821
|
+
if q_rope is not None:
|
822
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
823
|
+
q_rope = q_rope.view(
|
824
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
825
|
+
)
|
826
|
+
else:
|
827
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
828
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
829
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
830
|
+
|
831
|
+
result = flash_attn_with_kvcache(
|
631
832
|
q=q_rope,
|
632
833
|
k_cache=k_rope_cache,
|
633
834
|
v_cache=c_kv_cache,
|
@@ -638,13 +839,44 @@ class FlashAttentionBackend(AttentionBackend):
|
|
638
839
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
639
840
|
max_seqlen_q=max_seqlen_q,
|
640
841
|
softmax_scale=layer.scaling,
|
641
|
-
causal=
|
842
|
+
causal=False if use_cascade_attn else causal,
|
642
843
|
softcap=layer.logit_cap,
|
643
844
|
k_descale=k_descale,
|
644
845
|
v_descale=v_descale,
|
846
|
+
return_softmax_lse=use_cascade_attn,
|
645
847
|
)
|
848
|
+
if use_cascade_attn:
|
849
|
+
o, softmax_lse, *rest = result
|
850
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
851
|
+
flash_attn_with_kvcache(
|
852
|
+
q=q_rope,
|
853
|
+
k_cache=k_rope_cache,
|
854
|
+
v_cache=c_kv_cache,
|
855
|
+
qv=q_nope,
|
856
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
857
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
858
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
859
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
860
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
861
|
+
softmax_scale=layer.scaling,
|
862
|
+
causal=False,
|
863
|
+
window_size=window_size,
|
864
|
+
softcap=layer.logit_cap,
|
865
|
+
k_descale=k_descale,
|
866
|
+
v_descale=v_descale,
|
867
|
+
return_softmax_lse=True,
|
868
|
+
)
|
869
|
+
)
|
870
|
+
o, _ = merge_state_v2_wrapper(
|
871
|
+
o,
|
872
|
+
softmax_lse.T.contiguous(),
|
873
|
+
o_expand,
|
874
|
+
softmax_lse_expand.T.contiguous(),
|
875
|
+
)
|
876
|
+
else:
|
877
|
+
o = result
|
646
878
|
|
647
|
-
|
879
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
648
880
|
|
649
881
|
def forward_decode(
|
650
882
|
self,
|
@@ -654,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
654
886
|
layer: RadixAttention,
|
655
887
|
forward_batch: ForwardBatch,
|
656
888
|
save_kv_cache=True,
|
889
|
+
# For multi-head latent attention
|
890
|
+
q_rope: Optional[torch.Tensor] = None,
|
891
|
+
k_rope: Optional[torch.Tensor] = None,
|
657
892
|
) -> torch.Tensor:
|
658
893
|
if k is not None:
|
659
894
|
assert v is not None
|
@@ -668,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
668
903
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
669
904
|
)
|
670
905
|
else:
|
671
|
-
forward_batch.token_to_kv_pool.
|
906
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
672
907
|
layer,
|
673
908
|
cache_loc,
|
674
909
|
k,
|
675
|
-
|
910
|
+
k_rope,
|
676
911
|
)
|
677
912
|
|
678
913
|
# Use precomputed metadata across all layers
|
@@ -681,6 +916,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
681
916
|
use_local_attention = (
|
682
917
|
self.attention_chunk_size is not None and local_attn_metadata is not None
|
683
918
|
)
|
919
|
+
# We do cascade attention for Draft Decode with topk > 1
|
920
|
+
use_cascade_attn = self.topk > 1
|
684
921
|
|
685
922
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
686
923
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
@@ -752,23 +989,61 @@ class FlashAttentionBackend(AttentionBackend):
|
|
752
989
|
v_descale=v_descale,
|
753
990
|
)
|
754
991
|
else:
|
992
|
+
page_table = metadata.page_table
|
993
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
994
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
995
|
+
max_seqlen_q = metadata.max_seq_len_q
|
996
|
+
q_reshaped = q.contiguous().view(
|
997
|
+
-1, layer.tp_q_head_num, layer.head_dim
|
998
|
+
)
|
999
|
+
|
755
1000
|
# Default: single-token self-attention
|
756
|
-
|
757
|
-
q=
|
1001
|
+
result = flash_attn_with_kvcache(
|
1002
|
+
q=q_reshaped,
|
758
1003
|
k_cache=key_cache,
|
759
1004
|
v_cache=value_cache,
|
760
|
-
page_table=
|
761
|
-
cache_seqlens=
|
1005
|
+
page_table=page_table,
|
1006
|
+
cache_seqlens=cache_seqlens,
|
762
1007
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
763
|
-
cu_seqlens_k_new=
|
764
|
-
max_seqlen_q=
|
1008
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
1009
|
+
max_seqlen_q=max_seqlen_q,
|
765
1010
|
softmax_scale=layer.scaling,
|
766
|
-
causal=
|
1011
|
+
causal=False if use_cascade_attn else causal,
|
767
1012
|
window_size=window_size,
|
768
1013
|
softcap=layer.logit_cap,
|
769
1014
|
k_descale=k_descale,
|
770
1015
|
v_descale=v_descale,
|
1016
|
+
return_softmax_lse=use_cascade_attn,
|
771
1017
|
)
|
1018
|
+
if use_cascade_attn:
|
1019
|
+
o, softmax_lse, *rest = result
|
1020
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
1021
|
+
flash_attn_with_kvcache(
|
1022
|
+
q=q_reshaped,
|
1023
|
+
k_cache=key_cache,
|
1024
|
+
v_cache=value_cache,
|
1025
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1026
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1027
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1028
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1029
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1030
|
+
softmax_scale=layer.scaling,
|
1031
|
+
causal=False,
|
1032
|
+
window_size=window_size,
|
1033
|
+
softcap=layer.logit_cap,
|
1034
|
+
k_descale=k_descale,
|
1035
|
+
v_descale=v_descale,
|
1036
|
+
return_softmax_lse=True,
|
1037
|
+
)
|
1038
|
+
)
|
1039
|
+
o, _ = merge_state_v2(
|
1040
|
+
o,
|
1041
|
+
softmax_lse.T.contiguous(),
|
1042
|
+
o_expand,
|
1043
|
+
softmax_lse_expand.T.contiguous(),
|
1044
|
+
)
|
1045
|
+
else:
|
1046
|
+
o = result
|
772
1047
|
else:
|
773
1048
|
# Do absorbed multi-latent attention
|
774
1049
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
@@ -784,11 +1059,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
784
1059
|
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
785
1060
|
)
|
786
1061
|
|
787
|
-
|
788
|
-
|
789
|
-
|
1062
|
+
if q_rope is not None:
|
1063
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
1064
|
+
q_rope = q_rope.view(
|
1065
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
1066
|
+
)
|
1067
|
+
else:
|
1068
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
1069
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
1070
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
1071
|
+
max_seqlen_q = metadata.max_seq_len_q
|
790
1072
|
|
791
|
-
|
1073
|
+
result = flash_attn_with_kvcache(
|
792
1074
|
q=q_rope,
|
793
1075
|
k_cache=k_rope_cache,
|
794
1076
|
v_cache=c_kv_cache,
|
@@ -797,13 +1079,43 @@ class FlashAttentionBackend(AttentionBackend):
|
|
797
1079
|
cache_seqlens=metadata.cache_seqlens_int32,
|
798
1080
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
799
1081
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
800
|
-
max_seqlen_q=
|
1082
|
+
max_seqlen_q=max_seqlen_q,
|
801
1083
|
softmax_scale=layer.scaling,
|
802
|
-
causal=
|
1084
|
+
causal=False if use_cascade_attn else causal,
|
803
1085
|
softcap=layer.logit_cap,
|
804
1086
|
k_descale=k_descale,
|
805
1087
|
v_descale=v_descale,
|
1088
|
+
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
806
1089
|
)
|
1090
|
+
if use_cascade_attn:
|
1091
|
+
o, softmax_lse, *rest = result
|
1092
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
1093
|
+
q=q_rope,
|
1094
|
+
k_cache=k_rope_cache,
|
1095
|
+
v_cache=c_kv_cache,
|
1096
|
+
qv=q_nope,
|
1097
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1098
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1099
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1100
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1101
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1102
|
+
softmax_scale=layer.scaling,
|
1103
|
+
causal=False,
|
1104
|
+
window_size=window_size,
|
1105
|
+
softcap=layer.logit_cap,
|
1106
|
+
k_descale=k_descale,
|
1107
|
+
v_descale=v_descale,
|
1108
|
+
return_softmax_lse=True,
|
1109
|
+
)
|
1110
|
+
o, _ = merge_state_v2(
|
1111
|
+
o,
|
1112
|
+
softmax_lse.T.contiguous(),
|
1113
|
+
o_expand,
|
1114
|
+
softmax_lse_expand.T.contiguous(),
|
1115
|
+
)
|
1116
|
+
else:
|
1117
|
+
o = result
|
1118
|
+
|
807
1119
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
808
1120
|
|
809
1121
|
def init_cuda_graph_state(self, max_bs: int):
|
@@ -815,6 +1127,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
815
1127
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
816
1128
|
to avoid memory allocations.
|
817
1129
|
"""
|
1130
|
+
|
1131
|
+
# This is being used by normal decode and draft decode when topk == 1
|
818
1132
|
self.decode_cuda_graph_metadata = {
|
819
1133
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
820
1134
|
"cu_seqlens_q": torch.arange(
|
@@ -840,24 +1154,136 @@ class FlashAttentionBackend(AttentionBackend):
|
|
840
1154
|
),
|
841
1155
|
}
|
842
1156
|
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
1157
|
+
# This is used by draft decode's first half of metadata when topk > 1
|
1158
|
+
if self.topk > 1:
|
1159
|
+
self.draft_decode_metadata_topk_normal = {
|
1160
|
+
"cache_seqlens": torch.zeros(
|
1161
|
+
max_bs, dtype=torch.int32, device=self.device
|
1162
|
+
),
|
1163
|
+
"cu_seqlens_q": torch.arange(
|
1164
|
+
0,
|
1165
|
+
max_bs * self.topk + 1,
|
1166
|
+
step=self.topk,
|
1167
|
+
dtype=torch.int32,
|
1168
|
+
device=self.device,
|
1169
|
+
),
|
1170
|
+
"cu_seqlens_k": torch.zeros(
|
1171
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1172
|
+
),
|
1173
|
+
"page_table": torch.zeros(
|
1174
|
+
max_bs,
|
1175
|
+
self.max_context_len,
|
1176
|
+
dtype=torch.int32,
|
1177
|
+
device=self.device,
|
1178
|
+
),
|
1179
|
+
}
|
1180
|
+
|
1181
|
+
# This is used by draft decode's second half of metadata when topk > 1
|
1182
|
+
decode_length = self.speculative_step_id + 1
|
1183
|
+
self.draft_decode_metadata_topk_expand = {
|
1184
|
+
"cache_seqlens": torch.full(
|
1185
|
+
(max_bs * self.topk,),
|
1186
|
+
decode_length,
|
1187
|
+
device=self.device,
|
1188
|
+
dtype=torch.int32,
|
1189
|
+
),
|
1190
|
+
"cu_seqlens_q": torch.arange(
|
1191
|
+
0,
|
1192
|
+
max_bs * self.topk + 1,
|
1193
|
+
dtype=torch.int32,
|
1194
|
+
device=self.device,
|
1195
|
+
),
|
1196
|
+
"cu_seqlens_k": torch.arange(
|
1197
|
+
0,
|
1198
|
+
max_bs * self.topk * decode_length + 1,
|
1199
|
+
step=decode_length,
|
1200
|
+
dtype=torch.int32,
|
1201
|
+
device=self.device,
|
1202
|
+
),
|
1203
|
+
"page_table": torch.zeros(
|
1204
|
+
max_bs * self.topk,
|
1205
|
+
decode_length,
|
1206
|
+
dtype=torch.int32,
|
1207
|
+
device=self.device,
|
1208
|
+
),
|
1209
|
+
}
|
1210
|
+
|
1211
|
+
if (
|
1212
|
+
self.speculative_num_draft_tokens is not None
|
1213
|
+
and self.speculative_num_draft_tokens > 0
|
1214
|
+
):
|
1215
|
+
self.target_verify_metadata = {
|
1216
|
+
"cache_seqlens": torch.zeros(
|
1217
|
+
max_bs, dtype=torch.int32, device=self.device
|
1218
|
+
),
|
1219
|
+
"cu_seqlens_q": torch.arange(
|
1220
|
+
0,
|
1221
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1222
|
+
step=self.speculative_num_draft_tokens,
|
1223
|
+
dtype=torch.int32,
|
1224
|
+
device=self.device,
|
1225
|
+
),
|
1226
|
+
"cu_seqlens_k": torch.zeros(
|
1227
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1228
|
+
),
|
1229
|
+
"page_table": torch.zeros(
|
1230
|
+
max_bs,
|
1231
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
1232
|
+
dtype=torch.int32,
|
1233
|
+
device=self.device,
|
1234
|
+
),
|
1235
|
+
"strided_indices": torch.arange(
|
1236
|
+
0, self.max_context_len, self.page_size, device=self.device
|
1237
|
+
),
|
1238
|
+
}
|
1239
|
+
|
1240
|
+
if self.topk > 1:
|
1241
|
+
self.target_verify_metadata_topk_normal = {
|
1242
|
+
"cache_seqlens": torch.zeros(
|
1243
|
+
max_bs, dtype=torch.int32, device=self.device
|
1244
|
+
),
|
1245
|
+
"cu_seqlens_q": torch.arange(
|
1246
|
+
0,
|
1247
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1248
|
+
step=self.speculative_num_draft_tokens,
|
1249
|
+
dtype=torch.int32,
|
1250
|
+
device=self.device,
|
1251
|
+
),
|
1252
|
+
"cu_seqlens_k": torch.zeros(
|
1253
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1254
|
+
),
|
1255
|
+
"page_table": torch.zeros(
|
1256
|
+
max_bs,
|
1257
|
+
self.max_context_len,
|
1258
|
+
dtype=torch.int32,
|
1259
|
+
device=self.device,
|
1260
|
+
),
|
1261
|
+
}
|
1262
|
+
|
1263
|
+
self.target_verify_metadata_topk_expand = {
|
1264
|
+
"cache_seqlens": torch.zeros(
|
1265
|
+
max_bs * self.speculative_num_draft_tokens,
|
1266
|
+
dtype=torch.int32,
|
1267
|
+
device=self.device,
|
1268
|
+
),
|
1269
|
+
"cu_seqlens_k": torch.zeros(
|
1270
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1271
|
+
dtype=torch.int32,
|
1272
|
+
device=self.device,
|
1273
|
+
),
|
1274
|
+
"cu_seqlens_q": torch.arange(
|
1275
|
+
0,
|
1276
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1277
|
+
dtype=torch.int32,
|
1278
|
+
device=self.device,
|
1279
|
+
),
|
1280
|
+
"page_table": torch.zeros(
|
1281
|
+
max_bs * self.speculative_num_draft_tokens,
|
1282
|
+
self.speculative_num_draft_tokens,
|
1283
|
+
dtype=torch.int32,
|
1284
|
+
device=self.device,
|
1285
|
+
),
|
1286
|
+
}
|
861
1287
|
|
862
1288
|
self.encoder_metadata = {
|
863
1289
|
"encoder_page_table": torch.zeros(
|
@@ -886,28 +1312,78 @@ class FlashAttentionBackend(AttentionBackend):
|
|
886
1312
|
):
|
887
1313
|
"""Initialize forward metadata for capturing CUDA graph."""
|
888
1314
|
metadata = FlashAttentionMetadata()
|
1315
|
+
|
1316
|
+
# metadata_expand is needed for Spec Decoding when top k > 1
|
1317
|
+
metadata_expand = FlashAttentionMetadata()
|
1318
|
+
|
889
1319
|
device = seq_lens.device
|
890
1320
|
if forward_mode.is_decode_or_idle():
|
891
1321
|
if spec_info is not None:
|
892
1322
|
# Draft Decode
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
1323
|
+
if self.topk <= 1:
|
1324
|
+
# When topk = 1, we use the normal decode metadata
|
1325
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
1326
|
+
"cache_seqlens"
|
1327
|
+
][:bs]
|
1328
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
1329
|
+
self.speculative_step_id + 1
|
1330
|
+
)
|
1331
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
|
1332
|
+
"cu_seqlens_q"
|
1333
|
+
][: bs + 1]
|
1334
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
1335
|
+
torch.cumsum(
|
1336
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1337
|
+
),
|
1338
|
+
(1, 0),
|
1339
|
+
)
|
1340
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
1341
|
+
"page_table_draft_decode"
|
1342
|
+
][req_pool_indices, :]
|
1343
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1344
|
+
else:
|
1345
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1346
|
+
# 1. The first half of metadata for prefix tokens
|
1347
|
+
metadata.cache_seqlens_int32 = (
|
1348
|
+
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
|
1349
|
+
)
|
1350
|
+
metadata.max_seq_len_q = self.topk
|
1351
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
1352
|
+
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
|
1353
|
+
"cu_seqlens_q"
|
1354
|
+
][: bs + 1]
|
1355
|
+
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
|
1356
|
+
"cu_seqlens_k"
|
1357
|
+
][: bs + 1]
|
1358
|
+
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1359
|
+
"page_table"
|
1360
|
+
][req_pool_indices, :]
|
1361
|
+
|
1362
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1363
|
+
metadata_expand.cache_seqlens_int32 = (
|
1364
|
+
self.draft_decode_metadata_topk_expand["cache_seqlens"][
|
1365
|
+
: bs * self.topk
|
1366
|
+
]
|
1367
|
+
)
|
1368
|
+
metadata_expand.max_seq_len_q = 1
|
1369
|
+
metadata_expand.max_seq_len_k = (
|
1370
|
+
self.speculative_step_id + 1
|
1371
|
+
) # , do this in replay
|
1372
|
+
metadata_expand.cu_seqlens_q = (
|
1373
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
1374
|
+
: bs * self.topk + 1
|
1375
|
+
]
|
1376
|
+
)
|
1377
|
+
metadata_expand.cu_seqlens_k = (
|
1378
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
|
1379
|
+
: bs * self.topk + 1
|
1380
|
+
]
|
1381
|
+
)
|
1382
|
+
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
|
1383
|
+
"page_table"
|
1384
|
+
][: bs * self.topk]
|
1385
|
+
self.draft_decode_metadata_topk_normal[bs] = metadata
|
1386
|
+
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
|
911
1387
|
else:
|
912
1388
|
# Normal Decode
|
913
1389
|
# Get sequence information
|
@@ -927,37 +1403,77 @@ class FlashAttentionBackend(AttentionBackend):
|
|
927
1403
|
metadata.cu_seqlens_q = torch.arange(
|
928
1404
|
0, batch_size + 1, dtype=torch.int32, device=device
|
929
1405
|
)
|
930
|
-
|
1406
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1407
|
+
|
931
1408
|
elif forward_mode.is_target_verify():
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
1409
|
+
if self.topk <= 1:
|
1410
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
1411
|
+
"cache_seqlens"
|
1412
|
+
][:bs]
|
1413
|
+
metadata.cache_seqlens_int32.copy_(
|
1414
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1415
|
+
)
|
938
1416
|
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
1417
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1418
|
+
metadata.max_seq_len_k = (
|
1419
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
1420
|
+
)
|
943
1421
|
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
1422
|
+
metadata.cu_seqlens_q = torch.arange(
|
1423
|
+
0,
|
1424
|
+
bs * self.speculative_num_draft_tokens + 1,
|
1425
|
+
self.speculative_num_draft_tokens,
|
1426
|
+
dtype=torch.int32,
|
1427
|
+
device=device,
|
1428
|
+
)
|
951
1429
|
|
952
|
-
|
953
|
-
|
954
|
-
|
1430
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
1431
|
+
: (bs + 1)
|
1432
|
+
]
|
955
1433
|
|
956
|
-
|
957
|
-
|
958
|
-
|
1434
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
1435
|
+
req_pool_indices, :
|
1436
|
+
]
|
1437
|
+
|
1438
|
+
self.target_verify_metadata[bs] = metadata
|
1439
|
+
else:
|
1440
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1441
|
+
# 1. The first half of metadata for prefix tokens
|
1442
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
|
1443
|
+
"cache_seqlens"
|
1444
|
+
][:bs]
|
1445
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1446
|
+
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
|
1447
|
+
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
|
1448
|
+
"cu_seqlens_q"
|
1449
|
+
][: bs + 1]
|
1450
|
+
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
|
1451
|
+
"cu_seqlens_k"
|
1452
|
+
][: bs + 1]
|
1453
|
+
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1454
|
+
"page_table"
|
1455
|
+
][req_pool_indices, :]
|
1456
|
+
|
1457
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1458
|
+
metadata_expand.cache_seqlens_int32 = (
|
1459
|
+
self.target_verify_metadata_topk_expand["cache_seqlens"][
|
1460
|
+
: bs * self.speculative_num_draft_tokens
|
1461
|
+
]
|
1462
|
+
)
|
1463
|
+
metadata_expand.max_seq_len_q = 1
|
1464
|
+
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
|
1465
|
+
"cu_seqlens_q"
|
1466
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1467
|
+
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
|
1468
|
+
"cu_seqlens_k"
|
1469
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1470
|
+
|
1471
|
+
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
|
1472
|
+
"page_table"
|
1473
|
+
][: bs * self.speculative_num_draft_tokens]
|
959
1474
|
|
960
|
-
|
1475
|
+
self.target_verify_metadata_topk_normal[bs] = metadata
|
1476
|
+
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
961
1477
|
|
962
1478
|
if encoder_lens is not None:
|
963
1479
|
encoder_bs = encoder_lens.numel()
|
@@ -973,6 +1489,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
973
1489
|
]
|
974
1490
|
|
975
1491
|
self.forward_metadata = metadata
|
1492
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
976
1493
|
|
977
1494
|
def init_forward_metadata_replay_cuda_graph(
|
978
1495
|
self,
|
@@ -986,41 +1503,85 @@ class FlashAttentionBackend(AttentionBackend):
|
|
986
1503
|
seq_lens_cpu: Optional[torch.Tensor],
|
987
1504
|
out_cache_loc: torch.Tensor = None,
|
988
1505
|
):
|
989
|
-
|
1506
|
+
"""Initialize forward metadata for replaying CUDA graph."""
|
990
1507
|
seq_lens = seq_lens[:bs]
|
991
1508
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
992
1509
|
req_pool_indices = req_pool_indices[:bs]
|
993
1510
|
device = seq_lens.device
|
1511
|
+
metadata = None
|
1512
|
+
metadata_expand = None
|
994
1513
|
|
995
1514
|
if forward_mode.is_decode_or_idle():
|
996
|
-
metadata = self.decode_cuda_graph_metadata[bs]
|
997
1515
|
|
998
1516
|
if spec_info is not None:
|
999
1517
|
# Draft Decode
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1518
|
+
if self.topk <= 1:
|
1519
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1520
|
+
# When topk = 1, we use the normal decode metadata
|
1521
|
+
metadata.cache_seqlens_int32.copy_(
|
1522
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
1523
|
+
)
|
1003
1524
|
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1525
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1526
|
+
self.speculative_step_id + 1
|
1527
|
+
)
|
1528
|
+
metadata.cu_seqlens_k.copy_(
|
1529
|
+
torch.nn.functional.pad(
|
1530
|
+
torch.cumsum(
|
1531
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1532
|
+
),
|
1533
|
+
(1, 0),
|
1534
|
+
)
|
1013
1535
|
)
|
1014
|
-
)
|
1015
1536
|
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1537
|
+
max_seq_pages = (
|
1538
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1539
|
+
) // self.page_size
|
1540
|
+
page_indices = self.req_to_token[
|
1541
|
+
req_pool_indices[:, None],
|
1542
|
+
self.decode_cuda_graph_metadata["strided_indices"][
|
1543
|
+
:max_seq_pages
|
1544
|
+
],
|
1545
|
+
]
|
1546
|
+
|
1547
|
+
page_indices //= self.page_size
|
1548
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1549
|
+
else:
|
1550
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1551
|
+
# 1. The first half of metadata for prefix tokens
|
1552
|
+
metadata = self.draft_decode_metadata_topk_normal[bs]
|
1553
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1554
|
+
# metadata.max_seq_len_q = self.topk, already set in capture
|
1555
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1556
|
+
# metadata.cu_seqlens_q already set in capture
|
1557
|
+
metadata.cu_seqlens_k.copy_(
|
1558
|
+
torch.nn.functional.pad(
|
1559
|
+
torch.cumsum(
|
1560
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1561
|
+
),
|
1562
|
+
(1, 0),
|
1563
|
+
)
|
1564
|
+
)
|
1019
1565
|
|
1020
|
-
|
1566
|
+
page_table = self.req_to_token[
|
1567
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1568
|
+
]
|
1569
|
+
|
1570
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1021
1571
|
|
1572
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1573
|
+
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
1574
|
+
decode_length = self.speculative_step_id + 1
|
1575
|
+
cache_loc = out_cache_loc.view(
|
1576
|
+
self.speculative_num_steps, -1
|
1577
|
+
).T.contiguous()
|
1578
|
+
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1579
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1580
|
+
)
|
1581
|
+
# TODO: we need to test this part for llama 4 eagle case
|
1022
1582
|
self._init_local_attn_metadata(metadata, device)
|
1023
1583
|
else:
|
1584
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1024
1585
|
# Normal Decode
|
1025
1586
|
max_len = seq_lens_cpu.max().item()
|
1026
1587
|
metadata.max_seq_len_k = max_len
|
@@ -1045,24 +1606,117 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1045
1606
|
|
1046
1607
|
self._init_local_attn_metadata(metadata, device)
|
1047
1608
|
elif forward_mode.is_target_verify():
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1609
|
+
if self.topk <= 1:
|
1610
|
+
metadata = self.target_verify_metadata[bs]
|
1611
|
+
metadata.cache_seqlens_int32.copy_(
|
1612
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1613
|
+
)
|
1052
1614
|
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1615
|
+
metadata.max_seq_len_k = (
|
1616
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1617
|
+
)
|
1618
|
+
metadata.cu_seqlens_k.copy_(
|
1619
|
+
torch.nn.functional.pad(
|
1620
|
+
torch.cumsum(
|
1621
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1622
|
+
),
|
1623
|
+
(1, 0),
|
1624
|
+
)
|
1625
|
+
)
|
1626
|
+
max_seq_pages = (
|
1627
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1628
|
+
) // self.page_size
|
1629
|
+
page_indices = self.req_to_token[
|
1630
|
+
req_pool_indices[:, None],
|
1631
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
1632
|
+
]
|
1633
|
+
page_indices //= self.page_size
|
1634
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1635
|
+
else:
|
1636
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1637
|
+
# 1. The first half of metadata for prefix tokens
|
1638
|
+
metadata = self.target_verify_metadata_topk_normal[bs]
|
1639
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1640
|
+
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1641
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1642
|
+
# metadata.cu_seqlens_q already set in capture
|
1643
|
+
metadata.cu_seqlens_k.copy_(
|
1644
|
+
torch.nn.functional.pad(
|
1645
|
+
torch.cumsum(
|
1646
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1647
|
+
),
|
1648
|
+
(1, 0),
|
1649
|
+
)
|
1650
|
+
)
|
1651
|
+
page_table = self.req_to_token[
|
1652
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1653
|
+
]
|
1654
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1655
|
+
|
1656
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1657
|
+
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1658
|
+
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1659
|
+
# metadata_expand.cu_seqlens_q already set in capture
|
1660
|
+
|
1661
|
+
offsets = torch.arange(
|
1662
|
+
self.speculative_num_draft_tokens, device=device
|
1663
|
+
).unsqueeze(
|
1664
|
+
0
|
1665
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
1666
|
+
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1667
|
+
cum_len = torch.nn.functional.pad(
|
1058
1668
|
torch.cumsum(
|
1059
|
-
|
1669
|
+
(
|
1670
|
+
seq_lens + self.speculative_num_draft_tokens
|
1671
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
1672
|
+
dim=0,
|
1060
1673
|
),
|
1061
1674
|
(1, 0),
|
1675
|
+
)[:-1]
|
1676
|
+
mask_extraction_indices = (
|
1677
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1678
|
+
+ cum_len[:, None]
|
1679
|
+
).view(1, -1)
|
1680
|
+
# avoid extracting padded seq indices which will be out of boundary
|
1681
|
+
mask_extraction_indices[
|
1682
|
+
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
1683
|
+
].fill_(0)
|
1684
|
+
|
1685
|
+
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1686
|
+
-1, self.speculative_num_draft_tokens
|
1687
|
+
) # (bsz * draft_num, draft_num)
|
1688
|
+
col_indices = offsets.expand(
|
1689
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
1690
|
+
)
|
1691
|
+
keys = torch.where(
|
1692
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
1693
|
+
)
|
1694
|
+
_, sort_order = torch.sort(keys, dim=1)
|
1695
|
+
|
1696
|
+
non_masked_page_table = (
|
1697
|
+
self.req_to_token[req_pool_indices, :]
|
1698
|
+
.gather(1, cols)
|
1699
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1700
|
+
) # (bsz, draft_num)
|
1701
|
+
metadata_expand.page_table.copy_(
|
1702
|
+
non_masked_page_table.gather(1, sort_order)
|
1703
|
+
)
|
1704
|
+
metadata_expand.cache_seqlens_int32.copy_(
|
1705
|
+
mask.sum(dim=1).to(torch.int32)
|
1706
|
+
)
|
1707
|
+
metadata_expand.cu_seqlens_k.copy_(
|
1708
|
+
torch.nn.functional.pad(
|
1709
|
+
torch.cumsum(
|
1710
|
+
metadata_expand.cache_seqlens_int32,
|
1711
|
+
dim=0,
|
1712
|
+
dtype=torch.int32,
|
1713
|
+
),
|
1714
|
+
(1, 0),
|
1715
|
+
)
|
1716
|
+
)
|
1717
|
+
metadata_expand.max_seq_len_k = (
|
1718
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
1062
1719
|
)
|
1063
|
-
)
|
1064
|
-
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
1065
|
-
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1066
1720
|
|
1067
1721
|
if encoder_lens is not None:
|
1068
1722
|
# Only support encoder size 1 for now
|
@@ -1089,6 +1743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1089
1743
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1090
1744
|
|
1091
1745
|
self.forward_metadata = metadata
|
1746
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
1092
1747
|
|
1093
1748
|
def get_cuda_graph_seq_len_fill_value(self):
|
1094
1749
|
"""Get the fill value for sequence length in CUDA graph."""
|
@@ -1139,12 +1794,6 @@ class FlashAttentionMultiStepBackend:
|
|
1139
1794
|
self.model_runner = model_runner
|
1140
1795
|
self.topk = topk
|
1141
1796
|
self.speculative_num_steps = speculative_num_steps
|
1142
|
-
|
1143
|
-
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1144
|
-
assert (
|
1145
|
-
self.topk == 1
|
1146
|
-
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1147
|
-
|
1148
1797
|
self.attn_backends = []
|
1149
1798
|
for i in range(self.speculative_num_steps):
|
1150
1799
|
self.attn_backends.append(
|