fbgemm-gpu-nightly-cpu 2025.7.19__cp311-cp311-manylinux_2_28_aarch64.whl → 2026.1.29__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +112 -19
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +3 -3
- fbgemm_gpu/config/feature_list.py +7 -1
- fbgemm_gpu/docs/jagged_tensor_ops.py +0 -1
- fbgemm_gpu/docs/sparse_ops.py +118 -0
- fbgemm_gpu/docs/target.default.json.py +6 -0
- fbgemm_gpu/enums.py +3 -4
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/fbgemm_gpu_config.so +0 -0
- fbgemm_gpu/fbgemm_gpu_embedding_inplace_ops.so +0 -0
- fbgemm_gpu/fbgemm_gpu_py.so +0 -0
- fbgemm_gpu/fbgemm_gpu_sparse_async_cumsum.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_cache.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_common.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_index_select.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_inference.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_optimizers.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_dense.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_gwd.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_pt2.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_split_host.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_backward_vbe.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_training_forward.so +0 -0
- fbgemm_gpu/fbgemm_gpu_tbe_utils.so +0 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +5 -4
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +4 -4
- fbgemm_gpu/quantize/__init__.py +2 -0
- fbgemm_gpu/quantize/quantize_ops.py +1 -0
- fbgemm_gpu/quantize_comm.py +29 -12
- fbgemm_gpu/quantize_utils.py +88 -8
- fbgemm_gpu/runtime_monitor.py +9 -5
- fbgemm_gpu/sll/__init__.py +3 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +8 -8
- fbgemm_gpu/sll/triton/__init__.py +0 -10
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +2 -3
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +2 -2
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +1 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +5 -6
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +1 -2
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +1 -2
- fbgemm_gpu/sparse_ops.py +190 -54
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/__init__.py +12 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py +14 -7
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_args_ssd.py +2 -0
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_lars_sgd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_none.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_adam.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_partial_rowwise_lamb.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_ssd.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_rowwise_adagrad_with_counter.py +12 -5
- fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_sgd.py +12 -5
- fbgemm_gpu/split_embedding_configs.py +134 -37
- fbgemm_gpu/split_embedding_inference_converter.py +7 -6
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +117 -24
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +37 -37
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +764 -123
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +0 -1
- fbgemm_gpu/tbe/bench/__init__.py +6 -1
- fbgemm_gpu/tbe/bench/bench_config.py +14 -3
- fbgemm_gpu/tbe/bench/bench_runs.py +163 -14
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +5 -2
- fbgemm_gpu/tbe/bench/eeg_cli.py +3 -3
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +3 -2
- fbgemm_gpu/tbe/bench/eval_compression.py +3 -3
- fbgemm_gpu/tbe/bench/tbe_data_config.py +115 -197
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +332 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +108 -8
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +15 -8
- fbgemm_gpu/tbe/bench/utils.py +129 -5
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +22 -19
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +4 -4
- fbgemm_gpu/tbe/ssd/common.py +1 -0
- fbgemm_gpu/tbe/ssd/inference.py +15 -15
- fbgemm_gpu/tbe/ssd/training.py +1292 -267
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +2 -3
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +198 -42
- fbgemm_gpu/tbe/utils/offsets.py +6 -6
- fbgemm_gpu/tbe/utils/quantize.py +8 -8
- fbgemm_gpu/tbe/utils/requests.py +15 -15
- fbgemm_gpu/tbe_input_multiplexer.py +10 -11
- fbgemm_gpu/triton/common.py +0 -1
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +11 -11
- fbgemm_gpu/triton/quantize.py +14 -9
- fbgemm_gpu/utils/filestore.py +6 -2
- fbgemm_gpu/utils/torch_library.py +2 -2
- fbgemm_gpu/utils/writeback_util.py +124 -0
- fbgemm_gpu/uvm.py +1 -0
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/METADATA +2 -2
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/RECORD +135 -0
- fbgemm_gpu_nightly_cpu-2026.1.29.dist-info/top_level.txt +2 -0
- fbgemm_gpu/docs/version.py → list_versions/__init__.py +5 -4
- list_versions/cli_run.py +161 -0
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/RECORD +0 -131
- fbgemm_gpu_nightly_cpu-2025.7.19.dist-info/top_level.txt +0 -1
- {fbgemm_gpu_nightly_cpu-2025.7.19.dist-info → fbgemm_gpu_nightly_cpu-2026.1.29.dist-info}/WHEEL +0 -0
fbgemm_gpu/quantize_utils.py
CHANGED
|
@@ -10,11 +10,34 @@
|
|
|
10
10
|
import logging
|
|
11
11
|
from typing import Optional, Union
|
|
12
12
|
|
|
13
|
-
import torch
|
|
13
|
+
import torch # isort:skip
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
import fbgemm_gpu
|
|
16
|
+
from fbgemm_gpu.split_embedding_configs import SparseType
|
|
17
|
+
from fbgemm_gpu.triton.common import RoundingMode
|
|
16
18
|
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
|
|
17
19
|
|
|
20
|
+
try:
|
|
21
|
+
if torch.cuda.is_available():
|
|
22
|
+
from fbgemm_gpu.triton import quantize_mx4
|
|
23
|
+
from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
|
|
24
|
+
except Exception:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
30
|
+
open_source = bool(getattr(fbgemm_gpu, "open_source", False))
|
|
31
|
+
except NotImplementedError:
|
|
32
|
+
open_source = False
|
|
33
|
+
|
|
34
|
+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
|
35
|
+
if not open_source:
|
|
36
|
+
from mtia.kernels.triton.mx4.quantize import (
|
|
37
|
+
triton_dequantize_mx4 as mtia_dequantize_mx4,
|
|
38
|
+
triton_quantize_mx4 as mtia_quantize_mx4,
|
|
39
|
+
)
|
|
40
|
+
|
|
18
41
|
logger: logging.Logger = logging.getLogger()
|
|
19
42
|
|
|
20
43
|
try:
|
|
@@ -60,7 +83,7 @@ def fp32_to_mx4(
|
|
|
60
83
|
if rounding_mode is None:
|
|
61
84
|
rounding_mode = RoundingMode.even
|
|
62
85
|
|
|
63
|
-
if not tensor.is_cuda:
|
|
86
|
+
if not tensor.is_cuda and not tensor.is_mtia:
|
|
64
87
|
return py_quantize_mx4(
|
|
65
88
|
tensor,
|
|
66
89
|
group_size,
|
|
@@ -71,6 +94,15 @@ def fp32_to_mx4(
|
|
|
71
94
|
)
|
|
72
95
|
|
|
73
96
|
if use_triton:
|
|
97
|
+
if tensor.is_mtia:
|
|
98
|
+
return mtia_quantize_mx4(
|
|
99
|
+
tensor,
|
|
100
|
+
group_size,
|
|
101
|
+
ebits=ebits,
|
|
102
|
+
mbits=mbits,
|
|
103
|
+
rounding_mode=rounding_mode,
|
|
104
|
+
stochastic_casting=stochastic_casting,
|
|
105
|
+
)
|
|
74
106
|
return quantize_mx4(
|
|
75
107
|
tensor,
|
|
76
108
|
group_size,
|
|
@@ -102,23 +134,71 @@ def mx4_to_fp32(
|
|
|
102
134
|
) -> torch.Tensor:
|
|
103
135
|
"""Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
|
|
104
136
|
|
|
137
|
+
This function is kept for backward compatibility and always returns FP32.
|
|
138
|
+
For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
|
|
139
|
+
"""
|
|
140
|
+
return mx4_to_float(
|
|
141
|
+
tensor,
|
|
142
|
+
group_size,
|
|
143
|
+
use_triton,
|
|
144
|
+
ebits,
|
|
145
|
+
mbits,
|
|
146
|
+
output_dtype=None, # None = FP32 default for backward compatibility
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def mx4_to_float(
|
|
151
|
+
tensor: torch.Tensor,
|
|
152
|
+
group_size: int = 32,
|
|
153
|
+
use_triton: bool = True,
|
|
154
|
+
ebits: int = 2,
|
|
155
|
+
mbits: int = 1,
|
|
156
|
+
output_dtype: Optional[SparseType] = None,
|
|
157
|
+
) -> torch.Tensor:
|
|
158
|
+
"""Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.
|
|
159
|
+
|
|
105
160
|
Args:
|
|
106
161
|
tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
|
|
107
162
|
group_size (int): Compute scale in chunks of group_size.
|
|
108
163
|
use_triton (bool): If set, use triton quantization, otherwise cuda.
|
|
109
164
|
ebits (int): Number of exponent bits in target mx4 format.
|
|
110
165
|
mbits (int): Number of mantissa bits in target mx4 format.
|
|
166
|
+
output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
|
|
167
|
+
Defaults to None (FP32) for backward compatibility.
|
|
111
168
|
|
|
112
169
|
Return:
|
|
113
|
-
output:
|
|
170
|
+
output: Tensor with dtype matching output_dtype and total elements (M).
|
|
114
171
|
"""
|
|
172
|
+
# Validate output_dtype
|
|
173
|
+
supported_dtypes = {SparseType.FP32, SparseType.BF16}
|
|
174
|
+
if output_dtype is not None and output_dtype not in supported_dtypes:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
|
|
177
|
+
f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
|
|
178
|
+
f"Use BF16 for memory savings with same dynamic range as FP32."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
target_dtype = (
|
|
182
|
+
output_dtype.as_dtype() if output_dtype is not None else torch.float32
|
|
183
|
+
)
|
|
184
|
+
|
|
115
185
|
# Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
|
|
116
|
-
if not tensor.is_cuda:
|
|
117
|
-
|
|
186
|
+
if not tensor.is_cuda and not tensor.is_mtia:
|
|
187
|
+
result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
|
|
188
|
+
return result.to(target_dtype) if output_dtype is not None else result
|
|
118
189
|
if use_triton:
|
|
119
|
-
|
|
190
|
+
if tensor.is_mtia:
|
|
191
|
+
return mtia_dequantize_mx4(
|
|
192
|
+
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
|
|
193
|
+
)
|
|
194
|
+
return triton_dequantize_mx4(
|
|
195
|
+
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
|
|
196
|
+
)
|
|
120
197
|
else:
|
|
121
|
-
|
|
198
|
+
output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
|
|
199
|
+
return torch.ops.fbgemm.dequantize_mx_cuda(
|
|
200
|
+
tensor.flatten(), group_size, output_dtype_int
|
|
201
|
+
)
|
|
122
202
|
|
|
123
203
|
|
|
124
204
|
def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
|
fbgemm_gpu/runtime_monitor.py
CHANGED
|
@@ -12,7 +12,7 @@ import logging
|
|
|
12
12
|
from collections import deque
|
|
13
13
|
from dataclasses import dataclass
|
|
14
14
|
from types import TracebackType
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Optional, TypeVar
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
@@ -49,6 +49,7 @@ class TBEStatsReporter(abc.ABC):
|
|
|
49
49
|
embedding_id: str = "",
|
|
50
50
|
tbe_id: str = "",
|
|
51
51
|
time_unit: str = "ms",
|
|
52
|
+
enable_tb_metrics: bool = False,
|
|
52
53
|
) -> None:
|
|
53
54
|
"""
|
|
54
55
|
Report the duration of a timed event.
|
|
@@ -63,6 +64,7 @@ class TBEStatsReporter(abc.ABC):
|
|
|
63
64
|
data_bytes: int,
|
|
64
65
|
embedding_id: str = "",
|
|
65
66
|
tbe_id: str = "",
|
|
67
|
+
enable_tb_metrics: bool = False,
|
|
66
68
|
) -> None:
|
|
67
69
|
"""
|
|
68
70
|
Report the size of some data amount.
|
|
@@ -89,9 +91,10 @@ class StdLogStatsReporter(TBEStatsReporter):
|
|
|
89
91
|
embedding_id: str = "",
|
|
90
92
|
tbe_id: str = "",
|
|
91
93
|
time_unit: str = "ms",
|
|
94
|
+
enable_tb_metrics: bool = False,
|
|
92
95
|
) -> None:
|
|
93
96
|
logging.info(
|
|
94
|
-
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit}"
|
|
97
|
+
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit} with {enable_tb_metrics}"
|
|
95
98
|
)
|
|
96
99
|
|
|
97
100
|
def report_data_amount(
|
|
@@ -101,9 +104,10 @@ class StdLogStatsReporter(TBEStatsReporter):
|
|
|
101
104
|
data_bytes: int,
|
|
102
105
|
embedding_id: str = "",
|
|
103
106
|
tbe_id: str = "",
|
|
107
|
+
enable_tb_metrics: bool = False,
|
|
104
108
|
) -> None:
|
|
105
109
|
logging.info(
|
|
106
|
-
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes"
|
|
110
|
+
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes with {enable_tb_metrics}"
|
|
107
111
|
)
|
|
108
112
|
|
|
109
113
|
def __repr__(self) -> str:
|
|
@@ -167,7 +171,7 @@ class AsyncSeriesTimerRecordedContext:
|
|
|
167
171
|
|
|
168
172
|
def __exit__(
|
|
169
173
|
self,
|
|
170
|
-
exc_type: Optional[
|
|
174
|
+
exc_type: Optional[type[BaseException]],
|
|
171
175
|
exc_val: Optional[BaseException],
|
|
172
176
|
exc_tb: Optional[TracebackType],
|
|
173
177
|
) -> None:
|
|
@@ -187,7 +191,7 @@ class AsyncSeriesTimer:
|
|
|
187
191
|
"""
|
|
188
192
|
|
|
189
193
|
def __init__(self, report_functor: Callable[[T, float], None]) -> None:
|
|
190
|
-
self._events_queue:
|
|
194
|
+
self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
|
|
191
195
|
deque()
|
|
192
196
|
)
|
|
193
197
|
self._active_start_event: Optional[torch.cuda.Event] = None
|
fbgemm_gpu/sll/__init__.py
CHANGED
|
@@ -9,12 +9,14 @@
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
|
|
12
|
+
# fmt:skip
|
|
12
13
|
from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
|
|
13
14
|
from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
|
|
14
15
|
from fbgemm_gpu.utils import TorchLibraryFragment
|
|
15
16
|
|
|
16
17
|
lib = TorchLibraryFragment("fbgemm")
|
|
17
18
|
|
|
19
|
+
# fmt:off
|
|
18
20
|
lib.define(
|
|
19
21
|
"""sll_jagged_dense_bmm(
|
|
20
22
|
Tensor x,
|
|
@@ -170,6 +172,7 @@ lib.define(
|
|
|
170
172
|
) -> Tensor
|
|
171
173
|
"""
|
|
172
174
|
)
|
|
175
|
+
# fmt:on
|
|
173
176
|
|
|
174
177
|
# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
|
|
175
178
|
# function however, this is not ideal because in the inference case, we don't
|
fbgemm_gpu/sll/cpu/cpu_sll.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
# pyre-strict
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
|
|
@@ -65,7 +65,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
|
|
|
65
65
|
# pyre-fixme
|
|
66
66
|
def backward(
|
|
67
67
|
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
68
|
-
) ->
|
|
68
|
+
) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
|
|
69
69
|
"""
|
|
70
70
|
# X = [Sum_B, D]
|
|
71
71
|
# Y = [B, D, T]
|
|
@@ -73,7 +73,7 @@ class JaggedDenseBmmCPU(torch.autograd.Function):
|
|
|
73
73
|
# dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
|
|
74
74
|
# dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
|
|
75
75
|
"""
|
|
76
|
-
|
|
76
|
+
x, y, x_offsets = ctx.saved_tensors
|
|
77
77
|
N = ctx.N
|
|
78
78
|
grad_x = cpu_jagged_dense_bmm_kernel(
|
|
79
79
|
grad_output, y.permute(0, 2, 1), x_offsets, N
|
|
@@ -128,7 +128,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
|
|
|
128
128
|
# pyre-fixme
|
|
129
129
|
def backward(
|
|
130
130
|
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
131
|
-
) ->
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
|
|
132
132
|
"""
|
|
133
133
|
# X = [Sum_B, D]
|
|
134
134
|
# Y = [Sum_B, T]
|
|
@@ -136,7 +136,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
|
|
|
136
136
|
# dXT = dZ * YT -> dX = Y * dZT
|
|
137
137
|
# dY = X * dZ -> X * dZ
|
|
138
138
|
"""
|
|
139
|
-
|
|
139
|
+
x, y, offsets = ctx.saved_tensors
|
|
140
140
|
N = ctx.N
|
|
141
141
|
grad_x = cpu_jagged_dense_bmm_kernel(
|
|
142
142
|
y, grad_output.permute(0, 2, 1), offsets, N
|
|
@@ -172,7 +172,7 @@ def cpu_dense_jagged_cat_jagged_out(
|
|
|
172
172
|
b: torch.Tensor,
|
|
173
173
|
b_offsets: torch.Tensor,
|
|
174
174
|
max_seq_len: int,
|
|
175
|
-
) ->
|
|
175
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
176
176
|
assert a.size(0) == b_offsets.size(0) - 1
|
|
177
177
|
c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
|
|
178
178
|
c_offsets = b_offsets + torch.arange(
|
|
@@ -368,7 +368,7 @@ class JaggedSoftmaxCPU(torch.autograd.Function):
|
|
|
368
368
|
# pyre-fixme
|
|
369
369
|
def backward(
|
|
370
370
|
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
371
|
-
) ->
|
|
371
|
+
) -> tuple[torch.Tensor, None, None]:
|
|
372
372
|
y, x_offsets = ctx.saved_tensors
|
|
373
373
|
|
|
374
374
|
B = x_offsets.size(0) - 1
|
|
@@ -923,7 +923,7 @@ class JaggedDenseAddCPU(torch.autograd.Function):
|
|
|
923
923
|
def backward(
|
|
924
924
|
ctx, # pyre-ignore
|
|
925
925
|
grad_output: torch.Tensor,
|
|
926
|
-
) ->
|
|
926
|
+
) -> tuple[torch.Tensor, None, torch.Tensor, None]:
|
|
927
927
|
(offsets,) = ctx.saved_tensors
|
|
928
928
|
grad_dense = torch.ops.fbgemm.jagged_to_padded_dense(
|
|
929
929
|
grad_output, [offsets], [ctx.max_seq_len]
|
|
@@ -10,19 +10,16 @@
|
|
|
10
10
|
from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
|
|
11
11
|
dense_jagged_cat_jagged_out,
|
|
12
12
|
)
|
|
13
|
-
|
|
14
13
|
from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
|
|
15
14
|
jagged2_to_padded_dense,
|
|
16
15
|
Jagged2ToPaddedDense, # noqa F401
|
|
17
16
|
)
|
|
18
|
-
|
|
19
17
|
from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
|
|
20
18
|
jagged_dense_bmm,
|
|
21
19
|
jagged_jagged_bmm,
|
|
22
20
|
JaggedDenseBmm, # noqa F401
|
|
23
21
|
JaggedJaggedBmm, # noqa F401
|
|
24
22
|
)
|
|
25
|
-
|
|
26
23
|
from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
|
|
27
24
|
array_jagged_bmm_jagged_out,
|
|
28
25
|
ArrayJaggedBmmNopadding, # noqa F401
|
|
@@ -31,38 +28,31 @@ from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
|
|
|
31
28
|
triton_array_jagged_bmm_jagged_out, # noqa F401
|
|
32
29
|
triton_jagged_jagged_bmm_jagged_out, # noqa F401
|
|
33
30
|
)
|
|
34
|
-
|
|
35
31
|
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
|
|
36
32
|
jagged_dense_elementwise_add,
|
|
37
33
|
JaggedDenseAdd, # noqa F401
|
|
38
34
|
)
|
|
39
|
-
|
|
40
35
|
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
|
|
41
36
|
jagged_dense_elementwise_mul_jagged_out,
|
|
42
37
|
JaggedDenseElementwiseMul, # noqa F401
|
|
43
38
|
)
|
|
44
|
-
|
|
45
39
|
from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
|
|
46
40
|
jagged_dense_flash_attention,
|
|
47
41
|
JaggedDenseFlashAttention, # noqa F401
|
|
48
42
|
)
|
|
49
|
-
|
|
50
43
|
from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
|
|
51
44
|
jagged_flash_attention_basic,
|
|
52
45
|
JaggedFlashAttentionBasic, # noqa F401
|
|
53
46
|
)
|
|
54
|
-
|
|
55
47
|
from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
|
|
56
48
|
triton_jagged_self_substraction_jagged_out,
|
|
57
49
|
)
|
|
58
|
-
|
|
59
50
|
from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
|
|
60
51
|
jagged2_softmax,
|
|
61
52
|
Jagged2Softmax, # noqa F401
|
|
62
53
|
jagged_softmax,
|
|
63
54
|
JaggedSoftmax, # noqa F401
|
|
64
55
|
)
|
|
65
|
-
|
|
66
56
|
from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
|
|
67
57
|
multi_head_jagged_flash_attention,
|
|
68
58
|
MultiHeadJaggedFlashAttention, # noqa F401
|
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-unsafe
|
|
8
8
|
|
|
9
|
-
from typing import Tuple
|
|
10
9
|
|
|
11
10
|
import torch
|
|
12
11
|
import triton
|
|
@@ -196,9 +195,9 @@ class Jagged2ToPaddedDense(torch.autograd.Function):
|
|
|
196
195
|
# pyre-fixme
|
|
197
196
|
def backward(
|
|
198
197
|
ctx, grad_output: torch.Tensor
|
|
199
|
-
) ->
|
|
198
|
+
) -> tuple[torch.Tensor, None, None, None]:
|
|
200
199
|
max_length = ctx.max_length
|
|
201
|
-
|
|
200
|
+
lengths, offsets = ctx.saved_tensors
|
|
202
201
|
grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
|
|
203
202
|
return (grad_in, None, None, None)
|
|
204
203
|
|
|
@@ -326,7 +326,7 @@ class JaggedDenseBmm(torch.autograd.Function):
|
|
|
326
326
|
|
|
327
327
|
# logging.info(f"Jagged bmm backward called")
|
|
328
328
|
|
|
329
|
-
|
|
329
|
+
x, y, x_offsets = ctx.saved_tensors
|
|
330
330
|
N = ctx.N
|
|
331
331
|
grad_x = triton_jagged_dense_bmm(
|
|
332
332
|
grad_output, y.permute(0, 2, 1), x_offsets, N, allow_tf32=ctx.allow_tf32
|
|
@@ -369,7 +369,7 @@ class JaggedJaggedBmm(torch.autograd.Function):
|
|
|
369
369
|
# dXT = dZ * YT -> dX = Y * dZT
|
|
370
370
|
# dY = X * dZ -> X * dZ
|
|
371
371
|
"""
|
|
372
|
-
|
|
372
|
+
x, y, offsets = ctx.saved_tensors
|
|
373
373
|
N = ctx.N
|
|
374
374
|
grad_x = triton_jagged_dense_bmm(
|
|
375
375
|
y, grad_output.permute(0, 2, 1), offsets, N, allow_tf32=ctx.allow_tf32
|
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-unsafe
|
|
8
8
|
|
|
9
|
-
from typing import Tuple
|
|
10
9
|
|
|
11
10
|
import torch
|
|
12
11
|
import triton
|
|
@@ -171,7 +170,7 @@ def jagged_dense_flash_attention_fwd(
|
|
|
171
170
|
jagged_offsets,
|
|
172
171
|
max_seq_len,
|
|
173
172
|
allow_tf32=False,
|
|
174
|
-
) ->
|
|
173
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
175
174
|
"""
|
|
176
175
|
Q: jagged tensor, [sum_B, D]
|
|
177
176
|
K: dense tensor, [B, D, T]
|
|
@@ -192,7 +191,7 @@ def jagged_dense_flash_attention_fwd(
|
|
|
192
191
|
assert Q.size() == V.size(), "incompatible dimensions for Q and V"
|
|
193
192
|
assert jagged_offsets.is_contiguous(), "jagged_offsets must be contiguous"
|
|
194
193
|
|
|
195
|
-
|
|
194
|
+
B, D, T = K.size()
|
|
196
195
|
assert D > 0 and (D & (D - 1)) == 0, "D needs to be a power of two"
|
|
197
196
|
|
|
198
197
|
attn_out = torch.zeros(B, T, D, dtype=Q.dtype, device=Q.device)
|
|
@@ -650,7 +649,7 @@ def jagged_dense_flash_attention_bwd(
|
|
|
650
649
|
jagged_offsets,
|
|
651
650
|
max_seq_len,
|
|
652
651
|
allow_tf32=False,
|
|
653
|
-
) ->
|
|
652
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
654
653
|
"""
|
|
655
654
|
Q: jagged tensor, [sum_B, D]
|
|
656
655
|
K: dense tensor, [B, D, T]
|
|
@@ -668,7 +667,7 @@ def jagged_dense_flash_attention_bwd(
|
|
|
668
667
|
if not do.is_contiguous():
|
|
669
668
|
do = do.contiguous()
|
|
670
669
|
|
|
671
|
-
|
|
670
|
+
B, D, T = K.size()
|
|
672
671
|
BLOCK_T = 32
|
|
673
672
|
BLOCK_L = 32
|
|
674
673
|
BLOCK_D = D
|
|
@@ -812,7 +811,7 @@ class JaggedDenseFlashAttention(torch.autograd.Function):
|
|
|
812
811
|
# pyre-fixme
|
|
813
812
|
def backward(
|
|
814
813
|
ctx, do: torch.Tensor
|
|
815
|
-
) ->
|
|
814
|
+
) -> tuple[
|
|
816
815
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None
|
|
817
816
|
]:
|
|
818
817
|
Q, K, V, attn_bias, jagged_offsets, lse, attn_out = ctx.saved_tensors
|
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-unsafe
|
|
8
8
|
|
|
9
|
-
from typing import Tuple
|
|
10
9
|
|
|
11
10
|
import torch
|
|
12
11
|
import triton
|
|
@@ -607,7 +606,7 @@ class JaggedFlashAttentionBasic(torch.autograd.Function):
|
|
|
607
606
|
# pyre-fixme
|
|
608
607
|
def backward(
|
|
609
608
|
ctx, grad_output: torch.Tensor
|
|
610
|
-
) ->
|
|
609
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]:
|
|
611
610
|
(
|
|
612
611
|
jagged_Q,
|
|
613
612
|
jagged_K,
|
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
|
|
7
7
|
# pyre-unsafe
|
|
8
8
|
|
|
9
|
-
from typing import Tuple
|
|
10
9
|
|
|
11
10
|
import torch
|
|
12
11
|
import triton
|
|
@@ -688,7 +687,7 @@ class MultiHeadJaggedFlashAttention(torch.autograd.Function):
|
|
|
688
687
|
# pyre-fixme
|
|
689
688
|
def backward(
|
|
690
689
|
ctx, grad_output: torch.Tensor
|
|
691
|
-
) ->
|
|
690
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]:
|
|
692
691
|
(
|
|
693
692
|
jagged_Q,
|
|
694
693
|
jagged_K,
|