quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/layout_utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
from cutlass import Int32, const_expr
|
|
8
|
+
|
|
9
|
+
from quack.utils import prmt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
13
|
+
"""Transpose the first two dimensions of a tensor on smem."""
|
|
14
|
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
15
|
+
order = (1, 0, *range(2, cute.rank(a)))
|
|
16
|
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
20
|
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
|
24
|
+
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
|
25
|
+
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
|
26
|
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@cute.jit
|
|
30
|
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
31
|
+
assert t.element_type.width == 16
|
|
32
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
33
|
+
t_u32 = cute.recast_tensor(t, Int32)
|
|
34
|
+
|
|
35
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
36
|
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
37
|
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
38
|
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
39
|
+
# upper_map = [0, 3, 1, 2]
|
|
40
|
+
# lower_map = [1, 2, 0, 3]
|
|
41
|
+
# upper_idx = upper_map[quad_idx]
|
|
42
|
+
# indexing isn't supported so we have to do arithmetic
|
|
43
|
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
44
|
+
lower_idx = upper_idx ^ 1
|
|
45
|
+
|
|
46
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
47
|
+
width = 4
|
|
48
|
+
mask = cute.arch.WARP_SIZE - width
|
|
49
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
50
|
+
mask_and_clamp = mask << 8 | clamp
|
|
51
|
+
|
|
52
|
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
53
|
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
54
|
+
upper0 = upper if lane_03 else lower
|
|
55
|
+
lower0 = lower if lane_03 else upper
|
|
56
|
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
57
|
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
58
|
+
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
59
|
+
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@cute.jit
|
|
63
|
+
def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
|
|
64
|
+
"""Permute and shuffle within 4 threads to change the layout from
|
|
65
|
+
T0 | T1 | T2 | T3
|
|
66
|
+
a b | c d | e f | g h
|
|
67
|
+
to
|
|
68
|
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
|
69
|
+
a | b | c | d | e | f | g | h
|
|
70
|
+
This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
assert t.element_type.width == 32
|
|
74
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
|
75
|
+
|
|
76
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
77
|
+
# left_map = [0, 2, 1, 3]
|
|
78
|
+
# right_map = [2, 0, 3, 1]
|
|
79
|
+
# indexing isn't supported so we have to do arithmetic
|
|
80
|
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
|
81
|
+
right_idx = left_idx ^ 0b10
|
|
82
|
+
|
|
83
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
84
|
+
width = 4
|
|
85
|
+
mask = cute.arch.WARP_SIZE - width
|
|
86
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
87
|
+
mask_and_clamp = mask << 8 | clamp
|
|
88
|
+
|
|
89
|
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
|
90
|
+
for r in cutlass.range(2, unroll_full=True):
|
|
91
|
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
|
92
|
+
# a b | c d | e f | g h -> a b | c d | f e | h g
|
|
93
|
+
left0 = left if quad_idx < 2 else right
|
|
94
|
+
right0 = right if quad_idx < 2 else left
|
|
95
|
+
# a b | c d | f e | h g -> a b | f d | c e | h g
|
|
96
|
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
|
97
|
+
# a b | f d | c e | h g -> a e | f b | c g | h d
|
|
98
|
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
|
99
|
+
# a e | f b | c g | h d -> a e | b f | c g | d h
|
|
100
|
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
|
|
101
|
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
|
|
102
|
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@cute.jit
|
|
106
|
+
def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
|
|
107
|
+
"""Permute and shuffle within 4 threads to change the layout from
|
|
108
|
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
|
109
|
+
a | b | c | d | e | f | g | h
|
|
110
|
+
to
|
|
111
|
+
T0 | T1 | T2 | T3
|
|
112
|
+
a b | c d | e f | g h
|
|
113
|
+
This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
assert t.element_type.width == 32
|
|
117
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
|
118
|
+
|
|
119
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
120
|
+
# left_map = [0, 2, 1, 3]
|
|
121
|
+
# right_map = [1, 3, 0, 2]
|
|
122
|
+
# indexing isn't supported so we have to do arithmetic
|
|
123
|
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
|
124
|
+
right_idx = left_idx ^ 0b01
|
|
125
|
+
|
|
126
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
127
|
+
width = 4
|
|
128
|
+
mask = cute.arch.WARP_SIZE - width
|
|
129
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
130
|
+
mask_and_clamp = mask << 8 | clamp
|
|
131
|
+
|
|
132
|
+
# This is just the inverse of permute_Cregs_b32_for_stsm
|
|
133
|
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
|
134
|
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
|
135
|
+
for r in cutlass.range(2, unroll_full=True):
|
|
136
|
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
|
137
|
+
# a e | b f | c g | d h -> a e | f b | c g | h d
|
|
138
|
+
left0 = left if quad_idx % 2 == 0 else right
|
|
139
|
+
right0 = right if quad_idx % 2 == 0 else left
|
|
140
|
+
# a e | f b | c g | h d -> a b | f d | c e | h g
|
|
141
|
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
|
142
|
+
# a b | f d | c e | h g -> a b | c d | f e | h g
|
|
143
|
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
|
144
|
+
# a b | c d | f e | h g -> a b | c d | e f | g h
|
|
145
|
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
|
|
146
|
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@cute.jit
|
|
150
|
+
def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
|
151
|
+
return cute.make_layout(
|
|
152
|
+
tuple(l.shape for l in layouts),
|
|
153
|
+
stride=tuple(l.stride for l in layouts),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
158
|
+
"""
|
|
159
|
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
160
|
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
161
|
+
"""
|
|
162
|
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
163
|
+
acc_layout_mn = cute.make_layout(
|
|
164
|
+
(
|
|
165
|
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
166
|
+
(
|
|
167
|
+
acc_layout_col_major.shape[0][0],
|
|
168
|
+
*acc_layout_col_major.shape[0][2:],
|
|
169
|
+
acc_layout_col_major.shape[2],
|
|
170
|
+
), # MMA_N
|
|
171
|
+
*acc_layout_col_major.shape[3:],
|
|
172
|
+
),
|
|
173
|
+
stride=(
|
|
174
|
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
175
|
+
(
|
|
176
|
+
acc_layout_col_major.stride[0][0],
|
|
177
|
+
*acc_layout_col_major.stride[0][2:],
|
|
178
|
+
acc_layout_col_major.stride[2],
|
|
179
|
+
), # MMA_N
|
|
180
|
+
*acc_layout_col_major.stride[3:],
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
return cute.composition(acc_layout, acc_layout_mn)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
187
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@cute.jit
|
|
191
|
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
192
|
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
|
193
|
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
194
|
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
|
195
|
+
# TODO: Sm90 FP8
|
|
196
|
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
|
197
|
+
l = cute.logical_divide(
|
|
198
|
+
acc_layout, ((None, None, 2), None, None)
|
|
199
|
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
|
200
|
+
rA_mma_view = cute.make_layout(
|
|
201
|
+
(
|
|
202
|
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
|
203
|
+
l.shape[1],
|
|
204
|
+
(l.shape[0][2][1], l.shape[2]),
|
|
205
|
+
),
|
|
206
|
+
stride=(
|
|
207
|
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
|
208
|
+
l.stride[1],
|
|
209
|
+
(l.stride[0][2][1], l.stride[2]),
|
|
210
|
+
),
|
|
211
|
+
)
|
|
212
|
+
else: # Sm80
|
|
213
|
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
|
214
|
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
|
215
|
+
rA_mma_view = cute.make_layout(
|
|
216
|
+
(
|
|
217
|
+
(l.shape[0], l.shape[2][0]),
|
|
218
|
+
l.shape[1],
|
|
219
|
+
l.shape[2][1],
|
|
220
|
+
),
|
|
221
|
+
stride=(
|
|
222
|
+
(l.stride[0], l.stride[2][0]),
|
|
223
|
+
l.stride[1],
|
|
224
|
+
l.stride[2][1],
|
|
225
|
+
),
|
|
226
|
+
)
|
|
227
|
+
return rA_mma_view
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def convert_layout_zero_stride(
|
|
231
|
+
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
|
232
|
+
) -> cute.Layout:
|
|
233
|
+
layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
|
|
234
|
+
# Group the modes with non-zero stride in the ref_layout together,
|
|
235
|
+
# and the modes with zero stride together
|
|
236
|
+
layout_flat = cute.flatten(layout)
|
|
237
|
+
ref_layout_flat = cute.flatten(ref_layout)
|
|
238
|
+
nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
|
|
239
|
+
zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
|
|
240
|
+
# There's an edge case when all modes are zero stride
|
|
241
|
+
new_shape = (
|
|
242
|
+
tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
|
|
243
|
+
tuple(layout_flat[i].shape for i in zero_modes),
|
|
244
|
+
)
|
|
245
|
+
new_stride = (
|
|
246
|
+
tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
|
|
247
|
+
tuple(layout_flat[i].stride for i in zero_modes),
|
|
248
|
+
)
|
|
249
|
+
out_layout = cute.make_layout(new_shape, stride=new_stride)
|
|
250
|
+
if const_expr(isinstance(input, cute.Tensor)):
|
|
251
|
+
return cute.make_tensor(input.iterator, out_layout)
|
|
252
|
+
else:
|
|
253
|
+
return out_layout
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def mma_partition_C_vec(
|
|
257
|
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
|
258
|
+
) -> cute.Tensor:
|
|
259
|
+
assert cute.rank(sVec) == 2
|
|
260
|
+
assert sVec.stride[0] == 1
|
|
261
|
+
stage = sVec.shape[1]
|
|
262
|
+
shape = (
|
|
263
|
+
(sVec.shape[0], expand_shape, stage)
|
|
264
|
+
if const_expr(is_colvec)
|
|
265
|
+
else (expand_shape, sVec.shape[0], stage)
|
|
266
|
+
)
|
|
267
|
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
|
268
|
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
|
269
|
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
|
|
270
|
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def mma_partition_A_vec(
|
|
274
|
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
|
275
|
+
) -> cute.Tensor:
|
|
276
|
+
assert cute.rank(sVec) == 2
|
|
277
|
+
assert sVec.stride[0] == 1
|
|
278
|
+
stage = sVec.shape[1]
|
|
279
|
+
shape = (
|
|
280
|
+
(sVec.shape[0], expand_shape, stage)
|
|
281
|
+
if const_expr(is_colvec)
|
|
282
|
+
else (expand_shape, sVec.shape[0], stage)
|
|
283
|
+
)
|
|
284
|
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
|
285
|
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
|
286
|
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
|
287
|
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
quack/linear.py
CHANGED
|
@@ -61,10 +61,11 @@ class LinearFunc(torch.autograd.Function):
|
|
|
61
61
|
# Use classmethod instead of staticmethod to allow inheritance
|
|
62
62
|
@classmethod
|
|
63
63
|
@custom_fwd(device_type="cuda")
|
|
64
|
-
def forward(cls, ctx, x, weight, fuse_grad_accum=False):
|
|
64
|
+
def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False):
|
|
65
65
|
"""
|
|
66
66
|
x: (..., in_features)
|
|
67
67
|
weight: (out_features, in_features)
|
|
68
|
+
bias: (out_features,) or None
|
|
68
69
|
out: (..., out_features)
|
|
69
70
|
"""
|
|
70
71
|
ctx.weight_dtype = weight.dtype
|
|
@@ -74,8 +75,9 @@ class LinearFunc(torch.autograd.Function):
|
|
|
74
75
|
batch_shape = x.shape[:-1]
|
|
75
76
|
x = x.reshape(-1, x.shape[-1])
|
|
76
77
|
# out = F.linear(x, weight)
|
|
77
|
-
out = cls.matmul_fwd_fn(x, weight.T)
|
|
78
|
+
out = cls.matmul_fwd_fn(x, weight.T, bias=bias)
|
|
78
79
|
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
80
|
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
|
79
81
|
return out.reshape(*batch_shape, out.shape[-1])
|
|
80
82
|
|
|
81
83
|
@classmethod
|
|
@@ -87,13 +89,18 @@ class LinearFunc(torch.autograd.Function):
|
|
|
87
89
|
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
88
90
|
batch_shape = dout.shape[:-1]
|
|
89
91
|
dout = dout.reshape(-1, dout.shape[-1])
|
|
92
|
+
dbias = (
|
|
93
|
+
dout.sum(0, dtype=ctx.bias_dtype)
|
|
94
|
+
if ctx.bias_dtype is not None and ctx.needs_input_grad[2]
|
|
95
|
+
else None
|
|
96
|
+
)
|
|
90
97
|
dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
|
|
91
98
|
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
92
99
|
dweight = linear_bwd_compute_weight_grad(
|
|
93
100
|
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
94
101
|
)
|
|
95
102
|
# return extra Nones for other classes that inherit from LinearFunc
|
|
96
|
-
return dx, dweight, *([None] * 10)
|
|
103
|
+
return dx, dweight, dbias, *([None] * 10)
|
|
97
104
|
|
|
98
105
|
|
|
99
106
|
class LinearUntunedFunc(LinearFunc):
|
|
@@ -104,9 +111,9 @@ class LinearUntunedFunc(LinearFunc):
|
|
|
104
111
|
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
105
112
|
|
|
106
113
|
|
|
107
|
-
def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
|
|
114
|
+
def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
|
|
108
115
|
fn_cls = LinearFunc if tuned else LinearUntunedFunc
|
|
109
|
-
return fn_cls.apply(x, weight, fuse_grad_accum)
|
|
116
|
+
return fn_cls.apply(x, weight, bias, fuse_grad_accum)
|
|
110
117
|
|
|
111
118
|
|
|
112
119
|
class LinearActFunc(LinearFunc):
|
|
@@ -115,10 +122,13 @@ class LinearActFunc(LinearFunc):
|
|
|
115
122
|
# Use classmethod instead of staticmethod to allow inheritance
|
|
116
123
|
@classmethod
|
|
117
124
|
@custom_fwd(device_type="cuda")
|
|
118
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False
|
|
127
|
+
):
|
|
119
128
|
"""
|
|
120
129
|
x: (..., in_features)
|
|
121
130
|
weight: (out_features, in_features)
|
|
131
|
+
bias: (out_features,) or None
|
|
122
132
|
out: (..., out_features)
|
|
123
133
|
Return both out and post-activation, but only out is differentiable.
|
|
124
134
|
"""
|
|
@@ -129,11 +139,12 @@ class LinearActFunc(LinearFunc):
|
|
|
129
139
|
batch_shape = x.shape[:-1]
|
|
130
140
|
x = x.reshape(-1, x.shape[-1])
|
|
131
141
|
out, postact = cls.matmul_fwd_fn(
|
|
132
|
-
x, weight.T, activation=activation, store_preact=store_preact
|
|
142
|
+
x, weight.T, bias=bias, activation=activation, store_preact=store_preact
|
|
133
143
|
)
|
|
134
144
|
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
135
145
|
if out is not None:
|
|
136
146
|
out = out.reshape(*batch_shape, out.shape[-1])
|
|
147
|
+
ctx.bias_dtype = bias.dtype if bias is not None else None
|
|
137
148
|
ctx.mark_non_differentiable(postact)
|
|
138
149
|
ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
|
|
139
150
|
return out, postact.reshape(*batch_shape, postact.shape[-1])
|
|
@@ -147,9 +158,11 @@ class LinearActUntunedFunc(LinearActFunc):
|
|
|
147
158
|
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
148
159
|
|
|
149
160
|
|
|
150
|
-
def linear_act_func(
|
|
161
|
+
def linear_act_func(
|
|
162
|
+
x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
|
|
163
|
+
):
|
|
151
164
|
fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
|
|
152
|
-
return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
|
|
165
|
+
return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
|
|
153
166
|
|
|
154
167
|
|
|
155
168
|
class DActLinearFunc(LinearFunc):
|
|
@@ -229,12 +242,7 @@ class Linear(nn.Linear):
|
|
|
229
242
|
self.fuse_grad_accum = fuse_grad_accum
|
|
230
243
|
|
|
231
244
|
def forward(self, input: Tensor) -> Tensor:
|
|
232
|
-
if
|
|
233
|
-
self.bias
|
|
234
|
-
and input.is_cuda
|
|
235
|
-
and self.in_features % 8 == 0
|
|
236
|
-
and self.out_features % 8 == 0
|
|
237
|
-
):
|
|
238
|
-
return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
245
|
+
if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
|
|
246
|
+
return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
|
|
239
247
|
else:
|
|
240
248
|
return F.linear(input, self.weight, self.bias)
|
quack/pipeline.py
CHANGED
|
@@ -4,9 +4,11 @@ from typing import Optional
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
|
|
6
6
|
import cutlass.cute as cute
|
|
7
|
-
from cutlass
|
|
8
|
-
from cutlass.
|
|
7
|
+
from cutlass import Boolean, Int32, const_expr
|
|
8
|
+
from cutlass.cutlass_dsl import if_generate, and_
|
|
9
|
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
9
10
|
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
11
|
+
from cutlass.pipeline import PipelineTmaUmma
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class PipelineStateWAdvance(PipelineState):
|
|
@@ -144,7 +146,160 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
144
146
|
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
145
147
|
)
|
|
146
148
|
|
|
147
|
-
def
|
|
149
|
+
def producer_cpasync_commit(self, state: PipelineState):
|
|
150
|
+
"""
|
|
151
|
+
We need the mbarrier to track the completion of cp.async
|
|
152
|
+
"""
|
|
153
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class MbarrierArrayWDropCount(MbarrierArray):
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
barrier_storage: cute.Pointer,
|
|
160
|
+
num_stages: int,
|
|
161
|
+
agent: tuple[PipelineOp, CooperativeGroup],
|
|
162
|
+
tx_count: int = 0,
|
|
163
|
+
drop_count: Optional[Int32] = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
self.barrier_storage = barrier_storage
|
|
166
|
+
self.tx_count = tx_count
|
|
167
|
+
self.num_stages = num_stages
|
|
168
|
+
self.op_type, self.cg = agent
|
|
169
|
+
self.arrive_count = self.cg.size
|
|
170
|
+
self.drop_count = drop_count
|
|
171
|
+
|
|
172
|
+
if self.num_stages <= 0:
|
|
173
|
+
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
|
|
174
|
+
if self.arrive_count <= 0:
|
|
175
|
+
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
|
|
176
|
+
if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
|
|
177
|
+
raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
|
|
178
|
+
|
|
179
|
+
if const_expr(drop_count is not None):
|
|
180
|
+
self.arrive_count = self.arrive_count - drop_count
|
|
181
|
+
|
|
182
|
+
# Store mbarrier base pointer
|
|
183
|
+
self.mbarrier_base = self.barrier_storage
|
|
184
|
+
|
|
185
|
+
# Mbarrier initialization in constructor
|
|
186
|
+
self.mbarrier_init()
|
|
187
|
+
|
|
188
|
+
def __extract_mlir_values__(self):
|
|
189
|
+
return [self.barrier_storage, self.drop_count]
|
|
190
|
+
|
|
191
|
+
def __new_from_mlir_values__(self, values):
|
|
192
|
+
return MbarrierArrayWDropCount(
|
|
193
|
+
values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass(frozen=True)
|
|
198
|
+
class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
199
|
+
"""
|
|
200
|
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
|
|
201
|
+
(e.g. Blackwell mainloops)
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
@staticmethod
|
|
205
|
+
def create(
|
|
206
|
+
*,
|
|
207
|
+
num_stages: int,
|
|
208
|
+
producer_group: CooperativeGroup,
|
|
209
|
+
consumer_group: CooperativeGroup,
|
|
210
|
+
tx_count: int,
|
|
211
|
+
barrier_storage: cute.Pointer = None,
|
|
212
|
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
213
|
+
producer_drop_count: Optional[Int32] = None,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
|
217
|
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
218
|
+
:type barrier_storage: cute.Pointer
|
|
219
|
+
:param num_stages: Number of buffer stages for this pipeline
|
|
220
|
+
:type num_stages: Int32
|
|
221
|
+
:param producer_group: `CooperativeGroup` for the producer agent
|
|
222
|
+
:type producer_group: CooperativeGroup
|
|
223
|
+
:param consumer_group: `CooperativeGroup` for the consumer agent
|
|
224
|
+
:type consumer_group: CooperativeGroup
|
|
225
|
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
226
|
+
:type tx_count: int
|
|
227
|
+
:param cta_layout_vmnk: Layout of the cluster shape
|
|
228
|
+
:type cta_layout_vmnk: cute.Layout | None
|
|
229
|
+
"""
|
|
230
|
+
if not isinstance(barrier_storage, cute.Pointer):
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
producer_type = PipelineOp.TmaLoad
|
|
236
|
+
consumer_type = PipelineOp.TCGen05Mma
|
|
237
|
+
|
|
238
|
+
producer = (producer_type, producer_group)
|
|
239
|
+
consumer = (consumer_type, consumer_group)
|
|
240
|
+
|
|
241
|
+
sync_object_full = MbarrierArrayWDropCount(
|
|
242
|
+
barrier_storage.align(min_align=8),
|
|
243
|
+
num_stages,
|
|
244
|
+
producer,
|
|
245
|
+
tx_count,
|
|
246
|
+
drop_count=producer_drop_count,
|
|
247
|
+
)
|
|
248
|
+
sync_object_empty = PipelineAsync._make_sync_object(
|
|
249
|
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
253
|
+
# No mcast mask if not using clusters
|
|
254
|
+
producer_mask = None
|
|
255
|
+
# All threadblocks are leaders if not using clusters
|
|
256
|
+
is_leader_cta = True
|
|
257
|
+
else:
|
|
258
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
|
259
|
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
|
260
|
+
|
|
261
|
+
cta_group = (
|
|
262
|
+
cute.nvgpu.tcgen05.CtaGroup.ONE
|
|
263
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
|
264
|
+
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
consumer_mask = producer_mask
|
|
268
|
+
|
|
269
|
+
pipeline_init_wait(cta_layout_vmnk)
|
|
270
|
+
|
|
271
|
+
return PipelineTmaCpAsyncUmma(
|
|
272
|
+
sync_object_full,
|
|
273
|
+
sync_object_empty,
|
|
274
|
+
num_stages,
|
|
275
|
+
producer_mask,
|
|
276
|
+
consumer_mask,
|
|
277
|
+
is_leader_cta,
|
|
278
|
+
cta_group,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def producer_acquire(
|
|
282
|
+
self,
|
|
283
|
+
state: PipelineState,
|
|
284
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
285
|
+
is_tma_warp: Optional[Boolean] = True,
|
|
286
|
+
):
|
|
287
|
+
"""
|
|
288
|
+
TMA producer commit conditionally waits on buffer empty and sets the
|
|
289
|
+
transaction barrier for leader threadblocks.
|
|
290
|
+
"""
|
|
291
|
+
if_generate(
|
|
292
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
293
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
294
|
+
)
|
|
295
|
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
296
|
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
297
|
+
if_generate(
|
|
298
|
+
and_(self.is_leader_cta, is_tma_warp),
|
|
299
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def producer_cpasync_commit(self, state: PipelineState):
|
|
148
303
|
"""
|
|
149
304
|
We need the mbarrier to track the completion of cp.async
|
|
150
305
|
"""
|