quack-kernels 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.2.0/quack_kernels.egg-info → quack_kernels-0.2.2}/PKG-INFO +3 -3
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/pyproject.toml +2 -2
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/__init__.py +1 -1
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/activation.py +16 -25
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/autotuner.py +64 -5
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/cross_entropy.py +6 -10
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/cute_dsl_utils.py +6 -7
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/dense_gemm_sm90.py +582 -287
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_act_sm90.py +70 -29
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_dact_sm90.py +43 -10
- quack_kernels-0.2.2/quack/gemm_interface.py +892 -0
- quack_kernels-0.2.0/quack/dense_gemm_sm100.py → quack_kernels-0.2.2/quack/gemm_sm100.py +443 -419
- quack_kernels-0.2.2/quack/gemm_wrapper_utils.py +315 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/layernorm.py +1 -1
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/reduce.py +6 -7
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/rmsnorm.py +126 -158
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/softmax.py +1 -1
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/tile_scheduler.py +37 -49
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/utils.py +61 -71
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0 → quack_kernels-0.2.2/quack_kernels.egg-info}/PKG-INFO +3 -3
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/SOURCES.txt +3 -1
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/requires.txt +1 -1
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_linear.py +6 -1
- quack_kernels-0.2.2/tests/test_linear_varlen_k.py +266 -0
- quack_kernels-0.2.2/tests/test_linear_varlen_m.py +376 -0
- quack_kernels-0.2.0/quack/gemm_interface.py +0 -569
- quack_kernels-0.2.0/quack/gemm_wrapper_utils.py +0 -158
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/LICENSE +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/README.md +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/fast_math.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_config.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/linear.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/mlp.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/pipeline.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/reduction_base.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/bitonic_sort.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/generate_sorting_networks.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/sorting_networks.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/utils.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/symmetric_dense_gemm_sm90.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/tensormap_manager.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/topk.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/setup.cfg +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_cross_entropy.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_layernorm.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_rmsnorm.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_softmax.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_symmetric_dense_gemm_sm90.py +0 -0
- {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_topk.py +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
4
|
-
Requires-Python: >=3.
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.2.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.1
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "quack-kernels"
|
|
7
7
|
dynamic = ["version"]
|
|
8
|
-
requires-python = ">=3.
|
|
8
|
+
requires-python = ">=3.10"
|
|
9
9
|
dependencies = [
|
|
10
|
-
"nvidia-cutlass-dsl==4.2.
|
|
10
|
+
"nvidia-cutlass-dsl==4.2.1",
|
|
11
11
|
"torch",
|
|
12
12
|
]
|
|
13
13
|
|
|
@@ -6,23 +6,12 @@ from typing import Tuple
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
from cutlass import Float32
|
|
9
|
-
from cutlass.cutlass_dsl import
|
|
10
|
-
from cutlass._mlir.dialects import llvm
|
|
9
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
@dsl_user_op
|
|
14
|
-
def
|
|
15
|
-
return
|
|
16
|
-
llvm.inline_asm(
|
|
17
|
-
T.f32(),
|
|
18
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
19
|
-
"tanh.approx.f32 $0, $1;",
|
|
20
|
-
"=f,f",
|
|
21
|
-
has_side_effects=False,
|
|
22
|
-
is_align_stack=False,
|
|
23
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
24
|
-
)
|
|
25
|
-
)
|
|
13
|
+
def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
14
|
+
return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
26
15
|
|
|
27
16
|
|
|
28
17
|
@dsl_user_op
|
|
@@ -67,7 +56,10 @@ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
67
56
|
"""
|
|
68
57
|
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
69
58
|
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
70
|
-
return 0.5 * (
|
|
59
|
+
return 0.5 * (
|
|
60
|
+
x
|
|
61
|
+
* (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
|
62
|
+
)
|
|
71
63
|
|
|
72
64
|
|
|
73
65
|
@dsl_user_op
|
|
@@ -88,7 +80,7 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
|
|
|
88
80
|
|
|
89
81
|
# Compute z = x * (c1 + c2 * x^2)
|
|
90
82
|
x_sq = x * x
|
|
91
|
-
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
83
|
+
tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
|
92
84
|
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
93
85
|
gelu_out = x * half_tanh_z_plus_one
|
|
94
86
|
|
|
@@ -111,7 +103,7 @@ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
111
103
|
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
112
104
|
"""
|
|
113
105
|
x_half = 0.5 * x
|
|
114
|
-
return x_half * tanh(x_half) + x_half
|
|
106
|
+
return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
115
107
|
|
|
116
108
|
|
|
117
109
|
@dsl_user_op
|
|
@@ -134,8 +126,8 @@ def dswiglu(
|
|
|
134
126
|
to use FFMA instead of FADD and FMUL).
|
|
135
127
|
"""
|
|
136
128
|
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
137
|
-
|
|
138
|
-
sigmoid_x =
|
|
129
|
+
# FMUL, MUFU.TANH, then FFMA
|
|
130
|
+
sigmoid_x = sigmoid(x)
|
|
139
131
|
silu_x = x * sigmoid_x # FMUL
|
|
140
132
|
silu_x_dout = silu_x * dout # FMUL
|
|
141
133
|
# d_silu(x) * dout
|
|
@@ -161,7 +153,7 @@ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=Non
|
|
|
161
153
|
"""
|
|
162
154
|
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
163
155
|
x_half = 0.5 * x
|
|
164
|
-
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
156
|
+
silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
|
165
157
|
return silu_x * y + silu_x
|
|
166
158
|
|
|
167
159
|
|
|
@@ -179,7 +171,8 @@ def dswiglu_oai(
|
|
|
179
171
|
"""
|
|
180
172
|
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
181
173
|
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
182
|
-
|
|
174
|
+
# MUFU.TANH, then FFMA
|
|
175
|
+
sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
|
183
176
|
silu_x = x * sigmoid_alpha_x # FMUL
|
|
184
177
|
silu_x_dout = silu_x * dout # FMUL
|
|
185
178
|
# FFMA, FFMA, FMUL
|
|
@@ -197,8 +190,7 @@ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
|
197
190
|
glu(x, y) = sigmoid(x) * y
|
|
198
191
|
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
199
192
|
"""
|
|
200
|
-
|
|
201
|
-
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
193
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
202
194
|
return sigmoid_x * y # FMUL
|
|
203
195
|
|
|
204
196
|
|
|
@@ -215,8 +207,7 @@ def dglu(
|
|
|
215
207
|
- glu_out = sigmoid(x) * y
|
|
216
208
|
"""
|
|
217
209
|
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
218
|
-
|
|
219
|
-
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
210
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
220
211
|
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
221
212
|
glu_out = sigmoid_x * y # FMUL
|
|
222
213
|
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
@@ -11,7 +11,7 @@ import hashlib
|
|
|
11
11
|
import json
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from functools import cached_property, partial
|
|
14
|
-
from typing import Dict, Tuple
|
|
14
|
+
from typing import Dict, Tuple, List, Optional, Any
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
@@ -53,7 +53,22 @@ def _base32(key):
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class Autotuner:
|
|
56
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fn,
|
|
59
|
+
key,
|
|
60
|
+
configs,
|
|
61
|
+
restore_value=None,
|
|
62
|
+
prune_configs_by: Optional[Dict] = None,
|
|
63
|
+
do_bench=None,
|
|
64
|
+
cache_results=False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
68
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
69
|
+
'top_k': number of configs to bench
|
|
70
|
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
71
|
+
"""
|
|
57
72
|
if not configs:
|
|
58
73
|
self.configs = [AutotuneConfig()]
|
|
59
74
|
else:
|
|
@@ -90,6 +105,16 @@ class Autotuner:
|
|
|
90
105
|
else:
|
|
91
106
|
self.post_hook = None
|
|
92
107
|
|
|
108
|
+
self.perf_model = None
|
|
109
|
+
self.configs_top_k = 1.0
|
|
110
|
+
self.early_config_prune = None
|
|
111
|
+
if prune_configs_by:
|
|
112
|
+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
|
|
113
|
+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
|
|
114
|
+
self.early_config_prune = prune_configs_by.get(
|
|
115
|
+
"early_config_prune", self.early_config_prune
|
|
116
|
+
)
|
|
117
|
+
|
|
93
118
|
self.fn = fn
|
|
94
119
|
self._do_bench = do_bench
|
|
95
120
|
|
|
@@ -198,13 +223,14 @@ class Autotuner:
|
|
|
198
223
|
key = tuple(key)
|
|
199
224
|
if key not in self.cache:
|
|
200
225
|
used_cached_result = False
|
|
226
|
+
pruned_configs = self.prune_configs(kwargs)
|
|
201
227
|
|
|
202
228
|
@torch.compiler.disable # Don't want any tracing here
|
|
203
229
|
def benchmark():
|
|
204
230
|
bench_start = time.time()
|
|
205
231
|
timings = {
|
|
206
232
|
config: self._bench(*args, config=config, **kwargs)
|
|
207
|
-
for config in
|
|
233
|
+
for config in pruned_configs
|
|
208
234
|
}
|
|
209
235
|
bench_end = time.time()
|
|
210
236
|
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
@@ -215,7 +241,7 @@ class Autotuner:
|
|
|
215
241
|
self.configs_timings = timings
|
|
216
242
|
|
|
217
243
|
if self.cache_results:
|
|
218
|
-
self.check_disk_cache(key,
|
|
244
|
+
self.check_disk_cache(key, pruned_configs, benchmark)
|
|
219
245
|
else:
|
|
220
246
|
benchmark()
|
|
221
247
|
|
|
@@ -239,6 +265,32 @@ class Autotuner:
|
|
|
239
265
|
self.nargs = None
|
|
240
266
|
return ret
|
|
241
267
|
|
|
268
|
+
def prune_configs(self, kwargs: Dict) -> List[Any]:
|
|
269
|
+
pruned_configs = self.configs
|
|
270
|
+
if self.early_config_prune:
|
|
271
|
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
272
|
+
if self.perf_model:
|
|
273
|
+
top_k = self.configs_top_k
|
|
274
|
+
if isinstance(top_k, float) and top_k <= 1.0:
|
|
275
|
+
top_k = int(len(self.configs) * top_k)
|
|
276
|
+
elif not isinstance(top_k, int):
|
|
277
|
+
# Slice index must be an integer
|
|
278
|
+
raise TypeError(
|
|
279
|
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if len(pruned_configs) > top_k:
|
|
283
|
+
est_timing = {
|
|
284
|
+
config: self.perf_model(
|
|
285
|
+
**self.nargs,
|
|
286
|
+
**kwargs,
|
|
287
|
+
**config.all_kwargs(),
|
|
288
|
+
)
|
|
289
|
+
for config in pruned_configs
|
|
290
|
+
}
|
|
291
|
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
292
|
+
return pruned_configs
|
|
293
|
+
|
|
242
294
|
|
|
243
295
|
class AutotuneConfig:
|
|
244
296
|
"""
|
|
@@ -272,7 +324,9 @@ class AutotuneConfig:
|
|
|
272
324
|
return self_tuple == other_tuple
|
|
273
325
|
|
|
274
326
|
|
|
275
|
-
def autotune(
|
|
327
|
+
def autotune(
|
|
328
|
+
configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
|
|
329
|
+
):
|
|
276
330
|
f"""
|
|
277
331
|
Decorator for auto-tuning a function function.
|
|
278
332
|
|
|
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
286
340
|
:type configs: list[AutotuneConfig]
|
|
287
341
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
288
342
|
:type key: list[str]
|
|
343
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
344
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
345
|
+
'top_k': number of configs to bench
|
|
346
|
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
|
289
347
|
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
290
348
|
:type restore_value: list[str]
|
|
291
349
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
303
361
|
key,
|
|
304
362
|
configs,
|
|
305
363
|
restore_value=restore_value,
|
|
364
|
+
prune_configs_by=prune_configs_by,
|
|
306
365
|
do_bench=do_bench,
|
|
307
366
|
cache_results=cache_results,
|
|
308
367
|
)
|
|
@@ -199,11 +199,8 @@ class CrossEntropy(ReductionBase):
|
|
|
199
199
|
cute.autovec_copy(tXsX, tXrX)
|
|
200
200
|
x = tXrX.load().to(Float32)
|
|
201
201
|
log2_e = math.log2(math.e)
|
|
202
|
-
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
203
|
-
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
204
|
-
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
205
202
|
# This would use ffma instead of fadd then fmul
|
|
206
|
-
exp_x =
|
|
203
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
|
|
207
204
|
denom = row_reduce(
|
|
208
205
|
exp_x,
|
|
209
206
|
cute.ReductionOp.ADD,
|
|
@@ -228,8 +225,7 @@ class CrossEntropy(ReductionBase):
|
|
|
228
225
|
and row < shape[0]
|
|
229
226
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
230
227
|
):
|
|
231
|
-
|
|
232
|
-
lse = max_x + utils.log2f(denom) * ln_2
|
|
228
|
+
lse = max_x + cute.math.log(denom, fastmath=True)
|
|
233
229
|
# Set loss to 0 if this index should be ignored, otherwise compute normally
|
|
234
230
|
loss_val = (lse - target_logit) if not should_ignore else Float32.zero
|
|
235
231
|
mLoss[row] = mLoss.element_type(loss_val)
|
|
@@ -552,7 +548,7 @@ class CrossEntropyBackward:
|
|
|
552
548
|
lse = Float32(mLSE[row])
|
|
553
549
|
|
|
554
550
|
log2_e = math.log2(math.e)
|
|
555
|
-
probs =
|
|
551
|
+
probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
|
|
556
552
|
prob_shifted = probs - 1.0
|
|
557
553
|
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
558
554
|
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
@@ -594,9 +590,9 @@ def _cross_entropy_backward(
|
|
|
594
590
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
595
591
|
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
596
592
|
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
597
|
-
assert (
|
|
598
|
-
|
|
599
|
-
)
|
|
593
|
+
assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
|
|
594
|
+
"Tensors must be on CUDA device"
|
|
595
|
+
)
|
|
600
596
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
601
597
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
602
598
|
|
|
@@ -98,22 +98,21 @@ class ArgumentsBase(JitArgument):
|
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def load_cubin_module_data_patched(cubin_data, filepath):
|
|
101
|
-
|
|
102
|
-
path.write_bytes(cubin_data)
|
|
101
|
+
pathlib.Path(filepath).write_bytes(cubin_data)
|
|
103
102
|
return load_cubin_module_data_og(cubin_data)
|
|
104
103
|
|
|
105
104
|
|
|
106
105
|
def cute_compile_patched(*args, **kwargs):
|
|
107
106
|
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
|
|
108
|
-
|
|
107
|
+
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
|
108
|
+
if cubin_path is not None:
|
|
109
109
|
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
|
110
|
-
load_cubin_module_data_patched, filepath=
|
|
110
|
+
load_cubin_module_data_patched, filepath=cubin_path
|
|
111
111
|
)
|
|
112
112
|
output = cute_compile_og(*args, **kwargs)
|
|
113
|
-
if
|
|
113
|
+
if cubin_path is not None:
|
|
114
114
|
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
|
115
115
|
if extract is not None:
|
|
116
|
-
cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
|
|
117
116
|
sass = extract(cubin_path, None)
|
|
118
|
-
cubin_path.with_suffix(".annotated.sass").write_text(sass)
|
|
117
|
+
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
|
119
118
|
return output
|