aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/modules/lr.py
CHANGED
|
@@ -1,12 +1,96 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor, nn
|
|
5
5
|
|
|
6
6
|
from aimnet import constants, nbops, ops
|
|
7
|
+
from aimnet.modules.ops import dftd3_energy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _calc_coulomb_sr(
|
|
11
|
+
data: dict[str, Tensor],
|
|
12
|
+
rc: Tensor,
|
|
13
|
+
envelope: str,
|
|
14
|
+
key_in: str,
|
|
15
|
+
factor: float,
|
|
16
|
+
) -> Tensor:
|
|
17
|
+
"""Shared short-range Coulomb energy calculation.
|
|
18
|
+
|
|
19
|
+
Computes pairwise Coulomb energy with envelope-weighted cutoff.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
data : dict
|
|
24
|
+
Data dictionary containing d_ij distances and charges.
|
|
25
|
+
rc : Tensor
|
|
26
|
+
Cutoff radius tensor.
|
|
27
|
+
envelope : str
|
|
28
|
+
Envelope function: "exp" or "cosine".
|
|
29
|
+
key_in : str
|
|
30
|
+
Key for charges in data dict.
|
|
31
|
+
factor : float
|
|
32
|
+
Unit conversion factor (half_Hartree * Bohr).
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Tensor
|
|
37
|
+
Short-range Coulomb energy per molecule.
|
|
38
|
+
"""
|
|
39
|
+
d_ij = data["d_ij"]
|
|
40
|
+
q = data[key_in]
|
|
41
|
+
q_i, q_j = nbops.get_ij(q, data)
|
|
42
|
+
q_ij = q_i * q_j
|
|
43
|
+
if envelope == "exp":
|
|
44
|
+
fc = ops.exp_cutoff(d_ij, rc)
|
|
45
|
+
else: # cosine
|
|
46
|
+
fc = ops.cosine_cutoff(d_ij, rc.item())
|
|
47
|
+
e_ij = fc * q_ij / d_ij
|
|
48
|
+
e_ij = nbops.mask_ij_(e_ij, data, 0.0)
|
|
49
|
+
# Accumulate in float64 for precision
|
|
50
|
+
e_i = e_ij.sum(-1, dtype=torch.float64)
|
|
51
|
+
return factor * nbops.mol_sum(e_i, data)
|
|
7
52
|
|
|
8
53
|
|
|
9
54
|
class LRCoulomb(nn.Module):
|
|
55
|
+
"""Long-range Coulomb energy module.
|
|
56
|
+
|
|
57
|
+
Computes electrostatic energy using one of several methods:
|
|
58
|
+
simple (all pairs), DSF (damped shifted force), or Ewald summation.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
key_in : str
|
|
63
|
+
Key for input charges in data dict. Default is "charges".
|
|
64
|
+
key_out : str
|
|
65
|
+
Key for output energy in data dict. Default is "e_h".
|
|
66
|
+
rc : float
|
|
67
|
+
Short-range cutoff radius. Default is 4.6 Angstrom.
|
|
68
|
+
method : str
|
|
69
|
+
Coulomb method: "simple", "dsf", or "ewald". Default is "simple".
|
|
70
|
+
dsf_alpha : float
|
|
71
|
+
Alpha parameter for DSF method. Default is 0.2.
|
|
72
|
+
dsf_rc : float
|
|
73
|
+
Cutoff for DSF method. Default is 15.0.
|
|
74
|
+
ewald_accuracy : float
|
|
75
|
+
Target accuracy for Ewald summation. Controls real-space and
|
|
76
|
+
reciprocal-space cutoffs. Lower values give higher accuracy.
|
|
77
|
+
Default is 1e-8.
|
|
78
|
+
subtract_sr : bool
|
|
79
|
+
Whether to subtract short-range contribution. Default is True.
|
|
80
|
+
envelope : str
|
|
81
|
+
Envelope function for SR cutoff: "exp" or "cosine". Default is "exp".
|
|
82
|
+
|
|
83
|
+
Notes
|
|
84
|
+
-----
|
|
85
|
+
Energy accumulation uses float64 for numerical precision, particularly
|
|
86
|
+
important for large systems where many small contributions can suffer
|
|
87
|
+
from floating-point error accumulation.
|
|
88
|
+
|
|
89
|
+
Neighbor list keys follow a suffix resolution pattern: methods first look
|
|
90
|
+
for module-specific keys (e.g., nbmat_coulomb, shifts_coulomb), falling
|
|
91
|
+
back to shared _lr suffix (nbmat_lr, shifts_lr) if not found.
|
|
92
|
+
"""
|
|
93
|
+
|
|
10
94
|
def __init__(
|
|
11
95
|
self,
|
|
12
96
|
key_in: str = "charges",
|
|
@@ -15,6 +99,9 @@ class LRCoulomb(nn.Module):
|
|
|
15
99
|
method: str = "simple",
|
|
16
100
|
dsf_alpha: float = 0.2,
|
|
17
101
|
dsf_rc: float = 15.0,
|
|
102
|
+
ewald_accuracy: float = 1e-8,
|
|
103
|
+
subtract_sr: bool = True,
|
|
104
|
+
envelope: str = "exp",
|
|
18
105
|
):
|
|
19
106
|
super().__init__()
|
|
20
107
|
self.key_in = key_in
|
|
@@ -23,55 +110,63 @@ class LRCoulomb(nn.Module):
|
|
|
23
110
|
self.register_buffer("rc", torch.tensor(rc))
|
|
24
111
|
self.dsf_alpha = dsf_alpha
|
|
25
112
|
self.dsf_rc = dsf_rc
|
|
113
|
+
self.ewald_accuracy = ewald_accuracy
|
|
114
|
+
self.subtract_sr = subtract_sr
|
|
115
|
+
if envelope not in ("exp", "cosine"):
|
|
116
|
+
raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
|
|
117
|
+
self.envelope = envelope
|
|
26
118
|
if method in ("simple", "dsf", "ewald"):
|
|
27
119
|
self.method = method
|
|
28
120
|
else:
|
|
29
121
|
raise ValueError(f"Unknown method {method}")
|
|
30
122
|
|
|
31
|
-
def coul_simple(self, data:
|
|
32
|
-
|
|
33
|
-
d_ij = data["d_ij_lr"]
|
|
34
|
-
q = data[self.key_in]
|
|
35
|
-
q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
|
|
36
|
-
q_ij = q_i * q_j
|
|
37
|
-
fc = 1.0 - ops.exp_cutoff(d_ij, self.rc)
|
|
38
|
-
e_ij = fc * q_ij / d_ij
|
|
39
|
-
e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr")
|
|
40
|
-
e_i = e_ij.sum(-1)
|
|
41
|
-
e = self._factor * nbops.mol_sum(e_i, data)
|
|
42
|
-
return e
|
|
123
|
+
def coul_simple(self, data: dict[str, Tensor]) -> Tensor:
|
|
124
|
+
"""Compute pairwise Coulomb energy.
|
|
43
125
|
|
|
44
|
-
|
|
45
|
-
|
|
126
|
+
With subtract_sr=True (default): Returns LR only (FULL - SR)
|
|
127
|
+
With subtract_sr=False: Returns FULL pairwise Coulomb
|
|
128
|
+
"""
|
|
129
|
+
suffix = nbops.resolve_suffix(data, ["_coulomb", "_lr"])
|
|
130
|
+
data = ops.lazy_calc_dij(data, suffix)
|
|
131
|
+
d_ij = data[f"d_ij{suffix}"]
|
|
46
132
|
q = data[self.key_in]
|
|
47
|
-
q_i, q_j = nbops.get_ij(q, data)
|
|
133
|
+
q_i, q_j = nbops.get_ij(q, data, suffix=suffix)
|
|
48
134
|
q_ij = q_i * q_j
|
|
49
|
-
|
|
50
|
-
e_ij =
|
|
51
|
-
e_ij = nbops.mask_ij_(e_ij, data, 0.0)
|
|
52
|
-
e_i = e_ij.sum(-1)
|
|
135
|
+
# Compute FULL pairwise Coulomb (no exp_cutoff weighting)
|
|
136
|
+
e_ij = q_ij / d_ij
|
|
137
|
+
e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix=suffix)
|
|
138
|
+
e_i = e_ij.sum(-1, dtype=torch.float64)
|
|
53
139
|
e = self._factor * nbops.mol_sum(e_i, data)
|
|
140
|
+
# Same pattern as dsf/ewald - subtract SR to get LR
|
|
141
|
+
if self.subtract_sr:
|
|
142
|
+
e = e - self.coul_simple_sr(data)
|
|
54
143
|
return e
|
|
55
144
|
|
|
56
|
-
def
|
|
57
|
-
data
|
|
58
|
-
|
|
145
|
+
def coul_simple_sr(self, data: dict[str, Tensor]) -> Tensor:
|
|
146
|
+
return _calc_coulomb_sr(data, self.rc, self.envelope, self.key_in, self._factor)
|
|
147
|
+
|
|
148
|
+
def coul_dsf(self, data: dict[str, Tensor]) -> Tensor:
|
|
149
|
+
suffix = nbops.resolve_suffix(data, ["_coulomb", "_lr"])
|
|
150
|
+
data = ops.lazy_calc_dij(data, suffix)
|
|
151
|
+
d_ij = data[f"d_ij{suffix}"]
|
|
59
152
|
q = data[self.key_in]
|
|
60
|
-
q_i, q_j = nbops.get_ij(q, data, suffix=
|
|
153
|
+
q_i, q_j = nbops.get_ij(q, data, suffix=suffix)
|
|
61
154
|
J = ops.coulomb_matrix_dsf(d_ij, self.dsf_rc, self.dsf_alpha, data)
|
|
62
|
-
e = (q_i * q_j * J).sum(-1)
|
|
155
|
+
e = (q_i * q_j * J).sum(-1, dtype=torch.float64)
|
|
63
156
|
e = self._factor * nbops.mol_sum(e, data)
|
|
64
|
-
|
|
157
|
+
if self.subtract_sr:
|
|
158
|
+
e = e - self.coul_simple_sr(data)
|
|
65
159
|
return e
|
|
66
160
|
|
|
67
|
-
def coul_ewald(self, data:
|
|
68
|
-
J = ops.coulomb_matrix_ewald(data["coord"], data["cell"])
|
|
161
|
+
def coul_ewald(self, data: dict[str, Tensor]) -> Tensor:
|
|
162
|
+
J = ops.coulomb_matrix_ewald(data["coord"], data["cell"], accuracy=self.ewald_accuracy)
|
|
69
163
|
q_i, q_j = data["charges"].unsqueeze(-1), data["charges"].unsqueeze(-2)
|
|
70
|
-
e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1)
|
|
71
|
-
|
|
164
|
+
e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1, dtype=torch.float64)
|
|
165
|
+
if self.subtract_sr:
|
|
166
|
+
e = e - self.coul_simple_sr(data)
|
|
72
167
|
return e
|
|
73
168
|
|
|
74
|
-
def forward(self, data:
|
|
169
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
75
170
|
if self.method == "simple":
|
|
76
171
|
e = self.coul_simple(data)
|
|
77
172
|
elif self.method == "dsf":
|
|
@@ -81,46 +176,131 @@ class LRCoulomb(nn.Module):
|
|
|
81
176
|
else:
|
|
82
177
|
raise ValueError(f"Unknown method {self.method}")
|
|
83
178
|
if self.key_out in data:
|
|
84
|
-
data[self.key_out] = data[self.key_out] + e
|
|
179
|
+
data[self.key_out] = data[self.key_out].double() + e
|
|
85
180
|
else:
|
|
86
181
|
data[self.key_out] = e
|
|
87
182
|
return data
|
|
88
183
|
|
|
89
184
|
|
|
185
|
+
class SRCoulomb(nn.Module):
|
|
186
|
+
"""Subtract short-range Coulomb contribution from energy.
|
|
187
|
+
|
|
188
|
+
For models trained with "simple" Coulomb mode, the NN has implicitly learned
|
|
189
|
+
the short-range Coulomb interaction. When using DSF or Ewald summation for
|
|
190
|
+
the full Coulomb energy, we need to subtract this short-range contribution
|
|
191
|
+
to avoid double-counting.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
rc : float
|
|
196
|
+
Cutoff radius for short-range Coulomb. Default is 4.6 Angstrom.
|
|
197
|
+
key_in : str
|
|
198
|
+
Key for input charges in data dict. Default is "charges".
|
|
199
|
+
key_out : str
|
|
200
|
+
Key for output energy in data dict. Default is "energy".
|
|
201
|
+
envelope : str
|
|
202
|
+
Envelope function for cutoff: "exp" (mollifier) or "cosine". Default is "exp".
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
rc: float = 4.6,
|
|
208
|
+
key_in: str = "charges",
|
|
209
|
+
key_out: str = "energy",
|
|
210
|
+
envelope: str = "exp",
|
|
211
|
+
):
|
|
212
|
+
super().__init__()
|
|
213
|
+
self.key_in = key_in
|
|
214
|
+
self.key_out = key_out
|
|
215
|
+
self._factor = constants.half_Hartree * constants.Bohr
|
|
216
|
+
self.register_buffer("rc", torch.tensor(rc))
|
|
217
|
+
if envelope not in ("exp", "cosine"):
|
|
218
|
+
raise ValueError(f"Unknown envelope {envelope}, must be 'exp' or 'cosine'")
|
|
219
|
+
self.envelope = envelope
|
|
220
|
+
|
|
221
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
222
|
+
"""Subtract short-range Coulomb from energy."""
|
|
223
|
+
e_sr = _calc_coulomb_sr(data, self.rc, self.envelope, self.key_in, self._factor)
|
|
224
|
+
|
|
225
|
+
# Subtract short-range Coulomb from energy (in float64)
|
|
226
|
+
if self.key_out in data:
|
|
227
|
+
data[self.key_out] = data[self.key_out].double() - e_sr
|
|
228
|
+
else:
|
|
229
|
+
data[self.key_out] = -e_sr
|
|
230
|
+
return data
|
|
231
|
+
|
|
232
|
+
|
|
90
233
|
class DispParam(nn.Module):
|
|
91
234
|
def __init__(
|
|
92
235
|
self,
|
|
93
|
-
ref_c6:
|
|
94
|
-
ref_alpha:
|
|
95
|
-
ptfile:
|
|
236
|
+
ref_c6: dict[int, Tensor] | Tensor | None = None,
|
|
237
|
+
ref_alpha: dict[int, Tensor] | Tensor | None = None,
|
|
238
|
+
ptfile: str | None = None,
|
|
96
239
|
key_in: str = "disp_param",
|
|
97
240
|
key_out: str = "disp_param",
|
|
98
241
|
):
|
|
99
242
|
super().__init__()
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
243
|
+
# Validate: cannot mix ptfile with ref_c6/ref_alpha
|
|
244
|
+
if ptfile is not None and (ref_c6 is not None or ref_alpha is not None):
|
|
245
|
+
raise ValueError("Cannot specify both ptfile and ref_c6/ref_alpha.")
|
|
246
|
+
|
|
247
|
+
# Load reference data
|
|
248
|
+
if ptfile is not None:
|
|
249
|
+
ref = torch.load(ptfile, weights_only=True)
|
|
250
|
+
elif ref_c6 is not None or ref_alpha is not None:
|
|
251
|
+
ref = torch.zeros(87, 2)
|
|
252
|
+
for i, p in enumerate([ref_c6, ref_alpha]):
|
|
253
|
+
if p is not None:
|
|
254
|
+
if isinstance(p, Tensor):
|
|
255
|
+
ref[: p.shape[0], i] = p
|
|
256
|
+
else:
|
|
257
|
+
for k, v in p.items():
|
|
258
|
+
ref[k, i] = v
|
|
259
|
+
else:
|
|
260
|
+
# Placeholder - will be populated by load_state_dict
|
|
261
|
+
ref = torch.zeros(87, 2)
|
|
262
|
+
|
|
263
|
+
# Element 0 represents dummy atoms with c6=0 and alpha=1
|
|
117
264
|
ref[0, 0] = 0.0
|
|
118
265
|
ref[0, 1] = 1.0
|
|
119
266
|
self.register_buffer("disp_param0", ref)
|
|
120
267
|
self.key_in = key_in
|
|
121
268
|
self.key_out = key_out
|
|
122
269
|
|
|
123
|
-
def
|
|
270
|
+
def _load_from_state_dict(
|
|
271
|
+
self,
|
|
272
|
+
state_dict: dict,
|
|
273
|
+
prefix: str,
|
|
274
|
+
local_metadata: dict,
|
|
275
|
+
strict: bool,
|
|
276
|
+
missing_keys: list,
|
|
277
|
+
unexpected_keys: list,
|
|
278
|
+
error_msgs: list,
|
|
279
|
+
) -> None:
|
|
280
|
+
# Resize placeholder buffer to match checkpoint size before loading
|
|
281
|
+
key = prefix + "disp_param0"
|
|
282
|
+
if key in state_dict:
|
|
283
|
+
buf = state_dict[key]
|
|
284
|
+
if buf.shape != self.disp_param0.shape:
|
|
285
|
+
# Resize placeholder to match checkpoint
|
|
286
|
+
self.disp_param0 = torch.zeros_like(buf)
|
|
287
|
+
|
|
288
|
+
# Validate buffer has non-zero values (safety check)
|
|
289
|
+
nonzero = (buf != 0).sum() / buf.numel()
|
|
290
|
+
if nonzero < 0.1:
|
|
291
|
+
import warnings
|
|
292
|
+
|
|
293
|
+
warnings.warn(
|
|
294
|
+
f"DispParam buffer appears to have mostly zero values (nonzero: {nonzero:.1%}). "
|
|
295
|
+
"This may indicate a loading issue.",
|
|
296
|
+
stacklevel=2,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
super()._load_from_state_dict(
|
|
300
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
124
304
|
disp_param_mult = data[self.key_in].clamp(min=-4, max=4).exp()
|
|
125
305
|
disp_param = self.disp_param0[data["numbers"]]
|
|
126
306
|
vals = disp_param * disp_param_mult
|
|
@@ -141,24 +321,28 @@ class D3TS(nn.Module):
|
|
|
141
321
|
self.key_in = key_in
|
|
142
322
|
self.key_out = key_out
|
|
143
323
|
|
|
144
|
-
def forward(self, data:
|
|
324
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
325
|
+
suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
|
|
326
|
+
|
|
145
327
|
disp_param = data[self.key_in]
|
|
146
|
-
disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix=
|
|
328
|
+
disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix=suffix)
|
|
147
329
|
c6_i, alpha_i = disp_param_i.unbind(dim=-1)
|
|
148
330
|
c6_j, alpha_j = disp_param_j.unbind(dim=-1)
|
|
149
331
|
|
|
150
332
|
# TS combination rule
|
|
151
333
|
c6ij = 2 * c6_i * c6_j / (c6_i * alpha_j / alpha_i + c6_j * alpha_i / alpha_j).clamp(min=1e-4)
|
|
334
|
+
c6ij = nbops.mask_ij_(c6ij, data, 0.0, suffix=suffix)
|
|
152
335
|
|
|
153
336
|
rr = self.r4r2[data["numbers"]]
|
|
154
|
-
rr_i, rr_j = nbops.get_ij(rr, data, suffix=
|
|
337
|
+
rr_i, rr_j = nbops.get_ij(rr, data, suffix=suffix)
|
|
155
338
|
rrij = 3 * rr_i * rr_j
|
|
156
|
-
rrij = nbops.mask_ij_(rrij, data, 1.0, suffix=
|
|
339
|
+
rrij = nbops.mask_ij_(rrij, data, 1.0, suffix=suffix)
|
|
157
340
|
r0ij = self.a1 * rrij.sqrt() + self.a2
|
|
158
341
|
|
|
159
|
-
ops.
|
|
160
|
-
d_ij = data["
|
|
342
|
+
ops.lazy_calc_dij(data, suffix)
|
|
343
|
+
d_ij = data[f"d_ij{suffix}"] * constants.Bohr_inv
|
|
161
344
|
e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
|
|
345
|
+
|
|
162
346
|
e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
|
|
163
347
|
|
|
164
348
|
if self.key_out in data:
|
|
@@ -170,74 +354,295 @@ class D3TS(nn.Module):
|
|
|
170
354
|
|
|
171
355
|
|
|
172
356
|
class DFTD3(nn.Module):
|
|
173
|
-
"""DFT-D3 implementation.
|
|
174
|
-
|
|
357
|
+
"""DFT-D3 implementation using nvalchemiops GPU-accelerated kernels.
|
|
358
|
+
|
|
359
|
+
BJ damping, C6 and C8 terms, without 3-body term.
|
|
360
|
+
|
|
361
|
+
This implementation uses nvalchemiops.interactions.dispersion.dftd3 for
|
|
362
|
+
GPU-accelerated computation of dispersion energies and forces. It is
|
|
363
|
+
differentiable through a custom autograd function.
|
|
364
|
+
|
|
365
|
+
Parameters
|
|
366
|
+
----------
|
|
367
|
+
s8 : float
|
|
368
|
+
Scaling factor for C8 term.
|
|
369
|
+
a1 : float
|
|
370
|
+
BJ damping parameter 1.
|
|
371
|
+
a2 : float
|
|
372
|
+
BJ damping parameter 2.
|
|
373
|
+
s6 : float, optional
|
|
374
|
+
Scaling factor for C6 term. Default is 1.0.
|
|
375
|
+
cutoff : float, optional
|
|
376
|
+
Cutoff distance in Angstroms for smoothing. Default is 15.0.
|
|
377
|
+
smoothing_fraction : float, optional
|
|
378
|
+
Fraction of cutoff distance used for smoothing window width.
|
|
379
|
+
Smoothing starts at cutoff * (1 - smoothing_fraction) and ends at cutoff.
|
|
380
|
+
Example: With cutoff=15.0 and smoothing_fraction=0.2:
|
|
381
|
+
- Smoothing starts at 12.0 Å (15.0 * 0.8)
|
|
382
|
+
- Smoothing ends at 15.0 Å
|
|
383
|
+
Default is 0.2 (20% of cutoff as smoothing window).
|
|
384
|
+
key_out : str, optional
|
|
385
|
+
Key for output energy in data dict. Default is "energy".
|
|
386
|
+
compute_forces : bool, optional
|
|
387
|
+
Whether to add forces to data dict. Default is False.
|
|
388
|
+
compute_virial : bool, optional
|
|
389
|
+
Whether to compute virial for cell gradients. Default is False.
|
|
390
|
+
|
|
391
|
+
Attributes
|
|
392
|
+
----------
|
|
393
|
+
smoothing_on : float
|
|
394
|
+
Distance where smoothing starts (Angstroms).
|
|
395
|
+
smoothing_off : float
|
|
396
|
+
Distance where smoothing ends / cutoff (Angstroms).
|
|
397
|
+
s6, s8, a1, a2 : float
|
|
398
|
+
BJ damping parameters.
|
|
399
|
+
|
|
400
|
+
Notes
|
|
401
|
+
-----
|
|
402
|
+
Neighbor list keys follow a suffix resolution pattern: methods first look
|
|
403
|
+
for module-specific keys (e.g., nbmat_dftd3, shifts_dftd3), falling back
|
|
404
|
+
to shared _lr suffix (nbmat_lr, shifts_lr) if not found.
|
|
175
405
|
"""
|
|
176
406
|
|
|
177
|
-
def __init__(
|
|
407
|
+
def __init__(
|
|
408
|
+
self,
|
|
409
|
+
s8: float,
|
|
410
|
+
a1: float,
|
|
411
|
+
a2: float,
|
|
412
|
+
s6: float = 1.0,
|
|
413
|
+
cutoff: float = 15.0,
|
|
414
|
+
smoothing_fraction: float = 0.2,
|
|
415
|
+
key_out: str = "energy",
|
|
416
|
+
compute_forces: bool = False,
|
|
417
|
+
compute_virial: bool = False,
|
|
418
|
+
):
|
|
178
419
|
super().__init__()
|
|
179
420
|
self.key_out = key_out
|
|
421
|
+
self.compute_forces = compute_forces
|
|
422
|
+
self.compute_virial = compute_virial
|
|
180
423
|
# BJ damping parameters
|
|
181
424
|
self.s6 = s6
|
|
182
425
|
self.s8 = s8
|
|
183
|
-
self.s9 = 4.0 / 3.0
|
|
184
426
|
self.a1 = a1
|
|
185
427
|
self.a2 = a2
|
|
186
|
-
self.a3 = 16.0
|
|
187
|
-
# CN parameters
|
|
188
|
-
self.k1 = -16.0
|
|
189
|
-
self.k3 = -4.0
|
|
190
|
-
# data
|
|
191
|
-
self.register_buffer("c6ab", torch.zeros(95, 95, 5, 5, 3))
|
|
192
|
-
self.register_buffer("r4r2", torch.zeros(95))
|
|
193
|
-
self.register_buffer("rcov", torch.zeros(95))
|
|
194
|
-
self.register_buffer("cnmax", torch.zeros(95))
|
|
195
|
-
sd = constants.get_dftd3_param()
|
|
196
|
-
self.load_state_dict(sd)
|
|
197
|
-
|
|
198
|
-
def _calc_c6ij(self, data: Dict[str, Tensor]) -> Tensor:
|
|
199
|
-
# CN part
|
|
200
|
-
# short range for CN
|
|
201
|
-
# d_ij = data["d_ij"] * constants.Bohr_inv
|
|
202
|
-
data = ops.lazy_calc_dij_lr(data)
|
|
203
|
-
d_ij = data["d_ij_lr"] * constants.Bohr_inv
|
|
204
|
-
|
|
205
|
-
numbers = data["numbers"]
|
|
206
|
-
numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr")
|
|
207
|
-
rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr")
|
|
208
|
-
rcov_ij = rcov_i + rcov_j
|
|
209
|
-
cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0)))
|
|
210
|
-
cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr")
|
|
211
|
-
cn = cn_ij.sum(-1)
|
|
212
|
-
cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1)
|
|
213
|
-
cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr")
|
|
214
|
-
c6ab = self.c6ab[numbers_i, numbers_j]
|
|
215
|
-
c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1)
|
|
216
|
-
c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr")
|
|
217
|
-
l_ij = torch.exp(self.k3 * ((cn_i - cnref_i).pow(2) + (cn_j - cnref_j).pow(2)))
|
|
218
|
-
w = l_ij.flatten(-2, -1).sum(-1)
|
|
219
|
-
z = torch.einsum("...ij,...ij->...", c6ref, l_ij)
|
|
220
|
-
_w = w < 1e-5
|
|
221
|
-
z[_w] = 0.0
|
|
222
|
-
c6_ij = z / w.clamp(min=1e-5)
|
|
223
|
-
return c6_ij
|
|
224
|
-
|
|
225
|
-
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
226
|
-
c6ij = self._calc_c6ij(data)
|
|
227
428
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
|
|
232
|
-
r0ij = self.a1 * rrij.sqrt() + self.a2
|
|
429
|
+
# Smoothing parameters as module attributes
|
|
430
|
+
self.smoothing_on: float = cutoff * (1 - smoothing_fraction)
|
|
431
|
+
self.smoothing_off: float = cutoff
|
|
233
432
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
433
|
+
# Load D3 reference parameters and convert to nvalchemiops format
|
|
434
|
+
dirname = os.path.dirname(os.path.dirname(__file__))
|
|
435
|
+
filename = os.path.join(dirname, "dftd3_data.pt")
|
|
436
|
+
param = torch.load(filename, map_location="cpu", weights_only=True)
|
|
437
|
+
|
|
438
|
+
c6ab_packed = param["c6ab"]
|
|
439
|
+
c6ab = c6ab_packed[..., 0].contiguous()
|
|
440
|
+
cn_ref = c6ab_packed[..., 1].contiguous()
|
|
441
|
+
|
|
442
|
+
# Register buffers for D3 parameters
|
|
443
|
+
self.register_buffer("rcov", param["rcov"].float())
|
|
444
|
+
self.register_buffer("r4r2", param["r4r2"].float())
|
|
445
|
+
self.register_buffer("c6ab", c6ab.float())
|
|
446
|
+
self.register_buffer("cn_ref", cn_ref.float())
|
|
447
|
+
|
|
448
|
+
def set_smoothing(self, cutoff: float, smoothing_fraction: float = 0.2) -> None:
|
|
449
|
+
"""Update smoothing parameters based on new cutoff and fraction.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
cutoff : float
|
|
454
|
+
Cutoff distance in Angstroms.
|
|
455
|
+
smoothing_fraction : float
|
|
456
|
+
Fraction of cutoff used as smoothing window width.
|
|
457
|
+
Smoothing occurs from cutoff * (1 - smoothing_fraction) to cutoff.
|
|
458
|
+
Example: smoothing_fraction=0.2 means smoothing over last 20%
|
|
459
|
+
of cutoff distance (from 0.8*cutoff to cutoff). Default is 0.2.
|
|
460
|
+
"""
|
|
461
|
+
self.smoothing_on = cutoff * (1 - smoothing_fraction)
|
|
462
|
+
self.smoothing_off = cutoff
|
|
463
|
+
|
|
464
|
+
def _load_from_state_dict(
|
|
465
|
+
self,
|
|
466
|
+
state_dict: dict,
|
|
467
|
+
prefix: str,
|
|
468
|
+
local_metadata: dict,
|
|
469
|
+
strict: bool,
|
|
470
|
+
missing_keys: list,
|
|
471
|
+
unexpected_keys: list,
|
|
472
|
+
error_msgs: list,
|
|
473
|
+
) -> None:
|
|
474
|
+
"""Handle loading from old state dict format with packed c6ab.
|
|
475
|
+
|
|
476
|
+
Migrates from legacy format where c6ab had shape [95, 95, 5, 5, 3]
|
|
477
|
+
with last dimension containing (c6ref, cnref_i, cnref_j) to new format
|
|
478
|
+
where c6ab is [95, 95, 5, 5] and cn_ref is separate [95, 95, 5, 5].
|
|
479
|
+
Also removes deprecated cnmax parameter if present.
|
|
480
|
+
"""
|
|
481
|
+
c6ab_key = prefix + "c6ab"
|
|
482
|
+
cn_ref_key = prefix + "cn_ref"
|
|
483
|
+
cnmax_key = prefix + "cnmax"
|
|
484
|
+
|
|
485
|
+
if c6ab_key in state_dict and state_dict[c6ab_key].ndim == 5:
|
|
486
|
+
c6ab_packed = state_dict[c6ab_key]
|
|
487
|
+
state_dict[c6ab_key] = c6ab_packed[..., 0].contiguous()
|
|
488
|
+
state_dict[cn_ref_key] = c6ab_packed[..., 1].contiguous()
|
|
489
|
+
|
|
490
|
+
state_dict.pop(cnmax_key, None)
|
|
491
|
+
|
|
492
|
+
super()._load_from_state_dict(
|
|
493
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
494
|
+
)
|
|
238
495
|
|
|
496
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
497
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
498
|
+
coord = data["coord"]
|
|
499
|
+
numbers = data["numbers"].to(torch.int32)
|
|
500
|
+
cell = data.get("cell")
|
|
501
|
+
|
|
502
|
+
# Prepare inputs based on nb_mode
|
|
503
|
+
if nb_mode == 0:
|
|
504
|
+
# Batched mode without neighbor matrix - construct full neighbor matrix
|
|
505
|
+
# This only applies when no nbmat is provided in data
|
|
506
|
+
B, N = coord.shape[:2]
|
|
507
|
+
coord_flat = coord.flatten(0, 1) # (B*N, 3)
|
|
508
|
+
numbers_flat = numbers.flatten() # (B*N,)
|
|
509
|
+
batch_idx = torch.arange(B, device=coord.device, dtype=torch.int32).repeat_interleave(N)
|
|
510
|
+
num_systems = B
|
|
511
|
+
total_atoms = B * N
|
|
512
|
+
max_neighbors = N - 1
|
|
513
|
+
|
|
514
|
+
# Build dense all-to-all neighbor matrix for mode 0 (no pre-computed neighbor list)
|
|
515
|
+
# Creates (N, N-1) template where each atom connects to all others except itself
|
|
516
|
+
arange_n = torch.arange(N, device=coord.device, dtype=torch.int32)
|
|
517
|
+
all_indices = arange_n.unsqueeze(0).expand(N, -1)
|
|
518
|
+
mask = all_indices != arange_n.unsqueeze(1)
|
|
519
|
+
template = all_indices[mask].view(N, N - 1)
|
|
520
|
+
batch_offsets = torch.arange(B, device=coord.device, dtype=torch.int32).unsqueeze(1).unsqueeze(2) * N
|
|
521
|
+
neighbor_matrix = (template.unsqueeze(0) + batch_offsets).view(total_atoms, max_neighbors)
|
|
522
|
+
|
|
523
|
+
fill_value = total_atoms
|
|
524
|
+
neighbor_matrix_shifts: Tensor | None = None
|
|
525
|
+
cell_for_autograd: Tensor | None = None
|
|
526
|
+
|
|
527
|
+
elif nb_mode == 1:
|
|
528
|
+
# Flat mode with neighbor matrix
|
|
529
|
+
suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
|
|
530
|
+
N = coord.shape[0]
|
|
531
|
+
coord_flat = coord
|
|
532
|
+
numbers_flat = numbers
|
|
533
|
+
neighbor_matrix = data[f"nbmat{suffix}"].to(torch.int32)
|
|
534
|
+
|
|
535
|
+
mol_idx = data.get("mol_idx")
|
|
536
|
+
if mol_idx is not None:
|
|
537
|
+
batch_idx = mol_idx.to(torch.int32)
|
|
538
|
+
num_systems = int(mol_idx.max().item()) + 1
|
|
539
|
+
else:
|
|
540
|
+
batch_idx = torch.zeros(N, dtype=torch.int32, device=coord.device)
|
|
541
|
+
num_systems = 1
|
|
542
|
+
|
|
543
|
+
shifts = data.get(f"shifts{suffix}")
|
|
544
|
+
neighbor_matrix_shifts = shifts.to(torch.int32) if shifts is not None else None
|
|
545
|
+
|
|
546
|
+
fill_value = N
|
|
547
|
+
cell_for_autograd = cell
|
|
548
|
+
|
|
549
|
+
elif nb_mode == 2:
|
|
550
|
+
# Batched mode with neighbor matrix
|
|
551
|
+
suffix = nbops.resolve_suffix(data, ["_dftd3", "_lr"])
|
|
552
|
+
B, N = coord.shape[:2]
|
|
553
|
+
coord_flat = coord.flatten(0, 1)
|
|
554
|
+
numbers_flat = numbers.flatten()
|
|
555
|
+
batch_idx = torch.arange(B, device=coord.device, dtype=torch.int32).repeat_interleave(N)
|
|
556
|
+
num_systems = B
|
|
557
|
+
|
|
558
|
+
nbmat = data[f"nbmat{suffix}"]
|
|
559
|
+
offsets = torch.arange(B, device=coord.device).unsqueeze(1) * N
|
|
560
|
+
neighbor_matrix = (nbmat + offsets.unsqueeze(-1)).flatten(0, 1).to(torch.int32)
|
|
561
|
+
|
|
562
|
+
shifts = data.get(f"shifts{suffix}")
|
|
563
|
+
if shifts is not None:
|
|
564
|
+
neighbor_matrix_shifts = shifts.flatten(0, 1).to(torch.int32)
|
|
565
|
+
else:
|
|
566
|
+
neighbor_matrix_shifts = None
|
|
567
|
+
|
|
568
|
+
fill_value = B * N
|
|
569
|
+
cell_for_autograd = cell
|
|
570
|
+
|
|
571
|
+
else:
|
|
572
|
+
raise ValueError(f"Unsupported neighbor mode: {nb_mode}")
|
|
573
|
+
|
|
574
|
+
# Compute energy using autograd function
|
|
575
|
+
energy_ev = self._compute_energy_autograd(
|
|
576
|
+
coord_flat,
|
|
577
|
+
cell_for_autograd,
|
|
578
|
+
numbers_flat,
|
|
579
|
+
batch_idx,
|
|
580
|
+
neighbor_matrix,
|
|
581
|
+
neighbor_matrix_shifts,
|
|
582
|
+
num_systems,
|
|
583
|
+
fill_value,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Add dispersion energy to output
|
|
239
587
|
if self.key_out in data:
|
|
240
|
-
data[self.key_out] = data[self.key_out] +
|
|
588
|
+
data[self.key_out] = data[self.key_out] + energy_ev.to(data[self.key_out].dtype)
|
|
241
589
|
else:
|
|
242
|
-
data[self.key_out] =
|
|
590
|
+
data[self.key_out] = energy_ev
|
|
591
|
+
|
|
592
|
+
# Optionally compute and add forces to data dict
|
|
593
|
+
# Compute forces via autograd (will use saved forces from DFTD3Function)
|
|
594
|
+
if self.compute_forces and not torch.jit.is_scripting() and coord_flat.requires_grad:
|
|
595
|
+
# Forces are -grad of energy
|
|
596
|
+
forces_flat = torch.autograd.grad(
|
|
597
|
+
energy_ev.sum(),
|
|
598
|
+
coord_flat,
|
|
599
|
+
create_graph=self.training,
|
|
600
|
+
retain_graph=True,
|
|
601
|
+
)[0]
|
|
602
|
+
forces = -forces_flat
|
|
603
|
+
|
|
604
|
+
# Reshape if needed
|
|
605
|
+
if nb_mode == 0 or nb_mode == 2:
|
|
606
|
+
B, N = coord.shape[:2]
|
|
607
|
+
forces = forces.view(B, N, 3)
|
|
608
|
+
|
|
609
|
+
if "forces" in data:
|
|
610
|
+
data["forces"] = data["forces"] + forces
|
|
611
|
+
else:
|
|
612
|
+
data["forces"] = forces
|
|
613
|
+
|
|
243
614
|
return data
|
|
615
|
+
|
|
616
|
+
def _compute_energy_autograd(
|
|
617
|
+
self,
|
|
618
|
+
coord: Tensor,
|
|
619
|
+
cell: Tensor | None,
|
|
620
|
+
numbers: Tensor,
|
|
621
|
+
batch_idx: Tensor,
|
|
622
|
+
neighbor_matrix: Tensor,
|
|
623
|
+
neighbor_matrix_shifts: Tensor | None,
|
|
624
|
+
num_systems: int,
|
|
625
|
+
fill_value: int,
|
|
626
|
+
) -> Tensor:
|
|
627
|
+
"""Compute DFT-D3 energy using custom op for differentiability and TorchScript."""
|
|
628
|
+
return dftd3_energy(
|
|
629
|
+
coord=coord,
|
|
630
|
+
cell=cell,
|
|
631
|
+
numbers=numbers,
|
|
632
|
+
batch_idx=batch_idx,
|
|
633
|
+
neighbor_matrix=neighbor_matrix,
|
|
634
|
+
shifts=neighbor_matrix_shifts,
|
|
635
|
+
rcov=self.rcov,
|
|
636
|
+
r4r2=self.r4r2,
|
|
637
|
+
c6ab=self.c6ab,
|
|
638
|
+
cn_ref=self.cn_ref,
|
|
639
|
+
a1=self.a1,
|
|
640
|
+
a2=self.a2,
|
|
641
|
+
s6=self.s6,
|
|
642
|
+
s8=self.s8,
|
|
643
|
+
num_systems=num_systems,
|
|
644
|
+
fill_value=fill_value,
|
|
645
|
+
smoothing_on=self.smoothing_on,
|
|
646
|
+
smoothing_off=self.smoothing_off,
|
|
647
|
+
compute_virial=self.compute_virial,
|
|
648
|
+
)
|