fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,586 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+ # pyre-ignore-all-errors[56]
10
+
11
+ import itertools
12
+ import logging
13
+ import os
14
+ import tempfile
15
+ from math import log2
16
+ from typing import Optional
17
+
18
+ import torch # usort:skip
19
+
20
+ from fbgemm_gpu.split_embedding_configs import SparseType
21
+ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
22
+ CacheAlgorithm,
23
+ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
24
+ EmbeddingLocation,
25
+ PoolingMode,
26
+ )
27
+ from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
28
+ align_to_cacheline,
29
+ rounded_row_size_in_bytes,
30
+ unpadded_row_size_in_bytes,
31
+ )
32
+
33
+ from torch import distributed as dist, nn, Tensor # usort:skip
34
+ from torch.autograd.profiler import record_function
35
+
36
+ from .common import ASSOC
37
+
38
+
39
+ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
40
+ """
41
+ SSD Table-batched version of nn.EmbeddingBag(sparse=False)
42
+ Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports
43
+ """
44
+
45
+ embedding_specs: list[tuple[str, int, int, SparseType]]
46
+ _local_instance_index: int = -1
47
+
48
+ def __init__(
49
+ self,
50
+ embedding_specs: list[
51
+ tuple[str, int, int, SparseType]
52
+ ], # tuple of (feature_names, rows, dims, SparseType)
53
+ feature_table_map: Optional[list[int]] = None, # [T]
54
+ pooling_mode: PoolingMode = PoolingMode.SUM,
55
+ output_dtype: SparseType = SparseType.FP16,
56
+ row_alignment: Optional[int] = None,
57
+ fp8_exponent_bits: Optional[int] = None,
58
+ fp8_exponent_bias: Optional[int] = None,
59
+ cache_assoc: int = 32,
60
+ scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
61
+ cache_sets: int = 0,
62
+ ssd_storage_directory: str = "/tmp",
63
+ ssd_shards: int = 1,
64
+ ssd_memtable_flush_period: int = -1,
65
+ ssd_memtable_flush_offset: int = -1,
66
+ ssd_l0_files_per_compact: int = 4,
67
+ ssd_rate_limit_mbps: int = 0,
68
+ ssd_size_ratio: int = 10,
69
+ ssd_compaction_trigger: int = 8,
70
+ ssd_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
71
+ ssd_max_write_buffer_num: int = 16,
72
+ ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
73
+ ssd_uniform_init_lower: float = -0.01,
74
+ ssd_uniform_init_upper: float = 0.01,
75
+ # Parameter Server Configs
76
+ ps_hosts: Optional[tuple[tuple[str, int]]] = None,
77
+ ps_max_key_per_request: Optional[int] = None,
78
+ ps_client_thread_num: Optional[int] = None,
79
+ ps_max_local_index_length: Optional[int] = None,
80
+ tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
81
+ ) -> None: # noqa C901 # tuple of (rows, dims,)
82
+ super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()
83
+
84
+ assert cache_assoc == 32, "Only 32-way cache is supported now"
85
+
86
+ self.scale_bias_size_in_bytes = scale_bias_size_in_bytes
87
+ self.pooling_mode = pooling_mode
88
+ self.embedding_specs = embedding_specs
89
+ T_ = len(self.embedding_specs)
90
+ assert T_ > 0
91
+ device = torch.cuda.current_device()
92
+ if device is None:
93
+ self.current_device: torch.device = torch.device(
94
+ torch.cuda.current_device()
95
+ )
96
+ elif isinstance(device, torch.device):
97
+ self.current_device = device
98
+ else:
99
+ self.current_device = torch.device(device)
100
+ self.use_cpu: bool = self.current_device.type == "cpu"
101
+
102
+ self.feature_table_map: list[int] = (
103
+ feature_table_map if feature_table_map is not None else list(range(T_))
104
+ )
105
+ T = len(self.feature_table_map)
106
+ assert T_ <= T
107
+ table_has_feature = [False] * T_
108
+ for t in self.feature_table_map:
109
+ table_has_feature[t] = True
110
+ assert all(table_has_feature), "Each table must have at least one feature!"
111
+
112
+ self.output_dtype: int = output_dtype.as_int()
113
+ # (feature_names, rows, dims, weights_tys) = zip(*embedding_specs)
114
+ # Pyre workaround
115
+ rows: list[int] = [e[1] for e in embedding_specs]
116
+ dims: list[int] = [e[2] for e in embedding_specs]
117
+ weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
118
+
119
+ D_offsets = [dims[t] for t in self.feature_table_map]
120
+ D_offsets = [0] + list(itertools.accumulate(D_offsets))
121
+ self.total_D: int = D_offsets[-1]
122
+ self.register_buffer(
123
+ "D_offsets",
124
+ torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
125
+ )
126
+
127
+ if row_alignment is None:
128
+ self.row_alignment: int = 1 if self.use_cpu else 16
129
+ else:
130
+ self.row_alignment = row_alignment
131
+
132
+ for dim, weight_ty in zip(dims, weights_tys):
133
+ if not weight_ty.is_float():
134
+ assert (
135
+ dim % (8 / weight_ty.bit_rate()) == 0
136
+ ), f"For quantized types we need to at least pack at byte granularity, dim: {dim}, weight_ty: {weight_ty}"
137
+
138
+ def max_ty_D(ty: SparseType) -> int:
139
+ return max(
140
+ [dim for dim, weight_ty in zip(dims, weights_tys) if weight_ty == ty],
141
+ default=0,
142
+ )
143
+
144
+ self.max_int2_D: int = max_ty_D(SparseType.INT2)
145
+ self.max_int4_D: int = max_ty_D(SparseType.INT4)
146
+ self.max_int8_D: int = max_ty_D(SparseType.INT8)
147
+ self.max_float8_D: int = max_ty_D(SparseType.FP8)
148
+ self.max_float16_D: int = max_ty_D(SparseType.FP16)
149
+ self.max_float32_D: int = max_ty_D(SparseType.FP32)
150
+
151
+ cached_dims = [
152
+ rounded_row_size_in_bytes(
153
+ embedding_spec[2], embedding_spec[3], 16, self.scale_bias_size_in_bytes
154
+ )
155
+ for embedding_spec in self.embedding_specs
156
+ ]
157
+ self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0
158
+
159
+ placements = []
160
+ offsets = []
161
+ uvm_size = 0
162
+ for _, num_embeddings, embedding_dim, weight_ty in embedding_specs:
163
+ embedding_dim = rounded_row_size_in_bytes(
164
+ embedding_dim, weight_ty, self.row_alignment, scale_bias_size_in_bytes
165
+ )
166
+ state_size = num_embeddings * embedding_dim
167
+ state_size = align_to_cacheline(state_size)
168
+ placements.append(EmbeddingLocation.MANAGED_CACHING)
169
+ offsets.append(uvm_size)
170
+ uvm_size += state_size
171
+
172
+ self.weights_physical_offsets: list[int] = offsets
173
+
174
+ weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
175
+ self.register_buffer(
176
+ "weights_tys",
177
+ torch.tensor(
178
+ weights_tys_int, device=self.current_device, dtype=torch.uint8
179
+ ),
180
+ )
181
+ self.weight_initialized: bool = True
182
+
183
+ assert self.D_offsets.numel() == T + 1
184
+ hash_size_cumsum = [0] + list(itertools.accumulate(rows))
185
+ if hash_size_cumsum[-1] == 0:
186
+ self.total_hash_size_bits: int = 0
187
+ else:
188
+ self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
189
+ # The last element is to easily access # of rows of each table by
190
+ self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
191
+ self.total_hash_size: int = hash_size_cumsum[-1]
192
+ # The last element is to easily access # of rows of each table by
193
+ # hash_size_cumsum[t + 1] - hash_size_cumsum[t]
194
+ hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [
195
+ hash_size_cumsum[-1]
196
+ ]
197
+ self.register_buffer(
198
+ "hash_size_cumsum",
199
+ torch.tensor(
200
+ hash_size_cumsum, device=self.current_device, dtype=torch.int64
201
+ ),
202
+ )
203
+ assert cache_sets > 0
204
+ element_size = 1
205
+ cache_size = cache_sets * ASSOC * element_size * self.max_D_cache
206
+ logging.info(
207
+ f"Using cache for SSD with admission algorithm "
208
+ f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_shards} shards, "
209
+ f"SSD storage directory: {ssd_storage_directory}, "
210
+ f"Memtable Flush Period: {ssd_memtable_flush_period}, "
211
+ f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
212
+ f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
213
+ f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
214
+ f"output dtype: {output_dtype}"
215
+ )
216
+ self.register_buffer(
217
+ "lxu_cache_state",
218
+ torch.zeros(cache_sets, ASSOC, dtype=torch.int64).fill_(-1),
219
+ )
220
+ self.register_buffer(
221
+ "lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64)
222
+ )
223
+
224
+ assert ssd_cache_location in (
225
+ EmbeddingLocation.MANAGED,
226
+ EmbeddingLocation.DEVICE,
227
+ )
228
+ if ssd_cache_location == EmbeddingLocation.MANAGED:
229
+ self.register_buffer(
230
+ "lxu_cache_weights",
231
+ torch.ops.fbgemm.new_managed_tensor(
232
+ torch.zeros(1, device=self.current_device, dtype=torch.uint8),
233
+ [cache_sets * ASSOC, self.max_D_cache],
234
+ ),
235
+ )
236
+ else:
237
+ self.register_buffer(
238
+ "lxu_cache_weights",
239
+ torch.zeros(
240
+ cache_sets * ASSOC,
241
+ self.max_D_cache,
242
+ device=self.current_device,
243
+ dtype=torch.uint8,
244
+ ),
245
+ )
246
+
247
+ assert (
248
+ cache_size
249
+ == self.lxu_cache_weights.numel()
250
+ * self.lxu_cache_weights.element_size()
251
+ ), "The precomputed cache_size does not match the actual cache size"
252
+
253
+ os.makedirs(ssd_storage_directory, exist_ok=True)
254
+
255
+ ssd_directory = tempfile.mkdtemp(
256
+ prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
257
+ )
258
+ if not ps_hosts:
259
+ # pyre-fixme[4]: Attribute must be annotated.
260
+ # pyre-ignore[16]
261
+ self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
262
+ ssd_directory,
263
+ ssd_shards,
264
+ ssd_shards,
265
+ ssd_memtable_flush_period,
266
+ ssd_memtable_flush_offset,
267
+ ssd_l0_files_per_compact,
268
+ self.max_D_cache,
269
+ ssd_rate_limit_mbps,
270
+ ssd_size_ratio,
271
+ ssd_compaction_trigger,
272
+ ssd_write_buffer_size,
273
+ ssd_max_write_buffer_num,
274
+ ssd_uniform_init_lower,
275
+ ssd_uniform_init_upper,
276
+ 8, # row_storage_bitwidth
277
+ 0, # ssd_block_cache_size
278
+ )
279
+ else:
280
+ # create tbe unique id using rank index | pooling mode
281
+ if tbe_unique_id == -1:
282
+ SSDIntNBitTableBatchedEmbeddingBags._local_instance_index += 1
283
+ assert (
284
+ SSDIntNBitTableBatchedEmbeddingBags._local_instance_index < 8
285
+ ), f"{SSDIntNBitTableBatchedEmbeddingBags._local_instance_index}, more than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
286
+ tbe_unique_id = (
287
+ dist.get_rank() << 3
288
+ | SSDIntNBitTableBatchedEmbeddingBags._local_instance_index
289
+ )
290
+ logging.info(f"tbe_unique_id: {tbe_unique_id}")
291
+ # pyre-fixme[4]: Attribute must be annotated.
292
+ # pyre-ignore[16]
293
+ self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
294
+ [host[0] for host in ps_hosts],
295
+ [host[1] for host in ps_hosts],
296
+ tbe_unique_id,
297
+ (
298
+ ps_max_local_index_length
299
+ if ps_max_local_index_length is not None
300
+ else 54
301
+ ),
302
+ ps_client_thread_num if ps_client_thread_num is not None else 32,
303
+ ps_max_key_per_request if ps_max_key_per_request is not None else 500,
304
+ 0, # ssd_block_cache_size
305
+ self.max_D_cache,
306
+ )
307
+
308
+ # pyre-fixme[20]: Argument `self` expected.
309
+ (low_priority, high_priority) = torch.cuda.Stream.priority_range()
310
+ self.ssd_stream = torch.cuda.Stream(priority=low_priority)
311
+ self.ssd_set_start = torch.cuda.Event()
312
+ self.ssd_set_end = torch.cuda.Event()
313
+
314
+ # pyre-fixme[4]: Attribute must be annotated.
315
+ # pyre-ignore[16]
316
+ self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
317
+ # pyre-fixme[4]: Attribute must be annotated.
318
+ # pyre-ignore[16]
319
+ self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()
320
+
321
+ self.weights_dev: torch.Tensor = torch.empty(
322
+ 0,
323
+ device=self.current_device,
324
+ dtype=torch.uint8,
325
+ )
326
+ self.register_buffer(
327
+ "weights_uvm",
328
+ torch.tensor((0,), device=self.current_device, dtype=torch.uint8),
329
+ )
330
+ self.register_buffer(
331
+ "weights_host",
332
+ torch.empty(0),
333
+ )
334
+
335
+ self.register_buffer(
336
+ "weights_placements",
337
+ torch.tensor(
338
+ [EmbeddingLocation.MANAGED_CACHING for _ in range(T_)],
339
+ dtype=torch.int32,
340
+ ),
341
+ )
342
+ weights_offsets = [0] + list(
343
+ itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)])
344
+ )
345
+ self.register_buffer(
346
+ "weights_offsets",
347
+ torch.tensor(
348
+ weights_offsets[:-1],
349
+ device=self.current_device,
350
+ dtype=torch.int64,
351
+ ),
352
+ )
353
+
354
+ if self.max_float8_D > 0:
355
+ default_config = SparseType.FP8.default_config()
356
+ self.fp8_exponent_bits: int = (
357
+ default_config.get("exponent_bits")
358
+ if fp8_exponent_bits is None
359
+ else fp8_exponent_bits
360
+ )
361
+ self.fp8_exponent_bias: int = (
362
+ default_config.get("exponent_bias")
363
+ if fp8_exponent_bias is None
364
+ else fp8_exponent_bias
365
+ )
366
+ else:
367
+ self.fp8_exponent_bits = -1
368
+ self.fp8_exponent_bias = -1
369
+
370
+ @torch.jit.export
371
+ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
372
+ (indices, offsets) = indices.long(), offsets.long()
373
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
374
+ self.hash_size_cumsum,
375
+ indices,
376
+ offsets,
377
+ )
378
+ self.timestep_counter.increment()
379
+ self.timestep_prefetch_size.increment()
380
+ (
381
+ inserted_indices,
382
+ evicted_indices,
383
+ assigned_cache_slots,
384
+ actions_count_gpu,
385
+ _,
386
+ _,
387
+ _,
388
+ _,
389
+ ) = torch.ops.fbgemm.ssd_cache_populate_actions(
390
+ linear_cache_indices,
391
+ self.total_hash_size,
392
+ self.lxu_cache_state,
393
+ self.timestep_counter.get(),
394
+ 1, # for now assume prefetch_dist == 1
395
+ self.lru_state,
396
+ )
397
+ actions_count_cpu = torch.empty(
398
+ actions_count_gpu.shape, pin_memory=True, dtype=actions_count_gpu.dtype
399
+ )
400
+ actions_count_cpu.copy_(actions_count_gpu, non_blocking=True)
401
+ assigned_cache_slots = assigned_cache_slots.long()
402
+ evicted_rows = self.lxu_cache_weights[
403
+ assigned_cache_slots.clamp_(min=0).long(), :
404
+ ]
405
+ inserted_rows = torch.empty(
406
+ evicted_rows.shape,
407
+ dtype=self.lxu_cache_weights.dtype,
408
+ pin_memory=True,
409
+ )
410
+
411
+ current_stream = torch.cuda.current_stream()
412
+
413
+ # Ensure the previous iterations l3_db.set(..) has completed.
414
+ current_stream.wait_event(self.ssd_set_end)
415
+ inserted_indices_cpu = torch.empty(
416
+ inserted_indices.shape, pin_memory=True, dtype=inserted_indices.dtype
417
+ )
418
+ inserted_indices_cpu.copy_(inserted_indices, non_blocking=True)
419
+ self.ssd_db.get_cuda(
420
+ inserted_indices_cpu,
421
+ inserted_rows,
422
+ actions_count_cpu,
423
+ )
424
+ current_stream.record_event(self.ssd_set_start)
425
+ # TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
426
+ # Should we allocate on HBM?
427
+ inserted_rows_gpu = inserted_rows.to(self.current_device, non_blocking=True)
428
+
429
+ # self.lxu_cache_weights[assigned_cache_slots, :] = inserted_rows.cuda(non_blocking=True)
430
+ torch.ops.fbgemm.masked_index_put(
431
+ self.lxu_cache_weights,
432
+ assigned_cache_slots,
433
+ inserted_rows_gpu,
434
+ actions_count_gpu,
435
+ )
436
+
437
+ with torch.cuda.stream(self.ssd_stream):
438
+ self.ssd_stream.wait_event(self.ssd_set_start)
439
+ evicted_rows_cpu = torch.empty(
440
+ evicted_rows.shape, pin_memory=True, dtype=evicted_rows.dtype
441
+ )
442
+ evicted_rows_cpu.copy_(evicted_rows, non_blocking=True)
443
+ evicted_indices_cpu = torch.empty(
444
+ evicted_indices.shape, pin_memory=True, dtype=evicted_indices.dtype
445
+ )
446
+ evicted_indices_cpu.copy_(evicted_indices, non_blocking=True)
447
+ evicted_rows.record_stream(self.ssd_stream)
448
+ evicted_indices.record_stream(self.ssd_stream)
449
+ self.ssd_db.set_cuda(
450
+ evicted_indices_cpu,
451
+ evicted_rows_cpu,
452
+ actions_count_cpu,
453
+ self.timestep_counter.get(),
454
+ )
455
+ # TODO: is this needed?
456
+ # Need a way to synchronize
457
+ # actions_count_cpu.record_stream(self.ssd_stream)
458
+ self.ssd_stream.record_event(self.ssd_set_end)
459
+ return linear_cache_indices
460
+
461
+ def forward(
462
+ self,
463
+ indices: Tensor,
464
+ offsets: Tensor,
465
+ per_sample_weights: Optional[Tensor] = None,
466
+ ) -> Tensor:
467
+ if self.timestep_prefetch_size.get() <= 0:
468
+ with record_function("## prefetch ##"):
469
+ linear_cache_indices = self.prefetch(indices, offsets)
470
+ else:
471
+ linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
472
+ self.hash_size_cumsum,
473
+ indices,
474
+ offsets,
475
+ )
476
+ lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
477
+ linear_cache_indices,
478
+ self.lxu_cache_state,
479
+ self.total_hash_size,
480
+ )
481
+
482
+ self.timestep_prefetch_size.decrement()
483
+
484
+ assert (
485
+ self.weight_initialized
486
+ ), "weight needs to be initialized before forward function"
487
+
488
+ # Note: CPU and CUDA ops use the same interface to facilitate JIT IR
489
+ # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and
490
+ # weights_placements
491
+ return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(
492
+ dev_weights=self.weights_dev,
493
+ uvm_weights=self.weights_uvm,
494
+ weights_placements=self.weights_placements,
495
+ weights_offsets=self.weights_offsets,
496
+ weights_tys=self.weights_tys,
497
+ D_offsets=self.D_offsets,
498
+ total_D=self.total_D,
499
+ max_int2_D=self.max_int2_D,
500
+ max_int4_D=self.max_int4_D,
501
+ max_int8_D=self.max_int8_D,
502
+ max_float16_D=self.max_float16_D,
503
+ max_float32_D=self.max_float32_D,
504
+ indices=indices,
505
+ offsets=offsets,
506
+ pooling_mode=int(self.pooling_mode),
507
+ indice_weights=per_sample_weights,
508
+ output_dtype=self.output_dtype,
509
+ lxu_cache_weights=self.lxu_cache_weights,
510
+ lxu_cache_locations=lxu_cache_locations,
511
+ row_alignment=self.row_alignment,
512
+ max_float8_D=self.max_float8_D,
513
+ fp8_exponent_bits=self.fp8_exponent_bits,
514
+ fp8_exponent_bias=self.fp8_exponent_bias,
515
+ )
516
+
517
+ @torch.jit.export
518
+ def split_embedding_weights(
519
+ self, split_scale_shifts: bool = True
520
+ ) -> list[tuple[Tensor, Optional[Tensor]]]:
521
+ """
522
+ Returns a list of weights, split by table.
523
+
524
+ Testing only, very slow.
525
+ """
526
+ splits: list[tuple[Tensor, Optional[Tensor]]] = []
527
+ rows_cumsum = 0
528
+ for _, row, dim, weight_ty in self.embedding_specs:
529
+ weights = torch.empty(
530
+ (
531
+ row,
532
+ rounded_row_size_in_bytes(
533
+ dim,
534
+ weight_ty,
535
+ self.row_alignment,
536
+ self.scale_bias_size_in_bytes,
537
+ ),
538
+ ),
539
+ dtype=torch.uint8,
540
+ )
541
+ self.ssd_db.get_cuda(
542
+ torch.arange(rows_cumsum, rows_cumsum + row).to(torch.int64),
543
+ weights,
544
+ torch.as_tensor([row]),
545
+ )
546
+ rows_cumsum += row
547
+ torch.cuda.synchronize(self.current_device)
548
+
549
+ weights_shifts = weights.detach()
550
+
551
+ if split_scale_shifts:
552
+ # remove the padding at the end of each row.
553
+ weights_shifts = weights_shifts[
554
+ :,
555
+ : unpadded_row_size_in_bytes(
556
+ dim, weight_ty, self.scale_bias_size_in_bytes
557
+ ),
558
+ ]
559
+ if (
560
+ weight_ty == SparseType.INT8
561
+ or weight_ty == SparseType.INT4
562
+ or weight_ty == SparseType.INT2
563
+ ):
564
+ splits.append(
565
+ (
566
+ weights_shifts[:, self.scale_bias_size_in_bytes :],
567
+ weights_shifts[:, : self.scale_bias_size_in_bytes],
568
+ )
569
+ )
570
+ else:
571
+ assert (
572
+ weight_ty == SparseType.FP8
573
+ or weight_ty == SparseType.FP16
574
+ or weight_ty == SparseType.FP32
575
+ )
576
+ splits.append(
577
+ (
578
+ weights_shifts,
579
+ None,
580
+ )
581
+ )
582
+ else:
583
+ splits.append((weights_shifts, None))
584
+
585
+ torch.cuda.synchronize(self.current_device)
586
+ return splits