mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/moe/shuffling.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import triton
|
|
13
|
+
import triton.language as tl
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Function APIs
|
|
17
|
+
def combine_shuffling(
|
|
18
|
+
tokens: torch.Tensor,
|
|
19
|
+
token_counts: torch.Tensor,
|
|
20
|
+
expert_start: Optional[int] = None,
|
|
21
|
+
expert_end: Optional[int] = None,
|
|
22
|
+
is_padded: bool = False,
|
|
23
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
24
|
+
# pyre-ignore
|
|
25
|
+
return _combine_or_split_shuffling(
|
|
26
|
+
tokens=tokens,
|
|
27
|
+
token_counts=token_counts,
|
|
28
|
+
expert_start=expert_start,
|
|
29
|
+
expert_end=expert_end,
|
|
30
|
+
is_padded=is_padded,
|
|
31
|
+
is_combine=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def split_shuffling(
|
|
36
|
+
tokens: torch.Tensor,
|
|
37
|
+
token_counts: torch.Tensor,
|
|
38
|
+
expert_start: Optional[int] = None,
|
|
39
|
+
expert_end: Optional[int] = None,
|
|
40
|
+
is_padded: bool = False,
|
|
41
|
+
init_with_zeros: bool = False,
|
|
42
|
+
) -> torch.Tensor:
|
|
43
|
+
# pyre-ignore
|
|
44
|
+
return _combine_or_split_shuffling(
|
|
45
|
+
tokens=tokens,
|
|
46
|
+
token_counts=token_counts,
|
|
47
|
+
expert_start=expert_start,
|
|
48
|
+
expert_end=expert_end,
|
|
49
|
+
is_padded=is_padded,
|
|
50
|
+
is_combine=False,
|
|
51
|
+
init_with_zeros=init_with_zeros,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _combine_or_split_shuffling(
|
|
56
|
+
tokens: torch.Tensor,
|
|
57
|
+
token_counts: torch.Tensor,
|
|
58
|
+
expert_start: Optional[int],
|
|
59
|
+
expert_end: Optional[int],
|
|
60
|
+
is_padded: bool,
|
|
61
|
+
is_combine: bool,
|
|
62
|
+
init_with_zeros: bool = False,
|
|
63
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
64
|
+
# T is intentionally ignored in kernel interface to avoid recompilation
|
|
65
|
+
assert tokens.is_contiguous()
|
|
66
|
+
assert token_counts.is_contiguous()
|
|
67
|
+
|
|
68
|
+
T, D = tokens.shape
|
|
69
|
+
EP, E = token_counts.shape
|
|
70
|
+
B_T = -1
|
|
71
|
+
if is_padded:
|
|
72
|
+
assert T % EP == 0
|
|
73
|
+
B_T = T // EP
|
|
74
|
+
|
|
75
|
+
if expert_start is None:
|
|
76
|
+
expert_start = 0
|
|
77
|
+
if expert_end is None:
|
|
78
|
+
expert_end = E
|
|
79
|
+
|
|
80
|
+
EG: int = expert_end - expert_start
|
|
81
|
+
|
|
82
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
83
|
+
SPLIT_D = max(NUM_SMS // (EP * EG), 1)
|
|
84
|
+
SPLIT_D = triton.next_power_of_2(SPLIT_D + 1)
|
|
85
|
+
if T <= 1024:
|
|
86
|
+
SPLIT_D //= 2
|
|
87
|
+
|
|
88
|
+
if is_combine:
|
|
89
|
+
grid = (EP * EG * SPLIT_D + 1,)
|
|
90
|
+
else:
|
|
91
|
+
grid = (EP * EG * SPLIT_D,)
|
|
92
|
+
|
|
93
|
+
output_tokens = (
|
|
94
|
+
torch.zeros_like(tokens) if init_with_zeros else torch.empty_like(tokens)
|
|
95
|
+
)
|
|
96
|
+
if is_combine:
|
|
97
|
+
output_token_counts = torch.empty(
|
|
98
|
+
EG + 1, dtype=token_counts.dtype, device=token_counts.device
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
output_token_counts = None
|
|
102
|
+
|
|
103
|
+
BLOCK_E = max(triton.next_power_of_2(E), 8)
|
|
104
|
+
BLOCK_EG = max(triton.next_power_of_2(EG), 8)
|
|
105
|
+
BLOCK_EP = max(triton.next_power_of_2(EP), 8)
|
|
106
|
+
|
|
107
|
+
_mslk_combine_or_split_shuffling[grid](
|
|
108
|
+
tokens,
|
|
109
|
+
token_counts,
|
|
110
|
+
output_tokens,
|
|
111
|
+
output_token_counts,
|
|
112
|
+
is_combine,
|
|
113
|
+
expert_start,
|
|
114
|
+
is_padded,
|
|
115
|
+
B_T,
|
|
116
|
+
EG,
|
|
117
|
+
EP,
|
|
118
|
+
E,
|
|
119
|
+
D,
|
|
120
|
+
BLOCK_E,
|
|
121
|
+
BLOCK_EG,
|
|
122
|
+
BLOCK_EP,
|
|
123
|
+
SPLIT_D,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if is_combine:
|
|
127
|
+
assert output_token_counts is not None
|
|
128
|
+
return output_tokens, output_token_counts
|
|
129
|
+
else:
|
|
130
|
+
return output_tokens
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# Torch Custom Op Registrations
|
|
134
|
+
_COMBINE_SHUFFLING_OP_NAME = "mslk::combine_shuffling"
|
|
135
|
+
|
|
136
|
+
torch.library.define(
|
|
137
|
+
_COMBINE_SHUFFLING_OP_NAME,
|
|
138
|
+
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False) -> (Tensor, Tensor)",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "Meta")
|
|
143
|
+
def combine_shuffling_meta(
|
|
144
|
+
tokens,
|
|
145
|
+
token_counts,
|
|
146
|
+
expert_start,
|
|
147
|
+
expert_end,
|
|
148
|
+
is_padded,
|
|
149
|
+
):
|
|
150
|
+
_, E = token_counts.shape
|
|
151
|
+
if expert_start is None:
|
|
152
|
+
expert_start = 0
|
|
153
|
+
if expert_end is None:
|
|
154
|
+
expert_end = E
|
|
155
|
+
|
|
156
|
+
EG: int = expert_end - expert_start
|
|
157
|
+
output_tokens = torch.empty_like(tokens)
|
|
158
|
+
output_token_counts = torch.empty(
|
|
159
|
+
EG + 1, dtype=token_counts.dtype, device=token_counts.device
|
|
160
|
+
)
|
|
161
|
+
return output_tokens, output_token_counts
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "CUDA")
|
|
165
|
+
def combine_shuffling_cuda(
|
|
166
|
+
tokens,
|
|
167
|
+
token_counts,
|
|
168
|
+
expert_start=None,
|
|
169
|
+
expert_end=None,
|
|
170
|
+
is_padded=False,
|
|
171
|
+
):
|
|
172
|
+
return combine_shuffling(
|
|
173
|
+
tokens,
|
|
174
|
+
token_counts,
|
|
175
|
+
expert_start,
|
|
176
|
+
expert_end,
|
|
177
|
+
is_padded,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
_SPLIT_SHUFFLING_OP_NAME = "mslk::split_shuffling"
|
|
182
|
+
|
|
183
|
+
torch.library.define(
|
|
184
|
+
_SPLIT_SHUFFLING_OP_NAME,
|
|
185
|
+
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False, bool? init_with_zeros = False) -> Tensor",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "Meta")
|
|
190
|
+
def split_shuffling_meta(
|
|
191
|
+
tokens,
|
|
192
|
+
token_counts,
|
|
193
|
+
expert_start,
|
|
194
|
+
expert_end,
|
|
195
|
+
is_padded,
|
|
196
|
+
):
|
|
197
|
+
output_tokens = torch.empty_like(tokens)
|
|
198
|
+
return output_tokens
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "CUDA")
|
|
202
|
+
def split_shuffling_cuda(
|
|
203
|
+
tokens,
|
|
204
|
+
token_counts,
|
|
205
|
+
expert_start=None,
|
|
206
|
+
expert_end=None,
|
|
207
|
+
is_padded=False,
|
|
208
|
+
):
|
|
209
|
+
return split_shuffling(
|
|
210
|
+
tokens,
|
|
211
|
+
token_counts,
|
|
212
|
+
expert_start,
|
|
213
|
+
expert_end,
|
|
214
|
+
is_padded,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# Kernel Implementations
|
|
219
|
+
_NV_CONFIGS = [
|
|
220
|
+
triton.Config(
|
|
221
|
+
{
|
|
222
|
+
"BLOCK_T": block_t,
|
|
223
|
+
"BLOCK_D": block_d,
|
|
224
|
+
},
|
|
225
|
+
num_stages=num_stages,
|
|
226
|
+
num_warps=num_warps,
|
|
227
|
+
num_ctas=num_ctas,
|
|
228
|
+
)
|
|
229
|
+
for block_t in [32, 64]
|
|
230
|
+
for block_d in [256, 512, 1024]
|
|
231
|
+
for num_stages in [1, 3]
|
|
232
|
+
for num_warps in [8, 16]
|
|
233
|
+
for num_ctas in [1]
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
_AMD_CONFIGS = [
|
|
237
|
+
triton.Config(
|
|
238
|
+
{
|
|
239
|
+
"BLOCK_T": block_t,
|
|
240
|
+
"BLOCK_D": block_d,
|
|
241
|
+
"waves_per_eu": waves_per_cu,
|
|
242
|
+
},
|
|
243
|
+
num_stages=num_stages,
|
|
244
|
+
num_warps=num_warps,
|
|
245
|
+
)
|
|
246
|
+
for block_t in [32, 64]
|
|
247
|
+
for block_d in [256, 512, 1024]
|
|
248
|
+
for num_stages in [1, 3]
|
|
249
|
+
for num_warps, waves_per_cu in [(8, 2), (16, 4)]
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@triton.autotune(
|
|
254
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
255
|
+
key=[
|
|
256
|
+
"COMBINE",
|
|
257
|
+
"EG",
|
|
258
|
+
"EP",
|
|
259
|
+
"E",
|
|
260
|
+
"D",
|
|
261
|
+
],
|
|
262
|
+
)
|
|
263
|
+
@triton.jit
|
|
264
|
+
def _mslk_combine_or_split_shuffling(
|
|
265
|
+
input_tokens_ptr,
|
|
266
|
+
input_token_counts_ptr,
|
|
267
|
+
output_tokens_ptr,
|
|
268
|
+
output_token_counts_ptr,
|
|
269
|
+
COMBINE: tl.constexpr,
|
|
270
|
+
EG_START,
|
|
271
|
+
PADDED,
|
|
272
|
+
B_T: tl.constexpr,
|
|
273
|
+
EG: tl.constexpr,
|
|
274
|
+
EP: tl.constexpr,
|
|
275
|
+
E: tl.constexpr,
|
|
276
|
+
D: tl.constexpr,
|
|
277
|
+
BLOCK_E: tl.constexpr,
|
|
278
|
+
BLOCK_EG: tl.constexpr,
|
|
279
|
+
BLOCK_EP: tl.constexpr,
|
|
280
|
+
SPLIT_D: tl.constexpr,
|
|
281
|
+
BLOCK_T: tl.constexpr,
|
|
282
|
+
BLOCK_D: tl.constexpr,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""
|
|
285
|
+
tokens: [T, D]
|
|
286
|
+
input_token_counts: [EP, E]
|
|
287
|
+
output_tokens: [T, D]
|
|
288
|
+
output_token_counts: [E]
|
|
289
|
+
"""
|
|
290
|
+
tidx = tl.program_id(0)
|
|
291
|
+
|
|
292
|
+
NUM_D_BLOCKS: tl.constexpr = (D + SPLIT_D * BLOCK_D - 1) // (SPLIT_D * BLOCK_D)
|
|
293
|
+
|
|
294
|
+
rank = tidx // (EG * SPLIT_D)
|
|
295
|
+
local_expert = (tidx % (EG * SPLIT_D)) // SPLIT_D
|
|
296
|
+
didx = tidx % SPLIT_D
|
|
297
|
+
# All experts in communication group
|
|
298
|
+
offs_e = tl.arange(0, BLOCK_E)
|
|
299
|
+
# Local experts
|
|
300
|
+
offs_eg = tl.arange(0, BLOCK_EG)
|
|
301
|
+
# Ranks
|
|
302
|
+
offs_ep = tl.arange(0, BLOCK_EP)
|
|
303
|
+
|
|
304
|
+
global_expert = local_expert + EG_START
|
|
305
|
+
|
|
306
|
+
input_token_counts = tl.load(
|
|
307
|
+
input_token_counts_ptr + offs_ep[:, None] * E + offs_e[None, :],
|
|
308
|
+
eviction_policy="evict_last",
|
|
309
|
+
mask=((offs_ep[:, None] < EP) & (offs_e[None, :] < E)),
|
|
310
|
+
other=0,
|
|
311
|
+
) # [EP, E]
|
|
312
|
+
|
|
313
|
+
if E == EG:
|
|
314
|
+
input_token_counts_eg = input_token_counts
|
|
315
|
+
else:
|
|
316
|
+
input_token_counts_eg = tl.load(
|
|
317
|
+
input_token_counts_ptr + offs_ep[:, None] * E + EG_START + offs_eg[None, :],
|
|
318
|
+
eviction_policy="evict_last",
|
|
319
|
+
mask=((offs_ep[:, None] < EP) & (offs_eg[None, :] < EG)),
|
|
320
|
+
other=0,
|
|
321
|
+
) # [EP, EG]
|
|
322
|
+
|
|
323
|
+
if COMBINE:
|
|
324
|
+
LAST_TILE: tl.constexpr = EP * EG * SPLIT_D
|
|
325
|
+
|
|
326
|
+
if tidx == LAST_TILE:
|
|
327
|
+
output_token_counts_eg = tl.sum(input_token_counts_eg, axis=0)
|
|
328
|
+
tl.store(
|
|
329
|
+
output_token_counts_ptr + offs_eg,
|
|
330
|
+
output_token_counts_eg,
|
|
331
|
+
mask=(offs_eg < EG),
|
|
332
|
+
)
|
|
333
|
+
output_token_counts_eg = tl.sum(output_token_counts_eg)
|
|
334
|
+
tl.store(output_token_counts_ptr + EG, output_token_counts_eg)
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
cond0 = offs_ep[:, None] < rank
|
|
338
|
+
cond1 = offs_ep[:, None] == rank
|
|
339
|
+
|
|
340
|
+
cond2 = offs_e[None, :] < global_expert
|
|
341
|
+
|
|
342
|
+
if PADDED:
|
|
343
|
+
tl.device_assert(B_T >= 0)
|
|
344
|
+
# Only need information from previous experts in the same rank.
|
|
345
|
+
ep_first_order = (
|
|
346
|
+
tl.sum(tl.where(cond1 and cond2, input_token_counts, 0)) + B_T * rank
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
# r < rank || (r == rank && e < expert)
|
|
350
|
+
ep_first_order = tl.sum(
|
|
351
|
+
tl.where(cond0 or (cond1 and cond2), input_token_counts, 0)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
cond4 = offs_eg[None, :] < local_expert
|
|
355
|
+
cond5 = offs_eg[None, :] == local_expert
|
|
356
|
+
|
|
357
|
+
# Expert first only need information from local experts across ranks.
|
|
358
|
+
# e < expert || (e == expert && r < rank)
|
|
359
|
+
expert_first_order = tl.sum(
|
|
360
|
+
tl.where(cond4 or (cond5 and cond0), input_token_counts_eg, 0)
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if COMBINE:
|
|
364
|
+
input_offset = ep_first_order
|
|
365
|
+
output_offset = expert_first_order
|
|
366
|
+
else:
|
|
367
|
+
input_offset = expert_first_order
|
|
368
|
+
output_offset = ep_first_order
|
|
369
|
+
|
|
370
|
+
input_offset = input_offset.to(tl.int64)
|
|
371
|
+
output_offset = output_offset.to(tl.int64)
|
|
372
|
+
|
|
373
|
+
num_copy_tokens = tl.load(input_token_counts_ptr + rank * E + global_expert)
|
|
374
|
+
if num_copy_tokens == 0:
|
|
375
|
+
return
|
|
376
|
+
|
|
377
|
+
STEP_D: tl.constexpr = SPLIT_D * BLOCK_D
|
|
378
|
+
MASK_D: tl.constexpr = D % STEP_D != 0
|
|
379
|
+
|
|
380
|
+
num_t_blocks = tl.cdiv(num_copy_tokens, BLOCK_T)
|
|
381
|
+
|
|
382
|
+
t_1d_ptr = tl.arange(0, BLOCK_T)[:, None]
|
|
383
|
+
ti_1d_ptr = input_offset + t_1d_ptr
|
|
384
|
+
to_1d_ptr = output_offset + t_1d_ptr
|
|
385
|
+
|
|
386
|
+
d_1d_ptr = didx * NUM_D_BLOCKS * BLOCK_D + tl.arange(0, BLOCK_D)[None, :]
|
|
387
|
+
|
|
388
|
+
i_2d_ptr = input_tokens_ptr + ti_1d_ptr * D + d_1d_ptr
|
|
389
|
+
o_2d_ptr = output_tokens_ptr + to_1d_ptr * D + d_1d_ptr
|
|
390
|
+
|
|
391
|
+
for i in range(num_t_blocks * NUM_D_BLOCKS):
|
|
392
|
+
mask = t_1d_ptr < num_copy_tokens
|
|
393
|
+
if MASK_D:
|
|
394
|
+
mask &= d_1d_ptr < D
|
|
395
|
+
|
|
396
|
+
block = tl.load(
|
|
397
|
+
i_2d_ptr,
|
|
398
|
+
mask=mask,
|
|
399
|
+
)
|
|
400
|
+
tl.store(
|
|
401
|
+
o_2d_ptr,
|
|
402
|
+
value=block,
|
|
403
|
+
mask=mask,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
if i % NUM_D_BLOCKS == (NUM_D_BLOCKS - 1): # pyre-ignore
|
|
407
|
+
# just to make sure constant folding happens
|
|
408
|
+
D_1D_SHIFT: tl.constexpr = -(NUM_D_BLOCKS - 1) * BLOCK_D
|
|
409
|
+
TD_2D_SHIFT: tl.constexpr = BLOCK_T * D + D_1D_SHIFT
|
|
410
|
+
# increment T, D
|
|
411
|
+
t_1d_ptr += BLOCK_T
|
|
412
|
+
i_2d_ptr += TD_2D_SHIFT
|
|
413
|
+
o_2d_ptr += TD_2D_SHIFT
|
|
414
|
+
if MASK_D:
|
|
415
|
+
d_1d_ptr += D_1D_SHIFT
|
|
416
|
+
else:
|
|
417
|
+
# increment D
|
|
418
|
+
i_2d_ptr += BLOCK_D
|
|
419
|
+
o_2d_ptr += BLOCK_D
|
|
420
|
+
if MASK_D:
|
|
421
|
+
d_1d_ptr += BLOCK_D
|
mslk/mslk.so
ADDED
|
Binary file
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
load_library_buck("//mslk/csrc/quantize:quantize_ops")
|