quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/cute_dsl_utils.py CHANGED
@@ -1,8 +1,7 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
2
 
3
- import os
4
- import pathlib
5
- from functools import partial, lru_cache
3
+ from typing import Tuple
4
+ from functools import lru_cache
6
5
  from dataclasses import dataclass, fields
7
6
 
8
7
  import torch
@@ -14,6 +13,7 @@ except ImportError:
14
13
 
15
14
  import cutlass
16
15
  import cutlass.cute as cute
16
+ from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
17
  from cutlass.base_dsl.typing import JitArgument
18
18
  from cutlass.cutlass_dsl import NumericMeta
19
19
 
@@ -26,9 +26,11 @@ cute_compile_og = cute.compile
26
26
 
27
27
 
28
28
  torch2cute_dtype_map = {
29
- torch.float16: cutlass.Float16,
30
- torch.bfloat16: cutlass.BFloat16,
31
- torch.float32: cutlass.Float32,
29
+ torch.float16: Float16,
30
+ torch.bfloat16: BFloat16,
31
+ torch.float32: Float32,
32
+ torch.int32: Int32,
33
+ torch.int64: Int64,
32
34
  }
33
35
 
34
36
 
@@ -37,6 +39,11 @@ def get_max_active_clusters(cluster_size):
37
39
  return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
38
40
 
39
41
 
42
+ @lru_cache
43
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
44
+ return torch.cuda.get_device_capability(device)
45
+
46
+
40
47
  @dataclass
41
48
  class ParamsBase:
42
49
  def __extract_mlir_values__(self):
@@ -75,10 +82,14 @@ class ArgumentsBase(JitArgument):
75
82
  def __get_mlir_types__(self):
76
83
  all_fields = [getattr(self, field.name) for field in fields(self)]
77
84
  non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
78
- types = []
85
+ types, self._values_pos = [], []
79
86
  for obj in non_constexpr_fields:
80
87
  if hasattr(obj, "__get_mlir_types__"):
81
- types.extend(obj.__get_mlir_types__())
88
+ obj_types = obj.__get_mlir_types__()
89
+ types.extend(obj_types)
90
+ self._values_pos.append(len(obj_types))
91
+ else:
92
+ self._values_pos.append(0)
82
93
  return types
83
94
 
84
95
  def __new_from_mlir_values__(self, values):
@@ -87,32 +98,7 @@ class ArgumentsBase(JitArgument):
87
98
  non_constexpr_fields = {
88
99
  n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
89
100
  }
90
- # for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
91
- for name, field in non_constexpr_fields.items():
92
- # non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
93
- # values = values[n_items:]
94
- n_items = 1
101
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
95
102
  non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
96
103
  values = values[n_items:]
97
104
  return self.__class__(**non_constexpr_fields, **constexpr_fields)
98
-
99
-
100
- def load_cubin_module_data_patched(cubin_data, filepath):
101
- pathlib.Path(filepath).write_bytes(cubin_data)
102
- return load_cubin_module_data_og(cubin_data)
103
-
104
-
105
- def cute_compile_patched(*args, **kwargs):
106
- """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
107
- cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
108
- if cubin_path is not None:
109
- cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
110
- load_cubin_module_data_patched, filepath=cubin_path
111
- )
112
- output = cute_compile_og(*args, **kwargs)
113
- if cubin_path is not None:
114
- cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
115
- if extract is not None:
116
- sass = extract(cubin_path, None)
117
- pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
118
- return output
quack/gemm.py ADDED
@@ -0,0 +1,194 @@
1
+ from typing import Optional
2
+ from functools import partial
3
+
4
+ from torch import Tensor
5
+
6
+ import cutlass.cute as cute
7
+ import cutlass.torch as cutlass_torch
8
+ from cutlass import Float32
9
+ from cutlass.cute.runtime import from_dlpack, make_ptr
10
+
11
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
12
+ from quack.gemm_wrapper_utils import GemmWrapperBase
13
+ from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
14
+
15
+
16
+ def gemm(
17
+ # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
18
+ A: Tensor,
19
+ B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
20
+ D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
21
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
22
+ tile_count_semaphore: Optional[Tensor], # (1,)
23
+ tile_M: int,
24
+ tile_N: int,
25
+ cluster_M: int,
26
+ cluster_N: int,
27
+ pingpong: bool = False,
28
+ persistent: bool = True,
29
+ max_swizzle_size: int = 8,
30
+ rowvec_bias: Optional[Tensor] = None, # (l, n)
31
+ colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
32
+ alpha: float | Tensor = 1.0,
33
+ beta: float | Tensor = 1.0,
34
+ cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
35
+ cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
36
+ A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
37
+ batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
38
+ add_to_output: bool = False,
39
+ ) -> None:
40
+ varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
41
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
42
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
43
+ )
44
+ gather_A = A_idx is not None
45
+ if gather_A:
46
+ assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
47
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
48
+ if varlen:
49
+ assert persistent, "varlen requires persistent=True"
50
+ if add_to_output:
51
+ assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
52
+ if cu_seqlens_m is not None:
53
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
54
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
55
+ if cu_seqlens_k is not None:
56
+ assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
57
+ assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
58
+
59
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
60
+ A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
61
+ )
62
+ GemmWrapperBase.permute_tensors(
63
+ tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
64
+ )
65
+ GemmWrapperBase.extract_dtypes(tensor_infos)
66
+ major_configs = {
67
+ "A": ("m", "k", "l"),
68
+ "B": ("n", "k", "l"),
69
+ "D": ("m", "n", "l"),
70
+ "C": ("m", "n", "l"),
71
+ }
72
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
73
+
74
+ device_capacity = get_device_capacity(A.device)
75
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
76
+ GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
77
+
78
+ acc_dtype = Float32
79
+ tile_shape_mn = (tile_M, tile_N)
80
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
81
+ if not GemmCls.is_valid_dtypes(
82
+ tensor_infos["A"].dtype,
83
+ tensor_infos["B"].dtype,
84
+ acc_dtype,
85
+ tensor_infos["D"].dtype,
86
+ tensor_infos["A"].major,
87
+ tensor_infos["B"].major,
88
+ ):
89
+ raise TypeError("Skipping due to unsupported combination of types and majors")
90
+
91
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
92
+ GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
93
+
94
+ def scalar_arg(scalar: float | Tensor):
95
+ if isinstance(scalar, float):
96
+ return Float32(scalar) if scalar != 1.0 else None
97
+ else:
98
+ assert isinstance(scalar, Tensor)
99
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
100
+
101
+ epi_args = GemmCls.EpilogueArguments(
102
+ scalar_arg(alpha),
103
+ scalar_arg(beta),
104
+ mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
105
+ leading_dim=1
106
+ )
107
+ if rowvec_bias is not None
108
+ else None,
109
+ mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
110
+ leading_dim=1 if cu_seqlens_m is None else 0
111
+ )
112
+ if colvec_bias is not None
113
+ else None,
114
+ add_to_output=add_to_output,
115
+ )
116
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
117
+ max_active_clusters,
118
+ tile_count_semaphore,
119
+ batch_idx_permute,
120
+ max_swizzle_size,
121
+ )
122
+
123
+ # Create varlen arguments if needed (assumes persistent=True when varlen)
124
+ varlen_args = GemmWrapperBase.create_varlen_args(
125
+ cu_seqlens_m,
126
+ cu_seqlens_k,
127
+ A_idx,
128
+ max_active_clusters,
129
+ cluster_shape_mnk,
130
+ tensor_infos,
131
+ GemmCls.num_epi_tensormaps,
132
+ pingpong,
133
+ )
134
+
135
+ current_stream = cutlass_torch.current_stream()
136
+ compile_key = GemmWrapperBase.get_compile_key(
137
+ tensor_infos,
138
+ None, # activation
139
+ tile_shape_mn,
140
+ cluster_shape_mnk,
141
+ pingpong,
142
+ persistent,
143
+ tile_count_semaphore is not None,
144
+ device_capacity,
145
+ # Technically we don't need to recompile for different max_swizzle_size, but currently
146
+ # not recompiling will skew the autotuning results due to power throttling.
147
+ # Effectively we're recompiling as a way to pause between benchmarks during autotuning.
148
+ max_swizzle_size,
149
+ rowvec_bias.dtype if rowvec_bias is not None else None,
150
+ colvec_bias.dtype if colvec_bias is not None else None,
151
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
152
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
153
+ add_to_output,
154
+ cu_seqlens_m is not None,
155
+ cu_seqlens_k is not None,
156
+ gather_A,
157
+ batch_idx_permute is not None,
158
+ key_tensor_names=("A", "B", "D", "C"),
159
+ )
160
+ cache = gemm.compile_cache
161
+ if compile_key not in cache:
162
+ if device_capacity[0] == 9:
163
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
164
+ gemm_obj = GemmCls(
165
+ acc_dtype,
166
+ tensor_infos["A"].dtype,
167
+ tile_shape_mn,
168
+ cluster_shape_mnk,
169
+ gather_A=gather_A,
170
+ )
171
+ cache[compile_key] = cute.compile(
172
+ gemm_obj,
173
+ tensor_infos["A"].cute_tensor,
174
+ tensor_infos["B"].cute_tensor,
175
+ tensor_infos["D"].cute_tensor,
176
+ tensor_infos["C"].cute_tensor,
177
+ epi_args,
178
+ scheduler_args,
179
+ varlen_args,
180
+ current_stream,
181
+ )
182
+ cache[compile_key](
183
+ tensor_infos["A"].cute_tensor,
184
+ tensor_infos["B"].cute_tensor,
185
+ tensor_infos["D"].cute_tensor,
186
+ tensor_infos["C"].cute_tensor,
187
+ epi_args,
188
+ scheduler_args,
189
+ varlen_args,
190
+ current_stream,
191
+ )
192
+
193
+
194
+ gemm.compile_cache = {}