fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,73 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .common import next_power_of_two
14
+
15
+
16
+ @triton.jit
17
+ def jagged_self_substraction_jagged_out_kernel(
18
+ a_ptr, # jagged
19
+ b_ptr, # jagged
20
+ a_offsets_ptr,
21
+ b_offsets_ptr,
22
+ max_seq_len,
23
+ BLOCK_SIZE: tl.constexpr,
24
+ ):
25
+ pid_batch = tl.program_id(0)
26
+ pid_index = tl.program_id(1)
27
+
28
+ a_offset = tl.load(a_offsets_ptr + pid_batch)
29
+ a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
30
+ a_length = tl.minimum(a_length, max_seq_len + 1)
31
+
32
+ if a_length <= 1:
33
+ return
34
+
35
+ N = a_length - 1
36
+ if pid_index >= N:
37
+ return
38
+
39
+ a_cur = tl.load(a_ptr + a_offset + pid_index)
40
+ offs = tl.arange(0, BLOCK_SIZE)
41
+ mask = offs < N
42
+ a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
43
+ b = a_cur - a_row
44
+
45
+ b_offset = tl.load(b_offsets_ptr + pid_batch)
46
+ tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)
47
+
48
+
49
+ def triton_jagged_self_substraction_jagged_out(
50
+ jagged_A: torch.Tensor,
51
+ offsets_a: torch.Tensor,
52
+ offsets_b: torch.Tensor,
53
+ max_seq_len,
54
+ ) -> torch.Tensor:
55
+ B = offsets_a.size(0) - 1
56
+
57
+ jagged_B = torch.empty(
58
+ (int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
59
+ )
60
+
61
+ BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
62
+ grid = (B, max_seq_len)
63
+
64
+ jagged_self_substraction_jagged_out_kernel[grid](
65
+ jagged_A,
66
+ jagged_B,
67
+ offsets_a,
68
+ offsets_b,
69
+ max_seq_len,
70
+ BLOCK_SIZE,
71
+ )
72
+
73
+ return jagged_B
@@ -0,0 +1,463 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+
14
+ @triton.jit
15
+ def jagged_softmax_kernel(
16
+ input_ptr,
17
+ output_ptr,
18
+ input_offsets_ptr,
19
+ input_row_stride,
20
+ input_head_stride,
21
+ output_row_stride,
22
+ output_head_stride,
23
+ max_seq_len: tl.constexpr,
24
+ BLOCK_SIZE: tl.constexpr, # BLOCK_SIZE > N (seq len)
25
+ ):
26
+ """
27
+ input shpae is [SUM_B, H]
28
+ output shape is [SUM_B, H]
29
+ """
30
+
31
+ pid_batch = tl.program_id(0)
32
+ pid_head = tl.program_id(1)
33
+ row_begin = tl.load(input_offsets_ptr + pid_batch)
34
+ row_end = tl.load(input_offsets_ptr + pid_batch + 1)
35
+ N = tl.minimum(
36
+ max_seq_len, row_end - row_begin
37
+ ) # number of rows to consider softmax
38
+ if N == 0:
39
+ return
40
+
41
+ row_start_ptr = input_ptr + row_begin * input_row_stride
42
+ col_offsets = tl.arange(0, BLOCK_SIZE)
43
+ input_ptrs = (
44
+ row_start_ptr + col_offsets * input_row_stride + pid_head * input_head_stride
45
+ )
46
+ row = tl.load(input_ptrs, mask=col_offsets < N, other=-float("inf"))
47
+ row_mins_max = row - tl.max(row, axis=0)
48
+ numerator = tl.exp(row_mins_max)
49
+ denominator = tl.sum(numerator, axis=0)
50
+ softmax_output = numerator / denominator
51
+
52
+ output_row_start_ptr = output_ptr + row_begin * output_row_stride
53
+ output_ptrs = (
54
+ output_row_start_ptr
55
+ + col_offsets * output_row_stride
56
+ + pid_head * output_head_stride
57
+ )
58
+
59
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < N)
60
+
61
+
62
+ def jagged_softmax_(x: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int):
63
+ sum_B, H = x.shape
64
+ B = x_offsets.size(0) - 1
65
+ BLOCK_SIZE = max(triton.next_power_of_2(max_seq_len), 8)
66
+
67
+ y = torch.zeros(
68
+ sum_B, H, device=x.device, dtype=x.dtype
69
+ ) # use zeros instead of empty to ensure the consistent behavior compare to padded version
70
+ jagged_softmax_kernel[(B, H)](
71
+ x,
72
+ y,
73
+ x_offsets,
74
+ x.stride(0),
75
+ x.stride(1),
76
+ y.stride(0),
77
+ y.stride(1),
78
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
79
+ max_seq_len,
80
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
81
+ BLOCK_SIZE,
82
+ )
83
+
84
+ return y
85
+
86
+
87
+ @triton.jit
88
+ def jagged_softmax_backward_kernel(
89
+ grad_output_ptr,
90
+ softmax_output_ptr,
91
+ grad_input_ptr, # return value
92
+ input_offsets_ptr,
93
+ grad_output_row_stride,
94
+ grad_output_head_stride,
95
+ softmax_output_row_stride,
96
+ softmax_output_head_stride,
97
+ grad_input_row_stride,
98
+ grad_input_head_stride,
99
+ max_seq_len: tl.constexpr,
100
+ BLOCK_SIZE: tl.constexpr,
101
+ ):
102
+ """
103
+ grad_output_ptr shpae is [SUM_B, H]
104
+ softmax_output shape is [SUM_B, H]
105
+ grad_input shape is [SUM_B, H]
106
+ """
107
+
108
+ pid_batch = tl.program_id(0)
109
+ pid_head = tl.program_id(1)
110
+ row_begin = tl.load(input_offsets_ptr + pid_batch)
111
+ row_end = tl.load(input_offsets_ptr + pid_batch + 1)
112
+ N = tl.minimum(
113
+ max_seq_len, row_end - row_begin
114
+ ) # number of rows to consider softmax
115
+
116
+ col_offsets = tl.arange(0, BLOCK_SIZE)
117
+ grad_output_ptrs = (
118
+ grad_output_ptr
119
+ + row_begin * grad_output_row_stride
120
+ + col_offsets * grad_output_row_stride
121
+ + pid_head * grad_output_head_stride
122
+ )
123
+ softmax_output_ptrs = (
124
+ softmax_output_ptr
125
+ + row_begin * softmax_output_row_stride
126
+ + col_offsets * softmax_output_row_stride
127
+ + pid_head * softmax_output_head_stride
128
+ )
129
+ grad_output_row = tl.load(grad_output_ptrs, mask=col_offsets < N, other=0.0)
130
+ softmax_output_row = tl.load(softmax_output_ptrs, mask=col_offsets < N, other=0.0)
131
+
132
+ sum_value = tl.sum(grad_output_row * softmax_output_row, axis=0)
133
+ grad_input_row = (grad_output_row - sum_value) * softmax_output_row
134
+ grad_input_ptrs = (
135
+ grad_input_ptr
136
+ + row_begin * grad_input_row_stride
137
+ + col_offsets * grad_input_row_stride
138
+ + pid_head * grad_input_head_stride
139
+ )
140
+ tl.store(grad_input_ptrs, grad_input_row, mask=col_offsets < N)
141
+
142
+
143
+ class JaggedSoftmax(torch.autograd.Function):
144
+ @staticmethod
145
+ # pyre-fixme
146
+ def forward(ctx, x: torch.Tensor, x_offsets: torch.Tensor, max_seq_len: int):
147
+ y = jagged_softmax_(x, x_offsets, max_seq_len)
148
+ ctx.save_for_backward(y, x_offsets)
149
+ ctx.max_seq_len = max_seq_len
150
+
151
+ return y
152
+
153
+ @staticmethod
154
+ # pyre-fixme
155
+ def backward(ctx, grad_output: torch.Tensor):
156
+ y, x_offsets = ctx.saved_tensors
157
+ max_seq_len = ctx.max_seq_len
158
+
159
+ sum_B, H = y.shape
160
+ B = x_offsets.size(0) - 1
161
+ BLOCK_SIZE = max(triton.next_power_of_2(max_seq_len), 8)
162
+ grad = torch.zeros(
163
+ sum_B, H, device=y.device, dtype=y.dtype
164
+ ) # use zeros instead of empty to guarantee the behavior
165
+
166
+ jagged_softmax_backward_kernel[(B, H)](
167
+ grad_output,
168
+ y,
169
+ grad,
170
+ x_offsets,
171
+ grad_output.stride(0),
172
+ grad_output.stride(1),
173
+ y.stride(0),
174
+ y.stride(1),
175
+ grad.stride(0),
176
+ grad.stride(1),
177
+ max_seq_len,
178
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
179
+ BLOCK_SIZE,
180
+ )
181
+
182
+ return grad, None, None
183
+
184
+
185
+ def jagged_softmax(
186
+ x: torch.Tensor,
187
+ x_offsets: torch.Tensor,
188
+ max_seq_len: int,
189
+ use_fbgemm_kernel: bool = True,
190
+ ):
191
+ """
192
+ GPU version of jagged softmax: [sum(softmax([B_i, D]))]
193
+ """
194
+ if use_fbgemm_kernel:
195
+ return torch.ops.fbgemm.jagged_softmax(x, x_offsets, max_seq_len)[0]
196
+ else:
197
+ return JaggedSoftmax.apply(x, x_offsets, max_seq_len)
198
+
199
+
200
+ # works now
201
+ # we use row offset for softmax calculation
202
+ # for now, offsets row == offsets col
203
+ @triton.jit
204
+ def jagged_2_softmax_kernel(
205
+ input_ptr,
206
+ output_ptr,
207
+ offsets_row_ptr, # seq
208
+ offsets_col_ptr, # head
209
+ offsets_overall_ptr, # offsets for overall matrix = seq_length_i * head_i
210
+ input_stride,
211
+ output_stride,
212
+ transpose, # one if a is transpose, otherwise zero
213
+ max_seq_len_row, # max_seq_len for row (seq)
214
+ max_seq_len_col, # max_seq_len for col (head)
215
+ BLOCK_SIZE: tl.constexpr, # BLOCK_SIZE > seq_length
216
+ ):
217
+ """
218
+ input shape is [sum_B(Ni * Hi)]
219
+ output shape is [sum_B(Ni * Hi)]
220
+ Padded version = [B, N, H]
221
+ Calculate softmax alone N dim
222
+ Each kernel calulates softmax for 1 sample and 1 head
223
+ offsets_row.size == offsets_col.size == offsets_overall.size
224
+ """
225
+
226
+ pid_batch = tl.program_id(0)
227
+ pid_head = tl.program_id(1)
228
+ # start location of current example
229
+ begin = tl.load(offsets_overall_ptr + pid_batch)
230
+ # end = tl.load(offsets_overall_ptr + pid_batch + 1) # noqa F841
231
+ # end - begin = M_i * N_i
232
+
233
+ # softmax on row
234
+ if transpose:
235
+ N = tl.load(offsets_row_ptr + pid_batch + 1) - tl.load(
236
+ offsets_row_ptr + pid_batch
237
+ )
238
+ H = tl.load(offsets_col_ptr + pid_batch + 1) - tl.load(
239
+ offsets_col_ptr + pid_batch
240
+ )
241
+ stride_n = H
242
+ stride_h = H // H # 1
243
+ # sometimes H is larger than max_seq_len_col
244
+ H = tl.minimum(max_seq_len_col, H)
245
+ N = tl.minimum(max_seq_len_row, N)
246
+ # softmax on col
247
+ else:
248
+ N = tl.load(offsets_col_ptr + pid_batch + 1) - tl.load(
249
+ offsets_col_ptr + pid_batch
250
+ )
251
+ H = tl.load(offsets_row_ptr + pid_batch + 1) - tl.load(
252
+ offsets_row_ptr + pid_batch
253
+ )
254
+ stride_h = N
255
+ stride_n = N // N # 1
256
+ H = tl.minimum(max_seq_len_row, H)
257
+ N = tl.minimum(max_seq_len_col, N)
258
+
259
+ if pid_head >= H: # TODO double check the equal here
260
+ return
261
+ if H == 0 or N == 0:
262
+ return
263
+
264
+ # start of the current example
265
+ start_ptr = input_ptr + begin * input_stride
266
+ # offset for n
267
+ offsets = tl.arange(0, BLOCK_SIZE)
268
+
269
+ # Load a softmax row
270
+ input_ptrs = (
271
+ start_ptr
272
+ + offsets * input_stride * stride_n
273
+ + pid_head * input_stride * stride_h
274
+ ) # start + n offsets + head offset
275
+ row = tl.load(input_ptrs, mask=offsets < N, other=-float("inf"))
276
+ row_mins_max = row - tl.max(row, axis=0)
277
+ numerator = tl.exp(row_mins_max)
278
+ denominator = tl.sum(numerator, axis=0)
279
+ softmax_output = numerator / denominator
280
+
281
+ # calculate output ptr, should be similar to input
282
+ output_start_ptr = output_ptr + begin * output_stride
283
+ output_ptrs = (
284
+ output_start_ptr
285
+ + offsets * output_stride * stride_n
286
+ + pid_head * output_stride * stride_h
287
+ )
288
+ tl.store(output_ptrs, softmax_output, mask=offsets < N)
289
+
290
+
291
+ # TODO, pending test
292
+ @triton.jit
293
+ def jagged_2_softmax_backward_kernel(
294
+ grad_output_ptr, # input
295
+ softmax_output_ptr,
296
+ grad_input_ptr, # return value
297
+ offsets_row_ptr,
298
+ offsets_col_ptr,
299
+ offsets_overall_ptr,
300
+ grad_output_stride,
301
+ softmax_output_stride,
302
+ grad_input_stride,
303
+ transpose, # transpose
304
+ max_seq_len_row: tl.constexpr,
305
+ max_seq_len_col: tl.constexpr,
306
+ BLOCK_SIZE: tl.constexpr,
307
+ ):
308
+ pid_batch = tl.program_id(0)
309
+ pid_head = tl.program_id(1)
310
+ begin = tl.load(offsets_overall_ptr + pid_batch)
311
+ # end = tl.load(offsets_overall_ptr + pid_batch + 1) # noqa F841
312
+
313
+ # softmax on row
314
+ if transpose:
315
+ N = tl.load(offsets_row_ptr + pid_batch + 1) - tl.load(
316
+ offsets_row_ptr + pid_batch
317
+ )
318
+ H = tl.load(offsets_col_ptr + pid_batch + 1) - tl.load(
319
+ offsets_col_ptr + pid_batch
320
+ )
321
+ stride_n = H
322
+ stride_h = H // H # 1
323
+ # sometimes H is larger than max_seq_len_col
324
+ H = tl.minimum(max_seq_len_col, H)
325
+ N = tl.minimum(max_seq_len_row, N)
326
+ # softmax on col
327
+ else:
328
+ N = tl.load(offsets_col_ptr + pid_batch + 1) - tl.load(
329
+ offsets_col_ptr + pid_batch
330
+ )
331
+ H = tl.load(offsets_row_ptr + pid_batch + 1) - tl.load(
332
+ offsets_row_ptr + pid_batch
333
+ )
334
+ stride_h = N
335
+ stride_n = N // N # 1
336
+ H = tl.minimum(max_seq_len_row, H)
337
+ N = tl.minimum(max_seq_len_col, N)
338
+
339
+ if pid_head >= H:
340
+ return
341
+ if H == 0 or N == 0:
342
+ pass
343
+
344
+ start_ptr = grad_output_ptr + begin * grad_output_stride
345
+ offsets = tl.arange(0, BLOCK_SIZE)
346
+
347
+ grad_output_ptrs = (
348
+ start_ptr
349
+ + offsets * grad_output_stride * stride_n
350
+ + pid_head * grad_output_stride * stride_h
351
+ )
352
+ softmax_output_ptrs = (
353
+ softmax_output_ptr
354
+ + begin * softmax_output_stride
355
+ + offsets * softmax_output_stride * stride_n
356
+ + pid_head * softmax_output_stride * stride_h
357
+ )
358
+
359
+ grad_output_row = tl.load(grad_output_ptrs, mask=offsets < N, other=0.0)
360
+ softmax_output_row = tl.load(softmax_output_ptrs, mask=offsets < N, other=0.0)
361
+
362
+ sum_value = tl.sum(grad_output_row * softmax_output_row, axis=0)
363
+ grad_input_row = (grad_output_row - sum_value) * softmax_output_row
364
+
365
+ grad_input_row_start_ptr = grad_input_ptr + begin * grad_input_stride
366
+ grad_input_ptrs = (
367
+ grad_input_row_start_ptr
368
+ + offsets * grad_input_stride * stride_n
369
+ + pid_head * grad_input_stride * stride_h
370
+ )
371
+ tl.store(grad_input_ptrs, grad_input_row, mask=offsets < N)
372
+
373
+
374
+ class Jagged2Softmax(torch.autograd.Function):
375
+ @staticmethod
376
+ # pyre-fixme
377
+ def forward(
378
+ ctx,
379
+ x: torch.Tensor,
380
+ x_offsets: torch.Tensor,
381
+ row_offsets: torch.Tensor,
382
+ head_offsets: torch.Tensor,
383
+ max_seq_len_row: int,
384
+ max_seq_len_head: int,
385
+ transpose: bool = True,
386
+ ) -> torch.Tensor:
387
+ B = x_offsets.size(0) - 1
388
+ BLOCK_SIZE = max(triton.next_power_of_2(max_seq_len_row), 8)
389
+
390
+ y = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
391
+ jagged_2_softmax_kernel[(B, max_seq_len_head)](
392
+ x,
393
+ y,
394
+ row_offsets,
395
+ head_offsets,
396
+ x_offsets,
397
+ x.stride(0),
398
+ y.stride(0),
399
+ transpose, # transpose
400
+ max_seq_len_row,
401
+ max_seq_len_head,
402
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
403
+ BLOCK_SIZE,
404
+ )
405
+
406
+ ctx.save_for_backward(y, x_offsets, row_offsets, head_offsets)
407
+ ctx.max_seq_len_row = max_seq_len_row
408
+ ctx.max_seq_len_head = max_seq_len_head
409
+ ctx.transpose = transpose
410
+
411
+ return y
412
+
413
+ @staticmethod
414
+ # pyre-fixme
415
+ def backward(ctx, grad_output: torch.Tensor):
416
+ # TODO: currently backward kernel have small numerical issues.
417
+ y, x_offsets, row_offsets, head_offsets = ctx.saved_tensors
418
+ B = x_offsets.size(0) - 1
419
+ max_seq_len_row = ctx.max_seq_len_row
420
+ max_seq_len_head = ctx.max_seq_len_head
421
+ BLOCK_SIZE = max(triton.next_power_of_2(max_seq_len_row), 8)
422
+
423
+ grad = torch.zeros(y.size(0), device=y.device, dtype=y.dtype)
424
+
425
+ jagged_2_softmax_backward_kernel[(B, max_seq_len_head)](
426
+ grad_output,
427
+ y,
428
+ grad,
429
+ row_offsets,
430
+ head_offsets,
431
+ x_offsets,
432
+ grad_output.stride(0),
433
+ softmax_output_stride=y.stride(0),
434
+ grad_input_stride=grad.stride(0),
435
+ transpose=ctx.transpose, # transpose
436
+ max_seq_len_row=max_seq_len_row,
437
+ max_seq_len_col=max_seq_len_head,
438
+ # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
439
+ BLOCK_SIZE=BLOCK_SIZE,
440
+ )
441
+
442
+ return grad, None, None, None, None, None, None
443
+
444
+
445
+ def jagged2_softmax(
446
+ x: torch.Tensor,
447
+ offsets: torch.Tensor,
448
+ offsets_total: torch.Tensor,
449
+ max_seq_len: int,
450
+ transpose: bool,
451
+ ):
452
+ """
453
+ GPU version of jagged2 softmax: [sum(softmax([B_i, B_i]))]
454
+ """
455
+ return Jagged2Softmax.apply(
456
+ x,
457
+ offsets_total,
458
+ offsets,
459
+ offsets,
460
+ max_seq_len,
461
+ max_seq_len,
462
+ transpose,
463
+ )