cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,546 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Layout(enum.IntEnum):
18
+ BND_BND = 0
19
+ BDN_BND = 1
20
+ BND_BDN = 2
21
+ DBN_BND = 3
22
+ BND_DBN = 4
23
+
24
+
25
+ @triton.jit
26
+ def layer_norm_transpose_forward_single_pass_kernel(
27
+ # inputs:
28
+ x_ptr,
29
+ w_ptr,
30
+ b_ptr,
31
+ # outputs: (order matters for jax_triton)
32
+ out_ptr,
33
+ mean_ptr,
34
+ rstd_ptr,
35
+ B,
36
+ N,
37
+ D: tl.constexpr,
38
+ EPS: tl.constexpr,
39
+ TILE_N: tl.constexpr,
40
+ TILE_D: tl.constexpr,
41
+ ELEMENTWISE_AFFINE: tl.constexpr,
42
+ LAYOUT: tl.constexpr,
43
+ NEEDS_INT64: tl.constexpr = True,
44
+ ):
45
+ pid_n = tl.program_id(0)
46
+ pid_b = tl.program_id(1)
47
+
48
+ if NEEDS_INT64:
49
+ pid_n = tl.cast(pid_n, tl.int64)
50
+ pid_b = tl.cast(pid_b, tl.int64)
51
+ B = tl.cast(B, tl.int64)
52
+ N = tl.cast(N, tl.int64)
53
+
54
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
55
+ offs_d = tl.arange(0, TILE_D)
56
+
57
+ if LAYOUT == 0: # bnd->bnd
58
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
59
+ elif LAYOUT == 1: # bdn->bnd
60
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
61
+ elif LAYOUT == 2: # bnd->bdn
62
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
63
+ elif LAYOUT == 3: # dbn->bnd
64
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
65
+ elif LAYOUT == 4: # bnd->dbn
66
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
67
+
68
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
69
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
70
+ mask_n = offs_n < N
71
+
72
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
73
+ mean = tl.sum(x, axis=1) / D
74
+ x_centered = x - mean[:, None]
75
+ var = tl.sum(x_centered * x_centered, axis=1) / D
76
+ rstd = tl.rsqrt(var + EPS)
77
+
78
+ tl.store(mean_ptrs, mean, mask=mask_n)
79
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
80
+
81
+ x_hat = x_centered * rstd[:, None]
82
+
83
+ if ELEMENTWISE_AFFINE:
84
+ w_ptrs = w_ptr + offs_d
85
+ b_ptrs = b_ptr + offs_d
86
+ w = tl.load(w_ptrs).to(tl.float32)
87
+ b = tl.load(b_ptrs).to(tl.float32)
88
+ y = x_hat * w[None, :] + b[None, :]
89
+ else:
90
+ y = x_hat
91
+
92
+ if LAYOUT == 0: # bnd->bnd
93
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
94
+ elif LAYOUT == 1: # bdn->bnd
95
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
96
+ elif LAYOUT == 2: # bnd->bdn
97
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
98
+ elif LAYOUT == 3: # dbn->bnd
99
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
100
+ elif LAYOUT == 4: # bnd->dbn
101
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
102
+
103
+ tl.store(out_ptrs, y, mask=mask_n[:, None])
104
+
105
+
106
+ @triton.jit
107
+ def layer_norm_transpose_forward_kernel(
108
+ # inputs:
109
+ x_ptr,
110
+ w_ptr,
111
+ b_ptr,
112
+ # outputs: (order matters for jax_triton)
113
+ out_ptr,
114
+ mean_ptr,
115
+ rstd_ptr,
116
+ B,
117
+ N,
118
+ D: tl.constexpr,
119
+ EPS: tl.constexpr,
120
+ TILE_N: tl.constexpr,
121
+ TILE_D: tl.constexpr,
122
+ ELEMENTWISE_AFFINE: tl.constexpr,
123
+ LAYOUT: tl.constexpr,
124
+ NEEDS_INT64: tl.constexpr = True,
125
+ ):
126
+ pid_n = tl.program_id(0)
127
+ pid_b = tl.program_id(1)
128
+
129
+ if NEEDS_INT64:
130
+ pid_n = tl.cast(pid_n, tl.int64)
131
+ pid_b = tl.cast(pid_b, tl.int64)
132
+ N = tl.cast(N, tl.int64)
133
+ B = tl.cast(B, tl.int64)
134
+
135
+ num_tiles_d = tl.cdiv(D, TILE_D)
136
+ D_CEIL = num_tiles_d * TILE_D
137
+
138
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
139
+ offs_d = tl.arange(0, TILE_D)
140
+ mask_n = offs_n < N
141
+
142
+ if LAYOUT == 0: # bnd->bnd
143
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
144
+ elif LAYOUT == 1: # bdn->bnd
145
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
146
+ elif LAYOUT == 2: # bnd->bdn
147
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
148
+ elif LAYOUT == 3: # dbn->bnd
149
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
150
+ elif LAYOUT == 4: # bnd->dbn
151
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
152
+
153
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
154
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
155
+
156
+ _mean = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
157
+ for di in range(0, num_tiles_d):
158
+ mask_d = offs_d < (D - di * TILE_D)
159
+ mask_nd = mask_n[:, None] & mask_d[None, :]
160
+
161
+ x = tl.load(x_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
162
+ _mean += x
163
+
164
+ if LAYOUT == 0: # bnd->bnd
165
+ x_ptrs += TILE_D
166
+ elif LAYOUT == 1: # bdn->bnd
167
+ x_ptrs += TILE_D * N
168
+ elif LAYOUT == 2: # bnd->bdn
169
+ x_ptrs += TILE_D
170
+ elif LAYOUT == 3: # dbn->bnd
171
+ x_ptrs += TILE_D * B * N
172
+ elif LAYOUT == 4: # bnd->dbn
173
+ x_ptrs += TILE_D
174
+
175
+ mean = tl.sum(_mean, axis=1) / D
176
+ tl.store(mean_ptrs, mean, mask=mask_n)
177
+
178
+ if LAYOUT == 0: # bnd->bnd
179
+ x_ptrs -= D_CEIL
180
+ elif LAYOUT == 1: # bdn->bnd
181
+ x_ptrs -= D_CEIL * N
182
+ elif LAYOUT == 2: # bnd->bdn
183
+ x_ptrs -= D_CEIL
184
+ elif LAYOUT == 3: # dbn->bnd
185
+ x_ptrs -= D_CEIL * B * N
186
+ elif LAYOUT == 4: # bnd->dbn
187
+ x_ptrs -= D_CEIL
188
+
189
+ _var = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
190
+ for di in range(0, num_tiles_d):
191
+ mask_d = offs_d < (D - di * TILE_D)
192
+ mask_nd = mask_n[:, None] & mask_d[None, :]
193
+
194
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
195
+ x = x - mean[:, None]
196
+ _var += x * x
197
+
198
+ if LAYOUT == 0: # bnd->bnd
199
+ x_ptrs += TILE_D
200
+ elif LAYOUT == 1: # bdn->bnd
201
+ x_ptrs += TILE_D * N
202
+ elif LAYOUT == 2: # bnd->bdn
203
+ x_ptrs += TILE_D
204
+ elif LAYOUT == 3: # dbn->bnd
205
+ x_ptrs += TILE_D * B * N
206
+ elif LAYOUT == 4: # bnd->dbn
207
+ x_ptrs += TILE_D
208
+
209
+ var = tl.sum(_var, axis=1) / D
210
+ rstd = tl.rsqrt(var + EPS)
211
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
212
+
213
+ if LAYOUT == 0: # bnd->bnd
214
+ x_ptrs -= D_CEIL
215
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
216
+ elif LAYOUT == 1: # bdn->bnd
217
+ x_ptrs -= D_CEIL * N
218
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
219
+ elif LAYOUT == 2: # bnd->bdn
220
+ x_ptrs -= D_CEIL
221
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
222
+ elif LAYOUT == 3: # dbn->bnd
223
+ x_ptrs -= D_CEIL * B * N
224
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
225
+ elif LAYOUT == 4: # bnd->dbn
226
+ x_ptrs -= D_CEIL
227
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
228
+
229
+ if ELEMENTWISE_AFFINE:
230
+ w_ptrs = w_ptr + offs_d
231
+ b_ptrs = b_ptr + offs_d
232
+
233
+ for di in range(0, num_tiles_d):
234
+ mask_d = offs_d < (D - di * TILE_D)
235
+ mask_nd = mask_n[:, None] & mask_d[None, :]
236
+
237
+ if ELEMENTWISE_AFFINE:
238
+ w = tl.load(w_ptrs, mask=mask_d, other=0.0).to(tl.float32)
239
+ b = tl.load(b_ptrs, mask=mask_d, other=0.0).to(tl.float32)
240
+ else:
241
+ w = 1.0
242
+ b = 0.0
243
+
244
+ x = tl.load(x_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
245
+ x_hat = (x - mean[:, None]) * rstd[:, None]
246
+ y = x_hat * w[None, :] + b[None, :]
247
+ tl.store(out_ptrs, y, mask=mask_nd)
248
+
249
+ if LAYOUT == 0: # bnd->bnd
250
+ x_ptrs += TILE_D
251
+ out_ptrs += TILE_D
252
+ elif LAYOUT == 1: # bdn->bnd
253
+ x_ptrs += TILE_D * N
254
+ out_ptrs += TILE_D
255
+ elif LAYOUT == 2: # bnd->bdn
256
+ x_ptrs += TILE_D
257
+ out_ptrs += TILE_D * N
258
+ elif LAYOUT == 3: # dbn->bnd
259
+ x_ptrs += TILE_D * B * N
260
+ out_ptrs += TILE_D
261
+ elif LAYOUT == 4: # bnd->dbn
262
+ x_ptrs += TILE_D
263
+ out_ptrs += TILE_D * B * N
264
+
265
+ if ELEMENTWISE_AFFINE:
266
+ w_ptrs += TILE_D
267
+ b_ptrs += TILE_D
268
+
269
+
270
+ @triton.jit
271
+ def layer_norm_transpose_backward_single_pass_kernel(
272
+ # inputs:
273
+ grad_out_ptr,
274
+ x_ptr,
275
+ w_ptr,
276
+ mean_ptr,
277
+ rstd_ptr,
278
+ # outputs: (order matters for jax_triton)
279
+ grad_x_ptr,
280
+ grad_w_ptr,
281
+ grad_b_ptr,
282
+ B,
283
+ N,
284
+ D: tl.constexpr,
285
+ TILE_N: tl.constexpr,
286
+ TILE_D: tl.constexpr,
287
+ ELEMENTWISE_AFFINE: tl.constexpr,
288
+ LAYOUT: tl.constexpr,
289
+ NEEDS_INT64: tl.constexpr = True,
290
+ ):
291
+ pid_n = tl.program_id(0)
292
+ pid_b = tl.program_id(1)
293
+
294
+ if NEEDS_INT64:
295
+ pid_n = tl.cast(pid_n, tl.int64)
296
+ pid_b = tl.cast(pid_b, tl.int64)
297
+ N = tl.cast(N, tl.int64)
298
+ B = tl.cast(B, tl.int64)
299
+
300
+ num_tiles_n = tl.cdiv(N, TILE_N)
301
+
302
+ offs_d = tl.arange(0, TILE_D)
303
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
304
+ mask_n = offs_n < N
305
+
306
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
307
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
308
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
309
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
310
+
311
+ if LAYOUT == 0: # bnd->bnd
312
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
313
+ x_ptrs = x_base_ptrs + offs_d[None, :]
314
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
315
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
316
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
317
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
318
+ elif LAYOUT == 1: # bdn->bnd
319
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
320
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
321
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
322
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
323
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
324
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
325
+ elif LAYOUT == 2: # bnd->bdn
326
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
327
+ x_ptrs = x_base_ptrs + offs_d[None, :]
328
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
329
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
330
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
331
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
332
+ elif LAYOUT == 3: # dbn->bnd
333
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
334
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
335
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
336
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
337
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
338
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
339
+ elif LAYOUT == 4: # bnd->dbn
340
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
341
+ x_ptrs = x_base_ptrs + offs_d[None, :]
342
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
343
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
344
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
345
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
346
+
347
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
348
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
349
+
350
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
351
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
352
+
353
+ xhat = (x - mean[:, None]) * rstd[:, None]
354
+
355
+ if ELEMENTWISE_AFFINE:
356
+ grad_b = grad_out
357
+ grad_b = tl.sum(grad_b, axis=0)
358
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
359
+ tl.store(grad_b_ptrs, grad_b)
360
+
361
+ grad_w = grad_out * xhat
362
+ grad_w = tl.sum(grad_w, axis=0)
363
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
364
+ tl.store(grad_w_ptrs, grad_w)
365
+
366
+ w_ptrs = w_ptr + offs_d
367
+ w = tl.load(w_ptrs).to(tl.float32)
368
+ wdo = w * grad_out
369
+
370
+ else:
371
+ wdo = grad_out
372
+
373
+ c1 = xhat * wdo
374
+ c2 = wdo
375
+
376
+ c1_dot = tl.sum(c1, axis=1) / D
377
+ c2_dot = tl.sum(c2, axis=1) / D
378
+
379
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
380
+ tl.store(grad_x_ptrs, dx, mask=mask_n[:, None])
381
+
382
+
383
+ @triton.jit
384
+ def layer_norm_transpose_backward_kernel(
385
+ # inputs:
386
+ grad_out_ptr,
387
+ x_ptr,
388
+ w_ptr,
389
+ mean_ptr,
390
+ rstd_ptr,
391
+ # outputs: (order matters for jax_triton)
392
+ grad_x_ptr,
393
+ grad_w_ptr,
394
+ grad_b_ptr,
395
+ B,
396
+ N,
397
+ D: tl.constexpr,
398
+ TILE_N: tl.constexpr,
399
+ TILE_D: tl.constexpr,
400
+ ELEMENTWISE_AFFINE: tl.constexpr,
401
+ LAYOUT: tl.constexpr,
402
+ NEEDS_INT64: tl.constexpr = True,
403
+ ):
404
+ pid_n = tl.program_id(0)
405
+ pid_b = tl.program_id(1)
406
+
407
+ if NEEDS_INT64:
408
+ pid_n = tl.cast(pid_n, tl.int64)
409
+ pid_b = tl.cast(pid_b, tl.int64)
410
+ N = tl.cast(N, tl.int64)
411
+ B = tl.cast(B, tl.int64)
412
+
413
+ num_tiles_d = tl.cdiv(D, TILE_D)
414
+ num_tiles_n = tl.cdiv(N, TILE_N)
415
+
416
+ offs_d = tl.arange(0, TILE_D)
417
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
418
+ mask_n = offs_n < N
419
+
420
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
421
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
422
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
423
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
424
+
425
+ if LAYOUT == 0: # bnd->bnd
426
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
427
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
428
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
429
+ elif LAYOUT == 1: # bdn->bnd
430
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
431
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
432
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
433
+ elif LAYOUT == 2: # bnd->bdn
434
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
435
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
436
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
437
+ elif LAYOUT == 3: # dbn->bnd
438
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
439
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
440
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
441
+ elif LAYOUT == 4: # bnd->dbn
442
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
443
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
444
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
445
+
446
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
447
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
448
+
449
+ c1 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
450
+ c2 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
451
+
452
+ for di in range(num_tiles_d):
453
+ mask_d = offs_d < D
454
+ mask_nd = mask_n[:, None] & mask_d[None, :]
455
+
456
+ if ELEMENTWISE_AFFINE:
457
+ w_ptrs = w_ptr + offs_d
458
+ w = tl.load(w_ptrs, mask=mask_d, other=1.0).to(tl.float32)
459
+ else:
460
+ w = 1.0
461
+
462
+ if LAYOUT == 0: # bnd->bnd
463
+ x_ptrs = x_base_ptrs + offs_d[None, :]
464
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
465
+ elif LAYOUT == 1: # bdn->bnd
466
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
467
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
468
+ elif LAYOUT == 2: # bnd->bdn
469
+ x_ptrs = x_base_ptrs + offs_d[None, :]
470
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
471
+ elif LAYOUT == 3: # dbn->bnd
472
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
473
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
474
+ elif LAYOUT == 4: # bnd->dbn
475
+ x_ptrs = x_base_ptrs + offs_d[None, :]
476
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
477
+
478
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
479
+ grad_out = tl.load(grad_out_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
480
+
481
+ xhat = (x - mean[:, None]) * rstd[:, None]
482
+ wdo = w * grad_out
483
+
484
+ c1 += xhat * wdo
485
+ c2 += wdo
486
+
487
+ offs_d += TILE_D
488
+
489
+ c1_dot = tl.sum(c1, axis=1) / D
490
+ c2_dot = tl.sum(c2, axis=1) / D
491
+
492
+ offs_d -= TILE_D * num_tiles_d
493
+
494
+ for di in range(num_tiles_d):
495
+ mask_d = offs_d < D
496
+ mask_nd = mask_n[:, None] & mask_d[None, :]
497
+
498
+ if ELEMENTWISE_AFFINE:
499
+ w_ptrs = w_ptr + offs_d
500
+ w = tl.load(w_ptrs, mask=mask_d, other=0.0).to(tl.float32)
501
+ else:
502
+ w = 1.0
503
+
504
+ if LAYOUT == 0: # bnd->bnd
505
+ x_ptrs = x_base_ptrs + offs_d[None, :]
506
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
507
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
508
+ elif LAYOUT == 1: # bdn->bnd
509
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
510
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
511
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
512
+ elif LAYOUT == 2: # bnd->bdn
513
+ x_ptrs = x_base_ptrs + offs_d[None, :]
514
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
515
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
516
+ elif LAYOUT == 3: # dbn->bnd
517
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
518
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
519
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
520
+ elif LAYOUT == 4: # bnd->dbn
521
+ x_ptrs = x_base_ptrs + offs_d[None, :]
522
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
523
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
524
+
525
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
526
+ grad_out = tl.load(grad_out_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
527
+
528
+ xhat = (x - mean[:, None]) * rstd[:, None]
529
+
530
+ if ELEMENTWISE_AFFINE:
531
+ grad_b = grad_out
532
+ grad_b = tl.sum(grad_b, axis=0)
533
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
534
+ tl.store(grad_b_ptrs, grad_b, mask=mask_d)
535
+
536
+ grad_w = grad_out * xhat
537
+ grad_w = tl.sum(grad_w, axis=0)
538
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
539
+ tl.store(grad_w_ptrs, grad_w, mask=mask_d)
540
+
541
+ wdo = w * grad_out
542
+
543
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
544
+ tl.store(grad_x_ptrs, dx, mask=mask_nd)
545
+
546
+ offs_d += TILE_D