sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,56 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
-
5
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
6
-
7
- """
8
- Support different attention backends.
9
- Now there are three backends: FlashInfer, Triton and FlashAttention.
10
- Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
11
- """
12
-
13
3
  from dataclasses import dataclass
14
4
  from typing import TYPE_CHECKING, Optional, Union
15
5
 
6
+ import numpy as np
16
7
  import torch
17
8
 
18
9
  from sglang.srt.configs.model_config import AttentionArch
19
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
21
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
22
14
 
23
15
  if TYPE_CHECKING:
24
16
  from sglang.srt.layers.radix_attention import RadixAttention
25
17
  from sglang.srt.model_executor.model_runner import ModelRunner
26
18
 
27
- from sgl_kernel.flash_attn import flash_attn_with_kvcache
19
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
28
20
 
29
21
 
30
22
  @dataclass
31
23
  class FlashAttentionMetadata:
32
24
  """Metadata to be init once in the model forward pass,
33
- each layer's forward pass can reuse the metadata."""
25
+ each layer's forward pass can reuse the metadata.
34
26
 
35
- # Cumulative sequence lengths for query
36
- cu_seqlens_q: torch.Tensor = None
37
- # Cumulative sequence lengths for key
38
- cu_seqlens_k: torch.Tensor = None
27
+ For each init metadata function, we will try set up them in below order
28
+ """
29
+
30
+ # Sequence lengths for the forward batch
31
+ cache_seqlens_int32: torch.Tensor = None
39
32
  # Maximum sequence length for query
40
33
  max_seq_len_q: int = 0
41
34
  # Maximum sequence length for key
42
35
  max_seq_len_k: int = 0
36
+ # Cumulative sequence lengths for query
37
+ cu_seqlens_q: torch.Tensor = None
38
+ # Cumulative sequence lengths for key
39
+ cu_seqlens_k: torch.Tensor = None
43
40
  # Window size (typically used by Gemma)
44
41
  window_size: tuple = (-1, -1)
45
42
  # Page table, the index of KV Cache Tables/Blocks
46
43
  page_table: torch.Tensor = None
44
+
45
+ # Encoder metadata
46
+ # Cumulative sequence lengths for encoder key
47
+ encoder_cu_seqlens_k: torch.Tensor = None
48
+ # Maximum sequence length for encoder key
49
+ encoder_max_seq_len_k: int = 0
47
50
  # Sequence lengths for the forward batch
48
- cache_seqlens_int32: torch.Tensor = None
51
+ encoder_lens_int32: torch.Tensor = None
52
+ # Page table for the encoder
53
+ encoder_page_table: torch.Tensor = None
49
54
 
50
55
  @dataclass
51
56
  class LocalAttentionMetadata:
@@ -137,6 +142,16 @@ def make_local_attention_virtual_batches(
137
142
  seqlens_k_local: Key sequence lengths for local attention
138
143
  block_table_local: Block table for local attention
139
144
  """
145
+ # Adjust attention_chunk_size based on the actual sequence length
146
+ # to avoid index out of bounds errors
147
+ max_seq_len = seq_lens_np.max()
148
+ effective_chunk_size = min(attn_chunk_size, max_seq_len)
149
+ # Make sure effective_chunk_size is divisible by page_size
150
+ effective_chunk_size = (effective_chunk_size // page_size) * page_size
151
+ if effective_chunk_size < page_size:
152
+ effective_chunk_size = page_size
153
+ attn_chunk_size = effective_chunk_size
154
+
140
155
  q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
141
156
  actual_batch_size = seq_lens_np.shape[0]
142
157
 
@@ -231,7 +246,11 @@ def make_local_attention_virtual_batches(
231
246
  np.arange(pages_per_local_batch, dtype=np.int32),
232
247
  (virtual_batches, pages_per_local_batch),
233
248
  ) + np.expand_dims(block_starts, axis=1)
234
- block_indices = block_indices.flatten()
249
+ # Ensure block_indices doesn't exceed block_table dimensions
250
+ # This is a critical safety check that prevents index out of bounds errors
251
+ # when dealing with large sequences (>8192 tokens) or when the block_table
252
+ # dimensions are smaller than what would be needed for the full attention chunk size.
253
+ block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
235
254
  batch_indices = np.repeat(
236
255
  np.arange(actual_batch_size, dtype=np.int32),
237
256
  local_blocks * pages_per_local_batch,
@@ -270,9 +289,9 @@ class FlashAttentionBackend(AttentionBackend):
270
289
  self,
271
290
  model_runner: ModelRunner,
272
291
  skip_prefill: bool = False,
292
+ speculative_step_id=0,
273
293
  topk=0,
274
294
  speculative_num_steps=0,
275
- step_id=0,
276
295
  ):
277
296
  super().__init__()
278
297
 
@@ -287,20 +306,18 @@ class FlashAttentionBackend(AttentionBackend):
287
306
  self.decode_cuda_graph_metadata = {}
288
307
  self.target_verify_metadata = {}
289
308
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
309
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
310
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
290
311
  self.page_size = model_runner.page_size
291
- self.use_mla = (
292
- model_runner.model_config.attention_arch == AttentionArch.MLA
293
- ) and (not global_server_args_dict["disable_mla"])
312
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
294
313
  self.skip_prefill = skip_prefill
295
314
 
296
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
297
- assert (
298
- topk <= 1
299
- ), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
300
-
301
- self.topk = 1
302
- self.step_id = step_id
315
+ self.topk = topk
303
316
  self.speculative_num_steps = speculative_num_steps
317
+ self.speculative_num_draft_tokens = (
318
+ model_runner.server_args.speculative_num_draft_tokens
319
+ )
320
+ self.speculative_step_id = speculative_step_id
304
321
 
305
322
  # Local attention settings
306
323
  self.attention_chunk_size = (
@@ -310,71 +327,63 @@ class FlashAttentionBackend(AttentionBackend):
310
327
  )
311
328
 
312
329
  def init_forward_metadata(self, forward_batch: ForwardBatch):
313
- """Initialize forward metadata to cache repetitive calculations."""
330
+ """Initialize forward metadata hence all layers in the forward pass can reuse it."""
314
331
  metadata = FlashAttentionMetadata()
315
332
  seqlens_in_batch = forward_batch.seq_lens
316
333
  batch_size = len(seqlens_in_batch)
317
334
  device = seqlens_in_batch.device
318
- if forward_batch.forward_mode.is_decode():
319
- # Skip Prefill or Draft Decode
320
- # Note: Draft Decode will be ran on the Draft Worker
335
+
336
+ if forward_batch.forward_mode.is_decode_or_idle():
337
+ # Draft Decode
321
338
  if forward_batch.spec_info is not None:
339
+ metadata.cache_seqlens_int32 = (
340
+ seqlens_in_batch + (self.speculative_step_id + 1)
341
+ ).to(torch.int32)
342
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
343
+ self.speculative_step_id + 1
344
+ )
322
345
  metadata.cu_seqlens_q = torch.arange(
323
346
  0, batch_size + 1, dtype=torch.int32, device=device
324
347
  )
325
- seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
326
- metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
327
348
  metadata.cu_seqlens_k = torch.nn.functional.pad(
328
349
  torch.cumsum(
329
350
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
330
351
  ),
331
352
  (1, 0),
332
353
  )
333
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
334
- self.step_id + 1
335
- )
336
354
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
337
355
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
338
356
  ]
339
- cache_loc = forward_batch.out_cache_loc.view(
340
- self.speculative_num_steps, -1
341
- ).T
342
-
343
- for idx, single_seq_len in enumerate(seq_lens_with_decode):
344
- real_bsz_start_idx = idx
345
- real_bsz_end_idx = idx + 1
346
- metadata.page_table[
347
- real_bsz_start_idx:real_bsz_end_idx,
348
- (single_seq_len - (self.step_id + 1)) : single_seq_len,
349
- ] = cache_loc[
350
- real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
351
- ]
352
- else: # Normal Decode without Spec Decoding
357
+
358
+ self._init_local_attn_metadata(metadata, device)
359
+ else:
360
+ # Normal Decode
353
361
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
362
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
363
+ metadata.cu_seqlens_q = torch.arange(
364
+ 0, batch_size + 1, dtype=torch.int32, device=device
365
+ )
354
366
  metadata.cu_seqlens_k = torch.nn.functional.pad(
355
367
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
356
368
  )
357
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
358
369
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
359
370
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
360
371
  ]
361
- metadata.cu_seqlens_q = torch.arange(
362
- 0, batch_size + 1, dtype=torch.int32, device=device
363
- )
372
+
373
+ self._init_local_attn_metadata(metadata, device)
364
374
  elif forward_batch.forward_mode.is_target_verify():
365
- # Note: Target Verify will be ran on the Target Worker
366
- draft_token_num = forward_batch.spec_info.draft_token_num
367
375
  metadata.cache_seqlens_int32 = (
368
- forward_batch.seq_lens + draft_token_num
376
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
369
377
  ).to(torch.int32)
370
- metadata.max_seq_len_q = draft_token_num
378
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
371
379
  metadata.max_seq_len_k = (
372
- forward_batch.seq_lens_cpu.max().item() + draft_token_num
380
+ forward_batch.seq_lens_cpu.max().item()
381
+ + self.speculative_num_draft_tokens
373
382
  )
374
383
  metadata.cu_seqlens_q = torch.arange(
375
384
  0,
376
- batch_size * draft_token_num + 1,
377
- draft_token_num,
385
+ batch_size * self.speculative_num_draft_tokens + 1,
386
+ self.speculative_num_draft_tokens,
378
387
  dtype=torch.int32,
379
388
  device=device,
380
389
  )
@@ -386,79 +395,58 @@ class FlashAttentionBackend(AttentionBackend):
386
395
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
387
396
  ]
388
397
 
389
- elif forward_batch.forward_mode.is_extend_or_draft_extend():
390
- # Normal or Draft Extend (Both of them will be ran on the Target Worker)
398
+ elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
391
399
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
400
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
392
401
  metadata.cu_seqlens_k = torch.nn.functional.pad(
393
402
  torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
394
403
  )
395
- # Precompute maximum sequence length
396
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
397
- # Precompute page table
398
404
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
399
405
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
400
406
  ]
401
407
 
402
- # Precompute cumulative sequence lengths
403
408
  if (
404
409
  any(forward_batch.extend_prefix_lens_cpu)
405
410
  or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
406
411
  ):
407
412
  extend_seq_lens = forward_batch.extend_seq_lens
413
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
408
414
  metadata.cu_seqlens_q = torch.nn.functional.pad(
409
415
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
410
416
  )
411
- metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
412
417
  else:
413
- metadata.cu_seqlens_q = metadata.cu_seqlens_k
414
418
  metadata.max_seq_len_q = metadata.max_seq_len_k
419
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
415
420
 
416
421
  # Setup local attention if enabled
417
- if (
418
- self.attention_chunk_size is not None
419
- and forward_batch.forward_mode == ForwardMode.EXTEND
420
- ):
421
- # Convert tensors to numpy for local attention processing
422
- cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
423
- seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
424
-
425
- # Adjust attention_chunk_size based on the actual sequence length
426
- # to avoid index out of bounds errors
427
- max_seq_len = seq_lens_np.max()
428
- effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
429
- # Make sure effective_chunk_size is divisible by page_size
430
- effective_chunk_size = (
431
- effective_chunk_size // self.page_size
432
- ) * self.page_size
433
- if effective_chunk_size < self.page_size:
434
- effective_chunk_size = self.page_size
435
-
436
- # Create local attention metadata
437
- (
438
- seqlens_q_local_np,
439
- cu_seqlens_q_local_np,
440
- seqlens_k_local_np,
441
- block_table_local,
442
- ) = make_local_attention_virtual_batches(
443
- effective_chunk_size,
444
- cu_seqlens_q_np,
445
- seq_lens_np,
446
- metadata.page_table,
447
- self.page_size,
448
- )
422
+ if forward_batch.forward_mode == ForwardMode.EXTEND:
423
+ self._init_local_attn_metadata(metadata, device)
424
+
425
+ # Encoder metadata for cross attention
426
+ if forward_batch.encoder_lens is not None:
427
+ assert (
428
+ forward_batch.encoder_lens.numel() == 1
429
+ ), "Only encoder size 1 is supported for now"
430
+
431
+ metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
432
+ metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
433
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
434
+ (1, 0),
435
+ )
436
+ metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
437
+ metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
438
+ forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
439
+ ]
449
440
 
450
- local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
451
- local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
452
- device
453
- ),
454
- local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
455
- local_block_table=block_table_local,
456
- local_max_query_len=seqlens_q_local_np.max(),
457
- local_max_seq_len=seqlens_k_local_np.max(),
458
- )
459
- metadata.local_attn_metadata = local_metadata
441
+ # Currently only support forward_batch.encoder_lens.numel() == 1
442
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
443
+ forward_batch.req_pool_indices,
444
+ metadata.encoder_max_seq_len_k : (
445
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
446
+ ),
447
+ ]
460
448
 
461
- # Precompute strided indices
449
+ # Convert the page table to a strided format which is needed by FA3 API
462
450
  if self.page_size > 1:
463
451
  self.strided_indices = torch.arange(
464
452
  0, metadata.page_table.shape[1], self.page_size, device=self.device
@@ -498,7 +486,7 @@ class FlashAttentionBackend(AttentionBackend):
498
486
  v,
499
487
  )
500
488
 
501
- # Use precomputed metadata
489
+ # Use precomputed metadata across all layers
502
490
  metadata = self.forward_metadata
503
491
 
504
492
  # Calculate window size (can be moved to metadata if layer properties don't change)
@@ -506,9 +494,18 @@ class FlashAttentionBackend(AttentionBackend):
506
494
  # here is two side inclusive
507
495
  window_size = (
508
496
  (layer.sliding_window_size, 0)
509
- if layer.sliding_window_size is not None
497
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
510
498
  else (-1, -1)
511
499
  )
500
+ k_descale, v_descale = None, None
501
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
502
+ # has corresponding quantization method so that layer.k_scale is not None
503
+ if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
504
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
505
+ k_descale = layer.k_scale.expand(descale_shape)
506
+ v_descale = layer.v_scale.expand(descale_shape)
507
+ q = q.to(self.kv_cache_dtype)
508
+ causal = not layer.is_cross_attention
512
509
 
513
510
  # Check if we should use local attention
514
511
  use_local_attn = (
@@ -536,14 +533,21 @@ class FlashAttentionBackend(AttentionBackend):
536
533
  # Use Flash Attention for prefill
537
534
  if not self.use_mla:
538
535
  # Do multi-head attention
539
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
540
- key_cache, value_cache = kv_cache[0], kv_cache[1]
536
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
537
+ layer.layer_id
538
+ )
541
539
  key_cache = key_cache.view(
542
540
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
543
541
  )
544
542
  value_cache = value_cache.view(
545
543
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
546
544
  )
545
+ if layer.is_cross_attention:
546
+ page_table = metadata.encoder_page_table
547
+ cache_seqlens = metadata.encoder_lens_int32
548
+ cu_seqlens_k = metadata.encoder_cu_seqlens_k
549
+ window_size = (-1, -1)
550
+
547
551
  o = flash_attn_with_kvcache(
548
552
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
549
553
  k_cache=key_cache,
@@ -554,48 +558,93 @@ class FlashAttentionBackend(AttentionBackend):
554
558
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
555
559
  max_seqlen_q=max_seqlen_q,
556
560
  softmax_scale=layer.scaling,
557
- causal=True,
561
+ causal=causal,
558
562
  window_size=window_size,
559
563
  softcap=layer.logit_cap,
560
- k_descale=layer.k_scale,
561
- v_descale=layer.v_scale,
564
+ k_descale=k_descale,
565
+ v_descale=v_descale,
562
566
  )
567
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
563
568
  else:
564
- # Do absorbed multi-latent attention
565
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
566
- k_rope = kv_cache[:, :, layer.v_head_dim :]
567
- c_kv = kv_cache[:, :, : layer.v_head_dim]
568
- k_rope_cache = k_rope.view(
569
- -1,
570
- self.page_size,
571
- layer.tp_k_head_num,
572
- layer.head_dim - layer.v_head_dim,
573
- )
574
- c_kv_cache = c_kv.view(
575
- -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
576
- )
577
-
578
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
579
- q_nope = q_all[:, :, : layer.v_head_dim]
580
- q_rope = q_all[:, :, layer.v_head_dim :]
581
- o = flash_attn_with_kvcache(
582
- q=q_rope,
583
- k_cache=k_rope_cache,
584
- v_cache=c_kv_cache,
585
- qv=q_nope,
586
- page_table=page_table,
587
- cache_seqlens=cache_seqlens,
588
- cu_seqlens_q=cu_seqlens_q,
589
- cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
590
- max_seqlen_q=max_seqlen_q,
591
- softmax_scale=layer.scaling,
592
- causal=True,
593
- softcap=layer.logit_cap,
594
- k_descale=layer.k_scale,
595
- v_descale=layer.v_scale,
596
- )
569
+ if (
570
+ not global_server_args_dict["disable_chunked_prefix_cache"]
571
+ and forward_batch.attn_attend_prefix_cache is not None
572
+ and not forward_batch.forward_mode.is_target_verify()
573
+ and not forward_batch.forward_mode.is_draft_extend()
574
+ ):
575
+ # Do multi-head attention with chunked prefix cache
576
+
577
+ if forward_batch.attn_attend_prefix_cache:
578
+ # MHA for chunked prefix kv cache when running model with MLA
579
+ assert forward_batch.prefix_chunk_idx is not None
580
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
581
+ assert forward_batch.prefix_chunk_max_seq_lens is not None
582
+
583
+ chunk_idx = forward_batch.prefix_chunk_idx
584
+ assert chunk_idx >= 0
585
+
586
+ output, lse, *rest = flash_attn_varlen_func(
587
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
588
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
589
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
590
+ cu_seqlens_q=metadata.cu_seqlens_q,
591
+ cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
592
+ max_seqlen_q=metadata.max_seq_len_q,
593
+ max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
594
+ softmax_scale=layer.scaling,
595
+ causal=False,
596
+ return_softmax_lse=True,
597
+ )
598
+ else:
599
+ # MHA for extend part of sequence without attending prefix kv cache
600
+ output, lse, *rest = flash_attn_varlen_func(
601
+ q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
602
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
603
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
604
+ cu_seqlens_q=metadata.cu_seqlens_q,
605
+ cu_seqlens_k=metadata.cu_seqlens_q,
606
+ max_seqlen_q=metadata.max_seq_len_q,
607
+ max_seqlen_k=metadata.max_seq_len_q,
608
+ softmax_scale=layer.scaling,
609
+ causal=True,
610
+ return_softmax_lse=True,
611
+ )
612
+ return output, lse
613
+ else:
614
+ # Do absorbed multi-latent attention
615
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
616
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
617
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
618
+ k_rope_cache = k_rope.view(
619
+ -1,
620
+ self.page_size,
621
+ layer.tp_k_head_num,
622
+ layer.head_dim - layer.v_head_dim,
623
+ )
624
+ c_kv_cache = c_kv.view(
625
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
626
+ )
627
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
628
+ q_nope = q_all[:, :, : layer.v_head_dim]
629
+ q_rope = q_all[:, :, layer.v_head_dim :]
630
+ o = flash_attn_with_kvcache(
631
+ q=q_rope,
632
+ k_cache=k_rope_cache,
633
+ v_cache=c_kv_cache,
634
+ qv=q_nope,
635
+ page_table=page_table,
636
+ cache_seqlens=cache_seqlens,
637
+ cu_seqlens_q=cu_seqlens_q,
638
+ cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
639
+ max_seqlen_q=max_seqlen_q,
640
+ softmax_scale=layer.scaling,
641
+ causal=True,
642
+ softcap=layer.logit_cap,
643
+ k_descale=k_descale,
644
+ v_descale=v_descale,
645
+ )
597
646
 
598
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
647
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
599
648
 
600
649
  def forward_decode(
601
650
  self,
@@ -606,8 +655,6 @@ class FlashAttentionBackend(AttentionBackend):
606
655
  forward_batch: ForwardBatch,
607
656
  save_kv_cache=True,
608
657
  ) -> torch.Tensor:
609
- """Forward pass with FlashAttention using precomputed metadata."""
610
- # Save KV cache if needed
611
658
  if k is not None:
612
659
  assert v is not None
613
660
  if save_kv_cache:
@@ -628,25 +675,39 @@ class FlashAttentionBackend(AttentionBackend):
628
675
  v,
629
676
  )
630
677
 
631
- # Use precomputed metadata
678
+ # Use precomputed metadata across all layers
632
679
  metadata = self.forward_metadata
680
+ local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
681
+ use_local_attention = (
682
+ self.attention_chunk_size is not None and local_attn_metadata is not None
683
+ )
633
684
 
634
685
  # Calculate window size (can be moved to metadata if layer properties don't change)
635
686
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
636
687
  # here is two side inclusive
637
688
  window_size = (
638
689
  (layer.sliding_window_size, 0)
639
- if layer.sliding_window_size is not None
690
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1
640
691
  else (-1, -1)
641
692
  )
642
- page_table = metadata.page_table
693
+ causal = not layer.is_cross_attention
694
+
695
+ k_descale, v_descale = None, None
696
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
697
+ # has corresponding quantization method so that layer.k_scale is not None
698
+ if self.kv_cache_dtype_str != "auto":
699
+ if layer.k_scale is not None:
700
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
701
+ k_descale = layer.k_scale.expand(descale_shape)
702
+ v_descale = layer.v_scale.expand(descale_shape)
703
+ q = q.to(self.kv_cache_dtype)
643
704
 
644
705
  if not self.use_mla:
645
706
  # Do multi-head attention
646
707
 
647
- # Get KV cache
648
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
649
- key_cache, value_cache = kv_cache[0], kv_cache[1]
708
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
709
+ layer.layer_id
710
+ )
650
711
  key_cache = key_cache.view(
651
712
  -1, self.page_size, layer.tp_k_head_num, layer.head_dim
652
713
  )
@@ -654,24 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
654
715
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
655
716
  )
656
717
 
657
- # Pre-reshape query tensor
658
- q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
659
- o = flash_attn_with_kvcache(
660
- q=q_reshaped,
661
- k_cache=key_cache,
662
- v_cache=value_cache,
663
- page_table=page_table,
664
- cache_seqlens=metadata.cache_seqlens_int32,
665
- cu_seqlens_q=metadata.cu_seqlens_q,
666
- cu_seqlens_k_new=metadata.cu_seqlens_k,
667
- max_seqlen_q=1,
668
- softmax_scale=layer.scaling,
669
- causal=True,
670
- window_size=window_size,
671
- softcap=layer.logit_cap,
672
- k_descale=layer.k_scale,
673
- v_descale=layer.v_scale,
674
- )
718
+ if layer.is_cross_attention:
719
+ # Always use non-chunked logic for cross-attention
720
+ o = flash_attn_with_kvcache(
721
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
722
+ k_cache=key_cache,
723
+ v_cache=value_cache,
724
+ page_table=metadata.encoder_page_table,
725
+ cache_seqlens=metadata.encoder_lens_int32,
726
+ cu_seqlens_q=metadata.cu_seqlens_q,
727
+ cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
728
+ max_seqlen_q=1,
729
+ softmax_scale=layer.scaling,
730
+ causal=False,
731
+ window_size=(-1, -1),
732
+ softcap=layer.logit_cap,
733
+ k_descale=k_descale,
734
+ v_descale=v_descale,
735
+ )
736
+ elif use_local_attention:
737
+ # Use chunked (local) attention batching for self-attention
738
+ o = flash_attn_with_kvcache(
739
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
740
+ k_cache=key_cache,
741
+ v_cache=value_cache,
742
+ page_table=local_attn_metadata.local_block_table,
743
+ cache_seqlens=local_attn_metadata.local_seqused_k,
744
+ cu_seqlens_q=local_attn_metadata.local_query_start_loc,
745
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
746
+ max_seqlen_q=local_attn_metadata.local_max_query_len,
747
+ softmax_scale=layer.scaling,
748
+ causal=True,
749
+ window_size=(-1, -1),
750
+ softcap=layer.logit_cap,
751
+ k_descale=k_descale,
752
+ v_descale=v_descale,
753
+ )
754
+ else:
755
+ # Default: single-token self-attention
756
+ o = flash_attn_with_kvcache(
757
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
758
+ k_cache=key_cache,
759
+ v_cache=value_cache,
760
+ page_table=metadata.page_table,
761
+ cache_seqlens=metadata.cache_seqlens_int32,
762
+ cu_seqlens_q=metadata.cu_seqlens_q,
763
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
764
+ max_seqlen_q=1,
765
+ softmax_scale=layer.scaling,
766
+ causal=True,
767
+ window_size=window_size,
768
+ softcap=layer.logit_cap,
769
+ k_descale=k_descale,
770
+ v_descale=v_descale,
771
+ )
675
772
  else:
676
773
  # Do absorbed multi-latent attention
677
774
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
@@ -696,7 +793,7 @@ class FlashAttentionBackend(AttentionBackend):
696
793
  k_cache=k_rope_cache,
697
794
  v_cache=c_kv_cache,
698
795
  qv=q_nope,
699
- page_table=page_table,
796
+ page_table=metadata.page_table,
700
797
  cache_seqlens=metadata.cache_seqlens_int32,
701
798
  cu_seqlens_q=metadata.cu_seqlens_q,
702
799
  cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -704,8 +801,8 @@ class FlashAttentionBackend(AttentionBackend):
704
801
  softmax_scale=layer.scaling,
705
802
  causal=True,
706
803
  softcap=layer.logit_cap,
707
- k_descale=layer.k_scale,
708
- v_descale=layer.v_scale,
804
+ k_descale=k_descale,
805
+ v_descale=v_descale,
709
806
  )
710
807
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
711
808
 
@@ -719,7 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
719
816
  to avoid memory allocations.
720
817
  """
721
818
  self.decode_cuda_graph_metadata = {
722
- # Page table for token mapping (batch_size, max_context_len)
819
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
820
+ "cu_seqlens_q": torch.arange(
821
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
822
+ ),
823
+ "cu_seqlens_k": torch.zeros(
824
+ max_bs + 1, dtype=torch.int32, device=self.device
825
+ ),
723
826
  "page_table": torch.zeros(
724
827
  max_bs,
725
828
  (self.max_context_len + self.page_size - 1) // self.page_size,
@@ -735,35 +838,42 @@ class FlashAttentionBackend(AttentionBackend):
735
838
  "strided_indices": torch.arange(
736
839
  0, self.max_context_len, self.page_size, device=self.device
737
840
  ),
841
+ }
842
+
843
+ self.target_verify_metadata = {
738
844
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
739
- "cu_seqlens_q": torch.arange(
740
- 0, max_bs + 128, dtype=torch.int32, device=self.device
845
+ "cu_seqlens_q": torch.zeros(
846
+ max_bs + 1, dtype=torch.int32, device=self.device
741
847
  ),
742
848
  "cu_seqlens_k": torch.zeros(
743
- max_bs + 128, dtype=torch.int32, device=self.device
849
+ max_bs + 1, dtype=torch.int32, device=self.device
744
850
  ),
745
- }
746
-
747
- self.target_verify_metadata = {
748
851
  "page_table": torch.zeros(
749
852
  max_bs,
750
853
  (self.max_context_len + self.page_size - 1) // self.page_size,
751
854
  dtype=torch.int32,
752
855
  device=self.device,
753
856
  ),
754
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
755
- "cu_seqlens_q": torch.zeros(
756
- max_bs + 128, dtype=torch.int32, device=self.device
757
- ),
758
- "cu_seqlens_k": torch.zeros(
759
- max_bs + 128, dtype=torch.int32, device=self.device
760
- ),
761
- "max_seqlen_q": 0,
762
857
  "strided_indices": torch.arange(
763
858
  0, self.max_context_len, self.page_size, device=self.device
764
859
  ),
765
860
  }
766
861
 
862
+ self.encoder_metadata = {
863
+ "encoder_page_table": torch.zeros(
864
+ max_bs,
865
+ self.max_context_len,
866
+ dtype=torch.int32,
867
+ device=self.device,
868
+ ),
869
+ "encoder_lens_int32": torch.zeros(
870
+ max_bs, dtype=torch.int32, device=self.device
871
+ ),
872
+ "encoder_cu_seqlens_k": torch.zeros(
873
+ max_bs + 1, dtype=torch.int32, device=self.device
874
+ ),
875
+ }
876
+
767
877
  def init_forward_metadata_capture_cuda_graph(
768
878
  self,
769
879
  bs: int,
@@ -777,27 +887,24 @@ class FlashAttentionBackend(AttentionBackend):
777
887
  """Initialize forward metadata for capturing CUDA graph."""
778
888
  metadata = FlashAttentionMetadata()
779
889
  device = seq_lens.device
780
- if forward_mode.is_decode():
890
+ if forward_mode.is_decode_or_idle():
781
891
  if spec_info is not None:
782
892
  # Draft Decode
783
- metadata.cu_seqlens_q = torch.arange(
784
- 0, bs + 1, dtype=torch.int32, device=device
785
- )
786
893
  metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
787
894
  "cache_seqlens"
788
895
  ][:bs]
789
-
896
+ metadata.max_seq_len_k = seq_lens.max().item() + (
897
+ self.speculative_step_id + 1
898
+ )
790
899
  metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
791
900
  : bs + 1
792
901
  ]
793
-
794
902
  metadata.cu_seqlens_k = torch.nn.functional.pad(
795
903
  torch.cumsum(
796
904
  metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
797
905
  ),
798
906
  (1, 0),
799
907
  )
800
- metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
801
908
  metadata.page_table = self.decode_cuda_graph_metadata[
802
909
  "page_table_draft_decode"
803
910
  ][req_pool_indices, :]
@@ -822,43 +929,49 @@ class FlashAttentionBackend(AttentionBackend):
822
929
  )
823
930
  self.decode_cuda_graph_metadata[bs] = metadata
824
931
  elif forward_mode.is_target_verify():
825
- draft_token_num = spec_info.draft_token_num
826
-
827
932
  metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
828
933
  :bs
829
934
  ]
830
935
  metadata.cache_seqlens_int32.copy_(
831
- (seq_lens + draft_token_num).to(torch.int32)
936
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
832
937
  )
833
938
 
834
- metadata.max_seq_len_q = draft_token_num
835
- metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
939
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
940
+ metadata.max_seq_len_k = (
941
+ seq_lens.max().item() + self.speculative_num_draft_tokens
942
+ )
836
943
 
837
- metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
838
- torch.arange(
839
- 0,
840
- bs * draft_token_num + 1,
841
- draft_token_num,
842
- dtype=torch.int32,
843
- device=device,
844
- )
845
- ]
846
- cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
847
- cu_k.copy_(
848
- torch.nn.functional.pad(
849
- torch.cumsum(
850
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
851
- ),
852
- (1, 0),
853
- )
944
+ metadata.cu_seqlens_q = torch.arange(
945
+ 0,
946
+ bs * self.speculative_num_draft_tokens + 1,
947
+ self.speculative_num_draft_tokens,
948
+ dtype=torch.int32,
949
+ device=device,
854
950
  )
855
- metadata.cu_seqlens_k = cu_k
951
+
952
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
953
+ : (bs + 1)
954
+ ]
955
+
856
956
  metadata.page_table = self.target_verify_metadata["page_table"][
857
957
  req_pool_indices, :
858
958
  ]
859
959
 
860
960
  self.target_verify_metadata[bs] = metadata
861
961
 
962
+ if encoder_lens is not None:
963
+ encoder_bs = encoder_lens.numel()
964
+ metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
965
+ :encoder_bs
966
+ ]
967
+ metadata.encoder_cu_seqlens_k = self.encoder_metadata[
968
+ "encoder_cu_seqlens_k"
969
+ ][: (encoder_bs + 1)]
970
+
971
+ metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
972
+ req_pool_indices, :
973
+ ]
974
+
862
975
  self.forward_metadata = metadata
863
976
 
864
977
  def init_forward_metadata_replay_cuda_graph(
@@ -874,24 +987,23 @@ class FlashAttentionBackend(AttentionBackend):
874
987
  out_cache_loc: torch.Tensor = None,
875
988
  ):
876
989
  # """Initialize forward metadata for replaying CUDA graph."""
877
- device = seq_lens.device
878
990
  seq_lens = seq_lens[:bs]
879
- req_pool_indices = req_pool_indices[:bs]
880
991
  seq_lens_cpu = seq_lens_cpu[:bs]
881
- if forward_mode.is_decode():
992
+ req_pool_indices = req_pool_indices[:bs]
993
+ device = seq_lens.device
994
+
995
+ if forward_mode.is_decode_or_idle():
882
996
  metadata = self.decode_cuda_graph_metadata[bs]
883
997
 
884
998
  if spec_info is not None:
885
999
  # Draft Decode
886
- max_len = seq_lens_cpu.max().item()
887
- metadata.max_seq_len_k = max_len + (self.step_id + 1)
888
-
889
1000
  metadata.cache_seqlens_int32.copy_(
890
- (seq_lens + (self.step_id + 1)).to(torch.int32)
1001
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
891
1002
  )
892
1003
 
893
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
894
-
1004
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1005
+ self.speculative_step_id + 1
1006
+ )
895
1007
  metadata.cu_seqlens_k.copy_(
896
1008
  torch.nn.functional.pad(
897
1009
  torch.cumsum(
@@ -906,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
906
1018
  ]
907
1019
 
908
1020
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1021
+
1022
+ self._init_local_attn_metadata(metadata, device)
909
1023
  else:
910
1024
  # Normal Decode
911
1025
  max_len = seq_lens_cpu.max().item()
@@ -920,31 +1034,25 @@ class FlashAttentionBackend(AttentionBackend):
920
1034
  metadata.max_seq_len_k + self.page_size - 1
921
1035
  ) // self.page_size
922
1036
  page_indices = self.req_to_token[
923
- :,
924
- self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1037
+ req_pool_indices[:, None],
1038
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
1039
+ None, :
1040
+ ],
925
1041
  ]
926
- page_indices = page_indices[req_pool_indices] // self.page_size
1042
+ page_indices //= self.page_size
927
1043
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
928
1044
  metadata.page_table[:, max_seq_pages:].fill_(0)
929
1045
 
1046
+ self._init_local_attn_metadata(metadata, device)
930
1047
  elif forward_mode.is_target_verify():
931
1048
  metadata = self.target_verify_metadata[bs]
932
- draft_token_num = spec_info.draft_token_num
933
-
934
- metadata.cu_seqlens_q.copy_(
935
- torch.arange(
936
- 0,
937
- bs * draft_token_num + 1,
938
- draft_token_num,
939
- dtype=torch.int32,
940
- device=device,
941
- )
942
- )
943
1049
  metadata.cache_seqlens_int32.copy_(
944
- (seq_lens + draft_token_num).to(torch.int32)
1050
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
945
1051
  )
946
1052
 
947
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
1053
+ metadata.max_seq_len_k = (
1054
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1055
+ )
948
1056
  metadata.cu_seqlens_k.copy_(
949
1057
  torch.nn.functional.pad(
950
1058
  torch.cumsum(
@@ -956,12 +1064,72 @@ class FlashAttentionBackend(AttentionBackend):
956
1064
  page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
957
1065
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
958
1066
 
1067
+ if encoder_lens is not None:
1068
+ # Only support encoder size 1 for now
1069
+ metadata.encoder_max_seq_len_k = encoder_lens[0]
1070
+ metadata.encoder_lens_int32.copy_(encoder_lens[:1])
1071
+ metadata.encoder_cu_seqlens_k.copy_(
1072
+ torch.nn.functional.pad(
1073
+ torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
1074
+ (1, 0),
1075
+ )
1076
+ )
1077
+
1078
+ metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
1079
+ self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
1080
+ )
1081
+
1082
+ # Update the regular page table
1083
+ page_table = self.req_to_token[
1084
+ req_pool_indices,
1085
+ metadata.encoder_max_seq_len_k : (
1086
+ metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
1087
+ ),
1088
+ ]
1089
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1090
+
959
1091
  self.forward_metadata = metadata
960
1092
 
961
1093
  def get_cuda_graph_seq_len_fill_value(self):
962
1094
  """Get the fill value for sequence length in CUDA graph."""
963
1095
  return 0
964
1096
 
1097
+ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1098
+ """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1099
+ if self.attention_chunk_size is None:
1100
+ metadata.local_attn_metadata = None
1101
+ return
1102
+
1103
+ cu_seqlens_q = metadata.cu_seqlens_q
1104
+ cache_seqlens_int32 = metadata.cache_seqlens_int32
1105
+ page_table = metadata.page_table
1106
+ if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1107
+ metadata.local_attn_metadata = None
1108
+ return
1109
+
1110
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1111
+ seq_lens_np = cache_seqlens_int32.cpu().numpy()
1112
+ (
1113
+ seqlens_q_local_np,
1114
+ cu_seqlens_q_local_np,
1115
+ seqlens_k_local_np,
1116
+ block_table_local,
1117
+ ) = make_local_attention_virtual_batches(
1118
+ self.attention_chunk_size,
1119
+ cu_seqlens_q_np,
1120
+ seq_lens_np,
1121
+ page_table,
1122
+ self.page_size,
1123
+ )
1124
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1125
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1126
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
1127
+ local_block_table=block_table_local.to(device),
1128
+ local_max_query_len=int(seqlens_q_local_np.max()),
1129
+ local_max_seq_len=int(seqlens_k_local_np.max()),
1130
+ )
1131
+ metadata.local_attn_metadata = local_metadata
1132
+
965
1133
 
966
1134
  class FlashAttentionMultiStepBackend:
967
1135
 
@@ -972,14 +1140,19 @@ class FlashAttentionMultiStepBackend:
972
1140
  self.topk = topk
973
1141
  self.speculative_num_steps = speculative_num_steps
974
1142
 
1143
+ # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1144
+ assert (
1145
+ self.topk == 1
1146
+ ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1147
+
975
1148
  self.attn_backends = []
976
1149
  for i in range(self.speculative_num_steps):
977
1150
  self.attn_backends.append(
978
1151
  FlashAttentionBackend(
979
1152
  model_runner,
1153
+ speculative_step_id=i,
980
1154
  topk=self.topk,
981
1155
  speculative_num_steps=self.speculative_num_steps,
982
- step_id=i,
983
1156
  )
984
1157
  )
985
1158
 
@@ -1004,7 +1177,7 @@ class FlashAttentionMultiStepBackend:
1004
1177
  forward_batch.batch_size * self.topk,
1005
1178
  forward_batch.req_pool_indices,
1006
1179
  forward_batch.seq_lens,
1007
- encoder_lens=None,
1180
+ encoder_lens=forward_batch.encoder_lens,
1008
1181
  forward_mode=ForwardMode.DECODE,
1009
1182
  spec_info=forward_batch.spec_info,
1010
1183
  )
@@ -1021,7 +1194,7 @@ class FlashAttentionMultiStepBackend:
1021
1194
  forward_batch.req_pool_indices,
1022
1195
  forward_batch.seq_lens,
1023
1196
  forward_batch.seq_lens_sum,
1024
- encoder_lens=None,
1197
+ encoder_lens=forward_batch.encoder_lens,
1025
1198
  forward_mode=ForwardMode.DECODE,
1026
1199
  spec_info=forward_batch.spec_info,
1027
1200
  seq_lens_cpu=forward_batch.seq_lens_cpu,