quack-kernels 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -161,30 +161,33 @@ class RMSNorm(ReductionBase):
161
161
  copy_atom_load_X_async = cute.make_copy_atom(
162
162
  cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
163
163
  )
164
+ num_bits_per_copy_W = cutlass.const_expr(
165
+ min(128, 128 // mX.element_type.width * mW.element_type.width)
166
+ )
164
167
  copy_atom_load_W = cute.make_copy_atom(
165
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
168
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
169
+ )
170
+ num_bits_per_copy_O = cutlass.const_expr(
171
+ min(128, 128 // mX.element_type.width * mO.element_type.width)
166
172
  )
167
173
  copy_atom_store_O = cute.make_copy_atom(
168
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
174
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
169
175
  )
170
176
 
171
177
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
172
178
  tidx
173
179
  )
174
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
175
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
176
180
 
177
- tWgW = thr_copy_W.partition_S(gW)
181
+ tXgW = thr_copy_X.partition_S(gW)
178
182
  tXgX = thr_copy_X.partition_S(gX)
179
183
  tXsX = thr_copy_X.partition_D(sX)
180
- tXgO = thr_copy_O.partition_D(gO)
181
- tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
184
+ tXgO = thr_copy_X.partition_D(gO)
185
+ tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
182
186
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
183
187
 
184
188
  # allocate fragments for gmem->rmem
185
- tWrW = cute.make_fragment_like(tWgW)
186
- tWrW.fill(0.0)
187
- tXrW = thr_copy_X.retile(tWrW)
189
+ tXrW = cute.make_fragment_like(tXgW)
190
+ tXrW.fill(0.0)
188
191
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
189
192
 
190
193
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -196,9 +199,9 @@ class RMSNorm(ReductionBase):
196
199
  cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
197
200
  cute.arch.cp_async_commit_group()
198
201
 
199
- tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
202
+ tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
200
203
  if cutlass.const_expr(not delay_w_load):
201
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
204
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
202
205
 
203
206
  cute.arch.cp_async_wait_group(0)
204
207
  cute.autovec_copy(tXsX, tXrX)
@@ -223,7 +226,7 @@ class RMSNorm(ReductionBase):
223
226
  ):
224
227
  tXrRstd[0] = rstd
225
228
  if cutlass.const_expr(delay_w_load):
226
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
229
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
227
230
  if cutlass.const_expr(reload_from == "smem"):
228
231
  cute.autovec_copy(tXsX, tXrX)
229
232
  x = tXrX.load().to(cute.Float32)
@@ -234,9 +237,9 @@ class RMSNorm(ReductionBase):
234
237
  w = tXrW.load().to(cute.Float32)
235
238
  y = x_hat * w
236
239
  tXrO.store(y.to(tXrO.element_type))
237
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
240
+ tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
238
241
  if row < shape[0]:
239
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
242
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
240
243
 
241
244
 
242
245
  def _rmsnorm_fwd(
@@ -460,39 +463,41 @@ class RMSNormBackward(ReductionBase):
460
463
  copy_atom_load_X_async = cute.make_copy_atom(
461
464
  cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
462
465
  )
466
+ num_bits_per_copy_W = cutlass.const_expr(
467
+ min(128, 128 // mX.element_type.width * mW.element_type.width)
468
+ )
463
469
  copy_atom_load_W = cute.make_copy_atom(
464
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
470
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
471
+ )
472
+ num_bits_per_copy_dX = cutlass.const_expr(
473
+ min(128, 128 // mX.element_type.width * mdX.element_type.width)
465
474
  )
466
475
  copy_atom_store_dX = cute.make_copy_atom(
467
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
476
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
477
+ )
478
+ num_bits_per_copy_dW = cutlass.const_expr(
479
+ min(128, 128 // mX.element_type.width * mdW.element_type.width)
468
480
  )
469
481
  copy_atom_store_dW = cute.make_copy_atom(
470
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
482
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
471
483
  )
472
484
 
473
485
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
474
- thr_copy_X_async = cute.make_tiled_copy(
475
- copy_atom_load_X_async, tv_layout, tiler_mn
476
- ).get_slice(tidx)
477
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
478
- thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
479
- thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
480
486
 
481
487
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
482
- tWgW = thr_copy_W.partition_S(gW)
483
- tWrW = cute.make_fragment_like(tWgW)
488
+ tXgW = thr_copy_X.partition_S(gW)
489
+ tXrW = cute.make_fragment_like(tXgW)
484
490
  # Need this, otherwise rW can have arbitrary values that changes the reduction
485
491
  if not is_even_N:
486
- tWrW.fill(0.0)
487
- tXrW = thr_copy_X.retile(tWrW)
492
+ tXrW.fill(0.0)
488
493
 
489
494
  gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
490
- tWpW = (
491
- utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
495
+ tXpW = (
496
+ utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
492
497
  if not is_even_N
493
498
  else None
494
499
  )
495
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
500
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
496
501
  weight = tXrW.load().to(cute.Float32)
497
502
 
498
503
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -500,17 +505,16 @@ class RMSNormBackward(ReductionBase):
500
505
  self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
501
506
 
502
507
  dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
503
- tdWpdW = (
504
- utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
508
+ tXpdW = (
509
+ utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
505
510
  if not is_even_N
506
511
  else None
507
512
  )
508
513
 
509
514
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
510
- tdWgdW = thr_copy_dW.partition_D(gdW)
515
+ tXgdW = thr_copy_X.partition_S(gdW)
511
516
  # Always compute partial weight gradients in fp32
512
- tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
513
- tXrdW = thr_copy_X.retile(tdWrdW)
517
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
514
518
 
515
519
  gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
516
520
  gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
@@ -520,7 +524,7 @@ class RMSNormBackward(ReductionBase):
520
524
  tXsX = thr_copy_X.partition_D(sX)
521
525
  tXgdOut = thr_copy_X.partition_S(gdOut)
522
526
  tXsdOut = thr_copy_X.partition_D(sdOut)
523
- tXgdX = thr_store_dX.partition_D(gdX)
527
+ tXgdX = thr_copy_X.partition_D(gdX)
524
528
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
525
529
  # This doesn't change across iterations
526
530
  tXpX = (
@@ -670,11 +674,10 @@ class RMSNormBackward(ReductionBase):
670
674
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
671
675
  cute.autovec_copy(tXsdW_other, tXrdW_other)
672
676
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
673
- cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
674
-
677
+ cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
675
678
  else:
676
679
  # dw is already in fp32, so we can directly copy to global memory
677
- cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
680
+ cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
678
681
 
679
682
 
680
683
  def _rmsnorm_backward(
quack/utils.py CHANGED
@@ -315,7 +315,7 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
315
315
  if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
316
316
  res = cute.make_fragment(x.shape, Float32)
317
317
  res.store(x)
318
- for i in cutlass.range_constexpr(cute.size(x.shape)):
318
+ for i in cutlass.range(cute.size(x.shape), unroll_full=True):
319
319
  res[i] = cute.arch.exp2(res[i])
320
320
  return res.load()
321
321
  else:
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.8
4
- Requires-Python: >=3.9
3
+ Version: 0.1.10
4
+ Requires-Python: >=3.12
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,13 @@
1
+ quack/__init__.py,sha256=4tLchTx7d0d1ZVg6psRjjoXAWKHqzIWRF5mUk8ZdgkQ,204
2
+ quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
3
+ quack/dense_gemm_sm90.py,sha256=jULXfAQkRh1SUAOpesx8wouY-GLDCm05Fb5LynozSl8,59932
4
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
5
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
6
+ quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
7
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
8
+ quack/utils.py,sha256=RZq-7YA8UMUizHpVyZM1we4zGm9NaC178M2g2HXdjmE,17799
9
+ quack_kernels-0.1.10.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ quack_kernels-0.1.10.dist-info/METADATA,sha256=baMTwibt6u0IQb8YJFFhCY0RD3Aervf5sl6EpYF6IQ8,286
11
+ quack_kernels-0.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ quack_kernels-0.1.10.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
13
+ quack_kernels-0.1.10.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=tDgX5MF1ttfEyDVFWi47DA8tDooYcBQlkuzvabGUoQI,203
2
- quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
- quack/rmsnorm.py,sha256=-qrKqPKk0fUuq0a5-vJmZZ7nQsHgyaqTg0EKhWT44r0,32738
6
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
- quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
- quack_kernels-0.1.8.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.8.dist-info/METADATA,sha256=b_2PxFEoVqWJbT2FtuP9FJyF-jpL2Z3q9OHoOEipqo4,289
10
- quack_kernels-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.8.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.8.dist-info/RECORD,,