imt-ring 1.6.17__py3-none-any.whl → 1.6.18__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.17
3
+ Version: 1.6.18
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -78,6 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
+ ring/utils/dataloader_torch.py,sha256=DR2uUiA9x49_6EBjnbVLfWu7GBX7wtKjgHSIlF80HO0,1502
81
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
82
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
83
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -85,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
85
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
86
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
87
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
88
- imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
89
- imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
90
- imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
91
- imt_ring-1.6.17.dist-info/RECORD,,
89
+ imt_ring-1.6.18.dist-info/METADATA,sha256=BhREl-3Q3LDO-ugNEfitQuW2ZDCR-ng-Hleb9FVG6Ps,3833
90
+ imt_ring-1.6.18.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
91
+ imt_ring-1.6.18.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.18.dist-info/RECORD,,
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import jax
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from torch.utils.data import Dataset
7
+ from tree_utils import PyTree
8
+
9
+ from ring.utils import parse_path
10
+ from ring.utils import pickle_load
11
+
12
+
13
+ class FolderOfPickleFilesDataset(Dataset):
14
+ def __init__(self, path, transform=None):
15
+ self.files = self.listdir(path)
16
+ self.transform = transform
17
+ self.N = len(self.files)
18
+
19
+ def __len__(self):
20
+ return self.N
21
+
22
+ def __getitem__(self, idx: int):
23
+ element = pickle_load(self.files[idx])
24
+ if self.transform is not None:
25
+ element = self.transform(element)
26
+ return element
27
+
28
+ @staticmethod
29
+ def listdir(path: str) -> list:
30
+ return [parse_path(path, file) for file in os.listdir(path)]
31
+
32
+
33
+ def dataset_to_generator(
34
+ dataset: Dataset,
35
+ batch_size: int,
36
+ shuffle=True,
37
+ seed: int = 1,
38
+ **kwargs,
39
+ ):
40
+ torch.manual_seed(seed)
41
+
42
+ dl = DataLoader(
43
+ dataset,
44
+ batch_size=batch_size,
45
+ shuffle=shuffle,
46
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
47
+ **kwargs,
48
+ )
49
+ dl_iter = iter(dl)
50
+
51
+ def to_numpy(tree: PyTree[torch.Tensor]):
52
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
53
+
54
+ def generator(_):
55
+ nonlocal dl, dl_iter
56
+ try:
57
+ return to_numpy(next(dl_iter))
58
+ except StopIteration:
59
+ dl_iter = iter(dl)
60
+ return to_numpy(next(dl_iter))
61
+
62
+ return generator