fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from fbgemm_gpu.tbe.utils.common import get_device
|
|
18
|
+
|
|
19
|
+
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
torch.ops.load_library(
|
|
23
|
+
"//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
|
|
24
|
+
)
|
|
25
|
+
except Exception:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclasses.dataclass(frozen=True)
|
|
30
|
+
class TBEDataConfig:
|
|
31
|
+
# Number of tables
|
|
32
|
+
T: int
|
|
33
|
+
# Number of rows in the embedding table
|
|
34
|
+
E: int
|
|
35
|
+
# Target embedding dimension for a table (number of columns)
|
|
36
|
+
D: int
|
|
37
|
+
# Generate mixed dimensions if true
|
|
38
|
+
mixed_dim: bool
|
|
39
|
+
# Whether the lookup rows are weighted or not
|
|
40
|
+
weighted: bool
|
|
41
|
+
# Batch parameters
|
|
42
|
+
batch_params: BatchParams
|
|
43
|
+
# Indices parameters
|
|
44
|
+
indices_params: IndicesParams
|
|
45
|
+
# Pooling parameters
|
|
46
|
+
pooling_params: PoolingParams
|
|
47
|
+
# Force generated tensors to be on CPU
|
|
48
|
+
use_cpu: bool = False
|
|
49
|
+
# Number of embeddings in each embedding features (number of rows)
|
|
50
|
+
Es: Optional[list[int]] = None
|
|
51
|
+
# Target embedding dimension for each features (number of columns)
|
|
52
|
+
Ds: Optional[list[int]] = None
|
|
53
|
+
# Maximum number of indices
|
|
54
|
+
max_indices: Optional[int] = None # Maximum number of indices
|
|
55
|
+
|
|
56
|
+
def __post_init__(self) -> None:
|
|
57
|
+
if isinstance(self.D, list):
|
|
58
|
+
object.__setattr__(self, "mixed_dim", len(set(self.D)) > 1)
|
|
59
|
+
if isinstance(self.E, list) and self.max_indices is None:
|
|
60
|
+
object.__setattr__(self, "max_indices", sum(self.E) - 1)
|
|
61
|
+
self.validate()
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def complex_fields() -> dict[str, Any]:
|
|
65
|
+
return {
|
|
66
|
+
"batch_params": BatchParams,
|
|
67
|
+
"indices_params": IndicesParams,
|
|
68
|
+
"pooling_params": PoolingParams,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
# pyre-ignore [3]
|
|
73
|
+
def from_dict(cls, data: dict[str, Any]):
|
|
74
|
+
for field, Type in cls.complex_fields().items():
|
|
75
|
+
if not isinstance(data[field], Type):
|
|
76
|
+
data[field] = Type.from_dict(data[field])
|
|
77
|
+
return cls(**data)
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
# pyre-ignore [3]
|
|
81
|
+
def from_json(cls, data: str):
|
|
82
|
+
raw = json.loads(data)
|
|
83
|
+
allowed = {f.name for f in dataclasses.fields(cls)}
|
|
84
|
+
existing_fields = {k: v for k, v in raw.items() if k in allowed}
|
|
85
|
+
missing_fields = allowed - set(existing_fields.keys())
|
|
86
|
+
unknown_fields = set(raw.keys()) - allowed
|
|
87
|
+
if missing_fields:
|
|
88
|
+
logging.warning(
|
|
89
|
+
f"TBEDataConfig.from_json: Missing expected fields not loaded: {sorted(missing_fields)}"
|
|
90
|
+
)
|
|
91
|
+
if unknown_fields:
|
|
92
|
+
logging.info(
|
|
93
|
+
f"TBEDataConfig.from_json: Ignored unknown fields from input: {sorted(unknown_fields)}"
|
|
94
|
+
)
|
|
95
|
+
return cls.from_dict(existing_fields)
|
|
96
|
+
|
|
97
|
+
def dict(self) -> dict[str, Any]:
|
|
98
|
+
tmp = dataclasses.asdict(self)
|
|
99
|
+
for field in TBEDataConfig.complex_fields().keys():
|
|
100
|
+
tmp[field] = self.__dict__[field].dict()
|
|
101
|
+
return tmp
|
|
102
|
+
|
|
103
|
+
def json(self, format: bool = False) -> str:
|
|
104
|
+
return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
|
|
105
|
+
|
|
106
|
+
# pyre-ignore [3]
|
|
107
|
+
def validate(self):
|
|
108
|
+
# NOTE: Add validation logic here
|
|
109
|
+
assert self.T > 0, "T must be positive"
|
|
110
|
+
assert self.E > 0, "E must be positive"
|
|
111
|
+
if self.Es is not None:
|
|
112
|
+
assert all(e > 0 for e in self.Es), "All elements in Es must be positive"
|
|
113
|
+
assert self.D > 0, "D must be positive"
|
|
114
|
+
if self.Ds is not None:
|
|
115
|
+
assert all(d > 0 for d in self.Ds), "All elements in Ds must be positive"
|
|
116
|
+
if isinstance(self.E, list) and isinstance(self.D, list):
|
|
117
|
+
assert (
|
|
118
|
+
len(self.E) == len(self.D) == self.T
|
|
119
|
+
), "Lengths of Es, Lengths of Ds, and T must be equal"
|
|
120
|
+
if self.max_indices is not None:
|
|
121
|
+
assert self.max_indices == (
|
|
122
|
+
sum(self.Es) - 1
|
|
123
|
+
), "max_indices must be equal to sum(Es) - 1"
|
|
124
|
+
self.batch_params.validate()
|
|
125
|
+
self.indices_params.validate()
|
|
126
|
+
self.pooling_params.validate()
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
def variable_B(self) -> bool:
|
|
130
|
+
return self.batch_params.sigma_B is not None
|
|
131
|
+
|
|
132
|
+
def variable_L(self) -> bool:
|
|
133
|
+
return self.pooling_params.sigma_L is not None
|
|
134
|
+
|
|
135
|
+
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
|
|
136
|
+
# Per-sample weights will always be FP32
|
|
137
|
+
return None if not self.weighted else torch.randn(size, device=get_device())
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig
|
|
16
|
+
from fbgemm_gpu.tbe.utils.common import get_device, round_up
|
|
17
|
+
|
|
18
|
+
from fbgemm_gpu.tbe.utils.requests import (
|
|
19
|
+
generate_batch_sizes_from_stats,
|
|
20
|
+
generate_pooling_factors_from_stats,
|
|
21
|
+
get_table_batched_offsets_from_dense,
|
|
22
|
+
maybe_to_dtype,
|
|
23
|
+
TBERequest,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
# pyre-ignore[21]
|
|
28
|
+
from fbgemm_gpu import open_source # noqa: F401
|
|
29
|
+
except Exception:
|
|
30
|
+
torch.ops.load_library(
|
|
31
|
+
"//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _generate_batch_sizes(
|
|
36
|
+
tbe_data_config: TBEDataConfig,
|
|
37
|
+
) -> tuple[list[int], Optional[list[list[int]]]]:
|
|
38
|
+
if tbe_data_config.variable_B():
|
|
39
|
+
assert (
|
|
40
|
+
tbe_data_config.batch_params.vbe_num_ranks is not None
|
|
41
|
+
), "vbe_num_ranks must be set for varaible batch size generation"
|
|
42
|
+
return generate_batch_sizes_from_stats(
|
|
43
|
+
tbe_data_config.batch_params.B,
|
|
44
|
+
tbe_data_config.T,
|
|
45
|
+
# pyre-ignore [6]
|
|
46
|
+
tbe_data_config.batch_params.sigma_B,
|
|
47
|
+
tbe_data_config.batch_params.vbe_num_ranks,
|
|
48
|
+
# pyre-ignore [6]
|
|
49
|
+
tbe_data_config.batch_params.vbe_distribution,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
return ([tbe_data_config.batch_params.B] * tbe_data_config.T, None)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _generate_pooling_info(
|
|
57
|
+
tbe_data_config: TBEDataConfig, iters: int, Bs: list[int]
|
|
58
|
+
) -> torch.Tensor:
|
|
59
|
+
if tbe_data_config.variable_L():
|
|
60
|
+
# Generate L from stats
|
|
61
|
+
_, L_offsets = generate_pooling_factors_from_stats(
|
|
62
|
+
iters,
|
|
63
|
+
Bs,
|
|
64
|
+
tbe_data_config.pooling_params.L,
|
|
65
|
+
# pyre-ignore [6]
|
|
66
|
+
tbe_data_config.pooling_params.sigma_L,
|
|
67
|
+
# pyre-ignore [6]
|
|
68
|
+
tbe_data_config.pooling_params.length_distribution,
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
Ls = [tbe_data_config.pooling_params.L] * (sum(Bs) * iters)
|
|
72
|
+
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
|
|
73
|
+
|
|
74
|
+
return L_offsets
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _generate_indices(
|
|
78
|
+
tbe_data_config: TBEDataConfig,
|
|
79
|
+
iters: int,
|
|
80
|
+
Bs: list[int],
|
|
81
|
+
L_offsets: torch.Tensor,
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
|
|
84
|
+
total_B = sum(Bs)
|
|
85
|
+
L_offsets_list = L_offsets.tolist()
|
|
86
|
+
indices_list = []
|
|
87
|
+
for it in range(iters):
|
|
88
|
+
# L_offsets is defined over the entire set of batches for a single iteration
|
|
89
|
+
start_offset = L_offsets_list[it * total_B]
|
|
90
|
+
end_offset = L_offsets_list[(it + 1) * total_B]
|
|
91
|
+
|
|
92
|
+
indices_list.append(
|
|
93
|
+
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
|
|
94
|
+
tbe_data_config.indices_params.heavy_hitters,
|
|
95
|
+
tbe_data_config.indices_params.zipf_q,
|
|
96
|
+
tbe_data_config.indices_params.zipf_s,
|
|
97
|
+
# max_index = dimensions of the embedding table
|
|
98
|
+
tbe_data_config.E,
|
|
99
|
+
# num_indices = number of indices to generate
|
|
100
|
+
end_offset - start_offset,
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return torch.cat(indices_list)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _build_requests_jagged(
|
|
108
|
+
tbe_data_config: TBEDataConfig,
|
|
109
|
+
iters: int,
|
|
110
|
+
Bs: list[int],
|
|
111
|
+
Bs_feature_rank: Optional[list[list[int]]],
|
|
112
|
+
L_offsets: torch.Tensor,
|
|
113
|
+
all_indices: torch.Tensor,
|
|
114
|
+
) -> list[TBERequest]:
|
|
115
|
+
total_B = sum(Bs)
|
|
116
|
+
all_indices = all_indices.flatten()
|
|
117
|
+
requests = []
|
|
118
|
+
for it in range(iters):
|
|
119
|
+
start_offset = L_offsets[it * total_B]
|
|
120
|
+
it_L_offsets = torch.concat(
|
|
121
|
+
[
|
|
122
|
+
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
|
|
123
|
+
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
requests.append(
|
|
127
|
+
TBERequest(
|
|
128
|
+
maybe_to_dtype(
|
|
129
|
+
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
|
|
130
|
+
tbe_data_config.indices_params.index_dtype,
|
|
131
|
+
),
|
|
132
|
+
maybe_to_dtype(
|
|
133
|
+
it_L_offsets.to(get_device()),
|
|
134
|
+
tbe_data_config.indices_params.offset_dtype,
|
|
135
|
+
),
|
|
136
|
+
tbe_data_config._new_weights(int(it_L_offsets[-1].item())),
|
|
137
|
+
Bs_feature_rank if tbe_data_config.variable_B() else None,
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
return requests
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _build_requests_dense(
|
|
144
|
+
tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor
|
|
145
|
+
) -> list[TBERequest]:
|
|
146
|
+
# NOTE: We're using existing code from requests.py to build the
|
|
147
|
+
# requests, and since the existing code requires 2D view of all_indices,
|
|
148
|
+
# the existing all_indices must be reshaped
|
|
149
|
+
all_indices = all_indices.reshape(iters, -1)
|
|
150
|
+
|
|
151
|
+
requests = []
|
|
152
|
+
for it in range(iters):
|
|
153
|
+
indices, offsets = get_table_batched_offsets_from_dense(
|
|
154
|
+
all_indices[it].view(
|
|
155
|
+
tbe_data_config.T,
|
|
156
|
+
tbe_data_config.batch_params.B,
|
|
157
|
+
tbe_data_config.pooling_params.L,
|
|
158
|
+
),
|
|
159
|
+
use_cpu=tbe_data_config.use_cpu,
|
|
160
|
+
)
|
|
161
|
+
requests.append(
|
|
162
|
+
TBERequest(
|
|
163
|
+
maybe_to_dtype(indices, tbe_data_config.indices_params.index_dtype),
|
|
164
|
+
maybe_to_dtype(offsets, tbe_data_config.indices_params.offset_dtype),
|
|
165
|
+
tbe_data_config._new_weights(
|
|
166
|
+
tbe_data_config.T
|
|
167
|
+
* tbe_data_config.batch_params.B
|
|
168
|
+
* tbe_data_config.pooling_params.L
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
return requests
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def generate_requests(
|
|
176
|
+
tbe_data_config: TBEDataConfig,
|
|
177
|
+
iters: int = 1,
|
|
178
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
179
|
+
) -> list[TBERequest]:
|
|
180
|
+
|
|
181
|
+
# Generate batch sizes
|
|
182
|
+
if batch_size_per_feature_per_rank:
|
|
183
|
+
Bs = tbe_data_config.batch_params.Bs
|
|
184
|
+
else:
|
|
185
|
+
Bs, _ = _generate_batch_sizes(tbe_data_config)
|
|
186
|
+
|
|
187
|
+
assert Bs is not None, "Batch sizes (Bs) must be set"
|
|
188
|
+
|
|
189
|
+
# Generate pooling info
|
|
190
|
+
L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs)
|
|
191
|
+
|
|
192
|
+
# Generate indices
|
|
193
|
+
all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
|
|
194
|
+
all_indices = all_indices.to(get_device())
|
|
195
|
+
|
|
196
|
+
# Build TBE requests
|
|
197
|
+
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
|
|
198
|
+
if batch_size_per_feature_per_rank:
|
|
199
|
+
return _build_requests_jagged(
|
|
200
|
+
tbe_data_config,
|
|
201
|
+
iters,
|
|
202
|
+
Bs,
|
|
203
|
+
batch_size_per_feature_per_rank,
|
|
204
|
+
L_offsets,
|
|
205
|
+
all_indices,
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
return _build_requests_jagged(
|
|
209
|
+
tbe_data_config,
|
|
210
|
+
iters,
|
|
211
|
+
Bs,
|
|
212
|
+
batch_size_per_feature_per_rank,
|
|
213
|
+
L_offsets,
|
|
214
|
+
all_indices,
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
return _build_requests_dense(tbe_data_config, iters, all_indices)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def generate_requests_with_Llist(
|
|
221
|
+
tbe_data_config: TBEDataConfig,
|
|
222
|
+
L_list: torch.Tensor,
|
|
223
|
+
iters: int = 1,
|
|
224
|
+
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
225
|
+
) -> list[TBERequest]:
|
|
226
|
+
"""
|
|
227
|
+
Generate a list of TBERequest objects based on the provided TBE data configuration and L_list
|
|
228
|
+
This function generates batch sizes and pooling information from the input L_list,
|
|
229
|
+
simulates L distributions with Gaussian noise, and creates indices for embedding lookups.
|
|
230
|
+
It supports both variable batch sizes and sequence lengths, building either jagged or dense requests accordingly.
|
|
231
|
+
Args:
|
|
232
|
+
tbe_data_config (TBEDataConfig): Configuration object containing batch parameters and pooling parameters.
|
|
233
|
+
L_list (torch.Tensor): Tensor of base sequence lengths for each batch.
|
|
234
|
+
iters (int, optional): Number of iterations to repeat the generated requests. Defaults to 1.
|
|
235
|
+
batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Optional batch size specification per feature per rank. Defaults to None.
|
|
236
|
+
Returns:
|
|
237
|
+
List[TBERequest]: A list of TBERequest objects constructed according to the configuration and input parameters.
|
|
238
|
+
Raises:
|
|
239
|
+
AssertionError: If batch sizes (Bs) are not set in the tbe_data_config.
|
|
240
|
+
Example:
|
|
241
|
+
>>> requests = generate_requests_with_Llist(tbe_data_config, L_list=torch.tensor([10, 20]), iters=2)
|
|
242
|
+
>>> len(requests)
|
|
243
|
+
2
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
# Generate batch sizes
|
|
247
|
+
Bs = tbe_data_config.batch_params.Bs
|
|
248
|
+
assert (
|
|
249
|
+
Bs is not None
|
|
250
|
+
), "Batch sizes (Bs) must be set for generate_requests_with_Llist"
|
|
251
|
+
|
|
252
|
+
# Generate pooling info from L list
|
|
253
|
+
Ls_list = []
|
|
254
|
+
for i in range(len(Bs)):
|
|
255
|
+
L = L_list[i]
|
|
256
|
+
B = Bs[i]
|
|
257
|
+
Ls_iter = np.random.normal(
|
|
258
|
+
loc=L, scale=tbe_data_config.pooling_params.sigma_L, size=B
|
|
259
|
+
).astype(int)
|
|
260
|
+
Ls_list.append(Ls_iter)
|
|
261
|
+
Ls = np.concatenate(Ls_list)
|
|
262
|
+
Ls[Ls < 0] = 0
|
|
263
|
+
# Use the same L distribution across iters
|
|
264
|
+
Ls = np.tile(Ls, iters)
|
|
265
|
+
L = Ls.max()
|
|
266
|
+
# Make it exclusive cumsum
|
|
267
|
+
L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
|
|
268
|
+
|
|
269
|
+
# Generate indices
|
|
270
|
+
all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
|
|
271
|
+
all_indices = all_indices.to(get_device())
|
|
272
|
+
|
|
273
|
+
# Build TBE requests
|
|
274
|
+
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
|
|
275
|
+
return _build_requests_jagged(
|
|
276
|
+
tbe_data_config,
|
|
277
|
+
iters,
|
|
278
|
+
Bs,
|
|
279
|
+
batch_size_per_feature_per_rank,
|
|
280
|
+
L_offsets,
|
|
281
|
+
all_indices,
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
return _build_requests_dense(tbe_data_config, iters, all_indices)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> tuple[int, list[int]]:
|
|
288
|
+
if tbe_data_config.mixed_dim:
|
|
289
|
+
Ds = [
|
|
290
|
+
round_up(
|
|
291
|
+
int(
|
|
292
|
+
torch.randint(
|
|
293
|
+
low=int(0.5 * tbe_data_config.D),
|
|
294
|
+
high=int(1.5 * tbe_data_config.D),
|
|
295
|
+
size=(1,),
|
|
296
|
+
).item()
|
|
297
|
+
),
|
|
298
|
+
4,
|
|
299
|
+
)
|
|
300
|
+
for _ in range(tbe_data_config.T)
|
|
301
|
+
]
|
|
302
|
+
return (sum(Ds) // len(Ds), Ds)
|
|
303
|
+
else:
|
|
304
|
+
return (tbe_data_config.D, [tbe_data_config.D] * tbe_data_config.T)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def generate_feature_requires_grad(
|
|
308
|
+
tbe_data_config: TBEDataConfig, size: int
|
|
309
|
+
) -> torch.Tensor:
|
|
310
|
+
assert (
|
|
311
|
+
size <= tbe_data_config.T
|
|
312
|
+
), "size of feature_requires_grad must be less than T"
|
|
313
|
+
weighted_requires_grad_tables = torch.randperm(tbe_data_config.T)[:size].tolist()
|
|
314
|
+
return (
|
|
315
|
+
torch.tensor(
|
|
316
|
+
[
|
|
317
|
+
1 if t in weighted_requires_grad_tables else 0
|
|
318
|
+
for t in range(tbe_data_config.T)
|
|
319
|
+
]
|
|
320
|
+
)
|
|
321
|
+
.to(get_device())
|
|
322
|
+
.int()
|
|
323
|
+
)
|