imt-ring 1.6.32__py3-none-any.whl → 1.6.33__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.32
3
+ Version: 1.6.33
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -15,8 +15,8 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=rrhHg6lFPDJs72kXvzF15v1vzkaUTKtCnpcmWZONYA8,16847
19
- ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
18
+ ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
19
+ ring/algorithms/generator/batch.py,sha256=P51UnAZl9TUF_eVq58VL1CsmPPStPHhRDdKjUyvu4EA,2652
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
22
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
@@ -78,7 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=wMKJ-eCJ4cHjisGODOZgDVG2r-XQjSANBQFfC05wpzo,2092
81
+ ring/utils/dataloader_torch.py,sha256=bravdBqbkxxcQDieg6OnmArGwGcWpMI3MmNFwTCt0qg,3808
82
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.32.dist-info/METADATA,sha256=6bNRA4bhdmUOnqV8ZfR-tl5wbk-awzTE7qmLzDOI1xs,4251
90
- imt_ring-1.6.32.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.32.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.32.dist-info/RECORD,,
89
+ imt_ring-1.6.33.dist-info/METADATA,sha256=FYe4G7jx8u4IblPmrFrpnqCxL3Nv-ITa6LG9fVGaOng,4251
90
+ imt_ring-1.6.33.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.33.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.33.dist-info/RECORD,,
@@ -213,6 +213,8 @@ class RCMG:
213
213
  )
214
214
  save_fn(d, file)
215
215
  i += 1
216
+ # cleanup
217
+ del data
216
218
 
217
219
  gens, n_calls = self._generators_ncalls(sizes)
218
220
  batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from typing import Callable
2
3
 
3
4
  import jax
@@ -83,4 +84,8 @@ def generators_eager(
83
84
 
84
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
85
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
86
- callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
87
+ callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
88
+
89
+ # cleanup
90
+ del sample, sample_flat
91
+ gc.collect()
@@ -1,8 +1,9 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
  import warnings
4
4
 
5
5
  import jax
6
+ import numpy as np
6
7
  import torch
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data import Dataset
@@ -12,7 +13,7 @@ from ring.utils import parse_path
12
13
  from ring.utils import pickle_load
13
14
 
14
15
 
15
- class FolderOfPickleFilesDataset(Dataset):
16
+ class FolderOfFilesDataset(Dataset):
16
17
  def __init__(self, path, transform=None):
17
18
  self.files = self.listdir(path)
18
19
  self.transform = transform
@@ -22,7 +23,7 @@ class FolderOfPickleFilesDataset(Dataset):
22
23
  return self.N
23
24
 
24
25
  def __getitem__(self, idx: int):
25
- element = pickle_load(self.files[idx])
26
+ element = self._load_file(self.files[idx])
26
27
  if self.transform is not None:
27
28
  element = self.transform(element)
28
29
  return element
@@ -31,6 +32,10 @@ class FolderOfPickleFilesDataset(Dataset):
31
32
  def listdir(path: str) -> list:
32
33
  return [parse_path(path, file) for file in os.listdir(path)]
33
34
 
35
+ @staticmethod
36
+ def _load_file(file_path: str) -> Any:
37
+ return pickle_load(file_path)
38
+
34
39
 
35
40
  def dataset_to_generator(
36
41
  dataset: Dataset,
@@ -84,3 +89,60 @@ def _get_number_of_logical_cores() -> int:
84
89
  )
85
90
  N = 0
86
91
  return N
92
+
93
+
94
+ class MultiDataset(Dataset):
95
+ def __init__(self, datasets, transform=None):
96
+ """
97
+ Args:
98
+ datasets: A list of datasets to sample from.
99
+ transform: A function that takes N items (one from each dataset) and combines them.
100
+ """ # noqa: E501
101
+ self.datasets = datasets
102
+ self.transform = transform
103
+
104
+ def __len__(self):
105
+ # Length is defined by the smallest dataset in the list
106
+ return min(len(ds) for ds in self.datasets)
107
+
108
+ def __getitem__(self, idx):
109
+ sampled_items = [ds[idx] for ds in self.datasets]
110
+
111
+ if self.transform:
112
+ # Apply the transformation to all sampled items
113
+ return self.transform(*sampled_items)
114
+
115
+ return tuple(sampled_items)
116
+
117
+
118
+ class ShuffledDataset(Dataset):
119
+ def __init__(self, dataset):
120
+ """
121
+ Wrapper that shuffles the dataset indices once.
122
+
123
+ Args:
124
+ dataset (Dataset): The original dataset to shuffle.
125
+ """
126
+ self.dataset = dataset
127
+ self.shuffled_indices = np.random.permutation(
128
+ len(dataset)
129
+ ) # Shuffle indices once
130
+
131
+ def __len__(self):
132
+ return len(self.dataset)
133
+
134
+ def __getitem__(self, idx):
135
+ """
136
+ Returns the data at the shuffled index.
137
+
138
+ Args:
139
+ idx (int): Index in the shuffled dataset.
140
+ """
141
+ original_idx = self.shuffled_indices[idx]
142
+ return self.dataset[original_idx]
143
+
144
+
145
+ def dataset_to_Xy(ds: Dataset):
146
+ return dataset_to_generator(ds, batch_size=len(ds), shuffle=False, num_workers=0)(
147
+ None
148
+ )