cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.0__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

@@ -0,0 +1,324 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Layout(enum.IntEnum):
18
+ BND_BND = 0
19
+ BDN_BND = 1
20
+ BND_BDN = 2
21
+ DBN_BND = 3
22
+ BND_DBN = 4
23
+
24
+
25
+ @triton.jit
26
+ def layer_norm_transpose_forward_kernel(
27
+ x_ptr,
28
+ out_ptr,
29
+ w_ptr,
30
+ b_ptr,
31
+ mean_ptr,
32
+ rstd_ptr,
33
+ B,
34
+ N,
35
+ D: tl.constexpr,
36
+ EPS: tl.constexpr,
37
+ TILE_N: tl.constexpr,
38
+ TILE_D: tl.constexpr,
39
+ ELEMENTWISE_AFFINE: tl.constexpr,
40
+ LAYOUT: tl.constexpr,
41
+ ):
42
+ pid_n = tl.program_id(0)
43
+ pid_b = tl.program_id(1)
44
+
45
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
46
+ offs_d = tl.arange(0, TILE_D)
47
+
48
+ if LAYOUT == 0: # bnd->bnd
49
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
50
+ elif LAYOUT == 1: # bdn->bnd
51
+ x_ptrs = x_ptr + pid_b * D * N + offs_d[None, :] * N + offs_n[:, None]
52
+ elif LAYOUT == 2: # bnd->bdn
53
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
54
+ elif LAYOUT == 3: # dbn->bnd
55
+ x_ptrs = x_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
56
+ elif LAYOUT == 4: # bnd->dbn
57
+ x_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
58
+
59
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
60
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
61
+ mask_n = offs_n < N
62
+
63
+ num_tiles = D // TILE_D
64
+
65
+ _mean = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
66
+ for _ in range(0, num_tiles):
67
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
68
+ _mean += x
69
+
70
+ if LAYOUT == 0: # bnd->bnd
71
+ x_ptrs += TILE_D
72
+ elif LAYOUT == 1: # bdn->bnd
73
+ x_ptrs += TILE_D * N
74
+ elif LAYOUT == 2: # bnd->bdn
75
+ x_ptrs += TILE_D
76
+ elif LAYOUT == 3: # dbn->bnd
77
+ x_ptrs += TILE_D * B * N
78
+ elif LAYOUT == 4: # bnd->dbn
79
+ x_ptrs += TILE_D
80
+
81
+ mean = tl.sum(_mean, axis=1) / D
82
+ tl.store(mean_ptrs, mean, mask=mask_n)
83
+
84
+ if LAYOUT == 0: # bnd->bnd
85
+ x_ptrs -= D
86
+ elif LAYOUT == 1: # bdn->bnd
87
+ x_ptrs -= D * N
88
+ elif LAYOUT == 2: # bnd->bdn
89
+ x_ptrs -= D
90
+ elif LAYOUT == 3: # dbn->bnd
91
+ x_ptrs -= D * B * N
92
+ elif LAYOUT == 4: # bnd->dbn
93
+ x_ptrs -= D
94
+
95
+ _var = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
96
+ for d in range(0, num_tiles):
97
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
98
+ x = x - mean[:, None]
99
+ _var += x * x
100
+
101
+ if LAYOUT == 0: # bnd->bnd
102
+ x_ptrs += TILE_D
103
+ elif LAYOUT == 1: # bdn->bnd
104
+ x_ptrs += TILE_D * N
105
+ elif LAYOUT == 2: # bnd->bdn
106
+ x_ptrs += TILE_D
107
+ elif LAYOUT == 3: # dbn->bnd
108
+ x_ptrs += TILE_D * B * N
109
+ elif LAYOUT == 4: # bnd->dbn
110
+ x_ptrs += TILE_D
111
+
112
+ var = tl.sum(_var, axis=1) / D
113
+ rstd = 1.0 / tl.sqrt(var + EPS)
114
+ tl.store(rstd_ptrs, rstd, mask=mask_n)
115
+
116
+ if LAYOUT == 0: # bnd->bnd
117
+ x_ptrs -= D
118
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
119
+ elif LAYOUT == 1: # bdn->bnd
120
+ x_ptrs -= D * N
121
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
122
+ elif LAYOUT == 2: # bnd->bdn
123
+ x_ptrs -= D
124
+ out_ptrs = out_ptr + pid_b * N * D + offs_d[None, :] * N + offs_n[:, None]
125
+ elif LAYOUT == 3: # dbn->bnd
126
+ x_ptrs -= D * B * N
127
+ out_ptrs = out_ptr + pid_b * N * D + offs_n[:, None] * D + offs_d[None, :]
128
+ elif LAYOUT == 4: # bnd->dbn
129
+ x_ptrs -= D
130
+ out_ptrs = out_ptr + offs_d[None, :] * B * N + pid_b * N + offs_n[:, None]
131
+
132
+ if ELEMENTWISE_AFFINE:
133
+ w_ptrs = w_ptr + offs_d
134
+ b_ptrs = b_ptr + offs_d
135
+
136
+ for _ in range(0, num_tiles):
137
+ if ELEMENTWISE_AFFINE:
138
+ w = tl.load(w_ptrs)
139
+ b = tl.load(b_ptrs)
140
+ else:
141
+ w = 1.0
142
+ b = 0.0
143
+
144
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
145
+ x_hat = (x - mean[:, None]) * rstd[:, None]
146
+ y = x_hat * w[None, :] + b[None, :]
147
+ tl.store(out_ptrs, y, mask=mask_n[:, None])
148
+
149
+ if LAYOUT == 0: # bnd->bnd
150
+ x_ptrs += TILE_D
151
+ out_ptrs += TILE_D
152
+ elif LAYOUT == 1: # bdn->bnd
153
+ x_ptrs += TILE_D * N
154
+ out_ptrs += TILE_D
155
+ elif LAYOUT == 2: # bnd->bdn
156
+ x_ptrs += TILE_D
157
+ out_ptrs += TILE_D * N
158
+ elif LAYOUT == 3: # dbn->bnd
159
+ x_ptrs += TILE_D * B * N
160
+ out_ptrs += TILE_D
161
+ elif LAYOUT == 4: # bnd->dbn
162
+ x_ptrs += TILE_D
163
+ out_ptrs += TILE_D * B * N
164
+
165
+ if ELEMENTWISE_AFFINE:
166
+ w_ptrs += TILE_D
167
+ b_ptrs += TILE_D
168
+
169
+
170
+ @triton.jit
171
+ def layer_norm_transpose_backward_kernel(
172
+ grad_out_ptr,
173
+ grad_x_ptr,
174
+ grad_w_ptr,
175
+ grad_b_ptr,
176
+ x_ptr,
177
+ w_ptr,
178
+ mean_ptr,
179
+ rstd_ptr,
180
+ B,
181
+ N,
182
+ D: tl.constexpr,
183
+ TILE_N: tl.constexpr,
184
+ TILE_D: tl.constexpr,
185
+ ELEMENTWISE_AFFINE: tl.constexpr,
186
+ LAYOUT: tl.constexpr,
187
+ ):
188
+ pid_n = tl.program_id(0)
189
+ pid_b = tl.program_id(1)
190
+
191
+ num_tiles = D // TILE_D
192
+ num_tiles_n = tl.cdiv(N, TILE_N)
193
+
194
+ offs_d = tl.arange(0, TILE_D)
195
+ offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
196
+ mask_n = offs_n < N
197
+
198
+ mean_ptrs = mean_ptr + pid_b * N + offs_n
199
+ rstd_ptrs = rstd_ptr + pid_b * N + offs_n
200
+ mean = tl.load(mean_ptrs, mask=mask_n, other=0.0).to(tl.float32)
201
+ rstd = tl.load(rstd_ptrs, mask=mask_n, other=0.0).to(tl.float32)
202
+
203
+ if LAYOUT == 0: # bnd->bnd
204
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
205
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
206
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
207
+ elif LAYOUT == 1: # bdn->bnd
208
+ x_base_ptrs = x_ptr + pid_b * D * N + offs_n[:, None]
209
+ grad_x_base_ptrs = grad_x_ptr + pid_b * D * N + offs_n[:, None]
210
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
211
+ elif LAYOUT == 2: # bnd->bdn
212
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
213
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
214
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None]
215
+ elif LAYOUT == 3: # dbn->bnd
216
+ x_base_ptrs = x_ptr + pid_b * N + offs_n[:, None]
217
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N + offs_n[:, None]
218
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N * D + offs_n[:, None] * D
219
+ elif LAYOUT == 4: # bnd->dbn
220
+ x_base_ptrs = x_ptr + pid_b * N * D + offs_n[:, None] * D
221
+ grad_x_base_ptrs = grad_x_ptr + pid_b * N * D + offs_n[:, None] * D
222
+ grad_out_base_ptrs = grad_out_ptr + pid_b * N + offs_n[:, None]
223
+
224
+ grad_w_base_ptrs = grad_w_ptr + pid_b * num_tiles_n * D + pid_n * D
225
+ grad_b_base_ptrs = grad_b_ptr + pid_b * num_tiles_n * D + pid_n * D
226
+
227
+ c1 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
228
+ c2 = tl.zeros([TILE_N, TILE_D], dtype=tl.float32)
229
+
230
+ for _ in range(num_tiles):
231
+ if ELEMENTWISE_AFFINE:
232
+ w_ptrs = w_ptr + offs_d
233
+ w = tl.load(w_ptrs).to(tl.float32)
234
+ else:
235
+ w = 1.0
236
+
237
+ if LAYOUT == 0: # bnd->bnd
238
+ x_ptrs = x_base_ptrs + offs_d[None, :]
239
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
240
+ elif LAYOUT == 1: # bdn->bnd
241
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
242
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
243
+ elif LAYOUT == 2: # bnd->bdn
244
+ x_ptrs = x_base_ptrs + offs_d[None, :]
245
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
246
+ elif LAYOUT == 3: # dbn->bnd
247
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
248
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
249
+ elif LAYOUT == 4: # bnd->dbn
250
+ x_ptrs = x_base_ptrs + offs_d[None, :]
251
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
252
+
253
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
254
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(
255
+ tl.float32
256
+ )
257
+
258
+ xhat = (x - mean[:, None]) * rstd[:, None]
259
+
260
+ if ELEMENTWISE_AFFINE:
261
+ grad_b = grad_out
262
+ grad_w = grad_out * xhat
263
+
264
+ grad_b = tl.sum(grad_b, axis=0)
265
+ grad_w = tl.sum(grad_w, axis=0)
266
+
267
+ grad_w_ptrs = grad_w_base_ptrs + offs_d
268
+ grad_b_ptrs = grad_b_base_ptrs + offs_d
269
+
270
+ tl.store(grad_w_ptrs, grad_w)
271
+ tl.store(grad_b_ptrs, grad_b)
272
+
273
+ wdo = w * grad_out
274
+
275
+ c1 += xhat * wdo
276
+ c2 += wdo
277
+
278
+ offs_d += TILE_D
279
+
280
+ c1_dot = tl.sum(c1, axis=1) / D
281
+ c2_dot = tl.sum(c2, axis=1) / D
282
+
283
+ offs_d -= TILE_D * num_tiles
284
+
285
+ for _ in range(num_tiles):
286
+ if ELEMENTWISE_AFFINE:
287
+ w_ptrs = w_ptr + offs_d
288
+ w = tl.load(w_ptrs).to(tl.float32)
289
+ else:
290
+ w = 1.0
291
+
292
+ if LAYOUT == 0: # bnd->bnd
293
+ x_ptrs = x_base_ptrs + offs_d[None, :]
294
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
295
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
296
+ elif LAYOUT == 1: # bdn->bnd
297
+ x_ptrs = x_base_ptrs + offs_d[None, :] * N
298
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * N
299
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
300
+ elif LAYOUT == 2: # bnd->bdn
301
+ x_ptrs = x_base_ptrs + offs_d[None, :]
302
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
303
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * N
304
+ elif LAYOUT == 3: # dbn->bnd
305
+ x_ptrs = x_base_ptrs + offs_d[None, :] * B * N
306
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :] * B * N
307
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :]
308
+ elif LAYOUT == 4: # bnd->dbn
309
+ x_ptrs = x_base_ptrs + offs_d[None, :]
310
+ grad_x_ptrs = grad_x_base_ptrs + offs_d[None, :]
311
+ grad_out_ptrs = grad_out_base_ptrs + offs_d[None, :] * B * N
312
+
313
+ x = tl.load(x_ptrs, mask=mask_n[:, None], other=0.0).to(tl.float32)
314
+ grad_out = tl.load(grad_out_ptrs, mask=mask_n[:, None], other=0.0).to(
315
+ tl.float32
316
+ )
317
+
318
+ xhat = (x - mean[:, None]) * rstd[:, None]
319
+ wdo = w * grad_out
320
+
321
+ dx = (wdo - (xhat * c1_dot[:, None] + c2_dot[:, None])) * rstd[:, None]
322
+ tl.store(grad_x_ptrs, dx, mask=mask_n[:, None])
323
+
324
+ offs_d += TILE_D