fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,556 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
# pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed).
|
|
18
|
+
from numpy.random import default_rng
|
|
19
|
+
|
|
20
|
+
from .common import get_device
|
|
21
|
+
from .offsets import get_table_batched_offsets_from_dense
|
|
22
|
+
|
|
23
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class TBERequest:
|
|
28
|
+
"""
|
|
29
|
+
`generate_requests`'s output wrapper
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
indices: torch.Tensor
|
|
33
|
+
offsets: torch.Tensor
|
|
34
|
+
per_sample_weights: Optional[torch.Tensor] = None
|
|
35
|
+
Bs_per_feature_per_rank: Optional[list[list[int]]] = None
|
|
36
|
+
|
|
37
|
+
def unpack_2(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
38
|
+
return (self.indices, self.offsets)
|
|
39
|
+
|
|
40
|
+
def unpack_3(
|
|
41
|
+
self,
|
|
42
|
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
43
|
+
return (self.indices, self.offsets, self.per_sample_weights)
|
|
44
|
+
|
|
45
|
+
def unpack_4(
|
|
46
|
+
self,
|
|
47
|
+
) -> tuple[
|
|
48
|
+
torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[list[list[int]]]
|
|
49
|
+
]:
|
|
50
|
+
return (
|
|
51
|
+
self.indices,
|
|
52
|
+
self.offsets,
|
|
53
|
+
self.per_sample_weights,
|
|
54
|
+
self.Bs_per_feature_per_rank,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_requests_from_data_file(
|
|
59
|
+
iters: int,
|
|
60
|
+
B: int,
|
|
61
|
+
T: int,
|
|
62
|
+
L: int,
|
|
63
|
+
E: int,
|
|
64
|
+
weighted: bool,
|
|
65
|
+
requests_data_file: Optional[str] = None,
|
|
66
|
+
indices_file: Optional[str] = None,
|
|
67
|
+
offsets_file: Optional[str] = None,
|
|
68
|
+
tables: Optional[str] = None,
|
|
69
|
+
index_dtype: Optional[torch.dtype] = None,
|
|
70
|
+
offset_dtype: Optional[torch.dtype] = None,
|
|
71
|
+
) -> list[TBERequest]:
|
|
72
|
+
"""
|
|
73
|
+
Generate TBE requests from the input data file. If `requests_data_file` is provided,
|
|
74
|
+
`indices_file` and `offsets_file` should not be provided. If either `indices_file`
|
|
75
|
+
or `offsets_file` is provided, both must be provided.
|
|
76
|
+
"""
|
|
77
|
+
assert not (
|
|
78
|
+
requests_data_file and (indices_file or offsets_file)
|
|
79
|
+
), "If requests_data_file is provided, indices_file and offsets_file cannot be provided."
|
|
80
|
+
assert (
|
|
81
|
+
indices_file and offsets_file
|
|
82
|
+
), "Both indices_file and offsets_file must be provided if either is provided."
|
|
83
|
+
|
|
84
|
+
if requests_data_file:
|
|
85
|
+
indices_tensor, offsets_tensor, *rest = torch.load(requests_data_file)
|
|
86
|
+
else:
|
|
87
|
+
indices_tensor = torch.load(indices_file)
|
|
88
|
+
offsets_tensor = torch.load(offsets_file)
|
|
89
|
+
|
|
90
|
+
average_L = 0
|
|
91
|
+
if tables is not None:
|
|
92
|
+
emb_tables = tuple(int(x) for x in tables.split(","))
|
|
93
|
+
indices = torch.zeros(0, dtype=indices_tensor.dtype)
|
|
94
|
+
offsets = torch.zeros(1, dtype=offsets_tensor.dtype)
|
|
95
|
+
total_L = 0
|
|
96
|
+
for t in emb_tables:
|
|
97
|
+
t_offsets = offsets_tensor[B * t : B * (t + 1) + 1]
|
|
98
|
+
total_L += t_offsets[-1] - t_offsets[0]
|
|
99
|
+
indices = torch.cat((indices, indices_tensor[t_offsets[0] : t_offsets[-1]]))
|
|
100
|
+
offsets = torch.cat(
|
|
101
|
+
(
|
|
102
|
+
offsets,
|
|
103
|
+
t_offsets[1:] - t_offsets[0] + offsets[-1],
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
indices_tensor = indices
|
|
107
|
+
offsets_tensor = offsets
|
|
108
|
+
average_L = int(total_L / B)
|
|
109
|
+
|
|
110
|
+
assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), (
|
|
111
|
+
f"Requested tables: {emb_tables} "
|
|
112
|
+
f"does not conform to inputs (T, B) = ({T}, {B})."
|
|
113
|
+
)
|
|
114
|
+
logging.warning(
|
|
115
|
+
f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based "
|
|
116
|
+
f"on tables: {emb_tables}"
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B)
|
|
120
|
+
assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), (
|
|
121
|
+
f"Data file (indices = {indices_tensor.size()}, "
|
|
122
|
+
f"offsets = {offsets_tensor.size()}, lengths = {offsets_tensor.size() - 1}) "
|
|
123
|
+
f"does not conform to inputs (T, B) = ({T}, {B})."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
assert (
|
|
127
|
+
L == average_L
|
|
128
|
+
), f"Requested L does not align with provided data file ({L} vs. {average_L})"
|
|
129
|
+
assert E > max(indices_tensor), (
|
|
130
|
+
f"Number of embeddings is not enough to support maximum index "
|
|
131
|
+
f"provided by data file {E} vs. {max(indices_tensor)}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
weights_tensor = (
|
|
135
|
+
None
|
|
136
|
+
if not weighted
|
|
137
|
+
else torch.randn(indices_tensor.size(), device=get_device())
|
|
138
|
+
)
|
|
139
|
+
rs = []
|
|
140
|
+
for _ in range(iters):
|
|
141
|
+
rs.append(
|
|
142
|
+
TBERequest(
|
|
143
|
+
maybe_to_dtype(indices_tensor.to(get_device()), index_dtype),
|
|
144
|
+
maybe_to_dtype(offsets_tensor.to(get_device()), offset_dtype),
|
|
145
|
+
weights_tensor,
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
return rs
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def generate_int_data_from_stats(
|
|
152
|
+
mu: int,
|
|
153
|
+
sigma: int,
|
|
154
|
+
size: int,
|
|
155
|
+
distribution: str,
|
|
156
|
+
) -> npt.NDArray:
|
|
157
|
+
"""
|
|
158
|
+
Generate integer data based on stats
|
|
159
|
+
"""
|
|
160
|
+
if distribution == "uniform":
|
|
161
|
+
# TODO: either make these separate parameters or make a separate version of
|
|
162
|
+
# generate_requests to handle the uniform dist case once whole
|
|
163
|
+
# generate_requests function is refactored to split into helper functions
|
|
164
|
+
# for each use case.
|
|
165
|
+
# mu represents the lower bound when the uniform distribution is used
|
|
166
|
+
lower_bound = mu
|
|
167
|
+
# sigma represetns the upper bound when the uniform distribution is used
|
|
168
|
+
upper_bound = sigma + 1
|
|
169
|
+
return np.random.randint(
|
|
170
|
+
lower_bound,
|
|
171
|
+
upper_bound,
|
|
172
|
+
(size,),
|
|
173
|
+
dtype=np.int32,
|
|
174
|
+
)
|
|
175
|
+
else: # normal dist
|
|
176
|
+
return np.random.normal(loc=mu, scale=sigma, size=size).astype(int)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def generate_pooling_factors_from_stats(
|
|
180
|
+
iters: int,
|
|
181
|
+
Bs: list[int],
|
|
182
|
+
L: int,
|
|
183
|
+
sigma_L: int,
|
|
184
|
+
# distribution of pooling factors
|
|
185
|
+
length_dist: str,
|
|
186
|
+
) -> tuple[int, torch.Tensor]:
|
|
187
|
+
"""
|
|
188
|
+
Generate pooling factors for the TBE requests from the given stats
|
|
189
|
+
"""
|
|
190
|
+
Ls_list = []
|
|
191
|
+
for B in Bs:
|
|
192
|
+
Ls_list.append(generate_int_data_from_stats(L, sigma_L, B, length_dist))
|
|
193
|
+
|
|
194
|
+
# Concat all Ls
|
|
195
|
+
Ls = np.concatenate(Ls_list)
|
|
196
|
+
|
|
197
|
+
# Make sure that Ls are positive
|
|
198
|
+
Ls[Ls < 0] = 0
|
|
199
|
+
# Use the same L distribution across iters
|
|
200
|
+
Ls = np.tile(Ls, iters)
|
|
201
|
+
L = Ls.max()
|
|
202
|
+
# Make it exclusive cumsum
|
|
203
|
+
L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
|
|
204
|
+
return L, L_offsets
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def generate_batch_sizes_from_stats(
|
|
208
|
+
B: int,
|
|
209
|
+
T: int,
|
|
210
|
+
sigma_B: int,
|
|
211
|
+
vbe_num_ranks: int,
|
|
212
|
+
# Distribution of batch sizes
|
|
213
|
+
batch_size_dist: str,
|
|
214
|
+
) -> tuple[list[int], list[list[int]]]:
|
|
215
|
+
"""
|
|
216
|
+
Generate batch sizes for features from the given stats
|
|
217
|
+
"""
|
|
218
|
+
# Generate batch size per feature per rank
|
|
219
|
+
Bs_feature_rank = generate_int_data_from_stats(
|
|
220
|
+
B, sigma_B, T * vbe_num_ranks, batch_size_dist
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Make sure that Bs are at least one
|
|
224
|
+
Bs_feature_rank = np.absolute(Bs_feature_rank)
|
|
225
|
+
Bs_feature_rank[Bs_feature_rank == 0] = 1
|
|
226
|
+
|
|
227
|
+
# Convert numpy array to Torch tensor
|
|
228
|
+
Bs_feature_rank = torch.from_numpy(Bs_feature_rank).view(T, vbe_num_ranks)
|
|
229
|
+
# Compute batch sizes per feature
|
|
230
|
+
Bs = Bs_feature_rank.sum(1).tolist()
|
|
231
|
+
|
|
232
|
+
return Bs, Bs_feature_rank.tolist()
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def generate_indices_uniform(
|
|
236
|
+
iters: int,
|
|
237
|
+
Bs: list[int],
|
|
238
|
+
L: int,
|
|
239
|
+
E: int,
|
|
240
|
+
use_variable_L: bool,
|
|
241
|
+
L_offsets: torch.Tensor,
|
|
242
|
+
) -> torch.Tensor:
|
|
243
|
+
"""
|
|
244
|
+
Generate indices for the TBE requests using the uniform distribution
|
|
245
|
+
"""
|
|
246
|
+
total_B = sum(Bs)
|
|
247
|
+
indices = torch.randint(
|
|
248
|
+
low=0,
|
|
249
|
+
high=E,
|
|
250
|
+
size=(iters, total_B, L),
|
|
251
|
+
device="cpu" if use_variable_L else get_device(),
|
|
252
|
+
dtype=torch.int32,
|
|
253
|
+
)
|
|
254
|
+
# each bag is usually sorted
|
|
255
|
+
(indices, _) = torch.sort(indices)
|
|
256
|
+
if use_variable_L:
|
|
257
|
+
# 1D layout, where row offsets are determined by L_offsets
|
|
258
|
+
indices = torch.ops.fbgemm.bottom_k_per_row(
|
|
259
|
+
indices.to(torch.long), L_offsets, False
|
|
260
|
+
)
|
|
261
|
+
indices = indices.to(get_device()).int()
|
|
262
|
+
else:
|
|
263
|
+
# 2D layout
|
|
264
|
+
indices = indices.reshape(iters, total_B * L)
|
|
265
|
+
return indices
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def generate_indices_zipf(
|
|
269
|
+
iters: int,
|
|
270
|
+
Bs: list[int],
|
|
271
|
+
L: int,
|
|
272
|
+
E: int,
|
|
273
|
+
alpha: float,
|
|
274
|
+
zipf_oversample_ratio: int,
|
|
275
|
+
use_variable_L: bool,
|
|
276
|
+
L_offsets: torch.Tensor,
|
|
277
|
+
deterministic_output: bool,
|
|
278
|
+
) -> torch.Tensor:
|
|
279
|
+
"""
|
|
280
|
+
Generate indices for the TBE requests using the zipf distribution
|
|
281
|
+
"""
|
|
282
|
+
assert E >= L, "num-embeddings must be greater than equal to bag-size"
|
|
283
|
+
# oversample and then remove duplicates to obtain sampling without
|
|
284
|
+
# replacement
|
|
285
|
+
if L == 0:
|
|
286
|
+
return torch.empty(iters, 0, dtype=torch.int).to(get_device())
|
|
287
|
+
total_B = sum(Bs)
|
|
288
|
+
zipf_shape = (iters, total_B, zipf_oversample_ratio * L)
|
|
289
|
+
if torch.cuda.is_available():
|
|
290
|
+
zipf_shape_total_len = np.prod(zipf_shape)
|
|
291
|
+
indices_list = []
|
|
292
|
+
# process 8 GB at a time on GPU
|
|
293
|
+
chunk_len = int(1e9)
|
|
294
|
+
for chunk_begin in range(0, zipf_shape_total_len, chunk_len):
|
|
295
|
+
indices_gpu = torch.ops.fbgemm.zipf_cuda(
|
|
296
|
+
alpha,
|
|
297
|
+
min(zipf_shape_total_len - chunk_begin, chunk_len),
|
|
298
|
+
seed=torch.randint(2**31 - 1, (1,))[0],
|
|
299
|
+
)
|
|
300
|
+
indices_list.append(indices_gpu.cpu())
|
|
301
|
+
indices = torch.cat(indices_list).reshape(zipf_shape)
|
|
302
|
+
else:
|
|
303
|
+
indices = torch.as_tensor(np.random.zipf(a=alpha, size=zipf_shape))
|
|
304
|
+
indices = (indices - 1) % E
|
|
305
|
+
if use_variable_L:
|
|
306
|
+
indices = torch.ops.fbgemm.bottom_k_per_row(indices, L_offsets, True)
|
|
307
|
+
else:
|
|
308
|
+
indices = torch.ops.fbgemm.bottom_k_per_row(
|
|
309
|
+
indices, torch.tensor([0, L], dtype=torch.long), True
|
|
310
|
+
)
|
|
311
|
+
if deterministic_output:
|
|
312
|
+
rng = default_rng(12345)
|
|
313
|
+
else:
|
|
314
|
+
rng = default_rng()
|
|
315
|
+
permutation = torch.as_tensor(
|
|
316
|
+
rng.choice(E, size=indices.max().item() + 1, replace=False)
|
|
317
|
+
)
|
|
318
|
+
indices = permutation.gather(0, indices.flatten())
|
|
319
|
+
indices = indices.to(get_device()).int()
|
|
320
|
+
if not use_variable_L:
|
|
321
|
+
indices = indices.reshape(iters, total_B * L)
|
|
322
|
+
return indices
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def update_indices_with_random_reuse(
|
|
326
|
+
iters: int,
|
|
327
|
+
Bs: list[int],
|
|
328
|
+
L: int,
|
|
329
|
+
reuse: float,
|
|
330
|
+
indices: torch.Tensor,
|
|
331
|
+
) -> torch.Tensor:
|
|
332
|
+
"""
|
|
333
|
+
Update the generated indices with random reuse
|
|
334
|
+
"""
|
|
335
|
+
for it in range(iters - 1):
|
|
336
|
+
B_offset = 0
|
|
337
|
+
for B in Bs:
|
|
338
|
+
reused_indices = torch.randperm(B * L, device=get_device())[
|
|
339
|
+
: int(B * L * reuse)
|
|
340
|
+
]
|
|
341
|
+
reused_indices += B_offset
|
|
342
|
+
indices[it + 1, reused_indices] = indices[it, reused_indices]
|
|
343
|
+
B_offset += B * L
|
|
344
|
+
return indices
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def update_indices_with_random_pruning(
|
|
348
|
+
iters: int,
|
|
349
|
+
B: int,
|
|
350
|
+
T: int,
|
|
351
|
+
L: int,
|
|
352
|
+
indices: torch.Tensor,
|
|
353
|
+
) -> torch.Tensor:
|
|
354
|
+
"""
|
|
355
|
+
Update the generated indices with random pruning
|
|
356
|
+
"""
|
|
357
|
+
for it in range(iters):
|
|
358
|
+
for t in range(T):
|
|
359
|
+
num_negative_indices = B // 2
|
|
360
|
+
random_locations = torch.randint(
|
|
361
|
+
low=0,
|
|
362
|
+
high=(B * L),
|
|
363
|
+
size=(num_negative_indices,),
|
|
364
|
+
device=torch.cuda.current_device(),
|
|
365
|
+
dtype=torch.int32,
|
|
366
|
+
)
|
|
367
|
+
indices[it, t, random_locations] = -1
|
|
368
|
+
return indices
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def maybe_to_dtype(tensor: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.Tensor:
|
|
372
|
+
return tensor if dtype is None else tensor.to(dtype)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def generate_requests( # noqa C901
|
|
376
|
+
iters: int,
|
|
377
|
+
B: int,
|
|
378
|
+
T: int,
|
|
379
|
+
L: int,
|
|
380
|
+
E: int,
|
|
381
|
+
# inter-batch indices reuse rate
|
|
382
|
+
reuse: float = 0.0,
|
|
383
|
+
# alpha <= 1.0: use uniform distribution
|
|
384
|
+
# alpha > 1.0: use zipf distribution
|
|
385
|
+
alpha: float = 1.0,
|
|
386
|
+
zipf_oversample_ratio: int = 3,
|
|
387
|
+
weighted: bool = False,
|
|
388
|
+
requests_data_file: Optional[str] = None,
|
|
389
|
+
# Path to file containing indices and offsets. If provided, this will be used
|
|
390
|
+
indices_file: Optional[str] = None,
|
|
391
|
+
offsets_file: Optional[str] = None,
|
|
392
|
+
# Comma-separated list of table numbers
|
|
393
|
+
tables: Optional[str] = None,
|
|
394
|
+
# If sigma_L is not None, treat L as mu_L and generate Ls from sigma_L
|
|
395
|
+
# and mu_L
|
|
396
|
+
sigma_L: Optional[int] = None,
|
|
397
|
+
# If sigma_B is not None, treat B as mu_B and generate Bs from sigma_B
|
|
398
|
+
sigma_B: Optional[int] = None,
|
|
399
|
+
emulate_pruning: bool = False,
|
|
400
|
+
use_cpu: bool = False,
|
|
401
|
+
# generate_requests uses numpy.random.default_rng without a set random seed
|
|
402
|
+
# be default, causing the indices tensor to vary with each call to
|
|
403
|
+
# generate_requests - set generate_repeatable_output to use a fixed random
|
|
404
|
+
# seed instead for repeatable outputs
|
|
405
|
+
deterministic_output: bool = False,
|
|
406
|
+
# distribution of embedding sequence lengths
|
|
407
|
+
length_dist: str = "normal",
|
|
408
|
+
# distribution of batch sizes
|
|
409
|
+
batch_size_dist: str = "normal",
|
|
410
|
+
# Number of ranks for variable batch size generation
|
|
411
|
+
vbe_num_ranks: Optional[int] = None,
|
|
412
|
+
index_dtype: Optional[torch.dtype] = None,
|
|
413
|
+
offset_dtype: Optional[torch.dtype] = None,
|
|
414
|
+
) -> list[TBERequest]:
|
|
415
|
+
# TODO: refactor and split into helper functions to separate load from file,
|
|
416
|
+
# generate from distribution, and other future methods of generating data
|
|
417
|
+
if (
|
|
418
|
+
requests_data_file is not None
|
|
419
|
+
or indices_file is not None
|
|
420
|
+
or offsets_file is not None
|
|
421
|
+
):
|
|
422
|
+
|
|
423
|
+
assert sigma_L is None, "Variable pooling factors is not supported"
|
|
424
|
+
assert sigma_B is None, "Variable batch sizes is not supported"
|
|
425
|
+
return generate_requests_from_data_file(
|
|
426
|
+
iters=iters,
|
|
427
|
+
B=B,
|
|
428
|
+
T=T,
|
|
429
|
+
L=L,
|
|
430
|
+
E=E,
|
|
431
|
+
weighted=weighted,
|
|
432
|
+
requests_data_file=requests_data_file,
|
|
433
|
+
indices_file=indices_file,
|
|
434
|
+
offsets_file=offsets_file,
|
|
435
|
+
tables=tables,
|
|
436
|
+
index_dtype=index_dtype,
|
|
437
|
+
offset_dtype=offset_dtype,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if sigma_B is not None:
|
|
441
|
+
assert (
|
|
442
|
+
vbe_num_ranks is not None
|
|
443
|
+
), "vbe_num_ranks must be set for varaible batch size generation"
|
|
444
|
+
use_variable_B = True
|
|
445
|
+
Bs, Bs_feature_rank = generate_batch_sizes_from_stats(
|
|
446
|
+
B, T, sigma_B, vbe_num_ranks, batch_size_dist
|
|
447
|
+
)
|
|
448
|
+
else:
|
|
449
|
+
use_variable_B = False
|
|
450
|
+
Bs = [B] * T
|
|
451
|
+
Bs_feature_rank = None
|
|
452
|
+
|
|
453
|
+
if sigma_L is not None:
|
|
454
|
+
# Generate L from stats
|
|
455
|
+
use_variable_L = True
|
|
456
|
+
L, L_offsets = generate_pooling_factors_from_stats(
|
|
457
|
+
iters, Bs, L, sigma_L, length_dist
|
|
458
|
+
)
|
|
459
|
+
elif use_variable_B:
|
|
460
|
+
use_variable_L = False
|
|
461
|
+
Ls = [L] * (sum(Bs) * iters)
|
|
462
|
+
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
|
|
463
|
+
else:
|
|
464
|
+
use_variable_L = False
|
|
465
|
+
# Init to suppress the pyre error
|
|
466
|
+
L_offsets = torch.empty(1)
|
|
467
|
+
|
|
468
|
+
if alpha <= 1.0:
|
|
469
|
+
# Generate indices using uniform dist
|
|
470
|
+
all_indices = generate_indices_uniform(
|
|
471
|
+
iters, Bs, L, E, use_variable_L, L_offsets
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
# Generate indices using zipf dist
|
|
475
|
+
all_indices = generate_indices_zipf(
|
|
476
|
+
iters,
|
|
477
|
+
Bs,
|
|
478
|
+
L,
|
|
479
|
+
E,
|
|
480
|
+
alpha,
|
|
481
|
+
zipf_oversample_ratio,
|
|
482
|
+
use_variable_L,
|
|
483
|
+
L_offsets,
|
|
484
|
+
deterministic_output,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if reuse > 0.0:
|
|
488
|
+
assert (
|
|
489
|
+
not use_variable_L
|
|
490
|
+
), "Does not support generating Ls from stats for reuse > 0.0"
|
|
491
|
+
all_indices = update_indices_with_random_reuse(iters, Bs, L, reuse, all_indices)
|
|
492
|
+
|
|
493
|
+
# Some indices are set to -1 for emulating pruned rows.
|
|
494
|
+
if emulate_pruning:
|
|
495
|
+
assert (
|
|
496
|
+
not use_variable_L
|
|
497
|
+
), "Does not support generating Ls from stats for emulate_pruning=True"
|
|
498
|
+
assert (
|
|
499
|
+
not use_variable_B
|
|
500
|
+
), "Does not support generating Bs from stats for emulate_pruning=True"
|
|
501
|
+
|
|
502
|
+
all_indices = update_indices_with_random_pruning(
|
|
503
|
+
iters, B, T, L, all_indices.view(iters, T, B * L)
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Pack requests
|
|
507
|
+
rs = []
|
|
508
|
+
if use_variable_L or use_variable_B:
|
|
509
|
+
total_B = sum(Bs)
|
|
510
|
+
all_indices = all_indices.flatten()
|
|
511
|
+
for it in range(iters):
|
|
512
|
+
start_offset = L_offsets[it * total_B]
|
|
513
|
+
it_L_offsets = torch.concat(
|
|
514
|
+
[
|
|
515
|
+
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
|
|
516
|
+
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
|
|
517
|
+
]
|
|
518
|
+
)
|
|
519
|
+
weights_tensor = (
|
|
520
|
+
None
|
|
521
|
+
if not weighted
|
|
522
|
+
else torch.randn(
|
|
523
|
+
int(it_L_offsets[-1].item()), device=get_device()
|
|
524
|
+
) # per sample weights will always be FP32
|
|
525
|
+
)
|
|
526
|
+
rs.append(
|
|
527
|
+
TBERequest(
|
|
528
|
+
maybe_to_dtype(
|
|
529
|
+
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
|
|
530
|
+
index_dtype,
|
|
531
|
+
),
|
|
532
|
+
maybe_to_dtype(it_L_offsets.to(get_device()), offset_dtype),
|
|
533
|
+
weights_tensor,
|
|
534
|
+
Bs_feature_rank if use_variable_B else None,
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
else:
|
|
538
|
+
for it in range(iters):
|
|
539
|
+
weights_tensor = (
|
|
540
|
+
None
|
|
541
|
+
if not weighted
|
|
542
|
+
else torch.randn(
|
|
543
|
+
T * B * L, device=get_device()
|
|
544
|
+
) # per sample weights will always be FP32
|
|
545
|
+
)
|
|
546
|
+
indices, offsets = get_table_batched_offsets_from_dense(
|
|
547
|
+
all_indices[it].view(T, B, L), use_cpu=use_cpu
|
|
548
|
+
)
|
|
549
|
+
rs.append(
|
|
550
|
+
TBERequest(
|
|
551
|
+
maybe_to_dtype(indices, index_dtype),
|
|
552
|
+
maybe_to_dtype(offsets, offset_dtype),
|
|
553
|
+
weights_tensor,
|
|
554
|
+
)
|
|
555
|
+
)
|
|
556
|
+
return rs
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
|
|
10
|
+
import abc
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class TBEInfo:
|
|
20
|
+
"""
|
|
21
|
+
contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
table_names: table names within the tbe
|
|
25
|
+
table_heights: sharded table heights (hashsize)
|
|
26
|
+
tbe_uuid: a unique identifier for the TBE
|
|
27
|
+
feature_table_map: feature to table map
|
|
28
|
+
table_dims: sharded table dimensions
|
|
29
|
+
full_table_heights: table heights before sharding
|
|
30
|
+
full_table_dims: table dimensions before sharding
|
|
31
|
+
row_offset: the shard offset of the current rank on row (height)
|
|
32
|
+
col_offset: the shard offset of the current rank on column (dim)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
table_names: list[str]
|
|
36
|
+
table_heights: list[int]
|
|
37
|
+
tbe_uuid: str
|
|
38
|
+
feature_table_map: list[int]
|
|
39
|
+
table_dims: list[int]
|
|
40
|
+
full_table_heights: list[int]
|
|
41
|
+
full_table_dims: list[int]
|
|
42
|
+
row_offset: list[int]
|
|
43
|
+
col_offset: list[int]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(frozen=True)
|
|
47
|
+
class TBEInputInfo:
|
|
48
|
+
"""
|
|
49
|
+
indices: A 1D-tensor that contains indices to be looked up
|
|
50
|
+
from all embedding table.
|
|
51
|
+
offsets: A 1D-tensor that conatins offsets of indices.
|
|
52
|
+
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
|
|
53
|
+
every feature. this is needed to support VBE.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
indices: Tensor
|
|
57
|
+
offsets: Tensor
|
|
58
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TBEInputMultiplexer(abc.ABC):
|
|
62
|
+
"""
|
|
63
|
+
Interface for multiplex TBE input data out, actual implementation may store the data to files
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
def should_run(self, step: int) -> bool:
|
|
68
|
+
"""
|
|
69
|
+
To check if should run at this step
|
|
70
|
+
Args:
|
|
71
|
+
step: the current step
|
|
72
|
+
Returns:
|
|
73
|
+
True if should run, otherwise False
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abc.abstractmethod
|
|
78
|
+
def run(
|
|
79
|
+
self,
|
|
80
|
+
tbe_input_info: TBEInputInfo,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
To run the tbe input multiplex, and this is called for every batch that needs to be dumped
|
|
84
|
+
Args:
|
|
85
|
+
tbe_input_info: tbe input info that contains all the necessary info for further processing
|
|
86
|
+
"""
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass(frozen=True)
|
|
91
|
+
class TBEInputMultiplexerConfig:
|
|
92
|
+
"""
|
|
93
|
+
Configuration for TBEInputMultiplexer
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
# first batch to start run, -1 means no run
|
|
97
|
+
start_batch: int = -1
|
|
98
|
+
# total batch to multiplex
|
|
99
|
+
total_batch: int = 0
|
|
100
|
+
|
|
101
|
+
def create_tbe_input_multiplexer(
|
|
102
|
+
self,
|
|
103
|
+
tbe_info: TBEInfo,
|
|
104
|
+
) -> Optional[TBEInputMultiplexer]:
|
|
105
|
+
assert (
|
|
106
|
+
self.start_batch == -1
|
|
107
|
+
), "Cannot specify monitor_start_batch without an actual implementation."
|
|
108
|
+
return None
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-unsafe
|
|
9
|
+
|
|
10
|
+
# Attempt to import triton kernels, fallback to reference if we cannot.
|
|
11
|
+
from .common import RoundingMode # noqa
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from .quantize import (
|
|
15
|
+
triton_dequantize_mx4 as dequantize_mx4,
|
|
16
|
+
triton_quantize_mx4 as quantize_mx4,
|
|
17
|
+
)
|
|
18
|
+
except ImportError:
|
|
19
|
+
from .quantize_ref import ( # noqa: F401, E402
|
|
20
|
+
py_dequantize_mx4 as dequantize_mx4,
|
|
21
|
+
py_quantize_mx4 as quantize_mx4,
|
|
22
|
+
)
|