fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,707 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import os
9
+ import sys
10
+
11
+ from dataclasses import dataclass
12
+ from datetime import datetime
13
+ from typing import Any, Optional
14
+
15
+ import click
16
+
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+
20
+ import pandas as pd
21
+ import seaborn as sns
22
+ import torch
23
+ from tabulate import tabulate
24
+
25
+ try:
26
+ from accelerators.utils.torch_profiler import profiler_or_nullcontext
27
+ except ImportError:
28
+ from contextlib import nullcontext
29
+
30
+ class profiler_or_nullcontext(nullcontext):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__()
33
+
34
+
35
+ from fbgemm_gpu.experimental.gen_ai.bench.quantize_ops import (
36
+ get_quantize_ops,
37
+ QuantizeOpBase,
38
+ )
39
+
40
+
41
+ def generate_group_tensor(G, M):
42
+ """
43
+ Generate a tensor with G elements whose integer elements sum to A.
44
+
45
+ Args:
46
+ G (int): Number of elements in the tensor.
47
+ M (int): Sum of the elements in the tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: A tensor with G elements whose integer elements sum to M.
51
+ """
52
+
53
+ # First, we generate a random tensor with G elements
54
+ random_tensor = torch.rand(G)
55
+ # Then, we normalize this tensor so it sums up to 1
56
+ normalized_tensor = random_tensor / random_tensor.sum()
57
+ # Finally, we multiply this tensor by M and round to the nearest integer
58
+ output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
59
+ # Adjust the last element to ensure the sum is exactly M
60
+ output_tensor[-1] += max(0, M - output_tensor.sum())
61
+ return output_tensor.tolist()
62
+
63
+
64
+ def set_amd_env_vars() -> None:
65
+ print("Setting environment variables for AMD GPU performance")
66
+ os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
67
+ os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
68
+ os.environ["PYTORCH_TUNABLEOP_VERBOSE"] = "0"
69
+ os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
70
+ os.environ["PYTORCH_TUNABLEOP_TUNING"] = "1"
71
+ os.environ["PYTORCH_TUNABLEOP_FILENAME"] = "hipblas_tuning_pt_llama.csv"
72
+ os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"
73
+ os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"
74
+
75
+
76
+ def get_llama_shapes() -> list[tuple[int, int, int, int]]:
77
+ # Helper function that returns a list of shapes relevant to llama.
78
+
79
+ llama_shapes = []
80
+ for M in [1, 16, 32, 64, 96, 128, 16384]:
81
+ # Add shapes for llama3 70B
82
+ llama_shapes += [
83
+ (1, M, 1280, 8192),
84
+ (1, M, 8192, 1024),
85
+ (1, M, 7168, 8192),
86
+ (1, M, 8192, 3584),
87
+ ]
88
+ # Add shapes for llama3 405B
89
+ llama_shapes += [
90
+ (1, M, 13312, 6656),
91
+ (1, M, 13312, 16384),
92
+ (1, M, 16384, 6656),
93
+ (1, M, 16384, 16384),
94
+ ]
95
+ # Add shapes for llama4 Scout/Maverick (17Bx{16,128})
96
+ llama_shapes += [
97
+ (1, M, 896, 5120),
98
+ (1, M, 5120, 640),
99
+ (1, M, 2048, 5120),
100
+ (1, M, 5120, 1024),
101
+ ]
102
+
103
+ return llama_shapes
104
+
105
+
106
+ def get_ldm_shapes() -> list[tuple[int, int, int, int]]:
107
+ # Helper function that returns a list of shapes relevant to ldm.
108
+ return [
109
+ (1, 1536, 3584, 3584),
110
+ (1, 8192, 9728, 3584),
111
+ (1, 8192, 3584, 9728),
112
+ (1, 8192, 3584, 3584),
113
+ (1, 4096, 3584, 3584),
114
+ (1, 768, 3584, 3584),
115
+ (1, 4096, 9728, 3584),
116
+ (1, 4096, 3584, 9728),
117
+ (1, 7200, 3584, 3584),
118
+ (1, 7200, 9728, 3584),
119
+ (1, 7200, 3584, 9728),
120
+ (1, 3600, 3584, 3584),
121
+ (1, 3600, 9728, 3584),
122
+ (1, 3600, 3584, 9728),
123
+ (1, 1536, 4096, 4096),
124
+ (1, 3600, 4096, 4096),
125
+ (1, 3600, 11008, 4096),
126
+ (1, 3600, 4096, 11008),
127
+ (1, 4096, 4096, 4096),
128
+ (1, 4096, 11008, 4096),
129
+ (1, 4096, 4096, 11008),
130
+ (1, 32768, 128, 8192),
131
+ (1, 32768, 8192, 1024),
132
+ (1, 32768, 8192, 3072),
133
+ (1, 32768, 3072, 8192),
134
+ (1, 32768, 1024, 8192),
135
+ ]
136
+
137
+
138
+ @dataclass
139
+ class Metrics:
140
+ op_name: str
141
+
142
+ sim: float = 0.0
143
+ ms: float = 0.0
144
+ tflops: float = 0.0
145
+ gbps: float = 0.0
146
+
147
+ def __str__(self) -> str:
148
+ return (
149
+ "%s sim: %.3f.\n%s ms: %.3f. \n" "%s TFLOPS: %.3f. \n%s GB/s: %.3f."
150
+ ) % (
151
+ self.op_name,
152
+ self.sim,
153
+ self.op_name,
154
+ self.ms,
155
+ self.op_name,
156
+ self.tflops,
157
+ self.op_name,
158
+ self.gbps,
159
+ )
160
+
161
+
162
+ def benchmark_grouped(
163
+ quantize_ops: list[QuantizeOpBase],
164
+ b: list[int],
165
+ m: list[int],
166
+ n: list[int],
167
+ k: list[int],
168
+ bench_quantize: bool = False,
169
+ use_rotating_buffer_bench: bool = False,
170
+ use_cuda_graph: bool = True,
171
+ trace: bool = False,
172
+ num_iters: int = 1,
173
+ fast_accum: bool = True,
174
+ torch_compile: bool = False,
175
+ ) -> dict[str, Any]:
176
+ num_groups = len(m)
177
+ # Create input tensors.
178
+ A = []
179
+ B = []
180
+ for i in range(num_groups):
181
+ if b[i] > 1:
182
+ A.append(torch.randn(b[i], m[i], k[i], device="cuda", dtype=torch.bfloat16))
183
+ B.append(torch.randn(b[i], n[i], k[i], device="cuda", dtype=torch.bfloat16))
184
+ else:
185
+ A.append(torch.randn(m[i], k[i], device="cuda", dtype=torch.bfloat16))
186
+ B.append(torch.randn(n[i], k[i], device="cuda", dtype=torch.bfloat16))
187
+ # Compute baseline output for correctness checking.
188
+ out_ref = []
189
+ for i in range(num_groups):
190
+ out_ref.append(torch.matmul(A[i], B[i].t()))
191
+ # Keep track of results.
192
+ # Only log all shapes in a group if they are unique.
193
+ log_m = m[0] if len(np.unique(m)) == 1 else m
194
+ log_n = n[0] if len(np.unique(n)) == 1 else n
195
+ log_k = k[0] if len(np.unique(k)) == 1 else k
196
+ results: dict[str, Any] = {"M": log_m, "N": log_n, "K": log_k, "groups": num_groups}
197
+ # Benchmark each operator.
198
+ for quantize_op in quantize_ops:
199
+ metrics = Metrics(op_name=quantize_op.name)
200
+ # Set fast accum mode if applicable.
201
+ if hasattr(quantize_op, "fast_accum"):
202
+ quantize_op.fast_accum = fast_accum
203
+ if hasattr(quantize_op, "torch_compile"):
204
+ quantize_op.torch_compile = torch_compile
205
+ # Get the quantized tensors for this operator.
206
+ preprocessed_args = quantize_op.preprocess(A, B)
207
+ quantized_vals = quantize_op.quantize(*preprocessed_args)
208
+ # Compute the output given quantized values.
209
+ output = quantize_op.compute(*quantized_vals)
210
+ # Some kernels may pad output, just take the first m values of each row.
211
+ if isinstance(output, torch.Tensor) and output.ndim == 2:
212
+ # Output is stacked and needs to be split.
213
+ output = torch.split(output, m, dim=0)
214
+ else:
215
+ # Otherwise output may be padded or require unbinding.
216
+ output = [o[: m[i]] for i, o in enumerate(output)]
217
+ # Compare the quantize op output to reference as a sanity check.
218
+ for i in range(num_groups):
219
+ if m[i] > 0:
220
+ metrics.sim += float(
221
+ torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
222
+ )
223
+ for _ in range(num_iters):
224
+ # Now perform benchmark.
225
+ if bench_quantize:
226
+ # Benchmark both quantize and compute.
227
+ with profiler_or_nullcontext(enabled=trace, with_stack=True):
228
+ ms_runtime = quantize_op.benchmark(
229
+ *preprocessed_args,
230
+ bench_quantize=True,
231
+ use_rotating_buffer_bench=use_rotating_buffer_bench,
232
+ use_cuda_graph=use_cuda_graph,
233
+ )
234
+ else:
235
+ with profiler_or_nullcontext(enabled=trace, with_stack=True):
236
+ ms_runtime = quantize_op.benchmark(
237
+ *quantized_vals,
238
+ bench_quantize=False,
239
+ use_rotating_buffer_bench=use_rotating_buffer_bench,
240
+ use_cuda_graph=use_cuda_graph,
241
+ )
242
+
243
+ # Print out results for this op.
244
+ for i in range(num_groups):
245
+ metrics.tflops += (
246
+ 2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
247
+ )
248
+ output_multiplier = 2 if "fuse_scatter_add" in quantize_op.name else 1
249
+ if m[i] > 0:
250
+ metrics.gbps += (
251
+ (
252
+ b[i] * m[i] * k[i] * quantized_vals[0][0].element_size()
253
+ + b[i] * n[i] * k[i] * quantized_vals[1][0].element_size()
254
+ + output_multiplier
255
+ * b[i]
256
+ * m[i]
257
+ * n[i]
258
+ * output[0].element_size()
259
+ )
260
+ / (ms_runtime / 1e3)
261
+ / 1e9
262
+ )
263
+ metrics.ms += ms_runtime
264
+ metrics.ms /= num_iters
265
+ metrics.tflops /= num_iters
266
+ metrics.gbps /= num_iters
267
+ print(f"Average metrics over {num_iters} iterations:")
268
+ print(metrics)
269
+
270
+ # Save results for this operator.
271
+ results[f"{quantize_op.name}_sim"] = metrics.sim
272
+ results[f"{quantize_op.name}_ms"] = metrics.ms
273
+ results[f"{quantize_op.name}_tflops"] = metrics.tflops
274
+ results[f"{quantize_op.name}_gb/s"] = metrics.gbps
275
+
276
+ return results
277
+
278
+
279
+ def benchmark(
280
+ quantize_ops: list[QuantizeOpBase],
281
+ b: int,
282
+ m: int,
283
+ n: int,
284
+ k: int,
285
+ bench_quantize: bool = False,
286
+ use_rotating_buffer_bench: bool = False,
287
+ use_cuda_graph: bool = True,
288
+ trace: bool = False,
289
+ num_iters: int = 1,
290
+ fast_accum: bool = True,
291
+ torch_compile: bool = False,
292
+ ) -> dict[str, Any]:
293
+ # Create input tensors.
294
+ if b > 1:
295
+ A = torch.randn(b, m, k, device="cuda", dtype=torch.bfloat16)
296
+ B = torch.randn(b, n, k, device="cuda", dtype=torch.bfloat16)
297
+ else:
298
+ A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
299
+ B = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
300
+
301
+ # Compute baseline output for correctness checking.
302
+ out_ref = torch.matmul(A, torch.transpose(B, -2, -1))
303
+ # Keep track of results.
304
+ results: dict[str, Any] = {"B": b, "M": m, "N": n, "K": k}
305
+ # Benchmark each operator.
306
+ for quantize_op in quantize_ops:
307
+ metrics = Metrics(op_name=quantize_op.name)
308
+ # Set fast accum mode if applicable.
309
+ if hasattr(quantize_op, "fast_accum"):
310
+ quantize_op.fast_accum = fast_accum
311
+ if hasattr(quantize_op, "torch_compile"):
312
+ quantize_op.torch_compile = torch_compile
313
+ # Preprocess data if needed.
314
+ preprocessed_args = quantize_op.preprocess(A, B)
315
+ # Get the quantized tensors for this operator.
316
+ quantized_vals = quantize_op.quantize(*preprocessed_args)
317
+ # Compute the output given quantized values.
318
+ output = quantize_op.compute(*quantized_vals)
319
+ # Compare the quantize op output to reference as a sanity check.
320
+ # TODO(shikaili): This calculation is incorrect for scatter add fusion.
321
+ metrics.sim = torch.mean(torch.pow(output - out_ref, 2)).item()
322
+
323
+ for _ in range(num_iters):
324
+ # Now perform benchmark.
325
+ if bench_quantize:
326
+ # Benchmark both quantize and compute.
327
+ with profiler_or_nullcontext(enabled=trace, with_stack=True):
328
+ ms_runtime = quantize_op.benchmark(
329
+ *preprocessed_args,
330
+ bench_quantize=True,
331
+ use_rotating_buffer_bench=use_rotating_buffer_bench,
332
+ use_cuda_graph=use_cuda_graph,
333
+ )
334
+ else:
335
+ with profiler_or_nullcontext(enabled=trace, with_stack=True):
336
+ ms_runtime = quantize_op.benchmark(
337
+ *quantized_vals,
338
+ bench_quantize=False,
339
+ use_rotating_buffer_bench=use_rotating_buffer_bench,
340
+ use_cuda_graph=use_cuda_graph,
341
+ )
342
+
343
+ # Print out results for this op.
344
+ metrics.tflops += 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12
345
+ metrics.gbps += (
346
+ (
347
+ quantized_vals[0].numel() * quantized_vals[0].element_size()
348
+ + quantized_vals[1].numel() * quantized_vals[1].element_size()
349
+ + output.numel() * output.element_size()
350
+ )
351
+ / (ms_runtime / 1e3)
352
+ / 1e9
353
+ )
354
+ metrics.ms += ms_runtime
355
+ # Print out results for this op.
356
+ metrics.ms /= num_iters
357
+ metrics.tflops /= num_iters
358
+ metrics.gbps /= num_iters
359
+ print(f"Average metrics over {num_iters} iterations:")
360
+ print(metrics)
361
+
362
+ # Save results for this operator.
363
+ results[f"{quantize_op.name}_sim"] = metrics.sim
364
+ results[f"{quantize_op.name}_ms"] = metrics.ms
365
+ results[f"{quantize_op.name}_tflops"] = metrics.tflops
366
+ results[f"{quantize_op.name}_gb/s"] = metrics.gbps
367
+
368
+ return results
369
+
370
+
371
+ def plot_benchmark(results: list[dict[str, Any]], output_dir: str) -> None:
372
+ """Create a barplot visualizing the TFLOPS of each kernel."""
373
+ # Reprocess into new dataframe with proper graph format.
374
+ data = []
375
+ # Extract measurements for each shape.
376
+ for impl in results:
377
+ mnk = f"{impl['M']}, {impl['N']}, {impl['K']}"
378
+ # Iterate over keys to find tflops entries.
379
+ for key in impl:
380
+ if "tflops" in key:
381
+ op_name = key.split("_tflops")[0]
382
+ op_tflops = impl[key]
383
+ data.append({"MNK": mnk, "kernel": op_name, "TFLOPS": op_tflops})
384
+
385
+ # Create a barplot using seaborn.
386
+ df = pd.DataFrame(data)
387
+ plot = plt.figure()
388
+ plt.xticks(rotation=30)
389
+ plt.yscale("log")
390
+ ax = sns.barplot(x="MNK", y="TFLOPS", hue="kernel", data=df)
391
+ ax.tick_params(axis="x", labelsize=3)
392
+ img_fn = os.path.join(output_dir, "quantize_ops_benchmark.png")
393
+ plot.savefig(img_fn, dpi=300)
394
+ print(f"Plot saved to {img_fn}")
395
+
396
+
397
+ def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[QuantizeOpBase]:
398
+ # Get existing quantization operators.
399
+ quantize_ops = get_quantize_ops()
400
+ quantize_ops = [op for op in quantize_ops if op.supported]
401
+ if kernels is None:
402
+ return quantize_ops
403
+ return [op for op in quantize_ops if op.name in kernels]
404
+
405
+
406
+ def print_kernels(kernels: Optional[list[str]]) -> list[QuantizeOpBase]:
407
+ data = sorted(
408
+ [
409
+ (op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
410
+ for op in get_quantize_ops()
411
+ ]
412
+ )
413
+ print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
414
+
415
+
416
+ @click.command()
417
+ @click.option(
418
+ "--output-dir",
419
+ default="/tmp",
420
+ help="Directory to save plots and csvs to",
421
+ )
422
+ @click.option(
423
+ "--num-iters",
424
+ default=1,
425
+ type=int,
426
+ help="Number of iterations to repeat each benchmark.",
427
+ )
428
+ @click.option(
429
+ "--export-csv",
430
+ is_flag=True,
431
+ help="Export results to a CSV file.",
432
+ )
433
+ @click.option(
434
+ "--plot",
435
+ is_flag=True,
436
+ help="Create a plot of the benchmark measurements.",
437
+ )
438
+ @click.option(
439
+ "--enable-amd-env-vars",
440
+ is_flag=True,
441
+ help="Enable a set of environment variables for AMD GPU performance",
442
+ )
443
+ @click.option(
444
+ "--bench-quantize",
445
+ is_flag=True,
446
+ help="If set, include quantization cost in benchmark.",
447
+ )
448
+ @click.option(
449
+ "--kernels",
450
+ default=None,
451
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
452
+ )
453
+ @click.option(
454
+ "--B",
455
+ default=None,
456
+ help="Comma separated list of batches to benchmark.",
457
+ )
458
+ @click.option(
459
+ "--M",
460
+ default=None,
461
+ help="Comma separated list of M values to benchmark.",
462
+ )
463
+ @click.option(
464
+ "--N",
465
+ default=None,
466
+ help="Comma separated list of N values to benchmark",
467
+ )
468
+ @click.option(
469
+ "--K",
470
+ default=None,
471
+ help="Comma separated list of K values to benchmark.",
472
+ )
473
+ @click.option(
474
+ "--pair-NK",
475
+ is_flag=True,
476
+ help="If set, instead of benchmarking cartesian product of N * K, benchmark consecutive NK pairs together.",
477
+ )
478
+ @click.option(
479
+ "--grouped",
480
+ is_flag=True,
481
+ help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
482
+ "as the size of groups. The length of each must be the same.",
483
+ )
484
+ @click.option(
485
+ "--groups",
486
+ default=None,
487
+ help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark",
488
+ )
489
+ @click.option(
490
+ "--total-K",
491
+ default=None,
492
+ help="If set, adjusts the K values to sum to this number. "
493
+ "This can help simulate real grouped workloads in backward wgrad. "
494
+ "Comma separated list of total-K values to benchmark.",
495
+ )
496
+ @click.option(
497
+ "--total-M",
498
+ default=None,
499
+ help="If set, adjusts the M values to sum to this number. "
500
+ "This can help simulate real grouped workloads."
501
+ "Comma separated list of total-M values to benchmark.",
502
+ )
503
+ @click.option(
504
+ "--no-cuda-graph",
505
+ is_flag=True,
506
+ help="If set, do not use cuda graph for benchmarking.",
507
+ )
508
+ @click.option(
509
+ "--use-rotating-buffer-bench",
510
+ is_flag=True,
511
+ help="If set, use rotating buffer to benchmark.",
512
+ )
513
+ @click.option(
514
+ "--use-llama-shapes",
515
+ is_flag=True,
516
+ help="If set, benchmark using fixed shapes relevant to llama workloads.",
517
+ )
518
+ @click.option(
519
+ "--use-ldm-shapes",
520
+ is_flag=True,
521
+ help="If set, benchmark using fixed shapes relevant to ldm workloads.",
522
+ )
523
+ @click.option(
524
+ "--trace",
525
+ is_flag=True,
526
+ help="If set, produce a performance trace of the benchmark.",
527
+ )
528
+ @click.option(
529
+ "--disable-fast-accum",
530
+ is_flag=True,
531
+ help="If set, disable fast accumulation for FP8 implementations.",
532
+ )
533
+ @click.option(
534
+ "--torch-compile",
535
+ is_flag=True,
536
+ help="If set, torch.compile will be used for scaled_mm backed ops.",
537
+ )
538
+ def invoke_main(
539
+ output_dir: str,
540
+ num_iters: int,
541
+ export_csv: bool,
542
+ plot: bool,
543
+ enable_amd_env_vars: bool,
544
+ bench_quantize: bool,
545
+ kernels: Optional[str],
546
+ b: Optional[str],
547
+ m: Optional[str],
548
+ n: Optional[str],
549
+ k: Optional[str],
550
+ pair_nk: bool,
551
+ grouped: bool,
552
+ groups: Optional[str],
553
+ total_k: Optional[str],
554
+ total_m: Optional[str],
555
+ no_cuda_graph: bool,
556
+ use_rotating_buffer_bench: bool,
557
+ use_llama_shapes: bool,
558
+ use_ldm_shapes: bool,
559
+ trace: bool,
560
+ disable_fast_accum: bool,
561
+ torch_compile: bool,
562
+ ):
563
+ if enable_amd_env_vars:
564
+ set_amd_env_vars()
565
+
566
+ # Validate that total_m and total_k are mutually exclusive
567
+ if total_m is not None and total_k is not None:
568
+ raise ValueError(
569
+ "total_m and total_k cannot be specified at the same time. "
570
+ "Please provide only one of them."
571
+ )
572
+
573
+ # If kernel filter is provided, parse it. Else, benchmark all kernels.
574
+ all_kernels = kernels.strip().split(",") if kernels else None
575
+ quantize_ops = collect_kernels_to_profile(all_kernels)
576
+
577
+ if len(quantize_ops) == 0:
578
+ print("No valid kernels to benchmark. Available kernels:")
579
+ print_kernels(all_kernels)
580
+ sys.exit(1)
581
+
582
+ if num_iters < 1:
583
+ print("Warning: Number of iterations must be at least 1.")
584
+ num_iters = 1
585
+
586
+ # Enumerate shapes to benchmark.
587
+ if grouped and not groups:
588
+ # In grouped mode, M, N, and K represent the groups of a single gemm.
589
+ assert m is not None and n is not None and k is not None
590
+ M = [int(m_val) for m_val in m.strip().split(",")]
591
+ N = [int(n_val) for n_val in n.strip().split(",")]
592
+ K = [int(k_val) for k_val in k.strip().split(",")]
593
+ if b is None:
594
+ B = [1] * len(M)
595
+ else:
596
+ B = [int(b_val) for b_val in b.strip().split(",")]
597
+ assert (
598
+ len(M) == len(N) == len(K) == len(B)
599
+ ), "B, M, N, and K must be the same length in grouped mode."
600
+
601
+ # Note this is a single grouped gemm.
602
+ MNK = [[B, M, N, K]]
603
+ else:
604
+ if b is None:
605
+ B = [1]
606
+ else:
607
+ B = [int(b_val) for b_val in b.strip().split(",")]
608
+ if use_llama_shapes:
609
+ MNK = get_llama_shapes()
610
+ elif use_ldm_shapes:
611
+ MNK = get_ldm_shapes()
612
+ else:
613
+ if m is None:
614
+ M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384]
615
+ else:
616
+ M = [int(m_val) for m_val in m.strip().split(",")]
617
+ if n is None:
618
+ N = [1280, 2304, 7168, 8192, 16384]
619
+ else:
620
+ N = [int(n_val) for n_val in n.strip().split(",")]
621
+ if k is None:
622
+ K = [1024, 3584, 8192, 16384]
623
+ else:
624
+ K = [int(k_val) for k_val in k.strip().split(",")]
625
+ # List all shapes for simplicity.
626
+ if pair_nk:
627
+ if len(N) != len(K):
628
+ raise Exception("N and K must be the same length in pair_NK mode.")
629
+ NK = zip(N, K)
630
+ MNK = list(
631
+ (B, M, N, K) for (B, M, (N, K)) in itertools.product(B, M, NK)
632
+ )
633
+ else:
634
+ MNK = list(itertools.product(B, M, N, K))
635
+ # When groups is provided transform shapes into grouped format.
636
+ if groups:
637
+ groups_list = [int(g) for g in groups.strip().split(",")]
638
+ if total_m:
639
+ total_m_list = [int(tm) for tm in total_m.strip().split(",")]
640
+ MNK = [
641
+ [
642
+ [b] * g,
643
+ generate_group_tensor(g, tm),
644
+ [n] * g,
645
+ [k] * g,
646
+ ]
647
+ for g in groups_list
648
+ for tm in total_m_list
649
+ for b, _, n, k in MNK
650
+ ]
651
+ elif total_k:
652
+ total_k_list = [int(tk) for tk in total_k.strip().split(",")]
653
+ MNK = [
654
+ [
655
+ [b] * g,
656
+ [m] * g,
657
+ [n] * g,
658
+ generate_group_tensor(g, tk),
659
+ ]
660
+ for g in groups_list
661
+ for tk in total_k_list
662
+ for b, m, n, _ in MNK
663
+ ]
664
+ else:
665
+ MNK = [
666
+ [[b] * g, [m] * g, [n] * g, [k] * g]
667
+ for g in groups_list
668
+ for b, m, n, k in MNK
669
+ ]
670
+
671
+ # Iterate over shapes and benchmark.
672
+ benchmark_results = []
673
+ for b, m, n, k in MNK:
674
+ print(f"Benchmarking B={b}, M={m}, N={n}, K={k}.")
675
+ benchmark_func = benchmark_grouped if grouped else benchmark
676
+ quantize_measurements = benchmark_func(
677
+ quantize_ops,
678
+ b, # pyre-ignore[6]: Incompatible parameter type [6]
679
+ m, # pyre-ignore[6]: Incompatible parameter type [6]
680
+ n, # pyre-ignore[6]: Incompatible parameter type [6]
681
+ k, # pyre-ignore[6]: Incompatible parameter type [6]
682
+ bench_quantize,
683
+ use_rotating_buffer_bench,
684
+ not no_cuda_graph,
685
+ trace,
686
+ num_iters,
687
+ not disable_fast_accum,
688
+ torch_compile,
689
+ )
690
+ benchmark_results.append(quantize_measurements)
691
+ if export_csv or plot:
692
+ os.makedirs(output_dir, exist_ok=True)
693
+ if export_csv:
694
+ datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
695
+ csv_file = os.path.join(
696
+ output_dir, f"quantize_ops_benchmark_{datetime_str}.csv"
697
+ )
698
+ print(f"CSV saved to {csv_file}")
699
+ # Export results to a CSV file.
700
+ df = pd.DataFrame(benchmark_results)
701
+ df.to_csv(csv_file, index=False)
702
+ if plot:
703
+ plot_benchmark(benchmark_results, output_dir)
704
+
705
+
706
+ if __name__ == "__main__":
707
+ invoke_main() # pragma: no cover