imt-ring 1.6.36__tar.gz → 1.6.38__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.36 → imt_ring-1.6.38}/PKG-INFO +2 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/pyproject.toml +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/PKG-INFO +2 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/suntay.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/batch.py +2 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/finalize_fns.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/pd_control.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/kinematics.py +2 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/sensors.py +12 -10
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/base.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/from_xml.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/base.py +2 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/ml_utils.py +3 -3
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/ringnet.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/train.py +2 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/mujoco_render.py +11 -7
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/vispy_render.py +5 -4
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/inject_sys.py +3 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/batchsize.py +3 -3
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/dataloader.py +4 -3
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/dataloader_torch.py +14 -5
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/hdf5.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/normalizer.py +6 -5
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/utils.py +18 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_ml_utils.py +1 -1
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sim2real.py +3 -2
- {imt_ring-1.6.36 → imt_ring-1.6.38}/readme.md +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/setup.cfg +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/SOURCES.txt +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algebra.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/base.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/setup_fns.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/jcalc.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/maths.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/spatial.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/path.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/randomize_sys.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/register_gym_envs/__init__.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/register_gym_envs/saddle.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_algebra.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_base.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_dynamics.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_generator.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_jcalc.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_jit.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_kinematics.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_maths.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_pd_control.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_random.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_randomize.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_rcmg.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_render.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sensors.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_train.py +0 -0
- {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.6.
|
3
|
+
Version: 1.6.38
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.6.
|
3
|
+
Version: 1.6.38
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
184
184
|
|
185
185
|
suntay_link_name = _utils_find_suntay_joint(sys)
|
186
186
|
|
187
|
-
params = jax.
|
187
|
+
params = jax.tree.map(
|
188
188
|
lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
|
189
189
|
sys.links.joint_params[name],
|
190
190
|
)
|
@@ -80,11 +80,11 @@ def generators_eager(
|
|
80
80
|
# converts also to numpy; but with np.array.flags.writeable = False
|
81
81
|
sample = jax.device_get(sample)
|
82
82
|
# this then sets this flag to True
|
83
|
-
sample = jax.
|
83
|
+
sample = jax.tree.map(np.array, sample)
|
84
84
|
|
85
85
|
sample_flat, _ = jax.tree_util.tree_flatten(sample)
|
86
86
|
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
|
87
|
-
callback([jax.
|
87
|
+
callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
|
88
88
|
|
89
89
|
# cleanup
|
90
90
|
del sample, sample_flat
|
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
86
86
|
controller_state: PDControllerState, sys: base.System, state: base.State
|
87
87
|
) -> jax.Array:
|
88
88
|
taus = jnp.zeros((sys.qd_size()))
|
89
|
-
q_ref, qd_ref = jax.
|
89
|
+
q_ref, qd_ref = jax.tree.map(
|
90
90
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
91
91
|
arr, controller_state.i, keepdims=False
|
92
92
|
),
|
@@ -4,6 +4,7 @@ import jax
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jaxopt
|
6
6
|
from jaxopt._src.base import Solver
|
7
|
+
|
7
8
|
from ring import algebra
|
8
9
|
from ring import base
|
9
10
|
from ring import maths
|
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
|
|
171
172
|
|
172
173
|
# find result of best q0 initial value
|
173
174
|
best_q_index = jnp.argmin(values)
|
174
|
-
best_q, best_q_value = jax.
|
175
|
+
best_q, best_q_value = jax.tree.map(
|
175
176
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
176
177
|
arr, best_q_index, keepdims=False
|
177
178
|
),
|
@@ -244,7 +244,7 @@ def imu(
|
|
244
244
|
measurements["mag"] = magnetometer(xs.rot, magvec)
|
245
245
|
|
246
246
|
if smoothen_degree is not None:
|
247
|
-
measurements = jax.
|
247
|
+
measurements = jax.tree.map(
|
248
248
|
lambda arr: _moving_average(arr, smoothen_degree),
|
249
249
|
measurements,
|
250
250
|
)
|
@@ -257,7 +257,7 @@ def imu(
|
|
257
257
|
delay = half_window
|
258
258
|
|
259
259
|
if delay is not None and delay > 0:
|
260
|
-
measurements = jax.
|
260
|
+
measurements = jax.tree.map(
|
261
261
|
lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
|
262
262
|
)
|
263
263
|
|
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
|
473
473
|
X[name] = {"joint_axes": joint_axes}
|
474
474
|
|
475
475
|
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
476
|
-
X = jax.
|
476
|
+
X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
|
477
477
|
return X
|
478
478
|
|
479
479
|
|
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
|
|
498
498
|
<x_xy>
|
499
499
|
<options gravity="0 0 0"/>
|
500
500
|
<worldbody>
|
501
|
-
<body name="IMU" joint="
|
502
|
-
<geom type="box" mass="
|
501
|
+
<body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
|
502
|
+
<geom type="box" mass="1" dim="0.01 0.01 0.01"/>
|
503
503
|
</body>
|
504
504
|
</worldbody>
|
505
505
|
</x_xy>
|
506
|
-
"""
|
506
|
+
""" # noqa: E501
|
507
507
|
|
508
508
|
|
509
509
|
def _quasi_physical_simulation_beautiful(
|
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
|
|
512
512
|
sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
|
513
513
|
|
514
514
|
def step_dynamics(state: base.State, x):
|
515
|
-
state = algorithms.step(
|
515
|
+
state = algorithms.step(
|
516
|
+
sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
|
517
|
+
)
|
516
518
|
return state, state.q
|
517
519
|
|
518
|
-
state = base.State.create(sys, q=xs.pos[0])
|
519
|
-
_,
|
520
|
-
return xs.replace(pos=
|
520
|
+
state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
|
521
|
+
_, qs = jax.lax.scan(step_dynamics, state, xs)
|
522
|
+
return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
|
521
523
|
|
522
524
|
|
523
525
|
_constants = {
|
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
|
|
252
252
|
|
253
253
|
# numpy -> jax
|
254
254
|
# we load using numpy in order to have float64 precision
|
255
|
-
sys = jax.
|
255
|
+
sys = jax.tree.map(jax.numpy.asarray, sys)
|
256
256
|
|
257
257
|
sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
|
258
258
|
|
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
|
|
13
13
|
def _to_3d(tree):
|
14
14
|
if tree is None:
|
15
15
|
return None
|
16
|
-
return jax.
|
16
|
+
return jax.tree.map(lambda arr: arr[None], tree)
|
17
17
|
|
18
18
|
|
19
19
|
def _to_2d(tree, i: int = 0):
|
20
20
|
if tree is None:
|
21
21
|
return None
|
22
|
-
return jax.
|
22
|
+
return jax.tree.map(lambda arr: arr[i], tree)
|
23
23
|
|
24
24
|
|
25
25
|
class AbstractFilter(ABC):
|
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
|
|
161
161
|
metrices: NestedDict, filter_nan_inf: bool = True
|
162
162
|
):
|
163
163
|
metrices = _flatten_dict(metrices)
|
164
|
-
metrices = jax.
|
164
|
+
metrices = jax.tree.map(_to_float_if_not_string, metrices)
|
165
165
|
|
166
166
|
if not filter_nan_inf:
|
167
167
|
return metrices
|
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
216
216
|
from jax.experimental import jax2tf
|
217
217
|
import tensorflow as tf
|
218
218
|
|
219
|
-
signature = jax.
|
219
|
+
signature = jax.tree.map(
|
220
220
|
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
221
221
|
)
|
222
222
|
|
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
241
241
|
if validate:
|
242
242
|
output_jax = jax_func(*input)
|
243
243
|
output_tf = tf.saved_model.load(path)(*input)
|
244
|
-
jax.
|
244
|
+
jax.tree.map(
|
245
245
|
lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
|
246
246
|
output_jax,
|
247
247
|
output_tf,
|
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
|
|
248
248
|
params, state = self.forward_lam_factory(lam=lam).init(key, X)
|
249
249
|
|
250
250
|
if bs is not None:
|
251
|
-
state = jax.
|
251
|
+
state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
|
252
252
|
|
253
253
|
return params, state
|
254
254
|
|
@@ -50,7 +50,7 @@ def _build_step_fn(
|
|
50
50
|
# this vmap maps along batch-axis, not time-axis
|
51
51
|
# time-axis is handled by `metric_fn`
|
52
52
|
pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
|
53
|
-
error_tree = jax.
|
53
|
+
error_tree = jax.tree.map(pipe, y, yhat)
|
54
54
|
return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
|
55
55
|
|
56
56
|
@partial(
|
@@ -274,7 +274,7 @@ def _build_eval_fn(
|
|
274
274
|
), f"The metric identitifier {metric_name} is not unique"
|
275
275
|
|
276
276
|
pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
|
277
|
-
values.update({metric_name: jax.
|
277
|
+
values.update({metric_name: jax.tree.map(pipe, y, yhat)})
|
278
278
|
|
279
279
|
return values
|
280
280
|
|
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
|
|
10
10
|
_skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
11
11
|
|
12
12
|
|
13
|
-
def _floor(
|
14
|
-
return f"""<geom name="floor" pos="0 0 {
|
13
|
+
def _floor(z: float, material: str) -> str:
|
14
|
+
return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
|
15
15
|
|
16
16
|
|
17
17
|
def _build_model_of_geoms(
|
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
|
|
19
19
|
cameras: dict[int, Sequence[str]],
|
20
20
|
lights: dict[int, Sequence[str]],
|
21
21
|
floor: bool,
|
22
|
-
|
22
|
+
floor_kwargs: dict,
|
23
23
|
stars: bool,
|
24
24
|
debug: bool,
|
25
25
|
) -> mujoco.MjModel:
|
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
|
|
77
77
|
xml_str = f""" # noqa: E501
|
78
78
|
<mujoco>
|
79
79
|
<asset>
|
80
|
-
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".
|
80
|
+
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
|
81
81
|
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
|
82
82
|
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
|
83
83
|
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
|
84
|
+
<material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
|
85
|
+
<material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
|
86
|
+
<material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
|
84
87
|
{_skybox if stars else ''}
|
85
88
|
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
|
86
89
|
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
|
|
98
101
|
<camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
|
99
102
|
<camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
|
100
103
|
<camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
|
101
|
-
{_floor(
|
104
|
+
{_floor(**floor_kwargs) if floor else ''}
|
102
105
|
{inside_worldbody_cameras}
|
103
106
|
{inside_worldbody_lights}
|
104
107
|
{inside_worldbody}
|
@@ -176,6 +179,7 @@ class MujocoScene:
|
|
176
179
|
show_stars: bool = True,
|
177
180
|
show_floor: bool = True,
|
178
181
|
floor_z: float = -0.84,
|
182
|
+
floor_material: str = "matplane",
|
179
183
|
debug: bool = False,
|
180
184
|
) -> None:
|
181
185
|
self.debug = debug
|
@@ -190,7 +194,7 @@ class MujocoScene:
|
|
190
194
|
self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
|
191
195
|
self.show_stars = show_stars
|
192
196
|
self.show_floor = show_floor
|
193
|
-
self.
|
197
|
+
self.floor_kwargs = dict(z=floor_z, material=floor_material)
|
194
198
|
|
195
199
|
def init(self, geoms: list[base.Geometry]):
|
196
200
|
self._parent_ids = list(set([geom.link_idx for geom in geoms]))
|
@@ -199,7 +203,7 @@ class MujocoScene:
|
|
199
203
|
self.add_cameras,
|
200
204
|
self.add_lights,
|
201
205
|
floor=self.show_floor,
|
202
|
-
|
206
|
+
floor_kwargs=self.floor_kwargs,
|
203
207
|
stars=self.show_stars,
|
204
208
|
debug=self.debug,
|
205
209
|
)
|
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
|
|
7
7
|
import jax
|
8
8
|
import jax.numpy as jnp
|
9
9
|
import numpy as np
|
10
|
-
from ring import algebra
|
11
|
-
from ring import base
|
12
|
-
from ring import maths
|
13
10
|
from tree_utils import PyTree
|
14
11
|
from tree_utils import tree_batch
|
15
12
|
from vispy import scene
|
16
13
|
from vispy.scene import MatrixTransform
|
17
14
|
|
15
|
+
from ring import algebra
|
16
|
+
from ring import base
|
17
|
+
from ring import maths
|
18
|
+
|
18
19
|
from . import vispy_visuals
|
19
20
|
|
20
21
|
Camera = TypeVar("Camera")
|
@@ -192,7 +193,7 @@ class Scene(ABC):
|
|
192
193
|
|
193
194
|
# step 3: update visuals
|
194
195
|
for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
|
195
|
-
t = jax.
|
196
|
+
t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
|
196
197
|
if self._fresh_init:
|
197
198
|
self._init_visual(visual, t, geom)
|
198
199
|
else:
|
@@ -2,12 +2,13 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
from ring import base
|
6
5
|
from tree_utils import tree_batch
|
7
6
|
|
7
|
+
from ring import base
|
8
|
+
|
8
9
|
|
9
10
|
def _tree_nan_like(tree, repeats: int):
|
10
|
-
return jax.
|
11
|
+
return jax.tree.map(
|
11
12
|
lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
|
12
13
|
)
|
13
14
|
|
@@ -39,19 +39,19 @@ def merge_batchsize(
|
|
39
39
|
tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
|
40
40
|
) -> PyTree:
|
41
41
|
if third_dim_also:
|
42
|
-
return jax.
|
42
|
+
return jax.tree.map(
|
43
43
|
lambda arr: arr.reshape(
|
44
44
|
(pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
|
45
45
|
),
|
46
46
|
tree,
|
47
47
|
)
|
48
|
-
return jax.
|
48
|
+
return jax.tree.map(
|
49
49
|
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
|
50
50
|
)
|
51
51
|
|
52
52
|
|
53
53
|
def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
54
|
-
return jax.
|
54
|
+
return jax.tree.map(
|
55
55
|
lambda arr: arr.reshape(
|
56
56
|
(
|
57
57
|
pmap_size,
|
@@ -4,14 +4,15 @@ from typing import Callable, Optional
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import numpy as np
|
7
|
-
from ring.utils import parse_path
|
8
|
-
from ring.utils import pickle_load
|
9
7
|
import torch
|
10
8
|
from torch.utils.data import DataLoader
|
11
9
|
from torch.utils.data import Dataset
|
12
10
|
import tqdm
|
13
11
|
from tree_utils import PyTree
|
14
12
|
|
13
|
+
from ring.utils import parse_path
|
14
|
+
from ring.utils import pickle_load
|
15
|
+
|
15
16
|
|
16
17
|
def make_generator(
|
17
18
|
*paths,
|
@@ -103,7 +104,7 @@ def pytorch_generator(
|
|
103
104
|
dl_iter = iter(dl)
|
104
105
|
|
105
106
|
def to_numpy(tree: PyTree[torch.Tensor]):
|
106
|
-
return jax.
|
107
|
+
return jax.tree.map(lambda tensor: tensor.numpy(), tree)
|
107
108
|
|
108
109
|
def generator(_):
|
109
110
|
nonlocal dl, dl_iter
|
@@ -1,16 +1,25 @@
|
|
1
1
|
import os
|
2
|
+
import pickle
|
2
3
|
from typing import Any, Optional
|
3
4
|
import warnings
|
4
5
|
|
5
|
-
import jax
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
from torch.utils.data import DataLoader
|
9
9
|
from torch.utils.data import Dataset
|
10
|
+
import tree
|
10
11
|
from tree_utils import PyTree
|
11
12
|
|
12
|
-
from ring.utils import parse_path
|
13
|
-
|
13
|
+
from ring.utils.path import parse_path
|
14
|
+
|
15
|
+
|
16
|
+
def pickle_load(
|
17
|
+
path,
|
18
|
+
):
|
19
|
+
path = parse_path(path, extension="pickle", require_is_file=True)
|
20
|
+
with open(path, "rb") as file:
|
21
|
+
obj = pickle.load(file)
|
22
|
+
return obj
|
14
23
|
|
15
24
|
|
16
25
|
class FolderOfFilesDataset(Dataset):
|
@@ -60,8 +69,8 @@ def dataset_to_generator(
|
|
60
69
|
)
|
61
70
|
dl_iter = iter(dl)
|
62
71
|
|
63
|
-
def to_numpy(
|
64
|
-
return
|
72
|
+
def to_numpy(data: PyTree[torch.Tensor]):
|
73
|
+
return tree.map_structure(lambda tensor: tensor.numpy(), data)
|
65
74
|
|
66
75
|
def generator(_):
|
67
76
|
nonlocal dl, dl_iter
|
@@ -121,7 +121,7 @@ def _parse_path(
|
|
121
121
|
|
122
122
|
def _tree_concat(trees: list):
|
123
123
|
# otherwise scalar-arrays will lead to indexing error
|
124
|
-
trees = jax.
|
124
|
+
trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
|
125
125
|
|
126
126
|
if len(trees) == 0:
|
127
127
|
return trees
|
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
-
from ring.algorithms.generator import types
|
7
6
|
import tree_utils
|
8
7
|
|
8
|
+
from ring.algorithms.generator import types
|
9
|
+
|
9
10
|
KEY = jax.random.PRNGKey(777)
|
10
11
|
KEY_PERMUTATION = jax.random.PRNGKey(888)
|
11
12
|
|
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
|
|
37
38
|
# permute 0-th axis, since batchsize of generator might be larger than
|
38
39
|
# `approx_with_large_batchsize`, then we would not get a representative
|
39
40
|
# subsample otherwise
|
40
|
-
Xs = jax.
|
41
|
+
Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
|
41
42
|
Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
|
42
43
|
|
43
44
|
# obtain statistics
|
44
|
-
mean = jax.
|
45
|
-
std = jax.
|
45
|
+
mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
|
46
|
+
std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
|
46
47
|
|
47
48
|
if verbose:
|
48
49
|
print("Mean: ", mean)
|
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
|
|
51
52
|
eps = 1e-8
|
52
53
|
|
53
54
|
def normalizer(X):
|
54
|
-
return jax.
|
55
|
+
return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
|
55
56
|
|
56
57
|
return normalizer
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from importlib import import_module as _import_module
|
2
2
|
import io
|
3
|
+
from pathlib import Path
|
3
4
|
import pickle
|
4
5
|
import random
|
5
6
|
from typing import Optional
|
@@ -152,13 +153,28 @@ def import_lib(
|
|
152
153
|
|
153
154
|
def pickle_save(obj, path, overwrite: bool = False):
|
154
155
|
path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
|
155
|
-
|
156
|
-
|
156
|
+
try:
|
157
|
+
with open(path, "wb") as file:
|
158
|
+
pickle.dump(obj, file, protocol=5)
|
159
|
+
except OSError as e:
|
160
|
+
print(
|
161
|
+
f"saving with `pickle` throws exception {e}. "
|
162
|
+
+ "Attempting to save using `joblib`"
|
163
|
+
)
|
164
|
+
path = parse_path(path, extension="joblib", file_exists_ok=overwrite)
|
165
|
+
import joblib
|
166
|
+
|
167
|
+
joblib.dump(obj, path)
|
157
168
|
|
158
169
|
|
159
170
|
def pickle_load(
|
160
171
|
path,
|
161
172
|
):
|
173
|
+
if Path(path).suffix == ".joblib":
|
174
|
+
import joblib
|
175
|
+
|
176
|
+
return joblib.load(path)
|
177
|
+
|
162
178
|
path = parse_path(path, extension="pickle", require_is_file=True)
|
163
179
|
with open(path, "rb") as file:
|
164
180
|
obj = pickle.load(file)
|
@@ -41,7 +41,7 @@ def test_save_load_generators():
|
|
41
41
|
data = rcmg.to_list()[0]
|
42
42
|
rcmg.to_pickle(path)
|
43
43
|
|
44
|
-
data_list = [jax.
|
44
|
+
data_list = [jax.tree.map(lambda a: a[0], utils.pickle_load(path))]
|
45
45
|
gen_reloaded = ring.RCMG.eager_gen_from_list(data_list, 1)
|
46
46
|
data_reloaded = unbatch_gen(gen_reloaded)(jax.random.PRNGKey(1))
|
47
47
|
|
@@ -2,6 +2,7 @@ from _compat import unbatch_gen
|
|
2
2
|
import jax
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import numpy as np
|
5
|
+
|
5
6
|
import ring
|
6
7
|
from ring import maths
|
7
8
|
from ring import sim2real
|
@@ -49,7 +50,7 @@ def test_forward_kinematics_omc():
|
|
49
50
|
# t1_omc should be used when p == -1, else t1_sys
|
50
51
|
@jax.vmap
|
51
52
|
def merge_transform1(t1_omc):
|
52
|
-
return jax.
|
53
|
+
return jax.tree.map(
|
53
54
|
lambda a, b: jnp.where(
|
54
55
|
jnp.repeat(
|
55
56
|
jnp.array(sys.link_parents)[:, None] == -1,
|
@@ -138,7 +139,7 @@ def test_zip_unzip_scale():
|
|
138
139
|
t1, t2 = sim2real.unzip_xs(sys, xs)
|
139
140
|
xs_re = sim2real.zip_xs(sys, t1, t2)
|
140
141
|
|
141
|
-
jax.
|
142
|
+
jax.tree.map(
|
142
143
|
lambda a, b: np.testing.assert_allclose(a, b, rtol=1e-3, atol=1e-5),
|
143
144
|
xs,
|
144
145
|
xs_re,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml
RENAMED
File without changes
|
{imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|