aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/ops.py ADDED
@@ -0,0 +1,208 @@
1
+ import math
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from aimnet import nbops
8
+
9
+
10
+ def lazy_calc_dij_lr(data: Dict[str, Tensor]) -> Dict[str, Tensor]:
11
+ if "d_ij_lr" not in data:
12
+ nb_mode = nbops.get_nb_mode(data)
13
+ if nb_mode == 0:
14
+ data["d_ij_lr"] = data["d_ij"]
15
+ else:
16
+ data["d_ij_lr"] = calc_distances(data, suffix="_lr")[0]
17
+ return data
18
+
19
+
20
+ def calc_distances(data: Dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> Tuple[Tensor, Tensor]:
21
+ coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix)
22
+ if f"shifts{suffix}" in data:
23
+ assert "cell" in data, "cell is required if shifts are provided"
24
+ nb_mode = nbops.get_nb_mode(data)
25
+ if nb_mode == 2:
26
+ shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], data["cell"])
27
+ else:
28
+ shifts = data[f"shifts{suffix}"] @ data["cell"]
29
+ coord_j = coord_j + shifts
30
+ r_ij = coord_j - coord_i
31
+ d_ij = torch.norm(r_ij, p=2, dim=-1)
32
+ d_ij = nbops.mask_ij_(d_ij, data, mask_value=pad_value, inplace=False, suffix=suffix)
33
+ return d_ij, r_ij
34
+
35
+
36
+ def center_coordinates(coord: Tensor, data: Dict[str, Tensor], masses: Optional[Tensor] = None) -> Tensor:
37
+ if masses is not None:
38
+ masses = masses.unsqueeze(-1)
39
+ center = nbops.mol_sum(coord * masses, data) / nbops.mol_sum(masses, data) / data["mol_sizes"].unsqueeze(-1)
40
+ else:
41
+ center = nbops.mol_sum(coord, data) / data["mol_sizes"]
42
+ nb_mode = nbops.get_nb_mode(data)
43
+ if nb_mode in (0, 2):
44
+ center = center.unsqueeze(-2)
45
+ coord = coord - center
46
+ return coord
47
+
48
+
49
+ def cosine_cutoff(d_ij: Tensor, rc: float) -> Tensor:
50
+ fc = 0.5 * (torch.cos(d_ij.clamp(min=1e-6, max=rc) * (math.pi / rc)) + 1.0)
51
+ return fc
52
+
53
+
54
+ def exp_cutoff(d: Tensor, rc: Tensor) -> Tensor:
55
+ fc = torch.exp(-1.0 / (1.0 - (d / rc).clamp(0, 1.0 - 1e-6).pow(2))) / 0.36787944117144233
56
+ return fc
57
+
58
+
59
+ def exp_expand(d_ij: Tensor, shifts: Tensor, eta: float) -> Tensor:
60
+ # expand on axis -1, e.g. (b, n, m) -> (b, n, m, shifts)
61
+ return torch.exp(-eta * (d_ij.unsqueeze(-1) - shifts) ** 2)
62
+
63
+
64
+ # pylint: disable=invalid-name
65
+ def nse(
66
+ Q: Tensor,
67
+ q_u: Tensor,
68
+ f_u: Tensor,
69
+ data: Dict[str, Tensor],
70
+ epsilon: float = 1.0e-6,
71
+ ) -> Tensor:
72
+ # Q and q_u and f_u must have last dimension size 1 or 2
73
+ F_u = nbops.mol_sum(f_u, data) + epsilon
74
+ Q_u = nbops.mol_sum(q_u, data)
75
+ dQ = Q - Q_u
76
+ # for loss
77
+ data["_dQ"] = dQ
78
+
79
+ nb_mode = nbops.get_nb_mode(data)
80
+ if nb_mode in (0, 2):
81
+ F_u = F_u.unsqueeze(-2)
82
+ dQ = dQ.unsqueeze(-2)
83
+ elif nb_mode == 1:
84
+ data["mol_sizes"][-1] += 1
85
+ F_u = torch.repeat_interleave(F_u, data["mol_sizes"], dim=0)
86
+ dQ = torch.repeat_interleave(dQ, data["mol_sizes"], dim=0)
87
+ data["mol_sizes"][-1] -= 1
88
+ else:
89
+ raise ValueError(f"Invalid neighbor mode: {nb_mode}")
90
+ f = f_u / F_u
91
+ q = q_u + f * dQ
92
+ return q
93
+
94
+
95
+ def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: Dict[str, Tensor]) -> Tensor:
96
+ _c1 = (alpha * d_ij).erfc() / d_ij
97
+ _c2 = math.erfc(alpha * Rc) / Rc
98
+ _c3 = _c2 / Rc
99
+ _c4 = 2 * alpha * math.exp(-((alpha * Rc) ** 2)) / (Rc * math.pi**0.5)
100
+ J = _c1 - _c2 + (d_ij - Rc) * (_c3 + _c4)
101
+ # mask for d_ij > Rc
102
+ mask = data["mask_ij_lr"] & (d_ij > Rc)
103
+ J.masked_fill_(mask, 0.0)
104
+ return J
105
+
106
+
107
+ def coulomb_matrix_sf(q_j: Tensor, d_ij: Tensor, Rc: float, data: Dict[str, Tensor]) -> Tensor:
108
+ _c1 = 1.0 / d_ij
109
+ _c2 = 1.0 / Rc
110
+ _c3 = _c2 / Rc
111
+ J = _c1 - _c2 + (d_ij - Rc) * _c3
112
+ mask = data["mask_ij_lr"] & (d_ij > Rc)
113
+ J.masked_fill_(mask, 0.0)
114
+ return J
115
+
116
+
117
+ def get_shifts_within_cutoff(cell: Tensor, cutoff: Tensor) -> Tensor:
118
+ assert cell.shape == (3, 3), "Batch cell is not supported"
119
+ cell_inv = torch.inverse(cell).mT
120
+ inv_distances = cell_inv.norm(p=2, dim=-1)
121
+ num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
122
+ device = cell.device
123
+ shifts = torch.cartesian_prod(
124
+ torch.arange(-num_repeats[0], num_repeats[0] + 1, device=device), # type: ignore[attr-defined]
125
+ torch.arange(-num_repeats[1], num_repeats[1] + 1, device=device), # type: ignore[attr-defined]
126
+ torch.arange(-num_repeats[2], num_repeats[2] + 1, device=device), # type: ignore[attr-defined]
127
+ ).to(torch.float)
128
+ return shifts
129
+
130
+
131
+ def coulomb_matrix_ewald(coord: Tensor, cell: Tensor) -> Tensor:
132
+ # single molecule implementation. nb_mode == 1
133
+ assert coord.ndim == 2 and cell.ndim == 2, "Only single molecule is supported"
134
+ accuracy = 1e-8
135
+ N = coord.shape[0]
136
+ volume = torch.det(cell)
137
+ eta = ((volume**2 / N) ** (1 / 6)) / math.sqrt(2.0 * math.pi)
138
+ cutoff_real = math.sqrt(-2.0 * math.log(accuracy)) * eta
139
+ cutoff_recip = math.sqrt(-2.0 * math.log(accuracy)) / eta
140
+
141
+ # real space
142
+ _grad_mode = torch.is_grad_enabled()
143
+ torch.set_grad_enabled(False)
144
+ shifts = get_shifts_within_cutoff(cell, cutoff_real) # (num_shifts, 3)
145
+ torch.set_grad_enabled(_grad_mode)
146
+ disps_ij = coord[None, :, :] - coord[:, None, :]
147
+ disps = disps_ij[None, :, :, :] + torch.matmul(shifts, cell)[:, None, None, :]
148
+ distances_all = disps.norm(p=2, dim=-1) # (num_shifts, num_atoms, num_atoms)
149
+ within_cutoff = (distances_all > 0.1) & (distances_all < cutoff_real)
150
+ distances = distances_all[within_cutoff]
151
+ e_real_matrix_aug = torch.zeros_like(distances_all)
152
+ e_real_matrix_aug[within_cutoff] = torch.erfc(distances / (math.sqrt(2) * eta)) / distances
153
+ e_real_matrix = e_real_matrix_aug.sum(dim=0)
154
+
155
+ # reciprocal space
156
+ recip = 2 * math.pi * torch.transpose(torch.linalg.inv(cell), 0, 1)
157
+ _grad_mode = torch.is_grad_enabled()
158
+ torch.set_grad_enabled(False)
159
+ shifts = get_shifts_within_cutoff(recip, cutoff_recip)
160
+ torch.set_grad_enabled(_grad_mode)
161
+ ks_all = torch.matmul(shifts, recip)
162
+ length_all = ks_all.norm(p=2, dim=-1)
163
+ within_cutoff = (length_all > 0.1) & (length_all < cutoff_recip)
164
+ ks = ks_all[within_cutoff]
165
+ length = length_all[within_cutoff]
166
+ # disps_ij[i, j, :] is displacement vector r_{ij}, (num_atoms, num_atoms, 3)
167
+ # disps_ij = coord[None, :, :] - coord[:, None, :] # computed above
168
+ phases = torch.sum(ks[:, None, None, :] * disps_ij[None, :, :, :], dim=-1)
169
+ e_recip_matrix_aug = (
170
+ torch.cos(phases)
171
+ * torch.exp(-0.5 * torch.square(eta * length[:, None, None]))
172
+ / torch.square(length[:, None, None])
173
+ )
174
+ e_recip_matrix = 4.0 * math.pi / volume * torch.sum(e_recip_matrix_aug, dim=0)
175
+ # self interaction
176
+ device = coord.device
177
+ diag = -math.sqrt(2.0 / math.pi) / eta * torch.ones(N, device=device)
178
+ e_self_matrix = torch.diag(diag)
179
+
180
+ J = e_real_matrix + e_recip_matrix + e_self_matrix
181
+ return J
182
+
183
+
184
+ def huber(x: Tensor, delta: float = 1.0) -> Tensor:
185
+ return torch.where(x.abs() < delta, 0.5 * x**2, delta * (x.abs() - 0.5 * delta))
186
+
187
+
188
+ def bumpfn(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
189
+ """For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
190
+ x = (x - low) / (high - low)
191
+ x = x.clamp(min=1e-6, max=1 - 1e-6)
192
+ a = (-1 / x).exp()
193
+ b = (-1 / (1 - x)).exp()
194
+ return a / (a + b)
195
+
196
+
197
+ def smoothstep(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
198
+ """For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
199
+ x = (x - low) / (high - low)
200
+ x = x.clamp(min=0, max=1)
201
+ return x.pow(3) * (x * (x * 6 - 15) + 10)
202
+
203
+
204
+ def expstep(x: Tensor, low: float = 0.0, high: float = 1.0) -> Tensor:
205
+ """For x > 0, return smooth transition function which is 0 for x <= low and 1 for x >= high"""
206
+ x = (x - low) / (high - low)
207
+ x = x.clamp(min=1e-6, max=1 - 1e-6)
208
+ return (-1 / (1 - x.pow(2))).exp() / 0.36787944117144233
File without changes
@@ -0,0 +1,43 @@
1
+ import logging
2
+
3
+ import click
4
+ import numpy as np
5
+
6
+ from aimnet.data import SizeGroupedDataset
7
+
8
+
9
+ @click.command(short_help="Calculate SAE for a dataset.")
10
+ @click.option("--samples", type=int, default=100000, help="Max number of samples to consider.")
11
+ @click.argument("ds", type=str)
12
+ @click.argument("output", type=str)
13
+ def calc_sae(ds, output, samples=100000):
14
+ """Script to calculate energy SAE for a dataset DS. Writes SAE to OUTPUT file."""
15
+ logging.info(f"Loading dataset from {ds}")
16
+ ds = SizeGroupedDataset(ds, keys=["numbers", "energy"])
17
+ logging.info(f"Loaded dataset with {len(ds)} samples")
18
+ if samples > 0 and len(ds) > samples:
19
+ ds = ds.random_split(samples / len(ds))[0]
20
+ logging.info(f"Using {len(ds)} samples to calculate SAE")
21
+ sae = ds.apply_peratom_shift("energy", "_energy")
22
+ # remove up 2 percentiles from right and left
23
+ energy = ds.concatenate("_energy")
24
+ pct1, pct2 = np.percentile(energy, [2, 98])
25
+ for _n, g in ds.items():
26
+ mask = (g["_energy"] > pct1) & (g["_energy"] < pct2)
27
+ g = g.sample(mask)
28
+ if not len(g):
29
+ ds._data.pop(_n)
30
+ else:
31
+ ds[_n] = g
32
+ # now re-compute SAE
33
+ sae = ds.apply_peratom_shift("energy", "_energy")
34
+ # save
35
+ with open(output, "w") as f:
36
+ for k, v in sae.items():
37
+ s = f"{k}: {v}\n"
38
+ f.write(s)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ logging.basicConfig(level=logging.INFO)
43
+ calc_sae()
@@ -0,0 +1,166 @@
1
+ # required to be set
2
+ run_name: ???
3
+
4
+ data:
5
+ # path to train and validation datasets
6
+ train: ???
7
+ val: null # same data file by default
8
+ # sae files
9
+ sae:
10
+ energy:
11
+ file: "???"
12
+ mode: linreg
13
+
14
+ # fraction of training data to use for validation is separate val file is not provided
15
+ val_fraction: 0.1
16
+ separate_val: true
17
+ # in DDP mode, will load only a shard of the dataset on each worker
18
+ ddp_load_full_dataset: false
19
+
20
+ # data keys
21
+ x: [coord, numbers, charge]
22
+ y: [energy, forces, charges]
23
+
24
+ # dataset class definition
25
+ datasets:
26
+ train:
27
+ class: aimnet.data.SizeGroupedDataset
28
+ kwargs: {}
29
+ val:
30
+ class: aimnet.data.SizeGroupedDataset
31
+ kwargs: {}
32
+
33
+ # sampler class definition
34
+ samplers:
35
+ train:
36
+ class: aimnet.data.SizeGroupedSampler
37
+ kwargs:
38
+ # this value if for each of DDP worker, total batch size is `batch_size*world_size`
39
+ batch_size: 512
40
+ # could be set to 'atoms', than 'batch_size' could be around 16384
41
+ batch_mode: molecules
42
+ shuffle: True
43
+ # for extra large datasets we want to do evaluation more often then onece per full epoch
44
+ # this value sets the size the epoch. `batches_per_epoch*batch_size` could be smaller or larger than the dataset size
45
+ batches_per_epoch: 10000
46
+ val:
47
+ class: aimnet.data.SizeGroupedSampler
48
+ kwargs:
49
+ batch_size: 1024
50
+ batch_mode: molecules
51
+ shuffle: False
52
+ # full dataset
53
+ batches_per_epoch: -1
54
+
55
+ # any additional torch.util.data.DataLoader options
56
+ # num_workers=0 and pin_memory=True are recommended
57
+ loaders:
58
+ train:
59
+ num_workers: 0
60
+ pin_memory: true
61
+ val:
62
+ num_workers: 0
63
+ pin_memory: true
64
+
65
+ # definition for loss function class. Modify if trainig on different targets
66
+ loss:
67
+ class: aimnet.train.loss.MTLoss
68
+ kwargs:
69
+ components:
70
+ energy:
71
+ fn: aimnet.train.loss.energy_loss_fn
72
+ weight: 1.0
73
+ forces:
74
+ fn: aimnet.train.loss.peratom_loss_fn
75
+ weight: 0.2
76
+ kwargs:
77
+ key_true: forces
78
+ key_pred: forces
79
+ charges:
80
+ fn: aimnet.train.loss.peratom_loss_fn
81
+ weight: 0.05
82
+ kwargs:
83
+ key_true: charges
84
+ key_pred: charges
85
+
86
+ optimizer:
87
+ # lists of regular expressions for parameter names to enable or disable gradients
88
+ # force_no_grad will be processed first
89
+ force_no_train: []
90
+ force_train: []
91
+ class: torch.optim.RAdam
92
+ kwargs:
93
+ lr: 0.0004
94
+ weight_decay: 1e-8
95
+ # parameters with non-default optimizer settings
96
+ param_groups:
97
+ shifts:
98
+ re: ".*.atomic_shift.shifts.weight$"
99
+ weight_decay: 0.0
100
+
101
+ # class definition for learning rate scheduler
102
+ scheduler:
103
+ class: ignite.handlers.param_scheduler.ReduceLROnPlateauScheduler
104
+ kwargs:
105
+ metric_name: loss
106
+ factor: 0.75
107
+ patience: 10
108
+ # terminate training if learning rate is lower than this value
109
+ # useful for ReduceLROnPlateauScheduler
110
+ terminate_on_low_lr: 1.0e-5
111
+
112
+ trainer:
113
+ # function that define trainig and validation loops
114
+ trainer: aimnet.train.utils.default_trainer
115
+ evaluator: aimnet.train.utils.default_evaluator
116
+ # total number of epochs to train
117
+ epochs: 100
118
+
119
+ # perdicaly save chechpoints, set to null to disable
120
+ checkpoint:
121
+ dirname: checkpoints
122
+ filename_prefix: ${run_name}
123
+ kwargs:
124
+ n_saved: 1
125
+ require_empty: False
126
+
127
+ # wandb logger
128
+ wandb:
129
+ init:
130
+ name: ${run_name}
131
+ mode: offline
132
+ entity: null
133
+ project: null
134
+ notes: null
135
+ watch_model:
136
+ log: all
137
+ log_freq: 1000
138
+ log_graph: true
139
+
140
+ # standard set of metrics. Add an entry if training on different targets
141
+ metrics:
142
+ class: aimnet.train.metrics.RegMultiMetric
143
+ kwargs:
144
+ cfg:
145
+ energy:
146
+ abbr: E
147
+ scale: 23.06 # eV to kcal/mol
148
+ dipole:
149
+ abbr: D
150
+ scale: 1.0
151
+ mult: 3
152
+ quadrupole:
153
+ abbr: Q
154
+ scale: 1.0
155
+ mult: 6
156
+ charges:
157
+ abbr: q
158
+ peratom: True
159
+ volumes:
160
+ abbr: V
161
+ peratom: True
162
+ forces:
163
+ abbr: F
164
+ peratom: True
165
+ mult: 3
166
+ scale: 23.06
aimnet/train/loss.py ADDED
@@ -0,0 +1,83 @@
1
+ from functools import partial
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from aimnet.config import get_module
8
+
9
+
10
+ class MTLoss:
11
+ """Multi-target loss function with fixed weights.
12
+
13
+ This class allows for the combination of multiple loss functions, each with a specified weight.
14
+ The weights are normalized to sum to 1. The loss functions are applied to the predictions and
15
+ true values, and the weighted sum of the losses is computed.
16
+
17
+ Loss functions definition must contain keys:
18
+ name (str): The name of the loss function.
19
+ fn (str): The loss function (e.g. `aimnet2.train.loss.mse_loss_fn`).
20
+ weight (float): The weight of the loss function.
21
+ kwargs (Dict): Optional, additional keyword arguments for the loss function.
22
+
23
+ Methods:
24
+ __call__(y_pred, y_true):
25
+ Computes the weighted sum of the losses from the individual loss functions.
26
+ Args:
27
+ y_pred (Dict[str, Tensor]): Predicted values.
28
+ y_true (Dict[str, Tensor]): True values.
29
+ Returns:
30
+ Dict[str, Tensor]: total loss under key 'loss' and values for individual components.
31
+ """
32
+
33
+ def __init__(self, components: Dict[str, Any]):
34
+ w_sum = sum(c["weight"] for c in components.values())
35
+ self.components = {}
36
+ for name, c in components.items():
37
+ kwargs = c.get("kwargs", {})
38
+ fn = partial(get_module(c["fn"]), **kwargs)
39
+ self.components[name] = (fn, c["weight"] / w_sum)
40
+
41
+ def __call__(self, y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor]) -> Dict[str, Tensor]:
42
+ loss = {}
43
+ for name, (fn, w) in self.components.items():
44
+ _l = fn(y_pred=y_pred, y_true=y_true)
45
+ loss[name] = _l * w
46
+ # special name for the total loss
47
+ loss["loss"] = sum(loss.values())
48
+ return loss
49
+
50
+
51
+ def mse_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
52
+ """General MSE loss function"""
53
+ x = y_true[key_true]
54
+ y = y_pred[key_pred]
55
+ loss = torch.nn.functional.mse_loss(x, y)
56
+ return loss
57
+
58
+
59
+ def peratom_loss_fn(y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str, key_true: str) -> Tensor:
60
+ """MSE loss function with per-atom normalization correction.
61
+ Suitable when some of the values are zero both in y_pred and y_true due to padding of inputs.
62
+ """
63
+ x = y_true[key_true]
64
+ y = y_pred[key_pred]
65
+
66
+ if y_pred["_natom"].numel() == 1:
67
+ loss = torch.nn.functional.mse_loss(x, y)
68
+ else:
69
+ diff2 = (x - y).pow(2).view(x.shape[0], -1)
70
+ dim = diff2.shape[-1]
71
+ loss = (diff2 * (y_pred["_natom"].unsqueeze(-1) / dim)).mean()
72
+ return loss
73
+
74
+
75
+ def energy_loss_fn(
76
+ y_pred: Dict[str, Tensor], y_true: Dict[str, Tensor], key_pred: str = "energy", key_true: str = "energy"
77
+ ) -> Tensor:
78
+ """MSE loss normalized by the number of atoms."""
79
+ x = y_true[key_true]
80
+ y = y_pred[key_pred]
81
+ s = y_pred["_natom"].sqrt()
82
+ loss = ((x - y).pow(2) / s).mean() if y_pred["_natom"].numel() > 1 else torch.nn.functional.mse_loss(x, y) / s
83
+ return loss
@@ -0,0 +1,188 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Tuple
4
+
5
+ import ignite.distributed as idist
6
+ import numpy as np
7
+ import torch
8
+ from ignite.exceptions import NotComputableError
9
+ from ignite.metrics import Metric
10
+ from ignite.metrics.metric import reinit__is_reduced
11
+ from torch import Tensor
12
+
13
+
14
+ def regression_stats(pred: Tensor, true: Tensor) -> Dict[str, Tensor]:
15
+ diff = true - pred
16
+ diff2 = diff.pow(2)
17
+ mae = diff.abs().mean(-1)
18
+ rmse = diff2.mean(-1).sqrt()
19
+ true_mean = true.mean()
20
+ tot = (true - true_mean).pow(2).to(torch.double).sum()
21
+ res = diff2.to(torch.double).sum(-1)
22
+ r2 = 1 - (res / tot)
23
+ return {"mae": mae, "rmse": rmse, "r2": r2}
24
+
25
+
26
+ def cat_flatten(y_pred: Tensor, y_true: Tensor) -> Tuple[Tensor, Tensor]:
27
+ if isinstance(y_true, (list, tuple)):
28
+ y_true = torch.cat([x.view(-1) for x in y_true])
29
+ if isinstance(y_pred, (list, tuple)):
30
+ _n = sum(x.numel() for x in y_pred)
31
+ _npass = _n // y_true.numel()
32
+ y_pred = torch.cat([x.view(_npass, -1) for x in y_pred], dim=1)
33
+ y_true = y_true.view(-1)
34
+ if y_pred.ndim > y_true.ndim:
35
+ if y_pred.shape[1] != y_true.shape[0]:
36
+ y_pred = y_pred.view(-1, y_true.shape[0])
37
+ _npass = y_pred.shape[0]
38
+ y_pred = y_pred.view(_npass, -1)
39
+ else:
40
+ y_pred = y_pred.view(-1)
41
+ return y_pred, y_true
42
+
43
+
44
+ def _iqr(a):
45
+ a = a.view(-1)
46
+ k1 = 1 + round(0.25 * (a.numel() - 1))
47
+ k2 = 1 + round(0.75 * (a.numel() - 1))
48
+ v1 = a.kthvalue(k1).values.item()
49
+ v2 = a.kthvalue(k2).values.item()
50
+ return v2 - v1
51
+
52
+
53
+ def _freedman_diaconis_bins(a, max_bins=50):
54
+ """Calculate number of hist bins using Freedman-Diaconis rule."""
55
+ # From https://stats.stackexchange.com/questions/798/
56
+ if a.numel() < 2:
57
+ return 1
58
+ h = 2 * _iqr(a) / (a.numel() ** (1 / 3))
59
+ # fall back to sqrt(a) bins if iqr is 0
60
+ n_bins = int(np.sqrt(a.numel())) if h == 0 else int(np.ceil((a.max().item() - a.min().item()) / h))
61
+ return min(n_bins, max_bins)
62
+
63
+
64
+ def calculate_metrics(result, histogram=False, corrplot=False):
65
+ keys = [k[:-5] for k in result if k.endswith("_pred")]
66
+ for k in keys:
67
+ y_pred = result.pop(k + "_pred")
68
+ y_true = result.pop(k + "_true")
69
+ y_pred, y_true = cat_flatten(y_pred, y_true)
70
+ stats = regression_stats(y_pred, y_true)
71
+ npass = stats["mae"].numel()
72
+ if k.split(".")[-1] in ("energy", "forces"): # noqa: SIM108
73
+ f = 23.06 # eV to kcal/mol
74
+ else:
75
+ f = 1.0
76
+ for i in range(npass):
77
+ for m, v in stats.items():
78
+ if m in ("mae", "rmse"):
79
+ v[i] = v[i] * f
80
+ result.log(f"{k}_{m}_{i}", v[i])
81
+ if histogram:
82
+ err = y_pred - y_true
83
+ for i in range(npass):
84
+ result[f"{k}_{i}_hist"] = torch.histc(err, bins=_freedman_diaconis_bins(err))
85
+ return result
86
+
87
+
88
+ class RegMultiMetric(Metric):
89
+ def __init__(self, cfg: List[Dict], loss_fn=None):
90
+ super().__init__()
91
+ self.cfg = cfg
92
+ self.loss_fn = loss_fn
93
+
94
+ def attach_loss(self, loss_fn):
95
+ self.loss_fn = loss_fn
96
+
97
+ @reinit__is_reduced
98
+ def reset(self):
99
+ super().reset()
100
+ self.data = defaultdict(lambda: defaultdict(float))
101
+ self.atoms = 0.0
102
+ self.samples = 0.0
103
+ self.loss = defaultdict(float)
104
+
105
+ def _update_one(self, key: str, pred: Tensor, true: Tensor) -> None:
106
+ e = true - pred
107
+ e = e.view(pred.shape[0], -1) if pred.ndim > true.ndim else e.view(-1)
108
+ d = self.data[key]
109
+ d["sum_abs_err"] += e.abs().sum(-1).to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
110
+ d["sum_sq_err"] += e.pow(2).sum(-1).to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
111
+ d["sum_true"] += true.sum().to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
112
+ d["sum_sq_true"] += true.pow(2).sum().to(dtype=torch.double, device="cpu") # type: ignore[attr-defined]
113
+
114
+ @reinit__is_reduced
115
+ def update(self, output) -> None:
116
+ y_pred, y_true = output
117
+ if y_pred is None:
118
+ return
119
+ for k in y_pred:
120
+ if k not in y_true:
121
+ continue
122
+ with torch.no_grad():
123
+ self._update_one(k, y_pred[k].detach(), y_true[k].detach())
124
+ b = y_true[k].shape[0]
125
+ self.samples += b
126
+
127
+ _n = y_pred["_natom"]
128
+ if _n.numel() > 1:
129
+ self.atoms += _n.sum().item()
130
+ else:
131
+ self.atoms += y_pred["numbers"].shape[0] * y_pred["numbers"].shape[1]
132
+ if self.loss_fn is not None:
133
+ with torch.no_grad():
134
+ loss_d = self.loss_fn(y_pred, y_true)
135
+ for k, loss in loss_d.items():
136
+ if isinstance(loss, Tensor):
137
+ if loss.numel() > 1:
138
+ loss = loss.mean()
139
+ loss = loss.item()
140
+ self.loss[k] += loss * b
141
+
142
+ def compute(self) -> Dict[str, float]:
143
+ if self.samples == 0:
144
+ raise NotComputableError
145
+ # Use custom reduction
146
+ if idist.get_world_size() > 1:
147
+ self.atoms = idist.all_reduce(self.atoms)
148
+ self.samples = idist.all_reduce(self.samples)
149
+ for k, loss in self.loss.items():
150
+ self.loss[k] = idist.all_reduce(loss) # type: ignore[attr-defined]
151
+ for k1, v1 in self.data.items():
152
+ for k2, v2 in v1.items():
153
+ self.data[k1][k2] = idist.all_reduce(v2) # type: ignore[attr-defined]
154
+ self._is_reduced = True
155
+
156
+ # compute
157
+ ret = {}
158
+ for k in self.data:
159
+ if k not in self.cfg:
160
+ continue
161
+ cfg = self.cfg[k]
162
+ _n = self.atoms if cfg.get("peratom", False) else self.samples
163
+ _n *= cfg.get("mult", 1.0)
164
+ name = k
165
+ abbr = cfg["abbr"]
166
+ v = self.data[name]
167
+ m = {}
168
+ m["mae"] = v["sum_abs_err"] / _n
169
+ m["rmse"] = (v["sum_sq_err"] / _n).sqrt()
170
+ m["r2"] = 1.0 - v["sum_sq_err"] / (v["sum_sq_true"] - (v["sum_true"].pow(2)) / _n) # type: ignore[attr-defined]
171
+ for k, v in m.items():
172
+ if k in ("mae", "rmse"):
173
+ v *= cfg.get("scale", 1.0)
174
+ v = v.tolist()
175
+ if isinstance(v, list):
176
+ for ii, vv in enumerate(v):
177
+ ret[f"{abbr}_{k}_{ii}"] = vv
178
+ else:
179
+ ret[f"{abbr}_{k}"] = v
180
+ if len(self.loss):
181
+ for k, loss in self.loss.items():
182
+ if not k.endswith("loss"):
183
+ k = k + "_loss"
184
+ ret[k] = loss / self.samples
185
+
186
+ logging.info(str(ret))
187
+
188
+ return ret