quack-kernels 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +11 -7
- quack/layernorm.py +351 -0
- quack/reduction_base.py +16 -8
- quack/rmsnorm.py +227 -151
- quack/softmax.py +9 -6
- quack/utils.py +66 -10
- {quack_kernels-0.1.5.dist-info → quack_kernels-0.1.7.dist-info}/METADATA +1 -1
- quack_kernels-0.1.7.dist-info/RECORD +12 -0
- quack_kernels-0.1.5.dist-info/RECORD +0 -11
- {quack_kernels-0.1.5.dist-info → quack_kernels-0.1.7.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.5.dist-info → quack_kernels-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.5.dist-info → quack_kernels-0.1.7.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -104,7 +104,10 @@ class CrossEntropy(ReductionBase):
|
|
|
104
104
|
shape: cute.Shape = mX.shape
|
|
105
105
|
idX = cute.make_identity_tensor(shape)
|
|
106
106
|
# slice for CTAs
|
|
107
|
-
|
|
107
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
108
|
+
mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
109
|
+
gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
|
|
110
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
108
111
|
|
|
109
112
|
smem = cutlass.utils.SmemAllocator()
|
|
110
113
|
sX = smem.allocate_tensor(
|
|
@@ -150,7 +153,9 @@ class CrossEntropy(ReductionBase):
|
|
|
150
153
|
|
|
151
154
|
target_logit = cute.Float32.zero
|
|
152
155
|
if row < shape[0] and tXcX[0][1] == 0:
|
|
153
|
-
|
|
156
|
+
# Use Int64 for indexing to deal with large tensors
|
|
157
|
+
mX_off = utils.domain_offset_i64((row, 0), mX)
|
|
158
|
+
target_logit = cute.Float32(mX_off[0, target])
|
|
154
159
|
|
|
155
160
|
threads_per_row = tv_layout.shape[0][0]
|
|
156
161
|
if cutlass.const_expr(not self.online_softmax):
|
|
@@ -363,11 +368,10 @@ class CrossEntropyBackward:
|
|
|
363
368
|
)
|
|
364
369
|
|
|
365
370
|
idX = cute.make_identity_tensor(shape)
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
]
|
|
371
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
372
|
+
mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
|
|
373
|
+
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
374
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
371
375
|
|
|
372
376
|
copy_atom_load_X = cute.make_copy_atom(
|
|
373
377
|
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
quack/layernorm.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
import quack.utils as utils
|
|
13
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LayerNorm(ReductionBase):
|
|
17
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
18
|
+
super().__init__(dtype, N, stage=2) # 2 stages for mean and var
|
|
19
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
20
|
+
self.delay_w_load = False
|
|
21
|
+
|
|
22
|
+
def _calculate_threads_per_row(self):
|
|
23
|
+
N = self.N
|
|
24
|
+
return (
|
|
25
|
+
8
|
|
26
|
+
if N <= 64
|
|
27
|
+
else (
|
|
28
|
+
16
|
|
29
|
+
if N <= 128
|
|
30
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _set_cluster_n(self):
|
|
35
|
+
N = self.N
|
|
36
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
37
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
38
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
39
|
+
cluster_n = (
|
|
40
|
+
1
|
|
41
|
+
if N <= 16 * 1024
|
|
42
|
+
else (
|
|
43
|
+
2
|
|
44
|
+
if N <= 32 * 1024
|
|
45
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
else: # fp32
|
|
49
|
+
cluster_n = (
|
|
50
|
+
1
|
|
51
|
+
if N <= 32 * 1024
|
|
52
|
+
else (
|
|
53
|
+
2
|
|
54
|
+
if N <= 64 * 1024
|
|
55
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
self.cluster_n = cluster_n
|
|
59
|
+
|
|
60
|
+
@cute.jit
|
|
61
|
+
def __call__(
|
|
62
|
+
self,
|
|
63
|
+
mX: cute.Tensor,
|
|
64
|
+
mW: cute.Tensor,
|
|
65
|
+
mO: cute.Tensor,
|
|
66
|
+
mRstd: Optional[cute.Tensor],
|
|
67
|
+
mMean: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
if cutlass.const_expr(mMean is not None):
|
|
85
|
+
mMean_expanded_layout = cute.append(
|
|
86
|
+
mMean.layout, cute.make_layout((self.N,), stride=(0,))
|
|
87
|
+
)
|
|
88
|
+
mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
|
|
89
|
+
self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
90
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
91
|
+
block=[num_threads, 1, 1],
|
|
92
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
93
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
94
|
+
stream=stream,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@cute.kernel
|
|
98
|
+
def kernel(
|
|
99
|
+
self,
|
|
100
|
+
mX: cute.Tensor,
|
|
101
|
+
mW: cute.Tensor,
|
|
102
|
+
mO: cute.Tensor,
|
|
103
|
+
mRstd: Optional[cute.Tensor],
|
|
104
|
+
mMean: Optional[cute.Tensor],
|
|
105
|
+
eps: cute.Float32,
|
|
106
|
+
tv_layout: cute.Layout,
|
|
107
|
+
tiler_mn: cute.Shape,
|
|
108
|
+
reload_from: cutlass.Constexpr = None,
|
|
109
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
110
|
+
):
|
|
111
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
112
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
113
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
114
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
115
|
+
else:
|
|
116
|
+
cluster_y = cutlass.const_expr(0)
|
|
117
|
+
|
|
118
|
+
smem = cutlass.utils.SmemAllocator()
|
|
119
|
+
sX = smem.allocate_tensor(
|
|
120
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
121
|
+
)
|
|
122
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
123
|
+
|
|
124
|
+
shape = mX.shape
|
|
125
|
+
idX = cute.make_identity_tensor(shape)
|
|
126
|
+
# slice for CTAs
|
|
127
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
128
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
129
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
130
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
131
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
132
|
+
gRstd = (
|
|
133
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
134
|
+
if cutlass.const_expr(mRstd is not None)
|
|
135
|
+
else None
|
|
136
|
+
)
|
|
137
|
+
gMean = (
|
|
138
|
+
cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
|
|
139
|
+
if cutlass.const_expr(mMean is not None)
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# declare the atoms which will be used later for memory copy
|
|
144
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
145
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
146
|
+
)
|
|
147
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
148
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
149
|
+
)
|
|
150
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
151
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
152
|
+
)
|
|
153
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
154
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
158
|
+
tidx
|
|
159
|
+
)
|
|
160
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
161
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
162
|
+
|
|
163
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
164
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
165
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
166
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
167
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
168
|
+
tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
|
|
169
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
170
|
+
|
|
171
|
+
# allocate fragments for gmem->rmem
|
|
172
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
173
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
174
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
175
|
+
|
|
176
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
177
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
178
|
+
|
|
179
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
180
|
+
row = tXcX[0][0]
|
|
181
|
+
if row < shape[0]:
|
|
182
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
183
|
+
cute.arch.cp_async_commit_group()
|
|
184
|
+
|
|
185
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
186
|
+
if cutlass.const_expr(not delay_w_load):
|
|
187
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
188
|
+
|
|
189
|
+
cute.arch.cp_async_wait_group(0)
|
|
190
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
191
|
+
x = tXrX.load().to(cute.Float32)
|
|
192
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
193
|
+
sum_x = utils.row_reduce(
|
|
194
|
+
x,
|
|
195
|
+
cute.ReductionOp.ADD,
|
|
196
|
+
threads_per_row,
|
|
197
|
+
reduction_buffer[None, None, 0],
|
|
198
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
199
|
+
init_val=0.0,
|
|
200
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
201
|
+
)
|
|
202
|
+
mean = sum_x / shape[1]
|
|
203
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
204
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
205
|
+
x = tXrX.load().to(cute.Float32)
|
|
206
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
207
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
208
|
+
x = tXrX.load().to(cute.Float32)
|
|
209
|
+
|
|
210
|
+
sum_sq_x_sub_mean = utils.row_reduce(
|
|
211
|
+
(x - mean) * (x - mean),
|
|
212
|
+
cute.ReductionOp.ADD,
|
|
213
|
+
threads_per_row,
|
|
214
|
+
reduction_buffer[None, None, 1],
|
|
215
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
216
|
+
init_val=0.0,
|
|
217
|
+
)
|
|
218
|
+
rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
|
|
219
|
+
if cutlass.const_expr(mRstd is not None):
|
|
220
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
221
|
+
if (
|
|
222
|
+
tXcX[0][1] == 0
|
|
223
|
+
and row < shape[0]
|
|
224
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
225
|
+
):
|
|
226
|
+
tXrRstd[0] = rstd
|
|
227
|
+
if cutlass.const_expr(mMean is not None):
|
|
228
|
+
# Only the thread corresponding to column 0 writes out the mean to gmem
|
|
229
|
+
if (
|
|
230
|
+
tXcX[0][1] == 0
|
|
231
|
+
and row < shape[0]
|
|
232
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
233
|
+
):
|
|
234
|
+
tXrMean[0] = mean
|
|
235
|
+
if cutlass.const_expr(delay_w_load):
|
|
236
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
237
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
238
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
239
|
+
x = tXrX.load().to(cute.Float32)
|
|
240
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
241
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
242
|
+
x = tXrX.load().to(cute.Float32)
|
|
243
|
+
x_hat = (x - mean) * rstd
|
|
244
|
+
w = tXrW.load().to(cute.Float32)
|
|
245
|
+
y = x_hat * w
|
|
246
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
247
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
248
|
+
if row < shape[0]:
|
|
249
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def layernorm(
|
|
253
|
+
x: torch.Tensor,
|
|
254
|
+
weight: torch.Tensor,
|
|
255
|
+
eps: float = 1e-6,
|
|
256
|
+
return_rstd: bool = False,
|
|
257
|
+
return_mean: bool = False,
|
|
258
|
+
) -> torch.Tensor:
|
|
259
|
+
"""LayerNorm forward pass.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
x: Input tensor of shape (M, N)
|
|
263
|
+
weight: Weight tensor of shape (N,)
|
|
264
|
+
eps: Small value for numerical stability
|
|
265
|
+
return_rstd: Whether to return the reciprocal standard deviation
|
|
266
|
+
return_mean: Whether to return the mean
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Normalized output tensor of same shape as x
|
|
270
|
+
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
271
|
+
If return_mean is True, also returns mean tensor of shape (M,)
|
|
272
|
+
"""
|
|
273
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
274
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
275
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
276
|
+
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
277
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
278
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
279
|
+
M, N = x.shape
|
|
280
|
+
device = x.device
|
|
281
|
+
out = torch.empty_like(x)
|
|
282
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
283
|
+
mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
|
|
284
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
285
|
+
convert_from_dlpack = lambda x: (
|
|
286
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
287
|
+
mode=0, stride_order=(0, 1)
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
x_tensor, out_tensor = [
|
|
291
|
+
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
292
|
+
convert_from_dlpack(t)
|
|
293
|
+
for t in (x, out)
|
|
294
|
+
]
|
|
295
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
296
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
297
|
+
)
|
|
298
|
+
rstd_tensor = (
|
|
299
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
300
|
+
if rstd is not None
|
|
301
|
+
else None
|
|
302
|
+
)
|
|
303
|
+
mean_tensor = (
|
|
304
|
+
from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
305
|
+
if mean is not None
|
|
306
|
+
else None
|
|
307
|
+
)
|
|
308
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
309
|
+
compile_key = (dtype, N, rstd is not None, mean is not None)
|
|
310
|
+
if compile_key not in layernorm.compile_cache:
|
|
311
|
+
rmsnorm_op = LayerNorm(dtype, N)
|
|
312
|
+
layernorm.compile_cache[compile_key] = cute.compile(
|
|
313
|
+
rmsnorm_op,
|
|
314
|
+
x_tensor,
|
|
315
|
+
weight_tensor,
|
|
316
|
+
out_tensor,
|
|
317
|
+
rstd_tensor,
|
|
318
|
+
mean_tensor,
|
|
319
|
+
current_stream,
|
|
320
|
+
)
|
|
321
|
+
layernorm.compile_cache[compile_key](
|
|
322
|
+
x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
|
|
323
|
+
)
|
|
324
|
+
return (
|
|
325
|
+
(out, rstd, mean)
|
|
326
|
+
if return_mean and return_rstd
|
|
327
|
+
else (
|
|
328
|
+
(out, rstd)
|
|
329
|
+
if return_rstd and not return_mean
|
|
330
|
+
else ((out, mean) if return_mean and not return_rstd else (out))
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
layernorm.compile_cache = {}
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
|
|
339
|
+
x_f32 = x.float()
|
|
340
|
+
return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
|
|
344
|
+
x_f32 = x.float()
|
|
345
|
+
mean = x_f32.mean(dim=-1, keepdim=True)
|
|
346
|
+
var = ((x_f32 - mean) ** 2).mean(dim=-1)
|
|
347
|
+
return 1.0 / torch.sqrt(var + eps)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def mean_ref(x: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
return x.float().mean(dim=-1)
|
quack/reduction_base.py
CHANGED
|
@@ -68,7 +68,7 @@ class ReductionBase:
|
|
|
68
68
|
)
|
|
69
69
|
|
|
70
70
|
def _allocate_reduction_buffer_and_mbar(
|
|
71
|
-
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
71
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
|
|
72
72
|
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
73
73
|
reduction_buffer = smem.allocate_tensor(
|
|
74
74
|
self.reduction_dtype,
|
|
@@ -76,20 +76,28 @@ class ReductionBase:
|
|
|
76
76
|
byte_alignment=4,
|
|
77
77
|
)
|
|
78
78
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
79
|
-
mbar_ptr = smem.allocate_array(
|
|
79
|
+
mbar_ptr = smem.allocate_array(
|
|
80
|
+
cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
|
|
81
|
+
)
|
|
80
82
|
else:
|
|
81
83
|
mbar_ptr = None
|
|
82
84
|
return reduction_buffer, mbar_ptr
|
|
83
85
|
|
|
84
86
|
@cute.jit
|
|
85
|
-
def _initialize_cluster(
|
|
87
|
+
def _initialize_cluster(
|
|
88
|
+
self,
|
|
89
|
+
tidx: cutlass.Int32,
|
|
90
|
+
mbar_ptr: cute.Pointer,
|
|
91
|
+
num_warps: int,
|
|
92
|
+
is_persistent: bool = False,
|
|
93
|
+
):
|
|
86
94
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
87
|
-
if tidx < self.stage:
|
|
95
|
+
if tidx < self.stage: # Initialize full barrier
|
|
88
96
|
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
97
|
+
if cutlass.const_expr(is_persistent): # Initialize empty barrier
|
|
98
|
+
cute.arch.mbarrier_init(
|
|
99
|
+
mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
|
|
100
|
+
)
|
|
89
101
|
cute.arch.mbarrier_init_fence()
|
|
90
|
-
if tidx < self.stage:
|
|
91
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
92
|
-
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
93
|
-
)
|
|
94
102
|
# Cluster arrive after barrier init
|
|
95
103
|
cute.arch.cluster_arrive_relaxed()
|
quack/rmsnorm.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
|
|
4
3
|
import torch
|
|
5
4
|
from typing import Optional
|
|
6
5
|
|
|
@@ -117,7 +116,10 @@ class RMSNorm(ReductionBase):
|
|
|
117
116
|
shape = mX.shape
|
|
118
117
|
idX = cute.make_identity_tensor(shape)
|
|
119
118
|
# slice for CTAs
|
|
120
|
-
|
|
119
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
120
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
121
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
122
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
121
123
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
122
124
|
gRstd = (
|
|
123
125
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
@@ -154,6 +156,7 @@ class RMSNorm(ReductionBase):
|
|
|
154
156
|
|
|
155
157
|
# allocate fragments for gmem->rmem
|
|
156
158
|
tWrW = cute.make_fragment_like(tWgW)
|
|
159
|
+
tWrW.fill(0.0)
|
|
157
160
|
tXrW = thr_copy_X.retile(tWrW)
|
|
158
161
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
159
162
|
|
|
@@ -297,8 +300,14 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
297
300
|
|
|
298
301
|
class RMSNormBackward(ReductionBase):
|
|
299
302
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
300
|
-
#
|
|
301
|
-
super().__init__(dtype, N, stage=
|
|
303
|
+
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
304
|
+
super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
|
|
305
|
+
if self.N > 128 * 1024 and self.dtype.width >= 32:
|
|
306
|
+
# Not enough smem
|
|
307
|
+
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
308
|
+
|
|
309
|
+
def _get_num_threads(self):
|
|
310
|
+
return 128 if self.N <= 4096 else 256
|
|
302
311
|
|
|
303
312
|
def _calculate_threads_per_row(self):
|
|
304
313
|
N = self.N
|
|
@@ -308,44 +317,38 @@ class RMSNormBackward(ReductionBase):
|
|
|
308
317
|
else (
|
|
309
318
|
16
|
|
310
319
|
if N <= 128
|
|
311
|
-
else (32 if N <=
|
|
320
|
+
else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
|
|
312
321
|
)
|
|
313
322
|
)
|
|
314
323
|
|
|
315
324
|
def _set_cluster_n(self):
|
|
316
325
|
N = self.N
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
2
|
|
323
|
-
if N <= 32 * 1024
|
|
324
|
-
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
325
|
-
)
|
|
326
|
-
)
|
|
327
|
-
else: # fp32
|
|
328
|
-
cluster_n = (
|
|
329
|
-
1
|
|
330
|
-
if N <= 32 * 1024
|
|
331
|
-
else (
|
|
332
|
-
2
|
|
333
|
-
if N <= 64 * 1024
|
|
334
|
-
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
335
|
-
)
|
|
336
|
-
)
|
|
326
|
+
cluster_n = (
|
|
327
|
+
1
|
|
328
|
+
if N <= 8 * 1024
|
|
329
|
+
else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
|
|
330
|
+
)
|
|
337
331
|
self.cluster_n = cluster_n
|
|
338
332
|
|
|
333
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
334
|
+
return (
|
|
335
|
+
# Multiply by 2 since we need space for X and dOut,
|
|
336
|
+
# and multiply by another 2 due to double buffering
|
|
337
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
|
|
338
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
339
|
+
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
340
|
+
)
|
|
341
|
+
|
|
339
342
|
@cute.jit
|
|
340
343
|
def __call__(
|
|
341
344
|
self,
|
|
342
345
|
mX: cute.Tensor,
|
|
343
346
|
mW: cute.Tensor,
|
|
344
|
-
|
|
347
|
+
mdOut: cute.Tensor,
|
|
345
348
|
mRstd: cute.Tensor,
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
sm_count: cutlass.
|
|
349
|
+
mdX: cute.Tensor,
|
|
350
|
+
mdW: cute.Tensor,
|
|
351
|
+
sm_count: cutlass.Int32,
|
|
349
352
|
stream: cuda.CUstream,
|
|
350
353
|
):
|
|
351
354
|
self._set_cluster_n()
|
|
@@ -356,14 +359,8 @@ class RMSNormBackward(ReductionBase):
|
|
|
356
359
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
357
360
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
358
361
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
num_blocks = (
|
|
363
|
-
sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
|
|
362
|
+
num_blocks = sm_count
|
|
363
|
+
self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
|
|
367
364
|
grid=[num_blocks, self.cluster_n, 1],
|
|
368
365
|
block=[num_threads, 1, 1],
|
|
369
366
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
@@ -376,177 +373,244 @@ class RMSNormBackward(ReductionBase):
|
|
|
376
373
|
self,
|
|
377
374
|
mX: cute.Tensor,
|
|
378
375
|
mW: cute.Tensor,
|
|
379
|
-
|
|
376
|
+
mdOut: cute.Tensor,
|
|
380
377
|
mRstd: cute.Tensor,
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
sm_count: cutlass.Constexpr,
|
|
378
|
+
mdX: cute.Tensor,
|
|
379
|
+
mdW: cute.Tensor,
|
|
384
380
|
tv_layout: cute.Layout,
|
|
385
381
|
tiler_mn: cute.Shape,
|
|
386
382
|
):
|
|
387
383
|
tidx, _, _ = cute.arch.thread_idx()
|
|
388
|
-
|
|
384
|
+
bidx_start, _, _ = cute.arch.block_idx()
|
|
389
385
|
gdim, _, _ = cute.arch.grid_dim()
|
|
386
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
387
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
388
|
+
else:
|
|
389
|
+
cluster_y = cutlass.const_expr(0)
|
|
390
390
|
|
|
391
391
|
shape = mX.shape
|
|
392
392
|
M, N = shape[0], shape[1]
|
|
393
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
393
394
|
|
|
394
395
|
idX = cute.make_identity_tensor(shape)
|
|
395
396
|
|
|
396
397
|
smem = cutlass.utils.SmemAllocator()
|
|
397
|
-
|
|
398
|
+
smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
|
|
399
|
+
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
|
|
400
|
+
sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
|
|
401
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
|
|
402
|
+
smem, tv_layout, is_persistent=True
|
|
403
|
+
)
|
|
404
|
+
if cutlass.const_expr(mbar_ptr is not None):
|
|
405
|
+
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
|
|
406
|
+
else:
|
|
407
|
+
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
398
408
|
|
|
399
409
|
copy_atom_load_X = cute.make_copy_atom(
|
|
400
410
|
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
401
411
|
)
|
|
402
|
-
|
|
412
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
413
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
414
|
+
)
|
|
403
415
|
copy_atom_load_W = cute.make_copy_atom(
|
|
404
416
|
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
405
417
|
)
|
|
406
|
-
|
|
407
418
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
408
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
419
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
|
|
409
420
|
)
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
|
|
421
|
+
copy_atom_store_dW = cute.make_copy_atom(
|
|
422
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
|
|
413
423
|
)
|
|
414
424
|
|
|
415
425
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
426
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
427
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
428
|
+
).get_slice(tidx)
|
|
416
429
|
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
417
|
-
|
|
418
|
-
|
|
430
|
+
thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
|
|
431
|
+
thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
419
432
|
|
|
420
|
-
gW = cute.local_tile(mW, tiler_mn, (
|
|
433
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
421
434
|
tWgW = thr_copy_W.partition_S(gW)
|
|
422
435
|
tWrW = cute.make_fragment_like(tWgW)
|
|
436
|
+
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
437
|
+
if not is_even_N:
|
|
438
|
+
tWrW.fill(0.0)
|
|
423
439
|
tXrW = thr_copy_X.retile(tWrW)
|
|
424
440
|
|
|
425
|
-
gW_coord = cute.local_tile(idX, tiler_mn, (0,
|
|
426
|
-
|
|
427
|
-
|
|
441
|
+
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
442
|
+
tWpW = (
|
|
443
|
+
utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
|
|
444
|
+
if not is_even_N
|
|
445
|
+
else None
|
|
446
|
+
)
|
|
428
447
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
429
448
|
weight = tXrW.load().to(cute.Float32)
|
|
430
449
|
|
|
431
450
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
432
451
|
|
|
433
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
434
|
-
|
|
435
|
-
dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
436
|
-
tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
|
|
452
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
437
453
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
M_pad = ((M + sm_count - 1) // sm_count) * sm_count
|
|
454
|
+
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
455
|
+
tdWpdW = (
|
|
456
|
+
utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
|
|
457
|
+
if not is_even_N
|
|
458
|
+
else None
|
|
459
|
+
)
|
|
445
460
|
|
|
446
|
-
|
|
461
|
+
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
462
|
+
tdWgdW = thr_copy_dW.partition_D(gdW)
|
|
463
|
+
tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
|
|
464
|
+
tXrdW = thr_copy_X.retile(tdWrdW)
|
|
447
465
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
466
|
+
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
467
|
+
gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
|
|
468
|
+
gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
|
|
469
|
+
cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
|
|
470
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
471
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
472
|
+
tXgdOut = thr_copy_X.partition_S(gdOut)
|
|
473
|
+
tXsdOut = thr_copy_X.partition_D(sdOut)
|
|
474
|
+
tXgdX = thr_store_dX.partition_D(gdX)
|
|
475
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
476
|
+
# This doesn't change across iterations
|
|
477
|
+
tXpX = (
|
|
478
|
+
utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
479
|
+
if not is_even_N
|
|
480
|
+
else None
|
|
481
|
+
)
|
|
451
482
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
483
|
+
tXrX, tXrdOut, tXrdX = [
|
|
484
|
+
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
# Prefetch the first batch
|
|
488
|
+
row = tXcX[None, None, None, bidx_start][0][0]
|
|
489
|
+
if row < M:
|
|
490
|
+
tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
|
|
491
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
|
|
492
|
+
cute.copy(
|
|
493
|
+
copy_atom_load_X_async,
|
|
494
|
+
tXgX_cur,
|
|
495
|
+
tXsX[None, None, None, 0],
|
|
496
|
+
pred=tXpX,
|
|
456
497
|
)
|
|
457
|
-
|
|
458
|
-
|
|
498
|
+
cute.copy(
|
|
499
|
+
copy_atom_load_X_async,
|
|
500
|
+
tXgdOut_cur,
|
|
501
|
+
tXsdOut[None, None, None, 0],
|
|
502
|
+
pred=tXpX,
|
|
459
503
|
)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
cX = cute.local_tile(
|
|
467
|
-
idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
471
|
-
thrDout = thr_copy_X.partition_S(gDout)
|
|
472
|
-
tXrRstd = thr_copy_W.partition_S(gRstd)
|
|
473
|
-
thrDx = thr_store_dx.partition_D(gDx)
|
|
474
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
475
|
-
|
|
476
|
-
tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
|
|
477
|
-
|
|
478
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
504
|
+
elif tiler_mn[0] > 1:
|
|
505
|
+
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
506
|
+
# later into registers, causing wrong dW.
|
|
507
|
+
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
508
|
+
utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
|
|
509
|
+
cute.arch.cp_async_commit_group()
|
|
479
510
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
|
|
511
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
512
|
+
cute.arch.cluster_wait()
|
|
483
513
|
|
|
514
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
515
|
+
tXrdW.fill(0.0)
|
|
516
|
+
stage = cutlass.Int32(0)
|
|
517
|
+
producer_phase = cutlass.Int32(1)
|
|
518
|
+
consumer_phase = cutlass.Int32(0)
|
|
519
|
+
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
520
|
+
row = tXcX[None, None, None, bidx][0][0]
|
|
521
|
+
rstd = cutlass.Float.zero
|
|
522
|
+
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
523
|
+
tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
|
|
524
|
+
tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
|
|
525
|
+
None, None, None, 0
|
|
526
|
+
]
|
|
527
|
+
cute.copy(
|
|
528
|
+
copy_atom_load_X_async,
|
|
529
|
+
tXgX_cur,
|
|
530
|
+
tXsX[None, None, None, stage ^ 1],
|
|
531
|
+
pred=tXpX,
|
|
532
|
+
)
|
|
533
|
+
cute.copy(
|
|
534
|
+
copy_atom_load_X_async,
|
|
535
|
+
tXgdOut_cur,
|
|
536
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
537
|
+
pred=tXpX,
|
|
538
|
+
)
|
|
539
|
+
elif tiler_mn[0] > 1:
|
|
540
|
+
utils.fill_oob(
|
|
541
|
+
tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
|
|
542
|
+
)
|
|
543
|
+
utils.fill_oob(
|
|
544
|
+
tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
|
|
545
|
+
)
|
|
546
|
+
cute.arch.cp_async_commit_group()
|
|
547
|
+
if row < M or tiler_mn[0] == 1:
|
|
548
|
+
rstd = mRstd[row]
|
|
549
|
+
cute.arch.cp_async_wait_group(1)
|
|
550
|
+
cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
|
|
484
551
|
x = tXrX.load().to(cute.Float32)
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
rstd = tXrRstd[0]
|
|
552
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
553
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
488
554
|
x_hat = x * rstd
|
|
489
555
|
wdy = dout * weight
|
|
490
|
-
|
|
491
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
492
|
-
|
|
493
|
-
row = tXcX[0][0]
|
|
494
556
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
495
|
-
cute.arch.
|
|
496
|
-
cute.arch.cluster_wait()
|
|
497
|
-
else:
|
|
498
|
-
cute.arch.barrier()
|
|
499
|
-
|
|
557
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
500
558
|
mean_xhat_wdy = (
|
|
501
559
|
utils.row_reduce(
|
|
502
560
|
x_hat * wdy,
|
|
503
561
|
cute.ReductionOp.ADD,
|
|
504
562
|
threads_per_row,
|
|
505
|
-
reduction_buffer[None, None,
|
|
506
|
-
|
|
563
|
+
reduction_buffer[None, None, stage],
|
|
564
|
+
mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
565
|
+
phase=consumer_phase,
|
|
507
566
|
init_val=0.0,
|
|
508
|
-
hook_fn=cute.arch.cluster_wait
|
|
509
|
-
if cutlass.const_expr(self.cluster_n > 1)
|
|
510
|
-
else None,
|
|
511
567
|
)
|
|
512
568
|
/ shape[1]
|
|
513
569
|
)
|
|
514
|
-
|
|
515
|
-
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
516
|
-
frgDx.store(dx.to(frgDout.element_type))
|
|
517
|
-
|
|
518
|
-
if row < M:
|
|
519
|
-
cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
|
|
520
|
-
|
|
521
570
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
cute.arch.
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
if cutlass.const_expr(self.cluster_n > 1):
|
|
542
|
-
cute.arch.
|
|
543
|
-
|
|
544
|
-
|
|
571
|
+
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
572
|
+
# Requires adjusting the thread_count when initializing the mbar
|
|
573
|
+
cute.arch.sync_warp()
|
|
574
|
+
lane_idx = cute.arch.lane_idx()
|
|
575
|
+
if lane_idx < self.cluster_n:
|
|
576
|
+
cute.arch.mbarrier_arrive(
|
|
577
|
+
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
578
|
+
)
|
|
579
|
+
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
580
|
+
tXrdX.store(dx.to(tXrdOut.element_type))
|
|
581
|
+
if row < M or tiler_mn[0] == 1:
|
|
582
|
+
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
583
|
+
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
584
|
+
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
585
|
+
stage ^= 1
|
|
586
|
+
if stage == 0:
|
|
587
|
+
consumer_phase ^= 1
|
|
588
|
+
producer_phase ^= 1
|
|
589
|
+
|
|
590
|
+
if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
591
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
592
|
+
|
|
593
|
+
if cutlass.const_expr(tiler_mn[0] > 1):
|
|
594
|
+
# reduction of dw_partial within the same threadblock
|
|
595
|
+
sdW = cute.make_tensor(
|
|
596
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
597
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
598
|
+
)
|
|
599
|
+
tXsdW = thr_copy_X.partition_D(sdW)
|
|
545
600
|
cute.arch.barrier()
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
601
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
602
|
+
if row > 0:
|
|
603
|
+
cute.autovec_copy(tXrdW, tXsdW)
|
|
604
|
+
cute.arch.barrier()
|
|
605
|
+
if row == 0:
|
|
606
|
+
for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
|
|
607
|
+
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
608
|
+
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
609
|
+
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
610
|
+
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
611
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
612
|
+
else:
|
|
613
|
+
cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
|
|
550
614
|
|
|
551
615
|
|
|
552
616
|
def _rmsnorm_backward(
|
|
@@ -578,8 +642,19 @@ def _rmsnorm_backward(
|
|
|
578
642
|
|
|
579
643
|
device = x.device
|
|
580
644
|
|
|
581
|
-
|
|
582
|
-
|
|
645
|
+
# This should be tuned on how many CTAs can be launched on each SM
|
|
646
|
+
sm_count_multiple = (
|
|
647
|
+
16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
|
|
648
|
+
)
|
|
649
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
|
|
650
|
+
# By right, if we're using cluster, this should be cluster_count not sm_count.
|
|
651
|
+
# But for cluster >= 4, due to quantization we would need to query active max cluster.
|
|
652
|
+
# Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
|
|
653
|
+
# avoid wave quantization.
|
|
654
|
+
sm_count = (
|
|
655
|
+
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
656
|
+
)
|
|
657
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
|
|
583
658
|
|
|
584
659
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
585
660
|
|
|
@@ -622,6 +697,7 @@ def _rmsnorm_backward(
|
|
|
622
697
|
rstd_tensor,
|
|
623
698
|
dx_tensor,
|
|
624
699
|
dw_partial_tensor,
|
|
700
|
+
sm_count,
|
|
625
701
|
current_stream,
|
|
626
702
|
)
|
|
627
703
|
|
quack/softmax.py
CHANGED
|
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
|
|
|
98
98
|
shape = mX.shape
|
|
99
99
|
idX = cute.make_identity_tensor(shape)
|
|
100
100
|
# slice for CTAs
|
|
101
|
-
|
|
101
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
102
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
103
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
104
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
102
105
|
|
|
103
106
|
smem = cutlass.utils.SmemAllocator()
|
|
104
107
|
sX = smem.allocate_tensor(
|
|
@@ -130,9 +133,7 @@ class Softmax(ReductionBase):
|
|
|
130
133
|
|
|
131
134
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
132
135
|
tXpX = (
|
|
133
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
134
|
-
if cutlass.const_expr(not is_even_N)
|
|
135
|
-
else None
|
|
136
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
136
137
|
)
|
|
137
138
|
if tXcX[0][0] < shape[0]:
|
|
138
139
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -312,9 +313,11 @@ class SoftmaxBackward(ReductionBase):
|
|
|
312
313
|
shape = mdY.shape
|
|
313
314
|
idX = cute.make_identity_tensor(shape)
|
|
314
315
|
# slice for CTAs
|
|
315
|
-
|
|
316
|
-
|
|
316
|
+
mdY, mY, mdX = [
|
|
317
|
+
utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
|
|
317
318
|
]
|
|
319
|
+
gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
|
|
320
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
318
321
|
|
|
319
322
|
smem = cutlass.utils.SmemAllocator()
|
|
320
323
|
sdY = smem.allocate_tensor(
|
quack/utils.py
CHANGED
|
@@ -120,12 +120,20 @@ def cluster_reduce(
|
|
|
120
120
|
reduction_buffer: cute.Tensor,
|
|
121
121
|
mbar_ptr: cute.Pointer,
|
|
122
122
|
init_val: cute.Numeric = 0.0,
|
|
123
|
+
phase: Optional[cutlass.Int32] = None,
|
|
123
124
|
) -> cute.Numeric:
|
|
124
125
|
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
125
126
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
126
127
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
127
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
128
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
128
129
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
130
|
+
if warp_idx == 0:
|
|
131
|
+
with cute.arch.elect_one():
|
|
132
|
+
num_warps = rows_per_block * warps_per_row
|
|
133
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
134
|
+
mbar_ptr,
|
|
135
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
136
|
+
)
|
|
129
137
|
if lane_idx < cluster_n:
|
|
130
138
|
store_shared_remote(
|
|
131
139
|
val,
|
|
@@ -133,7 +141,7 @@ def cluster_reduce(
|
|
|
133
141
|
mbar_ptr,
|
|
134
142
|
peer_cta_rank_in_cluster=lane_idx,
|
|
135
143
|
)
|
|
136
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
144
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
137
145
|
block_reduce_val = init_val
|
|
138
146
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
139
147
|
for i in cutlass.range_constexpr(num_iter):
|
|
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
|
|
|
149
157
|
op: Callable,
|
|
150
158
|
reduction_buffer: cute.Tensor,
|
|
151
159
|
mbar_ptr: Optional[cute.Pointer],
|
|
160
|
+
phase: Optional[cutlass.Int32] = None,
|
|
152
161
|
init_val: cute.Numeric = 0.0,
|
|
153
162
|
) -> cute.Numeric:
|
|
154
163
|
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
155
164
|
if cutlass.const_expr(mbar_ptr is None):
|
|
156
165
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
157
166
|
else:
|
|
158
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
|
|
167
|
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
159
168
|
|
|
160
169
|
|
|
161
170
|
@cute.jit
|
|
@@ -165,6 +174,7 @@ def row_reduce(
|
|
|
165
174
|
threads_per_row: cutlass.Constexpr[int],
|
|
166
175
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
167
176
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
177
|
+
phase: Optional[cutlass.Int32] = None,
|
|
168
178
|
init_val: cute.Numeric = 0.0,
|
|
169
179
|
hook_fn: Optional[Callable] = None,
|
|
170
180
|
) -> cute.Numeric:
|
|
@@ -193,7 +203,7 @@ def row_reduce(
|
|
|
193
203
|
), "mbar_ptr must be provided for cluster reduction"
|
|
194
204
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
195
205
|
val = block_or_cluster_reduce(
|
|
196
|
-
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
206
|
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
197
207
|
)
|
|
198
208
|
return val
|
|
199
209
|
|
|
@@ -205,6 +215,7 @@ def online_softmax_reduce(
|
|
|
205
215
|
reduction_buffer: Optional[cute.Tensor] = None,
|
|
206
216
|
mbar_ptr: Optional[cute.Pointer] = None,
|
|
207
217
|
hook_fn: Optional[Callable] = None,
|
|
218
|
+
phase: Optional[cutlass.Int32] = None,
|
|
208
219
|
return_exp_x: bool = False,
|
|
209
220
|
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
210
221
|
assert x.dtype == Float32, "x must be of type Float32"
|
|
@@ -225,7 +236,7 @@ def online_softmax_reduce(
|
|
|
225
236
|
if cutlass.const_expr(hook_fn is not None):
|
|
226
237
|
hook_fn()
|
|
227
238
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
228
|
-
warps_per_row, cluster_n = reduction_buffer.shape
|
|
239
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
229
240
|
assert (
|
|
230
241
|
cluster_n == 1 or mbar_ptr is not None
|
|
231
242
|
), "mbar_ptr must be provided for cluster reduction"
|
|
@@ -251,6 +262,13 @@ def online_softmax_reduce(
|
|
|
251
262
|
max_x = max_x_final
|
|
252
263
|
else:
|
|
253
264
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
265
|
+
if warp_idx == 0:
|
|
266
|
+
with cute.arch.elect_one():
|
|
267
|
+
num_warps = rows_per_block * warps_per_row
|
|
268
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
269
|
+
mbar_ptr,
|
|
270
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
271
|
+
)
|
|
254
272
|
if lane_idx < cluster_n:
|
|
255
273
|
store_shared_remote(
|
|
256
274
|
f32x2_to_i64(max_x, sum_exp_x),
|
|
@@ -258,7 +276,7 @@ def online_softmax_reduce(
|
|
|
258
276
|
mbar_ptr,
|
|
259
277
|
peer_cta_rank_in_cluster=lane_idx,
|
|
260
278
|
)
|
|
261
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
279
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
262
280
|
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
263
281
|
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
264
282
|
max_x_single_warp.fill(-Float32.inf)
|
|
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
351
369
|
|
|
352
370
|
|
|
353
371
|
@cute.jit
|
|
354
|
-
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
372
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
|
|
355
373
|
"""Fill out-of-bounds values in shared memory tensor.
|
|
356
374
|
|
|
357
375
|
Args:
|
|
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
|
|
|
361
379
|
"""
|
|
362
380
|
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
363
381
|
tXrX_fill.fill(fill_value)
|
|
364
|
-
for rest_v in cutlass.range_constexpr(
|
|
365
|
-
for rest_k in cutlass.range_constexpr(
|
|
366
|
-
if
|
|
382
|
+
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
383
|
+
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
384
|
+
if cutlass.const_expr(tXpX is not None):
|
|
385
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
386
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
387
|
+
else:
|
|
367
388
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
368
389
|
|
|
369
390
|
|
|
@@ -390,3 +411,38 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
390
411
|
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
391
412
|
)
|
|
392
413
|
return res0, res1
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
@dsl_user_op
|
|
417
|
+
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
418
|
+
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
419
|
+
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
420
|
+
assert len(flat_coord_i64) == len(
|
|
421
|
+
flat_stride
|
|
422
|
+
), "Coordinate and stride must have the same length"
|
|
423
|
+
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
424
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
425
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
426
|
+
new_ptr = cute.make_ptr(
|
|
427
|
+
tensor.element_type,
|
|
428
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
429
|
+
tensor.memspace,
|
|
430
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
431
|
+
)
|
|
432
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@dsl_user_op
|
|
436
|
+
def coord_offset_i64(
|
|
437
|
+
idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
|
|
438
|
+
) -> cute.Tensor:
|
|
439
|
+
offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
|
|
440
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
441
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
442
|
+
new_ptr = cute.make_ptr(
|
|
443
|
+
tensor.element_type,
|
|
444
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
445
|
+
tensor.memspace,
|
|
446
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
447
|
+
)
|
|
448
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
|
|
2
|
+
quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
|
|
3
|
+
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
+
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
+
quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
|
|
6
|
+
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
|
+
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
+
quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
+
quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
|
|
10
|
+
quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
+
quack_kernels-0.1.7.dist-info/RECORD,,
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=GPoImcynY5-OkMep5RhQhXrnZyxgqZG3RoHhsYQFSL4,203
|
|
2
|
-
quack/cross_entropy.py,sha256=WkngPY8uk4RCjCFtHtB7h9GF_8xt4NnyvDzvw73gIL4,19320
|
|
3
|
-
quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
|
|
4
|
-
quack/rmsnorm.py,sha256=N9NavrR85ws4cZgkfpeRLjYkVSq2yfyzJQWvfKf98pY,23935
|
|
5
|
-
quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
|
|
6
|
-
quack/utils.py,sha256=6EyWgf0z3wcbhGUivHmWB8hVBnEzMyOhmAuZ2Te82k0,15226
|
|
7
|
-
quack_kernels-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
quack_kernels-0.1.5.dist-info/METADATA,sha256=WI-2CP1mRH05V9Fjdx7HsErNOkrc6fUhheoH4ynlo-U,289
|
|
9
|
-
quack_kernels-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
-
quack_kernels-0.1.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
-
quack_kernels-0.1.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|