env-ssl-wrapper 0.0.1__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {env_ssl_wrapper-0.0.1 → env_ssl_wrapper-0.0.3}/PKG-INFO +65 -1
- env_ssl_wrapper-0.0.3/README.md +63 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/__init__.py +5 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/auto_batched_wrapper.py +58 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/image_wrapper.py +70 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/tensor_wrapper.py +74 -0
- env_ssl_wrapper-0.0.3/env_ssl_wrapper/utils.py +30 -0
- {env_ssl_wrapper-0.0.1 → env_ssl_wrapper-0.0.3}/pyproject.toml +5 -1
- env_ssl_wrapper-0.0.1/README.md +0 -3
- env_ssl_wrapper-0.0.1/env_ssl_wrapper/__init__.py +0 -0
- env_ssl_wrapper-0.0.1/env_ssl_wrapper/env_ssl_wrapper.py +0 -0
- {env_ssl_wrapper-0.0.1 → env_ssl_wrapper-0.0.3}/.gitignore +0 -0
- {env_ssl_wrapper-0.0.1 → env_ssl_wrapper-0.0.3}/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: env-ssl-wrapper
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: An RL environment wrapper for learning SSL in the background
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
|
|
@@ -34,10 +34,14 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.10
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.10
|
|
37
|
+
Requires-Dist: discrete-continuous-embed-readout
|
|
37
38
|
Requires-Dist: einops>=0.8.1
|
|
38
39
|
Requires-Dist: einx>=0.3.0
|
|
40
|
+
Requires-Dist: memmap-replay-buffer
|
|
39
41
|
Requires-Dist: torch-einops-utils>=0.0.29
|
|
40
42
|
Requires-Dist: torch>=2.5
|
|
43
|
+
Requires-Dist: x-mlps-pytorch
|
|
44
|
+
Requires-Dist: x-transformers
|
|
41
45
|
Provides-Extra: examples
|
|
42
46
|
Provides-Extra: test
|
|
43
47
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -46,3 +50,63 @@ Description-Content-Type: text/markdown
|
|
|
46
50
|
## env-ssl-wrapper (wip)
|
|
47
51
|
|
|
48
52
|
Wrappers around environments that will take care of providing representations from self supervised learning automagically
|
|
53
|
+
|
|
54
|
+
## Citations
|
|
55
|
+
|
|
56
|
+
```bibtex
|
|
57
|
+
@misc{schwarzer2021dataefficientreinforcementlearningselfpredictive,
|
|
58
|
+
title = {Data-Efficient Reinforcement Learning with Self-Predictive Representations},
|
|
59
|
+
author = {Max Schwarzer and Ankesh Anand and Rishab Goel and R Devon Hjelm and Aaron Courville and Philip Bachman},
|
|
60
|
+
year = {2021},
|
|
61
|
+
eprint = {2007.05929},
|
|
62
|
+
archivePrefix = {arXiv},
|
|
63
|
+
primaryClass = {cs.LG},
|
|
64
|
+
url = {https://arxiv.org/abs/2007.05929},
|
|
65
|
+
}
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
```bibtex
|
|
69
|
+
@misc{schmidt2024learningactactions,
|
|
70
|
+
title = {Learning to Act without Actions},
|
|
71
|
+
author = {Dominik Schmidt and Minqi Jiang},
|
|
72
|
+
year = {2024},
|
|
73
|
+
eprint = {2312.10812},
|
|
74
|
+
archivePrefix = {arXiv},
|
|
75
|
+
primaryClass = {cs.LG},
|
|
76
|
+
url = {https://arxiv.org/abs/2312.10812},
|
|
77
|
+
}
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
```bibtex
|
|
81
|
+
@misc{eysenbach2023contrastivelearninggoalconditionedreinforcement,
|
|
82
|
+
title = {Contrastive Learning as Goal-Conditioned Reinforcement Learning},
|
|
83
|
+
author = {Benjamin Eysenbach and Tianjun Zhang and Ruslan Salakhutdinov and Sergey Levine},
|
|
84
|
+
year = {2023},
|
|
85
|
+
eprint = {2206.07568},
|
|
86
|
+
archivePrefix = {arXiv},
|
|
87
|
+
primaryClass = {cs.LG},
|
|
88
|
+
url = {https://arxiv.org/abs/2206.07568},
|
|
89
|
+
}
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
```bibtex
|
|
93
|
+
@misc{ashlag2025stateentropyregularizationrobust,
|
|
94
|
+
title = {State Entropy Regularization for Robust Reinforcement Learning},
|
|
95
|
+
author = {Yonatan Ashlag and Uri Koren and Mirco Mutti and Esther Derman and Pierre-Luc Bacon and Shie Mannor},
|
|
96
|
+
year = {2025},
|
|
97
|
+
eprint = {2506.07085},
|
|
98
|
+
archivePrefix = {arXiv},
|
|
99
|
+
primaryClass = {cs.LG},
|
|
100
|
+
url = {https://arxiv.org/abs/2506.07085},
|
|
101
|
+
}
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
```bibtex
|
|
105
|
+
@inproceedings{park2026dual,
|
|
106
|
+
title = {Dual Goal Representations},
|
|
107
|
+
author = {Seohong Park and Deepinder Mann and Sergey Levine},
|
|
108
|
+
booktitle = {The Fourteenth International Conference on Learning Representations},
|
|
109
|
+
year = {2026},
|
|
110
|
+
url = {https://openreview.net/forum?id=aMKFTidLSM}
|
|
111
|
+
}
|
|
112
|
+
```
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
## env-ssl-wrapper (wip)
|
|
2
|
+
|
|
3
|
+
Wrappers around environments that will take care of providing representations from self supervised learning automagically
|
|
4
|
+
|
|
5
|
+
## Citations
|
|
6
|
+
|
|
7
|
+
```bibtex
|
|
8
|
+
@misc{schwarzer2021dataefficientreinforcementlearningselfpredictive,
|
|
9
|
+
title = {Data-Efficient Reinforcement Learning with Self-Predictive Representations},
|
|
10
|
+
author = {Max Schwarzer and Ankesh Anand and Rishab Goel and R Devon Hjelm and Aaron Courville and Philip Bachman},
|
|
11
|
+
year = {2021},
|
|
12
|
+
eprint = {2007.05929},
|
|
13
|
+
archivePrefix = {arXiv},
|
|
14
|
+
primaryClass = {cs.LG},
|
|
15
|
+
url = {https://arxiv.org/abs/2007.05929},
|
|
16
|
+
}
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
```bibtex
|
|
20
|
+
@misc{schmidt2024learningactactions,
|
|
21
|
+
title = {Learning to Act without Actions},
|
|
22
|
+
author = {Dominik Schmidt and Minqi Jiang},
|
|
23
|
+
year = {2024},
|
|
24
|
+
eprint = {2312.10812},
|
|
25
|
+
archivePrefix = {arXiv},
|
|
26
|
+
primaryClass = {cs.LG},
|
|
27
|
+
url = {https://arxiv.org/abs/2312.10812},
|
|
28
|
+
}
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
```bibtex
|
|
32
|
+
@misc{eysenbach2023contrastivelearninggoalconditionedreinforcement,
|
|
33
|
+
title = {Contrastive Learning as Goal-Conditioned Reinforcement Learning},
|
|
34
|
+
author = {Benjamin Eysenbach and Tianjun Zhang and Ruslan Salakhutdinov and Sergey Levine},
|
|
35
|
+
year = {2023},
|
|
36
|
+
eprint = {2206.07568},
|
|
37
|
+
archivePrefix = {arXiv},
|
|
38
|
+
primaryClass = {cs.LG},
|
|
39
|
+
url = {https://arxiv.org/abs/2206.07568},
|
|
40
|
+
}
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
```bibtex
|
|
44
|
+
@misc{ashlag2025stateentropyregularizationrobust,
|
|
45
|
+
title = {State Entropy Regularization for Robust Reinforcement Learning},
|
|
46
|
+
author = {Yonatan Ashlag and Uri Koren and Mirco Mutti and Esther Derman and Pierre-Luc Bacon and Shie Mannor},
|
|
47
|
+
year = {2025},
|
|
48
|
+
eprint = {2506.07085},
|
|
49
|
+
archivePrefix = {arXiv},
|
|
50
|
+
primaryClass = {cs.LG},
|
|
51
|
+
url = {https://arxiv.org/abs/2506.07085},
|
|
52
|
+
}
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
```bibtex
|
|
56
|
+
@inproceedings{park2026dual,
|
|
57
|
+
title = {Dual Goal Representations},
|
|
58
|
+
author = {Seohong Park and Deepinder Mann and Sergey Levine},
|
|
59
|
+
booktitle = {The Fourteenth International Conference on Learning Representations},
|
|
60
|
+
year = {2026},
|
|
61
|
+
url = {https://openreview.net/forum?id=aMKFTidLSM}
|
|
62
|
+
}
|
|
63
|
+
```
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch import is_tensor
|
|
5
|
+
from torch.utils._pytree import tree_map
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
|
|
8
|
+
# helper functions
|
|
9
|
+
|
|
10
|
+
def is_vectorized(env) -> bool:
|
|
11
|
+
if hasattr(env, 'num_envs') and getattr(env, 'num_envs', 0) > 0:
|
|
12
|
+
return True
|
|
13
|
+
try:
|
|
14
|
+
from gymnasium.vector import VectorEnv
|
|
15
|
+
return isinstance(getattr(env, 'unwrapped', env), VectorEnv)
|
|
16
|
+
except ImportError:
|
|
17
|
+
return False
|
|
18
|
+
|
|
19
|
+
def maybe_expand_dim(x):
|
|
20
|
+
def _expand(t):
|
|
21
|
+
if isinstance(t, np.ndarray) or is_tensor(t):
|
|
22
|
+
return rearrange(t, '... -> 1 ...')
|
|
23
|
+
if isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
24
|
+
return np.array([t])
|
|
25
|
+
return t
|
|
26
|
+
return tree_map(_expand, x)
|
|
27
|
+
|
|
28
|
+
def maybe_squeeze_dim(x):
|
|
29
|
+
def _squeeze(t):
|
|
30
|
+
if isinstance(t, np.ndarray) or is_tensor(t):
|
|
31
|
+
return rearrange(t, '1 ... -> ...')
|
|
32
|
+
return t
|
|
33
|
+
return tree_map(_squeeze, x)
|
|
34
|
+
|
|
35
|
+
# classes
|
|
36
|
+
|
|
37
|
+
class AutoBatchedWrapper:
|
|
38
|
+
def __init__(self, env, is_vector: bool | None = None):
|
|
39
|
+
self.env = env
|
|
40
|
+
self.is_vector = is_vector if is_vector is not None else is_vectorized(env)
|
|
41
|
+
|
|
42
|
+
def __getattr__(self, name):
|
|
43
|
+
if name.startswith('_'):
|
|
44
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
45
|
+
return getattr(self.env, name)
|
|
46
|
+
|
|
47
|
+
def reset(self, **kwargs):
|
|
48
|
+
obs, info = self.env.reset(**kwargs)
|
|
49
|
+
return (maybe_expand_dim(obs), info) if not self.is_vector else (obs, info)
|
|
50
|
+
|
|
51
|
+
def step(self, action):
|
|
52
|
+
action = maybe_squeeze_dim(action) if not self.is_vector else action
|
|
53
|
+
out = self.env.step(action)
|
|
54
|
+
|
|
55
|
+
if self.is_vector:
|
|
56
|
+
return out
|
|
57
|
+
|
|
58
|
+
return *maybe_expand_dim(out[:4]), out[4]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
9
|
+
# functions
|
|
10
|
+
|
|
11
|
+
def cast_tuple(t, length = 1):
|
|
12
|
+
return t if isinstance(t, tuple) else ((t,) * length)
|
|
13
|
+
|
|
14
|
+
# class
|
|
15
|
+
|
|
16
|
+
class ImageObservationWrapper:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
env,
|
|
20
|
+
image_size = (64, 64),
|
|
21
|
+
image_key = 'image',
|
|
22
|
+
resample_method = Image.BILINEAR,
|
|
23
|
+
normalize = True,
|
|
24
|
+
normalize_divisor = 255.0
|
|
25
|
+
):
|
|
26
|
+
self.env = env
|
|
27
|
+
self.image_size = cast_tuple(image_size, 2)
|
|
28
|
+
self.image_key = image_key
|
|
29
|
+
self.resample_method = resample_method
|
|
30
|
+
self.normalize = normalize
|
|
31
|
+
self.normalize_divisor = normalize_divisor
|
|
32
|
+
|
|
33
|
+
def __getattr__(self, name):
|
|
34
|
+
if name.startswith('_'):
|
|
35
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
36
|
+
return getattr(self.env, name)
|
|
37
|
+
|
|
38
|
+
def render_frame(self):
|
|
39
|
+
img = self.env.render()
|
|
40
|
+
img = Image.fromarray(img).resize(self.image_size, resample = self.resample_method)
|
|
41
|
+
img_tensor = torch.from_numpy(np.array(img))
|
|
42
|
+
img = rearrange(img_tensor, 'h w c -> 1 c h w')
|
|
43
|
+
|
|
44
|
+
if self.normalize:
|
|
45
|
+
img = img.float() / self.normalize_divisor
|
|
46
|
+
|
|
47
|
+
return img
|
|
48
|
+
|
|
49
|
+
def observation(self, obs):
|
|
50
|
+
img_tensor = self.render_frame()
|
|
51
|
+
img_tensor = rearrange(img_tensor, '1 c h w -> c h w')
|
|
52
|
+
|
|
53
|
+
if not isinstance(obs, dict):
|
|
54
|
+
return dict(state = obs, **{self.image_key: img_tensor})
|
|
55
|
+
|
|
56
|
+
if self.image_key in obs:
|
|
57
|
+
raise ValueError(f"Key '{self.image_key}' is already present in the observation dictionary.")
|
|
58
|
+
|
|
59
|
+
obs = dict(obs)
|
|
60
|
+
obs.update({self.image_key: img_tensor})
|
|
61
|
+
|
|
62
|
+
return obs
|
|
63
|
+
|
|
64
|
+
def reset(self, **kwargs):
|
|
65
|
+
obs, info = self.env.reset(**kwargs)
|
|
66
|
+
return self.observation(obs), info
|
|
67
|
+
|
|
68
|
+
def step(self, action):
|
|
69
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
70
|
+
return self.observation(obs), reward, terminated, truncated, info
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from torch import tensor, is_tensor, from_numpy, float64, device as torch_device
|
|
5
|
+
from torch.utils._pytree import tree_map
|
|
6
|
+
|
|
7
|
+
# helper functions
|
|
8
|
+
|
|
9
|
+
def numpy_to_torch(x, device, cast_float64_to_float32 = False):
|
|
10
|
+
def _to_torch(t):
|
|
11
|
+
if isinstance(t, np.ndarray):
|
|
12
|
+
t = from_numpy(t)
|
|
13
|
+
elif isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
14
|
+
t = tensor(t)
|
|
15
|
+
|
|
16
|
+
if not is_tensor(t):
|
|
17
|
+
return t
|
|
18
|
+
|
|
19
|
+
if cast_float64_to_float32 and t.dtype == float64:
|
|
20
|
+
t = t.float()
|
|
21
|
+
|
|
22
|
+
return t.to(device)
|
|
23
|
+
return tree_map(_to_torch, x)
|
|
24
|
+
|
|
25
|
+
def torch_to_numpy(x, cast_float64_to_float32 = False):
|
|
26
|
+
def _to_numpy(t):
|
|
27
|
+
if is_tensor(t):
|
|
28
|
+
t = t.detach().cpu().numpy()
|
|
29
|
+
elif isinstance(t, (int, float, bool, np.number, np.bool_)):
|
|
30
|
+
t = np.array(t)
|
|
31
|
+
|
|
32
|
+
if not isinstance(t, np.ndarray):
|
|
33
|
+
return t
|
|
34
|
+
|
|
35
|
+
if cast_float64_to_float32 and t.dtype == np.float64:
|
|
36
|
+
t = t.astype(np.float32)
|
|
37
|
+
|
|
38
|
+
return t
|
|
39
|
+
return tree_map(_to_numpy, x)
|
|
40
|
+
|
|
41
|
+
# classes
|
|
42
|
+
|
|
43
|
+
class TensorWrapper:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
env,
|
|
47
|
+
device: str | torch_device = 'cpu',
|
|
48
|
+
convert_in: bool = True,
|
|
49
|
+
convert_out: bool = True,
|
|
50
|
+
cast_float64_to_float32: bool = False
|
|
51
|
+
):
|
|
52
|
+
self.env = env
|
|
53
|
+
self.device = torch_device(device)
|
|
54
|
+
self.convert_in = convert_in
|
|
55
|
+
self.convert_out = convert_out
|
|
56
|
+
self.cast_float64_to_float32 = cast_float64_to_float32
|
|
57
|
+
|
|
58
|
+
def __getattr__(self, name):
|
|
59
|
+
if name.startswith('_'):
|
|
60
|
+
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
|
61
|
+
return getattr(self.env, name)
|
|
62
|
+
|
|
63
|
+
def reset(self, **kwargs):
|
|
64
|
+
obs, info = self.env.reset(**kwargs)
|
|
65
|
+
return (numpy_to_torch(obs, self.device, self.cast_float64_to_float32), info) if self.convert_out else (obs, info)
|
|
66
|
+
|
|
67
|
+
def step(self, action):
|
|
68
|
+
action = torch_to_numpy(action, self.cast_float64_to_float32) if self.convert_in else action
|
|
69
|
+
out = self.env.step(action)
|
|
70
|
+
|
|
71
|
+
if not self.convert_out:
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
return *numpy_to_torch(out[:4], self.device, self.cast_float64_to_float32), out[4]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
from .image_wrapper import ImageObservationWrapper
|
|
5
|
+
from .auto_batched_wrapper import AutoBatchedWrapper
|
|
6
|
+
from .tensor_wrapper import TensorWrapper
|
|
7
|
+
|
|
8
|
+
WRAPPERS = dict(
|
|
9
|
+
image = ImageObservationWrapper,
|
|
10
|
+
auto_batch = AutoBatchedWrapper,
|
|
11
|
+
tensor = TensorWrapper
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
def compose_env(env, *wrappers):
|
|
15
|
+
for wrapper in wrappers:
|
|
16
|
+
if isinstance(wrapper, str):
|
|
17
|
+
wrapper = WRAPPERS[wrapper]
|
|
18
|
+
|
|
19
|
+
if isinstance(wrapper, tuple):
|
|
20
|
+
name_or_fn, kwargs = wrapper
|
|
21
|
+
fn = WRAPPERS.get(name_or_fn, name_or_fn)
|
|
22
|
+
wrapper = partial(fn, **kwargs)
|
|
23
|
+
|
|
24
|
+
env = wrapper(env)
|
|
25
|
+
|
|
26
|
+
return env
|
|
27
|
+
|
|
28
|
+
# alias
|
|
29
|
+
|
|
30
|
+
wrap_env = compose_env
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "env-ssl-wrapper"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.3"
|
|
4
4
|
description = "An RL environment wrapper for learning SSL in the background"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -24,10 +24,14 @@ classifiers=[
|
|
|
24
24
|
]
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
|
+
"discrete-continuous-embed-readout",
|
|
27
28
|
"einx>=0.3.0",
|
|
28
29
|
"einops>=0.8.1",
|
|
30
|
+
"memmap-replay-buffer",
|
|
29
31
|
"torch>=2.5",
|
|
30
32
|
"torch-einops-utils>=0.0.29",
|
|
33
|
+
"x-transformers",
|
|
34
|
+
"x-mlps-pytorch",
|
|
31
35
|
]
|
|
32
36
|
|
|
33
37
|
[project.urls]
|
env_ssl_wrapper-0.0.1/README.md
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|