fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import triton
|
|
12
|
+
import triton.language as tl
|
|
13
|
+
|
|
14
|
+
from .common import expect_contiguous
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@triton.jit
|
|
18
|
+
def jagged2_to_padded_dense_kernel(
|
|
19
|
+
x_ptr,
|
|
20
|
+
lengths_ptr,
|
|
21
|
+
offsets_ptr,
|
|
22
|
+
output_dense_ptr,
|
|
23
|
+
stride_b,
|
|
24
|
+
stride_m,
|
|
25
|
+
stride_n,
|
|
26
|
+
max_length,
|
|
27
|
+
BLOCK_M: tl.constexpr,
|
|
28
|
+
BLOCK_N: tl.constexpr,
|
|
29
|
+
):
|
|
30
|
+
pid_batch = tl.program_id(2)
|
|
31
|
+
pid_m = tl.program_id(0)
|
|
32
|
+
pid_n = tl.program_id(1)
|
|
33
|
+
|
|
34
|
+
begin = tl.load(offsets_ptr + pid_batch)
|
|
35
|
+
seqlen = tl.load(lengths_ptr + pid_batch)
|
|
36
|
+
|
|
37
|
+
seqlen = tl.minimum(seqlen, max_length)
|
|
38
|
+
if seqlen == 0:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
42
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
43
|
+
|
|
44
|
+
x_ptrs = x_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
|
|
45
|
+
x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))
|
|
46
|
+
|
|
47
|
+
out_ptrs = (
|
|
48
|
+
output_dense_ptr
|
|
49
|
+
+ pid_batch * stride_b
|
|
50
|
+
+ offs_m[:, None] * stride_m
|
|
51
|
+
+ offs_n[None, :] * stride_n
|
|
52
|
+
)
|
|
53
|
+
tl.store(
|
|
54
|
+
out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@triton.jit
|
|
59
|
+
def padded_dense_to_jagged2_kernel(
|
|
60
|
+
x_ptr,
|
|
61
|
+
lengths_ptr,
|
|
62
|
+
offsets_ptr,
|
|
63
|
+
output_jagged_ptr,
|
|
64
|
+
stride_b,
|
|
65
|
+
stride_m,
|
|
66
|
+
stride_n,
|
|
67
|
+
max_length,
|
|
68
|
+
BLOCK_M: tl.constexpr,
|
|
69
|
+
BLOCK_N: tl.constexpr,
|
|
70
|
+
):
|
|
71
|
+
pid_batch = tl.program_id(2)
|
|
72
|
+
pid_m = tl.program_id(0)
|
|
73
|
+
pid_n = tl.program_id(1)
|
|
74
|
+
|
|
75
|
+
begin = tl.load(offsets_ptr + pid_batch)
|
|
76
|
+
# end = tl.load(offsets_ptr + pid_batch + 1)
|
|
77
|
+
seqlen = tl.load(lengths_ptr + pid_batch)
|
|
78
|
+
|
|
79
|
+
seqlen = tl.minimum(seqlen, max_length)
|
|
80
|
+
|
|
81
|
+
if seqlen == 0:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
85
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
86
|
+
|
|
87
|
+
x_ptrs = (
|
|
88
|
+
x_ptr
|
|
89
|
+
+ pid_batch * stride_b
|
|
90
|
+
+ offs_m[:, None] * stride_m
|
|
91
|
+
+ offs_n[None, :] * stride_n
|
|
92
|
+
)
|
|
93
|
+
x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))
|
|
94
|
+
out_ptrs = output_jagged_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
|
|
95
|
+
tl.store(
|
|
96
|
+
out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def jagged2_to_padded_dense_fwd(
|
|
101
|
+
values: torch.Tensor,
|
|
102
|
+
lengths: torch.Tensor,
|
|
103
|
+
offsets: torch.Tensor,
|
|
104
|
+
max_length: int,
|
|
105
|
+
padding_value: float,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
B = offsets.size(0) - 1
|
|
108
|
+
|
|
109
|
+
output_dense = torch.full(
|
|
110
|
+
(B, max_length, max_length),
|
|
111
|
+
padding_value,
|
|
112
|
+
dtype=values.dtype,
|
|
113
|
+
device=values.device,
|
|
114
|
+
)
|
|
115
|
+
BLOCK_M = 32
|
|
116
|
+
BLOCK_N = 32
|
|
117
|
+
num_blocks_m = triton.cdiv(max_length, BLOCK_M)
|
|
118
|
+
num_blocks_n = triton.cdiv(max_length, BLOCK_N)
|
|
119
|
+
grid = (num_blocks_m, num_blocks_n, B)
|
|
120
|
+
|
|
121
|
+
jagged2_to_padded_dense_kernel[grid](
|
|
122
|
+
values,
|
|
123
|
+
lengths,
|
|
124
|
+
offsets,
|
|
125
|
+
output_dense,
|
|
126
|
+
output_dense.stride(0),
|
|
127
|
+
output_dense.stride(1),
|
|
128
|
+
output_dense.stride(2),
|
|
129
|
+
max_length,
|
|
130
|
+
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
|
|
131
|
+
BLOCK_M,
|
|
132
|
+
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
|
|
133
|
+
BLOCK_N,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return output_dense
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def padded_dense_to_jagged2_fwd(
|
|
140
|
+
values: torch.Tensor,
|
|
141
|
+
lengths: torch.Tensor,
|
|
142
|
+
offsets: torch.Tensor,
|
|
143
|
+
max_length: int,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
B = values.size(0)
|
|
146
|
+
output_jagged = torch.empty(
|
|
147
|
+
int(offsets[-1]), dtype=values.dtype, device=values.device
|
|
148
|
+
)
|
|
149
|
+
BLOCK_M = 32
|
|
150
|
+
BLOCK_N = 32
|
|
151
|
+
num_blocks_m = triton.cdiv(max_length, BLOCK_M)
|
|
152
|
+
num_blocks_n = triton.cdiv(max_length, BLOCK_N)
|
|
153
|
+
grid = (num_blocks_m, num_blocks_n, B)
|
|
154
|
+
|
|
155
|
+
padded_dense_to_jagged2_kernel[grid](
|
|
156
|
+
values,
|
|
157
|
+
lengths,
|
|
158
|
+
offsets,
|
|
159
|
+
output_jagged,
|
|
160
|
+
values.stride(0),
|
|
161
|
+
values.stride(1),
|
|
162
|
+
values.stride(2),
|
|
163
|
+
max_length,
|
|
164
|
+
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
|
|
165
|
+
BLOCK_M,
|
|
166
|
+
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
|
|
167
|
+
BLOCK_N,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return output_jagged
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class Jagged2ToPaddedDense(torch.autograd.Function):
|
|
174
|
+
@staticmethod
|
|
175
|
+
# pyre-fixme
|
|
176
|
+
def forward(
|
|
177
|
+
ctx,
|
|
178
|
+
values: torch.Tensor,
|
|
179
|
+
offsets: torch.Tensor,
|
|
180
|
+
max_length: int,
|
|
181
|
+
padding_value: float,
|
|
182
|
+
) -> torch.Tensor:
|
|
183
|
+
lengths_square = offsets[1:] - offsets[0:-1:1]
|
|
184
|
+
lengths = torch.sqrt(lengths_square).to(torch.int32)
|
|
185
|
+
|
|
186
|
+
ctx.max_length = max_length
|
|
187
|
+
ctx.save_for_backward(lengths, offsets)
|
|
188
|
+
|
|
189
|
+
output = jagged2_to_padded_dense_fwd(
|
|
190
|
+
values, lengths, offsets, max_length, padding_value
|
|
191
|
+
)
|
|
192
|
+
return output
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
# pyre-fixme
|
|
196
|
+
def backward(
|
|
197
|
+
ctx, grad_output: torch.Tensor
|
|
198
|
+
) -> tuple[torch.Tensor, None, None, None]:
|
|
199
|
+
max_length = ctx.max_length
|
|
200
|
+
(lengths, offsets) = ctx.saved_tensors
|
|
201
|
+
grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
|
|
202
|
+
return (grad_in, None, None, None)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def jagged2_to_padded_dense(
|
|
206
|
+
values: torch.Tensor,
|
|
207
|
+
offsets: torch.Tensor,
|
|
208
|
+
max_length: int,
|
|
209
|
+
padding_value: float = 0.0,
|
|
210
|
+
) -> torch.Tensor:
|
|
211
|
+
"""
|
|
212
|
+
values: jagged tensor with size [sum(Ni * Ni)]
|
|
213
|
+
offsets: offsets for jagged tensor, with size [B + 1]
|
|
214
|
+
max_length: maximum sequence length in the batch
|
|
215
|
+
padding_value: value to use for padding
|
|
216
|
+
return padded dense tensor of size [B, N, N]
|
|
217
|
+
"""
|
|
218
|
+
values = expect_contiguous(values)
|
|
219
|
+
offsets = expect_contiguous(offsets)
|
|
220
|
+
|
|
221
|
+
return Jagged2ToPaddedDense.apply(values, offsets, max_length, padding_value)
|
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import triton
|
|
11
|
+
import triton.language as tl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def set_block_size(N: int) -> int:
|
|
15
|
+
if N > 64:
|
|
16
|
+
return 64
|
|
17
|
+
elif N > 16:
|
|
18
|
+
return 32
|
|
19
|
+
else:
|
|
20
|
+
return 16
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# TODO add autotune to find best block size
|
|
24
|
+
# add supergroup to optimize GPU cache
|
|
25
|
+
@triton.jit
|
|
26
|
+
def jagged_dense_bmm_kernel(
|
|
27
|
+
a_ptr,
|
|
28
|
+
a_offset_ptr,
|
|
29
|
+
b_ptr,
|
|
30
|
+
c_ptr,
|
|
31
|
+
N,
|
|
32
|
+
K,
|
|
33
|
+
stride_am,
|
|
34
|
+
stride_ak,
|
|
35
|
+
stride_bl, # batch idx
|
|
36
|
+
stride_bk,
|
|
37
|
+
stride_bn,
|
|
38
|
+
stride_cm,
|
|
39
|
+
stride_cn,
|
|
40
|
+
max_seq_len, # max sequence length for jaggged tensor
|
|
41
|
+
allow_tf32: tl.constexpr,
|
|
42
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
43
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
44
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
45
|
+
):
|
|
46
|
+
"""Kernel for computing the matmul C = A x B.
|
|
47
|
+
A has shape (sum_B(M_i), K), B has shape (B, K, N) and C has shape (sum_B(M_i), N)
|
|
48
|
+
"""
|
|
49
|
+
pid_batch = tl.program_id(0)
|
|
50
|
+
pid = tl.program_id(1)
|
|
51
|
+
|
|
52
|
+
# a_offset_ptr has stride of 1
|
|
53
|
+
# row_start for jagged tensor
|
|
54
|
+
begin = tl.load(a_offset_ptr + pid_batch)
|
|
55
|
+
end = tl.load(a_offset_ptr + pid_batch + 1)
|
|
56
|
+
M = tl.minimum(end - begin, max_seq_len) # in case M > max seq len
|
|
57
|
+
if M == 0:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
61
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
62
|
+
pid_m = pid // num_pid_n
|
|
63
|
+
pid_n = pid % num_pid_n
|
|
64
|
+
|
|
65
|
+
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
66
|
+
|
|
67
|
+
# if pid_m * BLOCK_SIZE_M >=M, then this block doesn't need to be computed
|
|
68
|
+
if pid_m * BLOCK_SIZE_M >= M:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
72
|
+
|
|
73
|
+
if pid_n * BLOCK_SIZE_N >= N:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
77
|
+
a_ptrs = a_ptr + (
|
|
78
|
+
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + begin * stride_am
|
|
79
|
+
) # jagged tensor ptr
|
|
80
|
+
b_ptrs = b_ptr + (
|
|
81
|
+
offs_k[:, None] * stride_bk
|
|
82
|
+
+ offs_bn[None, :] * stride_bn
|
|
83
|
+
+ pid_batch * stride_bl
|
|
84
|
+
) # dense tensor ptr
|
|
85
|
+
|
|
86
|
+
c = tl.zeros(
|
|
87
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
88
|
+
) # TODO, max this flexible
|
|
89
|
+
|
|
90
|
+
# Compute c[m, n] for 1 example of the batch
|
|
91
|
+
for k in range(0, K, BLOCK_SIZE_K):
|
|
92
|
+
updated_offset = k + offs_k
|
|
93
|
+
a = tl.load(
|
|
94
|
+
a_ptrs,
|
|
95
|
+
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
|
|
96
|
+
mask=(updated_offset[None, :] < K) & (offs_am[:, None] < M),
|
|
97
|
+
other=0.0,
|
|
98
|
+
)
|
|
99
|
+
b = tl.load(
|
|
100
|
+
b_ptrs,
|
|
101
|
+
mask=(updated_offset[:, None] < K) & (offs_bn[None, :] < N),
|
|
102
|
+
other=0.0,
|
|
103
|
+
)
|
|
104
|
+
c += tl.dot(a, b, allow_tf32=allow_tf32)
|
|
105
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
106
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
107
|
+
|
|
108
|
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
109
|
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
110
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
111
|
+
c_ptrs = (
|
|
112
|
+
c_ptr
|
|
113
|
+
+ stride_cm * offs_m[:, None]
|
|
114
|
+
+ stride_cn * offs_n[None, :]
|
|
115
|
+
+ begin * stride_cm
|
|
116
|
+
)
|
|
117
|
+
tl.store(c_ptrs, c, mask=mask)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@triton.jit
|
|
121
|
+
def jagged_jagged_bmm_kernel(
|
|
122
|
+
a_ptr,
|
|
123
|
+
a_offset_ptr,
|
|
124
|
+
b_ptr,
|
|
125
|
+
c_ptr,
|
|
126
|
+
M,
|
|
127
|
+
N,
|
|
128
|
+
stride_am,
|
|
129
|
+
stride_ak,
|
|
130
|
+
stride_bk,
|
|
131
|
+
stride_bn,
|
|
132
|
+
stride_cl,
|
|
133
|
+
stride_cm,
|
|
134
|
+
stride_cn,
|
|
135
|
+
max_seq_len,
|
|
136
|
+
allow_tf32: tl.constexpr,
|
|
137
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
138
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
139
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
Kernel for computing the matmul C = A x B.
|
|
143
|
+
A has shape (M, sum_B(Ki)), B has shape (sum_B(Ki), N) and C has shape (B, M, N)
|
|
144
|
+
"""
|
|
145
|
+
pid_batch = tl.program_id(0)
|
|
146
|
+
pid = tl.program_id(1)
|
|
147
|
+
|
|
148
|
+
# need to make sure a_offset_ptr has stride of 1
|
|
149
|
+
begin = tl.load(a_offset_ptr + pid_batch)
|
|
150
|
+
end = tl.load(a_offset_ptr + pid_batch + 1)
|
|
151
|
+
K = end - begin # K for current pid_batch
|
|
152
|
+
K = tl.minimum(K, max_seq_len)
|
|
153
|
+
# if K == 0:
|
|
154
|
+
# return
|
|
155
|
+
|
|
156
|
+
# calculate pid_m and pid_n
|
|
157
|
+
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
158
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
159
|
+
pid_m = pid // num_pid_n
|
|
160
|
+
pid_n = pid % num_pid_n
|
|
161
|
+
|
|
162
|
+
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
163
|
+
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
164
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
165
|
+
a_ptrs = (
|
|
166
|
+
a_ptr
|
|
167
|
+
+ (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
168
|
+
+ begin * stride_ak
|
|
169
|
+
)
|
|
170
|
+
b_ptrs = (
|
|
171
|
+
b_ptr
|
|
172
|
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
173
|
+
+ begin * stride_bk
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
c = tl.zeros(
|
|
177
|
+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
178
|
+
) # TODO, max this flexible
|
|
179
|
+
for k in range(0, K, BLOCK_SIZE_K):
|
|
180
|
+
updated_offset = k + offs_k
|
|
181
|
+
a = tl.load(
|
|
182
|
+
a_ptrs,
|
|
183
|
+
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
|
|
184
|
+
mask=((updated_offset[None, :] < K) & (offs_am[:, None] < M)),
|
|
185
|
+
other=0.0,
|
|
186
|
+
)
|
|
187
|
+
b = tl.load(
|
|
188
|
+
b_ptrs,
|
|
189
|
+
mask=((updated_offset[:, None] < K) & (offs_bn[None, :] < N)),
|
|
190
|
+
other=0.0,
|
|
191
|
+
)
|
|
192
|
+
c += tl.dot(a, b, allow_tf32=allow_tf32)
|
|
193
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
194
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
195
|
+
|
|
196
|
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
197
|
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
198
|
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
199
|
+
c_ptrs = (
|
|
200
|
+
c_ptr
|
|
201
|
+
+ stride_cm * offs_m[:, None]
|
|
202
|
+
+ stride_cn * offs_n[None, :]
|
|
203
|
+
+ stride_cl * pid_batch
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
tl.store(c_ptrs, c, mask=mask)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def triton_jagged_dense_bmm(a, b, a_offsets, max_seq_len, allow_tf32):
|
|
210
|
+
# checks constraints
|
|
211
|
+
assert a.shape[1] == b.shape[1], "incompatible dimensions"
|
|
212
|
+
assert a_offsets.is_contiguous(), "A offsets mush be contiguous"
|
|
213
|
+
sum_B, K = a.shape
|
|
214
|
+
B, K, N = b.shape
|
|
215
|
+
# Use zeros instead of empty to handle corner case when jagged tensor has length > max seq len
|
|
216
|
+
# In that case, it is possible that the output is inconsistent with the padded version if empty is used
|
|
217
|
+
c = a.new_zeros((sum_B, N))
|
|
218
|
+
|
|
219
|
+
BLOCK_SIZE_M = 32 if max_seq_len < 50 else 64
|
|
220
|
+
BLOCK_SIZE_N = set_block_size(N)
|
|
221
|
+
BLOCK_SIZE_K = set_block_size(K)
|
|
222
|
+
|
|
223
|
+
# 2D launch kernel where each block gets its own program.
|
|
224
|
+
# TODO, is this the best way to handle launch grid?
|
|
225
|
+
# The grid number on M axises is larger than required often due to max_seq_len
|
|
226
|
+
grid = (
|
|
227
|
+
B,
|
|
228
|
+
triton.cdiv(max_seq_len, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
jagged_dense_bmm_kernel[grid](
|
|
232
|
+
a,
|
|
233
|
+
a_offsets,
|
|
234
|
+
b,
|
|
235
|
+
c,
|
|
236
|
+
N,
|
|
237
|
+
K,
|
|
238
|
+
a.stride(0),
|
|
239
|
+
a.stride(1),
|
|
240
|
+
b.stride(0),
|
|
241
|
+
b.stride(1),
|
|
242
|
+
b.stride(2),
|
|
243
|
+
c.stride(0),
|
|
244
|
+
c.stride(1),
|
|
245
|
+
max_seq_len,
|
|
246
|
+
allow_tf32,
|
|
247
|
+
BLOCK_SIZE_M,
|
|
248
|
+
BLOCK_SIZE_N,
|
|
249
|
+
BLOCK_SIZE_K,
|
|
250
|
+
)
|
|
251
|
+
return c
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def triton_jagged_jagged_bmm(a, b, a_offsets, max_seq_len, allow_tf32):
|
|
255
|
+
# checks constraints
|
|
256
|
+
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
|
257
|
+
assert a_offsets.is_contiguous(), "A offsets mush be contiguous"
|
|
258
|
+
M, _ = a.shape
|
|
259
|
+
_, N = b.shape
|
|
260
|
+
B = a_offsets.size(0) - 1
|
|
261
|
+
# allocates output
|
|
262
|
+
c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
|
|
263
|
+
# 2D launch kernel where each block gets its own program.
|
|
264
|
+
BLOCK_SIZE_M = set_block_size(M)
|
|
265
|
+
BLOCK_SIZE_N = set_block_size(N)
|
|
266
|
+
BLOCK_SIZE_K = 32
|
|
267
|
+
grid = (
|
|
268
|
+
B,
|
|
269
|
+
triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
|
|
270
|
+
)
|
|
271
|
+
jagged_jagged_bmm_kernel[grid](
|
|
272
|
+
a,
|
|
273
|
+
a_offsets,
|
|
274
|
+
b,
|
|
275
|
+
c,
|
|
276
|
+
M,
|
|
277
|
+
N,
|
|
278
|
+
a.stride(0),
|
|
279
|
+
a.stride(1),
|
|
280
|
+
b.stride(0),
|
|
281
|
+
b.stride(1),
|
|
282
|
+
c.stride(0),
|
|
283
|
+
c.stride(1),
|
|
284
|
+
c.stride(2),
|
|
285
|
+
max_seq_len,
|
|
286
|
+
allow_tf32,
|
|
287
|
+
BLOCK_SIZE_M,
|
|
288
|
+
BLOCK_SIZE_N,
|
|
289
|
+
BLOCK_SIZE_K,
|
|
290
|
+
)
|
|
291
|
+
return c
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class JaggedDenseBmm(torch.autograd.Function):
|
|
295
|
+
"""
|
|
296
|
+
Compute batch matrix multiplication between JaggedTensor and dense tensor
|
|
297
|
+
dense: [B, N, D] * [B, D, T] = [B, N, T]
|
|
298
|
+
jagged: [Sum_B, D] * [B, D, T] = [Sum_B, T]
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
# pyre-fixme
|
|
303
|
+
def forward(
|
|
304
|
+
ctx,
|
|
305
|
+
x: torch.Tensor,
|
|
306
|
+
y: torch.Tensor,
|
|
307
|
+
x_offsets: torch.Tensor,
|
|
308
|
+
N: int,
|
|
309
|
+
allow_tf32: bool,
|
|
310
|
+
):
|
|
311
|
+
ctx.save_for_backward(x, y, x_offsets)
|
|
312
|
+
ctx.N = N
|
|
313
|
+
ctx.allow_tf32 = allow_tf32
|
|
314
|
+
return triton_jagged_dense_bmm(x, y, x_offsets, N, allow_tf32=allow_tf32)
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
# pyre-fixme
|
|
318
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
319
|
+
"""
|
|
320
|
+
# X = [Sum_B, D]
|
|
321
|
+
# Y = [B, D, T]
|
|
322
|
+
# Z = X * Y = [Sum_B, T]
|
|
323
|
+
# dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
|
|
324
|
+
# dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
# logging.info(f"Jagged bmm backward called")
|
|
328
|
+
|
|
329
|
+
(x, y, x_offsets) = ctx.saved_tensors
|
|
330
|
+
N = ctx.N
|
|
331
|
+
grad_x = triton_jagged_dense_bmm(
|
|
332
|
+
grad_output, y.permute(0, 2, 1), x_offsets, N, allow_tf32=ctx.allow_tf32
|
|
333
|
+
)
|
|
334
|
+
grad_y = triton_jagged_jagged_bmm(
|
|
335
|
+
x.T, grad_output, x_offsets, N, allow_tf32=ctx.allow_tf32
|
|
336
|
+
)
|
|
337
|
+
return grad_x, grad_y, None, None, None
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class JaggedJaggedBmm(torch.autograd.Function):
|
|
341
|
+
"""
|
|
342
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
343
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
344
|
+
jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
@staticmethod
|
|
348
|
+
# pyre-fixme
|
|
349
|
+
def forward(
|
|
350
|
+
ctx,
|
|
351
|
+
x: torch.Tensor,
|
|
352
|
+
y: torch.Tensor,
|
|
353
|
+
x_offsets: torch.Tensor,
|
|
354
|
+
N: int,
|
|
355
|
+
allow_tf32,
|
|
356
|
+
):
|
|
357
|
+
ctx.save_for_backward(x, y, x_offsets)
|
|
358
|
+
ctx.N = N
|
|
359
|
+
ctx.allow_tf32 = allow_tf32
|
|
360
|
+
return triton_jagged_jagged_bmm(x.T, y, x_offsets, N, allow_tf32=allow_tf32)
|
|
361
|
+
|
|
362
|
+
@staticmethod
|
|
363
|
+
# pyre-fixme
|
|
364
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
365
|
+
"""
|
|
366
|
+
# X = [Sum_B, D]
|
|
367
|
+
# Y = [Sum_B, T]
|
|
368
|
+
# Z = XT * Y = [B, D, T]
|
|
369
|
+
# dXT = dZ * YT -> dX = Y * dZT
|
|
370
|
+
# dY = X * dZ -> X * dZ
|
|
371
|
+
"""
|
|
372
|
+
(x, y, offsets) = ctx.saved_tensors
|
|
373
|
+
N = ctx.N
|
|
374
|
+
grad_x = triton_jagged_dense_bmm(
|
|
375
|
+
y, grad_output.permute(0, 2, 1), offsets, N, allow_tf32=ctx.allow_tf32
|
|
376
|
+
)
|
|
377
|
+
grad_y = triton_jagged_dense_bmm(
|
|
378
|
+
x, grad_output, offsets, N, allow_tf32=ctx.allow_tf32
|
|
379
|
+
)
|
|
380
|
+
return grad_x, grad_y, None, None, None
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def jagged_dense_bmm(
|
|
384
|
+
x: torch.Tensor,
|
|
385
|
+
y: torch.Tensor,
|
|
386
|
+
x_offsets: torch.Tensor,
|
|
387
|
+
N: int,
|
|
388
|
+
allow_tf32: bool,
|
|
389
|
+
use_fbgemm_kernel: bool = True,
|
|
390
|
+
) -> torch.Tensor:
|
|
391
|
+
"""
|
|
392
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
393
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
394
|
+
jagged: [D, Sum_B] * [Sum_B, T] = [B, D, T]
|
|
395
|
+
"""
|
|
396
|
+
if use_fbgemm_kernel:
|
|
397
|
+
return torch.ops.fbgemm.jagged_dense_bmm(x, x_offsets, y, N)[0]
|
|
398
|
+
else:
|
|
399
|
+
return JaggedDenseBmm.apply(x, y, x_offsets, N, allow_tf32)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def jagged_jagged_bmm(
|
|
403
|
+
x: torch.Tensor,
|
|
404
|
+
y: torch.Tensor,
|
|
405
|
+
x_offsets: torch.Tensor,
|
|
406
|
+
N: int,
|
|
407
|
+
allow_tf32: bool,
|
|
408
|
+
use_fbgemm_kernel: bool = True,
|
|
409
|
+
):
|
|
410
|
+
"""
|
|
411
|
+
Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
|
|
412
|
+
dense: [B, D, N] * [B, N, T] = [B, D, T]
|
|
413
|
+
jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
|
|
414
|
+
"""
|
|
415
|
+
if use_fbgemm_kernel:
|
|
416
|
+
return torch.ops.fbgemm.jagged_jagged_bmm(x, y, x_offsets, N)
|
|
417
|
+
else:
|
|
418
|
+
return JaggedJaggedBmm.apply(x, y, x_offsets, N, allow_tf32)
|