imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/METADATA +1 -1
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/RECORD +27 -27
- ring/algorithms/custom_joints/suntay.py +1 -1
- ring/algorithms/dynamics.py +27 -1
- ring/algorithms/generator/base.py +82 -2
- ring/algorithms/generator/batch.py +2 -2
- ring/algorithms/generator/finalize_fns.py +1 -1
- ring/algorithms/generator/pd_control.py +1 -1
- ring/algorithms/jcalc.py +198 -0
- ring/algorithms/kinematics.py +2 -1
- ring/algorithms/sensors.py +12 -10
- ring/base.py +356 -27
- ring/io/xml/from_xml.py +1 -1
- ring/ml/base.py +4 -3
- ring/ml/ml_utils.py +3 -3
- ring/ml/ringnet.py +1 -1
- ring/ml/train.py +2 -2
- ring/rendering/mujoco_render.py +11 -7
- ring/rendering/vispy_render.py +5 -4
- ring/sys_composer/inject_sys.py +3 -2
- ring/utils/batchsize.py +3 -3
- ring/utils/dataloader.py +4 -3
- ring/utils/dataloader_torch.py +14 -5
- ring/utils/hdf5.py +1 -1
- ring/utils/normalizer.py +6 -5
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/top_level.txt +0 -0
ring/rendering/mujoco_render.py
CHANGED
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
|
|
10
10
|
_skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
11
11
|
|
12
12
|
|
13
|
-
def _floor(
|
14
|
-
return f"""<geom name="floor" pos="0 0 {
|
13
|
+
def _floor(z: float, material: str) -> str:
|
14
|
+
return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
|
15
15
|
|
16
16
|
|
17
17
|
def _build_model_of_geoms(
|
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
|
|
19
19
|
cameras: dict[int, Sequence[str]],
|
20
20
|
lights: dict[int, Sequence[str]],
|
21
21
|
floor: bool,
|
22
|
-
|
22
|
+
floor_kwargs: dict,
|
23
23
|
stars: bool,
|
24
24
|
debug: bool,
|
25
25
|
) -> mujoco.MjModel:
|
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
|
|
77
77
|
xml_str = f""" # noqa: E501
|
78
78
|
<mujoco>
|
79
79
|
<asset>
|
80
|
-
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".
|
80
|
+
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
|
81
81
|
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
|
82
82
|
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
|
83
83
|
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
|
84
|
+
<material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
|
85
|
+
<material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
|
86
|
+
<material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
|
84
87
|
{_skybox if stars else ''}
|
85
88
|
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
|
86
89
|
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
|
|
98
101
|
<camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
|
99
102
|
<camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
|
100
103
|
<camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
|
101
|
-
{_floor(
|
104
|
+
{_floor(**floor_kwargs) if floor else ''}
|
102
105
|
{inside_worldbody_cameras}
|
103
106
|
{inside_worldbody_lights}
|
104
107
|
{inside_worldbody}
|
@@ -176,6 +179,7 @@ class MujocoScene:
|
|
176
179
|
show_stars: bool = True,
|
177
180
|
show_floor: bool = True,
|
178
181
|
floor_z: float = -0.84,
|
182
|
+
floor_material: str = "matplane",
|
179
183
|
debug: bool = False,
|
180
184
|
) -> None:
|
181
185
|
self.debug = debug
|
@@ -190,7 +194,7 @@ class MujocoScene:
|
|
190
194
|
self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
|
191
195
|
self.show_stars = show_stars
|
192
196
|
self.show_floor = show_floor
|
193
|
-
self.
|
197
|
+
self.floor_kwargs = dict(z=floor_z, material=floor_material)
|
194
198
|
|
195
199
|
def init(self, geoms: list[base.Geometry]):
|
196
200
|
self._parent_ids = list(set([geom.link_idx for geom in geoms]))
|
@@ -199,7 +203,7 @@ class MujocoScene:
|
|
199
203
|
self.add_cameras,
|
200
204
|
self.add_lights,
|
201
205
|
floor=self.show_floor,
|
202
|
-
|
206
|
+
floor_kwargs=self.floor_kwargs,
|
203
207
|
stars=self.show_stars,
|
204
208
|
debug=self.debug,
|
205
209
|
)
|
ring/rendering/vispy_render.py
CHANGED
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
|
|
7
7
|
import jax
|
8
8
|
import jax.numpy as jnp
|
9
9
|
import numpy as np
|
10
|
-
from ring import algebra
|
11
|
-
from ring import base
|
12
|
-
from ring import maths
|
13
10
|
from tree_utils import PyTree
|
14
11
|
from tree_utils import tree_batch
|
15
12
|
from vispy import scene
|
16
13
|
from vispy.scene import MatrixTransform
|
17
14
|
|
15
|
+
from ring import algebra
|
16
|
+
from ring import base
|
17
|
+
from ring import maths
|
18
|
+
|
18
19
|
from . import vispy_visuals
|
19
20
|
|
20
21
|
Camera = TypeVar("Camera")
|
@@ -192,7 +193,7 @@ class Scene(ABC):
|
|
192
193
|
|
193
194
|
# step 3: update visuals
|
194
195
|
for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
|
195
|
-
t = jax.
|
196
|
+
t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
|
196
197
|
if self._fresh_init:
|
197
198
|
self._init_visual(visual, t, geom)
|
198
199
|
else:
|
ring/sys_composer/inject_sys.py
CHANGED
@@ -2,12 +2,13 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
from ring import base
|
6
5
|
from tree_utils import tree_batch
|
7
6
|
|
7
|
+
from ring import base
|
8
|
+
|
8
9
|
|
9
10
|
def _tree_nan_like(tree, repeats: int):
|
10
|
-
return jax.
|
11
|
+
return jax.tree.map(
|
11
12
|
lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
|
12
13
|
)
|
13
14
|
|
ring/utils/batchsize.py
CHANGED
@@ -39,19 +39,19 @@ def merge_batchsize(
|
|
39
39
|
tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
|
40
40
|
) -> PyTree:
|
41
41
|
if third_dim_also:
|
42
|
-
return jax.
|
42
|
+
return jax.tree.map(
|
43
43
|
lambda arr: arr.reshape(
|
44
44
|
(pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
|
45
45
|
),
|
46
46
|
tree,
|
47
47
|
)
|
48
|
-
return jax.
|
48
|
+
return jax.tree.map(
|
49
49
|
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
|
50
50
|
)
|
51
51
|
|
52
52
|
|
53
53
|
def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
54
|
-
return jax.
|
54
|
+
return jax.tree.map(
|
55
55
|
lambda arr: arr.reshape(
|
56
56
|
(
|
57
57
|
pmap_size,
|
ring/utils/dataloader.py
CHANGED
@@ -4,14 +4,15 @@ from typing import Callable, Optional
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import numpy as np
|
7
|
-
from ring.utils import parse_path
|
8
|
-
from ring.utils import pickle_load
|
9
7
|
import torch
|
10
8
|
from torch.utils.data import DataLoader
|
11
9
|
from torch.utils.data import Dataset
|
12
10
|
import tqdm
|
13
11
|
from tree_utils import PyTree
|
14
12
|
|
13
|
+
from ring.utils import parse_path
|
14
|
+
from ring.utils import pickle_load
|
15
|
+
|
15
16
|
|
16
17
|
def make_generator(
|
17
18
|
*paths,
|
@@ -103,7 +104,7 @@ def pytorch_generator(
|
|
103
104
|
dl_iter = iter(dl)
|
104
105
|
|
105
106
|
def to_numpy(tree: PyTree[torch.Tensor]):
|
106
|
-
return jax.
|
107
|
+
return jax.tree.map(lambda tensor: tensor.numpy(), tree)
|
107
108
|
|
108
109
|
def generator(_):
|
109
110
|
nonlocal dl, dl_iter
|
ring/utils/dataloader_torch.py
CHANGED
@@ -1,16 +1,25 @@
|
|
1
1
|
import os
|
2
|
+
import pickle
|
2
3
|
from typing import Any, Optional
|
3
4
|
import warnings
|
4
5
|
|
5
|
-
import jax
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
from torch.utils.data import DataLoader
|
9
9
|
from torch.utils.data import Dataset
|
10
|
+
import tree
|
10
11
|
from tree_utils import PyTree
|
11
12
|
|
12
|
-
from ring.utils import parse_path
|
13
|
-
|
13
|
+
from ring.utils.path import parse_path
|
14
|
+
|
15
|
+
|
16
|
+
def pickle_load(
|
17
|
+
path,
|
18
|
+
):
|
19
|
+
path = parse_path(path, extension="pickle", require_is_file=True)
|
20
|
+
with open(path, "rb") as file:
|
21
|
+
obj = pickle.load(file)
|
22
|
+
return obj
|
14
23
|
|
15
24
|
|
16
25
|
class FolderOfFilesDataset(Dataset):
|
@@ -60,8 +69,8 @@ def dataset_to_generator(
|
|
60
69
|
)
|
61
70
|
dl_iter = iter(dl)
|
62
71
|
|
63
|
-
def to_numpy(
|
64
|
-
return
|
72
|
+
def to_numpy(data: PyTree[torch.Tensor]):
|
73
|
+
return tree.map_structure(lambda tensor: tensor.numpy(), data)
|
65
74
|
|
66
75
|
def generator(_):
|
67
76
|
nonlocal dl, dl_iter
|
ring/utils/hdf5.py
CHANGED
@@ -121,7 +121,7 @@ def _parse_path(
|
|
121
121
|
|
122
122
|
def _tree_concat(trees: list):
|
123
123
|
# otherwise scalar-arrays will lead to indexing error
|
124
|
-
trees = jax.
|
124
|
+
trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
|
125
125
|
|
126
126
|
if len(trees) == 0:
|
127
127
|
return trees
|
ring/utils/normalizer.py
CHANGED
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
-
from ring.algorithms.generator import types
|
7
6
|
import tree_utils
|
8
7
|
|
8
|
+
from ring.algorithms.generator import types
|
9
|
+
|
9
10
|
KEY = jax.random.PRNGKey(777)
|
10
11
|
KEY_PERMUTATION = jax.random.PRNGKey(888)
|
11
12
|
|
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
|
|
37
38
|
# permute 0-th axis, since batchsize of generator might be larger than
|
38
39
|
# `approx_with_large_batchsize`, then we would not get a representative
|
39
40
|
# subsample otherwise
|
40
|
-
Xs = jax.
|
41
|
+
Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
|
41
42
|
Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
|
42
43
|
|
43
44
|
# obtain statistics
|
44
|
-
mean = jax.
|
45
|
-
std = jax.
|
45
|
+
mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
|
46
|
+
std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
|
46
47
|
|
47
48
|
if verbose:
|
48
49
|
print("Mean: ", mean)
|
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
|
|
51
52
|
eps = 1e-8
|
52
53
|
|
53
54
|
def normalizer(X):
|
54
|
-
return jax.
|
55
|
+
return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
|
55
56
|
|
56
57
|
return normalizer
|
File without changes
|
File without changes
|