fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
import math
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import triton # @manual
|
|
14
|
+
|
|
15
|
+
import triton.language as tl # @manual
|
|
16
|
+
|
|
17
|
+
from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@triton.jit
|
|
21
|
+
def _floor_log2(x):
|
|
22
|
+
"""Helper function to efficiently compute floor(log2(x))
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
x (Tensor): FP32 Input tensor to operate on.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: Floor of log2(x).
|
|
29
|
+
"""
|
|
30
|
+
# Helpful bit constants.
|
|
31
|
+
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
|
|
32
|
+
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
|
|
33
|
+
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
|
|
34
|
+
|
|
35
|
+
# View x as an integer and extract its exponent.
|
|
36
|
+
x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK
|
|
37
|
+
# Shift exponent down to bottom bits.
|
|
38
|
+
x = x >> FP32_EXP_OFFSET
|
|
39
|
+
# Remove FP32 exponent bias and return.
|
|
40
|
+
return (x - FP32_EXP_BIAS).to(tl.float32)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@triton.jit
|
|
44
|
+
def _compute_exp(
|
|
45
|
+
group_max,
|
|
46
|
+
rounding_mode,
|
|
47
|
+
rand_bits,
|
|
48
|
+
MBITS: tl.constexpr,
|
|
49
|
+
):
|
|
50
|
+
"""Compute shared exponent of group using specified rounding mode.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
group_max (Tensor): Group of values to compute exponent of.
|
|
54
|
+
rounding_mode (int or RoundingMode): Which rounding mode to use.
|
|
55
|
+
rand_bits (int): Random integer values used for stochastic rounding.
|
|
56
|
+
mbits (int): Number of mantissa bits in target mx4 format.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tensor: Shared exponent of group.
|
|
60
|
+
"""
|
|
61
|
+
# Define some helpful constants.
|
|
62
|
+
MBITS_FP32: tl.constexpr = 23 # type: ignore[Incompatible variable type]
|
|
63
|
+
M_ROUND: tl.constexpr = (1 << (MBITS_FP32 - MBITS - 1)) - 1 # type: ignore[Incompatible variable type]
|
|
64
|
+
RAND_MASK: tl.constexpr = (1 << (MBITS_FP32 - MBITS)) - 1 # type: ignore[Incompatible variable type]
|
|
65
|
+
|
|
66
|
+
# Nearest rounding mode.
|
|
67
|
+
if rounding_mode == 0:
|
|
68
|
+
return tl.floor(tl.log2(group_max) + 0.5)
|
|
69
|
+
# Floor rounding mode. This can be done with fast bit ops.
|
|
70
|
+
if rounding_mode == 1:
|
|
71
|
+
return _floor_log2(group_max)
|
|
72
|
+
# Even pre-rounding mode.
|
|
73
|
+
elif rounding_mode == 2:
|
|
74
|
+
# Add fixed rounding to the mantissa bits of the input to round during truncation.
|
|
75
|
+
group_max = group_max.to(tl.int32, bitcast=True) + M_ROUND
|
|
76
|
+
# Then perform floor rounding of log.
|
|
77
|
+
return _floor_log2(group_max)
|
|
78
|
+
# Stochastic rounding mode.
|
|
79
|
+
elif rounding_mode == 3:
|
|
80
|
+
# Use random bits to add noise to mantissa that would otherwise
|
|
81
|
+
# be rounded away.
|
|
82
|
+
group_max = group_max.to(tl.int32, bitcast=True) + (RAND_MASK & rand_bits)
|
|
83
|
+
# Now compute log and truncate.
|
|
84
|
+
return _floor_log2(group_max)
|
|
85
|
+
else:
|
|
86
|
+
return tl.ceil(tl.log2(group_max))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@triton.jit
|
|
90
|
+
def _kernel_quantize_mx4(
|
|
91
|
+
A,
|
|
92
|
+
out,
|
|
93
|
+
rand_bits,
|
|
94
|
+
M,
|
|
95
|
+
K,
|
|
96
|
+
GROUPS_PER_ROW,
|
|
97
|
+
GROUPS_PER_THREAD,
|
|
98
|
+
ROW_PADDING,
|
|
99
|
+
GROUP_SIZE: tl.constexpr,
|
|
100
|
+
EBITS: tl.constexpr,
|
|
101
|
+
MBITS: tl.constexpr,
|
|
102
|
+
ROUNDING_MODE: tl.constexpr,
|
|
103
|
+
STOCHASTIC_CASTING: tl.constexpr,
|
|
104
|
+
FP4_EXP_BIAS: tl.constexpr,
|
|
105
|
+
GROUP_LOAD: tl.constexpr,
|
|
106
|
+
USE_INT64: tl.constexpr,
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Quantize a 1D float tensor into a packed MX4 tensor.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
A (Tensor): [M] float tensor to be quantized.
|
|
112
|
+
out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values.
|
|
113
|
+
rand_bits (Optional Tensor): [M, K / 2] random integers used for stochastic rounding.
|
|
114
|
+
M (int): Number of input rows.
|
|
115
|
+
K (int): Number of input columns.
|
|
116
|
+
GROUPS_PER_ROW (int): Number of groups in each row of the input.
|
|
117
|
+
GROUPS_PER_THREAD (int): Number of groups to process per thread.
|
|
118
|
+
ROW_PADDING (int): Number of elements of padding to insert into each row.
|
|
119
|
+
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
|
|
120
|
+
EBITS (int): Number of exponent bits in target mx4 format.
|
|
121
|
+
MBITS (int): Number of mantissa bits in target mx4 format.
|
|
122
|
+
ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent.
|
|
123
|
+
STOCHASTIC_CASTING (bool): Whether to use stochastic rounding when downcasting.
|
|
124
|
+
FP4_EXP_BIAS (int): Exponent bias of target mx4 format.
|
|
125
|
+
GROUP_LOAD (int): Number of groups to process simultaneously.
|
|
126
|
+
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
|
|
127
|
+
"""
|
|
128
|
+
# Define Constant Expressions.
|
|
129
|
+
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
|
|
130
|
+
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
|
|
131
|
+
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
|
|
132
|
+
FP32_SIGN_OFFSET: tl.constexpr = 31 # type: ignore[Incompatible variable type]
|
|
133
|
+
SIGN_MASK: tl.constexpr = 0x1 # type: ignore[Incompatible variable type]
|
|
134
|
+
FP32_MANTISSA_MASK: tl.constexpr = 0x007FFFFF # type: ignore[Incompatible variable type]
|
|
135
|
+
# FP4 has 2 mantissa bits, one explicit one implicit.
|
|
136
|
+
MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
|
|
137
|
+
MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type]
|
|
138
|
+
IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
|
|
139
|
+
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
|
|
140
|
+
MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
|
|
141
|
+
EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
|
|
142
|
+
IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
|
|
143
|
+
RAND_MASK: tl.constexpr = (1 << (FP32_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type]
|
|
144
|
+
|
|
145
|
+
# Get the current thread number.
|
|
146
|
+
pid = tl.program_id(0)
|
|
147
|
+
# For very large inputs, we need to use int64 indexes. This is slower but necessary.
|
|
148
|
+
if USE_INT64:
|
|
149
|
+
pid = pid.to(tl.int64)
|
|
150
|
+
M = tl.cast(M, tl.int64)
|
|
151
|
+
K = tl.cast(K, tl.int64)
|
|
152
|
+
GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64)
|
|
153
|
+
|
|
154
|
+
# Boundaries for writing to output tensor.
|
|
155
|
+
PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
|
|
156
|
+
NUM_GROUPS = M * GROUPS_PER_ROW
|
|
157
|
+
OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 2 + GROUPS_PER_THREAD
|
|
158
|
+
OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 2 + NUM_GROUPS
|
|
159
|
+
|
|
160
|
+
# Find starting offsets for this thread. These are calculated before adjusting for padding.
|
|
161
|
+
input_start = pid * (GROUPS_PER_THREAD * GROUP_SIZE)
|
|
162
|
+
output_start = pid * OUTPUT_CHUNK_SIZE
|
|
163
|
+
exp_start = output_start + GROUP_SIZE // 2
|
|
164
|
+
# Initiate offset ranges used in kernel.
|
|
165
|
+
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start
|
|
166
|
+
output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2))
|
|
167
|
+
# Stochastic rounding loads chunks of random values.
|
|
168
|
+
if ROUNDING_MODE == 3:
|
|
169
|
+
rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * GROUPS_PER_THREAD
|
|
170
|
+
# Ceil rounding uses single values as a seed.
|
|
171
|
+
else:
|
|
172
|
+
rand_bits_offset = pid * GROUPS_PER_THREAD
|
|
173
|
+
# We need to shift output offsets to make space for shared exponent storage.
|
|
174
|
+
output_offset += output_offset // (GROUP_SIZE // 2) + output_start
|
|
175
|
+
# Now create offsets for writing the shared exponent.
|
|
176
|
+
exp_offset = tl.arange(0, GROUP_LOAD) * PACKED_GROUP_SIZE + exp_start
|
|
177
|
+
|
|
178
|
+
# Load and process blocks of values for this chunk.
|
|
179
|
+
for _k in range(0, tl.cdiv(GROUPS_PER_THREAD, GROUP_LOAD)):
|
|
180
|
+
# We need to make some adjustments to allow for padding.
|
|
181
|
+
pad_mask = (input_offset % (GROUPS_PER_ROW * GROUP_SIZE)) < K
|
|
182
|
+
if ROW_PADDING != 0:
|
|
183
|
+
# Shift the input to account for padding.
|
|
184
|
+
padded_input_offset = (
|
|
185
|
+
input_offset
|
|
186
|
+
- (input_offset // (GROUPS_PER_ROW * GROUP_SIZE)) * ROW_PADDING
|
|
187
|
+
)
|
|
188
|
+
# When theres no padding we can simplify indexing.
|
|
189
|
+
else:
|
|
190
|
+
padded_input_offset = input_offset
|
|
191
|
+
|
|
192
|
+
# Load a block of values.
|
|
193
|
+
a = tl.load(
|
|
194
|
+
A + padded_input_offset,
|
|
195
|
+
# Mask values out of range for both the main array and this chunk. Also pad if needed.
|
|
196
|
+
mask=(padded_input_offset < (M * K))
|
|
197
|
+
& (padded_input_offset < ((pid + 1) * GROUPS_PER_THREAD * GROUP_SIZE))
|
|
198
|
+
& pad_mask,
|
|
199
|
+
other=0,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Scaling step
|
|
203
|
+
##############
|
|
204
|
+
|
|
205
|
+
# View the block in terms of groups.
|
|
206
|
+
a_groups = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
|
|
207
|
+
# Compute the shared exponent of each group.
|
|
208
|
+
group_max = tl.max(tl.abs(a_groups), axis=1)
|
|
209
|
+
# Prevent infinite values in log.
|
|
210
|
+
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
|
|
211
|
+
# Load relevant random values if doing stochastic rounding
|
|
212
|
+
# or stochastic casting.
|
|
213
|
+
group_rand_bits = None
|
|
214
|
+
if (ROUNDING_MODE) == 3 or STOCHASTIC_CASTING:
|
|
215
|
+
group_rand_bits = tl.load(
|
|
216
|
+
rand_bits + rand_bits_offset,
|
|
217
|
+
mask=rand_bits_offset < K // GROUP_SIZE,
|
|
218
|
+
other=0,
|
|
219
|
+
)
|
|
220
|
+
rand_bits_offset += GROUP_LOAD
|
|
221
|
+
# Compute shared exponent using specified rounding mode.
|
|
222
|
+
group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits, MBITS)
|
|
223
|
+
# Subtract largest exponent in target datatype and remove bias.
|
|
224
|
+
group_exp = group_exp - EBITS
|
|
225
|
+
# Make sure exponent is in valid range.
|
|
226
|
+
group_exp = tl.clamp(group_exp, -127, 125)
|
|
227
|
+
|
|
228
|
+
# Next we scale A in preparation for quantization.
|
|
229
|
+
scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
|
|
230
|
+
# Apply scale to input. We do this by broadcasting scale.
|
|
231
|
+
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) / tl.reshape(
|
|
232
|
+
scale, [GROUP_LOAD, 1]
|
|
233
|
+
)
|
|
234
|
+
# Reshape back to a flat array.
|
|
235
|
+
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
|
|
236
|
+
|
|
237
|
+
# We're done with group_exp now so we can write it out.
|
|
238
|
+
# We readd fp32_exp_bias for compatibility with cuda dequant.
|
|
239
|
+
tl.store(
|
|
240
|
+
out + exp_offset,
|
|
241
|
+
(group_exp + FP32_EXP_BIAS).to(tl.uint8),
|
|
242
|
+
# Prevent writing outside this chunk or the main array.
|
|
243
|
+
mask=(exp_offset < OUTPUT_SIZE)
|
|
244
|
+
& (exp_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Quantization step
|
|
248
|
+
###################
|
|
249
|
+
|
|
250
|
+
# During quantization, we're going to be doing a lot of bitwise operations.
|
|
251
|
+
# This is easier to work with in int32.
|
|
252
|
+
scaled_a = scaled_a.to(tl.int32, bitcast=True)
|
|
253
|
+
|
|
254
|
+
# When doing stochastic downcasting, generate random values for this block
|
|
255
|
+
# and apply it to the mantissa.
|
|
256
|
+
if STOCHASTIC_CASTING:
|
|
257
|
+
# We're going to generate 4 blocks at once so we only need
|
|
258
|
+
# one fourth of the input offsets.
|
|
259
|
+
# Start by splitting down to half of offsets.
|
|
260
|
+
philox_4x_offset = tl.split(
|
|
261
|
+
tl.reshape(
|
|
262
|
+
input_offset,
|
|
263
|
+
[GROUP_LOAD * GROUP_SIZE // 2, 2],
|
|
264
|
+
can_reorder=True,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
# Split down to fourth.
|
|
268
|
+
philox_4x_offset = tl.split(
|
|
269
|
+
tl.reshape(
|
|
270
|
+
philox_4x_offset,
|
|
271
|
+
[GROUP_LOAD * GROUP_SIZE // 4, 2],
|
|
272
|
+
can_reorder=True,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
# Generate 4 blocks of random bits for this block.
|
|
276
|
+
a_4x, b_4x, c_4x, d_4x = tl.randint4x(
|
|
277
|
+
group_rand_bits, philox_4x_offset, n_rounds=7
|
|
278
|
+
)
|
|
279
|
+
# Combine the 4 blocks into a single chunk of random values.
|
|
280
|
+
# This needs to be done incrementally.
|
|
281
|
+
stochastic_round_bits = tl.join(tl.join(a_4x, b_4x), tl.join(c_4x, d_4x))
|
|
282
|
+
# Flatten back to simple array.
|
|
283
|
+
stochastic_round_bits = tl.reshape(
|
|
284
|
+
stochastic_round_bits, [GROUP_LOAD * GROUP_SIZE]
|
|
285
|
+
).to(tl.int32, bitcast=True)
|
|
286
|
+
|
|
287
|
+
# Mask off mantissa bits of random value and add to mantissa.
|
|
288
|
+
scaled_a = scaled_a + (stochastic_round_bits & RAND_MASK)
|
|
289
|
+
|
|
290
|
+
# Extract sign bit of value.
|
|
291
|
+
sign_bit = (scaled_a >> FP32_SIGN_OFFSET) & SIGN_MASK
|
|
292
|
+
|
|
293
|
+
# Extract exponent.
|
|
294
|
+
biased_exp = (scaled_a & FP32_EXP_MASK) >> FP32_EXP_OFFSET
|
|
295
|
+
|
|
296
|
+
# Extract mantissa.
|
|
297
|
+
trailing_mantissa = scaled_a & FP32_MANTISSA_MASK
|
|
298
|
+
|
|
299
|
+
# Adjust exponent bias for FP4.
|
|
300
|
+
new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
|
|
301
|
+
|
|
302
|
+
# Compute difference between ideal exponent and what fp4 can represent.
|
|
303
|
+
exp_diff = tl.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
|
|
304
|
+
|
|
305
|
+
# Clip this difference to maximum number of fp32 mantissa bits.
|
|
306
|
+
exp_diff = tl.minimum(exp_diff, MAX_FP32_MANTISSA_BITS)
|
|
307
|
+
|
|
308
|
+
# Now we round our fp32 mantissa down to fp4.
|
|
309
|
+
is_subnorm = biased_exp == 0
|
|
310
|
+
# Add implied 1 bit to normal values.
|
|
311
|
+
mantissa = tl.where(
|
|
312
|
+
is_subnorm, trailing_mantissa, trailing_mantissa + IMPLIED_1_BIT
|
|
313
|
+
)
|
|
314
|
+
# Compute base number of bits corresponding to the mantissa, smaller for subnorms
|
|
315
|
+
# since implied one is included in exp_diff.
|
|
316
|
+
fp32_sig_bits = tl.where(is_subnorm, 23, 24).to(tl.int32)
|
|
317
|
+
# Now we're ready to shift down to target bitwidth (with an extra bit for rounding).
|
|
318
|
+
mantissa = mantissa >> (fp32_sig_bits + exp_diff - MBITS_IMPLICIT - 1)
|
|
319
|
+
# Perform rounding by adding 1 and shifting down.
|
|
320
|
+
mantissa = (mantissa + 1) >> 1
|
|
321
|
+
|
|
322
|
+
# Check for overflow and adjust exponent accordingly.
|
|
323
|
+
overflow = mantissa > MANTISSA_OVERFLOW_THRESHOLD
|
|
324
|
+
# Allow subnorms to overflow into normals, otherwise shift away overflow.
|
|
325
|
+
mantissa = tl.where(overflow and (not is_subnorm), mantissa >> 1, mantissa)
|
|
326
|
+
# Special case where a value is subnormal and has a large mantissa, overflow it.
|
|
327
|
+
new_biased_exp = tl.where(
|
|
328
|
+
(new_biased_exp <= 0) and (mantissa == 2), 1, new_biased_exp
|
|
329
|
+
)
|
|
330
|
+
# Remove implicit 1.
|
|
331
|
+
mantissa = mantissa & IMPLICIT_1_MASK
|
|
332
|
+
# Add overflow to exponent.
|
|
333
|
+
new_biased_exp = tl.where(overflow, new_biased_exp + 1, new_biased_exp)
|
|
334
|
+
# If exp overflows, set mantissa to maximum value (equivalent to clamping).
|
|
335
|
+
mantissa = tl.where(new_biased_exp > EXPONENT_OVERFLOW_THRESHOLD, 1, mantissa)
|
|
336
|
+
|
|
337
|
+
# Construct FP4 value from components.
|
|
338
|
+
new_biased_exp = tl.maximum(
|
|
339
|
+
tl.minimum(new_biased_exp, EXPONENT_OVERFLOW_THRESHOLD), 0
|
|
340
|
+
)
|
|
341
|
+
mx4_value = (new_biased_exp << (MBITS_IMPLICIT - 1)) | mantissa
|
|
342
|
+
mx4_value = (sign_bit << (EBITS + MBITS)) | mx4_value
|
|
343
|
+
|
|
344
|
+
# Extract low and high bits from values.
|
|
345
|
+
low_mx4, high_mx4 = tl.split(
|
|
346
|
+
tl.reshape(mx4_value, [(GROUP_LOAD * GROUP_SIZE) // 2, 2])
|
|
347
|
+
)
|
|
348
|
+
# Shift mx4 values together so they are packed into int8.
|
|
349
|
+
packed_mx4 = ((high_mx4 << 4) | (low_mx4)).to(tl.int8)
|
|
350
|
+
|
|
351
|
+
# Write out packed values to output tensor.
|
|
352
|
+
tl.store(
|
|
353
|
+
out + output_offset,
|
|
354
|
+
packed_mx4,
|
|
355
|
+
# Prevent writing outside this chunk or the main array.
|
|
356
|
+
mask=(output_offset < OUTPUT_SIZE)
|
|
357
|
+
& (output_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))),
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Update offsets so we work on the next block.
|
|
361
|
+
input_offset += GROUP_LOAD * GROUP_SIZE
|
|
362
|
+
exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
|
|
363
|
+
output_offset += GROUP_LOAD * PACKED_GROUP_SIZE
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def triton_quantize_mx4(
|
|
367
|
+
a: torch.Tensor,
|
|
368
|
+
group_size: int = 32,
|
|
369
|
+
ebits: int = 2,
|
|
370
|
+
mbits: int = 1,
|
|
371
|
+
rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
|
|
372
|
+
stochastic_casting: bool = False,
|
|
373
|
+
) -> torch.Tensor:
|
|
374
|
+
"""
|
|
375
|
+
Quantize a tensor to mx4 format using efficient triton kernels.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
a (Tensor): [M] higher precision input tensor.
|
|
379
|
+
group_size (int): Size of chunks that will use the same shared exponent.
|
|
380
|
+
ebits (int): Number of bits to use for exponent in target mx4 format.
|
|
381
|
+
mbits (int): Number of bits to use for mantissa in target mx4 format.
|
|
382
|
+
rounding_mode (Union[RoundingMode, int]): Which type of rounding to use
|
|
383
|
+
when calculating shared exponent. Defaults to pre-rounding to nearest even int.
|
|
384
|
+
stochastic_casting (bool): Whether to use stochastic casting.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
|
|
388
|
+
with group exponents attached to each row.
|
|
389
|
+
|
|
390
|
+
eg.
|
|
391
|
+
Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
|
|
392
|
+
each value contain two elements packed into an int8 and
|
|
393
|
+
there are 32 groups in each row.
|
|
394
|
+
"""
|
|
395
|
+
# If given an empty shape, return an empty tensor.
|
|
396
|
+
if a.numel() == 0:
|
|
397
|
+
return torch.empty(a.shape, device=a.device, dtype=torch.uint8)
|
|
398
|
+
# Make sure input is continuous in memory.
|
|
399
|
+
assert a.is_contiguous(), "Inputs to mx4 quantize must be contiguous in memory."
|
|
400
|
+
|
|
401
|
+
orig_shape = a.shape
|
|
402
|
+
# For simplicity, view input as a 2D array.
|
|
403
|
+
a = a.view(-1, a.shape[-1])
|
|
404
|
+
# Extract rows and columns.
|
|
405
|
+
M, K = a.shape
|
|
406
|
+
# In this kernel, we want each row to be divisible by group_size.
|
|
407
|
+
# If the rows are not, then we will pad them. Find the number of
|
|
408
|
+
# groups per row after padding.
|
|
409
|
+
groups_per_row = math.ceil(K / group_size)
|
|
410
|
+
num_groups = M * groups_per_row
|
|
411
|
+
# Find how many groups each thread should process. We do this
|
|
412
|
+
# by assuming that it is good to distribute work evenly over threads.
|
|
413
|
+
num_threads = math.ceil(math.sqrt(a.numel()))
|
|
414
|
+
# Data is loaded in chunks of GROUP_LOAD elements, so theres no reason
|
|
415
|
+
# to ever fewer groups per thread than it.
|
|
416
|
+
GROUP_LOAD = 64
|
|
417
|
+
groups_per_thread = max(math.ceil(num_groups / num_threads), GROUP_LOAD)
|
|
418
|
+
# Determine how much padding, if any is needed for each row.
|
|
419
|
+
if K % group_size != 0:
|
|
420
|
+
padding = group_size - (K % group_size)
|
|
421
|
+
else:
|
|
422
|
+
padding = 0
|
|
423
|
+
|
|
424
|
+
# Create output tensor.
|
|
425
|
+
out_elems = (num_groups * group_size) // 2 + num_groups
|
|
426
|
+
out = torch.empty([out_elems], device=a.device, dtype=torch.uint8)
|
|
427
|
+
|
|
428
|
+
# If using stochastic rounding, create random noise for each group.
|
|
429
|
+
# We use the same random bits as seeds when doing stochastic downcasting.
|
|
430
|
+
if rounding_mode == RoundingMode.stochastic or stochastic_casting:
|
|
431
|
+
# Each group will need a seed.
|
|
432
|
+
rand_bits = torch.randint(
|
|
433
|
+
low=0,
|
|
434
|
+
high=2**31 - 1,
|
|
435
|
+
size=(num_groups,),
|
|
436
|
+
dtype=torch.int32,
|
|
437
|
+
device=a.device,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
rand_bits = None
|
|
441
|
+
|
|
442
|
+
# Check if we need to use int64 for indexing.
|
|
443
|
+
use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
|
|
444
|
+
|
|
445
|
+
# Invoke triton quantization kernel over rows.
|
|
446
|
+
grid = (num_threads,)
|
|
447
|
+
_kernel_quantize_mx4[grid](
|
|
448
|
+
a,
|
|
449
|
+
out,
|
|
450
|
+
rand_bits=rand_bits,
|
|
451
|
+
M=M,
|
|
452
|
+
K=K,
|
|
453
|
+
GROUPS_PER_ROW=groups_per_row,
|
|
454
|
+
GROUPS_PER_THREAD=groups_per_thread,
|
|
455
|
+
ROW_PADDING=padding,
|
|
456
|
+
# pyre-ignore[6]
|
|
457
|
+
GROUP_SIZE=group_size,
|
|
458
|
+
# pyre-ignore[6]
|
|
459
|
+
EBITS=ebits,
|
|
460
|
+
# pyre-ignore[6]
|
|
461
|
+
MBITS=mbits,
|
|
462
|
+
# pyre-ignore[6]
|
|
463
|
+
ROUNDING_MODE=rounding_mode,
|
|
464
|
+
# pyre-ignore[6]
|
|
465
|
+
STOCHASTIC_CASTING=stochastic_casting,
|
|
466
|
+
FP4_EXP_BIAS=get_mx4_exp_bias(ebits),
|
|
467
|
+
# pyre-ignore[6]
|
|
468
|
+
GROUP_LOAD=GROUP_LOAD,
|
|
469
|
+
# pyre-ignore[6]
|
|
470
|
+
USE_INT64=use_int64,
|
|
471
|
+
)
|
|
472
|
+
# Inputs are now fully quantized and ready to return.
|
|
473
|
+
# Try to return in the original shape if possible.
|
|
474
|
+
try:
|
|
475
|
+
output_shape = list(orig_shape[:-1]) + [-1]
|
|
476
|
+
return out.view(output_shape)
|
|
477
|
+
# If we cant, return as a flat array.
|
|
478
|
+
except RuntimeError:
|
|
479
|
+
return out.view(-1)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@triton.jit
|
|
483
|
+
def _kernel_dequantize_mx4(
|
|
484
|
+
A,
|
|
485
|
+
mx4_lookup_table,
|
|
486
|
+
out,
|
|
487
|
+
M,
|
|
488
|
+
GROUPS_PER_THREAD,
|
|
489
|
+
GROUP_SIZE: tl.constexpr,
|
|
490
|
+
GROUP_LOAD: tl.constexpr,
|
|
491
|
+
USE_INT64: tl.constexpr,
|
|
492
|
+
) -> None:
|
|
493
|
+
"""Dequantize a packed MX4 tensor and apply scaling.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
A (Tensor): [M] MX4 tensor packed into int8.
|
|
497
|
+
shared_exp (Tensor): Int8 tensor representing group exponent.
|
|
498
|
+
mx4_lookup_table (Tensor): Map from mx4 integer value to floating point.
|
|
499
|
+
M (int): Total number of elements in input.
|
|
500
|
+
GROUPS_PER_THREAD (int): Number of groups each thread is responsible for.
|
|
501
|
+
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
|
|
502
|
+
GROUP_LOAD (int): Number of groups to process simultaneously.
|
|
503
|
+
USE_INT64 (bool): Whether to use int64 for indexing.
|
|
504
|
+
"""
|
|
505
|
+
# Define constants.
|
|
506
|
+
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
|
|
507
|
+
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
|
|
508
|
+
PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
|
|
509
|
+
|
|
510
|
+
# Get the current thread number.
|
|
511
|
+
pid = tl.program_id(0)
|
|
512
|
+
# For very large tensors, use int64 for indexing. This is slower but necessary.
|
|
513
|
+
if USE_INT64:
|
|
514
|
+
pid = pid.to(tl.int64)
|
|
515
|
+
M = tl.cast(M, tl.int64)
|
|
516
|
+
GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64)
|
|
517
|
+
|
|
518
|
+
# Boundaries for reading input and writing to output tensor.
|
|
519
|
+
INPUT_CHUNK_SIZE = GROUPS_PER_THREAD * PACKED_GROUP_SIZE
|
|
520
|
+
OUTPUT_CHUNK_SIZE = GROUPS_PER_THREAD * GROUP_SIZE
|
|
521
|
+
OUTPUT_SIZE = (M // PACKED_GROUP_SIZE) * GROUP_SIZE
|
|
522
|
+
|
|
523
|
+
# Find the starting offsets for this thread.
|
|
524
|
+
input_start = pid * (GROUPS_PER_THREAD * PACKED_GROUP_SIZE)
|
|
525
|
+
exp_start = input_start + GROUP_SIZE // 2
|
|
526
|
+
# Remove shared exponents from output offset.
|
|
527
|
+
output_start = pid * OUTPUT_CHUNK_SIZE
|
|
528
|
+
# Initiate offset ranges used in this thread.
|
|
529
|
+
# This is a little complicated because we need to skip one value (the shared exponent)
|
|
530
|
+
# every group_size elements.
|
|
531
|
+
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE // 2)
|
|
532
|
+
# Add 1 every GROUP_SIZE / 2 steps so we skip shared exponent.
|
|
533
|
+
exp_indices = input_offset // (GROUP_SIZE // 2)
|
|
534
|
+
input_offset = input_offset + exp_indices + input_start
|
|
535
|
+
# We need to space out each group of the input by 1 since thats the shared exp.
|
|
536
|
+
output_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + output_start
|
|
537
|
+
# Stride exponent access across packed groups.
|
|
538
|
+
exp_offset = exp_indices * PACKED_GROUP_SIZE + exp_start
|
|
539
|
+
|
|
540
|
+
# Iterate over input tensor and unpack mx4 values.
|
|
541
|
+
for _k in range(0, tl.cdiv(GROUPS_PER_THREAD, GROUP_LOAD)):
|
|
542
|
+
a = tl.load(
|
|
543
|
+
A + input_offset,
|
|
544
|
+
# Mask values that are out of this chunk or the main array.
|
|
545
|
+
mask=(input_offset < M) & (input_offset < (INPUT_CHUNK_SIZE * (pid + 1))),
|
|
546
|
+
other=0.0,
|
|
547
|
+
)
|
|
548
|
+
# Extract high and low values from loaded mx4 tile.
|
|
549
|
+
low_mx4 = a & MX4_BIT_MASK
|
|
550
|
+
high_mx4 = (a >> 4) & MX4_BIT_MASK
|
|
551
|
+
|
|
552
|
+
# Get equivalent fp32 values.
|
|
553
|
+
low_fp32 = tl.load(mx4_lookup_table + low_mx4)
|
|
554
|
+
high_fp32 = tl.load(mx4_lookup_table + high_mx4)
|
|
555
|
+
|
|
556
|
+
# Get proper shared exponent and convert it to a float scale.
|
|
557
|
+
exp = tl.load(
|
|
558
|
+
A + exp_offset,
|
|
559
|
+
mask=(exp_offset < M) & (exp_offset < (INPUT_CHUNK_SIZE * (pid + 1))),
|
|
560
|
+
other=0.0,
|
|
561
|
+
)
|
|
562
|
+
# Remove fp32 exponent bias.
|
|
563
|
+
exp = exp.to(tl.int16) - FP32_EXP_BIAS
|
|
564
|
+
|
|
565
|
+
# Convert exponent to scale and apply to input.
|
|
566
|
+
# Requires higher precision to avoid rounding out small values.
|
|
567
|
+
# This might be slow so we should consider just letting them round away.
|
|
568
|
+
scale = tl.exp2(exp.to(tl.float64)).to(tl.float32)
|
|
569
|
+
scaled_low_fp32 = scale * low_fp32
|
|
570
|
+
scaled_high_fp32 = scale * high_fp32
|
|
571
|
+
|
|
572
|
+
# Combine the two components into a single tensor, interweave them.
|
|
573
|
+
scaled_fp32 = tl.interleave(scaled_low_fp32, scaled_high_fp32)
|
|
574
|
+
|
|
575
|
+
# Write final outputs.
|
|
576
|
+
tl.store(
|
|
577
|
+
out + output_offset,
|
|
578
|
+
scaled_fp32,
|
|
579
|
+
# Mask values that are out of this chunk or the main array.
|
|
580
|
+
mask=(output_offset < OUTPUT_SIZE)
|
|
581
|
+
& (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Update indices for next group.
|
|
585
|
+
input_offset += GROUP_LOAD * PACKED_GROUP_SIZE
|
|
586
|
+
exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
|
|
587
|
+
output_offset += GROUP_LOAD * GROUP_SIZE
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def triton_dequantize_mx4(
|
|
591
|
+
a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
|
|
592
|
+
) -> torch.Tensor:
|
|
593
|
+
"""
|
|
594
|
+
Dequantize a tensor from mx4 format to fp32.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
|
|
598
|
+
with group exponents attached to end of each row.
|
|
599
|
+
group_size (int): Size of chunks that use the same shared exponent.
|
|
600
|
+
ebits (int): Number of bits to use for exponent in target mx4 format.
|
|
601
|
+
mbits (int): Number of bits to use for mantissa in target mx4 format.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
torch.Tensor: [M, K] dequantized fp32 tensor.
|
|
605
|
+
"""
|
|
606
|
+
# If given an empty shape, return an empty tensor.
|
|
607
|
+
if a.numel() == 0:
|
|
608
|
+
return torch.empty(a.shape, device=a.device, dtype=torch.float32)
|
|
609
|
+
# View a as 2D for simplicity.
|
|
610
|
+
orig_shape = a.shape
|
|
611
|
+
a = a.flatten()
|
|
612
|
+
# Find number of groups.
|
|
613
|
+
packed_group_size = group_size // 2 + 1
|
|
614
|
+
num_groups = a.numel() // packed_group_size
|
|
615
|
+
# Find a workload that distributes work evenly over threads.
|
|
616
|
+
num_threads = math.ceil(math.sqrt(a.numel()))
|
|
617
|
+
# There is no need to ever have fewer groups per thread than the amount
|
|
618
|
+
# loaded at once.
|
|
619
|
+
GROUP_LOAD = 64
|
|
620
|
+
groups_per_thread = max(math.ceil(num_groups / num_threads), GROUP_LOAD)
|
|
621
|
+
|
|
622
|
+
# Use a lookup table to convert
|
|
623
|
+
mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
|
|
624
|
+
|
|
625
|
+
# Create output tensor.
|
|
626
|
+
output_elems = num_groups * group_size
|
|
627
|
+
out = torch.empty([output_elems], device=a.device, dtype=torch.float)
|
|
628
|
+
# Check if we need to use int64 for indexing.
|
|
629
|
+
use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
|
|
630
|
+
# Invoke triton dequantization kernel over rows.
|
|
631
|
+
grid = (num_threads,)
|
|
632
|
+
_kernel_dequantize_mx4[grid](
|
|
633
|
+
a,
|
|
634
|
+
mx4_to_fp_values,
|
|
635
|
+
out,
|
|
636
|
+
a.numel(),
|
|
637
|
+
GROUPS_PER_THREAD=groups_per_thread,
|
|
638
|
+
# pyre-ignore[6]
|
|
639
|
+
GROUP_SIZE=group_size,
|
|
640
|
+
# pyre-ignore[6]
|
|
641
|
+
GROUP_LOAD=GROUP_LOAD,
|
|
642
|
+
# pyre-ignore[6]
|
|
643
|
+
USE_INT64=use_int64,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
out_shape = list(orig_shape[:-1]) + [-1]
|
|
647
|
+
return out.view(out_shape)
|