sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,56 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
-
5
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
6
-
7
- """
8
- Support different attention backends.
9
- Now there are three backends: FlashInfer, Triton and FlashAttention.
10
- Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
11
- """
12
-
13
3
  from dataclasses import dataclass
14
4
  from typing import TYPE_CHECKING, Optional, Union
15
5
 
6
+ import numpy as np
16
7
  import torch
17
8
 
18
9
  from sglang.srt.configs.model_config import AttentionArch
19
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
21
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
22
14
 
23
15
  if TYPE_CHECKING:
24
16
  from sglang.srt.layers.radix_attention import RadixAttention
25
17
  from sglang.srt.model_executor.model_runner import ModelRunner
26
18
 
27
- from sgl_kernel.flash_attn import flash_attn_with_kvcache
19
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
28
20
 
29
21
 
30
22
  @dataclass
31
23
  class FlashAttentionMetadata:
32
24
  """Metadata to be init once in the model forward pass,
33
- each layer's forward pass can reuse the metadata."""
25
+ each layer's forward pass can reuse the metadata.
34
26
 
35
- # Cumulative sequence lengths for query
36
- cu_seqlens_q: torch.Tensor = None
37
- # Cumulative sequence lengths for key
38
- cu_seqlens_k: torch.Tensor = None
27
+ For each init metadata function, we will try set up them in below order
28
+ """
29
+
30
+ # Sequence lengths for the forward batch
31
+ cache_seqlens_int32: torch.Tensor = None
39
32
  # Maximum sequence length for query
40
33
  max_seq_len_q: int = 0
41
34
  # Maximum sequence length for key
42
35
  max_seq_len_k: int = 0
36
+ # Cumulative sequence lengths for query
37
+ cu_seqlens_q: torch.Tensor = None
38
+ # Cumulative sequence lengths for key
39
+ cu_seqlens_k: torch.Tensor = None
43
40
  # Window size (typically used by Gemma)
44
41
  window_size: tuple = (-1, -1)
45
42
  # Page table, the index of KV Cache Tables/Blocks
46
43
  page_table: torch.Tensor = None
44
+
45
+ # Encoder metadata
46
+ # Cumulative sequence lengths for encoder key
47
+ encoder_cu_seqlens_k: torch.Tensor = None
48
+ # Maximum sequence length for encoder key
49
+ encoder_max_seq_len_k: int = 0
47
50
  # Sequence lengths for the forward batch
48
- cache_seqlens_int32: torch.Tensor = None
51
+ encoder_lens_int32: torch.Tensor = None
52
+ # Page table for the encoder
53
+ encoder_page_table: torch.Tensor = None
49
54
 
50
55
  @dataclass
51
56
  class LocalAttentionMetadata:
@@ -231,7 +236,11 @@ def make_local_attention_virtual_batches(
231
236
  np.arange(pages_per_local_batch, dtype=np.int32),
232
237
  (virtual_batches, pages_per_local_batch),
233
238
  ) + np.expand_dims(block_starts, axis=1)
234
- block_indices = block_indices.flatten()
239
+ # Ensure block_indices doesn't exceed block_table dimensions
240
+ # This is a critical safety check that prevents index out of bounds errors
241
+ # when dealing with large sequences (>8192 tokens) or when the block_table
242
+ # dimensions are smaller than what would be needed for the full attention chunk size.
243
+ block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
235
244
  batch_indices = np.repeat(
236
245
  np.arange(actual_batch_size, dtype=np.int32),
237
246
  local_blocks * pages_per_local_batch,
@@ -270,9 +279,9 @@ class FlashAttentionBackend(AttentionBackend):
270
279
  self,
271
280
  model_runner: ModelRunner,
272
281
  skip_prefill: bool = False,
282
+ speculative_step_id=0,
273
283
  topk=0,
274
284
  speculative_num_steps=0,
275
- step_id=0,
276
285
  ):
277
286
  super().__init__()
278
287
 
@@ -287,20 +296,20 @@ class FlashAttentionBackend(AttentionBackend):
287
296
  self.decode_cuda_graph_metadata = {}
288
297
  self.target_verify_metadata = {}
289
298
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
299
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
300
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
290
301
  self.page_size = model_runner.page_size
291
302
  self.use_mla = (
292
303
  model_runner.model_config.attention_arch == AttentionArch.MLA
293
304
  ) and (not global_server_args_dict["disable_mla"])
294
305
  self.skip_prefill = skip_prefill
295
306
 
296
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
297
- assert (
298
- topk <= 1
299
- ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
300
-
301
- self.topk = 1
302
- self.step_id = step_id
307
+ self.topk = topk
303
308
  self.speculative_num_steps = speculative_num_steps
309
+ self.speculative_num_draft_tokens = (
310
+ model_runner.server_args.speculative_num_draft_tokens
311
+ )
312
+ self.speculative_step_id = speculative_step_id
304
313
 
305
314
  # Local attention settings
306
315
  self.attention_chunk_size = (
@@ -310,71 +319,59 @@ class FlashAttentionBackend(AttentionBackend):
310
319
  )
311
320
 
312
321
  def init_forward_metadata(self, forward_batch: ForwardBatch):
313
- """Initialize forward metadata to cache repetitive calculations."""
322
+ """Initialize forward metadata hence all layers in the forward pass can reuse it."""
314
323
  metadata = FlashAttentionMetadata()
315
324
  seqlens_in_batch = forward_batch.seq_lens
316
325
  batch_size = len(seqlens_in_batch)
317
326
  device = seqlens_in_batch.device
318
- if forward_batch.forward_mode.is_decode():
319
- # Skip Prefill or Draft Decode
320
- # Note: Draft Decode will be ran on the Draft Worker
327
+
328
+ if forward_batch.forward_mode.is_decode_or_idle():
329
+ # Draft Decode
321
330
  if forward_batch.spec_info is not None:
331
+ metadata.cache_seqlens_int32 = (
332
+ seqlens_in_batch + (self.speculative_step_id + 1)
333
+ ).to(torch.int32)
334
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
335
+ self.speculative_step_id + 1
336
+ )
322
337
  metadata.cu_seqlens_q = torch.arange(
323
338
  0, batch_size + 1, dtype=torch.int32, device=device
324
339
  )
325
- seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
326
- metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
327
340
  metadata.cu_seqlens_k = torch.nn.functional.pad(
328
341
  torch.cumsum(
329
342
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
330
343
  ),
331
344
  (1, 0),
332
345
  )
333
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
334
- self.step_id + 1
335
- )
336
346
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
337
347
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
338
348
  ]
339
- cache_loc = forward_batch.out_cache_loc.view(
340
- self.speculative_num_steps, -1
341
- ).T
342
-
343
- for idx, single_seq_len in enumerate(seq_lens_with_decode):
344
- real_bsz_start_idx = idx
345
- real_bsz_end_idx = idx + 1
346
- metadata.page_table[
347
- real_bsz_start_idx:real_bsz_end_idx,
348
- (single_seq_len - (self.step_id + 1)) : single_seq_len,
349
- ] = cache_loc[
350
- real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
351
- ]
352
- else: # Normal Decode without Spec Decoding
349
+ else:
350
+ # Normal Decode
353
351
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
352
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
353
+ metadata.cu_seqlens_q = torch.arange(
354
+ 0, batch_size + 1, dtype=torch.int32, device=device
355
+ )
354
356
  metadata.cu_seqlens_k = torch.nn.functional.pad(
355
357
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
356
358
  )
357
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
358
359
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
359
360
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
360
361
  ]
361
- metadata.cu_seqlens_q = torch.arange(
362
- 0, batch_size + 1, dtype=torch.int32, device=device
363
- )
364
362
  elif forward_batch.forward_mode.is_target_verify():
365
- # Note: Target Verify will be ran on the Target Worker
366
- draft_token_num = forward_batch.spec_info.draft_token_num
367
363
  metadata.cache_seqlens_int32 = (
368
- forward_batch.seq_lens + draft_token_num
364
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
369
365
  ).to(torch.int32)
370
- metadata.max_seq_len_q = draft_token_num
366
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
371
367
  metadata.max_seq_len_k = (
372
- forward_batch.seq_lens_cpu.max().item() + draft_token_num
368
+ forward_batch.seq_lens_cpu.max().item()
369
+ + self.speculative_num_draft_tokens
373
370
  )
374
371
  metadata.cu_seqlens_q = torch.arange(
375
372
  0,
376
- batch_size * draft_token_num + 1,
377
- draft_token_num,
373
+ batch_size * self.speculative_num_draft_tokens + 1,
374
+ self.speculative_num_draft_tokens,
378
375
  dtype=torch.int32,
379
376
  device=device,
380
377
  )
@@ -386,32 +383,28 @@ class FlashAttentionBackend(AttentionBackend):
386
383
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
387
384
  ]
388
385
 
389
- elif forward_batch.forward_mode.is_extend_or_draft_extend():
390
- # Normal or Draft Extend (Both of them will be ran on the Target Worker)
386
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
391
387
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
388
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
392
389
  metadata.cu_seqlens_k = torch.nn.functional.pad(
393
390
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
394
391
  )
395
- # Precompute maximum sequence length
396
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
397
- # Precompute page table
398
392
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
399
393
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
400
394
  ]
401
395
 
402
- # Precompute cumulative sequence lengths
403
396
  if (
404
397
  any(forward_batch.extend_prefix_lens_cpu)
405
398
  or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
406
399
  ):
407
400
  extend_seq_lens = forward_batch.extend_seq_lens
401
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
408
402
  metadata.cu_seqlens_q = torch.nn.functional.pad(
409
403
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
410
404
  )
411
- metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
412
405
  else:
413
- metadata.cu_seqlens_q = metadata.cu_seqlens_k
414
406
  metadata.max_seq_len_q = metadata.max_seq_len_k
407
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
415
408
 
416
409
  # Setup local attention if enabled
417
410
  if (
@@ -458,7 +451,31 @@ class FlashAttentionBackend(AttentionBackend):
458
451
  )
459
452
  metadata.local_attn_metadata = local_metadata
460
453
 
461
- # Precompute strided indices
454
+ # Encoder metadata for cross attention
455
+ if forward_batch.encoder_lens is not None:
456
+ assert (
457
+ forward_batch.encoder_lens.numel() == 1
458
+ ), "Only encoder size 1 is supported for now"
459
+
460
+ metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
461
+ metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
462
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
463
+ (1, 0),
464
+ )
465
+ metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
466
+ metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
467
+ forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
468
+ ]
469
+
470
+ # Currently only support forward_batch.encoder_lens.numel() == 1
471
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
472
+ forward_batch.req_pool_indices,
473
+ metadata.encoder_max_seq_len_k : (
474
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
475
+ ),
476
+ ]
477
+
478
+ # Convert the page table to a strided format which is needed by FA3 API
462
479
  if self.page_size > 1:
463
480
  self.strided_indices = torch.arange(
464
481
  0, metadata.page_table.shape[1], self.page_size, device=self.device
@@ -498,7 +515,7 @@ class FlashAttentionBackend(AttentionBackend):
498
515
  v,
499
516
  )
500
517
 
501
- # Use precomputed metadata
518
+ # Use precomputed metadata across all layers
502
519
  metadata = self.forward_metadata
503
520
 
504
521
  # Calculate window size (can be moved to metadata if layer properties don't change)
@@ -506,9 +523,18 @@ class FlashAttentionBackend(AttentionBackend):
506
523
  # here is two side inclusive
507
524
  window_size = (
508
525
  (layer.sliding_window_size, 0)
509
- if layer.sliding_window_size is not None
526
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
510
527
  else (-1, -1)
511
528
  )
529
+ k_descale, v_descale = None, None
530
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
531
+ # has corresponding quantization method so that layer.k_scale is not None
532
+ if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
533
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
534
+ k_descale = layer.k_scale.expand(descale_shape)
535
+ v_descale = layer.v_scale.expand(descale_shape)
536
+ q = q.to(self.kv_cache_dtype)
537
+ causal = not layer.is_cross_attention
512
538
 
513
539
  # Check if we should use local attention
514
540
  use_local_attn = (
@@ -536,14 +562,21 @@ class FlashAttentionBackend(AttentionBackend):
536
562
  # Use Flash Attention for prefill
537
563
  if not self.use_mla:
538
564
  # Do multi-head attention
539
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
540
- key_cache, value_cache = kv_cache[0], kv_cache[1]
565
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
566
+ layer.layer_id
567
+ )
541
568
  key_cache = key_cache.view(
542
569
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
543
570
  )
544
571
  value_cache = value_cache.view(
545
572
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
546
573
  )
574
+ if layer.is_cross_attention:
575
+ page_table = metadata.encoder_page_table
576
+ cache_seqlens = metadata.encoder_lens_int32
577
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
578
+ window_size = (-1, -1)
579
+
547
580
  o = flash_attn_with_kvcache(
548
581
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
549
582
  k_cache=key_cache,
@@ -554,48 +587,93 @@ class FlashAttentionBackend(AttentionBackend):
554
587
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
555
588
  max_seqlen_q=max_seqlen_q,
556
589
  softmax_scale=layer.scaling,
557
- causal=True,
590
+ causal=causal,
558
591
  window_size=window_size,
559
592
  softcap=layer.logit_cap,
560
- k_descale=layer.k_scale,
561
- v_descale=layer.v_scale,
593
+ k_descale=k_descale,
594
+ v_descale=v_descale,
562
595
  )
596
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
563
597
  else:
564
- # Do absorbed multi-latent attention
565
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
566
- k_rope = kv_cache[:, :, layer.v_head_dim :]
567
- c_kv = kv_cache[:, :, : layer.v_head_dim]
568
- k_rope_cache = k_rope.view(
569
- -1,
570
- self.page_size,
571
- layer.tp_k_head_num,
572
- layer.head_dim - layer.v_head_dim,
573
- )
574
- c_kv_cache = c_kv.view(
575
- -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
576
- )
577
-
578
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
579
- q_nope = q_all[:, :, : layer.v_head_dim]
580
- q_rope = q_all[:, :, layer.v_head_dim :]
581
- o = flash_attn_with_kvcache(
582
- q=q_rope,
583
- k_cache=k_rope_cache,
584
- v_cache=c_kv_cache,
585
- qv=q_nope,
586
- page_table=page_table,
587
- cache_seqlens=cache_seqlens,
588
- cu_seqlens_q=cu_seqlens_q,
589
- cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
590
- max_seqlen_q=max_seqlen_q,
591
- softmax_scale=layer.scaling,
592
- causal=True,
593
- softcap=layer.logit_cap,
594
- k_descale=layer.k_scale,
595
- v_descale=layer.v_scale,
596
- )
598
+ if (
599
+ not global_server_args_dict["disable_chunked_prefix_cache"]
600
+ and forward_batch.attn_attend_prefix_cache is not None
601
+ and not forward_batch.forward_mode.is_target_verify()
602
+ and not forward_batch.forward_mode.is_draft_extend()
603
+ ):
604
+ # Do multi-head attention with chunked prefix cache
605
+
606
+ if forward_batch.attn_attend_prefix_cache:
607
+ # MHA for chunked prefix kv cache when running model with MLA
608
+ assert forward_batch.prefix_chunk_idx is not None
609
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
610
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
611
+
612
+ chunk_idx = forward_batch.prefix_chunk_idx
613
+ assert chunk_idx >= 0
614
+
615
+ output, lse, *rest = flash_attn_varlen_func(
616
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
617
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
618
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
619
+ cu_seqlens_q=metadata.cu_seqlens_q,
620
+ cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
621
+ max_seqlen_q=metadata.max_seq_len_q,
622
+ max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
623
+ softmax_scale=layer.scaling,
624
+ causal=False,
625
+ return_softmax_lse=True,
626
+ )
627
+ else:
628
+ # MHA for extend part of sequence without attending prefix kv cache
629
+ output, lse, *rest = flash_attn_varlen_func(
630
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
631
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
632
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
633
+ cu_seqlens_q=metadata.cu_seqlens_q,
634
+ cu_seqlens_k=metadata.cu_seqlens_q,
635
+ max_seqlen_q=metadata.max_seq_len_q,
636
+ max_seqlen_k=metadata.max_seq_len_q,
637
+ softmax_scale=layer.scaling,
638
+ causal=True,
639
+ return_softmax_lse=True,
640
+ )
641
+ return output, lse
642
+ else:
643
+ # Do absorbed multi-latent attention
644
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
645
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
646
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
647
+ k_rope_cache = k_rope.view(
648
+ -1,
649
+ self.page_size,
650
+ layer.tp_k_head_num,
651
+ layer.head_dim - layer.v_head_dim,
652
+ )
653
+ c_kv_cache = c_kv.view(
654
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
655
+ )
656
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
657
+ q_nope = q_all[:, :, : layer.v_head_dim]
658
+ q_rope = q_all[:, :, layer.v_head_dim :]
659
+ o = flash_attn_with_kvcache(
660
+ q=q_rope,
661
+ k_cache=k_rope_cache,
662
+ v_cache=c_kv_cache,
663
+ qv=q_nope,
664
+ page_table=page_table,
665
+ cache_seqlens=cache_seqlens,
666
+ cu_seqlens_q=cu_seqlens_q,
667
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
668
+ max_seqlen_q=max_seqlen_q,
669
+ softmax_scale=layer.scaling,
670
+ causal=True,
671
+ softcap=layer.logit_cap,
672
+ k_descale=k_descale,
673
+ v_descale=v_descale,
674
+ )
597
675
 
598
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
676
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
599
677
 
600
678
  def forward_decode(
601
679
  self,
@@ -606,8 +684,6 @@ class FlashAttentionBackend(AttentionBackend):
606
684
  forward_batch: ForwardBatch,
607
685
  save_kv_cache=True,
608
686
  ) -> torch.Tensor:
609
- """Forward pass with FlashAttention using precomputed metadata."""
610
- # Save KV cache if needed
611
687
  if k is not None:
612
688
  assert v is not None
613
689
  if save_kv_cache:
@@ -628,7 +704,7 @@ class FlashAttentionBackend(AttentionBackend):
628
704
  v,
629
705
  )
630
706
 
631
- # Use precomputed metadata
707
+ # Use precomputed metadata across all layers
632
708
  metadata = self.forward_metadata
633
709
 
634
710
  # Calculate window size (can be moved to metadata if layer properties don't change)
@@ -636,17 +712,27 @@ class FlashAttentionBackend(AttentionBackend):
636
712
  # here is two side inclusive
637
713
  window_size = (
638
714
  (layer.sliding_window_size, 0)
639
- if layer.sliding_window_size is not None
715
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
640
716
  else (-1, -1)
641
717
  )
642
- page_table = metadata.page_table
718
+ causal = not layer.is_cross_attention
719
+
720
+ k_descale, v_descale = None, None
721
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
722
+ # has corresponding quantization method so that layer.k_scale is not None
723
+ if self.kv_cache_dtype_str != "auto":
724
+ if layer.k_scale is not None:
725
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
726
+ k_descale = layer.k_scale.expand(descale_shape)
727
+ v_descale = layer.v_scale.expand(descale_shape)
728
+ q = q.to(self.kv_cache_dtype)
643
729
 
644
730
  if not self.use_mla:
645
731
  # Do multi-head attention
646
732
 
647
- # Get KV cache
648
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
649
- key_cache, value_cache = kv_cache[0], kv_cache[1]
733
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
734
+ layer.layer_id
735
+ )
650
736
  key_cache = key_cache.view(
651
737
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
652
738
  )
@@ -654,23 +740,32 @@ class FlashAttentionBackend(AttentionBackend):
654
740
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
655
741
  )
656
742
 
657
- # Pre-reshape query tensor
658
743
  q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
744
+ if layer.is_cross_attention:
745
+ page_table = metadata.encoder_page_table
746
+ cache_seqlens = metadata.encoder_lens_int32
747
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
748
+ window_size = (-1, -1)
749
+ else:
750
+ page_table = metadata.page_table
751
+ cache_seqlens = metadata.cache_seqlens_int32
752
+ cu_seqlens_k = metadata.cu_seqlens_k
753
+
659
754
  o = flash_attn_with_kvcache(
660
755
  q=q_reshaped,
661
756
  k_cache=key_cache,
662
757
  v_cache=value_cache,
663
758
  page_table=page_table,
664
- cache_seqlens=metadata.cache_seqlens_int32,
759
+ cache_seqlens=cache_seqlens,
665
760
  cu_seqlens_q=metadata.cu_seqlens_q,
666
- cu_seqlens_k_new=metadata.cu_seqlens_k,
761
+ cu_seqlens_k_new=cu_seqlens_k,
667
762
  max_seqlen_q=1,
668
763
  softmax_scale=layer.scaling,
669
- causal=True,
764
+ causal=causal,
670
765
  window_size=window_size,
671
766
  softcap=layer.logit_cap,
672
- k_descale=layer.k_scale,
673
- v_descale=layer.v_scale,
767
+ k_descale=k_descale,
768
+ v_descale=v_descale,
674
769
  )
675
770
  else:
676
771
  # Do absorbed multi-latent attention
@@ -696,7 +791,7 @@ class FlashAttentionBackend(AttentionBackend):
696
791
  k_cache=k_rope_cache,
697
792
  v_cache=c_kv_cache,
698
793
  qv=q_nope,
699
- page_table=page_table,
794
+ page_table=metadata.page_table,
700
795
  cache_seqlens=metadata.cache_seqlens_int32,
701
796
  cu_seqlens_q=metadata.cu_seqlens_q,
702
797
  cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -704,8 +799,8 @@ class FlashAttentionBackend(AttentionBackend):
704
799
  softmax_scale=layer.scaling,
705
800
  causal=True,
706
801
  softcap=layer.logit_cap,
707
- k_descale=layer.k_scale,
708
- v_descale=layer.v_scale,
802
+ k_descale=k_descale,
803
+ v_descale=v_descale,
709
804
  )
710
805
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
711
806
 
@@ -719,7 +814,13 @@ class FlashAttentionBackend(AttentionBackend):
719
814
  to avoid memory allocations.
720
815
  """
721
816
  self.decode_cuda_graph_metadata = {
722
- # Page table for token mapping (batch_size, max_context_len)
817
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
818
+ "cu_seqlens_q": torch.arange(
819
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
820
+ ),
821
+ "cu_seqlens_k": torch.zeros(
822
+ max_bs + 1, dtype=torch.int32, device=self.device
823
+ ),
723
824
  "page_table": torch.zeros(
724
825
  max_bs,
725
826
  (self.max_context_len + self.page_size - 1) // self.page_size,
@@ -735,35 +836,42 @@ class FlashAttentionBackend(AttentionBackend):
735
836
  "strided_indices": torch.arange(
736
837
  0, self.max_context_len, self.page_size, device=self.device
737
838
  ),
839
+ }
840
+
841
+ self.target_verify_metadata = {
738
842
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
739
- "cu_seqlens_q": torch.arange(
740
- 0, max_bs + 128, dtype=torch.int32, device=self.device
843
+ "cu_seqlens_q": torch.zeros(
844
+ max_bs + 1, dtype=torch.int32, device=self.device
741
845
  ),
742
846
  "cu_seqlens_k": torch.zeros(
743
- max_bs + 128, dtype=torch.int32, device=self.device
847
+ max_bs + 1, dtype=torch.int32, device=self.device
744
848
  ),
745
- }
746
-
747
- self.target_verify_metadata = {
748
849
  "page_table": torch.zeros(
749
850
  max_bs,
750
851
  (self.max_context_len + self.page_size - 1) // self.page_size,
751
852
  dtype=torch.int32,
752
853
  device=self.device,
753
854
  ),
754
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
755
- "cu_seqlens_q": torch.zeros(
756
- max_bs + 128, dtype=torch.int32, device=self.device
757
- ),
758
- "cu_seqlens_k": torch.zeros(
759
- max_bs + 128, dtype=torch.int32, device=self.device
760
- ),
761
- "max_seqlen_q": 0,
762
855
  "strided_indices": torch.arange(
763
856
  0, self.max_context_len, self.page_size, device=self.device
764
857
  ),
765
858
  }
766
859
 
860
+ self.encoder_metadata = {
861
+ "encoder_page_table": torch.zeros(
862
+ max_bs,
863
+ self.max_context_len,
864
+ dtype=torch.int32,
865
+ device=self.device,
866
+ ),
867
+ "encoder_lens_int32": torch.zeros(
868
+ max_bs, dtype=torch.int32, device=self.device
869
+ ),
870
+ "encoder_cu_seqlens_k": torch.zeros(
871
+ max_bs + 1, dtype=torch.int32, device=self.device
872
+ ),
873
+ }
874
+
767
875
  def init_forward_metadata_capture_cuda_graph(
768
876
  self,
769
877
  bs: int,
@@ -777,27 +885,24 @@ class FlashAttentionBackend(AttentionBackend):
777
885
  """Initialize forward metadata for capturing CUDA graph."""
778
886
  metadata = FlashAttentionMetadata()
779
887
  device = seq_lens.device
780
- if forward_mode.is_decode():
888
+ if forward_mode.is_decode_or_idle():
781
889
  if spec_info is not None:
782
890
  # Draft Decode
783
- metadata.cu_seqlens_q = torch.arange(
784
- 0, bs + 1, dtype=torch.int32, device=device
785
- )
786
891
  metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
787
892
  "cache_seqlens"
788
893
  ][:bs]
789
-
894
+ metadata.max_seq_len_k = seq_lens.max().item() + (
895
+ self.speculative_step_id + 1
896
+ )
790
897
  metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
791
898
  : bs + 1
792
899
  ]
793
-
794
900
  metadata.cu_seqlens_k = torch.nn.functional.pad(
795
901
  torch.cumsum(
796
902
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
797
903
  ),
798
904
  (1, 0),
799
905
  )
800
- metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
801
906
  metadata.page_table = self.decode_cuda_graph_metadata[
802
907
  "page_table_draft_decode"
803
908
  ][req_pool_indices, :]
@@ -822,43 +927,49 @@ class FlashAttentionBackend(AttentionBackend):
822
927
  )
823
928
  self.decode_cuda_graph_metadata[bs] = metadata
824
929
  elif forward_mode.is_target_verify():
825
- draft_token_num = spec_info.draft_token_num
826
-
827
930
  metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
828
931
  :bs
829
932
  ]
830
933
  metadata.cache_seqlens_int32.copy_(
831
- (seq_lens + draft_token_num).to(torch.int32)
934
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
832
935
  )
833
936
 
834
- metadata.max_seq_len_q = draft_token_num
835
- metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
937
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
938
+ metadata.max_seq_len_k = (
939
+ seq_lens.max().item() + self.speculative_num_draft_tokens
940
+ )
836
941
 
837
- metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
838
- torch.arange(
839
- 0,
840
- bs * draft_token_num + 1,
841
- draft_token_num,
842
- dtype=torch.int32,
843
- device=device,
844
- )
845
- ]
846
- cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
847
- cu_k.copy_(
848
- torch.nn.functional.pad(
849
- torch.cumsum(
850
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
851
- ),
852
- (1, 0),
853
- )
942
+ metadata.cu_seqlens_q = torch.arange(
943
+ 0,
944
+ bs * self.speculative_num_draft_tokens + 1,
945
+ self.speculative_num_draft_tokens,
946
+ dtype=torch.int32,
947
+ device=device,
854
948
  )
855
- metadata.cu_seqlens_k = cu_k
949
+
950
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
951
+ : (bs + 1)
952
+ ]
953
+
856
954
  metadata.page_table = self.target_verify_metadata["page_table"][
857
955
  req_pool_indices, :
858
956
  ]
859
957
 
860
958
  self.target_verify_metadata[bs] = metadata
861
959
 
960
+ if encoder_lens is not None:
961
+ encoder_bs = encoder_lens.numel()
962
+ metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
963
+ :encoder_bs
964
+ ]
965
+ metadata.encoder_cu_seqlens_k = self.encoder_metadata[
966
+ "encoder_cu_seqlens_k"
967
+ ][: (encoder_bs + 1)]
968
+
969
+ metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
970
+ req_pool_indices, :
971
+ ]
972
+
862
973
  self.forward_metadata = metadata
863
974
 
864
975
  def init_forward_metadata_replay_cuda_graph(
@@ -874,24 +985,21 @@ class FlashAttentionBackend(AttentionBackend):
874
985
  out_cache_loc: torch.Tensor = None,
875
986
  ):
876
987
  # """Initialize forward metadata for replaying CUDA graph."""
877
- device = seq_lens.device
878
988
  seq_lens = seq_lens[:bs]
879
- req_pool_indices = req_pool_indices[:bs]
880
989
  seq_lens_cpu = seq_lens_cpu[:bs]
881
- if forward_mode.is_decode():
990
+ req_pool_indices = req_pool_indices[:bs]
991
+ if forward_mode.is_decode_or_idle():
882
992
  metadata = self.decode_cuda_graph_metadata[bs]
883
993
 
884
994
  if spec_info is not None:
885
995
  # Draft Decode
886
- max_len = seq_lens_cpu.max().item()
887
- metadata.max_seq_len_k = max_len + (self.step_id + 1)
888
-
889
996
  metadata.cache_seqlens_int32.copy_(
890
- (seq_lens + (self.step_id + 1)).to(torch.int32)
997
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
891
998
  )
892
999
 
893
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
894
-
1000
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1001
+ self.speculative_step_id + 1
1002
+ )
895
1003
  metadata.cu_seqlens_k.copy_(
896
1004
  torch.nn.functional.pad(
897
1005
  torch.cumsum(
@@ -920,31 +1028,24 @@ class FlashAttentionBackend(AttentionBackend):
920
1028
  metadata.max_seq_len_k + self.page_size - 1
921
1029
  ) // self.page_size
922
1030
  page_indices = self.req_to_token[
923
- :,
924
- self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1031
+ req_pool_indices[:, None],
1032
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
1033
+ None, :
1034
+ ],
925
1035
  ]
926
- page_indices = page_indices[req_pool_indices] // self.page_size
1036
+ page_indices //= self.page_size
927
1037
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
928
1038
  metadata.page_table[:, max_seq_pages:].fill_(0)
929
1039
 
930
1040
  elif forward_mode.is_target_verify():
931
1041
  metadata = self.target_verify_metadata[bs]
932
- draft_token_num = spec_info.draft_token_num
933
-
934
- metadata.cu_seqlens_q.copy_(
935
- torch.arange(
936
- 0,
937
- bs * draft_token_num + 1,
938
- draft_token_num,
939
- dtype=torch.int32,
940
- device=device,
941
- )
942
- )
943
1042
  metadata.cache_seqlens_int32.copy_(
944
- (seq_lens + draft_token_num).to(torch.int32)
1043
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
945
1044
  )
946
1045
 
947
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
1046
+ metadata.max_seq_len_k = (
1047
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1048
+ )
948
1049
  metadata.cu_seqlens_k.copy_(
949
1050
  torch.nn.functional.pad(
950
1051
  torch.cumsum(
@@ -956,6 +1057,30 @@ class FlashAttentionBackend(AttentionBackend):
956
1057
  page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
957
1058
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
958
1059
 
1060
+ if encoder_lens is not None:
1061
+ # Only support encoder size 1 for now
1062
+ metadata.encoder_max_seq_len_k = encoder_lens[0]
1063
+ metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1064
+ metadata.encoder_cu_seqlens_k.copy_(
1065
+ torch.nn.functional.pad(
1066
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1067
+ (1, 0),
1068
+ )
1069
+ )
1070
+
1071
+ metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
1072
+ self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
1073
+ )
1074
+
1075
+ # Update the regular page table
1076
+ page_table = self.req_to_token[
1077
+ req_pool_indices,
1078
+ metadata.encoder_max_seq_len_k : (
1079
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
1080
+ ),
1081
+ ]
1082
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1083
+
959
1084
  self.forward_metadata = metadata
960
1085
 
961
1086
  def get_cuda_graph_seq_len_fill_value(self):
@@ -972,14 +1097,19 @@ class FlashAttentionMultiStepBackend:
972
1097
  self.topk = topk
973
1098
  self.speculative_num_steps = speculative_num_steps
974
1099
 
1100
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1101
+ assert (
1102
+ self.topk == 1
1103
+ ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1104
+
975
1105
  self.attn_backends = []
976
1106
  for i in range(self.speculative_num_steps):
977
1107
  self.attn_backends.append(
978
1108
  FlashAttentionBackend(
979
1109
  model_runner,
1110
+ speculative_step_id=i,
980
1111
  topk=self.topk,
981
1112
  speculative_num_steps=self.speculative_num_steps,
982
- step_id=i,
983
1113
  )
984
1114
  )
985
1115
 
@@ -1004,7 +1134,7 @@ class FlashAttentionMultiStepBackend:
1004
1134
  forward_batch.batch_size * self.topk,
1005
1135
  forward_batch.req_pool_indices,
1006
1136
  forward_batch.seq_lens,
1007
- encoder_lens=None,
1137
+ encoder_lens=forward_batch.encoder_lens,
1008
1138
  forward_mode=ForwardMode.DECODE,
1009
1139
  spec_info=forward_batch.spec_info,
1010
1140
  )
@@ -1021,7 +1151,7 @@ class FlashAttentionMultiStepBackend:
1021
1151
  forward_batch.req_pool_indices,
1022
1152
  forward_batch.seq_lens,
1023
1153
  forward_batch.seq_lens_sum,
1024
- encoder_lens=None,
1154
+ encoder_lens=forward_batch.encoder_lens,
1025
1155
  forward_mode=ForwardMode.DECODE,
1026
1156
  spec_info=forward_batch.spec_info,
1027
1157
  seq_lens_cpu=forward_batch.seq_lens_cpu,