mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,860 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+
4
+ import math
5
+ import hashlib
6
+ import inspect
7
+ import re
8
+ from typing import Type, Callable, Optional, Tuple, overload
9
+ from functools import partial
10
+
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+
14
+ from cutlass import Float32, const_expr
15
+ from cutlass.cutlass_dsl import T, dsl_user_op
16
+ from cutlass._mlir.dialects import nvvm, llvm
17
+ from cutlass.cute.runtime import from_dlpack
18
+
19
+
20
+ # cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
21
+ fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
22
+ mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
23
+ add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
24
+ sub_packed_f32x2 = partial(
25
+ cute.arch.calc_packed_f32x2_op,
26
+ src_c=None,
27
+ calc_func=nvvm.sub_packed_f32x2,
28
+ rnd=nvvm.RoundingModeKind.RN,
29
+ )
30
+
31
+
32
+ def hash_callable(func: Callable, set_cute_hash=True) -> str:
33
+ """Hash a callable based on the source code or bytecode and closure values.
34
+
35
+ Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
36
+ attribute, that value is returned immediately. Code-generation backends such
37
+ as Inductor can set this attribute to avoid expensive runtime hashing.
38
+
39
+ set_cute_hash: whether or not to set func.__cute_hash__ if not present
40
+ """
41
+ if hasattr(func, "__cute_hash__"):
42
+ return func.__cute_hash__
43
+
44
+ # Unwrap decorated functions (e.g., cute.jit wrappers).
45
+ if hasattr(func, "__wrapped__"):
46
+ base_func = func.__wrapped__
47
+ if hasattr(base_func, "__cute_hash__"):
48
+ return base_func.__cute_hash__
49
+ func = base_func
50
+
51
+ try:
52
+ data = inspect.getsource(func).encode()
53
+ except (OSError, TypeError):
54
+ if hasattr(func, "__code__") and func.__code__ is not None:
55
+ data = func.__code__.co_code
56
+ else:
57
+ data = repr(func).encode()
58
+
59
+ hasher = hashlib.sha256(data)
60
+
61
+ if hasattr(func, "__closure__") and func.__closure__ is not None:
62
+ for idx, cell in enumerate(func.__closure__):
63
+ cell_value = cell.cell_contents
64
+ hasher.update(repr(cell_value).encode())
65
+
66
+ hash = hasher.hexdigest()
67
+
68
+ if set_cute_hash:
69
+ func.__cute_hash__ = hash
70
+
71
+ return hash
72
+
73
+
74
+ def create_softcap_scoremod(softcap_val):
75
+ inv_softcap = 1.0 / softcap_val
76
+
77
+ @cute.jit
78
+ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
79
+ scores = acc_S_SSA * inv_softcap
80
+ return scores * cute.math.tanh(scores, fastmath=True)
81
+
82
+ return scoremod_premask_fn
83
+
84
+
85
+ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
86
+ return (
87
+ from_dlpack(x, assumed_align=alignment)
88
+ .mark_layout_dynamic(leading_dim=leading_dim)
89
+ .mark_compact_shape_dynamic(
90
+ mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
91
+ )
92
+ )
93
+
94
+
95
+ def convert_from_dlpack_leading_static(
96
+ x, leading_dim, alignment=16, static_modes=None, stride_order=None
97
+ ) -> cute.Tensor:
98
+ if stride_order is None:
99
+ stride_order = x.dim_order()
100
+ x_ = from_dlpack(x, assumed_align=alignment)
101
+ for i in range(x.ndim):
102
+ if i != leading_dim and (static_modes is None or i not in static_modes):
103
+ x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
104
+ return x_
105
+
106
+
107
+ def make_tiled_copy_A(
108
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
109
+ ) -> cute.TiledCopy:
110
+ if const_expr(swapAB):
111
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
112
+ else:
113
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
114
+
115
+
116
+ def make_tiled_copy_B(
117
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
118
+ ) -> cute.TiledCopy:
119
+ if const_expr(swapAB):
120
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
121
+ else:
122
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
123
+
124
+
125
+ def mma_make_fragment_A(
126
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
127
+ ) -> cute.Tensor:
128
+ if const_expr(swapAB):
129
+ return mma_make_fragment_B(smem, thr_mma)
130
+ else:
131
+ return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
132
+
133
+
134
+ def mma_make_fragment_B(
135
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
136
+ ) -> cute.Tensor:
137
+ if const_expr(swapAB):
138
+ return mma_make_fragment_A(smem, thr_mma)
139
+ else:
140
+ return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
141
+
142
+
143
+ def get_smem_store_atom(
144
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
145
+ ) -> cute.CopyAtom:
146
+ if const_expr(arch < 90 or element_type.width != 16):
147
+ return cute.make_copy_atom(
148
+ cute.nvgpu.CopyUniversalOp(),
149
+ element_type,
150
+ num_bits_per_copy=2 * element_type.width,
151
+ )
152
+ else:
153
+ return cute.make_copy_atom(
154
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
155
+ element_type,
156
+ )
157
+
158
+
159
+ @cute.jit
160
+ def warp_reduce(
161
+ val: cute.TensorSSA | cute.Numeric,
162
+ op: Callable,
163
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
164
+ ) -> cute.TensorSSA | cute.Numeric:
165
+ if const_expr(isinstance(val, cute.TensorSSA)):
166
+ res = cute.make_fragment(val.shape, val.dtype)
167
+ res.store(val)
168
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
169
+ res[i] = warp_reduce(res[i], op, width)
170
+ return res.load()
171
+ else:
172
+ for i in cutlass.range_constexpr(int(math.log2(width))):
173
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
174
+ return val
175
+
176
+
177
+ def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
178
+ """
179
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
180
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
181
+ """
182
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
183
+ shape = (
184
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
185
+ (
186
+ acc_layout_col_major.shape[0][0],
187
+ *acc_layout_col_major.shape[0][2:],
188
+ acc_layout_col_major.shape[2],
189
+ ), # MMA_N
190
+ *acc_layout_col_major.shape[3:],
191
+ )
192
+ stride = (
193
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
194
+ (
195
+ acc_layout_col_major.stride[0][0],
196
+ *acc_layout_col_major.stride[0][2:],
197
+ acc_layout_col_major.stride[2],
198
+ ), # MMA_N
199
+ *acc_layout_col_major.stride[3:],
200
+ )
201
+ if const_expr(transpose):
202
+ shape = (shape[1], shape[0], *shape[2:])
203
+ stride = (stride[1], stride[0], *stride[2:])
204
+ acc_layout_mn = cute.make_layout(shape, stride=stride)
205
+ return cute.composition(acc_layout, acc_layout_mn)
206
+
207
+
208
+ def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
209
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
210
+
211
+
212
+ @cute.jit
213
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
214
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
215
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
216
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
217
+ # TODO: Sm90 FP8
218
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
219
+ l = cute.logical_divide(
220
+ acc_layout, ((None, None, 2), None, None)
221
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
222
+ rA_mma_view = cute.make_layout(
223
+ (
224
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
225
+ l.shape[1],
226
+ (l.shape[0][2][1], l.shape[2]),
227
+ ),
228
+ stride=(
229
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
230
+ l.stride[1],
231
+ (l.stride[0][2][1], l.stride[2]),
232
+ ),
233
+ )
234
+ else: # Sm80
235
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
236
+ l = cute.logical_divide(acc_layout, (None, None, 2))
237
+ rA_mma_view = cute.make_layout(
238
+ (
239
+ (l.shape[0], l.shape[2][0]),
240
+ l.shape[1],
241
+ l.shape[2][1],
242
+ ),
243
+ stride=(
244
+ (l.stride[0], l.stride[2][0]),
245
+ l.stride[1],
246
+ l.stride[2][1],
247
+ ),
248
+ )
249
+ return rA_mma_view
250
+
251
+
252
+ def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor:
253
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
254
+
255
+
256
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
257
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
258
+
259
+
260
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
261
+ """Transpose the first two dimensions of a tensor on smem."""
262
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
263
+ order = (1, 0, *range(2, cute.rank(a)))
264
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
265
+ # stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:])
266
+ # return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
267
+
268
+
269
+ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
270
+ """Extract swizzle parameters from a pointer's swizzle_type.
271
+
272
+ The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
273
+ b, m, s are the swizzle parameters (bits, base, shift).
274
+
275
+ Returns:
276
+ A cute.Swizzle object constructed from the extracted parameters
277
+
278
+ Raises:
279
+ ValueError: If the swizzle_type string cannot be parsed
280
+ """
281
+ # Ideally there should be a better API to get swizzle parameters, but we'll just parse
282
+ # the string here.
283
+ swizzle_str = str(ptr.type.swizzle_type)
284
+ # Extract the inner part "S<b,m,s>"
285
+ match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
286
+ if match:
287
+ b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
288
+ return cute.make_swizzle(b, m, s)
289
+ else:
290
+ raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
291
+
292
+
293
+ @cute.jit
294
+ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
295
+ """exp2f calculation for both vector and scalar.
296
+ :param x: input value
297
+ :type x: cute.TensorSSA or Float32
298
+ :return: exp2 value
299
+ :rtype: cute.TensorSSA or Float32
300
+ """
301
+ if const_expr(isinstance(x, cute.TensorSSA)):
302
+ res = cute.make_fragment(x.shape, Float32)
303
+ res.store(x)
304
+ for i in cutlass.range_constexpr(cute.size(x.shape)):
305
+ res[i] = cute.arch.exp2(res[i])
306
+ return res.load()
307
+ else:
308
+ return cute.arch.exp2(x)
309
+
310
+
311
+ @dsl_user_op
312
+ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
313
+ return Float32(
314
+ llvm.inline_asm(
315
+ T.f32(),
316
+ [Float32(a).ir_value(loc=loc, ip=ip)],
317
+ "lg2.approx.ftz.f32 $0, $1;",
318
+ "=f,f",
319
+ has_side_effects=False,
320
+ is_align_stack=False,
321
+ asm_dialect=llvm.AsmDialect.AD_ATT,
322
+ )
323
+ )
324
+
325
+
326
+ @dsl_user_op
327
+ def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
328
+ return log2f(a, loc=loc, ip=ip) * math.log(2.0)
329
+
330
+
331
+ @dsl_user_op
332
+ def fmax(
333
+ a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
334
+ ) -> Float32:
335
+ return Float32(
336
+ nvvm.fmax(
337
+ T.f32(),
338
+ Float32(a).ir_value(loc=loc, ip=ip),
339
+ Float32(b).ir_value(loc=loc, ip=ip),
340
+ c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
341
+ loc=loc,
342
+ ip=ip,
343
+ )
344
+ )
345
+
346
+
347
+ @cute.jit
348
+ def fmax_reduce(
349
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
350
+ ) -> Float32:
351
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
352
+ # if const_expr(init_val is None):
353
+ # init_val = -cutlass.Float32.if
354
+ # return x.reduce(cute.ReductionOp.MAX, init_val, 0)
355
+ res = cute.make_fragment(x.shape, Float32)
356
+ res.store(x)
357
+ # local_max = [res[0], res[1]]
358
+ # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
359
+ # local_max[0] = fmax(local_max[0], res[i + 0])
360
+ # local_max[1] = fmax(local_max[1], res[i + 1])
361
+ # local_max[0] = fmax(local_max[0], local_max[1])
362
+ # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
363
+ local_max = [res[0], res[1], res[2], res[3]]
364
+ for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
365
+ local_max[0] = fmax(local_max[0], res[i + 0])
366
+ local_max[1] = fmax(local_max[1], res[i + 1])
367
+ local_max[2] = fmax(local_max[2], res[i + 2])
368
+ local_max[3] = fmax(local_max[3], res[i + 3])
369
+ local_max[0] = fmax(local_max[0], local_max[1])
370
+ local_max[2] = fmax(local_max[2], local_max[3])
371
+ local_max[0] = fmax(local_max[0], local_max[2])
372
+ return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
373
+ else:
374
+ # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
375
+ # We instead force the 3-input max.
376
+ res = cute.make_fragment(x.shape, Float32)
377
+ res.store(x)
378
+ local_max_0 = (
379
+ fmax(init_val, res[0], res[1])
380
+ if const_expr(init_val is not None)
381
+ else fmax(res[0], res[1])
382
+ )
383
+ local_max = [
384
+ local_max_0,
385
+ fmax(res[2], res[3]),
386
+ fmax(res[4], res[5]),
387
+ fmax(res[6], res[7]),
388
+ ]
389
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
390
+ local_max[0] = fmax(local_max[0], res[i], res[i + 1])
391
+ local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
392
+ local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
393
+ local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
394
+ local_max[0] = fmax(local_max[0], local_max[1])
395
+ return fmax(local_max[0], local_max[2], local_max[3])
396
+
397
+
398
+ @cute.jit
399
+ def fadd_reduce(
400
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
401
+ ) -> Float32:
402
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
403
+ if const_expr(init_val is None):
404
+ init_val = Float32.zero
405
+ return x.reduce(cute.ReductionOp.ADD, init_val, 0)
406
+ # res = cute.make_fragment(x.shape, Float32)
407
+ # res.store(x)
408
+ # local_sum = [res[0], res[1], res[2], res[3]]
409
+ # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
410
+ # local_sum[0] += res[i + 0]
411
+ # local_sum[1] += res[i + 1]
412
+ # local_sum[2] += res[i + 2]
413
+ # local_sum[3] += res[i + 3]
414
+ # local_sum[0] += local_sum[1]
415
+ # local_sum[2] += local_sum[3]
416
+ # local_sum[0] += local_sum[2]
417
+ # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
418
+ else:
419
+ res = cute.make_fragment(x.shape, Float32)
420
+ res.store(x)
421
+ local_sum_0 = (
422
+ add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
423
+ # add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
424
+ if const_expr(init_val is not None)
425
+ else (res[0], res[1])
426
+ )
427
+ local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
428
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
429
+ local_sum[0] = add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
430
+ local_sum[1] = add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
431
+ local_sum[2] = add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
432
+ local_sum[3] = add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
433
+ local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[1])
434
+ local_sum[2] = add_packed_f32x2(local_sum[2], local_sum[3])
435
+ local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[2])
436
+ return local_sum[0][0] + local_sum[0][1]
437
+
438
+
439
+ @dsl_user_op
440
+ def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
441
+ # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
442
+ # # cache_hint = cutlass.Int64(0x12F0000000000000)
443
+ # llvm.inline_asm(
444
+ # None,
445
+ # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
446
+ # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
447
+ # "red.global.add.f32 [$0], $1;",
448
+ # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
449
+ # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
450
+ # "l,f",
451
+ # # "l,f,l",
452
+ # has_side_effects=True,
453
+ # is_align_stack=False,
454
+ # asm_dialect=llvm.AsmDialect.AD_ATT,
455
+ # )
456
+ nvvm.atomicrmw(
457
+ res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
458
+ )
459
+
460
+
461
+ @dsl_user_op
462
+ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
463
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
464
+
465
+
466
+ @dsl_user_op
467
+ def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
468
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
469
+ flat_stride = cute.flatten_to_tuple(x.stride)
470
+ assert len(flat_coord_i64) == len(flat_stride), (
471
+ "Coordinate and stride must have the same length"
472
+ )
473
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
474
+ # HACK: we assume that applying the offset does not change the pointer alignment
475
+ byte_offset = offset * x.element_type.width // 8
476
+ return cute.make_ptr(
477
+ x.element_type,
478
+ x.iterator.toint() + byte_offset,
479
+ x.memspace,
480
+ assumed_align=x.iterator.alignment,
481
+ )
482
+
483
+
484
+ @cute.jit
485
+ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
486
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
487
+ tApA = cute.make_fragment(
488
+ cute.make_layout(
489
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
490
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
491
+ ),
492
+ cutlass.Boolean,
493
+ )
494
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
495
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
496
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
497
+ return tApA
498
+
499
+
500
+ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
501
+ warp_group_idx = cute.arch.thread_idx()[0] // 128
502
+ if const_expr(sync):
503
+ warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
504
+ return warp_group_idx
505
+
506
+
507
+ # @dsl_user_op
508
+ # def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
509
+ # mask = cutlass.Int32(-1)
510
+ # return cutlass.Boolean(
511
+ # llvm.inline_asm(
512
+ # T.i32(),
513
+ # [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
514
+ # ".pred p1, p2;\n"
515
+ # "setp.lt.f32 p1, $1, $2;\n"
516
+ # "vote.sync.any.pred p2, p1, $3;\n"
517
+ # "selp.u32 $0, 1, 0, p2;",
518
+ # # "selp.u32 $0, 1, 0, p1;",
519
+ # "=r,f,f,r",
520
+ # has_side_effects=False,
521
+ # is_align_stack=False,
522
+ # asm_dialect=llvm.AsmDialect.AD_ATT,
523
+ # )
524
+ # )
525
+
526
+
527
+ @cute.jit
528
+ def shuffle_sync(
529
+ value: cute.Numeric,
530
+ offset: cute.typing.Int,
531
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
532
+ ) -> cute.Numeric:
533
+ assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
534
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
535
+ mask = cute.arch.WARP_SIZE - width
536
+ clamp = cute.arch.WARP_SIZE - 1
537
+ mask_and_clamp = mask << 8 | clamp
538
+ # important: need stride 1 and not 0 for recast_tensor to work
539
+ val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
540
+ val[0] = value
541
+ val_i32 = cute.recast_tensor(val, cutlass.Int32)
542
+ for i in cutlass.range_constexpr(cute.size(val_i32)):
543
+ val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
544
+ return val[0]
545
+
546
+
547
+ @dsl_user_op
548
+ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
549
+ return cutlass.Uint32(
550
+ llvm.inline_asm(
551
+ T.i32(),
552
+ [
553
+ cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
554
+ cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
555
+ ],
556
+ "shr.s32 $0, $1, $2;",
557
+ "=r,r,r",
558
+ has_side_effects=False,
559
+ is_align_stack=False,
560
+ asm_dialect=llvm.AsmDialect.AD_ATT,
561
+ )
562
+ )
563
+
564
+
565
+ @cute.jit
566
+ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
567
+ if const_expr(lane is None):
568
+ lane = cute.arch.lane_idx()
569
+ # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
570
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
571
+ offset = 1 << i
572
+ # Very important that we set mask_and_clamp to 0
573
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
574
+ if lane >= offset:
575
+ val += partial_sum
576
+ # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
577
+ return val
578
+
579
+
580
+ @dsl_user_op
581
+ def cvt_f16x2_f32(
582
+ a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
583
+ ) -> cutlass.Int32:
584
+ assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
585
+ return cutlass.Int32(
586
+ llvm.inline_asm(
587
+ T.i32(),
588
+ [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
589
+ f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
590
+ "=r,f,f",
591
+ has_side_effects=False,
592
+ is_align_stack=False,
593
+ asm_dialect=llvm.AsmDialect.AD_ATT,
594
+ )
595
+ )
596
+
597
+
598
+ @overload
599
+ def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
600
+
601
+
602
+ @overload
603
+ def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
604
+
605
+
606
+ @cute.jit
607
+ def cvt_f16(src: cute.Tensor, dst_or_dtype):
608
+ """Convert Float32 tensor to Float16/BFloat16.
609
+
610
+ Args:
611
+ src: Source tensor with Float32 element type
612
+ dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
613
+
614
+ Returns:
615
+ None if dst is a tensor, or a new tensor if dtype is provided
616
+ """
617
+ if const_expr(isinstance(dst_or_dtype, type)):
618
+ # dtype variant: create new tensor and call the tensor variant
619
+ dtype = dst_or_dtype
620
+ dst = cute.make_fragment(src.shape, dtype)
621
+ cvt_f16(src, dst)
622
+ return dst
623
+ else:
624
+ # tensor variant: write to dst
625
+ dst = dst_or_dtype
626
+ assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
627
+ assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
628
+ assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
629
+ "dst must be BFloat16 or Float16"
630
+ )
631
+ assert src.element_type is Float32, "src must be Float32"
632
+ dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
633
+ assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
634
+ for i in cutlass.range_constexpr(cute.size(dst_i32)):
635
+ dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
636
+
637
+
638
+ @dsl_user_op
639
+ @cute.jit
640
+ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
641
+ deg = len(poly) - 1
642
+ out = poly[deg]
643
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
644
+ out = out * x + poly[i]
645
+ return out
646
+
647
+
648
+ @dsl_user_op
649
+ @cute.jit
650
+ def evaluate_polynomial_2(
651
+ x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
652
+ ) -> Tuple[Float32, Float32]:
653
+ deg = len(poly) - 1
654
+ out = (poly[deg], poly[deg])
655
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
656
+ out = fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
657
+ return out
658
+
659
+
660
+ @dsl_user_op
661
+ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
662
+ # There's probably a way to call llvm or nvvm to do this instead of ptx
663
+ return cutlass.Float32(
664
+ llvm.inline_asm(
665
+ T.f32(),
666
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
667
+ "add.rm.ftz.f32 $0, $1, $2;",
668
+ "=f,f,f",
669
+ has_side_effects=False,
670
+ is_align_stack=False,
671
+ asm_dialect=llvm.AsmDialect.AD_ATT,
672
+ )
673
+ )
674
+
675
+
676
+ @dsl_user_op
677
+ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
678
+ return cutlass.Float32(
679
+ llvm.inline_asm(
680
+ T.f32(),
681
+ [
682
+ Float32(x_rounded).ir_value(loc=loc, ip=ip),
683
+ Float32(frac_ex2).ir_value(loc=loc, ip=ip),
684
+ ],
685
+ "{\n\t"
686
+ ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
687
+ "mov.b32 x_rounded_i, $1;\n\t"
688
+ "mov.b32 frac_ex_i, $2;\n\t"
689
+ "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
690
+ # add.u32 generates IMAD instruction and add.s32 generates LEA instruction
691
+ # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
692
+ "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
693
+ "mov.b32 $0, out_i;\n\t"
694
+ "}\n",
695
+ "=f,f,f",
696
+ has_side_effects=False,
697
+ is_align_stack=False,
698
+ asm_dialect=llvm.AsmDialect.AD_ATT,
699
+ )
700
+ )
701
+
702
+
703
+ @dsl_user_op
704
+ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
705
+ # We assume x <= 127.0
706
+ poly_ex2_deg3 = (
707
+ 1.0,
708
+ 0.695146143436431884765625,
709
+ 0.227564394474029541015625,
710
+ 0.077119089663028717041015625,
711
+ )
712
+ fp32_round_int = float(2**23 + 2**22)
713
+ x_clamped = cute.arch.fmax(x, -127.0)
714
+ # We want to round down here, so that the fractional part is in [0, 1)
715
+ x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
716
+ # The integer floor of x is now in the last 8 bits of x_rounded
717
+ # We assume the next 2 ops round to nearest even. The rounding mode is important.
718
+ x_rounded_back = x_rounded - fp32_round_int
719
+ x_frac = x_clamped - x_rounded_back
720
+ x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip)
721
+ return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
722
+
723
+
724
+ # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
725
+ @dsl_user_op
726
+ def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
727
+ # We assume x <= 127.0 and y <= 127.0
728
+ poly_ex2_deg3 = (
729
+ 1.0,
730
+ 0.695146143436431884765625,
731
+ 0.227564394474029541015625,
732
+ 0.077119089663028717041015625,
733
+ )
734
+ fp32_round_int = float(2**23 + 2**22)
735
+ xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
736
+ # We want to round down here, so that the fractional part is in [0, 1)
737
+ xy_rounded = cute.arch.add_packed_f32x2(
738
+ xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM
739
+ )
740
+ # The integer floor of x & y are now in the last 8 bits of xy_rounded
741
+ # We want the next 2 ops to round to nearest even. The rounding mode is important.
742
+ xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int))
743
+ xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back)
744
+ xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip)
745
+ x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
746
+ y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
747
+ return x_out, y_out
748
+
749
+
750
+ @dsl_user_op
751
+ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
752
+ out_f32x2 = llvm.inline_asm(
753
+ llvm.StructType.get_literal([T.f32(), T.f32()]),
754
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
755
+ "{\n\t"
756
+ ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
757
+ ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
758
+ ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
759
+ "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
760
+ "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
761
+ "mov.b64 l1, {f1, f2};\n\t"
762
+ "mov.f32 f3, 0f4B400000;\n\t"
763
+ "mov.b64 l2, {f3, f3};\n\t"
764
+ "add.rm.ftz.f32x2 l7, l1, l2;\n\t"
765
+ "sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
766
+ "sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
767
+ "mov.f32 f7, 0f3D9DF09D;\n\t"
768
+ "mov.b64 l6, {f7, f7};\n\t"
769
+ "mov.f32 f6, 0f3E6906A4;\n\t"
770
+ "mov.b64 l5, {f6, f6};\n\t"
771
+ "mov.f32 f5, 0f3F31F519;\n\t"
772
+ "mov.b64 l4, {f5, f5};\n\t"
773
+ "mov.f32 f4, 0f3F800000;\n\t"
774
+ "mov.b64 l3, {f4, f4};\n\t"
775
+ "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
776
+ "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
777
+ "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
778
+ "mov.b64 {r1, r2}, l7;\n\t"
779
+ "mov.b64 {r3, r4}, l10;\n\t"
780
+ "shl.b32 r5, r1, 23;\n\t"
781
+ "add.s32 r7, r5, r3;\n\t"
782
+ "shl.b32 r6, r2, 23;\n\t"
783
+ "add.s32 r8, r6, r4;\n\t"
784
+ "mov.b32 $0, r7;\n\t"
785
+ "mov.b32 $1, r8;\n\t"
786
+ "}\n",
787
+ "=r,=r,f,f",
788
+ has_side_effects=False,
789
+ is_align_stack=False,
790
+ asm_dialect=llvm.AsmDialect.AD_ATT,
791
+ )
792
+ out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
793
+ out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
794
+ return out0, out1
795
+
796
+
797
+ @dsl_user_op
798
+ def domain_offset_aligned(
799
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
800
+ ) -> cute.Tensor:
801
+ assert isinstance(tensor.iterator, cute.Pointer)
802
+ # We assume that applying the offset does not change the pointer alignment
803
+ new_ptr = cute.make_ptr(
804
+ tensor.element_type,
805
+ elem_pointer(tensor, coord).toint(),
806
+ tensor.memspace,
807
+ assumed_align=tensor.iterator.alignment,
808
+ )
809
+ return cute.make_tensor(new_ptr, tensor.layout)
810
+
811
+
812
+ @dsl_user_op
813
+ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
814
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
815
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
816
+ assert len(flat_coord_i64) == len(flat_stride), (
817
+ "Coordinate and stride must have the same length"
818
+ )
819
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
820
+ assert isinstance(tensor.iterator, cute.Pointer)
821
+ # HACK: we assume that applying the offset does not change the pointer alignment
822
+ new_ptr = cute.make_ptr(
823
+ tensor.element_type,
824
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
825
+ tensor.memspace,
826
+ assumed_align=tensor.iterator.max_alignment,
827
+ )
828
+ return cute.make_tensor(new_ptr, tensor.layout)
829
+
830
+
831
+ @dsl_user_op
832
+ def coord_offset_i64(
833
+ tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None
834
+ ) -> cute.Tensor:
835
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
836
+ assert isinstance(tensor.iterator, cute.Pointer)
837
+ # HACK: we assume that applying the offset does not change the pointer alignment
838
+ new_ptr = cute.make_ptr(
839
+ tensor.element_type,
840
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
841
+ tensor.memspace,
842
+ assumed_align=tensor.iterator.max_alignment,
843
+ )
844
+ new_layout = cute.slice_(
845
+ tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))
846
+ )
847
+ return cute.make_tensor(new_ptr, new_layout)
848
+
849
+
850
+ @cute.jit
851
+ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
852
+ """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
853
+ vec = cute.make_fragment(1, dtype)
854
+ vec[0] = a
855
+ return vec.load()
856
+
857
+
858
+ def ssa_to_scalar(val):
859
+ """Could inline but nice for reflecting the above api"""
860
+ return val[0]