fbgemm-gpu-nightly-cpu 2025.3.27__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. fbgemm_gpu/__init__.py +118 -23
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +142 -1
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +244 -76
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +26 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +208 -105
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +261 -53
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +9 -58
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +10 -59
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +225 -41
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +211 -36
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +195 -26
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +225 -41
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +225 -41
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +216 -111
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +221 -37
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +259 -53
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +192 -96
  58. fbgemm_gpu/split_embedding_configs.py +287 -3
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_embedding_optimizer_codegen/optimizer_args.py +2 -0
  61. fbgemm_gpu/split_embedding_optimizer_codegen/split_embedding_optimizer_rowwise_adagrad.py +2 -0
  62. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +275 -9
  63. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +44 -37
  64. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +900 -126
  65. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  66. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  67. fbgemm_gpu/tbe/bench/__init__.py +13 -2
  68. fbgemm_gpu/tbe/bench/bench_config.py +37 -9
  69. fbgemm_gpu/tbe/bench/bench_runs.py +301 -12
  70. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +189 -0
  71. fbgemm_gpu/tbe/bench/eeg_cli.py +138 -0
  72. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +4 -5
  73. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  74. fbgemm_gpu/tbe/bench/tbe_data_config.py +116 -198
  75. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  76. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +158 -32
  77. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +16 -8
  78. fbgemm_gpu/tbe/bench/utils.py +129 -5
  79. fbgemm_gpu/tbe/cache/__init__.py +1 -0
  80. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  81. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -5
  82. fbgemm_gpu/tbe/ssd/common.py +27 -0
  83. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  84. fbgemm_gpu/tbe/ssd/training.py +2930 -195
  85. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +34 -3
  86. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  87. fbgemm_gpu/tbe/stats/bench_params_reporter.py +349 -0
  88. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  89. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  90. fbgemm_gpu/tbe/utils/requests.py +53 -28
  91. fbgemm_gpu/tbe_input_multiplexer.py +16 -7
  92. fbgemm_gpu/triton/common.py +0 -1
  93. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  94. fbgemm_gpu/triton/quantize.py +14 -9
  95. fbgemm_gpu/utils/filestore.py +56 -5
  96. fbgemm_gpu/utils/torch_library.py +2 -2
  97. fbgemm_gpu/utils/writeback_util.py +124 -0
  98. fbgemm_gpu/uvm.py +3 -0
  99. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +3 -6
  100. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  101. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  102. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -3
  103. list_versions/cli_run.py +161 -0
  104. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/RECORD +0 -126
  105. fbgemm_gpu_nightly_cpu-2025.3.27.dist-info/top_level.txt +0 -1
  106. {fbgemm_gpu_nightly_cpu-2025.3.27.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -13,10 +13,11 @@
13
13
 
14
14
 
15
15
  import logging
16
- from typing import List, Optional, Tuple, TypeVar
16
+ from typing import Optional, TypeVar
17
17
 
18
18
  import torch
19
19
 
20
+ # fmt:skip
20
21
  from fbgemm_gpu.quantize_utils import (
21
22
  bf16_to_fp32,
22
23
  fp16_to_fp32,
@@ -25,12 +26,10 @@ from fbgemm_gpu.quantize_utils import (
25
26
  fp32_to_hfp8_with_clamp,
26
27
  fp32_to_mx4,
27
28
  hfp8_to_fp32,
28
- mx4_to_fp32,
29
+ mx4_to_float,
29
30
  RoundingMode,
30
31
  )
31
-
32
32
  from fbgemm_gpu.split_embedding_configs import SparseType
33
-
34
33
  from torch.autograd.profiler import record_function # usort:skip
35
34
  from dataclasses import dataclass
36
35
 
@@ -66,8 +65,8 @@ class QuantizationContext:
66
65
  row_dim: int = ROW_DIM_DEFAULT
67
66
  row_dim_quant: int = -1
68
67
  mx_group_size: int = MX_GROUP_SIZE_DEFAULT
69
- rounding_mode: RoundingMode = RoundingMode.even
70
- padded_dim_sum_per_rank: Optional[List[int]] = None
68
+ rounding_mode: Optional[RoundingMode] = RoundingMode.even
69
+ padded_dim_sum_per_rank: Optional[list[int]] = None
71
70
 
72
71
 
73
72
  def _quantize_tensor(
@@ -123,6 +122,7 @@ def _dequantize_tensor(
123
122
  comm_precision: SparseType,
124
123
  ctx: Optional[QuantizationContext] = None,
125
124
  is_fwd: bool = True,
125
+ output_dtype: Optional[SparseType] = None,
126
126
  ) -> torch.Tensor:
127
127
  if comm_precision == SparseType.FP32:
128
128
  assert quantized_tensor.dtype == torch.float
@@ -137,8 +137,12 @@ def _dequantize_tensor(
137
137
  if ctx is not None and ctx.row_dim > 0:
138
138
  row_dim_quant = ctx.row_dim_quant
139
139
  quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
140
+ # use provided output_dtype or default to FP32 (0)
141
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
140
142
  dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
141
- quantized_tensor_2d, is_fwd
143
+ quantized_tensor_2d,
144
+ is_fwd,
145
+ output_dtype_int,
142
146
  )
143
147
  return dequant_tensor.view(-1)
144
148
  else:
@@ -154,7 +158,7 @@ def _dequantize_tensor(
154
158
  return dequant_tensor.view(-1)
155
159
  elif comm_precision == SparseType.MX4:
156
160
  mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
157
- return mx4_to_fp32(quantized_tensor, mx_group_size)
161
+ return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
158
162
  else:
159
163
  raise ValueError(f"comm_precision={comm_precision} is not supported")
160
164
 
@@ -167,6 +171,8 @@ class QuantizedCommCodec:
167
171
  loss_scale: Optional[float] = None,
168
172
  row_dim: Optional[int] = None,
169
173
  is_fwd: bool = True,
174
+ rounding_mode: Optional[RoundingMode] = None,
175
+ output_dtype: Optional[SparseType] = None,
170
176
  ) -> None:
171
177
  if loss_scale is not None:
172
178
  if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -183,8 +189,13 @@ class QuantizedCommCodec:
183
189
  self._loss_scale = loss_scale
184
190
  self._is_fwd = is_fwd
185
191
  self._row_dim: int = -1 if row_dim is None else row_dim
192
+ self._rounding_mode: Optional[RoundingMode] = rounding_mode
193
+ self._output_dtype: Optional[SparseType] = output_dtype
186
194
  if self._comm_precision == SparseType.MX4:
187
195
  self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
196
+ self._rounding_mode = (
197
+ RoundingMode.even if rounding_mode is None else rounding_mode
198
+ )
188
199
 
189
200
  def encode(
190
201
  self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
@@ -211,7 +222,11 @@ class QuantizedCommCodec:
211
222
  f"## decoder {self._comm_precision} {self._loss_scale} ##"
212
223
  ):
213
224
  dequantized_tensor = _dequantize_tensor(
214
- input_tensor, self._comm_precision, ctx, self._is_fwd
225
+ input_tensor,
226
+ self._comm_precision,
227
+ ctx,
228
+ self._is_fwd,
229
+ output_dtype=self._output_dtype,
215
230
  )
216
231
  return dequantized_tensor
217
232
 
@@ -258,7 +273,9 @@ class QuantizedCommCodec:
258
273
  return QuantizationContext(self._row_dim)
259
274
  if self._comm_precision == SparseType.MX4:
260
275
  return QuantizationContext(
261
- row_dim=self._row_dim, mx_group_size=self._row_dim
276
+ row_dim=self._row_dim,
277
+ mx_group_size=self._row_dim,
278
+ rounding_mode=self._rounding_mode,
262
279
  )
263
280
  # int8 rowwise is default
264
281
  return QuantizationContext()
@@ -266,10 +283,10 @@ class QuantizedCommCodec:
266
283
  def padded_size(
267
284
  self,
268
285
  input_tensor: torch.Tensor,
269
- dim_per_rank: List[int],
286
+ dim_per_rank: list[int],
270
287
  my_rank: int,
271
288
  qcomm_ctx: QuantizationContext,
272
- ) -> Tuple[int, int]:
289
+ ) -> tuple[int, int]:
273
290
  if input_tensor.ndim == 1:
274
291
  return input_tensor.shape[0], 0
275
292
  # return padded size for the feature dimension (dim 1), 0 if no padding needed.
@@ -10,11 +10,34 @@
10
10
  import logging
11
11
  from typing import Optional, Union
12
12
 
13
- import torch
13
+ import torch # isort:skip
14
14
 
15
- from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
15
+ import fbgemm_gpu
16
+ from fbgemm_gpu.split_embedding_configs import SparseType
17
+ from fbgemm_gpu.triton.common import RoundingMode
16
18
  from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
17
19
 
20
+ try:
21
+ if torch.cuda.is_available():
22
+ from fbgemm_gpu.triton import quantize_mx4
23
+ from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
24
+ except Exception:
25
+ pass
26
+
27
+
28
+ try:
29
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
30
+ open_source = bool(getattr(fbgemm_gpu, "open_source", False))
31
+ except NotImplementedError:
32
+ open_source = False
33
+
34
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
35
+ if not open_source:
36
+ from mtia.kernels.triton.mx4.quantize import (
37
+ triton_dequantize_mx4 as mtia_dequantize_mx4,
38
+ triton_quantize_mx4 as mtia_quantize_mx4,
39
+ )
40
+
18
41
  logger: logging.Logger = logging.getLogger()
19
42
 
20
43
  try:
@@ -60,7 +83,7 @@ def fp32_to_mx4(
60
83
  if rounding_mode is None:
61
84
  rounding_mode = RoundingMode.even
62
85
 
63
- if not tensor.is_cuda:
86
+ if not tensor.is_cuda and not tensor.is_mtia:
64
87
  return py_quantize_mx4(
65
88
  tensor,
66
89
  group_size,
@@ -71,6 +94,15 @@ def fp32_to_mx4(
71
94
  )
72
95
 
73
96
  if use_triton:
97
+ if tensor.is_mtia:
98
+ return mtia_quantize_mx4(
99
+ tensor,
100
+ group_size,
101
+ ebits=ebits,
102
+ mbits=mbits,
103
+ rounding_mode=rounding_mode,
104
+ stochastic_casting=stochastic_casting,
105
+ )
74
106
  return quantize_mx4(
75
107
  tensor,
76
108
  group_size,
@@ -102,23 +134,71 @@ def mx4_to_fp32(
102
134
  ) -> torch.Tensor:
103
135
  """Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
104
136
 
137
+ This function is kept for backward compatibility and always returns FP32.
138
+ For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
139
+ """
140
+ return mx4_to_float(
141
+ tensor,
142
+ group_size,
143
+ use_triton,
144
+ ebits,
145
+ mbits,
146
+ output_dtype=None, # None = FP32 default for backward compatibility
147
+ )
148
+
149
+
150
+ def mx4_to_float(
151
+ tensor: torch.Tensor,
152
+ group_size: int = 32,
153
+ use_triton: bool = True,
154
+ ebits: int = 2,
155
+ mbits: int = 1,
156
+ output_dtype: Optional[SparseType] = None,
157
+ ) -> torch.Tensor:
158
+ """Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
159
+
105
160
  Args:
106
161
  tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
107
162
  group_size (int): Compute scale in chunks of group_size.
108
163
  use_triton (bool): If set, use triton quantization, otherwise cuda.
109
164
  ebits (int): Number of exponent bits in target mx4 format.
110
165
  mbits (int): Number of mantissa bits in target mx4 format.
166
+ output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
167
+ Defaults to None (FP32) for backward compatibility.
111
168
 
112
169
  Return:
113
- output: FP32 tensor with total elements (M).
170
+ output: Tensor with dtype matching output_dtype and total elements (M).
114
171
  """
172
+ # Validate output_dtype
173
+ supported_dtypes = {SparseType.FP32, SparseType.BF16}
174
+ if output_dtype is not None and output_dtype not in supported_dtypes:
175
+ raise ValueError(
176
+ f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
177
+ f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
178
+ f"Use BF16 for memory savings with same dynamic range as FP32."
179
+ )
180
+
181
+ target_dtype = (
182
+ output_dtype.as_dtype() if output_dtype is not None else torch.float32
183
+ )
184
+
115
185
  # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
116
- if not tensor.is_cuda:
117
- return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
186
+ if not tensor.is_cuda and not tensor.is_mtia:
187
+ result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
188
+ return result.to(target_dtype) if output_dtype is not None else result
118
189
  if use_triton:
119
- return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
190
+ if tensor.is_mtia:
191
+ return mtia_dequantize_mx4(
192
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
193
+ )
194
+ return triton_dequantize_mx4(
195
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
196
+ )
120
197
  else:
121
- return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
198
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
199
+ return torch.ops.fbgemm.dequantize_mx_cuda(
200
+ tensor.flatten(), group_size, output_dtype_int
201
+ )
122
202
 
123
203
 
124
204
  def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
@@ -12,7 +12,7 @@ import logging
12
12
  from collections import deque
13
13
  from dataclasses import dataclass
14
14
  from types import TracebackType
15
- from typing import Callable, Deque, Optional, Tuple, Type, TypeVar
15
+ from typing import Callable, Optional, TypeVar
16
16
 
17
17
  import torch
18
18
 
@@ -49,6 +49,7 @@ class TBEStatsReporter(abc.ABC):
49
49
  embedding_id: str = "",
50
50
  tbe_id: str = "",
51
51
  time_unit: str = "ms",
52
+ enable_tb_metrics: bool = False,
52
53
  ) -> None:
53
54
  """
54
55
  Report the duration of a timed event.
@@ -63,6 +64,7 @@ class TBEStatsReporter(abc.ABC):
63
64
  data_bytes: int,
64
65
  embedding_id: str = "",
65
66
  tbe_id: str = "",
67
+ enable_tb_metrics: bool = False,
66
68
  ) -> None:
67
69
  """
68
70
  Report the size of some data amount.
@@ -89,9 +91,10 @@ class StdLogStatsReporter(TBEStatsReporter):
89
91
  embedding_id: str = "",
90
92
  tbe_id: str = "",
91
93
  time_unit: str = "ms",
94
+ enable_tb_metrics: bool = False,
92
95
  ) -> None:
93
96
  logging.info(
94
- f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit}"
97
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit} with {enable_tb_metrics}"
95
98
  )
96
99
 
97
100
  def report_data_amount(
@@ -101,9 +104,10 @@ class StdLogStatsReporter(TBEStatsReporter):
101
104
  data_bytes: int,
102
105
  embedding_id: str = "",
103
106
  tbe_id: str = "",
107
+ enable_tb_metrics: bool = False,
104
108
  ) -> None:
105
109
  logging.info(
106
- f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes"
110
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes with {enable_tb_metrics}"
107
111
  )
108
112
 
109
113
  def __repr__(self) -> str:
@@ -167,7 +171,7 @@ class AsyncSeriesTimerRecordedContext:
167
171
 
168
172
  def __exit__(
169
173
  self,
170
- exc_type: Optional[Type[BaseException]],
174
+ exc_type: Optional[type[BaseException]],
171
175
  exc_val: Optional[BaseException],
172
176
  exc_tb: Optional[TracebackType],
173
177
  ) -> None:
@@ -187,7 +191,7 @@ class AsyncSeriesTimer:
187
191
  """
188
192
 
189
193
  def __init__(self, report_functor: Callable[[T, float], None]) -> None:
190
- self._events_queue: Deque[Tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
194
+ self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
191
195
  deque()
192
196
  )
193
197
  self._active_start_event: Optional[torch.cuda.Event] = None
@@ -9,12 +9,14 @@
9
9
 
10
10
  import torch
11
11
 
12
+ # fmt:skip
12
13
  from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
13
14
  from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
14
15
  from fbgemm_gpu.utils import TorchLibraryFragment
15
16
 
16
17
  lib = TorchLibraryFragment("fbgemm")
17
18
 
19
+ # fmt:off
18
20
  lib.define(
19
21
  """sll_jagged_dense_bmm(
20
22
  Tensor x,
@@ -170,6 +172,7 @@ lib.define(
170
172
  ) -> Tensor
171
173
  """
172
174
  )
175
+ # fmt:on
173
176
 
174
177
  # NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
175
178
  # function however, this is not ideal because in the inference case, we don't
@@ -5,7 +5,7 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-strict
8
- from typing import Any, Tuple
8
+ from typing import Any
9
9
 
10
10
  import torch
11
11
 
@@ -65,7 +65,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
65
65
  # pyre-fixme
66
66
  def backward(
67
67
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
68
- ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
68
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
69
69
  """
70
70
  # X = [Sum_B, D]
71
71
  # Y = [B, D, T]
@@ -73,7 +73,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
73
73
  # dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
74
74
  # dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
75
75
  """
76
- (x, y, x_offsets) = ctx.saved_tensors
76
+ x, y, x_offsets = ctx.saved_tensors
77
77
  N = ctx.N
78
78
  grad_x = cpu_jagged_dense_bmm_kernel(
79
79
  grad_output, y.permute(0, 2, 1), x_offsets, N
@@ -128,7 +128,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
128
128
  # pyre-fixme
129
129
  def backward(
130
130
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
131
- ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
131
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
132
132
  """
133
133
  # X = [Sum_B, D]
134
134
  # Y = [Sum_B, T]
@@ -136,7 +136,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
136
136
  # dXT = dZ * YT -> dX = Y * dZT
137
137
  # dY = X * dZ -> X * dZ
138
138
  """
139
- (x, y, offsets) = ctx.saved_tensors
139
+ x, y, offsets = ctx.saved_tensors
140
140
  N = ctx.N
141
141
  grad_x = cpu_jagged_dense_bmm_kernel(
142
142
  y, grad_output.permute(0, 2, 1), offsets, N
@@ -172,7 +172,7 @@ def cpu_dense_jagged_cat_jagged_out(
172
172
  b: torch.Tensor,
173
173
  b_offsets: torch.Tensor,
174
174
  max_seq_len: int,
175
- ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ ) -> tuple[torch.Tensor, torch.Tensor]:
176
176
  assert a.size(0) == b_offsets.size(0) - 1
177
177
  c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
178
178
  c_offsets = b_offsets + torch.arange(
@@ -368,7 +368,7 @@ class JaggedSoftmaxCPU(torch.autograd.Function):
368
368
  # pyre-fixme
369
369
  def backward(
370
370
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
371
- ) -> Tuple[torch.Tensor, None, None]:
371
+ ) -> tuple[torch.Tensor, None, None]:
372
372
  y, x_offsets = ctx.saved_tensors
373
373
 
374
374
  B = x_offsets.size(0) - 1
@@ -923,7 +923,7 @@ class JaggedDenseAddCPU(torch.autograd.Function):
923
923
  def backward(
924
924
  ctx, # pyre-ignore
925
925
  grad_output: torch.Tensor,
926
- ) -> Tuple[torch.Tensor, None, torch.Tensor, None]:
926
+ ) -> tuple[torch.Tensor, None, torch.Tensor, None]:
927
927
  (offsets,) = ctx.saved_tensors
928
928
  grad_dense = torch.ops.fbgemm.jagged_to_padded_dense(
929
929
  grad_output, [offsets], [ctx.max_seq_len]
@@ -10,19 +10,16 @@
10
10
  from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
11
11
  dense_jagged_cat_jagged_out,
12
12
  )
13
-
14
13
  from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
15
14
  jagged2_to_padded_dense,
16
15
  Jagged2ToPaddedDense, # noqa F401
17
16
  )
18
-
19
17
  from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
20
18
  jagged_dense_bmm,
21
19
  jagged_jagged_bmm,
22
20
  JaggedDenseBmm, # noqa F401
23
21
  JaggedJaggedBmm, # noqa F401
24
22
  )
25
-
26
23
  from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
27
24
  array_jagged_bmm_jagged_out,
28
25
  ArrayJaggedBmmNopadding, # noqa F401
@@ -31,38 +28,31 @@ from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
31
28
  triton_array_jagged_bmm_jagged_out, # noqa F401
32
29
  triton_jagged_jagged_bmm_jagged_out, # noqa F401
33
30
  )
34
-
35
31
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
36
32
  jagged_dense_elementwise_add,
37
33
  JaggedDenseAdd, # noqa F401
38
34
  )
39
-
40
35
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
41
36
  jagged_dense_elementwise_mul_jagged_out,
42
37
  JaggedDenseElementwiseMul, # noqa F401
43
38
  )
44
-
45
39
  from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
46
40
  jagged_dense_flash_attention,
47
41
  JaggedDenseFlashAttention, # noqa F401
48
42
  )
49
-
50
43
  from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
51
44
  jagged_flash_attention_basic,
52
45
  JaggedFlashAttentionBasic, # noqa F401
53
46
  )
54
-
55
47
  from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
56
48
  triton_jagged_self_substraction_jagged_out,
57
49
  )
58
-
59
50
  from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
60
51
  jagged2_softmax,
61
52
  Jagged2Softmax, # noqa F401
62
53
  jagged_softmax,
63
54
  JaggedSoftmax, # noqa F401
64
55
  )
65
-
66
56
  from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
67
57
  multi_head_jagged_flash_attention,
68
58
  MultiHeadJaggedFlashAttention, # noqa F401
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -196,9 +195,9 @@ class Jagged2ToPaddedDense(torch.autograd.Function):
196
195
  # pyre-fixme
197
196
  def backward(
198
197
  ctx, grad_output: torch.Tensor
199
- ) -> Tuple[torch.Tensor, None, None, None]:
198
+ ) -> tuple[torch.Tensor, None, None, None]:
200
199
  max_length = ctx.max_length
201
- (lengths, offsets) = ctx.saved_tensors
200
+ lengths, offsets = ctx.saved_tensors
202
201
  grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
203
202
  return (grad_in, None, None, None)
204
203
 
@@ -326,7 +326,7 @@ class JaggedDenseBmm(torch.autograd.Function):
326
326
 
327
327
  # logging.info(f"Jagged bmm backward called")
328
328
 
329
- (x, y, x_offsets) = ctx.saved_tensors
329
+ x, y, x_offsets = ctx.saved_tensors
330
330
  N = ctx.N
331
331
  grad_x = triton_jagged_dense_bmm(
332
332
  grad_output, y.permute(0, 2, 1), x_offsets, N, allow_tf32=ctx.allow_tf32
@@ -369,7 +369,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
369
369
  # dXT = dZ * YT -> dX = Y * dZT
370
370
  # dY = X * dZ -> X * dZ
371
371
  """
372
- (x, y, offsets) = ctx.saved_tensors
372
+ x, y, offsets = ctx.saved_tensors
373
373
  N = ctx.N
374
374
  grad_x = triton_jagged_dense_bmm(
375
375
  y, grad_output.permute(0, 2, 1), offsets, N, allow_tf32=ctx.allow_tf32
@@ -8,6 +8,7 @@
8
8
 
9
9
  import torch
10
10
 
11
+ # fmt:skip
11
12
  from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
12
13
  dense_to_jagged,
13
14
  jagged_to_dense,
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -171,7 +170,7 @@ def jagged_dense_flash_attention_fwd(
171
170
  jagged_offsets,
172
171
  max_seq_len,
173
172
  allow_tf32=False,
174
- ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ ) -> tuple[torch.Tensor, torch.Tensor]:
175
174
  """
176
175
  Q: jagged tensor, [sum_B, D]
177
176
  K: dense tensor, [B, D, T]
@@ -192,7 +191,7 @@ def jagged_dense_flash_attention_fwd(
192
191
  assert Q.size() == V.size(), "incompatible dimensions for Q and V"
193
192
  assert jagged_offsets.is_contiguous(), "jagged_offsets must be contiguous"
194
193
 
195
- (B, D, T) = K.size()
194
+ B, D, T = K.size()
196
195
  assert D > 0 and (D & (D - 1)) == 0, "D needs to be a power of two"
197
196
 
198
197
  attn_out = torch.zeros(B, T, D, dtype=Q.dtype, device=Q.device)
@@ -650,7 +649,7 @@ def jagged_dense_flash_attention_bwd(
650
649
  jagged_offsets,
651
650
  max_seq_len,
652
651
  allow_tf32=False,
653
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
654
653
  """
655
654
  Q: jagged tensor, [sum_B, D]
656
655
  K: dense tensor, [B, D, T]
@@ -668,7 +667,7 @@ def jagged_dense_flash_attention_bwd(
668
667
  if not do.is_contiguous():
669
668
  do = do.contiguous()
670
669
 
671
- (B, D, T) = K.size()
670
+ B, D, T = K.size()
672
671
  BLOCK_T = 32
673
672
  BLOCK_L = 32
674
673
  BLOCK_D = D
@@ -812,7 +811,7 @@ class JaggedDenseFlashAttention(torch.autograd.Function):
812
811
  # pyre-fixme
813
812
  def backward(
814
813
  ctx, do: torch.Tensor
815
- ) -> Tuple[
814
+ ) -> tuple[
816
815
  torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None
817
816
  ]:
818
817
  Q, K, V, attn_bias, jagged_offsets, lse, attn_out = ctx.saved_tensors
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -607,7 +606,7 @@ class JaggedFlashAttentionBasic(torch.autograd.Function):
607
606
  # pyre-fixme
608
607
  def backward(
609
608
  ctx, grad_output: torch.Tensor
610
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
609
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
611
610
  (
612
611
  jagged_Q,
613
612
  jagged_K,
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -688,7 +687,7 @@ class MultiHeadJaggedFlashAttention(torch.autograd.Function):
688
687
  # pyre-fixme
689
688
  def backward(
690
689
  ctx, grad_output: torch.Tensor
691
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
690
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
692
691
  (
693
692
  jagged_Q,
694
693
  jagged_K,