humming-kernels 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- humming/__init__.py +2 -0
- humming/config/__init__.py +13 -0
- humming/config/base.py +129 -0
- humming/config/config.py +196 -0
- humming/config/enum.py +21 -0
- humming/config/mma.py +301 -0
- humming/csrc/launcher/elf.h +125 -0
- humming/csrc/launcher/launcher.cpp +262 -0
- humming/csrc/launcher/tensor.h +288 -0
- humming/csrc/launcher/tma.h +77 -0
- humming/csrc/launcher/torch_api.h +63 -0
- humming/csrc/launcher/utils.h +105 -0
- humming/dtypes.py +205 -0
- humming/include/humming/arith/epilogue_arith.cuh +182 -0
- humming/include/humming/arith/exp_offset.cuh +126 -0
- humming/include/humming/arith/mainloop_arith.cuh +511 -0
- humming/include/humming/datatype/base_conversion.cuh +125 -0
- humming/include/humming/datatype/dequant.cuh +110 -0
- humming/include/humming/datatype/dequant_fused.cuh +90 -0
- humming/include/humming/datatype/dequant_prepare.cuh +68 -0
- humming/include/humming/datatype/dequant_single.cuh +154 -0
- humming/include/humming/datatype/dtypes.cuh +70 -0
- humming/include/humming/epilogue/gmem_writer.cuh +171 -0
- humming/include/humming/epilogue/pipeline.cuh +112 -0
- humming/include/humming/epilogue/smem_reducer.cuh +92 -0
- humming/include/humming/epilogue/smem_writer.cuh +212 -0
- humming/include/humming/kernel/dequant_weight.cuh +47 -0
- humming/include/humming/kernel/humming.cuh +156 -0
- humming/include/humming/kernel/humming_ws.cuh +191 -0
- humming/include/humming/kernel/pack_weight.cuh +95 -0
- humming/include/humming/kernel/process.cuh +263 -0
- humming/include/humming/kernel/process_mxfp4.cuh +69 -0
- humming/include/humming/kernel/quant_weight.cuh +277 -0
- humming/include/humming/kernel/tops_bench.cuh +50 -0
- humming/include/humming/memory/g2s_loader/loader_a.cuh +183 -0
- humming/include/humming/memory/g2s_loader/loader_as.cuh +131 -0
- humming/include/humming/memory/g2s_loader/loader_b.cuh +83 -0
- humming/include/humming/memory/g2s_loader/loader_bias.cuh +57 -0
- humming/include/humming/memory/g2s_loader/loader_bs.cuh +114 -0
- humming/include/humming/memory/g2s_loader/loader_bzp.cuh +91 -0
- humming/include/humming/memory/g2s_pipeline.cuh +343 -0
- humming/include/humming/memory/s2r_loader/loader_a.cuh +61 -0
- humming/include/humming/memory/s2r_loader/loader_as.cuh +65 -0
- humming/include/humming/memory/s2r_loader/loader_b.cuh +57 -0
- humming/include/humming/memory/s2r_loader/loader_bias.cuh +51 -0
- humming/include/humming/memory/s2r_loader/loader_bs.cuh +104 -0
- humming/include/humming/memory/s2r_loader/loader_bzp.cuh +64 -0
- humming/include/humming/memory/s2r_pipeline.cuh +79 -0
- humming/include/humming/mma/wgmma.cuh +175 -0
- humming/include/humming/mma/wmma.cuh +124 -0
- humming/include/humming/scheduler.cuh +335 -0
- humming/include/humming/utils/all.cuh +13 -0
- humming/include/humming/utils/base.cuh +71 -0
- humming/include/humming/utils/enum.cuh +24 -0
- humming/include/humming/utils/ptx/barrier.cuh +122 -0
- humming/include/humming/utils/ptx/legacy_load.cuh +180 -0
- humming/include/humming/utils/ptx/math.cuh +20 -0
- humming/include/humming/utils/ptx/shared.cuh +45 -0
- humming/include/humming/utils/ptx/tma.cuh +139 -0
- humming/include/humming/utils/ptx/warp.cuh +17 -0
- humming/include/humming/utils/ptx/wgmma.cuh +24 -0
- humming/include/humming/utils/storage.cuh +163 -0
- humming/jit/__init__.py +3 -0
- humming/jit/compiler.py +278 -0
- humming/jit/runtime.py +136 -0
- humming/kernel/__init__.py +17 -0
- humming/kernel/dequant_weight.py +65 -0
- humming/kernel/humming.py +404 -0
- humming/kernel/pack_weight.py +42 -0
- humming/kernel/process_mxfp4.py +64 -0
- humming/kernel/quant_weight.py +81 -0
- humming/kernel/repack_weight.py +103 -0
- humming/kernel/tops_bench.py +86 -0
- humming/kernel/unpack_weight.py +43 -0
- humming/layer.py +821 -0
- humming/ops/__init__.py +144 -0
- humming/ops/bench.py +43 -0
- humming/ops/input.py +229 -0
- humming/ops/moe.py +188 -0
- humming/ops/utils.py +165 -0
- humming/ops/weight.py +212 -0
- humming/schema/__init__.py +43 -0
- humming/schema/awq.py +118 -0
- humming/schema/base.py +316 -0
- humming/schema/bitnet.py +108 -0
- humming/schema/compressed_tensors.py +330 -0
- humming/schema/fp8.py +144 -0
- humming/schema/gpt_oss_mxfp4.py +43 -0
- humming/schema/gptq.py +110 -0
- humming/schema/humming.py +276 -0
- humming/schema/modelopt.py +253 -0
- humming/schema/mxfp4.py +75 -0
- humming/tune/__init__.py +80 -0
- humming/tune/base.py +271 -0
- humming/tune/sm100.py +9 -0
- humming/tune/sm75.py +53 -0
- humming/tune/sm8x.py +110 -0
- humming/tune/sm90.py +167 -0
- humming/tune/sm90_h20.py +198 -0
- humming/utils/__init__.py +0 -0
- humming/utils/cuda.py +220 -0
- humming/utils/device.py +84 -0
- humming/utils/jit.py +190 -0
- humming/utils/smem.py +91 -0
- humming/utils/test.py +386 -0
- humming/utils/weight.py +302 -0
- humming_kernels-0.1.0.dist-info/METADATA +111 -0
- humming_kernels-0.1.0.dist-info/RECORD +111 -0
- humming_kernels-0.1.0.dist-info/WHEEL +5 -0
- humming_kernels-0.1.0.dist-info/licenses/LICENSE +202 -0
- humming_kernels-0.1.0.dist-info/top_level.txt +1 -0
humming/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from humming.config.config import ComputeConfig, LayerConfig, TuningConfig
|
|
2
|
+
from humming.config.enum import GemmType, MmaType, WeightScaleType
|
|
3
|
+
from humming.config.mma import MmaOpClass
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"LayerConfig",
|
|
7
|
+
"ComputeConfig",
|
|
8
|
+
"TuningConfig",
|
|
9
|
+
"MmaType",
|
|
10
|
+
"WeightScaleType",
|
|
11
|
+
"GemmType",
|
|
12
|
+
"MmaOpClass",
|
|
13
|
+
]
|
humming/config/base.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, ClassVar
|
|
6
|
+
|
|
7
|
+
from humming import dtypes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def name_to_google_cpp_const_style(name: str) -> str:
|
|
11
|
+
if not name:
|
|
12
|
+
return ""
|
|
13
|
+
name = name.strip().lower()
|
|
14
|
+
words = re.split(r"[_ \W]+", name)
|
|
15
|
+
pascal_words = [word.capitalize() for word in words if word]
|
|
16
|
+
return "k" + "".join(pascal_words)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def name_value_to_google_cpp_const_style(name: str, value: Any, keep_name: bool = False) -> str:
|
|
20
|
+
if not keep_name:
|
|
21
|
+
name = name_to_google_cpp_const_style(name)
|
|
22
|
+
if isinstance(value, bool):
|
|
23
|
+
value = "true" if value else "false"
|
|
24
|
+
elif isinstance(value, float):
|
|
25
|
+
value = str(value) + "f"
|
|
26
|
+
elif isinstance(value, int):
|
|
27
|
+
value = str(value) + "u"
|
|
28
|
+
else:
|
|
29
|
+
value = str(value).replace(".", "::")
|
|
30
|
+
|
|
31
|
+
return f"static constexpr auto {name} = {value};"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def name_value_to_extern_const_style(name: str, value: Any) -> str:
|
|
35
|
+
name = name.upper()
|
|
36
|
+
if isinstance(value, (bool, int)):
|
|
37
|
+
value = int(value)
|
|
38
|
+
return f'extern "C" __constant__ uint32_t {name} = {value};'
|
|
39
|
+
return ""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def name_value_to_macro_style(name: str, value: Any) -> str:
|
|
43
|
+
name = name.upper()
|
|
44
|
+
if isinstance(value, (bool, int)):
|
|
45
|
+
value = int(value)
|
|
46
|
+
return f"#define HUMMING_{name.upper()} {int(value)}"
|
|
47
|
+
return ""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclasses.dataclass
|
|
51
|
+
class BaseHummingConfig:
|
|
52
|
+
_name_map: ClassVar[dict[str, str]] = {}
|
|
53
|
+
_cpp_extra_names: ClassVar[tuple[str, ...]] = ()
|
|
54
|
+
|
|
55
|
+
def __post_init__(self):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def to_cpp_str(
|
|
59
|
+
self,
|
|
60
|
+
cls: type["BaseHummingConfig"] | None = None,
|
|
61
|
+
include_class_name: bool = False,
|
|
62
|
+
) -> str:
|
|
63
|
+
cls = cls or self.__class__
|
|
64
|
+
str_list = []
|
|
65
|
+
names = [x.name for x in dataclasses.fields(cls)]
|
|
66
|
+
names += list(cls._cpp_extra_names)
|
|
67
|
+
for name in names:
|
|
68
|
+
value = getattr(self, name)
|
|
69
|
+
if not isinstance(value, (bool, int, Enum)):
|
|
70
|
+
continue
|
|
71
|
+
keep_name = name in cls._name_map
|
|
72
|
+
if keep_name:
|
|
73
|
+
name = cls._name_map[name]
|
|
74
|
+
line = name_value_to_google_cpp_const_style(name, value, keep_name)
|
|
75
|
+
str_list.append(line)
|
|
76
|
+
|
|
77
|
+
code = "\n".join(" " + x for x in str_list)
|
|
78
|
+
class_name = cls.__name__
|
|
79
|
+
|
|
80
|
+
if include_class_name:
|
|
81
|
+
code = f"class {class_name} {{\n{code}\n}};"
|
|
82
|
+
|
|
83
|
+
return code
|
|
84
|
+
|
|
85
|
+
def to_macro_cpp_str(self, cls: type["BaseHummingConfig"] | None = None) -> str:
|
|
86
|
+
cls = cls or self.__class__
|
|
87
|
+
str_list = []
|
|
88
|
+
names = [x.name for x in dataclasses.fields(cls)]
|
|
89
|
+
names += list(cls._cpp_extra_names)
|
|
90
|
+
for name in names:
|
|
91
|
+
value = getattr(self, name)
|
|
92
|
+
if not isinstance(value, (bool, int, Enum)):
|
|
93
|
+
continue
|
|
94
|
+
line = name_value_to_macro_style(name, value)
|
|
95
|
+
str_list.append(line)
|
|
96
|
+
|
|
97
|
+
str_list = [x for x in str_list if x]
|
|
98
|
+
code = "\n".join(x for x in str_list if x)
|
|
99
|
+
|
|
100
|
+
return code
|
|
101
|
+
|
|
102
|
+
def to_extern_cpp_str(self, cls: type["BaseHummingConfig"] | None = None) -> str:
|
|
103
|
+
cls = cls or self.__class__
|
|
104
|
+
str_list = []
|
|
105
|
+
names = [x.name for x in dataclasses.fields(cls)]
|
|
106
|
+
names += list(cls._cpp_extra_names)
|
|
107
|
+
for name in names:
|
|
108
|
+
value = getattr(self, name)
|
|
109
|
+
if not isinstance(value, (bool, int, Enum)):
|
|
110
|
+
continue
|
|
111
|
+
line = name_value_to_extern_const_style(name, value)
|
|
112
|
+
str_list.append(line)
|
|
113
|
+
|
|
114
|
+
str_list = [x for x in str_list if x]
|
|
115
|
+
code = "\n".join(x for x in str_list if x)
|
|
116
|
+
|
|
117
|
+
return code
|
|
118
|
+
|
|
119
|
+
def to_str(self) -> str:
|
|
120
|
+
res = {}
|
|
121
|
+
for field in dataclasses.fields(self):
|
|
122
|
+
value = getattr(self, field.name)
|
|
123
|
+
if isinstance(value, Enum):
|
|
124
|
+
value = value.value
|
|
125
|
+
elif isinstance(value, dtypes.DataType):
|
|
126
|
+
value = str(value)
|
|
127
|
+
res[field.name] = value
|
|
128
|
+
|
|
129
|
+
return json.dumps(res)
|
humming/config/config.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import math
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from humming import dtypes
|
|
8
|
+
from humming.config.base import BaseHummingConfig
|
|
9
|
+
from humming.config.enum import GemmType, MmaType, WeightScaleType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(kw_only=True)
|
|
13
|
+
class LayerConfig(BaseHummingConfig):
|
|
14
|
+
# shape config
|
|
15
|
+
shape_n: int
|
|
16
|
+
shape_k: int
|
|
17
|
+
pad_shape_n: int = 0
|
|
18
|
+
pad_shape_k: int = 0
|
|
19
|
+
num_experts: int = 0
|
|
20
|
+
|
|
21
|
+
# datatype config
|
|
22
|
+
b_dtype: dtypes.DataType
|
|
23
|
+
a_dtype: dtypes.DataType
|
|
24
|
+
c_dtype: dtypes.DataType
|
|
25
|
+
bs_dtype: dtypes.DataType | None = None
|
|
26
|
+
|
|
27
|
+
# quant param config
|
|
28
|
+
input_scale_group_size: int = 0
|
|
29
|
+
weight_scale_group_size: int = 0
|
|
30
|
+
weight_scale_group_size_n: int = 0
|
|
31
|
+
weight_scale_type: WeightScaleType | None = None
|
|
32
|
+
use_int_weight_scale: bool = False
|
|
33
|
+
use_fused_e8m0_scale: bool = False
|
|
34
|
+
has_zero_point: bool = False
|
|
35
|
+
is_fp_zero_point: bool = False
|
|
36
|
+
|
|
37
|
+
# bias config
|
|
38
|
+
has_bias: bool = False
|
|
39
|
+
|
|
40
|
+
# mma config
|
|
41
|
+
mma_type: MmaType | None = None
|
|
42
|
+
|
|
43
|
+
_cpp_extra_names: ClassVar[tuple[str, ...]] = (
|
|
44
|
+
"is_channel_weight_scale",
|
|
45
|
+
"is_block_weight_scale",
|
|
46
|
+
"is_group_weight_scale",
|
|
47
|
+
"is_tensor_weight_scale",
|
|
48
|
+
"has_input_scale",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def __post_init__(self):
|
|
52
|
+
self.problem_shape = (0, self.shape_n, self.shape_k)
|
|
53
|
+
self.pad_shape = (0, self.pad_shape_n, self.pad_shape_k)
|
|
54
|
+
|
|
55
|
+
if self.bs_dtype is None:
|
|
56
|
+
self.bs_dtype = self.c_dtype
|
|
57
|
+
|
|
58
|
+
if self.weight_scale_type is None:
|
|
59
|
+
if self.weight_scale_group_size_n > 1:
|
|
60
|
+
self.weight_scale_type = WeightScaleType.BLOCK
|
|
61
|
+
elif self.weight_scale_group_size == 0:
|
|
62
|
+
self.weight_scale_type = WeightScaleType.CHANNEL
|
|
63
|
+
elif self.weight_scale_group_size > 0:
|
|
64
|
+
self.weight_scale_type = WeightScaleType.GROUP
|
|
65
|
+
|
|
66
|
+
if isinstance(self.weight_scale_type, str):
|
|
67
|
+
self.weight_scale_type = WeightScaleType(self.weight_scale_type)
|
|
68
|
+
if self.weight_scale_type is None:
|
|
69
|
+
if self.weight_scale_group_size == 0:
|
|
70
|
+
self.weight_scale_type = WeightScaleType.CHANNEL
|
|
71
|
+
elif self.weight_scale_group_size > 0 and self.weight_scale_group_size_n > 1:
|
|
72
|
+
self.weight_scale_type = WeightScaleType.BLOCK
|
|
73
|
+
elif self.weight_scale_group_size > 0:
|
|
74
|
+
self.weight_scale_type = WeightScaleType.GROUP
|
|
75
|
+
|
|
76
|
+
if self.mma_type is None:
|
|
77
|
+
sm_version = torch.cuda.get_device_capability()[0]
|
|
78
|
+
self.mma_type = MmaType.WGMMA if sm_version == 9 else MmaType.MMA
|
|
79
|
+
if isinstance(self.mma_type, str):
|
|
80
|
+
self.mma_type = MmaType(self.mma_type)
|
|
81
|
+
|
|
82
|
+
for name in ["a", "b", "c", "bs"]:
|
|
83
|
+
value = getattr(self, f"{name}_dtype")
|
|
84
|
+
if isinstance(value, str):
|
|
85
|
+
value = dtypes.DataType.from_str(value)
|
|
86
|
+
setattr(self, f"{name}_dtype", value)
|
|
87
|
+
|
|
88
|
+
self.has_input_scale = self.a_dtype.num_bits != 16
|
|
89
|
+
self.is_channel_weight_scale = self.weight_scale_type == WeightScaleType.CHANNEL
|
|
90
|
+
self.is_tensor_weight_scale = self.weight_scale_type in [
|
|
91
|
+
WeightScaleType.TENSOR,
|
|
92
|
+
WeightScaleType.GROUP_TENSOR,
|
|
93
|
+
]
|
|
94
|
+
self.is_block_weight_scale = self.weight_scale_type == WeightScaleType.BLOCK
|
|
95
|
+
self.is_group_weight_scale = self.weight_scale_type in [
|
|
96
|
+
WeightScaleType.GROUP,
|
|
97
|
+
WeightScaleType.GROUP_TENSOR,
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclasses.dataclass(kw_only=True)
|
|
102
|
+
class ComputeConfig(BaseHummingConfig):
|
|
103
|
+
use_f16_accum: bool = False
|
|
104
|
+
use_batch_invariant: bool = False
|
|
105
|
+
gemm_type: GemmType | None = None
|
|
106
|
+
|
|
107
|
+
_cpp_extra_names: ClassVar[tuple[str, ...]] = (
|
|
108
|
+
"gemm_type_id",
|
|
109
|
+
"is_indexed_gemm",
|
|
110
|
+
"is_grouped_gemm",
|
|
111
|
+
"is_grouped_contiguous_gemm",
|
|
112
|
+
"is_grouped_masked_gemm",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def __post_init__(self):
|
|
116
|
+
if isinstance(self.gemm_type, str):
|
|
117
|
+
self.gemm_type = GemmType(self.gemm_type)
|
|
118
|
+
self.is_indexed_gemm = self.gemm_type == GemmType.INDEXED
|
|
119
|
+
self.is_grouped_contiguous_gemm = self.gemm_type == GemmType.GROUPED_CONTIGUOUS
|
|
120
|
+
self.is_grouped_masked_gemm = self.gemm_type == GemmType.GROUPED_MASKED
|
|
121
|
+
self.is_grouped_gemm = self.is_grouped_contiguous_gemm or self.is_grouped_masked_gemm
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def gemm_type_id(self):
|
|
125
|
+
assert self.gemm_type is not None
|
|
126
|
+
value = self.gemm_type.value.lower()
|
|
127
|
+
return ["dense", "indexed", "grouped_contiguous", "grouped_masked"].index(value)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclasses.dataclass(kw_only=True)
|
|
131
|
+
class TuningConfig(BaseHummingConfig):
|
|
132
|
+
block_shape: tuple[int, int, int]
|
|
133
|
+
warp_shape: tuple[int, int, int]
|
|
134
|
+
|
|
135
|
+
use_stream_k: bool = True
|
|
136
|
+
|
|
137
|
+
num_stages: int = 2
|
|
138
|
+
num_ctas_per_sm: int = 1
|
|
139
|
+
|
|
140
|
+
use_warp_spec: bool | None = None
|
|
141
|
+
use_mbarrier: bool | None = None
|
|
142
|
+
use_cp_async: bool | None = None
|
|
143
|
+
|
|
144
|
+
use_tma: bool | None = None
|
|
145
|
+
use_tma_a: bool | None = None
|
|
146
|
+
use_tma_b: bool | None = None
|
|
147
|
+
use_tma_c: bool | None = None
|
|
148
|
+
use_tma_bs: bool | None = None
|
|
149
|
+
use_tma_bzp: bool | None = None
|
|
150
|
+
use_tma_bias: bool | None = None
|
|
151
|
+
|
|
152
|
+
num_write_splits: int = 1
|
|
153
|
+
multi_cast_size_a: int = 1
|
|
154
|
+
multi_cast_size_b: int = 1
|
|
155
|
+
|
|
156
|
+
_cpp_extra_names: ClassVar[tuple[str, ...]] = (
|
|
157
|
+
"num_threads",
|
|
158
|
+
"num_math_threads",
|
|
159
|
+
"num_load_threads",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
_name_map = {
|
|
163
|
+
"use_mbarrier": "kUseMBarrier",
|
|
164
|
+
"use_tma_bs": "kUseTmaBS",
|
|
165
|
+
"use_tma_bzp": "kUseTmaBZP",
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def __post_init__(self):
|
|
169
|
+
if self.use_warp_spec is None:
|
|
170
|
+
self.use_warp_spec = False
|
|
171
|
+
|
|
172
|
+
if self.use_tma is None:
|
|
173
|
+
self.use_tma = False
|
|
174
|
+
|
|
175
|
+
if self.use_mbarrier is None:
|
|
176
|
+
self.use_mbarrier = self.use_tma or self.use_warp_spec
|
|
177
|
+
|
|
178
|
+
if self.use_cp_async is None:
|
|
179
|
+
sm_version = torch.cuda.get_device_capability()
|
|
180
|
+
self.use_cp_async = sm_version[0] >= 8
|
|
181
|
+
|
|
182
|
+
self.num_math_threads = math.prod(self.block_shape) // math.prod(self.warp_shape) * 32
|
|
183
|
+
if self.use_warp_spec:
|
|
184
|
+
self.num_load_threads = 128
|
|
185
|
+
self.num_threads = self.num_math_threads + 128
|
|
186
|
+
else:
|
|
187
|
+
self.num_load_threads = self.num_math_threads
|
|
188
|
+
self.num_threads = self.num_math_threads
|
|
189
|
+
|
|
190
|
+
for name in dir(self):
|
|
191
|
+
if not name.startswith("use_tma_"):
|
|
192
|
+
continue
|
|
193
|
+
if not self.use_tma:
|
|
194
|
+
assert getattr(self, name) is not True
|
|
195
|
+
if getattr(self, name) is None:
|
|
196
|
+
setattr(self, name, self.use_tma)
|
humming/config/enum.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MmaType(enum.Enum):
|
|
5
|
+
MMA = "mma"
|
|
6
|
+
WGMMA = "wgmma"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class WeightScaleType(enum.Enum):
|
|
10
|
+
GROUP = "group"
|
|
11
|
+
BLOCK = "block"
|
|
12
|
+
CHANNEL = "channel"
|
|
13
|
+
TENSOR = "tensor"
|
|
14
|
+
GROUP_TENSOR = "group_tensor"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GemmType(enum.Enum):
|
|
18
|
+
DENSE = "dense"
|
|
19
|
+
INDEXED = "indexed"
|
|
20
|
+
GROUPED_CONTIGUOUS = "grouped_contiguous"
|
|
21
|
+
GROUPED_MASKED = "grouped_masked"
|
humming/config/mma.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
import humming.dtypes as dtypes
|
|
5
|
+
from humming.config.enum import MmaType
|
|
6
|
+
|
|
7
|
+
DTYPE_BIT_WIDTH_MAP = {
|
|
8
|
+
"f32": 32,
|
|
9
|
+
"s32": 32,
|
|
10
|
+
"f16": 16,
|
|
11
|
+
"bf16": 16,
|
|
12
|
+
"e4m3": 8,
|
|
13
|
+
"e5m2": 8,
|
|
14
|
+
"s8": 8,
|
|
15
|
+
"e2m1": 4,
|
|
16
|
+
"s4": 4,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
DTYPE_MAP = {
|
|
20
|
+
dtypes.float32: "f32",
|
|
21
|
+
dtypes.int32: "s32",
|
|
22
|
+
dtypes.float16: "f16",
|
|
23
|
+
dtypes.bfloat16: "bf16",
|
|
24
|
+
dtypes.float8e4m3: "e4m3",
|
|
25
|
+
dtypes.float8e5m2: "e5m2",
|
|
26
|
+
dtypes.int8: "s8",
|
|
27
|
+
dtypes.float4e2m1: "e2m1",
|
|
28
|
+
dtypes.int4: "s4",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def calc_reg_count(rows, cols, ptx_dtype):
|
|
33
|
+
total_bits = rows * cols * DTYPE_BIT_WIDTH_MAP[ptx_dtype]
|
|
34
|
+
assert total_bits % (32 * 32) == 0
|
|
35
|
+
reg_count = total_bits // (32 * 32)
|
|
36
|
+
return reg_count
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MmaOpClassImpl:
|
|
40
|
+
def __init__(self, m, n, k, a_dtype, b_dtype, cd_dtype):
|
|
41
|
+
self.shape = (m, n, k)
|
|
42
|
+
self.a_dtype = a_dtype if isinstance(a_dtype, str) else DTYPE_MAP[a_dtype]
|
|
43
|
+
self.b_dtype = b_dtype if isinstance(b_dtype, str) else DTYPE_MAP[b_dtype]
|
|
44
|
+
self.cd_dtype = cd_dtype if isinstance(cd_dtype, str) else DTYPE_MAP[cd_dtype]
|
|
45
|
+
|
|
46
|
+
self.reg_a_count = calc_reg_count(m, k, self.a_dtype)
|
|
47
|
+
self.reg_b_count = calc_reg_count(k, n, self.b_dtype)
|
|
48
|
+
self.reg_cd_count = calc_reg_count(m, n, self.cd_dtype)
|
|
49
|
+
if self.cd_dtype == "f16":
|
|
50
|
+
self.val_type_cd = "half"
|
|
51
|
+
self.reg_cd_type = "uint32_t"
|
|
52
|
+
elif self.cd_dtype == "bf16":
|
|
53
|
+
self.val_type_cd = "nv_bfloat16"
|
|
54
|
+
self.reg_cd_type = "uint32_t"
|
|
55
|
+
elif self.cd_dtype == "f32":
|
|
56
|
+
self.val_type_cd = "float"
|
|
57
|
+
self.reg_cd_type = "float"
|
|
58
|
+
elif self.cd_dtype == "s32":
|
|
59
|
+
self.val_type_cd = "int32_t"
|
|
60
|
+
self.reg_cd_type = "uint32_t"
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Invalid cd_dtype: {cd_dtype}")
|
|
63
|
+
|
|
64
|
+
def to_cpp_str(self, include_class_name=False):
|
|
65
|
+
reg_cd_type = self.reg_cd_type
|
|
66
|
+
lines = [
|
|
67
|
+
"static constexpr MmaType kMmaType = MmaType::MMA;",
|
|
68
|
+
f"using MmaShape = Shape<{self.shape[0]}, {self.shape[1]}, {self.shape[2]}>;",
|
|
69
|
+
"",
|
|
70
|
+
f"using ValTypeC = {self.val_type_cd};",
|
|
71
|
+
f"using ValTypeD = {self.val_type_cd};",
|
|
72
|
+
"",
|
|
73
|
+
f"static constexpr uint32_t kATypeBits = {DTYPE_BIT_WIDTH_MAP[self.a_dtype]};",
|
|
74
|
+
f"static constexpr uint32_t kBTypeBits = {DTYPE_BIT_WIDTH_MAP[self.b_dtype]};",
|
|
75
|
+
f"static constexpr uint32_t kCTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
|
|
76
|
+
f"static constexpr uint32_t kDTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
|
|
77
|
+
"",
|
|
78
|
+
f"using ARegisters = uint32_t[{self.reg_a_count}];",
|
|
79
|
+
f"using BRegisters = uint32_t[{self.reg_b_count}];",
|
|
80
|
+
f"using CRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
|
|
81
|
+
f"using DRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
|
|
82
|
+
"",
|
|
83
|
+
"CUDA_INLINE",
|
|
84
|
+
f"static void fma(uint32_t *a, uint32_t *b, {reg_cd_type} *c, {reg_cd_type} *d) {{",
|
|
85
|
+
*self.generate_ptx(indent=2).strip("\n").split("\n"),
|
|
86
|
+
"};",
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
code = "\n".join(" " + x if x else x for x in lines)
|
|
90
|
+
if include_class_name:
|
|
91
|
+
code = f"class MmaOpClass {{\n{code}\n}};"
|
|
92
|
+
|
|
93
|
+
return code
|
|
94
|
+
|
|
95
|
+
def generate_ptx(self, indent=0):
|
|
96
|
+
a_dtype = self.a_dtype
|
|
97
|
+
b_dtype = self.b_dtype
|
|
98
|
+
cd_dtype = self.cd_dtype
|
|
99
|
+
shape = self.shape
|
|
100
|
+
|
|
101
|
+
asm_op = f"mma.sync.aligned.m{shape[0]}n{shape[1]}k{shape[2]}.row.col"
|
|
102
|
+
asm_op += f".{cd_dtype}.{a_dtype}.{b_dtype}.{cd_dtype}"
|
|
103
|
+
if "s" in a_dtype:
|
|
104
|
+
asm_op += ".satfinite"
|
|
105
|
+
|
|
106
|
+
start = 0
|
|
107
|
+
end = 0
|
|
108
|
+
param_placeholders_list = []
|
|
109
|
+
counts = [self.reg_cd_count, self.reg_a_count, self.reg_b_count, self.reg_cd_count]
|
|
110
|
+
for i in range(len(counts)):
|
|
111
|
+
end += counts[i]
|
|
112
|
+
placeholder_str = ", ".join(f"%{x}" for x in range(start, end))
|
|
113
|
+
param_placeholders_list.append("{" + placeholder_str + "}")
|
|
114
|
+
start += counts[i]
|
|
115
|
+
|
|
116
|
+
a_params = []
|
|
117
|
+
b_params = []
|
|
118
|
+
c_params = []
|
|
119
|
+
d_params = []
|
|
120
|
+
for i in range(self.reg_a_count):
|
|
121
|
+
a_params.append(f' "r"(a[{i}])')
|
|
122
|
+
for i in range(self.reg_b_count):
|
|
123
|
+
b_params.append(f' "r"(b[{i}])')
|
|
124
|
+
for i in range(self.reg_cd_count):
|
|
125
|
+
t = "f" if cd_dtype == "f32" else "r"
|
|
126
|
+
c_params.append(f' "{t}"(c[{i}])')
|
|
127
|
+
d_params.append(f'"+{t}"(d[{i}])')
|
|
128
|
+
|
|
129
|
+
asm_code = f"""
|
|
130
|
+
asm volatile(
|
|
131
|
+
"{asm_op} "
|
|
132
|
+
"{", ".join(param_placeholders_list)};\\n"
|
|
133
|
+
: {", ".join(d_params)}
|
|
134
|
+
: {", ".join(a_params)},
|
|
135
|
+
{", ".join(b_params)},
|
|
136
|
+
{", ".join(c_params)}
|
|
137
|
+
);
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
space_count = len(re.findall("^\n( +)", asm_code)[0])
|
|
141
|
+
asm_code = asm_code.replace("\n" + " " * space_count, "\n").strip()
|
|
142
|
+
asm_code = "".join("\n" + " " * indent + x for x in asm_code.split("\n"))
|
|
143
|
+
|
|
144
|
+
return asm_code
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class WgmmaOpClassImpl:
|
|
148
|
+
def __init__(self, m, n, k, a_dtype, b_dtype, cd_dtype):
|
|
149
|
+
self.shape = (m, n, k)
|
|
150
|
+
self.a_dtype = a_dtype if isinstance(a_dtype, str) else DTYPE_MAP[a_dtype]
|
|
151
|
+
self.b_dtype = b_dtype if isinstance(b_dtype, str) else DTYPE_MAP[b_dtype]
|
|
152
|
+
self.cd_dtype = cd_dtype if isinstance(cd_dtype, str) else DTYPE_MAP[cd_dtype]
|
|
153
|
+
|
|
154
|
+
# Project B (registers) is sized (project N) x (project K) in b_dtype after
|
|
155
|
+
# the transpose — that's what fills the wgmma A register operand.
|
|
156
|
+
self.reg_b_count = calc_reg_count(n, k, self.b_dtype) // 4
|
|
157
|
+
self.reg_cd_count = calc_reg_count(m, n, self.cd_dtype) // 4
|
|
158
|
+
if self.cd_dtype == "f16":
|
|
159
|
+
self.val_type_cd = "half"
|
|
160
|
+
self.reg_cd_type = "uint32_t"
|
|
161
|
+
elif self.cd_dtype == "bf16":
|
|
162
|
+
self.val_type_cd = "nv_bfloat16"
|
|
163
|
+
self.reg_cd_type = "uint32_t"
|
|
164
|
+
elif self.cd_dtype == "f32":
|
|
165
|
+
self.val_type_cd = "float"
|
|
166
|
+
self.reg_cd_type = "float"
|
|
167
|
+
elif self.cd_dtype == "s32":
|
|
168
|
+
self.val_type_cd = "int32_t"
|
|
169
|
+
self.reg_cd_type = "uint32_t"
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(f"Invalid cd_dtype: {cd_dtype}")
|
|
172
|
+
|
|
173
|
+
def to_cpp_str(self, include_class_name=False):
|
|
174
|
+
reg_cd_type = self.reg_cd_type
|
|
175
|
+
lines = [
|
|
176
|
+
"static constexpr MmaType kMmaType = MmaType::WGMMA;",
|
|
177
|
+
f"using MmaShape = Shape<{self.shape[0]}, {self.shape[1]}, {self.shape[2]}>;",
|
|
178
|
+
"",
|
|
179
|
+
f"using ValTypeC = {self.val_type_cd};",
|
|
180
|
+
f"using ValTypeD = {self.val_type_cd};",
|
|
181
|
+
"",
|
|
182
|
+
f"static constexpr uint32_t kATypeBits = {DTYPE_BIT_WIDTH_MAP[self.a_dtype]};",
|
|
183
|
+
f"static constexpr uint32_t kBTypeBits = {DTYPE_BIT_WIDTH_MAP[self.b_dtype]};",
|
|
184
|
+
f"static constexpr uint32_t kCTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
|
|
185
|
+
f"static constexpr uint32_t kDTypeBits = {DTYPE_BIT_WIDTH_MAP[self.cd_dtype]};",
|
|
186
|
+
"",
|
|
187
|
+
f"using BRegisters = uint32_t[{self.reg_b_count}];",
|
|
188
|
+
f"using CRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
|
|
189
|
+
f"using DRegisters = {self.reg_cd_type}[{self.reg_cd_count}];",
|
|
190
|
+
"",
|
|
191
|
+
"CUDA_INLINE",
|
|
192
|
+
f"static void fma(uint64_t &desc, uint32_t *b, {reg_cd_type} *d, bool pred = true) {{",
|
|
193
|
+
*self.generate_ptx(indent=2, has_scale_d=True).strip("\n").split("\n"),
|
|
194
|
+
"};",
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
code = "\n".join(" " + x if x else x for x in lines)
|
|
198
|
+
if include_class_name:
|
|
199
|
+
code = f"class MmaOpClass {{\n{code}\n}};"
|
|
200
|
+
|
|
201
|
+
return code
|
|
202
|
+
|
|
203
|
+
def generate_ptx(self, indent=2, has_scale_d=True):
|
|
204
|
+
a_dtype = self.a_dtype
|
|
205
|
+
b_dtype = self.b_dtype
|
|
206
|
+
cd_dtype = self.cd_dtype
|
|
207
|
+
m, n, k = self.shape
|
|
208
|
+
|
|
209
|
+
# Swap M<->N and A-dtype<->B-dtype in PTX: project's A becomes wgmma's B and
|
|
210
|
+
# project's B becomes wgmma's A. The PTX dtype suffix order is .cd.a.b, so
|
|
211
|
+
# the wgmma A slot takes project's b_dtype and the wgmma B slot takes a_dtype.
|
|
212
|
+
asm_op = f"wgmma.mma_async.sync.aligned.m{n}n{m}k{k}"
|
|
213
|
+
asm_op += f".{cd_dtype}.{b_dtype}.{a_dtype}"
|
|
214
|
+
# satfinite gates on the wgmma-A operand dtype (= project's B).
|
|
215
|
+
if "s" in b_dtype:
|
|
216
|
+
asm_op += ".satfinite"
|
|
217
|
+
|
|
218
|
+
start = 0
|
|
219
|
+
end = 0
|
|
220
|
+
param_placeholders_list = []
|
|
221
|
+
counts = [self.reg_cd_count, self.reg_b_count]
|
|
222
|
+
for i in range(len(counts)):
|
|
223
|
+
end += counts[i]
|
|
224
|
+
placeholder_str = ", ".join(f"%{x}" for x in range(start, end))
|
|
225
|
+
param_placeholders_list.append("{" + placeholder_str + "}")
|
|
226
|
+
start += counts[i]
|
|
227
|
+
param_placeholders_list.append(f"%{sum(counts)}")
|
|
228
|
+
|
|
229
|
+
other_ptx_args = ", p" if has_scale_d else ", 1"
|
|
230
|
+
# The dtype-specific PTX tail args (scale/trans flags) gate on the wgmma-A
|
|
231
|
+
# operand dtype, which after the swap is project's b_dtype.
|
|
232
|
+
if self.b_dtype in ["f16", "bf16"]:
|
|
233
|
+
other_ptx_args += ", 1, 1, 0"
|
|
234
|
+
elif self.b_dtype in ["e4m3", "e5m2", "e2m1"]:
|
|
235
|
+
other_ptx_args += ", 1, 1"
|
|
236
|
+
|
|
237
|
+
# Project A's smem descriptor fills the wgmma B operand.
|
|
238
|
+
a_desc_param = ' "l"(desc)'
|
|
239
|
+
# Project B's registers fill the wgmma A operand.
|
|
240
|
+
b_params = []
|
|
241
|
+
cd_params = []
|
|
242
|
+
for i in range(self.reg_b_count):
|
|
243
|
+
b_params.append(f' "r"(b[{i}])')
|
|
244
|
+
for i in range(self.reg_cd_count):
|
|
245
|
+
t = "f" if cd_dtype == "f32" else "r"
|
|
246
|
+
cd_params.append(f'"+{t}"(d[{i}])')
|
|
247
|
+
|
|
248
|
+
cd_param_str = ""
|
|
249
|
+
for i in range(math.ceil(len(cd_params) / 4)):
|
|
250
|
+
cd_params_part = cd_params[i * 4 : (i + 1) * 4]
|
|
251
|
+
cd_params_part_str = ", ".join(cd_params_part) + ",\n"
|
|
252
|
+
if cd_param_str:
|
|
253
|
+
cd_params_part_str = " " + cd_params_part_str
|
|
254
|
+
|
|
255
|
+
cd_param_str += cd_params_part_str
|
|
256
|
+
|
|
257
|
+
cd_param_str = cd_param_str.strip().strip(",")
|
|
258
|
+
|
|
259
|
+
if has_scale_d:
|
|
260
|
+
asm_code = f"""
|
|
261
|
+
asm volatile(
|
|
262
|
+
"{{\\n"
|
|
263
|
+
".reg .pred p;\\n"
|
|
264
|
+
"setp.ne.b32 p, %{sum(counts) + 1}, 0;\\n"
|
|
265
|
+
"{asm_op} "
|
|
266
|
+
"{", ".join(param_placeholders_list)}{other_ptx_args};\\n"
|
|
267
|
+
"}}\\n"
|
|
268
|
+
: {cd_param_str}
|
|
269
|
+
: {", ".join(b_params)},
|
|
270
|
+
{a_desc_param}, "r"((uint32_t)pred)
|
|
271
|
+
);
|
|
272
|
+
"""
|
|
273
|
+
else:
|
|
274
|
+
asm_code = f"""
|
|
275
|
+
asm volatile(
|
|
276
|
+
"{asm_op} "
|
|
277
|
+
"{", ".join(param_placeholders_list)}{other_ptx_args};\\n"
|
|
278
|
+
: {cd_param_str}
|
|
279
|
+
: {", ".join(b_params)},
|
|
280
|
+
{a_desc_param}
|
|
281
|
+
);
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
space_count = len(re.findall("^\n( +)", asm_code)[0])
|
|
285
|
+
asm_code = asm_code.replace("\n" + " " * space_count, "\n").strip()
|
|
286
|
+
asm_code = "".join("\n" + " " * indent + x for x in asm_code.split("\n"))
|
|
287
|
+
|
|
288
|
+
return asm_code
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class MmaOpClass:
|
|
292
|
+
@classmethod
|
|
293
|
+
def from_config(cls, mma_type, m, n, k, a_dtype, b_dtype, cd_dtype):
|
|
294
|
+
mma_type = mma_type if isinstance(mma_type, MmaType) else getattr(MmaType, mma_type.upper())
|
|
295
|
+
|
|
296
|
+
if mma_type == MmaType.MMA:
|
|
297
|
+
return MmaOpClassImpl(m, n, k, a_dtype, b_dtype, cd_dtype)
|
|
298
|
+
elif mma_type == MmaType.WGMMA:
|
|
299
|
+
return WgmmaOpClassImpl(m, n, k, a_dtype, b_dtype, cd_dtype)
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError(f"Invalid MMA Type: {mma_type}")
|