quack-kernels 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +12 -9
- quack/reduction_base.py +2 -2
- quack/rmsnorm.py +13 -12
- quack/softmax.py +25 -17
- quack/utils.py +15 -12
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.4.dist-info}/METADATA +2 -2
- quack_kernels-0.1.4.dist-info/RECORD +11 -0
- quack_kernels-0.1.3.dist-info/RECORD +0 -11
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.4.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -77,7 +77,7 @@ class CrossEntropy(ReductionBase):
|
|
|
77
77
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
78
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
79
|
block=[num_threads, 1, 1],
|
|
80
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
80
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
81
81
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
82
|
stream=stream,
|
|
83
83
|
)
|
|
@@ -93,15 +93,16 @@ class CrossEntropy(ReductionBase):
|
|
|
93
93
|
tiler_mn: cute.Shape,
|
|
94
94
|
):
|
|
95
95
|
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
-
bidx,
|
|
96
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
97
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
98
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
99
|
+
else:
|
|
100
|
+
cluster_y = cutlass.const_expr(0)
|
|
97
101
|
|
|
98
102
|
shape: cute.Shape = mX.shape
|
|
99
103
|
idX = cute.make_identity_tensor(shape)
|
|
100
104
|
# slice for CTAs
|
|
101
|
-
gX, cX = [
|
|
102
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
103
|
-
for mT in (mX, idX)
|
|
104
|
-
]
|
|
105
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
105
106
|
|
|
106
107
|
smem = cutlass.utils.SmemAllocator()
|
|
107
108
|
sX = smem.allocate_tensor(
|
|
@@ -131,7 +132,9 @@ class CrossEntropy(ReductionBase):
|
|
|
131
132
|
|
|
132
133
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
133
134
|
tXpX = (
|
|
134
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
135
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
136
|
+
if cutlass.const_expr(not is_even_N)
|
|
137
|
+
else None
|
|
135
138
|
)
|
|
136
139
|
if row < shape[0]:
|
|
137
140
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -154,7 +157,7 @@ class CrossEntropy(ReductionBase):
|
|
|
154
157
|
cute.ReductionOp.MAX,
|
|
155
158
|
threads_per_row,
|
|
156
159
|
reduction_buffer[None, None, 0],
|
|
157
|
-
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
160
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
158
161
|
init_val=-cutlass.Float32.inf,
|
|
159
162
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
163
|
)
|
|
@@ -172,7 +175,7 @@ class CrossEntropy(ReductionBase):
|
|
|
172
175
|
cute.ReductionOp.ADD,
|
|
173
176
|
threads_per_row,
|
|
174
177
|
reduction_buffer[None, None, 1],
|
|
175
|
-
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
178
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
176
179
|
init_val=0.0,
|
|
177
180
|
)
|
|
178
181
|
else:
|
quack/reduction_base.py
CHANGED
|
@@ -88,10 +88,10 @@ class ReductionBase:
|
|
|
88
88
|
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
89
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
90
|
if tidx < self.stage:
|
|
91
|
-
cute.arch.
|
|
91
|
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
92
92
|
cute.arch.mbarrier_init_fence()
|
|
93
93
|
if tidx < self.stage:
|
|
94
|
-
cute.arch.
|
|
94
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
95
95
|
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
96
|
)
|
|
97
97
|
# Cluster arrive after barrier init
|
quack/rmsnorm.py
CHANGED
|
@@ -84,7 +84,7 @@ class RMSNorm(ReductionBase):
|
|
|
84
84
|
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
85
85
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
86
86
|
block=[num_threads, 1, 1],
|
|
87
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
87
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
88
88
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
89
89
|
stream=stream,
|
|
90
90
|
)
|
|
@@ -103,7 +103,11 @@ class RMSNorm(ReductionBase):
|
|
|
103
103
|
delay_w_load: cutlass.Constexpr = False,
|
|
104
104
|
):
|
|
105
105
|
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
-
bidx,
|
|
106
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
107
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
108
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
109
|
+
else:
|
|
110
|
+
cluster_y = cutlass.const_expr(0)
|
|
107
111
|
|
|
108
112
|
smem = cutlass.utils.SmemAllocator()
|
|
109
113
|
sX = smem.allocate_tensor(
|
|
@@ -114,13 +118,10 @@ class RMSNorm(ReductionBase):
|
|
|
114
118
|
shape = mX.shape
|
|
115
119
|
idX = cute.make_identity_tensor(shape)
|
|
116
120
|
# slice for CTAs
|
|
117
|
-
gX, gO, cX = [
|
|
118
|
-
|
|
119
|
-
for mT in (mX, mO, idX)
|
|
120
|
-
]
|
|
121
|
-
gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
121
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
122
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
122
123
|
gRstd = (
|
|
123
|
-
cute.local_tile(mRstd, tiler_mn, (bidx,
|
|
124
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
124
125
|
if cutlass.const_expr(mRstd is not None)
|
|
125
126
|
else None
|
|
126
127
|
)
|
|
@@ -167,7 +168,7 @@ class RMSNorm(ReductionBase):
|
|
|
167
168
|
cute.arch.cp_async_commit_group()
|
|
168
169
|
|
|
169
170
|
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
170
|
-
if not delay_w_load:
|
|
171
|
+
if cutlass.const_expr(not delay_w_load):
|
|
171
172
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
172
173
|
|
|
173
174
|
cute.arch.cp_async_wait_group(0)
|
|
@@ -192,12 +193,12 @@ class RMSNorm(ReductionBase):
|
|
|
192
193
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
193
194
|
):
|
|
194
195
|
tXrRstd[0] = rstd
|
|
195
|
-
if delay_w_load:
|
|
196
|
+
if cutlass.const_expr(delay_w_load):
|
|
196
197
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
197
|
-
if reload_from == "smem":
|
|
198
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
198
199
|
cute.autovec_copy(tXsX, tXrX)
|
|
199
200
|
x = tXrX.load().to(cute.Float32)
|
|
200
|
-
elif reload_from == "gmem":
|
|
201
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
201
202
|
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
202
203
|
x = tXrX.load().to(cute.Float32)
|
|
203
204
|
x_hat = x * rstd
|
quack/softmax.py
CHANGED
|
@@ -75,7 +75,7 @@ class Softmax(ReductionBase):
|
|
|
75
75
|
self.kernel(mX, mO, tv_layout, tiler_mn).launch(
|
|
76
76
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
77
77
|
block=[num_threads, 1, 1],
|
|
78
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
78
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
79
79
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
80
80
|
stream=stream,
|
|
81
81
|
)
|
|
@@ -89,15 +89,16 @@ class Softmax(ReductionBase):
|
|
|
89
89
|
tiler_mn: cute.Shape,
|
|
90
90
|
):
|
|
91
91
|
tidx, _, _ = cute.arch.thread_idx()
|
|
92
|
-
bidx,
|
|
92
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
93
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
94
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
95
|
+
else:
|
|
96
|
+
cluster_y = cutlass.const_expr(0)
|
|
93
97
|
|
|
94
98
|
shape = mX.shape
|
|
95
99
|
idX = cute.make_identity_tensor(shape)
|
|
96
100
|
# slice for CTAs
|
|
97
|
-
gX, gO, cX = [
|
|
98
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
99
|
-
for mT in (mX, mO, idX)
|
|
100
|
-
]
|
|
101
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
101
102
|
|
|
102
103
|
smem = cutlass.utils.SmemAllocator()
|
|
103
104
|
sX = smem.allocate_tensor(
|
|
@@ -129,7 +130,9 @@ class Softmax(ReductionBase):
|
|
|
129
130
|
|
|
130
131
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
131
132
|
tXpX = (
|
|
132
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
133
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
134
|
+
if cutlass.const_expr(not is_even_N)
|
|
135
|
+
else None
|
|
133
136
|
)
|
|
134
137
|
if tXcX[0][0] < shape[0]:
|
|
135
138
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -148,7 +151,7 @@ class Softmax(ReductionBase):
|
|
|
148
151
|
cute.ReductionOp.MAX,
|
|
149
152
|
threads_per_row,
|
|
150
153
|
reduction_buffer[None, None, 0],
|
|
151
|
-
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
154
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
152
155
|
init_val=-cutlass.Float32.inf,
|
|
153
156
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
154
157
|
)
|
|
@@ -159,7 +162,7 @@ class Softmax(ReductionBase):
|
|
|
159
162
|
cute.ReductionOp.ADD,
|
|
160
163
|
threads_per_row,
|
|
161
164
|
reduction_buffer[None, None, 1],
|
|
162
|
-
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
165
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
163
166
|
init_val=0.0,
|
|
164
167
|
)
|
|
165
168
|
else:
|
|
@@ -174,7 +177,9 @@ class Softmax(ReductionBase):
|
|
|
174
177
|
y = exp_x * (1.0 / denom)
|
|
175
178
|
tXrO.store(y.to(tXrO.element_type))
|
|
176
179
|
tOpO = (
|
|
177
|
-
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
180
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
181
|
+
if cutlass.const_expr(not is_even_N)
|
|
182
|
+
else None
|
|
178
183
|
)
|
|
179
184
|
if tXcX[0][0] < shape[0]:
|
|
180
185
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
@@ -283,7 +288,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
283
288
|
self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
|
|
284
289
|
grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
285
290
|
block=[num_threads, 1, 1],
|
|
286
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
291
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
287
292
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
288
293
|
stream=stream,
|
|
289
294
|
)
|
|
@@ -298,14 +303,17 @@ class SoftmaxBackward(ReductionBase):
|
|
|
298
303
|
tiler_mn: cute.Shape,
|
|
299
304
|
):
|
|
300
305
|
tidx, _, _ = cute.arch.thread_idx()
|
|
301
|
-
bidx,
|
|
306
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
307
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
308
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
309
|
+
else:
|
|
310
|
+
cluster_y = cutlass.const_expr(0)
|
|
302
311
|
|
|
303
312
|
shape = mdY.shape
|
|
304
313
|
idX = cute.make_identity_tensor(shape)
|
|
305
314
|
# slice for CTAs
|
|
306
315
|
gdY, gY, gdX, cX = [
|
|
307
|
-
cute.local_tile(mT, tiler_mn, (bidx,
|
|
308
|
-
for mT in (mdY, mY, mdX, idX)
|
|
316
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
|
|
309
317
|
]
|
|
310
318
|
|
|
311
319
|
smem = cutlass.utils.SmemAllocator()
|
|
@@ -344,7 +352,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
344
352
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
345
353
|
tdYpdY = (
|
|
346
354
|
utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
|
|
347
|
-
if not is_even_N
|
|
355
|
+
if cutlass.const_expr(not is_even_N)
|
|
348
356
|
else None
|
|
349
357
|
)
|
|
350
358
|
|
|
@@ -366,7 +374,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
366
374
|
cute.ReductionOp.ADD,
|
|
367
375
|
threads_per_row,
|
|
368
376
|
reduction_buffer[None, None, 0],
|
|
369
|
-
mbar_ptr if self.cluster_n > 1 else None,
|
|
377
|
+
mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
370
378
|
init_val=0.0,
|
|
371
379
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
372
380
|
)
|
|
@@ -376,7 +384,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
376
384
|
tdXrdX.store(dx.to(tdXrdX.element_type))
|
|
377
385
|
tdXpdX = (
|
|
378
386
|
utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
|
|
379
|
-
if not is_even_N
|
|
387
|
+
if cutlass.const_expr(not is_even_N)
|
|
380
388
|
else None
|
|
381
389
|
)
|
|
382
390
|
if tXcX[0][0] < shape[0]:
|
quack/utils.py
CHANGED
|
@@ -37,19 +37,20 @@ def min_constexpr(
|
|
|
37
37
|
return a if a < b else b
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
@cute.jit
|
|
40
41
|
def warp_reduce(
|
|
41
42
|
val: cute.TensorSSA | cute.Numeric,
|
|
42
43
|
op: Callable,
|
|
43
44
|
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
44
45
|
) -> cute.TensorSSA | cute.Numeric:
|
|
45
|
-
if isinstance(val, cute.TensorSSA):
|
|
46
|
+
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
46
47
|
res = cute.make_fragment(val.shape, val.dtype)
|
|
47
48
|
res.store(val)
|
|
48
|
-
for i in
|
|
49
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
49
50
|
res[i] = warp_reduce(res[i], op, width)
|
|
50
51
|
return res.load()
|
|
51
52
|
else:
|
|
52
|
-
for i in
|
|
53
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
53
54
|
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
54
55
|
return val
|
|
55
56
|
|
|
@@ -111,15 +112,15 @@ def store_shared_remote(
|
|
|
111
112
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
112
113
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
113
114
|
).ir_value()
|
|
114
|
-
if isinstance(val, float):
|
|
115
|
+
if cutlass.const_expr(isinstance(val, float)):
|
|
115
116
|
val = Float32(val)
|
|
116
117
|
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
117
|
-
suffix = "f32" if isinstance(val, Float32) else "s64"
|
|
118
|
+
suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
|
|
118
119
|
llvm.inline_asm(
|
|
119
120
|
None,
|
|
120
121
|
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
121
122
|
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
122
|
-
f"r,{'f' if isinstance(val, Float32) else 'l'},r",
|
|
123
|
+
f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
|
|
123
124
|
has_side_effects=True,
|
|
124
125
|
is_align_stack=False,
|
|
125
126
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -299,6 +300,7 @@ def online_softmax_reduce(
|
|
|
299
300
|
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
300
301
|
|
|
301
302
|
|
|
303
|
+
@cute.jit
|
|
302
304
|
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
303
305
|
"""exp2f calculation for both vector and scalar.
|
|
304
306
|
|
|
@@ -307,10 +309,10 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
|
307
309
|
:return: exp2 value
|
|
308
310
|
:rtype: cute.TensorSSA or Float32
|
|
309
311
|
"""
|
|
310
|
-
if isinstance(x, cute.TensorSSA):
|
|
312
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
311
313
|
res = cute.make_fragment(x.shape, Float32)
|
|
312
314
|
res.store(x)
|
|
313
|
-
for i in
|
|
315
|
+
for i in cutlass.range_constexpr(cute.size(x.shape)):
|
|
314
316
|
res[i] = cute.arch.exp2(res[i])
|
|
315
317
|
return res.load()
|
|
316
318
|
else:
|
|
@@ -347,6 +349,7 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
347
349
|
)
|
|
348
350
|
|
|
349
351
|
|
|
352
|
+
@cute.jit
|
|
350
353
|
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
351
354
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
352
355
|
tApA = cute.make_fragment(
|
|
@@ -356,8 +359,8 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
356
359
|
),
|
|
357
360
|
cutlass.Boolean,
|
|
358
361
|
)
|
|
359
|
-
for rest_v in
|
|
360
|
-
for rest_k in
|
|
362
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
363
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
361
364
|
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
362
365
|
return tApA
|
|
363
366
|
|
|
@@ -373,8 +376,8 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
|
|
|
373
376
|
"""
|
|
374
377
|
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
375
378
|
tXrX_fill.fill(fill_value)
|
|
376
|
-
for rest_v in
|
|
377
|
-
for rest_k in
|
|
379
|
+
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
|
|
380
|
+
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
|
|
378
381
|
if not tXpX[rest_v, 0, rest_k]:
|
|
379
382
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
380
383
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Requires-Python: >=3.9
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.0.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
|
|
2
|
+
quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
|
|
3
|
+
quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
|
|
4
|
+
quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
|
|
5
|
+
quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
|
|
6
|
+
quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
|
|
7
|
+
quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
|
|
9
|
+
quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
+
quack_kernels-0.1.4.dist-info/RECORD,,
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
|
|
2
|
-
quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
|
|
3
|
-
quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
|
|
4
|
-
quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
|
|
5
|
-
quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
|
|
6
|
-
quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
|
|
7
|
-
quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
|
|
9
|
-
quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
-
quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
-
quack_kernels-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|