quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/mlp.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from quack.linear import linear_act_func, act_linear_func
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True):
|
|
11
|
+
preact, postact = linear_act_func(
|
|
12
|
+
x,
|
|
13
|
+
weight1,
|
|
14
|
+
activation,
|
|
15
|
+
store_preact=torch.is_grad_enabled(),
|
|
16
|
+
fuse_grad_accum=fuse_grad_accum,
|
|
17
|
+
tuned=tuned,
|
|
18
|
+
)
|
|
19
|
+
out = act_linear_func(
|
|
20
|
+
preact,
|
|
21
|
+
weight2,
|
|
22
|
+
postact,
|
|
23
|
+
activation=activation,
|
|
24
|
+
fuse_grad_accum=fuse_grad_accum,
|
|
25
|
+
tuned=tuned,
|
|
26
|
+
)
|
|
27
|
+
return out
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MLP(nn.Module):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
in_features,
|
|
34
|
+
hidden_features=None,
|
|
35
|
+
out_features=None,
|
|
36
|
+
bias1=False,
|
|
37
|
+
bias2=False,
|
|
38
|
+
activation="gelu",
|
|
39
|
+
device=None,
|
|
40
|
+
dtype=None,
|
|
41
|
+
fuse_grad_accum: bool = False,
|
|
42
|
+
tuned: bool = True,
|
|
43
|
+
):
|
|
44
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
45
|
+
super().__init__()
|
|
46
|
+
out_features = out_features if out_features is not None else in_features
|
|
47
|
+
hidden_features = hidden_features if hidden_features is not None else 4 * in_features
|
|
48
|
+
self.activation = activation
|
|
49
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
|
50
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
51
|
+
self.fuse_grad_accum = fuse_grad_accum
|
|
52
|
+
self.tuned = tuned
|
|
53
|
+
|
|
54
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
55
|
+
if (
|
|
56
|
+
self.fc1.bias is None
|
|
57
|
+
and self.fc2.bias is None
|
|
58
|
+
and input.is_cuda
|
|
59
|
+
and input.stride(-1) == 1
|
|
60
|
+
and self.fc1.in_features % 8 == 0
|
|
61
|
+
and self.fc1.out_features % 8 == 0
|
|
62
|
+
and self.fc2.out_features % 8 == 0
|
|
63
|
+
):
|
|
64
|
+
return mlp_func(
|
|
65
|
+
input,
|
|
66
|
+
self.fc1.weight,
|
|
67
|
+
self.fc2.weight,
|
|
68
|
+
activation=self.activation,
|
|
69
|
+
fuse_grad_accum=self.fuse_grad_accum,
|
|
70
|
+
tuned=self.tuned,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
y = self.fc1(input)
|
|
74
|
+
return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
|
quack/pipeline.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
|
|
8
|
+
from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
9
|
+
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PipelineStateWAdvance(PipelineState):
|
|
13
|
+
def advance_iters(self, num_iterations: Int32):
|
|
14
|
+
self._count += Int32(num_iterations)
|
|
15
|
+
new_index = self._index + Int32(num_iterations)
|
|
16
|
+
# How many times did we cross the stages boundary
|
|
17
|
+
num_crossings = new_index // self.stages
|
|
18
|
+
self._phase ^= num_crossings
|
|
19
|
+
self._index = new_index % self.stages
|
|
20
|
+
|
|
21
|
+
# This can be overridden by derived classes
|
|
22
|
+
def __new_from_mlir_values__(self, values):
|
|
23
|
+
return PipelineStateWAdvance(
|
|
24
|
+
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
29
|
+
"""
|
|
30
|
+
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
|
31
|
+
"""
|
|
32
|
+
if type is PipelineUserType.Producer:
|
|
33
|
+
return PipelineStateWAdvance(
|
|
34
|
+
stages,
|
|
35
|
+
Int32(0),
|
|
36
|
+
Int32(0),
|
|
37
|
+
Int32(1),
|
|
38
|
+
)
|
|
39
|
+
elif type is PipelineUserType.Consumer:
|
|
40
|
+
return PipelineStateWAdvance(
|
|
41
|
+
stages,
|
|
42
|
+
Int32(0),
|
|
43
|
+
Int32(0),
|
|
44
|
+
Int32(0),
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
52
|
+
"""
|
|
53
|
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def create(
|
|
58
|
+
*,
|
|
59
|
+
num_stages: int,
|
|
60
|
+
producer_group: CooperativeGroup,
|
|
61
|
+
consumer_group: CooperativeGroup,
|
|
62
|
+
tx_count: int,
|
|
63
|
+
barrier_storage: cute.Pointer = None,
|
|
64
|
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
65
|
+
tidx: Optional[Int32] = None,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
|
69
|
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
70
|
+
:type barrier_storage: cute.Pointer
|
|
71
|
+
:param num_stages: Number of buffer stages for this pipeline
|
|
72
|
+
:type num_stages: Int32
|
|
73
|
+
:param producer_group: CooperativeGroup for the producer agent
|
|
74
|
+
:type producer_group: CooperativeGroup
|
|
75
|
+
:param consumer_group: CooperativeGroup for the consumer agent
|
|
76
|
+
:type consumer_group: CooperativeGroup
|
|
77
|
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
78
|
+
:type tx_count: int
|
|
79
|
+
:param cta_layout_vmnk: Layout of the cluster shape
|
|
80
|
+
:type cta_layout_vmnk: cute.Layout | None
|
|
81
|
+
:param tidx: thread index to consumer async threads
|
|
82
|
+
:type tidx: Int32 | None
|
|
83
|
+
"""
|
|
84
|
+
if not isinstance(barrier_storage, cute.Pointer):
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
producer_type = PipelineOp.TmaLoad
|
|
90
|
+
consumer_type = PipelineOp.AsyncThread
|
|
91
|
+
|
|
92
|
+
producer = (producer_type, producer_group)
|
|
93
|
+
consumer = (consumer_type, consumer_group)
|
|
94
|
+
|
|
95
|
+
sync_object_full = PipelineAsync._make_sync_object(
|
|
96
|
+
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
|
97
|
+
)
|
|
98
|
+
sync_object_empty = PipelineAsync._make_sync_object(
|
|
99
|
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
100
|
+
)
|
|
101
|
+
if tidx is None:
|
|
102
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
103
|
+
if cta_layout_vmnk is None:
|
|
104
|
+
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
|
105
|
+
(
|
|
106
|
+
dst_rank,
|
|
107
|
+
is_signalling_thread,
|
|
108
|
+
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
|
109
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
110
|
+
dst_rank = None
|
|
111
|
+
else:
|
|
112
|
+
dst_rank = dst_rank
|
|
113
|
+
|
|
114
|
+
producer_mask = None
|
|
115
|
+
|
|
116
|
+
pipeline_init_wait(cta_layout_vmnk)
|
|
117
|
+
|
|
118
|
+
return PipelineTmaCpAsync(
|
|
119
|
+
sync_object_full,
|
|
120
|
+
sync_object_empty,
|
|
121
|
+
num_stages,
|
|
122
|
+
producer_mask,
|
|
123
|
+
dst_rank,
|
|
124
|
+
is_signalling_thread,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def producer_acquire(
|
|
128
|
+
self,
|
|
129
|
+
state: PipelineState,
|
|
130
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
131
|
+
is_tma_warp: Optional[Boolean] = True,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
|
135
|
+
"""
|
|
136
|
+
if_generate(
|
|
137
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
138
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
139
|
+
)
|
|
140
|
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
141
|
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
142
|
+
if_generate(
|
|
143
|
+
is_tma_warp,
|
|
144
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def producer_commit(self, state: PipelineState):
|
|
148
|
+
"""
|
|
149
|
+
We need the mbarrier to track the completion of cp.async
|
|
150
|
+
"""
|
|
151
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|
quack/reduce.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import operator
|
|
5
|
+
from typing import Callable, Optional
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cute.jit
|
|
15
|
+
def warp_reduce(
|
|
16
|
+
val: cute.TensorSSA | cute.Numeric,
|
|
17
|
+
op: Callable,
|
|
18
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
19
|
+
) -> cute.TensorSSA | cute.Numeric:
|
|
20
|
+
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
21
|
+
res = cute.make_fragment(val.shape, val.dtype)
|
|
22
|
+
res.store(val)
|
|
23
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
24
|
+
res[i] = warp_reduce(res[i], op, width)
|
|
25
|
+
return res.load()
|
|
26
|
+
else:
|
|
27
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
28
|
+
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
29
|
+
return val
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@cute.jit
|
|
33
|
+
def block_reduce(
|
|
34
|
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
35
|
+
) -> cute.Numeric:
|
|
36
|
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
37
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
38
|
+
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
39
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
40
|
+
if lane_idx == 0:
|
|
41
|
+
reduction_buffer[row_idx, col_idx] = val
|
|
42
|
+
cute.arch.barrier()
|
|
43
|
+
block_reduce_val = init_val
|
|
44
|
+
if lane_idx < warps_per_row:
|
|
45
|
+
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
46
|
+
return warp_reduce(block_reduce_val, op)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@cute.jit
|
|
50
|
+
def cluster_reduce(
|
|
51
|
+
val: cute.Numeric,
|
|
52
|
+
op: Callable,
|
|
53
|
+
reduction_buffer: cute.Tensor,
|
|
54
|
+
mbar_ptr: cute.Pointer,
|
|
55
|
+
init_val: cute.Numeric = 0.0,
|
|
56
|
+
phase: Optional[cutlass.Int32] = None,
|
|
57
|
+
) -> cute.Numeric:
|
|
58
|
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
59
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
60
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
61
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
62
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
63
|
+
if warp_idx == 0:
|
|
64
|
+
with cute.arch.elect_one():
|
|
65
|
+
num_warps = rows_per_block * warps_per_row
|
|
66
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
67
|
+
mbar_ptr,
|
|
68
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
69
|
+
)
|
|
70
|
+
if lane_idx < cluster_n:
|
|
71
|
+
utils.store_shared_remote(
|
|
72
|
+
val,
|
|
73
|
+
utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
74
|
+
mbar_ptr,
|
|
75
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
76
|
+
)
|
|
77
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
78
|
+
block_reduce_val = init_val
|
|
79
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
80
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
81
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
82
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
83
|
+
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
84
|
+
return warp_reduce(block_reduce_val, op)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def block_or_cluster_reduce(
|
|
89
|
+
val: cute.Numeric,
|
|
90
|
+
op: Callable,
|
|
91
|
+
reduction_buffer: cute.Tensor,
|
|
92
|
+
mbar_ptr: Optional[cute.Pointer],
|
|
93
|
+
phase: Optional[cutlass.Int32] = None,
|
|
94
|
+
init_val: cute.Numeric = 0.0,
|
|
95
|
+
) -> cute.Numeric:
|
|
96
|
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
97
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
98
|
+
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
99
|
+
else:
|
|
100
|
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@cute.jit
|
|
104
|
+
def row_reduce(
|
|
105
|
+
x: cute.TensorSSA | cute.Numeric,
|
|
106
|
+
op: cute.ReductionOp,
|
|
107
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
108
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
109
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
110
|
+
phase: Optional[cutlass.Int32] = None,
|
|
111
|
+
init_val: cute.Numeric = 0.0,
|
|
112
|
+
hook_fn: Optional[Callable] = None,
|
|
113
|
+
) -> cute.Numeric:
|
|
114
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
115
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
116
|
+
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
117
|
+
else:
|
|
118
|
+
val = x
|
|
119
|
+
warp_op = {
|
|
120
|
+
cute.ReductionOp.ADD: operator.add,
|
|
121
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
122
|
+
cute.ReductionOp.MIN: min,
|
|
123
|
+
cute.ReductionOp.MUL: operator.mul,
|
|
124
|
+
}[op]
|
|
125
|
+
val = warp_reduce(
|
|
126
|
+
val,
|
|
127
|
+
warp_op,
|
|
128
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
129
|
+
)
|
|
130
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
131
|
+
hook_fn()
|
|
132
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
133
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
134
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
135
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
136
|
+
)
|
|
137
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
138
|
+
val = block_or_cluster_reduce(
|
|
139
|
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
140
|
+
)
|
|
141
|
+
return val
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@cute.jit
|
|
145
|
+
def online_softmax_reduce(
|
|
146
|
+
x: cute.TensorSSA,
|
|
147
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
148
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
149
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
150
|
+
hook_fn: Optional[Callable] = None,
|
|
151
|
+
phase: Optional[cutlass.Int32] = None,
|
|
152
|
+
return_exp_x: bool = False,
|
|
153
|
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
154
|
+
assert x.dtype == Float32, "x must be of type Float32"
|
|
155
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
156
|
+
max_x = warp_reduce(
|
|
157
|
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
158
|
+
cute.arch.fmax,
|
|
159
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
160
|
+
)
|
|
161
|
+
log2_e = math.log2(math.e)
|
|
162
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
163
|
+
# exp_x = exp2f((x - max_x) * log2_e)
|
|
164
|
+
sum_exp_x = warp_reduce(
|
|
165
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
166
|
+
operator.add,
|
|
167
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
168
|
+
)
|
|
169
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
170
|
+
hook_fn()
|
|
171
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
172
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
173
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
174
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
175
|
+
)
|
|
176
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
177
|
+
assert reduction_buffer.element_type == cutlass.Int64, (
|
|
178
|
+
"reduction_buffer must be of type cute.Int64"
|
|
179
|
+
)
|
|
180
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
181
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
182
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
183
|
+
if lane_idx == 0:
|
|
184
|
+
reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
|
|
185
|
+
cute.arch.barrier()
|
|
186
|
+
max_x_single_warp = -Float32.inf
|
|
187
|
+
sum_exp_x = 0.0
|
|
188
|
+
if lane_idx < warps_per_row:
|
|
189
|
+
max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
|
|
190
|
+
reduction_buffer[row_idx, lane_idx]
|
|
191
|
+
)
|
|
192
|
+
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
193
|
+
sum_exp_x *= utils.exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
194
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
195
|
+
if cutlass.const_expr(return_exp_x):
|
|
196
|
+
exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
|
|
197
|
+
max_x = max_x_final
|
|
198
|
+
else:
|
|
199
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
200
|
+
if warp_idx == 0:
|
|
201
|
+
with cute.arch.elect_one():
|
|
202
|
+
num_warps = rows_per_block * warps_per_row
|
|
203
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
204
|
+
mbar_ptr,
|
|
205
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
206
|
+
)
|
|
207
|
+
if lane_idx < cluster_n:
|
|
208
|
+
utils.store_shared_remote(
|
|
209
|
+
utils.f32x2_to_i64(max_x, sum_exp_x),
|
|
210
|
+
utils.elem_pointer(
|
|
211
|
+
reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
|
|
212
|
+
),
|
|
213
|
+
mbar_ptr,
|
|
214
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
215
|
+
)
|
|
216
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
217
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
218
|
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
219
|
+
max_x_single_warp.fill(-Float32.inf)
|
|
220
|
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
221
|
+
sum_exp_x_single_warp.fill(0.0)
|
|
222
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
223
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
224
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
225
|
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
|
|
226
|
+
reduction_buffer[row_idx, idx]
|
|
227
|
+
)
|
|
228
|
+
max_x_final = max_x_single_warp.load().reduce(
|
|
229
|
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
230
|
+
)
|
|
231
|
+
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
232
|
+
sum_exp_x = 0.0
|
|
233
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
234
|
+
sum_exp_x += sum_exp_x_single_warp[i] * utils.exp2f(
|
|
235
|
+
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
236
|
+
)
|
|
237
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
238
|
+
if cutlass.const_expr(return_exp_x):
|
|
239
|
+
exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
|
|
240
|
+
max_x = max_x_final
|
|
241
|
+
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
quack/reduction_base.py
CHANGED
|
@@ -1,19 +1,11 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
3
|
from typing import Type, Tuple, Optional
|
|
5
4
|
|
|
6
5
|
import cutlass
|
|
7
6
|
import cutlass.cute as cute
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
torch2cute_dtype_map = {
|
|
11
|
-
torch.float16: cutlass.Float16,
|
|
12
|
-
torch.bfloat16: cutlass.BFloat16,
|
|
13
|
-
torch.float32: cutlass.Float32,
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
|
|
17
9
|
class ReductionBase:
|
|
18
10
|
def __init__(
|
|
19
11
|
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
@@ -32,9 +24,8 @@ class ReductionBase:
|
|
|
32
24
|
def _get_num_threads(self):
|
|
33
25
|
return 128 if self.N <= 16384 else 256
|
|
34
26
|
|
|
35
|
-
def _get_tv_layout(self):
|
|
36
|
-
|
|
37
|
-
vecsize = copy_bits // self.dtype.width
|
|
27
|
+
def _get_tv_layout(self, num_copy_bits=128):
|
|
28
|
+
vecsize = num_copy_bits // self.dtype.width
|
|
38
29
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
39
30
|
num_threads = self._get_num_threads()
|
|
40
31
|
assert num_threads % cute.arch.WARP_SIZE == 0
|