fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,1001 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def cpu_jagged_jagged_bmm_kernel(
|
|
14
|
+
x: torch.Tensor, y: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
assert x.size(1) == y.size(0), "incompatible dimensions"
|
|
17
|
+
B = x_offsets.size(0) - 1
|
|
18
|
+
D, _ = x.size()
|
|
19
|
+
_, T = y.size()
|
|
20
|
+
z = torch.empty((B, D, T), dtype=x.dtype, device=x.device)
|
|
21
|
+
|
|
22
|
+
for b in range(B):
|
|
23
|
+
z[b, :, :] = torch.mm(
|
|
24
|
+
x[:, x_offsets[b] : x_offsets[b + 1]],
|
|
25
|
+
y[x_offsets[b] : x_offsets[b + 1], :],
|
|
26
|
+
)
|
|
27
|
+
return z
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def cpu_jagged_dense_bmm_kernel(
|
|
31
|
+
x: torch.Tensor, y: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int
|
|
32
|
+
) -> torch.Tensor:
|
|
33
|
+
assert x.size(1) == y.size(1), "incompatible dimensions"
|
|
34
|
+
B = x_offsets.size(0) - 1
|
|
35
|
+
z = torch.zeros((x.size(0), y.size(2)), dtype=x.dtype, device=x.device)
|
|
36
|
+
|
|
37
|
+
for b in range(B):
|
|
38
|
+
z[x_offsets[b] : x_offsets[b + 1], :] = torch.mm(
|
|
39
|
+
x[x_offsets[b] : x_offsets[b + 1], :], y[b, :, :]
|
|
40
|
+
)
|
|
41
|
+
return z
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class JaggedDenseBmmCPU(torch.autograd.Function):
|
|
45
|
+
"""
|
|
46
|
+
Compute batch matrix multiplication between JaggedTensor and dense tensor
|
|
47
|
+
dense: [B, N, D] * [B, D, T] = [B, N, T]
|
|
48
|
+
jagged: [Sum_B, D] * [B, D, T] = [Sum_B, T]
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
# pyre-fixme
|
|
53
|
+
def forward(
|
|
54
|
+
ctx: Any, # pyre-ignore
|
|
55
|
+
x: torch.Tensor,
|
|
56
|
+
y: torch.Tensor,
|
|
57
|
+
x_offsets: torch.Tensor,
|
|
58
|
+
N: int,
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
ctx.save_for_backward(x, y, x_offsets)
|
|
61
|
+
ctx.N = N
|
|
62
|
+
return cpu_jagged_dense_bmm_kernel(x, y, x_offsets, N)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
# pyre-fixme
|
|
66
|
+
def backward(
|
|
67
|
+
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
68
|
+
) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
|
|
69
|
+
"""
|
|
70
|
+
# X = [Sum_B, D]
|
|
71
|
+
# Y = [B, D, T]
|
|
72
|
+
# Z = X * Y = [Sum_B, T]
|
|
73
|
+
# dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
|
|
74
|
+
# dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
|
|
75
|
+
"""
|
|
76
|
+
(x, y, x_offsets) = ctx.saved_tensors
|
|
77
|
+
N = ctx.N
|
|
78
|
+
grad_x = cpu_jagged_dense_bmm_kernel(
|
|
79
|
+
grad_output, y.permute(0, 2, 1), x_offsets, N
|
|
80
|
+
)
|
|
81
|
+
grad_y = cpu_jagged_jagged_bmm_kernel(x.T, grad_output, x_offsets, N)
|
|
82
|
+
return grad_x, grad_y, None, None, None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def cpu_jagged_dense_bmm(
|
|
86
|
+
x: torch.Tensor,
|
|
87
|
+
y: torch.Tensor,
|
|
88
|
+
x_offsets: torch.Tensor,
|
|
89
|
+
N: int,
|
|
90
|
+
allow_tf32: bool,
|
|
91
|
+
use_fbgemm_kernel: bool = True,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
"""
|
|
94
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
95
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
96
|
+
jagged: [D, Sum_B] * [Sum_B, T] = [B, D, T]
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
# Force the CPU backend to use fbgemm kernel as it has better performance
|
|
100
|
+
use_fbgemm_kernel = True
|
|
101
|
+
if use_fbgemm_kernel:
|
|
102
|
+
return torch.ops.fbgemm.jagged_dense_bmm(x, x_offsets, y, N)[0]
|
|
103
|
+
else:
|
|
104
|
+
return JaggedDenseBmmCPU.apply(x, y, x_offsets, N)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class JaggedJaggedBmm(torch.autograd.Function):
|
|
108
|
+
"""
|
|
109
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
110
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
111
|
+
jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
# pyre-fixme
|
|
116
|
+
def forward(
|
|
117
|
+
ctx: Any, # pyre-ignore
|
|
118
|
+
x: torch.Tensor,
|
|
119
|
+
y: torch.Tensor,
|
|
120
|
+
x_offsets: torch.Tensor,
|
|
121
|
+
N: int,
|
|
122
|
+
) -> torch.Tensor:
|
|
123
|
+
ctx.save_for_backward(x, y, x_offsets)
|
|
124
|
+
ctx.N = N
|
|
125
|
+
return cpu_jagged_jagged_bmm_kernel(x.T, y, x_offsets, N)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
# pyre-fixme
|
|
129
|
+
def backward(
|
|
130
|
+
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
|
|
132
|
+
"""
|
|
133
|
+
# X = [Sum_B, D]
|
|
134
|
+
# Y = [Sum_B, T]
|
|
135
|
+
# Z = XT * Y = [B, D, T]
|
|
136
|
+
# dXT = dZ * YT -> dX = Y * dZT
|
|
137
|
+
# dY = X * dZ -> X * dZ
|
|
138
|
+
"""
|
|
139
|
+
(x, y, offsets) = ctx.saved_tensors
|
|
140
|
+
N = ctx.N
|
|
141
|
+
grad_x = cpu_jagged_dense_bmm_kernel(
|
|
142
|
+
y, grad_output.permute(0, 2, 1), offsets, N
|
|
143
|
+
)
|
|
144
|
+
grad_y = cpu_jagged_dense_bmm_kernel(x, grad_output, offsets, N)
|
|
145
|
+
return grad_x, grad_y, None, None, None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def cpu_jagged_jagged_bmm(
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
y: torch.Tensor,
|
|
151
|
+
x_offsets: torch.Tensor,
|
|
152
|
+
N: int,
|
|
153
|
+
allow_tf32: bool,
|
|
154
|
+
use_fbgemm_kernel: bool = True,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
"""
|
|
157
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
158
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
159
|
+
jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
# Force the CPU backend to use fbgemm kernel as it has better performance
|
|
163
|
+
use_fbgemm_kernel = True
|
|
164
|
+
if use_fbgemm_kernel:
|
|
165
|
+
return torch.ops.fbgemm.jagged_jagged_bmm(x, y, x_offsets, N)
|
|
166
|
+
else:
|
|
167
|
+
return JaggedJaggedBmm.apply(x, y, x_offsets, N)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def cpu_dense_jagged_cat_jagged_out(
|
|
171
|
+
a: torch.Tensor,
|
|
172
|
+
b: torch.Tensor,
|
|
173
|
+
b_offsets: torch.Tensor,
|
|
174
|
+
max_seq_len: int,
|
|
175
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
176
|
+
assert a.size(0) == b_offsets.size(0) - 1
|
|
177
|
+
c = torch.empty(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
|
|
178
|
+
c_offsets = b_offsets + torch.arange(
|
|
179
|
+
b_offsets.size(0), dtype=torch.int64, device=a.device
|
|
180
|
+
)
|
|
181
|
+
lengths = torch.diff(b_offsets)
|
|
182
|
+
c = torch.cat(
|
|
183
|
+
[
|
|
184
|
+
(
|
|
185
|
+
torch.cat((a[i : i + 1], b[b_offsets[i] : b_offsets[i + 1]]), dim=-1)
|
|
186
|
+
if lengths[i] > 0
|
|
187
|
+
else a[i : i + 1]
|
|
188
|
+
)
|
|
189
|
+
for i in range(a.size(0))
|
|
190
|
+
],
|
|
191
|
+
dim=-1,
|
|
192
|
+
)
|
|
193
|
+
return c, c_offsets
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def cpu_jagged_self_substraction_jagged_out(
|
|
197
|
+
jagged_A: torch.Tensor,
|
|
198
|
+
offsets_a: torch.Tensor,
|
|
199
|
+
offsets_b: torch.Tensor,
|
|
200
|
+
max_seq_len: int,
|
|
201
|
+
) -> torch.Tensor:
|
|
202
|
+
jagged_B = torch.empty(
|
|
203
|
+
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
|
|
204
|
+
)
|
|
205
|
+
for i in range(len(offsets_a) - 1):
|
|
206
|
+
if offsets_a[i + 1] - offsets_a[i] == 1:
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
a = jagged_A[offsets_a[i] : offsets_a[i + 1]]
|
|
210
|
+
jagged_B[offsets_b[i] : offsets_b[i + 1]] = (
|
|
211
|
+
a[:-1].unsqueeze(1) - a[1:].unsqueeze(0)
|
|
212
|
+
).flatten()
|
|
213
|
+
return jagged_B
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def cpu_jagged2_to_padded_dense(
|
|
217
|
+
values: torch.Tensor,
|
|
218
|
+
offsets: torch.Tensor,
|
|
219
|
+
max_length: int,
|
|
220
|
+
padding_value: float = 0.0,
|
|
221
|
+
) -> torch.Tensor:
|
|
222
|
+
"""
|
|
223
|
+
values: jagged tensor with size [sum(Ni * Ni)]
|
|
224
|
+
offsets: offsets for jagged tensor, with size [B + 1]
|
|
225
|
+
max_length: maximum sequence length in the batch
|
|
226
|
+
padding_value: value to use for padding
|
|
227
|
+
return padded dense tensor of size [B, N, N]
|
|
228
|
+
"""
|
|
229
|
+
B = offsets.size(0) - 1
|
|
230
|
+
dense_output = torch.full(
|
|
231
|
+
(B, max_length, max_length),
|
|
232
|
+
padding_value,
|
|
233
|
+
dtype=values.dtype,
|
|
234
|
+
device=values.device,
|
|
235
|
+
)
|
|
236
|
+
for b in range(B):
|
|
237
|
+
begin = offsets[b]
|
|
238
|
+
end = offsets[b + 1]
|
|
239
|
+
Ni = int(torch.sqrt(end - begin))
|
|
240
|
+
if Ni == 0:
|
|
241
|
+
continue
|
|
242
|
+
dense_output[b, 0:Ni, 0:Ni] = values[begin:end].view(Ni, Ni)
|
|
243
|
+
|
|
244
|
+
return dense_output
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class CPUJaggedDenseElementwiseMul(torch.autograd.Function):
|
|
248
|
+
# NOTE: CPU, GPU, CUDA versions all have their own autograd.Function implementations,
|
|
249
|
+
# ideally we should use one autograd.Function for all of them and do the dispatching
|
|
250
|
+
# inside the autograd.Function.
|
|
251
|
+
"""
|
|
252
|
+
Compute elementwise multiplication between jagged tensor and dense tensor.
|
|
253
|
+
z = x * y
|
|
254
|
+
x: [sum_B(L_i)]
|
|
255
|
+
y: dense tensor
|
|
256
|
+
z: [sum_B(L_i)]
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def jagged_dense_elementwise_mul_jagged_out(
|
|
261
|
+
jagged: torch.Tensor,
|
|
262
|
+
dense: torch.Tensor,
|
|
263
|
+
seq_lengths: torch.Tensor,
|
|
264
|
+
offsets: torch.Tensor,
|
|
265
|
+
max_seq_len: int,
|
|
266
|
+
) -> torch.Tensor:
|
|
267
|
+
out = torch.empty_like(jagged)
|
|
268
|
+
for i in range(seq_lengths.size(0)):
|
|
269
|
+
if seq_lengths[i] == 0:
|
|
270
|
+
continue
|
|
271
|
+
a = jagged[offsets[i] : offsets[i + 1]]
|
|
272
|
+
a = a.view(int(seq_lengths[i]), int(seq_lengths[i]))
|
|
273
|
+
out[offsets[i] : offsets[i + 1]] = (
|
|
274
|
+
a * dense[0 : seq_lengths[i], 0 : seq_lengths[i]]
|
|
275
|
+
).flatten()
|
|
276
|
+
return out
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
# pyre-fixme
|
|
280
|
+
def forward(
|
|
281
|
+
ctx, # pyre-ignore [2]
|
|
282
|
+
x: torch.Tensor,
|
|
283
|
+
y: torch.Tensor,
|
|
284
|
+
x_seq_lengths: torch.Tensor,
|
|
285
|
+
x_offsets: torch.Tensor,
|
|
286
|
+
max_seq_len: int,
|
|
287
|
+
):
|
|
288
|
+
ctx.max_seq_len = max_seq_len
|
|
289
|
+
|
|
290
|
+
ctx.save_for_backward(
|
|
291
|
+
x,
|
|
292
|
+
y,
|
|
293
|
+
x_seq_lengths,
|
|
294
|
+
x_offsets,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
|
|
298
|
+
x,
|
|
299
|
+
y,
|
|
300
|
+
x_seq_lengths,
|
|
301
|
+
x_offsets,
|
|
302
|
+
max_seq_len,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
@staticmethod
|
|
306
|
+
# pyre-fixme
|
|
307
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
308
|
+
(
|
|
309
|
+
x,
|
|
310
|
+
y,
|
|
311
|
+
x_seq_lengths,
|
|
312
|
+
x_offsets,
|
|
313
|
+
) = ctx.saved_tensors
|
|
314
|
+
|
|
315
|
+
grad_x = CPUJaggedDenseElementwiseMul.jagged_dense_elementwise_mul_jagged_out(
|
|
316
|
+
grad_output,
|
|
317
|
+
y,
|
|
318
|
+
x_seq_lengths,
|
|
319
|
+
x_offsets,
|
|
320
|
+
ctx.max_seq_len,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
return grad_x, None, None, None, None
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def cpu_jagged_dense_elementwise_mul_jagged_out(
|
|
327
|
+
x: torch.Tensor,
|
|
328
|
+
y: torch.Tensor,
|
|
329
|
+
x_seq_lengths: torch.Tensor,
|
|
330
|
+
x_offsets: torch.Tensor,
|
|
331
|
+
max_seq_len: int,
|
|
332
|
+
) -> torch.Tensor:
|
|
333
|
+
return CPUJaggedDenseElementwiseMul.apply(
|
|
334
|
+
x,
|
|
335
|
+
y,
|
|
336
|
+
x_seq_lengths,
|
|
337
|
+
x_offsets,
|
|
338
|
+
max_seq_len,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class JaggedSoftmaxCPU(torch.autograd.Function):
|
|
343
|
+
@staticmethod
|
|
344
|
+
# pyre-fixme
|
|
345
|
+
def forward(
|
|
346
|
+
ctx: Any, # pyre-ignore
|
|
347
|
+
x: torch.Tensor,
|
|
348
|
+
x_offsets: torch.Tensor,
|
|
349
|
+
max_seq_len: int,
|
|
350
|
+
) -> torch.Tensor:
|
|
351
|
+
"""
|
|
352
|
+
input shpae is [SUM_B, D]
|
|
353
|
+
output shape is [SUM_B, D]
|
|
354
|
+
"""
|
|
355
|
+
B = x_offsets.size(0) - 1
|
|
356
|
+
y = torch.zeros(x.size(), device=x.device, dtype=x.dtype)
|
|
357
|
+
|
|
358
|
+
for b in range(B):
|
|
359
|
+
y[x_offsets[b] : x_offsets[b + 1], :] = torch.nn.functional.softmax(
|
|
360
|
+
x[x_offsets[b] : x_offsets[b + 1], :], dim=0
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
ctx.save_for_backward(y, x_offsets)
|
|
364
|
+
|
|
365
|
+
return y
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
# pyre-fixme
|
|
369
|
+
def backward(
|
|
370
|
+
ctx: Any, grad_output: torch.Tensor # pyre-ignore
|
|
371
|
+
) -> tuple[torch.Tensor, None, None]:
|
|
372
|
+
y, x_offsets = ctx.saved_tensors
|
|
373
|
+
|
|
374
|
+
B = x_offsets.size(0) - 1
|
|
375
|
+
grad = torch.zeros(y.size(), device=y.device, dtype=y.dtype)
|
|
376
|
+
|
|
377
|
+
for b in range(B):
|
|
378
|
+
curr_y = y[x_offsets[b] : x_offsets[b + 1]]
|
|
379
|
+
curr_grad = grad_output[x_offsets[b] : x_offsets[b + 1]]
|
|
380
|
+
grad[x_offsets[b] : x_offsets[b + 1]] = curr_y * (
|
|
381
|
+
curr_grad - torch.sum(curr_grad * curr_y, dim=0, keepdim=True)
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return grad, None, None
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def cpu_jagged_softmax(
|
|
388
|
+
x: torch.Tensor,
|
|
389
|
+
x_offsets: torch.Tensor,
|
|
390
|
+
max_seq_len: int,
|
|
391
|
+
use_fbgemm_kernel: bool = True,
|
|
392
|
+
) -> torch.Tensor:
|
|
393
|
+
"""
|
|
394
|
+
CPU version of jagged softmax: [sum(softmax([B_i, D]))]
|
|
395
|
+
"""
|
|
396
|
+
# Force the CPU backend to use fbgemm kernel as it has better performance
|
|
397
|
+
use_fbgemm_kernel = True
|
|
398
|
+
if use_fbgemm_kernel:
|
|
399
|
+
return torch.ops.fbgemm.jagged_softmax(x, x_offsets, max_seq_len)[0]
|
|
400
|
+
else:
|
|
401
|
+
return JaggedSoftmaxCPU.apply(x, x_offsets, max_seq_len)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class Jagged2SoftmaxCPU(torch.autograd.Function):
|
|
405
|
+
@staticmethod
|
|
406
|
+
# pyre-fixme
|
|
407
|
+
def forward(
|
|
408
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
409
|
+
ctx,
|
|
410
|
+
x: torch.Tensor,
|
|
411
|
+
x_offsets: torch.Tensor,
|
|
412
|
+
row_offsets: torch.Tensor,
|
|
413
|
+
head_offsets: torch.Tensor,
|
|
414
|
+
max_seq_len_row: int,
|
|
415
|
+
max_seq_len_head: int,
|
|
416
|
+
transpose: bool = True,
|
|
417
|
+
) -> torch.Tensor:
|
|
418
|
+
B = x_offsets.size(0) - 1
|
|
419
|
+
y = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
|
|
420
|
+
|
|
421
|
+
for i in range(B):
|
|
422
|
+
submatrix = x[x_offsets[i] : x_offsets[i + 1]]
|
|
423
|
+
Ni = int(row_offsets[i + 1] - row_offsets[i])
|
|
424
|
+
softmax_dim = 0 if transpose else 1
|
|
425
|
+
y[x_offsets[i] : x_offsets[i + 1]] = torch.nn.functional.softmax(
|
|
426
|
+
submatrix.reshape((Ni, Ni)), dim=softmax_dim
|
|
427
|
+
).view(-1)
|
|
428
|
+
|
|
429
|
+
ctx.save_for_backward(y, x_offsets, row_offsets, head_offsets)
|
|
430
|
+
ctx.max_seq_len_row = max_seq_len_row
|
|
431
|
+
ctx.max_seq_len_head = max_seq_len_head
|
|
432
|
+
ctx.transpose = transpose
|
|
433
|
+
|
|
434
|
+
return y
|
|
435
|
+
|
|
436
|
+
@staticmethod
|
|
437
|
+
# pyre-fixme
|
|
438
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
439
|
+
y, x_offsets, row_offsets, head_offsets = ctx.saved_tensors
|
|
440
|
+
B = x_offsets.size(0) - 1
|
|
441
|
+
transpose = ctx.transpose
|
|
442
|
+
softmax_dim = 0 if transpose else -1
|
|
443
|
+
grad = torch.zeros(y.size(0), device=y.device, dtype=y.dtype)
|
|
444
|
+
|
|
445
|
+
for i in range(B):
|
|
446
|
+
Ni = row_offsets[i + 1] - row_offsets[i]
|
|
447
|
+
curr_y = y[x_offsets[i] : x_offsets[i + 1]].reshape(Ni, Ni)
|
|
448
|
+
curr_grad = grad_output[x_offsets[i] : x_offsets[i + 1]].reshape(Ni, Ni)
|
|
449
|
+
grad[x_offsets[i] : x_offsets[i + 1]] = (
|
|
450
|
+
curr_y
|
|
451
|
+
* (
|
|
452
|
+
curr_grad
|
|
453
|
+
- torch.sum(curr_grad * curr_y, dim=softmax_dim, keepdim=True)
|
|
454
|
+
)
|
|
455
|
+
).view(-1)
|
|
456
|
+
|
|
457
|
+
return grad, None, None, None, None, None, None
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def cpu_jagged2_softmax(
|
|
461
|
+
x: torch.Tensor,
|
|
462
|
+
offsets: torch.Tensor,
|
|
463
|
+
offsets_total: torch.Tensor,
|
|
464
|
+
max_seq_len: int,
|
|
465
|
+
transpose: bool,
|
|
466
|
+
) -> torch.Tensor:
|
|
467
|
+
"""
|
|
468
|
+
GPU version of jagged2 softmax: [sum(softmax([B_i, B_i]))]
|
|
469
|
+
"""
|
|
470
|
+
return Jagged2SoftmaxCPU.apply(
|
|
471
|
+
x,
|
|
472
|
+
offsets_total,
|
|
473
|
+
offsets,
|
|
474
|
+
offsets,
|
|
475
|
+
max_seq_len,
|
|
476
|
+
max_seq_len,
|
|
477
|
+
transpose,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
482
|
+
def cpu_jagged_jagged_bmm_jagged_out_kernel(
|
|
483
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
484
|
+
jagged_A,
|
|
485
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
486
|
+
jagged_B,
|
|
487
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
488
|
+
max_seq_len,
|
|
489
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
490
|
+
lengths_m,
|
|
491
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
492
|
+
lengths_n,
|
|
493
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
494
|
+
lengths_mn,
|
|
495
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
496
|
+
offsets_m,
|
|
497
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
498
|
+
offsets_n,
|
|
499
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
500
|
+
offsets_mn,
|
|
501
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
502
|
+
allow_tf32=False,
|
|
503
|
+
):
|
|
504
|
+
jagged_C = torch.empty((int(lengths_mn.sum().item())), dtype=jagged_A.dtype).to(
|
|
505
|
+
jagged_A.device
|
|
506
|
+
)
|
|
507
|
+
B = lengths_m.size(0)
|
|
508
|
+
|
|
509
|
+
for i in range(B):
|
|
510
|
+
jagged_C[offsets_mn[i] : offsets_mn[i + 1]] = torch.matmul(
|
|
511
|
+
jagged_A[offsets_m[i] : offsets_m[i + 1]],
|
|
512
|
+
jagged_B[offsets_n[i] : offsets_n[i + 1]].T,
|
|
513
|
+
).flatten()
|
|
514
|
+
return jagged_C
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
518
|
+
def cpu_array_jagged_bmm_jagged_out_kernel(
|
|
519
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
520
|
+
array_A,
|
|
521
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
522
|
+
jagged_B,
|
|
523
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
524
|
+
lengths_am,
|
|
525
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
526
|
+
lengths_bk,
|
|
527
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
528
|
+
lengths_cm,
|
|
529
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
530
|
+
offsets_am,
|
|
531
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
532
|
+
offsets_bk,
|
|
533
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
534
|
+
offsets_cm,
|
|
535
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
536
|
+
max_seq_len,
|
|
537
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
538
|
+
allow_tf32=False,
|
|
539
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
540
|
+
transpose=0, # one if a is transpose, otherwise zero
|
|
541
|
+
):
|
|
542
|
+
B = lengths_am.size(0)
|
|
543
|
+
D = jagged_B.size(1)
|
|
544
|
+
jagged_C = torch.zeros(
|
|
545
|
+
(int(lengths_cm.sum()), D), device=jagged_B.device, dtype=jagged_B.dtype
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
for i in range(B):
|
|
549
|
+
seq_len = int(lengths_bk[i])
|
|
550
|
+
capped_seq_len = min(seq_len, max_seq_len)
|
|
551
|
+
a = array_A[offsets_am[i] : offsets_am[i + 1]].view(seq_len, seq_len)
|
|
552
|
+
a = a[:capped_seq_len, :capped_seq_len]
|
|
553
|
+
|
|
554
|
+
if transpose:
|
|
555
|
+
a = a.T
|
|
556
|
+
b = jagged_B[offsets_bk[i] : offsets_bk[i] + capped_seq_len]
|
|
557
|
+
jagged_C[offsets_cm[i] : offsets_cm[i] + capped_seq_len] = torch.matmul(a, b)
|
|
558
|
+
|
|
559
|
+
return jagged_C
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
class ArrayJaggedBmmNopaddingCPU(torch.autograd.Function):
|
|
563
|
+
"""
|
|
564
|
+
Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
|
|
565
|
+
z = X * Y
|
|
566
|
+
x: [Sum_B(N_i, N_i)]
|
|
567
|
+
y: [sum_B(N_i), D]
|
|
568
|
+
z: [sum_B(N_i), D]
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
# pyre-fixme
|
|
573
|
+
def forward(
|
|
574
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
575
|
+
ctx,
|
|
576
|
+
x: torch.Tensor,
|
|
577
|
+
y: torch.Tensor,
|
|
578
|
+
x_lengths: torch.Tensor,
|
|
579
|
+
x_offsets: torch.Tensor,
|
|
580
|
+
y_lengths: torch.Tensor,
|
|
581
|
+
y_offsets: torch.Tensor,
|
|
582
|
+
z_lengths: torch.Tensor,
|
|
583
|
+
z_offsets: torch.Tensor,
|
|
584
|
+
max_seq_len: int,
|
|
585
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
586
|
+
allow_tf32,
|
|
587
|
+
):
|
|
588
|
+
ctx.allow_tf32 = allow_tf32
|
|
589
|
+
ctx.max_seq_len = max_seq_len
|
|
590
|
+
|
|
591
|
+
ctx.save_for_backward(
|
|
592
|
+
x,
|
|
593
|
+
y,
|
|
594
|
+
x_lengths,
|
|
595
|
+
y_lengths,
|
|
596
|
+
z_lengths,
|
|
597
|
+
x_offsets,
|
|
598
|
+
y_offsets,
|
|
599
|
+
z_offsets,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
return cpu_array_jagged_bmm_jagged_out_kernel(
|
|
603
|
+
x,
|
|
604
|
+
y,
|
|
605
|
+
x_lengths,
|
|
606
|
+
y_lengths,
|
|
607
|
+
z_lengths,
|
|
608
|
+
x_offsets,
|
|
609
|
+
y_offsets,
|
|
610
|
+
z_offsets,
|
|
611
|
+
max_seq_len,
|
|
612
|
+
allow_tf32,
|
|
613
|
+
0,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
@staticmethod
|
|
617
|
+
# pyre-fixme
|
|
618
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
619
|
+
"""
|
|
620
|
+
z = X * Y
|
|
621
|
+
dX = dZ * YT
|
|
622
|
+
dY = XT * dZ
|
|
623
|
+
|
|
624
|
+
dZ: [sum_B(N_i), D]
|
|
625
|
+
YT: [D, sum_B(N_i)] call Y.T
|
|
626
|
+
XT: transposed
|
|
627
|
+
Z: [sum_B(N_i), D]
|
|
628
|
+
"""
|
|
629
|
+
|
|
630
|
+
(
|
|
631
|
+
x,
|
|
632
|
+
y,
|
|
633
|
+
x_lengths,
|
|
634
|
+
y_lengths,
|
|
635
|
+
z_lengths,
|
|
636
|
+
x_offsets,
|
|
637
|
+
y_offsets,
|
|
638
|
+
z_offsets,
|
|
639
|
+
) = ctx.saved_tensors
|
|
640
|
+
|
|
641
|
+
grad_x = cpu_jagged_jagged_bmm_jagged_out_kernel(
|
|
642
|
+
grad_output,
|
|
643
|
+
y,
|
|
644
|
+
ctx.max_seq_len,
|
|
645
|
+
z_lengths,
|
|
646
|
+
y_lengths,
|
|
647
|
+
x_lengths,
|
|
648
|
+
z_offsets,
|
|
649
|
+
y_offsets,
|
|
650
|
+
x_offsets,
|
|
651
|
+
ctx.allow_tf32,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
|
|
655
|
+
x,
|
|
656
|
+
grad_output,
|
|
657
|
+
x_lengths,
|
|
658
|
+
y_lengths,
|
|
659
|
+
z_lengths,
|
|
660
|
+
x_offsets,
|
|
661
|
+
y_offsets,
|
|
662
|
+
z_offsets,
|
|
663
|
+
ctx.max_seq_len,
|
|
664
|
+
ctx.allow_tf32,
|
|
665
|
+
1,
|
|
666
|
+
)
|
|
667
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
671
|
+
def cpu_array_jagged_bmm_jagged_out(
|
|
672
|
+
x: torch.Tensor,
|
|
673
|
+
y: torch.Tensor,
|
|
674
|
+
x_lengths: torch.Tensor,
|
|
675
|
+
x_offsets: torch.Tensor,
|
|
676
|
+
y_lengths: torch.Tensor,
|
|
677
|
+
y_offsets: torch.Tensor,
|
|
678
|
+
z_lengths: torch.Tensor,
|
|
679
|
+
z_offsets: torch.Tensor,
|
|
680
|
+
max_seq_len: int,
|
|
681
|
+
allow_tf32: bool = True,
|
|
682
|
+
):
|
|
683
|
+
return ArrayJaggedBmmNopaddingCPU.apply(
|
|
684
|
+
x,
|
|
685
|
+
y,
|
|
686
|
+
x_lengths,
|
|
687
|
+
x_offsets,
|
|
688
|
+
y_lengths,
|
|
689
|
+
y_offsets,
|
|
690
|
+
z_lengths,
|
|
691
|
+
z_offsets,
|
|
692
|
+
max_seq_len,
|
|
693
|
+
allow_tf32,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class JaggedJaggedBmmNoPaddingCPU(torch.autograd.Function):
|
|
698
|
+
"""
|
|
699
|
+
Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
|
|
700
|
+
z = x x y^T
|
|
701
|
+
x: [sum_B(M_i), D]
|
|
702
|
+
y: [sum_B(N_i), D]
|
|
703
|
+
z: [sum_B(M_i * N_i)], assuming M_i = N_i
|
|
704
|
+
"""
|
|
705
|
+
|
|
706
|
+
@staticmethod
|
|
707
|
+
# pyre-fixme
|
|
708
|
+
def forward(
|
|
709
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
710
|
+
ctx,
|
|
711
|
+
x: torch.Tensor,
|
|
712
|
+
y: torch.Tensor,
|
|
713
|
+
x_lengths: torch.Tensor,
|
|
714
|
+
x_offsets: torch.Tensor,
|
|
715
|
+
y_lengths: torch.Tensor,
|
|
716
|
+
y_offsets: torch.Tensor,
|
|
717
|
+
z_lengths: torch.Tensor,
|
|
718
|
+
z_offsets: torch.Tensor,
|
|
719
|
+
max_seq_len: int,
|
|
720
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
721
|
+
allow_tf32,
|
|
722
|
+
):
|
|
723
|
+
ctx.allow_tf32 = allow_tf32
|
|
724
|
+
ctx.max_seq_len = max_seq_len
|
|
725
|
+
|
|
726
|
+
ctx.save_for_backward(
|
|
727
|
+
x,
|
|
728
|
+
y,
|
|
729
|
+
x_lengths,
|
|
730
|
+
y_lengths,
|
|
731
|
+
z_lengths,
|
|
732
|
+
x_offsets,
|
|
733
|
+
y_offsets,
|
|
734
|
+
z_offsets,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
return cpu_jagged_jagged_bmm_jagged_out_kernel(
|
|
738
|
+
x,
|
|
739
|
+
y,
|
|
740
|
+
max_seq_len,
|
|
741
|
+
x_lengths,
|
|
742
|
+
y_lengths,
|
|
743
|
+
z_lengths,
|
|
744
|
+
x_offsets,
|
|
745
|
+
y_offsets,
|
|
746
|
+
z_offsets,
|
|
747
|
+
allow_tf32,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
@staticmethod
|
|
751
|
+
# pyre-fixme
|
|
752
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
753
|
+
"""
|
|
754
|
+
z = x x y^T
|
|
755
|
+
x: [sum_B(M_i), D]
|
|
756
|
+
y: [sum_B(N_i), D]
|
|
757
|
+
z: [sum_B(M_i * N_i)], assuming M_i = N_i
|
|
758
|
+
dx = dz x (y^T)^T = > dx = dz x y
|
|
759
|
+
d(y^T) = x^T x dz => dy = dz^T x x
|
|
760
|
+
"""
|
|
761
|
+
(
|
|
762
|
+
x,
|
|
763
|
+
y,
|
|
764
|
+
x_lengths,
|
|
765
|
+
y_lengths,
|
|
766
|
+
z_lengths,
|
|
767
|
+
x_offsets,
|
|
768
|
+
y_offsets,
|
|
769
|
+
z_offsets,
|
|
770
|
+
) = ctx.saved_tensors
|
|
771
|
+
|
|
772
|
+
grad_x = cpu_array_jagged_bmm_jagged_out_kernel(
|
|
773
|
+
grad_output,
|
|
774
|
+
y,
|
|
775
|
+
z_lengths,
|
|
776
|
+
y_lengths,
|
|
777
|
+
x_lengths,
|
|
778
|
+
z_offsets,
|
|
779
|
+
y_offsets,
|
|
780
|
+
x_offsets,
|
|
781
|
+
ctx.max_seq_len,
|
|
782
|
+
ctx.allow_tf32,
|
|
783
|
+
transpose=0,
|
|
784
|
+
)
|
|
785
|
+
grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
|
|
786
|
+
grad_output,
|
|
787
|
+
x,
|
|
788
|
+
z_lengths,
|
|
789
|
+
x_lengths,
|
|
790
|
+
y_lengths,
|
|
791
|
+
z_offsets,
|
|
792
|
+
x_offsets,
|
|
793
|
+
y_offsets,
|
|
794
|
+
ctx.max_seq_len,
|
|
795
|
+
ctx.allow_tf32,
|
|
796
|
+
transpose=1,
|
|
797
|
+
)
|
|
798
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
802
|
+
def cpu_jagged_jagged_bmm_jagged_out(
|
|
803
|
+
x: torch.Tensor,
|
|
804
|
+
y: torch.Tensor,
|
|
805
|
+
x_lengths: torch.Tensor,
|
|
806
|
+
x_offsets: torch.Tensor,
|
|
807
|
+
y_lengths: torch.Tensor,
|
|
808
|
+
y_offsets: torch.Tensor,
|
|
809
|
+
z_lengths: torch.Tensor,
|
|
810
|
+
z_offsets: torch.Tensor,
|
|
811
|
+
max_seq_len: int,
|
|
812
|
+
allow_tf32: bool = True,
|
|
813
|
+
):
|
|
814
|
+
return JaggedJaggedBmmNoPaddingCPU.apply(
|
|
815
|
+
x,
|
|
816
|
+
y,
|
|
817
|
+
x_lengths,
|
|
818
|
+
x_offsets,
|
|
819
|
+
y_lengths,
|
|
820
|
+
y_offsets,
|
|
821
|
+
z_lengths,
|
|
822
|
+
z_offsets,
|
|
823
|
+
max_seq_len,
|
|
824
|
+
allow_tf32,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
def cpu_jagged_flash_attention_basic(
|
|
829
|
+
q_weights: torch.Tensor,
|
|
830
|
+
k_weights: torch.Tensor,
|
|
831
|
+
v_weights: torch.Tensor,
|
|
832
|
+
offsets: torch.Tensor,
|
|
833
|
+
max_seq_len: int,
|
|
834
|
+
use_mask: bool = False,
|
|
835
|
+
allow_tf32: bool = True,
|
|
836
|
+
) -> torch.Tensor:
|
|
837
|
+
num_objects = offsets[1:] - offsets[0:-1:1]
|
|
838
|
+
attn_lengths = num_objects * num_objects
|
|
839
|
+
attn_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(attn_lengths)
|
|
840
|
+
|
|
841
|
+
s = torch.ops.fbgemm.sll_jagged_jagged_bmm_jagged_out(
|
|
842
|
+
x=q_weights,
|
|
843
|
+
y=k_weights, # transpose is done inside the function
|
|
844
|
+
x_lengths=num_objects,
|
|
845
|
+
x_offsets=offsets,
|
|
846
|
+
y_lengths=num_objects,
|
|
847
|
+
y_offsets=offsets,
|
|
848
|
+
z_lengths=attn_lengths,
|
|
849
|
+
z_offsets=attn_offsets,
|
|
850
|
+
max_seq_len=max_seq_len,
|
|
851
|
+
allow_tf32=allow_tf32,
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
p = (
|
|
855
|
+
torch.ops.fbgemm.sll_jagged2_softmax(
|
|
856
|
+
x=s,
|
|
857
|
+
offsets=offsets,
|
|
858
|
+
offsets_total=attn_offsets,
|
|
859
|
+
max_seq_len=max_seq_len,
|
|
860
|
+
transpose=False,
|
|
861
|
+
)
|
|
862
|
+
/ max_seq_len
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
if use_mask:
|
|
866
|
+
attn_mask = torch.triu(
|
|
867
|
+
torch.ones(
|
|
868
|
+
(max_seq_len, max_seq_len),
|
|
869
|
+
dtype=torch.bool,
|
|
870
|
+
device=q_weights.device,
|
|
871
|
+
),
|
|
872
|
+
).requires_grad_(False)
|
|
873
|
+
# p = p * attn_mask
|
|
874
|
+
p = torch.ops.fbgemm.sll_jagged_dense_elementwise_mul_jagged_out(
|
|
875
|
+
x=p,
|
|
876
|
+
y=attn_mask,
|
|
877
|
+
x_seq_lengths=num_objects,
|
|
878
|
+
x_offsets=attn_offsets,
|
|
879
|
+
max_seq_len=max_seq_len,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
jagged_O = torch.ops.fbgemm.sll_array_jagged_bmm_jagged_out(
|
|
883
|
+
x=p,
|
|
884
|
+
y=v_weights,
|
|
885
|
+
x_lengths=attn_lengths,
|
|
886
|
+
x_offsets=attn_offsets,
|
|
887
|
+
y_lengths=num_objects,
|
|
888
|
+
y_offsets=offsets,
|
|
889
|
+
z_lengths=num_objects,
|
|
890
|
+
z_offsets=offsets,
|
|
891
|
+
max_seq_len=max_seq_len,
|
|
892
|
+
allow_tf32=allow_tf32,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
return jagged_O
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
class JaggedDenseAddCPU(torch.autograd.Function):
|
|
899
|
+
@staticmethod
|
|
900
|
+
# pyre-fixme
|
|
901
|
+
def forward(
|
|
902
|
+
ctx: Any, # pyre-ignore
|
|
903
|
+
x: torch.Tensor,
|
|
904
|
+
x_offsets: torch.Tensor,
|
|
905
|
+
y: torch.Tensor,
|
|
906
|
+
max_seq_len: int,
|
|
907
|
+
) -> torch.Tensor:
|
|
908
|
+
ctx.save_for_backward(x_offsets)
|
|
909
|
+
ctx.max_seq_len = max_seq_len
|
|
910
|
+
# TODO: what should be the correct behavior when jagged values has length > max seq len?
|
|
911
|
+
# current behavior is to not truncate jagged values
|
|
912
|
+
# similar for backward grad_output
|
|
913
|
+
padded_x = torch.ops.fbgemm.jagged_to_padded_dense(
|
|
914
|
+
x,
|
|
915
|
+
[x_offsets],
|
|
916
|
+
max_lengths=[max_seq_len],
|
|
917
|
+
padding_value=0.0,
|
|
918
|
+
) # [B, max_seq_len, D]
|
|
919
|
+
return torch.ops.fbgemm.dense_to_jagged(padded_x + y, [x_offsets])[0]
|
|
920
|
+
|
|
921
|
+
@staticmethod
|
|
922
|
+
# pyre-fixme
|
|
923
|
+
def backward(
|
|
924
|
+
ctx, # pyre-ignore
|
|
925
|
+
grad_output: torch.Tensor,
|
|
926
|
+
) -> tuple[torch.Tensor, None, torch.Tensor, None]:
|
|
927
|
+
(offsets,) = ctx.saved_tensors
|
|
928
|
+
grad_dense = torch.ops.fbgemm.jagged_to_padded_dense(
|
|
929
|
+
grad_output, [offsets], [ctx.max_seq_len]
|
|
930
|
+
)
|
|
931
|
+
return grad_output, None, grad_dense, None
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def cpu_jagged_dense_elementwise_add(
|
|
935
|
+
x: torch.Tensor,
|
|
936
|
+
x_offsets: torch.Tensor,
|
|
937
|
+
y: torch.Tensor,
|
|
938
|
+
max_seq_len: int,
|
|
939
|
+
use_fbgemm_kernel: bool = True,
|
|
940
|
+
) -> torch.Tensor:
|
|
941
|
+
# Force the CPU backend to use fbgemm kernel as it has better performance
|
|
942
|
+
use_fbgemm_kernel = True
|
|
943
|
+
if use_fbgemm_kernel:
|
|
944
|
+
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
|
|
945
|
+
x, [x_offsets], y
|
|
946
|
+
)[0]
|
|
947
|
+
else:
|
|
948
|
+
return JaggedDenseAddCPU.apply(x, x_offsets, y, max_seq_len)
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def cpu_jagged_dense_flash_attention(
|
|
952
|
+
q: torch.Tensor,
|
|
953
|
+
k: torch.Tensor,
|
|
954
|
+
v: torch.Tensor,
|
|
955
|
+
attn_bias: torch.Tensor,
|
|
956
|
+
offsets: torch.Tensor,
|
|
957
|
+
max_seq_len: int,
|
|
958
|
+
allow_tf32: bool = True,
|
|
959
|
+
) -> torch.Tensor:
|
|
960
|
+
"""
|
|
961
|
+
q: jagged tensor, [sum_B, D]
|
|
962
|
+
k: dense tensor, [B, D, T]
|
|
963
|
+
v: jagged tensor [sum_B, D]
|
|
964
|
+
attn_bias: dense tensor [B, N, T]
|
|
965
|
+
offsets: offsets for jagged tensor [B + 1]
|
|
966
|
+
"""
|
|
967
|
+
|
|
968
|
+
# [sum_B, D] * [B, D, T] = [sum_B, T]
|
|
969
|
+
qk = torch.ops.fbgemm.sll_jagged_dense_bmm(
|
|
970
|
+
q,
|
|
971
|
+
k.to(q.dtype),
|
|
972
|
+
offsets,
|
|
973
|
+
max_seq_len,
|
|
974
|
+
allow_tf32=allow_tf32,
|
|
975
|
+
use_fbgemm_kernel=True,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
softmax_input = torch.ops.fbgemm.sll_jagged_dense_elementwise_add(
|
|
979
|
+
qk,
|
|
980
|
+
offsets,
|
|
981
|
+
attn_bias,
|
|
982
|
+
max_seq_len,
|
|
983
|
+
use_fbgemm_kernel=True,
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
normed_attn_weights = torch.ops.fbgemm.sll_jagged_softmax(
|
|
987
|
+
softmax_input,
|
|
988
|
+
offsets,
|
|
989
|
+
max_seq_len,
|
|
990
|
+
use_fbgemm_kernel=True,
|
|
991
|
+
) # [sum_B, T]
|
|
992
|
+
|
|
993
|
+
# [sum_B, T] * [sum_B, D] = [B, T, D]
|
|
994
|
+
return torch.ops.fbgemm.sll_jagged_jagged_bmm(
|
|
995
|
+
normed_attn_weights,
|
|
996
|
+
v.to(normed_attn_weights.dtype),
|
|
997
|
+
offsets,
|
|
998
|
+
max_seq_len,
|
|
999
|
+
allow_tf32=allow_tf32,
|
|
1000
|
+
use_fbgemm_kernel=True,
|
|
1001
|
+
)
|