aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/__init__.py
CHANGED
aimnet/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import ClassVar,
|
|
1
|
+
from typing import ClassVar, Final
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import Tensor, nn
|
|
@@ -6,30 +6,46 @@ from torch import Tensor, nn
|
|
|
6
6
|
from aimnet import nbops
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class AIMNet2Base(nn.Module):
|
|
9
|
+
class AIMNet2Base(nn.Module):
|
|
10
10
|
"""Base class for AIMNet2 models. Implements pre-processing data:
|
|
11
11
|
converting to right dtype and device, setting nb mode, calculating masks.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
14
|
_required_keys: Final = ["coord", "numbers", "charge"]
|
|
15
15
|
_required_keys_dtype: Final = [torch.float32, torch.int64, torch.float32]
|
|
16
|
-
_optional_keys: Final = [
|
|
17
|
-
|
|
16
|
+
_optional_keys: Final = [
|
|
17
|
+
"mult",
|
|
18
|
+
"nbmat",
|
|
19
|
+
"nbmat_lr",
|
|
20
|
+
"mol_idx",
|
|
21
|
+
"shifts",
|
|
22
|
+
"shifts_lr",
|
|
23
|
+
"cell",
|
|
24
|
+
]
|
|
25
|
+
_optional_keys_dtype: Final = [
|
|
26
|
+
torch.float32, # mult
|
|
27
|
+
torch.int64, # nbmat
|
|
28
|
+
torch.int64, # nbmat_lr
|
|
29
|
+
torch.int64, # mol_idx
|
|
30
|
+
torch.float32, # shifts
|
|
31
|
+
torch.float32, # shifts_lr
|
|
32
|
+
torch.float32, # cell
|
|
33
|
+
]
|
|
18
34
|
__constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
|
|
19
35
|
|
|
20
36
|
def __init__(self):
|
|
21
37
|
super().__init__()
|
|
22
38
|
|
|
23
|
-
def _prepare_dtype(self, data:
|
|
24
|
-
for k, d in zip(self._required_keys, self._required_keys_dtype):
|
|
39
|
+
def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
40
|
+
for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False):
|
|
25
41
|
assert k in data, f"Key {k} is required"
|
|
26
42
|
data[k] = data[k].to(d)
|
|
27
|
-
for k, d in zip(self._optional_keys, self._optional_keys_dtype):
|
|
43
|
+
for k, d in zip(self._optional_keys, self._optional_keys_dtype, strict=False):
|
|
28
44
|
if k in data:
|
|
29
45
|
data[k] = data[k].to(d)
|
|
30
46
|
return data
|
|
31
47
|
|
|
32
|
-
def prepare_input(self, data:
|
|
48
|
+
def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
33
49
|
"""Some sommon operations"""
|
|
34
50
|
data = self._prepare_dtype(data)
|
|
35
51
|
data = nbops.set_nb_mode(data)
|
aimnet/calculators/__init__.py
CHANGED
|
@@ -5,11 +5,11 @@ from .calculator import AIMNet2Calculator
|
|
|
5
5
|
__all__ = ["AIMNet2Calculator"]
|
|
6
6
|
|
|
7
7
|
if importlib.util.find_spec("ase") is not None:
|
|
8
|
-
from .aimnet2ase import AIMNet2ASE
|
|
8
|
+
from .aimnet2ase import AIMNet2ASE # noqa: F401
|
|
9
9
|
|
|
10
|
-
__all__.append(AIMNet2ASE)
|
|
10
|
+
__all__.append("AIMNet2ASE")
|
|
11
11
|
|
|
12
12
|
if importlib.util.find_spec("pysisyphus") is not None:
|
|
13
|
-
from .aimnet2pysis import AIMNet2Pysis
|
|
13
|
+
from .aimnet2pysis import AIMNet2Pysis # noqa: F401
|
|
14
14
|
|
|
15
|
-
__all__.append(AIMNet2Pysis)
|
|
15
|
+
__all__.append("AIMNet2Pysis")
|
aimnet/calculators/aimnet2ase.py
CHANGED
|
@@ -12,16 +12,24 @@ from .calculator import AIMNet2Calculator
|
|
|
12
12
|
class AIMNet2ASE(Calculator):
|
|
13
13
|
from typing import ClassVar
|
|
14
14
|
|
|
15
|
-
implemented_properties: ClassVar[list[str]] = [
|
|
15
|
+
implemented_properties: ClassVar[list[str]] = [
|
|
16
|
+
"energy",
|
|
17
|
+
"forces",
|
|
18
|
+
"free_energy",
|
|
19
|
+
"charges",
|
|
20
|
+
"stress",
|
|
21
|
+
"dipole_moment",
|
|
22
|
+
]
|
|
16
23
|
|
|
17
24
|
def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
|
|
18
25
|
super().__init__()
|
|
19
26
|
if isinstance(base_calc, str):
|
|
20
27
|
base_calc = AIMNet2Calculator(base_calc)
|
|
21
28
|
self.base_calc = base_calc
|
|
29
|
+
self.reset()
|
|
22
30
|
self.charge = charge
|
|
23
31
|
self.mult = mult
|
|
24
|
-
self.
|
|
32
|
+
self.update_tensors()
|
|
25
33
|
# list of implemented species
|
|
26
34
|
if hasattr(base_calc, "implemented_species"):
|
|
27
35
|
self.implemented_species = base_calc.implemented_species.cpu().numpy() # type: ignore
|
|
@@ -33,8 +41,6 @@ class AIMNet2ASE(Calculator):
|
|
|
33
41
|
self._t_numbers = None
|
|
34
42
|
self._t_charge = None
|
|
35
43
|
self._t_mult = None
|
|
36
|
-
self.charge = 0.0
|
|
37
|
-
self.mult = 1.0
|
|
38
44
|
|
|
39
45
|
def set_atoms(self, atoms):
|
|
40
46
|
if self.implemented_species is not None and not np.in1d(atoms.numbers, self.implemented_species).all():
|
|
@@ -53,13 +59,18 @@ class AIMNet2ASE(Calculator):
|
|
|
53
59
|
self.update_tensors()
|
|
54
60
|
|
|
55
61
|
def update_tensors(self):
|
|
56
|
-
if self._t_numbers is None:
|
|
62
|
+
if self._t_numbers is None and getattr(self, "atoms", None):
|
|
57
63
|
self._t_numbers = torch.tensor(self.atoms.numbers, dtype=torch.int64, device=self.base_calc.device)
|
|
58
64
|
if self._t_charge is None:
|
|
59
65
|
self._t_charge = torch.tensor(self.charge, dtype=torch.float32, device=self.base_calc.device)
|
|
60
66
|
if self._t_mult is None:
|
|
61
67
|
self._t_mult = torch.tensor(self.mult, dtype=torch.float32, device=self.base_calc.device)
|
|
62
68
|
|
|
69
|
+
def get_dipole_moment(self, atoms):
|
|
70
|
+
charges = self.get_charges()[:, np.newaxis]
|
|
71
|
+
positions = atoms.get_positions()
|
|
72
|
+
return np.sum(charges * positions, axis=0)
|
|
73
|
+
|
|
63
74
|
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
|
|
64
75
|
if properties is None:
|
|
65
76
|
properties = ["energy"]
|
|
@@ -90,8 +101,10 @@ class AIMNet2ASE(Calculator):
|
|
|
90
101
|
v = v.squeeze(0)
|
|
91
102
|
results[k] = v.detach().cpu().numpy() # type: ignore
|
|
92
103
|
|
|
93
|
-
self.results["energy"] = results["energy"]
|
|
104
|
+
self.results["energy"] = results["energy"].item()
|
|
94
105
|
self.results["charges"] = results["charges"]
|
|
106
|
+
self.results["dipole_moment"] = self.get_dipole_moment(self.atoms)
|
|
107
|
+
|
|
95
108
|
if "forces" in properties:
|
|
96
109
|
self.results["forces"] = results["forces"]
|
|
97
110
|
if "stress" in properties:
|