aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/__init__.py ADDED
File without changes
aimnet/base.py ADDED
@@ -0,0 +1,41 @@
1
+ from typing import ClassVar, Dict, Final
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from aimnet import nbops
7
+
8
+
9
+ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
10
+ """Base class for AIMNet2 models. Implements pre-processing data:
11
+ converting to right dtype and device, setting nb mode, calculating masks.
12
+ """
13
+
14
+ _required_keys: Final = ["coord", "numbers", "charge"]
15
+ _required_keys_dtype: Final = [torch.float32, torch.int64, torch.float32]
16
+ _optional_keys: Final = ["mult", "nbmat", "nbmat_lr", "mol_idx", "shifts", "cell"]
17
+ _optional_keys_dtype: Final = [torch.float32, torch.int64, torch.int64, torch.int64, torch.float32, torch.float32]
18
+ __constants__: ClassVar = ["_required_keys", "_required_keys_dtype", "_optional_keys", "_optional_keys_dtype"]
19
+
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
24
+ for k, d in zip(self._required_keys, self._required_keys_dtype):
25
+ assert k in data, f"Key {k} is required"
26
+ data[k] = data[k].to(d)
27
+ for k, d in zip(self._optional_keys, self._optional_keys_dtype):
28
+ if k in data:
29
+ data[k] = data[k].to(d)
30
+ return data
31
+
32
+ def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
33
+ """Some sommon operations"""
34
+ data = self._prepare_dtype(data)
35
+ data = nbops.set_nb_mode(data)
36
+ data = nbops.calc_masks(data)
37
+
38
+ assert data["charge"].ndim == 1, "Charge should be 1D tensor."
39
+ if "mult" in data:
40
+ assert data["mult"].ndim == 1, "Mult should be 1D tensor."
41
+ return data
@@ -0,0 +1,15 @@
1
+ import importlib.util
2
+
3
+ from .calculator import AIMNet2Calculator
4
+
5
+ __all__ = ["AIMNet2Calculator"]
6
+
7
+ if importlib.util.find_spec("ase") is not None:
8
+ from .aimnet2ase import AIMNet2ASE
9
+
10
+ __all__.append(AIMNet2ASE) # type: ignore
11
+
12
+ if importlib.util.find_spec("pysisyphus") is not None:
13
+ from .aimnet2pysis import AIMNet2Pysis
14
+
15
+ __all__.append(AIMNet2Pysis) # type: ignore
@@ -0,0 +1,98 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ try:
5
+ from ase.calculators.calculator import Calculator, all_changes # type: ignore
6
+ except ImportError:
7
+ raise ImportError("ASE is not installed. Please install ASE to use this module.") from None
8
+
9
+ from .calculator import AIMNet2Calculator
10
+
11
+
12
+ class AIMNet2ASE(Calculator):
13
+ from typing import ClassVar
14
+
15
+ implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress"]
16
+
17
+ def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
18
+ super().__init__()
19
+ if isinstance(base_calc, str):
20
+ base_calc = AIMNet2Calculator(base_calc)
21
+ self.base_calc = base_calc
22
+ self.charge = charge
23
+ self.mult = mult
24
+ self.reset()
25
+ # list of implemented species
26
+ if hasattr(base_calc, "implemented_species"):
27
+ self.implemented_species = base_calc.implemented_species.cpu().numpy() # type: ignore
28
+ else:
29
+ self.implemented_species = None
30
+
31
+ def reset(self):
32
+ super().reset()
33
+ self._t_numbers = None
34
+ self._t_charge = None
35
+ self._t_mult = None
36
+ self.charge = 0.0
37
+ self.mult = 1.0
38
+
39
+ def set_atoms(self, atoms):
40
+ if self.implemented_species is not None and not np.in1d(atoms.numbers, self.implemented_species).all():
41
+ raise ValueError("Some species are not implemented in the AIMNet2Calculator")
42
+ self.reset()
43
+ self.atoms = atoms
44
+
45
+ def set_charge(self, charge):
46
+ self.charge = charge
47
+ self._t_charge = None
48
+ self.update_tensors()
49
+
50
+ def set_mult(self, mult):
51
+ self.mult = mult
52
+ self._t_mult = None
53
+ self.update_tensors()
54
+
55
+ def update_tensors(self):
56
+ if self._t_numbers is None:
57
+ self._t_numbers = torch.tensor(self.atoms.numbers, dtype=torch.int64, device=self.base_calc.device)
58
+ if self._t_charge is None:
59
+ self._t_charge = torch.tensor(self.charge, dtype=torch.float32, device=self.base_calc.device)
60
+ if self._t_mult is None:
61
+ self._t_mult = torch.tensor(self.mult, dtype=torch.float32, device=self.base_calc.device)
62
+
63
+ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
64
+ if properties is None:
65
+ properties = ["energy"]
66
+ super().calculate(atoms, properties, system_changes)
67
+ self.update_tensors()
68
+
69
+ cell = self.atoms.cell.array if self.atoms.cell is not None and self.atoms.pbc.any() else None
70
+
71
+ _in = {
72
+ "coord": torch.tensor(self.atoms.positions, dtype=torch.float32, device=self.base_calc.device),
73
+ "numbers": self._t_numbers,
74
+ "charge": self._t_charge,
75
+ "mult": self._t_mult,
76
+ }
77
+
78
+ _unsqueezed = False
79
+ if cell is not None:
80
+ _in["cell"] = cell
81
+ else:
82
+ for k, v in _in.items():
83
+ _in[k] = v.unsqueeze(0)
84
+ _unsqueezed = True
85
+
86
+ results = self.base_calc(_in, forces="forces" in properties, stress="stress" in properties)
87
+
88
+ for k, v in results.items():
89
+ if _unsqueezed:
90
+ v = v.squeeze(0)
91
+ results[k] = v.detach().cpu().numpy() # type: ignore
92
+
93
+ self.results["energy"] = results["energy"]
94
+ self.results["charges"] = results["charges"]
95
+ if "forces" in properties:
96
+ self.results["forces"] = results["forces"]
97
+ if "stress" in properties:
98
+ self.results["stress"] = results["stress"]
@@ -0,0 +1,76 @@
1
+ from typing import ClassVar
2
+
3
+ import torch
4
+
5
+ from .calculator import AIMNet2Calculator
6
+
7
+ try:
8
+ import pysisyphus.run # type: ignore
9
+ from pysisyphus.calculators.Calculator import Calculator # type: ignore
10
+ from pysisyphus.constants import ANG2BOHR, AU2EV, BOHR2ANG # type: ignore
11
+ from pysisyphus.elem_data import ATOMIC_NUMBERS # type: ignore
12
+ except ImportError:
13
+ raise ImportError("Pysisyphus is not installed. Please install Pysisyphus to use this module.") from None
14
+
15
+ EV2AU = 1 / AU2EV
16
+
17
+
18
+ class AIMNet2Pysis(Calculator):
19
+ implemented_properties: ClassVar = ["energy", "forces", "free_energy", "charges", "stress"]
20
+
21
+ def __init__(self, model: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, **kwargs):
22
+ super().__init__(charge=charge, mult=mult, **kwargs)
23
+ if isinstance(model, str):
24
+ model = AIMNet2Calculator(model)
25
+ self.model = model
26
+
27
+ def _prepere_input(self, atoms, coord):
28
+ device = self.model.device
29
+ numbers = torch.as_tensor([ATOMIC_NUMBERS[a.lower()] for a in atoms], device=device)
30
+ coord = torch.as_tensor(coord, dtype=torch.float, device=device).view(-1, 3) * BOHR2ANG
31
+ charge = torch.as_tensor([self.charge], dtype=torch.float, device=device)
32
+ mult = torch.as_tensor([self.mult], dtype=torch.float, device=device)
33
+ return {"coord": coord, "numbers": numbers, "charge": charge, "mult": mult}
34
+
35
+ @staticmethod
36
+ def _results_get_energy(results):
37
+ return results["energy"].item() * EV2AU
38
+
39
+ @staticmethod
40
+ def _results_get_forces(results):
41
+ return (results["forces"].detach() * (EV2AU / ANG2BOHR)).flatten().to(torch.double).cpu().numpy()
42
+
43
+ @staticmethod
44
+ def _results_get_hessian(results):
45
+ return (
46
+ (results["hessian"].flatten(0, 1).flatten(-2, -1) * (EV2AU / ANG2BOHR / ANG2BOHR))
47
+ .to(torch.double)
48
+ .cpu()
49
+ .numpy()
50
+ )
51
+
52
+ def get_energy(self, atoms, coords):
53
+ _in = self._prepere_input(atoms, coords)
54
+ res = self.model(_in)
55
+ energy = self._results_get_energy(res)
56
+ return {"energy": energy}
57
+
58
+ def get_forces(self, atoms, coords):
59
+ _in = self._prepere_input(atoms, coords)
60
+ res = self.model(_in, forces=True)
61
+ energy = self._results_get_energy(res)
62
+ forces = self._results_get_forces(res)
63
+ return {"energy": energy, "forces": forces}
64
+
65
+ def get_hessian(self, atoms, coords):
66
+ _in = self._prepere_input(atoms, coords)
67
+ res = self.model(_in, forces=True, hessian=True)
68
+ energy = self._results_get_energy(res)
69
+ forces = self._results_get_forces(res)
70
+ hessian = self._results_get_hessian(res)
71
+ return {"energy": energy, "forces": forces, "hessian": hessian}
72
+
73
+
74
+ def run_pysis():
75
+ pysisyphus.run.CALC_DICT["aimnet"] = AIMNet2Pysis
76
+ pysisyphus.run.run()
@@ -0,0 +1,320 @@
1
+ import warnings
2
+ from typing import Any, ClassVar, Dict, Literal
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from .model_registry import get_model_path
8
+ from .nbmat import TooManyNeighborsError, calc_nbmat
9
+
10
+
11
+ class AIMNet2Calculator:
12
+ """Genegic AIMNet2 calculator
13
+ A helper class to load AIMNet2 models and perform inference.
14
+ """
15
+
16
+ keys_in: ClassVar[Dict[str, torch.dtype]] = {"coord": torch.float, "numbers": torch.int, "charge": torch.float}
17
+ keys_in_optional: ClassVar[Dict[str, torch.dtype]] = {
18
+ "mult": torch.float,
19
+ "mol_idx": torch.int,
20
+ "nbmat": torch.int,
21
+ "nbmat_lr": torch.int,
22
+ "nb_pad_mask": torch.bool,
23
+ "nb_pad_mask_lr": torch.bool,
24
+ "shifts": torch.float,
25
+ "shifts_lr": torch.float,
26
+ "cell": torch.float,
27
+ }
28
+ keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
29
+ atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]
30
+
31
+ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ if isinstance(model, str):
34
+ p = get_model_path(model)
35
+ self.model = torch.jit.load(p, map_location=self.device)
36
+ elif isinstance(model, nn.Module):
37
+ self.model = model.to(self.device)
38
+ else:
39
+ raise TypeError("Invalid model type/name.")
40
+
41
+ self.cutoff = self.model.cutoff
42
+ self.lr = hasattr(self.model, "cutoff_lr")
43
+ self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf"))
44
+ self.max_density = 0.2
45
+ self.nb_threshold = nb_threshold
46
+
47
+ # indicator if input was flattened
48
+ self._batch = None
49
+ self._max_mol_size: int = 0
50
+ # placeholder for tensors that require grad
51
+ self._saved_for_grad = {}
52
+ # set flag of current Coulomb model
53
+ coul_methods = {getattr(mod, "method", None) for mod in iter_lrcoulomb_mods(self.model)}
54
+ if len(coul_methods) > 1:
55
+ raise ValueError("Multiple Coulomb modules found.")
56
+ if len(coul_methods):
57
+ self._coulomb_method = coul_methods.pop()
58
+ else:
59
+ self._coulomb_method = None
60
+
61
+ def __call__(self, *args, **kwargs):
62
+ return self.eval(*args, **kwargs)
63
+
64
+ def set_lrcoulomb_method(
65
+ self, method: Literal["simple", "dsf", "ewald"], cutoff: float = 15.0, dsf_alpha: float = 0.2
66
+ ):
67
+ if method not in ("simple", "dsf", "ewald"):
68
+ raise ValueError(f"Invalid method: {method}")
69
+ for mod in iter_lrcoulomb_mods(self.model):
70
+ mod.method = method # type: ignore
71
+ if method == "simple":
72
+ self.cutoff_lr = float("inf")
73
+ elif method == "dsf":
74
+ self.cutoff_lr = cutoff
75
+ mod.dsf_alpha = dsf_alpha # type: ignore
76
+ mod.dsf_rc = cutoff # type: ignore
77
+ elif method == "ewald":
78
+ # current implementaion of Ewald does not use nb mat
79
+ self.cutoff_lr = cutoff
80
+ self._coulomb_method = method
81
+
82
+ def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
83
+ data = self.prepare_input(data)
84
+ if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
85
+ raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
86
+ data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)
87
+ with torch.jit.optimized_execution(False): # type: ignore
88
+ data = self.model(data)
89
+ data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
90
+ data = self.process_output(data)
91
+ return data
92
+
93
+ def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
94
+ data = self.to_input_tensors(data)
95
+ data = self.mol_flatten(data)
96
+ if data.get("cell") is not None:
97
+ if data["mol_idx"][-1] > 0:
98
+ raise NotImplementedError("PBC with multiple molecules is not implemented yet.")
99
+ if self._coulomb_method == "simple":
100
+ warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
101
+ self.set_lrcoulomb_method("dsf")
102
+ if data["coord"].ndim == 2:
103
+ data = self.make_nbmat(data)
104
+ data = self.pad_input(data)
105
+ return data
106
+
107
+ def process_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
108
+ if data["coord"].ndim == 2:
109
+ data = self.unpad_output(data)
110
+ data = self.mol_unflatten(data)
111
+ data = self.keep_only(data)
112
+ return data
113
+
114
+ def to_input_tensors(self, data: Dict[str, Any]) -> Dict[str, Tensor]:
115
+ ret = {}
116
+ for k in self.keys_in:
117
+ if k not in data:
118
+ raise KeyError(f"Missing key {k} in the input data")
119
+ # always detach !!
120
+ ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in[k]).detach()
121
+ for k in self.keys_in_optional:
122
+ if k in data and data[k] is not None:
123
+ ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in_optional[k]).detach()
124
+ # convert any scalar tensors to shape (1,) tensors
125
+ for k, v in ret.items():
126
+ if v.ndim == 0:
127
+ ret[k] = v.unsqueeze(0)
128
+ return ret
129
+
130
+ def mol_flatten(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
131
+ """Flatten the input data for multiple molecules.
132
+ Will not flatten for batched input and molecule size below threshold.
133
+ """
134
+ ndim = data["coord"].ndim
135
+ if ndim == 2:
136
+ # single molecule or already flattened
137
+ self._batch = None
138
+ if "mol_idx" not in data:
139
+ data["mol_idx"] = torch.zeros(data["coord"].shape[0], dtype=torch.long, device=self.device)
140
+ self._max_mol_size = data["coord"].shape[0]
141
+ elif data["mol_idx"][-1] == 0:
142
+ self._max_mol_size = len(data["mol_idx"])
143
+ else:
144
+ self._max_mol_size = data["mol_idx"].unique(return_counts=True)[1].max().item()
145
+
146
+ elif ndim == 3:
147
+ # batched input
148
+ B, N = data["coord"].shape[:2]
149
+ if self.nb_threshold > N or self.device == "cpu":
150
+ self._batch = B
151
+ data["mol_idx"] = torch.repeat_interleave(
152
+ torch.arange(0, B, device=self.device), torch.full((B,), N, device=self.device)
153
+ )
154
+ for k, v in data.items():
155
+ if k in self.atom_feature_keys:
156
+ data[k] = v.flatten(0, 1)
157
+ else:
158
+ self._batch = None
159
+ self._max_mol_size = N
160
+ return data
161
+
162
+ def mol_unflatten(self, data: Dict[str, Tensor], batch=None) -> Dict[str, Tensor]:
163
+ batch = batch if batch is not None else self._batch
164
+ if batch is not None:
165
+ for k, v in data.items():
166
+ if k in self.atom_feature_keys:
167
+ data[k] = v.view(batch, -1, *v.shape[1:])
168
+ return data
169
+
170
+ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
171
+ assert self._max_mol_size > 0, "Molecule size is not set"
172
+
173
+ if "cell" in data and data["cell"] is not None:
174
+ data["coord"] = move_coord_to_cell(data["coord"], data["cell"])
175
+ cell = data["cell"]
176
+ else:
177
+ cell = None
178
+
179
+ while True:
180
+ try:
181
+ maxnb1 = calc_max_nb(self.cutoff, self.max_density)
182
+ maxnb2 = calc_max_nb(self.cutoff_lr, self.max_density) if self.lr else None # type: ignore
183
+ if cell is None:
184
+ maxnb1 = min(maxnb1, self._max_mol_size)
185
+ maxnb2 = min(maxnb2, self._max_mol_size) if self.lr else None # type: ignore
186
+ maxnb = (maxnb1, maxnb2)
187
+ nbmat1, nbmat2, shifts1, shifts2 = calc_nbmat(
188
+ data["coord"],
189
+ (self.cutoff, self.cutoff_lr),
190
+ maxnb, # type: ignore
191
+ cell,
192
+ data.get("mol_idx"), # type: ignore
193
+ )
194
+ break
195
+ except TooManyNeighborsError:
196
+ self.max_density *= 1.2
197
+ assert self.max_density <= 4, "Something went wrong in nbmat calculation"
198
+ data["nbmat"] = nbmat1
199
+ if self.lr:
200
+ assert nbmat2 is not None
201
+ data["nbmat_lr"] = nbmat2
202
+ if cell is not None:
203
+ assert shifts1 is not None
204
+ data["shifts"] = shifts1
205
+ if self.lr:
206
+ assert shifts2 is not None
207
+ data["shifts_lr"] = shifts2
208
+ return data
209
+
210
+ def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
211
+ N = data["nbmat"].shape[0]
212
+ data["mol_idx"] = maybe_pad_dim0(data["mol_idx"], N, value=data["mol_idx"][-1].item())
213
+ for k in ("coord", "numbers"):
214
+ if k in data:
215
+ data[k] = maybe_pad_dim0(data[k], N)
216
+ return data
217
+
218
+ def unpad_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
219
+ N = data["nbmat"].shape[0] - 1
220
+ for k, v in data.items():
221
+ if k in self.atom_feature_keys:
222
+ data[k] = maybe_unpad_dim0(v, N)
223
+ return data
224
+
225
+ def set_grad_tensors(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
226
+ self._saved_for_grad = {}
227
+ if forces or hessian:
228
+ data["coord"].requires_grad_(True)
229
+ self._saved_for_grad["coord"] = data["coord"]
230
+ if stress:
231
+ assert "cell" in data and data["cell"] is not None, "Stress calculation requires cell"
232
+ scaling = torch.eye(3, requires_grad=True, dtype=data["cell"].dtype, device=data["cell"].device)
233
+ data["coord"] = data["coord"] @ scaling
234
+ data["cell"] = data["cell"] @ scaling
235
+ self._saved_for_grad["scaling"] = scaling
236
+ return data
237
+
238
+ def keep_only(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
239
+ ret = {}
240
+ for k, v in data.items():
241
+ if k in self.keys_out or (k.endswith("_std") and k[:-4] in self.keys_out):
242
+ ret[k] = v
243
+ return ret
244
+
245
+ def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
246
+ training = getattr(self.model, "training", False)
247
+ _create_graph = hessian or training
248
+ x = []
249
+ if hessian:
250
+ forces = True
251
+ if forces and ("forces" not in data or (_create_graph and not data["forces"].requires_grad)):
252
+ forces = True
253
+ x.append(self._saved_for_grad["coord"])
254
+ if stress:
255
+ x.append(self._saved_for_grad["scaling"])
256
+ if x:
257
+ tot_energy = data["energy"].sum()
258
+ deriv = torch.autograd.grad(tot_energy, x, create_graph=_create_graph)
259
+ if forces:
260
+ data["forces"] = -deriv[0]
261
+ if stress:
262
+ dedc = deriv[0] if not forces else deriv[1]
263
+ data["stress"] = dedc / data["cell"].detach().det().abs()
264
+ if hessian:
265
+ data["hessian"] = self.calculate_hessian(data["forces"], self._saved_for_grad["coord"])
266
+ return data
267
+
268
+ @staticmethod
269
+ def calculate_hessian(forces: Tensor, coord: Tensor) -> Tensor:
270
+ # here forces have shape (N, 3) and coord has shape (N+1, 3)
271
+ # return hessian with shape (N, 3, N, 3)
272
+ hessian = -torch.stack([
273
+ torch.autograd.grad(_f, coord, retain_graph=True)[0] for _f in forces.flatten().unbind()
274
+ ]).view(-1, 3, coord.shape[0], 3)[:-1, :, :-1, :]
275
+ return hessian
276
+
277
+
278
+ def maybe_pad_dim0(a: Tensor, N: int, value=0.0) -> Tensor:
279
+ _shape_diff = N - a.shape[0]
280
+ assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
281
+ if _shape_diff == 1:
282
+ a = pad_dim0(a, value=value)
283
+ return a
284
+
285
+
286
+ def pad_dim0(a: Tensor, value=0.0) -> Tensor:
287
+ shapes = [0] * ((a.ndim - 1) * 2) + [0, 1]
288
+ a = torch.nn.functional.pad(a, shapes, mode="constant", value=value)
289
+ return a
290
+
291
+
292
+ def maybe_unpad_dim0(a: Tensor, N: int) -> Tensor:
293
+ _shape_diff = a.shape[0] - N
294
+ assert _shape_diff == 0 or _shape_diff == 1, "Invalid shape"
295
+ if _shape_diff == 1:
296
+ a = a[:-1]
297
+ return a
298
+
299
+
300
+ def move_coord_to_cell(coord, cell):
301
+ coord_f = coord @ cell.inverse()
302
+ coord_f = coord_f % 1
303
+ return coord_f @ cell
304
+
305
+
306
+ def _named_children_rec(module):
307
+ if isinstance(module, torch.nn.Module):
308
+ for name, child in module.named_children():
309
+ yield name, child
310
+ yield from _named_children_rec(child)
311
+
312
+
313
+ def iter_lrcoulomb_mods(model):
314
+ for name, module in _named_children_rec(model):
315
+ if name == "lrcoulomb":
316
+ yield module
317
+
318
+
319
+ def calc_max_nb(cutoff: float, density: float = 0.2) -> int | float:
320
+ return int(density * 4 / 3 * 3.14159 * cutoff**3) if cutoff < float("inf") else float("inf")
@@ -0,0 +1,60 @@
1
+ import logging
2
+ import os
3
+ from typing import Dict, Optional
4
+
5
+ import click
6
+ import requests
7
+ import yaml
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+
12
+ def load_model_registry(registry_file: Optional[str] = None) -> Dict[str, str]:
13
+ registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml")
14
+ with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f:
15
+ return yaml.load(f, Loader=yaml.SafeLoader)
16
+
17
+
18
+ def create_assets_dir():
19
+ os.makedirs(os.path.join(os.path.dirname(__file__), "assets"), exist_ok=True)
20
+
21
+
22
+ def get_registry_model_path(model_name: str) -> str:
23
+ model_registry = load_model_registry()
24
+ create_assets_dir()
25
+ if model_name in model_registry["aliases"]:
26
+ model_name = model_registry["aliases"][model_name] # type: ignore
27
+ if model_name not in model_registry["models"]:
28
+ raise ValueError(f"Model {model_name} not found in the registry.")
29
+ cfg = model_registry["models"][model_name] # type: ignore
30
+ model_path = _maybe_download_asset(**cfg) # type: ignore
31
+ return model_path
32
+
33
+
34
+ def _maybe_download_asset(file: str, url: str) -> str:
35
+ filename = os.path.join(os.path.dirname(__file__), "assets", file)
36
+ if not os.path.exists(filename):
37
+ print(f"Downloading {url} -> {filename}")
38
+ with open(filename, "wb") as f:
39
+ response = requests.get(url, timeout=60)
40
+ f.write(response.content)
41
+ return filename
42
+
43
+
44
+ def get_model_path(s: str) -> str:
45
+ # direct file path
46
+ if os.path.isfile(s):
47
+ print("Found model file:", s)
48
+ else:
49
+ s = get_registry_model_path(s)
50
+ return s
51
+
52
+
53
+ @click.command(short_help="Clear assets directory.")
54
+ def clear_assets():
55
+ from glob import glob
56
+
57
+ for fil in glob(os.path.join(os.path.dirname(__file__), "assets", "*")):
58
+ if os.path.isfile(fil):
59
+ logging.warn(f"Removing {fil}")
60
+ os.remove(fil)
@@ -0,0 +1,33 @@
1
+ # map file name to url
2
+ models:
3
+ aimnet2_wb97m_d3_0:
4
+ file: aimnet2_wb97m_d3_0.jpt
5
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.jpt
6
+ aimnet2_wb97m_d3_1:
7
+ file: aimnet2_wb97m_d3_1.jpt
8
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.jpt
9
+ aimnet2_wb97m_d3_2:
10
+ file: aimnet2_wb97m_d3_2.jpt
11
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.jpt
12
+ aimnet2_wb97m_d3_3:
13
+ file: aimnet2_wb97m_d3_3.jpt
14
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.jpt
15
+ aimnet2_b973c_d3_0:
16
+ file: aimnet2_b973c_d3_0.jpt
17
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.jpt
18
+ aimnet2_b973c_d3_1:
19
+ file: aimnet2_b973c_d3_1.jpt
20
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.jpt
21
+ aimnet2_b973c_d3_2:
22
+ file: aimnet2_b973c_d3_2.jpt
23
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.jpt
24
+ aimnet2_b973c_d3_3:
25
+ file: aimnet2_b973c_d3_3.jpt
26
+ url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.jpt
27
+
28
+ # map model alias to file name
29
+ aliases:
30
+ aimnet2: aimnet2_wb97m_d3_0
31
+ aimnet2_wb97m: aimnet2_wb97m_d3_0
32
+ aimnet2_b973c: aimnet2_b973c_d3_0
33
+ aimnet2_qr: aimnet2_qr_v0