quack-kernels 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.5/quack_kernels.egg-info → quack_kernels-0.1.6}/PKG-INFO +1 -1
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/README.md +16 -5
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/__init__.py +1 -1
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/cross_entropy.py +11 -7
- quack_kernels-0.1.6/quack/layernorm.py +351 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/rmsnorm.py +4 -1
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/softmax.py +8 -3
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/utils.py +16 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6/quack_kernels.egg-info}/PKG-INFO +1 -1
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/SOURCES.txt +2 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/top_level.txt +1 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_cross_entropy.py +13 -49
- quack_kernels-0.1.6/tests/test_layernorm.py +162 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_rmsnorm.py +36 -5
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_softmax.py +2 -3
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/LICENSE +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/pyproject.toml +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/reduction_base.py +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/requires.txt +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/setup.cfg +0 -0
- {quack_kernels-0.1.5 → quack_kernels-0.1.6}/setup.py +0 -0
|
@@ -19,6 +19,7 @@ pip install quack-kernels
|
|
|
19
19
|
- 🦆 RMSNorm forward
|
|
20
20
|
- 🦆 Softmax forward + backward
|
|
21
21
|
- 🦆 Cross entropy forward + backward
|
|
22
|
+
- 🦆 Layernorm forward
|
|
22
23
|
|
|
23
24
|
Upcoming:
|
|
24
25
|
- 🦆 RMSNorm backward
|
|
@@ -30,13 +31,23 @@ Upcoming:
|
|
|
30
31
|
from quack import rmsnorm, softmax, cross_entropy
|
|
31
32
|
```
|
|
32
33
|
|
|
33
|
-
##
|
|
34
|
+
## Documentations
|
|
34
35
|
|
|
35
|
-
|
|
36
|
+
[2025-07-10] We have a comprehensive
|
|
37
|
+
[blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
|
|
38
|
+
to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
|
|
36
39
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
## Performance
|
|
41
|
+
|
|
42
|
+
<div align="center">
|
|
43
|
+
<figure>
|
|
44
|
+
<img
|
|
45
|
+
src="media/bf16_kernel_benchmarks_single_row.svg"
|
|
46
|
+
>
|
|
47
|
+
</figure>
|
|
48
|
+
</div>
|
|
49
|
+
|
|
50
|
+
See our [blogpost](media/2025-07-10-membound-sol.md) for the details.
|
|
40
51
|
|
|
41
52
|
## Development
|
|
42
53
|
|
|
@@ -104,7 +104,10 @@ class CrossEntropy(ReductionBase):
|
|
|
104
104
|
shape: cute.Shape = mX.shape
|
|
105
105
|
idX = cute.make_identity_tensor(shape)
|
|
106
106
|
# slice for CTAs
|
|
107
|
-
|
|
107
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
108
|
+
mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
109
|
+
gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
|
|
110
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
108
111
|
|
|
109
112
|
smem = cutlass.utils.SmemAllocator()
|
|
110
113
|
sX = smem.allocate_tensor(
|
|
@@ -150,7 +153,9 @@ class CrossEntropy(ReductionBase):
|
|
|
150
153
|
|
|
151
154
|
target_logit = cute.Float32.zero
|
|
152
155
|
if row < shape[0] and tXcX[0][1] == 0:
|
|
153
|
-
|
|
156
|
+
# Use Int64 for indexing to deal with large tensors
|
|
157
|
+
mX_off = utils.domain_offset_i64((row, 0), mX)
|
|
158
|
+
target_logit = cute.Float32(mX_off[0, target])
|
|
154
159
|
|
|
155
160
|
threads_per_row = tv_layout.shape[0][0]
|
|
156
161
|
if cutlass.const_expr(not self.online_softmax):
|
|
@@ -363,11 +368,10 @@ class CrossEntropyBackward:
|
|
|
363
368
|
)
|
|
364
369
|
|
|
365
370
|
idX = cute.make_identity_tensor(shape)
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
]
|
|
371
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
372
|
+
mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
|
|
373
|
+
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
374
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
371
375
|
|
|
372
376
|
copy_atom_load_X = cute.make_copy_atom(
|
|
373
377
|
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
import quack.utils as utils
|
|
13
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LayerNorm(ReductionBase):
|
|
17
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
18
|
+
super().__init__(dtype, N, stage=2) # 2 stages for mean and var
|
|
19
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
20
|
+
self.delay_w_load = False
|
|
21
|
+
|
|
22
|
+
def _calculate_threads_per_row(self):
|
|
23
|
+
N = self.N
|
|
24
|
+
return (
|
|
25
|
+
8
|
|
26
|
+
if N <= 64
|
|
27
|
+
else (
|
|
28
|
+
16
|
|
29
|
+
if N <= 128
|
|
30
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _set_cluster_n(self):
|
|
35
|
+
N = self.N
|
|
36
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
37
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
38
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
39
|
+
cluster_n = (
|
|
40
|
+
1
|
|
41
|
+
if N <= 16 * 1024
|
|
42
|
+
else (
|
|
43
|
+
2
|
|
44
|
+
if N <= 32 * 1024
|
|
45
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
else: # fp32
|
|
49
|
+
cluster_n = (
|
|
50
|
+
1
|
|
51
|
+
if N <= 32 * 1024
|
|
52
|
+
else (
|
|
53
|
+
2
|
|
54
|
+
if N <= 64 * 1024
|
|
55
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
self.cluster_n = cluster_n
|
|
59
|
+
|
|
60
|
+
@cute.jit
|
|
61
|
+
def __call__(
|
|
62
|
+
self,
|
|
63
|
+
mX: cute.Tensor,
|
|
64
|
+
mW: cute.Tensor,
|
|
65
|
+
mO: cute.Tensor,
|
|
66
|
+
mRstd: Optional[cute.Tensor],
|
|
67
|
+
mMean: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
if cutlass.const_expr(mMean is not None):
|
|
85
|
+
mMean_expanded_layout = cute.append(
|
|
86
|
+
mMean.layout, cute.make_layout((self.N,), stride=(0,))
|
|
87
|
+
)
|
|
88
|
+
mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
|
|
89
|
+
self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
90
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
91
|
+
block=[num_threads, 1, 1],
|
|
92
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
93
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
94
|
+
stream=stream,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@cute.kernel
|
|
98
|
+
def kernel(
|
|
99
|
+
self,
|
|
100
|
+
mX: cute.Tensor,
|
|
101
|
+
mW: cute.Tensor,
|
|
102
|
+
mO: cute.Tensor,
|
|
103
|
+
mRstd: Optional[cute.Tensor],
|
|
104
|
+
mMean: Optional[cute.Tensor],
|
|
105
|
+
eps: cute.Float32,
|
|
106
|
+
tv_layout: cute.Layout,
|
|
107
|
+
tiler_mn: cute.Shape,
|
|
108
|
+
reload_from: cutlass.Constexpr = None,
|
|
109
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
110
|
+
):
|
|
111
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
112
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
113
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
114
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
115
|
+
else:
|
|
116
|
+
cluster_y = cutlass.const_expr(0)
|
|
117
|
+
|
|
118
|
+
smem = cutlass.utils.SmemAllocator()
|
|
119
|
+
sX = smem.allocate_tensor(
|
|
120
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
121
|
+
)
|
|
122
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
123
|
+
|
|
124
|
+
shape = mX.shape
|
|
125
|
+
idX = cute.make_identity_tensor(shape)
|
|
126
|
+
# slice for CTAs
|
|
127
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
128
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
129
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
130
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
131
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
132
|
+
gRstd = (
|
|
133
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
134
|
+
if cutlass.const_expr(mRstd is not None)
|
|
135
|
+
else None
|
|
136
|
+
)
|
|
137
|
+
gMean = (
|
|
138
|
+
cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
|
|
139
|
+
if cutlass.const_expr(mMean is not None)
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# declare the atoms which will be used later for memory copy
|
|
144
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
145
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
146
|
+
)
|
|
147
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
148
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
149
|
+
)
|
|
150
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
151
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
152
|
+
)
|
|
153
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
154
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
158
|
+
tidx
|
|
159
|
+
)
|
|
160
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
161
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
162
|
+
|
|
163
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
164
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
165
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
166
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
167
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
168
|
+
tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
|
|
169
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
170
|
+
|
|
171
|
+
# allocate fragments for gmem->rmem
|
|
172
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
173
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
174
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
175
|
+
|
|
176
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
177
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
178
|
+
|
|
179
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
180
|
+
row = tXcX[0][0]
|
|
181
|
+
if row < shape[0]:
|
|
182
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
183
|
+
cute.arch.cp_async_commit_group()
|
|
184
|
+
|
|
185
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
186
|
+
if cutlass.const_expr(not delay_w_load):
|
|
187
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
188
|
+
|
|
189
|
+
cute.arch.cp_async_wait_group(0)
|
|
190
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
191
|
+
x = tXrX.load().to(cute.Float32)
|
|
192
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
193
|
+
sum_x = utils.row_reduce(
|
|
194
|
+
x,
|
|
195
|
+
cute.ReductionOp.ADD,
|
|
196
|
+
threads_per_row,
|
|
197
|
+
reduction_buffer[None, None, 0],
|
|
198
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
199
|
+
init_val=0.0,
|
|
200
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
201
|
+
)
|
|
202
|
+
mean = sum_x / shape[1]
|
|
203
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
204
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
205
|
+
x = tXrX.load().to(cute.Float32)
|
|
206
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
207
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
208
|
+
x = tXrX.load().to(cute.Float32)
|
|
209
|
+
|
|
210
|
+
sum_sq_x_sub_mean = utils.row_reduce(
|
|
211
|
+
(x - mean) * (x - mean),
|
|
212
|
+
cute.ReductionOp.ADD,
|
|
213
|
+
threads_per_row,
|
|
214
|
+
reduction_buffer[None, None, 1],
|
|
215
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
216
|
+
init_val=0.0,
|
|
217
|
+
)
|
|
218
|
+
rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
|
|
219
|
+
if cutlass.const_expr(mRstd is not None):
|
|
220
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
221
|
+
if (
|
|
222
|
+
tXcX[0][1] == 0
|
|
223
|
+
and row < shape[0]
|
|
224
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
225
|
+
):
|
|
226
|
+
tXrRstd[0] = rstd
|
|
227
|
+
if cutlass.const_expr(mMean is not None):
|
|
228
|
+
# Only the thread corresponding to column 0 writes out the mean to gmem
|
|
229
|
+
if (
|
|
230
|
+
tXcX[0][1] == 0
|
|
231
|
+
and row < shape[0]
|
|
232
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
233
|
+
):
|
|
234
|
+
tXrMean[0] = mean
|
|
235
|
+
if cutlass.const_expr(delay_w_load):
|
|
236
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
237
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
238
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
239
|
+
x = tXrX.load().to(cute.Float32)
|
|
240
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
241
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
242
|
+
x = tXrX.load().to(cute.Float32)
|
|
243
|
+
x_hat = (x - mean) * rstd
|
|
244
|
+
w = tXrW.load().to(cute.Float32)
|
|
245
|
+
y = x_hat * w
|
|
246
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
247
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
248
|
+
if row < shape[0]:
|
|
249
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def layernorm(
|
|
253
|
+
x: torch.Tensor,
|
|
254
|
+
weight: torch.Tensor,
|
|
255
|
+
eps: float = 1e-6,
|
|
256
|
+
return_rstd: bool = False,
|
|
257
|
+
return_mean: bool = False,
|
|
258
|
+
) -> torch.Tensor:
|
|
259
|
+
"""LayerNorm forward pass.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
x: Input tensor of shape (M, N)
|
|
263
|
+
weight: Weight tensor of shape (N,)
|
|
264
|
+
eps: Small value for numerical stability
|
|
265
|
+
return_rstd: Whether to return the reciprocal standard deviation
|
|
266
|
+
return_mean: Whether to return the mean
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Normalized output tensor of same shape as x
|
|
270
|
+
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
271
|
+
If return_mean is True, also returns mean tensor of shape (M,)
|
|
272
|
+
"""
|
|
273
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
274
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
275
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
276
|
+
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
277
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
278
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
279
|
+
M, N = x.shape
|
|
280
|
+
device = x.device
|
|
281
|
+
out = torch.empty_like(x)
|
|
282
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
283
|
+
mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
|
|
284
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
285
|
+
convert_from_dlpack = lambda x: (
|
|
286
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
287
|
+
mode=0, stride_order=(0, 1)
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
x_tensor, out_tensor = [
|
|
291
|
+
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
292
|
+
convert_from_dlpack(t)
|
|
293
|
+
for t in (x, out)
|
|
294
|
+
]
|
|
295
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
296
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
297
|
+
)
|
|
298
|
+
rstd_tensor = (
|
|
299
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
300
|
+
if rstd is not None
|
|
301
|
+
else None
|
|
302
|
+
)
|
|
303
|
+
mean_tensor = (
|
|
304
|
+
from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
305
|
+
if mean is not None
|
|
306
|
+
else None
|
|
307
|
+
)
|
|
308
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
309
|
+
compile_key = (dtype, N, rstd is not None, mean is not None)
|
|
310
|
+
if compile_key not in layernorm.compile_cache:
|
|
311
|
+
rmsnorm_op = LayerNorm(dtype, N)
|
|
312
|
+
layernorm.compile_cache[compile_key] = cute.compile(
|
|
313
|
+
rmsnorm_op,
|
|
314
|
+
x_tensor,
|
|
315
|
+
weight_tensor,
|
|
316
|
+
out_tensor,
|
|
317
|
+
rstd_tensor,
|
|
318
|
+
mean_tensor,
|
|
319
|
+
current_stream,
|
|
320
|
+
)
|
|
321
|
+
layernorm.compile_cache[compile_key](
|
|
322
|
+
x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
|
|
323
|
+
)
|
|
324
|
+
return (
|
|
325
|
+
(out, rstd, mean)
|
|
326
|
+
if return_mean and return_rstd
|
|
327
|
+
else (
|
|
328
|
+
(out, rstd)
|
|
329
|
+
if return_rstd and not return_mean
|
|
330
|
+
else ((out, mean) if return_mean and not return_rstd else (out))
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
layernorm.compile_cache = {}
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
|
|
339
|
+
x_f32 = x.float()
|
|
340
|
+
return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
|
|
344
|
+
x_f32 = x.float()
|
|
345
|
+
mean = x_f32.mean(dim=-1, keepdim=True)
|
|
346
|
+
var = ((x_f32 - mean) ** 2).mean(dim=-1)
|
|
347
|
+
return 1.0 / torch.sqrt(var + eps)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def mean_ref(x: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
return x.float().mean(dim=-1)
|
|
@@ -117,7 +117,10 @@ class RMSNorm(ReductionBase):
|
|
|
117
117
|
shape = mX.shape
|
|
118
118
|
idX = cute.make_identity_tensor(shape)
|
|
119
119
|
# slice for CTAs
|
|
120
|
-
|
|
120
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
121
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
122
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
123
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
121
124
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
122
125
|
gRstd = (
|
|
123
126
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
|
|
|
98
98
|
shape = mX.shape
|
|
99
99
|
idX = cute.make_identity_tensor(shape)
|
|
100
100
|
# slice for CTAs
|
|
101
|
-
|
|
101
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
102
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
103
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
104
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
102
105
|
|
|
103
106
|
smem = cutlass.utils.SmemAllocator()
|
|
104
107
|
sX = smem.allocate_tensor(
|
|
@@ -312,9 +315,11 @@ class SoftmaxBackward(ReductionBase):
|
|
|
312
315
|
shape = mdY.shape
|
|
313
316
|
idX = cute.make_identity_tensor(shape)
|
|
314
317
|
# slice for CTAs
|
|
315
|
-
|
|
316
|
-
|
|
318
|
+
mdY, mY, mdX = [
|
|
319
|
+
utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
|
|
317
320
|
]
|
|
321
|
+
gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
|
|
322
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
318
323
|
|
|
319
324
|
smem = cutlass.utils.SmemAllocator()
|
|
320
325
|
sdY = smem.allocate_tensor(
|
|
@@ -390,3 +390,19 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
390
390
|
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
391
391
|
)
|
|
392
392
|
return res0, res1
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@dsl_user_op
|
|
396
|
+
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
397
|
+
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
398
|
+
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
399
|
+
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
400
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
401
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
402
|
+
new_ptr = cute.make_ptr(
|
|
403
|
+
tensor.element_type,
|
|
404
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
405
|
+
tensor.memspace,
|
|
406
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
407
|
+
)
|
|
408
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
@@ -4,6 +4,7 @@ pyproject.toml
|
|
|
4
4
|
setup.py
|
|
5
5
|
quack/__init__.py
|
|
6
6
|
quack/cross_entropy.py
|
|
7
|
+
quack/layernorm.py
|
|
7
8
|
quack/reduction_base.py
|
|
8
9
|
quack/rmsnorm.py
|
|
9
10
|
quack/softmax.py
|
|
@@ -14,5 +15,6 @@ quack_kernels.egg-info/dependency_links.txt
|
|
|
14
15
|
quack_kernels.egg-info/requires.txt
|
|
15
16
|
quack_kernels.egg-info/top_level.txt
|
|
16
17
|
tests/test_cross_entropy.py
|
|
18
|
+
tests/test_layernorm.py
|
|
17
19
|
tests/test_rmsnorm.py
|
|
18
20
|
tests/test_softmax.py
|
|
@@ -16,18 +16,19 @@ import cutlass
|
|
|
16
16
|
)
|
|
17
17
|
@pytest.mark.parametrize("M", [1, 77, 289])
|
|
18
18
|
# @pytest.mark.parametrize("M", [1])
|
|
19
|
-
def
|
|
19
|
+
def test_cross_entropy(M, N, input_dtype):
|
|
20
20
|
"""Test Cross Entropy forward pass against reference implementation."""
|
|
21
21
|
device = "cuda"
|
|
22
|
-
atol, rtol =
|
|
22
|
+
atol, rtol = 5e-5, 1e-5
|
|
23
23
|
torch.random.manual_seed(0)
|
|
24
|
+
cutlass.cuda.initialize_cuda_context()
|
|
24
25
|
# Create input tensors (scale down to avoid overflow)
|
|
25
|
-
x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype
|
|
26
|
+
x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_()
|
|
26
27
|
target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
|
|
27
|
-
x_ref = x.detach().clone()
|
|
28
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
28
29
|
target_ref = target.detach().clone()
|
|
29
30
|
# Forward pass
|
|
30
|
-
loss =
|
|
31
|
+
loss = cross_entropy(x, target)
|
|
31
32
|
loss_ref = F.cross_entropy(x_ref.float(), target_ref, reduction='none')
|
|
32
33
|
# Check output shape and dtype
|
|
33
34
|
assert loss.shape == (M,)
|
|
@@ -40,6 +41,13 @@ def test_cross_entropy_forward(M, N, input_dtype):
|
|
|
40
41
|
# Check that loss is reasonable (not inf or nan)
|
|
41
42
|
assert not torch.isnan(loss).any()
|
|
42
43
|
assert not torch.isinf(loss).any()
|
|
44
|
+
# Test backward pass
|
|
45
|
+
dloss = torch.randn_like(loss)
|
|
46
|
+
torch.cuda.synchronize()
|
|
47
|
+
dx_ref, = torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss)
|
|
48
|
+
dx, = torch.autograd.grad(loss, x, grad_outputs=dloss)
|
|
49
|
+
assert dx.shape == x.shape
|
|
50
|
+
torch.testing.assert_close(dx, dx_ref.to(input_dtype), atol=atol, rtol=rtol)
|
|
43
51
|
|
|
44
52
|
|
|
45
53
|
@pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32])
|
|
@@ -99,47 +107,3 @@ def test_cross_entropy_edge_targets():
|
|
|
99
107
|
loss_last = _cross_entropy(x, target_last)
|
|
100
108
|
loss_ref_last = F.cross_entropy(x, target_last, reduction='none')
|
|
101
109
|
torch.testing.assert_close(loss_last, loss_ref_last, atol=1e-4, rtol=1e-4)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
108
|
-
@pytest.mark.parametrize(
|
|
109
|
-
"N",
|
|
110
|
-
[192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 128256, 131072, 256128, 262144] # A representative subset to keep compile time reasonable
|
|
111
|
-
)
|
|
112
|
-
@pytest.mark.parametrize("M", [1, 37, 77])
|
|
113
|
-
def test_cross_entropy_autograd_backward(M, N, input_dtype):
|
|
114
|
-
device = "cuda"
|
|
115
|
-
|
|
116
|
-
if input_dtype == torch.bfloat16:
|
|
117
|
-
atol = 1e-3
|
|
118
|
-
rtol = 1e-3
|
|
119
|
-
else:
|
|
120
|
-
atol = 1e-5
|
|
121
|
-
rtol = 1e-5
|
|
122
|
-
|
|
123
|
-
torch.random.manual_seed(0)
|
|
124
|
-
|
|
125
|
-
x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
|
|
126
|
-
target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
|
|
127
|
-
|
|
128
|
-
x_ref = x.detach().clone().requires_grad_(True)
|
|
129
|
-
target_ref = target.detach().clone()
|
|
130
|
-
|
|
131
|
-
cutlass.cuda.initialize_cuda_context()
|
|
132
|
-
|
|
133
|
-
loss = cross_entropy(x, target) # our autograd-enabled op
|
|
134
|
-
loss_ref = F.cross_entropy(x_ref.float(), target_ref, reduction='none')
|
|
135
|
-
|
|
136
|
-
torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol)
|
|
137
|
-
|
|
138
|
-
dloss = torch.randn_like(loss)
|
|
139
|
-
|
|
140
|
-
dx_ref, = torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss)
|
|
141
|
-
|
|
142
|
-
dx, = torch.autograd.grad(loss, x, grad_outputs=dloss)
|
|
143
|
-
|
|
144
|
-
assert dx.shape == x.shape
|
|
145
|
-
torch.testing.assert_close(dx, dx_ref.to(input_dtype), atol=atol, rtol=rtol)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# tests/test_layernorm.py
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from quack.layernorm import layernorm, layernorm_ref, rstd_ref, mean_ref
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
10
|
+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
11
|
+
@pytest.mark.parametrize("M", [1, 37, 199])
|
|
12
|
+
@pytest.mark.parametrize(
|
|
13
|
+
"N", [256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
|
|
14
|
+
) # , 32768])
|
|
15
|
+
def test_layernorm_forward(M, N, input_dtype, eps):
|
|
16
|
+
"""Test LayerNorm forward pass against reference implementation."""
|
|
17
|
+
device = "cuda"
|
|
18
|
+
|
|
19
|
+
# tolerance depends on precision
|
|
20
|
+
if input_dtype == torch.bfloat16:
|
|
21
|
+
atol = 1e-2
|
|
22
|
+
rtol = 1e-2
|
|
23
|
+
elif input_dtype == torch.float16:
|
|
24
|
+
atol = 1e-3
|
|
25
|
+
rtol = 1e-3
|
|
26
|
+
else:
|
|
27
|
+
atol = 1e-4
|
|
28
|
+
rtol = 1e-4
|
|
29
|
+
|
|
30
|
+
torch.random.manual_seed(0)
|
|
31
|
+
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
|
|
32
|
+
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
33
|
+
|
|
34
|
+
# pure‐PyTorch refs
|
|
35
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
36
|
+
weight_ref = weight.detach().clone().requires_grad_()
|
|
37
|
+
|
|
38
|
+
out, rstd, mean = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=True)
|
|
39
|
+
out_ref = layernorm_ref(x_ref, weight_ref, eps=eps)
|
|
40
|
+
rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
41
|
+
mean_ref_val = mean_ref(x_ref)
|
|
42
|
+
|
|
43
|
+
# shapes & dtypes
|
|
44
|
+
assert out.shape == x.shape
|
|
45
|
+
assert out.dtype == input_dtype
|
|
46
|
+
assert rstd.shape == (M,) and rstd.dtype == torch.float32
|
|
47
|
+
assert mean.shape == (M,) and mean.dtype == torch.float32
|
|
48
|
+
|
|
49
|
+
# numeric check
|
|
50
|
+
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
|
|
51
|
+
torch.testing.assert_close(rstd, rstd_ref_val, atol=6e-4, rtol=6e-4)
|
|
52
|
+
torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@pytest.mark.parametrize("return_rstd", [True, False])
|
|
56
|
+
@pytest.mark.parametrize("return_mean", [True, False])
|
|
57
|
+
def test_layernormnorm_return_rstd_option(return_rstd, return_mean):
|
|
58
|
+
"""Test that return_rstd option works correctly."""
|
|
59
|
+
device = "cuda"
|
|
60
|
+
M, N = 32, 1024
|
|
61
|
+
eps = 1e-6
|
|
62
|
+
|
|
63
|
+
x = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
64
|
+
weight = torch.randn(N, device=device, dtype=torch.float32)
|
|
65
|
+
|
|
66
|
+
if return_rstd and return_mean:
|
|
67
|
+
out, rstd, mean = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=True)
|
|
68
|
+
assert out.shape == (M, N)
|
|
69
|
+
assert rstd.shape == (M,)
|
|
70
|
+
assert rstd.dtype == torch.float32
|
|
71
|
+
assert mean.shape == (M,)
|
|
72
|
+
assert mean.dtype == torch.float32
|
|
73
|
+
elif return_rstd and not return_mean:
|
|
74
|
+
out, rstd = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=False)
|
|
75
|
+
assert out.shape == (M, N)
|
|
76
|
+
assert rstd.shape == (M,)
|
|
77
|
+
assert rstd.dtype == torch.float32
|
|
78
|
+
elif not return_rstd and return_mean:
|
|
79
|
+
out, mean = layernorm(x, weight, eps=eps, return_rstd=False, return_mean=True)
|
|
80
|
+
assert out.shape == (M, N)
|
|
81
|
+
assert mean.shape == (M,)
|
|
82
|
+
assert mean.dtype == torch.float32
|
|
83
|
+
else:
|
|
84
|
+
out = layernorm(x, weight, eps=eps, return_rstd=False, return_mean=False)
|
|
85
|
+
assert out.shape == (M, N)
|
|
86
|
+
assert isinstance(out, torch.Tensor)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_layernorm_input_validation():
|
|
90
|
+
"""Test input validation and error handling."""
|
|
91
|
+
device = "cuda"
|
|
92
|
+
|
|
93
|
+
# Test 3D input (should fail)
|
|
94
|
+
x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
|
|
95
|
+
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
96
|
+
|
|
97
|
+
with pytest.raises(AssertionError, match="Input must be 2D"):
|
|
98
|
+
layernorm(x_3d, weight)
|
|
99
|
+
|
|
100
|
+
# Test weight dimension mismatch
|
|
101
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
102
|
+
weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
|
|
103
|
+
|
|
104
|
+
with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
|
|
105
|
+
layernorm(x, weight_wrong)
|
|
106
|
+
|
|
107
|
+
# Test CPU tensors (should fail)
|
|
108
|
+
x_cpu = torch.randn(32, 1024, dtype=torch.float16)
|
|
109
|
+
weight_cpu = torch.randn(1024, dtype=torch.float32)
|
|
110
|
+
|
|
111
|
+
with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
|
|
112
|
+
layernorm(x_cpu, weight_cpu)
|
|
113
|
+
|
|
114
|
+
# Test unsupported dtype
|
|
115
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float64)
|
|
116
|
+
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
117
|
+
|
|
118
|
+
with pytest.raises(AssertionError, match="Unsupported dtype"):
|
|
119
|
+
layernorm(x, weight)
|
|
120
|
+
|
|
121
|
+
# Test wrong weight dtype
|
|
122
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
123
|
+
weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16)
|
|
124
|
+
|
|
125
|
+
with pytest.raises(AssertionError, match="Weight must be float32"):
|
|
126
|
+
layernorm(x, weight_wrong_dtype)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_layernorm_compile_cache():
|
|
130
|
+
"""Test that compile cache works correctly for repeated calls."""
|
|
131
|
+
device = "cuda"
|
|
132
|
+
M, N = 32, 1024
|
|
133
|
+
eps = 1e-6
|
|
134
|
+
|
|
135
|
+
# Clear cache
|
|
136
|
+
layernorm.compile_cache.clear()
|
|
137
|
+
assert len(layernorm.compile_cache) == 0
|
|
138
|
+
|
|
139
|
+
x1 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
140
|
+
weight1 = torch.randn(N, device=device, dtype=torch.float32)
|
|
141
|
+
|
|
142
|
+
# First call should compile
|
|
143
|
+
out1 = layernorm(x1, weight1, eps=eps)
|
|
144
|
+
assert len(layernorm.compile_cache) == 1
|
|
145
|
+
|
|
146
|
+
# Same shape should reuse cache
|
|
147
|
+
x2 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
148
|
+
weight2 = torch.randn(N, device=device, dtype=torch.float32)
|
|
149
|
+
out2 = layernorm(x2, weight2, eps=eps)
|
|
150
|
+
assert len(layernorm.compile_cache) == 1
|
|
151
|
+
|
|
152
|
+
# Different shape should create new cache entry
|
|
153
|
+
x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
|
|
154
|
+
weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
|
|
155
|
+
out3 = layernorm(x3, weight3, eps=eps)
|
|
156
|
+
assert len(layernorm.compile_cache) == 2
|
|
157
|
+
|
|
158
|
+
# Different dtype should create new cache entry
|
|
159
|
+
x4 = torch.randn(M, N, device=device, dtype=torch.float32)
|
|
160
|
+
weight4 = torch.randn(N, device=device, dtype=torch.float32)
|
|
161
|
+
out4 = layernorm(x4, weight4, eps=eps)
|
|
162
|
+
assert len(layernorm.compile_cache) == 3
|
|
@@ -32,19 +32,17 @@ def test_rmsnorm_forward(M, N, input_dtype, eps):
|
|
|
32
32
|
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
33
33
|
x_ref = x.detach().clone().requires_grad_()
|
|
34
34
|
weight_ref = weight.detach().clone().requires_grad_()
|
|
35
|
-
out
|
|
35
|
+
out = rmsnorm(x, weight, eps=eps)
|
|
36
36
|
out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
37
|
-
rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
37
|
+
# rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
38
38
|
|
|
39
39
|
# Check output shape and dtype
|
|
40
40
|
assert out.shape == x.shape
|
|
41
41
|
assert out.dtype == input_dtype
|
|
42
|
-
assert rstd.shape == (M,)
|
|
43
|
-
assert rstd.dtype == torch.float32
|
|
44
42
|
|
|
45
43
|
# Check accuracy
|
|
46
44
|
torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
|
|
47
|
-
torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
|
|
45
|
+
# torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
|
|
48
46
|
|
|
49
47
|
|
|
50
48
|
# @pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
@@ -88,6 +86,39 @@ def test_rmsnorm_forward(M, N, input_dtype, eps):
|
|
|
88
86
|
# torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
89
87
|
|
|
90
88
|
|
|
89
|
+
@pytest.mark.parametrize("eps", [1e-5])
|
|
90
|
+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
|
|
91
|
+
@pytest.mark.parametrize(
|
|
92
|
+
"N",
|
|
93
|
+
[131072, 262144]
|
|
94
|
+
# [262144]
|
|
95
|
+
)
|
|
96
|
+
@pytest.mark.parametrize("M", [32 * 1024])
|
|
97
|
+
def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
|
|
98
|
+
"""Test RMSNorm forward pass against reference implementation."""
|
|
99
|
+
device = "cuda"
|
|
100
|
+
# Set tolerance based on dtype
|
|
101
|
+
if input_dtype == torch.bfloat16:
|
|
102
|
+
atol = 1e-1
|
|
103
|
+
elif input_dtype == torch.float16:
|
|
104
|
+
atol = 1e-2
|
|
105
|
+
else:
|
|
106
|
+
atol = 1e-4
|
|
107
|
+
torch.random.manual_seed(0)
|
|
108
|
+
torch.cuda.empty_cache()
|
|
109
|
+
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
|
|
110
|
+
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
|
|
111
|
+
out = rmsnorm(x, weight, eps=eps)
|
|
112
|
+
# Need to compile, otherwise it OOMs
|
|
113
|
+
rmsnorm_compiled = torch.compile(rmsnorm_ref)
|
|
114
|
+
# Run once with smaller input to avoid OOMs
|
|
115
|
+
rmsnorm_compiled(x[:32], weight, eps=eps)
|
|
116
|
+
out_ref = rmsnorm_compiled(x, weight, eps=eps)
|
|
117
|
+
# Need to chunk, otherwise it OOMs
|
|
118
|
+
assert all((out_c - out_ref_c).abs().max() < atol
|
|
119
|
+
for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16)))
|
|
120
|
+
|
|
121
|
+
|
|
91
122
|
@pytest.mark.parametrize("return_rstd", [True, False])
|
|
92
123
|
def test_rmsnorm_return_rstd_option(return_rstd):
|
|
93
124
|
"""Test that return_rstd option works correctly."""
|
|
@@ -34,7 +34,7 @@ def test_softmax(M, N, input_dtype):
|
|
|
34
34
|
|
|
35
35
|
torch.random.manual_seed(0)
|
|
36
36
|
# Create input tensors (scale down to avoid overflow in softmax)
|
|
37
|
-
x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype
|
|
37
|
+
x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_()
|
|
38
38
|
x_ref = x.detach().clone().requires_grad_(True)
|
|
39
39
|
|
|
40
40
|
# Forward pass
|
|
@@ -58,13 +58,12 @@ def test_softmax(M, N, input_dtype):
|
|
|
58
58
|
|
|
59
59
|
# Test backward pass
|
|
60
60
|
dy = torch.randn_like(out)
|
|
61
|
+
torch.cuda.synchronize() # without sync, torch.autograd gets wrong results
|
|
61
62
|
dx_ref, = torch.autograd.grad(out_ref, x_ref, grad_outputs=dy)
|
|
62
63
|
# Call our implementation later, otherwise getting CUDA_ERROR_INVALID_CONTEXT
|
|
63
64
|
dx, = torch.autograd.grad(out, x, grad_outputs=dy)
|
|
64
|
-
# Check output shape and dtype
|
|
65
65
|
assert dx.shape == dy.shape
|
|
66
66
|
assert dx.dtype == input_dtype
|
|
67
|
-
# Check accuracy against reference
|
|
68
67
|
torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
|
|
69
68
|
|
|
70
69
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|