fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,221 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from .common import expect_contiguous
15
+
16
+
17
+ @triton.jit
18
+ def jagged2_to_padded_dense_kernel(
19
+ x_ptr,
20
+ lengths_ptr,
21
+ offsets_ptr,
22
+ output_dense_ptr,
23
+ stride_b,
24
+ stride_m,
25
+ stride_n,
26
+ max_length,
27
+ BLOCK_M: tl.constexpr,
28
+ BLOCK_N: tl.constexpr,
29
+ ):
30
+ pid_batch = tl.program_id(2)
31
+ pid_m = tl.program_id(0)
32
+ pid_n = tl.program_id(1)
33
+
34
+ begin = tl.load(offsets_ptr + pid_batch)
35
+ seqlen = tl.load(lengths_ptr + pid_batch)
36
+
37
+ seqlen = tl.minimum(seqlen, max_length)
38
+ if seqlen == 0:
39
+ return
40
+
41
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
42
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
43
+
44
+ x_ptrs = x_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
45
+ x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))
46
+
47
+ out_ptrs = (
48
+ output_dense_ptr
49
+ + pid_batch * stride_b
50
+ + offs_m[:, None] * stride_m
51
+ + offs_n[None, :] * stride_n
52
+ )
53
+ tl.store(
54
+ out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
55
+ )
56
+
57
+
58
+ @triton.jit
59
+ def padded_dense_to_jagged2_kernel(
60
+ x_ptr,
61
+ lengths_ptr,
62
+ offsets_ptr,
63
+ output_jagged_ptr,
64
+ stride_b,
65
+ stride_m,
66
+ stride_n,
67
+ max_length,
68
+ BLOCK_M: tl.constexpr,
69
+ BLOCK_N: tl.constexpr,
70
+ ):
71
+ pid_batch = tl.program_id(2)
72
+ pid_m = tl.program_id(0)
73
+ pid_n = tl.program_id(1)
74
+
75
+ begin = tl.load(offsets_ptr + pid_batch)
76
+ # end = tl.load(offsets_ptr + pid_batch + 1)
77
+ seqlen = tl.load(lengths_ptr + pid_batch)
78
+
79
+ seqlen = tl.minimum(seqlen, max_length)
80
+
81
+ if seqlen == 0:
82
+ return
83
+
84
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
85
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
86
+
87
+ x_ptrs = (
88
+ x_ptr
89
+ + pid_batch * stride_b
90
+ + offs_m[:, None] * stride_m
91
+ + offs_n[None, :] * stride_n
92
+ )
93
+ x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))
94
+ out_ptrs = output_jagged_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
95
+ tl.store(
96
+ out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
97
+ )
98
+
99
+
100
+ def jagged2_to_padded_dense_fwd(
101
+ values: torch.Tensor,
102
+ lengths: torch.Tensor,
103
+ offsets: torch.Tensor,
104
+ max_length: int,
105
+ padding_value: float,
106
+ ) -> torch.Tensor:
107
+ B = offsets.size(0) - 1
108
+
109
+ output_dense = torch.full(
110
+ (B, max_length, max_length),
111
+ padding_value,
112
+ dtype=values.dtype,
113
+ device=values.device,
114
+ )
115
+ BLOCK_M = 32
116
+ BLOCK_N = 32
117
+ num_blocks_m = triton.cdiv(max_length, BLOCK_M)
118
+ num_blocks_n = triton.cdiv(max_length, BLOCK_N)
119
+ grid = (num_blocks_m, num_blocks_n, B)
120
+
121
+ jagged2_to_padded_dense_kernel[grid](
122
+ values,
123
+ lengths,
124
+ offsets,
125
+ output_dense,
126
+ output_dense.stride(0),
127
+ output_dense.stride(1),
128
+ output_dense.stride(2),
129
+ max_length,
130
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
131
+ BLOCK_M,
132
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
133
+ BLOCK_N,
134
+ )
135
+
136
+ return output_dense
137
+
138
+
139
+ def padded_dense_to_jagged2_fwd(
140
+ values: torch.Tensor,
141
+ lengths: torch.Tensor,
142
+ offsets: torch.Tensor,
143
+ max_length: int,
144
+ ) -> torch.Tensor:
145
+ B = values.size(0)
146
+ output_jagged = torch.empty(
147
+ int(offsets[-1]), dtype=values.dtype, device=values.device
148
+ )
149
+ BLOCK_M = 32
150
+ BLOCK_N = 32
151
+ num_blocks_m = triton.cdiv(max_length, BLOCK_M)
152
+ num_blocks_n = triton.cdiv(max_length, BLOCK_N)
153
+ grid = (num_blocks_m, num_blocks_n, B)
154
+
155
+ padded_dense_to_jagged2_kernel[grid](
156
+ values,
157
+ lengths,
158
+ offsets,
159
+ output_jagged,
160
+ values.stride(0),
161
+ values.stride(1),
162
+ values.stride(2),
163
+ max_length,
164
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
165
+ BLOCK_M,
166
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
167
+ BLOCK_N,
168
+ )
169
+
170
+ return output_jagged
171
+
172
+
173
+ class Jagged2ToPaddedDense(torch.autograd.Function):
174
+ @staticmethod
175
+ # pyre-fixme
176
+ def forward(
177
+ ctx,
178
+ values: torch.Tensor,
179
+ offsets: torch.Tensor,
180
+ max_length: int,
181
+ padding_value: float,
182
+ ) -> torch.Tensor:
183
+ lengths_square = offsets[1:] - offsets[0:-1:1]
184
+ lengths = torch.sqrt(lengths_square).to(torch.int32)
185
+
186
+ ctx.max_length = max_length
187
+ ctx.save_for_backward(lengths, offsets)
188
+
189
+ output = jagged2_to_padded_dense_fwd(
190
+ values, lengths, offsets, max_length, padding_value
191
+ )
192
+ return output
193
+
194
+ @staticmethod
195
+ # pyre-fixme
196
+ def backward(
197
+ ctx, grad_output: torch.Tensor
198
+ ) -> tuple[torch.Tensor, None, None, None]:
199
+ max_length = ctx.max_length
200
+ (lengths, offsets) = ctx.saved_tensors
201
+ grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
202
+ return (grad_in, None, None, None)
203
+
204
+
205
+ def jagged2_to_padded_dense(
206
+ values: torch.Tensor,
207
+ offsets: torch.Tensor,
208
+ max_length: int,
209
+ padding_value: float = 0.0,
210
+ ) -> torch.Tensor:
211
+ """
212
+ values: jagged tensor with size [sum(Ni * Ni)]
213
+ offsets: offsets for jagged tensor, with size [B + 1]
214
+ max_length: maximum sequence length in the batch
215
+ padding_value: value to use for padding
216
+ return padded dense tensor of size [B, N, N]
217
+ """
218
+ values = expect_contiguous(values)
219
+ offsets = expect_contiguous(offsets)
220
+
221
+ return Jagged2ToPaddedDense.apply(values, offsets, max_length, padding_value)
@@ -0,0 +1,418 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+
14
+ def set_block_size(N: int) -> int:
15
+ if N > 64:
16
+ return 64
17
+ elif N > 16:
18
+ return 32
19
+ else:
20
+ return 16
21
+
22
+
23
+ # TODO add autotune to find best block size
24
+ # add supergroup to optimize GPU cache
25
+ @triton.jit
26
+ def jagged_dense_bmm_kernel(
27
+ a_ptr,
28
+ a_offset_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ N,
32
+ K,
33
+ stride_am,
34
+ stride_ak,
35
+ stride_bl, # batch idx
36
+ stride_bk,
37
+ stride_bn,
38
+ stride_cm,
39
+ stride_cn,
40
+ max_seq_len, # max sequence length for jaggged tensor
41
+ allow_tf32: tl.constexpr,
42
+ BLOCK_SIZE_M: tl.constexpr,
43
+ BLOCK_SIZE_N: tl.constexpr,
44
+ BLOCK_SIZE_K: tl.constexpr,
45
+ ):
46
+ """Kernel for computing the matmul C = A x B.
47
+ A has shape (sum_B(M_i), K), B has shape (B, K, N) and C has shape (sum_B(M_i), N)
48
+ """
49
+ pid_batch = tl.program_id(0)
50
+ pid = tl.program_id(1)
51
+
52
+ # a_offset_ptr has stride of 1
53
+ # row_start for jagged tensor
54
+ begin = tl.load(a_offset_ptr + pid_batch)
55
+ end = tl.load(a_offset_ptr + pid_batch + 1)
56
+ M = tl.minimum(end - begin, max_seq_len) # in case M > max seq len
57
+ if M == 0:
58
+ return
59
+
60
+ # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
62
+ pid_m = pid // num_pid_n
63
+ pid_n = pid % num_pid_n
64
+
65
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
+
67
+ # if pid_m * BLOCK_SIZE_M >=M, then this block doesn't need to be computed
68
+ if pid_m * BLOCK_SIZE_M >= M:
69
+ return
70
+
71
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
72
+
73
+ if pid_n * BLOCK_SIZE_N >= N:
74
+ return
75
+
76
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
77
+ a_ptrs = a_ptr + (
78
+ offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + begin * stride_am
79
+ ) # jagged tensor ptr
80
+ b_ptrs = b_ptr + (
81
+ offs_k[:, None] * stride_bk
82
+ + offs_bn[None, :] * stride_bn
83
+ + pid_batch * stride_bl
84
+ ) # dense tensor ptr
85
+
86
+ c = tl.zeros(
87
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
88
+ ) # TODO, max this flexible
89
+
90
+ # Compute c[m, n] for 1 example of the batch
91
+ for k in range(0, K, BLOCK_SIZE_K):
92
+ updated_offset = k + offs_k
93
+ a = tl.load(
94
+ a_ptrs,
95
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
96
+ mask=(updated_offset[None, :] < K) & (offs_am[:, None] < M),
97
+ other=0.0,
98
+ )
99
+ b = tl.load(
100
+ b_ptrs,
101
+ mask=(updated_offset[:, None] < K) & (offs_bn[None, :] < N),
102
+ other=0.0,
103
+ )
104
+ c += tl.dot(a, b, allow_tf32=allow_tf32)
105
+ a_ptrs += BLOCK_SIZE_K * stride_ak
106
+ b_ptrs += BLOCK_SIZE_K * stride_bk
107
+
108
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
111
+ c_ptrs = (
112
+ c_ptr
113
+ + stride_cm * offs_m[:, None]
114
+ + stride_cn * offs_n[None, :]
115
+ + begin * stride_cm
116
+ )
117
+ tl.store(c_ptrs, c, mask=mask)
118
+
119
+
120
+ @triton.jit
121
+ def jagged_jagged_bmm_kernel(
122
+ a_ptr,
123
+ a_offset_ptr,
124
+ b_ptr,
125
+ c_ptr,
126
+ M,
127
+ N,
128
+ stride_am,
129
+ stride_ak,
130
+ stride_bk,
131
+ stride_bn,
132
+ stride_cl,
133
+ stride_cm,
134
+ stride_cn,
135
+ max_seq_len,
136
+ allow_tf32: tl.constexpr,
137
+ BLOCK_SIZE_M: tl.constexpr,
138
+ BLOCK_SIZE_N: tl.constexpr,
139
+ BLOCK_SIZE_K: tl.constexpr,
140
+ ):
141
+ """
142
+ Kernel for computing the matmul C = A x B.
143
+ A has shape (M, sum_B(Ki)), B has shape (sum_B(Ki), N) and C has shape (B, M, N)
144
+ """
145
+ pid_batch = tl.program_id(0)
146
+ pid = tl.program_id(1)
147
+
148
+ # need to make sure a_offset_ptr has stride of 1
149
+ begin = tl.load(a_offset_ptr + pid_batch)
150
+ end = tl.load(a_offset_ptr + pid_batch + 1)
151
+ K = end - begin # K for current pid_batch
152
+ K = tl.minimum(K, max_seq_len)
153
+ # if K == 0:
154
+ # return
155
+
156
+ # calculate pid_m and pid_n
157
+ # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
158
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
159
+ pid_m = pid // num_pid_n
160
+ pid_n = pid % num_pid_n
161
+
162
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
163
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
164
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
165
+ a_ptrs = (
166
+ a_ptr
167
+ + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
168
+ + begin * stride_ak
169
+ )
170
+ b_ptrs = (
171
+ b_ptr
172
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
173
+ + begin * stride_bk
174
+ )
175
+
176
+ c = tl.zeros(
177
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
178
+ ) # TODO, max this flexible
179
+ for k in range(0, K, BLOCK_SIZE_K):
180
+ updated_offset = k + offs_k
181
+ a = tl.load(
182
+ a_ptrs,
183
+ # pyre-fixme[16]: `int` has no attribute `__getitem__`.
184
+ mask=((updated_offset[None, :] < K) & (offs_am[:, None] < M)),
185
+ other=0.0,
186
+ )
187
+ b = tl.load(
188
+ b_ptrs,
189
+ mask=((updated_offset[:, None] < K) & (offs_bn[None, :] < N)),
190
+ other=0.0,
191
+ )
192
+ c += tl.dot(a, b, allow_tf32=allow_tf32)
193
+ a_ptrs += BLOCK_SIZE_K * stride_ak
194
+ b_ptrs += BLOCK_SIZE_K * stride_bk
195
+
196
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
199
+ c_ptrs = (
200
+ c_ptr
201
+ + stride_cm * offs_m[:, None]
202
+ + stride_cn * offs_n[None, :]
203
+ + stride_cl * pid_batch
204
+ )
205
+
206
+ tl.store(c_ptrs, c, mask=mask)
207
+
208
+
209
+ def triton_jagged_dense_bmm(a, b, a_offsets, max_seq_len, allow_tf32):
210
+ # checks constraints
211
+ assert a.shape[1] == b.shape[1], "incompatible dimensions"
212
+ assert a_offsets.is_contiguous(), "A offsets mush be contiguous"
213
+ sum_B, K = a.shape
214
+ B, K, N = b.shape
215
+ # Use zeros instead of empty to handle corner case when jagged tensor has length > max seq len
216
+ # In that case, it is possible that the output is inconsistent with the padded version if empty is used
217
+ c = a.new_zeros((sum_B, N))
218
+
219
+ BLOCK_SIZE_M = 32 if max_seq_len < 50 else 64
220
+ BLOCK_SIZE_N = set_block_size(N)
221
+ BLOCK_SIZE_K = set_block_size(K)
222
+
223
+ # 2D launch kernel where each block gets its own program.
224
+ # TODO, is this the best way to handle launch grid?
225
+ # The grid number on M axises is larger than required often due to max_seq_len
226
+ grid = (
227
+ B,
228
+ triton.cdiv(max_seq_len, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
229
+ )
230
+
231
+ jagged_dense_bmm_kernel[grid](
232
+ a,
233
+ a_offsets,
234
+ b,
235
+ c,
236
+ N,
237
+ K,
238
+ a.stride(0),
239
+ a.stride(1),
240
+ b.stride(0),
241
+ b.stride(1),
242
+ b.stride(2),
243
+ c.stride(0),
244
+ c.stride(1),
245
+ max_seq_len,
246
+ allow_tf32,
247
+ BLOCK_SIZE_M,
248
+ BLOCK_SIZE_N,
249
+ BLOCK_SIZE_K,
250
+ )
251
+ return c
252
+
253
+
254
+ def triton_jagged_jagged_bmm(a, b, a_offsets, max_seq_len, allow_tf32):
255
+ # checks constraints
256
+ assert a.shape[1] == b.shape[0], "incompatible dimensions"
257
+ assert a_offsets.is_contiguous(), "A offsets mush be contiguous"
258
+ M, _ = a.shape
259
+ _, N = b.shape
260
+ B = a_offsets.size(0) - 1
261
+ # allocates output
262
+ c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
263
+ # 2D launch kernel where each block gets its own program.
264
+ BLOCK_SIZE_M = set_block_size(M)
265
+ BLOCK_SIZE_N = set_block_size(N)
266
+ BLOCK_SIZE_K = 32
267
+ grid = (
268
+ B,
269
+ triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
270
+ )
271
+ jagged_jagged_bmm_kernel[grid](
272
+ a,
273
+ a_offsets,
274
+ b,
275
+ c,
276
+ M,
277
+ N,
278
+ a.stride(0),
279
+ a.stride(1),
280
+ b.stride(0),
281
+ b.stride(1),
282
+ c.stride(0),
283
+ c.stride(1),
284
+ c.stride(2),
285
+ max_seq_len,
286
+ allow_tf32,
287
+ BLOCK_SIZE_M,
288
+ BLOCK_SIZE_N,
289
+ BLOCK_SIZE_K,
290
+ )
291
+ return c
292
+
293
+
294
+ class JaggedDenseBmm(torch.autograd.Function):
295
+ """
296
+ Compute batch matrix multiplication between JaggedTensor and dense tensor
297
+ dense: [B, N, D] * [B, D, T] = [B, N, T]
298
+ jagged: [Sum_B, D] * [B, D, T] = [Sum_B, T]
299
+ """
300
+
301
+ @staticmethod
302
+ # pyre-fixme
303
+ def forward(
304
+ ctx,
305
+ x: torch.Tensor,
306
+ y: torch.Tensor,
307
+ x_offsets: torch.Tensor,
308
+ N: int,
309
+ allow_tf32: bool,
310
+ ):
311
+ ctx.save_for_backward(x, y, x_offsets)
312
+ ctx.N = N
313
+ ctx.allow_tf32 = allow_tf32
314
+ return triton_jagged_dense_bmm(x, y, x_offsets, N, allow_tf32=allow_tf32)
315
+
316
+ @staticmethod
317
+ # pyre-fixme
318
+ def backward(ctx, grad_output: torch.Tensor):
319
+ """
320
+ # X = [Sum_B, D]
321
+ # Y = [B, D, T]
322
+ # Z = X * Y = [Sum_B, T]
323
+ # dX = dZ * YT # [Sum_B, T] * [B, T, D] = [Sum_B, D]
324
+ # dY = XT * dZ # [D, sum_B] * [sum_B, T] = [D, B, T]
325
+ """
326
+
327
+ # logging.info(f"Jagged bmm backward called")
328
+
329
+ (x, y, x_offsets) = ctx.saved_tensors
330
+ N = ctx.N
331
+ grad_x = triton_jagged_dense_bmm(
332
+ grad_output, y.permute(0, 2, 1), x_offsets, N, allow_tf32=ctx.allow_tf32
333
+ )
334
+ grad_y = triton_jagged_jagged_bmm(
335
+ x.T, grad_output, x_offsets, N, allow_tf32=ctx.allow_tf32
336
+ )
337
+ return grad_x, grad_y, None, None, None
338
+
339
+
340
+ class JaggedJaggedBmm(torch.autograd.Function):
341
+ """
342
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
343
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
344
+ jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
345
+ """
346
+
347
+ @staticmethod
348
+ # pyre-fixme
349
+ def forward(
350
+ ctx,
351
+ x: torch.Tensor,
352
+ y: torch.Tensor,
353
+ x_offsets: torch.Tensor,
354
+ N: int,
355
+ allow_tf32,
356
+ ):
357
+ ctx.save_for_backward(x, y, x_offsets)
358
+ ctx.N = N
359
+ ctx.allow_tf32 = allow_tf32
360
+ return triton_jagged_jagged_bmm(x.T, y, x_offsets, N, allow_tf32=allow_tf32)
361
+
362
+ @staticmethod
363
+ # pyre-fixme
364
+ def backward(ctx, grad_output: torch.Tensor):
365
+ """
366
+ # X = [Sum_B, D]
367
+ # Y = [Sum_B, T]
368
+ # Z = XT * Y = [B, D, T]
369
+ # dXT = dZ * YT -> dX = Y * dZT
370
+ # dY = X * dZ -> X * dZ
371
+ """
372
+ (x, y, offsets) = ctx.saved_tensors
373
+ N = ctx.N
374
+ grad_x = triton_jagged_dense_bmm(
375
+ y, grad_output.permute(0, 2, 1), offsets, N, allow_tf32=ctx.allow_tf32
376
+ )
377
+ grad_y = triton_jagged_dense_bmm(
378
+ x, grad_output, offsets, N, allow_tf32=ctx.allow_tf32
379
+ )
380
+ return grad_x, grad_y, None, None, None
381
+
382
+
383
+ def jagged_dense_bmm(
384
+ x: torch.Tensor,
385
+ y: torch.Tensor,
386
+ x_offsets: torch.Tensor,
387
+ N: int,
388
+ allow_tf32: bool,
389
+ use_fbgemm_kernel: bool = True,
390
+ ) -> torch.Tensor:
391
+ """
392
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
393
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
394
+ jagged: [D, Sum_B] * [Sum_B, T] = [B, D, T]
395
+ """
396
+ if use_fbgemm_kernel:
397
+ return torch.ops.fbgemm.jagged_dense_bmm(x, x_offsets, y, N)[0]
398
+ else:
399
+ return JaggedDenseBmm.apply(x, y, x_offsets, N, allow_tf32)
400
+
401
+
402
+ def jagged_jagged_bmm(
403
+ x: torch.Tensor,
404
+ y: torch.Tensor,
405
+ x_offsets: torch.Tensor,
406
+ N: int,
407
+ allow_tf32: bool,
408
+ use_fbgemm_kernel: bool = True,
409
+ ):
410
+ """
411
+ Compute batch matrix multiplication between JaggedTensor and Jagged Tensor
412
+ dense: [B, D, N] * [B, N, T] = [B, D, T]
413
+ jagged: [Sum_B, D].T * [Sum_B, T] = [B, D, T]
414
+ """
415
+ if use_fbgemm_kernel:
416
+ return torch.ops.fbgemm.jagged_jagged_bmm(x, y, x_offsets, N)
417
+ else:
418
+ return JaggedJaggedBmm.apply(x, y, x_offsets, N, allow_tf32)