env-ssl-wrapper 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: env-ssl-wrapper
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: An RL environment wrapper for learning SSL in the background
5
5
  Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
@@ -34,10 +34,14 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.10
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.10
37
+ Requires-Dist: discrete-continuous-embed-readout
37
38
  Requires-Dist: einops>=0.8.1
38
39
  Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: memmap-replay-buffer
39
41
  Requires-Dist: torch-einops-utils>=0.0.29
40
42
  Requires-Dist: torch>=2.5
43
+ Requires-Dist: x-mlps-pytorch
44
+ Requires-Dist: x-transformers
41
45
  Provides-Extra: examples
42
46
  Provides-Extra: test
43
47
  Requires-Dist: pytest; extra == 'test'
@@ -46,3 +50,63 @@ Description-Content-Type: text/markdown
46
50
  ## env-ssl-wrapper (wip)
47
51
 
48
52
  Wrappers around environments that will take care of providing representations from self supervised learning automagically
53
+
54
+ ## Citations
55
+
56
+ ```bibtex
57
+ @misc{schwarzer2021dataefficientreinforcementlearningselfpredictive,
58
+ title = {Data-Efficient Reinforcement Learning with Self-Predictive Representations},
59
+ author = {Max Schwarzer and Ankesh Anand and Rishab Goel and R Devon Hjelm and Aaron Courville and Philip Bachman},
60
+ year = {2021},
61
+ eprint = {2007.05929},
62
+ archivePrefix = {arXiv},
63
+ primaryClass = {cs.LG},
64
+ url = {https://arxiv.org/abs/2007.05929},
65
+ }
66
+ ```
67
+
68
+ ```bibtex
69
+ @misc{schmidt2024learningactactions,
70
+ title = {Learning to Act without Actions},
71
+ author = {Dominik Schmidt and Minqi Jiang},
72
+ year = {2024},
73
+ eprint = {2312.10812},
74
+ archivePrefix = {arXiv},
75
+ primaryClass = {cs.LG},
76
+ url = {https://arxiv.org/abs/2312.10812},
77
+ }
78
+ ```
79
+
80
+ ```bibtex
81
+ @misc{eysenbach2023contrastivelearninggoalconditionedreinforcement,
82
+ title = {Contrastive Learning as Goal-Conditioned Reinforcement Learning},
83
+ author = {Benjamin Eysenbach and Tianjun Zhang and Ruslan Salakhutdinov and Sergey Levine},
84
+ year = {2023},
85
+ eprint = {2206.07568},
86
+ archivePrefix = {arXiv},
87
+ primaryClass = {cs.LG},
88
+ url = {https://arxiv.org/abs/2206.07568},
89
+ }
90
+ ```
91
+
92
+ ```bibtex
93
+ @misc{ashlag2025stateentropyregularizationrobust,
94
+ title = {State Entropy Regularization for Robust Reinforcement Learning},
95
+ author = {Yonatan Ashlag and Uri Koren and Mirco Mutti and Esther Derman and Pierre-Luc Bacon and Shie Mannor},
96
+ year = {2025},
97
+ eprint = {2506.07085},
98
+ archivePrefix = {arXiv},
99
+ primaryClass = {cs.LG},
100
+ url = {https://arxiv.org/abs/2506.07085},
101
+ }
102
+ ```
103
+
104
+ ```bibtex
105
+ @inproceedings{park2026dual,
106
+ title = {Dual Goal Representations},
107
+ author = {Seohong Park and Deepinder Mann and Sergey Levine},
108
+ booktitle = {The Fourteenth International Conference on Learning Representations},
109
+ year = {2026},
110
+ url = {https://openreview.net/forum?id=aMKFTidLSM}
111
+ }
112
+ ```
@@ -0,0 +1,63 @@
1
+ ## env-ssl-wrapper (wip)
2
+
3
+ Wrappers around environments that will take care of providing representations from self supervised learning automagically
4
+
5
+ ## Citations
6
+
7
+ ```bibtex
8
+ @misc{schwarzer2021dataefficientreinforcementlearningselfpredictive,
9
+ title = {Data-Efficient Reinforcement Learning with Self-Predictive Representations},
10
+ author = {Max Schwarzer and Ankesh Anand and Rishab Goel and R Devon Hjelm and Aaron Courville and Philip Bachman},
11
+ year = {2021},
12
+ eprint = {2007.05929},
13
+ archivePrefix = {arXiv},
14
+ primaryClass = {cs.LG},
15
+ url = {https://arxiv.org/abs/2007.05929},
16
+ }
17
+ ```
18
+
19
+ ```bibtex
20
+ @misc{schmidt2024learningactactions,
21
+ title = {Learning to Act without Actions},
22
+ author = {Dominik Schmidt and Minqi Jiang},
23
+ year = {2024},
24
+ eprint = {2312.10812},
25
+ archivePrefix = {arXiv},
26
+ primaryClass = {cs.LG},
27
+ url = {https://arxiv.org/abs/2312.10812},
28
+ }
29
+ ```
30
+
31
+ ```bibtex
32
+ @misc{eysenbach2023contrastivelearninggoalconditionedreinforcement,
33
+ title = {Contrastive Learning as Goal-Conditioned Reinforcement Learning},
34
+ author = {Benjamin Eysenbach and Tianjun Zhang and Ruslan Salakhutdinov and Sergey Levine},
35
+ year = {2023},
36
+ eprint = {2206.07568},
37
+ archivePrefix = {arXiv},
38
+ primaryClass = {cs.LG},
39
+ url = {https://arxiv.org/abs/2206.07568},
40
+ }
41
+ ```
42
+
43
+ ```bibtex
44
+ @misc{ashlag2025stateentropyregularizationrobust,
45
+ title = {State Entropy Regularization for Robust Reinforcement Learning},
46
+ author = {Yonatan Ashlag and Uri Koren and Mirco Mutti and Esther Derman and Pierre-Luc Bacon and Shie Mannor},
47
+ year = {2025},
48
+ eprint = {2506.07085},
49
+ archivePrefix = {arXiv},
50
+ primaryClass = {cs.LG},
51
+ url = {https://arxiv.org/abs/2506.07085},
52
+ }
53
+ ```
54
+
55
+ ```bibtex
56
+ @inproceedings{park2026dual,
57
+ title = {Dual Goal Representations},
58
+ author = {Seohong Park and Deepinder Mann and Sergey Levine},
59
+ booktitle = {The Fourteenth International Conference on Learning Representations},
60
+ year = {2026},
61
+ url = {https://openreview.net/forum?id=aMKFTidLSM}
62
+ }
63
+ ```
@@ -0,0 +1,5 @@
1
+ from .image_wrapper import ImageObservationWrapper
2
+ from .auto_batched_wrapper import AutoBatchedWrapper
3
+ from .tensor_wrapper import TensorWrapper
4
+
5
+ from .utils import wrap_env, compose_env
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from torch import is_tensor
5
+ from torch.utils._pytree import tree_map
6
+ from einops import rearrange
7
+
8
+ # helper functions
9
+
10
+ def is_vectorized(env) -> bool:
11
+ if hasattr(env, 'num_envs') and getattr(env, 'num_envs', 0) > 0:
12
+ return True
13
+ try:
14
+ from gymnasium.vector import VectorEnv
15
+ return isinstance(getattr(env, 'unwrapped', env), VectorEnv)
16
+ except ImportError:
17
+ return False
18
+
19
+ def maybe_expand_dim(x):
20
+ def _expand(t):
21
+ if isinstance(t, np.ndarray) or is_tensor(t):
22
+ return rearrange(t, '... -> 1 ...')
23
+ if isinstance(t, (int, float, bool, np.number, np.bool_)):
24
+ return np.array([t])
25
+ return t
26
+ return tree_map(_expand, x)
27
+
28
+ def maybe_squeeze_dim(x):
29
+ def _squeeze(t):
30
+ if isinstance(t, np.ndarray) or is_tensor(t):
31
+ return rearrange(t, '1 ... -> ...')
32
+ return t
33
+ return tree_map(_squeeze, x)
34
+
35
+ # classes
36
+
37
+ class AutoBatchedWrapper:
38
+ def __init__(self, env, is_vector: bool | None = None):
39
+ self.env = env
40
+ self.is_vector = is_vector if is_vector is not None else is_vectorized(env)
41
+
42
+ def __getattr__(self, name):
43
+ if name.startswith('_'):
44
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
45
+ return getattr(self.env, name)
46
+
47
+ def reset(self, **kwargs):
48
+ obs, info = self.env.reset(**kwargs)
49
+ return (maybe_expand_dim(obs), info) if not self.is_vector else (obs, info)
50
+
51
+ def step(self, action):
52
+ action = maybe_squeeze_dim(action) if not self.is_vector else action
53
+ out = self.env.step(action)
54
+
55
+ if self.is_vector:
56
+ return out
57
+
58
+ return *maybe_expand_dim(out[:4]), out[4]
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from einops import rearrange
8
+
9
+ # functions
10
+
11
+ def cast_tuple(t, length = 1):
12
+ return t if isinstance(t, tuple) else ((t,) * length)
13
+
14
+ # class
15
+
16
+ class ImageObservationWrapper:
17
+ def __init__(
18
+ self,
19
+ env,
20
+ image_size = (64, 64),
21
+ image_key = 'image',
22
+ resample_method = Image.BILINEAR,
23
+ normalize = True,
24
+ normalize_divisor = 255.0
25
+ ):
26
+ self.env = env
27
+ self.image_size = cast_tuple(image_size, 2)
28
+ self.image_key = image_key
29
+ self.resample_method = resample_method
30
+ self.normalize = normalize
31
+ self.normalize_divisor = normalize_divisor
32
+
33
+ def __getattr__(self, name):
34
+ if name.startswith('_'):
35
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
36
+ return getattr(self.env, name)
37
+
38
+ def render_frame(self):
39
+ img = self.env.render()
40
+ img = Image.fromarray(img).resize(self.image_size, resample = self.resample_method)
41
+ img_tensor = torch.from_numpy(np.array(img))
42
+ img = rearrange(img_tensor, 'h w c -> 1 c h w')
43
+
44
+ if self.normalize:
45
+ img = img.float() / self.normalize_divisor
46
+
47
+ return img
48
+
49
+ def observation(self, obs):
50
+ img_tensor = self.render_frame()
51
+ img_tensor = rearrange(img_tensor, '1 c h w -> c h w')
52
+
53
+ if not isinstance(obs, dict):
54
+ return dict(state = obs, **{self.image_key: img_tensor})
55
+
56
+ if self.image_key in obs:
57
+ raise ValueError(f"Key '{self.image_key}' is already present in the observation dictionary.")
58
+
59
+ obs = dict(obs)
60
+ obs.update({self.image_key: img_tensor})
61
+
62
+ return obs
63
+
64
+ def reset(self, **kwargs):
65
+ obs, info = self.env.reset(**kwargs)
66
+ return self.observation(obs), info
67
+
68
+ def step(self, action):
69
+ obs, reward, terminated, truncated, info = self.env.step(action)
70
+ return self.observation(obs), reward, terminated, truncated, info
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from torch import tensor, is_tensor, from_numpy, float64, device as torch_device
5
+ from torch.utils._pytree import tree_map
6
+
7
+ # helper functions
8
+
9
+ def numpy_to_torch(x, device, cast_float64_to_float32 = False):
10
+ def _to_torch(t):
11
+ if isinstance(t, np.ndarray):
12
+ t = from_numpy(t)
13
+ elif isinstance(t, (int, float, bool, np.number, np.bool_)):
14
+ t = tensor(t)
15
+
16
+ if not is_tensor(t):
17
+ return t
18
+
19
+ if cast_float64_to_float32 and t.dtype == float64:
20
+ t = t.float()
21
+
22
+ return t.to(device)
23
+ return tree_map(_to_torch, x)
24
+
25
+ def torch_to_numpy(x, cast_float64_to_float32 = False):
26
+ def _to_numpy(t):
27
+ if is_tensor(t):
28
+ t = t.detach().cpu().numpy()
29
+ elif isinstance(t, (int, float, bool, np.number, np.bool_)):
30
+ t = np.array(t)
31
+
32
+ if not isinstance(t, np.ndarray):
33
+ return t
34
+
35
+ if cast_float64_to_float32 and t.dtype == np.float64:
36
+ t = t.astype(np.float32)
37
+
38
+ return t
39
+ return tree_map(_to_numpy, x)
40
+
41
+ # classes
42
+
43
+ class TensorWrapper:
44
+ def __init__(
45
+ self,
46
+ env,
47
+ device: str | torch_device = 'cpu',
48
+ convert_in: bool = True,
49
+ convert_out: bool = True,
50
+ cast_float64_to_float32: bool = False
51
+ ):
52
+ self.env = env
53
+ self.device = torch_device(device)
54
+ self.convert_in = convert_in
55
+ self.convert_out = convert_out
56
+ self.cast_float64_to_float32 = cast_float64_to_float32
57
+
58
+ def __getattr__(self, name):
59
+ if name.startswith('_'):
60
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
61
+ return getattr(self.env, name)
62
+
63
+ def reset(self, **kwargs):
64
+ obs, info = self.env.reset(**kwargs)
65
+ return (numpy_to_torch(obs, self.device, self.cast_float64_to_float32), info) if self.convert_out else (obs, info)
66
+
67
+ def step(self, action):
68
+ action = torch_to_numpy(action, self.cast_float64_to_float32) if self.convert_in else action
69
+ out = self.env.step(action)
70
+
71
+ if not self.convert_out:
72
+ return out
73
+
74
+ return *numpy_to_torch(out[:4], self.device, self.cast_float64_to_float32), out[4]
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ from .image_wrapper import ImageObservationWrapper
5
+ from .auto_batched_wrapper import AutoBatchedWrapper
6
+ from .tensor_wrapper import TensorWrapper
7
+
8
+ WRAPPERS = dict(
9
+ image = ImageObservationWrapper,
10
+ auto_batch = AutoBatchedWrapper,
11
+ tensor = TensorWrapper
12
+ )
13
+
14
+ def compose_env(env, *wrappers):
15
+ for wrapper in wrappers:
16
+ if isinstance(wrapper, str):
17
+ wrapper = WRAPPERS[wrapper]
18
+
19
+ if isinstance(wrapper, tuple):
20
+ name_or_fn, kwargs = wrapper
21
+ fn = WRAPPERS.get(name_or_fn, name_or_fn)
22
+ wrapper = partial(fn, **kwargs)
23
+
24
+ env = wrapper(env)
25
+
26
+ return env
27
+
28
+ # alias
29
+
30
+ wrap_env = compose_env
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "env-ssl-wrapper"
3
- version = "0.0.1"
3
+ version = "0.0.3"
4
4
  description = "An RL environment wrapper for learning SSL in the background"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,10 +24,14 @@ classifiers=[
24
24
  ]
25
25
 
26
26
  dependencies = [
27
+ "discrete-continuous-embed-readout",
27
28
  "einx>=0.3.0",
28
29
  "einops>=0.8.1",
30
+ "memmap-replay-buffer",
29
31
  "torch>=2.5",
30
32
  "torch-einops-utils>=0.0.29",
33
+ "x-transformers",
34
+ "x-mlps-pytorch",
31
35
  ]
32
36
 
33
37
  [project.urls]
@@ -1,3 +0,0 @@
1
- ## env-ssl-wrapper (wip)
2
-
3
- Wrappers around environments that will take care of providing representations from self supervised learning automagically
File without changes
File without changes
File without changes