fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import abc
|
|
11
|
+
import logging
|
|
12
|
+
from collections import deque
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from types import TracebackType
|
|
15
|
+
from typing import Callable, Optional, TypeVar
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TBEStatsReporter(abc.ABC):
|
|
21
|
+
"""
|
|
22
|
+
Interface for TBE runtime stats reporting. Actual implementation may do
|
|
23
|
+
custome aggregation (on intended group-key) and reporting destination.
|
|
24
|
+
|
|
25
|
+
All the report_XXX functions should be light weighted and fail-safe.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
def should_report(self, iteration_step: int) -> bool:
|
|
30
|
+
"""
|
|
31
|
+
Return whether we should report metrics during this step.
|
|
32
|
+
This function should be cheap, side-effect free and return immediately.
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Register stats_name in the whitelist of the reporter
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def report_duration(
|
|
45
|
+
self,
|
|
46
|
+
iteration_step: int,
|
|
47
|
+
event_name: str,
|
|
48
|
+
duration_ms: float,
|
|
49
|
+
embedding_id: str = "",
|
|
50
|
+
tbe_id: str = "",
|
|
51
|
+
time_unit: str = "ms",
|
|
52
|
+
enable_tb_metrics: bool = False,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Report the duration of a timed event.
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
@abc.abstractmethod
|
|
60
|
+
def report_data_amount(
|
|
61
|
+
self,
|
|
62
|
+
iteration_step: int,
|
|
63
|
+
event_name: str,
|
|
64
|
+
data_bytes: int,
|
|
65
|
+
embedding_id: str = "",
|
|
66
|
+
tbe_id: str = "",
|
|
67
|
+
enable_tb_metrics: bool = False,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Report the size of some data amount.
|
|
71
|
+
"""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class StdLogStatsReporter(TBEStatsReporter):
|
|
76
|
+
def __init__(self, report_interval: int) -> None:
|
|
77
|
+
assert report_interval > 0, "Report interval must be positive"
|
|
78
|
+
self.report_interval = report_interval
|
|
79
|
+
|
|
80
|
+
def register_stats(self, stats_name: str, amplifier: int = 1) -> None:
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
def should_report(self, iteration_step: int) -> bool:
|
|
84
|
+
return iteration_step % self.report_interval == 0
|
|
85
|
+
|
|
86
|
+
def report_duration(
|
|
87
|
+
self,
|
|
88
|
+
iteration_step: int,
|
|
89
|
+
event_name: str,
|
|
90
|
+
duration_ms: float,
|
|
91
|
+
embedding_id: str = "",
|
|
92
|
+
tbe_id: str = "",
|
|
93
|
+
time_unit: str = "ms",
|
|
94
|
+
enable_tb_metrics: bool = False,
|
|
95
|
+
) -> None:
|
|
96
|
+
logging.info(
|
|
97
|
+
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit} with {enable_tb_metrics}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def report_data_amount(
|
|
101
|
+
self,
|
|
102
|
+
iteration_step: int,
|
|
103
|
+
event_name: str,
|
|
104
|
+
data_bytes: int,
|
|
105
|
+
embedding_id: str = "",
|
|
106
|
+
tbe_id: str = "",
|
|
107
|
+
enable_tb_metrics: bool = False,
|
|
108
|
+
) -> None:
|
|
109
|
+
logging.info(
|
|
110
|
+
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} used {data_bytes} bytes with {enable_tb_metrics}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def __repr__(self) -> str:
|
|
114
|
+
return "StdLogStatsReporter{ " f"report_interval={self.report_interval} " "}"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass(frozen=True)
|
|
118
|
+
class TBEStatsReporterConfig:
|
|
119
|
+
"""
|
|
120
|
+
Configuration for TBEStatsReporter. It eventually instantiates the actual
|
|
121
|
+
reporter, so it can be deep-copied without incurring the actual reporter
|
|
122
|
+
getting copied.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
# Collect required batches every given batches. Non-positive stands for
|
|
126
|
+
# no collection or reporting
|
|
127
|
+
interval: int = -1
|
|
128
|
+
|
|
129
|
+
def create_reporter(self) -> Optional[TBEStatsReporter]:
|
|
130
|
+
assert (
|
|
131
|
+
self.interval <= 0
|
|
132
|
+
), "Cannot specify interval without an actual implementation of reporter"
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass(frozen=True)
|
|
137
|
+
class StdLogStatsReporterConfig(TBEStatsReporterConfig):
|
|
138
|
+
def create_reporter(self) -> Optional[TBEStatsReporter]:
|
|
139
|
+
if self.interval <= 0:
|
|
140
|
+
return None
|
|
141
|
+
return StdLogStatsReporter(report_interval=self.interval)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
T = TypeVar("T")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class AsyncSeriesTimerRecordedContext:
|
|
148
|
+
"""
|
|
149
|
+
An easier way to use AsyncSeriesTimer. Example:
|
|
150
|
+
```
|
|
151
|
+
timer : AsyncSeriesTimer = ...
|
|
152
|
+
with timer.recording(ctx):
|
|
153
|
+
cuda_kernel1()
|
|
154
|
+
cuda_kernel2()
|
|
155
|
+
cuda_kernel3()
|
|
156
|
+
```
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
timer: "AsyncSeriesTimer",
|
|
162
|
+
context: T,
|
|
163
|
+
stream: Optional[torch.cuda.Stream] = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
self._context = context
|
|
166
|
+
self._stream = stream
|
|
167
|
+
self._timer = timer
|
|
168
|
+
|
|
169
|
+
def __enter__(self) -> None:
|
|
170
|
+
self._timer.start(self._stream)
|
|
171
|
+
|
|
172
|
+
def __exit__(
|
|
173
|
+
self,
|
|
174
|
+
exc_type: Optional[type[BaseException]],
|
|
175
|
+
exc_val: Optional[BaseException],
|
|
176
|
+
exc_tb: Optional[TracebackType],
|
|
177
|
+
) -> None:
|
|
178
|
+
self._timer.stop(self._context, self._stream)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class AsyncSeriesTimer:
|
|
182
|
+
"""
|
|
183
|
+
A wrapper class on top of torch.cuda.Event to measure the time between a
|
|
184
|
+
series of CUDA events. Once initiated, every start() and stop() call pair
|
|
185
|
+
will measure the timing between them on GPU. Caller cannot initiate another
|
|
186
|
+
recording if there's already one ongoing.
|
|
187
|
+
|
|
188
|
+
Reporting is asynchronous as the timing result is not ready immediately at
|
|
189
|
+
stop(). Instead, we do it in a lazy way -- we check the all unreported
|
|
190
|
+
events at every start or stop call.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, report_functor: Callable[[T, float], None]) -> None:
|
|
194
|
+
self._events_queue: deque[tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
|
|
195
|
+
deque()
|
|
196
|
+
)
|
|
197
|
+
self._active_start_event: Optional[torch.cuda.Event] = None
|
|
198
|
+
self._report_functor = report_functor
|
|
199
|
+
|
|
200
|
+
def start(self, stream: Optional[torch.cuda.Stream] = None) -> None:
|
|
201
|
+
assert self._active_start_event is None, "There's an active recording"
|
|
202
|
+
self._active_start_event = torch.cuda.Event(enable_timing=True)
|
|
203
|
+
self._active_start_event.record(stream)
|
|
204
|
+
self._lazy_report()
|
|
205
|
+
|
|
206
|
+
def stop(self, context: T, stream: Optional[torch.cuda.Stream] = None) -> None:
|
|
207
|
+
assert self._active_start_event is not None, "There's no active recording"
|
|
208
|
+
active_start_event: torch.cuda.Event = self._active_start_event
|
|
209
|
+
|
|
210
|
+
active_stop_event = torch.cuda.Event(enable_timing=True)
|
|
211
|
+
active_stop_event.record(stream)
|
|
212
|
+
self._events_queue.append((active_start_event, active_stop_event, context))
|
|
213
|
+
self._active_start_event = None
|
|
214
|
+
self._lazy_report()
|
|
215
|
+
|
|
216
|
+
def recording(
|
|
217
|
+
self, context: T, stream: Optional[torch.cuda.Stream] = None
|
|
218
|
+
) -> AsyncSeriesTimerRecordedContext:
|
|
219
|
+
return AsyncSeriesTimerRecordedContext(self, context, stream)
|
|
220
|
+
|
|
221
|
+
def _lazy_report(self) -> None:
|
|
222
|
+
# Since this is a series of timing events, the earliest recorded event
|
|
223
|
+
# finishes earliest. So we only need to check the leftmost stop event
|
|
224
|
+
# to decide if we need to report now.
|
|
225
|
+
|
|
226
|
+
while len(self._events_queue):
|
|
227
|
+
stop_event = self._events_queue[0][1]
|
|
228
|
+
if not stop_event.query():
|
|
229
|
+
# Even the earliest event hasn't completed in GPU. Don't do
|
|
230
|
+
# report.
|
|
231
|
+
return
|
|
232
|
+
start_event, stop_event, context = self._events_queue.popleft()
|
|
233
|
+
assert (
|
|
234
|
+
start_event.query()
|
|
235
|
+
), "Recording has start event later than stop event"
|
|
236
|
+
result = float(start_event.elapsed_time(stop_event))
|
|
237
|
+
self._report_functor(context, result)
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
|
|
13
|
+
from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
|
|
14
|
+
from fbgemm_gpu.utils import TorchLibraryFragment
|
|
15
|
+
|
|
16
|
+
lib = TorchLibraryFragment("fbgemm")
|
|
17
|
+
|
|
18
|
+
lib.define(
|
|
19
|
+
"""sll_jagged_dense_bmm(
|
|
20
|
+
Tensor x,
|
|
21
|
+
Tensor y,
|
|
22
|
+
Tensor x_offsets,
|
|
23
|
+
int N,
|
|
24
|
+
bool allow_tf32,
|
|
25
|
+
bool use_fbgemm_kernel=True
|
|
26
|
+
) -> Tensor
|
|
27
|
+
"""
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
lib.define(
|
|
31
|
+
"""sll_jagged_jagged_bmm(
|
|
32
|
+
Tensor x,
|
|
33
|
+
Tensor y,
|
|
34
|
+
Tensor x_offsets,
|
|
35
|
+
int N,
|
|
36
|
+
bool allow_tf32,
|
|
37
|
+
bool use_fbgemm_kernel=True
|
|
38
|
+
) -> Tensor
|
|
39
|
+
"""
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
lib.define(
|
|
43
|
+
"""sll_dense_jagged_cat_jagged_out(
|
|
44
|
+
Tensor a,
|
|
45
|
+
Tensor b,
|
|
46
|
+
Tensor a_offsets,
|
|
47
|
+
int max_seq_len
|
|
48
|
+
) -> (Tensor, Tensor)
|
|
49
|
+
"""
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
lib.define(
|
|
53
|
+
"""sll_jagged_self_substraction_jagged_out(
|
|
54
|
+
Tensor a,
|
|
55
|
+
Tensor offsets_a,
|
|
56
|
+
Tensor offsets_b,
|
|
57
|
+
int max_seq_len
|
|
58
|
+
) -> Tensor
|
|
59
|
+
"""
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
lib.define(
|
|
63
|
+
"""sll_jagged2_to_padded_dense(
|
|
64
|
+
Tensor values,
|
|
65
|
+
Tensor offsets,
|
|
66
|
+
int max_length,
|
|
67
|
+
float padding_value
|
|
68
|
+
) -> Tensor
|
|
69
|
+
"""
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
lib.define(
|
|
73
|
+
"""sll_jagged_dense_elementwise_mul_jagged_out(
|
|
74
|
+
Tensor x,
|
|
75
|
+
Tensor y,
|
|
76
|
+
Tensor x_seq_lengths,
|
|
77
|
+
Tensor x_offsets,
|
|
78
|
+
int max_seq_len
|
|
79
|
+
) -> Tensor
|
|
80
|
+
"""
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
lib.define(
|
|
84
|
+
"""sll_jagged_softmax(Tensor x, Tensor x_offsets, int max_seq_len, bool use_fbgemm_kernel=True) -> Tensor
|
|
85
|
+
"""
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
lib.define(
|
|
89
|
+
"""sll_jagged2_softmax(Tensor x, Tensor offsets, Tensor offsets_total, int max_seq_len, bool transpose) -> Tensor
|
|
90
|
+
"""
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
lib.define(
|
|
94
|
+
"""sll_array_jagged_bmm_jagged_out(
|
|
95
|
+
Tensor x,
|
|
96
|
+
Tensor y,
|
|
97
|
+
Tensor x_lengths,
|
|
98
|
+
Tensor x_offsets,
|
|
99
|
+
Tensor y_lengths,
|
|
100
|
+
Tensor y_offsets,
|
|
101
|
+
Tensor z_lengths,
|
|
102
|
+
Tensor z_offsets,
|
|
103
|
+
int max_seq_len,
|
|
104
|
+
bool allow_tf32
|
|
105
|
+
) -> Tensor
|
|
106
|
+
"""
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
lib.define(
|
|
110
|
+
"""sll_jagged_jagged_bmm_jagged_out(
|
|
111
|
+
Tensor x,
|
|
112
|
+
Tensor y,
|
|
113
|
+
Tensor x_lengths,
|
|
114
|
+
Tensor x_offsets,
|
|
115
|
+
Tensor y_lengths,
|
|
116
|
+
Tensor y_offsets,
|
|
117
|
+
Tensor z_lengths,
|
|
118
|
+
Tensor z_offsets,
|
|
119
|
+
int max_seq_len,
|
|
120
|
+
bool allow_tf32
|
|
121
|
+
) -> Tensor
|
|
122
|
+
"""
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
lib.define(
|
|
126
|
+
"""sll_jagged_flash_attention_basic(
|
|
127
|
+
Tensor q_weights,
|
|
128
|
+
Tensor k_weights,
|
|
129
|
+
Tensor v_weights,
|
|
130
|
+
Tensor offsets,
|
|
131
|
+
int max_seq_len,
|
|
132
|
+
bool use_mask=False,
|
|
133
|
+
bool allow_tf32=True
|
|
134
|
+
) -> Tensor
|
|
135
|
+
"""
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
lib.define(
|
|
139
|
+
"""sll_jagged_dense_elementwise_add(
|
|
140
|
+
Tensor x,
|
|
141
|
+
Tensor x_offsets,
|
|
142
|
+
Tensor y,
|
|
143
|
+
int max_seq_len,
|
|
144
|
+
bool use_fbgemm_kernel=True
|
|
145
|
+
) -> Tensor
|
|
146
|
+
"""
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
lib.define(
|
|
150
|
+
"""sll_jagged_dense_flash_attention(
|
|
151
|
+
Tensor q_weights,
|
|
152
|
+
Tensor k_weights,
|
|
153
|
+
Tensor v_weights,
|
|
154
|
+
Tensor attn_bias,
|
|
155
|
+
Tensor offsets,
|
|
156
|
+
int max_seq_len,
|
|
157
|
+
bool allow_tf32=True
|
|
158
|
+
) -> Tensor
|
|
159
|
+
"""
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
lib.define(
|
|
163
|
+
"""sll_multi_head_jagged_flash_attention(
|
|
164
|
+
Tensor q_weights,
|
|
165
|
+
Tensor k_weights,
|
|
166
|
+
Tensor v_weights,
|
|
167
|
+
Tensor offsets,
|
|
168
|
+
int max_seq_len,
|
|
169
|
+
bool allow_tf32=True
|
|
170
|
+
) -> Tensor
|
|
171
|
+
"""
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
|
|
175
|
+
# function however, this is not ideal because in the inference case, we don't
|
|
176
|
+
# need the autograd forward to save the context because we don't need to do
|
|
177
|
+
# backward.
|
|
178
|
+
|
|
179
|
+
for op_name, dispatches in sll_cpu_registrations.items():
|
|
180
|
+
lib.register(op_name, dispatches)
|
|
181
|
+
|
|
182
|
+
for op_name, dispatches in sll_meta_registrations.items():
|
|
183
|
+
lib.register(op_name, dispatches)
|
|
184
|
+
|
|
185
|
+
if torch.cuda.is_available():
|
|
186
|
+
from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations
|
|
187
|
+
|
|
188
|
+
for op_name, dispatches in sll_gpu_registrations.items():
|
|
189
|
+
lib.register(op_name, dispatches)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
from fbgemm_gpu.sll.cpu.cpu_sll import ( # noqa F401
|
|
11
|
+
cpu_array_jagged_bmm_jagged_out,
|
|
12
|
+
cpu_array_jagged_bmm_jagged_out_kernel, # noqa F401
|
|
13
|
+
cpu_dense_jagged_cat_jagged_out,
|
|
14
|
+
cpu_jagged2_softmax,
|
|
15
|
+
cpu_jagged2_to_padded_dense,
|
|
16
|
+
cpu_jagged_dense_bmm,
|
|
17
|
+
cpu_jagged_dense_elementwise_add,
|
|
18
|
+
cpu_jagged_dense_elementwise_mul_jagged_out,
|
|
19
|
+
cpu_jagged_dense_flash_attention,
|
|
20
|
+
cpu_jagged_flash_attention_basic,
|
|
21
|
+
cpu_jagged_jagged_bmm,
|
|
22
|
+
cpu_jagged_jagged_bmm_jagged_out,
|
|
23
|
+
cpu_jagged_jagged_bmm_jagged_out_kernel, # noqa F401
|
|
24
|
+
cpu_jagged_self_substraction_jagged_out,
|
|
25
|
+
cpu_jagged_softmax,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# pyre-ignore[5]
|
|
29
|
+
op_registrations = {
|
|
30
|
+
"sll_jagged_dense_bmm": {
|
|
31
|
+
"CPU": cpu_jagged_dense_bmm,
|
|
32
|
+
"AutogradCPU": cpu_jagged_dense_bmm,
|
|
33
|
+
},
|
|
34
|
+
"sll_jagged_jagged_bmm": {
|
|
35
|
+
"CPU": cpu_jagged_jagged_bmm,
|
|
36
|
+
"AutogradCPU": cpu_jagged_jagged_bmm,
|
|
37
|
+
},
|
|
38
|
+
"sll_dense_jagged_cat_jagged_out": {
|
|
39
|
+
"CPU": cpu_dense_jagged_cat_jagged_out,
|
|
40
|
+
},
|
|
41
|
+
"sll_jagged_self_substraction_jagged_out": {
|
|
42
|
+
"CPU": cpu_jagged_self_substraction_jagged_out,
|
|
43
|
+
},
|
|
44
|
+
"sll_jagged2_to_padded_dense": {
|
|
45
|
+
"CPU": cpu_jagged2_to_padded_dense,
|
|
46
|
+
"AutogradCPU": cpu_jagged2_to_padded_dense,
|
|
47
|
+
},
|
|
48
|
+
"sll_jagged_dense_elementwise_mul_jagged_out": {
|
|
49
|
+
"CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
|
|
50
|
+
"AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
|
|
51
|
+
},
|
|
52
|
+
"sll_jagged_softmax": {
|
|
53
|
+
"CPU": cpu_jagged_softmax,
|
|
54
|
+
"AutogradCPU": cpu_jagged_softmax,
|
|
55
|
+
},
|
|
56
|
+
"sll_jagged2_softmax": {
|
|
57
|
+
"CPU": cpu_jagged2_softmax,
|
|
58
|
+
"AutogradCPU": cpu_jagged2_softmax,
|
|
59
|
+
},
|
|
60
|
+
"sll_array_jagged_bmm_jagged_out": {
|
|
61
|
+
"CPU": cpu_array_jagged_bmm_jagged_out,
|
|
62
|
+
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
|
|
63
|
+
},
|
|
64
|
+
"sll_jagged_jagged_bmm_jagged_out": {
|
|
65
|
+
"CPU": cpu_jagged_jagged_bmm_jagged_out,
|
|
66
|
+
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
|
|
67
|
+
},
|
|
68
|
+
"sll_jagged_flash_attention_basic": {
|
|
69
|
+
"CPU": cpu_jagged_flash_attention_basic,
|
|
70
|
+
"AutogradCPU": cpu_jagged_flash_attention_basic,
|
|
71
|
+
},
|
|
72
|
+
"sll_jagged_dense_elementwise_add": {
|
|
73
|
+
"CPU": cpu_jagged_dense_elementwise_add,
|
|
74
|
+
"AutogradCPU": cpu_jagged_dense_elementwise_add,
|
|
75
|
+
},
|
|
76
|
+
"sll_jagged_dense_flash_attention": {
|
|
77
|
+
"CPU": cpu_jagged_dense_flash_attention,
|
|
78
|
+
"AutogradCPU": cpu_jagged_dense_flash_attention,
|
|
79
|
+
},
|
|
80
|
+
}
|