cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.1__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (28) hide show
  1. cuequivariance_ops/VERSION +1 -1
  2. cuequivariance_ops/__init__.py +3 -2
  3. cuequivariance_ops/equivariance/dtypes.hh +21 -0
  4. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  5. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  6. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  7. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +17 -35
  8. cuequivariance_ops/lib/libcue_ops.so +0 -0
  9. cuequivariance_ops/triton/__init__.py +29 -0
  10. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  11. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  12. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  13. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  14. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  15. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  16. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  20. cuequivariance_ops/triton/cache_manager.py +244 -0
  21. cuequivariance_ops/triton/fused_layer_norm_triton.py +324 -0
  22. cuequivariance_ops/triton/gated_gemm_triton.py +340 -0
  23. cuequivariance_ops/triton/tuning_decorator.py +272 -0
  24. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/METADATA +5 -1
  25. cuequivariance_ops_cu12-0.5.1.dist-info/RECORD +32 -0
  26. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/WHEEL +1 -1
  27. cuequivariance_ops_cu12-0.4.0.dist-info/RECORD +0 -13
  28. {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,340 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Precision(enum.Enum):
18
+ DEFAULT = 0
19
+ TF32 = 1
20
+ TF32x3 = 2
21
+ IEEE = 3
22
+
23
+
24
+ @triton.jit
25
+ def fused_sigmoid_gated_dual_gemm_forward_kernel(
26
+ x1_ptr,
27
+ x2_ptr,
28
+ w1_ptr,
29
+ w2_ptr,
30
+ mask_ptr,
31
+ o_ptr,
32
+ M,
33
+ N,
34
+ K,
35
+ TILE_M: tl.constexpr,
36
+ TILE_N: tl.constexpr,
37
+ TILE_K: tl.constexpr,
38
+ PRECISION: tl.constexpr,
39
+ APPLY_MASK: tl.constexpr,
40
+ TRANSPOSE_OUT: tl.constexpr,
41
+ TWO_INPUTS: tl.constexpr,
42
+ ):
43
+ # fully gated GEMM kernel with optional mask at the end
44
+ pid_m = tl.program_id(axis=0)
45
+ pid_n = tl.program_id(axis=1)
46
+
47
+ start_m = pid_m * TILE_M
48
+ start_n = pid_n * TILE_N
49
+
50
+ offs_xm = start_m + tl.arange(0, TILE_M)
51
+ offs_wn = start_n + tl.arange(0, TILE_N)
52
+ offs_k = tl.arange(0, TILE_K)
53
+
54
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
55
+ if TWO_INPUTS:
56
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
57
+
58
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
59
+
60
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
61
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
62
+
63
+ mask_m = offs_xm < M
64
+
65
+ if TWO_INPUTS:
66
+ for _ in range(0, tl.cdiv(K, TILE_K)):
67
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
68
+ w1_ptr.type.element_ty
69
+ )
70
+ w1_ptrs = w1_ptr + w_tile_offs
71
+ w1 = tl.load(w1_ptrs)
72
+ if PRECISION == 0:
73
+ acc_1 = tl.dot(x1, w1, acc_1)
74
+ elif PRECISION == 1:
75
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
76
+ elif PRECISION == 2:
77
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
78
+ elif PRECISION == 3:
79
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
80
+ else:
81
+ tl.static_assert(
82
+ False,
83
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
84
+ )
85
+
86
+ x1_ptrs += TILE_K
87
+ w1_ptr += TILE_K
88
+
89
+ for _ in range(0, tl.cdiv(K, TILE_K)):
90
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
91
+ w2_ptr.type.element_ty
92
+ )
93
+ w2_ptrs = w2_ptr + w_tile_offs
94
+ w2 = tl.load(w2_ptrs)
95
+ if PRECISION == 0:
96
+ acc_2 = tl.dot(x2, w2, acc_2)
97
+ elif PRECISION == 1:
98
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
99
+ elif PRECISION == 2:
100
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
101
+ elif PRECISION == 3:
102
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
103
+ else:
104
+ tl.static_assert(
105
+ False,
106
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
107
+ )
108
+
109
+ x2_ptrs += TILE_K
110
+ w2_ptr += TILE_K
111
+
112
+ else:
113
+ for _ in range(0, tl.cdiv(K, TILE_K)):
114
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
115
+ w1_ptr.type.element_ty
116
+ )
117
+
118
+ w1_ptrs = w1_ptr + w_tile_offs
119
+ w1 = tl.load(w1_ptrs)
120
+ if PRECISION == 0:
121
+ acc_1 = tl.dot(x, w1, acc_1)
122
+ elif PRECISION == 1:
123
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
124
+ elif PRECISION == 2:
125
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
126
+ elif PRECISION == 3:
127
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
128
+ else:
129
+ tl.static_assert(
130
+ False,
131
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
132
+ )
133
+
134
+ w2_ptrs = w2_ptr + w_tile_offs
135
+ w2 = tl.load(w2_ptrs)
136
+ if PRECISION == 0:
137
+ acc_2 = tl.dot(x, w2, acc_2)
138
+ elif PRECISION == 1:
139
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
140
+ elif PRECISION == 2:
141
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
142
+ elif PRECISION == 3:
143
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
144
+ else:
145
+ tl.static_assert(
146
+ False,
147
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
148
+ )
149
+
150
+ x1_ptrs += TILE_K
151
+ w1_ptr += TILE_K
152
+ w2_ptr += TILE_K
153
+
154
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
155
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
156
+
157
+ acc_1 = 1.0 / (1.0 + tl.exp(-acc_1))
158
+ acc_gated = acc_1 * acc_2
159
+
160
+ if APPLY_MASK:
161
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
162
+ acc_gated = acc_gated * mask[:, None]
163
+
164
+ acc_gated = acc_gated.to(o_ptr.dtype.element_ty)
165
+
166
+ if TRANSPOSE_OUT:
167
+ o_ptrs = o_ptr + offs_on[None, :] * M + offs_om[:, None]
168
+ else:
169
+ o_ptrs = o_ptr + offs_om[:, None] * N + offs_on[None, :]
170
+
171
+ o_mask = offs_om[:, None] < M
172
+ tl.store(o_ptrs, acc_gated, mask=o_mask)
173
+
174
+
175
+ @triton.jit
176
+ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel(
177
+ grad_xw1_ptr,
178
+ grad_xw2_ptr,
179
+ grad_mask_ptr,
180
+ grad_o_ptr,
181
+ x1_ptr,
182
+ x2_ptr,
183
+ w1_ptr,
184
+ w2_ptr,
185
+ mask_ptr,
186
+ M,
187
+ N,
188
+ K,
189
+ TILE_M: tl.constexpr,
190
+ TILE_N: tl.constexpr,
191
+ TILE_K: tl.constexpr,
192
+ PRECISION: tl.constexpr,
193
+ APPLY_MASK: tl.constexpr,
194
+ TRANSPOSE_OUT: tl.constexpr,
195
+ TWO_INPUTS: tl.constexpr,
196
+ ):
197
+ # fully gated GEMM kernel with optional mask at the end
198
+ pid_m = tl.program_id(axis=0)
199
+ pid_n = tl.program_id(axis=1)
200
+
201
+ start_m = pid_m * TILE_M
202
+ start_n = pid_n * TILE_N
203
+
204
+ offs_xm = start_m + tl.arange(0, TILE_M)
205
+ offs_wn = start_n + tl.arange(0, TILE_N)
206
+ offs_k = tl.arange(0, TILE_K)
207
+
208
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
209
+ if TWO_INPUTS:
210
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
211
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
212
+
213
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
214
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
215
+
216
+ mask_m = offs_xm < M
217
+
218
+ if TWO_INPUTS:
219
+ # recompute acc1 and acc2
220
+ for _ in range(0, tl.cdiv(K, TILE_K)):
221
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
222
+ w1_ptr.type.element_ty
223
+ )
224
+ w1_ptrs = w1_ptr + w_tile_offs
225
+ w1 = tl.load(w1_ptrs)
226
+
227
+ if PRECISION == 0:
228
+ acc_1 = tl.dot(x1, w1, acc_1)
229
+ elif PRECISION == 1:
230
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
231
+ elif PRECISION == 2:
232
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
233
+ elif PRECISION == 3:
234
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
235
+ else:
236
+ tl.static_assert(
237
+ False,
238
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
239
+ )
240
+
241
+ x1_ptrs += TILE_K
242
+ w1_ptr += TILE_K
243
+
244
+ for _ in range(0, tl.cdiv(K, TILE_K)):
245
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
246
+ w2_ptr.type.element_ty
247
+ )
248
+ w2_ptrs = w2_ptr + w_tile_offs
249
+ w2 = tl.load(w2_ptrs)
250
+
251
+ if PRECISION == 0:
252
+ acc_2 = tl.dot(x2, w2, acc_2)
253
+ elif PRECISION == 1:
254
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
255
+ elif PRECISION == 2:
256
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
257
+ elif PRECISION == 3:
258
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
259
+ else:
260
+ tl.static_assert(
261
+ False,
262
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
263
+ )
264
+
265
+ x2_ptrs += TILE_K
266
+ w2_ptr += TILE_K
267
+
268
+ else:
269
+ # recompute acc1 and acc2
270
+ for _ in range(0, tl.cdiv(K, TILE_K)):
271
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
272
+ w1_ptr.type.element_ty
273
+ )
274
+
275
+ w1_ptrs = w1_ptr + w_tile_offs
276
+ w1 = tl.load(w1_ptrs)
277
+ if PRECISION == 0:
278
+ acc_1 = tl.dot(x, w1, acc_1)
279
+ elif PRECISION == 1:
280
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
281
+ elif PRECISION == 2:
282
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
283
+ elif PRECISION == 3:
284
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
285
+ else:
286
+ tl.static_assert(
287
+ False,
288
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
289
+ )
290
+
291
+ w2_ptrs = w2_ptr + w_tile_offs
292
+ w2 = tl.load(w2_ptrs)
293
+ if PRECISION == 0:
294
+ acc_2 = tl.dot(x, w2, acc_2)
295
+ elif PRECISION == 1:
296
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
297
+ elif PRECISION == 2:
298
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
299
+ elif PRECISION == 3:
300
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
301
+ else:
302
+ tl.static_assert(
303
+ False,
304
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
305
+ )
306
+
307
+ x1_ptrs += TILE_K
308
+ w1_ptr += TILE_K
309
+ w2_ptr += TILE_K
310
+
311
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
312
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
313
+ if TRANSPOSE_OUT:
314
+ grad_o_ptrs = grad_o_ptr + offs_on[None, :] * M + offs_om[:, None]
315
+ else:
316
+ grad_o_ptrs = grad_o_ptr + offs_om[:, None] * N + offs_on[None, :]
317
+
318
+ grad_o = tl.load(grad_o_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
319
+
320
+ acc_sig = 1.0 / (1.0 + tl.exp(-acc_1))
321
+
322
+ if APPLY_MASK:
323
+ tmp = acc_sig * acc_2
324
+ grad_mask = grad_o * tmp
325
+ grad_mask = tl.sum(grad_mask, axis=1)
326
+ grad_mask_ptrs = grad_mask_ptr + pid_n * M + offs_om
327
+ tl.store(grad_mask_ptrs, grad_mask.to(grad_mask.type.element_ty), mask=mask_m)
328
+
329
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
330
+ grad_o = grad_o * mask[:, None]
331
+
332
+ tmp = (1.0 - acc_sig) * acc_sig
333
+
334
+ grad_xw1 = grad_o * acc_2 * tmp
335
+ grad_xw2 = grad_o * acc_sig
336
+
337
+ grad_xw1_ptrs = grad_xw1_ptr + offs_om[:, None] * N + offs_on[None, :]
338
+ grad_xw2_ptrs = grad_xw2_ptr + offs_om[:, None] * N + offs_on[None, :]
339
+ tl.store(grad_xw1_ptrs, grad_xw1.to(grad_xw1.type.element_ty), mask=mask_m[:, None])
340
+ tl.store(grad_xw2_ptrs, grad_xw2.to(grad_xw2.type.element_ty), mask=mask_m[:, None])
@@ -0,0 +1,272 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import inspect
12
+ import logging # Added logging import
13
+ from enum import Enum
14
+ from typing import Any, Callable
15
+
16
+ import torch
17
+ from tqdm import tqdm
18
+
19
+ from .cache_manager import get_cache_manager
20
+
21
+ # Configure logging
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class BenchmarkMode(Enum):
26
+ FLUSH_CACHE = 0
27
+ FLUSH_CACHE_PEAK_PROXY = 1
28
+ ROT_BUFFER = 2
29
+ ROT_BUFFER_PEAK_PROXY = 3
30
+
31
+
32
+ def run_bench(
33
+ f, input_dict, warmup_iter=250, run_iter=250, bench_mode=BenchmarkMode.ROT_BUFFER
34
+ ):
35
+ initial_output = f(**input_dict)
36
+
37
+ if bench_mode in (BenchmarkMode.ROT_BUFFER, BenchmarkMode.ROT_BUFFER_PEAK_PROXY):
38
+ len_rot = 4
39
+ inputs_rot = [None] * len_rot
40
+ for r in range(len_rot):
41
+ r_inputs = []
42
+ for key, value in input_dict.items():
43
+ if isinstance(value, torch.Tensor):
44
+ if bench_mode == BenchmarkMode.ROT_BUFFER_PEAK_PROXY:
45
+ r_inputs.append(
46
+ (
47
+ key,
48
+ torch.ones_like(
49
+ value, requires_grad=value.requires_grad
50
+ ),
51
+ )
52
+ )
53
+ else:
54
+ r_inputs.append(
55
+ (
56
+ key,
57
+ torch.randn_like(
58
+ value, requires_grad=value.requires_grad
59
+ ),
60
+ )
61
+ )
62
+ else:
63
+ r_inputs.append((key, value))
64
+ r_inputs = dict(r_inputs)
65
+ inputs_rot[r] = r_inputs
66
+
67
+ for it in range(warmup_iter):
68
+ _ = f(**inputs_rot[it % len_rot])
69
+
70
+ start = torch.cuda.Event(enable_timing=True)
71
+ end = torch.cuda.Event(enable_timing=True)
72
+ start.record()
73
+ for it in range(run_iter):
74
+ _ = f(**inputs_rot[it % len_rot])
75
+ end.record()
76
+ torch.cuda.synchronize()
77
+ elapsed = start.elapsed_time(end)
78
+
79
+ elif bench_mode in (
80
+ BenchmarkMode.FLUSH_CACHE,
81
+ BenchmarkMode.FLUSH_CACHE_PEAK_PROXY,
82
+ ):
83
+ cache_filler = torch.empty(1024 * 1024 * 256, dtype=torch.int8, device="cuda")
84
+
85
+ if bench_mode == BenchmarkMode.FLUSH_CACHE_PEAK_PROXY:
86
+ _inputs = {}
87
+ for key, value in input_dict.items():
88
+ if isinstance(value, torch.Tensor):
89
+ _inputs.append(
90
+ (key, torch.ones_like(value, requires_grad=value.requires_grad))
91
+ )
92
+ else:
93
+ _inputs.append((key, value))
94
+ input_dict = _inputs
95
+
96
+ for _ in range(warmup_iter):
97
+ cache_filler.zero_()
98
+ _ = f(**input_dict)
99
+
100
+ starts = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
101
+ ends = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
102
+ for i in range(run_iter):
103
+ cache_filler.zero_()
104
+ starts[i].record()
105
+ _ = f(**input_dict)
106
+ ends[i].record()
107
+ torch.cuda.synchronize()
108
+ elapsed = sum(s.elapsed_time(e) for s, e in zip(starts, ends))
109
+
110
+ return elapsed / run_iter, initial_output
111
+
112
+
113
+ def input_to_key_default(**args) -> str:
114
+ key_parts = []
115
+ for arg in args:
116
+ if isinstance(arg, torch.Tensor):
117
+ key_parts.append(f"{list(arg.shape)}_{arg.dtype}")
118
+ elif isinstance(arg, bool):
119
+ key_parts.append("True" if arg else "False")
120
+ elif isinstance(arg, str):
121
+ key_parts.append(arg)
122
+ else:
123
+ key_parts.append(str(arg.__class__.__name__))
124
+
125
+ return "_".join(key_parts)
126
+
127
+
128
+ def combine_all_kwargs(
129
+ fn: Callable,
130
+ args: tuple,
131
+ kwargs: dict[str, Any],
132
+ ) -> dict[str, Any]:
133
+ # Get the function signature
134
+ sig = inspect.signature(fn)
135
+ params = sig.parameters
136
+ param_names = list(params.keys())
137
+
138
+ # Create dictionary of default values
139
+ defaults = {
140
+ name: param.default
141
+ for name, param in params.items()
142
+ if param.default is not inspect.Parameter.empty
143
+ }
144
+ # Create dictionary mapping positional args to parameter names
145
+ args_as_kwargs = {
146
+ param_names[i]: args[i] for i in range(min(len(args), len(param_names)))
147
+ }
148
+ # Create combined dictionary of all parameters
149
+ all_kwargs = defaults.copy() # Start with defaults
150
+ all_kwargs.update(args_as_kwargs) # Override with positional args
151
+ all_kwargs.update(kwargs) # Override with explicit kwargs
152
+
153
+ return all_kwargs
154
+
155
+
156
+ def autotune_aot(
157
+ input_generator: Callable,
158
+ input_to_key: Callable | None,
159
+ input_configs: list[dict[str, Any]],
160
+ tunable_configs: list[dict[str, Any]],
161
+ prune_configs_fn: Callable[
162
+ [list[dict[str, Any]], dict[str, Any]], list[dict[str, Any]]
163
+ ]
164
+ | None,
165
+ bench_mode=BenchmarkMode.ROT_BUFFER,
166
+ warmup_iter=25,
167
+ run_iter=100,
168
+ ) -> None:
169
+ def decorator(fn: Callable) -> Callable:
170
+ def wrapper(*args, **kwargs):
171
+ all_kwargs = combine_all_kwargs(fn, args, kwargs)
172
+ nonlocal input_to_key
173
+ nonlocal input_configs
174
+
175
+ if input_to_key is None:
176
+ input_to_key = input_to_key_default
177
+
178
+ # Check if the function is already cached
179
+ function_key = fn.__name__
180
+ input_key = input_to_key(**all_kwargs)
181
+ cache_manager = get_cache_manager()
182
+ best_cached_config = cache_manager.get(function_key, input_key)
183
+
184
+ aot_mode = cache_manager.aot_mode
185
+
186
+ if best_cached_config is None and aot_mode is not None:
187
+ # start autotuning process
188
+ # input_configs = input_configs + [None]
189
+ if aot_mode == "ONDEMAND":
190
+ input_configs = [None]
191
+
192
+ try:
193
+ # Initialize the progress bar
194
+ progress_bar = tqdm(
195
+ input_configs, desc="Autotuning Progress", unit="config"
196
+ )
197
+
198
+ for input_config in progress_bar:
199
+ # generate input based on the config
200
+ input_data = (
201
+ input_generator(**input_config)
202
+ if input_config is not None
203
+ else all_kwargs
204
+ )
205
+
206
+ # Make a copy of all_kwargs to avoid modifying the original
207
+ current_kwargs = all_kwargs.copy()
208
+ current_kwargs.update(input_data)
209
+ current_input_key = input_to_key(**current_kwargs)
210
+
211
+ best_cached_config = cache_manager.get(
212
+ function_key, current_input_key
213
+ )
214
+
215
+ if best_cached_config is not None:
216
+ continue
217
+
218
+ # prune the tunable configs based on the all_kwargs
219
+ pruned_tunable_configs = (
220
+ prune_configs_fn(tunable_configs, **all_kwargs)
221
+ if prune_configs_fn is not None
222
+ else tunable_configs
223
+ )
224
+
225
+ best_config = None
226
+ best_time = float("inf")
227
+ working_config = []
228
+ for tunable in pruned_tunable_configs:
229
+ try:
230
+ current_kwargs.update(tunable)
231
+ fn(**current_kwargs)
232
+ torch.cuda.synchronize()
233
+ working_config.append(tunable)
234
+ except Exception:
235
+ pass
236
+
237
+ if not working_config:
238
+ continue
239
+
240
+ for tunable in working_config:
241
+ current_kwargs.update(tunable)
242
+ elapse, _ = run_bench(
243
+ fn,
244
+ current_kwargs,
245
+ warmup_iter=warmup_iter,
246
+ run_iter=run_iter,
247
+ bench_mode=bench_mode,
248
+ )
249
+ if elapse < best_time:
250
+ best_time = elapse
251
+ best_config = tunable
252
+
253
+ cache_manager.set(
254
+ function_key,
255
+ current_input_key,
256
+ {"config": best_config, "time": best_time},
257
+ )
258
+ cache_manager.save_cache(function_key)
259
+ except Exception as e:
260
+ print(e)
261
+
262
+ # After tuning, try to get the best config
263
+ best_cached_config = cache_manager.get(function_key, input_key)
264
+
265
+ if best_cached_config is not None:
266
+ all_kwargs.update(best_cached_config["config"])
267
+
268
+ return fn(**all_kwargs)
269
+
270
+ return wrapper
271
+
272
+ return decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: cuequivariance-ops-cu12
3
- Version: 0.4.0
3
+ Version: 0.5.1
4
4
  Summary: cuequivariance-ops - GPU Accelerated Extensions for Equivariant Primitives
5
5
  Author: NVIDIA Corporation
6
6
  License: # Software License Agreement
@@ -177,6 +177,10 @@ Classifier: Programming Language :: Python
177
177
  Project-URL: Homepage, https://github.com/nvidia/cuEquivariance
178
178
  Project-URL: Documentation, https://github.com/nvidia/cuEquivariance
179
179
  Requires-Python: >=3.10
180
+ Requires-Dist: nvidia-cublas-cu12>=12.5.0
181
+ Requires-Dist: tqdm
182
+ Requires-Dist: pynvml
183
+ Requires-Dist: platformdirs
180
184
  Provides-Extra: test
181
185
  Requires-Dist: numpy; extra == "test"
182
186
  Requires-Dist: pytest; extra == "test"
@@ -0,0 +1,32 @@
1
+ cuequivariance_ops/VERSION,sha256=q6lRYmyGkM5JPLyPAYIFu0aAA_YfvwD9PTxxzrq8AXc,6
2
+ cuequivariance_ops/__init__.py,sha256=wvvAMuXpOg5W4oE-AnHDWoHPzcamAWK_DiUXyg3hgW8,1332
3
+ cuequivariance_ops/_version.py,sha256=o9Flao_mTq2Y7TrrjnSCqEAgebmA0sGozsl15qVI13Y,730
4
+ cuequivariance_ops/common/common.hpp,sha256=2zDyE5lGugQL43vmM4_ylmp-Tz8OBFnPRsdFra_1BdM,2787
5
+ cuequivariance_ops/common/nvtx.hpp,sha256=Wi6z9b-yFUNq6ShJjjcsdxQRqCygd4xGegGJrqUI9Wk,708
6
+ cuequivariance_ops/equivariance/dtypes.hh,sha256=w0BYWZ0LYklODXhp7PR6VYE__DE1Syj0Ur11aFaq9VM,466
7
+ cuequivariance_ops/equivariance/fused_tensor_product.cuh,sha256=bOXR5UWU9gNYRfdh6k28NEkV3CUU2ijmh6y7c0ND0J4,8283
8
+ cuequivariance_ops/equivariance/indexed_linear.hh,sha256=lNqJNafJdPyMAUp6iwWvu6RyassSXh7JqyqJ4bfjoxQ,1402
9
+ cuequivariance_ops/equivariance/run_fmha.h,sha256=7l62dTQJbX7BbHLB7MmVP1t26Cfpmcu3h6eY048Hof0,9505
10
+ cuequivariance_ops/equivariance/run_fmha_cudafree.h,sha256=bF2_nrvSfrqSVZ0eOPcq4CJ-NKqmJ2VgQv1cstvHBkU,2695
11
+ cuequivariance_ops/equivariance/segmented_transpose.cuh,sha256=gfSZhRBwSqwVAgFCCiGtI-NJ8yDy9tV_iCg1G2KpctY,1766
12
+ cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh,sha256=7PPybCWczS58XKA-iFLoCM7MDEomO4-enF6RCBj5G5M,1922
13
+ cuequivariance_ops/lib/libcue_ops.so,sha256=rSk0Km-M-Zu703bOMV-OF8mykMPOgFTyx73OW_r8ofM,112269112
14
+ cuequivariance_ops/triton/__init__.py,sha256=LCHvxif4kwr0Squy7mjgx0NCUyM2AcOjkDg5CXZZtuA,1053
15
+ cuequivariance_ops/triton/cache_manager.py,sha256=sXlbuCKsoRMEc1wQVcdk7Vk18LJdej97Ve9FkxdQRYU,9154
16
+ cuequivariance_ops/triton/fused_layer_norm_triton.py,sha256=SyQf_eJvTKm3Foe8BI0sjWZdEtGp__LQ1qgImlHLc4c,11056
17
+ cuequivariance_ops/triton/gated_gemm_triton.py,sha256=PEJcgNVZUk8G5Z5ukD8Ksbe71kUWHUufjUINI5JGnV8,11405
18
+ cuequivariance_ops/triton/tuning_decorator.py,sha256=ruN_Ck5Np6a09slN-VbK0uF4IEgR9WxNIVK4bKmPKzo,9813
19
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json,sha256=4Gi4yJ_I-smVSPzEWZn_kZWktn6BI5sNFk5wJDE4aH8,1397798
20
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json,sha256=HUK_ayOTS5WrJy_W_sVxyBSADLubp69spp8lfjbkHX8,1392688
21
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json,sha256=HFZDB_XzoSSg1DToHV297NKLrCSycsA4QUQ_aOSelOs,1392431
22
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json,sha256=eU6hDvUU8YAxCGydnXd6Dnl9x6xo52KonOBxRbrvBgw,1392528
23
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json,sha256=3wruFbYFOLXQZXQg0FQPN11X8YYeHCowrARr61yTbew,2785263
24
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json,sha256=9hMqFldfcq4rFnKalSNm5vO_bUbIw66MO-BrqhSodRI,1754706
25
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json,sha256=grlXjI18H5d71mIcYY8sgo6s4Ssz3aUSCK4rMCP0dtc,2011725
26
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json,sha256=ndWI0VB8R0RjvEC2JPZz_GbQQ_xfg6zEfrvmSgSMMMw,2010879
27
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json,sha256=biWhNBUqJiYNuuC1hyimbUTndFINBrQpAHEcXFOluN0,2011532
28
+ cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json,sha256=3_v5C0cW_Ab3rNC5cu-SBRUQ7Ala18fvzvytSPa-KHI,4025067
29
+ cuequivariance_ops_cu12-0.5.1.dist-info/METADATA,sha256=PKjZtsBFTYmyoPzltH0OwMWljL8FJRJgWSJSp8tTDFU,20954
30
+ cuequivariance_ops_cu12-0.5.1.dist-info/WHEEL,sha256=RxM28Avh4PDgHOLX-AZLV1MP0dIb1yycxVPEx6_SFW0,116
31
+ cuequivariance_ops_cu12-0.5.1.dist-info/RECORD,,
32
+ cuequivariance_ops_cu12-0.5.1.dist-info/licenses/LICENSE,sha256=rvp0QV9FuOdxz_CGWTd9DgId4xh2BByyXfBBnb0ejZM,18279
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: scikit-build-core 0.11.1
2
+ Generator: scikit-build-core 0.11.4
3
3
  Root-Is-Purelib: false
4
4
  Tag: py3-none-manylinux_2_39_aarch64
5
5
 
@@ -1,13 +0,0 @@
1
- cuequivariance_ops/VERSION,sha256=QLjrQACpE6d5EJBTXykdPTaYdBYqie88nj1OiHobnnk,6
2
- cuequivariance_ops/__init__.py,sha256=ba7jv_WICRROtLbDU2O1u0MHxp6VkVu0-UGKuQxf9iw,1255
3
- cuequivariance_ops/_version.py,sha256=o9Flao_mTq2Y7TrrjnSCqEAgebmA0sGozsl15qVI13Y,730
4
- cuequivariance_ops/common/common.hpp,sha256=2zDyE5lGugQL43vmM4_ylmp-Tz8OBFnPRsdFra_1BdM,2787
5
- cuequivariance_ops/common/nvtx.hpp,sha256=Wi6z9b-yFUNq6ShJjjcsdxQRqCygd4xGegGJrqUI9Wk,708
6
- cuequivariance_ops/equivariance/fused_tensor_product.cuh,sha256=bOXR5UWU9gNYRfdh6k28NEkV3CUU2ijmh6y7c0ND0J4,8283
7
- cuequivariance_ops/equivariance/segmented_transpose.cuh,sha256=gfSZhRBwSqwVAgFCCiGtI-NJ8yDy9tV_iCg1G2KpctY,1766
8
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh,sha256=oWhSS0ZmMHlye8eTucweoGBtzN1H0nN1GX_Rz-MsPqI,2002
9
- cuequivariance_ops/lib/libcue_ops.so,sha256=VQP3gnNy4jpVad__bbaZiakvx5c8J63Q2knrvfIRLJc,81794536
10
- cuequivariance_ops_cu12-0.4.0.dist-info/METADATA,sha256=na1Ly8dpRX4aVspiUYshcsTnjVmh0zsW13AtoteoEJ0,20842
11
- cuequivariance_ops_cu12-0.4.0.dist-info/WHEEL,sha256=teK9zuS7Jv7dMHQejkMfDwwTIgdimcBypnObHv4zSrs,116
12
- cuequivariance_ops_cu12-0.4.0.dist-info/RECORD,,
13
- cuequivariance_ops_cu12-0.4.0.dist-info/licenses/LICENSE,sha256=rvp0QV9FuOdxz_CGWTd9DgId4xh2BByyXfBBnb0ejZM,18279