fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,647 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+ import math
10
+ from typing import Union
11
+
12
+ import torch
13
+ import triton # @manual
14
+
15
+ import triton.language as tl # @manual
16
+
17
+ from .common import get_mx4_exp_bias, get_mx4_lookup_table, RoundingMode
18
+
19
+
20
+ @triton.jit
21
+ def _floor_log2(x):
22
+ """Helper function to efficiently compute floor(log2(x))
23
+
24
+ Args:
25
+ x (Tensor): FP32 Input tensor to operate on.
26
+
27
+ Returns:
28
+ Tensor: Floor of log2(x).
29
+ """
30
+ # Helpful bit constants.
31
+ FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
32
+ FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
33
+ FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
34
+
35
+ # View x as an integer and extract its exponent.
36
+ x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK
37
+ # Shift exponent down to bottom bits.
38
+ x = x >> FP32_EXP_OFFSET
39
+ # Remove FP32 exponent bias and return.
40
+ return (x - FP32_EXP_BIAS).to(tl.float32)
41
+
42
+
43
+ @triton.jit
44
+ def _compute_exp(
45
+ group_max,
46
+ rounding_mode,
47
+ rand_bits,
48
+ MBITS: tl.constexpr,
49
+ ):
50
+ """Compute shared exponent of group using specified rounding mode.
51
+
52
+ Args:
53
+ group_max (Tensor): Group of values to compute exponent of.
54
+ rounding_mode (int or RoundingMode): Which rounding mode to use.
55
+ rand_bits (int): Random integer values used for stochastic rounding.
56
+ mbits (int): Number of mantissa bits in target mx4 format.
57
+
58
+ Returns:
59
+ Tensor: Shared exponent of group.
60
+ """
61
+ # Define some helpful constants.
62
+ MBITS_FP32: tl.constexpr = 23 # type: ignore[Incompatible variable type]
63
+ M_ROUND: tl.constexpr = (1 << (MBITS_FP32 - MBITS - 1)) - 1 # type: ignore[Incompatible variable type]
64
+ RAND_MASK: tl.constexpr = (1 << (MBITS_FP32 - MBITS)) - 1 # type: ignore[Incompatible variable type]
65
+
66
+ # Nearest rounding mode.
67
+ if rounding_mode == 0:
68
+ return tl.floor(tl.log2(group_max) + 0.5)
69
+ # Floor rounding mode. This can be done with fast bit ops.
70
+ if rounding_mode == 1:
71
+ return _floor_log2(group_max)
72
+ # Even pre-rounding mode.
73
+ elif rounding_mode == 2:
74
+ # Add fixed rounding to the mantissa bits of the input to round during truncation.
75
+ group_max = group_max.to(tl.int32, bitcast=True) + M_ROUND
76
+ # Then perform floor rounding of log.
77
+ return _floor_log2(group_max)
78
+ # Stochastic rounding mode.
79
+ elif rounding_mode == 3:
80
+ # Use random bits to add noise to mantissa that would otherwise
81
+ # be rounded away.
82
+ group_max = group_max.to(tl.int32, bitcast=True) + (RAND_MASK & rand_bits)
83
+ # Now compute log and truncate.
84
+ return _floor_log2(group_max)
85
+ else:
86
+ return tl.ceil(tl.log2(group_max))
87
+
88
+
89
+ @triton.jit
90
+ def _kernel_quantize_mx4(
91
+ A,
92
+ out,
93
+ rand_bits,
94
+ M,
95
+ K,
96
+ GROUPS_PER_ROW,
97
+ GROUPS_PER_THREAD,
98
+ ROW_PADDING,
99
+ GROUP_SIZE: tl.constexpr,
100
+ EBITS: tl.constexpr,
101
+ MBITS: tl.constexpr,
102
+ ROUNDING_MODE: tl.constexpr,
103
+ STOCHASTIC_CASTING: tl.constexpr,
104
+ FP4_EXP_BIAS: tl.constexpr,
105
+ GROUP_LOAD: tl.constexpr,
106
+ USE_INT64: tl.constexpr,
107
+ ) -> None:
108
+ """Quantize a 1D float tensor into a packed MX4 tensor.
109
+
110
+ Args:
111
+ A (Tensor): [M] float tensor to be quantized.
112
+ out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values.
113
+ rand_bits (Optional Tensor): [M, K / 2] random integers used for stochastic rounding.
114
+ M (int): Number of input rows.
115
+ K (int): Number of input columns.
116
+ GROUPS_PER_ROW (int): Number of groups in each row of the input.
117
+ GROUPS_PER_THREAD (int): Number of groups to process per thread.
118
+ ROW_PADDING (int): Number of elements of padding to insert into each row.
119
+ GROUP_SIZE (int): Size of chunks that use the same shared exponent.
120
+ EBITS (int): Number of exponent bits in target mx4 format.
121
+ MBITS (int): Number of mantissa bits in target mx4 format.
122
+ ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent.
123
+ STOCHASTIC_CASTING (bool): Whether to use stochastic rounding when downcasting.
124
+ FP4_EXP_BIAS (int): Exponent bias of target mx4 format.
125
+ GROUP_LOAD (int): Number of groups to process simultaneously.
126
+ USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
127
+ """
128
+ # Define Constant Expressions.
129
+ FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
130
+ FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
131
+ FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
132
+ FP32_SIGN_OFFSET: tl.constexpr = 31 # type: ignore[Incompatible variable type]
133
+ SIGN_MASK: tl.constexpr = 0x1 # type: ignore[Incompatible variable type]
134
+ FP32_MANTISSA_MASK: tl.constexpr = 0x007FFFFF # type: ignore[Incompatible variable type]
135
+ # FP4 has 2 mantissa bits, one explicit one implicit.
136
+ MBITS_IMPLICIT: tl.constexpr = MBITS + 1 # type: ignore[Incompatible variable type]
137
+ MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type]
138
+ IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
139
+ FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
140
+ MANTISSA_OVERFLOW_THRESHOLD: tl.constexpr = (1 << MBITS_IMPLICIT) - 1 # type: ignore[Incompatible variable type]
141
+ EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type]
142
+ IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1
143
+ RAND_MASK: tl.constexpr = (1 << (FP32_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type]
144
+
145
+ # Get the current thread number.
146
+ pid = tl.program_id(0)
147
+ # For very large inputs, we need to use int64 indexes. This is slower but necessary.
148
+ if USE_INT64:
149
+ pid = pid.to(tl.int64)
150
+ M = tl.cast(M, tl.int64)
151
+ K = tl.cast(K, tl.int64)
152
+ GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64)
153
+
154
+ # Boundaries for writing to output tensor.
155
+ PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
156
+ NUM_GROUPS = M * GROUPS_PER_ROW
157
+ OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 2 + GROUPS_PER_THREAD
158
+ OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 2 + NUM_GROUPS
159
+
160
+ # Find starting offsets for this thread. These are calculated before adjusting for padding.
161
+ input_start = pid * (GROUPS_PER_THREAD * GROUP_SIZE)
162
+ output_start = pid * OUTPUT_CHUNK_SIZE
163
+ exp_start = output_start + GROUP_SIZE // 2
164
+ # Initiate offset ranges used in kernel.
165
+ input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start
166
+ output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2))
167
+ # Stochastic rounding loads chunks of random values.
168
+ if ROUNDING_MODE == 3:
169
+ rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * GROUPS_PER_THREAD
170
+ # Ceil rounding uses single values as a seed.
171
+ else:
172
+ rand_bits_offset = pid * GROUPS_PER_THREAD
173
+ # We need to shift output offsets to make space for shared exponent storage.
174
+ output_offset += output_offset // (GROUP_SIZE // 2) + output_start
175
+ # Now create offsets for writing the shared exponent.
176
+ exp_offset = tl.arange(0, GROUP_LOAD) * PACKED_GROUP_SIZE + exp_start
177
+
178
+ # Load and process blocks of values for this chunk.
179
+ for _k in range(0, tl.cdiv(GROUPS_PER_THREAD, GROUP_LOAD)):
180
+ # We need to make some adjustments to allow for padding.
181
+ pad_mask = (input_offset % (GROUPS_PER_ROW * GROUP_SIZE)) < K
182
+ if ROW_PADDING != 0:
183
+ # Shift the input to account for padding.
184
+ padded_input_offset = (
185
+ input_offset
186
+ - (input_offset // (GROUPS_PER_ROW * GROUP_SIZE)) * ROW_PADDING
187
+ )
188
+ # When theres no padding we can simplify indexing.
189
+ else:
190
+ padded_input_offset = input_offset
191
+
192
+ # Load a block of values.
193
+ a = tl.load(
194
+ A + padded_input_offset,
195
+ # Mask values out of range for both the main array and this chunk. Also pad if needed.
196
+ mask=(padded_input_offset < (M * K))
197
+ & (padded_input_offset < ((pid + 1) * GROUPS_PER_THREAD * GROUP_SIZE))
198
+ & pad_mask,
199
+ other=0,
200
+ )
201
+
202
+ # Scaling step
203
+ ##############
204
+
205
+ # View the block in terms of groups.
206
+ a_groups = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
207
+ # Compute the shared exponent of each group.
208
+ group_max = tl.max(tl.abs(a_groups), axis=1)
209
+ # Prevent infinite values in log.
210
+ group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
211
+ # Load relevant random values if doing stochastic rounding
212
+ # or stochastic casting.
213
+ group_rand_bits = None
214
+ if (ROUNDING_MODE) == 3 or STOCHASTIC_CASTING:
215
+ group_rand_bits = tl.load(
216
+ rand_bits + rand_bits_offset,
217
+ mask=rand_bits_offset < K // GROUP_SIZE,
218
+ other=0,
219
+ )
220
+ rand_bits_offset += GROUP_LOAD
221
+ # Compute shared exponent using specified rounding mode.
222
+ group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits, MBITS)
223
+ # Subtract largest exponent in target datatype and remove bias.
224
+ group_exp = group_exp - EBITS
225
+ # Make sure exponent is in valid range.
226
+ group_exp = tl.clamp(group_exp, -127, 125)
227
+
228
+ # Next we scale A in preparation for quantization.
229
+ scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
230
+ # Apply scale to input. We do this by broadcasting scale.
231
+ scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) / tl.reshape(
232
+ scale, [GROUP_LOAD, 1]
233
+ )
234
+ # Reshape back to a flat array.
235
+ scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
236
+
237
+ # We're done with group_exp now so we can write it out.
238
+ # We readd fp32_exp_bias for compatibility with cuda dequant.
239
+ tl.store(
240
+ out + exp_offset,
241
+ (group_exp + FP32_EXP_BIAS).to(tl.uint8),
242
+ # Prevent writing outside this chunk or the main array.
243
+ mask=(exp_offset < OUTPUT_SIZE)
244
+ & (exp_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))),
245
+ )
246
+
247
+ # Quantization step
248
+ ###################
249
+
250
+ # During quantization, we're going to be doing a lot of bitwise operations.
251
+ # This is easier to work with in int32.
252
+ scaled_a = scaled_a.to(tl.int32, bitcast=True)
253
+
254
+ # When doing stochastic downcasting, generate random values for this block
255
+ # and apply it to the mantissa.
256
+ if STOCHASTIC_CASTING:
257
+ # We're going to generate 4 blocks at once so we only need
258
+ # one fourth of the input offsets.
259
+ # Start by splitting down to half of offsets.
260
+ philox_4x_offset = tl.split(
261
+ tl.reshape(
262
+ input_offset,
263
+ [GROUP_LOAD * GROUP_SIZE // 2, 2],
264
+ can_reorder=True,
265
+ )
266
+ )
267
+ # Split down to fourth.
268
+ philox_4x_offset = tl.split(
269
+ tl.reshape(
270
+ philox_4x_offset,
271
+ [GROUP_LOAD * GROUP_SIZE // 4, 2],
272
+ can_reorder=True,
273
+ )
274
+ )
275
+ # Generate 4 blocks of random bits for this block.
276
+ a_4x, b_4x, c_4x, d_4x = tl.randint4x(
277
+ group_rand_bits, philox_4x_offset, n_rounds=7
278
+ )
279
+ # Combine the 4 blocks into a single chunk of random values.
280
+ # This needs to be done incrementally.
281
+ stochastic_round_bits = tl.join(tl.join(a_4x, b_4x), tl.join(c_4x, d_4x))
282
+ # Flatten back to simple array.
283
+ stochastic_round_bits = tl.reshape(
284
+ stochastic_round_bits, [GROUP_LOAD * GROUP_SIZE]
285
+ ).to(tl.int32, bitcast=True)
286
+
287
+ # Mask off mantissa bits of random value and add to mantissa.
288
+ scaled_a = scaled_a + (stochastic_round_bits & RAND_MASK)
289
+
290
+ # Extract sign bit of value.
291
+ sign_bit = (scaled_a >> FP32_SIGN_OFFSET) & SIGN_MASK
292
+
293
+ # Extract exponent.
294
+ biased_exp = (scaled_a & FP32_EXP_MASK) >> FP32_EXP_OFFSET
295
+
296
+ # Extract mantissa.
297
+ trailing_mantissa = scaled_a & FP32_MANTISSA_MASK
298
+
299
+ # Adjust exponent bias for FP4.
300
+ new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS
301
+
302
+ # Compute difference between ideal exponent and what fp4 can represent.
303
+ exp_diff = tl.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
304
+
305
+ # Clip this difference to maximum number of fp32 mantissa bits.
306
+ exp_diff = tl.minimum(exp_diff, MAX_FP32_MANTISSA_BITS)
307
+
308
+ # Now we round our fp32 mantissa down to fp4.
309
+ is_subnorm = biased_exp == 0
310
+ # Add implied 1 bit to normal values.
311
+ mantissa = tl.where(
312
+ is_subnorm, trailing_mantissa, trailing_mantissa + IMPLIED_1_BIT
313
+ )
314
+ # Compute base number of bits corresponding to the mantissa, smaller for subnorms
315
+ # since implied one is included in exp_diff.
316
+ fp32_sig_bits = tl.where(is_subnorm, 23, 24).to(tl.int32)
317
+ # Now we're ready to shift down to target bitwidth (with an extra bit for rounding).
318
+ mantissa = mantissa >> (fp32_sig_bits + exp_diff - MBITS_IMPLICIT - 1)
319
+ # Perform rounding by adding 1 and shifting down.
320
+ mantissa = (mantissa + 1) >> 1
321
+
322
+ # Check for overflow and adjust exponent accordingly.
323
+ overflow = mantissa > MANTISSA_OVERFLOW_THRESHOLD
324
+ # Allow subnorms to overflow into normals, otherwise shift away overflow.
325
+ mantissa = tl.where(overflow and (not is_subnorm), mantissa >> 1, mantissa)
326
+ # Special case where a value is subnormal and has a large mantissa, overflow it.
327
+ new_biased_exp = tl.where(
328
+ (new_biased_exp <= 0) and (mantissa == 2), 1, new_biased_exp
329
+ )
330
+ # Remove implicit 1.
331
+ mantissa = mantissa & IMPLICIT_1_MASK
332
+ # Add overflow to exponent.
333
+ new_biased_exp = tl.where(overflow, new_biased_exp + 1, new_biased_exp)
334
+ # If exp overflows, set mantissa to maximum value (equivalent to clamping).
335
+ mantissa = tl.where(new_biased_exp > EXPONENT_OVERFLOW_THRESHOLD, 1, mantissa)
336
+
337
+ # Construct FP4 value from components.
338
+ new_biased_exp = tl.maximum(
339
+ tl.minimum(new_biased_exp, EXPONENT_OVERFLOW_THRESHOLD), 0
340
+ )
341
+ mx4_value = (new_biased_exp << (MBITS_IMPLICIT - 1)) | mantissa
342
+ mx4_value = (sign_bit << (EBITS + MBITS)) | mx4_value
343
+
344
+ # Extract low and high bits from values.
345
+ low_mx4, high_mx4 = tl.split(
346
+ tl.reshape(mx4_value, [(GROUP_LOAD * GROUP_SIZE) // 2, 2])
347
+ )
348
+ # Shift mx4 values together so they are packed into int8.
349
+ packed_mx4 = ((high_mx4 << 4) | (low_mx4)).to(tl.int8)
350
+
351
+ # Write out packed values to output tensor.
352
+ tl.store(
353
+ out + output_offset,
354
+ packed_mx4,
355
+ # Prevent writing outside this chunk or the main array.
356
+ mask=(output_offset < OUTPUT_SIZE)
357
+ & (output_offset < (OUTPUT_CHUNK_SIZE * (pid + 1))),
358
+ )
359
+
360
+ # Update offsets so we work on the next block.
361
+ input_offset += GROUP_LOAD * GROUP_SIZE
362
+ exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
363
+ output_offset += GROUP_LOAD * PACKED_GROUP_SIZE
364
+
365
+
366
+ def triton_quantize_mx4(
367
+ a: torch.Tensor,
368
+ group_size: int = 32,
369
+ ebits: int = 2,
370
+ mbits: int = 1,
371
+ rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
372
+ stochastic_casting: bool = False,
373
+ ) -> torch.Tensor:
374
+ """
375
+ Quantize a tensor to mx4 format using efficient triton kernels.
376
+
377
+ Args:
378
+ a (Tensor): [M] higher precision input tensor.
379
+ group_size (int): Size of chunks that will use the same shared exponent.
380
+ ebits (int): Number of bits to use for exponent in target mx4 format.
381
+ mbits (int): Number of bits to use for mantissa in target mx4 format.
382
+ rounding_mode (Union[RoundingMode, int]): Which type of rounding to use
383
+ when calculating shared exponent. Defaults to pre-rounding to nearest even int.
384
+ stochastic_casting (bool): Whether to use stochastic casting.
385
+
386
+ Returns:
387
+ torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
388
+ with group exponents attached to each row.
389
+
390
+ eg.
391
+ Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
392
+ each value contain two elements packed into an int8 and
393
+ there are 32 groups in each row.
394
+ """
395
+ # If given an empty shape, return an empty tensor.
396
+ if a.numel() == 0:
397
+ return torch.empty(a.shape, device=a.device, dtype=torch.uint8)
398
+ # Make sure input is continuous in memory.
399
+ assert a.is_contiguous(), "Inputs to mx4 quantize must be contiguous in memory."
400
+
401
+ orig_shape = a.shape
402
+ # For simplicity, view input as a 2D array.
403
+ a = a.view(-1, a.shape[-1])
404
+ # Extract rows and columns.
405
+ M, K = a.shape
406
+ # In this kernel, we want each row to be divisible by group_size.
407
+ # If the rows are not, then we will pad them. Find the number of
408
+ # groups per row after padding.
409
+ groups_per_row = math.ceil(K / group_size)
410
+ num_groups = M * groups_per_row
411
+ # Find how many groups each thread should process. We do this
412
+ # by assuming that it is good to distribute work evenly over threads.
413
+ num_threads = math.ceil(math.sqrt(a.numel()))
414
+ # Data is loaded in chunks of GROUP_LOAD elements, so theres no reason
415
+ # to ever fewer groups per thread than it.
416
+ GROUP_LOAD = 64
417
+ groups_per_thread = max(math.ceil(num_groups / num_threads), GROUP_LOAD)
418
+ # Determine how much padding, if any is needed for each row.
419
+ if K % group_size != 0:
420
+ padding = group_size - (K % group_size)
421
+ else:
422
+ padding = 0
423
+
424
+ # Create output tensor.
425
+ out_elems = (num_groups * group_size) // 2 + num_groups
426
+ out = torch.empty([out_elems], device=a.device, dtype=torch.uint8)
427
+
428
+ # If using stochastic rounding, create random noise for each group.
429
+ # We use the same random bits as seeds when doing stochastic downcasting.
430
+ if rounding_mode == RoundingMode.stochastic or stochastic_casting:
431
+ # Each group will need a seed.
432
+ rand_bits = torch.randint(
433
+ low=0,
434
+ high=2**31 - 1,
435
+ size=(num_groups,),
436
+ dtype=torch.int32,
437
+ device=a.device,
438
+ )
439
+ else:
440
+ rand_bits = None
441
+
442
+ # Check if we need to use int64 for indexing.
443
+ use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
444
+
445
+ # Invoke triton quantization kernel over rows.
446
+ grid = (num_threads,)
447
+ _kernel_quantize_mx4[grid](
448
+ a,
449
+ out,
450
+ rand_bits=rand_bits,
451
+ M=M,
452
+ K=K,
453
+ GROUPS_PER_ROW=groups_per_row,
454
+ GROUPS_PER_THREAD=groups_per_thread,
455
+ ROW_PADDING=padding,
456
+ # pyre-ignore[6]
457
+ GROUP_SIZE=group_size,
458
+ # pyre-ignore[6]
459
+ EBITS=ebits,
460
+ # pyre-ignore[6]
461
+ MBITS=mbits,
462
+ # pyre-ignore[6]
463
+ ROUNDING_MODE=rounding_mode,
464
+ # pyre-ignore[6]
465
+ STOCHASTIC_CASTING=stochastic_casting,
466
+ FP4_EXP_BIAS=get_mx4_exp_bias(ebits),
467
+ # pyre-ignore[6]
468
+ GROUP_LOAD=GROUP_LOAD,
469
+ # pyre-ignore[6]
470
+ USE_INT64=use_int64,
471
+ )
472
+ # Inputs are now fully quantized and ready to return.
473
+ # Try to return in the original shape if possible.
474
+ try:
475
+ output_shape = list(orig_shape[:-1]) + [-1]
476
+ return out.view(output_shape)
477
+ # If we cant, return as a flat array.
478
+ except RuntimeError:
479
+ return out.view(-1)
480
+
481
+
482
+ @triton.jit
483
+ def _kernel_dequantize_mx4(
484
+ A,
485
+ mx4_lookup_table,
486
+ out,
487
+ M,
488
+ GROUPS_PER_THREAD,
489
+ GROUP_SIZE: tl.constexpr,
490
+ GROUP_LOAD: tl.constexpr,
491
+ USE_INT64: tl.constexpr,
492
+ ) -> None:
493
+ """Dequantize a packed MX4 tensor and apply scaling.
494
+
495
+ Args:
496
+ A (Tensor): [M] MX4 tensor packed into int8.
497
+ shared_exp (Tensor): Int8 tensor representing group exponent.
498
+ mx4_lookup_table (Tensor): Map from mx4 integer value to floating point.
499
+ M (int): Total number of elements in input.
500
+ GROUPS_PER_THREAD (int): Number of groups each thread is responsible for.
501
+ GROUP_SIZE (int): Size of chunks that use the same shared exponent.
502
+ GROUP_LOAD (int): Number of groups to process simultaneously.
503
+ USE_INT64 (bool): Whether to use int64 for indexing.
504
+ """
505
+ # Define constants.
506
+ MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
507
+ FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
508
+ PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
509
+
510
+ # Get the current thread number.
511
+ pid = tl.program_id(0)
512
+ # For very large tensors, use int64 for indexing. This is slower but necessary.
513
+ if USE_INT64:
514
+ pid = pid.to(tl.int64)
515
+ M = tl.cast(M, tl.int64)
516
+ GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64)
517
+
518
+ # Boundaries for reading input and writing to output tensor.
519
+ INPUT_CHUNK_SIZE = GROUPS_PER_THREAD * PACKED_GROUP_SIZE
520
+ OUTPUT_CHUNK_SIZE = GROUPS_PER_THREAD * GROUP_SIZE
521
+ OUTPUT_SIZE = (M // PACKED_GROUP_SIZE) * GROUP_SIZE
522
+
523
+ # Find the starting offsets for this thread.
524
+ input_start = pid * (GROUPS_PER_THREAD * PACKED_GROUP_SIZE)
525
+ exp_start = input_start + GROUP_SIZE // 2
526
+ # Remove shared exponents from output offset.
527
+ output_start = pid * OUTPUT_CHUNK_SIZE
528
+ # Initiate offset ranges used in this thread.
529
+ # This is a little complicated because we need to skip one value (the shared exponent)
530
+ # every group_size elements.
531
+ input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE // 2)
532
+ # Add 1 every GROUP_SIZE / 2 steps so we skip shared exponent.
533
+ exp_indices = input_offset // (GROUP_SIZE // 2)
534
+ input_offset = input_offset + exp_indices + input_start
535
+ # We need to space out each group of the input by 1 since thats the shared exp.
536
+ output_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + output_start
537
+ # Stride exponent access across packed groups.
538
+ exp_offset = exp_indices * PACKED_GROUP_SIZE + exp_start
539
+
540
+ # Iterate over input tensor and unpack mx4 values.
541
+ for _k in range(0, tl.cdiv(GROUPS_PER_THREAD, GROUP_LOAD)):
542
+ a = tl.load(
543
+ A + input_offset,
544
+ # Mask values that are out of this chunk or the main array.
545
+ mask=(input_offset < M) & (input_offset < (INPUT_CHUNK_SIZE * (pid + 1))),
546
+ other=0.0,
547
+ )
548
+ # Extract high and low values from loaded mx4 tile.
549
+ low_mx4 = a & MX4_BIT_MASK
550
+ high_mx4 = (a >> 4) & MX4_BIT_MASK
551
+
552
+ # Get equivalent fp32 values.
553
+ low_fp32 = tl.load(mx4_lookup_table + low_mx4)
554
+ high_fp32 = tl.load(mx4_lookup_table + high_mx4)
555
+
556
+ # Get proper shared exponent and convert it to a float scale.
557
+ exp = tl.load(
558
+ A + exp_offset,
559
+ mask=(exp_offset < M) & (exp_offset < (INPUT_CHUNK_SIZE * (pid + 1))),
560
+ other=0.0,
561
+ )
562
+ # Remove fp32 exponent bias.
563
+ exp = exp.to(tl.int16) - FP32_EXP_BIAS
564
+
565
+ # Convert exponent to scale and apply to input.
566
+ # Requires higher precision to avoid rounding out small values.
567
+ # This might be slow so we should consider just letting them round away.
568
+ scale = tl.exp2(exp.to(tl.float64)).to(tl.float32)
569
+ scaled_low_fp32 = scale * low_fp32
570
+ scaled_high_fp32 = scale * high_fp32
571
+
572
+ # Combine the two components into a single tensor, interweave them.
573
+ scaled_fp32 = tl.interleave(scaled_low_fp32, scaled_high_fp32)
574
+
575
+ # Write final outputs.
576
+ tl.store(
577
+ out + output_offset,
578
+ scaled_fp32,
579
+ # Mask values that are out of this chunk or the main array.
580
+ mask=(output_offset < OUTPUT_SIZE)
581
+ & (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
582
+ )
583
+
584
+ # Update indices for next group.
585
+ input_offset += GROUP_LOAD * PACKED_GROUP_SIZE
586
+ exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
587
+ output_offset += GROUP_LOAD * GROUP_SIZE
588
+
589
+
590
+ def triton_dequantize_mx4(
591
+ a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
592
+ ) -> torch.Tensor:
593
+ """
594
+ Dequantize a tensor from mx4 format to fp32.
595
+
596
+ Args:
597
+ a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
598
+ with group exponents attached to end of each row.
599
+ group_size (int): Size of chunks that use the same shared exponent.
600
+ ebits (int): Number of bits to use for exponent in target mx4 format.
601
+ mbits (int): Number of bits to use for mantissa in target mx4 format.
602
+
603
+ Returns:
604
+ torch.Tensor: [M, K] dequantized fp32 tensor.
605
+ """
606
+ # If given an empty shape, return an empty tensor.
607
+ if a.numel() == 0:
608
+ return torch.empty(a.shape, device=a.device, dtype=torch.float32)
609
+ # View a as 2D for simplicity.
610
+ orig_shape = a.shape
611
+ a = a.flatten()
612
+ # Find number of groups.
613
+ packed_group_size = group_size // 2 + 1
614
+ num_groups = a.numel() // packed_group_size
615
+ # Find a workload that distributes work evenly over threads.
616
+ num_threads = math.ceil(math.sqrt(a.numel()))
617
+ # There is no need to ever have fewer groups per thread than the amount
618
+ # loaded at once.
619
+ GROUP_LOAD = 64
620
+ groups_per_thread = max(math.ceil(num_groups / num_threads), GROUP_LOAD)
621
+
622
+ # Use a lookup table to convert
623
+ mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
624
+
625
+ # Create output tensor.
626
+ output_elems = num_groups * group_size
627
+ out = torch.empty([output_elems], device=a.device, dtype=torch.float)
628
+ # Check if we need to use int64 for indexing.
629
+ use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
630
+ # Invoke triton dequantization kernel over rows.
631
+ grid = (num_threads,)
632
+ _kernel_dequantize_mx4[grid](
633
+ a,
634
+ mx4_to_fp_values,
635
+ out,
636
+ a.numel(),
637
+ GROUPS_PER_THREAD=groups_per_thread,
638
+ # pyre-ignore[6]
639
+ GROUP_SIZE=group_size,
640
+ # pyre-ignore[6]
641
+ GROUP_LOAD=GROUP_LOAD,
642
+ # pyre-ignore[6]
643
+ USE_INT64=use_int64,
644
+ )
645
+
646
+ out_shape = list(orig_shape[:-1]) + [-1]
647
+ return out.view(out_shape)