imt-ring 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/METADATA +1 -1
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/RECORD +19 -18
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/WHEEL +1 -1
- ring/__init__.py +21 -10
- ring/algorithms/__init__.py +1 -11
- ring/algorithms/generator/__init__.py +2 -16
- ring/algorithms/generator/base.py +245 -276
- ring/algorithms/generator/batch.py +26 -109
- ring/algorithms/generator/finalize_fns.py +306 -0
- ring/algorithms/generator/motion_artifacts.py +31 -19
- ring/algorithms/generator/setup_fns.py +43 -0
- ring/algorithms/generator/types.py +3 -18
- ring/algorithms/jcalc.py +0 -9
- ring/rendering/mujoco_render.py +2 -1
- ring/utils/__init__.py +3 -4
- ring/utils/batchsize.py +12 -4
- ring/utils/utils.py +6 -0
- ring/algorithms/generator/transforms.py +0 -411
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/top_level.txt +0 -0
- /ring/{algorithms/generator/randomize.py → utils/randomize_sys.py} +0 -0
ring/algorithms/jcalc.py
CHANGED
@@ -60,7 +60,6 @@ class MotionConfig:
|
|
60
60
|
pos0_max: float = 0.0
|
61
61
|
|
62
62
|
# cor (center of rotation) custom fields
|
63
|
-
cor: bool = False
|
64
63
|
cor_t_min: float = 0.2
|
65
64
|
cor_t_max: float | TimeDependentFloat = 2.0
|
66
65
|
cor_dpos_min: float | TimeDependentFloat = 0.00001
|
@@ -102,7 +101,6 @@ _registered_motion_configs = {
|
|
102
101
|
pos_min=-1.5,
|
103
102
|
pos_max=1.5,
|
104
103
|
randomized_interpolation_angle=True,
|
105
|
-
cor=True,
|
106
104
|
),
|
107
105
|
"langsam": MotionConfig(
|
108
106
|
t_min=0.2,
|
@@ -114,13 +112,11 @@ _registered_motion_configs = {
|
|
114
112
|
cdf_bins_max=3,
|
115
113
|
pos_min=-1.5,
|
116
114
|
pos_max=1.5,
|
117
|
-
cor=True,
|
118
115
|
),
|
119
116
|
"standard": MotionConfig(
|
120
117
|
randomized_interpolation_angle=True,
|
121
118
|
cdf_bins_min=1,
|
122
119
|
cdf_bins_max=5,
|
123
|
-
cor=True,
|
124
120
|
),
|
125
121
|
"expFast": MotionConfig(
|
126
122
|
t_min=0.4,
|
@@ -134,7 +130,6 @@ _registered_motion_configs = {
|
|
134
130
|
randomized_interpolation_angle=True,
|
135
131
|
cdf_bins_min=1,
|
136
132
|
cdf_bins_max=3,
|
137
|
-
cor=True,
|
138
133
|
),
|
139
134
|
"expSlow": MotionConfig(
|
140
135
|
t_min=0.75,
|
@@ -151,7 +146,6 @@ _registered_motion_configs = {
|
|
151
146
|
randomized_interpolation_angle=True,
|
152
147
|
cdf_bins_min=1,
|
153
148
|
cdf_bins_max=5,
|
154
|
-
cor=True,
|
155
149
|
),
|
156
150
|
"expFastNoSig": MotionConfig(
|
157
151
|
t_min=0.4,
|
@@ -164,7 +158,6 @@ _registered_motion_configs = {
|
|
164
158
|
randomized_interpolation_angle=True,
|
165
159
|
cdf_bins_min=1,
|
166
160
|
cdf_bins_max=3,
|
167
|
-
cor=True,
|
168
161
|
),
|
169
162
|
"expSlowNoSig": MotionConfig(
|
170
163
|
t_min=0.75,
|
@@ -180,7 +173,6 @@ _registered_motion_configs = {
|
|
180
173
|
randomized_interpolation_angle=True,
|
181
174
|
cdf_bins_min=1,
|
182
175
|
cdf_bins_max=3,
|
183
|
-
cor=True,
|
184
176
|
),
|
185
177
|
"verySlow": MotionConfig(
|
186
178
|
t_min=1.5,
|
@@ -196,7 +188,6 @@ _registered_motion_configs = {
|
|
196
188
|
randomized_interpolation_angle=True,
|
197
189
|
cdf_bins_min=1,
|
198
190
|
cdf_bins_max=3,
|
199
|
-
cor=True,
|
200
191
|
),
|
201
192
|
}
|
202
193
|
|
ring/rendering/mujoco_render.py
CHANGED
@@ -2,10 +2,12 @@ from typing import Optional, Sequence
|
|
2
2
|
|
3
3
|
import mujoco
|
4
4
|
import numpy as np
|
5
|
+
|
5
6
|
from ring import base
|
6
7
|
from ring import maths
|
7
8
|
|
8
9
|
_skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
10
|
+
_skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
9
11
|
_floor = """<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
|
10
12
|
|
11
13
|
|
@@ -90,7 +92,6 @@ def _build_model_of_geoms(
|
|
90
92
|
<camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
|
91
93
|
<camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
|
92
94
|
<camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
|
93
|
-
<light pos="0 0 4" dir="0 0 -1"/>
|
94
95
|
{_floor if floor else ''}
|
95
96
|
{inside_worldbody_cameras}
|
96
97
|
{inside_worldbody_lights}
|
ring/utils/__init__.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
+
from . import hdf5
|
2
|
+
from . import randomize_sys
|
1
3
|
from .batchsize import batchsize_thresholds
|
2
4
|
from .batchsize import distribute_batchsize
|
3
5
|
from .batchsize import expand_batchsize
|
4
6
|
from .batchsize import merge_batchsize
|
5
7
|
from .colab import setup_colab_env
|
6
|
-
from .hdf5 import load as hdf5_load
|
7
|
-
from .hdf5 import load_from_multiple as hdf5_load_from_multiple
|
8
|
-
from .hdf5 import load_length as hdf5_load_length
|
9
|
-
from .hdf5 import save as hdf5_save
|
10
8
|
from .normalizer import make_normalizer_from_generator
|
11
9
|
from .normalizer import Normalizer
|
12
10
|
from .path import parse_path
|
13
11
|
from .utils import dict_to_nested
|
14
12
|
from .utils import dict_union
|
13
|
+
from .utils import gcd
|
15
14
|
from .utils import import_lib
|
16
15
|
from .utils import pickle_load
|
17
16
|
from .utils import pickle_save
|
ring/utils/batchsize.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
from typing import Tuple
|
1
|
+
from typing import Tuple
|
2
2
|
|
3
3
|
import jax
|
4
|
-
|
5
|
-
PyTree = TypeVar("PyTree")
|
4
|
+
from tree_utils import PyTree
|
6
5
|
|
7
6
|
|
8
7
|
def batchsize_thresholds():
|
@@ -36,7 +35,16 @@ def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
|
|
36
35
|
return int(batchsize / vmap_size), vmap_size
|
37
36
|
|
38
37
|
|
39
|
-
def merge_batchsize(
|
38
|
+
def merge_batchsize(
|
39
|
+
tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
|
40
|
+
) -> PyTree:
|
41
|
+
if third_dim_also:
|
42
|
+
return jax.tree_map(
|
43
|
+
lambda arr: arr.reshape(
|
44
|
+
(pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
|
45
|
+
),
|
46
|
+
tree,
|
47
|
+
)
|
40
48
|
return jax.tree_map(
|
41
49
|
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
|
42
50
|
)
|
ring/utils/utils.py
CHANGED
@@ -1,411 +0,0 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
import warnings
|
3
|
-
|
4
|
-
import jax
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import numpy as np
|
7
|
-
import tree_utils
|
8
|
-
|
9
|
-
from ring import base
|
10
|
-
from ring import maths
|
11
|
-
from ring import utils
|
12
|
-
from ring.algorithms import sensors
|
13
|
-
from ring.algorithms.generator import pd_control
|
14
|
-
from ring.algorithms.generator import types
|
15
|
-
|
16
|
-
|
17
|
-
class GeneratorTrafoLambda(types.GeneratorTrafo):
|
18
|
-
def __init__(self, f, input: bool = False):
|
19
|
-
self.f = f
|
20
|
-
self.input = input
|
21
|
-
|
22
|
-
def __call__(self, gen):
|
23
|
-
if self.input:
|
24
|
-
|
25
|
-
def _gen(*args):
|
26
|
-
return gen(*self.f(*args))
|
27
|
-
|
28
|
-
else:
|
29
|
-
|
30
|
-
def _gen(*args):
|
31
|
-
return self.f(gen(*args))
|
32
|
-
|
33
|
-
return _gen
|
34
|
-
|
35
|
-
|
36
|
-
def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
|
37
|
-
for key in list(d.keys()):
|
38
|
-
if key in names:
|
39
|
-
d[str(names.index(key))] = d.pop(key)
|
40
|
-
else:
|
41
|
-
warnings.warn(
|
42
|
-
f"The key `{key}` was not found in names `{names}`. "
|
43
|
-
"It will not be renamed."
|
44
|
-
)
|
45
|
-
|
46
|
-
return d
|
47
|
-
|
48
|
-
|
49
|
-
class GeneratorTrafoNames2Indices(types.GeneratorTrafo):
|
50
|
-
def __init__(self, sys_noimu: base.System) -> None:
|
51
|
-
self.sys_noimu = sys_noimu
|
52
|
-
|
53
|
-
def __call__(self, gen: types.GeneratorWithInputOutputExtras):
|
54
|
-
def _gen(*args):
|
55
|
-
(X, y), extras = gen(*args)
|
56
|
-
X = _rename_links(X, self.sys_noimu.link_names)
|
57
|
-
y = _rename_links(y, self.sys_noimu.link_names)
|
58
|
-
return (X, y), extras
|
59
|
-
|
60
|
-
return _gen
|
61
|
-
|
62
|
-
|
63
|
-
class GeneratorTrafoSetupFn(types.GeneratorTrafo):
|
64
|
-
def __init__(self, setup_fn: types.SETUP_FN):
|
65
|
-
self.setup_fn = setup_fn
|
66
|
-
|
67
|
-
def __call__(
|
68
|
-
self,
|
69
|
-
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
70
|
-
) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
|
71
|
-
def _gen(key, sys):
|
72
|
-
key, consume = jax.random.split(key)
|
73
|
-
sys = self.setup_fn(consume, sys)
|
74
|
-
return gen(key, sys)
|
75
|
-
|
76
|
-
return _gen
|
77
|
-
|
78
|
-
|
79
|
-
class GeneratorTrafoFinalizeFn(types.GeneratorTrafo):
|
80
|
-
def __init__(self, finalize_fn: types.FINALIZE_FN):
|
81
|
-
self.finalize_fn = finalize_fn
|
82
|
-
|
83
|
-
def __call__(
|
84
|
-
self,
|
85
|
-
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
86
|
-
) -> types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras:
|
87
|
-
def _gen(*args):
|
88
|
-
(X, y), (key, *extras) = gen(*args)
|
89
|
-
# make sure we aren't overwriting anything
|
90
|
-
assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
|
91
|
-
key, consume = jax.random.split(key)
|
92
|
-
Xy = self.finalize_fn(consume, *extras)
|
93
|
-
return Xy, tuple([key] + extras)
|
94
|
-
|
95
|
-
return _gen
|
96
|
-
|
97
|
-
|
98
|
-
class GeneratorTrafoRandomizePositions(types.GeneratorTrafo):
|
99
|
-
def __call__(
|
100
|
-
self,
|
101
|
-
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
102
|
-
) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
|
103
|
-
return GeneratorTrafoSetupFn(_setup_fn_randomize_positions)(gen)
|
104
|
-
|
105
|
-
|
106
|
-
def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
|
107
|
-
ts = sys.links.transform1
|
108
|
-
|
109
|
-
for i in range(sys.num_links()):
|
110
|
-
link = sys.links[i]
|
111
|
-
key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
|
112
|
-
ts = ts.index_set(i, ts[i].replace(pos=new_pos))
|
113
|
-
|
114
|
-
return sys.replace(links=sys.links.replace(transform1=ts))
|
115
|
-
|
116
|
-
|
117
|
-
def _draw_pos_uniform(key, pos_min, pos_max):
|
118
|
-
key, c1, c2, c3 = jax.random.split(key, num=4)
|
119
|
-
pos = jnp.array(
|
120
|
-
[
|
121
|
-
jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
|
122
|
-
jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
|
123
|
-
jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
|
124
|
-
]
|
125
|
-
)
|
126
|
-
return key, pos
|
127
|
-
|
128
|
-
|
129
|
-
class GeneratorTrafoRandomizeTransform1Rot(types.GeneratorTrafo):
|
130
|
-
def __init__(self, maxval_deg: float):
|
131
|
-
self.maxval = jnp.deg2rad(maxval_deg)
|
132
|
-
|
133
|
-
def __call__(self, gen):
|
134
|
-
setup_fn = lambda key, sys: _setup_fn_randomize_transform1_rot(
|
135
|
-
key, sys, self.maxval
|
136
|
-
)
|
137
|
-
return GeneratorTrafoSetupFn(setup_fn)(gen)
|
138
|
-
|
139
|
-
|
140
|
-
def _setup_fn_randomize_transform1_rot(
|
141
|
-
key, sys, maxval: float, not_imus: bool = True
|
142
|
-
) -> base.System:
|
143
|
-
new_transform1 = sys.links.transform1.replace(
|
144
|
-
rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
|
145
|
-
)
|
146
|
-
if not_imus:
|
147
|
-
imus = [name for name in sys.link_names if name[:3] == "imu"]
|
148
|
-
new_rot = new_transform1.rot
|
149
|
-
for imu in imus:
|
150
|
-
new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
|
151
|
-
new_transform1 = new_transform1.replace(rot=new_rot)
|
152
|
-
return sys.replace(links=sys.links.replace(transform1=new_transform1))
|
153
|
-
|
154
|
-
|
155
|
-
class GeneratorTrafoJointAxisSensor(types.GeneratorTrafo):
|
156
|
-
def __init__(self, sys: base.System, **kwargs):
|
157
|
-
self.sys = sys
|
158
|
-
self.kwargs = kwargs
|
159
|
-
|
160
|
-
def __call__(self, gen):
|
161
|
-
def _gen(*args):
|
162
|
-
(X, y), (key, q, x, sys_x) = gen(*args)
|
163
|
-
key, consume = jax.random.split(key)
|
164
|
-
X_joint_axes = sensors.joint_axes(
|
165
|
-
self.sys, x, sys_x, key=consume, **self.kwargs
|
166
|
-
)
|
167
|
-
X = utils.dict_union(X, X_joint_axes)
|
168
|
-
return (X, y), (key, q, x, sys_x)
|
169
|
-
|
170
|
-
return _gen
|
171
|
-
|
172
|
-
|
173
|
-
class GeneratorTrafoRelPose(types.GeneratorTrafo):
|
174
|
-
def __init__(self, sys: base.System):
|
175
|
-
self.sys = sys
|
176
|
-
|
177
|
-
def __call__(self, gen):
|
178
|
-
def _gen(*args):
|
179
|
-
(X, y), (key, q, x, sys_x) = gen(*args)
|
180
|
-
y_relpose = sensors.rel_pose(self.sys, x, sys_x)
|
181
|
-
y = utils.dict_union(y, y_relpose)
|
182
|
-
return (X, y), (key, q, x, sys_x)
|
183
|
-
|
184
|
-
return _gen
|
185
|
-
|
186
|
-
|
187
|
-
class GeneratorTrafoRootIncl(types.GeneratorTrafo):
|
188
|
-
def __init__(self, sys: base.System):
|
189
|
-
self.sys = sys
|
190
|
-
|
191
|
-
def __call__(self, gen):
|
192
|
-
def _gen(*args):
|
193
|
-
(X, y), (key, q, x, sys_x) = gen(*args)
|
194
|
-
y_root_incl = sensors.root_incl(self.sys, x, sys_x)
|
195
|
-
y = utils.dict_union(y, y_root_incl)
|
196
|
-
return (X, y), (key, q, x, sys_x)
|
197
|
-
|
198
|
-
return _gen
|
199
|
-
|
200
|
-
|
201
|
-
_default_imu_kwargs = dict(
|
202
|
-
noisy=True,
|
203
|
-
low_pass_filter_pos_f_cutoff=13.5,
|
204
|
-
low_pass_filter_rot_cutoff=16.0,
|
205
|
-
)
|
206
|
-
|
207
|
-
|
208
|
-
class GeneratorTrafoIMU(types.GeneratorTrafo):
|
209
|
-
def __init__(self, **imu_kwargs):
|
210
|
-
self.kwargs = _default_imu_kwargs.copy()
|
211
|
-
self.kwargs.update(imu_kwargs)
|
212
|
-
|
213
|
-
def __call__(
|
214
|
-
self,
|
215
|
-
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
216
|
-
):
|
217
|
-
def _gen(*args):
|
218
|
-
(X, y), (key, q, x, sys) = gen(*args)
|
219
|
-
key, consume = jax.random.split(key)
|
220
|
-
X_imu = _imu_data(consume, x, sys, **self.kwargs)
|
221
|
-
X = utils.dict_union(X, X_imu)
|
222
|
-
return (X, y), (key, q, x, sys)
|
223
|
-
|
224
|
-
return _gen
|
225
|
-
|
226
|
-
|
227
|
-
def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
|
228
|
-
sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
|
229
|
-
inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
|
230
|
-
X = {}
|
231
|
-
N = xs.shape()
|
232
|
-
for segment in sys_noimu.link_names:
|
233
|
-
if segment in inv_imu_attachment:
|
234
|
-
imu = inv_imu_attachment[segment]
|
235
|
-
key, consume = jax.random.split(key)
|
236
|
-
imu_measurements = sensors.imu(
|
237
|
-
xs=xs.take(sys_xs.name_to_idx(imu), 1),
|
238
|
-
gravity=sys_xs.gravity,
|
239
|
-
dt=sys_xs.dt,
|
240
|
-
key=consume,
|
241
|
-
**kwargs,
|
242
|
-
)
|
243
|
-
else:
|
244
|
-
imu_measurements = {
|
245
|
-
"acc": jnp.zeros(
|
246
|
-
(
|
247
|
-
N,
|
248
|
-
3,
|
249
|
-
)
|
250
|
-
),
|
251
|
-
"gyr": jnp.zeros(
|
252
|
-
(
|
253
|
-
N,
|
254
|
-
3,
|
255
|
-
)
|
256
|
-
),
|
257
|
-
}
|
258
|
-
X[segment] = imu_measurements
|
259
|
-
return X
|
260
|
-
|
261
|
-
|
262
|
-
P_rot, P_pos = 100.0, 250.0
|
263
|
-
_P_gains = {
|
264
|
-
"free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
|
265
|
-
"free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
|
266
|
-
"px": jnp.array([P_pos]),
|
267
|
-
"py": jnp.array([P_pos]),
|
268
|
-
"pz": jnp.array([P_pos]),
|
269
|
-
"rx": jnp.array([P_rot]),
|
270
|
-
"ry": jnp.array([P_rot]),
|
271
|
-
"rz": jnp.array([P_rot]),
|
272
|
-
"rr": jnp.array([P_rot]),
|
273
|
-
# primary, residual
|
274
|
-
"rr_imp": jnp.array([P_rot, P_rot]),
|
275
|
-
"cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
|
276
|
-
"spherical": jnp.array(3 * [P_rot]),
|
277
|
-
"p3d": jnp.array(3 * [P_pos]),
|
278
|
-
"saddle": jnp.array([P_rot, P_rot]),
|
279
|
-
"frozen": jnp.array([]),
|
280
|
-
"suntay": jnp.array([P_rot]),
|
281
|
-
}
|
282
|
-
|
283
|
-
|
284
|
-
class GeneratorTrafoDynamicalSimulation(types.GeneratorTrafo):
|
285
|
-
def __init__(
|
286
|
-
self,
|
287
|
-
custom_P_gains: dict[str, jax.Array] = dict(),
|
288
|
-
unactuated_subsystems: list[str] = [],
|
289
|
-
return_q_ref: bool = False,
|
290
|
-
overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
|
291
|
-
**unroll_kwargs,
|
292
|
-
):
|
293
|
-
self.unactuated_links = unactuated_subsystems
|
294
|
-
self.custom_P_gains = custom_P_gains
|
295
|
-
self.return_q_ref = return_q_ref
|
296
|
-
self.overwrite_q_ref = overwrite_q_ref
|
297
|
-
self.unroll_kwargs = unroll_kwargs
|
298
|
-
|
299
|
-
def __call__(self, gen):
|
300
|
-
def _gen(*args):
|
301
|
-
(X, y), (key, q, _, sys_x) = gen(*args)
|
302
|
-
idx_map_q = sys_x.idx_map("q")
|
303
|
-
|
304
|
-
if self.overwrite_q_ref is not None:
|
305
|
-
q, idx_map_q = self.overwrite_q_ref
|
306
|
-
assert q.shape[-1] == sum(
|
307
|
-
[s.stop - s.start for s in idx_map_q.values()]
|
308
|
-
)
|
309
|
-
|
310
|
-
sys_q_ref = sys_x
|
311
|
-
if len(self.unactuated_links) > 0:
|
312
|
-
sys_q_ref = sys_x.delete_system(self.unactuated_links)
|
313
|
-
|
314
|
-
q_ref = []
|
315
|
-
p_gains_list = []
|
316
|
-
q = q.T
|
317
|
-
|
318
|
-
def build_q_ref(_, __, name, link_type):
|
319
|
-
q_ref.append(q[idx_map_q[name]])
|
320
|
-
|
321
|
-
if link_type in self.custom_P_gains:
|
322
|
-
p_gain_this_link = self.custom_P_gains[link_type]
|
323
|
-
elif link_type in _P_gains:
|
324
|
-
p_gain_this_link = _P_gains[link_type]
|
325
|
-
else:
|
326
|
-
raise RuntimeError(
|
327
|
-
f"Please proved gain parameters for the joint typ `{link_type}`"
|
328
|
-
" via the argument `custom_P_gains: dict[str, Array]`"
|
329
|
-
)
|
330
|
-
|
331
|
-
required_qd_size = base.QD_WIDTHS[link_type]
|
332
|
-
assert (
|
333
|
-
required_qd_size == p_gain_this_link.size
|
334
|
-
), f"The gain parameters must be of qd_size=`{required_qd_size}`"
|
335
|
-
f" but got `{p_gain_this_link.size}`. This happened for the link "
|
336
|
-
f"`{name}` of type `{link_type}`."
|
337
|
-
p_gains_list.append(p_gain_this_link)
|
338
|
-
|
339
|
-
sys_q_ref.scan(
|
340
|
-
build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types
|
341
|
-
)
|
342
|
-
q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(
|
343
|
-
p_gains_list
|
344
|
-
)
|
345
|
-
|
346
|
-
# perform dynamical simulation
|
347
|
-
states = pd_control._unroll_dynamics_pd_control(
|
348
|
-
sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
|
349
|
-
)
|
350
|
-
|
351
|
-
if self.return_q_ref:
|
352
|
-
X = utils.dict_union(X, dict(q_ref=q_ref))
|
353
|
-
|
354
|
-
return (X, y), (key, states.q, states.x, sys_x)
|
355
|
-
|
356
|
-
return _gen
|
357
|
-
|
358
|
-
|
359
|
-
def _flatten(seq: list):
|
360
|
-
seq = tree_utils.tree_batch(seq, backend=None)
|
361
|
-
seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
|
362
|
-
return seq
|
363
|
-
|
364
|
-
|
365
|
-
def _expand_dt(X: dict, T: int):
|
366
|
-
dt = X.pop("dt", None)
|
367
|
-
if dt is not None:
|
368
|
-
if isinstance(dt, np.ndarray):
|
369
|
-
numpy = np
|
370
|
-
else:
|
371
|
-
numpy = jnp
|
372
|
-
dt = numpy.repeat(dt[:, None, :], T, axis=1)
|
373
|
-
for seg in X:
|
374
|
-
X[seg]["dt"] = dt
|
375
|
-
return X
|
376
|
-
|
377
|
-
|
378
|
-
def _expand_then_flatten(args):
|
379
|
-
X, y = args
|
380
|
-
gyr = X["0"]["gyr"]
|
381
|
-
|
382
|
-
batched = True
|
383
|
-
if gyr.ndim == 2:
|
384
|
-
batched = False
|
385
|
-
X, y = tree_utils.add_batch_dim((X, y))
|
386
|
-
|
387
|
-
X = _expand_dt(X, gyr.shape[-2])
|
388
|
-
|
389
|
-
N = len(X)
|
390
|
-
|
391
|
-
def dict_to_tuple(d: dict[str, jax.Array]):
|
392
|
-
tup = (d["acc"], d["gyr"])
|
393
|
-
if "joint_axes" in d:
|
394
|
-
tup = tup + (d["joint_axes"],)
|
395
|
-
if "dt" in d:
|
396
|
-
tup = tup + (d["dt"],)
|
397
|
-
return tup
|
398
|
-
|
399
|
-
X = [dict_to_tuple(X[str(i)]) for i in range(N)]
|
400
|
-
y = [y[str(i)] for i in range(N)]
|
401
|
-
|
402
|
-
X, y = _flatten(X), _flatten(y)
|
403
|
-
if not batched:
|
404
|
-
X, y = jax.tree_map(lambda arr: arr[0], (X, y))
|
405
|
-
return X, y
|
406
|
-
|
407
|
-
|
408
|
-
def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
|
409
|
-
if jit:
|
410
|
-
return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
|
411
|
-
return GeneratorTrafoLambda(_expand_then_flatten)(gen)
|
File without changes
|
File without changes
|