quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,23 +1,28 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
from functools import partial
|
|
4
5
|
|
|
5
6
|
import cuda.bindings.driver as cuda
|
|
6
7
|
|
|
7
8
|
import cutlass
|
|
8
9
|
import cutlass.cute as cute
|
|
9
10
|
from cutlass import Float32, Int32
|
|
11
|
+
from cutlass import const_expr
|
|
10
12
|
from cutlass.cute.runtime import from_dlpack
|
|
11
13
|
|
|
12
|
-
import quack.utils as utils
|
|
13
14
|
import torch
|
|
14
|
-
from
|
|
15
|
+
from torch import Tensor
|
|
15
16
|
|
|
17
|
+
import quack.utils as utils
|
|
18
|
+
from quack.reduce import row_reduce
|
|
19
|
+
from quack.reduction_base import ReductionBase
|
|
20
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
16
21
|
|
|
17
22
|
class RMSNorm(ReductionBase):
|
|
18
23
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
19
24
|
super().__init__(dtype, N, stage=1)
|
|
20
|
-
self.reload_from = None if N <=
|
|
25
|
+
self.reload_from = None if N <= 8192 else "smem"
|
|
21
26
|
self.delay_w_load = False
|
|
22
27
|
|
|
23
28
|
def _calculate_threads_per_row(self):
|
|
@@ -45,7 +50,7 @@ class RMSNorm(ReductionBase):
|
|
|
45
50
|
|
|
46
51
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
47
52
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
48
|
-
if
|
|
53
|
+
if const_expr(self.dtype.width == 16):
|
|
49
54
|
# 16-bit types (fp16, bf16)
|
|
50
55
|
if N <= 16 * 1024:
|
|
51
56
|
cluster_n = 1
|
|
@@ -72,12 +77,27 @@ class RMSNorm(ReductionBase):
|
|
|
72
77
|
|
|
73
78
|
self.cluster_n = cluster_n
|
|
74
79
|
|
|
80
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
|
|
81
|
+
return (
|
|
82
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
83
|
+
+ (
|
|
84
|
+
cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
|
|
85
|
+
if dtype_res is not None
|
|
86
|
+
else 0
|
|
87
|
+
)
|
|
88
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
89
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
90
|
+
)
|
|
91
|
+
|
|
75
92
|
@cute.jit
|
|
76
93
|
def __call__(
|
|
77
94
|
self,
|
|
78
95
|
mX: cute.Tensor,
|
|
79
96
|
mW: cute.Tensor,
|
|
97
|
+
mB: Optional[cute.Tensor],
|
|
98
|
+
mRes: Optional[cute.Tensor],
|
|
80
99
|
mO: cute.Tensor,
|
|
100
|
+
mResO: Optional[cute.Tensor],
|
|
81
101
|
mRstd: Optional[cute.Tensor],
|
|
82
102
|
stream: cuda.CUstream,
|
|
83
103
|
eps: Float32 = 1e-6,
|
|
@@ -87,28 +107,47 @@ class RMSNorm(ReductionBase):
|
|
|
87
107
|
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
88
108
|
t.stride[1],
|
|
89
109
|
)
|
|
90
|
-
mX, mO = [
|
|
110
|
+
mX, mRes, mO, mResO = [
|
|
91
111
|
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
92
|
-
|
|
112
|
+
if const_expr(t is not None)
|
|
113
|
+
else None
|
|
114
|
+
for t in (mX, mRes, mO, mResO)
|
|
93
115
|
]
|
|
94
116
|
assert mX.element_type == self.dtype
|
|
95
117
|
assert mO.element_type == self.dtype
|
|
96
118
|
self._set_cluster_n()
|
|
97
|
-
|
|
119
|
+
largest_dtype_width = const_expr(
|
|
120
|
+
max(
|
|
121
|
+
mX.element_type.width,
|
|
122
|
+
mRes.element_type.width if mRes is not None else 0,
|
|
123
|
+
mO.element_type.width,
|
|
124
|
+
mResO.element_type.width if mResO is not None else 0,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
tiler_mn, tv_layout = self._get_tv_layout(
|
|
128
|
+
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
129
|
+
)
|
|
98
130
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
99
131
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
100
132
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
101
133
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
102
|
-
if
|
|
134
|
+
if const_expr(mB is not None):
|
|
135
|
+
mB_expanded_layout = cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
136
|
+
mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
|
|
137
|
+
if const_expr(mRstd is not None):
|
|
103
138
|
mRstd_expanded_layout = cute.append(
|
|
104
139
|
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
105
140
|
)
|
|
106
141
|
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
107
|
-
self.kernel(
|
|
142
|
+
self.kernel(
|
|
143
|
+
mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
|
|
144
|
+
).launch(
|
|
108
145
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
109
146
|
block=[num_threads, 1, 1],
|
|
110
|
-
cluster=([1, self.cluster_n, 1] if
|
|
111
|
-
smem=self._smem_size_in_bytes(
|
|
147
|
+
cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
|
|
148
|
+
smem=self._smem_size_in_bytes(
|
|
149
|
+
tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
|
|
150
|
+
),
|
|
112
151
|
stream=stream,
|
|
113
152
|
)
|
|
114
153
|
|
|
@@ -117,7 +156,10 @@ class RMSNorm(ReductionBase):
|
|
|
117
156
|
self,
|
|
118
157
|
mX: cute.Tensor,
|
|
119
158
|
mW: cute.Tensor,
|
|
159
|
+
mB: Optional[cute.Tensor],
|
|
160
|
+
mRes: Optional[cute.Tensor],
|
|
120
161
|
mO: cute.Tensor,
|
|
162
|
+
mResO: Optional[cute.Tensor],
|
|
121
163
|
mRstd: Optional[cute.Tensor],
|
|
122
164
|
eps: cute.Float32,
|
|
123
165
|
tv_layout: cute.Layout,
|
|
@@ -127,10 +169,10 @@ class RMSNorm(ReductionBase):
|
|
|
127
169
|
):
|
|
128
170
|
tidx, _, _ = cute.arch.thread_idx()
|
|
129
171
|
bidx, _, _ = cute.arch.block_idx()
|
|
130
|
-
if
|
|
172
|
+
if const_expr(self.cluster_n > 1):
|
|
131
173
|
cluster_y = cute.arch.block_idx()[1]
|
|
132
174
|
else:
|
|
133
|
-
cluster_y =
|
|
175
|
+
cluster_y = const_expr(0)
|
|
134
176
|
|
|
135
177
|
smem = cutlass.utils.SmemAllocator()
|
|
136
178
|
sX = smem.allocate_tensor(
|
|
@@ -138,86 +180,145 @@ class RMSNorm(ReductionBase):
|
|
|
138
180
|
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
139
181
|
byte_alignment=16,
|
|
140
182
|
)
|
|
183
|
+
if const_expr(mRes is not None):
|
|
184
|
+
sRes = smem.allocate_tensor(
|
|
185
|
+
mRes.element_type,
|
|
186
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
187
|
+
byte_alignment=16,
|
|
188
|
+
)
|
|
141
189
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
142
190
|
|
|
143
191
|
shape = mX.shape
|
|
144
192
|
idX = cute.make_identity_tensor(shape)
|
|
145
193
|
# slice for CTAs
|
|
146
194
|
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
147
|
-
mX,
|
|
148
|
-
|
|
195
|
+
mX, mRes, mO, mResO = [
|
|
196
|
+
utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
|
|
197
|
+
for mT in (mX, mRes, mO, mResO)
|
|
198
|
+
]
|
|
199
|
+
gX, gRes, gO, gResO = [
|
|
200
|
+
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
|
|
201
|
+
for mT in (mX, mRes, mO, mResO)
|
|
202
|
+
]
|
|
149
203
|
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
150
204
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
205
|
+
gB = (
|
|
206
|
+
cute.local_tile(mB, tiler_mn, (0, cluster_y))
|
|
207
|
+
if const_expr(mB is not None)
|
|
208
|
+
else None
|
|
209
|
+
)
|
|
151
210
|
gRstd = (
|
|
152
211
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
153
|
-
if
|
|
212
|
+
if const_expr(mRstd is not None)
|
|
154
213
|
else None
|
|
155
214
|
)
|
|
156
215
|
|
|
157
216
|
# declare the atoms which will be used later for memory copy
|
|
217
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
218
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
158
219
|
copy_atom_load_X = cute.make_copy_atom(
|
|
159
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=
|
|
220
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
160
221
|
)
|
|
161
222
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
162
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=
|
|
163
|
-
)
|
|
164
|
-
num_bits_per_copy_W = cutlass.const_expr(
|
|
165
|
-
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
223
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
166
224
|
)
|
|
225
|
+
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
167
226
|
copy_atom_load_W = cute.make_copy_atom(
|
|
168
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
169
|
-
)
|
|
170
|
-
num_bits_per_copy_O = cutlass.const_expr(
|
|
171
|
-
min(128, 128 // mX.element_type.width * mO.element_type.width)
|
|
227
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
172
228
|
)
|
|
229
|
+
num_bits_per_copy_B = cutlass.const_expr(
|
|
230
|
+
min(128, num_copy_elems_X * mB.element_type.width)
|
|
231
|
+
) if const_expr(mB is not None) else 0
|
|
232
|
+
copy_atom_load_B = cute.make_copy_atom(
|
|
233
|
+
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
234
|
+
) if const_expr(mB is not None) else None
|
|
235
|
+
if const_expr(mRes is not None):
|
|
236
|
+
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
237
|
+
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
238
|
+
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
239
|
+
mRes.element_type,
|
|
240
|
+
num_bits_per_copy=num_copy_bits_Res,
|
|
241
|
+
)
|
|
242
|
+
num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
|
|
173
243
|
copy_atom_store_O = cute.make_copy_atom(
|
|
174
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=
|
|
244
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
|
|
175
245
|
)
|
|
246
|
+
if const_expr(mResO is not None):
|
|
247
|
+
num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
|
|
248
|
+
copy_atom_store_ResO = cute.make_copy_atom(
|
|
249
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
250
|
+
mResO.element_type,
|
|
251
|
+
num_bits_per_copy=num_copy_bits_ResO,
|
|
252
|
+
)
|
|
176
253
|
|
|
177
254
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
178
255
|
tidx
|
|
179
256
|
)
|
|
180
257
|
|
|
181
258
|
tXgW = thr_copy_X.partition_S(gW)
|
|
259
|
+
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
182
260
|
tXgX = thr_copy_X.partition_S(gX)
|
|
183
261
|
tXsX = thr_copy_X.partition_D(sX)
|
|
262
|
+
if const_expr(mRes is not None):
|
|
263
|
+
tXgRes = thr_copy_X.partition_S(gRes)
|
|
264
|
+
tXsRes = thr_copy_X.partition_D(sRes)
|
|
184
265
|
tXgO = thr_copy_X.partition_D(gO)
|
|
185
|
-
|
|
266
|
+
if const_expr(mResO is not None):
|
|
267
|
+
tXgResO = thr_copy_X.partition_D(gResO)
|
|
268
|
+
tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
|
|
186
269
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
187
270
|
|
|
188
271
|
# allocate fragments for gmem->rmem
|
|
189
272
|
tXrW = cute.make_fragment_like(tXgW)
|
|
190
273
|
tXrW.fill(0.0)
|
|
191
|
-
|
|
274
|
+
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
275
|
+
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
276
|
+
if const_expr(mRes is not None):
|
|
277
|
+
tXrRes = cute.make_fragment_like(tXgRes)
|
|
192
278
|
|
|
193
279
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
194
280
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
195
281
|
|
|
196
|
-
|
|
282
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
283
|
+
tXpX = (
|
|
284
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
285
|
+
)
|
|
197
286
|
row = tXcX[0][0]
|
|
198
287
|
if row < shape[0]:
|
|
199
288
|
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
289
|
+
if const_expr(mRes is not None):
|
|
290
|
+
cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
|
|
200
291
|
cute.arch.cp_async_commit_group()
|
|
201
292
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
293
|
+
if const_expr(not delay_w_load):
|
|
294
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
|
|
295
|
+
if const_expr(mB is not None):
|
|
296
|
+
cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
|
|
205
297
|
|
|
206
298
|
cute.arch.cp_async_wait_group(0)
|
|
207
299
|
cute.autovec_copy(tXsX, tXrX)
|
|
208
300
|
x = tXrX.load().to(cute.Float32)
|
|
301
|
+
if const_expr(mRes is not None):
|
|
302
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
303
|
+
x += tXrRes.load().to(cute.Float32)
|
|
304
|
+
if const_expr(mResO is not None):
|
|
305
|
+
tXrResO = cute.make_fragment_like(tXgResO)
|
|
306
|
+
tXrResO.store(x.to(tXrResO.element_type))
|
|
307
|
+
if row < shape[0]:
|
|
308
|
+
cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
|
|
309
|
+
|
|
209
310
|
threads_per_row = tv_layout.shape[0][0]
|
|
210
|
-
sum_sq_x =
|
|
311
|
+
sum_sq_x = row_reduce(
|
|
211
312
|
x * x,
|
|
212
313
|
cute.ReductionOp.ADD,
|
|
213
314
|
threads_per_row,
|
|
214
315
|
reduction_buffer[None, None, 0],
|
|
215
316
|
mbar_ptr,
|
|
216
317
|
init_val=0.0,
|
|
217
|
-
hook_fn=(cute.arch.cluster_wait if
|
|
318
|
+
hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
|
|
218
319
|
)
|
|
219
320
|
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
220
|
-
if
|
|
321
|
+
if const_expr(mRstd is not None):
|
|
221
322
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
222
323
|
if (
|
|
223
324
|
tXcX[0][1] == 0
|
|
@@ -225,59 +326,76 @@ class RMSNorm(ReductionBase):
|
|
|
225
326
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
226
327
|
):
|
|
227
328
|
tXrRstd[0] = rstd
|
|
228
|
-
if
|
|
229
|
-
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
329
|
+
if const_expr(delay_w_load):
|
|
330
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
|
|
331
|
+
if const_expr(mB is not None):
|
|
332
|
+
cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
|
|
333
|
+
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
334
|
+
if const_expr(reload_from == "smem"):
|
|
335
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
336
|
+
else:
|
|
337
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
235
338
|
x = tXrX.load().to(cute.Float32)
|
|
339
|
+
if const_expr(mRes is not None):
|
|
340
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
341
|
+
x += tXrRes.load().to(cute.Float32)
|
|
236
342
|
x_hat = x * rstd
|
|
237
343
|
w = tXrW.load().to(cute.Float32)
|
|
238
344
|
y = x_hat * w
|
|
345
|
+
if const_expr(mB is not None):
|
|
346
|
+
b = tXrB.load().to(cute.Float32)
|
|
347
|
+
y = y + b
|
|
239
348
|
tXrO.store(y.to(tXrO.element_type))
|
|
240
|
-
tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
241
349
|
if row < shape[0]:
|
|
242
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=
|
|
350
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
|
|
243
351
|
|
|
244
352
|
|
|
353
|
+
@torch.library.custom_op(
|
|
354
|
+
"quack::_rmsnorm_fwd",
|
|
355
|
+
mutates_args=("out", "rstd", "residual_out"),
|
|
356
|
+
device_types="cuda",
|
|
357
|
+
# We need to specify the schema manually since we're mutating an optional tensor
|
|
358
|
+
schema="(Tensor x, Tensor weight, Tensor(a!) out, Tensor? bias, Tensor(a!)? rstd, Tensor? residual, Tensor(a!)? residual_out, float eps=1e-6) -> ()",
|
|
359
|
+
)
|
|
245
360
|
def _rmsnorm_fwd(
|
|
246
|
-
x:
|
|
247
|
-
weight:
|
|
361
|
+
x: Tensor,
|
|
362
|
+
weight: Tensor,
|
|
363
|
+
out: Tensor,
|
|
364
|
+
bias: Optional[Tensor] = None,
|
|
365
|
+
rstd: Optional[Tensor] = None,
|
|
366
|
+
residual: Optional[Tensor] = None,
|
|
367
|
+
residual_out: Optional[Tensor] = None,
|
|
248
368
|
eps: float = 1e-6,
|
|
249
|
-
|
|
250
|
-
) -> torch.Tensor:
|
|
369
|
+
) -> None:
|
|
251
370
|
"""RMSNorm forward pass.
|
|
252
371
|
Args:
|
|
253
372
|
x: Input tensor of shape (M, N)
|
|
254
373
|
weight: Weight tensor of shape (N,)
|
|
255
374
|
eps: Small value for numerical stability
|
|
256
|
-
return_rstd: Whether to return the reciprocal standard deviation
|
|
257
375
|
Returns:
|
|
258
376
|
Normalized output tensor of same shape as x
|
|
259
|
-
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
260
377
|
"""
|
|
261
378
|
assert x.dim() == 2, "Input must be 2D"
|
|
262
379
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
263
380
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
264
381
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
265
|
-
assert x.dtype in [
|
|
266
|
-
torch.float16,
|
|
267
|
-
torch.bfloat16,
|
|
268
|
-
torch.float32,
|
|
269
|
-
], "Unsupported dtype"
|
|
270
|
-
|
|
382
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
271
383
|
assert weight.dtype in [
|
|
272
384
|
torch.float32,
|
|
273
385
|
torch.bfloat16,
|
|
274
386
|
torch.float16,
|
|
275
387
|
], "Weight must be float32, float16 or bfloat16"
|
|
276
|
-
|
|
277
|
-
|
|
388
|
+
if residual is not None:
|
|
389
|
+
assert residual.shape == x.shape
|
|
390
|
+
assert residual.is_cuda
|
|
391
|
+
assert residual.dtype in [
|
|
392
|
+
torch.float16,
|
|
393
|
+
torch.bfloat16,
|
|
394
|
+
torch.float32,
|
|
395
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
396
|
+
|
|
397
|
+
_, N = x.shape
|
|
278
398
|
device = x.device
|
|
279
|
-
out = torch.empty_like(x)
|
|
280
|
-
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
281
399
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
282
400
|
# convert_from_dlpack = lambda x: (
|
|
283
401
|
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -287,44 +405,109 @@ def _rmsnorm_fwd(
|
|
|
287
405
|
convert_from_dlpack = lambda x: (
|
|
288
406
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
289
407
|
)
|
|
290
|
-
x_tensor, out_tensor = [
|
|
408
|
+
x_tensor, res_tensor, out_tensor, res_out_tensor = [
|
|
409
|
+
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
410
|
+
]
|
|
291
411
|
# handle weight divisibility based on weight dtype
|
|
292
412
|
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
293
413
|
weight_tensor = utils.convert_from_dlpack(
|
|
294
414
|
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
295
415
|
)
|
|
416
|
+
if bias is not None:
|
|
417
|
+
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
418
|
+
bias_tensor = utils.convert_from_dlpack(
|
|
419
|
+
bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
|
|
420
|
+
)
|
|
421
|
+
else:
|
|
422
|
+
bias_tensor = None
|
|
296
423
|
rstd_tensor = (
|
|
297
424
|
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
298
425
|
if rstd is not None
|
|
299
426
|
else None
|
|
300
427
|
)
|
|
301
428
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
302
|
-
compile_key = (
|
|
429
|
+
compile_key = (
|
|
430
|
+
N,
|
|
431
|
+
dtype,
|
|
432
|
+
res_tensor.element_type if residual is not None else None,
|
|
433
|
+
weight_tensor.element_type,
|
|
434
|
+
bias_tensor.element_type if bias is not None else None,
|
|
435
|
+
res_out_tensor.element_type if residual_out is not None else None,
|
|
436
|
+
rstd is not None,
|
|
437
|
+
)
|
|
303
438
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
304
439
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
305
440
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
306
|
-
rmsnorm_op,
|
|
441
|
+
rmsnorm_op,
|
|
442
|
+
x_tensor,
|
|
443
|
+
weight_tensor,
|
|
444
|
+
bias_tensor,
|
|
445
|
+
res_tensor,
|
|
446
|
+
out_tensor,
|
|
447
|
+
res_out_tensor,
|
|
448
|
+
rstd_tensor,
|
|
449
|
+
current_stream,
|
|
450
|
+
eps,
|
|
307
451
|
)
|
|
308
452
|
_rmsnorm_fwd.compile_cache[compile_key](
|
|
309
|
-
x_tensor,
|
|
453
|
+
x_tensor,
|
|
454
|
+
weight_tensor,
|
|
455
|
+
bias_tensor,
|
|
456
|
+
res_tensor,
|
|
457
|
+
out_tensor,
|
|
458
|
+
res_out_tensor,
|
|
459
|
+
rstd_tensor,
|
|
460
|
+
current_stream,
|
|
461
|
+
eps,
|
|
310
462
|
)
|
|
311
|
-
return (out, rstd) if return_rstd else out
|
|
312
463
|
|
|
313
464
|
|
|
314
465
|
_rmsnorm_fwd.compile_cache = {}
|
|
315
466
|
|
|
316
467
|
|
|
317
|
-
def
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
468
|
+
def rmsnorm_fwd(
|
|
469
|
+
x: Tensor,
|
|
470
|
+
weight: Tensor,
|
|
471
|
+
bias: Optional[Tensor] = None,
|
|
472
|
+
residual: Optional[Tensor] = None,
|
|
473
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
474
|
+
residual_dtype: Optional[torch.dtype] = None,
|
|
475
|
+
eps: float = 1e-6,
|
|
476
|
+
store_rstd: bool = False,
|
|
477
|
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
478
|
+
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
|
|
479
|
+
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
|
|
480
|
+
# so that _layer_norm_fwd_impl doesn't have to return them.
|
|
481
|
+
out_dtype = x.dtype if out_dtype is None else out_dtype
|
|
482
|
+
out = torch.empty_like(x, dtype=out_dtype)
|
|
483
|
+
rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None
|
|
484
|
+
if residual is not None:
|
|
485
|
+
residual_dtype = residual.dtype
|
|
486
|
+
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
|
487
|
+
residual_out = torch.empty_like(
|
|
488
|
+
x, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
|
489
|
+
)
|
|
490
|
+
else:
|
|
491
|
+
residual_out = None
|
|
492
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
|
|
493
|
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
|
494
|
+
if residual_out is None:
|
|
495
|
+
residual_out = x
|
|
496
|
+
return out, residual_out, rstd
|
|
322
497
|
|
|
323
498
|
|
|
324
|
-
def
|
|
499
|
+
def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
325
500
|
x_f32 = x.float()
|
|
326
|
-
|
|
327
|
-
|
|
501
|
+
if residual is not None:
|
|
502
|
+
residual_f32 = residual.float()
|
|
503
|
+
x_f32 += residual_f32
|
|
504
|
+
out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
|
|
505
|
+
if bias is not None:
|
|
506
|
+
out = out + bias.float()
|
|
507
|
+
if residual is None:
|
|
508
|
+
return out.to(x.dtype)
|
|
509
|
+
else:
|
|
510
|
+
return out.to(x.dtype), x_f32.to(residual.dtype)
|
|
328
511
|
|
|
329
512
|
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
330
513
|
"""Reference implementation for RMSNorm backward pass."""
|
|
@@ -338,7 +521,6 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
338
521
|
dw = (dout * x_hat).sum(dim=0)
|
|
339
522
|
return dx.to(x.dtype), dw.to(w.dtype)
|
|
340
523
|
|
|
341
|
-
|
|
342
524
|
class RMSNormBackward(ReductionBase):
|
|
343
525
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
344
526
|
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
@@ -372,11 +554,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
372
554
|
)
|
|
373
555
|
self.cluster_n = cluster_n
|
|
374
556
|
|
|
375
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
557
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
|
|
558
|
+
if do_dtype is None:
|
|
559
|
+
do_dtype = self.dtype
|
|
376
560
|
return (
|
|
377
|
-
#
|
|
378
|
-
|
|
379
|
-
cute.size_in_bytes(
|
|
561
|
+
# We need space for X and dO, and multiply by 2 due to double buffering
|
|
562
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
563
|
+
+ cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
|
|
380
564
|
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
381
565
|
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
382
566
|
)
|
|
@@ -386,10 +570,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
386
570
|
self,
|
|
387
571
|
mX: cute.Tensor,
|
|
388
572
|
mW: cute.Tensor,
|
|
389
|
-
|
|
573
|
+
mdO: cute.Tensor,
|
|
574
|
+
mdResO: Optional[cute.Tensor],
|
|
390
575
|
mRstd: cute.Tensor,
|
|
391
576
|
mdX: cute.Tensor,
|
|
392
577
|
mdW: cute.Tensor,
|
|
578
|
+
mdRes: Optional[cute.Tensor],
|
|
579
|
+
mdB: Optional[cute.Tensor],
|
|
393
580
|
sm_count: Int32,
|
|
394
581
|
stream: cuda.CUstream,
|
|
395
582
|
):
|
|
@@ -398,24 +585,36 @@ class RMSNormBackward(ReductionBase):
|
|
|
398
585
|
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
399
586
|
t.stride[1],
|
|
400
587
|
)
|
|
401
|
-
mX,
|
|
588
|
+
mX, mdO, mdResO, mdX, mdRes = [
|
|
402
589
|
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
403
|
-
|
|
590
|
+
if const_expr(t is not None)
|
|
591
|
+
else None
|
|
592
|
+
for t in (mX, mdO, mdResO, mdX, mdRes)
|
|
404
593
|
]
|
|
405
594
|
self._set_cluster_n()
|
|
406
|
-
|
|
595
|
+
largest_dtype_width = const_expr(
|
|
596
|
+
max(
|
|
597
|
+
mX.element_type.width,
|
|
598
|
+
mdO.element_type.width,
|
|
599
|
+
mdX.element_type.width,
|
|
600
|
+
mdResO.element_type.width if mdResO is not None else 0,
|
|
601
|
+
mdRes.element_type.width if mdRes is not None else 0,
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
tiler_mn, tv_layout = self._get_tv_layout(
|
|
605
|
+
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
606
|
+
)
|
|
407
607
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
408
608
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
409
|
-
|
|
410
609
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
411
610
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
412
611
|
|
|
413
612
|
num_blocks = sm_count
|
|
414
|
-
self.kernel(mX, mW,
|
|
613
|
+
self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
|
|
415
614
|
grid=[num_blocks, self.cluster_n, 1],
|
|
416
615
|
block=[num_threads, 1, 1],
|
|
417
616
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
418
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
617
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
|
|
419
618
|
stream=stream,
|
|
420
619
|
)
|
|
421
620
|
|
|
@@ -424,63 +623,85 @@ class RMSNormBackward(ReductionBase):
|
|
|
424
623
|
self,
|
|
425
624
|
mX: cute.Tensor,
|
|
426
625
|
mW: cute.Tensor,
|
|
427
|
-
|
|
626
|
+
mdO: cute.Tensor,
|
|
627
|
+
mdResO: Optional[cute.Tensor],
|
|
428
628
|
mRstd: cute.Tensor,
|
|
429
629
|
mdX: cute.Tensor,
|
|
430
630
|
mdW: cute.Tensor,
|
|
631
|
+
mdB: Optional[cute.Tensor],
|
|
632
|
+
mdRes: Optional[cute.Tensor],
|
|
431
633
|
tv_layout: cute.Layout,
|
|
432
634
|
tiler_mn: cute.Shape,
|
|
433
635
|
):
|
|
434
636
|
tidx, _, _ = cute.arch.thread_idx()
|
|
435
637
|
bidx_start, _, _ = cute.arch.block_idx()
|
|
436
638
|
gdim, _, _ = cute.arch.grid_dim()
|
|
437
|
-
if
|
|
639
|
+
if const_expr(self.cluster_n > 1):
|
|
438
640
|
cluster_y = cute.arch.block_idx()[1]
|
|
439
641
|
else:
|
|
440
|
-
cluster_y =
|
|
642
|
+
cluster_y = const_expr(0)
|
|
441
643
|
|
|
442
644
|
shape = mX.shape
|
|
443
645
|
M, N = shape[0], shape[1]
|
|
444
|
-
is_even_N =
|
|
646
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
445
647
|
|
|
446
648
|
idX = cute.make_identity_tensor(shape)
|
|
447
649
|
|
|
448
650
|
smem = cutlass.utils.SmemAllocator()
|
|
449
651
|
smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
|
|
450
652
|
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
|
|
451
|
-
|
|
653
|
+
sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
|
|
452
654
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
|
|
453
655
|
smem, tv_layout, is_persistent=True
|
|
454
656
|
)
|
|
455
|
-
if
|
|
657
|
+
if const_expr(mbar_ptr is not None):
|
|
456
658
|
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
|
|
457
659
|
else:
|
|
458
660
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
459
661
|
|
|
662
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
663
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
460
664
|
copy_atom_load_X = cute.make_copy_atom(
|
|
461
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=
|
|
665
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
462
666
|
)
|
|
463
667
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
464
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=
|
|
668
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
465
669
|
)
|
|
466
|
-
|
|
467
|
-
|
|
670
|
+
num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
|
|
671
|
+
copy_atom_load_dO_async = cute.make_copy_atom(
|
|
672
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
|
|
468
673
|
)
|
|
674
|
+
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
469
675
|
copy_atom_load_W = cute.make_copy_atom(
|
|
470
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
471
|
-
)
|
|
472
|
-
num_bits_per_copy_dX = cutlass.const_expr(
|
|
473
|
-
min(128, 128 // mX.element_type.width * mdX.element_type.width)
|
|
676
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
474
677
|
)
|
|
678
|
+
if const_expr(mdResO is not None):
|
|
679
|
+
num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
|
|
680
|
+
copy_atom_load_dResO = cute.make_copy_atom(
|
|
681
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
682
|
+
mdResO.element_type,
|
|
683
|
+
num_bits_per_copy=num_copy_bits_dResO,
|
|
684
|
+
)
|
|
685
|
+
num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
|
|
475
686
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
476
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=
|
|
477
|
-
)
|
|
478
|
-
num_bits_per_copy_dW = cutlass.const_expr(
|
|
479
|
-
min(128, 128 // mX.element_type.width * mdW.element_type.width)
|
|
687
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
|
|
480
688
|
)
|
|
689
|
+
num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
|
|
481
690
|
copy_atom_store_dW = cute.make_copy_atom(
|
|
482
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=
|
|
691
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
|
|
483
692
|
)
|
|
693
|
+
if const_expr(mdB is not None):
|
|
694
|
+
num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
|
|
695
|
+
copy_atom_store_dB = cute.make_copy_atom(
|
|
696
|
+
cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
|
|
697
|
+
)
|
|
698
|
+
if const_expr(mdRes is not None):
|
|
699
|
+
num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
|
|
700
|
+
copy_atom_load_dRes = cute.make_copy_atom(
|
|
701
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
702
|
+
mdRes.element_type,
|
|
703
|
+
num_bits_per_copy=num_copy_bits_dRes,
|
|
704
|
+
)
|
|
484
705
|
|
|
485
706
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
486
707
|
|
|
@@ -510,21 +731,36 @@ class RMSNormBackward(ReductionBase):
|
|
|
510
731
|
if not is_even_N
|
|
511
732
|
else None
|
|
512
733
|
)
|
|
734
|
+
if const_expr(mdB is not None):
|
|
735
|
+
db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
736
|
+
tXpdB = (
|
|
737
|
+
utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
|
|
738
|
+
if not is_even_N
|
|
739
|
+
else None
|
|
740
|
+
)
|
|
513
741
|
|
|
514
742
|
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
515
743
|
tXgdW = thr_copy_X.partition_S(gdW)
|
|
516
744
|
# Always compute partial weight gradients in fp32
|
|
517
745
|
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
518
746
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
747
|
+
gdB = cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y)) if const_expr(mdB is not None) else None
|
|
748
|
+
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
749
|
+
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
750
|
+
|
|
751
|
+
gX, gdO, gdResO, gdX, gdRes, cX = [
|
|
752
|
+
cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
|
|
753
|
+
for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
|
|
754
|
+
]
|
|
523
755
|
tXgX = thr_copy_X.partition_S(gX)
|
|
524
756
|
tXsX = thr_copy_X.partition_D(sX)
|
|
525
|
-
|
|
526
|
-
|
|
757
|
+
tXgdO = thr_copy_X.partition_S(gdO)
|
|
758
|
+
tXsdO = thr_copy_X.partition_D(sdO)
|
|
527
759
|
tXgdX = thr_copy_X.partition_D(gdX)
|
|
760
|
+
if const_expr(mdResO is not None):
|
|
761
|
+
tXgdResO = thr_copy_X.partition_S(gdResO)
|
|
762
|
+
if const_expr(mdRes is not None):
|
|
763
|
+
tXgdRes = thr_copy_X.partition_D(gdRes)
|
|
528
764
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
529
765
|
# This doesn't change across iterations
|
|
530
766
|
tXpX = (
|
|
@@ -533,62 +769,48 @@ class RMSNormBackward(ReductionBase):
|
|
|
533
769
|
else None
|
|
534
770
|
)
|
|
535
771
|
|
|
536
|
-
tXrX,
|
|
537
|
-
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX,
|
|
772
|
+
tXrX, tXrdO, tXrdX = [
|
|
773
|
+
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
538
774
|
]
|
|
775
|
+
if const_expr(mdResO is not None):
|
|
776
|
+
tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
|
|
777
|
+
if const_expr(mdRes is not None):
|
|
778
|
+
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
779
|
+
|
|
780
|
+
copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
|
|
781
|
+
copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
|
|
539
782
|
|
|
540
783
|
# Prefetch the first batch
|
|
541
784
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
542
785
|
if row < M:
|
|
543
786
|
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
tXgX_cur,
|
|
548
|
-
tXsX[None, None, None, 0],
|
|
549
|
-
pred=tXpX,
|
|
550
|
-
)
|
|
551
|
-
cute.copy(
|
|
552
|
-
copy_atom_load_X_async,
|
|
553
|
-
tXgdOut_cur,
|
|
554
|
-
tXsdOut[None, None, None, 0],
|
|
555
|
-
pred=tXpX,
|
|
556
|
-
)
|
|
787
|
+
tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
|
|
788
|
+
copy_X(tXgX_cur, tXsX[None, None, None, 0])
|
|
789
|
+
copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
|
|
557
790
|
elif tiler_mn[0] > 1:
|
|
558
791
|
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
559
792
|
# later into registers, causing wrong dW.
|
|
560
793
|
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
561
|
-
utils.fill_oob(
|
|
794
|
+
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
562
795
|
cute.arch.cp_async_commit_group()
|
|
563
796
|
|
|
564
|
-
if
|
|
797
|
+
if const_expr(self.cluster_n > 1):
|
|
565
798
|
cute.arch.cluster_wait()
|
|
566
799
|
|
|
567
800
|
threads_per_row = tv_layout.shape[0][0]
|
|
568
801
|
tXrdW.fill(0.0)
|
|
802
|
+
if const_expr(mdB is not None):
|
|
803
|
+
tXrdB.fill(0.0)
|
|
569
804
|
stage = Int32(0)
|
|
570
805
|
producer_phase = Int32(1)
|
|
571
806
|
consumer_phase = Int32(0)
|
|
572
807
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
573
808
|
row = tXcX[None, None, None, bidx][0][0]
|
|
574
|
-
rstd = cutlass.Float.zero
|
|
575
809
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
576
810
|
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
]
|
|
580
|
-
cute.copy(
|
|
581
|
-
copy_atom_load_X_async,
|
|
582
|
-
tXgX_cur,
|
|
583
|
-
tXsX[None, None, None, stage ^ 1],
|
|
584
|
-
pred=tXpX,
|
|
585
|
-
)
|
|
586
|
-
cute.copy(
|
|
587
|
-
copy_atom_load_X_async,
|
|
588
|
-
tXgdOut_cur,
|
|
589
|
-
tXsdOut[None, None, None, stage ^ 1],
|
|
590
|
-
pred=tXpX,
|
|
591
|
-
)
|
|
811
|
+
tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
|
|
812
|
+
copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
|
|
813
|
+
copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
|
|
592
814
|
elif tiler_mn[0] > 1:
|
|
593
815
|
utils.fill_oob(
|
|
594
816
|
tXsX[None, None, None, stage ^ 1],
|
|
@@ -596,36 +818,45 @@ class RMSNormBackward(ReductionBase):
|
|
|
596
818
|
fill_value=mX.element_type.zero,
|
|
597
819
|
)
|
|
598
820
|
utils.fill_oob(
|
|
599
|
-
|
|
821
|
+
tXsdO[None, None, None, stage ^ 1],
|
|
600
822
|
None,
|
|
601
|
-
fill_value=
|
|
823
|
+
fill_value=mdO.element_type.zero,
|
|
602
824
|
)
|
|
603
825
|
cute.arch.cp_async_commit_group()
|
|
826
|
+
rstd = cutlass.Float.zero
|
|
604
827
|
if row < M or tiler_mn[0] == 1:
|
|
605
828
|
rstd = mRstd[row]
|
|
829
|
+
if const_expr(mdResO is not None):
|
|
830
|
+
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
831
|
+
if row < M or tiler_mn[0] == 1:
|
|
832
|
+
cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
|
|
833
|
+
elif tiler_mn[0] > 1:
|
|
834
|
+
tXrdResO.fill(0.0)
|
|
606
835
|
cute.arch.cp_async_wait_group(1)
|
|
607
836
|
cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
|
|
608
837
|
x = tXrX.load().to(cute.Float32)
|
|
609
|
-
cute.autovec_copy(
|
|
610
|
-
dout =
|
|
838
|
+
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
839
|
+
dout = tXrdO.load().to(cute.Float32)
|
|
840
|
+
if const_expr(mdResO is not None):
|
|
841
|
+
dout += tXrdResO.load().to(cute.Float32)
|
|
611
842
|
x_hat = x * rstd
|
|
612
843
|
wdy = dout * weight
|
|
613
|
-
if
|
|
844
|
+
if const_expr(self.cluster_n > 1):
|
|
614
845
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
615
846
|
mean_xhat_wdy = (
|
|
616
|
-
|
|
847
|
+
row_reduce(
|
|
617
848
|
x_hat * wdy,
|
|
618
849
|
cute.ReductionOp.ADD,
|
|
619
850
|
threads_per_row,
|
|
620
851
|
reduction_buffer[None, None, stage],
|
|
621
|
-
(mbar_full_ptr + stage if
|
|
852
|
+
(mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
|
|
622
853
|
phase=consumer_phase,
|
|
623
854
|
init_val=0.0,
|
|
624
855
|
)
|
|
625
856
|
/ shape[1]
|
|
626
857
|
)
|
|
627
858
|
|
|
628
|
-
if
|
|
859
|
+
if const_expr(self.cluster_n > 1):
|
|
629
860
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
630
861
|
# Requires adjusting the thread_count when initializing the mbar
|
|
631
862
|
cute.arch.sync_warp()
|
|
@@ -635,28 +866,37 @@ class RMSNormBackward(ReductionBase):
|
|
|
635
866
|
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
636
867
|
)
|
|
637
868
|
|
|
638
|
-
if
|
|
639
|
-
cute.autovec_copy(
|
|
640
|
-
dout =
|
|
869
|
+
if const_expr(self.reload_wdy == "smem"):
|
|
870
|
+
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
871
|
+
dout = tXrdO.load().to(cute.Float32)
|
|
872
|
+
if const_expr(mdResO is not None):
|
|
873
|
+
dout += tXrdResO.load().to(cute.Float32)
|
|
641
874
|
wdy = dout * weight
|
|
642
875
|
|
|
643
876
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
644
|
-
tXrdX.store(dx.to(
|
|
877
|
+
tXrdX.store(dx.to(tXrdX.element_type))
|
|
645
878
|
if row < M or tiler_mn[0] == 1:
|
|
646
879
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
647
880
|
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
881
|
+
if const_expr(mdRes is not None):
|
|
882
|
+
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
883
|
+
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
884
|
+
if row < M or tiler_mn[0] == 1:
|
|
885
|
+
cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
|
|
648
886
|
# Accumulate weight gradients in fp32
|
|
649
887
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
888
|
+
if const_expr(mdB is not None):
|
|
889
|
+
tXrdB.store(tXrdB.load() + dout)
|
|
650
890
|
|
|
651
891
|
stage ^= 1
|
|
652
892
|
if stage == 0:
|
|
653
893
|
consumer_phase ^= 1
|
|
654
894
|
producer_phase ^= 1
|
|
655
895
|
|
|
656
|
-
if
|
|
896
|
+
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
657
897
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
658
898
|
|
|
659
|
-
if
|
|
899
|
+
if const_expr(tiler_mn[0] > 1):
|
|
660
900
|
# reduction of dw_partial within the same threadblock
|
|
661
901
|
sdW = cute.make_tensor(
|
|
662
902
|
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
@@ -669,23 +909,73 @@ class RMSNormBackward(ReductionBase):
|
|
|
669
909
|
cute.autovec_copy(tXrdW, tXsdW)
|
|
670
910
|
cute.arch.barrier()
|
|
671
911
|
if row == 0:
|
|
672
|
-
for i in cutlass.range_constexpr(1,
|
|
912
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
673
913
|
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
674
914
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
675
915
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
676
916
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
677
917
|
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
918
|
+
cute.arch.barrier()
|
|
919
|
+
if const_expr(mdB is not None):
|
|
920
|
+
sdB = cute.make_tensor(
|
|
921
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
922
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
923
|
+
)
|
|
924
|
+
tXsdB = thr_copy_X.partition_D(sdB)
|
|
925
|
+
cute.arch.barrier()
|
|
926
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
927
|
+
if row > 0:
|
|
928
|
+
cute.autovec_copy(tXrdB, tXsdB)
|
|
929
|
+
cute.arch.barrier()
|
|
930
|
+
if row == 0:
|
|
931
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
932
|
+
tXrdB_other = cute.make_fragment_like(tXrdB)
|
|
933
|
+
tXsdB_other = cute.make_tensor(tXsdB.iterator + i * sdB.stride[0], tXsdB.layout)
|
|
934
|
+
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
935
|
+
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
936
|
+
cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
|
|
678
937
|
else:
|
|
679
938
|
# dw is already in fp32, so we can directly copy to global memory
|
|
680
939
|
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
940
|
+
if const_expr(mdB is not None):
|
|
941
|
+
cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
|
|
681
942
|
|
|
682
943
|
|
|
683
|
-
def
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
944
|
+
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
945
|
+
# This should be tuned on how many CTAs can be launched on each SM
|
|
946
|
+
sm_count_multiple = (
|
|
947
|
+
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
948
|
+
)
|
|
949
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
950
|
+
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
951
|
+
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
952
|
+
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
953
|
+
# avoid wave quantization.
|
|
954
|
+
sm_count = (
|
|
955
|
+
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
return sm_count
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
@torch.library.custom_op(
|
|
962
|
+
"quack::_rmsnorm_bwd",
|
|
963
|
+
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
964
|
+
device_types="cuda",
|
|
965
|
+
# We need to specify the schema manually since we're mutating an optional tensor
|
|
966
|
+
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a!) dx, Tensor(a!) dw_partial, Tensor(a!)? db_partial, Tensor? dresidual_out, Tensor(a!)? dresidual) -> ()",
|
|
967
|
+
)
|
|
968
|
+
def _rmsnorm_bwd(
|
|
969
|
+
x: Tensor,
|
|
970
|
+
weight: Tensor,
|
|
971
|
+
dout: Tensor,
|
|
972
|
+
rstd: Tensor,
|
|
973
|
+
dx: Tensor,
|
|
974
|
+
dw_partial: Tensor,
|
|
975
|
+
db_partial: Optional[Tensor] = None,
|
|
976
|
+
dresidual_out: Optional[Tensor] = None,
|
|
977
|
+
dresidual: Optional[Tensor] = None,
|
|
978
|
+
) -> None:
|
|
689
979
|
"""RMSNorm backward pass.
|
|
690
980
|
Args:
|
|
691
981
|
x: Input tensor of shape (M, N)
|
|
@@ -701,46 +991,39 @@ def _rmsnorm_backward(
|
|
|
701
991
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
702
992
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
703
993
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
704
|
-
assert x.dtype in [
|
|
705
|
-
torch.float16,
|
|
706
|
-
torch.bfloat16,
|
|
707
|
-
torch.float32,
|
|
708
|
-
], "Unsupported dtype"
|
|
709
|
-
|
|
994
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
710
995
|
assert weight.dtype in [
|
|
711
996
|
torch.float32,
|
|
712
997
|
torch.bfloat16,
|
|
713
998
|
torch.float16,
|
|
714
999
|
], "Weight must be float32, float16 or bfloat16"
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
1000
|
+
if dresidual_out is not None:
|
|
1001
|
+
assert dresidual_out.shape == x.shape
|
|
1002
|
+
assert dresidual_out.is_cuda
|
|
1003
|
+
assert dresidual_out.dtype in [
|
|
1004
|
+
torch.float16,
|
|
1005
|
+
torch.bfloat16,
|
|
1006
|
+
torch.float32,
|
|
1007
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
1008
|
+
if dresidual is not None:
|
|
1009
|
+
assert dresidual.shape == x.shape
|
|
1010
|
+
assert dresidual.is_cuda
|
|
1011
|
+
assert dresidual.dtype in [
|
|
1012
|
+
torch.float16,
|
|
1013
|
+
torch.bfloat16,
|
|
1014
|
+
torch.float32,
|
|
1015
|
+
], "Residual must be float16, bfloat16, or float32"
|
|
1016
|
+
|
|
1017
|
+
N = x.size(1)
|
|
719
1018
|
device = x.device
|
|
720
|
-
|
|
721
|
-
# This should be tuned on how many CTAs can be launched on each SM
|
|
722
|
-
sm_count_multiple = (
|
|
723
|
-
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
724
|
-
)
|
|
725
|
-
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
726
|
-
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
727
|
-
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
728
|
-
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
729
|
-
# avoid wave quantization.
|
|
730
|
-
sm_count = (
|
|
731
|
-
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
732
|
-
)
|
|
733
|
-
|
|
734
|
-
# Always store partial gradients in fp32 for numerical accuracy
|
|
735
|
-
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
736
|
-
|
|
737
|
-
dtype = torch2cute_dtype_map[x.dtype]
|
|
738
|
-
|
|
1019
|
+
sm_count = dw_partial.shape[0]
|
|
739
1020
|
convert_from_dlpack = lambda x: (
|
|
740
1021
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
741
1022
|
)
|
|
742
|
-
x_tensor, dout_tensor, dx_tensor = [
|
|
743
|
-
|
|
1023
|
+
x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
|
|
1024
|
+
convert_from_dlpack(t) if t is not None else None
|
|
1025
|
+
for t in (x, dout, dresidual_out, dx, dresidual)
|
|
1026
|
+
]
|
|
744
1027
|
# Handle weight div based on weight dtype
|
|
745
1028
|
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
746
1029
|
weight_tensor = utils.convert_from_dlpack(
|
|
@@ -748,74 +1031,143 @@ def _rmsnorm_backward(
|
|
|
748
1031
|
)
|
|
749
1032
|
|
|
750
1033
|
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1034
|
+
db_partial_tensor = from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) if db_partial is not None else None
|
|
751
1035
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
752
1036
|
|
|
753
1037
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
754
1038
|
|
|
755
|
-
compile_key = (
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
1039
|
+
compile_key = (N, x_tensor.element_type, weight_tensor.element_type, db_partial.dtype if db_partial is not None else None,
|
|
1040
|
+
dresidual.dtype if dresidual is not None else None,
|
|
1041
|
+
dresidual_out.dtype if dresidual_out is not None else None)
|
|
1042
|
+
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1043
|
+
rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
|
|
1044
|
+
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
759
1045
|
rmsnorm_backward_op,
|
|
760
1046
|
x_tensor,
|
|
761
1047
|
weight_tensor,
|
|
762
1048
|
dout_tensor,
|
|
1049
|
+
dres_out_tensor,
|
|
763
1050
|
rstd_tensor,
|
|
764
1051
|
dx_tensor,
|
|
765
1052
|
dw_partial_tensor,
|
|
1053
|
+
dres_tensor,
|
|
1054
|
+
db_partial_tensor,
|
|
766
1055
|
sm_count,
|
|
767
1056
|
current_stream,
|
|
768
1057
|
)
|
|
769
1058
|
|
|
770
|
-
|
|
1059
|
+
_rmsnorm_bwd.compile_cache[compile_key](
|
|
771
1060
|
x_tensor,
|
|
772
1061
|
weight_tensor,
|
|
773
1062
|
dout_tensor,
|
|
1063
|
+
dres_out_tensor,
|
|
774
1064
|
rstd_tensor,
|
|
775
1065
|
dx_tensor,
|
|
776
1066
|
dw_partial_tensor,
|
|
1067
|
+
dres_tensor,
|
|
1068
|
+
db_partial_tensor,
|
|
777
1069
|
sm_count,
|
|
778
1070
|
current_stream,
|
|
779
1071
|
)
|
|
780
|
-
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
781
|
-
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
782
|
-
return dx, dw
|
|
783
1072
|
|
|
784
1073
|
|
|
785
|
-
|
|
1074
|
+
_rmsnorm_bwd.compile_cache = {}
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def rmsnorm_bwd(
|
|
1078
|
+
x: Tensor,
|
|
1079
|
+
weight: Tensor,
|
|
1080
|
+
dout: Tensor,
|
|
1081
|
+
rstd: Tensor,
|
|
1082
|
+
dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
|
|
1083
|
+
has_bias: bool = False,
|
|
1084
|
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
1085
|
+
device = x.device
|
|
1086
|
+
N = x.size(1)
|
|
1087
|
+
sm_count = _get_sm_count(N, device)
|
|
1088
|
+
dx = torch.empty_like(x)
|
|
1089
|
+
|
|
1090
|
+
if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
|
|
1091
|
+
dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
|
|
1092
|
+
else:
|
|
1093
|
+
dresidual = None
|
|
1094
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
1095
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
1096
|
+
db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
|
|
1097
|
+
_rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
|
|
1098
|
+
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
1099
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
1100
|
+
db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
|
|
1101
|
+
# dresidual is the same as dx in this case
|
|
1102
|
+
if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
|
|
1103
|
+
dresidual = dx
|
|
1104
|
+
return dx, dw, db, dresidual
|
|
786
1105
|
|
|
787
1106
|
|
|
788
1107
|
class RMSNormFunction(torch.autograd.Function):
|
|
789
1108
|
@staticmethod
|
|
790
|
-
def forward(ctx, x, weight, eps):
|
|
791
|
-
|
|
792
|
-
|
|
1109
|
+
def forward(ctx, x, weight, bias=None, residual=None, out_dtype=None, residual_dtype=None, eps=1e-6, prenorm=False):
|
|
1110
|
+
x_shape_og = x.shape
|
|
793
1111
|
# Flatten input
|
|
794
|
-
x = x.
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
ctx.
|
|
1112
|
+
x = x.reshape(-1, x.shape[-1])
|
|
1113
|
+
if residual is not None:
|
|
1114
|
+
residual = residual.reshape(-1, residual.shape[-1])
|
|
1115
|
+
need_grad = any(ctx.needs_input_grad[:3])
|
|
1116
|
+
out, residual_out, rstd = rmsnorm_fwd(
|
|
1117
|
+
x,
|
|
1118
|
+
weight,
|
|
1119
|
+
bias=bias,
|
|
1120
|
+
residual=residual,
|
|
1121
|
+
out_dtype=out_dtype,
|
|
1122
|
+
residual_dtype=residual_dtype,
|
|
1123
|
+
eps=eps,
|
|
1124
|
+
store_rstd=need_grad,
|
|
1125
|
+
)
|
|
1126
|
+
ctx.save_for_backward(x if residual is None else residual_out, weight, rstd)
|
|
1127
|
+
ctx.has_bias = bias is not None
|
|
798
1128
|
ctx.eps = eps
|
|
799
|
-
ctx.
|
|
800
|
-
|
|
801
|
-
|
|
1129
|
+
ctx.x_shape_og = x_shape_og
|
|
1130
|
+
ctx.residual_dtype = residual.dtype if residual is not None else None
|
|
1131
|
+
ctx.prenorm = prenorm
|
|
1132
|
+
if residual_out is None or prenorm == False:
|
|
1133
|
+
return out.reshape(x_shape_og)
|
|
1134
|
+
else:
|
|
1135
|
+
return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
|
|
802
1136
|
|
|
803
1137
|
@staticmethod
|
|
804
|
-
def backward(ctx, dout):
|
|
1138
|
+
def backward(ctx, dout, *args):
|
|
805
1139
|
x, weight, rstd = ctx.saved_tensors
|
|
806
|
-
|
|
1140
|
+
has_bias = ctx.has_bias
|
|
1141
|
+
if ctx.prenorm and ctx.residual_dtype is not None:
|
|
1142
|
+
dresidual_out = args[0]
|
|
1143
|
+
dresidual_out = dresidual_out.reshape(-1, dresidual_out.shape[-1])
|
|
1144
|
+
else:
|
|
1145
|
+
dresidual_out = None
|
|
1146
|
+
x_shape_og = ctx.x_shape_og
|
|
807
1147
|
# Reshape dout to match the flattened shape used in forward
|
|
808
1148
|
dout = dout.view(-1, dout.shape[-1])
|
|
809
|
-
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
810
|
-
dx = dx.view(x_shape_start)
|
|
811
|
-
# dx is returned for input gradient,
|
|
812
|
-
# dw is returned for weight gradient,
|
|
813
|
-
# None for eps gradient
|
|
814
|
-
return dx, dw, None
|
|
815
1149
|
|
|
1150
|
+
dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
|
|
1151
|
+
dx = dx.view(x_shape_og)
|
|
1152
|
+
if dresidual_out is not None:
|
|
1153
|
+
dresidual_out = dresidual_out.reshape(x_shape_og)
|
|
1154
|
+
if dresidual is not None:
|
|
1155
|
+
dresidual = dresidual.reshape(x_shape_og)
|
|
1156
|
+
|
|
1157
|
+
return dx, dw, db, dresidual, *([None] * 4)
|
|
816
1158
|
|
|
817
|
-
|
|
818
|
-
|
|
1159
|
+
|
|
1160
|
+
def rmsnorm(
|
|
1161
|
+
x: Tensor,
|
|
1162
|
+
weight: Tensor,
|
|
1163
|
+
bias: Optional[Tensor] = None,
|
|
1164
|
+
residual: Optional[Tensor] = None,
|
|
1165
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1166
|
+
residual_dtype: Optional[torch.dtype] = None,
|
|
1167
|
+
eps: float = 1e-6,
|
|
1168
|
+
prenorm: bool = False,
|
|
1169
|
+
) -> Tensor:
|
|
1170
|
+
"""RMSNorm with automatic differentiation support.
|
|
819
1171
|
|
|
820
1172
|
Args:
|
|
821
1173
|
x: Input tensor of shape (M, N)
|
|
@@ -825,7 +1177,7 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
|
|
|
825
1177
|
Returns:
|
|
826
1178
|
Normalized output tensor of same shape as x
|
|
827
1179
|
"""
|
|
828
|
-
return RMSNormFunction.apply(x, weight, eps)
|
|
1180
|
+
return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
|
|
829
1181
|
|
|
830
1182
|
|
|
831
1183
|
class QuackRMSNorm(torch.nn.Module):
|
|
@@ -848,17 +1200,17 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
848
1200
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
849
1201
|
self.eps = eps
|
|
850
1202
|
|
|
851
|
-
def forward(self, x:
|
|
1203
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
852
1204
|
"""Apply RMSNorm to the input tensor.
|
|
853
1205
|
|
|
854
1206
|
Args:
|
|
855
|
-
x (
|
|
1207
|
+
x (Tensor): Input tensor
|
|
856
1208
|
|
|
857
1209
|
Returns:
|
|
858
|
-
|
|
1210
|
+
Tensor: Normalized tensor
|
|
859
1211
|
"""
|
|
860
|
-
return rmsnorm(x, self.weight, self.eps)
|
|
1212
|
+
return rmsnorm(x, self.weight, eps=self.eps)
|
|
861
1213
|
|
|
862
1214
|
def reset_parameters(self):
|
|
863
1215
|
"""Reset the weight parameter to ones."""
|
|
864
|
-
torch.nn.init.ones_(self.weight)
|
|
1216
|
+
torch.nn.init.ones_(self.weight)
|