quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.1
6
+ Requires-Dist: nvidia-cutlass-dsl<4.4.0,>=4.3.4
7
7
  Requires-Dist: torch
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
9
+ Requires-Dist: torch-c-dlpack-ext
8
10
  Provides-Extra: dev
9
11
  Requires-Dist: pre-commit; extra == "dev"
10
12
  Requires-Dist: ruff; extra == "dev"
@@ -0,0 +1,44 @@
1
+ quack/__init__.py,sha256=_vZWQp7kr01iQb3frKnJuzUf11z7ID0upA7oR_8mRTE,203
2
+ quack/activation.py,sha256=-lZgojraqdyLjOzgOXBehoVeRBhBq30UX7kOkXsCpGI,20855
3
+ quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
4
+ quack/broadcast_utils.py,sha256=X5vWg2RtIIWU9Z7nEUW6m0EP0Cfd9XtCKxp4tSyp4Mg,1283
5
+ quack/compile_utils.py,sha256=qJ3oTsDlbAiddrJHtEO7LPYVqn_s-neNfiw-_KvfXZU,591
6
+ quack/copy_utils.py,sha256=J1Hcw18iNHHpOP2wNFhF8Lz16NEmXtoQMu59mmLrRCs,18761
7
+ quack/cross_entropy.py,sha256=w6fjHC_vXt5ji2KfoLrSOdAvpLrQszrYU9rmRij2yY8,24899
8
+ quack/cute_dsl_utils.py,sha256=4uQx5aYDG9UvVzbWwJTjjJLrnoympz70_CD8b37FQWo,3854
9
+ quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
10
+ quack/gemm.py,sha256=8V23MPq49QbV3csv-_AxjfE9qf8R3NIqFK9Q9db6t2c,7417
11
+ quack/gemm_act.py,sha256=Y8HJKfw3tCoFKecwhwhd5xpXd9jCQCGZT_V2xXf-CnU,20823
12
+ quack/gemm_config.py,sha256=94o3g9x7H0wi7aBbsb7H67H8nSzTurwL2zgvKDtQUas,3575
13
+ quack/gemm_dact.py,sha256=l__UhCrFbPjD9a1TAVgP7_C7p5lLfX5DkRcM6z0ofOw,7789
14
+ quack/gemm_default_epi.py,sha256=6qO8Ovtcw8sQQ_kXTBTTQ5IHh1lS6RBCGZG0lgLHNrs,11916
15
+ quack/gemm_interface.py,sha256=AF5PYTNgEHjb3MNXcNvvEpOcShAHtak0Xu12l1zrOAw,44804
16
+ quack/gemm_sm100.py,sha256=U9jmzpST_d1W6CBFf1ZHhTtr0K8hENCsUz7dXvHaMZc,122344
17
+ quack/gemm_sm90.py,sha256=u-Q3fN6DPm1fEdz0LcMecMbGTBcRunUCWopufwO8cHU,92015
18
+ quack/gemm_symmetric.py,sha256=mqx7wgOCY6Dh9hjL6gR9PBstMD476GhpA_NkGeaEtik,13349
19
+ quack/gemm_wrapper_utils.py,sha256=EaPyR3Lq19z_RkdB2_xxRj0IPSJMgyfpkrTXyvY3B6M,12775
20
+ quack/layout_utils.py,sha256=QjFFlvDcLiyGGfA2FKWKI75twHIkOJ2AotE0cIpBAlI,11923
21
+ quack/linear.py,sha256=mhN2A98w7H7X4MS63XCCK3gpOm1eS8H7a4WO9ovkt5U,9791
22
+ quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
23
+ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
24
+ quack/pipeline.py,sha256=mMdIlpUaHdRDOkvQzgKdCdJydJq6C2eYrny5Bui4KFs,11311
25
+ quack/reduce.py,sha256=ySKT2xh1_pIlbJX29BPmwH6yJ7MxIrRZyxHIPPYVpm0,12698
26
+ quack/reduction_base.py,sha256=QqlPs5L2VCxwDrO4CHPq-KY6f_BAYRbvsR6k81LPzTU,3180
27
+ quack/rmsnorm.py,sha256=esy18s5JtT7KBPRPhWf_anLRTrtromwqeJmg2yzOm60,44678
28
+ quack/sm100_utils.py,sha256=-p5qj3Wi9n4WDLy2sl-fApYpGp5rH3JvZQb712OTxPs,1901
29
+ quack/sm90_utils.py,sha256=hg8qq7S8NODZlUSaxNpdZcsnxcR0jM921rMn1VmBo7o,4278
30
+ quack/softmax.py,sha256=ZqeVbnGfzwkro1LfWBHagbS7B7ug7b9SLZWuGx_Y3Kc,14367
31
+ quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
32
+ quack/tile_scheduler.py,sha256=vbKq0xp94eII0uJ63yY_3sgvJkQI7Irc8y1OttO6cRA,42514
33
+ quack/topk.py,sha256=43xHpRGbwZCSRsulmfrG4WA_r2eLHc3sniaUFU7wn-o,22522
34
+ quack/utils.py,sha256=WIttE1iiwyPIwR1NpaeO26Pn9YkZb361TDxFTUDH-IE,7354
35
+ quack/varlen_utils.py,sha256=SOYkomxX2FoqjYlybg99CqNhS9IARM6F9ba2AkIVvT4,15811
36
+ quack/sort/bitonic_sort.py,sha256=VJPVjPulW_jEr3myBE7AiBYGtsc5T9FEy3sjXFukF7s,4831
37
+ quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
38
+ quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
39
+ quack/sort/utils.py,sha256=RbubEY1GcEpsjiz_6o5o2WB47IeMOzaajW6Jis0s444,1059
40
+ quack_kernels-0.2.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
41
+ quack_kernels-0.2.4.dist-info/METADATA,sha256=vMKNVe5-xDcELyrpCllppMWMRLp0T3M0wFqkHsT7hw0,368
42
+ quack_kernels-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
+ quack_kernels-0.2.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
44
+ quack_kernels-0.2.4.dist-info/RECORD,,
quack/layernorm.py DELETED
@@ -1,353 +0,0 @@
1
- # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
-
3
-
4
- import torch
5
- from typing import Optional
6
-
7
- import cuda.bindings.driver as cuda
8
-
9
- import cutlass
10
- import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
- import quack.utils as utils
13
- from quack.reduce import row_reduce
14
- from quack.reduction_base import ReductionBase
15
- from quack.cute_dsl_utils import torch2cute_dtype_map
16
-
17
-
18
- class LayerNorm(ReductionBase):
19
- def __init__(self, dtype: cutlass.Numeric, N: int):
20
- super().__init__(dtype, N, stage=2) # 2 stages for mean and var
21
- self.reload_from = None if N <= 16384 else "smem"
22
- self.delay_w_load = False
23
-
24
- def _calculate_threads_per_row(self):
25
- N = self.N
26
- return (
27
- 8
28
- if N <= 64
29
- else (
30
- 16
31
- if N <= 128
32
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
33
- )
34
- )
35
-
36
- def _set_cluster_n(self):
37
- N = self.N
38
- # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
39
- # Similarly cluster_n = 8 is faster for N=128k
40
- if cutlass.const_expr(self.dtype.width == 16):
41
- cluster_n = (
42
- 1
43
- if N <= 16 * 1024
44
- else (
45
- 2
46
- if N <= 32 * 1024
47
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
48
- )
49
- )
50
- else: # fp32
51
- cluster_n = (
52
- 1
53
- if N <= 32 * 1024
54
- else (
55
- 2
56
- if N <= 64 * 1024
57
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
58
- )
59
- )
60
- self.cluster_n = cluster_n
61
-
62
- @cute.jit
63
- def __call__(
64
- self,
65
- mX: cute.Tensor,
66
- mW: cute.Tensor,
67
- mO: cute.Tensor,
68
- mRstd: Optional[cute.Tensor],
69
- mMean: Optional[cute.Tensor],
70
- stream: cuda.CUstream,
71
- eps: cutlass.Float32 = 1e-6,
72
- ):
73
- assert mX.element_type == self.dtype
74
- assert mO.element_type == self.dtype
75
- self._set_cluster_n()
76
- tiler_mn, tv_layout = self._get_tv_layout()
77
- num_threads = cute.size(tv_layout, mode=[0])
78
- num_warps = num_threads // cute.arch.WARP_SIZE
79
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
80
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
81
- if cutlass.const_expr(mRstd is not None):
82
- mRstd_expanded_layout = cute.append(
83
- mRstd.layout, cute.make_layout((self.N,), stride=(0,))
84
- )
85
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
86
- if cutlass.const_expr(mMean is not None):
87
- mMean_expanded_layout = cute.append(
88
- mMean.layout, cute.make_layout((self.N,), stride=(0,))
89
- )
90
- mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
91
- self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
92
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
93
- block=[num_threads, 1, 1],
94
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
95
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
96
- stream=stream,
97
- )
98
-
99
- @cute.kernel
100
- def kernel(
101
- self,
102
- mX: cute.Tensor,
103
- mW: cute.Tensor,
104
- mO: cute.Tensor,
105
- mRstd: Optional[cute.Tensor],
106
- mMean: Optional[cute.Tensor],
107
- eps: cute.Float32,
108
- tv_layout: cute.Layout,
109
- tiler_mn: cute.Shape,
110
- reload_from: cutlass.Constexpr = None,
111
- delay_w_load: cutlass.Constexpr = False,
112
- ):
113
- tidx, _, _ = cute.arch.thread_idx()
114
- bidx, _, _ = cute.arch.block_idx()
115
- if cutlass.const_expr(self.cluster_n > 1):
116
- cluster_y = cute.arch.block_idx()[1]
117
- else:
118
- cluster_y = cutlass.const_expr(0)
119
-
120
- smem = cutlass.utils.SmemAllocator()
121
- sX = smem.allocate_tensor(
122
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
123
- )
124
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
125
-
126
- shape = mX.shape
127
- idX = cute.make_identity_tensor(shape)
128
- # slice for CTAs
129
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
130
- mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
131
- gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
132
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
133
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
134
- gRstd = (
135
- cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
136
- if cutlass.const_expr(mRstd is not None)
137
- else None
138
- )
139
- gMean = (
140
- cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
141
- if cutlass.const_expr(mMean is not None)
142
- else None
143
- )
144
-
145
- # declare the atoms which will be used later for memory copy
146
- copy_atom_load_X = cute.make_copy_atom(
147
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
148
- )
149
- copy_atom_load_X_async = cute.make_copy_atom(
150
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
151
- )
152
- copy_atom_load_W = cute.make_copy_atom(
153
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
154
- )
155
- copy_atom_store_O = cute.make_copy_atom(
156
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
157
- )
158
-
159
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
160
- tidx
161
- )
162
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
163
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
164
-
165
- tWgW = thr_copy_W.partition_S(gW)
166
- tXgX = thr_copy_X.partition_S(gX)
167
- tXsX = thr_copy_X.partition_D(sX)
168
- tXgO = thr_copy_O.partition_D(gO)
169
- tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
170
- tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
171
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
172
-
173
- # allocate fragments for gmem->rmem
174
- tWrW = cute.make_fragment_like(tWgW)
175
- tXrW = thr_copy_X.retile(tWrW)
176
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
177
-
178
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
179
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
180
-
181
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
182
- row = tXcX[0][0]
183
- if row < shape[0]:
184
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
185
- cute.arch.cp_async_commit_group()
186
-
187
- tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
188
- if cutlass.const_expr(not delay_w_load):
189
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
190
-
191
- cute.arch.cp_async_wait_group(0)
192
- cute.autovec_copy(tXsX, tXrX)
193
- x = tXrX.load().to(cute.Float32)
194
- threads_per_row = tv_layout.shape[0][0]
195
- sum_x = row_reduce(
196
- x,
197
- cute.ReductionOp.ADD,
198
- threads_per_row,
199
- reduction_buffer[None, None, 0],
200
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
201
- init_val=0.0,
202
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
203
- )
204
- mean = sum_x / shape[1]
205
- if cutlass.const_expr(reload_from == "smem"):
206
- cute.autovec_copy(tXsX, tXrX)
207
- x = tXrX.load().to(cute.Float32)
208
- elif cutlass.const_expr(reload_from == "gmem"):
209
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
210
- x = tXrX.load().to(cute.Float32)
211
-
212
- sum_sq_x_sub_mean = row_reduce(
213
- (x - mean) * (x - mean),
214
- cute.ReductionOp.ADD,
215
- threads_per_row,
216
- reduction_buffer[None, None, 1],
217
- mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
218
- init_val=0.0,
219
- )
220
- rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
221
- if cutlass.const_expr(mRstd is not None):
222
- # Only the thread corresponding to column 0 writes out the rstd to gmem
223
- if (
224
- tXcX[0][1] == 0
225
- and row < shape[0]
226
- and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
227
- ):
228
- tXrRstd[0] = rstd
229
- if cutlass.const_expr(mMean is not None):
230
- # Only the thread corresponding to column 0 writes out the mean to gmem
231
- if (
232
- tXcX[0][1] == 0
233
- and row < shape[0]
234
- and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
235
- ):
236
- tXrMean[0] = mean
237
- if cutlass.const_expr(delay_w_load):
238
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
239
- if cutlass.const_expr(reload_from == "smem"):
240
- cute.autovec_copy(tXsX, tXrX)
241
- x = tXrX.load().to(cute.Float32)
242
- elif cutlass.const_expr(reload_from == "gmem"):
243
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
244
- x = tXrX.load().to(cute.Float32)
245
- x_hat = (x - mean) * rstd
246
- w = tXrW.load().to(cute.Float32)
247
- y = x_hat * w
248
- tXrO.store(y.to(tXrO.element_type))
249
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
250
- if row < shape[0]:
251
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
252
-
253
-
254
- def layernorm(
255
- x: torch.Tensor,
256
- weight: torch.Tensor,
257
- eps: float = 1e-6,
258
- return_rstd: bool = False,
259
- return_mean: bool = False,
260
- ) -> torch.Tensor:
261
- """LayerNorm forward pass.
262
-
263
- Args:
264
- x: Input tensor of shape (M, N)
265
- weight: Weight tensor of shape (N,)
266
- eps: Small value for numerical stability
267
- return_rstd: Whether to return the reciprocal standard deviation
268
- return_mean: Whether to return the mean
269
-
270
- Returns:
271
- Normalized output tensor of same shape as x
272
- If return_rstd is True, also returns rstd tensor of shape (M,)
273
- If return_mean is True, also returns mean tensor of shape (M,)
274
- """
275
- assert x.dim() == 2, "Input must be 2D"
276
- assert weight.dim() == 1, "Weight must be 1D"
277
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
278
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
279
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
280
- assert weight.dtype == torch.float32, "Weight must be float32"
281
- M, N = x.shape
282
- device = x.device
283
- out = torch.empty_like(x)
284
- rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
285
- mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
286
- dtype = torch2cute_dtype_map[x.dtype]
287
- convert_from_dlpack = lambda x: (
288
- from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
289
- mode=0, stride_order=(0, 1)
290
- )
291
- )
292
- x_tensor, out_tensor = [
293
- # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
294
- convert_from_dlpack(t)
295
- for t in (x, out)
296
- ]
297
- weight_tensor = utils.convert_from_dlpack(
298
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
299
- )
300
- rstd_tensor = (
301
- from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
302
- if rstd is not None
303
- else None
304
- )
305
- mean_tensor = (
306
- from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
307
- if mean is not None
308
- else None
309
- )
310
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
311
- compile_key = (dtype, N, rstd is not None, mean is not None)
312
- if compile_key not in layernorm.compile_cache:
313
- rmsnorm_op = LayerNorm(dtype, N)
314
- layernorm.compile_cache[compile_key] = cute.compile(
315
- rmsnorm_op,
316
- x_tensor,
317
- weight_tensor,
318
- out_tensor,
319
- rstd_tensor,
320
- mean_tensor,
321
- current_stream,
322
- )
323
- layernorm.compile_cache[compile_key](
324
- x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
325
- )
326
- return (
327
- (out, rstd, mean)
328
- if return_mean and return_rstd
329
- else (
330
- (out, rstd)
331
- if return_rstd and not return_mean
332
- else ((out, mean) if return_mean and not return_rstd else (out))
333
- )
334
- )
335
-
336
-
337
- layernorm.compile_cache = {}
338
-
339
-
340
- def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
341
- x_f32 = x.float()
342
- return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
343
-
344
-
345
- def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
346
- x_f32 = x.float()
347
- mean = x_f32.mean(dim=-1, keepdim=True)
348
- var = ((x_f32 - mean) ** 2).mean(dim=-1)
349
- return 1.0 / torch.sqrt(var + eps)
350
-
351
-
352
- def mean_ref(x: torch.Tensor) -> torch.Tensor:
353
- return x.float().mean(dim=-1)