imt-ring 1.6.32__py3-none-any.whl → 1.6.34__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.32
3
+ Version: 1.6.34
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -15,8 +15,8 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=rrhHg6lFPDJs72kXvzF15v1vzkaUTKtCnpcmWZONYA8,16847
19
- ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
18
+ ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
19
+ ring/algorithms/generator/batch.py,sha256=P51UnAZl9TUF_eVq58VL1CsmPPStPHhRDdKjUyvu4EA,2652
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
22
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
@@ -58,7 +58,7 @@ ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=-6SzQKjIgktgRjaXKVg_1dqcBmAJggZSVwDnau1FnxI,10832
61
+ ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
62
62
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
@@ -78,7 +78,7 @@ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
78
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
80
  ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=wMKJ-eCJ4cHjisGODOZgDVG2r-XQjSANBQFfC05wpzo,2092
81
+ ring/utils/dataloader_torch.py,sha256=bravdBqbkxxcQDieg6OnmArGwGcWpMI3MmNFwTCt0qg,3808
82
82
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
83
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.32.dist-info/METADATA,sha256=6bNRA4bhdmUOnqV8ZfR-tl5wbk-awzTE7qmLzDOI1xs,4251
90
- imt_ring-1.6.32.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.32.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.32.dist-info/RECORD,,
89
+ imt_ring-1.6.34.dist-info/METADATA,sha256=D7FXQFI8b4iXaJiWkM4MvHNqzncvFh_wn4UpVK8iqMs,4251
90
+ imt_ring-1.6.34.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.34.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.34.dist-info/RECORD,,
@@ -213,6 +213,8 @@ class RCMG:
213
213
  )
214
214
  save_fn(d, file)
215
215
  i += 1
216
+ # cleanup
217
+ del data
216
218
 
217
219
  gens, n_calls = self._generators_ncalls(sizes)
218
220
  batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from typing import Callable
2
3
 
3
4
  import jax
@@ -83,4 +84,8 @@ def generators_eager(
83
84
 
84
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
85
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
86
- callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
87
+ callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
88
+
89
+ # cleanup
90
+ del sample, sample_flat
91
+ gc.collect()
ring/ml/train.py CHANGED
@@ -39,6 +39,7 @@ def _build_step_fn(
39
39
  filter: ml_base.AbstractFilter,
40
40
  optimizer,
41
41
  tbp,
42
+ skip_first_tbp_batch,
42
43
  ):
43
44
  """Build step function that optimizes filter parameters based on `metric_fn`.
44
45
  `initial_state` has shape (pmap, vmap, state_dim)"""
@@ -89,6 +90,8 @@ def _build_step_fn(
89
90
  ):
90
91
  (loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
91
92
  debug_grads.append(grads)
93
+ if skip_first_tbp_batch and i == 0:
94
+ continue
92
95
  state = jax.lax.stop_gradient(state)
93
96
  params, opt_state = apply_grads(grads, params, opt_state)
94
97
 
@@ -119,6 +122,7 @@ def train_fn(
119
122
  loss_fn: LOSS_FN = _default_loss_fn,
120
123
  metrices: Optional[METRICES] = _default_metrices,
121
124
  link_names: Optional[list[str]] = None,
125
+ skip_first_tbp_batch: bool = False,
122
126
  ) -> bool:
123
127
  """Trains RNNO
124
128
 
@@ -161,10 +165,7 @@ def train_fn(
161
165
  opt_state = optimizer.init(filter_params)
162
166
 
163
167
  step_fn = _build_step_fn(
164
- loss_fn,
165
- filter,
166
- optimizer,
167
- tbp=tbp,
168
+ loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
168
169
  )
169
170
 
170
171
  # always log, because we also want `i_epsiode` to be logged in wandb
@@ -1,8 +1,9 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
  import warnings
4
4
 
5
5
  import jax
6
+ import numpy as np
6
7
  import torch
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data import Dataset
@@ -12,7 +13,7 @@ from ring.utils import parse_path
12
13
  from ring.utils import pickle_load
13
14
 
14
15
 
15
- class FolderOfPickleFilesDataset(Dataset):
16
+ class FolderOfFilesDataset(Dataset):
16
17
  def __init__(self, path, transform=None):
17
18
  self.files = self.listdir(path)
18
19
  self.transform = transform
@@ -22,7 +23,7 @@ class FolderOfPickleFilesDataset(Dataset):
22
23
  return self.N
23
24
 
24
25
  def __getitem__(self, idx: int):
25
- element = pickle_load(self.files[idx])
26
+ element = self._load_file(self.files[idx])
26
27
  if self.transform is not None:
27
28
  element = self.transform(element)
28
29
  return element
@@ -31,6 +32,10 @@ class FolderOfPickleFilesDataset(Dataset):
31
32
  def listdir(path: str) -> list:
32
33
  return [parse_path(path, file) for file in os.listdir(path)]
33
34
 
35
+ @staticmethod
36
+ def _load_file(file_path: str) -> Any:
37
+ return pickle_load(file_path)
38
+
34
39
 
35
40
  def dataset_to_generator(
36
41
  dataset: Dataset,
@@ -84,3 +89,60 @@ def _get_number_of_logical_cores() -> int:
84
89
  )
85
90
  N = 0
86
91
  return N
92
+
93
+
94
+ class MultiDataset(Dataset):
95
+ def __init__(self, datasets, transform=None):
96
+ """
97
+ Args:
98
+ datasets: A list of datasets to sample from.
99
+ transform: A function that takes N items (one from each dataset) and combines them.
100
+ """ # noqa: E501
101
+ self.datasets = datasets
102
+ self.transform = transform
103
+
104
+ def __len__(self):
105
+ # Length is defined by the smallest dataset in the list
106
+ return min(len(ds) for ds in self.datasets)
107
+
108
+ def __getitem__(self, idx):
109
+ sampled_items = [ds[idx] for ds in self.datasets]
110
+
111
+ if self.transform:
112
+ # Apply the transformation to all sampled items
113
+ return self.transform(*sampled_items)
114
+
115
+ return tuple(sampled_items)
116
+
117
+
118
+ class ShuffledDataset(Dataset):
119
+ def __init__(self, dataset):
120
+ """
121
+ Wrapper that shuffles the dataset indices once.
122
+
123
+ Args:
124
+ dataset (Dataset): The original dataset to shuffle.
125
+ """
126
+ self.dataset = dataset
127
+ self.shuffled_indices = np.random.permutation(
128
+ len(dataset)
129
+ ) # Shuffle indices once
130
+
131
+ def __len__(self):
132
+ return len(self.dataset)
133
+
134
+ def __getitem__(self, idx):
135
+ """
136
+ Returns the data at the shuffled index.
137
+
138
+ Args:
139
+ idx (int): Index in the shuffled dataset.
140
+ """
141
+ original_idx = self.shuffled_indices[idx]
142
+ return self.dataset[original_idx]
143
+
144
+
145
+ def dataset_to_Xy(ds: Dataset):
146
+ return dataset_to_generator(ds, batch_size=len(ds), shuffle=False, num_workers=0)(
147
+ None
148
+ )