quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/autotuner.py +64 -5
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/rmsnorm.py +83 -149
- quack/tile_scheduler.py +34 -47
- quack/utils.py +61 -8
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/RECORD +18 -18
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -94,7 +94,7 @@ class RMSNorm(ReductionBase):
|
|
|
94
94
|
def __call__(
|
|
95
95
|
self,
|
|
96
96
|
mX: cute.Tensor,
|
|
97
|
-
mW: cute.Tensor,
|
|
97
|
+
mW: Optional[cute.Tensor],
|
|
98
98
|
mB: Optional[cute.Tensor],
|
|
99
99
|
mRes: Optional[cute.Tensor],
|
|
100
100
|
mO: cute.Tensor,
|
|
@@ -130,8 +130,11 @@ class RMSNorm(ReductionBase):
|
|
|
130
130
|
)
|
|
131
131
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
132
132
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
if const_expr(mW is not None):
|
|
134
|
+
mW_expanded_layout = cute.prepend(
|
|
135
|
+
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
136
|
+
)
|
|
137
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
135
138
|
if const_expr(mB is not None):
|
|
136
139
|
mB_expanded_layout = cute.prepend(
|
|
137
140
|
mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
@@ -158,7 +161,7 @@ class RMSNorm(ReductionBase):
|
|
|
158
161
|
def kernel(
|
|
159
162
|
self,
|
|
160
163
|
mX: cute.Tensor,
|
|
161
|
-
mW: cute.Tensor,
|
|
164
|
+
mW: Optional[cute.Tensor],
|
|
162
165
|
mB: Optional[cute.Tensor],
|
|
163
166
|
mRes: Optional[cute.Tensor],
|
|
164
167
|
mO: cute.Tensor,
|
|
@@ -204,8 +207,10 @@ class RMSNorm(ReductionBase):
|
|
|
204
207
|
for mT in (mX, mRes, mO, mResO)
|
|
205
208
|
]
|
|
206
209
|
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
207
|
-
gW
|
|
208
|
-
|
|
210
|
+
gW, gB = [
|
|
211
|
+
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
|
|
212
|
+
for mT in (mW, mB)
|
|
213
|
+
]
|
|
209
214
|
gRstd = (
|
|
210
215
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
211
216
|
if const_expr(mRstd is not None)
|
|
@@ -214,53 +219,14 @@ class RMSNorm(ReductionBase):
|
|
|
214
219
|
|
|
215
220
|
# declare the atoms which will be used later for memory copy
|
|
216
221
|
num_copy_elems_X = tv_layout.shape[1][0]
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
220
|
-
)
|
|
221
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
222
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
223
|
-
)
|
|
224
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
225
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
226
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
227
|
-
)
|
|
228
|
-
num_bits_per_copy_B = (
|
|
229
|
-
cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
|
|
230
|
-
if const_expr(mB is not None)
|
|
231
|
-
else 0
|
|
232
|
-
)
|
|
233
|
-
copy_atom_load_B = (
|
|
234
|
-
cute.make_copy_atom(
|
|
235
|
-
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
236
|
-
)
|
|
237
|
-
if const_expr(mB is not None)
|
|
238
|
-
else None
|
|
239
|
-
)
|
|
240
|
-
if const_expr(mRes is not None):
|
|
241
|
-
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
242
|
-
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
243
|
-
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
244
|
-
mRes.element_type,
|
|
245
|
-
num_bits_per_copy=num_copy_bits_Res,
|
|
246
|
-
)
|
|
247
|
-
num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
|
|
248
|
-
copy_atom_store_O = cute.make_copy_atom(
|
|
249
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
|
|
222
|
+
copy_atom_load_X_async = utils.get_copy_atom(
|
|
223
|
+
mX.element_type, num_copy_elems_X, is_async=True
|
|
250
224
|
)
|
|
251
|
-
if const_expr(mResO is not None):
|
|
252
|
-
num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
|
|
253
|
-
copy_atom_store_ResO = cute.make_copy_atom(
|
|
254
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
255
|
-
mResO.element_type,
|
|
256
|
-
num_bits_per_copy=num_copy_bits_ResO,
|
|
257
|
-
)
|
|
258
|
-
|
|
259
225
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
260
226
|
tidx
|
|
261
227
|
)
|
|
262
228
|
|
|
263
|
-
tXgW = thr_copy_X.partition_S(gW)
|
|
229
|
+
tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
|
|
264
230
|
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
265
231
|
tXgX = thr_copy_X.partition_S(gX)
|
|
266
232
|
tXsX = thr_copy_X.partition_D(sX)
|
|
@@ -274,8 +240,9 @@ class RMSNorm(ReductionBase):
|
|
|
274
240
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
275
241
|
|
|
276
242
|
# allocate fragments for gmem->rmem
|
|
277
|
-
tXrW = cute.make_fragment_like(tXgW)
|
|
278
|
-
|
|
243
|
+
tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
|
|
244
|
+
if const_expr(mW is not None):
|
|
245
|
+
tXrW.fill(0.0)
|
|
279
246
|
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
280
247
|
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
281
248
|
if const_expr(mRes is not None):
|
|
@@ -288,17 +255,21 @@ class RMSNorm(ReductionBase):
|
|
|
288
255
|
tXpX = (
|
|
289
256
|
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
290
257
|
)
|
|
258
|
+
# Each copy will use the same number of elements as X and same predicate
|
|
259
|
+
copy = partial(utils.copy, pred=tXpX, num_copy_elems=num_copy_elems_X)
|
|
260
|
+
|
|
291
261
|
row = tXcX[0][0]
|
|
292
262
|
if row < shape[0]:
|
|
293
|
-
|
|
263
|
+
copy(tXgX, tXsX, is_async=True)
|
|
294
264
|
if const_expr(mRes is not None):
|
|
295
|
-
|
|
265
|
+
copy(tXgRes, tXsRes, is_async=True)
|
|
296
266
|
cute.arch.cp_async_commit_group()
|
|
297
267
|
|
|
298
268
|
if const_expr(not delay_w_load):
|
|
299
|
-
|
|
269
|
+
if const_expr(mW is not None):
|
|
270
|
+
copy(tXgW, tXrW)
|
|
300
271
|
if const_expr(mB is not None):
|
|
301
|
-
|
|
272
|
+
copy(tXgB, tXrB)
|
|
302
273
|
|
|
303
274
|
cute.arch.cp_async_wait_group(0)
|
|
304
275
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -310,7 +281,7 @@ class RMSNorm(ReductionBase):
|
|
|
310
281
|
tXrResO = cute.make_fragment_like(tXgResO)
|
|
311
282
|
tXrResO.store(x.to(tXrResO.element_type))
|
|
312
283
|
if row < shape[0]:
|
|
313
|
-
|
|
284
|
+
copy(tXrResO, tXgResO)
|
|
314
285
|
|
|
315
286
|
threads_per_row = tv_layout.shape[0][0]
|
|
316
287
|
sum_sq_x = row_reduce(
|
|
@@ -332,27 +303,28 @@ class RMSNorm(ReductionBase):
|
|
|
332
303
|
):
|
|
333
304
|
tXrRstd[0] = rstd
|
|
334
305
|
if const_expr(delay_w_load):
|
|
335
|
-
|
|
306
|
+
if const_expr(mW is not None):
|
|
307
|
+
copy(tXgW, tXrW)
|
|
336
308
|
if const_expr(mB is not None):
|
|
337
|
-
|
|
309
|
+
copy(tXgB, tXrB)
|
|
338
310
|
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
339
311
|
if const_expr(reload_from == "smem"):
|
|
340
312
|
cute.autovec_copy(tXsX, tXrX)
|
|
341
313
|
else:
|
|
342
|
-
|
|
314
|
+
copy(tXgX, tXrX)
|
|
343
315
|
x = tXrX.load().to(cute.Float32)
|
|
344
316
|
if const_expr(mRes is not None):
|
|
345
317
|
cute.autovec_copy(tXsRes, tXrRes)
|
|
346
318
|
x += tXrRes.load().to(cute.Float32)
|
|
347
319
|
x_hat = x * rstd
|
|
348
|
-
|
|
349
|
-
|
|
320
|
+
y = x_hat
|
|
321
|
+
if const_expr(mW is not None):
|
|
322
|
+
y *= tXrW.load().to(cute.Float32)
|
|
350
323
|
if const_expr(mB is not None):
|
|
351
|
-
|
|
352
|
-
y = y + b
|
|
324
|
+
y += tXrB.load().to(cute.Float32)
|
|
353
325
|
tXrO.store(y.to(tXrO.element_type))
|
|
354
326
|
if row < shape[0]:
|
|
355
|
-
|
|
327
|
+
copy(tXrO, tXgO)
|
|
356
328
|
|
|
357
329
|
|
|
358
330
|
@torch.library.custom_op(
|
|
@@ -360,11 +332,11 @@ class RMSNorm(ReductionBase):
|
|
|
360
332
|
mutates_args=("out", "rstd", "residual_out"),
|
|
361
333
|
device_types="cuda",
|
|
362
334
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
363
|
-
schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
|
|
335
|
+
schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
|
|
364
336
|
)
|
|
365
337
|
def _rmsnorm_fwd(
|
|
366
338
|
x: Tensor,
|
|
367
|
-
weight: Tensor,
|
|
339
|
+
weight: Optional[Tensor],
|
|
368
340
|
out: Tensor,
|
|
369
341
|
bias: Optional[Tensor] = None,
|
|
370
342
|
rstd: Optional[Tensor] = None,
|
|
@@ -375,21 +347,23 @@ def _rmsnorm_fwd(
|
|
|
375
347
|
"""RMSNorm forward pass.
|
|
376
348
|
Args:
|
|
377
349
|
x: Input tensor of shape (M, N)
|
|
378
|
-
weight:
|
|
350
|
+
weight: Optional weight tensor of shape (N,)
|
|
379
351
|
eps: Small value for numerical stability
|
|
380
352
|
Returns:
|
|
381
353
|
Normalized output tensor of same shape as x
|
|
382
354
|
"""
|
|
383
355
|
assert x.dim() == 2, "Input must be 2D"
|
|
384
|
-
assert
|
|
385
|
-
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
386
|
-
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
356
|
+
assert x.is_cuda, "Input tensor must be on CUDA device"
|
|
387
357
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
358
|
+
if weight is not None:
|
|
359
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
360
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
361
|
+
assert weight.is_cuda, "Weight tensor must be on CUDA device"
|
|
362
|
+
assert weight.dtype in [
|
|
363
|
+
torch.float32,
|
|
364
|
+
torch.bfloat16,
|
|
365
|
+
torch.float16,
|
|
366
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
393
367
|
if residual is not None:
|
|
394
368
|
assert residual.shape == x.shape
|
|
395
369
|
assert residual.is_cuda
|
|
@@ -402,11 +376,6 @@ def _rmsnorm_fwd(
|
|
|
402
376
|
_, N = x.shape
|
|
403
377
|
device = x.device
|
|
404
378
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
405
|
-
# convert_from_dlpack = lambda x: (
|
|
406
|
-
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
407
|
-
# mode=0, divisibility=128 // dtype.width
|
|
408
|
-
# )
|
|
409
|
-
# )
|
|
410
379
|
convert_from_dlpack = lambda x: (
|
|
411
380
|
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
412
381
|
)
|
|
@@ -414,10 +383,13 @@ def _rmsnorm_fwd(
|
|
|
414
383
|
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
415
384
|
]
|
|
416
385
|
# handle weight divisibility based on weight dtype
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
386
|
+
if weight is not None:
|
|
387
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
388
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
389
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
weight_tensor = None
|
|
421
393
|
if bias is not None:
|
|
422
394
|
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
423
395
|
bias_tensor = utils.convert_from_dlpack(
|
|
@@ -435,7 +407,7 @@ def _rmsnorm_fwd(
|
|
|
435
407
|
N,
|
|
436
408
|
dtype,
|
|
437
409
|
res_tensor.element_type if residual is not None else None,
|
|
438
|
-
weight_tensor.element_type,
|
|
410
|
+
weight_tensor.element_type if weight is not None else None,
|
|
439
411
|
bias_tensor.element_type if bias is not None else None,
|
|
440
412
|
res_out_tensor.element_type if residual_out is not None else None,
|
|
441
413
|
rstd is not None,
|
|
@@ -472,7 +444,7 @@ _rmsnorm_fwd.compile_cache = {}
|
|
|
472
444
|
|
|
473
445
|
def rmsnorm_fwd(
|
|
474
446
|
x: Tensor,
|
|
475
|
-
weight: Tensor,
|
|
447
|
+
weight: Optional[Tensor] = None,
|
|
476
448
|
bias: Optional[Tensor] = None,
|
|
477
449
|
residual: Optional[Tensor] = None,
|
|
478
450
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -501,12 +473,13 @@ def rmsnorm_fwd(
|
|
|
501
473
|
return out, residual_out, rstd
|
|
502
474
|
|
|
503
475
|
|
|
504
|
-
def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
476
|
+
def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
|
|
505
477
|
x_f32 = x.float()
|
|
506
478
|
if residual is not None:
|
|
507
479
|
residual_f32 = residual.float()
|
|
508
480
|
x_f32 += residual_f32
|
|
509
|
-
|
|
481
|
+
x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
|
|
482
|
+
out = x_norm * w if w is not None else x_norm
|
|
510
483
|
if bias is not None:
|
|
511
484
|
out = out + bias.float()
|
|
512
485
|
if residual is None:
|
|
@@ -613,8 +586,11 @@ class RMSNormBackward(ReductionBase):
|
|
|
613
586
|
)
|
|
614
587
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
615
588
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
616
|
-
|
|
617
|
-
|
|
589
|
+
if const_expr(mW is not None):
|
|
590
|
+
mW_expanded_layout = cute.prepend(
|
|
591
|
+
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
592
|
+
)
|
|
593
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
618
594
|
|
|
619
595
|
num_blocks = sm_count
|
|
620
596
|
self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
|
|
@@ -667,50 +643,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
667
643
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
668
644
|
|
|
669
645
|
num_copy_elems_X = tv_layout.shape[1][0]
|
|
670
|
-
|
|
671
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
672
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
673
|
-
)
|
|
674
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
675
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
676
|
-
)
|
|
677
|
-
num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
|
|
678
|
-
copy_atom_load_dO_async = cute.make_copy_atom(
|
|
679
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
|
|
680
|
-
)
|
|
681
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
682
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
683
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
684
|
-
)
|
|
685
|
-
if const_expr(mdResO is not None):
|
|
686
|
-
num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
|
|
687
|
-
copy_atom_load_dResO = cute.make_copy_atom(
|
|
688
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
689
|
-
mdResO.element_type,
|
|
690
|
-
num_bits_per_copy=num_copy_bits_dResO,
|
|
691
|
-
)
|
|
692
|
-
num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
|
|
693
|
-
copy_atom_store_dX = cute.make_copy_atom(
|
|
694
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
|
|
695
|
-
)
|
|
696
|
-
num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
|
|
697
|
-
copy_atom_store_dW = cute.make_copy_atom(
|
|
698
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
|
|
699
|
-
)
|
|
700
|
-
if const_expr(mdB is not None):
|
|
701
|
-
num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
|
|
702
|
-
copy_atom_store_dB = cute.make_copy_atom(
|
|
703
|
-
cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
|
|
704
|
-
)
|
|
705
|
-
if const_expr(mdRes is not None):
|
|
706
|
-
num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
|
|
707
|
-
copy_atom_load_dRes = cute.make_copy_atom(
|
|
708
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
709
|
-
mdRes.element_type,
|
|
710
|
-
num_bits_per_copy=num_copy_bits_dRes,
|
|
711
|
-
)
|
|
712
|
-
|
|
646
|
+
copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
|
|
713
647
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
648
|
+
# Each copy will use the same number of elements as X
|
|
649
|
+
copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
|
|
714
650
|
|
|
715
651
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
716
652
|
tXgW = thr_copy_X.partition_S(gW)
|
|
@@ -725,7 +661,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
725
661
|
if not is_even_N
|
|
726
662
|
else None
|
|
727
663
|
)
|
|
728
|
-
|
|
664
|
+
copy(tXgW, tXrW, pred=tXpW)
|
|
729
665
|
weight = tXrW.load().to(cute.Float32)
|
|
730
666
|
|
|
731
667
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -790,16 +726,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
790
726
|
if const_expr(mdRes is not None):
|
|
791
727
|
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
792
728
|
|
|
793
|
-
copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
|
|
794
|
-
copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
|
|
795
|
-
|
|
796
729
|
# Prefetch the first batch
|
|
797
730
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
798
731
|
if row < M:
|
|
799
732
|
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
800
733
|
tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
|
|
801
|
-
|
|
802
|
-
|
|
734
|
+
copy(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
|
|
735
|
+
copy(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
|
|
803
736
|
elif tiler_mn[0] > 1:
|
|
804
737
|
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
805
738
|
# later into registers, causing wrong dW.
|
|
@@ -822,8 +755,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
822
755
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
823
756
|
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
824
757
|
tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
|
|
825
|
-
|
|
826
|
-
|
|
758
|
+
copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
759
|
+
copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
827
760
|
elif tiler_mn[0] > 1:
|
|
828
761
|
utils.fill_oob(
|
|
829
762
|
tXsX[None, None, None, stage ^ 1],
|
|
@@ -842,7 +775,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
842
775
|
if const_expr(mdResO is not None):
|
|
843
776
|
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
844
777
|
if row < M or tiler_mn[0] == 1:
|
|
845
|
-
|
|
778
|
+
copy(tXgdResO_cur, tXrdResO, pred=tXpX)
|
|
846
779
|
elif tiler_mn[0] > 1:
|
|
847
780
|
tXrdResO.fill(0.0)
|
|
848
781
|
cute.arch.cp_async_wait_group(1)
|
|
@@ -890,12 +823,12 @@ class RMSNormBackward(ReductionBase):
|
|
|
890
823
|
tXrdX.store(dx.to(tXrdX.element_type))
|
|
891
824
|
if row < M or tiler_mn[0] == 1:
|
|
892
825
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
893
|
-
|
|
826
|
+
copy(tXrdX, tXgdX_cur, pred=tXpX)
|
|
894
827
|
if const_expr(mdRes is not None):
|
|
895
828
|
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
896
829
|
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
897
830
|
if row < M or tiler_mn[0] == 1:
|
|
898
|
-
|
|
831
|
+
copy(tXrdRes, tXgdRes_cur, pred=tXpX)
|
|
899
832
|
# Accumulate weight gradients in fp32
|
|
900
833
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
901
834
|
if const_expr(mdB is not None):
|
|
@@ -927,7 +860,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
927
860
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
928
861
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
929
862
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
930
|
-
|
|
863
|
+
copy(tXrdW, tXgdW, pred=tXpdW)
|
|
931
864
|
cute.arch.barrier()
|
|
932
865
|
if const_expr(mdB is not None):
|
|
933
866
|
sdB = cute.make_tensor(
|
|
@@ -948,12 +881,12 @@ class RMSNormBackward(ReductionBase):
|
|
|
948
881
|
)
|
|
949
882
|
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
950
883
|
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
951
|
-
|
|
884
|
+
copy(tXrdB, tXgdB, pred=tXpdB)
|
|
952
885
|
else:
|
|
953
886
|
# dw is already in fp32, so we can directly copy to global memory
|
|
954
|
-
|
|
887
|
+
copy(tXrdW, tXgdW, pred=tXpdW)
|
|
955
888
|
if const_expr(mdB is not None):
|
|
956
|
-
|
|
889
|
+
copy(tXrdB, tXgdB, pred=tXpdB)
|
|
957
890
|
|
|
958
891
|
|
|
959
892
|
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
@@ -1171,6 +1104,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1171
1104
|
@staticmethod
|
|
1172
1105
|
def backward(ctx, dout, *args):
|
|
1173
1106
|
x, weight, rstd = ctx.saved_tensors
|
|
1107
|
+
assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
|
|
1174
1108
|
has_bias = ctx.has_bias
|
|
1175
1109
|
if ctx.prenorm and ctx.residual_dtype is not None:
|
|
1176
1110
|
dresidual_out = args[0]
|
|
@@ -1193,7 +1127,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1193
1127
|
|
|
1194
1128
|
def rmsnorm(
|
|
1195
1129
|
x: Tensor,
|
|
1196
|
-
weight: Tensor,
|
|
1130
|
+
weight: Optional[Tensor] = None,
|
|
1197
1131
|
bias: Optional[Tensor] = None,
|
|
1198
1132
|
residual: Optional[Tensor] = None,
|
|
1199
1133
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -1205,7 +1139,7 @@ def rmsnorm(
|
|
|
1205
1139
|
|
|
1206
1140
|
Args:
|
|
1207
1141
|
x: Input tensor of shape (M, N)
|
|
1208
|
-
weight:
|
|
1142
|
+
weight: Optional weight tensor of shape (N,)
|
|
1209
1143
|
eps: Small value for numerical stability
|
|
1210
1144
|
|
|
1211
1145
|
Returns:
|
quack/tile_scheduler.py
CHANGED
|
@@ -135,7 +135,7 @@ class TileScheduler:
|
|
|
135
135
|
ip=None,
|
|
136
136
|
):
|
|
137
137
|
self._current_work_linear_idx = current_work_linear_idx
|
|
138
|
-
self.
|
|
138
|
+
self.num_tiles_executed = num_tiles_executed
|
|
139
139
|
self._tile_count = tile_count
|
|
140
140
|
self._scheduler_pipeline = scheduler_pipeline
|
|
141
141
|
self._pipeline_state = pipeline_state
|
|
@@ -251,7 +251,7 @@ class TileScheduler:
|
|
|
251
251
|
)
|
|
252
252
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
253
253
|
if const_expr(not params.is_persistent):
|
|
254
|
-
is_valid = self.
|
|
254
|
+
is_valid = self.num_tiles_executed == 0
|
|
255
255
|
else:
|
|
256
256
|
is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
|
|
257
257
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
@@ -276,38 +276,6 @@ class TileScheduler:
|
|
|
276
276
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
277
277
|
self._current_work_linear_idx = current_work_linear_idx
|
|
278
278
|
|
|
279
|
-
# We have to split broadcast_next_work and advance_to_next_work into two functions
|
|
280
|
-
# due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
|
|
281
|
-
@cute.jit
|
|
282
|
-
def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
283
|
-
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
284
|
-
params = self.params
|
|
285
|
-
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
286
|
-
current_work_linear_idx = self._current_work_linear_idx
|
|
287
|
-
if is_scheduler_warp:
|
|
288
|
-
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
289
|
-
lane_idx = cute.arch.lane_idx()
|
|
290
|
-
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
291
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
292
|
-
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
293
|
-
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
294
|
-
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
295
|
-
else:
|
|
296
|
-
peer_cta_rank_in_cluster = lane_idx
|
|
297
|
-
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
298
|
-
self._pipeline_state
|
|
299
|
-
)
|
|
300
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
301
|
-
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
302
|
-
)
|
|
303
|
-
utils.store_shared_remote(
|
|
304
|
-
val=current_work_linear_idx,
|
|
305
|
-
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
306
|
-
mbar_ptr=mbar_ptr,
|
|
307
|
-
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
308
|
-
)
|
|
309
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
310
|
-
|
|
311
279
|
@cute.jit
|
|
312
280
|
def advance_to_next_work(
|
|
313
281
|
self,
|
|
@@ -328,7 +296,30 @@ class TileScheduler:
|
|
|
328
296
|
if const_expr(advance_count > 1):
|
|
329
297
|
self._pipeline_state.advance_iters(advance_count - 1)
|
|
330
298
|
current_work_linear_idx = self._current_work_linear_idx
|
|
331
|
-
if
|
|
299
|
+
if is_scheduler_warp:
|
|
300
|
+
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
301
|
+
lane_idx = cute.arch.lane_idx()
|
|
302
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
303
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
304
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
305
|
+
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
306
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
307
|
+
else:
|
|
308
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
309
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
310
|
+
self._pipeline_state
|
|
311
|
+
)
|
|
312
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
313
|
+
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
314
|
+
)
|
|
315
|
+
utils.store_shared_remote(
|
|
316
|
+
val=current_work_linear_idx,
|
|
317
|
+
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
318
|
+
mbar_ptr=mbar_ptr,
|
|
319
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
320
|
+
)
|
|
321
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
322
|
+
else:
|
|
332
323
|
# if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
333
324
|
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
334
325
|
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
@@ -341,21 +332,17 @@ class TileScheduler:
|
|
|
341
332
|
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
|
|
342
333
|
self._current_work_linear_idx = current_work_linear_idx
|
|
343
334
|
self._pipeline_state.advance()
|
|
344
|
-
self.
|
|
335
|
+
self.num_tiles_executed += Int32(advance_count)
|
|
345
336
|
|
|
346
337
|
def producer_tail(self):
|
|
347
338
|
if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
|
|
348
339
|
self._scheduler_pipeline.producer_tail(self._pipeline_state)
|
|
349
340
|
|
|
350
|
-
@property
|
|
351
|
-
def num_tiles_executed(self) -> Int32:
|
|
352
|
-
return self._num_tiles_executed
|
|
353
|
-
|
|
354
341
|
def __extract_mlir_values__(self):
|
|
355
342
|
values, self._values_pos = [], []
|
|
356
343
|
for obj in [
|
|
357
344
|
self._current_work_linear_idx,
|
|
358
|
-
self.
|
|
345
|
+
self.num_tiles_executed,
|
|
359
346
|
self._tile_count,
|
|
360
347
|
self._scheduler_pipeline,
|
|
361
348
|
self._pipeline_state,
|
|
@@ -371,7 +358,7 @@ class TileScheduler:
|
|
|
371
358
|
for obj, n_items in zip(
|
|
372
359
|
[
|
|
373
360
|
self._current_work_linear_idx,
|
|
374
|
-
self.
|
|
361
|
+
self.num_tiles_executed,
|
|
375
362
|
self._tile_count,
|
|
376
363
|
self._scheduler_pipeline,
|
|
377
364
|
self._pipeline_state,
|
|
@@ -562,7 +549,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
562
549
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
563
550
|
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
564
551
|
if const_expr(not params.is_persistent):
|
|
565
|
-
is_valid = self.
|
|
552
|
+
is_valid = self.num_tiles_executed == 0
|
|
566
553
|
else:
|
|
567
554
|
is_valid = (
|
|
568
555
|
self._current_work_linear_idx
|
|
@@ -681,7 +668,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
681
668
|
ip=None,
|
|
682
669
|
):
|
|
683
670
|
self._current_work_linear_idx = current_work_linear_idx
|
|
684
|
-
self.
|
|
671
|
+
self.num_tiles_executed = num_tiles_executed
|
|
685
672
|
self._current_batch_idx = current_batch_idx
|
|
686
673
|
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
687
674
|
self._tile_count = tile_count
|
|
@@ -878,7 +865,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
878
865
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
879
866
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
880
867
|
if const_expr(not params.is_persistent):
|
|
881
|
-
is_valid = self.
|
|
868
|
+
is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
|
|
882
869
|
else:
|
|
883
870
|
is_valid = batch_idx < num_batch
|
|
884
871
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
@@ -905,7 +892,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
905
892
|
values, self._values_pos = [], []
|
|
906
893
|
for obj in [
|
|
907
894
|
self._current_work_linear_idx,
|
|
908
|
-
self.
|
|
895
|
+
self.num_tiles_executed,
|
|
909
896
|
self._current_batch_idx,
|
|
910
897
|
self._num_work_idx_before_cur_batch,
|
|
911
898
|
self._tile_count,
|
|
@@ -923,7 +910,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
923
910
|
for obj, n_items in zip(
|
|
924
911
|
[
|
|
925
912
|
self._current_work_linear_idx,
|
|
926
|
-
self.
|
|
913
|
+
self.num_tiles_executed,
|
|
927
914
|
self._current_batch_idx,
|
|
928
915
|
self._num_work_idx_before_cur_batch,
|
|
929
916
|
self._tile_count,
|