quack-kernels 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.2.2/quack_kernels.egg-info → quack_kernels-0.2.4}/PKG-INFO +4 -2
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/pyproject.toml +3 -1
- quack_kernels-0.2.4/quack/__init__.py +11 -0
- quack_kernels-0.2.4/quack/activation.py +524 -0
- quack_kernels-0.2.4/quack/broadcast_utils.py +29 -0
- quack_kernels-0.2.4/quack/compile_utils.py +19 -0
- quack_kernels-0.2.4/quack/copy_utils.py +487 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/cross_entropy.py +157 -233
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/cute_dsl_utils.py +20 -34
- quack_kernels-0.2.4/quack/gemm.py +194 -0
- quack_kernels-0.2.2/quack/gemm_act_sm90.py → quack_kernels-0.2.4/quack/gemm_act.py +218 -117
- quack_kernels-0.2.4/quack/gemm_config.py +95 -0
- quack_kernels-0.2.2/quack/gemm_dact_sm90.py → quack_kernels-0.2.4/quack/gemm_dact.py +53 -21
- quack_kernels-0.2.4/quack/gemm_default_epi.py +259 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_interface.py +177 -31
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_sm100.py +729 -506
- quack_kernels-0.2.2/quack/dense_gemm_sm90.py → quack_kernels-0.2.4/quack/gemm_sm90.py +344 -814
- quack_kernels-0.2.4/quack/gemm_symmetric.py +330 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_wrapper_utils.py +3 -1
- quack_kernels-0.2.4/quack/layout_utils.py +287 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/linear.py +24 -16
- quack_kernels-0.2.4/quack/pipeline.py +306 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/reduce.py +88 -49
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/reduction_base.py +25 -36
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/rmsnorm.py +476 -526
- quack_kernels-0.2.4/quack/sm100_utils.py +62 -0
- quack_kernels-0.2.4/quack/sm90_utils.py +127 -0
- quack_kernels-0.2.4/quack/softmax.py +403 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/bitonic_sort.py +13 -10
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/utils.py +6 -6
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/tile_scheduler.py +23 -16
- quack_kernels-0.2.4/quack/topk.py +551 -0
- quack_kernels-0.2.4/quack/utils.py +223 -0
- quack_kernels-0.2.4/quack/varlen_utils.py +386 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4/quack_kernels.egg-info}/PKG-INFO +4 -2
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/SOURCES.txt +13 -6
- quack_kernels-0.2.4/quack_kernels.egg-info/requires.txt +8 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_layernorm.py +17 -51
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear.py +37 -17
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_varlen_k.py +49 -3
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_varlen_m.py +43 -24
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_rmsnorm.py +26 -17
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_softmax.py +1 -2
- quack_kernels-0.2.2/tests/test_symmetric_dense_gemm_sm90.py → quack_kernels-0.2.4/tests/test_symmetric_gemm.py +15 -16
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_topk.py +42 -23
- quack_kernels-0.2.2/quack/__init__.py +0 -18
- quack_kernels-0.2.2/quack/activation.py +0 -279
- quack_kernels-0.2.2/quack/gemm_config.py +0 -69
- quack_kernels-0.2.2/quack/layernorm.py +0 -353
- quack_kernels-0.2.2/quack/pipeline.py +0 -151
- quack_kernels-0.2.2/quack/softmax.py +0 -471
- quack_kernels-0.2.2/quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2/quack/topk.py +0 -227
- quack_kernels-0.2.2/quack/utils.py +0 -411
- quack_kernels-0.2.2/quack/varlen_utils.py +0 -17
- quack_kernels-0.2.2/quack_kernels.egg-info/requires.txt +0 -6
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/LICENSE +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/README.md +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/autotuner.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/fast_math.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/mlp.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/generate_sorting_networks.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/sorting_networks.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/tensormap_manager.py +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/setup.cfg +0 -0
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_cross_entropy.py +1 -1
- {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_cross_entropy.py +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl<4.4.0,>=4.3.4
|
|
7
7
|
Requires-Dist: torch
|
|
8
|
+
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
|
|
9
|
+
Requires-Dist: torch-c-dlpack-ext
|
|
8
10
|
Provides-Extra: dev
|
|
9
11
|
Requires-Dist: pre-commit; extra == "dev"
|
|
10
12
|
Requires-Dist: ruff; extra == "dev"
|
|
@@ -7,8 +7,10 @@ name = "quack-kernels"
|
|
|
7
7
|
dynamic = ["version"]
|
|
8
8
|
requires-python = ">=3.10"
|
|
9
9
|
dependencies = [
|
|
10
|
-
"nvidia-cutlass-dsl
|
|
10
|
+
"nvidia-cutlass-dsl>=4.3.4,<4.4.0",
|
|
11
11
|
"torch",
|
|
12
|
+
"apache-tvm-ffi>=0.1.6,<0.2",
|
|
13
|
+
"torch-c-dlpack-ext",
|
|
12
14
|
]
|
|
13
15
|
|
|
14
16
|
[project.optional-dependencies]
|
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Float32, Boolean, const_expr
|
|
8
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
+
from cutlass._mlir.dialects import llvm
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dsl_user_op
|
|
18
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
19
|
+
return Float32(
|
|
20
|
+
llvm.inline_asm(
|
|
21
|
+
T.f32(),
|
|
22
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
23
|
+
"tanh.approx.f32 $0, $1;",
|
|
24
|
+
"=f,f",
|
|
25
|
+
has_side_effects=False,
|
|
26
|
+
is_align_stack=False,
|
|
27
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
28
|
+
)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dsl_user_op
|
|
33
|
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
34
|
+
if const_expr(not isinstance(x, tuple)):
|
|
35
|
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
36
|
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
|
37
|
+
else:
|
|
38
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
39
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
40
|
+
return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dsl_user_op
|
|
44
|
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
45
|
+
# return dout * out * (1.0 - out)
|
|
46
|
+
return dout * (out - out * out)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dsl_user_op
|
|
50
|
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
51
|
+
if const_expr(not isinstance(x, tuple)):
|
|
52
|
+
return cute.arch.fmax(x, Float32(0.0))
|
|
53
|
+
else:
|
|
54
|
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dsl_user_op
|
|
58
|
+
@cute.jit
|
|
59
|
+
def drelu(
|
|
60
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
61
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
62
|
+
if const_expr(not isinstance(x, tuple)):
|
|
63
|
+
x_pos = Boolean(x > 0)
|
|
64
|
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
|
65
|
+
else:
|
|
66
|
+
x0_pos = Boolean(x[0] > 0)
|
|
67
|
+
x1_pos = Boolean(x[1] > 0)
|
|
68
|
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
|
69
|
+
return dx, relu(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dsl_user_op
|
|
73
|
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
74
|
+
if const_expr(not isinstance(x, tuple)):
|
|
75
|
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
76
|
+
else:
|
|
77
|
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
|
78
|
+
return utils.mul_packed_f32x2(relu_x, x)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dsl_user_op
|
|
82
|
+
@cute.jit
|
|
83
|
+
def drelu_sq(
|
|
84
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
85
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
86
|
+
"""
|
|
87
|
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
|
88
|
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
|
89
|
+
Returns: (dx, relu_sq_out) where:
|
|
90
|
+
- dx = dout * 2 * x if x > 0, else 0
|
|
91
|
+
- relu_sq_out = max(x, 0) * x
|
|
92
|
+
"""
|
|
93
|
+
if const_expr(not isinstance(x, tuple)):
|
|
94
|
+
relu_x = relu(x)
|
|
95
|
+
relu_sq_out = relu_x * x
|
|
96
|
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
|
97
|
+
dx = 2.0 * (dout * relu_x)
|
|
98
|
+
return dx, relu_sq_out
|
|
99
|
+
else:
|
|
100
|
+
relu_x = relu(x)
|
|
101
|
+
relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
|
|
102
|
+
dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
|
|
103
|
+
return dx, relu_sq_out
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dsl_user_op
|
|
107
|
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
108
|
+
"""
|
|
109
|
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
110
|
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
|
111
|
+
"""
|
|
112
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
113
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
114
|
+
if const_expr(not isinstance(x, tuple)):
|
|
115
|
+
return 0.5 * (
|
|
116
|
+
x
|
|
117
|
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
|
118
|
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
|
119
|
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
123
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
124
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
125
|
+
)
|
|
126
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
127
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
128
|
+
x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
|
|
129
|
+
return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dsl_user_op
|
|
133
|
+
def dgelu_tanh_approx(
|
|
134
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
135
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
136
|
+
"""
|
|
137
|
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
|
138
|
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
|
139
|
+
Returns: (dx, gelu_out)
|
|
140
|
+
|
|
141
|
+
Derivative uses the chain rule:
|
|
142
|
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
143
|
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
|
144
|
+
and sech^2(z) = 1 - tanh^2(z)
|
|
145
|
+
"""
|
|
146
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
|
147
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
|
148
|
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
|
149
|
+
|
|
150
|
+
if const_expr(not isinstance(x, tuple)):
|
|
151
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
152
|
+
x_sq = x * x
|
|
153
|
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
|
154
|
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
155
|
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
156
|
+
gelu_out = x * half_tanh_z_plus_one
|
|
157
|
+
|
|
158
|
+
# Compute gradient
|
|
159
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
160
|
+
sech2_z = 1 - tanh_z * tanh_z
|
|
161
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
162
|
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
163
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
164
|
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
165
|
+
|
|
166
|
+
dx = dout * dgelu
|
|
167
|
+
return dx, gelu_out
|
|
168
|
+
else:
|
|
169
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
170
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
171
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
172
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
173
|
+
)
|
|
174
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
175
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
176
|
+
half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
|
177
|
+
gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
|
178
|
+
|
|
179
|
+
# Compute gradient
|
|
180
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
181
|
+
sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
|
182
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
183
|
+
dz_dx = utils.fma_packed_f32x2(
|
|
184
|
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
185
|
+
)
|
|
186
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
187
|
+
sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
|
|
188
|
+
x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
|
|
189
|
+
dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
|
190
|
+
|
|
191
|
+
dx = utils.mul_packed_f32x2(dout, dgelu)
|
|
192
|
+
return dx, gelu_out
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dsl_user_op
|
|
196
|
+
@cute.jit
|
|
197
|
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
198
|
+
if const_expr(not isinstance(x, tuple)):
|
|
199
|
+
use_linear = Boolean(x > 20.0)
|
|
200
|
+
return (
|
|
201
|
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
|
202
|
+
if not use_linear
|
|
203
|
+
else x
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
log2_e = math.log2(math.e)
|
|
207
|
+
x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
|
|
208
|
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
|
209
|
+
x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
|
|
210
|
+
log_x_exp_p1 = (
|
|
211
|
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
|
212
|
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
|
213
|
+
)
|
|
214
|
+
ln2 = math.log(2.0)
|
|
215
|
+
softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
|
216
|
+
use_linear_0 = Boolean(x[0] > 20.0)
|
|
217
|
+
use_linear_1 = Boolean(x[1] > 20.0)
|
|
218
|
+
return (
|
|
219
|
+
softplus_x[0] if not use_linear_0 else x[0],
|
|
220
|
+
softplus_x[1] if not use_linear_1 else x[1],
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@dsl_user_op
|
|
225
|
+
@cute.jit
|
|
226
|
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
227
|
+
use_linear = Boolean(out > 20.0)
|
|
228
|
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
|
229
|
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
|
230
|
+
return dx if not use_linear else dout
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@dsl_user_op
|
|
234
|
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
|
235
|
+
"""
|
|
236
|
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
|
237
|
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
238
|
+
"""
|
|
239
|
+
if const_expr(not isinstance(x, tuple)):
|
|
240
|
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
|
241
|
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
242
|
+
return x_half * tanh(x_half) + x_half
|
|
243
|
+
else:
|
|
244
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
|
245
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
246
|
+
return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dsl_user_op
|
|
250
|
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
251
|
+
if const_expr(not isinstance(x, tuple)):
|
|
252
|
+
return silu(x) * y
|
|
253
|
+
else:
|
|
254
|
+
return utils.mul_packed_f32x2(silu(x), y)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dsl_user_op
|
|
258
|
+
def dswiglu(
|
|
259
|
+
x: F32_or_F32x2,
|
|
260
|
+
y: F32_or_F32x2,
|
|
261
|
+
dout: F32_or_F32x2,
|
|
262
|
+
*,
|
|
263
|
+
already_halved: bool = False,
|
|
264
|
+
loc=None,
|
|
265
|
+
ip=None,
|
|
266
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
267
|
+
"""
|
|
268
|
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
269
|
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
|
270
|
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
|
271
|
+
|
|
272
|
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
273
|
+
|
|
274
|
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
|
275
|
+
to use FFMA instead of FADD and FMUL).
|
|
276
|
+
"""
|
|
277
|
+
if const_expr(not isinstance(x, tuple)):
|
|
278
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
279
|
+
# FMUL, MUFU.TANH, then FFMA
|
|
280
|
+
if const_expr(not already_halved):
|
|
281
|
+
sigmoid_x = sigmoid(x)
|
|
282
|
+
silu_x = x * sigmoid_x # FMUL
|
|
283
|
+
else:
|
|
284
|
+
tanh_x = tanh(x) # MUFU.TANH
|
|
285
|
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
|
286
|
+
silu_x = x * tanh_x + x # FFMA
|
|
287
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
288
|
+
# d_silu(x) * dout
|
|
289
|
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
|
290
|
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
|
291
|
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
|
292
|
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
|
293
|
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
294
|
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
|
295
|
+
dx = d_silu_x_dout * y # FMUL
|
|
296
|
+
dy = silu_x_dout
|
|
297
|
+
swiglu_out = silu_x * y # FMUL
|
|
298
|
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
|
299
|
+
return dx, dy, swiglu_out
|
|
300
|
+
else:
|
|
301
|
+
# Compute sigmoid(x) and silu(x)
|
|
302
|
+
if const_expr(not already_halved):
|
|
303
|
+
sigmoid_x = sigmoid(x)
|
|
304
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
|
|
305
|
+
else:
|
|
306
|
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
|
307
|
+
sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
|
308
|
+
silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
|
|
309
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
310
|
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
311
|
+
sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
|
|
312
|
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
|
313
|
+
)
|
|
314
|
+
d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
|
|
315
|
+
dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
|
|
316
|
+
dy = silu_x_dout
|
|
317
|
+
swiglu_out = utils.mul_packed_f32x2(silu_x, y)
|
|
318
|
+
return dx, dy, swiglu_out
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@dsl_user_op
|
|
322
|
+
def swiglu_oai(
|
|
323
|
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
324
|
+
) -> F32_or_F32x2:
|
|
325
|
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
|
326
|
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
|
327
|
+
x * sigmoid(alpha * x) * (y + 1)
|
|
328
|
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
|
329
|
+
"""
|
|
330
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
331
|
+
if const_expr(not isinstance(x, tuple)):
|
|
332
|
+
x_half = 0.5 * x
|
|
333
|
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
|
334
|
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
335
|
+
return silu_x * y + silu_x
|
|
336
|
+
else:
|
|
337
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
338
|
+
alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
|
|
339
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
340
|
+
silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
|
341
|
+
return utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@dsl_user_op
|
|
345
|
+
def dswiglu_oai(
|
|
346
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
347
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
348
|
+
"""
|
|
349
|
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
|
350
|
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
|
351
|
+
Returns: (dx, dy, swiglu_oai_out)
|
|
352
|
+
|
|
353
|
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
|
354
|
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
|
355
|
+
"""
|
|
356
|
+
if const_expr(not isinstance(x, tuple)):
|
|
357
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
358
|
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
359
|
+
# MUFU.TANH, then FFMA
|
|
360
|
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
|
361
|
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
|
362
|
+
silu_x = x * sigmoid_alpha_x # FMUL
|
|
363
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
364
|
+
# FFMA, FFMA, FMUL
|
|
365
|
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
366
|
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
|
367
|
+
dy = silu_x_dout
|
|
368
|
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
|
369
|
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
|
370
|
+
return dx, dy, swiglu_out
|
|
371
|
+
else:
|
|
372
|
+
# Compute sigmoid(alpha * x)
|
|
373
|
+
alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
|
374
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
375
|
+
sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
376
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
|
|
377
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
378
|
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
379
|
+
silu_x_minus_product = utils.fma_packed_f32x2(
|
|
380
|
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
|
381
|
+
)
|
|
382
|
+
sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
|
|
383
|
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
|
384
|
+
)
|
|
385
|
+
d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
|
386
|
+
dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
|
387
|
+
dy = silu_x_dout
|
|
388
|
+
swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
389
|
+
return dx, dy, swiglu_out
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@dsl_user_op
|
|
393
|
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
394
|
+
"""GLU: Gated Linear Unit
|
|
395
|
+
glu(x, y) = sigmoid(x) * y
|
|
396
|
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
397
|
+
"""
|
|
398
|
+
if const_expr(not isinstance(x, tuple)):
|
|
399
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
400
|
+
return sigmoid_x * y # FMUL
|
|
401
|
+
else:
|
|
402
|
+
sigmoid_x = sigmoid(x)
|
|
403
|
+
return utils.mul_packed_f32x2(sigmoid_x, y)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@dsl_user_op
|
|
407
|
+
def dglu(
|
|
408
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
409
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
410
|
+
"""
|
|
411
|
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
412
|
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
|
413
|
+
Returns: (dx, dy, glu_out) where:
|
|
414
|
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
|
415
|
+
- dy = dout * sigmoid(x)
|
|
416
|
+
- glu_out = sigmoid(x) * y
|
|
417
|
+
"""
|
|
418
|
+
if const_expr(not isinstance(x, tuple)):
|
|
419
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
420
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
421
|
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
422
|
+
glu_out = sigmoid_x * y # FMUL
|
|
423
|
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
424
|
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
|
425
|
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
|
426
|
+
# = (y - glu_out) * sigmoid_x_dout
|
|
427
|
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
|
428
|
+
dy = sigmoid_x_dout
|
|
429
|
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
|
430
|
+
return dx, dy, glu_out
|
|
431
|
+
else:
|
|
432
|
+
sigmoid_x = sigmoid(x)
|
|
433
|
+
sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
|
|
434
|
+
glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
|
|
435
|
+
# dx = (y - glu_out) * sigmoid_x_dout
|
|
436
|
+
y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
|
|
437
|
+
dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
|
438
|
+
dy = sigmoid_x_dout
|
|
439
|
+
return dx, dy, glu_out
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
@dsl_user_op
|
|
443
|
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
444
|
+
"""ReGLU: ReLU Gated Linear Unit
|
|
445
|
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
|
446
|
+
"""
|
|
447
|
+
if const_expr(not isinstance(x, tuple)):
|
|
448
|
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
449
|
+
else:
|
|
450
|
+
relu_x = relu(x)
|
|
451
|
+
return utils.mul_packed_f32x2(relu_x, y)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@dsl_user_op
|
|
455
|
+
@cute.jit
|
|
456
|
+
def dreglu(
|
|
457
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
458
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
459
|
+
"""
|
|
460
|
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
461
|
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
|
462
|
+
Returns: (dx, dy, reglu_out) where:
|
|
463
|
+
- dx = dout * y if x > 0, else 0
|
|
464
|
+
- dy = dout * relu(x)
|
|
465
|
+
- reglu_out = relu(x) * y
|
|
466
|
+
"""
|
|
467
|
+
if const_expr(not isinstance(x, tuple)):
|
|
468
|
+
x_pos = Boolean(x > 0)
|
|
469
|
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
|
470
|
+
dx = (dout * y) if x_pos else Float32(0.0)
|
|
471
|
+
dy = dout * relu_x
|
|
472
|
+
reglu_out = relu_x * y
|
|
473
|
+
return dx, dy, reglu_out
|
|
474
|
+
else:
|
|
475
|
+
x0_pos = Boolean(x[0] > 0)
|
|
476
|
+
x1_pos = Boolean(x[1] > 0)
|
|
477
|
+
relu_x = relu(x)
|
|
478
|
+
dout_y = utils.mul_packed_f32x2(dout, y)
|
|
479
|
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
|
480
|
+
dy = utils.mul_packed_f32x2(dout, relu_x)
|
|
481
|
+
reglu_out = utils.mul_packed_f32x2(relu_x, y)
|
|
482
|
+
return dx, dy, reglu_out
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dsl_user_op
|
|
486
|
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
487
|
+
"""GeGLU: GELU Gated Linear Unit
|
|
488
|
+
geglu(x, y) = gelu(x) * y
|
|
489
|
+
Uses the tanh approximation of GELU
|
|
490
|
+
"""
|
|
491
|
+
if const_expr(not isinstance(x, tuple)):
|
|
492
|
+
return gelu_tanh_approx(x) * y
|
|
493
|
+
else:
|
|
494
|
+
return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@dsl_user_op
|
|
498
|
+
def dgeglu(
|
|
499
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
500
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
501
|
+
"""
|
|
502
|
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
503
|
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
|
504
|
+
Returns: (dx, dy, geglu_out) where:
|
|
505
|
+
- dx = dout * y * d_gelu(x)
|
|
506
|
+
- dy = dout * gelu(x)
|
|
507
|
+
- geglu_out = gelu(x) * y
|
|
508
|
+
"""
|
|
509
|
+
if const_expr(not isinstance(x, tuple)):
|
|
510
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
511
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
512
|
+
# Compute gradients for geglu
|
|
513
|
+
dx = dgelu_x_dout * y
|
|
514
|
+
dy = gelu_x * dout
|
|
515
|
+
geglu_out = gelu_x * y
|
|
516
|
+
return dx, dy, geglu_out
|
|
517
|
+
else:
|
|
518
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
519
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
520
|
+
# Compute gradients for geglu
|
|
521
|
+
dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
|
|
522
|
+
dy = utils.mul_packed_f32x2(gelu_x, dout)
|
|
523
|
+
geglu_out = utils.mul_packed_f32x2(gelu_x, y)
|
|
524
|
+
return dx, dy, geglu_out
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
from cutlass import Float32, const_expr
|
|
7
|
+
|
|
8
|
+
from quack.layout_utils import make_acc_tensor_mn_view
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@cute.jit
|
|
12
|
+
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
|
13
|
+
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
|
14
|
+
tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
|
|
15
|
+
tCrC_f32.store(tCrC.load().to(Float32))
|
|
16
|
+
else:
|
|
17
|
+
tCrC_f32 = tCrC
|
|
18
|
+
# this happens to work for frgA layout too, not just acc layout
|
|
19
|
+
tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
|
|
20
|
+
if const_expr(is_colvec):
|
|
21
|
+
assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
|
|
22
|
+
for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
|
|
23
|
+
tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
|
|
24
|
+
else:
|
|
25
|
+
assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
|
|
26
|
+
for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
|
|
27
|
+
tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
|
|
28
|
+
if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
|
|
29
|
+
tCrC.store(tCrC_f32.load().to(tCrC.element_type))
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
|
9
|
+
if leading_dim < 0:
|
|
10
|
+
leading_dim = len(shape) + leading_dim
|
|
11
|
+
if dtype is None:
|
|
12
|
+
return None
|
|
13
|
+
stride = tuple(
|
|
14
|
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
|
15
|
+
for i in range(len(shape))
|
|
16
|
+
)
|
|
17
|
+
return cute.runtime.make_fake_tensor(
|
|
18
|
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
|
19
|
+
)
|