imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,56 @@
1
+ import jax.numpy as jnp
2
+ import ring
3
+
4
+ sys_str = """
5
+ <x_xy model="model">
6
+ <options gravity=".1 2 3" dt=".03"/>
7
+ <worldbody>
8
+ <body name="name" joint="rx" pos="1 2 3" euler="30 30 30" damping=".7" armature=".8" spring_stiff="1" spring_zero=".9">
9
+ <geom type="box" mass="2.7" dim="0.2 0.3 0.4" color="black" edge_color="pink"/>
10
+ </body>
11
+ </worldbody>
12
+ </x_xy>
13
+ """ # noqa: E501
14
+
15
+
16
+ def test_from_xml():
17
+ pos = jnp.array([1.0, 2, 3])
18
+ sys1 = ring.System(
19
+ [-1],
20
+ ring.base.Link(
21
+ ring.base.Transform(
22
+ pos=pos,
23
+ rot=ring.maths.quat_euler(
24
+ jnp.array([jnp.deg2rad(30), jnp.deg2rad(30), jnp.deg2rad(30)])
25
+ ),
26
+ ),
27
+ pos_min=pos,
28
+ pos_max=pos,
29
+ ).batch(),
30
+ ["rx"],
31
+ link_damping=jnp.array([0.7]),
32
+ link_armature=jnp.array([0.8]),
33
+ link_spring_zeropoint=jnp.array([0.9]),
34
+ link_spring_stiffness=jnp.array([1.0]),
35
+ dt=0.03,
36
+ geoms=[
37
+ ring.base.Box(
38
+ jnp.array(2.7),
39
+ ring.Transform.zero(),
40
+ 0,
41
+ "black",
42
+ "pink",
43
+ jnp.array(0.2),
44
+ jnp.array(0.3),
45
+ jnp.array(0.4),
46
+ )
47
+ ],
48
+ gravity=jnp.array([0.1, 2, 3.0]),
49
+ link_names=["name"],
50
+ model_name="model",
51
+ omc=[None],
52
+ )
53
+ sys1 = sys1.parse()
54
+ sys2 = ring.io.load_sys_from_str(sys_str)
55
+
56
+ assert ring.utils.sys_compare(sys1, sys2)
@@ -0,0 +1,31 @@
1
+ import logging
2
+
3
+ import ring
4
+ from ring.base import System
5
+ from ring.utils import sys_compare
6
+
7
+
8
+ def test_save_sys_to_str():
9
+ for original_sys in ring.io.list_load_examples():
10
+ sys_to_xml_str = ring.io.save_sys_to_str(original_sys)
11
+
12
+ logging.debug(sys_to_xml_str)
13
+
14
+ compare_sys = ring.io.load_sys_from_str(sys_to_xml_str)
15
+
16
+ assert sys_compare(
17
+ original_sys, compare_sys
18
+ ), f"Failed {original_sys.model_name}.xml"
19
+
20
+ print(f"Passed {original_sys.model_name}.xml")
21
+
22
+ def double_load_xml_to_sys(example: str) -> System:
23
+ orig_sys = ring.io.load_example(example)
24
+ exported_xml = ring.io.save_sys_to_str(orig_sys)
25
+ new_sys = ring.io.load_sys_from_str(exported_xml)
26
+ return new_sys
27
+
28
+ sys_test_xml_1 = double_load_xml_to_sys("test_all_1.xml")
29
+ sys_test_xml_2 = double_load_xml_to_sys("test_all_2.xml")
30
+
31
+ assert not sys_compare(sys_test_xml_1, sys_test_xml_2)
ring/io/xml/to_xml.py ADDED
@@ -0,0 +1,94 @@
1
+ import warnings
2
+ from xml.dom.minidom import parseString
3
+ from xml.etree.ElementTree import Element
4
+ from xml.etree.ElementTree import SubElement
5
+ from xml.etree.ElementTree import tostring
6
+
7
+ import jax.numpy as jnp
8
+ from ring import base
9
+ from tree_utils import batch_concat
10
+
11
+ from . import abstract
12
+ from .abstract import _to_str
13
+
14
+
15
+ def save_sys_to_str(sys: base.System) -> str:
16
+ for joint_type in sys.links.joint_params:
17
+ for i, link_name in enumerate(sys.link_names):
18
+ joint_params_flat = batch_concat((sys.links[i]).joint_params[joint_type], 0)
19
+ if not jnp.all(joint_params_flat == 0.0):
20
+ warnings.warn(
21
+ "The system has `sys.links.joint_params` unequal to the 'default'"
22
+ f" value (of zeros). In particular the link `{link_name}` has for"
23
+ f" the jointtype `{joint_type}` the values {joint_params_flat}. "
24
+ "This will not be preserved in the xml."
25
+ )
26
+ global_index_map = {qd: sys.idx_map(qd) for qd in ["q", "d"]}
27
+
28
+ # Create root element
29
+ x_xy = Element("x_xy")
30
+ x_xy.set("model", sys.model_name)
31
+
32
+ options = SubElement(x_xy, "options")
33
+ options.set("dt", str(sys.dt))
34
+ options.set("gravity", _to_str(sys.gravity))
35
+
36
+ # Create worldbody
37
+ worldbody = SubElement(x_xy, "worldbody")
38
+
39
+ def process_link(link_idx: int, parent_elem: Element):
40
+ link = sys.links[link_idx]
41
+ link_typ = sys.link_types[link_idx]
42
+ link_name = sys.link_names[link_idx]
43
+
44
+ # Create body element
45
+ body = SubElement(parent_elem, "body")
46
+ body.set("joint", link_typ)
47
+ body.set("name", link_name)
48
+
49
+ # Set attributes
50
+ abstract.AbsTrans.to_xml(body, link.transform1)
51
+ abstract.AbsPosMinMax.to_xml(body, link.pos_min, link.pos_max)
52
+ abstract.AbsDampArmaStiffZero.to_xml(
53
+ body,
54
+ sys.link_damping[global_index_map["d"][link_name]],
55
+ sys.link_armature[global_index_map["d"][link_name]],
56
+ sys.link_spring_stiffness[global_index_map["d"][link_name]],
57
+ sys.link_spring_zeropoint[global_index_map["q"][link_name]],
58
+ base.Q_WIDTHS[link_typ],
59
+ base.QD_WIDTHS[link_typ],
60
+ link_typ,
61
+ )
62
+
63
+ # Add geometry elements
64
+ geoms = sys.geoms
65
+ for geom in geoms:
66
+ if geom.link_idx == link_idx:
67
+ geom_elem = SubElement(body, "geom")
68
+ abstract_class = abstract.geometry_to_abstract[type(geom)]
69
+ abstract_class.to_xml(geom_elem, geom)
70
+
71
+ # Maybe add omc element
72
+ omc_link = sys.omc[link_idx]
73
+ if omc_link is not None:
74
+ omc_elem = SubElement(body, "omc")
75
+ abstract.AbsMaxCoordOMC.to_xml(omc_elem, omc_link)
76
+
77
+ # Recursively process child links
78
+ for child_idx, parent_idx in enumerate(sys.link_parents):
79
+ if parent_idx == link_idx:
80
+ process_link(child_idx, body)
81
+
82
+ for root_link_idx, parent_idx in enumerate(sys.link_parents):
83
+ if parent_idx == -1:
84
+ process_link(root_link_idx, worldbody)
85
+
86
+ # Pretty print xml
87
+ xml_str = parseString(tostring(x_xy)).toprettyxml(indent=" ")
88
+ return xml_str
89
+
90
+
91
+ def save_sys_to_xml(sys: base.System, xml_path: str) -> None:
92
+ xml_str = save_sys_to_str(sys)
93
+ with open(xml_path, "w") as f:
94
+ f.write(xml_str)
ring/maths.py ADDED
@@ -0,0 +1,397 @@
1
+ from functools import partial
2
+
3
+ import jax
4
+ from jax import custom_jvp
5
+ import jax.numpy as jnp
6
+ import jax.random as jrand
7
+
8
+
9
+ def wrap_to_pi(phi):
10
+ "Wraps angle `phi` (radians) to interval [-pi, pi]."
11
+ return (phi + jnp.pi) % (2 * jnp.pi) - jnp.pi
12
+
13
+
14
+ x_unit_vector = jnp.array([1.0, 0, 0])
15
+ y_unit_vector = jnp.array([0.0, 1, 0])
16
+ z_unit_vector = jnp.array([0.0, 0, 1])
17
+
18
+
19
+ def unit_vectors(xyz: int | str):
20
+ if isinstance(xyz, str):
21
+ xyz = {"x": 0, "y": 1, "z": 2}[xyz]
22
+ return [x_unit_vector, y_unit_vector, z_unit_vector][xyz]
23
+
24
+
25
+ @partial(jnp.vectorize, signature="(k)->(1)")
26
+ def safe_norm(x):
27
+ """Grad-safe for x=0.0. Norm along last axis."""
28
+ assert x.ndim == 1
29
+
30
+ is_zero = jnp.all(jnp.isclose(x, 0.0), axis=-1, keepdims=False)
31
+ return jax.lax.cond(
32
+ is_zero,
33
+ lambda x: jnp.array([0.0], dtype=x.dtype),
34
+ lambda x: jnp.linalg.norm(x, keepdims=True),
35
+ x,
36
+ )
37
+
38
+
39
+ @partial(jnp.vectorize, signature="(k)->(k)")
40
+ def safe_normalize(x):
41
+ """Execution- and Grad-safe for x=0.0. Normalizes along last axis."""
42
+ assert x.ndim == 1
43
+
44
+ is_zero = jnp.allclose(x, 0.0)
45
+ return jax.lax.cond(
46
+ is_zero,
47
+ lambda x: jnp.zeros_like(x),
48
+ lambda x: x / jnp.where(is_zero, 1.0, safe_norm(x)),
49
+ x,
50
+ )
51
+
52
+
53
+ @custom_jvp
54
+ def safe_arccos(x: jnp.ndarray) -> jnp.ndarray:
55
+ """Trigonometric inverse cosine, element-wise with safety clipping in grad."""
56
+ return jnp.arccos(x)
57
+
58
+
59
+ @safe_arccos.defjvp
60
+ def _safe_arccos_jvp(primal, tangent):
61
+ (x,) = primal
62
+ (x_dot,) = tangent
63
+ primal_out = safe_arccos(x)
64
+ tangent_out = -x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
65
+ return primal_out, tangent_out
66
+
67
+
68
+ @custom_jvp
69
+ def safe_arcsin(x: jnp.ndarray) -> jnp.ndarray:
70
+ """Trigonometric inverse sine, element-wise with safety clipping in grad."""
71
+ return jnp.arcsin(x)
72
+
73
+
74
+ @safe_arcsin.defjvp
75
+ def _safe_arcsin_jvp(primal, tangent):
76
+ (x,) = primal
77
+ (x_dot,) = tangent
78
+ primal_out = safe_arccos(x)
79
+ tangent_out = x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
80
+ return primal_out, tangent_out
81
+
82
+
83
+ @partial(jnp.vectorize, signature="(4)->(4)")
84
+ def ensure_positive_w(q):
85
+ return jnp.where(q[0] < 0, -q, q)
86
+
87
+
88
+ def angle_error(q, qhat):
89
+ "Absolute angle in radians between `q` and `qhat`."
90
+ return jnp.abs(quat_angle(quat_mul(quat_inv(q), qhat)))
91
+
92
+
93
+ def unit_quats_like(array):
94
+ "Array of *unit* quaternions of identical shape."
95
+ if array.shape[-1] != 4:
96
+ raise Exception()
97
+
98
+ return jnp.ones(array.shape[:-1])[..., None] * jnp.array([1.0, 0, 0, 0])
99
+
100
+
101
+ @partial(jnp.vectorize, signature="(4),(4)->(4)")
102
+ def quat_mul(u: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
103
+ "Multiplies two quaternions."
104
+ q = jnp.array(
105
+ [
106
+ u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
107
+ u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
108
+ u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
109
+ u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
110
+ ]
111
+ )
112
+ return q
113
+
114
+
115
+ def quat_inv(q: jnp.ndarray) -> jnp.ndarray:
116
+ "Calculates the inverse of quaternion q."
117
+ return q * jnp.array([1.0, -1, -1, -1])
118
+
119
+
120
+ @partial(jnp.vectorize, signature="(3),(4)->(3)")
121
+ def rotate(vector: jnp.ndarray, quat: jnp.ndarray):
122
+ """Rotates a vector `vector` by a *unit* quaternion `quat`."""
123
+ qvec = jnp.array([0, *vector])
124
+ return rotate_quat(qvec, quat)[1:4]
125
+
126
+
127
+ def rotate_matrix(matrix: jax.Array, quat: jax.Array):
128
+ "Rotate matrix `matrix` by a *unit* quaternion `quat`."
129
+ E = quat_to_3x3(quat)
130
+ return E @ matrix @ E.T
131
+
132
+
133
+ def rotate_quat(q: jax.Array, quat: jax.Array):
134
+ "Rotate quaternion `q` by `quat`"
135
+ return quat_mul(quat, quat_mul(q, quat_inv(quat)))
136
+
137
+
138
+ @partial(jnp.vectorize, signature="(3),()->(4)")
139
+ def quat_rot_axis(axis: jnp.ndarray, angle: jnp.ndarray) -> jnp.ndarray:
140
+ """Construct a *unit* quaternion that describes rotating around
141
+ `axis` by `angle` (radians).
142
+
143
+ This is the interpretation of rotating the vector and *not*
144
+ the frame.
145
+ For the interpretation of rotating the frame and *not* the
146
+ vector, you should use angle -> -angle.
147
+ NOTE: Usually, we actually want the second interpretation. Think about it,
148
+ we use quaternions to re-express vectors in other frames. But the
149
+ vectors stay the same. We only transform them to a common frames.
150
+ """
151
+ assert axis.shape == (3,)
152
+ assert angle.shape == ()
153
+
154
+ axis = safe_normalize(axis)
155
+ # NOTE: CONVENTION
156
+ # 23.04.23
157
+ # this fixes the issue of prismatic joints being inverted w.r.t.
158
+ # gravity vector.
159
+ # The reason is that it inverts the way how revolute joints behave
160
+ # Such that prismatic joints work by inverting gravity
161
+ angle *= -1.0
162
+ s, c = jnp.sin(angle / 2), jnp.cos(angle / 2)
163
+ return jnp.array([c, *(axis * s)])
164
+
165
+
166
+ @partial(jnp.vectorize, signature="(3,3)->(4)")
167
+ def quat_from_3x3(m: jnp.ndarray) -> jnp.ndarray:
168
+ """Converts 3x3 rotation matrix to *unit* quaternion."""
169
+ w = jnp.sqrt(1 + m[0, 0] + m[1, 1] + m[2, 2]) / 2.0
170
+ x = (m[2][1] - m[1][2]) / (w * 4)
171
+ y = (m[0][2] - m[2][0]) / (w * 4)
172
+ z = (m[1][0] - m[0][1]) / (w * 4)
173
+ return jnp.array([w, x, y, z])
174
+
175
+
176
+ @partial(jnp.vectorize, signature="(4)->(3,3)")
177
+ def quat_to_3x3(q: jnp.ndarray) -> jnp.ndarray:
178
+ """Converts *unit* quaternion to 3x3 rotation matrix."""
179
+ d = jnp.dot(q, q)
180
+ w, x, y, z = q
181
+ s = 2 / d
182
+ xs, ys, zs = x * s, y * s, z * s
183
+ wx, wy, wz = w * xs, w * ys, w * zs
184
+ xx, xy, xz = x * xs, x * ys, x * zs
185
+ yy, yz, zz = y * ys, y * zs, z * zs
186
+
187
+ return jnp.array(
188
+ [
189
+ jnp.array([1 - (yy + zz), xy - wz, xz + wy]),
190
+ jnp.array([xy + wz, 1 - (xx + zz), yz - wx]),
191
+ jnp.array([xz - wy, yz + wx, 1 - (xx + yy)]),
192
+ ]
193
+ )
194
+
195
+
196
+ def quat_random(
197
+ key: jrand.PRNGKey, batch_shape: tuple = (), maxval: float = jnp.pi
198
+ ) -> jax.Array:
199
+ """Provides a random *unit* quaternion, sampled uniformly"""
200
+ assert key.shape == (2,), f"{key.shape}"
201
+ shape = batch_shape + (4,)
202
+ qs = safe_normalize(jrand.normal(key, shape))
203
+
204
+ def _scale_angle():
205
+ axis, angle = quat_to_rot_axis(qs)
206
+ angle_scaled = angle * maxval / jnp.pi
207
+ return quat_rot_axis(axis, angle_scaled)
208
+
209
+ return jax.lax.cond(maxval == jnp.pi, lambda: qs, _scale_angle)
210
+
211
+
212
+ def quat_euler(angles, intrinsic=True, convention="zyx"):
213
+ "Construct a *unit* quaternion from Euler angles (radians)."
214
+
215
+ @partial(jnp.vectorize, signature="(3)->(4)")
216
+ def _quat_euler(angles):
217
+ xunit = jnp.array([1.0, 0.0, 0.0])
218
+ yunit = jnp.array([0.0, 1.0, 0.0])
219
+ zunit = jnp.array([0.0, 0.0, 1.0])
220
+
221
+ axes_map = {
222
+ "x": xunit,
223
+ "y": yunit,
224
+ "z": zunit,
225
+ }
226
+
227
+ q1 = quat_rot_axis(axes_map[convention[0]], angles[0])
228
+ q2 = quat_rot_axis(axes_map[convention[1]], angles[1])
229
+ q3 = quat_rot_axis(axes_map[convention[2]], angles[2])
230
+
231
+ if intrinsic:
232
+ return quat_mul(q3, quat_mul(q2, q1))
233
+ else:
234
+ return quat_mul(q1, quat_mul(q2, q3))
235
+
236
+ return _quat_euler(angles)
237
+
238
+
239
+ @partial(jnp.vectorize, signature="(4)->()")
240
+ def quat_angle(q):
241
+ "Extract rotation angle (radians) of quaternion `q`."
242
+ phi = 2 * jnp.arctan2(safe_norm(q[1:])[0], q[0])
243
+ return wrap_to_pi(phi)
244
+
245
+
246
+ def quat_angle_constantAxisOverTime(qs):
247
+ assert qs.ndim == 2
248
+ assert qs.shape[-1] == 4
249
+
250
+ l2norm = lambda x: jnp.sqrt(jnp.sum(x**2, axis=-1))
251
+
252
+ axis = safe_normalize(qs[:, 1:])
253
+ angle = quat_angle(qs)[:, None]
254
+ convention = axis[0]
255
+ cond = (l2norm(convention - axis) > l2norm(convention + axis))[..., None]
256
+ return jnp.where(cond, -angle, angle)[:, 0]
257
+
258
+
259
+ @partial(jnp.vectorize, signature="(4)->(3),()")
260
+ def quat_to_rot_axis(q):
261
+ "Extract unit-axis and angle from quaternion `q`."
262
+ angle = quat_angle(q)
263
+ # NOTE: CONVENTION
264
+ angle *= -1.0
265
+ axis = safe_normalize(q[1:])
266
+ return axis, angle
267
+
268
+
269
+ @partial(jnp.vectorize, signature="(3)->(4)")
270
+ def euler_to_quat(angles: jnp.ndarray) -> jnp.ndarray:
271
+ """Converts euler rotations in radians to quaternion."""
272
+ # this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
273
+ c1, c2, c3 = jnp.cos(angles / 2)
274
+ s1, s2, s3 = jnp.sin(angles / 2)
275
+ w = c1 * c2 * c3 - s1 * s2 * s3
276
+ x = s1 * c2 * c3 + c1 * s2 * s3
277
+ y = c1 * s2 * c3 - s1 * c2 * s3
278
+ z = c1 * c2 * s3 + s1 * s2 * c3
279
+ # NOTE: CONVENTION
280
+ return quat_inv(jnp.array([w, x, y, z]))
281
+
282
+
283
+ @partial(jnp.vectorize, signature="(4)->(3)")
284
+ def quat_to_euler(q: jnp.ndarray) -> jnp.ndarray:
285
+ """Converts quaternions to euler rotations in radians."""
286
+ # this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
287
+
288
+ # NOTE: CONVENTION
289
+ q = quat_inv(q)
290
+
291
+ z = jnp.arctan2(
292
+ -2 * q[1] * q[2] + 2 * q[0] * q[3],
293
+ q[1] * q[1] + q[0] * q[0] - q[3] * q[3] - q[2] * q[2],
294
+ )
295
+ # TODO: Investigate why quaternions go so big we need to clip.
296
+ y = safe_arcsin(jnp.clip(2 * q[1] * q[3] + 2 * q[0] * q[2], -1.0, 1.0))
297
+ x = jnp.arctan2(
298
+ -2 * q[2] * q[3] + 2 * q[0] * q[1],
299
+ q[3] * q[3] - q[2] * q[2] - q[1] * q[1] + q[0] * q[0],
300
+ )
301
+
302
+ return jnp.array([x, y, z])
303
+
304
+
305
+ @partial(jnp.vectorize, signature="(4),(3)->(4),(4)")
306
+ def quat_project(q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
307
+ """Decompose quaternion into a primary rotation around axis `k` such that
308
+ the residual rotation's angle is minimized.
309
+
310
+ Args:
311
+ q (jax.Array): Quaternion to decompose.
312
+ k (jax.Array): Primary axis direction.
313
+
314
+ Returns:
315
+ tuple[jax.Array, jax.Array]: Primary quaternion, residual quaternion
316
+ """
317
+ phi_pri = 2 * jnp.arctan2(q[1:] @ k, q[0])
318
+ # NOTE: CONVENTION
319
+ q_pri = quat_rot_axis(k, -phi_pri)
320
+ q_res = quat_mul(q, quat_inv(q_pri))
321
+ return q_pri, q_res
322
+
323
+
324
+ def quat_avg(qs: jax.Array):
325
+ "Tolga Birdal's algorithm."
326
+ if qs.ndim == 1:
327
+ qs = qs[None, :]
328
+ assert qs.ndim == 2
329
+ return jnp.linalg.eigh(
330
+ jnp.einsum("ij,ik,i->...jk", qs, qs, jnp.ones((qs.shape[0],)))
331
+ )[1][:, -1]
332
+
333
+
334
+ # cutoff_freq=20.0; sampe_freq=100.0
335
+ # -> alpha = 0.55686
336
+ # cutoff_freq=15.0
337
+ # -> alpha = 0.48519
338
+ def quat_lowpassfilter(
339
+ qs: jax.Array,
340
+ cutoff_freq: float = 20.0,
341
+ samp_freq: float = 100.0,
342
+ filtfilt: bool = False,
343
+ ) -> jax.Array:
344
+ assert qs.ndim == 2
345
+ assert qs.shape[1] == 4
346
+
347
+ if filtfilt:
348
+ qs = quat_lowpassfilter(qs, cutoff_freq, samp_freq, filtfilt=False)
349
+ qs = quat_lowpassfilter(jnp.flip(qs, 0), cutoff_freq, samp_freq, filtfilt=False)
350
+ return jnp.flip(qs, 0)
351
+
352
+ omega_times_Ts = 2 * jnp.pi * cutoff_freq / samp_freq
353
+ alpha = omega_times_Ts / (1 + omega_times_Ts)
354
+
355
+ def f(y, x):
356
+ # error quaternion; current state -> target
357
+ q_err = quat_mul(x, quat_inv(y))
358
+ # scale down error quaternion
359
+ axis, angle = quat_to_rot_axis(q_err)
360
+ # ensure angle >= 0
361
+ axis, angle = jax.lax.cond(
362
+ angle < 0,
363
+ lambda axis, angle: (-axis, -angle),
364
+ lambda axis, angle: (axis, angle),
365
+ axis,
366
+ angle,
367
+ )
368
+ angle_scaled = angle * alpha
369
+ q_err_scaled = quat_rot_axis(axis, angle_scaled)
370
+ # move small step toward error quaternion
371
+ y = quat_mul(q_err_scaled, y)
372
+ return y, y
373
+
374
+ qs_filtered = jax.lax.scan(f, qs[0], qs[1:])[1]
375
+
376
+ # padd with first value, such that length remains equal
377
+ qs_filtered = jnp.vstack((qs[0:1], qs_filtered))
378
+
379
+ # renormalize due to float32 numerical errors accumulating
380
+ return qs_filtered / jnp.linalg.norm(qs_filtered, axis=-1, keepdims=True)
381
+
382
+
383
+ def quat_inclinationAngle(q: jax.Array):
384
+ head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
385
+ return quat_angle(incl)
386
+
387
+
388
+ def quat_headingAngle(q: jax.Array):
389
+ head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
390
+ return quat_angle(head)
391
+
392
+
393
+ def quat_transfer_heading(q_from: jax.Array, q_to: jax.Array):
394
+ heading = quat_project(q_from, jnp.array([0.0, 0, 1]))[0]
395
+ # set heading to zero in the `q_to` quaternions
396
+ q_to = quat_project(q_to, jnp.array([0.0, 0, 1]))[1]
397
+ return quat_mul(q_to, heading)
ring/ml/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ from . import base
2
+ from . import callbacks
3
+ from . import ml_utils
4
+ from . import optimizer
5
+ from . import ringnet
6
+ from . import train
7
+ from . import training_loop
8
+ from .base import AbstractFilter
9
+ from .ml_utils import on_cluster
10
+ from .ml_utils import unique_id
11
+ from .optimizer import make_optimizer
12
+ from .ringnet import RING
13
+ from .train import train_fn
14
+
15
+
16
+ def RING_ICML24(params=None, **kwargs):
17
+ """Create the RING network used in the icml24 paper.
18
+
19
+ X[..., :3] = acc
20
+ X[..., 3:6] = gyr
21
+ X[..., 6:9] = jointaxis
22
+ X[..., 9:] = dt
23
+ """
24
+ from pathlib import Path
25
+
26
+ if params is None:
27
+ params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
28
+
29
+ ringnet = RING(params=params, **kwargs) # noqa: F811
30
+ ringnet = base.ScaleX_FilterWrapper(ringnet)
31
+ ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
32
+ ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
33
+ return ringnet