cuequivariance-ops-cu12 0.6.0__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (37) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/nvtx.hpp +29 -0
  6. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  7. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  8. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  9. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  10. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  11. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  12. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  13. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  14. cuequivariance_ops/lib/libcue_ops.so +0 -0
  15. cuequivariance_ops/sleep.hh +18 -0
  16. cuequivariance_ops/triton/__init__.py +66 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  20. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  21. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  27. cuequivariance_ops/triton/cache_manager.py +259 -0
  28. cuequivariance_ops/triton/fused_layer_norm_triton.py +518 -0
  29. cuequivariance_ops/triton/gated_gemm_triton.py +380 -0
  30. cuequivariance_ops/triton/pair_bias.py +324 -0
  31. cuequivariance_ops/triton/tuning_decorator.py +177 -0
  32. cuequivariance_ops/triton/utils.py +28 -0
  33. cuequivariance_ops_cu12-0.6.0.dist-info/METADATA +182 -0
  34. cuequivariance_ops_cu12-0.6.0.dist-info/RECORD +37 -0
  35. cuequivariance_ops_cu12-0.6.0.dist-info/WHEEL +6 -0
  36. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/LICENSE +142 -0
  37. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/Third_party_attr.txt +24 -0
@@ -0,0 +1,518 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Layout(enum.IntEnum):
18
+ BND_BND = 0
19
+ BDN_BND = 1
20
+ BND_BDN = 2
21
+ DBN_BND = 3
22
+ BND_DBN = 4
23
+
24
+
25
+ @triton.jit
26
+ def layer_norm_transpose_forward_single_pass_kernel(
27
+ # inputs:
28
+ x_ptr,
29
+ w_ptr,
30
+ b_ptr,
31
+ # outputs: (order matters for jax_triton)
32
+ out_ptr,
33
+ mean_ptr,
34
+ rstd_ptr,
35
+ B,
36
+ N,
37
+ D: tl.constexpr,
38
+ EPS: tl.constexpr,
39
+ TILE_N: tl.constexpr,
40
+ TILE_D: tl.constexpr,
41
+ ELEMENTWISE_AFFINE: tl.constexpr,
42
+ LAYOUT: tl.constexpr,
43
+ ):
44
+ pid_n = tl.program_id(0)
45
+ pid_b = tl.program_id(1)
46
+
47
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
48
+ offs_d = tl.arange(0, TILE_D)
49
+
50
+ if LAYOUT == 0: # bnd->bnd
51
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
52
+ elif LAYOUT == 1: # bdn->bnd
53
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
54
+ elif LAYOUT == 2: # bnd->bdn
55
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
56
+ elif LAYOUT == 3: # dbn->bnd
57
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
58
+ elif LAYOUT == 4: # bnd->dbn
59
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
60
+
61
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
62
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
63
+ mask_n = offs_n < N
64
+
65
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
66
+ mean = tl.sum(x, axis=1) / D
67
+ x_centered = x - mean[:, None]
68
+ var = tl.sum(x_centered * x_centered, axis=1) / D
69
+ rstd = tl.rsqrt(var + EPS)
70
+
71
+ tl.store(mean_ptrs, mean, mask=mask_n)
72
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
73
+
74
+ x_hat = x_centered * rstd[:, None]
75
+
76
+ if ELEMENTWISE_AFFINE:
77
+ w_ptrs = w_ptr + offs_d
78
+ b_ptrs = b_ptr + offs_d
79
+ w = tl.load(w_ptrs).to(tl.float32)
80
+ b = tl.load(b_ptrs).to(tl.float32)
81
+ y = x_hat * w[None, :] + b[None, :]
82
+ else:
83
+ y = x_hat
84
+
85
+ if LAYOUT == 0: # bnd->bnd
86
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
87
+ elif LAYOUT == 1: # bdn->bnd
88
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
89
+ elif LAYOUT == 2: # bnd->bdn
90
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
91
+ elif LAYOUT == 3: # dbn->bnd
92
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
93
+ elif LAYOUT == 4: # bnd->dbn
94
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
95
+
96
+ tl.store(out_ptrs, y, mask=mask_n[:, None])
97
+
98
+
99
+ @triton.jit
100
+ def layer_norm_transpose_forward_kernel(
101
+ # inputs:
102
+ x_ptr,
103
+ w_ptr,
104
+ b_ptr,
105
+ # outputs: (order matters for jax_triton)
106
+ out_ptr,
107
+ mean_ptr,
108
+ rstd_ptr,
109
+ B,
110
+ N,
111
+ D: tl.constexpr,
112
+ EPS: tl.constexpr,
113
+ TILE_N: tl.constexpr,
114
+ TILE_D: tl.constexpr,
115
+ ELEMENTWISE_AFFINE: tl.constexpr,
116
+ LAYOUT: tl.constexpr,
117
+ ):
118
+ pid_n = tl.program_id(0)
119
+ pid_b = tl.program_id(1)
120
+
121
+ num_tiles_d = tl.cdiv(D, TILE_D)
122
+ D_CEIL = num_tiles_d * TILE_D
123
+
124
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
125
+ offs_d = tl.arange(0, TILE_D)
126
+ mask_n = offs_n < N
127
+
128
+ if LAYOUT == 0: # bnd->bnd
129
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
130
+ elif LAYOUT == 1: # bdn->bnd
131
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
132
+ elif LAYOUT == 2: # bnd->bdn
133
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
134
+ elif LAYOUT == 3: # dbn->bnd
135
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
136
+ elif LAYOUT == 4: # bnd->dbn
137
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
138
+
139
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
140
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
141
+
142
+ _mean = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
143
+ for di in range(0, num_tiles_d):
144
+ mask_d = offs_d < (D - di * TILE_D)
145
+ mask_nd = mask_n[:, None] & mask_d[None, :]
146
+
147
+ x = tl.load(x_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
148
+ _mean += x
149
+
150
+ if LAYOUT == 0: # bnd->bnd
151
+ x_ptrs += TILE_D
152
+ elif LAYOUT == 1: # bdn->bnd
153
+ x_ptrs += TILE_D * N
154
+ elif LAYOUT == 2: # bnd->bdn
155
+ x_ptrs += TILE_D
156
+ elif LAYOUT == 3: # dbn->bnd
157
+ x_ptrs += TILE_D * B * N
158
+ elif LAYOUT == 4: # bnd->dbn
159
+ x_ptrs += TILE_D
160
+
161
+ mean = tl.sum(_mean, axis=1) / D
162
+ tl.store(mean_ptrs, mean, mask=mask_n)
163
+
164
+ if LAYOUT == 0: # bnd->bnd
165
+ x_ptrs -= D_CEIL
166
+ elif LAYOUT == 1: # bdn->bnd
167
+ x_ptrs -= D_CEIL * N
168
+ elif LAYOUT == 2: # bnd->bdn
169
+ x_ptrs -= D_CEIL
170
+ elif LAYOUT == 3: # dbn->bnd
171
+ x_ptrs -= D_CEIL * B * N
172
+ elif LAYOUT == 4: # bnd->dbn
173
+ x_ptrs -= D_CEIL
174
+
175
+ _var = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
176
+ for di in range(0, num_tiles_d):
177
+ mask_d = offs_d < (D - di * TILE_D)
178
+ mask_nd = mask_n[:, None] & mask_d[None, :]
179
+
180
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
181
+ x = x - mean[:, None]
182
+ _var += x * x
183
+
184
+ if LAYOUT == 0: # bnd->bnd
185
+ x_ptrs += TILE_D
186
+ elif LAYOUT == 1: # bdn->bnd
187
+ x_ptrs += TILE_D * N
188
+ elif LAYOUT == 2: # bnd->bdn
189
+ x_ptrs += TILE_D
190
+ elif LAYOUT == 3: # dbn->bnd
191
+ x_ptrs += TILE_D * B * N
192
+ elif LAYOUT == 4: # bnd->dbn
193
+ x_ptrs += TILE_D
194
+
195
+ var = tl.sum(_var, axis=1) / D
196
+ rstd = tl.rsqrt(var + EPS)
197
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
198
+
199
+ if LAYOUT == 0: # bnd->bnd
200
+ x_ptrs -= D_CEIL
201
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
202
+ elif LAYOUT == 1: # bdn->bnd
203
+ x_ptrs -= D_CEIL * N
204
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
205
+ elif LAYOUT == 2: # bnd->bdn
206
+ x_ptrs -= D_CEIL
207
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
208
+ elif LAYOUT == 3: # dbn->bnd
209
+ x_ptrs -= D_CEIL * B * N
210
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
211
+ elif LAYOUT == 4: # bnd->dbn
212
+ x_ptrs -= D_CEIL
213
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
214
+
215
+ if ELEMENTWISE_AFFINE:
216
+ w_ptrs = w_ptr + offs_d
217
+ b_ptrs = b_ptr + offs_d
218
+
219
+ for di in range(0, num_tiles_d):
220
+ mask_d = offs_d < (D - di * TILE_D)
221
+ mask_nd = mask_n[:, None] & mask_d[None, :]
222
+
223
+ if ELEMENTWISE_AFFINE:
224
+ w = tl.load(w_ptrs, mask=mask_d, other=0.0).to(tl.float32)
225
+ b = tl.load(b_ptrs, mask=mask_d, other=0.0).to(tl.float32)
226
+ else:
227
+ w = 1.0
228
+ b = 0.0
229
+
230
+ x = tl.load(x_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
231
+ x_hat = (x - mean[:, None]) * rstd[:, None]
232
+ y = x_hat * w[None, :] + b[None, :]
233
+ tl.store(out_ptrs, y, mask=mask_nd)
234
+
235
+ if LAYOUT == 0: # bnd->bnd
236
+ x_ptrs += TILE_D
237
+ out_ptrs += TILE_D
238
+ elif LAYOUT == 1: # bdn->bnd
239
+ x_ptrs += TILE_D * N
240
+ out_ptrs += TILE_D
241
+ elif LAYOUT == 2: # bnd->bdn
242
+ x_ptrs += TILE_D
243
+ out_ptrs += TILE_D * N
244
+ elif LAYOUT == 3: # dbn->bnd
245
+ x_ptrs += TILE_D * B * N
246
+ out_ptrs += TILE_D
247
+ elif LAYOUT == 4: # bnd->dbn
248
+ x_ptrs += TILE_D
249
+ out_ptrs += TILE_D * B * N
250
+
251
+ if ELEMENTWISE_AFFINE:
252
+ w_ptrs += TILE_D
253
+ b_ptrs += TILE_D
254
+
255
+
256
+ @triton.jit
257
+ def layer_norm_transpose_backward_single_pass_kernel(
258
+ # inputs:
259
+ grad_out_ptr,
260
+ x_ptr,
261
+ w_ptr,
262
+ mean_ptr,
263
+ rstd_ptr,
264
+ # outputs: (order matters for jax_triton)
265
+ grad_x_ptr,
266
+ grad_w_ptr,
267
+ grad_b_ptr,
268
+ B,
269
+ N,
270
+ D: tl.constexpr,
271
+ TILE_N: tl.constexpr,
272
+ TILE_D: tl.constexpr,
273
+ ELEMENTWISE_AFFINE: tl.constexpr,
274
+ LAYOUT: tl.constexpr,
275
+ ):
276
+ pid_n = tl.program_id(0)
277
+ pid_b = tl.program_id(1)
278
+
279
+ num_tiles_n = tl.cdiv(N, TILE_N)
280
+
281
+ offs_d = tl.arange(0, TILE_D)
282
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
283
+ mask_n = offs_n < N
284
+
285
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
286
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
287
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
288
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
289
+
290
+ if LAYOUT == 0: # bnd->bnd
291
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
292
+ x_ptrs = x_base_ptrs + offs_d[None, :]
293
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
294
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
295
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
296
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
297
+ elif LAYOUT == 1: # bdn->bnd
298
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
299
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
300
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
301
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
302
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
303
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
304
+ elif LAYOUT == 2: # bnd->bdn
305
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
306
+ x_ptrs = x_base_ptrs + offs_d[None, :]
307
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
308
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
309
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
310
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
311
+ elif LAYOUT == 3: # dbn->bnd
312
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
313
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
314
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
315
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
316
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
317
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
318
+ elif LAYOUT == 4: # bnd->dbn
319
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
320
+ x_ptrs = x_base_ptrs + offs_d[None, :]
321
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
322
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
323
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
324
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
325
+
326
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
327
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
328
+
329
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
330
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
331
+
332
+ xhat = (x - mean[:, None]) * rstd[:, None]
333
+
334
+ if ELEMENTWISE_AFFINE:
335
+ grad_b = grad_out
336
+ grad_b = tl.sum(grad_b, axis=0)
337
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
338
+ tl.store(grad_b_ptrs, grad_b)
339
+
340
+ grad_w = grad_out * xhat
341
+ grad_w = tl.sum(grad_w, axis=0)
342
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
343
+ tl.store(grad_w_ptrs, grad_w)
344
+
345
+ w_ptrs = w_ptr + offs_d
346
+ w = tl.load(w_ptrs).to(tl.float32)
347
+ wdo = w * grad_out
348
+
349
+ else:
350
+ wdo = grad_out
351
+
352
+ c1 = xhat * wdo
353
+ c2 = wdo
354
+
355
+ c1_dot = tl.sum(c1, axis=1) / D
356
+ c2_dot = tl.sum(c2, axis=1) / D
357
+
358
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
359
+ tl.store(grad_x_ptrs, dx, mask=mask_n[:, None])
360
+
361
+
362
+ @triton.jit
363
+ def layer_norm_transpose_backward_kernel(
364
+ # inputs:
365
+ grad_out_ptr,
366
+ x_ptr,
367
+ w_ptr,
368
+ mean_ptr,
369
+ rstd_ptr,
370
+ # outputs: (order matters for jax_triton)
371
+ grad_x_ptr,
372
+ grad_w_ptr,
373
+ grad_b_ptr,
374
+ B,
375
+ N,
376
+ D: tl.constexpr,
377
+ TILE_N: tl.constexpr,
378
+ TILE_D: tl.constexpr,
379
+ ELEMENTWISE_AFFINE: tl.constexpr,
380
+ LAYOUT: tl.constexpr,
381
+ ):
382
+ pid_n = tl.program_id(0)
383
+ pid_b = tl.program_id(1)
384
+
385
+ num_tiles_d = tl.cdiv(D, TILE_D)
386
+ num_tiles_n = tl.cdiv(N, TILE_N)
387
+
388
+ offs_d = tl.arange(0, TILE_D)
389
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
390
+ mask_n = offs_n < N
391
+
392
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
393
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
394
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
395
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
396
+
397
+ if LAYOUT == 0: # bnd->bnd
398
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
399
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
400
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
401
+ elif LAYOUT == 1: # bdn->bnd
402
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
403
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
404
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
405
+ elif LAYOUT == 2: # bnd->bdn
406
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
407
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
408
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
409
+ elif LAYOUT == 3: # dbn->bnd
410
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
411
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
412
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
413
+ elif LAYOUT == 4: # bnd->dbn
414
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
415
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
416
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
417
+
418
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
419
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
420
+
421
+ c1 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
422
+ c2 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
423
+
424
+ for di in range(num_tiles_d):
425
+ mask_d = offs_d < D
426
+ mask_nd = mask_n[:, None] & mask_d[None, :]
427
+
428
+ if ELEMENTWISE_AFFINE:
429
+ w_ptrs = w_ptr + offs_d
430
+ w = tl.load(w_ptrs, mask=mask_d, other=1.0).to(tl.float32)
431
+ else:
432
+ w = 1.0
433
+
434
+ if LAYOUT == 0: # bnd->bnd
435
+ x_ptrs = x_base_ptrs + offs_d[None, :]
436
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
437
+ elif LAYOUT == 1: # bdn->bnd
438
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
439
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
440
+ elif LAYOUT == 2: # bnd->bdn
441
+ x_ptrs = x_base_ptrs + offs_d[None, :]
442
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
443
+ elif LAYOUT == 3: # dbn->bnd
444
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
445
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
446
+ elif LAYOUT == 4: # bnd->dbn
447
+ x_ptrs = x_base_ptrs + offs_d[None, :]
448
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
449
+
450
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
451
+ grad_out = tl.load(grad_out_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
452
+
453
+ xhat = (x - mean[:, None]) * rstd[:, None]
454
+ wdo = w * grad_out
455
+
456
+ c1 += xhat * wdo
457
+ c2 += wdo
458
+
459
+ offs_d += TILE_D
460
+
461
+ c1_dot = tl.sum(c1, axis=1) / D
462
+ c2_dot = tl.sum(c2, axis=1) / D
463
+
464
+ offs_d -= TILE_D * num_tiles_d
465
+
466
+ for di in range(num_tiles_d):
467
+ mask_d = offs_d < D
468
+ mask_nd = mask_n[:, None] & mask_d[None, :]
469
+
470
+ if ELEMENTWISE_AFFINE:
471
+ w_ptrs = w_ptr + offs_d
472
+ w = tl.load(w_ptrs, mask=mask_d, other=0.0).to(tl.float32)
473
+ else:
474
+ w = 1.0
475
+
476
+ if LAYOUT == 0: # bnd->bnd
477
+ x_ptrs = x_base_ptrs + offs_d[None, :]
478
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
479
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
480
+ elif LAYOUT == 1: # bdn->bnd
481
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
482
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
483
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
484
+ elif LAYOUT == 2: # bnd->bdn
485
+ x_ptrs = x_base_ptrs + offs_d[None, :]
486
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
487
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
488
+ elif LAYOUT == 3: # dbn->bnd
489
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
490
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
491
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
492
+ elif LAYOUT == 4: # bnd->dbn
493
+ x_ptrs = x_base_ptrs + offs_d[None, :]
494
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
495
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
496
+
497
+ x = tl.load(x_ptrs, mask=mask_nd, other=mean[:, None]).to(tl.float32)
498
+ grad_out = tl.load(grad_out_ptrs, mask=mask_nd, other=0.0).to(tl.float32)
499
+
500
+ xhat = (x - mean[:, None]) * rstd[:, None]
501
+
502
+ if ELEMENTWISE_AFFINE:
503
+ grad_b = grad_out
504
+ grad_b = tl.sum(grad_b, axis=0)
505
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
506
+ tl.store(grad_b_ptrs, grad_b, mask=mask_d)
507
+
508
+ grad_w = grad_out * xhat
509
+ grad_w = tl.sum(grad_w, axis=0)
510
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
511
+ tl.store(grad_w_ptrs, grad_w, mask=mask_d)
512
+
513
+ wdo = w * grad_out
514
+
515
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
516
+ tl.store(grad_x_ptrs, dx, mask=mask_nd)
517
+
518
+ offs_d += TILE_D