env-ssl-wrapper 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/PKG-INFO +1 -1
- env_ssl_wrapper-0.0.5/env_ssl_wrapper/action_transform_wrapper.py +99 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/pyproject.toml +1 -1
- env_ssl_wrapper-0.0.4/env_ssl_wrapper/action_transform_wrapper.py +0 -93
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/.gitignore +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/LICENSE +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/README.md +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/env_ssl_wrapper/__init__.py +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/env_ssl_wrapper/auto_batched_wrapper.py +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/env_ssl_wrapper/image_wrapper.py +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/env_ssl_wrapper/tensor_wrapper.py +0 -0
- {env_ssl_wrapper-0.0.4 → env_ssl_wrapper-0.0.5}/env_ssl_wrapper/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: env-ssl-wrapper
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: An RL environment wrapper for learning SSL in the background
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import is_tensor
|
|
6
|
+
from torch.utils._pytree import tree_map
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
# helpers
|
|
10
|
+
|
|
11
|
+
def exists(v):
|
|
12
|
+
return v is not None
|
|
13
|
+
|
|
14
|
+
def default(v, d):
|
|
15
|
+
return v if exists(v) else d
|
|
16
|
+
|
|
17
|
+
def is_float_dtype(t):
|
|
18
|
+
if is_tensor(t):
|
|
19
|
+
return t.is_floating_point()
|
|
20
|
+
if isinstance(t, np.ndarray):
|
|
21
|
+
return np.issubdtype(t.dtype, np.floating)
|
|
22
|
+
return isinstance(t, float)
|
|
23
|
+
|
|
24
|
+
def copy(t):
|
|
25
|
+
if is_tensor(t):
|
|
26
|
+
return t.clone()
|
|
27
|
+
if isinstance(t, np.ndarray):
|
|
28
|
+
return np.copy(t)
|
|
29
|
+
return t
|
|
30
|
+
|
|
31
|
+
def clamp(t, min_val, max_val):
|
|
32
|
+
if is_tensor(t):
|
|
33
|
+
return torch.clamp(t, min_val, max_val)
|
|
34
|
+
if isinstance(t, np.ndarray):
|
|
35
|
+
return np.clip(t, min_val, max_val)
|
|
36
|
+
return max(min_val, min(t, max_val))
|
|
37
|
+
|
|
38
|
+
def rescale(
|
|
39
|
+
t,
|
|
40
|
+
from_range: tuple[float, float],
|
|
41
|
+
to_range: tuple[float, float]
|
|
42
|
+
):
|
|
43
|
+
from_min, from_max = from_range
|
|
44
|
+
to_min, to_max = to_range
|
|
45
|
+
return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
|
|
46
|
+
|
|
47
|
+
# wrapper
|
|
48
|
+
|
|
49
|
+
class ActionTransformWrapper:
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
env,
|
|
53
|
+
transforms = None,
|
|
54
|
+
clip = None
|
|
55
|
+
):
|
|
56
|
+
self.env = env
|
|
57
|
+
self.clip = clip
|
|
58
|
+
|
|
59
|
+
if isinstance(transforms, dict):
|
|
60
|
+
transforms = [transforms]
|
|
61
|
+
|
|
62
|
+
self.transforms = default(transforms, [])
|
|
63
|
+
|
|
64
|
+
def __getattr__(self, name):
|
|
65
|
+
if name.startswith('_'):
|
|
66
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
67
|
+
return getattr(self.env, name)
|
|
68
|
+
|
|
69
|
+
def reset(self, **kwargs):
|
|
70
|
+
return self.env.reset(**kwargs)
|
|
71
|
+
|
|
72
|
+
def step(self, action):
|
|
73
|
+
def transform_action(t):
|
|
74
|
+
if not is_float_dtype(t):
|
|
75
|
+
return t
|
|
76
|
+
|
|
77
|
+
t = copy(t)
|
|
78
|
+
|
|
79
|
+
for ind, transform in enumerate(self.transforms):
|
|
80
|
+
indices = transform.get('indices', ind if len(self.transforms) > 1 else None)
|
|
81
|
+
|
|
82
|
+
if 'rescale_from_to' not in transform:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
from_range, to_range = transform['rescale_from_to']
|
|
86
|
+
fn = partial(rescale, from_range = from_range, to_range = to_range)
|
|
87
|
+
|
|
88
|
+
if exists(indices):
|
|
89
|
+
t[..., indices] = fn(t[..., indices])
|
|
90
|
+
else:
|
|
91
|
+
t = fn(t)
|
|
92
|
+
|
|
93
|
+
if exists(self.clip):
|
|
94
|
+
t = clamp(t, *self.clip)
|
|
95
|
+
|
|
96
|
+
return t
|
|
97
|
+
|
|
98
|
+
transformed_action = tree_map(transform_action, action)
|
|
99
|
+
return self.env.step(transformed_action)
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
from torch import is_tensor
|
|
6
|
-
from torch.utils._pytree import tree_map
|
|
7
|
-
|
|
8
|
-
# helpers
|
|
9
|
-
|
|
10
|
-
def exists(v):
|
|
11
|
-
return v is not None
|
|
12
|
-
|
|
13
|
-
def default(v, d):
|
|
14
|
-
return v if exists(v) else d
|
|
15
|
-
|
|
16
|
-
def rescale(
|
|
17
|
-
t,
|
|
18
|
-
from_range: tuple[float, float],
|
|
19
|
-
to_range: tuple[float, float]
|
|
20
|
-
):
|
|
21
|
-
from_min, from_max = from_range
|
|
22
|
-
to_min, to_max = to_range
|
|
23
|
-
return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
|
|
24
|
-
|
|
25
|
-
# classes
|
|
26
|
-
|
|
27
|
-
class ActionTransformWrapper:
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
env,
|
|
31
|
-
transforms = None,
|
|
32
|
-
clip = None
|
|
33
|
-
):
|
|
34
|
-
self.env = env
|
|
35
|
-
|
|
36
|
-
if isinstance(transforms, dict):
|
|
37
|
-
transforms = [transforms]
|
|
38
|
-
|
|
39
|
-
self.transforms = default(transforms, [])
|
|
40
|
-
self.clip = clip
|
|
41
|
-
|
|
42
|
-
def __getattr__(self, name):
|
|
43
|
-
if name.startswith('_'):
|
|
44
|
-
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
45
|
-
return getattr(self.env, name)
|
|
46
|
-
|
|
47
|
-
def reset(self, **kwargs):
|
|
48
|
-
return self.env.reset(**kwargs)
|
|
49
|
-
|
|
50
|
-
def step(self, action):
|
|
51
|
-
def transform_action(t):
|
|
52
|
-
is_torch_float = is_tensor(t) and t.is_floating_point()
|
|
53
|
-
is_np_float = isinstance(t, np.ndarray) and np.issubdtype(t.dtype, np.floating)
|
|
54
|
-
is_scalar_float = isinstance(t, float)
|
|
55
|
-
|
|
56
|
-
if not (is_torch_float or is_np_float or is_scalar_float):
|
|
57
|
-
return t
|
|
58
|
-
|
|
59
|
-
if is_tensor(t):
|
|
60
|
-
t = t.clone()
|
|
61
|
-
elif isinstance(t, np.ndarray):
|
|
62
|
-
t = np.copy(t)
|
|
63
|
-
|
|
64
|
-
for ind, transform in enumerate(self.transforms):
|
|
65
|
-
indices = transform.get('indices')
|
|
66
|
-
rescale_from_to = transform.get('rescale_from_to')
|
|
67
|
-
|
|
68
|
-
if not exists(indices) and len(self.transforms) > 1:
|
|
69
|
-
indices = ind
|
|
70
|
-
|
|
71
|
-
if exists(rescale_from_to):
|
|
72
|
-
from_range, to_range = rescale_from_to
|
|
73
|
-
|
|
74
|
-
if not exists(indices):
|
|
75
|
-
t = rescale(t, from_range, to_range)
|
|
76
|
-
else:
|
|
77
|
-
part = t[..., indices]
|
|
78
|
-
t[..., indices] = rescale(part, from_range, to_range)
|
|
79
|
-
|
|
80
|
-
if exists(self.clip):
|
|
81
|
-
min_clip, max_clip = self.clip
|
|
82
|
-
|
|
83
|
-
if is_tensor(t):
|
|
84
|
-
t = torch.clamp(t, min_clip, max_clip)
|
|
85
|
-
elif isinstance(t, np.ndarray):
|
|
86
|
-
t = np.clip(t, min_clip, max_clip)
|
|
87
|
-
else:
|
|
88
|
-
t = max(min_clip, min(t, max_clip))
|
|
89
|
-
|
|
90
|
-
return t
|
|
91
|
-
|
|
92
|
-
action = tree_map(transform_action, action)
|
|
93
|
-
return self.env.step(action)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|