env-ssl-wrapper 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: env-ssl-wrapper
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: An RL environment wrapper for learning SSL in the background
5
5
  Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
@@ -0,0 +1,6 @@
1
+ from .image_wrapper import ImageObservationWrapper
2
+ from .auto_batched_wrapper import AutoBatchedWrapper
3
+ from .tensor_wrapper import TensorWrapper
4
+ from .action_transform_wrapper import ActionTransformWrapper
5
+
6
+ from .utils import wrap_env, compose_env
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import is_tensor
6
+ from torch.utils._pytree import tree_map
7
+
8
+ # helpers
9
+
10
+ def exists(v):
11
+ return v is not None
12
+
13
+ def default(v, d):
14
+ return v if exists(v) else d
15
+
16
+ def rescale(
17
+ t,
18
+ from_range: tuple[float, float],
19
+ to_range: tuple[float, float]
20
+ ):
21
+ from_min, from_max = from_range
22
+ to_min, to_max = to_range
23
+ return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
24
+
25
+ # classes
26
+
27
+ class ActionTransformWrapper:
28
+ def __init__(
29
+ self,
30
+ env,
31
+ transforms = None,
32
+ clip = None
33
+ ):
34
+ self.env = env
35
+
36
+ if isinstance(transforms, dict):
37
+ transforms = [transforms]
38
+
39
+ self.transforms = default(transforms, [])
40
+ self.clip = clip
41
+
42
+ def __getattr__(self, name):
43
+ if name.startswith('_'):
44
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
45
+ return getattr(self.env, name)
46
+
47
+ def reset(self, **kwargs):
48
+ return self.env.reset(**kwargs)
49
+
50
+ def step(self, action):
51
+ def transform_action(t):
52
+ is_torch_float = is_tensor(t) and t.is_floating_point()
53
+ is_np_float = isinstance(t, np.ndarray) and np.issubdtype(t.dtype, np.floating)
54
+ is_scalar_float = isinstance(t, float)
55
+
56
+ if not (is_torch_float or is_np_float or is_scalar_float):
57
+ return t
58
+
59
+ if is_tensor(t):
60
+ t = t.clone()
61
+ elif isinstance(t, np.ndarray):
62
+ t = np.copy(t)
63
+
64
+ for ind, transform in enumerate(self.transforms):
65
+ indices = transform.get('indices')
66
+ rescale_from_to = transform.get('rescale_from_to')
67
+
68
+ if not exists(indices) and len(self.transforms) > 1:
69
+ indices = ind
70
+
71
+ if exists(rescale_from_to):
72
+ from_range, to_range = rescale_from_to
73
+
74
+ if not exists(indices):
75
+ t = rescale(t, from_range, to_range)
76
+ else:
77
+ part = t[..., indices]
78
+ t[..., indices] = rescale(part, from_range, to_range)
79
+
80
+ if exists(self.clip):
81
+ min_clip, max_clip = self.clip
82
+
83
+ if is_tensor(t):
84
+ t = torch.clamp(t, min_clip, max_clip)
85
+ elif isinstance(t, np.ndarray):
86
+ t = np.clip(t, min_clip, max_clip)
87
+ else:
88
+ t = max(min_clip, min(t, max_clip))
89
+
90
+ return t
91
+
92
+ action = tree_map(transform_action, action)
93
+ return self.env.step(action)
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from torch import is_tensor
5
+ from torch.utils._pytree import tree_map
6
+ from einops import rearrange
7
+
8
+ # helper functions
9
+
10
+ def is_vectorized(env) -> bool:
11
+ if hasattr(env, 'num_envs') and getattr(env, 'num_envs', 0) > 0:
12
+ return True
13
+ try:
14
+ from gymnasium.vector import VectorEnv
15
+ return isinstance(getattr(env, 'unwrapped', env), VectorEnv)
16
+ except ImportError:
17
+ return False
18
+
19
+ def maybe_expand_dim(x):
20
+ def _expand(t):
21
+ if isinstance(t, np.ndarray) or is_tensor(t):
22
+ return rearrange(t, '... -> 1 ...')
23
+ if isinstance(t, (int, float, bool, np.number, np.bool_)):
24
+ return np.array([t])
25
+ return t
26
+ return tree_map(_expand, x)
27
+
28
+ def maybe_squeeze_dim(x):
29
+ def _squeeze(t):
30
+ if isinstance(t, np.ndarray) or is_tensor(t):
31
+ return rearrange(t, '1 ... -> ...')
32
+ return t
33
+ return tree_map(_squeeze, x)
34
+
35
+ # classes
36
+
37
+ class AutoBatchedWrapper:
38
+ def __init__(self, env, is_vector: bool | None = None):
39
+ self.env = env
40
+ self.is_vector = is_vector if is_vector is not None else is_vectorized(env)
41
+
42
+ def __getattr__(self, name):
43
+ if name.startswith('_'):
44
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
45
+ return getattr(self.env, name)
46
+
47
+ def reset(self, **kwargs):
48
+ obs, info = self.env.reset(**kwargs)
49
+ return (maybe_expand_dim(obs), info) if not self.is_vector else (obs, info)
50
+
51
+ def step(self, action):
52
+ action = maybe_squeeze_dim(action) if not self.is_vector else action
53
+ out = self.env.step(action)
54
+
55
+ if self.is_vector:
56
+ return out
57
+
58
+ return *maybe_expand_dim(out[:4]), out[4]
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import torch
4
4
  import numpy as np
5
5
 
6
- import gymnasium as gym
7
6
  from PIL import Image
8
7
  from einops import rearrange
9
8
 
@@ -14,7 +13,7 @@ def cast_tuple(t, length = 1):
14
13
 
15
14
  # class
16
15
 
17
- class ImageObservationWrapper(gym.ObservationWrapper):
16
+ class ImageObservationWrapper:
18
17
  def __init__(
19
18
  self,
20
19
  env,
@@ -24,13 +23,18 @@ class ImageObservationWrapper(gym.ObservationWrapper):
24
23
  normalize = True,
25
24
  normalize_divisor = 255.0
26
25
  ):
27
- super().__init__(env)
26
+ self.env = env
28
27
  self.image_size = cast_tuple(image_size, 2)
29
28
  self.image_key = image_key
30
29
  self.resample_method = resample_method
31
30
  self.normalize = normalize
32
31
  self.normalize_divisor = normalize_divisor
33
32
 
33
+ def __getattr__(self, name):
34
+ if name.startswith('_'):
35
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
36
+ return getattr(self.env, name)
37
+
34
38
  def render_frame(self):
35
39
  img = self.env.render()
36
40
  img = Image.fromarray(img).resize(self.image_size, resample = self.resample_method)
@@ -56,3 +60,11 @@ class ImageObservationWrapper(gym.ObservationWrapper):
56
60
  obs.update({self.image_key: img_tensor})
57
61
 
58
62
  return obs
63
+
64
+ def reset(self, **kwargs):
65
+ obs, info = self.env.reset(**kwargs)
66
+ return self.observation(obs), info
67
+
68
+ def step(self, action):
69
+ obs, reward, terminated, truncated, info = self.env.step(action)
70
+ return self.observation(obs), reward, terminated, truncated, info
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from torch import tensor, is_tensor, from_numpy, float64, device as torch_device
5
+ from torch.utils._pytree import tree_map
6
+
7
+ # helper functions
8
+
9
+ def numpy_to_torch(x, device, cast_float64_to_float32 = False):
10
+ def _to_torch(t):
11
+ if isinstance(t, np.ndarray):
12
+ t = from_numpy(t)
13
+ elif isinstance(t, (int, float, bool, np.number, np.bool_)):
14
+ t = tensor(t)
15
+
16
+ if not is_tensor(t):
17
+ return t
18
+
19
+ if cast_float64_to_float32 and t.dtype == float64:
20
+ t = t.float()
21
+
22
+ return t.to(device)
23
+ return tree_map(_to_torch, x)
24
+
25
+ def torch_to_numpy(x, cast_float64_to_float32 = False):
26
+ def _to_numpy(t):
27
+ if is_tensor(t):
28
+ t = t.detach().cpu().numpy()
29
+ elif isinstance(t, (int, float, bool, np.number, np.bool_)):
30
+ t = np.array(t)
31
+
32
+ if not isinstance(t, np.ndarray):
33
+ return t
34
+
35
+ if cast_float64_to_float32 and t.dtype == np.float64:
36
+ t = t.astype(np.float32)
37
+
38
+ return t
39
+ return tree_map(_to_numpy, x)
40
+
41
+ # classes
42
+
43
+ class TensorWrapper:
44
+ def __init__(
45
+ self,
46
+ env,
47
+ device: str | torch_device = 'cpu',
48
+ convert_in: bool = True,
49
+ convert_out: bool = True,
50
+ cast_float64_to_float32: bool = False
51
+ ):
52
+ self.env = env
53
+ self.device = torch_device(device)
54
+ self.convert_in = convert_in
55
+ self.convert_out = convert_out
56
+ self.cast_float64_to_float32 = cast_float64_to_float32
57
+
58
+ def __getattr__(self, name):
59
+ if name.startswith('_'):
60
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
61
+ return getattr(self.env, name)
62
+
63
+ def reset(self, **kwargs):
64
+ obs, info = self.env.reset(**kwargs)
65
+ return (numpy_to_torch(obs, self.device, self.cast_float64_to_float32), info) if self.convert_out else (obs, info)
66
+
67
+ def step(self, action):
68
+ action = torch_to_numpy(action, self.cast_float64_to_float32) if self.convert_in else action
69
+ out = self.env.step(action)
70
+
71
+ if not self.convert_out:
72
+ return out
73
+
74
+ return *numpy_to_torch(out[:4], self.device, self.cast_float64_to_float32), out[4]
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ from .image_wrapper import ImageObservationWrapper
5
+ from .auto_batched_wrapper import AutoBatchedWrapper
6
+ from .tensor_wrapper import TensorWrapper
7
+ from .action_transform_wrapper import ActionTransformWrapper
8
+
9
+ WRAPPERS = dict(
10
+ image = ImageObservationWrapper,
11
+ auto_batch = AutoBatchedWrapper,
12
+ tensor = TensorWrapper,
13
+ action_transform = ActionTransformWrapper
14
+ )
15
+
16
+ def is_unique(arr):
17
+ return len(set(arr)) == len(arr)
18
+
19
+ def compose_env(env, *wrappers):
20
+ funcs = []
21
+ classes = []
22
+
23
+ for wrapper in wrappers:
24
+ if isinstance(wrapper, str):
25
+ wrapper = WRAPPERS[wrapper]
26
+
27
+ if isinstance(wrapper, tuple):
28
+ name, kwargs = wrapper
29
+ wrapper = partial(WRAPPERS.get(name, name), **kwargs)
30
+
31
+ cls = wrapper.func if isinstance(wrapper, partial) else wrapper
32
+
33
+ funcs.append(wrapper)
34
+ classes.append(cls)
35
+
36
+ assert is_unique(classes), 'duplicate wrappers found'
37
+
38
+ for func in funcs:
39
+ env = func(env)
40
+
41
+ return env
42
+
43
+ # alias
44
+
45
+ wrap_env = compose_env
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "env-ssl-wrapper"
3
- version = "0.0.2"
3
+ version = "0.0.4"
4
4
  description = "An RL environment wrapper for learning SSL in the background"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1 +0,0 @@
1
- from .image_wrapper import ImageObservationWrapper
File without changes