quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/pipeline.py
CHANGED
|
@@ -6,9 +6,10 @@ from dataclasses import dataclass
|
|
|
6
6
|
import cutlass.cute as cute
|
|
7
7
|
from cutlass import Boolean, Int32, const_expr
|
|
8
8
|
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
|
9
|
-
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
|
10
|
-
from cutlass.pipeline import
|
|
9
|
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
|
10
|
+
from cutlass.pipeline import PipelineTmaAsync, PipelineState, PipelineUserType
|
|
11
11
|
from cutlass.pipeline import PipelineTmaUmma
|
|
12
|
+
from cutlass.pipeline import Agent, agent_sync
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class PipelineStateWAdvance(PipelineState):
|
|
@@ -57,75 +58,12 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
57
58
|
"""
|
|
58
59
|
|
|
59
60
|
@staticmethod
|
|
60
|
-
def create(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
barrier_storage: cute.Pointer = None,
|
|
67
|
-
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
68
|
-
tidx: Optional[Int32] = None,
|
|
69
|
-
):
|
|
70
|
-
"""
|
|
71
|
-
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
|
72
|
-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
73
|
-
:type barrier_storage: cute.Pointer
|
|
74
|
-
:param num_stages: Number of buffer stages for this pipeline
|
|
75
|
-
:type num_stages: Int32
|
|
76
|
-
:param producer_group: CooperativeGroup for the producer agent
|
|
77
|
-
:type producer_group: CooperativeGroup
|
|
78
|
-
:param consumer_group: CooperativeGroup for the consumer agent
|
|
79
|
-
:type consumer_group: CooperativeGroup
|
|
80
|
-
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
81
|
-
:type tx_count: int
|
|
82
|
-
:param cta_layout_vmnk: Layout of the cluster shape
|
|
83
|
-
:type cta_layout_vmnk: cute.Layout | None
|
|
84
|
-
:param tidx: thread index to consumer async threads
|
|
85
|
-
:type tidx: Int32 | None
|
|
86
|
-
"""
|
|
87
|
-
if not isinstance(barrier_storage, cute.Pointer):
|
|
88
|
-
raise ValueError(
|
|
89
|
-
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
producer_type = PipelineOp.TmaLoad
|
|
93
|
-
consumer_type = PipelineOp.AsyncThread
|
|
94
|
-
|
|
95
|
-
producer = (producer_type, producer_group)
|
|
96
|
-
consumer = (consumer_type, consumer_group)
|
|
97
|
-
|
|
98
|
-
sync_object_full = PipelineAsync._make_sync_object(
|
|
99
|
-
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
|
100
|
-
)
|
|
101
|
-
sync_object_empty = PipelineAsync._make_sync_object(
|
|
102
|
-
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
103
|
-
)
|
|
104
|
-
if tidx is None:
|
|
105
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
-
if cta_layout_vmnk is None:
|
|
107
|
-
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
|
108
|
-
(
|
|
109
|
-
dst_rank,
|
|
110
|
-
is_signalling_thread,
|
|
111
|
-
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
|
112
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
113
|
-
dst_rank = None
|
|
114
|
-
else:
|
|
115
|
-
dst_rank = dst_rank
|
|
116
|
-
|
|
117
|
-
producer_mask = None
|
|
118
|
-
|
|
119
|
-
pipeline_init_wait(cta_layout_vmnk)
|
|
120
|
-
|
|
121
|
-
return PipelineTmaCpAsync(
|
|
122
|
-
sync_object_full,
|
|
123
|
-
sync_object_empty,
|
|
124
|
-
num_stages,
|
|
125
|
-
producer_mask,
|
|
126
|
-
dst_rank,
|
|
127
|
-
is_signalling_thread,
|
|
128
|
-
)
|
|
61
|
+
def create(*args, **kwargs):
|
|
62
|
+
obj = PipelineTmaAsync.create(*args, **kwargs)
|
|
63
|
+
# Can't assign to __class__ directly since the dataclass is frozen
|
|
64
|
+
# obj.__class__ = PipelineTmaCpAsync
|
|
65
|
+
object.__setattr__(obj, "__class__", PipelineTmaCpAsync)
|
|
66
|
+
return obj
|
|
129
67
|
|
|
130
68
|
@dsl_user_op
|
|
131
69
|
def producer_acquire(
|
|
@@ -143,12 +81,16 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
143
81
|
if_generate(
|
|
144
82
|
try_acquire_token is None or try_acquire_token == 0,
|
|
145
83
|
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
84
|
+
loc=loc,
|
|
85
|
+
ip=ip,
|
|
146
86
|
)
|
|
147
87
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
148
88
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
149
89
|
if_generate(
|
|
150
90
|
is_tma_warp,
|
|
151
91
|
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
92
|
+
loc=loc,
|
|
93
|
+
ip=ip,
|
|
152
94
|
)
|
|
153
95
|
|
|
154
96
|
@dsl_user_op
|
|
@@ -156,7 +98,9 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
156
98
|
"""
|
|
157
99
|
We need the mbarrier to track the completion of cp.async
|
|
158
100
|
"""
|
|
159
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
101
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
102
|
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
|
103
|
+
)
|
|
160
104
|
|
|
161
105
|
|
|
162
106
|
class MbarrierArrayWDropCount(MbarrierArray):
|
|
@@ -211,6 +155,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
211
155
|
(e.g. Blackwell mainloops)
|
|
212
156
|
"""
|
|
213
157
|
|
|
158
|
+
@dsl_user_op
|
|
214
159
|
@staticmethod
|
|
215
160
|
def create(
|
|
216
161
|
*,
|
|
@@ -220,28 +165,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
220
165
|
tx_count: int,
|
|
221
166
|
barrier_storage: cute.Pointer = None,
|
|
222
167
|
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
223
|
-
producer_drop_count: Optional[Int32] = None,
|
|
224
168
|
mcast_mode_mn: tuple[int, int] = (1, 1),
|
|
169
|
+
defer_sync: bool = False,
|
|
170
|
+
producer_drop_count: Optional[Int32] = None,
|
|
171
|
+
loc=None,
|
|
172
|
+
ip=None,
|
|
225
173
|
):
|
|
226
|
-
"""
|
|
227
|
-
|
|
228
|
-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
229
|
-
:type barrier_storage: cute.Pointer
|
|
174
|
+
"""Creates and initializes a new PipelineTmaUmma instance.
|
|
175
|
+
|
|
230
176
|
:param num_stages: Number of buffer stages for this pipeline
|
|
231
|
-
:type num_stages:
|
|
232
|
-
:param producer_group:
|
|
177
|
+
:type num_stages: int
|
|
178
|
+
:param producer_group: CooperativeGroup for the producer agent
|
|
233
179
|
:type producer_group: CooperativeGroup
|
|
234
|
-
:param consumer_group:
|
|
180
|
+
:param consumer_group: CooperativeGroup for the consumer agent
|
|
235
181
|
:type consumer_group: CooperativeGroup
|
|
236
182
|
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
237
183
|
:type tx_count: int
|
|
184
|
+
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
|
|
185
|
+
:type barrier_storage: cute.Pointer, optional
|
|
238
186
|
:param cta_layout_vmnk: Layout of the cluster shape
|
|
239
|
-
:type cta_layout_vmnk: cute.Layout
|
|
187
|
+
:type cta_layout_vmnk: cute.Layout, optional
|
|
240
188
|
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
|
241
189
|
:type mcast_mode_mn: tuple[int, int], optional
|
|
190
|
+
:raises ValueError: If barrier_storage is not a cute.Pointer instance
|
|
191
|
+
:return: A new PipelineTmaUmma instance configured with the provided parameters
|
|
192
|
+
:rtype: PipelineTmaUmma
|
|
242
193
|
"""
|
|
243
194
|
if not isinstance(barrier_storage, cute.Pointer):
|
|
244
|
-
raise
|
|
195
|
+
raise TypeError(
|
|
245
196
|
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
246
197
|
)
|
|
247
198
|
|
|
@@ -257,29 +208,42 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
257
208
|
producer,
|
|
258
209
|
tx_count,
|
|
259
210
|
drop_count=producer_drop_count,
|
|
211
|
+
loc=loc,
|
|
212
|
+
ip=ip,
|
|
260
213
|
)
|
|
261
214
|
sync_object_empty = PipelineTmaUmma._make_sync_object(
|
|
262
|
-
barrier_storage.align(min_align=8) + num_stages,
|
|
215
|
+
barrier_storage.align(min_align=8) + num_stages,
|
|
216
|
+
num_stages,
|
|
217
|
+
consumer,
|
|
218
|
+
loc=loc,
|
|
219
|
+
ip=ip,
|
|
263
220
|
)
|
|
264
221
|
|
|
265
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
222
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
|
266
223
|
# No mcast mask if not using clusters
|
|
267
224
|
producer_mask = None
|
|
268
225
|
# All threadblocks are leaders if not using clusters
|
|
269
226
|
is_leader_cta = True
|
|
270
227
|
else:
|
|
271
|
-
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
|
272
|
-
|
|
228
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
|
229
|
+
cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
|
|
230
|
+
)
|
|
231
|
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk, loc=loc, ip=ip)
|
|
273
232
|
|
|
274
233
|
cta_group = (
|
|
275
234
|
cute.nvgpu.tcgen05.CtaGroup.ONE
|
|
276
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
|
235
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
|
|
277
236
|
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
|
278
237
|
)
|
|
279
238
|
|
|
280
239
|
consumer_mask = producer_mask
|
|
281
240
|
|
|
282
|
-
|
|
241
|
+
if not defer_sync:
|
|
242
|
+
cute.arch.mbarrier_init_fence()
|
|
243
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
|
244
|
+
agent_sync(Agent.ThreadBlock)
|
|
245
|
+
else:
|
|
246
|
+
agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
|
|
283
247
|
|
|
284
248
|
return PipelineTmaCpAsyncUmma(
|
|
285
249
|
sync_object_full,
|
|
@@ -308,12 +272,16 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
308
272
|
if_generate(
|
|
309
273
|
try_acquire_token is None or try_acquire_token == 0,
|
|
310
274
|
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
275
|
+
loc=loc,
|
|
276
|
+
ip=ip,
|
|
311
277
|
)
|
|
312
278
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
313
279
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
314
280
|
if_generate(
|
|
315
281
|
and_(self.is_leader_cta, is_tma_warp),
|
|
316
282
|
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
283
|
+
loc=loc,
|
|
284
|
+
ip=ip,
|
|
317
285
|
)
|
|
318
286
|
|
|
319
287
|
@dsl_user_op
|
|
@@ -321,4 +289,6 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
321
289
|
"""
|
|
322
290
|
We need the mbarrier to track the completion of cp.async
|
|
323
291
|
"""
|
|
324
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
292
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
293
|
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
|
294
|
+
)
|
quack/reduce.py
CHANGED
|
@@ -196,9 +196,9 @@ def online_softmax_reduce(
|
|
|
196
196
|
)
|
|
197
197
|
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
198
198
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
199
|
-
max_x_single_warp = cute.
|
|
199
|
+
max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
|
200
200
|
max_x_single_warp.fill(-Float32.inf)
|
|
201
|
-
sum_exp_x_single_warp = cute.
|
|
201
|
+
sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
|
202
202
|
sum_exp_x_single_warp.fill(0.0)
|
|
203
203
|
for i in cutlass.range_constexpr(num_iter):
|
|
204
204
|
idx = lane_idx + i * cute.arch.WARP_SIZE
|
quack/rmsnorm.py
CHANGED
|
@@ -686,9 +686,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
686
686
|
|
|
687
687
|
if const_expr(self.cluster_n > 1):
|
|
688
688
|
# Need this fence since the STAS from the producer is using the async proxy.
|
|
689
|
-
cute.arch.
|
|
690
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
691
|
-
)
|
|
689
|
+
cute.arch.fence_view_async_shared()
|
|
692
690
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
693
691
|
# Requires adjusting the thread_count when initializing the mbar
|
|
694
692
|
cute.arch.sync_warp()
|
quack/sm90_utils.py
CHANGED
|
@@ -102,7 +102,7 @@ def gemm_zero_init(
|
|
|
102
102
|
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
|
103
103
|
)
|
|
104
104
|
else:
|
|
105
|
-
acc = cute.
|
|
105
|
+
acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
|
|
106
106
|
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
107
107
|
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
108
108
|
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
|
@@ -137,7 +137,7 @@ def partition_fragment_ABC(
|
|
|
137
137
|
):
|
|
138
138
|
is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
|
|
139
139
|
if const_expr(not swap_AB):
|
|
140
|
-
acc = cute.
|
|
140
|
+
acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
|
|
141
141
|
if const_expr(not is_rs):
|
|
142
142
|
assert sA is not None
|
|
143
143
|
tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
@@ -146,7 +146,9 @@ def partition_fragment_ABC(
|
|
|
146
146
|
assert sB is not None
|
|
147
147
|
tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
148
148
|
else:
|
|
149
|
-
acc = cute.
|
|
149
|
+
acc = cute.make_rmem_tensor(
|
|
150
|
+
thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
|
|
151
|
+
)
|
|
150
152
|
if const_expr(not is_rs):
|
|
151
153
|
assert sB is not None
|
|
152
154
|
tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
|
quack/sort/bitonic_sort.py
CHANGED
|
@@ -108,12 +108,12 @@ def bitonic_topk(
|
|
|
108
108
|
n = cute.size(arr.shape)
|
|
109
109
|
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
|
110
110
|
assert n % k == 0, "n must be divisible by k"
|
|
111
|
-
topk_vals = cute.
|
|
111
|
+
topk_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
112
112
|
for v in cutlass.range(k, unroll_full=True):
|
|
113
113
|
topk_vals[v] = arr[v]
|
|
114
114
|
bitonic_sort(topk_vals, ascending=ascending)
|
|
115
115
|
for i in cutlass.range(1, n // k, unroll_full=True):
|
|
116
|
-
other_vals = cute.
|
|
116
|
+
other_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
117
117
|
for v in cutlass.range(k, unroll_full=True):
|
|
118
118
|
other_vals[v] = arr[i * k + v]
|
|
119
119
|
bitonic_sort(other_vals, ascending=ascending)
|
|
@@ -122,7 +122,7 @@ def bitonic_topk(
|
|
|
122
122
|
# TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
|
|
123
123
|
# do duplicate work.
|
|
124
124
|
for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
|
|
125
|
-
other_vals = cute.
|
|
125
|
+
other_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
126
126
|
for v in cutlass.range(k, unroll_full=True):
|
|
127
127
|
other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
|
|
128
128
|
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|