quack-kernels 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.7/quack_kernels.egg-info → quack_kernels-0.1.8}/PKG-INFO +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/__init__.py +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/cross_entropy.py +56 -15
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/rmsnorm.py +191 -68
- {quack_kernels-0.1.7 → quack_kernels-0.1.8/quack_kernels.egg-info}/PKG-INFO +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/top_level.txt +1 -0
- quack_kernels-0.1.8/tests/test_rmsnorm.py +392 -0
- quack_kernels-0.1.7/tests/test_rmsnorm.py +0 -183
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/LICENSE +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/README.md +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/pyproject.toml +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/layernorm.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/reduction_base.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/softmax.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/utils.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/SOURCES.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/requires.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/setup.cfg +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/setup.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_cross_entropy.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_layernorm.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_softmax.py +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
import torch
|
|
5
4
|
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
import cuda.bindings.driver as cuda
|
|
8
7
|
|
|
9
8
|
import cutlass
|
|
10
9
|
import cutlass.cute as cute
|
|
11
|
-
from cutlass.cute.runtime import from_dlpack
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
import torch
|
|
13
|
+
from cutlass.cute.runtime import from_dlpack
|
|
14
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
15
|
|
|
16
16
|
|
|
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
|
|
|
79
79
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
80
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
81
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
82
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
83
83
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
84
|
stream=stream,
|
|
85
85
|
)
|
|
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
|
|
|
111
111
|
|
|
112
112
|
smem = cutlass.utils.SmemAllocator()
|
|
113
113
|
sX = smem.allocate_tensor(
|
|
114
|
-
mX.element_type,
|
|
114
|
+
mX.element_type,
|
|
115
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
116
|
+
byte_alignment=16,
|
|
115
117
|
)
|
|
116
118
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
119
|
|
|
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
|
|
|
166
168
|
reduction_buffer[None, None, 0],
|
|
167
169
|
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
168
170
|
init_val=-cutlass.Float32.inf,
|
|
169
|
-
hook_fn=
|
|
171
|
+
hook_fn=(
|
|
172
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
+
),
|
|
170
174
|
)
|
|
171
175
|
if cutlass.const_expr(self.reload_from == "smem"):
|
|
172
176
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
|
|
|
191
195
|
threads_per_row,
|
|
192
196
|
reduction_buffer[None, None, 0],
|
|
193
197
|
mbar_ptr,
|
|
194
|
-
hook_fn=
|
|
198
|
+
hook_fn=(
|
|
199
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
200
|
+
),
|
|
195
201
|
)
|
|
196
202
|
|
|
197
203
|
if (
|
|
@@ -225,7 +231,11 @@ def _cross_entropy(
|
|
|
225
231
|
assert target.dim() == 1, "Target must be 1D"
|
|
226
232
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
227
233
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
228
|
-
assert x.dtype in [
|
|
234
|
+
assert x.dtype in [
|
|
235
|
+
torch.float16,
|
|
236
|
+
torch.bfloat16,
|
|
237
|
+
torch.float32,
|
|
238
|
+
], "Unsupported input dtype"
|
|
229
239
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
230
240
|
M, N = x.shape
|
|
231
241
|
device = x.device
|
|
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
|
|
|
314
324
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
325
|
|
|
316
326
|
mDLoss = cute.make_tensor(
|
|
317
|
-
mDLoss.iterator,
|
|
327
|
+
mDLoss.iterator,
|
|
328
|
+
cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
318
329
|
)
|
|
319
330
|
mTarget = cute.make_tensor(
|
|
320
|
-
mTarget.iterator,
|
|
331
|
+
mTarget.iterator,
|
|
332
|
+
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
321
333
|
)
|
|
322
334
|
mLSE = cute.make_tensor(
|
|
323
|
-
mLSE.iterator,
|
|
335
|
+
mLSE.iterator,
|
|
336
|
+
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
324
337
|
)
|
|
325
338
|
|
|
326
339
|
smem_size = cute.size_in_bytes(
|
|
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
|
|
|
364
377
|
|
|
365
378
|
smem = cutlass.utils.SmemAllocator()
|
|
366
379
|
sX = smem.allocate_tensor(
|
|
367
|
-
mX.element_type,
|
|
380
|
+
mX.element_type,
|
|
381
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
382
|
+
byte_alignment=16,
|
|
368
383
|
)
|
|
369
384
|
|
|
370
385
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
|
|
|
474
489
|
assert (
|
|
475
490
|
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
491
|
), "Tensors must be on CUDA device"
|
|
477
|
-
assert x.dtype in [
|
|
492
|
+
assert x.dtype in [
|
|
493
|
+
torch.float16,
|
|
494
|
+
torch.bfloat16,
|
|
495
|
+
torch.float32,
|
|
496
|
+
], "Unsupported input dtype"
|
|
478
497
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
498
|
|
|
480
499
|
M, N = x.shape
|
|
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
|
|
|
532
551
|
|
|
533
552
|
|
|
534
553
|
def cross_entropy(
|
|
535
|
-
x: torch.Tensor,
|
|
554
|
+
x: torch.Tensor,
|
|
555
|
+
target: torch.Tensor,
|
|
556
|
+
inplace_backward: bool = True,
|
|
557
|
+
reduction: str = "none",
|
|
536
558
|
) -> torch.Tensor:
|
|
537
559
|
"""Cross entropy loss with automatic differentiation support.
|
|
538
560
|
|
|
539
561
|
Args:
|
|
540
562
|
x: Input logits tensor of shape (M, N)
|
|
541
563
|
target: Target class indices tensor of shape (M,)
|
|
564
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
565
|
+
reduction: Specifies the reduction to apply to the output:
|
|
566
|
+
'none': no reduction will be applied (default)
|
|
567
|
+
'mean': the sum of the output will be divided by the number of elements
|
|
568
|
+
'sum': the output will be summed
|
|
542
569
|
|
|
543
570
|
Returns:
|
|
544
|
-
Cross entropy loss tensor
|
|
571
|
+
Cross entropy loss tensor:
|
|
572
|
+
- If reduction='none': tensor of shape (M,) with per-example losses
|
|
573
|
+
- If reduction='mean': scalar tensor with mean loss
|
|
574
|
+
- If reduction='sum': scalar tensor with sum of losses
|
|
545
575
|
"""
|
|
546
|
-
|
|
576
|
+
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
577
|
+
|
|
578
|
+
if reduction == "mean":
|
|
579
|
+
return loss.mean()
|
|
580
|
+
elif reduction == "sum":
|
|
581
|
+
return loss.sum()
|
|
582
|
+
elif reduction == "none":
|
|
583
|
+
return loss
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
|
|
587
|
+
)
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
8
7
|
import cutlass
|
|
9
8
|
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, Int32
|
|
10
10
|
from cutlass.cute.runtime import from_dlpack
|
|
11
|
+
|
|
11
12
|
import quack.utils as utils
|
|
13
|
+
import torch
|
|
12
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
15
|
|
|
14
16
|
|
|
@@ -19,41 +21,55 @@ class RMSNorm(ReductionBase):
|
|
|
19
21
|
self.delay_w_load = False
|
|
20
22
|
|
|
21
23
|
def _calculate_threads_per_row(self):
|
|
24
|
+
"""Calculate the number of threads per row for the RMSNorm kernel."""
|
|
22
25
|
N = self.N
|
|
23
|
-
|
|
24
|
-
8
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
26
|
+
if N <= 64:
|
|
27
|
+
return 8
|
|
28
|
+
elif N <= 128:
|
|
29
|
+
return 16
|
|
30
|
+
elif N <= 3072:
|
|
31
|
+
return 32
|
|
32
|
+
elif N <= 6144:
|
|
33
|
+
return 64
|
|
34
|
+
elif N <= 16384:
|
|
35
|
+
return 128
|
|
36
|
+
else:
|
|
37
|
+
return 256
|
|
32
38
|
|
|
33
39
|
def _set_cluster_n(self):
|
|
40
|
+
"""
|
|
41
|
+
Set the number of clusters for the RMSNorm kernel.
|
|
42
|
+
Stored in self.cluster_n.
|
|
43
|
+
"""
|
|
34
44
|
N = self.N
|
|
45
|
+
|
|
35
46
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
36
47
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
37
48
|
if cutlass.const_expr(self.dtype.width == 16):
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
# 16-bit types (fp16, bf16)
|
|
50
|
+
if N <= 16 * 1024:
|
|
51
|
+
cluster_n = 1
|
|
52
|
+
elif N <= 32 * 1024:
|
|
53
|
+
cluster_n = 2
|
|
54
|
+
elif N <= 64 * 1024:
|
|
55
|
+
cluster_n = 4
|
|
56
|
+
elif N <= 128 * 1024:
|
|
57
|
+
cluster_n = 8
|
|
58
|
+
else:
|
|
59
|
+
cluster_n = 16
|
|
60
|
+
else:
|
|
61
|
+
# 32-bit types (fp32)
|
|
62
|
+
if N <= 32 * 1024:
|
|
63
|
+
cluster_n = 1
|
|
64
|
+
elif N <= 64 * 1024:
|
|
65
|
+
cluster_n = 2
|
|
66
|
+
elif N <= 128 * 1024:
|
|
67
|
+
cluster_n = 4
|
|
68
|
+
elif N <= 256 * 1024:
|
|
69
|
+
cluster_n = 8
|
|
70
|
+
else:
|
|
71
|
+
cluster_n = 16
|
|
72
|
+
|
|
57
73
|
self.cluster_n = cluster_n
|
|
58
74
|
|
|
59
75
|
@cute.jit
|
|
@@ -64,8 +80,17 @@ class RMSNorm(ReductionBase):
|
|
|
64
80
|
mO: cute.Tensor,
|
|
65
81
|
mRstd: Optional[cute.Tensor],
|
|
66
82
|
stream: cuda.CUstream,
|
|
67
|
-
eps:
|
|
83
|
+
eps: Float32 = 1e-6,
|
|
68
84
|
):
|
|
85
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
86
|
+
new_stride = lambda t: (
|
|
87
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
88
|
+
t.stride[1],
|
|
89
|
+
)
|
|
90
|
+
mX, mO = [
|
|
91
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
92
|
+
for t in (mX, mO)
|
|
93
|
+
]
|
|
69
94
|
assert mX.element_type == self.dtype
|
|
70
95
|
assert mO.element_type == self.dtype
|
|
71
96
|
self._set_cluster_n()
|
|
@@ -82,7 +107,7 @@ class RMSNorm(ReductionBase):
|
|
|
82
107
|
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
83
108
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
84
109
|
block=[num_threads, 1, 1],
|
|
85
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
110
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
86
111
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
87
112
|
stream=stream,
|
|
88
113
|
)
|
|
@@ -109,7 +134,9 @@ class RMSNorm(ReductionBase):
|
|
|
109
134
|
|
|
110
135
|
smem = cutlass.utils.SmemAllocator()
|
|
111
136
|
sX = smem.allocate_tensor(
|
|
112
|
-
mX.element_type,
|
|
137
|
+
mX.element_type,
|
|
138
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
139
|
+
byte_alignment=16,
|
|
113
140
|
)
|
|
114
141
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
115
142
|
|
|
@@ -184,7 +211,7 @@ class RMSNorm(ReductionBase):
|
|
|
184
211
|
reduction_buffer[None, None, 0],
|
|
185
212
|
mbar_ptr,
|
|
186
213
|
init_val=0.0,
|
|
187
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
214
|
+
hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
188
215
|
)
|
|
189
216
|
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
190
217
|
if cutlass.const_expr(mRstd is not None):
|
|
@@ -232,25 +259,36 @@ def _rmsnorm_fwd(
|
|
|
232
259
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
233
260
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
234
261
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
235
|
-
assert x.dtype in [
|
|
236
|
-
|
|
262
|
+
assert x.dtype in [
|
|
263
|
+
torch.float16,
|
|
264
|
+
torch.bfloat16,
|
|
265
|
+
torch.float32,
|
|
266
|
+
], "Unsupported dtype"
|
|
267
|
+
|
|
268
|
+
assert weight.dtype in [
|
|
269
|
+
torch.float32,
|
|
270
|
+
torch.bfloat16,
|
|
271
|
+
torch.float16,
|
|
272
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
273
|
+
|
|
237
274
|
M, N = x.shape
|
|
238
275
|
device = x.device
|
|
239
276
|
out = torch.empty_like(x)
|
|
240
277
|
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
241
278
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
279
|
+
# convert_from_dlpack = lambda x: (
|
|
280
|
+
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
281
|
+
# mode=0, divisibility=128 // dtype.width
|
|
282
|
+
# )
|
|
283
|
+
# )
|
|
242
284
|
convert_from_dlpack = lambda x: (
|
|
243
|
-
from_dlpack(x.detach(), assumed_align=16).
|
|
244
|
-
mode=0, stride_order=(0, 1)
|
|
245
|
-
)
|
|
285
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
246
286
|
)
|
|
247
|
-
x_tensor, out_tensor = [
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
for t in (x, out)
|
|
251
|
-
]
|
|
287
|
+
x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
|
|
288
|
+
# handle weight divisibility based on weight dtype
|
|
289
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
252
290
|
weight_tensor = utils.convert_from_dlpack(
|
|
253
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
291
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
254
292
|
)
|
|
255
293
|
rstd_tensor = (
|
|
256
294
|
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
@@ -258,7 +296,7 @@ def _rmsnorm_fwd(
|
|
|
258
296
|
else None
|
|
259
297
|
)
|
|
260
298
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
261
|
-
compile_key = (dtype, N, rstd is not None)
|
|
299
|
+
compile_key = (dtype, N, rstd is not None, weight.dtype)
|
|
262
300
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
263
301
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
264
302
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -301,7 +339,8 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
301
339
|
class RMSNormBackward(ReductionBase):
|
|
302
340
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
303
341
|
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
304
|
-
super().__init__(dtype, N, stage=2, reduction_dtype=
|
|
342
|
+
super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
|
|
343
|
+
self.reload_wdy = None if N <= 16 * 1024 else "smem"
|
|
305
344
|
if self.N > 128 * 1024 and self.dtype.width >= 32:
|
|
306
345
|
# Not enough smem
|
|
307
346
|
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
@@ -348,9 +387,18 @@ class RMSNormBackward(ReductionBase):
|
|
|
348
387
|
mRstd: cute.Tensor,
|
|
349
388
|
mdX: cute.Tensor,
|
|
350
389
|
mdW: cute.Tensor,
|
|
351
|
-
sm_count:
|
|
390
|
+
sm_count: Int32,
|
|
352
391
|
stream: cuda.CUstream,
|
|
353
392
|
):
|
|
393
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
394
|
+
new_stride = lambda t: (
|
|
395
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
396
|
+
t.stride[1],
|
|
397
|
+
)
|
|
398
|
+
mX, mdOut, mdX = [
|
|
399
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
400
|
+
for t in (mX, mdOut, mdX)
|
|
401
|
+
]
|
|
354
402
|
self._set_cluster_n()
|
|
355
403
|
tiler_mn, tv_layout = self._get_tv_layout()
|
|
356
404
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
@@ -460,7 +508,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
460
508
|
|
|
461
509
|
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
462
510
|
tdWgdW = thr_copy_dW.partition_D(gdW)
|
|
463
|
-
|
|
511
|
+
# Always compute partial weight gradients in fp32
|
|
512
|
+
tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
|
|
464
513
|
tXrdW = thr_copy_X.retile(tdWrdW)
|
|
465
514
|
|
|
466
515
|
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
@@ -513,9 +562,9 @@ class RMSNormBackward(ReductionBase):
|
|
|
513
562
|
|
|
514
563
|
threads_per_row = tv_layout.shape[0][0]
|
|
515
564
|
tXrdW.fill(0.0)
|
|
516
|
-
stage =
|
|
517
|
-
producer_phase =
|
|
518
|
-
consumer_phase =
|
|
565
|
+
stage = Int32(0)
|
|
566
|
+
producer_phase = Int32(1)
|
|
567
|
+
consumer_phase = Int32(0)
|
|
519
568
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
520
569
|
row = tXcX[None, None, None, bidx][0][0]
|
|
521
570
|
rstd = cutlass.Float.zero
|
|
@@ -538,10 +587,14 @@ class RMSNormBackward(ReductionBase):
|
|
|
538
587
|
)
|
|
539
588
|
elif tiler_mn[0] > 1:
|
|
540
589
|
utils.fill_oob(
|
|
541
|
-
tXsX[None, None, None, stage ^ 1],
|
|
590
|
+
tXsX[None, None, None, stage ^ 1],
|
|
591
|
+
None,
|
|
592
|
+
fill_value=mX.element_type.zero,
|
|
542
593
|
)
|
|
543
594
|
utils.fill_oob(
|
|
544
|
-
tXsdOut[None, None, None, stage ^ 1],
|
|
595
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
596
|
+
None,
|
|
597
|
+
fill_value=mdOut.element_type.zero,
|
|
545
598
|
)
|
|
546
599
|
cute.arch.cp_async_commit_group()
|
|
547
600
|
if row < M or tiler_mn[0] == 1:
|
|
@@ -561,12 +614,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
561
614
|
cute.ReductionOp.ADD,
|
|
562
615
|
threads_per_row,
|
|
563
616
|
reduction_buffer[None, None, stage],
|
|
564
|
-
mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
617
|
+
(mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
565
618
|
phase=consumer_phase,
|
|
566
619
|
init_val=0.0,
|
|
567
620
|
)
|
|
568
621
|
/ shape[1]
|
|
569
622
|
)
|
|
623
|
+
|
|
570
624
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
571
625
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
572
626
|
# Requires adjusting the thread_count when initializing the mbar
|
|
@@ -576,12 +630,20 @@ class RMSNormBackward(ReductionBase):
|
|
|
576
630
|
cute.arch.mbarrier_arrive(
|
|
577
631
|
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
578
632
|
)
|
|
633
|
+
|
|
634
|
+
if cutlass.const_expr(self.reload_wdy == "smem"):
|
|
635
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
636
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
637
|
+
wdy = dout * weight
|
|
638
|
+
|
|
579
639
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
580
640
|
tXrdX.store(dx.to(tXrdOut.element_type))
|
|
581
641
|
if row < M or tiler_mn[0] == 1:
|
|
582
642
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
583
643
|
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
644
|
+
# Accumulate weight gradients in fp32
|
|
584
645
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
646
|
+
|
|
585
647
|
stage ^= 1
|
|
586
648
|
if stage == 0:
|
|
587
649
|
consumer_phase ^= 1
|
|
@@ -609,7 +671,9 @@ class RMSNormBackward(ReductionBase):
|
|
|
609
671
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
610
672
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
611
673
|
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
674
|
+
|
|
612
675
|
else:
|
|
676
|
+
# dw is already in fp32, so we can directly copy to global memory
|
|
613
677
|
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
614
678
|
|
|
615
679
|
|
|
@@ -634,8 +698,17 @@ def _rmsnorm_backward(
|
|
|
634
698
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
635
699
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
636
700
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
637
|
-
assert x.dtype in [
|
|
638
|
-
|
|
701
|
+
assert x.dtype in [
|
|
702
|
+
torch.float16,
|
|
703
|
+
torch.bfloat16,
|
|
704
|
+
torch.float32,
|
|
705
|
+
], "Unsupported dtype"
|
|
706
|
+
|
|
707
|
+
assert weight.dtype in [
|
|
708
|
+
torch.float32,
|
|
709
|
+
torch.bfloat16,
|
|
710
|
+
torch.float16,
|
|
711
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
639
712
|
|
|
640
713
|
M, N = x.shape
|
|
641
714
|
dx = torch.empty_like(x)
|
|
@@ -654,28 +727,29 @@ def _rmsnorm_backward(
|
|
|
654
727
|
sm_count = (
|
|
655
728
|
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
656
729
|
)
|
|
657
|
-
|
|
730
|
+
|
|
731
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
732
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
658
733
|
|
|
659
734
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
660
735
|
|
|
661
|
-
convert_from_dlpack = lambda
|
|
662
|
-
from_dlpack(
|
|
663
|
-
mode=0, stride_order=(0, 1)
|
|
664
|
-
)
|
|
736
|
+
convert_from_dlpack = lambda x: (
|
|
737
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
665
738
|
)
|
|
666
|
-
|
|
667
739
|
x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
|
|
668
740
|
|
|
741
|
+
# Handle weight div based on weight dtype
|
|
742
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
669
743
|
weight_tensor = utils.convert_from_dlpack(
|
|
670
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
744
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
671
745
|
)
|
|
672
746
|
|
|
673
|
-
dw_partial_tensor =
|
|
747
|
+
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
674
748
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
675
749
|
|
|
676
750
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
677
751
|
|
|
678
|
-
compile_key = (dtype, N)
|
|
752
|
+
compile_key = (dtype, N, weight.dtype)
|
|
679
753
|
if compile_key not in _rmsnorm_backward.compile_cache:
|
|
680
754
|
rmsnorm_backward_op = RMSNormBackward(dtype, N)
|
|
681
755
|
_rmsnorm_backward.compile_cache[compile_key] = cute.compile(
|
|
@@ -700,7 +774,7 @@ def _rmsnorm_backward(
|
|
|
700
774
|
sm_count,
|
|
701
775
|
current_stream,
|
|
702
776
|
)
|
|
703
|
-
|
|
777
|
+
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
704
778
|
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
705
779
|
return dx, dw
|
|
706
780
|
|
|
@@ -711,16 +785,29 @@ _rmsnorm_backward.compile_cache = {}
|
|
|
711
785
|
class RMSNormFunction(torch.autograd.Function):
|
|
712
786
|
@staticmethod
|
|
713
787
|
def forward(ctx, x, weight, eps):
|
|
788
|
+
x_shape_start = x.shape
|
|
789
|
+
|
|
790
|
+
# Flatten input
|
|
791
|
+
x = x.view(-1, x.shape[-1])
|
|
792
|
+
|
|
714
793
|
out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
|
|
715
794
|
ctx.save_for_backward(x, weight, rstd)
|
|
716
795
|
ctx.eps = eps
|
|
717
|
-
|
|
796
|
+
ctx.x_shape_start = x_shape_start
|
|
797
|
+
|
|
798
|
+
return out.reshape(x_shape_start)
|
|
718
799
|
|
|
719
800
|
@staticmethod
|
|
720
801
|
def backward(ctx, dout):
|
|
721
802
|
x, weight, rstd = ctx.saved_tensors
|
|
803
|
+
x_shape_start = ctx.x_shape_start
|
|
804
|
+
# Reshape dout to match the flattened shape used in forward
|
|
805
|
+
dout = dout.view(-1, dout.shape[-1])
|
|
722
806
|
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
723
|
-
|
|
807
|
+
dx = dx.view(x_shape_start)
|
|
808
|
+
# dx is returned for input gradient,
|
|
809
|
+
# dw is returned for weight gradient,
|
|
810
|
+
# None for eps gradient
|
|
724
811
|
return dx, dw, None
|
|
725
812
|
|
|
726
813
|
|
|
@@ -736,3 +823,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
|
|
|
736
823
|
Normalized output tensor of same shape as x
|
|
737
824
|
"""
|
|
738
825
|
return RMSNormFunction.apply(x, weight, eps)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class QuackRMSNorm(torch.nn.Module):
|
|
829
|
+
"""RMSNorm module that behaves like torch.nn.RMSNorm.
|
|
830
|
+
|
|
831
|
+
This class provides a drop-in replacement for torch.nn.RMSNorm that uses
|
|
832
|
+
the quack.rmsnorm implementation under the hood.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
dim (int): The dimension to normalize over
|
|
836
|
+
eps (float, optional): A small constant for numerical stability. Default: 1e-6
|
|
837
|
+
|
|
838
|
+
Attributes:
|
|
839
|
+
weight (torch.nn.Parameter): The learnable weight parameter
|
|
840
|
+
eps (float): A small constant for numerical stability
|
|
841
|
+
"""
|
|
842
|
+
|
|
843
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
844
|
+
super().__init__()
|
|
845
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
846
|
+
self.eps = eps
|
|
847
|
+
|
|
848
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
849
|
+
"""Apply RMSNorm to the input tensor.
|
|
850
|
+
|
|
851
|
+
Args:
|
|
852
|
+
x (torch.Tensor): Input tensor
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
torch.Tensor: Normalized tensor
|
|
856
|
+
"""
|
|
857
|
+
return rmsnorm(x, self.weight, self.eps)
|
|
858
|
+
|
|
859
|
+
def reset_parameters(self):
|
|
860
|
+
"""Reset the weight parameter to ones."""
|
|
861
|
+
torch.nn.init.ones_(self.weight)
|
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from quack.rmsnorm import rmsnorm, rmsnorm_ref, rstd_ref, _rmsnorm_fwd
|
|
7
|
+
|
|
8
|
+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
9
|
+
# @pytest.mark.parametrize("eps", [1e-5])
|
|
10
|
+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
11
|
+
# @pytest.mark.parametrize("input_dtype", [torch.float16])
|
|
12
|
+
@pytest.mark.parametrize(
|
|
13
|
+
"N",
|
|
14
|
+
[
|
|
15
|
+
192,
|
|
16
|
+
256,
|
|
17
|
+
512,
|
|
18
|
+
760,
|
|
19
|
+
1024,
|
|
20
|
+
1128,
|
|
21
|
+
2048,
|
|
22
|
+
4096,
|
|
23
|
+
8192,
|
|
24
|
+
16384,
|
|
25
|
+
32768,
|
|
26
|
+
65536,
|
|
27
|
+
131072,
|
|
28
|
+
262144,
|
|
29
|
+
],
|
|
30
|
+
# [262144]
|
|
31
|
+
)
|
|
32
|
+
@pytest.mark.parametrize("M", [1, 37, 199, 8 * 1024])
|
|
33
|
+
# @pytest.mark.parametrize("M", [1])
|
|
34
|
+
def test_rmsnorm_forward_backward(M, N, input_dtype, eps):
|
|
35
|
+
"""Test RMSNorm forward pass against reference implementation."""
|
|
36
|
+
if N >= 256 * 1024 and input_dtype == torch.float32 and M >= 8 * 1024:
|
|
37
|
+
pytest.skip("Skipping large tensor test for float32 to avoid OOM")
|
|
38
|
+
device = "cuda"
|
|
39
|
+
# Set tolerance based on dtype
|
|
40
|
+
if input_dtype == torch.bfloat16:
|
|
41
|
+
atol = 1e-1
|
|
42
|
+
elif input_dtype == torch.float16:
|
|
43
|
+
atol = 1e-2
|
|
44
|
+
else:
|
|
45
|
+
atol = 1e-4
|
|
46
|
+
torch.random.manual_seed(0)
|
|
47
|
+
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
|
|
48
|
+
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
49
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
50
|
+
weight_ref = weight.detach().clone().requires_grad_()
|
|
51
|
+
out = rmsnorm(x, weight, eps=eps)
|
|
52
|
+
out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
53
|
+
# rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
54
|
+
assert out.shape == x.shape
|
|
55
|
+
assert out.dtype == input_dtype
|
|
56
|
+
torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
|
|
57
|
+
# torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
|
|
58
|
+
# Backward pass
|
|
59
|
+
if N > 128 * 1024 and input_dtype == torch.float32:
|
|
60
|
+
# Skip backward pass for due to not enough smem
|
|
61
|
+
return
|
|
62
|
+
grad_out = torch.randn_like(out)
|
|
63
|
+
torch.cuda.synchronize()
|
|
64
|
+
out_ref.backward(grad_out)
|
|
65
|
+
out.backward(grad_out)
|
|
66
|
+
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
|
|
67
|
+
torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_rmsnorm_strided_tensor():
|
|
71
|
+
"""Test RMSNorm with strided tensor input where shape is (8, 4096, 512) and stride is (sth, 576, 1)."""
|
|
72
|
+
device = "cuda"
|
|
73
|
+
dtype = torch.bfloat16
|
|
74
|
+
atol = 1e-1
|
|
75
|
+
eps = 1e-5
|
|
76
|
+
# Create a larger tensor with 576 features
|
|
77
|
+
full_tensor = torch.randn(8, 4096, 576, device=device, dtype=dtype)
|
|
78
|
+
# Take a slice of the top 512 dimensions - this creates a strided view
|
|
79
|
+
x = full_tensor[:, :, :512].detach().requires_grad_()
|
|
80
|
+
# Create weight tensor
|
|
81
|
+
weight = torch.randn(512, device=device, dtype=torch.float32, requires_grad=True)
|
|
82
|
+
# Reference implementation
|
|
83
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
84
|
+
weight_ref = weight.detach().clone().requires_grad_()
|
|
85
|
+
out = rmsnorm(x, weight, eps=eps)
|
|
86
|
+
out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
87
|
+
assert out.shape == x.shape
|
|
88
|
+
torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
|
|
89
|
+
grad_out = torch.randn_like(out)
|
|
90
|
+
torch.cuda.synchronize()
|
|
91
|
+
out_ref.backward(grad_out)
|
|
92
|
+
out.backward(grad_out)
|
|
93
|
+
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
|
|
94
|
+
torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.parametrize("eps", [1e-5])
|
|
98
|
+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
|
|
99
|
+
@pytest.mark.parametrize(
|
|
100
|
+
"N",
|
|
101
|
+
[131072, 262144],
|
|
102
|
+
# [262144]
|
|
103
|
+
)
|
|
104
|
+
@pytest.mark.parametrize("M", [32 * 1024])
|
|
105
|
+
def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
|
|
106
|
+
"""Test RMSNorm forward pass against reference implementation."""
|
|
107
|
+
device = "cuda"
|
|
108
|
+
# Set tolerance based on dtype
|
|
109
|
+
if input_dtype == torch.bfloat16:
|
|
110
|
+
atol = 1e-1
|
|
111
|
+
elif input_dtype == torch.float16:
|
|
112
|
+
atol = 1e-2
|
|
113
|
+
else:
|
|
114
|
+
atol = 1e-4
|
|
115
|
+
torch.random.manual_seed(0)
|
|
116
|
+
torch.cuda.empty_cache()
|
|
117
|
+
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
|
|
118
|
+
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
|
|
119
|
+
out = rmsnorm(x, weight, eps=eps)
|
|
120
|
+
# Need to compile, otherwise it OOMs
|
|
121
|
+
rmsnorm_compiled = torch.compile(rmsnorm_ref)
|
|
122
|
+
# Run once with smaller input to avoid OOMs
|
|
123
|
+
rmsnorm_compiled(x[:32], weight, eps=eps)
|
|
124
|
+
out_ref = rmsnorm_compiled(x, weight, eps=eps)
|
|
125
|
+
# Need to chunk, otherwise it OOMs
|
|
126
|
+
assert all(
|
|
127
|
+
(out_c - out_ref_c).abs().max() < atol
|
|
128
|
+
for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16))
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize("return_rstd", [True, False])
|
|
133
|
+
def test_rmsnorm_return_rstd_option(return_rstd):
|
|
134
|
+
"""Test that return_rstd option works correctly."""
|
|
135
|
+
device = "cuda"
|
|
136
|
+
M, N = 32, 1024
|
|
137
|
+
eps = 1e-6
|
|
138
|
+
|
|
139
|
+
x = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
140
|
+
weight = torch.randn(N, device=device, dtype=torch.float32)
|
|
141
|
+
|
|
142
|
+
if return_rstd:
|
|
143
|
+
out, rstd = _rmsnorm_fwd(x, weight, eps=eps, return_rstd=True)
|
|
144
|
+
assert out.shape == (M, N)
|
|
145
|
+
assert rstd.shape == (M,)
|
|
146
|
+
assert rstd.dtype == torch.float32
|
|
147
|
+
else:
|
|
148
|
+
out = _rmsnorm_fwd(x, weight, eps=eps, return_rstd=False)
|
|
149
|
+
assert out.shape == (M, N)
|
|
150
|
+
assert isinstance(out, torch.Tensor)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_rmsnorm_input_validation():
|
|
154
|
+
"""Test input validation and error handling."""
|
|
155
|
+
device = "cuda"
|
|
156
|
+
|
|
157
|
+
# Test 3D input (should now work since rmsnorm was updated to accept 3D inputs)
|
|
158
|
+
x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
|
|
159
|
+
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
160
|
+
|
|
161
|
+
# This should not raise an exception now
|
|
162
|
+
out = rmsnorm(x_3d, weight)
|
|
163
|
+
# Verify output shape matches input shape
|
|
164
|
+
assert out.shape == x_3d.shape
|
|
165
|
+
# Verify output dtype matches input dtype
|
|
166
|
+
assert out.dtype == x_3d.dtype
|
|
167
|
+
|
|
168
|
+
# Test weight dimension mismatch
|
|
169
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
170
|
+
weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
|
|
171
|
+
|
|
172
|
+
with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
|
|
173
|
+
rmsnorm(x, weight_wrong)
|
|
174
|
+
|
|
175
|
+
# Test CPU tensors (should fail)
|
|
176
|
+
x_cpu = torch.randn(32, 1024, dtype=torch.float16)
|
|
177
|
+
weight_cpu = torch.randn(1024, dtype=torch.float32)
|
|
178
|
+
|
|
179
|
+
with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
|
|
180
|
+
rmsnorm(x_cpu, weight_cpu)
|
|
181
|
+
|
|
182
|
+
# Test unsupported dtype
|
|
183
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float64)
|
|
184
|
+
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
185
|
+
|
|
186
|
+
with pytest.raises(AssertionError, match="Unsupported dtype"):
|
|
187
|
+
rmsnorm(x, weight)
|
|
188
|
+
|
|
189
|
+
# Test wrong weight dtype
|
|
190
|
+
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
191
|
+
weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float64)
|
|
192
|
+
|
|
193
|
+
with pytest.raises(AssertionError, match="Weight must be float32, float16 or bfloat16"):
|
|
194
|
+
rmsnorm(x, weight_wrong_dtype)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_rmsnorm_bf16_weights():
|
|
198
|
+
"""Test that bfloat16 weights work correctly with rmsnorm."""
|
|
199
|
+
device = "cuda"
|
|
200
|
+
M, N = 32, 1024
|
|
201
|
+
eps = 1e-6
|
|
202
|
+
|
|
203
|
+
# Test with bfloat16 input and weights
|
|
204
|
+
x = torch.randn(M, N, device=device, dtype=torch.bfloat16)
|
|
205
|
+
weight_bf16 = torch.randn(N, device=device, dtype=torch.bfloat16)
|
|
206
|
+
|
|
207
|
+
# Run rmsnorm with bfloat16 weights
|
|
208
|
+
out_bf16 = rmsnorm(x, weight_bf16, eps=eps)
|
|
209
|
+
|
|
210
|
+
# Verify output shape and dtype
|
|
211
|
+
assert out_bf16.shape == x.shape
|
|
212
|
+
assert out_bf16.dtype == torch.bfloat16
|
|
213
|
+
|
|
214
|
+
# Convert to float32 for reference comparison
|
|
215
|
+
x_fp32 = x.to(torch.float32)
|
|
216
|
+
weight_fp32 = weight_bf16.to(torch.float32)
|
|
217
|
+
|
|
218
|
+
# Run reference implementation with float32
|
|
219
|
+
out_ref = rmsnorm_ref(x_fp32, weight_fp32, eps=eps).to(torch.bfloat16)
|
|
220
|
+
|
|
221
|
+
# Verify output values match reference implementation
|
|
222
|
+
torch.testing.assert_close(out_bf16, out_ref, atol=1e-1, rtol=1e-2)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_rmsnorm_bf16_weights_backward():
|
|
226
|
+
"""Test that bfloat16 weights work correctly with rmsnorm backward pass."""
|
|
227
|
+
device = "cuda"
|
|
228
|
+
M, N = 32, 1024
|
|
229
|
+
eps = 1e-6
|
|
230
|
+
atol = 1e-1 # Higher tolerance for bfloat16
|
|
231
|
+
|
|
232
|
+
# Create tensors with gradients
|
|
233
|
+
x = torch.randn(M, N, device=device, dtype=torch.bfloat16, requires_grad=True)
|
|
234
|
+
weight_bf16 = torch.randn(N, device=device, dtype=torch.bfloat16, requires_grad=True)
|
|
235
|
+
|
|
236
|
+
# Create reference tensors with float32 weights for comparison
|
|
237
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
238
|
+
weight_fp32 = weight_bf16.to(torch.float32).detach().requires_grad_()
|
|
239
|
+
|
|
240
|
+
# Forward pass
|
|
241
|
+
out_bf16 = rmsnorm(x, weight_bf16, eps=eps)
|
|
242
|
+
out_ref = rmsnorm(x_ref, weight_fp32, eps=eps)
|
|
243
|
+
|
|
244
|
+
# Create gradient for backward pass
|
|
245
|
+
grad_out = torch.randn_like(out_bf16)
|
|
246
|
+
grad_out_ref = grad_out.clone()
|
|
247
|
+
|
|
248
|
+
# Backward pass
|
|
249
|
+
torch.cuda.synchronize()
|
|
250
|
+
out_bf16.backward(grad_out)
|
|
251
|
+
out_ref.backward(grad_out_ref)
|
|
252
|
+
|
|
253
|
+
# Verify gradients
|
|
254
|
+
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-2)
|
|
255
|
+
torch.testing.assert_close(
|
|
256
|
+
weight_bf16.grad, weight_fp32.grad.to(torch.bfloat16), atol=atol, rtol=1e-2
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Test with mixed precision: bfloat16 input and float32 weights
|
|
260
|
+
x = torch.randn(M, N, device=device, dtype=torch.bfloat16, requires_grad=True)
|
|
261
|
+
weight_fp32 = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
262
|
+
|
|
263
|
+
# Forward pass
|
|
264
|
+
out_mixed = rmsnorm(x, weight_fp32, eps=eps)
|
|
265
|
+
|
|
266
|
+
# Create gradient for backward pass
|
|
267
|
+
grad_out = torch.randn_like(out_mixed)
|
|
268
|
+
|
|
269
|
+
# Backward pass
|
|
270
|
+
torch.cuda.synchronize()
|
|
271
|
+
out_mixed.backward(grad_out)
|
|
272
|
+
|
|
273
|
+
# Just verify that backward pass completes without errors
|
|
274
|
+
assert x.grad is not None
|
|
275
|
+
assert weight_fp32.grad is not None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_rmsnorm_fp16_weights():
|
|
279
|
+
"""Test that float16 weights work correctly with rmsnorm."""
|
|
280
|
+
device = "cuda"
|
|
281
|
+
M, N = 32, 1024
|
|
282
|
+
eps = 1e-6
|
|
283
|
+
|
|
284
|
+
# Test with float16 input and weights
|
|
285
|
+
x = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
286
|
+
weight_fp16 = torch.randn(N, device=device, dtype=torch.float16)
|
|
287
|
+
|
|
288
|
+
# Run rmsnorm with float16 weights
|
|
289
|
+
out_fp16 = rmsnorm(x, weight_fp16, eps=eps)
|
|
290
|
+
|
|
291
|
+
# Verify output shape and dtype
|
|
292
|
+
assert out_fp16.shape == x.shape
|
|
293
|
+
assert out_fp16.dtype == torch.float16
|
|
294
|
+
|
|
295
|
+
# Convert to float32 for reference comparison
|
|
296
|
+
x_fp32 = x.to(torch.float32)
|
|
297
|
+
weight_fp32 = weight_fp16.to(torch.float32)
|
|
298
|
+
|
|
299
|
+
# Run reference implementation with float32
|
|
300
|
+
out_ref = rmsnorm_ref(x_fp32, weight_fp32, eps=eps).to(torch.float16)
|
|
301
|
+
|
|
302
|
+
# Verify output values match reference implementation
|
|
303
|
+
torch.testing.assert_close(out_fp16, out_ref, atol=1e-2, rtol=1e-2)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def test_rmsnorm_fp16_weights_backward():
|
|
307
|
+
"""Test that float16 weights work correctly with rmsnorm backward pass."""
|
|
308
|
+
device = "cuda"
|
|
309
|
+
M, N = 32, 1024
|
|
310
|
+
eps = 1e-6
|
|
311
|
+
atol = 1e-2 # Tolerance for float16
|
|
312
|
+
|
|
313
|
+
# Create tensors with gradients
|
|
314
|
+
x = torch.randn(M, N, device=device, dtype=torch.float16, requires_grad=True)
|
|
315
|
+
weight_fp16 = torch.randn(N, device=device, dtype=torch.float16, requires_grad=True)
|
|
316
|
+
|
|
317
|
+
# Create reference tensors with float32 weights for comparison
|
|
318
|
+
x_ref = x.detach().clone().requires_grad_()
|
|
319
|
+
weight_fp32 = weight_fp16.to(torch.float32).detach().requires_grad_()
|
|
320
|
+
|
|
321
|
+
# Forward pass
|
|
322
|
+
out_fp16 = rmsnorm(x, weight_fp16, eps=eps)
|
|
323
|
+
out_ref = rmsnorm(x_ref, weight_fp32, eps=eps)
|
|
324
|
+
|
|
325
|
+
# Create gradient for backward pass
|
|
326
|
+
grad_out = torch.randn_like(out_fp16)
|
|
327
|
+
grad_out_ref = grad_out.clone()
|
|
328
|
+
|
|
329
|
+
# Backward pass
|
|
330
|
+
torch.cuda.synchronize()
|
|
331
|
+
out_fp16.backward(grad_out)
|
|
332
|
+
out_ref.backward(grad_out_ref)
|
|
333
|
+
|
|
334
|
+
# Verify gradients
|
|
335
|
+
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-2)
|
|
336
|
+
torch.testing.assert_close(
|
|
337
|
+
weight_fp16.grad, weight_fp32.grad.to(torch.float16), atol=atol, rtol=1e-2
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Test with mixed precision: float16 input and float32 weights
|
|
341
|
+
x = torch.randn(M, N, device=device, dtype=torch.float16, requires_grad=True)
|
|
342
|
+
weight_fp32 = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
343
|
+
|
|
344
|
+
# Forward pass
|
|
345
|
+
out_mixed = rmsnorm(x, weight_fp32, eps=eps)
|
|
346
|
+
|
|
347
|
+
# Create gradient for backward pass
|
|
348
|
+
grad_out = torch.randn_like(out_mixed)
|
|
349
|
+
|
|
350
|
+
# Backward pass
|
|
351
|
+
torch.cuda.synchronize()
|
|
352
|
+
out_mixed.backward(grad_out)
|
|
353
|
+
|
|
354
|
+
# Just verify that backward pass completes without errors
|
|
355
|
+
assert x.grad is not None
|
|
356
|
+
assert weight_fp32.grad is not None
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def test_rmsnorm_compile_cache():
|
|
360
|
+
"""Test that compile cache works correctly for repeated calls."""
|
|
361
|
+
device = "cuda"
|
|
362
|
+
M, N = 32, 1024
|
|
363
|
+
eps = 1e-6
|
|
364
|
+
|
|
365
|
+
# Clear cache
|
|
366
|
+
_rmsnorm_fwd.compile_cache.clear()
|
|
367
|
+
assert len(_rmsnorm_fwd.compile_cache) == 0
|
|
368
|
+
|
|
369
|
+
x1 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
370
|
+
weight1 = torch.randn(N, device=device, dtype=torch.float32)
|
|
371
|
+
|
|
372
|
+
# First call should compile
|
|
373
|
+
out1 = _rmsnorm_fwd(x1, weight1, eps=eps)
|
|
374
|
+
assert len(_rmsnorm_fwd.compile_cache) == 1
|
|
375
|
+
|
|
376
|
+
# Same shape should reuse cache
|
|
377
|
+
x2 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
378
|
+
weight2 = torch.randn(N, device=device, dtype=torch.float32)
|
|
379
|
+
out2 = _rmsnorm_fwd(x2, weight2, eps=eps)
|
|
380
|
+
assert len(_rmsnorm_fwd.compile_cache) == 1
|
|
381
|
+
|
|
382
|
+
# Different shape should create new cache entry
|
|
383
|
+
x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
|
|
384
|
+
weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
|
|
385
|
+
out3 = _rmsnorm_fwd(x3, weight3, eps=eps)
|
|
386
|
+
assert len(_rmsnorm_fwd.compile_cache) == 2
|
|
387
|
+
|
|
388
|
+
# Different dtype should create new cache entry
|
|
389
|
+
x4 = torch.randn(M, N, device=device, dtype=torch.float32)
|
|
390
|
+
weight4 = torch.randn(N, device=device, dtype=torch.float32)
|
|
391
|
+
out4 = _rmsnorm_fwd(x4, weight4, eps=eps)
|
|
392
|
+
assert len(_rmsnorm_fwd.compile_cache) == 3
|
|
@@ -1,183 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from quack.rmsnorm import rmsnorm, rmsnorm_ref, rstd_ref
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
10
|
-
# @pytest.mark.parametrize("eps", [1e-5])
|
|
11
|
-
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
12
|
-
# @pytest.mark.parametrize("input_dtype", [torch.float16])
|
|
13
|
-
@pytest.mark.parametrize(
|
|
14
|
-
"N",
|
|
15
|
-
[192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
|
|
16
|
-
# [262144]
|
|
17
|
-
)
|
|
18
|
-
@pytest.mark.parametrize("M", [1, 37, 199, 8 * 1024])
|
|
19
|
-
# @pytest.mark.parametrize("M", [1])
|
|
20
|
-
def test_rmsnorm_forward_backward(M, N, input_dtype, eps):
|
|
21
|
-
"""Test RMSNorm forward pass against reference implementation."""
|
|
22
|
-
if N >= 256 * 1024 and input_dtype == torch.float32 and M >= 8 * 1024:
|
|
23
|
-
pytest.skip("Skipping large tensor test for float32 to avoid OOM")
|
|
24
|
-
device = "cuda"
|
|
25
|
-
# Set tolerance based on dtype
|
|
26
|
-
if input_dtype == torch.bfloat16:
|
|
27
|
-
atol = 1e-1
|
|
28
|
-
elif input_dtype == torch.float16:
|
|
29
|
-
atol = 1e-2
|
|
30
|
-
else:
|
|
31
|
-
atol = 1e-4
|
|
32
|
-
torch.random.manual_seed(0)
|
|
33
|
-
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
|
|
34
|
-
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
|
|
35
|
-
x_ref = x.detach().clone().requires_grad_()
|
|
36
|
-
weight_ref = weight.detach().clone().requires_grad_()
|
|
37
|
-
out = rmsnorm(x, weight, eps=eps)
|
|
38
|
-
out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
|
|
39
|
-
# rstd_ref_val = rstd_ref(x_ref, eps=eps)
|
|
40
|
-
assert out.shape == x.shape
|
|
41
|
-
assert out.dtype == input_dtype
|
|
42
|
-
torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
|
|
43
|
-
# torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
|
|
44
|
-
# Backward pass
|
|
45
|
-
if N > 128 * 1024 and input_dtype == torch.float32:
|
|
46
|
-
# Skip backward pass for due to not enough smem
|
|
47
|
-
return
|
|
48
|
-
grad_out = torch.randn_like(out)
|
|
49
|
-
torch.cuda.synchronize()
|
|
50
|
-
out_ref.backward(grad_out)
|
|
51
|
-
out.backward(grad_out)
|
|
52
|
-
torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
|
|
53
|
-
torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@pytest.mark.parametrize("eps", [1e-5])
|
|
57
|
-
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
|
|
58
|
-
@pytest.mark.parametrize(
|
|
59
|
-
"N",
|
|
60
|
-
[131072, 262144]
|
|
61
|
-
# [262144]
|
|
62
|
-
)
|
|
63
|
-
@pytest.mark.parametrize("M", [32 * 1024])
|
|
64
|
-
def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
|
|
65
|
-
"""Test RMSNorm forward pass against reference implementation."""
|
|
66
|
-
device = "cuda"
|
|
67
|
-
# Set tolerance based on dtype
|
|
68
|
-
if input_dtype == torch.bfloat16:
|
|
69
|
-
atol = 1e-1
|
|
70
|
-
elif input_dtype == torch.float16:
|
|
71
|
-
atol = 1e-2
|
|
72
|
-
else:
|
|
73
|
-
atol = 1e-4
|
|
74
|
-
torch.random.manual_seed(0)
|
|
75
|
-
torch.cuda.empty_cache()
|
|
76
|
-
x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
|
|
77
|
-
weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
|
|
78
|
-
out = rmsnorm(x, weight, eps=eps)
|
|
79
|
-
# Need to compile, otherwise it OOMs
|
|
80
|
-
rmsnorm_compiled = torch.compile(rmsnorm_ref)
|
|
81
|
-
# Run once with smaller input to avoid OOMs
|
|
82
|
-
rmsnorm_compiled(x[:32], weight, eps=eps)
|
|
83
|
-
out_ref = rmsnorm_compiled(x, weight, eps=eps)
|
|
84
|
-
# Need to chunk, otherwise it OOMs
|
|
85
|
-
assert all((out_c - out_ref_c).abs().max() < atol
|
|
86
|
-
for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16)))
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@pytest.mark.parametrize("return_rstd", [True, False])
|
|
90
|
-
def test_rmsnorm_return_rstd_option(return_rstd):
|
|
91
|
-
"""Test that return_rstd option works correctly."""
|
|
92
|
-
device = "cuda"
|
|
93
|
-
M, N = 32, 1024
|
|
94
|
-
eps = 1e-6
|
|
95
|
-
|
|
96
|
-
x = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
97
|
-
weight = torch.randn(N, device=device, dtype=torch.float32)
|
|
98
|
-
|
|
99
|
-
if return_rstd:
|
|
100
|
-
out, rstd = rmsnorm(x, weight, eps=eps, return_rstd=True)
|
|
101
|
-
assert out.shape == (M, N)
|
|
102
|
-
assert rstd.shape == (M,)
|
|
103
|
-
assert rstd.dtype == torch.float32
|
|
104
|
-
else:
|
|
105
|
-
out = rmsnorm(x, weight, eps=eps, return_rstd=False)
|
|
106
|
-
assert out.shape == (M, N)
|
|
107
|
-
assert isinstance(out, torch.Tensor)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def test_rmsnorm_input_validation():
|
|
111
|
-
"""Test input validation and error handling."""
|
|
112
|
-
device = "cuda"
|
|
113
|
-
|
|
114
|
-
# Test 3D input (should fail)
|
|
115
|
-
x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
|
|
116
|
-
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
117
|
-
|
|
118
|
-
with pytest.raises(AssertionError, match="Input must be 2D"):
|
|
119
|
-
rmsnorm(x_3d, weight)
|
|
120
|
-
|
|
121
|
-
# Test weight dimension mismatch
|
|
122
|
-
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
123
|
-
weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
|
|
124
|
-
|
|
125
|
-
with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
|
|
126
|
-
rmsnorm(x, weight_wrong)
|
|
127
|
-
|
|
128
|
-
# Test CPU tensors (should fail)
|
|
129
|
-
x_cpu = torch.randn(32, 1024, dtype=torch.float16)
|
|
130
|
-
weight_cpu = torch.randn(1024, dtype=torch.float32)
|
|
131
|
-
|
|
132
|
-
with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
|
|
133
|
-
rmsnorm(x_cpu, weight_cpu)
|
|
134
|
-
|
|
135
|
-
# Test unsupported dtype
|
|
136
|
-
x = torch.randn(32, 1024, device=device, dtype=torch.float64)
|
|
137
|
-
weight = torch.randn(1024, device=device, dtype=torch.float32)
|
|
138
|
-
|
|
139
|
-
with pytest.raises(AssertionError, match="Unsupported dtype"):
|
|
140
|
-
rmsnorm(x, weight)
|
|
141
|
-
|
|
142
|
-
# Test wrong weight dtype
|
|
143
|
-
x = torch.randn(32, 1024, device=device, dtype=torch.float16)
|
|
144
|
-
weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16)
|
|
145
|
-
|
|
146
|
-
with pytest.raises(AssertionError, match="Weight must be float32"):
|
|
147
|
-
rmsnorm(x, weight_wrong_dtype)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def test_rmsnorm_compile_cache():
|
|
151
|
-
"""Test that compile cache works correctly for repeated calls."""
|
|
152
|
-
device = "cuda"
|
|
153
|
-
M, N = 32, 1024
|
|
154
|
-
eps = 1e-6
|
|
155
|
-
|
|
156
|
-
# Clear cache
|
|
157
|
-
rmsnorm.compile_cache.clear()
|
|
158
|
-
assert len(rmsnorm.compile_cache) == 0
|
|
159
|
-
|
|
160
|
-
x1 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
161
|
-
weight1 = torch.randn(N, device=device, dtype=torch.float32)
|
|
162
|
-
|
|
163
|
-
# First call should compile
|
|
164
|
-
out1 = rmsnorm(x1, weight1, eps=eps)
|
|
165
|
-
assert len(rmsnorm.compile_cache) == 1
|
|
166
|
-
|
|
167
|
-
# Same shape should reuse cache
|
|
168
|
-
x2 = torch.randn(M, N, device=device, dtype=torch.float16)
|
|
169
|
-
weight2 = torch.randn(N, device=device, dtype=torch.float32)
|
|
170
|
-
out2 = rmsnorm(x2, weight2, eps=eps)
|
|
171
|
-
assert len(rmsnorm.compile_cache) == 1
|
|
172
|
-
|
|
173
|
-
# Different shape should create new cache entry
|
|
174
|
-
x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
|
|
175
|
-
weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
|
|
176
|
-
out3 = rmsnorm(x3, weight3, eps=eps)
|
|
177
|
-
assert len(rmsnorm.compile_cache) == 2
|
|
178
|
-
|
|
179
|
-
# Different dtype should create new cache entry
|
|
180
|
-
x4 = torch.randn(M, N, device=device, dtype=torch.float32)
|
|
181
|
-
weight4 = torch.randn(N, device=device, dtype=torch.float32)
|
|
182
|
-
out4 = rmsnorm(x4, weight4, eps=eps)
|
|
183
|
-
assert len(rmsnorm.compile_cache) == 3
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|