cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.0__py3-none-manylinux_2_39_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.
- cuequivariance_ops/VERSION +1 -1
- cuequivariance_ops/__init__.py +3 -2
- cuequivariance_ops/cache_manager.py +130 -0
- cuequivariance_ops/equivariance/dtypes.hh +21 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +17 -35
- cuequivariance_ops/fused_layer_norm_triton.py +324 -0
- cuequivariance_ops/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.json +222844 -0
- cuequivariance_ops/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.json +326932 -0
- cuequivariance_ops/gated_gemm_triton.py +340 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/tuning_decorator.py +328 -0
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.0.dist-info}/METADATA +4 -1
- cuequivariance_ops_cu12-0.5.0.dist-info/RECORD +23 -0
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.0.dist-info}/WHEEL +1 -1
- cuequivariance_ops_cu12-0.4.0.dist-info/RECORD +0 -13
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
import enum
|
|
12
|
+
|
|
13
|
+
import triton
|
|
14
|
+
import triton.language as tl
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Precision(enum.Enum):
|
|
18
|
+
DEFAULT = 0
|
|
19
|
+
TF32 = 1
|
|
20
|
+
TF32x3 = 2
|
|
21
|
+
IEEE = 3
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@triton.jit
|
|
25
|
+
def fused_sigmoid_gated_dual_gemm_forward_kernel(
|
|
26
|
+
x1_ptr,
|
|
27
|
+
x2_ptr,
|
|
28
|
+
w1_ptr,
|
|
29
|
+
w2_ptr,
|
|
30
|
+
mask_ptr,
|
|
31
|
+
o_ptr,
|
|
32
|
+
M,
|
|
33
|
+
N,
|
|
34
|
+
K,
|
|
35
|
+
TILE_M: tl.constexpr,
|
|
36
|
+
TILE_N: tl.constexpr,
|
|
37
|
+
TILE_K: tl.constexpr,
|
|
38
|
+
PRECISION: tl.constexpr,
|
|
39
|
+
APPLY_MASK: tl.constexpr,
|
|
40
|
+
TRANSPOSE_OUT: tl.constexpr,
|
|
41
|
+
TWO_INPUTS: tl.constexpr,
|
|
42
|
+
):
|
|
43
|
+
# fully gated GEMM kernel with optional mask at the end
|
|
44
|
+
pid_m = tl.program_id(axis=0)
|
|
45
|
+
pid_n = tl.program_id(axis=1)
|
|
46
|
+
|
|
47
|
+
start_m = pid_m * TILE_M
|
|
48
|
+
start_n = pid_n * TILE_N
|
|
49
|
+
|
|
50
|
+
offs_xm = start_m + tl.arange(0, TILE_M)
|
|
51
|
+
offs_wn = start_n + tl.arange(0, TILE_N)
|
|
52
|
+
offs_k = tl.arange(0, TILE_K)
|
|
53
|
+
|
|
54
|
+
x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
|
|
55
|
+
if TWO_INPUTS:
|
|
56
|
+
x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
|
|
57
|
+
|
|
58
|
+
w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
|
|
59
|
+
|
|
60
|
+
acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
|
61
|
+
acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
|
62
|
+
|
|
63
|
+
mask_m = offs_xm < M
|
|
64
|
+
|
|
65
|
+
if TWO_INPUTS:
|
|
66
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
67
|
+
x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
68
|
+
w1_ptr.type.element_ty
|
|
69
|
+
)
|
|
70
|
+
w1_ptrs = w1_ptr + w_tile_offs
|
|
71
|
+
w1 = tl.load(w1_ptrs)
|
|
72
|
+
if PRECISION == 0:
|
|
73
|
+
acc_1 = tl.dot(x1, w1, acc_1)
|
|
74
|
+
elif PRECISION == 1:
|
|
75
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
|
|
76
|
+
elif PRECISION == 2:
|
|
77
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
|
|
78
|
+
elif PRECISION == 3:
|
|
79
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
|
|
80
|
+
else:
|
|
81
|
+
tl.static_assert(
|
|
82
|
+
False,
|
|
83
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
x1_ptrs += TILE_K
|
|
87
|
+
w1_ptr += TILE_K
|
|
88
|
+
|
|
89
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
90
|
+
x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
91
|
+
w2_ptr.type.element_ty
|
|
92
|
+
)
|
|
93
|
+
w2_ptrs = w2_ptr + w_tile_offs
|
|
94
|
+
w2 = tl.load(w2_ptrs)
|
|
95
|
+
if PRECISION == 0:
|
|
96
|
+
acc_2 = tl.dot(x2, w2, acc_2)
|
|
97
|
+
elif PRECISION == 1:
|
|
98
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
|
|
99
|
+
elif PRECISION == 2:
|
|
100
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
|
|
101
|
+
elif PRECISION == 3:
|
|
102
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
|
|
103
|
+
else:
|
|
104
|
+
tl.static_assert(
|
|
105
|
+
False,
|
|
106
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
x2_ptrs += TILE_K
|
|
110
|
+
w2_ptr += TILE_K
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
114
|
+
x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
115
|
+
w1_ptr.type.element_ty
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
w1_ptrs = w1_ptr + w_tile_offs
|
|
119
|
+
w1 = tl.load(w1_ptrs)
|
|
120
|
+
if PRECISION == 0:
|
|
121
|
+
acc_1 = tl.dot(x, w1, acc_1)
|
|
122
|
+
elif PRECISION == 1:
|
|
123
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
|
|
124
|
+
elif PRECISION == 2:
|
|
125
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
|
|
126
|
+
elif PRECISION == 3:
|
|
127
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
|
|
128
|
+
else:
|
|
129
|
+
tl.static_assert(
|
|
130
|
+
False,
|
|
131
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
w2_ptrs = w2_ptr + w_tile_offs
|
|
135
|
+
w2 = tl.load(w2_ptrs)
|
|
136
|
+
if PRECISION == 0:
|
|
137
|
+
acc_2 = tl.dot(x, w2, acc_2)
|
|
138
|
+
elif PRECISION == 1:
|
|
139
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
|
|
140
|
+
elif PRECISION == 2:
|
|
141
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
|
|
142
|
+
elif PRECISION == 3:
|
|
143
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
|
|
144
|
+
else:
|
|
145
|
+
tl.static_assert(
|
|
146
|
+
False,
|
|
147
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
x1_ptrs += TILE_K
|
|
151
|
+
w1_ptr += TILE_K
|
|
152
|
+
w2_ptr += TILE_K
|
|
153
|
+
|
|
154
|
+
offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
|
|
155
|
+
offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
|
|
156
|
+
|
|
157
|
+
acc_1 = 1.0 / (1.0 + tl.exp(-acc_1))
|
|
158
|
+
acc_gated = acc_1 * acc_2
|
|
159
|
+
|
|
160
|
+
if APPLY_MASK:
|
|
161
|
+
mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
|
|
162
|
+
acc_gated = acc_gated * mask[:, None]
|
|
163
|
+
|
|
164
|
+
acc_gated = acc_gated.to(o_ptr.dtype.element_ty)
|
|
165
|
+
|
|
166
|
+
if TRANSPOSE_OUT:
|
|
167
|
+
o_ptrs = o_ptr + offs_on[None, :] * M + offs_om[:, None]
|
|
168
|
+
else:
|
|
169
|
+
o_ptrs = o_ptr + offs_om[:, None] * N + offs_on[None, :]
|
|
170
|
+
|
|
171
|
+
o_mask = offs_om[:, None] < M
|
|
172
|
+
tl.store(o_ptrs, acc_gated, mask=o_mask)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@triton.jit
|
|
176
|
+
def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel(
|
|
177
|
+
grad_xw1_ptr,
|
|
178
|
+
grad_xw2_ptr,
|
|
179
|
+
grad_mask_ptr,
|
|
180
|
+
grad_o_ptr,
|
|
181
|
+
x1_ptr,
|
|
182
|
+
x2_ptr,
|
|
183
|
+
w1_ptr,
|
|
184
|
+
w2_ptr,
|
|
185
|
+
mask_ptr,
|
|
186
|
+
M,
|
|
187
|
+
N,
|
|
188
|
+
K,
|
|
189
|
+
TILE_M: tl.constexpr,
|
|
190
|
+
TILE_N: tl.constexpr,
|
|
191
|
+
TILE_K: tl.constexpr,
|
|
192
|
+
PRECISION: tl.constexpr,
|
|
193
|
+
APPLY_MASK: tl.constexpr,
|
|
194
|
+
TRANSPOSE_OUT: tl.constexpr,
|
|
195
|
+
TWO_INPUTS: tl.constexpr,
|
|
196
|
+
):
|
|
197
|
+
# fully gated GEMM kernel with optional mask at the end
|
|
198
|
+
pid_m = tl.program_id(axis=0)
|
|
199
|
+
pid_n = tl.program_id(axis=1)
|
|
200
|
+
|
|
201
|
+
start_m = pid_m * TILE_M
|
|
202
|
+
start_n = pid_n * TILE_N
|
|
203
|
+
|
|
204
|
+
offs_xm = start_m + tl.arange(0, TILE_M)
|
|
205
|
+
offs_wn = start_n + tl.arange(0, TILE_N)
|
|
206
|
+
offs_k = tl.arange(0, TILE_K)
|
|
207
|
+
|
|
208
|
+
x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
|
|
209
|
+
if TWO_INPUTS:
|
|
210
|
+
x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
|
|
211
|
+
w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
|
|
212
|
+
|
|
213
|
+
acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
|
214
|
+
acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
|
215
|
+
|
|
216
|
+
mask_m = offs_xm < M
|
|
217
|
+
|
|
218
|
+
if TWO_INPUTS:
|
|
219
|
+
# recompute acc1 and acc2
|
|
220
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
221
|
+
x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
222
|
+
w1_ptr.type.element_ty
|
|
223
|
+
)
|
|
224
|
+
w1_ptrs = w1_ptr + w_tile_offs
|
|
225
|
+
w1 = tl.load(w1_ptrs)
|
|
226
|
+
|
|
227
|
+
if PRECISION == 0:
|
|
228
|
+
acc_1 = tl.dot(x1, w1, acc_1)
|
|
229
|
+
elif PRECISION == 1:
|
|
230
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
|
|
231
|
+
elif PRECISION == 2:
|
|
232
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
|
|
233
|
+
elif PRECISION == 3:
|
|
234
|
+
acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
|
|
235
|
+
else:
|
|
236
|
+
tl.static_assert(
|
|
237
|
+
False,
|
|
238
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
x1_ptrs += TILE_K
|
|
242
|
+
w1_ptr += TILE_K
|
|
243
|
+
|
|
244
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
245
|
+
x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
246
|
+
w2_ptr.type.element_ty
|
|
247
|
+
)
|
|
248
|
+
w2_ptrs = w2_ptr + w_tile_offs
|
|
249
|
+
w2 = tl.load(w2_ptrs)
|
|
250
|
+
|
|
251
|
+
if PRECISION == 0:
|
|
252
|
+
acc_2 = tl.dot(x2, w2, acc_2)
|
|
253
|
+
elif PRECISION == 1:
|
|
254
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
|
|
255
|
+
elif PRECISION == 2:
|
|
256
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
|
|
257
|
+
elif PRECISION == 3:
|
|
258
|
+
acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
|
|
259
|
+
else:
|
|
260
|
+
tl.static_assert(
|
|
261
|
+
False,
|
|
262
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
x2_ptrs += TILE_K
|
|
266
|
+
w2_ptr += TILE_K
|
|
267
|
+
|
|
268
|
+
else:
|
|
269
|
+
# recompute acc1 and acc2
|
|
270
|
+
for _ in range(0, tl.cdiv(K, TILE_K)):
|
|
271
|
+
x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
|
|
272
|
+
w1_ptr.type.element_ty
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
w1_ptrs = w1_ptr + w_tile_offs
|
|
276
|
+
w1 = tl.load(w1_ptrs)
|
|
277
|
+
if PRECISION == 0:
|
|
278
|
+
acc_1 = tl.dot(x, w1, acc_1)
|
|
279
|
+
elif PRECISION == 1:
|
|
280
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
|
|
281
|
+
elif PRECISION == 2:
|
|
282
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
|
|
283
|
+
elif PRECISION == 3:
|
|
284
|
+
acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
|
|
285
|
+
else:
|
|
286
|
+
tl.static_assert(
|
|
287
|
+
False,
|
|
288
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
w2_ptrs = w2_ptr + w_tile_offs
|
|
292
|
+
w2 = tl.load(w2_ptrs)
|
|
293
|
+
if PRECISION == 0:
|
|
294
|
+
acc_2 = tl.dot(x, w2, acc_2)
|
|
295
|
+
elif PRECISION == 1:
|
|
296
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
|
|
297
|
+
elif PRECISION == 2:
|
|
298
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
|
|
299
|
+
elif PRECISION == 3:
|
|
300
|
+
acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
|
|
301
|
+
else:
|
|
302
|
+
tl.static_assert(
|
|
303
|
+
False,
|
|
304
|
+
"PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
x1_ptrs += TILE_K
|
|
308
|
+
w1_ptr += TILE_K
|
|
309
|
+
w2_ptr += TILE_K
|
|
310
|
+
|
|
311
|
+
offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
|
|
312
|
+
offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
|
|
313
|
+
if TRANSPOSE_OUT:
|
|
314
|
+
grad_o_ptrs = grad_o_ptr + offs_on[None, :] * M + offs_om[:, None]
|
|
315
|
+
else:
|
|
316
|
+
grad_o_ptrs = grad_o_ptr + offs_om[:, None] * N + offs_on[None, :]
|
|
317
|
+
|
|
318
|
+
grad_o = tl.load(grad_o_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
|
|
319
|
+
|
|
320
|
+
acc_sig = 1.0 / (1.0 + tl.exp(-acc_1))
|
|
321
|
+
|
|
322
|
+
if APPLY_MASK:
|
|
323
|
+
tmp = acc_sig * acc_2
|
|
324
|
+
grad_mask = grad_o * tmp
|
|
325
|
+
grad_mask = tl.sum(grad_mask, axis=1)
|
|
326
|
+
grad_mask_ptrs = grad_mask_ptr + pid_n * M + offs_om
|
|
327
|
+
tl.store(grad_mask_ptrs, grad_mask.to(grad_mask.type.element_ty), mask=mask_m)
|
|
328
|
+
|
|
329
|
+
mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0)
|
|
330
|
+
grad_o = grad_o * mask[:, None]
|
|
331
|
+
|
|
332
|
+
tmp = (1.0 - acc_sig) * acc_sig
|
|
333
|
+
|
|
334
|
+
grad_xw1 = grad_o * acc_2 * tmp
|
|
335
|
+
grad_xw2 = grad_o * acc_sig
|
|
336
|
+
|
|
337
|
+
grad_xw1_ptrs = grad_xw1_ptr + offs_om[:, None] * N + offs_on[None, :]
|
|
338
|
+
grad_xw2_ptrs = grad_xw2_ptr + offs_om[:, None] * N + offs_on[None, :]
|
|
339
|
+
tl.store(grad_xw1_ptrs, grad_xw1.to(grad_xw1.type.element_ty), mask=mask_m[:, None])
|
|
340
|
+
tl.store(grad_xw2_ptrs, grad_xw2.to(grad_xw2.type.element_ty), mask=mask_m[:, None])
|
|
Binary file
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
import inspect
|
|
12
|
+
|
|
13
|
+
# import logging # Added logging import
|
|
14
|
+
import os
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Any, Callable
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from .cache_manager import CacheManager
|
|
22
|
+
|
|
23
|
+
# Configure logging
|
|
24
|
+
# logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BenchmarkMode(Enum):
|
|
28
|
+
FLUSH_CACHE = 0
|
|
29
|
+
FLUSH_CACHE_PEAK_PROXY = 1
|
|
30
|
+
ROT_BUFFER = 2
|
|
31
|
+
ROT_BUFFER_PEAK_PROXY = 3
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def run_bench(
|
|
35
|
+
f, input_dict, warmup_iter=250, run_iter=250, bench_mode=BenchmarkMode.ROT_BUFFER
|
|
36
|
+
):
|
|
37
|
+
initial_output = f(**input_dict)
|
|
38
|
+
|
|
39
|
+
if bench_mode in (BenchmarkMode.ROT_BUFFER, BenchmarkMode.ROT_BUFFER_PEAK_PROXY):
|
|
40
|
+
len_rot = 4
|
|
41
|
+
inputs_rot = [None] * len_rot
|
|
42
|
+
for r in range(len_rot):
|
|
43
|
+
r_inputs = []
|
|
44
|
+
for key, value in input_dict.items():
|
|
45
|
+
if isinstance(value, torch.Tensor):
|
|
46
|
+
if bench_mode == BenchmarkMode.ROT_BUFFER_PEAK_PROXY:
|
|
47
|
+
r_inputs.append(
|
|
48
|
+
(
|
|
49
|
+
key,
|
|
50
|
+
torch.ones_like(
|
|
51
|
+
value, requires_grad=value.requires_grad
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
r_inputs.append(
|
|
57
|
+
(
|
|
58
|
+
key,
|
|
59
|
+
torch.randn_like(
|
|
60
|
+
value, requires_grad=value.requires_grad
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
r_inputs.append((key, value))
|
|
66
|
+
r_inputs = dict(r_inputs)
|
|
67
|
+
inputs_rot[r] = r_inputs
|
|
68
|
+
|
|
69
|
+
for it in range(warmup_iter):
|
|
70
|
+
_ = f(**inputs_rot[it % len_rot])
|
|
71
|
+
|
|
72
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
73
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
74
|
+
start.record()
|
|
75
|
+
for it in range(run_iter):
|
|
76
|
+
_ = f(**inputs_rot[it % len_rot])
|
|
77
|
+
end.record()
|
|
78
|
+
torch.cuda.synchronize()
|
|
79
|
+
elapsed = start.elapsed_time(end)
|
|
80
|
+
|
|
81
|
+
elif bench_mode in (
|
|
82
|
+
BenchmarkMode.FLUSH_CACHE,
|
|
83
|
+
BenchmarkMode.FLUSH_CACHE_PEAK_PROXY,
|
|
84
|
+
):
|
|
85
|
+
cache_filler = torch.empty(1024 * 1024 * 256, dtype=torch.int8, device="cuda")
|
|
86
|
+
|
|
87
|
+
if bench_mode == BenchmarkMode.FLUSH_CACHE_PEAK_PROXY:
|
|
88
|
+
_inputs = {}
|
|
89
|
+
for key, value in input_dict.items():
|
|
90
|
+
if isinstance(value, torch.Tensor):
|
|
91
|
+
_inputs.append(
|
|
92
|
+
(key, torch.ones_like(value, requires_grad=value.requires_grad))
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
_inputs.append((key, value))
|
|
96
|
+
input_dict = _inputs
|
|
97
|
+
|
|
98
|
+
for _ in range(warmup_iter):
|
|
99
|
+
cache_filler.zero_()
|
|
100
|
+
_ = f(**input_dict)
|
|
101
|
+
|
|
102
|
+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
|
|
103
|
+
ends = [torch.cuda.Event(enable_timing=True) for _ in range(run_iter)]
|
|
104
|
+
for i in range(run_iter):
|
|
105
|
+
cache_filler.zero_()
|
|
106
|
+
starts[i].record()
|
|
107
|
+
_ = f(**input_dict)
|
|
108
|
+
ends[i].record()
|
|
109
|
+
torch.cuda.synchronize()
|
|
110
|
+
elapsed = sum(s.elapsed_time(e) for s, e in zip(starts, ends))
|
|
111
|
+
|
|
112
|
+
return elapsed / run_iter, initial_output
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def input_to_key_default(**args) -> str:
|
|
116
|
+
key_parts = []
|
|
117
|
+
for arg in args:
|
|
118
|
+
if isinstance(arg, torch.Tensor):
|
|
119
|
+
key_parts.append(f"{list(arg.shape)}_{arg.dtype}")
|
|
120
|
+
elif isinstance(arg, bool):
|
|
121
|
+
key_parts.append("True" if arg else "False")
|
|
122
|
+
elif isinstance(arg, str):
|
|
123
|
+
key_parts.append(arg)
|
|
124
|
+
else:
|
|
125
|
+
key_parts.append(str(arg.__class__.__name__))
|
|
126
|
+
|
|
127
|
+
return "_".join(key_parts)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def combine_all_kwargs(
|
|
131
|
+
fn: Callable,
|
|
132
|
+
args: tuple,
|
|
133
|
+
kwargs: dict[str, Any],
|
|
134
|
+
) -> dict[str, Any]:
|
|
135
|
+
# Get the function signature
|
|
136
|
+
sig = inspect.signature(fn)
|
|
137
|
+
params = sig.parameters
|
|
138
|
+
param_names = list(params.keys())
|
|
139
|
+
|
|
140
|
+
# Create dictionary of default values
|
|
141
|
+
defaults = {
|
|
142
|
+
name: param.default
|
|
143
|
+
for name, param in params.items()
|
|
144
|
+
if param.default is not inspect.Parameter.empty
|
|
145
|
+
}
|
|
146
|
+
# Create dictionary mapping positional args to parameter names
|
|
147
|
+
args_as_kwargs = {
|
|
148
|
+
param_names[i]: args[i] for i in range(min(len(args), len(param_names)))
|
|
149
|
+
}
|
|
150
|
+
# Create combined dictionary of all parameters
|
|
151
|
+
all_kwargs = defaults.copy() # Start with defaults
|
|
152
|
+
all_kwargs.update(args_as_kwargs) # Override with positional args
|
|
153
|
+
all_kwargs.update(kwargs) # Override with explicit kwargs
|
|
154
|
+
|
|
155
|
+
return all_kwargs
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def autotune_aot(
|
|
159
|
+
input_generator: Callable,
|
|
160
|
+
input_to_key: Callable | None,
|
|
161
|
+
input_configs: list[dict[str, Any]],
|
|
162
|
+
tunable_configs: list[dict[str, Any]],
|
|
163
|
+
prune_configs_fn: Callable[
|
|
164
|
+
[list[dict[str, Any]], dict[str, Any]], list[dict[str, Any]]
|
|
165
|
+
]
|
|
166
|
+
| None,
|
|
167
|
+
bench_mode=BenchmarkMode.ROT_BUFFER,
|
|
168
|
+
warmup_iter=25,
|
|
169
|
+
run_iter=100,
|
|
170
|
+
) -> None:
|
|
171
|
+
def decorator(fn: Callable) -> Callable:
|
|
172
|
+
def wrapper(*args, **kwargs):
|
|
173
|
+
all_kwargs = combine_all_kwargs(fn, args, kwargs)
|
|
174
|
+
nonlocal input_to_key
|
|
175
|
+
nonlocal input_configs
|
|
176
|
+
# Early exit if AOT is disabled
|
|
177
|
+
if os.environ.get("CUEQ_DISABLE_AOT_TUNING") == "1":
|
|
178
|
+
if input_to_key is None:
|
|
179
|
+
input_to_key = input_to_key_default
|
|
180
|
+
|
|
181
|
+
function_key = fn.__name__
|
|
182
|
+
cache_manager = CacheManager()
|
|
183
|
+
input_key = input_to_key(**all_kwargs)
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
best_cached_config = cache_manager.get(function_key, input_key)
|
|
187
|
+
if best_cached_config is not None:
|
|
188
|
+
all_kwargs.update(best_cached_config["config"])
|
|
189
|
+
return fn(**all_kwargs)
|
|
190
|
+
elif os.environ.get("CUEQ_DEFAULT_CONFIG") != "1":
|
|
191
|
+
input_configs = [None]
|
|
192
|
+
# Continue to rest of function for tuning
|
|
193
|
+
pass
|
|
194
|
+
except (PermissionError, OSError):
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
if os.environ.get("CUEQ_DEFAULT_CONFIG") == "1":
|
|
198
|
+
return fn(**all_kwargs)
|
|
199
|
+
|
|
200
|
+
if input_to_key is None:
|
|
201
|
+
input_to_key = input_to_key_default
|
|
202
|
+
|
|
203
|
+
# Check if the function is already cached
|
|
204
|
+
function_key = fn.__name__
|
|
205
|
+
cache_manager = CacheManager()
|
|
206
|
+
input_key = input_to_key(**all_kwargs)
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
best_cached_config = cache_manager.get(function_key, input_key)
|
|
210
|
+
except (PermissionError, OSError):
|
|
211
|
+
# If there's a permission error, fall back to JIT compilation
|
|
212
|
+
# logger.warning(f"Permission error accessing cache: {e}. Falling back to JIT compilation.")
|
|
213
|
+
return fn(**all_kwargs)
|
|
214
|
+
|
|
215
|
+
if best_cached_config is not None:
|
|
216
|
+
# If cached, return the cached result
|
|
217
|
+
all_kwargs.update(best_cached_config["config"])
|
|
218
|
+
result = fn(**all_kwargs)
|
|
219
|
+
return result
|
|
220
|
+
else:
|
|
221
|
+
# start autotuning process
|
|
222
|
+
input_configs = input_configs + [None]
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
# Initialize the progress bar
|
|
226
|
+
progress_bar = tqdm(
|
|
227
|
+
input_configs, desc="Autotuning Progress", unit="config"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
for input_config in progress_bar:
|
|
231
|
+
# generate input based on the config
|
|
232
|
+
input_data = (
|
|
233
|
+
input_generator(**input_config)
|
|
234
|
+
if input_config is not None
|
|
235
|
+
else all_kwargs
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Make a copy of all_kwargs to avoid modifying the original
|
|
239
|
+
current_kwargs = all_kwargs.copy()
|
|
240
|
+
current_kwargs.update(input_data)
|
|
241
|
+
current_input_key = input_to_key(**current_kwargs)
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
best_cached_config = cache_manager.get(
|
|
245
|
+
function_key, current_input_key
|
|
246
|
+
)
|
|
247
|
+
except (PermissionError, OSError):
|
|
248
|
+
# If there's a permission error, skip this config and continue
|
|
249
|
+
# logger.warning(f"Permission error accessing cache for config: {e}. Skipping.")
|
|
250
|
+
continue
|
|
251
|
+
|
|
252
|
+
if best_cached_config is not None:
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
# prune the tunable configs based on the all_kwargs
|
|
256
|
+
pruned_tunable_configs = (
|
|
257
|
+
prune_configs_fn(tunable_configs, **all_kwargs)
|
|
258
|
+
if prune_configs_fn is not None
|
|
259
|
+
else tunable_configs
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
best_config = None
|
|
263
|
+
best_time = float("inf")
|
|
264
|
+
working_config = []
|
|
265
|
+
for tunable in pruned_tunable_configs:
|
|
266
|
+
try:
|
|
267
|
+
current_kwargs.update(tunable)
|
|
268
|
+
fn(**current_kwargs)
|
|
269
|
+
torch.cuda.synchronize()
|
|
270
|
+
working_config.append(tunable)
|
|
271
|
+
except Exception:
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
if not working_config:
|
|
275
|
+
continue
|
|
276
|
+
|
|
277
|
+
for tunable in working_config:
|
|
278
|
+
current_kwargs.update(tunable)
|
|
279
|
+
elapse, _ = run_bench(
|
|
280
|
+
fn,
|
|
281
|
+
current_kwargs,
|
|
282
|
+
warmup_iter=warmup_iter,
|
|
283
|
+
run_iter=run_iter,
|
|
284
|
+
bench_mode=bench_mode,
|
|
285
|
+
)
|
|
286
|
+
if elapse < best_time:
|
|
287
|
+
best_time = elapse
|
|
288
|
+
best_config = tunable
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
cache_manager.set(
|
|
292
|
+
function_key,
|
|
293
|
+
current_input_key,
|
|
294
|
+
{"config": best_config, "time": best_time},
|
|
295
|
+
)
|
|
296
|
+
except (PermissionError, OSError):
|
|
297
|
+
# If there's a permission error saving the cache, continue without saving
|
|
298
|
+
# logger.warning(f"Permission error saving cache: {e}. Continuing without saving.")
|
|
299
|
+
pass
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
# Save the cache to a file
|
|
303
|
+
cache_manager.save_cache(function_key)
|
|
304
|
+
except (PermissionError, OSError):
|
|
305
|
+
# If there's a permission error saving the final cache, continue without saving
|
|
306
|
+
# logger.warning(f"Permission error saving final cache: {e}. Continuing without saving.")
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
# After tuning, try to get the best config
|
|
310
|
+
try:
|
|
311
|
+
best_cached_config = cache_manager.get(function_key, input_key)
|
|
312
|
+
if best_cached_config is not None:
|
|
313
|
+
all_kwargs.update(best_cached_config["config"])
|
|
314
|
+
except (PermissionError, OSError):
|
|
315
|
+
# If there's a permission error getting the final config, continue without it
|
|
316
|
+
# logger.warning(f"Permission error getting final config: {e}. Continuing without config.")
|
|
317
|
+
pass
|
|
318
|
+
|
|
319
|
+
except Exception:
|
|
320
|
+
# If any other error occurs during tuning, fall back to JIT compilation
|
|
321
|
+
# logger.warning(f"Error during tuning: {e}. Falling back to JIT compilation.")
|
|
322
|
+
pass
|
|
323
|
+
|
|
324
|
+
return fn(**all_kwargs)
|
|
325
|
+
|
|
326
|
+
return wrapper
|
|
327
|
+
|
|
328
|
+
return decorator
|
{cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.0.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: cuequivariance-ops-cu12
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: cuequivariance-ops - GPU Accelerated Extensions for Equivariant Primitives
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
License: # Software License Agreement
|
|
@@ -177,6 +177,9 @@ Classifier: Programming Language :: Python
|
|
|
177
177
|
Project-URL: Homepage, https://github.com/nvidia/cuEquivariance
|
|
178
178
|
Project-URL: Documentation, https://github.com/nvidia/cuEquivariance
|
|
179
179
|
Requires-Python: >=3.10
|
|
180
|
+
Requires-Dist: nvidia-cublas-cu12>=12.5.0
|
|
181
|
+
Requires-Dist: tqdm
|
|
182
|
+
Requires-Dist: pynvml
|
|
180
183
|
Provides-Extra: test
|
|
181
184
|
Requires-Dist: numpy; extra == "test"
|
|
182
185
|
Requires-Dist: pytest; extra == "test"
|