fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,824 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# pyre-ignore-all-errors[6]
|
|
11
|
+
|
|
12
|
+
from typing import Optional, Union
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import triton # @manual
|
|
16
|
+
import triton.language as tl # @manual
|
|
17
|
+
from torch._tensor import Tensor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@triton.jit
|
|
21
|
+
def jagged_jagged_elementwise_arithmetic_ops(
|
|
22
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
23
|
+
x_ptr, # x_ptr and y_ptr is pointer of jagged tensor value
|
|
24
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
25
|
+
y_ptr,
|
|
26
|
+
M: tl.constexpr, # M and N would be size of the tensor with (M , N)
|
|
27
|
+
N: tl.constexpr,
|
|
28
|
+
stride_row: tl.constexpr, # shared row stride for tensor
|
|
29
|
+
stride_col: tl.constexpr, # shared colume stride for tensor
|
|
30
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
31
|
+
output,
|
|
32
|
+
thread_block_row_size: tl.constexpr, # row and colume size of current thread block with size (thread_block_row_size * thread_block_col_size)
|
|
33
|
+
thread_block_col_size: tl.constexpr,
|
|
34
|
+
ops_func: tl.constexpr, # function use for calculation either add or multiplication
|
|
35
|
+
) -> None:
|
|
36
|
+
pid = tl.program_id(0)
|
|
37
|
+
# number of col group need for total N col
|
|
38
|
+
num_group_n = (N + thread_block_col_size - 1) // thread_block_col_size
|
|
39
|
+
# pid position in col perspective in range(0,num_group_n)
|
|
40
|
+
pid_n = pid % num_group_n
|
|
41
|
+
# pid position in row perspective since everytime row increase when we have num_group_n iteration
|
|
42
|
+
pid_m = pid // num_group_n
|
|
43
|
+
|
|
44
|
+
offset_m = pid_m * thread_block_row_size + tl.arange(0, thread_block_row_size)
|
|
45
|
+
offset_n = pid_n * thread_block_col_size + tl.arange(0, thread_block_col_size)
|
|
46
|
+
mask = (offset_m[:, None] < M) & (offset_n[None, :] < N)
|
|
47
|
+
offset = offset_m[:, None] * stride_row + offset_n[None, :] * stride_col
|
|
48
|
+
|
|
49
|
+
x_ptr += offset
|
|
50
|
+
y_ptr += offset
|
|
51
|
+
|
|
52
|
+
x = tl.load(x_ptr, mask=mask)
|
|
53
|
+
y = tl.load(y_ptr, mask=mask)
|
|
54
|
+
|
|
55
|
+
if ops_func == "add":
|
|
56
|
+
z = tensor_elementwise_add(x, y)
|
|
57
|
+
else:
|
|
58
|
+
z = tensor_elementwise_mul(x, y)
|
|
59
|
+
|
|
60
|
+
output += offset
|
|
61
|
+
tl.store(output, z, mask=mask)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@triton.jit
|
|
65
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
66
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
67
|
+
def tensor_elementwise_add(x, y):
|
|
68
|
+
return x + y
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@triton.jit
|
|
72
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
73
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
74
|
+
def tensor_elementwise_mul(x, y):
|
|
75
|
+
return x * y
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def triton_jagged_add_jagged(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
79
|
+
|
|
80
|
+
# x and y need to have same shape to do addition
|
|
81
|
+
assert x.shape == y.shape
|
|
82
|
+
|
|
83
|
+
thread_block_row_size = 32
|
|
84
|
+
thread_block_col_size = 32
|
|
85
|
+
|
|
86
|
+
# x and y would a tensor with same dimension (M,N)
|
|
87
|
+
M, N = x.shape
|
|
88
|
+
|
|
89
|
+
output = torch.empty((M, N), device="cuda", dtype=x.dtype)
|
|
90
|
+
|
|
91
|
+
# pyre-fixme[53]: Captured variable `M` is not annotated.
|
|
92
|
+
# pyre-fixme[53]: Captured variable `N` is not annotated.
|
|
93
|
+
# pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
|
|
94
|
+
# pyre-fixme[53]: Captured variable `thread_block_row_size` is not annotated.
|
|
95
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
96
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
97
|
+
def grid(META):
|
|
98
|
+
return (
|
|
99
|
+
triton.cdiv(M, thread_block_row_size)
|
|
100
|
+
* triton.cdiv(N, thread_block_col_size),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
jagged_jagged_elementwise_arithmetic_ops[grid](
|
|
104
|
+
x,
|
|
105
|
+
y,
|
|
106
|
+
M,
|
|
107
|
+
N,
|
|
108
|
+
x.stride(0),
|
|
109
|
+
x.stride(1),
|
|
110
|
+
output,
|
|
111
|
+
thread_block_row_size,
|
|
112
|
+
thread_block_col_size,
|
|
113
|
+
ops_func="add",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return output
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def triton_jagged_mul_jagged(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
120
|
+
|
|
121
|
+
# x and y need to have same shape to do addition
|
|
122
|
+
assert x.shape == y.shape
|
|
123
|
+
|
|
124
|
+
thread_block_row_size = 32
|
|
125
|
+
thread_block_col_size = 32
|
|
126
|
+
# x and y would a tensor with same dimension (M,N)
|
|
127
|
+
M, N = x.shape
|
|
128
|
+
|
|
129
|
+
output = torch.empty((M, N), device="cuda", dtype=x.dtype)
|
|
130
|
+
|
|
131
|
+
# pyre-fixme[53]: Captured variable `M` is not annotated.
|
|
132
|
+
# pyre-fixme[53]: Captured variable `N` is not annotated.
|
|
133
|
+
# pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
|
|
134
|
+
# pyre-fixme[53]: Captured variable `thread_block_row_size` is not annotated.
|
|
135
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
136
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
137
|
+
def grid(META):
|
|
138
|
+
return (
|
|
139
|
+
triton.cdiv(M, thread_block_row_size)
|
|
140
|
+
* triton.cdiv(N, thread_block_col_size),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
jagged_jagged_elementwise_arithmetic_ops[grid](
|
|
144
|
+
x,
|
|
145
|
+
y,
|
|
146
|
+
M,
|
|
147
|
+
N,
|
|
148
|
+
x.stride(0),
|
|
149
|
+
x.stride(1),
|
|
150
|
+
output,
|
|
151
|
+
thread_block_row_size,
|
|
152
|
+
thread_block_col_size,
|
|
153
|
+
ops_func="mul",
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return output
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# with bmm([B * H , 1 , N] , [B*H , N , D])
|
|
160
|
+
# Each kernel function dealing with matmul of (1,N) * (N,D)
|
|
161
|
+
@triton.jit
|
|
162
|
+
def triton_batched_dense_vec_jagged_2d_matmul(
|
|
163
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
164
|
+
jagged_tensor_ptr,
|
|
165
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
166
|
+
dense_ptr,
|
|
167
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
168
|
+
jagged_offset,
|
|
169
|
+
thread_block_col_size: tl.constexpr,
|
|
170
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
171
|
+
dense_row_stride,
|
|
172
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
173
|
+
jagged_value_row_stride,
|
|
174
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
175
|
+
D,
|
|
176
|
+
H: tl.constexpr,
|
|
177
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
178
|
+
output_ptr,
|
|
179
|
+
) -> None:
|
|
180
|
+
|
|
181
|
+
pid = tl.program_id(0)
|
|
182
|
+
|
|
183
|
+
# number of kernel need for with matrix (N,D) calculated by D // thread_block_col_size
|
|
184
|
+
GRID_DIM_COL = (D + thread_block_col_size - 1) // thread_block_col_size
|
|
185
|
+
|
|
186
|
+
# current output row index
|
|
187
|
+
output_row_idx = pid // GRID_DIM_COL
|
|
188
|
+
|
|
189
|
+
# current jagged tensor offset index
|
|
190
|
+
jagged_offset_id = output_row_idx // H
|
|
191
|
+
|
|
192
|
+
# current index with D reference since the real shape of jagged values is [B , N , H * D]
|
|
193
|
+
D_refer_idx = output_row_idx % H
|
|
194
|
+
|
|
195
|
+
# current part of [N * D] id
|
|
196
|
+
group_id = pid % GRID_DIM_COL
|
|
197
|
+
|
|
198
|
+
# size of tile
|
|
199
|
+
offset = group_id * thread_block_col_size + tl.arange(0, thread_block_col_size)
|
|
200
|
+
|
|
201
|
+
# begin index and end index of values
|
|
202
|
+
begin = tl.load(jagged_offset + jagged_offset_id)
|
|
203
|
+
end = tl.load(jagged_offset + (jagged_offset_id + 1))
|
|
204
|
+
|
|
205
|
+
# update each pointer to the correct address
|
|
206
|
+
dense_ptr += output_row_idx * dense_row_stride
|
|
207
|
+
jagged_tensor_ptr += begin * jagged_value_row_stride + D_refer_idx * D
|
|
208
|
+
output_ptr += D * output_row_idx
|
|
209
|
+
|
|
210
|
+
# Number of row each kernel will go through
|
|
211
|
+
num_row = tl.minimum(end - begin, dense_row_stride)
|
|
212
|
+
|
|
213
|
+
# accumulation variable use for matmul
|
|
214
|
+
acc = tl.zeros((thread_block_col_size,), dtype=tl.float32)
|
|
215
|
+
mask = offset < D
|
|
216
|
+
for i in range(num_row):
|
|
217
|
+
val1 = tl.load(dense_ptr + i)
|
|
218
|
+
val2 = tl.load(jagged_tensor_ptr + offset, mask=mask, other=0.0)
|
|
219
|
+
result = val1 * val2
|
|
220
|
+
acc += result
|
|
221
|
+
jagged_tensor_ptr += jagged_value_row_stride
|
|
222
|
+
|
|
223
|
+
tl.store(output_ptr + offset, acc, mask=mask)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
# torch.bmm refer https://pytorch.org/docs/stable/generated/torch.bmm.html
|
|
227
|
+
# Operation that take dense as format [B * H , N] where N is the max_length in the logical representation we treat dense like [B * H , 1 , N]
|
|
228
|
+
# and 2D jagged tensor with format values format [B , N , H * D] in the logical representation we treat values like [B * H , N , D]
|
|
229
|
+
# in the 2D jagged tensor case offset will be tensor instead of list of tensor
|
|
230
|
+
# create output dense with shape [B * H , 1 , D]
|
|
231
|
+
# dense * jagged_tesnor = output_dense -> [B * H , 1 , N] * [B * H , N , D] = [B * H , 1 , D]
|
|
232
|
+
def batched_dense_vec_jagged_2d_matmul(
|
|
233
|
+
dense: torch.Tensor,
|
|
234
|
+
values: torch.Tensor,
|
|
235
|
+
offset: torch.Tensor,
|
|
236
|
+
) -> torch.Tensor:
|
|
237
|
+
B = offset.size(0) - 1
|
|
238
|
+
H = dense.size(0) // B
|
|
239
|
+
D = values.size(-1) // H
|
|
240
|
+
thread_block_col_size = 32
|
|
241
|
+
|
|
242
|
+
output_dense = torch.empty((B * H, D), device="cuda", dtype=values.dtype)
|
|
243
|
+
|
|
244
|
+
# number of thread block need for jagged tensor with [B * H , N , D]
|
|
245
|
+
# pyre-fixme[53]: Captured variable `B` is not annotated.
|
|
246
|
+
# pyre-fixme[53]: Captured variable `D` is not annotated.
|
|
247
|
+
# pyre-fixme[53]: Captured variable `H` is not annotated.
|
|
248
|
+
# pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
|
|
249
|
+
# pyre-fixme[3]: Return type must be annotated.
|
|
250
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
251
|
+
def grid(META):
|
|
252
|
+
return (B * H * triton.cdiv(D, thread_block_col_size),)
|
|
253
|
+
|
|
254
|
+
triton_batched_dense_vec_jagged_2d_matmul[grid](
|
|
255
|
+
values,
|
|
256
|
+
dense,
|
|
257
|
+
offset,
|
|
258
|
+
thread_block_col_size,
|
|
259
|
+
dense.stride(0),
|
|
260
|
+
values.stride(0),
|
|
261
|
+
D,
|
|
262
|
+
H,
|
|
263
|
+
output_dense,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return output_dense
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# each kernel will handle the conversion of one jagged tensor offset range to corresponding dense index
|
|
270
|
+
@triton.jit
|
|
271
|
+
def triton_jagged_to_dense(
|
|
272
|
+
# only constexpr annotations support in triton now
|
|
273
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
274
|
+
jagged_value_ptr,
|
|
275
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
276
|
+
jagged_offsets_ptr,
|
|
277
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
278
|
+
jagged_value_row_stride,
|
|
279
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
280
|
+
output_dense_ptr,
|
|
281
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
282
|
+
dense_indices_ptr,
|
|
283
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
284
|
+
dense_col_stride, # stride of output dense with dimension (z,y,x)
|
|
285
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
286
|
+
dense_row_stride,
|
|
287
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
288
|
+
dense_matrix_stride,
|
|
289
|
+
JAGGED_DIM: tl.constexpr, # number of dimension of jagged tensor
|
|
290
|
+
thread_block_row_size: tl.constexpr,
|
|
291
|
+
thread_block_col_size: tl.constexpr,
|
|
292
|
+
operation_function: tl.constexpr, # fusion arithmetic operation function and it's input dense
|
|
293
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
294
|
+
operation_dense,
|
|
295
|
+
) -> None:
|
|
296
|
+
pid = tl.program_id(0)
|
|
297
|
+
|
|
298
|
+
# begin index and end index of jagged tensor Values
|
|
299
|
+
begin = tl.load(jagged_offsets_ptr + pid)
|
|
300
|
+
end = tl.load(jagged_offsets_ptr + (pid + 1))
|
|
301
|
+
|
|
302
|
+
# adjust the address of the jagged tensor Values to the correct address
|
|
303
|
+
jagged_value_ptr += begin * jagged_value_row_stride
|
|
304
|
+
|
|
305
|
+
# if it's 2D (or 1D) Jagged tensor we can direct use the offset in offsets ( since there is only one offset )
|
|
306
|
+
# else we actually need to use the preprocess index to found the correct address of dense
|
|
307
|
+
if JAGGED_DIM > 2:
|
|
308
|
+
# read the index for current kernel
|
|
309
|
+
dense_indice = tl.load(dense_indices_ptr + pid)
|
|
310
|
+
|
|
311
|
+
# if the dense_indice is -1 which mean it's a truncation case
|
|
312
|
+
# in that case we don't need to do anything since the dense
|
|
313
|
+
# initialize with padded value
|
|
314
|
+
if dense_indice == -1:
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
# adjust the address of output dense ptr to the correct address
|
|
318
|
+
output_dense_ptr += dense_indice
|
|
319
|
+
|
|
320
|
+
# also need to update the operation function if exist
|
|
321
|
+
# notice dense_indice of two is same because we assume
|
|
322
|
+
# the two dense + dense are same size
|
|
323
|
+
if operation_function is not None:
|
|
324
|
+
operation_dense += dense_indice
|
|
325
|
+
else:
|
|
326
|
+
output_dense_ptr += pid * dense_matrix_stride
|
|
327
|
+
|
|
328
|
+
if operation_function is not None:
|
|
329
|
+
operation_dense += pid * dense_matrix_stride
|
|
330
|
+
|
|
331
|
+
offset_row = tl.arange(0, thread_block_row_size)
|
|
332
|
+
|
|
333
|
+
# boundary need for the mask since it could be dense's size smaller than jagged tensor or revert case
|
|
334
|
+
N = tl.minimum(dense_row_stride, jagged_value_row_stride)
|
|
335
|
+
M = tl.minimum(dense_matrix_stride // dense_row_stride, end - begin)
|
|
336
|
+
|
|
337
|
+
for _i in range(begin, end, thread_block_row_size):
|
|
338
|
+
offset_col = tl.arange(0, thread_block_col_size)
|
|
339
|
+
block_offset = (
|
|
340
|
+
offset_row[:, None] * dense_row_stride
|
|
341
|
+
+ offset_col[None, :] * dense_col_stride
|
|
342
|
+
)
|
|
343
|
+
for _j in range(0, N, thread_block_col_size):
|
|
344
|
+
mask = (offset_row[:, None] < M) & (offset_col[None, :] < N)
|
|
345
|
+
jagged_val = tl.load(jagged_value_ptr + block_offset, mask=mask, other=0)
|
|
346
|
+
|
|
347
|
+
# if there is some arithmetic operation we do the fusion computation
|
|
348
|
+
if operation_function is not None:
|
|
349
|
+
val1 = jagged_val
|
|
350
|
+
val2 = tl.load(operation_dense + block_offset, mask=mask, other=0)
|
|
351
|
+
# do the arithmetic operation
|
|
352
|
+
if operation_function == "add":
|
|
353
|
+
jagged_val = tensor_elementwise_add(val1, val2)
|
|
354
|
+
else:
|
|
355
|
+
jagged_val = tensor_elementwise_mul(val1, val2)
|
|
356
|
+
|
|
357
|
+
# store the result
|
|
358
|
+
tl.store(output_dense_ptr + block_offset, jagged_val, mask=mask)
|
|
359
|
+
|
|
360
|
+
# update the block offset
|
|
361
|
+
offset_col += thread_block_col_size
|
|
362
|
+
block_offset += thread_block_col_size
|
|
363
|
+
offset_row += thread_block_row_size
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# This function will handle the 2d Jagged Tensor to Dense operation
|
|
367
|
+
# each kernel will go through all the element in each 2D tensor in
|
|
368
|
+
# Dense ( Notice that since it's 2d jagged tensor dense will be 3D ).
|
|
369
|
+
# Each kernel will check if the current value in 2d tensor is in
|
|
370
|
+
# range or out of range. If in the range of Jagged Tensor, it will load
|
|
371
|
+
# corresponding value, otherwise it will load padded value into dense.
|
|
372
|
+
# On the other hand, in the function triton_jagged_to_dense, we are
|
|
373
|
+
# only able to fill the value from jagged tensor to corresponding dense
|
|
374
|
+
# but we are not be able to fill the dense with padded value in kernel.
|
|
375
|
+
# therefore in pervious function, we fill dense with padded value first
|
|
376
|
+
# then load corresponding value. Instead this function can directly
|
|
377
|
+
# fill the value in kernel to avoid extra latency.
|
|
378
|
+
@triton.jit
|
|
379
|
+
def triton_jagged_to_dense_optimization_2d(
|
|
380
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
381
|
+
input_jagged_values_ptr,
|
|
382
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
383
|
+
input_jagged_offset_ptr,
|
|
384
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
385
|
+
input_jagged_row_stride,
|
|
386
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
387
|
+
output_dense_ptr,
|
|
388
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
389
|
+
output_dense_row_stride,
|
|
390
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
391
|
+
output_dense_matrix_stride,
|
|
392
|
+
thread_block_row_size: tl.constexpr,
|
|
393
|
+
thread_block_col_size: tl.constexpr,
|
|
394
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
395
|
+
padded_value,
|
|
396
|
+
operation_function: tl.constexpr,
|
|
397
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
398
|
+
operation_dense,
|
|
399
|
+
) -> None:
|
|
400
|
+
pid = tl.program_id(0)
|
|
401
|
+
|
|
402
|
+
# Current corresponding offset indice
|
|
403
|
+
offset_idx = pid
|
|
404
|
+
|
|
405
|
+
# begin index and end index of jagged tensor Values
|
|
406
|
+
begin = tl.load(input_jagged_offset_ptr + offset_idx)
|
|
407
|
+
end = tl.load(input_jagged_offset_ptr + offset_idx + 1)
|
|
408
|
+
|
|
409
|
+
# row size of current sub tensor
|
|
410
|
+
cur_jagged_tensor_row_size = end - begin
|
|
411
|
+
|
|
412
|
+
# update dense and jagged tensor Values to corresponding address
|
|
413
|
+
output_dense_ptr += pid * output_dense_matrix_stride
|
|
414
|
+
input_jagged_values_ptr += begin * input_jagged_row_stride
|
|
415
|
+
|
|
416
|
+
# also need to update the operation function if exist
|
|
417
|
+
# notice dense_indice of two is same because we assume
|
|
418
|
+
# the two dense + dense are same size
|
|
419
|
+
if operation_function is not None:
|
|
420
|
+
operation_dense += pid * output_dense_matrix_stride
|
|
421
|
+
|
|
422
|
+
# jagged tensor row block
|
|
423
|
+
offset_row = tl.arange(0, thread_block_row_size)
|
|
424
|
+
|
|
425
|
+
# dense row and col block
|
|
426
|
+
# notice jagged tensor and dense share same col block since embedding dimension is same
|
|
427
|
+
dense_col_size = output_dense_row_stride
|
|
428
|
+
dense_row_size = output_dense_matrix_stride // output_dense_row_stride
|
|
429
|
+
|
|
430
|
+
for _i in range(0, dense_row_size, thread_block_row_size):
|
|
431
|
+
offset_col = tl.arange(0, thread_block_col_size)
|
|
432
|
+
block_offset = (
|
|
433
|
+
offset_row[:, None] * output_dense_row_stride + offset_col[None, :]
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
for _j in range(0, dense_col_size, thread_block_col_size):
|
|
437
|
+
|
|
438
|
+
# create mask for dense and jagged tensor for boundary check
|
|
439
|
+
dense_mask = (offset_row[:, None] < dense_row_size) & (
|
|
440
|
+
offset_col[None, :] < dense_col_size
|
|
441
|
+
)
|
|
442
|
+
jagged_mask = (offset_row[:, None] < cur_jagged_tensor_row_size) & (
|
|
443
|
+
offset_col[None, :] < input_jagged_row_stride
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# get value from jagged tesnor
|
|
447
|
+
jagged_val = tl.load(
|
|
448
|
+
input_jagged_values_ptr + block_offset,
|
|
449
|
+
mask=jagged_mask,
|
|
450
|
+
other=padded_value,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# do fusion operation if need
|
|
454
|
+
if operation_function is not None:
|
|
455
|
+
operation_dense_val = tl.load(
|
|
456
|
+
operation_dense + block_offset, mask=dense_mask, other=0.0
|
|
457
|
+
)
|
|
458
|
+
jagged_val = operation_function(operation_dense_val, jagged_val)
|
|
459
|
+
|
|
460
|
+
# load value into empty dense
|
|
461
|
+
tl.store(output_dense_ptr + block_offset, jagged_val, mask=dense_mask)
|
|
462
|
+
|
|
463
|
+
# update each block
|
|
464
|
+
offset_col += thread_block_col_size
|
|
465
|
+
block_offset += thread_block_col_size
|
|
466
|
+
offset_row += thread_block_row_size
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# this function parse the jagged tensor offsets to corresponding dense index position
|
|
470
|
+
# to see the detail of it see the quip note : https://fb.quip.com/gnzpA7d13vqO
|
|
471
|
+
# the FBGEMM implementation refer : https://www.internalfb.com/code/fbsource/[308212b2902c3182edcb5b204768321e032e8175]/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/jagged_tensor_ops.cu?lines=280
|
|
472
|
+
# In FBGEMM it was computed by GPU but in triton currently has some compilation issue so we use CUP computation method as workaround
|
|
473
|
+
# However in real-world case if we only dealing with 2d jagged tensor we don't need to use this function at all
|
|
474
|
+
def _jagged_offsets_to_dense_indice(
|
|
475
|
+
offsets: list[torch.Tensor], dense_strides: list[int], dense_sizes: list[int]
|
|
476
|
+
) -> torch.Tensor:
|
|
477
|
+
|
|
478
|
+
output_offset = torch.zeros(len(offsets[-1]) - 1, device="cpu", dtype=torch.int32)
|
|
479
|
+
|
|
480
|
+
offsets_cpu = []
|
|
481
|
+
|
|
482
|
+
for offset in offsets:
|
|
483
|
+
offsets_cpu.append(offset.cpu())
|
|
484
|
+
|
|
485
|
+
for i in range(0, len(offsets_cpu[-1]) - 1):
|
|
486
|
+
idx = i
|
|
487
|
+
result = 0
|
|
488
|
+
|
|
489
|
+
# flag to check if current offset is in the range of dense
|
|
490
|
+
in_range = True
|
|
491
|
+
for j in range(len(offsets_cpu) - 2, -1, -1):
|
|
492
|
+
left = 0
|
|
493
|
+
right = offsets_cpu[j].size(0)
|
|
494
|
+
|
|
495
|
+
# binary search found the corresponding offset group of current index
|
|
496
|
+
while left < right:
|
|
497
|
+
mid = left + (right - left) // 2
|
|
498
|
+
|
|
499
|
+
if offsets_cpu[j][mid] > idx:
|
|
500
|
+
right = mid
|
|
501
|
+
else:
|
|
502
|
+
left = mid + 1
|
|
503
|
+
|
|
504
|
+
cur_val = idx - offsets_cpu[j][left - 1]
|
|
505
|
+
|
|
506
|
+
if dense_sizes and cur_val >= dense_sizes[j + 1]:
|
|
507
|
+
in_range = False
|
|
508
|
+
break
|
|
509
|
+
|
|
510
|
+
result += cur_val * dense_strides[j + 1]
|
|
511
|
+
idx = left - 1
|
|
512
|
+
|
|
513
|
+
if in_range:
|
|
514
|
+
result += idx * dense_strides[0]
|
|
515
|
+
|
|
516
|
+
# another out of output dense range case
|
|
517
|
+
if dense_sizes and idx > dense_sizes[0]:
|
|
518
|
+
result = -1
|
|
519
|
+
output_offset[i] = result
|
|
520
|
+
else:
|
|
521
|
+
output_offset[i] = -1
|
|
522
|
+
|
|
523
|
+
return output_offset.cuda()
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
# transfer jagged tensor to dense for referring the quip note for wiki : https://fb.quip.com/gnzpA7d13vqO
|
|
527
|
+
# currently when doing the conversion if certain part of dense are not load from the jagged tensor Values
|
|
528
|
+
# it will be skiped. Which mean we initialize the tensor with padded value instead of fill it with padded
|
|
529
|
+
# value while conversion. Currently optimization approach implementation in triton faced some issue with
|
|
530
|
+
# LLVM compile issue but will look a work around when make a comparsion with multiple dimension of
|
|
531
|
+
# jagged tensot. However if currently we only dealing with 2d jagged tensor in real-world case this should
|
|
532
|
+
# not be affected at all
|
|
533
|
+
def jagged_to_dense(
|
|
534
|
+
jagged_values: torch.Tensor,
|
|
535
|
+
jagged_offsets: list[torch.Tensor],
|
|
536
|
+
jagged_max_lengths: list[int],
|
|
537
|
+
padding_value: float = 0.0, # padding value currently use 0.0 as default value
|
|
538
|
+
operation_function: Union[
|
|
539
|
+
str, None
|
|
540
|
+
] = None, # fusioned operation currently could be add or multiplication
|
|
541
|
+
operation_dense: Union[
|
|
542
|
+
torch.Tensor, None
|
|
543
|
+
] = None, # dense to make the add/mul with the output dense
|
|
544
|
+
) -> torch.Tensor:
|
|
545
|
+
outer_dense_size = len(jagged_offsets[0]) - 1
|
|
546
|
+
inner_dense_size = jagged_values.size(-1)
|
|
547
|
+
|
|
548
|
+
# dimension of jagged tensor
|
|
549
|
+
JAGGED_DIM = len(jagged_offsets) + 1
|
|
550
|
+
|
|
551
|
+
output_dense = None
|
|
552
|
+
|
|
553
|
+
# fill the padded value into dense if is multiple dimension
|
|
554
|
+
# other wise create empty dense
|
|
555
|
+
# this is for avoid multiple dimension cases
|
|
556
|
+
# it can create compile error if we going to fill the padding
|
|
557
|
+
# value inside of kernel function
|
|
558
|
+
if JAGGED_DIM > 2:
|
|
559
|
+
output_dense = torch.full(
|
|
560
|
+
((outer_dense_size,) + tuple(jagged_max_lengths) + (inner_dense_size,)),
|
|
561
|
+
padding_value,
|
|
562
|
+
device="cuda",
|
|
563
|
+
dtype=jagged_values.dtype,
|
|
564
|
+
)
|
|
565
|
+
else:
|
|
566
|
+
output_dense = torch.empty(
|
|
567
|
+
((outer_dense_size,) + tuple(jagged_max_lengths) + (inner_dense_size,)),
|
|
568
|
+
device="cuda",
|
|
569
|
+
dtype=jagged_values.dtype,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
thread_block_row_size = 32
|
|
573
|
+
thread_block_col_size = 32
|
|
574
|
+
|
|
575
|
+
grid = (len(jagged_offsets[-1]) - 1,)
|
|
576
|
+
|
|
577
|
+
# dense index in address perspective
|
|
578
|
+
dense_indices = None
|
|
579
|
+
|
|
580
|
+
# if dimension of jagged tensor ( which is number of offset ) we will need calculated the related dense index referring to jagged offsets
|
|
581
|
+
if JAGGED_DIM > 2:
|
|
582
|
+
dense_indices = _jagged_offsets_to_dense_indice(
|
|
583
|
+
jagged_offsets,
|
|
584
|
+
output_dense.stride()[:-2],
|
|
585
|
+
output_dense.size()[:-2],
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# dense stride for each column, row, and matrix
|
|
589
|
+
dense_col_stride = output_dense.stride(-1)
|
|
590
|
+
dense_row_stride = output_dense.stride(-2)
|
|
591
|
+
dense_matrix_stride = output_dense.stride(-3)
|
|
592
|
+
|
|
593
|
+
if JAGGED_DIM > 2:
|
|
594
|
+
triton_jagged_to_dense[grid](
|
|
595
|
+
jagged_values,
|
|
596
|
+
jagged_offsets[-1],
|
|
597
|
+
jagged_values.stride(0),
|
|
598
|
+
output_dense,
|
|
599
|
+
dense_indices,
|
|
600
|
+
dense_col_stride,
|
|
601
|
+
dense_row_stride,
|
|
602
|
+
dense_matrix_stride,
|
|
603
|
+
JAGGED_DIM,
|
|
604
|
+
thread_block_row_size,
|
|
605
|
+
thread_block_col_size,
|
|
606
|
+
operation_function=operation_function,
|
|
607
|
+
operation_dense=operation_dense,
|
|
608
|
+
)
|
|
609
|
+
else:
|
|
610
|
+
grid = (output_dense.size(0),)
|
|
611
|
+
triton_jagged_to_dense_optimization_2d[grid](
|
|
612
|
+
jagged_values,
|
|
613
|
+
jagged_offsets[-1],
|
|
614
|
+
jagged_values.stride(0),
|
|
615
|
+
output_dense,
|
|
616
|
+
dense_row_stride,
|
|
617
|
+
dense_matrix_stride,
|
|
618
|
+
thread_block_row_size,
|
|
619
|
+
thread_block_col_size,
|
|
620
|
+
padded_value=padding_value,
|
|
621
|
+
operation_function=operation_function,
|
|
622
|
+
operation_dense=operation_dense,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
return output_dense
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
# each kernel will handle the conversion of one jagged tensor offset range from corresponding dense index
|
|
629
|
+
@triton.jit
|
|
630
|
+
def triton_dense_to_jagged(
|
|
631
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
632
|
+
jagged_value_ptr,
|
|
633
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
634
|
+
jagged_offsets_ptr,
|
|
635
|
+
jagged_value_row_stride: int,
|
|
636
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
637
|
+
output_dense_ptr,
|
|
638
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
639
|
+
dense_indices_ptr,
|
|
640
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
641
|
+
dense_col_stride, # stride of output dense with dimension (z,y,x)
|
|
642
|
+
dense_row_stride: int,
|
|
643
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
644
|
+
dense_matrix_stride,
|
|
645
|
+
JAGGED_DIM: tl.constexpr, # number of dimension of jagged tensor
|
|
646
|
+
thread_block_row_size: tl.constexpr,
|
|
647
|
+
thread_block_col_size: tl.constexpr,
|
|
648
|
+
operation_function: tl.constexpr, # fusion arithmetic opeartion function and it's input dense
|
|
649
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
650
|
+
operation_jagged_value_ptr,
|
|
651
|
+
) -> None:
|
|
652
|
+
pid = tl.program_id(0)
|
|
653
|
+
|
|
654
|
+
begin = tl.load(jagged_offsets_ptr + pid)
|
|
655
|
+
end = tl.load(jagged_offsets_ptr + (pid + 1))
|
|
656
|
+
|
|
657
|
+
# size of the current value offset range (M , N)
|
|
658
|
+
N = jagged_value_row_stride
|
|
659
|
+
M = end - begin
|
|
660
|
+
|
|
661
|
+
dense_boundary_col = dense_row_stride
|
|
662
|
+
# tl.minimum will change the return type cased compile issue
|
|
663
|
+
# in that case use if statement instead
|
|
664
|
+
if N < dense_row_stride:
|
|
665
|
+
dense_boundary_col = N
|
|
666
|
+
|
|
667
|
+
dense_boundary_row = tl.minimum(dense_matrix_stride // dense_row_stride, M)
|
|
668
|
+
|
|
669
|
+
jagged_value_ptr += begin * jagged_value_row_stride
|
|
670
|
+
if JAGGED_DIM > 2:
|
|
671
|
+
dense_indice = tl.load(dense_indices_ptr + pid)
|
|
672
|
+
# if dense output range we set dense_boundary to -1
|
|
673
|
+
# that mean dense values will not be use with mask
|
|
674
|
+
# since we still need the calculation of fusion step
|
|
675
|
+
# therefore we do not do return here
|
|
676
|
+
if dense_indice == -1:
|
|
677
|
+
dense_boundary_col = -1
|
|
678
|
+
else:
|
|
679
|
+
output_dense_ptr += dense_indice
|
|
680
|
+
else:
|
|
681
|
+
output_dense_ptr += pid * dense_matrix_stride
|
|
682
|
+
|
|
683
|
+
if operation_function is not None:
|
|
684
|
+
operation_jagged_value_ptr += begin * jagged_value_row_stride
|
|
685
|
+
|
|
686
|
+
offset_row = tl.arange(0, thread_block_row_size)
|
|
687
|
+
|
|
688
|
+
for _i in range(begin, end, thread_block_row_size):
|
|
689
|
+
offset_col = tl.arange(0, thread_block_col_size)
|
|
690
|
+
block_offset = (
|
|
691
|
+
offset_row[:, None] * dense_row_stride
|
|
692
|
+
+ offset_col[None, :] * dense_col_stride
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
for _j in range(0, N, thread_block_col_size):
|
|
696
|
+
dense_mask = (offset_row[:, None] < dense_boundary_row) & (
|
|
697
|
+
offset_col[None, :] < dense_boundary_col
|
|
698
|
+
)
|
|
699
|
+
jagged_mask = (offset_row[:, None] < M) & (offset_col[None, :] < N)
|
|
700
|
+
dense_values = tl.load(
|
|
701
|
+
output_dense_ptr + block_offset, mask=dense_mask, other=0
|
|
702
|
+
)
|
|
703
|
+
if operation_function is not None:
|
|
704
|
+
operation_jagged_value = tl.load(
|
|
705
|
+
operation_jagged_value_ptr + block_offset, mask=jagged_mask, other=0
|
|
706
|
+
)
|
|
707
|
+
if operation_function == "add":
|
|
708
|
+
dense_values = tensor_elementwise_add(
|
|
709
|
+
dense_values, operation_jagged_value
|
|
710
|
+
)
|
|
711
|
+
else:
|
|
712
|
+
dense_values = tensor_elementwise_mul(
|
|
713
|
+
dense_values, operation_jagged_value
|
|
714
|
+
)
|
|
715
|
+
tl.store(jagged_value_ptr + block_offset, dense_values, mask=jagged_mask)
|
|
716
|
+
offset_col += thread_block_col_size
|
|
717
|
+
block_offset += thread_block_col_size
|
|
718
|
+
offset_row += thread_block_row_size
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def dense_to_jagged(
|
|
722
|
+
dense: torch.Tensor,
|
|
723
|
+
jagged_offsets: list[torch.Tensor],
|
|
724
|
+
operation_function: Union[str, None] = None,
|
|
725
|
+
operation_jagged_values: Union[torch.Tensor, None] = None,
|
|
726
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
727
|
+
|
|
728
|
+
thread_block_row_size = 32
|
|
729
|
+
thread_block_col_size = 32
|
|
730
|
+
|
|
731
|
+
if operation_function is None:
|
|
732
|
+
output_jagged_value = torch.empty(
|
|
733
|
+
(jagged_offsets[-1][-1], dense.size(-1)),
|
|
734
|
+
device="cuda",
|
|
735
|
+
dtype=dense.dtype,
|
|
736
|
+
)
|
|
737
|
+
else:
|
|
738
|
+
output_jagged_value = torch.empty(
|
|
739
|
+
# pyre-fixme [16]: Optional type has no attribute `shape`.Pyre
|
|
740
|
+
operation_jagged_values.shape,
|
|
741
|
+
device="cuda",
|
|
742
|
+
dtype=dense.dtype,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
grid = (jagged_offsets[-1].size(0) - 1,)
|
|
746
|
+
|
|
747
|
+
JAGGED_DIM = len(jagged_offsets) + 1
|
|
748
|
+
dense_indices = None
|
|
749
|
+
if len(jagged_offsets) > 1:
|
|
750
|
+
dense_indices = _jagged_offsets_to_dense_indice(
|
|
751
|
+
jagged_offsets,
|
|
752
|
+
dense.stride()[:-2],
|
|
753
|
+
dense.size()[:-2],
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# dense stride for each column, row, and matrix
|
|
757
|
+
dense_col_stride = dense.stride(-1)
|
|
758
|
+
dense_row_stride = dense.stride(-2)
|
|
759
|
+
dense_matrix_stride = dense.stride(-3)
|
|
760
|
+
|
|
761
|
+
triton_dense_to_jagged[grid](
|
|
762
|
+
output_jagged_value,
|
|
763
|
+
jagged_offsets[-1],
|
|
764
|
+
output_jagged_value.stride(0),
|
|
765
|
+
dense,
|
|
766
|
+
dense_indices,
|
|
767
|
+
dense_col_stride,
|
|
768
|
+
dense_row_stride,
|
|
769
|
+
dense_matrix_stride,
|
|
770
|
+
JAGGED_DIM,
|
|
771
|
+
thread_block_row_size,
|
|
772
|
+
thread_block_col_size,
|
|
773
|
+
operation_function=operation_function,
|
|
774
|
+
operation_jagged_value_ptr=operation_jagged_values,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
return output_jagged_value, jagged_offsets
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
# jagged_tensor + dense -> dense
|
|
781
|
+
def jagged_dense_elementwise_add_dense_output(
|
|
782
|
+
jagged_values: Tensor,
|
|
783
|
+
jagged_offsets: list[Tensor],
|
|
784
|
+
# pyre-fixme[2]: Parameter must be annotated.
|
|
785
|
+
dense,
|
|
786
|
+
) -> Tensor:
|
|
787
|
+
|
|
788
|
+
# max_length use to build output dense
|
|
789
|
+
# that has same size as input dense
|
|
790
|
+
max_length = dense.size()[1:-1]
|
|
791
|
+
|
|
792
|
+
# convert jagged tensor to dense
|
|
793
|
+
converted_dense = jagged_to_dense(jagged_values, jagged_offsets, max_length)
|
|
794
|
+
|
|
795
|
+
# add opeartion add two dense with same shape
|
|
796
|
+
# Once it's optimazied we can remove this statement
|
|
797
|
+
# and directly return converted_dense
|
|
798
|
+
return converted_dense + dense
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
# jagged_tensor + dense -> jagged_tensor
|
|
802
|
+
def jagged_dense_elementwise_add_jagged_output(
|
|
803
|
+
jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
|
|
804
|
+
) -> tuple[Tensor, list[Tensor]]:
|
|
805
|
+
|
|
806
|
+
return dense_to_jagged(
|
|
807
|
+
dense,
|
|
808
|
+
jagged_offsets,
|
|
809
|
+
operation_function="add",
|
|
810
|
+
operation_jagged_values=jagged_values,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
# jagged_tensor * dense -> jagged_tensor
|
|
815
|
+
def jagged_dense_elementwise_mul_jagged_output(
|
|
816
|
+
jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
|
|
817
|
+
) -> tuple[Tensor, list[Tensor]]:
|
|
818
|
+
|
|
819
|
+
return dense_to_jagged(
|
|
820
|
+
dense,
|
|
821
|
+
jagged_offsets,
|
|
822
|
+
operation_function="mul",
|
|
823
|
+
operation_jagged_values=jagged_values,
|
|
824
|
+
)
|