imt-ring 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.2
3
+ Version: 1.6.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -1,22 +1,22 @@
1
- ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
1
+ ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
3
+ ring/base.py,sha256=BGAJE3PSOUnTHte4UesJc1J7MQraIEiVpStkhrgXhaI,33245
4
4
  ring/maths.py,sha256=zGm5XagiKTaIJp310VcqVEVUuLhv3FPS-TJ-TFzIrwM,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
- ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
8
+ ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
9
  ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=QSIcU_sEB_tRo-ADD_66ZD01LMmJlSG0op6YnM-Gai8,17965
11
+ ring/algorithms/sensors.py,sha256=06x7RfhoQ6dx1B_TAEuCKxNTiicQDDBxcmzRtsCAxsM,18125
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
17
+ ring/algorithms/generator/base.py,sha256=00789VeBrYcKx5BjsiGu-d4UzO6FGRz-YUKiiLUOL2Q,14497
18
18
  ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
- ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
19
+ ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -55,7 +55,7 @@ ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
58
+ ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.2.dist-info/METADATA,sha256=S_LlVrmdRPQCzT5aeRoSWyOQ3eBJBL1D33tUgXMUEso,3104
87
- imt_ring-1.6.2.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
- imt_ring-1.6.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.2.dist-info/RECORD,,
86
+ imt_ring-1.6.4.dist-info/METADATA,sha256=3uiX-NEHrXZ4QurQPYQUPRlul7aVLBZYIr0w17Yo54E,3922
87
+ imt_ring-1.6.4.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
+ imt_ring-1.6.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.4.dist-info/RECORD,,
ring/__init__.py CHANGED
@@ -20,52 +20,55 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int] | None, Ts: float | None, **kwargs):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter:
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
27
  lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
- Usage:
31
- >>> import ring
32
- >>> import numpy as np
33
- >>>
34
- >>> T : int = 30 # sequence length [s]
35
- >>> Ts : float = 0.01 # sampling interval [s]
36
- >>> B : int = 1 # batch size
37
- >>> lam: list[int] = [0, 1, 2] # parent array
38
- >>> N : int = len(lam) # number of bodies
39
- >>> T_i: int = int(T/Ts) # number of timesteps
40
- >>>
41
- >>> X = np.zeros((B, T_i, N, 9))
42
- >>> # where X is structured as follows:
43
- >>> # X[..., :3] = acc
44
- >>> # X[..., 3:6] = gyr
45
- >>> # X[..., 6:9] = jointaxis
46
- >>>
47
- >>> # let's assume we have an IMU on each outer segment of the
48
- >>> # three-segment kinematic chain
49
- >>> X[:, :, 0, :3] = acc_segment1
50
- >>> X[:, :, 2, :3] = acc_segment3
51
- >>> X[:, :, 0, 3:6] = gyr_segment1
52
- >>> X[:, :, 2, 3:6] = gyr_segment3
53
- >>>
54
- >>> ringnet = ring.RING(lam, Ts)
55
- >>>
56
- >>> yhat, _ = ringnet.apply(X)
57
- >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
- >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
59
- >>>
60
- >>> # use `jax.jit` to compile the forward pass
61
- >>> jit_apply = jax.jit(ringnet.apply)
62
- >>> yhat, _ = jit_apply(X)
63
- >>>
64
- >>> # manually pass in and out the hidden state like so
65
- >>> initial_state = None
66
- >>> yhat, state = ringnet.apply(X, state=initial_state)
67
- >>> # state: final hidden state, shape = (B, N, 2*H)
68
-
30
+ Returns:
31
+ ring.ml.AbstractFilter: An instantiation of `ring.ml.ringnet.RING` with trained
32
+ parameters.
33
+
34
+ Examples:
35
+ >>> import ring
36
+ >>> import numpy as np
37
+ >>>
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [0, 1, 2] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
+ >>>
45
+ >>> X = np.zeros((B, T_i, N, 9))
46
+ >>> # where X is structured as follows:
47
+ >>> # X[..., :3] = acc
48
+ >>> # X[..., 3:6] = gyr
49
+ >>> # X[..., 6:9] = jointaxis
50
+ >>>
51
+ >>> # let's assume we have an IMU on each outer segment of the
52
+ >>> # three-segment kinematic chain
53
+ >>> X[:, :, 0, :3] = acc_segment1
54
+ >>> X[:, :, 2, :3] = acc_segment3
55
+ >>> X[:, :, 0, 3:6] = gyr_segment1
56
+ >>> X[:, :, 2, 3:6] = gyr_segment3
57
+ >>>
58
+ >>> ringnet = ring.RING(lam, Ts)
59
+ >>>
60
+ >>> yhat, _ = ringnet.apply(X)
61
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
62
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
63
+ >>>
64
+ >>> # use `jax.jit` to compile the forward pass
65
+ >>> jit_apply = jax.jit(ringnet.apply)
66
+ >>> yhat, _ = jit_apply(X)
67
+ >>>
68
+ >>> # manually pass in and out the hidden state like so
69
+ >>> initial_state = None
70
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
71
+ >>> # state: final hidden state, shape = (B, N, 2*H)
69
72
  """
70
73
  from pathlib import Path
71
74
  import warnings
@@ -303,6 +303,7 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
+ "Steps the dynamics. Returns the state of next timestep."
306
307
  assert sys.q_size() == state.q.size
307
308
  if taus is None:
308
309
  taus = jnp.zeros_like(state.qd)
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import numpy as np
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -30,6 +31,7 @@ class RCMG:
30
31
  add_X_jointaxes_kwargs: dict = dict(),
31
32
  add_y_relpose: bool = False,
32
33
  add_y_rootincl: bool = False,
34
+ add_y_rootincl_kwargs: dict = dict(),
33
35
  sys_ml: Optional[base.System] = None,
34
36
  randomize_positions: bool = False,
35
37
  randomize_motion_artifacts: bool = False,
@@ -46,6 +48,7 @@ class RCMG:
46
48
  cor: bool = False,
47
49
  disable_tqdm: bool = False,
48
50
  ) -> None:
51
+ "Random Chain Motion Generator"
49
52
 
50
53
  sys, config = utils.to_list(sys), utils.to_list(config)
51
54
  sys_ml = sys[0] if sys_ml is None else sys_ml
@@ -67,6 +70,7 @@ class RCMG:
67
70
  add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
68
71
  add_y_relpose=add_y_relpose,
69
72
  add_y_rootincl=add_y_rootincl,
73
+ add_y_rootincl_kwargs=add_y_rootincl_kwargs,
70
74
  sys_ml=sys_ml,
71
75
  randomize_positions=randomize_positions,
72
76
  randomize_motion_artifacts=randomize_motion_artifacts,
@@ -139,7 +143,9 @@ class RCMG:
139
143
 
140
144
  return n_calls
141
145
 
142
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
146
+ def to_list(
147
+ self, sizes: int | list[int] = 1, seed: int = 1
148
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
143
149
  "Returns list of unbatched sequences as numpy arrays."
144
150
  repeats = self._compute_repeats(sizes)
145
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -168,7 +174,7 @@ class RCMG:
168
174
  seed: int = 1,
169
175
  overwrite: bool = True,
170
176
  ) -> None:
171
- data = tree_utils.tree_batch(self.to_list(sizes, seed))
177
+ data = tree_utils.tree_batch(self.to_list(sizes, seed), backend="numpy")
172
178
  utils.pickle_save(data, path, overwrite=overwrite)
173
179
 
174
180
  def to_eager_gen(
@@ -232,6 +238,7 @@ def _build_mconfig_batched_generator(
232
238
  add_X_jointaxes_kwargs: dict,
233
239
  add_y_relpose: bool,
234
240
  add_y_rootincl: bool,
241
+ add_y_rootincl_kwargs: dict,
235
242
  sys_ml: base.System,
236
243
  randomize_positions: bool,
237
244
  randomize_motion_artifacts: bool,
@@ -77,12 +77,13 @@ class RelPose:
77
77
 
78
78
 
79
79
  class RootIncl:
80
- def __init__(self, sys: base.System):
80
+ def __init__(self, sys: base.System, **kwargs):
81
81
  self.sys = sys
82
+ self.kwargs = kwargs
82
83
 
83
84
  def __call__(self, Xy, extras):
84
85
  (X, y), (key, q, x, sys_x) = Xy, extras
85
- y_root_incl = sensors.root_incl(self.sys, x, sys_x)
86
+ y_root_incl = sensors.root_incl(self.sys, x, sys_x, **self.kwargs)
86
87
  y = utils.dict_union(y, y_root_incl)
87
88
  return (X, y), (key, q, x, sys_x)
88
89
 
@@ -330,7 +330,10 @@ def rel_pose(
330
330
 
331
331
 
332
332
  def root_incl(
333
- sys: base.System, x: base.Transform, sys_x: base.System
333
+ sys: base.System,
334
+ x: base.Transform,
335
+ sys_x: base.System,
336
+ child_to_parent: bool = False,
334
337
  ) -> dict[str, jax.Array]:
335
338
  # (time, nlinks, 4) -> (nlinks, time, 4)
336
339
  rots = x.rot.transpose((1, 0, 2))
@@ -341,8 +344,10 @@ def root_incl(
341
344
  def f(_, __, name: str, parent: int):
342
345
  if parent != -1:
343
346
  return
344
- q_eps_to_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
345
- y[name] = maths.quat_inv(q_eps_to_i)
347
+ q_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
348
+ if child_to_parent:
349
+ q_i = maths.quat_inv(q_i)
350
+ y[name] = q_i
346
351
 
347
352
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
348
353
 
@@ -350,7 +355,10 @@ def root_incl(
350
355
 
351
356
 
352
357
  def root_full(
353
- sys: base.System, x: base.Transform, sys_x: base.System
358
+ sys: base.System,
359
+ x: base.Transform,
360
+ sys_x: base.System,
361
+ child_to_parent: bool = False,
354
362
  ) -> dict[str, jax.Array]:
355
363
  # (time, nlinks, 4) -> (nlinks, time, 4)
356
364
  rots = x.rot.transpose((1, 0, 2))
@@ -361,8 +369,10 @@ def root_full(
361
369
  def f(_, __, name: str, parent: int):
362
370
  if parent != -1:
363
371
  return
364
- q_eps_to_i = rots[l_map[name]]
365
- y[name] = maths.quat_inv(q_eps_to_i)
372
+ q_i = rots[l_map[name]]
373
+ if child_to_parent:
374
+ q_i = maths.quat_inv(q_i)
375
+ y[name] = q_i
366
376
 
367
377
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
368
378
 
ring/base.py CHANGED
@@ -113,7 +113,9 @@ class _Base:
113
113
  class Transform(_Base):
114
114
  """Represents the Transformation from Plücker A to Plücker B,
115
115
  where B is located relative to A at `pos` in frame A and `rot` is the
116
- relative quaternion from A to B."""
116
+ relative quaternion from A to B.
117
+ Create using `Transform.create(pos=..., rot=...)
118
+ """
117
119
 
118
120
  pos: Vector
119
121
  rot: Quaternion
@@ -399,6 +401,7 @@ QD_WIDTHS = {
399
401
 
400
402
  @struct.dataclass
401
403
  class System(_Base):
404
+ "System object. Create using `System.create(path_xml)`"
402
405
  link_parents: list[int] = struct.field(False)
403
406
  links: Link
404
407
  link_types: list[str] = struct.field(False)
ring/ml/ringnet.py CHANGED
@@ -200,6 +200,7 @@ class RING(ml_base.AbstractFilter):
200
200
  forward_factory=make_ring,
201
201
  **kwargs,
202
202
  ):
203
+ "Untrained RING network"
203
204
  self.forward_lam_factory = partial(forward_factory, **kwargs)
204
205
  self.params = self._load_params(params)
205
206
  self.lam = lam