imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,56 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import ring
|
3
|
+
|
4
|
+
sys_str = """
|
5
|
+
<x_xy model="model">
|
6
|
+
<options gravity=".1 2 3" dt=".03"/>
|
7
|
+
<worldbody>
|
8
|
+
<body name="name" joint="rx" pos="1 2 3" euler="30 30 30" damping=".7" armature=".8" spring_stiff="1" spring_zero=".9">
|
9
|
+
<geom type="box" mass="2.7" dim="0.2 0.3 0.4" color="black" edge_color="pink"/>
|
10
|
+
</body>
|
11
|
+
</worldbody>
|
12
|
+
</x_xy>
|
13
|
+
""" # noqa: E501
|
14
|
+
|
15
|
+
|
16
|
+
def test_from_xml():
|
17
|
+
pos = jnp.array([1.0, 2, 3])
|
18
|
+
sys1 = ring.System(
|
19
|
+
[-1],
|
20
|
+
ring.base.Link(
|
21
|
+
ring.base.Transform(
|
22
|
+
pos=pos,
|
23
|
+
rot=ring.maths.quat_euler(
|
24
|
+
jnp.array([jnp.deg2rad(30), jnp.deg2rad(30), jnp.deg2rad(30)])
|
25
|
+
),
|
26
|
+
),
|
27
|
+
pos_min=pos,
|
28
|
+
pos_max=pos,
|
29
|
+
).batch(),
|
30
|
+
["rx"],
|
31
|
+
link_damping=jnp.array([0.7]),
|
32
|
+
link_armature=jnp.array([0.8]),
|
33
|
+
link_spring_zeropoint=jnp.array([0.9]),
|
34
|
+
link_spring_stiffness=jnp.array([1.0]),
|
35
|
+
dt=0.03,
|
36
|
+
geoms=[
|
37
|
+
ring.base.Box(
|
38
|
+
jnp.array(2.7),
|
39
|
+
ring.Transform.zero(),
|
40
|
+
0,
|
41
|
+
"black",
|
42
|
+
"pink",
|
43
|
+
jnp.array(0.2),
|
44
|
+
jnp.array(0.3),
|
45
|
+
jnp.array(0.4),
|
46
|
+
)
|
47
|
+
],
|
48
|
+
gravity=jnp.array([0.1, 2, 3.0]),
|
49
|
+
link_names=["name"],
|
50
|
+
model_name="model",
|
51
|
+
omc=[None],
|
52
|
+
)
|
53
|
+
sys1 = sys1.parse()
|
54
|
+
sys2 = ring.io.load_sys_from_str(sys_str)
|
55
|
+
|
56
|
+
assert ring.utils.sys_compare(sys1, sys2)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import ring
|
4
|
+
from ring.base import System
|
5
|
+
from ring.utils import sys_compare
|
6
|
+
|
7
|
+
|
8
|
+
def test_save_sys_to_str():
|
9
|
+
for original_sys in ring.io.list_load_examples():
|
10
|
+
sys_to_xml_str = ring.io.save_sys_to_str(original_sys)
|
11
|
+
|
12
|
+
logging.debug(sys_to_xml_str)
|
13
|
+
|
14
|
+
compare_sys = ring.io.load_sys_from_str(sys_to_xml_str)
|
15
|
+
|
16
|
+
assert sys_compare(
|
17
|
+
original_sys, compare_sys
|
18
|
+
), f"Failed {original_sys.model_name}.xml"
|
19
|
+
|
20
|
+
print(f"Passed {original_sys.model_name}.xml")
|
21
|
+
|
22
|
+
def double_load_xml_to_sys(example: str) -> System:
|
23
|
+
orig_sys = ring.io.load_example(example)
|
24
|
+
exported_xml = ring.io.save_sys_to_str(orig_sys)
|
25
|
+
new_sys = ring.io.load_sys_from_str(exported_xml)
|
26
|
+
return new_sys
|
27
|
+
|
28
|
+
sys_test_xml_1 = double_load_xml_to_sys("test_all_1.xml")
|
29
|
+
sys_test_xml_2 = double_load_xml_to_sys("test_all_2.xml")
|
30
|
+
|
31
|
+
assert not sys_compare(sys_test_xml_1, sys_test_xml_2)
|
ring/io/xml/to_xml.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
import warnings
|
2
|
+
from xml.dom.minidom import parseString
|
3
|
+
from xml.etree.ElementTree import Element
|
4
|
+
from xml.etree.ElementTree import SubElement
|
5
|
+
from xml.etree.ElementTree import tostring
|
6
|
+
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from ring import base
|
9
|
+
from tree_utils import batch_concat
|
10
|
+
|
11
|
+
from . import abstract
|
12
|
+
from .abstract import _to_str
|
13
|
+
|
14
|
+
|
15
|
+
def save_sys_to_str(sys: base.System) -> str:
|
16
|
+
for joint_type in sys.links.joint_params:
|
17
|
+
for i, link_name in enumerate(sys.link_names):
|
18
|
+
joint_params_flat = batch_concat((sys.links[i]).joint_params[joint_type], 0)
|
19
|
+
if not jnp.all(joint_params_flat == 0.0):
|
20
|
+
warnings.warn(
|
21
|
+
"The system has `sys.links.joint_params` unequal to the 'default'"
|
22
|
+
f" value (of zeros). In particular the link `{link_name}` has for"
|
23
|
+
f" the jointtype `{joint_type}` the values {joint_params_flat}. "
|
24
|
+
"This will not be preserved in the xml."
|
25
|
+
)
|
26
|
+
global_index_map = {qd: sys.idx_map(qd) for qd in ["q", "d"]}
|
27
|
+
|
28
|
+
# Create root element
|
29
|
+
x_xy = Element("x_xy")
|
30
|
+
x_xy.set("model", sys.model_name)
|
31
|
+
|
32
|
+
options = SubElement(x_xy, "options")
|
33
|
+
options.set("dt", str(sys.dt))
|
34
|
+
options.set("gravity", _to_str(sys.gravity))
|
35
|
+
|
36
|
+
# Create worldbody
|
37
|
+
worldbody = SubElement(x_xy, "worldbody")
|
38
|
+
|
39
|
+
def process_link(link_idx: int, parent_elem: Element):
|
40
|
+
link = sys.links[link_idx]
|
41
|
+
link_typ = sys.link_types[link_idx]
|
42
|
+
link_name = sys.link_names[link_idx]
|
43
|
+
|
44
|
+
# Create body element
|
45
|
+
body = SubElement(parent_elem, "body")
|
46
|
+
body.set("joint", link_typ)
|
47
|
+
body.set("name", link_name)
|
48
|
+
|
49
|
+
# Set attributes
|
50
|
+
abstract.AbsTrans.to_xml(body, link.transform1)
|
51
|
+
abstract.AbsPosMinMax.to_xml(body, link.pos_min, link.pos_max)
|
52
|
+
abstract.AbsDampArmaStiffZero.to_xml(
|
53
|
+
body,
|
54
|
+
sys.link_damping[global_index_map["d"][link_name]],
|
55
|
+
sys.link_armature[global_index_map["d"][link_name]],
|
56
|
+
sys.link_spring_stiffness[global_index_map["d"][link_name]],
|
57
|
+
sys.link_spring_zeropoint[global_index_map["q"][link_name]],
|
58
|
+
base.Q_WIDTHS[link_typ],
|
59
|
+
base.QD_WIDTHS[link_typ],
|
60
|
+
link_typ,
|
61
|
+
)
|
62
|
+
|
63
|
+
# Add geometry elements
|
64
|
+
geoms = sys.geoms
|
65
|
+
for geom in geoms:
|
66
|
+
if geom.link_idx == link_idx:
|
67
|
+
geom_elem = SubElement(body, "geom")
|
68
|
+
abstract_class = abstract.geometry_to_abstract[type(geom)]
|
69
|
+
abstract_class.to_xml(geom_elem, geom)
|
70
|
+
|
71
|
+
# Maybe add omc element
|
72
|
+
omc_link = sys.omc[link_idx]
|
73
|
+
if omc_link is not None:
|
74
|
+
omc_elem = SubElement(body, "omc")
|
75
|
+
abstract.AbsMaxCoordOMC.to_xml(omc_elem, omc_link)
|
76
|
+
|
77
|
+
# Recursively process child links
|
78
|
+
for child_idx, parent_idx in enumerate(sys.link_parents):
|
79
|
+
if parent_idx == link_idx:
|
80
|
+
process_link(child_idx, body)
|
81
|
+
|
82
|
+
for root_link_idx, parent_idx in enumerate(sys.link_parents):
|
83
|
+
if parent_idx == -1:
|
84
|
+
process_link(root_link_idx, worldbody)
|
85
|
+
|
86
|
+
# Pretty print xml
|
87
|
+
xml_str = parseString(tostring(x_xy)).toprettyxml(indent=" ")
|
88
|
+
return xml_str
|
89
|
+
|
90
|
+
|
91
|
+
def save_sys_to_xml(sys: base.System, xml_path: str) -> None:
|
92
|
+
xml_str = save_sys_to_str(sys)
|
93
|
+
with open(xml_path, "w") as f:
|
94
|
+
f.write(xml_str)
|
ring/maths.py
ADDED
@@ -0,0 +1,397 @@
|
|
1
|
+
from functools import partial
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from jax import custom_jvp
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import jax.random as jrand
|
7
|
+
|
8
|
+
|
9
|
+
def wrap_to_pi(phi):
|
10
|
+
"Wraps angle `phi` (radians) to interval [-pi, pi]."
|
11
|
+
return (phi + jnp.pi) % (2 * jnp.pi) - jnp.pi
|
12
|
+
|
13
|
+
|
14
|
+
x_unit_vector = jnp.array([1.0, 0, 0])
|
15
|
+
y_unit_vector = jnp.array([0.0, 1, 0])
|
16
|
+
z_unit_vector = jnp.array([0.0, 0, 1])
|
17
|
+
|
18
|
+
|
19
|
+
def unit_vectors(xyz: int | str):
|
20
|
+
if isinstance(xyz, str):
|
21
|
+
xyz = {"x": 0, "y": 1, "z": 2}[xyz]
|
22
|
+
return [x_unit_vector, y_unit_vector, z_unit_vector][xyz]
|
23
|
+
|
24
|
+
|
25
|
+
@partial(jnp.vectorize, signature="(k)->(1)")
|
26
|
+
def safe_norm(x):
|
27
|
+
"""Grad-safe for x=0.0. Norm along last axis."""
|
28
|
+
assert x.ndim == 1
|
29
|
+
|
30
|
+
is_zero = jnp.all(jnp.isclose(x, 0.0), axis=-1, keepdims=False)
|
31
|
+
return jax.lax.cond(
|
32
|
+
is_zero,
|
33
|
+
lambda x: jnp.array([0.0], dtype=x.dtype),
|
34
|
+
lambda x: jnp.linalg.norm(x, keepdims=True),
|
35
|
+
x,
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
@partial(jnp.vectorize, signature="(k)->(k)")
|
40
|
+
def safe_normalize(x):
|
41
|
+
"""Execution- and Grad-safe for x=0.0. Normalizes along last axis."""
|
42
|
+
assert x.ndim == 1
|
43
|
+
|
44
|
+
is_zero = jnp.allclose(x, 0.0)
|
45
|
+
return jax.lax.cond(
|
46
|
+
is_zero,
|
47
|
+
lambda x: jnp.zeros_like(x),
|
48
|
+
lambda x: x / jnp.where(is_zero, 1.0, safe_norm(x)),
|
49
|
+
x,
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
@custom_jvp
|
54
|
+
def safe_arccos(x: jnp.ndarray) -> jnp.ndarray:
|
55
|
+
"""Trigonometric inverse cosine, element-wise with safety clipping in grad."""
|
56
|
+
return jnp.arccos(x)
|
57
|
+
|
58
|
+
|
59
|
+
@safe_arccos.defjvp
|
60
|
+
def _safe_arccos_jvp(primal, tangent):
|
61
|
+
(x,) = primal
|
62
|
+
(x_dot,) = tangent
|
63
|
+
primal_out = safe_arccos(x)
|
64
|
+
tangent_out = -x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
|
65
|
+
return primal_out, tangent_out
|
66
|
+
|
67
|
+
|
68
|
+
@custom_jvp
|
69
|
+
def safe_arcsin(x: jnp.ndarray) -> jnp.ndarray:
|
70
|
+
"""Trigonometric inverse sine, element-wise with safety clipping in grad."""
|
71
|
+
return jnp.arcsin(x)
|
72
|
+
|
73
|
+
|
74
|
+
@safe_arcsin.defjvp
|
75
|
+
def _safe_arcsin_jvp(primal, tangent):
|
76
|
+
(x,) = primal
|
77
|
+
(x_dot,) = tangent
|
78
|
+
primal_out = safe_arccos(x)
|
79
|
+
tangent_out = x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
|
80
|
+
return primal_out, tangent_out
|
81
|
+
|
82
|
+
|
83
|
+
@partial(jnp.vectorize, signature="(4)->(4)")
|
84
|
+
def ensure_positive_w(q):
|
85
|
+
return jnp.where(q[0] < 0, -q, q)
|
86
|
+
|
87
|
+
|
88
|
+
def angle_error(q, qhat):
|
89
|
+
"Absolute angle in radians between `q` and `qhat`."
|
90
|
+
return jnp.abs(quat_angle(quat_mul(quat_inv(q), qhat)))
|
91
|
+
|
92
|
+
|
93
|
+
def unit_quats_like(array):
|
94
|
+
"Array of *unit* quaternions of identical shape."
|
95
|
+
if array.shape[-1] != 4:
|
96
|
+
raise Exception()
|
97
|
+
|
98
|
+
return jnp.ones(array.shape[:-1])[..., None] * jnp.array([1.0, 0, 0, 0])
|
99
|
+
|
100
|
+
|
101
|
+
@partial(jnp.vectorize, signature="(4),(4)->(4)")
|
102
|
+
def quat_mul(u: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
|
103
|
+
"Multiplies two quaternions."
|
104
|
+
q = jnp.array(
|
105
|
+
[
|
106
|
+
u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
|
107
|
+
u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
|
108
|
+
u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
|
109
|
+
u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
|
110
|
+
]
|
111
|
+
)
|
112
|
+
return q
|
113
|
+
|
114
|
+
|
115
|
+
def quat_inv(q: jnp.ndarray) -> jnp.ndarray:
|
116
|
+
"Calculates the inverse of quaternion q."
|
117
|
+
return q * jnp.array([1.0, -1, -1, -1])
|
118
|
+
|
119
|
+
|
120
|
+
@partial(jnp.vectorize, signature="(3),(4)->(3)")
|
121
|
+
def rotate(vector: jnp.ndarray, quat: jnp.ndarray):
|
122
|
+
"""Rotates a vector `vector` by a *unit* quaternion `quat`."""
|
123
|
+
qvec = jnp.array([0, *vector])
|
124
|
+
return rotate_quat(qvec, quat)[1:4]
|
125
|
+
|
126
|
+
|
127
|
+
def rotate_matrix(matrix: jax.Array, quat: jax.Array):
|
128
|
+
"Rotate matrix `matrix` by a *unit* quaternion `quat`."
|
129
|
+
E = quat_to_3x3(quat)
|
130
|
+
return E @ matrix @ E.T
|
131
|
+
|
132
|
+
|
133
|
+
def rotate_quat(q: jax.Array, quat: jax.Array):
|
134
|
+
"Rotate quaternion `q` by `quat`"
|
135
|
+
return quat_mul(quat, quat_mul(q, quat_inv(quat)))
|
136
|
+
|
137
|
+
|
138
|
+
@partial(jnp.vectorize, signature="(3),()->(4)")
|
139
|
+
def quat_rot_axis(axis: jnp.ndarray, angle: jnp.ndarray) -> jnp.ndarray:
|
140
|
+
"""Construct a *unit* quaternion that describes rotating around
|
141
|
+
`axis` by `angle` (radians).
|
142
|
+
|
143
|
+
This is the interpretation of rotating the vector and *not*
|
144
|
+
the frame.
|
145
|
+
For the interpretation of rotating the frame and *not* the
|
146
|
+
vector, you should use angle -> -angle.
|
147
|
+
NOTE: Usually, we actually want the second interpretation. Think about it,
|
148
|
+
we use quaternions to re-express vectors in other frames. But the
|
149
|
+
vectors stay the same. We only transform them to a common frames.
|
150
|
+
"""
|
151
|
+
assert axis.shape == (3,)
|
152
|
+
assert angle.shape == ()
|
153
|
+
|
154
|
+
axis = safe_normalize(axis)
|
155
|
+
# NOTE: CONVENTION
|
156
|
+
# 23.04.23
|
157
|
+
# this fixes the issue of prismatic joints being inverted w.r.t.
|
158
|
+
# gravity vector.
|
159
|
+
# The reason is that it inverts the way how revolute joints behave
|
160
|
+
# Such that prismatic joints work by inverting gravity
|
161
|
+
angle *= -1.0
|
162
|
+
s, c = jnp.sin(angle / 2), jnp.cos(angle / 2)
|
163
|
+
return jnp.array([c, *(axis * s)])
|
164
|
+
|
165
|
+
|
166
|
+
@partial(jnp.vectorize, signature="(3,3)->(4)")
|
167
|
+
def quat_from_3x3(m: jnp.ndarray) -> jnp.ndarray:
|
168
|
+
"""Converts 3x3 rotation matrix to *unit* quaternion."""
|
169
|
+
w = jnp.sqrt(1 + m[0, 0] + m[1, 1] + m[2, 2]) / 2.0
|
170
|
+
x = (m[2][1] - m[1][2]) / (w * 4)
|
171
|
+
y = (m[0][2] - m[2][0]) / (w * 4)
|
172
|
+
z = (m[1][0] - m[0][1]) / (w * 4)
|
173
|
+
return jnp.array([w, x, y, z])
|
174
|
+
|
175
|
+
|
176
|
+
@partial(jnp.vectorize, signature="(4)->(3,3)")
|
177
|
+
def quat_to_3x3(q: jnp.ndarray) -> jnp.ndarray:
|
178
|
+
"""Converts *unit* quaternion to 3x3 rotation matrix."""
|
179
|
+
d = jnp.dot(q, q)
|
180
|
+
w, x, y, z = q
|
181
|
+
s = 2 / d
|
182
|
+
xs, ys, zs = x * s, y * s, z * s
|
183
|
+
wx, wy, wz = w * xs, w * ys, w * zs
|
184
|
+
xx, xy, xz = x * xs, x * ys, x * zs
|
185
|
+
yy, yz, zz = y * ys, y * zs, z * zs
|
186
|
+
|
187
|
+
return jnp.array(
|
188
|
+
[
|
189
|
+
jnp.array([1 - (yy + zz), xy - wz, xz + wy]),
|
190
|
+
jnp.array([xy + wz, 1 - (xx + zz), yz - wx]),
|
191
|
+
jnp.array([xz - wy, yz + wx, 1 - (xx + yy)]),
|
192
|
+
]
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
def quat_random(
|
197
|
+
key: jrand.PRNGKey, batch_shape: tuple = (), maxval: float = jnp.pi
|
198
|
+
) -> jax.Array:
|
199
|
+
"""Provides a random *unit* quaternion, sampled uniformly"""
|
200
|
+
assert key.shape == (2,), f"{key.shape}"
|
201
|
+
shape = batch_shape + (4,)
|
202
|
+
qs = safe_normalize(jrand.normal(key, shape))
|
203
|
+
|
204
|
+
def _scale_angle():
|
205
|
+
axis, angle = quat_to_rot_axis(qs)
|
206
|
+
angle_scaled = angle * maxval / jnp.pi
|
207
|
+
return quat_rot_axis(axis, angle_scaled)
|
208
|
+
|
209
|
+
return jax.lax.cond(maxval == jnp.pi, lambda: qs, _scale_angle)
|
210
|
+
|
211
|
+
|
212
|
+
def quat_euler(angles, intrinsic=True, convention="zyx"):
|
213
|
+
"Construct a *unit* quaternion from Euler angles (radians)."
|
214
|
+
|
215
|
+
@partial(jnp.vectorize, signature="(3)->(4)")
|
216
|
+
def _quat_euler(angles):
|
217
|
+
xunit = jnp.array([1.0, 0.0, 0.0])
|
218
|
+
yunit = jnp.array([0.0, 1.0, 0.0])
|
219
|
+
zunit = jnp.array([0.0, 0.0, 1.0])
|
220
|
+
|
221
|
+
axes_map = {
|
222
|
+
"x": xunit,
|
223
|
+
"y": yunit,
|
224
|
+
"z": zunit,
|
225
|
+
}
|
226
|
+
|
227
|
+
q1 = quat_rot_axis(axes_map[convention[0]], angles[0])
|
228
|
+
q2 = quat_rot_axis(axes_map[convention[1]], angles[1])
|
229
|
+
q3 = quat_rot_axis(axes_map[convention[2]], angles[2])
|
230
|
+
|
231
|
+
if intrinsic:
|
232
|
+
return quat_mul(q3, quat_mul(q2, q1))
|
233
|
+
else:
|
234
|
+
return quat_mul(q1, quat_mul(q2, q3))
|
235
|
+
|
236
|
+
return _quat_euler(angles)
|
237
|
+
|
238
|
+
|
239
|
+
@partial(jnp.vectorize, signature="(4)->()")
|
240
|
+
def quat_angle(q):
|
241
|
+
"Extract rotation angle (radians) of quaternion `q`."
|
242
|
+
phi = 2 * jnp.arctan2(safe_norm(q[1:])[0], q[0])
|
243
|
+
return wrap_to_pi(phi)
|
244
|
+
|
245
|
+
|
246
|
+
def quat_angle_constantAxisOverTime(qs):
|
247
|
+
assert qs.ndim == 2
|
248
|
+
assert qs.shape[-1] == 4
|
249
|
+
|
250
|
+
l2norm = lambda x: jnp.sqrt(jnp.sum(x**2, axis=-1))
|
251
|
+
|
252
|
+
axis = safe_normalize(qs[:, 1:])
|
253
|
+
angle = quat_angle(qs)[:, None]
|
254
|
+
convention = axis[0]
|
255
|
+
cond = (l2norm(convention - axis) > l2norm(convention + axis))[..., None]
|
256
|
+
return jnp.where(cond, -angle, angle)[:, 0]
|
257
|
+
|
258
|
+
|
259
|
+
@partial(jnp.vectorize, signature="(4)->(3),()")
|
260
|
+
def quat_to_rot_axis(q):
|
261
|
+
"Extract unit-axis and angle from quaternion `q`."
|
262
|
+
angle = quat_angle(q)
|
263
|
+
# NOTE: CONVENTION
|
264
|
+
angle *= -1.0
|
265
|
+
axis = safe_normalize(q[1:])
|
266
|
+
return axis, angle
|
267
|
+
|
268
|
+
|
269
|
+
@partial(jnp.vectorize, signature="(3)->(4)")
|
270
|
+
def euler_to_quat(angles: jnp.ndarray) -> jnp.ndarray:
|
271
|
+
"""Converts euler rotations in radians to quaternion."""
|
272
|
+
# this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
|
273
|
+
c1, c2, c3 = jnp.cos(angles / 2)
|
274
|
+
s1, s2, s3 = jnp.sin(angles / 2)
|
275
|
+
w = c1 * c2 * c3 - s1 * s2 * s3
|
276
|
+
x = s1 * c2 * c3 + c1 * s2 * s3
|
277
|
+
y = c1 * s2 * c3 - s1 * c2 * s3
|
278
|
+
z = c1 * c2 * s3 + s1 * s2 * c3
|
279
|
+
# NOTE: CONVENTION
|
280
|
+
return quat_inv(jnp.array([w, x, y, z]))
|
281
|
+
|
282
|
+
|
283
|
+
@partial(jnp.vectorize, signature="(4)->(3)")
|
284
|
+
def quat_to_euler(q: jnp.ndarray) -> jnp.ndarray:
|
285
|
+
"""Converts quaternions to euler rotations in radians."""
|
286
|
+
# this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
|
287
|
+
|
288
|
+
# NOTE: CONVENTION
|
289
|
+
q = quat_inv(q)
|
290
|
+
|
291
|
+
z = jnp.arctan2(
|
292
|
+
-2 * q[1] * q[2] + 2 * q[0] * q[3],
|
293
|
+
q[1] * q[1] + q[0] * q[0] - q[3] * q[3] - q[2] * q[2],
|
294
|
+
)
|
295
|
+
# TODO: Investigate why quaternions go so big we need to clip.
|
296
|
+
y = safe_arcsin(jnp.clip(2 * q[1] * q[3] + 2 * q[0] * q[2], -1.0, 1.0))
|
297
|
+
x = jnp.arctan2(
|
298
|
+
-2 * q[2] * q[3] + 2 * q[0] * q[1],
|
299
|
+
q[3] * q[3] - q[2] * q[2] - q[1] * q[1] + q[0] * q[0],
|
300
|
+
)
|
301
|
+
|
302
|
+
return jnp.array([x, y, z])
|
303
|
+
|
304
|
+
|
305
|
+
@partial(jnp.vectorize, signature="(4),(3)->(4),(4)")
|
306
|
+
def quat_project(q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
|
307
|
+
"""Decompose quaternion into a primary rotation around axis `k` such that
|
308
|
+
the residual rotation's angle is minimized.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
q (jax.Array): Quaternion to decompose.
|
312
|
+
k (jax.Array): Primary axis direction.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
tuple[jax.Array, jax.Array]: Primary quaternion, residual quaternion
|
316
|
+
"""
|
317
|
+
phi_pri = 2 * jnp.arctan2(q[1:] @ k, q[0])
|
318
|
+
# NOTE: CONVENTION
|
319
|
+
q_pri = quat_rot_axis(k, -phi_pri)
|
320
|
+
q_res = quat_mul(q, quat_inv(q_pri))
|
321
|
+
return q_pri, q_res
|
322
|
+
|
323
|
+
|
324
|
+
def quat_avg(qs: jax.Array):
|
325
|
+
"Tolga Birdal's algorithm."
|
326
|
+
if qs.ndim == 1:
|
327
|
+
qs = qs[None, :]
|
328
|
+
assert qs.ndim == 2
|
329
|
+
return jnp.linalg.eigh(
|
330
|
+
jnp.einsum("ij,ik,i->...jk", qs, qs, jnp.ones((qs.shape[0],)))
|
331
|
+
)[1][:, -1]
|
332
|
+
|
333
|
+
|
334
|
+
# cutoff_freq=20.0; sampe_freq=100.0
|
335
|
+
# -> alpha = 0.55686
|
336
|
+
# cutoff_freq=15.0
|
337
|
+
# -> alpha = 0.48519
|
338
|
+
def quat_lowpassfilter(
|
339
|
+
qs: jax.Array,
|
340
|
+
cutoff_freq: float = 20.0,
|
341
|
+
samp_freq: float = 100.0,
|
342
|
+
filtfilt: bool = False,
|
343
|
+
) -> jax.Array:
|
344
|
+
assert qs.ndim == 2
|
345
|
+
assert qs.shape[1] == 4
|
346
|
+
|
347
|
+
if filtfilt:
|
348
|
+
qs = quat_lowpassfilter(qs, cutoff_freq, samp_freq, filtfilt=False)
|
349
|
+
qs = quat_lowpassfilter(jnp.flip(qs, 0), cutoff_freq, samp_freq, filtfilt=False)
|
350
|
+
return jnp.flip(qs, 0)
|
351
|
+
|
352
|
+
omega_times_Ts = 2 * jnp.pi * cutoff_freq / samp_freq
|
353
|
+
alpha = omega_times_Ts / (1 + omega_times_Ts)
|
354
|
+
|
355
|
+
def f(y, x):
|
356
|
+
# error quaternion; current state -> target
|
357
|
+
q_err = quat_mul(x, quat_inv(y))
|
358
|
+
# scale down error quaternion
|
359
|
+
axis, angle = quat_to_rot_axis(q_err)
|
360
|
+
# ensure angle >= 0
|
361
|
+
axis, angle = jax.lax.cond(
|
362
|
+
angle < 0,
|
363
|
+
lambda axis, angle: (-axis, -angle),
|
364
|
+
lambda axis, angle: (axis, angle),
|
365
|
+
axis,
|
366
|
+
angle,
|
367
|
+
)
|
368
|
+
angle_scaled = angle * alpha
|
369
|
+
q_err_scaled = quat_rot_axis(axis, angle_scaled)
|
370
|
+
# move small step toward error quaternion
|
371
|
+
y = quat_mul(q_err_scaled, y)
|
372
|
+
return y, y
|
373
|
+
|
374
|
+
qs_filtered = jax.lax.scan(f, qs[0], qs[1:])[1]
|
375
|
+
|
376
|
+
# padd with first value, such that length remains equal
|
377
|
+
qs_filtered = jnp.vstack((qs[0:1], qs_filtered))
|
378
|
+
|
379
|
+
# renormalize due to float32 numerical errors accumulating
|
380
|
+
return qs_filtered / jnp.linalg.norm(qs_filtered, axis=-1, keepdims=True)
|
381
|
+
|
382
|
+
|
383
|
+
def quat_inclinationAngle(q: jax.Array):
|
384
|
+
head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
|
385
|
+
return quat_angle(incl)
|
386
|
+
|
387
|
+
|
388
|
+
def quat_headingAngle(q: jax.Array):
|
389
|
+
head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
|
390
|
+
return quat_angle(head)
|
391
|
+
|
392
|
+
|
393
|
+
def quat_transfer_heading(q_from: jax.Array, q_to: jax.Array):
|
394
|
+
heading = quat_project(q_from, jnp.array([0.0, 0, 1]))[0]
|
395
|
+
# set heading to zero in the `q_to` quaternions
|
396
|
+
q_to = quat_project(q_to, jnp.array([0.0, 0, 1]))[1]
|
397
|
+
return quat_mul(q_to, heading)
|
ring/ml/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import callbacks
|
3
|
+
from . import ml_utils
|
4
|
+
from . import optimizer
|
5
|
+
from . import ringnet
|
6
|
+
from . import train
|
7
|
+
from . import training_loop
|
8
|
+
from .base import AbstractFilter
|
9
|
+
from .ml_utils import on_cluster
|
10
|
+
from .ml_utils import unique_id
|
11
|
+
from .optimizer import make_optimizer
|
12
|
+
from .ringnet import RING
|
13
|
+
from .train import train_fn
|
14
|
+
|
15
|
+
|
16
|
+
def RING_ICML24(params=None, **kwargs):
|
17
|
+
"""Create the RING network used in the icml24 paper.
|
18
|
+
|
19
|
+
X[..., :3] = acc
|
20
|
+
X[..., 3:6] = gyr
|
21
|
+
X[..., 6:9] = jointaxis
|
22
|
+
X[..., 9:] = dt
|
23
|
+
"""
|
24
|
+
from pathlib import Path
|
25
|
+
|
26
|
+
if params is None:
|
27
|
+
params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
|
28
|
+
|
29
|
+
ringnet = RING(params=params, **kwargs) # noqa: F811
|
30
|
+
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
31
|
+
ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
|
32
|
+
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
33
|
+
return ringnet
|