quack-kernels 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +56 -15
- quack/reduction_base.py +16 -8
- quack/rmsnorm.py +401 -205
- quack/softmax.py +1 -3
- quack/utils.py +50 -10
- {quack_kernels-0.1.6.dist-info → quack_kernels-0.1.8.dist-info}/METADATA +1 -1
- quack_kernels-0.1.8.dist-info/RECORD +12 -0
- quack_kernels-0.1.6.dist-info/RECORD +0 -12
- {quack_kernels-0.1.6.dist-info → quack_kernels-0.1.8.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.6.dist-info → quack_kernels-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.6.dist-info → quack_kernels-0.1.8.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
import torch
|
|
5
4
|
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
import cuda.bindings.driver as cuda
|
|
8
7
|
|
|
9
8
|
import cutlass
|
|
10
9
|
import cutlass.cute as cute
|
|
11
|
-
from cutlass.cute.runtime import from_dlpack
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
import torch
|
|
13
|
+
from cutlass.cute.runtime import from_dlpack
|
|
14
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
15
|
|
|
16
16
|
|
|
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
|
|
|
79
79
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
80
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
81
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
82
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
83
83
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
84
|
stream=stream,
|
|
85
85
|
)
|
|
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
|
|
|
111
111
|
|
|
112
112
|
smem = cutlass.utils.SmemAllocator()
|
|
113
113
|
sX = smem.allocate_tensor(
|
|
114
|
-
mX.element_type,
|
|
114
|
+
mX.element_type,
|
|
115
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
116
|
+
byte_alignment=16,
|
|
115
117
|
)
|
|
116
118
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
119
|
|
|
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
|
|
|
166
168
|
reduction_buffer[None, None, 0],
|
|
167
169
|
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
168
170
|
init_val=-cutlass.Float32.inf,
|
|
169
|
-
hook_fn=
|
|
171
|
+
hook_fn=(
|
|
172
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
+
),
|
|
170
174
|
)
|
|
171
175
|
if cutlass.const_expr(self.reload_from == "smem"):
|
|
172
176
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
|
|
|
191
195
|
threads_per_row,
|
|
192
196
|
reduction_buffer[None, None, 0],
|
|
193
197
|
mbar_ptr,
|
|
194
|
-
hook_fn=
|
|
198
|
+
hook_fn=(
|
|
199
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
200
|
+
),
|
|
195
201
|
)
|
|
196
202
|
|
|
197
203
|
if (
|
|
@@ -225,7 +231,11 @@ def _cross_entropy(
|
|
|
225
231
|
assert target.dim() == 1, "Target must be 1D"
|
|
226
232
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
227
233
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
228
|
-
assert x.dtype in [
|
|
234
|
+
assert x.dtype in [
|
|
235
|
+
torch.float16,
|
|
236
|
+
torch.bfloat16,
|
|
237
|
+
torch.float32,
|
|
238
|
+
], "Unsupported input dtype"
|
|
229
239
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
230
240
|
M, N = x.shape
|
|
231
241
|
device = x.device
|
|
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
|
|
|
314
324
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
325
|
|
|
316
326
|
mDLoss = cute.make_tensor(
|
|
317
|
-
mDLoss.iterator,
|
|
327
|
+
mDLoss.iterator,
|
|
328
|
+
cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
318
329
|
)
|
|
319
330
|
mTarget = cute.make_tensor(
|
|
320
|
-
mTarget.iterator,
|
|
331
|
+
mTarget.iterator,
|
|
332
|
+
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
321
333
|
)
|
|
322
334
|
mLSE = cute.make_tensor(
|
|
323
|
-
mLSE.iterator,
|
|
335
|
+
mLSE.iterator,
|
|
336
|
+
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
324
337
|
)
|
|
325
338
|
|
|
326
339
|
smem_size = cute.size_in_bytes(
|
|
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
|
|
|
364
377
|
|
|
365
378
|
smem = cutlass.utils.SmemAllocator()
|
|
366
379
|
sX = smem.allocate_tensor(
|
|
367
|
-
mX.element_type,
|
|
380
|
+
mX.element_type,
|
|
381
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
382
|
+
byte_alignment=16,
|
|
368
383
|
)
|
|
369
384
|
|
|
370
385
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
|
|
|
474
489
|
assert (
|
|
475
490
|
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
491
|
), "Tensors must be on CUDA device"
|
|
477
|
-
assert x.dtype in [
|
|
492
|
+
assert x.dtype in [
|
|
493
|
+
torch.float16,
|
|
494
|
+
torch.bfloat16,
|
|
495
|
+
torch.float32,
|
|
496
|
+
], "Unsupported input dtype"
|
|
478
497
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
498
|
|
|
480
499
|
M, N = x.shape
|
|
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
|
|
|
532
551
|
|
|
533
552
|
|
|
534
553
|
def cross_entropy(
|
|
535
|
-
x: torch.Tensor,
|
|
554
|
+
x: torch.Tensor,
|
|
555
|
+
target: torch.Tensor,
|
|
556
|
+
inplace_backward: bool = True,
|
|
557
|
+
reduction: str = "none",
|
|
536
558
|
) -> torch.Tensor:
|
|
537
559
|
"""Cross entropy loss with automatic differentiation support.
|
|
538
560
|
|
|
539
561
|
Args:
|
|
540
562
|
x: Input logits tensor of shape (M, N)
|
|
541
563
|
target: Target class indices tensor of shape (M,)
|
|
564
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
565
|
+
reduction: Specifies the reduction to apply to the output:
|
|
566
|
+
'none': no reduction will be applied (default)
|
|
567
|
+
'mean': the sum of the output will be divided by the number of elements
|
|
568
|
+
'sum': the output will be summed
|
|
542
569
|
|
|
543
570
|
Returns:
|
|
544
|
-
Cross entropy loss tensor
|
|
571
|
+
Cross entropy loss tensor:
|
|
572
|
+
- If reduction='none': tensor of shape (M,) with per-example losses
|
|
573
|
+
- If reduction='mean': scalar tensor with mean loss
|
|
574
|
+
- If reduction='sum': scalar tensor with sum of losses
|
|
545
575
|
"""
|
|
546
|
-
|
|
576
|
+
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
577
|
+
|
|
578
|
+
if reduction == "mean":
|
|
579
|
+
return loss.mean()
|
|
580
|
+
elif reduction == "sum":
|
|
581
|
+
return loss.sum()
|
|
582
|
+
elif reduction == "none":
|
|
583
|
+
return loss
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
|
|
587
|
+
)
|
quack/reduction_base.py
CHANGED
|
@@ -68,7 +68,7 @@ class ReductionBase:
|
|
|
68
68
|
)
|
|
69
69
|
|
|
70
70
|
def _allocate_reduction_buffer_and_mbar(
|
|
71
|
-
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
71
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
|
|
72
72
|
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
73
73
|
reduction_buffer = smem.allocate_tensor(
|
|
74
74
|
self.reduction_dtype,
|
|
@@ -76,20 +76,28 @@ class ReductionBase:
|
|
|
76
76
|
byte_alignment=4,
|
|
77
77
|
)
|
|
78
78
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
79
|
-
mbar_ptr = smem.allocate_array(
|
|
79
|
+
mbar_ptr = smem.allocate_array(
|
|
80
|
+
cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
|
|
81
|
+
)
|
|
80
82
|
else:
|
|
81
83
|
mbar_ptr = None
|
|
82
84
|
return reduction_buffer, mbar_ptr
|
|
83
85
|
|
|
84
86
|
@cute.jit
|
|
85
|
-
def _initialize_cluster(
|
|
87
|
+
def _initialize_cluster(
|
|
88
|
+
self,
|
|
89
|
+
tidx: cutlass.Int32,
|
|
90
|
+
mbar_ptr: cute.Pointer,
|
|
91
|
+
num_warps: int,
|
|
92
|
+
is_persistent: bool = False,
|
|
93
|
+
):
|
|
86
94
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
87
|
-
if tidx < self.stage:
|
|
95
|
+
if tidx < self.stage: # Initialize full barrier
|
|
88
96
|
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
97
|
+
if cutlass.const_expr(is_persistent): # Initialize empty barrier
|
|
98
|
+
cute.arch.mbarrier_init(
|
|
99
|
+
mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
|
|
100
|
+
)
|
|
89
101
|
cute.arch.mbarrier_init_fence()
|
|
90
|
-
if tidx < self.stage:
|
|
91
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
92
|
-
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
93
|
-
)
|
|
94
102
|
# Cluster arrive after barrier init
|
|
95
103
|
cute.arch.cluster_arrive_relaxed()
|
quack/rmsnorm.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
3
|
from typing import Optional
|
|
6
4
|
|
|
7
5
|
import cuda.bindings.driver as cuda
|
|
8
6
|
|
|
9
7
|
import cutlass
|
|
10
8
|
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, Int32
|
|
11
10
|
from cutlass.cute.runtime import from_dlpack
|
|
11
|
+
|
|
12
12
|
import quack.utils as utils
|
|
13
|
+
import torch
|
|
13
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,41 +21,55 @@ class RMSNorm(ReductionBase):
|
|
|
20
21
|
self.delay_w_load = False
|
|
21
22
|
|
|
22
23
|
def _calculate_threads_per_row(self):
|
|
24
|
+
"""Calculate the number of threads per row for the RMSNorm kernel."""
|
|
23
25
|
N = self.N
|
|
24
|
-
|
|
25
|
-
8
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
26
|
+
if N <= 64:
|
|
27
|
+
return 8
|
|
28
|
+
elif N <= 128:
|
|
29
|
+
return 16
|
|
30
|
+
elif N <= 3072:
|
|
31
|
+
return 32
|
|
32
|
+
elif N <= 6144:
|
|
33
|
+
return 64
|
|
34
|
+
elif N <= 16384:
|
|
35
|
+
return 128
|
|
36
|
+
else:
|
|
37
|
+
return 256
|
|
33
38
|
|
|
34
39
|
def _set_cluster_n(self):
|
|
40
|
+
"""
|
|
41
|
+
Set the number of clusters for the RMSNorm kernel.
|
|
42
|
+
Stored in self.cluster_n.
|
|
43
|
+
"""
|
|
35
44
|
N = self.N
|
|
45
|
+
|
|
36
46
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
37
47
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
38
48
|
if cutlass.const_expr(self.dtype.width == 16):
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
49
|
+
# 16-bit types (fp16, bf16)
|
|
50
|
+
if N <= 16 * 1024:
|
|
51
|
+
cluster_n = 1
|
|
52
|
+
elif N <= 32 * 1024:
|
|
53
|
+
cluster_n = 2
|
|
54
|
+
elif N <= 64 * 1024:
|
|
55
|
+
cluster_n = 4
|
|
56
|
+
elif N <= 128 * 1024:
|
|
57
|
+
cluster_n = 8
|
|
58
|
+
else:
|
|
59
|
+
cluster_n = 16
|
|
60
|
+
else:
|
|
61
|
+
# 32-bit types (fp32)
|
|
62
|
+
if N <= 32 * 1024:
|
|
63
|
+
cluster_n = 1
|
|
64
|
+
elif N <= 64 * 1024:
|
|
65
|
+
cluster_n = 2
|
|
66
|
+
elif N <= 128 * 1024:
|
|
67
|
+
cluster_n = 4
|
|
68
|
+
elif N <= 256 * 1024:
|
|
69
|
+
cluster_n = 8
|
|
70
|
+
else:
|
|
71
|
+
cluster_n = 16
|
|
72
|
+
|
|
58
73
|
self.cluster_n = cluster_n
|
|
59
74
|
|
|
60
75
|
@cute.jit
|
|
@@ -65,8 +80,17 @@ class RMSNorm(ReductionBase):
|
|
|
65
80
|
mO: cute.Tensor,
|
|
66
81
|
mRstd: Optional[cute.Tensor],
|
|
67
82
|
stream: cuda.CUstream,
|
|
68
|
-
eps:
|
|
83
|
+
eps: Float32 = 1e-6,
|
|
69
84
|
):
|
|
85
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
86
|
+
new_stride = lambda t: (
|
|
87
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
88
|
+
t.stride[1],
|
|
89
|
+
)
|
|
90
|
+
mX, mO = [
|
|
91
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
92
|
+
for t in (mX, mO)
|
|
93
|
+
]
|
|
70
94
|
assert mX.element_type == self.dtype
|
|
71
95
|
assert mO.element_type == self.dtype
|
|
72
96
|
self._set_cluster_n()
|
|
@@ -83,7 +107,7 @@ class RMSNorm(ReductionBase):
|
|
|
83
107
|
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
84
108
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
85
109
|
block=[num_threads, 1, 1],
|
|
86
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
110
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
87
111
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
88
112
|
stream=stream,
|
|
89
113
|
)
|
|
@@ -110,7 +134,9 @@ class RMSNorm(ReductionBase):
|
|
|
110
134
|
|
|
111
135
|
smem = cutlass.utils.SmemAllocator()
|
|
112
136
|
sX = smem.allocate_tensor(
|
|
113
|
-
mX.element_type,
|
|
137
|
+
mX.element_type,
|
|
138
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
139
|
+
byte_alignment=16,
|
|
114
140
|
)
|
|
115
141
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
116
142
|
|
|
@@ -157,6 +183,7 @@ class RMSNorm(ReductionBase):
|
|
|
157
183
|
|
|
158
184
|
# allocate fragments for gmem->rmem
|
|
159
185
|
tWrW = cute.make_fragment_like(tWgW)
|
|
186
|
+
tWrW.fill(0.0)
|
|
160
187
|
tXrW = thr_copy_X.retile(tWrW)
|
|
161
188
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
162
189
|
|
|
@@ -184,7 +211,7 @@ class RMSNorm(ReductionBase):
|
|
|
184
211
|
reduction_buffer[None, None, 0],
|
|
185
212
|
mbar_ptr,
|
|
186
213
|
init_val=0.0,
|
|
187
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
214
|
+
hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
188
215
|
)
|
|
189
216
|
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
190
217
|
if cutlass.const_expr(mRstd is not None):
|
|
@@ -232,25 +259,36 @@ def _rmsnorm_fwd(
|
|
|
232
259
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
233
260
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
234
261
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
235
|
-
assert x.dtype in [
|
|
236
|
-
|
|
262
|
+
assert x.dtype in [
|
|
263
|
+
torch.float16,
|
|
264
|
+
torch.bfloat16,
|
|
265
|
+
torch.float32,
|
|
266
|
+
], "Unsupported dtype"
|
|
267
|
+
|
|
268
|
+
assert weight.dtype in [
|
|
269
|
+
torch.float32,
|
|
270
|
+
torch.bfloat16,
|
|
271
|
+
torch.float16,
|
|
272
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
273
|
+
|
|
237
274
|
M, N = x.shape
|
|
238
275
|
device = x.device
|
|
239
276
|
out = torch.empty_like(x)
|
|
240
277
|
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
241
278
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
279
|
+
# convert_from_dlpack = lambda x: (
|
|
280
|
+
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
281
|
+
# mode=0, divisibility=128 // dtype.width
|
|
282
|
+
# )
|
|
283
|
+
# )
|
|
242
284
|
convert_from_dlpack = lambda x: (
|
|
243
|
-
from_dlpack(x.detach(), assumed_align=16).
|
|
244
|
-
mode=0, stride_order=(0, 1)
|
|
245
|
-
)
|
|
285
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
246
286
|
)
|
|
247
|
-
x_tensor, out_tensor = [
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
for t in (x, out)
|
|
251
|
-
]
|
|
287
|
+
x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
|
|
288
|
+
# handle weight divisibility based on weight dtype
|
|
289
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
252
290
|
weight_tensor = utils.convert_from_dlpack(
|
|
253
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
291
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
254
292
|
)
|
|
255
293
|
rstd_tensor = (
|
|
256
294
|
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
@@ -258,7 +296,7 @@ def _rmsnorm_fwd(
|
|
|
258
296
|
else None
|
|
259
297
|
)
|
|
260
298
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
261
|
-
compile_key = (dtype, N, rstd is not None)
|
|
299
|
+
compile_key = (dtype, N, rstd is not None, weight.dtype)
|
|
262
300
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
263
301
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
264
302
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -300,8 +338,15 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
300
338
|
|
|
301
339
|
class RMSNormBackward(ReductionBase):
|
|
302
340
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
303
|
-
#
|
|
304
|
-
super().__init__(dtype, N, stage=
|
|
341
|
+
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
342
|
+
super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
|
|
343
|
+
self.reload_wdy = None if N <= 16 * 1024 else "smem"
|
|
344
|
+
if self.N > 128 * 1024 and self.dtype.width >= 32:
|
|
345
|
+
# Not enough smem
|
|
346
|
+
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
347
|
+
|
|
348
|
+
def _get_num_threads(self):
|
|
349
|
+
return 128 if self.N <= 4096 else 256
|
|
305
350
|
|
|
306
351
|
def _calculate_threads_per_row(self):
|
|
307
352
|
N = self.N
|
|
@@ -311,46 +356,49 @@ class RMSNormBackward(ReductionBase):
|
|
|
311
356
|
else (
|
|
312
357
|
16
|
|
313
358
|
if N <= 128
|
|
314
|
-
else (32 if N <=
|
|
359
|
+
else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
|
|
315
360
|
)
|
|
316
361
|
)
|
|
317
362
|
|
|
318
363
|
def _set_cluster_n(self):
|
|
319
364
|
N = self.N
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
2
|
|
326
|
-
if N <= 32 * 1024
|
|
327
|
-
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
328
|
-
)
|
|
329
|
-
)
|
|
330
|
-
else: # fp32
|
|
331
|
-
cluster_n = (
|
|
332
|
-
1
|
|
333
|
-
if N <= 32 * 1024
|
|
334
|
-
else (
|
|
335
|
-
2
|
|
336
|
-
if N <= 64 * 1024
|
|
337
|
-
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
338
|
-
)
|
|
339
|
-
)
|
|
365
|
+
cluster_n = (
|
|
366
|
+
1
|
|
367
|
+
if N <= 8 * 1024
|
|
368
|
+
else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
|
|
369
|
+
)
|
|
340
370
|
self.cluster_n = cluster_n
|
|
341
371
|
|
|
372
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
373
|
+
return (
|
|
374
|
+
# Multiply by 2 since we need space for X and dOut,
|
|
375
|
+
# and multiply by another 2 due to double buffering
|
|
376
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
|
|
377
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
378
|
+
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
379
|
+
)
|
|
380
|
+
|
|
342
381
|
@cute.jit
|
|
343
382
|
def __call__(
|
|
344
383
|
self,
|
|
345
384
|
mX: cute.Tensor,
|
|
346
385
|
mW: cute.Tensor,
|
|
347
|
-
|
|
386
|
+
mdOut: cute.Tensor,
|
|
348
387
|
mRstd: cute.Tensor,
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
sm_count:
|
|
388
|
+
mdX: cute.Tensor,
|
|
389
|
+
mdW: cute.Tensor,
|
|
390
|
+
sm_count: Int32,
|
|
352
391
|
stream: cuda.CUstream,
|
|
353
392
|
):
|
|
393
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
394
|
+
new_stride = lambda t: (
|
|
395
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
396
|
+
t.stride[1],
|
|
397
|
+
)
|
|
398
|
+
mX, mdOut, mdX = [
|
|
399
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
400
|
+
for t in (mX, mdOut, mdX)
|
|
401
|
+
]
|
|
354
402
|
self._set_cluster_n()
|
|
355
403
|
tiler_mn, tv_layout = self._get_tv_layout()
|
|
356
404
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
@@ -359,14 +407,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
359
407
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
360
408
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
361
409
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
num_blocks = (
|
|
366
|
-
sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
|
|
410
|
+
num_blocks = sm_count
|
|
411
|
+
self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
|
|
370
412
|
grid=[num_blocks, self.cluster_n, 1],
|
|
371
413
|
block=[num_threads, 1, 1],
|
|
372
414
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
@@ -379,177 +421,260 @@ class RMSNormBackward(ReductionBase):
|
|
|
379
421
|
self,
|
|
380
422
|
mX: cute.Tensor,
|
|
381
423
|
mW: cute.Tensor,
|
|
382
|
-
|
|
424
|
+
mdOut: cute.Tensor,
|
|
383
425
|
mRstd: cute.Tensor,
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
sm_count: cutlass.Constexpr,
|
|
426
|
+
mdX: cute.Tensor,
|
|
427
|
+
mdW: cute.Tensor,
|
|
387
428
|
tv_layout: cute.Layout,
|
|
388
429
|
tiler_mn: cute.Shape,
|
|
389
430
|
):
|
|
390
431
|
tidx, _, _ = cute.arch.thread_idx()
|
|
391
|
-
|
|
432
|
+
bidx_start, _, _ = cute.arch.block_idx()
|
|
392
433
|
gdim, _, _ = cute.arch.grid_dim()
|
|
434
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
435
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
436
|
+
else:
|
|
437
|
+
cluster_y = cutlass.const_expr(0)
|
|
393
438
|
|
|
394
439
|
shape = mX.shape
|
|
395
440
|
M, N = shape[0], shape[1]
|
|
441
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
396
442
|
|
|
397
443
|
idX = cute.make_identity_tensor(shape)
|
|
398
444
|
|
|
399
445
|
smem = cutlass.utils.SmemAllocator()
|
|
400
|
-
|
|
446
|
+
smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
|
|
447
|
+
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
|
|
448
|
+
sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
|
|
449
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
|
|
450
|
+
smem, tv_layout, is_persistent=True
|
|
451
|
+
)
|
|
452
|
+
if cutlass.const_expr(mbar_ptr is not None):
|
|
453
|
+
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
|
|
454
|
+
else:
|
|
455
|
+
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
401
456
|
|
|
402
457
|
copy_atom_load_X = cute.make_copy_atom(
|
|
403
458
|
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
404
459
|
)
|
|
405
|
-
|
|
460
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
461
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
462
|
+
)
|
|
406
463
|
copy_atom_load_W = cute.make_copy_atom(
|
|
407
464
|
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
408
465
|
)
|
|
409
|
-
|
|
410
466
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
411
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
467
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
|
|
412
468
|
)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
|
|
469
|
+
copy_atom_store_dW = cute.make_copy_atom(
|
|
470
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
|
|
416
471
|
)
|
|
417
472
|
|
|
418
473
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
474
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
475
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
476
|
+
).get_slice(tidx)
|
|
419
477
|
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
420
|
-
|
|
421
|
-
|
|
478
|
+
thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
|
|
479
|
+
thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
422
480
|
|
|
423
|
-
gW = cute.local_tile(mW, tiler_mn, (
|
|
481
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
424
482
|
tWgW = thr_copy_W.partition_S(gW)
|
|
425
483
|
tWrW = cute.make_fragment_like(tWgW)
|
|
484
|
+
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
485
|
+
if not is_even_N:
|
|
486
|
+
tWrW.fill(0.0)
|
|
426
487
|
tXrW = thr_copy_X.retile(tWrW)
|
|
427
488
|
|
|
428
|
-
gW_coord = cute.local_tile(idX, tiler_mn, (0,
|
|
429
|
-
|
|
430
|
-
|
|
489
|
+
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
490
|
+
tWpW = (
|
|
491
|
+
utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
|
|
492
|
+
if not is_even_N
|
|
493
|
+
else None
|
|
494
|
+
)
|
|
431
495
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
432
496
|
weight = tXrW.load().to(cute.Float32)
|
|
433
497
|
|
|
434
498
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
435
499
|
|
|
436
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
437
|
-
|
|
438
|
-
dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
439
|
-
tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
|
|
500
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
440
501
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
M_pad = ((M + sm_count - 1) // sm_count) * sm_count
|
|
502
|
+
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
503
|
+
tdWpdW = (
|
|
504
|
+
utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
|
|
505
|
+
if not is_even_N
|
|
506
|
+
else None
|
|
507
|
+
)
|
|
448
508
|
|
|
449
|
-
|
|
509
|
+
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
510
|
+
tdWgdW = thr_copy_dW.partition_D(gdW)
|
|
511
|
+
# Always compute partial weight gradients in fp32
|
|
512
|
+
tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
|
|
513
|
+
tXrdW = thr_copy_X.retile(tdWrdW)
|
|
450
514
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
515
|
+
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
516
|
+
gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
|
|
517
|
+
gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
|
|
518
|
+
cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
|
|
519
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
520
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
521
|
+
tXgdOut = thr_copy_X.partition_S(gdOut)
|
|
522
|
+
tXsdOut = thr_copy_X.partition_D(sdOut)
|
|
523
|
+
tXgdX = thr_store_dX.partition_D(gdX)
|
|
524
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
525
|
+
# This doesn't change across iterations
|
|
526
|
+
tXpX = (
|
|
527
|
+
utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
528
|
+
if not is_even_N
|
|
529
|
+
else None
|
|
530
|
+
)
|
|
454
531
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
|
|
532
|
+
tXrX, tXrdOut, tXrdX = [
|
|
533
|
+
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
|
|
534
|
+
]
|
|
535
|
+
|
|
536
|
+
# Prefetch the first batch
|
|
537
|
+
row = tXcX[None, None, None, bidx_start][0][0]
|
|
538
|
+
if row < M:
|
|
539
|
+
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
540
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
|
|
541
|
+
cute.copy(
|
|
542
|
+
copy_atom_load_X_async,
|
|
543
|
+
tXgX_cur,
|
|
544
|
+
tXsX[None, None, None, 0],
|
|
545
|
+
pred=tXpX,
|
|
465
546
|
)
|
|
466
|
-
|
|
467
|
-
|
|
547
|
+
cute.copy(
|
|
548
|
+
copy_atom_load_X_async,
|
|
549
|
+
tXgdOut_cur,
|
|
550
|
+
tXsdOut[None, None, None, 0],
|
|
551
|
+
pred=tXpX,
|
|
468
552
|
)
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
tXrRstd = thr_copy_W.partition_S(gRstd)
|
|
476
|
-
thrDx = thr_store_dx.partition_D(gDx)
|
|
477
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
478
|
-
|
|
479
|
-
tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
|
|
480
|
-
|
|
481
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
553
|
+
elif tiler_mn[0] > 1:
|
|
554
|
+
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
555
|
+
# later into registers, causing wrong dW.
|
|
556
|
+
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
557
|
+
utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
|
|
558
|
+
cute.arch.cp_async_commit_group()
|
|
482
559
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
|
|
560
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
561
|
+
cute.arch.cluster_wait()
|
|
486
562
|
|
|
563
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
564
|
+
tXrdW.fill(0.0)
|
|
565
|
+
stage = Int32(0)
|
|
566
|
+
producer_phase = Int32(1)
|
|
567
|
+
consumer_phase = Int32(0)
|
|
568
|
+
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
569
|
+
row = tXcX[None, None, None, bidx][0][0]
|
|
570
|
+
rstd = cutlass.Float.zero
|
|
571
|
+
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
572
|
+
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
573
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
|
|
574
|
+
None, None, None, 0
|
|
575
|
+
]
|
|
576
|
+
cute.copy(
|
|
577
|
+
copy_atom_load_X_async,
|
|
578
|
+
tXgX_cur,
|
|
579
|
+
tXsX[None, None, None, stage ^ 1],
|
|
580
|
+
pred=tXpX,
|
|
581
|
+
)
|
|
582
|
+
cute.copy(
|
|
583
|
+
copy_atom_load_X_async,
|
|
584
|
+
tXgdOut_cur,
|
|
585
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
586
|
+
pred=tXpX,
|
|
587
|
+
)
|
|
588
|
+
elif tiler_mn[0] > 1:
|
|
589
|
+
utils.fill_oob(
|
|
590
|
+
tXsX[None, None, None, stage ^ 1],
|
|
591
|
+
None,
|
|
592
|
+
fill_value=mX.element_type.zero,
|
|
593
|
+
)
|
|
594
|
+
utils.fill_oob(
|
|
595
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
596
|
+
None,
|
|
597
|
+
fill_value=mdOut.element_type.zero,
|
|
598
|
+
)
|
|
599
|
+
cute.arch.cp_async_commit_group()
|
|
600
|
+
if row < M or tiler_mn[0] == 1:
|
|
601
|
+
rstd = mRstd[row]
|
|
602
|
+
cute.arch.cp_async_wait_group(1)
|
|
603
|
+
cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
|
|
487
604
|
x = tXrX.load().to(cute.Float32)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
rstd = tXrRstd[0]
|
|
605
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
606
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
491
607
|
x_hat = x * rstd
|
|
492
608
|
wdy = dout * weight
|
|
493
|
-
|
|
494
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
495
|
-
|
|
496
|
-
row = tXcX[0][0]
|
|
497
609
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
498
|
-
cute.arch.
|
|
499
|
-
cute.arch.cluster_wait()
|
|
500
|
-
else:
|
|
501
|
-
cute.arch.barrier()
|
|
502
|
-
|
|
610
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
503
611
|
mean_xhat_wdy = (
|
|
504
612
|
utils.row_reduce(
|
|
505
613
|
x_hat * wdy,
|
|
506
614
|
cute.ReductionOp.ADD,
|
|
507
615
|
threads_per_row,
|
|
508
|
-
reduction_buffer[None, None,
|
|
509
|
-
|
|
616
|
+
reduction_buffer[None, None, stage],
|
|
617
|
+
(mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
618
|
+
phase=consumer_phase,
|
|
510
619
|
init_val=0.0,
|
|
511
|
-
hook_fn=cute.arch.cluster_wait
|
|
512
|
-
if cutlass.const_expr(self.cluster_n > 1)
|
|
513
|
-
else None,
|
|
514
620
|
)
|
|
515
621
|
/ shape[1]
|
|
516
622
|
)
|
|
517
623
|
|
|
518
|
-
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
519
|
-
frgDx.store(dx.to(frgDout.element_type))
|
|
520
|
-
|
|
521
|
-
if row < M:
|
|
522
|
-
cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
|
|
523
|
-
|
|
524
624
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
cute.arch.
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
625
|
+
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
626
|
+
# Requires adjusting the thread_count when initializing the mbar
|
|
627
|
+
cute.arch.sync_warp()
|
|
628
|
+
lane_idx = cute.arch.lane_idx()
|
|
629
|
+
if lane_idx < self.cluster_n:
|
|
630
|
+
cute.arch.mbarrier_arrive(
|
|
631
|
+
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
if cutlass.const_expr(self.reload_wdy == "smem"):
|
|
635
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
636
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
637
|
+
wdy = dout * weight
|
|
535
638
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
639
|
+
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
640
|
+
tXrdX.store(dx.to(tXrdOut.element_type))
|
|
641
|
+
if row < M or tiler_mn[0] == 1:
|
|
642
|
+
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
643
|
+
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
644
|
+
# Accumulate weight gradients in fp32
|
|
645
|
+
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
646
|
+
|
|
647
|
+
stage ^= 1
|
|
648
|
+
if stage == 0:
|
|
649
|
+
consumer_phase ^= 1
|
|
650
|
+
producer_phase ^= 1
|
|
651
|
+
|
|
652
|
+
if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
653
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
654
|
+
|
|
655
|
+
if cutlass.const_expr(tiler_mn[0] > 1):
|
|
656
|
+
# reduction of dw_partial within the same threadblock
|
|
657
|
+
sdW = cute.make_tensor(
|
|
658
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
659
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
660
|
+
)
|
|
661
|
+
tXsdW = thr_copy_X.partition_D(sdW)
|
|
548
662
|
cute.arch.barrier()
|
|
549
|
-
|
|
663
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
664
|
+
if row > 0:
|
|
665
|
+
cute.autovec_copy(tXrdW, tXsdW)
|
|
666
|
+
cute.arch.barrier()
|
|
667
|
+
if row == 0:
|
|
668
|
+
for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
|
|
669
|
+
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
670
|
+
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
671
|
+
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
672
|
+
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
673
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
550
674
|
|
|
551
|
-
|
|
552
|
-
|
|
675
|
+
else:
|
|
676
|
+
# dw is already in fp32, so we can directly copy to global memory
|
|
677
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
553
678
|
|
|
554
679
|
|
|
555
680
|
def _rmsnorm_backward(
|
|
@@ -573,37 +698,58 @@ def _rmsnorm_backward(
|
|
|
573
698
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
574
699
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
575
700
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
576
|
-
assert x.dtype in [
|
|
577
|
-
|
|
701
|
+
assert x.dtype in [
|
|
702
|
+
torch.float16,
|
|
703
|
+
torch.bfloat16,
|
|
704
|
+
torch.float32,
|
|
705
|
+
], "Unsupported dtype"
|
|
706
|
+
|
|
707
|
+
assert weight.dtype in [
|
|
708
|
+
torch.float32,
|
|
709
|
+
torch.bfloat16,
|
|
710
|
+
torch.float16,
|
|
711
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
578
712
|
|
|
579
713
|
M, N = x.shape
|
|
580
714
|
dx = torch.empty_like(x)
|
|
581
715
|
|
|
582
716
|
device = x.device
|
|
583
717
|
|
|
584
|
-
|
|
585
|
-
|
|
718
|
+
# This should be tuned on how many CTAs can be launched on each SM
|
|
719
|
+
sm_count_multiple = (
|
|
720
|
+
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
721
|
+
)
|
|
722
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
723
|
+
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
724
|
+
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
725
|
+
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
726
|
+
# avoid wave quantization.
|
|
727
|
+
sm_count = (
|
|
728
|
+
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
732
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
586
733
|
|
|
587
734
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
588
735
|
|
|
589
|
-
convert_from_dlpack = lambda
|
|
590
|
-
from_dlpack(
|
|
591
|
-
mode=0, stride_order=(0, 1)
|
|
592
|
-
)
|
|
736
|
+
convert_from_dlpack = lambda x: (
|
|
737
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
593
738
|
)
|
|
594
|
-
|
|
595
739
|
x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
|
|
596
740
|
|
|
741
|
+
# Handle weight div based on weight dtype
|
|
742
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
597
743
|
weight_tensor = utils.convert_from_dlpack(
|
|
598
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
744
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
599
745
|
)
|
|
600
746
|
|
|
601
|
-
dw_partial_tensor =
|
|
747
|
+
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
602
748
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
603
749
|
|
|
604
750
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
605
751
|
|
|
606
|
-
compile_key = (dtype, N)
|
|
752
|
+
compile_key = (dtype, N, weight.dtype)
|
|
607
753
|
if compile_key not in _rmsnorm_backward.compile_cache:
|
|
608
754
|
rmsnorm_backward_op = RMSNormBackward(dtype, N)
|
|
609
755
|
_rmsnorm_backward.compile_cache[compile_key] = cute.compile(
|
|
@@ -625,9 +771,10 @@ def _rmsnorm_backward(
|
|
|
625
771
|
rstd_tensor,
|
|
626
772
|
dx_tensor,
|
|
627
773
|
dw_partial_tensor,
|
|
774
|
+
sm_count,
|
|
628
775
|
current_stream,
|
|
629
776
|
)
|
|
630
|
-
|
|
777
|
+
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
631
778
|
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
632
779
|
return dx, dw
|
|
633
780
|
|
|
@@ -638,16 +785,29 @@ _rmsnorm_backward.compile_cache = {}
|
|
|
638
785
|
class RMSNormFunction(torch.autograd.Function):
|
|
639
786
|
@staticmethod
|
|
640
787
|
def forward(ctx, x, weight, eps):
|
|
788
|
+
x_shape_start = x.shape
|
|
789
|
+
|
|
790
|
+
# Flatten input
|
|
791
|
+
x = x.view(-1, x.shape[-1])
|
|
792
|
+
|
|
641
793
|
out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
|
|
642
794
|
ctx.save_for_backward(x, weight, rstd)
|
|
643
795
|
ctx.eps = eps
|
|
644
|
-
|
|
796
|
+
ctx.x_shape_start = x_shape_start
|
|
797
|
+
|
|
798
|
+
return out.reshape(x_shape_start)
|
|
645
799
|
|
|
646
800
|
@staticmethod
|
|
647
801
|
def backward(ctx, dout):
|
|
648
802
|
x, weight, rstd = ctx.saved_tensors
|
|
803
|
+
x_shape_start = ctx.x_shape_start
|
|
804
|
+
# Reshape dout to match the flattened shape used in forward
|
|
805
|
+
dout = dout.view(-1, dout.shape[-1])
|
|
649
806
|
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
650
|
-
|
|
807
|
+
dx = dx.view(x_shape_start)
|
|
808
|
+
# dx is returned for input gradient,
|
|
809
|
+
# dw is returned for weight gradient,
|
|
810
|
+
# None for eps gradient
|
|
651
811
|
return dx, dw, None
|
|
652
812
|
|
|
653
813
|
|
|
@@ -663,3 +823,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
|
|
|
663
823
|
Normalized output tensor of same shape as x
|
|
664
824
|
"""
|
|
665
825
|
return RMSNormFunction.apply(x, weight, eps)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class QuackRMSNorm(torch.nn.Module):
|
|
829
|
+
"""RMSNorm module that behaves like torch.nn.RMSNorm.
|
|
830
|
+
|
|
831
|
+
This class provides a drop-in replacement for torch.nn.RMSNorm that uses
|
|
832
|
+
the quack.rmsnorm implementation under the hood.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
dim (int): The dimension to normalize over
|
|
836
|
+
eps (float, optional): A small constant for numerical stability. Default: 1e-6
|
|
837
|
+
|
|
838
|
+
Attributes:
|
|
839
|
+
weight (torch.nn.Parameter): The learnable weight parameter
|
|
840
|
+
eps (float): A small constant for numerical stability
|
|
841
|
+
"""
|
|
842
|
+
|
|
843
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
844
|
+
super().__init__()
|
|
845
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
846
|
+
self.eps = eps
|
|
847
|
+
|
|
848
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
849
|
+
"""Apply RMSNorm to the input tensor.
|
|
850
|
+
|
|
851
|
+
Args:
|
|
852
|
+
x (torch.Tensor): Input tensor
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
torch.Tensor: Normalized tensor
|
|
856
|
+
"""
|
|
857
|
+
return rmsnorm(x, self.weight, self.eps)
|
|
858
|
+
|
|
859
|
+
def reset_parameters(self):
|
|
860
|
+
"""Reset the weight parameter to ones."""
|
|
861
|
+
torch.nn.init.ones_(self.weight)
|
quack/softmax.py
CHANGED
|
@@ -133,9 +133,7 @@ class Softmax(ReductionBase):
|
|
|
133
133
|
|
|
134
134
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
135
135
|
tXpX = (
|
|
136
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
137
|
-
if cutlass.const_expr(not is_even_N)
|
|
138
|
-
else None
|
|
136
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
139
137
|
)
|
|
140
138
|
if tXcX[0][0] < shape[0]:
|
|
141
139
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
quack/utils.py
CHANGED
|
@@ -120,12 +120,20 @@ def cluster_reduce(
|
|
|
120
120
|
reduction_buffer: cute.Tensor,
|
|
121
121
|
mbar_ptr: cute.Pointer,
|
|
122
122
|
init_val: cute.Numeric = 0.0,
|
|
123
|
+
phase: Optional[cutlass.Int32] = None,
|
|
123
124
|
) -> cute.Numeric:
|
|
124
125
|
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
125
126
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
126
127
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
127
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
128
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
128
129
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
130
|
+
if warp_idx == 0:
|
|
131
|
+
with cute.arch.elect_one():
|
|
132
|
+
num_warps = rows_per_block * warps_per_row
|
|
133
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
134
|
+
mbar_ptr,
|
|
135
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
136
|
+
)
|
|
129
137
|
if lane_idx < cluster_n:
|
|
130
138
|
store_shared_remote(
|
|
131
139
|
val,
|
|
@@ -133,7 +141,7 @@ def cluster_reduce(
|
|
|
133
141
|
mbar_ptr,
|
|
134
142
|
peer_cta_rank_in_cluster=lane_idx,
|
|
135
143
|
)
|
|
136
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
144
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
137
145
|
block_reduce_val = init_val
|
|
138
146
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
139
147
|
for i in cutlass.range_constexpr(num_iter):
|
|
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
|
|
|
149
157
|
op: Callable,
|
|
150
158
|
reduction_buffer: cute.Tensor,
|
|
151
159
|
mbar_ptr: Optional[cute.Pointer],
|
|
160
|
+
phase: Optional[cutlass.Int32] = None,
|
|
152
161
|
init_val: cute.Numeric = 0.0,
|
|
153
162
|
) -> cute.Numeric:
|
|
154
163
|
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
155
164
|
if cutlass.const_expr(mbar_ptr is None):
|
|
156
165
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
157
166
|
else:
|
|
158
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
|
|
167
|
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
159
168
|
|
|
160
169
|
|
|
161
170
|
@cute.jit
|
|
@@ -165,6 +174,7 @@ def row_reduce(
|
|
|
165
174
|
threads_per_row: cutlass.Constexpr[int],
|
|
166
175
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
167
176
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
177
|
+
phase: Optional[cutlass.Int32] = None,
|
|
168
178
|
init_val: cute.Numeric = 0.0,
|
|
169
179
|
hook_fn: Optional[Callable] = None,
|
|
170
180
|
) -> cute.Numeric:
|
|
@@ -193,7 +203,7 @@ def row_reduce(
|
|
|
193
203
|
), "mbar_ptr must be provided for cluster reduction"
|
|
194
204
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
195
205
|
val = block_or_cluster_reduce(
|
|
196
|
-
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
206
|
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
197
207
|
)
|
|
198
208
|
return val
|
|
199
209
|
|
|
@@ -205,6 +215,7 @@ def online_softmax_reduce(
|
|
|
205
215
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
206
216
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
207
217
|
hook_fn: Optional[Callable] = None,
|
|
218
|
+
phase: Optional[cutlass.Int32] = None,
|
|
208
219
|
return_exp_x: bool = False,
|
|
209
220
|
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
210
221
|
assert x.dtype == Float32, "x must be of type Float32"
|
|
@@ -225,7 +236,7 @@ def online_softmax_reduce(
|
|
|
225
236
|
if cutlass.const_expr(hook_fn is not None):
|
|
226
237
|
hook_fn()
|
|
227
238
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
228
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
239
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
229
240
|
assert (
|
|
230
241
|
cluster_n == 1 or mbar_ptr is not None
|
|
231
242
|
), "mbar_ptr must be provided for cluster reduction"
|
|
@@ -251,6 +262,13 @@ def online_softmax_reduce(
|
|
|
251
262
|
max_x = max_x_final
|
|
252
263
|
else:
|
|
253
264
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
265
|
+
if warp_idx == 0:
|
|
266
|
+
with cute.arch.elect_one():
|
|
267
|
+
num_warps = rows_per_block * warps_per_row
|
|
268
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
269
|
+
mbar_ptr,
|
|
270
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
271
|
+
)
|
|
254
272
|
if lane_idx < cluster_n:
|
|
255
273
|
store_shared_remote(
|
|
256
274
|
f32x2_to_i64(max_x, sum_exp_x),
|
|
@@ -258,7 +276,7 @@ def online_softmax_reduce(
|
|
|
258
276
|
mbar_ptr,
|
|
259
277
|
peer_cta_rank_in_cluster=lane_idx,
|
|
260
278
|
)
|
|
261
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
279
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
262
280
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
263
281
|
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
264
282
|
max_x_single_warp.fill(-Float32.inf)
|
|
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
351
369
|
|
|
352
370
|
|
|
353
371
|
@cute.jit
|
|
354
|
-
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
372
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
|
355
373
|
"""Fill out-of-bounds values in shared memory tensor.
|
|
356
374
|
|
|
357
375
|
Args:
|
|
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
|
|
|
361
379
|
"""
|
|
362
380
|
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
363
381
|
tXrX_fill.fill(fill_value)
|
|
364
|
-
for rest_v in cutlass.range_constexpr(
|
|
365
|
-
for rest_k in cutlass.range_constexpr(
|
|
366
|
-
if
|
|
382
|
+
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
383
|
+
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
384
|
+
if cutlass.const_expr(tXpX is not None):
|
|
385
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
386
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
387
|
+
else:
|
|
367
388
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
368
389
|
|
|
369
390
|
|
|
@@ -396,6 +417,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
396
417
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
397
418
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
398
419
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
420
|
+
assert len(flat_coord_i64) == len(
|
|
421
|
+
flat_stride
|
|
422
|
+
), "Coordinate and stride must have the same length"
|
|
399
423
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
400
424
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
401
425
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -406,3 +430,19 @@ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=No
|
|
|
406
430
|
assumed_align=tensor.iterator.max_alignment,
|
|
407
431
|
)
|
|
408
432
|
return cute.make_tensor(new_ptr, tensor.layout)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@dsl_user_op
|
|
436
|
+
def coord_offset_i64(
|
|
437
|
+
idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
|
|
438
|
+
) -> cute.Tensor:
|
|
439
|
+
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
440
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
441
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
442
|
+
new_ptr = cute.make_ptr(
|
|
443
|
+
tensor.element_type,
|
|
444
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
445
|
+
tensor.memspace,
|
|
446
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
447
|
+
)
|
|
448
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
quack/__init__.py,sha256=tDgX5MF1ttfEyDVFWi47DA8tDooYcBQlkuzvabGUoQI,203
|
|
2
|
+
quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
|
|
3
|
+
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
+
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
+
quack/rmsnorm.py,sha256=-qrKqPKk0fUuq0a5-vJmZZ7nQsHgyaqTg0EKhWT44r0,32738
|
|
6
|
+
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
|
+
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
+
quack_kernels-0.1.8.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
+
quack_kernels-0.1.8.dist-info/METADATA,sha256=b_2PxFEoVqWJbT2FtuP9FJyF-jpL2Z3q9OHoOEipqo4,289
|
|
10
|
+
quack_kernels-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
quack_kernels-0.1.8.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
+
quack_kernels-0.1.8.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=EV_43VfcxQUsFu_Nfq0944ImloZ2T9X94-8T9IQM0I8,203
|
|
2
|
-
quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
|
|
3
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
-
quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
|
|
5
|
-
quack/rmsnorm.py,sha256=rbRP_O-EJYvrvxYPFwfy0yCbG7Qs_DgbzgacxTLvih4,24159
|
|
6
|
-
quack/softmax.py,sha256=b-QEQiGjOEPidolE2--K21kwEaLJ-3wQoGIz_BsEcSI,16742
|
|
7
|
-
quack/utils.py,sha256=laz_lqeggiIOY_xCs3c3VvCPP8bJmSKBfjFpFR1Al80,15946
|
|
8
|
-
quack_kernels-0.1.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
-
quack_kernels-0.1.6.dist-info/METADATA,sha256=NKhSoudW9lNdYryNiMkyjKXYl5n37cuFMbyLv3zr5js,289
|
|
10
|
-
quack_kernels-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
quack_kernels-0.1.6.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
-
quack_kernels-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|