sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|
16
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
17
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
18
|
|
19
|
+
from sgl_kernel import merge_state_v2
|
19
20
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
20
21
|
|
21
22
|
|
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
|
|
30
31
|
# Sequence lengths for the forward batch
|
31
32
|
cache_seqlens_int32: torch.Tensor = None
|
32
33
|
# Maximum sequence length for query
|
33
|
-
max_seq_len_q: int =
|
34
|
+
max_seq_len_q: int = 1
|
34
35
|
# Maximum sequence length for key
|
35
36
|
max_seq_len_k: int = 0
|
36
37
|
# Cumulative sequence lengths for query
|
@@ -142,6 +143,16 @@ def make_local_attention_virtual_batches(
|
|
142
143
|
seqlens_k_local: Key sequence lengths for local attention
|
143
144
|
block_table_local: Block table for local attention
|
144
145
|
"""
|
146
|
+
# Adjust attention_chunk_size based on the actual sequence length
|
147
|
+
# to avoid index out of bounds errors
|
148
|
+
max_seq_len = seq_lens_np.max()
|
149
|
+
effective_chunk_size = min(attn_chunk_size, max_seq_len)
|
150
|
+
# Make sure effective_chunk_size is divisible by page_size
|
151
|
+
effective_chunk_size = (effective_chunk_size // page_size) * page_size
|
152
|
+
if effective_chunk_size < page_size:
|
153
|
+
effective_chunk_size = page_size
|
154
|
+
attn_chunk_size = effective_chunk_size
|
155
|
+
|
145
156
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
146
157
|
actual_batch_size = seq_lens_np.shape[0]
|
147
158
|
|
@@ -257,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
|
|
257
268
|
return -(a // -b)
|
258
269
|
|
259
270
|
|
271
|
+
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
|
272
|
+
@torch._dynamo.disable()
|
273
|
+
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
|
274
|
+
return merge_state_v2(o, s_a, o_exp, s_b)
|
275
|
+
|
276
|
+
|
260
277
|
class FlashAttentionBackend(AttentionBackend):
|
261
278
|
"""FlashAttention backend implementation.
|
262
279
|
|
@@ -291,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
291
308
|
), "Sliding window and cross attention are not supported together"
|
292
309
|
|
293
310
|
self.forward_metadata: FlashAttentionMetadata = None
|
311
|
+
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
|
312
|
+
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
294
313
|
self.max_context_len = model_runner.model_config.context_len
|
295
314
|
self.device = model_runner.device
|
296
315
|
self.decode_cuda_graph_metadata = {}
|
@@ -299,12 +318,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
299
318
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
300
319
|
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
301
320
|
self.page_size = model_runner.page_size
|
302
|
-
self.use_mla =
|
303
|
-
model_runner.model_config.attention_arch == AttentionArch.MLA
|
304
|
-
) and (not global_server_args_dict["disable_mla"])
|
321
|
+
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
305
322
|
self.skip_prefill = skip_prefill
|
306
|
-
|
307
|
-
self.topk = topk
|
323
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
308
324
|
self.speculative_num_steps = speculative_num_steps
|
309
325
|
self.speculative_num_draft_tokens = (
|
310
326
|
model_runner.server_args.speculative_num_draft_tokens
|
@@ -328,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
|
|
328
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
329
345
|
# Draft Decode
|
330
346
|
if forward_batch.spec_info is not None:
|
347
|
+
if self.topk <= 1:
|
348
|
+
metadata.cache_seqlens_int32 = (
|
349
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
350
|
+
).to(torch.int32)
|
351
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
352
|
+
self.speculative_step_id + 1
|
353
|
+
)
|
354
|
+
metadata.cu_seqlens_q = torch.arange(
|
355
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
356
|
+
)
|
357
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
358
|
+
torch.cumsum(
|
359
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
360
|
+
),
|
361
|
+
(1, 0),
|
362
|
+
)
|
363
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
364
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
365
|
+
]
|
366
|
+
else:
|
367
|
+
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
368
|
+
metadata.max_seq_len_q = self.topk
|
369
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
370
|
+
metadata.cu_seqlens_q = torch.arange(
|
371
|
+
0,
|
372
|
+
batch_size * self.topk + 1,
|
373
|
+
step=self.topk,
|
374
|
+
dtype=torch.int32,
|
375
|
+
device=device,
|
376
|
+
)
|
377
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
378
|
+
torch.cumsum(
|
379
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
380
|
+
),
|
381
|
+
(1, 0),
|
382
|
+
)
|
383
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
384
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
385
|
+
]
|
386
|
+
|
387
|
+
metadata_expand = FlashAttentionMetadata()
|
388
|
+
decode_length = self.speculative_step_id + 1
|
389
|
+
metadata_expand.cache_seqlens_int32 = torch.full(
|
390
|
+
(seqlens_in_batch.numel() * self.topk,),
|
391
|
+
decode_length,
|
392
|
+
device=device,
|
393
|
+
dtype=torch.int32,
|
394
|
+
)
|
395
|
+
metadata_expand.max_seq_len_q = 1
|
396
|
+
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
397
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
398
|
+
0,
|
399
|
+
metadata_expand.cache_seqlens_int32.numel() + 1,
|
400
|
+
dtype=torch.int32,
|
401
|
+
device=device,
|
402
|
+
)
|
403
|
+
metadata_expand.cu_seqlens_k = torch.arange(
|
404
|
+
0,
|
405
|
+
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
406
|
+
step=decode_length,
|
407
|
+
dtype=torch.int32,
|
408
|
+
device=device,
|
409
|
+
)
|
410
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
411
|
+
self.speculative_num_steps, -1
|
412
|
+
).T.contiguous()
|
413
|
+
metadata_expand.page_table = (
|
414
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
415
|
+
)
|
416
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
417
|
+
else:
|
418
|
+
# Normal Decode
|
419
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
420
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
421
|
+
metadata.cu_seqlens_q = torch.arange(
|
422
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
423
|
+
)
|
424
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
425
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
426
|
+
)
|
427
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
428
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
429
|
+
]
|
430
|
+
# TODO: we need to test this part for llama 4 eagle case
|
431
|
+
self._init_local_attn_metadata(metadata, device)
|
432
|
+
elif forward_batch.forward_mode.is_target_verify():
|
433
|
+
if self.topk <= 1:
|
331
434
|
metadata.cache_seqlens_int32 = (
|
332
|
-
|
435
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
333
436
|
).to(torch.int32)
|
334
|
-
metadata.
|
335
|
-
|
437
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
438
|
+
metadata.max_seq_len_k = (
|
439
|
+
forward_batch.seq_lens_cpu.max().item()
|
440
|
+
+ self.speculative_num_draft_tokens
|
336
441
|
)
|
337
442
|
metadata.cu_seqlens_q = torch.arange(
|
338
|
-
0,
|
443
|
+
0,
|
444
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
445
|
+
self.speculative_num_draft_tokens,
|
446
|
+
dtype=torch.int32,
|
447
|
+
device=device,
|
339
448
|
)
|
340
449
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
341
450
|
torch.cumsum(
|
@@ -346,43 +455,104 @@ class FlashAttentionBackend(AttentionBackend):
|
|
346
455
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
347
456
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
348
457
|
]
|
458
|
+
|
459
|
+
self._init_local_attn_metadata(metadata, device)
|
349
460
|
else:
|
350
|
-
|
351
|
-
metadata.
|
461
|
+
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
462
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
352
463
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
353
464
|
metadata.cu_seqlens_q = torch.arange(
|
354
|
-
0,
|
465
|
+
0,
|
466
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
467
|
+
step=self.speculative_num_draft_tokens,
|
468
|
+
dtype=torch.int32,
|
469
|
+
device=device,
|
355
470
|
)
|
356
471
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
357
|
-
torch.cumsum(
|
472
|
+
torch.cumsum(
|
473
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
474
|
+
),
|
475
|
+
(1, 0),
|
358
476
|
)
|
359
477
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
360
478
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
361
479
|
]
|
362
|
-
elif forward_batch.forward_mode.is_target_verify():
|
363
|
-
metadata.cache_seqlens_int32 = (
|
364
|
-
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
365
|
-
).to(torch.int32)
|
366
|
-
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
367
|
-
metadata.max_seq_len_k = (
|
368
|
-
forward_batch.seq_lens_cpu.max().item()
|
369
|
-
+ self.speculative_num_draft_tokens
|
370
|
-
)
|
371
|
-
metadata.cu_seqlens_q = torch.arange(
|
372
|
-
0,
|
373
|
-
batch_size * self.speculative_num_draft_tokens + 1,
|
374
|
-
self.speculative_num_draft_tokens,
|
375
|
-
dtype=torch.int32,
|
376
|
-
device=device,
|
377
|
-
)
|
378
|
-
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
379
|
-
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
380
|
-
(1, 0),
|
381
|
-
)
|
382
|
-
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
383
|
-
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
384
|
-
]
|
385
480
|
|
481
|
+
metadata_expand = FlashAttentionMetadata()
|
482
|
+
|
483
|
+
metadata_expand.max_seq_len_q = 1
|
484
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
485
|
+
0,
|
486
|
+
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
487
|
+
+ 1,
|
488
|
+
dtype=torch.int32,
|
489
|
+
device=device,
|
490
|
+
)
|
491
|
+
|
492
|
+
# create expand page table
|
493
|
+
offsets = torch.arange(
|
494
|
+
self.speculative_num_draft_tokens, device=device
|
495
|
+
).unsqueeze(
|
496
|
+
0
|
497
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
498
|
+
cols = offsets.expand(
|
499
|
+
forward_batch.seq_lens.numel(), -1
|
500
|
+
) + forward_batch.seq_lens.unsqueeze(1)
|
501
|
+
cum_len = torch.nn.functional.pad(
|
502
|
+
torch.cumsum(
|
503
|
+
(
|
504
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
505
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
506
|
+
dim=0,
|
507
|
+
),
|
508
|
+
(1, 0),
|
509
|
+
)[:-1]
|
510
|
+
mask_extraction_indices = (
|
511
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
512
|
+
+ cum_len[:, None]
|
513
|
+
).view(1, -1)
|
514
|
+
mask = forward_batch.spec_info.custom_mask[
|
515
|
+
mask_extraction_indices
|
516
|
+
].view(
|
517
|
+
-1, self.speculative_num_draft_tokens
|
518
|
+
) # (bsz * draft_num, draft_num)
|
519
|
+
|
520
|
+
# shift table indices to avoid padding
|
521
|
+
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
522
|
+
# [8, 9, 10], [1, 1, 0],
|
523
|
+
# [8, 9, 10]] [1, 0, 1]]
|
524
|
+
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
525
|
+
# [8, 9, 0], [8, 9, 10],
|
526
|
+
# [8, 0, 10]] [8, 10, 9]]
|
527
|
+
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
528
|
+
col_indices = offsets.expand(
|
529
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
530
|
+
)
|
531
|
+
# Build keys: if an entry is valid (mask==True), keep its original index;
|
532
|
+
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
533
|
+
keys = torch.where(
|
534
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
535
|
+
)
|
536
|
+
_, sort_order = torch.sort(keys, dim=1)
|
537
|
+
non_masked_page_table = (
|
538
|
+
forward_batch.req_to_token_pool.req_to_token[
|
539
|
+
forward_batch.req_pool_indices, :
|
540
|
+
]
|
541
|
+
.gather(1, cols)
|
542
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
543
|
+
) # (bsz, draft_num)
|
544
|
+
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
545
|
+
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
546
|
+
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
547
|
+
torch.cumsum(
|
548
|
+
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
549
|
+
),
|
550
|
+
(1, 0),
|
551
|
+
)
|
552
|
+
metadata_expand.max_seq_len_k = (
|
553
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
554
|
+
)
|
555
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
386
556
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
387
557
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
388
558
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -407,49 +577,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
407
577
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
408
578
|
|
409
579
|
# Setup local attention if enabled
|
410
|
-
if
|
411
|
-
self.
|
412
|
-
and forward_batch.forward_mode == ForwardMode.EXTEND
|
413
|
-
):
|
414
|
-
# Convert tensors to numpy for local attention processing
|
415
|
-
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
416
|
-
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
|
417
|
-
|
418
|
-
# Adjust attention_chunk_size based on the actual sequence length
|
419
|
-
# to avoid index out of bounds errors
|
420
|
-
max_seq_len = seq_lens_np.max()
|
421
|
-
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
|
422
|
-
# Make sure effective_chunk_size is divisible by page_size
|
423
|
-
effective_chunk_size = (
|
424
|
-
effective_chunk_size // self.page_size
|
425
|
-
) * self.page_size
|
426
|
-
if effective_chunk_size < self.page_size:
|
427
|
-
effective_chunk_size = self.page_size
|
428
|
-
|
429
|
-
# Create local attention metadata
|
430
|
-
(
|
431
|
-
seqlens_q_local_np,
|
432
|
-
cu_seqlens_q_local_np,
|
433
|
-
seqlens_k_local_np,
|
434
|
-
block_table_local,
|
435
|
-
) = make_local_attention_virtual_batches(
|
436
|
-
effective_chunk_size,
|
437
|
-
cu_seqlens_q_np,
|
438
|
-
seq_lens_np,
|
439
|
-
metadata.page_table,
|
440
|
-
self.page_size,
|
441
|
-
)
|
442
|
-
|
443
|
-
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
444
|
-
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
|
445
|
-
device
|
446
|
-
),
|
447
|
-
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
448
|
-
local_block_table=block_table_local,
|
449
|
-
local_max_query_len=seqlens_q_local_np.max(),
|
450
|
-
local_max_seq_len=seqlens_k_local_np.max(),
|
451
|
-
)
|
452
|
-
metadata.local_attn_metadata = local_metadata
|
580
|
+
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
581
|
+
self._init_local_attn_metadata(metadata, device)
|
453
582
|
|
454
583
|
# Encoder metadata for cross attention
|
455
584
|
if forward_batch.encoder_lens is not None:
|
@@ -543,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
543
672
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
544
673
|
)
|
545
674
|
|
675
|
+
# We do cascade attention for Target Verify with topk > 1
|
676
|
+
use_cascade_attn = (
|
677
|
+
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
678
|
+
)
|
679
|
+
|
546
680
|
# Get the appropriate page table based on whether we're using local attention
|
547
681
|
if use_local_attn:
|
548
682
|
local_metadata = metadata.local_attn_metadata
|
@@ -577,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
577
711
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
578
712
|
window_size = (-1, -1)
|
579
713
|
|
580
|
-
|
714
|
+
result = flash_attn_with_kvcache(
|
581
715
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
582
716
|
k_cache=key_cache,
|
583
717
|
v_cache=value_cache,
|
@@ -587,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
|
|
587
721
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
588
722
|
max_seqlen_q=max_seqlen_q,
|
589
723
|
softmax_scale=layer.scaling,
|
590
|
-
causal=causal,
|
724
|
+
causal=False if use_cascade_attn else causal,
|
591
725
|
window_size=window_size,
|
592
726
|
softcap=layer.logit_cap,
|
593
727
|
k_descale=k_descale,
|
594
728
|
v_descale=v_descale,
|
729
|
+
return_softmax_lse=use_cascade_attn,
|
595
730
|
)
|
596
|
-
|
731
|
+
|
732
|
+
if use_cascade_attn:
|
733
|
+
o, softmax_lse, *rest = result
|
734
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
735
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
736
|
+
k_cache=key_cache,
|
737
|
+
v_cache=value_cache,
|
738
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
739
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
740
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
741
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
742
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
743
|
+
softmax_scale=layer.scaling,
|
744
|
+
causal=False,
|
745
|
+
window_size=window_size,
|
746
|
+
softcap=layer.logit_cap,
|
747
|
+
k_descale=k_descale,
|
748
|
+
v_descale=v_descale,
|
749
|
+
return_softmax_lse=True,
|
750
|
+
)
|
751
|
+
o, _ = merge_state_v2_wrapper(
|
752
|
+
o,
|
753
|
+
softmax_lse.T.contiguous(),
|
754
|
+
o_expand,
|
755
|
+
softmax_lse_expand.T.contiguous(),
|
756
|
+
)
|
757
|
+
else:
|
758
|
+
o = result
|
597
759
|
else:
|
598
760
|
if (
|
599
761
|
not global_server_args_dict["disable_chunked_prefix_cache"]
|
@@ -656,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
656
818
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
657
819
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
658
820
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
659
|
-
|
821
|
+
|
822
|
+
result = flash_attn_with_kvcache(
|
660
823
|
q=q_rope,
|
661
824
|
k_cache=k_rope_cache,
|
662
825
|
v_cache=c_kv_cache,
|
@@ -667,13 +830,44 @@ class FlashAttentionBackend(AttentionBackend):
|
|
667
830
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
668
831
|
max_seqlen_q=max_seqlen_q,
|
669
832
|
softmax_scale=layer.scaling,
|
670
|
-
causal=
|
833
|
+
causal=False if use_cascade_attn else causal,
|
671
834
|
softcap=layer.logit_cap,
|
672
835
|
k_descale=k_descale,
|
673
836
|
v_descale=v_descale,
|
837
|
+
return_softmax_lse=use_cascade_attn,
|
674
838
|
)
|
839
|
+
if use_cascade_attn:
|
840
|
+
o, softmax_lse, *rest = result
|
841
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
842
|
+
flash_attn_with_kvcache(
|
843
|
+
q=q_rope,
|
844
|
+
k_cache=k_rope_cache,
|
845
|
+
v_cache=c_kv_cache,
|
846
|
+
qv=q_nope,
|
847
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
848
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
849
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
850
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
851
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
852
|
+
softmax_scale=layer.scaling,
|
853
|
+
causal=False,
|
854
|
+
window_size=window_size,
|
855
|
+
softcap=layer.logit_cap,
|
856
|
+
k_descale=k_descale,
|
857
|
+
v_descale=v_descale,
|
858
|
+
return_softmax_lse=True,
|
859
|
+
)
|
860
|
+
)
|
861
|
+
o, _ = merge_state_v2_wrapper(
|
862
|
+
o,
|
863
|
+
softmax_lse.T.contiguous(),
|
864
|
+
o_expand,
|
865
|
+
softmax_lse_expand.T.contiguous(),
|
866
|
+
)
|
867
|
+
else:
|
868
|
+
o = result
|
675
869
|
|
676
|
-
|
870
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
677
871
|
|
678
872
|
def forward_decode(
|
679
873
|
self,
|
@@ -706,6 +900,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
706
900
|
|
707
901
|
# Use precomputed metadata across all layers
|
708
902
|
metadata = self.forward_metadata
|
903
|
+
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
904
|
+
use_local_attention = (
|
905
|
+
self.attention_chunk_size is not None and local_attn_metadata is not None
|
906
|
+
)
|
907
|
+
# We do cascade attention for Draft Decode with topk > 1
|
908
|
+
use_cascade_attn = self.topk > 1
|
709
909
|
|
710
910
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
711
911
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
@@ -740,33 +940,98 @@ class FlashAttentionBackend(AttentionBackend):
|
|
740
940
|
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
741
941
|
)
|
742
942
|
|
743
|
-
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
744
943
|
if layer.is_cross_attention:
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
944
|
+
# Always use non-chunked logic for cross-attention
|
945
|
+
o = flash_attn_with_kvcache(
|
946
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
947
|
+
k_cache=key_cache,
|
948
|
+
v_cache=value_cache,
|
949
|
+
page_table=metadata.encoder_page_table,
|
950
|
+
cache_seqlens=metadata.encoder_lens_int32,
|
951
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
952
|
+
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
|
953
|
+
max_seqlen_q=1,
|
954
|
+
softmax_scale=layer.scaling,
|
955
|
+
causal=False,
|
956
|
+
window_size=(-1, -1),
|
957
|
+
softcap=layer.logit_cap,
|
958
|
+
k_descale=k_descale,
|
959
|
+
v_descale=v_descale,
|
960
|
+
)
|
961
|
+
elif use_local_attention:
|
962
|
+
# Use chunked (local) attention batching for self-attention
|
963
|
+
o = flash_attn_with_kvcache(
|
964
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
965
|
+
k_cache=key_cache,
|
966
|
+
v_cache=value_cache,
|
967
|
+
page_table=local_attn_metadata.local_block_table,
|
968
|
+
cache_seqlens=local_attn_metadata.local_seqused_k,
|
969
|
+
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
970
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
971
|
+
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
972
|
+
softmax_scale=layer.scaling,
|
973
|
+
causal=True,
|
974
|
+
window_size=(-1, -1),
|
975
|
+
softcap=layer.logit_cap,
|
976
|
+
k_descale=k_descale,
|
977
|
+
v_descale=v_descale,
|
978
|
+
)
|
749
979
|
else:
|
750
980
|
page_table = metadata.page_table
|
751
981
|
cache_seqlens = metadata.cache_seqlens_int32
|
752
982
|
cu_seqlens_k = metadata.cu_seqlens_k
|
983
|
+
max_seqlen_q = metadata.max_seq_len_q
|
984
|
+
q_reshaped = q.contiguous().view(
|
985
|
+
-1, layer.tp_q_head_num, layer.head_dim
|
986
|
+
)
|
753
987
|
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
988
|
+
# Default: single-token self-attention
|
989
|
+
result = flash_attn_with_kvcache(
|
990
|
+
q=q_reshaped,
|
991
|
+
k_cache=key_cache,
|
992
|
+
v_cache=value_cache,
|
993
|
+
page_table=page_table,
|
994
|
+
cache_seqlens=cache_seqlens,
|
995
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
996
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
997
|
+
max_seqlen_q=max_seqlen_q,
|
998
|
+
softmax_scale=layer.scaling,
|
999
|
+
causal=False if use_cascade_attn else causal,
|
1000
|
+
window_size=window_size,
|
1001
|
+
softcap=layer.logit_cap,
|
1002
|
+
k_descale=k_descale,
|
1003
|
+
v_descale=v_descale,
|
1004
|
+
return_softmax_lse=use_cascade_attn,
|
1005
|
+
)
|
1006
|
+
if use_cascade_attn:
|
1007
|
+
o, softmax_lse, *rest = result
|
1008
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
1009
|
+
flash_attn_with_kvcache(
|
1010
|
+
q=q_reshaped,
|
1011
|
+
k_cache=key_cache,
|
1012
|
+
v_cache=value_cache,
|
1013
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1014
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1015
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1016
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1017
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1018
|
+
softmax_scale=layer.scaling,
|
1019
|
+
causal=False,
|
1020
|
+
window_size=window_size,
|
1021
|
+
softcap=layer.logit_cap,
|
1022
|
+
k_descale=k_descale,
|
1023
|
+
v_descale=v_descale,
|
1024
|
+
return_softmax_lse=True,
|
1025
|
+
)
|
1026
|
+
)
|
1027
|
+
o, _ = merge_state_v2(
|
1028
|
+
o,
|
1029
|
+
softmax_lse.T.contiguous(),
|
1030
|
+
o_expand,
|
1031
|
+
softmax_lse_expand.T.contiguous(),
|
1032
|
+
)
|
1033
|
+
else:
|
1034
|
+
o = result
|
770
1035
|
else:
|
771
1036
|
# Do absorbed multi-latent attention
|
772
1037
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
@@ -785,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
785
1050
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
786
1051
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
787
1052
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
1053
|
+
max_seqlen_q = metadata.max_seq_len_q
|
788
1054
|
|
789
|
-
|
1055
|
+
result = flash_attn_with_kvcache(
|
790
1056
|
q=q_rope,
|
791
1057
|
k_cache=k_rope_cache,
|
792
1058
|
v_cache=c_kv_cache,
|
@@ -795,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
|
|
795
1061
|
cache_seqlens=metadata.cache_seqlens_int32,
|
796
1062
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
797
1063
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
798
|
-
max_seqlen_q=
|
1064
|
+
max_seqlen_q=max_seqlen_q,
|
799
1065
|
softmax_scale=layer.scaling,
|
800
|
-
causal=
|
1066
|
+
causal=False if use_cascade_attn else causal,
|
801
1067
|
softcap=layer.logit_cap,
|
802
1068
|
k_descale=k_descale,
|
803
1069
|
v_descale=v_descale,
|
1070
|
+
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
804
1071
|
)
|
1072
|
+
if use_cascade_attn:
|
1073
|
+
o, softmax_lse, *rest = result
|
1074
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
1075
|
+
q=q_rope,
|
1076
|
+
k_cache=k_rope_cache,
|
1077
|
+
v_cache=c_kv_cache,
|
1078
|
+
qv=q_nope,
|
1079
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1080
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1081
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1082
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1083
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1084
|
+
softmax_scale=layer.scaling,
|
1085
|
+
causal=False,
|
1086
|
+
window_size=window_size,
|
1087
|
+
softcap=layer.logit_cap,
|
1088
|
+
k_descale=k_descale,
|
1089
|
+
v_descale=v_descale,
|
1090
|
+
return_softmax_lse=True,
|
1091
|
+
)
|
1092
|
+
o, _ = merge_state_v2(
|
1093
|
+
o,
|
1094
|
+
softmax_lse.T.contiguous(),
|
1095
|
+
o_expand,
|
1096
|
+
softmax_lse_expand.T.contiguous(),
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
o = result
|
1100
|
+
|
805
1101
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
806
1102
|
|
807
1103
|
def init_cuda_graph_state(self, max_bs: int):
|
@@ -813,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
813
1109
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
814
1110
|
to avoid memory allocations.
|
815
1111
|
"""
|
1112
|
+
|
1113
|
+
# This is being used by normal decode and draft decode when topk == 1
|
816
1114
|
self.decode_cuda_graph_metadata = {
|
817
1115
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
818
1116
|
"cu_seqlens_q": torch.arange(
|
@@ -838,24 +1136,136 @@ class FlashAttentionBackend(AttentionBackend):
|
|
838
1136
|
),
|
839
1137
|
}
|
840
1138
|
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
1139
|
+
# This is used by draft decode's first half of metadata when topk > 1
|
1140
|
+
if self.topk > 1:
|
1141
|
+
self.draft_decode_metadata_topk_normal = {
|
1142
|
+
"cache_seqlens": torch.zeros(
|
1143
|
+
max_bs, dtype=torch.int32, device=self.device
|
1144
|
+
),
|
1145
|
+
"cu_seqlens_q": torch.arange(
|
1146
|
+
0,
|
1147
|
+
max_bs * self.topk + 1,
|
1148
|
+
step=self.topk,
|
1149
|
+
dtype=torch.int32,
|
1150
|
+
device=self.device,
|
1151
|
+
),
|
1152
|
+
"cu_seqlens_k": torch.zeros(
|
1153
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1154
|
+
),
|
1155
|
+
"page_table": torch.zeros(
|
1156
|
+
max_bs,
|
1157
|
+
self.max_context_len,
|
1158
|
+
dtype=torch.int32,
|
1159
|
+
device=self.device,
|
1160
|
+
),
|
1161
|
+
}
|
1162
|
+
|
1163
|
+
# This is used by draft decode's second half of metadata when topk > 1
|
1164
|
+
decode_length = self.speculative_step_id + 1
|
1165
|
+
self.draft_decode_metadata_topk_expand = {
|
1166
|
+
"cache_seqlens": torch.full(
|
1167
|
+
(max_bs * self.topk,),
|
1168
|
+
decode_length,
|
1169
|
+
device=self.device,
|
1170
|
+
dtype=torch.int32,
|
1171
|
+
),
|
1172
|
+
"cu_seqlens_q": torch.arange(
|
1173
|
+
0,
|
1174
|
+
max_bs * self.topk + 1,
|
1175
|
+
dtype=torch.int32,
|
1176
|
+
device=self.device,
|
1177
|
+
),
|
1178
|
+
"cu_seqlens_k": torch.arange(
|
1179
|
+
0,
|
1180
|
+
max_bs * self.topk * decode_length + 1,
|
1181
|
+
step=decode_length,
|
1182
|
+
dtype=torch.int32,
|
1183
|
+
device=self.device,
|
1184
|
+
),
|
1185
|
+
"page_table": torch.zeros(
|
1186
|
+
max_bs * self.topk,
|
1187
|
+
decode_length,
|
1188
|
+
dtype=torch.int32,
|
1189
|
+
device=self.device,
|
1190
|
+
),
|
1191
|
+
}
|
1192
|
+
|
1193
|
+
if (
|
1194
|
+
self.speculative_num_draft_tokens is not None
|
1195
|
+
and self.speculative_num_draft_tokens > 0
|
1196
|
+
):
|
1197
|
+
self.target_verify_metadata = {
|
1198
|
+
"cache_seqlens": torch.zeros(
|
1199
|
+
max_bs, dtype=torch.int32, device=self.device
|
1200
|
+
),
|
1201
|
+
"cu_seqlens_q": torch.arange(
|
1202
|
+
0,
|
1203
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1204
|
+
step=self.speculative_num_draft_tokens,
|
1205
|
+
dtype=torch.int32,
|
1206
|
+
device=self.device,
|
1207
|
+
),
|
1208
|
+
"cu_seqlens_k": torch.zeros(
|
1209
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1210
|
+
),
|
1211
|
+
"page_table": torch.zeros(
|
1212
|
+
max_bs,
|
1213
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
1214
|
+
dtype=torch.int32,
|
1215
|
+
device=self.device,
|
1216
|
+
),
|
1217
|
+
"strided_indices": torch.arange(
|
1218
|
+
0, self.max_context_len, self.page_size, device=self.device
|
1219
|
+
),
|
1220
|
+
}
|
1221
|
+
|
1222
|
+
if self.topk > 1:
|
1223
|
+
self.target_verify_metadata_topk_normal = {
|
1224
|
+
"cache_seqlens": torch.zeros(
|
1225
|
+
max_bs, dtype=torch.int32, device=self.device
|
1226
|
+
),
|
1227
|
+
"cu_seqlens_q": torch.arange(
|
1228
|
+
0,
|
1229
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1230
|
+
step=self.speculative_num_draft_tokens,
|
1231
|
+
dtype=torch.int32,
|
1232
|
+
device=self.device,
|
1233
|
+
),
|
1234
|
+
"cu_seqlens_k": torch.zeros(
|
1235
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1236
|
+
),
|
1237
|
+
"page_table": torch.zeros(
|
1238
|
+
max_bs,
|
1239
|
+
self.max_context_len,
|
1240
|
+
dtype=torch.int32,
|
1241
|
+
device=self.device,
|
1242
|
+
),
|
1243
|
+
}
|
1244
|
+
|
1245
|
+
self.target_verify_metadata_topk_expand = {
|
1246
|
+
"cache_seqlens": torch.zeros(
|
1247
|
+
max_bs * self.speculative_num_draft_tokens,
|
1248
|
+
dtype=torch.int32,
|
1249
|
+
device=self.device,
|
1250
|
+
),
|
1251
|
+
"cu_seqlens_k": torch.zeros(
|
1252
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1253
|
+
dtype=torch.int32,
|
1254
|
+
device=self.device,
|
1255
|
+
),
|
1256
|
+
"cu_seqlens_q": torch.arange(
|
1257
|
+
0,
|
1258
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1259
|
+
dtype=torch.int32,
|
1260
|
+
device=self.device,
|
1261
|
+
),
|
1262
|
+
"page_table": torch.zeros(
|
1263
|
+
max_bs * self.speculative_num_draft_tokens,
|
1264
|
+
self.speculative_num_draft_tokens,
|
1265
|
+
dtype=torch.int32,
|
1266
|
+
device=self.device,
|
1267
|
+
),
|
1268
|
+
}
|
859
1269
|
|
860
1270
|
self.encoder_metadata = {
|
861
1271
|
"encoder_page_table": torch.zeros(
|
@@ -884,28 +1294,78 @@ class FlashAttentionBackend(AttentionBackend):
|
|
884
1294
|
):
|
885
1295
|
"""Initialize forward metadata for capturing CUDA graph."""
|
886
1296
|
metadata = FlashAttentionMetadata()
|
1297
|
+
|
1298
|
+
# metadata_expand is needed for Spec Decoding when top k > 1
|
1299
|
+
metadata_expand = FlashAttentionMetadata()
|
1300
|
+
|
887
1301
|
device = seq_lens.device
|
888
1302
|
if forward_mode.is_decode_or_idle():
|
889
1303
|
if spec_info is not None:
|
890
1304
|
# Draft Decode
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
1305
|
+
if self.topk <= 1:
|
1306
|
+
# When topk = 1, we use the normal decode metadata
|
1307
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
1308
|
+
"cache_seqlens"
|
1309
|
+
][:bs]
|
1310
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
1311
|
+
self.speculative_step_id + 1
|
1312
|
+
)
|
1313
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
|
1314
|
+
"cu_seqlens_q"
|
1315
|
+
][: bs + 1]
|
1316
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
1317
|
+
torch.cumsum(
|
1318
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1319
|
+
),
|
1320
|
+
(1, 0),
|
1321
|
+
)
|
1322
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
1323
|
+
"page_table_draft_decode"
|
1324
|
+
][req_pool_indices, :]
|
1325
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1326
|
+
else:
|
1327
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1328
|
+
# 1. The first half of metadata for prefix tokens
|
1329
|
+
metadata.cache_seqlens_int32 = (
|
1330
|
+
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
|
1331
|
+
)
|
1332
|
+
metadata.max_seq_len_q = self.topk
|
1333
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
1334
|
+
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
|
1335
|
+
"cu_seqlens_q"
|
1336
|
+
][: bs + 1]
|
1337
|
+
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
|
1338
|
+
"cu_seqlens_k"
|
1339
|
+
][: bs + 1]
|
1340
|
+
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1341
|
+
"page_table"
|
1342
|
+
][req_pool_indices, :]
|
1343
|
+
|
1344
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1345
|
+
metadata_expand.cache_seqlens_int32 = (
|
1346
|
+
self.draft_decode_metadata_topk_expand["cache_seqlens"][
|
1347
|
+
: bs * self.topk
|
1348
|
+
]
|
1349
|
+
)
|
1350
|
+
metadata_expand.max_seq_len_q = 1
|
1351
|
+
metadata_expand.max_seq_len_k = (
|
1352
|
+
self.speculative_step_id + 1
|
1353
|
+
) # , do this in replay
|
1354
|
+
metadata_expand.cu_seqlens_q = (
|
1355
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
1356
|
+
: bs * self.topk + 1
|
1357
|
+
]
|
1358
|
+
)
|
1359
|
+
metadata_expand.cu_seqlens_k = (
|
1360
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
|
1361
|
+
: bs * self.topk + 1
|
1362
|
+
]
|
1363
|
+
)
|
1364
|
+
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
|
1365
|
+
"page_table"
|
1366
|
+
][: bs * self.topk]
|
1367
|
+
self.draft_decode_metadata_topk_normal[bs] = metadata
|
1368
|
+
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
|
909
1369
|
else:
|
910
1370
|
# Normal Decode
|
911
1371
|
# Get sequence information
|
@@ -925,37 +1385,77 @@ class FlashAttentionBackend(AttentionBackend):
|
|
925
1385
|
metadata.cu_seqlens_q = torch.arange(
|
926
1386
|
0, batch_size + 1, dtype=torch.int32, device=device
|
927
1387
|
)
|
928
|
-
|
1388
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1389
|
+
|
929
1390
|
elif forward_mode.is_target_verify():
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
1391
|
+
if self.topk <= 1:
|
1392
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
1393
|
+
"cache_seqlens"
|
1394
|
+
][:bs]
|
1395
|
+
metadata.cache_seqlens_int32.copy_(
|
1396
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1397
|
+
)
|
936
1398
|
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
1399
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1400
|
+
metadata.max_seq_len_k = (
|
1401
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
1402
|
+
)
|
941
1403
|
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
1404
|
+
metadata.cu_seqlens_q = torch.arange(
|
1405
|
+
0,
|
1406
|
+
bs * self.speculative_num_draft_tokens + 1,
|
1407
|
+
self.speculative_num_draft_tokens,
|
1408
|
+
dtype=torch.int32,
|
1409
|
+
device=device,
|
1410
|
+
)
|
949
1411
|
|
950
|
-
|
951
|
-
|
952
|
-
|
1412
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
1413
|
+
: (bs + 1)
|
1414
|
+
]
|
953
1415
|
|
954
|
-
|
955
|
-
|
956
|
-
|
1416
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
1417
|
+
req_pool_indices, :
|
1418
|
+
]
|
1419
|
+
|
1420
|
+
self.target_verify_metadata[bs] = metadata
|
1421
|
+
else:
|
1422
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1423
|
+
# 1. The first half of metadata for prefix tokens
|
1424
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
|
1425
|
+
"cache_seqlens"
|
1426
|
+
][:bs]
|
1427
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1428
|
+
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
|
1429
|
+
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
|
1430
|
+
"cu_seqlens_q"
|
1431
|
+
][: bs + 1]
|
1432
|
+
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
|
1433
|
+
"cu_seqlens_k"
|
1434
|
+
][: bs + 1]
|
1435
|
+
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1436
|
+
"page_table"
|
1437
|
+
][req_pool_indices, :]
|
1438
|
+
|
1439
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1440
|
+
metadata_expand.cache_seqlens_int32 = (
|
1441
|
+
self.target_verify_metadata_topk_expand["cache_seqlens"][
|
1442
|
+
: bs * self.speculative_num_draft_tokens
|
1443
|
+
]
|
1444
|
+
)
|
1445
|
+
metadata_expand.max_seq_len_q = 1
|
1446
|
+
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
|
1447
|
+
"cu_seqlens_q"
|
1448
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1449
|
+
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
|
1450
|
+
"cu_seqlens_k"
|
1451
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1452
|
+
|
1453
|
+
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
|
1454
|
+
"page_table"
|
1455
|
+
][: bs * self.speculative_num_draft_tokens]
|
957
1456
|
|
958
|
-
|
1457
|
+
self.target_verify_metadata_topk_normal[bs] = metadata
|
1458
|
+
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
959
1459
|
|
960
1460
|
if encoder_lens is not None:
|
961
1461
|
encoder_bs = encoder_lens.numel()
|
@@ -971,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
971
1471
|
]
|
972
1472
|
|
973
1473
|
self.forward_metadata = metadata
|
1474
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
974
1475
|
|
975
1476
|
def init_forward_metadata_replay_cuda_graph(
|
976
1477
|
self,
|
@@ -984,37 +1485,85 @@ class FlashAttentionBackend(AttentionBackend):
|
|
984
1485
|
seq_lens_cpu: Optional[torch.Tensor],
|
985
1486
|
out_cache_loc: torch.Tensor = None,
|
986
1487
|
):
|
987
|
-
|
1488
|
+
"""Initialize forward metadata for replaying CUDA graph."""
|
988
1489
|
seq_lens = seq_lens[:bs]
|
989
1490
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
990
1491
|
req_pool_indices = req_pool_indices[:bs]
|
1492
|
+
device = seq_lens.device
|
1493
|
+
metadata = None
|
1494
|
+
metadata_expand = None
|
1495
|
+
|
991
1496
|
if forward_mode.is_decode_or_idle():
|
992
|
-
metadata = self.decode_cuda_graph_metadata[bs]
|
993
1497
|
|
994
1498
|
if spec_info is not None:
|
995
1499
|
# Draft Decode
|
996
|
-
|
997
|
-
|
998
|
-
|
1500
|
+
if self.topk <= 1:
|
1501
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1502
|
+
# When topk = 1, we use the normal decode metadata
|
1503
|
+
metadata.cache_seqlens_int32.copy_(
|
1504
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
1505
|
+
)
|
999
1506
|
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1507
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1508
|
+
self.speculative_step_id + 1
|
1509
|
+
)
|
1510
|
+
metadata.cu_seqlens_k.copy_(
|
1511
|
+
torch.nn.functional.pad(
|
1512
|
+
torch.cumsum(
|
1513
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1514
|
+
),
|
1515
|
+
(1, 0),
|
1516
|
+
)
|
1009
1517
|
)
|
1010
|
-
)
|
1011
1518
|
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1519
|
+
max_seq_pages = (
|
1520
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1521
|
+
) // self.page_size
|
1522
|
+
page_indices = self.req_to_token[
|
1523
|
+
req_pool_indices[:, None],
|
1524
|
+
self.decode_cuda_graph_metadata["strided_indices"][
|
1525
|
+
:max_seq_pages
|
1526
|
+
],
|
1527
|
+
]
|
1528
|
+
|
1529
|
+
page_indices //= self.page_size
|
1530
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1531
|
+
else:
|
1532
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1533
|
+
# 1. The first half of metadata for prefix tokens
|
1534
|
+
metadata = self.draft_decode_metadata_topk_normal[bs]
|
1535
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1536
|
+
# metadata.max_seq_len_q = self.topk, already set in capture
|
1537
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1538
|
+
# metadata.cu_seqlens_q already set in capture
|
1539
|
+
metadata.cu_seqlens_k.copy_(
|
1540
|
+
torch.nn.functional.pad(
|
1541
|
+
torch.cumsum(
|
1542
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1543
|
+
),
|
1544
|
+
(1, 0),
|
1545
|
+
)
|
1546
|
+
)
|
1015
1547
|
|
1016
|
-
|
1548
|
+
page_table = self.req_to_token[
|
1549
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1550
|
+
]
|
1551
|
+
|
1552
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1553
|
+
|
1554
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1555
|
+
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
1556
|
+
decode_length = self.speculative_step_id + 1
|
1557
|
+
cache_loc = out_cache_loc.view(
|
1558
|
+
self.speculative_num_steps, -1
|
1559
|
+
).T.contiguous()
|
1560
|
+
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1561
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1562
|
+
)
|
1563
|
+
# TODO: we need to test this part for llama 4 eagle case
|
1564
|
+
self._init_local_attn_metadata(metadata, device)
|
1017
1565
|
else:
|
1566
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1018
1567
|
# Normal Decode
|
1019
1568
|
max_len = seq_lens_cpu.max().item()
|
1020
1569
|
metadata.max_seq_len_k = max_len
|
@@ -1037,25 +1586,119 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1037
1586
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1038
1587
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
1039
1588
|
|
1589
|
+
self._init_local_attn_metadata(metadata, device)
|
1040
1590
|
elif forward_mode.is_target_verify():
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1591
|
+
if self.topk <= 1:
|
1592
|
+
metadata = self.target_verify_metadata[bs]
|
1593
|
+
metadata.cache_seqlens_int32.copy_(
|
1594
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1595
|
+
)
|
1045
1596
|
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1597
|
+
metadata.max_seq_len_k = (
|
1598
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1599
|
+
)
|
1600
|
+
metadata.cu_seqlens_k.copy_(
|
1601
|
+
torch.nn.functional.pad(
|
1602
|
+
torch.cumsum(
|
1603
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1604
|
+
),
|
1605
|
+
(1, 0),
|
1606
|
+
)
|
1607
|
+
)
|
1608
|
+
max_seq_pages = (
|
1609
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1610
|
+
) // self.page_size
|
1611
|
+
page_indices = self.req_to_token[
|
1612
|
+
req_pool_indices[:, None],
|
1613
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
1614
|
+
]
|
1615
|
+
page_indices //= self.page_size
|
1616
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1617
|
+
else:
|
1618
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1619
|
+
# 1. The first half of metadata for prefix tokens
|
1620
|
+
metadata = self.target_verify_metadata_topk_normal[bs]
|
1621
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1622
|
+
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1623
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1624
|
+
# metadata.cu_seqlens_q already set in capture
|
1625
|
+
metadata.cu_seqlens_k.copy_(
|
1626
|
+
torch.nn.functional.pad(
|
1627
|
+
torch.cumsum(
|
1628
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1629
|
+
),
|
1630
|
+
(1, 0),
|
1631
|
+
)
|
1632
|
+
)
|
1633
|
+
page_table = self.req_to_token[
|
1634
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1635
|
+
]
|
1636
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1637
|
+
|
1638
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1639
|
+
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1640
|
+
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1641
|
+
# metadata_expand.cu_seqlens_q already set in capture
|
1642
|
+
|
1643
|
+
offsets = torch.arange(
|
1644
|
+
self.speculative_num_draft_tokens, device=device
|
1645
|
+
).unsqueeze(
|
1646
|
+
0
|
1647
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
1648
|
+
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1649
|
+
cum_len = torch.nn.functional.pad(
|
1051
1650
|
torch.cumsum(
|
1052
|
-
|
1651
|
+
(
|
1652
|
+
seq_lens + self.speculative_num_draft_tokens
|
1653
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
1654
|
+
dim=0,
|
1053
1655
|
),
|
1054
1656
|
(1, 0),
|
1657
|
+
)[:-1]
|
1658
|
+
mask_extraction_indices = (
|
1659
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1660
|
+
+ cum_len[:, None]
|
1661
|
+
).view(1, -1)
|
1662
|
+
# avoid extracting padded seq indices which will be out of boundary
|
1663
|
+
mask_extraction_indices[
|
1664
|
+
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
1665
|
+
].fill_(0)
|
1666
|
+
|
1667
|
+
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1668
|
+
-1, self.speculative_num_draft_tokens
|
1669
|
+
) # (bsz * draft_num, draft_num)
|
1670
|
+
col_indices = offsets.expand(
|
1671
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
1672
|
+
)
|
1673
|
+
keys = torch.where(
|
1674
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
1675
|
+
)
|
1676
|
+
_, sort_order = torch.sort(keys, dim=1)
|
1677
|
+
|
1678
|
+
non_masked_page_table = (
|
1679
|
+
self.req_to_token[req_pool_indices, :]
|
1680
|
+
.gather(1, cols)
|
1681
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1682
|
+
) # (bsz, draft_num)
|
1683
|
+
metadata_expand.page_table.copy_(
|
1684
|
+
non_masked_page_table.gather(1, sort_order)
|
1685
|
+
)
|
1686
|
+
metadata_expand.cache_seqlens_int32.copy_(
|
1687
|
+
mask.sum(dim=1).to(torch.int32)
|
1688
|
+
)
|
1689
|
+
metadata_expand.cu_seqlens_k.copy_(
|
1690
|
+
torch.nn.functional.pad(
|
1691
|
+
torch.cumsum(
|
1692
|
+
metadata_expand.cache_seqlens_int32,
|
1693
|
+
dim=0,
|
1694
|
+
dtype=torch.int32,
|
1695
|
+
),
|
1696
|
+
(1, 0),
|
1697
|
+
)
|
1698
|
+
)
|
1699
|
+
metadata_expand.max_seq_len_k = (
|
1700
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
1055
1701
|
)
|
1056
|
-
)
|
1057
|
-
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
1058
|
-
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1059
1702
|
|
1060
1703
|
if encoder_lens is not None:
|
1061
1704
|
# Only support encoder size 1 for now
|
@@ -1082,11 +1725,48 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1082
1725
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1083
1726
|
|
1084
1727
|
self.forward_metadata = metadata
|
1728
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
1085
1729
|
|
1086
1730
|
def get_cuda_graph_seq_len_fill_value(self):
|
1087
1731
|
"""Get the fill value for sequence length in CUDA graph."""
|
1088
1732
|
return 0
|
1089
1733
|
|
1734
|
+
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1735
|
+
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
1736
|
+
if self.attention_chunk_size is None:
|
1737
|
+
metadata.local_attn_metadata = None
|
1738
|
+
return
|
1739
|
+
|
1740
|
+
cu_seqlens_q = metadata.cu_seqlens_q
|
1741
|
+
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
1742
|
+
page_table = metadata.page_table
|
1743
|
+
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
1744
|
+
metadata.local_attn_metadata = None
|
1745
|
+
return
|
1746
|
+
|
1747
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1748
|
+
seq_lens_np = cache_seqlens_int32.cpu().numpy()
|
1749
|
+
(
|
1750
|
+
seqlens_q_local_np,
|
1751
|
+
cu_seqlens_q_local_np,
|
1752
|
+
seqlens_k_local_np,
|
1753
|
+
block_table_local,
|
1754
|
+
) = make_local_attention_virtual_batches(
|
1755
|
+
self.attention_chunk_size,
|
1756
|
+
cu_seqlens_q_np,
|
1757
|
+
seq_lens_np,
|
1758
|
+
page_table,
|
1759
|
+
self.page_size,
|
1760
|
+
)
|
1761
|
+
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1762
|
+
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1763
|
+
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
1764
|
+
local_block_table=block_table_local.to(device),
|
1765
|
+
local_max_query_len=int(seqlens_q_local_np.max()),
|
1766
|
+
local_max_seq_len=int(seqlens_k_local_np.max()),
|
1767
|
+
)
|
1768
|
+
metadata.local_attn_metadata = local_metadata
|
1769
|
+
|
1090
1770
|
|
1091
1771
|
class FlashAttentionMultiStepBackend:
|
1092
1772
|
|
@@ -1096,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
|
|
1096
1776
|
self.model_runner = model_runner
|
1097
1777
|
self.topk = topk
|
1098
1778
|
self.speculative_num_steps = speculative_num_steps
|
1099
|
-
|
1100
|
-
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1101
|
-
assert (
|
1102
|
-
self.topk == 1
|
1103
|
-
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1104
|
-
|
1105
1779
|
self.attn_backends = []
|
1106
1780
|
for i in range(self.speculative_num_steps):
|
1107
1781
|
self.attn_backends.append(
|