fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,289 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import dataclasses
11
+ from enum import Enum
12
+
13
+ import click
14
+ import torch
15
+ import yaml
16
+
17
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
18
+ BatchParams,
19
+ IndicesParams,
20
+ PoolingParams,
21
+ TBEDataConfig,
22
+ )
23
+
24
+
25
+ @dataclasses.dataclass(frozen=True)
26
+ class TBEDataConfigHelperText(Enum):
27
+ # Config File
28
+ TBE_CONFIG = "TBE data configuration filepath. If provided, all other `--tbe-*` options are ignored."
29
+
30
+ # Table Parameters
31
+ TBE_NUM_TABLES = "Number of tables (T)"
32
+ TBE_NUM_EMBEDDINGS = "Number of embeddings (E)"
33
+ TBE_EMBEDDING_DIM = "Embedding dimensions (D)"
34
+ TBE_MIXED_DIM = "Use mixed dimensions"
35
+ TBE_WEIGHTED = "Flag to indicate if the table is weighted"
36
+
37
+ # Batch Parameters
38
+ TBE_BATCH_SIZE = "Batch size (B)"
39
+ TBE_BATCH_VBE_SIGMA = "Standard deviation of B for VBE"
40
+ TBE_BATCH_VBE_DIST = "VBE distribution (choices: 'uniform', 'normal')"
41
+ TBE_BATCH_VBE_RANKS = "Number of ranks for VBE"
42
+
43
+ # Indices Parameters
44
+ TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
45
+ TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
46
+ TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
47
+ TBE_OFFSETS_DTYPE = "The dtype of the table indices (choices: '32', '64')"
48
+
49
+ # Pooling Parameters
50
+ TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
51
+ TBE_POOLING_VL_SIGMA = "Standard deviation of B for VBE"
52
+ TBE_POOLING_VL_DIST = "VBE distribution (choices: 'uniform', 'normal')"
53
+
54
+
55
+ class TBEDataConfigLoader:
56
+ @classmethod
57
+ # pyre-ignore [2]
58
+ def options(cls, func) -> click.Command:
59
+ options = [
60
+ # Config File
61
+ click.option(
62
+ "--tbe-config",
63
+ type=str,
64
+ required=False,
65
+ help=TBEDataConfigHelperText.TBE_CONFIG.value,
66
+ ),
67
+ # Table Parameters
68
+ click.option(
69
+ "--tbe-num-tables",
70
+ type=int,
71
+ default=32,
72
+ help=TBEDataConfigHelperText.TBE_NUM_TABLES.value,
73
+ ),
74
+ click.option(
75
+ "--tbe-num-embeddings",
76
+ type=int,
77
+ default=int(1e5),
78
+ help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
79
+ ),
80
+ click.option(
81
+ "--tbe-num-embeddings-list",
82
+ type=str,
83
+ required=False,
84
+ default=None,
85
+ help="Comma-separated list of number of embeddings (Es)",
86
+ ),
87
+ click.option(
88
+ "--tbe-embedding-dim",
89
+ type=int,
90
+ default=128,
91
+ help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
92
+ ),
93
+ click.option(
94
+ "--tbe-embedding-dim-list",
95
+ type=str,
96
+ required=False,
97
+ default=None,
98
+ help="Comma-separated list of number of Embedding dimensions (D)",
99
+ ),
100
+ click.option(
101
+ "--tbe-mixed-dim",
102
+ is_flag=True,
103
+ default=False,
104
+ help=TBEDataConfigHelperText.TBE_MIXED_DIM.value,
105
+ ),
106
+ click.option(
107
+ "--tbe-weighted",
108
+ is_flag=True,
109
+ default=False,
110
+ help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
111
+ ),
112
+ click.option(
113
+ "--tbe-max-indices",
114
+ type=int,
115
+ required=False,
116
+ default=None,
117
+ help="(Optional) Maximum number of indices, will be calculated if not provided",
118
+ ),
119
+ # Batch Parameters
120
+ click.option(
121
+ "--tbe-batch-size",
122
+ type=int,
123
+ default=512,
124
+ help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
125
+ ),
126
+ click.option(
127
+ "--tbe-batch-sizes-list",
128
+ type=str,
129
+ required=False,
130
+ default=None,
131
+ help="List Batch sizes per feature (Bs)",
132
+ ),
133
+ click.option(
134
+ "--tbe-batch-vbe-sigma",
135
+ type=int,
136
+ required=False,
137
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_SIGMA.value,
138
+ ),
139
+ click.option(
140
+ "--tbe-batch-vbe-dist",
141
+ type=click.Choice(["uniform", "normal"]),
142
+ required=False,
143
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_DIST.value,
144
+ ),
145
+ click.option(
146
+ "--tbe-batch-vbe-ranks",
147
+ type=int,
148
+ required=False,
149
+ help=TBEDataConfigHelperText.TBE_BATCH_VBE_RANKS.value,
150
+ ),
151
+ # Indices Parameters
152
+ click.option(
153
+ "--tbe-indices-hitters",
154
+ type=str,
155
+ default="",
156
+ help=TBEDataConfigHelperText.TBE_INDICES_HITTERS.value,
157
+ ),
158
+ click.option(
159
+ "--tbe-indices-zipf",
160
+ type=(float, float),
161
+ default=(0.1, 0.1),
162
+ help=TBEDataConfigHelperText.TBE_INDICES_ZIPF.value,
163
+ ),
164
+ click.option(
165
+ "--tbe-indices-dtype",
166
+ type=click.Choice(["32", "64"]),
167
+ default="64",
168
+ help=TBEDataConfigHelperText.TBE_INDICES_DTYPE.value,
169
+ ),
170
+ click.option(
171
+ "--tbe-offsets-dtype",
172
+ type=click.Choice(["32", "64"]),
173
+ default="64",
174
+ help=TBEDataConfigHelperText.TBE_OFFSETS_DTYPE.value,
175
+ ),
176
+ # Pooling Parameters
177
+ click.option(
178
+ "--tbe-pooling-size",
179
+ type=int,
180
+ default=20,
181
+ help=TBEDataConfigHelperText.TBE_POOLING_SIZE.value,
182
+ ),
183
+ click.option(
184
+ "--tbe-pooling-vl-sigma",
185
+ type=int,
186
+ required=False,
187
+ help=TBEDataConfigHelperText.TBE_POOLING_VL_SIGMA.value,
188
+ ),
189
+ click.option(
190
+ "--tbe-pooling-vl-dist",
191
+ type=click.Choice(["uniform", "normal"]),
192
+ required=False,
193
+ help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
194
+ ),
195
+ ]
196
+
197
+ for option in reversed(options):
198
+ func = option(func)
199
+ return func
200
+
201
+ @classmethod
202
+ def load_from_file(cls, filepath: str) -> TBEDataConfig:
203
+ with open(filepath, "r") as f:
204
+ if filepath.endswith(".yaml") or filepath.endswith(".yml"):
205
+ data = yaml.safe_load(f)
206
+ return TBEDataConfig.from_dict(data).validate()
207
+ else:
208
+ return TBEDataConfig.from_json(f.read()).validate()
209
+
210
+ @classmethod
211
+ def load_from_context(cls, context: click.Context) -> TBEDataConfig:
212
+ params = context.params
213
+
214
+ # Read table parameters
215
+ T = params["tbe_num_tables"]
216
+ E = params["tbe_num_embeddings"]
217
+ if params["tbe_num_embeddings_list"] is not None:
218
+ Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
219
+ else:
220
+ Es = None
221
+ D = params["tbe_embedding_dim"]
222
+ if params["tbe_embedding_dim_list"] is not None:
223
+ Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
224
+ else:
225
+ Ds = None
226
+
227
+ mixed_dim = params["tbe_mixed_dim"]
228
+ weighted = params["tbe_weighted"]
229
+ if params["tbe_max_indices"] is not None:
230
+ max_indices = params["tbe_max_indices"]
231
+ else:
232
+ max_indices = None
233
+
234
+ # Read batch parameters
235
+ B = params["tbe_batch_size"]
236
+ sigma_B = params["tbe_batch_vbe_sigma"]
237
+ vbe_distribution = params["tbe_batch_vbe_dist"]
238
+ vbe_num_ranks = params["tbe_batch_vbe_ranks"]
239
+ if params["tbe_batch_sizes_list"] is not None:
240
+ Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
241
+ else:
242
+ Bs = None
243
+ batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
244
+
245
+ # Read indices parameters
246
+ heavy_hitters = (
247
+ torch.tensor([float(x) for x in params["tbe_indices_hitters"].split(",")])
248
+ if params["tbe_indices_hitters"]
249
+ else torch.tensor([])
250
+ )
251
+ zipf_q, zipf_s = params["tbe_indices_zipf"]
252
+ index_dtype = (
253
+ torch.int32 if int(params["tbe_indices_dtype"]) == 32 else torch.int64
254
+ )
255
+ offset_dtype = (
256
+ torch.int32 if int(params["tbe_offsets_dtype"]) == 32 else torch.int64
257
+ )
258
+ indices_params = IndicesParams(
259
+ heavy_hitters, zipf_q, zipf_s, index_dtype, offset_dtype
260
+ )
261
+
262
+ # Read pooling parameters
263
+ L = params["tbe_pooling_size"]
264
+ sigma_L = params["tbe_pooling_vl_sigma"]
265
+ length_distribution = params["tbe_pooling_vl_dist"]
266
+ pooling_params = PoolingParams(L, sigma_L, length_distribution)
267
+
268
+ return TBEDataConfig(
269
+ T,
270
+ E,
271
+ D,
272
+ mixed_dim,
273
+ weighted,
274
+ batch_params,
275
+ indices_params,
276
+ pooling_params,
277
+ not torch.cuda.is_available(),
278
+ Es,
279
+ Ds,
280
+ max_indices,
281
+ ).validate()
282
+
283
+ @classmethod
284
+ def load(cls, context: click.Context) -> TBEDataConfig:
285
+ tbe_config_filepath = context.params["tbe_config"]
286
+ if tbe_config_filepath is not None:
287
+ return cls.load_from_file(tbe_config_filepath)
288
+ else:
289
+ return cls.load_from_context(context)
@@ -0,0 +1,170 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import dataclasses
11
+ import json
12
+ from typing import Any, Optional
13
+
14
+ import torch
15
+
16
+
17
+ def str_to_int_dtype(dtype: str) -> torch.dtype:
18
+ if dtype == "torch.int32":
19
+ return torch.int32
20
+ elif dtype == "torch.int64":
21
+ return torch.int64
22
+ else:
23
+ raise ValueError(f"Unsupported dtype: {dtype}")
24
+
25
+
26
+ @dataclasses.dataclass(frozen=True, eq=False)
27
+ class IndicesParams:
28
+ # Heavy hitters for the Zipf distribution, i.e. a probability density map
29
+ # for the most hot indices. There should not ever be more than 100
30
+ # elements, and currently it is limited to 20 entries (kHeavyHittersMaxSize)
31
+ heavy_hitters: torch.Tensor
32
+ # zipf*: parameters for the Zipf distribution (x+q)^{-s}
33
+ zipf_q: float
34
+ # zipf_s is synonymous with alpha in the literature
35
+ zipf_s: float
36
+ # [Optional] dtype for indices tensor
37
+ index_dtype: Optional[torch.dtype] = None
38
+ # [Optional] dtype for offsets tensor
39
+ offset_dtype: Optional[torch.dtype] = None
40
+
41
+ @classmethod
42
+ # pyre-ignore [3]
43
+ def from_dict(cls, data: dict[str, Any]):
44
+ if not isinstance(data["heavy_hitters"], torch.Tensor):
45
+ data["heavy_hitters"] = torch.tensor(
46
+ data["heavy_hitters"], dtype=torch.float32
47
+ )
48
+ data["index_dtype"] = str_to_int_dtype(data["index_dtype"])
49
+ data["offset_dtype"] = str_to_int_dtype(data["offset_dtype"])
50
+ return cls(**data)
51
+
52
+ @classmethod
53
+ # pyre-ignore [3]
54
+ def from_json(cls, data: str):
55
+ return cls.from_dict(json.loads(data))
56
+
57
+ def dict(self) -> dict[str, Any]:
58
+ # https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
59
+ tmp = dataclasses.asdict(self)
60
+ # Convert tensor to list for JSON serialization
61
+ tmp["heavy_hitters"] = self.heavy_hitters.tolist()
62
+ tmp["index_dtype"] = str(self.index_dtype)
63
+ tmp["offset_dtype"] = str(self.offset_dtype)
64
+ return tmp
65
+
66
+ def json(self, format: bool = False) -> str:
67
+ return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
68
+
69
+ # pyre-ignore [2]
70
+ def __eq__(self, other) -> bool:
71
+ return (
72
+ (self.zipf_q, self.zipf_s, self.index_dtype, self.offset_dtype)
73
+ == (other.zipf_q, other.zipf_s, other.index_dtype, other.offset_dtype)
74
+ ) and bool((self.heavy_hitters - other.heavy_hitters).abs().max() < 1e-6)
75
+
76
+ # pyre-ignore [3]
77
+ def validate(self):
78
+ assert self.zipf_q > 0, "zipf_q must be positive"
79
+ assert self.zipf_s > 0, "zipf_s must be positive"
80
+ assert self.index_dtype is None or self.index_dtype in [
81
+ torch.int32,
82
+ torch.int64,
83
+ ], "index_dtype must be one of [torch.int32, torch.int64]"
84
+ assert self.offset_dtype is None or self.offset_dtype in [
85
+ torch.int32,
86
+ torch.int64,
87
+ ], "offset_dtype must be one of [torch.int32, torch.int64]"
88
+ return self
89
+
90
+
91
+ @dataclasses.dataclass(frozen=True)
92
+ class BatchParams:
93
+ # Target batch size, i.e. number of batch lookups per table
94
+ B: int
95
+ # [Optional] Standard deviation of B (for variable batch size configuration)
96
+ sigma_B: Optional[int] = None
97
+ # [Optional] Distribution of batch sizes (normal, uniform)
98
+ vbe_distribution: Optional[str] = "normal"
99
+ # Number of ranks for variable batch size generation
100
+ vbe_num_ranks: Optional[int] = None
101
+ # List of target batch sizes, i.e. number of batch lookups per table
102
+ Bs: Optional[list[int]] = None
103
+
104
+ @classmethod
105
+ # pyre-ignore [3]
106
+ def from_dict(cls, data: dict[str, Any]):
107
+ return cls(**data)
108
+
109
+ @classmethod
110
+ # pyre-ignore [3]
111
+ def from_json(cls, data: str):
112
+ return cls.from_dict(json.loads(data))
113
+
114
+ def dict(self) -> dict[str, Any]:
115
+ return dataclasses.asdict(self)
116
+
117
+ def json(self, format: bool = False) -> str:
118
+ return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
119
+
120
+ # pyre-ignore [3]
121
+ def validate(self):
122
+ if self.Bs is not None:
123
+ assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
124
+ else:
125
+ assert self.B > 0, "B must be positive"
126
+ assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
127
+ assert (
128
+ self.vbe_num_ranks is None or self.vbe_num_ranks > 0
129
+ ), "vbe_num_ranks must be positive"
130
+ assert self.vbe_distribution is None or self.vbe_distribution in [
131
+ "normal",
132
+ "uniform",
133
+ ], "vbe_distribution must be one of [normal, uniform]"
134
+ return self
135
+
136
+
137
+ @dataclasses.dataclass(frozen=True)
138
+ class PoolingParams:
139
+ # Target bag size, i.e. pooling factor, or number of indices per batch lookup
140
+ L: int
141
+ # [Optional] Standard deviation of L (for variable bag size configuration)
142
+ sigma_L: Optional[int] = None
143
+ # [Optional] Distribution of embedding sequence lengths (normal, uniform)
144
+ length_distribution: Optional[str] = "normal"
145
+
146
+ @classmethod
147
+ # pyre-ignore [3]
148
+ def from_dict(cls, data: dict[str, Any]):
149
+ return cls(**data)
150
+
151
+ @classmethod
152
+ # pyre-ignore [3]
153
+ def from_json(cls, data: str):
154
+ return cls.from_dict(json.loads(data))
155
+
156
+ def dict(self) -> dict[str, Any]:
157
+ return dataclasses.asdict(self)
158
+
159
+ def json(self, format: bool = False) -> str:
160
+ return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
161
+
162
+ # pyre-ignore [3]
163
+ def validate(self):
164
+ assert self.L > 0, "L must be positive"
165
+ assert not self.sigma_L or self.sigma_L > 0, "sigma_L must be positive"
166
+ assert self.length_distribution is None or self.length_distribution in [
167
+ "normal",
168
+ "uniform",
169
+ ], "length_distribution must be one of [normal, uniform]"
170
+ return self
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import logging
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from fbgemm_gpu.split_embedding_configs import SparseType
15
+
16
+ logging.basicConfig(level=logging.DEBUG)
17
+
18
+
19
+ def fill_random_scale_bias(
20
+ emb: torch.nn.Module,
21
+ T: int,
22
+ weights_precision: SparseType,
23
+ ) -> None:
24
+ for t in range(T):
25
+ # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
26
+ (weights, scale_shift) = emb.split_embedding_weights()[t]
27
+ if scale_shift is not None:
28
+ (E, R) = scale_shift.shape
29
+ assert R == 4
30
+ scales = None
31
+ shifts = None
32
+ if weights_precision == SparseType.INT8:
33
+ scales = np.random.uniform(0.001, 0.01, size=(E,)).astype(np.float16)
34
+ shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
35
+ elif weights_precision == SparseType.INT4:
36
+ scales = np.random.uniform(0.01, 0.1, size=(E,)).astype(np.float16)
37
+ shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
38
+ elif weights_precision == SparseType.INT2:
39
+ scales = np.random.uniform(0.1, 1, size=(E,)).astype(np.float16)
40
+ shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
41
+ scale_shift.copy_(
42
+ torch.tensor(
43
+ np.stack([scales, shifts], axis=1)
44
+ .astype(np.float16)
45
+ .view(np.uint8),
46
+ device=scale_shift.device,
47
+ )
48
+ )
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+
10
+ from .kv_embedding_ops_inference import KVEmbeddingInference # noqa: F401
11
+ from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401