imt-ring 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.2
3
+ Version: 1.6.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -1,22 +1,22 @@
1
- ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
1
+ ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
3
+ ring/base.py,sha256=BGAJE3PSOUnTHte4UesJc1J7MQraIEiVpStkhrgXhaI,33245
4
4
  ring/maths.py,sha256=zGm5XagiKTaIJp310VcqVEVUuLhv3FPS-TJ-TFzIrwM,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
- ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
8
+ ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
9
  ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=QSIcU_sEB_tRo-ADD_66ZD01LMmJlSG0op6YnM-Gai8,17965
11
+ ring/algorithms/sensors.py,sha256=06x7RfhoQ6dx1B_TAEuCKxNTiicQDDBxcmzRtsCAxsM,18125
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
17
+ ring/algorithms/generator/base.py,sha256=00789VeBrYcKx5BjsiGu-d4UzO6FGRz-YUKiiLUOL2Q,14497
18
18
  ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
- ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
19
+ ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -55,7 +55,7 @@ ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
58
+ ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.2.dist-info/METADATA,sha256=S_LlVrmdRPQCzT5aeRoSWyOQ3eBJBL1D33tUgXMUEso,3104
87
- imt_ring-1.6.2.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
- imt_ring-1.6.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.2.dist-info/RECORD,,
86
+ imt_ring-1.6.4.dist-info/METADATA,sha256=3uiX-NEHrXZ4QurQPYQUPRlul7aVLBZYIr0w17Yo54E,3922
87
+ imt_ring-1.6.4.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
+ imt_ring-1.6.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.4.dist-info/RECORD,,
ring/__init__.py CHANGED
@@ -20,52 +20,55 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int] | None, Ts: float | None, **kwargs):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter:
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
27
  lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
- Usage:
31
- >>> import ring
32
- >>> import numpy as np
33
- >>>
34
- >>> T : int = 30 # sequence length [s]
35
- >>> Ts : float = 0.01 # sampling interval [s]
36
- >>> B : int = 1 # batch size
37
- >>> lam: list[int] = [0, 1, 2] # parent array
38
- >>> N : int = len(lam) # number of bodies
39
- >>> T_i: int = int(T/Ts) # number of timesteps
40
- >>>
41
- >>> X = np.zeros((B, T_i, N, 9))
42
- >>> # where X is structured as follows:
43
- >>> # X[..., :3] = acc
44
- >>> # X[..., 3:6] = gyr
45
- >>> # X[..., 6:9] = jointaxis
46
- >>>
47
- >>> # let's assume we have an IMU on each outer segment of the
48
- >>> # three-segment kinematic chain
49
- >>> X[:, :, 0, :3] = acc_segment1
50
- >>> X[:, :, 2, :3] = acc_segment3
51
- >>> X[:, :, 0, 3:6] = gyr_segment1
52
- >>> X[:, :, 2, 3:6] = gyr_segment3
53
- >>>
54
- >>> ringnet = ring.RING(lam, Ts)
55
- >>>
56
- >>> yhat, _ = ringnet.apply(X)
57
- >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
- >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
59
- >>>
60
- >>> # use `jax.jit` to compile the forward pass
61
- >>> jit_apply = jax.jit(ringnet.apply)
62
- >>> yhat, _ = jit_apply(X)
63
- >>>
64
- >>> # manually pass in and out the hidden state like so
65
- >>> initial_state = None
66
- >>> yhat, state = ringnet.apply(X, state=initial_state)
67
- >>> # state: final hidden state, shape = (B, N, 2*H)
68
-
30
+ Returns:
31
+ ring.ml.AbstractFilter: An instantiation of `ring.ml.ringnet.RING` with trained
32
+ parameters.
33
+
34
+ Examples:
35
+ >>> import ring
36
+ >>> import numpy as np
37
+ >>>
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [0, 1, 2] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
+ >>>
45
+ >>> X = np.zeros((B, T_i, N, 9))
46
+ >>> # where X is structured as follows:
47
+ >>> # X[..., :3] = acc
48
+ >>> # X[..., 3:6] = gyr
49
+ >>> # X[..., 6:9] = jointaxis
50
+ >>>
51
+ >>> # let's assume we have an IMU on each outer segment of the
52
+ >>> # three-segment kinematic chain
53
+ >>> X[:, :, 0, :3] = acc_segment1
54
+ >>> X[:, :, 2, :3] = acc_segment3
55
+ >>> X[:, :, 0, 3:6] = gyr_segment1
56
+ >>> X[:, :, 2, 3:6] = gyr_segment3
57
+ >>>
58
+ >>> ringnet = ring.RING(lam, Ts)
59
+ >>>
60
+ >>> yhat, _ = ringnet.apply(X)
61
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
62
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
63
+ >>>
64
+ >>> # use `jax.jit` to compile the forward pass
65
+ >>> jit_apply = jax.jit(ringnet.apply)
66
+ >>> yhat, _ = jit_apply(X)
67
+ >>>
68
+ >>> # manually pass in and out the hidden state like so
69
+ >>> initial_state = None
70
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
71
+ >>> # state: final hidden state, shape = (B, N, 2*H)
69
72
  """
70
73
  from pathlib import Path
71
74
  import warnings
@@ -303,6 +303,7 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
+ "Steps the dynamics. Returns the state of next timestep."
306
307
  assert sys.q_size() == state.q.size
307
308
  if taus is None:
308
309
  taus = jnp.zeros_like(state.qd)
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import numpy as np
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -30,6 +31,7 @@ class RCMG:
30
31
  add_X_jointaxes_kwargs: dict = dict(),
31
32
  add_y_relpose: bool = False,
32
33
  add_y_rootincl: bool = False,
34
+ add_y_rootincl_kwargs: dict = dict(),
33
35
  sys_ml: Optional[base.System] = None,
34
36
  randomize_positions: bool = False,
35
37
  randomize_motion_artifacts: bool = False,
@@ -46,6 +48,7 @@ class RCMG:
46
48
  cor: bool = False,
47
49
  disable_tqdm: bool = False,
48
50
  ) -> None:
51
+ "Random Chain Motion Generator"
49
52
 
50
53
  sys, config = utils.to_list(sys), utils.to_list(config)
51
54
  sys_ml = sys[0] if sys_ml is None else sys_ml
@@ -67,6 +70,7 @@ class RCMG:
67
70
  add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
68
71
  add_y_relpose=add_y_relpose,
69
72
  add_y_rootincl=add_y_rootincl,
73
+ add_y_rootincl_kwargs=add_y_rootincl_kwargs,
70
74
  sys_ml=sys_ml,
71
75
  randomize_positions=randomize_positions,
72
76
  randomize_motion_artifacts=randomize_motion_artifacts,
@@ -139,7 +143,9 @@ class RCMG:
139
143
 
140
144
  return n_calls
141
145
 
142
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
146
+ def to_list(
147
+ self, sizes: int | list[int] = 1, seed: int = 1
148
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
143
149
  "Returns list of unbatched sequences as numpy arrays."
144
150
  repeats = self._compute_repeats(sizes)
145
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -168,7 +174,7 @@ class RCMG:
168
174
  seed: int = 1,
169
175
  overwrite: bool = True,
170
176
  ) -> None:
171
- data = tree_utils.tree_batch(self.to_list(sizes, seed))
177
+ data = tree_utils.tree_batch(self.to_list(sizes, seed), backend="numpy")
172
178
  utils.pickle_save(data, path, overwrite=overwrite)
173
179
 
174
180
  def to_eager_gen(
@@ -232,6 +238,7 @@ def _build_mconfig_batched_generator(
232
238
  add_X_jointaxes_kwargs: dict,
233
239
  add_y_relpose: bool,
234
240
  add_y_rootincl: bool,
241
+ add_y_rootincl_kwargs: dict,
235
242
  sys_ml: base.System,
236
243
  randomize_positions: bool,
237
244
  randomize_motion_artifacts: bool,
@@ -77,12 +77,13 @@ class RelPose:
77
77
 
78
78
 
79
79
  class RootIncl:
80
- def __init__(self, sys: base.System):
80
+ def __init__(self, sys: base.System, **kwargs):
81
81
  self.sys = sys
82
+ self.kwargs = kwargs
82
83
 
83
84
  def __call__(self, Xy, extras):
84
85
  (X, y), (key, q, x, sys_x) = Xy, extras
85
- y_root_incl = sensors.root_incl(self.sys, x, sys_x)
86
+ y_root_incl = sensors.root_incl(self.sys, x, sys_x, **self.kwargs)
86
87
  y = utils.dict_union(y, y_root_incl)
87
88
  return (X, y), (key, q, x, sys_x)
88
89
 
@@ -330,7 +330,10 @@ def rel_pose(
330
330
 
331
331
 
332
332
  def root_incl(
333
- sys: base.System, x: base.Transform, sys_x: base.System
333
+ sys: base.System,
334
+ x: base.Transform,
335
+ sys_x: base.System,
336
+ child_to_parent: bool = False,
334
337
  ) -> dict[str, jax.Array]:
335
338
  # (time, nlinks, 4) -> (nlinks, time, 4)
336
339
  rots = x.rot.transpose((1, 0, 2))
@@ -341,8 +344,10 @@ def root_incl(
341
344
  def f(_, __, name: str, parent: int):
342
345
  if parent != -1:
343
346
  return
344
- q_eps_to_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
345
- y[name] = maths.quat_inv(q_eps_to_i)
347
+ q_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
348
+ if child_to_parent:
349
+ q_i = maths.quat_inv(q_i)
350
+ y[name] = q_i
346
351
 
347
352
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
348
353
 
@@ -350,7 +355,10 @@ def root_incl(
350
355
 
351
356
 
352
357
  def root_full(
353
- sys: base.System, x: base.Transform, sys_x: base.System
358
+ sys: base.System,
359
+ x: base.Transform,
360
+ sys_x: base.System,
361
+ child_to_parent: bool = False,
354
362
  ) -> dict[str, jax.Array]:
355
363
  # (time, nlinks, 4) -> (nlinks, time, 4)
356
364
  rots = x.rot.transpose((1, 0, 2))
@@ -361,8 +369,10 @@ def root_full(
361
369
  def f(_, __, name: str, parent: int):
362
370
  if parent != -1:
363
371
  return
364
- q_eps_to_i = rots[l_map[name]]
365
- y[name] = maths.quat_inv(q_eps_to_i)
372
+ q_i = rots[l_map[name]]
373
+ if child_to_parent:
374
+ q_i = maths.quat_inv(q_i)
375
+ y[name] = q_i
366
376
 
367
377
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
368
378
 
ring/base.py CHANGED
@@ -113,7 +113,9 @@ class _Base:
113
113
  class Transform(_Base):
114
114
  """Represents the Transformation from Plücker A to Plücker B,
115
115
  where B is located relative to A at `pos` in frame A and `rot` is the
116
- relative quaternion from A to B."""
116
+ relative quaternion from A to B.
117
+ Create using `Transform.create(pos=..., rot=...)
118
+ """
117
119
 
118
120
  pos: Vector
119
121
  rot: Quaternion
@@ -399,6 +401,7 @@ QD_WIDTHS = {
399
401
 
400
402
  @struct.dataclass
401
403
  class System(_Base):
404
+ "System object. Create using `System.create(path_xml)`"
402
405
  link_parents: list[int] = struct.field(False)
403
406
  links: Link
404
407
  link_types: list[str] = struct.field(False)
ring/ml/ringnet.py CHANGED
@@ -200,6 +200,7 @@ class RING(ml_base.AbstractFilter):
200
200
  forward_factory=make_ring,
201
201
  **kwargs,
202
202
  ):
203
+ "Untrained RING network"
203
204
  self.forward_lam_factory = partial(forward_factory, **kwargs)
204
205
  self.params = self._load_params(params)
205
206
  self.lam = lam