fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,556 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import torch
16
+
17
+ # pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
18
+ from numpy.random import default_rng
19
+
20
+ from .common import get_device
21
+ from .offsets import get_table_batched_offsets_from_dense
22
+
23
+ logging.basicConfig(level=logging.DEBUG)
24
+
25
+
26
+ @dataclass
27
+ class TBERequest:
28
+ """
29
+ `generate_requests`'s output wrapper
30
+ """
31
+
32
+ indices: torch.Tensor
33
+ offsets: torch.Tensor
34
+ per_sample_weights: Optional[torch.Tensor] = None
35
+ Bs_per_feature_per_rank: Optional[list[list[int]]] = None
36
+
37
+ def unpack_2(self) -> tuple[torch.Tensor, torch.Tensor]:
38
+ return (self.indices, self.offsets)
39
+
40
+ def unpack_3(
41
+ self,
42
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
43
+ return (self.indices, self.offsets, self.per_sample_weights)
44
+
45
+ def unpack_4(
46
+ self,
47
+ ) -> tuple[
48
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]
49
+ ]:
50
+ return (
51
+ self.indices,
52
+ self.offsets,
53
+ self.per_sample_weights,
54
+ self.Bs_per_feature_per_rank,
55
+ )
56
+
57
+
58
+ def generate_requests_from_data_file(
59
+ iters: int,
60
+ B: int,
61
+ T: int,
62
+ L: int,
63
+ E: int,
64
+ weighted: bool,
65
+ requests_data_file: Optional[str] = None,
66
+ indices_file: Optional[str] = None,
67
+ offsets_file: Optional[str] = None,
68
+ tables: Optional[str] = None,
69
+ index_dtype: Optional[torch.dtype] = None,
70
+ offset_dtype: Optional[torch.dtype] = None,
71
+ ) -> list[TBERequest]:
72
+ """
73
+ Generate TBE requests from the input data file. If `requests_data_file` is provided,
74
+ `indices_file` and `offsets_file` should not be provided. If either `indices_file`
75
+ or `offsets_file` is provided, both must be provided.
76
+ """
77
+ assert not (
78
+ requests_data_file and (indices_file or offsets_file)
79
+ ), "If requests_data_file is provided, indices_file and offsets_file cannot be provided."
80
+ assert (
81
+ indices_file and offsets_file
82
+ ), "Both indices_file and offsets_file must be provided if either is provided."
83
+
84
+ if requests_data_file:
85
+ indices_tensor, offsets_tensor, *rest = torch.load(requests_data_file)
86
+ else:
87
+ indices_tensor = torch.load(indices_file)
88
+ offsets_tensor = torch.load(offsets_file)
89
+
90
+ average_L = 0
91
+ if tables is not None:
92
+ emb_tables = tuple(int(x) for x in tables.split(","))
93
+ indices = torch.zeros(0, dtype=indices_tensor.dtype)
94
+ offsets = torch.zeros(1, dtype=offsets_tensor.dtype)
95
+ total_L = 0
96
+ for t in emb_tables:
97
+ t_offsets = offsets_tensor[B * t : B * (t + 1) + 1]
98
+ total_L += t_offsets[-1] - t_offsets[0]
99
+ indices = torch.cat((indices, indices_tensor[t_offsets[0] : t_offsets[-1]]))
100
+ offsets = torch.cat(
101
+ (
102
+ offsets,
103
+ t_offsets[1:] - t_offsets[0] + offsets[-1],
104
+ )
105
+ )
106
+ indices_tensor = indices
107
+ offsets_tensor = offsets
108
+ average_L = int(total_L / B)
109
+
110
+ assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), (
111
+ f"Requested tables: {emb_tables} "
112
+ f"does not conform to inputs (T, B) = ({T}, {B})."
113
+ )
114
+ logging.warning(
115
+ f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based "
116
+ f"on tables: {emb_tables}"
117
+ )
118
+ else:
119
+ average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B)
120
+ assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), (
121
+ f"Data file (indices = {indices_tensor.size()}, "
122
+ f"offsets = {offsets_tensor.size()}, lengths = {offsets_tensor.size() - 1}) "
123
+ f"does not conform to inputs (T, B) = ({T}, {B})."
124
+ )
125
+
126
+ assert (
127
+ L == average_L
128
+ ), f"Requested L does not align with provided data file ({L} vs. {average_L})"
129
+ assert E > max(indices_tensor), (
130
+ f"Number of embeddings is not enough to support maximum index "
131
+ f"provided by data file {E} vs. {max(indices_tensor)}"
132
+ )
133
+
134
+ weights_tensor = (
135
+ None
136
+ if not weighted
137
+ else torch.randn(indices_tensor.size(), device=get_device())
138
+ )
139
+ rs = []
140
+ for _ in range(iters):
141
+ rs.append(
142
+ TBERequest(
143
+ maybe_to_dtype(indices_tensor.to(get_device()), index_dtype),
144
+ maybe_to_dtype(offsets_tensor.to(get_device()), offset_dtype),
145
+ weights_tensor,
146
+ )
147
+ )
148
+ return rs
149
+
150
+
151
+ def generate_int_data_from_stats(
152
+ mu: int,
153
+ sigma: int,
154
+ size: int,
155
+ distribution: str,
156
+ ) -> npt.NDArray:
157
+ """
158
+ Generate integer data based on stats
159
+ """
160
+ if distribution == "uniform":
161
+ # TODO: either make these separate parameters or make a separate version of
162
+ # generate_requests to handle the uniform dist case once whole
163
+ # generate_requests function is refactored to split into helper functions
164
+ # for each use case.
165
+ # mu represents the lower bound when the uniform distribution is used
166
+ lower_bound = mu
167
+ # sigma represetns the upper bound when the uniform distribution is used
168
+ upper_bound = sigma + 1
169
+ return np.random.randint(
170
+ lower_bound,
171
+ upper_bound,
172
+ (size,),
173
+ dtype=np.int32,
174
+ )
175
+ else: # normal dist
176
+ return np.random.normal(loc=mu, scale=sigma, size=size).astype(int)
177
+
178
+
179
+ def generate_pooling_factors_from_stats(
180
+ iters: int,
181
+ Bs: list[int],
182
+ L: int,
183
+ sigma_L: int,
184
+ # distribution of pooling factors
185
+ length_dist: str,
186
+ ) -> tuple[int, torch.Tensor]:
187
+ """
188
+ Generate pooling factors for the TBE requests from the given stats
189
+ """
190
+ Ls_list = []
191
+ for B in Bs:
192
+ Ls_list.append(generate_int_data_from_stats(L, sigma_L, B, length_dist))
193
+
194
+ # Concat all Ls
195
+ Ls = np.concatenate(Ls_list)
196
+
197
+ # Make sure that Ls are positive
198
+ Ls[Ls < 0] = 0
199
+ # Use the same L distribution across iters
200
+ Ls = np.tile(Ls, iters)
201
+ L = Ls.max()
202
+ # Make it exclusive cumsum
203
+ L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
204
+ return L, L_offsets
205
+
206
+
207
+ def generate_batch_sizes_from_stats(
208
+ B: int,
209
+ T: int,
210
+ sigma_B: int,
211
+ vbe_num_ranks: int,
212
+ # Distribution of batch sizes
213
+ batch_size_dist: str,
214
+ ) -> tuple[list[int], list[list[int]]]:
215
+ """
216
+ Generate batch sizes for features from the given stats
217
+ """
218
+ # Generate batch size per feature per rank
219
+ Bs_feature_rank = generate_int_data_from_stats(
220
+ B, sigma_B, T * vbe_num_ranks, batch_size_dist
221
+ )
222
+
223
+ # Make sure that Bs are at least one
224
+ Bs_feature_rank = np.absolute(Bs_feature_rank)
225
+ Bs_feature_rank[Bs_feature_rank == 0] = 1
226
+
227
+ # Convert numpy array to Torch tensor
228
+ Bs_feature_rank = torch.from_numpy(Bs_feature_rank).view(T, vbe_num_ranks)
229
+ # Compute batch sizes per feature
230
+ Bs = Bs_feature_rank.sum(1).tolist()
231
+
232
+ return Bs, Bs_feature_rank.tolist()
233
+
234
+
235
+ def generate_indices_uniform(
236
+ iters: int,
237
+ Bs: list[int],
238
+ L: int,
239
+ E: int,
240
+ use_variable_L: bool,
241
+ L_offsets: torch.Tensor,
242
+ ) -> torch.Tensor:
243
+ """
244
+ Generate indices for the TBE requests using the uniform distribution
245
+ """
246
+ total_B = sum(Bs)
247
+ indices = torch.randint(
248
+ low=0,
249
+ high=E,
250
+ size=(iters, total_B, L),
251
+ device="cpu" if use_variable_L else get_device(),
252
+ dtype=torch.int32,
253
+ )
254
+ # each bag is usually sorted
255
+ (indices, _) = torch.sort(indices)
256
+ if use_variable_L:
257
+ # 1D layout, where row offsets are determined by L_offsets
258
+ indices = torch.ops.fbgemm.bottom_k_per_row(
259
+ indices.to(torch.long), L_offsets, False
260
+ )
261
+ indices = indices.to(get_device()).int()
262
+ else:
263
+ # 2D layout
264
+ indices = indices.reshape(iters, total_B * L)
265
+ return indices
266
+
267
+
268
+ def generate_indices_zipf(
269
+ iters: int,
270
+ Bs: list[int],
271
+ L: int,
272
+ E: int,
273
+ alpha: float,
274
+ zipf_oversample_ratio: int,
275
+ use_variable_L: bool,
276
+ L_offsets: torch.Tensor,
277
+ deterministic_output: bool,
278
+ ) -> torch.Tensor:
279
+ """
280
+ Generate indices for the TBE requests using the zipf distribution
281
+ """
282
+ assert E >= L, "num-embeddings must be greater than equal to bag-size"
283
+ # oversample and then remove duplicates to obtain sampling without
284
+ # replacement
285
+ if L == 0:
286
+ return torch.empty(iters, 0, dtype=torch.int).to(get_device())
287
+ total_B = sum(Bs)
288
+ zipf_shape = (iters, total_B, zipf_oversample_ratio * L)
289
+ if torch.cuda.is_available():
290
+ zipf_shape_total_len = np.prod(zipf_shape)
291
+ indices_list = []
292
+ # process 8 GB at a time on GPU
293
+ chunk_len = int(1e9)
294
+ for chunk_begin in range(0, zipf_shape_total_len, chunk_len):
295
+ indices_gpu = torch.ops.fbgemm.zipf_cuda(
296
+ alpha,
297
+ min(zipf_shape_total_len - chunk_begin, chunk_len),
298
+ seed=torch.randint(2**31 - 1, (1,))[0],
299
+ )
300
+ indices_list.append(indices_gpu.cpu())
301
+ indices = torch.cat(indices_list).reshape(zipf_shape)
302
+ else:
303
+ indices = torch.as_tensor(np.random.zipf(a=alpha, size=zipf_shape))
304
+ indices = (indices - 1) % E
305
+ if use_variable_L:
306
+ indices = torch.ops.fbgemm.bottom_k_per_row(indices, L_offsets, True)
307
+ else:
308
+ indices = torch.ops.fbgemm.bottom_k_per_row(
309
+ indices, torch.tensor([0, L], dtype=torch.long), True
310
+ )
311
+ if deterministic_output:
312
+ rng = default_rng(12345)
313
+ else:
314
+ rng = default_rng()
315
+ permutation = torch.as_tensor(
316
+ rng.choice(E, size=indices.max().item() + 1, replace=False)
317
+ )
318
+ indices = permutation.gather(0, indices.flatten())
319
+ indices = indices.to(get_device()).int()
320
+ if not use_variable_L:
321
+ indices = indices.reshape(iters, total_B * L)
322
+ return indices
323
+
324
+
325
+ def update_indices_with_random_reuse(
326
+ iters: int,
327
+ Bs: list[int],
328
+ L: int,
329
+ reuse: float,
330
+ indices: torch.Tensor,
331
+ ) -> torch.Tensor:
332
+ """
333
+ Update the generated indices with random reuse
334
+ """
335
+ for it in range(iters - 1):
336
+ B_offset = 0
337
+ for B in Bs:
338
+ reused_indices = torch.randperm(B * L, device=get_device())[
339
+ : int(B * L * reuse)
340
+ ]
341
+ reused_indices += B_offset
342
+ indices[it + 1, reused_indices] = indices[it, reused_indices]
343
+ B_offset += B * L
344
+ return indices
345
+
346
+
347
+ def update_indices_with_random_pruning(
348
+ iters: int,
349
+ B: int,
350
+ T: int,
351
+ L: int,
352
+ indices: torch.Tensor,
353
+ ) -> torch.Tensor:
354
+ """
355
+ Update the generated indices with random pruning
356
+ """
357
+ for it in range(iters):
358
+ for t in range(T):
359
+ num_negative_indices = B // 2
360
+ random_locations = torch.randint(
361
+ low=0,
362
+ high=(B * L),
363
+ size=(num_negative_indices,),
364
+ device=torch.cuda.current_device(),
365
+ dtype=torch.int32,
366
+ )
367
+ indices[it, t, random_locations] = -1
368
+ return indices
369
+
370
+
371
+ def maybe_to_dtype(tensor: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.Tensor:
372
+ return tensor if dtype is None else tensor.to(dtype)
373
+
374
+
375
+ def generate_requests( # noqa C901
376
+ iters: int,
377
+ B: int,
378
+ T: int,
379
+ L: int,
380
+ E: int,
381
+ # inter-batch indices reuse rate
382
+ reuse: float = 0.0,
383
+ # alpha <= 1.0: use uniform distribution
384
+ # alpha > 1.0: use zipf distribution
385
+ alpha: float = 1.0,
386
+ zipf_oversample_ratio: int = 3,
387
+ weighted: bool = False,
388
+ requests_data_file: Optional[str] = None,
389
+ # Path to file containing indices and offsets. If provided, this will be used
390
+ indices_file: Optional[str] = None,
391
+ offsets_file: Optional[str] = None,
392
+ # Comma-separated list of table numbers
393
+ tables: Optional[str] = None,
394
+ # If sigma_L is not None, treat L as mu_L and generate Ls from sigma_L
395
+ # and mu_L
396
+ sigma_L: Optional[int] = None,
397
+ # If sigma_B is not None, treat B as mu_B and generate Bs from sigma_B
398
+ sigma_B: Optional[int] = None,
399
+ emulate_pruning: bool = False,
400
+ use_cpu: bool = False,
401
+ # generate_requests uses numpy.random.default_rng without a set random seed
402
+ # be default, causing the indices tensor to vary with each call to
403
+ # generate_requests - set generate_repeatable_output to use a fixed random
404
+ # seed instead for repeatable outputs
405
+ deterministic_output: bool = False,
406
+ # distribution of embedding sequence lengths
407
+ length_dist: str = "normal",
408
+ # distribution of batch sizes
409
+ batch_size_dist: str = "normal",
410
+ # Number of ranks for variable batch size generation
411
+ vbe_num_ranks: Optional[int] = None,
412
+ index_dtype: Optional[torch.dtype] = None,
413
+ offset_dtype: Optional[torch.dtype] = None,
414
+ ) -> list[TBERequest]:
415
+ # TODO: refactor and split into helper functions to separate load from file,
416
+ # generate from distribution, and other future methods of generating data
417
+ if (
418
+ requests_data_file is not None
419
+ or indices_file is not None
420
+ or offsets_file is not None
421
+ ):
422
+
423
+ assert sigma_L is None, "Variable pooling factors is not supported"
424
+ assert sigma_B is None, "Variable batch sizes is not supported"
425
+ return generate_requests_from_data_file(
426
+ iters=iters,
427
+ B=B,
428
+ T=T,
429
+ L=L,
430
+ E=E,
431
+ weighted=weighted,
432
+ requests_data_file=requests_data_file,
433
+ indices_file=indices_file,
434
+ offsets_file=offsets_file,
435
+ tables=tables,
436
+ index_dtype=index_dtype,
437
+ offset_dtype=offset_dtype,
438
+ )
439
+
440
+ if sigma_B is not None:
441
+ assert (
442
+ vbe_num_ranks is not None
443
+ ), "vbe_num_ranks must be set for varaible batch size generation"
444
+ use_variable_B = True
445
+ Bs, Bs_feature_rank = generate_batch_sizes_from_stats(
446
+ B, T, sigma_B, vbe_num_ranks, batch_size_dist
447
+ )
448
+ else:
449
+ use_variable_B = False
450
+ Bs = [B] * T
451
+ Bs_feature_rank = None
452
+
453
+ if sigma_L is not None:
454
+ # Generate L from stats
455
+ use_variable_L = True
456
+ L, L_offsets = generate_pooling_factors_from_stats(
457
+ iters, Bs, L, sigma_L, length_dist
458
+ )
459
+ elif use_variable_B:
460
+ use_variable_L = False
461
+ Ls = [L] * (sum(Bs) * iters)
462
+ L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
463
+ else:
464
+ use_variable_L = False
465
+ # Init to suppress the pyre error
466
+ L_offsets = torch.empty(1)
467
+
468
+ if alpha <= 1.0:
469
+ # Generate indices using uniform dist
470
+ all_indices = generate_indices_uniform(
471
+ iters, Bs, L, E, use_variable_L, L_offsets
472
+ )
473
+ else:
474
+ # Generate indices using zipf dist
475
+ all_indices = generate_indices_zipf(
476
+ iters,
477
+ Bs,
478
+ L,
479
+ E,
480
+ alpha,
481
+ zipf_oversample_ratio,
482
+ use_variable_L,
483
+ L_offsets,
484
+ deterministic_output,
485
+ )
486
+
487
+ if reuse > 0.0:
488
+ assert (
489
+ not use_variable_L
490
+ ), "Does not support generating Ls from stats for reuse > 0.0"
491
+ all_indices = update_indices_with_random_reuse(iters, Bs, L, reuse, all_indices)
492
+
493
+ # Some indices are set to -1 for emulating pruned rows.
494
+ if emulate_pruning:
495
+ assert (
496
+ not use_variable_L
497
+ ), "Does not support generating Ls from stats for emulate_pruning=True"
498
+ assert (
499
+ not use_variable_B
500
+ ), "Does not support generating Bs from stats for emulate_pruning=True"
501
+
502
+ all_indices = update_indices_with_random_pruning(
503
+ iters, B, T, L, all_indices.view(iters, T, B * L)
504
+ )
505
+
506
+ # Pack requests
507
+ rs = []
508
+ if use_variable_L or use_variable_B:
509
+ total_B = sum(Bs)
510
+ all_indices = all_indices.flatten()
511
+ for it in range(iters):
512
+ start_offset = L_offsets[it * total_B]
513
+ it_L_offsets = torch.concat(
514
+ [
515
+ torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
516
+ L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
517
+ ]
518
+ )
519
+ weights_tensor = (
520
+ None
521
+ if not weighted
522
+ else torch.randn(
523
+ int(it_L_offsets[-1].item()), device=get_device()
524
+ ) # per sample weights will always be FP32
525
+ )
526
+ rs.append(
527
+ TBERequest(
528
+ maybe_to_dtype(
529
+ all_indices[start_offset : L_offsets[(it + 1) * total_B]],
530
+ index_dtype,
531
+ ),
532
+ maybe_to_dtype(it_L_offsets.to(get_device()), offset_dtype),
533
+ weights_tensor,
534
+ Bs_feature_rank if use_variable_B else None,
535
+ )
536
+ )
537
+ else:
538
+ for it in range(iters):
539
+ weights_tensor = (
540
+ None
541
+ if not weighted
542
+ else torch.randn(
543
+ T * B * L, device=get_device()
544
+ ) # per sample weights will always be FP32
545
+ )
546
+ indices, offsets = get_table_batched_offsets_from_dense(
547
+ all_indices[it].view(T, B, L), use_cpu=use_cpu
548
+ )
549
+ rs.append(
550
+ TBERequest(
551
+ maybe_to_dtype(indices, index_dtype),
552
+ maybe_to_dtype(offsets, offset_dtype),
553
+ weights_tensor,
554
+ )
555
+ )
556
+ return rs
@@ -0,0 +1,108 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+
10
+ import abc
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Optional
14
+
15
+ from torch import Tensor
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class TBEInfo:
20
+ """
21
+ contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65
22
+
23
+ Args:
24
+ table_names: table names within the tbe
25
+ table_heights: sharded table heights (hashsize)
26
+ tbe_uuid: a unique identifier for the TBE
27
+ feature_table_map: feature to table map
28
+ table_dims: sharded table dimensions
29
+ full_table_heights: table heights before sharding
30
+ full_table_dims: table dimensions before sharding
31
+ row_offset: the shard offset of the current rank on row (height)
32
+ col_offset: the shard offset of the current rank on column (dim)
33
+ """
34
+
35
+ table_names: list[str]
36
+ table_heights: list[int]
37
+ tbe_uuid: str
38
+ feature_table_map: list[int]
39
+ table_dims: list[int]
40
+ full_table_heights: list[int]
41
+ full_table_dims: list[int]
42
+ row_offset: list[int]
43
+ col_offset: list[int]
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class TBEInputInfo:
48
+ """
49
+ indices: A 1D-tensor that contains indices to be looked up
50
+ from all embedding table.
51
+ offsets: A 1D-tensor that conatins offsets of indices.
52
+ batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
53
+ every feature. this is needed to support VBE.
54
+ """
55
+
56
+ indices: Tensor
57
+ offsets: Tensor
58
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None
59
+
60
+
61
+ class TBEInputMultiplexer(abc.ABC):
62
+ """
63
+ Interface for multiplex TBE input data out, actual implementation may store the data to files
64
+ """
65
+
66
+ @abc.abstractmethod
67
+ def should_run(self, step: int) -> bool:
68
+ """
69
+ To check if should run at this step
70
+ Args:
71
+ step: the current step
72
+ Returns:
73
+ True if should run, otherwise False
74
+ """
75
+ pass
76
+
77
+ @abc.abstractmethod
78
+ def run(
79
+ self,
80
+ tbe_input_info: TBEInputInfo,
81
+ ) -> None:
82
+ """
83
+ To run the tbe input multiplex, and this is called for every batch that needs to be dumped
84
+ Args:
85
+ tbe_input_info: tbe input info that contains all the necessary info for further processing
86
+ """
87
+ pass
88
+
89
+
90
+ @dataclass(frozen=True)
91
+ class TBEInputMultiplexerConfig:
92
+ """
93
+ Configuration for TBEInputMultiplexer
94
+ """
95
+
96
+ # first batch to start run, -1 means no run
97
+ start_batch: int = -1
98
+ # total batch to multiplex
99
+ total_batch: int = 0
100
+
101
+ def create_tbe_input_multiplexer(
102
+ self,
103
+ tbe_info: TBEInfo,
104
+ ) -> Optional[TBEInputMultiplexer]:
105
+ assert (
106
+ self.start_batch == -1
107
+ ), "Cannot specify monitor_start_batch without an actual implementation."
108
+ return None
@@ -0,0 +1,22 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+
10
+ # Attempt to import triton kernels, fallback to reference if we cannot.
11
+ from .common import RoundingMode # noqa
12
+
13
+ try:
14
+ from .quantize import (
15
+ triton_dequantize_mx4 as dequantize_mx4,
16
+ triton_quantize_mx4 as quantize_mx4,
17
+ )
18
+ except ImportError:
19
+ from .quantize_ref import ( # noqa: F401, E402
20
+ py_dequantize_mx4 as dequantize_mx4,
21
+ py_quantize_mx4 as quantize_mx4,
22
+ )