imt-ring 1.6.3__py3-none-any.whl → 1.6.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.3
3
+ Version: 1.6.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -1,11 +1,11 @@
1
- ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
1
+ ring/__init__.py,sha256=k7tL-XgggUwWxHCXyv60rQn-OcXHPg82QcIUkKLEd-c,5057
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
3
+ ring/base.py,sha256=BGAJE3PSOUnTHte4UesJc1J7MQraIEiVpStkhrgXhaI,33245
4
4
  ring/maths.py,sha256=zGm5XagiKTaIJp310VcqVEVUuLhv3FPS-TJ-TFzIrwM,12207
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
- ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
8
+ ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
9
9
  ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=06x7RfhoQ6dx1B_TAEuCKxNTiicQDDBxcmzRtsCAxsM,18125
@@ -14,7 +14,7 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwa
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=ZW6tJSUiNKtRCYIsbCf6e6kRBbmuCLzbZj4ppNLVwJY,14368
17
+ ring/algorithms/generator/base.py,sha256=00789VeBrYcKx5BjsiGu-d4UzO6FGRz-YUKiiLUOL2Q,14497
18
18
  ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
19
  ring/algorithms/generator/finalize_fns.py,sha256=LUw1Wc2YrmMRRh4RF704ob3bZOXktAZAbbLoBm_p1yw,9131
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
@@ -55,7 +55,7 @@ ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
58
+ ring/ml/ringnet.py,sha256=Tb2WJ_cc5L3mk1lo0NOfkpXIzJZXf4PJ5aLPtHQyUmY,8650
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
60
60
  ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
61
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
@@ -83,7 +83,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
83
83
  ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
84
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
85
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
- imt_ring-1.6.3.dist-info/METADATA,sha256=uFTNWR0YQbQLE50Oby-D2H3NSpdycgCwzKaz_UxxSP8,3104
87
- imt_ring-1.6.3.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
- imt_ring-1.6.3.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
- imt_ring-1.6.3.dist-info/RECORD,,
86
+ imt_ring-1.6.4.dist-info/METADATA,sha256=3uiX-NEHrXZ4QurQPYQUPRlul7aVLBZYIr0w17Yo54E,3922
87
+ imt_ring-1.6.4.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
88
+ imt_ring-1.6.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.4.dist-info/RECORD,,
ring/__init__.py CHANGED
@@ -20,52 +20,55 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int] | None, Ts: float | None, **kwargs):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter:
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
27
  lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
- Usage:
31
- >>> import ring
32
- >>> import numpy as np
33
- >>>
34
- >>> T : int = 30 # sequence length [s]
35
- >>> Ts : float = 0.01 # sampling interval [s]
36
- >>> B : int = 1 # batch size
37
- >>> lam: list[int] = [0, 1, 2] # parent array
38
- >>> N : int = len(lam) # number of bodies
39
- >>> T_i: int = int(T/Ts) # number of timesteps
40
- >>>
41
- >>> X = np.zeros((B, T_i, N, 9))
42
- >>> # where X is structured as follows:
43
- >>> # X[..., :3] = acc
44
- >>> # X[..., 3:6] = gyr
45
- >>> # X[..., 6:9] = jointaxis
46
- >>>
47
- >>> # let's assume we have an IMU on each outer segment of the
48
- >>> # three-segment kinematic chain
49
- >>> X[:, :, 0, :3] = acc_segment1
50
- >>> X[:, :, 2, :3] = acc_segment3
51
- >>> X[:, :, 0, 3:6] = gyr_segment1
52
- >>> X[:, :, 2, 3:6] = gyr_segment3
53
- >>>
54
- >>> ringnet = ring.RING(lam, Ts)
55
- >>>
56
- >>> yhat, _ = ringnet.apply(X)
57
- >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
- >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
59
- >>>
60
- >>> # use `jax.jit` to compile the forward pass
61
- >>> jit_apply = jax.jit(ringnet.apply)
62
- >>> yhat, _ = jit_apply(X)
63
- >>>
64
- >>> # manually pass in and out the hidden state like so
65
- >>> initial_state = None
66
- >>> yhat, state = ringnet.apply(X, state=initial_state)
67
- >>> # state: final hidden state, shape = (B, N, 2*H)
68
-
30
+ Returns:
31
+ ring.ml.AbstractFilter: An instantiation of `ring.ml.ringnet.RING` with trained
32
+ parameters.
33
+
34
+ Examples:
35
+ >>> import ring
36
+ >>> import numpy as np
37
+ >>>
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [0, 1, 2] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
+ >>>
45
+ >>> X = np.zeros((B, T_i, N, 9))
46
+ >>> # where X is structured as follows:
47
+ >>> # X[..., :3] = acc
48
+ >>> # X[..., 3:6] = gyr
49
+ >>> # X[..., 6:9] = jointaxis
50
+ >>>
51
+ >>> # let's assume we have an IMU on each outer segment of the
52
+ >>> # three-segment kinematic chain
53
+ >>> X[:, :, 0, :3] = acc_segment1
54
+ >>> X[:, :, 2, :3] = acc_segment3
55
+ >>> X[:, :, 0, 3:6] = gyr_segment1
56
+ >>> X[:, :, 2, 3:6] = gyr_segment3
57
+ >>>
58
+ >>> ringnet = ring.RING(lam, Ts)
59
+ >>>
60
+ >>> yhat, _ = ringnet.apply(X)
61
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
62
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
63
+ >>>
64
+ >>> # use `jax.jit` to compile the forward pass
65
+ >>> jit_apply = jax.jit(ringnet.apply)
66
+ >>> yhat, _ = jit_apply(X)
67
+ >>>
68
+ >>> # manually pass in and out the hidden state like so
69
+ >>> initial_state = None
70
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
71
+ >>> # state: final hidden state, shape = (B, N, 2*H)
69
72
  """
70
73
  from pathlib import Path
71
74
  import warnings
@@ -303,6 +303,7 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
+ "Steps the dynamics. Returns the state of next timestep."
306
307
  assert sys.q_size() == state.q.size
307
308
  if taus is None:
308
309
  taus = jnp.zeros_like(state.qd)
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import numpy as np
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -47,6 +48,7 @@ class RCMG:
47
48
  cor: bool = False,
48
49
  disable_tqdm: bool = False,
49
50
  ) -> None:
51
+ "Random Chain Motion Generator"
50
52
 
51
53
  sys, config = utils.to_list(sys), utils.to_list(config)
52
54
  sys_ml = sys[0] if sys_ml is None else sys_ml
@@ -141,7 +143,9 @@ class RCMG:
141
143
 
142
144
  return n_calls
143
145
 
144
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
146
+ def to_list(
147
+ self, sizes: int | list[int] = 1, seed: int = 1
148
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
145
149
  "Returns list of unbatched sequences as numpy arrays."
146
150
  repeats = self._compute_repeats(sizes)
147
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -170,7 +174,7 @@ class RCMG:
170
174
  seed: int = 1,
171
175
  overwrite: bool = True,
172
176
  ) -> None:
173
- data = tree_utils.tree_batch(self.to_list(sizes, seed))
177
+ data = tree_utils.tree_batch(self.to_list(sizes, seed), backend="numpy")
174
178
  utils.pickle_save(data, path, overwrite=overwrite)
175
179
 
176
180
  def to_eager_gen(
ring/base.py CHANGED
@@ -113,7 +113,9 @@ class _Base:
113
113
  class Transform(_Base):
114
114
  """Represents the Transformation from Plücker A to Plücker B,
115
115
  where B is located relative to A at `pos` in frame A and `rot` is the
116
- relative quaternion from A to B."""
116
+ relative quaternion from A to B.
117
+ Create using `Transform.create(pos=..., rot=...)
118
+ """
117
119
 
118
120
  pos: Vector
119
121
  rot: Quaternion
@@ -399,6 +401,7 @@ QD_WIDTHS = {
399
401
 
400
402
  @struct.dataclass
401
403
  class System(_Base):
404
+ "System object. Create using `System.create(path_xml)`"
402
405
  link_parents: list[int] = struct.field(False)
403
406
  links: Link
404
407
  link_types: list[str] = struct.field(False)
ring/ml/ringnet.py CHANGED
@@ -200,6 +200,7 @@ class RING(ml_base.AbstractFilter):
200
200
  forward_factory=make_ring,
201
201
  **kwargs,
202
202
  ):
203
+ "Untrained RING network"
203
204
  self.forward_lam_factory = partial(forward_factory, **kwargs)
204
205
  self.params = self._load_params(params)
205
206
  self.lam = lam