fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import triton
|
|
11
|
+
import triton.language as tl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@triton.jit
|
|
15
|
+
def jagged_dense_elementwise_mul_jagged_out_kernel(
|
|
16
|
+
a_ptr, # 1d jagged
|
|
17
|
+
b_ptr, # dense
|
|
18
|
+
c_ptr, # 1d jagged
|
|
19
|
+
a_seq_lengths_ptr,
|
|
20
|
+
a_offsets_ptr,
|
|
21
|
+
stride_a,
|
|
22
|
+
stride_bm,
|
|
23
|
+
stride_bn,
|
|
24
|
+
max_seq_len,
|
|
25
|
+
BLOCK_M: tl.constexpr,
|
|
26
|
+
BLOCK_N: tl.constexpr,
|
|
27
|
+
):
|
|
28
|
+
pid_batch = tl.program_id(0)
|
|
29
|
+
pid_row_block = tl.program_id(1)
|
|
30
|
+
|
|
31
|
+
batch_offset = tl.load(a_offsets_ptr + pid_batch)
|
|
32
|
+
batch_seq_len = tl.load(a_seq_lengths_ptr + pid_batch)
|
|
33
|
+
truncated_seq_len = tl.minimum(batch_seq_len, max_seq_len)
|
|
34
|
+
|
|
35
|
+
offs_row = tl.arange(0, BLOCK_M)
|
|
36
|
+
offs_col = tl.arange(0, BLOCK_N)
|
|
37
|
+
|
|
38
|
+
rows = pid_row_block * BLOCK_M + offs_row
|
|
39
|
+
|
|
40
|
+
# a start + batch offset + row offsets + initial col offsets
|
|
41
|
+
a_ptrs = (
|
|
42
|
+
a_ptr
|
|
43
|
+
+ batch_offset * stride_a
|
|
44
|
+
+ rows[:, None] * truncated_seq_len
|
|
45
|
+
+ offs_col[None, :]
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# b start + row offsets + initial col offsets
|
|
49
|
+
b_ptrs = b_ptr + rows[:, None] * stride_bm + offs_col[None, :] * stride_bn
|
|
50
|
+
|
|
51
|
+
# c start + batch offset + row offsets + initial col offsets
|
|
52
|
+
c_ptrs = (
|
|
53
|
+
c_ptr + batch_offset + rows[:, None] * truncated_seq_len + offs_col[None, :]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
for block_start in range(0, truncated_seq_len, BLOCK_N):
|
|
57
|
+
cols = block_start + offs_col
|
|
58
|
+
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
|
|
59
|
+
mask = (rows[:, None] < truncated_seq_len) & (cols[None, :] < truncated_seq_len)
|
|
60
|
+
a = tl.load(a_ptrs, mask=mask)
|
|
61
|
+
a_ptrs += BLOCK_N
|
|
62
|
+
|
|
63
|
+
b = tl.load(b_ptrs, mask=mask)
|
|
64
|
+
b_ptrs += BLOCK_N
|
|
65
|
+
|
|
66
|
+
c = a * b
|
|
67
|
+
tl.store(c_ptrs, c, mask=mask)
|
|
68
|
+
c_ptrs += BLOCK_N
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def triton_jagged_dense_elementwise_mul_jagged_out(
|
|
72
|
+
jagged_A,
|
|
73
|
+
dense_B,
|
|
74
|
+
seq_lengths_a,
|
|
75
|
+
offsets_a,
|
|
76
|
+
max_seq_len,
|
|
77
|
+
):
|
|
78
|
+
B = seq_lengths_a.size(0)
|
|
79
|
+
total_L = jagged_A.size(0)
|
|
80
|
+
|
|
81
|
+
jagged_C = torch.zeros((total_L), device=jagged_A.device, dtype=jagged_A.dtype)
|
|
82
|
+
|
|
83
|
+
BLOCK_M = 32
|
|
84
|
+
BLOCK_N = 32
|
|
85
|
+
num_blocks_m = triton.cdiv(max_seq_len, BLOCK_M)
|
|
86
|
+
grid = (B, num_blocks_m)
|
|
87
|
+
|
|
88
|
+
jagged_dense_elementwise_mul_jagged_out_kernel[grid](
|
|
89
|
+
jagged_A,
|
|
90
|
+
dense_B,
|
|
91
|
+
jagged_C,
|
|
92
|
+
seq_lengths_a,
|
|
93
|
+
offsets_a,
|
|
94
|
+
jagged_A.stride(0),
|
|
95
|
+
dense_B.stride(0),
|
|
96
|
+
dense_B.stride(1),
|
|
97
|
+
max_seq_len,
|
|
98
|
+
BLOCK_M,
|
|
99
|
+
BLOCK_N,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return jagged_C
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class JaggedDenseElementwiseMul(torch.autograd.Function):
|
|
106
|
+
"""
|
|
107
|
+
Compute elementwise multiplication between jagged tensor and dense tensor.
|
|
108
|
+
z = x * y
|
|
109
|
+
x: [sum_B(L_i)]
|
|
110
|
+
y: dense tensor
|
|
111
|
+
z: [sum_B(L_i)]
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
# pyre-fixme
|
|
116
|
+
def forward(
|
|
117
|
+
ctx,
|
|
118
|
+
x: torch.Tensor,
|
|
119
|
+
y: torch.Tensor,
|
|
120
|
+
x_seq_lengths: torch.Tensor,
|
|
121
|
+
x_offsets: torch.Tensor,
|
|
122
|
+
max_seq_len: int,
|
|
123
|
+
):
|
|
124
|
+
ctx.max_seq_len = max_seq_len
|
|
125
|
+
|
|
126
|
+
ctx.save_for_backward(
|
|
127
|
+
x,
|
|
128
|
+
y,
|
|
129
|
+
x_seq_lengths,
|
|
130
|
+
x_offsets,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return triton_jagged_dense_elementwise_mul_jagged_out(
|
|
134
|
+
x,
|
|
135
|
+
y,
|
|
136
|
+
x_seq_lengths,
|
|
137
|
+
x_offsets,
|
|
138
|
+
max_seq_len,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
# pyre-fixme
|
|
143
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
144
|
+
(
|
|
145
|
+
x,
|
|
146
|
+
y,
|
|
147
|
+
x_seq_lengths,
|
|
148
|
+
x_offsets,
|
|
149
|
+
) = ctx.saved_tensors
|
|
150
|
+
|
|
151
|
+
grad_x = triton_jagged_dense_elementwise_mul_jagged_out(
|
|
152
|
+
grad_output,
|
|
153
|
+
y,
|
|
154
|
+
x_seq_lengths,
|
|
155
|
+
x_offsets,
|
|
156
|
+
ctx.max_seq_len,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return grad_x, None, None, None, None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def jagged_dense_elementwise_mul_jagged_out(
|
|
163
|
+
x: torch.Tensor,
|
|
164
|
+
y: torch.Tensor,
|
|
165
|
+
x_seq_lengths: torch.Tensor,
|
|
166
|
+
x_offsets: torch.Tensor,
|
|
167
|
+
max_seq_len: int,
|
|
168
|
+
) -> torch.Tensor:
|
|
169
|
+
return JaggedDenseElementwiseMul.apply(
|
|
170
|
+
x,
|
|
171
|
+
y,
|
|
172
|
+
x_seq_lengths,
|
|
173
|
+
x_offsets,
|
|
174
|
+
max_seq_len,
|
|
175
|
+
)
|