fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,237 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import abc
11
+ import logging
12
+ from collections import deque
13
+ from dataclasses import dataclass
14
+ from types import TracebackType
15
+ from typing import Callable, Optional, TypeVar
16
+
17
+ import torch
18
+
19
+
20
+ class TBEStatsReporter(abc.ABC):
21
+ """
22
+ Interface for TBE runtime stats reporting. Actual implementation may do
23
+ custome aggregation (on intended group-key) and reporting destination.
24
+
25
+ All the report_XXX functions should be light weighted and fail-safe.
26
+ """
27
+
28
+ @abc.abstractmethod
29
+ def should_report(self, iteration_step: int) -> bool:
30
+ """
31
+ Return whether we should report metrics during this step.
32
+ This function should be cheap, side-effect free and return immediately.
33
+ """
34
+ ...
35
+
36
+ @abc.abstractmethod
37
+ def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
38
+ """
39
+ Register stats_name in the whitelist of the reporter
40
+ """
41
+ ...
42
+
43
+ @abc.abstractmethod
44
+ def report_duration(
45
+ self,
46
+ iteration_step: int,
47
+ event_name: str,
48
+ duration_ms: float,
49
+ embedding_id: str = "",
50
+ tbe_id: str = "",
51
+ time_unit: str = "ms",
52
+ enable_tb_metrics: bool = False,
53
+ ) -> None:
54
+ """
55
+ Report the duration of a timed event.
56
+ """
57
+ ...
58
+
59
+ @abc.abstractmethod
60
+ def report_data_amount(
61
+ self,
62
+ iteration_step: int,
63
+ event_name: str,
64
+ data_bytes: int,
65
+ embedding_id: str = "",
66
+ tbe_id: str = "",
67
+ enable_tb_metrics: bool = False,
68
+ ) -> None:
69
+ """
70
+ Report the size of some data amount.
71
+ """
72
+ ...
73
+
74
+
75
+ class StdLogStatsReporter(TBEStatsReporter):
76
+ def __init__(self, report_interval: int) -> None:
77
+ assert report_interval > 0, "Report interval must be positive"
78
+ self.report_interval = report_interval
79
+
80
+ def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
81
+ return
82
+
83
+ def should_report(self, iteration_step: int) -> bool:
84
+ return iteration_step % self.report_interval == 0
85
+
86
+ def report_duration(
87
+ self,
88
+ iteration_step: int,
89
+ event_name: str,
90
+ duration_ms: float,
91
+ embedding_id: str = "",
92
+ tbe_id: str = "",
93
+ time_unit: str = "ms",
94
+ enable_tb_metrics: bool = False,
95
+ ) -> None:
96
+ logging.info(
97
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit} with {enable_tb_metrics}"
98
+ )
99
+
100
+ def report_data_amount(
101
+ self,
102
+ iteration_step: int,
103
+ event_name: str,
104
+ data_bytes: int,
105
+ embedding_id: str = "",
106
+ tbe_id: str = "",
107
+ enable_tb_metrics: bool = False,
108
+ ) -> None:
109
+ logging.info(
110
+ f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes with {enable_tb_metrics}"
111
+ )
112
+
113
+ def __repr__(self) -> str:
114
+ return "StdLogStatsReporter{ " f"report_interval={self.report_interval} " "}"
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class TBEStatsReporterConfig:
119
+ """
120
+ Configuration for TBEStatsReporter. It eventually instantiates the actual
121
+ reporter, so it can be deep-copied without incurring the actual reporter
122
+ getting copied.
123
+ """
124
+
125
+ # Collect required batches every given batches. Non-positive stands for
126
+ # no collection or reporting
127
+ interval: int = -1
128
+
129
+ def create_reporter(self) -> Optional[TBEStatsReporter]:
130
+ assert (
131
+ self.interval <= 0
132
+ ), "Cannot specify interval without an actual implementation of reporter"
133
+ return None
134
+
135
+
136
+ @dataclass(frozen=True)
137
+ class StdLogStatsReporterConfig(TBEStatsReporterConfig):
138
+ def create_reporter(self) -> Optional[TBEStatsReporter]:
139
+ if self.interval <= 0:
140
+ return None
141
+ return StdLogStatsReporter(report_interval=self.interval)
142
+
143
+
144
+ T = TypeVar("T")
145
+
146
+
147
+ class AsyncSeriesTimerRecordedContext:
148
+ """
149
+ An easier way to use AsyncSeriesTimer. Example:
150
+ ```
151
+ timer : AsyncSeriesTimer = ...
152
+ with timer.recording(ctx):
153
+ cuda_kernel1()
154
+ cuda_kernel2()
155
+ cuda_kernel3()
156
+ ```
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ timer: "AsyncSeriesTimer",
162
+ context: T,
163
+ stream: Optional[torch.cuda.Stream] = None,
164
+ ) -> None:
165
+ self._context = context
166
+ self._stream = stream
167
+ self._timer = timer
168
+
169
+ def __enter__(self) -> None:
170
+ self._timer.start(self._stream)
171
+
172
+ def __exit__(
173
+ self,
174
+ exc_type: Optional[type[BaseException]],
175
+ exc_val: Optional[BaseException],
176
+ exc_tb: Optional[TracebackType],
177
+ ) -> None:
178
+ self._timer.stop(self._context, self._stream)
179
+
180
+
181
+ class AsyncSeriesTimer:
182
+ """
183
+ A wrapper class on top of torch.cuda.Event to measure the time between a
184
+ series of CUDA events. Once initiated, every start() and stop() call pair
185
+ will measure the timing between them on GPU. Caller cannot initiate another
186
+ recording if there's already one ongoing.
187
+
188
+ Reporting is asynchronous as the timing result is not ready immediately at
189
+ stop(). Instead, we do it in a lazy way -- we check the all unreported
190
+ events at every start or stop call.
191
+ """
192
+
193
+ def __init__(self, report_functor: Callable[[T, float], None]) -> None:
194
+ self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
195
+ deque()
196
+ )
197
+ self._active_start_event: Optional[torch.cuda.Event] = None
198
+ self._report_functor = report_functor
199
+
200
+ def start(self, stream: Optional[torch.cuda.Stream] = None) -> None:
201
+ assert self._active_start_event is None, "There's an active recording"
202
+ self._active_start_event = torch.cuda.Event(enable_timing=True)
203
+ self._active_start_event.record(stream)
204
+ self._lazy_report()
205
+
206
+ def stop(self, context: T, stream: Optional[torch.cuda.Stream] = None) -> None:
207
+ assert self._active_start_event is not None, "There's no active recording"
208
+ active_start_event: torch.cuda.Event = self._active_start_event
209
+
210
+ active_stop_event = torch.cuda.Event(enable_timing=True)
211
+ active_stop_event.record(stream)
212
+ self._events_queue.append((active_start_event, active_stop_event, context))
213
+ self._active_start_event = None
214
+ self._lazy_report()
215
+
216
+ def recording(
217
+ self, context: T, stream: Optional[torch.cuda.Stream] = None
218
+ ) -> AsyncSeriesTimerRecordedContext:
219
+ return AsyncSeriesTimerRecordedContext(self, context, stream)
220
+
221
+ def _lazy_report(self) -> None:
222
+ # Since this is a series of timing events, the earliest recorded event
223
+ # finishes earliest. So we only need to check the leftmost stop event
224
+ # to decide if we need to report now.
225
+
226
+ while len(self._events_queue):
227
+ stop_event = self._events_queue[0][1]
228
+ if not stop_event.query():
229
+ # Even the earliest event hasn't completed in GPU. Don't do
230
+ # report.
231
+ return
232
+ start_event, stop_event, context = self._events_queue.popleft()
233
+ assert (
234
+ start_event.query()
235
+ ), "Recording has start event later than stop event"
236
+ result = float(start_event.elapsed_time(stop_event))
237
+ self._report_functor(context, result)
@@ -0,0 +1,189 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import torch
11
+
12
+ from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
13
+ from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
14
+ from fbgemm_gpu.utils import TorchLibraryFragment
15
+
16
+ lib = TorchLibraryFragment("fbgemm")
17
+
18
+ lib.define(
19
+ """sll_jagged_dense_bmm(
20
+ Tensor x,
21
+ Tensor y,
22
+ Tensor x_offsets,
23
+ int N,
24
+ bool allow_tf32,
25
+ bool use_fbgemm_kernel=True
26
+ ) -> Tensor
27
+ """
28
+ )
29
+
30
+ lib.define(
31
+ """sll_jagged_jagged_bmm(
32
+ Tensor x,
33
+ Tensor y,
34
+ Tensor x_offsets,
35
+ int N,
36
+ bool allow_tf32,
37
+ bool use_fbgemm_kernel=True
38
+ ) -> Tensor
39
+ """
40
+ )
41
+
42
+ lib.define(
43
+ """sll_dense_jagged_cat_jagged_out(
44
+ Tensor a,
45
+ Tensor b,
46
+ Tensor a_offsets,
47
+ int max_seq_len
48
+ ) -> (Tensor, Tensor)
49
+ """
50
+ )
51
+
52
+ lib.define(
53
+ """sll_jagged_self_substraction_jagged_out(
54
+ Tensor a,
55
+ Tensor offsets_a,
56
+ Tensor offsets_b,
57
+ int max_seq_len
58
+ ) -> Tensor
59
+ """
60
+ )
61
+
62
+ lib.define(
63
+ """sll_jagged2_to_padded_dense(
64
+ Tensor values,
65
+ Tensor offsets,
66
+ int max_length,
67
+ float padding_value
68
+ ) -> Tensor
69
+ """
70
+ )
71
+
72
+ lib.define(
73
+ """sll_jagged_dense_elementwise_mul_jagged_out(
74
+ Tensor x,
75
+ Tensor y,
76
+ Tensor x_seq_lengths,
77
+ Tensor x_offsets,
78
+ int max_seq_len
79
+ ) -> Tensor
80
+ """
81
+ )
82
+
83
+ lib.define(
84
+ """sll_jagged_softmax(Tensor x, Tensor x_offsets, int max_seq_len, bool use_fbgemm_kernel=True) -> Tensor
85
+ """
86
+ )
87
+
88
+ lib.define(
89
+ """sll_jagged2_softmax(Tensor x, Tensor offsets, Tensor offsets_total, int max_seq_len, bool transpose) -> Tensor
90
+ """
91
+ )
92
+
93
+ lib.define(
94
+ """sll_array_jagged_bmm_jagged_out(
95
+ Tensor x,
96
+ Tensor y,
97
+ Tensor x_lengths,
98
+ Tensor x_offsets,
99
+ Tensor y_lengths,
100
+ Tensor y_offsets,
101
+ Tensor z_lengths,
102
+ Tensor z_offsets,
103
+ int max_seq_len,
104
+ bool allow_tf32
105
+ ) -> Tensor
106
+ """
107
+ )
108
+
109
+ lib.define(
110
+ """sll_jagged_jagged_bmm_jagged_out(
111
+ Tensor x,
112
+ Tensor y,
113
+ Tensor x_lengths,
114
+ Tensor x_offsets,
115
+ Tensor y_lengths,
116
+ Tensor y_offsets,
117
+ Tensor z_lengths,
118
+ Tensor z_offsets,
119
+ int max_seq_len,
120
+ bool allow_tf32
121
+ ) -> Tensor
122
+ """
123
+ )
124
+
125
+ lib.define(
126
+ """sll_jagged_flash_attention_basic(
127
+ Tensor q_weights,
128
+ Tensor k_weights,
129
+ Tensor v_weights,
130
+ Tensor offsets,
131
+ int max_seq_len,
132
+ bool use_mask=False,
133
+ bool allow_tf32=True
134
+ ) -> Tensor
135
+ """
136
+ )
137
+
138
+ lib.define(
139
+ """sll_jagged_dense_elementwise_add(
140
+ Tensor x,
141
+ Tensor x_offsets,
142
+ Tensor y,
143
+ int max_seq_len,
144
+ bool use_fbgemm_kernel=True
145
+ ) -> Tensor
146
+ """
147
+ )
148
+
149
+ lib.define(
150
+ """sll_jagged_dense_flash_attention(
151
+ Tensor q_weights,
152
+ Tensor k_weights,
153
+ Tensor v_weights,
154
+ Tensor attn_bias,
155
+ Tensor offsets,
156
+ int max_seq_len,
157
+ bool allow_tf32=True
158
+ ) -> Tensor
159
+ """
160
+ )
161
+
162
+ lib.define(
163
+ """sll_multi_head_jagged_flash_attention(
164
+ Tensor q_weights,
165
+ Tensor k_weights,
166
+ Tensor v_weights,
167
+ Tensor offsets,
168
+ int max_seq_len,
169
+ bool allow_tf32=True
170
+ ) -> Tensor
171
+ """
172
+ )
173
+
174
+ # NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
175
+ # function however, this is not ideal because in the inference case, we don't
176
+ # need the autograd forward to save the context because we don't need to do
177
+ # backward.
178
+
179
+ for op_name, dispatches in sll_cpu_registrations.items():
180
+ lib.register(op_name, dispatches)
181
+
182
+ for op_name, dispatches in sll_meta_registrations.items():
183
+ lib.register(op_name, dispatches)
184
+
185
+ if torch.cuda.is_available():
186
+ from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations
187
+
188
+ for op_name, dispatches in sll_gpu_registrations.items():
189
+ lib.register(op_name, dispatches)
@@ -0,0 +1,80 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ from fbgemm_gpu.sll.cpu.cpu_sll import ( # noqa F401
11
+ cpu_array_jagged_bmm_jagged_out,
12
+ cpu_array_jagged_bmm_jagged_out_kernel, # noqa F401
13
+ cpu_dense_jagged_cat_jagged_out,
14
+ cpu_jagged2_softmax,
15
+ cpu_jagged2_to_padded_dense,
16
+ cpu_jagged_dense_bmm,
17
+ cpu_jagged_dense_elementwise_add,
18
+ cpu_jagged_dense_elementwise_mul_jagged_out,
19
+ cpu_jagged_dense_flash_attention,
20
+ cpu_jagged_flash_attention_basic,
21
+ cpu_jagged_jagged_bmm,
22
+ cpu_jagged_jagged_bmm_jagged_out,
23
+ cpu_jagged_jagged_bmm_jagged_out_kernel, # noqa F401
24
+ cpu_jagged_self_substraction_jagged_out,
25
+ cpu_jagged_softmax,
26
+ )
27
+
28
+ # pyre-ignore[5]
29
+ op_registrations = {
30
+ "sll_jagged_dense_bmm": {
31
+ "CPU": cpu_jagged_dense_bmm,
32
+ "AutogradCPU": cpu_jagged_dense_bmm,
33
+ },
34
+ "sll_jagged_jagged_bmm": {
35
+ "CPU": cpu_jagged_jagged_bmm,
36
+ "AutogradCPU": cpu_jagged_jagged_bmm,
37
+ },
38
+ "sll_dense_jagged_cat_jagged_out": {
39
+ "CPU": cpu_dense_jagged_cat_jagged_out,
40
+ },
41
+ "sll_jagged_self_substraction_jagged_out": {
42
+ "CPU": cpu_jagged_self_substraction_jagged_out,
43
+ },
44
+ "sll_jagged2_to_padded_dense": {
45
+ "CPU": cpu_jagged2_to_padded_dense,
46
+ "AutogradCPU": cpu_jagged2_to_padded_dense,
47
+ },
48
+ "sll_jagged_dense_elementwise_mul_jagged_out": {
49
+ "CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
50
+ "AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
51
+ },
52
+ "sll_jagged_softmax": {
53
+ "CPU": cpu_jagged_softmax,
54
+ "AutogradCPU": cpu_jagged_softmax,
55
+ },
56
+ "sll_jagged2_softmax": {
57
+ "CPU": cpu_jagged2_softmax,
58
+ "AutogradCPU": cpu_jagged2_softmax,
59
+ },
60
+ "sll_array_jagged_bmm_jagged_out": {
61
+ "CPU": cpu_array_jagged_bmm_jagged_out,
62
+ "AutogradCPU": cpu_array_jagged_bmm_jagged_out,
63
+ },
64
+ "sll_jagged_jagged_bmm_jagged_out": {
65
+ "CPU": cpu_jagged_jagged_bmm_jagged_out,
66
+ "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
67
+ },
68
+ "sll_jagged_flash_attention_basic": {
69
+ "CPU": cpu_jagged_flash_attention_basic,
70
+ "AutogradCPU": cpu_jagged_flash_attention_basic,
71
+ },
72
+ "sll_jagged_dense_elementwise_add": {
73
+ "CPU": cpu_jagged_dense_elementwise_add,
74
+ "AutogradCPU": cpu_jagged_dense_elementwise_add,
75
+ },
76
+ "sll_jagged_dense_flash_attention": {
77
+ "CPU": cpu_jagged_dense_flash_attention,
78
+ "AutogradCPU": cpu_jagged_dense_flash_attention,
79
+ },
80
+ }