imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,222 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from ring import base
|
6
|
+
from ring import io
|
7
|
+
import tree_utils
|
8
|
+
|
9
|
+
|
10
|
+
def imu_reference_link_name(imu_link_name: str) -> str:
|
11
|
+
return "_" + imu_link_name
|
12
|
+
|
13
|
+
|
14
|
+
def unactuated_subsystem(sys) -> list[str]:
|
15
|
+
return [imu_reference_link_name(name) for name in sys.findall_imus()]
|
16
|
+
|
17
|
+
|
18
|
+
def _subsystem_factory(imu_name: str, pos_min_max: float) -> base.System:
|
19
|
+
assert pos_min_max >= 0
|
20
|
+
pos = f'pos_min="-{pos_min_max} -{pos_min_max} -{pos_min_max}" pos_max="{pos_min_max} {pos_min_max} {pos_min_max}"' # noqa: E501
|
21
|
+
stiff = 'spring_stiff="50 50 50"'
|
22
|
+
damping = 'damping="5 5 5"'
|
23
|
+
return io.load_sys_from_str(
|
24
|
+
f"""
|
25
|
+
<x_xy>
|
26
|
+
<worldbody>
|
27
|
+
<body name="{imu_name}" joint="p3d" {pos if pos_min_max != 0.0 else ""} {stiff} {damping}/>
|
28
|
+
</worldbody>
|
29
|
+
</x_xy>
|
30
|
+
""" # noqa: E501
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def inject_subsystems(
|
35
|
+
sys: base.System,
|
36
|
+
pos_min_max: float = 0.0,
|
37
|
+
**kwargs,
|
38
|
+
) -> base.System:
|
39
|
+
imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
|
40
|
+
|
41
|
+
default_spher_stif = jnp.ones((3,)) * 0.3
|
42
|
+
default_spher_damp = default_spher_stif * 0.1
|
43
|
+
for imu in sys.findall_imus():
|
44
|
+
sys = sys.unfreeze(imu, "spherical")
|
45
|
+
# set default stiffness and damping of spherical joint
|
46
|
+
# this won't override anything because the frozen joint can not have any values
|
47
|
+
qd_slice = sys.idx_map("d")[imu]
|
48
|
+
stiffne = sys.link_spring_stiffness.at[qd_slice].set(default_spher_stif)
|
49
|
+
damping = sys.link_damping.at[qd_slice].set(default_spher_damp)
|
50
|
+
sys = sys.replace(link_spring_stiffness=stiffne, link_damping=damping)
|
51
|
+
|
52
|
+
_imu = imu_reference_link_name(imu)
|
53
|
+
sys = sys.change_link_name(imu, _imu)
|
54
|
+
sys = sys.inject_system(_subsystem_factory(imu, pos_min_max), _imu)
|
55
|
+
|
56
|
+
# attach geoms to newly injected link
|
57
|
+
new_geoms = []
|
58
|
+
|
59
|
+
for geom in sys.geoms:
|
60
|
+
if geom.link_idx in imu_idx_to_name_map:
|
61
|
+
imu_name = imu_idx_to_name_map[geom.link_idx]
|
62
|
+
new_link_idx = sys.name_to_idx(imu_name)
|
63
|
+
geom = geom.replace(link_idx=new_link_idx)
|
64
|
+
new_geoms.append(geom)
|
65
|
+
|
66
|
+
sys = sys.replace(geoms=new_geoms)
|
67
|
+
|
68
|
+
# TODO investigate whether this parse is needed; I don't think so
|
69
|
+
# re-calculate the inertia matrices because the geoms have been re-attached
|
70
|
+
sys = sys.parse()
|
71
|
+
|
72
|
+
# TODO set all joint_params to zeros; they can not be preserved anyways and
|
73
|
+
# otherwise many warnings will be rose
|
74
|
+
# instead warn explicitly once now and move on
|
75
|
+
warnings.warn(
|
76
|
+
"`sys.links.joint_params` has been set to zero, this might lead to "
|
77
|
+
"unexpected behaviour unless you use `randomize_joint_params`"
|
78
|
+
)
|
79
|
+
joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
|
80
|
+
sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
|
81
|
+
|
82
|
+
# double load; this fixes the issue that injected links got appended at the end
|
83
|
+
sys = io.load_sys_from_str(io.save_sys_to_str(sys))
|
84
|
+
|
85
|
+
return sys
|
86
|
+
|
87
|
+
|
88
|
+
_STIF_MIN_SPH = 0.2
|
89
|
+
_STIF_MAX_SPH = 10.0
|
90
|
+
_STIF_MIN_P3D = 25.0
|
91
|
+
_STIF_MAX_P3D = 1e3
|
92
|
+
# damping = factor * stiffness
|
93
|
+
_DAMP_MIN = 0.05
|
94
|
+
_DAMP_MAX = 0.5
|
95
|
+
|
96
|
+
|
97
|
+
def _log_uniform(key, shape, minval, maxval):
|
98
|
+
assert 0 <= minval <= maxval
|
99
|
+
minval, maxval = map(jnp.log, (minval, maxval))
|
100
|
+
return jnp.exp(jax.random.uniform(key, shape, minval=minval, maxval=maxval))
|
101
|
+
|
102
|
+
|
103
|
+
def setup_fn_randomize_damping_stiffness_factory(
|
104
|
+
prob_rigid: float,
|
105
|
+
all_imus_either_rigid_or_flex: bool,
|
106
|
+
imus_surely_rigid: list[str],
|
107
|
+
):
|
108
|
+
assert 0 <= prob_rigid <= 1
|
109
|
+
assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
|
110
|
+
if prob_rigid == 0.0:
|
111
|
+
assert len(imus_surely_rigid) == 0
|
112
|
+
|
113
|
+
def stif_damp_rigid(key):
|
114
|
+
stif_sph = 200.0 * jnp.ones((3,))
|
115
|
+
stif_p3d = 2e4 * jnp.ones((3,))
|
116
|
+
stif = jnp.concatenate((stif_sph, stif_p3d))
|
117
|
+
return stif, stif * 0.2
|
118
|
+
|
119
|
+
def stif_damp_nonrigid(key):
|
120
|
+
keys = jax.random.split(key, 3)
|
121
|
+
stif_sph = _log_uniform(keys[0], (3,), _STIF_MIN_SPH, _STIF_MAX_SPH)
|
122
|
+
stif_p3d = _log_uniform(keys[1], (3,), _STIF_MIN_P3D, _STIF_MAX_P3D)
|
123
|
+
stif = jnp.concatenate((stif_sph, stif_p3d))
|
124
|
+
damp = _log_uniform(keys[2], (6,), _DAMP_MIN, _DAMP_MAX)
|
125
|
+
return stif, stif * damp
|
126
|
+
|
127
|
+
def setup_fn_randomize_damping_stiffness(key, sys: base.System) -> base.System:
|
128
|
+
link_damping = sys.link_damping
|
129
|
+
link_spring_stiffness = sys.link_spring_stiffness
|
130
|
+
|
131
|
+
idx_map = sys.idx_map("d")
|
132
|
+
imus = sys.findall_imus()
|
133
|
+
|
134
|
+
# initialize this RV because it might not get redrawn if
|
135
|
+
# `all_imus_either_rigid_or_flex` is set
|
136
|
+
key, consume = jax.random.split(key)
|
137
|
+
is_rigid = jax.random.bernoulli(consume, prob_rigid)
|
138
|
+
|
139
|
+
# this is only for the assertion used below
|
140
|
+
triggered_surely_rigid = []
|
141
|
+
|
142
|
+
for imu in imus:
|
143
|
+
# _imu has spherical joint and imu has p3d joint
|
144
|
+
slice = jnp.r_[idx_map[imu_reference_link_name(imu)], idx_map[imu]]
|
145
|
+
key, c1, c2 = jax.random.split(key, 3)
|
146
|
+
|
147
|
+
if prob_rigid > 0:
|
148
|
+
if imu in imus_surely_rigid:
|
149
|
+
triggered_surely_rigid.append(imu)
|
150
|
+
# logging.debug(f"IMU {imu} is surely rigid.")
|
151
|
+
stif, damp = stif_damp_rigid(c2)
|
152
|
+
else:
|
153
|
+
if not all_imus_either_rigid_or_flex:
|
154
|
+
is_rigid = jax.random.bernoulli(c1, prob_rigid)
|
155
|
+
stif, damp = jax.lax.cond(
|
156
|
+
is_rigid, stif_damp_rigid, stif_damp_nonrigid, c2
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
stif, damp = stif_damp_nonrigid(c2)
|
160
|
+
link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
|
161
|
+
link_damping = link_damping.at[slice].set(damp)
|
162
|
+
|
163
|
+
assert len(imus_surely_rigid) == len(triggered_surely_rigid)
|
164
|
+
for imu_surely_rigid in imus_surely_rigid:
|
165
|
+
assert imu_surely_rigid in triggered_surely_rigid
|
166
|
+
|
167
|
+
return sys.replace(
|
168
|
+
link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
|
169
|
+
)
|
170
|
+
|
171
|
+
return setup_fn_randomize_damping_stiffness
|
172
|
+
|
173
|
+
|
174
|
+
def _match_q_x_between_sys(
|
175
|
+
sys_small: base.System,
|
176
|
+
q_large: jax.Array,
|
177
|
+
x_large: base.Transform,
|
178
|
+
sys_large: base.System,
|
179
|
+
q_large_skip: list[str],
|
180
|
+
) -> tree_utils.PyTree:
|
181
|
+
assert q_large.ndim == 2
|
182
|
+
assert q_large.shape[1] == sys_large.q_size()
|
183
|
+
assert x_large.shape(1) == sys_large.num_links()
|
184
|
+
|
185
|
+
x_small_indices = []
|
186
|
+
q_small = []
|
187
|
+
q_idx_map = sys_large.idx_map("q")
|
188
|
+
|
189
|
+
def f(_, __, name: str):
|
190
|
+
x_small_indices.append(sys_large.name_to_idx(name))
|
191
|
+
# for the imu links the joint type was changed from spherical to frozen
|
192
|
+
# thus the q_idx_map has slices of length 4 but the `sys_small` has those
|
193
|
+
# imus but with frozen joint type and thus slices of length 0; so skip them
|
194
|
+
if name in q_large_skip:
|
195
|
+
return
|
196
|
+
q_small.append(q_large[:, q_idx_map[name]])
|
197
|
+
|
198
|
+
sys_small.scan(f, "l", sys_small.link_names)
|
199
|
+
|
200
|
+
x_small = tree_utils.tree_indices(x_large, jnp.array(x_small_indices), axis=1)
|
201
|
+
q_small = jnp.concatenate(q_small, axis=1)
|
202
|
+
return q_small, x_small
|
203
|
+
|
204
|
+
|
205
|
+
class GeneratorTrafoHideInjectedBodies:
|
206
|
+
def __call__(self, gen):
|
207
|
+
def _gen(*args):
|
208
|
+
(X, y), (key, q, x, sys_x) = gen(*args)
|
209
|
+
|
210
|
+
# delete injected frames; then rename from `_imu` back to `imu`
|
211
|
+
imus = sys_x.findall_imus()
|
212
|
+
_imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
|
213
|
+
sys = sys_x.delete_system(imus)
|
214
|
+
for _imu, imu in _imu2imu_map.items():
|
215
|
+
sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
|
216
|
+
|
217
|
+
# match q and x to `sys`; second axis is link axis
|
218
|
+
q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
|
219
|
+
|
220
|
+
return (X, y), (key, q, x, sys)
|
221
|
+
|
222
|
+
return _gen
|
@@ -0,0 +1,182 @@
|
|
1
|
+
from types import SimpleNamespace
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from flax import struct
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from ring import base
|
8
|
+
from ring.algorithms import dynamics
|
9
|
+
from ring.algorithms import jcalc
|
10
|
+
|
11
|
+
|
12
|
+
@struct.dataclass
|
13
|
+
class PDControllerState:
|
14
|
+
i: int
|
15
|
+
q_ref_as_dict: dict
|
16
|
+
qd_ref_as_dict: dict
|
17
|
+
P_gains: dict
|
18
|
+
D_gains: dict
|
19
|
+
|
20
|
+
|
21
|
+
def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
22
|
+
"""Computes tau using a PD controller. Returns a pair of (init, apply) functions.
|
23
|
+
|
24
|
+
NOTE: Gains around ~10_000 are good for spherical joints, everything else ~250-300
|
25
|
+
works just fine. Damping should be about 2500 for spherical joints, and
|
26
|
+
about 25 for everything else.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
P: jax.Array of P gains. Shape: (sys_init.qd_size())
|
30
|
+
D: jax.Array of D gains. Shape: (sys_init.qd_size()) where `sys_init` is the
|
31
|
+
system that recorded the reference trajectory `q_ref`
|
32
|
+
If not given, then no D control is applied.
|
33
|
+
|
34
|
+
Returns: Pair of (init, apply) functions
|
35
|
+
init: (sys, q_ref) -> controller_state
|
36
|
+
apply: (controller_state, sys, state) -> controller_state, tau
|
37
|
+
|
38
|
+
Example:
|
39
|
+
>>> gains = jnp.array([250.0] * sys1.qd_size())
|
40
|
+
>>> controller = pd_control(gains, gains)
|
41
|
+
>>> q_ref = rcmg(sys1)
|
42
|
+
>>> cs = controller.init(sys1, q_ref)
|
43
|
+
>>> for t in range(1000):
|
44
|
+
>>> cs, tau = controller.apply(cs, sys2, state)
|
45
|
+
>>> state = dynamics.step(sys2, state, tau)
|
46
|
+
"""
|
47
|
+
|
48
|
+
def init(sys: base.System, q_ref: jax.Array) -> dict:
|
49
|
+
assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
|
50
|
+
assert sys.qd_size() == P.size
|
51
|
+
if D is not None:
|
52
|
+
sys.qd_size() == D.size
|
53
|
+
|
54
|
+
q_ref_as_dict = {}
|
55
|
+
qd_ref_as_dict = {}
|
56
|
+
P_as_dict = {}
|
57
|
+
D_as_dict = {}
|
58
|
+
|
59
|
+
def f(_, __, q_ref_link, name, typ, P_link, D_link):
|
60
|
+
P_as_dict[name] = P_link
|
61
|
+
q_ref_link = q_ref_link.T
|
62
|
+
q_ref_as_dict[name] = q_ref_link
|
63
|
+
|
64
|
+
if D is not None:
|
65
|
+
qd_from_q = jcalc.get_joint_model(typ).qd_from_q
|
66
|
+
if qd_from_q is None:
|
67
|
+
raise NotImplementedError(
|
68
|
+
f"Please specify `JointModel.qd_from_q` for joint type `{typ}`"
|
69
|
+
)
|
70
|
+
qd_ref_as_dict[name] = qd_from_q(q_ref_link, sys.dt)
|
71
|
+
D_as_dict[name] = D_link
|
72
|
+
|
73
|
+
sys.scan(
|
74
|
+
f,
|
75
|
+
"qlldd",
|
76
|
+
q_ref.T,
|
77
|
+
sys.link_names,
|
78
|
+
sys.link_types,
|
79
|
+
P,
|
80
|
+
D if D is not None else jnp.zeros((sys.qd_size(),)),
|
81
|
+
)
|
82
|
+
return PDControllerState(0, q_ref_as_dict, qd_ref_as_dict, P_as_dict, D_as_dict)
|
83
|
+
|
84
|
+
def apply(
|
85
|
+
controller_state: PDControllerState, sys: base.System, state: base.State
|
86
|
+
) -> jax.Array:
|
87
|
+
taus = jnp.zeros((sys.qd_size()))
|
88
|
+
q_ref, qd_ref = jax.tree_map(
|
89
|
+
lambda arr: jax.lax.dynamic_index_in_dim(
|
90
|
+
arr, controller_state.i, keepdims=False
|
91
|
+
),
|
92
|
+
(controller_state.q_ref_as_dict, controller_state.qd_ref_as_dict),
|
93
|
+
)
|
94
|
+
|
95
|
+
def f(_, idx_map, idx, name, typ, q_curr, qd_curr):
|
96
|
+
nonlocal taus
|
97
|
+
|
98
|
+
if name not in controller_state.q_ref_as_dict:
|
99
|
+
return
|
100
|
+
|
101
|
+
p_control_term = jcalc.get_joint_model(typ).p_control_term
|
102
|
+
if p_control_term is None:
|
103
|
+
raise NotImplementedError(
|
104
|
+
f"Please specify `JointModel.p_control_term` for joint type `{typ}`"
|
105
|
+
)
|
106
|
+
P_term = p_control_term(q_curr, q_ref[name])
|
107
|
+
tau = P_term * controller_state.P_gains[name]
|
108
|
+
|
109
|
+
if name in controller_state.qd_ref_as_dict:
|
110
|
+
D_term = (qd_ref[name] - qd_curr) * controller_state.D_gains[name]
|
111
|
+
tau += D_term
|
112
|
+
|
113
|
+
taus = taus.at[idx_map["d"](idx)].set(tau)
|
114
|
+
|
115
|
+
sys.scan(
|
116
|
+
f,
|
117
|
+
"lllqd",
|
118
|
+
list(range(sys.num_links())),
|
119
|
+
sys.link_names,
|
120
|
+
sys.link_types,
|
121
|
+
state.q,
|
122
|
+
state.qd,
|
123
|
+
)
|
124
|
+
|
125
|
+
return controller_state.replace(i=controller_state.i + 1), taus
|
126
|
+
|
127
|
+
return SimpleNamespace(init=init, apply=apply)
|
128
|
+
|
129
|
+
|
130
|
+
def _unroll_dynamics_pd_control(
|
131
|
+
sys: base.System,
|
132
|
+
q_ref: jax.Array,
|
133
|
+
P: jax.Array,
|
134
|
+
D: Optional[jax.Array] = None,
|
135
|
+
nograv: bool = False,
|
136
|
+
sys_q_ref: Optional[base.System] = None,
|
137
|
+
initial_sim_state_is_zeros: bool = False,
|
138
|
+
clip_taus: Optional[float] = None,
|
139
|
+
):
|
140
|
+
assert q_ref.ndim == 2
|
141
|
+
|
142
|
+
if sys_q_ref is None:
|
143
|
+
sys_q_ref = sys
|
144
|
+
|
145
|
+
if nograv:
|
146
|
+
sys = sys.replace(gravity=sys.gravity * 0.0)
|
147
|
+
|
148
|
+
if initial_sim_state_is_zeros:
|
149
|
+
state = base.State.create(sys)
|
150
|
+
else:
|
151
|
+
state = _initial_q_is_q_ref(sys, sys_q_ref, q_ref[0])
|
152
|
+
|
153
|
+
controller = _pd_control(P, D)
|
154
|
+
cs = controller.init(sys_q_ref, q_ref)
|
155
|
+
|
156
|
+
def step(carry, _):
|
157
|
+
state, cs = carry
|
158
|
+
cs, taus = controller.apply(cs, sys, state)
|
159
|
+
if clip_taus is not None:
|
160
|
+
assert clip_taus > 0.0
|
161
|
+
taus = jnp.clip(taus, -clip_taus, clip_taus)
|
162
|
+
state = dynamics.step(sys, state, taus)
|
163
|
+
carry = (state, cs)
|
164
|
+
return carry, state
|
165
|
+
|
166
|
+
states = jax.lax.scan(step, (state, cs), None, length=q_ref.shape[0])[1]
|
167
|
+
return states
|
168
|
+
|
169
|
+
|
170
|
+
def _initial_q_is_q_ref(sys: base.System, sys_q_ref: base.System, q_ref):
|
171
|
+
# you can not preallocate q using zeros because of quaternions..
|
172
|
+
q = base.State.create(sys).q
|
173
|
+
|
174
|
+
sys_q_map = sys.idx_map("q")
|
175
|
+
|
176
|
+
def f(_, __, name, q_ref_link):
|
177
|
+
nonlocal q
|
178
|
+
q = q.at[sys_q_map[name]].set(q_ref_link)
|
179
|
+
|
180
|
+
sys_q_ref.scan(f, "lq", sys_q_ref.link_names, q_ref)
|
181
|
+
|
182
|
+
return base.State.create(sys, q=q)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
"""Randomization by modifying System and MotionConfig objects before building
|
2
|
+
generator."""
|
3
|
+
|
4
|
+
from dataclasses import replace
|
5
|
+
import itertools
|
6
|
+
from typing import Optional
|
7
|
+
import warnings
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
from ring import base
|
11
|
+
from ring.algorithms import jcalc
|
12
|
+
from ring.algorithms.generator import types
|
13
|
+
|
14
|
+
|
15
|
+
def _find_children(lam: list[int], body: int) -> list[int]:
|
16
|
+
|
17
|
+
children = []
|
18
|
+
|
19
|
+
def _children(body: int) -> None:
|
20
|
+
for i in range(len(lam)):
|
21
|
+
if lam[i] == body:
|
22
|
+
children.append(i)
|
23
|
+
_children(i)
|
24
|
+
|
25
|
+
_children(body)
|
26
|
+
return children
|
27
|
+
|
28
|
+
|
29
|
+
def _find_root_of_subsys_that_contains_body(sys: base.System, body: str) -> str:
|
30
|
+
body_i = sys.name_to_idx(body)
|
31
|
+
for i, p in enumerate(sys.link_parents):
|
32
|
+
if p == -1:
|
33
|
+
if body_i == i or body_i in _find_children(sys.link_parents, i):
|
34
|
+
return sys.idx_to_name(i)
|
35
|
+
|
36
|
+
|
37
|
+
def _assign_anchors_to_subsys(sys: base.System, anchors: list[str]) -> list[list[str]]:
|
38
|
+
anchors_per_subsys = []
|
39
|
+
for i, p in enumerate(sys.link_parents):
|
40
|
+
if p == -1:
|
41
|
+
link_idxs_subsys = [i] + _find_children(sys.link_parents, i)
|
42
|
+
link_names_subsys = [sys.idx_to_name(i) for i in link_idxs_subsys]
|
43
|
+
anchors_this_subsys = [
|
44
|
+
name for name in anchors if name in link_names_subsys
|
45
|
+
]
|
46
|
+
if len(anchors_this_subsys) == 0:
|
47
|
+
anchors_this_subsys = [sys.idx_to_name(i)]
|
48
|
+
anchors_per_subsys.append(anchors_this_subsys)
|
49
|
+
return anchors_per_subsys
|
50
|
+
|
51
|
+
|
52
|
+
def _morph_extract_subsys(sys: base.System, anchor: str):
|
53
|
+
root = _find_root_of_subsys_that_contains_body(sys, anchor)
|
54
|
+
roots = sys.findall_bodies_to_world(names=True)
|
55
|
+
subsys = sys.delete_system(list(set(roots) - set([root])))
|
56
|
+
return subsys.morph_system(new_anchor=anchor)
|
57
|
+
|
58
|
+
|
59
|
+
def randomize_anchors(
|
60
|
+
sys: base.System, anchors: Optional[list[str]] = None
|
61
|
+
) -> list[base.System]:
|
62
|
+
|
63
|
+
if anchors is None:
|
64
|
+
anchors = sys.findall_segments()
|
65
|
+
|
66
|
+
anchors = _assign_anchors_to_subsys(sys, anchors)
|
67
|
+
syss = []
|
68
|
+
for anchors_subsys in itertools.product(*anchors):
|
69
|
+
sys_mod = _morph_extract_subsys(sys, anchors_subsys[0])
|
70
|
+
for anchor_subsys in anchors_subsys[1:]:
|
71
|
+
sys_mod = sys_mod.inject_system(_morph_extract_subsys(sys, anchor_subsys))
|
72
|
+
syss.append(sys_mod)
|
73
|
+
|
74
|
+
return syss
|
75
|
+
|
76
|
+
|
77
|
+
_WARN_HZ_Threshold: float = 40.0
|
78
|
+
|
79
|
+
|
80
|
+
def randomize_hz(
|
81
|
+
sys: list[base.System],
|
82
|
+
configs: list[jcalc.MotionConfig],
|
83
|
+
sampling_rates: list[float],
|
84
|
+
) -> tuple[list[base.System], list[jcalc.MotionConfig]]:
|
85
|
+
Ts = [c.T for c in configs]
|
86
|
+
assert len(set(Ts)), f"Time length between configs does not agree {Ts}"
|
87
|
+
T_global = Ts[0]
|
88
|
+
|
89
|
+
for hz in sampling_rates:
|
90
|
+
if hz < _WARN_HZ_Threshold:
|
91
|
+
warnings.warn(
|
92
|
+
"The sampling rate {hz} is below the warning threshold of "
|
93
|
+
f"{_WARN_HZ_Threshold}. This might lead to NaNs."
|
94
|
+
)
|
95
|
+
|
96
|
+
sys_out, configs_out = [], []
|
97
|
+
for _sys in sys:
|
98
|
+
for _config in configs:
|
99
|
+
for hz in sampling_rates:
|
100
|
+
dt = 1 / hz
|
101
|
+
T = (T_global / _sys.dt) * dt
|
102
|
+
|
103
|
+
sys_out.append(_sys.replace(dt=dt))
|
104
|
+
configs_out.append(replace(_config, T=T))
|
105
|
+
return sys_out, configs_out
|
106
|
+
|
107
|
+
|
108
|
+
def randomize_hz_finalize_fn_factory(finalize_fn_user: types.FINALIZE_FN | None):
|
109
|
+
def finalize_fn(key, q, x, sys: base.System):
|
110
|
+
X, y = {}, {}
|
111
|
+
if finalize_fn_user is not None:
|
112
|
+
X, y = finalize_fn_user(key, q, x, sys)
|
113
|
+
|
114
|
+
assert "dt" not in X
|
115
|
+
X["dt"] = jnp.array([sys.dt], dtype=jnp.float32)
|
116
|
+
|
117
|
+
return X, y
|
118
|
+
|
119
|
+
return finalize_fn
|