aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/__init__.py CHANGED
@@ -0,0 +1,7 @@
1
+ # Version is managed by hatch-vcs from git tags
2
+ try:
3
+ from importlib.metadata import version
4
+
5
+ __version__ = version("aimnet")
6
+ except Exception:
7
+ __version__ = "0.0.0+unknown"
aimnet/base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import ClassVar, Dict, Final
1
+ from typing import ClassVar, Final
2
2
 
3
3
  import torch
4
4
  from torch import Tensor, nn
@@ -6,30 +6,46 @@ from torch import Tensor, nn
6
6
  from aimnet import nbops
7
7
 
8
8
 
9
- class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
9
+ class AIMNet2Base(nn.Module):
10
10
  """Base class for AIMNet2 models. Implements pre-processing data:
11
11
  converting to right dtype and device, setting nb mode, calculating masks.
12
12
  """
13
13
 
14
14
  _required_keys: Final = ["coord", "numbers", "charge"]
15
15
  _required_keys_dtype: Final = [torch.float32, torch.int64, torch.float32]
16
- _optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "cell"]
17
- _optional_keys_dtype: Final = [torch.float32, torch.int64, torch.int64, torch.int64, torch.float32, torch.float32]
16
+ _optional_keys: Final = [
17
+ "mult",
18
+ "nbmat",
19
+ "nbmat_lr",
20
+ "mol_idx",
21
+ "shifts",
22
+ "shifts_lr",
23
+ "cell",
24
+ ]
25
+ _optional_keys_dtype: Final = [
26
+ torch.float32, # mult
27
+ torch.int64, # nbmat
28
+ torch.int64, # nbmat_lr
29
+ torch.int64, # mol_idx
30
+ torch.float32, # shifts
31
+ torch.float32, # shifts_lr
32
+ torch.float32, # cell
33
+ ]
18
34
  __constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
19
35
 
20
36
  def __init__(self):
21
37
  super().__init__()
22
38
 
23
- def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
24
- for k, d in zip(self._required_keys, self._required_keys_dtype):
39
+ def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
40
+ for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False):
25
41
  assert k in data, f"Key {k} is required"
26
42
  data[k] = data[k].to(d)
27
- for k, d in zip(self._optional_keys, self._optional_keys_dtype):
43
+ for k, d in zip(self._optional_keys, self._optional_keys_dtype, strict=False):
28
44
  if k in data:
29
45
  data[k] = data[k].to(d)
30
46
  return data
31
47
 
32
- def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
48
+ def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
33
49
  """Some sommon operations"""
34
50
  data = self._prepare_dtype(data)
35
51
  data = nbops.set_nb_mode(data)
@@ -5,11 +5,11 @@ from .calculator import AIMNet2Calculator
5
5
  __all__ = ["AIMNet2Calculator"]
6
6
 
7
7
  if importlib.util.find_spec("ase") is not None:
8
- from .aimnet2ase import AIMNet2ASE
8
+ from .aimnet2ase import AIMNet2ASE # noqa: F401
9
9
 
10
- __all__.append(AIMNet2ASE) # type: ignore
10
+ __all__.append("AIMNet2ASE")
11
11
 
12
12
  if importlib.util.find_spec("pysisyphus") is not None:
13
- from .aimnet2pysis import AIMNet2Pysis
13
+ from .aimnet2pysis import AIMNet2Pysis # noqa: F401
14
14
 
15
- __all__.append(AIMNet2Pysis) # type: ignore
15
+ __all__.append("AIMNet2Pysis")
@@ -12,16 +12,24 @@ from .calculator import AIMNet2Calculator
12
12
  class AIMNet2ASE(Calculator):
13
13
  from typing import ClassVar
14
14
 
15
- implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress"]
15
+ implemented_properties: ClassVar[list[str]] = [
16
+ "energy",
17
+ "forces",
18
+ "free_energy",
19
+ "charges",
20
+ "stress",
21
+ "dipole_moment",
22
+ ]
16
23
 
17
24
  def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
18
25
  super().__init__()
19
26
  if isinstance(base_calc, str):
20
27
  base_calc = AIMNet2Calculator(base_calc)
21
28
  self.base_calc = base_calc
29
+ self.reset()
22
30
  self.charge = charge
23
31
  self.mult = mult
24
- self.reset()
32
+ self.update_tensors()
25
33
  # list of implemented species
26
34
  if hasattr(base_calc, "implemented_species"):
27
35
  self.implemented_species = base_calc.implemented_species.cpu().numpy() # type: ignore
@@ -33,8 +41,6 @@ class AIMNet2ASE(Calculator):
33
41
  self._t_numbers = None
34
42
  self._t_charge = None
35
43
  self._t_mult = None
36
- self.charge = 0.0
37
- self.mult = 1.0
38
44
 
39
45
  def set_atoms(self, atoms):
40
46
  if self.implemented_species is not None and not np.in1d(atoms.numbers, self.implemented_species).all():
@@ -53,13 +59,18 @@ class AIMNet2ASE(Calculator):
53
59
  self.update_tensors()
54
60
 
55
61
  def update_tensors(self):
56
- if self._t_numbers is None:
62
+ if self._t_numbers is None and getattr(self, "atoms", None):
57
63
  self._t_numbers = torch.tensor(self.atoms.numbers, dtype=torch.int64, device=self.base_calc.device)
58
64
  if self._t_charge is None:
59
65
  self._t_charge = torch.tensor(self.charge, dtype=torch.float32, device=self.base_calc.device)
60
66
  if self._t_mult is None:
61
67
  self._t_mult = torch.tensor(self.mult, dtype=torch.float32, device=self.base_calc.device)
62
68
 
69
+ def get_dipole_moment(self, atoms):
70
+ charges = self.get_charges()[:, np.newaxis]
71
+ positions = atoms.get_positions()
72
+ return np.sum(charges * positions, axis=0)
73
+
63
74
  def calculate(self, atoms=None, properties=None, system_changes=all_changes):
64
75
  if properties is None:
65
76
  properties = ["energy"]
@@ -90,8 +101,10 @@ class AIMNet2ASE(Calculator):
90
101
  v = v.squeeze(0)
91
102
  results[k] = v.detach().cpu().numpy() # type: ignore
92
103
 
93
- self.results["energy"] = results["energy"]
104
+ self.results["energy"] = results["energy"].item()
94
105
  self.results["charges"] = results["charges"]
106
+ self.results["dipole_moment"] = self.get_dipole_moment(self.atoms)
107
+
95
108
  if "forces" in properties:
96
109
  self.results["forces"] = results["forces"]
97
110
  if "stress" in properties: