env-ssl-wrapper 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/PKG-INFO +1 -1
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/__init__.py +5 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/auto_batched_wrapper.py +58 -0
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/env_ssl_wrapper/image_wrapper.py +15 -3
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/tensor_wrapper.py +74 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/utils.py +30 -0
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/pyproject.toml +1 -1
- env_ssl_wrapper-0.0.2/env_ssl_wrapper/__init__.py +0 -1
- env_ssl_wrapper-0.0.2/env_ssl_wrapper/env_ssl_wrapper.py +0 -1
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/.gitignore +0 -0
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/LICENSE +0 -0
- {env_ssl_wrapper-0.0.2 → env_ssl_wrapper-0.0.3}/README.md +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: env-ssl-wrapper
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: An RL environment wrapper for learning SSL in the background
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch import is_tensor
|
|
5
|
+
from torch.utils._pytree import tree_map
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
|
|
8
|
+
# helper functions
|
|
9
|
+
|
|
10
|
+
def is_vectorized(env) -> bool:
|
|
11
|
+
if hasattr(env, 'num_envs') and getattr(env, 'num_envs', 0) > 0:
|
|
12
|
+
return True
|
|
13
|
+
try:
|
|
14
|
+
from gymnasium.vector import VectorEnv
|
|
15
|
+
return isinstance(getattr(env, 'unwrapped', env), VectorEnv)
|
|
16
|
+
except ImportError:
|
|
17
|
+
return False
|
|
18
|
+
|
|
19
|
+
def maybe_expand_dim(x):
|
|
20
|
+
def _expand(t):
|
|
21
|
+
if isinstance(t, np.ndarray) or is_tensor(t):
|
|
22
|
+
return rearrange(t, '... -> 1 ...')
|
|
23
|
+
if isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
24
|
+
return np.array([t])
|
|
25
|
+
return t
|
|
26
|
+
return tree_map(_expand, x)
|
|
27
|
+
|
|
28
|
+
def maybe_squeeze_dim(x):
|
|
29
|
+
def _squeeze(t):
|
|
30
|
+
if isinstance(t, np.ndarray) or is_tensor(t):
|
|
31
|
+
return rearrange(t, '1 ... -> ...')
|
|
32
|
+
return t
|
|
33
|
+
return tree_map(_squeeze, x)
|
|
34
|
+
|
|
35
|
+
# classes
|
|
36
|
+
|
|
37
|
+
class AutoBatchedWrapper:
|
|
38
|
+
def __init__(self, env, is_vector: bool | None = None):
|
|
39
|
+
self.env = env
|
|
40
|
+
self.is_vector = is_vector if is_vector is not None else is_vectorized(env)
|
|
41
|
+
|
|
42
|
+
def __getattr__(self, name):
|
|
43
|
+
if name.startswith('_'):
|
|
44
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
45
|
+
return getattr(self.env, name)
|
|
46
|
+
|
|
47
|
+
def reset(self, **kwargs):
|
|
48
|
+
obs, info = self.env.reset(**kwargs)
|
|
49
|
+
return (maybe_expand_dim(obs), info) if not self.is_vector else (obs, info)
|
|
50
|
+
|
|
51
|
+
def step(self, action):
|
|
52
|
+
action = maybe_squeeze_dim(action) if not self.is_vector else action
|
|
53
|
+
out = self.env.step(action)
|
|
54
|
+
|
|
55
|
+
if self.is_vector:
|
|
56
|
+
return out
|
|
57
|
+
|
|
58
|
+
return *maybe_expand_dim(out[:4]), out[4]
|
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import torch
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
|
-
import gymnasium as gym
|
|
7
6
|
from PIL import Image
|
|
8
7
|
from einops import rearrange
|
|
9
8
|
|
|
@@ -14,7 +13,7 @@ def cast_tuple(t, length = 1):
|
|
|
14
13
|
|
|
15
14
|
# class
|
|
16
15
|
|
|
17
|
-
class ImageObservationWrapper
|
|
16
|
+
class ImageObservationWrapper:
|
|
18
17
|
def __init__(
|
|
19
18
|
self,
|
|
20
19
|
env,
|
|
@@ -24,13 +23,18 @@ class ImageObservationWrapper(gym.ObservationWrapper):
|
|
|
24
23
|
normalize = True,
|
|
25
24
|
normalize_divisor = 255.0
|
|
26
25
|
):
|
|
27
|
-
|
|
26
|
+
self.env = env
|
|
28
27
|
self.image_size = cast_tuple(image_size, 2)
|
|
29
28
|
self.image_key = image_key
|
|
30
29
|
self.resample_method = resample_method
|
|
31
30
|
self.normalize = normalize
|
|
32
31
|
self.normalize_divisor = normalize_divisor
|
|
33
32
|
|
|
33
|
+
def __getattr__(self, name):
|
|
34
|
+
if name.startswith('_'):
|
|
35
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
36
|
+
return getattr(self.env, name)
|
|
37
|
+
|
|
34
38
|
def render_frame(self):
|
|
35
39
|
img = self.env.render()
|
|
36
40
|
img = Image.fromarray(img).resize(self.image_size, resample = self.resample_method)
|
|
@@ -56,3 +60,11 @@ class ImageObservationWrapper(gym.ObservationWrapper):
|
|
|
56
60
|
obs.update({self.image_key: img_tensor})
|
|
57
61
|
|
|
58
62
|
return obs
|
|
63
|
+
|
|
64
|
+
def reset(self, **kwargs):
|
|
65
|
+
obs, info = self.env.reset(**kwargs)
|
|
66
|
+
return self.observation(obs), info
|
|
67
|
+
|
|
68
|
+
def step(self, action):
|
|
69
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
70
|
+
return self.observation(obs), reward, terminated, truncated, info
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch import tensor, is_tensor, from_numpy, float64, device as torch_device
|
|
5
|
+
from torch.utils._pytree import tree_map
|
|
6
|
+
|
|
7
|
+
# helper functions
|
|
8
|
+
|
|
9
|
+
def numpy_to_torch(x, device, cast_float64_to_float32 = False):
|
|
10
|
+
def _to_torch(t):
|
|
11
|
+
if isinstance(t, np.ndarray):
|
|
12
|
+
t = from_numpy(t)
|
|
13
|
+
elif isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
14
|
+
t = tensor(t)
|
|
15
|
+
|
|
16
|
+
if not is_tensor(t):
|
|
17
|
+
return t
|
|
18
|
+
|
|
19
|
+
if cast_float64_to_float32 and t.dtype == float64:
|
|
20
|
+
t = t.float()
|
|
21
|
+
|
|
22
|
+
return t.to(device)
|
|
23
|
+
return tree_map(_to_torch, x)
|
|
24
|
+
|
|
25
|
+
def torch_to_numpy(x, cast_float64_to_float32 = False):
|
|
26
|
+
def _to_numpy(t):
|
|
27
|
+
if is_tensor(t):
|
|
28
|
+
t = t.detach().cpu().numpy()
|
|
29
|
+
elif isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
30
|
+
t = np.array(t)
|
|
31
|
+
|
|
32
|
+
if not isinstance(t, np.ndarray):
|
|
33
|
+
return t
|
|
34
|
+
|
|
35
|
+
if cast_float64_to_float32 and t.dtype == np.float64:
|
|
36
|
+
t = t.astype(np.float32)
|
|
37
|
+
|
|
38
|
+
return t
|
|
39
|
+
return tree_map(_to_numpy, x)
|
|
40
|
+
|
|
41
|
+
# classes
|
|
42
|
+
|
|
43
|
+
class TensorWrapper:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
env,
|
|
47
|
+
device: str | torch_device = 'cpu',
|
|
48
|
+
convert_in: bool = True,
|
|
49
|
+
convert_out: bool = True,
|
|
50
|
+
cast_float64_to_float32: bool = False
|
|
51
|
+
):
|
|
52
|
+
self.env = env
|
|
53
|
+
self.device = torch_device(device)
|
|
54
|
+
self.convert_in = convert_in
|
|
55
|
+
self.convert_out = convert_out
|
|
56
|
+
self.cast_float64_to_float32 = cast_float64_to_float32
|
|
57
|
+
|
|
58
|
+
def __getattr__(self, name):
|
|
59
|
+
if name.startswith('_'):
|
|
60
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
61
|
+
return getattr(self.env, name)
|
|
62
|
+
|
|
63
|
+
def reset(self, **kwargs):
|
|
64
|
+
obs, info = self.env.reset(**kwargs)
|
|
65
|
+
return (numpy_to_torch(obs, self.device, self.cast_float64_to_float32), info) if self.convert_out else (obs, info)
|
|
66
|
+
|
|
67
|
+
def step(self, action):
|
|
68
|
+
action = torch_to_numpy(action, self.cast_float64_to_float32) if self.convert_in else action
|
|
69
|
+
out = self.env.step(action)
|
|
70
|
+
|
|
71
|
+
if not self.convert_out:
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
return *numpy_to_torch(out[:4], self.device, self.cast_float64_to_float32), out[4]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
from .image_wrapper import ImageObservationWrapper
|
|
5
|
+
from .auto_batched_wrapper import AutoBatchedWrapper
|
|
6
|
+
from .tensor_wrapper import TensorWrapper
|
|
7
|
+
|
|
8
|
+
WRAPPERS = dict(
|
|
9
|
+
image = ImageObservationWrapper,
|
|
10
|
+
auto_batch = AutoBatchedWrapper,
|
|
11
|
+
tensor = TensorWrapper
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
def compose_env(env, *wrappers):
|
|
15
|
+
for wrapper in wrappers:
|
|
16
|
+
if isinstance(wrapper, str):
|
|
17
|
+
wrapper = WRAPPERS[wrapper]
|
|
18
|
+
|
|
19
|
+
if isinstance(wrapper, tuple):
|
|
20
|
+
name_or_fn, kwargs = wrapper
|
|
21
|
+
fn = WRAPPERS.get(name_or_fn, name_or_fn)
|
|
22
|
+
wrapper = partial(fn, **kwargs)
|
|
23
|
+
|
|
24
|
+
env = wrapper(env)
|
|
25
|
+
|
|
26
|
+
return env
|
|
27
|
+
|
|
28
|
+
# alias
|
|
29
|
+
|
|
30
|
+
wrap_env = compose_env
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .image_wrapper import ImageObservationWrapper
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|