aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/modules/core.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from aimnet import constants, nbops, ops
|
|
7
|
+
from aimnet.config import get_init_module, get_module
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def MLP(
|
|
11
|
+
n_in: int,
|
|
12
|
+
n_out: int,
|
|
13
|
+
hidden: Optional[List[int]] = None,
|
|
14
|
+
activation_fn: Callable | str = "torch.nn.GELU",
|
|
15
|
+
activation_kwargs: Optional[Dict[str, Any]] = None,
|
|
16
|
+
weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
|
|
17
|
+
bias: bool = True,
|
|
18
|
+
last_linear: bool = True,
|
|
19
|
+
):
|
|
20
|
+
"""Convenience function to build MLP from config"""
|
|
21
|
+
if hidden is None:
|
|
22
|
+
hidden = []
|
|
23
|
+
if activation_kwargs is None:
|
|
24
|
+
activation_kwargs = {}
|
|
25
|
+
# hp search hack
|
|
26
|
+
hidden = [x for x in hidden if x > 0]
|
|
27
|
+
if isinstance(activation_fn, str):
|
|
28
|
+
activation_fn = get_init_module(activation_fn, kwargs=activation_kwargs)
|
|
29
|
+
if isinstance(weight_init_fn, str):
|
|
30
|
+
weight_init_fn = get_module(weight_init_fn)
|
|
31
|
+
sizes = [n_in, *hidden, n_out]
|
|
32
|
+
layers = []
|
|
33
|
+
for i in range(1, len(sizes)):
|
|
34
|
+
n_in, n_out = sizes[i - 1], sizes[i]
|
|
35
|
+
layer = nn.Linear(n_in, n_out, bias=bias)
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
weight_init_fn(layer.weight)
|
|
38
|
+
if bias:
|
|
39
|
+
nn.init.zeros_(layer.bias)
|
|
40
|
+
layers.append(layer)
|
|
41
|
+
if not (last_linear and i == len(sizes) - 1):
|
|
42
|
+
layers.append(activation_fn)
|
|
43
|
+
return nn.Sequential(*layers)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Embedding(nn.Embedding):
|
|
47
|
+
def __init__(self, init: Optional[Dict[int, Any]] = None, **kwargs):
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
if init is not None:
|
|
51
|
+
for i in range(self.weight.shape[0]):
|
|
52
|
+
if self.padding_idx is not None and i == self.padding_idx:
|
|
53
|
+
continue
|
|
54
|
+
if i in init:
|
|
55
|
+
self.weight[i] = init[i]
|
|
56
|
+
else:
|
|
57
|
+
self.weight[i].fill_(float("nan"))
|
|
58
|
+
for k, v in init.items():
|
|
59
|
+
self.weight[k] = v
|
|
60
|
+
|
|
61
|
+
def reset_parameters(self) -> None:
|
|
62
|
+
nn.init.orthogonal_(self.weight)
|
|
63
|
+
if self.padding_idx is not None:
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
self.weight[self.padding_idx].fill_(0)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DSequential(nn.Module):
|
|
69
|
+
def __init__(self, *modules):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.module = nn.ModuleList(modules)
|
|
72
|
+
|
|
73
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
74
|
+
for m in self.module:
|
|
75
|
+
data = m(data)
|
|
76
|
+
return data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class AtomicShift(nn.Module):
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
key_in: str,
|
|
83
|
+
key_out: str,
|
|
84
|
+
num_types: int = 64,
|
|
85
|
+
dtype: torch.dtype = torch.float,
|
|
86
|
+
requires_grad: bool = True,
|
|
87
|
+
reduce_sum=False,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
shifts = nn.Embedding(num_types, 1, padding_idx=0, dtype=dtype)
|
|
91
|
+
shifts.weight.requires_grad_(requires_grad)
|
|
92
|
+
self.shifts = shifts
|
|
93
|
+
self.key_in = key_in
|
|
94
|
+
self.key_out = key_out
|
|
95
|
+
self.reduce_sum = reduce_sum
|
|
96
|
+
|
|
97
|
+
def extra_repr(self) -> str:
|
|
98
|
+
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
99
|
+
|
|
100
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
101
|
+
shifts = self.shifts(data["numbers"]).squeeze(-1)
|
|
102
|
+
if self.reduce_sum:
|
|
103
|
+
shifts = nbops.mol_sum(shifts, data)
|
|
104
|
+
data[self.key_out] = data[self.key_in] + shifts
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class AtomicSum(nn.Module):
|
|
109
|
+
def __init__(self, key_in: str, key_out: str):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.key_in = key_in
|
|
112
|
+
self.key_out = key_out
|
|
113
|
+
|
|
114
|
+
def extra_repr(self) -> str:
|
|
115
|
+
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
116
|
+
|
|
117
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
118
|
+
data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
|
|
119
|
+
return data
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Output(nn.Module):
|
|
123
|
+
def __init__(self, mlp: Dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.key_in = key_in
|
|
126
|
+
self.key_out = key_out
|
|
127
|
+
if not isinstance(mlp, nn.Module):
|
|
128
|
+
mlp = MLP(n_in=n_in, n_out=n_out, **mlp)
|
|
129
|
+
self.mlp = mlp
|
|
130
|
+
|
|
131
|
+
def extra_repr(self) -> str:
|
|
132
|
+
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
133
|
+
|
|
134
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
135
|
+
v = self.mlp(data[self.key_in]).squeeze(-1)
|
|
136
|
+
if data["_input_padded"].item():
|
|
137
|
+
v = nbops.mask_i_(v, data, mask_value=0.0)
|
|
138
|
+
data[self.key_out] = v
|
|
139
|
+
return data
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class Forces(nn.Module):
|
|
143
|
+
def __init__(self, module: nn.Module, x: str = "coord", y: str = "energy", key_out: str = "forces"):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.module = module
|
|
146
|
+
self.x = x
|
|
147
|
+
self.y = y
|
|
148
|
+
self.key_out = key_out
|
|
149
|
+
|
|
150
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
151
|
+
prev = torch.is_grad_enabled()
|
|
152
|
+
torch.set_grad_enabled(True)
|
|
153
|
+
data[self.x].requires_grad_(True)
|
|
154
|
+
data = self.module(data)
|
|
155
|
+
y = data[self.y]
|
|
156
|
+
g = torch.autograd.grad([y.sum()], [data[self.x]], create_graph=self.training)[0]
|
|
157
|
+
assert g is not None
|
|
158
|
+
data[self.key_out] = -g
|
|
159
|
+
torch.set_grad_enabled(prev)
|
|
160
|
+
return data
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Dipole(nn.Module):
|
|
164
|
+
def __init__(self, key_in: str = "charges", key_out: str = "dipole", center_coord: bool = False):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.center_coord = center_coord
|
|
167
|
+
self.key_out = key_out
|
|
168
|
+
self.key_in = key_in
|
|
169
|
+
self.register_buffer("mass", constants.get_masses())
|
|
170
|
+
|
|
171
|
+
def extra_repr(self) -> str:
|
|
172
|
+
return f"key_in: {self.key_in}, key_out: {self.key_out}, center_coord: {self.center_coord}"
|
|
173
|
+
|
|
174
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
175
|
+
q = data[self.key_in]
|
|
176
|
+
r = data["coord"]
|
|
177
|
+
if self.center_coord:
|
|
178
|
+
r = ops.center_coordinates(r, data, self.mass[data["numbers"]])
|
|
179
|
+
data[self.key_out] = nbops.mol_sum(q.unsqueeze(-1) * r, data)
|
|
180
|
+
return data
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class Quadrupole(Dipole):
|
|
184
|
+
def __init__(self, key_in: str = "charges", key_out: str = "quadrupole", center_coord: bool = False):
|
|
185
|
+
super().__init__(key_in=key_in, key_out=key_out, center_coord=center_coord)
|
|
186
|
+
|
|
187
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
188
|
+
q = data[self.key_in]
|
|
189
|
+
r = data["coord"]
|
|
190
|
+
if self.center_coord:
|
|
191
|
+
r = ops.center_coordinates(r, data, self.mass[data["numbers"]])
|
|
192
|
+
_x = torch.cat([r.pow(2), r * r.roll(-1, -1)], dim=-1)
|
|
193
|
+
quad = nbops.mol_sum(q.unsqueeze(-1) * _x, data)
|
|
194
|
+
_x1, _x2 = quad.split(3, dim=-1)
|
|
195
|
+
_x1 = _x1 - _x1.mean(dim=-1, keepdim=True)
|
|
196
|
+
quad = torch.cat([_x1, _x2], dim=-1)
|
|
197
|
+
data[self.key_out] = quad
|
|
198
|
+
return data
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class SRRep(nn.Module):
|
|
202
|
+
"""GFN1-stype short range repulsion function"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, key_out="e_rep", cutoff_fn="none", rc=5.2, reduce_sum=True):
|
|
205
|
+
super().__init__()
|
|
206
|
+
from aimnet.constants import get_gfn1_rep
|
|
207
|
+
|
|
208
|
+
self.key_out = key_out
|
|
209
|
+
self.cutoff_fn = cutoff_fn
|
|
210
|
+
self.reduce_sum = reduce_sum
|
|
211
|
+
|
|
212
|
+
self.register_buffer("rc", torch.tensor(rc))
|
|
213
|
+
gfn1_repa, gfn1_repb = get_gfn1_rep()
|
|
214
|
+
weight = torch.stack([gfn1_repa, gfn1_repb], dim=-1)
|
|
215
|
+
self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
|
|
216
|
+
self.params.weight.requires_grad_(False)
|
|
217
|
+
|
|
218
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
219
|
+
p = self.params(data["numbers"])
|
|
220
|
+
p_i, p_j = nbops.get_ij(p, data)
|
|
221
|
+
p_ij = p_i * p_j
|
|
222
|
+
alpha_ij, zeff_ij = p_ij.unbind(-1)
|
|
223
|
+
d_ij = data["d_ij"]
|
|
224
|
+
e = torch.exp(-alpha_ij * d_ij.pow(1.5)) * zeff_ij / d_ij
|
|
225
|
+
e = nbops.mask_ij_(e, data, 0.0)
|
|
226
|
+
if self.cutoff_fn == "exp_cutoff":
|
|
227
|
+
e = e * ops.exp_cutoff(d_ij, self.rc)
|
|
228
|
+
elif self.cutoff_fn == "cosine_cutoff":
|
|
229
|
+
e = e * ops.cosine_cutoff(d_ij, self.rc)
|
|
230
|
+
e = e.sum(-1)
|
|
231
|
+
if self.reduce_sum:
|
|
232
|
+
e = nbops.mol_sum(e, data)
|
|
233
|
+
if self.key_out in data:
|
|
234
|
+
data[self.key_out] = data[self.key_out] + e
|
|
235
|
+
else:
|
|
236
|
+
data[self.key_out] = e
|
|
237
|
+
return data
|
aimnet/modules/lr.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from aimnet import constants, nbops, ops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LRCoulomb(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
key_in: str = "charges",
|
|
13
|
+
key_out: str = "e_h",
|
|
14
|
+
rc: float = 4.6,
|
|
15
|
+
method: str = "simple",
|
|
16
|
+
dsf_alpha: float = 0.2,
|
|
17
|
+
dsf_rc: float = 15.0,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.key_in = key_in
|
|
21
|
+
self.key_out = key_out
|
|
22
|
+
self._factor = constants.half_Hartree * constants.Bohr
|
|
23
|
+
self.register_buffer("rc", torch.tensor(rc))
|
|
24
|
+
self.dsf_alpha = dsf_alpha
|
|
25
|
+
self.dsf_rc = dsf_rc
|
|
26
|
+
if method in ("simple", "dsf", "ewald"):
|
|
27
|
+
self.method = method
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Unknown method {method}")
|
|
30
|
+
|
|
31
|
+
def coul_simple(self, data: Dict[str, Tensor]) -> Tensor:
|
|
32
|
+
data = ops.lazy_calc_dij_lr(data)
|
|
33
|
+
d_ij = data["d_ij_lr"]
|
|
34
|
+
q = data[self.key_in]
|
|
35
|
+
q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
|
|
36
|
+
q_ij = q_i * q_j
|
|
37
|
+
fc = 1.0 - ops.exp_cutoff(d_ij, self.rc)
|
|
38
|
+
e_ij = fc * q_ij / d_ij
|
|
39
|
+
e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr")
|
|
40
|
+
e_i = e_ij.sum(-1)
|
|
41
|
+
e = self._factor * nbops.mol_sum(e_i, data)
|
|
42
|
+
return e
|
|
43
|
+
|
|
44
|
+
def coul_simple_sr(self, data: Dict[str, Tensor]) -> Tensor:
|
|
45
|
+
d_ij = data["d_ij"]
|
|
46
|
+
q = data[self.key_in]
|
|
47
|
+
q_i, q_j = nbops.get_ij(q, data)
|
|
48
|
+
q_ij = q_i * q_j
|
|
49
|
+
fc = ops.exp_cutoff(d_ij, self.rc)
|
|
50
|
+
e_ij = fc * q_ij / d_ij
|
|
51
|
+
e_ij = nbops.mask_ij_(e_ij, data, 0.0)
|
|
52
|
+
e_i = e_ij.sum(-1)
|
|
53
|
+
e = self._factor * nbops.mol_sum(e_i, data)
|
|
54
|
+
return e
|
|
55
|
+
|
|
56
|
+
def coul_dsf(self, data: Dict[str, Tensor]) -> Tensor:
|
|
57
|
+
data = ops.lazy_calc_dij_lr(data)
|
|
58
|
+
d_ij = data["d_ij_lr"]
|
|
59
|
+
q = data[self.key_in]
|
|
60
|
+
q_i, q_j = nbops.get_ij(q, data, suffix="_lr")
|
|
61
|
+
J = ops.coulomb_matrix_dsf(d_ij, self.dsf_rc, self.dsf_alpha, data)
|
|
62
|
+
e = (q_i * q_j * J).sum(-1)
|
|
63
|
+
e = self._factor * nbops.mol_sum(e, data)
|
|
64
|
+
e = e - self.coul_simple_sr(data)
|
|
65
|
+
return e
|
|
66
|
+
|
|
67
|
+
def coul_ewald(self, data: Dict[str, Tensor]) -> Tensor:
|
|
68
|
+
J = ops.coulomb_matrix_ewald(data["coord"], data["cell"])
|
|
69
|
+
q_i, q_j = data["charges"].unsqueeze(-1), data["charges"].unsqueeze(-2)
|
|
70
|
+
e = self._factor * (q_i * q_j * J).flatten(-2, -1).sum(-1)
|
|
71
|
+
e = e - self.coul_simple_sr(data)
|
|
72
|
+
return e
|
|
73
|
+
|
|
74
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
75
|
+
if self.method == "simple":
|
|
76
|
+
e = self.coul_simple(data)
|
|
77
|
+
elif self.method == "dsf":
|
|
78
|
+
e = self.coul_dsf(data)
|
|
79
|
+
elif self.method == "ewald":
|
|
80
|
+
e = self.coul_ewald(data)
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unknown method {self.method}")
|
|
83
|
+
if self.key_out in data:
|
|
84
|
+
data[self.key_out] = data[self.key_out] + e
|
|
85
|
+
else:
|
|
86
|
+
data[self.key_out] = e
|
|
87
|
+
return data
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DispParam(nn.Module):
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
ref_c6: Optional[Dict[int, Tensor] | Tensor] = None,
|
|
94
|
+
ref_alpha: Optional[Dict[int, Tensor] | Tensor] = None,
|
|
95
|
+
ptfile: Optional[str] = None,
|
|
96
|
+
key_in: str = "disp_param",
|
|
97
|
+
key_out: str = "disp_param",
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
if (
|
|
101
|
+
ptfile is None
|
|
102
|
+
and (ref_c6 is None or ref_alpha is None)
|
|
103
|
+
or ptfile is not None
|
|
104
|
+
and (ref_c6 is not None or ref_alpha is not None)
|
|
105
|
+
):
|
|
106
|
+
raise ValueError("Either ptfile or ref_c6 and ref_alpha should be supplied.")
|
|
107
|
+
# load data
|
|
108
|
+
ref = torch.load(ptfile) if ptfile is not None else torch.zeros(87, 2)
|
|
109
|
+
for i, p in enumerate([ref_c6, ref_alpha]):
|
|
110
|
+
if p is not None:
|
|
111
|
+
if isinstance(p, Tensor):
|
|
112
|
+
ref[: p.shape[0], i] = p
|
|
113
|
+
else:
|
|
114
|
+
for k, v in p.items():
|
|
115
|
+
ref[k, i] = v
|
|
116
|
+
# c6=0 and alpha=1 for dummy atom
|
|
117
|
+
ref[0, 0] = 0.0
|
|
118
|
+
ref[0, 1] = 1.0
|
|
119
|
+
self.register_buffer("disp_param0", ref)
|
|
120
|
+
self.key_in = key_in
|
|
121
|
+
self.key_out = key_out
|
|
122
|
+
|
|
123
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
124
|
+
disp_param_mult = data[self.key_in].clamp(min=-4, max=4).exp()
|
|
125
|
+
disp_param = self.disp_param0[data["numbers"]]
|
|
126
|
+
vals = disp_param * disp_param_mult
|
|
127
|
+
data[self.key_out] = vals
|
|
128
|
+
return data
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class D3TS(nn.Module):
|
|
132
|
+
"""DFT-D3-like pairwise dispersion with TS combination rule"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, a1: float, a2: float, s8: float, s6: float = 1.0, key_in="disp_param", key_out="energy"):
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.register_buffer("r4r2", constants.get_r4r2())
|
|
137
|
+
self.a1 = a1
|
|
138
|
+
self.a2 = a2
|
|
139
|
+
self.s6 = s6
|
|
140
|
+
self.s8 = s8
|
|
141
|
+
self.key_in = key_in
|
|
142
|
+
self.key_out = key_out
|
|
143
|
+
|
|
144
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
145
|
+
disp_param = data[self.key_in]
|
|
146
|
+
disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix="_lr")
|
|
147
|
+
c6_i, alpha_i = disp_param_i.unbind(dim=-1)
|
|
148
|
+
c6_j, alpha_j = disp_param_j.unbind(dim=-1)
|
|
149
|
+
|
|
150
|
+
# TS combination rule
|
|
151
|
+
c6ij = 2 * c6_i * c6_j / (c6_i * alpha_j / alpha_i + c6_j * alpha_i / alpha_j).clamp(min=1e-4)
|
|
152
|
+
|
|
153
|
+
rr = self.r4r2[data["numbers"]]
|
|
154
|
+
rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
|
|
155
|
+
rrij = 3 * rr_i * rr_j
|
|
156
|
+
rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
|
|
157
|
+
r0ij = self.a1 * rrij.sqrt() + self.a2
|
|
158
|
+
|
|
159
|
+
ops.lazy_calc_dij_lr(data)
|
|
160
|
+
d_ij = data["d_ij_lr"] * constants.Bohr_inv
|
|
161
|
+
e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
|
|
162
|
+
e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
|
|
163
|
+
|
|
164
|
+
if self.key_out in data:
|
|
165
|
+
data[self.key_out] = data[self.key_out] + e
|
|
166
|
+
else:
|
|
167
|
+
data[self.key_out] = e
|
|
168
|
+
|
|
169
|
+
return data
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class DFTD3(nn.Module):
|
|
173
|
+
"""DFT-D3 implementation.
|
|
174
|
+
BJ dumping, C6 and C8 terms, without 3-body term.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="energy"):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.key_out = key_out
|
|
180
|
+
# BJ damping parameters
|
|
181
|
+
self.s6 = s6
|
|
182
|
+
self.s8 = s8
|
|
183
|
+
self.s9 = 4.0 / 3.0
|
|
184
|
+
self.a1 = a1
|
|
185
|
+
self.a2 = a2
|
|
186
|
+
self.a3 = 16.0
|
|
187
|
+
# CN parameters
|
|
188
|
+
self.k1 = -16.0
|
|
189
|
+
self.k3 = -4.0
|
|
190
|
+
# data
|
|
191
|
+
self.register_buffer("c6ab", torch.zeros(95, 95, 5, 5, 3))
|
|
192
|
+
self.register_buffer("r4r2", torch.zeros(95))
|
|
193
|
+
self.register_buffer("rcov", torch.zeros(95))
|
|
194
|
+
self.register_buffer("cnmax", torch.zeros(95))
|
|
195
|
+
sd = constants.get_dftd3_param()
|
|
196
|
+
self.load_state_dict(sd)
|
|
197
|
+
|
|
198
|
+
def _calc_c6ij(self, data: Dict[str, Tensor]) -> Tensor:
|
|
199
|
+
# CN part
|
|
200
|
+
# short range for CN
|
|
201
|
+
# d_ij = data["d_ij"] * constants.Bohr_inv
|
|
202
|
+
data = ops.lazy_calc_dij_lr(data)
|
|
203
|
+
d_ij = data["d_ij_lr"] * constants.Bohr_inv
|
|
204
|
+
|
|
205
|
+
numbers = data["numbers"]
|
|
206
|
+
numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr")
|
|
207
|
+
rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr")
|
|
208
|
+
rcov_ij = rcov_i + rcov_j
|
|
209
|
+
cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0)))
|
|
210
|
+
cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr")
|
|
211
|
+
cn = cn_ij.sum(-1)
|
|
212
|
+
cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1)
|
|
213
|
+
cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr")
|
|
214
|
+
c6ab = self.c6ab[numbers_i, numbers_j]
|
|
215
|
+
c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1)
|
|
216
|
+
c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr")
|
|
217
|
+
l_ij = torch.exp(self.k3 * ((cn_i - cnref_i).pow(2) + (cn_j - cnref_j).pow(2)))
|
|
218
|
+
w = l_ij.flatten(-2, -1).sum(-1)
|
|
219
|
+
z = torch.einsum("...ij,...ij->...", c6ref, l_ij)
|
|
220
|
+
_w = w < 1e-5
|
|
221
|
+
z[_w] = 0.0
|
|
222
|
+
c6_ij = z / w.clamp(min=1e-5)
|
|
223
|
+
return c6_ij
|
|
224
|
+
|
|
225
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
226
|
+
c6ij = self._calc_c6ij(data)
|
|
227
|
+
|
|
228
|
+
rr = self.r4r2[data["numbers"]]
|
|
229
|
+
rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr")
|
|
230
|
+
rrij = 3 * rr_i * rr_j
|
|
231
|
+
rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr")
|
|
232
|
+
r0ij = self.a1 * rrij.sqrt() + self.a2
|
|
233
|
+
|
|
234
|
+
ops.lazy_calc_dij_lr(data)
|
|
235
|
+
d_ij = data["d_ij_lr"] * constants.Bohr_inv
|
|
236
|
+
e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8)))
|
|
237
|
+
e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data)
|
|
238
|
+
|
|
239
|
+
if self.key_out in data:
|
|
240
|
+
data[self.key_out] = data[self.key_out] + e
|
|
241
|
+
else:
|
|
242
|
+
data[self.key_out] = e
|
|
243
|
+
return data
|
aimnet/nbops.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def set_nb_mode(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
8
|
+
"""Logic to guess and set the neighbor model."""
|
|
9
|
+
if "nbmat" in data:
|
|
10
|
+
if data["nbmat"].ndim == 2:
|
|
11
|
+
data["_nb_mode"] = torch.tensor(1)
|
|
12
|
+
elif data["nbmat"].ndim == 3:
|
|
13
|
+
data["_nb_mode"] = torch.tensor(2)
|
|
14
|
+
else:
|
|
15
|
+
raise ValueError(f"Invalid neighbor matrix shape: {data['nbmat'].shape}")
|
|
16
|
+
else:
|
|
17
|
+
data["_nb_mode"] = torch.tensor(0)
|
|
18
|
+
return data
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_nb_mode(data: Dict[str, Tensor]) -> int:
|
|
22
|
+
"""Get the neighbor model."""
|
|
23
|
+
return int(data["_nb_mode"].item())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
27
|
+
"""Calculate neighbor masks"""
|
|
28
|
+
nb_mode = get_nb_mode(data)
|
|
29
|
+
if nb_mode == 0:
|
|
30
|
+
data["mask_i"] = data["numbers"] == 0
|
|
31
|
+
data["mask_ij"] = torch.eye(
|
|
32
|
+
data["numbers"].shape[1], device=data["numbers"].device, dtype=torch.bool
|
|
33
|
+
).unsqueeze(0)
|
|
34
|
+
if data["mask_i"].any():
|
|
35
|
+
data["_input_padded"] = torch.tensor(True)
|
|
36
|
+
data["_natom"] = data["mask_i"].logical_not().sum(-1)
|
|
37
|
+
data["mol_sizes"] = (~data["mask_i"]).sum(-1)
|
|
38
|
+
data["mask_ij"] = data["mask_ij"] | (data["mask_i"].unsqueeze(-2) + data["mask_i"].unsqueeze(-1))
|
|
39
|
+
else:
|
|
40
|
+
data["_input_padded"] = torch.tensor(False)
|
|
41
|
+
data["_natom"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)
|
|
42
|
+
data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device)
|
|
43
|
+
data["mask_ij_lr"] = data["mask_ij"]
|
|
44
|
+
elif nb_mode == 1:
|
|
45
|
+
# padding must be the last atom
|
|
46
|
+
data["mask_i"] = torch.zeros(data["numbers"].shape[0], device=data["numbers"].device, dtype=torch.bool)
|
|
47
|
+
data["mask_i"][-1] = True
|
|
48
|
+
for suffix in ("", "_lr"):
|
|
49
|
+
if f"nbmat{suffix}" in data:
|
|
50
|
+
data[f"mask_ij{suffix}"] = data[f"nbmat{suffix}"] == data["numbers"].shape[0] - 1
|
|
51
|
+
data["_input_padded"] = torch.tensor(True)
|
|
52
|
+
data["mol_sizes"] = torch.bincount(data["mol_idx"])
|
|
53
|
+
# last atom is padding
|
|
54
|
+
data["mol_sizes"][-1] -= 1
|
|
55
|
+
elif nb_mode == 2:
|
|
56
|
+
data["mask_i"] = data["numbers"] == 0
|
|
57
|
+
w = torch.where(data["mask_i"])
|
|
58
|
+
pad_idx = w[0] * data["numbers"].shape[1] + w[1]
|
|
59
|
+
for suffix in ("", "_lr"):
|
|
60
|
+
if f"nbmat{suffix}" in data:
|
|
61
|
+
data[f"mask_ij{suffix}"] = torch.isin(data[f"nbmat{suffix}"], pad_idx)
|
|
62
|
+
data["_input_padded"] = torch.tensor(True)
|
|
63
|
+
data["mol_sizes"] = (~data["mask_i"]).sum(-1)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
66
|
+
|
|
67
|
+
return data
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def mask_ij_(
|
|
71
|
+
x: Tensor,
|
|
72
|
+
data: Dict[str, Tensor],
|
|
73
|
+
mask_value: float = 0.0,
|
|
74
|
+
inplace: bool = True,
|
|
75
|
+
suffix: str = "",
|
|
76
|
+
) -> Tensor:
|
|
77
|
+
mask = data[f"mask_ij{suffix}"]
|
|
78
|
+
for _i in range(x.ndim - mask.ndim):
|
|
79
|
+
mask = mask.unsqueeze(-1)
|
|
80
|
+
if inplace:
|
|
81
|
+
x.masked_fill_(mask, mask_value)
|
|
82
|
+
else:
|
|
83
|
+
x = x.masked_fill(mask, mask_value)
|
|
84
|
+
return x
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def mask_i_(x: Tensor, data: Dict[str, Tensor], mask_value: float = 0.0, inplace: bool = True) -> Tensor:
|
|
88
|
+
nb_mode = get_nb_mode(data)
|
|
89
|
+
if nb_mode == 0:
|
|
90
|
+
if data["_input_padded"].item():
|
|
91
|
+
mask = data["mask_i"]
|
|
92
|
+
for _i in range(x.ndim - mask.ndim):
|
|
93
|
+
mask = mask.unsqueeze(-1)
|
|
94
|
+
if inplace:
|
|
95
|
+
x.masked_fill_(mask, mask_value)
|
|
96
|
+
else:
|
|
97
|
+
x = x.masked_fill(mask, mask_value)
|
|
98
|
+
elif nb_mode == 1:
|
|
99
|
+
if inplace:
|
|
100
|
+
x[-1] = mask_value
|
|
101
|
+
else:
|
|
102
|
+
x = torch.cat([x[:-1], torch.zeros_like(x[:1])], dim=0)
|
|
103
|
+
elif nb_mode == 2:
|
|
104
|
+
if inplace:
|
|
105
|
+
x[:, -1] = mask_value
|
|
106
|
+
else:
|
|
107
|
+
x = torch.cat([x[:, :-1], torch.zeros_like(x[:, :1])], dim=1)
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
110
|
+
return x
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor, Tensor]:
|
|
114
|
+
nb_mode = get_nb_mode(data)
|
|
115
|
+
if nb_mode == 0:
|
|
116
|
+
x_i = x.unsqueeze(2)
|
|
117
|
+
x_j = x.unsqueeze(1)
|
|
118
|
+
elif nb_mode == 1:
|
|
119
|
+
x_i = x.unsqueeze(1)
|
|
120
|
+
idx = data[f"nbmat{suffix}"]
|
|
121
|
+
x_j = torch.index_select(x, 0, idx.flatten()).unflatten(0, idx.shape)
|
|
122
|
+
elif nb_mode == 2:
|
|
123
|
+
x_i = x.unsqueeze(2)
|
|
124
|
+
idx = data[f"nbmat{suffix}"]
|
|
125
|
+
x_j = torch.index_select(x.flatten(0, 1), 0, idx.flatten()).unflatten(0, idx.shape)
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
128
|
+
return x_i, x_j
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor:
|
|
132
|
+
nb_mode = get_nb_mode(data)
|
|
133
|
+
if nb_mode in (0, 2):
|
|
134
|
+
res = x.sum(dim=1)
|
|
135
|
+
elif nb_mode == 1:
|
|
136
|
+
assert x.ndim in (
|
|
137
|
+
1,
|
|
138
|
+
2,
|
|
139
|
+
), "Invalid tensor shape for mol_sum, ndim should be 1 or 2"
|
|
140
|
+
idx = data["mol_idx"]
|
|
141
|
+
# assuming mol_idx is sorted, replace with max if not
|
|
142
|
+
out_size = int(idx[-1].item()) + 1
|
|
143
|
+
if x.ndim == 1:
|
|
144
|
+
res = torch.zeros(out_size, device=x.device, dtype=x.dtype)
|
|
145
|
+
else:
|
|
146
|
+
idx = idx.unsqueeze(-1).expand(-1, x.shape[1])
|
|
147
|
+
res = torch.zeros(out_size, x.shape[1], device=x.device, dtype=x.dtype)
|
|
148
|
+
res.scatter_add_(0, idx, x)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
151
|
+
return res
|