mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,718 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-strict
7
+
8
+ import itertools
9
+ from dataclasses import dataclass
10
+ from functools import lru_cache
11
+ from typing import List, Optional, Sequence, Tuple, Type
12
+
13
+ import torch
14
+
15
+ from .. import fmha
16
+ from . import (
17
+ flash,
18
+ flash3,
19
+ memory_efficient_attention_forward_requires_grad,
20
+ memory_efficient_attention_partial,
21
+ merge_attentions,
22
+ triton_splitk,
23
+ )
24
+ from .attn_bias import (
25
+ AttentionBias,
26
+ PagedBlockDiagonalGappyKeysMask,
27
+ PagedBlockDiagonalPaddedKeysMask,
28
+ )
29
+ from .common import AttentionFwOpBase, pack_fp8_tensorwise_per_head
30
+ from .dispatch import _get_use_fa3, fa3_available
31
+
32
+
33
+ @dataclass
34
+ class TreeAttnMetadata:
35
+ """
36
+ tree_choices: definition of the tree, tuples sorted by length, each corresponding
37
+ to a node. See the docstring of TreeAttnMetadata.from_tree_choices.
38
+ attention_bias: Medusa-style tree attention bias as an explicit tensor
39
+ of shape (tree_size, tree_size), where tree_size is the total number
40
+ of nodes in the tree. It can be used as a spec_attn_bias ("right"
41
+ or "suffix" attention part) in tree_attention.
42
+ See tree_attention_with_sync for a usage example.
43
+ tree_indices: 1D tensor of size tree_size which maps tree nodes to draft tokens.
44
+ Tree nodes are assumed to be in the same order as in tree_choices
45
+ (see TreeAttnMetadata.from_tree_choices).
46
+ retrieval_indices: a tensor of shape (number of leaves, depth + 1), where one
47
+ row corresponds to one path, and contains indices of the tree nodes
48
+ on that path. Paths are padded with -1 from the right.
49
+ The paths (row dim) are unsorted.
50
+ path_lengths: real lengths for each of the paths.
51
+ tree_seq_position_ids: 1D tensor of size tree_size which indicates which head
52
+ a node belongs to. Equivalently, it shows the sequence position of the
53
+ node within the corresponding path.
54
+ parent_node_indices: 1D tensor of size tree_size which for each node contains
55
+ position of its parent + 1. For root node(s) it contains 0.
56
+ child_node_indices: a tensor of shape (tree_size, max_num_children_per_node),
57
+ in which each row contains indices of children of the corresponding node.
58
+ Rows corresponding to nodes which have less than max_num_children_per_node
59
+ children are padded by repeating the last child index.
60
+ For leaf nodes the values are meaningless and filled with 0.
61
+ num_children_per_node: 1D tensor of size tree_size which contains the number of
62
+ children for each node.
63
+ candidate_idx: 1D tensor of size tree_size, contains index of each node among its "siblings".
64
+ Takes values from 0 to the number of children of the parent node minus 1.
65
+ num_nodes_per_level: 1D tensor of the number of nodes at each level (including root).
66
+ num_children_per_node_at_level: List of 1D tensors, each containing the number of children at the tree level.
67
+ subtree_size: List of integers, each containing the number of nodes in the subtree at the tree level.
68
+ Example:
69
+ Tree choices
70
+ `[(0,), (0, 0), (0, 1), (0, 2), (1,), (1, 0), (1, 1), (1, 2), (2,), (2, 0), (2, 1), (2, 2)]`
71
+ represents a tree that looks like this:
72
+ 0
73
+ |-- 1
74
+ | |-- 4
75
+ | |-- 5
76
+ | |-- 6
77
+ |
78
+ |-- 2
79
+ | |-- 7
80
+ | |-- 8
81
+ | |-- 9
82
+ |
83
+ |-- 3
84
+ |-- 10
85
+ |-- 11
86
+ |-- 12
87
+
88
+ with TreeAttnMetadata
89
+ tree_indices=tensor([0, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6])
90
+ retrieval_indices=tensor([[ 0, 1, 5],
91
+ [ 0, 2, 9],
92
+ [ 0, 3, 11],
93
+ [ 0, 1, 4],
94
+ [ 0, 2, 8],
95
+ [ 0, 3, 10],
96
+ [ 0, 1, 6],
97
+ [ 0, 2, 7],
98
+ [ 0, 3, 12]])
99
+ path_lengths=[3, 3, 3, 3, 3, 3, 3, 3, 3]
100
+ tree_seq_position_ids=tensor([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
101
+ child_node_indices=tensor([[ 0, 1, 2],
102
+ [ 3, 4, 5],
103
+ [ 6, 7, 8],
104
+ [ 9, 10, 11],
105
+ [ 0, 0, 0],
106
+ ...
107
+ [ 0, 0, 0]])
108
+ num_children_per_node=tensor([3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0])
109
+ candidate_idx=tensor([0, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
110
+ num_nodes_per_level=tensor([1, 3, 3])
111
+ num_children_per_node_at_level=[tensor([3]), tensor([3, 3, 3]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0])]
112
+ subtree_sizes=[1, 4, 13]
113
+ """
114
+
115
+ tree_choices: Sequence[Tuple[int, ...]]
116
+ attention_bias: torch.Tensor
117
+ tree_indices: torch.Tensor
118
+ retrieval_indices: torch.Tensor
119
+ path_lengths: List[int]
120
+ tree_seq_position_ids: torch.Tensor
121
+ parent_node_indices: torch.Tensor
122
+ child_node_indices: torch.Tensor
123
+ num_children_per_node: torch.Tensor
124
+ candidate_idx: torch.Tensor
125
+ num_nodes_per_level: torch.Tensor
126
+ num_nodes_per_level_cpu: torch.Tensor
127
+ num_children_per_node_at_level: List[torch.Tensor]
128
+ num_children_per_node_at_level_cpu: List[torch.Tensor]
129
+ subtree_sizes: List[int]
130
+
131
+ @classmethod
132
+ @lru_cache
133
+ def from_tree_choices_cached(
134
+ cls,
135
+ tree_choices: Tuple[Tuple[int, ...]],
136
+ dtype: Optional[torch.dtype] = None,
137
+ device: Optional[torch.device] = None,
138
+ ) -> "TreeAttnMetadata":
139
+ return cls.from_tree_choices(tree_choices, dtype, device)
140
+
141
+ @classmethod
142
+ def from_tree_choices(
143
+ cls,
144
+ tree_choices: Sequence[Tuple[int, ...]],
145
+ dtype: Optional[torch.dtype] = None,
146
+ device: Optional[torch.device] = None,
147
+ ) -> "TreeAttnMetadata":
148
+ """
149
+ Args:
150
+ tree_choices: tree description in the style of
151
+ https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py
152
+ A typical tree description would look like:
153
+ [(node0, node1, ...), (node0, node2), (node0, node3), (node1, node3), ..., (node0, node2, ..., nodeN)]
154
+ Every tuple is corresponds to one node in the tree, encoded as a path from one of the root nodes to the
155
+ node in question.
156
+ For example, a node encoded as (1, 0, 3, ..., 2) is understood as:
157
+ list all the root nodes and take node number 1
158
+ list all children of that node and take node number 0
159
+ list all children of that node and take node number 3
160
+ ...
161
+ list all children of that node and take node number 2 - that's the node encoded by this tuple.
162
+
163
+ dtype: data type of the output mask tensor.
164
+ device: device of the output tensors.
165
+ Returns:
166
+ TreeAttnMetadata object with all the fields.
167
+ """
168
+ # from https://github.com/SafeAILab/EAGLE/blob/e98fc7c/model/utils.py#L89C1-L117C1
169
+ sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
170
+ tree_len = len(sorted_tree_choices) + 1
171
+
172
+ depth_counts = _get_depth_counts(sorted_tree_choices)
173
+ tree_indices = _prepare_tree_indices(sorted_tree_choices, depth_counts, device)
174
+ retrieval_indices, path_lengths = _prepare_retrieval_indices(
175
+ sorted_tree_choices, device
176
+ )
177
+ tree_seq_position_ids = _prepare_tree_position_ids(
178
+ depth_counts, tree_len, device
179
+ )
180
+ tree_attn_mask = _prepare_tree_attn_bias(
181
+ sorted_tree_choices, depth_counts, dtype, device
182
+ )
183
+ parent_node_indices = _prepare_parent_node_indices(sorted_tree_choices, device)
184
+ child_node_indices, num_children_per_node = _prepare_child_node_indices(
185
+ sorted_tree_choices, device
186
+ )
187
+ candidate_idx = _prepare_candidate_idx(sorted_tree_choices, device)
188
+
189
+ num_nodes_per_level = _get_num_nodes_per_level(depth_counts, device)
190
+ num_nodes_per_level_cpu = num_nodes_per_level.cpu()
191
+ (
192
+ subtree_sizes,
193
+ num_children_per_node_at_level,
194
+ ) = _get_subtree_size_and_num_children_per_node_at_level(
195
+ num_nodes_per_level, num_children_per_node, device
196
+ )
197
+ num_children_per_node_at_level_cpu = [
198
+ row.cpu() for row in num_children_per_node_at_level
199
+ ]
200
+ return TreeAttnMetadata(
201
+ sorted_tree_choices,
202
+ tree_attn_mask,
203
+ tree_indices,
204
+ retrieval_indices,
205
+ path_lengths,
206
+ tree_seq_position_ids,
207
+ parent_node_indices,
208
+ child_node_indices,
209
+ num_children_per_node,
210
+ candidate_idx,
211
+ num_nodes_per_level,
212
+ num_nodes_per_level_cpu,
213
+ num_children_per_node_at_level,
214
+ num_children_per_node_at_level_cpu,
215
+ subtree_sizes,
216
+ )
217
+
218
+
219
+ def _get_subtree_size_and_num_children_per_node_at_level(
220
+ num_nodes_per_level: torch.Tensor,
221
+ num_children_per_node: torch.Tensor,
222
+ device: Optional[torch.device] = None,
223
+ ) -> Tuple[List[int], List[torch.Tensor]]:
224
+ depth: int = len(num_nodes_per_level)
225
+ subtree_sizes: List[int] = [
226
+ 1,
227
+ ]
228
+ num_children_per_node_at_level: List[torch.Tensor] = [
229
+ num_children_per_node[0].unsqueeze(0)
230
+ ]
231
+ for i in range(1, depth):
232
+ subtree_sizes.append(int(torch.sum(num_nodes_per_level[: (i + 1)])))
233
+ num_children_per_node_at_level.append(
234
+ num_children_per_node[subtree_sizes[i - 1] : subtree_sizes[i]]
235
+ )
236
+ return subtree_sizes, num_children_per_node_at_level
237
+
238
+
239
+ def _get_depth_counts(sorted_tree_choices: List[Tuple[int, ...]]) -> List[int]:
240
+ # Initialize depth_counts to keep track of how many choices have a particular depth
241
+ depth_counts = []
242
+ prev_depth = 0
243
+ for path in sorted_tree_choices:
244
+ depth = len(path)
245
+ if depth != prev_depth:
246
+ depth_counts.append(0)
247
+ depth_counts[depth - 1] += 1
248
+ prev_depth = depth
249
+ return depth_counts
250
+
251
+
252
+ def _get_num_nodes_per_level(
253
+ depth_counts: List[int], device: Optional[torch.device]
254
+ ) -> torch.Tensor:
255
+ depth_counts_tensor: torch.Tensor = torch.tensor([1] + depth_counts, device=device)
256
+ return depth_counts_tensor[depth_counts_tensor != 0]
257
+
258
+
259
+ def _prepare_tree_attn_bias(
260
+ sorted_tree_choices: List[Tuple[int, ...]],
261
+ depth_counts: List[int],
262
+ dtype: Optional[torch.dtype],
263
+ device: Optional[torch.device],
264
+ ) -> torch.Tensor:
265
+ """
266
+ Construct a Medusa-style tree attention bias as an explicit tensor.
267
+ It can be used as a spec_attn_bias ("right" or "suffix" attention part)
268
+ in tree_attention. See run_tree_attention_inner in test for a usage example.
269
+ Args:
270
+ sorted_tree_choices: tree description in the style of
271
+ https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py
272
+ A typical tree description would look like:
273
+ [(node0, node1, ...), (node0, node2), (node0, node3), (node1, node3), ..., (node0, node2, ..., nodeN)]
274
+ Every tuple is corresponds to one node in the tree, encoded as a path from one of the root nodes to the
275
+ node in question. Passed in sorted order.
276
+ For example, a node encoded as (1, 0, 3, ..., 2) is understood as:
277
+ list all the root nodes and take node number 1
278
+ list all children of that node and take node number 0
279
+ list all children of that node and take node number 3
280
+ ...
281
+ list all children of that node and take node number 2 - that's the node encoded by this tuple
282
+ depth_counts: a list of integers, where the i-th element is the number of choices with depth i.
283
+ dtype: data type of the output tensor.
284
+ device: device of the output tensor.
285
+ Returns:
286
+ attention bias of shape (tree_size, tree_size),
287
+ where tree_size is the total number of nodes in the tree.
288
+ """
289
+ # +1 comes from the addtional root node
290
+ tree_len = len(sorted_tree_choices) + 1
291
+ tree_attn_mask = torch.full(
292
+ (tree_len, tree_len), -torch.inf, device=device, dtype=dtype
293
+ )
294
+
295
+ mask_val = 0
296
+ for i in range(tree_len):
297
+ tree_attn_mask[i, i] = mask_val
298
+
299
+ tree_attn_mask[:, 0] = mask_val
300
+ start = 0
301
+ for i in range(len(depth_counts)):
302
+ for j in range(depth_counts[i]):
303
+ cur_tree_choice = sorted_tree_choices[start + j]
304
+ # retrieve ancestor position
305
+ if len(cur_tree_choice) == 1:
306
+ continue
307
+ ancestor_idx = []
308
+ for c in range(len(cur_tree_choice) - 1):
309
+ ancestor_idx.append(
310
+ sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
311
+ )
312
+ tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
313
+ start += depth_counts[i]
314
+ return tree_attn_mask
315
+
316
+
317
+ def _prepare_tree_indices(
318
+ sorted_tree_choices: List[Tuple[int, ...]],
319
+ depth_counts: List[int],
320
+ device: Optional[torch.device],
321
+ ) -> torch.Tensor:
322
+ """
323
+ Construct an index tensor for choices in the tree and their corresponding index in the draft tokens.
324
+ Args:
325
+ sorted_tree_choices: sorted from tree_choices input of prepare_tree_attn_metadata function
326
+ depth_counts: a list of integers, where the i-th element is the number of choices with depth i.
327
+ device: device of the output tensor.
328
+ Returns:
329
+ tree indices of shape (tree_len,). See docstring of TreeAttnMetadata for details.
330
+ """
331
+ # Generate tree indices for the tree_choices structure
332
+ # add root node from main head prediction to the tree_len
333
+ tree_len = len(sorted_tree_choices) + 1
334
+ tree_indices = torch.zeros(tree_len, device=device, dtype=torch.long)
335
+ tree_indices[0] = 0
336
+ start, max_idx_prev_level = 0, 0
337
+ for i in range(len(depth_counts)):
338
+ cur_offset = max_idx_prev_level
339
+ for j in range(depth_counts[i]):
340
+ cur_tree_choice = sorted_tree_choices[start + j]
341
+ tree_idx = cur_tree_choice[-1] + cur_offset + 1
342
+ tree_indices[start + j + 1] = tree_idx
343
+ max_idx_prev_level = max(tree_idx, max_idx_prev_level)
344
+ start += depth_counts[i]
345
+ return tree_indices
346
+
347
+
348
+ def _prepare_retrieval_indices(
349
+ tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
350
+ ) -> Tuple[torch.Tensor, List[int]]:
351
+ """
352
+ Convert tree definition from the format used by Medusa and EAGLE (tree_choices, see docstring of
353
+ TreeAttnMetadata.from_tree_choices) to a list of paths:
354
+ [
355
+ (node_index0_path0, node_index1_path0, ...),
356
+ (node_index0_path1, node_index1_path1, ...),
357
+ ...
358
+ ]
359
+ where each value is an index of a node inside the corresponding level of a tree.
360
+ Returns:
361
+ retrieval indices of shape (number of leaves, depth + 1)
362
+ length of each path.
363
+ """
364
+ tree_depth = max(len(node) for node in tree_choices) + 1 if tree_choices else 1
365
+
366
+ leaves = set(tree_choices)
367
+
368
+ for node in tree_choices[::-1]:
369
+ if node[:-1] in leaves:
370
+ leaves.remove(node[:-1])
371
+
372
+ paths, path_lengths = [], []
373
+ for leaf in leaves:
374
+ path = [0] + [
375
+ tree_choices.index(leaf[:level]) + 1 for level in range(1, len(leaf) + 1)
376
+ ]
377
+ path_lengths.append(len(path))
378
+ paths.append(path + [-1] * (tree_depth - len(path)))
379
+ paths_tensor = torch.tensor(paths, dtype=torch.long, device=device)
380
+ return paths_tensor, path_lengths
381
+
382
+
383
+ def _prepare_tree_position_ids(
384
+ depth_counts: List[int], tree_len: int, device: Optional[torch.device]
385
+ ) -> torch.Tensor:
386
+ """
387
+ Construct sequence position of each node within its path, can be used for positional embedding.
388
+ Args:
389
+ depth_counts: number of nodes at each of the levels of the tree.
390
+ tree_len: total number of nodes in the tree including the root.
391
+ device: device of the output tensor.
392
+ Returns:
393
+ tree position ids of shape (tree_len,). See docstring of TreeAttnMetadata for details.
394
+ """
395
+ tree_position_ids = torch.zeros(tree_len, dtype=torch.int32, device=device)
396
+ start = 0
397
+ for i in range(len(depth_counts)):
398
+ tree_position_ids[start + 1 : start + depth_counts[i] + 1] = i + 1
399
+ start += depth_counts[i]
400
+
401
+ return tree_position_ids
402
+
403
+
404
+ def _prepare_parent_node_indices(
405
+ sorted_tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
406
+ ) -> torch.Tensor:
407
+ ancestor_idx = []
408
+ for cur_medusa_choice in sorted_tree_choices:
409
+ try:
410
+ ancestor_idx.append(sorted_tree_choices.index(cur_medusa_choice[:-1]) + 1)
411
+ except ValueError:
412
+ ancestor_idx.append(0)
413
+ return torch.tensor(ancestor_idx, dtype=torch.long, device=device)
414
+
415
+
416
+ def _prepare_child_node_indices(
417
+ tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
418
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
419
+ res = []
420
+ num_children_per_node = []
421
+ for x in [()] + tree_choices:
422
+ curr_children = [
423
+ i
424
+ for i, y in enumerate(tree_choices)
425
+ if len(x) + 1 == len(y) and y[:-1] == x
426
+ ]
427
+ num_children_per_node.append(len(curr_children))
428
+ if curr_children:
429
+ res.append(curr_children)
430
+ else:
431
+ res.append([0])
432
+
433
+ # pad all children lists by repeating the last element
434
+ max_num_children = max(len(x) for x in res)
435
+ res = [x + x[-1:] * (max_num_children - len(x)) for x in res]
436
+
437
+ # Check that all nodes have the same number of children.
438
+ assert all(len(x) == len(res[0]) for x in res)
439
+ return (
440
+ torch.tensor(res, dtype=torch.long, device=device),
441
+ torch.tensor(num_children_per_node, dtype=torch.long, device=device),
442
+ )
443
+
444
+
445
+ def _prepare_candidate_idx(
446
+ tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
447
+ ) -> torch.Tensor:
448
+ candidate_idx = [
449
+ sum(
450
+ curr_node[:-1] == another_node[:-1]
451
+ for another_node in tree_choices[:curr_node_idx]
452
+ )
453
+ for curr_node_idx, curr_node in enumerate(tree_choices)
454
+ ]
455
+ return torch.tensor(candidate_idx, dtype=torch.long, device=device)
456
+
457
+
458
+ def use_triton_splitk_for_prefix(B: int, G: int, tree_size: int) -> bool:
459
+ """
460
+ Heuristic to decide whether to use Triton Split-k or default (Flash Attention) for prefix attention.
461
+ """
462
+ return (
463
+ (B * G <= 128 and tree_size <= 64)
464
+ or (B * G < 4 and tree_size < 100)
465
+ or B * G < 2
466
+ )
467
+
468
+
469
+ def select_prefix_op(
470
+ B: int,
471
+ G: int,
472
+ tree_size: int,
473
+ autotune: bool,
474
+ attn_bias: AttentionBias,
475
+ kv_cache_dtype: torch.dtype,
476
+ ) -> Optional[Type[AttentionFwOpBase]]:
477
+ """
478
+ Heuristic to decide whether to use Triton Split-k or default (Flash Attention) for prefix attention.
479
+ """
480
+ triton_splitk_op = SplitKAutotune if autotune else triton_splitk.FwOp
481
+ if torch.version.hip:
482
+ # TODO: further tune heuristics once CK splitK is ready
483
+ return triton_splitk_op
484
+
485
+ # Triton Split-k is not present in the dispatcher list for some shapes.
486
+ # However, we need to dispatch to it if no other op is available.
487
+ # FA3 splitKV doesn't yet support gappy or paged biases.
488
+ fa3_splitkv_supported = isinstance(
489
+ attn_bias,
490
+ flash3.FwOp_KVSplit.SUPPORTED_ATTN_BIAS_TYPES, # type: ignore
491
+ )
492
+ fa3_supported = isinstance(attn_bias, flash3.FwOp.SUPPORTED_ATTN_BIAS_TYPES) # type: ignore
493
+ flash2_supported = isinstance(attn_bias, flash.FwOp.SUPPORTED_ATTN_BIAS_TYPES) # type: ignore
494
+ if not (fa3_splitkv_supported or fa3_supported or flash2_supported):
495
+ return triton_splitk_op
496
+
497
+ assert torch.version.cuda
498
+ use_fa3 = _get_use_fa3() and fa3_available()
499
+ # override heuristics for quantized kv cache for decode
500
+ if kv_cache_dtype == torch.uint8:
501
+ return triton_splitk_op
502
+ # select FA3 when bs >= 64
503
+ if B >= 64 and use_fa3:
504
+ if fa3_splitkv_supported:
505
+ return flash3.FwOp_KVSplit
506
+ return flash3.FwOp
507
+ elif use_triton_splitk_for_prefix(B, G, tree_size):
508
+ return triton_splitk_op
509
+ else:
510
+ # use default heuristics from xformers
511
+ return None
512
+
513
+
514
+ def tree_attention(
515
+ q: torch.Tensor,
516
+ spec_k: torch.Tensor,
517
+ spec_v: torch.Tensor,
518
+ cache_k: torch.Tensor,
519
+ cache_v: torch.Tensor,
520
+ spec_attn_bias: torch.Tensor,
521
+ prefix_attn_bias: AttentionBias,
522
+ prefix_op: Optional[Type[AttentionFwOpBase]] = None,
523
+ suffix_op: Optional[Type[AttentionFwOpBase]] = None,
524
+ autotune: bool = False,
525
+ quantized_kv_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
526
+ q_fp8: Optional[torch.Tensor] = None,
527
+ ) -> torch.Tensor:
528
+ """
529
+ Compute Medusa/EAGLE/Hydra-style tree attention.
530
+ Notice that this function takes as arguments biases for the left (prefix)
531
+ and right (speculative suffix) parts of the attention.
532
+ This way we avoid creating these biases on the fly, and
533
+ allow this function to be used in performance-critical decoding
534
+ jobs, including in CUDA graph mode. In the latter case one should
535
+ construct the biases once, and update prefix_attn_bias with
536
+ current seqlens before every graph replay; spec_attn_bias stays static,
537
+ as it's determined by the tree structure.
538
+ Args:
539
+ q: query from speculative tokens, of shape (B, tree_size_q, (G), H, D)
540
+ spec_k, spec_v: keys/values from speculative tokens, each of shape (B, tree_size_kv, (G), H, D).
541
+ If tree_size_q < tree_size_kv, we assume the end of the query sequence aligns with end the k/v sequence,
542
+ like in "from-bottom-right" attention masks. Such rectangular attention masks can be used when we are
543
+ adding new nodes to the tree, and want to avoid recomputing attention for the existing nodes. For example,
544
+ this can be used during draft token generation in EAGLE.
545
+ cache_k/cache_v: queries/keys/values from the existing context, each of shape (B, Mk, (G), H, D)
546
+ spec_attn_bias: attention bias of the "right" part of the attention (tree_size_q x spec tokens).
547
+ This would typically be a an explicit tensor mask, precomputed once and not changing during decoding
548
+ prefix_attn_bias: attention bias of the "left" part of the attention (tree_size_q x existing context).
549
+ This bias would typically be block-diagonal padded non-causal (BlockDiagonalPaddedKeysMask), and it
550
+ changes at every decoding step as K/V sequence lengths grow during decoding.
551
+ prefix_op: attention backend which will be passed to memory_efficient_attention to compute prefix attention.
552
+ If None, will use Triton Split-K or Flash Attention depending on the heuristics.
553
+ suffix_op: same as prefix_op, but for the suffix.
554
+ autotune: If True, Triton Split-K will use autotuning when chosen
555
+ as a default backend for prefix/suffix attention.
556
+ Returns:
557
+ attention output of shape (B, tree_size_q, (G), H, D)
558
+
559
+ :Usage example:
560
+
561
+ See also tree_attention_with_sync in tests/test_tree_attention.py
562
+
563
+ .. code-block:: python
564
+
565
+ # Create an attention bias for the prefix part of the attention
566
+ prefix_attn_bias = BlockDiagonalPaddedKeysMask.from_seqlens(
567
+ q_seqlen=[tree_size_q for _ in range(B)], kv_seqlen=kv_lens, kv_padding=Mk
568
+ )
569
+ # Create an explit attention bias for the speculative part of the attention
570
+ spec_attn_bias = TreeAttnMetadata.from_tree_choices(tree_choices, q.dtype, q.device).attention_bias
571
+ attn_output = tree_attention(
572
+ q, spec_k, spec_v, cache_k, cache_v, spec_attn_bias, prefix_attn_bias
573
+ )
574
+ """
575
+
576
+ is_bmhk = q.ndim == 4
577
+ if is_bmhk:
578
+ q = q.unsqueeze(2)
579
+ spec_k, spec_v = spec_k.unsqueeze(2), spec_v.unsqueeze(2)
580
+ cache_k, cache_v = cache_k.unsqueeze(2), cache_v.unsqueeze(2)
581
+
582
+ B, tree_size_q, G, H, D = q.shape
583
+ Bkv, Mk, G1, H1, D1 = cache_k.shape
584
+ tree_size_q1, tree_size_kv = spec_attn_bias.shape
585
+ if isinstance(
586
+ prefix_attn_bias,
587
+ (PagedBlockDiagonalPaddedKeysMask, PagedBlockDiagonalGappyKeysMask),
588
+ ):
589
+ assert Bkv == 1
590
+ else:
591
+ assert Bkv == B
592
+
593
+ assert H == H1 and D == D1 and G == G1
594
+ assert cache_k.shape == cache_v.shape
595
+ assert tree_size_q1 == tree_size_q <= tree_size_kv, (
596
+ f"{tree_size_q1=} {tree_size_q=} {tree_size_kv=}"
597
+ )
598
+ assert q.shape[2:] == spec_k.shape[2:] == spec_v.shape[2:], (
599
+ f"{q.shape=} {spec_k.shape=} {spec_v.shape=}"
600
+ )
601
+
602
+ spec_attn_bias = spec_attn_bias.expand(B, G, H, tree_size_q, tree_size_kv)
603
+
604
+ triton_splitk_op = SplitKAutotune if autotune else triton_splitk.FwOp
605
+
606
+ # TODO: improve this heuristic
607
+ if prefix_op is None:
608
+ prefix_op = select_prefix_op(
609
+ B, G, tree_size_kv, autotune, prefix_attn_bias, cache_k.dtype
610
+ )
611
+ if cache_k.dtype == torch.uint8:
612
+ assert quantized_kv_scales is not None
613
+ assert prefix_op is triton_splitk.FwOp
614
+ fp8_inp = triton_splitk.InputsFp8(
615
+ query=q.view(1, B * tree_size_q, G, H, D),
616
+ key=cache_k.view(1, Bkv * Mk, G, H, D).view(torch.int32),
617
+ value=cache_v.view(1, Bkv * Mk, G, H, D).view(torch.int32),
618
+ attn_bias=prefix_attn_bias,
619
+ k_fp8_scale_shift=quantized_kv_scales[0].view(1, Bkv * Mk, G, H),
620
+ v_fp8_scale_shift=quantized_kv_scales[1].view(1, Bkv * Mk, G, H),
621
+ is_partial=True,
622
+ )
623
+ out, ctx = fmha._memory_efficient_attention_forward_requires_grad(
624
+ fp8_inp,
625
+ op=prefix_op,
626
+ )
627
+ attn_prefix, lse_prefix = out, ctx.lse
628
+ elif cache_k.dtype == torch.float8_e4m3fn:
629
+ assert q_fp8 is not None
630
+ packed_q = pack_fp8_tensorwise_per_head(
631
+ q_fp8.view(1, B * tree_size_q, G, H, D),
632
+ q_fp8.scale, # type: ignore
633
+ torch.bfloat16,
634
+ )
635
+ packed_k = pack_fp8_tensorwise_per_head(
636
+ cache_k.view(1, Bkv * Mk, G, H, D), cache_k.scale, torch.bfloat16
637
+ )
638
+ packed_v = pack_fp8_tensorwise_per_head(
639
+ cache_v.view(1, Bkv * Mk, G, H, D), cache_v.scale, torch.bfloat16
640
+ )
641
+ attn_prefix, lse_prefix = memory_efficient_attention_partial(
642
+ packed_q,
643
+ packed_k,
644
+ packed_v,
645
+ attn_bias=prefix_attn_bias,
646
+ op=fmha.flash3.FwOp_KVSplit,
647
+ )
648
+
649
+ else:
650
+ attn_prefix, lse_prefix = memory_efficient_attention_partial(
651
+ q.view(1, B * tree_size_q, G, H, D),
652
+ cache_k.view(1, Bkv * Mk, G, H, D),
653
+ cache_v.view(1, Bkv * Mk, G, H, D),
654
+ attn_bias=prefix_attn_bias,
655
+ op=prefix_op,
656
+ )
657
+ attn_prefix = attn_prefix.view(B, tree_size_q, G, H, D)
658
+ lse_prefix = lse_prefix.view(G, H, B, tree_size_q).permute(2, 0, 1, 3)
659
+
660
+ # attn_suffix ~ (B, tree_size_q, G, H, D)
661
+ # lse_suffix ~ (B, G, H, tree_size_q)
662
+ attn_suffix, lse_suffix = memory_efficient_attention_forward_requires_grad(
663
+ q,
664
+ spec_k,
665
+ spec_v,
666
+ attn_bias=spec_attn_bias,
667
+ op=suffix_op or triton_splitk_op,
668
+ )
669
+
670
+ # attn_output ~ [B, tree_size_q, G, H, D]
671
+ # attn input [B, M, G, H, Kq]
672
+ # lse input [B, G, H, M]
673
+ attn_output, _ = merge_attentions(
674
+ [attn_prefix, attn_suffix],
675
+ [lse_prefix, lse_suffix],
676
+ output_dtype=q.dtype,
677
+ )
678
+
679
+ if is_bmhk:
680
+ attn_output = attn_output.squeeze(2)
681
+ return attn_output
682
+
683
+
684
+ class SplitKAutotune(triton_splitk.FwOp):
685
+ AUTOTUNE = True
686
+
687
+
688
+ @lru_cache
689
+ def construct_full_tree_choices(
690
+ tree_depth: int, branching: int
691
+ ) -> List[Tuple[int, ...]]:
692
+ """
693
+ Construct a full tree of a given depth where each node (except for leaves) has a given number of children.
694
+ The format is compatible with that used by Medusa and EAGLE:
695
+ https://github.com/FasterDecoding/Medusa/blob/5e98053/medusa/model/medusa_choices.py
696
+ For detailed description, see docstring of
697
+ xformers.ops.tree_attention.TreeAttnMetadata.from_tree_choices .
698
+ """
699
+ return construct_tree_choices(branching=[branching] * tree_depth)
700
+
701
+
702
+ def construct_tree_choices(
703
+ branching: List[int],
704
+ ) -> List[Tuple[int, ...]]:
705
+ """
706
+ Construct a tree based on given branching factor for each non-root level.
707
+ """
708
+ choices: List[Tuple[int, ...]] = []
709
+ for i in range(len(branching)):
710
+ choices.extend(itertools.product(*[range(branching[k]) for k in range(i + 1)]))
711
+ return choices
712
+
713
+
714
+ def get_full_tree_size(tree_depth: int, branching: int) -> int:
715
+ """
716
+ Number of nodes in a full tree of a given depth (including the root node) and branching factor.
717
+ """
718
+ return sum(branching**i for i in range(tree_depth))