aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
#
|
|
4
|
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
5
|
+
# copy of this software and associated documentation files (the "Software"),
|
|
6
|
+
# to deal in the Software without restriction, including without limitation
|
|
7
|
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
8
|
+
# and/or sell copies of the Software, and to permit persons to whom the
|
|
9
|
+
# Software is furnished to do so, subject to the following conditions:
|
|
10
|
+
#
|
|
11
|
+
# The above copyright notice and this permission notice shall be included in
|
|
12
|
+
# all copies or substantial portions of the Software.
|
|
13
|
+
#
|
|
14
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
15
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
16
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
17
|
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
18
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
19
|
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
20
|
+
# DEALINGS IN THE SOFTWARE.
|
|
21
|
+
|
|
22
|
+
# type: ignore
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import warp as wp
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
|
|
28
|
+
wp.init()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_stream(device: torch.device):
|
|
32
|
+
"""Get the Warp stream for the given device."""
|
|
33
|
+
if device.type == "cuda":
|
|
34
|
+
return wp.stream_from_torch(torch.cuda.current_stream(device))
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Warp Kernels
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@wp.kernel(enable_backward=False)
|
|
44
|
+
def _conv_sv_2d_sp_kernel(
|
|
45
|
+
a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
46
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
47
|
+
g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
48
|
+
output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
|
|
49
|
+
):
|
|
50
|
+
"""Forward: output[b,a,g] = sum_m a[idx[b,m],a,g] * g[b,m,g]"""
|
|
51
|
+
B, M = idx.shape[0], idx.shape[1]
|
|
52
|
+
padding_value = B - 1 # last row is padding
|
|
53
|
+
|
|
54
|
+
_b, _a, _g = wp.tid()
|
|
55
|
+
|
|
56
|
+
acc = wp.vec4f()
|
|
57
|
+
for _m in range(M):
|
|
58
|
+
_idx = idx[_b, _m]
|
|
59
|
+
if _idx >= padding_value:
|
|
60
|
+
break
|
|
61
|
+
a_val = a[_idx, _a, _g]
|
|
62
|
+
g_val = g[_b, _m, _g]
|
|
63
|
+
acc += a_val * g_val
|
|
64
|
+
output[_b, _a, _g] = acc
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@wp.kernel(enable_backward=False)
|
|
68
|
+
def _conv_sv_2d_sp_backward_a_kernel(
|
|
69
|
+
grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
|
|
70
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
71
|
+
g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
72
|
+
grad_a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
73
|
+
):
|
|
74
|
+
"""Backward w.r.t. a: grad_a[idx[b,m],a,g] += dot(grad_output[b,a,g], g[b,m,g])"""
|
|
75
|
+
B, M = idx.shape[0], idx.shape[1]
|
|
76
|
+
padding_value = B - 1 # last row is padding
|
|
77
|
+
|
|
78
|
+
_b, _a, _g = wp.tid()
|
|
79
|
+
|
|
80
|
+
grad_out = grad_output[_b, _a, _g]
|
|
81
|
+
for _m in range(M):
|
|
82
|
+
_idx = idx[_b, _m]
|
|
83
|
+
if _idx >= padding_value:
|
|
84
|
+
break
|
|
85
|
+
g_val = g[_b, _m, _g]
|
|
86
|
+
val = wp.dot(grad_out, g_val)
|
|
87
|
+
wp.atomic_add(grad_a, _idx, _a, _g, val)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@wp.kernel(enable_backward=False)
|
|
91
|
+
def _conv_sv_2d_sp_backward_g_kernel(
|
|
92
|
+
grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
|
|
93
|
+
a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
94
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
95
|
+
grad_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
96
|
+
):
|
|
97
|
+
"""Backward w.r.t. g: grad_g[b,m,g] = sum_a a[idx[b,m],a,g] * grad_output[b,a,g]"""
|
|
98
|
+
B = idx.shape[0]
|
|
99
|
+
A = a.shape[1]
|
|
100
|
+
padding_value = B - 1 # last row is padding
|
|
101
|
+
|
|
102
|
+
_b, _m, _g = wp.tid()
|
|
103
|
+
|
|
104
|
+
_idx = idx[_b, _m]
|
|
105
|
+
if _idx >= padding_value:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
acc = wp.vec4f()
|
|
109
|
+
|
|
110
|
+
for _a in range(A):
|
|
111
|
+
grad_out = grad_output[_b, _a, _g]
|
|
112
|
+
a_val = a[_idx, _a, _g]
|
|
113
|
+
acc += a_val * grad_out
|
|
114
|
+
|
|
115
|
+
grad_g[_b, _m, _g] = acc
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@wp.kernel(enable_backward=False)
|
|
119
|
+
def _conv_sv_2d_sp_double_backward_a_g_kernel(
|
|
120
|
+
grad_grad_a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
121
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
122
|
+
grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
|
|
123
|
+
grad_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
124
|
+
):
|
|
125
|
+
"""Double backward: d(grad_a)/dg -> grad_g"""
|
|
126
|
+
B = idx.shape[0]
|
|
127
|
+
A = grad_grad_a.shape[1]
|
|
128
|
+
padding_value = B - 1 # last row is padding
|
|
129
|
+
|
|
130
|
+
_b, _m, _g = wp.tid()
|
|
131
|
+
|
|
132
|
+
_idx = idx[_b, _m]
|
|
133
|
+
if _idx >= padding_value:
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
acc = wp.vec4f()
|
|
137
|
+
|
|
138
|
+
for _a in range(A):
|
|
139
|
+
grad_grad_a_val = grad_grad_a[_idx, _a, _g]
|
|
140
|
+
grad_out = grad_output[_b, _a, _g]
|
|
141
|
+
acc += grad_grad_a_val * grad_out
|
|
142
|
+
|
|
143
|
+
grad_g[_b, _m, _g] = acc
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@wp.kernel(enable_backward=False)
|
|
147
|
+
def _conv_sv_2d_sp_double_backward_g_contrib_kernel(
|
|
148
|
+
grad2_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
149
|
+
a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
150
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
151
|
+
grad_output_double: wp.array3d(dtype=wp.vec4f), # (B, A, G, D) - OUTPUT
|
|
152
|
+
):
|
|
153
|
+
"""Double backward from grad2_g: einsum('bmgd,bmag->bagd', grad2_g, a_selected)"""
|
|
154
|
+
B, M = idx.shape[0], idx.shape[1]
|
|
155
|
+
padding_value = B - 1 # last row is padding
|
|
156
|
+
|
|
157
|
+
_b, _a, _g = wp.tid()
|
|
158
|
+
|
|
159
|
+
acc = wp.vec4f()
|
|
160
|
+
for _m in range(M):
|
|
161
|
+
_idx = idx[_b, _m]
|
|
162
|
+
if _idx >= padding_value:
|
|
163
|
+
break
|
|
164
|
+
a_val = a[_idx, _a, _g]
|
|
165
|
+
grad2_g_val = grad2_g[_b, _m, _g]
|
|
166
|
+
acc += a_val * grad2_g_val
|
|
167
|
+
|
|
168
|
+
grad_output_double[_b, _a, _g] = acc
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@wp.kernel(enable_backward=False)
|
|
172
|
+
def _conv_sv_2d_sp_double_backward_a_contrib_kernel(
|
|
173
|
+
grad2_a: wp.array3d(dtype=wp.float32), # (B, A, G)
|
|
174
|
+
idx: wp.array2d(dtype=wp.int32), # (B, M)
|
|
175
|
+
g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
|
|
176
|
+
grad_output_double: wp.array3d(dtype=wp.vec4f), # (B, A, G, D) - OUTPUT
|
|
177
|
+
):
|
|
178
|
+
"""Double backward from grad2_a: einsum('bmag,bmgd->bagd', grad2_a_selected, g)"""
|
|
179
|
+
B, M = idx.shape[0], idx.shape[1]
|
|
180
|
+
padding_value = B - 1 # last row is padding
|
|
181
|
+
|
|
182
|
+
_b, _a, _g = wp.tid()
|
|
183
|
+
|
|
184
|
+
acc = wp.vec4f()
|
|
185
|
+
for _m in range(M):
|
|
186
|
+
_idx = idx[_b, _m]
|
|
187
|
+
if _idx >= padding_value:
|
|
188
|
+
break
|
|
189
|
+
grad2_a_val = grad2_a[_idx, _a, _g]
|
|
190
|
+
g_val = g[_b, _m, _g]
|
|
191
|
+
acc += grad2_a_val * g_val
|
|
192
|
+
|
|
193
|
+
grad_output_double[_b, _a, _g] = acc
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# =============================================================================
|
|
197
|
+
# PyTorch Custom Op Primitives
|
|
198
|
+
# =============================================================================
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@torch.library.custom_op(
|
|
202
|
+
"aimnet::conv_sv_2d_sp_fwd",
|
|
203
|
+
mutates_args=(),
|
|
204
|
+
device_types=["cuda"],
|
|
205
|
+
)
|
|
206
|
+
def _(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
|
|
207
|
+
"""Forward primitive for conv_sv_2d_sp."""
|
|
208
|
+
stream = _get_stream(a.device)
|
|
209
|
+
device = wp.device_from_torch(a.device)
|
|
210
|
+
B, A, G = a.shape
|
|
211
|
+
output = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
|
|
212
|
+
|
|
213
|
+
wp.launch(
|
|
214
|
+
_conv_sv_2d_sp_kernel,
|
|
215
|
+
dim=(B - 1, A, G), # B-1: exclude padding row
|
|
216
|
+
stream=stream,
|
|
217
|
+
device=device,
|
|
218
|
+
inputs=(
|
|
219
|
+
wp.from_torch(a.detach(), return_ctype=True),
|
|
220
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
221
|
+
wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
|
|
222
|
+
wp.from_torch(output, return_ctype=True, dtype=wp.vec4f),
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
return output
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@torch.library.register_fake("aimnet::conv_sv_2d_sp_fwd")
|
|
229
|
+
def _(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
|
|
230
|
+
B, A, G = a.shape
|
|
231
|
+
return torch.empty(B, A, G, 4, dtype=a.dtype, device=a.device)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@torch.library.custom_op(
|
|
235
|
+
"aimnet::conv_sv_2d_sp_bwd",
|
|
236
|
+
mutates_args=(),
|
|
237
|
+
device_types=["cuda"],
|
|
238
|
+
)
|
|
239
|
+
def _(grad_output: Tensor, a: Tensor, idx: Tensor, g: Tensor) -> list[Tensor]:
|
|
240
|
+
"""Backward primitive for conv_sv_2d_sp."""
|
|
241
|
+
stream = _get_stream(a.device)
|
|
242
|
+
device = wp.device_from_torch(a.device)
|
|
243
|
+
B, A, G = a.shape
|
|
244
|
+
B_out, M = idx.shape
|
|
245
|
+
|
|
246
|
+
grad_a = torch.zeros_like(a)
|
|
247
|
+
grad_g = torch.zeros(B_out, M, G, 4, dtype=g.dtype, device=g.device)
|
|
248
|
+
|
|
249
|
+
grad_output_contig = grad_output.detach().contiguous()
|
|
250
|
+
|
|
251
|
+
# Launch backward w.r.t. a
|
|
252
|
+
wp.launch(
|
|
253
|
+
_conv_sv_2d_sp_backward_a_kernel,
|
|
254
|
+
dim=(B - 1, A, G), # B-1: exclude padding row
|
|
255
|
+
stream=stream,
|
|
256
|
+
device=device,
|
|
257
|
+
inputs=(
|
|
258
|
+
wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
|
|
259
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
260
|
+
wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
|
|
261
|
+
wp.from_torch(grad_a, return_ctype=True),
|
|
262
|
+
),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Launch backward w.r.t. g
|
|
266
|
+
wp.launch(
|
|
267
|
+
_conv_sv_2d_sp_backward_g_kernel,
|
|
268
|
+
dim=(B_out - 1, M, G), # B_out-1: exclude padding row
|
|
269
|
+
stream=stream,
|
|
270
|
+
device=device,
|
|
271
|
+
inputs=(
|
|
272
|
+
wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
|
|
273
|
+
wp.from_torch(a.detach(), return_ctype=True),
|
|
274
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
275
|
+
wp.from_torch(grad_g, return_ctype=True, dtype=wp.vec4f),
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return [grad_a, grad_g]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@torch.library.register_fake("aimnet::conv_sv_2d_sp_bwd")
|
|
283
|
+
def _(grad_output: Tensor, a: Tensor, idx: Tensor, g: Tensor) -> list[Tensor]:
|
|
284
|
+
B_out, M = idx.shape
|
|
285
|
+
G = a.shape[2]
|
|
286
|
+
return [
|
|
287
|
+
torch.empty_like(a),
|
|
288
|
+
torch.empty(B_out, M, G, 4, dtype=g.dtype, device=g.device),
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@torch.library.custom_op(
|
|
293
|
+
"aimnet::conv_sv_2d_sp_bwd_bwd",
|
|
294
|
+
mutates_args=(),
|
|
295
|
+
device_types=["cuda"],
|
|
296
|
+
)
|
|
297
|
+
def _(
|
|
298
|
+
grad_output: Tensor,
|
|
299
|
+
grad2_a: Tensor,
|
|
300
|
+
grad2_g: Tensor,
|
|
301
|
+
a: Tensor,
|
|
302
|
+
idx: Tensor,
|
|
303
|
+
g: Tensor,
|
|
304
|
+
) -> list[Tensor]:
|
|
305
|
+
"""Double backward primitive for conv_sv_2d_sp."""
|
|
306
|
+
stream = _get_stream(a.device)
|
|
307
|
+
device = wp.device_from_torch(a.device)
|
|
308
|
+
B, A, G = a.shape
|
|
309
|
+
B_out, M = idx.shape
|
|
310
|
+
|
|
311
|
+
grad_grad_output = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
|
|
312
|
+
grad_a_double = torch.zeros_like(a)
|
|
313
|
+
grad_g_double = torch.zeros(B_out, M, G, 4, dtype=a.dtype, device=a.device)
|
|
314
|
+
|
|
315
|
+
grad_output_contig = grad_output.detach().contiguous()
|
|
316
|
+
grad2_a_contig = grad2_a.detach().contiguous()
|
|
317
|
+
grad2_g_contig = grad2_g.detach().contiguous()
|
|
318
|
+
|
|
319
|
+
# Contribution from grad2_g to grad_grad_output
|
|
320
|
+
wp.launch(
|
|
321
|
+
_conv_sv_2d_sp_double_backward_g_contrib_kernel,
|
|
322
|
+
dim=(B - 1, A, G), # B-1: exclude padding row
|
|
323
|
+
stream=stream,
|
|
324
|
+
device=device,
|
|
325
|
+
inputs=(
|
|
326
|
+
wp.from_torch(grad2_g_contig, return_ctype=True, dtype=wp.vec4f),
|
|
327
|
+
wp.from_torch(a.detach(), return_ctype=True),
|
|
328
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
329
|
+
wp.from_torch(grad_grad_output, return_ctype=True, dtype=wp.vec4f),
|
|
330
|
+
),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Contribution from grad2_a to grad_grad_output
|
|
334
|
+
grad_output_2_a = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
|
|
335
|
+
wp.launch(
|
|
336
|
+
_conv_sv_2d_sp_double_backward_a_contrib_kernel,
|
|
337
|
+
dim=(B - 1, A, G), # B-1: exclude padding row
|
|
338
|
+
stream=stream,
|
|
339
|
+
device=device,
|
|
340
|
+
inputs=(
|
|
341
|
+
wp.from_torch(grad2_a_contig, return_ctype=True),
|
|
342
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
343
|
+
wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
|
|
344
|
+
wp.from_torch(grad_output_2_a, return_ctype=True, dtype=wp.vec4f),
|
|
345
|
+
),
|
|
346
|
+
)
|
|
347
|
+
grad_grad_output = grad_grad_output + grad_output_2_a
|
|
348
|
+
|
|
349
|
+
# Mixed partial: d(grad_a)/dg -> grad_g_double
|
|
350
|
+
wp.launch(
|
|
351
|
+
_conv_sv_2d_sp_double_backward_a_g_kernel,
|
|
352
|
+
dim=(B_out - 1, M, G), # B_out-1: exclude padding row
|
|
353
|
+
stream=stream,
|
|
354
|
+
device=device,
|
|
355
|
+
inputs=(
|
|
356
|
+
wp.from_torch(grad2_a_contig, return_ctype=True),
|
|
357
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
358
|
+
wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
|
|
359
|
+
wp.from_torch(grad_g_double, return_ctype=True, dtype=wp.vec4f),
|
|
360
|
+
),
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Mixed partial: d(grad_g)/da -> grad_a_double
|
|
364
|
+
wp.launch(
|
|
365
|
+
_conv_sv_2d_sp_backward_a_kernel,
|
|
366
|
+
dim=(B - 1, A, G), # B-1: exclude padding row
|
|
367
|
+
stream=stream,
|
|
368
|
+
device=device,
|
|
369
|
+
inputs=(
|
|
370
|
+
wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
|
|
371
|
+
wp.from_torch(idx.to(torch.int32), return_ctype=True),
|
|
372
|
+
wp.from_torch(grad2_g_contig, return_ctype=True, dtype=wp.vec4f),
|
|
373
|
+
wp.from_torch(grad_a_double, return_ctype=True),
|
|
374
|
+
),
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
return [grad_grad_output, grad_a_double, grad_g_double]
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@torch.library.register_fake("aimnet::conv_sv_2d_sp_bwd_bwd")
|
|
381
|
+
def _(
|
|
382
|
+
grad_output: Tensor,
|
|
383
|
+
grad2_a: Tensor,
|
|
384
|
+
grad2_g: Tensor,
|
|
385
|
+
a: Tensor,
|
|
386
|
+
idx: Tensor,
|
|
387
|
+
g: Tensor,
|
|
388
|
+
) -> list[Tensor]:
|
|
389
|
+
B, A, G = a.shape
|
|
390
|
+
B_out, M = idx.shape
|
|
391
|
+
return [
|
|
392
|
+
torch.empty(B, A, G, 4, dtype=a.dtype, device=a.device),
|
|
393
|
+
torch.empty_like(a),
|
|
394
|
+
torch.empty(B_out, M, G, 4, dtype=a.dtype, device=a.device),
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
# =============================================================================
|
|
399
|
+
# Autograd Registration
|
|
400
|
+
# =============================================================================
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _conv_sv_2d_sp_setup_fwd_context(ctx, inputs, output):
|
|
404
|
+
"""Setup context for forward pass."""
|
|
405
|
+
a, idx, g = inputs
|
|
406
|
+
ctx.save_for_backward(a, idx, g)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _conv_sv_2d_sp_setup_bwd_context(ctx, inputs, output):
|
|
410
|
+
"""Setup context for backward pass."""
|
|
411
|
+
grad_output, a, idx, g = inputs
|
|
412
|
+
ctx.save_for_backward(grad_output, a, idx, g)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@torch.compiler.allow_in_graph
|
|
416
|
+
def _conv_sv_2d_sp_bwd(ctx, grad_output):
|
|
417
|
+
"""Backward pass for conv_sv_2d_sp."""
|
|
418
|
+
a, idx, g = ctx.saved_tensors
|
|
419
|
+
grad_a, grad_g = torch.ops.aimnet.conv_sv_2d_sp_bwd(grad_output.contiguous(), a, idx, g)
|
|
420
|
+
return grad_a, None, grad_g
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@torch.compiler.allow_in_graph
|
|
424
|
+
def _conv_sv_2d_sp_bwd_bwd(ctx, *grad_outputs):
|
|
425
|
+
"""Double backward pass for conv_sv_2d_sp."""
|
|
426
|
+
grad2_a = grad_outputs[0][0]
|
|
427
|
+
grad2_g = grad_outputs[0][1]
|
|
428
|
+
|
|
429
|
+
grad_output_saved, a, idx, g = ctx.saved_tensors
|
|
430
|
+
|
|
431
|
+
if grad2_a is None:
|
|
432
|
+
grad2_a = torch.zeros_like(a)
|
|
433
|
+
if grad2_g is None:
|
|
434
|
+
B_out, M = idx.shape
|
|
435
|
+
G = a.shape[2]
|
|
436
|
+
grad2_g = torch.zeros(B_out, M, G, 4, dtype=g.dtype, device=g.device)
|
|
437
|
+
|
|
438
|
+
outputs = torch.ops.aimnet.conv_sv_2d_sp_bwd_bwd(grad_output_saved, grad2_a, grad2_g, a, idx, g)
|
|
439
|
+
|
|
440
|
+
return outputs[0], outputs[1], None, outputs[2]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
torch.library.register_autograd(
|
|
444
|
+
"aimnet::conv_sv_2d_sp_fwd",
|
|
445
|
+
_conv_sv_2d_sp_bwd,
|
|
446
|
+
setup_context=_conv_sv_2d_sp_setup_fwd_context,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
torch.library.register_autograd(
|
|
450
|
+
"aimnet::conv_sv_2d_sp_bwd",
|
|
451
|
+
_conv_sv_2d_sp_bwd_bwd,
|
|
452
|
+
setup_context=_conv_sv_2d_sp_setup_bwd_context,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
# =============================================================================
|
|
457
|
+
# Public API
|
|
458
|
+
# =============================================================================
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def conv_sv_2d_sp(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
|
|
462
|
+
"""Compute conv_sv_2d_sp with support for 1st and 2nd order derivatives.
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
a : Tensor
|
|
467
|
+
Input tensor of shape (B, A, G).
|
|
468
|
+
idx : Tensor
|
|
469
|
+
Index tensor of shape (B, M).
|
|
470
|
+
g : Tensor
|
|
471
|
+
Gate tensor of shape (B, M, G, 4).
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
Tensor
|
|
476
|
+
Output tensor of shape (B, A, G, 4).
|
|
477
|
+
"""
|
|
478
|
+
return torch.ops.aimnet.conv_sv_2d_sp_fwd(a, idx, g)
|
aimnet/models/__init__.py
CHANGED
|
@@ -1,2 +1,14 @@
|
|
|
1
1
|
from .aimnet2 import AIMNet2 # noqa: F401
|
|
2
|
-
from .base import AIMNet2Base # noqa: F401
|
|
2
|
+
from .base import AIMNet2Base, load_model # noqa: F401
|
|
3
|
+
from .utils import ( # noqa: F401
|
|
4
|
+
extract_coulomb_rc,
|
|
5
|
+
extract_d3_params,
|
|
6
|
+
extract_species,
|
|
7
|
+
has_d3ts,
|
|
8
|
+
has_d3ts_in_config,
|
|
9
|
+
has_dftd3_in_config,
|
|
10
|
+
has_dispersion,
|
|
11
|
+
has_externalizable_dftd3,
|
|
12
|
+
has_lrcoulomb,
|
|
13
|
+
iter_lrcoulomb_mods,
|
|
14
|
+
)
|
aimnet/models/aimnet2.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor, nn
|
|
@@ -8,17 +8,16 @@ from aimnet.models.base import AIMNet2Base
|
|
|
8
8
|
from aimnet.modules import AEVSV, MLP, ConvSV, Embedding
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
# pylint: disable=too-many-arguments, too-many-instance-attributes
|
|
12
11
|
class AIMNet2(AIMNet2Base):
|
|
13
12
|
def __init__(
|
|
14
13
|
self,
|
|
15
|
-
aev:
|
|
14
|
+
aev: dict,
|
|
16
15
|
nfeature: int,
|
|
17
16
|
d2features: bool,
|
|
18
17
|
ncomb_v: int,
|
|
19
|
-
hidden:
|
|
18
|
+
hidden: tuple[list[int]],
|
|
20
19
|
aim_size: int,
|
|
21
|
-
outputs:
|
|
20
|
+
outputs: list[nn.Module] | dict[str, nn.Module],
|
|
22
21
|
num_charge_channels: int = 1,
|
|
23
22
|
):
|
|
24
23
|
super().__init__()
|
|
@@ -29,7 +28,7 @@ class AIMNet2(AIMNet2Base):
|
|
|
29
28
|
|
|
30
29
|
self.aev = AEVSV(**aev)
|
|
31
30
|
nshifts_s = aev["nshifts_s"]
|
|
32
|
-
nshifts_v = aev.get("
|
|
31
|
+
nshifts_v = aev.get("nshifts_v") or nshifts_s
|
|
33
32
|
if d2features:
|
|
34
33
|
if nshifts_s != nshifts_v:
|
|
35
34
|
raise ValueError("nshifts_s must be equal to nshifts_v for d2features")
|
|
@@ -49,7 +48,7 @@ class AIMNet2(AIMNet2Base):
|
|
|
49
48
|
self.afv.weight.clone().unsqueeze(-1).expand(64, nfeature, nshifts_s).flatten(-2, -1)
|
|
50
49
|
)
|
|
51
50
|
|
|
52
|
-
conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v
|
|
51
|
+
conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v}
|
|
53
52
|
self.conv_a = ConvSV(nchannel=nfeature, d2features=d2features, **conv_param)
|
|
54
53
|
self.conv_q = ConvSV(nchannel=num_charge_channels, d2features=False, **conv_param)
|
|
55
54
|
|
|
@@ -90,7 +89,7 @@ class AIMNet2(AIMNet2Base):
|
|
|
90
89
|
else:
|
|
91
90
|
raise TypeError("`outputs` is not either list or dict")
|
|
92
91
|
|
|
93
|
-
def _preprocess_spin_polarized_charge(self, data:
|
|
92
|
+
def _preprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
94
93
|
if "mult" not in data:
|
|
95
94
|
raise ValueError("mult key is required for NSE if two channels for charge are not provided")
|
|
96
95
|
_half_spin = 0.5 * (data["mult"] - 1.0)
|
|
@@ -98,27 +97,27 @@ class AIMNet2(AIMNet2Base):
|
|
|
98
97
|
data["charge"] = torch.stack([_half_q + _half_spin, _half_q - _half_spin], dim=-1)
|
|
99
98
|
return data
|
|
100
99
|
|
|
101
|
-
def _postprocess_spin_polarized_charge(self, data:
|
|
100
|
+
def _postprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
102
101
|
data["spin_charges"] = data["charges"][..., 0] - data["charges"][..., 1]
|
|
103
102
|
data["charges"] = data["charges"].sum(dim=-1)
|
|
104
103
|
data["charge"] = data["charge"].sum(dim=-1)
|
|
105
104
|
return data
|
|
106
105
|
|
|
107
|
-
def _prepare_in_a(self, data:
|
|
108
|
-
a_i
|
|
109
|
-
avf_a = self.conv_a(
|
|
106
|
+
def _prepare_in_a(self, data: dict[str, Tensor]) -> Tensor:
|
|
107
|
+
a_i = nbops.get_i(data["a"], data)
|
|
108
|
+
avf_a = self.conv_a(data, data["a"])
|
|
110
109
|
if self.d2features:
|
|
111
110
|
a_i = a_i.flatten(-2, -1)
|
|
112
111
|
_in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
|
|
113
112
|
return _in
|
|
114
113
|
|
|
115
|
-
def _prepare_in_q(self, data:
|
|
116
|
-
q_i
|
|
117
|
-
avf_q = self.conv_q(
|
|
114
|
+
def _prepare_in_q(self, data: dict[str, Tensor]) -> Tensor:
|
|
115
|
+
q_i = nbops.get_i(data["charges"], data)
|
|
116
|
+
avf_q = self.conv_q(data, data["charges"])
|
|
118
117
|
_in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
|
|
119
118
|
return _in
|
|
120
119
|
|
|
121
|
-
def _update_q(self, data:
|
|
120
|
+
def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dict[str, Tensor]:
|
|
122
121
|
_q, _f, delta_a = x.split(
|
|
123
122
|
[
|
|
124
123
|
self.num_charge_channels,
|
|
@@ -127,16 +126,17 @@ class AIMNet2(AIMNet2Base):
|
|
|
127
126
|
],
|
|
128
127
|
dim=-1,
|
|
129
128
|
)
|
|
130
|
-
# for loss
|
|
129
|
+
# Charge conservation violation penalty for training loss
|
|
131
130
|
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
|
|
132
131
|
q = data["charges"] + _q if delta_q else _q
|
|
132
|
+
data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1)
|
|
133
133
|
f = _f.pow(2)
|
|
134
134
|
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
|
|
135
135
|
data["charges"] = q
|
|
136
136
|
data["a"] = data["a"] + delta_a.view_as(data["a"])
|
|
137
137
|
return data
|
|
138
138
|
|
|
139
|
-
def forward(self, data:
|
|
139
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
140
140
|
data = self.prepare_input(data)
|
|
141
141
|
|
|
142
142
|
# initial features
|
|
@@ -149,13 +149,11 @@ class AIMNet2(AIMNet2Base):
|
|
|
149
149
|
if self.num_charge_channels == 2:
|
|
150
150
|
data = self._preprocess_spin_polarized_charge(data)
|
|
151
151
|
else:
|
|
152
|
-
#
|
|
152
|
+
# Ensure charge tensor has channel dimension for consistency with features
|
|
153
153
|
data["charge"] = data["charge"].unsqueeze(-1)
|
|
154
154
|
|
|
155
|
-
# AEV
|
|
156
155
|
data = self.aev(data)
|
|
157
156
|
|
|
158
|
-
# MP iterations
|
|
159
157
|
_npass = len(self.mlps)
|
|
160
158
|
for ipass, mlp in enumerate(self.mlps):
|
|
161
159
|
if ipass == 0:
|
|
@@ -181,7 +179,6 @@ class AIMNet2(AIMNet2Base):
|
|
|
181
179
|
data["charges"] = data["charges"].squeeze(-1)
|
|
182
180
|
data["charge"] = data["charge"].squeeze(-1)
|
|
183
181
|
|
|
184
|
-
# readout
|
|
185
182
|
for m in self.outputs.children():
|
|
186
183
|
data = m(data)
|
|
187
184
|
|