aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,517 @@
1
+ import os
2
+ from glob import glob
3
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from torch.utils.data.dataloader import DataLoader, default_collate
8
+
9
+
10
+ class DataGroup:
11
+ """Dict-like container for data arrays with consistent shape.
12
+
13
+ Args:
14
+
15
+ `data (str | Dict[str, np.ndarray] | None)`: The data to be used in the dataset.
16
+ It can be a string representing the path to the data file in NPZ format or a dictionary where
17
+ keys are strings and values are numpy arrays.
18
+
19
+ `keys (List[str] | None)`: A list of keys to be used from the data dictionary.
20
+
21
+ `shard (Tuple[int, int] | None)`: A tuple representing the shard index and total shards.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ data: Optional[str | Dict[str, np.ndarray] | h5py.Group] = None,
27
+ keys: Optional[List[str]] = None,
28
+ shard: Optional[Tuple[int, int]] = None,
29
+ ):
30
+ # main container for data
31
+ self._data: Dict[str, np.ndarray] = {}
32
+
33
+ if data is None:
34
+ data = {}
35
+
36
+ s = slice(shard[0], None, shard[1]) if shard is not None else slice(None)
37
+
38
+ # load to dict
39
+ if isinstance(data, str):
40
+ if not os.path.isfile(data):
41
+ raise FileNotFoundError(f"{data} does not exist or not a file.")
42
+ data = np.load(data, mmap_mode="r")
43
+ if not hasattr(data, "files"):
44
+ raise ValueError(f"Data file {data} does not contain named arrays.")
45
+
46
+ # take only keys
47
+ if keys is None:
48
+ keys = data.keys() # type: ignore[union-attr]
49
+ data = {k: v[s] for k, v in data.items() if k in keys} # type: ignore[union-attr]
50
+
51
+ # check data
52
+ _n = None
53
+ for k, v in data.items():
54
+ if not isinstance(k, str):
55
+ raise TypeError(f"Expected key to be of type str, but got {type(k).__name__}")
56
+ if keys is not None and k not in keys:
57
+ continue
58
+ if _n is None:
59
+ _n = len(v)
60
+ if len(v) != _n:
61
+ raise ValueError(f"Inconsistent data shape for key {k}. Expected first dimension {_n}, got {len(v)}.")
62
+ self._data[k] = v
63
+
64
+ def __getitem__(self, key):
65
+ return self._data[key]
66
+
67
+ def __setitem__(self, key, value):
68
+ if not isinstance(key, str):
69
+ raise TypeError(f"Failed to set key of type {type(key)}, expected str.")
70
+ if not isinstance(value, np.ndarray):
71
+ raise TypeError(f"Failed to set item of wrong type. Expected {type(np.ndarray)}, got {type(value)}.")
72
+ if len(self) and len(value) != len(self):
73
+ raise ValueError(f"Failed to set item of wrong shape. Expected {len(self)}, got {len(value)}.")
74
+ self._data[key] = value
75
+
76
+ def __delitem__(self, key):
77
+ del self._data[key]
78
+
79
+ def __contains__(self, key):
80
+ return key in self._data
81
+
82
+ def __len__(self):
83
+ return len(next(iter(self.values()))) if self._data else 0
84
+
85
+ def to_dict(self):
86
+ return self._data
87
+
88
+ def items(self):
89
+ return self._data.items()
90
+
91
+ def values(self):
92
+ return self._data.values()
93
+
94
+ def keys(self):
95
+ return self._data.keys()
96
+
97
+ def pop(self, key):
98
+ return self._data.pop(key)
99
+
100
+ def rename_key(self, old, new):
101
+ self[new] = self.pop(old)
102
+
103
+ def sample(self, idx, keys=None) -> "DataGroup":
104
+ """Return a new `DataGroup` with the data indexed by `idx`."""
105
+ if keys is None:
106
+ keys = self.keys()
107
+ if isinstance(idx, int):
108
+ idx = slice(idx, idx + 1)
109
+ return self.__class__({k: self[k][idx] for k in keys})
110
+
111
+ def random_split(self, *fractions, seed=None):
112
+ if not (0 < sum(fractions) <= 1):
113
+ raise ValueError("Sum of fractions must be between 0 and 1.")
114
+ if not all(f > 0 for f in fractions):
115
+ raise ValueError("All fractions must be greater than 0.")
116
+ idx = np.arange(len(self))
117
+ np.random.seed(seed)
118
+ np.random.shuffle(idx)
119
+ sections = np.around(np.cumsum(fractions) * len(self)).astype(np.int64)
120
+ return [self.sample(sidx) if len(sidx) else self.__class__() for sidx in np.array_split(idx, sections)]
121
+
122
+ def cv_split(self, cv: int = 5, seed=None):
123
+ """Return list of `cv` tuples containing train and val `DataGroup`s"""
124
+ fractions = [1 / cv] * cv
125
+ parts = self.random_split(*fractions, seed=seed)
126
+ splits = []
127
+ for icv in range(cv):
128
+ val = parts[icv]
129
+ _idx = [_i for _i in range(cv) if _i != icv]
130
+ train = parts[_idx[0]]
131
+ train.cat(*[parts[_i] for _i in _idx[1:]])
132
+ splits.append((train, val))
133
+ return splits
134
+
135
+ def save(self, filename, compress=False):
136
+ op = np.savez_compressed if compress else np.savez
137
+ if len(self):
138
+ op(filename, **self._data)
139
+
140
+ def shuffle(self, seed=None):
141
+ idx = np.arange(len(self))
142
+ np.random.seed(seed)
143
+ np.random.shuffle(idx)
144
+ for k, v in self.items():
145
+ self[k] = v[idx]
146
+
147
+ def cat(self, *others):
148
+ _n = set(self.keys())
149
+ for other in others:
150
+ if set(other.keys()) != _n:
151
+ raise ValueError("Inconsistent data keys.")
152
+ for k, v in self.items():
153
+ self._data[k] = np.concatenate([v, *[other[k] for other in others]], axis=0)
154
+
155
+ def iter_batched(self, batch_size=128, keys=None):
156
+ idx = np.arange(len(self))
157
+ idxs = np.array_split(idx, np.ceil(len(self) / batch_size))
158
+ if keys is None:
159
+ keys = self.keys()
160
+ for idx in idxs:
161
+ yield {k: v[idx] for k, v in self.items() if k in keys}
162
+
163
+ def merge(self, other, strict=True):
164
+ if strict:
165
+ if set(self.keys()) != set(other.keys()):
166
+ raise ValueError("Data keys do not match between the datasets.")
167
+ keys = self.keys()
168
+ else:
169
+ keys = set(self.keys()) & set(other.keys())
170
+ for k in list(self.keys()):
171
+ if k in keys:
172
+ self._data[k] = np.concatenate([self[k], other[k]], axis=0)
173
+ else:
174
+ self.pop(k)
175
+
176
+ def apply_peratom_shift(self, sap_dict, key_in="energy", key_out="energy", numbers_key="numbers"):
177
+ ntyp = max(sap_dict.keys()) + 1
178
+ sap = np.zeros(ntyp) * np.nan
179
+ for k, v in sap_dict.items():
180
+ sap[k] = v
181
+ self._data[key_out] = self[key_in] - sap[self[numbers_key]].sum(axis=-1)
182
+
183
+
184
+ class SizeGroupedDataset:
185
+ def __init__(
186
+ self,
187
+ data: Optional[str | List[str] | Dict[int, Dict[str, np.ndarray]] | Dict[int, DataGroup]] = None,
188
+ keys: Optional[List[str]] = None,
189
+ shard: Optional[Tuple[int, int]] = None,
190
+ ):
191
+ # main containers
192
+ self._data: Dict[int, DataGroup] = {}
193
+ self._meta: Dict[str, str] = {}
194
+
195
+ # load data
196
+ if isinstance(data, str):
197
+ if os.path.isdir(data):
198
+ self.load_datadir(data, keys=keys, shard=shard)
199
+ else:
200
+ self.load_h5(data, keys=keys, shard=shard)
201
+ elif isinstance(data, (list, tuple)):
202
+ self.load_files(data, shard=shard)
203
+ elif isinstance(data, dict):
204
+ self.load_dict(data)
205
+ self.loader_mode = False
206
+ self.x: List[str] = []
207
+ self.y: List[str] = []
208
+
209
+ def load_datadir(self, path, keys=None, shard: Optional[Tuple[int, int]] = None):
210
+ if not os.path.isdir(path):
211
+ raise FileNotFoundError(f"{path} does not exist or not a directory.")
212
+ for f in glob(os.path.join(path, "???.npz")):
213
+ k = int(os.path.basename(f)[:3])
214
+ self[k] = DataGroup(f, keys=keys, shard=shard)
215
+
216
+ def load_files(self, files, keys=None, shard: Optional[Tuple[int, int]] = None):
217
+ for fil in files:
218
+ if not os.path.isfile(fil):
219
+ raise FileNotFoundError(f"{fil} does not exist or not a file.")
220
+ k = int(os.path.splitext(os.path.basename(fil))[0])
221
+ self[k] = DataGroup(fil, keys=keys, shard=shard)
222
+
223
+ def load_dict(self, data, keys=None):
224
+ for k, v in data.items():
225
+ self[k] = DataGroup(v, keys=keys)
226
+
227
+ def load_h5(self, data, keys=None, shard: Optional[Tuple[int, int]] = None):
228
+ with h5py.File(data, "r") as f:
229
+ for k, g in f.items():
230
+ k = int(k)
231
+ self[k] = DataGroup(g, keys=keys, shard=shard)
232
+ self._meta = dict(f.attrs) # type: ignore[attr-defined]
233
+
234
+ def keys(self) -> List[int]:
235
+ return sorted(self._data.keys())
236
+
237
+ def values(self) -> List:
238
+ return [self[k] for k in self.keys()]
239
+
240
+ def items(self) -> List[Tuple[int, Any]]:
241
+ return [(k, self[k]) for k in self.keys()]
242
+
243
+ def datakeys(self):
244
+ return next(iter(self._data.values())).keys() if self._data else set()
245
+
246
+ @property
247
+ def groups(self) -> List[DataGroup]:
248
+ return self.values()
249
+
250
+ def __len__(self):
251
+ return sum(len(d) for d in self.values())
252
+
253
+ def __setitem__(self, key: int, value: DataGroup):
254
+ if not isinstance(key, int):
255
+ raise TypeError(f"Failed to set key of type {type(key)}, expected int.")
256
+ if not isinstance(value, DataGroup):
257
+ raise TypeError(f"Failed to set item of wrong type. Expected DataGroup, got {type(value)}.")
258
+ if self._data and set(self.datakeys()) != set(value.keys()):
259
+ raise ValueError("Wrong set of data keys.")
260
+ self._data[key] = value
261
+
262
+ def __getitem__(self, item: int | Tuple[int, Sequence]) -> Dict | Tuple[Dict, Dict]:
263
+ if isinstance(item, int):
264
+ ret = self._data[item]
265
+ else:
266
+ grp, idx = item
267
+ if self.loader_mode:
268
+ ret = (
269
+ {k: v[idx] for k, v in self[grp].items() if k in self.x}, # type: ignore[union-attr, assignment]
270
+ {k: v[idx] for k, v in self[grp].items() if k in self.y}, # type: ignore[union-attr, assignment]
271
+ )
272
+ else:
273
+ ret = {k: v[idx] for k, v in self[grp].items()} # type: ignore[union-attr, assignment]
274
+ return ret # type: ignore[return-value]
275
+
276
+ def __contains__(self, value):
277
+ return value in self.keys()
278
+
279
+ def rename_datakey(self, old, new):
280
+ for g in self.groups:
281
+ g.rename_key(old, new)
282
+
283
+ def apply(self, fn):
284
+ for grp in self.groups:
285
+ fn(grp)
286
+
287
+ def merge(self, other, strict=True):
288
+ if not isinstance(other, self.__class__):
289
+ other = self.__class__(other)
290
+ if strict:
291
+ if set(other.datakeys()) != set(self.datakeys()):
292
+ raise ValueError("Data keys do not match between the datasets.")
293
+ else:
294
+ keys = set(other.datakeys()) & set(self.datakeys())
295
+ for k in list(self.datakeys()):
296
+ if k not in keys:
297
+ for g in self.groups:
298
+ g.pop(k)
299
+ for k in list(other.datakeys()):
300
+ if k not in keys:
301
+ for g in other.groups:
302
+ g.pop(k)
303
+ for k in other:
304
+ if k in self:
305
+ self[k].cat(other[k]) # type: ignore[attr-defined]
306
+ else:
307
+ self[k] = other[k] # type: ignore[attr-defined]
308
+
309
+ def random_split(self, *fractions, seed=None):
310
+ splitted_groups = {}
311
+ for k, v in self.items():
312
+ splitted_groups[k] = v.random_split(*fractions, seed=seed)
313
+ datasets = []
314
+ for i in range(len(fractions)):
315
+ datasets.append(
316
+ self.__class__({k: splitted_groups[k][i] for k in splitted_groups if len(splitted_groups[k][i]) > 0})
317
+ )
318
+ return datasets
319
+
320
+ def cv_split(self, cv: int = 5, seed=None):
321
+ splitted_groups = {}
322
+ for k, v in self.items():
323
+ splitted_groups[k] = v.cv_split(cv, seed)
324
+ datasets = []
325
+ for i in range(cv):
326
+ train = self.__class__({k: splitted_groups[k][i][0] for k in splitted_groups})
327
+ val = self.__class__({k: splitted_groups[k][i][1] for k in splitted_groups})
328
+ datasets.append((train, val))
329
+ return datasets
330
+
331
+ def shuffle(self, seed=None):
332
+ for v in self.values():
333
+ v.shuffle(seed)
334
+
335
+ def save(self, dirname, namemap_fn: Optional[Callable] = None, compress: bool = False):
336
+ os.makedirs(dirname, exist_ok=True)
337
+ if namemap_fn is None:
338
+ namemap_fn = lambda x: f"{x:03d}.npz"
339
+ for k, v in self.items():
340
+ fname = os.path.join(dirname, namemap_fn(k))
341
+ v.save(fname, compress=compress)
342
+
343
+ def save_h5(self, filename):
344
+ with h5py.File(filename, "w") as f:
345
+ for n, g in self.items():
346
+ n = f"{n:03d}"
347
+ h5g = f.create_group(n)
348
+ for k, v in g.items():
349
+ h5g.create_dataset(k, data=v)
350
+ for k, v in self._meta.items():
351
+ f.attrs[k] = v
352
+
353
+ def merge_groups(self, min_size=1, mode_atoms=False, atom_key="numbers"):
354
+ # create list of supergroups
355
+ sgroups = []
356
+ n = 0
357
+ sg = []
358
+ for k, v in self.items():
359
+ _n = len(v)
360
+ if mode_atoms:
361
+ _n *= v[atom_key].shape[1]
362
+ n += _n
363
+ sg.append(k)
364
+ if n >= min_size:
365
+ sgroups.append(sg)
366
+ n = 0
367
+ sg = []
368
+ sgroups[-1].extend(sg)
369
+
370
+ # merge
371
+ keys = self.datakeys()
372
+ for sg in sgroups:
373
+ for k in keys:
374
+ arrs = [self[n][k] for n in sg] # type: ignore
375
+ arrs = self._collate(arrs)
376
+ self[sg[-1]]._data[k] = arrs # type: ignore
377
+ for n in sg[:-1]:
378
+ del self._data[n]
379
+
380
+ @staticmethod
381
+ def _collate(arrs, pad_value=0):
382
+ n = sum(a.shape[0] for a in arrs)
383
+ shape = np.stack([a.shape[1:] for a in arrs], axis=0).max(axis=0)
384
+ arr = np.full((n, *shape), pad_value, dtype=arrs[0].dtype)
385
+ i = 0
386
+ for a in arrs:
387
+ n = a.shape[0]
388
+ slices = tuple([slice(i, i + n)] + [slice(0, x) for x in a.shape[1:]])
389
+ arr[slices] = a
390
+ i += n
391
+ return arr
392
+
393
+ def concatenate(self, key):
394
+ try:
395
+ c = np.concatenate([g[key] for g in self.values() if len(g)], axis=0)
396
+ except ValueError:
397
+ c = np.concatenate([g[key].flatten() for g in self.values() if len(g)], axis=0)
398
+ return c
399
+
400
+ def apply_peratom_shift(self, key_in="energy", key_out="energy", numbers_key="numbers", sap_dict=None):
401
+ if sap_dict is None:
402
+ E = self.concatenate(key_in)
403
+ ntyp = max(g[numbers_key].max() for g in self.groups) + 1
404
+ eye = np.eye(ntyp, dtype=np.min_scalar_type(ntyp))
405
+ F = np.concatenate([eye[g[numbers_key]].sum(-2) for g in self.values()])
406
+ sap = np.linalg.lstsq(F, E, rcond=None)[0]
407
+ present_elements = np.nonzero(F.sum(0))[0]
408
+ else:
409
+ ntyp = max(sap_dict.keys()) + 1
410
+ sap = np.zeros(ntyp) * np.nan
411
+ for k, v in sap_dict.items():
412
+ sap[k] = v
413
+ present_elements = sap_dict.keys()
414
+
415
+ def fn(g):
416
+ g[key_out] = g[key_in] - sap[g[numbers_key]].sum(axis=-1)
417
+
418
+ self.apply(fn)
419
+
420
+ return {i: sap[i] for i in present_elements}
421
+
422
+ def apply_pertype_logratio(self, key_in="volumes", key_out="volumes", numbers_key="numbers", sap_dict=None):
423
+ if sap_dict is None:
424
+ numbers = self.concatenate("numbers")
425
+ present_elements = sorted(np.unique(numbers))
426
+ x = self.concatenate(key_in)
427
+ sap_dict = {}
428
+ for n in present_elements:
429
+ sap_dict[n] = np.median(x[numbers == n])
430
+ sap = np.zeros(max(sap_dict.keys()) + 1)
431
+ for n, v in sap_dict.items():
432
+ sap[n] = v
433
+
434
+ def fn(g):
435
+ g[key_out] = np.log(g[key_in] / sap[g[numbers_key]])
436
+
437
+ self.apply(fn)
438
+ return sap_dict
439
+
440
+ def numpy_batches(self, batch_size=128, keys=None):
441
+ for g in self.values():
442
+ yield from g.iter_batched(batch_size, keys)
443
+
444
+ def get_loader(self, sampler, x: List[str], y: Optional[List[str]] = None, **loader_kwargs):
445
+ self.loader_mode = True
446
+ self.x = x
447
+ self.y = y or []
448
+
449
+ loader = DataLoader(self, batch_sampler=sampler, **loader_kwargs) # type: ignore
450
+
451
+ def _squeeze(t):
452
+ for d in t:
453
+ for v in d.values():
454
+ v.squeeze_(0)
455
+ return t
456
+
457
+ loader.collate_fn = lambda x: _squeeze(default_collate(x))
458
+ return loader
459
+
460
+
461
+ class SizeGroupedSampler:
462
+ def __init__(
463
+ self,
464
+ ds: SizeGroupedDataset,
465
+ batch_size: int,
466
+ batch_mode: str = "molecules",
467
+ shuffle: bool = False,
468
+ batches_per_epoch: int = -1,
469
+ ):
470
+ self.ds = ds
471
+ self.batch_size = batch_size
472
+ if batch_mode not in ["molecules", "atoms"]:
473
+ raise ValueError(f"Unknown batch_mode {batch_mode}")
474
+ self.batch_mode = batch_mode
475
+ self.shuffle = shuffle
476
+ self.batches_per_epoch = batches_per_epoch
477
+
478
+ def __len__(self):
479
+ if self.batches_per_epoch > 0:
480
+ return self.batches_per_epoch
481
+ else:
482
+ return sum(self._get_num_batches_for_group(g) for g in self.ds.groups)
483
+
484
+ def __iter__(self):
485
+ return iter(self._samples_list())
486
+
487
+ def _get_num_batches_for_group(self, g):
488
+ if self.batch_mode == "molecules":
489
+ return int(np.ceil(len(g) / self.batch_size))
490
+ elif self.batch_mode == "atoms":
491
+ return int(np.ceil(len(g) * g["numbers"].shape[1] / self.batch_size))
492
+ else:
493
+ raise ValueError(f"Unknown batch_mode: {self.batch_mode}")
494
+
495
+ def _samples_list(self):
496
+ samples = []
497
+ for group_key, g in self.ds.items():
498
+ n = len(g)
499
+ if n == 0:
500
+ continue
501
+ idx = np.arange(n)
502
+ if self.shuffle:
503
+ np.random.shuffle(idx)
504
+ n_batches = self._get_num_batches_for_group(g)
505
+ samples.extend(((group_key, idx_batch),) for idx_batch in np.array_split(idx, n_batches))
506
+ if self.shuffle:
507
+ np.random.shuffle(samples)
508
+ if self.batches_per_epoch > 0:
509
+ if len(samples) > self.batches_per_epoch:
510
+ samples = samples[: self.batches_per_epoch]
511
+ else:
512
+ # add some random duplicates
513
+ idx = np.arange(len(samples))
514
+ np.random.shuffle(idx)
515
+ n = self.batches_per_epoch - len(samples)
516
+ samples.extend([samples[i] for i in np.random.choice(idx, n, replace=True)])
517
+ return samples
aimnet/dftd3_data.pt ADDED
Binary file
@@ -0,0 +1,2 @@
1
+ from .aimnet2 import AIMNet2 # noqa: F401
2
+ from .base import AIMNet2Base # noqa: F401