imt-ring 1.6.36__py3-none-any.whl → 1.6.38__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.36.dist-info → imt_ring-1.6.38.dist-info}/METADATA +2 -2
- {imt_ring-1.6.36.dist-info → imt_ring-1.6.38.dist-info}/RECORD +25 -25
- {imt_ring-1.6.36.dist-info → imt_ring-1.6.38.dist-info}/WHEEL +1 -1
- ring/algorithms/custom_joints/suntay.py +1 -1
- ring/algorithms/generator/batch.py +2 -2
- ring/algorithms/generator/finalize_fns.py +1 -1
- ring/algorithms/generator/pd_control.py +1 -1
- ring/algorithms/kinematics.py +2 -1
- ring/algorithms/sensors.py +12 -10
- ring/base.py +1 -1
- ring/io/xml/from_xml.py +1 -1
- ring/ml/base.py +2 -2
- ring/ml/ml_utils.py +3 -3
- ring/ml/ringnet.py +1 -1
- ring/ml/train.py +2 -2
- ring/rendering/mujoco_render.py +11 -7
- ring/rendering/vispy_render.py +5 -4
- ring/sys_composer/inject_sys.py +3 -2
- ring/utils/batchsize.py +3 -3
- ring/utils/dataloader.py +4 -3
- ring/utils/dataloader_torch.py +14 -5
- ring/utils/hdf5.py +1 -1
- ring/utils/normalizer.py +6 -5
- ring/utils/utils.py +18 -2
- {imt_ring-1.6.36.dist-info → imt_ring-1.6.38.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.6.
|
3
|
+
Version: 1.6.38
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -1,25 +1,25 @@
|
|
1
1
|
ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=4Yxk6jk-B4UUm_6YYshxmHSHqOg0mhTOxtZP5fFS8nw,35373
|
4
4
|
ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
7
|
ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
|
8
8
|
ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
|
9
9
|
ring/algorithms/jcalc.py,sha256=QafnCKa1mjEl7nL_KuadPJB5ebW31NKnkdcKn2YtSsM,36171
|
10
|
-
ring/algorithms/kinematics.py,sha256=
|
11
|
-
ring/algorithms/sensors.py,sha256=
|
10
|
+
ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
|
11
|
+
ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
15
|
ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
|
16
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
16
|
+
ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
|
17
17
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
18
18
|
ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
|
19
|
-
ring/algorithms/generator/batch.py,sha256=
|
20
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
19
|
+
ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
|
20
|
+
ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
|
21
21
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
22
|
-
ring/algorithms/generator/pd_control.py,sha256=
|
22
|
+
ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
|
23
23
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
24
24
|
ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
|
25
25
|
ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
|
@@ -47,46 +47,46 @@ ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQ
|
|
47
47
|
ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
|
48
48
|
ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
|
49
49
|
ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
|
50
|
-
ring/io/xml/from_xml.py,sha256=
|
50
|
+
ring/io/xml/from_xml.py,sha256=CR3OaBxoDuHK8k5N79XziHwS90lCaaw49UGzQirWiIw,9356
|
51
51
|
ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
|
52
52
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
53
53
|
ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
|
54
54
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
55
|
-
ring/ml/base.py,sha256=
|
55
|
+
ring/ml/base.py,sha256=eEpLuXlhFoJ-cpXoSGjctLaYduQhnSVpbv-FEYftNRs,9972
|
56
56
|
ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
|
57
|
-
ring/ml/ml_utils.py,sha256=
|
57
|
+
ring/ml/ml_utils.py,sha256=hu189AnHcmkhkpEPZZ19O0gWz3T-YKpWQW9buqDTMow,10915
|
58
58
|
ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
|
59
|
-
ring/ml/ringnet.py,sha256=
|
59
|
+
ring/ml/ringnet.py,sha256=uybMtLdQV0oTFldlbffrqku_l79y3coJYHwxX8P9QZQ,8973
|
60
60
|
ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
|
61
|
-
ring/ml/train.py,sha256=
|
61
|
+
ring/ml/train.py,sha256=_CtQM3w9L01V5yn23lz0aaIPJN5sOWlL9e7G_9__11c,10989
|
62
62
|
ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
|
63
63
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
64
64
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
65
65
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
66
66
|
ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
|
67
|
-
ring/rendering/mujoco_render.py,sha256=
|
68
|
-
ring/rendering/vispy_render.py,sha256=
|
67
|
+
ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obmj4,8757
|
68
|
+
ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
|
69
69
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
70
70
|
ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
|
71
71
|
ring/sim2real/sim2real.py,sha256=B4nqBBnjGXhM-7PfTyxEq44ZidGNghqaq--qdFILX5A,9675
|
72
72
|
ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E,193
|
73
73
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
74
|
-
ring/sys_composer/inject_sys.py,sha256=
|
74
|
+
ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
|
75
75
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
76
76
|
ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
|
77
77
|
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
78
|
-
ring/utils/batchsize.py,sha256=
|
78
|
+
ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
|
79
79
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
80
|
-
ring/utils/dataloader.py,sha256=
|
81
|
-
ring/utils/dataloader_torch.py,sha256=
|
82
|
-
ring/utils/hdf5.py,sha256=
|
83
|
-
ring/utils/normalizer.py,sha256=
|
80
|
+
ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
|
81
|
+
ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
|
82
|
+
ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
|
83
|
+
ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
|
84
84
|
ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
|
85
85
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
86
|
-
ring/utils/utils.py,sha256=
|
86
|
+
ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
|
87
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
88
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
92
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.38.dist-info/METADATA,sha256=9rN1VzsIlGU8eyABz9-pTxj0OTCFOZRilEEzkB4gyvg,4251
|
90
|
+
imt_ring-1.6.38.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
91
|
+
imt_ring-1.6.38.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.38.dist-info/RECORD,,
|
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
184
184
|
|
185
185
|
suntay_link_name = _utils_find_suntay_joint(sys)
|
186
186
|
|
187
|
-
params = jax.
|
187
|
+
params = jax.tree.map(
|
188
188
|
lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
|
189
189
|
sys.links.joint_params[name],
|
190
190
|
)
|
@@ -80,11 +80,11 @@ def generators_eager(
|
|
80
80
|
# converts also to numpy; but with np.array.flags.writeable = False
|
81
81
|
sample = jax.device_get(sample)
|
82
82
|
# this then sets this flag to True
|
83
|
-
sample = jax.
|
83
|
+
sample = jax.tree.map(np.array, sample)
|
84
84
|
|
85
85
|
sample_flat, _ = jax.tree_util.tree_flatten(sample)
|
86
86
|
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
|
87
|
-
callback([jax.
|
87
|
+
callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
|
88
88
|
|
89
89
|
# cleanup
|
90
90
|
del sample, sample_flat
|
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
86
86
|
controller_state: PDControllerState, sys: base.System, state: base.State
|
87
87
|
) -> jax.Array:
|
88
88
|
taus = jnp.zeros((sys.qd_size()))
|
89
|
-
q_ref, qd_ref = jax.
|
89
|
+
q_ref, qd_ref = jax.tree.map(
|
90
90
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
91
91
|
arr, controller_state.i, keepdims=False
|
92
92
|
),
|
ring/algorithms/kinematics.py
CHANGED
@@ -4,6 +4,7 @@ import jax
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jaxopt
|
6
6
|
from jaxopt._src.base import Solver
|
7
|
+
|
7
8
|
from ring import algebra
|
8
9
|
from ring import base
|
9
10
|
from ring import maths
|
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
|
|
171
172
|
|
172
173
|
# find result of best q0 initial value
|
173
174
|
best_q_index = jnp.argmin(values)
|
174
|
-
best_q, best_q_value = jax.
|
175
|
+
best_q, best_q_value = jax.tree.map(
|
175
176
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
176
177
|
arr, best_q_index, keepdims=False
|
177
178
|
),
|
ring/algorithms/sensors.py
CHANGED
@@ -244,7 +244,7 @@ def imu(
|
|
244
244
|
measurements["mag"] = magnetometer(xs.rot, magvec)
|
245
245
|
|
246
246
|
if smoothen_degree is not None:
|
247
|
-
measurements = jax.
|
247
|
+
measurements = jax.tree.map(
|
248
248
|
lambda arr: _moving_average(arr, smoothen_degree),
|
249
249
|
measurements,
|
250
250
|
)
|
@@ -257,7 +257,7 @@ def imu(
|
|
257
257
|
delay = half_window
|
258
258
|
|
259
259
|
if delay is not None and delay > 0:
|
260
|
-
measurements = jax.
|
260
|
+
measurements = jax.tree.map(
|
261
261
|
lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
|
262
262
|
)
|
263
263
|
|
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
|
473
473
|
X[name] = {"joint_axes": joint_axes}
|
474
474
|
|
475
475
|
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
476
|
-
X = jax.
|
476
|
+
X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
|
477
477
|
return X
|
478
478
|
|
479
479
|
|
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
|
|
498
498
|
<x_xy>
|
499
499
|
<options gravity="0 0 0"/>
|
500
500
|
<worldbody>
|
501
|
-
<body name="IMU" joint="
|
502
|
-
<geom type="box" mass="
|
501
|
+
<body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
|
502
|
+
<geom type="box" mass="1" dim="0.01 0.01 0.01"/>
|
503
503
|
</body>
|
504
504
|
</worldbody>
|
505
505
|
</x_xy>
|
506
|
-
"""
|
506
|
+
""" # noqa: E501
|
507
507
|
|
508
508
|
|
509
509
|
def _quasi_physical_simulation_beautiful(
|
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
|
|
512
512
|
sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
|
513
513
|
|
514
514
|
def step_dynamics(state: base.State, x):
|
515
|
-
state = algorithms.step(
|
515
|
+
state = algorithms.step(
|
516
|
+
sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
|
517
|
+
)
|
516
518
|
return state, state.q
|
517
519
|
|
518
|
-
state = base.State.create(sys, q=xs.pos[0])
|
519
|
-
_,
|
520
|
-
return xs.replace(pos=
|
520
|
+
state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
|
521
|
+
_, qs = jax.lax.scan(step_dynamics, state, xs)
|
522
|
+
return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
|
521
523
|
|
522
524
|
|
523
525
|
_constants = {
|
ring/base.py
CHANGED
ring/io/xml/from_xml.py
CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
|
|
252
252
|
|
253
253
|
# numpy -> jax
|
254
254
|
# we load using numpy in order to have float64 precision
|
255
|
-
sys = jax.
|
255
|
+
sys = jax.tree.map(jax.numpy.asarray, sys)
|
256
256
|
|
257
257
|
sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
|
258
258
|
|
ring/ml/base.py
CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
|
|
13
13
|
def _to_3d(tree):
|
14
14
|
if tree is None:
|
15
15
|
return None
|
16
|
-
return jax.
|
16
|
+
return jax.tree.map(lambda arr: arr[None], tree)
|
17
17
|
|
18
18
|
|
19
19
|
def _to_2d(tree, i: int = 0):
|
20
20
|
if tree is None:
|
21
21
|
return None
|
22
|
-
return jax.
|
22
|
+
return jax.tree.map(lambda arr: arr[i], tree)
|
23
23
|
|
24
24
|
|
25
25
|
class AbstractFilter(ABC):
|
ring/ml/ml_utils.py
CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
|
|
161
161
|
metrices: NestedDict, filter_nan_inf: bool = True
|
162
162
|
):
|
163
163
|
metrices = _flatten_dict(metrices)
|
164
|
-
metrices = jax.
|
164
|
+
metrices = jax.tree.map(_to_float_if_not_string, metrices)
|
165
165
|
|
166
166
|
if not filter_nan_inf:
|
167
167
|
return metrices
|
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
216
216
|
from jax.experimental import jax2tf
|
217
217
|
import tensorflow as tf
|
218
218
|
|
219
|
-
signature = jax.
|
219
|
+
signature = jax.tree.map(
|
220
220
|
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
221
221
|
)
|
222
222
|
|
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
241
241
|
if validate:
|
242
242
|
output_jax = jax_func(*input)
|
243
243
|
output_tf = tf.saved_model.load(path)(*input)
|
244
|
-
jax.
|
244
|
+
jax.tree.map(
|
245
245
|
lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
|
246
246
|
output_jax,
|
247
247
|
output_tf,
|
ring/ml/ringnet.py
CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
|
|
248
248
|
params, state = self.forward_lam_factory(lam=lam).init(key, X)
|
249
249
|
|
250
250
|
if bs is not None:
|
251
|
-
state = jax.
|
251
|
+
state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
|
252
252
|
|
253
253
|
return params, state
|
254
254
|
|
ring/ml/train.py
CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
|
|
50
50
|
# this vmap maps along batch-axis, not time-axis
|
51
51
|
# time-axis is handled by `metric_fn`
|
52
52
|
pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
|
53
|
-
error_tree = jax.
|
53
|
+
error_tree = jax.tree.map(pipe, y, yhat)
|
54
54
|
return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
|
55
55
|
|
56
56
|
@partial(
|
@@ -274,7 +274,7 @@ def _build_eval_fn(
|
|
274
274
|
), f"The metric identitifier {metric_name} is not unique"
|
275
275
|
|
276
276
|
pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
|
277
|
-
values.update({metric_name: jax.
|
277
|
+
values.update({metric_name: jax.tree.map(pipe, y, yhat)})
|
278
278
|
|
279
279
|
return values
|
280
280
|
|
ring/rendering/mujoco_render.py
CHANGED
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
|
|
10
10
|
_skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
11
11
|
|
12
12
|
|
13
|
-
def _floor(
|
14
|
-
return f"""<geom name="floor" pos="0 0 {
|
13
|
+
def _floor(z: float, material: str) -> str:
|
14
|
+
return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
|
15
15
|
|
16
16
|
|
17
17
|
def _build_model_of_geoms(
|
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
|
|
19
19
|
cameras: dict[int, Sequence[str]],
|
20
20
|
lights: dict[int, Sequence[str]],
|
21
21
|
floor: bool,
|
22
|
-
|
22
|
+
floor_kwargs: dict,
|
23
23
|
stars: bool,
|
24
24
|
debug: bool,
|
25
25
|
) -> mujoco.MjModel:
|
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
|
|
77
77
|
xml_str = f""" # noqa: E501
|
78
78
|
<mujoco>
|
79
79
|
<asset>
|
80
|
-
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".
|
80
|
+
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
|
81
81
|
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
|
82
82
|
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
|
83
83
|
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
|
84
|
+
<material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
|
85
|
+
<material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
|
86
|
+
<material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
|
84
87
|
{_skybox if stars else ''}
|
85
88
|
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
|
86
89
|
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
|
|
98
101
|
<camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
|
99
102
|
<camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
|
100
103
|
<camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
|
101
|
-
{_floor(
|
104
|
+
{_floor(**floor_kwargs) if floor else ''}
|
102
105
|
{inside_worldbody_cameras}
|
103
106
|
{inside_worldbody_lights}
|
104
107
|
{inside_worldbody}
|
@@ -176,6 +179,7 @@ class MujocoScene:
|
|
176
179
|
show_stars: bool = True,
|
177
180
|
show_floor: bool = True,
|
178
181
|
floor_z: float = -0.84,
|
182
|
+
floor_material: str = "matplane",
|
179
183
|
debug: bool = False,
|
180
184
|
) -> None:
|
181
185
|
self.debug = debug
|
@@ -190,7 +194,7 @@ class MujocoScene:
|
|
190
194
|
self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
|
191
195
|
self.show_stars = show_stars
|
192
196
|
self.show_floor = show_floor
|
193
|
-
self.
|
197
|
+
self.floor_kwargs = dict(z=floor_z, material=floor_material)
|
194
198
|
|
195
199
|
def init(self, geoms: list[base.Geometry]):
|
196
200
|
self._parent_ids = list(set([geom.link_idx for geom in geoms]))
|
@@ -199,7 +203,7 @@ class MujocoScene:
|
|
199
203
|
self.add_cameras,
|
200
204
|
self.add_lights,
|
201
205
|
floor=self.show_floor,
|
202
|
-
|
206
|
+
floor_kwargs=self.floor_kwargs,
|
203
207
|
stars=self.show_stars,
|
204
208
|
debug=self.debug,
|
205
209
|
)
|
ring/rendering/vispy_render.py
CHANGED
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
|
|
7
7
|
import jax
|
8
8
|
import jax.numpy as jnp
|
9
9
|
import numpy as np
|
10
|
-
from ring import algebra
|
11
|
-
from ring import base
|
12
|
-
from ring import maths
|
13
10
|
from tree_utils import PyTree
|
14
11
|
from tree_utils import tree_batch
|
15
12
|
from vispy import scene
|
16
13
|
from vispy.scene import MatrixTransform
|
17
14
|
|
15
|
+
from ring import algebra
|
16
|
+
from ring import base
|
17
|
+
from ring import maths
|
18
|
+
|
18
19
|
from . import vispy_visuals
|
19
20
|
|
20
21
|
Camera = TypeVar("Camera")
|
@@ -192,7 +193,7 @@ class Scene(ABC):
|
|
192
193
|
|
193
194
|
# step 3: update visuals
|
194
195
|
for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
|
195
|
-
t = jax.
|
196
|
+
t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
|
196
197
|
if self._fresh_init:
|
197
198
|
self._init_visual(visual, t, geom)
|
198
199
|
else:
|
ring/sys_composer/inject_sys.py
CHANGED
@@ -2,12 +2,13 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
from ring import base
|
6
5
|
from tree_utils import tree_batch
|
7
6
|
|
7
|
+
from ring import base
|
8
|
+
|
8
9
|
|
9
10
|
def _tree_nan_like(tree, repeats: int):
|
10
|
-
return jax.
|
11
|
+
return jax.tree.map(
|
11
12
|
lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
|
12
13
|
)
|
13
14
|
|
ring/utils/batchsize.py
CHANGED
@@ -39,19 +39,19 @@ def merge_batchsize(
|
|
39
39
|
tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
|
40
40
|
) -> PyTree:
|
41
41
|
if third_dim_also:
|
42
|
-
return jax.
|
42
|
+
return jax.tree.map(
|
43
43
|
lambda arr: arr.reshape(
|
44
44
|
(pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
|
45
45
|
),
|
46
46
|
tree,
|
47
47
|
)
|
48
|
-
return jax.
|
48
|
+
return jax.tree.map(
|
49
49
|
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
|
50
50
|
)
|
51
51
|
|
52
52
|
|
53
53
|
def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
54
|
-
return jax.
|
54
|
+
return jax.tree.map(
|
55
55
|
lambda arr: arr.reshape(
|
56
56
|
(
|
57
57
|
pmap_size,
|
ring/utils/dataloader.py
CHANGED
@@ -4,14 +4,15 @@ from typing import Callable, Optional
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import numpy as np
|
7
|
-
from ring.utils import parse_path
|
8
|
-
from ring.utils import pickle_load
|
9
7
|
import torch
|
10
8
|
from torch.utils.data import DataLoader
|
11
9
|
from torch.utils.data import Dataset
|
12
10
|
import tqdm
|
13
11
|
from tree_utils import PyTree
|
14
12
|
|
13
|
+
from ring.utils import parse_path
|
14
|
+
from ring.utils import pickle_load
|
15
|
+
|
15
16
|
|
16
17
|
def make_generator(
|
17
18
|
*paths,
|
@@ -103,7 +104,7 @@ def pytorch_generator(
|
|
103
104
|
dl_iter = iter(dl)
|
104
105
|
|
105
106
|
def to_numpy(tree: PyTree[torch.Tensor]):
|
106
|
-
return jax.
|
107
|
+
return jax.tree.map(lambda tensor: tensor.numpy(), tree)
|
107
108
|
|
108
109
|
def generator(_):
|
109
110
|
nonlocal dl, dl_iter
|
ring/utils/dataloader_torch.py
CHANGED
@@ -1,16 +1,25 @@
|
|
1
1
|
import os
|
2
|
+
import pickle
|
2
3
|
from typing import Any, Optional
|
3
4
|
import warnings
|
4
5
|
|
5
|
-
import jax
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
from torch.utils.data import DataLoader
|
9
9
|
from torch.utils.data import Dataset
|
10
|
+
import tree
|
10
11
|
from tree_utils import PyTree
|
11
12
|
|
12
|
-
from ring.utils import parse_path
|
13
|
-
|
13
|
+
from ring.utils.path import parse_path
|
14
|
+
|
15
|
+
|
16
|
+
def pickle_load(
|
17
|
+
path,
|
18
|
+
):
|
19
|
+
path = parse_path(path, extension="pickle", require_is_file=True)
|
20
|
+
with open(path, "rb") as file:
|
21
|
+
obj = pickle.load(file)
|
22
|
+
return obj
|
14
23
|
|
15
24
|
|
16
25
|
class FolderOfFilesDataset(Dataset):
|
@@ -60,8 +69,8 @@ def dataset_to_generator(
|
|
60
69
|
)
|
61
70
|
dl_iter = iter(dl)
|
62
71
|
|
63
|
-
def to_numpy(
|
64
|
-
return
|
72
|
+
def to_numpy(data: PyTree[torch.Tensor]):
|
73
|
+
return tree.map_structure(lambda tensor: tensor.numpy(), data)
|
65
74
|
|
66
75
|
def generator(_):
|
67
76
|
nonlocal dl, dl_iter
|
ring/utils/hdf5.py
CHANGED
@@ -121,7 +121,7 @@ def _parse_path(
|
|
121
121
|
|
122
122
|
def _tree_concat(trees: list):
|
123
123
|
# otherwise scalar-arrays will lead to indexing error
|
124
|
-
trees = jax.
|
124
|
+
trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
|
125
125
|
|
126
126
|
if len(trees) == 0:
|
127
127
|
return trees
|
ring/utils/normalizer.py
CHANGED
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
-
from ring.algorithms.generator import types
|
7
6
|
import tree_utils
|
8
7
|
|
8
|
+
from ring.algorithms.generator import types
|
9
|
+
|
9
10
|
KEY = jax.random.PRNGKey(777)
|
10
11
|
KEY_PERMUTATION = jax.random.PRNGKey(888)
|
11
12
|
|
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
|
|
37
38
|
# permute 0-th axis, since batchsize of generator might be larger than
|
38
39
|
# `approx_with_large_batchsize`, then we would not get a representative
|
39
40
|
# subsample otherwise
|
40
|
-
Xs = jax.
|
41
|
+
Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
|
41
42
|
Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
|
42
43
|
|
43
44
|
# obtain statistics
|
44
|
-
mean = jax.
|
45
|
-
std = jax.
|
45
|
+
mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
|
46
|
+
std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
|
46
47
|
|
47
48
|
if verbose:
|
48
49
|
print("Mean: ", mean)
|
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
|
|
51
52
|
eps = 1e-8
|
52
53
|
|
53
54
|
def normalizer(X):
|
54
|
-
return jax.
|
55
|
+
return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
|
55
56
|
|
56
57
|
return normalizer
|
ring/utils/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from importlib import import_module as _import_module
|
2
2
|
import io
|
3
|
+
from pathlib import Path
|
3
4
|
import pickle
|
4
5
|
import random
|
5
6
|
from typing import Optional
|
@@ -152,13 +153,28 @@ def import_lib(
|
|
152
153
|
|
153
154
|
def pickle_save(obj, path, overwrite: bool = False):
|
154
155
|
path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
|
155
|
-
|
156
|
-
|
156
|
+
try:
|
157
|
+
with open(path, "wb") as file:
|
158
|
+
pickle.dump(obj, file, protocol=5)
|
159
|
+
except OSError as e:
|
160
|
+
print(
|
161
|
+
f"saving with `pickle` throws exception {e}. "
|
162
|
+
+ "Attempting to save using `joblib`"
|
163
|
+
)
|
164
|
+
path = parse_path(path, extension="joblib", file_exists_ok=overwrite)
|
165
|
+
import joblib
|
166
|
+
|
167
|
+
joblib.dump(obj, path)
|
157
168
|
|
158
169
|
|
159
170
|
def pickle_load(
|
160
171
|
path,
|
161
172
|
):
|
173
|
+
if Path(path).suffix == ".joblib":
|
174
|
+
import joblib
|
175
|
+
|
176
|
+
return joblib.load(path)
|
177
|
+
|
162
178
|
path = parse_path(path, extension="pickle", require_is_file=True)
|
163
179
|
with open(path, "rb") as file:
|
164
180
|
obj = pickle.load(file)
|
File without changes
|