imt-ring 1.6.2__tar.gz → 1.6.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (116) hide show
  1. {imt_ring-1.6.2 → imt_ring-1.6.4}/PKG-INFO +31 -1
  2. {imt_ring-1.6.2 → imt_ring-1.6.4}/pyproject.toml +1 -1
  3. {imt_ring-1.6.2 → imt_ring-1.6.4}/readme.md +30 -0
  4. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/imt_ring.egg-info/PKG-INFO +31 -1
  5. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/__init__.py +43 -40
  6. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/dynamics.py +1 -0
  7. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/base.py +9 -2
  8. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/finalize_fns.py +3 -2
  9. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/sensors.py +16 -6
  10. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/base.py +4 -1
  11. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/ringnet.py +1 -0
  12. {imt_ring-1.6.2 → imt_ring-1.6.4}/setup.cfg +0 -0
  13. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  14. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  15. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/imt_ring.egg-info/requires.txt +0 -0
  16. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/imt_ring.egg-info/top_level.txt +0 -0
  17. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algebra.py +0 -0
  18. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/__init__.py +0 -0
  19. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/_random.py +0 -0
  20. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  21. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  22. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  23. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  24. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/__init__.py +0 -0
  25. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/batch.py +0 -0
  26. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  27. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/pd_control.py +0 -0
  28. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/setup_fns.py +0 -0
  29. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/generator/types.py +0 -0
  30. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/jcalc.py +0 -0
  31. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/__init__.py +0 -0
  33. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/branched.xml +0 -0
  34. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  35. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  36. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  37. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/inv_pendulum.xml +0 -0
  38. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  39. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/spherical_stiff.xml +0 -0
  40. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/symmetric.xml +0 -0
  41. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_all_1.xml +0 -0
  42. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_all_2.xml +0 -0
  43. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  44. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_control.xml +0 -0
  45. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  46. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_free.xml +0 -0
  47. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_kinematics.xml +0 -0
  48. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  49. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  50. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_randomize_position.xml +0 -0
  51. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_sensors.xml +0 -0
  52. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  53. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/examples.py +0 -0
  54. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/test_examples.py +0 -0
  55. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/__init__.py +0 -0
  56. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/abstract.py +0 -0
  57. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/from_xml.py +0 -0
  58. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/test_from_xml.py +0 -0
  59. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/test_to_xml.py +0 -0
  60. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/io/xml/to_xml.py +0 -0
  61. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/maths.py +0 -0
  62. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/__init__.py +0 -0
  63. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/base.py +0 -0
  64. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/callbacks.py +0 -0
  65. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/ml_utils.py +0 -0
  66. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/optimizer.py +0 -0
  67. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  68. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  69. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/rnno_v1.py +0 -0
  70. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/train.py +0 -0
  71. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/ml/training_loop.py +0 -0
  72. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/rendering/__init__.py +0 -0
  73. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/rendering/base_render.py +0 -0
  74. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/rendering/mujoco_render.py +0 -0
  75. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/rendering/vispy_render.py +0 -0
  76. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/rendering/vispy_visuals.py +0 -0
  77. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sim2real/__init__.py +0 -0
  78. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sim2real/sim2real.py +0 -0
  79. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/spatial.py +0 -0
  80. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sys_composer/__init__.py +0 -0
  81. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sys_composer/delete_sys.py +0 -0
  82. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sys_composer/inject_sys.py +0 -0
  83. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/sys_composer/morph_sys.py +0 -0
  84. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/__init__.py +0 -0
  85. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/backend.py +0 -0
  86. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/batchsize.py +0 -0
  87. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/colab.py +0 -0
  88. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/hdf5.py +0 -0
  89. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/normalizer.py +0 -0
  90. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/path.py +0 -0
  91. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/randomize_sys.py +0 -0
  92. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  93. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  94. {imt_ring-1.6.2 → imt_ring-1.6.4}/src/ring/utils/utils.py +0 -0
  95. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_algebra.py +0 -0
  96. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_base.py +0 -0
  97. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_custom_joints.py +0 -0
  98. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_dynamics.py +0 -0
  99. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_generator.py +0 -0
  100. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_jcalc.py +0 -0
  101. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_jit.py +0 -0
  102. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_kinematics.py +0 -0
  103. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_maths.py +0 -0
  104. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_ml_utils.py +0 -0
  105. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_motion_artifacts.py +0 -0
  106. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_pd_control.py +0 -0
  107. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_quickstart_example.py +0 -0
  108. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_random.py +0 -0
  109. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_randomize.py +0 -0
  110. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_rcmg.py +0 -0
  111. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_render.py +0 -0
  112. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_sensors.py +0 -0
  113. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_sim2real.py +0 -0
  114. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_sys_composer.py +0 -0
  115. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_train.py +0 -0
  116. {imt_ring-1.6.2 → imt_ring-1.6.4}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.2
3
+ Version: 1.6.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.2"
7
+ version = "1.6.4"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -22,6 +22,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
22
22
 
23
23
  Available [here](https://simipixel.github.io/ring/).
24
24
 
25
+ ## Quickstart Example
26
+ ```python
27
+ import ring
28
+ import numpy as np
29
+
30
+ T : int = 30 # sequence length [s]
31
+ Ts : float = 0.01 # sampling interval [s]
32
+ B : int = 1 # batch size
33
+ lam: list[int] = [0, 1, 2] # parent array
34
+ N : int = len(lam) # number of bodies
35
+ T_i: int = int(T/Ts) # number of timesteps
36
+
37
+ X = np.zeros((B, T_i, N, 9))
38
+ # where X is structured as follows:
39
+ # X[..., :3] = acc
40
+ # X[..., 3:6] = gyr
41
+ # X[..., 6:9] = jointaxis
42
+
43
+ # let's assume we have an IMU on each outer segment of the
44
+ # three-segment kinematic chain
45
+ X[..., 0, :3] = acc_segment1
46
+ X[..., 2, :3] = acc_segment3
47
+ X[..., 0, 3:6] = gyr_segment1
48
+ X[..., 2, 3:6] = gyr_segment3
49
+
50
+ ringnet = ring.RING(lam, Ts)
51
+ yhat, _ = ringnet.apply(X)
52
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
53
+ ```
54
+
25
55
  ### Known fixes
26
56
 
27
57
  #### Offscreen rendering with Mujoco
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.2
3
+ Version: 1.6.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -60,6 +60,36 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
60
60
 
61
61
  Available [here](https://simipixel.github.io/ring/).
62
62
 
63
+ ## Quickstart Example
64
+ ```python
65
+ import ring
66
+ import numpy as np
67
+
68
+ T : int = 30 # sequence length [s]
69
+ Ts : float = 0.01 # sampling interval [s]
70
+ B : int = 1 # batch size
71
+ lam: list[int] = [0, 1, 2] # parent array
72
+ N : int = len(lam) # number of bodies
73
+ T_i: int = int(T/Ts) # number of timesteps
74
+
75
+ X = np.zeros((B, T_i, N, 9))
76
+ # where X is structured as follows:
77
+ # X[..., :3] = acc
78
+ # X[..., 3:6] = gyr
79
+ # X[..., 6:9] = jointaxis
80
+
81
+ # let's assume we have an IMU on each outer segment of the
82
+ # three-segment kinematic chain
83
+ X[..., 0, :3] = acc_segment1
84
+ X[..., 2, :3] = acc_segment3
85
+ X[..., 0, 3:6] = gyr_segment1
86
+ X[..., 2, 3:6] = gyr_segment3
87
+
88
+ ringnet = ring.RING(lam, Ts)
89
+ yhat, _ = ringnet.apply(X)
90
+ # yhat: unit quaternions, shape = (B, T_i, N, 4)
91
+ ```
92
+
63
93
  ### Known fixes
64
94
 
65
95
  #### Offscreen rendering with Mujoco
@@ -20,52 +20,55 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int] | None, Ts: float | None, **kwargs):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter:
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
27
  lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
- Usage:
31
- >>> import ring
32
- >>> import numpy as np
33
- >>>
34
- >>> T : int = 30 # sequence length [s]
35
- >>> Ts : float = 0.01 # sampling interval [s]
36
- >>> B : int = 1 # batch size
37
- >>> lam: list[int] = [0, 1, 2] # parent array
38
- >>> N : int = len(lam) # number of bodies
39
- >>> T_i: int = int(T/Ts) # number of timesteps
40
- >>>
41
- >>> X = np.zeros((B, T_i, N, 9))
42
- >>> # where X is structured as follows:
43
- >>> # X[..., :3] = acc
44
- >>> # X[..., 3:6] = gyr
45
- >>> # X[..., 6:9] = jointaxis
46
- >>>
47
- >>> # let's assume we have an IMU on each outer segment of the
48
- >>> # three-segment kinematic chain
49
- >>> X[:, :, 0, :3] = acc_segment1
50
- >>> X[:, :, 2, :3] = acc_segment3
51
- >>> X[:, :, 0, 3:6] = gyr_segment1
52
- >>> X[:, :, 2, 3:6] = gyr_segment3
53
- >>>
54
- >>> ringnet = ring.RING(lam, Ts)
55
- >>>
56
- >>> yhat, _ = ringnet.apply(X)
57
- >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
- >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
59
- >>>
60
- >>> # use `jax.jit` to compile the forward pass
61
- >>> jit_apply = jax.jit(ringnet.apply)
62
- >>> yhat, _ = jit_apply(X)
63
- >>>
64
- >>> # manually pass in and out the hidden state like so
65
- >>> initial_state = None
66
- >>> yhat, state = ringnet.apply(X, state=initial_state)
67
- >>> # state: final hidden state, shape = (B, N, 2*H)
68
-
30
+ Returns:
31
+ ring.ml.AbstractFilter: An instantiation of `ring.ml.ringnet.RING` with trained
32
+ parameters.
33
+
34
+ Examples:
35
+ >>> import ring
36
+ >>> import numpy as np
37
+ >>>
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [0, 1, 2] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
+ >>>
45
+ >>> X = np.zeros((B, T_i, N, 9))
46
+ >>> # where X is structured as follows:
47
+ >>> # X[..., :3] = acc
48
+ >>> # X[..., 3:6] = gyr
49
+ >>> # X[..., 6:9] = jointaxis
50
+ >>>
51
+ >>> # let's assume we have an IMU on each outer segment of the
52
+ >>> # three-segment kinematic chain
53
+ >>> X[:, :, 0, :3] = acc_segment1
54
+ >>> X[:, :, 2, :3] = acc_segment3
55
+ >>> X[:, :, 0, 3:6] = gyr_segment1
56
+ >>> X[:, :, 2, 3:6] = gyr_segment3
57
+ >>>
58
+ >>> ringnet = ring.RING(lam, Ts)
59
+ >>>
60
+ >>> yhat, _ = ringnet.apply(X)
61
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
62
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
63
+ >>>
64
+ >>> # use `jax.jit` to compile the forward pass
65
+ >>> jit_apply = jax.jit(ringnet.apply)
66
+ >>> yhat, _ = jit_apply(X)
67
+ >>>
68
+ >>> # manually pass in and out the hidden state like so
69
+ >>> initial_state = None
70
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
71
+ >>> # state: final hidden state, shape = (B, N, 2*H)
69
72
  """
70
73
  from pathlib import Path
71
74
  import warnings
@@ -303,6 +303,7 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
+ "Steps the dynamics. Returns the state of next timestep."
306
307
  assert sys.q_size() == state.q.size
307
308
  if taus is None:
308
309
  taus = jnp.zeros_like(state.qd)
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import numpy as np
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -30,6 +31,7 @@ class RCMG:
30
31
  add_X_jointaxes_kwargs: dict = dict(),
31
32
  add_y_relpose: bool = False,
32
33
  add_y_rootincl: bool = False,
34
+ add_y_rootincl_kwargs: dict = dict(),
33
35
  sys_ml: Optional[base.System] = None,
34
36
  randomize_positions: bool = False,
35
37
  randomize_motion_artifacts: bool = False,
@@ -46,6 +48,7 @@ class RCMG:
46
48
  cor: bool = False,
47
49
  disable_tqdm: bool = False,
48
50
  ) -> None:
51
+ "Random Chain Motion Generator"
49
52
 
50
53
  sys, config = utils.to_list(sys), utils.to_list(config)
51
54
  sys_ml = sys[0] if sys_ml is None else sys_ml
@@ -67,6 +70,7 @@ class RCMG:
67
70
  add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
68
71
  add_y_relpose=add_y_relpose,
69
72
  add_y_rootincl=add_y_rootincl,
73
+ add_y_rootincl_kwargs=add_y_rootincl_kwargs,
70
74
  sys_ml=sys_ml,
71
75
  randomize_positions=randomize_positions,
72
76
  randomize_motion_artifacts=randomize_motion_artifacts,
@@ -139,7 +143,9 @@ class RCMG:
139
143
 
140
144
  return n_calls
141
145
 
142
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
146
+ def to_list(
147
+ self, sizes: int | list[int] = 1, seed: int = 1
148
+ ) -> list[tree_utils.PyTree[np.ndarray]]:
143
149
  "Returns list of unbatched sequences as numpy arrays."
144
150
  repeats = self._compute_repeats(sizes)
145
151
  sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
@@ -168,7 +174,7 @@ class RCMG:
168
174
  seed: int = 1,
169
175
  overwrite: bool = True,
170
176
  ) -> None:
171
- data = tree_utils.tree_batch(self.to_list(sizes, seed))
177
+ data = tree_utils.tree_batch(self.to_list(sizes, seed), backend="numpy")
172
178
  utils.pickle_save(data, path, overwrite=overwrite)
173
179
 
174
180
  def to_eager_gen(
@@ -232,6 +238,7 @@ def _build_mconfig_batched_generator(
232
238
  add_X_jointaxes_kwargs: dict,
233
239
  add_y_relpose: bool,
234
240
  add_y_rootincl: bool,
241
+ add_y_rootincl_kwargs: dict,
235
242
  sys_ml: base.System,
236
243
  randomize_positions: bool,
237
244
  randomize_motion_artifacts: bool,
@@ -77,12 +77,13 @@ class RelPose:
77
77
 
78
78
 
79
79
  class RootIncl:
80
- def __init__(self, sys: base.System):
80
+ def __init__(self, sys: base.System, **kwargs):
81
81
  self.sys = sys
82
+ self.kwargs = kwargs
82
83
 
83
84
  def __call__(self, Xy, extras):
84
85
  (X, y), (key, q, x, sys_x) = Xy, extras
85
- y_root_incl = sensors.root_incl(self.sys, x, sys_x)
86
+ y_root_incl = sensors.root_incl(self.sys, x, sys_x, **self.kwargs)
86
87
  y = utils.dict_union(y, y_root_incl)
87
88
  return (X, y), (key, q, x, sys_x)
88
89
 
@@ -330,7 +330,10 @@ def rel_pose(
330
330
 
331
331
 
332
332
  def root_incl(
333
- sys: base.System, x: base.Transform, sys_x: base.System
333
+ sys: base.System,
334
+ x: base.Transform,
335
+ sys_x: base.System,
336
+ child_to_parent: bool = False,
334
337
  ) -> dict[str, jax.Array]:
335
338
  # (time, nlinks, 4) -> (nlinks, time, 4)
336
339
  rots = x.rot.transpose((1, 0, 2))
@@ -341,8 +344,10 @@ def root_incl(
341
344
  def f(_, __, name: str, parent: int):
342
345
  if parent != -1:
343
346
  return
344
- q_eps_to_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
345
- y[name] = maths.quat_inv(q_eps_to_i)
347
+ q_i = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
348
+ if child_to_parent:
349
+ q_i = maths.quat_inv(q_i)
350
+ y[name] = q_i
346
351
 
347
352
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
348
353
 
@@ -350,7 +355,10 @@ def root_incl(
350
355
 
351
356
 
352
357
  def root_full(
353
- sys: base.System, x: base.Transform, sys_x: base.System
358
+ sys: base.System,
359
+ x: base.Transform,
360
+ sys_x: base.System,
361
+ child_to_parent: bool = False,
354
362
  ) -> dict[str, jax.Array]:
355
363
  # (time, nlinks, 4) -> (nlinks, time, 4)
356
364
  rots = x.rot.transpose((1, 0, 2))
@@ -361,8 +369,10 @@ def root_full(
361
369
  def f(_, __, name: str, parent: int):
362
370
  if parent != -1:
363
371
  return
364
- q_eps_to_i = rots[l_map[name]]
365
- y[name] = maths.quat_inv(q_eps_to_i)
372
+ q_i = rots[l_map[name]]
373
+ if child_to_parent:
374
+ q_i = maths.quat_inv(q_i)
375
+ y[name] = q_i
366
376
 
367
377
  sys.scan(f, "ll", sys.link_names, sys.link_parents)
368
378
 
@@ -113,7 +113,9 @@ class _Base:
113
113
  class Transform(_Base):
114
114
  """Represents the Transformation from Plücker A to Plücker B,
115
115
  where B is located relative to A at `pos` in frame A and `rot` is the
116
- relative quaternion from A to B."""
116
+ relative quaternion from A to B.
117
+ Create using `Transform.create(pos=..., rot=...)
118
+ """
117
119
 
118
120
  pos: Vector
119
121
  rot: Quaternion
@@ -399,6 +401,7 @@ QD_WIDTHS = {
399
401
 
400
402
  @struct.dataclass
401
403
  class System(_Base):
404
+ "System object. Create using `System.create(path_xml)`"
402
405
  link_parents: list[int] = struct.field(False)
403
406
  links: Link
404
407
  link_types: list[str] = struct.field(False)
@@ -200,6 +200,7 @@ class RING(ml_base.AbstractFilter):
200
200
  forward_factory=make_ring,
201
201
  **kwargs,
202
202
  ):
203
+ "Untrained RING network"
203
204
  self.forward_lam_factory = partial(forward_factory, **kwargs)
204
205
  self.params = self._load_params(params)
205
206
  self.lam = lam
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes