imt-ring 1.6.8__py3-none-any.whl → 1.6.10__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.8
3
+ Version: 1.6.10
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=BGAJE3PSOUnTHte4UesJc1J7MQraIEiVpStkhrgXhaI,33245
3
+ ring/base.py,sha256=MkkziQx01sdMOpB8MFUDFgFlZUrXCFjpb8hS9yKHUyM,33751
4
4
  ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
@@ -14,8 +14,8 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwa
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=YKUr9UnjFbWAYjP8V1j0FaKGVpK92W6E1iuoet7qBNg,14522
18
- ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
17
+ ring/algorithms/generator/base.py,sha256=JHCTbHtmYEdmsHyQnJN9vMP6rzrlBoqzTZ27c3zhCDI,15655
18
+ ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
19
19
  ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=oGC7kh19s5zvmnUvWy8B3fBl9loVU58ppz91osk2m3w,6550
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.8.dist-info/METADATA,sha256=FxO8BZbMRGagmYW1cQCHPuUIuiXxsIr8F-zlBkBVHNQ,3820
87
- imt_ring-1.6.8.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
88
- imt_ring-1.6.8.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.8.dist-info/RECORD,,
86
+ imt_ring-1.6.10.dist-info/METADATA,sha256=j5LtBbakAQBMu_XP_642TYIaq0PFkG0S9h49nrnReoc,3821
87
+ imt_ring-1.6.10.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
88
+ imt_ring-1.6.10.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.2.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -6,6 +6,7 @@ import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
8
  import tree_utils
9
+ from tree_utils import PyTree
9
10
 
10
11
  from ring import base
11
12
  from ring import utils
@@ -143,9 +144,7 @@ class RCMG:
143
144
 
144
145
  return n_calls
145
146
 
146
- def to_list(
147
- self, sizes: int | list[int] = 1, seed: int = 1
148
- ) -> list[tree_utils.PyTree[np.ndarray]]:
147
+ def _generators_ncalls(self, sizes: int | list[int] = 1):
149
148
  "Returns list of unbatched sequences as numpy arrays."
150
149
  repeats = self._compute_repeats(sizes)
151
150
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -165,7 +164,45 @@ class RCMG:
165
164
  batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
166
165
  )
167
166
 
168
- return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
167
+ return gens, n_calls
168
+
169
+ def to_list(
170
+ self, sizes: int | list[int] = 1, seed: int = 1
171
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
172
+ "Returns list of unbatched sequences as numpy arrays."
173
+ gens, n_calls = self._generators_ncalls(sizes)
174
+
175
+ data = []
176
+ batch.generators_eager(
177
+ gens, n_calls, lambda d: data.extend(d), seed, self._disable_tqdm
178
+ )
179
+ return data
180
+
181
+ def to_folder(
182
+ self,
183
+ path: str,
184
+ sizes: int | list[int] = 1,
185
+ seed: int = 1,
186
+ overwrite: bool = True,
187
+ file_prefix: str = "seq",
188
+ save_fn: Callable[[PyTree[np.ndarray], str], None] = utils.pickle_save,
189
+ verbose: bool = True,
190
+ ):
191
+
192
+ i = 0
193
+
194
+ def callback(data: list[PyTree[np.ndarray]]) -> None:
195
+ nonlocal i
196
+ data = utils.replace_elements_w_nans(data, verbose=verbose)
197
+ for d in data:
198
+ file = utils.parse_path(
199
+ path, file_prefix + str(i), file_exists_ok=overwrite
200
+ )
201
+ save_fn(d, file)
202
+ i += 1
203
+
204
+ gens, n_calls = self._generators_ncalls(sizes)
205
+ batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
169
206
 
170
207
  def to_pickle(
171
208
  self,
@@ -1,8 +1,10 @@
1
+ from typing import Callable
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import numpy as np
4
6
  from tqdm import tqdm
5
- import tree_utils
7
+ from tree_utils import PyTree
6
8
 
7
9
  from ring import utils
8
10
  from ring.algorithms.generator import types
@@ -50,15 +52,15 @@ def generators_lazy(
50
52
  return generator
51
53
 
52
54
 
53
- def generators_eager_to_list(
55
+ def generators_eager(
54
56
  generators: list[types.BatchedGenerator],
55
57
  n_calls: list[int],
58
+ callback: Callable[[list[PyTree[np.ndarray]]], None],
56
59
  seed: int = 1,
57
60
  disable_tqdm: bool = False,
58
- ) -> list[tree_utils.PyTree]:
61
+ ) -> None:
59
62
 
60
63
  key = jax.random.PRNGKey(seed)
61
- data = []
62
64
  for gen, n_call in tqdm(
63
65
  zip(generators, n_calls),
64
66
  desc="executing generators",
@@ -81,6 +83,4 @@ def generators_eager_to_list(
81
83
 
82
84
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
83
85
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
84
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
85
-
86
- return data
86
+ callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
ring/base.py CHANGED
@@ -590,12 +590,14 @@ class System(_Base):
590
590
 
591
591
  return sys
592
592
 
593
- def findall_imus(self) -> list[str]:
594
- return [name for name in self.link_names if name[:3] == "imu"]
593
+ def findall_imus(self, names: bool = True) -> list[str] | list[int]:
594
+ bodies = [name for name in self.link_names if name[:3] == "imu"]
595
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
595
596
 
596
- def findall_segments(self) -> list[str]:
597
- imus = self.findall_imus()
598
- return [name for name in self.link_names if name not in imus]
597
+ def findall_segments(self, names: bool = True) -> list[str] | list[int]:
598
+ imus = self.findall_imus(names=True)
599
+ bodies = [name for name in self.link_names if name not in imus]
600
+ return bodies if names else [self.name_to_idx(n) for n in bodies]
599
601
 
600
602
  def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:
601
603
  return [self.idx_to_name(i) for i in bodies]
@@ -615,6 +617,11 @@ class System(_Base):
615
617
  bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
616
618
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
617
619
 
620
+ def children(self, name: str, names: bool = False) -> list[int] | list[str]:
621
+ p = self.name_to_idx(name)
622
+ bodies = [i for i in range(self.num_links()) if self.link_parents[i] == p]
623
+ return bodies if (not names) else [self.idx_to_name(i) for i in bodies]
624
+
618
625
  def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
619
626
  """Scan `f` along each link in system whilst carrying along state.
620
627