imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,403 @@
|
|
1
|
+
from typing import Callable, Optional
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import jax
|
5
|
+
from jax import random
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from ring import maths
|
8
|
+
|
9
|
+
Float = jax.Array
|
10
|
+
TimeDependentFloat = Callable[[Float], Float]
|
11
|
+
|
12
|
+
|
13
|
+
def _to_float(scalar: Float | TimeDependentFloat, t: Float) -> Float:
|
14
|
+
if isinstance(scalar, Callable):
|
15
|
+
return scalar(t)
|
16
|
+
return scalar
|
17
|
+
|
18
|
+
|
19
|
+
# APPROVED
|
20
|
+
def random_angle_over_time(
|
21
|
+
key_t: random.PRNGKey,
|
22
|
+
key_ang: random.PRNGKey,
|
23
|
+
ANG_0: float,
|
24
|
+
dang_min: float | TimeDependentFloat,
|
25
|
+
dang_max: float | TimeDependentFloat,
|
26
|
+
delta_ang_min: float | TimeDependentFloat,
|
27
|
+
delta_ang_max: float | TimeDependentFloat,
|
28
|
+
t_min: float,
|
29
|
+
t_max: float | TimeDependentFloat,
|
30
|
+
T: float,
|
31
|
+
Ts: float,
|
32
|
+
max_iter: int = 5,
|
33
|
+
randomized_interpolation: bool = False,
|
34
|
+
range_of_motion: bool = False,
|
35
|
+
range_of_motion_method: str = "uniform",
|
36
|
+
cdf_bins_min: int = 5,
|
37
|
+
cdf_bins_max: Optional[int] = None,
|
38
|
+
interpolation_method: str = "cosine",
|
39
|
+
) -> jax.Array:
|
40
|
+
def body_fn_outer(val):
|
41
|
+
i, t, phi, key_t, key_ang, ANG = val
|
42
|
+
|
43
|
+
key_t, consume = random.split(key_t)
|
44
|
+
dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
|
45
|
+
|
46
|
+
key_ang, consume = random.split(key_ang)
|
47
|
+
phi = _resolve_range_of_motion(
|
48
|
+
range_of_motion,
|
49
|
+
range_of_motion_method,
|
50
|
+
_to_float(dang_min, t),
|
51
|
+
_to_float(dang_max, t),
|
52
|
+
_to_float(delta_ang_min, t),
|
53
|
+
_to_float(delta_ang_max, t),
|
54
|
+
dt,
|
55
|
+
phi,
|
56
|
+
consume,
|
57
|
+
max_iter,
|
58
|
+
)
|
59
|
+
t += dt
|
60
|
+
|
61
|
+
# TODO do we really need the `jnp.floor(t / Ts) * Ts` since we resample later
|
62
|
+
# anyways
|
63
|
+
ANG_i = jnp.array([[jnp.floor(t / Ts) * Ts, phi]])
|
64
|
+
ANG = jax.lax.dynamic_update_slice_in_dim(ANG, ANG_i, start_index=i, axis=0)
|
65
|
+
|
66
|
+
return i + 1, t, phi, key_t, key_ang, ANG
|
67
|
+
|
68
|
+
def cond_fn_outer(val):
|
69
|
+
i, t, phi, key_t, key_ang, ANG = val
|
70
|
+
return t <= T
|
71
|
+
|
72
|
+
# preallocate ANG array
|
73
|
+
_warn_huge_preallocation(t_min, T)
|
74
|
+
ANG = jnp.zeros((int(T // t_min) + 1, 2))
|
75
|
+
ANG = ANG.at[0, 1].set(ANG_0)
|
76
|
+
|
77
|
+
val_outer = (1, 0.0, ANG_0, key_t, key_ang, ANG)
|
78
|
+
end, *_, consume, ANG = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
|
79
|
+
ANG = jnp.where(
|
80
|
+
(jnp.arange(len(ANG)) < end)[:, None],
|
81
|
+
ANG,
|
82
|
+
jax.lax.dynamic_index_in_dim(ANG, end - 1),
|
83
|
+
)
|
84
|
+
|
85
|
+
# resample
|
86
|
+
t = jnp.arange(T, step=Ts)
|
87
|
+
if randomized_interpolation:
|
88
|
+
q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
89
|
+
t, ANG[:, 0], ANG[:, 1], consume
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
if interpolation_method != "cosine":
|
93
|
+
warnings.warn(
|
94
|
+
f"You have select interpolation method {interpolation_method}. "
|
95
|
+
"Differnt choices of interpolation method are only available if "
|
96
|
+
"`randomized_interpolation` is set."
|
97
|
+
)
|
98
|
+
q = cosInterpolate(t, ANG[:, 0], ANG[:, 1])
|
99
|
+
|
100
|
+
# if range_of_motion is true, then it is wrapped already
|
101
|
+
if not range_of_motion:
|
102
|
+
q = maths.wrap_to_pi(q)
|
103
|
+
|
104
|
+
return q
|
105
|
+
|
106
|
+
|
107
|
+
# APPROVED
|
108
|
+
def random_position_over_time(
|
109
|
+
key: random.PRNGKey,
|
110
|
+
POS_0: float,
|
111
|
+
pos_min: float | TimeDependentFloat,
|
112
|
+
pos_max: float | TimeDependentFloat,
|
113
|
+
dpos_min: float | TimeDependentFloat,
|
114
|
+
dpos_max: float | TimeDependentFloat,
|
115
|
+
t_min: float,
|
116
|
+
t_max: float | TimeDependentFloat,
|
117
|
+
T: float,
|
118
|
+
Ts: float,
|
119
|
+
max_it: int,
|
120
|
+
randomized_interpolation: bool = False,
|
121
|
+
cdf_bins_min: int = 5,
|
122
|
+
cdf_bins_max: Optional[int] = None,
|
123
|
+
interpolation_method: str = "cosine",
|
124
|
+
) -> jax.Array:
|
125
|
+
def body_fn_inner(val):
|
126
|
+
i, t, t_pre, x, x_pre, key = val
|
127
|
+
dt = t - t_pre
|
128
|
+
|
129
|
+
def sample_dx_squared(key):
|
130
|
+
key, consume = random.split(key)
|
131
|
+
dx = (
|
132
|
+
random.uniform(consume) * (2 * dpos_max * t_max**2)
|
133
|
+
- dpos_max * t_max**2
|
134
|
+
)
|
135
|
+
return key, dx
|
136
|
+
|
137
|
+
def sample_dx(key):
|
138
|
+
key, consume1, consume2 = random.split(key, 3)
|
139
|
+
sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
|
140
|
+
dx = (
|
141
|
+
sign
|
142
|
+
* random.uniform(
|
143
|
+
consume2,
|
144
|
+
minval=_to_float(dpos_min, t_pre),
|
145
|
+
maxval=_to_float(dpos_max, t_pre),
|
146
|
+
)
|
147
|
+
* dt
|
148
|
+
)
|
149
|
+
return key, dx
|
150
|
+
|
151
|
+
key, dx = jax.lax.cond(i > max_it, (lambda key: (key, 0.0)), sample_dx, key)
|
152
|
+
x = x_pre + dx
|
153
|
+
|
154
|
+
return i + 1, t, t_pre, x, x_pre, key
|
155
|
+
|
156
|
+
def cond_fn_inner(val):
|
157
|
+
i, t, t_pre, x, x_pre, key = val
|
158
|
+
# this was used before as `dpos`, i don't know why i used a square here?
|
159
|
+
# dpos = abs((x - x_pre) / ((t - t_pre) ** 2)) # noqa: F841
|
160
|
+
dpos = jnp.abs((x - x_pre) / (t - t_pre))
|
161
|
+
break_if_true1 = (
|
162
|
+
(dpos < _to_float(dpos_max, t_pre))
|
163
|
+
& (dpos > _to_float(dpos_min, t_pre))
|
164
|
+
& (x >= _to_float(pos_min, t_pre))
|
165
|
+
& (x <= _to_float(pos_max, t_pre))
|
166
|
+
)
|
167
|
+
break_if_true2 = i > max_it
|
168
|
+
return jnp.logical_not(break_if_true1 | break_if_true2)
|
169
|
+
|
170
|
+
def body_fn_outer(val):
|
171
|
+
i, t, t_pre, x, x_pre, key, POS = val
|
172
|
+
key, consume = random.split(key)
|
173
|
+
t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
|
174
|
+
|
175
|
+
# that zero resets the max_it count
|
176
|
+
val_inner = (0, t, t_pre, x, x_pre, key)
|
177
|
+
_, t, t_pre, x, x_pre, key = jax.lax.while_loop(
|
178
|
+
cond_fn_inner, body_fn_inner, val_inner
|
179
|
+
)
|
180
|
+
|
181
|
+
POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
|
182
|
+
POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
|
183
|
+
t_pre = t
|
184
|
+
x_pre = x
|
185
|
+
return i + 1, t, t_pre, x, x_pre, key, POS
|
186
|
+
|
187
|
+
def cond_fn_outer(val):
|
188
|
+
i, t, t_pre, x, x_pre, key, POS = val
|
189
|
+
return t <= T
|
190
|
+
|
191
|
+
# preallocate POS array
|
192
|
+
_warn_huge_preallocation(t_min, T)
|
193
|
+
POS = jnp.zeros((int(T // t_min) + 1, 2))
|
194
|
+
POS = POS.at[0, 1].set(POS_0)
|
195
|
+
|
196
|
+
val_outer = (1, 0.0, 0.0, 0.0, 0.0, key, POS)
|
197
|
+
end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
|
198
|
+
POS = jnp.where(
|
199
|
+
(jnp.arange(len(POS)) < end)[:, None],
|
200
|
+
POS,
|
201
|
+
jax.lax.dynamic_index_in_dim(POS, end - 1),
|
202
|
+
)
|
203
|
+
|
204
|
+
# resample
|
205
|
+
t = jnp.arange(T, step=Ts)
|
206
|
+
if randomized_interpolation:
|
207
|
+
r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
208
|
+
t, POS[:, 0], POS[:, 1], consume
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
# TODO
|
212
|
+
# Don't warn for position trajectories, i don't care about them as much
|
213
|
+
if False:
|
214
|
+
if interpolation_method != "cosine":
|
215
|
+
warnings.warn(
|
216
|
+
f"You have select interpolation method {interpolation_method}. "
|
217
|
+
"Differnt choices of interpolation method are only available if "
|
218
|
+
"`randomized_interpolation` is set."
|
219
|
+
)
|
220
|
+
r = cosInterpolate(t, POS[:, 0], POS[:, 1])
|
221
|
+
return r
|
222
|
+
|
223
|
+
|
224
|
+
_PREALLOCATION_WARN_LIMIT = 6000
|
225
|
+
|
226
|
+
|
227
|
+
def _warn_huge_preallocation(t_min, T):
|
228
|
+
N = int(T // t_min) + 1
|
229
|
+
if N > _PREALLOCATION_WARN_LIMIT:
|
230
|
+
warnings.warn(
|
231
|
+
f"The combination of `T`={T} and `t_min`={t_min} requires preallocating an "
|
232
|
+
f"array with axis-length of {N} which is larger than the warn limit of "
|
233
|
+
f"{_PREALLOCATION_WARN_LIMIT}. This might lead to large memory requirements"
|
234
|
+
" and/or large jit-times, consider reducing `t_min`."
|
235
|
+
)
|
236
|
+
|
237
|
+
|
238
|
+
def _clip_to_pi(phi):
|
239
|
+
return jnp.clip(phi, -jnp.pi, jnp.pi)
|
240
|
+
|
241
|
+
|
242
|
+
def _resolve_range_of_motion(
|
243
|
+
range_of_motion,
|
244
|
+
range_of_motion_method,
|
245
|
+
dang_min,
|
246
|
+
dang_max,
|
247
|
+
delta_ang_min,
|
248
|
+
delta_ang_max,
|
249
|
+
dt,
|
250
|
+
prev_phi,
|
251
|
+
key,
|
252
|
+
max_iter,
|
253
|
+
):
|
254
|
+
def _next_phi(key):
|
255
|
+
key, consume = random.split(key)
|
256
|
+
|
257
|
+
if range_of_motion:
|
258
|
+
if range_of_motion_method == "coinflip":
|
259
|
+
probs = jnp.array([0.5, 0.5])
|
260
|
+
elif range_of_motion_method == "uniform":
|
261
|
+
p = 0.5 * (1 - prev_phi / jnp.pi)
|
262
|
+
probs = jnp.array([p, (1 - p)])
|
263
|
+
elif range_of_motion_method[:7] == "sigmoid":
|
264
|
+
scale = 1.5
|
265
|
+
provided_params = range_of_motion_method.split("-")
|
266
|
+
if len(provided_params) == 2:
|
267
|
+
scale = float(provided_params[-1])
|
268
|
+
hardcut = jnp.pi - 0.01
|
269
|
+
p = jnp.where(
|
270
|
+
prev_phi > hardcut,
|
271
|
+
0.0,
|
272
|
+
jnp.where(
|
273
|
+
prev_phi < -hardcut, 1.0, jax.nn.sigmoid(-scale * prev_phi)
|
274
|
+
),
|
275
|
+
)
|
276
|
+
probs = jnp.array([p, (1 - p)])
|
277
|
+
else:
|
278
|
+
raise NotImplementedError
|
279
|
+
|
280
|
+
sign = random.choice(consume, jnp.array([1.0, -1.0]), p=probs)
|
281
|
+
lower = _clip_to_pi(prev_phi + sign * dang_min * dt)
|
282
|
+
upper = _clip_to_pi(prev_phi + sign * dang_max * dt)
|
283
|
+
|
284
|
+
# swap if lower > upper
|
285
|
+
lower, upper = jnp.sort(jnp.hstack((lower, upper)))
|
286
|
+
|
287
|
+
key, consume = random.split(key)
|
288
|
+
return random.uniform(consume, minval=lower, maxval=upper)
|
289
|
+
|
290
|
+
else:
|
291
|
+
dphi = random.uniform(consume, minval=dang_min, maxval=dang_max) * dt
|
292
|
+
key, consume = random.split(key)
|
293
|
+
sign = random.choice(consume, jnp.array([1.0, -1.0]))
|
294
|
+
return prev_phi + sign * dphi
|
295
|
+
|
296
|
+
def body_fn(val):
|
297
|
+
key, _, i = val
|
298
|
+
key, consume = jax.random.split(key)
|
299
|
+
next_phi = _next_phi(consume)
|
300
|
+
return key, next_phi, i + 1
|
301
|
+
|
302
|
+
def cond_fn(val):
|
303
|
+
_, next_phi, i = val
|
304
|
+
delta_phi = jnp.abs(next_phi - prev_phi)
|
305
|
+
# delta is in bounds
|
306
|
+
break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
|
307
|
+
break_if_true2 = i > max_iter
|
308
|
+
return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
|
309
|
+
|
310
|
+
# the `prev_phi` here is unused
|
311
|
+
return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
|
312
|
+
|
313
|
+
|
314
|
+
def cosInterpolate(x, xp, fp):
|
315
|
+
i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
|
316
|
+
dx = xp[i] - xp[i - 1]
|
317
|
+
alpha = (x - xp[i - 1]) / dx
|
318
|
+
|
319
|
+
def cos_interpolate(x1, x2, alpha):
|
320
|
+
"""x2 > x1"""
|
321
|
+
return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
|
322
|
+
|
323
|
+
f = jnp.where((dx == 0), fp[i], jax.vmap(cos_interpolate)(fp[i - 1], fp[i], alpha))
|
324
|
+
f = jnp.where(x > xp[-1], fp[-1], f)
|
325
|
+
return f
|
326
|
+
|
327
|
+
|
328
|
+
def _biject_alpha(alpha, cdf):
|
329
|
+
cdf_dx = 1 / (len(cdf) - 1)
|
330
|
+
left_idx = (alpha // cdf_dx).astype(int)
|
331
|
+
a = (alpha - left_idx * cdf_dx) / cdf_dx
|
332
|
+
return (1 - a) * cdf[left_idx] + a * cdf[left_idx + 1]
|
333
|
+
|
334
|
+
|
335
|
+
def _generate_cdf(cdf_bins_min, cdf_bins_max=None):
|
336
|
+
if cdf_bins_max is None:
|
337
|
+
|
338
|
+
def _generate_cdf_min_eq_max(cdf_bins):
|
339
|
+
def __generate_cdf(key):
|
340
|
+
samples = random.uniform(key, (cdf_bins,), minval=1e-6, maxval=1.0)
|
341
|
+
samples = jnp.hstack((jnp.array([0.0]), samples))
|
342
|
+
montonous = jnp.cumsum(samples)
|
343
|
+
cdf = montonous / montonous[-1]
|
344
|
+
return cdf
|
345
|
+
|
346
|
+
return __generate_cdf
|
347
|
+
|
348
|
+
return _generate_cdf_min_eq_max(cdf_bins=cdf_bins_min)
|
349
|
+
|
350
|
+
def _generate_cdf_min_uneq_max(dy_min, dy_max):
|
351
|
+
assert dy_max >= dy_min
|
352
|
+
|
353
|
+
def __generate_cdf(key):
|
354
|
+
key, consume = random.split(key)
|
355
|
+
cdf_bins = random.randint(consume, (), dy_min, dy_max + 1)
|
356
|
+
mask = jnp.where(jnp.arange(dy_max) < cdf_bins, 1, 0)
|
357
|
+
key, consume = random.split(key)
|
358
|
+
mask = random.permutation(consume, mask)
|
359
|
+
dy = random.uniform(key, (dy_max,), minval=1e-6, maxval=1.0)
|
360
|
+
dy = dy[jnp.cumsum(mask) - 1]
|
361
|
+
y = jnp.hstack((jnp.array([0.0]), dy))
|
362
|
+
montonous = jnp.cumsum(y)
|
363
|
+
cdf = montonous / montonous[-1]
|
364
|
+
return cdf
|
365
|
+
|
366
|
+
return __generate_cdf
|
367
|
+
|
368
|
+
return _generate_cdf_min_uneq_max(cdf_bins_min, cdf_bins_max)
|
369
|
+
|
370
|
+
|
371
|
+
def interpolate(
|
372
|
+
cdf_bins_min: int = 1, cdf_bins_max: Optional[int] = None, method: str = "cosine"
|
373
|
+
):
|
374
|
+
"Interpolation with random alpha projection (disabled by default)."
|
375
|
+
generate_cdf = _generate_cdf(cdf_bins_min, cdf_bins_max)
|
376
|
+
|
377
|
+
def _interpolate(x, xp, fp, key):
|
378
|
+
i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
|
379
|
+
dx = xp[i] - xp[i - 1]
|
380
|
+
alpha = (x - xp[i - 1]) / dx
|
381
|
+
|
382
|
+
key, *consume = random.split(key, len(xp) + 1)
|
383
|
+
consume = jnp.array(consume).reshape((len(xp), 2))
|
384
|
+
consume = consume[i - 1]
|
385
|
+
cdfs = jax.vmap(generate_cdf)(consume)
|
386
|
+
alpha = jax.vmap(_biject_alpha)(alpha, cdfs)
|
387
|
+
|
388
|
+
def two_point_interp(x1, x2, alpha):
|
389
|
+
"""x2 > x1"""
|
390
|
+
if method == "cosine":
|
391
|
+
return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
|
392
|
+
elif method == "linear":
|
393
|
+
return (1 - alpha) * x1 + alpha * x2
|
394
|
+
else:
|
395
|
+
raise NotImplementedError
|
396
|
+
|
397
|
+
f = jnp.where(
|
398
|
+
(dx == 0), fp[i], jax.vmap(two_point_interp)(fp[i - 1], fp[i], alpha)
|
399
|
+
)
|
400
|
+
f = jnp.where(x > xp[-1], fp[-1], f)
|
401
|
+
return f
|
402
|
+
|
403
|
+
return _interpolate
|
@@ -0,0 +1,69 @@
|
|
1
|
+
from dataclasses import replace
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import ring
|
6
|
+
from ring import maths
|
7
|
+
from ring.algorithms.jcalc import _draw_rxyz
|
8
|
+
from ring.algorithms.jcalc import _p_control_term_rxyz
|
9
|
+
from ring.algorithms.jcalc import _qd_from_q_cartesian
|
10
|
+
|
11
|
+
|
12
|
+
def register_rr_imp_joint(
|
13
|
+
config_res=ring.MotionConfig(dang_max=5.0, t_max=0.4),
|
14
|
+
ang_max_deg: float = 7.5,
|
15
|
+
name: str = "rr_imp",
|
16
|
+
):
|
17
|
+
def _rr_imp_transform(q, params):
|
18
|
+
axis_pri, axis_res = params["joint_axes"], params["residual"]
|
19
|
+
rot_pri = maths.quat_rot_axis(axis_pri, q[0])
|
20
|
+
rot_res = maths.quat_rot_axis(axis_res, q[1])
|
21
|
+
rot = ring.maths.quat_mul(rot_res, rot_pri)
|
22
|
+
return ring.Transform.create(rot=rot)
|
23
|
+
|
24
|
+
def _draw_rr_imp(config, key_t, key_value, dt, _):
|
25
|
+
key_t1, key_t2 = jax.random.split(key_t)
|
26
|
+
key_value1, key_value2 = jax.random.split(key_value)
|
27
|
+
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
|
28
|
+
q_traj_res = _draw_rxyz(
|
29
|
+
replace(config_res, T=config.T), key_t2, key_value2, dt, _
|
30
|
+
)
|
31
|
+
# scale to be within bounds
|
32
|
+
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
33
|
+
# center
|
34
|
+
q_traj_res -= jnp.mean(q_traj_res)
|
35
|
+
return jnp.concatenate((q_traj_pri[:, None], q_traj_res[:, None]), axis=1)
|
36
|
+
|
37
|
+
def _motion_fn_factory(whichone: str):
|
38
|
+
def _motion_fn(params):
|
39
|
+
axis = params[whichone]
|
40
|
+
return ring.base.Motion.create(ang=axis)
|
41
|
+
|
42
|
+
return _motion_fn
|
43
|
+
|
44
|
+
rr_imp_joint = ring.JointModel(
|
45
|
+
_rr_imp_transform,
|
46
|
+
motion=[_motion_fn_factory("joint_axes"), _motion_fn_factory("residual")],
|
47
|
+
rcmg_draw_fn=_draw_rr_imp,
|
48
|
+
p_control_term=_p_control_term_rxyz,
|
49
|
+
qd_from_q=_qd_from_q_cartesian,
|
50
|
+
init_joint_params=_draw_random_joint_axes,
|
51
|
+
)
|
52
|
+
ring.register_new_joint_type(
|
53
|
+
name,
|
54
|
+
rr_imp_joint,
|
55
|
+
2,
|
56
|
+
2,
|
57
|
+
overwrite=True,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def _draw_random_joint_axes(key):
|
62
|
+
pri_axis = jnp.array([0, 0, 1.0])
|
63
|
+
key1, key2 = jax.random.split(key)
|
64
|
+
phi = jax.random.uniform(key1, maxval=2 * jnp.pi)
|
65
|
+
res_axis = jnp.array([jnp.cos(phi), jnp.sin(phi), 0.0])
|
66
|
+
random_rotation = maths.quat_random(key2)
|
67
|
+
pri_axis = maths.rotate(pri_axis, random_rotation)
|
68
|
+
res_axis = maths.rotate(res_axis, random_rotation)
|
69
|
+
return dict(joint_axes=pri_axis, residual=res_axis)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import ring
|
3
|
+
from ring import maths
|
4
|
+
from ring.algorithms.jcalc import _draw_rxyz
|
5
|
+
from ring.algorithms.jcalc import _p_control_term_rxyz
|
6
|
+
from ring.algorithms.jcalc import _qd_from_q_cartesian
|
7
|
+
|
8
|
+
|
9
|
+
def register_rr_joint():
|
10
|
+
def _rr_transform(q, params):
|
11
|
+
axis = params["joint_axes"]
|
12
|
+
q = jnp.squeeze(q)
|
13
|
+
rot = ring.maths.quat_rot_axis(axis, q)
|
14
|
+
return ring.Transform.create(rot=rot)
|
15
|
+
|
16
|
+
def _motion_fn(params):
|
17
|
+
axis = params["joint_axes"]
|
18
|
+
return ring.base.Motion.create(ang=axis)
|
19
|
+
|
20
|
+
rr_joint = ring.JointModel(
|
21
|
+
_rr_transform,
|
22
|
+
motion=[_motion_fn],
|
23
|
+
rcmg_draw_fn=_draw_rxyz,
|
24
|
+
p_control_term=_p_control_term_rxyz,
|
25
|
+
qd_from_q=_qd_from_q_cartesian,
|
26
|
+
init_joint_params=_draw_random_joint_axis,
|
27
|
+
)
|
28
|
+
|
29
|
+
ring.register_new_joint_type("rr", rr_joint, 1, overwrite=True)
|
30
|
+
|
31
|
+
|
32
|
+
def _draw_random_joint_axis(key):
|
33
|
+
return dict(joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key)))
|