env-ssl-wrapper 0.0.3__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/PKG-INFO +1 -1
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/env_ssl_wrapper/__init__.py +1 -0
- env_ssl_wrapper-0.0.4/env_ssl_wrapper/action_transform_wrapper.py +93 -0
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/env_ssl_wrapper/auto_batched_wrapper.py +2 -2
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/env_ssl_wrapper/tensor_wrapper.py +1 -1
- env_ssl_wrapper-0.0.4/env_ssl_wrapper/utils.py +45 -0
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/pyproject.toml +1 -1
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/utils.py +0 -30
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/.gitignore +0 -0
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/LICENSE +0 -0
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/README.md +0 -0
- {env_ssl_wrapper-0.0.3 → env_ssl_wrapper-0.0.4}/env_ssl_wrapper/image_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: env-ssl-wrapper
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: An RL environment wrapper for learning SSL in the background
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import is_tensor
|
|
6
|
+
from torch.utils._pytree import tree_map
|
|
7
|
+
|
|
8
|
+
# helpers
|
|
9
|
+
|
|
10
|
+
def exists(v):
|
|
11
|
+
return v is not None
|
|
12
|
+
|
|
13
|
+
def default(v, d):
|
|
14
|
+
return v if exists(v) else d
|
|
15
|
+
|
|
16
|
+
def rescale(
|
|
17
|
+
t,
|
|
18
|
+
from_range: tuple[float, float],
|
|
19
|
+
to_range: tuple[float, float]
|
|
20
|
+
):
|
|
21
|
+
from_min, from_max = from_range
|
|
22
|
+
to_min, to_max = to_range
|
|
23
|
+
return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
|
|
24
|
+
|
|
25
|
+
# classes
|
|
26
|
+
|
|
27
|
+
class ActionTransformWrapper:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
env,
|
|
31
|
+
transforms = None,
|
|
32
|
+
clip = None
|
|
33
|
+
):
|
|
34
|
+
self.env = env
|
|
35
|
+
|
|
36
|
+
if isinstance(transforms, dict):
|
|
37
|
+
transforms = [transforms]
|
|
38
|
+
|
|
39
|
+
self.transforms = default(transforms, [])
|
|
40
|
+
self.clip = clip
|
|
41
|
+
|
|
42
|
+
def __getattr__(self, name):
|
|
43
|
+
if name.startswith('_'):
|
|
44
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
45
|
+
return getattr(self.env, name)
|
|
46
|
+
|
|
47
|
+
def reset(self, **kwargs):
|
|
48
|
+
return self.env.reset(**kwargs)
|
|
49
|
+
|
|
50
|
+
def step(self, action):
|
|
51
|
+
def transform_action(t):
|
|
52
|
+
is_torch_float = is_tensor(t) and t.is_floating_point()
|
|
53
|
+
is_np_float = isinstance(t, np.ndarray) and np.issubdtype(t.dtype, np.floating)
|
|
54
|
+
is_scalar_float = isinstance(t, float)
|
|
55
|
+
|
|
56
|
+
if not (is_torch_float or is_np_float or is_scalar_float):
|
|
57
|
+
return t
|
|
58
|
+
|
|
59
|
+
if is_tensor(t):
|
|
60
|
+
t = t.clone()
|
|
61
|
+
elif isinstance(t, np.ndarray):
|
|
62
|
+
t = np.copy(t)
|
|
63
|
+
|
|
64
|
+
for ind, transform in enumerate(self.transforms):
|
|
65
|
+
indices = transform.get('indices')
|
|
66
|
+
rescale_from_to = transform.get('rescale_from_to')
|
|
67
|
+
|
|
68
|
+
if not exists(indices) and len(self.transforms) > 1:
|
|
69
|
+
indices = ind
|
|
70
|
+
|
|
71
|
+
if exists(rescale_from_to):
|
|
72
|
+
from_range, to_range = rescale_from_to
|
|
73
|
+
|
|
74
|
+
if not exists(indices):
|
|
75
|
+
t = rescale(t, from_range, to_range)
|
|
76
|
+
else:
|
|
77
|
+
part = t[..., indices]
|
|
78
|
+
t[..., indices] = rescale(part, from_range, to_range)
|
|
79
|
+
|
|
80
|
+
if exists(self.clip):
|
|
81
|
+
min_clip, max_clip = self.clip
|
|
82
|
+
|
|
83
|
+
if is_tensor(t):
|
|
84
|
+
t = torch.clamp(t, min_clip, max_clip)
|
|
85
|
+
elif isinstance(t, np.ndarray):
|
|
86
|
+
t = np.clip(t, min_clip, max_clip)
|
|
87
|
+
else:
|
|
88
|
+
t = max(min_clip, min(t, max_clip))
|
|
89
|
+
|
|
90
|
+
return t
|
|
91
|
+
|
|
92
|
+
action = tree_map(transform_action, action)
|
|
93
|
+
return self.env.step(action)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
from .image_wrapper import ImageObservationWrapper
|
|
5
|
+
from .auto_batched_wrapper import AutoBatchedWrapper
|
|
6
|
+
from .tensor_wrapper import TensorWrapper
|
|
7
|
+
from .action_transform_wrapper import ActionTransformWrapper
|
|
8
|
+
|
|
9
|
+
WRAPPERS = dict(
|
|
10
|
+
image = ImageObservationWrapper,
|
|
11
|
+
auto_batch = AutoBatchedWrapper,
|
|
12
|
+
tensor = TensorWrapper,
|
|
13
|
+
action_transform = ActionTransformWrapper
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def is_unique(arr):
|
|
17
|
+
return len(set(arr)) == len(arr)
|
|
18
|
+
|
|
19
|
+
def compose_env(env, *wrappers):
|
|
20
|
+
funcs = []
|
|
21
|
+
classes = []
|
|
22
|
+
|
|
23
|
+
for wrapper in wrappers:
|
|
24
|
+
if isinstance(wrapper, str):
|
|
25
|
+
wrapper = WRAPPERS[wrapper]
|
|
26
|
+
|
|
27
|
+
if isinstance(wrapper, tuple):
|
|
28
|
+
name, kwargs = wrapper
|
|
29
|
+
wrapper = partial(WRAPPERS.get(name, name), **kwargs)
|
|
30
|
+
|
|
31
|
+
cls = wrapper.func if isinstance(wrapper, partial) else wrapper
|
|
32
|
+
|
|
33
|
+
funcs.append(wrapper)
|
|
34
|
+
classes.append(cls)
|
|
35
|
+
|
|
36
|
+
assert is_unique(classes), 'duplicate wrappers found'
|
|
37
|
+
|
|
38
|
+
for func in funcs:
|
|
39
|
+
env = func(env)
|
|
40
|
+
|
|
41
|
+
return env
|
|
42
|
+
|
|
43
|
+
# alias
|
|
44
|
+
|
|
45
|
+
wrap_env = compose_env
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
from .image_wrapper import ImageObservationWrapper
|
|
5
|
-
from .auto_batched_wrapper import AutoBatchedWrapper
|
|
6
|
-
from .tensor_wrapper import TensorWrapper
|
|
7
|
-
|
|
8
|
-
WRAPPERS = dict(
|
|
9
|
-
image = ImageObservationWrapper,
|
|
10
|
-
auto_batch = AutoBatchedWrapper,
|
|
11
|
-
tensor = TensorWrapper
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
def compose_env(env, *wrappers):
|
|
15
|
-
for wrapper in wrappers:
|
|
16
|
-
if isinstance(wrapper, str):
|
|
17
|
-
wrapper = WRAPPERS[wrapper]
|
|
18
|
-
|
|
19
|
-
if isinstance(wrapper, tuple):
|
|
20
|
-
name_or_fn, kwargs = wrapper
|
|
21
|
-
fn = WRAPPERS.get(name_or_fn, name_or_fn)
|
|
22
|
-
wrapper = partial(fn, **kwargs)
|
|
23
|
-
|
|
24
|
-
env = wrapper(env)
|
|
25
|
-
|
|
26
|
-
return env
|
|
27
|
-
|
|
28
|
-
# alias
|
|
29
|
-
|
|
30
|
-
wrap_env = compose_env
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|