mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,754 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+ from typing import Optional, Tuple
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, Boolean, const_expr
8
+ from cutlass.cute.nvgpu import tcgen05
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+ import mslk.attention.flash_attn.mma_sm100_desc as sm100_desc
12
+ from mslk.attention.flash_attn.utils import parse_swizzle_from_pointer
13
+
14
+
15
+ @cute.jit
16
+ def gemm_w_idx(
17
+ tiled_mma: cute.TiledMma,
18
+ acc: cute.Tensor,
19
+ tCrA: cute.Tensor,
20
+ tCrB: cute.Tensor,
21
+ A_idx: Optional[Int32] = None,
22
+ B_idx: Optional[Int32] = None,
23
+ zero_init: bool | Boolean = False,
24
+ swap_AB: bool = False,
25
+ ) -> None:
26
+ if const_expr(swap_AB):
27
+ return gemm_w_idx(
28
+ tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
29
+ )
30
+ else:
31
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
32
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
33
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
34
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
35
+ mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
36
+ cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
37
+
38
+
39
+ @cute.jit
40
+ def gemm_ptx_w_idx(
41
+ tiled_mma: cute.TiledMma,
42
+ acc: cute.Tensor,
43
+ tCrA: cute.Tensor,
44
+ tCrB: cute.Tensor,
45
+ sA: Optional[cute.Tensor],
46
+ sB: cute.Tensor,
47
+ A_idx: Optional[Int32] = None,
48
+ B_idx: Optional[Int32] = None,
49
+ zero_init: bool | Boolean = False,
50
+ **kwargs,
51
+ ) -> None:
52
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
53
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
54
+ sA_cur = None
55
+ if const_expr(sA is not None):
56
+ sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
57
+ sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
58
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
59
+ acc_tmem_addr = acc.iterator.toint()
60
+ gemm_ptx_partial(
61
+ mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs
62
+ )
63
+
64
+
65
+ @cute.jit
66
+ def gemm(
67
+ tiled_mma: cute.TiledMma,
68
+ acc: cute.Tensor,
69
+ tCrA: cute.Tensor,
70
+ tCrB: cute.Tensor,
71
+ zero_init: bool | Boolean = False,
72
+ ) -> cute.TiledMma:
73
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
74
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
75
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
76
+ return tiled_mma
77
+
78
+
79
+ def i64_to_i32x2(i: int) -> Tuple[int, int]:
80
+ """Convert a 64-bit integer to a tuple of two 32-bit integers."""
81
+ return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
82
+
83
+
84
+ @cute.jit
85
+ def gemm_ptx(
86
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
87
+ acc: cute.Tensor,
88
+ tCrA: cute.Tensor,
89
+ tCrB: cute.Tensor,
90
+ sA: Optional[cute.Tensor],
91
+ sB: cute.Tensor,
92
+ zero_init: bool | Boolean = False,
93
+ ) -> None:
94
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
95
+ if const_expr(not is_ts):
96
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
97
+ sA_layout = sA.layout if sA is not None else None
98
+ sB_layout = sB.layout
99
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
100
+ if const_expr(not is_ts):
101
+ sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
102
+ smem_desc_base_a: int = const_expr(
103
+ sm100_desc.make_smem_desc_base(
104
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
105
+ sA_swizzle,
106
+ sm100_desc.Major.K
107
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
108
+ else sm100_desc.Major.MN,
109
+ )
110
+ )
111
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
112
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
113
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
114
+ else:
115
+ smem_desc_base_a = None
116
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
117
+ sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
118
+ smem_desc_base_b: int = const_expr(
119
+ sm100_desc.make_smem_desc_base(
120
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
121
+ sB_swizzle,
122
+ sm100_desc.Major.K
123
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
124
+ else sm100_desc.Major.MN,
125
+ )
126
+ )
127
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
128
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
129
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
130
+
131
+ if const_expr(not is_ts):
132
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
133
+ sA[None, None, 0].iterator
134
+ )
135
+ else:
136
+ smem_desc_start_a_lo = None
137
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
138
+ sB[None, None, 0].iterator
139
+ )
140
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
141
+ if const_expr(not is_ts):
142
+ smem_desc_a_lo = smem_desc_start_a_lo + (
143
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
144
+ )
145
+ smem_desc_b_lo = smem_desc_start_b_lo + (
146
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
147
+ )
148
+ # with cute.arch.elect_one():
149
+ # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
150
+ # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
151
+ with cute.arch.elect_one():
152
+ if const_expr(not is_ts):
153
+ llvm.inline_asm(
154
+ None,
155
+ [
156
+ acc.iterator.toint().ir_value(),
157
+ smem_desc_a_lo.ir_value(),
158
+ smem_desc_b_lo.ir_value(),
159
+ Int32(not zero_init or k != 0).ir_value(),
160
+ ],
161
+ "{\n\t"
162
+ ".reg .pred p;\n\t"
163
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
164
+ ".reg .b32 idesc;\n\t"
165
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
166
+ f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
167
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
168
+ "setp.ne.b32 p, $3, 0;\n\t"
169
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
170
+ "}\n",
171
+ "r,r,r,r",
172
+ has_side_effects=True,
173
+ is_align_stack=False,
174
+ asm_dialect=llvm.AsmDialect.AD_ATT,
175
+ )
176
+ else:
177
+ llvm.inline_asm(
178
+ None,
179
+ [
180
+ acc.iterator.toint().ir_value(),
181
+ tCrA[None, None, k].iterator.toint().ir_value(),
182
+ smem_desc_b_lo.ir_value(),
183
+ Int32(not zero_init or k != 0).ir_value(),
184
+ ],
185
+ "{\n\t"
186
+ ".reg .pred p;\n\t"
187
+ ".reg .b64 smem_desc_b;\n\t"
188
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
189
+ "setp.ne.b32 p, $3, 0;\n\t"
190
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
191
+ "}\n",
192
+ "r,r,r,r",
193
+ has_side_effects=True,
194
+ is_align_stack=False,
195
+ asm_dialect=llvm.AsmDialect.AD_ATT,
196
+ )
197
+
198
+
199
+ @cute.jit
200
+ def gemm_ptx_loop(
201
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
202
+ acc: cute.Tensor,
203
+ tCrA: cute.Tensor,
204
+ tCrB: cute.Tensor,
205
+ sA: Optional[cute.Tensor],
206
+ sB: cute.Tensor,
207
+ zero_init: bool | Boolean = False,
208
+ ) -> None:
209
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
210
+ if const_expr(not is_ts):
211
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
212
+ sA_layout = sA.layout if sA is not None else tCrA.layout
213
+ sB_layout = sB.layout
214
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
215
+ if const_expr(not is_ts):
216
+ sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
217
+ smem_desc_base_a: int = const_expr(
218
+ sm100_desc.make_smem_desc_base(
219
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
220
+ sA_swizzle,
221
+ sm100_desc.Major.K
222
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
223
+ else sm100_desc.Major.MN,
224
+ )
225
+ )
226
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
227
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
228
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
229
+ else:
230
+ smem_desc_base_a = None
231
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
232
+ sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
233
+ smem_desc_base_b: int = const_expr(
234
+ sm100_desc.make_smem_desc_base(
235
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
236
+ sB_swizzle,
237
+ sm100_desc.Major.K
238
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
239
+ else sm100_desc.Major.MN,
240
+ )
241
+ )
242
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
243
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
244
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
245
+
246
+ if const_expr(not is_ts):
247
+ offset_a = [
248
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
249
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
250
+ ]
251
+ else:
252
+ offset_a = [
253
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
254
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
255
+ ]
256
+ offset_a_diff = [
257
+ offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
258
+ ]
259
+ offset_b = [
260
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
261
+ for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
262
+ ]
263
+ offset_b_diff = [
264
+ offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
265
+ ]
266
+
267
+ if const_expr(not is_ts):
268
+ smem_desc_start_a_lo = Int32(
269
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
270
+ )
271
+ else:
272
+ smem_desc_start_a_lo = None
273
+ smem_desc_start_b_lo = Int32(
274
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
275
+ )
276
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
277
+ if const_expr(not is_ts):
278
+ llvm.inline_asm(
279
+ None,
280
+ [
281
+ acc.iterator.toint().ir_value(),
282
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
283
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
284
+ Int32(not zero_init).ir_value(),
285
+ ],
286
+ "{\n\t"
287
+ ".reg .pred leader_thread;\n\t"
288
+ ".reg .pred p;\n\t"
289
+ ".reg .b32 idesc;\n\t"
290
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
291
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
292
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
293
+ "elect.sync _|leader_thread, -1;\n\t"
294
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
295
+ "mov.b32 smem_desc_a_lo, $1;\n\t"
296
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
297
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
298
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
299
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
300
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
301
+ "setp.ne.b32 p, $3, 0;\n\t"
302
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
303
+ + "".join(
304
+ (
305
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
306
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
307
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
308
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
309
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
310
+ )
311
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
312
+ )
313
+ + "}\n",
314
+ "r,r,r,r",
315
+ has_side_effects=True,
316
+ is_align_stack=False,
317
+ asm_dialect=llvm.AsmDialect.AD_ATT,
318
+ )
319
+ else:
320
+ llvm.inline_asm(
321
+ None,
322
+ [
323
+ acc.iterator.toint().ir_value(),
324
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
325
+ Int32(smem_desc_start_b_lo).ir_value(),
326
+ Int32(not zero_init).ir_value(),
327
+ ],
328
+ "{\n\t"
329
+ ".reg .pred leader_thread;\n\t"
330
+ ".reg .pred p;\n\t"
331
+ ".reg .b32 idesc;\n\t"
332
+ ".reg .b32 tmem_a;\n\t"
333
+ ".reg .b32 smem_desc_b_lo;\n\t"
334
+ ".reg .b32 smem_desc_b_hi;\n\t"
335
+ ".reg .b64 smem_desc_b;\n\t"
336
+ "elect.sync _|leader_thread, -1;\n\t"
337
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
338
+ "mov.b32 tmem_a, $1;\n\t"
339
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
340
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
341
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
342
+ "setp.ne.b32 p, $3, 0;\n\t"
343
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
344
+ + "".join(
345
+ (
346
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
347
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
348
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
349
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t"
350
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
351
+ )
352
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
353
+ )
354
+ + "}\n",
355
+ "r,r,r,r",
356
+ has_side_effects=True,
357
+ is_align_stack=False,
358
+ asm_dialect=llvm.AsmDialect.AD_ATT,
359
+ )
360
+
361
+
362
+ @cute.jit
363
+ def gemm_ptx_partial(
364
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
365
+ acc_tmem_addr: Int32,
366
+ tCrA: cute.Tensor,
367
+ tCrB: cute.Tensor,
368
+ sA: Optional[cute.Tensor],
369
+ sB: cute.Tensor,
370
+ mbar_ptr: Optional[cutlass.Pointer] = None,
371
+ mbar_phase: Optional[Int32] = None,
372
+ zero_init: bool | Boolean = False,
373
+ # sA_offset: Int32 = 0,
374
+ # acc_offset: Int32 = 0,
375
+ tA_addr: Optional[Int32] = None,
376
+ ) -> None:
377
+ # acc_tmem_addr += acc_offset
378
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
379
+ if const_expr(not is_ts):
380
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
381
+ sA_layout = sA.layout if sA is not None else tCrA.layout
382
+ sB_layout = sB.layout
383
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
384
+ if const_expr(not is_ts):
385
+ sA_swizzle = parse_swizzle_from_pointer(sA.iterator)
386
+ smem_desc_base_a: int = const_expr(
387
+ sm100_desc.make_smem_desc_base(
388
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
389
+ sA_swizzle,
390
+ sm100_desc.Major.K
391
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
392
+ else sm100_desc.Major.MN,
393
+ )
394
+ )
395
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
396
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
397
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
398
+ else:
399
+ smem_desc_base_a = None
400
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
401
+ sB_swizzle = parse_swizzle_from_pointer(sB.iterator)
402
+ smem_desc_base_b: int = const_expr(
403
+ sm100_desc.make_smem_desc_base(
404
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
405
+ sB_swizzle,
406
+ sm100_desc.Major.K
407
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
408
+ else sm100_desc.Major.MN,
409
+ )
410
+ )
411
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
412
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
413
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
414
+
415
+ tCrA_layout = (
416
+ tCrA.layout
417
+ if const_expr(not is_ts)
418
+ else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
419
+ )
420
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
421
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
422
+ offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
423
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
424
+
425
+ if const_expr(not is_ts):
426
+ smem_desc_start_a_lo = Int32(
427
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
428
+ )
429
+ # ) + sA_offset
430
+ else:
431
+ smem_desc_start_a_lo = None
432
+ smem_desc_start_b_lo = Int32(
433
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
434
+ )
435
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
436
+ if const_expr(not is_ts):
437
+ assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
438
+ llvm.inline_asm(
439
+ None,
440
+ [
441
+ # acc.iterator.toint().ir_value(),
442
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
443
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
444
+ Int32(not zero_init).ir_value(),
445
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
446
+ ],
447
+ "{\n\t"
448
+ ".reg .pred leader_thread;\n\t"
449
+ ".reg .pred p;\n\t"
450
+ ".reg .b32 idesc;\n\t"
451
+ ".reg .b32 tmem_acc;\n\t"
452
+ ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
453
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
454
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
455
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
456
+ "elect.sync _|leader_thread, -1;\n\t"
457
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
458
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
459
+ f"mov.b32 tmem_acc, $3;\n\t"
460
+ "mov.b32 smem_desc_a_lo_start, $0;\n\t"
461
+ "mov.b32 smem_desc_b_lo_start, $1;\n\t"
462
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
463
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
464
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
465
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
466
+ "setp.ne.b32 p, $2, 0;\n\t"
467
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
468
+ + "".join(
469
+ (
470
+ # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
471
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
472
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
473
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
474
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
475
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
476
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
477
+ )
478
+ for k in range(1, cute.size(tCrA.shape[2]))
479
+ )
480
+ + "}\n",
481
+ # "r,r,r",
482
+ "r,r,r,r",
483
+ has_side_effects=True,
484
+ is_align_stack=False,
485
+ asm_dialect=llvm.AsmDialect.AD_ATT,
486
+ )
487
+ else:
488
+ # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
489
+ # explicitly pass in the tA_addr for correctness.
490
+ tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
491
+ input_args = [
492
+ # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
493
+ Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
494
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
495
+ Int32(not zero_init).ir_value(),
496
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
497
+ ]
498
+ if const_expr(mbar_ptr is not None):
499
+ assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
500
+ input_args.append(mbar_ptr.toint().ir_value())
501
+ input_args.append(Int32(mbar_phase).ir_value())
502
+ mbar_wait_str = (
503
+ ".reg .pred P1; \n\t"
504
+ "LAB_WAIT: \n\t"
505
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
506
+ "@P1 bra DONE; \n\t"
507
+ "bra LAB_WAIT; \n\t"
508
+ "DONE: \n\t"
509
+ )
510
+ else:
511
+ mbar_wait_str = ""
512
+ llvm.inline_asm(
513
+ None,
514
+ # [
515
+ # # acc.iterator.toint().ir_value(),
516
+ # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
517
+ # Int32(smem_desc_start_b_lo).ir_value(),
518
+ # Int32(not zero_init).ir_value(),
519
+ # ],
520
+ input_args,
521
+ "{\n\t"
522
+ ".reg .pred leader_thread;\n\t"
523
+ ".reg .pred p;\n\t"
524
+ ".reg .b32 idesc;\n\t"
525
+ ".reg .b32 tmem_acc;\n\t"
526
+ ".reg .b32 tmem_a;\n\t"
527
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
528
+ ".reg .b32 smem_desc_b_lo;\n\t"
529
+ ".reg .b32 smem_desc_b_hi;\n\t"
530
+ ".reg .b64 smem_desc_b;\n\t"
531
+ "elect.sync _|leader_thread, -1;\n\t"
532
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
533
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
534
+ f"mov.b32 tmem_acc, $3;\n\t"
535
+ f"mov.b32 tmem_a, $0;\n\t"
536
+ f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
537
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
538
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
539
+ "setp.ne.b32 p, $2, 0;\n\t"
540
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
541
+ + "".join(
542
+ (
543
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
544
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
545
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
546
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
547
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
548
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
549
+ )
550
+ for k in range(
551
+ 1,
552
+ cute.size(tCrA.shape[2])
553
+ if const_expr(mbar_ptr is None)
554
+ else cute.size(tCrA.shape[2]) // 4 * 3,
555
+ )
556
+ )
557
+ + mbar_wait_str
558
+ + (
559
+ "".join(
560
+ (
561
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
562
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
563
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
564
+ )
565
+ for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2]))
566
+ )
567
+ if const_expr(mbar_ptr is not None)
568
+ else ""
569
+ )
570
+ + "}\n",
571
+ "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
572
+ has_side_effects=True,
573
+ is_align_stack=False,
574
+ asm_dialect=llvm.AsmDialect.AD_ATT,
575
+ )
576
+
577
+
578
+ @cute.jit
579
+ def gemm_ptx_partial1(
580
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
581
+ acc_tmem_addr: cutlass.Constexpr[int],
582
+ tCrA: cute.Tensor,
583
+ tCrB: cute.Tensor,
584
+ sA_base_addr_for_desc: Int32,
585
+ sA_addr_offset_for_desc: cutlass.Constexpr[int],
586
+ sA_stage: Int32,
587
+ sB_base_addr_for_desc: Int32,
588
+ sB_addr_offset_for_desc: cutlass.Constexpr[int],
589
+ sB_stage: Int32,
590
+ sA_layout: Optional[cute.Layout],
591
+ sB_layout: Optional[cute.Layout],
592
+ sA_swizzle: Optional[cute.Swizzle],
593
+ sB_swizzle: cute.Swizzle,
594
+ zero_init: bool | Boolean = False,
595
+ ) -> None:
596
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
597
+ if const_expr(not is_ts):
598
+ assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
599
+ assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
600
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
601
+ if const_expr(not is_ts):
602
+ smem_desc_base_a: int = const_expr(
603
+ sm100_desc.make_smem_desc_base(
604
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
605
+ sA_swizzle,
606
+ sm100_desc.Major.K
607
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
608
+ else sm100_desc.Major.MN,
609
+ )
610
+ )
611
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
612
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
613
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
614
+ else:
615
+ smem_desc_base_a = None
616
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
617
+ smem_desc_base_b: int = const_expr(
618
+ sm100_desc.make_smem_desc_base(
619
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
620
+ sB_swizzle,
621
+ sm100_desc.Major.K
622
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
623
+ else sm100_desc.Major.MN,
624
+ )
625
+ )
626
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
627
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
628
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
629
+ mask = [Int32(0)] * 4
630
+
631
+ if const_expr(not is_ts):
632
+ offset_a = [
633
+ (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
634
+ for k in range(cute.size(tCrA.shape[2]))
635
+ ]
636
+ else:
637
+ offset_a = [
638
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
639
+ for k in range(cute.size(tCrA.shape[2]))
640
+ ]
641
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
642
+ offset_b = [
643
+ (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
644
+ for k in range(cute.size(tCrB.shape[2]))
645
+ ]
646
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
647
+
648
+ if const_expr(not is_ts):
649
+ # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
650
+ smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
651
+ else:
652
+ smem_desc_start_a_lo = None
653
+ # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
654
+ smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
655
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
656
+ if const_expr(not is_ts):
657
+ llvm.inline_asm(
658
+ None,
659
+ [
660
+ # acc.iterator.toint().ir_value(),
661
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
662
+ Int32(sA_base_addr_for_desc).ir_value(),
663
+ Int32(sA_stage).ir_value(),
664
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
665
+ Int32(sB_base_addr_for_desc).ir_value(),
666
+ Int32(sB_stage).ir_value(),
667
+ Int32(not zero_init).ir_value(),
668
+ mask[0].ir_value(),
669
+ mask[1].ir_value(),
670
+ mask[2].ir_value(),
671
+ mask[3].ir_value(),
672
+ ],
673
+ "{\n\t"
674
+ ".reg .pred leader_thread;\n\t"
675
+ ".reg .pred p;\n\t"
676
+ ".reg .b32 idesc;\n\t"
677
+ ".reg .b32 tmem_acc;\n\t"
678
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
679
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
680
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
681
+ "elect.sync _|leader_thread, -1;\n\t"
682
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
683
+ f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
684
+ # "mov.b32 smem_desc_a_lo, $0;\n\t"
685
+ # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
686
+ f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
687
+ # "mov.b32 smem_desc_b_lo, $2;\n\t"
688
+ f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
689
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
690
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
691
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
692
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
693
+ "setp.ne.b32 p, $4, 0;\n\t"
694
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
695
+ + "".join(
696
+ (
697
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
698
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
699
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
700
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
701
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
702
+ )
703
+ for k in range(1, cute.size(tCrA.shape[2]))
704
+ )
705
+ + "}\n",
706
+ "r,r,r,r,r,r,r,r,r",
707
+ has_side_effects=True,
708
+ is_align_stack=False,
709
+ asm_dialect=llvm.AsmDialect.AD_ATT,
710
+ )
711
+ else:
712
+ llvm.inline_asm(
713
+ None,
714
+ [
715
+ # acc.iterator.toint().ir_value(),
716
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
717
+ Int32(smem_desc_start_b_lo).ir_value(),
718
+ Int32(not zero_init).ir_value(),
719
+ mask[0].ir_value(),
720
+ mask[1].ir_value(),
721
+ mask[2].ir_value(),
722
+ mask[3].ir_value(),
723
+ ],
724
+ "{\n\t"
725
+ ".reg .pred leader_thread;\n\t"
726
+ ".reg .pred p;\n\t"
727
+ ".reg .b32 idesc;\n\t"
728
+ ".reg .b32 tmem_a;\n\t"
729
+ ".reg .b32 smem_desc_b_lo;\n\t"
730
+ ".reg .b32 smem_desc_b_hi;\n\t"
731
+ ".reg .b64 smem_desc_b;\n\t"
732
+ "elect.sync _|leader_thread, -1;\n\t"
733
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
734
+ f"mov.b32 tmem_a, $1;\n\t"
735
+ f"mov.b32 smem_desc_b_lo, $2;\n\t"
736
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
737
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
738
+ "setp.ne.b32 p, $3, 0;\n\t"
739
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
740
+ + "".join(
741
+ (
742
+ f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
743
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
744
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
745
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
746
+ )
747
+ for k in range(1, cute.size(tCrA.shape[2]))
748
+ )
749
+ + "}\n",
750
+ "r,r,r,r,r,r,r,r",
751
+ has_side_effects=True,
752
+ is_align_stack=False,
753
+ asm_dialect=llvm.AsmDialect.AD_ATT,
754
+ )