sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,513 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ end to end attention solution with aiter kernels
5
+ """
6
+
7
+ import math
8
+ import os
9
+ from dataclasses import dataclass
10
+ from enum import Enum, auto
11
+ from functools import partial
12
+ from typing import TYPE_CHECKING, List, Optional, Union
13
+
14
+ import torch
15
+ import triton
16
+ import triton.language as tl
17
+
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.model_executor.model_runner import ModelRunner
27
+ from sglang.srt.speculative.spec_info import SpecInfo
28
+
29
+ try:
30
+ from aiter import mha_batch_prefill_func, paged_attention_ragged
31
+ except ImportError:
32
+ print(
33
+ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
34
+ )
35
+
36
+
37
+ class WrapperDispatch(Enum):
38
+ SLIDING_WINDOW = auto()
39
+ CROSS_ATTENTION = auto()
40
+
41
+
42
+ @dataclass
43
+ class ForwardMetadata:
44
+ kv_indptr: torch.Tensor
45
+ kv_indices: torch.Tensor
46
+ max_q_len: int
47
+ max_kv_len: int
48
+
49
+
50
+ global_workspace_buffer = None
51
+
52
+ _AITER_PARTITION_SIZE_ROCM = 256
53
+
54
+
55
+ class AiterAttnBackend(AttentionBackend):
56
+ def __init__(
57
+ self,
58
+ model_runner: ModelRunner,
59
+ skip_prefill: bool = False,
60
+ kv_indptr_buf: Optional[torch.Tensor] = None,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.device = model_runner.device
65
+ self.is_multimodal = model_runner.model_config.is_multimodal
66
+ self.num_head = (
67
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
68
+ )
69
+ self.head_dim = model_runner.model_config.head_dim
70
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
71
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
72
+ get_attention_tp_size()
73
+ )
74
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
75
+
76
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
77
+
78
+ # Parse constants
79
+ self.max_context_len = model_runner.model_config.context_len
80
+ self.skip_prefill = skip_prefill
81
+
82
+ max_bs = model_runner.req_to_token_pool.size
83
+
84
+ if kv_indptr_buf is None:
85
+ self.kv_indptr = torch.zeros(
86
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
87
+ )
88
+ else:
89
+ self.kv_indptr = kv_indptr_buf
90
+
91
+ self.kv_last_page_len = torch.ones(
92
+ (max_bs,), dtype=torch.int32, device=model_runner.device
93
+ )
94
+ self.qo_indptr = torch.zeros(
95
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
96
+ )
97
+
98
+ # Create prefill indices updater
99
+ if not skip_prefill:
100
+ self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
101
+ model_runner, self
102
+ )
103
+
104
+ # aiter kernel related initialization
105
+ self.max_num_partitions = (
106
+ self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
107
+ ) // _AITER_PARTITION_SIZE_ROCM
108
+
109
+ nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
110
+
111
+ self.workspace_buffer = torch.empty(
112
+ (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
113
+ * nbyes_per_qo_elem
114
+ + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
115
+ dtype=torch.uint8,
116
+ device=self.device,
117
+ )
118
+
119
+ self.scale = float(1.0 / (self.head_dim**0.5))
120
+ self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
121
+ self.device
122
+ )
123
+ self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
124
+ self.device
125
+ )
126
+
127
+ self.logits_soft_cap = 0.0
128
+
129
+ self.forward_metadata: ForwardMetadata = None
130
+
131
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
132
+ if forward_batch.forward_mode.is_decode_or_idle():
133
+ # update for aiter
134
+ # create kv_indices and kv_inptr
135
+ bs = forward_batch.batch_size
136
+ kv_indptr = self.kv_indptr
137
+ spec_info = forward_batch.spec_info
138
+ if spec_info is None:
139
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
140
+ kv_indptr = kv_indptr[: bs + 1]
141
+ kv_indices = torch.zeros(
142
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
143
+ )
144
+ create_flashinfer_kv_indices_triton[(bs,)](
145
+ self.req_to_token,
146
+ forward_batch.req_pool_indices,
147
+ forward_batch.seq_lens,
148
+ kv_indptr,
149
+ None,
150
+ kv_indices,
151
+ self.req_to_token.stride(0),
152
+ )
153
+ else:
154
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
155
+ bs = kv_indptr.shape[0] - 1
156
+
157
+ self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
158
+
159
+ elif forward_batch.forward_mode.is_draft_extend():
160
+ self.indices_updater_prefill.update(
161
+ forward_batch.req_pool_indices,
162
+ forward_batch.seq_lens,
163
+ forward_batch.seq_lens_sum,
164
+ prefix_lens=None,
165
+ encoder_lens=forward_batch.encoder_lens,
166
+ spec_info=forward_batch.spec_info,
167
+ )
168
+ self.forward_metadata = ForwardMetadata(
169
+ self.indices_updater_prefill.kv_indptr,
170
+ self.indices_updater_prefill.kv_indices,
171
+ self.indices_updater_prefill.max_q_len,
172
+ self.indices_updater_prefill.max_kv_len,
173
+ )
174
+ elif forward_batch.forward_mode.is_target_verify():
175
+ self.indices_updater_prefill.update(
176
+ forward_batch.req_pool_indices,
177
+ forward_batch.seq_lens,
178
+ forward_batch.seq_lens_sum,
179
+ prefix_lens=None,
180
+ encoder_lens=forward_batch.encoder_lens,
181
+ spec_info=forward_batch.spec_info,
182
+ )
183
+ self.forward_metadata = ForwardMetadata(
184
+ self.indices_updater_prefill.kv_indptr,
185
+ self.indices_updater_prefill.kv_indices,
186
+ self.indices_updater_prefill.max_q_len,
187
+ self.indices_updater_prefill.max_kv_len,
188
+ )
189
+ else:
190
+ prefix_lens = forward_batch.extend_prefix_lens
191
+
192
+ if self.is_multimodal:
193
+ extend_no_prefix = False
194
+ else:
195
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
196
+
197
+ self.indices_updater_prefill.update(
198
+ forward_batch.req_pool_indices,
199
+ forward_batch.seq_lens,
200
+ forward_batch.seq_lens_sum,
201
+ prefix_lens,
202
+ encoder_lens=forward_batch.encoder_lens,
203
+ spec_info=None,
204
+ )
205
+ self.forward_metadata = ForwardMetadata(
206
+ self.indices_updater_prefill.kv_indptr,
207
+ self.indices_updater_prefill.kv_indices,
208
+ self.indices_updater_prefill.max_q_len,
209
+ self.indices_updater_prefill.max_kv_len,
210
+ )
211
+
212
+ def init_cuda_graph_state(
213
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
214
+ ):
215
+ if kv_indices_buf is None:
216
+ self.cuda_graph_kv_indices = torch.zeros(
217
+ (max_bs * self.max_context_len),
218
+ dtype=torch.int32,
219
+ device=self.device,
220
+ )
221
+ else:
222
+ self.cuda_graph_kv_indices = kv_indices_buf
223
+
224
+ if not self.skip_prefill:
225
+ self.cuda_graph_custom_mask = torch.zeros(
226
+ (max_bs * self.max_context_len),
227
+ dtype=torch.uint8,
228
+ device=self.device,
229
+ )
230
+
231
+ def init_forward_metadata_capture_cuda_graph(
232
+ self,
233
+ bs: int,
234
+ num_tokens: int,
235
+ req_pool_indices: torch.Tensor,
236
+ seq_lens: torch.Tensor,
237
+ encoder_lens: Optional[torch.Tensor],
238
+ forward_mode: ForwardMode,
239
+ spec_info: Optional[SpecInfo],
240
+ ):
241
+ if forward_mode.is_decode_or_idle():
242
+ if spec_info is None:
243
+ kv_indptr = self.kv_indptr
244
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
245
+ kv_indptr = kv_indptr[: bs + 1]
246
+ kv_indices = self.cuda_graph_kv_indices
247
+ create_flashinfer_kv_indices_triton[(bs,)](
248
+ self.req_to_token,
249
+ req_pool_indices,
250
+ seq_lens,
251
+ kv_indptr,
252
+ None,
253
+ kv_indices,
254
+ self.req_to_token.stride(0),
255
+ )
256
+ else:
257
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
258
+ self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
259
+
260
+ elif forward_mode.is_target_verify():
261
+ seq_lens_sum = seq_lens.sum().item()
262
+ self.indices_updater_prefill.update(
263
+ req_pool_indices,
264
+ seq_lens,
265
+ seq_lens_sum,
266
+ prefix_lens=None,
267
+ encoder_lens=encoder_lens,
268
+ spec_info=spec_info,
269
+ )
270
+ self.forward_metadata = ForwardMetadata(
271
+ self.indices_updater_prefill.kv_indptr,
272
+ self.indices_updater_prefill.kv_indices,
273
+ self.indices_updater_prefill.max_q_len,
274
+ self.indices_updater_prefill.max_kv_len,
275
+ )
276
+
277
+ else:
278
+ raise ValueError(f"Invalid mode: {forward_mode=}")
279
+
280
+ def init_forward_metadata_replay_cuda_graph(
281
+ self,
282
+ bs: int,
283
+ req_pool_indices: torch.Tensor,
284
+ seq_lens: torch.Tensor,
285
+ seq_lens_sum: int,
286
+ encoder_lens: Optional[torch.Tensor],
287
+ forward_mode: ForwardMode,
288
+ spec_info: Optional[SpecInfo],
289
+ seq_lens_cpu: Optional[torch.Tensor],
290
+ ):
291
+ if forward_mode.is_decode_or_idle():
292
+ kv_indptr = self.kv_indptr
293
+ kv_indices = self.cuda_graph_kv_indices
294
+ if spec_info is None:
295
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
296
+ kv_indptr = kv_indptr[: bs + 1]
297
+ create_flashinfer_kv_indices_triton[(bs,)](
298
+ self.req_to_token,
299
+ req_pool_indices[:bs],
300
+ seq_lens[:bs],
301
+ kv_indptr,
302
+ None,
303
+ kv_indices,
304
+ self.req_to_token.stride(0),
305
+ )
306
+ else:
307
+ kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
308
+ kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
309
+
310
+ elif forward_mode.is_target_verify():
311
+ self.indices_updater_prefill.update(
312
+ req_pool_indices[:bs],
313
+ seq_lens[:bs],
314
+ seq_lens_sum,
315
+ prefix_lens=None,
316
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
317
+ spec_info=spec_info,
318
+ )
319
+ else:
320
+ raise ValueError("Invalid forward mode")
321
+
322
+ def get_cuda_graph_seq_len_fill_value(self):
323
+ return 1
324
+
325
+ def forward_extend(
326
+ self,
327
+ q: torch.Tensor,
328
+ k: torch.Tensor,
329
+ v: torch.Tensor,
330
+ layer: RadixAttention,
331
+ forward_batch: ForwardBatch,
332
+ save_kv_cache=True,
333
+ ):
334
+ cache_loc = (
335
+ forward_batch.out_cache_loc
336
+ if not layer.is_cross_attention
337
+ else forward_batch.encoder_out_cache_loc
338
+ )
339
+
340
+ self.logits_soft_cap = layer.logit_cap
341
+
342
+ if k is not None:
343
+ assert v is not None
344
+ if save_kv_cache:
345
+ forward_batch.token_to_kv_pool.set_kv_buffer(
346
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
347
+ )
348
+
349
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
350
+
351
+ bs0 = forward_batch.batch_size + 1
352
+
353
+ o = mha_batch_prefill_func(
354
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
355
+ k_cache,
356
+ v_cache,
357
+ self.qo_indptr[:bs0],
358
+ self.forward_metadata.kv_indptr[:bs0],
359
+ self.forward_metadata.kv_indices,
360
+ self.forward_metadata.max_q_len,
361
+ self.forward_metadata.max_kv_len,
362
+ causal=True,
363
+ logits_soft_cap=self.logits_soft_cap,
364
+ alibi_slopes=None,
365
+ return_lse=False,
366
+ return_attn_probs=False,
367
+ )
368
+
369
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
370
+
371
+ def forward_decode(
372
+ self,
373
+ q: torch.Tensor,
374
+ k: torch.Tensor,
375
+ v: torch.Tensor,
376
+ layer: RadixAttention,
377
+ forward_batch: ForwardBatch,
378
+ save_kv_cache=True,
379
+ ):
380
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
381
+
382
+ if layer.qk_head_dim != layer.v_head_dim:
383
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
384
+ else:
385
+ o = torch.empty_like(q)
386
+
387
+ if save_kv_cache:
388
+ forward_batch.token_to_kv_pool.set_kv_buffer(
389
+ layer, forward_batch.out_cache_loc, k, v
390
+ )
391
+
392
+ self.logits_soft_cap = layer.logit_cap
393
+ paged_attention_ragged(
394
+ o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
395
+ self.workspace_buffer,
396
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
397
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
398
+ -1, 1, layer.tp_k_head_num, layer.qk_head_dim
399
+ ),
400
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
401
+ -1, 1, layer.tp_v_head_num, layer.v_head_dim
402
+ ),
403
+ self.scale,
404
+ self.forward_metadata.kv_indptr,
405
+ self.forward_metadata.kv_indices,
406
+ self.kv_last_page_lens,
407
+ 1,
408
+ self.max_num_partitions,
409
+ None,
410
+ "auto",
411
+ "NHD",
412
+ self.logits_soft_cap,
413
+ self.k_scale,
414
+ self.v_scale,
415
+ None,
416
+ _AITER_PARTITION_SIZE_ROCM,
417
+ )
418
+
419
+ return o
420
+
421
+
422
+ class AiterIndicesUpdaterPrefill:
423
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
424
+ # Parse Constants
425
+ self.num_qo_heads = (
426
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
427
+ )
428
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
429
+ get_attention_tp_size()
430
+ )
431
+ self.head_dim = model_runner.model_config.head_dim
432
+ self.data_type = model_runner.kv_cache_dtype
433
+ self.q_data_type = model_runner.dtype
434
+ self.sliding_window_size = model_runner.sliding_window_size
435
+ self.attn_backend = attn_backend
436
+
437
+ # Buffers and wrappers
438
+ self.kv_indptr = attn_backend.kv_indptr
439
+ self.kv_last_page_len = attn_backend.kv_last_page_len
440
+ self.qo_indptr = attn_backend.qo_indptr
441
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
442
+ self.update = self.update_single_wrapper
443
+
444
+ self.kv_indices = None
445
+ self.max_q_len = 0
446
+ self.max_kv_len = 0
447
+
448
+ def update(
449
+ self,
450
+ req_pool_indices: torch.Tensor,
451
+ seq_lens: torch.Tensor,
452
+ seq_lens_sum: int,
453
+ prefix_lens: torch.Tensor,
454
+ encoder_lens: Optional[torch.Tensor],
455
+ spec_info: Optional[SpecInfo],
456
+ ):
457
+ # Keep the signature for type checking. It will be assigned during runtime.
458
+ raise NotImplementedError()
459
+
460
+ def update_single_wrapper(
461
+ self,
462
+ req_pool_indices: torch.Tensor,
463
+ seq_lens: torch.Tensor,
464
+ seq_lens_sum: int,
465
+ prefix_lens: torch.Tensor,
466
+ encoder_lens: Optional[torch.Tensor],
467
+ spec_info: Optional[SpecInfo],
468
+ ):
469
+
470
+ kv_start_idx = None
471
+ kv_indptr = self.kv_indptr
472
+ qo_indptr = self.qo_indptr
473
+ paged_kernel_lens = seq_lens
474
+ paged_kernel_lens_sum = seq_lens_sum
475
+
476
+ bs = len(req_pool_indices)
477
+ if spec_info is None:
478
+ # Normal extend
479
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
480
+ kv_indptr = kv_indptr[: bs + 1]
481
+ kv_indices = torch.empty(
482
+ paged_kernel_lens_sum + 256,
483
+ dtype=torch.int32,
484
+ device=req_pool_indices.device,
485
+ )
486
+ create_flashinfer_kv_indices_triton[(bs,)](
487
+ self.req_to_token,
488
+ req_pool_indices,
489
+ paged_kernel_lens,
490
+ kv_indptr,
491
+ kv_start_idx,
492
+ kv_indices,
493
+ self.req_to_token.shape[1],
494
+ )
495
+
496
+ self.max_kv_len = torch.max(paged_kernel_lens).item()
497
+
498
+ extend_lens = seq_lens - prefix_lens
499
+ self.max_q_len = torch.max(extend_lens).item()
500
+
501
+ qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
502
+ qo_indptr = qo_indptr[: bs + 1]
503
+ custom_mask = None
504
+ else:
505
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
506
+ spec_info.generate_attn_arg_prefill(
507
+ req_pool_indices,
508
+ paged_kernel_lens,
509
+ self.req_to_token,
510
+ )
511
+ )
512
+
513
+ self.kv_indices = kv_indices
@@ -918,8 +918,11 @@ class FlashAttentionBackend(AttentionBackend):
918
918
  and local_attn_metadata is not None
919
919
  and (hasattr(layer, "use_irope") and layer.use_irope)
920
920
  )
921
- # We do cascade attention for Draft Decode with topk > 1
922
- use_cascade_attn = self.topk > 1
921
+
922
+ # When Spec Decode enabled, forward_decode would be called with two mode:
923
+ # 1. DRAFT_DECODE: we enable cascade attention when top_k > 1
924
+ # 2. IDLE: we don’t need cascade attention, spec_info will be none in this case
925
+ use_cascade_attn = forward_batch.spec_info is not None and self.topk > 1
923
926
 
924
927
  # Calculate window size (can be moved to metadata if layer properties don't change)
925
928
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
@@ -1165,7 +1168,6 @@ class FlashAttentionBackend(AttentionBackend):
1165
1168
  max_virtual_batches = max_bs * (
1166
1169
  (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1167
1170
  )
1168
- max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
1169
1171
  max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
1170
1172
 
1171
1173
  self.decode_cuda_graph_local_attn_metadata = {
@@ -1177,7 +1179,7 @@ class FlashAttentionBackend(AttentionBackend):
1177
1179
  ),
1178
1180
  "local_block_table": torch.zeros(
1179
1181
  max_virtual_batches,
1180
- max_blocks_per_seq * max_pages_per_block,
1182
+ max_pages_per_block,
1181
1183
  dtype=torch.int32,
1182
1184
  device=self.device,
1183
1185
  ),
@@ -1435,19 +1437,7 @@ class FlashAttentionBackend(AttentionBackend):
1435
1437
  self.decode_cuda_graph_metadata[bs] = metadata
1436
1438
 
1437
1439
  if self.attention_chunk_size is not None:
1438
- metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1439
- local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
1440
- "local_query_start_loc"
1441
- ],
1442
- local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
1443
- "local_seqused_k"
1444
- ],
1445
- local_block_table=self.decode_cuda_graph_local_attn_metadata[
1446
- "local_block_table"
1447
- ],
1448
- local_max_query_len=1,
1449
- local_max_seq_len=1,
1450
- )
1440
+ self._update_local_attn_metadata_for_capture(metadata, batch_size)
1451
1441
 
1452
1442
  elif forward_mode.is_target_verify():
1453
1443
  if self.topk <= 1:
@@ -1808,6 +1798,62 @@ class FlashAttentionBackend(AttentionBackend):
1808
1798
  )
1809
1799
  metadata.local_attn_metadata = local_metadata
1810
1800
 
1801
+ def _update_local_attn_metadata_for_capture(
1802
+ self, metadata: FlashAttentionMetadata, bs: int
1803
+ ):
1804
+ """Update local attention metadata during CUDA graph capture phase.
1805
+
1806
+ This method calculates the exact buffer sizes needed for local attention metadata
1807
+ during the CUDA graph capture phase, optimizing memory usage by creating views of
1808
+ pre-allocated buffers with exactly the sizes needed.
1809
+ """
1810
+ seq_lens_capture = metadata.cache_seqlens_int32
1811
+ max_seq_len = int(seq_lens_capture.max().item())
1812
+ page_table_capture = metadata.page_table
1813
+
1814
+ cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
1815
+ seqlens_np = seq_lens_capture.cpu().numpy()
1816
+ (
1817
+ seqlens_q_local_np,
1818
+ cu_seqlens_q_local_np,
1819
+ seqlens_k_local_np,
1820
+ block_table_local_np,
1821
+ ) = make_local_attention_virtual_batches(
1822
+ self.attention_chunk_size,
1823
+ cu_seqlens_q_np,
1824
+ seqlens_np,
1825
+ page_table_capture,
1826
+ self.page_size,
1827
+ )
1828
+
1829
+ # Get exact dimensions from the calculation
1830
+ q_len = len(cu_seqlens_q_local_np)
1831
+ k_len = len(seqlens_k_local_np)
1832
+ b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
1833
+ b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
1834
+
1835
+ # Create views of the pre-allocated buffers with exactly these sizes
1836
+ # This is the key optimization - we only use the memory we actually need
1837
+ local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
1838
+ "local_query_start_loc"
1839
+ ][:q_len]
1840
+
1841
+ local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
1842
+ :k_len
1843
+ ]
1844
+
1845
+ local_block_table = self.decode_cuda_graph_local_attn_metadata[
1846
+ "local_block_table"
1847
+ ][:b0, :b1]
1848
+
1849
+ metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1850
+ local_query_start_loc=local_query_start_loc,
1851
+ local_seqused_k=local_seqused_k,
1852
+ local_block_table=local_block_table,
1853
+ local_max_query_len=1,
1854
+ local_max_seq_len=max_seq_len,
1855
+ )
1856
+
1811
1857
  def _update_local_attn_metadata_for_replay(
1812
1858
  self, metadata: FlashAttentionMetadata, bs: int
1813
1859
  ):
@@ -346,7 +346,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
346
346
  cache_loc = forward_batch.out_cache_loc
347
347
  logits_soft_cap = layer.logit_cap
348
348
  prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
349
- k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
350
349
 
351
350
  # Save kv cache
352
351
  if save_kv_cache and k is not None:
@@ -381,6 +380,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
381
380
  )
382
381
  else:
383
382
  # mla paged prefill
383
+ k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
384
+ q.dtype
385
+ )
384
386
  if q_rope is None:
385
387
  qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
386
388
  q, q_rope = (
@@ -442,7 +444,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
442
444
  q_nope = reshaped_q[:, :, : layer.v_head_dim]
443
445
  q_rope = reshaped_q[:, :, layer.v_head_dim :]
444
446
 
445
- k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
447
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
448
+ q.dtype
449
+ )
446
450
 
447
451
  o = q_nope.new_empty(q_nope.shape)
448
452
  # Direct call to run without the wrapper
@@ -467,7 +471,7 @@ class FlashInferMLAIndicesUpdaterDecode:
467
471
  self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
468
472
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
469
473
  self.scaling = model_runner.model_config.scaling
470
- self.data_type = model_runner.kv_cache_dtype
474
+ self.data_type = model_runner.dtype
471
475
  self.attn_backend = attn_backend
472
476
 
473
477
  # Buffers and wrappers
@@ -577,7 +581,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
577
581
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
578
582
  self.v_head_dim = model_runner.model_config.v_head_dim
579
583
  self.scaling = model_runner.model_config.scaling
580
- self.data_type = model_runner.kv_cache_dtype
584
+ self.data_type = model_runner.dtype
581
585
  self.q_data_type = model_runner.dtype
582
586
  self.attn_backend = attn_backend
583
587