fbgemm-gpu-genai-nightly 2025.12.17__cp313-cp313-manylinux_2_28_x86_64.whl → 2026.1.4__cp313-cp313-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fbgemm_gpu/asmjit.so CHANGED
Binary file
@@ -63,6 +63,9 @@ class FeatureGateName(Enum):
63
63
  # Enable TBE input parameters extraction
64
64
  TBE_REPORT_INPUT_PARAMS = auto()
65
65
 
66
+ # Enable tuned max segment length per CTA for B200
67
+ TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
68
+
66
69
  def is_enabled(self) -> bool:
67
70
  return FeatureGate.is_enabled(self)
68
71
 
@@ -1,6 +1,6 @@
1
1
 
2
2
  {
3
- "version": "2025.12.17",
3
+ "version": "2026.1.4",
4
4
  "target": "genai",
5
5
  "variant": "cuda"
6
6
  }
@@ -289,7 +289,7 @@ def cutlass_blackwell_fmha_decode_forward(
289
289
  window_left: int = -1,
290
290
  window_right: int = -1,
291
291
  bottom_right: bool = True,
292
- split_k_size: int = 1024,
292
+ split_k_size: int = 0,
293
293
  use_heuristic: bool = True,
294
294
  ) -> tuple[torch.Tensor, torch.Tensor]:
295
295
  """
@@ -318,14 +318,9 @@ def cutlass_blackwell_fmha_decode_forward(
318
318
  split size using the heuristic. Default is True.
319
319
 
320
320
  Returns:
321
- Conditional return based on split-K mode:
322
- - Non-split case (split_k_size <= 0 and use_heuristic=False):
323
- out: Same shape as input q ([B, H, D] for varlen or [B, 1, H, D] for batch)
324
- with bfloat16 dtype
325
- lse: [B, H, 1] (always float32)
326
- - Split case (split_k_size > 0 or use_heuristic=True):
327
- out: [B, H, num_splits, D] with float32 dtype (partial outputs for later reduction)
328
- lse: [B, num_splits, H] (always float32)
321
+ Kernel output with Q dimension added:
322
+ - out: [B, 1, H, num_splits, D] (num_splits=1 when split-K disabled)
323
+ - lse: [B, num_splits, H, 1]
329
324
  """
330
325
  _validate_decode_inputs(q, k, v, seqlen_kv)
331
326
 
@@ -365,15 +360,12 @@ def cutlass_blackwell_fmha_decode_forward(
365
360
  split_k_size=split_k_size,
366
361
  )
367
362
 
368
- # Handle output based on split-K mode
369
- is_split = split_k_size > 0
370
-
371
- if not is_split:
372
- # out shape: [B, H, Splits = 1, D] -> original shape
373
- out = out.view(*original_shape)
374
- # lse shape: [B, Splits = 1, H] -> [B, H, 1]
375
- lse = lse.view(batch_size, -1, 1)
376
-
363
+ # Kernel returns: out [B, H, num_splits, D], lse [B, num_splits, H]
364
+ # Reshape to consistent format with Q dimension:
365
+ # out: [B, H, num_splits, D] -> [B, 1, H, num_splits, D]
366
+ # lse: [B, num_splits, H] -> [B, num_splits, H, 1]
367
+ out = out.unsqueeze(1) # [B, 1, H, num_splits, D]
368
+ lse = lse.unsqueeze(-1) # [B, num_splits, H, 1]
377
369
  return out, lse
378
370
 
379
371
 
fbgemm_gpu/fbgemm.so CHANGED
Binary file
@@ -313,6 +313,40 @@ def sparse_type_to_int(sparse_type: "SparseType") -> int:
313
313
  }[sparse_type.value]
314
314
 
315
315
 
316
+ def sparse_type_int_to_dtype(ty: int) -> torch.dtype:
317
+ """
318
+ TorchScript-compatible function to convert an SparseType enum as integer) to torch.dtype.
319
+
320
+ This is a standalone function equivalent to SparseType.from_int(dtype_int).as_dtype() that works
321
+ with TorchScript. TorchScript does not support @staticmethod on Enum classes,
322
+ so this function provides a workaround.
323
+ """
324
+ if ty == 0: # fp32
325
+ return torch.float32
326
+ elif ty == 1: # fp16
327
+ return torch.float16
328
+ elif ty == 2: # int8
329
+ return torch.uint8
330
+ elif ty == 3: # int4
331
+ return torch.quint4x2
332
+ elif ty == 4: # int2
333
+ return torch.quint2x4
334
+ elif ty == 5: # bf16
335
+ return torch.bfloat16
336
+ elif ty == 6: # fp8
337
+ return torch.uint8
338
+ elif ty == 7: # mx4
339
+ return torch.uint8
340
+ elif ty == 9:
341
+ return (
342
+ torch.float8_e4m3fnuz
343
+ if torch.version.hip is not None
344
+ else torch.float8_e4m3fn
345
+ )
346
+ else: # Invalid is 7 or non enumerated.
347
+ raise ValueError(f"Unsupported sparse type: {ty}")
348
+
349
+
316
350
  @enum.unique
317
351
  class SparseType(enum.Enum):
318
352
  FP32 = "fp32"
@@ -49,6 +49,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
49
49
  SplitState,
50
50
  )
51
51
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
52
+ check_allocated_vbe_output,
52
53
  generate_vbe_metadata,
53
54
  is_torchdynamo_compiling,
54
55
  )
@@ -60,6 +61,7 @@ from fbgemm_gpu.tbe_input_multiplexer import (
60
61
  )
61
62
 
62
63
  from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
64
+ from fbgemm_gpu.utils.writeback_util import writeback_gradient
63
65
 
64
66
  try:
65
67
  load_torch_module(
@@ -159,6 +161,7 @@ class UserEnabledConfigDefinition:
159
161
  # More details can be found in D64848802.
160
162
  use_rowwise_bias_correction: bool = False
161
163
  use_writeback_bwd_prehook: bool = False
164
+ writeback_first_feature_only: bool = False
162
165
 
163
166
 
164
167
  @dataclass(frozen=True)
@@ -1181,6 +1184,10 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1181
1184
  self.use_writeback_bwd_prehook: bool = (
1182
1185
  extra_optimizer_config.use_writeback_bwd_prehook
1183
1186
  )
1187
+
1188
+ writeback_first_feature_only: bool = (
1189
+ extra_optimizer_config.writeback_first_feature_only
1190
+ )
1184
1191
  self.log(f"self.extra_optimizer_config is {extra_optimizer_config}")
1185
1192
  if self.use_rowwise_bias_correction and not self.optimizer == OptimType.ADAM:
1186
1193
  raise AssertionError(
@@ -1469,6 +1476,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1469
1476
  # self.log("TBE_V2 Knob is set to True; Using experimental TBE")
1470
1477
 
1471
1478
  self.is_experimental: bool = is_experimental
1479
+ self._writeback_first_feature_only: bool = writeback_first_feature_only
1472
1480
 
1473
1481
  # Get a debug function pointer
1474
1482
  self._debug_print_input_stats: Callable[..., None] = (
@@ -1483,7 +1491,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
1483
1491
  if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
1484
1492
  # Register writeback hook for Exact_SGD optimizer
1485
1493
  self.log(
1486
- "SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled."
1494
+ f"SplitTableBatchedEmbeddingBagsCodegen: use_writeback_bwd_prehook is enabled with first feature only={self._writeback_first_feature_only}"
1487
1495
  )
1488
1496
  # pyre-fixme[6]: Expected `typing.Callable[[Module, Union[Tensor, typing.Tuple[Tensor, ...]]], Union[None, Tensor, typing.Tuple[Tensor, ...]]]`
1489
1497
  self.register_full_backward_pre_hook(self.writeback_hook)
@@ -2003,6 +2011,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2003
2011
  self,
2004
2012
  offsets: Tensor,
2005
2013
  batch_size_per_feature_per_rank: Optional[list[list[int]]],
2014
+ vbe_output: Optional[Tensor] = None,
2015
+ vbe_output_offsets: Optional[Tensor] = None,
2006
2016
  ) -> invokers.lookup_args.VBEMetadata:
2007
2017
  # Blocking D2H copy, but only runs at first call
2008
2018
  self.feature_dims = self.feature_dims.cpu()
@@ -2025,6 +2035,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2025
2035
  self.pooling_mode,
2026
2036
  self.feature_dims,
2027
2037
  self.current_device,
2038
+ vbe_output,
2039
+ vbe_output_offsets,
2028
2040
  )
2029
2041
 
2030
2042
  @torch.jit.ignore
@@ -2033,40 +2045,17 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2033
2045
  # This allows models using this class to compile correctly
2034
2046
  return FeatureGate.is_enabled(feature)
2035
2047
 
2036
- def writeback_update_gradient(
2037
- self, indices: torch.Tensor, offsets: torch.Tensor, grad: Tensor
2038
- ) -> Tensor:
2039
- if indices.numel() == 0:
2040
- return grad[0]
2041
- num_of_tables = len(set(self.feature_table_map))
2042
- assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
2043
- batch_size = offsets.shape[0] // num_of_tables
2044
- max_indices = indices.max()
2045
- non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
2046
- # disable dedup across different table
2047
- indices = ((offsets[non_empty_index]) // batch_size) * (
2048
- 1 + max_indices
2049
- ) + indices
2050
- grad = grad[0]
2051
- _, idx, counts = torch.unique(
2052
- indices, dim=0, sorted=True, return_inverse=True, return_counts=True
2053
- )
2054
- _, ind_sorted = torch.sort(idx, stable=True)
2055
- cum_sum = counts.cumsum(0)
2056
- cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
2057
- first_indicies = ind_sorted[cum_sum]
2058
- mask = torch.zeros_like(grad, device=grad.device)
2059
- original_index = non_empty_index[first_indicies]
2060
-
2061
- mask[original_index] = grad[original_index]
2062
- return mask
2063
-
2064
2048
  # pyre-fixme[2]: For 1st argument expected not ANY
2065
2049
  def writeback_hook(self, module: Any, grad: Tensor) -> tuple[Tensor]:
2066
2050
  indices = self._indices
2067
2051
  offsets = self._offsets
2068
-
2069
- return (self.writeback_update_gradient(indices, offsets, grad),)
2052
+ return writeback_gradient(
2053
+ grad,
2054
+ indices,
2055
+ offsets,
2056
+ self.feature_table_map,
2057
+ self._writeback_first_feature_only,
2058
+ )
2070
2059
 
2071
2060
  def forward( # noqa: C901
2072
2061
  self,
@@ -2078,6 +2067,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2078
2067
  total_unique_indices: Optional[int] = None,
2079
2068
  hash_zch_identities: Optional[Tensor] = None,
2080
2069
  hash_zch_runtime_meta: Optional[Tensor] = None,
2070
+ vbe_output: Optional[Tensor] = None,
2071
+ vbe_output_offsets: Optional[Tensor] = None,
2081
2072
  ) -> Tensor:
2082
2073
  """
2083
2074
  The forward pass function that
@@ -2130,13 +2121,22 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2130
2121
  be set when using `OptimType.NONE`. This is because TBE
2131
2122
  requires this information for allocating the weight gradient
2132
2123
  tensor in the backward pass.
2133
-
2134
2124
  hash_zch_identities (Optional[Tensor]): The original raw IDs before
2135
2125
  remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
2136
2126
  populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
2137
2127
  and is required for Raw Embedding Streaming (RES) to maintain
2138
2128
  consistency between training and inference.
2139
-
2129
+ vbe_output (Optional[Tensor]): An optional 2-D tensor of size that
2130
+ contains output for TBE VBE. The shape of the tensor is
2131
+ [1, total_vbe_output_size] where total_vbe_output_size is the
2132
+ output size across all ranks and all embedding tables.
2133
+ If this tensor is not None, the TBE VBE forward output is written
2134
+ to this tensor at the locations specified by `vbe_output_offsets`.
2135
+ vbe_output_offsets (Optional[Tensor]): An optional 2-D tensor that
2136
+ contains VBE output offsets to `vbe_output`. The shape of the
2137
+ tensor is [num_ranks, num_features].
2138
+ vbe_output_offsets[r][f] represents the starting offset for rank `r`
2139
+ and feature `f`.
2140
2140
  Returns:
2141
2141
  A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
2142
2142
  batch size and `total_D` = the sum of all embedding dimensions in the
@@ -2210,8 +2210,16 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
2210
2210
  batch_size_per_feature_per_rank,
2211
2211
  force_cast_input_types=True,
2212
2212
  prefetch_pipeline=False,
2213
+ vbe_output=vbe_output,
2214
+ vbe_output_offsets=vbe_output_offsets,
2213
2215
  )
2214
2216
 
2217
+ # Only enable VBE if batch_size_per_feature_per_rank is not None
2218
+ assert not (
2219
+ batch_size_per_feature_per_rank is not None
2220
+ and self.use_writeback_bwd_prehook
2221
+ ), "VBE is not supported with writeback_bwd_prehook"
2222
+
2215
2223
  # Print input stats if enable (for debugging purpose only)
2216
2224
  self._debug_print_input_stats(indices, offsets, per_sample_weights)
2217
2225
 
@@ -3875,6 +3883,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3875
3883
  batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3876
3884
  force_cast_input_types: bool = True,
3877
3885
  prefetch_pipeline: bool = False,
3886
+ vbe_output: Optional[Tensor] = None,
3887
+ vbe_output_offsets: Optional[Tensor] = None,
3878
3888
  ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3879
3889
  """
3880
3890
  Prepare TBE inputs as follows:
@@ -3901,9 +3911,20 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3901
3911
  metadata
3902
3912
  """
3903
3913
 
3914
+ if vbe_output is not None or vbe_output_offsets is not None:
3915
+ assert (
3916
+ not self.use_cpu
3917
+ ), "[TBE API v2] Using pre-allocated vbe_output is not supported on CPU"
3918
+ check_allocated_vbe_output(
3919
+ self.output_dtype,
3920
+ batch_size_per_feature_per_rank,
3921
+ vbe_output,
3922
+ vbe_output_offsets,
3923
+ )
3924
+
3904
3925
  # Generate VBE metadata
3905
3926
  vbe_metadata = self._generate_vbe_metadata(
3906
- offsets, batch_size_per_feature_per_rank
3927
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
3907
3928
  )
3908
3929
 
3909
3930
  vbe = vbe_metadata.B_offsets is not None
@@ -3976,7 +3997,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
3976
3997
  self.is_nobag,
3977
3998
  vbe_metadata.max_B_feature_rank,
3978
3999
  self.info_B_num_bits,
3979
- offsets.numel() - 1, # total_B
4000
+ offsets.numel() - 1, # total_B,
4001
+ vbe_output_offsets,
3980
4002
  )
3981
4003
  else:
3982
4004
  b_t_map = None
@@ -7,7 +7,7 @@
7
7
 
8
8
  # pyre-unsafe
9
9
 
10
- from typing import Optional
10
+ from typing import List, Optional
11
11
 
12
12
  import torch
13
13
  from torch import Tensor
@@ -31,6 +31,7 @@ except Exception:
31
31
 
32
32
  # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
33
33
  import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
34
+ from fbgemm_gpu.split_embedding_configs import sparse_type_int_to_dtype
34
35
  from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
35
36
 
36
37
 
@@ -40,6 +41,8 @@ def generate_vbe_metadata(
40
41
  pooling_mode: PoolingMode,
41
42
  feature_dims_cpu: Tensor,
42
43
  device: torch.device,
44
+ vbe_output: Optional[Tensor] = None,
45
+ vbe_output_offsets: Optional[Tensor] = None,
43
46
  ) -> invokers.lookup_args.VBEMetadata:
44
47
  """
45
48
  Generate VBE metadata based on batch_size_per_feature_per_rank.
@@ -133,6 +136,8 @@ def generate_vbe_metadata(
133
136
  max_B_feature_rank=max_B_feature_rank,
134
137
  # pyre-ignore
135
138
  output_size=output_size,
139
+ vbe_output=vbe_output,
140
+ vbe_output_offsets=vbe_output_offsets,
136
141
  )
137
142
  else:
138
143
  vbe_metadata = invokers.lookup_args.VBEMetadata(
@@ -142,5 +147,43 @@ def generate_vbe_metadata(
142
147
  max_B=-1,
143
148
  max_B_feature_rank=-1,
144
149
  output_size=-1,
150
+ vbe_output=None,
151
+ vbe_output_offsets=None,
145
152
  )
146
153
  return vbe_metadata
154
+
155
+
156
+ def check_allocated_vbe_output(
157
+ output_dtype: int,
158
+ batch_size_per_feature_per_rank: Optional[List[List[int]]],
159
+ vbe_output: Optional[Tensor] = None,
160
+ vbe_output_offsets: Optional[Tensor] = None,
161
+ ) -> None:
162
+ assert (
163
+ batch_size_per_feature_per_rank is not None
164
+ ), "[Merged_VBE] vbe_output is passed, batch_size_per_feature_per_rank cannot be None"
165
+ assert (
166
+ vbe_output is not None
167
+ ), "[Merged_VBE] vbe_output_offsets is not None, vbe_output cannot be None"
168
+ assert (
169
+ vbe_output_offsets is not None
170
+ ), "[Merged_VBE] vbe_output is not None, vbe_output_offsets cannot be None"
171
+ num_features = len(batch_size_per_feature_per_rank)
172
+ num_ranks = len(batch_size_per_feature_per_rank[0])
173
+ assert vbe_output_offsets.shape == torch.Size(
174
+ [num_ranks, num_features]
175
+ ), f"[Merged_VBE] Mismatched vbe_output_offsets shape. batch_size_per_feature_per_rank={batch_size_per_feature_per_rank}. Expected: {torch.Size([num_ranks, num_features])}, Actual: {vbe_output_offsets.shape}"
176
+ assert (
177
+ vbe_output.dim() == 1
178
+ ), f"[Merged_VBE] vbe_output must have 1 dimension, but got {vbe_output.dim()}. vbe_output shape is {vbe_output.shape}"
179
+ assert (
180
+ vbe_output_offsets.device == vbe_output.device
181
+ ), "[Merged_VBE] vbe_output_offsets and vbe_output must be on the same device"
182
+ _output_dtype = sparse_type_int_to_dtype(output_dtype)
183
+ assert (
184
+ vbe_output.dtype == _output_dtype
185
+ ), f"[Merged_VBE] vbe_output dtype must match TBE output dtype {_output_dtype} (SparseType {output_dtype}), but got {vbe_output.dtype}"
186
+ assert (
187
+ vbe_output_offsets.is_contiguous()
188
+ ), "[Merged_VBE] vbe_output_offsets needs to be contiguous"
189
+ assert vbe_output.is_contiguous(), "[Merged_VBE] vbe_output needs to be contiguous"
@@ -50,6 +50,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
50
50
  WeightDecayMode,
51
51
  )
52
52
  from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
53
+ check_allocated_vbe_output,
53
54
  generate_vbe_metadata,
54
55
  is_torchdynamo_compiling,
55
56
  )
@@ -2308,6 +2309,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2308
2309
  self,
2309
2310
  offsets: Tensor,
2310
2311
  batch_size_per_feature_per_rank: Optional[list[list[int]]],
2312
+ vbe_output: Optional[Tensor] = None,
2313
+ vbe_output_offsets: Optional[Tensor] = None,
2311
2314
  ) -> invokers.lookup_args.VBEMetadata:
2312
2315
  # Blocking D2H copy, but only runs at first call
2313
2316
  self.feature_dims = self.feature_dims.cpu()
@@ -2326,6 +2329,8 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2326
2329
  self.pooling_mode,
2327
2330
  self.feature_dims,
2328
2331
  self.current_device,
2332
+ vbe_output,
2333
+ vbe_output_offsets,
2329
2334
  )
2330
2335
 
2331
2336
  def _increment_iteration(self) -> int:
@@ -2356,11 +2361,26 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
2356
2361
  per_sample_weights: Optional[Tensor] = None,
2357
2362
  feature_requires_grad: Optional[Tensor] = None,
2358
2363
  batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
2364
+ vbe_output: Optional[Tensor] = None,
2365
+ vbe_output_offsets: Optional[Tensor] = None,
2359
2366
  # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
2360
2367
  ) -> Tensor:
2361
2368
  self.clear_cache()
2369
+ if vbe_output is not None or vbe_output_offsets is not None:
2370
+ # CPU is not supported in SSD TBE
2371
+ check_allocated_vbe_output(
2372
+ self.output_dtype,
2373
+ batch_size_per_feature_per_rank,
2374
+ vbe_output,
2375
+ vbe_output_offsets,
2376
+ )
2362
2377
  indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
2363
- indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
2378
+ indices,
2379
+ offsets,
2380
+ per_sample_weights,
2381
+ batch_size_per_feature_per_rank,
2382
+ vbe_output=vbe_output,
2383
+ vbe_output_offsets=vbe_output_offsets,
2364
2384
  )
2365
2385
 
2366
2386
  if len(self.timesteps_prefetched) == 0:
@@ -3691,13 +3711,15 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
3691
3711
  offsets: Tensor,
3692
3712
  per_sample_weights: Optional[Tensor] = None,
3693
3713
  batch_size_per_feature_per_rank: Optional[list[list[int]]] = None,
3714
+ vbe_output: Optional[Tensor] = None,
3715
+ vbe_output_offsets: Optional[Tensor] = None,
3694
3716
  ) -> tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
3695
3717
  """
3696
3718
  Prepare TBE inputs
3697
3719
  """
3698
3720
  # Generate VBE metadata
3699
3721
  vbe_metadata = self._generate_vbe_metadata(
3700
- offsets, batch_size_per_feature_per_rank
3722
+ offsets, batch_size_per_feature_per_rank, vbe_output, vbe_output_offsets
3701
3723
  )
3702
3724
 
3703
3725
  # Force casting indices and offsets to long
@@ -0,0 +1,124 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ def writeback_update_gradient(
11
+ indices: torch.Tensor,
12
+ offsets: torch.Tensor,
13
+ grad: torch.Tensor,
14
+ feature_table_map: list[int],
15
+ ) -> torch.Tensor:
16
+ """
17
+ Update gradient tensor by deduplicating indices across all features/tables.
18
+ For duplicate indices, only the first occurrence receives the gradient to achieve the assign purpose via gradient update
19
+
20
+ NOTE: This function is not supporting VBE yet
21
+
22
+ Args:
23
+ indices (torch.Tensor): Embedding indices tensor
24
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
25
+ grad (torch.Tensor): Gradient tensor to be updated
26
+ feature_table_map (list[int]): Mapping from feature to table
27
+
28
+ Returns:
29
+ torch.Tensor: Updated gradient tensor with duplicates masked out
30
+ """
31
+ if indices.numel() == 0:
32
+ return grad[0]
33
+ # get num of feature to estimate batch size
34
+ num_of_tables = len(feature_table_map)
35
+ assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
36
+ batch_size = offsets.shape[0] // num_of_tables
37
+ max_indices = indices.max()
38
+ non_empty_index = (offsets[1:] - offsets[:-1]).nonzero().flatten()
39
+ # disable dedup across different table
40
+ indices = ((offsets[non_empty_index]) // batch_size) * (1 + max_indices) + indices
41
+ grad = grad[0]
42
+ _, idx, counts = torch.unique(
43
+ indices, dim=0, sorted=True, return_inverse=True, return_counts=True
44
+ )
45
+ _, ind_sorted = torch.sort(idx, stable=True)
46
+ cum_sum = counts.cumsum(0)
47
+ cum_sum = torch.cat((torch.tensor([0]).to(indices.device), cum_sum[:-1]))
48
+ first_indicies = ind_sorted[cum_sum]
49
+ mask = torch.zeros_like(grad, device=grad.device)
50
+ original_index = non_empty_index[first_indicies]
51
+
52
+ mask[original_index] = grad[original_index]
53
+ return mask
54
+
55
+
56
+ def writeback_update_gradient_first_feature_only(
57
+ indices: torch.Tensor,
58
+ offsets: torch.Tensor,
59
+ grad: torch.Tensor,
60
+ feature_table_map: list[int],
61
+ ) -> torch.Tensor:
62
+ """
63
+ Special case of writeback_update_gradient where gradient only needs to be updated for the first feature. Other features will be forward-only
64
+
65
+ NOTE: This function is not supporting VBE yet
66
+
67
+ Args:
68
+ indices (torch.Tensor): Embedding indices tensor
69
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
70
+ grad (torch.Tensor): Gradient tensor to be updated
71
+ feature_table_map (list[int]): Mapping from feature to table
72
+
73
+ Returns:
74
+ torch.Tensor: Updated gradient tensor with duplicates masked out
75
+ """
76
+ num_of_tables = len(feature_table_map)
77
+ batch_size = (offsets.shape[0] - 1) // num_of_tables
78
+ shrink_indices = indices[: offsets[batch_size]]
79
+ if shrink_indices.numel() == 0 or indices.numel() == 0:
80
+ return grad[0]
81
+ assert num_of_tables * indices.max() < torch.iinfo(indices.dtype).max
82
+
83
+ grad = grad[0]
84
+ _, idx, counts = torch.unique(
85
+ shrink_indices, dim=0, sorted=True, return_inverse=True, return_counts=True
86
+ )
87
+ _, ind_sorted = torch.sort(idx, stable=True)
88
+ cum_sum = counts.cumsum(0)
89
+ cum_sum = torch.cat((torch.tensor([0]).to(shrink_indices.device), cum_sum[:-1]))
90
+ first_indicies = ind_sorted[cum_sum]
91
+ mask = torch.zeros_like(grad, device=grad.device)
92
+
93
+ mask[first_indicies] = grad[first_indicies]
94
+ return mask
95
+
96
+
97
+ def writeback_gradient(
98
+ grad: torch.Tensor,
99
+ indices: torch.Tensor,
100
+ offsets: torch.Tensor,
101
+ feature_table_map: list[int],
102
+ writeback_first_feature_only: bool = False,
103
+ ) -> tuple[torch.Tensor]:
104
+ """
105
+ Compute deduplicated gradient for writeback operation.
106
+
107
+ Args:
108
+ grad (torch.Tensor): Gradient tensor to be updated
109
+ indices (torch.Tensor): Embedding indices tensor
110
+ offsets (torch.Tensor): Offsets tensor for batched embeddings
111
+ feature_table_map (list[int]): Mapping from feature to table
112
+ writeback_first_feature_only (bool): If True, only first feature will apply gradient update, other features will be read-only
113
+
114
+ Returns:
115
+ tuple[torch.Tensor]: Tuple containing the updated gradient tensor
116
+ """
117
+ if writeback_first_feature_only:
118
+ return (
119
+ writeback_update_gradient_first_feature_only(
120
+ indices, offsets, grad, feature_table_map
121
+ ),
122
+ )
123
+ else:
124
+ return (writeback_update_gradient(indices, offsets, grad, feature_table_map),)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fbgemm_gpu_genai_nightly
3
- Version: 2025.12.17
3
+ Version: 2026.1.4
4
4
  Home-page: https://github.com/pytorch/fbgemm
5
5
  Author: FBGEMM Team
6
6
  Author-email: packages@pytorch.org
@@ -1,8 +1,8 @@
1
1
  fbgemm_gpu/__init__.py,sha256=bL2dL7uYeXb1GvdjIDUTcLXLRGNfmnI4MQoE3-Gg5m8,6361
2
- fbgemm_gpu/asmjit.so,sha256=231yAFvSUfy_B5xni9sAPQlsi5so9alFN3tXN7GFcMQ,484232
2
+ fbgemm_gpu/asmjit.so,sha256=UxnhHlu9LgmoRXa8fZwSX56b5QKffBxfAOs0AZLxRfk,501728
3
3
  fbgemm_gpu/batched_unary_embeddings_ops.py,sha256=GYeJ9pg-Wc9FokXVci_npDsL6UV18-pJXID2xzrJ9O8,2904
4
4
  fbgemm_gpu/enums.py,sha256=37ewGSfO1x7sO31ZkRiqV1yKuklfHXT5qZIxzeeGogo,755
5
- fbgemm_gpu/fbgemm.so,sha256=_fCdNktofSTSuedF0cLL3AKDTeKca5tty8RnRzKFCdg,5803160
5
+ fbgemm_gpu/fbgemm.so,sha256=U864UANx-CVyFYk5ADawCd0uWRfntHaVcyl6AVty_3Q,5642616
6
6
  fbgemm_gpu/metrics.py,sha256=TsurFLJf0nJvPDN7urWb4LMQlf5RgdWPTTTDO7S4wtI,5663
7
7
  fbgemm_gpu/permute_pooled_embedding_modules.py,sha256=vOXMYclaGnwSt0St_SOAlAe18kz6WjMyTeHnC9jLhcE,5130
8
8
  fbgemm_gpu/permute_pooled_embedding_modules_split.py,sha256=f3VJvH_kw9Ltd_DXtaf_PJPHmlmEWrQgzQ7MDkhh5Nw,2746
@@ -10,20 +10,20 @@ fbgemm_gpu/quantize_comm.py,sha256=ZfXtRHfqpVpV6k2PDL6oTUkKYzopqAV2M6vavp_RLSM,1
10
10
  fbgemm_gpu/quantize_utils.py,sha256=q8Aokk6nlHbXF6HcDBbhBCAGSZV4klM8uPF-MUFFtAw,8324
11
11
  fbgemm_gpu/runtime_monitor.py,sha256=YXRUv6nXCsoTgh5_RzailTGvCYzwoYDb-eR4rlGwtaw,7619
12
12
  fbgemm_gpu/sparse_ops.py,sha256=_EJC1pAbNnAnVQQ5JBg4DAV2TboIj-4XQkiKMmg1vXI,50417
13
- fbgemm_gpu/split_embedding_configs.py,sha256=fv29efZGD_cvh5KwdvTFD6GZtqJLYjWXW_0vMeyT_6k,15483
13
+ fbgemm_gpu/split_embedding_configs.py,sha256=EuVFKIDrgRQpRC5mmB4Du6WftK5GXJvDue9_ezt_eBI,16575
14
14
  fbgemm_gpu/split_embedding_inference_converter.py,sha256=AghGW22MgMsdHzdwdPMPYDjgas5AE_estckY8rMgXVU,7056
15
15
  fbgemm_gpu/split_embedding_optimizer_ops.py,sha256=wXuGazClBMk62yL_r9udUIKaPgQP7SlkSb5ugB75wrQ,711
16
16
  fbgemm_gpu/split_embedding_utils.py,sha256=Gb40ZKeATxIKEKI3aVQMgDDBanNpKMc53Z43mnzdR_I,851
17
17
  fbgemm_gpu/split_table_batched_embeddings_ops.py,sha256=_MIp6uHYHLn4GxGdrGsfddfSsZ2Z9mjsYIrih3ncI1I,2339
18
18
  fbgemm_gpu/split_table_batched_embeddings_ops_common.py,sha256=eFxb_bDfBV8G76pmd-SxDXXXnqgbuGYOS4pSU8JS5dg,19295
19
19
  fbgemm_gpu/split_table_batched_embeddings_ops_inference.py,sha256=dGC85xjQiRUrequBibSf9oMAVHT5Q49zsVo2zW4n_88,81679
20
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=D72laY5iFC3_6f_qHnPMizDDxwI0QW7-21RyY0ZikK4,187705
21
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py,sha256=e3O9ElaWBGvG7TdT3Ok_8cB06jhskXuyCQ0t40dzsEY,5449
20
+ fbgemm_gpu/split_table_batched_embeddings_ops_training.py,sha256=rNGMELM_xFIsdS_340PB7bsn9h_VjONq_JJG1SjHyvQ,188992
21
+ fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py,sha256=jofAN2UB_iSk53Id6MBvn9Bi3Qxw67IL0_VE_EHlw_Q,7593
22
22
  fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py,sha256=7qGkO8FARku38mFYl4Bc4qL8dS1wrfyorS9l1m5ZAVA,718
23
23
  fbgemm_gpu/tbe_input_multiplexer.py,sha256=TQjwkJ2JkOaQsMYuRdk9RbNa9759EPEtx8bYclChtZY,3063
24
24
  fbgemm_gpu/uvm.py,sha256=guNK8ZzR80jmv-CyRgEhxhVYhjz3R9d6tB8Hu1uWDUo,1047
25
25
  fbgemm_gpu/config/__init__.py,sha256=yN0KAneCICgF2BTfOYGsd0qU1PvZX_6msC6YHHZKLMg,292
26
- fbgemm_gpu/config/feature_list.py,sha256=iDOGr9nwTqUhWsqOefRIqIo1jwLSeSII4jGnLeU01kg,2359
26
+ fbgemm_gpu/config/feature_list.py,sha256=hhDNkkafd-Oetvuqv9ylBVTNM-lKPi029mpRqq-JZCA,2467
27
27
  fbgemm_gpu/docs/__init__.py,sha256=DR6hMSQrsZALfH2AnuJQ4Zq2CfBUUhMN8YjD6APjiAE,523
28
28
  fbgemm_gpu/docs/common.py,sha256=8ipXTwVb222X-aZ71O6n8fhxHCHPNhJEHMFiO7epcIs,273
29
29
  fbgemm_gpu/docs/examples.py,sha256=ZMN_6sL74LH_hrp2bF_hmg8gi29GhcgvwV3kCMjxkoE,2377
@@ -32,9 +32,9 @@ fbgemm_gpu/docs/merge_pooled_embedding_ops.py,sha256=oJLgSgZQmhsyGLbTmZTxNgQrk65
32
32
  fbgemm_gpu/docs/permute_pooled_embedding_ops.py,sha256=tZUqLVXlk5O6VAKKDA-OEMx2fCu5QPOOeoAPZA9_nLY,4454
33
33
  fbgemm_gpu/docs/quantize_ops.py,sha256=xTtOaVK1P02ymreE_i21YiyYDZCqhoZY9eWp_mEIRlo,1297
34
34
  fbgemm_gpu/docs/sparse_ops.py,sha256=gSLUFdnu8lle_6gLewFkM20wL3ek2jKLvDGMKR6POaY,27292
35
- fbgemm_gpu/docs/target.genai.json.py,sha256=Zzc84wR-3UYjzYFUQk2gX2r6FEia8mMClTg1gA1HVoc,79
35
+ fbgemm_gpu/docs/target.genai.json.py,sha256=5TMzQCJ6eRjDaUActAOucxjizI7IZg56rn512-ujiE4,77
36
36
  fbgemm_gpu/experimental/example/__init__.py,sha256=OvJHZgWnycL1gWKyCXFJCTKuys3KAqx4iadjx3R-tBQ,723
37
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=Mt99lNGcaYTxWVGqPP8Q2l-n_7lj2DNmPHura1eHAMM,183392
37
+ fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so,sha256=y0Z22D1LnOkH9vXtRVPYWJ5raZC27OTViPEtnqi8TyY,190656
38
38
  fbgemm_gpu/experimental/example/utils.py,sha256=Je__VkMlBMLOhh7NXOocOdvaa2gz9kl9Dkqeu25tpFA,562
39
39
  fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py,sha256=1CqUfzlYyXTvU-BNaUq4RZpLV-2lKAVCAHeJzSIZFWw,419
40
40
  fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py,sha256=R4VNZdPSgmRmwDfTt2CShED2SGUF6dCXSUW2C4LISgE,215713
@@ -43,11 +43,11 @@ fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py,sha256=5ClZ-GDrx6q0uaqW
43
43
  fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py,sha256=SltbY_dsit5e7B8lDIB_VYPrEq0t9kckthj9mQaVNfA,7571
44
44
  fbgemm_gpu/experimental/gemm/triton_gemm/utils.py,sha256=rULXIpVaaRS3GKUZ1RHcWUrUyy0xMVREwS1SFShGgcw,4302
45
45
  fbgemm_gpu/experimental/gen_ai/__init__.py,sha256=r3NlNCXuIh0pfKwKU5v14y6AZkpoIkKWbtzxSprgeKA,1713
46
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=rR2xW3Km17SqFFHLL-1WKIQ2hxd7-UpiEbEQmsvx8z8,64298336
46
+ fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so,sha256=2iHWrQDzhysRNMPbjFQpsxNdAkIRq__vTHy75sa4kJo,65238760
47
47
  fbgemm_gpu/experimental/gen_ai/quantize.py,sha256=KAljWSdN-1_c5DWfT-3MDxWLMULK49Yu36t6TmQI9Tw,12599
48
48
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py,sha256=-R_LxyHpdXMILU9TNuYoRisBCkfK0_VLyixefaeZf4g,1463
49
49
  fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py,sha256=gbhNU3mDTKJb3yt3inIDbiUjX_SG1oZfzgDygtHvMpk,10101
50
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py,sha256=r1AhV2qdIqxtYYeze6yr6_wg_Xzfzc4QJEBeNsGY4Gw,17570
50
+ fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py,sha256=fD39_WH7TfNCiP5Vl46ToX6PsLMLUFLhizT26Qe7TWg,17282
51
51
  fbgemm_gpu/experimental/gen_ai/bench/__init__.py,sha256=XpAK_eyqDSKeFC5J9KpnKtbZG07mrDh9d2j1LFKzr-8,404
52
52
  fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py,sha256=ApEyJOf_rdIo8V_EgvhZXBGNov8ITC_dnB95v8szulI,8515
53
53
  fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py,sha256=K9Nib6D7xJbw1QwEVuCJrVyI1qs988moo3cieVKYuFY,12057
@@ -99,7 +99,7 @@ fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py,sha256=vZHj7KIe1DoJDy5eft29Xt
99
99
  fbgemm_gpu/tbe/ssd/__init__.py,sha256=wzfMT10cp_dqK2lrebC449hOdexBnizcf_98lA1NyHs,483
100
100
  fbgemm_gpu/tbe/ssd/common.py,sha256=1J8K7sTQswgCYWaVwF-ZdCJj7mNN6O9GI70AaZWzJGE,1044
101
101
  fbgemm_gpu/tbe/ssd/inference.py,sha256=B_uX66ajGA9YKGlFa5TmGWs7b-b1RFigzwxmENZ9Oio,22816
102
- fbgemm_gpu/tbe/ssd/training.py,sha256=ElFvQHF5wQBzrqU34F6ZR2IEBVzKO3j3symntP15S3E,211380
102
+ fbgemm_gpu/tbe/ssd/training.py,sha256=C6M3H_f8oWWRkC4R-BJED73au-Gl9SUVllxOoFSiDkI,212234
103
103
  fbgemm_gpu/tbe/ssd/utils/__init__.py,sha256=5DgmR2HA6NtmYh2ddkUgpDsZ6a7hF0DPedA1gMpdh18,250
104
104
  fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py,sha256=SFg2-29b-i49LWm-FlaWUkTz2XzXbicYi_AzVj4jKNE,7601
105
105
  fbgemm_gpu/tbe/stats/__init__.py,sha256=on29iDtq7cVNh90JR9aeFNG-K9DDoYq0JryzoplL49I,322
@@ -119,9 +119,10 @@ fbgemm_gpu/utils/__init__.py,sha256=JQQNdcTTaEU6ptK-OW-ZQBwTFxEZZpWOtBXWwEZm39o,
119
119
  fbgemm_gpu/utils/filestore.py,sha256=oVtbKGaPQki1JgbJCkrkElukOFVyxntQpSC0lYBKgho,6455
120
120
  fbgemm_gpu/utils/loader.py,sha256=1hCEhNvkflniH46fGcrguLeP1z-6uyOu2QFwqKU5CIM,990
121
121
  fbgemm_gpu/utils/torch_library.py,sha256=ywsAHjbuwesj50LjEu99WkAH17FlaVgePZ9OmFg6YE4,4193
122
+ fbgemm_gpu/utils/writeback_util.py,sha256=PyVbHp1EuF-GKrJv_CTP6B50Z0oBblXKucf7Rhd6KKY,4614
122
123
  list_versions/__init__.py,sha256=UmTeqCk-UJWFtlZQWvZao3xvui2w9E3X_JdOXVjRaNw,315
123
124
  list_versions/cli_run.py,sha256=CChZoXQ-tiKaWboXAYlPVJ5w8K5zAKiKcncA087I1sc,4508
124
- fbgemm_gpu_genai_nightly-2025.12.17.dist-info/METADATA,sha256=oJzBJPiPBYhvls7W-MDbX-yBH6y4CRyGDHNvDFaAyBU,2657
125
- fbgemm_gpu_genai_nightly-2025.12.17.dist-info/WHEEL,sha256=Nkv8TSWVt7XcnRf1cdq5HOzycTl6Pjzlmn7gPSv4NiQ,108
126
- fbgemm_gpu_genai_nightly-2025.12.17.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
127
- fbgemm_gpu_genai_nightly-2025.12.17.dist-info/RECORD,,
125
+ fbgemm_gpu_genai_nightly-2026.1.4.dist-info/METADATA,sha256=MjhefCkWlccqGa-waygmSKkW1vaKWbpxX1U8VLRrMJ0,2655
126
+ fbgemm_gpu_genai_nightly-2026.1.4.dist-info/WHEEL,sha256=Nkv8TSWVt7XcnRf1cdq5HOzycTl6Pjzlmn7gPSv4NiQ,108
127
+ fbgemm_gpu_genai_nightly-2026.1.4.dist-info/top_level.txt,sha256=_2s1Aa08r_eDn0JP4FjOhzK09Q8bVlEI7q8pMep51UY,25
128
+ fbgemm_gpu_genai_nightly-2026.1.4.dist-info/RECORD,,