mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Keep a registry of all convolution operators.
8
+ import abc
9
+
10
+ import mslk.conv # noqa: F401
11
+ import torch
12
+ from mslk.bench.common.utils import BenchOptions, do_bench
13
+ from mslk.quantize.triton.fp8_quantize import quantize_fp8_tensor
14
+
15
+
16
+ conv_op_registry = []
17
+
18
+
19
+ class ConvOpBase(metaclass=abc.ABCMeta):
20
+ """Helper abstract class to define expected methods of conv ops."""
21
+
22
+ @abc.abstractmethod
23
+ def quantize(self, *args):
24
+ """Function which quantizes inputs."""
25
+ pass
26
+
27
+ @abc.abstractmethod
28
+ def compute(self, *args, **kwargs):
29
+ """Function which performs main compute operation."""
30
+ pass
31
+
32
+ @abc.abstractmethod
33
+ def quantize_and_compute(self, *args, **kwargs):
34
+ """Function which quantizes inputs and performs main compute operation."""
35
+ pass
36
+
37
+ def preprocess(self, *args):
38
+ """Preprocess inputs before benchmarking. These outputs will be passed to quantize."""
39
+ return args
40
+
41
+ def benchmark(
42
+ self,
43
+ *args,
44
+ opts: BenchOptions,
45
+ bench_quantize: bool,
46
+ ) -> float:
47
+ """Benchmark runtime of this operator."""
48
+ return do_bench(
49
+ lambda *a: self.quantize_and_compute(*a)
50
+ if bench_quantize
51
+ else self.compute(*a),
52
+ args,
53
+ opts,
54
+ )
55
+
56
+ @abc.abstractproperty
57
+ def name(self) -> str:
58
+ """Name of the operator."""
59
+ pass
60
+
61
+ @abc.abstractproperty
62
+ def hip(self) -> bool:
63
+ """Whether this operator supports AMD or not."""
64
+ pass
65
+
66
+ @abc.abstractproperty
67
+ def cuda(self) -> bool:
68
+ """Whether this operator supports Nvidia or not."""
69
+ pass
70
+
71
+ @property
72
+ def supported(self) -> bool:
73
+ """Whether this op will run on the current device."""
74
+ if torch.version.hip is not None:
75
+ return self.hip
76
+ elif torch.version.cuda is not None:
77
+ return self.cuda
78
+ else:
79
+ return False
80
+
81
+
82
+ def register_conv_op(op):
83
+ """Decorator function for assembling all conv ops."""
84
+ conv_op_registry.append(op())
85
+ return op
86
+
87
+
88
+ def get_conv_ops() -> list[ConvOpBase]:
89
+ """Get all registered conv ops."""
90
+ return conv_op_registry
91
+
92
+
93
+ @register_conv_op
94
+ class TorchBaseline(ConvOpBase):
95
+ """
96
+ PyTorch baseline convolution.
97
+ """
98
+
99
+ def __init__(self):
100
+ self.torch_compile = False
101
+
102
+ def quantize(self, activation, filter, padding, stride, dilation):
103
+ return (
104
+ activation.to(torch.bfloat16),
105
+ filter.to(torch.bfloat16),
106
+ padding,
107
+ stride,
108
+ dilation,
109
+ )
110
+
111
+ def compute(self, activation, filter, padding, stride, dilation):
112
+ if self.torch_compile:
113
+ f = torch.compile(
114
+ torch.nn.functional.conv3d,
115
+ options={
116
+ "max_autotune": True,
117
+ "max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
118
+ },
119
+ )
120
+ else:
121
+ f = torch.nn.functional.conv3d
122
+
123
+ return f(
124
+ activation,
125
+ filter,
126
+ bias=None,
127
+ stride=stride,
128
+ padding=padding,
129
+ dilation=dilation,
130
+ )
131
+
132
+ def quantize_and_compute(self, activation, filter, padding, stride, dilation):
133
+ return self.compute(
134
+ *self.quantize(activation, filter, padding, stride, dilation)
135
+ )
136
+
137
+ @property
138
+ def name(self) -> str:
139
+ return "torch_baseline"
140
+
141
+ @property
142
+ def hip(self) -> bool:
143
+ return True
144
+
145
+ @property
146
+ def cuda(self) -> bool:
147
+ return True
148
+
149
+
150
+ @register_conv_op
151
+ class F8F8BF16Conv(ConvOpBase):
152
+ """
153
+ FP8 convolution with rowwise scaling.
154
+ """
155
+
156
+ def preprocess(self, activation, filter, padding, stride, dilation):
157
+ # Inputs and filters are provided in channels first layout.
158
+ # Cutlass kernels support this but require the underlying memory
159
+ # to be channels last. Torch enables this through the memory format
160
+ # transformation which we assume has been applied ahead of time.
161
+ activation = activation.to(memory_format=torch.channels_last_3d)
162
+ filter = filter.to(memory_format=torch.channels_last_3d)
163
+ return activation, filter, padding, stride, dilation
164
+
165
+ def _quantize_tensor(self, x):
166
+ """Quantize tensor to FP8 with rowwise scaling."""
167
+ xq, x_scale = quantize_fp8_tensor(x)
168
+ return xq, x_scale
169
+
170
+ def quantize(self, activation, filter, padding, stride, dilation):
171
+ # Quantize both input tensors
172
+ activation_q, activation_scale = self._quantize_tensor(activation)
173
+ filter_q, filter_scale = self._quantize_tensor(filter)
174
+
175
+ # Compute combined scale for output
176
+ # For conv, we need a single scale value
177
+ scale = torch.tensor(
178
+ [activation_scale * filter_scale],
179
+ device=activation.device,
180
+ dtype=torch.float32,
181
+ )
182
+
183
+ return activation_q, filter_q, scale, padding, stride, dilation
184
+
185
+ def compute(self, activation_q, filter_q, scale, padding, stride, dilation):
186
+ output = torch.ops.mslk.f8f8bf16_conv(
187
+ activation_q,
188
+ filter_q,
189
+ scale,
190
+ padding,
191
+ stride,
192
+ dilation,
193
+ )
194
+ return output
195
+
196
+ def quantize_and_compute(self, activation, filter, padding, stride, dilation):
197
+ activation_q, filter_q, scale, padding, stride, dilation = self.quantize(
198
+ activation, filter, padding, stride, dilation
199
+ )
200
+ return self.compute(activation_q, filter_q, scale, padding, stride, dilation)
201
+
202
+ @property
203
+ def name(self) -> str:
204
+ return "f8f8bf16_conv"
205
+
206
+ @property
207
+ def hip(self) -> bool:
208
+ # Currently only supported on CUDA
209
+ return False
210
+
211
+ @property
212
+ def cuda(self) -> bool:
213
+ return True
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict