cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.0__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

@@ -0,0 +1,340 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Precision(enum.Enum):
18
+ DEFAULT = 0
19
+ TF32 = 1
20
+ TF32x3 = 2
21
+ IEEE = 3
22
+
23
+
24
+ @triton.jit
25
+ def fused_sigmoid_gated_dual_gemm_forward_kernel(
26
+ x1_ptr,
27
+ x2_ptr,
28
+ w1_ptr,
29
+ w2_ptr,
30
+ mask_ptr,
31
+ o_ptr,
32
+ M,
33
+ N,
34
+ K,
35
+ TILE_M: tl.constexpr,
36
+ TILE_N: tl.constexpr,
37
+ TILE_K: tl.constexpr,
38
+ PRECISION: tl.constexpr,
39
+ APPLY_MASK: tl.constexpr,
40
+ TRANSPOSE_OUT: tl.constexpr,
41
+ TWO_INPUTS: tl.constexpr,
42
+ ):
43
+ # fully gated GEMM kernel with optional mask at the end
44
+ pid_m = tl.program_id(axis=0)
45
+ pid_n = tl.program_id(axis=1)
46
+
47
+ start_m = pid_m * TILE_M
48
+ start_n = pid_n * TILE_N
49
+
50
+ offs_xm = start_m + tl.arange(0, TILE_M)
51
+ offs_wn = start_n + tl.arange(0, TILE_N)
52
+ offs_k = tl.arange(0, TILE_K)
53
+
54
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
55
+ if TWO_INPUTS:
56
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
57
+
58
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
59
+
60
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
61
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
62
+
63
+ mask_m = offs_xm < M
64
+
65
+ if TWO_INPUTS:
66
+ for _ in range(0, tl.cdiv(K, TILE_K)):
67
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
68
+ w1_ptr.type.element_ty
69
+ )
70
+ w1_ptrs = w1_ptr + w_tile_offs
71
+ w1 = tl.load(w1_ptrs)
72
+ if PRECISION == 0:
73
+ acc_1 = tl.dot(x1, w1, acc_1)
74
+ elif PRECISION == 1:
75
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
76
+ elif PRECISION == 2:
77
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
78
+ elif PRECISION == 3:
79
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
80
+ else:
81
+ tl.static_assert(
82
+ False,
83
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
84
+ )
85
+
86
+ x1_ptrs += TILE_K
87
+ w1_ptr += TILE_K
88
+
89
+ for _ in range(0, tl.cdiv(K, TILE_K)):
90
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
91
+ w2_ptr.type.element_ty
92
+ )
93
+ w2_ptrs = w2_ptr + w_tile_offs
94
+ w2 = tl.load(w2_ptrs)
95
+ if PRECISION == 0:
96
+ acc_2 = tl.dot(x2, w2, acc_2)
97
+ elif PRECISION == 1:
98
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
99
+ elif PRECISION == 2:
100
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
101
+ elif PRECISION == 3:
102
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
103
+ else:
104
+ tl.static_assert(
105
+ False,
106
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
107
+ )
108
+
109
+ x2_ptrs += TILE_K
110
+ w2_ptr += TILE_K
111
+
112
+ else:
113
+ for _ in range(0, tl.cdiv(K, TILE_K)):
114
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
115
+ w1_ptr.type.element_ty
116
+ )
117
+
118
+ w1_ptrs = w1_ptr + w_tile_offs
119
+ w1 = tl.load(w1_ptrs)
120
+ if PRECISION == 0:
121
+ acc_1 = tl.dot(x, w1, acc_1)
122
+ elif PRECISION == 1:
123
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
124
+ elif PRECISION == 2:
125
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
126
+ elif PRECISION == 3:
127
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
128
+ else:
129
+ tl.static_assert(
130
+ False,
131
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
132
+ )
133
+
134
+ w2_ptrs = w2_ptr + w_tile_offs
135
+ w2 = tl.load(w2_ptrs)
136
+ if PRECISION == 0:
137
+ acc_2 = tl.dot(x, w2, acc_2)
138
+ elif PRECISION == 1:
139
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
140
+ elif PRECISION == 2:
141
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
142
+ elif PRECISION == 3:
143
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
144
+ else:
145
+ tl.static_assert(
146
+ False,
147
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
148
+ )
149
+
150
+ x1_ptrs += TILE_K
151
+ w1_ptr += TILE_K
152
+ w2_ptr += TILE_K
153
+
154
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
155
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
156
+
157
+ acc_1 = 1.0 / (1.0 + tl.exp(-acc_1))
158
+ acc_gated = acc_1 * acc_2
159
+
160
+ if APPLY_MASK:
161
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
162
+ acc_gated = acc_gated * mask[:, None]
163
+
164
+ acc_gated = acc_gated.to(o_ptr.dtype.element_ty)
165
+
166
+ if TRANSPOSE_OUT:
167
+ o_ptrs = o_ptr + offs_on[None, :] * M + offs_om[:, None]
168
+ else:
169
+ o_ptrs = o_ptr + offs_om[:, None] * N + offs_on[None, :]
170
+
171
+ o_mask = offs_om[:, None] < M
172
+ tl.store(o_ptrs, acc_gated, mask=o_mask)
173
+
174
+
175
+ @triton.jit
176
+ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel(
177
+ grad_xw1_ptr,
178
+ grad_xw2_ptr,
179
+ grad_mask_ptr,
180
+ grad_o_ptr,
181
+ x1_ptr,
182
+ x2_ptr,
183
+ w1_ptr,
184
+ w2_ptr,
185
+ mask_ptr,
186
+ M,
187
+ N,
188
+ K,
189
+ TILE_M: tl.constexpr,
190
+ TILE_N: tl.constexpr,
191
+ TILE_K: tl.constexpr,
192
+ PRECISION: tl.constexpr,
193
+ APPLY_MASK: tl.constexpr,
194
+ TRANSPOSE_OUT: tl.constexpr,
195
+ TWO_INPUTS: tl.constexpr,
196
+ ):
197
+ # fully gated GEMM kernel with optional mask at the end
198
+ pid_m = tl.program_id(axis=0)
199
+ pid_n = tl.program_id(axis=1)
200
+
201
+ start_m = pid_m * TILE_M
202
+ start_n = pid_n * TILE_N
203
+
204
+ offs_xm = start_m + tl.arange(0, TILE_M)
205
+ offs_wn = start_n + tl.arange(0, TILE_N)
206
+ offs_k = tl.arange(0, TILE_K)
207
+
208
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
209
+ if TWO_INPUTS:
210
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
211
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
212
+
213
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
214
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
215
+
216
+ mask_m = offs_xm < M
217
+
218
+ if TWO_INPUTS:
219
+ # recompute acc1 and acc2
220
+ for _ in range(0, tl.cdiv(K, TILE_K)):
221
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
222
+ w1_ptr.type.element_ty
223
+ )
224
+ w1_ptrs = w1_ptr + w_tile_offs
225
+ w1 = tl.load(w1_ptrs)
226
+
227
+ if PRECISION == 0:
228
+ acc_1 = tl.dot(x1, w1, acc_1)
229
+ elif PRECISION == 1:
230
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
231
+ elif PRECISION == 2:
232
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
233
+ elif PRECISION == 3:
234
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
235
+ else:
236
+ tl.static_assert(
237
+ False,
238
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
239
+ )
240
+
241
+ x1_ptrs += TILE_K
242
+ w1_ptr += TILE_K
243
+
244
+ for _ in range(0, tl.cdiv(K, TILE_K)):
245
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
246
+ w2_ptr.type.element_ty
247
+ )
248
+ w2_ptrs = w2_ptr + w_tile_offs
249
+ w2 = tl.load(w2_ptrs)
250
+
251
+ if PRECISION == 0:
252
+ acc_2 = tl.dot(x2, w2, acc_2)
253
+ elif PRECISION == 1:
254
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
255
+ elif PRECISION == 2:
256
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
257
+ elif PRECISION == 3:
258
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
259
+ else:
260
+ tl.static_assert(
261
+ False,
262
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
263
+ )
264
+
265
+ x2_ptrs += TILE_K
266
+ w2_ptr += TILE_K
267
+
268
+ else:
269
+ # recompute acc1 and acc2
270
+ for _ in range(0, tl.cdiv(K, TILE_K)):
271
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
272
+ w1_ptr.type.element_ty
273
+ )
274
+
275
+ w1_ptrs = w1_ptr + w_tile_offs
276
+ w1 = tl.load(w1_ptrs)
277
+ if PRECISION == 0:
278
+ acc_1 = tl.dot(x, w1, acc_1)
279
+ elif PRECISION == 1:
280
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
281
+ elif PRECISION == 2:
282
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
283
+ elif PRECISION == 3:
284
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
285
+ else:
286
+ tl.static_assert(
287
+ False,
288
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
289
+ )
290
+
291
+ w2_ptrs = w2_ptr + w_tile_offs
292
+ w2 = tl.load(w2_ptrs)
293
+ if PRECISION == 0:
294
+ acc_2 = tl.dot(x, w2, acc_2)
295
+ elif PRECISION == 1:
296
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
297
+ elif PRECISION == 2:
298
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
299
+ elif PRECISION == 3:
300
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
301
+ else:
302
+ tl.static_assert(
303
+ False,
304
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
305
+ )
306
+
307
+ x1_ptrs += TILE_K
308
+ w1_ptr += TILE_K
309
+ w2_ptr += TILE_K
310
+
311
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
312
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
313
+ if TRANSPOSE_OUT:
314
+ grad_o_ptrs = grad_o_ptr + offs_on[None, :] * M + offs_om[:, None]
315
+ else:
316
+ grad_o_ptrs = grad_o_ptr + offs_om[:, None] * N + offs_on[None, :]
317
+
318
+ grad_o = tl.load(grad_o_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
319
+
320
+ acc_sig = 1.0 / (1.0 + tl.exp(-acc_1))
321
+
322
+ if APPLY_MASK:
323
+ tmp = acc_sig * acc_2
324
+ grad_mask = grad_o * tmp
325
+ grad_mask = tl.sum(grad_mask, axis=1)
326
+ grad_mask_ptrs = grad_mask_ptr + pid_n * M + offs_om
327
+ tl.store(grad_mask_ptrs, grad_mask.to(grad_mask.type.element_ty), mask=mask_m)
328
+
329
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
330
+ grad_o = grad_o * mask[:, None]
331
+
332
+ tmp = (1.0 - acc_sig) * acc_sig
333
+
334
+ grad_xw1 = grad_o * acc_2 * tmp
335
+ grad_xw2 = grad_o * acc_sig
336
+
337
+ grad_xw1_ptrs = grad_xw1_ptr + offs_om[:, None] * N + offs_on[None, :]
338
+ grad_xw2_ptrs = grad_xw2_ptr + offs_om[:, None] * N + offs_on[None, :]
339
+ tl.store(grad_xw1_ptrs, grad_xw1.to(grad_xw1.type.element_ty), mask=mask_m[:, None])
340
+ tl.store(grad_xw2_ptrs, grad_xw2.to(grad_xw2.type.element_ty), mask=mask_m[:, None])
Binary file
@@ -0,0 +1,328 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import inspect
12
+
13
+ # import logging # Added logging import
14
+ import os
15
+ from enum import Enum
16
+ from typing import Any, Callable
17
+
18
+ import torch
19
+ from tqdm import tqdm
20
+
21
+ from .cache_manager import CacheManager
22
+
23
+ # Configure logging
24
+ # logger = logging.getLogger(__name__)
25
+
26
+
27
+ class BenchmarkMode(Enum):
28
+ FLUSH_CACHE = 0
29
+ FLUSH_CACHE_PEAK_PROXY = 1
30
+ ROT_BUFFER = 2
31
+ ROT_BUFFER_PEAK_PROXY = 3
32
+
33
+
34
+ def run_bench(
35
+ f, input_dict, warmup_iter=250, run_iter=250, bench_mode=BenchmarkMode.ROT_BUFFER
36
+ ):
37
+ initial_output = f(**input_dict)
38
+
39
+ if bench_mode in (BenchmarkMode.ROT_BUFFER, BenchmarkMode.ROT_BUFFER_PEAK_PROXY):
40
+ len_rot = 4
41
+ inputs_rot = [None] * len_rot
42
+ for r in range(len_rot):
43
+ r_inputs = []
44
+ for key, value in input_dict.items():
45
+ if isinstance(value, torch.Tensor):
46
+ if bench_mode == BenchmarkMode.ROT_BUFFER_PEAK_PROXY:
47
+ r_inputs.append(
48
+ (
49
+ key,
50
+ torch.ones_like(
51
+ value, requires_grad=value.requires_grad
52
+ ),
53
+ )
54
+ )
55
+ else:
56
+ r_inputs.append(
57
+ (
58
+ key,
59
+ torch.randn_like(
60
+ value, requires_grad=value.requires_grad
61
+ ),
62
+ )
63
+ )
64
+ else:
65
+ r_inputs.append((key, value))
66
+ r_inputs = dict(r_inputs)
67
+ inputs_rot[r] = r_inputs
68
+
69
+ for it in range(warmup_iter):
70
+ _ = f(**inputs_rot[it % len_rot])
71
+
72
+ start = torch.cuda.Event(enable_timing=True)
73
+ end = torch.cuda.Event(enable_timing=True)
74
+ start.record()
75
+ for it in range(run_iter):
76
+ _ = f(**inputs_rot[it % len_rot])
77
+ end.record()
78
+ torch.cuda.synchronize()
79
+ elapsed = start.elapsed_time(end)
80
+
81
+ elif bench_mode in (
82
+ BenchmarkMode.FLUSH_CACHE,
83
+ BenchmarkMode.FLUSH_CACHE_PEAK_PROXY,
84
+ ):
85
+ cache_filler = torch.empty(1024 * 1024 * 256, dtype=torch.int8, device="cuda")
86
+
87
+ if bench_mode == BenchmarkMode.FLUSH_CACHE_PEAK_PROXY:
88
+ _inputs = {}
89
+ for key, value in input_dict.items():
90
+ if isinstance(value, torch.Tensor):
91
+ _inputs.append(
92
+ (key, torch.ones_like(value, requires_grad=value.requires_grad))
93
+ )
94
+ else:
95
+ _inputs.append((key, value))
96
+ input_dict = _inputs
97
+
98
+ for _ in range(warmup_iter):
99
+ cache_filler.zero_()
100
+ _ = f(**input_dict)
101
+
102
+ starts = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
103
+ ends = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
104
+ for i in range(run_iter):
105
+ cache_filler.zero_()
106
+ starts[i].record()
107
+ _ = f(**input_dict)
108
+ ends[i].record()
109
+ torch.cuda.synchronize()
110
+ elapsed = sum(s.elapsed_time(e) for s, e in zip(starts, ends))
111
+
112
+ return elapsed / run_iter, initial_output
113
+
114
+
115
+ def input_to_key_default(**args) -> str:
116
+ key_parts = []
117
+ for arg in args:
118
+ if isinstance(arg, torch.Tensor):
119
+ key_parts.append(f"{list(arg.shape)}_{arg.dtype}")
120
+ elif isinstance(arg, bool):
121
+ key_parts.append("True" if arg else "False")
122
+ elif isinstance(arg, str):
123
+ key_parts.append(arg)
124
+ else:
125
+ key_parts.append(str(arg.__class__.__name__))
126
+
127
+ return "_".join(key_parts)
128
+
129
+
130
+ def combine_all_kwargs(
131
+ fn: Callable,
132
+ args: tuple,
133
+ kwargs: dict[str, Any],
134
+ ) -> dict[str, Any]:
135
+ # Get the function signature
136
+ sig = inspect.signature(fn)
137
+ params = sig.parameters
138
+ param_names = list(params.keys())
139
+
140
+ # Create dictionary of default values
141
+ defaults = {
142
+ name: param.default
143
+ for name, param in params.items()
144
+ if param.default is not inspect.Parameter.empty
145
+ }
146
+ # Create dictionary mapping positional args to parameter names
147
+ args_as_kwargs = {
148
+ param_names[i]: args[i] for i in range(min(len(args), len(param_names)))
149
+ }
150
+ # Create combined dictionary of all parameters
151
+ all_kwargs = defaults.copy() # Start with defaults
152
+ all_kwargs.update(args_as_kwargs) # Override with positional args
153
+ all_kwargs.update(kwargs) # Override with explicit kwargs
154
+
155
+ return all_kwargs
156
+
157
+
158
+ def autotune_aot(
159
+ input_generator: Callable,
160
+ input_to_key: Callable | None,
161
+ input_configs: list[dict[str, Any]],
162
+ tunable_configs: list[dict[str, Any]],
163
+ prune_configs_fn: Callable[
164
+ [list[dict[str, Any]], dict[str, Any]], list[dict[str, Any]]
165
+ ]
166
+ | None,
167
+ bench_mode=BenchmarkMode.ROT_BUFFER,
168
+ warmup_iter=25,
169
+ run_iter=100,
170
+ ) -> None:
171
+ def decorator(fn: Callable) -> Callable:
172
+ def wrapper(*args, **kwargs):
173
+ all_kwargs = combine_all_kwargs(fn, args, kwargs)
174
+ nonlocal input_to_key
175
+ nonlocal input_configs
176
+ # Early exit if AOT is disabled
177
+ if os.environ.get("CUEQ_DISABLE_AOT_TUNING") == "1":
178
+ if input_to_key is None:
179
+ input_to_key = input_to_key_default
180
+
181
+ function_key = fn.__name__
182
+ cache_manager = CacheManager()
183
+ input_key = input_to_key(**all_kwargs)
184
+
185
+ try:
186
+ best_cached_config = cache_manager.get(function_key, input_key)
187
+ if best_cached_config is not None:
188
+ all_kwargs.update(best_cached_config["config"])
189
+ return fn(**all_kwargs)
190
+ elif os.environ.get("CUEQ_DEFAULT_CONFIG") != "1":
191
+ input_configs = [None]
192
+ # Continue to rest of function for tuning
193
+ pass
194
+ except (PermissionError, OSError):
195
+ pass
196
+
197
+ if os.environ.get("CUEQ_DEFAULT_CONFIG") == "1":
198
+ return fn(**all_kwargs)
199
+
200
+ if input_to_key is None:
201
+ input_to_key = input_to_key_default
202
+
203
+ # Check if the function is already cached
204
+ function_key = fn.__name__
205
+ cache_manager = CacheManager()
206
+ input_key = input_to_key(**all_kwargs)
207
+
208
+ try:
209
+ best_cached_config = cache_manager.get(function_key, input_key)
210
+ except (PermissionError, OSError):
211
+ # If there's a permission error, fall back to JIT compilation
212
+ # logger.warning(f"Permission error accessing cache: {e}. Falling back to JIT compilation.")
213
+ return fn(**all_kwargs)
214
+
215
+ if best_cached_config is not None:
216
+ # If cached, return the cached result
217
+ all_kwargs.update(best_cached_config["config"])
218
+ result = fn(**all_kwargs)
219
+ return result
220
+ else:
221
+ # start autotuning process
222
+ input_configs = input_configs + [None]
223
+
224
+ try:
225
+ # Initialize the progress bar
226
+ progress_bar = tqdm(
227
+ input_configs, desc="Autotuning Progress", unit="config"
228
+ )
229
+
230
+ for input_config in progress_bar:
231
+ # generate input based on the config
232
+ input_data = (
233
+ input_generator(**input_config)
234
+ if input_config is not None
235
+ else all_kwargs
236
+ )
237
+
238
+ # Make a copy of all_kwargs to avoid modifying the original
239
+ current_kwargs = all_kwargs.copy()
240
+ current_kwargs.update(input_data)
241
+ current_input_key = input_to_key(**current_kwargs)
242
+
243
+ try:
244
+ best_cached_config = cache_manager.get(
245
+ function_key, current_input_key
246
+ )
247
+ except (PermissionError, OSError):
248
+ # If there's a permission error, skip this config and continue
249
+ # logger.warning(f"Permission error accessing cache for config: {e}. Skipping.")
250
+ continue
251
+
252
+ if best_cached_config is not None:
253
+ continue
254
+
255
+ # prune the tunable configs based on the all_kwargs
256
+ pruned_tunable_configs = (
257
+ prune_configs_fn(tunable_configs, **all_kwargs)
258
+ if prune_configs_fn is not None
259
+ else tunable_configs
260
+ )
261
+
262
+ best_config = None
263
+ best_time = float("inf")
264
+ working_config = []
265
+ for tunable in pruned_tunable_configs:
266
+ try:
267
+ current_kwargs.update(tunable)
268
+ fn(**current_kwargs)
269
+ torch.cuda.synchronize()
270
+ working_config.append(tunable)
271
+ except Exception:
272
+ pass
273
+
274
+ if not working_config:
275
+ continue
276
+
277
+ for tunable in working_config:
278
+ current_kwargs.update(tunable)
279
+ elapse, _ = run_bench(
280
+ fn,
281
+ current_kwargs,
282
+ warmup_iter=warmup_iter,
283
+ run_iter=run_iter,
284
+ bench_mode=bench_mode,
285
+ )
286
+ if elapse < best_time:
287
+ best_time = elapse
288
+ best_config = tunable
289
+
290
+ try:
291
+ cache_manager.set(
292
+ function_key,
293
+ current_input_key,
294
+ {"config": best_config, "time": best_time},
295
+ )
296
+ except (PermissionError, OSError):
297
+ # If there's a permission error saving the cache, continue without saving
298
+ # logger.warning(f"Permission error saving cache: {e}. Continuing without saving.")
299
+ pass
300
+
301
+ try:
302
+ # Save the cache to a file
303
+ cache_manager.save_cache(function_key)
304
+ except (PermissionError, OSError):
305
+ # If there's a permission error saving the final cache, continue without saving
306
+ # logger.warning(f"Permission error saving final cache: {e}. Continuing without saving.")
307
+ pass
308
+
309
+ # After tuning, try to get the best config
310
+ try:
311
+ best_cached_config = cache_manager.get(function_key, input_key)
312
+ if best_cached_config is not None:
313
+ all_kwargs.update(best_cached_config["config"])
314
+ except (PermissionError, OSError):
315
+ # If there's a permission error getting the final config, continue without it
316
+ # logger.warning(f"Permission error getting final config: {e}. Continuing without config.")
317
+ pass
318
+
319
+ except Exception:
320
+ # If any other error occurs during tuning, fall back to JIT compilation
321
+ # logger.warning(f"Error during tuning: {e}. Falling back to JIT compilation.")
322
+ pass
323
+
324
+ return fn(**all_kwargs)
325
+
326
+ return wrapper
327
+
328
+ return decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: cuequivariance-ops-cu12
3
- Version: 0.4.0
3
+ Version: 0.5.0
4
4
  Summary: cuequivariance-ops - GPU Accelerated Extensions for Equivariant Primitives
5
5
  Author: NVIDIA Corporation
6
6
  License: # Software License Agreement
@@ -177,6 +177,9 @@ Classifier: Programming Language :: Python
177
177
  Project-URL: Homepage, https://github.com/nvidia/cuEquivariance
178
178
  Project-URL: Documentation, https://github.com/nvidia/cuEquivariance
179
179
  Requires-Python: >=3.10
180
+ Requires-Dist: nvidia-cublas-cu12>=12.5.0
181
+ Requires-Dist: tqdm
182
+ Requires-Dist: pynvml
180
183
  Provides-Extra: test
181
184
  Requires-Dist: numpy; extra == "test"
182
185
  Requires-Dist: pytest; extra == "test"