aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/config.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from importlib import import_module
|
|
3
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from jinja2 import Template
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_module(name: str) -> Callable:
|
|
10
|
+
"""
|
|
11
|
+
Retrieves a module and function based on the given name.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
name (str): The name of the module and function in the format 'module.function'.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
function: The function object.
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
ImportError: If the module cannot be imported.
|
|
21
|
+
AttributeError: If the function does not exist in the module.
|
|
22
|
+
"""
|
|
23
|
+
parts = name.split(".")
|
|
24
|
+
module_name, func_name = ".".join(parts[:-1]), parts[-1]
|
|
25
|
+
module = import_module(module_name)
|
|
26
|
+
func = getattr(module, func_name)
|
|
27
|
+
return func # type: ignore[no-any-return]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_init_module(name: str, args: Optional[List] = None, kwargs: Optional[Dict] = None) -> Callable:
|
|
31
|
+
"""
|
|
32
|
+
Get the initialized module based on the given name, arguments, and keyword arguments.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
name (str): The name of the module.
|
|
36
|
+
args (List, optional): The arguments to pass to the module constructor. Defaults to an empty list.
|
|
37
|
+
kwargs (Dict, optional): The keyword arguments to pass to the module constructor. Defaults to an empty dictionary.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The initialized module.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
args = args if args is not None else []
|
|
44
|
+
kwargs = kwargs if kwargs is not None else {}
|
|
45
|
+
return get_module(name)(*args, **kwargs) # type: ignore[no-any-return]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def load_yaml(
|
|
49
|
+
config: Dict[str, Any] | List | str, hyperpar: Optional[Dict[str, Any] | str] = None
|
|
50
|
+
) -> Dict[str, Any] | List:
|
|
51
|
+
"""
|
|
52
|
+
Load a YAML configuration file and apply optional hyperparameters.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
config (Union[str, List, Dict]): The YAML configuration file path or a YAML object.
|
|
56
|
+
hyperpar (Optional[Union[Dict, str, None]]): Optional hyperparameters to apply to the configuration.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Union[List, Dict]: The loaded and processed configuration.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
FileNotFoundError: If a file specified in the configuration does not exist.
|
|
63
|
+
|
|
64
|
+
"""
|
|
65
|
+
basedir = ""
|
|
66
|
+
if isinstance(hyperpar, str):
|
|
67
|
+
hyperpar = load_yaml(hyperpar) # type: ignore[assignment]
|
|
68
|
+
if not isinstance(hyperpar, dict):
|
|
69
|
+
raise TypeError("Loaded hyperpar must be a dict")
|
|
70
|
+
if isinstance(config, (list, dict)):
|
|
71
|
+
if hyperpar:
|
|
72
|
+
for d, k, v in _iter_rec_bottomup(config):
|
|
73
|
+
if isinstance(v, str) and "{{" in v:
|
|
74
|
+
d[k] = Template(v).render(**hyperpar) # type: ignore[assignment, index]
|
|
75
|
+
else:
|
|
76
|
+
with open(config, encoding="utf-8") as f:
|
|
77
|
+
config = f.read()
|
|
78
|
+
if hyperpar:
|
|
79
|
+
config = Template(config).render(**hyperpar)
|
|
80
|
+
config = yaml.load(config, Loader=yaml.FullLoader) # noqa: S506
|
|
81
|
+
# plugin yaml configs
|
|
82
|
+
for d, k, v in _iter_rec_bottomup(config): # type: ignore[arg-type]
|
|
83
|
+
if isinstance(v, str) and any(v.endswith(x) for x in (".yml", ".yaml")):
|
|
84
|
+
if not os.path.isfile(v):
|
|
85
|
+
v = os.path.join(basedir, v)
|
|
86
|
+
d[k] = load_yaml(v, hyperpar) # type: ignore[assignment, index]
|
|
87
|
+
return config # type: ignore[return-value]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _iter_rec_bottomup(
|
|
91
|
+
d: Dict[str, Any] | List,
|
|
92
|
+
) -> Iterator[Tuple[Dict[str, Any] | List, str | int, Any]]:
|
|
93
|
+
if isinstance(d, list):
|
|
94
|
+
it = enumerate(d)
|
|
95
|
+
elif isinstance(d, dict):
|
|
96
|
+
it = d.items() # type: ignore[assignment]
|
|
97
|
+
else:
|
|
98
|
+
raise TypeError(f"Unknown type: {type(d)}")
|
|
99
|
+
for k, v in it:
|
|
100
|
+
if isinstance(v, (list, dict)):
|
|
101
|
+
yield from _iter_rec_bottomup(v)
|
|
102
|
+
yield d, k, v
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def build_module(
|
|
106
|
+
config: Union[str, Dict, List], hyperpar: Union[str, Dict, None] = None
|
|
107
|
+
) -> Union[List, Dict, Callable]:
|
|
108
|
+
"""
|
|
109
|
+
Build a module based on the provided configuration.
|
|
110
|
+
Every (possibly nested) dictionary with a 'class' key will be replaced by an instance initialized with
|
|
111
|
+
arguments and keywords provided as 'args' and 'kwargs' keys.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
config (Union[str, Dict, List]): The configuration for building the module.
|
|
115
|
+
hyperpar (Union[str, Dict, None], optional): The hyperparameters for the module. Defaults to None.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Union[List, Dict, Callable]: The built module.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
AssertionError: If `hyperpar` is provided and is not a dictionary.
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
if isinstance(hyperpar, str):
|
|
125
|
+
hyperpar = load_yaml(hyperpar) # type: ignore[assignment]
|
|
126
|
+
if hyperpar and not isinstance(hyperpar, dict):
|
|
127
|
+
raise TypeError("Hyperpar must be a dictionary")
|
|
128
|
+
config = load_yaml(config, hyperpar)
|
|
129
|
+
for d, k, v in _iter_rec_bottomup(config):
|
|
130
|
+
if isinstance(v, dict) and "class" in v:
|
|
131
|
+
d[k] = get_init_module( # type: ignore[index]
|
|
132
|
+
v["class"],
|
|
133
|
+
args=v.get("args", []), # type: ignore[assignment]
|
|
134
|
+
kwargs=v.get("kwargs", {}),
|
|
135
|
+
)
|
|
136
|
+
if "class" in config:
|
|
137
|
+
config = get_init_module( # type: ignore[assignment]
|
|
138
|
+
config["class"], # type: ignore[call-overload]
|
|
139
|
+
args=config.get("args", []), # type: ignore[union-attr]
|
|
140
|
+
kwargs=config.get("kwargs", {}), # type: ignore[union-attr]
|
|
141
|
+
)
|
|
142
|
+
return config # type: ignore[assignment]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def dict_to_dotted(d, parent=""):
|
|
146
|
+
if parent:
|
|
147
|
+
parent += "."
|
|
148
|
+
for k, v in list(d.items()):
|
|
149
|
+
if isinstance(v, dict) and v:
|
|
150
|
+
v = dict_to_dotted(v, parent + k)
|
|
151
|
+
d.update(v)
|
|
152
|
+
d.pop(k)
|
|
153
|
+
else:
|
|
154
|
+
d[parent + k] = d.pop(k)
|
|
155
|
+
return d
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def dotted_to_dict(d):
|
|
159
|
+
for k, v in list(d.items()):
|
|
160
|
+
if "." not in k:
|
|
161
|
+
continue
|
|
162
|
+
ks = k.split(".")
|
|
163
|
+
ds = d
|
|
164
|
+
for ksp in ks[:-1]:
|
|
165
|
+
if ksp not in ds:
|
|
166
|
+
ds[ksp] = {}
|
|
167
|
+
ds = ds[ksp]
|
|
168
|
+
ds[ks[-1]] = v
|
|
169
|
+
d.pop(k)
|
|
170
|
+
return d
|
aimnet/constants.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
# from ase.units
|
|
6
|
+
kB = 8.617330337217213e-05
|
|
7
|
+
fs = 0.09822694788464063
|
|
8
|
+
Hartree = 27.211386024367243
|
|
9
|
+
half_Hartree = 0.5 * Hartree
|
|
10
|
+
Bohr = 0.5291772105638411
|
|
11
|
+
Bohr_inv = 1 / Bohr
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_masses(device="cpu"):
|
|
15
|
+
"""Atomic masses from `ase.data.atomic_masses`"""
|
|
16
|
+
atomic_masses = torch.tensor([
|
|
17
|
+
0.0,
|
|
18
|
+
1.008,
|
|
19
|
+
4.002602,
|
|
20
|
+
6.94,
|
|
21
|
+
9.0121831,
|
|
22
|
+
10.81,
|
|
23
|
+
12.011,
|
|
24
|
+
14.007,
|
|
25
|
+
15.999,
|
|
26
|
+
18.99840316,
|
|
27
|
+
20.1797,
|
|
28
|
+
22.98976928,
|
|
29
|
+
24.305,
|
|
30
|
+
26.9815385,
|
|
31
|
+
28.085,
|
|
32
|
+
30.973762,
|
|
33
|
+
32.06,
|
|
34
|
+
35.45,
|
|
35
|
+
39.948,
|
|
36
|
+
39.0983,
|
|
37
|
+
40.078,
|
|
38
|
+
44.955908,
|
|
39
|
+
47.867,
|
|
40
|
+
50.9415,
|
|
41
|
+
51.9961,
|
|
42
|
+
54.938044,
|
|
43
|
+
55.845,
|
|
44
|
+
58.933194,
|
|
45
|
+
58.6934,
|
|
46
|
+
63.546,
|
|
47
|
+
65.38,
|
|
48
|
+
69.723,
|
|
49
|
+
72.63,
|
|
50
|
+
74.921595,
|
|
51
|
+
78.971,
|
|
52
|
+
79.904,
|
|
53
|
+
83.798,
|
|
54
|
+
85.4678,
|
|
55
|
+
87.62,
|
|
56
|
+
88.90584,
|
|
57
|
+
91.224,
|
|
58
|
+
92.90637,
|
|
59
|
+
95.95,
|
|
60
|
+
97.90721,
|
|
61
|
+
101.07,
|
|
62
|
+
102.9055,
|
|
63
|
+
106.42,
|
|
64
|
+
107.8682,
|
|
65
|
+
112.414,
|
|
66
|
+
114.818,
|
|
67
|
+
118.71,
|
|
68
|
+
121.76,
|
|
69
|
+
127.6,
|
|
70
|
+
126.90447,
|
|
71
|
+
131.293,
|
|
72
|
+
132.90545196,
|
|
73
|
+
137.327,
|
|
74
|
+
138.90547,
|
|
75
|
+
140.116,
|
|
76
|
+
140.90766,
|
|
77
|
+
144.242,
|
|
78
|
+
144.91276,
|
|
79
|
+
150.36,
|
|
80
|
+
151.964,
|
|
81
|
+
157.25,
|
|
82
|
+
158.92535,
|
|
83
|
+
162.5,
|
|
84
|
+
164.93033,
|
|
85
|
+
167.259,
|
|
86
|
+
168.93422,
|
|
87
|
+
173.054,
|
|
88
|
+
174.9668,
|
|
89
|
+
178.49,
|
|
90
|
+
180.94788,
|
|
91
|
+
183.84,
|
|
92
|
+
186.207,
|
|
93
|
+
190.23,
|
|
94
|
+
192.217,
|
|
95
|
+
195.084,
|
|
96
|
+
196.966569,
|
|
97
|
+
200.592,
|
|
98
|
+
204.38,
|
|
99
|
+
207.2,
|
|
100
|
+
208.9804,
|
|
101
|
+
208.98243,
|
|
102
|
+
209.98715,
|
|
103
|
+
222.01758,
|
|
104
|
+
223.01974,
|
|
105
|
+
226.02541,
|
|
106
|
+
227.02775,
|
|
107
|
+
232.0377,
|
|
108
|
+
231.03588,
|
|
109
|
+
238.02891,
|
|
110
|
+
237.04817,
|
|
111
|
+
244.06421,
|
|
112
|
+
243.06138,
|
|
113
|
+
247.07035,
|
|
114
|
+
247.07031,
|
|
115
|
+
251.07959,
|
|
116
|
+
252.083,
|
|
117
|
+
257.09511,
|
|
118
|
+
258.09843,
|
|
119
|
+
259.101,
|
|
120
|
+
262.11,
|
|
121
|
+
267.122,
|
|
122
|
+
268.126,
|
|
123
|
+
271.134,
|
|
124
|
+
270.133,
|
|
125
|
+
269.1338,
|
|
126
|
+
278.156,
|
|
127
|
+
281.165,
|
|
128
|
+
281.166,
|
|
129
|
+
285.177,
|
|
130
|
+
286.182,
|
|
131
|
+
289.19,
|
|
132
|
+
289.194,
|
|
133
|
+
293.204,
|
|
134
|
+
293.208,
|
|
135
|
+
294.214,
|
|
136
|
+
])
|
|
137
|
+
return atomic_masses.to(device)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_gfn1_rep(device="cpu"):
|
|
141
|
+
"""Parameters for GFN1 repulsion function."""
|
|
142
|
+
gfn1_alpha = torch.tensor([
|
|
143
|
+
0.000001,
|
|
144
|
+
2.209700,
|
|
145
|
+
1.382907,
|
|
146
|
+
0.671797,
|
|
147
|
+
0.865377,
|
|
148
|
+
1.093544,
|
|
149
|
+
1.281954,
|
|
150
|
+
1.727773,
|
|
151
|
+
2.004253,
|
|
152
|
+
2.507078,
|
|
153
|
+
3.038727,
|
|
154
|
+
0.704472,
|
|
155
|
+
0.862629,
|
|
156
|
+
0.929219,
|
|
157
|
+
0.948165,
|
|
158
|
+
1.067197,
|
|
159
|
+
1.200803,
|
|
160
|
+
1.404155,
|
|
161
|
+
1.323756,
|
|
162
|
+
0.581529,
|
|
163
|
+
0.665588,
|
|
164
|
+
0.841357,
|
|
165
|
+
0.828638,
|
|
166
|
+
1.061627,
|
|
167
|
+
0.997051,
|
|
168
|
+
1.019783,
|
|
169
|
+
1.137174,
|
|
170
|
+
1.188538,
|
|
171
|
+
1.399197,
|
|
172
|
+
1.199230,
|
|
173
|
+
1.145056,
|
|
174
|
+
1.047536,
|
|
175
|
+
1.129480,
|
|
176
|
+
1.233641,
|
|
177
|
+
1.270088,
|
|
178
|
+
1.153580,
|
|
179
|
+
1.335287,
|
|
180
|
+
0.554032,
|
|
181
|
+
0.657904,
|
|
182
|
+
0.760144,
|
|
183
|
+
0.739520,
|
|
184
|
+
0.895357,
|
|
185
|
+
0.944064,
|
|
186
|
+
1.028240,
|
|
187
|
+
1.066144,
|
|
188
|
+
1.131380,
|
|
189
|
+
1.206869,
|
|
190
|
+
1.058886,
|
|
191
|
+
1.026434,
|
|
192
|
+
0.898148,
|
|
193
|
+
1.008192,
|
|
194
|
+
0.982673,
|
|
195
|
+
0.973410,
|
|
196
|
+
0.949181,
|
|
197
|
+
1.074785,
|
|
198
|
+
0.579919,
|
|
199
|
+
0.606485,
|
|
200
|
+
1.311200,
|
|
201
|
+
0.839861,
|
|
202
|
+
0.847281,
|
|
203
|
+
0.854701,
|
|
204
|
+
0.862121,
|
|
205
|
+
0.869541,
|
|
206
|
+
0.876961,
|
|
207
|
+
0.884381,
|
|
208
|
+
0.891801,
|
|
209
|
+
0.899221,
|
|
210
|
+
0.906641,
|
|
211
|
+
0.914061,
|
|
212
|
+
0.921481,
|
|
213
|
+
0.928901,
|
|
214
|
+
0.936321,
|
|
215
|
+
0.853744,
|
|
216
|
+
0.971873,
|
|
217
|
+
0.992643,
|
|
218
|
+
1.132106,
|
|
219
|
+
1.118216,
|
|
220
|
+
1.245003,
|
|
221
|
+
1.304590,
|
|
222
|
+
1.293034,
|
|
223
|
+
1.181865,
|
|
224
|
+
0.976397,
|
|
225
|
+
0.988859,
|
|
226
|
+
1.047194,
|
|
227
|
+
1.013118,
|
|
228
|
+
0.964652,
|
|
229
|
+
0.998641,
|
|
230
|
+
])
|
|
231
|
+
# Zeff
|
|
232
|
+
gfn1_Zeff = torch.tensor([
|
|
233
|
+
0.000000,
|
|
234
|
+
1.116244,
|
|
235
|
+
0.440231,
|
|
236
|
+
2.747587,
|
|
237
|
+
4.076830,
|
|
238
|
+
4.458376,
|
|
239
|
+
4.428763,
|
|
240
|
+
5.498808,
|
|
241
|
+
5.171786,
|
|
242
|
+
6.931741,
|
|
243
|
+
9.102523,
|
|
244
|
+
10.591259,
|
|
245
|
+
15.238107,
|
|
246
|
+
16.283595,
|
|
247
|
+
16.898359,
|
|
248
|
+
15.249559,
|
|
249
|
+
15.100323,
|
|
250
|
+
17.000000,
|
|
251
|
+
17.153132,
|
|
252
|
+
20.831436,
|
|
253
|
+
19.840212,
|
|
254
|
+
18.676202,
|
|
255
|
+
17.084130,
|
|
256
|
+
22.352532,
|
|
257
|
+
22.873486,
|
|
258
|
+
24.160655,
|
|
259
|
+
25.983149,
|
|
260
|
+
27.169215,
|
|
261
|
+
23.396999,
|
|
262
|
+
29.000000,
|
|
263
|
+
31.185765,
|
|
264
|
+
33.128619,
|
|
265
|
+
35.493164,
|
|
266
|
+
36.125762,
|
|
267
|
+
32.148852,
|
|
268
|
+
35.000000,
|
|
269
|
+
36.000000,
|
|
270
|
+
39.653032,
|
|
271
|
+
38.924904,
|
|
272
|
+
39.000000,
|
|
273
|
+
36.521516,
|
|
274
|
+
40.803132,
|
|
275
|
+
41.939347,
|
|
276
|
+
43.000000,
|
|
277
|
+
44.492732,
|
|
278
|
+
45.241537,
|
|
279
|
+
42.105527,
|
|
280
|
+
43.201446,
|
|
281
|
+
49.016827,
|
|
282
|
+
51.718417,
|
|
283
|
+
54.503455,
|
|
284
|
+
50.757213,
|
|
285
|
+
49.215262,
|
|
286
|
+
53.000000,
|
|
287
|
+
52.500985,
|
|
288
|
+
65.029838,
|
|
289
|
+
46.532974,
|
|
290
|
+
48.337542,
|
|
291
|
+
30.638143,
|
|
292
|
+
34.130718,
|
|
293
|
+
37.623294,
|
|
294
|
+
41.115870,
|
|
295
|
+
44.608445,
|
|
296
|
+
48.101021,
|
|
297
|
+
51.593596,
|
|
298
|
+
55.086172,
|
|
299
|
+
58.578748,
|
|
300
|
+
62.071323,
|
|
301
|
+
65.563899,
|
|
302
|
+
69.056474,
|
|
303
|
+
72.549050,
|
|
304
|
+
76.041625,
|
|
305
|
+
55.222897,
|
|
306
|
+
63.743065,
|
|
307
|
+
74.000000,
|
|
308
|
+
75.000000,
|
|
309
|
+
76.000000,
|
|
310
|
+
77.000000,
|
|
311
|
+
78.000000,
|
|
312
|
+
79.000000,
|
|
313
|
+
80.000000,
|
|
314
|
+
81.000000,
|
|
315
|
+
79.578302,
|
|
316
|
+
83.000000,
|
|
317
|
+
84.000000,
|
|
318
|
+
85.000000,
|
|
319
|
+
86.000000,
|
|
320
|
+
])
|
|
321
|
+
gfn1_repa = gfn1_alpha.pow(0.5) * Bohr_inv**0.75
|
|
322
|
+
gfn1_repb = gfn1_Zeff * (0.5 * Hartree * Bohr) ** 0.5
|
|
323
|
+
return gfn1_repa.to(device), gfn1_repb.to(device)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def get_r4r2(device="cpu"):
|
|
327
|
+
"""r4r2 parameter for DFT-D3"""
|
|
328
|
+
## https://github.com/dftd4/dftd4/blob/main/src/dftd4/data/r4r2.f90
|
|
329
|
+
sqrt_z_r4_over_r2 = [
|
|
330
|
+
0.0,
|
|
331
|
+
8.0589,
|
|
332
|
+
3.4698,
|
|
333
|
+
29.0974,
|
|
334
|
+
14.8517,
|
|
335
|
+
11.8799,
|
|
336
|
+
7.8715,
|
|
337
|
+
5.5588,
|
|
338
|
+
4.7566,
|
|
339
|
+
3.8025,
|
|
340
|
+
3.1036,
|
|
341
|
+
26.1552,
|
|
342
|
+
17.2304,
|
|
343
|
+
17.7210,
|
|
344
|
+
12.7442,
|
|
345
|
+
9.5361,
|
|
346
|
+
8.1652,
|
|
347
|
+
6.7463,
|
|
348
|
+
5.6004,
|
|
349
|
+
29.2012,
|
|
350
|
+
22.3934,
|
|
351
|
+
19.0598,
|
|
352
|
+
16.8590,
|
|
353
|
+
15.4023,
|
|
354
|
+
12.5589,
|
|
355
|
+
13.4788,
|
|
356
|
+
12.2309,
|
|
357
|
+
11.2809,
|
|
358
|
+
10.5569,
|
|
359
|
+
10.1428,
|
|
360
|
+
9.4907,
|
|
361
|
+
13.4606,
|
|
362
|
+
10.8544,
|
|
363
|
+
8.9386,
|
|
364
|
+
8.1350,
|
|
365
|
+
7.1251,
|
|
366
|
+
6.1971,
|
|
367
|
+
30.0162,
|
|
368
|
+
24.4103,
|
|
369
|
+
20.3537,
|
|
370
|
+
17.4780,
|
|
371
|
+
13.5528,
|
|
372
|
+
11.8451,
|
|
373
|
+
11.0355,
|
|
374
|
+
10.1997,
|
|
375
|
+
9.5414,
|
|
376
|
+
9.0061,
|
|
377
|
+
8.6417,
|
|
378
|
+
8.9975,
|
|
379
|
+
14.0834,
|
|
380
|
+
11.8333,
|
|
381
|
+
10.0179,
|
|
382
|
+
9.3844,
|
|
383
|
+
8.4110,
|
|
384
|
+
7.5152,
|
|
385
|
+
32.7622,
|
|
386
|
+
27.5708,
|
|
387
|
+
23.1671,
|
|
388
|
+
21.6003,
|
|
389
|
+
20.9615,
|
|
390
|
+
20.4562,
|
|
391
|
+
20.1010,
|
|
392
|
+
19.7475,
|
|
393
|
+
19.4828,
|
|
394
|
+
15.6013,
|
|
395
|
+
19.2362,
|
|
396
|
+
17.4717,
|
|
397
|
+
17.8321,
|
|
398
|
+
17.4237,
|
|
399
|
+
17.1954,
|
|
400
|
+
17.1631,
|
|
401
|
+
14.5716,
|
|
402
|
+
15.8758,
|
|
403
|
+
13.8989,
|
|
404
|
+
12.4834,
|
|
405
|
+
11.4421,
|
|
406
|
+
10.2671,
|
|
407
|
+
8.3549,
|
|
408
|
+
7.8496,
|
|
409
|
+
7.3278,
|
|
410
|
+
7.4820,
|
|
411
|
+
13.5124,
|
|
412
|
+
11.6554,
|
|
413
|
+
10.0959,
|
|
414
|
+
9.7340,
|
|
415
|
+
8.8584,
|
|
416
|
+
8.0125,
|
|
417
|
+
29.8135,
|
|
418
|
+
26.3157,
|
|
419
|
+
19.1885,
|
|
420
|
+
15.8542,
|
|
421
|
+
16.1305,
|
|
422
|
+
15.6161,
|
|
423
|
+
15.1226,
|
|
424
|
+
16.1576,
|
|
425
|
+
0.0000,
|
|
426
|
+
0.0000,
|
|
427
|
+
0.0000,
|
|
428
|
+
0.0000,
|
|
429
|
+
0.0000,
|
|
430
|
+
0.0000,
|
|
431
|
+
0.0000,
|
|
432
|
+
0.0000,
|
|
433
|
+
0.0000,
|
|
434
|
+
0.0000,
|
|
435
|
+
0.0000,
|
|
436
|
+
0.0000,
|
|
437
|
+
0.0000,
|
|
438
|
+
0.0000,
|
|
439
|
+
0.0000,
|
|
440
|
+
0.0000,
|
|
441
|
+
0.0000,
|
|
442
|
+
5.4929,
|
|
443
|
+
6.7286,
|
|
444
|
+
6.5144,
|
|
445
|
+
10.9169,
|
|
446
|
+
10.3600,
|
|
447
|
+
9.4723,
|
|
448
|
+
8.6641,
|
|
449
|
+
]
|
|
450
|
+
|
|
451
|
+
r4r2 = (0.5 * torch.tensor(sqrt_z_r4_over_r2) * torch.arange(len(sqrt_z_r4_over_r2)).sqrt()).sqrt()
|
|
452
|
+
return r4r2.to(device)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def get_dftd3_param(device="cpu"):
|
|
456
|
+
"""Collection of parameters for DFT-D3 model"""
|
|
457
|
+
dirname = os.path.dirname(__file__)
|
|
458
|
+
filename = os.path.join(dirname, "dftd3_data.pt")
|
|
459
|
+
if not os.path.exists(filename):
|
|
460
|
+
raise FileNotFoundError(f"dftd3_data.pt not found in {dirname}.")
|
|
461
|
+
param = torch.load(filename, map_location=device, weights_only=True)
|
|
462
|
+
assert isinstance(param, dict)
|
|
463
|
+
assert "c6ab" in param
|
|
464
|
+
assert "r4r2" in param
|
|
465
|
+
assert "rcov" in param
|
|
466
|
+
assert "cnmax" in param
|
|
467
|
+
return param
|
aimnet/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .sgdataset import DataGroup, SizeGroupedDataset, SizeGroupedSampler # noqa: F401
|