jax-envelope 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +54 -0
- envelope/compat/brax_envelope.py +5 -3
- envelope/compat/craftax_envelope.py +17 -2
- envelope/compat/gymnax_envelope.py +34 -7
- envelope/compat/jumanji_envelope.py +3 -2
- envelope/compat/kinetix_envelope.py +3 -2
- envelope/compat/mujoco_playground_envelope.py +1 -1
- envelope/compat/navix_envelope.py +1 -1
- envelope/environment.py +16 -9
- envelope/spaces.py +41 -21
- envelope/struct.py +10 -1
- envelope/typing.py +0 -16
- envelope/wrappers/__init__.py +36 -0
- envelope/wrappers/autoreset_wrapper.py +65 -21
- envelope/wrappers/clip_action_wrapper.py +27 -0
- envelope/wrappers/continuous_observation_wrapper.py +61 -0
- envelope/wrappers/episode_statistics_wrapper.py +29 -36
- envelope/wrappers/flatten_action_wrapper.py +75 -0
- envelope/wrappers/flatten_observation_wrapper.py +81 -0
- envelope/wrappers/normalization.py +1 -1
- envelope/wrappers/observation_normalization_wrapper.py +28 -16
- envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
- envelope/wrappers/state_injection_wrapper.py +18 -22
- envelope/wrappers/truncation_wrapper.py +18 -14
- envelope/wrappers/vmap_envs_wrapper.py +26 -21
- envelope/wrappers/vmap_wrapper.py +36 -21
- envelope/wrappers/wrapper.py +8 -8
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/METADATA +3 -3
- jax_envelope-0.2.0.dist-info/RECORD +32 -0
- envelope/wrappers/timestep_wrapper.py +0 -22
- jax_envelope-0.1.0.dist-info/RECORD +0 -27
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/WHEEL +0 -0
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from envelope.environment import Info, State
|
|
8
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, Space, peel_batched
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
from envelope.wrappers.wrapper import Wrapper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def to_float(obs: PyTree) -> PyTree:
|
|
14
|
+
return jax.tree.map(lambda x: x.astype(jnp.float32), obs)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_continuous(space: Discrete | Continuous) -> Continuous:
|
|
18
|
+
if isinstance(space, Continuous):
|
|
19
|
+
low = jnp.asarray(space.low, dtype=jnp.float32)
|
|
20
|
+
high = jnp.asarray(space.high, dtype=jnp.float32)
|
|
21
|
+
return Continuous(low=low, high=high)
|
|
22
|
+
elif isinstance(space, Discrete):
|
|
23
|
+
n = jnp.asarray(space.n)
|
|
24
|
+
low = jnp.zeros_like(n, dtype=jnp.float32)
|
|
25
|
+
high = jnp.asarray(n - 1, dtype=jnp.float32)
|
|
26
|
+
return Continuous(low=low, high=high)
|
|
27
|
+
raise TypeError(f"Expected Discrete or Continuous, got {type(space)}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ContinuousObservationWrapper(Wrapper):
|
|
31
|
+
@override
|
|
32
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
33
|
+
state, info = self.env.init(key)
|
|
34
|
+
info = info.update(obs=to_float(info.obs))
|
|
35
|
+
return state, info
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def reset(self, key: Key, state: State) -> tuple[State, Info]:
|
|
39
|
+
state, info = self.env.reset(key, state)
|
|
40
|
+
info = info.update(obs=to_float(info.obs))
|
|
41
|
+
return state, info
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
45
|
+
state, info = self.env.step(state, action)
|
|
46
|
+
info = info.update(obs=to_float(info.obs))
|
|
47
|
+
return state, info
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
@cached_property
|
|
51
|
+
def observation_space(self) -> Space:
|
|
52
|
+
batch_dims, base = peel_batched(self.env.observation_space)
|
|
53
|
+
|
|
54
|
+
def is_leaf(x):
|
|
55
|
+
return isinstance(x, (Discrete, Continuous))
|
|
56
|
+
|
|
57
|
+
space = jax.tree.map(to_continuous, base, is_leaf=is_leaf)
|
|
58
|
+
|
|
59
|
+
for batch_dim in batch_dims:
|
|
60
|
+
space = BatchedSpace(space, batch_dim)
|
|
61
|
+
return space
|
|
@@ -1,47 +1,40 @@
|
|
|
1
|
-
from dataclasses import field
|
|
2
|
-
from envelope.wrappers import Wrapper
|
|
3
1
|
from typing import override
|
|
4
2
|
|
|
5
|
-
|
|
6
|
-
from envelope.typing import Key, PyTree, Array
|
|
3
|
+
import jax
|
|
7
4
|
|
|
5
|
+
from envelope.environment import Info, State
|
|
6
|
+
from envelope.struct import FrozenPyTreeNode, field
|
|
7
|
+
from envelope.typing import Key, PyTree
|
|
8
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
8
9
|
|
|
9
|
-
class EpisodeStatisticsWrapper(Wrapper):
|
|
10
|
-
class StatisticsState(WappedState):
|
|
11
|
-
episode_reward: Array
|
|
12
|
-
episode_length: Array
|
|
13
|
-
_pointer: int = field(default=0)
|
|
14
10
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
info =
|
|
20
|
-
return state, info
|
|
11
|
+
class EpisodeStatistics(FrozenPyTreeNode):
|
|
12
|
+
reward: jax.Array = field(default=0)
|
|
13
|
+
length: jax.Array = field(default=0)
|
|
14
|
+
|
|
21
15
|
|
|
16
|
+
class EpisodeStatisticsWrapper(Wrapper):
|
|
17
|
+
class EpisodeStatisticsState(WrappedState):
|
|
18
|
+
stats: EpisodeStatistics = field(default=EpisodeStatistics())
|
|
22
19
|
|
|
23
20
|
@override
|
|
24
|
-
def
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
state, info
|
|
28
|
-
info =
|
|
29
|
-
return state, info
|
|
21
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
22
|
+
inner_state, info = self.env.init(key)
|
|
23
|
+
state = self.EpisodeStatisticsState(inner_state=inner_state)
|
|
24
|
+
return state, info.update(stats=state.stats)
|
|
30
25
|
|
|
31
26
|
@override
|
|
32
|
-
def
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
return
|
|
36
|
-
|
|
37
|
-
def _update_episode_statistics(self, info: Info) -> Info:
|
|
38
|
-
"""Update episode statistics in the info dictionary."""
|
|
39
|
-
if "episode_statistics" not in info:
|
|
40
|
-
info["episode_statistics"] = {
|
|
41
|
-
"reward": 0.0,
|
|
42
|
-
"length": 0,
|
|
43
|
-
}
|
|
44
|
-
info["episode_statistics"]["reward"] += info.get("reward", 0.0)
|
|
45
|
-
info["episode_statistics"]["length"] += 1
|
|
46
|
-
return info
|
|
27
|
+
def reset(self, key: Key, state: State) -> tuple[State, Info]:
|
|
28
|
+
inner_state, info = self.env.reset(key, state.inner_state)
|
|
29
|
+
state = state.replace(inner_state=inner_state)
|
|
30
|
+
return state, info.update(stats=state.stats)
|
|
47
31
|
|
|
32
|
+
@override
|
|
33
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
34
|
+
inner_state, info = self.env.step(state.inner_state, action)
|
|
35
|
+
stats = state.stats.replace(
|
|
36
|
+
reward=state.stats.reward + info.reward,
|
|
37
|
+
length=state.stats.length + 1,
|
|
38
|
+
)
|
|
39
|
+
state = state.replace(inner_state=inner_state, stats=stats)
|
|
40
|
+
return state, info.update(stats=stats)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from envelope.environment import Info, State
|
|
8
|
+
from envelope.spaces import (
|
|
9
|
+
BatchedSpace,
|
|
10
|
+
Continuous,
|
|
11
|
+
Discrete,
|
|
12
|
+
PyTreeSpace,
|
|
13
|
+
Space,
|
|
14
|
+
peel_batched,
|
|
15
|
+
)
|
|
16
|
+
from envelope.typing import PyTree
|
|
17
|
+
from envelope.wrappers.wrapper import Wrapper
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def flatten_space(space: PyTreeSpace | Continuous | Discrete):
|
|
21
|
+
def is_leaf(x):
|
|
22
|
+
# Tuples containing only integers are shape tuples (leaves)
|
|
23
|
+
# PyTreeSpace can only have tuples that contain at least a Space, so
|
|
24
|
+
# tuples with only integers must be shape tuples from leaf spaces
|
|
25
|
+
return isinstance(x, tuple) and all(isinstance(i, int) for i in x)
|
|
26
|
+
|
|
27
|
+
shapes, treedef = jax.tree.flatten(space.shape, is_leaf=is_leaf)
|
|
28
|
+
dims = [jnp.prod(jnp.asarray(shape)) for shape in shapes]
|
|
29
|
+
return treedef, shapes, dims
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def unflatten_x(x: jax.Array, treedef, shapes, dims):
|
|
33
|
+
indices = jnp.cumsum(jnp.array(dims))[:-1] # last split is the remainder
|
|
34
|
+
xs = jnp.split(x, indices)
|
|
35
|
+
xs = jax.tree.map(lambda x, shape: x.reshape(shape), xs, shapes)
|
|
36
|
+
return jax.tree.unflatten(treedef, xs)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class FlattenActionWrapper(Wrapper):
|
|
40
|
+
@override
|
|
41
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
42
|
+
treedef, shapes, dims = flatten_space(self.env.action_space)
|
|
43
|
+
action = unflatten_x(action, treedef, shapes, dims)
|
|
44
|
+
return self.env.step(state, action)
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
@cached_property
|
|
48
|
+
def action_space(self) -> Space:
|
|
49
|
+
batch_dims, base = peel_batched(self.env.action_space)
|
|
50
|
+
|
|
51
|
+
def is_leaf(x):
|
|
52
|
+
return isinstance(x, (Continuous, Discrete))
|
|
53
|
+
|
|
54
|
+
spaces = jax.tree.leaves(base, is_leaf=is_leaf)
|
|
55
|
+
act_cls = type(spaces[0])
|
|
56
|
+
|
|
57
|
+
if not all(isinstance(space, act_cls) for space in spaces):
|
|
58
|
+
raise ValueError("All spaces must be of the same type")
|
|
59
|
+
|
|
60
|
+
if act_cls == Continuous:
|
|
61
|
+
lows = [jnp.asarray(s.low).reshape(-1) for s in spaces]
|
|
62
|
+
highs = [jnp.asarray(s.high).reshape(-1) for s in spaces]
|
|
63
|
+
low = jnp.concatenate(lows, axis=0)
|
|
64
|
+
high = jnp.concatenate(highs, axis=0)
|
|
65
|
+
space = Continuous(low=low, high=high)
|
|
66
|
+
elif act_cls == Discrete:
|
|
67
|
+
ns = [jnp.asarray(s.n).reshape(-1) for s in spaces]
|
|
68
|
+
n = jnp.concatenate(ns, axis=0)
|
|
69
|
+
space = Discrete(n=n)
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unsupported space type: {act_cls}")
|
|
72
|
+
|
|
73
|
+
for batch_dim in batch_dims:
|
|
74
|
+
space = BatchedSpace(space, batch_dim)
|
|
75
|
+
return space
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from envelope.environment import Info, State
|
|
8
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, Space, peel_batched
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
from envelope.wrappers.wrapper import Wrapper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def flatten_space(space: Space):
|
|
14
|
+
def is_leaf(x):
|
|
15
|
+
# Tuples containing only integers are shape tuples (leaves)
|
|
16
|
+
# PyTreeSpace can only have tuples that contain at least a Space, so
|
|
17
|
+
# tuples with only integers must be shape tuples from leaf spaces
|
|
18
|
+
return isinstance(x, tuple) and all(isinstance(i, int) for i in x)
|
|
19
|
+
|
|
20
|
+
shapes, treedef = jax.tree.flatten(space.shape, is_leaf=is_leaf)
|
|
21
|
+
dims = [jnp.prod(jnp.asarray(shape)) for shape in shapes]
|
|
22
|
+
return treedef, shapes, dims
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def flatten_x(x: PyTree):
|
|
26
|
+
leaves = jax.tree.leaves(x)
|
|
27
|
+
xs = jax.tree.map(lambda x: jnp.asarray(x).reshape(-1), leaves)
|
|
28
|
+
x = jnp.concatenate(xs, axis=0)
|
|
29
|
+
return x
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FlattenObservationWrapper(Wrapper):
|
|
33
|
+
@override
|
|
34
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
35
|
+
state, info = self.env.init(key)
|
|
36
|
+
info = info.update(obs=flatten_x(info.obs))
|
|
37
|
+
return state, info
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
def reset(self, key: Key, state: State) -> tuple[State, Info]:
|
|
41
|
+
state, info = self.env.reset(key, state)
|
|
42
|
+
info = info.update(obs=flatten_x(info.obs))
|
|
43
|
+
return state, info
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
47
|
+
state, info = self.env.step(state, action)
|
|
48
|
+
info = info.update(obs=flatten_x(info.obs))
|
|
49
|
+
return state, info
|
|
50
|
+
|
|
51
|
+
@override
|
|
52
|
+
@cached_property
|
|
53
|
+
def observation_space(self) -> Space:
|
|
54
|
+
batch_dims, base = peel_batched(self.env.observation_space)
|
|
55
|
+
|
|
56
|
+
def is_leaf(x):
|
|
57
|
+
spaces = (Continuous, Discrete)
|
|
58
|
+
return isinstance(x, spaces)
|
|
59
|
+
|
|
60
|
+
spaces = jax.tree.leaves(base, is_leaf=is_leaf)
|
|
61
|
+
obs_cls = type(spaces[0])
|
|
62
|
+
|
|
63
|
+
if not all(isinstance(space, obs_cls) for space in spaces):
|
|
64
|
+
raise ValueError("All spaces must be of the same type")
|
|
65
|
+
|
|
66
|
+
if obs_cls == Continuous:
|
|
67
|
+
lows = [jnp.asarray(s.low).reshape(-1) for s in spaces]
|
|
68
|
+
highs = [jnp.asarray(s.high).reshape(-1) for s in spaces]
|
|
69
|
+
low = jnp.concatenate(lows, axis=0)
|
|
70
|
+
high = jnp.concatenate(highs, axis=0)
|
|
71
|
+
space = Continuous(low=low, high=high)
|
|
72
|
+
elif obs_cls == Discrete:
|
|
73
|
+
ns = [jnp.asarray(s.n).reshape(-1) for s in spaces]
|
|
74
|
+
n = jnp.concatenate(ns, axis=0)
|
|
75
|
+
space = Discrete(n=n)
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unsupported space type: {obs_cls}")
|
|
78
|
+
|
|
79
|
+
for batch_dim in batch_dims:
|
|
80
|
+
space = BatchedSpace(space, batch_dim)
|
|
81
|
+
return space
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import cast, override
|
|
2
3
|
|
|
3
4
|
import jax
|
|
4
5
|
from jax import numpy as jnp
|
|
5
6
|
|
|
6
7
|
from envelope.environment import Info
|
|
7
|
-
from envelope.spaces import BatchedSpace, PyTreeSpace, Space
|
|
8
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
8
9
|
from envelope.struct import field, static_field
|
|
9
10
|
from envelope.typing import Key, PyTree
|
|
10
11
|
from envelope.wrappers.normalization import RunningMeanVar, update_rmv
|
|
@@ -36,7 +37,7 @@ class ObservationNormalizationWrapper(Wrapper):
|
|
|
36
37
|
mean = jax.tree.map(zeros, self.stats_spec)
|
|
37
38
|
var = jax.tree.map(ones, self.stats_spec)
|
|
38
39
|
|
|
39
|
-
return RunningMeanVar(mean=mean, var=var, count=0)
|
|
40
|
+
return RunningMeanVar(mean=mean, var=var, count=jnp.asarray(0))
|
|
40
41
|
|
|
41
42
|
def _normalize_obs(self, obs: PyTree, rmv: RunningMeanVar) -> PyTree:
|
|
42
43
|
def norm_leaf(x, mean, std, spec):
|
|
@@ -66,29 +67,40 @@ class ObservationNormalizationWrapper(Wrapper):
|
|
|
66
67
|
return state, info
|
|
67
68
|
|
|
68
69
|
@override
|
|
69
|
-
def
|
|
70
|
-
|
|
71
|
-
) -> tuple[WrappedState, Info]:
|
|
72
|
-
inner_state = None
|
|
70
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
71
|
+
inner_state, info = self.env.init(key)
|
|
73
72
|
rmv_state = self._init_rmv_state()
|
|
74
|
-
if state:
|
|
75
|
-
inner_state = state.inner_state
|
|
76
|
-
rmv_state = state.rmv_state
|
|
77
|
-
|
|
78
|
-
inner_state, info = self.env.reset(key, inner_state, **kwargs)
|
|
79
73
|
next_state = self.ObservationNormalizationState(
|
|
80
74
|
inner_state=inner_state, rmv_state=rmv_state
|
|
81
75
|
)
|
|
82
76
|
return self._normalize_and_update(next_state, info)
|
|
83
77
|
|
|
84
78
|
@override
|
|
85
|
-
def
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
79
|
+
def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
|
|
80
|
+
inner_state, info = self.env.reset(key, state.inner_state)
|
|
81
|
+
# Preserve running statistics across resets
|
|
82
|
+
next_state = self.ObservationNormalizationState(
|
|
83
|
+
inner_state=inner_state, rmv_state=state.rmv_state
|
|
84
|
+
)
|
|
85
|
+
return self._normalize_and_update(next_state, info)
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
89
|
+
inner_state, info = self.env.step(state.inner_state, action)
|
|
89
90
|
state = state.replace(inner_state=inner_state)
|
|
90
91
|
return self._normalize_and_update(state, info)
|
|
91
92
|
|
|
93
|
+
@override
|
|
94
|
+
@cached_property
|
|
95
|
+
def observation_space(self) -> Space:
|
|
96
|
+
def to_continuous(space: Continuous | Discrete) -> Continuous:
|
|
97
|
+
return Continuous.from_shape(low=-jnp.inf, high=jnp.inf, shape=space.shape)
|
|
98
|
+
|
|
99
|
+
def is_leaf(space: Space) -> bool:
|
|
100
|
+
return isinstance(space, (Discrete, Continuous))
|
|
101
|
+
|
|
102
|
+
return jax.tree.map(to_continuous, self.env.observation_space, is_leaf=is_leaf)
|
|
103
|
+
|
|
92
104
|
|
|
93
105
|
def _infer_stats_spec(space: Space) -> PyTree:
|
|
94
106
|
"""
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from envelope import spaces
|
|
8
|
+
from envelope.environment import Info
|
|
9
|
+
from envelope.struct import field
|
|
10
|
+
from envelope.typing import Key, PyTree
|
|
11
|
+
from envelope.wrappers.vmap_wrapper import _split_or_keep_key
|
|
12
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PooledInitVmapWrapper(Wrapper):
|
|
16
|
+
batch_size: int = field(kw_only=True)
|
|
17
|
+
pool_size: int = field(kw_only=True)
|
|
18
|
+
|
|
19
|
+
class PooledInitVmapState(WrappedState):
|
|
20
|
+
init_key: Key = field()
|
|
21
|
+
last_final: Info = field()
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
25
|
+
keys = _split_or_keep_key(key, self.batch_size + 1)
|
|
26
|
+
key_next, keys_pool = keys[0], keys[1:]
|
|
27
|
+
inner_state, info = jax.vmap(self.env.init)(keys_pool)
|
|
28
|
+
pholder_info = jax.tree.map(
|
|
29
|
+
lambda x: jnp.full_like(x, jnp.nan, dtype=jnp.float32), info
|
|
30
|
+
)
|
|
31
|
+
state = self.PooledInitVmapState(
|
|
32
|
+
inner_state=inner_state,
|
|
33
|
+
init_key=key_next,
|
|
34
|
+
last_final=pholder_info,
|
|
35
|
+
)
|
|
36
|
+
return state, info.update(final=pholder_info)
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
|
|
40
|
+
# It's hard to support reset for this wrapper.
|
|
41
|
+
# We would have to init the state of a pool of unwrapped environments, and then
|
|
42
|
+
# somehow inject this into the stack of wrapped states. The current data
|
|
43
|
+
# structure for wrapped states does not make this possible without being super
|
|
44
|
+
# hacky, and violating the assumption that wrapped states are opaque (we would
|
|
45
|
+
# likely have to recursively descend by checking if
|
|
46
|
+
# hasattr(state, "inner_state")).
|
|
47
|
+
# Since there is currently no use case in which we need to carry state across
|
|
48
|
+
# episodes before vmapping, we will implement this later.
|
|
49
|
+
keys = _split_or_keep_key(key, self.batch_size + 1)
|
|
50
|
+
key_next, keys_pool = keys[0], keys[1:]
|
|
51
|
+
inner_state, info = jax.vmap(self.env.reset)(keys_pool, state.inner_state)
|
|
52
|
+
state = state.replace(inner_state=inner_state, init_key=key_next)
|
|
53
|
+
return state, info.update(final=state.last_final)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
57
|
+
inner_state, info = jax.vmap(self.env.step)(state.inner_state, action)
|
|
58
|
+
done = info.terminated | info.truncated
|
|
59
|
+
|
|
60
|
+
# Compute pool_size fresh init states
|
|
61
|
+
key_pool = jax.random.fold_in(state.init_key, 0)
|
|
62
|
+
next_init_key = jax.random.fold_in(state.init_key, 1)
|
|
63
|
+
keys_pool = jax.random.split(key_pool, self.pool_size)
|
|
64
|
+
inner_states_pool, infos_pool = jax.vmap(self.env.init)(keys_pool)
|
|
65
|
+
|
|
66
|
+
# Randomly assign each env a init state from the pool
|
|
67
|
+
key_idxs = jax.random.fold_in(state.init_key, 2)
|
|
68
|
+
pool_idxs = jax.random.randint(key_idxs, (self.batch_size,), 0, self.pool_size)
|
|
69
|
+
|
|
70
|
+
# Expand pool states to batch_size via indexing
|
|
71
|
+
mapped_init_state = jax.tree.map(lambda x: x[pool_idxs], inner_states_pool)
|
|
72
|
+
mapped_init_info = jax.tree.map(lambda x: x[pool_idxs], infos_pool)
|
|
73
|
+
|
|
74
|
+
# Select inner_state: init for done envs, continue for others
|
|
75
|
+
final_inner_state = jax.tree.map(
|
|
76
|
+
lambda init, curr: jax.vmap(jnp.where)(done, init, curr),
|
|
77
|
+
mapped_init_state,
|
|
78
|
+
inner_state,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Select last_final: on done, store terminal info; on continue, keep previous
|
|
82
|
+
final_last_final = jax.tree.map(
|
|
83
|
+
lambda curr, prev: jax.vmap(jnp.where)(done, curr, prev),
|
|
84
|
+
info,
|
|
85
|
+
state.last_final,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Build final_info with final field
|
|
89
|
+
# For done envs: obs is new initial obs, final is terminal info
|
|
90
|
+
# For continue envs: obs is current obs, final is previous last_final
|
|
91
|
+
final_obs = jax.tree.map(
|
|
92
|
+
lambda init, curr: jax.vmap(jnp.where)(done, init, curr),
|
|
93
|
+
mapped_init_info.obs,
|
|
94
|
+
info.obs,
|
|
95
|
+
)
|
|
96
|
+
final_final = jax.tree.map(
|
|
97
|
+
lambda curr, prev: jax.vmap(jnp.where)(done, curr, prev),
|
|
98
|
+
info, # Terminal info snapshot for done envs
|
|
99
|
+
state.last_final, # Previous episode's final for continue envs
|
|
100
|
+
)
|
|
101
|
+
final_info = info.update(obs=final_obs, final=final_final)
|
|
102
|
+
|
|
103
|
+
state = state.replace(
|
|
104
|
+
inner_state=final_inner_state,
|
|
105
|
+
init_key=next_init_key,
|
|
106
|
+
last_final=final_last_final,
|
|
107
|
+
)
|
|
108
|
+
return state, final_info
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
@cached_property
|
|
112
|
+
def observation_space(self) -> spaces.Space:
|
|
113
|
+
return spaces.BatchedSpace(
|
|
114
|
+
space=self.env.observation_space, batch_size=self.batch_size
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@override
|
|
118
|
+
@cached_property
|
|
119
|
+
def action_space(self) -> spaces.Space:
|
|
120
|
+
return spaces.BatchedSpace(
|
|
121
|
+
space=self.env.action_space, batch_size=self.batch_size
|
|
122
|
+
)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import override
|
|
2
|
+
|
|
1
3
|
from envelope.environment import Info, InfoContainer
|
|
2
4
|
from envelope.struct import field
|
|
3
5
|
from envelope.typing import Key, PyTree
|
|
@@ -12,7 +14,7 @@ class StateInjectionWrapper(Wrapper):
|
|
|
12
14
|
|
|
13
15
|
Usage:
|
|
14
16
|
env = AutoResetWrapper(StateInjectionWrapper(env=base_env))
|
|
15
|
-
state, info = env.
|
|
17
|
+
state, info = env.init(key)
|
|
16
18
|
|
|
17
19
|
for outer_iter in range(num_outer_iters):
|
|
18
20
|
# Sample a new task and set it as the reset state
|
|
@@ -60,32 +62,26 @@ class StateInjectionWrapper(Wrapper):
|
|
|
60
62
|
|
|
61
63
|
return update_injected(state)
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
state = self.InjectedState(inner_state=None)
|
|
69
|
-
|
|
70
|
-
# If no reset state is set, reset wrapped environment
|
|
71
|
-
if state.reset_state is None and state.reset_obs is None:
|
|
72
|
-
inner_state, info = self.env.reset(key, state=state.inner_state, **kwargs)
|
|
65
|
+
@override
|
|
66
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
67
|
+
inner_state, info = self.env.init(key)
|
|
68
|
+
state = self.InjectedState(inner_state=inner_state)
|
|
69
|
+
return state, info
|
|
73
70
|
|
|
74
|
-
|
|
75
|
-
|
|
71
|
+
@override
|
|
72
|
+
def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
|
|
73
|
+
# If reset state is set, use it instead of resetting inner env
|
|
74
|
+
if state.reset_state is not None and state.reset_obs is not None:
|
|
76
75
|
inner_state = state.reset_state
|
|
77
76
|
info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
elif state.reset_state is None and state.reset_obs is None:
|
|
78
|
+
inner_state, info = self.env.reset(key, state.inner_state)
|
|
80
79
|
else:
|
|
81
80
|
raise ValueError("State must set both reset_state and reset_obs or neither")
|
|
82
81
|
|
|
83
|
-
|
|
84
|
-
state = state.replace(inner_state=inner_state)
|
|
85
|
-
return state, info
|
|
82
|
+
return state.replace(inner_state=inner_state), info
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
84
|
+
@override
|
|
85
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
86
|
+
inner_state, info = self.env.step(state.inner_state, action)
|
|
91
87
|
return state.replace(inner_state=inner_state), info
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import override
|
|
2
|
+
|
|
1
3
|
import jax.numpy as jnp
|
|
2
4
|
|
|
3
5
|
from envelope.environment import Info
|
|
@@ -12,20 +14,22 @@ class TruncationWrapper(Wrapper):
|
|
|
12
14
|
class TruncationState(WrappedState):
|
|
13
15
|
steps: jnp.ndarray | int = field(default=0)
|
|
14
16
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
inner_state, info = self.env.reset(key)
|
|
17
|
+
@override
|
|
18
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
19
|
+
inner_state, info = self.env.init(key)
|
|
19
20
|
state = self.TruncationState(inner_state=inner_state, steps=0)
|
|
20
21
|
return state, info.update(truncated=self.max_steps <= 0)
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
23
|
+
@override
|
|
24
|
+
def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
|
|
25
|
+
inner_state, info = self.env.reset(key, state.inner_state)
|
|
26
|
+
state = state.replace(inner_state=inner_state, steps=0)
|
|
27
|
+
return state, info.update(truncated=self.max_steps <= 0)
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
31
|
+
next_inner_state, info = self.env.step(state.inner_state, action)
|
|
32
|
+
steps = state.steps + 1
|
|
33
|
+
state = self.TruncationState(inner_state=next_inner_state, steps=steps)
|
|
34
|
+
truncated = jnp.asarray(steps) >= self.max_steps
|
|
35
|
+
return state, info.update(truncated=truncated)
|