quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,23 +1,29 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from functools import partial
|
|
4
5
|
|
|
5
6
|
import cuda.bindings.driver as cuda
|
|
6
7
|
|
|
7
8
|
import cutlass
|
|
8
9
|
import cutlass.cute as cute
|
|
9
10
|
from cutlass import Float32, Int32
|
|
11
|
+
from cutlass import const_expr
|
|
10
12
|
from cutlass.cute.runtime import from_dlpack
|
|
11
13
|
|
|
12
|
-
import quack.utils as utils
|
|
13
14
|
import torch
|
|
14
|
-
from
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
|
|
17
|
+
import quack.utils as utils
|
|
18
|
+
from quack.reduce import row_reduce
|
|
19
|
+
from quack.reduction_base import ReductionBase
|
|
20
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
15
21
|
|
|
16
22
|
|
|
17
23
|
class RMSNorm(ReductionBase):
|
|
18
24
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
19
25
|
super().__init__(dtype, N, stage=1)
|
|
20
|
-
self.reload_from = None if N <=
|
|
26
|
+
self.reload_from = None if N <= 8192 else "smem"
|
|
21
27
|
self.delay_w_load = False
|
|
22
28
|
|
|
23
29
|
def _calculate_threads_per_row(self):
|
|
@@ -45,7 +51,7 @@ class RMSNorm(ReductionBase):
|
|
|
45
51
|
|
|
46
52
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
47
53
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
48
|
-
if
|
|
54
|
+
if const_expr(self.dtype.width == 16):
|
|
49
55
|
# 16-bit types (fp16, bf16)
|
|
50
56
|
if N <= 16 * 1024:
|
|
51
57
|
cluster_n = 1
|
|
@@ -72,12 +78,27 @@ class RMSNorm(ReductionBase):
|
|
|
72
78
|
|
|
73
79
|
self.cluster_n = cluster_n
|
|
74
80
|
|
|
81
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
|
|
82
|
+
return (
|
|
83
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
84
|
+
+ (
|
|
85
|
+
cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
|
|
86
|
+
if dtype_res is not None
|
|
87
|
+
else 0
|
|
88
|
+
)
|
|
89
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
90
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
91
|
+
)
|
|
92
|
+
|
|
75
93
|
@cute.jit
|
|
76
94
|
def __call__(
|
|
77
95
|
self,
|
|
78
96
|
mX: cute.Tensor,
|
|
79
97
|
mW: cute.Tensor,
|
|
98
|
+
mB: Optional[cute.Tensor],
|
|
99
|
+
mRes: Optional[cute.Tensor],
|
|
80
100
|
mO: cute.Tensor,
|
|
101
|
+
mResO: Optional[cute.Tensor],
|
|
81
102
|
mRstd: Optional[cute.Tensor],
|
|
82
103
|
stream: cuda.CUstream,
|
|
83
104
|
eps: Float32 = 1e-6,
|
|
@@ -87,28 +108,49 @@ class RMSNorm(ReductionBase):
|
|
|
87
108
|
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
88
109
|
t.stride[1],
|
|
89
110
|
)
|
|
90
|
-
mX, mO = [
|
|
111
|
+
mX, mRes, mO, mResO = [
|
|
91
112
|
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
92
|
-
|
|
113
|
+
if const_expr(t is not None)
|
|
114
|
+
else None
|
|
115
|
+
for t in (mX, mRes, mO, mResO)
|
|
93
116
|
]
|
|
94
117
|
assert mX.element_type == self.dtype
|
|
95
118
|
assert mO.element_type == self.dtype
|
|
96
119
|
self._set_cluster_n()
|
|
97
|
-
|
|
120
|
+
largest_dtype_width = const_expr(
|
|
121
|
+
max(
|
|
122
|
+
mX.element_type.width,
|
|
123
|
+
mRes.element_type.width if mRes is not None else 0,
|
|
124
|
+
mO.element_type.width,
|
|
125
|
+
mResO.element_type.width if mResO is not None else 0,
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
tiler_mn, tv_layout = self._get_tv_layout(
|
|
129
|
+
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
130
|
+
)
|
|
98
131
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
99
132
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
100
133
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
101
134
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
102
|
-
if
|
|
135
|
+
if const_expr(mB is not None):
|
|
136
|
+
mB_expanded_layout = cute.prepend(
|
|
137
|
+
mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
138
|
+
)
|
|
139
|
+
mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
|
|
140
|
+
if const_expr(mRstd is not None):
|
|
103
141
|
mRstd_expanded_layout = cute.append(
|
|
104
142
|
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
105
143
|
)
|
|
106
144
|
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
107
|
-
self.kernel(
|
|
145
|
+
self.kernel(
|
|
146
|
+
mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
|
|
147
|
+
).launch(
|
|
108
148
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
109
149
|
block=[num_threads, 1, 1],
|
|
110
|
-
cluster=([1, self.cluster_n, 1] if
|
|
111
|
-
smem=self._smem_size_in_bytes(
|
|
150
|
+
cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
|
|
151
|
+
smem=self._smem_size_in_bytes(
|
|
152
|
+
tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
|
|
153
|
+
),
|
|
112
154
|
stream=stream,
|
|
113
155
|
)
|
|
114
156
|
|
|
@@ -117,7 +159,10 @@ class RMSNorm(ReductionBase):
|
|
|
117
159
|
self,
|
|
118
160
|
mX: cute.Tensor,
|
|
119
161
|
mW: cute.Tensor,
|
|
162
|
+
mB: Optional[cute.Tensor],
|
|
163
|
+
mRes: Optional[cute.Tensor],
|
|
120
164
|
mO: cute.Tensor,
|
|
165
|
+
mResO: Optional[cute.Tensor],
|
|
121
166
|
mRstd: Optional[cute.Tensor],
|
|
122
167
|
eps: cute.Float32,
|
|
123
168
|
tv_layout: cute.Layout,
|
|
@@ -127,10 +172,10 @@ class RMSNorm(ReductionBase):
|
|
|
127
172
|
):
|
|
128
173
|
tidx, _, _ = cute.arch.thread_idx()
|
|
129
174
|
bidx, _, _ = cute.arch.block_idx()
|
|
130
|
-
if
|
|
175
|
+
if const_expr(self.cluster_n > 1):
|
|
131
176
|
cluster_y = cute.arch.block_idx()[1]
|
|
132
177
|
else:
|
|
133
|
-
cluster_y =
|
|
178
|
+
cluster_y = const_expr(0)
|
|
134
179
|
|
|
135
180
|
smem = cutlass.utils.SmemAllocator()
|
|
136
181
|
sX = smem.allocate_tensor(
|
|
@@ -138,86 +183,147 @@ class RMSNorm(ReductionBase):
|
|
|
138
183
|
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
139
184
|
byte_alignment=16,
|
|
140
185
|
)
|
|
186
|
+
if const_expr(mRes is not None):
|
|
187
|
+
sRes = smem.allocate_tensor(
|
|
188
|
+
mRes.element_type,
|
|
189
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
190
|
+
byte_alignment=16,
|
|
191
|
+
)
|
|
141
192
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
142
193
|
|
|
143
194
|
shape = mX.shape
|
|
144
195
|
idX = cute.make_identity_tensor(shape)
|
|
145
196
|
# slice for CTAs
|
|
146
197
|
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
147
|
-
mX,
|
|
148
|
-
|
|
198
|
+
mX, mRes, mO, mResO = [
|
|
199
|
+
utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
|
|
200
|
+
for mT in (mX, mRes, mO, mResO)
|
|
201
|
+
]
|
|
202
|
+
gX, gRes, gO, gResO = [
|
|
203
|
+
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
|
|
204
|
+
for mT in (mX, mRes, mO, mResO)
|
|
205
|
+
]
|
|
149
206
|
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
150
207
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
208
|
+
gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
|
|
151
209
|
gRstd = (
|
|
152
210
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
153
|
-
if
|
|
211
|
+
if const_expr(mRstd is not None)
|
|
154
212
|
else None
|
|
155
213
|
)
|
|
156
214
|
|
|
157
215
|
# declare the atoms which will be used later for memory copy
|
|
216
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
217
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
158
218
|
copy_atom_load_X = cute.make_copy_atom(
|
|
159
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=
|
|
219
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
160
220
|
)
|
|
161
221
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
162
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=
|
|
163
|
-
)
|
|
164
|
-
num_bits_per_copy_W = cutlass.const_expr(
|
|
165
|
-
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
222
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
166
223
|
)
|
|
224
|
+
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
167
225
|
copy_atom_load_W = cute.make_copy_atom(
|
|
168
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
226
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
169
227
|
)
|
|
170
|
-
|
|
171
|
-
min(128,
|
|
228
|
+
num_bits_per_copy_B = (
|
|
229
|
+
cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
|
|
230
|
+
if const_expr(mB is not None)
|
|
231
|
+
else 0
|
|
172
232
|
)
|
|
233
|
+
copy_atom_load_B = (
|
|
234
|
+
cute.make_copy_atom(
|
|
235
|
+
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
236
|
+
)
|
|
237
|
+
if const_expr(mB is not None)
|
|
238
|
+
else None
|
|
239
|
+
)
|
|
240
|
+
if const_expr(mRes is not None):
|
|
241
|
+
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
242
|
+
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
243
|
+
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
244
|
+
mRes.element_type,
|
|
245
|
+
num_bits_per_copy=num_copy_bits_Res,
|
|
246
|
+
)
|
|
247
|
+
num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
|
|
173
248
|
copy_atom_store_O = cute.make_copy_atom(
|
|
174
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=
|
|
249
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
|
|
175
250
|
)
|
|
251
|
+
if const_expr(mResO is not None):
|
|
252
|
+
num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
|
|
253
|
+
copy_atom_store_ResO = cute.make_copy_atom(
|
|
254
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
255
|
+
mResO.element_type,
|
|
256
|
+
num_bits_per_copy=num_copy_bits_ResO,
|
|
257
|
+
)
|
|
176
258
|
|
|
177
259
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
178
260
|
tidx
|
|
179
261
|
)
|
|
180
262
|
|
|
181
263
|
tXgW = thr_copy_X.partition_S(gW)
|
|
264
|
+
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
182
265
|
tXgX = thr_copy_X.partition_S(gX)
|
|
183
266
|
tXsX = thr_copy_X.partition_D(sX)
|
|
267
|
+
if const_expr(mRes is not None):
|
|
268
|
+
tXgRes = thr_copy_X.partition_S(gRes)
|
|
269
|
+
tXsRes = thr_copy_X.partition_D(sRes)
|
|
184
270
|
tXgO = thr_copy_X.partition_D(gO)
|
|
185
|
-
|
|
271
|
+
if const_expr(mResO is not None):
|
|
272
|
+
tXgResO = thr_copy_X.partition_D(gResO)
|
|
273
|
+
tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
|
|
186
274
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
187
275
|
|
|
188
276
|
# allocate fragments for gmem->rmem
|
|
189
277
|
tXrW = cute.make_fragment_like(tXgW)
|
|
190
278
|
tXrW.fill(0.0)
|
|
191
|
-
|
|
279
|
+
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
280
|
+
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
281
|
+
if const_expr(mRes is not None):
|
|
282
|
+
tXrRes = cute.make_fragment_like(tXgRes)
|
|
192
283
|
|
|
193
284
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
194
285
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
195
286
|
|
|
196
|
-
|
|
287
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
288
|
+
tXpX = (
|
|
289
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
290
|
+
)
|
|
197
291
|
row = tXcX[0][0]
|
|
198
292
|
if row < shape[0]:
|
|
199
293
|
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
294
|
+
if const_expr(mRes is not None):
|
|
295
|
+
cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
|
|
200
296
|
cute.arch.cp_async_commit_group()
|
|
201
297
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
298
|
+
if const_expr(not delay_w_load):
|
|
299
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
|
|
300
|
+
if const_expr(mB is not None):
|
|
301
|
+
cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
|
|
205
302
|
|
|
206
303
|
cute.arch.cp_async_wait_group(0)
|
|
207
304
|
cute.autovec_copy(tXsX, tXrX)
|
|
208
305
|
x = tXrX.load().to(cute.Float32)
|
|
306
|
+
if const_expr(mRes is not None):
|
|
307
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
308
|
+
x += tXrRes.load().to(cute.Float32)
|
|
309
|
+
if const_expr(mResO is not None):
|
|
310
|
+
tXrResO = cute.make_fragment_like(tXgResO)
|
|
311
|
+
tXrResO.store(x.to(tXrResO.element_type))
|
|
312
|
+
if row < shape[0]:
|
|
313
|
+
cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
|
|
314
|
+
|
|
209
315
|
threads_per_row = tv_layout.shape[0][0]
|
|
210
|
-
sum_sq_x =
|
|
316
|
+
sum_sq_x = row_reduce(
|
|
211
317
|
x * x,
|
|
212
318
|
cute.ReductionOp.ADD,
|
|
213
319
|
threads_per_row,
|
|
214
320
|
reduction_buffer[None, None, 0],
|
|
215
321
|
mbar_ptr,
|
|
216
322
|
init_val=0.0,
|
|
217
|
-
hook_fn=(cute.arch.cluster_wait if
|
|
323
|
+
hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
|
|
218
324
|
)
|
|
219
|
-
rstd =
|
|
220
|
-
if
|
|
325
|
+
rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
|
|
326
|
+
if const_expr(mRstd is not None):
|
|
221
327
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
222
328
|
if (
|
|
223
329
|
tXcX[0][1] == 0
|
|
@@ -225,59 +331,76 @@ class RMSNorm(ReductionBase):
|
|
|
225
331
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
226
332
|
):
|
|
227
333
|
tXrRstd[0] = rstd
|
|
228
|
-
if
|
|
229
|
-
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
334
|
+
if const_expr(delay_w_load):
|
|
335
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
|
|
336
|
+
if const_expr(mB is not None):
|
|
337
|
+
cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
|
|
338
|
+
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
339
|
+
if const_expr(reload_from == "smem"):
|
|
340
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
341
|
+
else:
|
|
342
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
235
343
|
x = tXrX.load().to(cute.Float32)
|
|
344
|
+
if const_expr(mRes is not None):
|
|
345
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
346
|
+
x += tXrRes.load().to(cute.Float32)
|
|
236
347
|
x_hat = x * rstd
|
|
237
348
|
w = tXrW.load().to(cute.Float32)
|
|
238
349
|
y = x_hat * w
|
|
350
|
+
if const_expr(mB is not None):
|
|
351
|
+
b = tXrB.load().to(cute.Float32)
|
|
352
|
+
y = y + b
|
|
239
353
|
tXrO.store(y.to(tXrO.element_type))
|
|
240
|
-
tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
241
354
|
if row < shape[0]:
|
|
242
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=
|
|
355
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
|
|
243
356
|
|
|
244
357
|
|
|
358
|
+
@torch.library.custom_op(
|
|
359
|
+
"quack::_rmsnorm_fwd",
|
|
360
|
+
mutates_args=("out", "rstd", "residual_out"),
|
|
361
|
+
device_types="cuda",
|
|
362
|
+
# We need to specify the schema manually since we're mutating an optional tensor
|
|
363
|
+
schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
|
|
364
|
+
)
|
|
245
365
|
def _rmsnorm_fwd(
|
|
246
|
-
x:
|
|
247
|
-
weight:
|
|
366
|
+
x: Tensor,
|
|
367
|
+
weight: Tensor,
|
|
368
|
+
out: Tensor,
|
|
369
|
+
bias: Optional[Tensor] = None,
|
|
370
|
+
rstd: Optional[Tensor] = None,
|
|
371
|
+
residual: Optional[Tensor] = None,
|
|
372
|
+
residual_out: Optional[Tensor] = None,
|
|
248
373
|
eps: float = 1e-6,
|
|
249
|
-
|
|
250
|
-
) -> torch.Tensor:
|
|
374
|
+
) -> None:
|
|
251
375
|
"""RMSNorm forward pass.
|
|
252
376
|
Args:
|
|
253
377
|
x: Input tensor of shape (M, N)
|
|
254
378
|
weight: Weight tensor of shape (N,)
|
|
255
379
|
eps: Small value for numerical stability
|
|
256
|
-
return_rstd: Whether to return the reciprocal standard deviation
|
|
257
380
|
Returns:
|
|
258
381
|
Normalized output tensor of same shape as x
|
|
259
|
-
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
260
382
|
"""
|
|
261
383
|
assert x.dim() == 2, "Input must be 2D"
|
|
262
384
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
263
385
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
264
386
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
265
|
-
assert x.dtype in [
|
|
266
|
-
torch.float16,
|
|
267
|
-
torch.bfloat16,
|
|
268
|
-
torch.float32,
|
|
269
|
-
], "Unsupported dtype"
|
|
270
|
-
|
|
387
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
271
388
|
assert weight.dtype in [
|
|
272
389
|
torch.float32,
|
|
273
390
|
torch.bfloat16,
|
|
274
391
|
torch.float16,
|
|
275
392
|
], "Weight must be float32, float16 or bfloat16"
|
|
276
|
-
|
|
277
|
-
|
|
393
|
+
if residual is not None:
|
|
394
|
+
assert residual.shape == x.shape
|
|
395
|
+
assert residual.is_cuda
|
|
396
|
+
assert residual.dtype in [
|
|
397
|
+
torch.float16,
|
|
398
|
+
torch.bfloat16,
|
|
399
|
+
torch.float32,
|
|
400
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
401
|
+
|
|
402
|
+
_, N = x.shape
|
|
278
403
|
device = x.device
|
|
279
|
-
out = torch.empty_like(x)
|
|
280
|
-
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
281
404
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
282
405
|
# convert_from_dlpack = lambda x: (
|
|
283
406
|
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -287,43 +410,109 @@ def _rmsnorm_fwd(
|
|
|
287
410
|
convert_from_dlpack = lambda x: (
|
|
288
411
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
289
412
|
)
|
|
290
|
-
x_tensor, out_tensor = [
|
|
413
|
+
x_tensor, res_tensor, out_tensor, res_out_tensor = [
|
|
414
|
+
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
415
|
+
]
|
|
291
416
|
# handle weight divisibility based on weight dtype
|
|
292
417
|
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
293
418
|
weight_tensor = utils.convert_from_dlpack(
|
|
294
419
|
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
295
420
|
)
|
|
421
|
+
if bias is not None:
|
|
422
|
+
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
423
|
+
bias_tensor = utils.convert_from_dlpack(
|
|
424
|
+
bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
bias_tensor = None
|
|
296
428
|
rstd_tensor = (
|
|
297
429
|
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
298
430
|
if rstd is not None
|
|
299
431
|
else None
|
|
300
432
|
)
|
|
301
433
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
302
|
-
compile_key = (
|
|
434
|
+
compile_key = (
|
|
435
|
+
N,
|
|
436
|
+
dtype,
|
|
437
|
+
res_tensor.element_type if residual is not None else None,
|
|
438
|
+
weight_tensor.element_type,
|
|
439
|
+
bias_tensor.element_type if bias is not None else None,
|
|
440
|
+
res_out_tensor.element_type if residual_out is not None else None,
|
|
441
|
+
rstd is not None,
|
|
442
|
+
)
|
|
303
443
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
304
444
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
305
445
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
306
|
-
rmsnorm_op,
|
|
446
|
+
rmsnorm_op,
|
|
447
|
+
x_tensor,
|
|
448
|
+
weight_tensor,
|
|
449
|
+
bias_tensor,
|
|
450
|
+
res_tensor,
|
|
451
|
+
out_tensor,
|
|
452
|
+
res_out_tensor,
|
|
453
|
+
rstd_tensor,
|
|
454
|
+
current_stream,
|
|
455
|
+
eps,
|
|
307
456
|
)
|
|
308
457
|
_rmsnorm_fwd.compile_cache[compile_key](
|
|
309
|
-
x_tensor,
|
|
458
|
+
x_tensor,
|
|
459
|
+
weight_tensor,
|
|
460
|
+
bias_tensor,
|
|
461
|
+
res_tensor,
|
|
462
|
+
out_tensor,
|
|
463
|
+
res_out_tensor,
|
|
464
|
+
rstd_tensor,
|
|
465
|
+
current_stream,
|
|
466
|
+
eps,
|
|
310
467
|
)
|
|
311
|
-
return (out, rstd) if return_rstd else out
|
|
312
468
|
|
|
313
469
|
|
|
314
470
|
_rmsnorm_fwd.compile_cache = {}
|
|
315
471
|
|
|
316
472
|
|
|
317
|
-
def
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
473
|
+
def rmsnorm_fwd(
|
|
474
|
+
x: Tensor,
|
|
475
|
+
weight: Tensor,
|
|
476
|
+
bias: Optional[Tensor] = None,
|
|
477
|
+
residual: Optional[Tensor] = None,
|
|
478
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
479
|
+
residual_dtype: Optional[torch.dtype] = None,
|
|
480
|
+
eps: float = 1e-6,
|
|
481
|
+
store_rstd: bool = False,
|
|
482
|
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
483
|
+
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
|
|
484
|
+
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
|
|
485
|
+
# so that _layer_norm_fwd_impl doesn't have to return them.
|
|
486
|
+
out_dtype = x.dtype if out_dtype is None else out_dtype
|
|
487
|
+
out = torch.empty_like(x, dtype=out_dtype)
|
|
488
|
+
rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None
|
|
489
|
+
if residual is not None:
|
|
490
|
+
residual_dtype = residual.dtype
|
|
491
|
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
|
492
|
+
residual_out = torch.empty_like(
|
|
493
|
+
x, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
residual_out = None
|
|
497
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
|
|
498
|
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
|
499
|
+
if residual_out is None:
|
|
500
|
+
residual_out = x
|
|
501
|
+
return out, residual_out, rstd
|
|
322
502
|
|
|
323
503
|
|
|
324
|
-
def
|
|
504
|
+
def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
325
505
|
x_f32 = x.float()
|
|
326
|
-
|
|
506
|
+
if residual is not None:
|
|
507
|
+
residual_f32 = residual.float()
|
|
508
|
+
x_f32 += residual_f32
|
|
509
|
+
out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
|
|
510
|
+
if bias is not None:
|
|
511
|
+
out = out + bias.float()
|
|
512
|
+
if residual is None:
|
|
513
|
+
return out.to(x.dtype)
|
|
514
|
+
else:
|
|
515
|
+
return out.to(x.dtype), x_f32.to(residual.dtype)
|
|
327
516
|
|
|
328
517
|
|
|
329
518
|
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
@@ -372,11 +561,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
372
561
|
)
|
|
373
562
|
self.cluster_n = cluster_n
|
|
374
563
|
|
|
375
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
564
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
|
|
565
|
+
if do_dtype is None:
|
|
566
|
+
do_dtype = self.dtype
|
|
376
567
|
return (
|
|
377
|
-
#
|
|
378
|
-
|
|
379
|
-
cute.size_in_bytes(
|
|
568
|
+
# We need space for X and dO, and multiply by 2 due to double buffering
|
|
569
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
570
|
+
+ cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
|
|
380
571
|
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
381
572
|
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
382
573
|
)
|
|
@@ -386,10 +577,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
386
577
|
self,
|
|
387
578
|
mX: cute.Tensor,
|
|
388
579
|
mW: cute.Tensor,
|
|
389
|
-
|
|
580
|
+
mdO: cute.Tensor,
|
|
581
|
+
mdResO: Optional[cute.Tensor],
|
|
390
582
|
mRstd: cute.Tensor,
|
|
391
583
|
mdX: cute.Tensor,
|
|
392
584
|
mdW: cute.Tensor,
|
|
585
|
+
mdRes: Optional[cute.Tensor],
|
|
586
|
+
mdB: Optional[cute.Tensor],
|
|
393
587
|
sm_count: Int32,
|
|
394
588
|
stream: cuda.CUstream,
|
|
395
589
|
):
|
|
@@ -398,24 +592,36 @@ class RMSNormBackward(ReductionBase):
|
|
|
398
592
|
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
399
593
|
t.stride[1],
|
|
400
594
|
)
|
|
401
|
-
mX,
|
|
595
|
+
mX, mdO, mdResO, mdX, mdRes = [
|
|
402
596
|
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
403
|
-
|
|
597
|
+
if const_expr(t is not None)
|
|
598
|
+
else None
|
|
599
|
+
for t in (mX, mdO, mdResO, mdX, mdRes)
|
|
404
600
|
]
|
|
405
601
|
self._set_cluster_n()
|
|
406
|
-
|
|
602
|
+
largest_dtype_width = const_expr(
|
|
603
|
+
max(
|
|
604
|
+
mX.element_type.width,
|
|
605
|
+
mdO.element_type.width,
|
|
606
|
+
mdX.element_type.width,
|
|
607
|
+
mdResO.element_type.width if mdResO is not None else 0,
|
|
608
|
+
mdRes.element_type.width if mdRes is not None else 0,
|
|
609
|
+
)
|
|
610
|
+
)
|
|
611
|
+
tiler_mn, tv_layout = self._get_tv_layout(
|
|
612
|
+
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
613
|
+
)
|
|
407
614
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
408
615
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
409
|
-
|
|
410
616
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
411
617
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
412
618
|
|
|
413
619
|
num_blocks = sm_count
|
|
414
|
-
self.kernel(mX, mW,
|
|
620
|
+
self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
|
|
415
621
|
grid=[num_blocks, self.cluster_n, 1],
|
|
416
622
|
block=[num_threads, 1, 1],
|
|
417
623
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
418
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
624
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
|
|
419
625
|
stream=stream,
|
|
420
626
|
)
|
|
421
627
|
|
|
@@ -424,63 +630,85 @@ class RMSNormBackward(ReductionBase):
|
|
|
424
630
|
self,
|
|
425
631
|
mX: cute.Tensor,
|
|
426
632
|
mW: cute.Tensor,
|
|
427
|
-
|
|
633
|
+
mdO: cute.Tensor,
|
|
634
|
+
mdResO: Optional[cute.Tensor],
|
|
428
635
|
mRstd: cute.Tensor,
|
|
429
636
|
mdX: cute.Tensor,
|
|
430
637
|
mdW: cute.Tensor,
|
|
638
|
+
mdB: Optional[cute.Tensor],
|
|
639
|
+
mdRes: Optional[cute.Tensor],
|
|
431
640
|
tv_layout: cute.Layout,
|
|
432
641
|
tiler_mn: cute.Shape,
|
|
433
642
|
):
|
|
434
643
|
tidx, _, _ = cute.arch.thread_idx()
|
|
435
644
|
bidx_start, _, _ = cute.arch.block_idx()
|
|
436
645
|
gdim, _, _ = cute.arch.grid_dim()
|
|
437
|
-
if
|
|
646
|
+
if const_expr(self.cluster_n > 1):
|
|
438
647
|
cluster_y = cute.arch.block_idx()[1]
|
|
439
648
|
else:
|
|
440
|
-
cluster_y =
|
|
649
|
+
cluster_y = const_expr(0)
|
|
441
650
|
|
|
442
651
|
shape = mX.shape
|
|
443
652
|
M, N = shape[0], shape[1]
|
|
444
|
-
is_even_N =
|
|
653
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
445
654
|
|
|
446
655
|
idX = cute.make_identity_tensor(shape)
|
|
447
656
|
|
|
448
657
|
smem = cutlass.utils.SmemAllocator()
|
|
449
658
|
smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
|
|
450
659
|
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
|
|
451
|
-
|
|
660
|
+
sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
|
|
452
661
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
|
|
453
662
|
smem, tv_layout, is_persistent=True
|
|
454
663
|
)
|
|
455
|
-
if
|
|
664
|
+
if const_expr(mbar_ptr is not None):
|
|
456
665
|
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
|
|
457
666
|
else:
|
|
458
667
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
459
668
|
|
|
669
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
670
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
460
671
|
copy_atom_load_X = cute.make_copy_atom(
|
|
461
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=
|
|
672
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
462
673
|
)
|
|
463
674
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
464
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=
|
|
675
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
465
676
|
)
|
|
466
|
-
|
|
467
|
-
|
|
677
|
+
num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
|
|
678
|
+
copy_atom_load_dO_async = cute.make_copy_atom(
|
|
679
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
|
|
468
680
|
)
|
|
681
|
+
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
469
682
|
copy_atom_load_W = cute.make_copy_atom(
|
|
470
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
471
|
-
)
|
|
472
|
-
num_bits_per_copy_dX = cutlass.const_expr(
|
|
473
|
-
min(128, 128 // mX.element_type.width * mdX.element_type.width)
|
|
683
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
474
684
|
)
|
|
685
|
+
if const_expr(mdResO is not None):
|
|
686
|
+
num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
|
|
687
|
+
copy_atom_load_dResO = cute.make_copy_atom(
|
|
688
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
689
|
+
mdResO.element_type,
|
|
690
|
+
num_bits_per_copy=num_copy_bits_dResO,
|
|
691
|
+
)
|
|
692
|
+
num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
|
|
475
693
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
476
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=
|
|
477
|
-
)
|
|
478
|
-
num_bits_per_copy_dW = cutlass.const_expr(
|
|
479
|
-
min(128, 128 // mX.element_type.width * mdW.element_type.width)
|
|
694
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
|
|
480
695
|
)
|
|
696
|
+
num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
|
|
481
697
|
copy_atom_store_dW = cute.make_copy_atom(
|
|
482
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=
|
|
698
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
|
|
483
699
|
)
|
|
700
|
+
if const_expr(mdB is not None):
|
|
701
|
+
num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
|
|
702
|
+
copy_atom_store_dB = cute.make_copy_atom(
|
|
703
|
+
cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
|
|
704
|
+
)
|
|
705
|
+
if const_expr(mdRes is not None):
|
|
706
|
+
num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
|
|
707
|
+
copy_atom_load_dRes = cute.make_copy_atom(
|
|
708
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
709
|
+
mdRes.element_type,
|
|
710
|
+
num_bits_per_copy=num_copy_bits_dRes,
|
|
711
|
+
)
|
|
484
712
|
|
|
485
713
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
486
714
|
|
|
@@ -510,21 +738,40 @@ class RMSNormBackward(ReductionBase):
|
|
|
510
738
|
if not is_even_N
|
|
511
739
|
else None
|
|
512
740
|
)
|
|
741
|
+
if const_expr(mdB is not None):
|
|
742
|
+
db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
743
|
+
tXpdB = (
|
|
744
|
+
utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
|
|
745
|
+
if not is_even_N
|
|
746
|
+
else None
|
|
747
|
+
)
|
|
513
748
|
|
|
514
749
|
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
515
750
|
tXgdW = thr_copy_X.partition_S(gdW)
|
|
516
751
|
# Always compute partial weight gradients in fp32
|
|
517
752
|
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
518
753
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
754
|
+
gdB = (
|
|
755
|
+
cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
756
|
+
if const_expr(mdB is not None)
|
|
757
|
+
else None
|
|
758
|
+
)
|
|
759
|
+
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
760
|
+
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
761
|
+
|
|
762
|
+
gX, gdO, gdResO, gdX, gdRes, cX = [
|
|
763
|
+
cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
|
|
764
|
+
for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
|
|
765
|
+
]
|
|
523
766
|
tXgX = thr_copy_X.partition_S(gX)
|
|
524
767
|
tXsX = thr_copy_X.partition_D(sX)
|
|
525
|
-
|
|
526
|
-
|
|
768
|
+
tXgdO = thr_copy_X.partition_S(gdO)
|
|
769
|
+
tXsdO = thr_copy_X.partition_D(sdO)
|
|
527
770
|
tXgdX = thr_copy_X.partition_D(gdX)
|
|
771
|
+
if const_expr(mdResO is not None):
|
|
772
|
+
tXgdResO = thr_copy_X.partition_S(gdResO)
|
|
773
|
+
if const_expr(mdRes is not None):
|
|
774
|
+
tXgdRes = thr_copy_X.partition_D(gdRes)
|
|
528
775
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
529
776
|
# This doesn't change across iterations
|
|
530
777
|
tXpX = (
|
|
@@ -533,62 +780,50 @@ class RMSNormBackward(ReductionBase):
|
|
|
533
780
|
else None
|
|
534
781
|
)
|
|
535
782
|
|
|
536
|
-
tXrX,
|
|
537
|
-
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX,
|
|
783
|
+
tXrX, tXrdO, tXrdX = [
|
|
784
|
+
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
538
785
|
]
|
|
786
|
+
tXrdResO = None
|
|
787
|
+
if const_expr(mdResO is not None):
|
|
788
|
+
tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
|
|
789
|
+
tXrdRes = None
|
|
790
|
+
if const_expr(mdRes is not None):
|
|
791
|
+
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
792
|
+
|
|
793
|
+
copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
|
|
794
|
+
copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
|
|
539
795
|
|
|
540
796
|
# Prefetch the first batch
|
|
541
797
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
542
798
|
if row < M:
|
|
543
799
|
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
tXgX_cur,
|
|
548
|
-
tXsX[None, None, None, 0],
|
|
549
|
-
pred=tXpX,
|
|
550
|
-
)
|
|
551
|
-
cute.copy(
|
|
552
|
-
copy_atom_load_X_async,
|
|
553
|
-
tXgdOut_cur,
|
|
554
|
-
tXsdOut[None, None, None, 0],
|
|
555
|
-
pred=tXpX,
|
|
556
|
-
)
|
|
800
|
+
tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
|
|
801
|
+
copy_X(tXgX_cur, tXsX[None, None, None, 0])
|
|
802
|
+
copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
|
|
557
803
|
elif tiler_mn[0] > 1:
|
|
558
804
|
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
559
805
|
# later into registers, causing wrong dW.
|
|
560
806
|
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
561
|
-
utils.fill_oob(
|
|
807
|
+
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
562
808
|
cute.arch.cp_async_commit_group()
|
|
563
809
|
|
|
564
|
-
if
|
|
810
|
+
if const_expr(self.cluster_n > 1):
|
|
565
811
|
cute.arch.cluster_wait()
|
|
566
812
|
|
|
567
813
|
threads_per_row = tv_layout.shape[0][0]
|
|
568
814
|
tXrdW.fill(0.0)
|
|
815
|
+
if const_expr(mdB is not None):
|
|
816
|
+
tXrdB.fill(0.0)
|
|
569
817
|
stage = Int32(0)
|
|
570
818
|
producer_phase = Int32(1)
|
|
571
819
|
consumer_phase = Int32(0)
|
|
572
820
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
573
821
|
row = tXcX[None, None, None, bidx][0][0]
|
|
574
|
-
rstd = cutlass.Float.zero
|
|
575
822
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
576
823
|
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
]
|
|
580
|
-
cute.copy(
|
|
581
|
-
copy_atom_load_X_async,
|
|
582
|
-
tXgX_cur,
|
|
583
|
-
tXsX[None, None, None, stage ^ 1],
|
|
584
|
-
pred=tXpX,
|
|
585
|
-
)
|
|
586
|
-
cute.copy(
|
|
587
|
-
copy_atom_load_X_async,
|
|
588
|
-
tXgdOut_cur,
|
|
589
|
-
tXsdOut[None, None, None, stage ^ 1],
|
|
590
|
-
pred=tXpX,
|
|
591
|
-
)
|
|
824
|
+
tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
|
|
825
|
+
copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
|
|
826
|
+
copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
|
|
592
827
|
elif tiler_mn[0] > 1:
|
|
593
828
|
utils.fill_oob(
|
|
594
829
|
tXsX[None, None, None, stage ^ 1],
|
|
@@ -596,36 +831,45 @@ class RMSNormBackward(ReductionBase):
|
|
|
596
831
|
fill_value=mX.element_type.zero,
|
|
597
832
|
)
|
|
598
833
|
utils.fill_oob(
|
|
599
|
-
|
|
834
|
+
tXsdO[None, None, None, stage ^ 1],
|
|
600
835
|
None,
|
|
601
|
-
fill_value=
|
|
836
|
+
fill_value=mdO.element_type.zero,
|
|
602
837
|
)
|
|
603
838
|
cute.arch.cp_async_commit_group()
|
|
839
|
+
rstd = cutlass.Float.zero
|
|
604
840
|
if row < M or tiler_mn[0] == 1:
|
|
605
841
|
rstd = mRstd[row]
|
|
842
|
+
if const_expr(mdResO is not None):
|
|
843
|
+
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
844
|
+
if row < M or tiler_mn[0] == 1:
|
|
845
|
+
cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
|
|
846
|
+
elif tiler_mn[0] > 1:
|
|
847
|
+
tXrdResO.fill(0.0)
|
|
606
848
|
cute.arch.cp_async_wait_group(1)
|
|
607
849
|
cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
|
|
608
850
|
x = tXrX.load().to(cute.Float32)
|
|
609
|
-
cute.autovec_copy(
|
|
610
|
-
dout =
|
|
851
|
+
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
852
|
+
dout = tXrdO.load().to(cute.Float32)
|
|
853
|
+
if const_expr(mdResO is not None):
|
|
854
|
+
dout += tXrdResO.load().to(cute.Float32)
|
|
611
855
|
x_hat = x * rstd
|
|
612
856
|
wdy = dout * weight
|
|
613
|
-
if
|
|
857
|
+
if const_expr(self.cluster_n > 1):
|
|
614
858
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
615
859
|
mean_xhat_wdy = (
|
|
616
|
-
|
|
860
|
+
row_reduce(
|
|
617
861
|
x_hat * wdy,
|
|
618
862
|
cute.ReductionOp.ADD,
|
|
619
863
|
threads_per_row,
|
|
620
864
|
reduction_buffer[None, None, stage],
|
|
621
|
-
(mbar_full_ptr + stage if
|
|
865
|
+
(mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
|
|
622
866
|
phase=consumer_phase,
|
|
623
867
|
init_val=0.0,
|
|
624
868
|
)
|
|
625
869
|
/ shape[1]
|
|
626
870
|
)
|
|
627
871
|
|
|
628
|
-
if
|
|
872
|
+
if const_expr(self.cluster_n > 1):
|
|
629
873
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
630
874
|
# Requires adjusting the thread_count when initializing the mbar
|
|
631
875
|
cute.arch.sync_warp()
|
|
@@ -635,28 +879,37 @@ class RMSNormBackward(ReductionBase):
|
|
|
635
879
|
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
636
880
|
)
|
|
637
881
|
|
|
638
|
-
if
|
|
639
|
-
cute.autovec_copy(
|
|
640
|
-
dout =
|
|
882
|
+
if const_expr(self.reload_wdy == "smem"):
|
|
883
|
+
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
884
|
+
dout = tXrdO.load().to(cute.Float32)
|
|
885
|
+
if const_expr(mdResO is not None):
|
|
886
|
+
dout += tXrdResO.load().to(cute.Float32)
|
|
641
887
|
wdy = dout * weight
|
|
642
888
|
|
|
643
889
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
644
|
-
tXrdX.store(dx.to(
|
|
890
|
+
tXrdX.store(dx.to(tXrdX.element_type))
|
|
645
891
|
if row < M or tiler_mn[0] == 1:
|
|
646
892
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
647
893
|
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
894
|
+
if const_expr(mdRes is not None):
|
|
895
|
+
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
896
|
+
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
897
|
+
if row < M or tiler_mn[0] == 1:
|
|
898
|
+
cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
|
|
648
899
|
# Accumulate weight gradients in fp32
|
|
649
900
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
901
|
+
if const_expr(mdB is not None):
|
|
902
|
+
tXrdB.store(tXrdB.load() + dout)
|
|
650
903
|
|
|
651
904
|
stage ^= 1
|
|
652
905
|
if stage == 0:
|
|
653
906
|
consumer_phase ^= 1
|
|
654
907
|
producer_phase ^= 1
|
|
655
908
|
|
|
656
|
-
if
|
|
909
|
+
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
657
910
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
658
911
|
|
|
659
|
-
if
|
|
912
|
+
if const_expr(tiler_mn[0] > 1):
|
|
660
913
|
# reduction of dw_partial within the same threadblock
|
|
661
914
|
sdW = cute.make_tensor(
|
|
662
915
|
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
@@ -669,23 +922,75 @@ class RMSNormBackward(ReductionBase):
|
|
|
669
922
|
cute.autovec_copy(tXrdW, tXsdW)
|
|
670
923
|
cute.arch.barrier()
|
|
671
924
|
if row == 0:
|
|
672
|
-
for i in cutlass.range_constexpr(1,
|
|
925
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
673
926
|
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
674
927
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
675
928
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
676
929
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
677
930
|
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
931
|
+
cute.arch.barrier()
|
|
932
|
+
if const_expr(mdB is not None):
|
|
933
|
+
sdB = cute.make_tensor(
|
|
934
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
935
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
936
|
+
)
|
|
937
|
+
tXsdB = thr_copy_X.partition_D(sdB)
|
|
938
|
+
cute.arch.barrier()
|
|
939
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
940
|
+
if row > 0:
|
|
941
|
+
cute.autovec_copy(tXrdB, tXsdB)
|
|
942
|
+
cute.arch.barrier()
|
|
943
|
+
if row == 0:
|
|
944
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
945
|
+
tXrdB_other = cute.make_fragment_like(tXrdB)
|
|
946
|
+
tXsdB_other = cute.make_tensor(
|
|
947
|
+
tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
|
|
948
|
+
)
|
|
949
|
+
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
950
|
+
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
951
|
+
cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
|
|
678
952
|
else:
|
|
679
953
|
# dw is already in fp32, so we can directly copy to global memory
|
|
680
954
|
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
955
|
+
if const_expr(mdB is not None):
|
|
956
|
+
cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
|
|
681
957
|
|
|
682
958
|
|
|
683
|
-
def
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
959
|
+
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
960
|
+
# This should be tuned on how many CTAs can be launched on each SM
|
|
961
|
+
sm_count_multiple = (
|
|
962
|
+
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
963
|
+
)
|
|
964
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
965
|
+
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
966
|
+
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
967
|
+
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
968
|
+
# avoid wave quantization.
|
|
969
|
+
sm_count = (
|
|
970
|
+
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
return sm_count
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
@torch.library.custom_op(
|
|
977
|
+
"quack::_rmsnorm_bwd",
|
|
978
|
+
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
979
|
+
device_types="cuda",
|
|
980
|
+
# We need to specify the schema manually since we're mutating an optional tensor
|
|
981
|
+
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
|
|
982
|
+
)
|
|
983
|
+
def _rmsnorm_bwd(
|
|
984
|
+
x: Tensor,
|
|
985
|
+
weight: Tensor,
|
|
986
|
+
dout: Tensor,
|
|
987
|
+
rstd: Tensor,
|
|
988
|
+
dx: Tensor,
|
|
989
|
+
dw_partial: Tensor,
|
|
990
|
+
db_partial: Optional[Tensor] = None,
|
|
991
|
+
dresidual_out: Optional[Tensor] = None,
|
|
992
|
+
dresidual: Optional[Tensor] = None,
|
|
993
|
+
) -> None:
|
|
689
994
|
"""RMSNorm backward pass.
|
|
690
995
|
Args:
|
|
691
996
|
x: Input tensor of shape (M, N)
|
|
@@ -701,46 +1006,39 @@ def _rmsnorm_backward(
|
|
|
701
1006
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
702
1007
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
703
1008
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
704
|
-
assert x.dtype in [
|
|
705
|
-
torch.float16,
|
|
706
|
-
torch.bfloat16,
|
|
707
|
-
torch.float32,
|
|
708
|
-
], "Unsupported dtype"
|
|
709
|
-
|
|
1009
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
710
1010
|
assert weight.dtype in [
|
|
711
1011
|
torch.float32,
|
|
712
1012
|
torch.bfloat16,
|
|
713
1013
|
torch.float16,
|
|
714
1014
|
], "Weight must be float32, float16 or bfloat16"
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
1015
|
+
if dresidual_out is not None:
|
|
1016
|
+
assert dresidual_out.shape == x.shape
|
|
1017
|
+
assert dresidual_out.is_cuda
|
|
1018
|
+
assert dresidual_out.dtype in [
|
|
1019
|
+
torch.float16,
|
|
1020
|
+
torch.bfloat16,
|
|
1021
|
+
torch.float32,
|
|
1022
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
1023
|
+
if dresidual is not None:
|
|
1024
|
+
assert dresidual.shape == x.shape
|
|
1025
|
+
assert dresidual.is_cuda
|
|
1026
|
+
assert dresidual.dtype in [
|
|
1027
|
+
torch.float16,
|
|
1028
|
+
torch.bfloat16,
|
|
1029
|
+
torch.float32,
|
|
1030
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
1031
|
+
|
|
1032
|
+
N = x.size(1)
|
|
719
1033
|
device = x.device
|
|
720
|
-
|
|
721
|
-
# This should be tuned on how many CTAs can be launched on each SM
|
|
722
|
-
sm_count_multiple = (
|
|
723
|
-
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
724
|
-
)
|
|
725
|
-
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
726
|
-
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
727
|
-
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
728
|
-
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
729
|
-
# avoid wave quantization.
|
|
730
|
-
sm_count = (
|
|
731
|
-
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
732
|
-
)
|
|
733
|
-
|
|
734
|
-
# Always store partial gradients in fp32 for numerical accuracy
|
|
735
|
-
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
736
|
-
|
|
737
|
-
dtype = torch2cute_dtype_map[x.dtype]
|
|
738
|
-
|
|
1034
|
+
sm_count = dw_partial.shape[0]
|
|
739
1035
|
convert_from_dlpack = lambda x: (
|
|
740
1036
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
741
1037
|
)
|
|
742
|
-
x_tensor, dout_tensor, dx_tensor = [
|
|
743
|
-
|
|
1038
|
+
x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
|
|
1039
|
+
convert_from_dlpack(t) if t is not None else None
|
|
1040
|
+
for t in (x, dout, dresidual_out, dx, dresidual)
|
|
1041
|
+
]
|
|
744
1042
|
# Handle weight div based on weight dtype
|
|
745
1043
|
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
746
1044
|
weight_tensor = utils.convert_from_dlpack(
|
|
@@ -748,74 +1046,162 @@ def _rmsnorm_backward(
|
|
|
748
1046
|
)
|
|
749
1047
|
|
|
750
1048
|
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1049
|
+
db_partial_tensor = (
|
|
1050
|
+
from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1051
|
+
if db_partial is not None
|
|
1052
|
+
else None
|
|
1053
|
+
)
|
|
751
1054
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
752
1055
|
|
|
753
1056
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
754
1057
|
|
|
755
|
-
compile_key = (
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
1058
|
+
compile_key = (
|
|
1059
|
+
N,
|
|
1060
|
+
x_tensor.element_type,
|
|
1061
|
+
weight_tensor.element_type,
|
|
1062
|
+
db_partial.dtype if db_partial is not None else None,
|
|
1063
|
+
dresidual.dtype if dresidual is not None else None,
|
|
1064
|
+
dresidual_out.dtype if dresidual_out is not None else None,
|
|
1065
|
+
)
|
|
1066
|
+
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1067
|
+
rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
|
|
1068
|
+
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
759
1069
|
rmsnorm_backward_op,
|
|
760
1070
|
x_tensor,
|
|
761
1071
|
weight_tensor,
|
|
762
1072
|
dout_tensor,
|
|
1073
|
+
dres_out_tensor,
|
|
763
1074
|
rstd_tensor,
|
|
764
1075
|
dx_tensor,
|
|
765
1076
|
dw_partial_tensor,
|
|
1077
|
+
dres_tensor,
|
|
1078
|
+
db_partial_tensor,
|
|
766
1079
|
sm_count,
|
|
767
1080
|
current_stream,
|
|
768
1081
|
)
|
|
769
1082
|
|
|
770
|
-
|
|
1083
|
+
_rmsnorm_bwd.compile_cache[compile_key](
|
|
771
1084
|
x_tensor,
|
|
772
1085
|
weight_tensor,
|
|
773
1086
|
dout_tensor,
|
|
1087
|
+
dres_out_tensor,
|
|
774
1088
|
rstd_tensor,
|
|
775
1089
|
dx_tensor,
|
|
776
1090
|
dw_partial_tensor,
|
|
1091
|
+
dres_tensor,
|
|
1092
|
+
db_partial_tensor,
|
|
777
1093
|
sm_count,
|
|
778
1094
|
current_stream,
|
|
779
1095
|
)
|
|
780
|
-
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
781
|
-
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
782
|
-
return dx, dw
|
|
783
1096
|
|
|
784
1097
|
|
|
785
|
-
|
|
1098
|
+
_rmsnorm_bwd.compile_cache = {}
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def rmsnorm_bwd(
|
|
1102
|
+
x: Tensor,
|
|
1103
|
+
weight: Tensor,
|
|
1104
|
+
dout: Tensor,
|
|
1105
|
+
rstd: Tensor,
|
|
1106
|
+
dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
|
|
1107
|
+
has_bias: bool = False,
|
|
1108
|
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
1109
|
+
device = x.device
|
|
1110
|
+
N = x.size(1)
|
|
1111
|
+
sm_count = _get_sm_count(N, device)
|
|
1112
|
+
dx = torch.empty_like(x)
|
|
1113
|
+
|
|
1114
|
+
if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
|
|
1115
|
+
dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
|
|
1116
|
+
else:
|
|
1117
|
+
dresidual = None
|
|
1118
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
1119
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
1120
|
+
db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
|
|
1121
|
+
_rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
|
|
1122
|
+
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
1123
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
1124
|
+
db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
|
|
1125
|
+
# dresidual is the same as dx in this case
|
|
1126
|
+
if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
|
|
1127
|
+
dresidual = dx
|
|
1128
|
+
return dx, dw, db, dresidual
|
|
786
1129
|
|
|
787
1130
|
|
|
788
1131
|
class RMSNormFunction(torch.autograd.Function):
|
|
789
1132
|
@staticmethod
|
|
790
|
-
def forward(
|
|
791
|
-
|
|
792
|
-
|
|
1133
|
+
def forward(
|
|
1134
|
+
ctx,
|
|
1135
|
+
x,
|
|
1136
|
+
weight,
|
|
1137
|
+
bias=None,
|
|
1138
|
+
residual=None,
|
|
1139
|
+
out_dtype=None,
|
|
1140
|
+
residual_dtype=None,
|
|
1141
|
+
eps=1e-6,
|
|
1142
|
+
prenorm=False,
|
|
1143
|
+
):
|
|
1144
|
+
x_shape_og = x.shape
|
|
793
1145
|
# Flatten input
|
|
794
|
-
x = x.
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
ctx.
|
|
1146
|
+
x = x.reshape(-1, x.shape[-1])
|
|
1147
|
+
if residual is not None:
|
|
1148
|
+
residual = residual.reshape(-1, residual.shape[-1])
|
|
1149
|
+
need_grad = any(ctx.needs_input_grad[:3])
|
|
1150
|
+
out, residual_out, rstd = rmsnorm_fwd(
|
|
1151
|
+
x,
|
|
1152
|
+
weight,
|
|
1153
|
+
bias=bias,
|
|
1154
|
+
residual=residual,
|
|
1155
|
+
out_dtype=out_dtype,
|
|
1156
|
+
residual_dtype=residual_dtype,
|
|
1157
|
+
eps=eps,
|
|
1158
|
+
store_rstd=need_grad,
|
|
1159
|
+
)
|
|
1160
|
+
ctx.save_for_backward(x if residual is None else residual_out, weight, rstd)
|
|
1161
|
+
ctx.has_bias = bias is not None
|
|
798
1162
|
ctx.eps = eps
|
|
799
|
-
ctx.
|
|
800
|
-
|
|
801
|
-
|
|
1163
|
+
ctx.x_shape_og = x_shape_og
|
|
1164
|
+
ctx.residual_dtype = residual.dtype if residual is not None else None
|
|
1165
|
+
ctx.prenorm = prenorm
|
|
1166
|
+
if residual_out is None or not prenorm:
|
|
1167
|
+
return out.reshape(x_shape_og)
|
|
1168
|
+
else:
|
|
1169
|
+
return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
|
|
802
1170
|
|
|
803
1171
|
@staticmethod
|
|
804
|
-
def backward(ctx, dout):
|
|
1172
|
+
def backward(ctx, dout, *args):
|
|
805
1173
|
x, weight, rstd = ctx.saved_tensors
|
|
806
|
-
|
|
1174
|
+
has_bias = ctx.has_bias
|
|
1175
|
+
if ctx.prenorm and ctx.residual_dtype is not None:
|
|
1176
|
+
dresidual_out = args[0]
|
|
1177
|
+
dresidual_out = dresidual_out.reshape(-1, dresidual_out.shape[-1])
|
|
1178
|
+
else:
|
|
1179
|
+
dresidual_out = None
|
|
1180
|
+
x_shape_og = ctx.x_shape_og
|
|
807
1181
|
# Reshape dout to match the flattened shape used in forward
|
|
808
1182
|
dout = dout.view(-1, dout.shape[-1])
|
|
809
|
-
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
810
|
-
dx = dx.view(x_shape_start)
|
|
811
|
-
# dx is returned for input gradient,
|
|
812
|
-
# dw is returned for weight gradient,
|
|
813
|
-
# None for eps gradient
|
|
814
|
-
return dx, dw, None
|
|
815
1183
|
|
|
1184
|
+
dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
|
|
1185
|
+
dx = dx.view(x_shape_og)
|
|
1186
|
+
if dresidual_out is not None:
|
|
1187
|
+
dresidual_out = dresidual_out.reshape(x_shape_og)
|
|
1188
|
+
if dresidual is not None:
|
|
1189
|
+
dresidual = dresidual.reshape(x_shape_og)
|
|
1190
|
+
|
|
1191
|
+
return dx, dw, db, dresidual, *([None] * 4)
|
|
816
1192
|
|
|
817
|
-
|
|
818
|
-
|
|
1193
|
+
|
|
1194
|
+
def rmsnorm(
|
|
1195
|
+
x: Tensor,
|
|
1196
|
+
weight: Tensor,
|
|
1197
|
+
bias: Optional[Tensor] = None,
|
|
1198
|
+
residual: Optional[Tensor] = None,
|
|
1199
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1200
|
+
residual_dtype: Optional[torch.dtype] = None,
|
|
1201
|
+
eps: float = 1e-6,
|
|
1202
|
+
prenorm: bool = False,
|
|
1203
|
+
) -> Tensor:
|
|
1204
|
+
"""RMSNorm with automatic differentiation support.
|
|
819
1205
|
|
|
820
1206
|
Args:
|
|
821
1207
|
x: Input tensor of shape (M, N)
|
|
@@ -825,7 +1211,7 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
|
|
|
825
1211
|
Returns:
|
|
826
1212
|
Normalized output tensor of same shape as x
|
|
827
1213
|
"""
|
|
828
|
-
return RMSNormFunction.apply(x, weight, eps)
|
|
1214
|
+
return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
|
|
829
1215
|
|
|
830
1216
|
|
|
831
1217
|
class QuackRMSNorm(torch.nn.Module):
|
|
@@ -848,16 +1234,16 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
848
1234
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
849
1235
|
self.eps = eps
|
|
850
1236
|
|
|
851
|
-
def forward(self, x:
|
|
1237
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
852
1238
|
"""Apply RMSNorm to the input tensor.
|
|
853
1239
|
|
|
854
1240
|
Args:
|
|
855
|
-
x (
|
|
1241
|
+
x (Tensor): Input tensor
|
|
856
1242
|
|
|
857
1243
|
Returns:
|
|
858
|
-
|
|
1244
|
+
Tensor: Normalized tensor
|
|
859
1245
|
"""
|
|
860
|
-
return rmsnorm(x, self.weight, self.eps)
|
|
1246
|
+
return rmsnorm(x, self.weight, eps=self.eps)
|
|
861
1247
|
|
|
862
1248
|
def reset_parameters(self):
|
|
863
1249
|
"""Reset the weight parameter to ones."""
|