fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ..common import ASSOC # noqa: F401
@@ -0,0 +1,273 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from __future__ import annotations
10
+
11
+ import functools
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+
16
+ _HANDLED_FUNCTIONS = {}
17
+
18
+
19
+ def implements(torch_function):
20
+ def decorator(func):
21
+ functools.update_wrapper(func, torch_function)
22
+ _HANDLED_FUNCTIONS[torch_function] = func
23
+ return func
24
+
25
+ return decorator
26
+
27
+
28
+ class PartiallyMaterializedTensor:
29
+ """
30
+ A tensor-like object that represents a partially materialized tensor in memory.
31
+
32
+ Caller can use `narrow()` to get a view of the backing storage,
33
+ or use `full_tensor()` to get the full tensor (this could OOM).
34
+ """
35
+
36
+ def __init__(self, wrapped, is_virtual: bool = False) -> None:
37
+ """
38
+ Ensure caller loads the module before creating this object.
39
+
40
+ ```
41
+ load_torch_module(
42
+ "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
43
+ )
44
+ ```
45
+
46
+ Args:
47
+
48
+ wrapped: torch.classes.fbgemm.KVTensorWrapper
49
+ """
50
+ self._wrapped = wrapped
51
+ self._is_virtual = is_virtual
52
+ self._requires_grad = False
53
+
54
+ @property
55
+ def wrapped(self):
56
+ """
57
+ Get the wrapped extension class for C++ interop.
58
+ """
59
+ return self._wrapped
60
+
61
+ @property
62
+ def is_virtual(self):
63
+ """
64
+ Indicate whether PMT is a virtual tensor.
65
+ This indicator is needed for checkpoint or publish.
66
+ They need to know wheether it is PMT for kvzch or for normal emb table
67
+ for kvzch, checkpoint and publish need to call all-gather to recalculate the correct
68
+ metadata of the ShardedTensor
69
+ """
70
+ return self._is_virtual
71
+
72
+ @classmethod
73
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
74
+ if kwargs is None:
75
+ kwargs = {}
76
+ if func not in _HANDLED_FUNCTIONS:
77
+ return NotImplemented
78
+ return _HANDLED_FUNCTIONS[func](cls, *args, **kwargs)
79
+
80
+ @implements(torch.narrow)
81
+ def narrow(self, dim: int, start: int, length: int) -> torch.Tensor:
82
+ """
83
+ This loads a narrowed view of the backing storage.
84
+
85
+ Returns:
86
+ a torch tensor
87
+ """
88
+ return self._wrapped.narrow(dim, start, length)
89
+
90
+ def set_weights_and_ids(self, weights: torch.Tensor, ids: torch.Tensor) -> None:
91
+ self._wrapped.set_weights_and_ids(weights, ids)
92
+
93
+ def get_weights_by_ids(self, ids: torch.Tensor) -> torch.Tensor:
94
+ return self._wrapped.get_weights_by_ids(ids)
95
+
96
+ def __reduce__(self):
97
+ return (
98
+ PartiallyMaterializedTensor,
99
+ (self._wrapped,),
100
+ )
101
+
102
+ def full_tensor(self) -> torch.Tensor:
103
+ """
104
+ This loads the full tensor into memory (may OOM).
105
+
106
+ Returns:
107
+ a torch tensor
108
+ """
109
+ return self.narrow(0, 0, self.size(0))
110
+
111
+ @implements(torch.detach)
112
+ def detach(self) -> PartiallyMaterializedTensor:
113
+ self._requires_grad = False
114
+ return self
115
+
116
+ def to(self, *args, **kwargs) -> PartiallyMaterializedTensor:
117
+ return self
118
+
119
+ def is_floating_point(self):
120
+ # this class only deals with embedding vectors
121
+ return True
122
+
123
+ @implements(torch._has_compatible_shallow_copy_type)
124
+ def _has_compatible_shallow_copy_type(*args, **kwargs):
125
+ return False
126
+
127
+ def requires_grad_(self, requires_grad=True) -> PartiallyMaterializedTensor:
128
+ self._requires_grad = requires_grad
129
+ return self
130
+
131
+ @property
132
+ def requires_grad(self) -> bool:
133
+ return self._requires_grad
134
+
135
+ @property
136
+ def grad(self) -> Optional[torch.Tensor]:
137
+ return None
138
+
139
+ @property
140
+ def is_leaf(self) -> bool:
141
+ return True
142
+
143
+ @property
144
+ def shape(self) -> torch.Size:
145
+ """
146
+ Shape of the full tensor.
147
+ """
148
+ return torch.Size(self._wrapped.shape)
149
+
150
+ def size(self, dim: Optional[int] = None) -> Union[int, torch.Size]:
151
+ sz = self.shape
152
+ if dim is None:
153
+ return sz
154
+ if dim >= len(sz) or dim < 0:
155
+ raise IndexError(
156
+ f"Dimension out of range (expected to be {len(sz)}, but got {dim})"
157
+ )
158
+ return sz[dim]
159
+
160
+ def is_contiguous(self):
161
+ return True
162
+
163
+ def is_pinned(self):
164
+ return False
165
+
166
+ @property
167
+ def dtype(self) -> torch.dtype:
168
+ if isinstance(self._wrapped, torch.Tensor):
169
+ return self._wrapped.dtype
170
+ mapping = {"c10::Half": "half"}
171
+ dtype_str: str = self._wrapped.dtype_str
172
+ dtype_str = mapping.get(dtype_str, dtype_str)
173
+
174
+ dtype = getattr(torch, dtype_str)
175
+ assert isinstance(dtype, torch.dtype)
176
+ return dtype
177
+
178
+ @property
179
+ def device(self) -> torch.device:
180
+ if isinstance(self._wrapped, torch.Tensor):
181
+ return self._wrapped.device
182
+ device_str: str = self._wrapped.device_str
183
+ device = torch.device(device_str)
184
+ assert isinstance(device, torch.device)
185
+ return device
186
+
187
+ @property
188
+ def layout(self) -> torch.layout:
189
+ if isinstance(self._wrapped, torch.Tensor):
190
+ return self._wrapped.layout
191
+ layout_str_mapping = {
192
+ "SparseCsr": "sparse_csr",
193
+ "Strided": "strided",
194
+ "SparseCsc": "sparse_csc",
195
+ "Jagged": "jagged",
196
+ }
197
+ layout_str: str = self._wrapped.layout_str
198
+ layout_str = layout_str_mapping[layout_str]
199
+ layout = getattr(torch, layout_str)
200
+ assert isinstance(layout, torch.layout)
201
+ return layout
202
+
203
+ @property
204
+ def __class__(self):
205
+ # this is a hack to avoid assertion error in torch.nn.Module.register_parameter()
206
+ return torch.nn.Parameter
207
+
208
+ @property
209
+ def grad_fn(self):
210
+ return None
211
+
212
+ def view(self, *args, **kwargs):
213
+ return self
214
+
215
+ def is_meta(*args, **kwargs):
216
+ return False
217
+
218
+ def copy_(self, src, non_blocking=False):
219
+ # noop
220
+ pass
221
+
222
+ def numel(self):
223
+ return torch.tensor(self.shape).prod().item()
224
+
225
+ def nelement(self):
226
+ return torch.tensor(self.shape).prod().item()
227
+
228
+ def element_size(self):
229
+ return torch.tensor([], dtype=self.dtype).element_size()
230
+
231
+ def __deepcopy__(self, memo):
232
+ # torch.classes.fbgemm.KVTensorWrapper doesn't support deepcopy
233
+ new_obj = PartiallyMaterializedTensor(self._wrapped)
234
+ memo[id(self)] = new_obj
235
+ return new_obj
236
+
237
+ def required_grad(self) -> bool:
238
+ return True
239
+
240
+ @property
241
+ def is_quantized(self) -> bool:
242
+ return False
243
+
244
+ @implements(torch.equal)
245
+ def __eq__(self, tensor1, tensor2, **kwargs):
246
+ if not isinstance(tensor2, PartiallyMaterializedTensor):
247
+ return False
248
+
249
+ return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
250
+
251
+ def get_kvtensor_serializable_metadata(self) -> list[str]:
252
+ return self._wrapped.get_kvtensor_serializable_metadata()
253
+
254
+ def __hash__(self):
255
+ return id(self)
256
+
257
+ @property
258
+ def is_mps(self):
259
+ return False
260
+
261
+ @property
262
+ def is_sparse(self):
263
+ return False
264
+
265
+ @implements(torch.isclose)
266
+ def isclose(self, tensor1, tensor2, rtol=1e-05, atol=1e-08, equal_nan=False):
267
+ return torch.isclose(
268
+ tensor1.full_tensor(),
269
+ tensor2.full_tensor(),
270
+ rtol=rtol,
271
+ atol=atol,
272
+ equal_nan=equal_nan,
273
+ )
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401
@@ -0,0 +1,339 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import io
11
+ import json
12
+ import logging
13
+ import os
14
+ from typing import Optional
15
+
16
+ import fbgemm_gpu # noqa F401
17
+ import torch # usort:skip
18
+
19
+ from fbgemm_gpu.tbe.bench.tbe_data_config import (
20
+ BatchParams,
21
+ IndicesParams,
22
+ PoolingParams,
23
+ TBEDataConfig,
24
+ )
25
+
26
+ open_source: bool = False
27
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
28
+ open_source: bool = getattr(fbgemm_gpu, "open_source", False)
29
+
30
+ if open_source:
31
+ from fbgemm_gpu.utils import FileStore
32
+
33
+ else:
34
+ try:
35
+ from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
36
+
37
+ torch.ops.load_library(
38
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_estimator"
39
+ )
40
+ except Exception:
41
+ pass
42
+
43
+
44
+ class TBEBenchmarkParamsReporter:
45
+ """
46
+ TBEBenchmarkParamsReporter is responsible for extracting and reporting the configuration data of TBE processes.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ report_interval: int,
52
+ report_iter_start: int = 0,
53
+ report_iter_end: int = -1,
54
+ bucket: Optional[str] = None,
55
+ path_prefix: Optional[str] = None,
56
+ ) -> None:
57
+ """
58
+ Initializes the TBEBenchmarkParamsReporter with the specified parameters.
59
+
60
+ Args:
61
+ report_interval (int): The interval at which reports are generated.
62
+ report_iter_start (int): The start of the iteration range to capture. Defaults to 0.
63
+ report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration).
64
+ bucket (Optional[str], optional): The storage bucket for reports. Defaults to None.
65
+ path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None.
66
+ """
67
+
68
+ assert report_interval > 0, "report_interval must be greater than 0"
69
+ assert (
70
+ report_iter_start >= 0
71
+ ), "report_iter_start must be greater than or equal to 0"
72
+ assert (
73
+ report_iter_end >= -1
74
+ ), "report_iter_end must be greater than or equal to -1"
75
+ assert (
76
+ report_iter_end == -1 or report_iter_start <= report_iter_end
77
+ ), "report_iter_start must be less than or equal to report_iter_end"
78
+
79
+ self.report_interval = report_interval
80
+ self.report_iter_start = report_iter_start
81
+ self.report_iter_end = report_iter_end
82
+
83
+ if path_prefix is not None and path_prefix.endswith("/"):
84
+ path_prefix = path_prefix[:-1]
85
+
86
+ self.path_prefix = path_prefix
87
+
88
+ default_bucket = "/tmp" if open_source else "tlparse_reports"
89
+ bucket = (
90
+ bucket
91
+ if bucket is not None
92
+ else os.environ.get("FBGEMM_TBE_REPORTING_BUCKET", default_bucket)
93
+ )
94
+ self.filestore = FileStore(bucket)
95
+
96
+ if self.path_prefix is not None and not self.filestore.exists(self.path_prefix):
97
+ self.filestore.create_directory(self.path_prefix)
98
+
99
+ self.logger: logging.Logger = logging.getLogger(__name__)
100
+ self.logger.setLevel(logging.INFO)
101
+
102
+ @classmethod
103
+ def create(cls) -> "TBEBenchmarkParamsReporter":
104
+ """
105
+ This method returns an instance of TBEBenchmarkParamsReporter based on environment variables.
106
+
107
+ If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that:
108
+ - Reports input parameters (TBEDataConfig).
109
+ - Writes the output as a JSON file.
110
+
111
+ Additionally, the following environment variables are considered:
112
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture.
113
+ - `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture.
114
+ - `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting.
115
+ - `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting.
116
+
117
+ Returns:
118
+ TBEBenchmarkParamsReporter: An instance configured based on the environment variables.
119
+ """
120
+ report_interval = int(
121
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1")
122
+ )
123
+ report_iter_start = int(
124
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0")
125
+ )
126
+ report_iter_end = int(
127
+ os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1")
128
+ )
129
+ bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "")
130
+ path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "")
131
+
132
+ return cls(
133
+ report_interval=report_interval,
134
+ report_iter_start=report_iter_start,
135
+ report_iter_end=report_iter_end,
136
+ bucket=bucket,
137
+ path_prefix=path_prefix,
138
+ )
139
+
140
+ def extract_params(
141
+ self,
142
+ feature_rows: torch.Tensor,
143
+ feature_dims: torch.Tensor,
144
+ indices: torch.Tensor,
145
+ offsets: torch.Tensor,
146
+ per_sample_weights: Optional[torch.Tensor] = None,
147
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
148
+ ) -> TBEDataConfig:
149
+ """
150
+ Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig.
151
+
152
+ Args:
153
+ feature_rows (torch.Tensor): Number of rows in each feature.
154
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
155
+ indices (torch.Tensor): The input indices tensor.
156
+ offsets (torch.Tensor): The input offsets tensor.
157
+ per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
158
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
159
+
160
+ Returns:
161
+ TBEDataConfig: The configuration data for TBE benchmarking.
162
+ """
163
+
164
+ Es = feature_rows.tolist()
165
+ Ds = feature_dims.tolist()
166
+
167
+ assert len(Es) == len(
168
+ Ds
169
+ ), "feature_rows and feature_dims must have the same length"
170
+
171
+ # Transfer indices back to CPU for EEG analysis
172
+ indices_cpu = indices.cpu()
173
+
174
+ # Set T to be the number of features we are looking at
175
+ T = len(Ds)
176
+ # Set E to be the mean of the rowcounts to avoid biasing
177
+ E = (
178
+ Es[0]
179
+ if len(set(Es)) == 1
180
+ else torch.ceil(
181
+ torch.mean(torch.tensor(feature_rows, dtype=torch.float))
182
+ ).item()
183
+ )
184
+ # Set mixed_dim to be True if there are multiple dims
185
+ mixed_dim = len(set(Ds)) > 1
186
+ # Set D to be the mean of the dims to avoid biasing
187
+ D = (
188
+ Ds[0]
189
+ if not mixed_dim
190
+ else torch.ceil(
191
+ torch.mean(torch.tensor(feature_dims, dtype=torch.float))
192
+ ).item()
193
+ )
194
+
195
+ # Compute indices distribution parameters
196
+ heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
197
+ indices_cpu
198
+ )
199
+ indices_params = IndicesParams(
200
+ heavy_hitters, q, s, indices.dtype, offsets.dtype
201
+ )
202
+
203
+ # Compute batch parameters
204
+ batch_params = BatchParams(
205
+ B=int((offsets.numel() - 1) // T),
206
+ sigma_B=(
207
+ int(
208
+ torch.ceil(
209
+ torch.std(
210
+ torch.tensor(
211
+ [
212
+ b
213
+ for bs in batch_size_per_feature_per_rank
214
+ for b in bs
215
+ ]
216
+ ).float()
217
+ )
218
+ )
219
+ )
220
+ if batch_size_per_feature_per_rank
221
+ else None
222
+ ),
223
+ vbe_distribution=("normal" if batch_size_per_feature_per_rank else None),
224
+ vbe_num_ranks=(
225
+ len(batch_size_per_feature_per_rank)
226
+ if batch_size_per_feature_per_rank
227
+ else None
228
+ ),
229
+ )
230
+
231
+ # Compute pooling parameters
232
+ bag_sizes = offsets[1:] - offsets[:-1]
233
+ mixed_bag_sizes = len(set(bag_sizes)) > 1
234
+ pooling_params = PoolingParams(
235
+ L=(
236
+ int(torch.ceil(torch.mean(bag_sizes.float())))
237
+ if mixed_bag_sizes
238
+ else int(bag_sizes[0])
239
+ ),
240
+ sigma_L=(
241
+ int(torch.ceil(torch.std(bag_sizes.float())))
242
+ if mixed_bag_sizes
243
+ else None
244
+ ),
245
+ length_distribution=("normal" if mixed_bag_sizes else None),
246
+ )
247
+
248
+ return TBEDataConfig(
249
+ T=T,
250
+ E=E,
251
+ D=D,
252
+ mixed_dim=mixed_dim,
253
+ weighted=(per_sample_weights is not None),
254
+ batch_params=batch_params,
255
+ indices_params=indices_params,
256
+ pooling_params=pooling_params,
257
+ use_cpu=(not torch.cuda.is_available()),
258
+ )
259
+
260
+ def report_stats(
261
+ self,
262
+ feature_rows: torch.Tensor,
263
+ feature_dims: torch.Tensor,
264
+ iteration: int,
265
+ indices: torch.Tensor,
266
+ offsets: torch.Tensor,
267
+ op_id: str = "",
268
+ per_sample_weights: Optional[torch.Tensor] = None,
269
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
270
+ ) -> None:
271
+ """
272
+ Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore.
273
+
274
+ Args:
275
+ feature_rows (torch.Tensor): Number of rows in each feature.
276
+ feature_dims (torch.Tensor): Number of dimensions in each feature.
277
+ iteration (int): The current iteration number.
278
+ indices (torch.Tensor): The input indices tensor.
279
+ offsets (torch.Tensor): The input offsets tensor.
280
+ op_id (str, optional): The operation identifier. Defaults to an empty string.
281
+ per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None.
282
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None.
283
+ """
284
+ if (
285
+ (iteration - self.report_iter_start) % self.report_interval == 0
286
+ and (iteration >= self.report_iter_start)
287
+ and (self.report_iter_end == -1 or iteration <= self.report_iter_end)
288
+ ):
289
+ # If indices tensor is empty (indices.numel() == 0), skip reporting
290
+ # TODO: Remove this once we have a better way to handle empty indices tensors
291
+ if indices.numel() == 0:
292
+ return
293
+
294
+ # Extract TBE config
295
+ config = self.extract_params(
296
+ feature_rows=feature_rows,
297
+ feature_dims=feature_dims,
298
+ indices=indices,
299
+ offsets=offsets,
300
+ per_sample_weights=per_sample_weights,
301
+ batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
302
+ )
303
+
304
+ # Ad-hoc fix for adding Es and Ds to JSON output
305
+ # TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig
306
+ adhoc_config = config.dict()
307
+ adhoc_config["Es"] = feature_rows.tolist()
308
+ adhoc_config["Ds"] = feature_dims.tolist()
309
+ if batch_size_per_feature_per_rank:
310
+ adhoc_config["Bs"] = [
311
+ sum(batch_size_per_feature_per_rank[f])
312
+ for f in range(len(adhoc_config["Es"]))
313
+ ]
314
+
315
+ bag_sizes = (offsets[1:] - offsets[:-1]).tolist()
316
+ adhoc_config["Ls"] = []
317
+ pointer_counter = 0
318
+ if batch_size_per_feature_per_rank:
319
+ for batchs_size in adhoc_config["Bs"]:
320
+ current_L = 0
321
+ for _i in range(batchs_size):
322
+ current_L += bag_sizes[pointer_counter]
323
+ pointer_counter += 1
324
+ adhoc_config["Ls"].append(current_L / batchs_size)
325
+ else:
326
+ batch_size = int(len(bag_sizes) // len(adhoc_config["Es"]))
327
+
328
+ for _j in range(len(adhoc_config["Es"])):
329
+ current_L = 0
330
+ for _i in range(batch_size):
331
+ current_L += bag_sizes[pointer_counter]
332
+ pointer_counter += 1
333
+ adhoc_config["Ls"].append(current_L / batch_size)
334
+
335
+ # Write the TBE config to FileStore
336
+ self.filestore.write(
337
+ f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json",
338
+ io.BytesIO(json.dumps(adhoc_config, indent=2).encode()),
339
+ )
@@ -0,0 +1,13 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-unsafe
9
+
10
+ from .common import get_device, round_up, to_device # noqa: F401
11
+ from .offsets import b_indices, get_table_batched_offsets_from_dense # noqa: F401
12
+ from .quantize import dequantize_embs, fake_quantize_embs, quantize_embs # noqa: F401
13
+ from .requests import generate_requests, TBERequest # noqa: F401
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from typing import TypeVar
10
+
11
+ import torch
12
+
13
+ Deviceable = TypeVar(
14
+ "Deviceable", torch.nn.EmbeddingBag, torch.nn.Embedding, torch.Tensor
15
+ )
16
+
17
+
18
+ def round_up(a: int, b: int) -> int:
19
+ return int((a + b - 1) // b) * b
20
+
21
+
22
+ def get_device() -> torch.device:
23
+ if torch.cuda.is_available():
24
+ # pyre-fixme[7]: Expected `device` but got `Union[int, device]`.
25
+ return torch.cuda.current_device()
26
+ elif torch.mtia.is_available():
27
+ # pyre-fixme[7]: Expected `device` but got `Union[int, device]`.
28
+ return torch.mtia.current_device()
29
+ else:
30
+ return torch.device("cpu")
31
+
32
+
33
+ def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
34
+ if use_cpu:
35
+ # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
36
+ return t.cpu()
37
+ elif torch.cuda.is_available():
38
+ # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
39
+ return t.cuda()
40
+ else:
41
+ # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor, torch.nn.EmbeddingBag]`.
42
+ return t.to(device="mtia")