dfa-gym 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dfa_gym/__init__.py +5 -15
- dfa_gym/dfa_bisim_env.py +121 -0
- dfa_gym/dfa_wrapper.py +185 -52
- dfa_gym/env.py +168 -0
- dfa_gym/maps/2buttons_2agents.pdf +0 -0
- dfa_gym/maps/2rooms_2agents.pdf +0 -0
- dfa_gym/maps/4buttons_4agents.pdf +0 -0
- dfa_gym/maps/4rooms_4agents.pdf +0 -0
- dfa_gym/robot.png +0 -0
- dfa_gym/spaces.py +156 -0
- dfa_gym/token_env.py +571 -0
- dfa_gym/utils.py +266 -0
- dfa_gym-0.2.0.dist-info/METADATA +93 -0
- dfa_gym-0.2.0.dist-info/RECORD +16 -0
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/WHEEL +1 -1
- dfa_gym/dfa_env.py +0 -45
- dfa_gym-0.1.0.dist-info/METADATA +0 -11
- dfa_gym-0.1.0.dist-info/RECORD +0 -7
- {dfa_gym-0.1.0.dist-info → dfa_gym-0.2.0.dist-info}/licenses/LICENSE +0 -0
dfa_gym/__init__.py
CHANGED
|
@@ -1,16 +1,6 @@
|
|
|
1
|
-
from dfa_gym.
|
|
1
|
+
from dfa_gym.token_env import *
|
|
2
|
+
from dfa_gym.dfa_bisim_env import *
|
|
2
3
|
from dfa_gym.dfa_wrapper import *
|
|
3
|
-
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
register(
|
|
8
|
-
id='DFAEnv-v0',
|
|
9
|
-
entry_point='dfa_gym.dfa_env:DFAEnv',
|
|
10
|
-
kwargs = {"sampler": RADSampler(n_tokens=12), "timeout": 75}
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
register(
|
|
14
|
-
id='DFAEnv-v1',
|
|
15
|
-
entry_point='dfa_gym.dfa_env:DFAEnv'
|
|
16
|
-
)
|
|
4
|
+
from dfa_gym.env import *
|
|
5
|
+
from dfa_gym.spaces import *
|
|
6
|
+
from dfa_gym.utils import *
|
dfa_gym/dfa_bisim_env.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import dfax
|
|
3
|
+
import chex
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import struct
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Tuple, Dict
|
|
8
|
+
from dfa_gym import spaces
|
|
9
|
+
from dfa_gym.env import MultiAgentEnv, State
|
|
10
|
+
from dfax.samplers import DFASampler, RADSampler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@struct.dataclass
|
|
14
|
+
class DFABisimState(State):
|
|
15
|
+
dfa_l: dfax.DFAx
|
|
16
|
+
dfa_r: dfax.DFAx
|
|
17
|
+
time: int
|
|
18
|
+
|
|
19
|
+
class DFABisimEnv(MultiAgentEnv):
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
sampler: DFASampler = RADSampler(),
|
|
24
|
+
max_steps_in_episode: int = 100
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(num_agents=1)
|
|
27
|
+
self.n_agents = self.num_agents
|
|
28
|
+
self.sampler = sampler
|
|
29
|
+
self.max_steps_in_episode = max_steps_in_episode
|
|
30
|
+
|
|
31
|
+
self.agents = [f"agent_{i}" for i in range(self.n_agents)]
|
|
32
|
+
|
|
33
|
+
self.action_spaces = {
|
|
34
|
+
agent: spaces.Discrete(self.sampler.n_tokens)
|
|
35
|
+
for agent in self.agents
|
|
36
|
+
}
|
|
37
|
+
max_dfa_size = self.sampler.max_size
|
|
38
|
+
n_tokens = self.sampler.n_tokens
|
|
39
|
+
self.observation_spaces = {
|
|
40
|
+
agent: spaces.Dict({
|
|
41
|
+
"graph_l": spaces.Dict({
|
|
42
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
|
|
43
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
|
|
44
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
|
|
45
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
|
|
46
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
|
|
47
|
+
}),
|
|
48
|
+
"graph_r": spaces.Dict({
|
|
49
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
|
|
50
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
|
|
51
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
|
|
52
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
|
|
53
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
|
|
54
|
+
})
|
|
55
|
+
})
|
|
56
|
+
for agent in self.agents
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
60
|
+
def reset(
|
|
61
|
+
self,
|
|
62
|
+
key: chex.PRNGKey
|
|
63
|
+
) -> Tuple[Dict[str, chex.Array], DFABisimState]:
|
|
64
|
+
|
|
65
|
+
def cond_fn(carry):
|
|
66
|
+
_, dfa_l, dfa_r = carry
|
|
67
|
+
return dfa_l == dfa_r
|
|
68
|
+
|
|
69
|
+
def body_fn(carry):
|
|
70
|
+
key, _, _ = carry
|
|
71
|
+
key, kl, kr = jax.random.split(key, 3)
|
|
72
|
+
dfa_l = self.sampler.sample(kl)
|
|
73
|
+
dfa_r = self.sampler.sample(kr)
|
|
74
|
+
return (key, dfa_l, dfa_r)
|
|
75
|
+
|
|
76
|
+
init_carry = body_fn((key, None, None))
|
|
77
|
+
_, dfa_l, dfa_r = jax.lax.while_loop(cond_fn, body_fn, init_carry)
|
|
78
|
+
|
|
79
|
+
state = DFABisimState(dfa_l=dfa_l, dfa_r=dfa_r, time=0)
|
|
80
|
+
obs = self.get_obs(state=state)
|
|
81
|
+
|
|
82
|
+
return {self.agents[0]: obs}, state
|
|
83
|
+
|
|
84
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
85
|
+
def step_env(
|
|
86
|
+
self,
|
|
87
|
+
key: chex.PRNGKey,
|
|
88
|
+
state: DFABisimState,
|
|
89
|
+
action: int
|
|
90
|
+
) -> Tuple[Dict[str, chex.Array], DFABisimState, Dict[str, float], Dict[str, bool], Dict]:
|
|
91
|
+
|
|
92
|
+
dfa_l = state.dfa_l.advance(action[self.agents[0]]).minimize()
|
|
93
|
+
dfa_r = state.dfa_r.advance(action[self.agents[0]]).minimize()
|
|
94
|
+
|
|
95
|
+
reward_l = dfa_l.reward(binary=False)
|
|
96
|
+
reward_r = dfa_r.reward(binary=False)
|
|
97
|
+
reward = reward_l - reward_r
|
|
98
|
+
|
|
99
|
+
new_state = DFABisimState(
|
|
100
|
+
dfa_l=dfa_l,
|
|
101
|
+
dfa_r=dfa_r,
|
|
102
|
+
time=state.time+1
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
done = jnp.logical_or(jnp.logical_or(dfa_l.n_states <= 1, dfa_r.n_states <= 1), new_state.time >= self.max_steps_in_episode)
|
|
106
|
+
|
|
107
|
+
obs = self.get_obs(state=new_state)
|
|
108
|
+
info = {}
|
|
109
|
+
|
|
110
|
+
return {self.agents[0]: obs}, new_state, {self.agents[0]: reward}, {self.agents[0]: done, "__all__": done}, info
|
|
111
|
+
|
|
112
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
113
|
+
def get_obs(
|
|
114
|
+
self,
|
|
115
|
+
state: DFABisimState
|
|
116
|
+
) -> Dict[str, chex.Array]:
|
|
117
|
+
return {
|
|
118
|
+
"graph_l": state.dfa_l.to_graph(),
|
|
119
|
+
"graph_r": state.dfa_r.to_graph()
|
|
120
|
+
}
|
|
121
|
+
|
dfa_gym/dfa_wrapper.py
CHANGED
|
@@ -1,57 +1,190 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
import jax
|
|
2
|
+
import dfax
|
|
3
|
+
import chex
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import struct
|
|
6
|
+
from dfa_gym import spaces
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import Tuple, Dict, Callable
|
|
9
|
+
from dfax.utils import list2batch, batch2graph
|
|
10
|
+
from dfa_gym.env import MultiAgentEnv, State
|
|
11
|
+
from dfax.samplers import DFASampler, RADSampler
|
|
5
12
|
|
|
6
|
-
from typing import Any
|
|
7
13
|
|
|
8
|
-
|
|
14
|
+
@struct.dataclass
|
|
15
|
+
class DFAWrapperState(State):
|
|
16
|
+
dfas: Dict[str, dfax.DFAx]
|
|
17
|
+
init_dfas: Dict[str, dfax.DFAx]
|
|
18
|
+
env_obs: chex.Array
|
|
19
|
+
env_state: State
|
|
20
|
+
|
|
21
|
+
class DFAWrapper(MultiAgentEnv):
|
|
9
22
|
|
|
10
|
-
class DFAWrapper(gym.Wrapper):
|
|
11
23
|
def __init__(
|
|
12
24
|
self,
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
self.
|
|
21
|
-
self.
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
24
|
-
self.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
25
|
+
env: MultiAgentEnv,
|
|
26
|
+
gamma: float | None = None,
|
|
27
|
+
sampler: DFASampler = RADSampler(),
|
|
28
|
+
binary_reward: bool = True,
|
|
29
|
+
progress: bool = True,
|
|
30
|
+
) -> None:
|
|
31
|
+
super().__init__(num_agents=env.num_agents)
|
|
32
|
+
self.env = env
|
|
33
|
+
self.gamma = gamma
|
|
34
|
+
self.sampler = sampler
|
|
35
|
+
self.binary_reward = binary_reward
|
|
36
|
+
self.progress = progress
|
|
37
|
+
|
|
38
|
+
assert self.sampler.n_tokens == self.env.n_tokens
|
|
39
|
+
|
|
40
|
+
self.agents = [f"agent_{i}" for i in range(self.num_agents)]
|
|
41
|
+
|
|
42
|
+
self.action_spaces = {
|
|
43
|
+
agent: self.env.action_space(agent)
|
|
44
|
+
for agent in self.agents
|
|
45
|
+
}
|
|
46
|
+
max_dfa_size = self.sampler.max_size
|
|
47
|
+
n_tokens = self.sampler.n_tokens
|
|
48
|
+
self.observation_spaces = {
|
|
49
|
+
agent: spaces.Dict({
|
|
50
|
+
"_id": spaces.Discrete(self.num_agents),
|
|
51
|
+
"obs": self.env.observation_space(agent),
|
|
52
|
+
"dfa": spaces.Dict({
|
|
53
|
+
"node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents, 4), dtype=jnp.float32),
|
|
54
|
+
"edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents*max_dfa_size*self.num_agents, n_tokens + 8), dtype=jnp.float32),
|
|
55
|
+
"edge_index": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(2, max_dfa_size*self.num_agents*max_dfa_size*self.num_agents), dtype=jnp.int32),
|
|
56
|
+
"current_state": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(self.num_agents,), dtype=jnp.int32),
|
|
57
|
+
"n_states": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(max_dfa_size*self.num_agents,), dtype=jnp.int32)
|
|
58
|
+
}),
|
|
59
|
+
})
|
|
60
|
+
for agent in self.agents
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
64
|
+
def reset(
|
|
65
|
+
self,
|
|
66
|
+
key: chex.PRNGKey
|
|
67
|
+
) -> Tuple[Dict[str, chex.Array], DFAWrapperState]:
|
|
68
|
+
keys = jax.random.split(key, 4 + self.num_agents)
|
|
69
|
+
|
|
70
|
+
env_obs, env_state = self.env.reset(keys[1])
|
|
71
|
+
|
|
72
|
+
n_trivial = jax.random.choice(keys[2], self.num_agents)
|
|
73
|
+
mask = jax.random.permutation(keys[3], jnp.arange(self.num_agents) < n_trivial)
|
|
74
|
+
|
|
75
|
+
def sample_dfa(dfa_key, sample_trivial):
|
|
76
|
+
return jax.tree_util.tree_map(
|
|
77
|
+
lambda t, s: jnp.where(sample_trivial, t, s),
|
|
78
|
+
self.sampler.trivial(True),
|
|
79
|
+
self.sampler.sample(dfa_key)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
dfas_tree = jax.vmap(sample_dfa)(keys[4:], mask)
|
|
83
|
+
|
|
84
|
+
dfas = {
|
|
85
|
+
agent: jax.tree_util.tree_map(lambda x: x[i], dfas_tree)
|
|
86
|
+
for i, agent in enumerate(self.agents)
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
state = DFAWrapperState(
|
|
90
|
+
dfas=dfas,
|
|
91
|
+
init_dfas={agent: dfas[agent] for agent in self.agents},
|
|
92
|
+
env_obs=env_obs,
|
|
93
|
+
env_state=env_state
|
|
94
|
+
)
|
|
95
|
+
obs = self.get_obs(state=state)
|
|
96
|
+
|
|
97
|
+
return obs, state
|
|
98
|
+
|
|
99
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
100
|
+
def step_env(
|
|
101
|
+
self,
|
|
102
|
+
key: chex.PRNGKey,
|
|
103
|
+
state: DFAWrapperState,
|
|
104
|
+
action: int,
|
|
105
|
+
) -> Tuple[Dict[str, chex.Array], DFAWrapperState, Dict[str, float], Dict[str, bool], Dict]:
|
|
106
|
+
|
|
107
|
+
env_obs, env_state, env_rewards, env_dones, env_info = self.env.step_env(key, state.env_state, action)
|
|
108
|
+
|
|
109
|
+
symbols = self.env.label_f(env_state)
|
|
110
|
+
|
|
111
|
+
dfas = {
|
|
112
|
+
agent: state.dfas[agent].advance(symbols[agent]).minimize()
|
|
113
|
+
for agent in self.agents
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
dones = {
|
|
117
|
+
agent: jnp.logical_or(env_dones[agent], dfas[agent].n_states <= 1)
|
|
118
|
+
for agent in self.agents
|
|
119
|
+
}
|
|
120
|
+
_dones = jnp.array([dones[agent] for agent in self.agents])
|
|
121
|
+
dones.update({"__all__": jnp.all(_dones)})
|
|
122
|
+
|
|
123
|
+
dfa_rewards_min = jnp.min(jnp.array([dfas[agent].reward(binary=self.binary_reward) for agent in self.agents]))
|
|
124
|
+
rewards = {
|
|
125
|
+
agent: jax.lax.cond(
|
|
126
|
+
dones["__all__"],
|
|
127
|
+
lambda _: env_rewards[agent] + dfa_rewards_min,
|
|
128
|
+
lambda _: env_rewards[agent],
|
|
129
|
+
operand=None
|
|
130
|
+
)
|
|
131
|
+
for agent in self.agents
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if self.gamma is not None:
|
|
135
|
+
rewards = {
|
|
136
|
+
agent: rewards[agent] + self.gamma * dfas[agent].reward(binary=self.binary_reward) - state.dfas[agent].reward(binary=self.binary_reward)
|
|
137
|
+
for agent in self.agents
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
infos = {}
|
|
141
|
+
|
|
142
|
+
state = DFAWrapperState(
|
|
143
|
+
dfas=dfas,
|
|
144
|
+
init_dfas=state.init_dfas,
|
|
145
|
+
env_obs=env_obs,
|
|
146
|
+
env_state=env_state
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
obs = self.get_obs(state=state)
|
|
150
|
+
|
|
151
|
+
return obs, state, rewards, dones, infos
|
|
152
|
+
|
|
153
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
154
|
+
def get_obs(
|
|
155
|
+
self,
|
|
156
|
+
state: DFAWrapperState
|
|
157
|
+
) -> Dict[str, chex.Array]:
|
|
158
|
+
if self.progress:
|
|
159
|
+
dfas = batch2graph(
|
|
160
|
+
list2batch(
|
|
161
|
+
[state.dfas[agent].to_graph() for agent in self.agents]
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
dfas = batch2graph(
|
|
166
|
+
list2batch(
|
|
167
|
+
[state.init_dfas[agent].to_graph() for agent in self.agents]
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
return {
|
|
171
|
+
agent: {
|
|
172
|
+
"_id": i,
|
|
173
|
+
"obs": state.env_obs[agent],
|
|
174
|
+
"dfa": dfas
|
|
175
|
+
}
|
|
176
|
+
for i, agent in enumerate(self.agents)
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
def render(self, state: DFAWrapperState):
|
|
180
|
+
out = ""
|
|
181
|
+
for agent in self.agents:
|
|
182
|
+
out += "****\n"
|
|
183
|
+
out += f"{agent}'s DFA:\n"
|
|
184
|
+
if self.progress:
|
|
185
|
+
out += f"{state.dfas[agent]}\n"
|
|
186
|
+
else:
|
|
187
|
+
out += f"{state.init_dfas[agent]}\n"
|
|
188
|
+
self.env.render(state.env_state)
|
|
189
|
+
print(out)
|
|
190
|
+
|
dfa_gym/env.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for multi agent gym environments with JAX
|
|
3
|
+
Based on the JaxMARL APIs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from typing import Dict
|
|
9
|
+
import chex
|
|
10
|
+
from functools import partial
|
|
11
|
+
from flax import struct
|
|
12
|
+
from typing import Tuple, Optional
|
|
13
|
+
|
|
14
|
+
from dfa_gym.spaces import Space
|
|
15
|
+
|
|
16
|
+
@struct.dataclass
|
|
17
|
+
class State:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MultiAgentEnv(object):
|
|
22
|
+
"""Jittable abstract base class for all JaxMARL Environments."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
num_agents: int,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
num_agents (int): maximum number of agents within the environment, used to set array dimensions
|
|
31
|
+
"""
|
|
32
|
+
self.num_agents = num_agents
|
|
33
|
+
self.observation_spaces = dict()
|
|
34
|
+
self.action_spaces = dict()
|
|
35
|
+
|
|
36
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
37
|
+
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
|
38
|
+
"""Performs resetting of the environment.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
key (chex.PRNGKey): random key
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Observations (Dict[str, chex.Array]): observations for each agent, keyed by agent name
|
|
45
|
+
State (State): environment state
|
|
46
|
+
"""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
50
|
+
def step(
|
|
51
|
+
self,
|
|
52
|
+
key: chex.PRNGKey,
|
|
53
|
+
state: State,
|
|
54
|
+
actions: Dict[str, chex.Array],
|
|
55
|
+
reset_state: Optional[State] = None,
|
|
56
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
|
57
|
+
"""Performs step transitions in the environment. Resets the environment if done.
|
|
58
|
+
To control the reset state, pass `reset_state`. Otherwise, the environment will reset using `self.reset`.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
key (chex.PRNGKey): random key
|
|
62
|
+
state (State): environment state
|
|
63
|
+
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
|
|
64
|
+
reset_state (Optional[State], optional): Optional environment state to reset to on episode completion. Defaults to None.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Observations (Dict[str, chex.Array]): next observations
|
|
68
|
+
State (State): next environment state
|
|
69
|
+
Rewards (Dict[str, float]): rewards, keyed by agent name
|
|
70
|
+
Dones (Dict[str, bool]): dones, keyed by agent name:
|
|
71
|
+
Info (Dict): info dictionary
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
key, key_reset = jax.random.split(key)
|
|
75
|
+
obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
|
|
76
|
+
|
|
77
|
+
if reset_state is None:
|
|
78
|
+
obs_re, states_re = self.reset(key_reset)
|
|
79
|
+
else:
|
|
80
|
+
states_re = reset_state
|
|
81
|
+
obs_re = self.get_obs(states_re)
|
|
82
|
+
|
|
83
|
+
# Auto-reset environment based on termination
|
|
84
|
+
states = jax.tree.map(
|
|
85
|
+
lambda x, y: jax.lax.select(dones["__all__"], x, y), states_re, states_st
|
|
86
|
+
)
|
|
87
|
+
obs = jax.tree.map(
|
|
88
|
+
lambda x, y: jax.lax.select(dones["__all__"], x, y), obs_re, obs_st
|
|
89
|
+
)
|
|
90
|
+
return obs, states, rewards, dones, infos
|
|
91
|
+
|
|
92
|
+
def step_env(
|
|
93
|
+
self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]
|
|
94
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
|
95
|
+
"""Environment-specific step transition.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
key (chex.PRNGKey): random key
|
|
99
|
+
state (State): environment state
|
|
100
|
+
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Observations (Dict[str, chex.Array]): next observations
|
|
104
|
+
State (State): next environment state
|
|
105
|
+
Rewards (Dict[str, float]): rewards, keyed by agent name
|
|
106
|
+
Dones (Dict[str, bool]): dones, keyed by agent name:
|
|
107
|
+
Info (Dict): info dictionary
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
raise NotImplementedError
|
|
111
|
+
|
|
112
|
+
def get_obs(self, state: State) -> Dict[str, chex.Array]:
|
|
113
|
+
"""Applies observation function to state.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
State (state): Environment state
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Observations (Dict[str, chex.Array]): observations keyed by agent names"""
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
122
|
+
def observation_space(self, agent: str) -> Space:
|
|
123
|
+
"""Observation space for a given agent.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
agent (str): agent name
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
space (Space): observation space
|
|
130
|
+
"""
|
|
131
|
+
return self.observation_spaces[agent]
|
|
132
|
+
|
|
133
|
+
def action_space(self, agent: str) -> Space:
|
|
134
|
+
"""Action space for a given agent.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
agent (str): agent name
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
space (Space): action space
|
|
141
|
+
"""
|
|
142
|
+
return self.action_spaces[agent]
|
|
143
|
+
|
|
144
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
145
|
+
def get_avail_actions(self, state: State) -> Dict[str, chex.Array]:
|
|
146
|
+
"""Returns the available actions for each agent.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
state (State): environment state
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
available actions (Dict[str, chex.Array]): available actions keyed by agent name
|
|
153
|
+
"""
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def name(self) -> str:
|
|
158
|
+
"""Environment name."""
|
|
159
|
+
return type(self).__name__
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def agent_classes(self) -> dict:
|
|
163
|
+
"""Returns a dictionary with agent classes
|
|
164
|
+
|
|
165
|
+
Format:
|
|
166
|
+
agent_names: [agent_base_name_1, agent_base_name_2, ...]
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
dfa_gym/robot.png
ADDED
|
Binary file
|