imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/algorithms/jcalc.py
ADDED
@@ -0,0 +1,840 @@
|
|
1
|
+
from dataclasses import asdict
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from dataclasses import field
|
4
|
+
from dataclasses import replace
|
5
|
+
from typing import Any, Callable, get_type_hints, Optional
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import tree_utils
|
10
|
+
|
11
|
+
from ring import algebra
|
12
|
+
from ring import base
|
13
|
+
from ring import maths
|
14
|
+
from ring.algorithms import _random
|
15
|
+
from ring.algorithms._random import _to_float
|
16
|
+
from ring.algorithms._random import TimeDependentFloat
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class MotionConfig:
|
21
|
+
T: float = 60.0 # length of random motion
|
22
|
+
t_min: float = 0.05 # min time between two generated angles
|
23
|
+
t_max: float | TimeDependentFloat = 0.30 # max time ..
|
24
|
+
|
25
|
+
dang_min: float | TimeDependentFloat = 0.1 # minimum angular velocity in rad/s
|
26
|
+
dang_max: float | TimeDependentFloat = 3.0 # maximum angular velocity in rad/s
|
27
|
+
|
28
|
+
# minimum angular velocity of euler angles used for `free and spherical joints`
|
29
|
+
dang_min_free_spherical: float | TimeDependentFloat = 0.1
|
30
|
+
dang_max_free_spherical: float | TimeDependentFloat = 3.0
|
31
|
+
|
32
|
+
# max min allowed actual delta values in radians
|
33
|
+
delta_ang_min: float | TimeDependentFloat = 0.0
|
34
|
+
delta_ang_max: float | TimeDependentFloat = 2 * jnp.pi
|
35
|
+
delta_ang_min_free_spherical: float | TimeDependentFloat = 0.0
|
36
|
+
delta_ang_max_free_spherical: float | TimeDependentFloat = 2 * jnp.pi
|
37
|
+
|
38
|
+
dpos_min: float | TimeDependentFloat = 0.001 # speed of translation
|
39
|
+
dpos_max: float | TimeDependentFloat = 0.7
|
40
|
+
pos_min: float | TimeDependentFloat = -2.5
|
41
|
+
pos_max: float | TimeDependentFloat = +2.5
|
42
|
+
|
43
|
+
# used by both `random_angle_*` and `random_pos_*`
|
44
|
+
# only used if `randomized_interpolation` is set
|
45
|
+
cdf_bins_min: int = 5
|
46
|
+
# by default equal to `cdf_bins_min`
|
47
|
+
cdf_bins_max: Optional[int] = None
|
48
|
+
|
49
|
+
# flags
|
50
|
+
randomized_interpolation_angle: bool = False
|
51
|
+
randomized_interpolation_position: bool = False
|
52
|
+
interpolation_method: str = "cosine"
|
53
|
+
range_of_motion_hinge: bool = True
|
54
|
+
range_of_motion_hinge_method: str = "uniform"
|
55
|
+
|
56
|
+
# initial value of joints
|
57
|
+
ang0_min: float = -jnp.pi
|
58
|
+
ang0_max: float = jnp.pi
|
59
|
+
pos0_min: float = 0.0
|
60
|
+
pos0_max: float = 0.0
|
61
|
+
|
62
|
+
# cor (center of rotation) custom fields
|
63
|
+
cor: bool = False
|
64
|
+
cor_t_min: float = 0.2
|
65
|
+
cor_t_max: float | TimeDependentFloat = 2.0
|
66
|
+
cor_dpos_min: float | TimeDependentFloat = 0.00001
|
67
|
+
cor_dpos_max: float | TimeDependentFloat = 0.5
|
68
|
+
cor_pos_min: float | TimeDependentFloat = -0.4
|
69
|
+
cor_pos_max: float | TimeDependentFloat = 0.4
|
70
|
+
|
71
|
+
def is_feasible(self) -> bool:
|
72
|
+
return _is_feasible_config1(self)
|
73
|
+
|
74
|
+
def to_nomotion_config(self) -> "MotionConfig":
|
75
|
+
kwargs = asdict(self)
|
76
|
+
for key in [
|
77
|
+
"dang_min",
|
78
|
+
"dang_max",
|
79
|
+
"delta_ang_min",
|
80
|
+
"dang_min_free_spherical",
|
81
|
+
"dang_max_free_spherical",
|
82
|
+
"delta_ang_min_free_spherical",
|
83
|
+
"dpos_min",
|
84
|
+
"dpos_max",
|
85
|
+
]:
|
86
|
+
kwargs[key] = 0.0
|
87
|
+
nomotion_config = MotionConfig(**kwargs)
|
88
|
+
assert nomotion_config.is_feasible()
|
89
|
+
return nomotion_config
|
90
|
+
|
91
|
+
|
92
|
+
def _is_feasible_config1(c: MotionConfig) -> bool:
|
93
|
+
t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
|
94
|
+
|
95
|
+
def dx_deltax_check(dx_min, dx_max, deltax_min, deltax_max) -> bool:
|
96
|
+
dx_min, dx_max, deltax_min, deltax_max = map(
|
97
|
+
(lambda v: _to_float(v, 0.0)), (dx_min, dx_max, deltax_min, deltax_max)
|
98
|
+
)
|
99
|
+
if (deltax_max / t_min) < dx_min:
|
100
|
+
return False
|
101
|
+
if (deltax_min / t_max) > dx_max:
|
102
|
+
return False
|
103
|
+
return True
|
104
|
+
|
105
|
+
return all(
|
106
|
+
[
|
107
|
+
dx_deltax_check(*args)
|
108
|
+
for args in zip(
|
109
|
+
[c.dang_min, c.dang_min_free_spherical],
|
110
|
+
[c.dang_max, c.dang_max_free_spherical],
|
111
|
+
[c.delta_ang_min, c.delta_ang_min_free_spherical],
|
112
|
+
[c.delta_ang_max, c.delta_ang_max_free_spherical],
|
113
|
+
)
|
114
|
+
]
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
def _find_interval(t: jax.Array, boundaries: jax.Array):
|
119
|
+
"""Find the interval of `boundaries` between which `t` lies.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
t: Scalar float (e.g. time)
|
123
|
+
boundaries: Array of floats
|
124
|
+
|
125
|
+
Example: (from `test_jcalc.py`)
|
126
|
+
>> _find_interval(1.5, jnp.array([0.0, 1.0, 2.0])) -> 2
|
127
|
+
>> _find_interval(0.5, jnp.array([0.0])) -> 1
|
128
|
+
>> _find_interval(-0.5, jnp.array([0.0])) -> 0
|
129
|
+
"""
|
130
|
+
assert boundaries.ndim == 1
|
131
|
+
|
132
|
+
@jax.vmap
|
133
|
+
def leq_than_boundary(boundary: jax.Array):
|
134
|
+
return jnp.where(t >= boundary, 1, 0)
|
135
|
+
|
136
|
+
return jnp.sum(leq_than_boundary(boundaries))
|
137
|
+
|
138
|
+
|
139
|
+
def join_motionconfigs(
|
140
|
+
configs: list[MotionConfig], boundaries: list[float]
|
141
|
+
) -> MotionConfig:
|
142
|
+
assert len(configs) == (
|
143
|
+
len(boundaries) + 1
|
144
|
+
), "length of `boundaries` should be one less than length of `configs`"
|
145
|
+
boundaries = jnp.array(boundaries, dtype=float)
|
146
|
+
|
147
|
+
def new_value(field: str):
|
148
|
+
scalar_options = jnp.array([getattr(c, field) for c in configs])
|
149
|
+
|
150
|
+
def scalar(t):
|
151
|
+
return jax.lax.dynamic_index_in_dim(
|
152
|
+
scalar_options, _find_interval(t, boundaries), keepdims=False
|
153
|
+
)
|
154
|
+
|
155
|
+
return scalar
|
156
|
+
|
157
|
+
hints = get_type_hints(MotionConfig())
|
158
|
+
attrs = MotionConfig().__dict__
|
159
|
+
is_time_dependent_field = lambda key: hints[key] == (float | TimeDependentFloat)
|
160
|
+
time_dependent_fields = [key for key in attrs if is_time_dependent_field(key)]
|
161
|
+
time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]
|
162
|
+
|
163
|
+
for time_dep_field in time_independent_fields:
|
164
|
+
field_values = set([getattr(config, time_dep_field) for config in configs])
|
165
|
+
assert (
|
166
|
+
len(field_values) == 1
|
167
|
+
), f"MotionConfig.{time_dep_field}={field_values}. Should be one unique value.."
|
168
|
+
|
169
|
+
changes = {field: new_value(field) for field in time_dependent_fields}
|
170
|
+
return replace(configs[0], **changes)
|
171
|
+
|
172
|
+
|
173
|
+
DRAW_FN = Callable[
|
174
|
+
# config, key_t, key_value, dt, params
|
175
|
+
[MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
|
176
|
+
jax.Array,
|
177
|
+
]
|
178
|
+
P_CONTROL_TERM = Callable[
|
179
|
+
# q, q_ref -> qdd
|
180
|
+
# (q_size,), (q_size), -> (qd_size,)
|
181
|
+
[jax.Array, jax.Array],
|
182
|
+
jax.Array,
|
183
|
+
]
|
184
|
+
# this function is used to generate the velocity reference trajectory from the
|
185
|
+
# reference trajectory q, which both are required for the pd control, which it is
|
186
|
+
# required if the simulation is not kinematic but dynamic
|
187
|
+
QD_FROM_Q = Callable[
|
188
|
+
# qs, dt -> dqs
|
189
|
+
# (N, q_size), (1,) -> (N, qd_size)
|
190
|
+
[jax.Array, jax.Array],
|
191
|
+
jax.Array,
|
192
|
+
]
|
193
|
+
# used by ring.algorithms.inverse_kinematics_endeffector to maps from
|
194
|
+
# [-inf, inf] -> feasible joint value range. Defaults to {}.
|
195
|
+
# For example: By default, for a hinge joint it uses `maths.wrap_to_pi`.
|
196
|
+
# For a spherical joint it would normalize to create a unit quaternion.
|
197
|
+
COORDINATE_VECTOR_TO_Q = Callable[
|
198
|
+
# (q_size,) -> (q_size)
|
199
|
+
[jax.Array],
|
200
|
+
jax.Array,
|
201
|
+
]
|
202
|
+
|
203
|
+
# used only by `sim2real.project_xs`, and it receives a transform object
|
204
|
+
# and projects it into the feasible subspace as defined by the joint
|
205
|
+
# and returns the new transform object
|
206
|
+
PROJECT_TRANSFORM_TO_FEASIBLE = Callable[
|
207
|
+
# base.Transform, Pytree (joint_params)
|
208
|
+
[base.Transform, tree_utils.PyTree],
|
209
|
+
base.Transform,
|
210
|
+
]
|
211
|
+
|
212
|
+
# used by ring.System.from_xml and by ring.RCMG
|
213
|
+
# (key) -> Pytree
|
214
|
+
# if it is not given and None, then there will be no specific
|
215
|
+
# joint_parameters for the custom joint and it will simply receive
|
216
|
+
# the defaults parameters, that is joint_params['default']
|
217
|
+
INIT_JOINT_PARAMS = Callable[[jax.Array], tree_utils.PyTree]
|
218
|
+
|
219
|
+
# (transform2_p_to_i, joint_params) -> (q_size)
|
220
|
+
INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
|
221
|
+
|
222
|
+
|
223
|
+
@dataclass
|
224
|
+
class JointModel:
|
225
|
+
# (q, params) -> Transform
|
226
|
+
transform: Callable[[jax.Array, jax.Array], base.Transform]
|
227
|
+
# len(motion) == len(qd)
|
228
|
+
# if callable: joint_params -> base.Motion
|
229
|
+
motion: list[base.Motion | Callable[[jax.Array], base.Motion]] = field(
|
230
|
+
default_factory=lambda: []
|
231
|
+
)
|
232
|
+
# (config, key_t, key_value, params) -> jax.Array
|
233
|
+
rcmg_draw_fn: Optional[DRAW_FN] = None
|
234
|
+
|
235
|
+
# only used by `pd_control`
|
236
|
+
p_control_term: Optional[P_CONTROL_TERM] = None
|
237
|
+
qd_from_q: Optional[QD_FROM_Q] = None
|
238
|
+
|
239
|
+
# used by
|
240
|
+
# -`inverse_kinematics_endeffector`
|
241
|
+
# - System.coordinate_vector_to_q
|
242
|
+
coordinate_vector_to_q: Optional[COORDINATE_VECTOR_TO_Q] = None
|
243
|
+
|
244
|
+
# only used by `inverse_kinematics`
|
245
|
+
inv_kin: Optional[INV_KIN] = None
|
246
|
+
|
247
|
+
init_joint_params: Optional[INIT_JOINT_PARAMS] = None
|
248
|
+
|
249
|
+
utilities: Optional[dict[str, Any]] = field(default_factory=lambda: dict())
|
250
|
+
|
251
|
+
|
252
|
+
def _free_transform(q, _):
|
253
|
+
rot, pos = q[:4], q[4:]
|
254
|
+
return base.Transform(pos, rot)
|
255
|
+
|
256
|
+
|
257
|
+
def _free_2d_transform(q, _):
|
258
|
+
angle_x, pos_yz = q[0], q[1:]
|
259
|
+
rot = maths.quat_rot_axis(maths.x_unit_vector, angle_x)
|
260
|
+
pos = jnp.concatenate((jnp.array([0.0]), pos_yz))
|
261
|
+
return base.Transform(pos, rot)
|
262
|
+
|
263
|
+
|
264
|
+
def _rxyz_transform(q, _, axis):
|
265
|
+
q = jnp.squeeze(q)
|
266
|
+
rot = maths.quat_rot_axis(axis, q)
|
267
|
+
return base.Transform.create(rot=rot)
|
268
|
+
|
269
|
+
|
270
|
+
def _pxyz_transform(q, _, direction):
|
271
|
+
pos = direction * q
|
272
|
+
return base.Transform.create(pos=pos)
|
273
|
+
|
274
|
+
|
275
|
+
def _frozen_transform(_, __):
|
276
|
+
return base.Transform.zero()
|
277
|
+
|
278
|
+
|
279
|
+
def _spherical_transform(q, _):
|
280
|
+
return base.Transform.create(rot=q)
|
281
|
+
|
282
|
+
|
283
|
+
def _saddle_transform(q, _):
|
284
|
+
rot = maths.euler_to_quat(jnp.array([0.0, q[0], q[1]]))
|
285
|
+
return base.Transform.create(rot=rot)
|
286
|
+
|
287
|
+
|
288
|
+
def _p3d_transform(q, _):
|
289
|
+
return base.Transform.create(pos=q)
|
290
|
+
|
291
|
+
|
292
|
+
def _cor_transform(q, _):
|
293
|
+
free = _free_transform(q[:7], _)
|
294
|
+
p3d = _p3d_transform(q[7:], _)
|
295
|
+
return algebra.transform_mul(p3d, free)
|
296
|
+
|
297
|
+
|
298
|
+
mrx = base.Motion.create(ang=jnp.array([1.0, 0, 0]))
|
299
|
+
mry = base.Motion.create(ang=jnp.array([0.0, 1, 0]))
|
300
|
+
mrz = base.Motion.create(ang=jnp.array([0.0, 0, 1]))
|
301
|
+
mpx = base.Motion.create(vel=jnp.array([1.0, 0, 0]))
|
302
|
+
mpy = base.Motion.create(vel=jnp.array([0.0, 1, 0]))
|
303
|
+
mpz = base.Motion.create(vel=jnp.array([0.0, 0, 1]))
|
304
|
+
|
305
|
+
|
306
|
+
def _draw_rxyz(
|
307
|
+
config: MotionConfig,
|
308
|
+
key_t: jax.random.PRNGKey,
|
309
|
+
key_value: jax.random.PRNGKey,
|
310
|
+
dt: float,
|
311
|
+
_: jax.Array,
|
312
|
+
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
313
|
+
enable_range_of_motion: bool = True,
|
314
|
+
free_spherical: bool = False,
|
315
|
+
) -> jax.Array:
|
316
|
+
key_value, consume = jax.random.split(key_value)
|
317
|
+
ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
|
318
|
+
# `random_angle_over_time` always returns wrapped angles, thus it would be
|
319
|
+
# inconsistent to allow an initial value that is not wrapped
|
320
|
+
ANG_0 = maths.wrap_to_pi(ANG_0)
|
321
|
+
# only used for `delta_ang_min_max` logic
|
322
|
+
max_iter = 5
|
323
|
+
return _random.random_angle_over_time(
|
324
|
+
key_t,
|
325
|
+
key_value,
|
326
|
+
ANG_0,
|
327
|
+
config.dang_min_free_spherical if free_spherical else config.dang_min,
|
328
|
+
config.dang_max_free_spherical if free_spherical else config.dang_max,
|
329
|
+
config.delta_ang_min_free_spherical if free_spherical else config.delta_ang_min,
|
330
|
+
config.delta_ang_max_free_spherical if free_spherical else config.delta_ang_max,
|
331
|
+
config.t_min,
|
332
|
+
config.t_max,
|
333
|
+
config.T,
|
334
|
+
dt,
|
335
|
+
max_iter,
|
336
|
+
config.randomized_interpolation_angle,
|
337
|
+
config.range_of_motion_hinge if enable_range_of_motion else False,
|
338
|
+
config.range_of_motion_hinge_method,
|
339
|
+
config.cdf_bins_min,
|
340
|
+
config.cdf_bins_max,
|
341
|
+
config.interpolation_method,
|
342
|
+
)
|
343
|
+
|
344
|
+
|
345
|
+
def _draw_pxyz(
|
346
|
+
config: MotionConfig,
|
347
|
+
_: jax.random.PRNGKey,
|
348
|
+
key_value: jax.random.PRNGKey,
|
349
|
+
dt: float,
|
350
|
+
__: jax.Array,
|
351
|
+
cor: bool = False,
|
352
|
+
) -> jax.Array:
|
353
|
+
key_value, consume = jax.random.split(key_value)
|
354
|
+
POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
|
355
|
+
max_iter = 100
|
356
|
+
return _random.random_position_over_time(
|
357
|
+
key_value,
|
358
|
+
POS_0,
|
359
|
+
config.cor_pos_min if cor else config.pos_min,
|
360
|
+
config.cor_pos_max if cor else config.pos_max,
|
361
|
+
config.cor_dpos_min if cor else config.dpos_min,
|
362
|
+
config.cor_dpos_max if cor else config.dpos_max,
|
363
|
+
config.cor_t_min if cor else config.t_min,
|
364
|
+
config.cor_t_max if cor else config.t_max,
|
365
|
+
config.T,
|
366
|
+
dt,
|
367
|
+
max_iter,
|
368
|
+
config.randomized_interpolation_position,
|
369
|
+
config.cdf_bins_min,
|
370
|
+
config.cdf_bins_max,
|
371
|
+
config.interpolation_method,
|
372
|
+
)
|
373
|
+
|
374
|
+
|
375
|
+
def _draw_spherical(
|
376
|
+
config: MotionConfig,
|
377
|
+
key_t: jax.random.PRNGKey,
|
378
|
+
key_value: jax.random.PRNGKey,
|
379
|
+
dt: float,
|
380
|
+
_: jax.Array,
|
381
|
+
) -> jax.Array:
|
382
|
+
# NOTE: We draw 3 euler angles and then build a quaternion.
|
383
|
+
# Not ideal, but i am unaware of a better way.
|
384
|
+
@jax.vmap
|
385
|
+
def draw_euler_angles(key_t, key_value):
|
386
|
+
return _draw_rxyz(
|
387
|
+
config,
|
388
|
+
key_t,
|
389
|
+
key_value,
|
390
|
+
dt,
|
391
|
+
None,
|
392
|
+
enable_range_of_motion=False,
|
393
|
+
free_spherical=True,
|
394
|
+
)
|
395
|
+
|
396
|
+
triple = lambda key: jax.random.split(key, 3)
|
397
|
+
euler_angles = draw_euler_angles(triple(key_t), triple(key_value)).T
|
398
|
+
q = maths.quat_euler(euler_angles)
|
399
|
+
return q
|
400
|
+
|
401
|
+
|
402
|
+
def _draw_saddle(
|
403
|
+
config: MotionConfig,
|
404
|
+
key_t: jax.random.PRNGKey,
|
405
|
+
key_value: jax.random.PRNGKey,
|
406
|
+
dt: float,
|
407
|
+
_: jax.Array,
|
408
|
+
) -> jax.Array:
|
409
|
+
@jax.vmap
|
410
|
+
def draw_euler_angles(key_t, key_value):
|
411
|
+
return _draw_rxyz(
|
412
|
+
config,
|
413
|
+
key_t,
|
414
|
+
key_value,
|
415
|
+
dt,
|
416
|
+
None,
|
417
|
+
enable_range_of_motion=False,
|
418
|
+
free_spherical=False,
|
419
|
+
)
|
420
|
+
|
421
|
+
double = lambda key: jax.random.split(key)
|
422
|
+
yz_euler_angles = draw_euler_angles(double(key_t), double(key_value)).T
|
423
|
+
return yz_euler_angles
|
424
|
+
|
425
|
+
|
426
|
+
def _draw_p3d_and_cor(
|
427
|
+
config: MotionConfig,
|
428
|
+
_: jax.random.PRNGKey,
|
429
|
+
key_value: jax.random.PRNGKey,
|
430
|
+
dt: float,
|
431
|
+
__: jax.Array,
|
432
|
+
cor: bool,
|
433
|
+
) -> jax.Array:
|
434
|
+
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
|
435
|
+
jax.random.split(key_value, 3)
|
436
|
+
)
|
437
|
+
return pos.T
|
438
|
+
|
439
|
+
|
440
|
+
def _draw_p3d(
|
441
|
+
config: MotionConfig,
|
442
|
+
_: jax.random.PRNGKey,
|
443
|
+
key_value: jax.random.PRNGKey,
|
444
|
+
dt: float,
|
445
|
+
__: jax.Array,
|
446
|
+
) -> jax.Array:
|
447
|
+
return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
|
448
|
+
|
449
|
+
|
450
|
+
def _draw_cor(
|
451
|
+
config: MotionConfig,
|
452
|
+
_: jax.random.PRNGKey,
|
453
|
+
key_value: jax.random.PRNGKey,
|
454
|
+
dt: float,
|
455
|
+
__: jax.Array,
|
456
|
+
) -> jax.Array:
|
457
|
+
key_value1, key_value2 = jax.random.split(key_value)
|
458
|
+
q_free = _draw_free(config, _, key_value1, dt, None)
|
459
|
+
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
|
460
|
+
return jnp.concatenate((q_free, q_p3d), axis=1)
|
461
|
+
|
462
|
+
|
463
|
+
def _draw_free(
|
464
|
+
config: MotionConfig,
|
465
|
+
key_t: jax.random.PRNGKey,
|
466
|
+
key_value: jax.random.PRNGKey,
|
467
|
+
dt: float,
|
468
|
+
__: jax.Array,
|
469
|
+
) -> jax.Array:
|
470
|
+
key_value1, key_value2 = jax.random.split(key_value)
|
471
|
+
q = _draw_spherical(config, key_t, key_value1, dt, None)
|
472
|
+
pos = _draw_p3d(config, None, key_value2, dt, None)
|
473
|
+
return jnp.concatenate((q, pos), axis=1)
|
474
|
+
|
475
|
+
|
476
|
+
def _draw_free_2d(
|
477
|
+
config: MotionConfig,
|
478
|
+
key_t: jax.random.PRNGKey,
|
479
|
+
key_value: jax.random.PRNGKey,
|
480
|
+
dt: float,
|
481
|
+
__: jax.Array,
|
482
|
+
) -> jax.Array:
|
483
|
+
key_value1, key_value2 = jax.random.split(key_value)
|
484
|
+
angle_x = _draw_rxyz(
|
485
|
+
config,
|
486
|
+
key_t,
|
487
|
+
key_value1,
|
488
|
+
dt,
|
489
|
+
None,
|
490
|
+
enable_range_of_motion=False,
|
491
|
+
free_spherical=True,
|
492
|
+
)[:, None]
|
493
|
+
pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
|
494
|
+
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
495
|
+
|
496
|
+
|
497
|
+
def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
|
498
|
+
N = int(config.T / dt)
|
499
|
+
return jnp.zeros((N, 0))
|
500
|
+
|
501
|
+
|
502
|
+
qrel = lambda q1, q2: maths.quat_mul(q1, maths.quat_inv(q2))
|
503
|
+
|
504
|
+
|
505
|
+
def _qd_from_q_quaternion(qs, dt):
|
506
|
+
axis, angle = maths.quat_to_rot_axis(qrel(qs[2:], qs[:-2]))
|
507
|
+
# axis.shape = (n_timesteps, 3); angle.shape = (n_timesteps,)
|
508
|
+
# Thus add singleton dimesions otherwise broadcast error
|
509
|
+
dq = axis * angle[:, None] / (2 * dt)
|
510
|
+
dq = jnp.vstack((jnp.zeros((3,)), dq, jnp.zeros((3,))))
|
511
|
+
return dq
|
512
|
+
|
513
|
+
|
514
|
+
def _qd_from_q_cartesian(qs, dt):
|
515
|
+
dq = jnp.vstack(
|
516
|
+
(jnp.zeros_like(qs[0]), (qs[2:] - qs[:-2]) / (2 * dt), jnp.zeros_like(qs[0]))
|
517
|
+
)
|
518
|
+
return dq
|
519
|
+
|
520
|
+
|
521
|
+
def _p_control_quaternion(q, q_ref):
|
522
|
+
axis, angle = maths.quat_to_rot_axis(qrel(q_ref, q))
|
523
|
+
return axis * angle
|
524
|
+
|
525
|
+
|
526
|
+
def _p_control_term_rxyz(q, q_ref):
|
527
|
+
# q_ref comes from rcmg. Thus, it is already wrapped
|
528
|
+
# TODO: Currently state.q is not wrapped. Change that?
|
529
|
+
return maths.wrap_to_pi(q_ref - maths.wrap_to_pi(q))
|
530
|
+
|
531
|
+
|
532
|
+
def _p_control_term_pxyz_p3d(q, q_ref):
|
533
|
+
return q_ref - q
|
534
|
+
|
535
|
+
|
536
|
+
def _p_control_term_frozen(q, q_ref):
|
537
|
+
return jnp.array([])
|
538
|
+
|
539
|
+
|
540
|
+
def _p_control_term_spherical(q, q_ref):
|
541
|
+
return _p_control_quaternion(q, q_ref)
|
542
|
+
|
543
|
+
|
544
|
+
def _p_control_term_free(q, q_ref):
|
545
|
+
return jnp.concatenate(
|
546
|
+
(
|
547
|
+
_p_control_quaternion(q[:4], q_ref[:4]),
|
548
|
+
(q_ref[4:] - q[4:]),
|
549
|
+
)
|
550
|
+
)
|
551
|
+
|
552
|
+
|
553
|
+
def _p_control_term_free_2d(q, q_ref):
|
554
|
+
return jnp.concatenate(
|
555
|
+
(
|
556
|
+
_p_control_term_rxyz(q[:1], q_ref[:1]),
|
557
|
+
(q_ref[1:] - q[1:]),
|
558
|
+
)
|
559
|
+
)
|
560
|
+
|
561
|
+
|
562
|
+
def _p_control_term_cor(q, q_ref):
|
563
|
+
return _p_control_term_free(q, q_ref)
|
564
|
+
|
565
|
+
|
566
|
+
def _qd_from_q_free(qs, dt):
|
567
|
+
qd_quat = _qd_from_q_quaternion(qs[:, :4], dt)
|
568
|
+
qd_pos = _qd_from_q_cartesian(qs[:, 4:], dt)
|
569
|
+
return jnp.hstack((qd_quat, qd_pos))
|
570
|
+
|
571
|
+
|
572
|
+
def _coordinate_vector_to_q_free_spherical_cor(q):
|
573
|
+
return q.at[:4].set(maths.safe_normalize(q[:4]))
|
574
|
+
|
575
|
+
|
576
|
+
def _coordinate_vector_to_q_free_2d(q):
|
577
|
+
return q.at[0].set(maths.wrap_to_pi(q[0]))
|
578
|
+
|
579
|
+
|
580
|
+
_str2idx = {"x": slice(0, 1), "y": slice(1, 2), "z": slice(2, 3)}
|
581
|
+
|
582
|
+
|
583
|
+
def _inv_kin_rxyz_factory(xyz: str):
|
584
|
+
k = maths.unit_vectors(xyz)
|
585
|
+
|
586
|
+
def _inv_kin_rxyz(x: base.Transform, _) -> jax.Array:
|
587
|
+
# TODO
|
588
|
+
# NOTE: CONVENTION
|
589
|
+
# the first return is the much faster version but it suffers from a convention
|
590
|
+
# issue the second version is equivalent and does not suffer from the
|
591
|
+
# convention issue but it is much slower
|
592
|
+
q = x.rot
|
593
|
+
angle = 2 * jnp.arctan2(q[1:] @ k, q[0])
|
594
|
+
return -angle[None]
|
595
|
+
axis, angle = maths.quat_to_rot_axis(maths.quat_project(q, k)[0])
|
596
|
+
return jnp.where((k @ axis) > 0, angle, -angle)[None]
|
597
|
+
|
598
|
+
return _inv_kin_rxyz
|
599
|
+
|
600
|
+
|
601
|
+
def _inv_kin_pxyz_factory(xyz: str):
|
602
|
+
idx = _str2idx[xyz]
|
603
|
+
|
604
|
+
def _inv_kin_pxyz(x: base.Transform, _) -> jax.Array:
|
605
|
+
return x.pos[idx]
|
606
|
+
|
607
|
+
return _inv_kin_pxyz
|
608
|
+
|
609
|
+
|
610
|
+
def _inv_kin_free_2d(x: base.Transform, _) -> jax.Array:
|
611
|
+
angle_x = _inv_kin_rxyz_factory("x")
|
612
|
+
return jnp.concatenate((angle_x(x), x.pos[1:]))
|
613
|
+
|
614
|
+
|
615
|
+
_joint_types = {
|
616
|
+
"free": JointModel(
|
617
|
+
_free_transform,
|
618
|
+
[mrx, mry, mrz, mpx, mpy, mpz],
|
619
|
+
_draw_free,
|
620
|
+
_p_control_term_free,
|
621
|
+
_qd_from_q_free,
|
622
|
+
coordinate_vector_to_q=_coordinate_vector_to_q_free_spherical_cor,
|
623
|
+
inv_kin=lambda x, _: jnp.concatenate((x.rot, x.pos)),
|
624
|
+
),
|
625
|
+
"free_2d": JointModel(
|
626
|
+
_free_2d_transform,
|
627
|
+
[mrx, mpy, mpz],
|
628
|
+
_draw_free_2d,
|
629
|
+
_p_control_term_free_2d,
|
630
|
+
_qd_from_q_cartesian,
|
631
|
+
coordinate_vector_to_q=_coordinate_vector_to_q_free_2d,
|
632
|
+
inv_kin=_inv_kin_free_2d,
|
633
|
+
),
|
634
|
+
"frozen": JointModel(
|
635
|
+
_frozen_transform,
|
636
|
+
[],
|
637
|
+
_draw_frozen,
|
638
|
+
_p_control_term_frozen,
|
639
|
+
_qd_from_q_cartesian,
|
640
|
+
lambda q: q,
|
641
|
+
lambda x, _: jnp.array([]),
|
642
|
+
),
|
643
|
+
"spherical": JointModel(
|
644
|
+
_spherical_transform,
|
645
|
+
[mrx, mry, mrz],
|
646
|
+
_draw_spherical,
|
647
|
+
_p_control_term_spherical,
|
648
|
+
_qd_from_q_quaternion,
|
649
|
+
_coordinate_vector_to_q_free_spherical_cor,
|
650
|
+
lambda x, _: x.rot,
|
651
|
+
),
|
652
|
+
"p3d": JointModel(
|
653
|
+
_p3d_transform,
|
654
|
+
[mpx, mpy, mpz],
|
655
|
+
_draw_p3d,
|
656
|
+
_p_control_term_pxyz_p3d,
|
657
|
+
_qd_from_q_cartesian,
|
658
|
+
lambda q: q,
|
659
|
+
lambda x, _: x.pos,
|
660
|
+
),
|
661
|
+
"cor": JointModel(
|
662
|
+
_cor_transform,
|
663
|
+
[mrx, mry, mrz, mpx, mpy, mpz, mpx, mpy, mpz],
|
664
|
+
_draw_cor,
|
665
|
+
_p_control_term_cor,
|
666
|
+
_qd_from_q_free,
|
667
|
+
_coordinate_vector_to_q_free_spherical_cor,
|
668
|
+
),
|
669
|
+
"rx": JointModel(
|
670
|
+
lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])),
|
671
|
+
[mrx],
|
672
|
+
_draw_rxyz,
|
673
|
+
_p_control_term_rxyz,
|
674
|
+
_qd_from_q_cartesian,
|
675
|
+
maths.wrap_to_pi,
|
676
|
+
_inv_kin_rxyz_factory("x"),
|
677
|
+
),
|
678
|
+
"ry": JointModel(
|
679
|
+
lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 1, 0])),
|
680
|
+
[mry],
|
681
|
+
_draw_rxyz,
|
682
|
+
_p_control_term_rxyz,
|
683
|
+
_qd_from_q_cartesian,
|
684
|
+
maths.wrap_to_pi,
|
685
|
+
_inv_kin_rxyz_factory("y"),
|
686
|
+
),
|
687
|
+
"rz": JointModel(
|
688
|
+
lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 0, 1])),
|
689
|
+
[mrz],
|
690
|
+
_draw_rxyz,
|
691
|
+
_p_control_term_rxyz,
|
692
|
+
_qd_from_q_cartesian,
|
693
|
+
maths.wrap_to_pi,
|
694
|
+
_inv_kin_rxyz_factory("z"),
|
695
|
+
),
|
696
|
+
"px": JointModel(
|
697
|
+
lambda q, _: _pxyz_transform(q, _, jnp.array([1.0, 0, 0])),
|
698
|
+
[mpx],
|
699
|
+
_draw_pxyz,
|
700
|
+
_p_control_term_pxyz_p3d,
|
701
|
+
_qd_from_q_cartesian,
|
702
|
+
lambda q: q,
|
703
|
+
_inv_kin_pxyz_factory("x"),
|
704
|
+
),
|
705
|
+
"py": JointModel(
|
706
|
+
lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 1, 0])),
|
707
|
+
[mpy],
|
708
|
+
_draw_pxyz,
|
709
|
+
_p_control_term_pxyz_p3d,
|
710
|
+
_qd_from_q_cartesian,
|
711
|
+
lambda q: q,
|
712
|
+
_inv_kin_pxyz_factory("y"),
|
713
|
+
),
|
714
|
+
"pz": JointModel(
|
715
|
+
lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 0, 1])),
|
716
|
+
[mpz],
|
717
|
+
_draw_pxyz,
|
718
|
+
_p_control_term_pxyz_p3d,
|
719
|
+
_qd_from_q_cartesian,
|
720
|
+
lambda q: q,
|
721
|
+
_inv_kin_pxyz_factory("z"),
|
722
|
+
),
|
723
|
+
"saddle": JointModel(
|
724
|
+
_saddle_transform,
|
725
|
+
[mry, mrz],
|
726
|
+
_draw_saddle,
|
727
|
+
_p_control_term_rxyz,
|
728
|
+
_qd_from_q_cartesian,
|
729
|
+
maths.wrap_to_pi,
|
730
|
+
),
|
731
|
+
}
|
732
|
+
|
733
|
+
|
734
|
+
def get_joint_model(joint_type: str) -> JointModel:
|
735
|
+
assert (
|
736
|
+
joint_type in _joint_types
|
737
|
+
), f"{joint_type} not in {list(_joint_types.keys())}"
|
738
|
+
return _joint_types[joint_type]
|
739
|
+
|
740
|
+
|
741
|
+
def register_new_joint_type(
|
742
|
+
joint_type: str,
|
743
|
+
joint_model: JointModel,
|
744
|
+
q_width: int,
|
745
|
+
qd_width: Optional[int] = None,
|
746
|
+
overwrite: bool = False,
|
747
|
+
):
|
748
|
+
# this name is used
|
749
|
+
assert joint_type != "default", "Please use another name."
|
750
|
+
|
751
|
+
exists = joint_type in _joint_types
|
752
|
+
if exists and overwrite:
|
753
|
+
for dic in [
|
754
|
+
base.Q_WIDTHS,
|
755
|
+
base.QD_WIDTHS,
|
756
|
+
_joint_types,
|
757
|
+
]:
|
758
|
+
dic.pop(joint_type)
|
759
|
+
else:
|
760
|
+
assert (
|
761
|
+
not exists
|
762
|
+
), f"joint type `{joint_type}`already exists, use `overwrite=True`"
|
763
|
+
|
764
|
+
if qd_width is None:
|
765
|
+
qd_width = q_width
|
766
|
+
|
767
|
+
assert len(joint_model.motion) == qd_width
|
768
|
+
|
769
|
+
_joint_types.update({joint_type: joint_model})
|
770
|
+
base.Q_WIDTHS.update({joint_type: q_width})
|
771
|
+
base.QD_WIDTHS.update({joint_type: qd_width})
|
772
|
+
|
773
|
+
|
774
|
+
def _limit_scope_of_joint_params(
|
775
|
+
joint_type: str, joint_params: dict[str, tree_utils.PyTree]
|
776
|
+
) -> tree_utils.PyTree:
|
777
|
+
if joint_type not in joint_params:
|
778
|
+
return joint_params["default"]
|
779
|
+
else:
|
780
|
+
return joint_params[joint_type]
|
781
|
+
|
782
|
+
|
783
|
+
def jcalc_transform(
|
784
|
+
joint_type: str, q: jax.Array, joint_params: dict[str, tree_utils.PyTree]
|
785
|
+
) -> base.Transform:
|
786
|
+
joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
|
787
|
+
return _joint_types[joint_type].transform(q, joint_params)
|
788
|
+
|
789
|
+
|
790
|
+
def _to_motion(
|
791
|
+
m: base.Motion | Callable[[jax.Array], base.Motion], joint_params: tree_utils.PyTree
|
792
|
+
) -> base.Motion:
|
793
|
+
if isinstance(m, base.Motion):
|
794
|
+
return m
|
795
|
+
return m(joint_params)
|
796
|
+
|
797
|
+
|
798
|
+
def jcalc_motion(
|
799
|
+
joint_type: str, qd: jax.Array, joint_params: dict[str, tree_utils.PyTree]
|
800
|
+
) -> base.Motion:
|
801
|
+
joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
|
802
|
+
list_motion = _joint_types[joint_type].motion
|
803
|
+
m = base.Motion.zero()
|
804
|
+
for dof in range(len(list_motion)):
|
805
|
+
m += _to_motion(list_motion[dof], joint_params) * qd[dof]
|
806
|
+
return m
|
807
|
+
|
808
|
+
|
809
|
+
def jcalc_tau(
|
810
|
+
joint_type: str, f: base.Force, joint_params: dict[str, tree_utils.PyTree]
|
811
|
+
) -> jax.Array:
|
812
|
+
joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
|
813
|
+
list_motion = _joint_types[joint_type].motion
|
814
|
+
return jnp.array(
|
815
|
+
[algebra.motion_dot(_to_motion(m, joint_params), f) for m in list_motion]
|
816
|
+
)
|
817
|
+
|
818
|
+
|
819
|
+
def _init_joint_params(key: jax.Array, sys: base.System) -> base.System:
|
820
|
+
"""Search systems for custom joints and call their JointModel.init_joint_params
|
821
|
+
functions. Then return updated system."""
|
822
|
+
|
823
|
+
joint_params_init_fns = {}
|
824
|
+
for typ in sys.link_types:
|
825
|
+
if typ not in joint_params_init_fns:
|
826
|
+
init_joint_params = _joint_types[typ].init_joint_params
|
827
|
+
if init_joint_params is not None:
|
828
|
+
joint_params_init_fns[typ] = init_joint_params
|
829
|
+
|
830
|
+
joint_params: dict[str, tree_utils.PyTree] = {}
|
831
|
+
n_links = sys.num_links()
|
832
|
+
for typ in joint_params_init_fns:
|
833
|
+
keys = jax.random.split(key, num=n_links + 1)
|
834
|
+
key, consume = keys[0], keys[1:]
|
835
|
+
joint_params[typ] = jax.vmap(joint_params_init_fns[typ])(consume)
|
836
|
+
|
837
|
+
# add batch default parameters
|
838
|
+
joint_params["default"] = jnp.zeros((n_links, 0))
|
839
|
+
|
840
|
+
return sys.replace(links=sys.links.replace(joint_params=joint_params))
|