quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/linear.py CHANGED
@@ -61,10 +61,11 @@ class LinearFunc(torch.autograd.Function):
61
61
  # Use classmethod instead of staticmethod to allow inheritance
62
62
  @classmethod
63
63
  @custom_fwd(device_type="cuda")
64
- def forward(cls, ctx, x, weight, fuse_grad_accum=False):
64
+ def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False):
65
65
  """
66
66
  x: (..., in_features)
67
67
  weight: (out_features, in_features)
68
+ bias: (out_features,) or None
68
69
  out: (..., out_features)
69
70
  """
70
71
  ctx.weight_dtype = weight.dtype
@@ -74,8 +75,9 @@ class LinearFunc(torch.autograd.Function):
74
75
  batch_shape = x.shape[:-1]
75
76
  x = x.reshape(-1, x.shape[-1])
76
77
  # out = F.linear(x, weight)
77
- out = cls.matmul_fwd_fn(x, weight.T)
78
+ out = cls.matmul_fwd_fn(x, weight.T, bias=bias)
78
79
  linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
80
+ ctx.bias_dtype = bias.dtype if bias is not None else None
79
81
  return out.reshape(*batch_shape, out.shape[-1])
80
82
 
81
83
  @classmethod
@@ -87,13 +89,18 @@ class LinearFunc(torch.autograd.Function):
87
89
  x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
88
90
  batch_shape = dout.shape[:-1]
89
91
  dout = dout.reshape(-1, dout.shape[-1])
92
+ dbias = (
93
+ dout.sum(0, dtype=ctx.bias_dtype)
94
+ if ctx.bias_dtype is not None and ctx.needs_input_grad[2]
95
+ else None
96
+ )
90
97
  dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
91
98
  dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
92
99
  dweight = linear_bwd_compute_weight_grad(
93
100
  ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
94
101
  )
95
102
  # return extra Nones for other classes that inherit from LinearFunc
96
- return dx, dweight, *([None] * 10)
103
+ return dx, dweight, dbias, *([None] * 10)
97
104
 
98
105
 
99
106
  class LinearUntunedFunc(LinearFunc):
@@ -104,9 +111,9 @@ class LinearUntunedFunc(LinearFunc):
104
111
  matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
105
112
 
106
113
 
107
- def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
114
+ def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
108
115
  fn_cls = LinearFunc if tuned else LinearUntunedFunc
109
- return fn_cls.apply(x, weight, fuse_grad_accum)
116
+ return fn_cls.apply(x, weight, bias, fuse_grad_accum)
110
117
 
111
118
 
112
119
  class LinearActFunc(LinearFunc):
@@ -115,10 +122,13 @@ class LinearActFunc(LinearFunc):
115
122
  # Use classmethod instead of staticmethod to allow inheritance
116
123
  @classmethod
117
124
  @custom_fwd(device_type="cuda")
118
- def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
125
+ def forward(
126
+ cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False
127
+ ):
119
128
  """
120
129
  x: (..., in_features)
121
130
  weight: (out_features, in_features)
131
+ bias: (out_features,) or None
122
132
  out: (..., out_features)
123
133
  Return both out and post-activation, but only out is differentiable.
124
134
  """
@@ -129,11 +139,12 @@ class LinearActFunc(LinearFunc):
129
139
  batch_shape = x.shape[:-1]
130
140
  x = x.reshape(-1, x.shape[-1])
131
141
  out, postact = cls.matmul_fwd_fn(
132
- x, weight.T, activation=activation, store_preact=store_preact
142
+ x, weight.T, bias=bias, activation=activation, store_preact=store_preact
133
143
  )
134
144
  linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
135
145
  if out is not None:
136
146
  out = out.reshape(*batch_shape, out.shape[-1])
147
+ ctx.bias_dtype = bias.dtype if bias is not None else None
137
148
  ctx.mark_non_differentiable(postact)
138
149
  ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
139
150
  return out, postact.reshape(*batch_shape, postact.shape[-1])
@@ -147,9 +158,11 @@ class LinearActUntunedFunc(LinearActFunc):
147
158
  matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
148
159
 
149
160
 
150
- def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
161
+ def linear_act_func(
162
+ x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
163
+ ):
151
164
  fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
152
- return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
165
+ return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
153
166
 
154
167
 
155
168
  class DActLinearFunc(LinearFunc):
@@ -229,12 +242,7 @@ class Linear(nn.Linear):
229
242
  self.fuse_grad_accum = fuse_grad_accum
230
243
 
231
244
  def forward(self, input: Tensor) -> Tensor:
232
- if (
233
- self.bias is None
234
- and input.is_cuda
235
- and self.in_features % 8 == 0
236
- and self.out_features % 8 == 0
237
- ):
238
- return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
245
+ if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
246
+ return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
239
247
  else:
240
248
  return F.linear(input, self.weight, self.bias)
quack/pipeline.py CHANGED
@@ -4,9 +4,11 @@ from typing import Optional
4
4
  from dataclasses import dataclass
5
5
 
6
6
  import cutlass.cute as cute
7
- from cutlass.cutlass_dsl import Boolean, Int32, if_generate
8
- from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
7
+ from cutlass import Boolean, Int32, const_expr
8
+ from cutlass.cutlass_dsl import if_generate, and_
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
9
10
  from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
11
+ from cutlass.pipeline import PipelineTmaUmma
10
12
 
11
13
 
12
14
  class PipelineStateWAdvance(PipelineState):
@@ -144,7 +146,160 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
144
146
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
145
147
  )
146
148
 
147
- def producer_commit(self, state: PipelineState):
149
+ def producer_cpasync_commit(self, state: PipelineState):
150
+ """
151
+ We need the mbarrier to track the completion of cp.async
152
+ """
153
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
154
+
155
+
156
+ class MbarrierArrayWDropCount(MbarrierArray):
157
+ def __init__(
158
+ self,
159
+ barrier_storage: cute.Pointer,
160
+ num_stages: int,
161
+ agent: tuple[PipelineOp, CooperativeGroup],
162
+ tx_count: int = 0,
163
+ drop_count: Optional[Int32] = None,
164
+ ) -> None:
165
+ self.barrier_storage = barrier_storage
166
+ self.tx_count = tx_count
167
+ self.num_stages = num_stages
168
+ self.op_type, self.cg = agent
169
+ self.arrive_count = self.cg.size
170
+ self.drop_count = drop_count
171
+
172
+ if self.num_stages <= 0:
173
+ raise ValueError("Error: Mbarrier stage count must be greater than 0.")
174
+ if self.arrive_count <= 0:
175
+ raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
176
+ if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0:
177
+ raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.")
178
+
179
+ if const_expr(drop_count is not None):
180
+ self.arrive_count = self.arrive_count - drop_count
181
+
182
+ # Store mbarrier base pointer
183
+ self.mbarrier_base = self.barrier_storage
184
+
185
+ # Mbarrier initialization in constructor
186
+ self.mbarrier_init()
187
+
188
+ def __extract_mlir_values__(self):
189
+ return [self.barrier_storage, self.drop_count]
190
+
191
+ def __new_from_mlir_values__(self, values):
192
+ return MbarrierArrayWDropCount(
193
+ values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1]
194
+ )
195
+
196
+
197
+ @dataclass(frozen=True)
198
+ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
199
+ """
200
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
201
+ (e.g. Blackwell mainloops)
202
+ """
203
+
204
+ @staticmethod
205
+ def create(
206
+ *,
207
+ num_stages: int,
208
+ producer_group: CooperativeGroup,
209
+ consumer_group: CooperativeGroup,
210
+ tx_count: int,
211
+ barrier_storage: cute.Pointer = None,
212
+ cta_layout_vmnk: Optional[cute.Layout] = None,
213
+ producer_drop_count: Optional[Int32] = None,
214
+ ):
215
+ """
216
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
217
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
218
+ :type barrier_storage: cute.Pointer
219
+ :param num_stages: Number of buffer stages for this pipeline
220
+ :type num_stages: Int32
221
+ :param producer_group: `CooperativeGroup` for the producer agent
222
+ :type producer_group: CooperativeGroup
223
+ :param consumer_group: `CooperativeGroup` for the consumer agent
224
+ :type consumer_group: CooperativeGroup
225
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
226
+ :type tx_count: int
227
+ :param cta_layout_vmnk: Layout of the cluster shape
228
+ :type cta_layout_vmnk: cute.Layout | None
229
+ """
230
+ if not isinstance(barrier_storage, cute.Pointer):
231
+ raise ValueError(
232
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
233
+ )
234
+
235
+ producer_type = PipelineOp.TmaLoad
236
+ consumer_type = PipelineOp.TCGen05Mma
237
+
238
+ producer = (producer_type, producer_group)
239
+ consumer = (consumer_type, consumer_group)
240
+
241
+ sync_object_full = MbarrierArrayWDropCount(
242
+ barrier_storage.align(min_align=8),
243
+ num_stages,
244
+ producer,
245
+ tx_count,
246
+ drop_count=producer_drop_count,
247
+ )
248
+ sync_object_empty = PipelineAsync._make_sync_object(
249
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
250
+ )
251
+
252
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
253
+ # No mcast mask if not using clusters
254
+ producer_mask = None
255
+ # All threadblocks are leaders if not using clusters
256
+ is_leader_cta = True
257
+ else:
258
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
259
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
260
+
261
+ cta_group = (
262
+ cute.nvgpu.tcgen05.CtaGroup.ONE
263
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
264
+ else cute.nvgpu.tcgen05.CtaGroup.TWO
265
+ )
266
+
267
+ consumer_mask = producer_mask
268
+
269
+ pipeline_init_wait(cta_layout_vmnk)
270
+
271
+ return PipelineTmaCpAsyncUmma(
272
+ sync_object_full,
273
+ sync_object_empty,
274
+ num_stages,
275
+ producer_mask,
276
+ consumer_mask,
277
+ is_leader_cta,
278
+ cta_group,
279
+ )
280
+
281
+ def producer_acquire(
282
+ self,
283
+ state: PipelineState,
284
+ try_acquire_token: Optional[Boolean] = None,
285
+ is_tma_warp: Optional[Boolean] = True,
286
+ ):
287
+ """
288
+ TMA producer commit conditionally waits on buffer empty and sets the
289
+ transaction barrier for leader threadblocks.
290
+ """
291
+ if_generate(
292
+ try_acquire_token is None or try_acquire_token == 0,
293
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
294
+ )
295
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
296
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
297
+ if_generate(
298
+ and_(self.is_leader_cta, is_tma_warp),
299
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
300
+ )
301
+
302
+ def producer_cpasync_commit(self, state: PipelineState):
148
303
  """
149
304
  We need the mbarrier to track the completion of cp.async
150
305
  """
quack/reduce.py CHANGED
@@ -6,29 +6,11 @@ from typing import Callable, Optional
6
6
 
7
7
  import cutlass
8
8
  import cutlass.cute as cute
9
- from cutlass import Float32
9
+ from cutlass import Int32, Int64, Float32, Boolean, const_expr
10
10
 
11
11
  import quack.utils as utils
12
12
 
13
13
 
14
- @cute.jit
15
- def warp_reduce(
16
- val: cute.TensorSSA | cute.Numeric,
17
- op: Callable,
18
- width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
19
- ) -> cute.TensorSSA | cute.Numeric:
20
- if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
21
- res = cute.make_fragment(val.shape, val.dtype)
22
- res.store(val)
23
- for i in cutlass.range_constexpr(cute.size(val.shape)):
24
- res[i] = warp_reduce(res[i], op, width)
25
- return res.load()
26
- else:
27
- for i in cutlass.range_constexpr(int(math.log2(width))):
28
- val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
29
- return val
30
-
31
-
32
14
  @cute.jit
33
15
  def block_reduce(
34
16
  val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
@@ -43,7 +25,7 @@ def block_reduce(
43
25
  block_reduce_val = init_val
44
26
  if lane_idx < warps_per_row:
45
27
  block_reduce_val = reduction_buffer[row_idx, lane_idx]
46
- return warp_reduce(block_reduce_val, op)
28
+ return cute.arch.warp_reduction(block_reduce_val, op)
47
29
 
48
30
 
49
31
  @cute.jit
@@ -53,7 +35,7 @@ def cluster_reduce(
53
35
  reduction_buffer: cute.Tensor,
54
36
  mbar_ptr: cute.Pointer,
55
37
  init_val: cute.Numeric = 0.0,
56
- phase: Optional[cutlass.Int32] = None,
38
+ phase: Optional[Int32] = None,
57
39
  ) -> cute.Numeric:
58
40
  """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
59
41
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
@@ -81,7 +63,7 @@ def cluster_reduce(
81
63
  idx = lane_idx + i * cute.arch.WARP_SIZE
82
64
  if idx < cute.size(reduction_buffer, mode=[1]):
83
65
  block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
84
- return warp_reduce(block_reduce_val, op)
66
+ return cute.arch.warp_reduction(block_reduce_val, op)
85
67
 
86
68
 
87
69
  @cute.jit
@@ -90,11 +72,11 @@ def block_or_cluster_reduce(
90
72
  op: Callable,
91
73
  reduction_buffer: cute.Tensor,
92
74
  mbar_ptr: Optional[cute.Pointer],
93
- phase: Optional[cutlass.Int32] = None,
75
+ phase: Optional[Int32] = None,
94
76
  init_val: cute.Numeric = 0.0,
95
77
  ) -> cute.Numeric:
96
78
  """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
97
- if cutlass.const_expr(mbar_ptr is None):
79
+ if const_expr(mbar_ptr is None):
98
80
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
99
81
  else:
100
82
  return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
@@ -107,34 +89,34 @@ def row_reduce(
107
89
  threads_per_row: cutlass.Constexpr[int],
108
90
  reduction_buffer: Optional[cute.Tensor] = None,
109
91
  mbar_ptr: Optional[cute.Pointer] = None,
110
- phase: Optional[cutlass.Int32] = None,
92
+ phase: Optional[Int32] = None,
111
93
  init_val: cute.Numeric = 0.0,
112
94
  hook_fn: Optional[Callable] = None,
113
95
  ) -> cute.Numeric:
114
96
  """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
115
- if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
97
+ if const_expr(isinstance(x, cute.TensorSSA)):
116
98
  val = x.reduce(op, init_val=init_val, reduction_profile=0)
117
99
  else:
118
100
  val = x
119
101
  warp_op = {
120
102
  cute.ReductionOp.ADD: operator.add,
121
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
103
+ cute.ReductionOp.MAX: cute.arch.fmax if const_expr(x.dtype == Float32) else max,
122
104
  cute.ReductionOp.MIN: min,
123
105
  cute.ReductionOp.MUL: operator.mul,
124
106
  }[op]
125
- val = warp_reduce(
107
+ val = cute.arch.warp_reduction(
126
108
  val,
127
109
  warp_op,
128
- width=min(threads_per_row, cute.arch.WARP_SIZE),
110
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
129
111
  )
130
- if cutlass.const_expr(hook_fn is not None):
112
+ if const_expr(hook_fn is not None):
131
113
  hook_fn()
132
- if cutlass.const_expr(reduction_buffer is not None):
114
+ if const_expr(reduction_buffer is not None):
133
115
  warps_per_row, cluster_n = reduction_buffer.shape[1]
134
116
  assert cluster_n == 1 or mbar_ptr is not None, (
135
117
  "mbar_ptr must be provided for cluster reduction"
136
118
  )
137
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
119
+ if const_expr(warps_per_row > 1 or cluster_n > 1):
138
120
  val = block_or_cluster_reduce(
139
121
  val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
140
122
  )
@@ -148,37 +130,37 @@ def online_softmax_reduce(
148
130
  reduction_buffer: Optional[cute.Tensor] = None,
149
131
  mbar_ptr: Optional[cute.Pointer] = None,
150
132
  hook_fn: Optional[Callable] = None,
151
- phase: Optional[cutlass.Int32] = None,
133
+ phase: Optional[Int32] = None,
152
134
  return_exp_x: bool = False,
153
135
  ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
154
136
  assert x.dtype == Float32, "x must be of type Float32"
155
137
  """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
156
- max_x = warp_reduce(
138
+ max_x = cute.arch.warp_reduction(
157
139
  x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
158
140
  cute.arch.fmax,
159
- width=min(threads_per_row, cute.arch.WARP_SIZE),
141
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
160
142
  )
161
143
  log2_e = math.log2(math.e)
162
144
  exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
163
- sum_exp_x = warp_reduce(
145
+ sum_exp_x = cute.arch.warp_reduction(
164
146
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
165
147
  operator.add,
166
- width=min(threads_per_row, cute.arch.WARP_SIZE),
148
+ threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
167
149
  )
168
- if cutlass.const_expr(hook_fn is not None):
150
+ if const_expr(hook_fn is not None):
169
151
  hook_fn()
170
- if cutlass.const_expr(reduction_buffer is not None):
152
+ if const_expr(reduction_buffer is not None):
171
153
  rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
172
154
  assert cluster_n == 1 or mbar_ptr is not None, (
173
155
  "mbar_ptr must be provided for cluster reduction"
174
156
  )
175
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
176
- assert reduction_buffer.element_type == cutlass.Int64, (
157
+ if const_expr(warps_per_row > 1 or cluster_n > 1):
158
+ assert reduction_buffer.element_type == Int64, (
177
159
  "reduction_buffer must be of type cute.Int64"
178
160
  )
179
161
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
180
162
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
181
- if cutlass.const_expr(mbar_ptr is None):
163
+ if const_expr(mbar_ptr is None):
182
164
  if lane_idx == 0:
183
165
  reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
184
166
  cute.arch.barrier()
@@ -188,10 +170,10 @@ def online_softmax_reduce(
188
170
  max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
189
171
  reduction_buffer[row_idx, lane_idx]
190
172
  )
191
- max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
173
+ max_x_final = cute.arch.warp_reduction(max_x_single_warp, cute.arch.fmax)
192
174
  sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
193
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
194
- if cutlass.const_expr(return_exp_x):
175
+ sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
176
+ if const_expr(return_exp_x):
195
177
  exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
196
178
  max_x = max_x_final
197
179
  else:
@@ -227,14 +209,71 @@ def online_softmax_reduce(
227
209
  max_x_final = max_x_single_warp.load().reduce(
228
210
  cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
229
211
  )
230
- max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
212
+ max_x_final = cute.arch.warp_reduction(max_x_final, cute.arch.fmax)
231
213
  sum_exp_x = 0.0
232
214
  for i in cutlass.range_constexpr(num_iter):
233
215
  sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
234
216
  max_x_single_warp[i] - max_x_final, fastmath=True
235
217
  )
236
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
237
- if cutlass.const_expr(return_exp_x):
218
+ sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add)
219
+ if const_expr(return_exp_x):
238
220
  exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
239
221
  max_x = max_x_final
240
- return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
222
+ return max_x, sum_exp_x, (exp_x if const_expr(return_exp_x) else None)
223
+
224
+
225
+ @cute.jit
226
+ def sum_swap_shuffle(
227
+ X: cute.Tensor, elem_per_lane: int = 1, subwarp_size: int = 1, warp_size: int = 32
228
+ ) -> cute.Tensor:
229
+ """
230
+ For warp reduction, we use Swap Shuffle
231
+ The normal way to reduction among threads:
232
+ use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads.
233
+ After each step of reduction, a half of threads won't work in the following steps.
234
+ That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case).
235
+ To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors,
236
+ we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads.
237
+ After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step.
238
+ We can recursively do this until the problem size is 1.
239
+ """
240
+ assert (
241
+ subwarp_size >= 1
242
+ and subwarp_size <= 32
243
+ and subwarp_size == 1 << int(math.log2(subwarp_size))
244
+ )
245
+ assert (
246
+ warp_size <= 32
247
+ and warp_size % subwarp_size == 0
248
+ and warp_size == 1 << int(math.log2(warp_size))
249
+ )
250
+ lane_idx = cute.arch.lane_idx() // subwarp_size
251
+ X = cute.logical_divide(X, cute.make_layout(elem_per_lane)) # (elem_per_lane, M)
252
+ numvec = cute.size(X, mode=[1])
253
+ assert numvec <= 32 // subwarp_size
254
+ # If X has more values than warp_size // subwarp_size, we first do a normal warp reduction
255
+ # to sum up values held by lanes further than size(X) away
256
+ for i in cutlass.range(
257
+ int(math.log2(numvec)), int(math.log2(warp_size // subwarp_size)), unroll_full=True
258
+ ):
259
+ for v in cutlass.range(cute.size(X), unroll_full=True):
260
+ shfl_val = cute.arch.shuffle_sync_bfly(X[v], offset=(1 << i) * subwarp_size)
261
+ X[v] = X[v] + shfl_val
262
+ for logm in cutlass.range_constexpr(int(math.log2(cute.size(X, mode=[1]))) - 1, -1, -1):
263
+ m = 1 << logm
264
+ for r in cutlass.range(m, unroll_full=True):
265
+ frg_A = X[None, r]
266
+ frg_B = X[None, r + m]
267
+ # First half of threads swap fragments from the first half of data to the second
268
+ should_swap = not Boolean(lane_idx & m)
269
+ for v in cutlass.range(cute.size(frg_A), unroll_full=True):
270
+ # Step 1: swap
271
+ lower, upper = frg_A[v], frg_B[v]
272
+ frg_A[v] = upper if should_swap else lower
273
+ frg_B[v] = lower if should_swap else upper
274
+ # Step 2: shuffle
275
+ # each half of threads get a half of data from the other half of threads
276
+ shfl_val = cute.arch.shuffle_sync_bfly(frg_A[v], offset=m * subwarp_size)
277
+ # Step 3: reduction
278
+ frg_A[v] = frg_B[v] + shfl_val
279
+ return X[None, 0]
quack/reduction_base.py CHANGED
@@ -4,55 +4,44 @@ from typing import Type, Tuple, Optional
4
4
 
5
5
  import cutlass
6
6
  import cutlass.cute as cute
7
+ from cutlass import Int32, Int64, Float32, const_expr
8
+
9
+ import quack.copy_utils as copy_utils
7
10
 
8
11
 
9
12
  class ReductionBase:
10
- def __init__(
11
- self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
12
- ):
13
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32):
13
14
  self.dtype = dtype
14
15
  self.N = N
15
16
  self.stage = stage
16
17
  self.reduction_dtype = reduction_dtype
17
18
 
18
- def _calculate_threads_per_row(self):
19
+ def _threads_per_row(self):
19
20
  raise NotImplementedError()
20
21
 
22
+ def _num_threads(self):
23
+ return 128 if self.N <= 16384 else 256
24
+
21
25
  def _set_cluster_n(self):
22
26
  self.cluster_n = 1
23
27
 
24
- def _get_num_threads(self):
25
- return 128 if self.N <= 16384 else 256
26
-
27
- def _get_tv_layout(self, num_copy_bits=128):
28
- vecsize = num_copy_bits // self.dtype.width
28
+ def _get_tiled_copy(self, vecsize: int = 1):
29
29
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
30
- num_threads = self._get_num_threads()
30
+ threads_per_row = self._threads_per_row()
31
+ num_threads = self._num_threads()
31
32
  assert num_threads % cute.arch.WARP_SIZE == 0
32
-
33
- threads_per_row = self._calculate_threads_per_row()
34
33
  num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
35
- cols_per_block = num_threads // threads_per_row
36
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
37
- tv_layout = cute.make_layout(
38
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
39
- stride=(
40
- (vecsize * cols_per_block, 1),
41
- (cols_per_block, cols_per_block * vecsize * threads_per_row),
42
- ),
43
- )
44
- return tiler_mn, tv_layout
45
-
46
- def _smem_size_in_bytes(self, tiler_mn, num_warps):
47
- return (
48
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
49
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
50
- + self.stage * (cutlass.Int64.width // 8)
51
- )
34
+ tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row)
35
+ tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize)
36
+ return tiled_copy, tiler_mn, threads_per_row
52
37
 
53
38
  def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
54
39
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
55
- warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
40
+ warps_per_row = (
41
+ num_warps
42
+ if cute.rank(tv_layout.shape[0]) == 1
43
+ else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
44
+ )
56
45
  return cute.make_ordered_layout(
57
46
  (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
58
47
  order=(1, 0, 2),
@@ -64,11 +53,11 @@ class ReductionBase:
64
53
  reduction_buffer = smem.allocate_tensor(
65
54
  self.reduction_dtype,
66
55
  self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
67
- byte_alignment=4,
56
+ byte_alignment=8,
68
57
  )
69
- if cutlass.const_expr(self.cluster_n > 1):
58
+ if const_expr(self.cluster_n > 1):
70
59
  mbar_ptr = smem.allocate_array(
71
- cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
60
+ Int64, num_elems=self.stage if not is_persistent else self.stage * 2
72
61
  )
73
62
  else:
74
63
  mbar_ptr = None
@@ -77,15 +66,15 @@ class ReductionBase:
77
66
  @cute.jit
78
67
  def _initialize_cluster(
79
68
  self,
80
- tidx: cutlass.Int32,
69
+ tidx: Int32,
81
70
  mbar_ptr: cute.Pointer,
82
71
  num_warps: int,
83
72
  is_persistent: bool = False,
84
73
  ):
85
- if cutlass.const_expr(self.cluster_n > 1):
74
+ if const_expr(self.cluster_n > 1):
86
75
  if tidx < self.stage: # Initialize full barrier
87
76
  cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
88
- if cutlass.const_expr(is_persistent): # Initialize empty barrier
77
+ if const_expr(is_persistent): # Initialize empty barrier
89
78
  cute.arch.mbarrier_init(
90
79
  mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
91
80
  )