imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,403 @@
1
+ from typing import Callable, Optional
2
+ import warnings
3
+
4
+ import jax
5
+ from jax import random
6
+ import jax.numpy as jnp
7
+ from ring import maths
8
+
9
+ Float = jax.Array
10
+ TimeDependentFloat = Callable[[Float], Float]
11
+
12
+
13
+ def _to_float(scalar: Float | TimeDependentFloat, t: Float) -> Float:
14
+ if isinstance(scalar, Callable):
15
+ return scalar(t)
16
+ return scalar
17
+
18
+
19
+ # APPROVED
20
+ def random_angle_over_time(
21
+ key_t: random.PRNGKey,
22
+ key_ang: random.PRNGKey,
23
+ ANG_0: float,
24
+ dang_min: float | TimeDependentFloat,
25
+ dang_max: float | TimeDependentFloat,
26
+ delta_ang_min: float | TimeDependentFloat,
27
+ delta_ang_max: float | TimeDependentFloat,
28
+ t_min: float,
29
+ t_max: float | TimeDependentFloat,
30
+ T: float,
31
+ Ts: float,
32
+ max_iter: int = 5,
33
+ randomized_interpolation: bool = False,
34
+ range_of_motion: bool = False,
35
+ range_of_motion_method: str = "uniform",
36
+ cdf_bins_min: int = 5,
37
+ cdf_bins_max: Optional[int] = None,
38
+ interpolation_method: str = "cosine",
39
+ ) -> jax.Array:
40
+ def body_fn_outer(val):
41
+ i, t, phi, key_t, key_ang, ANG = val
42
+
43
+ key_t, consume = random.split(key_t)
44
+ dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
45
+
46
+ key_ang, consume = random.split(key_ang)
47
+ phi = _resolve_range_of_motion(
48
+ range_of_motion,
49
+ range_of_motion_method,
50
+ _to_float(dang_min, t),
51
+ _to_float(dang_max, t),
52
+ _to_float(delta_ang_min, t),
53
+ _to_float(delta_ang_max, t),
54
+ dt,
55
+ phi,
56
+ consume,
57
+ max_iter,
58
+ )
59
+ t += dt
60
+
61
+ # TODO do we really need the `jnp.floor(t / Ts) * Ts` since we resample later
62
+ # anyways
63
+ ANG_i = jnp.array([[jnp.floor(t / Ts) * Ts, phi]])
64
+ ANG = jax.lax.dynamic_update_slice_in_dim(ANG, ANG_i, start_index=i, axis=0)
65
+
66
+ return i + 1, t, phi, key_t, key_ang, ANG
67
+
68
+ def cond_fn_outer(val):
69
+ i, t, phi, key_t, key_ang, ANG = val
70
+ return t <= T
71
+
72
+ # preallocate ANG array
73
+ _warn_huge_preallocation(t_min, T)
74
+ ANG = jnp.zeros((int(T // t_min) + 1, 2))
75
+ ANG = ANG.at[0, 1].set(ANG_0)
76
+
77
+ val_outer = (1, 0.0, ANG_0, key_t, key_ang, ANG)
78
+ end, *_, consume, ANG = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
79
+ ANG = jnp.where(
80
+ (jnp.arange(len(ANG)) < end)[:, None],
81
+ ANG,
82
+ jax.lax.dynamic_index_in_dim(ANG, end - 1),
83
+ )
84
+
85
+ # resample
86
+ t = jnp.arange(T, step=Ts)
87
+ if randomized_interpolation:
88
+ q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
89
+ t, ANG[:, 0], ANG[:, 1], consume
90
+ )
91
+ else:
92
+ if interpolation_method != "cosine":
93
+ warnings.warn(
94
+ f"You have select interpolation method {interpolation_method}. "
95
+ "Differnt choices of interpolation method are only available if "
96
+ "`randomized_interpolation` is set."
97
+ )
98
+ q = cosInterpolate(t, ANG[:, 0], ANG[:, 1])
99
+
100
+ # if range_of_motion is true, then it is wrapped already
101
+ if not range_of_motion:
102
+ q = maths.wrap_to_pi(q)
103
+
104
+ return q
105
+
106
+
107
+ # APPROVED
108
+ def random_position_over_time(
109
+ key: random.PRNGKey,
110
+ POS_0: float,
111
+ pos_min: float | TimeDependentFloat,
112
+ pos_max: float | TimeDependentFloat,
113
+ dpos_min: float | TimeDependentFloat,
114
+ dpos_max: float | TimeDependentFloat,
115
+ t_min: float,
116
+ t_max: float | TimeDependentFloat,
117
+ T: float,
118
+ Ts: float,
119
+ max_it: int,
120
+ randomized_interpolation: bool = False,
121
+ cdf_bins_min: int = 5,
122
+ cdf_bins_max: Optional[int] = None,
123
+ interpolation_method: str = "cosine",
124
+ ) -> jax.Array:
125
+ def body_fn_inner(val):
126
+ i, t, t_pre, x, x_pre, key = val
127
+ dt = t - t_pre
128
+
129
+ def sample_dx_squared(key):
130
+ key, consume = random.split(key)
131
+ dx = (
132
+ random.uniform(consume) * (2 * dpos_max * t_max**2)
133
+ - dpos_max * t_max**2
134
+ )
135
+ return key, dx
136
+
137
+ def sample_dx(key):
138
+ key, consume1, consume2 = random.split(key, 3)
139
+ sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
140
+ dx = (
141
+ sign
142
+ * random.uniform(
143
+ consume2,
144
+ minval=_to_float(dpos_min, t_pre),
145
+ maxval=_to_float(dpos_max, t_pre),
146
+ )
147
+ * dt
148
+ )
149
+ return key, dx
150
+
151
+ key, dx = jax.lax.cond(i > max_it, (lambda key: (key, 0.0)), sample_dx, key)
152
+ x = x_pre + dx
153
+
154
+ return i + 1, t, t_pre, x, x_pre, key
155
+
156
+ def cond_fn_inner(val):
157
+ i, t, t_pre, x, x_pre, key = val
158
+ # this was used before as `dpos`, i don't know why i used a square here?
159
+ # dpos = abs((x - x_pre) / ((t - t_pre) ** 2)) # noqa: F841
160
+ dpos = jnp.abs((x - x_pre) / (t - t_pre))
161
+ break_if_true1 = (
162
+ (dpos < _to_float(dpos_max, t_pre))
163
+ & (dpos > _to_float(dpos_min, t_pre))
164
+ & (x >= _to_float(pos_min, t_pre))
165
+ & (x <= _to_float(pos_max, t_pre))
166
+ )
167
+ break_if_true2 = i > max_it
168
+ return jnp.logical_not(break_if_true1 | break_if_true2)
169
+
170
+ def body_fn_outer(val):
171
+ i, t, t_pre, x, x_pre, key, POS = val
172
+ key, consume = random.split(key)
173
+ t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
174
+
175
+ # that zero resets the max_it count
176
+ val_inner = (0, t, t_pre, x, x_pre, key)
177
+ _, t, t_pre, x, x_pre, key = jax.lax.while_loop(
178
+ cond_fn_inner, body_fn_inner, val_inner
179
+ )
180
+
181
+ POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
182
+ POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
183
+ t_pre = t
184
+ x_pre = x
185
+ return i + 1, t, t_pre, x, x_pre, key, POS
186
+
187
+ def cond_fn_outer(val):
188
+ i, t, t_pre, x, x_pre, key, POS = val
189
+ return t <= T
190
+
191
+ # preallocate POS array
192
+ _warn_huge_preallocation(t_min, T)
193
+ POS = jnp.zeros((int(T // t_min) + 1, 2))
194
+ POS = POS.at[0, 1].set(POS_0)
195
+
196
+ val_outer = (1, 0.0, 0.0, 0.0, 0.0, key, POS)
197
+ end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
198
+ POS = jnp.where(
199
+ (jnp.arange(len(POS)) < end)[:, None],
200
+ POS,
201
+ jax.lax.dynamic_index_in_dim(POS, end - 1),
202
+ )
203
+
204
+ # resample
205
+ t = jnp.arange(T, step=Ts)
206
+ if randomized_interpolation:
207
+ r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
208
+ t, POS[:, 0], POS[:, 1], consume
209
+ )
210
+ else:
211
+ # TODO
212
+ # Don't warn for position trajectories, i don't care about them as much
213
+ if False:
214
+ if interpolation_method != "cosine":
215
+ warnings.warn(
216
+ f"You have select interpolation method {interpolation_method}. "
217
+ "Differnt choices of interpolation method are only available if "
218
+ "`randomized_interpolation` is set."
219
+ )
220
+ r = cosInterpolate(t, POS[:, 0], POS[:, 1])
221
+ return r
222
+
223
+
224
+ _PREALLOCATION_WARN_LIMIT = 6000
225
+
226
+
227
+ def _warn_huge_preallocation(t_min, T):
228
+ N = int(T // t_min) + 1
229
+ if N > _PREALLOCATION_WARN_LIMIT:
230
+ warnings.warn(
231
+ f"The combination of `T`={T} and `t_min`={t_min} requires preallocating an "
232
+ f"array with axis-length of {N} which is larger than the warn limit of "
233
+ f"{_PREALLOCATION_WARN_LIMIT}. This might lead to large memory requirements"
234
+ " and/or large jit-times, consider reducing `t_min`."
235
+ )
236
+
237
+
238
+ def _clip_to_pi(phi):
239
+ return jnp.clip(phi, -jnp.pi, jnp.pi)
240
+
241
+
242
+ def _resolve_range_of_motion(
243
+ range_of_motion,
244
+ range_of_motion_method,
245
+ dang_min,
246
+ dang_max,
247
+ delta_ang_min,
248
+ delta_ang_max,
249
+ dt,
250
+ prev_phi,
251
+ key,
252
+ max_iter,
253
+ ):
254
+ def _next_phi(key):
255
+ key, consume = random.split(key)
256
+
257
+ if range_of_motion:
258
+ if range_of_motion_method == "coinflip":
259
+ probs = jnp.array([0.5, 0.5])
260
+ elif range_of_motion_method == "uniform":
261
+ p = 0.5 * (1 - prev_phi / jnp.pi)
262
+ probs = jnp.array([p, (1 - p)])
263
+ elif range_of_motion_method[:7] == "sigmoid":
264
+ scale = 1.5
265
+ provided_params = range_of_motion_method.split("-")
266
+ if len(provided_params) == 2:
267
+ scale = float(provided_params[-1])
268
+ hardcut = jnp.pi - 0.01
269
+ p = jnp.where(
270
+ prev_phi > hardcut,
271
+ 0.0,
272
+ jnp.where(
273
+ prev_phi < -hardcut, 1.0, jax.nn.sigmoid(-scale * prev_phi)
274
+ ),
275
+ )
276
+ probs = jnp.array([p, (1 - p)])
277
+ else:
278
+ raise NotImplementedError
279
+
280
+ sign = random.choice(consume, jnp.array([1.0, -1.0]), p=probs)
281
+ lower = _clip_to_pi(prev_phi + sign * dang_min * dt)
282
+ upper = _clip_to_pi(prev_phi + sign * dang_max * dt)
283
+
284
+ # swap if lower > upper
285
+ lower, upper = jnp.sort(jnp.hstack((lower, upper)))
286
+
287
+ key, consume = random.split(key)
288
+ return random.uniform(consume, minval=lower, maxval=upper)
289
+
290
+ else:
291
+ dphi = random.uniform(consume, minval=dang_min, maxval=dang_max) * dt
292
+ key, consume = random.split(key)
293
+ sign = random.choice(consume, jnp.array([1.0, -1.0]))
294
+ return prev_phi + sign * dphi
295
+
296
+ def body_fn(val):
297
+ key, _, i = val
298
+ key, consume = jax.random.split(key)
299
+ next_phi = _next_phi(consume)
300
+ return key, next_phi, i + 1
301
+
302
+ def cond_fn(val):
303
+ _, next_phi, i = val
304
+ delta_phi = jnp.abs(next_phi - prev_phi)
305
+ # delta is in bounds
306
+ break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
307
+ break_if_true2 = i > max_iter
308
+ return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
309
+
310
+ # the `prev_phi` here is unused
311
+ return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
312
+
313
+
314
+ def cosInterpolate(x, xp, fp):
315
+ i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
316
+ dx = xp[i] - xp[i - 1]
317
+ alpha = (x - xp[i - 1]) / dx
318
+
319
+ def cos_interpolate(x1, x2, alpha):
320
+ """x2 > x1"""
321
+ return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
322
+
323
+ f = jnp.where((dx == 0), fp[i], jax.vmap(cos_interpolate)(fp[i - 1], fp[i], alpha))
324
+ f = jnp.where(x > xp[-1], fp[-1], f)
325
+ return f
326
+
327
+
328
+ def _biject_alpha(alpha, cdf):
329
+ cdf_dx = 1 / (len(cdf) - 1)
330
+ left_idx = (alpha // cdf_dx).astype(int)
331
+ a = (alpha - left_idx * cdf_dx) / cdf_dx
332
+ return (1 - a) * cdf[left_idx] + a * cdf[left_idx + 1]
333
+
334
+
335
+ def _generate_cdf(cdf_bins_min, cdf_bins_max=None):
336
+ if cdf_bins_max is None:
337
+
338
+ def _generate_cdf_min_eq_max(cdf_bins):
339
+ def __generate_cdf(key):
340
+ samples = random.uniform(key, (cdf_bins,), minval=1e-6, maxval=1.0)
341
+ samples = jnp.hstack((jnp.array([0.0]), samples))
342
+ montonous = jnp.cumsum(samples)
343
+ cdf = montonous / montonous[-1]
344
+ return cdf
345
+
346
+ return __generate_cdf
347
+
348
+ return _generate_cdf_min_eq_max(cdf_bins=cdf_bins_min)
349
+
350
+ def _generate_cdf_min_uneq_max(dy_min, dy_max):
351
+ assert dy_max >= dy_min
352
+
353
+ def __generate_cdf(key):
354
+ key, consume = random.split(key)
355
+ cdf_bins = random.randint(consume, (), dy_min, dy_max + 1)
356
+ mask = jnp.where(jnp.arange(dy_max) < cdf_bins, 1, 0)
357
+ key, consume = random.split(key)
358
+ mask = random.permutation(consume, mask)
359
+ dy = random.uniform(key, (dy_max,), minval=1e-6, maxval=1.0)
360
+ dy = dy[jnp.cumsum(mask) - 1]
361
+ y = jnp.hstack((jnp.array([0.0]), dy))
362
+ montonous = jnp.cumsum(y)
363
+ cdf = montonous / montonous[-1]
364
+ return cdf
365
+
366
+ return __generate_cdf
367
+
368
+ return _generate_cdf_min_uneq_max(cdf_bins_min, cdf_bins_max)
369
+
370
+
371
+ def interpolate(
372
+ cdf_bins_min: int = 1, cdf_bins_max: Optional[int] = None, method: str = "cosine"
373
+ ):
374
+ "Interpolation with random alpha projection (disabled by default)."
375
+ generate_cdf = _generate_cdf(cdf_bins_min, cdf_bins_max)
376
+
377
+ def _interpolate(x, xp, fp, key):
378
+ i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
379
+ dx = xp[i] - xp[i - 1]
380
+ alpha = (x - xp[i - 1]) / dx
381
+
382
+ key, *consume = random.split(key, len(xp) + 1)
383
+ consume = jnp.array(consume).reshape((len(xp), 2))
384
+ consume = consume[i - 1]
385
+ cdfs = jax.vmap(generate_cdf)(consume)
386
+ alpha = jax.vmap(_biject_alpha)(alpha, cdfs)
387
+
388
+ def two_point_interp(x1, x2, alpha):
389
+ """x2 > x1"""
390
+ if method == "cosine":
391
+ return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
392
+ elif method == "linear":
393
+ return (1 - alpha) * x1 + alpha * x2
394
+ else:
395
+ raise NotImplementedError
396
+
397
+ f = jnp.where(
398
+ (dx == 0), fp[i], jax.vmap(two_point_interp)(fp[i - 1], fp[i], alpha)
399
+ )
400
+ f = jnp.where(x > xp[-1], fp[-1], f)
401
+ return f
402
+
403
+ return _interpolate
@@ -0,0 +1,6 @@
1
+ from .rr_imp_joint import register_rr_imp_joint
2
+ from .rr_joint import register_rr_joint
3
+ from .suntay import GP_DrawFnPair
4
+ from .suntay import MLP_DrawnFnPair
5
+ from .suntay import register_suntay
6
+ from .suntay import SuntayConfig
@@ -0,0 +1,69 @@
1
+ from dataclasses import replace
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import ring
6
+ from ring import maths
7
+ from ring.algorithms.jcalc import _draw_rxyz
8
+ from ring.algorithms.jcalc import _p_control_term_rxyz
9
+ from ring.algorithms.jcalc import _qd_from_q_cartesian
10
+
11
+
12
+ def register_rr_imp_joint(
13
+ config_res=ring.MotionConfig(dang_max=5.0, t_max=0.4),
14
+ ang_max_deg: float = 7.5,
15
+ name: str = "rr_imp",
16
+ ):
17
+ def _rr_imp_transform(q, params):
18
+ axis_pri, axis_res = params["joint_axes"], params["residual"]
19
+ rot_pri = maths.quat_rot_axis(axis_pri, q[0])
20
+ rot_res = maths.quat_rot_axis(axis_res, q[1])
21
+ rot = ring.maths.quat_mul(rot_res, rot_pri)
22
+ return ring.Transform.create(rot=rot)
23
+
24
+ def _draw_rr_imp(config, key_t, key_value, dt, _):
25
+ key_t1, key_t2 = jax.random.split(key_t)
26
+ key_value1, key_value2 = jax.random.split(key_value)
27
+ q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
28
+ q_traj_res = _draw_rxyz(
29
+ replace(config_res, T=config.T), key_t2, key_value2, dt, _
30
+ )
31
+ # scale to be within bounds
32
+ q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
33
+ # center
34
+ q_traj_res -= jnp.mean(q_traj_res)
35
+ return jnp.concatenate((q_traj_pri[:, None], q_traj_res[:, None]), axis=1)
36
+
37
+ def _motion_fn_factory(whichone: str):
38
+ def _motion_fn(params):
39
+ axis = params[whichone]
40
+ return ring.base.Motion.create(ang=axis)
41
+
42
+ return _motion_fn
43
+
44
+ rr_imp_joint = ring.JointModel(
45
+ _rr_imp_transform,
46
+ motion=[_motion_fn_factory("joint_axes"), _motion_fn_factory("residual")],
47
+ rcmg_draw_fn=_draw_rr_imp,
48
+ p_control_term=_p_control_term_rxyz,
49
+ qd_from_q=_qd_from_q_cartesian,
50
+ init_joint_params=_draw_random_joint_axes,
51
+ )
52
+ ring.register_new_joint_type(
53
+ name,
54
+ rr_imp_joint,
55
+ 2,
56
+ 2,
57
+ overwrite=True,
58
+ )
59
+
60
+
61
+ def _draw_random_joint_axes(key):
62
+ pri_axis = jnp.array([0, 0, 1.0])
63
+ key1, key2 = jax.random.split(key)
64
+ phi = jax.random.uniform(key1, maxval=2 * jnp.pi)
65
+ res_axis = jnp.array([jnp.cos(phi), jnp.sin(phi), 0.0])
66
+ random_rotation = maths.quat_random(key2)
67
+ pri_axis = maths.rotate(pri_axis, random_rotation)
68
+ res_axis = maths.rotate(res_axis, random_rotation)
69
+ return dict(joint_axes=pri_axis, residual=res_axis)
@@ -0,0 +1,33 @@
1
+ import jax.numpy as jnp
2
+ import ring
3
+ from ring import maths
4
+ from ring.algorithms.jcalc import _draw_rxyz
5
+ from ring.algorithms.jcalc import _p_control_term_rxyz
6
+ from ring.algorithms.jcalc import _qd_from_q_cartesian
7
+
8
+
9
+ def register_rr_joint():
10
+ def _rr_transform(q, params):
11
+ axis = params["joint_axes"]
12
+ q = jnp.squeeze(q)
13
+ rot = ring.maths.quat_rot_axis(axis, q)
14
+ return ring.Transform.create(rot=rot)
15
+
16
+ def _motion_fn(params):
17
+ axis = params["joint_axes"]
18
+ return ring.base.Motion.create(ang=axis)
19
+
20
+ rr_joint = ring.JointModel(
21
+ _rr_transform,
22
+ motion=[_motion_fn],
23
+ rcmg_draw_fn=_draw_rxyz,
24
+ p_control_term=_p_control_term_rxyz,
25
+ qd_from_q=_qd_from_q_cartesian,
26
+ init_joint_params=_draw_random_joint_axis,
27
+ )
28
+
29
+ ring.register_new_joint_type("rr", rr_joint, 1, overwrite=True)
30
+
31
+
32
+ def _draw_random_joint_axis(key):
33
+ return dict(joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key)))