aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/__init__.py
ADDED
|
File without changes
|
aimnet/base.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import ClassVar, Dict, Final
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from aimnet import nbops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
|
|
10
|
+
"""Base class for AIMNet2 models. Implements pre-processing data:
|
|
11
|
+
converting to right dtype and device, setting nb mode, calculating masks.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
_required_keys: Final = ["coord", "numbers", "charge"]
|
|
15
|
+
_required_keys_dtype: Final = [torch.float32, torch.int64, torch.float32]
|
|
16
|
+
_optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "cell"]
|
|
17
|
+
_optional_keys_dtype: Final = [torch.float32, torch.int64, torch.int64, torch.int64, torch.float32, torch.float32]
|
|
18
|
+
__constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
24
|
+
for k, d in zip(self._required_keys, self._required_keys_dtype):
|
|
25
|
+
assert k in data, f"Key {k} is required"
|
|
26
|
+
data[k] = data[k].to(d)
|
|
27
|
+
for k, d in zip(self._optional_keys, self._optional_keys_dtype):
|
|
28
|
+
if k in data:
|
|
29
|
+
data[k] = data[k].to(d)
|
|
30
|
+
return data
|
|
31
|
+
|
|
32
|
+
def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
33
|
+
"""Some sommon operations"""
|
|
34
|
+
data = self._prepare_dtype(data)
|
|
35
|
+
data = nbops.set_nb_mode(data)
|
|
36
|
+
data = nbops.calc_masks(data)
|
|
37
|
+
|
|
38
|
+
assert data["charge"].ndim == 1, "Charge should be 1D tensor."
|
|
39
|
+
if "mult" in data:
|
|
40
|
+
assert data["mult"].ndim == 1, "Mult should be 1D tensor."
|
|
41
|
+
return data
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
|
|
3
|
+
from .calculator import AIMNet2Calculator
|
|
4
|
+
|
|
5
|
+
__all__ = ["AIMNet2Calculator"]
|
|
6
|
+
|
|
7
|
+
if importlib.util.find_spec("ase") is not None:
|
|
8
|
+
from .aimnet2ase import AIMNet2ASE
|
|
9
|
+
|
|
10
|
+
__all__.append(AIMNet2ASE) # type: ignore
|
|
11
|
+
|
|
12
|
+
if importlib.util.find_spec("pysisyphus") is not None:
|
|
13
|
+
from .aimnet2pysis import AIMNet2Pysis
|
|
14
|
+
|
|
15
|
+
__all__.append(AIMNet2Pysis) # type: ignore
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from ase.calculators.calculator import Calculator, all_changes # type: ignore
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError("ASE is not installed. Please install ASE to use this module.") from None
|
|
8
|
+
|
|
9
|
+
from .calculator import AIMNet2Calculator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AIMNet2ASE(Calculator):
|
|
13
|
+
from typing import ClassVar
|
|
14
|
+
|
|
15
|
+
implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress"]
|
|
16
|
+
|
|
17
|
+
def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
|
|
18
|
+
super().__init__()
|
|
19
|
+
if isinstance(base_calc, str):
|
|
20
|
+
base_calc = AIMNet2Calculator(base_calc)
|
|
21
|
+
self.base_calc = base_calc
|
|
22
|
+
self.charge = charge
|
|
23
|
+
self.mult = mult
|
|
24
|
+
self.reset()
|
|
25
|
+
# list of implemented species
|
|
26
|
+
if hasattr(base_calc, "implemented_species"):
|
|
27
|
+
self.implemented_species = base_calc.implemented_species.cpu().numpy() # type: ignore
|
|
28
|
+
else:
|
|
29
|
+
self.implemented_species = None
|
|
30
|
+
|
|
31
|
+
def reset(self):
|
|
32
|
+
super().reset()
|
|
33
|
+
self._t_numbers = None
|
|
34
|
+
self._t_charge = None
|
|
35
|
+
self._t_mult = None
|
|
36
|
+
self.charge = 0.0
|
|
37
|
+
self.mult = 1.0
|
|
38
|
+
|
|
39
|
+
def set_atoms(self, atoms):
|
|
40
|
+
if self.implemented_species is not None and not np.in1d(atoms.numbers, self.implemented_species).all():
|
|
41
|
+
raise ValueError("Some species are not implemented in the AIMNet2Calculator")
|
|
42
|
+
self.reset()
|
|
43
|
+
self.atoms = atoms
|
|
44
|
+
|
|
45
|
+
def set_charge(self, charge):
|
|
46
|
+
self.charge = charge
|
|
47
|
+
self._t_charge = None
|
|
48
|
+
self.update_tensors()
|
|
49
|
+
|
|
50
|
+
def set_mult(self, mult):
|
|
51
|
+
self.mult = mult
|
|
52
|
+
self._t_mult = None
|
|
53
|
+
self.update_tensors()
|
|
54
|
+
|
|
55
|
+
def update_tensors(self):
|
|
56
|
+
if self._t_numbers is None:
|
|
57
|
+
self._t_numbers = torch.tensor(self.atoms.numbers, dtype=torch.int64, device=self.base_calc.device)
|
|
58
|
+
if self._t_charge is None:
|
|
59
|
+
self._t_charge = torch.tensor(self.charge, dtype=torch.float32, device=self.base_calc.device)
|
|
60
|
+
if self._t_mult is None:
|
|
61
|
+
self._t_mult = torch.tensor(self.mult, dtype=torch.float32, device=self.base_calc.device)
|
|
62
|
+
|
|
63
|
+
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
|
|
64
|
+
if properties is None:
|
|
65
|
+
properties = ["energy"]
|
|
66
|
+
super().calculate(atoms, properties, system_changes)
|
|
67
|
+
self.update_tensors()
|
|
68
|
+
|
|
69
|
+
cell = self.atoms.cell.array if self.atoms.cell is not None and self.atoms.pbc.any() else None
|
|
70
|
+
|
|
71
|
+
_in = {
|
|
72
|
+
"coord": torch.tensor(self.atoms.positions, dtype=torch.float32, device=self.base_calc.device),
|
|
73
|
+
"numbers": self._t_numbers,
|
|
74
|
+
"charge": self._t_charge,
|
|
75
|
+
"mult": self._t_mult,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
_unsqueezed = False
|
|
79
|
+
if cell is not None:
|
|
80
|
+
_in["cell"] = cell
|
|
81
|
+
else:
|
|
82
|
+
for k, v in _in.items():
|
|
83
|
+
_in[k] = v.unsqueeze(0)
|
|
84
|
+
_unsqueezed = True
|
|
85
|
+
|
|
86
|
+
results = self.base_calc(_in, forces="forces" in properties, stress="stress" in properties)
|
|
87
|
+
|
|
88
|
+
for k, v in results.items():
|
|
89
|
+
if _unsqueezed:
|
|
90
|
+
v = v.squeeze(0)
|
|
91
|
+
results[k] = v.detach().cpu().numpy() # type: ignore
|
|
92
|
+
|
|
93
|
+
self.results["energy"] = results["energy"]
|
|
94
|
+
self.results["charges"] = results["charges"]
|
|
95
|
+
if "forces" in properties:
|
|
96
|
+
self.results["forces"] = results["forces"]
|
|
97
|
+
if "stress" in properties:
|
|
98
|
+
self.results["stress"] = results["stress"]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from typing import ClassVar
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .calculator import AIMNet2Calculator
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import pysisyphus.run # type: ignore
|
|
9
|
+
from pysisyphus.calculators.Calculator import Calculator # type: ignore
|
|
10
|
+
from pysisyphus.constants import ANG2BOHR, AU2EV, BOHR2ANG # type: ignore
|
|
11
|
+
from pysisyphus.elem_data import ATOMIC_NUMBERS # type: ignore
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("Pysisyphus is not installed. Please install Pysisyphus to use this module.") from None
|
|
14
|
+
|
|
15
|
+
EV2AU = 1 / AU2EV
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AIMNet2Pysis(Calculator):
|
|
19
|
+
implemented_properties: ClassVar = ["energy", "forces", "free_energy", "charges", "stress"]
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, **kwargs):
|
|
22
|
+
super().__init__(charge=charge, mult=mult, **kwargs)
|
|
23
|
+
if isinstance(model, str):
|
|
24
|
+
model = AIMNet2Calculator(model)
|
|
25
|
+
self.model = model
|
|
26
|
+
|
|
27
|
+
def _prepere_input(self, atoms, coord):
|
|
28
|
+
device = self.model.device
|
|
29
|
+
numbers = torch.as_tensor([ATOMIC_NUMBERS[a.lower()] for a in atoms], device=device)
|
|
30
|
+
coord = torch.as_tensor(coord, dtype=torch.float, device=device).view(-1, 3) * BOHR2ANG
|
|
31
|
+
charge = torch.as_tensor([self.charge], dtype=torch.float, device=device)
|
|
32
|
+
mult = torch.as_tensor([self.mult], dtype=torch.float, device=device)
|
|
33
|
+
return {"coord": coord, "numbers": numbers, "charge": charge, "mult": mult}
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _results_get_energy(results):
|
|
37
|
+
return results["energy"].item() * EV2AU
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _results_get_forces(results):
|
|
41
|
+
return (results["forces"].detach() * (EV2AU / ANG2BOHR)).flatten().to(torch.double).cpu().numpy()
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def _results_get_hessian(results):
|
|
45
|
+
return (
|
|
46
|
+
(results["hessian"].flatten(0, 1).flatten(-2, -1) * (EV2AU / ANG2BOHR / ANG2BOHR))
|
|
47
|
+
.to(torch.double)
|
|
48
|
+
.cpu()
|
|
49
|
+
.numpy()
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def get_energy(self, atoms, coords):
|
|
53
|
+
_in = self._prepere_input(atoms, coords)
|
|
54
|
+
res = self.model(_in)
|
|
55
|
+
energy = self._results_get_energy(res)
|
|
56
|
+
return {"energy": energy}
|
|
57
|
+
|
|
58
|
+
def get_forces(self, atoms, coords):
|
|
59
|
+
_in = self._prepere_input(atoms, coords)
|
|
60
|
+
res = self.model(_in, forces=True)
|
|
61
|
+
energy = self._results_get_energy(res)
|
|
62
|
+
forces = self._results_get_forces(res)
|
|
63
|
+
return {"energy": energy, "forces": forces}
|
|
64
|
+
|
|
65
|
+
def get_hessian(self, atoms, coords):
|
|
66
|
+
_in = self._prepere_input(atoms, coords)
|
|
67
|
+
res = self.model(_in, forces=True, hessian=True)
|
|
68
|
+
energy = self._results_get_energy(res)
|
|
69
|
+
forces = self._results_get_forces(res)
|
|
70
|
+
hessian = self._results_get_hessian(res)
|
|
71
|
+
return {"energy": energy, "forces": forces, "hessian": hessian}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def run_pysis():
|
|
75
|
+
pysisyphus.run.CALC_DICT["aimnet"] = AIMNet2Pysis
|
|
76
|
+
pysisyphus.run.run()
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, ClassVar, Dict, Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
|
|
7
|
+
from .model_registry import get_model_path
|
|
8
|
+
from .nbmat import TooManyNeighborsError, calc_nbmat
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AIMNet2Calculator:
|
|
12
|
+
"""Genegic AIMNet2 calculator
|
|
13
|
+
A helper class to load AIMNet2 models and perform inference.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
keys_in: ClassVar[Dict[str, torch.dtype]] = {"coord": torch.float, "numbers": torch.int, "charge": torch.float}
|
|
17
|
+
keys_in_optional: ClassVar[Dict[str, torch.dtype]] = {
|
|
18
|
+
"mult": torch.float,
|
|
19
|
+
"mol_idx": torch.int,
|
|
20
|
+
"nbmat": torch.int,
|
|
21
|
+
"nbmat_lr": torch.int,
|
|
22
|
+
"nb_pad_mask": torch.bool,
|
|
23
|
+
"nb_pad_mask_lr": torch.bool,
|
|
24
|
+
"shifts": torch.float,
|
|
25
|
+
"shifts_lr": torch.float,
|
|
26
|
+
"cell": torch.float,
|
|
27
|
+
}
|
|
28
|
+
keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
|
|
29
|
+
atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]
|
|
30
|
+
|
|
31
|
+
def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
|
|
32
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
33
|
+
if isinstance(model, str):
|
|
34
|
+
p = get_model_path(model)
|
|
35
|
+
self.model = torch.jit.load(p, map_location=self.device)
|
|
36
|
+
elif isinstance(model, nn.Module):
|
|
37
|
+
self.model = model.to(self.device)
|
|
38
|
+
else:
|
|
39
|
+
raise TypeError("Invalid model type/name.")
|
|
40
|
+
|
|
41
|
+
self.cutoff = self.model.cutoff
|
|
42
|
+
self.lr = hasattr(self.model, "cutoff_lr")
|
|
43
|
+
self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf"))
|
|
44
|
+
self.max_density = 0.2
|
|
45
|
+
self.nb_threshold = nb_threshold
|
|
46
|
+
|
|
47
|
+
# indicator if input was flattened
|
|
48
|
+
self._batch = None
|
|
49
|
+
self._max_mol_size: int = 0
|
|
50
|
+
# placeholder for tensors that require grad
|
|
51
|
+
self._saved_for_grad = {}
|
|
52
|
+
# set flag of current Coulomb model
|
|
53
|
+
coul_methods = {getattr(mod, "method", None) for mod in iter_lrcoulomb_mods(self.model)}
|
|
54
|
+
if len(coul_methods) > 1:
|
|
55
|
+
raise ValueError("Multiple Coulomb modules found.")
|
|
56
|
+
if len(coul_methods):
|
|
57
|
+
self._coulomb_method = coul_methods.pop()
|
|
58
|
+
else:
|
|
59
|
+
self._coulomb_method = None
|
|
60
|
+
|
|
61
|
+
def __call__(self, *args, **kwargs):
|
|
62
|
+
return self.eval(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
def set_lrcoulomb_method(
|
|
65
|
+
self, method: Literal["simple", "dsf", "ewald"], cutoff: float = 15.0, dsf_alpha: float = 0.2
|
|
66
|
+
):
|
|
67
|
+
if method not in ("simple", "dsf", "ewald"):
|
|
68
|
+
raise ValueError(f"Invalid method: {method}")
|
|
69
|
+
for mod in iter_lrcoulomb_mods(self.model):
|
|
70
|
+
mod.method = method # type: ignore
|
|
71
|
+
if method == "simple":
|
|
72
|
+
self.cutoff_lr = float("inf")
|
|
73
|
+
elif method == "dsf":
|
|
74
|
+
self.cutoff_lr = cutoff
|
|
75
|
+
mod.dsf_alpha = dsf_alpha # type: ignore
|
|
76
|
+
mod.dsf_rc = cutoff # type: ignore
|
|
77
|
+
elif method == "ewald":
|
|
78
|
+
# current implementaion of Ewald does not use nb mat
|
|
79
|
+
self.cutoff_lr = cutoff
|
|
80
|
+
self._coulomb_method = method
|
|
81
|
+
|
|
82
|
+
def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
|
|
83
|
+
data = self.prepare_input(data)
|
|
84
|
+
if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
|
|
85
|
+
raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
|
|
86
|
+
data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)
|
|
87
|
+
with torch.jit.optimized_execution(False): # type: ignore
|
|
88
|
+
data = self.model(data)
|
|
89
|
+
data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
|
|
90
|
+
data = self.process_output(data)
|
|
91
|
+
return data
|
|
92
|
+
|
|
93
|
+
def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
|
|
94
|
+
data = self.to_input_tensors(data)
|
|
95
|
+
data = self.mol_flatten(data)
|
|
96
|
+
if data.get("cell") is not None:
|
|
97
|
+
if data["mol_idx"][-1] > 0:
|
|
98
|
+
raise NotImplementedError("PBC with multiple molecules is not implemented yet.")
|
|
99
|
+
if self._coulomb_method == "simple":
|
|
100
|
+
warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
|
|
101
|
+
self.set_lrcoulomb_method("dsf")
|
|
102
|
+
if data["coord"].ndim == 2:
|
|
103
|
+
data = self.make_nbmat(data)
|
|
104
|
+
data = self.pad_input(data)
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
def process_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
108
|
+
if data["coord"].ndim == 2:
|
|
109
|
+
data = self.unpad_output(data)
|
|
110
|
+
data = self.mol_unflatten(data)
|
|
111
|
+
data = self.keep_only(data)
|
|
112
|
+
return data
|
|
113
|
+
|
|
114
|
+
def to_input_tensors(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
|
|
115
|
+
ret = {}
|
|
116
|
+
for k in self.keys_in:
|
|
117
|
+
if k not in data:
|
|
118
|
+
raise KeyError(f"Missing key {k} in the input data")
|
|
119
|
+
# always detach !!
|
|
120
|
+
ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in[k]).detach()
|
|
121
|
+
for k in self.keys_in_optional:
|
|
122
|
+
if k in data and data[k] is not None:
|
|
123
|
+
ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in_optional[k]).detach()
|
|
124
|
+
# convert any scalar tensors to shape (1,) tensors
|
|
125
|
+
for k, v in ret.items():
|
|
126
|
+
if v.ndim == 0:
|
|
127
|
+
ret[k] = v.unsqueeze(0)
|
|
128
|
+
return ret
|
|
129
|
+
|
|
130
|
+
def mol_flatten(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
131
|
+
"""Flatten the input data for multiple molecules.
|
|
132
|
+
Will not flatten for batched input and molecule size below threshold.
|
|
133
|
+
"""
|
|
134
|
+
ndim = data["coord"].ndim
|
|
135
|
+
if ndim == 2:
|
|
136
|
+
# single molecule or already flattened
|
|
137
|
+
self._batch = None
|
|
138
|
+
if "mol_idx" not in data:
|
|
139
|
+
data["mol_idx"] = torch.zeros(data["coord"].shape[0], dtype=torch.long, device=self.device)
|
|
140
|
+
self._max_mol_size = data["coord"].shape[0]
|
|
141
|
+
elif data["mol_idx"][-1] == 0:
|
|
142
|
+
self._max_mol_size = len(data["mol_idx"])
|
|
143
|
+
else:
|
|
144
|
+
self._max_mol_size = data["mol_idx"].unique(return_counts=True)[1].max().item()
|
|
145
|
+
|
|
146
|
+
elif ndim == 3:
|
|
147
|
+
# batched input
|
|
148
|
+
B, N = data["coord"].shape[:2]
|
|
149
|
+
if self.nb_threshold > N or self.device == "cpu":
|
|
150
|
+
self._batch = B
|
|
151
|
+
data["mol_idx"] = torch.repeat_interleave(
|
|
152
|
+
torch.arange(0, B, device=self.device), torch.full((B,), N, device=self.device)
|
|
153
|
+
)
|
|
154
|
+
for k, v in data.items():
|
|
155
|
+
if k in self.atom_feature_keys:
|
|
156
|
+
data[k] = v.flatten(0, 1)
|
|
157
|
+
else:
|
|
158
|
+
self._batch = None
|
|
159
|
+
self._max_mol_size = N
|
|
160
|
+
return data
|
|
161
|
+
|
|
162
|
+
def mol_unflatten(self, data: Dict[str, Tensor], batch=None) -> Dict[str, Tensor]:
|
|
163
|
+
batch = batch if batch is not None else self._batch
|
|
164
|
+
if batch is not None:
|
|
165
|
+
for k, v in data.items():
|
|
166
|
+
if k in self.atom_feature_keys:
|
|
167
|
+
data[k] = v.view(batch, -1, *v.shape[1:])
|
|
168
|
+
return data
|
|
169
|
+
|
|
170
|
+
def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
171
|
+
assert self._max_mol_size > 0, "Molecule size is not set"
|
|
172
|
+
|
|
173
|
+
if "cell" in data and data["cell"] is not None:
|
|
174
|
+
data["coord"] = move_coord_to_cell(data["coord"], data["cell"])
|
|
175
|
+
cell = data["cell"]
|
|
176
|
+
else:
|
|
177
|
+
cell = None
|
|
178
|
+
|
|
179
|
+
while True:
|
|
180
|
+
try:
|
|
181
|
+
maxnb1 = calc_max_nb(self.cutoff, self.max_density)
|
|
182
|
+
maxnb2 = calc_max_nb(self.cutoff_lr, self.max_density) if self.lr else None # type: ignore
|
|
183
|
+
if cell is None:
|
|
184
|
+
maxnb1 = min(maxnb1, self._max_mol_size)
|
|
185
|
+
maxnb2 = min(maxnb2, self._max_mol_size) if self.lr else None # type: ignore
|
|
186
|
+
maxnb = (maxnb1, maxnb2)
|
|
187
|
+
nbmat1, nbmat2, shifts1, shifts2 = calc_nbmat(
|
|
188
|
+
data["coord"],
|
|
189
|
+
(self.cutoff, self.cutoff_lr),
|
|
190
|
+
maxnb, # type: ignore
|
|
191
|
+
cell,
|
|
192
|
+
data.get("mol_idx"), # type: ignore
|
|
193
|
+
)
|
|
194
|
+
break
|
|
195
|
+
except TooManyNeighborsError:
|
|
196
|
+
self.max_density *= 1.2
|
|
197
|
+
assert self.max_density <= 4, "Something went wrong in nbmat calculation"
|
|
198
|
+
data["nbmat"] = nbmat1
|
|
199
|
+
if self.lr:
|
|
200
|
+
assert nbmat2 is not None
|
|
201
|
+
data["nbmat_lr"] = nbmat2
|
|
202
|
+
if cell is not None:
|
|
203
|
+
assert shifts1 is not None
|
|
204
|
+
data["shifts"] = shifts1
|
|
205
|
+
if self.lr:
|
|
206
|
+
assert shifts2 is not None
|
|
207
|
+
data["shifts_lr"] = shifts2
|
|
208
|
+
return data
|
|
209
|
+
|
|
210
|
+
def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
211
|
+
N = data["nbmat"].shape[0]
|
|
212
|
+
data["mol_idx"] = maybe_pad_dim0(data["mol_idx"], N, value=data["mol_idx"][-1].item())
|
|
213
|
+
for k in ("coord", "numbers"):
|
|
214
|
+
if k in data:
|
|
215
|
+
data[k] = maybe_pad_dim0(data[k], N)
|
|
216
|
+
return data
|
|
217
|
+
|
|
218
|
+
def unpad_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
219
|
+
N = data["nbmat"].shape[0] - 1
|
|
220
|
+
for k, v in data.items():
|
|
221
|
+
if k in self.atom_feature_keys:
|
|
222
|
+
data[k] = maybe_unpad_dim0(v, N)
|
|
223
|
+
return data
|
|
224
|
+
|
|
225
|
+
def set_grad_tensors(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
|
|
226
|
+
self._saved_for_grad = {}
|
|
227
|
+
if forces or hessian:
|
|
228
|
+
data["coord"].requires_grad_(True)
|
|
229
|
+
self._saved_for_grad["coord"] = data["coord"]
|
|
230
|
+
if stress:
|
|
231
|
+
assert "cell" in data and data["cell"] is not None, "Stress calculation requires cell"
|
|
232
|
+
scaling = torch.eye(3, requires_grad=True, dtype=data["cell"].dtype, device=data["cell"].device)
|
|
233
|
+
data["coord"] = data["coord"] @ scaling
|
|
234
|
+
data["cell"] = data["cell"] @ scaling
|
|
235
|
+
self._saved_for_grad["scaling"] = scaling
|
|
236
|
+
return data
|
|
237
|
+
|
|
238
|
+
def keep_only(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
239
|
+
ret = {}
|
|
240
|
+
for k, v in data.items():
|
|
241
|
+
if k in self.keys_out or (k.endswith("_std") and k[:-4] in self.keys_out):
|
|
242
|
+
ret[k] = v
|
|
243
|
+
return ret
|
|
244
|
+
|
|
245
|
+
def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
|
|
246
|
+
training = getattr(self.model, "training", False)
|
|
247
|
+
_create_graph = hessian or training
|
|
248
|
+
x = []
|
|
249
|
+
if hessian:
|
|
250
|
+
forces = True
|
|
251
|
+
if forces and ("forces" not in data or (_create_graph and not data["forces"].requires_grad)):
|
|
252
|
+
forces = True
|
|
253
|
+
x.append(self._saved_for_grad["coord"])
|
|
254
|
+
if stress:
|
|
255
|
+
x.append(self._saved_for_grad["scaling"])
|
|
256
|
+
if x:
|
|
257
|
+
tot_energy = data["energy"].sum()
|
|
258
|
+
deriv = torch.autograd.grad(tot_energy, x, create_graph=_create_graph)
|
|
259
|
+
if forces:
|
|
260
|
+
data["forces"] = -deriv[0]
|
|
261
|
+
if stress:
|
|
262
|
+
dedc = deriv[0] if not forces else deriv[1]
|
|
263
|
+
data["stress"] = dedc / data["cell"].detach().det().abs()
|
|
264
|
+
if hessian:
|
|
265
|
+
data["hessian"] = self.calculate_hessian(data["forces"], self._saved_for_grad["coord"])
|
|
266
|
+
return data
|
|
267
|
+
|
|
268
|
+
@staticmethod
|
|
269
|
+
def calculate_hessian(forces: Tensor, coord: Tensor) -> Tensor:
|
|
270
|
+
# here forces have shape (N, 3) and coord has shape (N+1, 3)
|
|
271
|
+
# return hessian with shape (N, 3, N, 3)
|
|
272
|
+
hessian = -torch.stack([
|
|
273
|
+
torch.autograd.grad(_f, coord, retain_graph=True)[0] for _f in forces.flatten().unbind()
|
|
274
|
+
]).view(-1, 3, coord.shape[0], 3)[:-1, :, :-1, :]
|
|
275
|
+
return hessian
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def maybe_pad_dim0(a: Tensor, N: int, value=0.0) -> Tensor:
|
|
279
|
+
_shape_diff = N - a.shape[0]
|
|
280
|
+
assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
|
|
281
|
+
if _shape_diff == 1:
|
|
282
|
+
a = pad_dim0(a, value=value)
|
|
283
|
+
return a
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def pad_dim0(a: Tensor, value=0.0) -> Tensor:
|
|
287
|
+
shapes = [0] * ((a.ndim - 1) * 2) + [0, 1]
|
|
288
|
+
a = torch.nn.functional.pad(a, shapes, mode="constant", value=value)
|
|
289
|
+
return a
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def maybe_unpad_dim0(a: Tensor, N: int) -> Tensor:
|
|
293
|
+
_shape_diff = a.shape[0] - N
|
|
294
|
+
assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
|
|
295
|
+
if _shape_diff == 1:
|
|
296
|
+
a = a[:-1]
|
|
297
|
+
return a
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def move_coord_to_cell(coord, cell):
|
|
301
|
+
coord_f = coord @ cell.inverse()
|
|
302
|
+
coord_f = coord_f % 1
|
|
303
|
+
return coord_f @ cell
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _named_children_rec(module):
|
|
307
|
+
if isinstance(module, torch.nn.Module):
|
|
308
|
+
for name, child in module.named_children():
|
|
309
|
+
yield name, child
|
|
310
|
+
yield from _named_children_rec(child)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def iter_lrcoulomb_mods(model):
|
|
314
|
+
for name, module in _named_children_rec(model):
|
|
315
|
+
if name == "lrcoulomb":
|
|
316
|
+
yield module
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def calc_max_nb(cutoff: float, density: float = 0.2) -> int | float:
|
|
320
|
+
return int(density * 4 / 3 * 3.14159 * cutoff**3) if cutoff < float("inf") else float("inf")
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, Optional
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import requests
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
logging.basicConfig(level=logging.INFO)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_model_registry(registry_file: Optional[str] = None) -> Dict[str, str]:
|
|
13
|
+
registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml")
|
|
14
|
+
with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f:
|
|
15
|
+
return yaml.load(f, Loader=yaml.SafeLoader)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_assets_dir():
|
|
19
|
+
os.makedirs(os.path.join(os.path.dirname(__file__), "assets"), exist_ok=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_registry_model_path(model_name: str) -> str:
|
|
23
|
+
model_registry = load_model_registry()
|
|
24
|
+
create_assets_dir()
|
|
25
|
+
if model_name in model_registry["aliases"]:
|
|
26
|
+
model_name = model_registry["aliases"][model_name] # type: ignore
|
|
27
|
+
if model_name not in model_registry["models"]:
|
|
28
|
+
raise ValueError(f"Model {model_name} not found in the registry.")
|
|
29
|
+
cfg = model_registry["models"][model_name] # type: ignore
|
|
30
|
+
model_path = _maybe_download_asset(**cfg) # type: ignore
|
|
31
|
+
return model_path
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _maybe_download_asset(file: str, url: str) -> str:
|
|
35
|
+
filename = os.path.join(os.path.dirname(__file__), "assets", file)
|
|
36
|
+
if not os.path.exists(filename):
|
|
37
|
+
print(f"Downloading {url} -> {filename}")
|
|
38
|
+
with open(filename, "wb") as f:
|
|
39
|
+
response = requests.get(url, timeout=60)
|
|
40
|
+
f.write(response.content)
|
|
41
|
+
return filename
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_model_path(s: str) -> str:
|
|
45
|
+
# direct file path
|
|
46
|
+
if os.path.isfile(s):
|
|
47
|
+
print("Found model file:", s)
|
|
48
|
+
else:
|
|
49
|
+
s = get_registry_model_path(s)
|
|
50
|
+
return s
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@click.command(short_help="Clear assets directory.")
|
|
54
|
+
def clear_assets():
|
|
55
|
+
from glob import glob
|
|
56
|
+
|
|
57
|
+
for fil in glob(os.path.join(os.path.dirname(__file__), "assets", "*")):
|
|
58
|
+
if os.path.isfile(fil):
|
|
59
|
+
logging.warn(f"Removing {fil}")
|
|
60
|
+
os.remove(fil)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# map file name to url
|
|
2
|
+
models:
|
|
3
|
+
aimnet2_wb97m_d3_0:
|
|
4
|
+
file: aimnet2_wb97m_d3_0.jpt
|
|
5
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.jpt
|
|
6
|
+
aimnet2_wb97m_d3_1:
|
|
7
|
+
file: aimnet2_wb97m_d3_1.jpt
|
|
8
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.jpt
|
|
9
|
+
aimnet2_wb97m_d3_2:
|
|
10
|
+
file: aimnet2_wb97m_d3_2.jpt
|
|
11
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.jpt
|
|
12
|
+
aimnet2_wb97m_d3_3:
|
|
13
|
+
file: aimnet2_wb97m_d3_3.jpt
|
|
14
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.jpt
|
|
15
|
+
aimnet2_b973c_d3_0:
|
|
16
|
+
file: aimnet2_b973c_d3_0.jpt
|
|
17
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.jpt
|
|
18
|
+
aimnet2_b973c_d3_1:
|
|
19
|
+
file: aimnet2_b973c_d3_1.jpt
|
|
20
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.jpt
|
|
21
|
+
aimnet2_b973c_d3_2:
|
|
22
|
+
file: aimnet2_b973c_d3_2.jpt
|
|
23
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.jpt
|
|
24
|
+
aimnet2_b973c_d3_3:
|
|
25
|
+
file: aimnet2_b973c_d3_3.jpt
|
|
26
|
+
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.jpt
|
|
27
|
+
|
|
28
|
+
# map model alias to file name
|
|
29
|
+
aliases:
|
|
30
|
+
aimnet2: aimnet2_wb97m_d3_0
|
|
31
|
+
aimnet2_wb97m: aimnet2_wb97m_d3_0
|
|
32
|
+
aimnet2_b973c: aimnet2_b973c_d3_0
|
|
33
|
+
aimnet2_qr: aimnet2_qr_v0
|