env-ssl-wrapper 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: env-ssl-wrapper
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: An RL environment wrapper for learning SSL in the background
5
5
  Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
@@ -1,5 +1,6 @@
1
1
  from .image_wrapper import ImageObservationWrapper
2
2
  from .auto_batched_wrapper import AutoBatchedWrapper
3
3
  from .tensor_wrapper import TensorWrapper
4
+ from .action_transform_wrapper import ActionTransformWrapper
4
5
 
5
6
  from .utils import wrap_env, compose_env
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import is_tensor
6
+ from torch.utils._pytree import tree_map
7
+
8
+ # helpers
9
+
10
+ def exists(v):
11
+ return v is not None
12
+
13
+ def default(v, d):
14
+ return v if exists(v) else d
15
+
16
+ def rescale(
17
+ t,
18
+ from_range: tuple[float, float],
19
+ to_range: tuple[float, float]
20
+ ):
21
+ from_min, from_max = from_range
22
+ to_min, to_max = to_range
23
+ return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
24
+
25
+ # classes
26
+
27
+ class ActionTransformWrapper:
28
+ def __init__(
29
+ self,
30
+ env,
31
+ transforms = None,
32
+ clip = None
33
+ ):
34
+ self.env = env
35
+
36
+ if isinstance(transforms, dict):
37
+ transforms = [transforms]
38
+
39
+ self.transforms = default(transforms, [])
40
+ self.clip = clip
41
+
42
+ def __getattr__(self, name):
43
+ if name.startswith('_'):
44
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
45
+ return getattr(self.env, name)
46
+
47
+ def reset(self, **kwargs):
48
+ return self.env.reset(**kwargs)
49
+
50
+ def step(self, action):
51
+ def transform_action(t):
52
+ is_torch_float = is_tensor(t) and t.is_floating_point()
53
+ is_np_float = isinstance(t, np.ndarray) and np.issubdtype(t.dtype, np.floating)
54
+ is_scalar_float = isinstance(t, float)
55
+
56
+ if not (is_torch_float or is_np_float or is_scalar_float):
57
+ return t
58
+
59
+ if is_tensor(t):
60
+ t = t.clone()
61
+ elif isinstance(t, np.ndarray):
62
+ t = np.copy(t)
63
+
64
+ for ind, transform in enumerate(self.transforms):
65
+ indices = transform.get('indices')
66
+ rescale_from_to = transform.get('rescale_from_to')
67
+
68
+ if not exists(indices) and len(self.transforms) > 1:
69
+ indices = ind
70
+
71
+ if exists(rescale_from_to):
72
+ from_range, to_range = rescale_from_to
73
+
74
+ if not exists(indices):
75
+ t = rescale(t, from_range, to_range)
76
+ else:
77
+ part = t[..., indices]
78
+ t[..., indices] = rescale(part, from_range, to_range)
79
+
80
+ if exists(self.clip):
81
+ min_clip, max_clip = self.clip
82
+
83
+ if is_tensor(t):
84
+ t = torch.clamp(t, min_clip, max_clip)
85
+ elif isinstance(t, np.ndarray):
86
+ t = np.clip(t, min_clip, max_clip)
87
+ else:
88
+ t = max(min_clip, min(t, max_clip))
89
+
90
+ return t
91
+
92
+ action = tree_map(transform_action, action)
93
+ return self.env.step(action)
@@ -51,8 +51,8 @@ class AutoBatchedWrapper:
51
51
  def step(self, action):
52
52
  action = maybe_squeeze_dim(action) if not self.is_vector else action
53
53
  out = self.env.step(action)
54
-
54
+
55
55
  if self.is_vector:
56
56
  return out
57
-
57
+
58
58
  return *maybe_expand_dim(out[:4]), out[4]
@@ -67,7 +67,7 @@ class TensorWrapper:
67
67
  def step(self, action):
68
68
  action = torch_to_numpy(action, self.cast_float64_to_float32) if self.convert_in else action
69
69
  out = self.env.step(action)
70
-
70
+
71
71
  if not self.convert_out:
72
72
  return out
73
73
 
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ from .image_wrapper import ImageObservationWrapper
5
+ from .auto_batched_wrapper import AutoBatchedWrapper
6
+ from .tensor_wrapper import TensorWrapper
7
+ from .action_transform_wrapper import ActionTransformWrapper
8
+
9
+ WRAPPERS = dict(
10
+ image = ImageObservationWrapper,
11
+ auto_batch = AutoBatchedWrapper,
12
+ tensor = TensorWrapper,
13
+ action_transform = ActionTransformWrapper
14
+ )
15
+
16
+ def is_unique(arr):
17
+ return len(set(arr)) == len(arr)
18
+
19
+ def compose_env(env, *wrappers):
20
+ funcs = []
21
+ classes = []
22
+
23
+ for wrapper in wrappers:
24
+ if isinstance(wrapper, str):
25
+ wrapper = WRAPPERS[wrapper]
26
+
27
+ if isinstance(wrapper, tuple):
28
+ name, kwargs = wrapper
29
+ wrapper = partial(WRAPPERS.get(name, name), **kwargs)
30
+
31
+ cls = wrapper.func if isinstance(wrapper, partial) else wrapper
32
+
33
+ funcs.append(wrapper)
34
+ classes.append(cls)
35
+
36
+ assert is_unique(classes), 'duplicate wrappers found'
37
+
38
+ for func in funcs:
39
+ env = func(env)
40
+
41
+ return env
42
+
43
+ # alias
44
+
45
+ wrap_env = compose_env
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "env-ssl-wrapper"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  description = "An RL environment wrapper for learning SSL in the background"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,30 +0,0 @@
1
- from __future__ import annotations
2
- from functools import partial
3
-
4
- from .image_wrapper import ImageObservationWrapper
5
- from .auto_batched_wrapper import AutoBatchedWrapper
6
- from .tensor_wrapper import TensorWrapper
7
-
8
- WRAPPERS = dict(
9
- image = ImageObservationWrapper,
10
- auto_batch = AutoBatchedWrapper,
11
- tensor = TensorWrapper
12
- )
13
-
14
- def compose_env(env, *wrappers):
15
- for wrapper in wrappers:
16
- if isinstance(wrapper, str):
17
- wrapper = WRAPPERS[wrapper]
18
-
19
- if isinstance(wrapper, tuple):
20
- name_or_fn, kwargs = wrapper
21
- fn = WRAPPERS.get(name_or_fn, name_or_fn)
22
- wrapper = partial(fn, **kwargs)
23
-
24
- env = wrapper(env)
25
-
26
- return env
27
-
28
- # alias
29
-
30
- wrap_env = compose_env
File without changes