aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,188 @@
1
+ from typing import Dict, List, Mapping, Sequence, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from aimnet import nbops, ops
7
+ from aimnet.models.base import AIMNet2Base
8
+ from aimnet.modules import AEVSV, MLP, ConvSV, Embedding
9
+
10
+
11
+ # pylint: disable=too-many-arguments, too-many-instance-attributes
12
+ class AIMNet2(AIMNet2Base):
13
+ def __init__(
14
+ self,
15
+ aev: Dict,
16
+ nfeature: int,
17
+ d2features: bool,
18
+ ncomb_v: int,
19
+ hidden: Tuple[List[int]],
20
+ aim_size: int,
21
+ outputs: Union[List[nn.Module], Dict[str, nn.Module]],
22
+ num_charge_channels: int = 1,
23
+ ):
24
+ super().__init__()
25
+
26
+ if num_charge_channels not in [1, 2]:
27
+ raise ValueError("num_charge_channels must be 1 (closed shell) or 2 (NSE for open-shell).")
28
+ self.num_charge_channels = num_charge_channels
29
+
30
+ self.aev = AEVSV(**aev)
31
+ nshifts_s = aev["nshifts_s"]
32
+ nshifts_v = aev.get("nshitfs_v") or nshifts_s
33
+ if d2features:
34
+ if nshifts_s != nshifts_v:
35
+ raise ValueError("nshifts_s must be equal to nshifts_v for d2features")
36
+ nfeature_tot = nshifts_s * nfeature
37
+ else:
38
+ nfeature_tot = nfeature
39
+ self.nfeature = nfeature
40
+ self.nshifts_s = nshifts_s
41
+ self.d2features = d2features
42
+
43
+ self.afv = Embedding(num_embeddings=64, embedding_dim=nfeature, padding_idx=0)
44
+
45
+ with torch.no_grad():
46
+ nn.init.orthogonal_(self.afv.weight[1:])
47
+ if d2features:
48
+ self.afv.weight = nn.Parameter(
49
+ self.afv.weight.clone().unsqueeze(-1).expand(64, nfeature, nshifts_s).flatten(-2, -1)
50
+ )
51
+
52
+ conv_param = {"nshifts_s": nshifts_s, "nshifts_v": nshifts_v, "ncomb_v": ncomb_v, "do_vector": True}
53
+ self.conv_a = ConvSV(nchannel=nfeature, d2features=d2features, **conv_param)
54
+ self.conv_q = ConvSV(nchannel=num_charge_channels, d2features=False, **conv_param)
55
+
56
+ mlp_param = {"activation_fn": nn.GELU(), "last_linear": True}
57
+ mlps = [
58
+ MLP(
59
+ n_in=self.conv_a.output_size() + nfeature_tot,
60
+ n_out=nfeature_tot + 2 * num_charge_channels,
61
+ hidden=hidden[0],
62
+ **mlp_param,
63
+ )
64
+ ]
65
+ mlp_param = {"activation_fn": nn.GELU(), "last_linear": False}
66
+ for h in hidden[1:-1]:
67
+ mlps.append(
68
+ MLP(
69
+ n_in=self.conv_a.output_size() + self.conv_q.output_size() + nfeature_tot + num_charge_channels,
70
+ n_out=nfeature_tot + 2 * num_charge_channels,
71
+ hidden=h,
72
+ **mlp_param,
73
+ )
74
+ )
75
+ mlp_param = {"activation_fn": nn.GELU(), "last_linear": False}
76
+ mlps.append(
77
+ MLP(
78
+ n_in=self.conv_a.output_size() + self.conv_q.output_size() + nfeature_tot + num_charge_channels,
79
+ n_out=aim_size,
80
+ hidden=hidden[-1],
81
+ **mlp_param,
82
+ )
83
+ )
84
+ self.mlps = nn.ModuleList(mlps)
85
+
86
+ if isinstance(outputs, Sequence):
87
+ self.outputs = nn.ModuleList(outputs)
88
+ elif isinstance(outputs, Mapping):
89
+ self.outputs = nn.ModuleDict(outputs)
90
+ else:
91
+ raise TypeError("`outputs` is not either list or dict")
92
+
93
+ def _preprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
94
+ if "mult" not in data:
95
+ raise ValueError("mult key is required for NSE if two channels for charge are not provided")
96
+ _half_spin = 0.5 * (data["mult"] - 1.0)
97
+ _half_q = 0.5 * data["charge"]
98
+ data["charge"] = torch.stack([_half_q + _half_spin, _half_q - _half_spin], dim=-1)
99
+ return data
100
+
101
+ def _postprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
102
+ data["spin_charges"] = data["charges"][..., 0] - data["charges"][..., 1]
103
+ data["charges"] = data["charges"].sum(dim=-1)
104
+ data["charge"] = data["charge"].sum(dim=-1)
105
+ return data
106
+
107
+ def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor:
108
+ a_i, a_j = nbops.get_ij(data["a"], data)
109
+ avf_a = self.conv_a(a_j, data["gs"], data["gv"])
110
+ if self.d2features:
111
+ a_i = a_i.flatten(-2, -1)
112
+ _in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
113
+ return _in
114
+
115
+ def _prepare_in_q(self, data: Dict[str, Tensor]) -> Tensor:
116
+ q_i, q_j = nbops.get_ij(data["charges"], data)
117
+ avf_q = self.conv_q(q_j, data["gs"], data["gv"])
118
+ _in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
119
+ return _in
120
+
121
+ def _update_q(self, data: Dict[str, Tensor], x: Tensor, delta_q: bool = True) -> Dict[str, Tensor]:
122
+ _q, _f, delta_a = x.split(
123
+ [
124
+ self.num_charge_channels,
125
+ self.num_charge_channels,
126
+ x.shape[-1] - 2 * self.num_charge_channels,
127
+ ],
128
+ dim=-1,
129
+ )
130
+ # for loss
131
+ data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
132
+ q = data["charges"] + _q if delta_q else _q
133
+ f = _f.pow(2)
134
+ q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
135
+ data["charges"] = q
136
+ data["a"] = data["a"] + delta_a.view_as(data["a"])
137
+ return data
138
+
139
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
140
+ data = self.prepare_input(data)
141
+
142
+ # initial features
143
+ a: Tensor = self.afv(data["numbers"])
144
+ if self.d2features:
145
+ a = a.unflatten(-1, (self.nfeature, self.nshifts_s))
146
+ data["a"] = a
147
+
148
+ # NSE case
149
+ if self.num_charge_channels == 2:
150
+ data = self._preprocess_spin_polarized_charge(data)
151
+ else:
152
+ # make sure that charge has channel dimension
153
+ data["charge"] = data["charge"].unsqueeze(-1)
154
+
155
+ # AEV
156
+ data = self.aev(data)
157
+
158
+ # MP iterations
159
+ _npass = len(self.mlps)
160
+ for ipass, mlp in enumerate(self.mlps):
161
+ if ipass == 0:
162
+ _in = self._prepare_in_a(data)
163
+ else:
164
+ _in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1)
165
+
166
+ _out = mlp(_in)
167
+ if data["_input_padded"].item():
168
+ _out = nbops.mask_i_(_out, data, mask_value=0.0)
169
+
170
+ if ipass == 0:
171
+ data = self._update_q(data, _out, delta_q=False)
172
+ elif ipass < _npass - 1:
173
+ data = self._update_q(data, _out, delta_q=True)
174
+ else:
175
+ data["aim"] = _out
176
+
177
+ # squeeze charges
178
+ if self.num_charge_channels == 2:
179
+ data = self._postprocess_spin_polarized_charge(data)
180
+ else:
181
+ data["charges"] = data["charges"].squeeze(-1)
182
+ data["charge"] = data["charge"].squeeze(-1)
183
+
184
+ # readout
185
+ for m in self.outputs.children():
186
+ data = m(data)
187
+
188
+ return data
@@ -0,0 +1,44 @@
1
+ class: aimnet.models.AIMNet2
2
+ kwargs:
3
+ nfeature: 16
4
+ d2features: true
5
+ ncomb_v: 12
6
+ hidden:
7
+ - [512, 380]
8
+ - [512, 380]
9
+ - [512, 380, 380]
10
+ aim_size: 256
11
+ aev:
12
+ rc_s: 5.0
13
+ nshifts_s: 16
14
+ outputs:
15
+ energy_mlp:
16
+ class: aimnet.modules.Output
17
+ kwargs:
18
+ n_in: 256
19
+ n_out: 1
20
+ key_in: aim
21
+ key_out: energy
22
+ mlp:
23
+ activation_fn: torch.nn.GELU
24
+ last_linear: true
25
+ hidden: [128, 128]
26
+
27
+ atomic_shift:
28
+ class: aimnet.modules.AtomicShift
29
+ kwargs:
30
+ key_in: energy
31
+ key_out: energy
32
+
33
+ atomic_sum:
34
+ class: aimnet.modules.AtomicSum
35
+ kwargs:
36
+ key_in: energy
37
+ key_out: energy
38
+
39
+ lrcoulomb:
40
+ class: aimnet.modules.LRCoulomb
41
+ kwargs:
42
+ rc: 4.6
43
+ key_in: charges
44
+ key_out: energy
@@ -0,0 +1,51 @@
1
+ class: aimnet.models.AIMNet2
2
+ kwargs:
3
+ nfeature: 16
4
+ d2features: true
5
+ ncomb_v: 12
6
+ hidden:
7
+ - [512, 380]
8
+ - [512, 380]
9
+ - [512, 380, 380]
10
+ aim_size: 256
11
+ aev:
12
+ rc_s: 5.0
13
+ nshifts_s: 16
14
+ outputs:
15
+ energy_mlp:
16
+ class: aimnet.modules.Output
17
+ kwargs:
18
+ n_in: 256
19
+ n_out: 1
20
+ key_in: aim
21
+ key_out: energy
22
+ mlp:
23
+ activation_fn: torch.nn.GELU
24
+ last_linear: true
25
+ hidden: [128, 128]
26
+
27
+ atomic_shift:
28
+ class: aimnet.modules.AtomicShift
29
+ kwargs:
30
+ key_in: energy
31
+ key_out: energy
32
+
33
+ atomic_sum:
34
+ class: aimnet.modules.AtomicSum
35
+ kwargs:
36
+ key_in: energy
37
+ key_out: energy
38
+
39
+ lrcoulomb:
40
+ class: aimnet.modules.LRCoulomb
41
+ kwargs:
42
+ rc: 4.6
43
+ key_in: charges
44
+ key_out: energy
45
+
46
+ dftd3:
47
+ class: aimnet.modules.DFTD3
48
+ kwargs:
49
+ s8: 0.3908
50
+ a1: 0.5660
51
+ a2: 3.1280
aimnet/models/base.py ADDED
@@ -0,0 +1,51 @@
1
+ from typing import ClassVar, Dict, Final
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from aimnet import nbops
7
+
8
+
9
+ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
10
+ """Base class for AIMNet2 models. Implements pre-processing data:
11
+ converting to right dtype and device, setting nb mode, calculating masks.
12
+ """
13
+
14
+ __default_dtype = torch.get_default_dtype()
15
+
16
+ _required_keys: Final = ["coord", "numbers", "charge"]
17
+ _required_keys_dtype: Final = [__default_dtype, torch.int64, __default_dtype]
18
+ _optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "shifts_lr", "cell"]
19
+ _optional_keys_dtype: Final = [
20
+ __default_dtype,
21
+ torch.int64,
22
+ torch.int64,
23
+ torch.int64,
24
+ __default_dtype,
25
+ __default_dtype,
26
+ __default_dtype,
27
+ ]
28
+ __constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
34
+ for k, d in zip(self._required_keys, self._required_keys_dtype):
35
+ assert k in data, f"Key {k} is required"
36
+ data[k] = data[k].to(d)
37
+ for k, d in zip(self._optional_keys, self._optional_keys_dtype):
38
+ if k in data:
39
+ data[k] = data[k].to(d)
40
+ return data
41
+
42
+ def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
43
+ """Some sommon operations"""
44
+ data = self._prepare_dtype(data)
45
+ data = nbops.set_nb_mode(data)
46
+ data = nbops.calc_masks(data)
47
+
48
+ assert data["charge"].ndim == 1, "Charge should be 1D tensor."
49
+ if "mult" in data:
50
+ assert data["mult"].ndim == 1, "Mult should be 1D tensor."
51
+ return data
@@ -0,0 +1,3 @@
1
+ from .aev import AEVSV, ConvSV # noqa: F401
2
+ from .core import MLP, AtomicShift, AtomicSum, Dipole, Embedding, Forces, Output, Quadrupole # noqa: F401
3
+ from .lr import D3TS, DFTD3, LRCoulomb # noqa: F401
aimnet/modules/aev.py ADDED
@@ -0,0 +1,201 @@
1
+ import math
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from aimnet import nbops, ops
8
+
9
+
10
+ class AEVSV(nn.Module):
11
+ """AEV module to expand distances and vectors toneighbors over shifted Gaussian basis functions.
12
+
13
+ Parameters:
14
+ -----------
15
+ rmin : float, optional
16
+ Minimum distance for the Gaussian basis functions. Default is 0.8.
17
+ rc_s : float, optional
18
+ Cutoff radius for scalar features. Default is 5.0.
19
+ nshifts_s : int, optional
20
+ Number of shifts for scalar features. Default is 16.
21
+ eta_s : Optional[float], optional
22
+ Width of the Gaussian basis functions for scalar features. Will estimate reasonable default.
23
+ rc_v : Optional[float], optional
24
+ Cutoff radius for vector features. Default is same as `rc_s`.
25
+ nshifts_v : Optional[int], optional
26
+ Number of shifts for vector features. Default is same as `nshifts_s`
27
+ eta_v : Optional[float], optional
28
+ Width of the Gaussian basis functions for vector features. Will estimate reasonable default.
29
+ shifts_s : Optional[List[float]], optional
30
+ List of shifts for scalar features. Default equidistant between `rmin` and `rc_s`
31
+ shifts_v : Optional[List[float]], optional
32
+ List of shifts for vector features. Default equidistant between `rmin` and `rc_v`
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ rmin: float = 0.8,
38
+ rc_s: float = 5.0,
39
+ nshifts_s: int = 16,
40
+ eta_s: Optional[float] = None,
41
+ rc_v: Optional[float] = None,
42
+ nshifts_v: Optional[int] = None,
43
+ eta_v: Optional[float] = None,
44
+ shifts_s: Optional[List[float]] = None,
45
+ shifts_v: Optional[List[float]] = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_s")
50
+ if rc_v is not None:
51
+ if rc_v > rc_s:
52
+ raise ValueError("rc_v must be less than or equal to rc_s")
53
+ if nshifts_v is None:
54
+ raise ValueError("nshifts_v must not be None")
55
+ self._init_basis(rc_v, eta_v, nshifts_v, shifts_v, rmin, mod="_v")
56
+ self._dual_basis = True
57
+ else:
58
+ # dummy init
59
+ self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_v")
60
+ self._dual_basis = False
61
+
62
+ self.dmat_fill = rc_s
63
+
64
+ def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"):
65
+ self.register_parameter(
66
+ "rc" + mod,
67
+ nn.Parameter(torch.tensor(rc, dtype=torch.float), requires_grad=False),
68
+ )
69
+ if eta is None:
70
+ eta = (1 / ((rc - rmin) / nshifts)) ** 2
71
+ self.register_parameter(
72
+ "eta" + mod,
73
+ nn.Parameter(torch.tensor(eta, dtype=torch.float), requires_grad=False),
74
+ )
75
+ if shifts is None:
76
+ shifts = torch.linspace(rmin, rc, nshifts + 1)[:nshifts]
77
+ else:
78
+ shifts = torch.as_tensor(shifts, dtype=torch.float)
79
+ self.register_parameter("shifts" + mod, nn.Parameter(shifts, requires_grad=False))
80
+
81
+ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
82
+ # shapes (..., m) and (..., m, 3)
83
+ d_ij, r_ij = ops.calc_distances(data)
84
+ data["d_ij"] = d_ij
85
+ # shapes (..., nshifts, m) and (..., nshifts, 3, m)
86
+ u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) # pylint: disable=unused-variable
87
+ # for now, do not save u_ij
88
+ data["gs"], data["gv"] = gs, gv
89
+ return data
90
+
91
+ def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
92
+ fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m)
93
+ fc_ij = nbops.mask_ij_(fc_ij, data, 0.0)
94
+ gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(
95
+ -1
96
+ ) # (..., m, nshifts) * (..., m, 1) -> (..., m, shitfs)
97
+ u_ij = r_ij / d_ij.unsqueeze(-1) # (..., m, 3) / (..., m, 1) -> (..., m, 3)
98
+ if self._dual_basis:
99
+ fc_ij = ops.cosine_cutoff(d_ij, self.rc_v)
100
+ gsv = ops.exp_expand(d_ij, self.shifts_v, self.eta_v) * fc_ij.unsqueeze(-1)
101
+ gv = gsv.unsqueeze(-2) * u_ij.unsqueeze(-1)
102
+ else:
103
+ # (..., m, 1, shifts), (..., m, 3, 1) -> (..., m, 3, shifts)
104
+ gv = gs.unsqueeze(-2) * u_ij.unsqueeze(-1)
105
+ return u_ij, gs, gv
106
+
107
+
108
+ class ConvSV(nn.Module):
109
+ """AIMNet2 type convolution: encoding of local environment which combines geometry of local environment and atomic features.
110
+
111
+ Parameters:
112
+ -----------
113
+ nshifts_s : int
114
+ Number of shifts (gaussian basis functions) for scalar convolution.
115
+ nchannel : int
116
+ Number of feature channels for atomic features.
117
+ d2features : bool, optional
118
+ Flag indicating whether to use 2D features. Default is False.
119
+ do_vector : bool, optional
120
+ Flag indicating whether to perform vector convolution. Default is True.
121
+ nshifts_v : Optional[int], optional
122
+ Number of shifts for vector convolution. If not provided, defaults to the value of nshifts_s.
123
+ ncomb_v : Optional[int], optional
124
+ Number of linear combinations for vector features. If not provided, defaults to the value of nshifts_v.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ nshifts_s: int,
130
+ nchannel: int,
131
+ d2features: bool = False,
132
+ do_vector: bool = True,
133
+ nshifts_v: Optional[int] = None,
134
+ ncomb_v: Optional[int] = None,
135
+ ):
136
+ super().__init__()
137
+ nshifts_v = nshifts_v or nshifts_s
138
+ ncomb_v = ncomb_v or nshifts_v
139
+ agh = _init_ahg(nchannel, nshifts_v, ncomb_v)
140
+ self.register_parameter("agh", nn.Parameter(agh, requires_grad=True))
141
+ self.do_vector = do_vector
142
+ self.nchannel = nchannel
143
+ self.d2features = d2features
144
+ self.nshifts_s = nshifts_s
145
+ self.nshifts_v = nshifts_v
146
+ self.ncomb_v = ncomb_v
147
+
148
+ def output_size(self):
149
+ n = self.nchannel * self.nshifts_s
150
+ if self.do_vector:
151
+ n += self.nchannel * self.ncomb_v
152
+ return n
153
+
154
+ def forward(self, a: Tensor, gs: Tensor, gv: Optional[Tensor] = None) -> Tensor:
155
+ avf = []
156
+ if self.d2features:
157
+ avf_s = torch.einsum("...mag,...mg->...ag", a, gs)
158
+ else:
159
+ avf_s = torch.einsum("...mg,...ma->...ag", gs, a)
160
+ avf.append(avf_s.flatten(-2, -1))
161
+ if self.do_vector:
162
+ assert gv is not None
163
+ agh = self.agh
164
+ if self.d2features:
165
+ avf_v = torch.einsum("...mag,...mdg,agh->...ahd", a, gv, agh)
166
+ else:
167
+ avf_v = torch.einsum("...ma,...mdg,agh->...ahd", a, gv, agh)
168
+ avf.append(avf_v.pow(2).sum(-1).flatten(-2, -1))
169
+ return torch.cat(avf, dim=-1)
170
+
171
+
172
+ def _init_ahg(b: int, m: int, n: int):
173
+ ret = torch.zeros(b, m, n)
174
+ for i in range(b):
175
+ ret[i] = _init_ahg_one(m, n) # pylinit: disable-arguments-out-of-order
176
+ return ret
177
+
178
+
179
+ def _init_ahg_one(m: int, n: int):
180
+ # make x8 times more vectors to select most diverse
181
+ x = torch.arange(m).unsqueeze(0)
182
+ a1, a2, a3, a4 = torch.randn(8 * n, 4).unsqueeze(-2).unbind(-1)
183
+ y = a1 * torch.sin(a2 * 2 * x * math.pi / m) + a3 * torch.cos(a4 * 2 * x * math.pi / m)
184
+ y -= y.mean(dim=-1, keepdim=True)
185
+ y /= y.std(dim=-1, keepdim=True)
186
+
187
+ dmat = torch.cdist(y, y)
188
+ # most distant point
189
+ ret = torch.zeros(n, m)
190
+ mask = torch.ones(y.shape[0], dtype=torch.bool)
191
+ i = dmat.sum(-1).argmax()
192
+ ret[0] = y[i]
193
+ mask[i] = False
194
+
195
+ # simple maxmin impementation
196
+ for j in range(1, n):
197
+ mindist, _ = torch.cdist(ret[:j], y).min(dim=0)
198
+ maxidx = torch.argsort(mindist)[mask][-1]
199
+ ret[j] = y[maxidx]
200
+ mask[maxidx] = False
201
+ return ret.t()