aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/config.py ADDED
@@ -0,0 +1,170 @@
1
+ import os
2
+ from importlib import import_module
3
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
4
+
5
+ import yaml
6
+ from jinja2 import Template
7
+
8
+
9
+ def get_module(name: str) -> Callable:
10
+ """
11
+ Retrieves a module and function based on the given name.
12
+
13
+ Args:
14
+ name (str): The name of the module and function in the format 'module.function'.
15
+
16
+ Returns:
17
+ function: The function object.
18
+
19
+ Raises:
20
+ ImportError: If the module cannot be imported.
21
+ AttributeError: If the function does not exist in the module.
22
+ """
23
+ parts = name.split(".")
24
+ module_name, func_name = ".".join(parts[:-1]), parts[-1]
25
+ module = import_module(module_name)
26
+ func = getattr(module, func_name)
27
+ return func # type: ignore[no-any-return]
28
+
29
+
30
+ def get_init_module(name: str, args: Optional[List] = None, kwargs: Optional[Dict] = None) -> Callable:
31
+ """
32
+ Get the initialized module based on the given name, arguments, and keyword arguments.
33
+
34
+ Args:
35
+ name (str): The name of the module.
36
+ args (List, optional): The arguments to pass to the module constructor. Defaults to an empty list.
37
+ kwargs (Dict, optional): The keyword arguments to pass to the module constructor. Defaults to an empty dictionary.
38
+
39
+ Returns:
40
+ The initialized module.
41
+
42
+ """
43
+ args = args if args is not None else []
44
+ kwargs = kwargs if kwargs is not None else {}
45
+ return get_module(name)(*args, **kwargs) # type: ignore[no-any-return]
46
+
47
+
48
+ def load_yaml(
49
+ config: Dict[str, Any] | List | str, hyperpar: Optional[Dict[str, Any] | str] = None
50
+ ) -> Dict[str, Any] | List:
51
+ """
52
+ Load a YAML configuration file and apply optional hyperparameters.
53
+
54
+ Args:
55
+ config (Union[str, List, Dict]): The YAML configuration file path or a YAML object.
56
+ hyperpar (Optional[Union[Dict, str, None]]): Optional hyperparameters to apply to the configuration.
57
+
58
+ Returns:
59
+ Union[List, Dict]: The loaded and processed configuration.
60
+
61
+ Raises:
62
+ FileNotFoundError: If a file specified in the configuration does not exist.
63
+
64
+ """
65
+ basedir = ""
66
+ if isinstance(hyperpar, str):
67
+ hyperpar = load_yaml(hyperpar) # type: ignore[assignment]
68
+ if not isinstance(hyperpar, dict):
69
+ raise TypeError("Loaded hyperpar must be a dict")
70
+ if isinstance(config, (list, dict)):
71
+ if hyperpar:
72
+ for d, k, v in _iter_rec_bottomup(config):
73
+ if isinstance(v, str) and "{{" in v:
74
+ d[k] = Template(v).render(**hyperpar) # type: ignore[assignment, index]
75
+ else:
76
+ with open(config, encoding="utf-8") as f:
77
+ config = f.read()
78
+ if hyperpar:
79
+ config = Template(config).render(**hyperpar)
80
+ config = yaml.load(config, Loader=yaml.FullLoader) # noqa: S506
81
+ # plugin yaml configs
82
+ for d, k, v in _iter_rec_bottomup(config): # type: ignore[arg-type]
83
+ if isinstance(v, str) and any(v.endswith(x) for x in (".yml", ".yaml")):
84
+ if not os.path.isfile(v):
85
+ v = os.path.join(basedir, v)
86
+ d[k] = load_yaml(v, hyperpar) # type: ignore[assignment, index]
87
+ return config # type: ignore[return-value]
88
+
89
+
90
+ def _iter_rec_bottomup(
91
+ d: Dict[str, Any] | List,
92
+ ) -> Iterator[Tuple[Dict[str, Any] | List, str | int, Any]]:
93
+ if isinstance(d, list):
94
+ it = enumerate(d)
95
+ elif isinstance(d, dict):
96
+ it = d.items() # type: ignore[assignment]
97
+ else:
98
+ raise TypeError(f"Unknown type: {type(d)}")
99
+ for k, v in it:
100
+ if isinstance(v, (list, dict)):
101
+ yield from _iter_rec_bottomup(v)
102
+ yield d, k, v
103
+
104
+
105
+ def build_module(
106
+ config: Union[str, Dict, List], hyperpar: Union[str, Dict, None] = None
107
+ ) -> Union[List, Dict, Callable]:
108
+ """
109
+ Build a module based on the provided configuration.
110
+ Every (possibly nested) dictionary with a 'class' key will be replaced by an instance initialized with
111
+ arguments and keywords provided as 'args' and 'kwargs' keys.
112
+
113
+ Args:
114
+ config (Union[str, Dict, List]): The configuration for building the module.
115
+ hyperpar (Union[str, Dict, None], optional): The hyperparameters for the module. Defaults to None.
116
+
117
+ Returns:
118
+ Union[List, Dict, Callable]: The built module.
119
+
120
+ Raises:
121
+ AssertionError: If `hyperpar` is provided and is not a dictionary.
122
+
123
+ """
124
+ if isinstance(hyperpar, str):
125
+ hyperpar = load_yaml(hyperpar) # type: ignore[assignment]
126
+ if hyperpar and not isinstance(hyperpar, dict):
127
+ raise TypeError("Hyperpar must be a dictionary")
128
+ config = load_yaml(config, hyperpar)
129
+ for d, k, v in _iter_rec_bottomup(config):
130
+ if isinstance(v, dict) and "class" in v:
131
+ d[k] = get_init_module( # type: ignore[index]
132
+ v["class"],
133
+ args=v.get("args", []), # type: ignore[assignment]
134
+ kwargs=v.get("kwargs", {}),
135
+ )
136
+ if "class" in config:
137
+ config = get_init_module( # type: ignore[assignment]
138
+ config["class"], # type: ignore[call-overload]
139
+ args=config.get("args", []), # type: ignore[union-attr]
140
+ kwargs=config.get("kwargs", {}), # type: ignore[union-attr]
141
+ )
142
+ return config # type: ignore[assignment]
143
+
144
+
145
+ def dict_to_dotted(d, parent=""):
146
+ if parent:
147
+ parent += "."
148
+ for k, v in list(d.items()):
149
+ if isinstance(v, dict) and v:
150
+ v = dict_to_dotted(v, parent + k)
151
+ d.update(v)
152
+ d.pop(k)
153
+ else:
154
+ d[parent + k] = d.pop(k)
155
+ return d
156
+
157
+
158
+ def dotted_to_dict(d):
159
+ for k, v in list(d.items()):
160
+ if "." not in k:
161
+ continue
162
+ ks = k.split(".")
163
+ ds = d
164
+ for ksp in ks[:-1]:
165
+ if ksp not in ds:
166
+ ds[ksp] = {}
167
+ ds = ds[ksp]
168
+ ds[ks[-1]] = v
169
+ d.pop(k)
170
+ return d
aimnet/constants.py ADDED
@@ -0,0 +1,467 @@
1
+ import os
2
+
3
+ import torch
4
+
5
+ # from ase.units
6
+ kB = 8.617330337217213e-05
7
+ fs = 0.09822694788464063
8
+ Hartree = 27.211386024367243
9
+ half_Hartree = 0.5 * Hartree
10
+ Bohr = 0.5291772105638411
11
+ Bohr_inv = 1 / Bohr
12
+
13
+
14
+ def get_masses(device="cpu"):
15
+ """Atomic masses from `ase.data.atomic_masses`"""
16
+ atomic_masses = torch.tensor([
17
+ 0.0,
18
+ 1.008,
19
+ 4.002602,
20
+ 6.94,
21
+ 9.0121831,
22
+ 10.81,
23
+ 12.011,
24
+ 14.007,
25
+ 15.999,
26
+ 18.99840316,
27
+ 20.1797,
28
+ 22.98976928,
29
+ 24.305,
30
+ 26.9815385,
31
+ 28.085,
32
+ 30.973762,
33
+ 32.06,
34
+ 35.45,
35
+ 39.948,
36
+ 39.0983,
37
+ 40.078,
38
+ 44.955908,
39
+ 47.867,
40
+ 50.9415,
41
+ 51.9961,
42
+ 54.938044,
43
+ 55.845,
44
+ 58.933194,
45
+ 58.6934,
46
+ 63.546,
47
+ 65.38,
48
+ 69.723,
49
+ 72.63,
50
+ 74.921595,
51
+ 78.971,
52
+ 79.904,
53
+ 83.798,
54
+ 85.4678,
55
+ 87.62,
56
+ 88.90584,
57
+ 91.224,
58
+ 92.90637,
59
+ 95.95,
60
+ 97.90721,
61
+ 101.07,
62
+ 102.9055,
63
+ 106.42,
64
+ 107.8682,
65
+ 112.414,
66
+ 114.818,
67
+ 118.71,
68
+ 121.76,
69
+ 127.6,
70
+ 126.90447,
71
+ 131.293,
72
+ 132.90545196,
73
+ 137.327,
74
+ 138.90547,
75
+ 140.116,
76
+ 140.90766,
77
+ 144.242,
78
+ 144.91276,
79
+ 150.36,
80
+ 151.964,
81
+ 157.25,
82
+ 158.92535,
83
+ 162.5,
84
+ 164.93033,
85
+ 167.259,
86
+ 168.93422,
87
+ 173.054,
88
+ 174.9668,
89
+ 178.49,
90
+ 180.94788,
91
+ 183.84,
92
+ 186.207,
93
+ 190.23,
94
+ 192.217,
95
+ 195.084,
96
+ 196.966569,
97
+ 200.592,
98
+ 204.38,
99
+ 207.2,
100
+ 208.9804,
101
+ 208.98243,
102
+ 209.98715,
103
+ 222.01758,
104
+ 223.01974,
105
+ 226.02541,
106
+ 227.02775,
107
+ 232.0377,
108
+ 231.03588,
109
+ 238.02891,
110
+ 237.04817,
111
+ 244.06421,
112
+ 243.06138,
113
+ 247.07035,
114
+ 247.07031,
115
+ 251.07959,
116
+ 252.083,
117
+ 257.09511,
118
+ 258.09843,
119
+ 259.101,
120
+ 262.11,
121
+ 267.122,
122
+ 268.126,
123
+ 271.134,
124
+ 270.133,
125
+ 269.1338,
126
+ 278.156,
127
+ 281.165,
128
+ 281.166,
129
+ 285.177,
130
+ 286.182,
131
+ 289.19,
132
+ 289.194,
133
+ 293.204,
134
+ 293.208,
135
+ 294.214,
136
+ ])
137
+ return atomic_masses.to(device)
138
+
139
+
140
+ def get_gfn1_rep(device="cpu"):
141
+ """Parameters for GFN1 repulsion function."""
142
+ gfn1_alpha = torch.tensor([
143
+ 0.000001,
144
+ 2.209700,
145
+ 1.382907,
146
+ 0.671797,
147
+ 0.865377,
148
+ 1.093544,
149
+ 1.281954,
150
+ 1.727773,
151
+ 2.004253,
152
+ 2.507078,
153
+ 3.038727,
154
+ 0.704472,
155
+ 0.862629,
156
+ 0.929219,
157
+ 0.948165,
158
+ 1.067197,
159
+ 1.200803,
160
+ 1.404155,
161
+ 1.323756,
162
+ 0.581529,
163
+ 0.665588,
164
+ 0.841357,
165
+ 0.828638,
166
+ 1.061627,
167
+ 0.997051,
168
+ 1.019783,
169
+ 1.137174,
170
+ 1.188538,
171
+ 1.399197,
172
+ 1.199230,
173
+ 1.145056,
174
+ 1.047536,
175
+ 1.129480,
176
+ 1.233641,
177
+ 1.270088,
178
+ 1.153580,
179
+ 1.335287,
180
+ 0.554032,
181
+ 0.657904,
182
+ 0.760144,
183
+ 0.739520,
184
+ 0.895357,
185
+ 0.944064,
186
+ 1.028240,
187
+ 1.066144,
188
+ 1.131380,
189
+ 1.206869,
190
+ 1.058886,
191
+ 1.026434,
192
+ 0.898148,
193
+ 1.008192,
194
+ 0.982673,
195
+ 0.973410,
196
+ 0.949181,
197
+ 1.074785,
198
+ 0.579919,
199
+ 0.606485,
200
+ 1.311200,
201
+ 0.839861,
202
+ 0.847281,
203
+ 0.854701,
204
+ 0.862121,
205
+ 0.869541,
206
+ 0.876961,
207
+ 0.884381,
208
+ 0.891801,
209
+ 0.899221,
210
+ 0.906641,
211
+ 0.914061,
212
+ 0.921481,
213
+ 0.928901,
214
+ 0.936321,
215
+ 0.853744,
216
+ 0.971873,
217
+ 0.992643,
218
+ 1.132106,
219
+ 1.118216,
220
+ 1.245003,
221
+ 1.304590,
222
+ 1.293034,
223
+ 1.181865,
224
+ 0.976397,
225
+ 0.988859,
226
+ 1.047194,
227
+ 1.013118,
228
+ 0.964652,
229
+ 0.998641,
230
+ ])
231
+ # Zeff
232
+ gfn1_Zeff = torch.tensor([
233
+ 0.000000,
234
+ 1.116244,
235
+ 0.440231,
236
+ 2.747587,
237
+ 4.076830,
238
+ 4.458376,
239
+ 4.428763,
240
+ 5.498808,
241
+ 5.171786,
242
+ 6.931741,
243
+ 9.102523,
244
+ 10.591259,
245
+ 15.238107,
246
+ 16.283595,
247
+ 16.898359,
248
+ 15.249559,
249
+ 15.100323,
250
+ 17.000000,
251
+ 17.153132,
252
+ 20.831436,
253
+ 19.840212,
254
+ 18.676202,
255
+ 17.084130,
256
+ 22.352532,
257
+ 22.873486,
258
+ 24.160655,
259
+ 25.983149,
260
+ 27.169215,
261
+ 23.396999,
262
+ 29.000000,
263
+ 31.185765,
264
+ 33.128619,
265
+ 35.493164,
266
+ 36.125762,
267
+ 32.148852,
268
+ 35.000000,
269
+ 36.000000,
270
+ 39.653032,
271
+ 38.924904,
272
+ 39.000000,
273
+ 36.521516,
274
+ 40.803132,
275
+ 41.939347,
276
+ 43.000000,
277
+ 44.492732,
278
+ 45.241537,
279
+ 42.105527,
280
+ 43.201446,
281
+ 49.016827,
282
+ 51.718417,
283
+ 54.503455,
284
+ 50.757213,
285
+ 49.215262,
286
+ 53.000000,
287
+ 52.500985,
288
+ 65.029838,
289
+ 46.532974,
290
+ 48.337542,
291
+ 30.638143,
292
+ 34.130718,
293
+ 37.623294,
294
+ 41.115870,
295
+ 44.608445,
296
+ 48.101021,
297
+ 51.593596,
298
+ 55.086172,
299
+ 58.578748,
300
+ 62.071323,
301
+ 65.563899,
302
+ 69.056474,
303
+ 72.549050,
304
+ 76.041625,
305
+ 55.222897,
306
+ 63.743065,
307
+ 74.000000,
308
+ 75.000000,
309
+ 76.000000,
310
+ 77.000000,
311
+ 78.000000,
312
+ 79.000000,
313
+ 80.000000,
314
+ 81.000000,
315
+ 79.578302,
316
+ 83.000000,
317
+ 84.000000,
318
+ 85.000000,
319
+ 86.000000,
320
+ ])
321
+ gfn1_repa = gfn1_alpha.pow(0.5) * Bohr_inv**0.75
322
+ gfn1_repb = gfn1_Zeff * (0.5 * Hartree * Bohr) ** 0.5
323
+ return gfn1_repa.to(device), gfn1_repb.to(device)
324
+
325
+
326
+ def get_r4r2(device="cpu"):
327
+ """r4r2 parameter for DFT-D3"""
328
+ ## https://github.com/dftd4/dftd4/blob/main/src/dftd4/data/r4r2.f90
329
+ sqrt_z_r4_over_r2 = [
330
+ 0.0,
331
+ 8.0589,
332
+ 3.4698,
333
+ 29.0974,
334
+ 14.8517,
335
+ 11.8799,
336
+ 7.8715,
337
+ 5.5588,
338
+ 4.7566,
339
+ 3.8025,
340
+ 3.1036,
341
+ 26.1552,
342
+ 17.2304,
343
+ 17.7210,
344
+ 12.7442,
345
+ 9.5361,
346
+ 8.1652,
347
+ 6.7463,
348
+ 5.6004,
349
+ 29.2012,
350
+ 22.3934,
351
+ 19.0598,
352
+ 16.8590,
353
+ 15.4023,
354
+ 12.5589,
355
+ 13.4788,
356
+ 12.2309,
357
+ 11.2809,
358
+ 10.5569,
359
+ 10.1428,
360
+ 9.4907,
361
+ 13.4606,
362
+ 10.8544,
363
+ 8.9386,
364
+ 8.1350,
365
+ 7.1251,
366
+ 6.1971,
367
+ 30.0162,
368
+ 24.4103,
369
+ 20.3537,
370
+ 17.4780,
371
+ 13.5528,
372
+ 11.8451,
373
+ 11.0355,
374
+ 10.1997,
375
+ 9.5414,
376
+ 9.0061,
377
+ 8.6417,
378
+ 8.9975,
379
+ 14.0834,
380
+ 11.8333,
381
+ 10.0179,
382
+ 9.3844,
383
+ 8.4110,
384
+ 7.5152,
385
+ 32.7622,
386
+ 27.5708,
387
+ 23.1671,
388
+ 21.6003,
389
+ 20.9615,
390
+ 20.4562,
391
+ 20.1010,
392
+ 19.7475,
393
+ 19.4828,
394
+ 15.6013,
395
+ 19.2362,
396
+ 17.4717,
397
+ 17.8321,
398
+ 17.4237,
399
+ 17.1954,
400
+ 17.1631,
401
+ 14.5716,
402
+ 15.8758,
403
+ 13.8989,
404
+ 12.4834,
405
+ 11.4421,
406
+ 10.2671,
407
+ 8.3549,
408
+ 7.8496,
409
+ 7.3278,
410
+ 7.4820,
411
+ 13.5124,
412
+ 11.6554,
413
+ 10.0959,
414
+ 9.7340,
415
+ 8.8584,
416
+ 8.0125,
417
+ 29.8135,
418
+ 26.3157,
419
+ 19.1885,
420
+ 15.8542,
421
+ 16.1305,
422
+ 15.6161,
423
+ 15.1226,
424
+ 16.1576,
425
+ 0.0000,
426
+ 0.0000,
427
+ 0.0000,
428
+ 0.0000,
429
+ 0.0000,
430
+ 0.0000,
431
+ 0.0000,
432
+ 0.0000,
433
+ 0.0000,
434
+ 0.0000,
435
+ 0.0000,
436
+ 0.0000,
437
+ 0.0000,
438
+ 0.0000,
439
+ 0.0000,
440
+ 0.0000,
441
+ 0.0000,
442
+ 5.4929,
443
+ 6.7286,
444
+ 6.5144,
445
+ 10.9169,
446
+ 10.3600,
447
+ 9.4723,
448
+ 8.6641,
449
+ ]
450
+
451
+ r4r2 = (0.5 * torch.tensor(sqrt_z_r4_over_r2) * torch.arange(len(sqrt_z_r4_over_r2)).sqrt()).sqrt()
452
+ return r4r2.to(device)
453
+
454
+
455
+ def get_dftd3_param(device="cpu"):
456
+ """Collection of parameters for DFT-D3 model"""
457
+ dirname = os.path.dirname(__file__)
458
+ filename = os.path.join(dirname, "dftd3_data.pt")
459
+ if not os.path.exists(filename):
460
+ raise FileNotFoundError(f"dftd3_data.pt not found in {dirname}.")
461
+ param = torch.load(filename, map_location=device, weights_only=True)
462
+ assert isinstance(param, dict)
463
+ assert "c6ab" in param
464
+ assert "r4r2" in param
465
+ assert "rcov" in param
466
+ assert "cnmax" in param
467
+ return param
@@ -0,0 +1 @@
1
+ from .sgdataset import DataGroup, SizeGroupedDataset, SizeGroupedSampler # noqa: F401