imt-ring 1.6.9__py3-none-any.whl → 1.6.10__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.9
3
+ Version: 1.6.10
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -14,8 +14,8 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwa
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=YKUr9UnjFbWAYjP8V1j0FaKGVpK92W6E1iuoet7qBNg,14522
18
- ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
17
+ ring/algorithms/generator/base.py,sha256=JHCTbHtmYEdmsHyQnJN9vMP6rzrlBoqzTZ27c3zhCDI,15655
18
+ ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
19
19
  ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=oGC7kh19s5zvmnUvWy8B3fBl9loVU58ppz91osk2m3w,6550
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.9.dist-info/METADATA,sha256=kFg-Ht8PsdnYA5lvEQ-KgMxlQSqiX_PtP8M8Q0vjVag,3820
87
- imt_ring-1.6.9.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
88
- imt_ring-1.6.9.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.9.dist-info/RECORD,,
86
+ imt_ring-1.6.10.dist-info/METADATA,sha256=j5LtBbakAQBMu_XP_642TYIaq0PFkG0S9h49nrnReoc,3821
87
+ imt_ring-1.6.10.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
88
+ imt_ring-1.6.10.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.10.dist-info/RECORD,,
@@ -6,6 +6,7 @@ import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
8
  import tree_utils
9
+ from tree_utils import PyTree
9
10
 
10
11
  from ring import base
11
12
  from ring import utils
@@ -143,9 +144,7 @@ class RCMG:
143
144
 
144
145
  return n_calls
145
146
 
146
- def to_list(
147
- self, sizes: int | list[int] = 1, seed: int = 1
148
- ) -> list[tree_utils.PyTree[np.ndarray]]:
147
+ def _generators_ncalls(self, sizes: int | list[int] = 1):
149
148
  "Returns list of unbatched sequences as numpy arrays."
150
149
  repeats = self._compute_repeats(sizes)
151
150
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -165,7 +164,45 @@ class RCMG:
165
164
  batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
166
165
  )
167
166
 
168
- return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
167
+ return gens, n_calls
168
+
169
+ def to_list(
170
+ self, sizes: int | list[int] = 1, seed: int = 1
171
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
172
+ "Returns list of unbatched sequences as numpy arrays."
173
+ gens, n_calls = self._generators_ncalls(sizes)
174
+
175
+ data = []
176
+ batch.generators_eager(
177
+ gens, n_calls, lambda d: data.extend(d), seed, self._disable_tqdm
178
+ )
179
+ return data
180
+
181
+ def to_folder(
182
+ self,
183
+ path: str,
184
+ sizes: int | list[int] = 1,
185
+ seed: int = 1,
186
+ overwrite: bool = True,
187
+ file_prefix: str = "seq",
188
+ save_fn: Callable[[PyTree[np.ndarray], str], None] = utils.pickle_save,
189
+ verbose: bool = True,
190
+ ):
191
+
192
+ i = 0
193
+
194
+ def callback(data: list[PyTree[np.ndarray]]) -> None:
195
+ nonlocal i
196
+ data = utils.replace_elements_w_nans(data, verbose=verbose)
197
+ for d in data:
198
+ file = utils.parse_path(
199
+ path, file_prefix + str(i), file_exists_ok=overwrite
200
+ )
201
+ save_fn(d, file)
202
+ i += 1
203
+
204
+ gens, n_calls = self._generators_ncalls(sizes)
205
+ batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)
169
206
 
170
207
  def to_pickle(
171
208
  self,
@@ -1,8 +1,10 @@
1
+ from typing import Callable
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  import numpy as np
4
6
  from tqdm import tqdm
5
- import tree_utils
7
+ from tree_utils import PyTree
6
8
 
7
9
  from ring import utils
8
10
  from ring.algorithms.generator import types
@@ -50,15 +52,15 @@ def generators_lazy(
50
52
  return generator
51
53
 
52
54
 
53
- def generators_eager_to_list(
55
+ def generators_eager(
54
56
  generators: list[types.BatchedGenerator],
55
57
  n_calls: list[int],
58
+ callback: Callable[[list[PyTree[np.ndarray]]], None],
56
59
  seed: int = 1,
57
60
  disable_tqdm: bool = False,
58
- ) -> list[tree_utils.PyTree]:
61
+ ) -> None:
59
62
 
60
63
  key = jax.random.PRNGKey(seed)
61
- data = []
62
64
  for gen, n_call in tqdm(
63
65
  zip(generators, n_calls),
64
66
  desc="executing generators",
@@ -81,6 +83,4 @@ def generators_eager_to_list(
81
83
 
82
84
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
83
85
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
84
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
85
-
86
- return data
86
+ callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])