mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# Keep a registry of all convolution operators.
|
|
8
|
+
import abc
|
|
9
|
+
|
|
10
|
+
import mslk.conv # noqa: F401
|
|
11
|
+
import torch
|
|
12
|
+
from mslk.bench.common.utils import BenchOptions, do_bench
|
|
13
|
+
from mslk.quantize.triton.fp8_quantize import quantize_fp8_tensor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
conv_op_registry = []
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConvOpBase(metaclass=abc.ABCMeta):
|
|
20
|
+
"""Helper abstract class to define expected methods of conv ops."""
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def quantize(self, *args):
|
|
24
|
+
"""Function which quantizes inputs."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def compute(self, *args, **kwargs):
|
|
29
|
+
"""Function which performs main compute operation."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def quantize_and_compute(self, *args, **kwargs):
|
|
34
|
+
"""Function which quantizes inputs and performs main compute operation."""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def preprocess(self, *args):
|
|
38
|
+
"""Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
|
|
39
|
+
return args
|
|
40
|
+
|
|
41
|
+
def benchmark(
|
|
42
|
+
self,
|
|
43
|
+
*args,
|
|
44
|
+
opts: BenchOptions,
|
|
45
|
+
bench_quantize: bool,
|
|
46
|
+
) -> float:
|
|
47
|
+
"""Benchmark runtime of this operator."""
|
|
48
|
+
return do_bench(
|
|
49
|
+
lambda *a: self.quantize_and_compute(*a)
|
|
50
|
+
if bench_quantize
|
|
51
|
+
else self.compute(*a),
|
|
52
|
+
args,
|
|
53
|
+
opts,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@abc.abstractproperty
|
|
57
|
+
def name(self) -> str:
|
|
58
|
+
"""Name of the operator."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abc.abstractproperty
|
|
62
|
+
def hip(self) -> bool:
|
|
63
|
+
"""Whether this operator supports AMD or not."""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@abc.abstractproperty
|
|
67
|
+
def cuda(self) -> bool:
|
|
68
|
+
"""Whether this operator supports Nvidia or not."""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def supported(self) -> bool:
|
|
73
|
+
"""Whether this op will run on the current device."""
|
|
74
|
+
if torch.version.hip is not None:
|
|
75
|
+
return self.hip
|
|
76
|
+
elif torch.version.cuda is not None:
|
|
77
|
+
return self.cuda
|
|
78
|
+
else:
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def register_conv_op(op):
|
|
83
|
+
"""Decorator function for assembling all conv ops."""
|
|
84
|
+
conv_op_registry.append(op())
|
|
85
|
+
return op
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_conv_ops() -> list[ConvOpBase]:
|
|
89
|
+
"""Get all registered conv ops."""
|
|
90
|
+
return conv_op_registry
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@register_conv_op
|
|
94
|
+
class TorchBaseline(ConvOpBase):
|
|
95
|
+
"""
|
|
96
|
+
PyTorch baseline convolution.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self):
|
|
100
|
+
self.torch_compile = False
|
|
101
|
+
|
|
102
|
+
def quantize(self, activation, filter, padding, stride, dilation):
|
|
103
|
+
return (
|
|
104
|
+
activation.to(torch.bfloat16),
|
|
105
|
+
filter.to(torch.bfloat16),
|
|
106
|
+
padding,
|
|
107
|
+
stride,
|
|
108
|
+
dilation,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def compute(self, activation, filter, padding, stride, dilation):
|
|
112
|
+
if self.torch_compile:
|
|
113
|
+
f = torch.compile(
|
|
114
|
+
torch.nn.functional.conv3d,
|
|
115
|
+
options={
|
|
116
|
+
"max_autotune": True,
|
|
117
|
+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
f = torch.nn.functional.conv3d
|
|
122
|
+
|
|
123
|
+
return f(
|
|
124
|
+
activation,
|
|
125
|
+
filter,
|
|
126
|
+
bias=None,
|
|
127
|
+
stride=stride,
|
|
128
|
+
padding=padding,
|
|
129
|
+
dilation=dilation,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def quantize_and_compute(self, activation, filter, padding, stride, dilation):
|
|
133
|
+
return self.compute(
|
|
134
|
+
*self.quantize(activation, filter, padding, stride, dilation)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def name(self) -> str:
|
|
139
|
+
return "torch_baseline"
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def hip(self) -> bool:
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def cuda(self) -> bool:
|
|
147
|
+
return True
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@register_conv_op
|
|
151
|
+
class F8F8BF16Conv(ConvOpBase):
|
|
152
|
+
"""
|
|
153
|
+
FP8 convolution with rowwise scaling.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def preprocess(self, activation, filter, padding, stride, dilation):
|
|
157
|
+
# Inputs and filters are provided in channels first layout.
|
|
158
|
+
# Cutlass kernels support this but require the underlying memory
|
|
159
|
+
# to be channels last. Torch enables this through the memory format
|
|
160
|
+
# transformation which we assume has been applied ahead of time.
|
|
161
|
+
activation = activation.to(memory_format=torch.channels_last_3d)
|
|
162
|
+
filter = filter.to(memory_format=torch.channels_last_3d)
|
|
163
|
+
return activation, filter, padding, stride, dilation
|
|
164
|
+
|
|
165
|
+
def _quantize_tensor(self, x):
|
|
166
|
+
"""Quantize tensor to FP8 with rowwise scaling."""
|
|
167
|
+
xq, x_scale = quantize_fp8_tensor(x)
|
|
168
|
+
return xq, x_scale
|
|
169
|
+
|
|
170
|
+
def quantize(self, activation, filter, padding, stride, dilation):
|
|
171
|
+
# Quantize both input tensors
|
|
172
|
+
activation_q, activation_scale = self._quantize_tensor(activation)
|
|
173
|
+
filter_q, filter_scale = self._quantize_tensor(filter)
|
|
174
|
+
|
|
175
|
+
# Compute combined scale for output
|
|
176
|
+
# For conv, we need a single scale value
|
|
177
|
+
scale = torch.tensor(
|
|
178
|
+
[activation_scale * filter_scale],
|
|
179
|
+
device=activation.device,
|
|
180
|
+
dtype=torch.float32,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return activation_q, filter_q, scale, padding, stride, dilation
|
|
184
|
+
|
|
185
|
+
def compute(self, activation_q, filter_q, scale, padding, stride, dilation):
|
|
186
|
+
output = torch.ops.mslk.f8f8bf16_conv(
|
|
187
|
+
activation_q,
|
|
188
|
+
filter_q,
|
|
189
|
+
scale,
|
|
190
|
+
padding,
|
|
191
|
+
stride,
|
|
192
|
+
dilation,
|
|
193
|
+
)
|
|
194
|
+
return output
|
|
195
|
+
|
|
196
|
+
def quantize_and_compute(self, activation, filter, padding, stride, dilation):
|
|
197
|
+
activation_q, filter_q, scale, padding, stride, dilation = self.quantize(
|
|
198
|
+
activation, filter, padding, stride, dilation
|
|
199
|
+
)
|
|
200
|
+
return self.compute(activation_q, filter_q, scale, padding, stride, dilation)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def name(self) -> str:
|
|
204
|
+
return "f8f8bf16_conv"
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def hip(self) -> bool:
|
|
208
|
+
# Currently only supported on CUDA
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def cuda(self) -> bool:
|
|
213
|
+
return True
|