fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,187 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import click
11
+
12
+ from fbgemm_gpu.split_embedding_configs import SparseType
13
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import BoundsCheckMode
14
+
15
+ from .bench_config import TBEBenchmarkingHelperText
16
+ from .tbe_data_config_loader import TBEDataConfigHelperText
17
+
18
+
19
+ class TbeBenchClickInterface:
20
+ @classmethod
21
+ # pyre-ignore [2]
22
+ def common_options(cls, func) -> click.Command:
23
+ options = [
24
+ click.option(
25
+ "--alpha",
26
+ default=1.0,
27
+ help="The alpha value used for the benchmark, default is 1.0. Recommended value: alpha=1.15 for training and alpha=1.09 for inference",
28
+ ),
29
+ click.option(
30
+ "--batch-size",
31
+ default=512,
32
+ help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value + " Default is 512.",
33
+ ),
34
+ click.option(
35
+ "--weights-precision",
36
+ type=SparseType,
37
+ default=SparseType.FP32,
38
+ help="The precision type for weights, default is FP32.",
39
+ ),
40
+ click.option(
41
+ "--stoc",
42
+ is_flag=True,
43
+ default=False,
44
+ help="Flag to enable stochastic rounding, default is False.",
45
+ ),
46
+ click.option(
47
+ "--iters",
48
+ default=100,
49
+ help=TBEBenchmarkingHelperText.BENCH_ITERATIONS.value
50
+ + " Default is 100.",
51
+ ),
52
+ click.option(
53
+ "--warmup-runs",
54
+ default=0,
55
+ help=(
56
+ TBEBenchmarkingHelperText.BENCH_WARMUP_ITERATIONS.value
57
+ + " Default is 0."
58
+ ),
59
+ ),
60
+ click.option( # Note: Original default for uvm bencmark is 0.1
61
+ "--reuse",
62
+ default=0.0,
63
+ help="The inter-batch indices reuse rate for the benchmark, default is 0.0.",
64
+ ),
65
+ click.option(
66
+ "--flush-gpu-cache-size-mb",
67
+ default=0,
68
+ help=TBEBenchmarkingHelperText.BENCH_FLUSH_GPU_CACHE_SIZE.value,
69
+ ),
70
+ ]
71
+
72
+ for option in reversed(options):
73
+ func = option(func)
74
+ return func
75
+
76
+ @classmethod
77
+ # pyre-ignore [2]
78
+ def table_options(cls, func) -> click.Command:
79
+ options = [
80
+ click.option(
81
+ "--bag-size",
82
+ default=20,
83
+ help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value + " Default is 20.",
84
+ ),
85
+ click.option(
86
+ "--embedding-dim",
87
+ default=128,
88
+ help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value
89
+ + " Default is 128.",
90
+ ),
91
+ click.option(
92
+ "--mixed",
93
+ is_flag=True,
94
+ default=False,
95
+ help=TBEDataConfigHelperText.TBE_MIXED_DIM.value + " Default is False.",
96
+ ),
97
+ click.option(
98
+ "--num-embeddings",
99
+ default=int(1e5),
100
+ help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value
101
+ + " Default is 1e5.",
102
+ ),
103
+ click.option(
104
+ "--num-tables",
105
+ default=32,
106
+ help=TBEDataConfigHelperText.TBE_NUM_TABLES.value + " Default is 32.",
107
+ ),
108
+ click.option(
109
+ "--tables",
110
+ type=str,
111
+ default=None,
112
+ help="Comma-separated list of table numbers Default is None.",
113
+ ),
114
+ ]
115
+
116
+ for option in reversed(options):
117
+ func = option(func)
118
+ return func
119
+
120
+ @classmethod
121
+ # pyre-ignore [2]
122
+ def device_options(cls, func) -> click.Command:
123
+ options = [
124
+ click.option(
125
+ "--cache-precision",
126
+ type=SparseType,
127
+ default=None,
128
+ help="The precision type for cache, default is None.",
129
+ ),
130
+ click.option(
131
+ "--managed",
132
+ type=click.Choice(
133
+ ["device", "managed", "managed_caching"], case_sensitive=False
134
+ ),
135
+ default="device",
136
+ help="The managed option for embedding location. Choices are 'device', 'managed', or 'managed_caching'. Default is 'device'.",
137
+ ),
138
+ click.option(
139
+ "--row-wise/--no-row-wise",
140
+ default=True,
141
+ help="Flag to enable or disable row-wise optimization, default is enabled. Use --no-row-wise to disable.",
142
+ ),
143
+ click.option(
144
+ "--weighted",
145
+ is_flag=True,
146
+ default=False,
147
+ help=TBEDataConfigHelperText.TBE_WEIGHTED.value + " Default is False.",
148
+ ),
149
+ click.option(
150
+ "--pooling",
151
+ type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
152
+ default="sum",
153
+ help="The pooling method to use. Choices are 'sum', 'mean', or 'none'. Default is 'sum'.",
154
+ ),
155
+ click.option(
156
+ "--bounds-check-mode",
157
+ type=int,
158
+ default=BoundsCheckMode.NONE.value,
159
+ help="The bounds check mode, default is NONE. Options are: FATAL (0) - Raise an exception (CPU) or device-side assert (CUDA), WARNING (1) - Log the first out-of-bounds instance per kernel, and set to zero, IGNORE (2) - Set to zero, NONE (3) - No bounds checks, V2_IGNORE (4) - IGNORE with V2 enabled, V2_WARNING (5) - WARNING with V2 enabled, V2_FATAL (6) - FATAL with V2 enabled.",
160
+ ),
161
+ ]
162
+
163
+ for option in reversed(options):
164
+ func = option(func)
165
+ return func
166
+
167
+ @classmethod
168
+ # pyre-ignore [2]
169
+ def vbe_options(cls, func) -> click.Command:
170
+ options = [
171
+ click.option(
172
+ "--bag-size-list",
173
+ type=str,
174
+ default="20",
175
+ help="A comma-separated list of bag sizes for each table, default is '20'.",
176
+ ),
177
+ click.option(
178
+ "--bag-size-sigma-list",
179
+ type=str,
180
+ default="None",
181
+ help="A comma-separated list of bag size standard deviations for generating bag sizes (one std per table). If set, the benchmark will treat --bag-size-list as a list of bag size means. Default is 'None'.",
182
+ ),
183
+ ]
184
+
185
+ for option in reversed(options):
186
+ func = option(func)
187
+ return func
@@ -0,0 +1,137 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+
10
+ import click
11
+ import torch
12
+
13
+ from fbgemm_gpu.tbe.bench import IndicesParams
14
+
15
+
16
+ @click.group()
17
+ def cli() -> None:
18
+ pass
19
+
20
+
21
+ @cli.command()
22
+ @click.option("--indices", required=True, help="Indices tensor file (*.pt)")
23
+ def estimate(indices: str) -> None:
24
+ """
25
+ Estimate the distribution of indices given a tensor file
26
+
27
+ Parameters:
28
+ indices (str): Indices tensor file (*.pt)
29
+
30
+ Returns:
31
+ None
32
+
33
+ Example:
34
+ estimate --indices="indices.pt"
35
+ """
36
+
37
+ indices = torch.load(indices)
38
+ heavy_hitters, q, s, max_index, num_indices = (
39
+ torch.ops.fbgemm.tbe_estimate_indices_distribution(indices)
40
+ )
41
+
42
+ params = IndicesParams(
43
+ heavy_hitters=heavy_hitters, zipf_q=q, zipf_s=s, index_dtype=indices.dtype
44
+ )
45
+
46
+ print(params.json(format=True), f"max_index={max_index}\nnum_indices={num_indices}")
47
+
48
+
49
+ @cli.command()
50
+ @click.option(
51
+ "--hitters",
52
+ type=str,
53
+ default="",
54
+ help="TBE heavy hitter indices (comma-delimited list of floats)",
55
+ )
56
+ @click.option(
57
+ "--zipf",
58
+ type=(float, float),
59
+ default=(0.1, 0.1),
60
+ help="Zipf distribution parameters for indices generation (q, s)",
61
+ )
62
+ @click.option(
63
+ "-e",
64
+ "--max-index",
65
+ type=int,
66
+ default=20,
67
+ help="Max index value (< E)",
68
+ )
69
+ @click.option(
70
+ "-n",
71
+ "--num-indices",
72
+ type=int,
73
+ default=20,
74
+ help="Target number of indices to generate",
75
+ )
76
+ @click.option(
77
+ "--output",
78
+ type=str,
79
+ required=True,
80
+ help="Tensor filepath (*.pt) to save the generated indices",
81
+ )
82
+ def generate(
83
+ hitters: str,
84
+ zipf: tuple[float, float],
85
+ max_index: int,
86
+ num_indices: int,
87
+ output: str,
88
+ ) -> None:
89
+ """
90
+ Generates a tensor of indices given the indices distribution parameters
91
+
92
+ Parameters:
93
+ hitters (str): heavy hitter indices (comma-delimited list of floats)
94
+
95
+ zipf (Tuple[float, float]): Zipf distribution parameters for indices generation (q, s)
96
+
97
+ max_index (int): Max index value (E)
98
+
99
+ num_indices (int): Target number of indices to generate
100
+
101
+ output (str): Tensor filepath (*.pt) to save the generated indices
102
+
103
+ Returns:
104
+ None
105
+
106
+ Example:
107
+ generate --hitters="2,4,6" --zipf="1.1,1.1" --max-index=10 --num-indices=100 --output="generated_indices.pt"
108
+ """
109
+ assert max_index > 0, "Max index value (E) must be greater than 0"
110
+ assert num_indices > 0, "Target number of indices must be greater than 0"
111
+ assert zipf[0] > 0, "Zipf parameter q must be greater than 0.0"
112
+ assert zipf[1] > 0, "Zipf parameter s must be greater than 0.0"
113
+ assert output != "", "Output file path must be provided"
114
+
115
+ try:
116
+ _hitters: list[float] = (
117
+ [float(x) for x in hitters.split(",")] if hitters else []
118
+ )
119
+ except Exception as e:
120
+ raise AssertionError(
121
+ f'Error: {e}. Please ensure to use comma-delimited list of floats, e.g., --hitters="2,4,6". '
122
+ )
123
+
124
+ heavy_hitters = torch.tensor(_hitters)
125
+ assert heavy_hitters.numel() <= 20, "The number of heavy hitters should be <= 20"
126
+
127
+ indices = torch.ops.fbgemm.tbe_generate_indices_from_distribution(
128
+ heavy_hitters, zipf[0], zipf[1], max_index, num_indices
129
+ )
130
+
131
+ print(f"Generated indices: {indices}")
132
+ torch.save(indices, output)
133
+ print(f"Saved indices to: {output}")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ cli()
@@ -0,0 +1,149 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import dataclasses
11
+ from typing import Any, Optional
12
+
13
+ import click
14
+ import torch
15
+
16
+ from fbgemm_gpu.split_embedding_configs import SparseType
17
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
18
+ BoundsCheckMode,
19
+ EmbeddingLocation,
20
+ PoolingMode,
21
+ )
22
+
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class EmbeddingOpsCommonConfig:
26
+ # Precision of the embedding weights
27
+ weights_dtype: SparseType
28
+ # Precision of the embedding cache
29
+ cache_dtype: Optional[SparseType]
30
+ # Precision of the embedding output
31
+ output_dtype: SparseType
32
+ # Enable stochastic rounding when performing quantization
33
+ stochastic_rounding: bool
34
+ # Pooling operation to perform
35
+ pooling_mode: PoolingMode
36
+ # Use host-mapped UVM buffers
37
+ uvm_host_mapped: bool
38
+ # Memory location of the embeddings
39
+ embedding_location: EmbeddingLocation
40
+ # Bounds check mode
41
+ bounds_check_mode: BoundsCheckMode
42
+
43
+ # pyre-ignore [3]
44
+ def validate(self):
45
+ return self
46
+
47
+ def split_args(self) -> dict[str, Any]:
48
+ return {
49
+ "weights_precision": self.weights_dtype,
50
+ "stochastic_rounding": self.stochastic_rounding,
51
+ "output_dtype": self.output_dtype,
52
+ "pooling_mode": self.pooling_mode,
53
+ "bounds_check_mode": self.bounds_check_mode,
54
+ "uvm_host_mapped": self.uvm_host_mapped,
55
+ }
56
+
57
+
58
+ class EmbeddingOpsCommonConfigLoader:
59
+ @classmethod
60
+ # pyre-ignore [2]
61
+ def options(cls, func) -> click.Command:
62
+ options = [
63
+ click.option(
64
+ "--emb-weights-dtype",
65
+ type=SparseType,
66
+ default=SparseType.FP32,
67
+ help="Precision of the embedding weights",
68
+ ),
69
+ click.option(
70
+ "--emb-cache-dtype",
71
+ type=SparseType,
72
+ default=None,
73
+ help="Precision of the embedding cache",
74
+ ),
75
+ click.option(
76
+ "--emb-output-dtype",
77
+ type=SparseType,
78
+ default=SparseType.FP32,
79
+ help="Precision of the embedding output",
80
+ ),
81
+ click.option(
82
+ "--emb-stochastic-rounding",
83
+ is_flag=True,
84
+ default=False,
85
+ help="Enable stochastic rounding when performing quantization",
86
+ ),
87
+ click.option(
88
+ "--emb-pooling-mode",
89
+ type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
90
+ default="sum",
91
+ help="Pooling operation to perform",
92
+ ),
93
+ click.option(
94
+ "--emb-uvm-host-mapped",
95
+ is_flag=True,
96
+ default=False,
97
+ help="Use host-mapped UVM buffers",
98
+ ),
99
+ click.option(
100
+ "--emb-location",
101
+ default="device",
102
+ type=click.Choice(EmbeddingLocation.str_values(), case_sensitive=False),
103
+ help="Memory location of the embeddings",
104
+ ),
105
+ click.option(
106
+ "--emb-bounds-check",
107
+ type=int,
108
+ default=BoundsCheckMode.WARNING.value,
109
+ help="Bounds check mode"
110
+ f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
111
+ f"WARNING={BoundsCheckMode.WARNING.value}, "
112
+ f"IGNORE={BoundsCheckMode.IGNORE.value}, "
113
+ f"NONE={BoundsCheckMode.NONE.value}",
114
+ ),
115
+ ]
116
+
117
+ for option in reversed(options):
118
+ func = option(func)
119
+ return func
120
+
121
+ @classmethod
122
+ def load(cls, context: click.Context) -> EmbeddingOpsCommonConfig:
123
+ params = context.params
124
+
125
+ weights_dtype = params["emb_weights_dtype"]
126
+ cache_dtype = params["emb_cache_dtype"]
127
+ output_dtype = params["emb_output_dtype"]
128
+ stochastic_rounding = params["emb_stochastic_rounding"]
129
+ pooling_mode = PoolingMode.from_str(str(params["emb_pooling_mode"]))
130
+ uvm_host_mapped = params["emb_uvm_host_mapped"]
131
+ bounds_check_mode = BoundsCheckMode(params["emb_bounds_check"])
132
+
133
+ embedding_location = EmbeddingLocation.from_str(str(params["emb_location"]))
134
+ if (
135
+ embedding_location is EmbeddingLocation.DEVICE
136
+ and not torch.cuda.is_available()
137
+ ):
138
+ embedding_location = EmbeddingLocation.HOST
139
+
140
+ return EmbeddingOpsCommonConfig(
141
+ weights_dtype,
142
+ cache_dtype,
143
+ output_dtype,
144
+ stochastic_rounding,
145
+ pooling_mode,
146
+ uvm_host_mapped,
147
+ embedding_location,
148
+ bounds_check_mode,
149
+ ).validate()
@@ -0,0 +1,119 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+
10
+ import logging
11
+ import statistics
12
+ from dataclasses import dataclass
13
+ from typing import Callable
14
+
15
+ import torch
16
+
17
+ logging.basicConfig(level=logging.DEBUG)
18
+
19
+
20
+ @dataclass
21
+ class EvalCompressionBenchmarkOutput:
22
+ avg: float
23
+ fwd: float
24
+ bwd: float
25
+ compressed_avg: float
26
+ compressed_fwd: float
27
+ reindex: float
28
+ compressed_bwd: float
29
+
30
+
31
+ def benchmark_eval_compression(
32
+ baseline_requests: list[tuple[torch.Tensor, torch.Tensor]],
33
+ compressed_requests: list[tuple[torch.Tensor, torch.Tensor]],
34
+ baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
35
+ compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
36
+ reindex: torch.Tensor,
37
+ embedding_dim: int,
38
+ ) -> EvalCompressionBenchmarkOutput:
39
+ times = []
40
+ fwd_times = []
41
+ bwd_times = []
42
+ torch.cuda.synchronize()
43
+ start_event = torch.cuda.Event(enable_timing=True)
44
+ end_event = torch.cuda.Event(enable_timing=True)
45
+ for indices, offsets in baseline_requests:
46
+ time = 0.0
47
+ start_event.record()
48
+ # forward
49
+ out = baseline_func(indices, offsets)
50
+ end_event.record()
51
+ torch.cuda.synchronize()
52
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
53
+ fwd_times.append(it_time)
54
+ time += it_time
55
+
56
+ grad = torch.rand_like(out)
57
+ start_event.record()
58
+ # backward
59
+ out.backward(grad)
60
+ end_event.record()
61
+ torch.cuda.synchronize()
62
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
63
+ bwd_times.append(it_time)
64
+ time += it_time
65
+ times.append(time)
66
+
67
+ avg = statistics.median(times)
68
+ fwd = statistics.median(fwd_times)
69
+ bwd = statistics.median(bwd_times)
70
+
71
+ times.clear()
72
+ fwd_times.clear()
73
+ bwd_times.clear()
74
+ reindex_times = []
75
+
76
+ torch.cuda.synchronize()
77
+ start_event = torch.cuda.Event(enable_timing=True)
78
+ end_event = torch.cuda.Event(enable_timing=True)
79
+
80
+ for indices, offsets in compressed_requests:
81
+ time = 0.0
82
+ start_event.record()
83
+ # forward
84
+ out = compressed_func(indices, offsets)
85
+ end_event.record()
86
+ torch.cuda.synchronize()
87
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
88
+ fwd_times.append(it_time)
89
+ time += it_time
90
+
91
+ start_event.record()
92
+ # reindex
93
+ out = out.reshape(-1, embedding_dim)
94
+ out = torch.ops.fbgemm.index_select_dim0(out, reindex)
95
+ end_event.record()
96
+ torch.cuda.synchronize()
97
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
98
+ reindex_times.append(it_time)
99
+ time += it_time
100
+
101
+ grad = torch.rand_like(out)
102
+ start_event.record()
103
+ # backward
104
+ out.backward(grad)
105
+ end_event.record()
106
+ torch.cuda.synchronize()
107
+ it_time = start_event.elapsed_time(end_event) * 1.0e-3
108
+ bwd_times.append(it_time)
109
+ time += it_time
110
+ times.append(time)
111
+
112
+ compressed_avg = statistics.median(times)
113
+ compressed_fwd = statistics.median(fwd_times)
114
+ reindex = statistics.median(reindex_times)
115
+ compressed_bwd = statistics.median(bwd_times)
116
+
117
+ return EvalCompressionBenchmarkOutput(
118
+ avg, fwd, bwd, compressed_avg, compressed_fwd, reindex, compressed_bwd
119
+ )
@@ -0,0 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+
11
+ import logging
12
+ from dataclasses import dataclass
13
+
14
+ haveAIBench = False
15
+ try:
16
+ from aibench_observer.utils.observer import emitMetric
17
+
18
+ haveAIBench = True
19
+ except Exception:
20
+ haveAIBench = False
21
+
22
+
23
+ @dataclass
24
+ class BenchmarkReporter:
25
+ report: bool
26
+ logger: logging.Logger = logging.getLogger()
27
+
28
+ # pyre-ignore[3]
29
+ def __post_init__(self):
30
+ self.logger.setLevel(logging.INFO)
31
+
32
+ # pyre-ignore[2]
33
+ def emit_metric(self, **kwargs) -> None:
34
+ if self.report and haveAIBench:
35
+ self.logger.info(emitMetric(**kwargs))