imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,202 @@
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import jaxopt
6
+ from jaxopt._src.base import Solver
7
+ from ring import algebra
8
+ from ring import base
9
+ from ring import maths
10
+ from ring.algorithms import jcalc
11
+
12
+
13
+ def forward_kinematics_transforms(
14
+ sys: base.System, q: jax.Array
15
+ ) -> Tuple[base.Transform, base.System]:
16
+ """Perform forward kinematics in system.
17
+
18
+ Returns:
19
+ - Transforms from base to links. Transforms first axis is (n_links,).
20
+ - Updated system object with updated `transform2` and `transform` fields.
21
+ """
22
+
23
+ eps_to_l = {-1: base.Transform.zero()}
24
+
25
+ def update_eps_to_l(_, __, q, link, link_idx, parent_idx, joint_type: str):
26
+ transform2 = jcalc.jcalc_transform(joint_type, q, link.joint_params)
27
+ transform = algebra.transform_mul(transform2, link.transform1)
28
+ link = link.replace(transform=transform, transform2=transform2)
29
+ eps_to_l[link_idx] = algebra.transform_mul(transform, eps_to_l[parent_idx])
30
+ return eps_to_l[link_idx], link
31
+
32
+ eps_to_l_trafos, updated_links = sys.scan(
33
+ update_eps_to_l,
34
+ "qllll",
35
+ q,
36
+ sys.links,
37
+ list(range(sys.num_links())),
38
+ sys.link_parents,
39
+ sys.link_types,
40
+ )
41
+ sys = sys.replace(links=updated_links)
42
+ return (eps_to_l_trafos, sys)
43
+
44
+
45
+ def forward_kinematics(
46
+ sys: base.System, state: base.State
47
+ ) -> Tuple[base.System, base.State]:
48
+ """Perform forward kinematics in system.
49
+ - Updates `transform` and `transform2` in `sys`
50
+ - Updates `x` in `state`
51
+ """
52
+ x, sys = forward_kinematics_transforms(sys, state.q)
53
+ state = state.replace(x=x)
54
+ return sys, state
55
+
56
+
57
+ def inverse_kinematics(
58
+ sys: base.System,
59
+ state: base.State,
60
+ ) -> base.State:
61
+ """Performs inverse kinematics in system. Updates `q` in `state`"""
62
+ x = state.x
63
+ q = []
64
+
65
+ def f(_, __, i: int, x_i: base.Transform, link_i: base.Link, p: int, typ: str):
66
+ if p == -1:
67
+ x_p = base.Transform.zero()
68
+ else:
69
+ x_p = x[p]
70
+ joint_params = jcalc._limit_scope_of_joint_params(typ, link_i.joint_params)
71
+ transform_p_to_i = algebra.transform_mul(x_i, algebra.transform_inv(x_p))
72
+ transform2 = algebra.transform_mul(
73
+ transform_p_to_i, algebra.transform_inv(link_i.transform1)
74
+ )
75
+ inv_kin_link = jcalc.get_joint_model(typ).inv_kin
76
+ if inv_kin_link is None:
77
+ raise NotImplementedError(
78
+ f"Please specify for the custom joint `{typ}`"
79
+ " the JointModel.inv_kin field."
80
+ )
81
+ q.append(inv_kin_link(transform2, joint_params))
82
+
83
+ sys.scan(
84
+ f,
85
+ "lllll",
86
+ list(range(sys.num_links())),
87
+ x,
88
+ sys.links,
89
+ sys.link_parents,
90
+ sys.link_types,
91
+ )
92
+
93
+ q = jnp.concatenate(q)
94
+ assert q.ndim == 1
95
+ assert q.size == sys.q_size()
96
+
97
+ return state.replace(q=q)
98
+
99
+
100
+ def inverse_kinematics_endeffector(
101
+ sys: base.System,
102
+ endeffector_link_name: str,
103
+ endeffector_x: base.Transform,
104
+ error_weight_rot: float = 1.0,
105
+ error_weight_pos: float = 1.0,
106
+ q0: Optional[jax.Array] = None,
107
+ random_q0_starts: Optional[int] = None,
108
+ key: Optional[jax.Array] = None,
109
+ custom_joints: dict[str, Callable[[jax.Array], jax.Array]] = {},
110
+ jaxopt_solver: Solver = jaxopt.LBFGS,
111
+ **jaxopt_solver_kwargs,
112
+ ) -> tuple[jax.Array, jaxopt.OptStep]:
113
+ """Find the minimal coordinates (joint configuration) such that the endeffector
114
+ reaches a desired rotational and positional configuration / state.
115
+
116
+ Args:
117
+ sys (base.System): System under consideration.
118
+ endeffector_link_name (str): Link in system which must reach a desired
119
+ pos & rot state.
120
+ endeffector_x (base.Transform): Desired position and rotation state values.
121
+ error_weight_rot (float, optional): Weight of position error term in
122
+ optimized RMSE loss. Defaults to 1.0.
123
+ error_weight_pos (float, optional): Weight of rotational error term in
124
+ optimized RMSE loss. Defaults to 1.0.
125
+ q0 (Optional[jax.Array], optional): Initial minimal coordinates guess.
126
+ Defaults to None.
127
+ random_q0_starts (Optional[int], optional): Number of random initial values
128
+ to try. Defaults to None.
129
+ key (Optional[jax.Array], optional): PRNGKey, only required if
130
+ `random_q0_starts` > 0. Defaults to None.
131
+ custom_joints (dict[str, Callable[[jax.Array], jax.Array]], optional):
132
+ Dictonary that contains for each custom joint type a function that maps from
133
+ [-inf, inf] -> feasible joint value range. Defaults to {}.
134
+ For example: By default, for a hinge joint it uses `maths.wrap_to_pi`.
135
+ jaxopt_solver (Solver, optional): Solver to use. Defaults to jaxopt.LBFGS.
136
+
137
+ Raises:
138
+ NotImplementedError: Specific joint has no preprocess function given in
139
+ `custom_joints`; but this is required.
140
+
141
+ Returns:
142
+ tuple[jax.Array, jaxopt.OptStep]:
143
+ Minimal coordinates solution, Residual Loss, Optimizer Results
144
+ """
145
+ assert endeffector_x.ndim() == 1, "Use `jax.vmap` for batching"
146
+
147
+ if random_q0_starts is not None:
148
+ assert q0 is None, "Either provide `q0` or `random_q0_starts`."
149
+ assert key is not None, "`random_q0_starts` requires `key`"
150
+
151
+ if q0 is None:
152
+ if random_q0_starts is None:
153
+ q0 = base.State.create(sys).q
154
+ else:
155
+ q0s = jax.random.normal(key, shape=(random_q0_starts, sys.q_size()))
156
+ qs, values, results = jax.vmap(
157
+ lambda q0: inverse_kinematics_endeffector(
158
+ sys,
159
+ endeffector_link_name,
160
+ endeffector_x,
161
+ error_weight_rot,
162
+ error_weight_pos,
163
+ q0,
164
+ None,
165
+ None,
166
+ custom_joints,
167
+ jaxopt_solver,
168
+ **jaxopt_solver_kwargs,
169
+ )
170
+ )(q0s)
171
+
172
+ # find result of best q0 initial value
173
+ best_q_index = jnp.argmin(values)
174
+ best_q, best_q_value = jax.tree_map(
175
+ lambda arr: jax.lax.dynamic_index_in_dim(
176
+ arr, best_q_index, keepdims=False
177
+ ),
178
+ (
179
+ qs,
180
+ values,
181
+ ),
182
+ )
183
+ return best_q, best_q_value, results
184
+ else:
185
+ assert len(q0) == sys.q_size()
186
+
187
+ def objective(q: jax.Array) -> jax.Array:
188
+ q = sys.coordinate_vector_to_q(q, custom_joints)
189
+ xhat = forward_kinematics_transforms(sys, q)[0][
190
+ sys.name_to_idx(endeffector_link_name)
191
+ ]
192
+ error_rot = maths.angle_error(endeffector_x.rot, xhat.rot)
193
+ error_pos = jnp.sqrt(jnp.sum((endeffector_x.pos - xhat.pos) ** 2))
194
+ return error_weight_rot * error_rot + error_weight_pos * error_pos
195
+
196
+ solver = jaxopt_solver(objective, **jaxopt_solver_kwargs)
197
+ results = solver.run(q0)
198
+ q_sol = sys.coordinate_vector_to_q(results.params, custom_joints)
199
+ # stop gradients such that this value can be used for optimizing e.g.
200
+ # parameters in the system object, such as sys.links.joint_params
201
+ q_sol_value = objective(jax.lax.stop_gradient(results.params))
202
+ return q_sol, q_sol_value, results