aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/ops.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Dict, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from aimnet import nbops
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def lazy_calc_dij_lr(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
11
|
+
if "d_ij_lr" not in data:
|
|
12
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
13
|
+
if nb_mode == 0:
|
|
14
|
+
data["d_ij_lr"] = data["d_ij"]
|
|
15
|
+
else:
|
|
16
|
+
data["d_ij_lr"] = calc_distances(data, suffix="_lr")[0]
|
|
17
|
+
return data
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def calc_distances(data: Dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> Tuple[Tensor, Tensor]:
|
|
21
|
+
coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix)
|
|
22
|
+
if f"shifts{suffix}" in data:
|
|
23
|
+
assert "cell" in data, "cell is required if shifts are provided"
|
|
24
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
25
|
+
if nb_mode == 2:
|
|
26
|
+
shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], data["cell"])
|
|
27
|
+
else:
|
|
28
|
+
shifts = data[f"shifts{suffix}"] @ data["cell"]
|
|
29
|
+
coord_j = coord_j + shifts
|
|
30
|
+
r_ij = coord_j - coord_i
|
|
31
|
+
d_ij = torch.norm(r_ij, p=2, dim=-1)
|
|
32
|
+
d_ij = nbops.mask_ij_(d_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
|
|
33
|
+
return d_ij, r_ij
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def center_coordinates(coord: Tensor, data: Dict[str, Tensor], masses: Optional[Tensor] = None) -> Tensor:
|
|
37
|
+
if masses is not None:
|
|
38
|
+
masses = masses.unsqueeze(-1)
|
|
39
|
+
center = nbops.mol_sum(coord * masses, data) / nbops.mol_sum(masses, data) / data["mol_sizes"].unsqueeze(-1)
|
|
40
|
+
else:
|
|
41
|
+
center = nbops.mol_sum(coord, data) / data["mol_sizes"]
|
|
42
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
43
|
+
if nb_mode in (0, 2):
|
|
44
|
+
center = center.unsqueeze(-2)
|
|
45
|
+
coord = coord - center
|
|
46
|
+
return coord
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def cosine_cutoff(d_ij: Tensor, rc: float) -> Tensor:
|
|
50
|
+
fc = 0.5 * (torch.cos(d_ij.clamp(min=1e-6, max=rc) * (math.pi / rc)) + 1.0)
|
|
51
|
+
return fc
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def exp_cutoff(d: Tensor, rc: Tensor) -> Tensor:
|
|
55
|
+
fc = torch.exp(-1.0 / (1.0 - (d / rc).clamp(0, 1.0 - 1e-6).pow(2))) / 0.36787944117144233
|
|
56
|
+
return fc
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def exp_expand(d_ij: Tensor, shifts: Tensor, eta: float) -> Tensor:
|
|
60
|
+
# expand on axis -1, e.g. (b, n, m) -> (b, n, m, shifts)
|
|
61
|
+
return torch.exp(-eta * (d_ij.unsqueeze(-1) - shifts) ** 2)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# pylint: disable=invalid-name
|
|
65
|
+
def nse(
|
|
66
|
+
Q: Tensor,
|
|
67
|
+
q_u: Tensor,
|
|
68
|
+
f_u: Tensor,
|
|
69
|
+
data: Dict[str, Tensor],
|
|
70
|
+
epsilon: float = 1.0e-6,
|
|
71
|
+
) -> Tensor:
|
|
72
|
+
# Q and q_u and f_u must have last dimension size 1 or 2
|
|
73
|
+
F_u = nbops.mol_sum(f_u, data) + epsilon
|
|
74
|
+
Q_u = nbops.mol_sum(q_u, data)
|
|
75
|
+
dQ = Q - Q_u
|
|
76
|
+
# for loss
|
|
77
|
+
data["_dQ"] = dQ
|
|
78
|
+
|
|
79
|
+
nb_mode = nbops.get_nb_mode(data)
|
|
80
|
+
if nb_mode in (0, 2):
|
|
81
|
+
F_u = F_u.unsqueeze(-2)
|
|
82
|
+
dQ = dQ.unsqueeze(-2)
|
|
83
|
+
elif nb_mode == 1:
|
|
84
|
+
data["mol_sizes"][-1] += 1
|
|
85
|
+
F_u = torch.repeat_interleave(F_u, data["mol_sizes"], dim=0)
|
|
86
|
+
dQ = torch.repeat_interleave(dQ, data["mol_sizes"], dim=0)
|
|
87
|
+
data["mol_sizes"][-1] -= 1
|
|
88
|
+
else:
|
|
89
|
+
raise ValueError(f"Invalid neighbor mode: {nb_mode}")
|
|
90
|
+
f = f_u / F_u
|
|
91
|
+
q = q_u + f * dQ
|
|
92
|
+
return q
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: Dict[str, Tensor]) -> Tensor:
|
|
96
|
+
_c1 = (alpha * d_ij).erfc() / d_ij
|
|
97
|
+
_c2 = math.erfc(alpha * Rc) / Rc
|
|
98
|
+
_c3 = _c2 / Rc
|
|
99
|
+
_c4 = 2 * alpha * math.exp(-((alpha * Rc) ** 2)) / (Rc * math.pi**0.5)
|
|
100
|
+
J = _c1 - _c2 + (d_ij - Rc) * (_c3 + _c4)
|
|
101
|
+
# mask for d_ij > Rc
|
|
102
|
+
mask = data["mask_ij_lr"] & (d_ij > Rc)
|
|
103
|
+
J.masked_fill_(mask, 0.0)
|
|
104
|
+
return J
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data: Dict[str, Tensor]) -> Tensor:
|
|
108
|
+
_c1 = 1.0 / d_ij
|
|
109
|
+
_c2 = 1.0 / Rc
|
|
110
|
+
_c3 = _c2 / Rc
|
|
111
|
+
J = _c1 - _c2 + (d_ij - Rc) * _c3
|
|
112
|
+
mask = data["mask_ij_lr"] & (d_ij > Rc)
|
|
113
|
+
J.masked_fill_(mask, 0.0)
|
|
114
|
+
return J
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
|
|
118
|
+
assert cell.shape == (3, 3), "Batch cell is not supported"
|
|
119
|
+
cell_inv = torch.inverse(cell).mT
|
|
120
|
+
inv_distances = cell_inv.norm(p=2, dim=-1)
|
|
121
|
+
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
|
|
122
|
+
device = cell.device
|
|
123
|
+
shifts = torch.cartesian_prod(
|
|
124
|
+
torch.arange(-num_repeats[0], num_repeats[0] + 1, device=device), # type: ignore[attr-defined]
|
|
125
|
+
torch.arange(-num_repeats[1], num_repeats[1] + 1, device=device), # type: ignore[attr-defined]
|
|
126
|
+
torch.arange(-num_repeats[2], num_repeats[2] + 1, device=device), # type: ignore[attr-defined]
|
|
127
|
+
).to(torch.float)
|
|
128
|
+
return shifts
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def coulomb_matrix_ewald(coord: Tensor, cell: Tensor) -> Tensor:
|
|
132
|
+
# single molecule implementation. nb_mode == 1
|
|
133
|
+
assert coord.ndim == 2 and cell.ndim == 2, "Only single molecule is supported"
|
|
134
|
+
accuracy = 1e-8
|
|
135
|
+
N = coord.shape[0]
|
|
136
|
+
volume = torch.det(cell)
|
|
137
|
+
eta = ((volume**2 / N) ** (1 / 6)) / math.sqrt(2.0 * math.pi)
|
|
138
|
+
cutoff_real = math.sqrt(-2.0 * math.log(accuracy)) * eta
|
|
139
|
+
cutoff_recip = math.sqrt(-2.0 * math.log(accuracy)) / eta
|
|
140
|
+
|
|
141
|
+
# real space
|
|
142
|
+
_grad_mode = torch.is_grad_enabled()
|
|
143
|
+
torch.set_grad_enabled(False)
|
|
144
|
+
shifts = get_shifts_within_cutoff(cell, cutoff_real) # (num_shifts, 3)
|
|
145
|
+
torch.set_grad_enabled(_grad_mode)
|
|
146
|
+
disps_ij = coord[None, :, :] - coord[:, None, :]
|
|
147
|
+
disps = disps_ij[None, :, :, :] + torch.matmul(shifts, cell)[:, None, None, :]
|
|
148
|
+
distances_all = disps.norm(p=2, dim=-1) # (num_shifts, num_atoms, num_atoms)
|
|
149
|
+
within_cutoff = (distances_all > 0.1) & (distances_all < cutoff_real)
|
|
150
|
+
distances = distances_all[within_cutoff]
|
|
151
|
+
e_real_matrix_aug = torch.zeros_like(distances_all)
|
|
152
|
+
e_real_matrix_aug[within_cutoff] = torch.erfc(distances / (math.sqrt(2) * eta)) / distances
|
|
153
|
+
e_real_matrix = e_real_matrix_aug.sum(dim=0)
|
|
154
|
+
|
|
155
|
+
# reciprocal space
|
|
156
|
+
recip = 2 * math.pi * torch.transpose(torch.linalg.inv(cell), 0, 1)
|
|
157
|
+
_grad_mode = torch.is_grad_enabled()
|
|
158
|
+
torch.set_grad_enabled(False)
|
|
159
|
+
shifts = get_shifts_within_cutoff(recip, cutoff_recip)
|
|
160
|
+
torch.set_grad_enabled(_grad_mode)
|
|
161
|
+
ks_all = torch.matmul(shifts, recip)
|
|
162
|
+
length_all = ks_all.norm(p=2, dim=-1)
|
|
163
|
+
within_cutoff = (length_all > 0.1) & (length_all < cutoff_recip)
|
|
164
|
+
ks = ks_all[within_cutoff]
|
|
165
|
+
length = length_all[within_cutoff]
|
|
166
|
+
# disps_ij[i, j, :] is displacement vector r_{ij}, (num_atoms, num_atoms, 3)
|
|
167
|
+
# disps_ij = coord[None, :, :] - coord[:, None, :] # computed above
|
|
168
|
+
phases = torch.sum(ks[:, None, None, :] * disps_ij[None, :, :, :], dim=-1)
|
|
169
|
+
e_recip_matrix_aug = (
|
|
170
|
+
torch.cos(phases)
|
|
171
|
+
* torch.exp(-0.5 * torch.square(eta * length[:, None, None]))
|
|
172
|
+
/ torch.square(length[:, None, None])
|
|
173
|
+
)
|
|
174
|
+
e_recip_matrix = 4.0 * math.pi / volume * torch.sum(e_recip_matrix_aug, dim=0)
|
|
175
|
+
# self interaction
|
|
176
|
+
device = coord.device
|
|
177
|
+
diag = -math.sqrt(2.0 / math.pi) / eta * torch.ones(N, device=device)
|
|
178
|
+
e_self_matrix = torch.diag(diag)
|
|
179
|
+
|
|
180
|
+
J = e_real_matrix + e_recip_matrix + e_self_matrix
|
|
181
|
+
return J
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def huber(x: Tensor, delta: float = 1.0) -> Tensor:
|
|
185
|
+
return torch.where(x.abs() < delta, 0.5 * x**2, delta * (x.abs() - 0.5 * delta))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def bumpfn(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
|
|
189
|
+
"""For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
|
|
190
|
+
x = (x - low) / (high - low)
|
|
191
|
+
x = x.clamp(min=1e-6, max=1 - 1e-6)
|
|
192
|
+
a = (-1 / x).exp()
|
|
193
|
+
b = (-1 / (1 - x)).exp()
|
|
194
|
+
return a / (a + b)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def smoothstep(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
|
|
198
|
+
"""For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
|
|
199
|
+
x = (x - low) / (high - low)
|
|
200
|
+
x = x.clamp(min=0, max=1)
|
|
201
|
+
return x.pow(3) * (x * (x * 6 - 15) + 10)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def expstep(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
|
|
205
|
+
"""For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
|
|
206
|
+
x = (x - low) / (high - low)
|
|
207
|
+
x = x.clamp(min=1e-6, max=1 - 1e-6)
|
|
208
|
+
return (-1 / (1 - x.pow(2))).exp() / 0.36787944117144233
|
aimnet/train/__init__.py
ADDED
|
File without changes
|
aimnet/train/calc_sae.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from aimnet.data import SizeGroupedDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.command(short_help="Calculate SAE for a dataset.")
|
|
10
|
+
@click.option("--samples", type=int, default=100000, help="Max number of samples to consider.")
|
|
11
|
+
@click.argument("ds", type=str)
|
|
12
|
+
@click.argument("output", type=str)
|
|
13
|
+
def calc_sae(ds, output, samples=100000):
|
|
14
|
+
"""Script to calculate energy SAE for a dataset DS. Writes SAE to OUTPUT file."""
|
|
15
|
+
logging.info(f"Loading dataset from {ds}")
|
|
16
|
+
ds = SizeGroupedDataset(ds, keys=["numbers", "energy"])
|
|
17
|
+
logging.info(f"Loaded dataset with {len(ds)} samples")
|
|
18
|
+
if samples > 0 and len(ds) > samples:
|
|
19
|
+
ds = ds.random_split(samples / len(ds))[0]
|
|
20
|
+
logging.info(f"Using {len(ds)} samples to calculate SAE")
|
|
21
|
+
sae = ds.apply_peratom_shift("energy", "_energy")
|
|
22
|
+
# remove up 2 percentiles from right and left
|
|
23
|
+
energy = ds.concatenate("_energy")
|
|
24
|
+
pct1, pct2 = np.percentile(energy, [2, 98])
|
|
25
|
+
for _n, g in ds.items():
|
|
26
|
+
mask = (g["_energy"] > pct1) & (g["_energy"] < pct2)
|
|
27
|
+
g = g.sample(mask)
|
|
28
|
+
if not len(g):
|
|
29
|
+
ds._data.pop(_n)
|
|
30
|
+
else:
|
|
31
|
+
ds[_n] = g
|
|
32
|
+
# now re-compute SAE
|
|
33
|
+
sae = ds.apply_peratom_shift("energy", "_energy")
|
|
34
|
+
# save
|
|
35
|
+
with open(output, "w") as f:
|
|
36
|
+
for k, v in sae.items():
|
|
37
|
+
s = f"{k}: {v}\n"
|
|
38
|
+
f.write(s)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == "__main__":
|
|
42
|
+
logging.basicConfig(level=logging.INFO)
|
|
43
|
+
calc_sae()
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# required to be set
|
|
2
|
+
run_name: ???
|
|
3
|
+
|
|
4
|
+
data:
|
|
5
|
+
# path to train and validation datasets
|
|
6
|
+
train: ???
|
|
7
|
+
val: null # same data file by default
|
|
8
|
+
# sae files
|
|
9
|
+
sae:
|
|
10
|
+
energy:
|
|
11
|
+
file: "???"
|
|
12
|
+
mode: linreg
|
|
13
|
+
|
|
14
|
+
# fraction of training data to use for validation is separate val file is not provided
|
|
15
|
+
val_fraction: 0.1
|
|
16
|
+
separate_val: true
|
|
17
|
+
# in DDP mode, will load only a shard of the dataset on each worker
|
|
18
|
+
ddp_load_full_dataset: false
|
|
19
|
+
|
|
20
|
+
# data keys
|
|
21
|
+
x: [coord, numbers, charge]
|
|
22
|
+
y: [energy, forces, charges]
|
|
23
|
+
|
|
24
|
+
# dataset class definition
|
|
25
|
+
datasets:
|
|
26
|
+
train:
|
|
27
|
+
class: aimnet.data.SizeGroupedDataset
|
|
28
|
+
kwargs: {}
|
|
29
|
+
val:
|
|
30
|
+
class: aimnet.data.SizeGroupedDataset
|
|
31
|
+
kwargs: {}
|
|
32
|
+
|
|
33
|
+
# sampler class definition
|
|
34
|
+
samplers:
|
|
35
|
+
train:
|
|
36
|
+
class: aimnet.data.SizeGroupedSampler
|
|
37
|
+
kwargs:
|
|
38
|
+
# this value if for each of DDP worker, total batch size is `batch_size*world_size`
|
|
39
|
+
batch_size: 512
|
|
40
|
+
# could be set to 'atoms', than 'batch_size' could be around 16384
|
|
41
|
+
batch_mode: molecules
|
|
42
|
+
shuffle: True
|
|
43
|
+
# for extra large datasets we want to do evaluation more often then onece per full epoch
|
|
44
|
+
# this value sets the size the epoch. `batches_per_epoch*batch_size` could be smaller or larger than the dataset size
|
|
45
|
+
batches_per_epoch: 10000
|
|
46
|
+
val:
|
|
47
|
+
class: aimnet.data.SizeGroupedSampler
|
|
48
|
+
kwargs:
|
|
49
|
+
batch_size: 1024
|
|
50
|
+
batch_mode: molecules
|
|
51
|
+
shuffle: False
|
|
52
|
+
# full dataset
|
|
53
|
+
batches_per_epoch: -1
|
|
54
|
+
|
|
55
|
+
# any additional torch.util.data.DataLoader options
|
|
56
|
+
# num_workers=0 and pin_memory=True are recommended
|
|
57
|
+
loaders:
|
|
58
|
+
train:
|
|
59
|
+
num_workers: 0
|
|
60
|
+
pin_memory: true
|
|
61
|
+
val:
|
|
62
|
+
num_workers: 0
|
|
63
|
+
pin_memory: true
|
|
64
|
+
|
|
65
|
+
# definition for loss function class. Modify if trainig on different targets
|
|
66
|
+
loss:
|
|
67
|
+
class: aimnet.train.loss.MTLoss
|
|
68
|
+
kwargs:
|
|
69
|
+
components:
|
|
70
|
+
energy:
|
|
71
|
+
fn: aimnet.train.loss.energy_loss_fn
|
|
72
|
+
weight: 1.0
|
|
73
|
+
forces:
|
|
74
|
+
fn: aimnet.train.loss.peratom_loss_fn
|
|
75
|
+
weight: 0.2
|
|
76
|
+
kwargs:
|
|
77
|
+
key_true: forces
|
|
78
|
+
key_pred: forces
|
|
79
|
+
charges:
|
|
80
|
+
fn: aimnet.train.loss.peratom_loss_fn
|
|
81
|
+
weight: 0.05
|
|
82
|
+
kwargs:
|
|
83
|
+
key_true: charges
|
|
84
|
+
key_pred: charges
|
|
85
|
+
|
|
86
|
+
optimizer:
|
|
87
|
+
# lists of regular expressions for parameter names to enable or disable gradients
|
|
88
|
+
# force_no_grad will be processed first
|
|
89
|
+
force_no_train: []
|
|
90
|
+
force_train: []
|
|
91
|
+
class: torch.optim.RAdam
|
|
92
|
+
kwargs:
|
|
93
|
+
lr: 0.0004
|
|
94
|
+
weight_decay: 1e-8
|
|
95
|
+
# parameters with non-default optimizer settings
|
|
96
|
+
param_groups:
|
|
97
|
+
shifts:
|
|
98
|
+
re: ".*.atomic_shift.shifts.weight$"
|
|
99
|
+
weight_decay: 0.0
|
|
100
|
+
|
|
101
|
+
# class definition for learning rate scheduler
|
|
102
|
+
scheduler:
|
|
103
|
+
class: ignite.handlers.param_scheduler.ReduceLROnPlateauScheduler
|
|
104
|
+
kwargs:
|
|
105
|
+
metric_name: loss
|
|
106
|
+
factor: 0.75
|
|
107
|
+
patience: 10
|
|
108
|
+
# terminate training if learning rate is lower than this value
|
|
109
|
+
# useful for ReduceLROnPlateauScheduler
|
|
110
|
+
terminate_on_low_lr: 1.0e-5
|
|
111
|
+
|
|
112
|
+
trainer:
|
|
113
|
+
# function that define trainig and validation loops
|
|
114
|
+
trainer: aimnet.train.utils.default_trainer
|
|
115
|
+
evaluator: aimnet.train.utils.default_evaluator
|
|
116
|
+
# total number of epochs to train
|
|
117
|
+
epochs: 100
|
|
118
|
+
|
|
119
|
+
# perdicaly save chechpoints, set to null to disable
|
|
120
|
+
checkpoint:
|
|
121
|
+
dirname: checkpoints
|
|
122
|
+
filename_prefix: ${run_name}
|
|
123
|
+
kwargs:
|
|
124
|
+
n_saved: 1
|
|
125
|
+
require_empty: False
|
|
126
|
+
|
|
127
|
+
# wandb logger
|
|
128
|
+
wandb:
|
|
129
|
+
init:
|
|
130
|
+
name: ${run_name}
|
|
131
|
+
mode: offline
|
|
132
|
+
entity: null
|
|
133
|
+
project: null
|
|
134
|
+
notes: null
|
|
135
|
+
watch_model:
|
|
136
|
+
log: all
|
|
137
|
+
log_freq: 1000
|
|
138
|
+
log_graph: true
|
|
139
|
+
|
|
140
|
+
# standard set of metrics. Add an entry if training on different targets
|
|
141
|
+
metrics:
|
|
142
|
+
class: aimnet.train.metrics.RegMultiMetric
|
|
143
|
+
kwargs:
|
|
144
|
+
cfg:
|
|
145
|
+
energy:
|
|
146
|
+
abbr: E
|
|
147
|
+
scale: 23.06 # eV to kcal/mol
|
|
148
|
+
dipole:
|
|
149
|
+
abbr: D
|
|
150
|
+
scale: 1.0
|
|
151
|
+
mult: 3
|
|
152
|
+
quadrupole:
|
|
153
|
+
abbr: Q
|
|
154
|
+
scale: 1.0
|
|
155
|
+
mult: 6
|
|
156
|
+
charges:
|
|
157
|
+
abbr: q
|
|
158
|
+
peratom: True
|
|
159
|
+
volumes:
|
|
160
|
+
abbr: V
|
|
161
|
+
peratom: True
|
|
162
|
+
forces:
|
|
163
|
+
abbr: F
|
|
164
|
+
peratom: True
|
|
165
|
+
mult: 3
|
|
166
|
+
scale: 23.06
|
aimnet/train/loss.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from aimnet.config import get_module
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MTLoss:
|
|
11
|
+
"""Multi-target loss function with fixed weights.
|
|
12
|
+
|
|
13
|
+
This class allows for the combination of multiple loss functions, each with a specified weight.
|
|
14
|
+
The weights are normalized to sum to 1. The loss functions are applied to the predictions and
|
|
15
|
+
true values, and the weighted sum of the losses is computed.
|
|
16
|
+
|
|
17
|
+
Loss functions definition must contain keys:
|
|
18
|
+
name (str): The name of the loss function.
|
|
19
|
+
fn (str): The loss function (e.g. `aimnet2.train.loss.mse_loss_fn`).
|
|
20
|
+
weight (float): The weight of the loss function.
|
|
21
|
+
kwargs (Dict): Optional, additional keyword arguments for the loss function.
|
|
22
|
+
|
|
23
|
+
Methods:
|
|
24
|
+
__call__(y_pred, y_true):
|
|
25
|
+
Computes the weighted sum of the losses from the individual loss functions.
|
|
26
|
+
Args:
|
|
27
|
+
y_pred (Dict[str, Tensor]): Predicted values.
|
|
28
|
+
y_true (Dict[str, Tensor]): True values.
|
|
29
|
+
Returns:
|
|
30
|
+
Dict[str, Tensor]: total loss under key 'loss' and values for individual components.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, components: Dict[str, Any]):
|
|
34
|
+
w_sum = sum(c["weight"] for c in components.values())
|
|
35
|
+
self.components = {}
|
|
36
|
+
for name, c in components.items():
|
|
37
|
+
kwargs = c.get("kwargs", {})
|
|
38
|
+
fn = partial(get_module(c["fn"]), **kwargs)
|
|
39
|
+
self.components[name] = (fn, c["weight"] / w_sum)
|
|
40
|
+
|
|
41
|
+
def __call__(self, y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
42
|
+
loss = {}
|
|
43
|
+
for name, (fn, w) in self.components.items():
|
|
44
|
+
_l = fn(y_pred=y_pred, y_true=y_true)
|
|
45
|
+
loss[name] = _l * w
|
|
46
|
+
# special name for the total loss
|
|
47
|
+
loss["loss"] = sum(loss.values())
|
|
48
|
+
return loss
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def mse_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
|
|
52
|
+
"""General MSE loss function"""
|
|
53
|
+
x = y_true[key_true]
|
|
54
|
+
y = y_pred[key_pred]
|
|
55
|
+
loss = torch.nn.functional.mse_loss(x, y)
|
|
56
|
+
return loss
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def peratom_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
|
|
60
|
+
"""MSE loss function with per-atom normalization correction.
|
|
61
|
+
Suitable when some of the values are zero both in y_pred and y_true due to padding of inputs.
|
|
62
|
+
"""
|
|
63
|
+
x = y_true[key_true]
|
|
64
|
+
y = y_pred[key_pred]
|
|
65
|
+
|
|
66
|
+
if y_pred["_natom"].numel() == 1:
|
|
67
|
+
loss = torch.nn.functional.mse_loss(x, y)
|
|
68
|
+
else:
|
|
69
|
+
diff2 = (x - y).pow(2).view(x.shape[0], -1)
|
|
70
|
+
dim = diff2.shape[-1]
|
|
71
|
+
loss = (diff2 * (y_pred["_natom"].unsqueeze(-1) / dim)).mean()
|
|
72
|
+
return loss
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def energy_loss_fn(
|
|
76
|
+
y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str = "energy", key_true: str = "energy"
|
|
77
|
+
) -> Tensor:
|
|
78
|
+
"""MSE loss normalized by the number of atoms."""
|
|
79
|
+
x = y_true[key_true]
|
|
80
|
+
y = y_pred[key_pred]
|
|
81
|
+
s = y_pred["_natom"].sqrt()
|
|
82
|
+
loss = ((x - y).pow(2) / s).mean() if y_pred["_natom"].numel() > 1 else torch.nn.functional.mse_loss(x, y) / s
|
|
83
|
+
return loss
|
aimnet/train/metrics.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import ignite.distributed as idist
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from ignite.exceptions import NotComputableError
|
|
9
|
+
from ignite.metrics import Metric
|
|
10
|
+
from ignite.metrics.metric import reinit__is_reduced
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def regression_stats(pred: Tensor, true: Tensor) -> Dict[str, Tensor]:
|
|
15
|
+
diff = true - pred
|
|
16
|
+
diff2 = diff.pow(2)
|
|
17
|
+
mae = diff.abs().mean(-1)
|
|
18
|
+
rmse = diff2.mean(-1).sqrt()
|
|
19
|
+
true_mean = true.mean()
|
|
20
|
+
tot = (true - true_mean).pow(2).to(torch.double).sum()
|
|
21
|
+
res = diff2.to(torch.double).sum(-1)
|
|
22
|
+
r2 = 1 - (res / tot)
|
|
23
|
+
return {"mae": mae, "rmse": rmse, "r2": r2}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def cat_flatten(y_pred: Tensor, y_true: Tensor) -> Tuple[Tensor, Tensor]:
|
|
27
|
+
if isinstance(y_true, (list, tuple)):
|
|
28
|
+
y_true = torch.cat([x.view(-1) for x in y_true])
|
|
29
|
+
if isinstance(y_pred, (list, tuple)):
|
|
30
|
+
_n = sum(x.numel() for x in y_pred)
|
|
31
|
+
_npass = _n // y_true.numel()
|
|
32
|
+
y_pred = torch.cat([x.view(_npass, -1) for x in y_pred], dim=1)
|
|
33
|
+
y_true = y_true.view(-1)
|
|
34
|
+
if y_pred.ndim > y_true.ndim:
|
|
35
|
+
if y_pred.shape[1] != y_true.shape[0]:
|
|
36
|
+
y_pred = y_pred.view(-1, y_true.shape[0])
|
|
37
|
+
_npass = y_pred.shape[0]
|
|
38
|
+
y_pred = y_pred.view(_npass, -1)
|
|
39
|
+
else:
|
|
40
|
+
y_pred = y_pred.view(-1)
|
|
41
|
+
return y_pred, y_true
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _iqr(a):
|
|
45
|
+
a = a.view(-1)
|
|
46
|
+
k1 = 1 + round(0.25 * (a.numel() - 1))
|
|
47
|
+
k2 = 1 + round(0.75 * (a.numel() - 1))
|
|
48
|
+
v1 = a.kthvalue(k1).values.item()
|
|
49
|
+
v2 = a.kthvalue(k2).values.item()
|
|
50
|
+
return v2 - v1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _freedman_diaconis_bins(a, max_bins=50):
|
|
54
|
+
"""Calculate number of hist bins using Freedman-Diaconis rule."""
|
|
55
|
+
# From https://stats.stackexchange.com/questions/798/
|
|
56
|
+
if a.numel() < 2:
|
|
57
|
+
return 1
|
|
58
|
+
h = 2 * _iqr(a) / (a.numel() ** (1 / 3))
|
|
59
|
+
# fall back to sqrt(a) bins if iqr is 0
|
|
60
|
+
n_bins = int(np.sqrt(a.numel())) if h == 0 else int(np.ceil((a.max().item() - a.min().item()) / h))
|
|
61
|
+
return min(n_bins, max_bins)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def calculate_metrics(result, histogram=False, corrplot=False):
|
|
65
|
+
keys = [k[:-5] for k in result if k.endswith("_pred")]
|
|
66
|
+
for k in keys:
|
|
67
|
+
y_pred = result.pop(k + "_pred")
|
|
68
|
+
y_true = result.pop(k + "_true")
|
|
69
|
+
y_pred, y_true = cat_flatten(y_pred, y_true)
|
|
70
|
+
stats = regression_stats(y_pred, y_true)
|
|
71
|
+
npass = stats["mae"].numel()
|
|
72
|
+
if k.split(".")[-1] in ("energy", "forces"): # noqa: SIM108
|
|
73
|
+
f = 23.06 # eV to kcal/mol
|
|
74
|
+
else:
|
|
75
|
+
f = 1.0
|
|
76
|
+
for i in range(npass):
|
|
77
|
+
for m, v in stats.items():
|
|
78
|
+
if m in ("mae", "rmse"):
|
|
79
|
+
v[i] = v[i] * f
|
|
80
|
+
result.log(f"{k}_{m}_{i}", v[i])
|
|
81
|
+
if histogram:
|
|
82
|
+
err = y_pred - y_true
|
|
83
|
+
for i in range(npass):
|
|
84
|
+
result[f"{k}_{i}_hist"] = torch.histc(err, bins=_freedman_diaconis_bins(err))
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class RegMultiMetric(Metric):
|
|
89
|
+
def __init__(self, cfg: List[Dict], loss_fn=None):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.cfg = cfg
|
|
92
|
+
self.loss_fn = loss_fn
|
|
93
|
+
|
|
94
|
+
def attach_loss(self, loss_fn):
|
|
95
|
+
self.loss_fn = loss_fn
|
|
96
|
+
|
|
97
|
+
@reinit__is_reduced
|
|
98
|
+
def reset(self):
|
|
99
|
+
super().reset()
|
|
100
|
+
self.data = defaultdict(lambda: defaultdict(float))
|
|
101
|
+
self.atoms = 0.0
|
|
102
|
+
self.samples = 0.0
|
|
103
|
+
self.loss = defaultdict(float)
|
|
104
|
+
|
|
105
|
+
def _update_one(self, key: str, pred: Tensor, true: Tensor) -> None:
|
|
106
|
+
e = true - pred
|
|
107
|
+
e = e.view(pred.shape[0], -1) if pred.ndim > true.ndim else e.view(-1)
|
|
108
|
+
d = self.data[key]
|
|
109
|
+
d["sum_abs_err"] += e.abs().sum(-1).to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
|
|
110
|
+
d["sum_sq_err"] += e.pow(2).sum(-1).to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
|
|
111
|
+
d["sum_true"] += true.sum().to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
|
|
112
|
+
d["sum_sq_true"] += true.pow(2).sum().to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
|
|
113
|
+
|
|
114
|
+
@reinit__is_reduced
|
|
115
|
+
def update(self, output) -> None:
|
|
116
|
+
y_pred, y_true = output
|
|
117
|
+
if y_pred is None:
|
|
118
|
+
return
|
|
119
|
+
for k in y_pred:
|
|
120
|
+
if k not in y_true:
|
|
121
|
+
continue
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
self._update_one(k, y_pred[k].detach(), y_true[k].detach())
|
|
124
|
+
b = y_true[k].shape[0]
|
|
125
|
+
self.samples += b
|
|
126
|
+
|
|
127
|
+
_n = y_pred["_natom"]
|
|
128
|
+
if _n.numel() > 1:
|
|
129
|
+
self.atoms += _n.sum().item()
|
|
130
|
+
else:
|
|
131
|
+
self.atoms += y_pred["numbers"].shape[0] * y_pred["numbers"].shape[1]
|
|
132
|
+
if self.loss_fn is not None:
|
|
133
|
+
with torch.no_grad():
|
|
134
|
+
loss_d = self.loss_fn(y_pred, y_true)
|
|
135
|
+
for k, loss in loss_d.items():
|
|
136
|
+
if isinstance(loss, Tensor):
|
|
137
|
+
if loss.numel() > 1:
|
|
138
|
+
loss = loss.mean()
|
|
139
|
+
loss = loss.item()
|
|
140
|
+
self.loss[k] += loss * b
|
|
141
|
+
|
|
142
|
+
def compute(self) -> Dict[str, float]:
|
|
143
|
+
if self.samples == 0:
|
|
144
|
+
raise NotComputableError
|
|
145
|
+
# Use custom reduction
|
|
146
|
+
if idist.get_world_size() > 1:
|
|
147
|
+
self.atoms = idist.all_reduce(self.atoms)
|
|
148
|
+
self.samples = idist.all_reduce(self.samples)
|
|
149
|
+
for k, loss in self.loss.items():
|
|
150
|
+
self.loss[k] = idist.all_reduce(loss) # type: ignore[attr-defined]
|
|
151
|
+
for k1, v1 in self.data.items():
|
|
152
|
+
for k2, v2 in v1.items():
|
|
153
|
+
self.data[k1][k2] = idist.all_reduce(v2) # type: ignore[attr-defined]
|
|
154
|
+
self._is_reduced = True
|
|
155
|
+
|
|
156
|
+
# compute
|
|
157
|
+
ret = {}
|
|
158
|
+
for k in self.data:
|
|
159
|
+
if k not in self.cfg:
|
|
160
|
+
continue
|
|
161
|
+
cfg = self.cfg[k]
|
|
162
|
+
_n = self.atoms if cfg.get("peratom", False) else self.samples
|
|
163
|
+
_n *= cfg.get("mult", 1.0)
|
|
164
|
+
name = k
|
|
165
|
+
abbr = cfg["abbr"]
|
|
166
|
+
v = self.data[name]
|
|
167
|
+
m = {}
|
|
168
|
+
m["mae"] = v["sum_abs_err"] / _n
|
|
169
|
+
m["rmse"] = (v["sum_sq_err"] / _n).sqrt()
|
|
170
|
+
m["r2"] = 1.0 - v["sum_sq_err"] / (v["sum_sq_true"] - (v["sum_true"].pow(2)) / _n) # type: ignore[attr-defined]
|
|
171
|
+
for k, v in m.items():
|
|
172
|
+
if k in ("mae", "rmse"):
|
|
173
|
+
v *= cfg.get("scale", 1.0)
|
|
174
|
+
v = v.tolist()
|
|
175
|
+
if isinstance(v, list):
|
|
176
|
+
for ii, vv in enumerate(v):
|
|
177
|
+
ret[f"{abbr}_{k}_{ii}"] = vv
|
|
178
|
+
else:
|
|
179
|
+
ret[f"{abbr}_{k}"] = v
|
|
180
|
+
if len(self.loss):
|
|
181
|
+
for k, loss in self.loss.items():
|
|
182
|
+
if not k.endswith("loss"):
|
|
183
|
+
k = k + "_loss"
|
|
184
|
+
ret[k] = loss / self.samples
|
|
185
|
+
|
|
186
|
+
logging.info(str(ret))
|
|
187
|
+
|
|
188
|
+
return ret
|