quack-kernels 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +56 -15
- quack/rmsnorm.py +232 -106
- {quack_kernels-0.1.7.dist-info → quack_kernels-0.1.9.dist-info}/METADATA +1 -1
- quack_kernels-0.1.9.dist-info/RECORD +12 -0
- quack_kernels-0.1.7.dist-info/RECORD +0 -12
- {quack_kernels-0.1.7.dist-info → quack_kernels-0.1.9.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.7.dist-info → quack_kernels-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.7.dist-info → quack_kernels-0.1.9.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
import torch
|
|
5
4
|
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
import cuda.bindings.driver as cuda
|
|
8
7
|
|
|
9
8
|
import cutlass
|
|
10
9
|
import cutlass.cute as cute
|
|
11
|
-
from cutlass.cute.runtime import from_dlpack
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
import torch
|
|
13
|
+
from cutlass.cute.runtime import from_dlpack
|
|
14
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
15
|
|
|
16
16
|
|
|
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
|
|
|
79
79
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
80
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
81
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
82
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
83
83
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
84
|
stream=stream,
|
|
85
85
|
)
|
|
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
|
|
|
111
111
|
|
|
112
112
|
smem = cutlass.utils.SmemAllocator()
|
|
113
113
|
sX = smem.allocate_tensor(
|
|
114
|
-
mX.element_type,
|
|
114
|
+
mX.element_type,
|
|
115
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
116
|
+
byte_alignment=16,
|
|
115
117
|
)
|
|
116
118
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
119
|
|
|
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
|
|
|
166
168
|
reduction_buffer[None, None, 0],
|
|
167
169
|
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
168
170
|
init_val=-cutlass.Float32.inf,
|
|
169
|
-
hook_fn=
|
|
171
|
+
hook_fn=(
|
|
172
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
+
),
|
|
170
174
|
)
|
|
171
175
|
if cutlass.const_expr(self.reload_from == "smem"):
|
|
172
176
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
|
|
|
191
195
|
threads_per_row,
|
|
192
196
|
reduction_buffer[None, None, 0],
|
|
193
197
|
mbar_ptr,
|
|
194
|
-
hook_fn=
|
|
198
|
+
hook_fn=(
|
|
199
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
200
|
+
),
|
|
195
201
|
)
|
|
196
202
|
|
|
197
203
|
if (
|
|
@@ -225,7 +231,11 @@ def _cross_entropy(
|
|
|
225
231
|
assert target.dim() == 1, "Target must be 1D"
|
|
226
232
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
227
233
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
228
|
-
assert x.dtype in [
|
|
234
|
+
assert x.dtype in [
|
|
235
|
+
torch.float16,
|
|
236
|
+
torch.bfloat16,
|
|
237
|
+
torch.float32,
|
|
238
|
+
], "Unsupported input dtype"
|
|
229
239
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
230
240
|
M, N = x.shape
|
|
231
241
|
device = x.device
|
|
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
|
|
|
314
324
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
325
|
|
|
316
326
|
mDLoss = cute.make_tensor(
|
|
317
|
-
mDLoss.iterator,
|
|
327
|
+
mDLoss.iterator,
|
|
328
|
+
cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
318
329
|
)
|
|
319
330
|
mTarget = cute.make_tensor(
|
|
320
|
-
mTarget.iterator,
|
|
331
|
+
mTarget.iterator,
|
|
332
|
+
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
321
333
|
)
|
|
322
334
|
mLSE = cute.make_tensor(
|
|
323
|
-
mLSE.iterator,
|
|
335
|
+
mLSE.iterator,
|
|
336
|
+
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
324
337
|
)
|
|
325
338
|
|
|
326
339
|
smem_size = cute.size_in_bytes(
|
|
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
|
|
|
364
377
|
|
|
365
378
|
smem = cutlass.utils.SmemAllocator()
|
|
366
379
|
sX = smem.allocate_tensor(
|
|
367
|
-
mX.element_type,
|
|
380
|
+
mX.element_type,
|
|
381
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
382
|
+
byte_alignment=16,
|
|
368
383
|
)
|
|
369
384
|
|
|
370
385
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
|
|
|
474
489
|
assert (
|
|
475
490
|
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
491
|
), "Tensors must be on CUDA device"
|
|
477
|
-
assert x.dtype in [
|
|
492
|
+
assert x.dtype in [
|
|
493
|
+
torch.float16,
|
|
494
|
+
torch.bfloat16,
|
|
495
|
+
torch.float32,
|
|
496
|
+
], "Unsupported input dtype"
|
|
478
497
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
498
|
|
|
480
499
|
M, N = x.shape
|
|
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
|
|
|
532
551
|
|
|
533
552
|
|
|
534
553
|
def cross_entropy(
|
|
535
|
-
x: torch.Tensor,
|
|
554
|
+
x: torch.Tensor,
|
|
555
|
+
target: torch.Tensor,
|
|
556
|
+
inplace_backward: bool = True,
|
|
557
|
+
reduction: str = "none",
|
|
536
558
|
) -> torch.Tensor:
|
|
537
559
|
"""Cross entropy loss with automatic differentiation support.
|
|
538
560
|
|
|
539
561
|
Args:
|
|
540
562
|
x: Input logits tensor of shape (M, N)
|
|
541
563
|
target: Target class indices tensor of shape (M,)
|
|
564
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
565
|
+
reduction: Specifies the reduction to apply to the output:
|
|
566
|
+
'none': no reduction will be applied (default)
|
|
567
|
+
'mean': the sum of the output will be divided by the number of elements
|
|
568
|
+
'sum': the output will be summed
|
|
542
569
|
|
|
543
570
|
Returns:
|
|
544
|
-
Cross entropy loss tensor
|
|
571
|
+
Cross entropy loss tensor:
|
|
572
|
+
- If reduction='none': tensor of shape (M,) with per-example losses
|
|
573
|
+
- If reduction='mean': scalar tensor with mean loss
|
|
574
|
+
- If reduction='sum': scalar tensor with sum of losses
|
|
545
575
|
"""
|
|
546
|
-
|
|
576
|
+
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
577
|
+
|
|
578
|
+
if reduction == "mean":
|
|
579
|
+
return loss.mean()
|
|
580
|
+
elif reduction == "sum":
|
|
581
|
+
return loss.sum()
|
|
582
|
+
elif reduction == "none":
|
|
583
|
+
return loss
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
|
|
587
|
+
)
|
quack/rmsnorm.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
8
7
|
import cutlass
|
|
9
8
|
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32, Int32
|
|
10
10
|
from cutlass.cute.runtime import from_dlpack
|
|
11
|
+
|
|
11
12
|
import quack.utils as utils
|
|
13
|
+
import torch
|
|
12
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
15
|
|
|
14
16
|
|
|
@@ -19,41 +21,55 @@ class RMSNorm(ReductionBase):
|
|
|
19
21
|
self.delay_w_load = False
|
|
20
22
|
|
|
21
23
|
def _calculate_threads_per_row(self):
|
|
24
|
+
"""Calculate the number of threads per row for the RMSNorm kernel."""
|
|
22
25
|
N = self.N
|
|
23
|
-
|
|
24
|
-
8
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
26
|
+
if N <= 64:
|
|
27
|
+
return 8
|
|
28
|
+
elif N <= 128:
|
|
29
|
+
return 16
|
|
30
|
+
elif N <= 3072:
|
|
31
|
+
return 32
|
|
32
|
+
elif N <= 6144:
|
|
33
|
+
return 64
|
|
34
|
+
elif N <= 16384:
|
|
35
|
+
return 128
|
|
36
|
+
else:
|
|
37
|
+
return 256
|
|
32
38
|
|
|
33
39
|
def _set_cluster_n(self):
|
|
40
|
+
"""
|
|
41
|
+
Set the number of clusters for the RMSNorm kernel.
|
|
42
|
+
Stored in self.cluster_n.
|
|
43
|
+
"""
|
|
34
44
|
N = self.N
|
|
45
|
+
|
|
35
46
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
36
47
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
37
48
|
if cutlass.const_expr(self.dtype.width == 16):
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
# 16-bit types (fp16, bf16)
|
|
50
|
+
if N <= 16 * 1024:
|
|
51
|
+
cluster_n = 1
|
|
52
|
+
elif N <= 32 * 1024:
|
|
53
|
+
cluster_n = 2
|
|
54
|
+
elif N <= 64 * 1024:
|
|
55
|
+
cluster_n = 4
|
|
56
|
+
elif N <= 128 * 1024:
|
|
57
|
+
cluster_n = 8
|
|
58
|
+
else:
|
|
59
|
+
cluster_n = 16
|
|
60
|
+
else:
|
|
61
|
+
# 32-bit types (fp32)
|
|
62
|
+
if N <= 32 * 1024:
|
|
63
|
+
cluster_n = 1
|
|
64
|
+
elif N <= 64 * 1024:
|
|
65
|
+
cluster_n = 2
|
|
66
|
+
elif N <= 128 * 1024:
|
|
67
|
+
cluster_n = 4
|
|
68
|
+
elif N <= 256 * 1024:
|
|
69
|
+
cluster_n = 8
|
|
70
|
+
else:
|
|
71
|
+
cluster_n = 16
|
|
72
|
+
|
|
57
73
|
self.cluster_n = cluster_n
|
|
58
74
|
|
|
59
75
|
@cute.jit
|
|
@@ -64,8 +80,17 @@ class RMSNorm(ReductionBase):
|
|
|
64
80
|
mO: cute.Tensor,
|
|
65
81
|
mRstd: Optional[cute.Tensor],
|
|
66
82
|
stream: cuda.CUstream,
|
|
67
|
-
eps:
|
|
83
|
+
eps: Float32 = 1e-6,
|
|
68
84
|
):
|
|
85
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
86
|
+
new_stride = lambda t: (
|
|
87
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
88
|
+
t.stride[1],
|
|
89
|
+
)
|
|
90
|
+
mX, mO = [
|
|
91
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
92
|
+
for t in (mX, mO)
|
|
93
|
+
]
|
|
69
94
|
assert mX.element_type == self.dtype
|
|
70
95
|
assert mO.element_type == self.dtype
|
|
71
96
|
self._set_cluster_n()
|
|
@@ -82,7 +107,7 @@ class RMSNorm(ReductionBase):
|
|
|
82
107
|
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
83
108
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
84
109
|
block=[num_threads, 1, 1],
|
|
85
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
110
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
86
111
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
87
112
|
stream=stream,
|
|
88
113
|
)
|
|
@@ -109,7 +134,9 @@ class RMSNorm(ReductionBase):
|
|
|
109
134
|
|
|
110
135
|
smem = cutlass.utils.SmemAllocator()
|
|
111
136
|
sX = smem.allocate_tensor(
|
|
112
|
-
mX.element_type,
|
|
137
|
+
mX.element_type,
|
|
138
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
139
|
+
byte_alignment=16,
|
|
113
140
|
)
|
|
114
141
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
115
142
|
|
|
@@ -134,30 +161,33 @@ class RMSNorm(ReductionBase):
|
|
|
134
161
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
135
162
|
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
136
163
|
)
|
|
164
|
+
num_bits_per_copy_W = cutlass.const_expr(
|
|
165
|
+
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
166
|
+
)
|
|
137
167
|
copy_atom_load_W = cute.make_copy_atom(
|
|
138
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
168
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
|
|
169
|
+
)
|
|
170
|
+
num_bits_per_copy_O = cutlass.const_expr(
|
|
171
|
+
min(128, 128 // mX.element_type.width * mO.element_type.width)
|
|
139
172
|
)
|
|
140
173
|
copy_atom_store_O = cute.make_copy_atom(
|
|
141
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=
|
|
174
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
|
|
142
175
|
)
|
|
143
176
|
|
|
144
177
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
145
178
|
tidx
|
|
146
179
|
)
|
|
147
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
148
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
149
180
|
|
|
150
|
-
|
|
181
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
151
182
|
tXgX = thr_copy_X.partition_S(gX)
|
|
152
183
|
tXsX = thr_copy_X.partition_D(sX)
|
|
153
|
-
tXgO =
|
|
154
|
-
tXrRstd =
|
|
184
|
+
tXgO = thr_copy_X.partition_D(gO)
|
|
185
|
+
tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
155
186
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
156
187
|
|
|
157
188
|
# allocate fragments for gmem->rmem
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
189
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
190
|
+
tXrW.fill(0.0)
|
|
161
191
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
162
192
|
|
|
163
193
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -169,9 +199,9 @@ class RMSNorm(ReductionBase):
|
|
|
169
199
|
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
170
200
|
cute.arch.cp_async_commit_group()
|
|
171
201
|
|
|
172
|
-
|
|
202
|
+
tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
173
203
|
if cutlass.const_expr(not delay_w_load):
|
|
174
|
-
cute.copy(copy_atom_load_W,
|
|
204
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
175
205
|
|
|
176
206
|
cute.arch.cp_async_wait_group(0)
|
|
177
207
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -184,7 +214,7 @@ class RMSNorm(ReductionBase):
|
|
|
184
214
|
reduction_buffer[None, None, 0],
|
|
185
215
|
mbar_ptr,
|
|
186
216
|
init_val=0.0,
|
|
187
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
217
|
+
hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
188
218
|
)
|
|
189
219
|
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
190
220
|
if cutlass.const_expr(mRstd is not None):
|
|
@@ -196,7 +226,7 @@ class RMSNorm(ReductionBase):
|
|
|
196
226
|
):
|
|
197
227
|
tXrRstd[0] = rstd
|
|
198
228
|
if cutlass.const_expr(delay_w_load):
|
|
199
|
-
cute.copy(copy_atom_load_W,
|
|
229
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
200
230
|
if cutlass.const_expr(reload_from == "smem"):
|
|
201
231
|
cute.autovec_copy(tXsX, tXrX)
|
|
202
232
|
x = tXrX.load().to(cute.Float32)
|
|
@@ -207,9 +237,9 @@ class RMSNorm(ReductionBase):
|
|
|
207
237
|
w = tXrW.load().to(cute.Float32)
|
|
208
238
|
y = x_hat * w
|
|
209
239
|
tXrO.store(y.to(tXrO.element_type))
|
|
210
|
-
|
|
240
|
+
tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
211
241
|
if row < shape[0]:
|
|
212
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=
|
|
242
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
|
|
213
243
|
|
|
214
244
|
|
|
215
245
|
def _rmsnorm_fwd(
|
|
@@ -232,25 +262,36 @@ def _rmsnorm_fwd(
|
|
|
232
262
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
233
263
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
234
264
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
235
|
-
assert x.dtype in [
|
|
236
|
-
|
|
265
|
+
assert x.dtype in [
|
|
266
|
+
torch.float16,
|
|
267
|
+
torch.bfloat16,
|
|
268
|
+
torch.float32,
|
|
269
|
+
], "Unsupported dtype"
|
|
270
|
+
|
|
271
|
+
assert weight.dtype in [
|
|
272
|
+
torch.float32,
|
|
273
|
+
torch.bfloat16,
|
|
274
|
+
torch.float16,
|
|
275
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
276
|
+
|
|
237
277
|
M, N = x.shape
|
|
238
278
|
device = x.device
|
|
239
279
|
out = torch.empty_like(x)
|
|
240
280
|
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
241
281
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
282
|
+
# convert_from_dlpack = lambda x: (
|
|
283
|
+
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
284
|
+
# mode=0, divisibility=128 // dtype.width
|
|
285
|
+
# )
|
|
286
|
+
# )
|
|
242
287
|
convert_from_dlpack = lambda x: (
|
|
243
|
-
from_dlpack(x.detach(), assumed_align=16).
|
|
244
|
-
mode=0, stride_order=(0, 1)
|
|
245
|
-
)
|
|
288
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
246
289
|
)
|
|
247
|
-
x_tensor, out_tensor = [
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
for t in (x, out)
|
|
251
|
-
]
|
|
290
|
+
x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
|
|
291
|
+
# handle weight divisibility based on weight dtype
|
|
292
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
252
293
|
weight_tensor = utils.convert_from_dlpack(
|
|
253
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
294
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
254
295
|
)
|
|
255
296
|
rstd_tensor = (
|
|
256
297
|
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
@@ -258,7 +299,7 @@ def _rmsnorm_fwd(
|
|
|
258
299
|
else None
|
|
259
300
|
)
|
|
260
301
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
261
|
-
compile_key = (dtype, N, rstd is not None)
|
|
302
|
+
compile_key = (dtype, N, rstd is not None, weight.dtype)
|
|
262
303
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
263
304
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
264
305
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -301,7 +342,8 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
301
342
|
class RMSNormBackward(ReductionBase):
|
|
302
343
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
303
344
|
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
304
|
-
super().__init__(dtype, N, stage=2, reduction_dtype=
|
|
345
|
+
super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
|
|
346
|
+
self.reload_wdy = None if N <= 16 * 1024 else "smem"
|
|
305
347
|
if self.N > 128 * 1024 and self.dtype.width >= 32:
|
|
306
348
|
# Not enough smem
|
|
307
349
|
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
@@ -348,9 +390,18 @@ class RMSNormBackward(ReductionBase):
|
|
|
348
390
|
mRstd: cute.Tensor,
|
|
349
391
|
mdX: cute.Tensor,
|
|
350
392
|
mdW: cute.Tensor,
|
|
351
|
-
sm_count:
|
|
393
|
+
sm_count: Int32,
|
|
352
394
|
stream: cuda.CUstream,
|
|
353
395
|
):
|
|
396
|
+
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
397
|
+
new_stride = lambda t: (
|
|
398
|
+
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
399
|
+
t.stride[1],
|
|
400
|
+
)
|
|
401
|
+
mX, mdOut, mdX = [
|
|
402
|
+
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
403
|
+
for t in (mX, mdOut, mdX)
|
|
404
|
+
]
|
|
354
405
|
self._set_cluster_n()
|
|
355
406
|
tiler_mn, tv_layout = self._get_tv_layout()
|
|
356
407
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
@@ -412,39 +463,41 @@ class RMSNormBackward(ReductionBase):
|
|
|
412
463
|
copy_atom_load_X_async = cute.make_copy_atom(
|
|
413
464
|
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
414
465
|
)
|
|
466
|
+
num_bits_per_copy_W = cutlass.const_expr(
|
|
467
|
+
min(128, 128 // mX.element_type.width * mW.element_type.width)
|
|
468
|
+
)
|
|
415
469
|
copy_atom_load_W = cute.make_copy_atom(
|
|
416
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=
|
|
470
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
|
|
471
|
+
)
|
|
472
|
+
num_bits_per_copy_dX = cutlass.const_expr(
|
|
473
|
+
min(128, 128 // mX.element_type.width * mdX.element_type.width)
|
|
417
474
|
)
|
|
418
475
|
copy_atom_store_dX = cute.make_copy_atom(
|
|
419
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=
|
|
476
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
|
|
477
|
+
)
|
|
478
|
+
num_bits_per_copy_dW = cutlass.const_expr(
|
|
479
|
+
min(128, 128 // mX.element_type.width * mdW.element_type.width)
|
|
420
480
|
)
|
|
421
481
|
copy_atom_store_dW = cute.make_copy_atom(
|
|
422
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=
|
|
482
|
+
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
|
|
423
483
|
)
|
|
424
484
|
|
|
425
485
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
426
|
-
thr_copy_X_async = cute.make_tiled_copy(
|
|
427
|
-
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
428
|
-
).get_slice(tidx)
|
|
429
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
430
|
-
thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
|
|
431
|
-
thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
432
486
|
|
|
433
487
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
434
|
-
|
|
435
|
-
|
|
488
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
489
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
436
490
|
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
437
491
|
if not is_even_N:
|
|
438
|
-
|
|
439
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
492
|
+
tXrW.fill(0.0)
|
|
440
493
|
|
|
441
494
|
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
442
|
-
|
|
443
|
-
utils.predicate_k(
|
|
495
|
+
tXpW = (
|
|
496
|
+
utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
|
|
444
497
|
if not is_even_N
|
|
445
498
|
else None
|
|
446
499
|
)
|
|
447
|
-
cute.copy(copy_atom_load_W,
|
|
500
|
+
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
448
501
|
weight = tXrW.load().to(cute.Float32)
|
|
449
502
|
|
|
450
503
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
@@ -452,16 +505,16 @@ class RMSNormBackward(ReductionBase):
|
|
|
452
505
|
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
453
506
|
|
|
454
507
|
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
455
|
-
|
|
456
|
-
utils.predicate_k(
|
|
508
|
+
tXpdW = (
|
|
509
|
+
utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
|
|
457
510
|
if not is_even_N
|
|
458
511
|
else None
|
|
459
512
|
)
|
|
460
513
|
|
|
461
514
|
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
tXrdW =
|
|
515
|
+
tXgdW = thr_copy_X.partition_S(gdW)
|
|
516
|
+
# Always compute partial weight gradients in fp32
|
|
517
|
+
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
465
518
|
|
|
466
519
|
gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
|
|
467
520
|
gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
|
|
@@ -471,7 +524,7 @@ class RMSNormBackward(ReductionBase):
|
|
|
471
524
|
tXsX = thr_copy_X.partition_D(sX)
|
|
472
525
|
tXgdOut = thr_copy_X.partition_S(gdOut)
|
|
473
526
|
tXsdOut = thr_copy_X.partition_D(sdOut)
|
|
474
|
-
tXgdX =
|
|
527
|
+
tXgdX = thr_copy_X.partition_D(gdX)
|
|
475
528
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
476
529
|
# This doesn't change across iterations
|
|
477
530
|
tXpX = (
|
|
@@ -513,9 +566,9 @@ class RMSNormBackward(ReductionBase):
|
|
|
513
566
|
|
|
514
567
|
threads_per_row = tv_layout.shape[0][0]
|
|
515
568
|
tXrdW.fill(0.0)
|
|
516
|
-
stage =
|
|
517
|
-
producer_phase =
|
|
518
|
-
consumer_phase =
|
|
569
|
+
stage = Int32(0)
|
|
570
|
+
producer_phase = Int32(1)
|
|
571
|
+
consumer_phase = Int32(0)
|
|
519
572
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
520
573
|
row = tXcX[None, None, None, bidx][0][0]
|
|
521
574
|
rstd = cutlass.Float.zero
|
|
@@ -538,10 +591,14 @@ class RMSNormBackward(ReductionBase):
|
|
|
538
591
|
)
|
|
539
592
|
elif tiler_mn[0] > 1:
|
|
540
593
|
utils.fill_oob(
|
|
541
|
-
tXsX[None, None, None, stage ^ 1],
|
|
594
|
+
tXsX[None, None, None, stage ^ 1],
|
|
595
|
+
None,
|
|
596
|
+
fill_value=mX.element_type.zero,
|
|
542
597
|
)
|
|
543
598
|
utils.fill_oob(
|
|
544
|
-
tXsdOut[None, None, None, stage ^ 1],
|
|
599
|
+
tXsdOut[None, None, None, stage ^ 1],
|
|
600
|
+
None,
|
|
601
|
+
fill_value=mdOut.element_type.zero,
|
|
545
602
|
)
|
|
546
603
|
cute.arch.cp_async_commit_group()
|
|
547
604
|
if row < M or tiler_mn[0] == 1:
|
|
@@ -561,12 +618,13 @@ class RMSNormBackward(ReductionBase):
|
|
|
561
618
|
cute.ReductionOp.ADD,
|
|
562
619
|
threads_per_row,
|
|
563
620
|
reduction_buffer[None, None, stage],
|
|
564
|
-
mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
621
|
+
(mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
565
622
|
phase=consumer_phase,
|
|
566
623
|
init_val=0.0,
|
|
567
624
|
)
|
|
568
625
|
/ shape[1]
|
|
569
626
|
)
|
|
627
|
+
|
|
570
628
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
571
629
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
572
630
|
# Requires adjusting the thread_count when initializing the mbar
|
|
@@ -576,12 +634,20 @@ class RMSNormBackward(ReductionBase):
|
|
|
576
634
|
cute.arch.mbarrier_arrive(
|
|
577
635
|
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
|
|
578
636
|
)
|
|
637
|
+
|
|
638
|
+
if cutlass.const_expr(self.reload_wdy == "smem"):
|
|
639
|
+
cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
|
|
640
|
+
dout = tXrdOut.load().to(cute.Float32)
|
|
641
|
+
wdy = dout * weight
|
|
642
|
+
|
|
579
643
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
580
644
|
tXrdX.store(dx.to(tXrdOut.element_type))
|
|
581
645
|
if row < M or tiler_mn[0] == 1:
|
|
582
646
|
tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
|
|
583
647
|
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
648
|
+
# Accumulate weight gradients in fp32
|
|
584
649
|
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
650
|
+
|
|
585
651
|
stage ^= 1
|
|
586
652
|
if stage == 0:
|
|
587
653
|
consumer_phase ^= 1
|
|
@@ -608,9 +674,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
608
674
|
tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
|
|
609
675
|
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
610
676
|
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
611
|
-
cute.copy(copy_atom_store_dW,
|
|
677
|
+
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
612
678
|
else:
|
|
613
|
-
|
|
679
|
+
# dw is already in fp32, so we can directly copy to global memory
|
|
680
|
+
cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
|
|
614
681
|
|
|
615
682
|
|
|
616
683
|
def _rmsnorm_backward(
|
|
@@ -634,8 +701,17 @@ def _rmsnorm_backward(
|
|
|
634
701
|
assert weight.dim() == 1, "Weight must be 1D"
|
|
635
702
|
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
636
703
|
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
637
|
-
assert x.dtype in [
|
|
638
|
-
|
|
704
|
+
assert x.dtype in [
|
|
705
|
+
torch.float16,
|
|
706
|
+
torch.bfloat16,
|
|
707
|
+
torch.float32,
|
|
708
|
+
], "Unsupported dtype"
|
|
709
|
+
|
|
710
|
+
assert weight.dtype in [
|
|
711
|
+
torch.float32,
|
|
712
|
+
torch.bfloat16,
|
|
713
|
+
torch.float16,
|
|
714
|
+
], "Weight must be float32, float16 or bfloat16"
|
|
639
715
|
|
|
640
716
|
M, N = x.shape
|
|
641
717
|
dx = torch.empty_like(x)
|
|
@@ -654,28 +730,29 @@ def _rmsnorm_backward(
|
|
|
654
730
|
sm_count = (
|
|
655
731
|
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
|
|
656
732
|
)
|
|
657
|
-
|
|
733
|
+
|
|
734
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
735
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
658
736
|
|
|
659
737
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
660
738
|
|
|
661
|
-
convert_from_dlpack = lambda
|
|
662
|
-
from_dlpack(
|
|
663
|
-
mode=0, stride_order=(0, 1)
|
|
664
|
-
)
|
|
739
|
+
convert_from_dlpack = lambda x: (
|
|
740
|
+
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
665
741
|
)
|
|
666
|
-
|
|
667
742
|
x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
|
|
668
743
|
|
|
744
|
+
# Handle weight div based on weight dtype
|
|
745
|
+
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
669
746
|
weight_tensor = utils.convert_from_dlpack(
|
|
670
|
-
weight.detach(), leading_dim=0, divisibility=128 //
|
|
747
|
+
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
671
748
|
)
|
|
672
749
|
|
|
673
|
-
dw_partial_tensor =
|
|
750
|
+
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
674
751
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
675
752
|
|
|
676
753
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
677
754
|
|
|
678
|
-
compile_key = (dtype, N)
|
|
755
|
+
compile_key = (dtype, N, weight.dtype)
|
|
679
756
|
if compile_key not in _rmsnorm_backward.compile_cache:
|
|
680
757
|
rmsnorm_backward_op = RMSNormBackward(dtype, N)
|
|
681
758
|
_rmsnorm_backward.compile_cache[compile_key] = cute.compile(
|
|
@@ -700,7 +777,7 @@ def _rmsnorm_backward(
|
|
|
700
777
|
sm_count,
|
|
701
778
|
current_stream,
|
|
702
779
|
)
|
|
703
|
-
|
|
780
|
+
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
704
781
|
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
705
782
|
return dx, dw
|
|
706
783
|
|
|
@@ -711,16 +788,29 @@ _rmsnorm_backward.compile_cache = {}
|
|
|
711
788
|
class RMSNormFunction(torch.autograd.Function):
|
|
712
789
|
@staticmethod
|
|
713
790
|
def forward(ctx, x, weight, eps):
|
|
791
|
+
x_shape_start = x.shape
|
|
792
|
+
|
|
793
|
+
# Flatten input
|
|
794
|
+
x = x.view(-1, x.shape[-1])
|
|
795
|
+
|
|
714
796
|
out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
|
|
715
797
|
ctx.save_for_backward(x, weight, rstd)
|
|
716
798
|
ctx.eps = eps
|
|
717
|
-
|
|
799
|
+
ctx.x_shape_start = x_shape_start
|
|
800
|
+
|
|
801
|
+
return out.reshape(x_shape_start)
|
|
718
802
|
|
|
719
803
|
@staticmethod
|
|
720
804
|
def backward(ctx, dout):
|
|
721
805
|
x, weight, rstd = ctx.saved_tensors
|
|
806
|
+
x_shape_start = ctx.x_shape_start
|
|
807
|
+
# Reshape dout to match the flattened shape used in forward
|
|
808
|
+
dout = dout.view(-1, dout.shape[-1])
|
|
722
809
|
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
723
|
-
|
|
810
|
+
dx = dx.view(x_shape_start)
|
|
811
|
+
# dx is returned for input gradient,
|
|
812
|
+
# dw is returned for weight gradient,
|
|
813
|
+
# None for eps gradient
|
|
724
814
|
return dx, dw, None
|
|
725
815
|
|
|
726
816
|
|
|
@@ -736,3 +826,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
|
|
|
736
826
|
Normalized output tensor of same shape as x
|
|
737
827
|
"""
|
|
738
828
|
return RMSNormFunction.apply(x, weight, eps)
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
class QuackRMSNorm(torch.nn.Module):
|
|
832
|
+
"""RMSNorm module that behaves like torch.nn.RMSNorm.
|
|
833
|
+
|
|
834
|
+
This class provides a drop-in replacement for torch.nn.RMSNorm that uses
|
|
835
|
+
the quack.rmsnorm implementation under the hood.
|
|
836
|
+
|
|
837
|
+
Args:
|
|
838
|
+
dim (int): The dimension to normalize over
|
|
839
|
+
eps (float, optional): A small constant for numerical stability. Default: 1e-6
|
|
840
|
+
|
|
841
|
+
Attributes:
|
|
842
|
+
weight (torch.nn.Parameter): The learnable weight parameter
|
|
843
|
+
eps (float): A small constant for numerical stability
|
|
844
|
+
"""
|
|
845
|
+
|
|
846
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
847
|
+
super().__init__()
|
|
848
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
849
|
+
self.eps = eps
|
|
850
|
+
|
|
851
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
852
|
+
"""Apply RMSNorm to the input tensor.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
x (torch.Tensor): Input tensor
|
|
856
|
+
|
|
857
|
+
Returns:
|
|
858
|
+
torch.Tensor: Normalized tensor
|
|
859
|
+
"""
|
|
860
|
+
return rmsnorm(x, self.weight, self.eps)
|
|
861
|
+
|
|
862
|
+
def reset_parameters(self):
|
|
863
|
+
"""Reset the weight parameter to ones."""
|
|
864
|
+
torch.nn.init.ones_(self.weight)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
quack/__init__.py,sha256=CT76CeRNh5bzQ9f13yVuRz9Sj7V3MvwzHH4fB1iQIf0,203
|
|
2
|
+
quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
|
|
3
|
+
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
+
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
+
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
6
|
+
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
|
+
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
+
quack_kernels-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
+
quack_kernels-0.1.9.dist-info/METADATA,sha256=vOnpbShNHRiUXKAnOUxzfRM7zkpW3RmjW4hIgvYda08,289
|
|
10
|
+
quack_kernels-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
quack_kernels-0.1.9.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
+
quack_kernels-0.1.9.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
|
|
2
|
-
quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
|
|
3
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
-
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
-
quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
|
|
6
|
-
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
|
-
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
-
quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
-
quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
|
|
10
|
-
quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
-
quack_kernels-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|