quack-kernels 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +62 -62
- quack/rmsnorm.py +54 -71
- quack/softmax.py +54 -61
- quack/utils.py +57 -1
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.2.dist-info}/METADATA +1 -1
- quack_kernels-0.1.2.dist-info/RECORD +10 -0
- quack_kernels-0.1.1.dist-info/RECORD +0 -10
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.2.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
3
|
import operator
|
|
4
|
-
from typing import Callable, Union
|
|
4
|
+
from typing import Callable, Union, Optional
|
|
5
5
|
|
|
6
6
|
import cuda.bindings.driver as cuda
|
|
7
7
|
|
|
@@ -17,37 +17,29 @@ def cross_entropy_kernel(
|
|
|
17
17
|
mX: cute.Tensor, # (M, N)
|
|
18
18
|
mTarget: cute.Tensor, # (M,)
|
|
19
19
|
mLoss: cute.Tensor, # (M,)
|
|
20
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
20
21
|
tv_layout: cute.Layout,
|
|
21
22
|
tiler_mn: cute.Shape,
|
|
22
23
|
cluster_n: cutlass.Constexpr = 1,
|
|
23
24
|
):
|
|
24
25
|
tidx, _, _ = cute.arch.thread_idx()
|
|
25
26
|
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
26
|
-
gdim, _, _ = cute.arch.grid_dim()
|
|
27
27
|
|
|
28
28
|
shape: cute.Shape = mX.shape
|
|
29
|
-
idX = cute.make_identity_tensor(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
36
|
-
copy_atom_scalar = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=gX.element_type.width)
|
|
37
|
-
|
|
38
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
39
|
-
thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
29
|
+
idX = cute.make_identity_tensor(shape)
|
|
30
|
+
# slice for CTAs
|
|
31
|
+
gX, cX = [
|
|
32
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
33
|
+
for mT in (mX, idX)
|
|
34
|
+
]
|
|
40
35
|
|
|
41
36
|
smem = cutlass.utils.SmemAllocator()
|
|
42
|
-
|
|
43
|
-
# Don't use blkX.layout here, because the stride is N, not N_rounded
|
|
44
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
37
|
+
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
45
38
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
46
39
|
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
47
|
-
|
|
48
40
|
reduction_buffer_layout = cute.make_ordered_layout(
|
|
49
41
|
# 2 stages: 1 for max, 1 for sum
|
|
50
|
-
(num_warps // warps_per_row,
|
|
42
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
|
|
51
43
|
order=(1, 0, 2)
|
|
52
44
|
)
|
|
53
45
|
reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
|
|
@@ -57,14 +49,15 @@ def cross_entropy_kernel(
|
|
|
57
49
|
else:
|
|
58
50
|
mbar_ptr = None
|
|
59
51
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
|
|
52
|
+
# declare the atoms which will be used later for memory copy
|
|
53
|
+
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
54
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
65
55
|
|
|
66
|
-
|
|
67
|
-
|
|
56
|
+
#### Thread View
|
|
57
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
58
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
59
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
60
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
68
61
|
|
|
69
62
|
if cluster_n > 1:
|
|
70
63
|
if tidx < 2:
|
|
@@ -80,54 +73,57 @@ def cross_entropy_kernel(
|
|
|
80
73
|
if row < shape[0] and tXcX[0][1] == 0:
|
|
81
74
|
target = cute.Int32(mTarget[row])
|
|
82
75
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
|
|
76
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
|
|
77
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
86
78
|
if row < shape[0]:
|
|
87
|
-
cute.copy(
|
|
79
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
88
80
|
cute.arch.cp_async_commit_group()
|
|
89
81
|
cute.arch.cp_async_wait_group(0)
|
|
90
82
|
cute.autovec_copy(tXsX, tXrX)
|
|
91
83
|
x = tXrX.load().to(cute.Float32)
|
|
84
|
+
# Fill OOB values with -inf
|
|
85
|
+
if cutlass.const_expr(not is_even_N):
|
|
86
|
+
tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
|
|
87
|
+
tXrX_fp32.store(x)
|
|
88
|
+
for rest_v in range(tXpX.shape[0]):
|
|
89
|
+
for rest_k in range(tXpX.shape[2]):
|
|
90
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
91
|
+
tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
|
|
92
|
+
x = tXrX_fp32.load()
|
|
92
93
|
|
|
93
94
|
target_logit = cute.Float32.zero
|
|
94
95
|
if row < shape[0] and tXcX[0][1] == 0:
|
|
95
96
|
target_logit = cute.Float32(mX[row, target])
|
|
96
97
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
99
|
+
max_x = utils.row_reduce(
|
|
100
|
+
x,
|
|
101
|
+
cute.ReductionOp.MAX,
|
|
102
|
+
threads_per_row,
|
|
103
|
+
reduction_buffer[None, None, 0],
|
|
104
|
+
mbar_ptr + 0 if cluster_n > 1 else None,
|
|
105
|
+
init_val=-cutlass.Float32.inf,
|
|
106
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
101
107
|
)
|
|
102
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
103
|
-
cute.arch.cluster_wait()
|
|
104
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
105
|
-
max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
|
|
106
|
-
max_x = utils.block_or_cluster_reduce(
|
|
107
|
-
max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
|
|
108
|
-
)
|
|
109
108
|
log2_e = math.log2(math.e)
|
|
110
109
|
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
111
110
|
exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
|
|
112
|
-
denom = utils.
|
|
113
|
-
exp_x
|
|
114
|
-
|
|
115
|
-
|
|
111
|
+
denom = utils.row_reduce(
|
|
112
|
+
exp_x,
|
|
113
|
+
cute.ReductionOp.ADD,
|
|
114
|
+
threads_per_row,
|
|
115
|
+
reduction_buffer[None, None, 1],
|
|
116
|
+
mbar_ptr + 1 if cluster_n > 1 else None,
|
|
117
|
+
init_val=0.0,
|
|
116
118
|
)
|
|
117
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
118
|
-
sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
|
|
119
|
-
denom = utils.block_or_cluster_reduce(
|
|
120
|
-
denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
|
|
121
|
-
)
|
|
122
119
|
|
|
123
|
-
if tXcX[0][1] == 0 and row < shape[0]:
|
|
120
|
+
if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
|
|
124
121
|
ln_2 = math.log(2.0)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
122
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
123
|
+
loss_val = lse - target_logit
|
|
124
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
125
|
+
if cutlass.const_expr(mLSE is not None):
|
|
126
|
+
mLSE[row] = lse
|
|
131
127
|
|
|
132
128
|
|
|
133
129
|
@cute.jit
|
|
@@ -135,6 +131,7 @@ def cross_entropy_interface(
|
|
|
135
131
|
mX: cute.Tensor,
|
|
136
132
|
mTarget: cute.Tensor,
|
|
137
133
|
mLoss: cute.Tensor,
|
|
134
|
+
mLSE: Optional[cute.Tensor],
|
|
138
135
|
stream: cuda.CUstream,
|
|
139
136
|
N: cutlass.Constexpr,
|
|
140
137
|
copy_bits: cutlass.Constexpr = 128
|
|
@@ -161,7 +158,7 @@ def cross_entropy_interface(
|
|
|
161
158
|
)
|
|
162
159
|
|
|
163
160
|
smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
|
|
164
|
-
cross_entropy_kernel(mX, mTarget, mLoss, tv_layout, tiler_mn, cluster_n).launch(
|
|
161
|
+
cross_entropy_kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn, cluster_n).launch(
|
|
165
162
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
|
|
166
163
|
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
167
164
|
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
@@ -181,6 +178,7 @@ torch2cute_dtype_map = {
|
|
|
181
178
|
def cross_entropy(
|
|
182
179
|
x: torch.Tensor,
|
|
183
180
|
target: torch.Tensor,
|
|
181
|
+
return_lse: bool = False,
|
|
184
182
|
) -> torch.Tensor:
|
|
185
183
|
"""Cross entropy forward pass.
|
|
186
184
|
|
|
@@ -199,7 +197,8 @@ def cross_entropy(
|
|
|
199
197
|
assert target.dtype == torch.int64, "Target must be int64"
|
|
200
198
|
M, N = x.shape
|
|
201
199
|
device = x.device
|
|
202
|
-
loss = torch.empty(M, device=device, dtype=
|
|
200
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
201
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
203
202
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
204
203
|
convert_from_dlpack = lambda tensor: (
|
|
205
204
|
from_dlpack(tensor.detach(), assumed_align=16)
|
|
@@ -207,15 +206,16 @@ def cross_entropy(
|
|
|
207
206
|
)
|
|
208
207
|
x_tensor, = [convert_from_dlpack(tensor) for tensor in (x,)]
|
|
209
208
|
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
209
|
+
lse_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0) if lse is not None else None
|
|
210
210
|
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
211
211
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
212
|
-
compile_key = (dtype, N)
|
|
212
|
+
compile_key = (dtype, N, lse_tensor is not None)
|
|
213
213
|
if compile_key not in cross_entropy.compile_cache:
|
|
214
214
|
cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
215
|
-
cross_entropy_interface, x_tensor, target_tensor, loss_tensor, stream, N
|
|
215
|
+
cross_entropy_interface, x_tensor, target_tensor, loss_tensor, lse_tensor, stream, N
|
|
216
216
|
)
|
|
217
|
-
cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, stream)
|
|
218
|
-
return loss
|
|
217
|
+
cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, lse_tensor, stream)
|
|
218
|
+
return loss if not return_lse else (loss, lse)
|
|
219
219
|
|
|
220
220
|
|
|
221
221
|
cross_entropy.compile_cache = {}
|
quack/rmsnorm.py
CHANGED
|
@@ -16,13 +16,11 @@ import quack.utils as utils
|
|
|
16
16
|
|
|
17
17
|
@cute.kernel
|
|
18
18
|
def rmsnorm_kernel(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
cX: cute.Tensor, # coordinate tensor
|
|
19
|
+
mX: cute.Tensor,
|
|
20
|
+
mW: cute.Tensor,
|
|
21
|
+
mO: cute.Tensor,
|
|
22
|
+
mRstd: cute.Tensor,
|
|
24
23
|
eps: cute.Float32,
|
|
25
|
-
shape: cute.Shape,
|
|
26
24
|
tv_layout: cute.Layout,
|
|
27
25
|
tiler_mn: cute.Shape,
|
|
28
26
|
cluster_n: cutlass.Constexpr = 1,
|
|
@@ -31,42 +29,45 @@ def rmsnorm_kernel(
|
|
|
31
29
|
):
|
|
32
30
|
tidx, _, _ = cute.arch.thread_idx()
|
|
33
31
|
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
34
|
-
gdim, _, _ = cute.arch.grid_dim()
|
|
35
|
-
|
|
36
|
-
# slice for CTAs
|
|
37
|
-
# logical id -> address
|
|
38
|
-
blkX, blkOut, blkRstd, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, gRstd, cX)]
|
|
39
|
-
blkW = gW[(None, None), 0 if cluster_n == 1 else (0, cluster_y)]
|
|
40
|
-
|
|
41
|
-
# declare the atoms which will be used later for memory copy
|
|
42
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
|
43
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
44
|
-
copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gW.element_type, num_bits_per_copy=128)
|
|
45
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
|
|
46
|
-
|
|
47
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
48
|
-
thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
49
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
50
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
51
32
|
|
|
52
33
|
smem = cutlass.utils.SmemAllocator()
|
|
53
|
-
|
|
54
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
34
|
+
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
55
35
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
56
36
|
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
57
|
-
|
|
58
|
-
|
|
37
|
+
reduction_buffer_layout = cute.make_ordered_layout(
|
|
38
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n)),
|
|
39
|
+
order=(1, 0)
|
|
40
|
+
)
|
|
59
41
|
reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
|
|
60
42
|
if cutlass.const_expr(cluster_n > 1):
|
|
61
43
|
mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
|
|
62
44
|
else:
|
|
63
45
|
mbar_ptr = None
|
|
64
46
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
47
|
+
shape = mX.shape
|
|
48
|
+
idX = cute.make_identity_tensor(shape)
|
|
49
|
+
# slice for CTAs
|
|
50
|
+
gX, gO, gRstd, cX = [
|
|
51
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
52
|
+
for mT in (mX, mO, mRstd, idX)
|
|
53
|
+
]
|
|
54
|
+
gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
|
|
55
|
+
|
|
56
|
+
# declare the atoms which will be used later for memory copy
|
|
57
|
+
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
|
|
58
|
+
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
|
|
59
|
+
copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
|
|
60
|
+
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
|
|
61
|
+
|
|
62
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
63
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
64
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
65
|
+
|
|
66
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
67
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
68
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
69
|
+
tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
|
|
70
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
70
71
|
|
|
71
72
|
# allocate fragments for gmem->rmem
|
|
72
73
|
tWrW = cute.make_fragment_like(tWgW)
|
|
@@ -82,44 +83,33 @@ def rmsnorm_kernel(
|
|
|
82
83
|
# Cluster arrive after barrier init
|
|
83
84
|
cute.arch.cluster_arrive_relaxed()
|
|
84
85
|
|
|
85
|
-
tXpX =
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
# tXrX.fill(0.0)
|
|
89
|
-
if tXcX[0][0] < shape[0]:
|
|
90
|
-
# cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
86
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
87
|
+
row = tXcX[0][0]
|
|
88
|
+
if row < shape[0]:
|
|
91
89
|
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
92
90
|
cute.arch.cp_async_commit_group()
|
|
93
91
|
|
|
94
|
-
tWpW =
|
|
95
|
-
tWcX = thr_copy_W.partition_S(blkCrd)[(0, None), None, None]
|
|
96
|
-
for i in range(cute.size(tWpW)):
|
|
97
|
-
tWpW[i] = cute.elem_less(tWcX[i][1], shape[1])
|
|
92
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
98
93
|
if not delay_w_load:
|
|
99
94
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
100
95
|
|
|
101
96
|
cute.arch.cp_async_wait_group(0)
|
|
102
97
|
cute.autovec_copy(tXsX, tXrX)
|
|
103
98
|
x = tXrX.load().to(cute.Float32)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
99
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
100
|
+
sum_sq_x = utils.row_reduce(
|
|
101
|
+
x * x,
|
|
102
|
+
cute.ReductionOp.ADD,
|
|
103
|
+
threads_per_row,
|
|
104
|
+
reduction_buffer,
|
|
105
|
+
mbar_ptr,
|
|
106
|
+
init_val=0.0,
|
|
107
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
108
108
|
)
|
|
109
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
110
|
-
cute.arch.cluster_wait()
|
|
111
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
112
|
-
sum_sq_x = utils.block_or_cluster_reduce(
|
|
113
|
-
sum_sq_x, operator.add, reduction_buffer, mbar_ptr, init_val=0.0
|
|
114
|
-
)
|
|
115
109
|
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
116
110
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
117
|
-
if tXcX[0][1] == 0 and
|
|
118
|
-
|
|
119
|
-
tXrRstd[0] = rstd
|
|
120
|
-
else:
|
|
121
|
-
if cute.arch.block_idx_in_cluster() == 0:
|
|
122
|
-
tXrRstd[0] = rstd
|
|
111
|
+
if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
|
|
112
|
+
tXrRstd[0] = rstd
|
|
123
113
|
if delay_w_load:
|
|
124
114
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
125
115
|
if reload_from == "smem":
|
|
@@ -132,20 +122,16 @@ def rmsnorm_kernel(
|
|
|
132
122
|
w = tXrW.load().to(cute.Float32)
|
|
133
123
|
y = x_hat * w
|
|
134
124
|
tXrO.store(y.to(tXrO.element_type))
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
for i in range(cute.size(tOpO)):
|
|
138
|
-
tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
|
|
139
|
-
if tXcX[0][0] < shape[0]:
|
|
125
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
126
|
+
if row < shape[0]:
|
|
140
127
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
141
128
|
|
|
142
129
|
|
|
143
130
|
@cute.jit
|
|
144
131
|
def rmsnorm_interface(
|
|
145
|
-
# mX_: cute.Tensor,
|
|
146
132
|
mX: cute.Tensor,
|
|
147
133
|
mW: cute.Tensor,
|
|
148
|
-
|
|
134
|
+
mO: cute.Tensor,
|
|
149
135
|
mRstd: cute.Tensor,
|
|
150
136
|
stream: cuda.CUstream,
|
|
151
137
|
N: cutlass.Constexpr,
|
|
@@ -180,21 +166,18 @@ def rmsnorm_interface(
|
|
|
180
166
|
mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
181
167
|
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
|
|
182
168
|
mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
183
|
-
idX = cute.make_identity_tensor(mX.shape)
|
|
184
|
-
gX, gW, gO, gRstd, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mW_expanded, mOut, mRstd_expanded, idX)] # ((TileM,TileN),(RestM,RestN))
|
|
185
169
|
|
|
186
170
|
# reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
|
|
187
171
|
reload_from = None if N <= 16384 else "smem"
|
|
188
172
|
# delay_w_load = N > 64 * 1024
|
|
189
173
|
delay_w_load = False
|
|
190
174
|
N_rounded = tiler_mn[1]
|
|
191
|
-
rmsnorm_kernel(
|
|
192
|
-
grid=[cute.
|
|
175
|
+
rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
|
|
176
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
|
|
193
177
|
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
194
178
|
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
195
179
|
cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
|
|
196
|
-
|
|
197
|
-
smem=cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
|
|
180
|
+
smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
|
|
198
181
|
stream=stream,
|
|
199
182
|
)
|
|
200
183
|
|
quack/softmax.py
CHANGED
|
@@ -15,40 +15,30 @@ import quack.utils as utils
|
|
|
15
15
|
|
|
16
16
|
@cute.kernel
|
|
17
17
|
def softmax_kernel(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
cX: cute.Tensor, # coordinate tensor
|
|
21
|
-
shape: cute.Shape,
|
|
18
|
+
mX: cute.Tensor,
|
|
19
|
+
mO: cute.Tensor,
|
|
22
20
|
tv_layout: cute.Layout,
|
|
23
21
|
tiler_mn: cute.Shape,
|
|
24
22
|
cluster_n: cutlass.Constexpr = 1,
|
|
25
23
|
):
|
|
26
24
|
tidx, _, _ = cute.arch.thread_idx()
|
|
27
25
|
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
28
|
-
gdim, _, _ = cute.arch.grid_dim()
|
|
29
26
|
|
|
27
|
+
shape = mX.shape
|
|
28
|
+
idX = cute.make_identity_tensor(shape)
|
|
30
29
|
# slice for CTAs
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
|
36
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
37
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
|
|
38
|
-
|
|
39
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
40
|
-
thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
41
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
30
|
+
gX, gO, cX = [
|
|
31
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
32
|
+
for mT in (mX, mO, idX)
|
|
33
|
+
]
|
|
42
34
|
|
|
43
35
|
smem = cutlass.utils.SmemAllocator()
|
|
44
|
-
|
|
45
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
36
|
+
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
46
37
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
47
38
|
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
48
|
-
|
|
49
39
|
reduction_buffer_layout = cute.make_ordered_layout(
|
|
50
40
|
# 2 stages: 1 for max, 1 for sum
|
|
51
|
-
(num_warps // warps_per_row,
|
|
41
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
|
|
52
42
|
order=(1, 0, 2)
|
|
53
43
|
)
|
|
54
44
|
reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
|
|
@@ -58,10 +48,17 @@ def softmax_kernel(
|
|
|
58
48
|
else:
|
|
59
49
|
mbar_ptr = None
|
|
60
50
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
51
|
+
# declare the atoms which will be used later for memory copy
|
|
52
|
+
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
|
|
53
|
+
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
|
|
54
|
+
|
|
55
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
56
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
57
|
+
|
|
58
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
59
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
60
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
61
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
65
62
|
|
|
66
63
|
# allocate fragments for gmem->rmem
|
|
67
64
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
@@ -75,49 +72,48 @@ def softmax_kernel(
|
|
|
75
72
|
# Cluster arrive after barrier init
|
|
76
73
|
cute.arch.cluster_arrive_relaxed()
|
|
77
74
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
|
|
81
|
-
|
|
75
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
|
|
76
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
82
77
|
if tXcX[0][0] < shape[0]:
|
|
83
|
-
cute.copy(
|
|
78
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
84
79
|
cute.arch.cp_async_commit_group()
|
|
85
80
|
cute.arch.cp_async_wait_group(0)
|
|
86
81
|
|
|
87
82
|
cute.autovec_copy(tXsX, tXrX)
|
|
88
83
|
x = tXrX.load().to(cute.Float32)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
cute.
|
|
92
|
-
|
|
84
|
+
# Fill OOB values with -inf
|
|
85
|
+
if cutlass.const_expr(not is_even_N):
|
|
86
|
+
tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
|
|
87
|
+
tXrX_fp32.store(x)
|
|
88
|
+
for rest_v in range(tXpX.shape[0]):
|
|
89
|
+
for rest_k in range(tXpX.shape[2]):
|
|
90
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
91
|
+
tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
|
|
92
|
+
x = tXrX_fp32.load()
|
|
93
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
94
|
+
max_x = utils.row_reduce(
|
|
95
|
+
x,
|
|
96
|
+
cute.ReductionOp.MAX,
|
|
97
|
+
threads_per_row,
|
|
98
|
+
reduction_buffer[None, None, 0],
|
|
99
|
+
mbar_ptr + 0 if cluster_n > 1 else None,
|
|
100
|
+
init_val=-cutlass.Float32.inf,
|
|
101
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
93
102
|
)
|
|
94
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
95
|
-
cute.arch.cluster_wait()
|
|
96
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
97
|
-
max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
|
|
98
|
-
max_x = utils.block_or_cluster_reduce(
|
|
99
|
-
max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
|
|
100
|
-
)
|
|
101
103
|
log2_e = math.log2(math.e)
|
|
102
104
|
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
103
|
-
denom = utils.
|
|
104
|
-
exp_x
|
|
105
|
-
|
|
106
|
-
|
|
105
|
+
denom = utils.row_reduce(
|
|
106
|
+
exp_x,
|
|
107
|
+
cute.ReductionOp.ADD,
|
|
108
|
+
threads_per_row,
|
|
109
|
+
reduction_buffer[None, None, 1],
|
|
110
|
+
mbar_ptr + 1 if cluster_n > 1 else None,
|
|
111
|
+
init_val=0.0,
|
|
107
112
|
)
|
|
108
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
109
|
-
sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
|
|
110
|
-
denom = utils.block_or_cluster_reduce(
|
|
111
|
-
denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
|
|
112
|
-
)
|
|
113
113
|
inv = 1.0 / denom
|
|
114
114
|
y = exp_x * inv
|
|
115
|
-
|
|
116
115
|
tXrO.store(y.to(tXrO.element_type))
|
|
117
|
-
|
|
118
|
-
tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
|
|
119
|
-
for i in range(cute.size(tOpO)):
|
|
120
|
-
tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
|
|
116
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
121
117
|
if tXcX[0][0] < shape[0]:
|
|
122
118
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
123
119
|
|
|
@@ -125,7 +121,7 @@ def softmax_kernel(
|
|
|
125
121
|
@cute.jit
|
|
126
122
|
def softmax_interface(
|
|
127
123
|
mX: cute.Tensor,
|
|
128
|
-
|
|
124
|
+
mO: cute.Tensor,
|
|
129
125
|
stream: cuda.CUstream,
|
|
130
126
|
N: cutlass.Constexpr,
|
|
131
127
|
copy_bits: cutlass.Constexpr = 128
|
|
@@ -149,12 +145,9 @@ def softmax_interface(
|
|
|
149
145
|
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
150
146
|
)
|
|
151
147
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
|
|
156
|
-
softmax_kernel(gX, gO, cX, mX.shape, tv_layout, tiler_mn, cluster_n).launch(
|
|
157
|
-
grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
|
|
148
|
+
smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
|
|
149
|
+
softmax_kernel(mX, mO, tv_layout, tiler_mn, cluster_n).launch(
|
|
150
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
|
|
158
151
|
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
159
152
|
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
160
153
|
cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
|
quack/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
+
import operator
|
|
3
4
|
import math
|
|
4
5
|
from typing import Type, Callable, Optional
|
|
5
6
|
|
|
@@ -57,7 +58,7 @@ def block_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor,
|
|
|
57
58
|
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)
|
|
58
59
|
"""
|
|
59
60
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
60
|
-
warps_per_row = reduction_buffer.shape[1]
|
|
61
|
+
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
61
62
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
62
63
|
if lane_idx == 0:
|
|
63
64
|
reduction_buffer[row_idx, col_idx] = val
|
|
@@ -142,6 +143,46 @@ def block_or_cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: c
|
|
|
142
143
|
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
|
|
143
144
|
|
|
144
145
|
|
|
146
|
+
@cute.jit
|
|
147
|
+
def row_reduce(
|
|
148
|
+
x: cute.TensorSSA | cute.Numeric,
|
|
149
|
+
op: cute.ReductionOp,
|
|
150
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
151
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
152
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
153
|
+
init_val: cute.Numeric = 0.0,
|
|
154
|
+
hook_fn: Optional[Callable] = None,
|
|
155
|
+
) -> cute.Numeric:
|
|
156
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
|
|
157
|
+
"""
|
|
158
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
159
|
+
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
160
|
+
else:
|
|
161
|
+
val = x
|
|
162
|
+
warp_op = {
|
|
163
|
+
cute.ReductionOp.ADD: operator.add,
|
|
164
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == cute.Float32) else max,
|
|
165
|
+
cute.ReductionOp.MIN: min,
|
|
166
|
+
cute.ReductionOp.MUL: operator.mul,
|
|
167
|
+
}[op]
|
|
168
|
+
val = warp_reduce(
|
|
169
|
+
val,
|
|
170
|
+
warp_op,
|
|
171
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
172
|
+
)
|
|
173
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
174
|
+
hook_fn()
|
|
175
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
176
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
177
|
+
assert cluster_n == 1 or mbar_ptr is not None, "mbar_ptr must be provided for cluster reduction"
|
|
178
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
179
|
+
val = block_or_cluster_reduce(
|
|
180
|
+
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
181
|
+
)
|
|
182
|
+
return val
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
|
|
145
186
|
def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32:
|
|
146
187
|
"""exp2f calculation for both vector and scalar.
|
|
147
188
|
|
|
@@ -188,3 +229,18 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
|
|
|
188
229
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
189
230
|
)
|
|
190
231
|
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
235
|
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
236
|
+
tApA = cute.make_fragment(
|
|
237
|
+
cute.make_layout(
|
|
238
|
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
239
|
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
240
|
+
),
|
|
241
|
+
cutlass.Boolean,
|
|
242
|
+
)
|
|
243
|
+
for rest_v in range(tApA.shape[0]):
|
|
244
|
+
for rest_k in range(tApA.shape[2]):
|
|
245
|
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
246
|
+
return tApA
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
quack/__init__.py,sha256=Nf01m1CGrOjSkqGJom6P65hSLkckljRMhlkSoqqlO9k,137
|
|
2
|
+
quack/cross_entropy.py,sha256=gdo8sR9KT5TsrShbgAmy-bwRZLu0gTs_ykXBF2RMbFI,8900
|
|
3
|
+
quack/rmsnorm.py,sha256=JhwJSAPDDpB_hV90xU9ymiLU-zu4WScrSHc5JX2JarY,10470
|
|
4
|
+
quack/softmax.py,sha256=C8e8ZNaF5ePJ1NlrWZN1goCcvsx1C60FWlRyuFCcYoM,7737
|
|
5
|
+
quack/utils.py,sha256=PRdu-P7azA_PeHUNdtoy1zyxZwg_QyVrSiVwE1iXaWo,8961
|
|
6
|
+
quack_kernels-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
+
quack_kernels-0.1.2.dist-info/METADATA,sha256=3WjugLu1IhLlgsg2qUcLBZq1HI4-BIyyJIuQc5Hk-rU,186
|
|
8
|
+
quack_kernels-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
quack_kernels-0.1.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
10
|
+
quack_kernels-0.1.2.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=y3Oa4OVPqaGU_P1miI435DzfpMgIwKVmU8-Eogv58jQ,137
|
|
2
|
-
quack/cross_entropy.py,sha256=V0kG8DCNh2735sPIDwe68NB50rAqDF3XQApnGyo-sKg,9220
|
|
3
|
-
quack/rmsnorm.py,sha256=RNqcT-q4uvMbF6ejpzuqQH8l8VVuTRlnueXf28V47sc,11954
|
|
4
|
-
quack/softmax.py,sha256=QABgOESH5JjDm3yuUkyZZKXXpzn7CTuMSs0NEBnFD80,8536
|
|
5
|
-
quack/utils.py,sha256=ofV7QLDuq80h3nEA3TwZW-ti8CnYwMgnz1dpxpvhHpk,6859
|
|
6
|
-
quack_kernels-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
-
quack_kernels-0.1.1.dist-info/METADATA,sha256=XG3zS0_q48TzkoR7CemzaJGVYHS731yVOrzH49_uRK8,186
|
|
8
|
-
quack_kernels-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
quack_kernels-0.1.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
10
|
-
quack_kernels-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|