imt-ring 1.6.15__py3-none-any.whl → 1.6.17__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.15.dist-info → imt_ring-1.6.17.dist-info}/METADATA +3 -3
- {imt_ring-1.6.15.dist-info → imt_ring-1.6.17.dist-info}/RECORD +10 -9
- {imt_ring-1.6.15.dist-info → imt_ring-1.6.17.dist-info}/WHEEL +1 -1
- ring/__init__.py +4 -0
- ring/algorithms/custom_joints/__init__.py +1 -0
- ring/algorithms/custom_joints/rsaddle_joint.py +40 -0
- ring/algorithms/generator/finalize_fns.py +9 -5
- ring/algorithms/jcalc.py +32 -1
- ring/ml/ringnet.py +11 -2
- {imt_ring-1.6.15.dist-info → imt_ring-1.6.17.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.6.
|
3
|
+
Version: 1.6.17
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -32,11 +32,11 @@ Requires-Dist: pytest-xdist ; extra == 'dev'
|
|
32
32
|
Requires-Dist: nbmake ; extra == 'dev'
|
33
33
|
|
34
34
|
<p align="center">
|
35
|
-
<img src="https://raw.githubusercontent.com/
|
35
|
+
<img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/icon.svg" height="200" />
|
36
36
|
</p>
|
37
37
|
|
38
38
|
# Recurrent Inertial Graph-based Estimator (RING)
|
39
|
-
<img src="https://raw.githubusercontent.com/
|
39
|
+
<img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/coverage_badge.svg" height="20" />
|
40
40
|
|
41
41
|
## Installation
|
42
42
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
ring/__init__.py,sha256=
|
1
|
+
ring/__init__.py,sha256=ncBRdHvge6uqpzFyk8_HUQNx4kZxENpVXTJTEK_SjCg,5216
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
3
|
ring/base.py,sha256=Ystn1EjTyOXBhVm5koroV_YPUYtFxrteJLd-XR3kEL8,33840
|
4
4
|
ring/maths.py,sha256=qPHH6TpHCK3TgExI98gNEySoSRKOwteN9McUlyUFipI,12207
|
@@ -6,17 +6,18 @@ ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
7
|
ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
|
8
8
|
ring/algorithms/dynamics.py,sha256=GOedL1STj6oXcXgMA7dB4PabvCQxPBbirJQhXBRuKqE,10929
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
9
|
+
ring/algorithms/jcalc.py,sha256=kQgrRE0XoUBrcdeFbUw_xKnL4m1P2G1Q1m4n7BX2yDk,30064
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
|
12
|
-
ring/algorithms/custom_joints/__init__.py,sha256=
|
12
|
+
ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
+
ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
|
15
16
|
ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
|
16
17
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
17
18
|
ring/algorithms/generator/base.py,sha256=vxUdA0ZeSNH3SOanL51qVRvCiJrmsWQyQX0g2fdm3Rg,15825
|
18
19
|
ring/algorithms/generator/batch.py,sha256=9yFxVv11hij-fJXGPxA3zEh1bE2_jrZk0R7kyGaiM5c,2551
|
19
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
20
|
+
ring/algorithms/generator/finalize_fns.py,sha256=EZ5p7fuZu0Zd0rHJzCVg3vy3U9ysny6TfQfolIGPERc,10029
|
20
21
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
21
22
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
22
23
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
@@ -55,7 +56,7 @@ ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
|
|
55
56
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
57
|
ring/ml/ml_utils.py,sha256=1GXJfeoXbwCbRdYA2np3CbJpSupaw4eyf3quh9y4BO0,6462
|
57
58
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
|
-
ring/ml/ringnet.py,sha256=
|
59
|
+
ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
|
59
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
60
61
|
ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
|
61
62
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
@@ -84,7 +85,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
84
85
|
ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
|
85
86
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
86
87
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
87
|
-
imt_ring-1.6.
|
88
|
-
imt_ring-1.6.
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
88
|
+
imt_ring-1.6.17.dist-info/METADATA,sha256=j0IKIyc6qAgz9059Z4-b46hi7qTn_ffI7oLw_OrD_Tk,3833
|
89
|
+
imt_ring-1.6.17.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
90
|
+
imt_ring-1.6.17.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
91
|
+
imt_ring-1.6.17.dist-info/RECORD,,
|
ring/__init__.py
CHANGED
@@ -121,6 +121,7 @@ _UNIQUE_ID = None
|
|
121
121
|
def setup(
|
122
122
|
rr_joint_kwargs: None | dict = dict(),
|
123
123
|
rr_imp_joint_kwargs: None | dict = dict(),
|
124
|
+
rsaddle_joint_kwargs: None | dict = dict(),
|
124
125
|
suntay_joint_kwargs: None | dict = None,
|
125
126
|
train_timing_start: None | float = None,
|
126
127
|
unique_id: None | str = None,
|
@@ -138,6 +139,9 @@ def setup(
|
|
138
139
|
if rr_imp_joint_kwargs is not None:
|
139
140
|
custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
|
140
141
|
|
142
|
+
if rsaddle_joint_kwargs is not None:
|
143
|
+
custom_joints.register_rsaddle_joint(**rsaddle_joint_kwargs)
|
144
|
+
|
141
145
|
if suntay_joint_kwargs is not None:
|
142
146
|
custom_joints.register_suntay(**suntay_joint_kwargs)
|
143
147
|
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import ring
|
4
|
+
from ring import maths
|
5
|
+
from ring.algorithms.jcalc import _draw_saddle
|
6
|
+
from ring.algorithms.jcalc import _p_control_term_rxyz
|
7
|
+
from ring.algorithms.jcalc import _qd_from_q_cartesian
|
8
|
+
|
9
|
+
|
10
|
+
def register_rsaddle_joint():
|
11
|
+
def _transform(q, params):
|
12
|
+
axes = params["joint_axes"]
|
13
|
+
rot1 = maths.quat_rot_axis(axes[0], q[0])
|
14
|
+
rot2 = maths.quat_rot_axis(axes[1], q[1])
|
15
|
+
rot = maths.quat_mul(rot2, rot1)
|
16
|
+
return ring.Transform.create(rot=rot)
|
17
|
+
|
18
|
+
def _motion_fn_gen(i: int):
|
19
|
+
def _motion_fn(params):
|
20
|
+
axis = params["joint_axes"][i]
|
21
|
+
return ring.base.Motion.create(ang=axis)
|
22
|
+
|
23
|
+
return _motion_fn
|
24
|
+
|
25
|
+
joint_model = ring.JointModel(
|
26
|
+
_transform,
|
27
|
+
motion=[_motion_fn_gen(i) for i in range(2)],
|
28
|
+
rcmg_draw_fn=_draw_saddle,
|
29
|
+
p_control_term=_p_control_term_rxyz,
|
30
|
+
qd_from_q=_qd_from_q_cartesian,
|
31
|
+
init_joint_params=_draw_random_joint_axes,
|
32
|
+
)
|
33
|
+
|
34
|
+
ring.register_new_joint_type("rsaddle", joint_model, 2, overwrite=True)
|
35
|
+
|
36
|
+
|
37
|
+
def _draw_random_joint_axes(key):
|
38
|
+
return dict(
|
39
|
+
joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key, (2,)))
|
40
|
+
)
|
@@ -160,6 +160,7 @@ _P_gains = {
|
|
160
160
|
"spherical": jnp.array(3 * [P_rot]),
|
161
161
|
"p3d": jnp.array(3 * [P_pos]),
|
162
162
|
"saddle": jnp.array([P_rot, P_rot]),
|
163
|
+
"rsaddle": jnp.array([P_rot, P_rot]),
|
163
164
|
"frozen": jnp.array([]),
|
164
165
|
"suntay": jnp.array([P_rot]),
|
165
166
|
}
|
@@ -182,13 +183,16 @@ class DynamicalSimulation:
|
|
182
183
|
|
183
184
|
@staticmethod
|
184
185
|
def assert_test_system(sys: base.System) -> None:
|
185
|
-
"test that system has no zero mass bodies and no joints without damping"
|
186
|
+
"test that system has no zero mass leaf bodies and no joints without damping"
|
186
187
|
|
187
188
|
def f(_, __, n, m, d):
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
189
|
+
is_leaf_body = len(sys.children(n)) == 0
|
190
|
+
if is_leaf_body:
|
191
|
+
assert d.size == 0 or m > 0, (
|
192
|
+
"Dynamic simulation is set to `True` which requires masses >= 0, "
|
193
|
+
f"but found body `{n}` with mass={float(m[0])}. This can lead to "
|
194
|
+
"NaNs."
|
195
|
+
)
|
192
196
|
|
193
197
|
assert d.size == 0 or all(d > 0.0), (
|
194
198
|
"Dynamic simulation is set to `True` which requires dampings > 0, "
|
ring/algorithms/jcalc.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from collections import defaultdict
|
1
2
|
from dataclasses import asdict
|
2
3
|
from dataclasses import dataclass
|
3
4
|
from dataclasses import field
|
@@ -87,11 +88,39 @@ class MotionConfig:
|
|
87
88
|
assert nomotion_config.is_feasible()
|
88
89
|
return nomotion_config
|
89
90
|
|
91
|
+
@staticmethod
|
92
|
+
def overwrite_for_joint_type(joint_type: str, **changes) -> None:
|
93
|
+
"""Changes values of the `MotionConfig` used by the draw_fn for only a specific
|
94
|
+
joint.
|
95
|
+
"""
|
96
|
+
previous_changes = _overwrite_for_joint_type_changes[joint_type]
|
97
|
+
for change in changes:
|
98
|
+
assert change not in previous_changes, f"For jointtype={joint_type} you "
|
99
|
+
f"previously changed the value={change}. You can't change it again, this "
|
100
|
+
"is not supported."
|
101
|
+
previous_changes.update(changes)
|
102
|
+
|
103
|
+
jm = get_joint_model(joint_type)
|
104
|
+
|
105
|
+
def draw_fn(config, *args):
|
106
|
+
return jm.rcmg_draw_fn(replace(config, **changes), *args)
|
107
|
+
|
108
|
+
register_new_joint_type(
|
109
|
+
joint_type,
|
110
|
+
replace(jm, rcmg_draw_fn=draw_fn),
|
111
|
+
base.Q_WIDTHS[joint_type],
|
112
|
+
base.QD_WIDTHS[joint_type],
|
113
|
+
overwrite=True,
|
114
|
+
)
|
115
|
+
|
90
116
|
@staticmethod
|
91
117
|
def from_register(name: str) -> "MotionConfig":
|
92
118
|
return _registered_motion_configs[name]
|
93
119
|
|
94
120
|
|
121
|
+
_overwrite_for_joint_type_changes: dict[str, dict] = defaultdict(lambda: dict())
|
122
|
+
|
123
|
+
|
95
124
|
_registered_motion_configs = {
|
96
125
|
"hinUndHer": MotionConfig(
|
97
126
|
t_min=0.3,
|
@@ -222,7 +251,9 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
|
|
222
251
|
def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
|
223
252
|
return (x0_min >= x_min) and (x0_max <= x_max)
|
224
253
|
|
225
|
-
cond2 = inside_box_checks(
|
254
|
+
cond2 = inside_box_checks(
|
255
|
+
_to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
|
256
|
+
)
|
226
257
|
|
227
258
|
return cond1 and cond2
|
228
259
|
|
ring/ml/ringnet.py
CHANGED
@@ -87,6 +87,7 @@ def make_ring(
|
|
87
87
|
link_output_normalize: bool = True,
|
88
88
|
link_output_transform: Optional[Callable] = None,
|
89
89
|
layernorm: bool = True,
|
90
|
+
layernorm_trainable: bool = True,
|
90
91
|
) -> SimpleNamespace:
|
91
92
|
|
92
93
|
if link_output_normalize:
|
@@ -104,7 +105,11 @@ def make_ring(
|
|
104
105
|
)
|
105
106
|
|
106
107
|
inner_cell = StackedRNNCell(
|
107
|
-
celltype,
|
108
|
+
celltype,
|
109
|
+
hidden_state_dim,
|
110
|
+
stack_rnn_cells,
|
111
|
+
layernorm=layernorm,
|
112
|
+
layernorm_trainable=layernorm_trainable,
|
108
113
|
)
|
109
114
|
send_output = hk.nets.MLP([hidden_state_dim, link_output_dim])
|
110
115
|
state = hk.get_state(
|
@@ -143,6 +148,7 @@ class StackedRNNCell(hk.Module):
|
|
143
148
|
hidden_state_dim,
|
144
149
|
stacks: int,
|
145
150
|
layernorm: bool = False,
|
151
|
+
layernorm_trainable: bool = True,
|
146
152
|
name: str | None = None,
|
147
153
|
):
|
148
154
|
super().__init__(name)
|
@@ -150,6 +156,7 @@ class StackedRNNCell(hk.Module):
|
|
150
156
|
|
151
157
|
self.cells = [cell(hidden_state_dim) for _ in range(stacks)]
|
152
158
|
self.layernorm = layernorm
|
159
|
+
self.layernorm_trainable = layernorm_trainable
|
153
160
|
|
154
161
|
def __call__(self, x, state):
|
155
162
|
output = x
|
@@ -159,7 +166,9 @@ class StackedRNNCell(hk.Module):
|
|
159
166
|
next_state.append(next_state_i)
|
160
167
|
|
161
168
|
if self.layernorm:
|
162
|
-
output = hk.LayerNorm(
|
169
|
+
output = hk.LayerNorm(
|
170
|
+
-1, self.layernorm_trainable, self.layernorm_trainable
|
171
|
+
)(output)
|
163
172
|
|
164
173
|
return output, jnp.stack(next_state)
|
165
174
|
|
File without changes
|