env-ssl-wrapper 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: env-ssl-wrapper
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: An RL environment wrapper for learning SSL in the background
5
5
  Project-URL: Homepage, https://pypi.org/project/env-ssl-wrapper/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/env-ssl-wrapper
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import is_tensor
6
+ from torch.utils._pytree import tree_map
7
+ from functools import partial
8
+
9
+ # helpers
10
+
11
+ def exists(v):
12
+ return v is not None
13
+
14
+ def default(v, d):
15
+ return v if exists(v) else d
16
+
17
+ def is_float_dtype(t):
18
+ if is_tensor(t):
19
+ return t.is_floating_point()
20
+ if isinstance(t, np.ndarray):
21
+ return np.issubdtype(t.dtype, np.floating)
22
+ return isinstance(t, float)
23
+
24
+ def copy(t):
25
+ if is_tensor(t):
26
+ return t.clone()
27
+ if isinstance(t, np.ndarray):
28
+ return np.copy(t)
29
+ return t
30
+
31
+ def clamp(t, min_val, max_val):
32
+ if is_tensor(t):
33
+ return torch.clamp(t, min_val, max_val)
34
+ if isinstance(t, np.ndarray):
35
+ return np.clip(t, min_val, max_val)
36
+ return max(min_val, min(t, max_val))
37
+
38
+ def rescale(
39
+ t,
40
+ from_range: tuple[float, float],
41
+ to_range: tuple[float, float]
42
+ ):
43
+ from_min, from_max = from_range
44
+ to_min, to_max = to_range
45
+ return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
46
+
47
+ # wrapper
48
+
49
+ class ActionTransformWrapper:
50
+ def __init__(
51
+ self,
52
+ env,
53
+ transforms = None,
54
+ clip = None
55
+ ):
56
+ self.env = env
57
+ self.clip = clip
58
+
59
+ if isinstance(transforms, dict):
60
+ transforms = [transforms]
61
+
62
+ self.transforms = default(transforms, [])
63
+
64
+ def __getattr__(self, name):
65
+ if name.startswith('_'):
66
+ raise AttributeError(f"attempted to get missing private attribute '{name}'")
67
+ return getattr(self.env, name)
68
+
69
+ def reset(self, **kwargs):
70
+ return self.env.reset(**kwargs)
71
+
72
+ def step(self, action):
73
+ def transform_action(t):
74
+ if not is_float_dtype(t):
75
+ return t
76
+
77
+ t = copy(t)
78
+
79
+ for ind, transform in enumerate(self.transforms):
80
+ indices = transform.get('indices', ind if len(self.transforms) > 1 else None)
81
+
82
+ if 'rescale_from_to' not in transform:
83
+ continue
84
+
85
+ from_range, to_range = transform['rescale_from_to']
86
+ fn = partial(rescale, from_range = from_range, to_range = to_range)
87
+
88
+ if exists(indices):
89
+ t[..., indices] = fn(t[..., indices])
90
+ else:
91
+ t = fn(t)
92
+
93
+ if exists(self.clip):
94
+ t = clamp(t, *self.clip)
95
+
96
+ return t
97
+
98
+ transformed_action = tree_map(transform_action, action)
99
+ return self.env.step(transformed_action)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "env-ssl-wrapper"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "An RL environment wrapper for learning SSL in the background"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,93 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import numpy as np
4
- import torch
5
- from torch import is_tensor
6
- from torch.utils._pytree import tree_map
7
-
8
- # helpers
9
-
10
- def exists(v):
11
- return v is not None
12
-
13
- def default(v, d):
14
- return v if exists(v) else d
15
-
16
- def rescale(
17
- t,
18
- from_range: tuple[float, float],
19
- to_range: tuple[float, float]
20
- ):
21
- from_min, from_max = from_range
22
- to_min, to_max = to_range
23
- return (t - from_min) / (from_max - from_min) * (to_max - to_min) + to_min
24
-
25
- # classes
26
-
27
- class ActionTransformWrapper:
28
- def __init__(
29
- self,
30
- env,
31
- transforms = None,
32
- clip = None
33
- ):
34
- self.env = env
35
-
36
- if isinstance(transforms, dict):
37
- transforms = [transforms]
38
-
39
- self.transforms = default(transforms, [])
40
- self.clip = clip
41
-
42
- def __getattr__(self, name):
43
- if name.startswith('_'):
44
- raise AttributeError(f"attempted to get missing private attribute '{name}'")
45
- return getattr(self.env, name)
46
-
47
- def reset(self, **kwargs):
48
- return self.env.reset(**kwargs)
49
-
50
- def step(self, action):
51
- def transform_action(t):
52
- is_torch_float = is_tensor(t) and t.is_floating_point()
53
- is_np_float = isinstance(t, np.ndarray) and np.issubdtype(t.dtype, np.floating)
54
- is_scalar_float = isinstance(t, float)
55
-
56
- if not (is_torch_float or is_np_float or is_scalar_float):
57
- return t
58
-
59
- if is_tensor(t):
60
- t = t.clone()
61
- elif isinstance(t, np.ndarray):
62
- t = np.copy(t)
63
-
64
- for ind, transform in enumerate(self.transforms):
65
- indices = transform.get('indices')
66
- rescale_from_to = transform.get('rescale_from_to')
67
-
68
- if not exists(indices) and len(self.transforms) > 1:
69
- indices = ind
70
-
71
- if exists(rescale_from_to):
72
- from_range, to_range = rescale_from_to
73
-
74
- if not exists(indices):
75
- t = rescale(t, from_range, to_range)
76
- else:
77
- part = t[..., indices]
78
- t[..., indices] = rescale(part, from_range, to_range)
79
-
80
- if exists(self.clip):
81
- min_clip, max_clip = self.clip
82
-
83
- if is_tensor(t):
84
- t = torch.clamp(t, min_clip, max_clip)
85
- elif isinstance(t, np.ndarray):
86
- t = np.clip(t, min_clip, max_clip)
87
- else:
88
- t = max(min_clip, min(t, max_clip))
89
-
90
- return t
91
-
92
- action = tree_map(transform_action, action)
93
- return self.env.step(action)
File without changes