sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1700 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention layer with Dual chunk flash attention and sparse attention.
3
+ """
4
+ import functools
5
+ import logging
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
13
+ from sgl_kernel.sparse_flash_attn import (
14
+ convert_vertical_slash_indexes,
15
+ convert_vertical_slash_indexes_mergehead,
16
+ sparse_attn_func,
17
+ )
18
+
19
+ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
20
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
21
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.model_executor.model_runner import ModelRunner
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class DualChunkFlashAttentionMetadata:
34
+ """Metadata for FlashAttentionBackend.
35
+
36
+ NOTE: Any python object stored here is not updated when it is
37
+ cuda-graph replayed. If you have values that need to be changed
38
+ dynamically, it should be stored in tensor. The tensor has to be
39
+ updated from `CUDAGraphRunner.forward` API.
40
+ """
41
+
42
+ # (batch_size,). The sequence length per sequence. Sequence length means
43
+ # the computed tokens + new tokens None if it is a decoding.
44
+ seq_lens: Optional[List[int]] = None
45
+ # seq_lens stored as a tensor.
46
+ seq_lens_tensor: Optional[torch.Tensor] = None
47
+ # Maximum sequence length among prefill batch. 0 if there are decoding
48
+ # requests only.
49
+ max_seq_len: int = None
50
+
51
+ # (batch_size,). The orig sequence length per sequence.
52
+ orig_seq_lens: Optional[List[int]] = None
53
+
54
+ # orig_seq_lens stored as a tensor.
55
+ orig_seq_lens_tensor: Optional[torch.Tensor] = None
56
+
57
+ # Block addresses per sequence. (Seq id -> list of physical block)
58
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
59
+ # in the kv cache. Each block can contain up to block_size tokens.
60
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
61
+ # captured.
62
+ block_tables: Optional[torch.Tensor] = None
63
+
64
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
65
+ # the batch, used to index into subquery. E.g., if the subquery length
66
+ # is [4, 6], it is [0, 4, 10].
67
+ query_start_loc: Optional[torch.Tensor] = None
68
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
69
+ # the batch, used to index into sequence. E.g., if the sequence length is
70
+ # [4, 6], it is [0, 4, 10].
71
+ seq_start_loc: Optional[torch.Tensor] = None
72
+
73
+ # Length scaling factor
74
+ scaling_factor: Optional[torch.Tensor] = None
75
+
76
+ # (batch_size,). Sequence lengths for intra attention.
77
+ seq_lens_intra: Optional[torch.Tensor] = None
78
+
79
+ # Max sequence length for intra attention.
80
+ max_seq_len_intra: Optional[int] = None
81
+
82
+ # (batch_size, num_blocks). Block table for intra attention.
83
+ block_tables_intra: Optional[torch.Tensor] = None
84
+
85
+ # (batch_size,). Sequence lengths for succ attention.
86
+ seq_lens_succ: Optional[torch.Tensor] = None
87
+
88
+ # Max sequence length for succ attention.
89
+ max_seq_len_succ: Optional[int] = None
90
+
91
+ # (batch_size, num_blocks). Block table for succ attention.
92
+ block_tables_succ: Optional[torch.Tensor] = None
93
+
94
+ # (batch_size,). Sequence lengths for inter attention.
95
+ seq_lens_inter: Optional[torch.Tensor] = None
96
+
97
+ # Max sequence length for inter attention.
98
+ max_seq_len_inter: Optional[int] = None
99
+
100
+
101
+ class DualChunkFlashAttentionBackend(AttentionBackend):
102
+ def __init__(
103
+ self,
104
+ model_runner: "ModelRunner",
105
+ ) -> None:
106
+ self.forward_metadata: FlashAttentionMetadata = None
107
+ self.device = model_runner.device
108
+ self.max_context_len = model_runner.model_config.context_len
109
+ self.num_heads = model_runner.model_config.get_num_attention_heads(
110
+ model_runner.server_args.tp_size
111
+ )
112
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
113
+ model_runner.server_args.tp_size
114
+ )
115
+ self.head_size = model_runner.model_config.head_dim
116
+
117
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
118
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
119
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
120
+ self.page_size = model_runner.page_size
121
+
122
+ assert self.num_heads % self.num_kv_heads == 0
123
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
124
+
125
+ dual_chunk_attention_config = getattr(
126
+ model_runner.model_config.hf_config, "dual_chunk_attention_config", None
127
+ )
128
+ assert dual_chunk_attention_config is not None
129
+ self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192)
130
+ self.local_size = dual_chunk_attention_config.get("local_size", 1024)
131
+ self.original_max_position_embeddings = dual_chunk_attention_config.get(
132
+ "original_max_position_embeddings", 0
133
+ )
134
+ self.sparse_attention_config = dual_chunk_attention_config.get(
135
+ "sparse_attention_config", None
136
+ )
137
+ if not self.sparse_attention_config:
138
+ logger.warning_once(
139
+ "Sparse attention will not be enabled as "
140
+ "sparse attention config is not provided."
141
+ )
142
+ self.sparse_attention_enabled = dual_chunk_attention_config.get(
143
+ "sparse_attention_enabled", self.sparse_attention_config is not None
144
+ )
145
+ self.sparse_attention_threshold = dual_chunk_attention_config.get(
146
+ "sparse_attention_threshold", 32768
147
+ )
148
+ self.sparse_attention_last_q = dual_chunk_attention_config.get(
149
+ "sparse_attention_last_q", 64
150
+ )
151
+ self.dual_chunk_attention_config = dual_chunk_attention_config
152
+
153
+ if self.sparse_attention_enabled:
154
+ self.arange = torch.arange(self.sparse_attention_last_q, device="cuda")
155
+ self.last_q_mask = (
156
+ self.arange[None, None, :, None] >= self.arange[None, None, None, :]
157
+ )
158
+
159
+ @functools.lru_cache()
160
+ def get_sparse_attention_config(self, layer_idx) -> List[Dict[str, Any]]:
161
+ layer_sparse_attention_config = {
162
+ int(i): j for i, j in self.sparse_attention_config[layer_idx].items()
163
+ }
164
+ start_head = self.num_heads * get_tensor_model_parallel_rank()
165
+ end_head = start_head + self.num_heads
166
+ return [layer_sparse_attention_config[i] for i in range(start_head, end_head)]
167
+
168
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
169
+ """Initialize forward metadata hence all layers in the forward pass can reuse it."""
170
+
171
+ forward_mode: ForwardMode = forward_batch.forward_mode
172
+ assert forward_mode.is_prefill() or forward_mode.is_decode()
173
+ batch_size = forward_batch.batch_size
174
+
175
+ metadata = DualChunkFlashAttentionMetadata()
176
+ metadata.seq_lens_tensor = forward_batch.seq_lens.to(torch.int32)
177
+ metadata.seq_lens = forward_batch.seq_lens.tolist()
178
+ metadata.max_seq_len = forward_batch.seq_lens.max().item()
179
+
180
+ metadata.orig_seq_lens_tensor = forward_batch.orig_seq_lens
181
+ metadata.orig_seq_lens = forward_batch.orig_seq_lens.tolist()
182
+
183
+ metadata.block_tables = forward_batch.req_to_token_pool.req_to_token[
184
+ forward_batch.req_pool_indices, : metadata.max_seq_len
185
+ ]
186
+ # Convert the block table to a strided format.
187
+ if self.page_size > 1:
188
+ strided_indices = torch.arange(
189
+ 0, metadata.block_tables.shape[1], self.page_size, device=self.device
190
+ )
191
+ metadata.block_tables = (
192
+ metadata.block_tables[:, strided_indices] // self.page_size
193
+ )
194
+
195
+ metadata.query_start_loc = torch.zeros(
196
+ batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device
197
+ )
198
+ if forward_mode.is_prefill():
199
+ metadata.query_start_loc[1:] = torch.cumsum(
200
+ forward_batch.extend_seq_lens.to(torch.int32), dim=0, dtype=torch.int32
201
+ )
202
+ else:
203
+ metadata.query_start_loc[1:] = torch.cumsum(
204
+ torch.arange(
205
+ batch_size,
206
+ dtype=metadata.query_start_loc.dtype,
207
+ device=metadata.query_start_loc.device,
208
+ ),
209
+ dim=0,
210
+ dtype=torch.int32,
211
+ )
212
+ metadata.seq_start_loc = torch.zeros(
213
+ batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device
214
+ )
215
+ metadata.seq_start_loc[1:] = torch.cumsum(
216
+ metadata.seq_lens_tensor, dim=0, dtype=torch.int32
217
+ )
218
+
219
+ if self.original_max_position_embeddings > 0:
220
+ if forward_mode.is_prefill():
221
+ metadata.scaling_factor = (
222
+ 0.1
223
+ * torch.log(
224
+ metadata.orig_seq_lens_tensor
225
+ / self.original_max_position_embeddings
226
+ )
227
+ + 1.0
228
+ ).clip(min=1)
229
+ else:
230
+ metadata.scaling_factor = (
231
+ 0.1
232
+ * torch.log(
233
+ metadata.orig_seq_lens_tensor
234
+ / self.original_max_position_embeddings
235
+ )
236
+ + 1.0
237
+ ).clip(min=1)
238
+
239
+ if forward_mode.is_decode():
240
+ cache_seq_lens = metadata.orig_seq_lens_tensor
241
+
242
+ chunk_len = self.chunk_size - self.local_size
243
+ chunk_num_curr = (cache_seq_lens - 1) // chunk_len
244
+
245
+ seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len
246
+ max_seq_len_intra = seq_lens_intra.max().item()
247
+ metadata.seq_lens_intra = seq_lens_intra
248
+ metadata.max_seq_len_intra = max_seq_len_intra
249
+
250
+ block_tables_intra = torch.zeros(
251
+ batch_size,
252
+ (max_seq_len_intra - 1) // self.page_size + 1,
253
+ dtype=metadata.block_tables.dtype,
254
+ device=metadata.block_tables.device,
255
+ )
256
+ for i in range(batch_size):
257
+ st = chunk_num_curr[i] * chunk_len // self.page_size
258
+ ed = min(
259
+ st + (max_seq_len_intra - 1) // self.page_size + 1,
260
+ (cache_seq_lens[i] - 1) // self.page_size + 1,
261
+ )
262
+ block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed]
263
+ metadata.block_tables_intra = block_tables_intra
264
+
265
+ metadata.seq_lens_succ = (
266
+ chunk_num_curr - (chunk_num_curr - 1).clip(min=0)
267
+ ) * chunk_len
268
+ metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item()
269
+ if metadata.max_seq_len_succ:
270
+ block_tables_succ = torch.zeros(
271
+ batch_size,
272
+ (metadata.max_seq_len_succ - 1) // self.page_size + 1,
273
+ dtype=metadata.block_tables.dtype,
274
+ device=metadata.block_tables.device,
275
+ )
276
+ for i in range(batch_size):
277
+ start = (
278
+ (chunk_num_curr[i] - 1).clip(min=0)
279
+ * chunk_len
280
+ // self.page_size
281
+ )
282
+ end = min(
283
+ start + (metadata.max_seq_len_succ - 1) // self.page_size + 1,
284
+ (cache_seq_lens[i] - 1) // self.page_size + 1,
285
+ )
286
+ block_tables_succ[i, : end - start] = metadata.block_tables[
287
+ i, start:end
288
+ ]
289
+ metadata.block_tables_succ = block_tables_succ
290
+
291
+ metadata.seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len
292
+ metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item()
293
+
294
+ self.forward_metadata = metadata
295
+
296
+ def forward_extend(
297
+ self,
298
+ q: torch.Tensor,
299
+ k: torch.Tensor,
300
+ v: torch.Tensor,
301
+ layer: "RadixAttention",
302
+ forward_batch: ForwardBatch,
303
+ save_kv_cache=True,
304
+ ):
305
+ # Use precomputed metadata across all layers
306
+ metadata = self.forward_metadata
307
+
308
+ (
309
+ query,
310
+ query_succ,
311
+ query_inter,
312
+ query_succ_critical,
313
+ query_inter_critical,
314
+ ) = torch.split(q, q.shape[-1] // 5, dim=-1)
315
+
316
+ # Reshape the query, key, and value tensors.
317
+ query = query.view(-1, self.num_heads, self.head_size)
318
+ query_succ = query_succ.view(-1, self.num_heads, self.head_size)
319
+ query_inter = query_inter.view(-1, self.num_heads, self.head_size)
320
+ query_succ_critical = query_succ_critical.view(
321
+ -1, self.num_heads, self.head_size
322
+ )
323
+ query_inter_critical = query_inter_critical.view(
324
+ -1, self.num_heads, self.head_size
325
+ )
326
+ key = k.view(-1, self.num_kv_heads, self.head_size)
327
+ value = v.view(-1, self.num_kv_heads, self.head_size)
328
+
329
+ # apply DCA scaling
330
+ if self.original_max_position_embeddings > 0:
331
+ assert metadata.scaling_factor is not None
332
+ assert metadata.query_start_loc is not None
333
+ assert metadata.orig_seq_lens is not None
334
+ current_start = 0
335
+ query_start_loc_cpu = metadata.query_start_loc.cpu()
336
+ for i in range(len(metadata.orig_seq_lens)):
337
+ current_end = (
338
+ current_start
339
+ + (query_start_loc_cpu[i + 1] - query_start_loc_cpu[i]).item()
340
+ )
341
+ key[current_start:current_end].mul_(metadata.scaling_factor[i])
342
+ current_start = current_end
343
+ assert current_end <= self.max_context_len
344
+
345
+ # Do multi-head attention
346
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
347
+ layer.layer_id
348
+ )
349
+ key_cache = key_cache.view(
350
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
351
+ )
352
+ value_cache = value_cache.view(
353
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
354
+ )
355
+
356
+ if key is not None and value is not None:
357
+ if save_kv_cache:
358
+ forward_batch.token_to_kv_pool.set_kv_buffer(
359
+ layer,
360
+ forward_batch.out_cache_loc,
361
+ key,
362
+ value,
363
+ layer.k_scale,
364
+ layer.v_scale,
365
+ )
366
+
367
+ if not save_kv_cache:
368
+ # profile run
369
+ o = flash_attn_varlen_func(
370
+ q=query,
371
+ k=key,
372
+ v=value,
373
+ cu_seqlens_q=metadata.seq_start_loc,
374
+ cu_seqlens_k=metadata.seq_start_loc,
375
+ max_seqlen_q=metadata.max_seq_len,
376
+ max_seqlen_k=metadata.max_seq_len,
377
+ softmax_scale=layer.scaling,
378
+ causal=True,
379
+ )
380
+ else:
381
+ # prefill/chunked-prefill
382
+ # get per layer sparse attention config
383
+ if self.sparse_attention_enabled:
384
+ self.layer_sparse_attention_config = self.get_sparse_attention_config(
385
+ layer.layer_id
386
+ )
387
+ assert metadata.orig_seq_lens is not None
388
+ o = self._dual_chunk_flash_attn_prefill(
389
+ q=query,
390
+ q_succ=query_succ,
391
+ q_inter=query_inter,
392
+ q_succ_critical=query_succ_critical,
393
+ q_inter_critical=query_inter_critical,
394
+ k=key_cache,
395
+ v=value_cache,
396
+ cu_seqlens_q=metadata.query_start_loc,
397
+ cu_seqlens_k=metadata.seq_start_loc,
398
+ orig_seq_lens=metadata.orig_seq_lens,
399
+ scaling_factor=metadata.scaling_factor,
400
+ softmax_scale=layer.scaling,
401
+ causal=True,
402
+ window_size=(-1, -1),
403
+ block_table=metadata.block_tables,
404
+ chunk_size=self.chunk_size,
405
+ local_size=self.local_size,
406
+ )
407
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
408
+
409
+ def forward_decode(
410
+ self,
411
+ q: torch.Tensor,
412
+ k: torch.Tensor,
413
+ v: torch.Tensor,
414
+ layer: "RadixAttention",
415
+ forward_batch: ForwardBatch,
416
+ save_kv_cache=True,
417
+ ) -> torch.Tensor:
418
+ # Use precomputed metadata across all layers
419
+ metadata = self.forward_metadata
420
+
421
+ (
422
+ query,
423
+ query_succ,
424
+ query_inter,
425
+ query_succ_critical,
426
+ query_inter_critical,
427
+ ) = torch.split(q, q.shape[-1] // 5, dim=-1)
428
+
429
+ # Reshape the query, key, and value tensors.
430
+ query = query.view(-1, self.num_heads, self.head_size)
431
+ query_succ = query_succ.view(-1, self.num_heads, self.head_size)
432
+ query_inter = query_inter.view(-1, self.num_heads, self.head_size)
433
+ query_succ_critical = query_succ_critical.view(
434
+ -1, self.num_heads, self.head_size
435
+ )
436
+ query_inter_critical = query_inter_critical.view(
437
+ -1, self.num_heads, self.head_size
438
+ )
439
+ key = k.view(-1, self.num_kv_heads, self.head_size)
440
+ value = v.view(-1, self.num_kv_heads, self.head_size)
441
+
442
+ key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
443
+ layer.layer_id
444
+ )
445
+ key_cache = key_cache.view(
446
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
447
+ )
448
+ value_cache = value_cache.view(
449
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
450
+ )
451
+
452
+ if key is not None and value is not None:
453
+ if save_kv_cache:
454
+ forward_batch.token_to_kv_pool.set_kv_buffer(
455
+ layer,
456
+ forward_batch.out_cache_loc,
457
+ key,
458
+ value,
459
+ layer.k_scale,
460
+ layer.v_scale,
461
+ )
462
+
463
+ # apply DCA scaling
464
+ if self.original_max_position_embeddings > 0:
465
+ assert metadata.scaling_factor is not None
466
+ scaling_factor = metadata.scaling_factor
467
+ key.mul_(scaling_factor.unsqueeze(-1).unsqueeze(-1))
468
+
469
+ o = self._dual_chunk_flash_attn_decoding(
470
+ query.unsqueeze(1),
471
+ query_succ.unsqueeze(1),
472
+ query_inter.unsqueeze(1),
473
+ key_cache,
474
+ value_cache,
475
+ block_table=metadata.block_tables,
476
+ cache_seqlens=metadata.seq_lens_tensor,
477
+ softmax_scale=layer.scaling,
478
+ causal=True,
479
+ chunk_size=self.chunk_size,
480
+ local_size=self.local_size,
481
+ original_max_position_embeddings=self.original_max_position_embeddings,
482
+ decode_meta=metadata,
483
+ ).squeeze(1)
484
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
485
+
486
+ def init_cuda_graph_state(self, max_bs: int):
487
+ """Initialize CUDA graph state for the attention backend.
488
+
489
+ Args:
490
+ max_bs (int): Maximum batch size to support in CUDA graphs
491
+
492
+ This creates fixed-size tensors that will be reused during CUDA graph replay
493
+ to avoid memory allocations.
494
+ """
495
+ self.decode_metadata = {
496
+ "seq_lens_tensor": torch.zeros(
497
+ max_bs, dtype=torch.int32, device=self.device
498
+ ),
499
+ "orig_seq_lens_tensor": torch.zeros(
500
+ max_bs, dtype=torch.int32, device=self.device
501
+ ),
502
+ "scaling_factor": torch.zeros(
503
+ max_bs, dtype=torch.float32, device=self.device
504
+ ),
505
+ "block_tables": torch.zeros(
506
+ max_bs,
507
+ (self.max_context_len - 1) // self.page_size + 1,
508
+ dtype=torch.int32,
509
+ device=self.device,
510
+ ),
511
+ "block_tables_intra": torch.zeros(
512
+ max_bs,
513
+ (self.max_context_len - 1) // self.page_size + 1,
514
+ dtype=torch.int32,
515
+ device=self.device,
516
+ ),
517
+ "seq_lens_intra": torch.zeros(
518
+ max_bs, dtype=torch.int32, device=self.device
519
+ ),
520
+ "block_tables_succ": torch.zeros(
521
+ max_bs,
522
+ (self.max_context_len - 1) // self.page_size + 1,
523
+ dtype=torch.int32,
524
+ device=self.device,
525
+ ),
526
+ "seq_lens_succ": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
527
+ "seq_lens_inter": torch.zeros(
528
+ max_bs, dtype=torch.int32, device=self.device
529
+ ),
530
+ }
531
+
532
+ def init_forward_metadata_capture_cuda_graph(
533
+ self,
534
+ bs: int,
535
+ num_tokens: int,
536
+ req_pool_indices: torch.Tensor,
537
+ seq_lens: torch.Tensor,
538
+ encoder_lens: Optional[torch.Tensor],
539
+ forward_mode: ForwardMode,
540
+ spec_info: Optional[None],
541
+ ):
542
+ metadata = DualChunkFlashAttentionMetadata()
543
+
544
+ if forward_mode.is_decode_or_idle():
545
+ if self.original_max_position_embeddings > 0:
546
+ metadata.scaling_factor = self.decode_metadata["scaling_factor"][:bs]
547
+
548
+ metadata.seq_lens_tensor = self.decode_metadata["seq_lens_tensor"][:bs]
549
+ metadata.orig_seq_lens_tensor = self.decode_metadata[
550
+ "orig_seq_lens_tensor"
551
+ ][:bs]
552
+ metadata.max_seq_len = self.max_context_len
553
+ metadata.block_tables = self.decode_metadata["block_tables"][
554
+ req_pool_indices, :
555
+ ]
556
+
557
+ # intra
558
+ metadata.max_seq_len_intra = self.max_context_len
559
+ metadata.seq_lens_intra = self.decode_metadata["seq_lens_intra"][:bs]
560
+
561
+ metadata.block_tables_intra = self.decode_metadata["block_tables_intra"][
562
+ :bs, :
563
+ ]
564
+
565
+ # succ
566
+ metadata.seq_lens_succ = self.decode_metadata["seq_lens_succ"][:bs]
567
+ metadata.max_seq_len_succ = self.max_context_len
568
+
569
+ metadata.block_tables_succ = self.decode_metadata["block_tables_succ"][
570
+ :bs, :
571
+ ]
572
+
573
+ metadata.seq_lens_inter = self.decode_metadata["seq_lens_inter"][:bs]
574
+ metadata.max_seq_len_inter = self.max_context_len
575
+
576
+ self.decode_metadata[bs] = metadata
577
+
578
+ self.forward_metadata = metadata
579
+
580
+ def init_forward_metadata_replay_cuda_graph(
581
+ self,
582
+ bs: int,
583
+ req_pool_indices: torch.Tensor,
584
+ seq_lens: torch.Tensor,
585
+ seq_lens_sum: int,
586
+ encoder_lens: Optional[torch.Tensor],
587
+ forward_mode: ForwardMode,
588
+ spec_info: Optional[None],
589
+ seq_lens_cpu: Optional[torch.Tensor],
590
+ out_cache_loc: torch.Tensor = None,
591
+ ):
592
+ """Initialize forward metadata for replaying CUDA graph."""
593
+ assert forward_mode.is_decode()
594
+ seq_lens = seq_lens[:bs]
595
+ req_pool_indices = req_pool_indices[:bs]
596
+ metadata = self.decode_metadata[bs]
597
+
598
+ metadata.seq_lens_tensor.copy_(seq_lens.to(torch.int32))
599
+ metadata.seq_lens = seq_lens.tolist()
600
+ metadata.max_seq_len = seq_lens.max().item()
601
+
602
+ metadata.orig_seq_lens_tensor.copy_(seq_lens)
603
+ metadata.orig_seq_lens = seq_lens.tolist()
604
+
605
+ block_tables = self.req_to_token[req_pool_indices, : metadata.max_seq_len]
606
+ # Convert the block table to a strided format.
607
+ if self.page_size > 1:
608
+ strided_indices = torch.arange(
609
+ 0, block_tables.shape[1], self.page_size, device=self.device
610
+ )
611
+ block_tables = block_tables[:, strided_indices] // self.page_size
612
+ metadata.block_tables.fill_(0)
613
+ metadata.block_tables[: block_tables.shape[0], : block_tables.shape[1]].copy_(
614
+ block_tables
615
+ )
616
+
617
+ if self.original_max_position_embeddings > 0:
618
+ scaling_factor = (
619
+ 0.1
620
+ * torch.log(
621
+ metadata.orig_seq_lens_tensor
622
+ / self.original_max_position_embeddings
623
+ )
624
+ + 1.0
625
+ ).clip(min=1)
626
+ metadata.scaling_factor.copy_(scaling_factor)
627
+
628
+ cache_seq_lens = metadata.orig_seq_lens_tensor
629
+
630
+ chunk_len = self.chunk_size - self.local_size
631
+ chunk_num_curr = (cache_seq_lens - 1) // chunk_len
632
+
633
+ seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len
634
+ max_seq_len_intra = seq_lens_intra.max().item()
635
+ metadata.seq_lens_intra.copy_(seq_lens_intra)
636
+ metadata.max_seq_len_intra = max_seq_len_intra
637
+
638
+ metadata.block_tables_intra.fill_(0)
639
+ for i in range(bs):
640
+ st = chunk_num_curr[i] * chunk_len // self.page_size
641
+ ed = min(
642
+ st + (max_seq_len_intra - 1) // self.page_size + 1,
643
+ (cache_seq_lens[i] - 1) // self.page_size + 1,
644
+ )
645
+ metadata.block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed]
646
+
647
+ seq_lens_succ = (chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len
648
+ metadata.seq_lens_succ.copy_(seq_lens_succ)
649
+ metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item()
650
+ if metadata.max_seq_len_succ:
651
+ metadata.block_tables_succ.fill_(0)
652
+ for i in range(bs):
653
+ start = (
654
+ (chunk_num_curr[i] - 1).clip(min=0) * chunk_len // self.page_size
655
+ )
656
+ end = min(
657
+ start + (metadata.max_seq_len_succ - 1) // self.page_size + 1,
658
+ (cache_seq_lens[i] - 1) // self.page_size + 1,
659
+ )
660
+ metadata.block_tables_succ[i, : end - start] = metadata.block_tables[
661
+ i, start:end
662
+ ]
663
+
664
+ seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len
665
+ metadata.seq_lens_inter.copy_(seq_lens_inter)
666
+ metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item()
667
+
668
+ self.forward_metadata = metadata
669
+
670
+ def get_cuda_graph_seq_len_fill_value(self):
671
+ """Get the fill value for sequence length in CUDA graph."""
672
+ return 1
673
+
674
+ def _dual_chunk_flash_attn_prefill(
675
+ self,
676
+ q,
677
+ q_succ,
678
+ q_inter,
679
+ q_succ_critical,
680
+ q_inter_critical,
681
+ k,
682
+ v,
683
+ cu_seqlens_q,
684
+ cu_seqlens_k,
685
+ orig_seq_lens: List[int],
686
+ scaling_factor: torch.Tensor,
687
+ softmax_scale: float,
688
+ causal: Optional[bool] = True,
689
+ window_size: Tuple[int, int] = (-1, -1),
690
+ block_table: Optional[torch.Tensor] = None,
691
+ chunk_size: int = 8192,
692
+ local_size: int = 1024,
693
+ ):
694
+ if not causal:
695
+ raise ValueError("Dual Chunk Attention does not support causal=False")
696
+ if window_size != (-1, -1):
697
+ raise ValueError("Dual Chunk Attention does not support window_size")
698
+
699
+ cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist()
700
+ cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist()
701
+ all_outputs = []
702
+
703
+ for i in range(0, len(cu_seqlens_q_cpu) - 1):
704
+ qs = cu_seqlens_q_cpu[i]
705
+ qe = cu_seqlens_q_cpu[i : i + 2][-1]
706
+ ks = cu_seqlens_k_cpu[i]
707
+ ke = cu_seqlens_k_cpu[i : i + 2][-1]
708
+
709
+ current_q = q[qs:qe]
710
+ current_q_succ = q_succ[qs:qe]
711
+ current_q_inter = q_inter[qs:qe]
712
+ current_q_succ_critical = q_succ_critical[qs:qe]
713
+ current_q_inter_critical = q_inter_critical[qs:qe]
714
+
715
+ if block_table is None:
716
+ current_k = k[ks:ke]
717
+ current_v = v[ks:ke]
718
+ current_block_table = None
719
+ current_orig_seq_len = orig_seq_lens[i]
720
+ else:
721
+ current_block_table = block_table[i]
722
+ current_orig_seq_len = orig_seq_lens[i]
723
+ current_k = k
724
+ current_v = v
725
+ sparse_attn_enabled = (
726
+ self.sparse_attention_enabled
727
+ and current_orig_seq_len > self.sparse_attention_threshold
728
+ )
729
+
730
+ if current_q.shape[0] == 0:
731
+ continue
732
+
733
+ if current_k.shape[0] == 0:
734
+ all_outputs.append(
735
+ torch.zeros(
736
+ (current_q.shape[0], current_q.shape[1], v.shape[2]),
737
+ device=q.device,
738
+ dtype=q.dtype,
739
+ )
740
+ )
741
+ continue
742
+
743
+ current_output = torch.empty_like(current_q)
744
+ group_size = int(current_q.size(-2) / current_k.size(-2))
745
+
746
+ if sparse_attn_enabled:
747
+ num_device_q_heads = current_q.size(-2)
748
+ heads_vertical_size = torch.empty(
749
+ size=(num_device_q_heads,), dtype=torch.int32
750
+ )
751
+ heads_slash_size = torch.empty(
752
+ size=(num_device_q_heads,), dtype=torch.int32
753
+ )
754
+ for head_id in range(current_q.size(-2)):
755
+ (
756
+ ty,
757
+ vertical_size,
758
+ slash_size,
759
+ _,
760
+ ) = self.layer_sparse_attention_config[head_id]
761
+ assert ty == "vertical_and_slash", "only support slash mode"
762
+
763
+ if vertical_size == 30:
764
+ vertical_size += 100
765
+ heads_vertical_size[head_id] = vertical_size
766
+ heads_slash_size[head_id] = slash_size
767
+
768
+ current_output = self._dual_chunk_flash_attn_prefill_func(
769
+ current_q, # allheads
770
+ current_q_succ,
771
+ current_q_inter,
772
+ current_q_succ_critical,
773
+ current_q_inter_critical,
774
+ current_k,
775
+ current_v,
776
+ current_block_table,
777
+ softmax_scale,
778
+ chunk_size,
779
+ local_size,
780
+ scaling_factor[i].item(),
781
+ ke - ks,
782
+ sparse_attn_enabled=sparse_attn_enabled,
783
+ heads_vertical_size=heads_vertical_size,
784
+ heads_slash_size=heads_slash_size,
785
+ group_size=group_size,
786
+ )
787
+ else:
788
+ for head_id in range(current_q.size(-2)):
789
+ # (seq_len, num_heads, head_size)
790
+ current_q_head = current_q[:, head_id, :].unsqueeze(1)
791
+ current_q_succ_head = current_q_succ[:, head_id, :].unsqueeze(1)
792
+ current_q_inter_head = current_q_inter[:, head_id, :].unsqueeze(1)
793
+ current_q_succ_head_critical = current_q_succ_critical[
794
+ :, head_id, :
795
+ ].unsqueeze(1)
796
+ current_q_inter_head_critical = current_q_inter_critical[
797
+ :, head_id, :
798
+ ].unsqueeze(1)
799
+ if block_table is not None:
800
+ current_k_head = current_k[
801
+ ..., head_id // group_size, :
802
+ ].unsqueeze(2)
803
+ current_v_head = current_v[
804
+ ..., head_id // group_size, :
805
+ ].unsqueeze(2)
806
+
807
+ else:
808
+ current_k_head = current_k[:, head_id, :].unsqueeze(1)
809
+ current_v_head = current_v[:, head_id, :].unsqueeze(1)
810
+
811
+ current_out = self._dual_chunk_flash_attn_prefill_func(
812
+ current_q_head,
813
+ current_q_succ_head,
814
+ current_q_inter_head,
815
+ current_q_succ_head_critical,
816
+ current_q_inter_head_critical,
817
+ current_k_head,
818
+ current_v_head,
819
+ current_block_table,
820
+ softmax_scale,
821
+ chunk_size,
822
+ local_size,
823
+ scaling_factor[i].item(),
824
+ ke - ks,
825
+ sparse_attn_enabled=sparse_attn_enabled,
826
+ )
827
+ current_output[:, head_id : head_id + 1, :] = current_out
828
+ all_outputs.append(current_output)
829
+ return torch.cat(all_outputs, dim=0)
830
+
831
+ def _dual_chunk_flash_attn_prefill_func(
832
+ self,
833
+ q,
834
+ q_succ,
835
+ q_inter,
836
+ q_succ_critical,
837
+ q_inter_critical,
838
+ k,
839
+ v,
840
+ block_table,
841
+ softmax_scale: float,
842
+ chunk_size: int,
843
+ local_size: int,
844
+ scaling_factor: float,
845
+ k_length: int,
846
+ sparse_attn_enabled: Optional[bool] = True,
847
+ heads_vertical_size=None,
848
+ heads_slash_size=None,
849
+ group_size=None,
850
+ ):
851
+ flash_results = []
852
+ chunk_len = chunk_size - local_size
853
+
854
+ if block_table is not None:
855
+ block_size = v.shape[1]
856
+ if chunk_len % block_size != 0:
857
+ raise ValueError("chunk_len must be divisible by block_size.")
858
+ else:
859
+ block_size = 1
860
+
861
+ if self.original_max_position_embeddings > 0:
862
+ softmax_scale = softmax_scale * scaling_factor
863
+
864
+ begin = k_length - q.shape[0]
865
+ while begin < k_length:
866
+ flash_per_chunk = []
867
+
868
+ prev_chunk_end_pos = (begin // chunk_len) * chunk_len
869
+ next_chunk_end_pos = prev_chunk_end_pos + chunk_len
870
+ end = min(next_chunk_end_pos, k_length)
871
+ qbegin = begin - (k_length - q.shape[0])
872
+ qend = end - (k_length - q.shape[0])
873
+
874
+ qk_chunks = []
875
+ q_states_intra = q[qbegin:qend]
876
+ # choose critical token
877
+ if block_table is not None:
878
+ block_tables_intra = _get_block(
879
+ block_table, block_size, prev_chunk_end_pos, end
880
+ )
881
+ k_states_intra = k[block_tables_intra].view(-1, *k.shape[-2:])[
882
+ : (end - prev_chunk_end_pos)
883
+ ]
884
+ v_states_intra = v[block_tables_intra].view(-1, *v.shape[-2:])[
885
+ : (end - prev_chunk_end_pos)
886
+ ]
887
+ else:
888
+ block_tables_intra = None
889
+ k_states_intra = k[prev_chunk_end_pos:end]
890
+ v_states_intra = v[prev_chunk_end_pos:end]
891
+
892
+ if sparse_attn_enabled:
893
+ last_q_size = min(qend - qbegin, self.sparse_attention_last_q)
894
+ _, num_device_k_heads, head_dim = k_states_intra.shape
895
+ k_states_intra = (
896
+ k_states_intra.unsqueeze(2)
897
+ .repeat(1, 1, group_size, 1)
898
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
899
+ )
900
+ v_states_intra = (
901
+ v_states_intra.unsqueeze(2)
902
+ .repeat(1, 1, group_size, 1)
903
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
904
+ )
905
+ qk_chunks.append(
906
+ (q_states_intra.transpose(0, 1)[:, -last_q_size:] * softmax_scale)
907
+ @ k_states_intra.permute(1, 2, 0)
908
+ )
909
+
910
+ if prev_chunk_end_pos - chunk_len >= 0:
911
+ q_states_succ = q_succ[qbegin:qend]
912
+ q_states_succ_critical = q_succ_critical[qbegin:qend]
913
+ if block_table is not None:
914
+ block_tables_succ = _get_block(
915
+ block_table,
916
+ block_size,
917
+ prev_chunk_end_pos - chunk_len,
918
+ prev_chunk_end_pos,
919
+ )
920
+ k_states_succ = k[block_tables_succ].view(-1, *k.shape[-2:])[
921
+ :chunk_len
922
+ ]
923
+ v_states_succ = v[block_tables_succ].view(-1, *v.shape[-2:])[
924
+ :chunk_len
925
+ ]
926
+ else:
927
+ k_states_succ = k[
928
+ prev_chunk_end_pos - chunk_len : prev_chunk_end_pos
929
+ ]
930
+ v_states_succ = v[
931
+ prev_chunk_end_pos - chunk_len : prev_chunk_end_pos
932
+ ]
933
+
934
+ if sparse_attn_enabled:
935
+ k_states_succ = (
936
+ k_states_succ.unsqueeze(2)
937
+ .repeat(1, 1, group_size, 1)
938
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
939
+ )
940
+ v_states_succ = (
941
+ v_states_succ.unsqueeze(2)
942
+ .repeat(1, 1, group_size, 1)
943
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
944
+ )
945
+ qk_chunks.append(
946
+ (
947
+ q_states_succ_critical.transpose(0, 1)[:, -last_q_size:]
948
+ * softmax_scale
949
+ )
950
+ @ k_states_succ.permute(1, 2, 0)
951
+ )
952
+
953
+ if prev_chunk_end_pos - chunk_len * 2 >= 0:
954
+ q_states_inter = q_inter[qbegin:qend]
955
+ q_states_inter_critical = q_inter_critical[qbegin:qend]
956
+ if block_table is not None:
957
+ block_tables_inter = _get_block(
958
+ block_table, block_size, 0, prev_chunk_end_pos - chunk_len
959
+ )
960
+ k_states_inter = k[block_tables_inter].view(-1, *k.shape[-2:])[
961
+ : (prev_chunk_end_pos - chunk_len)
962
+ ]
963
+ v_states_inter = v[block_tables_inter].view(-1, *v.shape[-2:])[
964
+ : (prev_chunk_end_pos - chunk_len)
965
+ ]
966
+ else:
967
+ k_states_inter = k[: prev_chunk_end_pos - chunk_len]
968
+ v_states_inter = v[: prev_chunk_end_pos - chunk_len]
969
+
970
+ if sparse_attn_enabled:
971
+ k_states_inter = (
972
+ k_states_inter.unsqueeze(2)
973
+ .repeat(1, 1, group_size, 1)
974
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
975
+ )
976
+ v_states_inter = (
977
+ v_states_inter.unsqueeze(2)
978
+ .repeat(1, 1, group_size, 1)
979
+ .reshape(-1, num_device_k_heads * group_size, head_dim)
980
+ )
981
+ qk_chunks.append(
982
+ (
983
+ q_states_inter_critical.transpose(0, 1)[:, -last_q_size:]
984
+ * softmax_scale
985
+ )
986
+ @ k_states_inter.permute(1, 2, 0)
987
+ )
988
+
989
+ if sparse_attn_enabled:
990
+ reversed_qk = qk_chunks[::-1]
991
+ qk = torch.cat(reversed_qk, dim=-1)
992
+
993
+ qk[:, :, -last_q_size:] = torch.where(
994
+ self.last_q_mask[..., -last_q_size:, -last_q_size:].to(qk.device),
995
+ qk[:, :, -last_q_size:],
996
+ -torch.inf,
997
+ )
998
+ qk = F.softmax(qk, dim=-1, dtype=torch.float32)
999
+
1000
+ vertical = qk.sum(-2, keepdim=True)
1001
+ vertical[..., :30] = torch.inf
1002
+
1003
+ # Avoid sorting by using the min/max ints to fill the indexer
1004
+ # buffers.
1005
+ int32_max = torch.iinfo(torch.int32).max
1006
+ int32_min = torch.iinfo(torch.int32).min
1007
+ n_heads = qk.size()[0]
1008
+ max_slash_topk = torch.max(heads_slash_size).item()
1009
+ max_vertical_topk = torch.max(heads_vertical_size).item()
1010
+ # store each head's slash topk, vertical topk
1011
+ vertical = vertical.reshape((n_heads, -1))
1012
+ # prevent out of range when prompt size < max_vertical_topk
1013
+ max_vertical_topk = min(vertical.shape[-1], max_vertical_topk)
1014
+ vertical_topk_buffer = torch.topk(
1015
+ vertical, max_vertical_topk, -1
1016
+ ).indices
1017
+ slash_topk_buffer = torch.empty(
1018
+ size=(n_heads, max_slash_topk), dtype=torch.int64, device=qk.device
1019
+ )
1020
+ for head_i in range(n_heads):
1021
+ # (nqheads=1, lastq, k_len)
1022
+ head_score = qk[head_i : head_i + 1, :, :]
1023
+ slash_scores = _sum_all_diagonal_matrix(head_score)
1024
+ if head_score.size(1) != 1:
1025
+ # drop right up corner
1026
+ slash_scores = slash_scores[..., : -last_q_size + 1]
1027
+ slash_scores[..., -100:] = torch.inf
1028
+
1029
+ head_slash_size = heads_slash_size[head_i]
1030
+ head_slash_size = min(head_slash_size, vertical.size(-1))
1031
+ slash_topk = torch.topk(slash_scores, head_slash_size, -1).indices
1032
+ # (nheads, max_topk)
1033
+ slash_topk_buffer[head_i, :head_slash_size] = slash_topk
1034
+
1035
+ # reset heads topk
1036
+ heads_slash_size[head_i] = head_slash_size
1037
+ heads_vertical_size[head_i] = min(
1038
+ heads_vertical_size[head_i], max_vertical_topk
1039
+ )
1040
+
1041
+ # store
1042
+ vertical_buffer = torch.full(
1043
+ (n_heads, max_vertical_topk),
1044
+ int32_max,
1045
+ dtype=torch.int64,
1046
+ device=q.device,
1047
+ )
1048
+ slash_buffer = torch.full(
1049
+ (n_heads, max_slash_topk),
1050
+ int32_min,
1051
+ dtype=torch.int64,
1052
+ device=q.device,
1053
+ )
1054
+ succ_vertical_buffer = torch.full(
1055
+ (n_heads, max_vertical_topk),
1056
+ int32_max,
1057
+ dtype=torch.int64,
1058
+ device=q.device,
1059
+ )
1060
+ succ_slash_buffer = torch.full(
1061
+ (n_heads, max_slash_topk),
1062
+ int32_min,
1063
+ dtype=torch.int64,
1064
+ device=q.device,
1065
+ )
1066
+ inter_vertical_buffer = torch.full(
1067
+ (n_heads, max_vertical_topk),
1068
+ int32_max,
1069
+ dtype=torch.int64,
1070
+ device=q.device,
1071
+ )
1072
+ inter_slash_buffer = torch.full(
1073
+ (n_heads, max_slash_topk),
1074
+ int32_min,
1075
+ dtype=torch.int64,
1076
+ device=q.device,
1077
+ )
1078
+
1079
+ vertical_size_buffer = torch.empty(
1080
+ size=(n_heads,), dtype=torch.int32, device=q.device
1081
+ )
1082
+ slash_sizes_buffer = torch.empty(
1083
+ size=(n_heads,), dtype=torch.int32, device=q.device
1084
+ )
1085
+ succ_vertical_size_buffer = torch.empty(
1086
+ size=(n_heads,), dtype=torch.int32, device=q.device
1087
+ )
1088
+ succ_slash_sizes_buffer = torch.empty(
1089
+ size=(n_heads,), dtype=torch.int32, device=q.device
1090
+ )
1091
+ inter_vertical_size_buffer = torch.empty(
1092
+ size=(n_heads,), dtype=torch.int32, device=q.device
1093
+ )
1094
+ inter_slash_sizes_buffer = torch.empty(
1095
+ size=(n_heads,), dtype=torch.int32, device=q.device
1096
+ )
1097
+
1098
+ for head_i in range(n_heads):
1099
+ vertical_topk = vertical_topk_buffer[
1100
+ head_i, : heads_vertical_size[head_i]
1101
+ ]
1102
+ # intra
1103
+ intra_vertical_indices = (
1104
+ vertical_topk[vertical_topk >= prev_chunk_end_pos]
1105
+ - prev_chunk_end_pos
1106
+ )
1107
+ if intra_vertical_indices.nelement() == 0:
1108
+ intra_vertical_indices = torch.cat(
1109
+ [
1110
+ intra_vertical_indices,
1111
+ torch.arange(
1112
+ 0,
1113
+ k_states_intra.size(0),
1114
+ max(1, k_states_intra.size(0) / 5),
1115
+ dtype=torch.int32,
1116
+ device=intra_vertical_indices.device,
1117
+ ),
1118
+ ]
1119
+ )
1120
+ slash_topk = slash_topk_buffer[head_i, : heads_slash_size[head_i]]
1121
+ intra_slash_indices = (qk.size(-1) - 1) - slash_topk[
1122
+ slash_topk >= prev_chunk_end_pos
1123
+ ]
1124
+ # fill buffer
1125
+ v_count = intra_vertical_indices.nelement()
1126
+ s_count = intra_slash_indices.nelement()
1127
+ vertical_size_buffer[head_i] = v_count
1128
+ slash_sizes_buffer[head_i] = s_count
1129
+ vertical_buffer[head_i, :v_count].copy_(intra_vertical_indices)
1130
+ slash_buffer[head_i, :s_count].copy_(intra_slash_indices)
1131
+ # succ
1132
+ if prev_chunk_end_pos - chunk_len >= 0:
1133
+ succ_vertical_indices = vertical_topk[
1134
+ (vertical_topk < prev_chunk_end_pos)
1135
+ & (vertical_topk >= prev_chunk_end_pos - chunk_len)
1136
+ ] - (prev_chunk_end_pos - chunk_len)
1137
+ # TODO: support no vertical
1138
+ if succ_vertical_indices.nelement() == 0:
1139
+ succ_vertical_indices = torch.cat(
1140
+ [
1141
+ succ_vertical_indices,
1142
+ torch.arange(
1143
+ 0,
1144
+ k_states_succ.size(0),
1145
+ max(1, k_states_succ.size(0) / 5),
1146
+ dtype=torch.int32,
1147
+ device=intra_vertical_indices.device,
1148
+ ),
1149
+ ]
1150
+ )
1151
+ succ_slash_indices = (
1152
+ prev_chunk_end_pos + (qend - qbegin) - 1
1153
+ ) - slash_topk[
1154
+ (
1155
+ (slash_topk >= (prev_chunk_end_pos - chunk_len))
1156
+ & (slash_topk < (prev_chunk_end_pos + (qend - qbegin)))
1157
+ )
1158
+ ]
1159
+ if succ_slash_indices.nelement() == 0:
1160
+ succ_slash_indices = torch.cat(
1161
+ [
1162
+ succ_slash_indices,
1163
+ torch.arange(
1164
+ 0,
1165
+ k_states_succ.size(0),
1166
+ max(1, k_states_succ.size(0) / 5),
1167
+ dtype=torch.int32,
1168
+ device=intra_vertical_indices.device,
1169
+ ),
1170
+ ]
1171
+ )
1172
+ # fill buffer
1173
+ v_count = succ_vertical_indices.nelement()
1174
+ s_count = succ_slash_indices.nelement()
1175
+ succ_vertical_size_buffer[head_i] = v_count
1176
+ succ_slash_sizes_buffer[head_i] = s_count
1177
+ succ_vertical_buffer[head_i, :v_count].copy_(
1178
+ succ_vertical_indices
1179
+ )
1180
+ succ_slash_buffer[head_i, :s_count].copy_(succ_slash_indices)
1181
+
1182
+ if prev_chunk_end_pos - 2 * chunk_len >= 0:
1183
+ inter_vertical_indices = vertical_topk[
1184
+ vertical_topk < prev_chunk_end_pos - chunk_len
1185
+ ]
1186
+
1187
+ if inter_vertical_indices.nelement() == 0:
1188
+ inter_vertical_indices = torch.cat(
1189
+ [
1190
+ inter_vertical_indices,
1191
+ torch.arange(
1192
+ 0,
1193
+ k_states_inter.size(0),
1194
+ max(1, k_states_inter.size(0) / 5),
1195
+ dtype=torch.int32,
1196
+ device=intra_vertical_indices.device,
1197
+ ),
1198
+ ]
1199
+ )
1200
+ inter_slash_indices = (
1201
+ prev_chunk_end_pos - chunk_len + (qend - qbegin) - 1
1202
+ ) - slash_topk[
1203
+ slash_topk
1204
+ < (prev_chunk_end_pos - chunk_len + (qend - qbegin))
1205
+ ]
1206
+ if inter_slash_indices.nelement() == 0:
1207
+ inter_slash_indices = torch.cat(
1208
+ [
1209
+ inter_slash_indices,
1210
+ torch.arange(
1211
+ 0,
1212
+ k_states_inter.size(0),
1213
+ max(1, k_states_inter.size(0) / 5),
1214
+ dtype=torch.int32,
1215
+ device=intra_vertical_indices.device,
1216
+ ),
1217
+ ]
1218
+ )
1219
+ # fill buffer
1220
+ v_count = inter_vertical_indices.nelement()
1221
+ s_count = inter_slash_indices.nelement()
1222
+ inter_vertical_size_buffer[head_i] = v_count
1223
+ inter_slash_sizes_buffer[head_i] = s_count
1224
+ inter_vertical_buffer[head_i, :v_count].copy_(
1225
+ inter_vertical_indices
1226
+ )
1227
+ inter_slash_buffer[head_i, :s_count].copy_(inter_slash_indices)
1228
+ else:
1229
+ intra_vertical_indices, intra_slash_indices = None, None
1230
+ succ_vertical_indices, succ_slash_indices = None, None
1231
+ inter_vertical_indices, inter_slash_indices = None, None
1232
+
1233
+ if sparse_attn_enabled:
1234
+ flash_result = self._do_flash_attn(
1235
+ q_states_intra,
1236
+ k_states_intra,
1237
+ v_states_intra,
1238
+ softmax_scale=softmax_scale,
1239
+ causal=True,
1240
+ stage="intra",
1241
+ vertical_indices=vertical_buffer,
1242
+ slash_indices=slash_buffer,
1243
+ vertical_indices_count=vertical_size_buffer,
1244
+ slash_indices_count=slash_sizes_buffer,
1245
+ mergehead_softmax_scale=softmax_scale,
1246
+ sparse_attn_enabled=sparse_attn_enabled,
1247
+ )
1248
+ else:
1249
+ flash_result = self._do_flash_attn(
1250
+ q_states_intra,
1251
+ k_states_intra,
1252
+ v_states_intra,
1253
+ softmax_scale=softmax_scale,
1254
+ causal=True,
1255
+ stage="intra",
1256
+ vertical_indices=intra_vertical_indices,
1257
+ slash_indices=intra_slash_indices,
1258
+ sparse_attn_enabled=sparse_attn_enabled,
1259
+ )
1260
+ flash_per_chunk.append(flash_result)
1261
+
1262
+ if prev_chunk_end_pos - chunk_len >= 0:
1263
+ if sparse_attn_enabled:
1264
+ flash_result = self._do_flash_attn(
1265
+ q_states_succ,
1266
+ k_states_succ,
1267
+ v_states_succ,
1268
+ softmax_scale=softmax_scale,
1269
+ causal=False,
1270
+ stage="succ",
1271
+ vertical_indices=succ_vertical_buffer,
1272
+ slash_indices=succ_slash_buffer,
1273
+ vertical_indices_count=succ_vertical_size_buffer,
1274
+ slash_indices_count=succ_slash_sizes_buffer,
1275
+ mergehead_softmax_scale=softmax_scale,
1276
+ sparse_attn_enabled=sparse_attn_enabled,
1277
+ )
1278
+ else:
1279
+ flash_result = self._do_flash_attn(
1280
+ q_states_succ,
1281
+ k_states_succ,
1282
+ v_states_succ,
1283
+ softmax_scale=softmax_scale,
1284
+ causal=False,
1285
+ stage="succ",
1286
+ vertical_indices=succ_vertical_indices,
1287
+ slash_indices=succ_slash_indices,
1288
+ sparse_attn_enabled=sparse_attn_enabled,
1289
+ )
1290
+ flash_per_chunk.append(flash_result)
1291
+
1292
+ if prev_chunk_end_pos - chunk_len * 2 >= 0:
1293
+ if sparse_attn_enabled:
1294
+ flash_result = self._do_flash_attn(
1295
+ q_states_inter,
1296
+ k_states_inter,
1297
+ v_states_inter,
1298
+ softmax_scale=softmax_scale,
1299
+ causal=False,
1300
+ stage="inter",
1301
+ vertical_indices=inter_vertical_buffer,
1302
+ slash_indices=inter_slash_buffer,
1303
+ vertical_indices_count=inter_vertical_size_buffer,
1304
+ slash_indices_count=inter_slash_sizes_buffer,
1305
+ mergehead_softmax_scale=softmax_scale,
1306
+ sparse_attn_enabled=sparse_attn_enabled,
1307
+ )
1308
+ else:
1309
+ flash_result = self._do_flash_attn(
1310
+ q_states_inter,
1311
+ k_states_inter,
1312
+ v_states_inter,
1313
+ softmax_scale=softmax_scale,
1314
+ causal=False,
1315
+ stage="inter",
1316
+ vertical_indices=inter_vertical_indices,
1317
+ slash_indices=inter_slash_indices,
1318
+ sparse_attn_enabled=sparse_attn_enabled,
1319
+ )
1320
+ flash_per_chunk.append(flash_result)
1321
+
1322
+ flash_results.append(flash_per_chunk)
1323
+ begin = end
1324
+
1325
+ attn_output = self._merge_attn_outputs(flash_results)
1326
+ del flash_results
1327
+ return attn_output
1328
+
1329
+ def _do_flash_attn(
1330
+ self,
1331
+ query_states: torch.Tensor,
1332
+ key_states: torch.Tensor,
1333
+ value_states: torch.Tensor,
1334
+ softmax_scale: float,
1335
+ causal: bool = True,
1336
+ max_seqlen_k: Optional[int] = None,
1337
+ stage: str = "intra",
1338
+ vertical_indices: Optional[torch.Tensor] = None,
1339
+ slash_indices: Optional[torch.Tensor] = None,
1340
+ vertical_indices_count: Optional[torch.Tensor] = None,
1341
+ slash_indices_count: Optional[torch.Tensor] = None,
1342
+ mergehead_softmax_scale: Optional[float] = None,
1343
+ sparse_attn_enabled: Optional[bool] = False,
1344
+ ):
1345
+ if max_seqlen_k is None:
1346
+ max_seqlen_k = key_states.shape[0]
1347
+
1348
+ q_len = query_states.shape[0]
1349
+ q_heads = query_states.shape[1]
1350
+ h_dim = query_states.shape[-1]
1351
+
1352
+ if sparse_attn_enabled:
1353
+ assert slash_indices is not None
1354
+ if stage == "intra":
1355
+ assert causal
1356
+ else:
1357
+ assert not causal
1358
+
1359
+ query_states = query_states.unsqueeze(0).transpose(1, 2)
1360
+ key_states = key_states.unsqueeze(0).transpose(1, 2)
1361
+ value_states = value_states.unsqueeze(0).transpose(1, 2)
1362
+
1363
+ q = query_states
1364
+ k = key_states
1365
+ v = value_states
1366
+
1367
+ if vertical_indices_count is not None and slash_indices_count is not None:
1368
+ assert mergehead_softmax_scale is not None
1369
+
1370
+ res, s_lse = _vertical_slash_sparse_attention(
1371
+ q,
1372
+ k,
1373
+ v,
1374
+ vertical_indices,
1375
+ slash_indices,
1376
+ mergehead_softmax_scale,
1377
+ causal=causal,
1378
+ stage=stage,
1379
+ vertical_indices_count=vertical_indices_count,
1380
+ slash_indices_count=slash_indices_count,
1381
+ )
1382
+ res = res.view(q_heads, q_len, h_dim).transpose(
1383
+ 0, 1
1384
+ ) # (qlen,nhead,h_dim)
1385
+ s_lse = (
1386
+ s_lse.view(q_heads, q_len, 1).squeeze(-1).unsqueeze(0).float()
1387
+ ) # (1, nhead,qlen)
1388
+ else:
1389
+ res, s_lse = _vertical_slash_sparse_attention(
1390
+ q,
1391
+ k,
1392
+ v,
1393
+ vertical_indices,
1394
+ slash_indices,
1395
+ softmax_scale,
1396
+ causal=causal,
1397
+ stage=stage,
1398
+ )
1399
+ res = res.view(q_len, q_heads, h_dim)
1400
+ s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
1401
+ return res, s_lse
1402
+
1403
+ output, softmax_lse, *rest = flash_attn_varlen_func(
1404
+ q=query_states,
1405
+ k=key_states,
1406
+ v=value_states,
1407
+ softmax_scale=softmax_scale,
1408
+ cu_seqlens_q=torch.tensor(
1409
+ [0, query_states.shape[0]],
1410
+ dtype=torch.int32,
1411
+ device=query_states.device,
1412
+ ),
1413
+ max_seqlen_q=query_states.shape[0],
1414
+ cu_seqlens_k=torch.tensor(
1415
+ [0, max_seqlen_k], dtype=torch.int32, device=query_states.device
1416
+ ),
1417
+ max_seqlen_k=max_seqlen_k,
1418
+ causal=causal,
1419
+ return_softmax_lse=True,
1420
+ )
1421
+ softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
1422
+ return output, softmax_lse
1423
+
1424
+ def _merge_attn_outputs(
1425
+ self,
1426
+ flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]],
1427
+ return_lse: Optional[bool] = False,
1428
+ ) -> torch.Tensor:
1429
+ attn_outputs_all = []
1430
+ logits_all = []
1431
+
1432
+ for flash_per_chunk in flash_results:
1433
+ if len(flash_per_chunk) == 1:
1434
+ attn_outputs_all.append(flash_per_chunk[0][0])
1435
+ if return_lse:
1436
+ logits_all.append(flash_per_chunk[0][1])
1437
+ continue
1438
+
1439
+ attn_outputs = torch.stack(
1440
+ [flash_attn_output[0] for flash_attn_output in flash_per_chunk]
1441
+ )
1442
+ logits = torch.stack(
1443
+ [flash_attn_output[1] for flash_attn_output in flash_per_chunk]
1444
+ )
1445
+ logits = logits.to(torch.float32)
1446
+
1447
+ if return_lse:
1448
+ max_val = torch.max(logits, dim=0).values
1449
+ diff = torch.abs(logits[0] - logits[1])
1450
+ log_sum_exp = max_val + torch.log1p(torch.exp(-diff))
1451
+ logits_all.append(log_sum_exp)
1452
+
1453
+ max_logits = torch.max(logits, dim=0).values
1454
+ stable_logits = logits - max_logits.unsqueeze(0)
1455
+ lse_s = torch.exp(stable_logits).detach()
1456
+ lse_sum = torch.sum(lse_s, dim=0)
1457
+ lse_s /= lse_sum
1458
+ attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1)
1459
+ attn_outputs_all.append(attn_outputs.sum(dim=0))
1460
+
1461
+ if return_lse:
1462
+ return (torch.cat(attn_outputs_all, dim=0), torch.cat(logits_all, dim=-1))
1463
+ else:
1464
+ return torch.cat(attn_outputs_all, dim=0)
1465
+
1466
+ def _dual_chunk_flash_attn_decoding(
1467
+ self,
1468
+ query: torch.Tensor,
1469
+ query_succ: torch.Tensor,
1470
+ query_inter: torch.Tensor,
1471
+ key_cache: torch.Tensor,
1472
+ value_cache: torch.Tensor,
1473
+ block_table: torch.Tensor,
1474
+ cache_seqlens: torch.Tensor,
1475
+ softmax_scale: float,
1476
+ causal: bool,
1477
+ chunk_size: int,
1478
+ local_size: int,
1479
+ original_max_position_embeddings: int,
1480
+ decode_meta: DualChunkFlashAttentionMetadata,
1481
+ ):
1482
+ if not causal:
1483
+ raise ValueError("Dual Chunk Attention does not support causal=False")
1484
+
1485
+ block_size = value_cache.shape[1]
1486
+ chunk_len = chunk_size - local_size
1487
+ if chunk_len % block_size != 0:
1488
+ raise ValueError("chunk_len must be divisible by block_size.")
1489
+ if original_max_position_embeddings > 0:
1490
+ assert decode_meta.scaling_factor is not None
1491
+ scaling_factor = decode_meta.scaling_factor
1492
+ query = (query * scaling_factor.view(-1, 1, 1, 1)).to(
1493
+ query.dtype
1494
+ ) # possible for numerical issue, need to fused in the kernel
1495
+ query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(query.dtype)
1496
+ query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to(
1497
+ query.dtype
1498
+ )
1499
+ outputs_list = []
1500
+ softmax_lses_list = []
1501
+
1502
+ # intra-attention
1503
+ intra_output, intra_softmax_lse = (
1504
+ self._dual_chunk_flash_attn_decoding_with_exp_sums(
1505
+ query,
1506
+ key_cache,
1507
+ value_cache,
1508
+ decode_meta.block_tables_intra,
1509
+ decode_meta.seq_lens_intra,
1510
+ softmax_scale,
1511
+ causal=False,
1512
+ )
1513
+ )
1514
+ outputs_list.append(intra_output)
1515
+ softmax_lses_list.append(intra_softmax_lse)
1516
+
1517
+ # succ-attention
1518
+ if decode_meta.max_seq_len_succ:
1519
+ succ_output, succ_softmax_lse = (
1520
+ self._dual_chunk_flash_attn_decoding_with_exp_sums(
1521
+ query_succ,
1522
+ key_cache,
1523
+ value_cache,
1524
+ decode_meta.block_tables_succ,
1525
+ decode_meta.seq_lens_succ,
1526
+ softmax_scale,
1527
+ causal=False,
1528
+ )
1529
+ )
1530
+ outputs_list.append(succ_output)
1531
+ softmax_lses_list.append(succ_softmax_lse)
1532
+
1533
+ # inter-attention
1534
+ if decode_meta.max_seq_len_inter:
1535
+ inter_output, inter_softmax_lse = (
1536
+ self._dual_chunk_flash_attn_decoding_with_exp_sums(
1537
+ query_inter,
1538
+ key_cache,
1539
+ value_cache,
1540
+ block_table[:, : decode_meta.max_seq_len_inter],
1541
+ decode_meta.seq_lens_inter,
1542
+ softmax_scale,
1543
+ causal=False,
1544
+ )
1545
+ )
1546
+ outputs_list.append(inter_output)
1547
+ softmax_lses_list.append(inter_softmax_lse)
1548
+ outputs = torch.stack(outputs_list, dim=0)
1549
+ del outputs_list
1550
+ softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32)
1551
+ del softmax_lses_list
1552
+ max_logits = torch.max(softmax_lses, dim=0).values
1553
+ stable_logits = softmax_lses - max_logits.unsqueeze(0)
1554
+ lse_s = torch.exp(stable_logits).detach()
1555
+ lse_sum = torch.sum(lse_s, dim=0)
1556
+ lse_s /= lse_sum
1557
+ outputs *= lse_s.unsqueeze(-1).transpose(2, 3)
1558
+ return outputs.sum(0)
1559
+
1560
+ def _dual_chunk_flash_attn_decoding_with_exp_sums(
1561
+ self,
1562
+ query: torch.Tensor,
1563
+ key_cache: torch.Tensor,
1564
+ value_cache: torch.Tensor,
1565
+ block_table: torch.Tensor,
1566
+ cache_seqlens: torch.Tensor,
1567
+ softmax_scale: float,
1568
+ causal: bool,
1569
+ ):
1570
+ out, softmax_lse, *rest_expand = flash_attn_with_kvcache(
1571
+ q=query,
1572
+ k_cache=key_cache,
1573
+ v_cache=value_cache,
1574
+ page_table=block_table,
1575
+ cache_seqlens=cache_seqlens,
1576
+ softmax_scale=softmax_scale,
1577
+ causal=causal,
1578
+ return_softmax_lse=True,
1579
+ )
1580
+ mask = cache_seqlens == 0
1581
+ out[mask] = 0
1582
+ softmax_lse[mask] = -float("inf")
1583
+ return out, softmax_lse
1584
+
1585
+
1586
+ def _vertical_slash_sparse_attention(
1587
+ query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
1588
+ key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
1589
+ value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
1590
+ v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
1591
+ s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
1592
+ softmax_scale: float,
1593
+ causal: bool = True,
1594
+ stage: str = "intra",
1595
+ block_size_M: int = 64,
1596
+ block_size_N: int = 64,
1597
+ vertical_indices_count: torch.Tensor = None, # [N_HEADS,]
1598
+ slash_indices_count: torch.Tensor = None,
1599
+ ):
1600
+ if stage == "intra":
1601
+ assert causal
1602
+ else:
1603
+ assert not causal
1604
+
1605
+ batch_size, num_heads, context_size, head_dim = query.shape
1606
+ _, _, kv_seq_len, _ = key.shape
1607
+
1608
+ if head_dim not in [16, 32, 64, 128, 256, 512]:
1609
+ target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
1610
+ query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
1611
+ key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
1612
+ value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
1613
+
1614
+ v_idx = (
1615
+ v_idx.to(torch.int32)
1616
+ .reshape((batch_size, num_heads, -1))
1617
+ .sort(dim=-1, descending=False)[0]
1618
+ )
1619
+ s_idx = (
1620
+ s_idx.to(torch.int32)
1621
+ .reshape((batch_size, num_heads, -1))
1622
+ .sort(dim=-1, descending=True)[0]
1623
+ )
1624
+ q_seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
1625
+ kv_seqlens = torch.tensor([kv_seq_len], dtype=torch.int32, device=query.device)
1626
+
1627
+ if vertical_indices_count is not None and slash_indices_count is not None:
1628
+ (
1629
+ block_count,
1630
+ block_offset,
1631
+ column_count,
1632
+ column_index,
1633
+ ) = convert_vertical_slash_indexes_mergehead(
1634
+ q_seqlens,
1635
+ kv_seqlens,
1636
+ v_idx,
1637
+ s_idx,
1638
+ vertical_indices_count,
1639
+ slash_indices_count,
1640
+ context_size,
1641
+ block_size_M,
1642
+ block_size_N,
1643
+ causal,
1644
+ )
1645
+ else:
1646
+ (
1647
+ block_count,
1648
+ block_offset,
1649
+ column_count,
1650
+ column_index,
1651
+ ) = convert_vertical_slash_indexes(
1652
+ q_seqlens,
1653
+ kv_seqlens,
1654
+ v_idx,
1655
+ s_idx,
1656
+ context_size,
1657
+ block_size_M,
1658
+ block_size_N,
1659
+ causal,
1660
+ )
1661
+
1662
+ q = query.transpose(1, 2).contiguous()
1663
+ k = key.transpose(1, 2).contiguous()
1664
+ v = value.transpose(1, 2).contiguous()
1665
+ out, lse = sparse_attn_func(
1666
+ q,
1667
+ k,
1668
+ v,
1669
+ block_count,
1670
+ block_offset,
1671
+ column_count,
1672
+ column_index,
1673
+ causal=causal,
1674
+ softmax_scale=softmax_scale,
1675
+ return_softmax_lse=True,
1676
+ )
1677
+ out = out.transpose(1, 2).contiguous()
1678
+ softmax_lse = lse.reshape(*lse.shape, 1)
1679
+ return (out[..., :context_size, :head_dim], softmax_lse[..., :context_size, :])
1680
+
1681
+
1682
+ def _sum_all_diagonal_matrix(mat: torch.tensor):
1683
+ h, n, m = mat.shape
1684
+ # Zero matrix used for padding
1685
+ zero_mat = torch.zeros((h, n, n), device=mat.device)
1686
+ # pads the matrix on left and right
1687
+ mat_padded = torch.cat((zero_mat, mat, zero_mat), -1)
1688
+ # Change the strides
1689
+ mat_strided = mat_padded.as_strided(
1690
+ (1, n, n + m), (n * (2 * n + m), 2 * n + m + 1, 1)
1691
+ )
1692
+ # Sums the resulting matrix's columns
1693
+ sum_diags = torch.sum(mat_strided, 1)
1694
+ return sum_diags[:, 1:] # drop left bottom corner
1695
+
1696
+
1697
+ def _get_block(block_table: torch.Tensor, block_size: int, begin: int, end: int):
1698
+ begin_block = begin // block_size
1699
+ end_block = (end - 1) // block_size + 1
1700
+ return block_table[begin_block:end_block]