fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,824 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # pyre-ignore-all-errors[6]
11
+
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+ import triton # @manual
16
+ import triton.language as tl # @manual
17
+ from torch._tensor import Tensor
18
+
19
+
20
+ @triton.jit
21
+ def jagged_jagged_elementwise_arithmetic_ops(
22
+ # pyre-fixme[2]: Parameter must be annotated.
23
+ x_ptr, # x_ptr and y_ptr is pointer of jagged tensor value
24
+ # pyre-fixme[2]: Parameter must be annotated.
25
+ y_ptr,
26
+ M: tl.constexpr, # M and N would be size of the tensor with (M , N)
27
+ N: tl.constexpr,
28
+ stride_row: tl.constexpr, # shared row stride for tensor
29
+ stride_col: tl.constexpr, # shared colume stride for tensor
30
+ # pyre-fixme[2]: Parameter must be annotated.
31
+ output,
32
+ thread_block_row_size: tl.constexpr, # row and colume size of current thread block with size (thread_block_row_size * thread_block_col_size)
33
+ thread_block_col_size: tl.constexpr,
34
+ ops_func: tl.constexpr, # function use for calculation either add or multiplication
35
+ ) -> None:
36
+ pid = tl.program_id(0)
37
+ # number of col group need for total N col
38
+ num_group_n = (N + thread_block_col_size - 1) // thread_block_col_size
39
+ # pid position in col perspective in range(0,num_group_n)
40
+ pid_n = pid % num_group_n
41
+ # pid position in row perspective since everytime row increase when we have num_group_n iteration
42
+ pid_m = pid // num_group_n
43
+
44
+ offset_m = pid_m * thread_block_row_size + tl.arange(0, thread_block_row_size)
45
+ offset_n = pid_n * thread_block_col_size + tl.arange(0, thread_block_col_size)
46
+ mask = (offset_m[:, None] < M) & (offset_n[None, :] < N)
47
+ offset = offset_m[:, None] * stride_row + offset_n[None, :] * stride_col
48
+
49
+ x_ptr += offset
50
+ y_ptr += offset
51
+
52
+ x = tl.load(x_ptr, mask=mask)
53
+ y = tl.load(y_ptr, mask=mask)
54
+
55
+ if ops_func == "add":
56
+ z = tensor_elementwise_add(x, y)
57
+ else:
58
+ z = tensor_elementwise_mul(x, y)
59
+
60
+ output += offset
61
+ tl.store(output, z, mask=mask)
62
+
63
+
64
+ @triton.jit
65
+ # pyre-fixme[3]: Return type must be annotated.
66
+ # pyre-fixme[2]: Parameter must be annotated.
67
+ def tensor_elementwise_add(x, y):
68
+ return x + y
69
+
70
+
71
+ @triton.jit
72
+ # pyre-fixme[3]: Return type must be annotated.
73
+ # pyre-fixme[2]: Parameter must be annotated.
74
+ def tensor_elementwise_mul(x, y):
75
+ return x * y
76
+
77
+
78
+ def triton_jagged_add_jagged(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
79
+
80
+ # x and y need to have same shape to do addition
81
+ assert x.shape == y.shape
82
+
83
+ thread_block_row_size = 32
84
+ thread_block_col_size = 32
85
+
86
+ # x and y would a tensor with same dimension (M,N)
87
+ M, N = x.shape
88
+
89
+ output = torch.empty((M, N), device="cuda", dtype=x.dtype)
90
+
91
+ # pyre-fixme[53]: Captured variable `M` is not annotated.
92
+ # pyre-fixme[53]: Captured variable `N` is not annotated.
93
+ # pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
94
+ # pyre-fixme[53]: Captured variable `thread_block_row_size` is not annotated.
95
+ # pyre-fixme[3]: Return type must be annotated.
96
+ # pyre-fixme[2]: Parameter must be annotated.
97
+ def grid(META):
98
+ return (
99
+ triton.cdiv(M, thread_block_row_size)
100
+ * triton.cdiv(N, thread_block_col_size),
101
+ )
102
+
103
+ jagged_jagged_elementwise_arithmetic_ops[grid](
104
+ x,
105
+ y,
106
+ M,
107
+ N,
108
+ x.stride(0),
109
+ x.stride(1),
110
+ output,
111
+ thread_block_row_size,
112
+ thread_block_col_size,
113
+ ops_func="add",
114
+ )
115
+
116
+ return output
117
+
118
+
119
+ def triton_jagged_mul_jagged(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
120
+
121
+ # x and y need to have same shape to do addition
122
+ assert x.shape == y.shape
123
+
124
+ thread_block_row_size = 32
125
+ thread_block_col_size = 32
126
+ # x and y would a tensor with same dimension (M,N)
127
+ M, N = x.shape
128
+
129
+ output = torch.empty((M, N), device="cuda", dtype=x.dtype)
130
+
131
+ # pyre-fixme[53]: Captured variable `M` is not annotated.
132
+ # pyre-fixme[53]: Captured variable `N` is not annotated.
133
+ # pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
134
+ # pyre-fixme[53]: Captured variable `thread_block_row_size` is not annotated.
135
+ # pyre-fixme[3]: Return type must be annotated.
136
+ # pyre-fixme[2]: Parameter must be annotated.
137
+ def grid(META):
138
+ return (
139
+ triton.cdiv(M, thread_block_row_size)
140
+ * triton.cdiv(N, thread_block_col_size),
141
+ )
142
+
143
+ jagged_jagged_elementwise_arithmetic_ops[grid](
144
+ x,
145
+ y,
146
+ M,
147
+ N,
148
+ x.stride(0),
149
+ x.stride(1),
150
+ output,
151
+ thread_block_row_size,
152
+ thread_block_col_size,
153
+ ops_func="mul",
154
+ )
155
+
156
+ return output
157
+
158
+
159
+ # with bmm([B * H , 1 , N] , [B*H , N , D])
160
+ # Each kernel function dealing with matmul of (1,N) * (N,D)
161
+ @triton.jit
162
+ def triton_batched_dense_vec_jagged_2d_matmul(
163
+ # pyre-fixme[2]: Parameter must be annotated.
164
+ jagged_tensor_ptr,
165
+ # pyre-fixme[2]: Parameter must be annotated.
166
+ dense_ptr,
167
+ # pyre-fixme[2]: Parameter must be annotated.
168
+ jagged_offset,
169
+ thread_block_col_size: tl.constexpr,
170
+ # pyre-fixme[2]: Parameter must be annotated.
171
+ dense_row_stride,
172
+ # pyre-fixme[2]: Parameter must be annotated.
173
+ jagged_value_row_stride,
174
+ # pyre-fixme[2]: Parameter must be annotated.
175
+ D,
176
+ H: tl.constexpr,
177
+ # pyre-fixme[2]: Parameter must be annotated.
178
+ output_ptr,
179
+ ) -> None:
180
+
181
+ pid = tl.program_id(0)
182
+
183
+ # number of kernel need for with matrix (N,D) calculated by D // thread_block_col_size
184
+ GRID_DIM_COL = (D + thread_block_col_size - 1) // thread_block_col_size
185
+
186
+ # current output row index
187
+ output_row_idx = pid // GRID_DIM_COL
188
+
189
+ # current jagged tensor offset index
190
+ jagged_offset_id = output_row_idx // H
191
+
192
+ # current index with D reference since the real shape of jagged values is [B , N , H * D]
193
+ D_refer_idx = output_row_idx % H
194
+
195
+ # current part of [N * D] id
196
+ group_id = pid % GRID_DIM_COL
197
+
198
+ # size of tile
199
+ offset = group_id * thread_block_col_size + tl.arange(0, thread_block_col_size)
200
+
201
+ # begin index and end index of values
202
+ begin = tl.load(jagged_offset + jagged_offset_id)
203
+ end = tl.load(jagged_offset + (jagged_offset_id + 1))
204
+
205
+ # update each pointer to the correct address
206
+ dense_ptr += output_row_idx * dense_row_stride
207
+ jagged_tensor_ptr += begin * jagged_value_row_stride + D_refer_idx * D
208
+ output_ptr += D * output_row_idx
209
+
210
+ # Number of row each kernel will go through
211
+ num_row = tl.minimum(end - begin, dense_row_stride)
212
+
213
+ # accumulation variable use for matmul
214
+ acc = tl.zeros((thread_block_col_size,), dtype=tl.float32)
215
+ mask = offset < D
216
+ for i in range(num_row):
217
+ val1 = tl.load(dense_ptr + i)
218
+ val2 = tl.load(jagged_tensor_ptr + offset, mask=mask, other=0.0)
219
+ result = val1 * val2
220
+ acc += result
221
+ jagged_tensor_ptr += jagged_value_row_stride
222
+
223
+ tl.store(output_ptr + offset, acc, mask=mask)
224
+
225
+
226
+ # torch.bmm refer https://pytorch.org/docs/stable/generated/torch.bmm.html
227
+ # Operation that take dense as format [B * H , N] where N is the max_length in the logical representation we treat dense like [B * H , 1 , N]
228
+ # and 2D jagged tensor with format values format [B , N , H * D] in the logical representation we treat values like [B * H , N , D]
229
+ # in the 2D jagged tensor case offset will be tensor instead of list of tensor
230
+ # create output dense with shape [B * H , 1 , D]
231
+ # dense * jagged_tesnor = output_dense -> [B * H , 1 , N] * [B * H , N , D] = [B * H , 1 , D]
232
+ def batched_dense_vec_jagged_2d_matmul(
233
+ dense: torch.Tensor,
234
+ values: torch.Tensor,
235
+ offset: torch.Tensor,
236
+ ) -> torch.Tensor:
237
+ B = offset.size(0) - 1
238
+ H = dense.size(0) // B
239
+ D = values.size(-1) // H
240
+ thread_block_col_size = 32
241
+
242
+ output_dense = torch.empty((B * H, D), device="cuda", dtype=values.dtype)
243
+
244
+ # number of thread block need for jagged tensor with [B * H , N , D]
245
+ # pyre-fixme[53]: Captured variable `B` is not annotated.
246
+ # pyre-fixme[53]: Captured variable `D` is not annotated.
247
+ # pyre-fixme[53]: Captured variable `H` is not annotated.
248
+ # pyre-fixme[53]: Captured variable `thread_block_col_size` is not annotated.
249
+ # pyre-fixme[3]: Return type must be annotated.
250
+ # pyre-fixme[2]: Parameter must be annotated.
251
+ def grid(META):
252
+ return (B * H * triton.cdiv(D, thread_block_col_size),)
253
+
254
+ triton_batched_dense_vec_jagged_2d_matmul[grid](
255
+ values,
256
+ dense,
257
+ offset,
258
+ thread_block_col_size,
259
+ dense.stride(0),
260
+ values.stride(0),
261
+ D,
262
+ H,
263
+ output_dense,
264
+ )
265
+
266
+ return output_dense
267
+
268
+
269
+ # each kernel will handle the conversion of one jagged tensor offset range to corresponding dense index
270
+ @triton.jit
271
+ def triton_jagged_to_dense(
272
+ # only constexpr annotations support in triton now
273
+ # pyre-fixme[2]: Parameter must be annotated.
274
+ jagged_value_ptr,
275
+ # pyre-fixme[2]: Parameter must be annotated.
276
+ jagged_offsets_ptr,
277
+ # pyre-fixme[2]: Parameter must be annotated.
278
+ jagged_value_row_stride,
279
+ # pyre-fixme[2]: Parameter must be annotated.
280
+ output_dense_ptr,
281
+ # pyre-fixme[2]: Parameter must be annotated.
282
+ dense_indices_ptr,
283
+ # pyre-fixme[2]: Parameter must be annotated.
284
+ dense_col_stride, # stride of output dense with dimension (z,y,x)
285
+ # pyre-fixme[2]: Parameter must be annotated.
286
+ dense_row_stride,
287
+ # pyre-fixme[2]: Parameter must be annotated.
288
+ dense_matrix_stride,
289
+ JAGGED_DIM: tl.constexpr, # number of dimension of jagged tensor
290
+ thread_block_row_size: tl.constexpr,
291
+ thread_block_col_size: tl.constexpr,
292
+ operation_function: tl.constexpr, # fusion arithmetic operation function and it's input dense
293
+ # pyre-fixme[2]: Parameter must be annotated.
294
+ operation_dense,
295
+ ) -> None:
296
+ pid = tl.program_id(0)
297
+
298
+ # begin index and end index of jagged tensor Values
299
+ begin = tl.load(jagged_offsets_ptr + pid)
300
+ end = tl.load(jagged_offsets_ptr + (pid + 1))
301
+
302
+ # adjust the address of the jagged tensor Values to the correct address
303
+ jagged_value_ptr += begin * jagged_value_row_stride
304
+
305
+ # if it's 2D (or 1D) Jagged tensor we can direct use the offset in offsets ( since there is only one offset )
306
+ # else we actually need to use the preprocess index to found the correct address of dense
307
+ if JAGGED_DIM > 2:
308
+ # read the index for current kernel
309
+ dense_indice = tl.load(dense_indices_ptr + pid)
310
+
311
+ # if the dense_indice is -1 which mean it's a truncation case
312
+ # in that case we don't need to do anything since the dense
313
+ # initialize with padded value
314
+ if dense_indice == -1:
315
+ return
316
+
317
+ # adjust the address of output dense ptr to the correct address
318
+ output_dense_ptr += dense_indice
319
+
320
+ # also need to update the operation function if exist
321
+ # notice dense_indice of two is same because we assume
322
+ # the two dense + dense are same size
323
+ if operation_function is not None:
324
+ operation_dense += dense_indice
325
+ else:
326
+ output_dense_ptr += pid * dense_matrix_stride
327
+
328
+ if operation_function is not None:
329
+ operation_dense += pid * dense_matrix_stride
330
+
331
+ offset_row = tl.arange(0, thread_block_row_size)
332
+
333
+ # boundary need for the mask since it could be dense's size smaller than jagged tensor or revert case
334
+ N = tl.minimum(dense_row_stride, jagged_value_row_stride)
335
+ M = tl.minimum(dense_matrix_stride // dense_row_stride, end - begin)
336
+
337
+ for _i in range(begin, end, thread_block_row_size):
338
+ offset_col = tl.arange(0, thread_block_col_size)
339
+ block_offset = (
340
+ offset_row[:, None] * dense_row_stride
341
+ + offset_col[None, :] * dense_col_stride
342
+ )
343
+ for _j in range(0, N, thread_block_col_size):
344
+ mask = (offset_row[:, None] < M) & (offset_col[None, :] < N)
345
+ jagged_val = tl.load(jagged_value_ptr + block_offset, mask=mask, other=0)
346
+
347
+ # if there is some arithmetic operation we do the fusion computation
348
+ if operation_function is not None:
349
+ val1 = jagged_val
350
+ val2 = tl.load(operation_dense + block_offset, mask=mask, other=0)
351
+ # do the arithmetic operation
352
+ if operation_function == "add":
353
+ jagged_val = tensor_elementwise_add(val1, val2)
354
+ else:
355
+ jagged_val = tensor_elementwise_mul(val1, val2)
356
+
357
+ # store the result
358
+ tl.store(output_dense_ptr + block_offset, jagged_val, mask=mask)
359
+
360
+ # update the block offset
361
+ offset_col += thread_block_col_size
362
+ block_offset += thread_block_col_size
363
+ offset_row += thread_block_row_size
364
+
365
+
366
+ # This function will handle the 2d Jagged Tensor to Dense operation
367
+ # each kernel will go through all the element in each 2D tensor in
368
+ # Dense ( Notice that since it's 2d jagged tensor dense will be 3D ).
369
+ # Each kernel will check if the current value in 2d tensor is in
370
+ # range or out of range. If in the range of Jagged Tensor, it will load
371
+ # corresponding value, otherwise it will load padded value into dense.
372
+ # On the other hand, in the function triton_jagged_to_dense, we are
373
+ # only able to fill the value from jagged tensor to corresponding dense
374
+ # but we are not be able to fill the dense with padded value in kernel.
375
+ # therefore in pervious function, we fill dense with padded value first
376
+ # then load corresponding value. Instead this function can directly
377
+ # fill the value in kernel to avoid extra latency.
378
+ @triton.jit
379
+ def triton_jagged_to_dense_optimization_2d(
380
+ # pyre-fixme[2]: Parameter must be annotated.
381
+ input_jagged_values_ptr,
382
+ # pyre-fixme[2]: Parameter must be annotated.
383
+ input_jagged_offset_ptr,
384
+ # pyre-fixme[2]: Parameter must be annotated.
385
+ input_jagged_row_stride,
386
+ # pyre-fixme[2]: Parameter must be annotated.
387
+ output_dense_ptr,
388
+ # pyre-fixme[2]: Parameter must be annotated.
389
+ output_dense_row_stride,
390
+ # pyre-fixme[2]: Parameter must be annotated.
391
+ output_dense_matrix_stride,
392
+ thread_block_row_size: tl.constexpr,
393
+ thread_block_col_size: tl.constexpr,
394
+ # pyre-fixme[2]: Parameter must be annotated.
395
+ padded_value,
396
+ operation_function: tl.constexpr,
397
+ # pyre-fixme[2]: Parameter must be annotated.
398
+ operation_dense,
399
+ ) -> None:
400
+ pid = tl.program_id(0)
401
+
402
+ # Current corresponding offset indice
403
+ offset_idx = pid
404
+
405
+ # begin index and end index of jagged tensor Values
406
+ begin = tl.load(input_jagged_offset_ptr + offset_idx)
407
+ end = tl.load(input_jagged_offset_ptr + offset_idx + 1)
408
+
409
+ # row size of current sub tensor
410
+ cur_jagged_tensor_row_size = end - begin
411
+
412
+ # update dense and jagged tensor Values to corresponding address
413
+ output_dense_ptr += pid * output_dense_matrix_stride
414
+ input_jagged_values_ptr += begin * input_jagged_row_stride
415
+
416
+ # also need to update the operation function if exist
417
+ # notice dense_indice of two is same because we assume
418
+ # the two dense + dense are same size
419
+ if operation_function is not None:
420
+ operation_dense += pid * output_dense_matrix_stride
421
+
422
+ # jagged tensor row block
423
+ offset_row = tl.arange(0, thread_block_row_size)
424
+
425
+ # dense row and col block
426
+ # notice jagged tensor and dense share same col block since embedding dimension is same
427
+ dense_col_size = output_dense_row_stride
428
+ dense_row_size = output_dense_matrix_stride // output_dense_row_stride
429
+
430
+ for _i in range(0, dense_row_size, thread_block_row_size):
431
+ offset_col = tl.arange(0, thread_block_col_size)
432
+ block_offset = (
433
+ offset_row[:, None] * output_dense_row_stride + offset_col[None, :]
434
+ )
435
+
436
+ for _j in range(0, dense_col_size, thread_block_col_size):
437
+
438
+ # create mask for dense and jagged tensor for boundary check
439
+ dense_mask = (offset_row[:, None] < dense_row_size) & (
440
+ offset_col[None, :] < dense_col_size
441
+ )
442
+ jagged_mask = (offset_row[:, None] < cur_jagged_tensor_row_size) & (
443
+ offset_col[None, :] < input_jagged_row_stride
444
+ )
445
+
446
+ # get value from jagged tesnor
447
+ jagged_val = tl.load(
448
+ input_jagged_values_ptr + block_offset,
449
+ mask=jagged_mask,
450
+ other=padded_value,
451
+ )
452
+
453
+ # do fusion operation if need
454
+ if operation_function is not None:
455
+ operation_dense_val = tl.load(
456
+ operation_dense + block_offset, mask=dense_mask, other=0.0
457
+ )
458
+ jagged_val = operation_function(operation_dense_val, jagged_val)
459
+
460
+ # load value into empty dense
461
+ tl.store(output_dense_ptr + block_offset, jagged_val, mask=dense_mask)
462
+
463
+ # update each block
464
+ offset_col += thread_block_col_size
465
+ block_offset += thread_block_col_size
466
+ offset_row += thread_block_row_size
467
+
468
+
469
+ # this function parse the jagged tensor offsets to corresponding dense index position
470
+ # to see the detail of it see the quip note : https://fb.quip.com/gnzpA7d13vqO
471
+ # the FBGEMM implementation refer : https://www.internalfb.com/code/fbsource/[308212b2902c3182edcb5b204768321e032e8175]/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/jagged_tensor_ops.cu?lines=280
472
+ # In FBGEMM it was computed by GPU but in triton currently has some compilation issue so we use CUP computation method as workaround
473
+ # However in real-world case if we only dealing with 2d jagged tensor we don't need to use this function at all
474
+ def _jagged_offsets_to_dense_indice(
475
+ offsets: list[torch.Tensor], dense_strides: list[int], dense_sizes: list[int]
476
+ ) -> torch.Tensor:
477
+
478
+ output_offset = torch.zeros(len(offsets[-1]) - 1, device="cpu", dtype=torch.int32)
479
+
480
+ offsets_cpu = []
481
+
482
+ for offset in offsets:
483
+ offsets_cpu.append(offset.cpu())
484
+
485
+ for i in range(0, len(offsets_cpu[-1]) - 1):
486
+ idx = i
487
+ result = 0
488
+
489
+ # flag to check if current offset is in the range of dense
490
+ in_range = True
491
+ for j in range(len(offsets_cpu) - 2, -1, -1):
492
+ left = 0
493
+ right = offsets_cpu[j].size(0)
494
+
495
+ # binary search found the corresponding offset group of current index
496
+ while left < right:
497
+ mid = left + (right - left) // 2
498
+
499
+ if offsets_cpu[j][mid] > idx:
500
+ right = mid
501
+ else:
502
+ left = mid + 1
503
+
504
+ cur_val = idx - offsets_cpu[j][left - 1]
505
+
506
+ if dense_sizes and cur_val >= dense_sizes[j + 1]:
507
+ in_range = False
508
+ break
509
+
510
+ result += cur_val * dense_strides[j + 1]
511
+ idx = left - 1
512
+
513
+ if in_range:
514
+ result += idx * dense_strides[0]
515
+
516
+ # another out of output dense range case
517
+ if dense_sizes and idx > dense_sizes[0]:
518
+ result = -1
519
+ output_offset[i] = result
520
+ else:
521
+ output_offset[i] = -1
522
+
523
+ return output_offset.cuda()
524
+
525
+
526
+ # transfer jagged tensor to dense for referring the quip note for wiki : https://fb.quip.com/gnzpA7d13vqO
527
+ # currently when doing the conversion if certain part of dense are not load from the jagged tensor Values
528
+ # it will be skiped. Which mean we initialize the tensor with padded value instead of fill it with padded
529
+ # value while conversion. Currently optimization approach implementation in triton faced some issue with
530
+ # LLVM compile issue but will look a work around when make a comparsion with multiple dimension of
531
+ # jagged tensot. However if currently we only dealing with 2d jagged tensor in real-world case this should
532
+ # not be affected at all
533
+ def jagged_to_dense(
534
+ jagged_values: torch.Tensor,
535
+ jagged_offsets: list[torch.Tensor],
536
+ jagged_max_lengths: list[int],
537
+ padding_value: float = 0.0, # padding value currently use 0.0 as default value
538
+ operation_function: Union[
539
+ str, None
540
+ ] = None, # fusioned operation currently could be add or multiplication
541
+ operation_dense: Union[
542
+ torch.Tensor, None
543
+ ] = None, # dense to make the add/mul with the output dense
544
+ ) -> torch.Tensor:
545
+ outer_dense_size = len(jagged_offsets[0]) - 1
546
+ inner_dense_size = jagged_values.size(-1)
547
+
548
+ # dimension of jagged tensor
549
+ JAGGED_DIM = len(jagged_offsets) + 1
550
+
551
+ output_dense = None
552
+
553
+ # fill the padded value into dense if is multiple dimension
554
+ # other wise create empty dense
555
+ # this is for avoid multiple dimension cases
556
+ # it can create compile error if we going to fill the padding
557
+ # value inside of kernel function
558
+ if JAGGED_DIM > 2:
559
+ output_dense = torch.full(
560
+ ((outer_dense_size,) + tuple(jagged_max_lengths) + (inner_dense_size,)),
561
+ padding_value,
562
+ device="cuda",
563
+ dtype=jagged_values.dtype,
564
+ )
565
+ else:
566
+ output_dense = torch.empty(
567
+ ((outer_dense_size,) + tuple(jagged_max_lengths) + (inner_dense_size,)),
568
+ device="cuda",
569
+ dtype=jagged_values.dtype,
570
+ )
571
+
572
+ thread_block_row_size = 32
573
+ thread_block_col_size = 32
574
+
575
+ grid = (len(jagged_offsets[-1]) - 1,)
576
+
577
+ # dense index in address perspective
578
+ dense_indices = None
579
+
580
+ # if dimension of jagged tensor ( which is number of offset ) we will need calculated the related dense index referring to jagged offsets
581
+ if JAGGED_DIM > 2:
582
+ dense_indices = _jagged_offsets_to_dense_indice(
583
+ jagged_offsets,
584
+ output_dense.stride()[:-2],
585
+ output_dense.size()[:-2],
586
+ )
587
+
588
+ # dense stride for each column, row, and matrix
589
+ dense_col_stride = output_dense.stride(-1)
590
+ dense_row_stride = output_dense.stride(-2)
591
+ dense_matrix_stride = output_dense.stride(-3)
592
+
593
+ if JAGGED_DIM > 2:
594
+ triton_jagged_to_dense[grid](
595
+ jagged_values,
596
+ jagged_offsets[-1],
597
+ jagged_values.stride(0),
598
+ output_dense,
599
+ dense_indices,
600
+ dense_col_stride,
601
+ dense_row_stride,
602
+ dense_matrix_stride,
603
+ JAGGED_DIM,
604
+ thread_block_row_size,
605
+ thread_block_col_size,
606
+ operation_function=operation_function,
607
+ operation_dense=operation_dense,
608
+ )
609
+ else:
610
+ grid = (output_dense.size(0),)
611
+ triton_jagged_to_dense_optimization_2d[grid](
612
+ jagged_values,
613
+ jagged_offsets[-1],
614
+ jagged_values.stride(0),
615
+ output_dense,
616
+ dense_row_stride,
617
+ dense_matrix_stride,
618
+ thread_block_row_size,
619
+ thread_block_col_size,
620
+ padded_value=padding_value,
621
+ operation_function=operation_function,
622
+ operation_dense=operation_dense,
623
+ )
624
+
625
+ return output_dense
626
+
627
+
628
+ # each kernel will handle the conversion of one jagged tensor offset range from corresponding dense index
629
+ @triton.jit
630
+ def triton_dense_to_jagged(
631
+ # pyre-fixme[2]: Parameter must be annotated.
632
+ jagged_value_ptr,
633
+ # pyre-fixme[2]: Parameter must be annotated.
634
+ jagged_offsets_ptr,
635
+ jagged_value_row_stride: int,
636
+ # pyre-fixme[2]: Parameter must be annotated.
637
+ output_dense_ptr,
638
+ # pyre-fixme[2]: Parameter must be annotated.
639
+ dense_indices_ptr,
640
+ # pyre-fixme[2]: Parameter must be annotated.
641
+ dense_col_stride, # stride of output dense with dimension (z,y,x)
642
+ dense_row_stride: int,
643
+ # pyre-fixme[2]: Parameter must be annotated.
644
+ dense_matrix_stride,
645
+ JAGGED_DIM: tl.constexpr, # number of dimension of jagged tensor
646
+ thread_block_row_size: tl.constexpr,
647
+ thread_block_col_size: tl.constexpr,
648
+ operation_function: tl.constexpr, # fusion arithmetic opeartion function and it's input dense
649
+ # pyre-fixme[2]: Parameter must be annotated.
650
+ operation_jagged_value_ptr,
651
+ ) -> None:
652
+ pid = tl.program_id(0)
653
+
654
+ begin = tl.load(jagged_offsets_ptr + pid)
655
+ end = tl.load(jagged_offsets_ptr + (pid + 1))
656
+
657
+ # size of the current value offset range (M , N)
658
+ N = jagged_value_row_stride
659
+ M = end - begin
660
+
661
+ dense_boundary_col = dense_row_stride
662
+ # tl.minimum will change the return type cased compile issue
663
+ # in that case use if statement instead
664
+ if N < dense_row_stride:
665
+ dense_boundary_col = N
666
+
667
+ dense_boundary_row = tl.minimum(dense_matrix_stride // dense_row_stride, M)
668
+
669
+ jagged_value_ptr += begin * jagged_value_row_stride
670
+ if JAGGED_DIM > 2:
671
+ dense_indice = tl.load(dense_indices_ptr + pid)
672
+ # if dense output range we set dense_boundary to -1
673
+ # that mean dense values will not be use with mask
674
+ # since we still need the calculation of fusion step
675
+ # therefore we do not do return here
676
+ if dense_indice == -1:
677
+ dense_boundary_col = -1
678
+ else:
679
+ output_dense_ptr += dense_indice
680
+ else:
681
+ output_dense_ptr += pid * dense_matrix_stride
682
+
683
+ if operation_function is not None:
684
+ operation_jagged_value_ptr += begin * jagged_value_row_stride
685
+
686
+ offset_row = tl.arange(0, thread_block_row_size)
687
+
688
+ for _i in range(begin, end, thread_block_row_size):
689
+ offset_col = tl.arange(0, thread_block_col_size)
690
+ block_offset = (
691
+ offset_row[:, None] * dense_row_stride
692
+ + offset_col[None, :] * dense_col_stride
693
+ )
694
+
695
+ for _j in range(0, N, thread_block_col_size):
696
+ dense_mask = (offset_row[:, None] < dense_boundary_row) & (
697
+ offset_col[None, :] < dense_boundary_col
698
+ )
699
+ jagged_mask = (offset_row[:, None] < M) & (offset_col[None, :] < N)
700
+ dense_values = tl.load(
701
+ output_dense_ptr + block_offset, mask=dense_mask, other=0
702
+ )
703
+ if operation_function is not None:
704
+ operation_jagged_value = tl.load(
705
+ operation_jagged_value_ptr + block_offset, mask=jagged_mask, other=0
706
+ )
707
+ if operation_function == "add":
708
+ dense_values = tensor_elementwise_add(
709
+ dense_values, operation_jagged_value
710
+ )
711
+ else:
712
+ dense_values = tensor_elementwise_mul(
713
+ dense_values, operation_jagged_value
714
+ )
715
+ tl.store(jagged_value_ptr + block_offset, dense_values, mask=jagged_mask)
716
+ offset_col += thread_block_col_size
717
+ block_offset += thread_block_col_size
718
+ offset_row += thread_block_row_size
719
+
720
+
721
+ def dense_to_jagged(
722
+ dense: torch.Tensor,
723
+ jagged_offsets: list[torch.Tensor],
724
+ operation_function: Union[str, None] = None,
725
+ operation_jagged_values: Union[torch.Tensor, None] = None,
726
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
727
+
728
+ thread_block_row_size = 32
729
+ thread_block_col_size = 32
730
+
731
+ if operation_function is None:
732
+ output_jagged_value = torch.empty(
733
+ (jagged_offsets[-1][-1], dense.size(-1)),
734
+ device="cuda",
735
+ dtype=dense.dtype,
736
+ )
737
+ else:
738
+ output_jagged_value = torch.empty(
739
+ # pyre-fixme [16]: Optional type has no attribute `shape`.Pyre
740
+ operation_jagged_values.shape,
741
+ device="cuda",
742
+ dtype=dense.dtype,
743
+ )
744
+
745
+ grid = (jagged_offsets[-1].size(0) - 1,)
746
+
747
+ JAGGED_DIM = len(jagged_offsets) + 1
748
+ dense_indices = None
749
+ if len(jagged_offsets) > 1:
750
+ dense_indices = _jagged_offsets_to_dense_indice(
751
+ jagged_offsets,
752
+ dense.stride()[:-2],
753
+ dense.size()[:-2],
754
+ )
755
+
756
+ # dense stride for each column, row, and matrix
757
+ dense_col_stride = dense.stride(-1)
758
+ dense_row_stride = dense.stride(-2)
759
+ dense_matrix_stride = dense.stride(-3)
760
+
761
+ triton_dense_to_jagged[grid](
762
+ output_jagged_value,
763
+ jagged_offsets[-1],
764
+ output_jagged_value.stride(0),
765
+ dense,
766
+ dense_indices,
767
+ dense_col_stride,
768
+ dense_row_stride,
769
+ dense_matrix_stride,
770
+ JAGGED_DIM,
771
+ thread_block_row_size,
772
+ thread_block_col_size,
773
+ operation_function=operation_function,
774
+ operation_jagged_value_ptr=operation_jagged_values,
775
+ )
776
+
777
+ return output_jagged_value, jagged_offsets
778
+
779
+
780
+ # jagged_tensor + dense -> dense
781
+ def jagged_dense_elementwise_add_dense_output(
782
+ jagged_values: Tensor,
783
+ jagged_offsets: list[Tensor],
784
+ # pyre-fixme[2]: Parameter must be annotated.
785
+ dense,
786
+ ) -> Tensor:
787
+
788
+ # max_length use to build output dense
789
+ # that has same size as input dense
790
+ max_length = dense.size()[1:-1]
791
+
792
+ # convert jagged tensor to dense
793
+ converted_dense = jagged_to_dense(jagged_values, jagged_offsets, max_length)
794
+
795
+ # add opeartion add two dense with same shape
796
+ # Once it's optimazied we can remove this statement
797
+ # and directly return converted_dense
798
+ return converted_dense + dense
799
+
800
+
801
+ # jagged_tensor + dense -> jagged_tensor
802
+ def jagged_dense_elementwise_add_jagged_output(
803
+ jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
804
+ ) -> tuple[Tensor, list[Tensor]]:
805
+
806
+ return dense_to_jagged(
807
+ dense,
808
+ jagged_offsets,
809
+ operation_function="add",
810
+ operation_jagged_values=jagged_values,
811
+ )
812
+
813
+
814
+ # jagged_tensor * dense -> jagged_tensor
815
+ def jagged_dense_elementwise_mul_jagged_output(
816
+ jagged_values: Optional[Tensor], jagged_offsets: list[Tensor], dense: Tensor
817
+ ) -> tuple[Tensor, list[Tensor]]:
818
+
819
+ return dense_to_jagged(
820
+ dense,
821
+ jagged_offsets,
822
+ operation_function="mul",
823
+ operation_jagged_values=jagged_values,
824
+ )