imt-ring 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: imt-ring
3
- Version: 1.7.0
3
+ Version: 1.7.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,13 +4,13 @@ ring/base.py,sha256=AkG_Gpk7i2j77MzxzjiolJ9WNGcSq_3aOcpu0l6-0e0,50543
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
7
+ ring/algorithms/_random.py,sha256=9mSP7M1On_CBz-dEbvE0iDzl_buf125EiYb1y10Ylxc,15805
8
8
  ring/algorithms/dynamics.py,sha256=NFOZawjrFoS5RgiWOpG1pQCC8l7RBOEZIi9ok6gvf9U,12268
9
- ring/algorithms/jcalc.py,sha256=l6BXOmXwrZ_AKKRm4gEHq_k2LSUQ4wd--1qL1qNTcKk,46794
9
+ ring/algorithms/jcalc.py,sha256=zUMKEndSF6zaU7qHjJEU-iEQ_YAFNTNPcOwzdyu33uk,47632
10
10
  ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
11
11
  ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
13
- ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
13
+ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=1sOej-D3q4bSh5KJqcK5fZA3iOHXUlrlT9u-tIa-D6c,2351
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
@@ -28,7 +28,7 @@ ring/extras/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
28
28
  ring/extras/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
29
29
  ring/extras/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
30
30
  ring/extras/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
31
- ring/extras/interactive_viewer.py,sha256=vQEzcBDdG3BPqTGEktC74DsCfvgKktj9DKWK8gBzRtE,3805
31
+ ring/extras/interactive_viewer.py,sha256=-jpoSsDrdzshZw-_MSI96QLvm9foRWTUS_aa_CNFk74,3867
32
32
  ring/extras/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
33
33
  ring/extras/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
34
34
  ring/extras/torch_loss_fn.py,sha256=1LnWTmtxXPxoQFr4QixW12AjpRUfrseSDBmifhu6ErE,2676
@@ -89,8 +89,8 @@ ring/utils/__init__.py,sha256=Q37bjy2wjRGggd77MHlgl_50i2zOuVnPny4yOLiTe-8,567
89
89
  ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
90
90
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
91
91
  ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
92
- imt_ring-1.7.0.dist-info/METADATA,sha256=CNwgvWr9Yu7MgIfcNwXkuByr7_8vkxvO5IkJg3iDKbs,5887
93
- imt_ring-1.7.0.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
94
- imt_ring-1.7.0.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
95
- imt_ring-1.7.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
96
- imt_ring-1.7.0.dist-info/RECORD,,
92
+ imt_ring-1.7.1.dist-info/METADATA,sha256=cFujMyd1Xpqa8NiYn3iMqnhPHzGvlb_mLBbIPplQA6M,5887
93
+ imt_ring-1.7.1.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
94
+ imt_ring-1.7.1.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
95
+ imt_ring-1.7.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
96
+ imt_ring-1.7.1.dist-info/RECORD,,
@@ -41,30 +41,48 @@ def random_angle_over_time(
41
41
  cdf_bins_min: int = 5,
42
42
  cdf_bins_max: Optional[int] = None,
43
43
  interpolation_method: str = "cosine",
44
+ include_standstills_prob: float = 0.0, # 0.0 means no standstills
45
+ include_standstills_t_min: float = 0.5,
46
+ include_standstills_t_max: float = 5.0,
44
47
  ) -> jax.Array:
45
48
  def body_fn_outer(val):
46
49
  i, t, phi, key_t, key_ang, ANG = val
47
50
 
48
- key_t, consume_t = random.split(key_t)
51
+ key_t, consume_t, consume_standstill = random.split(key_t, 3)
49
52
  key_ang, consume_ang = random.split(key_ang)
50
53
  rom_halfsize_float = _to_float(rom_halfsize, t)
51
54
  rom_lower = ANG_0 - rom_halfsize_float
52
55
  rom_upper = ANG_0 + rom_halfsize_float
53
- dt, phi = _resolve_range_of_motion(
54
- range_of_motion,
55
- range_of_motion_method,
56
- rom_lower,
57
- rom_upper,
58
- _to_float(dang_min, t),
59
- _to_float(dang_max, t),
60
- _to_float(delta_ang_min, t),
61
- _to_float(delta_ang_max, t),
62
- t_min,
63
- _to_float(t_max, t),
64
- phi,
65
- consume_t,
66
- consume_ang,
67
- max_iter,
56
+
57
+ is_standstill = jax.random.bernoulli(
58
+ consume_standstill, include_standstills_prob
59
+ )
60
+ dt, phi = jax.lax.cond(
61
+ is_standstill,
62
+ lambda: (
63
+ jax.random.uniform(
64
+ consume_t,
65
+ minval=include_standstills_t_min,
66
+ maxval=include_standstills_t_max,
67
+ ),
68
+ phi,
69
+ ),
70
+ lambda: _resolve_range_of_motion(
71
+ range_of_motion,
72
+ range_of_motion_method,
73
+ rom_lower,
74
+ rom_upper,
75
+ _to_float(dang_min, t),
76
+ _to_float(dang_max, t),
77
+ _to_float(delta_ang_min, t),
78
+ _to_float(delta_ang_max, t),
79
+ t_min,
80
+ _to_float(t_max, t),
81
+ phi,
82
+ consume_t,
83
+ consume_ang,
84
+ max_iter,
85
+ ),
68
86
  )
69
87
  t += dt
70
88
 
@@ -119,7 +137,8 @@ def random_angle_over_time(
119
137
 
120
138
  # APPROVED
121
139
  def random_position_over_time(
122
- key: random.PRNGKey,
140
+ key_t: random.PRNGKey,
141
+ key_value: random.PRNGKey,
123
142
  POS_0: float,
124
143
  pos_min: float | TimeDependentFloat,
125
144
  pos_max: float | TimeDependentFloat,
@@ -135,19 +154,14 @@ def random_position_over_time(
135
154
  cdf_bins_min: int = 5,
136
155
  cdf_bins_max: Optional[int] = None,
137
156
  interpolation_method: str = "cosine",
157
+ include_standstills_prob: float = 0.0, # 0.0 means no standstills
158
+ include_standstills_t_min: float = 0.5,
159
+ include_standstills_t_max: float = 5.0,
138
160
  ) -> jax.Array:
139
161
  def body_fn_inner(val):
140
162
  i, t, t_pre, x, x_pre, key = val
141
163
  dt = t - t_pre
142
164
 
143
- def sample_dx_squared(key):
144
- key, consume = random.split(key)
145
- dx = (
146
- random.uniform(consume) * (2 * dpos_max * t_max**2)
147
- - dpos_max * t_max**2
148
- )
149
- return key, dx
150
-
151
165
  def sample_dx(key):
152
166
  key, consume1, consume2 = random.split(key, 3)
153
167
  sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
@@ -182,24 +196,43 @@ def random_position_over_time(
182
196
  return jnp.logical_not(break_if_true1 | break_if_true2)
183
197
 
184
198
  def body_fn_outer(val):
185
- i, t, t_pre, x, x_pre, key, POS = val
186
- key, consume = random.split(key)
187
- t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
199
+ i, t, t_pre, x, x_pre, key_t, key_value, POS = val
200
+ key_t, consume_t, consume_standstill = random.split(key_t, 3)
188
201
 
189
- # that zero resets the max_it count
190
- val_inner = (0, t, t_pre, x, x_pre, key)
191
- _, t, t_pre, x, x_pre, key = jax.lax.while_loop(
192
- cond_fn_inner, body_fn_inner, val_inner
202
+ is_standstill = jax.random.bernoulli(
203
+ consume_standstill, include_standstills_prob
204
+ )
205
+
206
+ def is_standstill_branch():
207
+ dt = random.uniform(
208
+ consume_t,
209
+ minval=include_standstills_t_min,
210
+ maxval=include_standstills_t_max,
211
+ )
212
+ t = t_pre + dt
213
+ return 0, t, t_pre, x, x_pre, key_value
214
+
215
+ def no_standstill_branch():
216
+ dt = random.uniform(consume_t, minval=t_min, maxval=_to_float(t_max, t_pre))
217
+ t = t_pre + dt
218
+ # that zero resets the max_it count
219
+ val_inner = (0, t, t_pre, x, x_pre, key_value)
220
+ return jax.lax.while_loop(cond_fn_inner, body_fn_inner, val_inner)
221
+
222
+ _, t, t_pre, x, x_pre, key_value = jax.lax.cond(
223
+ is_standstill,
224
+ is_standstill_branch,
225
+ no_standstill_branch,
193
226
  )
194
227
 
195
228
  POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
196
229
  POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
197
230
  t_pre = t
198
231
  x_pre = x
199
- return i + 1, t, t_pre, x, x_pre, key, POS
232
+ return i + 1, t, t_pre, x, x_pre, key_t, key_value, POS
200
233
 
201
234
  def cond_fn_outer(val):
202
- i, t, t_pre, x, x_pre, key, POS = val
235
+ i, t, t_pre, x, x_pre, key_t, key_value, POS = val
203
236
  return t <= T
204
237
 
205
238
  # preallocate POS array
@@ -207,7 +240,7 @@ def random_position_over_time(
207
240
  POS = jnp.zeros((int(T // t_min) + 1, 2))
208
241
  POS = POS.at[0, 1].set(POS_0)
209
242
 
210
- val_outer = (1, 0.0, 0.0, POS_0, POS_0, key, POS)
243
+ val_outer = (1, 0.0, 0.0, POS_0, POS_0, key_t, key_value, POS)
211
244
  end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
212
245
  POS = jnp.where(
213
246
  (jnp.arange(len(POS)) < end)[:, None],
@@ -23,11 +23,10 @@ def register_rr_imp_joint(
23
23
  return ring.Transform.create(rot=rot)
24
24
 
25
25
  def _draw_rr_imp(config, key_t, key_value, dt, N, _):
26
- key_t1, key_t2 = jax.random.split(key_t)
27
26
  key_value1, key_value2 = jax.random.split(key_value)
28
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
27
+ q_traj_pri = _draw_rxyz(config, key_t, key_value1, dt, N, _)
29
28
  q_traj_res = _draw_rxyz(
30
- replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
29
+ replace(config_res, T=config.T), key_t, key_value2, dt, N, _
31
30
  )
32
31
  # scale to be within bounds
33
32
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
ring/algorithms/jcalc.py CHANGED
@@ -174,6 +174,16 @@ class MotionConfig:
174
174
  default_factory=lambda: dict()
175
175
  )
176
176
 
177
+ # fields related to simulating standstills (no motion time periods)
178
+ # these are "Joint Standstills" so the standstills are calculated on
179
+ # a joint level, for each joint independently
180
+ # This means that a `standstills_prob` of 20% means that each joint
181
+ # has at each dt \in [t_min, t_max] drawing process a probability of
182
+ # 20% that it will just stay at its current joint value
183
+ include_standstills_prob: float = 0.0 # in %; 0% means no standstills
184
+ include_standstills_t_min: float = 0.5
185
+ include_standstills_t_max: float = 5.0
186
+
177
187
  def is_feasible(self) -> bool:
178
188
  return _is_feasible_config1(self)
179
189
 
@@ -791,12 +801,15 @@ def _draw_rxyz(
791
801
  config.cdf_bins_min,
792
802
  config.cdf_bins_max,
793
803
  config.interpolation_method,
804
+ config.include_standstills_prob,
805
+ config.include_standstills_t_min,
806
+ config.include_standstills_t_max,
794
807
  )
795
808
 
796
809
 
797
810
  def _draw_pxyz(
798
811
  config: MotionConfig,
799
- _: jax.random.PRNGKey,
812
+ key_t: jax.random.PRNGKey,
800
813
  key_value: jax.random.PRNGKey,
801
814
  dt: float | jax.Array,
802
815
  N: int | None,
@@ -811,6 +824,7 @@ def _draw_pxyz(
811
824
  )
812
825
  max_iter = 100
813
826
  return _random.random_position_over_time(
827
+ key_t,
814
828
  key_value,
815
829
  POS_0,
816
830
  config.cor_pos_min if cor else config.pos_min,
@@ -827,6 +841,9 @@ def _draw_pxyz(
827
841
  config.cdf_bins_min,
828
842
  config.cdf_bins_max,
829
843
  config.interpolation_method,
844
+ config.include_standstills_prob,
845
+ config.include_standstills_t_min,
846
+ config.include_standstills_t_max,
830
847
  )
831
848
 
832
849
 
@@ -840,7 +857,6 @@ def _draw_spherical(
840
857
  ) -> jax.Array:
841
858
  # NOTE: We draw 3 euler angles and then build a quaternion.
842
859
  # Not ideal, but i am unaware of a better way.
843
- @jax.vmap
844
860
  def draw_euler_angles(key_t, key_value):
845
861
  return _draw_rxyz(
846
862
  config,
@@ -853,8 +869,9 @@ def _draw_spherical(
853
869
  free_spherical=True,
854
870
  )
855
871
 
856
- triple = lambda key: jax.random.split(key, 3)
857
- euler_angles = draw_euler_angles(triple(key_t), triple(key_value)).T
872
+ euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
873
+ key_t, jax.random.split(key_value, 3)
874
+ ).T
858
875
  q = maths.quat_euler(euler_angles)
859
876
  return q
860
877
 
@@ -867,7 +884,6 @@ def _draw_saddle(
867
884
  N: int | None,
868
885
  _: jax.Array,
869
886
  ) -> jax.Array:
870
- @jax.vmap
871
887
  def draw_euler_angles(key_t, key_value):
872
888
  return _draw_rxyz(
873
889
  config,
@@ -880,14 +896,15 @@ def _draw_saddle(
880
896
  free_spherical=False,
881
897
  )
882
898
 
883
- double = lambda key: jax.random.split(key)
884
- yz_euler_angles = draw_euler_angles(double(key_t), double(key_value)).T
899
+ yz_euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
900
+ key_t, jax.random.split(key_value)
901
+ ).T
885
902
  return yz_euler_angles
886
903
 
887
904
 
888
905
  def _draw_p3d_and_cor(
889
906
  config: MotionConfig,
890
- _: jax.random.PRNGKey,
907
+ key_t: jax.random.PRNGKey,
891
908
  key_value: jax.random.PRNGKey,
892
909
  dt: float | jax.Array,
893
910
  N: int | None,
@@ -896,7 +913,7 @@ def _draw_p3d_and_cor(
896
913
  ) -> jax.Array:
897
914
  keys = jax.random.split(key_value, 3)
898
915
 
899
- def draw(key, xyz: str):
916
+ def draw(key_value, xyz: str):
900
917
  return _draw_pxyz(
901
918
  replace(
902
919
  config,
@@ -905,8 +922,8 @@ def _draw_p3d_and_cor(
905
922
  pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
906
923
  pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
907
924
  ),
908
- None,
909
- key,
925
+ key_t,
926
+ key_value,
910
927
  dt,
911
928
  N,
912
929
  None,
@@ -919,26 +936,26 @@ def _draw_p3d_and_cor(
919
936
 
920
937
  def _draw_p3d(
921
938
  config: MotionConfig,
922
- _: jax.random.PRNGKey,
939
+ key_t: jax.random.PRNGKey,
923
940
  key_value: jax.random.PRNGKey,
924
941
  dt: float | jax.Array,
925
942
  N: int | None,
926
943
  __: jax.Array,
927
944
  ) -> jax.Array:
928
- return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
945
+ return _draw_p3d_and_cor(config, key_t, key_value, dt, N, None, cor=False)
929
946
 
930
947
 
931
948
  def _draw_cor(
932
949
  config: MotionConfig,
933
- _: jax.random.PRNGKey,
950
+ key_t: jax.random.PRNGKey,
934
951
  key_value: jax.random.PRNGKey,
935
952
  dt: float | jax.Array,
936
953
  N: int | None,
937
954
  __: jax.Array,
938
955
  ) -> jax.Array:
939
956
  key_value1, key_value2 = jax.random.split(key_value)
940
- q_free = _draw_free(config, _, key_value1, dt, N, None)
941
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
957
+ q_free = _draw_free(config, key_t, key_value1, dt, N, None)
958
+ q_p3d = _draw_p3d_and_cor(config, key_t, key_value2, dt, N, None, cor=True)
942
959
  return jnp.concatenate((q_free, q_p3d), axis=1)
943
960
 
944
961
 
@@ -952,7 +969,7 @@ def _draw_free(
952
969
  ) -> jax.Array:
953
970
  key_value1, key_value2 = jax.random.split(key_value)
954
971
  q = _draw_spherical(config, key_t, key_value1, dt, N, None)
955
- pos = _draw_p3d(config, None, key_value2, dt, N, None)
972
+ pos = _draw_p3d(config, key_t, key_value2, dt, N, None)
956
973
  return jnp.concatenate((q, pos), axis=1)
957
974
 
958
975
 
@@ -975,7 +992,7 @@ def _draw_free_2d(
975
992
  enable_range_of_motion=False,
976
993
  free_spherical=True,
977
994
  )[:, None]
978
- pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
995
+ pos_yz = _draw_p3d(config, key_t, key_value2, dt, N, None)[:, :2]
979
996
  return jnp.concatenate((angle_x, pos_yz), axis=1)
980
997
 
981
998
 
@@ -82,10 +82,10 @@ def _fire_main(path_sys_xml: str, path_qs_np: Optional[str] = None, **scene_kwar
82
82
 
83
83
  assert qs.ndim == 2, f"qs.shape = {qs.shape}"
84
84
  T, Q = qs.shape
85
- assert Q == sys.q_size()
85
+ assert Q == sys.q_size(), f"Q={Q} != sys.q_size={sys.q_size()}"
86
86
  dt_target = sys.dt
87
87
 
88
- with InteractiveViewer(sys, **scene_kwargs) as viewer:
88
+ with InteractiveViewer(sys, width=640, height=480, **scene_kwargs) as viewer:
89
89
  dt = dt_target
90
90
  last_t = time.time()
91
91
  t = -1