quack-kernels 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +11 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +143 -20
- quack/cute_dsl_ptxas.py +151 -0
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +10 -4
- quack/linear.py +37 -0
- quack/pipeline.py +87 -99
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +34 -2
- quack/sort/bitonic_sort.py +4 -4
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +1 -1
- quack_kernels-0.2.4.dist-info/RECORD +0 -44
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/pipeline.py
CHANGED
|
@@ -5,14 +5,16 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
import cutlass.cute as cute
|
|
7
7
|
from cutlass import Boolean, Int32, const_expr
|
|
8
|
-
from cutlass.cutlass_dsl import if_generate, and_
|
|
9
|
-
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
|
10
|
-
from cutlass.pipeline import
|
|
8
|
+
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
|
9
|
+
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
|
|
10
|
+
from cutlass.pipeline import PipelineTmaAsync, PipelineState, PipelineUserType
|
|
11
11
|
from cutlass.pipeline import PipelineTmaUmma
|
|
12
|
+
from cutlass.pipeline import Agent, agent_sync
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class PipelineStateWAdvance(PipelineState):
|
|
15
|
-
|
|
16
|
+
@dsl_user_op
|
|
17
|
+
def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
|
|
16
18
|
self._count += Int32(num_iterations)
|
|
17
19
|
new_index = self._index + Int32(num_iterations)
|
|
18
20
|
# How many times did we cross the stages boundary
|
|
@@ -56,104 +58,53 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
56
58
|
"""
|
|
57
59
|
|
|
58
60
|
@staticmethod
|
|
59
|
-
def create(
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
tidx: Optional[Int32] = None,
|
|
68
|
-
):
|
|
69
|
-
"""
|
|
70
|
-
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
|
71
|
-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
72
|
-
:type barrier_storage: cute.Pointer
|
|
73
|
-
:param num_stages: Number of buffer stages for this pipeline
|
|
74
|
-
:type num_stages: Int32
|
|
75
|
-
:param producer_group: CooperativeGroup for the producer agent
|
|
76
|
-
:type producer_group: CooperativeGroup
|
|
77
|
-
:param consumer_group: CooperativeGroup for the consumer agent
|
|
78
|
-
:type consumer_group: CooperativeGroup
|
|
79
|
-
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
80
|
-
:type tx_count: int
|
|
81
|
-
:param cta_layout_vmnk: Layout of the cluster shape
|
|
82
|
-
:type cta_layout_vmnk: cute.Layout | None
|
|
83
|
-
:param tidx: thread index to consumer async threads
|
|
84
|
-
:type tidx: Int32 | None
|
|
85
|
-
"""
|
|
86
|
-
if not isinstance(barrier_storage, cute.Pointer):
|
|
87
|
-
raise ValueError(
|
|
88
|
-
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
producer_type = PipelineOp.TmaLoad
|
|
92
|
-
consumer_type = PipelineOp.AsyncThread
|
|
93
|
-
|
|
94
|
-
producer = (producer_type, producer_group)
|
|
95
|
-
consumer = (consumer_type, consumer_group)
|
|
96
|
-
|
|
97
|
-
sync_object_full = PipelineAsync._make_sync_object(
|
|
98
|
-
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
|
99
|
-
)
|
|
100
|
-
sync_object_empty = PipelineAsync._make_sync_object(
|
|
101
|
-
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
102
|
-
)
|
|
103
|
-
if tidx is None:
|
|
104
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
105
|
-
if cta_layout_vmnk is None:
|
|
106
|
-
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
|
107
|
-
(
|
|
108
|
-
dst_rank,
|
|
109
|
-
is_signalling_thread,
|
|
110
|
-
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
|
111
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
112
|
-
dst_rank = None
|
|
113
|
-
else:
|
|
114
|
-
dst_rank = dst_rank
|
|
115
|
-
|
|
116
|
-
producer_mask = None
|
|
117
|
-
|
|
118
|
-
pipeline_init_wait(cta_layout_vmnk)
|
|
119
|
-
|
|
120
|
-
return PipelineTmaCpAsync(
|
|
121
|
-
sync_object_full,
|
|
122
|
-
sync_object_empty,
|
|
123
|
-
num_stages,
|
|
124
|
-
producer_mask,
|
|
125
|
-
dst_rank,
|
|
126
|
-
is_signalling_thread,
|
|
127
|
-
)
|
|
128
|
-
|
|
61
|
+
def create(*args, **kwargs):
|
|
62
|
+
obj = PipelineTmaAsync.create(*args, **kwargs)
|
|
63
|
+
# Can't assign to __class__ directly since the dataclass is frozen
|
|
64
|
+
# obj.__class__ = PipelineTmaCpAsync
|
|
65
|
+
object.__setattr__(obj, "__class__", PipelineTmaCpAsync)
|
|
66
|
+
return obj
|
|
67
|
+
|
|
68
|
+
@dsl_user_op
|
|
129
69
|
def producer_acquire(
|
|
130
70
|
self,
|
|
131
71
|
state: PipelineState,
|
|
132
72
|
try_acquire_token: Optional[Boolean] = None,
|
|
133
73
|
is_tma_warp: Optional[Boolean] = True,
|
|
74
|
+
*,
|
|
75
|
+
loc=None,
|
|
76
|
+
ip=None,
|
|
134
77
|
):
|
|
135
78
|
"""
|
|
136
79
|
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
|
137
80
|
"""
|
|
138
81
|
if_generate(
|
|
139
82
|
try_acquire_token is None or try_acquire_token == 0,
|
|
140
|
-
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
83
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
84
|
+
loc=loc,
|
|
85
|
+
ip=ip,
|
|
141
86
|
)
|
|
142
87
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
143
88
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
144
89
|
if_generate(
|
|
145
90
|
is_tma_warp,
|
|
146
|
-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
91
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
92
|
+
loc=loc,
|
|
93
|
+
ip=ip,
|
|
147
94
|
)
|
|
148
95
|
|
|
149
|
-
|
|
96
|
+
@dsl_user_op
|
|
97
|
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
|
150
98
|
"""
|
|
151
99
|
We need the mbarrier to track the completion of cp.async
|
|
152
100
|
"""
|
|
153
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
101
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
102
|
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
|
103
|
+
)
|
|
154
104
|
|
|
155
105
|
|
|
156
106
|
class MbarrierArrayWDropCount(MbarrierArray):
|
|
107
|
+
@dsl_user_op
|
|
157
108
|
def __init__(
|
|
158
109
|
self,
|
|
159
110
|
barrier_storage: cute.Pointer,
|
|
@@ -161,6 +112,9 @@ class MbarrierArrayWDropCount(MbarrierArray):
|
|
|
161
112
|
agent: tuple[PipelineOp, CooperativeGroup],
|
|
162
113
|
tx_count: int = 0,
|
|
163
114
|
drop_count: Optional[Int32] = None,
|
|
115
|
+
*,
|
|
116
|
+
loc=None,
|
|
117
|
+
ip=None,
|
|
164
118
|
) -> None:
|
|
165
119
|
self.barrier_storage = barrier_storage
|
|
166
120
|
self.tx_count = tx_count
|
|
@@ -183,7 +137,7 @@ class MbarrierArrayWDropCount(MbarrierArray):
|
|
|
183
137
|
self.mbarrier_base = self.barrier_storage
|
|
184
138
|
|
|
185
139
|
# Mbarrier initialization in constructor
|
|
186
|
-
self.mbarrier_init()
|
|
140
|
+
self.mbarrier_init(loc=loc, ip=ip)
|
|
187
141
|
|
|
188
142
|
def __extract_mlir_values__(self):
|
|
189
143
|
return [self.barrier_storage, self.drop_count]
|
|
@@ -201,6 +155,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
201
155
|
(e.g. Blackwell mainloops)
|
|
202
156
|
"""
|
|
203
157
|
|
|
158
|
+
@dsl_user_op
|
|
204
159
|
@staticmethod
|
|
205
160
|
def create(
|
|
206
161
|
*,
|
|
@@ -210,25 +165,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
210
165
|
tx_count: int,
|
|
211
166
|
barrier_storage: cute.Pointer = None,
|
|
212
167
|
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
168
|
+
mcast_mode_mn: tuple[int, int] = (1, 1),
|
|
169
|
+
defer_sync: bool = False,
|
|
213
170
|
producer_drop_count: Optional[Int32] = None,
|
|
171
|
+
loc=None,
|
|
172
|
+
ip=None,
|
|
214
173
|
):
|
|
215
|
-
"""
|
|
216
|
-
|
|
217
|
-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
218
|
-
:type barrier_storage: cute.Pointer
|
|
174
|
+
"""Creates and initializes a new PipelineTmaUmma instance.
|
|
175
|
+
|
|
219
176
|
:param num_stages: Number of buffer stages for this pipeline
|
|
220
|
-
:type num_stages:
|
|
221
|
-
:param producer_group:
|
|
177
|
+
:type num_stages: int
|
|
178
|
+
:param producer_group: CooperativeGroup for the producer agent
|
|
222
179
|
:type producer_group: CooperativeGroup
|
|
223
|
-
:param consumer_group:
|
|
180
|
+
:param consumer_group: CooperativeGroup for the consumer agent
|
|
224
181
|
:type consumer_group: CooperativeGroup
|
|
225
182
|
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
226
183
|
:type tx_count: int
|
|
184
|
+
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
|
|
185
|
+
:type barrier_storage: cute.Pointer, optional
|
|
227
186
|
:param cta_layout_vmnk: Layout of the cluster shape
|
|
228
|
-
:type cta_layout_vmnk: cute.Layout
|
|
187
|
+
:type cta_layout_vmnk: cute.Layout, optional
|
|
188
|
+
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
|
189
|
+
:type mcast_mode_mn: tuple[int, int], optional
|
|
190
|
+
:raises ValueError: If barrier_storage is not a cute.Pointer instance
|
|
191
|
+
:return: A new PipelineTmaUmma instance configured with the provided parameters
|
|
192
|
+
:rtype: PipelineTmaUmma
|
|
229
193
|
"""
|
|
230
194
|
if not isinstance(barrier_storage, cute.Pointer):
|
|
231
|
-
raise
|
|
195
|
+
raise TypeError(
|
|
232
196
|
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
233
197
|
)
|
|
234
198
|
|
|
@@ -244,29 +208,42 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
244
208
|
producer,
|
|
245
209
|
tx_count,
|
|
246
210
|
drop_count=producer_drop_count,
|
|
211
|
+
loc=loc,
|
|
212
|
+
ip=ip,
|
|
247
213
|
)
|
|
248
|
-
sync_object_empty =
|
|
249
|
-
barrier_storage.align(min_align=8) + num_stages,
|
|
214
|
+
sync_object_empty = PipelineTmaUmma._make_sync_object(
|
|
215
|
+
barrier_storage.align(min_align=8) + num_stages,
|
|
216
|
+
num_stages,
|
|
217
|
+
consumer,
|
|
218
|
+
loc=loc,
|
|
219
|
+
ip=ip,
|
|
250
220
|
)
|
|
251
221
|
|
|
252
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
222
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
|
253
223
|
# No mcast mask if not using clusters
|
|
254
224
|
producer_mask = None
|
|
255
225
|
# All threadblocks are leaders if not using clusters
|
|
256
226
|
is_leader_cta = True
|
|
257
227
|
else:
|
|
258
|
-
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
|
259
|
-
|
|
228
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
|
|
229
|
+
cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
|
|
230
|
+
)
|
|
231
|
+
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk, loc=loc, ip=ip)
|
|
260
232
|
|
|
261
233
|
cta_group = (
|
|
262
234
|
cute.nvgpu.tcgen05.CtaGroup.ONE
|
|
263
|
-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
|
235
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
|
|
264
236
|
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
|
265
237
|
)
|
|
266
238
|
|
|
267
239
|
consumer_mask = producer_mask
|
|
268
240
|
|
|
269
|
-
|
|
241
|
+
if not defer_sync:
|
|
242
|
+
cute.arch.mbarrier_init_fence()
|
|
243
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
|
|
244
|
+
agent_sync(Agent.ThreadBlock)
|
|
245
|
+
else:
|
|
246
|
+
agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
|
|
270
247
|
|
|
271
248
|
return PipelineTmaCpAsyncUmma(
|
|
272
249
|
sync_object_full,
|
|
@@ -278,11 +255,15 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
278
255
|
cta_group,
|
|
279
256
|
)
|
|
280
257
|
|
|
258
|
+
@dsl_user_op
|
|
281
259
|
def producer_acquire(
|
|
282
260
|
self,
|
|
283
261
|
state: PipelineState,
|
|
284
262
|
try_acquire_token: Optional[Boolean] = None,
|
|
285
263
|
is_tma_warp: Optional[Boolean] = True,
|
|
264
|
+
*,
|
|
265
|
+
loc=None,
|
|
266
|
+
ip=None,
|
|
286
267
|
):
|
|
287
268
|
"""
|
|
288
269
|
TMA producer commit conditionally waits on buffer empty and sets the
|
|
@@ -290,17 +271,24 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
290
271
|
"""
|
|
291
272
|
if_generate(
|
|
292
273
|
try_acquire_token is None or try_acquire_token == 0,
|
|
293
|
-
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
274
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
275
|
+
loc=loc,
|
|
276
|
+
ip=ip,
|
|
294
277
|
)
|
|
295
278
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
296
279
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
297
280
|
if_generate(
|
|
298
281
|
and_(self.is_leader_cta, is_tma_warp),
|
|
299
|
-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
282
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
283
|
+
loc=loc,
|
|
284
|
+
ip=ip,
|
|
300
285
|
)
|
|
301
286
|
|
|
302
|
-
|
|
287
|
+
@dsl_user_op
|
|
288
|
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
|
303
289
|
"""
|
|
304
290
|
We need the mbarrier to track the completion of cp.async
|
|
305
291
|
"""
|
|
306
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
292
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(
|
|
293
|
+
self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
|
|
294
|
+
)
|
quack/reduce.py
CHANGED
|
@@ -196,9 +196,9 @@ def online_softmax_reduce(
|
|
|
196
196
|
)
|
|
197
197
|
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
198
198
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
199
|
-
max_x_single_warp = cute.
|
|
199
|
+
max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
|
200
200
|
max_x_single_warp.fill(-Float32.inf)
|
|
201
|
-
sum_exp_x_single_warp = cute.
|
|
201
|
+
sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
|
|
202
202
|
sum_exp_x_single_warp.fill(0.0)
|
|
203
203
|
for i in cutlass.range_constexpr(num_iter):
|
|
204
204
|
idx = lane_idx + i * cute.arch.WARP_SIZE
|
quack/rmsnorm.py
CHANGED
|
@@ -686,9 +686,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
686
686
|
|
|
687
687
|
if const_expr(self.cluster_n > 1):
|
|
688
688
|
# Need this fence since the STAS from the producer is using the async proxy.
|
|
689
|
-
cute.arch.
|
|
690
|
-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
691
|
-
)
|
|
689
|
+
cute.arch.fence_view_async_shared()
|
|
692
690
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
693
691
|
# Requires adjusting the thread_count when initializing the mbar
|
|
694
692
|
cute.arch.sync_warp()
|
quack/sm90_utils.py
CHANGED
|
@@ -27,10 +27,11 @@ def make_smem_layout(
|
|
|
27
27
|
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
|
28
28
|
dtype,
|
|
29
29
|
)
|
|
30
|
+
order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
|
|
30
31
|
smem_layout_staged = cute.tile_to_shape(
|
|
31
32
|
smem_layout_atom,
|
|
32
33
|
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
|
33
|
-
order=(
|
|
34
|
+
order=order if const_expr(stage is not None) else order[:2],
|
|
34
35
|
)
|
|
35
36
|
return smem_layout_staged
|
|
36
37
|
|
|
@@ -101,7 +102,7 @@ def gemm_zero_init(
|
|
|
101
102
|
tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
|
|
102
103
|
)
|
|
103
104
|
else:
|
|
104
|
-
acc = cute.
|
|
105
|
+
acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
|
|
105
106
|
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
106
107
|
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
107
108
|
gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
|
|
@@ -125,3 +126,34 @@ def gemm_w_idx(
|
|
|
125
126
|
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
126
127
|
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
127
128
|
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def partition_fragment_ABC(
|
|
132
|
+
thr_mma: cute.ThrMma,
|
|
133
|
+
shape_mnk: cute.Shape,
|
|
134
|
+
sA: Optional[cute.Tensor],
|
|
135
|
+
sB: Optional[cute.Tensor],
|
|
136
|
+
swap_AB: bool = False,
|
|
137
|
+
):
|
|
138
|
+
is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
|
|
139
|
+
if const_expr(not swap_AB):
|
|
140
|
+
acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
|
|
141
|
+
if const_expr(not is_rs):
|
|
142
|
+
assert sA is not None
|
|
143
|
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
144
|
+
else:
|
|
145
|
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
|
|
146
|
+
assert sB is not None
|
|
147
|
+
tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
148
|
+
else:
|
|
149
|
+
acc = cute.make_rmem_tensor(
|
|
150
|
+
thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
|
|
151
|
+
)
|
|
152
|
+
if const_expr(not is_rs):
|
|
153
|
+
assert sB is not None
|
|
154
|
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
|
|
155
|
+
else: # B in rmem
|
|
156
|
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
|
|
157
|
+
assert sA is not None
|
|
158
|
+
tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
|
|
159
|
+
return acc, tCrA, tCrB
|
quack/sort/bitonic_sort.py
CHANGED
|
@@ -83,7 +83,7 @@ def bitonic_topk_merge(
|
|
|
83
83
|
else:
|
|
84
84
|
minmax_fn = min if ascending else max
|
|
85
85
|
# Write the top k elements to the first half of the array
|
|
86
|
-
for i in cutlass.range(k,
|
|
86
|
+
for i in cutlass.range(k, unroll_full=True):
|
|
87
87
|
arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
|
|
88
88
|
# Now the 1st half is bitonic, we just need to merge it
|
|
89
89
|
bitonic_merge(arr0, k, start0, ascending)
|
|
@@ -108,12 +108,12 @@ def bitonic_topk(
|
|
|
108
108
|
n = cute.size(arr.shape)
|
|
109
109
|
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
|
110
110
|
assert n % k == 0, "n must be divisible by k"
|
|
111
|
-
topk_vals = cute.
|
|
111
|
+
topk_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
112
112
|
for v in cutlass.range(k, unroll_full=True):
|
|
113
113
|
topk_vals[v] = arr[v]
|
|
114
114
|
bitonic_sort(topk_vals, ascending=ascending)
|
|
115
115
|
for i in cutlass.range(1, n // k, unroll_full=True):
|
|
116
|
-
other_vals = cute.
|
|
116
|
+
other_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
117
117
|
for v in cutlass.range(k, unroll_full=True):
|
|
118
118
|
other_vals[v] = arr[i * k + v]
|
|
119
119
|
bitonic_sort(other_vals, ascending=ascending)
|
|
@@ -122,7 +122,7 @@ def bitonic_topk(
|
|
|
122
122
|
# TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
|
|
123
123
|
# do duplicate work.
|
|
124
124
|
for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
|
|
125
|
-
other_vals = cute.
|
|
125
|
+
other_vals = cute.make_rmem_tensor(k, arr.element_type)
|
|
126
126
|
for v in cutlass.range(k, unroll_full=True):
|
|
127
127
|
other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
|
|
128
128
|
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|