aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,478 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ # type: ignore
23
+
24
+ import torch
25
+ import warp as wp
26
+ from torch import Tensor
27
+
28
+ wp.init()
29
+
30
+
31
+ def _get_stream(device: torch.device):
32
+ """Get the Warp stream for the given device."""
33
+ if device.type == "cuda":
34
+ return wp.stream_from_torch(torch.cuda.current_stream(device))
35
+ return None
36
+
37
+
38
+ # =============================================================================
39
+ # Warp Kernels
40
+ # =============================================================================
41
+
42
+
43
+ @wp.kernel(enable_backward=False)
44
+ def _conv_sv_2d_sp_kernel(
45
+ a: wp.array3d(dtype=wp.float32), # (B, A, G)
46
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
47
+ g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
48
+ output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
49
+ ):
50
+ """Forward: output[b,a,g] = sum_m a[idx[b,m],a,g] * g[b,m,g]"""
51
+ B, M = idx.shape[0], idx.shape[1]
52
+ padding_value = B - 1 # last row is padding
53
+
54
+ _b, _a, _g = wp.tid()
55
+
56
+ acc = wp.vec4f()
57
+ for _m in range(M):
58
+ _idx = idx[_b, _m]
59
+ if _idx >= padding_value:
60
+ break
61
+ a_val = a[_idx, _a, _g]
62
+ g_val = g[_b, _m, _g]
63
+ acc += a_val * g_val
64
+ output[_b, _a, _g] = acc
65
+
66
+
67
+ @wp.kernel(enable_backward=False)
68
+ def _conv_sv_2d_sp_backward_a_kernel(
69
+ grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
70
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
71
+ g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
72
+ grad_a: wp.array3d(dtype=wp.float32), # (B, A, G)
73
+ ):
74
+ """Backward w.r.t. a: grad_a[idx[b,m],a,g] += dot(grad_output[b,a,g], g[b,m,g])"""
75
+ B, M = idx.shape[0], idx.shape[1]
76
+ padding_value = B - 1 # last row is padding
77
+
78
+ _b, _a, _g = wp.tid()
79
+
80
+ grad_out = grad_output[_b, _a, _g]
81
+ for _m in range(M):
82
+ _idx = idx[_b, _m]
83
+ if _idx >= padding_value:
84
+ break
85
+ g_val = g[_b, _m, _g]
86
+ val = wp.dot(grad_out, g_val)
87
+ wp.atomic_add(grad_a, _idx, _a, _g, val)
88
+
89
+
90
+ @wp.kernel(enable_backward=False)
91
+ def _conv_sv_2d_sp_backward_g_kernel(
92
+ grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
93
+ a: wp.array3d(dtype=wp.float32), # (B, A, G)
94
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
95
+ grad_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
96
+ ):
97
+ """Backward w.r.t. g: grad_g[b,m,g] = sum_a a[idx[b,m],a,g] * grad_output[b,a,g]"""
98
+ B = idx.shape[0]
99
+ A = a.shape[1]
100
+ padding_value = B - 1 # last row is padding
101
+
102
+ _b, _m, _g = wp.tid()
103
+
104
+ _idx = idx[_b, _m]
105
+ if _idx >= padding_value:
106
+ return
107
+
108
+ acc = wp.vec4f()
109
+
110
+ for _a in range(A):
111
+ grad_out = grad_output[_b, _a, _g]
112
+ a_val = a[_idx, _a, _g]
113
+ acc += a_val * grad_out
114
+
115
+ grad_g[_b, _m, _g] = acc
116
+
117
+
118
+ @wp.kernel(enable_backward=False)
119
+ def _conv_sv_2d_sp_double_backward_a_g_kernel(
120
+ grad_grad_a: wp.array3d(dtype=wp.float32), # (B, A, G)
121
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
122
+ grad_output: wp.array3d(dtype=wp.vec4f), # (B, A, G, D)
123
+ grad_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
124
+ ):
125
+ """Double backward: d(grad_a)/dg -> grad_g"""
126
+ B = idx.shape[0]
127
+ A = grad_grad_a.shape[1]
128
+ padding_value = B - 1 # last row is padding
129
+
130
+ _b, _m, _g = wp.tid()
131
+
132
+ _idx = idx[_b, _m]
133
+ if _idx >= padding_value:
134
+ return
135
+
136
+ acc = wp.vec4f()
137
+
138
+ for _a in range(A):
139
+ grad_grad_a_val = grad_grad_a[_idx, _a, _g]
140
+ grad_out = grad_output[_b, _a, _g]
141
+ acc += grad_grad_a_val * grad_out
142
+
143
+ grad_g[_b, _m, _g] = acc
144
+
145
+
146
+ @wp.kernel(enable_backward=False)
147
+ def _conv_sv_2d_sp_double_backward_g_contrib_kernel(
148
+ grad2_g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
149
+ a: wp.array3d(dtype=wp.float32), # (B, A, G)
150
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
151
+ grad_output_double: wp.array3d(dtype=wp.vec4f), # (B, A, G, D) - OUTPUT
152
+ ):
153
+ """Double backward from grad2_g: einsum('bmgd,bmag->bagd', grad2_g, a_selected)"""
154
+ B, M = idx.shape[0], idx.shape[1]
155
+ padding_value = B - 1 # last row is padding
156
+
157
+ _b, _a, _g = wp.tid()
158
+
159
+ acc = wp.vec4f()
160
+ for _m in range(M):
161
+ _idx = idx[_b, _m]
162
+ if _idx >= padding_value:
163
+ break
164
+ a_val = a[_idx, _a, _g]
165
+ grad2_g_val = grad2_g[_b, _m, _g]
166
+ acc += a_val * grad2_g_val
167
+
168
+ grad_output_double[_b, _a, _g] = acc
169
+
170
+
171
+ @wp.kernel(enable_backward=False)
172
+ def _conv_sv_2d_sp_double_backward_a_contrib_kernel(
173
+ grad2_a: wp.array3d(dtype=wp.float32), # (B, A, G)
174
+ idx: wp.array2d(dtype=wp.int32), # (B, M)
175
+ g: wp.array3d(dtype=wp.vec4f), # (B, M, G, D)
176
+ grad_output_double: wp.array3d(dtype=wp.vec4f), # (B, A, G, D) - OUTPUT
177
+ ):
178
+ """Double backward from grad2_a: einsum('bmag,bmgd->bagd', grad2_a_selected, g)"""
179
+ B, M = idx.shape[0], idx.shape[1]
180
+ padding_value = B - 1 # last row is padding
181
+
182
+ _b, _a, _g = wp.tid()
183
+
184
+ acc = wp.vec4f()
185
+ for _m in range(M):
186
+ _idx = idx[_b, _m]
187
+ if _idx >= padding_value:
188
+ break
189
+ grad2_a_val = grad2_a[_idx, _a, _g]
190
+ g_val = g[_b, _m, _g]
191
+ acc += grad2_a_val * g_val
192
+
193
+ grad_output_double[_b, _a, _g] = acc
194
+
195
+
196
+ # =============================================================================
197
+ # PyTorch Custom Op Primitives
198
+ # =============================================================================
199
+
200
+
201
+ @torch.library.custom_op(
202
+ "aimnet::conv_sv_2d_sp_fwd",
203
+ mutates_args=(),
204
+ device_types=["cuda"],
205
+ )
206
+ def _(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
207
+ """Forward primitive for conv_sv_2d_sp."""
208
+ stream = _get_stream(a.device)
209
+ device = wp.device_from_torch(a.device)
210
+ B, A, G = a.shape
211
+ output = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
212
+
213
+ wp.launch(
214
+ _conv_sv_2d_sp_kernel,
215
+ dim=(B - 1, A, G), # B-1: exclude padding row
216
+ stream=stream,
217
+ device=device,
218
+ inputs=(
219
+ wp.from_torch(a.detach(), return_ctype=True),
220
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
221
+ wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
222
+ wp.from_torch(output, return_ctype=True, dtype=wp.vec4f),
223
+ ),
224
+ )
225
+ return output
226
+
227
+
228
+ @torch.library.register_fake("aimnet::conv_sv_2d_sp_fwd")
229
+ def _(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
230
+ B, A, G = a.shape
231
+ return torch.empty(B, A, G, 4, dtype=a.dtype, device=a.device)
232
+
233
+
234
+ @torch.library.custom_op(
235
+ "aimnet::conv_sv_2d_sp_bwd",
236
+ mutates_args=(),
237
+ device_types=["cuda"],
238
+ )
239
+ def _(grad_output: Tensor, a: Tensor, idx: Tensor, g: Tensor) -> list[Tensor]:
240
+ """Backward primitive for conv_sv_2d_sp."""
241
+ stream = _get_stream(a.device)
242
+ device = wp.device_from_torch(a.device)
243
+ B, A, G = a.shape
244
+ B_out, M = idx.shape
245
+
246
+ grad_a = torch.zeros_like(a)
247
+ grad_g = torch.zeros(B_out, M, G, 4, dtype=g.dtype, device=g.device)
248
+
249
+ grad_output_contig = grad_output.detach().contiguous()
250
+
251
+ # Launch backward w.r.t. a
252
+ wp.launch(
253
+ _conv_sv_2d_sp_backward_a_kernel,
254
+ dim=(B - 1, A, G), # B-1: exclude padding row
255
+ stream=stream,
256
+ device=device,
257
+ inputs=(
258
+ wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
259
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
260
+ wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
261
+ wp.from_torch(grad_a, return_ctype=True),
262
+ ),
263
+ )
264
+
265
+ # Launch backward w.r.t. g
266
+ wp.launch(
267
+ _conv_sv_2d_sp_backward_g_kernel,
268
+ dim=(B_out - 1, M, G), # B_out-1: exclude padding row
269
+ stream=stream,
270
+ device=device,
271
+ inputs=(
272
+ wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
273
+ wp.from_torch(a.detach(), return_ctype=True),
274
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
275
+ wp.from_torch(grad_g, return_ctype=True, dtype=wp.vec4f),
276
+ ),
277
+ )
278
+
279
+ return [grad_a, grad_g]
280
+
281
+
282
+ @torch.library.register_fake("aimnet::conv_sv_2d_sp_bwd")
283
+ def _(grad_output: Tensor, a: Tensor, idx: Tensor, g: Tensor) -> list[Tensor]:
284
+ B_out, M = idx.shape
285
+ G = a.shape[2]
286
+ return [
287
+ torch.empty_like(a),
288
+ torch.empty(B_out, M, G, 4, dtype=g.dtype, device=g.device),
289
+ ]
290
+
291
+
292
+ @torch.library.custom_op(
293
+ "aimnet::conv_sv_2d_sp_bwd_bwd",
294
+ mutates_args=(),
295
+ device_types=["cuda"],
296
+ )
297
+ def _(
298
+ grad_output: Tensor,
299
+ grad2_a: Tensor,
300
+ grad2_g: Tensor,
301
+ a: Tensor,
302
+ idx: Tensor,
303
+ g: Tensor,
304
+ ) -> list[Tensor]:
305
+ """Double backward primitive for conv_sv_2d_sp."""
306
+ stream = _get_stream(a.device)
307
+ device = wp.device_from_torch(a.device)
308
+ B, A, G = a.shape
309
+ B_out, M = idx.shape
310
+
311
+ grad_grad_output = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
312
+ grad_a_double = torch.zeros_like(a)
313
+ grad_g_double = torch.zeros(B_out, M, G, 4, dtype=a.dtype, device=a.device)
314
+
315
+ grad_output_contig = grad_output.detach().contiguous()
316
+ grad2_a_contig = grad2_a.detach().contiguous()
317
+ grad2_g_contig = grad2_g.detach().contiguous()
318
+
319
+ # Contribution from grad2_g to grad_grad_output
320
+ wp.launch(
321
+ _conv_sv_2d_sp_double_backward_g_contrib_kernel,
322
+ dim=(B - 1, A, G), # B-1: exclude padding row
323
+ stream=stream,
324
+ device=device,
325
+ inputs=(
326
+ wp.from_torch(grad2_g_contig, return_ctype=True, dtype=wp.vec4f),
327
+ wp.from_torch(a.detach(), return_ctype=True),
328
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
329
+ wp.from_torch(grad_grad_output, return_ctype=True, dtype=wp.vec4f),
330
+ ),
331
+ )
332
+
333
+ # Contribution from grad2_a to grad_grad_output
334
+ grad_output_2_a = torch.zeros(B, A, G, 4, dtype=a.dtype, device=a.device)
335
+ wp.launch(
336
+ _conv_sv_2d_sp_double_backward_a_contrib_kernel,
337
+ dim=(B - 1, A, G), # B-1: exclude padding row
338
+ stream=stream,
339
+ device=device,
340
+ inputs=(
341
+ wp.from_torch(grad2_a_contig, return_ctype=True),
342
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
343
+ wp.from_torch(g.detach(), return_ctype=True, dtype=wp.vec4f),
344
+ wp.from_torch(grad_output_2_a, return_ctype=True, dtype=wp.vec4f),
345
+ ),
346
+ )
347
+ grad_grad_output = grad_grad_output + grad_output_2_a
348
+
349
+ # Mixed partial: d(grad_a)/dg -> grad_g_double
350
+ wp.launch(
351
+ _conv_sv_2d_sp_double_backward_a_g_kernel,
352
+ dim=(B_out - 1, M, G), # B_out-1: exclude padding row
353
+ stream=stream,
354
+ device=device,
355
+ inputs=(
356
+ wp.from_torch(grad2_a_contig, return_ctype=True),
357
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
358
+ wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
359
+ wp.from_torch(grad_g_double, return_ctype=True, dtype=wp.vec4f),
360
+ ),
361
+ )
362
+
363
+ # Mixed partial: d(grad_g)/da -> grad_a_double
364
+ wp.launch(
365
+ _conv_sv_2d_sp_backward_a_kernel,
366
+ dim=(B - 1, A, G), # B-1: exclude padding row
367
+ stream=stream,
368
+ device=device,
369
+ inputs=(
370
+ wp.from_torch(grad_output_contig, return_ctype=True, dtype=wp.vec4f),
371
+ wp.from_torch(idx.to(torch.int32), return_ctype=True),
372
+ wp.from_torch(grad2_g_contig, return_ctype=True, dtype=wp.vec4f),
373
+ wp.from_torch(grad_a_double, return_ctype=True),
374
+ ),
375
+ )
376
+
377
+ return [grad_grad_output, grad_a_double, grad_g_double]
378
+
379
+
380
+ @torch.library.register_fake("aimnet::conv_sv_2d_sp_bwd_bwd")
381
+ def _(
382
+ grad_output: Tensor,
383
+ grad2_a: Tensor,
384
+ grad2_g: Tensor,
385
+ a: Tensor,
386
+ idx: Tensor,
387
+ g: Tensor,
388
+ ) -> list[Tensor]:
389
+ B, A, G = a.shape
390
+ B_out, M = idx.shape
391
+ return [
392
+ torch.empty(B, A, G, 4, dtype=a.dtype, device=a.device),
393
+ torch.empty_like(a),
394
+ torch.empty(B_out, M, G, 4, dtype=a.dtype, device=a.device),
395
+ ]
396
+
397
+
398
+ # =============================================================================
399
+ # Autograd Registration
400
+ # =============================================================================
401
+
402
+
403
+ def _conv_sv_2d_sp_setup_fwd_context(ctx, inputs, output):
404
+ """Setup context for forward pass."""
405
+ a, idx, g = inputs
406
+ ctx.save_for_backward(a, idx, g)
407
+
408
+
409
+ def _conv_sv_2d_sp_setup_bwd_context(ctx, inputs, output):
410
+ """Setup context for backward pass."""
411
+ grad_output, a, idx, g = inputs
412
+ ctx.save_for_backward(grad_output, a, idx, g)
413
+
414
+
415
+ @torch.compiler.allow_in_graph
416
+ def _conv_sv_2d_sp_bwd(ctx, grad_output):
417
+ """Backward pass for conv_sv_2d_sp."""
418
+ a, idx, g = ctx.saved_tensors
419
+ grad_a, grad_g = torch.ops.aimnet.conv_sv_2d_sp_bwd(grad_output.contiguous(), a, idx, g)
420
+ return grad_a, None, grad_g
421
+
422
+
423
+ @torch.compiler.allow_in_graph
424
+ def _conv_sv_2d_sp_bwd_bwd(ctx, *grad_outputs):
425
+ """Double backward pass for conv_sv_2d_sp."""
426
+ grad2_a = grad_outputs[0][0]
427
+ grad2_g = grad_outputs[0][1]
428
+
429
+ grad_output_saved, a, idx, g = ctx.saved_tensors
430
+
431
+ if grad2_a is None:
432
+ grad2_a = torch.zeros_like(a)
433
+ if grad2_g is None:
434
+ B_out, M = idx.shape
435
+ G = a.shape[2]
436
+ grad2_g = torch.zeros(B_out, M, G, 4, dtype=g.dtype, device=g.device)
437
+
438
+ outputs = torch.ops.aimnet.conv_sv_2d_sp_bwd_bwd(grad_output_saved, grad2_a, grad2_g, a, idx, g)
439
+
440
+ return outputs[0], outputs[1], None, outputs[2]
441
+
442
+
443
+ torch.library.register_autograd(
444
+ "aimnet::conv_sv_2d_sp_fwd",
445
+ _conv_sv_2d_sp_bwd,
446
+ setup_context=_conv_sv_2d_sp_setup_fwd_context,
447
+ )
448
+
449
+ torch.library.register_autograd(
450
+ "aimnet::conv_sv_2d_sp_bwd",
451
+ _conv_sv_2d_sp_bwd_bwd,
452
+ setup_context=_conv_sv_2d_sp_setup_bwd_context,
453
+ )
454
+
455
+
456
+ # =============================================================================
457
+ # Public API
458
+ # =============================================================================
459
+
460
+
461
+ def conv_sv_2d_sp(a: Tensor, idx: Tensor, g: Tensor) -> Tensor:
462
+ """Compute conv_sv_2d_sp with support for 1st and 2nd order derivatives.
463
+
464
+ Parameters
465
+ ----------
466
+ a : Tensor
467
+ Input tensor of shape (B, A, G).
468
+ idx : Tensor
469
+ Index tensor of shape (B, M).
470
+ g : Tensor
471
+ Gate tensor of shape (B, M, G, 4).
472
+
473
+ Returns
474
+ -------
475
+ Tensor
476
+ Output tensor of shape (B, A, G, 4).
477
+ """
478
+ return torch.ops.aimnet.conv_sv_2d_sp_fwd(a, idx, g)
aimnet/models/__init__.py CHANGED
@@ -1,2 +1,14 @@
1
1
  from .aimnet2 import AIMNet2 # noqa: F401
2
- from .base import AIMNet2Base # noqa: F401
2
+ from .base import AIMNet2Base, load_model # noqa: F401
3
+ from .utils import ( # noqa: F401
4
+ extract_coulomb_rc,
5
+ extract_d3_params,
6
+ extract_species,
7
+ has_d3ts,
8
+ has_d3ts_in_config,
9
+ has_dftd3_in_config,
10
+ has_dispersion,
11
+ has_externalizable_dftd3,
12
+ has_lrcoulomb,
13
+ iter_lrcoulomb_mods,
14
+ )
aimnet/models/aimnet2.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Mapping, Sequence, Tuple, Union
1
+ from collections.abc import Mapping, Sequence
2
2
 
3
3
  import torch
4
4
  from torch import Tensor, nn
@@ -8,17 +8,16 @@ from aimnet.models.base import AIMNet2Base
8
8
  from aimnet.modules import AEVSV, MLP, ConvSV, Embedding
9
9
 
10
10
 
11
- # pylint: disable=too-many-arguments, too-many-instance-attributes
12
11
  class AIMNet2(AIMNet2Base):
13
12
  def __init__(
14
13
  self,
15
- aev: Dict,
14
+ aev: dict,
16
15
  nfeature: int,
17
16
  d2features: bool,
18
17
  ncomb_v: int,
19
- hidden: Tuple[List[int]],
18
+ hidden: tuple[list[int]],
20
19
  aim_size: int,
21
- outputs: Union[List[nn.Module], Dict[str, nn.Module]],
20
+ outputs: list[nn.Module] | dict[str, nn.Module],
22
21
  num_charge_channels: int = 1,
23
22
  ):
24
23
  super().__init__()
@@ -29,7 +28,7 @@ class AIMNet2(AIMNet2Base):
29
28
 
30
29
  self.aev = AEVSV(**aev)
31
30
  nshifts_s = aev["nshifts_s"]
32
- nshifts_v = aev.get("nshitfs_v") or nshifts_s
31
+ nshifts_v = aev.get("nshifts_v") or nshifts_s
33
32
  if d2features:
34
33
  if nshifts_s != nshifts_v:
35
34
  raise ValueError("nshifts_s must be equal to nshifts_v for d2features")
@@ -49,7 +48,7 @@ class AIMNet2(AIMNet2Base):
49
48
  self.afv.weight.clone().unsqueeze(-1).expand(64, nfeature, nshifts_s).flatten(-2, -1)
50
49
  )
51
50
 
52
- conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v, "do_vector": True}
51
+ conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v}
53
52
  self.conv_a = ConvSV(nchannel=nfeature, d2features=d2features, **conv_param)
54
53
  self.conv_q = ConvSV(nchannel=num_charge_channels, d2features=False, **conv_param)
55
54
 
@@ -90,7 +89,7 @@ class AIMNet2(AIMNet2Base):
90
89
  else:
91
90
  raise TypeError("`outputs` is not either list or dict")
92
91
 
93
- def _preprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
92
+ def _preprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
94
93
  if "mult" not in data:
95
94
  raise ValueError("mult key is required for NSE if two channels for charge are not provided")
96
95
  _half_spin = 0.5 * (data["mult"] - 1.0)
@@ -98,27 +97,27 @@ class AIMNet2(AIMNet2Base):
98
97
  data["charge"] = torch.stack([_half_q + _half_spin, _half_q - _half_spin], dim=-1)
99
98
  return data
100
99
 
101
- def _postprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
100
+ def _postprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
102
101
  data["spin_charges"] = data["charges"][..., 0] - data["charges"][..., 1]
103
102
  data["charges"] = data["charges"].sum(dim=-1)
104
103
  data["charge"] = data["charge"].sum(dim=-1)
105
104
  return data
106
105
 
107
- def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor:
108
- a_i, a_j = nbops.get_ij(data["a"], data)
109
- avf_a = self.conv_a(a_j, data["gs"], data["gv"])
106
+ def _prepare_in_a(self, data: dict[str, Tensor]) -> Tensor:
107
+ a_i = nbops.get_i(data["a"], data)
108
+ avf_a = self.conv_a(data, data["a"])
110
109
  if self.d2features:
111
110
  a_i = a_i.flatten(-2, -1)
112
111
  _in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
113
112
  return _in
114
113
 
115
- def _prepare_in_q(self, data: Dict[str, Tensor]) -> Tensor:
116
- q_i, q_j = nbops.get_ij(data["charges"], data)
117
- avf_q = self.conv_q(q_j, data["gs"], data["gv"])
114
+ def _prepare_in_q(self, data: dict[str, Tensor]) -> Tensor:
115
+ q_i = nbops.get_i(data["charges"], data)
116
+ avf_q = self.conv_q(data, data["charges"])
118
117
  _in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
119
118
  return _in
120
119
 
121
- def _update_q(self, data: Dict[str, Tensor], x: Tensor, delta_q: bool = True) -> Dict[str, Tensor]:
120
+ def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dict[str, Tensor]:
122
121
  _q, _f, delta_a = x.split(
123
122
  [
124
123
  self.num_charge_channels,
@@ -127,16 +126,17 @@ class AIMNet2(AIMNet2Base):
127
126
  ],
128
127
  dim=-1,
129
128
  )
130
- # for loss
129
+ # Charge conservation violation penalty for training loss
131
130
  data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
132
131
  q = data["charges"] + _q if delta_q else _q
132
+ data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1)
133
133
  f = _f.pow(2)
134
134
  q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
135
135
  data["charges"] = q
136
136
  data["a"] = data["a"] + delta_a.view_as(data["a"])
137
137
  return data
138
138
 
139
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
139
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
140
140
  data = self.prepare_input(data)
141
141
 
142
142
  # initial features
@@ -149,13 +149,11 @@ class AIMNet2(AIMNet2Base):
149
149
  if self.num_charge_channels == 2:
150
150
  data = self._preprocess_spin_polarized_charge(data)
151
151
  else:
152
- # make sure that charge has channel dimension
152
+ # Ensure charge tensor has channel dimension for consistency with features
153
153
  data["charge"] = data["charge"].unsqueeze(-1)
154
154
 
155
- # AEV
156
155
  data = self.aev(data)
157
156
 
158
- # MP iterations
159
157
  _npass = len(self.mlps)
160
158
  for ipass, mlp in enumerate(self.mlps):
161
159
  if ipass == 0:
@@ -181,7 +179,6 @@ class AIMNet2(AIMNet2Base):
181
179
  data["charges"] = data["charges"].squeeze(-1)
182
180
  data["charge"] = data["charge"].squeeze(-1)
183
181
 
184
- # readout
185
182
  for m in self.outputs.children():
186
183
  data = m(data)
187
184