quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -1,9 +1,8 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import math
4
- import operator
5
3
 
6
4
  import torch
5
+ from typing import Optional
7
6
 
8
7
  import cuda.bindings.driver as cuda
9
8
 
@@ -12,198 +11,202 @@ import cutlass.cute as cute
12
11
  from cutlass.cute.runtime import from_dlpack
13
12
 
14
13
  import quack.utils as utils
14
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+
16
+
17
+ class RMSNorm(ReductionBase):
18
+ def __init__(self, dtype: cutlass.Numeric, N: int):
19
+ super().__init__(dtype, N, stage=1)
20
+ self.reload_from = None if N <= 16384 else "smem"
21
+ self.delay_w_load = False
22
+
23
+ def _calculate_threads_per_row(self):
24
+ N = self.N
25
+ return (
26
+ 8
27
+ if N <= 64
28
+ else (
29
+ 16
30
+ if N <= 128
31
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
32
+ )
33
+ )
15
34
 
35
+ def _set_cluster_n(self):
36
+ N = self.N
37
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
38
+ # Similarly cluster_n = 8 is faster for N=128k
39
+ if cutlass.const_expr(self.dtype.width == 16):
40
+ cluster_n = (
41
+ 1
42
+ if N <= 16 * 1024
43
+ else (
44
+ 2
45
+ if N <= 32 * 1024
46
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
47
+ )
48
+ )
49
+ else: # fp32
50
+ cluster_n = (
51
+ 1
52
+ if N <= 32 * 1024
53
+ else (
54
+ 2
55
+ if N <= 64 * 1024
56
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
57
+ )
58
+ )
59
+ self.cluster_n = cluster_n
60
+
61
+ @cute.jit
62
+ def __call__(
63
+ self,
64
+ mX: cute.Tensor,
65
+ mW: cute.Tensor,
66
+ mO: cute.Tensor,
67
+ mRstd: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
85
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
86
+ block=[num_threads, 1, 1],
87
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
88
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
89
+ stream=stream,
90
+ )
16
91
 
17
- @cute.kernel
18
- def rmsnorm_kernel(
19
- gX: cute.Tensor,
20
- gW: cute.Tensor,
21
- gO: cute.Tensor,
22
- gRstd: cute.Tensor,
23
- cX: cute.Tensor, # coordinate tensor
24
- eps: cute.Float32,
25
- shape: cute.Shape,
26
- tv_layout: cute.Layout,
27
- tiler_mn: cute.Shape,
28
- cluster_n: cutlass.Constexpr = 1,
29
- reload_from: cutlass.Constexpr = None,
30
- delay_w_load: cutlass.Constexpr = False,
31
- ):
32
- tidx, _, _ = cute.arch.thread_idx()
33
- bidx, cluster_y, _ = cute.arch.block_idx()
34
- gdim, _, _ = cute.arch.grid_dim()
35
-
36
- # slice for CTAs
37
- # logical id -> address
38
- blkX, blkOut, blkRstd, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, gRstd, cX)]
39
- blkW = gW[(None, None), 0 if cluster_n == 1 else (0, cluster_y)]
40
-
41
- # declare the atoms which will be used later for memory copy
42
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
43
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
44
- copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gW.element_type, num_bits_per_copy=128)
45
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
46
-
47
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
48
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
49
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
50
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
51
-
52
- smem = cutlass.utils.SmemAllocator()
53
- # Don't use blkX.layout here, because the stride is N, not N_rounded
54
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
55
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
56
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
57
- # reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row), order=(1, 0))
58
- reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n)), order=(1, 0))
59
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
60
- if cutlass.const_expr(cluster_n > 1):
61
- mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
62
- else:
63
- mbar_ptr = None
64
-
65
- tWgW = thr_copy_W.partition_S(blkW)
66
- tXgX = thr_copy_X_async.partition_S(blkX)
67
- tXsX = thr_copy_X_async.partition_S(sX)
68
- tXgO, tXrRstd = [thr_copy_O.partition_D(blk) for blk in (blkOut, blkRstd)]
69
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
70
-
71
- # allocate fragments for gmem->rmem
72
- tWrW = cute.make_fragment_like(tWgW)
73
- tXrW = thr_copy_X.retile(tWrW)
74
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
75
-
76
- if cluster_n > 1:
77
- if tidx == 0:
78
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
79
- cute.arch.mbarrier_init_fence()
80
- if tidx == 0:
81
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
82
- # Cluster arrive after barrier init
83
- cute.arch.cluster_arrive_relaxed()
84
-
85
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
86
- for i in range(cute.size(tXpX)):
87
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
88
- # tXrX.fill(0.0)
89
- if tXcX[0][0] < shape[0]:
90
- # cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
91
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
92
- cute.arch.cp_async_commit_group()
92
+ @cute.kernel
93
+ def kernel(
94
+ self,
95
+ mX: cute.Tensor,
96
+ mW: cute.Tensor,
97
+ mO: cute.Tensor,
98
+ mRstd: Optional[cute.Tensor],
99
+ eps: cute.Float32,
100
+ tv_layout: cute.Layout,
101
+ tiler_mn: cute.Shape,
102
+ reload_from: cutlass.Constexpr = None,
103
+ delay_w_load: cutlass.Constexpr = False,
104
+ ):
105
+ tidx, _, _ = cute.arch.thread_idx()
106
+ bidx, cluster_y, _ = cute.arch.block_idx()
107
+
108
+ smem = cutlass.utils.SmemAllocator()
109
+ sX = smem.allocate_tensor(
110
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
111
+ )
112
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
113
+
114
+ shape = mX.shape
115
+ idX = cute.make_identity_tensor(shape)
116
+ # slice for CTAs
117
+ gX, gO, cX = [
118
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
119
+ for mT in (mX, mO, idX)
120
+ ]
121
+ gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
122
+ gRstd = (
123
+ cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
124
+ if cutlass.const_expr(mRstd is not None)
125
+ else None
126
+ )
93
127
 
94
- tWpW = cute.make_fragment_like(tWgW[(0, None), None, None], cutlass.Boolean)
95
- tWcX = thr_copy_W.partition_S(blkCrd)[(0, None), None, None]
96
- for i in range(cute.size(tWpW)):
97
- tWpW[i] = cute.elem_less(tWcX[i][1], shape[1])
98
- if not delay_w_load:
99
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
128
+ # declare the atoms which will be used later for memory copy
129
+ copy_atom_load_X = cute.make_copy_atom(
130
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
131
+ )
132
+ copy_atom_load_X_async = cute.make_copy_atom(
133
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
134
+ )
135
+ copy_atom_load_W = cute.make_copy_atom(
136
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
137
+ )
138
+ copy_atom_store_O = cute.make_copy_atom(
139
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
140
+ )
100
141
 
101
- cute.arch.cp_async_wait_group(0)
102
- cute.autovec_copy(tXsX, tXrX)
103
- x = tXrX.load().to(cute.Float32)
104
- sum_sq_x = utils.warp_reduce(
105
- (x * x).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
106
- operator.add,
107
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
108
- )
109
- if cutlass.const_expr(cluster_n > 1):
110
- cute.arch.cluster_wait()
111
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
112
- sum_sq_x = utils.block_or_cluster_reduce(
113
- sum_sq_x, operator.add, reduction_buffer, mbar_ptr, init_val=0.0
142
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
143
+ tidx
114
144
  )
115
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
116
- # Only the thread corresponding to column 0 writes out the rstd to gmem
117
- if tXcX[0][1] == 0 and tXcX[0][0] < shape[0]:
118
- if cutlass.const_expr(cluster_n == 1):
119
- tXrRstd[0] = rstd
120
- else:
121
- if cute.arch.block_idx_in_cluster() == 0:
122
- tXrRstd[0] = rstd
123
- if delay_w_load:
124
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
125
- if reload_from == "smem":
145
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
146
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
147
+
148
+ tWgW = thr_copy_W.partition_S(gW)
149
+ tXgX = thr_copy_X.partition_S(gX)
150
+ tXsX = thr_copy_X.partition_D(sX)
151
+ tXgO = thr_copy_O.partition_D(gO)
152
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
153
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
154
+
155
+ # allocate fragments for gmem->rmem
156
+ tWrW = cute.make_fragment_like(tWgW)
157
+ tXrW = thr_copy_X.retile(tWrW)
158
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
159
+
160
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
161
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
162
+
163
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
164
+ row = tXcX[0][0]
165
+ if row < shape[0]:
166
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
167
+ cute.arch.cp_async_commit_group()
168
+
169
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
170
+ if not delay_w_load:
171
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
172
+
173
+ cute.arch.cp_async_wait_group(0)
126
174
  cute.autovec_copy(tXsX, tXrX)
127
175
  x = tXrX.load().to(cute.Float32)
128
- elif reload_from == "gmem":
129
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
130
- x = tXrX.load().to(cute.Float32)
131
- x_hat = x * rstd
132
- w = tXrW.load().to(cute.Float32)
133
- y = x_hat * w
134
- tXrO.store(y.to(tXrO.element_type))
135
- tOcX = thr_copy_O.partition_S(blkCrd)[(0, None), None, None]
136
- tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
137
- for i in range(cute.size(tOpO)):
138
- tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
139
- if tXcX[0][0] < shape[0]:
140
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
141
-
142
-
143
- @cute.jit
144
- def rmsnorm_interface(
145
- # mX_: cute.Tensor,
146
- mX: cute.Tensor,
147
- mW: cute.Tensor,
148
- mOut: cute.Tensor,
149
- mRstd: cute.Tensor,
150
- stream: cuda.CUstream,
151
- N: cutlass.Constexpr,
152
- eps: cutlass.Float32 = 1e-6,
153
- copy_bits: cutlass.Constexpr = 128
154
- ):
155
- # new_shape = (mX_.shape[0], cute.assume(mX_.shape[1], 128))
156
- # breakpoint()
157
- # mX = cute.make_tensor(mX_.iterator, cute.make_layout(new_shape, stride=mX_.stride))
158
- vecsize = copy_bits // mX.element_type.width
159
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
160
- num_threads = 128 if N <= 16384 else 256
161
- num_warps = num_threads // cute.arch.WARP_SIZE
162
- assert num_threads % cute.arch.WARP_SIZE == 0
163
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
164
- # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
165
- # Similarly cluster_n = 8 is faster for N=128k
166
- if cutlass.const_expr(mX.element_type.width == 16):
167
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
168
- else: # fp32
169
- cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
170
-
171
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
172
- cols_per_block = num_threads // threads_per_row
173
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
174
- tv_layout = cute.make_layout(
175
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
176
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
177
- )
178
-
179
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
180
- mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
181
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
182
- mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
183
- idX = cute.make_identity_tensor(mX.shape)
184
- gX, gW, gO, gRstd, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mW_expanded, mOut, mRstd_expanded, idX)] # ((TileM,TileN),(RestM,RestN))
185
-
186
- # reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
187
- reload_from = None if N <= 16384 else "smem"
188
- # delay_w_load = N > 64 * 1024
189
- delay_w_load = False
190
- N_rounded = tiler_mn[1]
191
- rmsnorm_kernel(gX, gW, gO, gRstd, cX, eps, mX.shape, tv_layout, tiler_mn, cluster_n, reload_from).launch(
192
- grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
193
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
194
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
195
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
196
- # We don't want to use gX.layout[0] here since that has stride in N, not N_rounded, leading IMA on smem
197
- smem=cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
198
- stream=stream,
199
- )
200
-
201
-
202
- torch2cute_dtype_map = {
203
- torch.float16: cutlass.Float16,
204
- torch.bfloat16: cutlass.BFloat16,
205
- torch.float32: cutlass.Float32,
206
- }
176
+ threads_per_row = tv_layout.shape[0][0]
177
+ sum_sq_x = utils.row_reduce(
178
+ x * x,
179
+ cute.ReductionOp.ADD,
180
+ threads_per_row,
181
+ reduction_buffer[None, None, 0],
182
+ mbar_ptr,
183
+ init_val=0.0,
184
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
185
+ )
186
+ rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
187
+ if cutlass.const_expr(mRstd is not None):
188
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
189
+ if (
190
+ tXcX[0][1] == 0
191
+ and row < shape[0]
192
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
193
+ ):
194
+ tXrRstd[0] = rstd
195
+ if delay_w_load:
196
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
197
+ if reload_from == "smem":
198
+ cute.autovec_copy(tXsX, tXrX)
199
+ x = tXrX.load().to(cute.Float32)
200
+ elif reload_from == "gmem":
201
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
202
+ x = tXrX.load().to(cute.Float32)
203
+ x_hat = x * rstd
204
+ w = tXrW.load().to(cute.Float32)
205
+ y = x_hat * w
206
+ tXrO.store(y.to(tXrO.element_type))
207
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
208
+ if row < shape[0]:
209
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
207
210
 
208
211
 
209
212
  def rmsnorm(
@@ -233,24 +236,32 @@ def rmsnorm(
233
236
  M, N = x.shape
234
237
  device = x.device
235
238
  out = torch.empty_like(x)
236
- rstd = torch.empty(M, device=device, dtype=torch.float32)
239
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
237
240
  dtype = torch2cute_dtype_map[x.dtype]
238
241
  convert_from_dlpack = lambda x: (
239
- from_dlpack(x.detach(), assumed_align=16)
240
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
242
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
243
+ mode=0, stride_order=(0, 1)
244
+ )
241
245
  )
242
246
  x_tensor, out_tensor = [
243
247
  # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
244
248
  convert_from_dlpack(t)
245
249
  for t in (x, out)
246
250
  ]
247
- weight_tensor = utils.convert_from_dlpack(weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width)
248
- rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
251
+ weight_tensor = utils.convert_from_dlpack(
252
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
253
+ )
254
+ rstd_tensor = (
255
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
256
+ if rstd is not None
257
+ else None
258
+ )
249
259
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
250
- compile_key = (dtype, N)
260
+ compile_key = (dtype, N, rstd is not None)
251
261
  if compile_key not in rmsnorm.compile_cache:
262
+ rmsnorm_op = RMSNorm(dtype, N)
252
263
  rmsnorm.compile_cache[compile_key] = cute.compile(
253
- rmsnorm_interface, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, N
264
+ rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
254
265
  )
255
266
  rmsnorm.compile_cache[compile_key](
256
267
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
@@ -263,7 +274,9 @@ rmsnorm.compile_cache = {}
263
274
 
264
275
  def rmsnorm_ref(x, w, eps=1e-6):
265
276
  x_f32 = x.float()
266
- return (x_f32 / (torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1, keepdim=True) + eps)) * w).to(x.dtype)
277
+ return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
278
+ x.dtype
279
+ )
267
280
 
268
281
 
269
282
  def rstd_ref(x, eps=1e-6):