imt-ring 1.6.32__tar.gz → 1.6.34__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.32 → imt_ring-1.6.34}/PKG-INFO +1 -1
  2. {imt_ring-1.6.32 → imt_ring-1.6.34}/pyproject.toml +1 -1
  3. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/base.py +2 -0
  5. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/batch.py +6 -1
  6. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/train.py +5 -4
  7. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/dataloader_torch.py +65 -3
  8. {imt_ring-1.6.32 → imt_ring-1.6.34}/readme.md +0 -0
  9. {imt_ring-1.6.32 → imt_ring-1.6.34}/setup.cfg +0 -0
  10. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  11. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  12. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/imt_ring.egg-info/requires.txt +0 -0
  13. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/imt_ring.egg-info/top_level.txt +0 -0
  14. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/__init__.py +0 -0
  15. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algebra.py +0 -0
  16. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/__init__.py +0 -0
  17. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/_random.py +0 -0
  18. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  19. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  20. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  21. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  22. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  23. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/dynamics.py +0 -0
  24. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/__init__.py +0 -0
  25. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/finalize_fns.py +0 -0
  26. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  27. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/pd_control.py +0 -0
  28. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/setup_fns.py +0 -0
  29. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/generator/types.py +0 -0
  30. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/jcalc.py +0 -0
  31. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/algorithms/sensors.py +0 -0
  33. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/base.py +0 -0
  34. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/__init__.py +0 -0
  35. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/branched.xml +0 -0
  36. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  37. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  38. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  39. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/inv_pendulum.xml +0 -0
  40. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  41. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/spherical_stiff.xml +0 -0
  42. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/symmetric.xml +0 -0
  43. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_all_1.xml +0 -0
  44. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_all_2.xml +0 -0
  45. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  46. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_control.xml +0 -0
  47. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  48. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_free.xml +0 -0
  49. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_kinematics.xml +0 -0
  50. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  51. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  52. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_randomize_position.xml +0 -0
  53. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_sensors.xml +0 -0
  54. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  55. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/examples.py +0 -0
  56. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/test_examples.py +0 -0
  57. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/__init__.py +0 -0
  58. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/abstract.py +0 -0
  59. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/from_xml.py +0 -0
  60. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/test_from_xml.py +0 -0
  61. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/test_to_xml.py +0 -0
  62. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/io/xml/to_xml.py +0 -0
  63. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/maths.py +0 -0
  64. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/__init__.py +0 -0
  65. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/base.py +0 -0
  66. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/callbacks.py +0 -0
  67. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/ml_utils.py +0 -0
  68. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/optimizer.py +0 -0
  69. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  70. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  71. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/ringnet.py +0 -0
  72. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/rnno_v1.py +0 -0
  73. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/ml/training_loop.py +0 -0
  74. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/rendering/__init__.py +0 -0
  75. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/rendering/base_render.py +0 -0
  76. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/rendering/mujoco_render.py +0 -0
  77. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/rendering/vispy_render.py +0 -0
  78. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/rendering/vispy_visuals.py +0 -0
  79. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sim2real/__init__.py +0 -0
  80. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sim2real/sim2real.py +0 -0
  81. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/spatial.py +0 -0
  82. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sys_composer/__init__.py +0 -0
  83. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sys_composer/delete_sys.py +0 -0
  84. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sys_composer/inject_sys.py +0 -0
  85. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/sys_composer/morph_sys.py +0 -0
  86. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/__init__.py +0 -0
  87. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/backend.py +0 -0
  88. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/batchsize.py +0 -0
  89. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/colab.py +0 -0
  90. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/dataloader.py +0 -0
  91. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/hdf5.py +0 -0
  92. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/normalizer.py +0 -0
  93. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/path.py +0 -0
  94. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/randomize_sys.py +0 -0
  95. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  96. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  97. {imt_ring-1.6.32 → imt_ring-1.6.34}/src/ring/utils/utils.py +0 -0
  98. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_base.py +0 -0
  100. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_generator.py +0 -0
  103. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_jit.py +0 -0
  105. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_maths.py +0 -0
  107. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_ml_utils.py +0 -0
  108. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_motion_artifacts.py +0 -0
  109. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_pd_control.py +0 -0
  110. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_quickstart_example.py +0 -0
  111. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_random.py +0 -0
  112. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_randomize.py +0 -0
  113. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_rcmg.py +0 -0
  114. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_render.py +0 -0
  115. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_sensors.py +0 -0
  116. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_sim2real.py +0 -0
  117. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_sys_composer.py +0 -0
  118. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_train.py +0 -0
  119. {imt_ring-1.6.32 → imt_ring-1.6.34}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.32
3
+ Version: 1.6.34
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.32"
7
+ version = "1.6.34"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.32
3
+ Version: 1.6.34
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -213,6 +213,8 @@ class RCMG:
213
213
  )
214
214
  save_fn(d, file)
215
215
  i += 1
216
+ # cleanup
217
+ del data
216
218
 
217
219
  gens, n_calls = self._generators_ncalls(sizes)
218
220
  batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from typing import Callable
2
3
 
3
4
  import jax
@@ -83,4 +84,8 @@ def generators_eager(
83
84
 
84
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
85
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
86
- callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
87
+ callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
88
+
89
+ # cleanup
90
+ del sample, sample_flat
91
+ gc.collect()
@@ -39,6 +39,7 @@ def _build_step_fn(
39
39
  filter: ml_base.AbstractFilter,
40
40
  optimizer,
41
41
  tbp,
42
+ skip_first_tbp_batch,
42
43
  ):
43
44
  """Build step function that optimizes filter parameters based on `metric_fn`.
44
45
  `initial_state` has shape (pmap, vmap, state_dim)"""
@@ -89,6 +90,8 @@ def _build_step_fn(
89
90
  ):
90
91
  (loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
91
92
  debug_grads.append(grads)
93
+ if skip_first_tbp_batch and i == 0:
94
+ continue
92
95
  state = jax.lax.stop_gradient(state)
93
96
  params, opt_state = apply_grads(grads, params, opt_state)
94
97
 
@@ -119,6 +122,7 @@ def train_fn(
119
122
  loss_fn: LOSS_FN = _default_loss_fn,
120
123
  metrices: Optional[METRICES] = _default_metrices,
121
124
  link_names: Optional[list[str]] = None,
125
+ skip_first_tbp_batch: bool = False,
122
126
  ) -> bool:
123
127
  """Trains RNNO
124
128
 
@@ -161,10 +165,7 @@ def train_fn(
161
165
  opt_state = optimizer.init(filter_params)
162
166
 
163
167
  step_fn = _build_step_fn(
164
- loss_fn,
165
- filter,
166
- optimizer,
167
- tbp=tbp,
168
+ loss_fn, filter, optimizer, tbp=tbp, skip_first_tbp_batch=skip_first_tbp_batch
168
169
  )
169
170
 
170
171
  # always log, because we also want `i_epsiode` to be logged in wandb
@@ -1,8 +1,9 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
  import warnings
4
4
 
5
5
  import jax
6
+ import numpy as np
6
7
  import torch
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data import Dataset
@@ -12,7 +13,7 @@ from ring.utils import parse_path
12
13
  from ring.utils import pickle_load
13
14
 
14
15
 
15
- class FolderOfPickleFilesDataset(Dataset):
16
+ class FolderOfFilesDataset(Dataset):
16
17
  def __init__(self, path, transform=None):
17
18
  self.files = self.listdir(path)
18
19
  self.transform = transform
@@ -22,7 +23,7 @@ class FolderOfPickleFilesDataset(Dataset):
22
23
  return self.N
23
24
 
24
25
  def __getitem__(self, idx: int):
25
- element = pickle_load(self.files[idx])
26
+ element = self._load_file(self.files[idx])
26
27
  if self.transform is not None:
27
28
  element = self.transform(element)
28
29
  return element
@@ -31,6 +32,10 @@ class FolderOfPickleFilesDataset(Dataset):
31
32
  def listdir(path: str) -> list:
32
33
  return [parse_path(path, file) for file in os.listdir(path)]
33
34
 
35
+ @staticmethod
36
+ def _load_file(file_path: str) -> Any:
37
+ return pickle_load(file_path)
38
+
34
39
 
35
40
  def dataset_to_generator(
36
41
  dataset: Dataset,
@@ -84,3 +89,60 @@ def _get_number_of_logical_cores() -> int:
84
89
  )
85
90
  N = 0
86
91
  return N
92
+
93
+
94
+ class MultiDataset(Dataset):
95
+ def __init__(self, datasets, transform=None):
96
+ """
97
+ Args:
98
+ datasets: A list of datasets to sample from.
99
+ transform: A function that takes N items (one from each dataset) and combines them.
100
+ """ # noqa: E501
101
+ self.datasets = datasets
102
+ self.transform = transform
103
+
104
+ def __len__(self):
105
+ # Length is defined by the smallest dataset in the list
106
+ return min(len(ds) for ds in self.datasets)
107
+
108
+ def __getitem__(self, idx):
109
+ sampled_items = [ds[idx] for ds in self.datasets]
110
+
111
+ if self.transform:
112
+ # Apply the transformation to all sampled items
113
+ return self.transform(*sampled_items)
114
+
115
+ return tuple(sampled_items)
116
+
117
+
118
+ class ShuffledDataset(Dataset):
119
+ def __init__(self, dataset):
120
+ """
121
+ Wrapper that shuffles the dataset indices once.
122
+
123
+ Args:
124
+ dataset (Dataset): The original dataset to shuffle.
125
+ """
126
+ self.dataset = dataset
127
+ self.shuffled_indices = np.random.permutation(
128
+ len(dataset)
129
+ ) # Shuffle indices once
130
+
131
+ def __len__(self):
132
+ return len(self.dataset)
133
+
134
+ def __getitem__(self, idx):
135
+ """
136
+ Returns the data at the shuffled index.
137
+
138
+ Args:
139
+ idx (int): Index in the shuffled dataset.
140
+ """
141
+ original_idx = self.shuffled_indices[idx]
142
+ return self.dataset[original_idx]
143
+
144
+
145
+ def dataset_to_Xy(ds: Dataset):
146
+ return dataset_to_generator(ds, batch_size=len(ds), shuffle=False, num_workers=0)(
147
+ None
148
+ )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes