imt-ring 1.6.17__py3-none-any.whl → 1.6.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.17
3
+ Version: 1.6.18
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -78,6 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
+ ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
82
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
83
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -85,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
85
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
86
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
87
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
88
- imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
89
- imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
90
- imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
91
- imt_ring-1.6.17.dist-info/RECORD,,
89
+ imt_ring-1.6.18.dist-info/METADATA,sha256=BhREl-3Q3LDO-ugNEfitQuW2ZDCR-ng-Hleb9FVG6Ps,3833
90
+ imt_ring-1.6.18.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
91
+ imt_ring-1.6.18.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.18.dist-info/RECORD,,
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import jax
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import Dataset
7
+ from tree_utils import PyTree
8
+
9
+ from ring.utils import parse_path
10
+ from ring.utils import pickle_load
11
+
12
+
13
+ class FolderOfPickleFilesDataset(Dataset):
14
+ def __init__(self, path, transform=None):
15
+ self.files = self.listdir(path)
16
+ self.transform = transform
17
+ self.N = len(self.files)
18
+
19
+ def __len__(self):
20
+ return self.N
21
+
22
+ def __getitem__(self, idx: int):
23
+ element = pickle_load(self.files[idx])
24
+ if self.transform is not None:
25
+ element = self.transform(element)
26
+ return element
27
+
28
+ @staticmethod
29
+ def listdir(path: str) -> list:
30
+ return [parse_path(path, file) for file in os.listdir(path)]
31
+
32
+
33
+ def dataset_to_generator(
34
+ dataset: Dataset,
35
+ batch_size: int,
36
+ shuffle=True,
37
+ seed: int = 1,
38
+ **kwargs,
39
+ ):
40
+ torch.manual_seed(seed)
41
+
42
+ dl = DataLoader(
43
+ dataset,
44
+ batch_size=batch_size,
45
+ shuffle=shuffle,
46
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
47
+ **kwargs,
48
+ )
49
+ dl_iter = iter(dl)
50
+
51
+ def to_numpy(tree: PyTree[torch.Tensor]):
52
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
53
+
54
+ def generator(_):
55
+ nonlocal dl, dl_iter
56
+ try:
57
+ return to_numpy(next(dl_iter))
58
+ except StopIteration:
59
+ dl_iter = iter(dl)
60
+ return to_numpy(next(dl_iter))
61
+
62
+ return generator