imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
10
10
  _skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
11
11
 
12
12
 
13
- def _floor(floor_z: float) -> str:
14
- return f"""<geom name="floor" pos="0 0 {floor_z}" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
13
+ def _floor(z: float, material: str) -> str:
14
+ return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
15
15
 
16
16
 
17
17
  def _build_model_of_geoms(
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
19
19
  cameras: dict[int, Sequence[str]],
20
20
  lights: dict[int, Sequence[str]],
21
21
  floor: bool,
22
- floor_z: float,
22
+ floor_kwargs: dict,
23
23
  stars: bool,
24
24
  debug: bool,
25
25
  ) -> mujoco.MjModel:
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
77
77
  xml_str = f""" # noqa: E501
78
78
  <mujoco>
79
79
  <asset>
80
- <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
80
+ <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
81
81
  <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
82
82
  <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
83
83
  <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
84
+ <material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
85
+ <material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
86
+ <material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
84
87
  {_skybox if stars else ''}
85
88
  <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
86
89
  <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
98
101
  <camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
99
102
  <camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
100
103
  <camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
101
- {_floor(floor_z) if floor else ''}
104
+ {_floor(**floor_kwargs) if floor else ''}
102
105
  {inside_worldbody_cameras}
103
106
  {inside_worldbody_lights}
104
107
  {inside_worldbody}
@@ -176,6 +179,7 @@ class MujocoScene:
176
179
  show_stars: bool = True,
177
180
  show_floor: bool = True,
178
181
  floor_z: float = -0.84,
182
+ floor_material: str = "matplane",
179
183
  debug: bool = False,
180
184
  ) -> None:
181
185
  self.debug = debug
@@ -190,7 +194,7 @@ class MujocoScene:
190
194
  self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
191
195
  self.show_stars = show_stars
192
196
  self.show_floor = show_floor
193
- self.floor_z = floor_z
197
+ self.floor_kwargs = dict(z=floor_z, material=floor_material)
194
198
 
195
199
  def init(self, geoms: list[base.Geometry]):
196
200
  self._parent_ids = list(set([geom.link_idx for geom in geoms]))
@@ -199,7 +203,7 @@ class MujocoScene:
199
203
  self.add_cameras,
200
204
  self.add_lights,
201
205
  floor=self.show_floor,
202
- floor_z=self.floor_z,
206
+ floor_kwargs=self.floor_kwargs,
203
207
  stars=self.show_stars,
204
208
  debug=self.debug,
205
209
  )
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
7
7
  import jax
8
8
  import jax.numpy as jnp
9
9
  import numpy as np
10
- from ring import algebra
11
- from ring import base
12
- from ring import maths
13
10
  from tree_utils import PyTree
14
11
  from tree_utils import tree_batch
15
12
  from vispy import scene
16
13
  from vispy.scene import MatrixTransform
17
14
 
15
+ from ring import algebra
16
+ from ring import base
17
+ from ring import maths
18
+
18
19
  from . import vispy_visuals
19
20
 
20
21
  Camera = TypeVar("Camera")
@@ -192,7 +193,7 @@ class Scene(ABC):
192
193
 
193
194
  # step 3: update visuals
194
195
  for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
195
- t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
196
+ t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
196
197
  if self._fresh_init:
197
198
  self._init_visual(visual, t, geom)
198
199
  else:
@@ -2,12 +2,13 @@ from typing import Optional
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- from ring import base
6
5
  from tree_utils import tree_batch
7
6
 
7
+ from ring import base
8
+
8
9
 
9
10
  def _tree_nan_like(tree, repeats: int):
10
- return jax.tree_map(
11
+ return jax.tree.map(
11
12
  lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
12
13
  )
13
14
 
ring/utils/batchsize.py CHANGED
@@ -39,19 +39,19 @@ def merge_batchsize(
39
39
  tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
40
40
  ) -> PyTree:
41
41
  if third_dim_also:
42
- return jax.tree_map(
42
+ return jax.tree.map(
43
43
  lambda arr: arr.reshape(
44
44
  (pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
45
45
  ),
46
46
  tree,
47
47
  )
48
- return jax.tree_map(
48
+ return jax.tree.map(
49
49
  lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
50
50
  )
51
51
 
52
52
 
53
53
  def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
54
- return jax.tree_map(
54
+ return jax.tree.map(
55
55
  lambda arr: arr.reshape(
56
56
  (
57
57
  pmap_size,
ring/utils/dataloader.py CHANGED
@@ -4,14 +4,15 @@ from typing import Callable, Optional
4
4
 
5
5
  import jax
6
6
  import numpy as np
7
- from ring.utils import parse_path
8
- from ring.utils import pickle_load
9
7
  import torch
10
8
  from torch.utils.data import DataLoader
11
9
  from torch.utils.data import Dataset
12
10
  import tqdm
13
11
  from tree_utils import PyTree
14
12
 
13
+ from ring.utils import parse_path
14
+ from ring.utils import pickle_load
15
+
15
16
 
16
17
  def make_generator(
17
18
  *paths,
@@ -103,7 +104,7 @@ def pytorch_generator(
103
104
  dl_iter = iter(dl)
104
105
 
105
106
  def to_numpy(tree: PyTree[torch.Tensor]):
106
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
107
+ return jax.tree.map(lambda tensor: tensor.numpy(), tree)
107
108
 
108
109
  def generator(_):
109
110
  nonlocal dl, dl_iter
@@ -1,16 +1,25 @@
1
1
  import os
2
+ import pickle
2
3
  from typing import Any, Optional
3
4
  import warnings
4
5
 
5
- import jax
6
6
  import numpy as np
7
7
  import torch
8
8
  from torch.utils.data import DataLoader
9
9
  from torch.utils.data import Dataset
10
+ import tree
10
11
  from tree_utils import PyTree
11
12
 
12
- from ring.utils import parse_path
13
- from ring.utils import pickle_load
13
+ from ring.utils.path import parse_path
14
+
15
+
16
+ def pickle_load(
17
+ path,
18
+ ):
19
+ path = parse_path(path, extension="pickle", require_is_file=True)
20
+ with open(path, "rb") as file:
21
+ obj = pickle.load(file)
22
+ return obj
14
23
 
15
24
 
16
25
  class FolderOfFilesDataset(Dataset):
@@ -60,8 +69,8 @@ def dataset_to_generator(
60
69
  )
61
70
  dl_iter = iter(dl)
62
71
 
63
- def to_numpy(tree: PyTree[torch.Tensor]):
64
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
72
+ def to_numpy(data: PyTree[torch.Tensor]):
73
+ return tree.map_structure(lambda tensor: tensor.numpy(), data)
65
74
 
66
75
  def generator(_):
67
76
  nonlocal dl, dl_iter
ring/utils/hdf5.py CHANGED
@@ -121,7 +121,7 @@ def _parse_path(
121
121
 
122
122
  def _tree_concat(trees: list):
123
123
  # otherwise scalar-arrays will lead to indexing error
124
- trees = jax.tree_map(lambda arr: np.atleast_1d(arr), trees)
124
+ trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
125
125
 
126
126
  if len(trees) == 0:
127
127
  return trees
ring/utils/normalizer.py CHANGED
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- from ring.algorithms.generator import types
7
6
  import tree_utils
8
7
 
8
+ from ring.algorithms.generator import types
9
+
9
10
  KEY = jax.random.PRNGKey(777)
10
11
  KEY_PERMUTATION = jax.random.PRNGKey(888)
11
12
 
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
37
38
  # permute 0-th axis, since batchsize of generator might be larger than
38
39
  # `approx_with_large_batchsize`, then we would not get a representative
39
40
  # subsample otherwise
40
- Xs = jax.tree_map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
+ Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
42
  Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
42
43
 
43
44
  # obtain statistics
44
- mean = jax.tree_map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
45
- std = jax.tree_map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
45
+ mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
46
+ std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
46
47
 
47
48
  if verbose:
48
49
  print("Mean: ", mean)
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
51
52
  eps = 1e-8
52
53
 
53
54
  def normalizer(X):
54
- return jax.tree_map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
+ return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
56
 
56
57
  return normalizer