quack-kernels 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/cross_entropy.py +6 -10
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +57 -23
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +3 -2
- quack/utils.py +0 -63
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.1.dist-info}/RECORD +14 -14
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/activation.py
CHANGED
|
@@ -6,23 +6,12 @@ from typing import Tuple
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
from cutlass import Float32
|
|
9
|
-
from cutlass.cutlass_dsl import
|
|
10
|
-
from cutlass._mlir.dialects import llvm
|
|
9
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
@dsl_user_op
|
|
14
|
-
def
|
|
15
|
-
return
|
|
16
|
-
llvm.inline_asm(
|
|
17
|
-
T.f32(),
|
|
18
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
19
|
-
"tanh.approx.f32 $0, $1;",
|
|
20
|
-
"=f,f",
|
|
21
|
-
has_side_effects=False,
|
|
22
|
-
is_align_stack=False,
|
|
23
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
24
|
-
)
|
|
25
|
-
)
|
|
13
|
+
def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
14
|
+
return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
26
15
|
|
|
27
16
|
|
|
28
17
|
@dsl_user_op
|
|
@@ -67,7 +56,10 @@ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
67
56
|
"""
|
|
68
57
|
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
69
58
|
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
70
|
-
return 0.5 * (
|
|
59
|
+
return 0.5 * (
|
|
60
|
+
x
|
|
61
|
+
* (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
|
62
|
+
)
|
|
71
63
|
|
|
72
64
|
|
|
73
65
|
@dsl_user_op
|
|
@@ -88,7 +80,7 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
|
|
|
88
80
|
|
|
89
81
|
# Compute z = x * (c1 + c2 * x^2)
|
|
90
82
|
x_sq = x * x
|
|
91
|
-
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
83
|
+
tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
|
92
84
|
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
93
85
|
gelu_out = x * half_tanh_z_plus_one
|
|
94
86
|
|
|
@@ -111,7 +103,7 @@ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
111
103
|
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
112
104
|
"""
|
|
113
105
|
x_half = 0.5 * x
|
|
114
|
-
return x_half * tanh(x_half) + x_half
|
|
106
|
+
return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
115
107
|
|
|
116
108
|
|
|
117
109
|
@dsl_user_op
|
|
@@ -134,8 +126,8 @@ def dswiglu(
|
|
|
134
126
|
to use FFMA instead of FADD and FMUL).
|
|
135
127
|
"""
|
|
136
128
|
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
137
|
-
|
|
138
|
-
sigmoid_x =
|
|
129
|
+
# FMUL, MUFU.TANH, then FFMA
|
|
130
|
+
sigmoid_x = sigmoid(x)
|
|
139
131
|
silu_x = x * sigmoid_x # FMUL
|
|
140
132
|
silu_x_dout = silu_x * dout # FMUL
|
|
141
133
|
# d_silu(x) * dout
|
|
@@ -161,7 +153,7 @@ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=Non
|
|
|
161
153
|
"""
|
|
162
154
|
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
163
155
|
x_half = 0.5 * x
|
|
164
|
-
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
156
|
+
silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
|
165
157
|
return silu_x * y + silu_x
|
|
166
158
|
|
|
167
159
|
|
|
@@ -179,7 +171,8 @@ def dswiglu_oai(
|
|
|
179
171
|
"""
|
|
180
172
|
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
181
173
|
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
182
|
-
|
|
174
|
+
# MUFU.TANH, then FFMA
|
|
175
|
+
sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
|
183
176
|
silu_x = x * sigmoid_alpha_x # FMUL
|
|
184
177
|
silu_x_dout = silu_x * dout # FMUL
|
|
185
178
|
# FFMA, FFMA, FMUL
|
|
@@ -197,8 +190,7 @@ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
197
190
|
glu(x, y) = sigmoid(x) * y
|
|
198
191
|
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
199
192
|
"""
|
|
200
|
-
|
|
201
|
-
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
193
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
202
194
|
return sigmoid_x * y # FMUL
|
|
203
195
|
|
|
204
196
|
|
|
@@ -215,8 +207,7 @@ def dglu(
|
|
|
215
207
|
- glu_out = sigmoid(x) * y
|
|
216
208
|
"""
|
|
217
209
|
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
218
|
-
|
|
219
|
-
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
210
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
220
211
|
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
221
212
|
glu_out = sigmoid_x * y # FMUL
|
|
222
213
|
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
quack/cross_entropy.py
CHANGED
|
@@ -199,11 +199,8 @@ class CrossEntropy(ReductionBase):
|
|
|
199
199
|
cute.autovec_copy(tXsX, tXrX)
|
|
200
200
|
x = tXrX.load().to(Float32)
|
|
201
201
|
log2_e = math.log2(math.e)
|
|
202
|
-
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
203
|
-
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
204
|
-
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
205
202
|
# This would use ffma instead of fadd then fmul
|
|
206
|
-
exp_x =
|
|
203
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
|
|
207
204
|
denom = row_reduce(
|
|
208
205
|
exp_x,
|
|
209
206
|
cute.ReductionOp.ADD,
|
|
@@ -228,8 +225,7 @@ class CrossEntropy(ReductionBase):
|
|
|
228
225
|
and row < shape[0]
|
|
229
226
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
230
227
|
):
|
|
231
|
-
|
|
232
|
-
lse = max_x + utils.log2f(denom) * ln_2
|
|
228
|
+
lse = max_x + cute.math.log(denom, fastmath=True)
|
|
233
229
|
# Set loss to 0 if this index should be ignored, otherwise compute normally
|
|
234
230
|
loss_val = (lse - target_logit) if not should_ignore else Float32.zero
|
|
235
231
|
mLoss[row] = mLoss.element_type(loss_val)
|
|
@@ -552,7 +548,7 @@ class CrossEntropyBackward:
|
|
|
552
548
|
lse = Float32(mLSE[row])
|
|
553
549
|
|
|
554
550
|
log2_e = math.log2(math.e)
|
|
555
|
-
probs =
|
|
551
|
+
probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
|
|
556
552
|
prob_shifted = probs - 1.0
|
|
557
553
|
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
558
554
|
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
@@ -594,9 +590,9 @@ def _cross_entropy_backward(
|
|
|
594
590
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
595
591
|
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
596
592
|
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
597
|
-
assert (
|
|
598
|
-
|
|
599
|
-
)
|
|
593
|
+
assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
|
|
594
|
+
"Tensors must be on CUDA device"
|
|
595
|
+
)
|
|
600
596
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
601
597
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
602
598
|
|
quack/layernorm.py
CHANGED
|
@@ -217,7 +217,7 @@ class LayerNorm(ReductionBase):
|
|
|
217
217
|
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
218
218
|
init_val=0.0,
|
|
219
219
|
)
|
|
220
|
-
rstd =
|
|
220
|
+
rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
|
|
221
221
|
if cutlass.const_expr(mRstd is not None):
|
|
222
222
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
223
223
|
if (
|
quack/reduce.py
CHANGED
|
@@ -159,8 +159,7 @@ def online_softmax_reduce(
|
|
|
159
159
|
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
160
160
|
)
|
|
161
161
|
log2_e = math.log2(math.e)
|
|
162
|
-
exp_x =
|
|
163
|
-
# exp_x = exp2f((x - max_x) * log2_e)
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
164
163
|
sum_exp_x = warp_reduce(
|
|
165
164
|
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
166
165
|
operator.add,
|
|
@@ -190,10 +189,10 @@ def online_softmax_reduce(
|
|
|
190
189
|
reduction_buffer[row_idx, lane_idx]
|
|
191
190
|
)
|
|
192
191
|
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
193
|
-
sum_exp_x *=
|
|
192
|
+
sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
|
|
194
193
|
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
195
194
|
if cutlass.const_expr(return_exp_x):
|
|
196
|
-
exp_x *=
|
|
195
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
197
196
|
max_x = max_x_final
|
|
198
197
|
else:
|
|
199
198
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
@@ -231,11 +230,11 @@ def online_softmax_reduce(
|
|
|
231
230
|
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
232
231
|
sum_exp_x = 0.0
|
|
233
232
|
for i in cutlass.range_constexpr(num_iter):
|
|
234
|
-
sum_exp_x += sum_exp_x_single_warp[i] *
|
|
235
|
-
|
|
233
|
+
sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
|
|
234
|
+
max_x_single_warp[i] - max_x_final, fastmath=True
|
|
236
235
|
)
|
|
237
236
|
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
238
237
|
if cutlass.const_expr(return_exp_x):
|
|
239
|
-
exp_x *=
|
|
238
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
240
239
|
max_x = max_x_final
|
|
241
240
|
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
quack/rmsnorm.py
CHANGED
|
@@ -19,6 +19,7 @@ from quack.reduce import row_reduce
|
|
|
19
19
|
from quack.reduction_base import ReductionBase
|
|
20
20
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
21
21
|
|
|
22
|
+
|
|
22
23
|
class RMSNorm(ReductionBase):
|
|
23
24
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
24
25
|
super().__init__(dtype, N, stage=1)
|
|
@@ -132,7 +133,9 @@ class RMSNorm(ReductionBase):
|
|
|
132
133
|
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
133
134
|
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
134
135
|
if const_expr(mB is not None):
|
|
135
|
-
mB_expanded_layout = cute.prepend(
|
|
136
|
+
mB_expanded_layout = cute.prepend(
|
|
137
|
+
mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
138
|
+
)
|
|
136
139
|
mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
|
|
137
140
|
if const_expr(mRstd is not None):
|
|
138
141
|
mRstd_expanded_layout = cute.append(
|
|
@@ -202,11 +205,7 @@ class RMSNorm(ReductionBase):
|
|
|
202
205
|
]
|
|
203
206
|
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
204
207
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
205
|
-
gB = (
|
|
206
|
-
cute.local_tile(mB, tiler_mn, (0, cluster_y))
|
|
207
|
-
if const_expr(mB is not None)
|
|
208
|
-
else None
|
|
209
|
-
)
|
|
208
|
+
gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
|
|
210
209
|
gRstd = (
|
|
211
210
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
212
211
|
if const_expr(mRstd is not None)
|
|
@@ -226,12 +225,18 @@ class RMSNorm(ReductionBase):
|
|
|
226
225
|
copy_atom_load_W = cute.make_copy_atom(
|
|
227
226
|
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
228
227
|
)
|
|
229
|
-
num_bits_per_copy_B =
|
|
230
|
-
min(128, num_copy_elems_X * mB.element_type.width)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
228
|
+
num_bits_per_copy_B = (
|
|
229
|
+
cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
|
|
230
|
+
if const_expr(mB is not None)
|
|
231
|
+
else 0
|
|
232
|
+
)
|
|
233
|
+
copy_atom_load_B = (
|
|
234
|
+
cute.make_copy_atom(
|
|
235
|
+
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
236
|
+
)
|
|
237
|
+
if const_expr(mB is not None)
|
|
238
|
+
else None
|
|
239
|
+
)
|
|
235
240
|
if const_expr(mRes is not None):
|
|
236
241
|
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
237
242
|
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
@@ -317,7 +322,7 @@ class RMSNorm(ReductionBase):
|
|
|
317
322
|
init_val=0.0,
|
|
318
323
|
hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
|
|
319
324
|
)
|
|
320
|
-
rstd =
|
|
325
|
+
rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
|
|
321
326
|
if const_expr(mRstd is not None):
|
|
322
327
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
323
328
|
if (
|
|
@@ -355,7 +360,7 @@ class RMSNorm(ReductionBase):
|
|
|
355
360
|
mutates_args=("out", "rstd", "residual_out"),
|
|
356
361
|
device_types="cuda",
|
|
357
362
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
358
|
-
schema="(Tensor x, Tensor weight, Tensor(
|
|
363
|
+
schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
|
|
359
364
|
)
|
|
360
365
|
def _rmsnorm_fwd(
|
|
361
366
|
x: Tensor,
|
|
@@ -509,6 +514,7 @@ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
|
509
514
|
else:
|
|
510
515
|
return out.to(x.dtype), x_f32.to(residual.dtype)
|
|
511
516
|
|
|
517
|
+
|
|
512
518
|
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
513
519
|
"""Reference implementation for RMSNorm backward pass."""
|
|
514
520
|
x_f32 = x.float()
|
|
@@ -521,6 +527,7 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
521
527
|
dw = (dout * x_hat).sum(dim=0)
|
|
522
528
|
return dx.to(x.dtype), dw.to(w.dtype)
|
|
523
529
|
|
|
530
|
+
|
|
524
531
|
class RMSNormBackward(ReductionBase):
|
|
525
532
|
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
526
533
|
# 2 stages for double buffering when computing mean of x_hat * wdy
|
|
@@ -744,7 +751,11 @@ class RMSNormBackward(ReductionBase):
|
|
|
744
751
|
# Always compute partial weight gradients in fp32
|
|
745
752
|
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
746
753
|
|
|
747
|
-
gdB =
|
|
754
|
+
gdB = (
|
|
755
|
+
cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
756
|
+
if const_expr(mdB is not None)
|
|
757
|
+
else None
|
|
758
|
+
)
|
|
748
759
|
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
749
760
|
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
750
761
|
|
|
@@ -772,8 +783,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
772
783
|
tXrX, tXrdO, tXrdX = [
|
|
773
784
|
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
774
785
|
]
|
|
786
|
+
tXrdResO = None
|
|
775
787
|
if const_expr(mdResO is not None):
|
|
776
788
|
tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
|
|
789
|
+
tXrdRes = None
|
|
777
790
|
if const_expr(mdRes is not None):
|
|
778
791
|
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
779
792
|
|
|
@@ -930,7 +943,9 @@ class RMSNormBackward(ReductionBase):
|
|
|
930
943
|
if row == 0:
|
|
931
944
|
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
932
945
|
tXrdB_other = cute.make_fragment_like(tXrdB)
|
|
933
|
-
tXsdB_other = cute.make_tensor(
|
|
946
|
+
tXsdB_other = cute.make_tensor(
|
|
947
|
+
tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
|
|
948
|
+
)
|
|
934
949
|
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
935
950
|
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
936
951
|
cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
|
|
@@ -963,7 +978,7 @@ def _get_sm_count(N: int, device: torch.device) -> int:
|
|
|
963
978
|
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
964
979
|
device_types="cuda",
|
|
965
980
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
966
|
-
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(
|
|
981
|
+
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
|
|
967
982
|
)
|
|
968
983
|
def _rmsnorm_bwd(
|
|
969
984
|
x: Tensor,
|
|
@@ -1031,14 +1046,23 @@ def _rmsnorm_bwd(
|
|
|
1031
1046
|
)
|
|
1032
1047
|
|
|
1033
1048
|
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1034
|
-
db_partial_tensor =
|
|
1049
|
+
db_partial_tensor = (
|
|
1050
|
+
from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1051
|
+
if db_partial is not None
|
|
1052
|
+
else None
|
|
1053
|
+
)
|
|
1035
1054
|
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
1036
1055
|
|
|
1037
1056
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1038
1057
|
|
|
1039
|
-
compile_key = (
|
|
1058
|
+
compile_key = (
|
|
1059
|
+
N,
|
|
1060
|
+
x_tensor.element_type,
|
|
1061
|
+
weight_tensor.element_type,
|
|
1062
|
+
db_partial.dtype if db_partial is not None else None,
|
|
1040
1063
|
dresidual.dtype if dresidual is not None else None,
|
|
1041
|
-
dresidual_out.dtype if dresidual_out is not None else None
|
|
1064
|
+
dresidual_out.dtype if dresidual_out is not None else None,
|
|
1065
|
+
)
|
|
1042
1066
|
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1043
1067
|
rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
|
|
1044
1068
|
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
@@ -1106,7 +1130,17 @@ def rmsnorm_bwd(
|
|
|
1106
1130
|
|
|
1107
1131
|
class RMSNormFunction(torch.autograd.Function):
|
|
1108
1132
|
@staticmethod
|
|
1109
|
-
def forward(
|
|
1133
|
+
def forward(
|
|
1134
|
+
ctx,
|
|
1135
|
+
x,
|
|
1136
|
+
weight,
|
|
1137
|
+
bias=None,
|
|
1138
|
+
residual=None,
|
|
1139
|
+
out_dtype=None,
|
|
1140
|
+
residual_dtype=None,
|
|
1141
|
+
eps=1e-6,
|
|
1142
|
+
prenorm=False,
|
|
1143
|
+
):
|
|
1110
1144
|
x_shape_og = x.shape
|
|
1111
1145
|
# Flatten input
|
|
1112
1146
|
x = x.reshape(-1, x.shape[-1])
|
|
@@ -1129,7 +1163,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1129
1163
|
ctx.x_shape_og = x_shape_og
|
|
1130
1164
|
ctx.residual_dtype = residual.dtype if residual is not None else None
|
|
1131
1165
|
ctx.prenorm = prenorm
|
|
1132
|
-
if residual_out is None or prenorm
|
|
1166
|
+
if residual_out is None or not prenorm:
|
|
1133
1167
|
return out.reshape(x_shape_og)
|
|
1134
1168
|
else:
|
|
1135
1169
|
return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
|
|
@@ -1213,4 +1247,4 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1213
1247
|
|
|
1214
1248
|
def reset_parameters(self):
|
|
1215
1249
|
"""Reset the weight parameter to ones."""
|
|
1216
|
-
torch.nn.init.ones_(self.weight)
|
|
1250
|
+
torch.nn.init.ones_(self.weight)
|
quack/softmax.py
CHANGED
|
@@ -159,7 +159,7 @@ class Softmax(ReductionBase):
|
|
|
159
159
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
160
|
)
|
|
161
161
|
log2_e = math.log2(math.e)
|
|
162
|
-
exp_x = cute.math.exp2(
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
163
163
|
denom = row_reduce(
|
|
164
164
|
exp_x,
|
|
165
165
|
cute.ReductionOp.ADD,
|
quack/tile_scheduler.py
CHANGED
|
@@ -390,7 +390,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
|
390
390
|
Convert a triangular index to 2D coordinates.
|
|
391
391
|
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
392
392
|
"""
|
|
393
|
-
row = utils.ceil((
|
|
393
|
+
row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
|
|
394
394
|
col = idx - (row * (row + 1)) // 2
|
|
395
395
|
return row, col
|
|
396
396
|
|
|
@@ -524,7 +524,8 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
524
524
|
group_size = params.group_size_divmod.divisor
|
|
525
525
|
group_id = (
|
|
526
526
|
utils.ceil(
|
|
527
|
-
(
|
|
527
|
+
(cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
|
|
528
|
+
* params.group_size_inv_f32
|
|
528
529
|
)
|
|
529
530
|
- 1
|
|
530
531
|
)
|
quack/utils.py
CHANGED
|
@@ -100,69 +100,6 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
|
|
|
100
100
|
)
|
|
101
101
|
|
|
102
102
|
|
|
103
|
-
@cute.jit
|
|
104
|
-
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
105
|
-
"""exp2f calculation for both vector and scalar.
|
|
106
|
-
:param x: input value
|
|
107
|
-
:type x: cute.TensorSSA or Float32
|
|
108
|
-
:return: exp2 value
|
|
109
|
-
:rtype: cute.TensorSSA or Float32
|
|
110
|
-
"""
|
|
111
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
112
|
-
res = cute.make_fragment(x.shape, Float32)
|
|
113
|
-
res.store(x)
|
|
114
|
-
for i in cutlass.range(cute.size(x.shape), unroll_full=True):
|
|
115
|
-
res[i] = cute.arch.exp2(res[i])
|
|
116
|
-
return res.load()
|
|
117
|
-
else:
|
|
118
|
-
return cute.arch.exp2(x)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@dsl_user_op
|
|
122
|
-
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
123
|
-
return Float32(
|
|
124
|
-
llvm.inline_asm(
|
|
125
|
-
T.f32(),
|
|
126
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
127
|
-
"lg2.approx.ftz.f32 $0, $1;",
|
|
128
|
-
"=f,f",
|
|
129
|
-
has_side_effects=False,
|
|
130
|
-
is_align_stack=False,
|
|
131
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
132
|
-
)
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
@dsl_user_op
|
|
137
|
-
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
138
|
-
return Float32(
|
|
139
|
-
llvm.inline_asm(
|
|
140
|
-
T.f32(),
|
|
141
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
142
|
-
"sqrt.approx.ftz.f32 $0, $1;",
|
|
143
|
-
"=f,f",
|
|
144
|
-
has_side_effects=False,
|
|
145
|
-
is_align_stack=False,
|
|
146
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
147
|
-
)
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
@dsl_user_op
|
|
152
|
-
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
153
|
-
return Float32(
|
|
154
|
-
llvm.inline_asm(
|
|
155
|
-
T.f32(),
|
|
156
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
157
|
-
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
158
|
-
"=f,f",
|
|
159
|
-
has_side_effects=False,
|
|
160
|
-
is_align_stack=False,
|
|
161
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
162
|
-
)
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
|
|
166
103
|
@dsl_user_op
|
|
167
104
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
168
105
|
return Int32(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
quack/__init__.py,sha256=
|
|
2
|
-
quack/activation.py,sha256=
|
|
1
|
+
quack/__init__.py,sha256=H1m0CnfPidSSmprZeTGJc8LVh7stdBPmPLEuZwgN_7M,364
|
|
2
|
+
quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
|
|
3
3
|
quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
|
|
4
|
-
quack/cross_entropy.py,sha256=
|
|
4
|
+
quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
|
|
5
5
|
quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
|
|
6
6
|
quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
|
|
7
7
|
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
@@ -11,27 +11,27 @@ quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
|
11
11
|
quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
|
|
12
12
|
quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
|
|
13
13
|
quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
|
|
14
|
-
quack/layernorm.py,sha256=
|
|
14
|
+
quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
|
|
15
15
|
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
16
|
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
17
|
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
18
|
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
-
quack/reduce.py,sha256=
|
|
19
|
+
quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
|
|
20
20
|
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
-
quack/rmsnorm.py,sha256=
|
|
22
|
-
quack/softmax.py,sha256=
|
|
21
|
+
quack/rmsnorm.py,sha256=PrW2zuaQs_Gr6g8B6DMsGSJFZdEsWf32if_EwUR_IDQ,49386
|
|
22
|
+
quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
|
|
23
23
|
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
24
|
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
-
quack/tile_scheduler.py,sha256=
|
|
25
|
+
quack/tile_scheduler.py,sha256=BQ-SeW5wxulKuwmpq0CAIjkuirv4KWdUdoIGQB88aGE,42319
|
|
26
26
|
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
-
quack/utils.py,sha256=
|
|
27
|
+
quack/utils.py,sha256=wOgNw9VL40FCsLwN52juPfk48zVpX-rta3MQhAQe8Wc,12767
|
|
28
28
|
quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
|
|
29
29
|
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
30
|
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
31
|
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
32
|
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
-
quack_kernels-0.2.
|
|
34
|
-
quack_kernels-0.2.
|
|
35
|
-
quack_kernels-0.2.
|
|
36
|
-
quack_kernels-0.2.
|
|
37
|
-
quack_kernels-0.2.
|
|
33
|
+
quack_kernels-0.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.1.dist-info/METADATA,sha256=_AFigx6aFt-25GzUP6YWalDBwHvwzgK9EU85WjZXvsI,285
|
|
35
|
+
quack_kernels-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|