fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from fbgemm_gpu.sll.meta.meta_sll import ( # noqa F401
|
|
11
|
+
meta_array_jagged_bmm_jagged_out,
|
|
12
|
+
meta_jagged2_softmax,
|
|
13
|
+
meta_jagged_dense_elementwise_mul_jagged_out,
|
|
14
|
+
meta_jagged_jagged_bmm_jagged_out,
|
|
15
|
+
meta_jagged_self_substraction_jagged_out,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# pyre-ignore[5]
|
|
19
|
+
op_registrations = {
|
|
20
|
+
"sll_jagged_self_substraction_jagged_out": {
|
|
21
|
+
"Meta": meta_jagged_self_substraction_jagged_out,
|
|
22
|
+
},
|
|
23
|
+
"sll_jagged_dense_elementwise_mul_jagged_out": {
|
|
24
|
+
"Meta": meta_jagged_dense_elementwise_mul_jagged_out,
|
|
25
|
+
},
|
|
26
|
+
"sll_jagged2_softmax": {
|
|
27
|
+
"AutogradMeta": meta_jagged2_softmax,
|
|
28
|
+
},
|
|
29
|
+
"sll_array_jagged_bmm_jagged_out": {
|
|
30
|
+
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
|
|
31
|
+
},
|
|
32
|
+
"sll_jagged_jagged_bmm_jagged_out": {
|
|
33
|
+
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
|
|
34
|
+
},
|
|
35
|
+
}
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def meta_jagged_self_substraction_jagged_out(
|
|
12
|
+
jagged_A: torch.Tensor,
|
|
13
|
+
offsets_a: torch.Tensor,
|
|
14
|
+
offsets_b: torch.Tensor,
|
|
15
|
+
max_seq_len: int,
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
return torch.empty(
|
|
18
|
+
[torch.library.get_ctx().new_dynamic_size()],
|
|
19
|
+
dtype=jagged_A.dtype,
|
|
20
|
+
device=jagged_A.device,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MetaJaggedDenseElementwiseMul(torch.autograd.Function):
|
|
25
|
+
@staticmethod
|
|
26
|
+
# pyre-fixme
|
|
27
|
+
def forward(
|
|
28
|
+
ctx, # pyre-ignore [2]
|
|
29
|
+
x: torch.Tensor,
|
|
30
|
+
y: torch.Tensor,
|
|
31
|
+
x_seq_lengths: torch.Tensor,
|
|
32
|
+
x_offsets: torch.Tensor,
|
|
33
|
+
max_seq_len: int,
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
ctx.max_seq_len = max_seq_len
|
|
36
|
+
|
|
37
|
+
ctx.save_for_backward(
|
|
38
|
+
x,
|
|
39
|
+
y,
|
|
40
|
+
x_seq_lengths,
|
|
41
|
+
x_offsets,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
total_L = x.size(0)
|
|
45
|
+
jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype)
|
|
46
|
+
|
|
47
|
+
return jagged_C
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
# pyre-fixme
|
|
51
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
52
|
+
(
|
|
53
|
+
x,
|
|
54
|
+
y,
|
|
55
|
+
x_seq_lengths,
|
|
56
|
+
x_offsets,
|
|
57
|
+
) = ctx.saved_tensors
|
|
58
|
+
|
|
59
|
+
total_L = grad_output.size(0)
|
|
60
|
+
jagged_C = torch.zeros(
|
|
61
|
+
(total_L), device=grad_output.device, dtype=grad_output.dtype
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return jagged_C, None, None, None, None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def meta_jagged_dense_elementwise_mul_jagged_out(
|
|
68
|
+
x: torch.Tensor,
|
|
69
|
+
y: torch.Tensor,
|
|
70
|
+
x_seq_lengths: torch.Tensor,
|
|
71
|
+
x_offsets: torch.Tensor,
|
|
72
|
+
max_seq_len: int,
|
|
73
|
+
) -> torch.Tensor:
|
|
74
|
+
return MetaJaggedDenseElementwiseMul.apply(
|
|
75
|
+
x,
|
|
76
|
+
y,
|
|
77
|
+
x_seq_lengths,
|
|
78
|
+
x_offsets,
|
|
79
|
+
max_seq_len,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class Jagged2SoftmaxMeta(torch.autograd.Function):
|
|
84
|
+
@staticmethod
|
|
85
|
+
# pyre-fixme
|
|
86
|
+
def forward(
|
|
87
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
88
|
+
ctx,
|
|
89
|
+
x: torch.Tensor,
|
|
90
|
+
x_offsets: torch.Tensor,
|
|
91
|
+
row_offsets: torch.Tensor,
|
|
92
|
+
head_offsets: torch.Tensor,
|
|
93
|
+
max_seq_len_row: int,
|
|
94
|
+
max_seq_len_head: int,
|
|
95
|
+
transpose: bool = True,
|
|
96
|
+
) -> torch.Tensor:
|
|
97
|
+
y = torch.rand(x.size(0), device=x.device, dtype=x.dtype)
|
|
98
|
+
|
|
99
|
+
ctx.save_for_backward(y, x_offsets, row_offsets, head_offsets)
|
|
100
|
+
ctx.max_seq_len_row = max_seq_len_row
|
|
101
|
+
ctx.max_seq_len_head = max_seq_len_head
|
|
102
|
+
|
|
103
|
+
return y
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
# pyre-fixme
|
|
107
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
108
|
+
y, x_offsets, row_offsets, head_offsets = ctx.saved_tensors
|
|
109
|
+
grad = torch.rand(y.size(), device=y.device, dtype=y.dtype)
|
|
110
|
+
|
|
111
|
+
return grad, None, None, None, None, None, None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def meta_jagged2_softmax(
|
|
115
|
+
x: torch.Tensor,
|
|
116
|
+
offsets: torch.Tensor,
|
|
117
|
+
offsets_total: torch.Tensor,
|
|
118
|
+
max_seq_len: int,
|
|
119
|
+
transpose: bool,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Meta version of jagged2 softmax: [sum(softmax([B_i, B_i]))]
|
|
123
|
+
"""
|
|
124
|
+
return Jagged2SoftmaxMeta.apply(
|
|
125
|
+
x,
|
|
126
|
+
offsets_total,
|
|
127
|
+
offsets,
|
|
128
|
+
offsets,
|
|
129
|
+
max_seq_len,
|
|
130
|
+
max_seq_len,
|
|
131
|
+
transpose,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ArrayJaggedBmmNopadding(torch.autograd.Function):
|
|
136
|
+
"""
|
|
137
|
+
Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
|
|
138
|
+
z = X * Y
|
|
139
|
+
x: [Sum_B(N_i, N_i)]
|
|
140
|
+
y: [sum_B(N_i), D]
|
|
141
|
+
z: [sum_B(N_i), D]
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
# pyre-fixme
|
|
146
|
+
def forward(
|
|
147
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
148
|
+
ctx,
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
y: torch.Tensor,
|
|
151
|
+
x_lengths: torch.Tensor,
|
|
152
|
+
x_offsets: torch.Tensor,
|
|
153
|
+
y_lengths: torch.Tensor,
|
|
154
|
+
y_offsets: torch.Tensor,
|
|
155
|
+
z_lengths: torch.Tensor,
|
|
156
|
+
z_offsets: torch.Tensor,
|
|
157
|
+
max_seq_len: int,
|
|
158
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
159
|
+
allow_tf32,
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
ctx.allow_tf32 = allow_tf32
|
|
162
|
+
ctx.max_seq_len = max_seq_len
|
|
163
|
+
|
|
164
|
+
ctx.save_for_backward(
|
|
165
|
+
x,
|
|
166
|
+
y,
|
|
167
|
+
x_lengths,
|
|
168
|
+
y_lengths,
|
|
169
|
+
z_lengths,
|
|
170
|
+
x_offsets,
|
|
171
|
+
y_offsets,
|
|
172
|
+
z_offsets,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
D = y.size(1)
|
|
176
|
+
L = y.size(0)
|
|
177
|
+
# gradients of the emb vectors beyond max_seq_len is set to zeros
|
|
178
|
+
jagged_C = torch.zeros((L, D), device=y.device, dtype=y.dtype)
|
|
179
|
+
return jagged_C
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
# pyre-fixme
|
|
183
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
184
|
+
"""
|
|
185
|
+
z = X * Y
|
|
186
|
+
dX = dZ * YT
|
|
187
|
+
dY = XT * dZ
|
|
188
|
+
|
|
189
|
+
dZ: [sum_B(N_i), D]
|
|
190
|
+
YT: [D, sum_B(N_i)] call Y.T
|
|
191
|
+
XT: transposed
|
|
192
|
+
Z: [sum_B(N_i), D]
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
(
|
|
196
|
+
x,
|
|
197
|
+
y,
|
|
198
|
+
x_lengths,
|
|
199
|
+
y_lengths,
|
|
200
|
+
z_lengths,
|
|
201
|
+
x_offsets,
|
|
202
|
+
y_offsets,
|
|
203
|
+
z_offsets,
|
|
204
|
+
) = ctx.saved_tensors
|
|
205
|
+
|
|
206
|
+
grad_x = torch.zeros(
|
|
207
|
+
(x.size()), device=grad_output.device, dtype=grad_output.dtype
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# gradients of the emb vectors beyond max_seq_len is set to zeros
|
|
211
|
+
grad_y = torch.zeros(
|
|
212
|
+
y.size(), device=grad_output.device, dtype=grad_output.dtype
|
|
213
|
+
)
|
|
214
|
+
return (
|
|
215
|
+
grad_x,
|
|
216
|
+
grad_y,
|
|
217
|
+
None,
|
|
218
|
+
None,
|
|
219
|
+
None,
|
|
220
|
+
None,
|
|
221
|
+
None,
|
|
222
|
+
None,
|
|
223
|
+
None,
|
|
224
|
+
None,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
229
|
+
def meta_array_jagged_bmm_jagged_out(
|
|
230
|
+
x: torch.Tensor,
|
|
231
|
+
y: torch.Tensor,
|
|
232
|
+
x_lengths: torch.Tensor,
|
|
233
|
+
x_offsets: torch.Tensor,
|
|
234
|
+
y_lengths: torch.Tensor,
|
|
235
|
+
y_offsets: torch.Tensor,
|
|
236
|
+
z_lengths: torch.Tensor,
|
|
237
|
+
z_offsets: torch.Tensor,
|
|
238
|
+
max_seq_len: int,
|
|
239
|
+
allow_tf32: bool = True,
|
|
240
|
+
):
|
|
241
|
+
return ArrayJaggedBmmNopadding.apply(
|
|
242
|
+
x,
|
|
243
|
+
y,
|
|
244
|
+
x_lengths,
|
|
245
|
+
x_offsets,
|
|
246
|
+
y_lengths,
|
|
247
|
+
y_offsets,
|
|
248
|
+
z_lengths,
|
|
249
|
+
z_offsets,
|
|
250
|
+
max_seq_len,
|
|
251
|
+
allow_tf32,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class JaggedJaggedBmmNoPaddingMeta(torch.autograd.Function):
|
|
256
|
+
@staticmethod
|
|
257
|
+
# pyre-fixme
|
|
258
|
+
def forward(
|
|
259
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
260
|
+
ctx,
|
|
261
|
+
x: torch.Tensor,
|
|
262
|
+
y: torch.Tensor,
|
|
263
|
+
x_lengths: torch.Tensor,
|
|
264
|
+
x_offsets: torch.Tensor,
|
|
265
|
+
y_lengths: torch.Tensor,
|
|
266
|
+
y_offsets: torch.Tensor,
|
|
267
|
+
z_lengths: torch.Tensor,
|
|
268
|
+
z_offsets: torch.Tensor,
|
|
269
|
+
max_seq_len: int,
|
|
270
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
271
|
+
allow_tf32,
|
|
272
|
+
):
|
|
273
|
+
assert x.size(1) == y.size(0), "incompatible dimensions"
|
|
274
|
+
|
|
275
|
+
ctx.allow_tf32 = allow_tf32
|
|
276
|
+
ctx.max_seq_len = max_seq_len
|
|
277
|
+
|
|
278
|
+
ctx.save_for_backward(
|
|
279
|
+
x,
|
|
280
|
+
y,
|
|
281
|
+
x_lengths,
|
|
282
|
+
y_lengths,
|
|
283
|
+
z_lengths,
|
|
284
|
+
x_offsets,
|
|
285
|
+
y_offsets,
|
|
286
|
+
z_offsets,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# pyre-fixme[6]: For 1st argument expected `Sequence[Union[int, SymInt]]`
|
|
290
|
+
# but got `Tensor`.
|
|
291
|
+
c = torch.rand((z_lengths.sum()), device=x.device, dtype=x.dtype)
|
|
292
|
+
return c
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
# pyre-fixme
|
|
296
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
297
|
+
(
|
|
298
|
+
x,
|
|
299
|
+
y,
|
|
300
|
+
x_lengths,
|
|
301
|
+
y_lengths,
|
|
302
|
+
z_lengths,
|
|
303
|
+
x_offsets,
|
|
304
|
+
y_offsets,
|
|
305
|
+
z_offsets,
|
|
306
|
+
) = ctx.saved_tensors
|
|
307
|
+
|
|
308
|
+
grad_x = torch.rand(x.size(), device=x.device, dtype=x.dtype)
|
|
309
|
+
grad_y = torch.rand(y.size(), device=y.device, dtype=y.dtype)
|
|
310
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
314
|
+
def meta_jagged_jagged_bmm_jagged_out(
|
|
315
|
+
x: torch.Tensor,
|
|
316
|
+
y: torch.Tensor,
|
|
317
|
+
x_lengths: torch.Tensor,
|
|
318
|
+
x_offsets: torch.Tensor,
|
|
319
|
+
y_lengths: torch.Tensor,
|
|
320
|
+
y_offsets: torch.Tensor,
|
|
321
|
+
z_lengths: torch.Tensor,
|
|
322
|
+
z_offsets: torch.Tensor,
|
|
323
|
+
max_seq_len: int,
|
|
324
|
+
allow_tf32: bool = True,
|
|
325
|
+
):
|
|
326
|
+
return JaggedJaggedBmmNoPaddingMeta.apply(
|
|
327
|
+
x,
|
|
328
|
+
y,
|
|
329
|
+
x_lengths,
|
|
330
|
+
x_offsets,
|
|
331
|
+
y_lengths,
|
|
332
|
+
y_offsets,
|
|
333
|
+
z_lengths,
|
|
334
|
+
z_offsets,
|
|
335
|
+
max_seq_len,
|
|
336
|
+
allow_tf32,
|
|
337
|
+
)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
|
|
11
|
+
dense_jagged_cat_jagged_out,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
|
|
15
|
+
jagged2_to_padded_dense,
|
|
16
|
+
Jagged2ToPaddedDense, # noqa F401
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
|
|
20
|
+
jagged_dense_bmm,
|
|
21
|
+
jagged_jagged_bmm,
|
|
22
|
+
JaggedDenseBmm, # noqa F401
|
|
23
|
+
JaggedJaggedBmm, # noqa F401
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
|
|
27
|
+
array_jagged_bmm_jagged_out,
|
|
28
|
+
ArrayJaggedBmmNopadding, # noqa F401
|
|
29
|
+
jagged_jagged_bmm_jagged_out,
|
|
30
|
+
JaggedJaggedBmmNoPadding, # noqa F401
|
|
31
|
+
triton_array_jagged_bmm_jagged_out, # noqa F401
|
|
32
|
+
triton_jagged_jagged_bmm_jagged_out, # noqa F401
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
|
|
36
|
+
jagged_dense_elementwise_add,
|
|
37
|
+
JaggedDenseAdd, # noqa F401
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
|
|
41
|
+
jagged_dense_elementwise_mul_jagged_out,
|
|
42
|
+
JaggedDenseElementwiseMul, # noqa F401
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
|
|
46
|
+
jagged_dense_flash_attention,
|
|
47
|
+
JaggedDenseFlashAttention, # noqa F401
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
from fbgemm_gpu.sll.triton.triton_jagged_flash_attention_basic import ( # noqa F401
|
|
51
|
+
jagged_flash_attention_basic,
|
|
52
|
+
JaggedFlashAttentionBasic, # noqa F401
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
|
|
56
|
+
triton_jagged_self_substraction_jagged_out,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
|
|
60
|
+
jagged2_softmax,
|
|
61
|
+
Jagged2Softmax, # noqa F401
|
|
62
|
+
jagged_softmax,
|
|
63
|
+
JaggedSoftmax, # noqa F401
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
from fbgemm_gpu.sll.triton.triton_multi_head_jagged_flash_attention import ( # noqa F401
|
|
67
|
+
multi_head_jagged_flash_attention,
|
|
68
|
+
MultiHeadJaggedFlashAttention, # noqa F401
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# pyre-ignore[5]
|
|
72
|
+
op_registrations = {
|
|
73
|
+
"sll_dense_jagged_cat_jagged_out": {
|
|
74
|
+
"CUDA": dense_jagged_cat_jagged_out,
|
|
75
|
+
},
|
|
76
|
+
"sll_jagged_dense_bmm": {
|
|
77
|
+
"CUDA": jagged_dense_bmm,
|
|
78
|
+
"AutogradCUDA": jagged_dense_bmm,
|
|
79
|
+
},
|
|
80
|
+
"sll_jagged_jagged_bmm": {
|
|
81
|
+
"CUDA": jagged_jagged_bmm,
|
|
82
|
+
"AutogradCUDA": jagged_jagged_bmm,
|
|
83
|
+
},
|
|
84
|
+
"sll_jagged2_to_padded_dense": {
|
|
85
|
+
"CUDA": jagged2_to_padded_dense,
|
|
86
|
+
"AutogradCUDA": jagged2_to_padded_dense,
|
|
87
|
+
},
|
|
88
|
+
"sll_array_jagged_bmm_jagged_out": {
|
|
89
|
+
"CUDA": array_jagged_bmm_jagged_out,
|
|
90
|
+
"AutogradCUDA": array_jagged_bmm_jagged_out,
|
|
91
|
+
},
|
|
92
|
+
"sll_jagged_jagged_bmm_jagged_out": {
|
|
93
|
+
"CUDA": jagged_jagged_bmm_jagged_out,
|
|
94
|
+
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
|
|
95
|
+
},
|
|
96
|
+
"sll_jagged_softmax": {
|
|
97
|
+
"CUDA": jagged_softmax,
|
|
98
|
+
"AutogradCUDA": jagged_softmax,
|
|
99
|
+
},
|
|
100
|
+
"sll_jagged2_softmax": {
|
|
101
|
+
"CUDA": jagged2_softmax,
|
|
102
|
+
"AutogradCUDA": jagged2_softmax,
|
|
103
|
+
},
|
|
104
|
+
"sll_jagged_dense_elementwise_add": {
|
|
105
|
+
"CUDA": jagged_dense_elementwise_add,
|
|
106
|
+
"AutogradCUDA": jagged_dense_elementwise_add,
|
|
107
|
+
},
|
|
108
|
+
"sll_jagged_dense_flash_attention": {
|
|
109
|
+
"CUDA": jagged_dense_flash_attention,
|
|
110
|
+
"AutogradCUDA": jagged_dense_flash_attention,
|
|
111
|
+
},
|
|
112
|
+
"sll_jagged_flash_attention_basic": {
|
|
113
|
+
"CUDA": jagged_flash_attention_basic,
|
|
114
|
+
"AutogradCUDA": jagged_flash_attention_basic,
|
|
115
|
+
},
|
|
116
|
+
"sll_multi_head_jagged_flash_attention": {
|
|
117
|
+
"CUDA": multi_head_jagged_flash_attention,
|
|
118
|
+
"AutogradCUDA": multi_head_jagged_flash_attention,
|
|
119
|
+
},
|
|
120
|
+
"sll_jagged_self_substraction_jagged_out": {
|
|
121
|
+
"CUDA": triton_jagged_self_substraction_jagged_out,
|
|
122
|
+
},
|
|
123
|
+
"sll_jagged_dense_elementwise_mul_jagged_out": {
|
|
124
|
+
"CUDA": jagged_dense_elementwise_mul_jagged_out,
|
|
125
|
+
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
|
|
126
|
+
},
|
|
127
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def next_power_of_two(N: int) -> int:
|
|
13
|
+
if N > 4096:
|
|
14
|
+
raise Exception(f"{N} is too large that is not supported yet")
|
|
15
|
+
|
|
16
|
+
if N > 2048:
|
|
17
|
+
return 4096
|
|
18
|
+
elif N > 1024:
|
|
19
|
+
return 2048
|
|
20
|
+
elif N > 512:
|
|
21
|
+
return 1024
|
|
22
|
+
elif N > 256:
|
|
23
|
+
return 512
|
|
24
|
+
elif N > 128:
|
|
25
|
+
return 256
|
|
26
|
+
elif N > 64:
|
|
27
|
+
return 128
|
|
28
|
+
elif N > 32:
|
|
29
|
+
return 64
|
|
30
|
+
else:
|
|
31
|
+
return 32
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
if not x.is_contiguous():
|
|
36
|
+
return x.contiguous()
|
|
37
|
+
else:
|
|
38
|
+
return x
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import triton
|
|
11
|
+
import triton.language as tl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@triton.jit
|
|
15
|
+
def dense_jagged_cat_jagged_out_kernel(
|
|
16
|
+
a_ptr, # dense
|
|
17
|
+
b_ptr, # jagged
|
|
18
|
+
c_ptr, # jagged
|
|
19
|
+
b_offsets_ptr,
|
|
20
|
+
c_offsets_ptr,
|
|
21
|
+
max_seq_len,
|
|
22
|
+
BLOCK_SIZE: tl.constexpr,
|
|
23
|
+
):
|
|
24
|
+
pid_batch = tl.program_id(0)
|
|
25
|
+
b_start = tl.load(b_offsets_ptr + pid_batch)
|
|
26
|
+
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
|
|
27
|
+
c_start = b_start + pid_batch
|
|
28
|
+
N = b_end - b_start
|
|
29
|
+
N = tl.minimum(N, max_seq_len)
|
|
30
|
+
|
|
31
|
+
a = tl.load(a_ptr + pid_batch)
|
|
32
|
+
tl.store(c_ptr + c_start, a)
|
|
33
|
+
|
|
34
|
+
offs_k = tl.arange(0, BLOCK_SIZE)
|
|
35
|
+
for k in range(0, N, BLOCK_SIZE):
|
|
36
|
+
b_offset = k + offs_k
|
|
37
|
+
b_ptrs = b_ptr + b_start + b_offset
|
|
38
|
+
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
|
|
39
|
+
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
|
|
40
|
+
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def dense_jagged_cat_jagged_out(
|
|
44
|
+
a: torch.Tensor,
|
|
45
|
+
b: torch.Tensor,
|
|
46
|
+
b_offsets: torch.Tensor,
|
|
47
|
+
max_seq_len: int,
|
|
48
|
+
):
|
|
49
|
+
assert a.is_contiguous()
|
|
50
|
+
assert b.is_contiguous()
|
|
51
|
+
assert b_offsets.is_contiguous()
|
|
52
|
+
B = a.size(0)
|
|
53
|
+
BLOCK_SIZE = 128
|
|
54
|
+
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
|
|
55
|
+
c_offsets = torch.empty(
|
|
56
|
+
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
|
|
57
|
+
) # B + 1
|
|
58
|
+
|
|
59
|
+
dense_jagged_cat_jagged_out_kernel[(B,)](
|
|
60
|
+
a,
|
|
61
|
+
b,
|
|
62
|
+
c,
|
|
63
|
+
b_offsets,
|
|
64
|
+
c_offsets,
|
|
65
|
+
max_seq_len,
|
|
66
|
+
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
|
|
67
|
+
BLOCK_SIZE,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
c_offsets[-1] = b_offsets[-1] + B
|
|
71
|
+
|
|
72
|
+
return c, c_offsets
|