imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,424 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, NamedTuple, Optional
3
+
4
+ import haiku as hk
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from tree_utils import PyTree
8
+
9
+ import ring
10
+ from ring import maths
11
+ from ring.algorithms._random import random_angle_over_time
12
+
13
+ Params = PyTree
14
+
15
+
16
+ class DrawnFnPair(NamedTuple):
17
+ # (key) -> tree
18
+ init: Callable[[jax.Array], Params]
19
+ # (params, q) -> (1,)
20
+ apply: Callable[[Params, jax.Array], jax.Array]
21
+
22
+
23
+ # (flexions, min, max) -> DrawnFnPair
24
+ DrawnFnPairFactory = Callable[[jax.Array, float, float], DrawnFnPair]
25
+
26
+
27
+ def deg2rad(deg: float):
28
+ return (deg / 180.0) * 3.1415926535
29
+
30
+
31
+ def GP_DrawFnPair(
32
+ length_scale: float = 1.4, large_abs_values_of_gps: float = 0.25
33
+ ) -> DrawnFnPairFactory:
34
+
35
+ def factory(xs, mn, mx):
36
+ def init(key):
37
+ return {
38
+ "xs": xs,
39
+ "ys": _gp_draw_and_rom(
40
+ key=key,
41
+ xs=xs,
42
+ ys=None,
43
+ length_scale=length_scale,
44
+ mn=mn,
45
+ mx=mx,
46
+ amin=-large_abs_values_of_gps,
47
+ amax=large_abs_values_of_gps,
48
+ ),
49
+ }
50
+
51
+ def apply(params, q):
52
+ return jnp.interp(q, params["xs"], params["ys"])
53
+
54
+ return DrawnFnPair(init, apply)
55
+
56
+ return factory
57
+
58
+
59
+ @dataclass
60
+ class SuntayConfig:
61
+ flexion_rot_min: float = -deg2rad(5.0)
62
+ flexion_rot_max: float = deg2rad(95.0)
63
+ flexion_rot_restrict_method: str = "minmax"
64
+ ###
65
+ flexion_pos_min: float = -0.015
66
+ flexion_pos_max: float = 0.015
67
+ flexion_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
68
+ ###
69
+ abduction_rot_min: float = deg2rad(-4)
70
+ abduction_rot_max: float = deg2rad(4)
71
+ abduction_rot_factory: DrawnFnPairFactory = GP_DrawFnPair()
72
+ ###
73
+ abduction_pos_min: float = -0.015
74
+ abduction_pos_max: float = 0.015
75
+ abduction_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
76
+ ###
77
+ external_rot_min: float = deg2rad(-10)
78
+ external_rot_max: float = deg2rad(10)
79
+ external_rot_factory: DrawnFnPairFactory = GP_DrawFnPair()
80
+ ###
81
+ external_pos_min: float = -0.06
82
+ external_pos_max: float = 0.0
83
+ external_pos_factory: DrawnFnPairFactory = GP_DrawFnPair()
84
+ ###
85
+ num_points: int = 50
86
+ mconfig: Optional[ring.MotionConfig] = None
87
+
88
+
89
+ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
90
+ """Ref to 'E.S. Grood and W.J. Suntay' paper"""
91
+
92
+ flexion_xs = jnp.linspace(
93
+ sconfig.flexion_rot_min, sconfig.flexion_rot_max, num=sconfig.num_points
94
+ )
95
+
96
+ draw_fn_pairs = {}
97
+ for config_name, params_name in zip(
98
+ [
99
+ "flexion_pos",
100
+ "abduction_rot",
101
+ "abduction_pos",
102
+ "external_rot",
103
+ "external_pos",
104
+ ],
105
+ ["ys_S1", "ys_beta", "ys_S2", "ys_gamma", "ys_S3"],
106
+ ):
107
+ get = lambda key: getattr(sconfig, config_name + "_" + key)
108
+ factory = get("factory")
109
+ draw_fn_pairs[params_name] = factory(flexion_xs, get("min"), get("max"))
110
+
111
+ def _suntay_rotation_matrix_R_transpose_eq26(alpha, beta, gamma):
112
+ sin_alp, sin_bet, sin_gam = jnp.sin(alpha), jnp.sin(beta), jnp.sin(gamma)
113
+ cos_alp, cos_bet, cos_gam = jnp.cos(alpha), jnp.cos(beta), jnp.cos(gamma)
114
+ return jnp.array(
115
+ [
116
+ [cos_gam * sin_bet, sin_gam * sin_bet, cos_bet],
117
+ [
118
+ -cos_alp * sin_gam - cos_gam * sin_alp * cos_bet,
119
+ cos_alp * cos_gam - sin_gam * sin_alp * cos_bet,
120
+ sin_bet * sin_alp,
121
+ ],
122
+ [
123
+ sin_alp * sin_gam - cos_gam * cos_alp * cos_bet,
124
+ -cos_gam * sin_alp - cos_alp * cos_bet * sin_gam,
125
+ cos_alp * sin_bet,
126
+ ],
127
+ ]
128
+ ).T
129
+
130
+ def _suntay_translation_vector_H_eq9(alpha, beta, S):
131
+ sin_alp, sin_bet = jnp.sin(alpha), jnp.sin(beta)
132
+ cos_alp, cos_bet = jnp.cos(alpha), jnp.cos(beta)
133
+ # eq (10)
134
+ U = jnp.array(
135
+ [
136
+ [1, 0, cos_bet],
137
+ [0, cos_alp, sin_alp * sin_bet],
138
+ [0, -sin_alp, cos_alp * sin_bet],
139
+ ]
140
+ )
141
+ return U @ S
142
+
143
+ def _alpha_beta_gamma_S(q_flexion, params):
144
+ assert q_flexion.shape == (1,)
145
+
146
+ # (1,) -> (,)
147
+ q_flexion = q_flexion[0]
148
+
149
+ S_123 = []
150
+ for i in range(1, 4):
151
+ key = f"ys_S{i}"
152
+ S_123.append(draw_fn_pairs[key].apply(params[key], q_flexion))
153
+ S = jnp.stack(S_123)
154
+ # table 2 of suntay paper
155
+ alpha = q_flexion
156
+ # note the minus sign, because in config we specify `abduction` not `adduction`
157
+ adduction = -draw_fn_pairs["ys_beta"].apply(params["ys_beta"], q_flexion)
158
+ beta = jnp.pi / 2 + adduction
159
+ gamma = draw_fn_pairs["ys_gamma"].apply(params["ys_gamma"], q_flexion)
160
+ return alpha, beta, gamma, S
161
+
162
+ def _utils_find_suntay_joint(sys: ring.System) -> str:
163
+ suntay_link_name = None
164
+ for link_name, link_type in zip(sys.link_names, sys.link_types):
165
+ if link_type == name:
166
+ if suntay_link_name is not None:
167
+ raise Exception(
168
+ f"multiple links of type `{name}` found, link_names "
169
+ f"are [{suntay_link_name}, {link_name}]"
170
+ )
171
+ suntay_link_name = link_name
172
+
173
+ if suntay_link_name is None:
174
+ raise Exception(
175
+ f"no link with type `{name}` found, link_types are {sys.link_types}"
176
+ )
177
+ return suntay_link_name
178
+
179
+ def _utils_Q_S_H_alpha_beta_gamma(sys: ring.System, qs: jax.Array):
180
+ # qs.shape = (timesteps, q_size)
181
+ assert qs.ndim == 2
182
+ assert qs.shape[-1] == sys.q_size()
183
+
184
+ suntay_link_name = _utils_find_suntay_joint(sys)
185
+
186
+ params = jax.tree_map(
187
+ lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
188
+ sys.links.joint_params[name],
189
+ )
190
+ # shape = (timesteps, 1)
191
+ q_flexion = qs[:, sys.idx_map("q")[suntay_link_name]]
192
+
193
+ @jax.vmap
194
+ def _Q_S_H_alpha_beta_gamma_from_q_flexion(q_flexion):
195
+ alpha, beta, gamma, S = _alpha_beta_gamma_S(q_flexion, params)
196
+ cos_bet = jnp.cos(beta)
197
+ Q = jnp.array([S[0] + S[2] * cos_bet, S[1], -S[2] - S[0] * cos_bet])
198
+ # translation from femur to tibia
199
+ H = _suntay_translation_vector_H_eq9(alpha, beta, S)
200
+ return Q, S, H, alpha, beta, gamma
201
+
202
+ return _Q_S_H_alpha_beta_gamma_from_q_flexion(q_flexion)
203
+
204
+ def _transform_suntay(q_flexion, params):
205
+ alpha, beta, gamma, S = _alpha_beta_gamma_S(q_flexion, params)
206
+
207
+ # rotation from femur to tibia
208
+ R_T = _suntay_rotation_matrix_R_transpose_eq26(alpha, beta, gamma)
209
+ q_fem_tib = maths.quat_from_3x3(R_T)
210
+ # translation from femur to tibia
211
+ H = _suntay_translation_vector_H_eq9(alpha, beta, S)
212
+
213
+ return ring.Transform.create(pos=H, rot=q_fem_tib)
214
+
215
+ def _init_joint_params_suntay(key):
216
+ params = dict()
217
+ for params_name, draw_fn_pair in draw_fn_pairs.items():
218
+ key, consume = jax.random.split(key)
219
+ params[params_name] = draw_fn_pair.init(consume)
220
+
221
+ return params
222
+
223
+ def _draw_flexion_angle(
224
+ mconfig: ring.MotionConfig,
225
+ key_t: jax.random.PRNGKey,
226
+ key_value: jax.random.PRNGKey,
227
+ dt: float,
228
+ _: jax.Array,
229
+ ) -> jax.Array:
230
+ key_value, consume = jax.random.split(key_value)
231
+
232
+ if sconfig.mconfig is not None:
233
+ mconfig = sconfig.mconfig
234
+
235
+ ANG_0 = jax.random.uniform(
236
+ consume, minval=mconfig.ang0_min, maxval=mconfig.ang0_max
237
+ )
238
+ # `random_angle_over_time` always returns wrapped angles, thus it would be
239
+ # inconsistent to allow an initial value that is not wrapped
240
+ ANG_0 = maths.wrap_to_pi(ANG_0)
241
+ qs_flexion = random_angle_over_time(
242
+ key_t,
243
+ key_value,
244
+ ANG_0,
245
+ mconfig.dang_min,
246
+ mconfig.dang_max,
247
+ mconfig.delta_ang_min,
248
+ mconfig.delta_ang_max,
249
+ mconfig.t_min,
250
+ mconfig.t_max,
251
+ mconfig.T,
252
+ dt,
253
+ 5,
254
+ mconfig.randomized_interpolation_angle,
255
+ mconfig.range_of_motion_hinge,
256
+ mconfig.range_of_motion_hinge_method,
257
+ mconfig.cdf_bins_min,
258
+ mconfig.cdf_bins_max,
259
+ mconfig.interpolation_method,
260
+ )
261
+ return restrict(
262
+ qs_flexion,
263
+ sconfig.flexion_rot_min,
264
+ sconfig.flexion_rot_max,
265
+ -jnp.pi,
266
+ jnp.pi,
267
+ method=sconfig.flexion_rot_restrict_method,
268
+ )
269
+
270
+ joint_model = ring.JointModel(
271
+ transform=_transform_suntay,
272
+ rcmg_draw_fn=_draw_flexion_angle,
273
+ init_joint_params=_init_joint_params_suntay,
274
+ utilities=dict(
275
+ Q_S_H_alpha_beta_gamma=_utils_Q_S_H_alpha_beta_gamma,
276
+ find_suntay_joint=_utils_find_suntay_joint,
277
+ ),
278
+ )
279
+ ring.register_new_joint_type(name, joint_model, 1, qd_width=0, overwrite=True)
280
+
281
+
282
+ def MLP_DrawnFnPair(
283
+ center: bool = False, flexion_center: Optional[float] = None
284
+ ) -> DrawnFnPairFactory:
285
+
286
+ def factory(xs, mn, mx):
287
+ nonlocal flexion_center
288
+
289
+ flexion_mn = jnp.min(xs)
290
+ flexion_mx = jnp.max(xs)
291
+
292
+ if flexion_center is None:
293
+ flexion_center = (flexion_mn + flexion_mx) / 2
294
+
295
+ @hk.without_apply_rng
296
+ @hk.transform
297
+ def mlp(x):
298
+ # normalize the x input; [0, 1]
299
+ x = _shift(x, flexion_mn, flexion_mx)
300
+ # center the x input; [-0.5, 0.5]
301
+ x = x - 0.5
302
+ net = hk.nets.MLP(
303
+ [10, 5, 1],
304
+ activation=jnp.tanh,
305
+ w_init=hk.initializers.RandomNormal(),
306
+ )
307
+ return net(x)
308
+
309
+ example_q = jnp.zeros((1,))
310
+
311
+ def init(key):
312
+ return mlp.init(key, example_q)
313
+
314
+ def _apply(params, q):
315
+ q = q[None]
316
+ return jnp.squeeze(_shift_inv(jax.nn.sigmoid(mlp.apply(params, q)), mn, mx))
317
+
318
+ if center:
319
+
320
+ def apply(params, q):
321
+ return _apply(params, q) - _apply(params, flexion_center)
322
+
323
+ else:
324
+ apply = _apply
325
+
326
+ return DrawnFnPair(init, apply)
327
+
328
+ return factory
329
+
330
+
331
+ def _gp_draw_and_rom(key, xs, ys, length_scale, mn, mx, amin, amax):
332
+ randomized_ys = _gp_draw(key, xs, ys, length_scale)
333
+ if ys is not None:
334
+ amin += jnp.min(ys)
335
+ amax += jnp.max(ys)
336
+ return restrict(randomized_ys, mn, mx, amin, amax)
337
+
338
+
339
+ def _gp_draw(key, xs, ys=None, length: float = 1.0, noise=0.0, method="svd", **kwargs):
340
+ if ys is None:
341
+ ys = jnp.zeros_like(xs)
342
+ cov = _gp_K(lambda *args: _rbf_kernel(*args, length=length), xs, noise)
343
+ return jax.random.multivariate_normal(
344
+ key=key, mean=ys, cov=cov, method=method, **kwargs
345
+ )
346
+
347
+
348
+ def _gp_K(kernel, xs, noise: float):
349
+ assert xs.ndim == 1
350
+ N = len(xs)
351
+ xs = xs[:, None]
352
+
353
+ K = jax.vmap(lambda x1: jax.vmap(lambda x2: kernel(x1, x2))(xs))(xs)
354
+ assert K.shape == (N, N, 1)
355
+ return K[..., 0] + jnp.eye(N) * noise
356
+
357
+
358
+ def _rbf_kernel(x1: float, x2: float, length: float):
359
+ return jnp.exp(-((x1 - x2) ** 2) / (2 * length**2))
360
+
361
+
362
+ def _shift(ys, min, max):
363
+ return (ys - min) / (max - min)
364
+
365
+
366
+ def _shift_inv(ys, min, max):
367
+ return (ys * (max - min)) + min
368
+
369
+
370
+ def _normalize(ys, amin=None, amax=None):
371
+ if amin is None:
372
+ amin = jnp.min(ys)
373
+ else:
374
+ amin = jnp.min(jnp.array([amin, jnp.min(ys)]))
375
+ if amax is None:
376
+ amax = jnp.max(ys)
377
+ else:
378
+ amax = jnp.max(jnp.array([amax, jnp.max(ys)]))
379
+ return _shift(ys, amin, amax)
380
+
381
+
382
+ def _smoothclamp(x, mi, mx):
383
+ return mi + (mx - mi) * (
384
+ lambda t: jnp.where(t < 0, 0, jnp.where(t <= 1, 3 * t**2 - 2 * t**3, 1))
385
+ )((x - mi) / (mx - mi))
386
+
387
+
388
+ def _sigmoidclamp(x, mi, mx):
389
+ return mi + (mx - mi) * (lambda t: (1 + 200 ** (-t + 0.5)) ** (-1))(
390
+ (x - mi) / (mx - mi)
391
+ )
392
+
393
+
394
+ def restrict(
395
+ ys,
396
+ min: float,
397
+ max: float,
398
+ actual_min=None,
399
+ actual_max=None,
400
+ method: str = "minmax",
401
+ method_kwargs=dict(),
402
+ ):
403
+ if method == "minmax":
404
+ # scale to [0, 1]
405
+ ys = _normalize(ys, actual_min, actual_max)
406
+ # scale to [min, max]
407
+ return _shift_inv(ys, min, max)
408
+ elif method == "clip":
409
+ return jnp.clip(ys, min, max)
410
+ elif method == "smoothclamp":
411
+ return _smoothclamp(ys, min, max)
412
+ elif method == "sigmoidclamp":
413
+ return _sigmoidclamp(ys, min, max)
414
+ elif method == "sigmoid":
415
+ # scale to [0, 1]
416
+ ys = _normalize(ys, actual_min, actual_max)
417
+ # scale to [-stepness, stepness]
418
+ stepness = method_kwargs.get("stepness", 3.0)
419
+ ys = _shift_inv(ys, -stepness, stepness)
420
+ # scale to [0, 1]
421
+ ys = jax.nn.sigmoid(ys)
422
+ return _shift_inv(ys, min, max)
423
+ else:
424
+ raise NotImplementedError()