fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,15 @@
1
+ # FBGEMM GenAI MoE Support
2
+
3
+ MetaShuffling MoE kernel support in FBGEMM GenAI kernel library.
4
+
5
+ # **Overview**
6
+
7
+ Mixture-of-Experts (MoE) is a popular model architecture for large language models (LLMs). Although it reduces computation in training and inference by activating less parameters per token, it imposes additional challenges in achieving optimal computation efficiency with high memory and communication pressure, as well as the complexity to handle the dynamism and sparsity nature of the model. Here we introduce a new MoE inference solution, MetaShuffling, which enables us to efficiently deploy Llama 4 models for real scenario inference.
8
+
9
+ [Technical design blog](https://pytorch.org/blog/metashuffling-accelerating-llama-4-moe-inference/).
10
+
11
+ # **Updates**
12
+
13
+ - 2025-05-01: Initial release of MetaShuffling MoE PyTorch examples.
14
+
15
+ - 2025-04-17: Initial release of MetaShuffling MoE GPU kernels.
@@ -0,0 +1,66 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import os
10
+
11
+ import torch
12
+
13
+ try:
14
+ # pyre-ignore[21]
15
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
16
+ from fbgemm_gpu import open_source
17
+ except Exception:
18
+ open_source: bool = False
19
+
20
+ # pyre-ignore[16]
21
+ if open_source:
22
+ torch.ops.load_library(
23
+ os.path.join(
24
+ os.path.dirname(os.path.dirname(__file__)),
25
+ "fbgemm_gpu_experimental_gen_ai.so",
26
+ )
27
+ )
28
+ torch.classes.load_library(
29
+ os.path.join(
30
+ os.path.dirname(os.path.dirname(__file__)),
31
+ "fbgemm_gpu_experimental_gen_ai.so",
32
+ )
33
+ )
34
+ else:
35
+ torch.ops.load_library(
36
+ "//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:index_shuffling_ops"
37
+ )
38
+ torch.ops.load_library(
39
+ "//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gather_scatter_ops"
40
+ )
41
+
42
+ index_shuffling = None
43
+ gather_along_first_dim = None
44
+ scatter_add_along_first_dim = None
45
+
46
+ if torch.cuda.is_available():
47
+ index_shuffling = torch.ops.fbgemm.index_shuffling # noqa F401
48
+ if not torch.version.hip:
49
+ # SM90 support
50
+ gather_along_first_dim = torch.ops.fbgemm.gather_along_first_dim # noqa F401
51
+ scatter_add_along_first_dim = torch.ops.fbgemm.scatter_add_along_first_dim # noqa F401
52
+
53
+ from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( # noqa F401
54
+ grouped_gemm,
55
+ grouped_gemm_fp8_rowwise,
56
+ )
57
+
58
+ from .activation import silu_mul, silu_mul_quant # noqa F401
59
+
60
+ from .gather_scatter import ( # noqa F401
61
+ gather_scale_dense_tokens,
62
+ gather_scale_quant_dense_tokens,
63
+ scatter_add_dense_tokens,
64
+ scatter_add_padded_tokens,
65
+ )
66
+ from .shuffling import combine_shuffling, split_shuffling # noqa F401
@@ -0,0 +1,292 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import get_fp8_constants
16
+
17
+
18
+ # Function APIs
19
+ def silu_mul(
20
+ x0: torch.Tensor,
21
+ x1: torch.Tensor,
22
+ valid_token_count: Optional[torch.Tensor] = None,
23
+ ) -> torch.Tensor:
24
+ """
25
+ Fused silu and mul operations.
26
+
27
+ y = x0 * sigmoid(x0) * x1
28
+
29
+ Args:
30
+ x0: input tensor of shape (T, D)
31
+ x1: input tensor of shape (T, D)
32
+ valid_token_count: tensor of shape (1,) to indicate the number of valid tokens.
33
+
34
+ Returns:
35
+ output tensor of shape (T, D)
36
+ """
37
+
38
+ assert x0.ndim == 2 and x0.stride(1) == 1
39
+ assert x1.ndim == 2 and x1.stride(1) == 1
40
+ assert x0.shape == x1.shape
41
+ assert x0.dtype == x1.dtype
42
+
43
+ T, D = x0.shape
44
+ stride_0 = x0.stride(0)
45
+ stride_1 = x1.stride(0)
46
+
47
+ out = torch.empty((T, D), device="cuda", dtype=x0.dtype)
48
+
49
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
50
+ if T >= NUM_SMS:
51
+ BLOCK_D_OUTER = D
52
+ BLOCK_D_INNER = 1024
53
+ assert D % BLOCK_D_INNER == 0
54
+ else:
55
+ BLOCK_D_OUTER = 512
56
+ BLOCK_D_INNER = 256
57
+ assert D % BLOCK_D_OUTER == 0
58
+ grid = (T, D // BLOCK_D_OUTER)
59
+ _fbgemm_silu_mul[grid](
60
+ out,
61
+ x0,
62
+ x1,
63
+ stride_0,
64
+ stride_1,
65
+ valid_token_count,
66
+ D, # pyre-ignore
67
+ BLOCK_D_OUTER, # pyre-ignore
68
+ BLOCK_D_INNER, # pyre-ignore
69
+ )
70
+ return out
71
+
72
+
73
+ def silu_mul_quant(
74
+ x0: torch.Tensor,
75
+ x1: torch.Tensor,
76
+ scale_ub: Optional[torch.Tensor] = None,
77
+ valid_token_count: Optional[torch.Tensor] = None,
78
+ ) -> tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Fused silu, mul, and FP8 rowwise quantization operations.
81
+
82
+ y, y_scale = quantize(x0 * sigmoid(x0) * x1)
83
+
84
+ Args:
85
+ x0: input tensor of shape (T, D)
86
+ x1: input tensor of shape (T, D)
87
+ scale_ub: tensor of shape (1,) to indicate the upper bound of the scale.
88
+ valid_token_count: tensor of shape (1,) to indicate the number of valid tokens.
89
+
90
+ Returns:
91
+ output quantized tensor of shape (T, D) and its inverse scale of shape (T,)
92
+ """
93
+
94
+ assert x0.ndim == 2 and x0.stride(1) == 1
95
+ assert x1.ndim == 2 and x1.stride(1) == 1
96
+ assert x0.shape == x1.shape
97
+ assert x0.dtype == x1.dtype
98
+
99
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
100
+
101
+ T, D = x0.shape
102
+ stride_0 = x0.stride(0)
103
+ stride_1 = x1.stride(0)
104
+
105
+ out = torch.empty((T, D), device="cuda", dtype=pt_dtype)
106
+ out_inv_scale = torch.empty((T,), device="cuda", dtype=torch.float32)
107
+ if T == 0:
108
+ return out, out_inv_scale
109
+
110
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
111
+ BLOCK_T = triton.cdiv(T, NUM_SMS)
112
+
113
+ NUM_CTAS = triton.cdiv(T, BLOCK_T)
114
+
115
+ grid = (NUM_CTAS,)
116
+ _fbgemm_silu_mul_quant[grid](
117
+ out,
118
+ out_inv_scale,
119
+ x0,
120
+ x1,
121
+ scale_ub,
122
+ stride_0,
123
+ stride_1,
124
+ valid_token_count,
125
+ T,
126
+ D, # pyre-ignore
127
+ BLOCK_T,
128
+ TL_FP8_DTYPE=tl_dtype, # pyre-ignore
129
+ MAX_FP8=max_fp8, # pyre-ignore
130
+ EPS=eps, # pyre-ignore
131
+ CLAMP_MAX=scale_ub is not None, # pyre-ignore
132
+ )
133
+ return out, out_inv_scale
134
+
135
+
136
+ # Torch Custom Op Registrations
137
+ _SILU_MUL_OP_NAME = "fbgemm::silu_mul"
138
+
139
+ torch.library.define(
140
+ "fbgemm::silu_mul",
141
+ "(Tensor x0, Tensor x1, Tensor? valid_token_count=None) -> Tensor",
142
+ )
143
+
144
+
145
+ @torch.library.impl(_SILU_MUL_OP_NAME, "Meta")
146
+ def silu_mul_meta(x0, x1, valid_token_count):
147
+ return x0.new_empty(x0.shape)
148
+
149
+
150
+ @torch.library.impl(_SILU_MUL_OP_NAME, "CUDA")
151
+ def silu_mul_cuda(x0, x1, valid_token_count):
152
+ return silu_mul(x0, x1, valid_token_count)
153
+
154
+
155
+ _SILU_MUL_OP_QUANT_NAME = "fbgemm::silu_mul_quant"
156
+
157
+ torch.library.define(
158
+ "fbgemm::silu_mul_quant",
159
+ "(Tensor x0, Tensor x1, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> (Tensor, Tensor)",
160
+ )
161
+
162
+
163
+ @torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "Meta")
164
+ def silu_mul_quant_meta(x0, x1, scale_ub, valid_token_count):
165
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
166
+ return torch.empty(x0.shape, device=x0.device, dtype=pt_dtype)
167
+
168
+
169
+ @torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "CUDA")
170
+ def silu_mul_quant_cuda(x0, x1, scale_ub=None, valid_token_count=None):
171
+ return silu_mul_quant(x0, x1, scale_ub, valid_token_count)
172
+
173
+
174
+ # Kernel Implementations
175
+ @triton.jit
176
+ def _fbgemm_silu_mul(
177
+ y_ptr,
178
+ x0_ptr,
179
+ x1_ptr,
180
+ stride_0,
181
+ stride_1,
182
+ valid_token_count,
183
+ D: tl.constexpr,
184
+ BLOCK_D_OUTER: tl.constexpr,
185
+ BLOCK_D_INNER: tl.constexpr,
186
+ ) -> None:
187
+ token_index = tl.program_id(0)
188
+ feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
189
+
190
+ if valid_token_count is not None:
191
+ valid_token_count = tl.load(
192
+ valid_token_count, None, eviction_policy="evict_last"
193
+ )
194
+ if token_index >= valid_token_count:
195
+ return
196
+
197
+ for _ in tl.range(0, BLOCK_D_OUTER // BLOCK_D_INNER, num_stages=3):
198
+ x0 = tl.load(
199
+ x0_ptr + token_index * stride_0 + feature_offset,
200
+ None,
201
+ eviction_policy="evict_first",
202
+ ).to(tl.float32)
203
+ x1 = tl.load(
204
+ x1_ptr + token_index * stride_1 + feature_offset,
205
+ None,
206
+ eviction_policy="evict_first",
207
+ ).to(tl.float32)
208
+
209
+ y = x0 * tl.sigmoid(x0) * x1
210
+
211
+ tl.store(
212
+ y_ptr + token_index * D + feature_offset,
213
+ y,
214
+ None,
215
+ )
216
+ feature_offset += BLOCK_D_INNER
217
+
218
+
219
+ @triton.jit
220
+ def _fbgemm_silu_mul_quant(
221
+ y_ptr,
222
+ y_inv_scale_ptr,
223
+ x0_ptr,
224
+ x1_ptr,
225
+ scale_ub_ptr,
226
+ stride_0,
227
+ stride_1,
228
+ valid_token_count,
229
+ T,
230
+ D: tl.constexpr,
231
+ BLOCK_T: tl.constexpr,
232
+ TL_FP8_DTYPE: tl.constexpr,
233
+ MAX_FP8: tl.constexpr,
234
+ EPS: tl.constexpr,
235
+ CLAMP_MAX: tl.constexpr,
236
+ ) -> None:
237
+ PADDED_D: tl.constexpr = triton.next_power_of_2(D) # pyre-ignore
238
+
239
+ tidx = tl.program_id(0)
240
+ start_idx = tidx * BLOCK_T
241
+ end_idx = tl.minimum(start_idx + BLOCK_T, T)
242
+
243
+ if valid_token_count is not None:
244
+ valid_token_count = tl.load(
245
+ valid_token_count, None, eviction_policy="evict_last"
246
+ )
247
+ if start_idx >= valid_token_count:
248
+ return
249
+
250
+ offsets = tl.arange(0, PADDED_D)[:]
251
+ mask = offsets < D
252
+
253
+ if CLAMP_MAX:
254
+ ub = tl.load(scale_ub_ptr, eviction_policy="evict_last")
255
+ else:
256
+ ub = float("inf")
257
+
258
+ for token_index in tl.range(start_idx, end_idx, 1, num_stages=2):
259
+ x0 = tl.load(
260
+ x0_ptr + token_index * stride_0 + offsets,
261
+ mask,
262
+ eviction_policy="evict_first",
263
+ ).to(tl.float32)
264
+ x1 = tl.load(
265
+ x1_ptr + token_index * stride_1 + offsets,
266
+ mask,
267
+ eviction_policy="evict_first",
268
+ ).to(tl.float32)
269
+
270
+ y = x0 * tl.sigmoid(x0) * x1
271
+
272
+ # Masked values are set to 0.0.
273
+ row_max = tl.max(tl.where(mask, tl.abs(y), 0.0))
274
+ if CLAMP_MAX:
275
+ row_max = tl.clamp(row_max, EPS, ub)
276
+ else:
277
+ row_max = tl.maximum(row_max, EPS)
278
+
279
+ y_scale = MAX_FP8 / row_max
280
+ tl.store(y_inv_scale_ptr + token_index, 1.0 / y_scale)
281
+
282
+ y = y * y_scale
283
+ # Clamp A to fp8 range to make sure there's no overflow.
284
+ # This is required for AMD. Nvidia's default saturation
285
+ # handles it, but it's nice to have anyway.
286
+ y_fp8 = tl.clamp(y, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
287
+
288
+ tl.store(
289
+ y_ptr + token_index * D + offsets,
290
+ y_fp8,
291
+ mask,
292
+ )