quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/layout_utils.py ADDED
@@ -0,0 +1,287 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from cutlass import Int32, const_expr
8
+
9
+ from quack.utils import prmt
10
+
11
+
12
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
13
+ """Transpose the first two dimensions of a tensor on smem."""
14
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
15
+ order = (1, 0, *range(2, cute.rank(a)))
16
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
17
+
18
+
19
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
20
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
21
+
22
+
23
+ def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
24
+ shape = (*a.shape[:dim], size, *a.shape[dim:])
25
+ stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
26
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
27
+
28
+
29
+ @cute.jit
30
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
31
+ assert t.element_type.width == 16
32
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
33
+ t_u32 = cute.recast_tensor(t, Int32)
34
+
35
+ quad_idx = cute.arch.lane_idx() % 4
36
+ lane_03 = quad_idx == 0 or quad_idx == 3
37
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
38
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
39
+ # upper_map = [0, 3, 1, 2]
40
+ # lower_map = [1, 2, 0, 3]
41
+ # upper_idx = upper_map[quad_idx]
42
+ # indexing isn't supported so we have to do arithmetic
43
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
44
+ lower_idx = upper_idx ^ 1
45
+
46
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
47
+ width = 4
48
+ mask = cute.arch.WARP_SIZE - width
49
+ clamp = cute.arch.WARP_SIZE - 1
50
+ mask_and_clamp = mask << 8 | clamp
51
+
52
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
53
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
54
+ upper0 = upper if lane_03 else lower
55
+ lower0 = lower if lane_03 else upper
56
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
57
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
58
+ t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
59
+ t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
60
+
61
+
62
+ @cute.jit
63
+ def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
64
+ """Permute and shuffle within 4 threads to change the layout from
65
+ T0 | T1 | T2 | T3
66
+ a b | c d | e f | g h
67
+ to
68
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
69
+ a | b | c | d | e | f | g | h
70
+ This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
71
+ """
72
+
73
+ assert t.element_type.width == 32
74
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
75
+
76
+ quad_idx = cute.arch.lane_idx() % 4
77
+ # left_map = [0, 2, 1, 3]
78
+ # right_map = [2, 0, 3, 1]
79
+ # indexing isn't supported so we have to do arithmetic
80
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
81
+ right_idx = left_idx ^ 0b10
82
+
83
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
84
+ width = 4
85
+ mask = cute.arch.WARP_SIZE - width
86
+ clamp = cute.arch.WARP_SIZE - 1
87
+ mask_and_clamp = mask << 8 | clamp
88
+
89
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
90
+ for r in cutlass.range(2, unroll_full=True):
91
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
92
+ # a b | c d | e f | g h -> a b | c d | f e | h g
93
+ left0 = left if quad_idx < 2 else right
94
+ right0 = right if quad_idx < 2 else left
95
+ # a b | c d | f e | h g -> a b | f d | c e | h g
96
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
97
+ # a b | f d | c e | h g -> a e | f b | c g | h d
98
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
99
+ # a e | f b | c g | h d -> a e | b f | c g | d h
100
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
101
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
102
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
103
+
104
+
105
+ @cute.jit
106
+ def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
107
+ """Permute and shuffle within 4 threads to change the layout from
108
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
109
+ a | b | c | d | e | f | g | h
110
+ to
111
+ T0 | T1 | T2 | T3
112
+ a b | c d | e f | g h
113
+ This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
114
+ """
115
+
116
+ assert t.element_type.width == 32
117
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
118
+
119
+ quad_idx = cute.arch.lane_idx() % 4
120
+ # left_map = [0, 2, 1, 3]
121
+ # right_map = [1, 3, 0, 2]
122
+ # indexing isn't supported so we have to do arithmetic
123
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
124
+ right_idx = left_idx ^ 0b01
125
+
126
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
127
+ width = 4
128
+ mask = cute.arch.WARP_SIZE - width
129
+ clamp = cute.arch.WARP_SIZE - 1
130
+ mask_and_clamp = mask << 8 | clamp
131
+
132
+ # This is just the inverse of permute_Cregs_b32_for_stsm
133
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
134
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
135
+ for r in cutlass.range(2, unroll_full=True):
136
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
137
+ # a e | b f | c g | d h -> a e | f b | c g | h d
138
+ left0 = left if quad_idx % 2 == 0 else right
139
+ right0 = right if quad_idx % 2 == 0 else left
140
+ # a e | f b | c g | h d -> a b | f d | c e | h g
141
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
142
+ # a b | f d | c e | h g -> a b | c d | f e | h g
143
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
144
+ # a b | c d | f e | h g -> a b | c d | e f | g h
145
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
146
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
147
+
148
+
149
+ @cute.jit
150
+ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
151
+ return cute.make_layout(
152
+ tuple(l.shape for l in layouts),
153
+ stride=tuple(l.stride for l in layouts),
154
+ )
155
+
156
+
157
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
158
+ """
159
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
160
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
161
+ """
162
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
163
+ acc_layout_mn = cute.make_layout(
164
+ (
165
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
166
+ (
167
+ acc_layout_col_major.shape[0][0],
168
+ *acc_layout_col_major.shape[0][2:],
169
+ acc_layout_col_major.shape[2],
170
+ ), # MMA_N
171
+ *acc_layout_col_major.shape[3:],
172
+ ),
173
+ stride=(
174
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
175
+ (
176
+ acc_layout_col_major.stride[0][0],
177
+ *acc_layout_col_major.stride[0][2:],
178
+ acc_layout_col_major.stride[2],
179
+ ), # MMA_N
180
+ *acc_layout_col_major.stride[3:],
181
+ ),
182
+ )
183
+ return cute.composition(acc_layout, acc_layout_mn)
184
+
185
+
186
+ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
+
189
+
190
+ @cute.jit
191
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
192
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
193
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
194
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
195
+ # TODO: Sm90 FP8
196
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
197
+ l = cute.logical_divide(
198
+ acc_layout, ((None, None, 2), None, None)
199
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
200
+ rA_mma_view = cute.make_layout(
201
+ (
202
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
203
+ l.shape[1],
204
+ (l.shape[0][2][1], l.shape[2]),
205
+ ),
206
+ stride=(
207
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
208
+ l.stride[1],
209
+ (l.stride[0][2][1], l.stride[2]),
210
+ ),
211
+ )
212
+ else: # Sm80
213
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
214
+ l = cute.logical_divide(acc_layout, (None, None, 2))
215
+ rA_mma_view = cute.make_layout(
216
+ (
217
+ (l.shape[0], l.shape[2][0]),
218
+ l.shape[1],
219
+ l.shape[2][1],
220
+ ),
221
+ stride=(
222
+ (l.stride[0], l.stride[2][0]),
223
+ l.stride[1],
224
+ l.stride[2][1],
225
+ ),
226
+ )
227
+ return rA_mma_view
228
+
229
+
230
+ def convert_layout_zero_stride(
231
+ input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
232
+ ) -> cute.Layout:
233
+ layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
234
+ # Group the modes with non-zero stride in the ref_layout together,
235
+ # and the modes with zero stride together
236
+ layout_flat = cute.flatten(layout)
237
+ ref_layout_flat = cute.flatten(ref_layout)
238
+ nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
239
+ zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
240
+ # There's an edge case when all modes are zero stride
241
+ new_shape = (
242
+ tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
243
+ tuple(layout_flat[i].shape for i in zero_modes),
244
+ )
245
+ new_stride = (
246
+ tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
247
+ tuple(layout_flat[i].stride for i in zero_modes),
248
+ )
249
+ out_layout = cute.make_layout(new_shape, stride=new_stride)
250
+ if const_expr(isinstance(input, cute.Tensor)):
251
+ return cute.make_tensor(input.iterator, out_layout)
252
+ else:
253
+ return out_layout
254
+
255
+
256
+ def mma_partition_C_vec(
257
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
258
+ ) -> cute.Tensor:
259
+ assert cute.rank(sVec) == 2
260
+ assert sVec.stride[0] == 1
261
+ stage = sVec.shape[1]
262
+ shape = (
263
+ (sVec.shape[0], expand_shape, stage)
264
+ if const_expr(is_colvec)
265
+ else (expand_shape, sVec.shape[0], stage)
266
+ )
267
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
268
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
269
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
270
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
271
+
272
+
273
+ def mma_partition_A_vec(
274
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
275
+ ) -> cute.Tensor:
276
+ assert cute.rank(sVec) == 2
277
+ assert sVec.stride[0] == 1
278
+ stage = sVec.shape[1]
279
+ shape = (
280
+ (sVec.shape[0], expand_shape, stage)
281
+ if const_expr(is_colvec)
282
+ else (expand_shape, sVec.shape[0], stage)
283
+ )
284
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
285
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
286
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
287
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
quack/linear.py CHANGED
@@ -61,10 +61,11 @@ class LinearFunc(torch.autograd.Function):
61
61
  # Use classmethod instead of staticmethod to allow inheritance
62
62
  @classmethod
63
63
  @custom_fwd(device_type="cuda")
64
- def forward(cls, ctx, x, weight, fuse_grad_accum=False):
64
+ def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False):
65
65
  """
66
66
  x: (..., in_features)
67
67
  weight: (out_features, in_features)
68
+ bias: (out_features,) or None
68
69
  out: (..., out_features)
69
70
  """
70
71
  ctx.weight_dtype = weight.dtype
@@ -74,8 +75,9 @@ class LinearFunc(torch.autograd.Function):
74
75
  batch_shape = x.shape[:-1]
75
76
  x = x.reshape(-1, x.shape[-1])
76
77
  # out = F.linear(x, weight)
77
- out = cls.matmul_fwd_fn(x, weight.T)
78
+ out = cls.matmul_fwd_fn(x, weight.T, bias=bias)
78
79
  linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
80
+ ctx.bias_dtype = bias.dtype if bias is not None else None
79
81
  return out.reshape(*batch_shape, out.shape[-1])
80
82
 
81
83
  @classmethod
@@ -87,13 +89,18 @@ class LinearFunc(torch.autograd.Function):
87
89
  x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
88
90
  batch_shape = dout.shape[:-1]
89
91
  dout = dout.reshape(-1, dout.shape[-1])
92
+ dbias = (
93
+ dout.sum(0, dtype=ctx.bias_dtype)
94
+ if ctx.bias_dtype is not None and ctx.needs_input_grad[2]
95
+ else None
96
+ )
90
97
  dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
91
98
  dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
92
99
  dweight = linear_bwd_compute_weight_grad(
93
100
  ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
94
101
  )
95
102
  # return extra Nones for other classes that inherit from LinearFunc
96
- return dx, dweight, *([None] * 10)
103
+ return dx, dweight, dbias, *([None] * 10)
97
104
 
98
105
 
99
106
  class LinearUntunedFunc(LinearFunc):
@@ -104,9 +111,9 @@ class LinearUntunedFunc(LinearFunc):
104
111
  matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
105
112
 
106
113
 
107
- def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
114
+ def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
108
115
  fn_cls = LinearFunc if tuned else LinearUntunedFunc
109
- return fn_cls.apply(x, weight, fuse_grad_accum)
116
+ return fn_cls.apply(x, weight, bias, fuse_grad_accum)
110
117
 
111
118
 
112
119
  class LinearActFunc(LinearFunc):
@@ -115,10 +122,13 @@ class LinearActFunc(LinearFunc):
115
122
  # Use classmethod instead of staticmethod to allow inheritance
116
123
  @classmethod
117
124
  @custom_fwd(device_type="cuda")
118
- def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
125
+ def forward(
126
+ cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False
127
+ ):
119
128
  """
120
129
  x: (..., in_features)
121
130
  weight: (out_features, in_features)
131
+ bias: (out_features,) or None
122
132
  out: (..., out_features)
123
133
  Return both out and post-activation, but only out is differentiable.
124
134
  """
@@ -129,11 +139,12 @@ class LinearActFunc(LinearFunc):
129
139
  batch_shape = x.shape[:-1]
130
140
  x = x.reshape(-1, x.shape[-1])
131
141
  out, postact = cls.matmul_fwd_fn(
132
- x, weight.T, activation=activation, store_preact=store_preact
142
+ x, weight.T, bias=bias, activation=activation, store_preact=store_preact
133
143
  )
134
144
  linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
135
145
  if out is not None:
136
146
  out = out.reshape(*batch_shape, out.shape[-1])
147
+ ctx.bias_dtype = bias.dtype if bias is not None else None
137
148
  ctx.mark_non_differentiable(postact)
138
149
  ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
139
150
  return out, postact.reshape(*batch_shape, postact.shape[-1])
@@ -147,9 +158,11 @@ class LinearActUntunedFunc(LinearActFunc):
147
158
  matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
148
159
 
149
160
 
150
- def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
161
+ def linear_act_func(
162
+ x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
163
+ ):
151
164
  fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
152
- return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
165
+ return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
153
166
 
154
167
 
155
168
  class DActLinearFunc(LinearFunc):
@@ -229,12 +242,7 @@ class Linear(nn.Linear):
229
242
  self.fuse_grad_accum = fuse_grad_accum
230
243
 
231
244
  def forward(self, input: Tensor) -> Tensor:
232
- if (
233
- self.bias is None
234
- and input.is_cuda
235
- and self.in_features % 8 == 0
236
- and self.out_features % 8 == 0
237
- ):
238
- return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
245
+ if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
246
+ return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
239
247
  else:
240
248
  return F.linear(input, self.weight, self.bias)
quack/pipeline.py CHANGED
@@ -4,9 +4,11 @@ from typing import Optional
4
4
  from dataclasses import dataclass
5
5
 
6
6
  import cutlass.cute as cute
7
- from cutlass.cutlass_dsl import Boolean, Int32, if_generate
8
- from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
7
+ from cutlass import Boolean, Int32, const_expr
8
+ from cutlass.cutlass_dsl import if_generate, and_
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
9
10
  from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
11
+ from cutlass.pipeline import PipelineTmaUmma
10
12
 
11
13
 
12
14
  class PipelineStateWAdvance(PipelineState):
@@ -144,7 +146,160 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
144
146
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
145
147
  )
146
148
 
147
- def producer_commit(self, state: PipelineState):
149
+ def producer_cpasync_commit(self, state: PipelineState):
150
+ """
151
+ We need the mbarrier to track the completion of cp.async
152
+ """
153
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
154
+
155
+
156
+ class MbarrierArrayWDropCount(MbarrierArray):
157
+ def __init__(
158
+ self,
159
+ barrier_storage: cute.Pointer,
160
+ num_stages: int,
161
+ agent: tuple[PipelineOp, CooperativeGroup],
162
+ tx_count: int = 0,
163
+ drop_count: Optional[Int32] = None,
164
+ ) -> None:
165
+ self.barrier_storage = barrier_storage
166
+ self.tx_count = tx_count
167
+ self.num_stages = num_stages
168
+ self.op_type, self.cg = agent
169
+ self.arrive_count = self.cg.size
170
+ self.drop_count = drop_count
171
+
172
+ if self.num_stages <= 0:
173
+ raise ValueError("Error: Mbarrier stage count must be greater than 0.")
174
+ if self.arrive_count <= 0:
175
+ raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
176
+ if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
177
+ raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
178
+
179
+ if const_expr(drop_count is not None):
180
+ self.arrive_count = self.arrive_count - drop_count
181
+
182
+ # Store mbarrier base pointer
183
+ self.mbarrier_base = self.barrier_storage
184
+
185
+ # Mbarrier initialization in constructor
186
+ self.mbarrier_init()
187
+
188
+ def __extract_mlir_values__(self):
189
+ return [self.barrier_storage, self.drop_count]
190
+
191
+ def __new_from_mlir_values__(self, values):
192
+ return MbarrierArrayWDropCount(
193
+ values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
194
+ )
195
+
196
+
197
+ @dataclass(frozen=True)
198
+ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
199
+ """
200
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
201
+ (e.g. Blackwell mainloops)
202
+ """
203
+
204
+ @staticmethod
205
+ def create(
206
+ *,
207
+ num_stages: int,
208
+ producer_group: CooperativeGroup,
209
+ consumer_group: CooperativeGroup,
210
+ tx_count: int,
211
+ barrier_storage: cute.Pointer = None,
212
+ cta_layout_vmnk: Optional[cute.Layout] = None,
213
+ producer_drop_count: Optional[Int32] = None,
214
+ ):
215
+ """
216
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
217
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
218
+ :type barrier_storage: cute.Pointer
219
+ :param num_stages: Number of buffer stages for this pipeline
220
+ :type num_stages: Int32
221
+ :param producer_group: `CooperativeGroup` for the producer agent
222
+ :type producer_group: CooperativeGroup
223
+ :param consumer_group: `CooperativeGroup` for the consumer agent
224
+ :type consumer_group: CooperativeGroup
225
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
226
+ :type tx_count: int
227
+ :param cta_layout_vmnk: Layout of the cluster shape
228
+ :type cta_layout_vmnk: cute.Layout | None
229
+ """
230
+ if not isinstance(barrier_storage, cute.Pointer):
231
+ raise ValueError(
232
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
233
+ )
234
+
235
+ producer_type = PipelineOp.TmaLoad
236
+ consumer_type = PipelineOp.TCGen05Mma
237
+
238
+ producer = (producer_type, producer_group)
239
+ consumer = (consumer_type, consumer_group)
240
+
241
+ sync_object_full = MbarrierArrayWDropCount(
242
+ barrier_storage.align(min_align=8),
243
+ num_stages,
244
+ producer,
245
+ tx_count,
246
+ drop_count=producer_drop_count,
247
+ )
248
+ sync_object_empty = PipelineAsync._make_sync_object(
249
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
250
+ )
251
+
252
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
253
+ # No mcast mask if not using clusters
254
+ producer_mask = None
255
+ # All threadblocks are leaders if not using clusters
256
+ is_leader_cta = True
257
+ else:
258
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
259
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
260
+
261
+ cta_group = (
262
+ cute.nvgpu.tcgen05.CtaGroup.ONE
263
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
264
+ else cute.nvgpu.tcgen05.CtaGroup.TWO
265
+ )
266
+
267
+ consumer_mask = producer_mask
268
+
269
+ pipeline_init_wait(cta_layout_vmnk)
270
+
271
+ return PipelineTmaCpAsyncUmma(
272
+ sync_object_full,
273
+ sync_object_empty,
274
+ num_stages,
275
+ producer_mask,
276
+ consumer_mask,
277
+ is_leader_cta,
278
+ cta_group,
279
+ )
280
+
281
+ def producer_acquire(
282
+ self,
283
+ state: PipelineState,
284
+ try_acquire_token: Optional[Boolean] = None,
285
+ is_tma_warp: Optional[Boolean] = True,
286
+ ):
287
+ """
288
+ TMA producer commit conditionally waits on buffer empty and sets the
289
+ transaction barrier for leader threadblocks.
290
+ """
291
+ if_generate(
292
+ try_acquire_token is None or try_acquire_token == 0,
293
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
294
+ )
295
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
296
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
297
+ if_generate(
298
+ and_(self.is_leader_cta, is_tma_warp),
299
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
300
+ )
301
+
302
+ def producer_cpasync_commit(self, state: PipelineState):
148
303
  """
149
304
  We need the mbarrier to track the completion of cp.async
150
305
  """