aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/models/aimnet2.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from typing import Dict, List, Mapping, Sequence, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from aimnet import nbops, ops
|
|
7
|
+
from aimnet.models.base import AIMNet2Base
|
|
8
|
+
from aimnet.modules import AEVSV, MLP, ConvSV, Embedding
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# pylint: disable=too-many-arguments, too-many-instance-attributes
|
|
12
|
+
class AIMNet2(AIMNet2Base):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
aev: Dict,
|
|
16
|
+
nfeature: int,
|
|
17
|
+
d2features: bool,
|
|
18
|
+
ncomb_v: int,
|
|
19
|
+
hidden: Tuple[List[int]],
|
|
20
|
+
aim_size: int,
|
|
21
|
+
outputs: Union[List[nn.Module], Dict[str, nn.Module]],
|
|
22
|
+
num_charge_channels: int = 1,
|
|
23
|
+
):
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
if num_charge_channels not in [1, 2]:
|
|
27
|
+
raise ValueError("num_charge_channels must be 1 (closed shell) or 2 (NSE for open-shell).")
|
|
28
|
+
self.num_charge_channels = num_charge_channels
|
|
29
|
+
|
|
30
|
+
self.aev = AEVSV(**aev)
|
|
31
|
+
nshifts_s = aev["nshifts_s"]
|
|
32
|
+
nshifts_v = aev.get("nshitfs_v") or nshifts_s
|
|
33
|
+
if d2features:
|
|
34
|
+
if nshifts_s != nshifts_v:
|
|
35
|
+
raise ValueError("nshifts_s must be equal to nshifts_v for d2features")
|
|
36
|
+
nfeature_tot = nshifts_s * nfeature
|
|
37
|
+
else:
|
|
38
|
+
nfeature_tot = nfeature
|
|
39
|
+
self.nfeature = nfeature
|
|
40
|
+
self.nshifts_s = nshifts_s
|
|
41
|
+
self.d2features = d2features
|
|
42
|
+
|
|
43
|
+
self.afv = Embedding(num_embeddings=64, embedding_dim=nfeature, padding_idx=0)
|
|
44
|
+
|
|
45
|
+
with torch.no_grad():
|
|
46
|
+
nn.init.orthogonal_(self.afv.weight[1:])
|
|
47
|
+
if d2features:
|
|
48
|
+
self.afv.weight = nn.Parameter(
|
|
49
|
+
self.afv.weight.clone().unsqueeze(-1).expand(64, nfeature, nshifts_s).flatten(-2, -1)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v, "do_vector": True}
|
|
53
|
+
self.conv_a = ConvSV(nchannel=nfeature, d2features=d2features, **conv_param)
|
|
54
|
+
self.conv_q = ConvSV(nchannel=num_charge_channels, d2features=False, **conv_param)
|
|
55
|
+
|
|
56
|
+
mlp_param = {"activation_fn": nn.GELU(), "last_linear": True}
|
|
57
|
+
mlps = [
|
|
58
|
+
MLP(
|
|
59
|
+
n_in=self.conv_a.output_size() + nfeature_tot,
|
|
60
|
+
n_out=nfeature_tot + 2 * num_charge_channels,
|
|
61
|
+
hidden=hidden[0],
|
|
62
|
+
**mlp_param,
|
|
63
|
+
)
|
|
64
|
+
]
|
|
65
|
+
mlp_param = {"activation_fn": nn.GELU(), "last_linear": False}
|
|
66
|
+
for h in hidden[1:-1]:
|
|
67
|
+
mlps.append(
|
|
68
|
+
MLP(
|
|
69
|
+
n_in=self.conv_a.output_size() + self.conv_q.output_size() + nfeature_tot + num_charge_channels,
|
|
70
|
+
n_out=nfeature_tot + 2 * num_charge_channels,
|
|
71
|
+
hidden=h,
|
|
72
|
+
**mlp_param,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
mlp_param = {"activation_fn": nn.GELU(), "last_linear": False}
|
|
76
|
+
mlps.append(
|
|
77
|
+
MLP(
|
|
78
|
+
n_in=self.conv_a.output_size() + self.conv_q.output_size() + nfeature_tot + num_charge_channels,
|
|
79
|
+
n_out=aim_size,
|
|
80
|
+
hidden=hidden[-1],
|
|
81
|
+
**mlp_param,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
self.mlps = nn.ModuleList(mlps)
|
|
85
|
+
|
|
86
|
+
if isinstance(outputs, Sequence):
|
|
87
|
+
self.outputs = nn.ModuleList(outputs)
|
|
88
|
+
elif isinstance(outputs, Mapping):
|
|
89
|
+
self.outputs = nn.ModuleDict(outputs)
|
|
90
|
+
else:
|
|
91
|
+
raise TypeError("`outputs` is not either list or dict")
|
|
92
|
+
|
|
93
|
+
def _preprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
94
|
+
if "mult" not in data:
|
|
95
|
+
raise ValueError("mult key is required for NSE if two channels for charge are not provided")
|
|
96
|
+
_half_spin = 0.5 * (data["mult"] - 1.0)
|
|
97
|
+
_half_q = 0.5 * data["charge"]
|
|
98
|
+
data["charge"] = torch.stack([_half_q + _half_spin, _half_q - _half_spin], dim=-1)
|
|
99
|
+
return data
|
|
100
|
+
|
|
101
|
+
def _postprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
102
|
+
data["spin_charges"] = data["charges"][..., 0] - data["charges"][..., 1]
|
|
103
|
+
data["charges"] = data["charges"].sum(dim=-1)
|
|
104
|
+
data["charge"] = data["charge"].sum(dim=-1)
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor:
|
|
108
|
+
a_i, a_j = nbops.get_ij(data["a"], data)
|
|
109
|
+
avf_a = self.conv_a(a_j, data["gs"], data["gv"])
|
|
110
|
+
if self.d2features:
|
|
111
|
+
a_i = a_i.flatten(-2, -1)
|
|
112
|
+
_in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
|
|
113
|
+
return _in
|
|
114
|
+
|
|
115
|
+
def _prepare_in_q(self, data: Dict[str, Tensor]) -> Tensor:
|
|
116
|
+
q_i, q_j = nbops.get_ij(data["charges"], data)
|
|
117
|
+
avf_q = self.conv_q(q_j, data["gs"], data["gv"])
|
|
118
|
+
_in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
|
|
119
|
+
return _in
|
|
120
|
+
|
|
121
|
+
def _update_q(self, data: Dict[str, Tensor], x: Tensor, delta_q: bool = True) -> Dict[str, Tensor]:
|
|
122
|
+
_q, _f, delta_a = x.split(
|
|
123
|
+
[
|
|
124
|
+
self.num_charge_channels,
|
|
125
|
+
self.num_charge_channels,
|
|
126
|
+
x.shape[-1] - 2 * self.num_charge_channels,
|
|
127
|
+
],
|
|
128
|
+
dim=-1,
|
|
129
|
+
)
|
|
130
|
+
# for loss
|
|
131
|
+
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
|
|
132
|
+
q = data["charges"] + _q if delta_q else _q
|
|
133
|
+
f = _f.pow(2)
|
|
134
|
+
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
|
|
135
|
+
data["charges"] = q
|
|
136
|
+
data["a"] = data["a"] + delta_a.view_as(data["a"])
|
|
137
|
+
return data
|
|
138
|
+
|
|
139
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
140
|
+
data = self.prepare_input(data)
|
|
141
|
+
|
|
142
|
+
# initial features
|
|
143
|
+
a: Tensor = self.afv(data["numbers"])
|
|
144
|
+
if self.d2features:
|
|
145
|
+
a = a.unflatten(-1, (self.nfeature, self.nshifts_s))
|
|
146
|
+
data["a"] = a
|
|
147
|
+
|
|
148
|
+
# NSE case
|
|
149
|
+
if self.num_charge_channels == 2:
|
|
150
|
+
data = self._preprocess_spin_polarized_charge(data)
|
|
151
|
+
else:
|
|
152
|
+
# make sure that charge has channel dimension
|
|
153
|
+
data["charge"] = data["charge"].unsqueeze(-1)
|
|
154
|
+
|
|
155
|
+
# AEV
|
|
156
|
+
data = self.aev(data)
|
|
157
|
+
|
|
158
|
+
# MP iterations
|
|
159
|
+
_npass = len(self.mlps)
|
|
160
|
+
for ipass, mlp in enumerate(self.mlps):
|
|
161
|
+
if ipass == 0:
|
|
162
|
+
_in = self._prepare_in_a(data)
|
|
163
|
+
else:
|
|
164
|
+
_in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1)
|
|
165
|
+
|
|
166
|
+
_out = mlp(_in)
|
|
167
|
+
if data["_input_padded"].item():
|
|
168
|
+
_out = nbops.mask_i_(_out, data, mask_value=0.0)
|
|
169
|
+
|
|
170
|
+
if ipass == 0:
|
|
171
|
+
data = self._update_q(data, _out, delta_q=False)
|
|
172
|
+
elif ipass < _npass - 1:
|
|
173
|
+
data = self._update_q(data, _out, delta_q=True)
|
|
174
|
+
else:
|
|
175
|
+
data["aim"] = _out
|
|
176
|
+
|
|
177
|
+
# squeeze charges
|
|
178
|
+
if self.num_charge_channels == 2:
|
|
179
|
+
data = self._postprocess_spin_polarized_charge(data)
|
|
180
|
+
else:
|
|
181
|
+
data["charges"] = data["charges"].squeeze(-1)
|
|
182
|
+
data["charge"] = data["charge"].squeeze(-1)
|
|
183
|
+
|
|
184
|
+
# readout
|
|
185
|
+
for m in self.outputs.children():
|
|
186
|
+
data = m(data)
|
|
187
|
+
|
|
188
|
+
return data
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
class: aimnet.models.AIMNet2
|
|
2
|
+
kwargs:
|
|
3
|
+
nfeature: 16
|
|
4
|
+
d2features: true
|
|
5
|
+
ncomb_v: 12
|
|
6
|
+
hidden:
|
|
7
|
+
- [512, 380]
|
|
8
|
+
- [512, 380]
|
|
9
|
+
- [512, 380, 380]
|
|
10
|
+
aim_size: 256
|
|
11
|
+
aev:
|
|
12
|
+
rc_s: 5.0
|
|
13
|
+
nshifts_s: 16
|
|
14
|
+
outputs:
|
|
15
|
+
energy_mlp:
|
|
16
|
+
class: aimnet.modules.Output
|
|
17
|
+
kwargs:
|
|
18
|
+
n_in: 256
|
|
19
|
+
n_out: 1
|
|
20
|
+
key_in: aim
|
|
21
|
+
key_out: energy
|
|
22
|
+
mlp:
|
|
23
|
+
activation_fn: torch.nn.GELU
|
|
24
|
+
last_linear: true
|
|
25
|
+
hidden: [128, 128]
|
|
26
|
+
|
|
27
|
+
atomic_shift:
|
|
28
|
+
class: aimnet.modules.AtomicShift
|
|
29
|
+
kwargs:
|
|
30
|
+
key_in: energy
|
|
31
|
+
key_out: energy
|
|
32
|
+
|
|
33
|
+
atomic_sum:
|
|
34
|
+
class: aimnet.modules.AtomicSum
|
|
35
|
+
kwargs:
|
|
36
|
+
key_in: energy
|
|
37
|
+
key_out: energy
|
|
38
|
+
|
|
39
|
+
lrcoulomb:
|
|
40
|
+
class: aimnet.modules.LRCoulomb
|
|
41
|
+
kwargs:
|
|
42
|
+
rc: 4.6
|
|
43
|
+
key_in: charges
|
|
44
|
+
key_out: energy
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
class: aimnet.models.AIMNet2
|
|
2
|
+
kwargs:
|
|
3
|
+
nfeature: 16
|
|
4
|
+
d2features: true
|
|
5
|
+
ncomb_v: 12
|
|
6
|
+
hidden:
|
|
7
|
+
- [512, 380]
|
|
8
|
+
- [512, 380]
|
|
9
|
+
- [512, 380, 380]
|
|
10
|
+
aim_size: 256
|
|
11
|
+
aev:
|
|
12
|
+
rc_s: 5.0
|
|
13
|
+
nshifts_s: 16
|
|
14
|
+
outputs:
|
|
15
|
+
energy_mlp:
|
|
16
|
+
class: aimnet.modules.Output
|
|
17
|
+
kwargs:
|
|
18
|
+
n_in: 256
|
|
19
|
+
n_out: 1
|
|
20
|
+
key_in: aim
|
|
21
|
+
key_out: energy
|
|
22
|
+
mlp:
|
|
23
|
+
activation_fn: torch.nn.GELU
|
|
24
|
+
last_linear: true
|
|
25
|
+
hidden: [128, 128]
|
|
26
|
+
|
|
27
|
+
atomic_shift:
|
|
28
|
+
class: aimnet.modules.AtomicShift
|
|
29
|
+
kwargs:
|
|
30
|
+
key_in: energy
|
|
31
|
+
key_out: energy
|
|
32
|
+
|
|
33
|
+
atomic_sum:
|
|
34
|
+
class: aimnet.modules.AtomicSum
|
|
35
|
+
kwargs:
|
|
36
|
+
key_in: energy
|
|
37
|
+
key_out: energy
|
|
38
|
+
|
|
39
|
+
lrcoulomb:
|
|
40
|
+
class: aimnet.modules.LRCoulomb
|
|
41
|
+
kwargs:
|
|
42
|
+
rc: 4.6
|
|
43
|
+
key_in: charges
|
|
44
|
+
key_out: energy
|
|
45
|
+
|
|
46
|
+
dftd3:
|
|
47
|
+
class: aimnet.modules.DFTD3
|
|
48
|
+
kwargs:
|
|
49
|
+
s8: 0.3908
|
|
50
|
+
a1: 0.5660
|
|
51
|
+
a2: 3.1280
|
aimnet/models/base.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import ClassVar, Dict, Final
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from aimnet import nbops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
|
|
10
|
+
"""Base class for AIMNet2 models. Implements pre-processing data:
|
|
11
|
+
converting to right dtype and device, setting nb mode, calculating masks.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
__default_dtype = torch.get_default_dtype()
|
|
15
|
+
|
|
16
|
+
_required_keys: Final = ["coord", "numbers", "charge"]
|
|
17
|
+
_required_keys_dtype: Final = [__default_dtype, torch.int64, __default_dtype]
|
|
18
|
+
_optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "shifts_lr", "cell"]
|
|
19
|
+
_optional_keys_dtype: Final = [
|
|
20
|
+
__default_dtype,
|
|
21
|
+
torch.int64,
|
|
22
|
+
torch.int64,
|
|
23
|
+
torch.int64,
|
|
24
|
+
__default_dtype,
|
|
25
|
+
__default_dtype,
|
|
26
|
+
__default_dtype,
|
|
27
|
+
]
|
|
28
|
+
__constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
34
|
+
for k, d in zip(self._required_keys, self._required_keys_dtype):
|
|
35
|
+
assert k in data, f"Key {k} is required"
|
|
36
|
+
data[k] = data[k].to(d)
|
|
37
|
+
for k, d in zip(self._optional_keys, self._optional_keys_dtype):
|
|
38
|
+
if k in data:
|
|
39
|
+
data[k] = data[k].to(d)
|
|
40
|
+
return data
|
|
41
|
+
|
|
42
|
+
def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
43
|
+
"""Some sommon operations"""
|
|
44
|
+
data = self._prepare_dtype(data)
|
|
45
|
+
data = nbops.set_nb_mode(data)
|
|
46
|
+
data = nbops.calc_masks(data)
|
|
47
|
+
|
|
48
|
+
assert data["charge"].ndim == 1, "Charge should be 1D tensor."
|
|
49
|
+
if "mult" in data:
|
|
50
|
+
assert data["mult"].ndim == 1, "Mult should be 1D tensor."
|
|
51
|
+
return data
|
aimnet/modules/aev.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
|
|
7
|
+
from aimnet import nbops, ops
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AEVSV(nn.Module):
|
|
11
|
+
"""AEV module to expand distances and vectors toneighbors over shifted Gaussian basis functions.
|
|
12
|
+
|
|
13
|
+
Parameters:
|
|
14
|
+
-----------
|
|
15
|
+
rmin : float, optional
|
|
16
|
+
Minimum distance for the Gaussian basis functions. Default is 0.8.
|
|
17
|
+
rc_s : float, optional
|
|
18
|
+
Cutoff radius for scalar features. Default is 5.0.
|
|
19
|
+
nshifts_s : int, optional
|
|
20
|
+
Number of shifts for scalar features. Default is 16.
|
|
21
|
+
eta_s : Optional[float], optional
|
|
22
|
+
Width of the Gaussian basis functions for scalar features. Will estimate reasonable default.
|
|
23
|
+
rc_v : Optional[float], optional
|
|
24
|
+
Cutoff radius for vector features. Default is same as `rc_s`.
|
|
25
|
+
nshifts_v : Optional[int], optional
|
|
26
|
+
Number of shifts for vector features. Default is same as `nshifts_s`
|
|
27
|
+
eta_v : Optional[float], optional
|
|
28
|
+
Width of the Gaussian basis functions for vector features. Will estimate reasonable default.
|
|
29
|
+
shifts_s : Optional[List[float]], optional
|
|
30
|
+
List of shifts for scalar features. Default equidistant between `rmin` and `rc_s`
|
|
31
|
+
shifts_v : Optional[List[float]], optional
|
|
32
|
+
List of shifts for vector features. Default equidistant between `rmin` and `rc_v`
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
rmin: float = 0.8,
|
|
38
|
+
rc_s: float = 5.0,
|
|
39
|
+
nshifts_s: int = 16,
|
|
40
|
+
eta_s: Optional[float] = None,
|
|
41
|
+
rc_v: Optional[float] = None,
|
|
42
|
+
nshifts_v: Optional[int] = None,
|
|
43
|
+
eta_v: Optional[float] = None,
|
|
44
|
+
shifts_s: Optional[List[float]] = None,
|
|
45
|
+
shifts_v: Optional[List[float]] = None,
|
|
46
|
+
):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_s")
|
|
50
|
+
if rc_v is not None:
|
|
51
|
+
if rc_v > rc_s:
|
|
52
|
+
raise ValueError("rc_v must be less than or equal to rc_s")
|
|
53
|
+
if nshifts_v is None:
|
|
54
|
+
raise ValueError("nshifts_v must not be None")
|
|
55
|
+
self._init_basis(rc_v, eta_v, nshifts_v, shifts_v, rmin, mod="_v")
|
|
56
|
+
self._dual_basis = True
|
|
57
|
+
else:
|
|
58
|
+
# dummy init
|
|
59
|
+
self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_v")
|
|
60
|
+
self._dual_basis = False
|
|
61
|
+
|
|
62
|
+
self.dmat_fill = rc_s
|
|
63
|
+
|
|
64
|
+
def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"):
|
|
65
|
+
self.register_parameter(
|
|
66
|
+
"rc" + mod,
|
|
67
|
+
nn.Parameter(torch.tensor(rc, dtype=torch.float), requires_grad=False),
|
|
68
|
+
)
|
|
69
|
+
if eta is None:
|
|
70
|
+
eta = (1 / ((rc - rmin) / nshifts)) ** 2
|
|
71
|
+
self.register_parameter(
|
|
72
|
+
"eta" + mod,
|
|
73
|
+
nn.Parameter(torch.tensor(eta, dtype=torch.float), requires_grad=False),
|
|
74
|
+
)
|
|
75
|
+
if shifts is None:
|
|
76
|
+
shifts = torch.linspace(rmin, rc, nshifts + 1)[:nshifts]
|
|
77
|
+
else:
|
|
78
|
+
shifts = torch.as_tensor(shifts, dtype=torch.float)
|
|
79
|
+
self.register_parameter("shifts" + mod, nn.Parameter(shifts, requires_grad=False))
|
|
80
|
+
|
|
81
|
+
def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
82
|
+
# shapes (..., m) and (..., m, 3)
|
|
83
|
+
d_ij, r_ij = ops.calc_distances(data)
|
|
84
|
+
data["d_ij"] = d_ij
|
|
85
|
+
# shapes (..., nshifts, m) and (..., nshifts, 3, m)
|
|
86
|
+
u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) # pylint: disable=unused-variable
|
|
87
|
+
# for now, do not save u_ij
|
|
88
|
+
data["gs"], data["gv"] = gs, gv
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
|
|
92
|
+
fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m)
|
|
93
|
+
fc_ij = nbops.mask_ij_(fc_ij, data, 0.0)
|
|
94
|
+
gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(
|
|
95
|
+
-1
|
|
96
|
+
) # (..., m, nshifts) * (..., m, 1) -> (..., m, shitfs)
|
|
97
|
+
u_ij = r_ij / d_ij.unsqueeze(-1) # (..., m, 3) / (..., m, 1) -> (..., m, 3)
|
|
98
|
+
if self._dual_basis:
|
|
99
|
+
fc_ij = ops.cosine_cutoff(d_ij, self.rc_v)
|
|
100
|
+
gsv = ops.exp_expand(d_ij, self.shifts_v, self.eta_v) * fc_ij.unsqueeze(-1)
|
|
101
|
+
gv = gsv.unsqueeze(-2) * u_ij.unsqueeze(-1)
|
|
102
|
+
else:
|
|
103
|
+
# (..., m, 1, shifts), (..., m, 3, 1) -> (..., m, 3, shifts)
|
|
104
|
+
gv = gs.unsqueeze(-2) * u_ij.unsqueeze(-1)
|
|
105
|
+
return u_ij, gs, gv
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ConvSV(nn.Module):
|
|
109
|
+
"""AIMNet2 type convolution: encoding of local environment which combines geometry of local environment and atomic features.
|
|
110
|
+
|
|
111
|
+
Parameters:
|
|
112
|
+
-----------
|
|
113
|
+
nshifts_s : int
|
|
114
|
+
Number of shifts (gaussian basis functions) for scalar convolution.
|
|
115
|
+
nchannel : int
|
|
116
|
+
Number of feature channels for atomic features.
|
|
117
|
+
d2features : bool, optional
|
|
118
|
+
Flag indicating whether to use 2D features. Default is False.
|
|
119
|
+
do_vector : bool, optional
|
|
120
|
+
Flag indicating whether to perform vector convolution. Default is True.
|
|
121
|
+
nshifts_v : Optional[int], optional
|
|
122
|
+
Number of shifts for vector convolution. If not provided, defaults to the value of nshifts_s.
|
|
123
|
+
ncomb_v : Optional[int], optional
|
|
124
|
+
Number of linear combinations for vector features. If not provided, defaults to the value of nshifts_v.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
nshifts_s: int,
|
|
130
|
+
nchannel: int,
|
|
131
|
+
d2features: bool = False,
|
|
132
|
+
do_vector: bool = True,
|
|
133
|
+
nshifts_v: Optional[int] = None,
|
|
134
|
+
ncomb_v: Optional[int] = None,
|
|
135
|
+
):
|
|
136
|
+
super().__init__()
|
|
137
|
+
nshifts_v = nshifts_v or nshifts_s
|
|
138
|
+
ncomb_v = ncomb_v or nshifts_v
|
|
139
|
+
agh = _init_ahg(nchannel, nshifts_v, ncomb_v)
|
|
140
|
+
self.register_parameter("agh", nn.Parameter(agh, requires_grad=True))
|
|
141
|
+
self.do_vector = do_vector
|
|
142
|
+
self.nchannel = nchannel
|
|
143
|
+
self.d2features = d2features
|
|
144
|
+
self.nshifts_s = nshifts_s
|
|
145
|
+
self.nshifts_v = nshifts_v
|
|
146
|
+
self.ncomb_v = ncomb_v
|
|
147
|
+
|
|
148
|
+
def output_size(self):
|
|
149
|
+
n = self.nchannel * self.nshifts_s
|
|
150
|
+
if self.do_vector:
|
|
151
|
+
n += self.nchannel * self.ncomb_v
|
|
152
|
+
return n
|
|
153
|
+
|
|
154
|
+
def forward(self, a: Tensor, gs: Tensor, gv: Optional[Tensor] = None) -> Tensor:
|
|
155
|
+
avf = []
|
|
156
|
+
if self.d2features:
|
|
157
|
+
avf_s = torch.einsum("...mag,...mg->...ag", a, gs)
|
|
158
|
+
else:
|
|
159
|
+
avf_s = torch.einsum("...mg,...ma->...ag", gs, a)
|
|
160
|
+
avf.append(avf_s.flatten(-2, -1))
|
|
161
|
+
if self.do_vector:
|
|
162
|
+
assert gv is not None
|
|
163
|
+
agh = self.agh
|
|
164
|
+
if self.d2features:
|
|
165
|
+
avf_v = torch.einsum("...mag,...mdg,agh->...ahd", a, gv, agh)
|
|
166
|
+
else:
|
|
167
|
+
avf_v = torch.einsum("...ma,...mdg,agh->...ahd", a, gv, agh)
|
|
168
|
+
avf.append(avf_v.pow(2).sum(-1).flatten(-2, -1))
|
|
169
|
+
return torch.cat(avf, dim=-1)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _init_ahg(b: int, m: int, n: int):
|
|
173
|
+
ret = torch.zeros(b, m, n)
|
|
174
|
+
for i in range(b):
|
|
175
|
+
ret[i] = _init_ahg_one(m, n) # pylinit: disable-arguments-out-of-order
|
|
176
|
+
return ret
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _init_ahg_one(m: int, n: int):
|
|
180
|
+
# make x8 times more vectors to select most diverse
|
|
181
|
+
x = torch.arange(m).unsqueeze(0)
|
|
182
|
+
a1, a2, a3, a4 = torch.randn(8 * n, 4).unsqueeze(-2).unbind(-1)
|
|
183
|
+
y = a1 * torch.sin(a2 * 2 * x * math.pi / m) + a3 * torch.cos(a4 * 2 * x * math.pi / m)
|
|
184
|
+
y -= y.mean(dim=-1, keepdim=True)
|
|
185
|
+
y /= y.std(dim=-1, keepdim=True)
|
|
186
|
+
|
|
187
|
+
dmat = torch.cdist(y, y)
|
|
188
|
+
# most distant point
|
|
189
|
+
ret = torch.zeros(n, m)
|
|
190
|
+
mask = torch.ones(y.shape[0], dtype=torch.bool)
|
|
191
|
+
i = dmat.sum(-1).argmax()
|
|
192
|
+
ret[0] = y[i]
|
|
193
|
+
mask[i] = False
|
|
194
|
+
|
|
195
|
+
# simple maxmin impementation
|
|
196
|
+
for j in range(1, n):
|
|
197
|
+
mindist, _ = torch.cdist(ret[:j], y).min(dim=0)
|
|
198
|
+
maxidx = torch.argsort(mindist)[mask][-1]
|
|
199
|
+
ret[j] = y[maxidx]
|
|
200
|
+
mask[maxidx] = False
|
|
201
|
+
return ret.t()
|