fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,137 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import dataclasses
11
+ import json
12
+ import logging
13
+ from typing import Any, Optional
14
+
15
+ import torch
16
+
17
+ from fbgemm_gpu.tbe.utils.common import get_device
18
+
19
+ from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
20
+
21
+ try:
22
+ torch.ops.load_library(
23
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
24
+ )
25
+ except Exception:
26
+ pass
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class TBEDataConfig:
31
+ # Number of tables
32
+ T: int
33
+ # Number of rows in the embedding table
34
+ E: int
35
+ # Target embedding dimension for a table (number of columns)
36
+ D: int
37
+ # Generate mixed dimensions if true
38
+ mixed_dim: bool
39
+ # Whether the lookup rows are weighted or not
40
+ weighted: bool
41
+ # Batch parameters
42
+ batch_params: BatchParams
43
+ # Indices parameters
44
+ indices_params: IndicesParams
45
+ # Pooling parameters
46
+ pooling_params: PoolingParams
47
+ # Force generated tensors to be on CPU
48
+ use_cpu: bool = False
49
+ # Number of embeddings in each embedding features (number of rows)
50
+ Es: Optional[list[int]] = None
51
+ # Target embedding dimension for each features (number of columns)
52
+ Ds: Optional[list[int]] = None
53
+ # Maximum number of indices
54
+ max_indices: Optional[int] = None # Maximum number of indices
55
+
56
+ def __post_init__(self) -> None:
57
+ if isinstance(self.D, list):
58
+ object.__setattr__(self, "mixed_dim", len(set(self.D)) > 1)
59
+ if isinstance(self.E, list) and self.max_indices is None:
60
+ object.__setattr__(self, "max_indices", sum(self.E) - 1)
61
+ self.validate()
62
+
63
+ @staticmethod
64
+ def complex_fields() -> dict[str, Any]:
65
+ return {
66
+ "batch_params": BatchParams,
67
+ "indices_params": IndicesParams,
68
+ "pooling_params": PoolingParams,
69
+ }
70
+
71
+ @classmethod
72
+ # pyre-ignore [3]
73
+ def from_dict(cls, data: dict[str, Any]):
74
+ for field, Type in cls.complex_fields().items():
75
+ if not isinstance(data[field], Type):
76
+ data[field] = Type.from_dict(data[field])
77
+ return cls(**data)
78
+
79
+ @classmethod
80
+ # pyre-ignore [3]
81
+ def from_json(cls, data: str):
82
+ raw = json.loads(data)
83
+ allowed = {f.name for f in dataclasses.fields(cls)}
84
+ existing_fields = {k: v for k, v in raw.items() if k in allowed}
85
+ missing_fields = allowed - set(existing_fields.keys())
86
+ unknown_fields = set(raw.keys()) - allowed
87
+ if missing_fields:
88
+ logging.warning(
89
+ f"TBEDataConfig.from_json: Missing expected fields not loaded: {sorted(missing_fields)}"
90
+ )
91
+ if unknown_fields:
92
+ logging.info(
93
+ f"TBEDataConfig.from_json: Ignored unknown fields from input: {sorted(unknown_fields)}"
94
+ )
95
+ return cls.from_dict(existing_fields)
96
+
97
+ def dict(self) -> dict[str, Any]:
98
+ tmp = dataclasses.asdict(self)
99
+ for field in TBEDataConfig.complex_fields().keys():
100
+ tmp[field] = self.__dict__[field].dict()
101
+ return tmp
102
+
103
+ def json(self, format: bool = False) -> str:
104
+ return json.dumps(self.dict(), indent=(2 if format else -1), sort_keys=True)
105
+
106
+ # pyre-ignore [3]
107
+ def validate(self):
108
+ # NOTE: Add validation logic here
109
+ assert self.T > 0, "T must be positive"
110
+ assert self.E > 0, "E must be positive"
111
+ if self.Es is not None:
112
+ assert all(e > 0 for e in self.Es), "All elements in Es must be positive"
113
+ assert self.D > 0, "D must be positive"
114
+ if self.Ds is not None:
115
+ assert all(d > 0 for d in self.Ds), "All elements in Ds must be positive"
116
+ if isinstance(self.E, list) and isinstance(self.D, list):
117
+ assert (
118
+ len(self.E) == len(self.D) == self.T
119
+ ), "Lengths of Es, Lengths of Ds, and T must be equal"
120
+ if self.max_indices is not None:
121
+ assert self.max_indices == (
122
+ sum(self.Es) - 1
123
+ ), "max_indices must be equal to sum(Es) - 1"
124
+ self.batch_params.validate()
125
+ self.indices_params.validate()
126
+ self.pooling_params.validate()
127
+ return self
128
+
129
+ def variable_B(self) -> bool:
130
+ return self.batch_params.sigma_B is not None
131
+
132
+ def variable_L(self) -> bool:
133
+ return self.pooling_params.sigma_L is not None
134
+
135
+ def _new_weights(self, size: int) -> Optional[torch.Tensor]:
136
+ # Per-sample weights will always be FP32
137
+ return None if not self.weighted else torch.randn(size, device=get_device())
@@ -0,0 +1,323 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig
16
+ from fbgemm_gpu.tbe.utils.common import get_device, round_up
17
+
18
+ from fbgemm_gpu.tbe.utils.requests import (
19
+ generate_batch_sizes_from_stats,
20
+ generate_pooling_factors_from_stats,
21
+ get_table_batched_offsets_from_dense,
22
+ maybe_to_dtype,
23
+ TBERequest,
24
+ )
25
+
26
+ try:
27
+ # pyre-ignore[21]
28
+ from fbgemm_gpu import open_source # noqa: F401
29
+ except Exception:
30
+ torch.ops.load_library(
31
+ "//deeplearning/fbgemm/fbgemm_gpu/src/tbe/eeg:indices_generator"
32
+ )
33
+
34
+
35
+ def _generate_batch_sizes(
36
+ tbe_data_config: TBEDataConfig,
37
+ ) -> tuple[list[int], Optional[list[list[int]]]]:
38
+ if tbe_data_config.variable_B():
39
+ assert (
40
+ tbe_data_config.batch_params.vbe_num_ranks is not None
41
+ ), "vbe_num_ranks must be set for varaible batch size generation"
42
+ return generate_batch_sizes_from_stats(
43
+ tbe_data_config.batch_params.B,
44
+ tbe_data_config.T,
45
+ # pyre-ignore [6]
46
+ tbe_data_config.batch_params.sigma_B,
47
+ tbe_data_config.batch_params.vbe_num_ranks,
48
+ # pyre-ignore [6]
49
+ tbe_data_config.batch_params.vbe_distribution,
50
+ )
51
+
52
+ else:
53
+ return ([tbe_data_config.batch_params.B] * tbe_data_config.T, None)
54
+
55
+
56
+ def _generate_pooling_info(
57
+ tbe_data_config: TBEDataConfig, iters: int, Bs: list[int]
58
+ ) -> torch.Tensor:
59
+ if tbe_data_config.variable_L():
60
+ # Generate L from stats
61
+ _, L_offsets = generate_pooling_factors_from_stats(
62
+ iters,
63
+ Bs,
64
+ tbe_data_config.pooling_params.L,
65
+ # pyre-ignore [6]
66
+ tbe_data_config.pooling_params.sigma_L,
67
+ # pyre-ignore [6]
68
+ tbe_data_config.pooling_params.length_distribution,
69
+ )
70
+ else:
71
+ Ls = [tbe_data_config.pooling_params.L] * (sum(Bs) * iters)
72
+ L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
73
+
74
+ return L_offsets
75
+
76
+
77
+ def _generate_indices(
78
+ tbe_data_config: TBEDataConfig,
79
+ iters: int,
80
+ Bs: list[int],
81
+ L_offsets: torch.Tensor,
82
+ ) -> torch.Tensor:
83
+
84
+ total_B = sum(Bs)
85
+ L_offsets_list = L_offsets.tolist()
86
+ indices_list = []
87
+ for it in range(iters):
88
+ # L_offsets is defined over the entire set of batches for a single iteration
89
+ start_offset = L_offsets_list[it * total_B]
90
+ end_offset = L_offsets_list[(it + 1) * total_B]
91
+
92
+ indices_list.append(
93
+ torch.ops.fbgemm.tbe_generate_indices_from_distribution(
94
+ tbe_data_config.indices_params.heavy_hitters,
95
+ tbe_data_config.indices_params.zipf_q,
96
+ tbe_data_config.indices_params.zipf_s,
97
+ # max_index = dimensions of the embedding table
98
+ tbe_data_config.E,
99
+ # num_indices = number of indices to generate
100
+ end_offset - start_offset,
101
+ )
102
+ )
103
+
104
+ return torch.cat(indices_list)
105
+
106
+
107
+ def _build_requests_jagged(
108
+ tbe_data_config: TBEDataConfig,
109
+ iters: int,
110
+ Bs: list[int],
111
+ Bs_feature_rank: Optional[list[list[int]]],
112
+ L_offsets: torch.Tensor,
113
+ all_indices: torch.Tensor,
114
+ ) -> list[TBERequest]:
115
+ total_B = sum(Bs)
116
+ all_indices = all_indices.flatten()
117
+ requests = []
118
+ for it in range(iters):
119
+ start_offset = L_offsets[it * total_B]
120
+ it_L_offsets = torch.concat(
121
+ [
122
+ torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
123
+ L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
124
+ ]
125
+ )
126
+ requests.append(
127
+ TBERequest(
128
+ maybe_to_dtype(
129
+ all_indices[start_offset : L_offsets[(it + 1) * total_B]],
130
+ tbe_data_config.indices_params.index_dtype,
131
+ ),
132
+ maybe_to_dtype(
133
+ it_L_offsets.to(get_device()),
134
+ tbe_data_config.indices_params.offset_dtype,
135
+ ),
136
+ tbe_data_config._new_weights(int(it_L_offsets[-1].item())),
137
+ Bs_feature_rank if tbe_data_config.variable_B() else None,
138
+ )
139
+ )
140
+ return requests
141
+
142
+
143
+ def _build_requests_dense(
144
+ tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor
145
+ ) -> list[TBERequest]:
146
+ # NOTE: We're using existing code from requests.py to build the
147
+ # requests, and since the existing code requires 2D view of all_indices,
148
+ # the existing all_indices must be reshaped
149
+ all_indices = all_indices.reshape(iters, -1)
150
+
151
+ requests = []
152
+ for it in range(iters):
153
+ indices, offsets = get_table_batched_offsets_from_dense(
154
+ all_indices[it].view(
155
+ tbe_data_config.T,
156
+ tbe_data_config.batch_params.B,
157
+ tbe_data_config.pooling_params.L,
158
+ ),
159
+ use_cpu=tbe_data_config.use_cpu,
160
+ )
161
+ requests.append(
162
+ TBERequest(
163
+ maybe_to_dtype(indices, tbe_data_config.indices_params.index_dtype),
164
+ maybe_to_dtype(offsets, tbe_data_config.indices_params.offset_dtype),
165
+ tbe_data_config._new_weights(
166
+ tbe_data_config.T
167
+ * tbe_data_config.batch_params.B
168
+ * tbe_data_config.pooling_params.L
169
+ ),
170
+ )
171
+ )
172
+ return requests
173
+
174
+
175
+ def generate_requests(
176
+ tbe_data_config: TBEDataConfig,
177
+ iters: int = 1,
178
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
179
+ ) -> list[TBERequest]:
180
+
181
+ # Generate batch sizes
182
+ if batch_size_per_feature_per_rank:
183
+ Bs = tbe_data_config.batch_params.Bs
184
+ else:
185
+ Bs, _ = _generate_batch_sizes(tbe_data_config)
186
+
187
+ assert Bs is not None, "Batch sizes (Bs) must be set"
188
+
189
+ # Generate pooling info
190
+ L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs)
191
+
192
+ # Generate indices
193
+ all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
194
+ all_indices = all_indices.to(get_device())
195
+
196
+ # Build TBE requests
197
+ if tbe_data_config.variable_B() or tbe_data_config.variable_L():
198
+ if batch_size_per_feature_per_rank:
199
+ return _build_requests_jagged(
200
+ tbe_data_config,
201
+ iters,
202
+ Bs,
203
+ batch_size_per_feature_per_rank,
204
+ L_offsets,
205
+ all_indices,
206
+ )
207
+ else:
208
+ return _build_requests_jagged(
209
+ tbe_data_config,
210
+ iters,
211
+ Bs,
212
+ batch_size_per_feature_per_rank,
213
+ L_offsets,
214
+ all_indices,
215
+ )
216
+ else:
217
+ return _build_requests_dense(tbe_data_config, iters, all_indices)
218
+
219
+
220
+ def generate_requests_with_Llist(
221
+ tbe_data_config: TBEDataConfig,
222
+ L_list: torch.Tensor,
223
+ iters: int = 1,
224
+ batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
225
+ ) -> list[TBERequest]:
226
+ """
227
+ Generate a list of TBERequest objects based on the provided TBE data configuration and L_list
228
+ This function generates batch sizes and pooling information from the input L_list,
229
+ simulates L distributions with Gaussian noise, and creates indices for embedding lookups.
230
+ It supports both variable batch sizes and sequence lengths, building either jagged or dense requests accordingly.
231
+ Args:
232
+ tbe_data_config (TBEDataConfig): Configuration object containing batch parameters and pooling parameters.
233
+ L_list (torch.Tensor): Tensor of base sequence lengths for each batch.
234
+ iters (int, optional): Number of iterations to repeat the generated requests. Defaults to 1.
235
+ batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Optional batch size specification per feature per rank. Defaults to None.
236
+ Returns:
237
+ List[TBERequest]: A list of TBERequest objects constructed according to the configuration and input parameters.
238
+ Raises:
239
+ AssertionError: If batch sizes (Bs) are not set in the tbe_data_config.
240
+ Example:
241
+ >>> requests = generate_requests_with_Llist(tbe_data_config, L_list=torch.tensor([10, 20]), iters=2)
242
+ >>> len(requests)
243
+ 2
244
+ """
245
+
246
+ # Generate batch sizes
247
+ Bs = tbe_data_config.batch_params.Bs
248
+ assert (
249
+ Bs is not None
250
+ ), "Batch sizes (Bs) must be set for generate_requests_with_Llist"
251
+
252
+ # Generate pooling info from L list
253
+ Ls_list = []
254
+ for i in range(len(Bs)):
255
+ L = L_list[i]
256
+ B = Bs[i]
257
+ Ls_iter = np.random.normal(
258
+ loc=L, scale=tbe_data_config.pooling_params.sigma_L, size=B
259
+ ).astype(int)
260
+ Ls_list.append(Ls_iter)
261
+ Ls = np.concatenate(Ls_list)
262
+ Ls[Ls < 0] = 0
263
+ # Use the same L distribution across iters
264
+ Ls = np.tile(Ls, iters)
265
+ L = Ls.max()
266
+ # Make it exclusive cumsum
267
+ L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long)
268
+
269
+ # Generate indices
270
+ all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets)
271
+ all_indices = all_indices.to(get_device())
272
+
273
+ # Build TBE requests
274
+ if tbe_data_config.variable_B() or tbe_data_config.variable_L():
275
+ return _build_requests_jagged(
276
+ tbe_data_config,
277
+ iters,
278
+ Bs,
279
+ batch_size_per_feature_per_rank,
280
+ L_offsets,
281
+ all_indices,
282
+ )
283
+ else:
284
+ return _build_requests_dense(tbe_data_config, iters, all_indices)
285
+
286
+
287
+ def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> tuple[int, list[int]]:
288
+ if tbe_data_config.mixed_dim:
289
+ Ds = [
290
+ round_up(
291
+ int(
292
+ torch.randint(
293
+ low=int(0.5 * tbe_data_config.D),
294
+ high=int(1.5 * tbe_data_config.D),
295
+ size=(1,),
296
+ ).item()
297
+ ),
298
+ 4,
299
+ )
300
+ for _ in range(tbe_data_config.T)
301
+ ]
302
+ return (sum(Ds) // len(Ds), Ds)
303
+ else:
304
+ return (tbe_data_config.D, [tbe_data_config.D] * tbe_data_config.T)
305
+
306
+
307
+ def generate_feature_requires_grad(
308
+ tbe_data_config: TBEDataConfig, size: int
309
+ ) -> torch.Tensor:
310
+ assert (
311
+ size <= tbe_data_config.T
312
+ ), "size of feature_requires_grad must be less than T"
313
+ weighted_requires_grad_tables = torch.randperm(tbe_data_config.T)[:size].tolist()
314
+ return (
315
+ torch.tensor(
316
+ [
317
+ 1 if t in weighted_requires_grad_tables else 0
318
+ for t in range(tbe_data_config.T)
319
+ ]
320
+ )
321
+ .to(get_device())
322
+ .int()
323
+ )