cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,394 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from cuequivariance_ops.triton.utils import cvt_tf32_rn
16
+
17
+
18
+ @triton.jit
19
+ def fused_sigmoid_gated_dual_gemm_forward_kernel(
20
+ # inputs
21
+ x1_ptr,
22
+ x2_ptr,
23
+ w1_ptr,
24
+ w2_ptr,
25
+ b1_ptr,
26
+ b2_ptr,
27
+ mask_ptr,
28
+ M,
29
+ N,
30
+ K,
31
+ # outputs
32
+ o_ptr,
33
+ TILE_M: tl.constexpr,
34
+ TILE_N: tl.constexpr,
35
+ TILE_K: tl.constexpr,
36
+ PRECISION: tl.constexpr,
37
+ APPLY_MASK: tl.constexpr,
38
+ TRANSPOSE_OUT: tl.constexpr,
39
+ TWO_INPUTS: tl.constexpr,
40
+ HAS_B1: tl.constexpr,
41
+ HAS_B2: tl.constexpr,
42
+ NEEDS_INT64: tl.constexpr = True,
43
+ ):
44
+ # fully gated GEMM kernel with optional mask at the end
45
+ pid_m = tl.program_id(axis=0)
46
+ pid_n = tl.program_id(axis=1)
47
+
48
+ if NEEDS_INT64:
49
+ pid_m = tl.cast(pid_m, tl.int64)
50
+ pid_n = tl.cast(pid_n, tl.int64)
51
+ M = tl.cast(M, tl.int64)
52
+ N = tl.cast(N, tl.int64)
53
+
54
+ start_m = pid_m * TILE_M
55
+ start_n = pid_n * TILE_N
56
+
57
+ offs_xm = start_m + tl.arange(0, TILE_M)
58
+ offs_wn = start_n + tl.arange(0, TILE_N)
59
+ offs_k = tl.arange(0, TILE_K)
60
+
61
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
62
+ if TWO_INPUTS:
63
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
64
+
65
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
66
+
67
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
68
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
69
+
70
+ mask_m = offs_xm < M
71
+
72
+ if TWO_INPUTS:
73
+ for _ in range(0, tl.cdiv(K, TILE_K)):
74
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
75
+ w1_ptr.type.element_ty
76
+ )
77
+ w1_ptrs = w1_ptr + w_tile_offs
78
+ w1 = tl.load(w1_ptrs)
79
+ if PRECISION == 0:
80
+ acc_1 = tl.dot(x1, w1, acc_1)
81
+ elif PRECISION == 1:
82
+ x1 = cvt_tf32_rn(x1)
83
+ w1 = cvt_tf32_rn(w1)
84
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
85
+ elif PRECISION == 2:
86
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
87
+ elif PRECISION == 3:
88
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
89
+ else:
90
+ tl.static_assert(
91
+ False,
92
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
93
+ )
94
+
95
+ x1_ptrs += TILE_K
96
+ w1_ptr += TILE_K
97
+
98
+ for _ in range(0, tl.cdiv(K, TILE_K)):
99
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
100
+ w2_ptr.type.element_ty
101
+ )
102
+ w2_ptrs = w2_ptr + w_tile_offs
103
+ w2 = tl.load(w2_ptrs)
104
+ if PRECISION == 0:
105
+ acc_2 = tl.dot(x2, w2, acc_2)
106
+ elif PRECISION == 1:
107
+ x2 = cvt_tf32_rn(x2)
108
+ w2 = cvt_tf32_rn(w2)
109
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
110
+ elif PRECISION == 2:
111
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
112
+ elif PRECISION == 3:
113
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
114
+ else:
115
+ tl.static_assert(
116
+ False,
117
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
118
+ )
119
+
120
+ x2_ptrs += TILE_K
121
+ w2_ptr += TILE_K
122
+
123
+ else:
124
+ for _ in range(0, tl.cdiv(K, TILE_K)):
125
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
126
+ w1_ptr.type.element_ty
127
+ )
128
+
129
+ w1_ptrs = w1_ptr + w_tile_offs
130
+ w1 = tl.load(w1_ptrs)
131
+ if PRECISION == 0:
132
+ acc_1 = tl.dot(x, w1, acc_1)
133
+ elif PRECISION == 1:
134
+ x = cvt_tf32_rn(x)
135
+ w1 = cvt_tf32_rn(w1)
136
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
137
+ elif PRECISION == 2:
138
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
139
+ elif PRECISION == 3:
140
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
141
+ else:
142
+ tl.static_assert(
143
+ False,
144
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
145
+ )
146
+
147
+ w2_ptrs = w2_ptr + w_tile_offs
148
+ w2 = tl.load(w2_ptrs)
149
+ if PRECISION == 0:
150
+ acc_2 = tl.dot(x, w2, acc_2)
151
+ elif PRECISION == 1:
152
+ x = cvt_tf32_rn(x)
153
+ w2 = cvt_tf32_rn(w2)
154
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
155
+ elif PRECISION == 2:
156
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
157
+ elif PRECISION == 3:
158
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
159
+ else:
160
+ tl.static_assert(
161
+ False,
162
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
163
+ )
164
+
165
+ x1_ptrs += TILE_K
166
+ w1_ptr += TILE_K
167
+ w2_ptr += TILE_K
168
+
169
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
170
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
171
+
172
+ if HAS_B1:
173
+ b1_ptrs = b1_ptr + offs_on
174
+ b1_tile = tl.load(b1_ptrs).to(tl.float32)
175
+ acc_1 += b1_tile
176
+
177
+ if HAS_B2:
178
+ b2_ptrs = b2_ptr + offs_on
179
+ b2_tile = tl.load(b2_ptrs).to(tl.float32)
180
+ acc_2 += b2_tile
181
+
182
+ acc_1 = 1.0 / (1.0 + tl.exp(-acc_1))
183
+ acc_gated = acc_1 * acc_2
184
+
185
+ if APPLY_MASK:
186
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0).to(tl.float32)
187
+ acc_gated = acc_gated * mask[:, None]
188
+
189
+ if TRANSPOSE_OUT:
190
+ o_ptrs = o_ptr + offs_on[None, :] * M + offs_om[:, None]
191
+ else:
192
+ o_ptrs = o_ptr + offs_om[:, None] * N + offs_on[None, :]
193
+
194
+ o_mask = offs_om[:, None] < M
195
+ tl.store(o_ptrs, acc_gated, mask=o_mask)
196
+
197
+
198
+ @triton.jit
199
+ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel(
200
+ # inputs
201
+ grad_o_ptr,
202
+ x1_ptr,
203
+ x2_ptr,
204
+ w1_ptr,
205
+ w2_ptr,
206
+ b1_ptr,
207
+ b2_ptr,
208
+ mask_ptr,
209
+ M,
210
+ N,
211
+ K,
212
+ # outputs
213
+ grad_xw1_ptr,
214
+ grad_xw2_ptr,
215
+ grad_mask_ptr,
216
+ TILE_M: tl.constexpr,
217
+ TILE_N: tl.constexpr,
218
+ TILE_K: tl.constexpr,
219
+ PRECISION: tl.constexpr,
220
+ APPLY_MASK: tl.constexpr,
221
+ TRANSPOSE_OUT: tl.constexpr,
222
+ TWO_INPUTS: tl.constexpr,
223
+ HAS_B1: tl.constexpr,
224
+ HAS_B2: tl.constexpr,
225
+ NEEDS_INT64: tl.constexpr = True,
226
+ ):
227
+ # fully gated GEMM kernel with optional mask at the end
228
+ pid_m = tl.program_id(axis=0)
229
+ pid_n = tl.program_id(axis=1)
230
+
231
+ if NEEDS_INT64:
232
+ pid_m = tl.cast(pid_m, tl.int64)
233
+ pid_n = tl.cast(pid_n, tl.int64)
234
+ M = tl.cast(M, tl.int64)
235
+ N = tl.cast(N, tl.int64)
236
+
237
+ start_m = pid_m * TILE_M
238
+ start_n = pid_n * TILE_N
239
+
240
+ offs_xm = start_m + tl.arange(0, TILE_M)
241
+ offs_wn = start_n + tl.arange(0, TILE_N)
242
+ offs_k = tl.arange(0, TILE_K)
243
+
244
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
245
+ if TWO_INPUTS:
246
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
247
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
248
+
249
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
250
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
251
+
252
+ mask_m = offs_xm < M
253
+
254
+ if TWO_INPUTS:
255
+ # recompute acc1 and acc2
256
+ for _ in range(0, tl.cdiv(K, TILE_K)):
257
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
258
+ w1_ptr.type.element_ty
259
+ )
260
+ w1_ptrs = w1_ptr + w_tile_offs
261
+ w1 = tl.load(w1_ptrs)
262
+
263
+ if PRECISION == 0:
264
+ acc_1 = tl.dot(x1, w1, acc_1)
265
+ elif PRECISION == 1:
266
+ x1 = cvt_tf32_rn(x1)
267
+ w1 = cvt_tf32_rn(w1)
268
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
269
+ elif PRECISION == 2:
270
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
271
+ elif PRECISION == 3:
272
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
273
+ else:
274
+ tl.static_assert(
275
+ False,
276
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
277
+ )
278
+
279
+ x1_ptrs += TILE_K
280
+ w1_ptr += TILE_K
281
+
282
+ for _ in range(0, tl.cdiv(K, TILE_K)):
283
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
284
+ w2_ptr.type.element_ty
285
+ )
286
+ w2_ptrs = w2_ptr + w_tile_offs
287
+ w2 = tl.load(w2_ptrs)
288
+
289
+ if PRECISION == 0:
290
+ acc_2 = tl.dot(x2, w2, acc_2)
291
+ elif PRECISION == 1:
292
+ x2 = cvt_tf32_rn(x2)
293
+ w2 = cvt_tf32_rn(w2)
294
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
295
+ elif PRECISION == 2:
296
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
297
+ elif PRECISION == 3:
298
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
299
+ else:
300
+ tl.static_assert(
301
+ False,
302
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
303
+ )
304
+
305
+ x2_ptrs += TILE_K
306
+ w2_ptr += TILE_K
307
+
308
+ else:
309
+ # recompute acc1 and acc2
310
+ for _ in range(0, tl.cdiv(K, TILE_K)):
311
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
312
+ w1_ptr.type.element_ty
313
+ )
314
+
315
+ w1_ptrs = w1_ptr + w_tile_offs
316
+ w1 = tl.load(w1_ptrs)
317
+ if PRECISION == 0:
318
+ acc_1 = tl.dot(x, w1, acc_1)
319
+ elif PRECISION == 1:
320
+ x = cvt_tf32_rn(x)
321
+ w1 = cvt_tf32_rn(w1)
322
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
323
+ elif PRECISION == 2:
324
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
325
+ elif PRECISION == 3:
326
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
327
+ else:
328
+ tl.static_assert(
329
+ False,
330
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
331
+ )
332
+
333
+ w2_ptrs = w2_ptr + w_tile_offs
334
+ w2 = tl.load(w2_ptrs)
335
+ if PRECISION == 0:
336
+ acc_2 = tl.dot(x, w2, acc_2)
337
+ elif PRECISION == 1:
338
+ x = cvt_tf32_rn(x)
339
+ w2 = cvt_tf32_rn(w2)
340
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
341
+ elif PRECISION == 2:
342
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
343
+ elif PRECISION == 3:
344
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
345
+ else:
346
+ tl.static_assert(
347
+ False,
348
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
349
+ )
350
+
351
+ x1_ptrs += TILE_K
352
+ w1_ptr += TILE_K
353
+ w2_ptr += TILE_K
354
+
355
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
356
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
357
+
358
+ if HAS_B1:
359
+ b1_ptrs = b1_ptr + offs_on
360
+ b1_tile = tl.load(b1_ptrs).to(tl.float32)
361
+ acc_1 += b1_tile
362
+
363
+ if HAS_B2:
364
+ b2_ptrs = b2_ptr + offs_on
365
+ b2_tile = tl.load(b2_ptrs).to(tl.float32)
366
+ acc_2 += b2_tile
367
+
368
+ if TRANSPOSE_OUT:
369
+ grad_o_ptrs = grad_o_ptr + offs_on[None, :] * M + offs_om[:, None]
370
+ else:
371
+ grad_o_ptrs = grad_o_ptr + offs_om[:, None] * N + offs_on[None, :]
372
+
373
+ grad_o = tl.load(grad_o_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
374
+
375
+ acc_sig = 1.0 / (1.0 + tl.exp(-acc_1))
376
+
377
+ if APPLY_MASK:
378
+ tmp = acc_sig * acc_2
379
+ grad_mask = grad_o * tmp
380
+ grad_mask = tl.sum(grad_mask, axis=1)
381
+ grad_mask_ptrs = grad_mask_ptr + pid_n * M + offs_om
382
+ tl.store(grad_mask_ptrs, grad_mask, mask=mask_m)
383
+
384
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0).to(tl.float32)
385
+ grad_o = grad_o * mask[:, None]
386
+
387
+ grad_xw2 = grad_o * acc_sig
388
+ grad_xw2_ptrs = grad_xw2_ptr + offs_om[:, None] * N + offs_on[None, :]
389
+ tl.store(grad_xw2_ptrs, grad_xw2, mask=mask_m[:, None])
390
+
391
+ tmp = (1.0 - acc_sig) * acc_sig
392
+ grad_xw1 = grad_o * acc_2 * tmp
393
+ grad_xw1_ptrs = grad_xw1_ptr + offs_om[:, None] * N + offs_on[None, :]
394
+ tl.store(grad_xw1_ptrs, grad_xw1, mask=mask_m[:, None])