imt-ring 1.6.9__py3-none-any.whl → 1.6.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.9
3
+ Version: 1.6.11
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -14,8 +14,8 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwa
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=YKUr9UnjFbWAYjP8V1j0FaKGVpK92W6E1iuoet7qBNg,14522
18
- ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
17
+ ring/algorithms/generator/base.py,sha256=LRAKxzrwq6fp4lgVw6IUg4i7isx3iqJLHvpFK1aTRcg,15732
18
+ ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
19
19
  ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
@@ -50,7 +50,7 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
53
+ ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
54
54
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=oGC7kh19s5zvmnUvWy8B3fBl9loVU58ppz91osk2m3w,6550
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.9.dist-info/METADATA,sha256=kFg-Ht8PsdnYA5lvEQ-KgMxlQSqiX_PtP8M8Q0vjVag,3820
87
- imt_ring-1.6.9.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
88
- imt_ring-1.6.9.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.9.dist-info/RECORD,,
86
+ imt_ring-1.6.11.dist-info/METADATA,sha256=kkQfOD5LOSzB4lR7LvkHeck6fB_KPNrSKIsvPizJAKI,3821
87
+ imt_ring-1.6.11.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
88
+ imt_ring-1.6.11.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,3 +1,4 @@
1
+ from functools import partial
1
2
  import random
2
3
  from typing import Callable, Optional
3
4
  import warnings
@@ -6,6 +7,7 @@ import jax
6
7
  import jax.numpy as jnp
7
8
  import numpy as np
8
9
  import tree_utils
10
+ from tree_utils import PyTree
9
11
 
10
12
  from ring import base
11
13
  from ring import utils
@@ -143,9 +145,7 @@ class RCMG:
143
145
 
144
146
  return n_calls
145
147
 
146
- def to_list(
147
- self, sizes: int | list[int] = 1, seed: int = 1
148
- ) -> list[tree_utils.PyTree[np.ndarray]]:
148
+ def _generators_ncalls(self, sizes: int | list[int] = 1):
149
149
  "Returns list of unbatched sequences as numpy arrays."
150
150
  repeats = self._compute_repeats(sizes)
151
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -165,7 +165,47 @@ class RCMG:
165
165
  batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
166
166
  )
167
167
 
168
- return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
168
+ return gens, n_calls
169
+
170
+ def to_list(
171
+ self, sizes: int | list[int] = 1, seed: int = 1
172
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
173
+ "Returns list of unbatched sequences as numpy arrays."
174
+ gens, n_calls = self._generators_ncalls(sizes)
175
+
176
+ data = []
177
+ batch.generators_eager(
178
+ gens, n_calls, lambda d: data.extend(d), seed, self._disable_tqdm
179
+ )
180
+ return data
181
+
182
+ def to_folder(
183
+ self,
184
+ path: str,
185
+ sizes: int | list[int] = 1,
186
+ seed: int = 1,
187
+ overwrite: bool = True,
188
+ file_prefix: str = "seq",
189
+ save_fn: Callable[[PyTree[np.ndarray], str], None] = partial(
190
+ utils.pickle_save, overwrite=True
191
+ ),
192
+ verbose: bool = True,
193
+ ):
194
+
195
+ i = 0
196
+
197
+ def callback(data: list[PyTree[np.ndarray]]) -> None:
198
+ nonlocal i
199
+ data = utils.replace_elements_w_nans(data, verbose=verbose)
200
+ for d in data:
201
+ file = utils.parse_path(
202
+ path, file_prefix + str(i), file_exists_ok=overwrite
203
+ )
204
+ save_fn(d, file)
205
+ i += 1
206
+
207
+ gens, n_calls = self._generators_ncalls(sizes)
208
+ batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
169
209
 
170
210
  def to_pickle(
171
211
  self,
@@ -1,8 +1,10 @@
1
+ from typing import Callable
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import numpy as np
4
6
  from tqdm import tqdm
5
- import tree_utils
7
+ from tree_utils import PyTree
6
8
 
7
9
  from ring import utils
8
10
  from ring.algorithms.generator import types
@@ -50,15 +52,15 @@ def generators_lazy(
50
52
  return generator
51
53
 
52
54
 
53
- def generators_eager_to_list(
55
+ def generators_eager(
54
56
  generators: list[types.BatchedGenerator],
55
57
  n_calls: list[int],
58
+ callback: Callable[[list[PyTree[np.ndarray]]], None],
56
59
  seed: int = 1,
57
60
  disable_tqdm: bool = False,
58
- ) -> list[tree_utils.PyTree]:
61
+ ) -> None:
59
62
 
60
63
  key = jax.random.PRNGKey(seed)
61
- data = []
62
64
  for gen, n_call in tqdm(
63
65
  zip(generators, n_calls),
64
66
  desc="executing generators",
@@ -81,6 +83,4 @@ def generators_eager_to_list(
81
83
 
82
84
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
83
85
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
84
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
85
-
86
- return data
86
+ callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
ring/ml/__init__.py CHANGED
@@ -23,6 +23,7 @@ def RNNO(
23
23
  eval: bool = True,
24
24
  samp_freq: float | None = None,
25
25
  v1: bool = False,
26
+ scale_X: bool = True,
26
27
  **kwargs,
27
28
  ):
28
29
  assert "message_dim" not in kwargs
@@ -47,7 +48,8 @@ def RNNO(
47
48
  **kwargs,
48
49
  )
49
50
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
50
- ringnet = base.ScaleX_FilterWrapper(ringnet)
51
+ if scale_X:
52
+ ringnet = base.ScaleX_FilterWrapper(ringnet)
51
53
  if eval and return_quats:
52
54
  ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
53
55
  if return_quats: