quack-kernels 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {quack_kernels-0.2.0/quack_kernels.egg-info → quack_kernels-0.2.2}/PKG-INFO +3 -3
  2. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/pyproject.toml +2 -2
  3. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/__init__.py +1 -1
  4. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/activation.py +16 -25
  5. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/autotuner.py +64 -5
  6. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/cross_entropy.py +6 -10
  7. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/cute_dsl_utils.py +6 -7
  8. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/dense_gemm_sm90.py +582 -287
  9. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_act_sm90.py +70 -29
  10. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_dact_sm90.py +43 -10
  11. quack_kernels-0.2.2/quack/gemm_interface.py +892 -0
  12. quack_kernels-0.2.0/quack/dense_gemm_sm100.py → quack_kernels-0.2.2/quack/gemm_sm100.py +443 -419
  13. quack_kernels-0.2.2/quack/gemm_wrapper_utils.py +315 -0
  14. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/layernorm.py +1 -1
  15. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/reduce.py +6 -7
  16. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/rmsnorm.py +126 -158
  17. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/softmax.py +1 -1
  18. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/tile_scheduler.py +37 -49
  19. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/utils.py +61 -71
  20. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/varlen_utils.py +1 -6
  21. {quack_kernels-0.2.0 → quack_kernels-0.2.2/quack_kernels.egg-info}/PKG-INFO +3 -3
  22. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/SOURCES.txt +3 -1
  23. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/requires.txt +1 -1
  24. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_linear.py +6 -1
  25. quack_kernels-0.2.2/tests/test_linear_varlen_k.py +266 -0
  26. quack_kernels-0.2.2/tests/test_linear_varlen_m.py +376 -0
  27. quack_kernels-0.2.0/quack/gemm_interface.py +0 -569
  28. quack_kernels-0.2.0/quack/gemm_wrapper_utils.py +0 -158
  29. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/LICENSE +0 -0
  30. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/README.md +0 -0
  31. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/fast_math.py +0 -0
  32. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/gemm_config.py +0 -0
  33. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/linear.py +0 -0
  34. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/linear_cross_entropy.py +0 -0
  35. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/mlp.py +0 -0
  36. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/pipeline.py +0 -0
  37. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/reduction_base.py +0 -0
  38. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/bitonic_sort.py +0 -0
  39. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/generate_sorting_networks.py +0 -0
  40. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/sorting_networks.py +0 -0
  41. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/sort/utils.py +0 -0
  42. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/symmetric_dense_gemm_sm90.py +0 -0
  43. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/tensormap_manager.py +0 -0
  44. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack/topk.py +0 -0
  45. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/dependency_links.txt +0 -0
  46. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/quack_kernels.egg-info/top_level.txt +0 -0
  47. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/setup.cfg +0 -0
  48. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_cross_entropy.py +0 -0
  49. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_layernorm.py +0 -0
  50. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_linear_cross_entropy.py +0 -0
  51. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_rmsnorm.py +0 -0
  52. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_softmax.py +0 -0
  53. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_symmetric_dense_gemm_sm90.py +0 -0
  54. {quack_kernels-0.2.0 → quack_kernels-0.2.2}/tests/test_topk.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.0
4
- Requires-Python: >=3.12
3
+ Version: 0.2.2
4
+ Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.1
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "quack-kernels"
7
7
  dynamic = ["version"]
8
- requires-python = ">=3.12"
8
+ requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.2.0",
10
+ "nvidia-cutlass-dsl==4.2.1",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -1,4 +1,4 @@
1
- __version__ = "0.2.0"
1
+ __version__ = "0.2.2"
2
2
 
3
3
  import cutlass.cute as cute
4
4
 
@@ -6,23 +6,12 @@ from typing import Tuple
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
  from cutlass import Float32
9
- from cutlass.cutlass_dsl import T, dsl_user_op
10
- from cutlass._mlir.dialects import llvm
9
+ from cutlass.cutlass_dsl import dsl_user_op
11
10
 
12
11
 
13
12
  @dsl_user_op
14
- def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
15
- return Float32(
16
- llvm.inline_asm(
17
- T.f32(),
18
- [Float32(a).ir_value(loc=loc, ip=ip)],
19
- "tanh.approx.f32 $0, $1;",
20
- "=f,f",
21
- has_side_effects=False,
22
- is_align_stack=False,
23
- asm_dialect=llvm.AsmDialect.AD_ATT,
24
- )
25
- )
13
+ def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
14
+ return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
26
15
 
27
16
 
28
17
  @dsl_user_op
@@ -67,7 +56,10 @@ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
67
56
  """
68
57
  sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
69
58
  sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
70
- return 0.5 * (x * (1 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))))
59
+ return 0.5 * (
60
+ x
61
+ * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
62
+ )
71
63
 
72
64
 
73
65
  @dsl_user_op
@@ -88,7 +80,7 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
88
80
 
89
81
  # Compute z = x * (c1 + c2 * x^2)
90
82
  x_sq = x * x
91
- tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
83
+ tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
92
84
  half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
93
85
  gelu_out = x * half_tanh_z_plus_one
94
86
 
@@ -111,7 +103,7 @@ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
111
103
  This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
112
104
  """
113
105
  x_half = 0.5 * x
114
- return x_half * tanh(x_half) + x_half
106
+ return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
115
107
 
116
108
 
117
109
  @dsl_user_op
@@ -134,8 +126,8 @@ def dswiglu(
134
126
  to use FFMA instead of FADD and FMUL).
135
127
  """
136
128
  # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
137
- x_half = 0.5 * x # FMUL
138
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
129
+ # FMUL, MUFU.TANH, then FFMA
130
+ sigmoid_x = sigmoid(x)
139
131
  silu_x = x * sigmoid_x # FMUL
140
132
  silu_x_dout = silu_x * dout # FMUL
141
133
  # d_silu(x) * dout
@@ -161,7 +153,7 @@ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=Non
161
153
  """
162
154
  # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
163
155
  x_half = 0.5 * x
164
- silu_x = x_half * tanh(alpha * x_half) + x_half
156
+ silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
165
157
  return silu_x * y + silu_x
166
158
 
167
159
 
@@ -179,7 +171,8 @@ def dswiglu_oai(
179
171
  """
180
172
  # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
181
173
  alpha_x_half = (0.5 * alpha) * x # FMUL
182
- sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) # MUFU.TANH, then FFMA
174
+ # MUFU.TANH, then FFMA
175
+ sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
183
176
  silu_x = x * sigmoid_alpha_x # FMUL
184
177
  silu_x_dout = silu_x * dout # FMUL
185
178
  # FFMA, FFMA, FMUL
@@ -197,8 +190,7 @@ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
197
190
  glu(x, y) = sigmoid(x) * y
198
191
  Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
199
192
  """
200
- x_half = 0.5 * x # FMUL
201
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
193
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
202
194
  return sigmoid_x * y # FMUL
203
195
 
204
196
 
@@ -215,8 +207,7 @@ def dglu(
215
207
  - glu_out = sigmoid(x) * y
216
208
  """
217
209
  # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
218
- x_half = 0.5 * x # FMUL
219
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
210
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
220
211
  sigmoid_x_dout = sigmoid_x * dout # FMUL
221
212
  glu_out = sigmoid_x * y # FMUL
222
213
  # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
@@ -11,7 +11,7 @@ import hashlib
11
11
  import json
12
12
  from pathlib import Path
13
13
  from functools import cached_property, partial
14
- from typing import Dict, Tuple
14
+ from typing import Dict, Tuple, List, Optional, Any
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
@@ -53,7 +53,22 @@ def _base32(key):
53
53
 
54
54
 
55
55
  class Autotuner:
56
- def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
56
+ def __init__(
57
+ self,
58
+ fn,
59
+ key,
60
+ configs,
61
+ restore_value=None,
62
+ prune_configs_by: Optional[Dict] = None,
63
+ do_bench=None,
64
+ cache_results=False,
65
+ ):
66
+ """
67
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
68
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
69
+ 'top_k': number of configs to bench
70
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
71
+ """
57
72
  if not configs:
58
73
  self.configs = [AutotuneConfig()]
59
74
  else:
@@ -90,6 +105,16 @@ class Autotuner:
90
105
  else:
91
106
  self.post_hook = None
92
107
 
108
+ self.perf_model = None
109
+ self.configs_top_k = 1.0
110
+ self.early_config_prune = None
111
+ if prune_configs_by:
112
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
113
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
114
+ self.early_config_prune = prune_configs_by.get(
115
+ "early_config_prune", self.early_config_prune
116
+ )
117
+
93
118
  self.fn = fn
94
119
  self._do_bench = do_bench
95
120
 
@@ -198,13 +223,14 @@ class Autotuner:
198
223
  key = tuple(key)
199
224
  if key not in self.cache:
200
225
  used_cached_result = False
226
+ pruned_configs = self.prune_configs(kwargs)
201
227
 
202
228
  @torch.compiler.disable # Don't want any tracing here
203
229
  def benchmark():
204
230
  bench_start = time.time()
205
231
  timings = {
206
232
  config: self._bench(*args, config=config, **kwargs)
207
- for config in self.configs
233
+ for config in pruned_configs
208
234
  }
209
235
  bench_end = time.time()
210
236
  if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
@@ -215,7 +241,7 @@ class Autotuner:
215
241
  self.configs_timings = timings
216
242
 
217
243
  if self.cache_results:
218
- self.check_disk_cache(key, self.configs, benchmark)
244
+ self.check_disk_cache(key, pruned_configs, benchmark)
219
245
  else:
220
246
  benchmark()
221
247
 
@@ -239,6 +265,32 @@ class Autotuner:
239
265
  self.nargs = None
240
266
  return ret
241
267
 
268
+ def prune_configs(self, kwargs: Dict) -> List[Any]:
269
+ pruned_configs = self.configs
270
+ if self.early_config_prune:
271
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
272
+ if self.perf_model:
273
+ top_k = self.configs_top_k
274
+ if isinstance(top_k, float) and top_k <= 1.0:
275
+ top_k = int(len(self.configs) * top_k)
276
+ elif not isinstance(top_k, int):
277
+ # Slice index must be an integer
278
+ raise TypeError(
279
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
280
+ )
281
+
282
+ if len(pruned_configs) > top_k:
283
+ est_timing = {
284
+ config: self.perf_model(
285
+ **self.nargs,
286
+ **kwargs,
287
+ **config.all_kwargs(),
288
+ )
289
+ for config in pruned_configs
290
+ }
291
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
292
+ return pruned_configs
293
+
242
294
 
243
295
  class AutotuneConfig:
244
296
  """
@@ -272,7 +324,9 @@ class AutotuneConfig:
272
324
  return self_tuple == other_tuple
273
325
 
274
326
 
275
- def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
327
+ def autotune(
328
+ configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
329
+ ):
276
330
  f"""
277
331
  Decorator for auto-tuning a function function.
278
332
 
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
286
340
  :type configs: list[AutotuneConfig]
287
341
  :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
288
342
  :type key: list[str]
343
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
344
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
345
+ 'top_k': number of configs to bench
346
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
289
347
  :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
290
348
  :type restore_value: list[str]
291
349
  :param do_bench: a benchmark function to measure the time of each run.
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
303
361
  key,
304
362
  configs,
305
363
  restore_value=restore_value,
364
+ prune_configs_by=prune_configs_by,
306
365
  do_bench=do_bench,
307
366
  cache_results=cache_results,
308
367
  )
@@ -199,11 +199,8 @@ class CrossEntropy(ReductionBase):
199
199
  cute.autovec_copy(tXsX, tXrX)
200
200
  x = tXrX.load().to(Float32)
201
201
  log2_e = math.log2(math.e)
202
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
203
- # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
204
- # exp_x = utils.exp2f((x - max_x) * log2_e)
205
202
  # This would use ffma instead of fadd then fmul
206
- exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
203
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
207
204
  denom = row_reduce(
208
205
  exp_x,
209
206
  cute.ReductionOp.ADD,
@@ -228,8 +225,7 @@ class CrossEntropy(ReductionBase):
228
225
  and row < shape[0]
229
226
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
230
227
  ):
231
- ln_2 = math.log(2.0)
232
- lse = max_x + utils.log2f(denom) * ln_2
228
+ lse = max_x + cute.math.log(denom, fastmath=True)
233
229
  # Set loss to 0 if this index should be ignored, otherwise compute normally
234
230
  loss_val = (lse - target_logit) if not should_ignore else Float32.zero
235
231
  mLoss[row] = mLoss.element_type(loss_val)
@@ -552,7 +548,7 @@ class CrossEntropyBackward:
552
548
  lse = Float32(mLSE[row])
553
549
 
554
550
  log2_e = math.log2(math.e)
555
- probs = utils.exp2f(x * log2_e - lse * log2_e)
551
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
556
552
  prob_shifted = probs - 1.0
557
553
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
558
554
  for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
@@ -594,9 +590,9 @@ def _cross_entropy_backward(
594
590
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
595
591
  assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
596
592
  assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
597
- assert (
598
- x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
599
- ), "Tensors must be on CUDA device"
593
+ assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
594
+ "Tensors must be on CUDA device"
595
+ )
600
596
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
601
597
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
602
598
 
@@ -98,22 +98,21 @@ class ArgumentsBase(JitArgument):
98
98
 
99
99
 
100
100
  def load_cubin_module_data_patched(cubin_data, filepath):
101
- path = pathlib.Path(filepath)
102
- path.write_bytes(cubin_data)
101
+ pathlib.Path(filepath).write_bytes(cubin_data)
103
102
  return load_cubin_module_data_og(cubin_data)
104
103
 
105
104
 
106
105
  def cute_compile_patched(*args, **kwargs):
107
106
  """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
108
- if os.getenv("CUTE_CUBIN_PATH") is not None:
107
+ cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
108
+ if cubin_path is not None:
109
109
  cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
110
- load_cubin_module_data_patched, filepath=os.getenv("CUTE_CUBIN_PATH")
110
+ load_cubin_module_data_patched, filepath=cubin_path
111
111
  )
112
112
  output = cute_compile_og(*args, **kwargs)
113
- if os.getenv("CUTE_CUBIN_PATH") is not None:
113
+ if cubin_path is not None:
114
114
  cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
115
115
  if extract is not None:
116
- cubin_path = pathlib.Path(os.getenv("CUTE_CUBIN_PATH"))
117
116
  sass = extract(cubin_path, None)
118
- cubin_path.with_suffix(".annotated.sass").write_text(sass)
117
+ pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
119
118
  return output