quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -19,6 +19,7 @@ from quack.reduce import row_reduce
|
|
|
19
19
|
from quack.reduction_base import ReductionBase
|
|
20
20
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
21
21
|
|
|
22
|
+
|
|
22
23
|
class RMSNorm(ReductionBase):
|
|
23
24
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
24
25
|
super().__init__(dtype, N, stage=1)
|
|
@@ -93,7 +94,7 @@ class RMSNorm(ReductionBase):
|
|
|
93
94
|
def __call__(
|
|
94
95
|
self,
|
|
95
96
|
mX: cute.Tensor,
|
|
96
|
-
mW: cute.Tensor,
|
|
97
|
+
mW: Optional[cute.Tensor],
|
|
97
98
|
mB: Optional[cute.Tensor],
|
|
98
99
|
mRes: Optional[cute.Tensor],
|
|
99
100
|
mO: cute.Tensor,
|
|
@@ -129,10 +130,15 @@ class RMSNorm(ReductionBase):
|
|
|
129
130
|
)
|
|
130
131
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
131
132
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
132
|
-
|
|
133
|
-
|
|
133
|
+
if const_expr(mW is not None):
|
|
134
|
+
mW_expanded_layout = cute.prepend(
|
|
135
|
+
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
136
|
+
)
|
|
137
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
134
138
|
if const_expr(mB is not None):
|
|
135
|
-
mB_expanded_layout = cute.prepend(
|
|
139
|
+
mB_expanded_layout = cute.prepend(
|
|
140
|
+
mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
141
|
+
)
|
|
136
142
|
mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
|
|
137
143
|
if const_expr(mRstd is not None):
|
|
138
144
|
mRstd_expanded_layout = cute.append(
|
|
@@ -155,7 +161,7 @@ class RMSNorm(ReductionBase):
|
|
|
155
161
|
def kernel(
|
|
156
162
|
self,
|
|
157
163
|
mX: cute.Tensor,
|
|
158
|
-
mW: cute.Tensor,
|
|
164
|
+
mW: Optional[cute.Tensor],
|
|
159
165
|
mB: Optional[cute.Tensor],
|
|
160
166
|
mRes: Optional[cute.Tensor],
|
|
161
167
|
mO: cute.Tensor,
|
|
@@ -201,12 +207,10 @@ class RMSNorm(ReductionBase):
|
|
|
201
207
|
for mT in (mX, mRes, mO, mResO)
|
|
202
208
|
]
|
|
203
209
|
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
204
|
-
gW
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
else None
|
|
209
|
-
)
|
|
210
|
+
gW, gB = [
|
|
211
|
+
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
|
|
212
|
+
for mT in (mW, mB)
|
|
213
|
+
]
|
|
210
214
|
gRstd = (
|
|
211
215
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
212
216
|
if const_expr(mRstd is not None)
|
|
@@ -215,47 +219,14 @@ class RMSNorm(ReductionBase):
|
|
|
215
219
|
|
|
216
220
|
# declare the atoms which will be used later for memory copy
|
|
217
221
|
num_copy_elems_X = tv_layout.shape[1][0]
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
222
|
+
copy_atom_load_X_async = utils.get_copy_atom(
|
|
223
|
+
mX.element_type, num_copy_elems_X, is_async=True
|
|
221
224
|
)
|
|
222
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
223
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
224
|
-
)
|
|
225
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
226
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
227
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
228
|
-
)
|
|
229
|
-
num_bits_per_copy_B = cutlass.const_expr(
|
|
230
|
-
min(128, num_copy_elems_X * mB.element_type.width)
|
|
231
|
-
) if const_expr(mB is not None) else 0
|
|
232
|
-
copy_atom_load_B = cute.make_copy_atom(
|
|
233
|
-
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
234
|
-
) if const_expr(mB is not None) else None
|
|
235
|
-
if const_expr(mRes is not None):
|
|
236
|
-
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
237
|
-
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
238
|
-
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
239
|
-
mRes.element_type,
|
|
240
|
-
num_bits_per_copy=num_copy_bits_Res,
|
|
241
|
-
)
|
|
242
|
-
num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
|
|
243
|
-
copy_atom_store_O = cute.make_copy_atom(
|
|
244
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
|
|
245
|
-
)
|
|
246
|
-
if const_expr(mResO is not None):
|
|
247
|
-
num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
|
|
248
|
-
copy_atom_store_ResO = cute.make_copy_atom(
|
|
249
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
250
|
-
mResO.element_type,
|
|
251
|
-
num_bits_per_copy=num_copy_bits_ResO,
|
|
252
|
-
)
|
|
253
|
-
|
|
254
225
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
255
226
|
tidx
|
|
256
227
|
)
|
|
257
228
|
|
|
258
|
-
tXgW = thr_copy_X.partition_S(gW)
|
|
229
|
+
tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
|
|
259
230
|
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
260
231
|
tXgX = thr_copy_X.partition_S(gX)
|
|
261
232
|
tXsX = thr_copy_X.partition_D(sX)
|
|
@@ -269,8 +240,9 @@ class RMSNorm(ReductionBase):
|
|
|
269
240
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
270
241
|
|
|
271
242
|
# allocate fragments for gmem->rmem
|
|
272
|
-
tXrW = cute.make_fragment_like(tXgW)
|
|
273
|
-
|
|
243
|
+
tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
|
|
244
|
+
if const_expr(mW is not None):
|
|
245
|
+
tXrW.fill(0.0)
|
|
274
246
|
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
275
247
|
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
276
248
|
if const_expr(mRes is not None):
|
|
@@ -283,17 +255,21 @@ class RMSNorm(ReductionBase):
|
|
|
283
255
|
tXpX = (
|
|
284
256
|
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
285
257
|
)
|
|
258
|
+
# Each copy will use the same number of elements as X and same predicate
|
|
259
|
+
copy = partial(utils.copy, pred=tXpX, num_copy_elems=num_copy_elems_X)
|
|
260
|
+
|
|
286
261
|
row = tXcX[0][0]
|
|
287
262
|
if row < shape[0]:
|
|
288
|
-
|
|
263
|
+
copy(tXgX, tXsX, is_async=True)
|
|
289
264
|
if const_expr(mRes is not None):
|
|
290
|
-
|
|
265
|
+
copy(tXgRes, tXsRes, is_async=True)
|
|
291
266
|
cute.arch.cp_async_commit_group()
|
|
292
267
|
|
|
293
268
|
if const_expr(not delay_w_load):
|
|
294
|
-
|
|
269
|
+
if const_expr(mW is not None):
|
|
270
|
+
copy(tXgW, tXrW)
|
|
295
271
|
if const_expr(mB is not None):
|
|
296
|
-
|
|
272
|
+
copy(tXgB, tXrB)
|
|
297
273
|
|
|
298
274
|
cute.arch.cp_async_wait_group(0)
|
|
299
275
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -305,7 +281,7 @@ class RMSNorm(ReductionBase):
|
|
|
305
281
|
tXrResO = cute.make_fragment_like(tXgResO)
|
|
306
282
|
tXrResO.store(x.to(tXrResO.element_type))
|
|
307
283
|
if row < shape[0]:
|
|
308
|
-
|
|
284
|
+
copy(tXrResO, tXgResO)
|
|
309
285
|
|
|
310
286
|
threads_per_row = tv_layout.shape[0][0]
|
|
311
287
|
sum_sq_x = row_reduce(
|
|
@@ -317,7 +293,7 @@ class RMSNorm(ReductionBase):
|
|
|
317
293
|
init_val=0.0,
|
|
318
294
|
hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
|
|
319
295
|
)
|
|
320
|
-
rstd =
|
|
296
|
+
rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
|
|
321
297
|
if const_expr(mRstd is not None):
|
|
322
298
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
323
299
|
if (
|
|
@@ -327,27 +303,28 @@ class RMSNorm(ReductionBase):
|
|
|
327
303
|
):
|
|
328
304
|
tXrRstd[0] = rstd
|
|
329
305
|
if const_expr(delay_w_load):
|
|
330
|
-
|
|
306
|
+
if const_expr(mW is not None):
|
|
307
|
+
copy(tXgW, tXrW)
|
|
331
308
|
if const_expr(mB is not None):
|
|
332
|
-
|
|
309
|
+
copy(tXgB, tXrB)
|
|
333
310
|
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
334
311
|
if const_expr(reload_from == "smem"):
|
|
335
312
|
cute.autovec_copy(tXsX, tXrX)
|
|
336
313
|
else:
|
|
337
|
-
|
|
314
|
+
copy(tXgX, tXrX)
|
|
338
315
|
x = tXrX.load().to(cute.Float32)
|
|
339
316
|
if const_expr(mRes is not None):
|
|
340
317
|
cute.autovec_copy(tXsRes, tXrRes)
|
|
341
318
|
x += tXrRes.load().to(cute.Float32)
|
|
342
319
|
x_hat = x * rstd
|
|
343
|
-
|
|
344
|
-
|
|
320
|
+
y = x_hat
|
|
321
|
+
if const_expr(mW is not None):
|
|
322
|
+
y *= tXrW.load().to(cute.Float32)
|
|
345
323
|
if const_expr(mB is not None):
|
|
346
|
-
|
|
347
|
-
y = y + b
|
|
324
|
+
y += tXrB.load().to(cute.Float32)
|
|
348
325
|
tXrO.store(y.to(tXrO.element_type))
|
|
349
326
|
if row < shape[0]:
|
|
350
|
-
|
|
327
|
+
copy(tXrO, tXgO)
|
|
351
328
|
|
|
352
329
|
|
|
353
330
|
@torch.library.custom_op(
|
|
@@ -355,11 +332,11 @@ class RMSNorm(ReductionBase):
|
|
|
355
332
|
mutates_args=("out", "rstd", "residual_out"),
|
|
356
333
|
device_types="cuda",
|
|
357
334
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
358
|
-
schema="(Tensor x, Tensor weight, Tensor(
|
|
335
|
+
schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
|
|
359
336
|
)
|
|
360
337
|
def _rmsnorm_fwd(
|
|
361
338
|
x: Tensor,
|
|
362
|
-
weight: Tensor,
|
|
339
|
+
weight: Optional[Tensor],
|
|
363
340
|
out: Tensor,
|
|
364
341
|
bias: Optional[Tensor] = None,
|
|
365
342
|
rstd: Optional[Tensor] = None,
|
|
@@ -370,21 +347,23 @@ def _rmsnorm_fwd(
|
|
|
370
347
|
"""RMSNorm forward pass.
|
|
371
348
|
Args:
|
|
372
349
|
x: Input tensor of shape (M, N)
|
|
373
|
-
weight:
|
|
350
|
+
weight: Optional weight tensor of shape (N,)
|
|
374
351
|
eps: Small value for numerical stability
|
|
375
352
|
Returns:
|
|
376
353
|
Normalized output tensor of same shape as x
|
|
377
354
|
"""
|
|
378
355
|
assert x.dim() == 2, "Input must be 2D"
|
|
379
|
-
assert
|
|
380
|
-
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
381
|
-
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
356
|
+
assert x.is_cuda, "Input tensor must be on CUDA device"
|
|
382
357
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
358
|
+
if weight is not None:
|
|
359
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
360
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
361
|
+
assert weight.is_cuda, "Weight tensor must be on CUDA device"
|
|
362
|
+
assert weight.dtype in [
|
|
363
|
+
torch.float32,
|
|
364
|
+
torch.bfloat16,
|
|
365
|
+
torch.float16,
|
|
366
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
388
367
|
if residual is not None:
|
|
389
368
|
assert residual.shape == x.shape
|
|
390
369
|
assert residual.is_cuda
|
|
@@ -397,11 +376,6 @@ def _rmsnorm_fwd(
|
|
|
397
376
|
_, N = x.shape
|
|
398
377
|
device = x.device
|
|
399
378
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
400
|
-
# convert_from_dlpack = lambda x: (
|
|
401
|
-
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
402
|
-
# mode=0, divisibility=128 // dtype.width
|
|
403
|
-
# )
|
|
404
|
-
# )
|
|
405
379
|
convert_from_dlpack = lambda x: (
|
|
406
380
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
407
381
|
)
|
|
@@ -409,10 +383,13 @@ def _rmsnorm_fwd(
|
|
|
409
383
|
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
410
384
|
]
|
|
411
385
|
# handle weight divisibility based on weight dtype
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
386
|
+
if weight is not None:
|
|
387
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
388
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
389
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
weight_tensor = None
|
|
416
393
|
if bias is not None:
|
|
417
394
|
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
418
395
|
bias_tensor = utils.convert_from_dlpack(
|
|
@@ -430,7 +407,7 @@ def _rmsnorm_fwd(
|
|
|
430
407
|
N,
|
|
431
408
|
dtype,
|
|
432
409
|
res_tensor.element_type if residual is not None else None,
|
|
433
|
-
weight_tensor.element_type,
|
|
410
|
+
weight_tensor.element_type if weight is not None else None,
|
|
434
411
|
bias_tensor.element_type if bias is not None else None,
|
|
435
412
|
res_out_tensor.element_type if residual_out is not None else None,
|
|
436
413
|
rstd is not None,
|
|
@@ -467,7 +444,7 @@ _rmsnorm_fwd.compile_cache = {}
|
|
|
467
444
|
|
|
468
445
|
def rmsnorm_fwd(
|
|
469
446
|
x: Tensor,
|
|
470
|
-
weight: Tensor,
|
|
447
|
+
weight: Optional[Tensor] = None,
|
|
471
448
|
bias: Optional[Tensor] = None,
|
|
472
449
|
residual: Optional[Tensor] = None,
|
|
473
450
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -496,12 +473,13 @@ def rmsnorm_fwd(
|
|
|
496
473
|
return out, residual_out, rstd
|
|
497
474
|
|
|
498
475
|
|
|
499
|
-
def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
476
|
+
def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
|
|
500
477
|
x_f32 = x.float()
|
|
501
478
|
if residual is not None:
|
|
502
479
|
residual_f32 = residual.float()
|
|
503
480
|
x_f32 += residual_f32
|
|
504
|
-
|
|
481
|
+
x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
|
|
482
|
+
out = x_norm * w if w is not None else x_norm
|
|
505
483
|
if bias is not None:
|
|
506
484
|
out = out + bias.float()
|
|
507
485
|
if residual is None:
|
|
@@ -509,6 +487,7 @@ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
|
509
487
|
else:
|
|
510
488
|
return out.to(x.dtype), x_f32.to(residual.dtype)
|
|
511
489
|
|
|
490
|
+
|
|
512
491
|
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
513
492
|
"""Reference implementation for RMSNorm backward pass."""
|
|
514
493
|
x_f32 = x.float()
|
|
@@ -521,6 +500,7 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
521
500
|
dw = (dout * x_hat).sum(dim=0)
|
|
522
501
|
return dx.to(x.dtype), dw.to(w.dtype)
|
|
523
502
|
|
|
503
|
+
|
|
524
504
|
class RMSNormBackward(ReductionBase):
|
|
525
505
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
526
506
|
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
@@ -606,8 +586,11 @@ class RMSNormBackward(ReductionBase):
|
|
|
606
586
|
)
|
|
607
587
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
608
588
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
609
|
-
|
|
610
|
-
|
|
589
|
+
if const_expr(mW is not None):
|
|
590
|
+
mW_expanded_layout = cute.prepend(
|
|
591
|
+
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
592
|
+
)
|
|
593
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
611
594
|
|
|
612
595
|
num_blocks = sm_count
|
|
613
596
|
self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
|
|
@@ -660,50 +643,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
660
643
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
661
644
|
|
|
662
645
|
num_copy_elems_X = tv_layout.shape[1][0]
|
|
663
|
-
|
|
664
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
665
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
666
|
-
)
|
|
667
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
668
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
669
|
-
)
|
|
670
|
-
num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
|
|
671
|
-
copy_atom_load_dO_async = cute.make_copy_atom(
|
|
672
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
|
|
673
|
-
)
|
|
674
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
675
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
676
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
677
|
-
)
|
|
678
|
-
if const_expr(mdResO is not None):
|
|
679
|
-
num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
|
|
680
|
-
copy_atom_load_dResO = cute.make_copy_atom(
|
|
681
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
682
|
-
mdResO.element_type,
|
|
683
|
-
num_bits_per_copy=num_copy_bits_dResO,
|
|
684
|
-
)
|
|
685
|
-
num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
|
|
686
|
-
copy_atom_store_dX = cute.make_copy_atom(
|
|
687
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
|
|
688
|
-
)
|
|
689
|
-
num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
|
|
690
|
-
copy_atom_store_dW = cute.make_copy_atom(
|
|
691
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
|
|
692
|
-
)
|
|
693
|
-
if const_expr(mdB is not None):
|
|
694
|
-
num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
|
|
695
|
-
copy_atom_store_dB = cute.make_copy_atom(
|
|
696
|
-
cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
|
|
697
|
-
)
|
|
698
|
-
if const_expr(mdRes is not None):
|
|
699
|
-
num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
|
|
700
|
-
copy_atom_load_dRes = cute.make_copy_atom(
|
|
701
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
702
|
-
mdRes.element_type,
|
|
703
|
-
num_bits_per_copy=num_copy_bits_dRes,
|
|
704
|
-
)
|
|
705
|
-
|
|
646
|
+
copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
|
|
706
647
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
648
|
+
# Each copy will use the same number of elements as X
|
|
649
|
+
copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
|
|
707
650
|
|
|
708
651
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
709
652
|
tXgW = thr_copy_X.partition_S(gW)
|
|
@@ -718,7 +661,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
718
661
|
if not is_even_N
|
|
719
662
|
else None
|
|
720
663
|
)
|
|
721
|
-
|
|
664
|
+
copy(tXgW, tXrW, pred=tXpW)
|
|
722
665
|
weight = tXrW.load().to(cute.Float32)
|
|
723
666
|
|
|
724
667
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -744,7 +687,11 @@ class RMSNormBackward(ReductionBase):
|
|
|
744
687
|
# Always compute partial weight gradients in fp32
|
|
745
688
|
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
746
689
|
|
|
747
|
-
gdB =
|
|
690
|
+
gdB = (
|
|
691
|
+
cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
692
|
+
if const_expr(mdB is not None)
|
|
693
|
+
else None
|
|
694
|
+
)
|
|
748
695
|
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
749
696
|
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
750
697
|
|
|
@@ -772,21 +719,20 @@ class RMSNormBackward(ReductionBase):
|
|
|
772
719
|
tXrX, tXrdO, tXrdX = [
|
|
773
720
|
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
774
721
|
]
|
|
722
|
+
tXrdResO = None
|
|
775
723
|
if const_expr(mdResO is not None):
|
|
776
724
|
tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
|
|
725
|
+
tXrdRes = None
|
|
777
726
|
if const_expr(mdRes is not None):
|
|
778
727
|
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
779
728
|
|
|
780
|
-
copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
|
|
781
|
-
copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
|
|
782
|
-
|
|
783
729
|
# Prefetch the first batch
|
|
784
730
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
785
731
|
if row < M:
|
|
786
732
|
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
787
733
|
tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
|
|
788
|
-
|
|
789
|
-
|
|
734
|
+
copy(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
|
|
735
|
+
copy(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
|
|
790
736
|
elif tiler_mn[0] > 1:
|
|
791
737
|
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
792
738
|
# later into registers, causing wrong dW.
|
|
@@ -809,8 +755,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
809
755
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
810
756
|
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
811
757
|
tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
|
|
812
|
-
|
|
813
|
-
|
|
758
|
+
copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
759
|
+
copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
814
760
|
elif tiler_mn[0] > 1:
|
|
815
761
|
utils.fill_oob(
|
|
816
762
|
tXsX[None, None, None, stage ^ 1],
|
|
@@ -829,7 +775,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
829
775
|
if const_expr(mdResO is not None):
|
|
830
776
|
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
831
777
|
if row < M or tiler_mn[0] == 1:
|
|
832
|
-
|
|
778
|
+
copy(tXgdResO_cur, tXrdResO, pred=tXpX)
|
|
833
779
|
elif tiler_mn[0] > 1:
|
|
834
780
|
tXrdResO.fill(0.0)
|
|
835
781
|
cute.arch.cp_async_wait_group(1)
|
|
@@ -877,12 +823,12 @@ class RMSNormBackward(ReductionBase):
|
|
|
877
823
|
tXrdX.store(dx.to(tXrdX.element_type))
|
|
878
824
|
if row < M or tiler_mn[0] == 1:
|
|
879
825
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
880
|
-
|
|
826
|
+
copy(tXrdX, tXgdX_cur, pred=tXpX)
|
|
881
827
|
if const_expr(mdRes is not None):
|
|
882
828
|
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
883
829
|
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
884
830
|
if row < M or tiler_mn[0] == 1:
|
|
885
|
-
|
|
831
|
+
copy(tXrdRes, tXgdRes_cur, pred=tXpX)
|
|
886
832
|
# Accumulate weight gradients in fp32
|
|
887
833
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
888
834
|
if const_expr(mdB is not None):
|
|
@@ -914,7 +860,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
914
860
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
915
861
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
916
862
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
917
|
-
|
|
863
|
+
copy(tXrdW, tXgdW, pred=tXpdW)
|
|
918
864
|
cute.arch.barrier()
|
|
919
865
|
if const_expr(mdB is not None):
|
|
920
866
|
sdB = cute.make_tensor(
|
|
@@ -930,15 +876,17 @@ class RMSNormBackward(ReductionBase):
|
|
|
930
876
|
if row == 0:
|
|
931
877
|
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
932
878
|
tXrdB_other = cute.make_fragment_like(tXrdB)
|
|
933
|
-
tXsdB_other = cute.make_tensor(
|
|
879
|
+
tXsdB_other = cute.make_tensor(
|
|
880
|
+
tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
|
|
881
|
+
)
|
|
934
882
|
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
935
883
|
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
936
|
-
|
|
884
|
+
copy(tXrdB, tXgdB, pred=tXpdB)
|
|
937
885
|
else:
|
|
938
886
|
# dw is already in fp32, so we can directly copy to global memory
|
|
939
|
-
|
|
887
|
+
copy(tXrdW, tXgdW, pred=tXpdW)
|
|
940
888
|
if const_expr(mdB is not None):
|
|
941
|
-
|
|
889
|
+
copy(tXrdB, tXgdB, pred=tXpdB)
|
|
942
890
|
|
|
943
891
|
|
|
944
892
|
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
@@ -963,7 +911,7 @@ def _get_sm_count(N: int, device: torch.device) -> int:
|
|
|
963
911
|
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
964
912
|
device_types="cuda",
|
|
965
913
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
966
|
-
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(
|
|
914
|
+
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
|
|
967
915
|
)
|
|
968
916
|
def _rmsnorm_bwd(
|
|
969
917
|
x: Tensor,
|
|
@@ -1031,14 +979,23 @@ def _rmsnorm_bwd(
|
|
|
1031
979
|
)
|
|
1032
980
|
|
|
1033
981
|
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1034
|
-
db_partial_tensor =
|
|
982
|
+
db_partial_tensor = (
|
|
983
|
+
from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
984
|
+
if db_partial is not None
|
|
985
|
+
else None
|
|
986
|
+
)
|
|
1035
987
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
1036
988
|
|
|
1037
989
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1038
990
|
|
|
1039
|
-
compile_key = (
|
|
991
|
+
compile_key = (
|
|
992
|
+
N,
|
|
993
|
+
x_tensor.element_type,
|
|
994
|
+
weight_tensor.element_type,
|
|
995
|
+
db_partial.dtype if db_partial is not None else None,
|
|
1040
996
|
dresidual.dtype if dresidual is not None else None,
|
|
1041
|
-
dresidual_out.dtype if dresidual_out is not None else None
|
|
997
|
+
dresidual_out.dtype if dresidual_out is not None else None,
|
|
998
|
+
)
|
|
1042
999
|
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1043
1000
|
rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
|
|
1044
1001
|
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -1106,7 +1063,17 @@ def rmsnorm_bwd(
|
|
|
1106
1063
|
|
|
1107
1064
|
class RMSNormFunction(torch.autograd.Function):
|
|
1108
1065
|
@staticmethod
|
|
1109
|
-
def forward(
|
|
1066
|
+
def forward(
|
|
1067
|
+
ctx,
|
|
1068
|
+
x,
|
|
1069
|
+
weight,
|
|
1070
|
+
bias=None,
|
|
1071
|
+
residual=None,
|
|
1072
|
+
out_dtype=None,
|
|
1073
|
+
residual_dtype=None,
|
|
1074
|
+
eps=1e-6,
|
|
1075
|
+
prenorm=False,
|
|
1076
|
+
):
|
|
1110
1077
|
x_shape_og = x.shape
|
|
1111
1078
|
# Flatten input
|
|
1112
1079
|
x = x.reshape(-1, x.shape[-1])
|
|
@@ -1129,7 +1096,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1129
1096
|
ctx.x_shape_og = x_shape_og
|
|
1130
1097
|
ctx.residual_dtype = residual.dtype if residual is not None else None
|
|
1131
1098
|
ctx.prenorm = prenorm
|
|
1132
|
-
if residual_out is None or prenorm
|
|
1099
|
+
if residual_out is None or not prenorm:
|
|
1133
1100
|
return out.reshape(x_shape_og)
|
|
1134
1101
|
else:
|
|
1135
1102
|
return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
|
|
@@ -1137,6 +1104,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1137
1104
|
@staticmethod
|
|
1138
1105
|
def backward(ctx, dout, *args):
|
|
1139
1106
|
x, weight, rstd = ctx.saved_tensors
|
|
1107
|
+
assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
|
|
1140
1108
|
has_bias = ctx.has_bias
|
|
1141
1109
|
if ctx.prenorm and ctx.residual_dtype is not None:
|
|
1142
1110
|
dresidual_out = args[0]
|
|
@@ -1159,7 +1127,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1159
1127
|
|
|
1160
1128
|
def rmsnorm(
|
|
1161
1129
|
x: Tensor,
|
|
1162
|
-
weight: Tensor,
|
|
1130
|
+
weight: Optional[Tensor] = None,
|
|
1163
1131
|
bias: Optional[Tensor] = None,
|
|
1164
1132
|
residual: Optional[Tensor] = None,
|
|
1165
1133
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -1171,7 +1139,7 @@ def rmsnorm(
|
|
|
1171
1139
|
|
|
1172
1140
|
Args:
|
|
1173
1141
|
x: Input tensor of shape (M, N)
|
|
1174
|
-
weight:
|
|
1142
|
+
weight: Optional weight tensor of shape (N,)
|
|
1175
1143
|
eps: Small value for numerical stability
|
|
1176
1144
|
|
|
1177
1145
|
Returns:
|
|
@@ -1213,4 +1181,4 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1213
1181
|
|
|
1214
1182
|
def reset_parameters(self):
|
|
1215
1183
|
"""Reset the weight parameter to ones."""
|
|
1216
|
-
torch.nn.init.ones_(self.weight)
|
|
1184
|
+
torch.nn.init.ones_(self.weight)
|
quack/softmax.py
CHANGED
|
@@ -159,7 +159,7 @@ class Softmax(ReductionBase):
|
|
|
159
159
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
160
|
)
|
|
161
161
|
log2_e = math.log2(math.e)
|
|
162
|
-
exp_x = cute.math.exp2(
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
163
163
|
denom = row_reduce(
|
|
164
164
|
exp_x,
|
|
165
165
|
cute.ReductionOp.ADD,
|