fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
|
@@ -8,14 +8,21 @@
|
|
|
8
8
|
# pyre-strict
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
11
13
|
from enum import Enum
|
|
12
14
|
|
|
13
15
|
import click
|
|
14
16
|
import torch
|
|
15
17
|
import yaml
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
from .
|
|
19
|
+
# fmt:skip
|
|
20
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import (
|
|
21
|
+
BatchParams,
|
|
22
|
+
IndicesParams,
|
|
23
|
+
PoolingParams,
|
|
24
|
+
TBEDataConfig,
|
|
25
|
+
)
|
|
19
26
|
|
|
20
27
|
|
|
21
28
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -40,12 +47,16 @@ class TBEDataConfigHelperText(Enum):
|
|
|
40
47
|
TBE_INDICES_HITTERS = "Heavy hitters for indices (comma-delimited list of floats)"
|
|
41
48
|
TBE_INDICES_ZIPF = "Zipf distribution parameters for indices generation (q, s)"
|
|
42
49
|
TBE_INDICES_DTYPE = "The dtype of the table indices (choices: '32', '64')"
|
|
43
|
-
TBE_OFFSETS_DTYPE = "The dtype of the table
|
|
50
|
+
TBE_OFFSETS_DTYPE = "The dtype of the table offsets (choices: '32', '64')"
|
|
44
51
|
|
|
45
52
|
# Pooling Parameters
|
|
46
53
|
TBE_POOLING_SIZE = "Bag size / pooling factor (L)"
|
|
47
|
-
TBE_POOLING_VL_SIGMA = "Standard deviation of
|
|
48
|
-
TBE_POOLING_VL_DIST =
|
|
54
|
+
TBE_POOLING_VL_SIGMA = "Standard deviation of L for variable bag size"
|
|
55
|
+
TBE_POOLING_VL_DIST = (
|
|
56
|
+
"Variable bag size distribution (choices: 'uniform', 'normal')"
|
|
57
|
+
)
|
|
58
|
+
TBE_EMBEDDING_SPECS = "Embedding Specs which is List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]"
|
|
59
|
+
TBE_FEATURE_TABLE_MAP = "Mapping of feature-table"
|
|
49
60
|
|
|
50
61
|
|
|
51
62
|
class TBEDataConfigLoader:
|
|
@@ -73,12 +84,26 @@ class TBEDataConfigLoader:
|
|
|
73
84
|
default=int(1e5),
|
|
74
85
|
help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value,
|
|
75
86
|
),
|
|
87
|
+
click.option(
|
|
88
|
+
"--tbe-num-embeddings-list",
|
|
89
|
+
type=str,
|
|
90
|
+
required=False,
|
|
91
|
+
default=None,
|
|
92
|
+
help="Comma-separated list of number of embeddings (Es)",
|
|
93
|
+
),
|
|
76
94
|
click.option(
|
|
77
95
|
"--tbe-embedding-dim",
|
|
78
96
|
type=int,
|
|
79
97
|
default=128,
|
|
80
98
|
help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value,
|
|
81
99
|
),
|
|
100
|
+
click.option(
|
|
101
|
+
"--tbe-embedding-dim-list",
|
|
102
|
+
type=str,
|
|
103
|
+
required=False,
|
|
104
|
+
default=None,
|
|
105
|
+
help="Comma-separated list of number of Embedding dimensions (D)",
|
|
106
|
+
),
|
|
82
107
|
click.option(
|
|
83
108
|
"--tbe-mixed-dim",
|
|
84
109
|
is_flag=True,
|
|
@@ -91,6 +116,13 @@ class TBEDataConfigLoader:
|
|
|
91
116
|
default=False,
|
|
92
117
|
help=TBEDataConfigHelperText.TBE_WEIGHTED.value,
|
|
93
118
|
),
|
|
119
|
+
click.option(
|
|
120
|
+
"--tbe-max-indices",
|
|
121
|
+
type=int,
|
|
122
|
+
required=False,
|
|
123
|
+
default=None,
|
|
124
|
+
help="(Optional) Maximum number of indices, will be calculated if not provided",
|
|
125
|
+
),
|
|
94
126
|
# Batch Parameters
|
|
95
127
|
click.option(
|
|
96
128
|
"--tbe-batch-size",
|
|
@@ -98,6 +130,13 @@ class TBEDataConfigLoader:
|
|
|
98
130
|
default=512,
|
|
99
131
|
help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value,
|
|
100
132
|
),
|
|
133
|
+
click.option(
|
|
134
|
+
"--tbe-batch-sizes-list",
|
|
135
|
+
type=str,
|
|
136
|
+
required=False,
|
|
137
|
+
default=None,
|
|
138
|
+
help="List Batch sizes per feature (Bs)",
|
|
139
|
+
),
|
|
101
140
|
click.option(
|
|
102
141
|
"--tbe-batch-vbe-sigma",
|
|
103
142
|
type=int,
|
|
@@ -160,6 +199,18 @@ class TBEDataConfigLoader:
|
|
|
160
199
|
required=False,
|
|
161
200
|
help=TBEDataConfigHelperText.TBE_POOLING_VL_DIST.value,
|
|
162
201
|
),
|
|
202
|
+
click.option(
|
|
203
|
+
"--tbe-embedding-specs",
|
|
204
|
+
type=str,
|
|
205
|
+
required=False,
|
|
206
|
+
help=TBEDataConfigHelperText.TBE_EMBEDDING_SPECS.value,
|
|
207
|
+
),
|
|
208
|
+
click.option(
|
|
209
|
+
"--tbe-feature-table-map",
|
|
210
|
+
type=str,
|
|
211
|
+
required=False,
|
|
212
|
+
help=TBEDataConfigHelperText.TBE_FEATURE_TABLE_MAP.value,
|
|
213
|
+
),
|
|
163
214
|
]
|
|
164
215
|
|
|
165
216
|
for option in reversed(options):
|
|
@@ -180,18 +231,62 @@ class TBEDataConfigLoader:
|
|
|
180
231
|
params = context.params
|
|
181
232
|
|
|
182
233
|
# Read table parameters
|
|
183
|
-
T = params["tbe_num_tables"]
|
|
184
|
-
E = params["tbe_num_embeddings"]
|
|
234
|
+
T = params["tbe_num_tables"] # number of features
|
|
235
|
+
E = params["tbe_num_embeddings"] # feature_rows
|
|
236
|
+
if params["tbe_num_embeddings_list"] is not None:
|
|
237
|
+
Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")]
|
|
238
|
+
T = len(Es)
|
|
239
|
+
E = sum(Es) // T # average E
|
|
240
|
+
else:
|
|
241
|
+
Es = None
|
|
185
242
|
D = params["tbe_embedding_dim"]
|
|
243
|
+
if params["tbe_embedding_dim_list"] is not None:
|
|
244
|
+
Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")]
|
|
245
|
+
assert (
|
|
246
|
+
len(Ds) == T
|
|
247
|
+
), f"Expected tbe_embedding_dim_list to have {T} elements, but got {len(Ds)}"
|
|
248
|
+
D = sum(Ds) // T # average D
|
|
249
|
+
else:
|
|
250
|
+
Ds = None
|
|
251
|
+
|
|
186
252
|
mixed_dim = params["tbe_mixed_dim"]
|
|
187
253
|
weighted = params["tbe_weighted"]
|
|
254
|
+
if params["tbe_max_indices"] is not None:
|
|
255
|
+
max_indices = params["tbe_max_indices"]
|
|
256
|
+
else:
|
|
257
|
+
max_indices = None
|
|
188
258
|
|
|
189
259
|
# Read batch parameters
|
|
190
260
|
B = params["tbe_batch_size"]
|
|
191
261
|
sigma_B = params["tbe_batch_vbe_sigma"]
|
|
192
262
|
vbe_distribution = params["tbe_batch_vbe_dist"]
|
|
193
263
|
vbe_num_ranks = params["tbe_batch_vbe_ranks"]
|
|
194
|
-
|
|
264
|
+
if params["tbe_batch_sizes_list"] is not None:
|
|
265
|
+
Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")]
|
|
266
|
+
B = sum(Bs) // T # average B
|
|
267
|
+
else:
|
|
268
|
+
B = params["tbe_batch_size"]
|
|
269
|
+
Bs = None
|
|
270
|
+
batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs)
|
|
271
|
+
|
|
272
|
+
# Parse embedding_specs: "(E,D),(E,D),..." or "(E,D,loc,dev),(E,D,loc,dev),..."
|
|
273
|
+
# Only the first two values (E, D) are extracted.
|
|
274
|
+
embedding_specs = None
|
|
275
|
+
feature_table_map = None
|
|
276
|
+
if params["tbe_embedding_specs"] is not None:
|
|
277
|
+
try:
|
|
278
|
+
tuples = re.findall(r"\(([^)]+)\)", params["tbe_embedding_specs"])
|
|
279
|
+
if tuples:
|
|
280
|
+
embedding_specs = [
|
|
281
|
+
(int(t.split(",")[0].strip()), int(t.split(",")[1].strip()))
|
|
282
|
+
for t in tuples
|
|
283
|
+
]
|
|
284
|
+
except (ValueError, IndexError):
|
|
285
|
+
logging.warning("Failed to parse embedding_specs. Setting to None.")
|
|
286
|
+
if params["tbe_feature_table_map"] is not None:
|
|
287
|
+
feature_table_map = [
|
|
288
|
+
int(x) for x in params["tbe_feature_table_map"].split(",")
|
|
289
|
+
]
|
|
195
290
|
|
|
196
291
|
# Read indices parameters
|
|
197
292
|
heavy_hitters = (
|
|
@@ -226,6 +321,11 @@ class TBEDataConfigLoader:
|
|
|
226
321
|
indices_params,
|
|
227
322
|
pooling_params,
|
|
228
323
|
not torch.cuda.is_available(),
|
|
324
|
+
Es,
|
|
325
|
+
Ds,
|
|
326
|
+
max_indices,
|
|
327
|
+
embedding_specs,
|
|
328
|
+
feature_table_map,
|
|
229
329
|
).validate()
|
|
230
330
|
|
|
231
331
|
@classmethod
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
import dataclasses
|
|
11
11
|
import json
|
|
12
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Optional
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
|
|
@@ -40,7 +40,7 @@ class IndicesParams:
|
|
|
40
40
|
|
|
41
41
|
@classmethod
|
|
42
42
|
# pyre-ignore [3]
|
|
43
|
-
def from_dict(cls, data:
|
|
43
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
44
44
|
if not isinstance(data["heavy_hitters"], torch.Tensor):
|
|
45
45
|
data["heavy_hitters"] = torch.tensor(
|
|
46
46
|
data["heavy_hitters"], dtype=torch.float32
|
|
@@ -54,7 +54,7 @@ class IndicesParams:
|
|
|
54
54
|
def from_json(cls, data: str):
|
|
55
55
|
return cls.from_dict(json.loads(data))
|
|
56
56
|
|
|
57
|
-
def dict(self) ->
|
|
57
|
+
def dict(self) -> dict[str, Any]:
|
|
58
58
|
# https://stackoverflow.com/questions/73735974/convert-dataclass-of-dataclass-to-json-string
|
|
59
59
|
tmp = dataclasses.asdict(self)
|
|
60
60
|
# Convert tensor to list for JSON serialization
|
|
@@ -98,10 +98,12 @@ class BatchParams:
|
|
|
98
98
|
vbe_distribution: Optional[str] = "normal"
|
|
99
99
|
# Number of ranks for variable batch size generation
|
|
100
100
|
vbe_num_ranks: Optional[int] = None
|
|
101
|
+
# List of target batch sizes, i.e. number of batch lookups per feature
|
|
102
|
+
Bs: Optional[list[int]] = None
|
|
101
103
|
|
|
102
104
|
@classmethod
|
|
103
105
|
# pyre-ignore [3]
|
|
104
|
-
def from_dict(cls, data:
|
|
106
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
105
107
|
return cls(**data)
|
|
106
108
|
|
|
107
109
|
@classmethod
|
|
@@ -109,7 +111,7 @@ class BatchParams:
|
|
|
109
111
|
def from_json(cls, data: str):
|
|
110
112
|
return cls.from_dict(json.loads(data))
|
|
111
113
|
|
|
112
|
-
def dict(self) ->
|
|
114
|
+
def dict(self) -> dict[str, Any]:
|
|
113
115
|
return dataclasses.asdict(self)
|
|
114
116
|
|
|
115
117
|
def json(self, format: bool = False) -> str:
|
|
@@ -117,7 +119,10 @@ class BatchParams:
|
|
|
117
119
|
|
|
118
120
|
# pyre-ignore [3]
|
|
119
121
|
def validate(self):
|
|
120
|
-
|
|
122
|
+
if self.Bs is not None:
|
|
123
|
+
assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive"
|
|
124
|
+
else:
|
|
125
|
+
assert self.B > 0, "B must be positive"
|
|
121
126
|
assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive"
|
|
122
127
|
assert (
|
|
123
128
|
self.vbe_num_ranks is None or self.vbe_num_ranks > 0
|
|
@@ -137,10 +142,12 @@ class PoolingParams:
|
|
|
137
142
|
sigma_L: Optional[int] = None
|
|
138
143
|
# [Optional] Distribution of embedding sequence lengths (normal, uniform)
|
|
139
144
|
length_distribution: Optional[str] = "normal"
|
|
145
|
+
# [Optional] List of target bag sizes, i.e. pooling factors per batch
|
|
146
|
+
Ls: Optional[list[float]] = None
|
|
140
147
|
|
|
141
148
|
@classmethod
|
|
142
149
|
# pyre-ignore [3]
|
|
143
|
-
def from_dict(cls, data:
|
|
150
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
144
151
|
return cls(**data)
|
|
145
152
|
|
|
146
153
|
@classmethod
|
|
@@ -148,7 +155,7 @@ class PoolingParams:
|
|
|
148
155
|
def from_json(cls, data: str):
|
|
149
156
|
return cls.from_dict(json.loads(data))
|
|
150
157
|
|
|
151
|
-
def dict(self) ->
|
|
158
|
+
def dict(self) -> dict[str, Any]:
|
|
152
159
|
return dataclasses.asdict(self)
|
|
153
160
|
|
|
154
161
|
def json(self, format: bool = False) -> str:
|
fbgemm_gpu/tbe/bench/utils.py
CHANGED
|
@@ -6,15 +6,14 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
|
-
import
|
|
9
|
+
from typing import List, Tuple
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
|
|
14
|
+
# fmt:skip
|
|
14
15
|
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
15
16
|
|
|
16
|
-
logging.basicConfig(level=logging.DEBUG)
|
|
17
|
-
|
|
18
17
|
|
|
19
18
|
def fill_random_scale_bias(
|
|
20
19
|
emb: torch.nn.Module,
|
|
@@ -23,9 +22,9 @@ def fill_random_scale_bias(
|
|
|
23
22
|
) -> None:
|
|
24
23
|
for t in range(T):
|
|
25
24
|
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
|
26
|
-
|
|
25
|
+
weights, scale_shift = emb.split_embedding_weights()[t]
|
|
27
26
|
if scale_shift is not None:
|
|
28
|
-
|
|
27
|
+
E, R = scale_shift.shape
|
|
29
28
|
assert R == 4
|
|
30
29
|
scales = None
|
|
31
30
|
shifts = None
|
|
@@ -46,3 +45,128 @@ def fill_random_scale_bias(
|
|
|
46
45
|
device=scale_shift.device,
|
|
47
46
|
)
|
|
48
47
|
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def check_oom(
|
|
51
|
+
data_size: int,
|
|
52
|
+
) -> Tuple[bool, str]:
|
|
53
|
+
free_memory, total_memory = torch.cuda.mem_get_info()
|
|
54
|
+
if data_size > free_memory:
|
|
55
|
+
warning = f"Expect to allocate {round(data_size / (1024 ** 3), 2)} GB, but available memory is {round(free_memory / (1024 ** 3), 2)} GB from {round(total_memory / (1024 ** 3), 2)} GB."
|
|
56
|
+
return (True, warning)
|
|
57
|
+
return (False, "")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def generate_batch_size_per_feature_per_rank(
|
|
61
|
+
Bs: List[int], num_ranks: int
|
|
62
|
+
) -> List[List[int]]:
|
|
63
|
+
"""
|
|
64
|
+
Generate batch size per feature per rank for VBE, assuming the batch size
|
|
65
|
+
is evenly distributed across ranks.
|
|
66
|
+
Args:
|
|
67
|
+
Bs (List[int]): batch size per feature
|
|
68
|
+
num_ranks (int): number of ranks
|
|
69
|
+
Returns:
|
|
70
|
+
List[List[int]]: batch size per feature per rank
|
|
71
|
+
"""
|
|
72
|
+
b_per_feature_per_rank = []
|
|
73
|
+
for B in Bs:
|
|
74
|
+
b_per_feature = []
|
|
75
|
+
for i in range(num_ranks):
|
|
76
|
+
if i != num_ranks - 1:
|
|
77
|
+
b_per_feature.append(int(B / num_ranks))
|
|
78
|
+
else:
|
|
79
|
+
b_per_feature.append(B - sum(b_per_feature))
|
|
80
|
+
b_per_feature_per_rank.append(b_per_feature)
|
|
81
|
+
return b_per_feature_per_rank
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_merged_output_and_offsets(
|
|
85
|
+
Ds: List[int],
|
|
86
|
+
Bs: List[int],
|
|
87
|
+
output_dtype: torch.dtype,
|
|
88
|
+
device: torch.device,
|
|
89
|
+
num_ranks: int = 2,
|
|
90
|
+
num_tbe_ops: int = 2,
|
|
91
|
+
) -> Tuple[List[List[int]], torch.Tensor, torch.Tensor]:
|
|
92
|
+
"""
|
|
93
|
+
Generate merged vbe_output and vbe_output_offsets tensors for VBE.
|
|
94
|
+
The vbe_output is a tensor that will contain forward output from all VBE TBE ops.
|
|
95
|
+
The vbe_output_offsets is a tensor that will contain start offsets for the output to be written to.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
Ds (List[int]): embedding dimension per feature
|
|
99
|
+
Bs (List[int]): batch size per feature
|
|
100
|
+
num_ranks (int): number of ranks
|
|
101
|
+
num_tbe_ops (int): number of TBE ops
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple[List[List[int]], torch.Tensor, torch.Tensor]: batch_size_per_feature_per_rank, merged vbe_output and vbe_output_offsets tensors
|
|
104
|
+
"""
|
|
105
|
+
# The first embedding ops is the embedding op created in the benchmark
|
|
106
|
+
emb_op = {}
|
|
107
|
+
emb_op[0] = {}
|
|
108
|
+
emb_op[0]["dim"] = Ds
|
|
109
|
+
emb_op[0]["Bs"] = Bs
|
|
110
|
+
emb_op[0]["output_size"] = sum([b * d for b, d in zip(Bs, Ds)])
|
|
111
|
+
emb_op[0]["batch_size_per_feature_per_rank"] = (
|
|
112
|
+
generate_batch_size_per_feature_per_rank(Bs, num_ranks)
|
|
113
|
+
)
|
|
114
|
+
num_features = len(Bs)
|
|
115
|
+
# create other embedding ops to allocate output and offsets tensors
|
|
116
|
+
# Using representative values for additional TBE ops in multi-op scenarios:
|
|
117
|
+
# - batch_size=32000: typical large batch size for production workloads
|
|
118
|
+
# - dim=512: common embedding dimension for large models
|
|
119
|
+
for i in range(1, num_tbe_ops):
|
|
120
|
+
emb_op[i] = {}
|
|
121
|
+
emb_op[i]["batch_size_per_feature_per_rank"] = (
|
|
122
|
+
generate_batch_size_per_feature_per_rank([32000], num_ranks)
|
|
123
|
+
)
|
|
124
|
+
emb_op[i]["Bs"] = [sum(B) for B in emb_op[i]["batch_size_per_feature_per_rank"]]
|
|
125
|
+
emb_op[i]["dim"] = [512]
|
|
126
|
+
emb_op[i]["output_size"] = sum(
|
|
127
|
+
[b * d for b, d in zip(emb_op[i]["Bs"], emb_op[i]["dim"])]
|
|
128
|
+
)
|
|
129
|
+
total_output = 0
|
|
130
|
+
ranks = [[] for _ in range(num_ranks)]
|
|
131
|
+
for e in emb_op.values():
|
|
132
|
+
b_per_rank_per_feature = list(zip(*e["batch_size_per_feature_per_rank"]))
|
|
133
|
+
assert len(b_per_rank_per_feature) == num_ranks
|
|
134
|
+
dims = e["dim"]
|
|
135
|
+
for r, b_r in enumerate(b_per_rank_per_feature):
|
|
136
|
+
for f, b in enumerate(b_r):
|
|
137
|
+
output_size_per_batch = b * dims[f]
|
|
138
|
+
ranks[r].append(output_size_per_batch)
|
|
139
|
+
total_output += output_size_per_batch
|
|
140
|
+
ranks[0].insert(0, 0)
|
|
141
|
+
offsets_ranks: List[List[int]] = [[] for _ in range(num_ranks)]
|
|
142
|
+
total_output_offsets = []
|
|
143
|
+
start = 0
|
|
144
|
+
for r in range(num_ranks):
|
|
145
|
+
offsets_ranks[r] = [
|
|
146
|
+
start + sum(ranks[r][: i + 1]) for i in range(len(ranks[r]))
|
|
147
|
+
]
|
|
148
|
+
start = offsets_ranks[r][-1]
|
|
149
|
+
total_output_offsets.extend(offsets_ranks[r])
|
|
150
|
+
check_total_output_size = sum([e["output_size"] for e in emb_op.values()])
|
|
151
|
+
assert (
|
|
152
|
+
total_output == check_total_output_size
|
|
153
|
+
), f"{total_output} != {check_total_output_size}{[e['output_size'] for e in emb_op.values()]}"
|
|
154
|
+
assert (
|
|
155
|
+
total_output == total_output_offsets[-1]
|
|
156
|
+
), f"{total_output} != {total_output_offsets[-1]}"
|
|
157
|
+
out = torch.empty(total_output, dtype=output_dtype, device=device)
|
|
158
|
+
offsets = []
|
|
159
|
+
offsets.append(offsets_ranks[0][:num_features])
|
|
160
|
+
for r in range(1, num_ranks):
|
|
161
|
+
start = [offsets_ranks[r - 1][-1]]
|
|
162
|
+
the_rest = offsets_ranks[r][: num_features - 1] if num_features > 1 else []
|
|
163
|
+
start.extend(the_rest)
|
|
164
|
+
offsets.append(start)
|
|
165
|
+
|
|
166
|
+
out_offsets = torch.tensor(
|
|
167
|
+
offsets,
|
|
168
|
+
dtype=torch.int64,
|
|
169
|
+
device=device,
|
|
170
|
+
)
|
|
171
|
+
batch_size_per_feature_per_rank = emb_op[0]["batch_size_per_feature_per_rank"]
|
|
172
|
+
return (batch_size_per_feature_per_rank, out, out_offsets)
|
|
@@ -10,7 +10,7 @@
|
|
|
10
10
|
# pyre-ignore-all-errors[56]
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
from typing import
|
|
13
|
+
from typing import Optional, Union
|
|
14
14
|
|
|
15
15
|
import torch # usort:skip
|
|
16
16
|
from torch import Tensor # usort:skip
|
|
@@ -47,15 +47,15 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
47
47
|
|
|
48
48
|
def __init__( # noqa C901
|
|
49
49
|
self,
|
|
50
|
-
embedding_specs:
|
|
51
|
-
|
|
50
|
+
embedding_specs: list[
|
|
51
|
+
tuple[str, int, int, SparseType, EmbeddingLocation]
|
|
52
52
|
], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement)
|
|
53
|
-
feature_table_map: Optional[
|
|
54
|
-
index_remapping: Optional[
|
|
53
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
54
|
+
index_remapping: Optional[list[Tensor]] = None,
|
|
55
55
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
56
56
|
device: Optional[Union[str, int, torch.device]] = None,
|
|
57
57
|
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
|
|
58
|
-
weight_lists: Optional[
|
|
58
|
+
weight_lists: Optional[list[tuple[Tensor, Optional[Tensor]]]] = None,
|
|
59
59
|
pruning_hash_load_factor: float = 0.5,
|
|
60
60
|
use_array_for_index_remapping: bool = True,
|
|
61
61
|
output_dtype: SparseType = SparseType.FP16,
|
|
@@ -74,8 +74,9 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
74
74
|
cacheline_alignment: bool = True,
|
|
75
75
|
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
|
|
76
76
|
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
|
|
77
|
-
feature_names_per_table: Optional[
|
|
77
|
+
feature_names_per_table: Optional[list[list[str]]] = None,
|
|
78
78
|
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
|
|
79
|
+
embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
|
|
79
80
|
) -> None: # noqa C901 # tuple of (rows, dims,)
|
|
80
81
|
super(KVEmbeddingInference, self).__init__(
|
|
81
82
|
embedding_specs=embedding_specs,
|
|
@@ -114,17 +115,21 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
114
115
|
num_shards = 32
|
|
115
116
|
uniform_init_lower: float = -0.01
|
|
116
117
|
uniform_init_upper: float = 0.01
|
|
118
|
+
|
|
117
119
|
# pyre-fixme[4]: Attribute must be annotated.
|
|
118
120
|
self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
|
|
119
|
-
num_shards,
|
|
121
|
+
num_shards,
|
|
122
|
+
uniform_init_lower,
|
|
123
|
+
uniform_init_upper,
|
|
124
|
+
embedding_cache_mode, # in embedding_cache_mode, we disable random init
|
|
120
125
|
)
|
|
121
126
|
|
|
122
|
-
self.specs:
|
|
127
|
+
self.specs: list[tuple[int, int, int]] = [
|
|
123
128
|
(rows, dims, sparse_type.as_int())
|
|
124
129
|
for (_, rows, dims, sparse_type, _) in self.embedding_specs
|
|
125
130
|
]
|
|
126
131
|
# table shard offset if inference sharding is enabled, otherwise, should be all zeros
|
|
127
|
-
self.table_sharding_offset:
|
|
132
|
+
self.table_sharding_offset: list[int] = [0] * len(self.embedding_specs)
|
|
128
133
|
self.kv_embedding_cache_initialized = False
|
|
129
134
|
self.hash_size_cumsum: torch.Tensor = torch.zeros(
|
|
130
135
|
0,
|
|
@@ -137,7 +142,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
137
142
|
dtype=torch.int64,
|
|
138
143
|
)
|
|
139
144
|
|
|
140
|
-
def construct_hash_size_cumsum(self) ->
|
|
145
|
+
def construct_hash_size_cumsum(self) -> list[int]:
|
|
141
146
|
hash_size_cumsum = [0]
|
|
142
147
|
for spec in self.embedding_specs:
|
|
143
148
|
rows = spec[1]
|
|
@@ -146,7 +151,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
146
151
|
|
|
147
152
|
def calculate_indices_and_weights_offsets(
|
|
148
153
|
self, indices: Tensor, offsets: Tensor
|
|
149
|
-
) ->
|
|
154
|
+
) -> tuple[Tensor, Tensor]:
|
|
150
155
|
if self.pooling_mode is not PoolingMode.NONE:
|
|
151
156
|
T = self.weights_offsets.numel()
|
|
152
157
|
else:
|
|
@@ -280,7 +285,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
280
285
|
self.weight_initialized = True
|
|
281
286
|
|
|
282
287
|
@torch.jit.export
|
|
283
|
-
def init_tbe_config(self, table_sharding_offset:
|
|
288
|
+
def init_tbe_config(self, table_sharding_offset: list[int]) -> None:
|
|
284
289
|
"""
|
|
285
290
|
Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc.
|
|
286
291
|
Should be called before loading weights.
|
|
@@ -290,9 +295,9 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
290
295
|
@torch.jit.export
|
|
291
296
|
def embedding_inplace_update(
|
|
292
297
|
self,
|
|
293
|
-
update_table_indices:
|
|
294
|
-
update_row_indices:
|
|
295
|
-
update_weights:
|
|
298
|
+
update_table_indices: list[int],
|
|
299
|
+
update_row_indices: list[list[int]],
|
|
300
|
+
update_weights: list[Tensor],
|
|
296
301
|
) -> None:
|
|
297
302
|
# function is not used for now on the inference side
|
|
298
303
|
for i in range(len(update_table_indices)):
|
|
@@ -355,9 +360,7 @@ class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen):
|
|
|
355
360
|
if not self.kv_embedding_cache_initialized:
|
|
356
361
|
self.initialize_logical_weights_placements_and_offsets()
|
|
357
362
|
|
|
358
|
-
self.row_alignment =
|
|
359
|
-
8 if self.use_cpu else self.row_alignment
|
|
360
|
-
) # in order to use mempool implementation for kv embedding it needs to be divisible by 8
|
|
363
|
+
self.row_alignment = 8 # in order to use mempool implementation for kv embedding it needs to be divisible by 8
|
|
361
364
|
|
|
362
365
|
hash_size_cumsum = self.construct_hash_size_cumsum()
|
|
363
366
|
self.hash_size_cumsum = torch.tensor(
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-unsafe
|
|
8
8
|
|
|
9
|
-
from typing import Optional,
|
|
9
|
+
from typing import Optional, Union
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
@@ -17,13 +17,13 @@ def get_unique_indices_v2(
|
|
|
17
17
|
compute_count: bool = False,
|
|
18
18
|
compute_inverse_indices: bool = False,
|
|
19
19
|
) -> Union[
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
|
21
|
+
tuple[
|
|
22
22
|
torch.Tensor,
|
|
23
23
|
torch.Tensor,
|
|
24
24
|
Optional[torch.Tensor],
|
|
25
25
|
],
|
|
26
|
-
|
|
26
|
+
tuple[torch.Tensor, torch.Tensor],
|
|
27
27
|
]:
|
|
28
28
|
"""
|
|
29
29
|
A wrapper for get_unique_indices for overloading the return type
|
fbgemm_gpu/tbe/ssd/common.py
CHANGED
fbgemm_gpu/tbe/ssd/inference.py
CHANGED
|
@@ -13,7 +13,7 @@ import logging
|
|
|
13
13
|
import os
|
|
14
14
|
import tempfile
|
|
15
15
|
from math import log2
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Optional
|
|
17
17
|
|
|
18
18
|
import torch # usort:skip
|
|
19
19
|
|
|
@@ -42,15 +42,15 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
42
42
|
Inference version, with FP32/FP16/FP8/INT8/INT4/INT2 supports
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
embedding_specs:
|
|
45
|
+
embedding_specs: list[tuple[str, int, int, SparseType]]
|
|
46
46
|
_local_instance_index: int = -1
|
|
47
47
|
|
|
48
48
|
def __init__(
|
|
49
49
|
self,
|
|
50
|
-
embedding_specs:
|
|
51
|
-
|
|
50
|
+
embedding_specs: list[
|
|
51
|
+
tuple[str, int, int, SparseType]
|
|
52
52
|
], # tuple of (feature_names, rows, dims, SparseType)
|
|
53
|
-
feature_table_map: Optional[
|
|
53
|
+
feature_table_map: Optional[list[int]] = None, # [T]
|
|
54
54
|
pooling_mode: PoolingMode = PoolingMode.SUM,
|
|
55
55
|
output_dtype: SparseType = SparseType.FP16,
|
|
56
56
|
row_alignment: Optional[int] = None,
|
|
@@ -73,7 +73,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
73
73
|
ssd_uniform_init_lower: float = -0.01,
|
|
74
74
|
ssd_uniform_init_upper: float = 0.01,
|
|
75
75
|
# Parameter Server Configs
|
|
76
|
-
ps_hosts: Optional[
|
|
76
|
+
ps_hosts: Optional[tuple[tuple[str, int]]] = None,
|
|
77
77
|
ps_max_key_per_request: Optional[int] = None,
|
|
78
78
|
ps_client_thread_num: Optional[int] = None,
|
|
79
79
|
ps_max_local_index_length: Optional[int] = None,
|
|
@@ -99,7 +99,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
99
99
|
self.current_device = torch.device(device)
|
|
100
100
|
self.use_cpu: bool = self.current_device.type == "cpu"
|
|
101
101
|
|
|
102
|
-
self.feature_table_map:
|
|
102
|
+
self.feature_table_map: list[int] = (
|
|
103
103
|
feature_table_map if feature_table_map is not None else list(range(T_))
|
|
104
104
|
)
|
|
105
105
|
T = len(self.feature_table_map)
|
|
@@ -112,9 +112,9 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
112
112
|
self.output_dtype: int = output_dtype.as_int()
|
|
113
113
|
# (feature_names, rows, dims, weights_tys) = zip(*embedding_specs)
|
|
114
114
|
# Pyre workaround
|
|
115
|
-
rows:
|
|
116
|
-
dims:
|
|
117
|
-
weights_tys:
|
|
115
|
+
rows: list[int] = [e[1] for e in embedding_specs]
|
|
116
|
+
dims: list[int] = [e[2] for e in embedding_specs]
|
|
117
|
+
weights_tys: list[SparseType] = [e[3] for e in embedding_specs]
|
|
118
118
|
|
|
119
119
|
D_offsets = [dims[t] for t in self.feature_table_map]
|
|
120
120
|
D_offsets = [0] + list(itertools.accumulate(D_offsets))
|
|
@@ -169,7 +169,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
169
169
|
offsets.append(uvm_size)
|
|
170
170
|
uvm_size += state_size
|
|
171
171
|
|
|
172
|
-
self.weights_physical_offsets:
|
|
172
|
+
self.weights_physical_offsets: list[int] = offsets
|
|
173
173
|
|
|
174
174
|
weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map]
|
|
175
175
|
self.register_buffer(
|
|
@@ -306,7 +306,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
306
306
|
)
|
|
307
307
|
|
|
308
308
|
# pyre-fixme[20]: Argument `self` expected.
|
|
309
|
-
|
|
309
|
+
low_priority, high_priority = torch.cuda.Stream.priority_range()
|
|
310
310
|
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
|
|
311
311
|
self.ssd_set_start = torch.cuda.Event()
|
|
312
312
|
self.ssd_set_end = torch.cuda.Event()
|
|
@@ -369,7 +369,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
369
369
|
|
|
370
370
|
@torch.jit.export
|
|
371
371
|
def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
|
|
372
|
-
|
|
372
|
+
indices, offsets = indices.long(), offsets.long()
|
|
373
373
|
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
|
|
374
374
|
self.hash_size_cumsum,
|
|
375
375
|
indices,
|
|
@@ -517,13 +517,13 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
|
|
|
517
517
|
@torch.jit.export
|
|
518
518
|
def split_embedding_weights(
|
|
519
519
|
self, split_scale_shifts: bool = True
|
|
520
|
-
) ->
|
|
520
|
+
) -> list[tuple[Tensor, Optional[Tensor]]]:
|
|
521
521
|
"""
|
|
522
522
|
Returns a list of weights, split by table.
|
|
523
523
|
|
|
524
524
|
Testing only, very slow.
|
|
525
525
|
"""
|
|
526
|
-
splits:
|
|
526
|
+
splits: list[tuple[Tensor, Optional[Tensor]]] = []
|
|
527
527
|
rows_cumsum = 0
|
|
528
528
|
for _, row, dim, weight_ty in self.embedding_specs:
|
|
529
529
|
weights = torch.empty(
|