quack-kernels 0.1.6__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.6/quack_kernels.egg-info → quack_kernels-0.1.7}/PKG-INFO +1 -1
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/README.md +1 -2
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/__init__.py +1 -1
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/reduction_base.py +16 -8
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/rmsnorm.py +223 -150
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/softmax.py +1 -3
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/utils.py +50 -10
- {quack_kernels-0.1.6 → quack_kernels-0.1.7/quack_kernels.egg-info}/PKG-INFO +1 -1
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/tests/test_rmsnorm.py +17 -50
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/LICENSE +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/pyproject.toml +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/cross_entropy.py +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack/layernorm.py +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack_kernels.egg-info/SOURCES.txt +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack_kernels.egg-info/requires.txt +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/setup.cfg +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/setup.py +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/tests/test_cross_entropy.py +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/tests/test_layernorm.py +0 -0
- {quack_kernels-0.1.6 → quack_kernels-0.1.7}/tests/test_softmax.py +0 -0
|
@@ -16,13 +16,12 @@ pip install quack-kernels
|
|
|
16
16
|
|
|
17
17
|
## Kernels 🐥
|
|
18
18
|
|
|
19
|
-
- 🦆 RMSNorm forward
|
|
19
|
+
- 🦆 RMSNorm forward + backward
|
|
20
20
|
- 🦆 Softmax forward + backward
|
|
21
21
|
- 🦆 Cross entropy forward + backward
|
|
22
22
|
- 🦆 Layernorm forward
|
|
23
23
|
|
|
24
24
|
Upcoming:
|
|
25
|
-
- 🦆 RMSNorm backward
|
|
26
25
|
- 🦆 Rotary forward + backward
|
|
27
26
|
|
|
28
27
|
## Usage
|
|
@@ -68,7 +68,7 @@ class ReductionBase:
|
|
|
68
68
|
)
|
|
69
69
|
|
|
70
70
|
def _allocate_reduction_buffer_and_mbar(
|
|
71
|
-
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
71
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
|
|
72
72
|
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
73
73
|
reduction_buffer = smem.allocate_tensor(
|
|
74
74
|
self.reduction_dtype,
|
|
@@ -76,20 +76,28 @@ class ReductionBase:
|
|
|
76
76
|
byte_alignment=4,
|
|
77
77
|
)
|
|
78
78
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
79
|
-
mbar_ptr = smem.allocate_array(
|
|
79
|
+
mbar_ptr = smem.allocate_array(
|
|
80
|
+
cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
|
|
81
|
+
)
|
|
80
82
|
else:
|
|
81
83
|
mbar_ptr = None
|
|
82
84
|
return reduction_buffer, mbar_ptr
|
|
83
85
|
|
|
84
86
|
@cute.jit
|
|
85
|
-
def _initialize_cluster(
|
|
87
|
+
def _initialize_cluster(
|
|
88
|
+
self,
|
|
89
|
+
tidx: cutlass.Int32,
|
|
90
|
+
mbar_ptr: cute.Pointer,
|
|
91
|
+
num_warps: int,
|
|
92
|
+
is_persistent: bool = False,
|
|
93
|
+
):
|
|
86
94
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
87
|
-
if tidx < self.stage:
|
|
95
|
+
if tidx < self.stage: # Initialize full barrier
|
|
88
96
|
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
97
|
+
if cutlass.const_expr(is_persistent): # Initialize empty barrier
|
|
98
|
+
cute.arch.mbarrier_init(
|
|
99
|
+
mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
|
|
100
|
+
)
|
|
89
101
|
cute.arch.mbarrier_init_fence()
|
|
90
|
-
if tidx < self.stage:
|
|
91
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
92
|
-
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
93
|
-
)
|
|
94
102
|
# Cluster arrive after barrier init
|
|
95
103
|
cute.arch.cluster_arrive_relaxed()
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
|
|
4
3
|
import torch
|
|
5
4
|
from typing import Optional
|
|
6
5
|
|
|
@@ -157,6 +156,7 @@ class RMSNorm(ReductionBase):
|
|
|
157
156
|
|
|
158
157
|
# allocate fragments for gmem->rmem
|
|
159
158
|
tWrW = cute.make_fragment_like(tWgW)
|
|
159
|
+
tWrW.fill(0.0)
|
|
160
160
|
tXrW = thr_copy_X.retile(tWrW)
|
|
161
161
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
162
162
|
|
|
@@ -300,8 +300,14 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
300
300
|
|
|
301
301
|
class RMSNormBackward(ReductionBase):
|
|
302
302
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
303
|
-
#
|
|
304
|
-
super().__init__(dtype, N, stage=
|
|
303
|
+
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
304
|
+
super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
|
|
305
|
+
if self.N > 128 * 1024 and self.dtype.width >= 32:
|
|
306
|
+
# Not enough smem
|
|
307
|
+
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
308
|
+
|
|
309
|
+
def _get_num_threads(self):
|
|
310
|
+
return 128 if self.N <= 4096 else 256
|
|
305
311
|
|
|
306
312
|
def _calculate_threads_per_row(self):
|
|
307
313
|
N = self.N
|
|
@@ -311,44 +317,38 @@ class RMSNormBackward(ReductionBase):
|
|
|
311
317
|
else (
|
|
312
318
|
16
|
|
313
319
|
if N <= 128
|
|
314
|
-
else (32 if N <=
|
|
320
|
+
else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
|
|
315
321
|
)
|
|
316
322
|
)
|
|
317
323
|
|
|
318
324
|
def _set_cluster_n(self):
|
|
319
325
|
N = self.N
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
2
|
|
326
|
-
if N <= 32 * 1024
|
|
327
|
-
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
328
|
-
)
|
|
329
|
-
)
|
|
330
|
-
else: # fp32
|
|
331
|
-
cluster_n = (
|
|
332
|
-
1
|
|
333
|
-
if N <= 32 * 1024
|
|
334
|
-
else (
|
|
335
|
-
2
|
|
336
|
-
if N <= 64 * 1024
|
|
337
|
-
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
338
|
-
)
|
|
339
|
-
)
|
|
326
|
+
cluster_n = (
|
|
327
|
+
1
|
|
328
|
+
if N <= 8 * 1024
|
|
329
|
+
else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
|
|
330
|
+
)
|
|
340
331
|
self.cluster_n = cluster_n
|
|
341
332
|
|
|
333
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
334
|
+
return (
|
|
335
|
+
# Multiply by 2 since we need space for X and dOut,
|
|
336
|
+
# and multiply by another 2 due to double buffering
|
|
337
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
|
|
338
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
339
|
+
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
340
|
+
)
|
|
341
|
+
|
|
342
342
|
@cute.jit
|
|
343
343
|
def __call__(
|
|
344
344
|
self,
|
|
345
345
|
mX: cute.Tensor,
|
|
346
346
|
mW: cute.Tensor,
|
|
347
|
-
|
|
347
|
+
mdOut: cute.Tensor,
|
|
348
348
|
mRstd: cute.Tensor,
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
sm_count: cutlass.
|
|
349
|
+
mdX: cute.Tensor,
|
|
350
|
+
mdW: cute.Tensor,
|
|
351
|
+
sm_count: cutlass.Int32,
|
|
352
352
|
stream: cuda.CUstream,
|
|
353
353
|
):
|
|
354
354
|
self._set_cluster_n()
|
|
@@ -359,14 +359,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
359
359
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
360
360
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
361
361
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
num_blocks = (
|
|
366
|
-
sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
|
|
362
|
+
num_blocks = sm_count
|
|
363
|
+
self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
|
|
370
364
|
grid=[num_blocks, self.cluster_n, 1],
|
|
371
365
|
block=[num_threads, 1, 1],
|
|
372
366
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
@@ -379,177 +373,244 @@ class RMSNormBackward(ReductionBase):
|
|
|
379
373
|
self,
|
|
380
374
|
mX: cute.Tensor,
|
|
381
375
|
mW: cute.Tensor,
|
|
382
|
-
|
|
376
|
+
mdOut: cute.Tensor,
|
|
383
377
|
mRstd: cute.Tensor,
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
sm_count: cutlass.Constexpr,
|
|
378
|
+
mdX: cute.Tensor,
|
|
379
|
+
mdW: cute.Tensor,
|
|
387
380
|
tv_layout: cute.Layout,
|
|
388
381
|
tiler_mn: cute.Shape,
|
|
389
382
|
):
|
|
390
383
|
tidx, _, _ = cute.arch.thread_idx()
|
|
391
|
-
|
|
384
|
+
bidx_start, _, _ = cute.arch.block_idx()
|
|
392
385
|
gdim, _, _ = cute.arch.grid_dim()
|
|
386
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
387
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
388
|
+
else:
|
|
389
|
+
cluster_y = cutlass.const_expr(0)
|
|
393
390
|
|
|
394
391
|
shape = mX.shape
|
|
395
392
|
M, N = shape[0], shape[1]
|
|
393
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
396
394
|
|
|
397
395
|
idX = cute.make_identity_tensor(shape)
|
|
398
396
|
|
|
399
397
|
smem = cutlass.utils.SmemAllocator()
|
|
400
|
-
|
|
398
|
+
smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
|
|
399
|
+
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
|
|
400
|
+
sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
|
|
401
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
|
|
402
|
+
smem, tv_layout, is_persistent=True
|
|
403
|
+
)
|
|
404
|
+
if cutlass.const_expr(mbar_ptr is not None):
|
|
405
|
+
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
|
|
406
|
+
else:
|
|
407
|
+
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
401
408
|
|
|
402
409
|
copy_atom_load_X = cute.make_copy_atom(
|
|
403
410
|
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
404
411
|
)
|
|
405
|
-
|
|
412
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
413
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
414
|
+
)
|
|
406
415
|
copy_atom_load_W = cute.make_copy_atom(
|
|
407
416
|
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
408
417
|
)
|
|
409
|
-
|
|
410
418
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
411
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
419
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
|
|
412
420
|
)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
|
|
421
|
+
copy_atom_store_dW = cute.make_copy_atom(
|
|
422
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
|
|
416
423
|
)
|
|
417
424
|
|
|
418
425
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
426
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
427
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
428
|
+
).get_slice(tidx)
|
|
419
429
|
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
420
|
-
|
|
421
|
-
|
|
430
|
+
thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
|
|
431
|
+
thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
422
432
|
|
|
423
|
-
gW = cute.local_tile(mW, tiler_mn, (
|
|
433
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
424
434
|
tWgW = thr_copy_W.partition_S(gW)
|
|
425
435
|
tWrW = cute.make_fragment_like(tWgW)
|
|
436
|
+
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
437
|
+
if not is_even_N:
|
|
438
|
+
tWrW.fill(0.0)
|
|
426
439
|
tXrW = thr_copy_X.retile(tWrW)
|
|
427
440
|
|
|
428
|
-
gW_coord = cute.local_tile(idX, tiler_mn, (0,
|
|
429
|
-
|
|
430
|
-
|
|
441
|
+
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
442
|
+
tWpW = (
|
|
443
|
+
utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
|
|
444
|
+
if not is_even_N
|
|
445
|
+
else None
|
|
446
|
+
)
|
|
431
447
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
432
448
|
weight = tXrW.load().to(cute.Float32)
|
|
433
449
|
|
|
434
450
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
435
451
|
|
|
436
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
437
|
-
|
|
438
|
-
dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
439
|
-
tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
|
|
452
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
440
453
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
M_pad = ((M + sm_count - 1) // sm_count) * sm_count
|
|
454
|
+
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
455
|
+
tdWpdW = (
|
|
456
|
+
utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
|
|
457
|
+
if not is_even_N
|
|
458
|
+
else None
|
|
459
|
+
)
|
|
448
460
|
|
|
449
|
-
|
|
461
|
+
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
462
|
+
tdWgdW = thr_copy_dW.partition_D(gdW)
|
|
463
|
+
tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
|
|
464
|
+
tXrdW = thr_copy_X.retile(tdWrdW)
|
|
450
465
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
466
|
+
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
467
|
+
gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
|
|
468
|
+
gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
|
|
469
|
+
cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
|
|
470
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
471
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
472
|
+
tXgdOut = thr_copy_X.partition_S(gdOut)
|
|
473
|
+
tXsdOut = thr_copy_X.partition_D(sdOut)
|
|
474
|
+
tXgdX = thr_store_dX.partition_D(gdX)
|
|
475
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
476
|
+
# This doesn't change across iterations
|
|
477
|
+
tXpX = (
|
|
478
|
+
utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
479
|
+
if not is_even_N
|
|
480
|
+
else None
|
|
481
|
+
)
|
|
454
482
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
483
|
+
tXrX, tXrdOut, tXrdX = [
|
|
484
|
+
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
# Prefetch the first batch
|
|
488
|
+
row = tXcX[None, None, None, bidx_start][0][0]
|
|
489
|
+
if row < M:
|
|
490
|
+
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
491
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
|
|
492
|
+
cute.copy(
|
|
493
|
+
copy_atom_load_X_async,
|
|
494
|
+
tXgX_cur,
|
|
495
|
+
tXsX[None, None, None, 0],
|
|
496
|
+
pred=tXpX,
|
|
459
497
|
)
|
|
460
|
-
|
|
461
|
-
|
|
498
|
+
cute.copy(
|
|
499
|
+
copy_atom_load_X_async,
|
|
500
|
+
tXgdOut_cur,
|
|
501
|
+
tXsdOut[None, None, None, 0],
|
|
502
|
+
pred=tXpX,
|
|
462
503
|
)
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
cX = cute.local_tile(
|
|
470
|
-
idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
474
|
-
thrDout = thr_copy_X.partition_S(gDout)
|
|
475
|
-
tXrRstd = thr_copy_W.partition_S(gRstd)
|
|
476
|
-
thrDx = thr_store_dx.partition_D(gDx)
|
|
477
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
478
|
-
|
|
479
|
-
tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
|
|
480
|
-
|
|
481
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
504
|
+
elif tiler_mn[0] > 1:
|
|
505
|
+
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
506
|
+
# later into registers, causing wrong dW.
|
|
507
|
+
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
508
|
+
utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
|
|
509
|
+
cute.arch.cp_async_commit_group()
|
|
482
510
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
|
|
511
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
512
|
+
cute.arch.cluster_wait()
|
|
486
513
|
|
|
514
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
515
|
+
tXrdW.fill(0.0)
|
|
516
|
+
stage = cutlass.Int32(0)
|
|
517
|
+
producer_phase = cutlass.Int32(1)
|
|
518
|
+
consumer_phase = cutlass.Int32(0)
|
|
519
|
+
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
520
|
+
row = tXcX[None, None, None, bidx][0][0]
|
|
521
|
+
rstd = cutlass.Float.zero
|
|
522
|
+
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
523
|
+
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
524
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
|
|
525
|
+
None, None, None, 0
|
|
526
|
+
]
|
|
527
|
+
cute.copy(
|
|
528
|
+
copy_atom_load_X_async,
|
|
529
|
+
tXgX_cur,
|
|
530
|
+
tXsX[None, None, None, stage ^ 1],
|
|
531
|
+
pred=tXpX,
|
|
532
|
+
)
|
|
533
|
+
cute.copy(
|
|
534
|
+
copy_atom_load_X_async,
|
|
535
|
+
tXgdOut_cur,
|
|
536
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
537
|
+
pred=tXpX,
|
|
538
|
+
)
|
|
539
|
+
elif tiler_mn[0] > 1:
|
|
540
|
+
utils.fill_oob(
|
|
541
|
+
tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
|
|
542
|
+
)
|
|
543
|
+
utils.fill_oob(
|
|
544
|
+
tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
|
|
545
|
+
)
|
|
546
|
+
cute.arch.cp_async_commit_group()
|
|
547
|
+
if row < M or tiler_mn[0] == 1:
|
|
548
|
+
rstd = mRstd[row]
|
|
549
|
+
cute.arch.cp_async_wait_group(1)
|
|
550
|
+
cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
|
|
487
551
|
x = tXrX.load().to(cute.Float32)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
rstd = tXrRstd[0]
|
|
552
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
553
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
491
554
|
x_hat = x * rstd
|
|
492
555
|
wdy = dout * weight
|
|
493
|
-
|
|
494
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
495
|
-
|
|
496
|
-
row = tXcX[0][0]
|
|
497
556
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
498
|
-
cute.arch.
|
|
499
|
-
cute.arch.cluster_wait()
|
|
500
|
-
else:
|
|
501
|
-
cute.arch.barrier()
|
|
502
|
-
|
|
557
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
503
558
|
mean_xhat_wdy = (
|
|
504
559
|
utils.row_reduce(
|
|
505
560
|
x_hat * wdy,
|
|
506
561
|
cute.ReductionOp.ADD,
|
|
507
562
|
threads_per_row,
|
|
508
|
-
reduction_buffer[None, None,
|
|
509
|
-
|
|
563
|
+
reduction_buffer[None, None, stage],
|
|
564
|
+
mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
565
|
+
phase=consumer_phase,
|
|
510
566
|
init_val=0.0,
|
|
511
|
-
hook_fn=cute.arch.cluster_wait
|
|
512
|
-
if cutlass.const_expr(self.cluster_n > 1)
|
|
513
|
-
else None,
|
|
514
567
|
)
|
|
515
568
|
/ shape[1]
|
|
516
569
|
)
|
|
517
|
-
|
|
518
|
-
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
519
|
-
frgDx.store(dx.to(frgDout.element_type))
|
|
520
|
-
|
|
521
|
-
if row < M:
|
|
522
|
-
cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
|
|
523
|
-
|
|
524
570
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
cute.arch.
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
if cutlass.const_expr(self.cluster_n > 1):
|
|
545
|
-
cute.arch.
|
|
546
|
-
|
|
547
|
-
|
|
571
|
+
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
572
|
+
# Requires adjusting the thread_count when initializing the mbar
|
|
573
|
+
cute.arch.sync_warp()
|
|
574
|
+
lane_idx = cute.arch.lane_idx()
|
|
575
|
+
if lane_idx < self.cluster_n:
|
|
576
|
+
cute.arch.mbarrier_arrive(
|
|
577
|
+
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
578
|
+
)
|
|
579
|
+
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
580
|
+
tXrdX.store(dx.to(tXrdOut.element_type))
|
|
581
|
+
if row < M or tiler_mn[0] == 1:
|
|
582
|
+
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
583
|
+
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
584
|
+
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
585
|
+
stage ^= 1
|
|
586
|
+
if stage == 0:
|
|
587
|
+
consumer_phase ^= 1
|
|
588
|
+
producer_phase ^= 1
|
|
589
|
+
|
|
590
|
+
if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
591
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
592
|
+
|
|
593
|
+
if cutlass.const_expr(tiler_mn[0] > 1):
|
|
594
|
+
# reduction of dw_partial within the same threadblock
|
|
595
|
+
sdW = cute.make_tensor(
|
|
596
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
597
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
598
|
+
)
|
|
599
|
+
tXsdW = thr_copy_X.partition_D(sdW)
|
|
548
600
|
cute.arch.barrier()
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
601
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
602
|
+
if row > 0:
|
|
603
|
+
cute.autovec_copy(tXrdW, tXsdW)
|
|
604
|
+
cute.arch.barrier()
|
|
605
|
+
if row == 0:
|
|
606
|
+
for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
|
|
607
|
+
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
608
|
+
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
609
|
+
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
610
|
+
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
611
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
612
|
+
else:
|
|
613
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
553
614
|
|
|
554
615
|
|
|
555
616
|
def _rmsnorm_backward(
|
|
@@ -581,8 +642,19 @@ def _rmsnorm_backward(
|
|
|
581
642
|
|
|
582
643
|
device = x.device
|
|
583
644
|
|
|
584
|
-
|
|
585
|
-
|
|
645
|
+
# This should be tuned on how many CTAs can be launched on each SM
|
|
646
|
+
sm_count_multiple = (
|
|
647
|
+
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
648
|
+
)
|
|
649
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
650
|
+
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
651
|
+
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
652
|
+
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
653
|
+
# avoid wave quantization.
|
|
654
|
+
sm_count = (
|
|
655
|
+
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
656
|
+
)
|
|
657
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
|
|
586
658
|
|
|
587
659
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
588
660
|
|
|
@@ -625,6 +697,7 @@ def _rmsnorm_backward(
|
|
|
625
697
|
rstd_tensor,
|
|
626
698
|
dx_tensor,
|
|
627
699
|
dw_partial_tensor,
|
|
700
|
+
sm_count,
|
|
628
701
|
current_stream,
|
|
629
702
|
)
|
|
630
703
|
|
|
@@ -133,9 +133,7 @@ class Softmax(ReductionBase):
|
|
|
133
133
|
|
|
134
134
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
135
135
|
tXpX = (
|
|
136
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
137
|
-
if cutlass.const_expr(not is_even_N)
|
|
138
|
-
else None
|
|
136
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
139
137
|
)
|
|
140
138
|
if tXcX[0][0] < shape[0]:
|
|
141
139
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -120,12 +120,20 @@ def cluster_reduce(
|
|
|
120
120
|
reduction_buffer: cute.Tensor,
|
|
121
121
|
mbar_ptr: cute.Pointer,
|
|
122
122
|
init_val: cute.Numeric = 0.0,
|
|
123
|
+
phase: Optional[cutlass.Int32] = None,
|
|
123
124
|
) -> cute.Numeric:
|
|
124
125
|
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
125
126
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
126
127
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
127
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
128
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
128
129
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
130
|
+
if warp_idx == 0:
|
|
131
|
+
with cute.arch.elect_one():
|
|
132
|
+
num_warps = rows_per_block * warps_per_row
|
|
133
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
134
|
+
mbar_ptr,
|
|
135
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
136
|
+
)
|
|
129
137
|
if lane_idx < cluster_n:
|
|
130
138
|
store_shared_remote(
|
|
131
139
|
val,
|
|
@@ -133,7 +141,7 @@ def cluster_reduce(
|
|
|
133
141
|
mbar_ptr,
|
|
134
142
|
peer_cta_rank_in_cluster=lane_idx,
|
|
135
143
|
)
|
|
136
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
144
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
137
145
|
block_reduce_val = init_val
|
|
138
146
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
139
147
|
for i in cutlass.range_constexpr(num_iter):
|
|
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
|
|
|
149
157
|
op: Callable,
|
|
150
158
|
reduction_buffer: cute.Tensor,
|
|
151
159
|
mbar_ptr: Optional[cute.Pointer],
|
|
160
|
+
phase: Optional[cutlass.Int32] = None,
|
|
152
161
|
init_val: cute.Numeric = 0.0,
|
|
153
162
|
) -> cute.Numeric:
|
|
154
163
|
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
155
164
|
if cutlass.const_expr(mbar_ptr is None):
|
|
156
165
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
157
166
|
else:
|
|
158
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
|
|
167
|
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
159
168
|
|
|
160
169
|
|
|
161
170
|
@cute.jit
|
|
@@ -165,6 +174,7 @@ def row_reduce(
|
|
|
165
174
|
threads_per_row: cutlass.Constexpr[int],
|
|
166
175
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
167
176
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
177
|
+
phase: Optional[cutlass.Int32] = None,
|
|
168
178
|
init_val: cute.Numeric = 0.0,
|
|
169
179
|
hook_fn: Optional[Callable] = None,
|
|
170
180
|
) -> cute.Numeric:
|
|
@@ -193,7 +203,7 @@ def row_reduce(
|
|
|
193
203
|
), "mbar_ptr must be provided for cluster reduction"
|
|
194
204
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
195
205
|
val = block_or_cluster_reduce(
|
|
196
|
-
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
206
|
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
197
207
|
)
|
|
198
208
|
return val
|
|
199
209
|
|
|
@@ -205,6 +215,7 @@ def online_softmax_reduce(
|
|
|
205
215
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
206
216
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
207
217
|
hook_fn: Optional[Callable] = None,
|
|
218
|
+
phase: Optional[cutlass.Int32] = None,
|
|
208
219
|
return_exp_x: bool = False,
|
|
209
220
|
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
210
221
|
assert x.dtype == Float32, "x must be of type Float32"
|
|
@@ -225,7 +236,7 @@ def online_softmax_reduce(
|
|
|
225
236
|
if cutlass.const_expr(hook_fn is not None):
|
|
226
237
|
hook_fn()
|
|
227
238
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
228
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
239
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
229
240
|
assert (
|
|
230
241
|
cluster_n == 1 or mbar_ptr is not None
|
|
231
242
|
), "mbar_ptr must be provided for cluster reduction"
|
|
@@ -251,6 +262,13 @@ def online_softmax_reduce(
|
|
|
251
262
|
max_x = max_x_final
|
|
252
263
|
else:
|
|
253
264
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
265
|
+
if warp_idx == 0:
|
|
266
|
+
with cute.arch.elect_one():
|
|
267
|
+
num_warps = rows_per_block * warps_per_row
|
|
268
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
269
|
+
mbar_ptr,
|
|
270
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
271
|
+
)
|
|
254
272
|
if lane_idx < cluster_n:
|
|
255
273
|
store_shared_remote(
|
|
256
274
|
f32x2_to_i64(max_x, sum_exp_x),
|
|
@@ -258,7 +276,7 @@ def online_softmax_reduce(
|
|
|
258
276
|
mbar_ptr,
|
|
259
277
|
peer_cta_rank_in_cluster=lane_idx,
|
|
260
278
|
)
|
|
261
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
279
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
262
280
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
263
281
|
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
264
282
|
max_x_single_warp.fill(-Float32.inf)
|
|
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
351
369
|
|
|
352
370
|
|
|
353
371
|
@cute.jit
|
|
354
|
-
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
372
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
|
355
373
|
"""Fill out-of-bounds values in shared memory tensor.
|
|
356
374
|
|
|
357
375
|
Args:
|
|
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
|
|
|
361
379
|
"""
|
|
362
380
|
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
363
381
|
tXrX_fill.fill(fill_value)
|
|
364
|
-
for rest_v in cutlass.range_constexpr(
|
|
365
|
-
for rest_k in cutlass.range_constexpr(
|
|
366
|
-
if
|
|
382
|
+
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
383
|
+
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
384
|
+
if cutlass.const_expr(tXpX is not None):
|
|
385
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
386
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
387
|
+
else:
|
|
367
388
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
368
389
|
|
|
369
390
|
|
|
@@ -396,6 +417,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
396
417
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
397
418
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
398
419
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
420
|
+
assert len(flat_coord_i64) == len(
|
|
421
|
+
flat_stride
|
|
422
|
+
), "Coordinate and stride must have the same length"
|
|
399
423
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
400
424
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
401
425
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -406,3 +430,19 @@ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=No
|
|
|
406
430
|
assumed_align=tensor.iterator.max_alignment,
|
|
407
431
|
)
|
|
408
432
|
return cute.make_tensor(new_ptr, tensor.layout)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@dsl_user_op
|
|
436
|
+
def coord_offset_i64(
|
|
437
|
+
idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
|
|
438
|
+
) -> cute.Tensor:
|
|
439
|
+
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
440
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
441
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
442
|
+
new_ptr = cute.make_ptr(
|
|
443
|
+
tensor.element_type,
|
|
444
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
445
|
+
tensor.memspace,
|
|
446
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
447
|
+
)
|
|
448
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
@@ -9,20 +9,22 @@ from quack.rmsnorm import rmsnorm, rmsnorm_ref, rstd_ref
|
|
|
9
9
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
10
10
|
# @pytest.mark.parametrize("eps", [1e-5])
|
|
11
11
|
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
12
|
-
# @pytest.mark.parametrize("input_dtype", [torch.
|
|
12
|
+
# @pytest.mark.parametrize("input_dtype", [torch.float16])
|
|
13
13
|
@pytest.mark.parametrize(
|
|
14
14
|
"N",
|
|
15
15
|
[192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
|
|
16
|
-
# [
|
|
16
|
+
# [262144]
|
|
17
17
|
)
|
|
18
|
-
@pytest.mark.parametrize("M", [1, 37, 199])
|
|
18
|
+
@pytest.mark.parametrize("M", [1, 37, 199, 8 * 1024])
|
|
19
19
|
# @pytest.mark.parametrize("M", [1])
|
|
20
|
-
def
|
|
20
|
+
def test_rmsnorm_forward_backward(M, N, input_dtype, eps):
|
|
21
21
|
"""Test RMSNorm forward pass against reference implementation."""
|
|
22
|
+
if N >= 256 * 1024 and input_dtype == torch.float32 and M >= 8 * 1024:
|
|
23
|
+
pytest.skip("Skipping large tensor test for float32 to avoid OOM")
|
|
22
24
|
device = "cuda"
|
|
23
25
|
# Set tolerance based on dtype
|
|
24
26
|
if input_dtype == torch.bfloat16:
|
|
25
|
-
atol =
|
|
27
|
+
atol = 1e-1
|
|
26
28
|
elif input_dtype == torch.float16:
|
|
27
29
|
atol = 1e-2
|
|
28
30
|
else:
|
|
@@ -35,55 +37,20 @@ def test_rmsnorm_forward(M, N, input_dtype, eps):
|
|
|
35
37
|
out = rmsnorm(x, weight, eps=eps)
|
|
36
38
|
out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
37
39
|
# rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
38
|
-
|
|
39
|
-
# Check output shape and dtype
|
|
40
40
|
assert out.shape == x.shape
|
|
41
41
|
assert out.dtype == input_dtype
|
|
42
|
-
|
|
43
|
-
# Check accuracy
|
|
44
42
|
torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
|
|
45
43
|
# torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
# # Set tolerance based on dtype
|
|
57
|
-
# if input_dtype == torch.bfloat16:
|
|
58
|
-
# atol = 5e-2
|
|
59
|
-
# elif input_dtype == torch.float16:
|
|
60
|
-
# atol = 1e-2
|
|
61
|
-
# else:
|
|
62
|
-
# atol = 1e-4
|
|
63
|
-
|
|
64
|
-
# # Set seed for reproducibility
|
|
65
|
-
# torch.random.manual_seed(0)
|
|
66
|
-
|
|
67
|
-
# # Create input tensors
|
|
68
|
-
# x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
|
|
69
|
-
# weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
70
|
-
|
|
71
|
-
# # Clone for reference
|
|
72
|
-
# x_ref = x.detach().clone().requires_grad_()
|
|
73
|
-
# weight_ref = weight.detach().clone().requires_grad_()
|
|
74
|
-
|
|
75
|
-
# # Forward pass
|
|
76
|
-
# out = rmsnorm(x, weight, eps=eps)
|
|
77
|
-
# out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
78
|
-
|
|
79
|
-
# # Backward pass
|
|
80
|
-
# grad_out = torch.randn_like(out)
|
|
81
|
-
# out.backward(grad_out)
|
|
82
|
-
# out_ref.backward(grad_out)
|
|
83
|
-
|
|
84
|
-
# # Check gradients
|
|
85
|
-
# torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
|
|
86
|
-
# torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
44
|
+
# Backward pass
|
|
45
|
+
if N > 128 * 1024 and input_dtype == torch.float32:
|
|
46
|
+
# Skip backward pass for due to not enough smem
|
|
47
|
+
return
|
|
48
|
+
grad_out = torch.randn_like(out)
|
|
49
|
+
torch.cuda.synchronize()
|
|
50
|
+
out_ref.backward(grad_out)
|
|
51
|
+
out.backward(grad_out)
|
|
52
|
+
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
|
|
53
|
+
torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
87
54
|
|
|
88
55
|
|
|
89
56
|
@pytest.mark.parametrize("eps", [1e-5])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|