fbgemm-gpu-genai-nightly 2025.12.17__cp313-cp313-manylinux_2_28_x86_64.whl → 2026.1.9__cp313-cp313-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/config/feature_list.py +3 -0
- fbgemm_gpu/docs/target.genai.json.py +1 -1
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +10 -18
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/quantize_comm.py +8 -10
- fbgemm_gpu/quantize_utils.py +58 -6
- fbgemm_gpu/split_embedding_configs.py +34 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +57 -35
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/tbe/ssd/training.py +24 -2
- fbgemm_gpu/triton/quantize.py +13 -7
- fbgemm_gpu/utils/writeback_util.py +124 -0
- {fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/METADATA +1 -1
- {fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/RECORD +19 -18
- {fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/WHEEL +0 -0
- {fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/top_level.txt +0 -0
fbgemm_gpu/asmjit.so
CHANGED
|
Binary file
|
|
@@ -63,6 +63,9 @@ class FeatureGateName(Enum):
|
|
|
63
63
|
# Enable TBE input parameters extraction
|
|
64
64
|
TBE_REPORT_INPUT_PARAMS = auto()
|
|
65
65
|
|
|
66
|
+
# Enable tuned max segment length per CTA for B200
|
|
67
|
+
TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
|
|
68
|
+
|
|
66
69
|
def is_enabled(self) -> bool:
|
|
67
70
|
return FeatureGate.is_enabled(self)
|
|
68
71
|
|
|
Binary file
|
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py
CHANGED
|
@@ -289,7 +289,7 @@ def cutlass_blackwell_fmha_decode_forward(
|
|
|
289
289
|
window_left: int = -1,
|
|
290
290
|
window_right: int = -1,
|
|
291
291
|
bottom_right: bool = True,
|
|
292
|
-
split_k_size: int =
|
|
292
|
+
split_k_size: int = 0,
|
|
293
293
|
use_heuristic: bool = True,
|
|
294
294
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
295
295
|
"""
|
|
@@ -318,14 +318,9 @@ def cutlass_blackwell_fmha_decode_forward(
|
|
|
318
318
|
split size using the heuristic. Default is True.
|
|
319
319
|
|
|
320
320
|
Returns:
|
|
321
|
-
|
|
322
|
-
-
|
|
323
|
-
|
|
324
|
-
with bfloat16 dtype
|
|
325
|
-
lse: [B, H, 1] (always float32)
|
|
326
|
-
- Split case (split_k_size > 0 or use_heuristic=True):
|
|
327
|
-
out: [B, H, num_splits, D] with float32 dtype (partial outputs for later reduction)
|
|
328
|
-
lse: [B, num_splits, H] (always float32)
|
|
321
|
+
Kernel output with Q dimension added:
|
|
322
|
+
- out: [B, 1, H, num_splits, D] (num_splits=1 when split-K disabled)
|
|
323
|
+
- lse: [B, num_splits, H, 1]
|
|
329
324
|
"""
|
|
330
325
|
_validate_decode_inputs(q, k, v, seqlen_kv)
|
|
331
326
|
|
|
@@ -365,15 +360,12 @@ def cutlass_blackwell_fmha_decode_forward(
|
|
|
365
360
|
split_k_size=split_k_size,
|
|
366
361
|
)
|
|
367
362
|
|
|
368
|
-
#
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
# lse shape: [B, Splits = 1, H] -> [B, H, 1]
|
|
375
|
-
lse = lse.view(batch_size, -1, 1)
|
|
376
|
-
|
|
363
|
+
# Kernel returns: out [B, H, num_splits, D], lse [B, num_splits, H]
|
|
364
|
+
# Reshape to consistent format with Q dimension:
|
|
365
|
+
# out: [B, H, num_splits, D] -> [B, 1, H, num_splits, D]
|
|
366
|
+
# lse: [B, num_splits, H] -> [B, num_splits, H, 1]
|
|
367
|
+
out = out.unsqueeze(1) # [B, 1, H, num_splits, D]
|
|
368
|
+
lse = lse.unsqueeze(-1) # [B, num_splits, H, 1]
|
|
377
369
|
return out, lse
|
|
378
370
|
|
|
379
371
|
|
|
Binary file
|
fbgemm_gpu/fbgemm.so
CHANGED
|
Binary file
|
fbgemm_gpu/quantize_comm.py
CHANGED
|
@@ -25,7 +25,7 @@ from fbgemm_gpu.quantize_utils import (
|
|
|
25
25
|
fp32_to_hfp8_with_clamp,
|
|
26
26
|
fp32_to_mx4,
|
|
27
27
|
hfp8_to_fp32,
|
|
28
|
-
|
|
28
|
+
mx4_to_float,
|
|
29
29
|
RoundingMode,
|
|
30
30
|
)
|
|
31
31
|
|
|
@@ -123,7 +123,7 @@ def _dequantize_tensor(
|
|
|
123
123
|
comm_precision: SparseType,
|
|
124
124
|
ctx: Optional[QuantizationContext] = None,
|
|
125
125
|
is_fwd: bool = True,
|
|
126
|
-
|
|
126
|
+
output_dtype: Optional[SparseType] = None,
|
|
127
127
|
) -> torch.Tensor:
|
|
128
128
|
if comm_precision == SparseType.FP32:
|
|
129
129
|
assert quantized_tensor.dtype == torch.float
|
|
@@ -138,10 +138,8 @@ def _dequantize_tensor(
|
|
|
138
138
|
if ctx is not None and ctx.row_dim > 0:
|
|
139
139
|
row_dim_quant = ctx.row_dim_quant
|
|
140
140
|
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
|
|
141
|
-
# use provided
|
|
142
|
-
output_dtype_int = (
|
|
143
|
-
fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
|
|
144
|
-
)
|
|
141
|
+
# use provided output_dtype or default to FP32 (0)
|
|
142
|
+
output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
|
|
145
143
|
dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
|
|
146
144
|
quantized_tensor_2d,
|
|
147
145
|
is_fwd,
|
|
@@ -161,7 +159,7 @@ def _dequantize_tensor(
|
|
|
161
159
|
return dequant_tensor.view(-1)
|
|
162
160
|
elif comm_precision == SparseType.MX4:
|
|
163
161
|
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
|
|
164
|
-
return
|
|
162
|
+
return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
|
|
165
163
|
else:
|
|
166
164
|
raise ValueError(f"comm_precision={comm_precision} is not supported")
|
|
167
165
|
|
|
@@ -175,7 +173,7 @@ class QuantizedCommCodec:
|
|
|
175
173
|
row_dim: Optional[int] = None,
|
|
176
174
|
is_fwd: bool = True,
|
|
177
175
|
rounding_mode: Optional[RoundingMode] = None,
|
|
178
|
-
|
|
176
|
+
output_dtype: Optional[SparseType] = None,
|
|
179
177
|
) -> None:
|
|
180
178
|
if loss_scale is not None:
|
|
181
179
|
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
|
|
@@ -193,7 +191,7 @@ class QuantizedCommCodec:
|
|
|
193
191
|
self._is_fwd = is_fwd
|
|
194
192
|
self._row_dim: int = -1 if row_dim is None else row_dim
|
|
195
193
|
self._rounding_mode: Optional[RoundingMode] = rounding_mode
|
|
196
|
-
self.
|
|
194
|
+
self._output_dtype: Optional[SparseType] = output_dtype
|
|
197
195
|
if self._comm_precision == SparseType.MX4:
|
|
198
196
|
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
|
|
199
197
|
self._rounding_mode = (
|
|
@@ -229,7 +227,7 @@ class QuantizedCommCodec:
|
|
|
229
227
|
self._comm_precision,
|
|
230
228
|
ctx,
|
|
231
229
|
self._is_fwd,
|
|
232
|
-
|
|
230
|
+
output_dtype=self._output_dtype,
|
|
233
231
|
)
|
|
234
232
|
return dequantized_tensor
|
|
235
233
|
|
fbgemm_gpu/quantize_utils.py
CHANGED
|
@@ -14,9 +14,15 @@ import torch # isort:skip
|
|
|
14
14
|
|
|
15
15
|
import fbgemm_gpu
|
|
16
16
|
|
|
17
|
-
from fbgemm_gpu.
|
|
17
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
18
|
+
from fbgemm_gpu.triton.common import RoundingMode
|
|
18
19
|
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
|
|
19
20
|
|
|
21
|
+
if torch.cuda.is_available():
|
|
22
|
+
from fbgemm_gpu.triton import quantize_mx4
|
|
23
|
+
from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
|
|
24
|
+
|
|
25
|
+
|
|
20
26
|
try:
|
|
21
27
|
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
22
28
|
open_source = bool(getattr(fbgemm_gpu, "open_source", False))
|
|
@@ -126,25 +132,71 @@ def mx4_to_fp32(
|
|
|
126
132
|
) -> torch.Tensor:
|
|
127
133
|
"""Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
|
|
128
134
|
|
|
135
|
+
This function is kept for backward compatibility and always returns FP32.
|
|
136
|
+
For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
|
|
137
|
+
"""
|
|
138
|
+
return mx4_to_float(
|
|
139
|
+
tensor,
|
|
140
|
+
group_size,
|
|
141
|
+
use_triton,
|
|
142
|
+
ebits,
|
|
143
|
+
mbits,
|
|
144
|
+
output_dtype=None, # None = FP32 default for backward compatibility
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def mx4_to_float(
|
|
149
|
+
tensor: torch.Tensor,
|
|
150
|
+
group_size: int = 32,
|
|
151
|
+
use_triton: bool = True,
|
|
152
|
+
ebits: int = 2,
|
|
153
|
+
mbits: int = 1,
|
|
154
|
+
output_dtype: Optional[SparseType] = None,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
"""Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
|
|
157
|
+
|
|
129
158
|
Args:
|
|
130
159
|
tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
|
|
131
160
|
group_size (int): Compute scale in chunks of group_size.
|
|
132
161
|
use_triton (bool): If set, use triton quantization, otherwise cuda.
|
|
133
162
|
ebits (int): Number of exponent bits in target mx4 format.
|
|
134
163
|
mbits (int): Number of mantissa bits in target mx4 format.
|
|
164
|
+
output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
|
|
165
|
+
Defaults to None (FP32) for backward compatibility.
|
|
135
166
|
|
|
136
167
|
Return:
|
|
137
|
-
output:
|
|
168
|
+
output: Tensor with dtype matching output_dtype and total elements (M).
|
|
138
169
|
"""
|
|
170
|
+
# Validate output_dtype
|
|
171
|
+
supported_dtypes = {SparseType.FP32, SparseType.BF16}
|
|
172
|
+
if output_dtype is not None and output_dtype not in supported_dtypes:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
|
|
175
|
+
f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
|
|
176
|
+
f"Use BF16 for memory savings with same dynamic range as FP32."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
target_dtype = (
|
|
180
|
+
output_dtype.as_dtype() if output_dtype is not None else torch.float32
|
|
181
|
+
)
|
|
182
|
+
|
|
139
183
|
# Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
|
|
140
184
|
if not tensor.is_cuda and not tensor.is_mtia:
|
|
141
|
-
|
|
185
|
+
result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
|
|
186
|
+
return result.to(target_dtype) if output_dtype is not None else result
|
|
142
187
|
if use_triton:
|
|
143
188
|
if tensor.is_mtia:
|
|
144
|
-
return mtia_dequantize_mx4(
|
|
145
|
-
|
|
189
|
+
return mtia_dequantize_mx4(
|
|
190
|
+
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
|
|
191
|
+
)
|
|
192
|
+
return triton_dequantize_mx4(
|
|
193
|
+
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
|
|
194
|
+
)
|
|
146
195
|
else:
|
|
147
|
-
|
|
196
|
+
output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
|
|
197
|
+
return torch.ops.fbgemm.dequantize_mx_cuda(
|
|
198
|
+
tensor.flatten(), group_size, output_dtype_int
|
|
199
|
+
)
|
|
148
200
|
|
|
149
201
|
|
|
150
202
|
def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -313,6 +313,40 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
|
|
|
313
313
|
}[sparse_type.value]
|
|
314
314
|
|
|
315
315
|
|
|
316
|
+
def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
|
|
317
|
+
"""
|
|
318
|
+
TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
|
|
319
|
+
|
|
320
|
+
This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
|
|
321
|
+
with TorchScript. TorchScript does not support @staticmethod on Enum classes,
|
|
322
|
+
so this function provides a workaround.
|
|
323
|
+
"""
|
|
324
|
+
if ty == 0: # fp32
|
|
325
|
+
return torch.float32
|
|
326
|
+
elif ty == 1: # fp16
|
|
327
|
+
return torch.float16
|
|
328
|
+
elif ty == 2: # int8
|
|
329
|
+
return torch.uint8
|
|
330
|
+
elif ty == 3: # int4
|
|
331
|
+
return torch.quint4x2
|
|
332
|
+
elif ty == 4: # int2
|
|
333
|
+
return torch.quint2x4
|
|
334
|
+
elif ty == 5: # bf16
|
|
335
|
+
return torch.bfloat16
|
|
336
|
+
elif ty == 6: # fp8
|
|
337
|
+
return torch.uint8
|
|
338
|
+
elif ty == 7: # mx4
|
|
339
|
+
return torch.uint8
|
|
340
|
+
elif ty == 9:
|
|
341
|
+
return (
|
|
342
|
+
torch.float8_e4m3fnuz
|
|
343
|
+
if torch.version.hip is not None
|
|
344
|
+
else torch.float8_e4m3fn
|
|
345
|
+
)
|
|
346
|
+
else: # Invalid is 7 or non enumerated.
|
|
347
|
+
raise ValueError(f"Unsupported sparse type: {ty}")
|
|
348
|
+
|
|
349
|
+
|
|
316
350
|
@enum.unique
|
|
317
351
|
class SparseType(enum.Enum):
|
|
318
352
|
FP32 = "fp32"
|
|
@@ -49,6 +49,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
|
|
49
49
|
SplitState,
|
|
50
50
|
)
|
|
51
51
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
52
|
+
check_allocated_vbe_output,
|
|
52
53
|
generate_vbe_metadata,
|
|
53
54
|
is_torchdynamo_compiling,
|
|
54
55
|
)
|
|
@@ -60,6 +61,7 @@ from fbgemm_gpu.tbe_input_multiplexer import (
|
|
|
60
61
|
)
|
|
61
62
|
|
|
62
63
|
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
|
|
64
|
+
from fbgemm_gpu.utils.writeback_util import writeback_gradient
|
|
63
65
|
|
|
64
66
|
try:
|
|
65
67
|
load_torch_module(
|
|
@@ -159,6 +161,7 @@ class UserEnabledConfigDefinition:
|
|
|
159
161
|
# More details can be found in D64848802.
|
|
160
162
|
use_rowwise_bias_correction: bool = False
|
|
161
163
|
use_writeback_bwd_prehook: bool = False
|
|
164
|
+
writeback_first_feature_only: bool = False
|
|
162
165
|
|
|
163
166
|
|
|
164
167
|
@dataclass(frozen=True)
|
|
@@ -1181,6 +1184,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1181
1184
|
self.use_writeback_bwd_prehook: bool = (
|
|
1182
1185
|
extra_optimizer_config.use_writeback_bwd_prehook
|
|
1183
1186
|
)
|
|
1187
|
+
|
|
1188
|
+
writeback_first_feature_only: bool = (
|
|
1189
|
+
extra_optimizer_config.writeback_first_feature_only
|
|
1190
|
+
)
|
|
1184
1191
|
self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
|
|
1185
1192
|
if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
|
|
1186
1193
|
raise AssertionError(
|
|
@@ -1469,6 +1476,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1469
1476
|
# self.log("TBE_V2 Knob is set to True; Using experimental TBE")
|
|
1470
1477
|
|
|
1471
1478
|
self.is_experimental: bool = is_experimental
|
|
1479
|
+
self._writeback_first_feature_only: bool = writeback_first_feature_only
|
|
1472
1480
|
|
|
1473
1481
|
# Get a debug function pointer
|
|
1474
1482
|
self._debug_print_input_stats: Callable[..., None] = (
|
|
@@ -1483,7 +1491,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
1483
1491
|
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
|
|
1484
1492
|
# Register writeback hook for Exact_SGD optimizer
|
|
1485
1493
|
self.log(
|
|
1486
|
-
"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
|
|
1494
|
+
f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
|
|
1487
1495
|
)
|
|
1488
1496
|
# pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
|
|
1489
1497
|
self.register_full_backward_pre_hook(self.writeback_hook)
|
|
@@ -2003,6 +2011,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2003
2011
|
self,
|
|
2004
2012
|
offsets: Tensor,
|
|
2005
2013
|
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2014
|
+
vbe_output: Optional[Tensor] = None,
|
|
2015
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2006
2016
|
) -> invokers.lookup_args.VBEMetadata:
|
|
2007
2017
|
# Blocking D2H copy, but only runs at first call
|
|
2008
2018
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -2025,6 +2035,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2025
2035
|
self.pooling_mode,
|
|
2026
2036
|
self.feature_dims,
|
|
2027
2037
|
self.current_device,
|
|
2038
|
+
vbe_output,
|
|
2039
|
+
vbe_output_offsets,
|
|
2028
2040
|
)
|
|
2029
2041
|
|
|
2030
2042
|
@torch.jit.ignore
|
|
@@ -2033,40 +2045,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2033
2045
|
# This allows models using this class to compile correctly
|
|
2034
2046
|
return FeatureGate.is_enabled(feature)
|
|
2035
2047
|
|
|
2036
|
-
def writeback_update_gradient(
|
|
2037
|
-
self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
|
|
2038
|
-
) -> Tensor:
|
|
2039
|
-
if indices.numel() == 0:
|
|
2040
|
-
return grad[0]
|
|
2041
|
-
num_of_tables = len(set(self.feature_table_map))
|
|
2042
|
-
assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
|
|
2043
|
-
batch_size = offsets.shape[0] // num_of_tables
|
|
2044
|
-
max_indices = indices.max()
|
|
2045
|
-
non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
|
|
2046
|
-
# disable dedup across different table
|
|
2047
|
-
indices = ((offsets[non_empty_index]) // batch_size) * (
|
|
2048
|
-
1 + max_indices
|
|
2049
|
-
) + indices
|
|
2050
|
-
grad = grad[0]
|
|
2051
|
-
_, idx, counts = torch.unique(
|
|
2052
|
-
indices, dim=0, sorted=True, return_inverse=True, return_counts=True
|
|
2053
|
-
)
|
|
2054
|
-
_, ind_sorted = torch.sort(idx, stable=True)
|
|
2055
|
-
cum_sum = counts.cumsum(0)
|
|
2056
|
-
cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
|
|
2057
|
-
first_indicies = ind_sorted[cum_sum]
|
|
2058
|
-
mask = torch.zeros_like(grad, device=grad.device)
|
|
2059
|
-
original_index = non_empty_index[first_indicies]
|
|
2060
|
-
|
|
2061
|
-
mask[original_index] = grad[original_index]
|
|
2062
|
-
return mask
|
|
2063
|
-
|
|
2064
2048
|
# pyre-fixme[2]: For 1st argument expected not ANY
|
|
2065
2049
|
def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
|
|
2066
2050
|
indices = self._indices
|
|
2067
2051
|
offsets = self._offsets
|
|
2068
|
-
|
|
2069
|
-
|
|
2052
|
+
return writeback_gradient(
|
|
2053
|
+
grad,
|
|
2054
|
+
indices,
|
|
2055
|
+
offsets,
|
|
2056
|
+
self.feature_table_map,
|
|
2057
|
+
self._writeback_first_feature_only,
|
|
2058
|
+
)
|
|
2070
2059
|
|
|
2071
2060
|
def forward( # noqa: C901
|
|
2072
2061
|
self,
|
|
@@ -2078,6 +2067,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2078
2067
|
total_unique_indices: Optional[int] = None,
|
|
2079
2068
|
hash_zch_identities: Optional[Tensor] = None,
|
|
2080
2069
|
hash_zch_runtime_meta: Optional[Tensor] = None,
|
|
2070
|
+
vbe_output: Optional[Tensor] = None,
|
|
2071
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2081
2072
|
) -> Tensor:
|
|
2082
2073
|
"""
|
|
2083
2074
|
The forward pass function that
|
|
@@ -2130,13 +2121,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2130
2121
|
be set when using `OptimType.NONE`. This is because TBE
|
|
2131
2122
|
requires this information for allocating the weight gradient
|
|
2132
2123
|
tensor in the backward pass.
|
|
2133
|
-
|
|
2134
2124
|
hash_zch_identities (Optional[Tensor]): The original raw IDs before
|
|
2135
2125
|
remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
|
|
2136
2126
|
populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
|
|
2137
2127
|
and is required for Raw Embedding Streaming (RES) to maintain
|
|
2138
2128
|
consistency between training and inference.
|
|
2139
|
-
|
|
2129
|
+
vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
|
|
2130
|
+
contains output for TBE VBE. The shape of the tensor is
|
|
2131
|
+
[1, total_vbe_output_size] where total_vbe_output_size is the
|
|
2132
|
+
output size across all ranks and all embedding tables.
|
|
2133
|
+
If this tensor is not None, the TBE VBE forward output is written
|
|
2134
|
+
to this tensor at the locations specified by `vbe_output_offsets`.
|
|
2135
|
+
vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
|
|
2136
|
+
contains VBE output offsets to `vbe_output`. The shape of the
|
|
2137
|
+
tensor is [num_ranks, num_features].
|
|
2138
|
+
vbe_output_offsets[r][f] represents the starting offset for rank `r`
|
|
2139
|
+
and feature `f`.
|
|
2140
2140
|
Returns:
|
|
2141
2141
|
A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
|
|
2142
2142
|
batch size and `total_D` = the sum of all embedding dimensions in the
|
|
@@ -2210,8 +2210,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
2210
2210
|
batch_size_per_feature_per_rank,
|
|
2211
2211
|
force_cast_input_types=True,
|
|
2212
2212
|
prefetch_pipeline=False,
|
|
2213
|
+
vbe_output=vbe_output,
|
|
2214
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
2213
2215
|
)
|
|
2214
2216
|
|
|
2217
|
+
# Only enable VBE if batch_size_per_feature_per_rank is not None
|
|
2218
|
+
assert not (
|
|
2219
|
+
batch_size_per_feature_per_rank is not None
|
|
2220
|
+
and self.use_writeback_bwd_prehook
|
|
2221
|
+
), "VBE is not supported with writeback_bwd_prehook"
|
|
2222
|
+
|
|
2215
2223
|
# Print input stats if enable (for debugging purpose only)
|
|
2216
2224
|
self._debug_print_input_stats(indices, offsets, per_sample_weights)
|
|
2217
2225
|
|
|
@@ -3875,6 +3883,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3875
3883
|
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3876
3884
|
force_cast_input_types: bool = True,
|
|
3877
3885
|
prefetch_pipeline: bool = False,
|
|
3886
|
+
vbe_output: Optional[Tensor] = None,
|
|
3887
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3878
3888
|
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3879
3889
|
"""
|
|
3880
3890
|
Prepare TBE inputs as follows:
|
|
@@ -3901,9 +3911,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3901
3911
|
metadata
|
|
3902
3912
|
"""
|
|
3903
3913
|
|
|
3914
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
3915
|
+
assert (
|
|
3916
|
+
not self.use_cpu
|
|
3917
|
+
), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
|
|
3918
|
+
check_allocated_vbe_output(
|
|
3919
|
+
self.output_dtype,
|
|
3920
|
+
batch_size_per_feature_per_rank,
|
|
3921
|
+
vbe_output,
|
|
3922
|
+
vbe_output_offsets,
|
|
3923
|
+
)
|
|
3924
|
+
|
|
3904
3925
|
# Generate VBE metadata
|
|
3905
3926
|
vbe_metadata = self._generate_vbe_metadata(
|
|
3906
|
-
offsets, batch_size_per_feature_per_rank
|
|
3927
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
3907
3928
|
)
|
|
3908
3929
|
|
|
3909
3930
|
vbe = vbe_metadata.B_offsets is not None
|
|
@@ -3976,7 +3997,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
|
|
|
3976
3997
|
self.is_nobag,
|
|
3977
3998
|
vbe_metadata.max_B_feature_rank,
|
|
3978
3999
|
self.info_B_num_bits,
|
|
3979
|
-
offsets.numel() - 1, # total_B
|
|
4000
|
+
offsets.numel() - 1, # total_B,
|
|
4001
|
+
vbe_output_offsets,
|
|
3980
4002
|
)
|
|
3981
4003
|
else:
|
|
3982
4004
|
b_t_map = None
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
# pyre-unsafe
|
|
9
9
|
|
|
10
|
-
from typing import Optional
|
|
10
|
+
from typing import List, Optional
|
|
11
11
|
|
|
12
12
|
import torch
|
|
13
13
|
from torch import Tensor
|
|
@@ -31,6 +31,7 @@ except Exception:
|
|
|
31
31
|
|
|
32
32
|
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
|
33
33
|
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
|
|
34
|
+
from fbgemm_gpu.split_embedding_configs import sparse_type_int_to_dtype
|
|
34
35
|
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
|
|
35
36
|
|
|
36
37
|
|
|
@@ -40,6 +41,8 @@ def generate_vbe_metadata(
|
|
|
40
41
|
pooling_mode: PoolingMode,
|
|
41
42
|
feature_dims_cpu: Tensor,
|
|
42
43
|
device: torch.device,
|
|
44
|
+
vbe_output: Optional[Tensor] = None,
|
|
45
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
43
46
|
) -> invokers.lookup_args.VBEMetadata:
|
|
44
47
|
"""
|
|
45
48
|
Generate VBE metadata based on batch_size_per_feature_per_rank.
|
|
@@ -133,6 +136,8 @@ def generate_vbe_metadata(
|
|
|
133
136
|
max_B_feature_rank=max_B_feature_rank,
|
|
134
137
|
# pyre-ignore
|
|
135
138
|
output_size=output_size,
|
|
139
|
+
vbe_output=vbe_output,
|
|
140
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
136
141
|
)
|
|
137
142
|
else:
|
|
138
143
|
vbe_metadata = invokers.lookup_args.VBEMetadata(
|
|
@@ -142,5 +147,43 @@ def generate_vbe_metadata(
|
|
|
142
147
|
max_B=-1,
|
|
143
148
|
max_B_feature_rank=-1,
|
|
144
149
|
output_size=-1,
|
|
150
|
+
vbe_output=None,
|
|
151
|
+
vbe_output_offsets=None,
|
|
145
152
|
)
|
|
146
153
|
return vbe_metadata
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def check_allocated_vbe_output(
|
|
157
|
+
output_dtype: int,
|
|
158
|
+
batch_size_per_feature_per_rank: Optional[List[List[int]]],
|
|
159
|
+
vbe_output: Optional[Tensor] = None,
|
|
160
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
161
|
+
) -> None:
|
|
162
|
+
assert (
|
|
163
|
+
batch_size_per_feature_per_rank is not None
|
|
164
|
+
), "[Merged_VBE] vbe_output is passed, batch_size_per_feature_per_rank cannot be None"
|
|
165
|
+
assert (
|
|
166
|
+
vbe_output is not None
|
|
167
|
+
), "[Merged_VBE] vbe_output_offsets is not None, vbe_output cannot be None"
|
|
168
|
+
assert (
|
|
169
|
+
vbe_output_offsets is not None
|
|
170
|
+
), "[Merged_VBE] vbe_output is not None, vbe_output_offsets cannot be None"
|
|
171
|
+
num_features = len(batch_size_per_feature_per_rank)
|
|
172
|
+
num_ranks = len(batch_size_per_feature_per_rank[0])
|
|
173
|
+
assert vbe_output_offsets.shape == torch.Size(
|
|
174
|
+
[num_ranks, num_features]
|
|
175
|
+
), f"[Merged_VBE] Mismatched vbe_output_offsets shape. batch_size_per_feature_per_rank={batch_size_per_feature_per_rank}. Expected: {torch.Size([num_ranks, num_features])}, Actual: {vbe_output_offsets.shape}"
|
|
176
|
+
assert (
|
|
177
|
+
vbe_output.dim() == 1
|
|
178
|
+
), f"[Merged_VBE] vbe_output must have 1 dimension, but got {vbe_output.dim()}. vbe_output shape is {vbe_output.shape}"
|
|
179
|
+
assert (
|
|
180
|
+
vbe_output_offsets.device == vbe_output.device
|
|
181
|
+
), "[Merged_VBE] vbe_output_offsets and vbe_output must be on the same device"
|
|
182
|
+
_output_dtype = sparse_type_int_to_dtype(output_dtype)
|
|
183
|
+
assert (
|
|
184
|
+
vbe_output.dtype == _output_dtype
|
|
185
|
+
), f"[Merged_VBE] vbe_output dtype must match TBE output dtype {_output_dtype} (SparseType {output_dtype}), but got {vbe_output.dtype}"
|
|
186
|
+
assert (
|
|
187
|
+
vbe_output_offsets.is_contiguous()
|
|
188
|
+
), "[Merged_VBE] vbe_output_offsets needs to be contiguous"
|
|
189
|
+
assert vbe_output.is_contiguous(), "[Merged_VBE] vbe_output needs to be contiguous"
|
fbgemm_gpu/tbe/ssd/training.py
CHANGED
|
@@ -50,6 +50,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
|
|
|
50
50
|
WeightDecayMode,
|
|
51
51
|
)
|
|
52
52
|
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
|
|
53
|
+
check_allocated_vbe_output,
|
|
53
54
|
generate_vbe_metadata,
|
|
54
55
|
is_torchdynamo_compiling,
|
|
55
56
|
)
|
|
@@ -2308,6 +2309,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2308
2309
|
self,
|
|
2309
2310
|
offsets: Tensor,
|
|
2310
2311
|
batch_size_per_feature_per_rank: Optional[list[list[int]]],
|
|
2312
|
+
vbe_output: Optional[Tensor] = None,
|
|
2313
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2311
2314
|
) -> invokers.lookup_args.VBEMetadata:
|
|
2312
2315
|
# Blocking D2H copy, but only runs at first call
|
|
2313
2316
|
self.feature_dims = self.feature_dims.cpu()
|
|
@@ -2326,6 +2329,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2326
2329
|
self.pooling_mode,
|
|
2327
2330
|
self.feature_dims,
|
|
2328
2331
|
self.current_device,
|
|
2332
|
+
vbe_output,
|
|
2333
|
+
vbe_output_offsets,
|
|
2329
2334
|
)
|
|
2330
2335
|
|
|
2331
2336
|
def _increment_iteration(self) -> int:
|
|
@@ -2356,11 +2361,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
2356
2361
|
per_sample_weights: Optional[Tensor] = None,
|
|
2357
2362
|
feature_requires_grad: Optional[Tensor] = None,
|
|
2358
2363
|
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
2364
|
+
vbe_output: Optional[Tensor] = None,
|
|
2365
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
2359
2366
|
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
|
|
2360
2367
|
) -> Tensor:
|
|
2361
2368
|
self.clear_cache()
|
|
2369
|
+
if vbe_output is not None or vbe_output_offsets is not None:
|
|
2370
|
+
# CPU is not supported in SSD TBE
|
|
2371
|
+
check_allocated_vbe_output(
|
|
2372
|
+
self.output_dtype,
|
|
2373
|
+
batch_size_per_feature_per_rank,
|
|
2374
|
+
vbe_output,
|
|
2375
|
+
vbe_output_offsets,
|
|
2376
|
+
)
|
|
2362
2377
|
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
|
|
2363
|
-
indices,
|
|
2378
|
+
indices,
|
|
2379
|
+
offsets,
|
|
2380
|
+
per_sample_weights,
|
|
2381
|
+
batch_size_per_feature_per_rank,
|
|
2382
|
+
vbe_output=vbe_output,
|
|
2383
|
+
vbe_output_offsets=vbe_output_offsets,
|
|
2364
2384
|
)
|
|
2365
2385
|
|
|
2366
2386
|
if len(self.timesteps_prefetched) == 0:
|
|
@@ -3691,13 +3711,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
|
|
3691
3711
|
offsets: Tensor,
|
|
3692
3712
|
per_sample_weights: Optional[Tensor] = None,
|
|
3693
3713
|
batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
|
|
3714
|
+
vbe_output: Optional[Tensor] = None,
|
|
3715
|
+
vbe_output_offsets: Optional[Tensor] = None,
|
|
3694
3716
|
) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
|
|
3695
3717
|
"""
|
|
3696
3718
|
Prepare TBE inputs
|
|
3697
3719
|
"""
|
|
3698
3720
|
# Generate VBE metadata
|
|
3699
3721
|
vbe_metadata = self._generate_vbe_metadata(
|
|
3700
|
-
offsets, batch_size_per_feature_per_rank
|
|
3722
|
+
offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
|
|
3701
3723
|
)
|
|
3702
3724
|
|
|
3703
3725
|
# Force casting indices and offsets to long
|
fbgemm_gpu/triton/quantize.py
CHANGED
|
@@ -575,7 +575,7 @@ def _kernel_dequantize_mx4(
|
|
|
575
575
|
# Write final outputs.
|
|
576
576
|
tl.store(
|
|
577
577
|
out + output_offset,
|
|
578
|
-
scaled_fp32,
|
|
578
|
+
scaled_fp32.to(out.dtype.element_ty),
|
|
579
579
|
# Mask values that are out of this chunk or the main array.
|
|
580
580
|
mask=(output_offset < OUTPUT_SIZE)
|
|
581
581
|
& (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
|
|
@@ -588,10 +588,14 @@ def _kernel_dequantize_mx4(
|
|
|
588
588
|
|
|
589
589
|
|
|
590
590
|
def triton_dequantize_mx4(
|
|
591
|
-
a: torch.Tensor,
|
|
591
|
+
a: torch.Tensor,
|
|
592
|
+
group_size: int = 32,
|
|
593
|
+
ebits: int = 2,
|
|
594
|
+
mbits: int = 1,
|
|
595
|
+
output_dtype: torch.dtype = torch.float32,
|
|
592
596
|
) -> torch.Tensor:
|
|
593
597
|
"""
|
|
594
|
-
Dequantize a tensor from mx4 format to fp32.
|
|
598
|
+
Dequantize a tensor from mx4 format to fp32 or bf16.
|
|
595
599
|
|
|
596
600
|
Args:
|
|
597
601
|
a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
|
|
@@ -599,13 +603,15 @@ def triton_dequantize_mx4(
|
|
|
599
603
|
group_size (int): Size of chunks that use the same shared exponent.
|
|
600
604
|
ebits (int): Number of bits to use for exponent in target mx4 format.
|
|
601
605
|
mbits (int): Number of bits to use for mantissa in target mx4 format.
|
|
606
|
+
output_dtype (torch.dtype): Output dtype (FP32 or BF16).
|
|
607
|
+
Defaults to torch.float32 for backward compatibility.
|
|
602
608
|
|
|
603
609
|
Returns:
|
|
604
|
-
torch.Tensor: [M, K] dequantized
|
|
610
|
+
torch.Tensor: [M, K] dequantized tensor in the specified dtype.
|
|
605
611
|
"""
|
|
606
612
|
# If given an empty shape, return an empty tensor.
|
|
607
613
|
if a.numel() == 0:
|
|
608
|
-
return torch.empty(a.shape, device=a.device, dtype=
|
|
614
|
+
return torch.empty(a.shape, device=a.device, dtype=output_dtype)
|
|
609
615
|
# View a as 2D for simplicity.
|
|
610
616
|
orig_shape = a.shape
|
|
611
617
|
a = a.flatten()
|
|
@@ -622,9 +628,9 @@ def triton_dequantize_mx4(
|
|
|
622
628
|
# Use a lookup table to convert
|
|
623
629
|
mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)
|
|
624
630
|
|
|
625
|
-
# Create output tensor.
|
|
631
|
+
# Create output tensor in target dtype.
|
|
626
632
|
output_elems = num_groups * group_size
|
|
627
|
-
out = torch.empty([output_elems], device=a.device, dtype=
|
|
633
|
+
out = torch.empty([output_elems], device=a.device, dtype=output_dtype)
|
|
628
634
|
# Check if we need to use int64 for indexing.
|
|
629
635
|
use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
|
|
630
636
|
# Invoke triton dequantization kernel over rows.
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def writeback_update_gradient(
|
|
11
|
+
indices: torch.Tensor,
|
|
12
|
+
offsets: torch.Tensor,
|
|
13
|
+
grad: torch.Tensor,
|
|
14
|
+
feature_table_map: list[int],
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
"""
|
|
17
|
+
Update gradient tensor by deduplicating indices across all features/tables.
|
|
18
|
+
For duplicate indices, only the first occurrence receives the gradient to achieve the assign purpose via gradient update
|
|
19
|
+
|
|
20
|
+
NOTE: This function is not supporting VBE yet
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
indices (torch.Tensor): Embedding indices tensor
|
|
24
|
+
offsets (torch.Tensor): Offsets tensor for batched embeddings
|
|
25
|
+
grad (torch.Tensor): Gradient tensor to be updated
|
|
26
|
+
feature_table_map (list[int]): Mapping from feature to table
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
torch.Tensor: Updated gradient tensor with duplicates masked out
|
|
30
|
+
"""
|
|
31
|
+
if indices.numel() == 0:
|
|
32
|
+
return grad[0]
|
|
33
|
+
# get num of feature to estimate batch size
|
|
34
|
+
num_of_tables = len(feature_table_map)
|
|
35
|
+
assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
|
|
36
|
+
batch_size = offsets.shape[0] // num_of_tables
|
|
37
|
+
max_indices = indices.max()
|
|
38
|
+
non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
|
|
39
|
+
# disable dedup across different table
|
|
40
|
+
indices = ((offsets[non_empty_index]) // batch_size) * (1 + max_indices) + indices
|
|
41
|
+
grad = grad[0]
|
|
42
|
+
_, idx, counts = torch.unique(
|
|
43
|
+
indices, dim=0, sorted=True, return_inverse=True, return_counts=True
|
|
44
|
+
)
|
|
45
|
+
_, ind_sorted = torch.sort(idx, stable=True)
|
|
46
|
+
cum_sum = counts.cumsum(0)
|
|
47
|
+
cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
|
|
48
|
+
first_indicies = ind_sorted[cum_sum]
|
|
49
|
+
mask = torch.zeros_like(grad, device=grad.device)
|
|
50
|
+
original_index = non_empty_index[first_indicies]
|
|
51
|
+
|
|
52
|
+
mask[original_index] = grad[original_index]
|
|
53
|
+
return mask
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def writeback_update_gradient_first_feature_only(
|
|
57
|
+
indices: torch.Tensor,
|
|
58
|
+
offsets: torch.Tensor,
|
|
59
|
+
grad: torch.Tensor,
|
|
60
|
+
feature_table_map: list[int],
|
|
61
|
+
) -> torch.Tensor:
|
|
62
|
+
"""
|
|
63
|
+
Special case of writeback_update_gradient where gradient only needs to be updated for the first feature. Other features will be forward-only
|
|
64
|
+
|
|
65
|
+
NOTE: This function is not supporting VBE yet
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
indices (torch.Tensor): Embedding indices tensor
|
|
69
|
+
offsets (torch.Tensor): Offsets tensor for batched embeddings
|
|
70
|
+
grad (torch.Tensor): Gradient tensor to be updated
|
|
71
|
+
feature_table_map (list[int]): Mapping from feature to table
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Updated gradient tensor with duplicates masked out
|
|
75
|
+
"""
|
|
76
|
+
num_of_tables = len(feature_table_map)
|
|
77
|
+
batch_size = (offsets.shape[0] - 1) // num_of_tables
|
|
78
|
+
shrink_indices = indices[: offsets[batch_size]]
|
|
79
|
+
if shrink_indices.numel() == 0 or indices.numel() == 0:
|
|
80
|
+
return grad[0]
|
|
81
|
+
assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
|
|
82
|
+
|
|
83
|
+
grad = grad[0]
|
|
84
|
+
_, idx, counts = torch.unique(
|
|
85
|
+
shrink_indices, dim=0, sorted=True, return_inverse=True, return_counts=True
|
|
86
|
+
)
|
|
87
|
+
_, ind_sorted = torch.sort(idx, stable=True)
|
|
88
|
+
cum_sum = counts.cumsum(0)
|
|
89
|
+
cum_sum = torch.cat((torch.tensor([0]).to(shrink_indices.device), cum_sum[:-1]))
|
|
90
|
+
first_indicies = ind_sorted[cum_sum]
|
|
91
|
+
mask = torch.zeros_like(grad, device=grad.device)
|
|
92
|
+
|
|
93
|
+
mask[first_indicies] = grad[first_indicies]
|
|
94
|
+
return mask
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def writeback_gradient(
|
|
98
|
+
grad: torch.Tensor,
|
|
99
|
+
indices: torch.Tensor,
|
|
100
|
+
offsets: torch.Tensor,
|
|
101
|
+
feature_table_map: list[int],
|
|
102
|
+
writeback_first_feature_only: bool = False,
|
|
103
|
+
) -> tuple[torch.Tensor]:
|
|
104
|
+
"""
|
|
105
|
+
Compute deduplicated gradient for writeback operation.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
grad (torch.Tensor): Gradient tensor to be updated
|
|
109
|
+
indices (torch.Tensor): Embedding indices tensor
|
|
110
|
+
offsets (torch.Tensor): Offsets tensor for batched embeddings
|
|
111
|
+
feature_table_map (list[int]): Mapping from feature to table
|
|
112
|
+
writeback_first_feature_only (bool): If True, only first feature will apply gradient update, other features will be read-only
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
tuple[torch.Tensor]: Tuple containing the updated gradient tensor
|
|
116
|
+
"""
|
|
117
|
+
if writeback_first_feature_only:
|
|
118
|
+
return (
|
|
119
|
+
writeback_update_gradient_first_feature_only(
|
|
120
|
+
indices, offsets, grad, feature_table_map
|
|
121
|
+
),
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
return (writeback_update_gradient(indices, offsets, grad, feature_table_map),)
|
{fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/RECORD
RENAMED
|
@@ -1,29 +1,29 @@
|
|
|
1
1
|
fbgemm_gpu/__init__.py,sha256=bL2dL7uYeXb1GvdjIDUTcLXLRGNfmnI4MQoE3-Gg5m8,6361
|
|
2
|
-
fbgemm_gpu/asmjit.so,sha256=
|
|
2
|
+
fbgemm_gpu/asmjit.so,sha256=UxnhHlu9LgmoRXa8fZwSX56b5QKffBxfAOs0AZLxRfk,501728
|
|
3
3
|
fbgemm_gpu/batched_unary_embeddings_ops.py,sha256=GYeJ9pg-Wc9FokXVci_npDsL6UV18-pJXID2xzrJ9O8,2904
|
|
4
4
|
fbgemm_gpu/enums.py,sha256=37ewGSfO1x7sO31ZkRiqV1yKuklfHXT5qZIxzeeGogo,755
|
|
5
|
-
fbgemm_gpu/fbgemm.so,sha256=
|
|
5
|
+
fbgemm_gpu/fbgemm.so,sha256=HQUXhk9ikCtui6125NmRjOJvAMaWRgZXLlBcQHsh2Xo,5646712
|
|
6
6
|
fbgemm_gpu/metrics.py,sha256=TsurFLJf0nJvPDN7urWb4LMQlf5RgdWPTTTDO7S4wtI,5663
|
|
7
7
|
fbgemm_gpu/permute_pooled_embedding_modules.py,sha256=vOXMYclaGnwSt0St_SOAlAe18kz6WjMyTeHnC9jLhcE,5130
|
|
8
8
|
fbgemm_gpu/permute_pooled_embedding_modules_split.py,sha256=f3VJvH_kw9Ltd_DXtaf_PJPHmlmEWrQgzQ7MDkhh5Nw,2746
|
|
9
|
-
fbgemm_gpu/quantize_comm.py,sha256=
|
|
10
|
-
fbgemm_gpu/quantize_utils.py,sha256=
|
|
9
|
+
fbgemm_gpu/quantize_comm.py,sha256=j4-wBqWRtXjhtQBKi7IOAftNDzv8-AeX9YXlD8e682c,11983
|
|
10
|
+
fbgemm_gpu/quantize_utils.py,sha256=fK4Dk9Qpjsu4qASCwAxkLjbiFRLI71Hd-AtHY4NyMZ8,10200
|
|
11
11
|
fbgemm_gpu/runtime_monitor.py,sha256=YXRUv6nXCsoTgh5_RzailTGvCYzwoYDb-eR4rlGwtaw,7619
|
|
12
12
|
fbgemm_gpu/sparse_ops.py,sha256=_EJC1pAbNnAnVQQ5JBg4DAV2TboIj-4XQkiKMmg1vXI,50417
|
|
13
|
-
fbgemm_gpu/split_embedding_configs.py,sha256=
|
|
13
|
+
fbgemm_gpu/split_embedding_configs.py,sha256=EuVFKIDrgRQpRC5mmB4Du6WftK5GXJvDue9_ezt_eBI,16575
|
|
14
14
|
fbgemm_gpu/split_embedding_inference_converter.py,sha256=AghGW22MgMsdHzdwdPMPYDjgas5AE_estckY8rMgXVU,7056
|
|
15
15
|
fbgemm_gpu/split_embedding_optimizer_ops.py,sha256=wXuGazClBMk62yL_r9udUIKaPgQP7SlkSb5ugB75wrQ,711
|
|
16
16
|
fbgemm_gpu/split_embedding_utils.py,sha256=Gb40ZKeATxIKEKI3aVQMgDDBanNpKMc53Z43mnzdR_I,851
|
|
17
17
|
fbgemm_gpu/split_table_batched_embeddings_ops.py,sha256=_MIp6uHYHLn4GxGdrGsfddfSsZ2Z9mjsYIrih3ncI1I,2339
|
|
18
18
|
fbgemm_gpu/split_table_batched_embeddings_ops_common.py,sha256=eFxb_bDfBV8G76pmd-SxDXXXnqgbuGYOS4pSU8JS5dg,19295
|
|
19
19
|
fbgemm_gpu/split_table_batched_embeddings_ops_inference.py,sha256=dGC85xjQiRUrequBibSf9oMAVHT5Q49zsVo2zW4n_88,81679
|
|
20
|
-
fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=
|
|
21
|
-
fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py,sha256=
|
|
20
|
+
fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=rNGMELM_xFIsdS_340PB7bsn9h_VjONq_JJG1SjHyvQ,188992
|
|
21
|
+
fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py,sha256=jofAN2UB_iSk53Id6MBvn9Bi3Qxw67IL0_VE_EHlw_Q,7593
|
|
22
22
|
fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py,sha256=7qGkO8FARku38mFYl4Bc4qL8dS1wrfyorS9l1m5ZAVA,718
|
|
23
23
|
fbgemm_gpu/tbe_input_multiplexer.py,sha256=TQjwkJ2JkOaQsMYuRdk9RbNa9759EPEtx8bYclChtZY,3063
|
|
24
24
|
fbgemm_gpu/uvm.py,sha256=guNK8ZzR80jmv-CyRgEhxhVYhjz3R9d6tB8Hu1uWDUo,1047
|
|
25
25
|
fbgemm_gpu/config/__init__.py,sha256=yN0KAneCICgF2BTfOYGsd0qU1PvZX_6msC6YHHZKLMg,292
|
|
26
|
-
fbgemm_gpu/config/feature_list.py,sha256=
|
|
26
|
+
fbgemm_gpu/config/feature_list.py,sha256=hhDNkkafd-Oetvuqv9ylBVTNM-lKPi029mpRqq-JZCA,2467
|
|
27
27
|
fbgemm_gpu/docs/__init__.py,sha256=DR6hMSQrsZALfH2AnuJQ4Zq2CfBUUhMN8YjD6APjiAE,523
|
|
28
28
|
fbgemm_gpu/docs/common.py,sha256=8ipXTwVb222X-aZ71O6n8fhxHCHPNhJEHMFiO7epcIs,273
|
|
29
29
|
fbgemm_gpu/docs/examples.py,sha256=ZMN_6sL74LH_hrp2bF_hmg8gi29GhcgvwV3kCMjxkoE,2377
|
|
@@ -32,9 +32,9 @@ fbgemm_gpu/docs/merge_pooled_embedding_ops.py,sha256=oJLgSgZQmhsyGLbTmZTxNgQrk65
|
|
|
32
32
|
fbgemm_gpu/docs/permute_pooled_embedding_ops.py,sha256=tZUqLVXlk5O6VAKKDA-OEMx2fCu5QPOOeoAPZA9_nLY,4454
|
|
33
33
|
fbgemm_gpu/docs/quantize_ops.py,sha256=xTtOaVK1P02ymreE_i21YiyYDZCqhoZY9eWp_mEIRlo,1297
|
|
34
34
|
fbgemm_gpu/docs/sparse_ops.py,sha256=gSLUFdnu8lle_6gLewFkM20wL3ek2jKLvDGMKR6POaY,27292
|
|
35
|
-
fbgemm_gpu/docs/target.genai.json.py,sha256=
|
|
35
|
+
fbgemm_gpu/docs/target.genai.json.py,sha256=TVO8vYaBQPaEdT-bYeXlTdOGiTw4ceWeAWa9m9Wnerg,77
|
|
36
36
|
fbgemm_gpu/experimental/example/__init__.py,sha256=OvJHZgWnycL1gWKyCXFJCTKuys3KAqx4iadjx3R-tBQ,723
|
|
37
|
-
fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=
|
|
37
|
+
fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=9NH_0L5RRD5NwIOILKiOGjKo25isXYYkmwFvbIlUGe0,190656
|
|
38
38
|
fbgemm_gpu/experimental/example/utils.py,sha256=Je__VkMlBMLOhh7NXOocOdvaa2gz9kl9Dkqeu25tpFA,562
|
|
39
39
|
fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py,sha256=1CqUfzlYyXTvU-BNaUq4RZpLV-2lKAVCAHeJzSIZFWw,419
|
|
40
40
|
fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py,sha256=R4VNZdPSgmRmwDfTt2CShED2SGUF6dCXSUW2C4LISgE,215713
|
|
@@ -43,11 +43,11 @@ fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py,sha256=5ClZ-GDrx6q0uaqW
|
|
|
43
43
|
fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py,sha256=SltbY_dsit5e7B8lDIB_VYPrEq0t9kckthj9mQaVNfA,7571
|
|
44
44
|
fbgemm_gpu/experimental/gemm/triton_gemm/utils.py,sha256=rULXIpVaaRS3GKUZ1RHcWUrUyy0xMVREwS1SFShGgcw,4302
|
|
45
45
|
fbgemm_gpu/experimental/gen_ai/__init__.py,sha256=r3NlNCXuIh0pfKwKU5v14y6AZkpoIkKWbtzxSprgeKA,1713
|
|
46
|
-
fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=
|
|
46
|
+
fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=_1HEs4E69AIPZ7zauPqcHB01egv5eNf2IKNZuihMOAA,65230568
|
|
47
47
|
fbgemm_gpu/experimental/gen_ai/quantize.py,sha256=KAljWSdN-1_c5DWfT-3MDxWLMULK49Yu36t6TmQI9Tw,12599
|
|
48
48
|
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py,sha256=-R_LxyHpdXMILU9TNuYoRisBCkfK0_VLyixefaeZf4g,1463
|
|
49
49
|
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py,sha256=gbhNU3mDTKJb3yt3inIDbiUjX_SG1oZfzgDygtHvMpk,10101
|
|
50
|
-
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py,sha256=
|
|
50
|
+
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py,sha256=fD39_WH7TfNCiP5Vl46ToX6PsLMLUFLhizT26Qe7TWg,17282
|
|
51
51
|
fbgemm_gpu/experimental/gen_ai/bench/__init__.py,sha256=XpAK_eyqDSKeFC5J9KpnKtbZG07mrDh9d2j1LFKzr-8,404
|
|
52
52
|
fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py,sha256=ApEyJOf_rdIo8V_EgvhZXBGNov8ITC_dnB95v8szulI,8515
|
|
53
53
|
fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py,sha256=K9Nib6D7xJbw1QwEVuCJrVyI1qs988moo3cieVKYuFY,12057
|
|
@@ -99,7 +99,7 @@ fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py,sha256=vZHj7KIe1DoJDy5eft29Xt
|
|
|
99
99
|
fbgemm_gpu/tbe/ssd/__init__.py,sha256=wzfMT10cp_dqK2lrebC449hOdexBnizcf_98lA1NyHs,483
|
|
100
100
|
fbgemm_gpu/tbe/ssd/common.py,sha256=1J8K7sTQswgCYWaVwF-ZdCJj7mNN6O9GI70AaZWzJGE,1044
|
|
101
101
|
fbgemm_gpu/tbe/ssd/inference.py,sha256=B_uX66ajGA9YKGlFa5TmGWs7b-b1RFigzwxmENZ9Oio,22816
|
|
102
|
-
fbgemm_gpu/tbe/ssd/training.py,sha256=
|
|
102
|
+
fbgemm_gpu/tbe/ssd/training.py,sha256=C6M3H_f8oWWRkC4R-BJED73au-Gl9SUVllxOoFSiDkI,212234
|
|
103
103
|
fbgemm_gpu/tbe/ssd/utils/__init__.py,sha256=5DgmR2HA6NtmYh2ddkUgpDsZ6a7hF0DPedA1gMpdh18,250
|
|
104
104
|
fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py,sha256=SFg2-29b-i49LWm-FlaWUkTz2XzXbicYi_AzVj4jKNE,7601
|
|
105
105
|
fbgemm_gpu/tbe/stats/__init__.py,sha256=on29iDtq7cVNh90JR9aeFNG-K9DDoYq0JryzoplL49I,322
|
|
@@ -111,7 +111,7 @@ fbgemm_gpu/tbe/utils/quantize.py,sha256=icN2MXnl5rNqtKhGKkjpelx5pYBMYUv-6CrghxeV
|
|
|
111
111
|
fbgemm_gpu/tbe/utils/requests.py,sha256=rQkEoaUUWEYCQM-1K_Lxg1wPcyIVw8sbdaGFTpsaE5I,18040
|
|
112
112
|
fbgemm_gpu/triton/__init__.py,sha256=kPn_Ye6J9DAzWtqi76KYGwfKSqw0IhqG3Bir5aUpkWM,658
|
|
113
113
|
fbgemm_gpu/triton/common.py,sha256=wnkLd2a8fKpefymLL-LjNKEL4hDVSxFiF5g3aF8mzsw,2131
|
|
114
|
-
fbgemm_gpu/triton/quantize.py,sha256=
|
|
114
|
+
fbgemm_gpu/triton/quantize.py,sha256=bjMPgcUOcuG_d9I_EjSCpkU3Fr5f3FU7CIWIqzc_N3w,27074
|
|
115
115
|
fbgemm_gpu/triton/quantize_ref.py,sha256=q4RBmFaqPVPELU52lbSgB0n26Aun7apeK7bRF2MWS80,11553
|
|
116
116
|
fbgemm_gpu/triton/jagged/__init__.py,sha256=om0yhjuzKuE1UQakFMWHsXN4WNb8mvNkZtYofQ8hdn4,246
|
|
117
117
|
fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py,sha256=F2eQWjkWMR5RWQ48oIr-8OU_CRZyLazDpT7DFrDWS6g,29871
|
|
@@ -119,9 +119,10 @@ fbgemm_gpu/utils/__init__.py,sha256=JQQNdcTTaEU6ptK-OW-ZQBwTFxEZZpWOtBXWwEZm39o,
|
|
|
119
119
|
fbgemm_gpu/utils/filestore.py,sha256=oVtbKGaPQki1JgbJCkrkElukOFVyxntQpSC0lYBKgho,6455
|
|
120
120
|
fbgemm_gpu/utils/loader.py,sha256=1hCEhNvkflniH46fGcrguLeP1z-6uyOu2QFwqKU5CIM,990
|
|
121
121
|
fbgemm_gpu/utils/torch_library.py,sha256=ywsAHjbuwesj50LjEu99WkAH17FlaVgePZ9OmFg6YE4,4193
|
|
122
|
+
fbgemm_gpu/utils/writeback_util.py,sha256=PyVbHp1EuF-GKrJv_CTP6B50Z0oBblXKucf7Rhd6KKY,4614
|
|
122
123
|
list_versions/__init__.py,sha256=UmTeqCk-UJWFtlZQWvZao3xvui2w9E3X_JdOXVjRaNw,315
|
|
123
124
|
list_versions/cli_run.py,sha256=CChZoXQ-tiKaWboXAYlPVJ5w8K5zAKiKcncA087I1sc,4508
|
|
124
|
-
fbgemm_gpu_genai_nightly-
|
|
125
|
-
fbgemm_gpu_genai_nightly-
|
|
126
|
-
fbgemm_gpu_genai_nightly-
|
|
127
|
-
fbgemm_gpu_genai_nightly-
|
|
125
|
+
fbgemm_gpu_genai_nightly-2026.1.9.dist-info/METADATA,sha256=WQ7sQGvWWGQao3Wcjk39i_bFY2BvrmWfGHxHgWxEsug,2655
|
|
126
|
+
fbgemm_gpu_genai_nightly-2026.1.9.dist-info/WHEEL,sha256=Nkv8TSWVt7XcnRf1cdq5HOzycTl6Pjzlmn7gPSv4NiQ,108
|
|
127
|
+
fbgemm_gpu_genai_nightly-2026.1.9.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
|
|
128
|
+
fbgemm_gpu_genai_nightly-2026.1.9.dist-info/RECORD,,
|
{fbgemm_gpu_genai_nightly-2025.12.17.dist-info → fbgemm_gpu_genai_nightly-2026.1.9.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|