imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
from typing import Callable, Optional, Tuple
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jaxopt
|
6
|
+
from jaxopt._src.base import Solver
|
7
|
+
from ring import algebra
|
8
|
+
from ring import base
|
9
|
+
from ring import maths
|
10
|
+
from ring.algorithms import jcalc
|
11
|
+
|
12
|
+
|
13
|
+
def forward_kinematics_transforms(
|
14
|
+
sys: base.System, q: jax.Array
|
15
|
+
) -> Tuple[base.Transform, base.System]:
|
16
|
+
"""Perform forward kinematics in system.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
- Transforms from base to links. Transforms first axis is (n_links,).
|
20
|
+
- Updated system object with updated `transform2` and `transform` fields.
|
21
|
+
"""
|
22
|
+
|
23
|
+
eps_to_l = {-1: base.Transform.zero()}
|
24
|
+
|
25
|
+
def update_eps_to_l(_, __, q, link, link_idx, parent_idx, joint_type: str):
|
26
|
+
transform2 = jcalc.jcalc_transform(joint_type, q, link.joint_params)
|
27
|
+
transform = algebra.transform_mul(transform2, link.transform1)
|
28
|
+
link = link.replace(transform=transform, transform2=transform2)
|
29
|
+
eps_to_l[link_idx] = algebra.transform_mul(transform, eps_to_l[parent_idx])
|
30
|
+
return eps_to_l[link_idx], link
|
31
|
+
|
32
|
+
eps_to_l_trafos, updated_links = sys.scan(
|
33
|
+
update_eps_to_l,
|
34
|
+
"qllll",
|
35
|
+
q,
|
36
|
+
sys.links,
|
37
|
+
list(range(sys.num_links())),
|
38
|
+
sys.link_parents,
|
39
|
+
sys.link_types,
|
40
|
+
)
|
41
|
+
sys = sys.replace(links=updated_links)
|
42
|
+
return (eps_to_l_trafos, sys)
|
43
|
+
|
44
|
+
|
45
|
+
def forward_kinematics(
|
46
|
+
sys: base.System, state: base.State
|
47
|
+
) -> Tuple[base.System, base.State]:
|
48
|
+
"""Perform forward kinematics in system.
|
49
|
+
- Updates `transform` and `transform2` in `sys`
|
50
|
+
- Updates `x` in `state`
|
51
|
+
"""
|
52
|
+
x, sys = forward_kinematics_transforms(sys, state.q)
|
53
|
+
state = state.replace(x=x)
|
54
|
+
return sys, state
|
55
|
+
|
56
|
+
|
57
|
+
def inverse_kinematics(
|
58
|
+
sys: base.System,
|
59
|
+
state: base.State,
|
60
|
+
) -> base.State:
|
61
|
+
"""Performs inverse kinematics in system. Updates `q` in `state`"""
|
62
|
+
x = state.x
|
63
|
+
q = []
|
64
|
+
|
65
|
+
def f(_, __, i: int, x_i: base.Transform, link_i: base.Link, p: int, typ: str):
|
66
|
+
if p == -1:
|
67
|
+
x_p = base.Transform.zero()
|
68
|
+
else:
|
69
|
+
x_p = x[p]
|
70
|
+
joint_params = jcalc._limit_scope_of_joint_params(typ, link_i.joint_params)
|
71
|
+
transform_p_to_i = algebra.transform_mul(x_i, algebra.transform_inv(x_p))
|
72
|
+
transform2 = algebra.transform_mul(
|
73
|
+
transform_p_to_i, algebra.transform_inv(link_i.transform1)
|
74
|
+
)
|
75
|
+
inv_kin_link = jcalc.get_joint_model(typ).inv_kin
|
76
|
+
if inv_kin_link is None:
|
77
|
+
raise NotImplementedError(
|
78
|
+
f"Please specify for the custom joint `{typ}`"
|
79
|
+
" the JointModel.inv_kin field."
|
80
|
+
)
|
81
|
+
q.append(inv_kin_link(transform2, joint_params))
|
82
|
+
|
83
|
+
sys.scan(
|
84
|
+
f,
|
85
|
+
"lllll",
|
86
|
+
list(range(sys.num_links())),
|
87
|
+
x,
|
88
|
+
sys.links,
|
89
|
+
sys.link_parents,
|
90
|
+
sys.link_types,
|
91
|
+
)
|
92
|
+
|
93
|
+
q = jnp.concatenate(q)
|
94
|
+
assert q.ndim == 1
|
95
|
+
assert q.size == sys.q_size()
|
96
|
+
|
97
|
+
return state.replace(q=q)
|
98
|
+
|
99
|
+
|
100
|
+
def inverse_kinematics_endeffector(
|
101
|
+
sys: base.System,
|
102
|
+
endeffector_link_name: str,
|
103
|
+
endeffector_x: base.Transform,
|
104
|
+
error_weight_rot: float = 1.0,
|
105
|
+
error_weight_pos: float = 1.0,
|
106
|
+
q0: Optional[jax.Array] = None,
|
107
|
+
random_q0_starts: Optional[int] = None,
|
108
|
+
key: Optional[jax.Array] = None,
|
109
|
+
custom_joints: dict[str, Callable[[jax.Array], jax.Array]] = {},
|
110
|
+
jaxopt_solver: Solver = jaxopt.LBFGS,
|
111
|
+
**jaxopt_solver_kwargs,
|
112
|
+
) -> tuple[jax.Array, jaxopt.OptStep]:
|
113
|
+
"""Find the minimal coordinates (joint configuration) such that the endeffector
|
114
|
+
reaches a desired rotational and positional configuration / state.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
sys (base.System): System under consideration.
|
118
|
+
endeffector_link_name (str): Link in system which must reach a desired
|
119
|
+
pos & rot state.
|
120
|
+
endeffector_x (base.Transform): Desired position and rotation state values.
|
121
|
+
error_weight_rot (float, optional): Weight of position error term in
|
122
|
+
optimized RMSE loss. Defaults to 1.0.
|
123
|
+
error_weight_pos (float, optional): Weight of rotational error term in
|
124
|
+
optimized RMSE loss. Defaults to 1.0.
|
125
|
+
q0 (Optional[jax.Array], optional): Initial minimal coordinates guess.
|
126
|
+
Defaults to None.
|
127
|
+
random_q0_starts (Optional[int], optional): Number of random initial values
|
128
|
+
to try. Defaults to None.
|
129
|
+
key (Optional[jax.Array], optional): PRNGKey, only required if
|
130
|
+
`random_q0_starts` > 0. Defaults to None.
|
131
|
+
custom_joints (dict[str, Callable[[jax.Array], jax.Array]], optional):
|
132
|
+
Dictonary that contains for each custom joint type a function that maps from
|
133
|
+
[-inf, inf] -> feasible joint value range. Defaults to {}.
|
134
|
+
For example: By default, for a hinge joint it uses `maths.wrap_to_pi`.
|
135
|
+
jaxopt_solver (Solver, optional): Solver to use. Defaults to jaxopt.LBFGS.
|
136
|
+
|
137
|
+
Raises:
|
138
|
+
NotImplementedError: Specific joint has no preprocess function given in
|
139
|
+
`custom_joints`; but this is required.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
tuple[jax.Array, jaxopt.OptStep]:
|
143
|
+
Minimal coordinates solution, Residual Loss, Optimizer Results
|
144
|
+
"""
|
145
|
+
assert endeffector_x.ndim() == 1, "Use `jax.vmap` for batching"
|
146
|
+
|
147
|
+
if random_q0_starts is not None:
|
148
|
+
assert q0 is None, "Either provide `q0` or `random_q0_starts`."
|
149
|
+
assert key is not None, "`random_q0_starts` requires `key`"
|
150
|
+
|
151
|
+
if q0 is None:
|
152
|
+
if random_q0_starts is None:
|
153
|
+
q0 = base.State.create(sys).q
|
154
|
+
else:
|
155
|
+
q0s = jax.random.normal(key, shape=(random_q0_starts, sys.q_size()))
|
156
|
+
qs, values, results = jax.vmap(
|
157
|
+
lambda q0: inverse_kinematics_endeffector(
|
158
|
+
sys,
|
159
|
+
endeffector_link_name,
|
160
|
+
endeffector_x,
|
161
|
+
error_weight_rot,
|
162
|
+
error_weight_pos,
|
163
|
+
q0,
|
164
|
+
None,
|
165
|
+
None,
|
166
|
+
custom_joints,
|
167
|
+
jaxopt_solver,
|
168
|
+
**jaxopt_solver_kwargs,
|
169
|
+
)
|
170
|
+
)(q0s)
|
171
|
+
|
172
|
+
# find result of best q0 initial value
|
173
|
+
best_q_index = jnp.argmin(values)
|
174
|
+
best_q, best_q_value = jax.tree_map(
|
175
|
+
lambda arr: jax.lax.dynamic_index_in_dim(
|
176
|
+
arr, best_q_index, keepdims=False
|
177
|
+
),
|
178
|
+
(
|
179
|
+
qs,
|
180
|
+
values,
|
181
|
+
),
|
182
|
+
)
|
183
|
+
return best_q, best_q_value, results
|
184
|
+
else:
|
185
|
+
assert len(q0) == sys.q_size()
|
186
|
+
|
187
|
+
def objective(q: jax.Array) -> jax.Array:
|
188
|
+
q = sys.coordinate_vector_to_q(q, custom_joints)
|
189
|
+
xhat = forward_kinematics_transforms(sys, q)[0][
|
190
|
+
sys.name_to_idx(endeffector_link_name)
|
191
|
+
]
|
192
|
+
error_rot = maths.angle_error(endeffector_x.rot, xhat.rot)
|
193
|
+
error_pos = jnp.sqrt(jnp.sum((endeffector_x.pos - xhat.pos) ** 2))
|
194
|
+
return error_weight_rot * error_rot + error_weight_pos * error_pos
|
195
|
+
|
196
|
+
solver = jaxopt_solver(objective, **jaxopt_solver_kwargs)
|
197
|
+
results = solver.run(q0)
|
198
|
+
q_sol = sys.coordinate_vector_to_q(results.params, custom_joints)
|
199
|
+
# stop gradients such that this value can be used for optimizing e.g.
|
200
|
+
# parameters in the system object, such as sys.links.joint_params
|
201
|
+
q_sol_value = objective(jax.lax.stop_gradient(results.params))
|
202
|
+
return q_sol, q_sol_value, results
|