quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/mlp.py ADDED
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from quack.linear import linear_act_func, act_linear_func
8
+
9
+
10
+ def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True):
11
+ preact, postact = linear_act_func(
12
+ x,
13
+ weight1,
14
+ activation,
15
+ store_preact=torch.is_grad_enabled(),
16
+ fuse_grad_accum=fuse_grad_accum,
17
+ tuned=tuned,
18
+ )
19
+ out = act_linear_func(
20
+ preact,
21
+ weight2,
22
+ postact,
23
+ activation=activation,
24
+ fuse_grad_accum=fuse_grad_accum,
25
+ tuned=tuned,
26
+ )
27
+ return out
28
+
29
+
30
+ class MLP(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_features,
34
+ hidden_features=None,
35
+ out_features=None,
36
+ bias1=False,
37
+ bias2=False,
38
+ activation="gelu",
39
+ device=None,
40
+ dtype=None,
41
+ fuse_grad_accum: bool = False,
42
+ tuned: bool = True,
43
+ ):
44
+ factory_kwargs = {"device": device, "dtype": dtype}
45
+ super().__init__()
46
+ out_features = out_features if out_features is not None else in_features
47
+ hidden_features = hidden_features if hidden_features is not None else 4 * in_features
48
+ self.activation = activation
49
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
50
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
51
+ self.fuse_grad_accum = fuse_grad_accum
52
+ self.tuned = tuned
53
+
54
+ def forward(self, input: Tensor) -> Tensor:
55
+ if (
56
+ self.fc1.bias is None
57
+ and self.fc2.bias is None
58
+ and input.is_cuda
59
+ and input.stride(-1) == 1
60
+ and self.fc1.in_features % 8 == 0
61
+ and self.fc1.out_features % 8 == 0
62
+ and self.fc2.out_features % 8 == 0
63
+ ):
64
+ return mlp_func(
65
+ input,
66
+ self.fc1.weight,
67
+ self.fc2.weight,
68
+ activation=self.activation,
69
+ fuse_grad_accum=self.fuse_grad_accum,
70
+ tuned=self.tuned,
71
+ )
72
+ else:
73
+ y = self.fc1(input)
74
+ return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
quack/pipeline.py ADDED
@@ -0,0 +1,151 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass.cutlass_dsl import Boolean, Int32, if_generate
8
+ from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
9
+ from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
10
+
11
+
12
+ class PipelineStateWAdvance(PipelineState):
13
+ def advance_iters(self, num_iterations: Int32):
14
+ self._count += Int32(num_iterations)
15
+ new_index = self._index + Int32(num_iterations)
16
+ # How many times did we cross the stages boundary
17
+ num_crossings = new_index // self.stages
18
+ self._phase ^= num_crossings
19
+ self._index = new_index % self.stages
20
+
21
+ # This can be overridden by derived classes
22
+ def __new_from_mlir_values__(self, values):
23
+ return PipelineStateWAdvance(
24
+ self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
25
+ )
26
+
27
+
28
+ def make_pipeline_state(type: PipelineUserType, stages: int):
29
+ """
30
+ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
31
+ """
32
+ if type is PipelineUserType.Producer:
33
+ return PipelineStateWAdvance(
34
+ stages,
35
+ Int32(0),
36
+ Int32(0),
37
+ Int32(1),
38
+ )
39
+ elif type is PipelineUserType.Consumer:
40
+ return PipelineStateWAdvance(
41
+ stages,
42
+ Int32(0),
43
+ Int32(0),
44
+ Int32(0),
45
+ )
46
+ else:
47
+ assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class PipelineTmaCpAsync(PipelineTmaAsync):
52
+ """
53
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
54
+ """
55
+
56
+ @staticmethod
57
+ def create(
58
+ *,
59
+ num_stages: int,
60
+ producer_group: CooperativeGroup,
61
+ consumer_group: CooperativeGroup,
62
+ tx_count: int,
63
+ barrier_storage: cute.Pointer = None,
64
+ cta_layout_vmnk: Optional[cute.Layout] = None,
65
+ tidx: Optional[Int32] = None,
66
+ ):
67
+ """
68
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
69
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
70
+ :type barrier_storage: cute.Pointer
71
+ :param num_stages: Number of buffer stages for this pipeline
72
+ :type num_stages: Int32
73
+ :param producer_group: CooperativeGroup for the producer agent
74
+ :type producer_group: CooperativeGroup
75
+ :param consumer_group: CooperativeGroup for the consumer agent
76
+ :type consumer_group: CooperativeGroup
77
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
78
+ :type tx_count: int
79
+ :param cta_layout_vmnk: Layout of the cluster shape
80
+ :type cta_layout_vmnk: cute.Layout | None
81
+ :param tidx: thread index to consumer async threads
82
+ :type tidx: Int32 | None
83
+ """
84
+ if not isinstance(barrier_storage, cute.Pointer):
85
+ raise ValueError(
86
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
87
+ )
88
+
89
+ producer_type = PipelineOp.TmaLoad
90
+ consumer_type = PipelineOp.AsyncThread
91
+
92
+ producer = (producer_type, producer_group)
93
+ consumer = (consumer_type, consumer_group)
94
+
95
+ sync_object_full = PipelineAsync._make_sync_object(
96
+ barrier_storage.align(min_align=8), num_stages, producer, tx_count
97
+ )
98
+ sync_object_empty = PipelineAsync._make_sync_object(
99
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
100
+ )
101
+ if tidx is None:
102
+ tidx, _, _ = cute.arch.thread_idx()
103
+ if cta_layout_vmnk is None:
104
+ cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
105
+ (
106
+ dst_rank,
107
+ is_signalling_thread,
108
+ ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
109
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
110
+ dst_rank = None
111
+ else:
112
+ dst_rank = dst_rank
113
+
114
+ producer_mask = None
115
+
116
+ pipeline_init_wait(cta_layout_vmnk)
117
+
118
+ return PipelineTmaCpAsync(
119
+ sync_object_full,
120
+ sync_object_empty,
121
+ num_stages,
122
+ producer_mask,
123
+ dst_rank,
124
+ is_signalling_thread,
125
+ )
126
+
127
+ def producer_acquire(
128
+ self,
129
+ state: PipelineState,
130
+ try_acquire_token: Optional[Boolean] = None,
131
+ is_tma_warp: Optional[Boolean] = True,
132
+ ):
133
+ """
134
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
135
+ """
136
+ if_generate(
137
+ try_acquire_token is None or try_acquire_token == 0,
138
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
139
+ )
140
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
141
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
142
+ if_generate(
143
+ is_tma_warp,
144
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
145
+ )
146
+
147
+ def producer_commit(self, state: PipelineState):
148
+ """
149
+ We need the mbarrier to track the completion of cp.async
150
+ """
151
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
quack/reduce.py ADDED
@@ -0,0 +1,241 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ import operator
5
+ from typing import Callable, Optional
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32
10
+
11
+ import quack.utils as utils
12
+
13
+
14
+ @cute.jit
15
+ def warp_reduce(
16
+ val: cute.TensorSSA | cute.Numeric,
17
+ op: Callable,
18
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
19
+ ) -> cute.TensorSSA | cute.Numeric:
20
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
21
+ res = cute.make_fragment(val.shape, val.dtype)
22
+ res.store(val)
23
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
24
+ res[i] = warp_reduce(res[i], op, width)
25
+ return res.load()
26
+ else:
27
+ for i in cutlass.range_constexpr(int(math.log2(width))):
28
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
29
+ return val
30
+
31
+
32
+ @cute.jit
33
+ def block_reduce(
34
+ val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
35
+ ) -> cute.Numeric:
36
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
37
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
38
+ warps_per_row = cute.size(reduction_buffer.shape[1])
39
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
40
+ if lane_idx == 0:
41
+ reduction_buffer[row_idx, col_idx] = val
42
+ cute.arch.barrier()
43
+ block_reduce_val = init_val
44
+ if lane_idx < warps_per_row:
45
+ block_reduce_val = reduction_buffer[row_idx, lane_idx]
46
+ return warp_reduce(block_reduce_val, op)
47
+
48
+
49
+ @cute.jit
50
+ def cluster_reduce(
51
+ val: cute.Numeric,
52
+ op: Callable,
53
+ reduction_buffer: cute.Tensor,
54
+ mbar_ptr: cute.Pointer,
55
+ init_val: cute.Numeric = 0.0,
56
+ phase: Optional[cutlass.Int32] = None,
57
+ ) -> cute.Numeric:
58
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
59
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
60
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
61
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
62
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
63
+ if warp_idx == 0:
64
+ with cute.arch.elect_one():
65
+ num_warps = rows_per_block * warps_per_row
66
+ cute.arch.mbarrier_arrive_and_expect_tx(
67
+ mbar_ptr,
68
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
69
+ )
70
+ if lane_idx < cluster_n:
71
+ utils.store_shared_remote(
72
+ val,
73
+ utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
74
+ mbar_ptr,
75
+ peer_cta_rank_in_cluster=lane_idx,
76
+ )
77
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
78
+ block_reduce_val = init_val
79
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
80
+ for i in cutlass.range_constexpr(num_iter):
81
+ idx = lane_idx + i * cute.arch.WARP_SIZE
82
+ if idx < cute.size(reduction_buffer, mode=[1]):
83
+ block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
84
+ return warp_reduce(block_reduce_val, op)
85
+
86
+
87
+ @cute.jit
88
+ def block_or_cluster_reduce(
89
+ val: cute.Numeric,
90
+ op: Callable,
91
+ reduction_buffer: cute.Tensor,
92
+ mbar_ptr: Optional[cute.Pointer],
93
+ phase: Optional[cutlass.Int32] = None,
94
+ init_val: cute.Numeric = 0.0,
95
+ ) -> cute.Numeric:
96
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
97
+ if cutlass.const_expr(mbar_ptr is None):
98
+ return block_reduce(val, op, reduction_buffer, init_val=init_val)
99
+ else:
100
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
101
+
102
+
103
+ @cute.jit
104
+ def row_reduce(
105
+ x: cute.TensorSSA | cute.Numeric,
106
+ op: cute.ReductionOp,
107
+ threads_per_row: cutlass.Constexpr[int],
108
+ reduction_buffer: Optional[cute.Tensor] = None,
109
+ mbar_ptr: Optional[cute.Pointer] = None,
110
+ phase: Optional[cutlass.Int32] = None,
111
+ init_val: cute.Numeric = 0.0,
112
+ hook_fn: Optional[Callable] = None,
113
+ ) -> cute.Numeric:
114
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
115
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
116
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
117
+ else:
118
+ val = x
119
+ warp_op = {
120
+ cute.ReductionOp.ADD: operator.add,
121
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
122
+ cute.ReductionOp.MIN: min,
123
+ cute.ReductionOp.MUL: operator.mul,
124
+ }[op]
125
+ val = warp_reduce(
126
+ val,
127
+ warp_op,
128
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
129
+ )
130
+ if cutlass.const_expr(hook_fn is not None):
131
+ hook_fn()
132
+ if cutlass.const_expr(reduction_buffer is not None):
133
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
134
+ assert cluster_n == 1 or mbar_ptr is not None, (
135
+ "mbar_ptr must be provided for cluster reduction"
136
+ )
137
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
138
+ val = block_or_cluster_reduce(
139
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
140
+ )
141
+ return val
142
+
143
+
144
+ @cute.jit
145
+ def online_softmax_reduce(
146
+ x: cute.TensorSSA,
147
+ threads_per_row: cutlass.Constexpr[int],
148
+ reduction_buffer: Optional[cute.Tensor] = None,
149
+ mbar_ptr: Optional[cute.Pointer] = None,
150
+ hook_fn: Optional[Callable] = None,
151
+ phase: Optional[cutlass.Int32] = None,
152
+ return_exp_x: bool = False,
153
+ ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
154
+ assert x.dtype == Float32, "x must be of type Float32"
155
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
156
+ max_x = warp_reduce(
157
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
158
+ cute.arch.fmax,
159
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
160
+ )
161
+ log2_e = math.log2(math.e)
162
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
163
+ # exp_x = exp2f((x - max_x) * log2_e)
164
+ sum_exp_x = warp_reduce(
165
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
166
+ operator.add,
167
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
168
+ )
169
+ if cutlass.const_expr(hook_fn is not None):
170
+ hook_fn()
171
+ if cutlass.const_expr(reduction_buffer is not None):
172
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
173
+ assert cluster_n == 1 or mbar_ptr is not None, (
174
+ "mbar_ptr must be provided for cluster reduction"
175
+ )
176
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
177
+ assert reduction_buffer.element_type == cutlass.Int64, (
178
+ "reduction_buffer must be of type cute.Int64"
179
+ )
180
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
181
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
182
+ if cutlass.const_expr(mbar_ptr is None):
183
+ if lane_idx == 0:
184
+ reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
185
+ cute.arch.barrier()
186
+ max_x_single_warp = -Float32.inf
187
+ sum_exp_x = 0.0
188
+ if lane_idx < warps_per_row:
189
+ max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
190
+ reduction_buffer[row_idx, lane_idx]
191
+ )
192
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
193
+ sum_exp_x *= utils.exp2f((max_x_single_warp - max_x_final) * log2_e)
194
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
195
+ if cutlass.const_expr(return_exp_x):
196
+ exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
197
+ max_x = max_x_final
198
+ else:
199
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
200
+ if warp_idx == 0:
201
+ with cute.arch.elect_one():
202
+ num_warps = rows_per_block * warps_per_row
203
+ cute.arch.mbarrier_arrive_and_expect_tx(
204
+ mbar_ptr,
205
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
206
+ )
207
+ if lane_idx < cluster_n:
208
+ utils.store_shared_remote(
209
+ utils.f32x2_to_i64(max_x, sum_exp_x),
210
+ utils.elem_pointer(
211
+ reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
212
+ ),
213
+ mbar_ptr,
214
+ peer_cta_rank_in_cluster=lane_idx,
215
+ )
216
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
217
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
218
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
219
+ max_x_single_warp.fill(-Float32.inf)
220
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
221
+ sum_exp_x_single_warp.fill(0.0)
222
+ for i in cutlass.range_constexpr(num_iter):
223
+ idx = lane_idx + i * cute.arch.WARP_SIZE
224
+ if idx < cute.size(reduction_buffer, mode=[1]):
225
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
226
+ reduction_buffer[row_idx, idx]
227
+ )
228
+ max_x_final = max_x_single_warp.load().reduce(
229
+ cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
230
+ )
231
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
232
+ sum_exp_x = 0.0
233
+ for i in cutlass.range_constexpr(num_iter):
234
+ sum_exp_x += sum_exp_x_single_warp[i] * utils.exp2f(
235
+ (max_x_single_warp[i] - max_x_final) * log2_e
236
+ )
237
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
238
+ if cutlass.const_expr(return_exp_x):
239
+ exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
240
+ max_x = max_x_final
241
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
quack/reduction_base.py CHANGED
@@ -1,19 +1,11 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import torch
4
3
  from typing import Type, Tuple, Optional
5
4
 
6
5
  import cutlass
7
6
  import cutlass.cute as cute
8
7
 
9
8
 
10
- torch2cute_dtype_map = {
11
- torch.float16: cutlass.Float16,
12
- torch.bfloat16: cutlass.BFloat16,
13
- torch.float32: cutlass.Float32,
14
- }
15
-
16
-
17
9
  class ReductionBase:
18
10
  def __init__(
19
11
  self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
@@ -32,9 +24,8 @@ class ReductionBase:
32
24
  def _get_num_threads(self):
33
25
  return 128 if self.N <= 16384 else 256
34
26
 
35
- def _get_tv_layout(self):
36
- copy_bits = 128
37
- vecsize = copy_bits // self.dtype.width
27
+ def _get_tv_layout(self, num_copy_bits=128):
28
+ vecsize = num_copy_bits // self.dtype.width
38
29
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
39
30
  num_threads = self._get_num_threads()
40
31
  assert num_threads % cute.arch.WARP_SIZE == 0