aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/modules/__init__.py
CHANGED
aimnet/modules/aev.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Dict, List, Optional, Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
from torch import Tensor, nn
|
|
6
5
|
|
|
7
6
|
from aimnet import nbops, ops
|
|
7
|
+
from aimnet.kernels import conv_sv_2d_sp
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class AEVSV(nn.Module):
|
|
@@ -37,12 +37,12 @@ class AEVSV(nn.Module):
|
|
|
37
37
|
rmin: float = 0.8,
|
|
38
38
|
rc_s: float = 5.0,
|
|
39
39
|
nshifts_s: int = 16,
|
|
40
|
-
eta_s:
|
|
41
|
-
rc_v:
|
|
42
|
-
nshifts_v:
|
|
43
|
-
eta_v:
|
|
44
|
-
shifts_s:
|
|
45
|
-
shifts_v:
|
|
40
|
+
eta_s: float | None = None,
|
|
41
|
+
rc_v: float | None = None,
|
|
42
|
+
nshifts_v: int | None = None,
|
|
43
|
+
eta_v: float | None = None,
|
|
44
|
+
shifts_s: list[float] | None = None,
|
|
45
|
+
shifts_v: list[float] | None = None,
|
|
46
46
|
):
|
|
47
47
|
super().__init__()
|
|
48
48
|
|
|
@@ -78,31 +78,31 @@ class AEVSV(nn.Module):
|
|
|
78
78
|
shifts = torch.as_tensor(shifts, dtype=torch.float)
|
|
79
79
|
self.register_parameter("shifts" + mod, nn.Parameter(shifts, requires_grad=False))
|
|
80
80
|
|
|
81
|
-
def forward(self, data:
|
|
82
|
-
#
|
|
81
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
82
|
+
# d_ij: distances to neighbors (..., m)
|
|
83
|
+
# r_ij: displacement vectors to neighbors (..., m, 3)
|
|
83
84
|
d_ij, r_ij = ops.calc_distances(data)
|
|
84
85
|
data["d_ij"] = d_ij
|
|
85
|
-
#
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
data["
|
|
86
|
+
# Atomic environment vectors: (..., m, g, 4)
|
|
87
|
+
# 4 components = 1 scalar (radial) + 3 vector (directional)
|
|
88
|
+
g_sv = self._calc_aev(r_ij, d_ij, data)
|
|
89
|
+
data["g_sv"] = g_sv
|
|
89
90
|
return data
|
|
90
91
|
|
|
91
|
-
def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data:
|
|
92
|
+
def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: dict[str, Tensor]) -> Tensor:
|
|
92
93
|
fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m)
|
|
93
94
|
fc_ij = nbops.mask_ij_(fc_ij, data, 0.0)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
return u_ij, gs, gv
|
|
95
|
+
# Apply cutoff envelope to Gaussian-expanded distances
|
|
96
|
+
# Shape: (..., m, nshifts) * (..., m, 1) -> (..., m, nshifts)
|
|
97
|
+
gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(-1)
|
|
98
|
+
# Normalize displacement vectors to unit direction vectors
|
|
99
|
+
# Shape: (..., m, 3) / (..., m, 1) -> (..., m, 3)
|
|
100
|
+
u_ij = r_ij / d_ij.unsqueeze(-1)
|
|
101
|
+
# Combine radial basis with directional info for vector features
|
|
102
|
+
# Shape: (..., m, 1, nshifts) * (..., m, 3, 1) -> (..., m, nshifts, 3)
|
|
103
|
+
gv = gs.unsqueeze(-1) * u_ij.unsqueeze(-2)
|
|
104
|
+
g_sv = torch.cat([gs.unsqueeze(-1), gv], dim=-1)
|
|
105
|
+
return g_sv
|
|
106
106
|
|
|
107
107
|
|
|
108
108
|
class ConvSV(nn.Module):
|
|
@@ -116,8 +116,6 @@ class ConvSV(nn.Module):
|
|
|
116
116
|
Number of feature channels for atomic features.
|
|
117
117
|
d2features : bool, optional
|
|
118
118
|
Flag indicating whether to use 2D features. Default is False.
|
|
119
|
-
do_vector : bool, optional
|
|
120
|
-
Flag indicating whether to perform vector convolution. Default is True.
|
|
121
119
|
nshifts_v : Optional[int], optional
|
|
122
120
|
Number of shifts for vector convolution. If not provided, defaults to the value of nshifts_s.
|
|
123
121
|
ncomb_v : Optional[int], optional
|
|
@@ -129,16 +127,15 @@ class ConvSV(nn.Module):
|
|
|
129
127
|
nshifts_s: int,
|
|
130
128
|
nchannel: int,
|
|
131
129
|
d2features: bool = False,
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
ncomb_v: Optional[int] = None,
|
|
130
|
+
nshifts_v: int | None = None,
|
|
131
|
+
ncomb_v: int | None = None,
|
|
135
132
|
):
|
|
136
133
|
super().__init__()
|
|
137
134
|
nshifts_v = nshifts_v or nshifts_s
|
|
138
135
|
ncomb_v = ncomb_v or nshifts_v
|
|
139
136
|
agh = _init_ahg(nchannel, nshifts_v, ncomb_v)
|
|
140
137
|
self.register_parameter("agh", nn.Parameter(agh, requires_grad=True))
|
|
141
|
-
self.do_vector =
|
|
138
|
+
self.do_vector = True
|
|
142
139
|
self.nchannel = nchannel
|
|
143
140
|
self.d2features = d2features
|
|
144
141
|
self.nshifts_s = nshifts_s
|
|
@@ -151,33 +148,37 @@ class ConvSV(nn.Module):
|
|
|
151
148
|
n += self.nchannel * self.ncomb_v
|
|
152
149
|
return n
|
|
153
150
|
|
|
154
|
-
def forward(self,
|
|
155
|
-
|
|
151
|
+
def forward(self, data: dict[str, Tensor], a: Tensor) -> Tensor:
|
|
152
|
+
g_sv = data["g_sv"]
|
|
153
|
+
mode = nbops.get_nb_mode(data)
|
|
156
154
|
if self.d2features:
|
|
157
|
-
|
|
155
|
+
if mode > 0 and a.device.type == "cuda":
|
|
156
|
+
avf_sv = conv_sv_2d_sp(a, data["nbmat"], g_sv)
|
|
157
|
+
elif mode > 0:
|
|
158
|
+
a_j = a.index_select(0, data["nbmat"].flatten()).unflatten(0, data["nbmat"].shape)
|
|
159
|
+
avf_sv = torch.einsum("...mag,...mgd->...agd", a_j, g_sv)
|
|
160
|
+
else:
|
|
161
|
+
avf_sv = torch.einsum("...mag,...mgd->...agd", a.unsqueeze(1), g_sv)
|
|
158
162
|
else:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
assert gv is not None
|
|
163
|
-
agh = self.agh
|
|
164
|
-
if self.d2features:
|
|
165
|
-
avf_v = torch.einsum("...mag,...mdg,agh->...ahd", a, gv, agh)
|
|
163
|
+
if mode > 0:
|
|
164
|
+
a_j = a.index_select(0, data["nbmat"].flatten()).unflatten(0, data["nbmat"].shape)
|
|
165
|
+
avf_sv = torch.einsum("...ma,...mgd->...agd", a_j, g_sv)
|
|
166
166
|
else:
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
avf_sv = torch.einsum("...ma,...mgd->...agd", a.unsqueeze(1), g_sv)
|
|
168
|
+
avf_s, avf_v = avf_sv.split([1, 3], dim=-1)
|
|
169
|
+
avf_v = torch.einsum("agh,...agd->...ahd", self.agh, avf_v).pow(2).sum(-1)
|
|
170
|
+
return torch.cat([avf_s.squeeze(-1).flatten(-2, -1), avf_v.flatten(-2, -1)], dim=-1)
|
|
170
171
|
|
|
171
172
|
|
|
172
173
|
def _init_ahg(b: int, m: int, n: int):
|
|
173
174
|
ret = torch.zeros(b, m, n)
|
|
174
175
|
for i in range(b):
|
|
175
|
-
ret[i] = _init_ahg_one(m, n) #
|
|
176
|
+
ret[i] = _init_ahg_one(m, n) # pylint: disable=arguments-out-of-order
|
|
176
177
|
return ret
|
|
177
178
|
|
|
178
179
|
|
|
179
180
|
def _init_ahg_one(m: int, n: int):
|
|
180
|
-
#
|
|
181
|
+
# Oversample by 8x to select most orthogonal vectors via maxmin algorithm
|
|
181
182
|
x = torch.arange(m).unsqueeze(0)
|
|
182
183
|
a1, a2, a3, a4 = torch.randn(8 * n, 4).unsqueeze(-2).unbind(-1)
|
|
183
184
|
y = a1 * torch.sin(a2 * 2 * x * math.pi / m) + a3 * torch.cos(a4 * 2 * x * math.pi / m)
|
|
@@ -192,7 +193,7 @@ def _init_ahg_one(m: int, n: int):
|
|
|
192
193
|
ret[0] = y[i]
|
|
193
194
|
mask[i] = False
|
|
194
195
|
|
|
195
|
-
# simple maxmin
|
|
196
|
+
# simple maxmin implementation
|
|
196
197
|
for j in range(1, n):
|
|
197
198
|
mindist, _ = torch.cdist(ret[:j], y).min(dim=0)
|
|
198
199
|
maxidx = torch.argsort(mindist)[mask][-1]
|
aimnet/modules/core.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor, nn
|
|
@@ -10,9 +11,9 @@ from aimnet.config import get_init_module, get_module
|
|
|
10
11
|
def MLP(
|
|
11
12
|
n_in: int,
|
|
12
13
|
n_out: int,
|
|
13
|
-
hidden:
|
|
14
|
+
hidden: list[int] | None = None,
|
|
14
15
|
activation_fn: Callable | str = "torch.nn.GELU",
|
|
15
|
-
activation_kwargs:
|
|
16
|
+
activation_kwargs: dict[str, Any] | None = None,
|
|
16
17
|
weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
|
|
17
18
|
bias: bool = True,
|
|
18
19
|
last_linear: bool = True,
|
|
@@ -44,7 +45,7 @@ def MLP(
|
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
class Embedding(nn.Embedding):
|
|
47
|
-
def __init__(self, init:
|
|
48
|
+
def __init__(self, init: dict[int, Any] | None = None, **kwargs):
|
|
48
49
|
super().__init__(**kwargs)
|
|
49
50
|
with torch.no_grad():
|
|
50
51
|
if init is not None:
|
|
@@ -70,7 +71,7 @@ class DSequential(nn.Module):
|
|
|
70
71
|
super().__init__()
|
|
71
72
|
self.module = nn.ModuleList(modules)
|
|
72
73
|
|
|
73
|
-
def forward(self, data:
|
|
74
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
74
75
|
for m in self.module:
|
|
75
76
|
data = m(data)
|
|
76
77
|
return data
|
|
@@ -97,7 +98,7 @@ class AtomicShift(nn.Module):
|
|
|
97
98
|
def extra_repr(self) -> str:
|
|
98
99
|
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
99
100
|
|
|
100
|
-
def forward(self, data:
|
|
101
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
101
102
|
shifts = self.shifts(data["numbers"]).squeeze(-1)
|
|
102
103
|
if self.reduce_sum:
|
|
103
104
|
shifts = nbops.mol_sum(shifts, data)
|
|
@@ -114,13 +115,13 @@ class AtomicSum(nn.Module):
|
|
|
114
115
|
def extra_repr(self) -> str:
|
|
115
116
|
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
116
117
|
|
|
117
|
-
def forward(self, data:
|
|
118
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
118
119
|
data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
|
|
119
120
|
return data
|
|
120
121
|
|
|
121
122
|
|
|
122
123
|
class Output(nn.Module):
|
|
123
|
-
def __init__(self, mlp:
|
|
124
|
+
def __init__(self, mlp: dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str):
|
|
124
125
|
super().__init__()
|
|
125
126
|
self.key_in = key_in
|
|
126
127
|
self.key_out = key_out
|
|
@@ -131,7 +132,7 @@ class Output(nn.Module):
|
|
|
131
132
|
def extra_repr(self) -> str:
|
|
132
133
|
return f"key_in: {self.key_in}, key_out: {self.key_out}"
|
|
133
134
|
|
|
134
|
-
def forward(self, data:
|
|
135
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
135
136
|
v = self.mlp(data[self.key_in]).squeeze(-1)
|
|
136
137
|
if data["_input_padded"].item():
|
|
137
138
|
v = nbops.mask_i_(v, data, mask_value=0.0)
|
|
@@ -147,7 +148,7 @@ class Forces(nn.Module):
|
|
|
147
148
|
self.y = y
|
|
148
149
|
self.key_out = key_out
|
|
149
150
|
|
|
150
|
-
def forward(self, data:
|
|
151
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
151
152
|
prev = torch.is_grad_enabled()
|
|
152
153
|
torch.set_grad_enabled(True)
|
|
153
154
|
data[self.x].requires_grad_(True)
|
|
@@ -171,7 +172,7 @@ class Dipole(nn.Module):
|
|
|
171
172
|
def extra_repr(self) -> str:
|
|
172
173
|
return f"key_in: {self.key_in}, key_out: {self.key_out}, center_coord: {self.center_coord}"
|
|
173
174
|
|
|
174
|
-
def forward(self, data:
|
|
175
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
175
176
|
q = data[self.key_in]
|
|
176
177
|
r = data["coord"]
|
|
177
178
|
if self.center_coord:
|
|
@@ -184,7 +185,7 @@ class Quadrupole(Dipole):
|
|
|
184
185
|
def __init__(self, key_in: str = "charges", key_out: str = "quadrupole", center_coord: bool = False):
|
|
185
186
|
super().__init__(key_in=key_in, key_out=key_out, center_coord=center_coord)
|
|
186
187
|
|
|
187
|
-
def forward(self, data:
|
|
188
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
188
189
|
q = data[self.key_in]
|
|
189
190
|
r = data["coord"]
|
|
190
191
|
if self.center_coord:
|
|
@@ -215,7 +216,7 @@ class SRRep(nn.Module):
|
|
|
215
216
|
self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
|
|
216
217
|
self.params.weight.requires_grad_(False)
|
|
217
218
|
|
|
218
|
-
def forward(self, data:
|
|
219
|
+
def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
219
220
|
p = self.params(data["numbers"])
|
|
220
221
|
p_i, p_j = nbops.get_ij(p, data)
|
|
221
222
|
p_ij = p_i * p_j
|