imt-ring 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.2.2
3
+ Version: 1.3.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,12 +1,12 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=lG1SAVi6VpJT20Xvdhv_NrObMb4008leEqPEaQ0anR8,33566
3
+ ring/base.py,sha256=gqdXejZ4E4liB5mZ6gPof3EDYTThlfro2MQs0bc5eOM,33530
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
7
7
  ring/algorithms/_random.py,sha256=6EG0GHYe6tCq0qUt4Jes8W1EaqqaLa0sSZhnwBbEjCE,13340
8
8
  ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
9
- ring/algorithms/jcalc.py,sha256=6bO-_zFbHGDG5oq0t-HZvOSiYSWmZz_k6Z6VlNmSThA,25270
9
+ ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=33WBnaBJMtq3vVcpMm7zmyeMrLY9PyV_8-wk5oSF65g,227
@@ -14,8 +14,8 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSe
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=d0Z54tIXiepMixE40W5H8JKxrT5U6VskPm2L2kKnQPw,13680
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=QDmzMAgtaK5M9WDl39qjXYfBa99d83vCPWEkYYmsplk,14952
18
- ring/algorithms/generator/batch.py,sha256=MZurZmQDH1vncoNbCspVNGNlfP0R87J6_HC7MMIqQ6A,8478
17
+ ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
+ ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
20
  ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
@@ -40,8 +40,8 @@ ring/io/examples/test_randomize_position.xml,sha256=h6Yo5tT8QJBOJEwhE0cpppF-rvbO
40
40
  ring/io/examples/test_sensors.xml,sha256=urI_19gzwpDyWjtnse1Iy7CWCB0ezbsDP7FMLFlNw_4,494
41
41
  ring/io/examples/test_three_seg_seg2.xml,sha256=g85tx4V6PahkSbYXKWqk5vqurLos15WB6tBm1dQ_V_o,1022
42
42
  ring/io/examples/exclude/knee_trans_dof.xml,sha256=4Cuv6c7Yqa4T-RirRbrJKTT_41vRTuDlLPRLE_NopjU,1379
43
- ring/io/examples/exclude/standard_sys.xml,sha256=zTn_TVOBVmp0rq-g3aOOpjxHt6lTvPOKlHdvvqEGm-Y,8967
44
- ring/io/examples/exclude/standard_sys_rr_imp.xml,sha256=6aR_eA8RGfAMi36xojn5KyXdSSwAnc9sEotB0ukaPQM,9015
43
+ ring/io/examples/exclude/standard_sys.xml,sha256=1QxyN01TaFPKdzyacomQeA_8elcB2XRQp5D4NTUYNJw,8967
44
+ ring/io/examples/exclude/standard_sys_rr_imp.xml,sha256=1K8aLe4n97qFMQaytXdHVHDfshDhoZf78QzEe-gLjiU,9015
45
45
  ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQULWJNK8MLDQQ3raqPNRCJbA,1283
46
46
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
47
47
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
@@ -50,9 +50,9 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=669ueX_IMTmhAE-4xCjjp90mTlj28u8voq5_5KE0ZY0,944
54
- ring/ml/base.py,sha256=5TpJtdfmlAv2j_f8yDW1U_wY4jZ2lA74pNR524JQTts,8905
55
- ring/ml/callbacks.py,sha256=yrh9YWdEATEJq3fi9lQR0OU3hFENPpPO2UD4cyTRlIk,13109
53
+ ring/ml/__init__.py,sha256=4eK8P-pjAe_TcURaXaHKykZ3IfTbmxnQyOaI-EGQzg4,1795
54
+ ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
55
+ ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=OP70P70YcX-2Z-cuoMluFk-L5Vhh_MmqiHdM9OZqyhI,4703
58
58
  ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
@@ -60,7 +60,7 @@ ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
60
60
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
61
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
62
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
63
- ring/rendering/base_render.py,sha256=c3NTRE0VjnWmcHqCalvfQhCwiPyoMRr_2eiU04Y-mzU,8764
63
+ ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
64
64
  ring/rendering/mujoco_render.py,sha256=aluzQJp3jrDdPfAyNmQuXIHRfgfBTCCZQqxKOx_0D2s,7770
65
65
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
66
66
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -70,14 +70,14 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
70
70
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
71
71
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
72
72
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
73
- ring/utils/__init__.py,sha256=6BWUMmbQ4E-Qwd-SNfRlpbzJ0UJ1DpEclstrgbLdDvk,773
73
+ ring/utils/__init__.py,sha256=rTvSA4RiJAVCY_A64FUMd8IJTv94LgoSA3Ps5X63_jA,799
74
74
  ring/utils/batchsize.py,sha256=mPFGD7AedFMycHtyIuZtNWCaAvKLLWSWaB7X6u54xvM,1358
75
75
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
76
76
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
- ring/utils/utils.py,sha256=AzOzR95oOyfdtJhjt5iIb35u611NlTb1Ds4QDKrGMOM,4967
80
- imt_ring-1.2.2.dist-info/METADATA,sha256=rGaRlA9bTJH-N8QLbO4tokVge6UMnx-2Pddy5QwMvXA,3104
81
- imt_ring-1.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.2.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.2.2.dist-info/RECORD,,
79
+ ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
+ imt_ring-1.3.1.dist-info/METADATA,sha256=sCl08586u_XLy0LUsEuIhyIUPxj3R3pzmXtXgFuRw1c,3104
81
+ imt_ring-1.3.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.1.dist-info/RECORD,,
@@ -4,6 +4,8 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tree_utils
8
+
7
9
  from ring import base
8
10
  from ring import utils
9
11
  from ring.algorithms import jcalc
@@ -13,7 +15,6 @@ from ring.algorithms.generator import motion_artifacts
13
15
  from ring.algorithms.generator import randomize
14
16
  from ring.algorithms.generator import transforms
15
17
  from ring.algorithms.generator import types
16
- import tree_utils
17
18
 
18
19
 
19
20
  class RCMG:
@@ -108,23 +109,20 @@ class RCMG:
108
109
  partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
109
110
  )
110
111
 
111
- def _to_data(self, sizes, seed, jit):
112
- return batch.batch_generators_eager_to_list(
113
- self.gens, sizes, seed=seed, jit=jit
114
- )
112
+ def _to_data(self, sizes, seed):
113
+ return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
115
114
 
116
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):
117
- return self._to_data(sizes, seed, jit)
115
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
116
+ return self._to_data(sizes, seed)
118
117
 
119
118
  def to_pickle(
120
119
  self,
121
120
  path: str,
122
121
  sizes: int | list[int] = 1,
123
122
  seed: int = 1,
124
- jit: bool = False,
125
123
  overwrite: bool = True,
126
124
  ) -> None:
127
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
125
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
128
126
  utils.pickle_save(data, path, overwrite=overwrite)
129
127
 
130
128
  def to_hdf5(
@@ -132,10 +130,9 @@ class RCMG:
132
130
  path: str,
133
131
  sizes: int | list[int] = 1,
134
132
  seed: int = 1,
135
- jit: bool = False,
136
133
  overwrite: bool = True,
137
134
  ) -> None:
138
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
135
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
139
136
  utils.hdf5_save(path, data, overwrite=overwrite)
140
137
 
141
138
  def to_eager_gen(
@@ -143,11 +140,8 @@ class RCMG:
143
140
  batchsize: int = 1,
144
141
  sizes: int | list[int] = 1,
145
142
  seed: int = 1,
146
- jit: bool = False,
147
143
  ) -> types.BatchedGenerator:
148
- return batch.batch_generators_eager(
149
- self.gens, sizes, batchsize, seed=seed, jit=jit
150
- )
144
+ return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
151
145
 
152
146
  def to_lazy_gen(
153
147
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -6,12 +6,13 @@ import warnings
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from ring import utils
10
- from ring.algorithms.generator import types
11
9
  from tqdm import tqdm
12
10
  import tree_utils
13
11
  from tree_utils import tree_batch
14
12
 
13
+ from ring import utils
14
+ from ring.algorithms.generator import types
15
+
15
16
 
16
17
  def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
17
18
  arr = []
@@ -61,11 +62,24 @@ def batch_generators_lazy(
61
62
  return generator
62
63
 
63
64
 
65
+ def _number_of_executions_required(size: int) -> int:
66
+ vmap_threshold = 128
67
+ _, vmap = utils.distribute_batchsize(size)
68
+
69
+ primes = iter(utils.primes(vmap))
70
+ n_calls = 1
71
+ while vmap > vmap_threshold:
72
+ prime = next(primes)
73
+ n_calls *= prime
74
+ vmap /= prime
75
+
76
+ return n_calls
77
+
78
+
64
79
  def batch_generators_eager_to_list(
65
80
  generators: types.Generator | list[types.Generator],
66
81
  sizes: int | list[int],
67
82
  seed: int = 1,
68
- jit: bool = True,
69
83
  ) -> list[tree_utils.PyTree]:
70
84
  "Returns list of unbatched sequences as numpy arrays."
71
85
  generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
@@ -73,11 +87,20 @@ def batch_generators_eager_to_list(
73
87
  key = jax.random.PRNGKey(seed)
74
88
  data = []
75
89
  for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
76
- key, consume = jax.random.split(key)
77
- sample = batch_generators_lazy(gen, size, jit=jit)(consume)
78
- # converts also to numpy
79
- sample = jax.device_get(sample)
80
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
90
+
91
+ n_calls = _number_of_executions_required(size)
92
+ # decrease size by n_calls times
93
+ size = int(size / n_calls)
94
+ jit = True if n_calls > 1 else False
95
+ gen_jit = batch_generators_lazy(gen, size, jit=jit)
96
+
97
+ for _ in range(n_calls):
98
+ key, consume = jax.random.split(key)
99
+ sample = gen_jit(consume)
100
+ # converts also to numpy
101
+ sample = jax.device_get(sample)
102
+ data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
+
81
104
  return data
82
105
 
83
106
 
@@ -243,12 +266,11 @@ def batch_generators_eager(
243
266
  shuffle: bool = True,
244
267
  drop_last: bool = True,
245
268
  seed: int = 1,
246
- jit: bool = True,
247
269
  ) -> types.BatchedGenerator:
248
270
  """Eagerly create a large precomputed generator by calling multiple generators
249
271
  and stacking their output."""
250
272
 
251
- data = batch_generators_eager_to_list(generators, sizes, seed=seed, jit=jit)
273
+ data = batch_generators_eager_to_list(generators, sizes, seed=seed)
252
274
  return batched_generator_from_list(data, batchsize, shuffle, drop_last)
253
275
 
254
276
 
@@ -270,7 +292,7 @@ def _process_sizes_batchsizes_generators(
270
292
 
271
293
  assert len(generators) == len(list_sizes)
272
294
 
273
- _WARN_SIZE = 4096
295
+ _WARN_SIZE = 1e6 # disable this warning
274
296
  for size in list_sizes:
275
297
  if size >= _WARN_SIZE:
276
298
  warnings.warn(
ring/algorithms/jcalc.py CHANGED
@@ -88,6 +88,118 @@ class MotionConfig:
88
88
  assert nomotion_config.is_feasible()
89
89
  return nomotion_config
90
90
 
91
+ @staticmethod
92
+ def from_register(name: str) -> "MotionConfig":
93
+ return _registered_motion_configs[name]
94
+
95
+
96
+ _registered_motion_configs = {
97
+ "hinUndHer": MotionConfig(
98
+ t_min=0.3,
99
+ t_max=1.5,
100
+ dang_max=3.0,
101
+ delta_ang_min=0.5,
102
+ pos_min=-1.5,
103
+ pos_max=1.5,
104
+ randomized_interpolation_angle=True,
105
+ cor=True,
106
+ ),
107
+ "langsam": MotionConfig(
108
+ t_min=0.2,
109
+ t_max=1.25,
110
+ dang_max=2.0,
111
+ randomized_interpolation_angle=True,
112
+ dang_max_free_spherical=2.0,
113
+ cdf_bins_min=1,
114
+ cdf_bins_max=3,
115
+ pos_min=-1.5,
116
+ pos_max=1.5,
117
+ cor=True,
118
+ ),
119
+ "standard": MotionConfig(
120
+ randomized_interpolation_angle=True,
121
+ cdf_bins_min=1,
122
+ cdf_bins_max=5,
123
+ cor=True,
124
+ ),
125
+ "expFast": MotionConfig(
126
+ t_min=0.4,
127
+ t_max=1.1,
128
+ dang_max=jnp.deg2rad(180),
129
+ delta_ang_min=jnp.deg2rad(60),
130
+ delta_ang_max=jnp.deg2rad(110),
131
+ pos_min=-1.5,
132
+ pos_max=1.5,
133
+ range_of_motion_hinge_method="sigmoid",
134
+ randomized_interpolation_angle=True,
135
+ cdf_bins_min=1,
136
+ cdf_bins_max=3,
137
+ cor=True,
138
+ ),
139
+ "expSlow": MotionConfig(
140
+ t_min=0.75,
141
+ t_max=3.0,
142
+ dang_min=0.1,
143
+ dang_max=1.0,
144
+ dang_min_free_spherical=0.1,
145
+ delta_ang_min=0.4,
146
+ dang_max_free_spherical=1.0,
147
+ delta_ang_max_free_spherical=1.0,
148
+ dpos_max=0.3,
149
+ cor_dpos_max=0.3,
150
+ range_of_motion_hinge_method="sigmoid",
151
+ randomized_interpolation_angle=True,
152
+ cdf_bins_min=1,
153
+ cdf_bins_max=5,
154
+ cor=True,
155
+ ),
156
+ "expFastNoSig": MotionConfig(
157
+ t_min=0.4,
158
+ t_max=1.1,
159
+ dang_max=jnp.deg2rad(180),
160
+ delta_ang_min=jnp.deg2rad(60),
161
+ delta_ang_max=jnp.deg2rad(110),
162
+ pos_min=-1.5,
163
+ pos_max=1.5,
164
+ randomized_interpolation_angle=True,
165
+ cdf_bins_min=1,
166
+ cdf_bins_max=3,
167
+ cor=True,
168
+ ),
169
+ "expSlowNoSig": MotionConfig(
170
+ t_min=0.75,
171
+ t_max=3.0,
172
+ dang_min=0.1,
173
+ dang_max=1.0,
174
+ dang_min_free_spherical=0.1,
175
+ delta_ang_min=0.4,
176
+ dang_max_free_spherical=1.0,
177
+ delta_ang_max_free_spherical=1.0,
178
+ dpos_max=0.3,
179
+ cor_dpos_max=0.3,
180
+ randomized_interpolation_angle=True,
181
+ cdf_bins_min=1,
182
+ cdf_bins_max=3,
183
+ cor=True,
184
+ ),
185
+ "verySlow": MotionConfig(
186
+ t_min=1.5,
187
+ t_max=5.0,
188
+ dang_min=jnp.deg2rad(1),
189
+ dang_max=jnp.deg2rad(30),
190
+ delta_ang_min=jnp.deg2rad(20),
191
+ dang_min_free_spherical=jnp.deg2rad(1),
192
+ dang_max_free_spherical=jnp.deg2rad(10),
193
+ delta_ang_min_free_spherical=jnp.deg2rad(5),
194
+ dpos_max=0.3,
195
+ cor_dpos_max=0.3,
196
+ randomized_interpolation_angle=True,
197
+ cdf_bins_min=1,
198
+ cdf_bins_max=3,
199
+ cor=True,
200
+ ),
201
+ }
202
+
91
203
 
92
204
  def _is_feasible_config1(c: MotionConfig) -> bool:
93
205
  t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
ring/base.py CHANGED
@@ -99,15 +99,15 @@ class _Base:
99
99
  def ndim(self):
100
100
  return tu.tree_ndim(self)
101
101
 
102
- def shape(self, axis=0) -> int:
103
- return tu.tree_shape(self, axis)
104
-
105
- def __len__(self) -> int:
106
- Bs = tree_map(lambda arr: arr.shape[0], self)
102
+ def shape(self, axis: int = 0) -> int:
103
+ Bs = tree_map(lambda arr: arr.shape[axis], self)
107
104
  Bs = set(jax.tree_util.tree_flatten(Bs)[0])
108
105
  assert len(Bs) == 1
109
106
  return list(Bs)[0]
110
107
 
108
+ def __len__(self) -> int:
109
+ return self.shape(axis=0)
110
+
111
111
 
112
112
  @struct.dataclass
113
113
  class Transform(_Base):
@@ -685,14 +685,13 @@ class System(_Base):
685
685
  self,
686
686
  xs: Transform | list[Transform],
687
687
  yhat: dict | jax.Array | np.ndarray,
688
- stepframe: int = 1,
689
688
  # by default we don't predict the global rotation
690
689
  transparent_segment_to_root: bool = True,
691
690
  **kwargs,
692
691
  ):
693
692
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
694
693
  return ring.rendering.render_prediction(
695
- self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
694
+ self, xs, yhat, transparent_segment_to_root, **kwargs
696
695
  )
697
696
 
698
697
  def delete_system(self, link_name: str | list[str], strict: bool = True):
@@ -2,101 +2,101 @@
2
2
  <x_xy model="arm_1Seg">
3
3
  <options dt="0.01" gravity="0.0 0.0 9.81"/>
4
4
  <worldbody>
5
- <body joint="free" name="seg2_1Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
5
+ <body joint="free" name="seg3_1Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
6
6
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
7
7
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
8
8
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
9
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
10
- <body joint="frozen" name="imu2_1Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
9
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
10
+ <body joint="frozen" name="imu3_1Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
11
11
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
12
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
12
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
13
13
  </body>
14
14
  </body>
15
- <body joint="free" name="seg2_2Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
15
+ <body joint="free" name="seg3_2Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
16
16
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
17
17
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
18
18
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
19
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
20
- <body joint="frozen" name="imu2_2Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
19
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
20
+ <body joint="frozen" name="imu3_2Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
21
21
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
22
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
22
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
23
23
  </body>
24
- <body joint="ry" name="seg3_2Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
24
+ <body joint="ry" name="seg4_2Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
25
25
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
26
26
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
27
27
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
28
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
29
- <body joint="frozen" name="imu3_2Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
28
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
29
+ <body joint="frozen" name="imu4_2Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
30
30
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
31
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
31
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
32
32
  </body>
33
33
  </body>
34
34
  </body>
35
- <body joint="free" name="seg2_3Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
35
+ <body joint="free" name="seg3_3Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
36
36
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
37
37
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
38
38
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
39
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
40
- <body joint="frozen" name="imu2_3Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
39
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
40
+ <body joint="frozen" name="imu3_3Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
41
41
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
42
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
42
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
43
43
  </body>
44
- <body joint="ry" name="seg3_3Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
44
+ <body joint="ry" name="seg4_3Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
45
45
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
46
46
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
47
47
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
48
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
49
- <body joint="frozen" name="imu3_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
48
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
49
+ <body joint="frozen" name="imu4_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
50
50
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
51
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
51
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
52
52
  </body>
53
- <body joint="rz" name="seg4_3Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
53
+ <body joint="rz" name="seg5_3Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
54
54
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
55
55
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
56
56
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
57
- <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="4"/>
58
- <body joint="frozen" name="imu4_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
57
+ <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="4"/>
58
+ <body joint="frozen" name="imu5_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
59
59
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
60
- <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="4"/>
60
+ <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="4"/>
61
61
  </body>
62
62
  </body>
63
63
  </body>
64
64
  </body>
65
- <body joint="free" name="seg5_4Seg" pos="0.2 0.0 0.0" pos_min="0.15 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
65
+ <body joint="free" name="seg2_4Seg" pos="0.2 0.0 0.0" pos_min="0.15 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
66
66
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
67
67
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="dustin_exp_white" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
68
68
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="dustin_exp_white" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
69
- <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="2"/>
70
- <body joint="frozen" name="imu5_4Seg" pos="0.10000001 0.0 0.035" pos_min="0.049999997 -0.05 -0.05" pos_max="0.15000002 0.05 0.05">
69
+ <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="2"/>
70
+ <body joint="frozen" name="imu2_4Seg" pos="0.10000001 0.0 0.035" pos_min="0.049999997 -0.05 -0.05" pos_max="0.15000002 0.05 0.05">
71
71
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
72
- <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="2"/>
72
+ <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="2"/>
73
73
  </body>
74
- <body joint="rx" name="seg2_4Seg" pos="0.2 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
74
+ <body joint="rx" name="seg3_4Seg" pos="0.2 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
75
75
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
76
76
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
77
77
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
78
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
79
- <body joint="frozen" name="imu2_4Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
78
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
79
+ <body joint="frozen" name="imu3_4Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
80
80
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
81
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
81
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
82
82
  </body>
83
- <body joint="ry" name="seg3_4Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
83
+ <body joint="ry" name="seg4_4Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0">
84
84
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
85
85
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
86
86
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
87
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
88
- <body joint="frozen" name="imu3_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
87
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
88
+ <body joint="frozen" name="imu4_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
89
89
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
90
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
90
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
91
91
  </body>
92
- <body joint="rz" name="seg4_4Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
92
+ <body joint="rz" name="seg5_4Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0">
93
93
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
94
94
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
95
95
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
96
- <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="4"/>
97
- <body joint="frozen" name="imu4_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
96
+ <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="4"/>
97
+ <body joint="frozen" name="imu5_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
98
98
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
99
- <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="4"/>
99
+ <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="4"/>
100
100
  </body>
101
101
  </body>
102
102
  </body>
@@ -2,101 +2,101 @@
2
2
  <x_xy model="arm_1Seg">
3
3
  <options dt="0.01" gravity="0.0 0.0 9.81"/>
4
4
  <worldbody>
5
- <body joint="free" name="seg2_1Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
5
+ <body joint="free" name="seg3_1Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
6
6
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
7
7
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
8
8
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
9
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
10
- <body joint="frozen" name="imu2_1Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
9
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
10
+ <body joint="frozen" name="imu3_1Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
11
11
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
12
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
12
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
13
13
  </body>
14
14
  </body>
15
- <body joint="free" name="seg2_2Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
15
+ <body joint="free" name="seg3_2Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
16
16
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
17
17
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
18
18
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
19
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
20
- <body joint="frozen" name="imu2_2Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
19
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
20
+ <body joint="frozen" name="imu3_2Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
21
21
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
22
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
22
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
23
23
  </body>
24
- <body joint="rr_imp" name="seg3_2Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
24
+ <body joint="rr_imp" name="seg4_2Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
25
25
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
26
26
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
27
27
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
28
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
29
- <body joint="frozen" name="imu3_2Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
28
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
29
+ <body joint="frozen" name="imu4_2Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
30
30
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
31
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
31
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
32
32
  </body>
33
33
  </body>
34
34
  </body>
35
- <body joint="free" name="seg2_3Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
35
+ <body joint="free" name="seg3_3Seg" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
36
36
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
37
37
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
38
38
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
39
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
40
- <body joint="frozen" name="imu2_3Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
39
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
40
+ <body joint="frozen" name="imu3_3Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
41
41
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
42
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
42
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
43
43
  </body>
44
- <body joint="rr_imp" name="seg3_3Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
44
+ <body joint="rr_imp" name="seg4_3Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
45
45
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
46
46
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
47
47
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
48
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
49
- <body joint="frozen" name="imu3_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
48
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
49
+ <body joint="frozen" name="imu4_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
50
50
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
51
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
51
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
52
52
  </body>
53
- <body joint="rr_imp" name="seg4_3Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
53
+ <body joint="rr_imp" name="seg5_3Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
54
54
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
55
55
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
56
56
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
57
- <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="4"/>
58
- <body joint="frozen" name="imu4_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
57
+ <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="4"/>
58
+ <body joint="frozen" name="imu5_3Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
59
59
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
60
- <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="4"/>
60
+ <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="4"/>
61
61
  </body>
62
62
  </body>
63
63
  </body>
64
64
  </body>
65
- <body joint="free" name="seg5_4Seg" pos="0.2 0.0 0.0" pos_min="0.15 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
65
+ <body joint="free" name="seg2_4Seg" pos="0.2 0.0 0.0" pos_min="0.15 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
66
66
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
67
67
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="dustin_exp_white" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
68
68
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="dustin_exp_white" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
69
- <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="2"/>
70
- <body joint="frozen" name="imu5_4Seg" pos="0.10000001 0.0 0.035" pos_min="0.049999997 -0.05 -0.05" pos_max="0.15000002 0.05 0.05">
69
+ <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="2"/>
70
+ <body joint="frozen" name="imu2_4Seg" pos="0.10000001 0.0 0.035" pos_min="0.049999997 -0.05 -0.05" pos_max="0.15000002 0.05 0.05">
71
71
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
72
- <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="2"/>
72
+ <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="2"/>
73
73
  </body>
74
- <body joint="rr_imp" name="seg2_4Seg" pos="0.2 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
74
+ <body joint="rr_imp" name="seg3_4Seg" pos="0.2 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
75
75
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
76
76
  <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
77
77
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
78
- <omc pos="0.0 0.0 -0.02" name="seg2" pos_marker="1"/>
79
- <body joint="frozen" name="imu2_4Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
78
+ <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="1"/>
79
+ <body joint="frozen" name="imu3_4Seg" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
80
80
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
81
- <omc pos="0.1 0.0 0.015" name="seg2" pos_marker="1"/>
81
+ <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="1"/>
82
82
  </body>
83
- <body joint="rr_imp" name="seg3_4Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
83
+ <body joint="rr_imp" name="seg4_4Seg" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
84
84
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
85
85
  <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
86
86
  <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
87
- <omc pos="0.0 0.0 -0.02" name="seg3" pos_marker="2"/>
88
- <body joint="frozen" name="imu3_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
87
+ <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="2"/>
88
+ <body joint="frozen" name="imu4_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
89
89
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
90
- <omc pos="0.1 0.0 0.015" name="seg3" pos_marker="2"/>
90
+ <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="2"/>
91
91
  </body>
92
- <body joint="rr_imp" name="seg4_4Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
92
+ <body joint="rr_imp" name="seg5_4Seg" pos="0.19999999 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35000002 0.05 0.05" damping="3.0 3.0">
93
93
  <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_white" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
94
94
  <geom pos="0.03 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
95
95
  <geom pos="0.17 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
96
- <omc pos="0.0 0.0 -0.02" name="seg4" pos_marker="4"/>
97
- <body joint="frozen" name="imu4_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
96
+ <omc pos="0.0 0.0 -0.02" name="seg5" pos_marker="4"/>
97
+ <body joint="frozen" name="imu5_4Seg" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15000004 0.05 0.05">
98
98
  <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
99
- <omc pos="0.1 0.0 0.015" name="seg4" pos_marker="4"/>
99
+ <omc pos="0.1 0.0 0.015" name="seg5" pos_marker="4"/>
100
100
  </body>
101
101
  </body>
102
102
  </body>
ring/ml/__init__.py CHANGED
@@ -12,8 +12,10 @@ from .optimizer import make_optimizer
12
12
  from .ringnet import RING
13
13
  from .train import train_fn
14
14
 
15
+ _lpf_cutoff_freq = 10.0
15
16
 
16
- def RING_ICML24(params=None, **kwargs):
17
+
18
+ def RING_ICML24(params=None, eval: bool = True, **kwargs):
17
19
  """Create the RING network used in the icml24 paper.
18
20
 
19
21
  X[..., :3] = acc
@@ -28,6 +30,34 @@ def RING_ICML24(params=None, **kwargs):
28
30
 
29
31
  ringnet = RING(params=params, **kwargs) # noqa: F811
30
32
  ringnet = base.ScaleX_FilterWrapper(ringnet)
31
- ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
33
+ if eval:
34
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
32
35
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
33
36
  return ringnet
37
+
38
+
39
+ def RNNO(
40
+ output_dim: int,
41
+ return_quats: bool = False,
42
+ params=None,
43
+ eval: bool = True,
44
+ **kwargs,
45
+ ):
46
+ assert "message_dim" not in kwargs
47
+ assert "link_output_normalize" not in kwargs
48
+ assert "link_output_dim" not in kwargs
49
+
50
+ ringnet = RING( # noqa: F811
51
+ params=params,
52
+ message_dim=0,
53
+ link_output_normalize=False,
54
+ link_output_dim=output_dim,
55
+ **kwargs,
56
+ )
57
+ ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
58
+ ringnet = base.ScaleX_FilterWrapper(ringnet)
59
+ if eval and return_quats:
60
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
61
+ if return_quats:
62
+ ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
63
+ return ringnet
ring/ml/base.py CHANGED
@@ -34,9 +34,9 @@ class AbstractFilter(ABC):
34
34
  def _apply_batched(self, X, params, state, y, lam):
35
35
  pass
36
36
 
37
- @abstractmethod
38
37
  def init(self, bs, X, lam, seed: int):
39
- pass
38
+ params = state = None
39
+ return params, state
40
40
 
41
41
  def apply(self, X, params=None, state=None, y=None, lam=None):
42
42
  "X.shape = (B, T, N, F) or (T, N, F)"
@@ -286,7 +286,7 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
286
286
  yhat = yhat.reshape((T, N, -1))
287
287
 
288
288
  if self._quat_normalize:
289
- assert yhat.shape[-1] == 4
289
+ assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
290
290
  yhat = ring.maths.safe_normalize(yhat)
291
291
 
292
292
  return yhat, state
ring/ml/callbacks.py CHANGED
@@ -108,6 +108,10 @@ class EvalXyTrainingLoopCallback(training_loop.TrainingLoopCallback):
108
108
  if (i_episode % self.eval_every) == 0:
109
109
  point_estimates = self.eval_fn(params)
110
110
  self.last_metrices = {self.metric_identifier: point_estimates}
111
+
112
+ assert (
113
+ self.metric_identifier not in metrices
114
+ ), f"`{self.metric_identifier}` is already in `{metrices.keys()}`"
111
115
  metrices.update(self.last_metrices)
112
116
 
113
117
 
@@ -136,12 +136,15 @@ def render_prediction(
136
136
  sys: base.System,
137
137
  xs: base.Transform | list[base.Transform],
138
138
  yhat: dict | jax.Array | np.ndarray,
139
- stepframe: int = 1,
140
139
  # by default we don't predict the global rotation
141
140
  transparent_segment_to_root: bool = True,
142
141
  **kwargs,
143
142
  ):
144
143
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
+
145
+ offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
+ offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
+
145
148
  if isinstance(xs, list):
146
149
  # list -> batched Transform
147
150
  xs = xs[0].batch(*xs[1:])
@@ -180,18 +183,23 @@ def render_prediction(
180
183
 
181
184
  # swap time axis, and link axis
182
185
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
+
187
+ add_offset = lambda x, offset: algebra.transform_mul(
188
+ x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
189
+ )
190
+
183
191
  # create mapping from `name` -> Transform
184
192
  xs_dict = dict(
185
193
  zip(
186
194
  ["hat_" + name for name in sys_noimu.link_names],
187
- [xshat[i] for i in range(sys_noimu.num_links())],
195
+ [add_offset(xshat[i], offset_pred) for i in range(sys_noimu.num_links())],
188
196
  )
189
197
  )
190
198
  xs_dict.update(
191
199
  dict(
192
200
  zip(
193
201
  sys.link_names,
194
- [xs[i] for i in range(sys.num_links())],
202
+ [add_offset(xs[i], offset_truth) for i in range(sys.num_links())],
195
203
  )
196
204
  )
197
205
  )
@@ -202,11 +210,8 @@ def render_prediction(
202
210
  xs_render.append(xs_dict[name])
203
211
  xs_render = xs_render[0].batch(*xs_render[1:])
204
212
  xs_render = xs_render.transpose((1, 0, 2))
205
- N = xs_render.shape()
206
- xs_render = [xs_render[t] for t in range(0, N, stepframe)]
207
213
 
208
214
  frames = render(sys_render, xs_render, **kwargs)
209
-
210
215
  return frames
211
216
 
212
217
 
@@ -229,7 +234,7 @@ def _color_to_rgba(geom: base.Geometry) -> base.Geometry:
229
234
  def _xyz_to_three_capsules(xyz: base.XYZ) -> list[base.Geometry]:
230
235
  capsules = []
231
236
  length = xyz.size
232
- radius = length / 6
237
+ radius = length / 7
233
238
  colors = ["red", "green", "blue"]
234
239
  rot_axis = [1, 0, 2]
235
240
 
ring/utils/__init__.py CHANGED
@@ -15,6 +15,7 @@ from .utils import dict_union
15
15
  from .utils import import_lib
16
16
  from .utils import pickle_load
17
17
  from .utils import pickle_save
18
+ from .utils import primes
18
19
  from .utils import pytree_deepcopy
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
ring/utils/utils.py CHANGED
@@ -159,3 +159,17 @@ def pickle_load(
159
159
  with open(path, "rb") as file:
160
160
  obj = pickle.load(file)
161
161
  return obj
162
+
163
+
164
+ def primes(n: int) -> list[int]:
165
+ "Primefactor decomposition in ascending order."
166
+ primfac = []
167
+ d = 2
168
+ while d * d <= n:
169
+ while (n % d) == 0:
170
+ primfac.append(d) # supposing you want multiple factors repeated
171
+ n //= d
172
+ d += 1
173
+ if n > 1:
174
+ primfac.append(n)
175
+ return primfac