quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/linear.py
CHANGED
|
@@ -61,10 +61,11 @@ class LinearFunc(torch.autograd.Function):
|
|
|
61
61
|
# Use classmethod instead of staticmethod to allow inheritance
|
|
62
62
|
@classmethod
|
|
63
63
|
@custom_fwd(device_type="cuda")
|
|
64
|
-
def forward(cls, ctx, x, weight, fuse_grad_accum=False):
|
|
64
|
+
def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False):
|
|
65
65
|
"""
|
|
66
66
|
x: (..., in_features)
|
|
67
67
|
weight: (out_features, in_features)
|
|
68
|
+
bias: (out_features,) or None
|
|
68
69
|
out: (..., out_features)
|
|
69
70
|
"""
|
|
70
71
|
ctx.weight_dtype = weight.dtype
|
|
@@ -74,8 +75,9 @@ class LinearFunc(torch.autograd.Function):
|
|
|
74
75
|
batch_shape = x.shape[:-1]
|
|
75
76
|
x = x.reshape(-1, x.shape[-1])
|
|
76
77
|
# out = F.linear(x, weight)
|
|
77
|
-
out = cls.matmul_fwd_fn(x, weight.T)
|
|
78
|
+
out = cls.matmul_fwd_fn(x, weight.T, bias=bias)
|
|
78
79
|
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
80
|
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
|
79
81
|
return out.reshape(*batch_shape, out.shape[-1])
|
|
80
82
|
|
|
81
83
|
@classmethod
|
|
@@ -87,13 +89,18 @@ class LinearFunc(torch.autograd.Function):
|
|
|
87
89
|
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
88
90
|
batch_shape = dout.shape[:-1]
|
|
89
91
|
dout = dout.reshape(-1, dout.shape[-1])
|
|
92
|
+
dbias = (
|
|
93
|
+
dout.sum(0, dtype=ctx.bias_dtype)
|
|
94
|
+
if ctx.bias_dtype is not None and ctx.needs_input_grad[2]
|
|
95
|
+
else None
|
|
96
|
+
)
|
|
90
97
|
dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
|
|
91
98
|
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
92
99
|
dweight = linear_bwd_compute_weight_grad(
|
|
93
100
|
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
94
101
|
)
|
|
95
102
|
# return extra Nones for other classes that inherit from LinearFunc
|
|
96
|
-
return dx, dweight, *([None] * 10)
|
|
103
|
+
return dx, dweight, dbias, *([None] * 10)
|
|
97
104
|
|
|
98
105
|
|
|
99
106
|
class LinearUntunedFunc(LinearFunc):
|
|
@@ -104,9 +111,9 @@ class LinearUntunedFunc(LinearFunc):
|
|
|
104
111
|
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
105
112
|
|
|
106
113
|
|
|
107
|
-
def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
|
|
114
|
+
def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
|
|
108
115
|
fn_cls = LinearFunc if tuned else LinearUntunedFunc
|
|
109
|
-
return fn_cls.apply(x, weight, fuse_grad_accum)
|
|
116
|
+
return fn_cls.apply(x, weight, bias, fuse_grad_accum)
|
|
110
117
|
|
|
111
118
|
|
|
112
119
|
class LinearActFunc(LinearFunc):
|
|
@@ -115,10 +122,13 @@ class LinearActFunc(LinearFunc):
|
|
|
115
122
|
# Use classmethod instead of staticmethod to allow inheritance
|
|
116
123
|
@classmethod
|
|
117
124
|
@custom_fwd(device_type="cuda")
|
|
118
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False
|
|
127
|
+
):
|
|
119
128
|
"""
|
|
120
129
|
x: (..., in_features)
|
|
121
130
|
weight: (out_features, in_features)
|
|
131
|
+
bias: (out_features,) or None
|
|
122
132
|
out: (..., out_features)
|
|
123
133
|
Return both out and post-activation, but only out is differentiable.
|
|
124
134
|
"""
|
|
@@ -129,11 +139,12 @@ class LinearActFunc(LinearFunc):
|
|
|
129
139
|
batch_shape = x.shape[:-1]
|
|
130
140
|
x = x.reshape(-1, x.shape[-1])
|
|
131
141
|
out, postact = cls.matmul_fwd_fn(
|
|
132
|
-
x, weight.T, activation=activation, store_preact=store_preact
|
|
142
|
+
x, weight.T, bias=bias, activation=activation, store_preact=store_preact
|
|
133
143
|
)
|
|
134
144
|
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
135
145
|
if out is not None:
|
|
136
146
|
out = out.reshape(*batch_shape, out.shape[-1])
|
|
147
|
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
|
137
148
|
ctx.mark_non_differentiable(postact)
|
|
138
149
|
ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
|
|
139
150
|
return out, postact.reshape(*batch_shape, postact.shape[-1])
|
|
@@ -147,9 +158,11 @@ class LinearActUntunedFunc(LinearActFunc):
|
|
|
147
158
|
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
148
159
|
|
|
149
160
|
|
|
150
|
-
def linear_act_func(
|
|
161
|
+
def linear_act_func(
|
|
162
|
+
x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
|
|
163
|
+
):
|
|
151
164
|
fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
|
|
152
|
-
return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
|
|
165
|
+
return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
|
|
153
166
|
|
|
154
167
|
|
|
155
168
|
class DActLinearFunc(LinearFunc):
|
|
@@ -229,12 +242,7 @@ class Linear(nn.Linear):
|
|
|
229
242
|
self.fuse_grad_accum = fuse_grad_accum
|
|
230
243
|
|
|
231
244
|
def forward(self, input: Tensor) -> Tensor:
|
|
232
|
-
if
|
|
233
|
-
self.bias
|
|
234
|
-
and input.is_cuda
|
|
235
|
-
and self.in_features % 8 == 0
|
|
236
|
-
and self.out_features % 8 == 0
|
|
237
|
-
):
|
|
238
|
-
return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
245
|
+
if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
|
|
246
|
+
return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
|
|
239
247
|
else:
|
|
240
248
|
return F.linear(input, self.weight, self.bias)
|
quack/pipeline.py
CHANGED
|
@@ -4,9 +4,11 @@ from typing import Optional
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
6
|
import cutlass.cute as cute
|
|
7
|
-
from cutlass
|
|
8
|
-
from cutlass.
|
|
7
|
+
from cutlass import Boolean, Int32, const_expr
|
|
8
|
+
from cutlass.cutlass_dsl import if_generate, and_
|
|
9
|
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
9
10
|
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
11
|
+
from cutlass.pipeline import PipelineTmaUmma
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class PipelineStateWAdvance(PipelineState):
|
|
@@ -144,7 +146,160 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
144
146
|
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
145
147
|
)
|
|
146
148
|
|
|
147
|
-
def
|
|
149
|
+
def producer_cpasync_commit(self, state: PipelineState):
|
|
150
|
+
"""
|
|
151
|
+
We need the mbarrier to track the completion of cp.async
|
|
152
|
+
"""
|
|
153
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class MbarrierArrayWDropCount(MbarrierArray):
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
barrier_storage: cute.Pointer,
|
|
160
|
+
num_stages: int,
|
|
161
|
+
agent: tuple[PipelineOp, CooperativeGroup],
|
|
162
|
+
tx_count: int = 0,
|
|
163
|
+
drop_count: Optional[Int32] = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
self.barrier_storage = barrier_storage
|
|
166
|
+
self.tx_count = tx_count
|
|
167
|
+
self.num_stages = num_stages
|
|
168
|
+
self.op_type, self.cg = agent
|
|
169
|
+
self.arrive_count = self.cg.size
|
|
170
|
+
self.drop_count = drop_count
|
|
171
|
+
|
|
172
|
+
if self.num_stages <= 0:
|
|
173
|
+
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
|
|
174
|
+
if self.arrive_count <= 0:
|
|
175
|
+
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
|
|
176
|
+
if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
|
|
177
|
+
raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
|
|
178
|
+
|
|
179
|
+
if const_expr(drop_count is not None):
|
|
180
|
+
self.arrive_count = self.arrive_count - drop_count
|
|
181
|
+
|
|
182
|
+
# Store mbarrier base pointer
|
|
183
|
+
self.mbarrier_base = self.barrier_storage
|
|
184
|
+
|
|
185
|
+
# Mbarrier initialization in constructor
|
|
186
|
+
self.mbarrier_init()
|
|
187
|
+
|
|
188
|
+
def __extract_mlir_values__(self):
|
|
189
|
+
return [self.barrier_storage, self.drop_count]
|
|
190
|
+
|
|
191
|
+
def __new_from_mlir_values__(self, values):
|
|
192
|
+
return MbarrierArrayWDropCount(
|
|
193
|
+
values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass(frozen=True)
|
|
198
|
+
class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
199
|
+
"""
|
|
200
|
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
|
|
201
|
+
(e.g. Blackwell mainloops)
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
@staticmethod
|
|
205
|
+
def create(
|
|
206
|
+
*,
|
|
207
|
+
num_stages: int,
|
|
208
|
+
producer_group: CooperativeGroup,
|
|
209
|
+
consumer_group: CooperativeGroup,
|
|
210
|
+
tx_count: int,
|
|
211
|
+
barrier_storage: cute.Pointer = None,
|
|
212
|
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
213
|
+
producer_drop_count: Optional[Int32] = None,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
|
217
|
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
218
|
+
:type barrier_storage: cute.Pointer
|
|
219
|
+
:param num_stages: Number of buffer stages for this pipeline
|
|
220
|
+
:type num_stages: Int32
|
|
221
|
+
:param producer_group: `CooperativeGroup` for the producer agent
|
|
222
|
+
:type producer_group: CooperativeGroup
|
|
223
|
+
:param consumer_group: `CooperativeGroup` for the consumer agent
|
|
224
|
+
:type consumer_group: CooperativeGroup
|
|
225
|
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
226
|
+
:type tx_count: int
|
|
227
|
+
:param cta_layout_vmnk: Layout of the cluster shape
|
|
228
|
+
:type cta_layout_vmnk: cute.Layout | None
|
|
229
|
+
"""
|
|
230
|
+
if not isinstance(barrier_storage, cute.Pointer):
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
producer_type = PipelineOp.TmaLoad
|
|
236
|
+
consumer_type = PipelineOp.TCGen05Mma
|
|
237
|
+
|
|
238
|
+
producer = (producer_type, producer_group)
|
|
239
|
+
consumer = (consumer_type, consumer_group)
|
|
240
|
+
|
|
241
|
+
sync_object_full = MbarrierArrayWDropCount(
|
|
242
|
+
barrier_storage.align(min_align=8),
|
|
243
|
+
num_stages,
|
|
244
|
+
producer,
|
|
245
|
+
tx_count,
|
|
246
|
+
drop_count=producer_drop_count,
|
|
247
|
+
)
|
|
248
|
+
sync_object_empty = PipelineAsync._make_sync_object(
|
|
249
|
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
253
|
+
# No mcast mask if not using clusters
|
|
254
|
+
producer_mask = None
|
|
255
|
+
# All threadblocks are leaders if not using clusters
|
|
256
|
+
is_leader_cta = True
|
|
257
|
+
else:
|
|
258
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
|
259
|
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
|
260
|
+
|
|
261
|
+
cta_group = (
|
|
262
|
+
cute.nvgpu.tcgen05.CtaGroup.ONE
|
|
263
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
|
264
|
+
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
consumer_mask = producer_mask
|
|
268
|
+
|
|
269
|
+
pipeline_init_wait(cta_layout_vmnk)
|
|
270
|
+
|
|
271
|
+
return PipelineTmaCpAsyncUmma(
|
|
272
|
+
sync_object_full,
|
|
273
|
+
sync_object_empty,
|
|
274
|
+
num_stages,
|
|
275
|
+
producer_mask,
|
|
276
|
+
consumer_mask,
|
|
277
|
+
is_leader_cta,
|
|
278
|
+
cta_group,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def producer_acquire(
|
|
282
|
+
self,
|
|
283
|
+
state: PipelineState,
|
|
284
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
285
|
+
is_tma_warp: Optional[Boolean] = True,
|
|
286
|
+
):
|
|
287
|
+
"""
|
|
288
|
+
TMA producer commit conditionally waits on buffer empty and sets the
|
|
289
|
+
transaction barrier for leader threadblocks.
|
|
290
|
+
"""
|
|
291
|
+
if_generate(
|
|
292
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
293
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
294
|
+
)
|
|
295
|
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
296
|
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
297
|
+
if_generate(
|
|
298
|
+
and_(self.is_leader_cta, is_tma_warp),
|
|
299
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def producer_cpasync_commit(self, state: PipelineState):
|
|
148
303
|
"""
|
|
149
304
|
We need the mbarrier to track the completion of cp.async
|
|
150
305
|
"""
|
quack/reduce.py
CHANGED
|
@@ -6,29 +6,11 @@ from typing import Callable, Optional
|
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
|
-
from cutlass import Float32
|
|
9
|
+
from cutlass import Int32, Int64, Float32, Boolean, const_expr
|
|
10
10
|
|
|
11
11
|
import quack.utils as utils
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
@cute.jit
|
|
15
|
-
def warp_reduce(
|
|
16
|
-
val: cute.TensorSSA | cute.Numeric,
|
|
17
|
-
op: Callable,
|
|
18
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
19
|
-
) -> cute.TensorSSA | cute.Numeric:
|
|
20
|
-
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
21
|
-
res = cute.make_fragment(val.shape, val.dtype)
|
|
22
|
-
res.store(val)
|
|
23
|
-
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
24
|
-
res[i] = warp_reduce(res[i], op, width)
|
|
25
|
-
return res.load()
|
|
26
|
-
else:
|
|
27
|
-
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
28
|
-
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
29
|
-
return val
|
|
30
|
-
|
|
31
|
-
|
|
32
14
|
@cute.jit
|
|
33
15
|
def block_reduce(
|
|
34
16
|
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
@@ -43,7 +25,7 @@ def block_reduce(
|
|
|
43
25
|
block_reduce_val = init_val
|
|
44
26
|
if lane_idx < warps_per_row:
|
|
45
27
|
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
46
|
-
return
|
|
28
|
+
return cute.arch.warp_reduction(block_reduce_val, op)
|
|
47
29
|
|
|
48
30
|
|
|
49
31
|
@cute.jit
|
|
@@ -53,7 +35,7 @@ def cluster_reduce(
|
|
|
53
35
|
reduction_buffer: cute.Tensor,
|
|
54
36
|
mbar_ptr: cute.Pointer,
|
|
55
37
|
init_val: cute.Numeric = 0.0,
|
|
56
|
-
phase: Optional[
|
|
38
|
+
phase: Optional[Int32] = None,
|
|
57
39
|
) -> cute.Numeric:
|
|
58
40
|
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
59
41
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
@@ -81,7 +63,7 @@ def cluster_reduce(
|
|
|
81
63
|
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
82
64
|
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
83
65
|
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
84
|
-
return
|
|
66
|
+
return cute.arch.warp_reduction(block_reduce_val, op)
|
|
85
67
|
|
|
86
68
|
|
|
87
69
|
@cute.jit
|
|
@@ -90,11 +72,11 @@ def block_or_cluster_reduce(
|
|
|
90
72
|
op: Callable,
|
|
91
73
|
reduction_buffer: cute.Tensor,
|
|
92
74
|
mbar_ptr: Optional[cute.Pointer],
|
|
93
|
-
phase: Optional[
|
|
75
|
+
phase: Optional[Int32] = None,
|
|
94
76
|
init_val: cute.Numeric = 0.0,
|
|
95
77
|
) -> cute.Numeric:
|
|
96
78
|
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
97
|
-
if
|
|
79
|
+
if const_expr(mbar_ptr is None):
|
|
98
80
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
99
81
|
else:
|
|
100
82
|
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
@@ -107,34 +89,34 @@ def row_reduce(
|
|
|
107
89
|
threads_per_row: cutlass.Constexpr[int],
|
|
108
90
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
109
91
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
110
|
-
phase: Optional[
|
|
92
|
+
phase: Optional[Int32] = None,
|
|
111
93
|
init_val: cute.Numeric = 0.0,
|
|
112
94
|
hook_fn: Optional[Callable] = None,
|
|
113
95
|
) -> cute.Numeric:
|
|
114
96
|
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
115
|
-
if
|
|
97
|
+
if const_expr(isinstance(x, cute.TensorSSA)):
|
|
116
98
|
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
117
99
|
else:
|
|
118
100
|
val = x
|
|
119
101
|
warp_op = {
|
|
120
102
|
cute.ReductionOp.ADD: operator.add,
|
|
121
|
-
cute.ReductionOp.MAX: cute.arch.fmax if
|
|
103
|
+
cute.ReductionOp.MAX: cute.arch.fmax if const_expr(x.dtype == Float32) else max,
|
|
122
104
|
cute.ReductionOp.MIN: min,
|
|
123
105
|
cute.ReductionOp.MUL: operator.mul,
|
|
124
106
|
}[op]
|
|
125
|
-
val =
|
|
107
|
+
val = cute.arch.warp_reduction(
|
|
126
108
|
val,
|
|
127
109
|
warp_op,
|
|
128
|
-
|
|
110
|
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
129
111
|
)
|
|
130
|
-
if
|
|
112
|
+
if const_expr(hook_fn is not None):
|
|
131
113
|
hook_fn()
|
|
132
|
-
if
|
|
114
|
+
if const_expr(reduction_buffer is not None):
|
|
133
115
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
134
116
|
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
135
117
|
"mbar_ptr must be provided for cluster reduction"
|
|
136
118
|
)
|
|
137
|
-
if
|
|
119
|
+
if const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
138
120
|
val = block_or_cluster_reduce(
|
|
139
121
|
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
140
122
|
)
|
|
@@ -148,37 +130,37 @@ def online_softmax_reduce(
|
|
|
148
130
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
149
131
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
150
132
|
hook_fn: Optional[Callable] = None,
|
|
151
|
-
phase: Optional[
|
|
133
|
+
phase: Optional[Int32] = None,
|
|
152
134
|
return_exp_x: bool = False,
|
|
153
135
|
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
154
136
|
assert x.dtype == Float32, "x must be of type Float32"
|
|
155
137
|
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
156
|
-
max_x =
|
|
138
|
+
max_x = cute.arch.warp_reduction(
|
|
157
139
|
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
158
140
|
cute.arch.fmax,
|
|
159
|
-
|
|
141
|
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
160
142
|
)
|
|
161
143
|
log2_e = math.log2(math.e)
|
|
162
144
|
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
163
|
-
sum_exp_x =
|
|
145
|
+
sum_exp_x = cute.arch.warp_reduction(
|
|
164
146
|
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
165
147
|
operator.add,
|
|
166
|
-
|
|
148
|
+
threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
167
149
|
)
|
|
168
|
-
if
|
|
150
|
+
if const_expr(hook_fn is not None):
|
|
169
151
|
hook_fn()
|
|
170
|
-
if
|
|
152
|
+
if const_expr(reduction_buffer is not None):
|
|
171
153
|
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
172
154
|
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
173
155
|
"mbar_ptr must be provided for cluster reduction"
|
|
174
156
|
)
|
|
175
|
-
if
|
|
176
|
-
assert reduction_buffer.element_type ==
|
|
157
|
+
if const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
158
|
+
assert reduction_buffer.element_type == Int64, (
|
|
177
159
|
"reduction_buffer must be of type cute.Int64"
|
|
178
160
|
)
|
|
179
161
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
180
162
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
181
|
-
if
|
|
163
|
+
if const_expr(mbar_ptr is None):
|
|
182
164
|
if lane_idx == 0:
|
|
183
165
|
reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
|
|
184
166
|
cute.arch.barrier()
|
|
@@ -188,10 +170,10 @@ def online_softmax_reduce(
|
|
|
188
170
|
max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
|
|
189
171
|
reduction_buffer[row_idx, lane_idx]
|
|
190
172
|
)
|
|
191
|
-
max_x_final =
|
|
173
|
+
max_x_final = cute.arch.warp_reduction(max_x_single_warp, cute.arch.fmax)
|
|
192
174
|
sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
|
|
193
|
-
sum_exp_x =
|
|
194
|
-
if
|
|
175
|
+
sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
|
|
176
|
+
if const_expr(return_exp_x):
|
|
195
177
|
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
196
178
|
max_x = max_x_final
|
|
197
179
|
else:
|
|
@@ -227,14 +209,71 @@ def online_softmax_reduce(
|
|
|
227
209
|
max_x_final = max_x_single_warp.load().reduce(
|
|
228
210
|
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
229
211
|
)
|
|
230
|
-
max_x_final =
|
|
212
|
+
max_x_final = cute.arch.warp_reduction(max_x_final, cute.arch.fmax)
|
|
231
213
|
sum_exp_x = 0.0
|
|
232
214
|
for i in cutlass.range_constexpr(num_iter):
|
|
233
215
|
sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
|
|
234
216
|
max_x_single_warp[i] - max_x_final, fastmath=True
|
|
235
217
|
)
|
|
236
|
-
sum_exp_x =
|
|
237
|
-
if
|
|
218
|
+
sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
|
|
219
|
+
if const_expr(return_exp_x):
|
|
238
220
|
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
239
221
|
max_x = max_x_final
|
|
240
|
-
return max_x, sum_exp_x, (exp_x if
|
|
222
|
+
return max_x, sum_exp_x, (exp_x if const_expr(return_exp_x) else None)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@cute.jit
|
|
226
|
+
def sum_swap_shuffle(
|
|
227
|
+
X: cute.Tensor, elem_per_lane: int = 1, subwarp_size: int = 1, warp_size: int = 32
|
|
228
|
+
) -> cute.Tensor:
|
|
229
|
+
"""
|
|
230
|
+
For warp reduction, we use Swap Shuffle
|
|
231
|
+
The normal way to reduction among threads:
|
|
232
|
+
use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads.
|
|
233
|
+
After each step of reduction, a half of threads won't work in the following steps.
|
|
234
|
+
That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case).
|
|
235
|
+
To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors,
|
|
236
|
+
we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads.
|
|
237
|
+
After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step.
|
|
238
|
+
We can recursively do this until the problem size is 1.
|
|
239
|
+
"""
|
|
240
|
+
assert (
|
|
241
|
+
subwarp_size >= 1
|
|
242
|
+
and subwarp_size <= 32
|
|
243
|
+
and subwarp_size == 1 << int(math.log2(subwarp_size))
|
|
244
|
+
)
|
|
245
|
+
assert (
|
|
246
|
+
warp_size <= 32
|
|
247
|
+
and warp_size % subwarp_size == 0
|
|
248
|
+
and warp_size == 1 << int(math.log2(warp_size))
|
|
249
|
+
)
|
|
250
|
+
lane_idx = cute.arch.lane_idx() // subwarp_size
|
|
251
|
+
X = cute.logical_divide(X, cute.make_layout(elem_per_lane)) # (elem_per_lane, M)
|
|
252
|
+
numvec = cute.size(X, mode=[1])
|
|
253
|
+
assert numvec <= 32 // subwarp_size
|
|
254
|
+
# If X has more values than warp_size // subwarp_size, we first do a normal warp reduction
|
|
255
|
+
# to sum up values held by lanes further than size(X) away
|
|
256
|
+
for i in cutlass.range(
|
|
257
|
+
int(math.log2(numvec)), int(math.log2(warp_size // subwarp_size)), unroll_full=True
|
|
258
|
+
):
|
|
259
|
+
for v in cutlass.range(cute.size(X), unroll_full=True):
|
|
260
|
+
shfl_val = cute.arch.shuffle_sync_bfly(X[v], offset=(1 << i) * subwarp_size)
|
|
261
|
+
X[v] = X[v] + shfl_val
|
|
262
|
+
for logm in cutlass.range_constexpr(int(math.log2(cute.size(X, mode=[1]))) - 1, -1, -1):
|
|
263
|
+
m = 1 << logm
|
|
264
|
+
for r in cutlass.range(m, unroll_full=True):
|
|
265
|
+
frg_A = X[None, r]
|
|
266
|
+
frg_B = X[None, r + m]
|
|
267
|
+
# First half of threads swap fragments from the first half of data to the second
|
|
268
|
+
should_swap = not Boolean(lane_idx & m)
|
|
269
|
+
for v in cutlass.range(cute.size(frg_A), unroll_full=True):
|
|
270
|
+
# Step 1: swap
|
|
271
|
+
lower, upper = frg_A[v], frg_B[v]
|
|
272
|
+
frg_A[v] = upper if should_swap else lower
|
|
273
|
+
frg_B[v] = lower if should_swap else upper
|
|
274
|
+
# Step 2: shuffle
|
|
275
|
+
# each half of threads get a half of data from the other half of threads
|
|
276
|
+
shfl_val = cute.arch.shuffle_sync_bfly(frg_A[v], offset=m * subwarp_size)
|
|
277
|
+
# Step 3: reduction
|
|
278
|
+
frg_A[v] = frg_B[v] + shfl_val
|
|
279
|
+
return X[None, 0]
|
quack/reduction_base.py
CHANGED
|
@@ -4,55 +4,44 @@ from typing import Type, Tuple, Optional
|
|
|
4
4
|
|
|
5
5
|
import cutlass
|
|
6
6
|
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Int32, Int64, Float32, const_expr
|
|
8
|
+
|
|
9
|
+
import quack.copy_utils as copy_utils
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class ReductionBase:
|
|
10
|
-
def __init__(
|
|
11
|
-
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
12
|
-
):
|
|
13
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32):
|
|
13
14
|
self.dtype = dtype
|
|
14
15
|
self.N = N
|
|
15
16
|
self.stage = stage
|
|
16
17
|
self.reduction_dtype = reduction_dtype
|
|
17
18
|
|
|
18
|
-
def
|
|
19
|
+
def _threads_per_row(self):
|
|
19
20
|
raise NotImplementedError()
|
|
20
21
|
|
|
22
|
+
def _num_threads(self):
|
|
23
|
+
return 128 if self.N <= 16384 else 256
|
|
24
|
+
|
|
21
25
|
def _set_cluster_n(self):
|
|
22
26
|
self.cluster_n = 1
|
|
23
27
|
|
|
24
|
-
def
|
|
25
|
-
return 128 if self.N <= 16384 else 256
|
|
26
|
-
|
|
27
|
-
def _get_tv_layout(self, num_copy_bits=128):
|
|
28
|
-
vecsize = num_copy_bits // self.dtype.width
|
|
28
|
+
def _get_tiled_copy(self, vecsize: int = 1):
|
|
29
29
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
30
|
-
|
|
30
|
+
threads_per_row = self._threads_per_row()
|
|
31
|
+
num_threads = self._num_threads()
|
|
31
32
|
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
32
|
-
|
|
33
|
-
threads_per_row = self._calculate_threads_per_row()
|
|
34
33
|
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
39
|
-
stride=(
|
|
40
|
-
(vecsize * cols_per_block, 1),
|
|
41
|
-
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
42
|
-
),
|
|
43
|
-
)
|
|
44
|
-
return tiler_mn, tv_layout
|
|
45
|
-
|
|
46
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
47
|
-
return (
|
|
48
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
49
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
50
|
-
+ self.stage * (cutlass.Int64.width // 8)
|
|
51
|
-
)
|
|
34
|
+
tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row)
|
|
35
|
+
tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize)
|
|
36
|
+
return tiled_copy, tiler_mn, threads_per_row
|
|
52
37
|
|
|
53
38
|
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
54
39
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
55
|
-
warps_per_row =
|
|
40
|
+
warps_per_row = (
|
|
41
|
+
num_warps
|
|
42
|
+
if cute.rank(tv_layout.shape[0]) == 1
|
|
43
|
+
else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
44
|
+
)
|
|
56
45
|
return cute.make_ordered_layout(
|
|
57
46
|
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
58
47
|
order=(1, 0, 2),
|
|
@@ -64,11 +53,11 @@ class ReductionBase:
|
|
|
64
53
|
reduction_buffer = smem.allocate_tensor(
|
|
65
54
|
self.reduction_dtype,
|
|
66
55
|
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
|
67
|
-
byte_alignment=
|
|
56
|
+
byte_alignment=8,
|
|
68
57
|
)
|
|
69
|
-
if
|
|
58
|
+
if const_expr(self.cluster_n > 1):
|
|
70
59
|
mbar_ptr = smem.allocate_array(
|
|
71
|
-
|
|
60
|
+
Int64, num_elems=self.stage if not is_persistent else self.stage * 2
|
|
72
61
|
)
|
|
73
62
|
else:
|
|
74
63
|
mbar_ptr = None
|
|
@@ -77,15 +66,15 @@ class ReductionBase:
|
|
|
77
66
|
@cute.jit
|
|
78
67
|
def _initialize_cluster(
|
|
79
68
|
self,
|
|
80
|
-
tidx:
|
|
69
|
+
tidx: Int32,
|
|
81
70
|
mbar_ptr: cute.Pointer,
|
|
82
71
|
num_warps: int,
|
|
83
72
|
is_persistent: bool = False,
|
|
84
73
|
):
|
|
85
|
-
if
|
|
74
|
+
if const_expr(self.cluster_n > 1):
|
|
86
75
|
if tidx < self.stage: # Initialize full barrier
|
|
87
76
|
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
88
|
-
if
|
|
77
|
+
if const_expr(is_persistent): # Initialize empty barrier
|
|
89
78
|
cute.arch.mbarrier_init(
|
|
90
79
|
mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
|
|
91
80
|
)
|