mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,345 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import itertools
10
+ import os
11
+ import sys
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+ from typing import Callable, Optional
15
+
16
+ import click
17
+ import pandas as pd
18
+ import torch
19
+ import triton # @manual=//triton:triton
20
+ from mslk.bench.common.utils import BenchOptions, profiler
21
+ from mslk.bench.quantize.quantize_ops import get_ops, QuantizeOpBase
22
+ from tabulate import tabulate
23
+
24
+ type ShapeFunction = Callable[[], list[tuple[int, int]]]
25
+
26
+ shape_registry: dict[str, ShapeFunction] = {}
27
+
28
+
29
+ def register_shapes(name: str) -> Callable[[ShapeFunction], ShapeFunction]:
30
+ def decorator(
31
+ shape_function: ShapeFunction,
32
+ ) -> ShapeFunction:
33
+ shape_registry[name] = shape_function
34
+ return shape_function
35
+
36
+ return decorator
37
+
38
+
39
+ @register_shapes("llm_eval")
40
+ def llm_eval() -> list[tuple[int, int]]:
41
+ return [
42
+ (1, 5120),
43
+ (1024, 5120),
44
+ (2000, 5120),
45
+ (4096, 5120),
46
+ (16384, 5120),
47
+ (1024, 7168),
48
+ (4096, 4096),
49
+ ]
50
+
51
+
52
+ @register_shapes("decode_1024")
53
+ def decode_1024_shapes() -> list[tuple[int, int]]:
54
+ return [
55
+ (1, 1024),
56
+ (1, 2048),
57
+ (1, 4096),
58
+ (1, 5120),
59
+ (1, 6144),
60
+ (1, 7168),
61
+ (1, 8192),
62
+ ]
63
+
64
+
65
+ @register_shapes("prefill_1024")
66
+ def prefill_1024_shapes() -> list[tuple[int, int]]:
67
+ shapes = []
68
+ for M in [2048, 4096, 8192, 16384]:
69
+ shapes += [
70
+ (M, 1024),
71
+ (M, 2048),
72
+ (M, 4096),
73
+ (M, 5120),
74
+ (M, 6144),
75
+ (M, 7168),
76
+ (M, 8192),
77
+ ]
78
+ return shapes
79
+
80
+
81
+ @dataclass
82
+ class Metrics:
83
+ op: str
84
+ M: int = 0
85
+ K: int = 0
86
+ sim: float = 0.0
87
+ us: float = 0.0
88
+ gbps: float = 0.0
89
+ memory_bw_util: float = 0.0
90
+
91
+ @staticmethod
92
+ def header() -> str:
93
+ header = f"{'OpName':<20} {'Problem Shape':<15} {'Sim':<10} {'Us':<10} {'GB/s':<10} {'Mem BW Util %':<10}"
94
+ divider = "-" * len(header)
95
+ return f"Quantize Bench\n{divider}\n{header}\n{divider}"
96
+
97
+ def __str__(self) -> str:
98
+ problem_shape = f"({self.M}, {self.K})"
99
+ return f"{self.op:<20} {problem_shape:<15} {self.sim:<10.3f} {self.us:<10.3f} {self.gbps:<10.2f} {self.memory_bw_util:<10.2f}"
100
+
101
+ def as_dict(self) -> dict[str, float]:
102
+ return {
103
+ "M": self.M,
104
+ "K": self.K,
105
+ f"{self.op}_sim": self.sim,
106
+ f"{self.op}_us": self.us,
107
+ f"{self.op}_gb/s": self.gbps,
108
+ f"{self.op}_memory_bw_util": self.memory_bw_util,
109
+ }
110
+
111
+
112
+ def get_problem_shapes(
113
+ shapes: Optional[str],
114
+ m: Optional[str],
115
+ k: Optional[str],
116
+ pair_mk: bool,
117
+ ) -> list[tuple[int, int]]:
118
+ if shapes:
119
+ all_shapes = set()
120
+
121
+ for shape in shapes.strip().split(","):
122
+ if shape not in shape_registry:
123
+ print(
124
+ f"Shape {shape} not found in shape registry. Valid shapes: {', '.join(shape_registry.keys())}."
125
+ )
126
+ sys.exit(1)
127
+ all_shapes.update(shape_registry[shape]())
128
+
129
+ return list(all_shapes)
130
+
131
+ if m is None:
132
+ raise Exception("M must be non-empty.")
133
+ M = [int(m_val) for m_val in m.strip().split(",")]
134
+ if k is None:
135
+ raise Exception("K must be non-empty.")
136
+ K = [int(k_val) for k_val in k.strip().split(",")]
137
+
138
+ if pair_mk:
139
+ if len(M) != len(K):
140
+ raise Exception("M and K must be the same length in pair_MK mode.")
141
+ return list(zip(M, K))
142
+ else:
143
+ return list(itertools.product(M, K))
144
+
145
+
146
+ def benchmark(
147
+ quantize_ops: list[QuantizeOpBase],
148
+ m: int,
149
+ k: int,
150
+ mem_bw_roofline_gbps: float,
151
+ opts: BenchOptions,
152
+ ) -> list[Metrics]:
153
+ # Create input tensors.
154
+ input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
155
+
156
+ # Keep track of results.
157
+ results = []
158
+ # Benchmark each operator.
159
+ for quantize_op in quantize_ops:
160
+ metrics = Metrics(op=quantize_op.name, M=m, K=k)
161
+ args = quantize_op.preprocess(input)
162
+ quantized = quantize_op.quantize(input, *args)
163
+ dequantized = quantize_op.dequantize(*quantized)
164
+ metrics.sim = torch.mean(torch.pow(dequantized - input, 2)).item()
165
+
166
+ for _ in range(opts.num_iters):
167
+ with profiler(enabled=opts.trace, with_stack=True):
168
+ ms_runtime = quantize_op.benchmark(
169
+ input,
170
+ args,
171
+ opts=opts,
172
+ )
173
+
174
+ input_bytes = input.numel() * input.element_size()
175
+ output_bytes = sum(t.numel() * t.element_size() for t in quantized)
176
+ total_size_bytes = input_bytes + output_bytes
177
+ gbps = (total_size_bytes / 1e9) / (ms_runtime / 1e3)
178
+ metrics.gbps += gbps
179
+ metrics.us += ms_runtime * 1000
180
+ metrics.memory_bw_util += (gbps / mem_bw_roofline_gbps) * 100
181
+
182
+ metrics.us /= opts.num_iters
183
+ metrics.gbps /= opts.num_iters
184
+ metrics.memory_bw_util /= opts.num_iters
185
+
186
+ results.append(metrics)
187
+
188
+ return results
189
+
190
+
191
+ def collect_kernels_to_profile(kernels: Optional[list[str]]) -> list[QuantizeOpBase]:
192
+ # Get existing quantization operators.
193
+ quantize_ops = [op for op in get_ops() if op.supported]
194
+ if kernels is None:
195
+ return quantize_ops
196
+ return [op for op in quantize_ops if op.name in kernels]
197
+
198
+
199
+ def print_kernels(kernels: Optional[list[str]]) -> None:
200
+ data = sorted(
201
+ [
202
+ (op.name, "Yes" if op.cuda else "No", "Yes" if op.hip else "No")
203
+ for op in get_ops()
204
+ ]
205
+ )
206
+ print(tabulate(data, headers=["Name", "CUDA", "ROCm"], tablefmt="orgtbl"))
207
+
208
+
209
+ @click.command()
210
+ @click.option(
211
+ "--output-dir",
212
+ default="/tmp",
213
+ help="Directory to save plots and csvs to",
214
+ )
215
+ @click.option(
216
+ "--num-iters",
217
+ default=1,
218
+ type=int,
219
+ help="Number of iterations to repeat each benchmark.",
220
+ )
221
+ @click.option(
222
+ "--export-csv",
223
+ is_flag=True,
224
+ help="Export results to a CSV file.",
225
+ )
226
+ @click.option(
227
+ "--kernels",
228
+ default=None,
229
+ help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
230
+ )
231
+ @click.option(
232
+ "--M",
233
+ default=None,
234
+ help="Comma separated list of M values to benchmark.",
235
+ )
236
+ @click.option(
237
+ "--K",
238
+ default=None,
239
+ help="Comma separated list of K values to benchmark.",
240
+ )
241
+ @click.option(
242
+ "--pair-MK",
243
+ is_flag=True,
244
+ help="If set, instead of benchmarking cartesian product of M * K, benchmark consecutive MK pairs together.",
245
+ )
246
+ @click.option(
247
+ "--no-cuda-graph",
248
+ is_flag=True,
249
+ help="If set, do not use cuda graph for benchmarking.",
250
+ )
251
+ @click.option(
252
+ "--no-rotating-buffer",
253
+ is_flag=True,
254
+ help="If set, do not use rotating buffer for benchmarking.",
255
+ )
256
+ @click.option(
257
+ "--shapes",
258
+ default=None,
259
+ help=f"Specific model shapes to use, options: {', '.join(shape_registry.keys())}.",
260
+ )
261
+ @click.option(
262
+ "--trace",
263
+ is_flag=True,
264
+ help="If set, produce a performance trace of the benchmark.",
265
+ )
266
+ def invoke_main(
267
+ output_dir: str,
268
+ num_iters: int,
269
+ export_csv: bool,
270
+ kernels: Optional[str],
271
+ m: Optional[str],
272
+ k: Optional[str],
273
+ pair_mk: bool,
274
+ no_cuda_graph: bool,
275
+ no_rotating_buffer: bool,
276
+ shapes: Optional[str],
277
+ trace: bool,
278
+ ) -> None:
279
+ # If kernel filter is provided, parse it. Else, benchmark all kernels.
280
+ all_kernels = kernels.strip().split(",") if kernels else None
281
+ quantize_ops = collect_kernels_to_profile(all_kernels)
282
+
283
+ if len(quantize_ops) == 0:
284
+ print("No valid kernels to benchmark. Available kernels:")
285
+ print_kernels(all_kernels)
286
+ sys.exit(1)
287
+
288
+ if num_iters < 1:
289
+ print("Warning: Number of iterations must be at least 1.")
290
+ num_iters = 1
291
+
292
+ mem_bw_roofline_gbps = triton.testing.get_dram_gbps()
293
+ MK = get_problem_shapes(shapes, m, k, pair_mk)
294
+
295
+ opts = BenchOptions(
296
+ num_iters=num_iters,
297
+ cuda_graph=not no_cuda_graph,
298
+ rotating_buffer=not no_rotating_buffer,
299
+ trace=trace,
300
+ )
301
+
302
+ # Iterate over shapes and benchmark.
303
+ benchmark_results = []
304
+ csv = []
305
+ for M, K in MK:
306
+ quantize_measurements = benchmark(
307
+ quantize_ops,
308
+ M,
309
+ K,
310
+ mem_bw_roofline_gbps,
311
+ opts,
312
+ )
313
+ benchmark_results.extend(quantize_measurements)
314
+ csv_row = {}
315
+ for metric in quantize_measurements:
316
+ csv_row.update(metric.as_dict())
317
+ csv.append(csv_row)
318
+
319
+ print(Metrics.header())
320
+ for metric in benchmark_results:
321
+ print(metric)
322
+
323
+ print("")
324
+ print(f"Hardware: {torch.cuda.get_device_name()}")
325
+ print(f" Memory BW Roofline: {mem_bw_roofline_gbps} GB/s")
326
+
327
+ print("")
328
+ print("Benchmark Settings:")
329
+ print(f" CUDA graph: {opts.cuda_graph}")
330
+ print(f" Buffer rotation: {opts.rotating_buffer}")
331
+
332
+ if export_csv:
333
+ os.makedirs(output_dir, exist_ok=True)
334
+ datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
335
+ csv_file = os.path.join(
336
+ output_dir, f"quantize_ops_benchmark_{datetime_str}.csv"
337
+ )
338
+ # Export results to a CSV file.
339
+ df = pd.DataFrame(csv)
340
+ df.to_csv(csv_file, na_rep="NaN", index=False)
341
+ print(f"CSV saved to {csv_file}")
342
+
343
+
344
+ if __name__ == "__main__":
345
+ invoke_main()
@@ -0,0 +1,266 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import abc
10
+ from typing import Any, TypeVar
11
+
12
+ import torch
13
+ from mslk.bench.common.utils import BenchOptions, do_bench
14
+ from mslk.quantize.triton.fp4_quantize import triton_quantize_nvfp4
15
+ from mslk.quantize.triton.fp8_quantize import (
16
+ dequantize_fp8_block,
17
+ dequantize_fp8_row,
18
+ triton_quantize_fp8_block,
19
+ triton_quantize_fp8_group,
20
+ triton_quantize_fp8_row,
21
+ triton_quantize_fp8_tensor,
22
+ )
23
+ from mslk.test.quantize.triton.fp4_quantize_test import (
24
+ dequantize_nvfp4,
25
+ global_scale_nvfp4,
26
+ )
27
+
28
+
29
+ class QuantizeOpBase(metaclass=abc.ABCMeta):
30
+ """Helper abstract class to define expected methods of quantize ops."""
31
+
32
+ @abc.abstractmethod
33
+ def quantize(self, input: torch.Tensor) -> Any:
34
+ """Function which quantizes inputs."""
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def dequantize(self, *args: Any) -> torch.Tensor:
39
+ """Function which dequantizes inputs. Used for sanity checking."""
40
+ pass
41
+
42
+ @abc.abstractproperty
43
+ def hip(self) -> bool:
44
+ """Whether this operator supports AMD or not."""
45
+ pass
46
+
47
+ @abc.abstractproperty
48
+ def cuda(self) -> bool:
49
+ """Whether this operator supports Nvidia or not."""
50
+ pass
51
+
52
+ def preprocess(self, input: torch.Tensor) -> Any:
53
+ """This is used for ops that require additional preprocessing. This method will not be counted in benchmarking."""
54
+ return ()
55
+
56
+ @property
57
+ def name(self) -> str:
58
+ """Name of this operator."""
59
+ return self.__class__.__name__
60
+
61
+ @property
62
+ def supported(self) -> bool:
63
+ """Whether this op will run on the current device."""
64
+ if torch.version.hip is not None:
65
+ return self.hip
66
+ elif torch.version.cuda is not None:
67
+ return self.cuda
68
+ else:
69
+ return False
70
+
71
+ def benchmark(
72
+ self,
73
+ input: torch.Tensor,
74
+ args: Any,
75
+ opts: BenchOptions,
76
+ ) -> float:
77
+ """Benchmark runtime of this operator using do_bench from common."""
78
+ return do_bench(
79
+ lambda inp, *a: self.quantize(inp, *a),
80
+ (input, *args),
81
+ opts,
82
+ )
83
+
84
+
85
+ op_registry: dict[str, QuantizeOpBase] = {}
86
+
87
+ T = TypeVar("T", bound=QuantizeOpBase)
88
+
89
+
90
+ def register_op(op_class: type[T]) -> type[T]:
91
+ """Decorator function for assembling all quantize ops."""
92
+ op_registry[op_class.__name__] = op_class()
93
+ return op_class
94
+
95
+
96
+ def get_ops() -> list[QuantizeOpBase]:
97
+ """Get all registered quantize ops."""
98
+ return list(op_registry.values())
99
+
100
+
101
+ @register_op
102
+ class TritonFP8Rowwise(QuantizeOpBase):
103
+ def quantize(self, input: torch.Tensor) -> Any:
104
+ return triton_quantize_fp8_row(input)
105
+
106
+ def dequantize(self, *args: Any) -> torch.Tensor:
107
+ input_quantized: torch.Tensor
108
+ scale: torch.Tensor
109
+ input_quantized, scale = args
110
+ return dequantize_fp8_row(input_quantized, scale)
111
+
112
+ @property
113
+ def hip(self) -> bool:
114
+ return True
115
+
116
+ @property
117
+ def cuda(self) -> bool:
118
+ return True
119
+
120
+
121
+ @register_op
122
+ class TritonFP8Blockwise(QuantizeOpBase):
123
+ def __init__(self) -> None:
124
+ super().__init__()
125
+ self.block_m = 128
126
+ self.block_k = 128
127
+
128
+ def quantize(self, input: torch.Tensor) -> Any:
129
+ return triton_quantize_fp8_block(
130
+ input, block_m=self.block_m, block_k=self.block_k
131
+ )
132
+
133
+ def dequantize(self, *args: Any) -> torch.Tensor:
134
+ input_quantized: torch.Tensor
135
+ scale: torch.Tensor
136
+ input_quantized, scale = args
137
+ return dequantize_fp8_block(input_quantized, scale, self.block_m, self.block_k)
138
+
139
+ @property
140
+ def hip(self) -> bool:
141
+ return True
142
+
143
+ @property
144
+ def cuda(self) -> bool:
145
+ return True
146
+
147
+
148
+ @register_op
149
+ class TritonFP8Groupwise(QuantizeOpBase):
150
+ def __init__(self) -> None:
151
+ super().__init__()
152
+ self.group_size = 128
153
+
154
+ def quantize(self, input: torch.Tensor) -> Any:
155
+ return triton_quantize_fp8_group(input, group_size=self.group_size)
156
+
157
+ def dequantize(self, *args: Any) -> torch.Tensor:
158
+ input_quantized: torch.Tensor
159
+ scale: torch.Tensor
160
+ input_quantized, scale = args
161
+
162
+ input_quantized = input_quantized.to(torch.float)
163
+ dequantized = input_quantized.view(
164
+ -1, input_quantized.shape[1] // self.group_size, self.group_size
165
+ ) * scale.unsqueeze(-1)
166
+ return dequantized.view(input_quantized.shape)
167
+
168
+ @property
169
+ def hip(self) -> bool:
170
+ return True
171
+
172
+ @property
173
+ def cuda(self) -> bool:
174
+ return True
175
+
176
+
177
+ @register_op
178
+ class TritonNVFP4(QuantizeOpBase):
179
+ def __init__(self) -> None:
180
+ super().__init__()
181
+
182
+ def preprocess(self, input: torch.Tensor) -> Any:
183
+ global_scale = global_scale_nvfp4(input)
184
+ return (global_scale,)
185
+
186
+ def quantize(self, input: torch.Tensor, *args: Any) -> Any:
187
+ global_scale: torch.Tensor
188
+ global_scale = args[0]
189
+ input_quantized, scales = triton_quantize_nvfp4(input, global_scale)
190
+ return input_quantized.view(torch.uint8), scales, global_scale
191
+
192
+ def dequantize(self, *args: Any) -> torch.Tensor:
193
+ input_quantized: torch.Tensor
194
+ scale: torch.Tensor
195
+ global_scale: torch.Tensor
196
+ input_quantized, scale, global_scale = args
197
+
198
+ return dequantize_nvfp4(input_quantized, scale, global_scale)
199
+
200
+ @property
201
+ def hip(self) -> bool:
202
+ return False
203
+
204
+ @property
205
+ def cuda(self) -> bool:
206
+ return True
207
+
208
+
209
+ @register_op
210
+ class CudaFP8Rowwise(QuantizeOpBase):
211
+ def quantize(self, input: torch.Tensor) -> Any:
212
+ return torch.ops.mslk.quantize_fp8_per_row(input)
213
+
214
+ def dequantize(self, *args: Any) -> torch.Tensor:
215
+ input_quantized: torch.Tensor
216
+ scale: torch.Tensor
217
+ input_quantized, scale = args
218
+ return dequantize_fp8_row(input_quantized, scale)
219
+
220
+ @property
221
+ def hip(self) -> bool:
222
+ return True
223
+
224
+ @property
225
+ def cuda(self) -> bool:
226
+ return True
227
+
228
+
229
+ @register_op
230
+ class CudaFP8Tensorwise(QuantizeOpBase):
231
+ def quantize(self, input: torch.Tensor) -> Any:
232
+ return torch.ops.mslk.quantize_fp8_per_tensor(input)
233
+
234
+ def dequantize(self, *args: Any) -> torch.Tensor:
235
+ input_quantized: torch.Tensor
236
+ scale: torch.Tensor
237
+ input_quantized, scale = args
238
+ return input_quantized.to(torch.float32) * scale
239
+
240
+ @property
241
+ def hip(self) -> bool:
242
+ return True
243
+
244
+ @property
245
+ def cuda(self) -> bool:
246
+ return True
247
+
248
+
249
+ @register_op
250
+ class TritonFP8Tensorwise(QuantizeOpBase):
251
+ def quantize(self, input: torch.Tensor) -> Any:
252
+ return triton_quantize_fp8_tensor(input)
253
+
254
+ def dequantize(self, *args: Any) -> torch.Tensor:
255
+ input_quantized: torch.Tensor
256
+ scale: torch.Tensor
257
+ input_quantized, scale = args
258
+ return input_quantized.to(torch.float32) * scale
259
+
260
+ @property
261
+ def hip(self) -> bool:
262
+ return True
263
+
264
+ @property
265
+ def cuda(self) -> bool:
266
+ return True
mslk/comm/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from mslk.utils.torch.library import load_library_buck
10
+
11
+ load_library_buck("//mslk/csrc/comm:car_ops")
mslk/conv/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from mslk.utils.torch.library import load_library_buck
10
+
11
+ load_library_buck("//mslk/csrc/conv:conv_ops")
mslk/gemm/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from mslk.utils.torch.library import load_library_buck
10
+
11
+ load_library_buck("//mslk/csrc/gemm:gemm_ops")
12
+
13
+ gemm_ops = [
14
+ "//mslk/csrc/gemm/cutlass:cutlass_bf16bf16bf16_grouped_grad",
15
+ "//mslk/csrc/gemm/cutlass:cutlass_bf16bf16bf16_grouped_wgrad",
16
+ ]
17
+ for op in gemm_ops:
18
+ load_library_buck(op)
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict