aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/nbops.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
from typing import Dict, Tuple
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from torch import Tensor
|
|
5
3
|
|
|
6
4
|
|
|
7
|
-
def set_nb_mode(data:
|
|
5
|
+
def set_nb_mode(data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
8
6
|
"""Logic to guess and set the neighbor model."""
|
|
9
7
|
if "nbmat" in data:
|
|
10
8
|
if data["nbmat"].ndim == 2:
|
|
@@ -18,12 +16,12 @@ def set_nb_mode(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
|
18
16
|
return data
|
|
19
17
|
|
|
20
18
|
|
|
21
|
-
def get_nb_mode(data:
|
|
19
|
+
def get_nb_mode(data: dict[str, Tensor]) -> int:
|
|
22
20
|
"""Get the neighbor model."""
|
|
23
21
|
return int(data["_nb_mode"].item())
|
|
24
22
|
|
|
25
23
|
|
|
26
|
-
def calc_masks(data:
|
|
24
|
+
def calc_masks(data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
27
25
|
"""Calculate neighbor masks"""
|
|
28
26
|
nb_mode = get_nb_mode(data)
|
|
29
27
|
if nb_mode == 0:
|
|
@@ -45,9 +43,20 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
|
45
43
|
# padding must be the last atom
|
|
46
44
|
data["mask_i"] = torch.zeros(data["numbers"].shape[0], device=data["numbers"].device, dtype=torch.bool)
|
|
47
45
|
data["mask_i"][-1] = True
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
46
|
+
# Track processed arrays by their data pointer to avoid redundant mask calculations
|
|
47
|
+
processed: dict[int, str] = {} # data_ptr -> mask_suffix
|
|
48
|
+
for suffix in ("", "_lr", "_coulomb", "_dftd3"):
|
|
49
|
+
nbmat_key = f"nbmat{suffix}"
|
|
50
|
+
if nbmat_key in data:
|
|
51
|
+
if not torch.jit.is_scripting():
|
|
52
|
+
# data_ptr() not supported in TorchScript
|
|
53
|
+
ptr = data[nbmat_key].data_ptr()
|
|
54
|
+
if ptr in processed:
|
|
55
|
+
# Same array - reuse existing mask
|
|
56
|
+
data[f"mask_ij{suffix}"] = data[f"mask_ij{processed[ptr]}"]
|
|
57
|
+
continue
|
|
58
|
+
processed[ptr] = suffix
|
|
59
|
+
data[f"mask_ij{suffix}"] = data[nbmat_key] == data["numbers"].shape[0] - 1
|
|
51
60
|
data["_input_padded"] = torch.tensor(True)
|
|
52
61
|
data["mol_sizes"] = torch.bincount(data["mol_idx"])
|
|
53
62
|
# last atom is padding
|
|
@@ -56,9 +65,20 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
|
56
65
|
data["mask_i"] = data["numbers"] == 0
|
|
57
66
|
w = torch.where(data["mask_i"])
|
|
58
67
|
pad_idx = w[0] * data["numbers"].shape[1] + w[1]
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
68
|
+
# Track processed arrays by their data pointer to avoid redundant mask calculations
|
|
69
|
+
processed: dict[int, str] = {} # data_ptr -> mask_suffix
|
|
70
|
+
for suffix in ("", "_lr", "_coulomb", "_dftd3"):
|
|
71
|
+
nbmat_key = f"nbmat{suffix}"
|
|
72
|
+
if nbmat_key in data:
|
|
73
|
+
if not torch.jit.is_scripting():
|
|
74
|
+
# data_ptr() not supported in TorchScript
|
|
75
|
+
ptr = data[nbmat_key].data_ptr()
|
|
76
|
+
if ptr in processed:
|
|
77
|
+
# Same array - reuse existing mask
|
|
78
|
+
data[f"mask_ij{suffix}"] = data[f"mask_ij{processed[ptr]}"]
|
|
79
|
+
continue
|
|
80
|
+
processed[ptr] = suffix
|
|
81
|
+
data[f"mask_ij{suffix}"] = torch.isin(data[nbmat_key], pad_idx)
|
|
62
82
|
data["_input_padded"] = torch.tensor(True)
|
|
63
83
|
data["mol_sizes"] = (~data["mask_i"]).sum(-1)
|
|
64
84
|
else:
|
|
@@ -69,7 +89,7 @@ def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
|
69
89
|
|
|
70
90
|
def mask_ij_(
|
|
71
91
|
x: Tensor,
|
|
72
|
-
data:
|
|
92
|
+
data: dict[str, Tensor],
|
|
73
93
|
mask_value: float = 0.0,
|
|
74
94
|
inplace: bool = True,
|
|
75
95
|
suffix: str = "",
|
|
@@ -84,7 +104,7 @@ def mask_ij_(
|
|
|
84
104
|
return x
|
|
85
105
|
|
|
86
106
|
|
|
87
|
-
def mask_i_(x: Tensor, data:
|
|
107
|
+
def mask_i_(x: Tensor, data: dict[str, Tensor], mask_value: float = 0.0, inplace: bool = True) -> Tensor:
|
|
88
108
|
nb_mode = get_nb_mode(data)
|
|
89
109
|
if nb_mode == 0:
|
|
90
110
|
if data["_input_padded"].item():
|
|
@@ -110,7 +130,47 @@ def mask_i_(x: Tensor, data: Dict[str, Tensor], mask_value: float = 0.0, inplace
|
|
|
110
130
|
return x
|
|
111
131
|
|
|
112
132
|
|
|
113
|
-
def
|
|
133
|
+
def resolve_suffix(data: dict[str, Tensor], suffixes: list[str]) -> str:
|
|
134
|
+
"""Try suffixes in order, return first found, raise if none exist.
|
|
135
|
+
|
|
136
|
+
This function makes fallback behavior explicit by requiring a list
|
|
137
|
+
of acceptable suffixes. Each module controls which neighbor lists
|
|
138
|
+
are acceptable for its operations.
|
|
139
|
+
|
|
140
|
+
For nb_mode=0 (no neighbor matrix), returns empty string since
|
|
141
|
+
neighbor lists are not used in that mode.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
data : dict
|
|
146
|
+
Data dictionary containing neighbor matrices.
|
|
147
|
+
suffixes : list[str]
|
|
148
|
+
List of suffixes to try in priority order (e.g., ["_dftd3", "_lr"]).
|
|
149
|
+
Empty string "" can be included for fallback to base nbmat.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
str
|
|
154
|
+
The first suffix that has a corresponding nbmat{suffix} in data.
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
KeyError
|
|
159
|
+
If none of the suffixes have corresponding neighbor matrices.
|
|
160
|
+
"""
|
|
161
|
+
# In nb_mode=0, there are no neighbor matrices - suffix is unused
|
|
162
|
+
nb_mode = get_nb_mode(data)
|
|
163
|
+
if nb_mode == 0:
|
|
164
|
+
return ""
|
|
165
|
+
|
|
166
|
+
for suffix in suffixes:
|
|
167
|
+
if f"nbmat{suffix}" in data:
|
|
168
|
+
return suffix
|
|
169
|
+
|
|
170
|
+
raise KeyError(f"No neighbor matrix found for any suffix in {suffixes}")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_ij(x: Tensor, data: dict[str, Tensor], suffix: str = "") -> tuple[Tensor, Tensor]:
|
|
114
174
|
nb_mode = get_nb_mode(data)
|
|
115
175
|
if nb_mode == 0:
|
|
116
176
|
x_i = x.unsqueeze(2)
|
|
@@ -128,7 +188,36 @@ def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor
|
|
|
128
188
|
return x_i, x_j
|
|
129
189
|
|
|
130
190
|
|
|
131
|
-
def
|
|
191
|
+
def get_i(x: Tensor, data: dict[str, Tensor]) -> Tensor:
|
|
192
|
+
"""Get the i-component of pairwise expansion without computing j.
|
|
193
|
+
|
|
194
|
+
This is an optimized version of get_ij when only x_i is needed,
|
|
195
|
+
avoiding the expensive index_select operation for x_j.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
x : Tensor
|
|
200
|
+
Input tensor to expand.
|
|
201
|
+
data : dict[str, Tensor]
|
|
202
|
+
Data dictionary containing neighbor mode information.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
Tensor
|
|
207
|
+
The i-component with appropriate unsqueeze for the neighbor mode.
|
|
208
|
+
"""
|
|
209
|
+
nb_mode = get_nb_mode(data)
|
|
210
|
+
if nb_mode == 0:
|
|
211
|
+
return x.unsqueeze(2)
|
|
212
|
+
elif nb_mode == 1:
|
|
213
|
+
return x.unsqueeze(1)
|
|
214
|
+
elif nb_mode == 2:
|
|
215
|
+
return x.unsqueeze(2)
|
|
216
|
+
else:
|
|
217
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def mol_sum(x: Tensor, data: dict[str, Tensor]) -> Tensor:
|
|
132
221
|
nb_mode = get_nb_mode(data)
|
|
133
222
|
if nb_mode in (0, 2):
|
|
134
223
|
res = x.sum(dim=1)
|
|
@@ -140,6 +229,7 @@ def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor:
|
|
|
140
229
|
idx = data["mol_idx"]
|
|
141
230
|
# assuming mol_idx is sorted, replace with max if not
|
|
142
231
|
out_size = int(idx[-1].item()) + 1
|
|
232
|
+
|
|
143
233
|
if x.ndim == 1:
|
|
144
234
|
res = torch.zeros(out_size, device=x.device, dtype=x.dtype)
|
|
145
235
|
else:
|
aimnet/ops.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Dict, Optional, Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
from torch import Tensor
|
|
@@ -7,7 +6,7 @@ from torch import Tensor
|
|
|
7
6
|
from aimnet import nbops
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
def lazy_calc_dij_lr(data:
|
|
9
|
+
def lazy_calc_dij_lr(data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
11
10
|
if "d_ij_lr" not in data:
|
|
12
11
|
nb_mode = nbops.get_nb_mode(data)
|
|
13
12
|
if nb_mode == 0:
|
|
@@ -17,23 +16,67 @@ def lazy_calc_dij_lr(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
|
17
16
|
return data
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
def
|
|
19
|
+
def lazy_calc_dij(data: dict[str, Tensor], suffix: str) -> dict[str, Tensor]:
|
|
20
|
+
"""Lazily calculate distances for a given suffix.
|
|
21
|
+
|
|
22
|
+
Computes and caches d_ij{suffix} in data dict if not present.
|
|
23
|
+
For nb_mode=0 (no neighbor list), reuses d_ij.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
data : dict
|
|
28
|
+
Data dictionary.
|
|
29
|
+
suffix : str
|
|
30
|
+
Suffix for neighbor matrix (e.g., "_coulomb", "_dftd3", "_lr").
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
dict
|
|
35
|
+
Data dictionary with d_ij{suffix} added.
|
|
36
|
+
"""
|
|
37
|
+
key = f"d_ij{suffix}"
|
|
38
|
+
if key not in data:
|
|
39
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
40
|
+
if nb_mode == 0:
|
|
41
|
+
data[key] = data["d_ij"]
|
|
42
|
+
else:
|
|
43
|
+
data[key] = calc_distances(data, suffix=suffix)[0]
|
|
44
|
+
return data
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def calc_distances(data: dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> tuple[Tensor, Tensor]:
|
|
21
48
|
coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix)
|
|
22
49
|
if f"shifts{suffix}" in data:
|
|
23
50
|
assert "cell" in data, "cell is required if shifts are provided"
|
|
24
51
|
nb_mode = nbops.get_nb_mode(data)
|
|
52
|
+
cell = data["cell"]
|
|
25
53
|
if nb_mode == 2:
|
|
26
|
-
shifts
|
|
54
|
+
# Batched format: shifts (B, N, M, 3), cell (B, 3, 3) or (3, 3)
|
|
55
|
+
if cell.ndim == 2:
|
|
56
|
+
shifts = torch.einsum("bnmd,dh->bnmh", data[f"shifts{suffix}"], cell)
|
|
57
|
+
else:
|
|
58
|
+
shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], cell)
|
|
59
|
+
elif nb_mode == 1:
|
|
60
|
+
# Flat format: shifts (N_total, M, 3), cell (3, 3) or (B, 3, 3)
|
|
61
|
+
if cell.ndim == 2:
|
|
62
|
+
shifts = data[f"shifts{suffix}"] @ cell
|
|
63
|
+
else:
|
|
64
|
+
# Batched cells - need mol_idx to select correct cell for each atom
|
|
65
|
+
mol_idx = data["mol_idx"]
|
|
66
|
+
atom_cell = cell[mol_idx] # (N_total, 3, 3)
|
|
67
|
+
# shifts: (N_total, M, 3), atom_cell: (N_total, 3, 3)
|
|
68
|
+
shifts = torch.einsum("nmd,ndh->nmh", data[f"shifts{suffix}"], atom_cell)
|
|
27
69
|
else:
|
|
28
|
-
|
|
70
|
+
# nb_mode == 0: no neighbor matrix, shouldn't have shifts
|
|
71
|
+
shifts = data[f"shifts{suffix}"] @ cell
|
|
29
72
|
coord_j = coord_j + shifts
|
|
30
73
|
r_ij = coord_j - coord_i
|
|
74
|
+
r_ij = nbops.mask_ij_(r_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
|
|
31
75
|
d_ij = torch.norm(r_ij, p=2, dim=-1)
|
|
32
|
-
d_ij = nbops.mask_ij_(d_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
|
|
33
76
|
return d_ij, r_ij
|
|
34
77
|
|
|
35
78
|
|
|
36
|
-
def center_coordinates(coord: Tensor, data:
|
|
79
|
+
def center_coordinates(coord: Tensor, data: dict[str, Tensor], masses: Tensor | None = None) -> Tensor:
|
|
37
80
|
if masses is not None:
|
|
38
81
|
masses = masses.unsqueeze(-1)
|
|
39
82
|
center = nbops.mol_sum(coord * masses, data) / nbops.mol_sum(masses, data) / data["mol_sizes"].unsqueeze(-1)
|
|
@@ -61,16 +104,17 @@ def exp_expand(d_ij: Tensor, shifts: Tensor, eta: float) -> Tensor:
|
|
|
61
104
|
return torch.exp(-eta * (d_ij.unsqueeze(-1) - shifts) ** 2)
|
|
62
105
|
|
|
63
106
|
|
|
64
|
-
# pylint: disable=invalid-name
|
|
65
107
|
def nse(
|
|
66
108
|
Q: Tensor,
|
|
67
109
|
q_u: Tensor,
|
|
68
110
|
f_u: Tensor,
|
|
69
|
-
data:
|
|
111
|
+
data: dict[str, Tensor],
|
|
70
112
|
epsilon: float = 1.0e-6,
|
|
71
113
|
) -> Tensor:
|
|
72
114
|
# Q and q_u and f_u must have last dimension size 1 or 2
|
|
73
|
-
F_u = nbops.mol_sum(f_u, data)
|
|
115
|
+
F_u = nbops.mol_sum(f_u, data)
|
|
116
|
+
if epsilon > 0:
|
|
117
|
+
F_u = F_u + epsilon
|
|
74
118
|
Q_u = nbops.mol_sum(q_u, data)
|
|
75
119
|
dQ = Q - Q_u
|
|
76
120
|
# for loss
|
|
@@ -92,30 +136,36 @@ def nse(
|
|
|
92
136
|
return q
|
|
93
137
|
|
|
94
138
|
|
|
95
|
-
def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data:
|
|
139
|
+
def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: dict[str, Tensor]) -> Tensor:
|
|
96
140
|
_c1 = (alpha * d_ij).erfc() / d_ij
|
|
97
141
|
_c2 = math.erfc(alpha * Rc) / Rc
|
|
98
142
|
_c3 = _c2 / Rc
|
|
99
143
|
_c4 = 2 * alpha * math.exp(-((alpha * Rc) ** 2)) / (Rc * math.pi**0.5)
|
|
100
144
|
J = _c1 - _c2 + (d_ij - Rc) * (_c3 + _c4)
|
|
101
|
-
#
|
|
102
|
-
mask = data["mask_ij_lr"]
|
|
145
|
+
# Zero invalid pairs: padding/diagonal (mask_ij_lr) OR beyond cutoff
|
|
146
|
+
mask = data["mask_ij_lr"] | (d_ij > Rc)
|
|
103
147
|
J.masked_fill_(mask, 0.0)
|
|
104
148
|
return J
|
|
105
149
|
|
|
106
150
|
|
|
107
|
-
def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data:
|
|
151
|
+
def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data: dict[str, Tensor]) -> Tensor:
|
|
108
152
|
_c1 = 1.0 / d_ij
|
|
109
153
|
_c2 = 1.0 / Rc
|
|
110
154
|
_c3 = _c2 / Rc
|
|
111
155
|
J = _c1 - _c2 + (d_ij - Rc) * _c3
|
|
112
|
-
|
|
156
|
+
# Zero invalid pairs: padding/diagonal (mask_ij_lr) OR beyond cutoff
|
|
157
|
+
mask = data["mask_ij_lr"] | (d_ij > Rc)
|
|
113
158
|
J.masked_fill_(mask, 0.0)
|
|
114
159
|
return J
|
|
115
160
|
|
|
116
161
|
|
|
117
162
|
def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
|
|
118
|
-
|
|
163
|
+
"""Get all lattice shift vectors within cutoff distance.
|
|
164
|
+
|
|
165
|
+
Note: Batched cells are not supported - this function is only used by Ewald summation
|
|
166
|
+
which is a single-molecule calculation.
|
|
167
|
+
"""
|
|
168
|
+
assert cell.ndim == 2 and cell.shape == (3, 3), "Batched cells not supported for Ewald summation"
|
|
119
169
|
cell_inv = torch.inverse(cell).mT
|
|
120
170
|
inv_distances = cell_inv.norm(p=2, dim=-1)
|
|
121
171
|
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
|
|
@@ -128,10 +178,32 @@ def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
|
|
|
128
178
|
return shifts
|
|
129
179
|
|
|
130
180
|
|
|
131
|
-
def coulomb_matrix_ewald(coord: Tensor, cell: Tensor) -> Tensor:
|
|
181
|
+
def coulomb_matrix_ewald(coord: Tensor, cell: Tensor, accuracy: float = 1e-8) -> Tensor:
|
|
182
|
+
"""Compute Coulomb matrix using Ewald summation.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
coord : Tensor
|
|
187
|
+
Atomic coordinates, shape (N, 3).
|
|
188
|
+
cell : Tensor
|
|
189
|
+
Unit cell vectors, shape (3, 3).
|
|
190
|
+
accuracy : float
|
|
191
|
+
Target accuracy for the Ewald summation. Controls the real-space
|
|
192
|
+
and reciprocal-space cutoffs. Lower values give higher accuracy
|
|
193
|
+
but require more computation. Default is 1e-8.
|
|
194
|
+
|
|
195
|
+
The cutoffs are computed as:
|
|
196
|
+
- eta = (V^2 / N)^(1/6) / sqrt(2*pi)
|
|
197
|
+
- cutoff_real = sqrt(-2 * ln(accuracy)) * eta
|
|
198
|
+
- cutoff_recip = sqrt(-2 * ln(accuracy)) / eta
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
Tensor
|
|
203
|
+
Coulomb matrix J, shape (N, N).
|
|
204
|
+
"""
|
|
132
205
|
# single molecule implementation. nb_mode == 1
|
|
133
206
|
assert coord.ndim == 2 and cell.ndim == 2, "Only single molecule is supported"
|
|
134
|
-
accuracy = 1e-8
|
|
135
207
|
N = coord.shape[0]
|
|
136
208
|
volume = torch.det(cell)
|
|
137
209
|
eta = ((volume**2 / N) ** (1 / 6)) / math.sqrt(2.0 * math.pi)
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Export trained model to distributable state dict format.
|
|
3
|
+
|
|
4
|
+
This script creates a self-contained .pt file from training artifacts:
|
|
5
|
+
- Raw PyTorch weights (.pt)
|
|
6
|
+
- Self-atomic energies (.sae)
|
|
7
|
+
- Model YAML configuration
|
|
8
|
+
|
|
9
|
+
The output file contains:
|
|
10
|
+
- model_yaml: Core model config (without LRCoulomb/DFTD3, with SRCoulomb if needed)
|
|
11
|
+
- cutoff: Model cutoff
|
|
12
|
+
- needs_coulomb: Whether calculator should add external Coulomb
|
|
13
|
+
- needs_dispersion: Whether calculator should add external DFTD3
|
|
14
|
+
- coulomb_mode: "sr_embedded" | "none" (describes what's in the model)
|
|
15
|
+
- coulomb_sr_rc: Coulomb short-range cutoff (optional, if coulomb_mode="sr_embedded")
|
|
16
|
+
- coulomb_sr_envelope: Envelope function ("exp" or "cosine", optional)
|
|
17
|
+
- d3_params: D3 parameters {s8, a1, a2, s6} (optional, if needs_dispersion=True)
|
|
18
|
+
- has_embedded_lr: Whether model has embedded LR modules (D3TS, SRCoulomb) needing nbmat_lr
|
|
19
|
+
- implemented_species: Parametrized atomic numbers
|
|
20
|
+
- state_dict: Model weights with SAE baked into atomic_shift (float64)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import copy
|
|
24
|
+
|
|
25
|
+
import click
|
|
26
|
+
import torch
|
|
27
|
+
import yaml
|
|
28
|
+
from torch import nn
|
|
29
|
+
|
|
30
|
+
from aimnet.config import build_module, load_yaml
|
|
31
|
+
from aimnet.models.utils import strip_lr_modules_from_yaml, validate_state_dict_keys
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_sae(sae_file: str) -> dict[int, float]:
|
|
35
|
+
"""Load SAE file (YAML-like format: atomic_number: energy)."""
|
|
36
|
+
sae = load_yaml(sae_file)
|
|
37
|
+
if not isinstance(sae, dict):
|
|
38
|
+
raise TypeError("SAE file must contain a dictionary.")
|
|
39
|
+
return {int(k): float(v) for k, v in sae.items()}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def bake_sae_into_model(model: nn.Module, sae: dict[int, float]) -> nn.Module:
|
|
43
|
+
"""Add SAE values to atomic_shift.shifts.weight (converted to float64)."""
|
|
44
|
+
# Disable gradients before in-place operation
|
|
45
|
+
for p in model.parameters():
|
|
46
|
+
p.requires_grad_(False)
|
|
47
|
+
model.outputs.atomic_shift.double() # type: ignore
|
|
48
|
+
for k, v in sae.items():
|
|
49
|
+
model.outputs.atomic_shift.shifts.weight[k] += v # type: ignore
|
|
50
|
+
return model
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def extract_cutoff(model: nn.Module) -> float:
|
|
54
|
+
"""Extract cutoff from model's AEV module."""
|
|
55
|
+
return float(model.aev.rc_s.item()) # type: ignore
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_implemented_species(sae: dict[int, float]) -> list[int]:
|
|
59
|
+
"""Get list of implemented species from SAE."""
|
|
60
|
+
return sorted(sae.keys())
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def mask_not_implemented_species(model: nn.Module, species: list[int]) -> nn.Module:
|
|
64
|
+
"""Set NaN for species not in the SAE."""
|
|
65
|
+
weight = model.afv.weight # type: ignore
|
|
66
|
+
for i in range(1, weight.shape[0]): # type: ignore
|
|
67
|
+
if i not in species:
|
|
68
|
+
weight[i, :] = torch.nan # type: ignore
|
|
69
|
+
return model
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@click.command()
|
|
73
|
+
@click.argument("weights", type=click.Path(exists=True))
|
|
74
|
+
@click.argument("output", type=str)
|
|
75
|
+
@click.option("--model", "-m", type=click.Path(exists=True), required=True, help="Path to model definition YAML file")
|
|
76
|
+
@click.option("--sae", "-s", type=click.Path(exists=True), required=True, help="Path to the SAE YAML file")
|
|
77
|
+
@click.option(
|
|
78
|
+
"--needs-coulomb/--no-coulomb", default=None, help="Override Coulomb detection. Default: auto-detect from YAML"
|
|
79
|
+
)
|
|
80
|
+
@click.option(
|
|
81
|
+
"--needs-dispersion/--no-dispersion",
|
|
82
|
+
default=None,
|
|
83
|
+
help="Override dispersion detection. Default: auto-detect from YAML",
|
|
84
|
+
)
|
|
85
|
+
def export_model(
|
|
86
|
+
weights: str,
|
|
87
|
+
output: str,
|
|
88
|
+
model: str,
|
|
89
|
+
sae: str,
|
|
90
|
+
needs_coulomb: bool | None,
|
|
91
|
+
needs_dispersion: bool | None,
|
|
92
|
+
):
|
|
93
|
+
"""Export trained model to distributable state dict format.
|
|
94
|
+
|
|
95
|
+
weights: Path to the raw PyTorch weights file (.pt).
|
|
96
|
+
outoput: Path to the output .pt file.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
aimnet export weights.pt model.pt --model config.yaml --sae model.sae
|
|
100
|
+
"""
|
|
101
|
+
# Load model YAML
|
|
102
|
+
print(f"Loading config from {model}")
|
|
103
|
+
with open(model, encoding="utf-8") as f:
|
|
104
|
+
model_config = yaml.safe_load(f)
|
|
105
|
+
|
|
106
|
+
# Load SAE
|
|
107
|
+
print(f"Loading SAE from {sae}")
|
|
108
|
+
sae_dict = load_sae(sae)
|
|
109
|
+
implemented_species = get_implemented_species(sae_dict)
|
|
110
|
+
|
|
111
|
+
# Load source state dict
|
|
112
|
+
print(f"Loading weights from {weights}")
|
|
113
|
+
source_sd = torch.load(weights, map_location="cpu", weights_only=True)
|
|
114
|
+
|
|
115
|
+
# Strip LR modules and detect flags
|
|
116
|
+
core_config, coulomb_mode, needs_dispersion_auto, d3_params, coulomb_sr_rc, coulomb_sr_envelope, disp_ptfile = (
|
|
117
|
+
strip_lr_modules_from_yaml(model_config, source_sd)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Serialize YAML BEFORE building module (build_module mutates the dict)
|
|
121
|
+
core_yaml_str = yaml.dump(core_config, default_flow_style=False, sort_keys=False)
|
|
122
|
+
|
|
123
|
+
# Build model from modified config
|
|
124
|
+
print("Building model...")
|
|
125
|
+
core_model = build_module(copy.deepcopy(core_config))
|
|
126
|
+
if not isinstance(core_model, nn.Module):
|
|
127
|
+
raise TypeError("Built module is not an nn.Module")
|
|
128
|
+
|
|
129
|
+
# Load weights with strict=False (modules may differ)
|
|
130
|
+
load_result = core_model.load_state_dict(source_sd, strict=False)
|
|
131
|
+
|
|
132
|
+
# Check for unexpected missing/extra keys
|
|
133
|
+
real_missing, real_unexpected = validate_state_dict_keys(load_result.missing_keys, load_result.unexpected_keys)
|
|
134
|
+
if real_missing:
|
|
135
|
+
print(f"WARNING: Unexpected missing keys: {real_missing}")
|
|
136
|
+
if real_unexpected:
|
|
137
|
+
print(f"WARNING: Unexpected extra keys in source: {real_unexpected}")
|
|
138
|
+
if not real_missing and not real_unexpected:
|
|
139
|
+
print("Loaded weights successfully")
|
|
140
|
+
|
|
141
|
+
# Load dispersion parameters from ptfile and inject into model
|
|
142
|
+
# (raw training weights don't contain disp_param0 buffer)
|
|
143
|
+
if disp_ptfile is not None:
|
|
144
|
+
disp_params = torch.load(disp_ptfile, map_location="cpu", weights_only=True)
|
|
145
|
+
for _name, module in core_model.named_modules():
|
|
146
|
+
if hasattr(module, "disp_param0"):
|
|
147
|
+
# Resize buffer if needed (ptfile may have different shape than placeholder)
|
|
148
|
+
if module.disp_param0.shape != disp_params.shape:
|
|
149
|
+
module.disp_param0 = torch.zeros_like(disp_params)
|
|
150
|
+
module.disp_param0.copy_(disp_params)
|
|
151
|
+
print(f"Loaded disp_param0 from {disp_ptfile}")
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
# Bake SAE into atomic_shift (float64)
|
|
155
|
+
print("Baking SAE into atomic_shift...")
|
|
156
|
+
core_model = bake_sae_into_model(core_model, sae_dict)
|
|
157
|
+
|
|
158
|
+
# Mask not-implemented species
|
|
159
|
+
core_model = mask_not_implemented_species(core_model, implemented_species)
|
|
160
|
+
|
|
161
|
+
# Extract cutoff
|
|
162
|
+
cutoff = extract_cutoff(core_model)
|
|
163
|
+
|
|
164
|
+
# Set model to eval mode
|
|
165
|
+
core_model.eval()
|
|
166
|
+
|
|
167
|
+
# Determine final flags (CLI overrides auto-detection)
|
|
168
|
+
auto_needs_coulomb = coulomb_mode == "sr_embedded"
|
|
169
|
+
auto_needs_dispersion = needs_dispersion_auto
|
|
170
|
+
|
|
171
|
+
final_needs_coulomb = needs_coulomb if needs_coulomb is not None else auto_needs_coulomb
|
|
172
|
+
final_needs_dispersion = needs_dispersion if needs_dispersion is not None else auto_needs_dispersion
|
|
173
|
+
|
|
174
|
+
# Warn if overriding auto-detection
|
|
175
|
+
if needs_coulomb is not None and needs_coulomb != auto_needs_coulomb:
|
|
176
|
+
print(f" Overriding needs_coulomb: {auto_needs_coulomb} -> {needs_coulomb}")
|
|
177
|
+
if needs_dispersion is not None and needs_dispersion != auto_needs_dispersion:
|
|
178
|
+
print(f" Overriding needs_dispersion: {auto_needs_dispersion} -> {needs_dispersion}")
|
|
179
|
+
|
|
180
|
+
# Detect if model has any embedded LR modules that need nbmat_lr
|
|
181
|
+
outputs = model_config.get("kwargs", {}).get("outputs", {})
|
|
182
|
+
has_embedded_lr = False
|
|
183
|
+
|
|
184
|
+
# Check for embedded D3TS (uses NN-predicted C6/alpha, must stay embedded)
|
|
185
|
+
has_d3ts = any("D3TS" in outputs.get(k, {}).get("class", "") for k in ["dftd3", "d3bj", "d3ts"])
|
|
186
|
+
if has_d3ts:
|
|
187
|
+
has_embedded_lr = True
|
|
188
|
+
|
|
189
|
+
# Check for embedded SRCoulomb (model had LRCoulomb before conversion)
|
|
190
|
+
if coulomb_mode == "sr_embedded":
|
|
191
|
+
has_embedded_lr = True
|
|
192
|
+
|
|
193
|
+
# Create new format dict
|
|
194
|
+
new_format = {
|
|
195
|
+
"format_version": 2, # v2 = new .pt format (v1 = legacy .jpt)
|
|
196
|
+
"model_yaml": core_yaml_str,
|
|
197
|
+
"cutoff": cutoff,
|
|
198
|
+
"needs_coulomb": final_needs_coulomb,
|
|
199
|
+
"needs_dispersion": final_needs_dispersion,
|
|
200
|
+
"coulomb_mode": coulomb_mode,
|
|
201
|
+
"coulomb_sr_rc": coulomb_sr_rc if final_needs_coulomb else None,
|
|
202
|
+
"coulomb_sr_envelope": coulomb_sr_envelope if final_needs_coulomb else None,
|
|
203
|
+
"d3_params": d3_params if final_needs_dispersion else None,
|
|
204
|
+
"has_embedded_lr": has_embedded_lr,
|
|
205
|
+
"implemented_species": implemented_species,
|
|
206
|
+
"state_dict": core_model.state_dict(),
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
# Save
|
|
210
|
+
torch.save(new_format, output)
|
|
211
|
+
print(f"\nSaved model to {output}")
|
|
212
|
+
print(f" cutoff: {cutoff}")
|
|
213
|
+
print(f" needs_coulomb: {final_needs_coulomb}")
|
|
214
|
+
print(f" needs_dispersion: {final_needs_dispersion}")
|
|
215
|
+
print(f" coulomb_mode: {coulomb_mode}")
|
|
216
|
+
if final_needs_coulomb:
|
|
217
|
+
print(f" coulomb_sr_rc: {coulomb_sr_rc}")
|
|
218
|
+
print(f" coulomb_sr_envelope: {coulomb_sr_envelope}")
|
|
219
|
+
if final_needs_dispersion:
|
|
220
|
+
print(f" d3_params: {d3_params}")
|
|
221
|
+
print(f" has_embedded_lr: {has_embedded_lr}")
|
|
222
|
+
print(f" implemented_species: {implemented_species}")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
if __name__ == "__main__":
|
|
226
|
+
export_model()
|
aimnet/train/loss.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
@@ -30,7 +30,7 @@ class MTLoss:
|
|
|
30
30
|
Dict[str, Tensor]: total loss under key 'loss' and values for individual components.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
def __init__(self, components:
|
|
33
|
+
def __init__(self, components: dict[str, Any]):
|
|
34
34
|
w_sum = sum(c["weight"] for c in components.values())
|
|
35
35
|
self.components = {}
|
|
36
36
|
for name, c in components.items():
|
|
@@ -38,7 +38,7 @@ class MTLoss:
|
|
|
38
38
|
fn = partial(get_module(c["fn"]), **kwargs)
|
|
39
39
|
self.components[name] = (fn, c["weight"] / w_sum)
|
|
40
40
|
|
|
41
|
-
def __call__(self, y_pred:
|
|
41
|
+
def __call__(self, y_pred: dict[str, Tensor], y_true: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
42
42
|
loss = {}
|
|
43
43
|
for name, (fn, w) in self.components.items():
|
|
44
44
|
_l = fn(y_pred=y_pred, y_true=y_true)
|
|
@@ -48,7 +48,7 @@ class MTLoss:
|
|
|
48
48
|
return loss
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
def mse_loss_fn(y_pred:
|
|
51
|
+
def mse_loss_fn(y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
|
|
52
52
|
"""General MSE loss function"""
|
|
53
53
|
x = y_true[key_true]
|
|
54
54
|
y = y_pred[key_pred]
|
|
@@ -56,7 +56,7 @@ def mse_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred:
|
|
|
56
56
|
return loss
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def peratom_loss_fn(y_pred:
|
|
59
|
+
def peratom_loss_fn(y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
|
|
60
60
|
"""MSE loss function with per-atom normalization correction.
|
|
61
61
|
Suitable when some of the values are zero both in y_pred and y_true due to padding of inputs.
|
|
62
62
|
"""
|
|
@@ -73,11 +73,11 @@ def peratom_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pr
|
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
def energy_loss_fn(
|
|
76
|
-
y_pred:
|
|
76
|
+
y_pred: dict[str, Tensor], y_true: dict[str, Tensor], key_pred: str = "energy", key_true: str = "energy"
|
|
77
77
|
) -> Tensor:
|
|
78
78
|
"""MSE loss normalized by the number of atoms."""
|
|
79
79
|
x = y_true[key_true]
|
|
80
80
|
y = y_pred[key_pred]
|
|
81
|
-
s = y_pred["_natom"].
|
|
81
|
+
s = y_pred["_natom"] ** 0.5
|
|
82
82
|
loss = ((x - y).pow(2) / s).mean() if y_pred["_natom"].numel() > 1 else torch.nn.functional.mse_loss(x, y) / s
|
|
83
83
|
return loss
|