quack-kernels 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/rmsnorm.py +43 -40
- {quack_kernels-0.1.8.dist-info → quack_kernels-0.1.9.dist-info}/METADATA +1 -1
- {quack_kernels-0.1.8.dist-info → quack_kernels-0.1.9.dist-info}/RECORD +7 -7
- {quack_kernels-0.1.8.dist-info → quack_kernels-0.1.9.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.8.dist-info → quack_kernels-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.8.dist-info → quack_kernels-0.1.9.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/rmsnorm.py
CHANGED
|
@@ -161,30 +161,33 @@ class RMSNorm(ReductionBase):
|
|
|
161
161
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
162
162
|
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
163
163
|
)
|
|
164
|
+
num_bits_per_copy_W = cutlass.const_expr(
|
|
165
|
+
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
166
|
+
)
|
|
164
167
|
copy_atom_load_W = cute.make_copy_atom(
|
|
165
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
168
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
|
|
169
|
+
)
|
|
170
|
+
num_bits_per_copy_O = cutlass.const_expr(
|
|
171
|
+
min(128, 128 // mX.element_type.width * mO.element_type.width)
|
|
166
172
|
)
|
|
167
173
|
copy_atom_store_O = cute.make_copy_atom(
|
|
168
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=
|
|
174
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
|
|
169
175
|
)
|
|
170
176
|
|
|
171
177
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
172
178
|
tidx
|
|
173
179
|
)
|
|
174
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
175
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
176
180
|
|
|
177
|
-
|
|
181
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
178
182
|
tXgX = thr_copy_X.partition_S(gX)
|
|
179
183
|
tXsX = thr_copy_X.partition_D(sX)
|
|
180
|
-
tXgO =
|
|
181
|
-
tXrRstd =
|
|
184
|
+
tXgO = thr_copy_X.partition_D(gO)
|
|
185
|
+
tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
182
186
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
183
187
|
|
|
184
188
|
# allocate fragments for gmem->rmem
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
189
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
190
|
+
tXrW.fill(0.0)
|
|
188
191
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
189
192
|
|
|
190
193
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -196,9 +199,9 @@ class RMSNorm(ReductionBase):
|
|
|
196
199
|
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
197
200
|
cute.arch.cp_async_commit_group()
|
|
198
201
|
|
|
199
|
-
|
|
202
|
+
tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
200
203
|
if cutlass.const_expr(not delay_w_load):
|
|
201
|
-
cute.copy(copy_atom_load_W,
|
|
204
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
202
205
|
|
|
203
206
|
cute.arch.cp_async_wait_group(0)
|
|
204
207
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -223,7 +226,7 @@ class RMSNorm(ReductionBase):
|
|
|
223
226
|
):
|
|
224
227
|
tXrRstd[0] = rstd
|
|
225
228
|
if cutlass.const_expr(delay_w_load):
|
|
226
|
-
cute.copy(copy_atom_load_W,
|
|
229
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
227
230
|
if cutlass.const_expr(reload_from == "smem"):
|
|
228
231
|
cute.autovec_copy(tXsX, tXrX)
|
|
229
232
|
x = tXrX.load().to(cute.Float32)
|
|
@@ -234,9 +237,9 @@ class RMSNorm(ReductionBase):
|
|
|
234
237
|
w = tXrW.load().to(cute.Float32)
|
|
235
238
|
y = x_hat * w
|
|
236
239
|
tXrO.store(y.to(tXrO.element_type))
|
|
237
|
-
|
|
240
|
+
tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
238
241
|
if row < shape[0]:
|
|
239
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=
|
|
242
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
|
|
240
243
|
|
|
241
244
|
|
|
242
245
|
def _rmsnorm_fwd(
|
|
@@ -460,39 +463,41 @@ class RMSNormBackward(ReductionBase):
|
|
|
460
463
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
461
464
|
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
462
465
|
)
|
|
466
|
+
num_bits_per_copy_W = cutlass.const_expr(
|
|
467
|
+
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
468
|
+
)
|
|
463
469
|
copy_atom_load_W = cute.make_copy_atom(
|
|
464
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
470
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
|
|
471
|
+
)
|
|
472
|
+
num_bits_per_copy_dX = cutlass.const_expr(
|
|
473
|
+
min(128, 128 // mX.element_type.width * mdX.element_type.width)
|
|
465
474
|
)
|
|
466
475
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
467
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=
|
|
476
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
|
|
477
|
+
)
|
|
478
|
+
num_bits_per_copy_dW = cutlass.const_expr(
|
|
479
|
+
min(128, 128 // mX.element_type.width * mdW.element_type.width)
|
|
468
480
|
)
|
|
469
481
|
copy_atom_store_dW = cute.make_copy_atom(
|
|
470
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=
|
|
482
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
|
|
471
483
|
)
|
|
472
484
|
|
|
473
485
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
474
|
-
thr_copy_X_async = cute.make_tiled_copy(
|
|
475
|
-
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
476
|
-
).get_slice(tidx)
|
|
477
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
478
|
-
thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
|
|
479
|
-
thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
480
486
|
|
|
481
487
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
482
|
-
|
|
483
|
-
|
|
488
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
489
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
484
490
|
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
485
491
|
if not is_even_N:
|
|
486
|
-
|
|
487
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
492
|
+
tXrW.fill(0.0)
|
|
488
493
|
|
|
489
494
|
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
490
|
-
|
|
491
|
-
utils.predicate_k(
|
|
495
|
+
tXpW = (
|
|
496
|
+
utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
|
|
492
497
|
if not is_even_N
|
|
493
498
|
else None
|
|
494
499
|
)
|
|
495
|
-
cute.copy(copy_atom_load_W,
|
|
500
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
496
501
|
weight = tXrW.load().to(cute.Float32)
|
|
497
502
|
|
|
498
503
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -500,17 +505,16 @@ class RMSNormBackward(ReductionBase):
|
|
|
500
505
|
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
501
506
|
|
|
502
507
|
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
503
|
-
|
|
504
|
-
utils.predicate_k(
|
|
508
|
+
tXpdW = (
|
|
509
|
+
utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
|
|
505
510
|
if not is_even_N
|
|
506
511
|
else None
|
|
507
512
|
)
|
|
508
513
|
|
|
509
514
|
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
510
|
-
|
|
515
|
+
tXgdW = thr_copy_X.partition_S(gdW)
|
|
511
516
|
# Always compute partial weight gradients in fp32
|
|
512
|
-
|
|
513
|
-
tXrdW = thr_copy_X.retile(tdWrdW)
|
|
517
|
+
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
514
518
|
|
|
515
519
|
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
516
520
|
gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
|
|
@@ -520,7 +524,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
520
524
|
tXsX = thr_copy_X.partition_D(sX)
|
|
521
525
|
tXgdOut = thr_copy_X.partition_S(gdOut)
|
|
522
526
|
tXsdOut = thr_copy_X.partition_D(sdOut)
|
|
523
|
-
tXgdX =
|
|
527
|
+
tXgdX = thr_copy_X.partition_D(gdX)
|
|
524
528
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
525
529
|
# This doesn't change across iterations
|
|
526
530
|
tXpX = (
|
|
@@ -670,11 +674,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
670
674
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
671
675
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
672
676
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
673
|
-
cute.copy(copy_atom_store_dW,
|
|
674
|
-
|
|
677
|
+
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
675
678
|
else:
|
|
676
679
|
# dw is already in fp32, so we can directly copy to global memory
|
|
677
|
-
cute.copy(copy_atom_store_dW,
|
|
680
|
+
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
678
681
|
|
|
679
682
|
|
|
680
683
|
def _rmsnorm_backward(
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
quack/__init__.py,sha256=
|
|
1
|
+
quack/__init__.py,sha256=CT76CeRNh5bzQ9f13yVuRz9Sj7V3MvwzHH4fB1iQIf0,203
|
|
2
2
|
quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
|
|
3
3
|
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
4
|
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
-
quack/rmsnorm.py,sha256
|
|
5
|
+
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
6
6
|
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
7
|
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
-
quack_kernels-0.1.
|
|
9
|
-
quack_kernels-0.1.
|
|
10
|
-
quack_kernels-0.1.
|
|
11
|
-
quack_kernels-0.1.
|
|
12
|
-
quack_kernels-0.1.
|
|
8
|
+
quack_kernels-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
+
quack_kernels-0.1.9.dist-info/METADATA,sha256=vOnpbShNHRiUXKAnOUxzfRM7zkpW3RmjW4hIgvYda08,289
|
|
10
|
+
quack_kernels-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
quack_kernels-0.1.9.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
+
quack_kernels-0.1.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|