mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-strict
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
from typing import List, Optional, Sequence, Tuple, Type
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from .. import fmha
|
|
16
|
+
from . import (
|
|
17
|
+
flash,
|
|
18
|
+
flash3,
|
|
19
|
+
memory_efficient_attention_forward_requires_grad,
|
|
20
|
+
memory_efficient_attention_partial,
|
|
21
|
+
merge_attentions,
|
|
22
|
+
triton_splitk,
|
|
23
|
+
)
|
|
24
|
+
from .attn_bias import (
|
|
25
|
+
AttentionBias,
|
|
26
|
+
PagedBlockDiagonalGappyKeysMask,
|
|
27
|
+
PagedBlockDiagonalPaddedKeysMask,
|
|
28
|
+
)
|
|
29
|
+
from .common import AttentionFwOpBase, pack_fp8_tensorwise_per_head
|
|
30
|
+
from .dispatch import _get_use_fa3, fa3_available
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class TreeAttnMetadata:
|
|
35
|
+
"""
|
|
36
|
+
tree_choices: definition of the tree, tuples sorted by length, each corresponding
|
|
37
|
+
to a node. See the docstring of TreeAttnMetadata.from_tree_choices.
|
|
38
|
+
attention_bias: Medusa-style tree attention bias as an explicit tensor
|
|
39
|
+
of shape (tree_size, tree_size), where tree_size is the total number
|
|
40
|
+
of nodes in the tree. It can be used as a spec_attn_bias ("right"
|
|
41
|
+
or "suffix" attention part) in tree_attention.
|
|
42
|
+
See tree_attention_with_sync for a usage example.
|
|
43
|
+
tree_indices: 1D tensor of size tree_size which maps tree nodes to draft tokens.
|
|
44
|
+
Tree nodes are assumed to be in the same order as in tree_choices
|
|
45
|
+
(see TreeAttnMetadata.from_tree_choices).
|
|
46
|
+
retrieval_indices: a tensor of shape (number of leaves, depth + 1), where one
|
|
47
|
+
row corresponds to one path, and contains indices of the tree nodes
|
|
48
|
+
on that path. Paths are padded with -1 from the right.
|
|
49
|
+
The paths (row dim) are unsorted.
|
|
50
|
+
path_lengths: real lengths for each of the paths.
|
|
51
|
+
tree_seq_position_ids: 1D tensor of size tree_size which indicates which head
|
|
52
|
+
a node belongs to. Equivalently, it shows the sequence position of the
|
|
53
|
+
node within the corresponding path.
|
|
54
|
+
parent_node_indices: 1D tensor of size tree_size which for each node contains
|
|
55
|
+
position of its parent + 1. For root node(s) it contains 0.
|
|
56
|
+
child_node_indices: a tensor of shape (tree_size, max_num_children_per_node),
|
|
57
|
+
in which each row contains indices of children of the corresponding node.
|
|
58
|
+
Rows corresponding to nodes which have less than max_num_children_per_node
|
|
59
|
+
children are padded by repeating the last child index.
|
|
60
|
+
For leaf nodes the values are meaningless and filled with 0.
|
|
61
|
+
num_children_per_node: 1D tensor of size tree_size which contains the number of
|
|
62
|
+
children for each node.
|
|
63
|
+
candidate_idx: 1D tensor of size tree_size, contains index of each node among its "siblings".
|
|
64
|
+
Takes values from 0 to the number of children of the parent node minus 1.
|
|
65
|
+
num_nodes_per_level: 1D tensor of the number of nodes at each level (including root).
|
|
66
|
+
num_children_per_node_at_level: List of 1D tensors, each containing the number of children at the tree level.
|
|
67
|
+
subtree_size: List of integers, each containing the number of nodes in the subtree at the tree level.
|
|
68
|
+
Example:
|
|
69
|
+
Tree choices
|
|
70
|
+
`[(0,), (0, 0), (0, 1), (0, 2), (1,), (1, 0), (1, 1), (1, 2), (2,), (2, 0), (2, 1), (2, 2)]`
|
|
71
|
+
represents a tree that looks like this:
|
|
72
|
+
0
|
|
73
|
+
|-- 1
|
|
74
|
+
| |-- 4
|
|
75
|
+
| |-- 5
|
|
76
|
+
| |-- 6
|
|
77
|
+
|
|
|
78
|
+
|-- 2
|
|
79
|
+
| |-- 7
|
|
80
|
+
| |-- 8
|
|
81
|
+
| |-- 9
|
|
82
|
+
|
|
|
83
|
+
|-- 3
|
|
84
|
+
|-- 10
|
|
85
|
+
|-- 11
|
|
86
|
+
|-- 12
|
|
87
|
+
|
|
88
|
+
with TreeAttnMetadata
|
|
89
|
+
tree_indices=tensor([0, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6])
|
|
90
|
+
retrieval_indices=tensor([[ 0, 1, 5],
|
|
91
|
+
[ 0, 2, 9],
|
|
92
|
+
[ 0, 3, 11],
|
|
93
|
+
[ 0, 1, 4],
|
|
94
|
+
[ 0, 2, 8],
|
|
95
|
+
[ 0, 3, 10],
|
|
96
|
+
[ 0, 1, 6],
|
|
97
|
+
[ 0, 2, 7],
|
|
98
|
+
[ 0, 3, 12]])
|
|
99
|
+
path_lengths=[3, 3, 3, 3, 3, 3, 3, 3, 3]
|
|
100
|
+
tree_seq_position_ids=tensor([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
|
|
101
|
+
child_node_indices=tensor([[ 0, 1, 2],
|
|
102
|
+
[ 3, 4, 5],
|
|
103
|
+
[ 6, 7, 8],
|
|
104
|
+
[ 9, 10, 11],
|
|
105
|
+
[ 0, 0, 0],
|
|
106
|
+
...
|
|
107
|
+
[ 0, 0, 0]])
|
|
108
|
+
num_children_per_node=tensor([3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
|
109
|
+
candidate_idx=tensor([0, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
|
|
110
|
+
num_nodes_per_level=tensor([1, 3, 3])
|
|
111
|
+
num_children_per_node_at_level=[tensor([3]), tensor([3, 3, 3]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0])]
|
|
112
|
+
subtree_sizes=[1, 4, 13]
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
tree_choices: Sequence[Tuple[int, ...]]
|
|
116
|
+
attention_bias: torch.Tensor
|
|
117
|
+
tree_indices: torch.Tensor
|
|
118
|
+
retrieval_indices: torch.Tensor
|
|
119
|
+
path_lengths: List[int]
|
|
120
|
+
tree_seq_position_ids: torch.Tensor
|
|
121
|
+
parent_node_indices: torch.Tensor
|
|
122
|
+
child_node_indices: torch.Tensor
|
|
123
|
+
num_children_per_node: torch.Tensor
|
|
124
|
+
candidate_idx: torch.Tensor
|
|
125
|
+
num_nodes_per_level: torch.Tensor
|
|
126
|
+
num_nodes_per_level_cpu: torch.Tensor
|
|
127
|
+
num_children_per_node_at_level: List[torch.Tensor]
|
|
128
|
+
num_children_per_node_at_level_cpu: List[torch.Tensor]
|
|
129
|
+
subtree_sizes: List[int]
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
@lru_cache
|
|
133
|
+
def from_tree_choices_cached(
|
|
134
|
+
cls,
|
|
135
|
+
tree_choices: Tuple[Tuple[int, ...]],
|
|
136
|
+
dtype: Optional[torch.dtype] = None,
|
|
137
|
+
device: Optional[torch.device] = None,
|
|
138
|
+
) -> "TreeAttnMetadata":
|
|
139
|
+
return cls.from_tree_choices(tree_choices, dtype, device)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def from_tree_choices(
|
|
143
|
+
cls,
|
|
144
|
+
tree_choices: Sequence[Tuple[int, ...]],
|
|
145
|
+
dtype: Optional[torch.dtype] = None,
|
|
146
|
+
device: Optional[torch.device] = None,
|
|
147
|
+
) -> "TreeAttnMetadata":
|
|
148
|
+
"""
|
|
149
|
+
Args:
|
|
150
|
+
tree_choices: tree description in the style of
|
|
151
|
+
https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py
|
|
152
|
+
A typical tree description would look like:
|
|
153
|
+
[(node0, node1, ...), (node0, node2), (node0, node3), (node1, node3), ..., (node0, node2, ..., nodeN)]
|
|
154
|
+
Every tuple is corresponds to one node in the tree, encoded as a path from one of the root nodes to the
|
|
155
|
+
node in question.
|
|
156
|
+
For example, a node encoded as (1, 0, 3, ..., 2) is understood as:
|
|
157
|
+
list all the root nodes and take node number 1
|
|
158
|
+
list all children of that node and take node number 0
|
|
159
|
+
list all children of that node and take node number 3
|
|
160
|
+
...
|
|
161
|
+
list all children of that node and take node number 2 - that's the node encoded by this tuple.
|
|
162
|
+
|
|
163
|
+
dtype: data type of the output mask tensor.
|
|
164
|
+
device: device of the output tensors.
|
|
165
|
+
Returns:
|
|
166
|
+
TreeAttnMetadata object with all the fields.
|
|
167
|
+
"""
|
|
168
|
+
# from https://github.com/SafeAILab/EAGLE/blob/e98fc7c/model/utils.py#L89C1-L117C1
|
|
169
|
+
sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x))
|
|
170
|
+
tree_len = len(sorted_tree_choices) + 1
|
|
171
|
+
|
|
172
|
+
depth_counts = _get_depth_counts(sorted_tree_choices)
|
|
173
|
+
tree_indices = _prepare_tree_indices(sorted_tree_choices, depth_counts, device)
|
|
174
|
+
retrieval_indices, path_lengths = _prepare_retrieval_indices(
|
|
175
|
+
sorted_tree_choices, device
|
|
176
|
+
)
|
|
177
|
+
tree_seq_position_ids = _prepare_tree_position_ids(
|
|
178
|
+
depth_counts, tree_len, device
|
|
179
|
+
)
|
|
180
|
+
tree_attn_mask = _prepare_tree_attn_bias(
|
|
181
|
+
sorted_tree_choices, depth_counts, dtype, device
|
|
182
|
+
)
|
|
183
|
+
parent_node_indices = _prepare_parent_node_indices(sorted_tree_choices, device)
|
|
184
|
+
child_node_indices, num_children_per_node = _prepare_child_node_indices(
|
|
185
|
+
sorted_tree_choices, device
|
|
186
|
+
)
|
|
187
|
+
candidate_idx = _prepare_candidate_idx(sorted_tree_choices, device)
|
|
188
|
+
|
|
189
|
+
num_nodes_per_level = _get_num_nodes_per_level(depth_counts, device)
|
|
190
|
+
num_nodes_per_level_cpu = num_nodes_per_level.cpu()
|
|
191
|
+
(
|
|
192
|
+
subtree_sizes,
|
|
193
|
+
num_children_per_node_at_level,
|
|
194
|
+
) = _get_subtree_size_and_num_children_per_node_at_level(
|
|
195
|
+
num_nodes_per_level, num_children_per_node, device
|
|
196
|
+
)
|
|
197
|
+
num_children_per_node_at_level_cpu = [
|
|
198
|
+
row.cpu() for row in num_children_per_node_at_level
|
|
199
|
+
]
|
|
200
|
+
return TreeAttnMetadata(
|
|
201
|
+
sorted_tree_choices,
|
|
202
|
+
tree_attn_mask,
|
|
203
|
+
tree_indices,
|
|
204
|
+
retrieval_indices,
|
|
205
|
+
path_lengths,
|
|
206
|
+
tree_seq_position_ids,
|
|
207
|
+
parent_node_indices,
|
|
208
|
+
child_node_indices,
|
|
209
|
+
num_children_per_node,
|
|
210
|
+
candidate_idx,
|
|
211
|
+
num_nodes_per_level,
|
|
212
|
+
num_nodes_per_level_cpu,
|
|
213
|
+
num_children_per_node_at_level,
|
|
214
|
+
num_children_per_node_at_level_cpu,
|
|
215
|
+
subtree_sizes,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _get_subtree_size_and_num_children_per_node_at_level(
|
|
220
|
+
num_nodes_per_level: torch.Tensor,
|
|
221
|
+
num_children_per_node: torch.Tensor,
|
|
222
|
+
device: Optional[torch.device] = None,
|
|
223
|
+
) -> Tuple[List[int], List[torch.Tensor]]:
|
|
224
|
+
depth: int = len(num_nodes_per_level)
|
|
225
|
+
subtree_sizes: List[int] = [
|
|
226
|
+
1,
|
|
227
|
+
]
|
|
228
|
+
num_children_per_node_at_level: List[torch.Tensor] = [
|
|
229
|
+
num_children_per_node[0].unsqueeze(0)
|
|
230
|
+
]
|
|
231
|
+
for i in range(1, depth):
|
|
232
|
+
subtree_sizes.append(int(torch.sum(num_nodes_per_level[: (i + 1)])))
|
|
233
|
+
num_children_per_node_at_level.append(
|
|
234
|
+
num_children_per_node[subtree_sizes[i - 1] : subtree_sizes[i]]
|
|
235
|
+
)
|
|
236
|
+
return subtree_sizes, num_children_per_node_at_level
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _get_depth_counts(sorted_tree_choices: List[Tuple[int, ...]]) -> List[int]:
|
|
240
|
+
# Initialize depth_counts to keep track of how many choices have a particular depth
|
|
241
|
+
depth_counts = []
|
|
242
|
+
prev_depth = 0
|
|
243
|
+
for path in sorted_tree_choices:
|
|
244
|
+
depth = len(path)
|
|
245
|
+
if depth != prev_depth:
|
|
246
|
+
depth_counts.append(0)
|
|
247
|
+
depth_counts[depth - 1] += 1
|
|
248
|
+
prev_depth = depth
|
|
249
|
+
return depth_counts
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _get_num_nodes_per_level(
|
|
253
|
+
depth_counts: List[int], device: Optional[torch.device]
|
|
254
|
+
) -> torch.Tensor:
|
|
255
|
+
depth_counts_tensor: torch.Tensor = torch.tensor([1] + depth_counts, device=device)
|
|
256
|
+
return depth_counts_tensor[depth_counts_tensor != 0]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _prepare_tree_attn_bias(
|
|
260
|
+
sorted_tree_choices: List[Tuple[int, ...]],
|
|
261
|
+
depth_counts: List[int],
|
|
262
|
+
dtype: Optional[torch.dtype],
|
|
263
|
+
device: Optional[torch.device],
|
|
264
|
+
) -> torch.Tensor:
|
|
265
|
+
"""
|
|
266
|
+
Construct a Medusa-style tree attention bias as an explicit tensor.
|
|
267
|
+
It can be used as a spec_attn_bias ("right" or "suffix" attention part)
|
|
268
|
+
in tree_attention. See run_tree_attention_inner in test for a usage example.
|
|
269
|
+
Args:
|
|
270
|
+
sorted_tree_choices: tree description in the style of
|
|
271
|
+
https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py
|
|
272
|
+
A typical tree description would look like:
|
|
273
|
+
[(node0, node1, ...), (node0, node2), (node0, node3), (node1, node3), ..., (node0, node2, ..., nodeN)]
|
|
274
|
+
Every tuple is corresponds to one node in the tree, encoded as a path from one of the root nodes to the
|
|
275
|
+
node in question. Passed in sorted order.
|
|
276
|
+
For example, a node encoded as (1, 0, 3, ..., 2) is understood as:
|
|
277
|
+
list all the root nodes and take node number 1
|
|
278
|
+
list all children of that node and take node number 0
|
|
279
|
+
list all children of that node and take node number 3
|
|
280
|
+
...
|
|
281
|
+
list all children of that node and take node number 2 - that's the node encoded by this tuple
|
|
282
|
+
depth_counts: a list of integers, where the i-th element is the number of choices with depth i.
|
|
283
|
+
dtype: data type of the output tensor.
|
|
284
|
+
device: device of the output tensor.
|
|
285
|
+
Returns:
|
|
286
|
+
attention bias of shape (tree_size, tree_size),
|
|
287
|
+
where tree_size is the total number of nodes in the tree.
|
|
288
|
+
"""
|
|
289
|
+
# +1 comes from the addtional root node
|
|
290
|
+
tree_len = len(sorted_tree_choices) + 1
|
|
291
|
+
tree_attn_mask = torch.full(
|
|
292
|
+
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
mask_val = 0
|
|
296
|
+
for i in range(tree_len):
|
|
297
|
+
tree_attn_mask[i, i] = mask_val
|
|
298
|
+
|
|
299
|
+
tree_attn_mask[:, 0] = mask_val
|
|
300
|
+
start = 0
|
|
301
|
+
for i in range(len(depth_counts)):
|
|
302
|
+
for j in range(depth_counts[i]):
|
|
303
|
+
cur_tree_choice = sorted_tree_choices[start + j]
|
|
304
|
+
# retrieve ancestor position
|
|
305
|
+
if len(cur_tree_choice) == 1:
|
|
306
|
+
continue
|
|
307
|
+
ancestor_idx = []
|
|
308
|
+
for c in range(len(cur_tree_choice) - 1):
|
|
309
|
+
ancestor_idx.append(
|
|
310
|
+
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
|
|
311
|
+
)
|
|
312
|
+
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
|
313
|
+
start += depth_counts[i]
|
|
314
|
+
return tree_attn_mask
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _prepare_tree_indices(
|
|
318
|
+
sorted_tree_choices: List[Tuple[int, ...]],
|
|
319
|
+
depth_counts: List[int],
|
|
320
|
+
device: Optional[torch.device],
|
|
321
|
+
) -> torch.Tensor:
|
|
322
|
+
"""
|
|
323
|
+
Construct an index tensor for choices in the tree and their corresponding index in the draft tokens.
|
|
324
|
+
Args:
|
|
325
|
+
sorted_tree_choices: sorted from tree_choices input of prepare_tree_attn_metadata function
|
|
326
|
+
depth_counts: a list of integers, where the i-th element is the number of choices with depth i.
|
|
327
|
+
device: device of the output tensor.
|
|
328
|
+
Returns:
|
|
329
|
+
tree indices of shape (tree_len,). See docstring of TreeAttnMetadata for details.
|
|
330
|
+
"""
|
|
331
|
+
# Generate tree indices for the tree_choices structure
|
|
332
|
+
# add root node from main head prediction to the tree_len
|
|
333
|
+
tree_len = len(sorted_tree_choices) + 1
|
|
334
|
+
tree_indices = torch.zeros(tree_len, device=device, dtype=torch.long)
|
|
335
|
+
tree_indices[0] = 0
|
|
336
|
+
start, max_idx_prev_level = 0, 0
|
|
337
|
+
for i in range(len(depth_counts)):
|
|
338
|
+
cur_offset = max_idx_prev_level
|
|
339
|
+
for j in range(depth_counts[i]):
|
|
340
|
+
cur_tree_choice = sorted_tree_choices[start + j]
|
|
341
|
+
tree_idx = cur_tree_choice[-1] + cur_offset + 1
|
|
342
|
+
tree_indices[start + j + 1] = tree_idx
|
|
343
|
+
max_idx_prev_level = max(tree_idx, max_idx_prev_level)
|
|
344
|
+
start += depth_counts[i]
|
|
345
|
+
return tree_indices
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _prepare_retrieval_indices(
|
|
349
|
+
tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
|
|
350
|
+
) -> Tuple[torch.Tensor, List[int]]:
|
|
351
|
+
"""
|
|
352
|
+
Convert tree definition from the format used by Medusa and EAGLE (tree_choices, see docstring of
|
|
353
|
+
TreeAttnMetadata.from_tree_choices) to a list of paths:
|
|
354
|
+
[
|
|
355
|
+
(node_index0_path0, node_index1_path0, ...),
|
|
356
|
+
(node_index0_path1, node_index1_path1, ...),
|
|
357
|
+
...
|
|
358
|
+
]
|
|
359
|
+
where each value is an index of a node inside the corresponding level of a tree.
|
|
360
|
+
Returns:
|
|
361
|
+
retrieval indices of shape (number of leaves, depth + 1)
|
|
362
|
+
length of each path.
|
|
363
|
+
"""
|
|
364
|
+
tree_depth = max(len(node) for node in tree_choices) + 1 if tree_choices else 1
|
|
365
|
+
|
|
366
|
+
leaves = set(tree_choices)
|
|
367
|
+
|
|
368
|
+
for node in tree_choices[::-1]:
|
|
369
|
+
if node[:-1] in leaves:
|
|
370
|
+
leaves.remove(node[:-1])
|
|
371
|
+
|
|
372
|
+
paths, path_lengths = [], []
|
|
373
|
+
for leaf in leaves:
|
|
374
|
+
path = [0] + [
|
|
375
|
+
tree_choices.index(leaf[:level]) + 1 for level in range(1, len(leaf) + 1)
|
|
376
|
+
]
|
|
377
|
+
path_lengths.append(len(path))
|
|
378
|
+
paths.append(path + [-1] * (tree_depth - len(path)))
|
|
379
|
+
paths_tensor = torch.tensor(paths, dtype=torch.long, device=device)
|
|
380
|
+
return paths_tensor, path_lengths
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _prepare_tree_position_ids(
|
|
384
|
+
depth_counts: List[int], tree_len: int, device: Optional[torch.device]
|
|
385
|
+
) -> torch.Tensor:
|
|
386
|
+
"""
|
|
387
|
+
Construct sequence position of each node within its path, can be used for positional embedding.
|
|
388
|
+
Args:
|
|
389
|
+
depth_counts: number of nodes at each of the levels of the tree.
|
|
390
|
+
tree_len: total number of nodes in the tree including the root.
|
|
391
|
+
device: device of the output tensor.
|
|
392
|
+
Returns:
|
|
393
|
+
tree position ids of shape (tree_len,). See docstring of TreeAttnMetadata for details.
|
|
394
|
+
"""
|
|
395
|
+
tree_position_ids = torch.zeros(tree_len, dtype=torch.int32, device=device)
|
|
396
|
+
start = 0
|
|
397
|
+
for i in range(len(depth_counts)):
|
|
398
|
+
tree_position_ids[start + 1 : start + depth_counts[i] + 1] = i + 1
|
|
399
|
+
start += depth_counts[i]
|
|
400
|
+
|
|
401
|
+
return tree_position_ids
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _prepare_parent_node_indices(
|
|
405
|
+
sorted_tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
|
|
406
|
+
) -> torch.Tensor:
|
|
407
|
+
ancestor_idx = []
|
|
408
|
+
for cur_medusa_choice in sorted_tree_choices:
|
|
409
|
+
try:
|
|
410
|
+
ancestor_idx.append(sorted_tree_choices.index(cur_medusa_choice[:-1]) + 1)
|
|
411
|
+
except ValueError:
|
|
412
|
+
ancestor_idx.append(0)
|
|
413
|
+
return torch.tensor(ancestor_idx, dtype=torch.long, device=device)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _prepare_child_node_indices(
|
|
417
|
+
tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
|
|
418
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
419
|
+
res = []
|
|
420
|
+
num_children_per_node = []
|
|
421
|
+
for x in [()] + tree_choices:
|
|
422
|
+
curr_children = [
|
|
423
|
+
i
|
|
424
|
+
for i, y in enumerate(tree_choices)
|
|
425
|
+
if len(x) + 1 == len(y) and y[:-1] == x
|
|
426
|
+
]
|
|
427
|
+
num_children_per_node.append(len(curr_children))
|
|
428
|
+
if curr_children:
|
|
429
|
+
res.append(curr_children)
|
|
430
|
+
else:
|
|
431
|
+
res.append([0])
|
|
432
|
+
|
|
433
|
+
# pad all children lists by repeating the last element
|
|
434
|
+
max_num_children = max(len(x) for x in res)
|
|
435
|
+
res = [x + x[-1:] * (max_num_children - len(x)) for x in res]
|
|
436
|
+
|
|
437
|
+
# Check that all nodes have the same number of children.
|
|
438
|
+
assert all(len(x) == len(res[0]) for x in res)
|
|
439
|
+
return (
|
|
440
|
+
torch.tensor(res, dtype=torch.long, device=device),
|
|
441
|
+
torch.tensor(num_children_per_node, dtype=torch.long, device=device),
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _prepare_candidate_idx(
|
|
446
|
+
tree_choices: List[Tuple[int, ...]], device: Optional[torch.device]
|
|
447
|
+
) -> torch.Tensor:
|
|
448
|
+
candidate_idx = [
|
|
449
|
+
sum(
|
|
450
|
+
curr_node[:-1] == another_node[:-1]
|
|
451
|
+
for another_node in tree_choices[:curr_node_idx]
|
|
452
|
+
)
|
|
453
|
+
for curr_node_idx, curr_node in enumerate(tree_choices)
|
|
454
|
+
]
|
|
455
|
+
return torch.tensor(candidate_idx, dtype=torch.long, device=device)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def use_triton_splitk_for_prefix(B: int, G: int, tree_size: int) -> bool:
|
|
459
|
+
"""
|
|
460
|
+
Heuristic to decide whether to use Triton Split-k or default (Flash Attention) for prefix attention.
|
|
461
|
+
"""
|
|
462
|
+
return (
|
|
463
|
+
(B * G <= 128 and tree_size <= 64)
|
|
464
|
+
or (B * G < 4 and tree_size < 100)
|
|
465
|
+
or B * G < 2
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def select_prefix_op(
|
|
470
|
+
B: int,
|
|
471
|
+
G: int,
|
|
472
|
+
tree_size: int,
|
|
473
|
+
autotune: bool,
|
|
474
|
+
attn_bias: AttentionBias,
|
|
475
|
+
kv_cache_dtype: torch.dtype,
|
|
476
|
+
) -> Optional[Type[AttentionFwOpBase]]:
|
|
477
|
+
"""
|
|
478
|
+
Heuristic to decide whether to use Triton Split-k or default (Flash Attention) for prefix attention.
|
|
479
|
+
"""
|
|
480
|
+
triton_splitk_op = SplitKAutotune if autotune else triton_splitk.FwOp
|
|
481
|
+
if torch.version.hip:
|
|
482
|
+
# TODO: further tune heuristics once CK splitK is ready
|
|
483
|
+
return triton_splitk_op
|
|
484
|
+
|
|
485
|
+
# Triton Split-k is not present in the dispatcher list for some shapes.
|
|
486
|
+
# However, we need to dispatch to it if no other op is available.
|
|
487
|
+
# FA3 splitKV doesn't yet support gappy or paged biases.
|
|
488
|
+
fa3_splitkv_supported = isinstance(
|
|
489
|
+
attn_bias,
|
|
490
|
+
flash3.FwOp_KVSplit.SUPPORTED_ATTN_BIAS_TYPES, # type: ignore
|
|
491
|
+
)
|
|
492
|
+
fa3_supported = isinstance(attn_bias, flash3.FwOp.SUPPORTED_ATTN_BIAS_TYPES) # type: ignore
|
|
493
|
+
flash2_supported = isinstance(attn_bias, flash.FwOp.SUPPORTED_ATTN_BIAS_TYPES) # type: ignore
|
|
494
|
+
if not (fa3_splitkv_supported or fa3_supported or flash2_supported):
|
|
495
|
+
return triton_splitk_op
|
|
496
|
+
|
|
497
|
+
assert torch.version.cuda
|
|
498
|
+
use_fa3 = _get_use_fa3() and fa3_available()
|
|
499
|
+
# override heuristics for quantized kv cache for decode
|
|
500
|
+
if kv_cache_dtype == torch.uint8:
|
|
501
|
+
return triton_splitk_op
|
|
502
|
+
# select FA3 when bs >= 64
|
|
503
|
+
if B >= 64 and use_fa3:
|
|
504
|
+
if fa3_splitkv_supported:
|
|
505
|
+
return flash3.FwOp_KVSplit
|
|
506
|
+
return flash3.FwOp
|
|
507
|
+
elif use_triton_splitk_for_prefix(B, G, tree_size):
|
|
508
|
+
return triton_splitk_op
|
|
509
|
+
else:
|
|
510
|
+
# use default heuristics from xformers
|
|
511
|
+
return None
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def tree_attention(
|
|
515
|
+
q: torch.Tensor,
|
|
516
|
+
spec_k: torch.Tensor,
|
|
517
|
+
spec_v: torch.Tensor,
|
|
518
|
+
cache_k: torch.Tensor,
|
|
519
|
+
cache_v: torch.Tensor,
|
|
520
|
+
spec_attn_bias: torch.Tensor,
|
|
521
|
+
prefix_attn_bias: AttentionBias,
|
|
522
|
+
prefix_op: Optional[Type[AttentionFwOpBase]] = None,
|
|
523
|
+
suffix_op: Optional[Type[AttentionFwOpBase]] = None,
|
|
524
|
+
autotune: bool = False,
|
|
525
|
+
quantized_kv_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
526
|
+
q_fp8: Optional[torch.Tensor] = None,
|
|
527
|
+
) -> torch.Tensor:
|
|
528
|
+
"""
|
|
529
|
+
Compute Medusa/EAGLE/Hydra-style tree attention.
|
|
530
|
+
Notice that this function takes as arguments biases for the left (prefix)
|
|
531
|
+
and right (speculative suffix) parts of the attention.
|
|
532
|
+
This way we avoid creating these biases on the fly, and
|
|
533
|
+
allow this function to be used in performance-critical decoding
|
|
534
|
+
jobs, including in CUDA graph mode. In the latter case one should
|
|
535
|
+
construct the biases once, and update prefix_attn_bias with
|
|
536
|
+
current seqlens before every graph replay; spec_attn_bias stays static,
|
|
537
|
+
as it's determined by the tree structure.
|
|
538
|
+
Args:
|
|
539
|
+
q: query from speculative tokens, of shape (B, tree_size_q, (G), H, D)
|
|
540
|
+
spec_k, spec_v: keys/values from speculative tokens, each of shape (B, tree_size_kv, (G), H, D).
|
|
541
|
+
If tree_size_q < tree_size_kv, we assume the end of the query sequence aligns with end the k/v sequence,
|
|
542
|
+
like in "from-bottom-right" attention masks. Such rectangular attention masks can be used when we are
|
|
543
|
+
adding new nodes to the tree, and want to avoid recomputing attention for the existing nodes. For example,
|
|
544
|
+
this can be used during draft token generation in EAGLE.
|
|
545
|
+
cache_k/cache_v: queries/keys/values from the existing context, each of shape (B, Mk, (G), H, D)
|
|
546
|
+
spec_attn_bias: attention bias of the "right" part of the attention (tree_size_q x spec tokens).
|
|
547
|
+
This would typically be a an explicit tensor mask, precomputed once and not changing during decoding
|
|
548
|
+
prefix_attn_bias: attention bias of the "left" part of the attention (tree_size_q x existing context).
|
|
549
|
+
This bias would typically be block-diagonal padded non-causal (BlockDiagonalPaddedKeysMask), and it
|
|
550
|
+
changes at every decoding step as K/V sequence lengths grow during decoding.
|
|
551
|
+
prefix_op: attention backend which will be passed to memory_efficient_attention to compute prefix attention.
|
|
552
|
+
If None, will use Triton Split-K or Flash Attention depending on the heuristics.
|
|
553
|
+
suffix_op: same as prefix_op, but for the suffix.
|
|
554
|
+
autotune: If True, Triton Split-K will use autotuning when chosen
|
|
555
|
+
as a default backend for prefix/suffix attention.
|
|
556
|
+
Returns:
|
|
557
|
+
attention output of shape (B, tree_size_q, (G), H, D)
|
|
558
|
+
|
|
559
|
+
:Usage example:
|
|
560
|
+
|
|
561
|
+
See also tree_attention_with_sync in tests/test_tree_attention.py
|
|
562
|
+
|
|
563
|
+
.. code-block:: python
|
|
564
|
+
|
|
565
|
+
# Create an attention bias for the prefix part of the attention
|
|
566
|
+
prefix_attn_bias = BlockDiagonalPaddedKeysMask.from_seqlens(
|
|
567
|
+
q_seqlen=[tree_size_q for _ in range(B)], kv_seqlen=kv_lens, kv_padding=Mk
|
|
568
|
+
)
|
|
569
|
+
# Create an explit attention bias for the speculative part of the attention
|
|
570
|
+
spec_attn_bias = TreeAttnMetadata.from_tree_choices(tree_choices, q.dtype, q.device).attention_bias
|
|
571
|
+
attn_output = tree_attention(
|
|
572
|
+
q, spec_k, spec_v, cache_k, cache_v, spec_attn_bias, prefix_attn_bias
|
|
573
|
+
)
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
is_bmhk = q.ndim == 4
|
|
577
|
+
if is_bmhk:
|
|
578
|
+
q = q.unsqueeze(2)
|
|
579
|
+
spec_k, spec_v = spec_k.unsqueeze(2), spec_v.unsqueeze(2)
|
|
580
|
+
cache_k, cache_v = cache_k.unsqueeze(2), cache_v.unsqueeze(2)
|
|
581
|
+
|
|
582
|
+
B, tree_size_q, G, H, D = q.shape
|
|
583
|
+
Bkv, Mk, G1, H1, D1 = cache_k.shape
|
|
584
|
+
tree_size_q1, tree_size_kv = spec_attn_bias.shape
|
|
585
|
+
if isinstance(
|
|
586
|
+
prefix_attn_bias,
|
|
587
|
+
(PagedBlockDiagonalPaddedKeysMask, PagedBlockDiagonalGappyKeysMask),
|
|
588
|
+
):
|
|
589
|
+
assert Bkv == 1
|
|
590
|
+
else:
|
|
591
|
+
assert Bkv == B
|
|
592
|
+
|
|
593
|
+
assert H == H1 and D == D1 and G == G1
|
|
594
|
+
assert cache_k.shape == cache_v.shape
|
|
595
|
+
assert tree_size_q1 == tree_size_q <= tree_size_kv, (
|
|
596
|
+
f"{tree_size_q1=} {tree_size_q=} {tree_size_kv=}"
|
|
597
|
+
)
|
|
598
|
+
assert q.shape[2:] == spec_k.shape[2:] == spec_v.shape[2:], (
|
|
599
|
+
f"{q.shape=} {spec_k.shape=} {spec_v.shape=}"
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
spec_attn_bias = spec_attn_bias.expand(B, G, H, tree_size_q, tree_size_kv)
|
|
603
|
+
|
|
604
|
+
triton_splitk_op = SplitKAutotune if autotune else triton_splitk.FwOp
|
|
605
|
+
|
|
606
|
+
# TODO: improve this heuristic
|
|
607
|
+
if prefix_op is None:
|
|
608
|
+
prefix_op = select_prefix_op(
|
|
609
|
+
B, G, tree_size_kv, autotune, prefix_attn_bias, cache_k.dtype
|
|
610
|
+
)
|
|
611
|
+
if cache_k.dtype == torch.uint8:
|
|
612
|
+
assert quantized_kv_scales is not None
|
|
613
|
+
assert prefix_op is triton_splitk.FwOp
|
|
614
|
+
fp8_inp = triton_splitk.InputsFp8(
|
|
615
|
+
query=q.view(1, B * tree_size_q, G, H, D),
|
|
616
|
+
key=cache_k.view(1, Bkv * Mk, G, H, D).view(torch.int32),
|
|
617
|
+
value=cache_v.view(1, Bkv * Mk, G, H, D).view(torch.int32),
|
|
618
|
+
attn_bias=prefix_attn_bias,
|
|
619
|
+
k_fp8_scale_shift=quantized_kv_scales[0].view(1, Bkv * Mk, G, H),
|
|
620
|
+
v_fp8_scale_shift=quantized_kv_scales[1].view(1, Bkv * Mk, G, H),
|
|
621
|
+
is_partial=True,
|
|
622
|
+
)
|
|
623
|
+
out, ctx = fmha._memory_efficient_attention_forward_requires_grad(
|
|
624
|
+
fp8_inp,
|
|
625
|
+
op=prefix_op,
|
|
626
|
+
)
|
|
627
|
+
attn_prefix, lse_prefix = out, ctx.lse
|
|
628
|
+
elif cache_k.dtype == torch.float8_e4m3fn:
|
|
629
|
+
assert q_fp8 is not None
|
|
630
|
+
packed_q = pack_fp8_tensorwise_per_head(
|
|
631
|
+
q_fp8.view(1, B * tree_size_q, G, H, D),
|
|
632
|
+
q_fp8.scale, # type: ignore
|
|
633
|
+
torch.bfloat16,
|
|
634
|
+
)
|
|
635
|
+
packed_k = pack_fp8_tensorwise_per_head(
|
|
636
|
+
cache_k.view(1, Bkv * Mk, G, H, D), cache_k.scale, torch.bfloat16
|
|
637
|
+
)
|
|
638
|
+
packed_v = pack_fp8_tensorwise_per_head(
|
|
639
|
+
cache_v.view(1, Bkv * Mk, G, H, D), cache_v.scale, torch.bfloat16
|
|
640
|
+
)
|
|
641
|
+
attn_prefix, lse_prefix = memory_efficient_attention_partial(
|
|
642
|
+
packed_q,
|
|
643
|
+
packed_k,
|
|
644
|
+
packed_v,
|
|
645
|
+
attn_bias=prefix_attn_bias,
|
|
646
|
+
op=fmha.flash3.FwOp_KVSplit,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
else:
|
|
650
|
+
attn_prefix, lse_prefix = memory_efficient_attention_partial(
|
|
651
|
+
q.view(1, B * tree_size_q, G, H, D),
|
|
652
|
+
cache_k.view(1, Bkv * Mk, G, H, D),
|
|
653
|
+
cache_v.view(1, Bkv * Mk, G, H, D),
|
|
654
|
+
attn_bias=prefix_attn_bias,
|
|
655
|
+
op=prefix_op,
|
|
656
|
+
)
|
|
657
|
+
attn_prefix = attn_prefix.view(B, tree_size_q, G, H, D)
|
|
658
|
+
lse_prefix = lse_prefix.view(G, H, B, tree_size_q).permute(2, 0, 1, 3)
|
|
659
|
+
|
|
660
|
+
# attn_suffix ~ (B, tree_size_q, G, H, D)
|
|
661
|
+
# lse_suffix ~ (B, G, H, tree_size_q)
|
|
662
|
+
attn_suffix, lse_suffix = memory_efficient_attention_forward_requires_grad(
|
|
663
|
+
q,
|
|
664
|
+
spec_k,
|
|
665
|
+
spec_v,
|
|
666
|
+
attn_bias=spec_attn_bias,
|
|
667
|
+
op=suffix_op or triton_splitk_op,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# attn_output ~ [B, tree_size_q, G, H, D]
|
|
671
|
+
# attn input [B, M, G, H, Kq]
|
|
672
|
+
# lse input [B, G, H, M]
|
|
673
|
+
attn_output, _ = merge_attentions(
|
|
674
|
+
[attn_prefix, attn_suffix],
|
|
675
|
+
[lse_prefix, lse_suffix],
|
|
676
|
+
output_dtype=q.dtype,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
if is_bmhk:
|
|
680
|
+
attn_output = attn_output.squeeze(2)
|
|
681
|
+
return attn_output
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
class SplitKAutotune(triton_splitk.FwOp):
|
|
685
|
+
AUTOTUNE = True
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
@lru_cache
|
|
689
|
+
def construct_full_tree_choices(
|
|
690
|
+
tree_depth: int, branching: int
|
|
691
|
+
) -> List[Tuple[int, ...]]:
|
|
692
|
+
"""
|
|
693
|
+
Construct a full tree of a given depth where each node (except for leaves) has a given number of children.
|
|
694
|
+
The format is compatible with that used by Medusa and EAGLE:
|
|
695
|
+
https://github.com/FasterDecoding/Medusa/blob/5e98053/medusa/model/medusa_choices.py
|
|
696
|
+
For detailed description, see docstring of
|
|
697
|
+
xformers.ops.tree_attention.TreeAttnMetadata.from_tree_choices .
|
|
698
|
+
"""
|
|
699
|
+
return construct_tree_choices(branching=[branching] * tree_depth)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def construct_tree_choices(
|
|
703
|
+
branching: List[int],
|
|
704
|
+
) -> List[Tuple[int, ...]]:
|
|
705
|
+
"""
|
|
706
|
+
Construct a tree based on given branching factor for each non-root level.
|
|
707
|
+
"""
|
|
708
|
+
choices: List[Tuple[int, ...]] = []
|
|
709
|
+
for i in range(len(branching)):
|
|
710
|
+
choices.extend(itertools.product(*[range(branching[k]) for k in range(i + 1)]))
|
|
711
|
+
return choices
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def get_full_tree_size(tree_depth: int, branching: int) -> int:
|
|
715
|
+
"""
|
|
716
|
+
Number of nodes in a full tree of a given depth (including the root node) and branching factor.
|
|
717
|
+
"""
|
|
718
|
+
return sum(branching**i for i in range(tree_depth))
|