humming-kernels 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. humming/__init__.py +2 -0
  2. humming/config/__init__.py +13 -0
  3. humming/config/base.py +129 -0
  4. humming/config/config.py +196 -0
  5. humming/config/enum.py +21 -0
  6. humming/config/mma.py +301 -0
  7. humming/csrc/launcher/elf.h +125 -0
  8. humming/csrc/launcher/launcher.cpp +262 -0
  9. humming/csrc/launcher/tensor.h +288 -0
  10. humming/csrc/launcher/tma.h +77 -0
  11. humming/csrc/launcher/torch_api.h +63 -0
  12. humming/csrc/launcher/utils.h +105 -0
  13. humming/dtypes.py +205 -0
  14. humming/include/humming/arith/epilogue_arith.cuh +182 -0
  15. humming/include/humming/arith/exp_offset.cuh +126 -0
  16. humming/include/humming/arith/mainloop_arith.cuh +511 -0
  17. humming/include/humming/datatype/base_conversion.cuh +125 -0
  18. humming/include/humming/datatype/dequant.cuh +110 -0
  19. humming/include/humming/datatype/dequant_fused.cuh +90 -0
  20. humming/include/humming/datatype/dequant_prepare.cuh +68 -0
  21. humming/include/humming/datatype/dequant_single.cuh +154 -0
  22. humming/include/humming/datatype/dtypes.cuh +70 -0
  23. humming/include/humming/epilogue/gmem_writer.cuh +171 -0
  24. humming/include/humming/epilogue/pipeline.cuh +112 -0
  25. humming/include/humming/epilogue/smem_reducer.cuh +92 -0
  26. humming/include/humming/epilogue/smem_writer.cuh +212 -0
  27. humming/include/humming/kernel/dequant_weight.cuh +47 -0
  28. humming/include/humming/kernel/humming.cuh +156 -0
  29. humming/include/humming/kernel/humming_ws.cuh +191 -0
  30. humming/include/humming/kernel/pack_weight.cuh +95 -0
  31. humming/include/humming/kernel/process.cuh +263 -0
  32. humming/include/humming/kernel/process_mxfp4.cuh +69 -0
  33. humming/include/humming/kernel/quant_weight.cuh +277 -0
  34. humming/include/humming/kernel/tops_bench.cuh +50 -0
  35. humming/include/humming/memory/g2s_loader/loader_a.cuh +183 -0
  36. humming/include/humming/memory/g2s_loader/loader_as.cuh +131 -0
  37. humming/include/humming/memory/g2s_loader/loader_b.cuh +83 -0
  38. humming/include/humming/memory/g2s_loader/loader_bias.cuh +57 -0
  39. humming/include/humming/memory/g2s_loader/loader_bs.cuh +114 -0
  40. humming/include/humming/memory/g2s_loader/loader_bzp.cuh +91 -0
  41. humming/include/humming/memory/g2s_pipeline.cuh +343 -0
  42. humming/include/humming/memory/s2r_loader/loader_a.cuh +61 -0
  43. humming/include/humming/memory/s2r_loader/loader_as.cuh +65 -0
  44. humming/include/humming/memory/s2r_loader/loader_b.cuh +57 -0
  45. humming/include/humming/memory/s2r_loader/loader_bias.cuh +51 -0
  46. humming/include/humming/memory/s2r_loader/loader_bs.cuh +104 -0
  47. humming/include/humming/memory/s2r_loader/loader_bzp.cuh +64 -0
  48. humming/include/humming/memory/s2r_pipeline.cuh +79 -0
  49. humming/include/humming/mma/wgmma.cuh +175 -0
  50. humming/include/humming/mma/wmma.cuh +124 -0
  51. humming/include/humming/scheduler.cuh +335 -0
  52. humming/include/humming/utils/all.cuh +13 -0
  53. humming/include/humming/utils/base.cuh +71 -0
  54. humming/include/humming/utils/enum.cuh +24 -0
  55. humming/include/humming/utils/ptx/barrier.cuh +122 -0
  56. humming/include/humming/utils/ptx/legacy_load.cuh +180 -0
  57. humming/include/humming/utils/ptx/math.cuh +20 -0
  58. humming/include/humming/utils/ptx/shared.cuh +45 -0
  59. humming/include/humming/utils/ptx/tma.cuh +139 -0
  60. humming/include/humming/utils/ptx/warp.cuh +17 -0
  61. humming/include/humming/utils/ptx/wgmma.cuh +24 -0
  62. humming/include/humming/utils/storage.cuh +163 -0
  63. humming/jit/__init__.py +3 -0
  64. humming/jit/compiler.py +278 -0
  65. humming/jit/runtime.py +136 -0
  66. humming/kernel/__init__.py +17 -0
  67. humming/kernel/dequant_weight.py +65 -0
  68. humming/kernel/humming.py +404 -0
  69. humming/kernel/pack_weight.py +42 -0
  70. humming/kernel/process_mxfp4.py +64 -0
  71. humming/kernel/quant_weight.py +81 -0
  72. humming/kernel/repack_weight.py +103 -0
  73. humming/kernel/tops_bench.py +86 -0
  74. humming/kernel/unpack_weight.py +43 -0
  75. humming/layer.py +821 -0
  76. humming/ops/__init__.py +144 -0
  77. humming/ops/bench.py +43 -0
  78. humming/ops/input.py +229 -0
  79. humming/ops/moe.py +188 -0
  80. humming/ops/utils.py +165 -0
  81. humming/ops/weight.py +212 -0
  82. humming/schema/__init__.py +43 -0
  83. humming/schema/awq.py +118 -0
  84. humming/schema/base.py +316 -0
  85. humming/schema/bitnet.py +108 -0
  86. humming/schema/compressed_tensors.py +330 -0
  87. humming/schema/fp8.py +144 -0
  88. humming/schema/gpt_oss_mxfp4.py +43 -0
  89. humming/schema/gptq.py +110 -0
  90. humming/schema/humming.py +276 -0
  91. humming/schema/modelopt.py +253 -0
  92. humming/schema/mxfp4.py +75 -0
  93. humming/tune/__init__.py +80 -0
  94. humming/tune/base.py +271 -0
  95. humming/tune/sm100.py +9 -0
  96. humming/tune/sm75.py +53 -0
  97. humming/tune/sm8x.py +110 -0
  98. humming/tune/sm90.py +167 -0
  99. humming/tune/sm90_h20.py +198 -0
  100. humming/utils/__init__.py +0 -0
  101. humming/utils/cuda.py +220 -0
  102. humming/utils/device.py +84 -0
  103. humming/utils/jit.py +190 -0
  104. humming/utils/smem.py +91 -0
  105. humming/utils/test.py +386 -0
  106. humming/utils/weight.py +302 -0
  107. humming_kernels-0.1.0.dist-info/METADATA +111 -0
  108. humming_kernels-0.1.0.dist-info/RECORD +111 -0
  109. humming_kernels-0.1.0.dist-info/WHEEL +5 -0
  110. humming_kernels-0.1.0.dist-info/licenses/LICENSE +202 -0
  111. humming_kernels-0.1.0.dist-info/top_level.txt +1 -0
humming/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ import humming.ops # noqa
2
+ import humming.dtypes # noqa
@@ -0,0 +1,13 @@
1
+ from humming.config.config import ComputeConfig, LayerConfig, TuningConfig
2
+ from humming.config.enum import GemmType, MmaType, WeightScaleType
3
+ from humming.config.mma import MmaOpClass
4
+
5
+ __all__ = [
6
+ "LayerConfig",
7
+ "ComputeConfig",
8
+ "TuningConfig",
9
+ "MmaType",
10
+ "WeightScaleType",
11
+ "GemmType",
12
+ "MmaOpClass",
13
+ ]
humming/config/base.py ADDED
@@ -0,0 +1,129 @@
1
+ import dataclasses
2
+ import json
3
+ import re
4
+ from enum import Enum
5
+ from typing import Any, ClassVar
6
+
7
+ from humming import dtypes
8
+
9
+
10
+ def name_to_google_cpp_const_style(name: str) -> str:
11
+ if not name:
12
+ return ""
13
+ name = name.strip().lower()
14
+ words = re.split(r"[_ \W]+", name)
15
+ pascal_words = [word.capitalize() for word in words if word]
16
+ return "k" + "".join(pascal_words)
17
+
18
+
19
+ def name_value_to_google_cpp_const_style(name: str, value: Any, keep_name: bool = False) -> str:
20
+ if not keep_name:
21
+ name = name_to_google_cpp_const_style(name)
22
+ if isinstance(value, bool):
23
+ value = "true" if value else "false"
24
+ elif isinstance(value, float):
25
+ value = str(value) + "f"
26
+ elif isinstance(value, int):
27
+ value = str(value) + "u"
28
+ else:
29
+ value = str(value).replace(".", "::")
30
+
31
+ return f"static constexpr auto {name} = {value};"
32
+
33
+
34
+ def name_value_to_extern_const_style(name: str, value: Any) -> str:
35
+ name = name.upper()
36
+ if isinstance(value, (bool, int)):
37
+ value = int(value)
38
+ return f'extern "C" __constant__ uint32_t {name} = {value};'
39
+ return ""
40
+
41
+
42
+ def name_value_to_macro_style(name: str, value: Any) -> str:
43
+ name = name.upper()
44
+ if isinstance(value, (bool, int)):
45
+ value = int(value)
46
+ return f"#define HUMMING_{name.upper()} {int(value)}"
47
+ return ""
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class BaseHummingConfig:
52
+ _name_map: ClassVar[dict[str, str]] = {}
53
+ _cpp_extra_names: ClassVar[tuple[str, ...]] = ()
54
+
55
+ def __post_init__(self):
56
+ pass
57
+
58
+ def to_cpp_str(
59
+ self,
60
+ cls: type["BaseHummingConfig"] | None = None,
61
+ include_class_name: bool = False,
62
+ ) -> str:
63
+ cls = cls or self.__class__
64
+ str_list = []
65
+ names = [x.name for x in dataclasses.fields(cls)]
66
+ names += list(cls._cpp_extra_names)
67
+ for name in names:
68
+ value = getattr(self, name)
69
+ if not isinstance(value, (bool, int, Enum)):
70
+ continue
71
+ keep_name = name in cls._name_map
72
+ if keep_name:
73
+ name = cls._name_map[name]
74
+ line = name_value_to_google_cpp_const_style(name, value, keep_name)
75
+ str_list.append(line)
76
+
77
+ code = "\n".join(" " + x for x in str_list)
78
+ class_name = cls.__name__
79
+
80
+ if include_class_name:
81
+ code = f"class {class_name} {{\n{code}\n}};"
82
+
83
+ return code
84
+
85
+ def to_macro_cpp_str(self, cls: type["BaseHummingConfig"] | None = None) -> str:
86
+ cls = cls or self.__class__
87
+ str_list = []
88
+ names = [x.name for x in dataclasses.fields(cls)]
89
+ names += list(cls._cpp_extra_names)
90
+ for name in names:
91
+ value = getattr(self, name)
92
+ if not isinstance(value, (bool, int, Enum)):
93
+ continue
94
+ line = name_value_to_macro_style(name, value)
95
+ str_list.append(line)
96
+
97
+ str_list = [x for x in str_list if x]
98
+ code = "\n".join(x for x in str_list if x)
99
+
100
+ return code
101
+
102
+ def to_extern_cpp_str(self, cls: type["BaseHummingConfig"] | None = None) -> str:
103
+ cls = cls or self.__class__
104
+ str_list = []
105
+ names = [x.name for x in dataclasses.fields(cls)]
106
+ names += list(cls._cpp_extra_names)
107
+ for name in names:
108
+ value = getattr(self, name)
109
+ if not isinstance(value, (bool, int, Enum)):
110
+ continue
111
+ line = name_value_to_extern_const_style(name, value)
112
+ str_list.append(line)
113
+
114
+ str_list = [x for x in str_list if x]
115
+ code = "\n".join(x for x in str_list if x)
116
+
117
+ return code
118
+
119
+ def to_str(self) -> str:
120
+ res = {}
121
+ for field in dataclasses.fields(self):
122
+ value = getattr(self, field.name)
123
+ if isinstance(value, Enum):
124
+ value = value.value
125
+ elif isinstance(value, dtypes.DataType):
126
+ value = str(value)
127
+ res[field.name] = value
128
+
129
+ return json.dumps(res)
@@ -0,0 +1,196 @@
1
+ import dataclasses
2
+ import math
3
+ from typing import ClassVar
4
+
5
+ import torch
6
+
7
+ from humming import dtypes
8
+ from humming.config.base import BaseHummingConfig
9
+ from humming.config.enum import GemmType, MmaType, WeightScaleType
10
+
11
+
12
+ @dataclasses.dataclass(kw_only=True)
13
+ class LayerConfig(BaseHummingConfig):
14
+ # shape config
15
+ shape_n: int
16
+ shape_k: int
17
+ pad_shape_n: int = 0
18
+ pad_shape_k: int = 0
19
+ num_experts: int = 0
20
+
21
+ # datatype config
22
+ b_dtype: dtypes.DataType
23
+ a_dtype: dtypes.DataType
24
+ c_dtype: dtypes.DataType
25
+ bs_dtype: dtypes.DataType | None = None
26
+
27
+ # quant param config
28
+ input_scale_group_size: int = 0
29
+ weight_scale_group_size: int = 0
30
+ weight_scale_group_size_n: int = 0
31
+ weight_scale_type: WeightScaleType | None = None
32
+ use_int_weight_scale: bool = False
33
+ use_fused_e8m0_scale: bool = False
34
+ has_zero_point: bool = False
35
+ is_fp_zero_point: bool = False
36
+
37
+ # bias config
38
+ has_bias: bool = False
39
+
40
+ # mma config
41
+ mma_type: MmaType | None = None
42
+
43
+ _cpp_extra_names: ClassVar[tuple[str, ...]] = (
44
+ "is_channel_weight_scale",
45
+ "is_block_weight_scale",
46
+ "is_group_weight_scale",
47
+ "is_tensor_weight_scale",
48
+ "has_input_scale",
49
+ )
50
+
51
+ def __post_init__(self):
52
+ self.problem_shape = (0, self.shape_n, self.shape_k)
53
+ self.pad_shape = (0, self.pad_shape_n, self.pad_shape_k)
54
+
55
+ if self.bs_dtype is None:
56
+ self.bs_dtype = self.c_dtype
57
+
58
+ if self.weight_scale_type is None:
59
+ if self.weight_scale_group_size_n > 1:
60
+ self.weight_scale_type = WeightScaleType.BLOCK
61
+ elif self.weight_scale_group_size == 0:
62
+ self.weight_scale_type = WeightScaleType.CHANNEL
63
+ elif self.weight_scale_group_size > 0:
64
+ self.weight_scale_type = WeightScaleType.GROUP
65
+
66
+ if isinstance(self.weight_scale_type, str):
67
+ self.weight_scale_type = WeightScaleType(self.weight_scale_type)
68
+ if self.weight_scale_type is None:
69
+ if self.weight_scale_group_size == 0:
70
+ self.weight_scale_type = WeightScaleType.CHANNEL
71
+ elif self.weight_scale_group_size > 0 and self.weight_scale_group_size_n > 1:
72
+ self.weight_scale_type = WeightScaleType.BLOCK
73
+ elif self.weight_scale_group_size > 0:
74
+ self.weight_scale_type = WeightScaleType.GROUP
75
+
76
+ if self.mma_type is None:
77
+ sm_version = torch.cuda.get_device_capability()[0]
78
+ self.mma_type = MmaType.WGMMA if sm_version == 9 else MmaType.MMA
79
+ if isinstance(self.mma_type, str):
80
+ self.mma_type = MmaType(self.mma_type)
81
+
82
+ for name in ["a", "b", "c", "bs"]:
83
+ value = getattr(self, f"{name}_dtype")
84
+ if isinstance(value, str):
85
+ value = dtypes.DataType.from_str(value)
86
+ setattr(self, f"{name}_dtype", value)
87
+
88
+ self.has_input_scale = self.a_dtype.num_bits != 16
89
+ self.is_channel_weight_scale = self.weight_scale_type == WeightScaleType.CHANNEL
90
+ self.is_tensor_weight_scale = self.weight_scale_type in [
91
+ WeightScaleType.TENSOR,
92
+ WeightScaleType.GROUP_TENSOR,
93
+ ]
94
+ self.is_block_weight_scale = self.weight_scale_type == WeightScaleType.BLOCK
95
+ self.is_group_weight_scale = self.weight_scale_type in [
96
+ WeightScaleType.GROUP,
97
+ WeightScaleType.GROUP_TENSOR,
98
+ ]
99
+
100
+
101
+ @dataclasses.dataclass(kw_only=True)
102
+ class ComputeConfig(BaseHummingConfig):
103
+ use_f16_accum: bool = False
104
+ use_batch_invariant: bool = False
105
+ gemm_type: GemmType | None = None
106
+
107
+ _cpp_extra_names: ClassVar[tuple[str, ...]] = (
108
+ "gemm_type_id",
109
+ "is_indexed_gemm",
110
+ "is_grouped_gemm",
111
+ "is_grouped_contiguous_gemm",
112
+ "is_grouped_masked_gemm",
113
+ )
114
+
115
+ def __post_init__(self):
116
+ if isinstance(self.gemm_type, str):
117
+ self.gemm_type = GemmType(self.gemm_type)
118
+ self.is_indexed_gemm = self.gemm_type == GemmType.INDEXED
119
+ self.is_grouped_contiguous_gemm = self.gemm_type == GemmType.GROUPED_CONTIGUOUS
120
+ self.is_grouped_masked_gemm = self.gemm_type == GemmType.GROUPED_MASKED
121
+ self.is_grouped_gemm = self.is_grouped_contiguous_gemm or self.is_grouped_masked_gemm
122
+
123
+ @property
124
+ def gemm_type_id(self):
125
+ assert self.gemm_type is not None
126
+ value = self.gemm_type.value.lower()
127
+ return ["dense", "indexed", "grouped_contiguous", "grouped_masked"].index(value)
128
+
129
+
130
+ @dataclasses.dataclass(kw_only=True)
131
+ class TuningConfig(BaseHummingConfig):
132
+ block_shape: tuple[int, int, int]
133
+ warp_shape: tuple[int, int, int]
134
+
135
+ use_stream_k: bool = True
136
+
137
+ num_stages: int = 2
138
+ num_ctas_per_sm: int = 1
139
+
140
+ use_warp_spec: bool | None = None
141
+ use_mbarrier: bool | None = None
142
+ use_cp_async: bool | None = None
143
+
144
+ use_tma: bool | None = None
145
+ use_tma_a: bool | None = None
146
+ use_tma_b: bool | None = None
147
+ use_tma_c: bool | None = None
148
+ use_tma_bs: bool | None = None
149
+ use_tma_bzp: bool | None = None
150
+ use_tma_bias: bool | None = None
151
+
152
+ num_write_splits: int = 1
153
+ multi_cast_size_a: int = 1
154
+ multi_cast_size_b: int = 1
155
+
156
+ _cpp_extra_names: ClassVar[tuple[str, ...]] = (
157
+ "num_threads",
158
+ "num_math_threads",
159
+ "num_load_threads",
160
+ )
161
+
162
+ _name_map = {
163
+ "use_mbarrier": "kUseMBarrier",
164
+ "use_tma_bs": "kUseTmaBS",
165
+ "use_tma_bzp": "kUseTmaBZP",
166
+ }
167
+
168
+ def __post_init__(self):
169
+ if self.use_warp_spec is None:
170
+ self.use_warp_spec = False
171
+
172
+ if self.use_tma is None:
173
+ self.use_tma = False
174
+
175
+ if self.use_mbarrier is None:
176
+ self.use_mbarrier = self.use_tma or self.use_warp_spec
177
+
178
+ if self.use_cp_async is None:
179
+ sm_version = torch.cuda.get_device_capability()
180
+ self.use_cp_async = sm_version[0] >= 8
181
+
182
+ self.num_math_threads = math.prod(self.block_shape) // math.prod(self.warp_shape) * 32
183
+ if self.use_warp_spec:
184
+ self.num_load_threads = 128
185
+ self.num_threads = self.num_math_threads + 128
186
+ else:
187
+ self.num_load_threads = self.num_math_threads
188
+ self.num_threads = self.num_math_threads
189
+
190
+ for name in dir(self):
191
+ if not name.startswith("use_tma_"):
192
+ continue
193
+ if not self.use_tma:
194
+ assert getattr(self, name) is not True
195
+ if getattr(self, name) is None:
196
+ setattr(self, name, self.use_tma)
humming/config/enum.py ADDED
@@ -0,0 +1,21 @@
1
+ import enum
2
+
3
+
4
+ class MmaType(enum.Enum):
5
+ MMA = "mma"
6
+ WGMMA = "wgmma"
7
+
8
+
9
+ class WeightScaleType(enum.Enum):
10
+ GROUP = "group"
11
+ BLOCK = "block"
12
+ CHANNEL = "channel"
13
+ TENSOR = "tensor"
14
+ GROUP_TENSOR = "group_tensor"
15
+
16
+
17
+ class GemmType(enum.Enum):
18
+ DENSE = "dense"
19
+ INDEXED = "indexed"
20
+ GROUPED_CONTIGUOUS = "grouped_contiguous"
21
+ GROUPED_MASKED = "grouped_masked"
humming/config/mma.py ADDED
@@ -0,0 +1,301 @@
1
+ import math
2
+ import re
3
+
4
+ import humming.dtypes as dtypes
5
+ from humming.config.enum import MmaType
6
+
7
+ DTYPE_BIT_WIDTH_MAP = {
8
+ "f32": 32,
9
+ "s32": 32,
10
+ "f16": 16,
11
+ "bf16": 16,
12
+ "e4m3": 8,
13
+ "e5m2": 8,
14
+ "s8": 8,
15
+ "e2m1": 4,
16
+ "s4": 4,
17
+ }
18
+
19
+ DTYPE_MAP = {
20
+ dtypes.float32: "f32",
21
+ dtypes.int32: "s32",
22
+ dtypes.float16: "f16",
23
+ dtypes.bfloat16: "bf16",
24
+ dtypes.float8e4m3: "e4m3",
25
+ dtypes.float8e5m2: "e5m2",
26
+ dtypes.int8: "s8",
27
+ dtypes.float4e2m1: "e2m1",
28
+ dtypes.int4: "s4",
29
+ }
30
+
31
+
32
+ def calc_reg_count(rows, cols, ptx_dtype):
33
+ total_bits = rows * cols * DTYPE_BIT_WIDTH_MAP[ptx_dtype]
34
+ assert total_bits % (32 * 32) == 0
35
+ reg_count = total_bits // (32 * 32)
36
+ return reg_count
37
+
38
+
39
+ class MmaOpClassImpl:
40
+ def __init__(self, m, n, k, a_dtype, b_dtype, cd_dtype):
41
+ self.shape = (m, n, k)
42
+ self.a_dtype = a_dtype if isinstance(a_dtype, str) else DTYPE_MAP[a_dtype]
43
+ self.b_dtype = b_dtype if isinstance(b_dtype, str) else DTYPE_MAP[b_dtype]
44
+ self.cd_dtype = cd_dtype if isinstance(cd_dtype, str) else DTYPE_MAP[cd_dtype]
45
+
46
+ self.reg_a_count = calc_reg_count(m, k, self.a_dtype)
47
+ self.reg_b_count = calc_reg_count(k, n, self.b_dtype)
48
+ self.reg_cd_count = calc_reg_count(m, n, self.cd_dtype)
49
+ if self.cd_dtype == "f16":
50
+ self.val_type_cd = "half"
51
+ self.reg_cd_type = "uint32_t"
52
+ elif self.cd_dtype == "bf16":
53
+ self.val_type_cd = "nv_bfloat16"
54
+ self.reg_cd_type = "uint32_t"
55
+ elif self.cd_dtype == "f32":
56
+ self.val_type_cd = "float"
57
+ self.reg_cd_type = "float"
58
+ elif self.cd_dtype == "s32":
59
+ self.val_type_cd = "int32_t"
60
+ self.reg_cd_type = "uint32_t"
61
+ else:
62
+ raise ValueError(f"Invalid cd_dtype: {cd_dtype}")
63
+
64
+ def to_cpp_str(self, include_class_name=False):
65
+ reg_cd_type = self.reg_cd_type
66
+ lines = [
67
+ "static constexpr MmaType kMmaType = MmaType::MMA;",
68
+ f"using MmaShape = Shape<{self.shape[0]}, {self.shape[1]}, {self.shape[2]}>;",
69
+ "",
70
+ f"using ValTypeC = {self.val_type_cd};",
71
+ f"using ValTypeD = {self.val_type_cd};",
72
+ "",
73
+ f"static constexpr uint32_t kATypeBits = {DTYPE_BIT_WIDTH_MAP[self.a_dtype]};",
74
+ f"static constexpr uint32_t kBTypeBits = {DTYPE_BIT_WIDTH_MAP[self.b_dtype]};",
75
+ f"static constexpr uint32_t kCTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
76
+ f"static constexpr uint32_t kDTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
77
+ "",
78
+ f"using ARegisters = uint32_t[{self.reg_a_count}];",
79
+ f"using BRegisters = uint32_t[{self.reg_b_count}];",
80
+ f"using CRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
81
+ f"using DRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
82
+ "",
83
+ "CUDA_INLINE",
84
+ f"static void fma(uint32_t *a, uint32_t *b, {reg_cd_type} *c, {reg_cd_type} *d) {{",
85
+ *self.generate_ptx(indent=2).strip("\n").split("\n"),
86
+ "};",
87
+ ]
88
+
89
+ code = "\n".join(" " + x if x else x for x in lines)
90
+ if include_class_name:
91
+ code = f"class MmaOpClass {{\n{code}\n}};"
92
+
93
+ return code
94
+
95
+ def generate_ptx(self, indent=0):
96
+ a_dtype = self.a_dtype
97
+ b_dtype = self.b_dtype
98
+ cd_dtype = self.cd_dtype
99
+ shape = self.shape
100
+
101
+ asm_op = f"mma.sync.aligned.m{shape[0]}n{shape[1]}k{shape[2]}.row.col"
102
+ asm_op += f".{cd_dtype}.{a_dtype}.{b_dtype}.{cd_dtype}"
103
+ if "s" in a_dtype:
104
+ asm_op += ".satfinite"
105
+
106
+ start = 0
107
+ end = 0
108
+ param_placeholders_list = []
109
+ counts = [self.reg_cd_count, self.reg_a_count, self.reg_b_count, self.reg_cd_count]
110
+ for i in range(len(counts)):
111
+ end += counts[i]
112
+ placeholder_str = ", ".join(f"%{x}" for x in range(start, end))
113
+ param_placeholders_list.append("{" + placeholder_str + "}")
114
+ start += counts[i]
115
+
116
+ a_params = []
117
+ b_params = []
118
+ c_params = []
119
+ d_params = []
120
+ for i in range(self.reg_a_count):
121
+ a_params.append(f' "r"(a[{i}])')
122
+ for i in range(self.reg_b_count):
123
+ b_params.append(f' "r"(b[{i}])')
124
+ for i in range(self.reg_cd_count):
125
+ t = "f" if cd_dtype == "f32" else "r"
126
+ c_params.append(f' "{t}"(c[{i}])')
127
+ d_params.append(f'"+{t}"(d[{i}])')
128
+
129
+ asm_code = f"""
130
+ asm volatile(
131
+ "{asm_op} "
132
+ "{", ".join(param_placeholders_list)};\\n"
133
+ : {", ".join(d_params)}
134
+ : {", ".join(a_params)},
135
+ {", ".join(b_params)},
136
+ {", ".join(c_params)}
137
+ );
138
+ """
139
+
140
+ space_count = len(re.findall("^\n( +)", asm_code)[0])
141
+ asm_code = asm_code.replace("\n" + " " * space_count, "\n").strip()
142
+ asm_code = "".join("\n" + " " * indent + x for x in asm_code.split("\n"))
143
+
144
+ return asm_code
145
+
146
+
147
+ class WgmmaOpClassImpl:
148
+ def __init__(self, m, n, k, a_dtype, b_dtype, cd_dtype):
149
+ self.shape = (m, n, k)
150
+ self.a_dtype = a_dtype if isinstance(a_dtype, str) else DTYPE_MAP[a_dtype]
151
+ self.b_dtype = b_dtype if isinstance(b_dtype, str) else DTYPE_MAP[b_dtype]
152
+ self.cd_dtype = cd_dtype if isinstance(cd_dtype, str) else DTYPE_MAP[cd_dtype]
153
+
154
+ # Project B (registers) is sized (project N) x (project K) in b_dtype after
155
+ # the transpose — that's what fills the wgmma A register operand.
156
+ self.reg_b_count = calc_reg_count(n, k, self.b_dtype) // 4
157
+ self.reg_cd_count = calc_reg_count(m, n, self.cd_dtype) // 4
158
+ if self.cd_dtype == "f16":
159
+ self.val_type_cd = "half"
160
+ self.reg_cd_type = "uint32_t"
161
+ elif self.cd_dtype == "bf16":
162
+ self.val_type_cd = "nv_bfloat16"
163
+ self.reg_cd_type = "uint32_t"
164
+ elif self.cd_dtype == "f32":
165
+ self.val_type_cd = "float"
166
+ self.reg_cd_type = "float"
167
+ elif self.cd_dtype == "s32":
168
+ self.val_type_cd = "int32_t"
169
+ self.reg_cd_type = "uint32_t"
170
+ else:
171
+ raise ValueError(f"Invalid cd_dtype: {cd_dtype}")
172
+
173
+ def to_cpp_str(self, include_class_name=False):
174
+ reg_cd_type = self.reg_cd_type
175
+ lines = [
176
+ "static constexpr MmaType kMmaType = MmaType::WGMMA;",
177
+ f"using MmaShape = Shape<{self.shape[0]}, {self.shape[1]}, {self.shape[2]}>;",
178
+ "",
179
+ f"using ValTypeC = {self.val_type_cd};",
180
+ f"using ValTypeD = {self.val_type_cd};",
181
+ "",
182
+ f"static constexpr uint32_t kATypeBits = {DTYPE_BIT_WIDTH_MAP[self.a_dtype]};",
183
+ f"static constexpr uint32_t kBTypeBits = {DTYPE_BIT_WIDTH_MAP[self.b_dtype]};",
184
+ f"static constexpr uint32_t kCTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
185
+ f"static constexpr uint32_t kDTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
186
+ "",
187
+ f"using BRegisters = uint32_t[{self.reg_b_count}];",
188
+ f"using CRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
189
+ f"using DRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
190
+ "",
191
+ "CUDA_INLINE",
192
+ f"static void fma(uint64_t &desc, uint32_t *b, {reg_cd_type} *d, bool pred = true) {{",
193
+ *self.generate_ptx(indent=2, has_scale_d=True).strip("\n").split("\n"),
194
+ "};",
195
+ ]
196
+
197
+ code = "\n".join(" " + x if x else x for x in lines)
198
+ if include_class_name:
199
+ code = f"class MmaOpClass {{\n{code}\n}};"
200
+
201
+ return code
202
+
203
+ def generate_ptx(self, indent=2, has_scale_d=True):
204
+ a_dtype = self.a_dtype
205
+ b_dtype = self.b_dtype
206
+ cd_dtype = self.cd_dtype
207
+ m, n, k = self.shape
208
+
209
+ # Swap M<->N and A-dtype<->B-dtype in PTX: project's A becomes wgmma's B and
210
+ # project's B becomes wgmma's A. The PTX dtype suffix order is .cd.a.b, so
211
+ # the wgmma A slot takes project's b_dtype and the wgmma B slot takes a_dtype.
212
+ asm_op = f"wgmma.mma_async.sync.aligned.m{n}n{m}k{k}"
213
+ asm_op += f".{cd_dtype}.{b_dtype}.{a_dtype}"
214
+ # satfinite gates on the wgmma-A operand dtype (= project's B).
215
+ if "s" in b_dtype:
216
+ asm_op += ".satfinite"
217
+
218
+ start = 0
219
+ end = 0
220
+ param_placeholders_list = []
221
+ counts = [self.reg_cd_count, self.reg_b_count]
222
+ for i in range(len(counts)):
223
+ end += counts[i]
224
+ placeholder_str = ", ".join(f"%{x}" for x in range(start, end))
225
+ param_placeholders_list.append("{" + placeholder_str + "}")
226
+ start += counts[i]
227
+ param_placeholders_list.append(f"%{sum(counts)}")
228
+
229
+ other_ptx_args = ", p" if has_scale_d else ", 1"
230
+ # The dtype-specific PTX tail args (scale/trans flags) gate on the wgmma-A
231
+ # operand dtype, which after the swap is project's b_dtype.
232
+ if self.b_dtype in ["f16", "bf16"]:
233
+ other_ptx_args += ", 1, 1, 0"
234
+ elif self.b_dtype in ["e4m3", "e5m2", "e2m1"]:
235
+ other_ptx_args += ", 1, 1"
236
+
237
+ # Project A's smem descriptor fills the wgmma B operand.
238
+ a_desc_param = ' "l"(desc)'
239
+ # Project B's registers fill the wgmma A operand.
240
+ b_params = []
241
+ cd_params = []
242
+ for i in range(self.reg_b_count):
243
+ b_params.append(f' "r"(b[{i}])')
244
+ for i in range(self.reg_cd_count):
245
+ t = "f" if cd_dtype == "f32" else "r"
246
+ cd_params.append(f'"+{t}"(d[{i}])')
247
+
248
+ cd_param_str = ""
249
+ for i in range(math.ceil(len(cd_params) / 4)):
250
+ cd_params_part = cd_params[i * 4 : (i + 1) * 4]
251
+ cd_params_part_str = ", ".join(cd_params_part) + ",\n"
252
+ if cd_param_str:
253
+ cd_params_part_str = " " + cd_params_part_str
254
+
255
+ cd_param_str += cd_params_part_str
256
+
257
+ cd_param_str = cd_param_str.strip().strip(",")
258
+
259
+ if has_scale_d:
260
+ asm_code = f"""
261
+ asm volatile(
262
+ "{{\\n"
263
+ ".reg .pred p;\\n"
264
+ "setp.ne.b32 p, %{sum(counts) + 1}, 0;\\n"
265
+ "{asm_op} "
266
+ "{", ".join(param_placeholders_list)}{other_ptx_args};\\n"
267
+ "}}\\n"
268
+ : {cd_param_str}
269
+ : {", ".join(b_params)},
270
+ {a_desc_param}, "r"((uint32_t)pred)
271
+ );
272
+ """
273
+ else:
274
+ asm_code = f"""
275
+ asm volatile(
276
+ "{asm_op} "
277
+ "{", ".join(param_placeholders_list)}{other_ptx_args};\\n"
278
+ : {cd_param_str}
279
+ : {", ".join(b_params)},
280
+ {a_desc_param}
281
+ );
282
+ """
283
+
284
+ space_count = len(re.findall("^\n( +)", asm_code)[0])
285
+ asm_code = asm_code.replace("\n" + " " * space_count, "\n").strip()
286
+ asm_code = "".join("\n" + " " * indent + x for x in asm_code.split("\n"))
287
+
288
+ return asm_code
289
+
290
+
291
+ class MmaOpClass:
292
+ @classmethod
293
+ def from_config(cls, mma_type, m, n, k, a_dtype, b_dtype, cd_dtype):
294
+ mma_type = mma_type if isinstance(mma_type, MmaType) else getattr(MmaType, mma_type.upper())
295
+
296
+ if mma_type == MmaType.MMA:
297
+ return MmaOpClassImpl(m, n, k, a_dtype, b_dtype, cd_dtype)
298
+ elif mma_type == MmaType.WGMMA:
299
+ return WgmmaOpClassImpl(m, n, k, a_dtype, b_dtype, cd_dtype)
300
+ else:
301
+ raise ValueError(f"Invalid MMA Type: {mma_type}")