fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. fbgemm_gpu/__init__.py +112 -19
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
  4. fbgemm_gpu/config/feature_list.py +7 -1
  5. fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
  6. fbgemm_gpu/docs/sparse_ops.py +118 -0
  7. fbgemm_gpu/docs/target.default.json.py +6 -0
  8. fbgemm_gpu/enums.py +3 -4
  9. fbgemm_gpu/fbgemm.so +0 -0
  10. fbgemm_gpu/fbgemm_gpu_config.so +0 -0
  11. fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
  12. fbgemm_gpu/fbgemm_gpu_py.so +0 -0
  13. fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
  14. fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
  15. fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
  16. fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
  17. fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
  18. fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
  19. fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
  20. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
  21. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
  22. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
  23. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
  24. fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
  25. fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
  26. fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
  27. fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
  28. fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
  29. fbgemm_gpu/quantize/__init__.py +2 -0
  30. fbgemm_gpu/quantize/quantize_ops.py +1 -0
  31. fbgemm_gpu/quantize_comm.py +29 -12
  32. fbgemm_gpu/quantize_utils.py +88 -8
  33. fbgemm_gpu/runtime_monitor.py +9 -5
  34. fbgemm_gpu/sll/__init__.py +3 -0
  35. fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
  36. fbgemm_gpu/sll/triton/__init__.py +0 -10
  37. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
  38. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
  39. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
  40. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
  41. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
  42. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
  43. fbgemm_gpu/sparse_ops.py +190 -54
  44. fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
  45. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
  46. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
  47. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
  48. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
  49. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
  50. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
  51. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
  52. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
  53. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
  54. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
  55. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
  56. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
  57. fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
  58. fbgemm_gpu/split_embedding_configs.py +134 -37
  59. fbgemm_gpu/split_embedding_inference_converter.py +7 -6
  60. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
  61. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
  62. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
  63. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
  64. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
  65. fbgemm_gpu/tbe/bench/__init__.py +6 -1
  66. fbgemm_gpu/tbe/bench/bench_config.py +14 -3
  67. fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
  68. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
  69. fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
  70. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
  71. fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
  72. fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
  73. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
  74. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
  75. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
  76. fbgemm_gpu/tbe/bench/utils.py +129 -5
  77. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
  78. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
  79. fbgemm_gpu/tbe/ssd/common.py +1 -0
  80. fbgemm_gpu/tbe/ssd/inference.py +15 -15
  81. fbgemm_gpu/tbe/ssd/training.py +1292 -267
  82. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
  83. fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
  84. fbgemm_gpu/tbe/utils/offsets.py +6 -6
  85. fbgemm_gpu/tbe/utils/quantize.py +8 -8
  86. fbgemm_gpu/tbe/utils/requests.py +15 -15
  87. fbgemm_gpu/tbe_input_multiplexer.py +10 -11
  88. fbgemm_gpu/triton/common.py +0 -1
  89. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
  90. fbgemm_gpu/triton/quantize.py +14 -9
  91. fbgemm_gpu/utils/filestore.py +6 -2
  92. fbgemm_gpu/utils/torch_library.py +2 -2
  93. fbgemm_gpu/utils/writeback_util.py +124 -0
  94. fbgemm_gpu/uvm.py +1 -0
  95. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
  96. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
  97. fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
  98. fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
  99. list_versions/cli_run.py +161 -0
  100. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
  101. fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
  102. {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
@@ -10,11 +10,34 @@
10
10
  import logging
11
11
  from typing import Optional, Union
12
12
 
13
- import torch
13
+ import torch # isort:skip
14
14
 
15
- from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
15
+ import fbgemm_gpu
16
+ from fbgemm_gpu.split_embedding_configs import SparseType
17
+ from fbgemm_gpu.triton.common import RoundingMode
16
18
  from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
17
19
 
20
+ try:
21
+ if torch.cuda.is_available():
22
+ from fbgemm_gpu.triton import quantize_mx4
23
+ from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
24
+ except Exception:
25
+ pass
26
+
27
+
28
+ try:
29
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
30
+ open_source = bool(getattr(fbgemm_gpu, "open_source", False))
31
+ except NotImplementedError:
32
+ open_source = False
33
+
34
+ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
35
+ if not open_source:
36
+ from mtia.kernels.triton.mx4.quantize import (
37
+ triton_dequantize_mx4 as mtia_dequantize_mx4,
38
+ triton_quantize_mx4 as mtia_quantize_mx4,
39
+ )
40
+
18
41
  logger: logging.Logger = logging.getLogger()
19
42
 
20
43
  try:
@@ -60,7 +83,7 @@ def fp32_to_mx4(
60
83
  if rounding_mode is None:
61
84
  rounding_mode = RoundingMode.even
62
85
 
63
- if not tensor.is_cuda:
86
+ if not tensor.is_cuda and not tensor.is_mtia:
64
87
  return py_quantize_mx4(
65
88
  tensor,
66
89
  group_size,
@@ -71,6 +94,15 @@ def fp32_to_mx4(
71
94
  )
72
95
 
73
96
  if use_triton:
97
+ if tensor.is_mtia:
98
+ return mtia_quantize_mx4(
99
+ tensor,
100
+ group_size,
101
+ ebits=ebits,
102
+ mbits=mbits,
103
+ rounding_mode=rounding_mode,
104
+ stochastic_casting=stochastic_casting,
105
+ )
74
106
  return quantize_mx4(
75
107
  tensor,
76
108
  group_size,
@@ -102,23 +134,71 @@ def mx4_to_fp32(
102
134
  ) -> torch.Tensor:
103
135
  """Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
104
136
 
137
+ This function is kept for backward compatibility and always returns FP32.
138
+ For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
139
+ """
140
+ return mx4_to_float(
141
+ tensor,
142
+ group_size,
143
+ use_triton,
144
+ ebits,
145
+ mbits,
146
+ output_dtype=None, # None = FP32 default for backward compatibility
147
+ )
148
+
149
+
150
+ def mx4_to_float(
151
+ tensor: torch.Tensor,
152
+ group_size: int = 32,
153
+ use_triton: bool = True,
154
+ ebits: int = 2,
155
+ mbits: int = 1,
156
+ output_dtype: Optional[SparseType] = None,
157
+ ) -> torch.Tensor:
158
+ """Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
159
+
105
160
  Args:
106
161
  tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
107
162
  group_size (int): Compute scale in chunks of group_size.
108
163
  use_triton (bool): If set, use triton quantization, otherwise cuda.
109
164
  ebits (int): Number of exponent bits in target mx4 format.
110
165
  mbits (int): Number of mantissa bits in target mx4 format.
166
+ output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
167
+ Defaults to None (FP32) for backward compatibility.
111
168
 
112
169
  Return:
113
- output: FP32 tensor with total elements (M).
170
+ output: Tensor with dtype matching output_dtype and total elements (M).
114
171
  """
172
+ # Validate output_dtype
173
+ supported_dtypes = {SparseType.FP32, SparseType.BF16}
174
+ if output_dtype is not None and output_dtype not in supported_dtypes:
175
+ raise ValueError(
176
+ f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
177
+ f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
178
+ f"Use BF16 for memory savings with same dynamic range as FP32."
179
+ )
180
+
181
+ target_dtype = (
182
+ output_dtype.as_dtype() if output_dtype is not None else torch.float32
183
+ )
184
+
115
185
  # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
116
- if not tensor.is_cuda:
117
- return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
186
+ if not tensor.is_cuda and not tensor.is_mtia:
187
+ result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
188
+ return result.to(target_dtype) if output_dtype is not None else result
118
189
  if use_triton:
119
- return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
190
+ if tensor.is_mtia:
191
+ return mtia_dequantize_mx4(
192
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
193
+ )
194
+ return triton_dequantize_mx4(
195
+ tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
196
+ )
120
197
  else:
121
- return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
198
+ output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
199
+ return torch.ops.fbgemm.dequantize_mx_cuda(
200
+ tensor.flatten(), group_size, output_dtype_int
201
+ )
122
202
 
123
203
 
124
204
  def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
@@ -12,7 +12,7 @@ import logging
12
12
  from collections import deque
13
13
  from dataclasses import dataclass
14
14
  from types import TracebackType
15
- from typing import Callable, Deque, Optional, Tuple, Type, TypeVar
15
+ from typing import Callable, Optional, TypeVar
16
16
 
17
17
  import torch
18
18
 
@@ -49,6 +49,7 @@ class TBEStatsReporter(abc.ABC):
49
49
  embedding_id: str = "",
50
50
  tbe_id: str = "",
51
51
  time_unit: str = "ms",
52
+ enable_tb_metrics: bool = False,
52
53
  ) -> None:
53
54
  """
54
55
  Report the duration of a timed event.
@@ -63,6 +64,7 @@ class TBEStatsReporter(abc.ABC):
63
64
  data_bytes: int,
64
65
  embedding_id: str = "",
65
66
  tbe_id: str = "",
67
+ enable_tb_metrics: bool = False,
66
68
  ) -> None:
67
69
  """
68
70
  Report the size of some data amount.
@@ -89,9 +91,10 @@ class StdLogStatsReporter(TBEStatsReporter):
89
91
  embedding_id: str = "",
90
92
  tbe_id: str = "",
91
93
  time_unit: str = "ms",
94
+ enable_tb_metrics: bool = False,
92
95
  ) -> None:
93
96
  logging.info(
94
- f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit}"
97
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit} with {enable_tb_metrics}"
95
98
  )
96
99
 
97
100
  def report_data_amount(
@@ -101,9 +104,10 @@ class StdLogStatsReporter(TBEStatsReporter):
101
104
  data_bytes: int,
102
105
  embedding_id: str = "",
103
106
  tbe_id: str = "",
107
+ enable_tb_metrics: bool = False,
104
108
  ) -> None:
105
109
  logging.info(
106
- f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes"
110
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes with {enable_tb_metrics}"
107
111
  )
108
112
 
109
113
  def __repr__(self) -> str:
@@ -167,7 +171,7 @@ class AsyncSeriesTimerRecordedContext:
167
171
 
168
172
  def __exit__(
169
173
  self,
170
- exc_type: Optional[Type[BaseException]],
174
+ exc_type: Optional[type[BaseException]],
171
175
  exc_val: Optional[BaseException],
172
176
  exc_tb: Optional[TracebackType],
173
177
  ) -> None:
@@ -187,7 +191,7 @@ class AsyncSeriesTimer:
187
191
  """
188
192
 
189
193
  def __init__(self, report_functor: Callable[[T, float], None]) -> None:
190
- self._events_queue: Deque[Tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
194
+ self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
191
195
  deque()
192
196
  )
193
197
  self._active_start_event: Optional[torch.cuda.Event] = None
@@ -9,12 +9,14 @@
9
9
 
10
10
  import torch
11
11
 
12
+ # fmt:skip
12
13
  from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
13
14
  from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
14
15
  from fbgemm_gpu.utils import TorchLibraryFragment
15
16
 
16
17
  lib = TorchLibraryFragment("fbgemm")
17
18
 
19
+ # fmt:off
18
20
  lib.define(
19
21
  """sll_jagged_dense_bmm(
20
22
  Tensor x,
@@ -170,6 +172,7 @@ lib.define(
170
172
  ) -> Tensor
171
173
  """
172
174
  )
175
+ # fmt:on
173
176
 
174
177
  # NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
175
178
  # function however, this is not ideal because in the inference case, we don't
@@ -5,7 +5,7 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-strict
8
- from typing import Any, Tuple
8
+ from typing import Any
9
9
 
10
10
  import torch
11
11
 
@@ -65,7 +65,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
65
65
  # pyre-fixme
66
66
  def backward(
67
67
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
68
- ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
68
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
69
69
  """
70
70
  # X = [Sum_B, D]
71
71
  # Y = [B, D, T]
@@ -73,7 +73,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
73
73
  # dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
74
74
  # dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
75
75
  """
76
- (x, y, x_offsets) = ctx.saved_tensors
76
+ x, y, x_offsets = ctx.saved_tensors
77
77
  N = ctx.N
78
78
  grad_x = cpu_jagged_dense_bmm_kernel(
79
79
  grad_output, y.permute(0, 2, 1), x_offsets, N
@@ -128,7 +128,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
128
128
  # pyre-fixme
129
129
  def backward(
130
130
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
131
- ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
131
+ ) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
132
132
  """
133
133
  # X = [Sum_B, D]
134
134
  # Y = [Sum_B, T]
@@ -136,7 +136,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
136
136
  # dXT = dZ * YT -> dX = Y * dZT
137
137
  # dY = X * dZ -> X * dZ
138
138
  """
139
- (x, y, offsets) = ctx.saved_tensors
139
+ x, y, offsets = ctx.saved_tensors
140
140
  N = ctx.N
141
141
  grad_x = cpu_jagged_dense_bmm_kernel(
142
142
  y, grad_output.permute(0, 2, 1), offsets, N
@@ -172,7 +172,7 @@ def cpu_dense_jagged_cat_jagged_out(
172
172
  b: torch.Tensor,
173
173
  b_offsets: torch.Tensor,
174
174
  max_seq_len: int,
175
- ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ ) -> tuple[torch.Tensor, torch.Tensor]:
176
176
  assert a.size(0) == b_offsets.size(0) - 1
177
177
  c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
178
178
  c_offsets = b_offsets + torch.arange(
@@ -368,7 +368,7 @@ class JaggedSoftmaxCPU(torch.autograd.Function):
368
368
  # pyre-fixme
369
369
  def backward(
370
370
  ctx: Any, grad_output: torch.Tensor # pyre-ignore
371
- ) -> Tuple[torch.Tensor, None, None]:
371
+ ) -> tuple[torch.Tensor, None, None]:
372
372
  y, x_offsets = ctx.saved_tensors
373
373
 
374
374
  B = x_offsets.size(0) - 1
@@ -923,7 +923,7 @@ class JaggedDenseAddCPU(torch.autograd.Function):
923
923
  def backward(
924
924
  ctx, # pyre-ignore
925
925
  grad_output: torch.Tensor,
926
- ) -> Tuple[torch.Tensor, None, torch.Tensor, None]:
926
+ ) -> tuple[torch.Tensor, None, torch.Tensor, None]:
927
927
  (offsets,) = ctx.saved_tensors
928
928
  grad_dense = torch.ops.fbgemm.jagged_to_padded_dense(
929
929
  grad_output, [offsets], [ctx.max_seq_len]
@@ -10,19 +10,16 @@
10
10
  from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
11
11
  dense_jagged_cat_jagged_out,
12
12
  )
13
-
14
13
  from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
15
14
  jagged2_to_padded_dense,
16
15
  Jagged2ToPaddedDense, # noqa F401
17
16
  )
18
-
19
17
  from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
20
18
  jagged_dense_bmm,
21
19
  jagged_jagged_bmm,
22
20
  JaggedDenseBmm, # noqa F401
23
21
  JaggedJaggedBmm, # noqa F401
24
22
  )
25
-
26
23
  from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
27
24
  array_jagged_bmm_jagged_out,
28
25
  ArrayJaggedBmmNopadding, # noqa F401
@@ -31,38 +28,31 @@ from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
31
28
  triton_array_jagged_bmm_jagged_out, # noqa F401
32
29
  triton_jagged_jagged_bmm_jagged_out, # noqa F401
33
30
  )
34
-
35
31
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
36
32
  jagged_dense_elementwise_add,
37
33
  JaggedDenseAdd, # noqa F401
38
34
  )
39
-
40
35
  from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
41
36
  jagged_dense_elementwise_mul_jagged_out,
42
37
  JaggedDenseElementwiseMul, # noqa F401
43
38
  )
44
-
45
39
  from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
46
40
  jagged_dense_flash_attention,
47
41
  JaggedDenseFlashAttention, # noqa F401
48
42
  )
49
-
50
43
  from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
51
44
  jagged_flash_attention_basic,
52
45
  JaggedFlashAttentionBasic, # noqa F401
53
46
  )
54
-
55
47
  from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
56
48
  triton_jagged_self_substraction_jagged_out,
57
49
  )
58
-
59
50
  from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
60
51
  jagged2_softmax,
61
52
  Jagged2Softmax, # noqa F401
62
53
  jagged_softmax,
63
54
  JaggedSoftmax, # noqa F401
64
55
  )
65
-
66
56
  from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
67
57
  multi_head_jagged_flash_attention,
68
58
  MultiHeadJaggedFlashAttention, # noqa F401
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -196,9 +195,9 @@ class Jagged2ToPaddedDense(torch.autograd.Function):
196
195
  # pyre-fixme
197
196
  def backward(
198
197
  ctx, grad_output: torch.Tensor
199
- ) -> Tuple[torch.Tensor, None, None, None]:
198
+ ) -> tuple[torch.Tensor, None, None, None]:
200
199
  max_length = ctx.max_length
201
- (lengths, offsets) = ctx.saved_tensors
200
+ lengths, offsets = ctx.saved_tensors
202
201
  grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
203
202
  return (grad_in, None, None, None)
204
203
 
@@ -326,7 +326,7 @@ class JaggedDenseBmm(torch.autograd.Function):
326
326
 
327
327
  # logging.info(f"Jagged bmm backward called")
328
328
 
329
- (x, y, x_offsets) = ctx.saved_tensors
329
+ x, y, x_offsets = ctx.saved_tensors
330
330
  N = ctx.N
331
331
  grad_x = triton_jagged_dense_bmm(
332
332
  grad_output, y.permute(0, 2, 1), x_offsets, N, allow_tf32=ctx.allow_tf32
@@ -369,7 +369,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
369
369
  # dXT = dZ * YT -> dX = Y * dZT
370
370
  # dY = X * dZ -> X * dZ
371
371
  """
372
- (x, y, offsets) = ctx.saved_tensors
372
+ x, y, offsets = ctx.saved_tensors
373
373
  N = ctx.N
374
374
  grad_x = triton_jagged_dense_bmm(
375
375
  y, grad_output.permute(0, 2, 1), offsets, N, allow_tf32=ctx.allow_tf32
@@ -8,6 +8,7 @@
8
8
 
9
9
  import torch
10
10
 
11
+ # fmt:skip
11
12
  from fbgemm_gpu.triton.jagged.triton_jagged_tensor_ops import (
12
13
  dense_to_jagged,
13
14
  jagged_to_dense,
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -171,7 +170,7 @@ def jagged_dense_flash_attention_fwd(
171
170
  jagged_offsets,
172
171
  max_seq_len,
173
172
  allow_tf32=False,
174
- ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ ) -> tuple[torch.Tensor, torch.Tensor]:
175
174
  """
176
175
  Q: jagged tensor, [sum_B, D]
177
176
  K: dense tensor, [B, D, T]
@@ -192,7 +191,7 @@ def jagged_dense_flash_attention_fwd(
192
191
  assert Q.size() == V.size(), "incompatible dimensions for Q and V"
193
192
  assert jagged_offsets.is_contiguous(), "jagged_offsets must be contiguous"
194
193
 
195
- (B, D, T) = K.size()
194
+ B, D, T = K.size()
196
195
  assert D > 0 and (D & (D - 1)) == 0, "D needs to be a power of two"
197
196
 
198
197
  attn_out = torch.zeros(B, T, D, dtype=Q.dtype, device=Q.device)
@@ -650,7 +649,7 @@ def jagged_dense_flash_attention_bwd(
650
649
  jagged_offsets,
651
650
  max_seq_len,
652
651
  allow_tf32=False,
653
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
654
653
  """
655
654
  Q: jagged tensor, [sum_B, D]
656
655
  K: dense tensor, [B, D, T]
@@ -668,7 +667,7 @@ def jagged_dense_flash_attention_bwd(
668
667
  if not do.is_contiguous():
669
668
  do = do.contiguous()
670
669
 
671
- (B, D, T) = K.size()
670
+ B, D, T = K.size()
672
671
  BLOCK_T = 32
673
672
  BLOCK_L = 32
674
673
  BLOCK_D = D
@@ -812,7 +811,7 @@ class JaggedDenseFlashAttention(torch.autograd.Function):
812
811
  # pyre-fixme
813
812
  def backward(
814
813
  ctx, do: torch.Tensor
815
- ) -> Tuple[
814
+ ) -> tuple[
816
815
  torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None
817
816
  ]:
818
817
  Q, K, V, attn_bias, jagged_offsets, lse, attn_out = ctx.saved_tensors
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -607,7 +606,7 @@ class JaggedFlashAttentionBasic(torch.autograd.Function):
607
606
  # pyre-fixme
608
607
  def backward(
609
608
  ctx, grad_output: torch.Tensor
610
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
609
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
611
610
  (
612
611
  jagged_Q,
613
612
  jagged_K,
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- from typing import Tuple
10
9
 
11
10
  import torch
12
11
  import triton
@@ -688,7 +687,7 @@ class MultiHeadJaggedFlashAttention(torch.autograd.Function):
688
687
  # pyre-fixme
689
688
  def backward(
690
689
  ctx, grad_output: torch.Tensor
691
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
690
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
692
691
  (
693
692
  jagged_Q,
694
693
  jagged_K,