fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +118 -23
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +142 -1
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +244 -76
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
- fbgemm_gpu/split_embedding_configs.py +287 -3
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
- fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +13 -2
- fbgemm_gpu/tbe/bench/bench_config.py +37 -9
- fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/__init__.py +1 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
- fbgemm_gpu/tbe/ssd/common.py +27 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +2930 -195
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +53 -28
- fbgemm_gpu/tbe_input_multiplexer.py +16 -7
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +56 -5
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +3 -0
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
- fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -9,21 +9,18 @@
|
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
11
|
import json
|
|
12
|
-
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Any, List, Optional, Tuple
|
|
13
14
|
|
|
14
|
-
import numpy as np
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
from fbgemm_gpu.tbe.utils.
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
|
|
17
|
+
# fmt:skip
|
|
18
|
+
from fbgemm_gpu.tbe.utils.common import get_device
|
|
19
|
+
from .tbe_data_config_param_models import (
|
|
20
|
+
BatchParams,
|
|
21
|
+
IndicesParams,
|
|
22
|
+
PoolingParams,
|
|
23
|
+
) # usort:skip
|
|
27
24
|
|
|
28
25
|
try:
|
|
29
26
|
torch.ops.load_library(
|
|
@@ -35,27 +32,87 @@ except Exception:
|
|
|
35
32
|
|
|
36
33
|
@dataclasses.dataclass(frozen=True)
|
|
37
34
|
class TBEDataConfig:
|
|
38
|
-
# Number of tables
|
|
39
35
|
T: int
|
|
40
|
-
# Number of rows in the embedding table
|
|
41
36
|
E: int
|
|
42
|
-
# Target embedding dimension for a table (number of columns)
|
|
43
37
|
D: int
|
|
44
|
-
# Generate mixed dimensions if true
|
|
45
38
|
mixed_dim: bool
|
|
46
|
-
# Whether the table is weighted or not
|
|
47
39
|
weighted: bool
|
|
48
|
-
# Batch parameters
|
|
49
40
|
batch_params: BatchParams
|
|
50
|
-
# Indices parameters
|
|
51
41
|
indices_params: IndicesParams
|
|
52
|
-
# Pooling parameters
|
|
53
42
|
pooling_params: PoolingParams
|
|
54
|
-
# Force generated tensors to be on CPU
|
|
55
43
|
use_cpu: bool = False
|
|
44
|
+
Es: Optional[list[int]] = None
|
|
45
|
+
Ds: Optional[list[int]] = None
|
|
46
|
+
max_indices: Optional[int] = None
|
|
47
|
+
embedding_specs: Optional[List[Tuple[int, int]]] = None
|
|
48
|
+
feature_table_map: Optional[List[int]] = None
|
|
49
|
+
"""
|
|
50
|
+
Configuration for TBE (Table Batched Embedding) benchmark data collection and generation.
|
|
51
|
+
|
|
52
|
+
This dataclass holds parameters required to generate synthetic data for
|
|
53
|
+
TBE benchmarking, including table specifications, batch parameters, indices
|
|
54
|
+
distribution parameters, and pooling parameters.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
T (int): Number of embedding tables (features). Must be positive.
|
|
58
|
+
E (int): Number of rows in the embedding table (feature). If T > 1, this
|
|
59
|
+
represents the averaged number of rows across all features.
|
|
60
|
+
D (int): Target embedding dimension for a table (feature), i.e., number of
|
|
61
|
+
columns. If T > 1, this represents the averaged dimension across
|
|
62
|
+
all features.
|
|
63
|
+
mixed_dim (bool): If True, generate embeddings with mixed dimensions
|
|
64
|
+
across tables (features). This is automatically set to True if D is provided
|
|
65
|
+
as a list with non-uniform values.
|
|
66
|
+
weighted (bool): If True, the lookup rows are weighted (per-sample
|
|
67
|
+
weights). The weights will be generated as FP32 tensors.
|
|
68
|
+
batch_params (BatchParams): Parameters controlling batch generation.
|
|
69
|
+
Contains:
|
|
70
|
+
(1) `B` = target batch size (number of batch lookups per features)
|
|
71
|
+
(2) `sigma_B` = optional standard deviation for variable batch size
|
|
72
|
+
(3) `vbe_distribution` = distribution type ("normal" or "uniform")
|
|
73
|
+
(4) `vbe_num_ranks` = number of ranks for variable batch size
|
|
74
|
+
(5) `Bs` = per-feature batch sizes
|
|
75
|
+
indices_params (IndicesParams): Parameters controlling index generation
|
|
76
|
+
following a Zipf distribution. Contains:
|
|
77
|
+
(1) `heavy_hitters` = probability density map for hot indices
|
|
78
|
+
(2) `zipf_q` = q parameter in Zipf distribution (x+q)^{-s}
|
|
79
|
+
(3) `zipf_s` = s parameter (alpha) in Zipf distribution
|
|
80
|
+
(4) `index_dtype` = optional dtype for indices tensor
|
|
81
|
+
(5) `offset_dtype` = optional dtype for offsets tensor
|
|
82
|
+
pooling_params (PoolingParams): Parameters controlling pooling behavior.
|
|
83
|
+
Contains:
|
|
84
|
+
(1) `L` = target bag size (pooling factor, indices per lookup)
|
|
85
|
+
(2) `sigma_L` = optional standard deviation for variable bag size
|
|
86
|
+
(3) `length_distribution` = distribution type ("normal" or "uniform")
|
|
87
|
+
(4) `Ls` = per-feature bag sizes
|
|
88
|
+
use_cpu (bool = False): If True, force generated tensors to be placed
|
|
89
|
+
on CPU instead of the default compute device.
|
|
90
|
+
Es (Optional[List[int]] = None): Number of embeddings (rows) for each
|
|
91
|
+
individual embedding feature. If provided, must have length equal
|
|
92
|
+
to T. All elements must be positive.
|
|
93
|
+
Ds (Optional[List[int]] = None): Target embedding dimension (columns)
|
|
94
|
+
for each individual feature. If provided, must have length equal
|
|
95
|
+
to T. All elements must be positive.
|
|
96
|
+
max_indices (Optional[int] = None): Maximum number of indices for
|
|
97
|
+
bounds checking. If Es is provided as a list and max_indices is
|
|
98
|
+
None, it is automatically computed as sum(Es) - 1.
|
|
99
|
+
embedding_specs (Optional[List[Tuple[int, int]]] = None): A list of
|
|
100
|
+
embedding specs consisting of a list of tuples of (num_rows, embedding_dim).
|
|
101
|
+
See https://fburl.com/tbe_embedding_specs for details.
|
|
102
|
+
feature_table_map (Optional[List[int]] = None): An optional list that
|
|
103
|
+
specifies feature-table mapping. feature_table_map[i] indicates the
|
|
104
|
+
physical embedding table that feature i maps to.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __post_init__(self) -> None:
|
|
108
|
+
if isinstance(self.D, list):
|
|
109
|
+
object.__setattr__(self, "mixed_dim", len(set(self.D)) > 1)
|
|
110
|
+
if isinstance(self.E, list) and self.max_indices is None:
|
|
111
|
+
object.__setattr__(self, "max_indices", sum(self.E) - 1)
|
|
112
|
+
self.validate()
|
|
56
113
|
|
|
57
114
|
@staticmethod
|
|
58
|
-
def complex_fields() ->
|
|
115
|
+
def complex_fields() -> dict[str, Any]:
|
|
59
116
|
return {
|
|
60
117
|
"batch_params": BatchParams,
|
|
61
118
|
"indices_params": IndicesParams,
|
|
@@ -64,7 +121,7 @@ class TBEDataConfig:
|
|
|
64
121
|
|
|
65
122
|
@classmethod
|
|
66
123
|
# pyre-ignore [3]
|
|
67
|
-
def from_dict(cls, data:
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
68
125
|
for field, Type in cls.complex_fields().items():
|
|
69
126
|
if not isinstance(data[field], Type):
|
|
70
127
|
data[field] = Type.from_dict(data[field])
|
|
@@ -73,9 +130,22 @@ class TBEDataConfig:
|
|
|
73
130
|
@classmethod
|
|
74
131
|
# pyre-ignore [3]
|
|
75
132
|
def from_json(cls, data: str):
|
|
76
|
-
|
|
133
|
+
raw = json.loads(data)
|
|
134
|
+
allowed = {f.name for f in dataclasses.fields(cls)}
|
|
135
|
+
existing_fields = {k: v for k, v in raw.items() if k in allowed}
|
|
136
|
+
missing_fields = allowed - set(existing_fields.keys())
|
|
137
|
+
unknown_fields = set(raw.keys()) - allowed
|
|
138
|
+
if missing_fields:
|
|
139
|
+
logging.warning(
|
|
140
|
+
f"TBEDataConfig.from_json: Missing expected fields not loaded: {sorted(missing_fields)}"
|
|
141
|
+
)
|
|
142
|
+
if unknown_fields:
|
|
143
|
+
logging.info(
|
|
144
|
+
f"TBEDataConfig.from_json: Ignored unknown fields from input: {sorted(unknown_fields)}"
|
|
145
|
+
)
|
|
146
|
+
return cls.from_dict(existing_fields)
|
|
77
147
|
|
|
78
|
-
def dict(self) ->
|
|
148
|
+
def dict(self) -> dict[str, Any]:
|
|
79
149
|
tmp = dataclasses.asdict(self)
|
|
80
150
|
for field in TBEDataConfig.complex_fields().keys():
|
|
81
151
|
tmp[field] = self.__dict__[field].dict()
|
|
@@ -89,10 +159,30 @@ class TBEDataConfig:
|
|
|
89
159
|
# NOTE: Add validation logic here
|
|
90
160
|
assert self.T > 0, "T must be positive"
|
|
91
161
|
assert self.E > 0, "E must be positive"
|
|
162
|
+
if self.Es is not None:
|
|
163
|
+
assert all(e > 0 for e in self.Es), "All elements in Es must be positive"
|
|
92
164
|
assert self.D > 0, "D must be positive"
|
|
165
|
+
if self.Ds is not None:
|
|
166
|
+
assert all(d > 0 for d in self.Ds), "All elements in Ds must be positive"
|
|
167
|
+
if isinstance(self.Es, list) and isinstance(self.Ds, list):
|
|
168
|
+
assert (
|
|
169
|
+
len(self.Es) == len(self.Ds) == self.T
|
|
170
|
+
), "Lengths of Es, Lengths of Ds, and T must be equal"
|
|
171
|
+
if self.max_indices is not None:
|
|
172
|
+
assert self.max_indices == (
|
|
173
|
+
sum(self.Es) - 1
|
|
174
|
+
), "max_indices must be equal to sum(Es) - 1"
|
|
93
175
|
self.batch_params.validate()
|
|
176
|
+
if self.batch_params.Bs is not None:
|
|
177
|
+
assert (
|
|
178
|
+
len(self.batch_params.Bs) == self.T
|
|
179
|
+
), f"Length of Bs must be equal to T. Expected: {self.T}, but got: {len(self.batch_params.Bs)}"
|
|
94
180
|
self.indices_params.validate()
|
|
95
181
|
self.pooling_params.validate()
|
|
182
|
+
if self.pooling_params.Ls is not None:
|
|
183
|
+
assert (
|
|
184
|
+
len(self.pooling_params.Ls) == self.T
|
|
185
|
+
), f"Length of Ls must be equal to T. Expected: {self.T}, but got: {len(self.pooling_params.Ls)}"
|
|
96
186
|
return self
|
|
97
187
|
|
|
98
188
|
def variable_B(self) -> bool:
|
|
@@ -102,177 +192,5 @@ class TBEDataConfig:
|
|
|
102
192
|
return self.pooling_params.sigma_L is not None
|
|
103
193
|
|
|
104
194
|
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
|
|
105
|
-
#
|
|
195
|
+
# Per-sample weights will always be FP32
|
|
106
196
|
return None if not self.weighted else torch.randn(size, device=get_device())
|
|
107
|
-
|
|
108
|
-
def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
|
|
109
|
-
if self.variable_B():
|
|
110
|
-
assert (
|
|
111
|
-
self.batch_params.vbe_num_ranks is not None
|
|
112
|
-
), "vbe_num_ranks must be set for varaible batch size generation"
|
|
113
|
-
return generate_batch_sizes_from_stats(
|
|
114
|
-
self.batch_params.B,
|
|
115
|
-
self.T,
|
|
116
|
-
# pyre-ignore [6]
|
|
117
|
-
self.batch_params.sigma_B,
|
|
118
|
-
self.batch_params.vbe_num_ranks,
|
|
119
|
-
# pyre-ignore [6]
|
|
120
|
-
self.batch_params.vbe_distribution,
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
return ([self.batch_params.B] * self.T, None)
|
|
125
|
-
|
|
126
|
-
def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor:
|
|
127
|
-
if self.variable_L():
|
|
128
|
-
# Generate L from stats
|
|
129
|
-
_, L_offsets = generate_pooling_factors_from_stats(
|
|
130
|
-
iters,
|
|
131
|
-
Bs,
|
|
132
|
-
self.pooling_params.L,
|
|
133
|
-
# pyre-ignore [6]
|
|
134
|
-
self.pooling_params.sigma_L,
|
|
135
|
-
# pyre-ignore [6]
|
|
136
|
-
self.pooling_params.length_distribution,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
else:
|
|
140
|
-
Ls = [self.pooling_params.L] * (sum(Bs) * iters)
|
|
141
|
-
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
|
|
142
|
-
|
|
143
|
-
return L_offsets
|
|
144
|
-
|
|
145
|
-
def _generate_indices(
|
|
146
|
-
self,
|
|
147
|
-
iters: int,
|
|
148
|
-
Bs: List[int],
|
|
149
|
-
L_offsets: torch.Tensor,
|
|
150
|
-
) -> torch.Tensor:
|
|
151
|
-
total_B = sum(Bs)
|
|
152
|
-
L_offsets_list = L_offsets.tolist()
|
|
153
|
-
indices_list = []
|
|
154
|
-
for it in range(iters):
|
|
155
|
-
# L_offsets is defined over the entire set of batches for a single iteration
|
|
156
|
-
start_offset = L_offsets_list[it * total_B]
|
|
157
|
-
end_offset = L_offsets_list[(it + 1) * total_B]
|
|
158
|
-
|
|
159
|
-
indices_list.append(
|
|
160
|
-
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
|
|
161
|
-
self.indices_params.heavy_hitters,
|
|
162
|
-
self.indices_params.zipf_q,
|
|
163
|
-
self.indices_params.zipf_s,
|
|
164
|
-
# max_index = dimensions of the embedding table
|
|
165
|
-
self.E,
|
|
166
|
-
# num_indices = number of indices to generate
|
|
167
|
-
end_offset - start_offset,
|
|
168
|
-
)
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
return torch.cat(indices_list)
|
|
172
|
-
|
|
173
|
-
def _build_requests_jagged(
|
|
174
|
-
self,
|
|
175
|
-
iters: int,
|
|
176
|
-
Bs: List[int],
|
|
177
|
-
Bs_feature_rank: Optional[List[List[int]]],
|
|
178
|
-
L_offsets: torch.Tensor,
|
|
179
|
-
all_indices: torch.Tensor,
|
|
180
|
-
) -> List[TBERequest]:
|
|
181
|
-
total_B = sum(Bs)
|
|
182
|
-
all_indices = all_indices.flatten()
|
|
183
|
-
requests = []
|
|
184
|
-
for it in range(iters):
|
|
185
|
-
start_offset = L_offsets[it * total_B]
|
|
186
|
-
it_L_offsets = torch.concat(
|
|
187
|
-
[
|
|
188
|
-
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
|
|
189
|
-
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
|
|
190
|
-
]
|
|
191
|
-
)
|
|
192
|
-
requests.append(
|
|
193
|
-
TBERequest(
|
|
194
|
-
maybe_to_dtype(
|
|
195
|
-
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
|
|
196
|
-
self.indices_params.index_dtype,
|
|
197
|
-
),
|
|
198
|
-
maybe_to_dtype(
|
|
199
|
-
it_L_offsets.to(get_device()), self.indices_params.offset_dtype
|
|
200
|
-
),
|
|
201
|
-
self._new_weights(int(it_L_offsets[-1].item())),
|
|
202
|
-
Bs_feature_rank if self.variable_B() else None,
|
|
203
|
-
)
|
|
204
|
-
)
|
|
205
|
-
return requests
|
|
206
|
-
|
|
207
|
-
def _build_requests_dense(
|
|
208
|
-
self, iters: int, all_indices: torch.Tensor
|
|
209
|
-
) -> List[TBERequest]:
|
|
210
|
-
# NOTE: We're using existing code from requests.py to build the
|
|
211
|
-
# requests, and since the existing code requires 2D view of all_indices,
|
|
212
|
-
# the existing all_indices must be reshaped
|
|
213
|
-
all_indices = all_indices.reshape(iters, -1)
|
|
214
|
-
|
|
215
|
-
requests = []
|
|
216
|
-
for it in range(iters):
|
|
217
|
-
indices, offsets = get_table_batched_offsets_from_dense(
|
|
218
|
-
all_indices[it].view(
|
|
219
|
-
self.T, self.batch_params.B, self.pooling_params.L
|
|
220
|
-
),
|
|
221
|
-
use_cpu=self.use_cpu,
|
|
222
|
-
)
|
|
223
|
-
requests.append(
|
|
224
|
-
TBERequest(
|
|
225
|
-
maybe_to_dtype(indices, self.indices_params.index_dtype),
|
|
226
|
-
maybe_to_dtype(offsets, self.indices_params.offset_dtype),
|
|
227
|
-
self._new_weights(
|
|
228
|
-
self.T * self.batch_params.B * self.pooling_params.L
|
|
229
|
-
),
|
|
230
|
-
)
|
|
231
|
-
)
|
|
232
|
-
return requests
|
|
233
|
-
|
|
234
|
-
def generate_requests(
|
|
235
|
-
self,
|
|
236
|
-
iters: int = 1,
|
|
237
|
-
) -> List[TBERequest]:
|
|
238
|
-
# Generate batch sizes
|
|
239
|
-
Bs, Bs_feature_rank = self._generate_batch_sizes()
|
|
240
|
-
|
|
241
|
-
# Generate pooling info
|
|
242
|
-
L_offsets = self._generate_pooling_info(iters, Bs)
|
|
243
|
-
|
|
244
|
-
# Generate indices
|
|
245
|
-
all_indices = self._generate_indices(iters, Bs, L_offsets)
|
|
246
|
-
|
|
247
|
-
# Build TBE requests
|
|
248
|
-
if self.variable_B() or self.variable_L():
|
|
249
|
-
return self._build_requests_jagged(
|
|
250
|
-
iters, Bs, Bs_feature_rank, L_offsets, all_indices
|
|
251
|
-
)
|
|
252
|
-
else:
|
|
253
|
-
return self._build_requests_dense(iters, all_indices)
|
|
254
|
-
|
|
255
|
-
def generate_embedding_dims(self) -> Tuple[int, List[int]]:
|
|
256
|
-
if self.mixed_dim:
|
|
257
|
-
Ds = [
|
|
258
|
-
round_up(
|
|
259
|
-
np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4
|
|
260
|
-
)
|
|
261
|
-
for _ in range(self.T)
|
|
262
|
-
]
|
|
263
|
-
return (int(np.average(Ds)), Ds)
|
|
264
|
-
else:
|
|
265
|
-
return (self.D, [self.D] * self.T)
|
|
266
|
-
|
|
267
|
-
def generate_feature_requires_grad(self, size: int) -> torch.Tensor:
|
|
268
|
-
assert size <= self.T, "size of feature_requires_grad must be less than T"
|
|
269
|
-
weighted_requires_grad_tables = np.random.choice(
|
|
270
|
-
self.T, replace=False, size=(size,)
|
|
271
|
-
).tolist()
|
|
272
|
-
return (
|
|
273
|
-
torch.tensor(
|
|
274
|
-
[1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)]
|
|
275
|
-
)
|
|
276
|
-
.to(get_device())
|
|
277
|
-
.int()
|
|
278
|
-
)
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
# fmt:skip
|
|
17
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig
|
|
18
|
+
from fbgemm_gpu.tbe.utils.common import get_device, round_up
|
|
19
|
+
from fbgemm_gpu.tbe.utils.requests import (
|
|
20
|
+
generate_batch_sizes_from_stats,
|
|
21
|
+
generate_pooling_factors_from_stats,
|
|
22
|
+
get_table_batched_offsets_from_dense,
|
|
23
|
+
maybe_to_dtype,
|
|
24
|
+
TBERequest,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
# pyre-ignore[21]
|
|
29
|
+
from fbgemm_gpu import open_source # noqa: F401
|
|
30
|
+
except Exception:
|
|
31
|
+
torch.ops.load_library(
|
|
32
|
+
"//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _generate_batch_sizes(
|
|
37
|
+
tbe_data_config: TBEDataConfig,
|
|
38
|
+
) -> tuple[list[int], Optional[list[list[int]]]]:
|
|
39
|
+
logging.info(
|
|
40
|
+
f"DEBUG_TBE: [_generate_batch_sizes] VBE tbe_data_config.variable_B()={tbe_data_config.variable_B()}"
|
|
41
|
+
)
|
|
42
|
+
if tbe_data_config.variable_B():
|
|
43
|
+
assert (
|
|
44
|
+
tbe_data_config.batch_params.vbe_num_ranks is not None
|
|
45
|
+
), "vbe_num_ranks must be set for varaible batch size generation"
|
|
46
|
+
return generate_batch_sizes_from_stats(
|
|
47
|
+
tbe_data_config.batch_params.B,
|
|
48
|
+
tbe_data_config.T,
|
|
49
|
+
# pyre-ignore [6]
|
|
50
|
+
tbe_data_config.batch_params.sigma_B,
|
|
51
|
+
tbe_data_config.batch_params.vbe_num_ranks,
|
|
52
|
+
# pyre-ignore [6]
|
|
53
|
+
tbe_data_config.batch_params.vbe_distribution,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
return ([tbe_data_config.batch_params.B] * tbe_data_config.T, None)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _generate_pooling_info(
|
|
60
|
+
tbe_data_config: TBEDataConfig, iters: int, Bs: list[int]
|
|
61
|
+
) -> torch.Tensor:
|
|
62
|
+
if tbe_data_config.variable_L():
|
|
63
|
+
# Generate L from stats
|
|
64
|
+
_, L_offsets = generate_pooling_factors_from_stats(
|
|
65
|
+
iters,
|
|
66
|
+
Bs,
|
|
67
|
+
tbe_data_config.pooling_params.L,
|
|
68
|
+
# pyre-ignore [6]
|
|
69
|
+
tbe_data_config.pooling_params.sigma_L,
|
|
70
|
+
# pyre-ignore [6]
|
|
71
|
+
tbe_data_config.pooling_params.length_distribution,
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
Ls = [tbe_data_config.pooling_params.L] * (sum(Bs) * iters)
|
|
75
|
+
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
|
|
76
|
+
|
|
77
|
+
return L_offsets
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _generate_indices(
|
|
81
|
+
tbe_data_config: TBEDataConfig,
|
|
82
|
+
iters: int,
|
|
83
|
+
Bs: list[int],
|
|
84
|
+
L_offsets: torch.Tensor,
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
|
|
87
|
+
total_B = sum(Bs)
|
|
88
|
+
L_offsets_list = L_offsets.tolist()
|
|
89
|
+
indices_list = []
|
|
90
|
+
for it in range(iters):
|
|
91
|
+
# L_offsets is defined over the entire set of batches for a single iteration
|
|
92
|
+
start_offset = L_offsets_list[it * total_B]
|
|
93
|
+
end_offset = L_offsets_list[(it + 1) * total_B]
|
|
94
|
+
|
|
95
|
+
logging.info(f"DEBUG_TBE: _generate_indices E = {tbe_data_config.E=}")
|
|
96
|
+
|
|
97
|
+
indices_list.append(
|
|
98
|
+
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
|
|
99
|
+
tbe_data_config.indices_params.heavy_hitters,
|
|
100
|
+
tbe_data_config.indices_params.zipf_q,
|
|
101
|
+
tbe_data_config.indices_params.zipf_s,
|
|
102
|
+
# max_index = dimensions of the embedding table
|
|
103
|
+
int(tbe_data_config.E),
|
|
104
|
+
# num_indices = number of indices to generate
|
|
105
|
+
end_offset - start_offset,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return torch.cat(indices_list)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _build_requests_jagged(
|
|
113
|
+
tbe_data_config: TBEDataConfig,
|
|
114
|
+
iters: int,
|
|
115
|
+
Bs: list[int],
|
|
116
|
+
Bs_feature_rank: Optional[list[list[int]]],
|
|
117
|
+
L_offsets: torch.Tensor,
|
|
118
|
+
all_indices: torch.Tensor,
|
|
119
|
+
) -> list[TBERequest]:
|
|
120
|
+
total_B = sum(Bs)
|
|
121
|
+
all_indices = all_indices.flatten()
|
|
122
|
+
requests = []
|
|
123
|
+
for it in range(iters):
|
|
124
|
+
start_offset = L_offsets[it * total_B]
|
|
125
|
+
it_L_offsets = torch.concat(
|
|
126
|
+
[
|
|
127
|
+
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
|
|
128
|
+
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
|
|
129
|
+
]
|
|
130
|
+
)
|
|
131
|
+
requests.append(
|
|
132
|
+
TBERequest(
|
|
133
|
+
maybe_to_dtype(
|
|
134
|
+
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
|
|
135
|
+
tbe_data_config.indices_params.index_dtype,
|
|
136
|
+
),
|
|
137
|
+
maybe_to_dtype(
|
|
138
|
+
it_L_offsets.to(get_device()),
|
|
139
|
+
tbe_data_config.indices_params.offset_dtype,
|
|
140
|
+
),
|
|
141
|
+
tbe_data_config._new_weights(int(it_L_offsets[-1].item())),
|
|
142
|
+
Bs_feature_rank if tbe_data_config.variable_B() else None,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
return requests
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _build_requests_dense(
|
|
149
|
+
tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor
|
|
150
|
+
) -> list[TBERequest]:
|
|
151
|
+
# NOTE: We're using existing code from requests.py to build the
|
|
152
|
+
# requests, and since the existing code requires 2D view of all_indices,
|
|
153
|
+
# the existing all_indices must be reshaped
|
|
154
|
+
all_indices = all_indices.reshape(iters, -1)
|
|
155
|
+
|
|
156
|
+
requests = []
|
|
157
|
+
for it in range(iters):
|
|
158
|
+
indices, offsets = get_table_batched_offsets_from_dense(
|
|
159
|
+
all_indices[it].view(
|
|
160
|
+
tbe_data_config.T,
|
|
161
|
+
tbe_data_config.batch_params.B,
|
|
162
|
+
tbe_data_config.pooling_params.L,
|
|
163
|
+
),
|
|
164
|
+
use_cpu=tbe_data_config.use_cpu,
|
|
165
|
+
)
|
|
166
|
+
requests.append(
|
|
167
|
+
TBERequest(
|
|
168
|
+
maybe_to_dtype(indices, tbe_data_config.indices_params.index_dtype),
|
|
169
|
+
maybe_to_dtype(offsets, tbe_data_config.indices_params.offset_dtype),
|
|
170
|
+
tbe_data_config._new_weights(
|
|
171
|
+
tbe_data_config.T
|
|
172
|
+
* tbe_data_config.batch_params.B
|
|
173
|
+
* tbe_data_config.pooling_params.L
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
return requests
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def generate_requests(
|
|
181
|
+
tbe_data_config: TBEDataConfig,
|
|
182
|
+
iters: int = 1,
|
|
183
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
184
|
+
) -> list[TBERequest]:
|
|
185
|
+
|
|
186
|
+
# Generate batch sizes
|
|
187
|
+
if batch_size_per_feature_per_rank:
|
|
188
|
+
Bs = tbe_data_config.batch_params.Bs
|
|
189
|
+
else:
|
|
190
|
+
Bs, _ = _generate_batch_sizes(tbe_data_config)
|
|
191
|
+
|
|
192
|
+
logging.info(
|
|
193
|
+
f"DEBUG_TBE: VBE [generate_requests] batch_size_per_feature_per_rank={batch_size_per_feature_per_rank} Bs={Bs}"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
assert Bs is not None, "Batch sizes (Bs) must be set"
|
|
197
|
+
|
|
198
|
+
# Generate pooling info
|
|
199
|
+
L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs)
|
|
200
|
+
|
|
201
|
+
# Generate indices
|
|
202
|
+
all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
|
|
203
|
+
all_indices = all_indices.to(get_device())
|
|
204
|
+
|
|
205
|
+
# Build TBE requests
|
|
206
|
+
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
|
|
207
|
+
if batch_size_per_feature_per_rank:
|
|
208
|
+
return _build_requests_jagged(
|
|
209
|
+
tbe_data_config,
|
|
210
|
+
iters,
|
|
211
|
+
Bs,
|
|
212
|
+
batch_size_per_feature_per_rank,
|
|
213
|
+
L_offsets,
|
|
214
|
+
all_indices,
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
return _build_requests_jagged(
|
|
218
|
+
tbe_data_config,
|
|
219
|
+
iters,
|
|
220
|
+
Bs,
|
|
221
|
+
batch_size_per_feature_per_rank,
|
|
222
|
+
L_offsets,
|
|
223
|
+
all_indices,
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
return _build_requests_dense(tbe_data_config, iters, all_indices)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def generate_requests_with_Llist(
|
|
230
|
+
tbe_data_config: TBEDataConfig,
|
|
231
|
+
L_list: torch.Tensor,
|
|
232
|
+
iters: int = 1,
|
|
233
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
234
|
+
) -> list[TBERequest]:
|
|
235
|
+
"""
|
|
236
|
+
Generate a list of TBERequest objects based on the provided TBE data configuration and L_list
|
|
237
|
+
This function generates batch sizes and pooling information from the input L_list,
|
|
238
|
+
simulates L distributions with Gaussian noise, and creates indices for embedding lookups.
|
|
239
|
+
It supports both variable batch sizes and sequence lengths, building either jagged or dense requests accordingly.
|
|
240
|
+
Args:
|
|
241
|
+
tbe_data_config (TBEDataConfig): Configuration object containing batch parameters and pooling parameters.
|
|
242
|
+
L_list (torch.Tensor): Tensor of base sequence lengths for each batch.
|
|
243
|
+
iters (int, optional): Number of iterations to repeat the generated requests. Defaults to 1.
|
|
244
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Optional batch size specification per feature per rank. Defaults to None.
|
|
245
|
+
Returns:
|
|
246
|
+
List[TBERequest]: A list of TBERequest objects constructed according to the configuration and input parameters.
|
|
247
|
+
Raises:
|
|
248
|
+
AssertionError: If batch sizes (Bs) are not set in the tbe_data_config.
|
|
249
|
+
Example:
|
|
250
|
+
>>> requests = generate_requests_with_Llist(tbe_data_config, L_list=torch.tensor([10, 20]), iters=2)
|
|
251
|
+
>>> len(requests)
|
|
252
|
+
2
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
# Generate batch sizes
|
|
256
|
+
Bs = tbe_data_config.batch_params.Bs
|
|
257
|
+
assert (
|
|
258
|
+
Bs is not None
|
|
259
|
+
), "Batch sizes (Bs) must be set for generate_requests_with_Llist"
|
|
260
|
+
|
|
261
|
+
# Generate pooling info from L list
|
|
262
|
+
Ls_list = []
|
|
263
|
+
for i in range(len(Bs)):
|
|
264
|
+
L = L_list[i]
|
|
265
|
+
B = Bs[i]
|
|
266
|
+
Ls_iter = np.random.normal(
|
|
267
|
+
loc=L, scale=tbe_data_config.pooling_params.sigma_L, size=B
|
|
268
|
+
).astype(int)
|
|
269
|
+
Ls_list.append(Ls_iter)
|
|
270
|
+
Ls = np.concatenate(Ls_list)
|
|
271
|
+
Ls[Ls < 0] = 0
|
|
272
|
+
# Use the same L distribution across iters
|
|
273
|
+
Ls = np.tile(Ls, iters)
|
|
274
|
+
L = Ls.max()
|
|
275
|
+
# Make it exclusive cumsum
|
|
276
|
+
L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
|
|
277
|
+
|
|
278
|
+
# Generate indices
|
|
279
|
+
all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
|
|
280
|
+
all_indices = all_indices.to(get_device())
|
|
281
|
+
|
|
282
|
+
# Build TBE requests
|
|
283
|
+
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
|
|
284
|
+
return _build_requests_jagged(
|
|
285
|
+
tbe_data_config,
|
|
286
|
+
iters,
|
|
287
|
+
Bs,
|
|
288
|
+
batch_size_per_feature_per_rank,
|
|
289
|
+
L_offsets,
|
|
290
|
+
all_indices,
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
return _build_requests_dense(tbe_data_config, iters, all_indices)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> tuple[int, list[int]]:
|
|
297
|
+
if tbe_data_config.mixed_dim:
|
|
298
|
+
Ds = [
|
|
299
|
+
round_up(
|
|
300
|
+
int(
|
|
301
|
+
torch.randint(
|
|
302
|
+
low=int(0.5 * tbe_data_config.D),
|
|
303
|
+
high=int(1.5 * tbe_data_config.D),
|
|
304
|
+
size=(1,),
|
|
305
|
+
).item()
|
|
306
|
+
),
|
|
307
|
+
4,
|
|
308
|
+
)
|
|
309
|
+
for _ in range(tbe_data_config.T)
|
|
310
|
+
]
|
|
311
|
+
return (sum(Ds) // len(Ds), Ds)
|
|
312
|
+
else:
|
|
313
|
+
return (tbe_data_config.D, [tbe_data_config.D] * tbe_data_config.T)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def generate_feature_requires_grad(
|
|
317
|
+
tbe_data_config: TBEDataConfig, size: int
|
|
318
|
+
) -> torch.Tensor:
|
|
319
|
+
assert (
|
|
320
|
+
size <= tbe_data_config.T
|
|
321
|
+
), "size of feature_requires_grad must be less than T"
|
|
322
|
+
weighted_requires_grad_tables = torch.randperm(tbe_data_config.T)[:size].tolist()
|
|
323
|
+
return (
|
|
324
|
+
torch.tensor(
|
|
325
|
+
[
|
|
326
|
+
1 if t in weighted_requires_grad_tables else 0
|
|
327
|
+
for t in range(tbe_data_config.T)
|
|
328
|
+
]
|
|
329
|
+
)
|
|
330
|
+
.to(get_device())
|
|
331
|
+
.int()
|
|
332
|
+
)
|