dfa-gym 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -169,3 +169,6 @@ cython_debug/
169
169
 
170
170
  # PyPI configuration file
171
171
  .pypirc
172
+
173
+ # macos
174
+ .DS_Store
@@ -0,0 +1 @@
1
+ 3.10
dfa_gym-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: dfa-gym
3
+ Version: 0.2.0
4
+ Summary: Python library for playing DFA bisimulation games and wrapping other RL environments with DFA goals.
5
+ Author-email: Beyazit Yalcinkaya <beyazit@berkeley.edu>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.10
8
+ Requires-Dist: dfax>=0.1.1
9
+ Description-Content-Type: text/markdown
10
+
11
+ # dfa-gym
12
+
13
+ This repo implements (Multi-Agent) Reinforcement Learning environments in JAX for solving objectives given as Deteministic Finite Automata (DFAs). There are three environments:
14
+
15
+ 1. `TokenEnv` is a fully observable grid environment with tokens in cells. The grid can be created randomly or from a specific layout. It can be instantiated in both single- and multi-agent settings.
16
+ 2. `DFAWrapper` is an environment wrapper assigning tasks represented as Deterministic Finite Automata (DFAs) to the agents in the wrapped environment. DFAs are repsented as [`DFAx`](https://github.com/rad-dfa/dfax) objects.
17
+ 3. `DFABisimEnv` is an environment for solving DFA bisimulation games to learn RAD Embeddings, provably correct latent DFA representation, as described in [this paper](https://arxiv.org/pdf/2503.05042).
18
+
19
+
20
+ ## Installation
21
+
22
+ This package will soon be made pip-installable. In the meantime, pull the repo and and install locally.
23
+
24
+ ```
25
+ git clone https://github.com/rad-dfa/dfa-gym.git
26
+ pip install -e dfa-gym
27
+ ```
28
+
29
+ ## TokenEnv
30
+
31
+ Create a grid world with token and agent positions assigned randomly.
32
+
33
+ ```python
34
+ from dfa_gym import TokenEnv
35
+
36
+ env = TokenEnv(
37
+ n_agents=1, # Single agent
38
+ n_tokens=10, # 10 different token types
39
+ n_token_repeat=2, # Each token repeated twice
40
+ grid_shape=(7, 7), # Shape of the grid
41
+ fixed_map_seed=None, # If not None, then samples the same map using the given seed
42
+ max_steps_in_episode=100, # Episode length is 100
43
+ )
44
+ ```
45
+
46
+ Create a grid world from a given layout.
47
+
48
+ ```python
49
+ layout = """
50
+ [ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
51
+ [ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
52
+ [ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
53
+ [ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
54
+ [ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
55
+ [ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
56
+ [ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
57
+ [ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
58
+ [ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
59
+ """
60
+ env = TokenEnv(
61
+ layout=layout, # Set layout, where each [] indicates a cell, uppercase letters are
62
+ # agents, # are walls, and lower case letters are buttons when alone
63
+ # and doors when paired with a wall. For example, [#,a] is a door
64
+ # that is open if an agent is on a [ a ] cell and closed otherwise.
65
+ )
66
+ ```
67
+
68
+
69
+ ## DFAWrapper
70
+
71
+ Wrap a `TokenEnv` instance using `DFAWrapper `.
72
+
73
+ ```python
74
+ from dfa_gym import DFAWrapper
75
+ from dfax.samplers import ReachSampler
76
+
77
+ env = DFAWrapper(
78
+ env=TokenEnv(layout=layout),
79
+ sampler=ReachSampler()
80
+ )
81
+ ```
82
+
83
+ ## DFABisimEnv
84
+
85
+ Create DFA bisimulation game.
86
+
87
+ ```python
88
+ from dfa_gym import DFABisimEnv
89
+ from dfax.samplers import RADSampler
90
+
91
+ env = DFABisimEnv(sampler=RADSampler())
92
+ ```
93
+
@@ -0,0 +1,83 @@
1
+ # dfa-gym
2
+
3
+ This repo implements (Multi-Agent) Reinforcement Learning environments in JAX for solving objectives given as Deteministic Finite Automata (DFAs). There are three environments:
4
+
5
+ 1. `TokenEnv` is a fully observable grid environment with tokens in cells. The grid can be created randomly or from a specific layout. It can be instantiated in both single- and multi-agent settings.
6
+ 2. `DFAWrapper` is an environment wrapper assigning tasks represented as Deterministic Finite Automata (DFAs) to the agents in the wrapped environment. DFAs are repsented as [`DFAx`](https://github.com/rad-dfa/dfax) objects.
7
+ 3. `DFABisimEnv` is an environment for solving DFA bisimulation games to learn RAD Embeddings, provably correct latent DFA representation, as described in [this paper](https://arxiv.org/pdf/2503.05042).
8
+
9
+
10
+ ## Installation
11
+
12
+ This package will soon be made pip-installable. In the meantime, pull the repo and and install locally.
13
+
14
+ ```
15
+ git clone https://github.com/rad-dfa/dfa-gym.git
16
+ pip install -e dfa-gym
17
+ ```
18
+
19
+ ## TokenEnv
20
+
21
+ Create a grid world with token and agent positions assigned randomly.
22
+
23
+ ```python
24
+ from dfa_gym import TokenEnv
25
+
26
+ env = TokenEnv(
27
+ n_agents=1, # Single agent
28
+ n_tokens=10, # 10 different token types
29
+ n_token_repeat=2, # Each token repeated twice
30
+ grid_shape=(7, 7), # Shape of the grid
31
+ fixed_map_seed=None, # If not None, then samples the same map using the given seed
32
+ max_steps_in_episode=100, # Episode length is 100
33
+ )
34
+ ```
35
+
36
+ Create a grid world from a given layout.
37
+
38
+ ```python
39
+ layout = """
40
+ [ 0 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
41
+ [ ][ ][ a ][ ][#,a][ 0 ][ ][ 2 ][ # ]
42
+ [ A ][ ][ a ][ ][#,a][ ][ 8 ][ ][ # ]
43
+ [ ][ ][ a ][ ][#,a][ 6 ][ ][ 4 ][ # ]
44
+ [ 1 ][ ][ ][ 3 ][ # ][ # ][ # ][ # ][ # ]
45
+ [ ][ ][ b ][ ][#,b][ 1 ][ ][ 3 ][ # ]
46
+ [ B ][ ][ b ][ ][#,b][ ][ 9 ][ ][ # ]
47
+ [ ][ ][ b ][ ][#,b][ 7 ][ ][ 5 ][ # ]
48
+ [ 2 ][ ][ ][ ][ # ][ # ][ # ][ # ][ # ]
49
+ """
50
+ env = TokenEnv(
51
+ layout=layout, # Set layout, where each [] indicates a cell, uppercase letters are
52
+ # agents, # are walls, and lower case letters are buttons when alone
53
+ # and doors when paired with a wall. For example, [#,a] is a door
54
+ # that is open if an agent is on a [ a ] cell and closed otherwise.
55
+ )
56
+ ```
57
+
58
+
59
+ ## DFAWrapper
60
+
61
+ Wrap a `TokenEnv` instance using `DFAWrapper `.
62
+
63
+ ```python
64
+ from dfa_gym import DFAWrapper
65
+ from dfax.samplers import ReachSampler
66
+
67
+ env = DFAWrapper(
68
+ env=TokenEnv(layout=layout),
69
+ sampler=ReachSampler()
70
+ )
71
+ ```
72
+
73
+ ## DFABisimEnv
74
+
75
+ Create DFA bisimulation game.
76
+
77
+ ```python
78
+ from dfa_gym import DFABisimEnv
79
+ from dfax.samplers import RADSampler
80
+
81
+ env = DFABisimEnv(sampler=RADSampler())
82
+ ```
83
+
@@ -0,0 +1,6 @@
1
+ from dfa_gym.token_env import *
2
+ from dfa_gym.dfa_bisim_env import *
3
+ from dfa_gym.dfa_wrapper import *
4
+ from dfa_gym.env import *
5
+ from dfa_gym.spaces import *
6
+ from dfa_gym.utils import *
@@ -0,0 +1,121 @@
1
+ import jax
2
+ import dfax
3
+ import chex
4
+ import jax.numpy as jnp
5
+ from flax import struct
6
+ from functools import partial
7
+ from typing import Tuple, Dict
8
+ from dfa_gym import spaces
9
+ from dfa_gym.env import MultiAgentEnv, State
10
+ from dfax.samplers import DFASampler, RADSampler
11
+
12
+
13
+ @struct.dataclass
14
+ class DFABisimState(State):
15
+ dfa_l: dfax.DFAx
16
+ dfa_r: dfax.DFAx
17
+ time: int
18
+
19
+ class DFABisimEnv(MultiAgentEnv):
20
+
21
+ def __init__(
22
+ self,
23
+ sampler: DFASampler = RADSampler(),
24
+ max_steps_in_episode: int = 100
25
+ ) -> None:
26
+ super().__init__(num_agents=1)
27
+ self.n_agents = self.num_agents
28
+ self.sampler = sampler
29
+ self.max_steps_in_episode = max_steps_in_episode
30
+
31
+ self.agents = [f"agent_{i}" for i in range(self.n_agents)]
32
+
33
+ self.action_spaces = {
34
+ agent: spaces.Discrete(self.sampler.n_tokens)
35
+ for agent in self.agents
36
+ }
37
+ max_dfa_size = self.sampler.max_size
38
+ n_tokens = self.sampler.n_tokens
39
+ self.observation_spaces = {
40
+ agent: spaces.Dict({
41
+ "graph_l": spaces.Dict({
42
+ "node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
43
+ "edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
44
+ "edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
45
+ "current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
46
+ "n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
47
+ }),
48
+ "graph_r": spaces.Dict({
49
+ "node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size, 4), dtype=jnp.uint16),
50
+ "edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*max_dfa_size, n_tokens + 8), dtype=jnp.uint16),
51
+ "edge_index": spaces.Box(low=0, high=max_dfa_size, shape=(2, max_dfa_size*max_dfa_size), dtype=jnp.uint16),
52
+ "current_state": spaces.Box(low=0, high=max_dfa_size, shape=(1,), dtype=jnp.uint16),
53
+ "n_states": spaces.Box(low=0, high=max_dfa_size, shape=(max_dfa_size,), dtype=jnp.uint16)
54
+ })
55
+ })
56
+ for agent in self.agents
57
+ }
58
+
59
+ @partial(jax.jit, static_argnums=(0,))
60
+ def reset(
61
+ self,
62
+ key: chex.PRNGKey
63
+ ) -> Tuple[Dict[str, chex.Array], DFABisimState]:
64
+
65
+ def cond_fn(carry):
66
+ _, dfa_l, dfa_r = carry
67
+ return dfa_l == dfa_r
68
+
69
+ def body_fn(carry):
70
+ key, _, _ = carry
71
+ key, kl, kr = jax.random.split(key, 3)
72
+ dfa_l = self.sampler.sample(kl)
73
+ dfa_r = self.sampler.sample(kr)
74
+ return (key, dfa_l, dfa_r)
75
+
76
+ init_carry = body_fn((key, None, None))
77
+ _, dfa_l, dfa_r = jax.lax.while_loop(cond_fn, body_fn, init_carry)
78
+
79
+ state = DFABisimState(dfa_l=dfa_l, dfa_r=dfa_r, time=0)
80
+ obs = self.get_obs(state=state)
81
+
82
+ return {self.agents[0]: obs}, state
83
+
84
+ @partial(jax.jit, static_argnums=(0,))
85
+ def step_env(
86
+ self,
87
+ key: chex.PRNGKey,
88
+ state: DFABisimState,
89
+ action: int
90
+ ) -> Tuple[Dict[str, chex.Array], DFABisimState, Dict[str, float], Dict[str, bool], Dict]:
91
+
92
+ dfa_l = state.dfa_l.advance(action[self.agents[0]]).minimize()
93
+ dfa_r = state.dfa_r.advance(action[self.agents[0]]).minimize()
94
+
95
+ reward_l = dfa_l.reward(binary=False)
96
+ reward_r = dfa_r.reward(binary=False)
97
+ reward = reward_l - reward_r
98
+
99
+ new_state = DFABisimState(
100
+ dfa_l=dfa_l,
101
+ dfa_r=dfa_r,
102
+ time=state.time+1
103
+ )
104
+
105
+ done = jnp.logical_or(jnp.logical_or(dfa_l.n_states <= 1, dfa_r.n_states <= 1), new_state.time >= self.max_steps_in_episode)
106
+
107
+ obs = self.get_obs(state=new_state)
108
+ info = {}
109
+
110
+ return {self.agents[0]: obs}, new_state, {self.agents[0]: reward}, {self.agents[0]: done, "__all__": done}, info
111
+
112
+ @partial(jax.jit, static_argnums=(0,))
113
+ def get_obs(
114
+ self,
115
+ state: DFABisimState
116
+ ) -> Dict[str, chex.Array]:
117
+ return {
118
+ "graph_l": state.dfa_l.to_graph(),
119
+ "graph_r": state.dfa_r.to_graph()
120
+ }
121
+
@@ -0,0 +1,190 @@
1
+ import jax
2
+ import dfax
3
+ import chex
4
+ import jax.numpy as jnp
5
+ from flax import struct
6
+ from dfa_gym import spaces
7
+ from functools import partial
8
+ from typing import Tuple, Dict, Callable
9
+ from dfax.utils import list2batch, batch2graph
10
+ from dfa_gym.env import MultiAgentEnv, State
11
+ from dfax.samplers import DFASampler, RADSampler
12
+
13
+
14
+ @struct.dataclass
15
+ class DFAWrapperState(State):
16
+ dfas: Dict[str, dfax.DFAx]
17
+ init_dfas: Dict[str, dfax.DFAx]
18
+ env_obs: chex.Array
19
+ env_state: State
20
+
21
+ class DFAWrapper(MultiAgentEnv):
22
+
23
+ def __init__(
24
+ self,
25
+ env: MultiAgentEnv,
26
+ gamma: float | None = None,
27
+ sampler: DFASampler = RADSampler(),
28
+ binary_reward: bool = True,
29
+ progress: bool = True,
30
+ ) -> None:
31
+ super().__init__(num_agents=env.num_agents)
32
+ self.env = env
33
+ self.gamma = gamma
34
+ self.sampler = sampler
35
+ self.binary_reward = binary_reward
36
+ self.progress = progress
37
+
38
+ assert self.sampler.n_tokens == self.env.n_tokens
39
+
40
+ self.agents = [f"agent_{i}" for i in range(self.num_agents)]
41
+
42
+ self.action_spaces = {
43
+ agent: self.env.action_space(agent)
44
+ for agent in self.agents
45
+ }
46
+ max_dfa_size = self.sampler.max_size
47
+ n_tokens = self.sampler.n_tokens
48
+ self.observation_spaces = {
49
+ agent: spaces.Dict({
50
+ "_id": spaces.Discrete(self.num_agents),
51
+ "obs": self.env.observation_space(agent),
52
+ "dfa": spaces.Dict({
53
+ "node_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents, 4), dtype=jnp.float32),
54
+ "edge_features": spaces.Box(low=0, high=1, shape=(max_dfa_size*self.num_agents*max_dfa_size*self.num_agents, n_tokens + 8), dtype=jnp.float32),
55
+ "edge_index": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(2, max_dfa_size*self.num_agents*max_dfa_size*self.num_agents), dtype=jnp.int32),
56
+ "current_state": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(self.num_agents,), dtype=jnp.int32),
57
+ "n_states": spaces.Box(low=0, high=max_dfa_size*self.num_agents, shape=(max_dfa_size*self.num_agents,), dtype=jnp.int32)
58
+ }),
59
+ })
60
+ for agent in self.agents
61
+ }
62
+
63
+ @partial(jax.jit, static_argnums=(0,))
64
+ def reset(
65
+ self,
66
+ key: chex.PRNGKey
67
+ ) -> Tuple[Dict[str, chex.Array], DFAWrapperState]:
68
+ keys = jax.random.split(key, 4 + self.num_agents)
69
+
70
+ env_obs, env_state = self.env.reset(keys[1])
71
+
72
+ n_trivial = jax.random.choice(keys[2], self.num_agents)
73
+ mask = jax.random.permutation(keys[3], jnp.arange(self.num_agents) < n_trivial)
74
+
75
+ def sample_dfa(dfa_key, sample_trivial):
76
+ return jax.tree_util.tree_map(
77
+ lambda t, s: jnp.where(sample_trivial, t, s),
78
+ self.sampler.trivial(True),
79
+ self.sampler.sample(dfa_key)
80
+ )
81
+
82
+ dfas_tree = jax.vmap(sample_dfa)(keys[4:], mask)
83
+
84
+ dfas = {
85
+ agent: jax.tree_util.tree_map(lambda x: x[i], dfas_tree)
86
+ for i, agent in enumerate(self.agents)
87
+ }
88
+
89
+ state = DFAWrapperState(
90
+ dfas=dfas,
91
+ init_dfas={agent: dfas[agent] for agent in self.agents},
92
+ env_obs=env_obs,
93
+ env_state=env_state
94
+ )
95
+ obs = self.get_obs(state=state)
96
+
97
+ return obs, state
98
+
99
+ @partial(jax.jit, static_argnums=(0,))
100
+ def step_env(
101
+ self,
102
+ key: chex.PRNGKey,
103
+ state: DFAWrapperState,
104
+ action: int,
105
+ ) -> Tuple[Dict[str, chex.Array], DFAWrapperState, Dict[str, float], Dict[str, bool], Dict]:
106
+
107
+ env_obs, env_state, env_rewards, env_dones, env_info = self.env.step_env(key, state.env_state, action)
108
+
109
+ symbols = self.env.label_f(env_state)
110
+
111
+ dfas = {
112
+ agent: state.dfas[agent].advance(symbols[agent]).minimize()
113
+ for agent in self.agents
114
+ }
115
+
116
+ dones = {
117
+ agent: jnp.logical_or(env_dones[agent], dfas[agent].n_states <= 1)
118
+ for agent in self.agents
119
+ }
120
+ _dones = jnp.array([dones[agent] for agent in self.agents])
121
+ dones.update({"__all__": jnp.all(_dones)})
122
+
123
+ dfa_rewards_min = jnp.min(jnp.array([dfas[agent].reward(binary=self.binary_reward) for agent in self.agents]))
124
+ rewards = {
125
+ agent: jax.lax.cond(
126
+ dones["__all__"],
127
+ lambda _: env_rewards[agent] + dfa_rewards_min,
128
+ lambda _: env_rewards[agent],
129
+ operand=None
130
+ )
131
+ for agent in self.agents
132
+ }
133
+
134
+ if self.gamma is not None:
135
+ rewards = {
136
+ agent: rewards[agent] + self.gamma * dfas[agent].reward(binary=self.binary_reward) - state.dfas[agent].reward(binary=self.binary_reward)
137
+ for agent in self.agents
138
+ }
139
+
140
+ infos = {}
141
+
142
+ state = DFAWrapperState(
143
+ dfas=dfas,
144
+ init_dfas=state.init_dfas,
145
+ env_obs=env_obs,
146
+ env_state=env_state
147
+ )
148
+
149
+ obs = self.get_obs(state=state)
150
+
151
+ return obs, state, rewards, dones, infos
152
+
153
+ @partial(jax.jit, static_argnums=(0,))
154
+ def get_obs(
155
+ self,
156
+ state: DFAWrapperState
157
+ ) -> Dict[str, chex.Array]:
158
+ if self.progress:
159
+ dfas = batch2graph(
160
+ list2batch(
161
+ [state.dfas[agent].to_graph() for agent in self.agents]
162
+ )
163
+ )
164
+ else:
165
+ dfas = batch2graph(
166
+ list2batch(
167
+ [state.init_dfas[agent].to_graph() for agent in self.agents]
168
+ )
169
+ )
170
+ return {
171
+ agent: {
172
+ "_id": i,
173
+ "obs": state.env_obs[agent],
174
+ "dfa": dfas
175
+ }
176
+ for i, agent in enumerate(self.agents)
177
+ }
178
+
179
+ def render(self, state: DFAWrapperState):
180
+ out = ""
181
+ for agent in self.agents:
182
+ out += "****\n"
183
+ out += f"{agent}'s DFA:\n"
184
+ if self.progress:
185
+ out += f"{state.dfas[agent]}\n"
186
+ else:
187
+ out += f"{state.init_dfas[agent]}\n"
188
+ self.env.render(state.env_state)
189
+ print(out)
190
+
@@ -0,0 +1,168 @@
1
+ """
2
+ Abstract base class for multi agent gym environments with JAX
3
+ Based on the JaxMARL APIs
4
+ """
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from typing import Dict
9
+ import chex
10
+ from functools import partial
11
+ from flax import struct
12
+ from typing import Tuple, Optional
13
+
14
+ from dfa_gym.spaces import Space
15
+
16
+ @struct.dataclass
17
+ class State:
18
+ pass
19
+
20
+
21
+ class MultiAgentEnv(object):
22
+ """Jittable abstract base class for all JaxMARL Environments."""
23
+
24
+ def __init__(
25
+ self,
26
+ num_agents: int,
27
+ ) -> None:
28
+ """
29
+ Args:
30
+ num_agents (int): maximum number of agents within the environment, used to set array dimensions
31
+ """
32
+ self.num_agents = num_agents
33
+ self.observation_spaces = dict()
34
+ self.action_spaces = dict()
35
+
36
+ @partial(jax.jit, static_argnums=(0,))
37
+ def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
38
+ """Performs resetting of the environment.
39
+
40
+ Args:
41
+ key (chex.PRNGKey): random key
42
+
43
+ Returns:
44
+ Observations (Dict[str, chex.Array]): observations for each agent, keyed by agent name
45
+ State (State): environment state
46
+ """
47
+ raise NotImplementedError
48
+
49
+ @partial(jax.jit, static_argnums=(0,))
50
+ def step(
51
+ self,
52
+ key: chex.PRNGKey,
53
+ state: State,
54
+ actions: Dict[str, chex.Array],
55
+ reset_state: Optional[State] = None,
56
+ ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
57
+ """Performs step transitions in the environment. Resets the environment if done.
58
+ To control the reset state, pass `reset_state`. Otherwise, the environment will reset using `self.reset`.
59
+
60
+ Args:
61
+ key (chex.PRNGKey): random key
62
+ state (State): environment state
63
+ actions (Dict[str, chex.Array]): agent actions, keyed by agent name
64
+ reset_state (Optional[State], optional): Optional environment state to reset to on episode completion. Defaults to None.
65
+
66
+ Returns:
67
+ Observations (Dict[str, chex.Array]): next observations
68
+ State (State): next environment state
69
+ Rewards (Dict[str, float]): rewards, keyed by agent name
70
+ Dones (Dict[str, bool]): dones, keyed by agent name:
71
+ Info (Dict): info dictionary
72
+ """
73
+
74
+ key, key_reset = jax.random.split(key)
75
+ obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
76
+
77
+ if reset_state is None:
78
+ obs_re, states_re = self.reset(key_reset)
79
+ else:
80
+ states_re = reset_state
81
+ obs_re = self.get_obs(states_re)
82
+
83
+ # Auto-reset environment based on termination
84
+ states = jax.tree.map(
85
+ lambda x, y: jax.lax.select(dones["__all__"], x, y), states_re, states_st
86
+ )
87
+ obs = jax.tree.map(
88
+ lambda x, y: jax.lax.select(dones["__all__"], x, y), obs_re, obs_st
89
+ )
90
+ return obs, states, rewards, dones, infos
91
+
92
+ def step_env(
93
+ self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]
94
+ ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
95
+ """Environment-specific step transition.
96
+
97
+ Args:
98
+ key (chex.PRNGKey): random key
99
+ state (State): environment state
100
+ actions (Dict[str, chex.Array]): agent actions, keyed by agent name
101
+
102
+ Returns:
103
+ Observations (Dict[str, chex.Array]): next observations
104
+ State (State): next environment state
105
+ Rewards (Dict[str, float]): rewards, keyed by agent name
106
+ Dones (Dict[str, bool]): dones, keyed by agent name:
107
+ Info (Dict): info dictionary
108
+ """
109
+
110
+ raise NotImplementedError
111
+
112
+ def get_obs(self, state: State) -> Dict[str, chex.Array]:
113
+ """Applies observation function to state.
114
+
115
+ Args:
116
+ State (state): Environment state
117
+
118
+ Returns:
119
+ Observations (Dict[str, chex.Array]): observations keyed by agent names"""
120
+ raise NotImplementedError
121
+
122
+ def observation_space(self, agent: str) -> Space:
123
+ """Observation space for a given agent.
124
+
125
+ Args:
126
+ agent (str): agent name
127
+
128
+ Returns:
129
+ space (Space): observation space
130
+ """
131
+ return self.observation_spaces[agent]
132
+
133
+ def action_space(self, agent: str) -> Space:
134
+ """Action space for a given agent.
135
+
136
+ Args:
137
+ agent (str): agent name
138
+
139
+ Returns:
140
+ space (Space): action space
141
+ """
142
+ return self.action_spaces[agent]
143
+
144
+ @partial(jax.jit, static_argnums=(0,))
145
+ def get_avail_actions(self, state: State) -> Dict[str, chex.Array]:
146
+ """Returns the available actions for each agent.
147
+
148
+ Args:
149
+ state (State): environment state
150
+
151
+ Returns:
152
+ available actions (Dict[str, chex.Array]): available actions keyed by agent name
153
+ """
154
+ raise NotImplementedError
155
+
156
+ @property
157
+ def name(self) -> str:
158
+ """Environment name."""
159
+ return type(self).__name__
160
+
161
+ @property
162
+ def agent_classes(self) -> dict:
163
+ """Returns a dictionary with agent classes
164
+
165
+ Format:
166
+ agent_names: [agent_base_name_1, agent_base_name_2, ...]
167
+ """
168
+ raise NotImplementedError
Binary file